diff --git a/lib/column/array.go b/lib/column/array.go index 2a0c17d40d..f510bfb10d 100644 --- a/lib/column/array.go +++ b/lib/column/array.go @@ -18,6 +18,7 @@ package column import ( + "database/sql" "fmt" "github.com/ClickHouse/ch-go/proto" "reflect" @@ -25,6 +26,8 @@ import ( "time" ) +var scanTypeAny = reflect.TypeOf((*interface{})(nil)).Elem() + type offset struct { values UInt64 scanType reflect.Type @@ -268,6 +271,13 @@ func (col *Array) WriteStatePrefix(buffer *proto.Buffer) error { } func (col *Array) ScanRow(dest any, row int) error { + if scanner, ok := dest.(sql.Scanner); ok { + value, err := col.scan(scanTypeAny, row) + if err != nil { + return err + } + return scanner.Scan(value.Interface()) + } elem := reflect.Indirect(reflect.ValueOf(dest)) value, err := col.scan(elem.Type(), row) if err != nil { diff --git a/lib/column/ipv4.go b/lib/column/ipv4.go index 3d4c252833..e0a869cc9a 100644 --- a/lib/column/ipv4.go +++ b/lib/column/ipv4.go @@ -18,6 +18,7 @@ package column import ( + "database/sql" "database/sql/driver" "encoding/binary" "fmt" @@ -98,6 +99,8 @@ func (col *IPv4) ScanRow(dest any, row int) error { } *d = new(uint32) **d = binary.BigEndian.Uint32(ipV4[:]) + case sql.Scanner: + return d.Scan(col.row(row)) default: return &ColumnConverterError{ Op: "ScanRow", diff --git a/lib/column/ipv6.go b/lib/column/ipv6.go index a67d17abc4..0544e7931d 100644 --- a/lib/column/ipv6.go +++ b/lib/column/ipv6.go @@ -18,6 +18,7 @@ package column import ( + "database/sql" "database/sql/driver" "fmt" "github.com/ClickHouse/ch-go/proto" @@ -91,6 +92,8 @@ func (col *IPv6) ScanRow(dest any, row int) error { case **[16]byte: *d = new([16]byte) **d = col.col.Row(row) + case sql.Scanner: + return d.Scan(col.row(row)) default: return &ColumnConverterError{ Op: "ScanRow", diff --git a/lib/column/map.go b/lib/column/map.go index a33514634d..70da38e62e 100644 --- a/lib/column/map.go +++ b/lib/column/map.go @@ -18,6 +18,7 @@ package column import ( + "database/sql" "database/sql/driver" "fmt" "reflect" @@ -114,6 +115,9 @@ func (col *Map) Row(i int, ptr bool) any { } func (col *Map) ScanRow(dest any, i int) error { + if scanner, ok := dest.(sql.Scanner); ok { + return scanner.Scan(col.row(i).Interface()) + } value := reflect.Indirect(reflect.ValueOf(dest)) if value.Type() == col.scanType { value.Set(col.row(i)) diff --git a/tests/array_test.go b/tests/array_test.go index 551e133e6b..014ff33fb2 100644 --- a/tests/array_test.go +++ b/tests/array_test.go @@ -372,3 +372,50 @@ func TestSimpleArrayValuer(t *testing.T) { require.NoError(t, rows.Close()) require.NoError(t, rows.Err()) } + +func TestSQLScannerArray(t *testing.T) { + conn, err := GetNativeConnection(nil, nil, &clickhouse.Compression{ + Method: clickhouse.CompressionLZ4, + }) + ctx := context.Background() + require.NoError(t, err) + const ddl = ` + CREATE TABLE test_array ( + Col1 Array(String) + ) Engine MergeTree() ORDER BY tuple() + ` + defer func() { + conn.Exec(ctx, "DROP TABLE test_array") + }() + require.NoError(t, conn.Exec(ctx, ddl)) + batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_array") + require.NoError(t, err) + var ( + col1Data = []string{"A", "b", "c"} + ) + for i := 0; i < 10; i++ { + require.NoError(t, batch.Append(col1Data)) + } + require.Equal(t, 10, batch.Rows()) + require.Nil(t, batch.Send()) + rows, err := conn.Query(ctx, "SELECT * FROM test_array") + require.NoError(t, err) + for rows.Next() { + var ( + col1 = sqlScannerArray{} + ) + require.NoError(t, rows.Scan(&col1)) + assert.Equal(t, col1Data, col1.value) + } + require.NoError(t, rows.Close()) + require.NoError(t, rows.Err()) +} + +type sqlScannerArray struct { + value any +} + +func (s *sqlScannerArray) Scan(src any) error { + s.value = src + return nil +} diff --git a/tests/ipv4_test.go b/tests/ipv4_test.go index ea67633b81..ceda75cf43 100644 --- a/tests/ipv4_test.go +++ b/tests/ipv4_test.go @@ -556,3 +556,44 @@ func TestIPv4Valuer(t *testing.T) { } require.Equal(t, 1000, i) } + +func TestSQLScannerIPv4(t *testing.T) { + conn, err := GetNativeConnection(nil, nil, &clickhouse.Compression{ + Method: clickhouse.CompressionLZ4, + }) + ctx := context.Background() + require.NoError(t, err) + const ddl = ` + CREATE TABLE test_ipv4 ( + Col1 IPv4 + ) Engine MergeTree() ORDER BY tuple() + ` + defer func() { + conn.Exec(ctx, "DROP TABLE test_ipv4") + }() + + require.NoError(t, conn.Exec(ctx, ddl)) + batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_ipv4") + require.NoError(t, err) + + var ( + col1Data = net.ParseIP("127.0.0.1") + ) + require.NoError(t, batch.Append(col1Data)) + require.Equal(t, 1, batch.Rows()) + require.NoError(t, batch.Send()) + var ( + col1 sqlScannerIPv4 + ) + require.NoError(t, conn.QueryRow(ctx, "SELECT * FROM test_ipv4").Scan(&col1)) + assert.Equal(t, col1Data.To4(), col1.value) +} + +type sqlScannerIPv4 struct { + value any +} + +func (s *sqlScannerIPv4) Scan(src any) error { + s.value = src + return nil +} diff --git a/tests/ipv6_test.go b/tests/ipv6_test.go index f8b0f8df25..6aee2ae130 100644 --- a/tests/ipv6_test.go +++ b/tests/ipv6_test.go @@ -21,13 +21,14 @@ import ( "context" "database/sql/driver" "fmt" - "github.com/ClickHouse/ch-go/proto" - "github.com/ClickHouse/clickhouse-go/v2/lib/column" - "github.com/stretchr/testify/require" "net" "net/netip" "testing" + "github.com/ClickHouse/ch-go/proto" + "github.com/ClickHouse/clickhouse-go/v2/lib/column" + "github.com/stretchr/testify/require" + "github.com/ClickHouse/clickhouse-go/v2" "github.com/stretchr/testify/assert" ) @@ -520,3 +521,44 @@ func TestIPv6Valuer(t *testing.T) { } require.Equal(t, 1000, i) } + +func TestSQLScannerIPv6(t *testing.T) { + conn, err := GetNativeConnection(nil, nil, &clickhouse.Compression{ + Method: clickhouse.CompressionLZ4, + }) + ctx := context.Background() + require.NoError(t, err) + const ddl = ` + CREATE TABLE test_ipv6 ( + Col1 IPv6 + ) Engine MergeTree() ORDER BY tuple() + ` + defer func() { + conn.Exec(ctx, "DROP TABLE test_ipv6") + }() + + require.NoError(t, conn.Exec(ctx, ddl)) + batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_ipv6") + require.NoError(t, err) + + var ( + col1Data = net.ParseIP("2001:44c8:129:2632:33:0:252:2") + ) + require.NoError(t, batch.Append(col1Data)) + require.Equal(t, 1, batch.Rows()) + require.NoError(t, batch.Send()) + var ( + col1 sqlScannerIPv6 + ) + require.NoError(t, conn.QueryRow(ctx, "SELECT * FROM test_ipv6").Scan(&col1)) + assert.Equal(t, col1Data, col1.value) +} + +type sqlScannerIPv6 struct { + value any +} + +func (s *sqlScannerIPv6) Scan(src any) error { + s.value = src + return nil +} diff --git a/tests/map_test.go b/tests/map_test.go index f3be3105dc..cecf53e11c 100644 --- a/tests/map_test.go +++ b/tests/map_test.go @@ -459,6 +459,52 @@ func (i *mapIter) Value() any { return i.om.valuesIter[i.iterIndex] } +func TestSQLScannerMap(t *testing.T) { + conn, err := GetNativeConnection(clickhouse.Settings{}, nil, &clickhouse.Compression{ + Method: clickhouse.CompressionLZ4, + }) + ctx := context.Background() + require.NoError(t, err) + if !CheckMinServerServerVersion(conn, 21, 9, 0) { + t.Skip(fmt.Errorf("unsupported clickhouse version")) + return + } + const ddl = ` + CREATE TABLE test_map ( + Col1 Map(String, UInt64) + ) Engine MergeTree() ORDER BY tuple() + ` + defer func() { + conn.Exec(ctx, "DROP TABLE IF EXISTS test_map") + }() + require.NoError(t, conn.Exec(ctx, ddl)) + batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_map") + require.NoError(t, err) + var ( + col1Data = map[string]uint64{ + "key_col_1_1": 1, + "key_col_1_2": 2, + } + ) + require.NoError(t, batch.Append(col1Data)) + require.Equal(t, 1, batch.Rows()) + require.NoError(t, batch.Send()) + var ( + col1 sqlScannerMap + ) + require.NoError(t, conn.QueryRow(ctx, "SELECT * FROM test_map").Scan(&col1)) + assert.Equal(t, col1Data, col1.value) +} + +type sqlScannerMap struct { + value any +} + +func (s *sqlScannerMap) Scan(src any) error { + s.value = src + return nil +} + func BenchmarkOrderedMapUseChanGo(b *testing.B) { m := NewOrderedMap() for i := 0; i < 10; i++ {