diff --git a/.golangci.yml b/.golangci.yml index b373a0322..1cdf55032 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -230,7 +230,6 @@ linters: - forcetypeassert - funlen - gochecknoglobals - - gocognit - godot - goerr113 - golint diff --git a/internal/backoff/backoff_test.go b/internal/backoff/backoff_test.go index 45cd7fd1e..4e2bd262c 100644 --- a/internal/backoff/backoff_test.go +++ b/internal/backoff/backoff_test.go @@ -9,6 +9,12 @@ import ( "github.com/stretchr/testify/require" ) +type exp struct { + eq time.Duration + gte time.Duration + lte time.Duration +} + func TestDelays(t *testing.T) { duration := func(s string) (d time.Duration) { d, err := time.ParseDuration(s) @@ -42,11 +48,6 @@ func TestDelays(t *testing.T) { } func TestLogBackoff(t *testing.T) { - type exp struct { - eq time.Duration - gte time.Duration - lte time.Duration - } for _, tt := range []struct { backoff Backoff exp []exp @@ -145,24 +146,28 @@ func TestLogBackoff(t *testing.T) { continue } - if gte := exp.gte; act <= gte { - t.Errorf( - "unexpected Backoff delay: %s; want >= %s", - act, gte, - ) - } - if lte := exp.lte; act >= lte { - t.Errorf( - "unexpected Backoff delay: %s; want <= %s", - act, lte, - ) - } + checkExpWithAct(t, exp, act) } } }) } } +func checkExpWithAct(t *testing.T, exp exp, act time.Duration) { + if gte := exp.gte; act <= gte { + t.Errorf( + "unexpected Backoff delay: %s; want >= %s", + act, gte, + ) + } + if lte := exp.lte; act >= lte { + t.Errorf( + "unexpected Backoff delay: %s; want <= %s", + act, lte, + ) + } +} + func TestFastSlowDelaysWithoutJitter(t *testing.T) { for _, tt := range []struct { name string diff --git a/internal/bind/params.go b/internal/bind/params.go index d0e6a82da..b68e4b226 100644 --- a/internal/bind/params.go +++ b/internal/bind/params.go @@ -179,30 +179,11 @@ func Params(args ...interface{}) (parameters []*params.Parameter, _ error) { for i, arg := range args { switch x := arg.(type) { case driver.NamedValue: - if x.Name == "" { - switch xx := x.Value.(type) { - case *params.Parameters: - if len(args) > 1 { - return nil, xerrors.WithStackTrace(errMultipleQueryParameters) - } - parameters = *xx - case *params.Parameter: - parameters = append(parameters, xx) - default: - x.Name = fmt.Sprintf("$p%d", i) - param, err := toYdbParam(x.Name, x.Value) - if err != nil { - return nil, xerrors.WithStackTrace(err) - } - parameters = append(parameters, param) - } - } else { - param, err := toYdbParam(x.Name, x.Value) - if err != nil { - return nil, xerrors.WithStackTrace(err) - } - parameters = append(parameters, param) + driverNamedParams, err := checkDriverNamedValue(i, params, x, args) + if err != nil { + return nil, err } + params = driverNamedParams case sql.NamedArg: if x.Name == "" { return nil, xerrors.WithStackTrace(errUnnamedParam) @@ -233,3 +214,40 @@ func Params(args ...interface{}) (parameters []*params.Parameter, _ error) { return parameters, nil } + +// checkDriverNamedValue checks the driver.NamedValue and adds it to the params slice. +func checkDriverNamedValue( + i int, + params []table.ParameterOption, + x driver.NamedValue, + args []interface{}, +) ([]table.ParameterOption, error) { + if x.Name == "" { + switch xx := x.Value.(type) { + case *table.QueryParameters: + if len(args) > 1 { + return nil, xerrors.WithStackTrace(errMultipleQueryParameters) + } + xx.Each(func(name string, v types.Value) { + params = append(params, table.ValueParam(name, v)) + }) + case table.ParameterOption: + params = append(params, xx) + default: + x.Name = fmt.Sprintf("$p%d", i) + param, err := toYdbParam(x.Name, x.Value) + if err != nil { + return nil, xerrors.WithStackTrace(err) + } + params = append(params, param) + } + } else { + param, err := toYdbParam(x.Name, x.Value) + if err != nil { + return nil, xerrors.WithStackTrace(err) + } + params = append(params, param) + } + + return params, nil +} diff --git a/internal/cmd/gtrace/main.go b/internal/cmd/gtrace/main.go index 99f2bf98a..384135676 100644 --- a/internal/cmd/gtrace/main.go +++ b/internal/cmd/gtrace/main.go @@ -142,6 +142,30 @@ func main() { if err != nil { panic(fmt.Sprintf("type error: %v", err)) } + items := findGtraceGen(astFiles, pkgFiles, srcFilePath) + p := Package{ + Package: pkg, + BuildConstraints: buildConstraints, + } + traces := make(map[string]*Trace) + for _, item := range items { + t := &Trace{ + Name: item.Ident.Name, + } + p.Traces = append(p.Traces, t) + traces[item.Ident.Name] = t + } + extractNameAndDetails(items, &p, info, traces) + checkErrsInWriters(writers, p) + + log.Println("OK") +} + +// findGtraceGen iterates over the astFiles and pkgFiles to find elements marked with a comment containing "gtrace:gen" +// and collects them into GenItems. +// +//nolint:gocognit +func findGtraceGen(astFiles []*ast.File, pkgFiles []*os.File, srcFilePath string) []*GenItem { var items []*GenItem for i, astFile := range astFiles { if pkgFiles[i].Name() != srcFilePath { @@ -151,66 +175,72 @@ func main() { depth int item *GenItem ) - ast.Inspect(astFile, func(n ast.Node) (next bool) { - if n == nil { - item = nil - depth-- + ast.Inspect(astFile, + func(n ast.Node) (next bool) { + if n == nil { + item = nil + depth-- - return true - } - defer func() { - if next { - depth++ + return true } - }() + defer func() { + if next { + depth++ + } + }() - switch v := n.(type) { - case *ast.FuncDecl, *ast.ValueSpec: - return false + switch v := n.(type) { + case *ast.FuncDecl, *ast.ValueSpec: + return false - case *ast.Ident: - if item != nil { - item.Ident = v - } + case *ast.Ident: + if item != nil { + item.Ident = v + } - return false + return false - case *ast.CommentGroup: - for _, c := range v.List { - if strings.Contains(strings.TrimPrefix(c.Text, "//"), "gtrace:gen") { - if item == nil { - item = &GenItem{} + case *ast.CommentGroup: + for _, c := range v.List { + if strings.Contains(strings.TrimPrefix(c.Text, "//"), "gtrace:gen") { + if item == nil { + item = &GenItem{} + } } } - } - return false + return false - case *ast.StructType: - if item != nil { - item.StructType = v - items = append(items, item) - item = nil - } + case *ast.StructType: + if item != nil { + item.StructType = v + items = append(items, item) + item = nil + } - return false - } + return false + } - return true - }) - } - p := Package{ - Package: pkg, - BuildConstraints: buildConstraints, + return true + }, + ) } - traces := make(map[string]*Trace) - for _, item := range items { - t := &Trace{ - Name: item.Ident.Name, + + return items +} + +// checkErrsInWriters iterate over each Writer in the writers slice and checks errors +func checkErrsInWriters(writers []*Writer, p Package) { + for _, w := range writers { + if err := w.Write(p); err != nil { + panic(err) } - p.Traces = append(p.Traces, t) - traces[item.Ident.Name] = t } +} + +// extractNameAndDetails extracts the name and details of functions from the given items, and populates the traces +// in the package with hooks based on the extracted functions. +func extractNameAndDetails(items []*GenItem, p *Package, info *types.Info, traces map[string]*Trace) { for i, item := range items { t := p.Traces[i] for _, field := range item.StructType.Fields.List { @@ -237,13 +267,6 @@ func main() { }) } } - for _, w := range writers { - if err := w.Write(p); err != nil { - panic(err) - } - } - - log.Println("OK") } func buildFunc(info *types.Info, traces map[string]*Trace, fn *ast.FuncType) (ret *Func, err error) { diff --git a/internal/decimal/decimal.go b/internal/decimal/decimal.go index a4753992a..72ffae8df 100644 --- a/internal/decimal/decimal.go +++ b/internal/decimal/decimal.go @@ -117,39 +117,9 @@ func Parse(s string, precision, scale uint32) (*big.Int, error) { integral := precision - scale var dot bool - for ; len(s) > 0; s = s[1:] { - c := s[0] - if c == '.' { - if dot { - return nil, syntaxError(s) - } - dot = true - - continue - } - if dot { - if scale > 0 { - scale-- - } else { - break - } - } - - if !isDigit(c) { - return nil, syntaxError(s) - } - - v.Mul(v, ten) - v.Add(v, big.NewInt(int64(c-'0'))) - - if !dot && v.Cmp(zero) > 0 && integral == 0 { - if neg { - return neginf, nil - } - - return inf, nil - } - integral-- + bInt, done, err := dotStringAnalysis(s, dot, scale, v, integral, neg) + if done { + return bInt, err } //nolint:nestif if len(s) > 0 { // Characters remaining. @@ -177,12 +147,17 @@ func Parse(s string, precision, scale uint32) (*big.Int, error) { } } } + multipliedByTen(v, scale, neg) + + return v, nil +} + +// multipliedByTen multiplies the given big.Int value by 10 raised to the power of scale. +func multipliedByTen(v *big.Int, scale uint32, neg bool) { v.Mul(v, pow(ten, scale)) if neg { v.Neg(v) } - - return v, nil } // Format returns the string representation of x with the given precision and @@ -266,6 +241,57 @@ func Format(x *big.Int, precision, scale uint32) string { return xstring.FromBytes(bts[pos:]) } +// dotStringAnalysis performs analysis on a string representation of a decimal number. +func dotStringAnalysis( + s string, + dot bool, + scale uint32, + v *big.Int, + integral uint32, + neg bool, +) ( + *big.Int, + bool, + error, +) { + for ; len(s) > 0; s = s[1:] { + c := s[0] + if c == '.' { + if dot { + return nil, true, syntaxError(s) + } + dot = true + + continue + } + if dot { + if scale > 0 { + scale-- + } else { + break + } + } + + if !isDigit(c) { + return nil, true, syntaxError(s) + } + + v.Mul(v, ten) + v.Add(v, big.NewInt(int64(c-'0'))) + + if !dot && v.Cmp(zero) > 0 && integral == 0 { + if neg { + return neginf, true, nil + } + + return inf, true, nil + } + integral-- + } + + return nil, false, nil +} + // BigIntToByte returns the 16-byte array representation of x. // // If x value does not fit in 16 bytes with given precision, it returns 16-byte diff --git a/internal/stack/record.go b/internal/stack/record.go index 3ae77866f..4455caf20 100644 --- a/internal/stack/record.go +++ b/internal/stack/record.go @@ -97,12 +97,7 @@ func (c call) Record(opts ...recordOption) string { funcName string file = c.file ) - if i := strings.LastIndex(file, "/"); i > -1 { - file = file[i+1:] - } - if i := strings.LastIndex(name, "/"); i > -1 { - pkgPath, name = name[:i], name[i+1:] - } + file, name, pkgPath = findFileNameAndPkgPath(file, name, pkgPath) split := strings.Split(name, ".") lambdas := make([]string, 0, len(split)) for i := range split { @@ -171,6 +166,18 @@ func (c call) Record(opts ...recordOption) string { return buffer.String() } +// findFileNameAndPkgPath finds the file, name and package path from the given file path, name and package path. +func findFileNameAndPkgPath(file, name, pkgPath string) (string, string, string) { + if i := strings.LastIndex(file, "/"); i > -1 { + file = file[i+1:] + } + if i := strings.LastIndex(name, "/"); i > -1 { + pkgPath, name = name[:i], name[i+1:] + } + + return file, name, pkgPath +} + func (c call) FunctionID() string { return c.Record(Lambda(false), FileName(false)) } diff --git a/internal/table/client.go b/internal/table/client.go index 505a66f01..b33660c8f 100644 --- a/internal/table/client.go +++ b/internal/table/client.go @@ -173,21 +173,7 @@ func (c *Client) createSession(ctx context.Context, opts ...createSessionOption) defer cancel() } - closeSession := func(s *session) { - if s == nil { - return - } - - closeSessionCtx := xcontext.WithoutDeadline(ctx) - - if timeout := c.config.DeleteTimeout(); timeout > 0 { - var cancel context.CancelFunc - createSessionCtx, cancel = xcontext.WithTimeout(closeSessionCtx, timeout) - defer cancel() - } - - _ = s.Close(closeSessionCtx) - } + closeSession := onCloseSession(ctx, createSessionCtx, c) s, err = c.build(createSessionCtx) @@ -223,6 +209,25 @@ func (c *Client) createSession(ctx context.Context, opts ...createSessionOption) } } +// onCloseSession is a closure function that takes a session and performs the actions to close it. +func onCloseSession(ctx, createSessionCtx context.Context, c *Client) func(s *session) { + return func(s *session) { + if s == nil { + return + } + + closeSessionCtx := xcontext.WithoutDeadline(ctx) + + if timeout := c.config.DeleteTimeout(); timeout > 0 { + var cancel context.CancelFunc + createSessionCtx, cancel = xcontext.WithTimeout(closeSessionCtx, timeout) + defer cancel() + } + + _ = s.Close(closeSessionCtx) + } +} + func (c *Client) CreateSession(ctx context.Context, opts ...table.Option) (_ table.ClosableSession, err error) { if c == nil { return nil, xerrors.WithStackTrace(errNilClient) diff --git a/internal/table/retry_test.go b/internal/table/retry_test.go index 008633786..6665cb770 100644 --- a/internal/table/retry_test.go +++ b/internal/table/retry_test.go @@ -457,30 +457,44 @@ func TestRetryWithCustomErrors(t *testing.T) { }, nil, ) - //nolint:nestif - if test.retriable { - if i != limit { - t.Fatalf("unexpected i: %d, err: %v", i, err) - } - if test.deleteSession { - if len(sessions) != limit { - t.Fatalf("unexpected len(sessions): %d, err: %v", len(sessions), err) - } - for s, n := range sessions { - if n != 1 { - t.Fatalf("unexpected session usage: %d, session: %v", n, s.ID()) - } - } - } - } else { - if i != 1 { - t.Fatalf("unexpected i: %d, err: %v", i, err) - } - if len(sessions) != 1 { - t.Fatalf("unexpected len(sessions): %d, err: %v", len(sessions), err) + checkResultsRetryWithCustomErrors(t, test, i, limit, err, sessions) + }) + } +} + +// checkResultsRetryWithCustomErrors is a helper function for testing all suspensions of a custom operation function +// against all deadlines. +func checkResultsRetryWithCustomErrors(t *testing.T, test struct { + error error + retriable bool + deleteSession bool +}, i int, + limit int, + err error, + sessions map[table.Session]int, +) { + //nolint:nestif + if test.retriable { + if i != limit { + t.Fatalf("unexpected i: %d, err: %v", i, err) + } + if test.deleteSession { + if len(sessions) != limit { + t.Fatalf("unexpected len(sessions): %d, err: %v", len(sessions), err) + } + for s, n := range sessions { + if n != 1 { + t.Fatalf("unexpected session usage: %d, session: %v", n, s.ID()) } } - }) + } + } else { + if i != 1 { + t.Fatalf("unexpected i: %d, err: %v", i, err) + } + if len(sessions) != 1 { + t.Fatalf("unexpected len(sessions): %d, err: %v", len(sessions), err) + } } } diff --git a/internal/table/scanner/scanner.go b/internal/table/scanner/scanner.go index f541a1235..f810072d4 100644 --- a/internal/table/scanner/scanner.go +++ b/internal/table/scanner/scanner.go @@ -899,153 +899,47 @@ func (s *valueScanner) scanOptional(v interface{}, defaultValueForOptional bool) } switch v := v.(type) { case **bool: - if s.isNull() { - *v = nil - } else { - src := s.bool() - *v = &src - } + handleBoolCase(s, v) case **int8: - if s.isNull() { - *v = nil - } else { - src := s.int8() - *v = &src - } + handleInt8Case(s, v) case **int16: - if s.isNull() { - *v = nil - } else { - src := s.int16() - *v = &src - } + handleInt16Case(s, v) case **int32: - if s.isNull() { - *v = nil - } else { - src := s.int32() - *v = &src - } + handleInt32Case(s, v) case **int: - if s.isNull() { - *v = nil - } else { - src := int(s.int32()) - *v = &src - } + handleIntCase(s, v) case **int64: - if s.isNull() { - *v = nil - } else { - src := s.int64() - *v = &src - } + handleInt64Case(s, v) case **uint8: - if s.isNull() { - *v = nil - } else { - src := s.uint8() - *v = &src - } + handleUint8Case(s, v) case **uint16: - if s.isNull() { - *v = nil - } else { - src := s.uint16() - *v = &src - } + handleUint16Case(s, v) case **uint32: - if s.isNull() { - *v = nil - } else { - src := s.uint32() - *v = &src - } + handleUint32Case(s, v) case **uint: - if s.isNull() { - *v = nil - } else { - src := uint(s.uint32()) - *v = &src - } + handleUintCase(s, v) case **uint64: - if s.isNull() { - *v = nil - } else { - src := s.uint64() - *v = &src - } + handleUint64Case(s, v) case **float32: - if s.isNull() { - *v = nil - } else { - src := s.float() - *v = &src - } + handleFloat32Case(s, v) case **float64: - if s.isNull() { - *v = nil - } else { - src := s.double() - *v = &src - } + handleFloat64Case(s, v) case **time.Time: - if s.isNull() { - *v = nil - } else { - s.unwrap() - var src time.Time - s.setTime(&src) - *v = &src - } + handleTimeCase(s, v) case **time.Duration: - if s.isNull() { - *v = nil - } else { - src := value.IntervalToDuration(s.int64()) - *v = &src - } + handleDurationCase(s, v) case **string: - if s.isNull() { - *v = nil - } else { - s.unwrap() - var src string - s.setString(&src) - *v = &src - } + handleStringCase(s, v) case **[]byte: - if s.isNull() { - *v = nil - } else { - s.unwrap() - var src []byte - s.setByte(&src) - *v = &src - } + handleSliceByteCase(s, v) case **[16]byte: - if s.isNull() { - *v = nil - } else { - src := s.uint128() - *v = &src - } + handleArrByte16Case(s, v) case **interface{}: - if s.isNull() { - *v = nil - } else { - src := s.any() - *v = &src - } + handleInterfaceCase(s, v) case *value.Value: *v = s.value() - case **decimal.Decimal: - if s.isNull() { - *v = nil - } else { - src := s.unwrapDecimal() - *v = &src - } + case **types.Decimal: + handleDecimalCase(s, v) case scanner.Scanner: err := v.UnmarshalYDB(s.converter) if err != nil { @@ -1092,6 +986,212 @@ func (s *valueScanner) scanOptional(v interface{}, defaultValueForOptional bool) } } +// handleBoolCase handles the special case for handling boolean values in the scanner. +func handleBoolCase(s *valueScanner, v **bool) { + if s.isNull() { + *v = nil + } else { + src := s.bool() + *v = &src + } +} + +// handleInt8Case handles the special case for handling int8 values in the scanner. +func handleInt8Case(s *valueScanner, v **int8) { + if s.isNull() { + *v = nil + } else { + src := s.int8() + *v = &src + } +} + +// handleInt16Case handles the special case for handling int16 values in the scanner. +func handleInt16Case(s *valueScanner, v **int16) { + if s.isNull() { + *v = nil + } else { + src := s.int16() + *v = &src + } +} + +// handleInt32Case handles the special case for handling int32 values in the scanner. +func handleInt32Case(s *valueScanner, v **int32) { + if s.isNull() { + *v = nil + } else { + src := s.int32() + *v = &src + } +} + +// handleIntCase handles the special case for handling int values in the scanner. +func handleIntCase(s *valueScanner, v **int) { + if s.isNull() { + *v = nil + } else { + src := int(s.int32()) + *v = &src + } +} + +// handleInt64Case handles the special case for handling int64 values in the scanner. +func handleInt64Case(s *valueScanner, v **int64) { + if s.isNull() { + *v = nil + } else { + src := s.int64() + *v = &src + } +} + +// handleUint8Case handles the special case for handling uint8 values in the scanner. +func handleUint8Case(s *valueScanner, v **uint8) { + if s.isNull() { + *v = nil + } else { + src := s.uint8() + *v = &src + } +} + +// handleUint16Case handles the special case for handling uint16 values in the scanner. +func handleUint16Case(s *valueScanner, v **uint16) { + if s.isNull() { + *v = nil + } else { + src := s.uint16() + *v = &src + } +} + +// handleUint32Case handles the special case for handling uint32 values in the scanner. +func handleUint32Case(s *valueScanner, v **uint32) { + if s.isNull() { + *v = nil + } else { + src := s.uint32() + *v = &src + } +} + +// handleUintCase handles the special case for handling uint values in the scanner. +func handleUintCase(s *valueScanner, v **uint) { + if s.isNull() { + *v = nil + } else { + src := uint(s.uint32()) + *v = &src + } +} + +// handleUint64Case handles the special case for handling uint64 values in the scanner. +func handleUint64Case(s *valueScanner, v **uint64) { + if s.isNull() { + *v = nil + } else { + src := s.uint64() + *v = &src + } +} + +// handleFloat32Case handles the special case for handling float32 values in the scanner. +func handleFloat32Case(s *valueScanner, v **float32) { + if s.isNull() { + *v = nil + } else { + src := s.float() + *v = &src + } +} + +// handleFloat64Case handles the special case for handling float64 values in the scanner. +func handleFloat64Case(s *valueScanner, v **float64) { + if s.isNull() { + *v = nil + } else { + src := s.double() + *v = &src + } +} + +// handleTimeCase handles the special case for handling time.Time values in the scanner. +func handleTimeCase(s *valueScanner, v **time.Time) { + if s.isNull() { + *v = nil + } else { + s.unwrap() + var src time.Time + s.setTime(&src) + *v = &src + } +} + +// handleDurationCase handles the special case for handling time.Duration values in the scanner. +func handleDurationCase(s *valueScanner, v **time.Duration) { + if s.isNull() { + *v = nil + } else { + src := value.IntervalToDuration(s.int64()) + *v = &src + } +} + +// handleStringCase handles the special case for handling string values in the scanner. +func handleStringCase(s *valueScanner, v **string) { + if s.isNull() { + *v = nil + } else { + s.unwrap() + var src string + s.setString(&src) + *v = &src + } +} + +// handleSliceByteCase handles the special case for handling []byte values in the scanner. +func handleSliceByteCase(s *valueScanner, v **[]byte) { + if s.isNull() { + *v = nil + } else { + s.unwrap() + var src []byte + s.setByte(&src) + *v = &src + } +} + +// handleArrByte16Case handles the special case for handling [16]byte values in the scanner. +func handleArrByte16Case(s *valueScanner, v **[16]byte) { + if s.isNull() { + *v = nil + } else { + src := s.uint128() + *v = &src + } +} + +// handleInterfaceCase handles the special case for handling interface{} values in the scanner. +func handleInterfaceCase(s *valueScanner, v **interface{}) { + if s.isNull() { + *v = nil + } else { + src := s.any() + *v = &src + } +} + +// handleDecimalCase handles the special case for handling types.Decimal values in the scanner. +func handleDecimalCase(s *valueScanner, v **types.Decimal) { + if s.isNull() { + *v = nil + } else { + src := s.unwrapDecimal() + *v = &src + } +} + func (s *valueScanner) setDefaultValue(dst interface{}) { switch v := dst.(type) { case *bool: diff --git a/internal/table/scanner/scanner_test.go b/internal/table/scanner/scanner_test.go index 44f09ff45..bd3250bd0 100644 --- a/internal/table/scanner/scanner_test.go +++ b/internal/table/scanner/scanner_test.go @@ -18,498 +18,622 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/table/types" ) -//nolint:gocyclo func valueFromPrimitiveTypeID(c *column, r xrand.Rand) (*Ydb.Value, interface{}) { rv := r.Int64(math.MaxInt16) switch c.typeID { case Ydb.Type_BOOL: - v := rv%2 == 1 - ydbval := &Ydb.Value{ - Value: &Ydb.Value_BoolValue{ - BoolValue: v, - }, - } - if c.optional && !c.testDefault { - vp := &v + return getValueBool(c, rv) + case Ydb.Type_INT8: + return getValueInt8(c, rv) + case Ydb.Type_UINT8: + return getValueUint8(c, rv) + case Ydb.Type_INT16: + return getValueInt16(c, rv) + case Ydb.Type_UINT16: + return getValueUint16(c, rv) + case Ydb.Type_INT32: + return getValueInt32(c, rv) + case Ydb.Type_UINT32: + return getValueUint32(c, rv) + case Ydb.Type_INT64: + return getValueInt64(c, rv) + case Ydb.Type_UINT64: + return getValueUint64(c, rv) + case Ydb.Type_FLOAT: + return getValueFloat32(c, rv) + case Ydb.Type_DOUBLE: + return getValueFloat64(c, rv) + case Ydb.Type_DATE: + return getValueDate(c, rv) + case Ydb.Type_DATETIME: + return getValueDateTime(c, rv) + case Ydb.Type_TIMESTAMP: + return getValueTimestamp(c, rv) + case Ydb.Type_INTERVAL: + return getValueInterval(c, rv) + case Ydb.Type_TZ_DATE: + return getValueTzDate(c) + case Ydb.Type_TZ_DATETIME: + return getValueTzDatetime(c, rv) + case Ydb.Type_TZ_TIMESTAMP: + return getValueTzTimestamp(c, rv) + case Ydb.Type_STRING: + return getValueString(c, rv) + case Ydb.Type_UTF8: + return getValueUtf8(c, rv) + case Ydb.Type_YSON: + return getValueYSON(c, rv) + case Ydb.Type_JSON: + return getValueJSON(c, rv) + case Ydb.Type_UUID: + return getValueUUID(c, rv) + case Ydb.Type_JSON_DOCUMENT: + return getValueJSONDocument(c, rv) + case Ydb.Type_DYNUMBER: + return getValueDyNumber(c, rv) + default: + panic("ydb: unexpected types") + } +} - return ydbval, &vp - } +// getValueDyNumber extracts a dynamic number value. +func getValueDyNumber(c *column, rv int64) (*Ydb.Value, interface{}) { + v := strconv.FormatUint(uint64(rv), 10) + ydbval := &Ydb.Value{ + Value: &Ydb.Value_TextValue{ + TextValue: v, + }, + } + if c.optional && !c.testDefault { + vp := &v - return ydbval, &v - case Ydb.Type_INT8: - v := int8(rv) - ydbval := &Ydb.Value{ - Value: &Ydb.Value_Int32Value{ - Int32Value: int32(v), - }, - } - if c.optional && !c.testDefault { - vp := &v + return ydbval, &vp + } - return ydbval, &vp - } + return ydbval, &v +} - return ydbval, &v - case Ydb.Type_UINT8: - if c.nilValue { - ydbval := &Ydb.Value{ - Value: &Ydb.Value_NullFlagValue{}, - } - if c.testDefault { - var dv uint8 +// getValueJSONDocument extracts a JSON document value. +func getValueJSONDocument(c *column, rv int64) (*Ydb.Value, interface{}) { + v := strconv.FormatUint(uint64(rv), 10) + ydbval := &Ydb.Value{ + Value: &Ydb.Value_TextValue{ + TextValue: v, + }, + } + src := []byte(v) + if c.optional && !c.testDefault { + vp := &src - return ydbval, &dv - } - var dv *uint8 + return ydbval, &vp + } - return ydbval, &dv - } - v := uint8(rv) + return ydbval, &src +} + +// getValueUUID extracts a UUID value. +func getValueUUID(c *column, rv int64) (*Ydb.Value, interface{}) { + if c.nilValue { ydbval := &Ydb.Value{ - Value: &Ydb.Value_Uint32Value{ - Uint32Value: uint32(v), - }, + Value: &Ydb.Value_NullFlagValue{}, } - if c.optional && !c.testDefault { - vp := &v + if c.testDefault { + var dv [16]byte - return ydbval, &vp + return ydbval, &dv } + var dv *[16]byte - return ydbval, &v - case Ydb.Type_INT16: - v := int16(rv) - ydbval := &Ydb.Value{ - Value: &Ydb.Value_Int32Value{ - Int32Value: int32(v), - }, - } - if c.optional && !c.testDefault { - vp := &v + return ydbval, &dv + } + v := [16]byte{} + binary.BigEndian.PutUint64(v[0:8], uint64(rv)) + binary.BigEndian.PutUint64(v[8:16], uint64(rv)) + ydbval := &Ydb.Value{ + High_128: binary.BigEndian.Uint64(v[0:8]), + Value: &Ydb.Value_Low_128{ + Low_128: binary.BigEndian.Uint64(v[8:16]), + }, + } + if c.optional && !c.testDefault { + vp := &v - return ydbval, &vp - } + return ydbval, &vp + } - return ydbval, &v - case Ydb.Type_UINT16: - v := uint16(rv) - ydbval := &Ydb.Value{ - Value: &Ydb.Value_Uint32Value{ - Uint32Value: uint32(v), - }, - } - if c.optional && !c.testDefault { - vp := &v + return ydbval, &v +} - return ydbval, &vp - } +// getValueJSON extracts a JSON value. +func getValueJSON(c *column, rv int64) (*Ydb.Value, interface{}) { + v := strconv.FormatUint(uint64(rv), 10) + ydbval := &Ydb.Value{ + Value: &Ydb.Value_TextValue{ + TextValue: v, + }, + } + if c.ydbvalue { + vp := types.JSONValue(v) - return ydbval, &v - case Ydb.Type_INT32: - if c.nilValue { - ydbval := &Ydb.Value{ - Value: &Ydb.Value_NullFlagValue{}, - } - if c.testDefault { - var dv int32 + return ydbval, &vp + } + src := []byte(v) + if c.optional && !c.testDefault { + vp := &src - return ydbval, &dv - } - var dv *int32 + return ydbval, &vp + } - return ydbval, &dv - } - v := int32(rv) + return ydbval, &src +} + +// getValueYSON extracts a YSON (Yandex Simple Object Notation) value. +func getValueYSON(c *column, rv int64) (*Ydb.Value, interface{}) { + if c.nilValue { ydbval := &Ydb.Value{ - Value: &Ydb.Value_Int32Value{ - Int32Value: v, - }, + Value: &Ydb.Value_NullFlagValue{}, } - if c.optional && !c.testDefault { - vp := &v + if c.testDefault { + var dv []byte - return ydbval, &vp + return ydbval, &dv } + var dv *[]byte - return ydbval, &v - case Ydb.Type_UINT32: - v := uint32(rv) - ydbval := &Ydb.Value{ - Value: &Ydb.Value_Uint32Value{ - Uint32Value: v, - }, - } - if c.optional && !c.testDefault { - vp := &v + return ydbval, &dv + } + v := strconv.FormatUint(uint64(rv), 10) + ydbval := &Ydb.Value{ + Value: &Ydb.Value_TextValue{ + TextValue: v, + }, + } + src := []byte(v) + if c.optional && !c.testDefault { + vp := &src - return ydbval, &vp - } + return ydbval, &vp + } - return ydbval, &v - case Ydb.Type_INT64: - v := rv - ydbval := &Ydb.Value{ - Value: &Ydb.Value_Int64Value{ - Int64Value: v, - }, - } - if c.ydbvalue { - vp := types.Int64Value(v) + return ydbval, &src +} - return ydbval, &vp - } - if c.scanner { - s := intIncScanner(v + 10) +// getValueUtf8 extracts an UTF-8 string value. +func getValueUtf8(c *column, rv int64) (*Ydb.Value, interface{}) { + v := strconv.FormatUint(uint64(rv), 10) + ydbval := &Ydb.Value{ + Value: &Ydb.Value_TextValue{ + TextValue: v, + }, + } + if c.optional && !c.testDefault { + vp := &v - return ydbval, &s - } - if c.optional && !c.testDefault { - vp := &v + return ydbval, &vp + } - return ydbval, &vp - } + return ydbval, &v +} - return ydbval, &v - case Ydb.Type_UINT64: - v := uint64(rv) +// getValueString extracts a string value from an int64 value. If `c.nilValue` is true, it returns a null value. +func getValueString(c *column, rv int64) (*Ydb.Value, interface{}) { + if c.nilValue { ydbval := &Ydb.Value{ - Value: &Ydb.Value_Uint64Value{ - Uint64Value: v, - }, + Value: &Ydb.Value_NullFlagValue{}, } - if c.optional && !c.testDefault { - vp := &v + if c.testDefault { + var dv []byte - return ydbval, &vp + return ydbval, &dv } + var dv *[]byte - return ydbval, &v - case Ydb.Type_FLOAT: - v := float32(rv) - ydbval := &Ydb.Value{ - Value: &Ydb.Value_FloatValue{ - FloatValue: v, - }, - } - if c.ydbvalue { - vp := types.FloatValue(v) + return ydbval, &dv + } + v := make([]byte, 16) + binary.BigEndian.PutUint64(v[0:8], uint64(rv)) + binary.BigEndian.PutUint64(v[8:16], uint64(rv)) + ydbval := &Ydb.Value{ + Value: &Ydb.Value_BytesValue{ + BytesValue: v, + }, + } + src := v + if c.optional && !c.testDefault { + vp := &src - return ydbval, &vp - } - if c.optional && !c.testDefault { - vp := &v + return ydbval, &vp + } - return ydbval, &vp - } + return ydbval, &src +} - return ydbval, &v - case Ydb.Type_DOUBLE: - v := float64(rv) - ydbval := &Ydb.Value{ - Value: &Ydb.Value_DoubleValue{ - DoubleValue: v, - }, - } - if c.optional && !c.testDefault { - vp := &v +// getValueTzTimestamp extracts the current timestamp in the format of "2006-01-02T15:04:05.000000,Europe/Berlin". +func getValueTzTimestamp(c *column, rv int64) (*Ydb.Value, interface{}) { + rv %= time.Now().Unix() + v := value.TimestampToTime(uint64(rv)).Format(value.LayoutTzTimestamp) + ",Europe/Berlin" + ydbval := &Ydb.Value{ + Value: &Ydb.Value_TextValue{ + TextValue: v, + }, + } + src, _ := value.TzTimestampToTime(v) + if c.optional && !c.testDefault { + vp := &src - return ydbval, &vp - } + return ydbval, &vp + } - return ydbval, &v - case Ydb.Type_DATE: - v := uint32(rv) + return ydbval, &src +} + +// getValueTzDatetime extracts the current time in the format of "2006-01-02T15:04:05,Europe/Berlin". +func getValueTzDatetime(c *column, rv int64) (*Ydb.Value, interface{}) { + if c.nilValue { ydbval := &Ydb.Value{ - Value: &Ydb.Value_Uint32Value{ - Uint32Value: v, - }, + Value: &Ydb.Value_NullFlagValue{}, } - src := value.DateToTime(v) - if c.scanner { - s := dateScanner(src) + if c.testDefault { + var dv time.Time - return ydbval, &s + return ydbval, &dv } - if c.optional && !c.testDefault { - vp := &src + var dv *time.Time - return ydbval, &vp - } + return ydbval, &dv + } + rv %= time.Now().Unix() + v := value.DatetimeToTime(uint32(rv)).Format(value.LayoutTzDatetime) + ",Europe/Berlin" + ydbval := &Ydb.Value{ + Value: &Ydb.Value_TextValue{ + TextValue: v, + }, + } + src, _ := value.TzDatetimeToTime(v) + if c.optional && !c.testDefault { + vp := &src - return ydbval, &src - case Ydb.Type_DATETIME: - v := uint32(rv) - ydbval := &Ydb.Value{ - Value: &Ydb.Value_Uint32Value{ - Uint32Value: v, - }, - } - src := value.DatetimeToTime(v) - if c.optional && !c.testDefault { - vp := &src + return ydbval, &vp + } - return ydbval, &vp - } + return ydbval, &src +} - return ydbval, &src - case Ydb.Type_TIMESTAMP: - v := uint64(rv) +// getValueTzDate extracts the current time in the format of "2006-01-02,Europe/Berlin" +func getValueTzDate(c *column) (*Ydb.Value, interface{}) { + v := time.Now().Format(value.LayoutDate) + ",Europe/Berlin" + ydbval := &Ydb.Value{ + Value: &Ydb.Value_TextValue{ + TextValue: v, + }, + } + src, _ := value.TzDateToTime(v) + if c.optional && !c.testDefault { + vp := &src + + return ydbval, &vp + } + + return ydbval, &src +} + +// getValueInterval extracts int64 value from the input int64 and returns it wrapped in a Ydb.Value struct. +func getValueInterval(c *column, rv int64) (*Ydb.Value, interface{}) { + if c.nilValue { ydbval := &Ydb.Value{ - Value: &Ydb.Value_Uint64Value{ - Uint64Value: v, - }, + Value: &Ydb.Value_NullFlagValue{}, } - src := value.TimestampToTime(v) - if c.optional && !c.testDefault { - vp := &src + if c.testDefault { + var dv time.Duration - return ydbval, &vp + return ydbval, &dv } + var dv *time.Duration - return ydbval, &src - case Ydb.Type_INTERVAL: - if c.nilValue { - ydbval := &Ydb.Value{ - Value: &Ydb.Value_NullFlagValue{}, - } - if c.testDefault { - var dv time.Duration + return ydbval, &dv + } + rv %= time.Now().Unix() + v := rv + ydbval := &Ydb.Value{ + Value: &Ydb.Value_Int64Value{ + Int64Value: v, + }, + } + src := value.IntervalToDuration(v) + if c.optional && !c.testDefault { + vp := &src - return ydbval, &dv - } - var dv *time.Duration + return ydbval, &vp + } - return ydbval, &dv - } - rv %= time.Now().Unix() - v := rv - ydbval := &Ydb.Value{ - Value: &Ydb.Value_Int64Value{ - Int64Value: v, - }, - } - src := value.IntervalToDuration(v) - if c.optional && !c.testDefault { - vp := &src + return ydbval, &src +} - return ydbval, &vp - } +// getValueTimestamp extracts uint64 value from an int64 and returns it wrapped in a Ydb.Value struct. +func getValueTimestamp(c *column, rv int64) (*Ydb.Value, interface{}) { + v := uint64(rv) + ydbval := &Ydb.Value{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: v, + }, + } + src := value.TimestampToTime(v) + if c.optional && !c.testDefault { + vp := &src - return ydbval, &src - case Ydb.Type_TZ_DATE: - v := time.Now().Format(value.LayoutDate) + ",Europe/Berlin" - ydbval := &Ydb.Value{ - Value: &Ydb.Value_TextValue{ - TextValue: v, - }, - } - src, _ := value.TzDateToTime(v) - if c.optional && !c.testDefault { - vp := &src + return ydbval, &vp + } - return ydbval, &vp - } + return ydbval, &src +} - return ydbval, &src - case Ydb.Type_TZ_DATETIME: - if c.nilValue { - ydbval := &Ydb.Value{ - Value: &Ydb.Value_NullFlagValue{}, - } - if c.testDefault { - var dv time.Time +// getValueDateTime extracts uint32 value from an int64 and returns it wrapped in a Ydb.Value struct. +func getValueDateTime(c *column, rv int64) (*Ydb.Value, interface{}) { + v := uint32(rv) + ydbval := &Ydb.Value{ + Value: &Ydb.Value_Uint32Value{ + Uint32Value: v, + }, + } + src := value.DatetimeToTime(v) + if c.optional && !c.testDefault { + vp := &src - return ydbval, &dv - } - var dv *time.Time + return ydbval, &vp + } - return ydbval, &dv - } - rv %= time.Now().Unix() - v := value.DatetimeToTime(uint32(rv)).Format(value.LayoutTzDatetime) + ",Europe/Berlin" - ydbval := &Ydb.Value{ - Value: &Ydb.Value_TextValue{ - TextValue: v, - }, - } - src, _ := value.TzDatetimeToTime(v) - if c.optional && !c.testDefault { - vp := &src + return ydbval, &src +} - return ydbval, &vp - } +// getValueDate extracts uint32 value from an int64 and returns it wrapped in a Ydb.Value struct. +func getValueDate(c *column, rv int64) (*Ydb.Value, interface{}) { + v := uint32(rv) + ydbval := &Ydb.Value{ + Value: &Ydb.Value_Uint32Value{ + Uint32Value: v, + }, + } + src := value.DateToTime(v) + if c.scanner { + s := dateScanner(src) - return ydbval, &src - case Ydb.Type_TZ_TIMESTAMP: - rv %= time.Now().Unix() - v := value.TimestampToTime(uint64(rv)).Format(value.LayoutTzTimestamp) + ",Europe/Berlin" - ydbval := &Ydb.Value{ - Value: &Ydb.Value_TextValue{ - TextValue: v, - }, - } - src, _ := value.TzTimestampToTime(v) - if c.optional && !c.testDefault { - vp := &src + return ydbval, &s + } + if c.optional && !c.testDefault { + vp := &src - return ydbval, &vp - } + return ydbval, &vp + } - return ydbval, &src - case Ydb.Type_STRING: - if c.nilValue { - ydbval := &Ydb.Value{ - Value: &Ydb.Value_NullFlagValue{}, - } - if c.testDefault { - var dv []byte + return ydbval, &src +} - return ydbval, &dv - } - var dv *[]byte +// getValueFloat64 extracts float64 value from an int64 and returns it wrapped in a Ydb.Value struct. +func getValueFloat64(c *column, rv int64) (*Ydb.Value, interface{}) { + v := float64(rv) + ydbval := &Ydb.Value{ + Value: &Ydb.Value_DoubleValue{ + DoubleValue: v, + }, + } + if c.optional && !c.testDefault { + vp := &v - return ydbval, &dv - } - v := make([]byte, 16) - binary.BigEndian.PutUint64(v[0:8], uint64(rv)) - binary.BigEndian.PutUint64(v[8:16], uint64(rv)) - ydbval := &Ydb.Value{ - Value: &Ydb.Value_BytesValue{ - BytesValue: v, - }, - } - src := v - if c.optional && !c.testDefault { - vp := &src + return ydbval, &vp + } - return ydbval, &vp - } + return ydbval, &v +} - return ydbval, &src - case Ydb.Type_UTF8: - v := strconv.FormatUint(uint64(rv), 10) - ydbval := &Ydb.Value{ - Value: &Ydb.Value_TextValue{ - TextValue: v, - }, - } - if c.optional && !c.testDefault { - vp := &v +// getValueFloat32 extracts float32 value from an int64 and returns it wrapped in a Ydb.Value struct. +func getValueFloat32(c *column, rv int64) (*Ydb.Value, interface{}) { + v := float32(rv) + ydbval := &Ydb.Value{ + Value: &Ydb.Value_FloatValue{ + FloatValue: v, + }, + } + if c.ydbvalue { + vp := types.FloatValue(v) - return ydbval, &vp - } + return ydbval, &vp + } + if c.optional && !c.testDefault { + vp := &v - return ydbval, &v - case Ydb.Type_YSON: - if c.nilValue { - ydbval := &Ydb.Value{ - Value: &Ydb.Value_NullFlagValue{}, - } - if c.testDefault { - var dv []byte + return ydbval, &vp + } - return ydbval, &dv - } - var dv *[]byte + return ydbval, &v +} - return ydbval, &dv - } - v := strconv.FormatUint(uint64(rv), 10) - ydbval := &Ydb.Value{ - Value: &Ydb.Value_TextValue{ - TextValue: v, - }, - } - src := []byte(v) - if c.optional && !c.testDefault { - vp := &src +// getValueUint64 extracts uint64 value from an int64 and returns it wrapped in a Ydb.Value struct. +func getValueUint64(c *column, rv int64) (*Ydb.Value, interface{}) { + v := uint64(rv) + ydbval := &Ydb.Value{ + Value: &Ydb.Value_Uint64Value{ + Uint64Value: v, + }, + } + if c.optional && !c.testDefault { + vp := &v - return ydbval, &vp - } + return ydbval, &vp + } - return ydbval, &src - case Ydb.Type_JSON: - v := strconv.FormatUint(uint64(rv), 10) - ydbval := &Ydb.Value{ - Value: &Ydb.Value_TextValue{ - TextValue: v, - }, - } - if c.ydbvalue { - vp := types.JSONValue(v) + return ydbval, &v +} - return ydbval, &vp - } - src := []byte(v) - if c.optional && !c.testDefault { - vp := &src +// getValueInt64 extracts int64 value from an int64 and returns it wrapped in a Ydb.Value struct. +func getValueInt64(c *column, rv int64) (*Ydb.Value, interface{}) { + v := rv + ydbval := &Ydb.Value{ + Value: &Ydb.Value_Int64Value{ + Int64Value: v, + }, + } + if c.ydbvalue { + vp := types.Int64Value(v) - return ydbval, &vp - } + return ydbval, &vp + } + if c.scanner { + s := intIncScanner(v + 10) - return ydbval, &src - case Ydb.Type_UUID: - if c.nilValue { - ydbval := &Ydb.Value{ - Value: &Ydb.Value_NullFlagValue{}, - } - if c.testDefault { - var dv [16]byte + return ydbval, &s + } + if c.optional && !c.testDefault { + vp := &v - return ydbval, &dv - } - var dv *[16]byte + return ydbval, &vp + } - return ydbval, &dv - } - v := [16]byte{} - binary.BigEndian.PutUint64(v[0:8], uint64(rv)) - binary.BigEndian.PutUint64(v[8:16], uint64(rv)) - ydbval := &Ydb.Value{ - High_128: binary.BigEndian.Uint64(v[0:8]), - Value: &Ydb.Value_Low_128{ - Low_128: binary.BigEndian.Uint64(v[8:16]), - }, - } - if c.optional && !c.testDefault { - vp := &v + return ydbval, &v +} - return ydbval, &vp - } +// getValueUint32 extracts uint32 value from an int64 and returns it wrapped in a Ydb.Value struct. +func getValueUint32(c *column, rv int64) (*Ydb.Value, interface{}) { + v := uint32(rv) + ydbval := &Ydb.Value{ + Value: &Ydb.Value_Uint32Value{ + Uint32Value: v, + }, + } + if c.optional && !c.testDefault { + vp := &v - return ydbval, &v - case Ydb.Type_JSON_DOCUMENT: - v := strconv.FormatUint(uint64(rv), 10) + return ydbval, &vp + } + + return ydbval, &v +} + +// getValueInt32 extracts an int32 value from an int64 and returns it wrapped in a Ydb.Value struct. +func getValueInt32(c *column, rv int64) (*Ydb.Value, interface{}) { + if c.nilValue { ydbval := &Ydb.Value{ - Value: &Ydb.Value_TextValue{ - TextValue: v, - }, + Value: &Ydb.Value_NullFlagValue{}, } - src := []byte(v) - if c.optional && !c.testDefault { - vp := &src + if c.testDefault { + var dv int32 - return ydbval, &vp + return ydbval, &dv } + var dv *int32 - return ydbval, &src - case Ydb.Type_DYNUMBER: - v := strconv.FormatUint(uint64(rv), 10) + return ydbval, &dv + } + v := int32(rv) + ydbval := &Ydb.Value{ + Value: &Ydb.Value_Int32Value{ + Int32Value: v, + }, + } + if c.optional && !c.testDefault { + vp := &v + + return ydbval, &vp + } + + return ydbval, &v +} + +// getValueUint16 extracts an uint16 value from an uint64 and returns it wrapped in a Ydb.Value struct. +func getValueUint16(c *column, rv int64) (*Ydb.Value, interface{}) { + v := uint16(rv) + ydbval := &Ydb.Value{ + Value: &Ydb.Value_Uint32Value{ + Uint32Value: uint32(v), + }, + } + if c.optional && !c.testDefault { + vp := &v + + return ydbval, &vp + } + + return ydbval, &v +} + +// getValueInt16 extracts an int16 value from an int64 and returns it wrapped in a Ydb.Value struct. +func getValueInt16(c *column, rv int64) (*Ydb.Value, interface{}) { + v := int16(rv) + ydbval := &Ydb.Value{ + Value: &Ydb.Value_Int32Value{ + Int32Value: int32(v), + }, + } + if c.optional && !c.testDefault { + vp := &v + + return ydbval, &vp + } + + return ydbval, &v +} + +// getValueUint8 extracts an uint8 value from an int64 and returns it wrapped in a Ydb.Value struct. +func getValueUint8(c *column, rv int64) (*Ydb.Value, interface{}) { + if c.nilValue { ydbval := &Ydb.Value{ - Value: &Ydb.Value_TextValue{ - TextValue: v, - }, + Value: &Ydb.Value_NullFlagValue{}, } - if c.optional && !c.testDefault { - vp := &v + if c.testDefault { + var dv uint8 - return ydbval, &vp + return ydbval, &dv } + var dv *uint8 - return ydbval, &v - default: - panic("ydb: unexpected types") + return ydbval, &dv + } + v := uint8(rv) + ydbval := &Ydb.Value{ + Value: &Ydb.Value_Uint32Value{ + Uint32Value: uint32(v), + }, + } + if c.optional && !c.testDefault { + vp := &v + + return ydbval, &vp + } + + return ydbval, &v +} + +// getValueInt8 extracts an int8 value from an int64 and returns it wrapped in a Ydb.Value struct. +func getValueInt8(c *column, rv int64) (*Ydb.Value, interface{}) { + v := int8(rv) + ydbval := &Ydb.Value{ + Value: &Ydb.Value_Int32Value{ + Int32Value: int32(v), + }, + } + if c.optional && !c.testDefault { + vp := &v + + return ydbval, &vp + } + + return ydbval, &v +} + +// getValueBool extracts a boolean value from an int64 and returns it wrapped in a Ydb.Value struct. +func getValueBool(c *column, rv int64) (*Ydb.Value, interface{}) { + v := rv%2 == 1 + ydbval := &Ydb.Value{ + Value: &Ydb.Value_BoolValue{ + BoolValue: v, + }, } + if c.optional && !c.testDefault { + vp := &v + + return ydbval, &vp + } + + return ydbval, &v } func getResultSet(count int, col []*column) (result *Ydb.ResultSet, testValues [][]indexed.RequiredOrOptional) { @@ -579,13 +703,7 @@ func TestScanSqlTypes(t *testing.T) { t.Fatalf("test: %s; error: %s", test.name, err) } } - if test.setColumnIndexes != nil { - for i, v := range test.setColumnIndexes { - require.Equal(t, expected[0][v], test.values[i]) - } - } else { - require.Equal(t, expected[0], test.values) - } + validateColumnIndexes(t, test, expected) expected = expected[1:] } }) @@ -607,7 +725,6 @@ func TestScanNamed(t *testing.T) { s.reset(set) for s.NextRow() { values := make([]named.Value, 0, len(test.values)) - //nolint:nestif if test.columns[0].testDefault { for i := range test.values { values = append( @@ -623,51 +740,86 @@ func TestScanNamed(t *testing.T) { } } else { for i := range test.values { - if test.columns[i].optional { - if test.columns[i].testDefault { - values = append( - values, - named.OptionalWithDefault( - or(test.setColumns, i, test.columns[i].name), - test.values[i], - ), - ) - } else { - values = append( - values, - named.Optional( - or(test.setColumns, i, test.columns[i].name), - test.values[i], - ), - ) - } - } else { - values = append( - values, - named.Required( - or(test.setColumns, i, test.columns[i].name), - test.values[i], - ), - ) - } + values = processValues(test, i, values, or) } if err := s.ScanNamed(values...); err != nil { t.Fatalf("test: %s; error: %s", test.name, err) } } - if test.setColumnIndexes != nil { - for i, v := range test.setColumnIndexes { - require.Equal(t, expected[0][v], test.values[i]) - } - } else { - require.Equal(t, expected[0], test.values) - } + validateColumnIndexes(t, test, expected) expected = expected[1:] } }) } } +// validateColumnIndexes validates the column indexes of the test data against the expected values. +func validateColumnIndexes(t *testing.T, test struct { + name string + count int + columns []*column + values []indexed.RequiredOrOptional + setColumns []string + setColumnIndexes []int +}, + expected [][]indexed.RequiredOrOptional, +) { + if test.setColumnIndexes != nil { + for i, v := range test.setColumnIndexes { + require.Equal(t, expected[0][v], test.values[i]) + } + } else { + require.Equal(t, expected[0], test.values) + } +} + +// processValues processes the values based on the given test data and parameters. +func processValues(test struct { + name string + count int + columns []*column + values []indexed.RequiredOrOptional + setColumns []string + setColumnIndexes []int +}, + i int, + values []named.Value, + or func(columns []string, + i int, + defaultValue string, + ) string, +) []named.Value { + if test.columns[i].optional { + if test.columns[i].testDefault { + values = append( + values, + named.OptionalWithDefault( + or(test.setColumns, i, test.columns[i].name), + test.values[i], + ), + ) + } else { + values = append( + values, + named.Optional( + or(test.setColumns, i, test.columns[i].name), + test.values[i], + ), + ) + } + } else { + values = append( + values, + named.Required( + or(test.setColumns, i, test.columns[i].name), + test.values[i], + ), + ) + } + + return values +} + type jsonUnmarshaller struct { bytes []byte } diff --git a/internal/topic/topicwriterinternal/queue_test.go b/internal/topic/topicwriterinternal/queue_test.go index d62e6b16e..06e486907 100644 --- a/internal/topic/topicwriterinternal/queue_test.go +++ b/internal/topic/topicwriterinternal/queue_test.go @@ -174,24 +174,7 @@ func TestMessageQueue_GetMessages(t *testing.T) { waitTimeout := time.Second * 10 startWait := time.Now() - waitReader: - for { - if lastReadSeqNo.Load() == lastSentSeqNo { - readCancel() - } - select { - case <-readFinished: - break waitReader - case stack := <-fatalChan: - t.Fatal(stack) - default: - } - - runtime.Gosched() - if time.Since(startWait) > waitTimeout { - t.Fatal() - } - } + waitReader(&lastReadSeqNo, lastSentSeqNo, readCancel, readFinished, fatalChan, startWait, waitTimeout, t) }) t.Run("ClosedContext", func(t *testing.T) { @@ -233,6 +216,38 @@ func TestMessageQueue_GetMessages(t *testing.T) { }) } +// waitReader waits for a condition where the lastReadSeqNo is equal to the lastSentSeqNo. +// It periodically checks for changes in the lastReadSeqNo and waits for a read to finish or a fatal error to occur. +// If the waitTimeout is reached, a fatal error is triggered. +func waitReader( + lastReadSeqNo *xatomic.Int64, + lastSentSeqNo int64, + readCancel func(), + readFinished <-chan struct{}, + fatalChan <-chan string, + startWait time.Time, + waitTimeout time.Duration, + t *testing.T, +) { +waitReader: + for { + if lastReadSeqNo.Load() == lastSentSeqNo { + readCancel() + } + select { + case <-readFinished: + break waitReader + case stack := <-fatalChan: + t.Fatal(stack) + default: + } + runtime.Gosched() + if time.Since(startWait) > waitTimeout { + t.Fatal() + } + } +} + func TestMessageQueue_ResetSentProgress(t *testing.T) { ctx := context.Background() diff --git a/internal/xsql/dsn.go b/internal/xsql/dsn.go index 508308995..dd0a4c759 100644 --- a/internal/xsql/dsn.go +++ b/internal/xsql/dsn.go @@ -53,29 +53,10 @@ func Parse(dataSourceName string) (opts []config.Option, connectorOpts []Connect } } if info.Params.Has("go_query_bind") { - var binders []ConnectorOption queryTransformers := strings.Split(info.Params.Get("go_query_bind"), ",") - for _, transformer := range queryTransformers { - switch transformer { - case "declare": - binders = append(binders, WithQueryBind(bind.AutoDeclare{})) - case "positional": - binders = append(binders, WithQueryBind(bind.PositionalArgs{})) - case "numeric": - binders = append(binders, WithQueryBind(bind.NumericArgs{})) - default: - if strings.HasPrefix(transformer, tablePathPrefixTransformer) { - prefix, err := extractTablePathPrefixFromBinderName(transformer) - if err != nil { - return nil, nil, xerrors.WithStackTrace(err) - } - binders = append(binders, WithTablePathPrefix(prefix)) - } else { - return nil, nil, xerrors.WithStackTrace( - fmt.Errorf("unknown query rewriter: %s", transformer), - ) - } - } + binders, err := bindTablePathPrefixInConnectorOptions(queryTransformers) + if err != nil { + return nil, nil, err } connectorOpts = append(connectorOpts, binders...) } @@ -96,3 +77,30 @@ func extractTablePathPrefixFromBinderName(binderName string) (string, error) { return ss[0][1], nil } + +// bindTablePathPrefixInConnectorOptions binds table path prefix query transformers to a list of ConnectorOptions. +func bindTablePathPrefixInConnectorOptions(queryTransformers []string) ([]ConnectorOption, error) { + var binders []ConnectorOption + for _, transformer := range queryTransformers { + switch transformer { + case "declare": + binders = append(binders, WithQueryBind(bind.AutoDeclare{})) + case "positional": + binders = append(binders, WithQueryBind(bind.PositionalArgs{})) + case "numeric": + binders = append(binders, WithQueryBind(bind.NumericArgs{})) + default: + if strings.HasPrefix(transformer, tablePathPrefixTransformer) { + prefix, err := extractTablePathPrefixFromBinderName(transformer) + if err != nil { + return nil, xerrors.WithStackTrace(err) + } + binders = append(binders, WithTablePathPrefix(prefix)) + } else { + return nil, xerrors.WithStackTrace(fmt.Errorf("unknown query rewriter: %s", transformer)) + } + } + } + + return binders, nil +} diff --git a/log/driver.go b/log/driver.go index a94ffb262..f40617c2c 100644 --- a/log/driver.go +++ b/log/driver.go @@ -13,8 +13,33 @@ func Driver(l Logger, d trace.Detailer, opts ...Option) (t trace.Driver) { return internalDriver(wrapLogger(l, opts...), d) } -func internalDriver(l Logger, d trace.Detailer) (t trace.Driver) { //nolint:gocyclo - t.OnResolve = func( +func internalDriver(l Logger, d trace.Detailer) (t trace.Driver) { + t.OnResolve = onResolve(l, d) + t.OnInit = onInit(l, d) + t.OnClose = onClose(l, d) + t.OnConnDial = connDial(l, d) + t.OnConnStateChange = connStateChange(l, d) + t.OnConnPark = connPark(l, d) + t.OnConnClose = onConnClose(l, d) + t.OnConnInvoke = connInvoke(l, d) + t.OnConnNewStream = connNewStream(l, d) + t.OnConnBan = connBan(l, d) + t.OnConnAllow = connAllow(l, d) + t.OnRepeaterWakeUp = repeaterWakeUp(l, d) + t.OnBalancerInit = balancerInit(l, d) + t.OnBalancerClose = balancerClose(l, d) + t.OnBalancerChooseEndpoint = balancerChooseEndpoint(l, d) + t.OnBalancerUpdate = balancerUpdate(l, d) + t.OnGetCredentials = getCredentials(l, d) + + return t +} + +func onResolve( + l Logger, + d trace.Detailer, +) func(info trace.DriverResolveStartInfo) func(trace.DriverResolveDoneInfo) { + return func( info trace.DriverResolveStartInfo, ) func( trace.DriverResolveDoneInfo, @@ -46,7 +71,13 @@ func internalDriver(l Logger, d trace.Detailer) (t trace.Driver) { //nolint:gocy } } } - t.OnInit = func(info trace.DriverInitStartInfo) func(trace.DriverInitDoneInfo) { +} + +func onInit( + l Logger, + d trace.Detailer, +) func(info trace.DriverInitStartInfo) func(trace.DriverInitDoneInfo) { + return func(info trace.DriverInitStartInfo) func(trace.DriverInitDoneInfo) { if d.Details()&trace.DriverEvents == 0 { return nil } @@ -81,7 +112,13 @@ func internalDriver(l Logger, d trace.Detailer) (t trace.Driver) { //nolint:gocy } } } - t.OnClose = func(info trace.DriverCloseStartInfo) func(trace.DriverCloseDoneInfo) { +} + +func onClose( + l Logger, + d trace.Detailer, +) func(info trace.DriverCloseStartInfo) func(trace.DriverCloseDoneInfo) { + return func(info trace.DriverCloseStartInfo) func(trace.DriverCloseDoneInfo) { if d.Details()&trace.DriverEvents == 0 { return nil } @@ -103,7 +140,13 @@ func internalDriver(l Logger, d trace.Detailer) (t trace.Driver) { //nolint:gocy } } } - t.OnConnDial = func(info trace.DriverConnDialStartInfo) func(trace.DriverConnDialDoneInfo) { +} + +func connDial( + l Logger, + d trace.Detailer, +) func(info trace.DriverConnDialStartInfo) func(trace.DriverConnDialDoneInfo) { + return func(info trace.DriverConnDialStartInfo) func(trace.DriverConnDialDoneInfo) { if d.Details()&trace.DriverConnEvents == 0 { return nil } @@ -130,7 +173,13 @@ func internalDriver(l Logger, d trace.Detailer) (t trace.Driver) { //nolint:gocy } } } - t.OnConnStateChange = func(info trace.DriverConnStateChangeStartInfo) func(trace.DriverConnStateChangeDoneInfo) { +} + +func connStateChange( + l Logger, + d trace.Detailer, +) func(info trace.DriverConnStateChangeStartInfo) func(trace.DriverConnStateChangeDoneInfo) { + return func(info trace.DriverConnStateChangeStartInfo) func(trace.DriverConnStateChangeDoneInfo) { if d.Details()&trace.DriverConnEvents == 0 { return nil } @@ -150,7 +199,13 @@ func internalDriver(l Logger, d trace.Detailer) (t trace.Driver) { //nolint:gocy ) } } - t.OnConnPark = func(info trace.DriverConnParkStartInfo) func(trace.DriverConnParkDoneInfo) { +} + +func connPark( + l Logger, + d trace.Detailer, +) func(info trace.DriverConnParkStartInfo) func(trace.DriverConnParkDoneInfo) { + return func(info trace.DriverConnParkStartInfo) func(trace.DriverConnParkDoneInfo) { if d.Details()&trace.DriverConnEvents == 0 { return nil } @@ -177,7 +232,13 @@ func internalDriver(l Logger, d trace.Detailer) (t trace.Driver) { //nolint:gocy } } } - t.OnConnClose = func(info trace.DriverConnCloseStartInfo) func(trace.DriverConnCloseDoneInfo) { +} + +func onConnClose( + l Logger, + d trace.Detailer, +) func(info trace.DriverConnCloseStartInfo) func(trace.DriverConnCloseDoneInfo) { + return func(info trace.DriverConnCloseStartInfo) func(trace.DriverConnCloseDoneInfo) { if d.Details()&trace.DriverConnEvents == 0 { return nil } @@ -204,7 +265,13 @@ func internalDriver(l Logger, d trace.Detailer) (t trace.Driver) { //nolint:gocy } } } - t.OnConnInvoke = func(info trace.DriverConnInvokeStartInfo) func(trace.DriverConnInvokeDoneInfo) { +} + +func connInvoke( + l Logger, + d trace.Detailer, +) func(info trace.DriverConnInvokeStartInfo) func(trace.DriverConnInvokeDoneInfo) { + return func(info trace.DriverConnInvokeStartInfo) func(trace.DriverConnInvokeDoneInfo) { if d.Details()&trace.DriverConnEvents == 0 { return nil } @@ -237,7 +304,14 @@ func internalDriver(l Logger, d trace.Detailer) (t trace.Driver) { //nolint:gocy } } } - t.OnConnNewStream = func( +} + +func connNewStream( + l Logger, + d trace.Detailer, +) func(info trace.DriverConnNewStreamStartInfo) func(trace.DriverConnNewStreamRecvInfo) func( + trace.DriverConnNewStreamDoneInfo) { + return func( info trace.DriverConnNewStreamStartInfo, ) func( trace.DriverConnNewStreamRecvInfo, @@ -294,7 +368,13 @@ func internalDriver(l Logger, d trace.Detailer) (t trace.Driver) { //nolint:gocy } } } - t.OnConnBan = func(info trace.DriverConnBanStartInfo) func(trace.DriverConnBanDoneInfo) { +} + +func connBan( + l Logger, + d trace.Detailer, +) func(info trace.DriverConnBanStartInfo) func(trace.DriverConnBanDoneInfo) { + return func(info trace.DriverConnBanStartInfo) func(trace.DriverConnBanDoneInfo) { if d.Details()&trace.DriverConnEvents == 0 { return nil } @@ -316,7 +396,13 @@ func internalDriver(l Logger, d trace.Detailer) (t trace.Driver) { //nolint:gocy ) } } - t.OnConnAllow = func(info trace.DriverConnAllowStartInfo) func(trace.DriverConnAllowDoneInfo) { +} + +func connAllow( + l Logger, + d trace.Detailer, +) func(info trace.DriverConnAllowStartInfo) func(trace.DriverConnAllowDoneInfo) { + return func(info trace.DriverConnAllowStartInfo) func(trace.DriverConnAllowDoneInfo) { if d.Details()&trace.DriverConnEvents == 0 { return nil } @@ -335,7 +421,13 @@ func internalDriver(l Logger, d trace.Detailer) (t trace.Driver) { //nolint:gocy ) } } - t.OnRepeaterWakeUp = func(info trace.DriverRepeaterWakeUpStartInfo) func(trace.DriverRepeaterWakeUpDoneInfo) { +} + +func repeaterWakeUp( + l Logger, + d trace.Detailer, +) func(info trace.DriverRepeaterWakeUpStartInfo) func(trace.DriverRepeaterWakeUpDoneInfo) { + return func(info trace.DriverRepeaterWakeUpStartInfo) func(trace.DriverRepeaterWakeUpDoneInfo) { if d.Details()&trace.DriverRepeaterEvents == 0 { return nil } @@ -366,7 +458,13 @@ func internalDriver(l Logger, d trace.Detailer) (t trace.Driver) { //nolint:gocy } } } - t.OnBalancerInit = func(info trace.DriverBalancerInitStartInfo) func(trace.DriverBalancerInitDoneInfo) { +} + +func balancerInit( + l Logger, + d trace.Detailer, +) func(info trace.DriverBalancerInitStartInfo) func(trace.DriverBalancerInitDoneInfo) { + return func(info trace.DriverBalancerInitStartInfo) func(trace.DriverBalancerInitDoneInfo) { if d.Details()&trace.DriverBalancerEvents == 0 { return nil } @@ -380,7 +478,13 @@ func internalDriver(l Logger, d trace.Detailer) (t trace.Driver) { //nolint:gocy ) } } - t.OnBalancerClose = func(info trace.DriverBalancerCloseStartInfo) func(trace.DriverBalancerCloseDoneInfo) { +} + +func balancerClose( + l Logger, + d trace.Detailer, +) func(info trace.DriverBalancerCloseStartInfo) func(trace.DriverBalancerCloseDoneInfo) { + return func(info trace.DriverBalancerCloseStartInfo) func(trace.DriverBalancerCloseDoneInfo) { if d.Details()&trace.DriverBalancerEvents == 0 { return nil } @@ -402,7 +506,13 @@ func internalDriver(l Logger, d trace.Detailer) (t trace.Driver) { //nolint:gocy } } } - t.OnBalancerChooseEndpoint = func( +} + +func balancerChooseEndpoint( + l Logger, + d trace.Detailer, +) func(info trace.DriverBalancerChooseEndpointStartInfo) func(trace.DriverBalancerChooseEndpointDoneInfo) { + return func( info trace.DriverBalancerChooseEndpointStartInfo, ) func( trace.DriverBalancerChooseEndpointDoneInfo, @@ -429,7 +539,13 @@ func internalDriver(l Logger, d trace.Detailer) (t trace.Driver) { //nolint:gocy } } } - t.OnBalancerUpdate = func( +} + +func balancerUpdate( + l Logger, + d trace.Detailer, +) func(info trace.DriverBalancerUpdateStartInfo) func(trace.DriverBalancerUpdateDoneInfo) { + return func( info trace.DriverBalancerUpdateStartInfo, ) func( trace.DriverBalancerUpdateDoneInfo, @@ -453,7 +569,13 @@ func internalDriver(l Logger, d trace.Detailer) (t trace.Driver) { //nolint:gocy ) } } - t.OnGetCredentials = func(info trace.DriverGetCredentialsStartInfo) func(trace.DriverGetCredentialsDoneInfo) { +} + +func getCredentials( + l Logger, + d trace.Detailer, +) func(info trace.DriverGetCredentialsStartInfo) func(trace.DriverGetCredentialsDoneInfo) { + return func(info trace.DriverGetCredentialsStartInfo) func(trace.DriverGetCredentialsDoneInfo) { if d.Details()&trace.DriverCredentialsEvents == 0 { return nil } @@ -477,6 +599,4 @@ func internalDriver(l Logger, d trace.Detailer) (t trace.Driver) { //nolint:gocy } } } - - return t } diff --git a/log/sql.go b/log/sql.go index 3a71f3cb7..8adf81ff6 100644 --- a/log/sql.go +++ b/log/sql.go @@ -14,7 +14,30 @@ func DatabaseSQL(l Logger, d trace.Detailer, opts ...Option) (t trace.DatabaseSQ } func internalDatabaseSQL(l *wrapper, d trace.Detailer) (t trace.DatabaseSQL) { - t.OnConnectorConnect = func( + logger := l.logger + loggerQuery := l.logQuery + + t.OnConnectorConnect = connectorConnect(logger, d) + t.OnConnPing = connPing(logger, d) + t.OnConnClose = connClose(logger, d) + t.OnConnBegin = connBegin(logger, d) + t.OnConnPrepare = connPrepare(logger, loggerQuery, d) + t.OnConnExec = connExec(logger, loggerQuery, d) + t.OnConnQuery = connQuery(logger, loggerQuery, d) + t.OnTxCommit = txCommit(logger, d) + t.OnTxRollback = txRollback(logger, d) + t.OnStmtClose = stmtClose(logger, d) + t.OnStmtExec = stmtExec(logger, loggerQuery, d) + t.OnStmtQuery = stmtQuery(logger, loggerQuery, d) + + return t +} + +func connectorConnect( + l Logger, + d trace.Detailer, +) func(info trace.DatabaseSQLConnectorConnectStartInfo) func(trace.DatabaseSQLConnectorConnectDoneInfo) { + return func( info trace.DatabaseSQLConnectorConnectStartInfo, ) func( trace.DatabaseSQLConnectorConnectDoneInfo, @@ -42,8 +65,13 @@ func internalDatabaseSQL(l *wrapper, d trace.Detailer) (t trace.DatabaseSQL) { } } } +} - t.OnConnPing = func(info trace.DatabaseSQLConnPingStartInfo) func(trace.DatabaseSQLConnPingDoneInfo) { +func connPing( + l Logger, + d trace.Detailer, +) func(info trace.DatabaseSQLConnPingStartInfo) func(trace.DatabaseSQLConnPingDoneInfo) { + return func(info trace.DatabaseSQLConnPingStartInfo) func(trace.DatabaseSQLConnPingDoneInfo) { if d.Details()&trace.DatabaseSQLConnEvents == 0 { return nil } @@ -65,7 +93,13 @@ func internalDatabaseSQL(l *wrapper, d trace.Detailer) (t trace.DatabaseSQL) { } } } - t.OnConnClose = func(info trace.DatabaseSQLConnCloseStartInfo) func(trace.DatabaseSQLConnCloseDoneInfo) { +} + +func connClose( + l Logger, + d trace.Detailer, +) func(info trace.DatabaseSQLConnCloseStartInfo) func(trace.DatabaseSQLConnCloseDoneInfo) { + return func(info trace.DatabaseSQLConnCloseStartInfo) func(trace.DatabaseSQLConnCloseDoneInfo) { if d.Details()&trace.DatabaseSQLConnEvents == 0 { return nil } @@ -87,7 +121,13 @@ func internalDatabaseSQL(l *wrapper, d trace.Detailer) (t trace.DatabaseSQL) { } } } - t.OnConnBegin = func(info trace.DatabaseSQLConnBeginStartInfo) func(trace.DatabaseSQLConnBeginDoneInfo) { +} + +func connBegin( + l Logger, + d trace.Detailer, +) func(info trace.DatabaseSQLConnBeginStartInfo) func(trace.DatabaseSQLConnBeginDoneInfo) { + return func(info trace.DatabaseSQLConnBeginStartInfo) func(trace.DatabaseSQLConnBeginDoneInfo) { if d.Details()&trace.DatabaseSQLConnEvents == 0 { return nil } @@ -109,13 +149,20 @@ func internalDatabaseSQL(l *wrapper, d trace.Detailer) (t trace.DatabaseSQL) { } } } - t.OnConnPrepare = func(info trace.DatabaseSQLConnPrepareStartInfo) func(trace.DatabaseSQLConnPrepareDoneInfo) { +} + +func connPrepare( + l Logger, + loggerQuery bool, + d trace.Detailer, +) func(info trace.DatabaseSQLConnPrepareStartInfo) func(trace.DatabaseSQLConnPrepareDoneInfo) { + return func(info trace.DatabaseSQLConnPrepareStartInfo) func(trace.DatabaseSQLConnPrepareDoneInfo) { if d.Details()&trace.DatabaseSQLConnEvents == 0 { return nil } ctx := with(*info.Context, TRACE, "ydb", "database", "sql", "conn", "prepare", "stmt") l.Log(ctx, "start", - appendFieldByCondition(l.logQuery, + appendFieldByCondition(loggerQuery, String("query", info.Query), )..., ) @@ -129,7 +176,7 @@ func internalDatabaseSQL(l *wrapper, d trace.Detailer) (t trace.DatabaseSQL) { ) } else { l.Log(WithLevel(ctx, ERROR), "failed", - appendFieldByCondition(l.logQuery, + appendFieldByCondition(loggerQuery, String("query", query), Error(info.Error), latencyField(start), @@ -139,13 +186,20 @@ func internalDatabaseSQL(l *wrapper, d trace.Detailer) (t trace.DatabaseSQL) { } } } - t.OnConnExec = func(info trace.DatabaseSQLConnExecStartInfo) func(trace.DatabaseSQLConnExecDoneInfo) { +} + +func connExec( + l Logger, + loggerQuery bool, + d trace.Detailer, +) func(info trace.DatabaseSQLConnExecStartInfo) func(trace.DatabaseSQLConnExecDoneInfo) { + return func(info trace.DatabaseSQLConnExecStartInfo) func(trace.DatabaseSQLConnExecDoneInfo) { if d.Details()&trace.DatabaseSQLConnEvents == 0 { return nil } ctx := with(*info.Context, TRACE, "ydb", "database", "sql", "conn", "exec") l.Log(ctx, "start", - appendFieldByCondition(l.logQuery, + appendFieldByCondition(loggerQuery, String("query", info.Query), )..., ) @@ -161,7 +215,7 @@ func internalDatabaseSQL(l *wrapper, d trace.Detailer) (t trace.DatabaseSQL) { } else { m := retry.Check(info.Error) l.Log(WithLevel(ctx, ERROR), "failed", - appendFieldByCondition(l.logQuery, + appendFieldByCondition(loggerQuery, String("query", query), Bool("retryable", m.MustRetry(idempotent)), Int64("code", m.StatusCode()), @@ -174,13 +228,20 @@ func internalDatabaseSQL(l *wrapper, d trace.Detailer) (t trace.DatabaseSQL) { } } } - t.OnConnQuery = func(info trace.DatabaseSQLConnQueryStartInfo) func(trace.DatabaseSQLConnQueryDoneInfo) { +} + +func connQuery( + l Logger, + loggerQuery bool, + d trace.Detailer, +) func(info trace.DatabaseSQLConnQueryStartInfo) func(trace.DatabaseSQLConnQueryDoneInfo) { + return func(info trace.DatabaseSQLConnQueryStartInfo) func(trace.DatabaseSQLConnQueryDoneInfo) { if d.Details()&trace.DatabaseSQLConnEvents == 0 { return nil } ctx := with(*info.Context, TRACE, "ydb", "database", "sql", "conn", "query") l.Log(ctx, "start", - appendFieldByCondition(l.logQuery, + appendFieldByCondition(loggerQuery, String("query", info.Query), )..., ) @@ -196,7 +257,7 @@ func internalDatabaseSQL(l *wrapper, d trace.Detailer) (t trace.DatabaseSQL) { } else { m := retry.Check(info.Error) l.Log(WithLevel(ctx, ERROR), "failed", - appendFieldByCondition(l.logQuery, + appendFieldByCondition(loggerQuery, String("query", query), Bool("retryable", m.MustRetry(idempotent)), Int64("code", m.StatusCode()), @@ -209,7 +270,13 @@ func internalDatabaseSQL(l *wrapper, d trace.Detailer) (t trace.DatabaseSQL) { } } } - t.OnTxCommit = func(info trace.DatabaseSQLTxCommitStartInfo) func(trace.DatabaseSQLTxCommitDoneInfo) { +} + +func txCommit( + l Logger, + d trace.Detailer, +) func(info trace.DatabaseSQLTxCommitStartInfo) func(trace.DatabaseSQLTxCommitDoneInfo) { + return func(info trace.DatabaseSQLTxCommitStartInfo) func(trace.DatabaseSQLTxCommitDoneInfo) { if d.Details()&trace.DatabaseSQLTxEvents == 0 { return nil } @@ -231,7 +298,13 @@ func internalDatabaseSQL(l *wrapper, d trace.Detailer) (t trace.DatabaseSQL) { } } } - t.OnTxRollback = func(info trace.DatabaseSQLTxRollbackStartInfo) func(trace.DatabaseSQLTxRollbackDoneInfo) { +} + +func txRollback( + l Logger, + d trace.Detailer, +) func(info trace.DatabaseSQLTxRollbackStartInfo) func(trace.DatabaseSQLTxRollbackDoneInfo) { + return func(info trace.DatabaseSQLTxRollbackStartInfo) func(trace.DatabaseSQLTxRollbackDoneInfo) { if d.Details()&trace.DatabaseSQLTxEvents == 0 { return nil } @@ -253,7 +326,13 @@ func internalDatabaseSQL(l *wrapper, d trace.Detailer) (t trace.DatabaseSQL) { } } } - t.OnStmtClose = func(info trace.DatabaseSQLStmtCloseStartInfo) func(trace.DatabaseSQLStmtCloseDoneInfo) { +} + +func stmtClose( + l Logger, + d trace.Detailer, +) func(info trace.DatabaseSQLStmtCloseStartInfo) func(trace.DatabaseSQLStmtCloseDoneInfo) { + return func(info trace.DatabaseSQLStmtCloseStartInfo) func(trace.DatabaseSQLStmtCloseDoneInfo) { if d.Details()&trace.DatabaseSQLStmtEvents == 0 { return nil } @@ -275,13 +354,20 @@ func internalDatabaseSQL(l *wrapper, d trace.Detailer) (t trace.DatabaseSQL) { } } } - t.OnStmtExec = func(info trace.DatabaseSQLStmtExecStartInfo) func(trace.DatabaseSQLStmtExecDoneInfo) { +} + +func stmtExec( + l Logger, + loggerQuery bool, + d trace.Detailer, +) func(info trace.DatabaseSQLStmtExecStartInfo) func(trace.DatabaseSQLStmtExecDoneInfo) { + return func(info trace.DatabaseSQLStmtExecStartInfo) func(trace.DatabaseSQLStmtExecDoneInfo) { if d.Details()&trace.DatabaseSQLStmtEvents == 0 { return nil } ctx := with(*info.Context, TRACE, "ydb", "database", "sql", "stmt", "exec") l.Log(ctx, "start", - appendFieldByCondition(l.logQuery, + appendFieldByCondition(loggerQuery, String("query", info.Query), )..., ) @@ -296,7 +382,7 @@ func internalDatabaseSQL(l *wrapper, d trace.Detailer) (t trace.DatabaseSQL) { ) } else { l.Log(WithLevel(ctx, ERROR), "failed", - appendFieldByCondition(l.logQuery, + appendFieldByCondition(loggerQuery, String("query", query), Error(info.Error), latencyField(start), @@ -306,13 +392,20 @@ func internalDatabaseSQL(l *wrapper, d trace.Detailer) (t trace.DatabaseSQL) { } } } - t.OnStmtQuery = func(info trace.DatabaseSQLStmtQueryStartInfo) func(trace.DatabaseSQLStmtQueryDoneInfo) { +} + +func stmtQuery( + l Logger, + loggerQuery bool, + d trace.Detailer, +) func(info trace.DatabaseSQLStmtQueryStartInfo) func(trace.DatabaseSQLStmtQueryDoneInfo) { + return func(info trace.DatabaseSQLStmtQueryStartInfo) func(trace.DatabaseSQLStmtQueryDoneInfo) { if d.Details()&trace.DatabaseSQLStmtEvents == 0 { return nil } ctx := with(*info.Context, TRACE, "ydb", "database", "sql", "stmt", "query") l.Log(ctx, "start", - appendFieldByCondition(l.logQuery, + appendFieldByCondition(loggerQuery, String("query", info.Query), )..., ) @@ -326,7 +419,7 @@ func internalDatabaseSQL(l *wrapper, d trace.Detailer) (t trace.DatabaseSQL) { ) } else { l.Log(WithLevel(ctx, ERROR), "failed", - appendFieldByCondition(l.logQuery, + appendFieldByCondition(loggerQuery, String("query", query), Error(info.Error), latencyField(start), @@ -336,6 +429,4 @@ func internalDatabaseSQL(l *wrapper, d trace.Detailer) (t trace.DatabaseSQL) { } } } - - return t } diff --git a/log/table.go b/log/table.go index 5afa72bee..f7df75ae1 100644 --- a/log/table.go +++ b/log/table.go @@ -14,9 +14,195 @@ func Table(l Logger, d trace.Detailer, opts ...Option) (t trace.Table) { return internalTable(wrapLogger(l, opts...), d) } -//nolint:gocyclo func internalTable(l *wrapper, d trace.Detailer) (t trace.Table) { - t.OnDo = func( + logger := l.logger + + t.OnDo = onDo(logger, d) + t.OnDoTx = onDoTx(logger, d) + t.OnCreateSession = onCreateSession(logger, d) + t.OnSessionNew = onSessionNew(logger, d) + t.OnSessionDelete = onSessionDelete(logger, d) + t.OnSessionKeepAlive = onSessionKeepAlive(logger, d) + t.OnSessionQueryPrepare = func( + info trace.TablePrepareDataQueryStartInfo, + ) func( + trace.TablePrepareDataQueryDoneInfo, + ) { + if d.Details()&trace.TableSessionQueryInvokeEvents == 0 { + return nil + } + ctx := with(*info.Context, TRACE, "ydb", "table", "session", "query", "prepare") + session := info.Session + query := info.Query + l.Log(ctx, "start", + appendFieldByCondition(l.logQuery, + String("query", info.Query), + String("id", session.ID()), + String("status", session.Status()), + )..., + ) + start := time.Now() + + return func(info trace.TablePrepareDataQueryDoneInfo) { + if info.Error == nil { + l.Log(ctx, "done", + appendFieldByCondition(l.logQuery, + Stringer("result", info.Result), + appendFieldByCondition(l.logQuery, + String("query", query), + String("id", session.ID()), + String("status", session.Status()), + latencyField(start), + )..., + )..., + ) + } else { + l.Log(WithLevel(ctx, ERROR), "failed", + appendFieldByCondition(l.logQuery, + String("query", query), + Error(info.Error), + String("id", session.ID()), + String("status", session.Status()), + latencyField(start), + versionField(), + )..., + ) + } + } + } + t.OnSessionQueryExecute = func( + info trace.TableExecuteDataQueryStartInfo, + ) func( + trace.TableExecuteDataQueryDoneInfo, + ) { + if d.Details()&trace.TableSessionQueryInvokeEvents == 0 { + return nil + } + ctx := with(*info.Context, TRACE, "ydb", "table", "session", "query", "execute") + session := info.Session + query := info.Query + l.Log(ctx, "start", + appendFieldByCondition(l.logQuery, + Stringer("query", info.Query), + String("id", session.ID()), + String("status", session.Status()), + )..., + ) + start := time.Now() + + return func(info trace.TableExecuteDataQueryDoneInfo) { + if info.Error == nil { + tx := info.Tx + l.Log(ctx, "done", + appendFieldByCondition(l.logQuery, + Stringer("query", query), + String("id", session.ID()), + String("tx", tx.ID()), + String("status", session.Status()), + Bool("prepared", info.Prepared), + NamedError("result_err", info.Result.Err()), + latencyField(start), + )..., + ) + } else { + l.Log(WithLevel(ctx, ERROR), "failed", + appendFieldByCondition(l.logQuery, + Stringer("query", query), + Error(info.Error), + String("id", session.ID()), + String("status", session.Status()), + Bool("prepared", info.Prepared), + latencyField(start), + versionField(), + )..., + ) + } + } + } + t.OnSessionQueryStreamExecute = func( + info trace.TableSessionQueryStreamExecuteStartInfo, + ) func( + trace.TableSessionQueryStreamExecuteIntermediateInfo, + ) func( + trace.TableSessionQueryStreamExecuteDoneInfo, + ) { + if d.Details()&trace.TableSessionQueryStreamEvents == 0 { + return nil + } + ctx := with(*info.Context, TRACE, "ydb", "table", "session", "query", "stream", "execute") + session := info.Session + query := info.Query + l.Log(ctx, "start", + appendFieldByCondition(l.logQuery, + Stringer("query", info.Query), + String("id", session.ID()), + String("status", session.Status()), + )..., + ) + start := time.Now() + + return func( + info trace.TableSessionQueryStreamExecuteIntermediateInfo, + ) func( + trace.TableSessionQueryStreamExecuteDoneInfo, + ) { + if info.Error == nil { + l.Log(ctx, "intermediate") + } else { + l.Log(WithLevel(ctx, WARN), "failed", + Error(info.Error), + versionField(), + ) + } + + return func(info trace.TableSessionQueryStreamExecuteDoneInfo) { + if info.Error == nil { + l.Log(ctx, "done", + appendFieldByCondition(l.logQuery, + Stringer("query", query), + Error(info.Error), + String("id", session.ID()), + String("status", session.Status()), + latencyField(start), + )..., + ) + } else { + l.Log(WithLevel(ctx, ERROR), "failed", + appendFieldByCondition(l.logQuery, + Stringer("query", query), + Error(info.Error), + String("id", session.ID()), + String("status", session.Status()), + latencyField(start), + versionField(), + )..., + ) + } + } + } + } + + t.OnSessionQueryStreamRead = onSessionQueryStreamRead(logger, d) + t.OnSessionTransactionBegin = onSessionTransactionBegin(logger, d) + t.OnSessionTransactionCommit = onSessionTransactionCommit(logger, d) + t.OnSessionTransactionRollback = onSessionTransactionRollback(logger, d) + t.OnInit = OnInitTable(logger, d) + t.OnClose = onCloseTable(logger, d) + t.OnPoolStateChange = onPoolStateChange(logger, d) + t.OnPoolSessionAdd = onPoolSessionAdd(logger, d) + t.OnPoolSessionRemove = onPoolSessionRemove(logger, d) + t.OnPoolPut = onPoolPut(logger, d) + t.OnPoolGet = onPoolGet(logger, d) + t.OnPoolWait = onPoolWait(logger, d) + + return t +} + +func onDo(l Logger, + d trace.Detailer, +) func(info trace.TableDoStartInfo) func( + info trace.TableDoIntermediateInfo) func(trace.TableDoDoneInfo) { + return func( info trace.TableDoStartInfo, ) func( info trace.TableDoIntermediateInfo, @@ -89,7 +275,14 @@ func internalTable(l *wrapper, d trace.Detailer) (t trace.Table) { } } } - t.OnDoTx = func( +} + +func onDoTx( + l Logger, + d trace.Detailer, +) func(info trace.TableDoTxStartInfo) func( + info trace.TableDoTxIntermediateInfo) func(trace.TableDoTxDoneInfo) { + return func( info trace.TableDoTxStartInfo, ) func( info trace.TableDoTxIntermediateInfo, @@ -162,7 +355,14 @@ func internalTable(l *wrapper, d trace.Detailer) (t trace.Table) { } } } - t.OnCreateSession = func( +} + +func onCreateSession( + l Logger, + d trace.Detailer, +) func(info trace.TableCreateSessionStartInfo) func( + info trace.TableCreateSessionIntermediateInfo) func(trace.TableCreateSessionDoneInfo) { + return func( info trace.TableCreateSessionStartInfo, ) func( info trace.TableCreateSessionIntermediateInfo, @@ -208,7 +408,13 @@ func internalTable(l *wrapper, d trace.Detailer) (t trace.Table) { } } } - t.OnSessionNew = func(info trace.TableSessionNewStartInfo) func(trace.TableSessionNewDoneInfo) { +} + +func onSessionNew( + l Logger, + d trace.Detailer, +) func(info trace.TableSessionNewStartInfo) func(trace.TableSessionNewDoneInfo) { + return func(info trace.TableSessionNewStartInfo) func(trace.TableSessionNewDoneInfo) { if d.Details()&trace.TableSessionEvents == 0 { return nil } @@ -238,7 +444,13 @@ func internalTable(l *wrapper, d trace.Detailer) (t trace.Table) { } } } - t.OnSessionDelete = func(info trace.TableSessionDeleteStartInfo) func(trace.TableSessionDeleteDoneInfo) { +} + +func onSessionDelete( + l Logger, + d trace.Detailer, +) func(info trace.TableSessionDeleteStartInfo) func(trace.TableSessionDeleteDoneInfo) { + return func(info trace.TableSessionDeleteStartInfo) func(trace.TableSessionDeleteDoneInfo) { if d.Details()&trace.TableSessionEvents == 0 { return nil } @@ -268,7 +480,13 @@ func internalTable(l *wrapper, d trace.Detailer) (t trace.Table) { } } } - t.OnSessionKeepAlive = func(info trace.TableKeepAliveStartInfo) func(trace.TableKeepAliveDoneInfo) { +} + +func onSessionKeepAlive( + l Logger, + d trace.Detailer, +) func(info trace.TableKeepAliveStartInfo) func(trace.TableKeepAliveDoneInfo) { + return func(info trace.TableKeepAliveStartInfo) func(trace.TableKeepAliveDoneInfo) { if d.Details()&trace.TableSessionEvents == 0 { return nil } @@ -298,165 +516,15 @@ func internalTable(l *wrapper, d trace.Detailer) (t trace.Table) { } } } - t.OnSessionQueryPrepare = func( - info trace.TablePrepareDataQueryStartInfo, - ) func( - trace.TablePrepareDataQueryDoneInfo, - ) { - if d.Details()&trace.TableSessionQueryInvokeEvents == 0 { - return nil - } - ctx := with(*info.Context, TRACE, "ydb", "table", "session", "query", "prepare") - session := info.Session - query := info.Query - l.Log(ctx, "start", - appendFieldByCondition(l.logQuery, - String("query", info.Query), - String("id", session.ID()), - String("status", session.Status()), - )..., - ) - start := time.Now() - - return func(info trace.TablePrepareDataQueryDoneInfo) { - if info.Error == nil { - l.Log(ctx, "done", - appendFieldByCondition(l.logQuery, - Stringer("result", info.Result), - appendFieldByCondition(l.logQuery, - String("query", query), - String("id", session.ID()), - String("status", session.Status()), - latencyField(start), - )..., - )..., - ) - } else { - l.Log(WithLevel(ctx, ERROR), "failed", - appendFieldByCondition(l.logQuery, - String("query", query), - Error(info.Error), - String("id", session.ID()), - String("status", session.Status()), - latencyField(start), - versionField(), - )..., - ) - } - } - } - t.OnSessionQueryExecute = func( - info trace.TableExecuteDataQueryStartInfo, - ) func( - trace.TableExecuteDataQueryDoneInfo, - ) { - if d.Details()&trace.TableSessionQueryInvokeEvents == 0 { - return nil - } - ctx := with(*info.Context, TRACE, "ydb", "table", "session", "query", "execute") - session := info.Session - query := info.Query - l.Log(ctx, "start", - appendFieldByCondition(l.logQuery, - Stringer("query", info.Query), - String("id", session.ID()), - String("status", session.Status()), - )..., - ) - start := time.Now() - - return func(info trace.TableExecuteDataQueryDoneInfo) { - if info.Error == nil { - tx := info.Tx - l.Log(ctx, "done", - appendFieldByCondition(l.logQuery, - Stringer("query", query), - String("id", session.ID()), - String("tx", tx.ID()), - String("status", session.Status()), - Bool("prepared", info.Prepared), - NamedError("result_err", info.Result.Err()), - latencyField(start), - )..., - ) - } else { - l.Log(WithLevel(ctx, ERROR), "failed", - appendFieldByCondition(l.logQuery, - Stringer("query", query), - Error(info.Error), - String("id", session.ID()), - String("status", session.Status()), - Bool("prepared", info.Prepared), - latencyField(start), - versionField(), - )..., - ) - } - } - } - t.OnSessionQueryStreamExecute = func( - info trace.TableSessionQueryStreamExecuteStartInfo, - ) func( - trace.TableSessionQueryStreamExecuteIntermediateInfo, - ) func( - trace.TableSessionQueryStreamExecuteDoneInfo, - ) { - if d.Details()&trace.TableSessionQueryStreamEvents == 0 { - return nil - } - ctx := with(*info.Context, TRACE, "ydb", "table", "session", "query", "stream", "execute") - session := info.Session - query := info.Query - l.Log(ctx, "start", - appendFieldByCondition(l.logQuery, - Stringer("query", info.Query), - String("id", session.ID()), - String("status", session.Status()), - )..., - ) - start := time.Now() - - return func( - info trace.TableSessionQueryStreamExecuteIntermediateInfo, - ) func( - trace.TableSessionQueryStreamExecuteDoneInfo, - ) { - if info.Error == nil { - l.Log(ctx, "intermediate") - } else { - l.Log(WithLevel(ctx, WARN), "failed", - Error(info.Error), - versionField(), - ) - } +} - return func(info trace.TableSessionQueryStreamExecuteDoneInfo) { - if info.Error == nil { - l.Log(ctx, "done", - appendFieldByCondition(l.logQuery, - Stringer("query", query), - Error(info.Error), - String("id", session.ID()), - String("status", session.Status()), - latencyField(start), - )..., - ) - } else { - l.Log(WithLevel(ctx, ERROR), "failed", - appendFieldByCondition(l.logQuery, - Stringer("query", query), - Error(info.Error), - String("id", session.ID()), - String("status", session.Status()), - latencyField(start), - versionField(), - )..., - ) - } - } - } - } - t.OnSessionQueryStreamRead = func( +func onSessionQueryStreamRead( + l Logger, + d trace.Detailer, +) func(info trace.TableSessionQueryStreamReadStartInfo) func( + intermediateInfo trace.TableSessionQueryStreamReadIntermediateInfo, +) func(trace.TableSessionQueryStreamReadDoneInfo) { + return func( info trace.TableSessionQueryStreamReadStartInfo, ) func( intermediateInfo trace.TableSessionQueryStreamReadIntermediateInfo, @@ -507,7 +575,13 @@ func internalTable(l *wrapper, d trace.Detailer) (t trace.Table) { } } } - t.OnSessionTransactionBegin = func( +} + +func onSessionTransactionBegin( + l Logger, + d trace.Detailer, +) func(info trace.TableSessionTransactionBeginStartInfo) func(trace.TableSessionTransactionBeginDoneInfo) { + return func( info trace.TableSessionTransactionBeginStartInfo, ) func( trace.TableSessionTransactionBeginDoneInfo, @@ -542,7 +616,13 @@ func internalTable(l *wrapper, d trace.Detailer) (t trace.Table) { } } } - t.OnSessionTransactionCommit = func( +} + +func onSessionTransactionCommit( + l Logger, + d trace.Detailer, +) func(info trace.TableSessionTransactionCommitStartInfo) func(trace.TableSessionTransactionCommitDoneInfo) { + return func( info trace.TableSessionTransactionCommitStartInfo, ) func( trace.TableSessionTransactionCommitDoneInfo, @@ -580,7 +660,13 @@ func internalTable(l *wrapper, d trace.Detailer) (t trace.Table) { } } } - t.OnSessionTransactionRollback = func( +} + +func onSessionTransactionRollback( + l Logger, + d trace.Detailer, +) func(info trace.TableSessionTransactionRollbackStartInfo) func(trace.TableSessionTransactionRollbackDoneInfo) { + return func( info trace.TableSessionTransactionRollbackStartInfo, ) func( trace.TableSessionTransactionRollbackDoneInfo, @@ -618,7 +704,10 @@ func internalTable(l *wrapper, d trace.Detailer) (t trace.Table) { } } } - t.OnInit = func(info trace.TableInitStartInfo) func(trace.TableInitDoneInfo) { +} + +func OnInitTable(l Logger, d trace.Detailer) func(info trace.TableInitStartInfo) func(trace.TableInitDoneInfo) { + return func(info trace.TableInitStartInfo) func(trace.TableInitDoneInfo) { if d.Details()&trace.TableEvents == 0 { return nil } @@ -633,7 +722,10 @@ func internalTable(l *wrapper, d trace.Detailer) (t trace.Table) { ) } } - t.OnClose = func(info trace.TableCloseStartInfo) func(trace.TableCloseDoneInfo) { +} + +func onCloseTable(l Logger, d trace.Detailer) func(info trace.TableCloseStartInfo) func(trace.TableCloseDoneInfo) { + return func(info trace.TableCloseStartInfo) func(trace.TableCloseDoneInfo) { if d.Details()&trace.TableEvents == 0 { return nil } @@ -655,7 +747,10 @@ func internalTable(l *wrapper, d trace.Detailer) (t trace.Table) { } } } - t.OnPoolStateChange = func(info trace.TablePoolStateChangeInfo) { +} + +func onPoolStateChange(l Logger, d trace.Detailer) func(info trace.TablePoolStateChangeInfo) { + return func(info trace.TablePoolStateChangeInfo) { if d.Details()&trace.TablePoolLifeCycleEvents == 0 { return } @@ -665,7 +760,10 @@ func internalTable(l *wrapper, d trace.Detailer) (t trace.Table) { String("event", info.Event), ) } - t.OnPoolSessionAdd = func(info trace.TablePoolSessionAddInfo) { +} + +func onPoolSessionAdd(l Logger, d trace.Detailer) func(info trace.TablePoolSessionAddInfo) { + return func(info trace.TablePoolSessionAddInfo) { if d.Details()&trace.TablePoolLifeCycleEvents == 0 { return } @@ -675,7 +773,10 @@ func internalTable(l *wrapper, d trace.Detailer) (t trace.Table) { String("status", info.Session.Status()), ) } - t.OnPoolSessionRemove = func(info trace.TablePoolSessionRemoveInfo) { +} + +func onPoolSessionRemove(l Logger, d trace.Detailer) func(info trace.TablePoolSessionRemoveInfo) { + return func(info trace.TablePoolSessionRemoveInfo) { if d.Details()&trace.TablePoolLifeCycleEvents == 0 { return } @@ -685,7 +786,10 @@ func internalTable(l *wrapper, d trace.Detailer) (t trace.Table) { String("status", info.Session.Status()), ) } - t.OnPoolPut = func(info trace.TablePoolPutStartInfo) func(trace.TablePoolPutDoneInfo) { +} + +func onPoolPut(l Logger, d trace.Detailer) func(info trace.TablePoolPutStartInfo) func(trace.TablePoolPutDoneInfo) { + return func(info trace.TablePoolPutStartInfo) func(trace.TablePoolPutDoneInfo) { if d.Details()&trace.TablePoolAPIEvents == 0 { return nil } @@ -715,7 +819,10 @@ func internalTable(l *wrapper, d trace.Detailer) (t trace.Table) { } } } - t.OnPoolGet = func(info trace.TablePoolGetStartInfo) func(trace.TablePoolGetDoneInfo) { +} + +func onPoolGet(l Logger, d trace.Detailer) func(info trace.TablePoolGetStartInfo) func(trace.TablePoolGetDoneInfo) { + return func(info trace.TablePoolGetStartInfo) func(trace.TablePoolGetDoneInfo) { if d.Details()&trace.TablePoolAPIEvents == 0 { return nil } @@ -742,7 +849,13 @@ func internalTable(l *wrapper, d trace.Detailer) (t trace.Table) { } } } - t.OnPoolWait = func(info trace.TablePoolWaitStartInfo) func(trace.TablePoolWaitDoneInfo) { +} + +func onPoolWait( + l Logger, + d trace.Detailer, +) func(info trace.TablePoolWaitStartInfo) func(trace.TablePoolWaitDoneInfo) { + return func(info trace.TablePoolWaitStartInfo) func(trace.TablePoolWaitDoneInfo) { if d.Details()&trace.TablePoolAPIEvents == 0 { return nil } @@ -768,6 +881,4 @@ func internalTable(l *wrapper, d trace.Detailer) (t trace.Table) { } } } - - return t } diff --git a/log/topic.go b/log/topic.go index 87d645d9b..dfe85fea7 100644 --- a/log/topic.go +++ b/log/topic.go @@ -12,8 +12,42 @@ func Topic(l Logger, d trace.Detailer, opts ...Option) (t trace.Topic) { return internalTopic(wrapLogger(l, opts...), d) } -func internalTopic(l Logger, d trace.Detailer) (t trace.Topic) { //nolint:gocyclo - t.OnReaderReconnect = func( +func internalTopic(l Logger, d trace.Detailer) (t trace.Topic) { + t.OnReaderReconnect = onReaderReconnect(l, d) + t.OnReaderReconnectRequest = onReaderReconnectRequest(l, d) + t.OnReaderPartitionReadStartResponse = onReaderPartitionReadStartResponse(l, d) + t.OnReaderPartitionReadStopResponse = onReaderPartitionReadStopResponse(l, d) + t.OnReaderCommit = onReaderCommit(l, d) + t.OnReaderSendCommitMessage = onReaderSendCommitMessage(l, d) + t.OnReaderCommittedNotify = onReaderCommittedNotify(l, d) + t.OnReaderClose = onReaderClose(l, d) + + t.OnReaderInit = onReaderInit(l, d) + t.OnReaderError = onReaderError(l, d) + t.OnReaderUpdateToken = onReaderUpdateToken(l, d) + t.OnReaderSentDataRequest = onReaderSentDataRequest(l, d) + t.OnReaderReceiveDataResponse = onReaderReceiveDataResponse(l, d) + t.OnReaderReadMessages = onReaderReadMessages(l, d) + t.OnReaderUnknownGrpcMessage = onReaderUnknownGrpcMessage(l, d) + + /// + /// Topic writer + /// + t.OnWriterReconnect = onWriterReconnect(l, d) + t.OnWriterInitStream = onWriterInitStream(l, d) + t.OnWriterClose = onWriterClose(l, d) + t.OnWriterCompressMessages = onWriterCompressMessages(l, d) + t.OnWriterSendMessages = onWriterSendMessages(l, d) + t.OnWriterReadUnknownGrpcMessage = onWriterReadUnknownGrpcMessage(l, d) + + return t +} + +func onReaderReconnect( + l Logger, + d trace.Detailer, +) func(info trace.TopicReaderReconnectStartInfo) func(doneInfo trace.TopicReaderReconnectDoneInfo) { + return func( info trace.TopicReaderReconnectStartInfo, ) func(doneInfo trace.TopicReaderReconnectDoneInfo) { if d.Details()&trace.TopicReaderStreamLifeCycleEvents == 0 { @@ -30,7 +64,13 @@ func internalTopic(l Logger, d trace.Detailer) (t trace.Topic) { //nolint:gocycl ) } } - t.OnReaderReconnectRequest = func(info trace.TopicReaderReconnectRequestInfo) { +} + +func onReaderReconnectRequest( + l Logger, + d trace.Detailer, +) func(info trace.TopicReaderReconnectRequestInfo) { + return func(info trace.TopicReaderReconnectRequestInfo) { if d.Details()&trace.TopicReaderStreamLifeCycleEvents == 0 { return } @@ -40,7 +80,15 @@ func internalTopic(l Logger, d trace.Detailer) (t trace.Topic) { //nolint:gocycl Bool("was_sent", info.WasSent), ) } - t.OnReaderPartitionReadStartResponse = func( +} + +func onReaderPartitionReadStartResponse( + l Logger, + d trace.Detailer, +) func( + info trace.TopicReaderPartitionReadStartResponseStartInfo) func( + stopInfo trace.TopicReaderPartitionReadStartResponseDoneInfo) { + return func( info trace.TopicReaderPartitionReadStartResponseStartInfo, ) func(stopInfo trace.TopicReaderPartitionReadStartResponseDoneInfo) { if d.Details()&trace.TopicReaderPartitionEvents == 0 { @@ -85,7 +133,15 @@ func internalTopic(l Logger, d trace.Detailer) (t trace.Topic) { //nolint:gocycl } } } - t.OnReaderPartitionReadStopResponse = func( +} + +func onReaderPartitionReadStopResponse( + l Logger, + d trace.Detailer, +) func( + info trace.TopicReaderPartitionReadStopResponseStartInfo) func( + trace.TopicReaderPartitionReadStopResponseDoneInfo) { + return func( info trace.TopicReaderPartitionReadStopResponseStartInfo, ) func(trace.TopicReaderPartitionReadStopResponseDoneInfo) { if d.Details()&trace.TopicReaderPartitionEvents == 0 { @@ -123,7 +179,13 @@ func internalTopic(l Logger, d trace.Detailer) (t trace.Topic) { //nolint:gocycl } } } - t.OnReaderCommit = func(info trace.TopicReaderCommitStartInfo) func(doneInfo trace.TopicReaderCommitDoneInfo) { +} + +func onReaderCommit( + l Logger, + d trace.Detailer, +) func(info trace.TopicReaderCommitStartInfo) func(doneInfo trace.TopicReaderCommitDoneInfo) { + return func(info trace.TopicReaderCommitStartInfo) func(doneInfo trace.TopicReaderCommitDoneInfo) { if d.Details()&trace.TopicReaderStreamEvents == 0 { return nil } @@ -158,7 +220,13 @@ func internalTopic(l Logger, d trace.Detailer) (t trace.Topic) { //nolint:gocycl } } } - t.OnReaderSendCommitMessage = func( +} + +func onReaderSendCommitMessage( + l Logger, + d trace.Detailer, +) func(info trace.TopicReaderSendCommitMessageStartInfo) func(trace.TopicReaderSendCommitMessageDoneInfo) { + return func( info trace.TopicReaderSendCommitMessageStartInfo, ) func(trace.TopicReaderSendCommitMessageDoneInfo) { if d.Details()&trace.TopicReaderStreamEvents == 0 { @@ -201,7 +269,13 @@ func internalTopic(l Logger, d trace.Detailer) (t trace.Topic) { //nolint:gocycl } } } - t.OnReaderCommittedNotify = func(info trace.TopicReaderCommittedNotifyInfo) { +} + +func onReaderCommittedNotify( + l Logger, + d trace.Detailer, +) func(info trace.TopicReaderCommittedNotifyInfo) { + return func(info trace.TopicReaderCommittedNotifyInfo) { if d.Details()&trace.TopicReaderStreamEvents == 0 { return } @@ -214,7 +288,13 @@ func internalTopic(l Logger, d trace.Detailer) (t trace.Topic) { //nolint:gocycl Int64("committed_offset", info.CommittedOffset), ) } - t.OnReaderClose = func(info trace.TopicReaderCloseStartInfo) func(doneInfo trace.TopicReaderCloseDoneInfo) { +} + +func onReaderClose( + l Logger, + d trace.Detailer, +) func(info trace.TopicReaderCloseStartInfo) func(doneInfo trace.TopicReaderCloseDoneInfo) { + return func(info trace.TopicReaderCloseStartInfo) func(doneInfo trace.TopicReaderCloseDoneInfo) { if d.Details()&trace.TopicReaderStreamEvents == 0 { return nil } @@ -242,8 +322,13 @@ func internalTopic(l Logger, d trace.Detailer) (t trace.Topic) { //nolint:gocycl } } } +} - t.OnReaderInit = func(info trace.TopicReaderInitStartInfo) func(doneInfo trace.TopicReaderInitDoneInfo) { +func onReaderInit( + l Logger, + d trace.Detailer, +) func(info trace.TopicReaderInitStartInfo) func(doneInfo trace.TopicReaderInitDoneInfo) { + return func(info trace.TopicReaderInitStartInfo) func(doneInfo trace.TopicReaderInitDoneInfo) { if d.Details()&trace.TopicReaderStreamEvents == 0 { return nil } @@ -274,7 +359,13 @@ func internalTopic(l Logger, d trace.Detailer) (t trace.Topic) { //nolint:gocycl } } } - t.OnReaderError = func(info trace.TopicReaderErrorInfo) { +} + +func onReaderError( + l Logger, + d trace.Detailer, +) func(info trace.TopicReaderErrorInfo) { + return func(info trace.TopicReaderErrorInfo) { if d.Details()&trace.TopicReaderStreamEvents == 0 { return } @@ -285,7 +376,15 @@ func internalTopic(l Logger, d trace.Detailer) (t trace.Topic) { //nolint:gocycl versionField(), ) } - t.OnReaderUpdateToken = func( +} + +func onReaderUpdateToken( + l Logger, + d trace.Detailer, +) func(info trace.OnReadUpdateTokenStartInfo) func( + updateTokenInfo trace.OnReadUpdateTokenMiddleTokenReceivedInfo) func( + doneInfo trace.OnReadStreamUpdateTokenDoneInfo) { + return func( info trace.OnReadUpdateTokenStartInfo, ) func( updateTokenInfo trace.OnReadUpdateTokenMiddleTokenReceivedInfo, @@ -337,7 +436,13 @@ func internalTopic(l Logger, d trace.Detailer) (t trace.Topic) { //nolint:gocycl } } } - t.OnReaderSentDataRequest = func(info trace.TopicReaderSentDataRequestInfo) { +} + +func onReaderSentDataRequest( + l Logger, + d trace.Detailer, +) func(info trace.TopicReaderSentDataRequestInfo) { + return func(info trace.TopicReaderSentDataRequestInfo) { if d.Details()&trace.TopicReaderMessageEvents == 0 { return } @@ -348,7 +453,13 @@ func internalTopic(l Logger, d trace.Detailer) (t trace.Topic) { //nolint:gocycl Int("local_capacity", info.LocalBufferSizeAfterSent), ) } - t.OnReaderReceiveDataResponse = func( +} + +func onReaderReceiveDataResponse( + l Logger, + d trace.Detailer, +) func(info trace.TopicReaderReceiveDataResponseStartInfo) func(trace.TopicReaderReceiveDataResponseDoneInfo) { + return func( info trace.TopicReaderReceiveDataResponseStartInfo, ) func(trace.TopicReaderReceiveDataResponseDoneInfo) { if d.Details()&trace.TopicReaderMessageEvents == 0 { @@ -392,7 +503,13 @@ func internalTopic(l Logger, d trace.Detailer) (t trace.Topic) { //nolint:gocycl } } } - t.OnReaderReadMessages = func( +} + +func onReaderReadMessages( + l Logger, + d trace.Detailer, +) func(info trace.TopicReaderReadMessagesStartInfo) func(doneInfo trace.TopicReaderReadMessagesDoneInfo) { + return func( info trace.TopicReaderReadMessagesStartInfo, ) func(doneInfo trace.TopicReaderReadMessagesDoneInfo) { if d.Details()&trace.TopicReaderMessageEvents == 0 { @@ -426,7 +543,13 @@ func internalTopic(l Logger, d trace.Detailer) (t trace.Topic) { //nolint:gocycl } } } - t.OnReaderUnknownGrpcMessage = func(info trace.OnReadUnknownGrpcMessageInfo) { +} + +func onReaderUnknownGrpcMessage( + l Logger, + d trace.Detailer, +) func(info trace.OnReadUnknownGrpcMessageInfo) { + return func(info trace.OnReadUnknownGrpcMessageInfo) { if d.Details()&trace.TopicReaderMessageEvents == 0 { return } @@ -436,11 +559,13 @@ func internalTopic(l Logger, d trace.Detailer) (t trace.Topic) { //nolint:gocycl String("reader_connection_id", info.ReaderConnectionID), ) } +} - /// - /// Topic writer - /// - t.OnWriterReconnect = func( +func onWriterReconnect( + l Logger, + d trace.Detailer, +) func(info trace.TopicWriterReconnectStartInfo) func(doneInfo trace.TopicWriterReconnectDoneInfo) { + return func( info trace.TopicWriterReconnectStartInfo, ) func(doneInfo trace.TopicWriterReconnectDoneInfo) { if d.Details()&trace.TopicWriterStreamLifeCycleEvents == 0 { @@ -476,7 +601,13 @@ func internalTopic(l Logger, d trace.Detailer) (t trace.Topic) { //nolint:gocycl } } } - t.OnWriterInitStream = func( +} + +func onWriterInitStream( + l Logger, + d trace.Detailer, +) func(info trace.TopicWriterInitStreamStartInfo) func(doneInfo trace.TopicWriterInitStreamDoneInfo) { + return func( info trace.TopicWriterInitStreamStartInfo, ) func(doneInfo trace.TopicWriterInitStreamDoneInfo) { if d.Details()&trace.TopicWriterStreamLifeCycleEvents == 0 { @@ -512,7 +643,13 @@ func internalTopic(l Logger, d trace.Detailer) (t trace.Topic) { //nolint:gocycl } } } - t.OnWriterClose = func(info trace.TopicWriterCloseStartInfo) func(doneInfo trace.TopicWriterCloseDoneInfo) { +} + +func onWriterClose( + l Logger, + d trace.Detailer, +) func(info trace.TopicWriterCloseStartInfo) func(doneInfo trace.TopicWriterCloseDoneInfo) { + return func(info trace.TopicWriterCloseStartInfo) func(doneInfo trace.TopicWriterCloseDoneInfo) { if d.Details()&trace.TopicWriterStreamLifeCycleEvents == 0 { return nil } @@ -541,7 +678,13 @@ func internalTopic(l Logger, d trace.Detailer) (t trace.Topic) { //nolint:gocycl } } } - t.OnWriterCompressMessages = func( +} + +func onWriterCompressMessages( + l Logger, + d trace.Detailer, +) func(info trace.TopicWriterCompressMessagesStartInfo) func(doneInfo trace.TopicWriterCompressMessagesDoneInfo) { + return func( info trace.TopicWriterCompressMessagesStartInfo, ) func(doneInfo trace.TopicWriterCompressMessagesDoneInfo) { if d.Details()&trace.TopicWriterStreamEvents == 0 { @@ -584,7 +727,13 @@ func internalTopic(l Logger, d trace.Detailer) (t trace.Topic) { //nolint:gocycl } } } - t.OnWriterSendMessages = func( +} + +func onWriterSendMessages( + l Logger, + d trace.Detailer, +) func(info trace.TopicWriterSendMessagesStartInfo) func(doneInfo trace.TopicWriterSendMessagesDoneInfo) { + return func( info trace.TopicWriterSendMessagesStartInfo, ) func(doneInfo trace.TopicWriterSendMessagesDoneInfo) { if d.Details()&trace.TopicWriterStreamEvents == 0 { @@ -623,7 +772,13 @@ func internalTopic(l Logger, d trace.Detailer) (t trace.Topic) { //nolint:gocycl } } } - t.OnWriterReadUnknownGrpcMessage = func(info trace.TopicOnWriterReadUnknownGrpcMessageInfo) { +} + +func onWriterReadUnknownGrpcMessage( + l Logger, + d trace.Detailer, +) func(info trace.TopicOnWriterReadUnknownGrpcMessageInfo) { + return func(info trace.TopicOnWriterReadUnknownGrpcMessageInfo) { if d.Details()&trace.TopicWriterStreamEvents == 0 { return } @@ -634,6 +789,4 @@ func internalTopic(l Logger, d trace.Detailer) (t trace.Topic) { //nolint:gocycl String("session_id", info.SessionID), ) } - - return t } diff --git a/metrics/driver.go b/metrics/driver.go index f20e7ae09..2ba051b16 100644 --- a/metrics/driver.go +++ b/metrics/driver.go @@ -8,6 +8,12 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/trace" ) +// endpointKey represents a key for mapping endpoints to their properties. +type endpointKey struct { + localDC bool // a boolean indicating if the endpoint is in the local data center + az string // a string representing the availability zone of the endpoint +} + // driver makes driver with New publishing func driver(config Config) (t trace.Driver) { config = config.WithSystem("driver") @@ -19,13 +25,29 @@ func driver(config Config) (t trace.Driver) { requests := config.WithSystem("conn").CounterVec("requests", "status", "method", "endpoint", "node_id") tli := config.CounterVec("transaction_locks_invalidated") - type endpointKey struct { - localDC bool - az string - } knownEndpoints := make(map[endpointKey]struct{}) + driverConnEvents := config.Details() & trace.DriverConnEvents + driverBalancerEvents := config.Details() & trace.DriverBalancerEvents - t.OnConnInvoke = func(info trace.DriverConnInvokeStartInfo) func(trace.DriverConnInvokeDoneInfo) { + t.OnConnInvoke = connInvoke(driverConnEvents, requests, tli) + t.OnConnNewStream = connNewStream(driverConnEvents, requests) + t.OnConnBan = connBan(driverConnEvents, banned) + t.OnBalancerClusterDiscoveryAttempt = balancerClusterDiscoveryAttempt(balancersDiscoveries) + t.OnBalancerUpdate = balancerUpdate(driverBalancerEvents, balancerUpdates, knownEndpoints, endpoints) + t.OnConnDial = connDial(driverConnEvents, conns) + t.OnConnClose = connClose(driverConnEvents, conns) + + return t +} + +// connInvoke is a function that returns a callback function to be called +// when a driver connection invoke starts and when it is done. +func connInvoke( + driverConnEvents trace.Details, + requests CounterVec, + tli CounterVec, +) func(info trace.DriverConnInvokeStartInfo) func(trace.DriverConnInvokeDoneInfo) { + return func(info trace.DriverConnInvokeStartInfo) func(trace.DriverConnInvokeDoneInfo) { var ( method = info.Method endpoint = info.Endpoint.Address() @@ -33,7 +55,7 @@ func driver(config Config) (t trace.Driver) { ) return func(info trace.DriverConnInvokeDoneInfo) { - if config.Details()&trace.DriverConnEvents != 0 { + if driverConnEvents != 0 { requests.With(map[string]string{ "status": errorBrief(info.Error), "method": string(method), @@ -46,7 +68,16 @@ func driver(config Config) (t trace.Driver) { } } } - t.OnConnNewStream = func(info trace.DriverConnNewStreamStartInfo) func( +} + +// connNewStream receives the `driverConnEvents` and `requests` parameters and returns a closure function. +func connNewStream( + driverConnEvents trace.Details, + requests CounterVec, +) func( + info trace.DriverConnNewStreamStartInfo, +) func(trace.DriverConnNewStreamRecvInfo) func(trace.DriverConnNewStreamDoneInfo) { + return func(info trace.DriverConnNewStreamStartInfo) func( trace.DriverConnNewStreamRecvInfo, ) func( trace.DriverConnNewStreamDoneInfo, @@ -59,7 +90,7 @@ func driver(config Config) (t trace.Driver) { return func(info trace.DriverConnNewStreamRecvInfo) func(trace.DriverConnNewStreamDoneInfo) { return func(info trace.DriverConnNewStreamDoneInfo) { - if config.Details()&trace.DriverConnEvents != 0 { + if driverConnEvents != 0 { requests.With(map[string]string{ "status": errorBrief(info.Error), "method": string(method), @@ -70,8 +101,15 @@ func driver(config Config) (t trace.Driver) { } } } - t.OnConnBan = func(info trace.DriverConnBanStartInfo) func(trace.DriverConnBanDoneInfo) { - if config.Details()&trace.DriverConnEvents != 0 { +} + +// connBan is a function that returns a closure wrapping the logic for tracing a connection ban event. +func connBan( + driverConnEvents trace.Details, + banned GaugeVec, +) func(info trace.DriverConnBanStartInfo) func(trace.DriverConnBanDoneInfo) { + return func(info trace.DriverConnBanStartInfo) func(trace.DriverConnBanDoneInfo) { + if driverConnEvents != 0 { banned.With(map[string]string{ "endpoint": info.Endpoint.Address(), "node_id": idToString(info.Endpoint.NodeID()), @@ -81,7 +119,15 @@ func driver(config Config) (t trace.Driver) { return nil } - t.OnBalancerClusterDiscoveryAttempt = func(info trace.DriverBalancerClusterDiscoveryAttemptStartInfo) func( +} + +// balancerClusterDiscoveryAttempt performs balancer cluster discovery attempt. +func balancerClusterDiscoveryAttempt( + balancersDiscoveries CounterVec, +) func( + info trace.DriverBalancerClusterDiscoveryAttemptStartInfo, +) func(trace.DriverBalancerClusterDiscoveryAttemptDoneInfo) { + return func(info trace.DriverBalancerClusterDiscoveryAttemptStartInfo) func( trace.DriverBalancerClusterDiscoveryAttemptDoneInfo, ) { eventType := repeater.EventType(*info.Context) @@ -93,11 +139,20 @@ func driver(config Config) (t trace.Driver) { }).Inc() } } - t.OnBalancerUpdate = func(info trace.DriverBalancerUpdateStartInfo) func(trace.DriverBalancerUpdateDoneInfo) { +} + +// balancerUpdate updates the balancer with new endpoint information. +func balancerUpdate( + driverBalancerEvents trace.Details, + balancerUpdates CounterVec, + knownEndpoints map[endpointKey]struct{}, + endpoints GaugeVec, +) func(info trace.DriverBalancerUpdateStartInfo) func(trace.DriverBalancerUpdateDoneInfo) { + return func(info trace.DriverBalancerUpdateStartInfo) func(trace.DriverBalancerUpdateDoneInfo) { eventType := repeater.EventType(*info.Context) return func(info trace.DriverBalancerUpdateDoneInfo) { - if config.Details()&trace.DriverBalancerEvents != 0 { + if driverBalancerEvents != 0 { balancerUpdates.With(map[string]string{ "cause": eventType, }).Inc() @@ -128,12 +183,20 @@ func driver(config Config) (t trace.Driver) { } } } - t.OnConnDial = func(info trace.DriverConnDialStartInfo) func(trace.DriverConnDialDoneInfo) { +} + +// connDial is a function that returns a closure function to handle the event when a driver connection dialing starts +// and completes. +func connDial( + driverConnEvents trace.Details, + conns GaugeVec, +) func(info trace.DriverConnDialStartInfo) func(trace.DriverConnDialDoneInfo) { + return func(info trace.DriverConnDialStartInfo) func(trace.DriverConnDialDoneInfo) { endpoint := info.Endpoint.Address() nodeID := info.Endpoint.NodeID() return func(info trace.DriverConnDialDoneInfo) { - if config.Details()&trace.DriverConnEvents != 0 { + if driverConnEvents != 0 { if info.Error == nil { conns.With(map[string]string{ "endpoint": endpoint, @@ -143,8 +206,15 @@ func driver(config Config) (t trace.Driver) { } } } - t.OnConnClose = func(info trace.DriverConnCloseStartInfo) func(trace.DriverConnCloseDoneInfo) { - if config.Details()&trace.DriverConnEvents != 0 { +} + +// connClose is a function that returns a closure accepting a trace.DriverConnCloseStartInfo parameter +func connClose( + driverConnEvents trace.Details, + conns GaugeVec, +) func(info trace.DriverConnCloseStartInfo) func(trace.DriverConnCloseDoneInfo) { + return func(info trace.DriverConnCloseStartInfo) func(trace.DriverConnCloseDoneInfo) { + if driverConnEvents != 0 { conns.With(map[string]string{ "endpoint": info.Endpoint.Address(), "node_id": idToString(info.Endpoint.NodeID()), @@ -153,6 +223,4 @@ func driver(config Config) (t trace.Driver) { return nil } - - return t } diff --git a/metrics/sql.go b/metrics/sql.go index 66b1303a2..1cc0aa76e 100644 --- a/metrics/sql.go +++ b/metrics/sql.go @@ -27,10 +27,34 @@ func databaseSQL(config Config) (t trace.DatabaseSQL) { txCommitLatency := config.WithSystem("commit").TimerVec("latency") txRollback := config.CounterVec("rollback", "status") txRollbackLatency := config.WithSystem("rollback").TimerVec("latency") - t.OnConnectorConnect = func(info trace.DatabaseSQLConnectorConnectStartInfo) func( + + databaseSQLConnectorEvents := config.Details() & trace.DatabaseSQLConnectorEvents + databaseSQLTxEvents := config.Details() & trace.DatabaseSQLTxEvents + databaseSQLEvents := config.Details() & trace.DatabaseSQLEvents + databaseSQLConnEvents := config.Details() & trace.DatabaseSQLConnEvents + + t.OnConnectorConnect = onConnectorConnect(databaseSQLConnectorEvents, conns) + t.OnConnClose = onConnCloseDatabaseSQL(databaseSQLConnectorEvents, conns) + t.OnConnBegin = onConnBeginDatabaseSQL(databaseSQLTxEvents, txBegin, txBeginLatency) + t.OnTxCommit = onTxCommit(databaseSQLTxEvents, txCommit, txCommitLatency) + t.OnTxExec = onTxExec(databaseSQLTxEvents, txExec, txExecLatency) + t.OnTxQuery = onTxQuery(databaseSQLTxEvents, txQuery, txQueryLatency) + t.OnTxRollback = onTxRollback(databaseSQLTxEvents, txRollback, txRollbackLatency) + t.OnConnExec = onConnExec(databaseSQLEvents, databaseSQLConnEvents, inflight, exec, execLatency) + t.OnConnQuery = onConnQuery(databaseSQLEvents, databaseSQLConnEvents, inflight, query, queryLatency) + + return t +} + +// The `onConnectorConnect` function is called when a connection is established to the database. +func onConnectorConnect( + databaseSQLConnectorEvents trace.Details, + conns GaugeVec, +) func(info trace.DatabaseSQLConnectorConnectStartInfo) func(trace.DatabaseSQLConnectorConnectDoneInfo) { + return func(info trace.DatabaseSQLConnectorConnectStartInfo) func( trace.DatabaseSQLConnectorConnectDoneInfo, ) { - if config.Details()&trace.DatabaseSQLConnectorEvents != 0 { + if databaseSQLConnectorEvents != 0 { return func(info trace.DatabaseSQLConnectorConnectDoneInfo) { if info.Error == nil { conns.With(nil).Add(1) @@ -40,8 +64,15 @@ func databaseSQL(config Config) (t trace.DatabaseSQL) { return nil } - t.OnConnClose = func(info trace.DatabaseSQLConnCloseStartInfo) func(trace.DatabaseSQLConnCloseDoneInfo) { - if config.Details()&trace.DatabaseSQLConnectorEvents != 0 { +} + +// onConnCloseDatabaseSQL is a function that is used to handle the closing of a database connection. +func onConnCloseDatabaseSQL( + databaseSQLConnectorEvents trace.Details, + conns GaugeVec, +) func(info trace.DatabaseSQLConnCloseStartInfo) func(trace.DatabaseSQLConnCloseDoneInfo) { + return func(info trace.DatabaseSQLConnCloseStartInfo) func(trace.DatabaseSQLConnCloseDoneInfo) { + if databaseSQLConnectorEvents != 0 { return func(info trace.DatabaseSQLConnCloseDoneInfo) { conns.With(nil).Add(-1) } @@ -49,9 +80,18 @@ func databaseSQL(config Config) (t trace.DatabaseSQL) { return nil } - t.OnConnBegin = func(info trace.DatabaseSQLConnBeginStartInfo) func(trace.DatabaseSQLConnBeginDoneInfo) { +} + +// onConnBeginDatabaseSQL measures `trace.DatabaseSQLConnBeginStartInfo` events and updates `txBegin` +// and `txBeginLatency` +func onConnBeginDatabaseSQL( + databaseSQLTxEvents trace.Details, + txBegin CounterVec, + txBeginLatency TimerVec, +) func(info trace.DatabaseSQLConnBeginStartInfo) func(trace.DatabaseSQLConnBeginDoneInfo) { + return func(info trace.DatabaseSQLConnBeginStartInfo) func(trace.DatabaseSQLConnBeginDoneInfo) { start := time.Now() - if config.Details()&trace.DatabaseSQLTxEvents != 0 { + if databaseSQLTxEvents != 0 { return func(info trace.DatabaseSQLConnBeginDoneInfo) { txBegin.With(map[string]string{ "status": errorBrief(info.Error), @@ -62,11 +102,20 @@ func databaseSQL(config Config) (t trace.DatabaseSQL) { return nil } - t.OnTxCommit = func(info trace.DatabaseSQLTxCommitStartInfo) func(trace.DatabaseSQLTxCommitDoneInfo) { +} + +// onTxCommit is a function that returns a closure function. The closure function +// takes a trace.DatabaseSQLTxCommitDoneInfo argument and performs certain actions based on the input. +func onTxCommit( + databaseSQLTxEvents trace.Details, + txCommit CounterVec, + txCommitLatency TimerVec, +) func(info trace.DatabaseSQLTxCommitStartInfo) func(trace.DatabaseSQLTxCommitDoneInfo) { + return func(info trace.DatabaseSQLTxCommitStartInfo) func(trace.DatabaseSQLTxCommitDoneInfo) { start := time.Now() return func(info trace.DatabaseSQLTxCommitDoneInfo) { - if config.Details()&trace.DatabaseSQLTxEvents != 0 { + if databaseSQLTxEvents != 0 { txCommit.With(map[string]string{ "status": errorBrief(info.Error), }).Inc() @@ -74,11 +123,20 @@ func databaseSQL(config Config) (t trace.DatabaseSQL) { } } } - t.OnTxExec = func(info trace.DatabaseSQLTxExecStartInfo) func(trace.DatabaseSQLTxExecDoneInfo) { +} + +// onTxExec is a function that returns a callback function to be executed when a trace.DatabaseSQLTxExecDoneInfo event +// occurs. +func onTxExec( + databaseSQLTxEvents trace.Details, + txExec CounterVec, + txExecLatency TimerVec, +) func(info trace.DatabaseSQLTxExecStartInfo) func(trace.DatabaseSQLTxExecDoneInfo) { + return func(info trace.DatabaseSQLTxExecStartInfo) func(trace.DatabaseSQLTxExecDoneInfo) { start := time.Now() return func(info trace.DatabaseSQLTxExecDoneInfo) { - if config.Details()&trace.DatabaseSQLTxEvents != 0 { + if databaseSQLTxEvents != 0 { status := errorBrief(info.Error) txExec.With(map[string]string{ "status": status, @@ -87,11 +145,20 @@ func databaseSQL(config Config) (t trace.DatabaseSQL) { } } } - t.OnTxQuery = func(info trace.DatabaseSQLTxQueryStartInfo) func(trace.DatabaseSQLTxQueryDoneInfo) { +} + +// onTxQuery is a callback function that measures trace events related to database transactions and queries +// in the "database/sql" package. +func onTxQuery( + databaseSQLTxEvents trace.Details, + txQuery CounterVec, + txQueryLatency TimerVec, +) func(info trace.DatabaseSQLTxQueryStartInfo) func(trace.DatabaseSQLTxQueryDoneInfo) { + return func(info trace.DatabaseSQLTxQueryStartInfo) func(trace.DatabaseSQLTxQueryDoneInfo) { start := time.Now() return func(info trace.DatabaseSQLTxQueryDoneInfo) { - if config.Details()&trace.DatabaseSQLTxEvents != 0 { + if databaseSQLTxEvents != 0 { status := errorBrief(info.Error) txQuery.With(map[string]string{ "status": status, @@ -100,11 +167,20 @@ func databaseSQL(config Config) (t trace.DatabaseSQL) { } } } - t.OnTxRollback = func(info trace.DatabaseSQLTxRollbackStartInfo) func(trace.DatabaseSQLTxRollbackDoneInfo) { +} + +// onTxRollback is a function that returns a closure. The closure takes in a `trace.DatabaseSQLTxRollbackStartInfo` +// argument and returns another closure that takes in a `trace.Database`. +func onTxRollback( + databaseSQLTxEvents trace.Details, + txRollback CounterVec, + txRollbackLatency TimerVec, +) func(info trace.DatabaseSQLTxRollbackStartInfo) func(trace.DatabaseSQLTxRollbackDoneInfo) { + return func(info trace.DatabaseSQLTxRollbackStartInfo) func(trace.DatabaseSQLTxRollbackDoneInfo) { start := time.Now() return func(info trace.DatabaseSQLTxRollbackDoneInfo) { - if config.Details()&trace.DatabaseSQLTxEvents != 0 { + if databaseSQLTxEvents != 0 { txRollback.With(map[string]string{ "status": errorBrief(info.Error), }).Inc() @@ -112,8 +188,18 @@ func databaseSQL(config Config) (t trace.DatabaseSQL) { } } } - t.OnConnExec = func(info trace.DatabaseSQLConnExecStartInfo) func(trace.DatabaseSQLConnExecDoneInfo) { - if config.Details()&trace.DatabaseSQLEvents != 0 { +} + +// onConnExec measures the execution of a database/sql connection command. +func onConnExec( + databaseSQLEvents trace.Details, + databaseSQLConnEvents trace.Details, + inflight GaugeVec, + exec CounterVec, + execLatency TimerVec, +) func(info trace.DatabaseSQLConnExecStartInfo) func(trace.DatabaseSQLConnExecDoneInfo) { + return func(info trace.DatabaseSQLConnExecStartInfo) func(trace.DatabaseSQLConnExecDoneInfo) { + if databaseSQLEvents != 0 { inflight.With(nil).Add(1) } var ( @@ -122,10 +208,10 @@ func databaseSQL(config Config) (t trace.DatabaseSQL) { ) return func(info trace.DatabaseSQLConnExecDoneInfo) { - if config.Details()&trace.DatabaseSQLEvents != 0 { + if databaseSQLEvents != 0 { inflight.With(nil).Add(-1) } - if config.Details()&trace.DatabaseSQLConnEvents != 0 { + if databaseSQLConnEvents != 0 { status := errorBrief(info.Error) exec.With(map[string]string{ "status": status, @@ -137,8 +223,18 @@ func databaseSQL(config Config) (t trace.DatabaseSQL) { } } } - t.OnConnQuery = func(info trace.DatabaseSQLConnQueryStartInfo) func(trace.DatabaseSQLConnQueryDoneInfo) { - if config.Details()&trace.DatabaseSQLEvents != 0 { +} + +// onConnQuery handles the start and completion of a connection query event. +func onConnQuery( + databaseSQLEvents trace.Details, + databaseSQLConnEvents trace.Details, + inflight GaugeVec, + query CounterVec, + queryLatency TimerVec, +) func(info trace.DatabaseSQLConnQueryStartInfo) func(trace.DatabaseSQLConnQueryDoneInfo) { + return func(info trace.DatabaseSQLConnQueryStartInfo) func(trace.DatabaseSQLConnQueryDoneInfo) { + if databaseSQLEvents != 0 { inflight.With(nil).Add(1) } var ( @@ -147,10 +243,10 @@ func databaseSQL(config Config) (t trace.DatabaseSQL) { ) return func(info trace.DatabaseSQLConnQueryDoneInfo) { - if config.Details()&trace.DatabaseSQLEvents != 0 { + if databaseSQLEvents != 0 { inflight.With(nil).Add(-1) } - if config.Details()&trace.DatabaseSQLConnEvents != 0 { + if databaseSQLConnEvents != 0 { status := errorBrief(info.Error) query.With(map[string]string{ "status": status, @@ -162,6 +258,4 @@ func databaseSQL(config Config) (t trace.DatabaseSQL) { } } } - - return t } diff --git a/retry/retry.go b/retry/retry.go index d936fd5a2..b49b0dbdb 100644 --- a/retry/retry.go +++ b/retry/retry.go @@ -276,21 +276,7 @@ func Retry(ctx context.Context, op retryOperation, opts ...Option) (finalErr err ) default: - err := func() (err error) { - if options.panicCallback != nil { - defer func() { - if e := recover(); e != nil { - options.panicCallback(e) - err = xerrors.WithStackTrace( - fmt.Errorf("panic recovered: %v", e), - ) - } - }() - } - - return op(ctx) - }() - + err := RecoveryCallbackWrapper(ctx, op, options) if err == nil { return nil } @@ -335,6 +321,25 @@ func Retry(ctx context.Context, op retryOperation, opts ...Option) (finalErr err } } +// RecoveryCallbackWrapper is a function that wraps the provided `op` retry operation +// with a panic recovery mechanism. If the `options.panicCallback` is specified, it +// is invoked with the recovered value, and an error is returned with the recovered value wrapped +// as the cause. This function returns the result of the `op` retry operation. +func RecoveryCallbackWrapper(context context.Context, op retryOperation, options *retryOptions) (err error) { + if options.panicCallback != nil { + defer func() { + if e := recover(); e != nil { + options.panicCallback(e) + err = xerrors.WithStackTrace( + fmt.Errorf("panic recovered: %v", e), + ) + } + }() + } + + return op(context) +} + // Check returns retry mode for queryErr. func Check(err error) (m retryMode) { code, errType, backoffType, deleteSession := xerrors.Check(err) diff --git a/retry/sql_test.go b/retry/sql_test.go index 935179834..8cf819ea0 100644 --- a/retry/sql_test.go +++ b/retry/sql_test.go @@ -180,7 +180,6 @@ func (m *mockStmt) QueryContext(ctx context.Context, args []driver.NamedValue) ( return m.conn.QueryContext(ctx, m.query, args) } -//nolint:nestif func TestDoTx(t *testing.T) { for _, idempotentType := range []idempotency{ idempotent, @@ -228,27 +227,44 @@ func TestDoTx(t *testing.T) { }, }), ) - if tt.canRetry[idempotentType] { - if err != nil { - t.Errorf("unexpected err after attempts=%d and driver conns=%d: %v)", attempts, m.conns, err) - } - if attempts <= 1 { - t.Errorf("must be attempts > 1 (actual=%d), driver conns=%d)", attempts, m.conns) - } - if tt.deleteSession { - if m.conns <= 1 { - t.Errorf("must be retry on different conns (attempts=%d, driver conns=%d)", attempts, m.conns) - } - } else { - if m.conns > 1 { - t.Errorf("must be retry on single conn (attempts=%d, driver conns=%d)", attempts, m.conns) - } - } - } else if err == nil { - t.Errorf("unexpected nil err (attempts=%d, driver conns=%d)", attempts, m.conns) - } + canRetry(t, tt, idempotentType, err, attempts, m) }) } }) } } + +// canRetry checks if a retry can be performed based on the given parameters. +// +//nolint:nestif +func canRetry(t *testing.T, tt struct { + err error + backoff backoff.Type + deleteSession bool + canRetry map[idempotency]bool +}, + idempotentType idempotency, + err error, + attempts int, + m *mockConnector, +) { + if tt.canRetry[idempotentType] { + if err != nil { + t.Errorf("unexpected err after attempts=%d and driver conns=%d: %v)", attempts, m.conns, err) + } + if attempts <= 1 { + t.Errorf("must be attempts > 1 (actual=%d), driver conns=%d)", attempts, m.conns) + } + if tt.deleteSession { + if m.conns <= 1 { + t.Errorf("must be retry on different conns (attempts=%d, driver conns=%d)", attempts, m.conns) + } + } else { + if m.conns > 1 { + t.Errorf("must be retry on single conn (attempts=%d, driver conns=%d)", attempts, m.conns) + } + } + } else if err == nil { + t.Errorf("unexpected nil err (attempts=%d, driver conns=%d)", attempts, m.conns) + } +} diff --git a/sugar/path.go b/sugar/path.go index cc22e15aa..e74af091e 100644 --- a/sugar/path.go +++ b/sugar/path.go @@ -124,36 +124,9 @@ func RemoveRecursive(ctx context.Context, db dbFoRemoveRecursive, pathToRemove s if pt == fullSysTablePath { continue } - switch t := dir.Children[j].Type; t { - case scheme.EntryDirectory: - if err = rmPath(i+1, pt); err != nil { - return xerrors.WithStackTrace( - fmt.Errorf("recursive removing directory %q failed: %w", pt, err), - ) - } - - case scheme.EntryTable, scheme.EntryColumnTable: - err = db.Table().Do(ctx, func(ctx context.Context, session table.Session) (err error) { - return session.DropTable(ctx, pt) - }, table.WithIdempotent()) - if err != nil { - return xerrors.WithStackTrace( - fmt.Errorf("removing table %q failed: %w", pt, err), - ) - } - - case scheme.EntryTopic: - err = db.Topic().Drop(ctx, pt) - if err != nil { - return xerrors.WithStackTrace( - fmt.Errorf("removing topic %q failed: %w", pt, err), - ) - } - - default: - return xerrors.WithStackTrace( - fmt.Errorf("unknown entry type: %s", t.String()), - ) + err = removeEntry(ctx, i, pt, j, db, &dir, rmPath) + if err != nil { + return err } } @@ -170,9 +143,61 @@ func RemoveRecursive(ctx context.Context, db dbFoRemoveRecursive, pathToRemove s return nil } + pathToRemove = removeWithPrefix(pathToRemove, db) + + return rmPath(0, pathToRemove) +} + +// removeWithPrefix prepends the db.Name() to the pathToRemove string if it does not already have the prefix. +func removeWithPrefix(pathToRemove string, db dbFoRemoveRecursive) string { if !strings.HasPrefix(pathToRemove, db.Name()) { pathToRemove = path.Join(db.Name(), pathToRemove) } - return rmPath(0, pathToRemove) + return pathToRemove +} + +// removeEntry removes an entry from the database. +func removeEntry( + ctx context.Context, + i int, + pt string, + j int, + db dbFoRemoveRecursive, + dir *scheme.Directory, + rmPath func(int, string) error, +) error { + switch t := dir.Children[j].Type; t { + case scheme.EntryDirectory: + if err := rmPath(i+1, pt); err != nil { + return xerrors.WithStackTrace( + fmt.Errorf("recursive removing directory %q failed: %w", pt, err), + ) + } + + case scheme.EntryTable, scheme.EntryColumnTable: + err := db.Table().Do(ctx, func(ctx context.Context, session table.Session) (err error) { + return session.DropTable(ctx, pt) + }, table.WithIdempotent()) + if err != nil { + return xerrors.WithStackTrace( + fmt.Errorf("removing table %q failed: %w", pt, err), + ) + } + + case scheme.EntryTopic: + err := db.Topic().Drop(ctx, pt) + if err != nil { + return xerrors.WithStackTrace( + fmt.Errorf("removing topic %q failed: %w", pt, err), + ) + } + + default: + return xerrors.WithStackTrace( + fmt.Errorf("unknown entry type: %s", t.String()), + ) + } + + return nil }