diff --git a/sqlite3.go b/sqlite3.go index 3025a500..162bd549 100644 --- a/sqlite3.go +++ b/sqlite3.go @@ -243,6 +243,7 @@ const ( columnDate string = "date" columnDatetime string = "datetime" columnTimestamp string = "timestamp" + columnBoolean string = "boolean" ) // This variable can be replaced with -ldflags like below: @@ -269,7 +270,7 @@ const ( SQLITE_INSERT = C.SQLITE_INSERT SQLITE_UPDATE = C.SQLITE_UPDATE - // used by authorzier - as return value + // used by authorizer - as return value SQLITE_OK = C.SQLITE_OK SQLITE_IGNORE = C.SQLITE_IGNORE SQLITE_DENY = C.SQLITE_DENY @@ -2105,7 +2106,7 @@ func (s *SQLiteStmt) execSync(args []driver.NamedValue) (driver.Result, error) { // // See: https://sqlite.org/c3ref/stmt_readonly.html func (s *SQLiteStmt) Readonly() bool { - return C.sqlite3_stmt_readonly(s.s) == 1 + return C.sqlite3_stmt_readonly(s.s) != 0 } // Close the rows. @@ -2233,8 +2234,8 @@ func (rc *SQLiteRows) nextSyncLocked(dest []driver.Value) error { t = t.In(rc.s.c.loc) } dest[i] = t - case "boolean": - dest[i] = val > 0 + case columnBoolean: + dest[i] = val != 0 default: dest[i] = val } diff --git a/sqlite3_test.go b/sqlite3_test.go index 94de7386..be7c5459 100644 --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -14,7 +14,7 @@ import ( "database/sql/driver" "errors" "fmt" - "io/ioutil" + "io" "math/rand" "net/url" "os" @@ -29,7 +29,7 @@ import ( ) func TempFilename(t testing.TB) string { - f, err := ioutil.TempFile("", "go-sqlite3-test-") + f, err := os.CreateTemp("", "go-sqlite3-test-") if err != nil { t.Fatal(err) } @@ -1709,6 +1709,288 @@ func TestDeclTypes(t *testing.T) { } } +func TestScanTypes(t *testing.T) { + + d := SQLiteDriver{} + + conn, err := d.Open(":memory:") + if err != nil { + t.Fatal("Failed to begin transaction:", err) + } + defer conn.Close() + + sqlite3conn := conn.(*SQLiteConn) + + _, err = sqlite3conn.Exec("create table foo (id integer not null primary key, name text, price integer, length float, token blob, dob timestamp, jdays date, somedate datetime)", nil) + if err != nil { + t.Fatal("Failed to create table:", err) + } + expected := []reflect.Type{type_int, type_string, type_int, type_float, type_rawbytes, type_time, type_time, type_time} + + _, err = sqlite3conn.Exec("insert into foo(name, price, length, token, dob, jdays, somedate) values('bar', 10, 3.1415, x'0500', 100, 5.0, '2006-01-02 15:04:05')", nil) + if err != nil { + t.Fatal("Failed to insert:", err) + } + + rs, err := sqlite3conn.Query("select * from foo", nil) + if err != nil { + t.Fatal("Failed to select:", err) + } + defer rs.Close() + + cols := make([]driver.Value, len(rs.Columns())) + err = rs.Next(cols) + if err != nil { + t.Fatal("Failed to advance cursor:", err) + } + + rc, ok := rs.(driver.RowsColumnTypeScanType) + if !ok { + t.Fatal("SQLiteRows does not implement driver.RowsColumnTypeScanType") + } + + for i := range rc.Columns() { + if st := rc.ColumnTypeScanType(i); st != expected[i] { + t.Fatal("Unexpected ScanType. Expected:", expected[i], "Got:", st) + } + } +} + +func TestScanTypesBeforeNext(t *testing.T) { + + d := SQLiteDriver{} + + conn, err := d.Open(":memory:") + if err != nil { + t.Fatal("Failed to begin transaction:", err) + } + defer conn.Close() + + sqlite3conn := conn.(*SQLiteConn) + + _, err = sqlite3conn.Exec("create table foo (id integer not null primary key, name text, price integer, length float, token blob, dob timestamp, jdays date, somedate datetime)", nil) + if err != nil { + t.Fatal("Failed to create table:", err) + } + + _, err = sqlite3conn.Exec("insert into foo(name, price, length, token, dob, jdays, somedate) values('bar', 10, 3.1415, x'0500', 100, 5.0, '2006-01-02 15:04:05')", nil) + if err != nil { + t.Fatal("Failed to insert:", err) + } + + rs, err := sqlite3conn.Query("select * from foo", nil) + if err != nil { + t.Fatal("Failed to select:", err) + } + defer rs.Close() + + rc, ok := rs.(driver.RowsColumnTypeScanType) + if !ok { + t.Fatal("SQLiteRows does not implement driver.RowsColumnTypeScanType") + } + + for i := range rc.Columns() { + if st := rc.ColumnTypeScanType(i); st != type_any { + t.Fatal("Unexpected ScanType:", st) + } + } +} + +func TestScanTypesAfterClosed(t *testing.T) { + + d := SQLiteDriver{} + + conn, err := d.Open(":memory:") + if err != nil { + t.Fatal("Failed to begin transaction:", err) + } + defer conn.Close() + + sqlite3conn := conn.(*SQLiteConn) + + _, err = sqlite3conn.Exec("create table foo (id integer not null primary key, name text, price integer, length float, token blob, dob timestamp, jdays date, somedate datetime)", nil) + if err != nil { + t.Fatal("Failed to create table:", err) + } + + _, err = sqlite3conn.Exec("insert into foo(name, price, length, token, dob, jdays, somedate) values('bar', 10, 3.1415, x'0500', 100, 5.0, '2006-01-02 15:04:05')", nil) + if err != nil { + t.Fatal("Failed to insert:", err) + } + + rs, err := sqlite3conn.Query("select * from foo", nil) + if err != nil { + t.Fatal("Failed to select:", err) + } + defer rs.Close() + + cols := make([]driver.Value, len(rs.Columns())) + err = rs.Next(cols) + if err != nil { + t.Fatal("Failed to advance cursor:", err) + } + err = rs.Next(cols) + if err != io.EOF { + t.Fatal("Unexpected error when reaching end of dataset:", err) + } + + rc, ok := rs.(driver.RowsColumnTypeScanType) + if !ok { + t.Fatal("SQLiteRows does not implement driver.RowsColumnTypeScanType") + } + + for i := range rc.Columns() { + if st := rc.ColumnTypeScanType(i); st != type_any { + t.Fatal("Unexpected ScanType:", st) + } + } +} + +func TestScanTypesInvalidColumn(t *testing.T) { + + d := SQLiteDriver{} + + conn, err := d.Open(":memory:") + if err != nil { + t.Fatal("Failed to begin transaction:", err) + } + defer conn.Close() + + sqlite3conn := conn.(*SQLiteConn) + + _, err = sqlite3conn.Exec("create table foo (id integer not null primary key, name text, price integer, length float, token blob, dob timestamp, jdays date, somedate datetime)", nil) + if err != nil { + t.Fatal("Failed to create table:", err) + } + + _, err = sqlite3conn.Exec("insert into foo(name, price, length, token, dob, jdays, somedate) values('bar', 10, 3.1415, x'0500', 100, 5.0, '2006-01-02 15:04:05')", nil) + if err != nil { + t.Fatal("Failed to insert:", err) + } + + rs, err := sqlite3conn.Query("select * from foo", nil) + if err != nil { + t.Fatal("Failed to select:", err) + } + defer rs.Close() + + cols := make([]driver.Value, len(rs.Columns())) + err = rs.Next(cols) + if err != nil { + t.Fatal("Failed to advance cursor:", err) + } + + rc, ok := rs.(driver.RowsColumnTypeScanType) + if !ok { + t.Fatal("SQLiteRows does not implement driver.RowsColumnTypeScanType") + } + + if st := rc.ColumnTypeScanType(len(rc.Columns())); st != type_any { + t.Fatal("Unexpected ScanType:", st) + } + if st := rc.ColumnTypeScanType(-1); st != type_any { + t.Fatal("Unexpected ScanType:", st) + } +} + +func TestScanTypesNull(t *testing.T) { + + d := SQLiteDriver{} + + conn, err := d.Open(":memory:") + if err != nil { + t.Fatal("Failed to begin transaction:", err) + } + defer conn.Close() + + sqlite3conn := conn.(*SQLiteConn) + + _, err = sqlite3conn.Exec("create table foo (id integer not null primary key, name text, price integer, length float, token blob, dob timestamp, jdays date, somedate datetime)", nil) + if err != nil { + t.Fatal("Failed to create table:", err) + } + + _, err = sqlite3conn.Exec("insert into foo(name, price, length, token, dob, jdays, somedate) values(null, null, null, null, null, null, null)", nil) + if err != nil { + t.Fatal("Failed to insert:", err) + } + + rs, err := sqlite3conn.Query("select * from foo", nil) + if err != nil { + t.Fatal("Failed to select:", err) + } + defer rs.Close() + + cols := make([]driver.Value, len(rs.Columns())) + err = rs.Next(cols) + if err != nil { + t.Fatal("Failed to advance cursor:", err) + } + + rc, ok := rs.(driver.RowsColumnTypeScanType) + if !ok { + t.Fatal("SQLiteRows does not implement driver.RowsColumnTypeScanType") + } + + for i := 1; i < len(rc.Columns()); i++ { + if st := rc.ColumnTypeScanType(i); st != type_any { + t.Fatal("Unexpected ScanType:", i, st) + } + } +} + +func TestScanTypesAggregate(t *testing.T) { + + d := SQLiteDriver{} + + conn, err := d.Open(":memory:") + if err != nil { + t.Fatal("Failed to begin transaction:", err) + } + defer conn.Close() + + sqlite3conn := conn.(*SQLiteConn) + + _, err = sqlite3conn.Exec("create table foo (id integer not null primary key, price integer)", nil) + if err != nil { + t.Fatal("Failed to create table:", err) + } + + _, err = sqlite3conn.Exec("insert into foo(price) values(0)", nil) + if err != nil { + t.Fatal("Failed to insert:", err) + } + _, err = sqlite3conn.Exec("insert into foo(price) values(5)", nil) + if err != nil { + t.Fatal("Failed to insert:", err) + } + _, err = sqlite3conn.Exec("insert into foo(price) values(10)", nil) + if err != nil { + t.Fatal("Failed to insert:", err) + } + + rs, err := sqlite3conn.Query("select total(price) from foo", nil) + if err != nil { + t.Fatal("Failed to select:", err) + } + defer rs.Close() + + cols := make([]driver.Value, len(rs.Columns())) + err = rs.Next(cols) + if err != nil { + t.Fatal("Failed to advance cursor:", err) + } + + rc, ok := rs.(driver.RowsColumnTypeScanType) + if !ok { + t.Fatal("SQLiteRows does not implement driver.RowsColumnTypeScanType") + } + + if st := rc.ColumnTypeScanType(0); st != type_float { + t.Fatal("Unexpected ScanType:", st) + } +} + func TestPinger(t *testing.T) { db, err := sql.Open("sqlite3", ":memory:") if err != nil { diff --git a/sqlite3_type.go b/sqlite3_type.go index 20537a09..13eb19d1 100644 --- a/sqlite3_type.go +++ b/sqlite3_type.go @@ -15,7 +15,17 @@ import "C" import ( "database/sql" "reflect" - "strings" + "time" +) + +var ( + type_int = reflect.TypeOf(int64(0)) + type_float = reflect.TypeOf(float64(0)) + type_string = reflect.TypeOf("") + type_rawbytes = reflect.TypeOf(sql.RawBytes{}) + type_bool = reflect.TypeOf(true) + type_time = reflect.TypeOf(time.Time{}) + type_any = reflect.TypeOf(new(any)).Elem() ) // ColumnTypeDatabaseTypeName implement RowsColumnTypeDatabaseTypeName. @@ -39,70 +49,42 @@ func (rc *SQLiteRows) ColumnTypeNullable(i int) (nullable, ok bool) { } // ColumnTypeScanType implement RowsColumnTypeScanType. +// In SQLite3, this method should be called after Next() has been called, as sqlite3_column_type() +// returns the column type for a specific row. If Next() has not been called, fallback to +// sqlite3_column_decltype() func (rc *SQLiteRows) ColumnTypeScanType(i int) reflect.Type { - //ct := C.sqlite3_column_type(rc.s.s, C.int(i)) // Always returns 5 - return scanType(C.GoString(C.sqlite3_column_decltype(rc.s.s, C.int(i)))) -} - -const ( - SQLITE_INTEGER = iota - SQLITE_TEXT - SQLITE_BLOB - SQLITE_REAL - SQLITE_NUMERIC - SQLITE_TIME - SQLITE_BOOL - SQLITE_NULL -) - -func scanType(cdt string) reflect.Type { - t := strings.ToUpper(cdt) - i := databaseTypeConvSqlite(t) - switch i { - case SQLITE_INTEGER: - return reflect.TypeOf(sql.NullInt64{}) - case SQLITE_TEXT: - return reflect.TypeOf(sql.NullString{}) - case SQLITE_BLOB: - return reflect.TypeOf(sql.RawBytes{}) - case SQLITE_REAL: - return reflect.TypeOf(sql.NullFloat64{}) - case SQLITE_NUMERIC: - return reflect.TypeOf(sql.NullFloat64{}) - case SQLITE_BOOL: - return reflect.TypeOf(sql.NullBool{}) - case SQLITE_TIME: - return reflect.TypeOf(sql.NullTime{}) - } - return reflect.TypeOf(new(any)) -} + rc.s.mu.Lock() + defer rc.s.mu.Unlock() -func databaseTypeConvSqlite(t string) int { - if strings.Contains(t, "INT") { - return SQLITE_INTEGER - } - if t == "CLOB" || t == "TEXT" || - strings.Contains(t, "CHAR") { - return SQLITE_TEXT - } - if t == "BLOB" { - return SQLITE_BLOB - } - if t == "REAL" || t == "FLOAT" || - strings.Contains(t, "DOUBLE") { - return SQLITE_REAL + if isValidRow := C.sqlite3_stmt_busy(rc.s.s) != 0; !isValidRow { + return type_any } - if t == "DATE" || t == "DATETIME" || - t == "TIMESTAMP" { - return SQLITE_TIME - } - if t == "NUMERIC" || - strings.Contains(t, "DECIMAL") { - return SQLITE_NUMERIC - } - if t == "BOOLEAN" { - return SQLITE_BOOL + if isValidColumn := i >= 0 && i < int(rc.nc); !isValidColumn { + return type_any } - return SQLITE_NULL + switch C.sqlite3_column_type(rc.s.s, C.int(i)) { + case C.SQLITE_INTEGER: + switch rc.decltype[i] { + case columnTimestamp, columnDatetime, columnDate: + return type_time + case columnBoolean: + return type_bool + } + return type_int + case C.SQLITE_FLOAT: + return type_float + case C.SQLITE_TEXT: + switch rc.decltype[i] { + case columnTimestamp, columnDatetime, columnDate: + return type_time + } + return type_string + case C.SQLITE_BLOB: + return type_rawbytes + case C.SQLITE_NULL: + fallthrough + default: + return type_any + } }