Skip to content

Commit dc547f2

Browse files
added json unmarshaler for castTo (#1825)
* added json unmarshaler for castTo * added json unmarshaler for castTo tests * json unmarshaler castTo CHANGELOG.md description --------- Co-authored-by: Aleksey Myasnikov <[email protected]>
1 parent 333e663 commit dc547f2

File tree

3 files changed

+153
-0
lines changed

3 files changed

+153
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
* Added support for the `json.Unmarshaler` interface in the `CastTo` function for use in scanners, such as the `ScanStruct` method
12
* Fixed the support of server-side session balancing in `database/sql` driver
23
* Added `ydb.WithDisableSessionBalancer()` driver option for disable server-side session balancing on table and query clients
34

internal/value/cast_test.go

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package value
22

33
import (
44
"database/sql/driver"
5+
"encoding/json"
6+
"errors"
57
"reflect"
68
"testing"
79
"time"
@@ -25,6 +27,28 @@ func unwrapPtr(v interface{}) interface{} {
2527
return reflect.ValueOf(v).Elem().Interface()
2628
}
2729

30+
type jsonUnmarshaller struct {
31+
bytes []byte
32+
}
33+
34+
func (json *jsonUnmarshaller) UnmarshalJSON(bytes []byte) error {
35+
json.bytes = bytes
36+
37+
return nil
38+
}
39+
40+
var _ json.Unmarshaler = &jsonUnmarshaller{}
41+
42+
type jsonUnmarshallerBroken struct {
43+
bytes []byte
44+
}
45+
46+
func (json *jsonUnmarshallerBroken) UnmarshalJSON(_ []byte) error {
47+
return errors.New("unmarshal error")
48+
}
49+
50+
var _ json.Unmarshaler = &jsonUnmarshallerBroken{}
51+
2852
func loadLocation(t *testing.T, name string) *time.Location {
2953
loc, err := time.LoadLocation(name)
3054
require.NoError(t, err)
@@ -138,6 +162,75 @@ func TestCastTo(t *testing.T) {
138162
err: ErrCannotCast,
139163
},
140164

165+
// JSONValue
166+
{
167+
name: xtest.CurrentFileLine(),
168+
value: JSONValue(`{"test": "text"}"`),
169+
dst: ptr[string](),
170+
exp: `{"test": "text"}"`,
171+
err: nil,
172+
},
173+
{
174+
name: xtest.CurrentFileLine(),
175+
value: JSONValue(`{"test":"text"}"`),
176+
dst: ptr[Value](),
177+
exp: JSONValue(`{"test":"text"}"`),
178+
err: nil,
179+
},
180+
{
181+
name: xtest.CurrentFileLine(),
182+
value: OptionalValue(JSONValue(`{"test": "text"}"`)),
183+
dst: ptr[*[]byte](),
184+
exp: value2ptr([]byte(`{"test": "text"}"`)),
185+
err: nil,
186+
},
187+
{
188+
name: xtest.CurrentFileLine(),
189+
value: JSONValue(`{"test":"text"}"`),
190+
dst: ptr[[]byte](),
191+
exp: []byte(`{"test":"text"}"`),
192+
err: nil,
193+
},
194+
{
195+
name: xtest.CurrentFileLine(),
196+
value: JSONValue(`{"test": "text"}"`),
197+
dst: ptr[jsonUnmarshaller](),
198+
exp: jsonUnmarshaller{[]byte(`{"test": "text"}"`)},
199+
err: nil,
200+
},
201+
{
202+
name: xtest.CurrentFileLine(),
203+
value: JSONValue(`{"test": "text"}"`),
204+
dst: ptr[jsonUnmarshallerBroken](),
205+
err: ErrCannotCast,
206+
},
207+
{
208+
name: xtest.CurrentFileLine(),
209+
value: OptionalValue(JSONValue(`{"test": "text"}"`)),
210+
dst: ptr[jsonUnmarshaller](),
211+
exp: jsonUnmarshaller{[]byte(`{"test": "text"}"`)},
212+
err: nil,
213+
},
214+
{
215+
name: xtest.CurrentFileLine(),
216+
value: OptionalValue(JSONValue(`{"test": "text"}"`)),
217+
dst: ptr[jsonUnmarshallerBroken](),
218+
err: ErrCannotCast,
219+
},
220+
{
221+
name: xtest.CurrentFileLine(),
222+
value: JSONValue(`{"test": "text"}"`),
223+
dst: ptr[int](),
224+
err: ErrCannotCast,
225+
},
226+
{
227+
name: xtest.CurrentFileLine(),
228+
value: OptionalValue(JSONValue(`{"test": "text"}"`)),
229+
dst: ptr[int](),
230+
err: ErrCannotCast,
231+
},
232+
233+
// JSONDocumentValue
141234
{
142235
name: xtest.CurrentFileLine(),
143236
value: JSONDocumentValue(`{"test": "text"}"`),
@@ -166,6 +259,44 @@ func TestCastTo(t *testing.T) {
166259
exp: []byte(`{"test":"text"}"`),
167260
err: nil,
168261
},
262+
{
263+
name: xtest.CurrentFileLine(),
264+
value: JSONDocumentValue(`{"test": "text"}"`),
265+
dst: ptr[jsonUnmarshaller](),
266+
exp: jsonUnmarshaller{[]byte(`{"test": "text"}"`)},
267+
err: nil,
268+
},
269+
{
270+
name: xtest.CurrentFileLine(),
271+
value: JSONDocumentValue(`{"test": "text"}"`),
272+
dst: ptr[jsonUnmarshallerBroken](),
273+
err: ErrCannotCast,
274+
},
275+
{
276+
name: xtest.CurrentFileLine(),
277+
value: OptionalValue(JSONDocumentValue(`{"test": "text"}"`)),
278+
dst: ptr[jsonUnmarshaller](),
279+
exp: jsonUnmarshaller{[]byte(`{"test": "text"}"`)},
280+
err: nil,
281+
},
282+
{
283+
name: xtest.CurrentFileLine(),
284+
value: OptionalValue(JSONDocumentValue(`{"test": "text"}"`)),
285+
dst: ptr[jsonUnmarshallerBroken](),
286+
err: ErrCannotCast,
287+
},
288+
{
289+
name: xtest.CurrentFileLine(),
290+
value: JSONDocumentValue(`{"test": "text"}"`),
291+
dst: ptr[int](),
292+
err: ErrCannotCast,
293+
},
294+
{
295+
name: xtest.CurrentFileLine(),
296+
value: OptionalValue(JSONDocumentValue(`{"test": "text"}"`)),
297+
dst: ptr[int](),
298+
err: ErrCannotCast,
299+
},
169300

170301
{
171302
name: xtest.CurrentFileLine(),

internal/value/value.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package value
33
import (
44
"database/sql/driver"
55
"encoding/binary"
6+
"encoding/json"
67
"fmt"
78
"math/big"
89
"reflect"
@@ -1369,6 +1370,16 @@ func (v jsonValue) castTo(dst any) error {
13691370
case *[]byte:
13701371
*vv = xstring.ToBytes(string(v))
13711372

1373+
return nil
1374+
case json.Unmarshaler:
1375+
err := vv.UnmarshalJSON(xstring.ToBytes(string(v)))
1376+
if err != nil {
1377+
return xerrors.WithStackTrace(fmt.Errorf(
1378+
"%w '%s(%+v)' to '%T' destination: %w",
1379+
ErrCannotCast, v.Type().Yql(), v, vv, err,
1380+
))
1381+
}
1382+
13721383
return nil
13731384
default:
13741385
return xerrors.WithStackTrace(fmt.Errorf(
@@ -1413,6 +1424,16 @@ func (v jsonDocumentValue) castTo(dst any) error {
14131424
case *[]byte:
14141425
*vv = xstring.ToBytes(string(v))
14151426

1427+
return nil
1428+
case json.Unmarshaler:
1429+
err := vv.UnmarshalJSON(xstring.ToBytes(string(v)))
1430+
if err != nil {
1431+
return xerrors.WithStackTrace(fmt.Errorf(
1432+
"%w '%s(%+v)' to '%T' destination: %w",
1433+
ErrCannotCast, v.Type().Yql(), v, vv, err,
1434+
))
1435+
}
1436+
14161437
return nil
14171438
default:
14181439
return xerrors.WithStackTrace(fmt.Errorf(

0 commit comments

Comments
 (0)