Skip to content

Commit 8bf4552

Browse files
Merge pull request #1570 from open2b/main
Supports scanning of Array, IPv4, IPv6, and Map types into Go values that implement the `sql.Scanner` interface.
2 parents 993aaee + 574280c commit 8bf4552

File tree

8 files changed

+199
-3
lines changed

8 files changed

+199
-3
lines changed

lib/column/array.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,16 @@
1818
package column
1919

2020
import (
21+
"database/sql"
2122
"fmt"
2223
"github.com/ClickHouse/ch-go/proto"
2324
"reflect"
2425
"strings"
2526
"time"
2627
)
2728

29+
var scanTypeAny = reflect.TypeOf((*interface{})(nil)).Elem()
30+
2831
type offset struct {
2932
values UInt64
3033
scanType reflect.Type
@@ -268,6 +271,13 @@ func (col *Array) WriteStatePrefix(buffer *proto.Buffer) error {
268271
}
269272

270273
func (col *Array) ScanRow(dest any, row int) error {
274+
if scanner, ok := dest.(sql.Scanner); ok {
275+
value, err := col.scan(scanTypeAny, row)
276+
if err != nil {
277+
return err
278+
}
279+
return scanner.Scan(value.Interface())
280+
}
271281
elem := reflect.Indirect(reflect.ValueOf(dest))
272282
value, err := col.scan(elem.Type(), row)
273283
if err != nil {

lib/column/ipv4.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package column
1919

2020
import (
21+
"database/sql"
2122
"database/sql/driver"
2223
"encoding/binary"
2324
"fmt"
@@ -98,6 +99,8 @@ func (col *IPv4) ScanRow(dest any, row int) error {
9899
}
99100
*d = new(uint32)
100101
**d = binary.BigEndian.Uint32(ipV4[:])
102+
case sql.Scanner:
103+
return d.Scan(col.row(row))
101104
default:
102105
return &ColumnConverterError{
103106
Op: "ScanRow",

lib/column/ipv6.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package column
1919

2020
import (
21+
"database/sql"
2122
"database/sql/driver"
2223
"fmt"
2324
"github.com/ClickHouse/ch-go/proto"
@@ -91,6 +92,8 @@ func (col *IPv6) ScanRow(dest any, row int) error {
9192
case **[16]byte:
9293
*d = new([16]byte)
9394
**d = col.col.Row(row)
95+
case sql.Scanner:
96+
return d.Scan(col.row(row))
9497
default:
9598
return &ColumnConverterError{
9699
Op: "ScanRow",

lib/column/map.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package column
1919

2020
import (
21+
"database/sql"
2122
"database/sql/driver"
2223
"fmt"
2324
"reflect"
@@ -114,6 +115,9 @@ func (col *Map) Row(i int, ptr bool) any {
114115
}
115116

116117
func (col *Map) ScanRow(dest any, i int) error {
118+
if scanner, ok := dest.(sql.Scanner); ok {
119+
return scanner.Scan(col.row(i).Interface())
120+
}
117121
value := reflect.Indirect(reflect.ValueOf(dest))
118122
if value.Type() == col.scanType {
119123
value.Set(col.row(i))

tests/array_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,3 +372,50 @@ func TestSimpleArrayValuer(t *testing.T) {
372372
require.NoError(t, rows.Close())
373373
require.NoError(t, rows.Err())
374374
}
375+
376+
func TestSQLScannerArray(t *testing.T) {
377+
conn, err := GetNativeConnection(nil, nil, &clickhouse.Compression{
378+
Method: clickhouse.CompressionLZ4,
379+
})
380+
ctx := context.Background()
381+
require.NoError(t, err)
382+
const ddl = `
383+
CREATE TABLE test_array (
384+
Col1 Array(String)
385+
) Engine MergeTree() ORDER BY tuple()
386+
`
387+
defer func() {
388+
conn.Exec(ctx, "DROP TABLE test_array")
389+
}()
390+
require.NoError(t, conn.Exec(ctx, ddl))
391+
batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_array")
392+
require.NoError(t, err)
393+
var (
394+
col1Data = []string{"A", "b", "c"}
395+
)
396+
for i := 0; i < 10; i++ {
397+
require.NoError(t, batch.Append(col1Data))
398+
}
399+
require.Equal(t, 10, batch.Rows())
400+
require.Nil(t, batch.Send())
401+
rows, err := conn.Query(ctx, "SELECT * FROM test_array")
402+
require.NoError(t, err)
403+
for rows.Next() {
404+
var (
405+
col1 = sqlScannerArray{}
406+
)
407+
require.NoError(t, rows.Scan(&col1))
408+
assert.Equal(t, col1Data, col1.value)
409+
}
410+
require.NoError(t, rows.Close())
411+
require.NoError(t, rows.Err())
412+
}
413+
414+
type sqlScannerArray struct {
415+
value any
416+
}
417+
418+
func (s *sqlScannerArray) Scan(src any) error {
419+
s.value = src
420+
return nil
421+
}

tests/ipv4_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,3 +556,44 @@ func TestIPv4Valuer(t *testing.T) {
556556
}
557557
require.Equal(t, 1000, i)
558558
}
559+
560+
func TestSQLScannerIPv4(t *testing.T) {
561+
conn, err := GetNativeConnection(nil, nil, &clickhouse.Compression{
562+
Method: clickhouse.CompressionLZ4,
563+
})
564+
ctx := context.Background()
565+
require.NoError(t, err)
566+
const ddl = `
567+
CREATE TABLE test_ipv4 (
568+
Col1 IPv4
569+
) Engine MergeTree() ORDER BY tuple()
570+
`
571+
defer func() {
572+
conn.Exec(ctx, "DROP TABLE test_ipv4")
573+
}()
574+
575+
require.NoError(t, conn.Exec(ctx, ddl))
576+
batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_ipv4")
577+
require.NoError(t, err)
578+
579+
var (
580+
col1Data = net.ParseIP("127.0.0.1")
581+
)
582+
require.NoError(t, batch.Append(col1Data))
583+
require.Equal(t, 1, batch.Rows())
584+
require.NoError(t, batch.Send())
585+
var (
586+
col1 sqlScannerIPv4
587+
)
588+
require.NoError(t, conn.QueryRow(ctx, "SELECT * FROM test_ipv4").Scan(&col1))
589+
assert.Equal(t, col1Data.To4(), col1.value)
590+
}
591+
592+
type sqlScannerIPv4 struct {
593+
value any
594+
}
595+
596+
func (s *sqlScannerIPv4) Scan(src any) error {
597+
s.value = src
598+
return nil
599+
}

tests/ipv6_test.go

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@ import (
2121
"context"
2222
"database/sql/driver"
2323
"fmt"
24-
"github.com/ClickHouse/ch-go/proto"
25-
"github.com/ClickHouse/clickhouse-go/v2/lib/column"
26-
"github.com/stretchr/testify/require"
2724
"net"
2825
"net/netip"
2926
"testing"
3027

28+
"github.com/ClickHouse/ch-go/proto"
29+
"github.com/ClickHouse/clickhouse-go/v2/lib/column"
30+
"github.com/stretchr/testify/require"
31+
3132
"github.com/ClickHouse/clickhouse-go/v2"
3233
"github.com/stretchr/testify/assert"
3334
)
@@ -520,3 +521,44 @@ func TestIPv6Valuer(t *testing.T) {
520521
}
521522
require.Equal(t, 1000, i)
522523
}
524+
525+
func TestSQLScannerIPv6(t *testing.T) {
526+
conn, err := GetNativeConnection(nil, nil, &clickhouse.Compression{
527+
Method: clickhouse.CompressionLZ4,
528+
})
529+
ctx := context.Background()
530+
require.NoError(t, err)
531+
const ddl = `
532+
CREATE TABLE test_ipv6 (
533+
Col1 IPv6
534+
) Engine MergeTree() ORDER BY tuple()
535+
`
536+
defer func() {
537+
conn.Exec(ctx, "DROP TABLE test_ipv6")
538+
}()
539+
540+
require.NoError(t, conn.Exec(ctx, ddl))
541+
batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_ipv6")
542+
require.NoError(t, err)
543+
544+
var (
545+
col1Data = net.ParseIP("2001:44c8:129:2632:33:0:252:2")
546+
)
547+
require.NoError(t, batch.Append(col1Data))
548+
require.Equal(t, 1, batch.Rows())
549+
require.NoError(t, batch.Send())
550+
var (
551+
col1 sqlScannerIPv6
552+
)
553+
require.NoError(t, conn.QueryRow(ctx, "SELECT * FROM test_ipv6").Scan(&col1))
554+
assert.Equal(t, col1Data, col1.value)
555+
}
556+
557+
type sqlScannerIPv6 struct {
558+
value any
559+
}
560+
561+
func (s *sqlScannerIPv6) Scan(src any) error {
562+
s.value = src
563+
return nil
564+
}

tests/map_test.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,52 @@ func (i *mapIter) Value() any {
459459
return i.om.valuesIter[i.iterIndex]
460460
}
461461

462+
func TestSQLScannerMap(t *testing.T) {
463+
conn, err := GetNativeConnection(clickhouse.Settings{}, nil, &clickhouse.Compression{
464+
Method: clickhouse.CompressionLZ4,
465+
})
466+
ctx := context.Background()
467+
require.NoError(t, err)
468+
if !CheckMinServerServerVersion(conn, 21, 9, 0) {
469+
t.Skip(fmt.Errorf("unsupported clickhouse version"))
470+
return
471+
}
472+
const ddl = `
473+
CREATE TABLE test_map (
474+
Col1 Map(String, UInt64)
475+
) Engine MergeTree() ORDER BY tuple()
476+
`
477+
defer func() {
478+
conn.Exec(ctx, "DROP TABLE IF EXISTS test_map")
479+
}()
480+
require.NoError(t, conn.Exec(ctx, ddl))
481+
batch, err := conn.PrepareBatch(ctx, "INSERT INTO test_map")
482+
require.NoError(t, err)
483+
var (
484+
col1Data = map[string]uint64{
485+
"key_col_1_1": 1,
486+
"key_col_1_2": 2,
487+
}
488+
)
489+
require.NoError(t, batch.Append(col1Data))
490+
require.Equal(t, 1, batch.Rows())
491+
require.NoError(t, batch.Send())
492+
var (
493+
col1 sqlScannerMap
494+
)
495+
require.NoError(t, conn.QueryRow(ctx, "SELECT * FROM test_map").Scan(&col1))
496+
assert.Equal(t, col1Data, col1.value)
497+
}
498+
499+
type sqlScannerMap struct {
500+
value any
501+
}
502+
503+
func (s *sqlScannerMap) Scan(src any) error {
504+
s.value = src
505+
return nil
506+
}
507+
462508
func BenchmarkOrderedMapUseChanGo(b *testing.B) {
463509
m := NewOrderedMap()
464510
for i := 0; i < 10; i++ {

0 commit comments

Comments
 (0)