diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index 47c2005018..be610f2a8a 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -10481,12 +10481,10 @@ where }, Assertions: []ScriptTestAssertion{ { - Skip: true, Query: "insert into tinyint_tbl values (999)", ExpectedErr: sql.ErrValueOutOfRange, }, { - Skip: true, Query: "insert into tinyint_tbl values (127)", Expected: []sql.Row{ {types.OkResult{ @@ -10496,7 +10494,6 @@ where }, }, { - Skip: true, Query: "show create table tinyint_tbl;", Expected: []sql.Row{ {"tinyint_tbl", "CREATE TABLE `tinyint_tbl` (\n" + @@ -10507,12 +10504,10 @@ where }, { - Skip: true, Query: "insert into smallint_tbl values (99999);", ExpectedErr: sql.ErrValueOutOfRange, }, { - Skip: true, Query: "insert into smallint_tbl values (32767);", Expected: []sql.Row{ {types.OkResult{ @@ -10522,23 +10517,20 @@ where }, }, { - Skip: true, Query: "show create table smallint_tbl;", Expected: []sql.Row{ {"smallint_tbl", "CREATE TABLE `smallint_tbl` (\n" + " `i` smallint NOT NULL AUTO_INCREMENT,\n" + " PRIMARY KEY (`i`)\n" + - ") ENGINE=InnoDB AUTO_INCREMENT=36727 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, + ") ENGINE=InnoDB AUTO_INCREMENT=32767 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, }, }, { - Skip: true, Query: "insert into mediumint_tbl values (99999999);", ExpectedErr: sql.ErrValueOutOfRange, }, { - Skip: true, Query: "insert into mediumint_tbl values (8388607);", Expected: []sql.Row{ {types.OkResult{ @@ -10548,7 +10540,6 @@ where }, }, { - Skip: true, Query: "show create table mediumint_tbl;", Expected: []sql.Row{ {"mediumint_tbl", "CREATE TABLE `mediumint_tbl` (\n" + @@ -10559,12 +10550,10 @@ where }, { - Skip: true, Query: "insert into int_tbl values (99999999999)", ExpectedErr: sql.ErrValueOutOfRange, }, { - Skip: true, Query: "insert into int_tbl values (2147483647)", Expected: []sql.Row{ {types.OkResult{ @@ -10574,7 +10563,6 @@ where }, }, { - Skip: true, Query: "show create table int_tbl;", Expected: []sql.Row{ {"int_tbl", "CREATE TABLE `int_tbl` (\n" + @@ -10585,12 +10573,10 @@ where }, { - Skip: true, Query: "insert into bigint_tbl values (99999999999999999999);", ExpectedErr: sql.ErrValueOutOfRange, }, { - Skip: true, Query: "insert into bigint_tbl values (9223372036854775807);", Expected: []sql.Row{ {types.OkResult{ @@ -10600,7 +10586,6 @@ where }, }, { - Skip: true, Query: "show create table bigint_tbl;", Expected: []sql.Row{ {"bigint_tbl", "CREATE TABLE `bigint_tbl` (\n" + @@ -10624,12 +10609,10 @@ where }, Assertions: []ScriptTestAssertion{ { - Skip: true, Query: "insert into tinyint_tbl values (999)", ExpectedErr: sql.ErrValueOutOfRange, }, { - Skip: true, Query: "insert into tinyint_tbl values (255)", Expected: []sql.Row{ {types.OkResult{ @@ -10639,23 +10622,20 @@ where }, }, { - Skip: true, Query: "show create table tinyint_tbl;", Expected: []sql.Row{ {"tinyint_tbl", "CREATE TABLE `tinyint_tbl` (\n" + - " `i` tinyint NOT NULL AUTO_INCREMENT,\n" + + " `i` tinyint unsigned NOT NULL AUTO_INCREMENT,\n" + " PRIMARY KEY (`i`)\n" + ") ENGINE=InnoDB AUTO_INCREMENT=255 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, }, }, { - Skip: true, Query: "insert into smallint_tbl values (99999);", ExpectedErr: sql.ErrValueOutOfRange, }, { - Skip: true, Query: "insert into smallint_tbl values (65535);", Expected: []sql.Row{ {types.OkResult{ @@ -10665,23 +10645,20 @@ where }, }, { - Skip: true, Query: "show create table smallint_tbl;", Expected: []sql.Row{ {"smallint_tbl", "CREATE TABLE `smallint_tbl` (\n" + - " `i` smallint NOT NULL AUTO_INCREMENT,\n" + + " `i` smallint unsigned NOT NULL AUTO_INCREMENT,\n" + " PRIMARY KEY (`i`)\n" + ") ENGINE=InnoDB AUTO_INCREMENT=65535 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, }, }, { - Skip: true, Query: "insert into mediumint_tbl values (999999999);", ExpectedErr: sql.ErrValueOutOfRange, }, { - Skip: true, Query: "insert into mediumint_tbl values (16777215);", Expected: []sql.Row{ {types.OkResult{ @@ -10691,23 +10668,20 @@ where }, }, { - Skip: true, Query: "show create table mediumint_tbl;", Expected: []sql.Row{ {"mediumint_tbl", "CREATE TABLE `mediumint_tbl` (\n" + - " `i` mediumint NOT NULL AUTO_INCREMENT,\n" + + " `i` mediumint unsigned NOT NULL AUTO_INCREMENT,\n" + " PRIMARY KEY (`i`)\n" + ") ENGINE=InnoDB AUTO_INCREMENT=16777215 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, }, }, { - Skip: true, Query: "insert into int_tbl values (99999999999)", ExpectedErr: sql.ErrValueOutOfRange, }, { - Skip: true, Query: "insert into int_tbl values (4294967295)", Expected: []sql.Row{ {types.OkResult{ @@ -10717,23 +10691,20 @@ where }, }, { - Skip: true, Query: "show create table int_tbl;", Expected: []sql.Row{ {"int_tbl", "CREATE TABLE `int_tbl` (\n" + - " `i` int NOT NULL AUTO_INCREMENT,\n" + + " `i` int unsigned NOT NULL AUTO_INCREMENT,\n" + " PRIMARY KEY (`i`)\n" + ") ENGINE=InnoDB AUTO_INCREMENT=4294967295 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, }, }, { - Skip: true, Query: "insert into bigint_tbl values (999999999999999999999);", ExpectedErr: sql.ErrValueOutOfRange, }, { - Skip: true, Query: "insert into bigint_tbl values (18446744073709551615);", Expected: []sql.Row{ {types.OkResult{ @@ -10743,11 +10714,10 @@ where }, }, { - Skip: true, Query: "show create table bigint_tbl;", Expected: []sql.Row{ {"bigint_tbl", "CREATE TABLE `bigint_tbl` (\n" + - " `i` bigint NOT NULL AUTO_INCREMENT,\n" + + " `i` bigint unsigned NOT NULL AUTO_INCREMENT,\n" + " PRIMARY KEY (`i`)\n" + ") ENGINE=InnoDB AUTO_INCREMENT=18446744073709551615 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}, }, diff --git a/memory/table.go b/memory/table.go index 63bebeaacb..d9ca1652c6 100644 --- a/memory/table.go +++ b/memory/table.go @@ -19,6 +19,7 @@ import ( "encoding/gob" "fmt" "io" + "math" "sort" "strconv" "strings" @@ -1144,9 +1145,32 @@ func (t *Table) Insert(ctx *sql.Context, row sql.Row) error { func (t *Table) PeekNextAutoIncrementValue(ctx *sql.Context) (uint64, error) { data := t.sessionTableData(ctx) + // Find the auto increment column to validate the current value + autoCol := t.getAutoIncrementColumn() + if autoCol == nil { + return data.autoIncVal, nil + } + + // If the current auto increment value is out of range for the column type, + // return the maximum valid value instead + if _, inRange, err := autoCol.Type.Convert(ctx, data.autoIncVal); err == nil && inRange == sql.OutOfRange { + return data.autoIncVal - 1, nil + } + return data.autoIncVal, nil } +// getAutoIncrementColumn returns the auto increment column for this table, or nil if none exists. +// Only one auto increment column is allowed per table. +func (t *Table) getAutoIncrementColumn() *sql.Column { + for _, col := range t.Schema() { + if col.AutoIncrement { + return col + } + } + return nil +} + // GetNextAutoIncrementValue gets the next auto increment value for the memory table the increment. func (t *Table) GetNextAutoIncrementValue(ctx *sql.Context, insertVal interface{}) (uint64, error) { data := t.sessionTableData(ctx) @@ -1163,7 +1187,6 @@ func (t *Table) GetNextAutoIncrementValue(ctx *sql.Context, insertVal interface{ } data.autoIncVal = v.(uint64) } - return data.autoIncVal, nil } @@ -1257,7 +1280,7 @@ func addColumnToSchema(ctx *sql.Context, data *TableData, newCol *sql.Column, or data.autoIncVal = 0 } - data.autoIncVal++ + updateAutoIncrementSafe(ctx, newCol, &data.autoIncVal) } newPkOrds := data.schema.PkOrdinals @@ -2551,3 +2574,21 @@ func (t *TableRevision) AddColumn(ctx *sql.Context, column *sql.Column, order *s func (t *TableRevision) IgnoreSessionData() bool { return true } + +// updateAutoIncrementSafe safely increments an auto_increment value, handling overflow +// by ensuring it doesn't exceed the column type's maximum value or wrap around. +func updateAutoIncrementSafe(ctx *sql.Context, autoCol *sql.Column, autoIncVal *uint64) { + currentVal := *autoIncVal + + // Check for arithmetic overflow before adding 1 + if currentVal == math.MaxUint64 { + // At maximum uint64 value, can't increment further + return + } + + nextVal := currentVal + 1 + if _, inRange, err := autoCol.Type.Convert(ctx, nextVal); err == nil && inRange == sql.InRange { + *autoIncVal = nextVal + } + // If next value would be out of range for the column type, stay at current value +} diff --git a/memory/table_editor.go b/memory/table_editor.go index 8d5074f005..29ace64555 100644 --- a/memory/table_editor.go +++ b/memory/table_editor.go @@ -188,16 +188,14 @@ func (t *tableEditor) Insert(ctx *sql.Context, row sql.Row) error { return err } if cmp > 0 { - // Provided value larger than autoIncVal, set autoIncVal to that value - v, _, err := types.Uint64.Convert(ctx, row[idx]) + insertedVal, _, err := types.Uint64.Convert(ctx, row[idx]) if err != nil { return err } - t.ea.TableData().autoIncVal = v.(uint64) - t.ea.TableData().autoIncVal++ // Move onto next autoIncVal + t.ea.TableData().autoIncVal = insertedVal.(uint64) + updateAutoIncrementSafe(ctx, autoCol, &t.ea.TableData().autoIncVal) } else if cmp == 0 { - // Provided value equal to autoIncVal - t.ea.TableData().autoIncVal++ // Move onto next autoIncVal + updateAutoIncrementSafe(ctx, autoCol, &t.ea.TableData().autoIncVal) } } diff --git a/sql/expression/auto_increment.go b/sql/expression/auto_increment.go index 2ee93720e3..6ad5ea521c 100644 --- a/sql/expression/auto_increment.go +++ b/sql/expression/auto_increment.go @@ -139,7 +139,10 @@ func (i *AutoIncrement) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) given = seq } - ret, _, err := i.Type().Convert(ctx, given) + ret, inRange, err := i.Type().Convert(ctx, given) + if err == nil && !inRange { + err = sql.ErrValueOutOfRange.New(given, i.Type()) + } if err != nil { return nil, err } diff --git a/sql/types/number.go b/sql/types/number.go index b6152a800d..af0d6a498f 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -1163,7 +1163,7 @@ func convertToUint64(t NumberTypeImpl_, v interface{}) (uint64, sql.ConvertInRan return uint64(math.Round(v)), sql.InRange, nil case decimal.Decimal: if v.GreaterThan(dec_uint64_max) { - return math.MaxUint64, sql.InRange, nil + return math.MaxUint64, sql.OutOfRange, nil } else if v.LessThan(dec_zero) { ret, _ := dec_uint64_max.Sub(v).Float64() return uint64(math.Round(ret)), sql.OutOfRange, nil @@ -1181,6 +1181,9 @@ func convertToUint64(t NumberTypeImpl_, v interface{}) (uint64, sql.ConvertInRan v = strings.Trim(v, intCutSet) if i, err := strconv.ParseUint(v, 10, 64); err == nil { return i, sql.InRange, nil + } else if err == strconv.ErrRange { + // Number is too large for uint64, return max value and OutOfRange + return math.MaxUint64, sql.OutOfRange, nil } if f, err := strconv.ParseFloat(v, 64); err == nil { if val, inRange, err := convertToUint64(t, f); err == nil && inRange { @@ -1238,7 +1241,7 @@ func convertToUint32(t NumberTypeImpl_, v interface{}) (uint32, sql.ConvertInRan } return uint32(v), sql.InRange, nil case uint: - return uint32(v), sql.InRange, nil + return convertUintToUint32(uint64(v)) case uint8: return uint32(v), sql.InRange, nil case uint16: @@ -1246,7 +1249,7 @@ func convertToUint32(t NumberTypeImpl_, v interface{}) (uint32, sql.ConvertInRan case uint32: return v, sql.InRange, nil case uint64: - return uint32(v), sql.InRange, nil + return convertUintToUint32(v) case float64: if float32(v) > float32(math.MaxInt32) { return math.MaxUint32, sql.OutOfRange, nil @@ -1334,13 +1337,13 @@ func convertToUint16(t NumberTypeImpl_, v interface{}) (uint16, sql.ConvertInRan } return uint16(v), sql.InRange, nil case uint: - return uint16(v), sql.InRange, nil + return convertUintToUint16(uint64(v)) case uint8: return uint16(v), sql.InRange, nil case uint64: - return uint16(v), sql.InRange, nil + return convertUintToUint16(v) case uint32: - return uint16(v), sql.InRange, nil + return convertUintToUint16(uint64(v)) case uint16: return v, sql.InRange, nil case float32: @@ -1434,13 +1437,13 @@ func convertToUint8(t NumberTypeImpl_, v interface{}) (uint8, sql.ConvertInRange } return uint8(v), sql.InRange, nil case uint: - return uint8(v), sql.InRange, nil + return convertUintToUint8(uint64(v)) case uint16: - return uint8(v), sql.InRange, nil + return convertUintToUint8(uint64(v)) case uint64: - return uint8(v), sql.InRange, nil + return convertUintToUint8(v) case uint32: - return uint8(v), sql.InRange, nil + return convertUintToUint8(uint64(v)) case uint8: return v, sql.InRange, nil case float32: @@ -1719,3 +1722,30 @@ func CoalesceInt(val interface{}) (int, bool) { return 0, false } } + +// convertUintToUint8 converts a uint64 value to uint8 with overflow checking. +// Returns the converted value, range status, and any error. +func convertUintToUint8(v uint64) (uint8, sql.ConvertInRange, error) { + if v > math.MaxUint8 { + return uint8(math.MaxUint8), sql.OutOfRange, nil + } + return uint8(v), sql.InRange, nil +} + +// convertUintToUint16 converts a uint64 value to uint16 with overflow checking. +// Returns the converted value, range status, and any error. +func convertUintToUint16(v uint64) (uint16, sql.ConvertInRange, error) { + if v > math.MaxUint16 { + return uint16(math.MaxUint16), sql.OutOfRange, nil + } + return uint16(v), sql.InRange, nil +} + +// convertUintToUint32 converts a uint64 value to uint32 with overflow checking. +// Returns the converted value, range status, and any error. +func convertUintToUint32(v uint64) (uint32, sql.ConvertInRange, error) { + if v > math.MaxUint32 { + return uint32(math.MaxUint32), sql.OutOfRange, nil + } + return uint32(v), sql.InRange, nil +}