diff --git a/CHANGELOG.md b/CHANGELOG.md index 334ff1919..8e5285c4e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,5 @@ +* Added validation for the WithTxControl option in the non-interactive methods of Client and Session in the query service. + ## v3.108.1 * Supported `json.Marshaller` query parameter in `database/sql` driver diff --git a/examples/basic/native/query/series.go b/examples/basic/native/query/series.go index 40506f6f6..27fe774a4 100644 --- a/examples/basic/native/query/series.go +++ b/examples/basic/native/query/series.go @@ -25,7 +25,7 @@ func read(ctx context.Context, c query.Client, prefix string) error { FROM %s `, "`"+path.Join(prefix, "series")+"`"), - query.WithTxControl(query.TxControl(query.BeginTx(query.WithSnapshotReadOnly()))), + query.WithTxControl(query.SnapshotReadOnlyTxControl()), ) if err != nil { return err diff --git a/internal/query/client.go b/internal/query/client.go index 2dc0ad0fc..84b30eb9e 100644 --- a/internal/query/client.go +++ b/internal/query/client.go @@ -2,6 +2,7 @@ package query import ( "context" + "errors" "time" "github.com/ydb-platform/ydb-go-genproto/Ydb_Query_V1" @@ -20,6 +21,7 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/internal/types" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xcontext" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" + "github.com/ydb-platform/ydb-go-sdk/v3/internal/xreflect" "github.com/ydb-platform/ydb-go-sdk/v3/query" "github.com/ydb-platform/ydb-go-sdk/v3/retry" "github.com/ydb-platform/ydb-go-sdk/v3/trace" @@ -32,6 +34,8 @@ var ( _ sessionPool = (*pool.Pool[*Session, Session])(nil) ) +var errNoCommit = xerrors.Wrap(errors.New("WithTxControl option is not allowed without CommitTx() option in Client methods, as these methods are non-interactive. You can either add the CommitTx() option to TxControl or use query.*TxControl methods (e.g., query.SnapshotReadOnlyTxControl) which already include the commit flag")) //nolint:lll + type ( sessionPool interface { closer.Closer @@ -173,6 +177,10 @@ func (c *Client) ExecuteScript( ), } + if err := checkTxControlWithCommit(settings.TxControl()); err != nil { + return nil, err + } + request, grpcOpts, err := executeQueryScriptRequest(q, settings) if err != nil { return op, xerrors.WithStackTrace(err) @@ -320,6 +328,10 @@ func (c *Client) QueryRow(ctx context.Context, q string, opts ...options.Execute settings := options.ExecuteSettings(opts...) + if err := checkTxControlWithCommit(settings.TxControl()); err != nil { + return nil, err + } + onDone := trace.QueryOnQueryRow(c.config.Trace(), &ctx, stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query.(*Client).QueryRow"), q, settings.Label(), @@ -366,6 +378,11 @@ func (c *Client) Exec(ctx context.Context, q string, opts ...options.Execute) (f defer cancel() settings := options.ExecuteSettings(opts...) + + if err := checkTxControlWithCommit(settings.TxControl()); err != nil { + return err + } + onDone := trace.QueryOnExec(c.config.Trace(), &ctx, stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query.(*Client).Exec"), q, @@ -415,6 +432,11 @@ func (c *Client) Query(ctx context.Context, q string, opts ...options.Execute) ( defer cancel() settings := options.ExecuteSettings(opts...) + + if err := checkTxControlWithCommit(settings.TxControl()); err != nil { + return nil, err + } + onDone := trace.QueryOnQuery(c.config.Trace(), &ctx, stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query.(*Client).Query"), q, settings.Label(), @@ -470,6 +492,10 @@ func (c *Client) QueryResultSet( err error ) + if err := checkTxControlWithCommit(settings.TxControl()); err != nil { + return nil, err + } + onDone := trace.QueryOnQueryResultSet(c.config.Trace(), &ctx, stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query.(*Client).QueryResultSet"), q, settings.Label(), @@ -612,6 +638,15 @@ func New(ctx context.Context, cc grpc.ClientConnInterface, cfg *config.Config) * } } +// checkTxControlWithCommit validates the transaction control object to ensure it includes a commit flag. +func checkTxControlWithCommit(txControl options.TxControl) error { + if !xreflect.IsContainsNilPointer(txControl) && !txControl.Commit() { + return xerrors.WithStackTrace(errNoCommit) + } + + return nil +} + func poolTrace(t *trace.Query) *pool.Trace { return &pool.Trace{ OnNew: func(ctx *context.Context, call stack.Caller) func(limit int) { diff --git a/internal/query/session.go b/internal/query/session.go index dce1cc130..87f8e070d 100644 --- a/internal/query/session.go +++ b/internal/query/session.go @@ -36,7 +36,12 @@ func (s *Session) QueryResultSet( onDone(finalErr) }() - r, err := s.execute(ctx, q, options.ExecuteSettings(opts...), withTrace(s.trace)) + settings := options.ExecuteSettings(opts...) + if err := checkTxControlWithCommit(settings.TxControl()); err != nil { + return nil, err + } + + r, err := s.execute(ctx, q, settings, withTrace(s.trace)) if err != nil { return nil, xerrors.WithStackTrace(err) } @@ -75,7 +80,12 @@ func (s *Session) QueryRow(ctx context.Context, q string, opts ...options.Execut onDone(finalErr) }() - row, err := s.queryRow(ctx, q, options.ExecuteSettings(opts...), withTrace(s.trace)) + settings := options.ExecuteSettings(opts...) + if err := checkTxControlWithCommit(settings.TxControl()); err != nil { + return nil, err + } + + row, err := s.queryRow(ctx, q, settings, withTrace(s.trace)) if err != nil { return nil, xerrors.WithStackTrace(err) } @@ -154,6 +164,11 @@ func (s *Session) execute( func (s *Session) Exec(ctx context.Context, q string, opts ...options.Execute) (finalErr error) { settings := options.ExecuteSettings(opts...) + + if err := checkTxControlWithCommit(settings.TxControl()); err != nil { + return err + } + onDone := trace.QueryOnSessionExec(s.trace, &ctx, stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query.(*Session).Exec"), s, @@ -182,6 +197,11 @@ func (s *Session) Exec(ctx context.Context, q string, opts ...options.Execute) ( func (s *Session) Query(ctx context.Context, q string, opts ...options.Execute) (_ query.Result, finalErr error) { settings := options.ExecuteSettings(opts...) + + if err := checkTxControlWithCommit(settings.TxControl()); err != nil { + return nil, err + } + onDone := trace.QueryOnSessionQuery(s.trace, &ctx, stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query.(*Session).Query"), s, diff --git a/internal/xreflect/is_nil.go b/internal/xreflect/is_nil.go new file mode 100644 index 000000000..6b8634fe0 --- /dev/null +++ b/internal/xreflect/is_nil.go @@ -0,0 +1,32 @@ +package xreflect + +import "reflect" + +func IsContainsNilPointer(v any) bool { + if v == nil { + return true + } + + rVal := reflect.ValueOf(v) + + return isValPointToNil(rVal) +} + +func isValPointToNil(v reflect.Value) bool { + kind := v.Kind() + var res bool + switch kind { + case reflect.Slice: + return false + case reflect.Chan, reflect.Func, reflect.Map, reflect.UnsafePointer: + res = v.IsNil() + case reflect.Pointer, reflect.Interface: + elem := v.Elem() + if v.IsNil() { + return true + } + res = isValPointToNil(elem) + } + + return res +} diff --git a/internal/xreflect/is_nil_test.go b/internal/xreflect/is_nil_test.go new file mode 100644 index 000000000..6f3713f42 --- /dev/null +++ b/internal/xreflect/is_nil_test.go @@ -0,0 +1,92 @@ +package xreflect + +import ( + "testing" +) + +func TestIsContainsNilPointer(t *testing.T) { + var nilIntPointer *int + vInterface := nilIntPointer + + // Test cases for different nil and non-nil scenarios + tests := []struct { + name string + input any + expected bool + }{ + { + name: "nil interface", + input: nil, + expected: true, + }, + { + name: "nil pointer to int", + input: (*int)(nil), + expected: true, + }, + { + name: "non-nil pointer to int", + input: new(int), + expected: false, + }, + { + name: "nil slice", + input: []int(nil), + expected: false, + }, + { + name: "empty slice", + input: []int{}, + expected: false, + }, + { + name: "nil map", + input: map[string]int(nil), + expected: true, + }, + { + name: "empty map", + input: map[string]int{}, + expected: false, + }, + { + name: "nil channel", + input: (chan int)(nil), + expected: true, + }, + { + name: "non-nil channel", + input: make(chan int), + expected: false, + }, + { + name: "nil function", + input: (func())(nil), + expected: true, + }, + { + name: "nested nil pointer", + input: &nilIntPointer, + expected: true, + }, + { + name: "interface with stored nil pointer", + input: vInterface, + expected: true, + }, + { + name: "non-nil interface value", + input: interface{}("test"), + expected: false, + }, + } + + // Execute all test cases + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsContainsNilPointer(tt.input); got != tt.expected { + t.Errorf("IsContainsNilPointer() = %v, want %v", got, tt.expected) + } + }) + } +} diff --git a/tests/integration/database_sql_with_tx_control_test.go b/tests/integration/database_sql_with_tx_control_test.go index 7ad16b394..f0949f6b8 100644 --- a/tests/integration/database_sql_with_tx_control_test.go +++ b/tests/integration/database_sql_with_tx_control_test.go @@ -38,9 +38,9 @@ func TestDatabaseSqlWithTxControl(t *testing.T) { ydb.WithTxControl( tx.WithTxControlHook(ctx, func(txControl *tx.Control) { hookCalled = true - require.Equal(t, tx.SerializableReadWriteTxControl(), txControl) + require.Equal(t, tx.SerializableReadWriteTxControl(tx.CommitTx()), txControl) }), - tx.SerializableReadWriteTxControl(), + tx.SerializableReadWriteTxControl(tx.CommitTx()), ), db, func(ctx context.Context, cc *sql.Conn) error { _, err := db.QueryContext(ctx, "SELECT 1") @@ -56,9 +56,9 @@ func TestDatabaseSqlWithTxControl(t *testing.T) { ydb.WithTxControl( tx.WithTxControlHook(ctx, func(txControl *tx.Control) { hookCalled = true - require.Equal(t, tx.SerializableReadWriteTxControl(), txControl) + require.Equal(t, tx.SerializableReadWriteTxControl(tx.CommitTx()), txControl) }), - tx.SerializableReadWriteTxControl(), + tx.SerializableReadWriteTxControl(tx.CommitTx()), ), db, func(ctx context.Context, cc *sql.Conn) error { _, err := db.QueryContext(ctx, "SELECT 1") diff --git a/tests/integration/query_regression_test.go b/tests/integration/query_regression_test.go index 9df274339..c32901f0f 100644 --- a/tests/integration/query_regression_test.go +++ b/tests/integration/query_regression_test.go @@ -45,7 +45,7 @@ DECLARE $val AS UUID; SELECT CAST($val AS Utf8)`, query.WithIdempotent(), query.WithParameters(ydb.ParamsBuilder().Param("$val").UUIDWithIssue1501Value(id).Build()), - query.WithTxControl(tx.SerializableReadWriteTxControl()), + query.WithTxControl(tx.SnapshotReadOnlyTxControl()), ) require.NoError(t, err) @@ -71,7 +71,7 @@ DECLARE $val AS Text; SELECT CAST($val AS UUID)`, query.WithIdempotent(), query.WithParameters(ydb.ParamsBuilder().Param("$val").Text(idString).Build()), - query.WithTxControl(tx.SerializableReadWriteTxControl()), + query.WithTxControl(tx.SnapshotReadOnlyTxControl()), ) require.NoError(t, err) @@ -97,7 +97,7 @@ DECLARE $val AS Text; SELECT CAST($val AS UUID)`, query.WithIdempotent(), query.WithParameters(ydb.ParamsBuilder().Param("$val").Text(idString).Build()), - query.WithTxControl(tx.SerializableReadWriteTxControl()), + query.WithTxControl(tx.SnapshotReadOnlyTxControl()), ) require.NoError(t, err) @@ -125,7 +125,7 @@ DECLARE $val AS Text; SELECT CAST($val AS UUID)`, query.WithIdempotent(), query.WithParameters(ydb.ParamsBuilder().Param("$val").Text(idString).Build()), - query.WithTxControl(tx.SerializableReadWriteTxControl()), + query.WithTxControl(tx.SnapshotReadOnlyTxControl()), ) require.NoError(t, err) @@ -151,7 +151,7 @@ DECLARE $val AS Text; SELECT CAST($val AS UUID)`, query.WithIdempotent(), query.WithParameters(ydb.ParamsBuilder().Param("$val").Text(idString).Build()), - query.WithTxControl(tx.SerializableReadWriteTxControl()), + query.WithTxControl(tx.SnapshotReadOnlyTxControl()), ) require.NoError(t, err) @@ -180,7 +180,7 @@ DECLARE $val AS UUID; SELECT $val`, query.WithIdempotent(), query.WithParameters(ydb.ParamsBuilder().Param("$val").UUIDWithIssue1501Value(id).Build()), - query.WithTxControl(tx.SerializableReadWriteTxControl()), + query.WithTxControl(tx.SnapshotReadOnlyTxControl()), ) require.NoError(t, err) @@ -207,7 +207,7 @@ DECLARE $val AS UUID; SELECT CAST($val AS Utf8)`, query.WithIdempotent(), - query.WithTxControl(query.SerializableReadWriteTxControl()), + query.WithTxControl(query.SnapshotReadOnlyTxControl()), query.WithParameters(ydb.ParamsBuilder().Param("$val").Uuid(id).Build()), ) @@ -233,7 +233,7 @@ DECLARE $val AS Utf8; SELECT CAST($val AS UUID)`, query.WithIdempotent(), query.WithParameters(ydb.ParamsBuilder().Param("$val").Text(idString).Build()), - query.WithTxControl(query.SerializableReadWriteTxControl()), + query.WithTxControl(query.SnapshotReadOnlyTxControl()), ) require.NoError(t, err) @@ -261,7 +261,7 @@ DECLARE $val AS UUID; SELECT $val`, query.WithIdempotent(), query.WithParameters(ydb.ParamsBuilder().Param("$val").Uuid(id).Build()), - query.WithTxControl(query.SerializableReadWriteTxControl()), + query.WithTxControl(query.SnapshotReadOnlyTxControl()), ) require.NoError(t, err)