diff --git a/CHANGELOG.md b/CHANGELOG.md index 4bd1a3acb..59fc9bcc8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,4 @@ +* Added support for the `json.Unmarshaler` interface in the `CastTo` function for use in scanners, such as the `ScanStruct` method * Fixed the support of server-side session balancing in `database/sql` driver * Added `ydb.WithDisableSessionBalancer()` driver option for disable server-side session balancing on table and query clients diff --git a/internal/value/cast_test.go b/internal/value/cast_test.go index 863587f59..2a9a924e6 100644 --- a/internal/value/cast_test.go +++ b/internal/value/cast_test.go @@ -2,6 +2,8 @@ package value import ( "database/sql/driver" + "encoding/json" + "errors" "reflect" "testing" "time" @@ -25,6 +27,28 @@ func unwrapPtr(v interface{}) interface{} { return reflect.ValueOf(v).Elem().Interface() } +type jsonUnmarshaller struct { + bytes []byte +} + +func (json *jsonUnmarshaller) UnmarshalJSON(bytes []byte) error { + json.bytes = bytes + + return nil +} + +var _ json.Unmarshaler = &jsonUnmarshaller{} + +type jsonUnmarshallerBroken struct { + bytes []byte +} + +func (json *jsonUnmarshallerBroken) UnmarshalJSON(_ []byte) error { + return errors.New("unmarshal error") +} + +var _ json.Unmarshaler = &jsonUnmarshallerBroken{} + func loadLocation(t *testing.T, name string) *time.Location { loc, err := time.LoadLocation(name) require.NoError(t, err) @@ -138,6 +162,75 @@ func TestCastTo(t *testing.T) { err: ErrCannotCast, }, + // JSONValue + { + name: xtest.CurrentFileLine(), + value: JSONValue(`{"test": "text"}"`), + dst: ptr[string](), + exp: `{"test": "text"}"`, + err: nil, + }, + { + name: xtest.CurrentFileLine(), + value: JSONValue(`{"test":"text"}"`), + dst: ptr[Value](), + exp: JSONValue(`{"test":"text"}"`), + err: nil, + }, + { + name: xtest.CurrentFileLine(), + value: OptionalValue(JSONValue(`{"test": "text"}"`)), + dst: ptr[*[]byte](), + exp: value2ptr([]byte(`{"test": "text"}"`)), + err: nil, + }, + { + name: xtest.CurrentFileLine(), + value: JSONValue(`{"test":"text"}"`), + dst: ptr[[]byte](), + exp: []byte(`{"test":"text"}"`), + err: nil, + }, + { + name: xtest.CurrentFileLine(), + value: JSONValue(`{"test": "text"}"`), + dst: ptr[jsonUnmarshaller](), + exp: jsonUnmarshaller{[]byte(`{"test": "text"}"`)}, + err: nil, + }, + { + name: xtest.CurrentFileLine(), + value: JSONValue(`{"test": "text"}"`), + dst: ptr[jsonUnmarshallerBroken](), + err: ErrCannotCast, + }, + { + name: xtest.CurrentFileLine(), + value: OptionalValue(JSONValue(`{"test": "text"}"`)), + dst: ptr[jsonUnmarshaller](), + exp: jsonUnmarshaller{[]byte(`{"test": "text"}"`)}, + err: nil, + }, + { + name: xtest.CurrentFileLine(), + value: OptionalValue(JSONValue(`{"test": "text"}"`)), + dst: ptr[jsonUnmarshallerBroken](), + err: ErrCannotCast, + }, + { + name: xtest.CurrentFileLine(), + value: JSONValue(`{"test": "text"}"`), + dst: ptr[int](), + err: ErrCannotCast, + }, + { + name: xtest.CurrentFileLine(), + value: OptionalValue(JSONValue(`{"test": "text"}"`)), + dst: ptr[int](), + err: ErrCannotCast, + }, + + // JSONDocumentValue { name: xtest.CurrentFileLine(), value: JSONDocumentValue(`{"test": "text"}"`), @@ -166,6 +259,44 @@ func TestCastTo(t *testing.T) { exp: []byte(`{"test":"text"}"`), err: nil, }, + { + name: xtest.CurrentFileLine(), + value: JSONDocumentValue(`{"test": "text"}"`), + dst: ptr[jsonUnmarshaller](), + exp: jsonUnmarshaller{[]byte(`{"test": "text"}"`)}, + err: nil, + }, + { + name: xtest.CurrentFileLine(), + value: JSONDocumentValue(`{"test": "text"}"`), + dst: ptr[jsonUnmarshallerBroken](), + err: ErrCannotCast, + }, + { + name: xtest.CurrentFileLine(), + value: OptionalValue(JSONDocumentValue(`{"test": "text"}"`)), + dst: ptr[jsonUnmarshaller](), + exp: jsonUnmarshaller{[]byte(`{"test": "text"}"`)}, + err: nil, + }, + { + name: xtest.CurrentFileLine(), + value: OptionalValue(JSONDocumentValue(`{"test": "text"}"`)), + dst: ptr[jsonUnmarshallerBroken](), + err: ErrCannotCast, + }, + { + name: xtest.CurrentFileLine(), + value: JSONDocumentValue(`{"test": "text"}"`), + dst: ptr[int](), + err: ErrCannotCast, + }, + { + name: xtest.CurrentFileLine(), + value: OptionalValue(JSONDocumentValue(`{"test": "text"}"`)), + dst: ptr[int](), + err: ErrCannotCast, + }, { name: xtest.CurrentFileLine(), diff --git a/internal/value/value.go b/internal/value/value.go index 2e642b886..ca09b95a1 100644 --- a/internal/value/value.go +++ b/internal/value/value.go @@ -3,6 +3,7 @@ package value import ( "database/sql/driver" "encoding/binary" + "encoding/json" "fmt" "math/big" "reflect" @@ -1369,6 +1370,16 @@ func (v jsonValue) castTo(dst any) error { case *[]byte: *vv = xstring.ToBytes(string(v)) + return nil + case json.Unmarshaler: + err := vv.UnmarshalJSON(xstring.ToBytes(string(v))) + if err != nil { + return xerrors.WithStackTrace(fmt.Errorf( + "%w '%s(%+v)' to '%T' destination: %w", + ErrCannotCast, v.Type().Yql(), v, vv, err, + )) + } + return nil default: return xerrors.WithStackTrace(fmt.Errorf( @@ -1413,6 +1424,16 @@ func (v jsonDocumentValue) castTo(dst any) error { case *[]byte: *vv = xstring.ToBytes(string(v)) + return nil + case json.Unmarshaler: + err := vv.UnmarshalJSON(xstring.ToBytes(string(v))) + if err != nil { + return xerrors.WithStackTrace(fmt.Errorf( + "%w '%s(%+v)' to '%T' destination: %w", + ErrCannotCast, v.Type().Yql(), v, vv, err, + )) + } + return nil default: return xerrors.WithStackTrace(fmt.Errorf(