From d2d2878f8a930aa9a81f83db73a96ea425f5ce94 Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 8 Sep 2025 11:32:25 -0700 Subject: [PATCH 01/48] tmp patch --- sql/expression/convert.go | 3 ++- sql/types/conversion.go | 5 +++++ sql/types/number.go | 22 ++++++++++++++++------ 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/sql/expression/convert.go b/sql/expression/convert.go index b15548cd67..30baace3fa 100644 --- a/sql/expression/convert.go +++ b/sql/expression/convert.go @@ -370,7 +370,8 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s } return d, nil case ConvertToDouble, ConvertToReal: - value, err := prepareForNumericContext(ctx, val, originType, false) + //value, err := prepareForNumericContext(ctx, val, originType, false) + value, err := types.ConvertOrTruncate(ctx, val, originType) if err != nil { return nil, err } diff --git a/sql/types/conversion.go b/sql/types/conversion.go index 2801cbf2a1..8b406acfac 100644 --- a/sql/types/conversion.go +++ b/sql/types/conversion.go @@ -769,6 +769,11 @@ func TypeAwareConversion(ctx *sql.Context, val interface{}, originalType sql.Typ // value is truncated to the Zero value for type |t|. If the value does not convert and the type is not automatically // coerced, then return an error. func ConvertOrTruncate(ctx *sql.Context, i interface{}, t sql.Type) (interface{}, error) { + // Do nothing if type is no provided. + if t == nil { + return i, nil + } + converted, _, err := t.Convert(ctx, i) if err == nil { return converted, nil diff --git a/sql/types/number.go b/sql/types/number.go index cb4dee0978..e0a3ed1a07 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -87,6 +87,15 @@ var ( numre = regexp.MustCompile(`^[ ]*[0-9]*\.?[0-9]+`) ) +const ( + // IntCutSet is the set of characters that should be trimmed from the beginning and end of a string + // when converting to a signed or unsigned integer + IntCutSet = " \t" + + // NumericCutSet is the set of characters to trim from a string before converting it to a number. + NumericCutSet = " \t\n\r" +) + type NumberTypeImpl_ struct { baseType query.Type displayWidth int @@ -982,7 +991,7 @@ func convertToInt64(t NumberTypeImpl_, v interface{}) (int64, sql.ConvertInRange } return i, sql.InRange, nil case string: - v = strings.Trim(v, sql.IntCutSet) + v = strings.Trim(v, IntCutSet) if v == "" { // StringType{}.Zero() returns empty string, but should represent "0" for number value return 0, sql.InRange, nil @@ -1169,7 +1178,7 @@ func convertToUint64(t NumberTypeImpl_, v interface{}) (uint64, sql.ConvertInRan } return i, sql.InRange, nil case string: - v = strings.Trim(v, sql.IntCutSet) + v = strings.Trim(v, IntCutSet) if i, err := strconv.ParseUint(v, 10, 64); err == nil { return i, sql.InRange, nil } else if err == strconv.ErrRange { @@ -1272,7 +1281,7 @@ func convertToUint32(t NumberTypeImpl_, v interface{}) (uint32, sql.ConvertInRan } return uint32(i), sql.InRange, nil case string: - v = strings.Trim(v, sql.IntCutSet) + v = strings.Trim(v, IntCutSet) if i, err := strconv.ParseUint(v, 10, 32); err == nil { return uint32(i), sql.InRange, nil } @@ -1368,7 +1377,7 @@ func convertToUint16(t NumberTypeImpl_, v interface{}) (uint16, sql.ConvertInRan } return uint16(i), sql.InRange, nil case string: - v = strings.Trim(v, sql.IntCutSet) + v = strings.Trim(v, IntCutSet) if i, err := strconv.ParseUint(v, 10, 16); err == nil { return uint16(i), sql.InRange, nil } @@ -1468,7 +1477,7 @@ func convertToUint8(t NumberTypeImpl_, v interface{}) (uint8, sql.ConvertInRange } return uint8(i), sql.InRange, nil case string: - v = strings.Trim(v, sql.IntCutSet) + v = strings.Trim(v, IntCutSet) if i, err := strconv.ParseUint(v, 10, 8); err == nil { return uint8(i), sql.InRange, nil } @@ -1528,7 +1537,8 @@ func convertToFloat64(t NumberTypeImpl_, v interface{}) (float64, error) { } return float64(i), nil case string: - v = strings.Trim(v, sql.NumericCutSet) + // TODO: just trimStringToPrefix here + v = strings.Trim(v, NumericCutSet) i, err := strconv.ParseFloat(v, 64) if err != nil { // parse the first longest valid numbers From 2b9554284b94692c199c976ba7fbe4b52e2d8d0b Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 9 Sep 2025 10:14:58 -0700 Subject: [PATCH 02/48] attempt at consolidating logic --- enginetest/memory_engine_test.go | 10 +---- sql/expression/convert.go | 62 +++----------------------- sql/hash/hash.go | 11 ++++- sql/types/conversion.go | 49 ++++++++++++++++++++- sql/types/decimal.go | 35 ++++++++------- sql/types/number.go | 75 +++++++++++++++++--------------- 6 files changed, 127 insertions(+), 115 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index f1ec7b45d0..fdebe8924e 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -200,20 +200,14 @@ func TestSingleQueryPrepared(t *testing.T) { // Convenience test for debugging a single query. Unskip and set to the desired query. func TestSingleScript(t *testing.T) { - t.Skip() + //t.Skip() var scripts = []queries.ScriptTest{ { Name: "AS OF propagates to nested CALLs", SetUpScript: []string{}, Assertions: []queries.ScriptTestAssertion{ { - Query: "create procedure create_proc() create table t (i int primary key, j int);", - Expected: []sql.Row{ - {types.NewOkResult(0)}, - }, - }, - { - Query: "call create_proc()", + Query: "select x'20' = 32;", Expected: []sql.Row{ {types.NewOkResult(0)}, }, diff --git a/sql/expression/convert.go b/sql/expression/convert.go index 30baace3fa..48472affd4 100644 --- a/sql/expression/convert.go +++ b/sql/expression/convert.go @@ -15,9 +15,7 @@ package expression import ( - "encoding/hex" "fmt" - "strconv" "strings" "time" @@ -349,33 +347,20 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s } return d, nil case ConvertToDecimal: - value, err := prepareForNumericContext(ctx, val, originType, false) - if err != nil { - return nil, err - } dt := createConvertedDecimalType(typeLength, typeScale, false) - d, _, err := dt.Convert(ctx, value) + d, _, err := dt.Convert(ctx, val) if err != nil { return dt.Zero(), nil } return d, nil case ConvertToFloat: - value, err := prepareForNumericContext(ctx, val, originType, false) - if err != nil { - return nil, err - } - d, _, err := types.Float32.Convert(ctx, value) + d, _, err := types.Float32.Convert(ctx, val) if err != nil { return types.Float32.Zero(), nil } return d, nil case ConvertToDouble, ConvertToReal: - //value, err := prepareForNumericContext(ctx, val, originType, false) - value, err := types.ConvertOrTruncate(ctx, val, originType) - if err != nil { - return nil, err - } - d, _, err := types.Float64.Convert(ctx, value) + d, _, err := types.Float64.Convert(ctx, val) if err != nil { if sql.ErrTruncatedIncorrect.Is(err) { ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) @@ -391,15 +376,10 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s } return js, nil case ConvertToSigned: - value, err := prepareForNumericContext(ctx, val, originType, true) - if err != nil { - return nil, err - } - num, _, err := types.Int64.Convert(ctx, value) + num, _, err := types.Int64.Convert(ctx, val) if err != nil { return types.Int64.Zero(), nil } - return num, nil case ConvertToTime: t, _, err := types.Time.Convert(ctx, val) @@ -408,13 +388,9 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s } return t, nil case ConvertToUnsigned: - value, err := prepareForNumericContext(ctx, val, originType, true) + num, _, err := types.Uint64.Convert(ctx, val) if err != nil { - return nil, err - } - num, _, err := types.Uint64.Convert(ctx, value) - if err != nil { - num, _, err = types.Int64.Convert(ctx, value) + num, _, err = types.Int64.Convert(ctx, val) if err != nil { return types.Uint64.Zero(), nil } @@ -422,7 +398,7 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s } return num, nil case ConvertToYear: - value, err := convertHexBlobToDecimalForNumericContext(val, originType) + value, err := types.ConvertHexBlobToUint(val, originType) if err != nil { return nil, err } @@ -484,27 +460,3 @@ func createConvertedDecimalType(length, scale int, logErrors bool) sql.DecimalTy } return types.InternalDecimalType } - -// prepareForNumberContext makes necessary preparations to strings and byte arrays for conversions to numbers -func prepareForNumericContext(ctx *sql.Context, val interface{}, originType sql.Type, isInt bool) (interface{}, error) { - if s, isString := val.(string); isString && types.IsTextOnly(originType) { - return sql.TrimStringToNumberPrefix(ctx, s, isInt), nil - } - return convertHexBlobToDecimalForNumericContext(val, originType) -} - -// convertHexBlobToDecimalForNumericContext converts byte array value to unsigned int value if originType is BLOB type. -// This function is called when convertTo type is number type only. The hex literal values are parsed into blobs as -// binary string as default, but for numeric context, the value should be a number. -// Byte arrays of other SQL types are not handled here. -func convertHexBlobToDecimalForNumericContext(val interface{}, originType sql.Type) (interface{}, error) { - if bin, isBinary := val.([]byte); isBinary && types.IsBlobType(originType) { - stringVal := hex.EncodeToString(bin) - decimalNum, err := strconv.ParseUint(stringVal, 16, 64) - if err != nil { - return nil, errors.NewKind("failed to convert hex blob value to unsigned int").New() - } - val = decimalNum - } - return val, nil -} diff --git a/sql/hash/hash.go b/sql/hash/hash.go index 62d5ed2c85..2a46f0265a 100644 --- a/sql/hash/hash.go +++ b/sql/hash/hash.go @@ -124,7 +124,7 @@ func HashOfSimple(ctx *sql.Context, i interface{}, t sql.Type) (uint64, error) { if s, ok := i.(string); ok { str = s } else { - converted, err := types.ConvertOrTruncate(ctx, i, t) + converted, _, err := t.Convert(ctx, i) if err != nil { return 0, err } @@ -133,8 +133,15 @@ func HashOfSimple(ctx *sql.Context, i interface{}, t sql.Type) (uint64, error) { return 0, err } } + } else if types.IsEnum(t) || types.IsSet(t) { + converted, _, err := t.Convert(ctx, i) + if err != nil { + str = fmt.Sprintf("%v", nil) + } else { + str = fmt.Sprintf("%v", converted) + } } else { - x, err := types.ConvertOrTruncate(ctx, i, t.Promote()) + x, _, err := t.Promote().Convert(ctx, i) if err != nil { return 0, err } diff --git a/sql/types/conversion.go b/sql/types/conversion.go index 8b406acfac..dc5e8a2d11 100644 --- a/sql/types/conversion.go +++ b/sql/types/conversion.go @@ -15,11 +15,13 @@ package types import ( + "encoding/hex" "fmt" "reflect" "strconv" "strings" "time" + "unicode" "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" @@ -769,7 +771,7 @@ func TypeAwareConversion(ctx *sql.Context, val interface{}, originalType sql.Typ // value is truncated to the Zero value for type |t|. If the value does not convert and the type is not automatically // coerced, then return an error. func ConvertOrTruncate(ctx *sql.Context, i interface{}, t sql.Type) (interface{}, error) { - // Do nothing if type is no provided. + // Do nothing if type is not provided. if t == nil { return i, nil } @@ -805,3 +807,48 @@ func ConvertOrTruncate(ctx *sql.Context, i interface{}, t sql.Type) (interface{} return t.Zero(), nil } + +// ConvertHexBlobToUint converts byte array value to unsigned int value if originType is BLOB type. +// This function is called when convertTo type is number type only. The hex literal values are parsed into blobs as +// binary string as default, but for numeric context, the value should be a number. +// Byte arrays of other SQL types are not handled here. +func ConvertHexBlobToUint(val interface{}, originType sql.Type) (interface{}, error) { + var err error + if bin, isBinary := val.([]byte); isBinary && IsBlobType(originType) { + stringVal := hex.EncodeToString(bin) + val, err = strconv.ParseUint(stringVal, 16, 64) + if err != nil { + return nil, errors.NewKind("failed to convert hex blob value to unsigned int").New() + } + } + return val, nil +} + +// TruncateStringToNumber truncates a string to the appropriate number prefix +func TruncateStringToNumber(s string, isInt bool) string { + if isInt { + s = strings.TrimLeft(s, IntCutSet) + } else { + s = strings.TrimLeft(s, NumericCutSet) + } + + seenDigit := false + seenDot := false + seenExp := false + signIndex := 0 + + for i := 0; i < len(s); i++ { + char := rune(s[i]) + if unicode.IsDigit(char) { + seenDigit = true + } else if char == '.' && !seenDot && !isInt { + seenDot = true + } else if (char == 'e' || char == 'E') && !seenExp && seenDigit && !isInt { + seenExp = true + signIndex = i + 1 + } else if !((char == '-' || char == '+') && i == signIndex) { + return s[:i] + } + } + return s +} diff --git a/sql/types/decimal.go b/sql/types/decimal.go index 48fa0288bc..f23186ca66 100644 --- a/sql/types/decimal.go +++ b/sql/types/decimal.go @@ -19,14 +19,13 @@ import ( "fmt" "math/big" "reflect" - "strings" - "github.com/dolthub/vitess/go/sqltypes" - "github.com/dolthub/vitess/go/vt/proto/query" "github.com/shopspring/decimal" "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/vitess/go/sqltypes" + "github.com/dolthub/vitess/go/vt/proto/query" ) const ( @@ -206,20 +205,26 @@ func (t DecimalType_) ConvertToNullDecimal(v interface{}) (decimal.NullDecimal, if len(value) == 0 { return t.ConvertToNullDecimal(decimal.NewFromInt(0)) } - var err error - res, err = decimal.NewFromString(value) - if err != nil { - // The decimal library cannot handle all of the different formats - bf, _, err := new(big.Float).SetPrec(217).Parse(value, 0) - if err != nil { - return decimal.NullDecimal{}, err - } - res, err = decimal.NewFromString(bf.Text('f', -1)) - if err != nil { - return decimal.NullDecimal{}, err + if dec, err := decimal.NewFromString(value); err == nil { + return t.ConvertToNullDecimal(dec) + } + // The decimal library cannot handle all the different formats + if bf, _, err := new(big.Float).SetPrec(217).Parse(value, 0); err == nil { + if res, err = decimal.NewFromString(bf.Text('f', -1)); err == nil { + return t.ConvertToNullDecimal(res) } } - return t.ConvertToNullDecimal(res) + + // TODO: how do we know that it is a hex number and not string? + value = TruncateStringToNumber(value, false) + if len(value) == 0 { + return t.ConvertToNullDecimal(decimal.NewFromInt(0)) + } + dec, err := decimal.NewFromString(value) + if err != nil { + return decimal.NullDecimal{}, err + } + return t.ConvertToNullDecimal(dec) case *big.Float: return t.ConvertToNullDecimal(value.Text('f', -1)) case *big.Int: diff --git a/sql/types/number.go b/sql/types/number.go index e0a3ed1a07..5d383864b7 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -22,7 +22,6 @@ import ( "reflect" "regexp" "strconv" - "strings" "time" "github.com/dolthub/vitess/go/sqltypes" @@ -328,7 +327,7 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{ return convertToUint64(t, v) case sqltypes.Float32: num, err := convertToFloat64(t, v) - if err != nil { + if err != nil && !sql.ErrTruncatedType.Is(err) { return nil, sql.OutOfRange, err } if num > math.MaxFloat32 { @@ -336,7 +335,7 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{ } else if num < -math.MaxFloat32 { return float32(-math.MaxFloat32), sql.OutOfRange, nil } - return float32(num), sql.InRange, nil + return float32(num), sql.InRange, nil // TODO: pass up error for warning? case sqltypes.Float64: ret, err := convertToFloat64(t, v) return ret, sql.InRange, err @@ -991,23 +990,29 @@ func convertToInt64(t NumberTypeImpl_, v interface{}) (int64, sql.ConvertInRange } return i, sql.InRange, nil case string: - v = strings.Trim(v, IntCutSet) - if v == "" { + if len(v) == 0 { // StringType{}.Zero() returns empty string, but should represent "0" for number value return 0, sql.InRange, nil } // Parse first an integer, which allows for more values than float64 - i, err := strconv.ParseInt(v, 10, 64) - if err == nil { + if i, err := strconv.ParseInt(v, 10, 64); err == nil { return i, sql.InRange, nil } - // If that fails, try as a float and truncate it to integral - f, err := strconv.ParseFloat(v, 64) - if err != nil { - return 0, sql.OutOfRange, sql.ErrInvalidValue.New(v, t.String()) + // If that fails, try as a float and round to integral + if f, err := strconv.ParseFloat(v, 64); err == nil { + f = math.Round(f) + return int64(f), sql.InRange, nil } - f = math.Round(f) - return int64(f), sql.InRange, nil + // If that fails, truncate the string and parse as int + // TODO: throw error / warning? + v = TruncateStringToNumber(v, true) + if len(v) == 0 { + return 0, sql.InRange, nil + } + if i, err := strconv.ParseInt(v, 10, 64); err == nil { + return i, sql.InRange, nil + } + return 0, sql.OutOfRange, sql.ErrInvalidValue.New(v, t.String()) case bool: if v { return 1, sql.InRange, nil @@ -1150,21 +1155,24 @@ func convertToUint64(t NumberTypeImpl_, v interface{}) (uint64, sql.ConvertInRan case float32: if v > float32(math.MaxInt64) { return math.MaxUint64, sql.OutOfRange, nil - } else if v < 0 { + } + if v < 0 { return uint64(math.MaxUint64 - v), sql.OutOfRange, nil } return uint64(math.Round(float64(v))), sql.InRange, nil case float64: if v >= float64(math.MaxUint64) { return math.MaxUint64, sql.OutOfRange, nil - } else if v <= 0 { + } + if v <= 0 { return uint64(math.MaxUint64 - v), sql.OutOfRange, nil } return uint64(math.Round(v)), sql.InRange, nil case decimal.Decimal: if v.GreaterThan(dec_uint64_max) { return math.MaxUint64, sql.OutOfRange, nil - } else if v.LessThan(dec_zero) { + } + if v.LessThan(dec_zero) { ret, _ := dec_uint64_max.Sub(v).Float64() return uint64(math.Round(ret)), sql.OutOfRange, nil } @@ -1178,17 +1186,15 @@ func convertToUint64(t NumberTypeImpl_, v interface{}) (uint64, sql.ConvertInRan } return i, sql.InRange, nil case string: - v = strings.Trim(v, IntCutSet) if i, err := strconv.ParseUint(v, 10, 64); err == nil { return i, sql.InRange, nil } else if err == strconv.ErrRange { // Number is too large for uint64, return max value and OutOfRange return math.MaxUint64, sql.OutOfRange, nil } + // If that fails, try as a float and round to integral if f, err := strconv.ParseFloat(v, 64); err == nil { - if val, inRange, err := convertToUint64(t, f); err == nil && inRange { - return val, inRange, err - } + return convertToUint64(t, f) } return 0, sql.OutOfRange, sql.ErrInvalidValue.New(v, t.String()) case bool: @@ -1281,14 +1287,15 @@ func convertToUint32(t NumberTypeImpl_, v interface{}) (uint32, sql.ConvertInRan } return uint32(i), sql.InRange, nil case string: - v = strings.Trim(v, IntCutSet) if i, err := strconv.ParseUint(v, 10, 32); err == nil { return uint32(i), sql.InRange, nil } if f, err := strconv.ParseFloat(v, 64); err == nil { - if val, inRange, err := convertToUint32(t, f); err == nil && inRange { - return val, inRange, err - } + return convertToUint32(t, f) + } + v = TruncateStringToNumber(v, true) + if i, err := strconv.ParseUint(v, 10, 32); err == nil { + return uint32(i), sql.InRange, nil } return 0, sql.OutOfRange, sql.ErrInvalidValue.New(v, t.String()) case bool: @@ -1377,14 +1384,15 @@ func convertToUint16(t NumberTypeImpl_, v interface{}) (uint16, sql.ConvertInRan } return uint16(i), sql.InRange, nil case string: - v = strings.Trim(v, IntCutSet) if i, err := strconv.ParseUint(v, 10, 16); err == nil { return uint16(i), sql.InRange, nil } if f, err := strconv.ParseFloat(v, 64); err == nil { - if val, inRange, err := convertToUint16(t, f); err == nil && inRange { - return val, inRange, err - } + return convertToUint16(t, f) + } + v = TruncateStringToNumber(v, true) + if i, err := strconv.ParseUint(v, 10, 16); err == nil { + return uint16(i), sql.InRange, nil } return 0, sql.OutOfRange, sql.ErrInvalidValue.New(v, t.String()) case bool: @@ -1477,14 +1485,15 @@ func convertToUint8(t NumberTypeImpl_, v interface{}) (uint8, sql.ConvertInRange } return uint8(i), sql.InRange, nil case string: - v = strings.Trim(v, IntCutSet) if i, err := strconv.ParseUint(v, 10, 8); err == nil { return uint8(i), sql.InRange, nil } if f, err := strconv.ParseFloat(v, 64); err == nil { - if val, inRange, err := convertToUint8(t, f); err == nil && inRange { - return val, inRange, err - } + return convertToUint8(t, f) + } + v = TruncateStringToNumber(v, true) + if i, err := strconv.ParseUint(v, 10, 8); err == nil { + return uint8(i), sql.InRange, nil } return 0, sql.OutOfRange, sql.ErrInvalidValue.New(v, t.String()) case bool: @@ -1537,8 +1546,6 @@ func convertToFloat64(t NumberTypeImpl_, v interface{}) (float64, error) { } return float64(i), nil case string: - // TODO: just trimStringToPrefix here - v = strings.Trim(v, NumericCutSet) i, err := strconv.ParseFloat(v, 64) if err != nil { // parse the first longest valid numbers From 389b1c2faa38244ee27d3714bccf3714b285191d Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 9 Sep 2025 15:37:22 -0700 Subject: [PATCH 03/48] asdf --- sql/types/number.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/types/number.go b/sql/types/number.go index 5d383864b7..3d54b5cd7c 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -327,7 +327,7 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{ return convertToUint64(t, v) case sqltypes.Float32: num, err := convertToFloat64(t, v) - if err != nil && !sql.ErrTruncatedType.Is(err) { + if err != nil && !sql.ErrTruncatedIncorrect.Is(err) { return nil, sql.OutOfRange, err } if num > math.MaxFloat32 { From 04aea421b166a00edd2cb33f64810f1058ba4851 Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 9 Sep 2025 16:27:31 -0700 Subject: [PATCH 04/48] aaaaa --- sql/types/number.go | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/types/number.go b/sql/types/number.go index 3d54b5cd7c..082c803243 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -990,6 +990,7 @@ func convertToInt64(t NumberTypeImpl_, v interface{}) (int64, sql.ConvertInRange } return i, sql.InRange, nil case string: + if len(v) == 0 { // StringType{}.Zero() returns empty string, but should represent "0" for number value return 0, sql.InRange, nil From 23fe5883e8550e78999e1af0314be69098aa9e16 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 10 Sep 2025 10:11:52 -0700 Subject: [PATCH 05/48] consolidate logic for truncation --- sql/hash/hash.go | 3 ++- sql/types/conversion.go | 9 ++------- sql/types/decimal.go | 6 +++--- sql/types/number.go | 28 +++++++++++++++------------- 4 files changed, 22 insertions(+), 24 deletions(-) diff --git a/sql/hash/hash.go b/sql/hash/hash.go index 2a46f0265a..ab21b08db1 100644 --- a/sql/hash/hash.go +++ b/sql/hash/hash.go @@ -142,7 +142,8 @@ func HashOfSimple(ctx *sql.Context, i interface{}, t sql.Type) (uint64, error) { } } else { x, _, err := t.Promote().Convert(ctx, i) - if err != nil { + // TODO: throw warning? + if err != nil && !sql.ErrTruncatedIncorrect.Is(err) { return 0, err } diff --git a/sql/types/conversion.go b/sql/types/conversion.go index dc5e8a2d11..70235d41d3 100644 --- a/sql/types/conversion.go +++ b/sql/types/conversion.go @@ -824,14 +824,9 @@ func ConvertHexBlobToUint(val interface{}, originType sql.Type) (interface{}, er return val, nil } -// TruncateStringToNumber truncates a string to the appropriate number prefix +// TruncateStringToNumber truncates a string to the appropriate number prefix. +// This function expects whitespace to already be properly trimmed. func TruncateStringToNumber(s string, isInt bool) string { - if isInt { - s = strings.TrimLeft(s, IntCutSet) - } else { - s = strings.TrimLeft(s, NumericCutSet) - } - seenDigit := false seenDot := false seenExp := false diff --git a/sql/types/decimal.go b/sql/types/decimal.go index f23186ca66..228d909de3 100644 --- a/sql/types/decimal.go +++ b/sql/types/decimal.go @@ -19,6 +19,7 @@ import ( "fmt" "math/big" "reflect" + "strings" "github.com/shopspring/decimal" "gopkg.in/src-d/go-errors.v1" @@ -200,8 +201,7 @@ func (t DecimalType_) ConvertToNullDecimal(v interface{}) (decimal.NullDecimal, case float64: return t.ConvertToNullDecimal(decimal.NewFromFloat(value)) case string: - // TODO: implement truncation here - value = strings.Trim(value, sql.NumericCutSet) + value = strings.Trim(value, NumericCutSet) if len(value) == 0 { return t.ConvertToNullDecimal(decimal.NewFromInt(0)) } @@ -215,7 +215,7 @@ func (t DecimalType_) ConvertToNullDecimal(v interface{}) (decimal.NullDecimal, } } - // TODO: how do we know that it is a hex number and not string? + // TODO: hex strings should not make it this far as numbers value = TruncateStringToNumber(value, false) if len(value) == 0 { return t.ConvertToNullDecimal(decimal.NewFromInt(0)) diff --git a/sql/types/number.go b/sql/types/number.go index 082c803243..d73d3abf86 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -22,6 +22,7 @@ import ( "reflect" "regexp" "strconv" + "strings" "time" "github.com/dolthub/vitess/go/sqltypes" @@ -990,7 +991,7 @@ func convertToInt64(t NumberTypeImpl_, v interface{}) (int64, sql.ConvertInRange } return i, sql.InRange, nil case string: - + v = strings.Trim(v, IntCutSet) if len(v) == 0 { // StringType{}.Zero() returns empty string, but should represent "0" for number value return 0, sql.InRange, nil @@ -1001,18 +1002,18 @@ func convertToInt64(t NumberTypeImpl_, v interface{}) (int64, sql.ConvertInRange } // If that fails, try as a float and round to integral if f, err := strconv.ParseFloat(v, 64); err == nil { - f = math.Round(f) + f = math.Round(f) // TODO: inserting rounds up, while casting truncates return int64(f), sql.InRange, nil } // If that fails, truncate the string and parse as int - // TODO: throw error / warning? v = TruncateStringToNumber(v, true) if len(v) == 0 { - return 0, sql.InRange, nil + return 0, sql.InRange, sql.ErrTruncatedIncorrect.New(t.String(), v) } if i, err := strconv.ParseInt(v, 10, 64); err == nil { - return i, sql.InRange, nil + return i, sql.InRange, sql.ErrTruncatedIncorrect.New(t.String(), v) } + // TODO: what should this error be? return 0, sql.OutOfRange, sql.ErrInvalidValue.New(v, t.String()) case bool: if v { @@ -1547,14 +1548,15 @@ func convertToFloat64(t NumberTypeImpl_, v interface{}) (float64, error) { } return float64(i), nil case string: - i, err := strconv.ParseFloat(v, 64) - if err != nil { - // parse the first longest valid numbers - s := numre.FindString(v) - i, _ = strconv.ParseFloat(s, 64) - return i, sql.ErrTruncatedIncorrect.New(t.String(), v) - } - return i, nil + v = strings.Trim(v, NumericCutSet) + if i, err := strconv.ParseFloat(v, 64); err == nil { + return i, nil + } + // TODO: what's the difference between this and TruncateStringToNumber? + // parse the first longest valid numbers + s := numre.FindString(v) + i, _ := strconv.ParseFloat(s, 64) + return i, sql.ErrTruncatedIncorrect.New(t.String(), v) case bool: if v { return 1, nil From 24ea63f7620776208fb7be728fd4bf8e00ab3755 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 10 Sep 2025 13:38:09 -0700 Subject: [PATCH 06/48] some progress --- enginetest/memory_engine_test.go | 182 ++++++++++++++++++++++++++- enginetest/queries/script_queries.go | 8 +- sql/expression/convert.go | 36 +++--- sql/types/conversion.go | 1 + sql/types/number.go | 13 +- 5 files changed, 212 insertions(+), 28 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index fdebe8924e..a671d24cda 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -203,17 +203,193 @@ func TestSingleScript(t *testing.T) { //t.Skip() var scripts = []queries.ScriptTest{ { - Name: "AS OF propagates to nested CALLs", + Skip: true, + Name: "asdf", SetUpScript: []string{}, Assertions: []queries.ScriptTestAssertion{ { - Query: "select x'20' = 32;", + Query: "select cast('-3.1a' as signed);", Expected: []sql.Row{ - {types.NewOkResult(0)}, + {-3}, }, }, }, }, + { + // https://github.com/dolthub/dolt/issues/9733 + // https://github.com/dolthub/dolt/issues/9739 + //Skip: true, + Name: "strings cast to numbers", + SetUpScript: []string{ + "create table test01(pk varchar(20) primary key);", + `insert into test01 values + (' 3 12 4'), + (' 3.2 12 4'), + ('-3.1234'), + ('-3.1a'), + ('-5+8'), + ('+3.1234'), + ('11d'), + ('11wha?'), + ('11'), + ('12'), + ('1a1'), + ('a1a1'), + ('11-5'), + ('3. 12 4'), + ('5.932887e+07'), + ('5.932887e+07abc'), + ('5.932887e7'), + ('5.932887e7abc');`, + }, + Assertions: []queries.ScriptTestAssertion{ + { + Dialect: "mysql", + Query: "select pk, cast(pk as signed) from test01", + Expected: []sql.Row{ + {" 3 12 4", 3}, + {" 3.2 12 4", 3}, + {"-3.1234", -3}, + {"-3.1a", -3}, + {"-5+8", -5}, + {"+3.1234", 3}, + {"11", 11}, + {"11-5", 11}, + {"11d", 11}, + {"11wha?", 11}, + {"12", 12}, + {"1a1", 1}, + {"3. 12 4", 3}, + {"5.932887e+07", 5}, + {"5.932887e+07abc", 5}, + {"5.932887e7", 5}, + {"5.932887e7abc", 5}, + {"a1a1", 0}, + }, + }, + { + Dialect: "mysql", + Query: "select pk, cast(pk as unsigned) from test01", + Expected: []sql.Row{ + {" 3 12 4", uint64(3)}, + {" 3.2 12 4", uint64(3)}, + {"-3.1234", uint64(18446744073709551613)}, + {"-3.1a", uint64(18446744073709551613)}, + {"-5+8", uint64(18446744073709551611)}, + {"+3.1234", uint64(3)}, + {"11", uint64(11)}, + {"11-5", uint64(11)}, + {"11d", uint64(11)}, + {"11wha?", uint64(11)}, + {"12", uint64(12)}, + {"1a1", uint64(1)}, + {"3. 12 4", uint64(3)}, + {"5.932887e+07", uint64(5)}, + {"5.932887e+07abc", uint64(5)}, + {"5.932887e7", uint64(5)}, + {"5.932887e7abc", uint64(5)}, + {"a1a1", uint64(0)}, + }, + }, + { + Dialect: "mysql", + Query: "select pk, cast(pk as decimal(12,3)) from test01", + Expected: []sql.Row{ + {" 3 12 4", "3.000"}, + {" 3.2 12 4", "3.200"}, + {"-3.1234", "-3.123"}, + {"-3.1a", "-3.100"}, + {"-5+8", "-5.000"}, + {"+3.1234", "3.123"}, + {"11", "11.000"}, + {"11-5", "11.000"}, + {"11d", "11.000"}, + {"11wha?", "11.000"}, + {"12", "12.000"}, + {"1a1", "1.000"}, + {"3. 12 4", "3.000"}, + {"5.932887e+07", "59328870.000"}, + {"5.932887e+07abc", "59328870.000"}, + {"5.932887e7", "59328870.000"}, + {"5.932887e7abc", "59328870.000"}, + {"a1a1", "0.000"}, + }, + }, + { + Query: "select * from test01 where pk in ('11')", + Expected: []sql.Row{{"11"}}, + }, + { + // https://github.com/dolthub/dolt/issues/9739 + Skip: true, + Dialect: "mysql", + Query: "select * from test01 where pk in (11)", + Expected: []sql.Row{ + {"11"}, + {"11-5"}, + {"11d"}, + {"11wha?"}, + }, + }, + { + // https://github.com/dolthub/dolt/issues/9739 + Skip: true, + Dialect: "mysql", + Query: "select * from test01 where pk=3", + Expected: []sql.Row{ + {" 3 12 4"}, + {" 3. 12 4"}, + {"3. 12 4"}, + }, + }, + { + // https://github.com/dolthub/dolt/issues/9739 + Skip: true, + Dialect: "mysql", + Query: "select * from test01 where pk>=3 and pk < 4", + Expected: []sql.Row{ + {" 3 12 4"}, + {" 3. 12 4"}, + {" 3.2 12 4"}, + {"+3.1234"}, + {"3. 12 4"}, + }, + }, + //{ + // // https://github.com/dolthub/dolt/issues/9739 + // Skip: true, + // Dialect: "mysql", + // Query: "select * from test02 where pk in ('11asdf')", + // Expected: []sql.Row{{"11"}}, + //}, + //{ + // // https://github.com/dolthub/dolt/issues/9739 + // Skip: true, + // Dialect: "mysql", + // Query: "select * from test02 where pk='11.12asdf'", + // Expected: []sql.Row{}, + //}, + }, + }, + //{ + // Name: "AS OF propagates to nested CALLs", + // SetUpScript: []string{}, + // Assertions: []queries.ScriptTestAssertion{ + // { + // Query: "select cast('123.99' as signed);", + // Expected: []sql.Row{ + // {123}, + // }, + // }, + // // TODO: some how fix this + // { + // Query: "select x'20' = 32;", + // Expected: []sql.Row{ + // {types.NewOkResult(0)}, + // }, + // }, + // }, + //}, } for _, test := range scripts { diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index 2804c7fa96..a70dd2502b 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -11778,13 +11778,13 @@ select * from t1 except ( // https://github.com/dolthub/dolt/issues/9739 Name: "strings cast to numbers", SetUpScript: []string{ - "create table test01(pk varchar(20) primary key)", + "create table test01(pk varchar(20) primary key);", `insert into test01 values (' 3 12 4'), (' 3.2 12 4'),('-3.1234'),('-3.1a'),('-5+8'),('+3.1234'), ('11d'),('11wha?'),('11'),('12'),('1a1'),('a1a1'),('11-5'), - ('3. 12 4'),('5.932887e+07'),('5.932887e+07abc'),('5.932887e7'),('5.932887e7abc')`, - "create table test02(pk int primary key)", - "insert into test02 values(11),(12),(13),(14),(15)", + ('3. 12 4'),('5.932887e+07'),('5.932887e+07abc'),('5.932887e7'),('5.932887e7abc');`, + "create table test02(pk int primary key);", + "insert into test02 values(11),(12),(13),(14),(15);", }, Assertions: []ScriptTestAssertion{ { diff --git a/sql/expression/convert.go b/sql/expression/convert.go index 48472affd4..04a4147c73 100644 --- a/sql/expression/convert.go +++ b/sql/expression/convert.go @@ -361,14 +361,14 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s return d, nil case ConvertToDouble, ConvertToReal: d, _, err := types.Float64.Convert(ctx, val) - if err != nil { - if sql.ErrTruncatedIncorrect.Is(err) { - ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) - return d, nil - } - return types.Float64.Zero(), nil + if err == nil { + return d, nil } - return d, nil + if sql.ErrTruncatedIncorrect.Is(err) { + ctx.Warn(1265, "%s", err.Error()) + return d, nil + } + return types.Float64.Zero(), nil case ConvertToJSON: js, _, err := types.JSON.Convert(ctx, val) if err != nil { @@ -377,10 +377,14 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s return js, nil case ConvertToSigned: num, _, err := types.Int64.Convert(ctx, val) - if err != nil { - return types.Int64.Zero(), nil + if err == nil { + return num, nil } - return num, nil + if sql.ErrTruncatedIncorrect.Is(err) { + ctx.Warn(1265, "%s", err.Error()) + return num, nil + } + return types.Int64.Zero(), nil case ConvertToTime: t, _, err := types.Time.Convert(ctx, val) if err != nil { @@ -389,14 +393,14 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s return t, nil case ConvertToUnsigned: num, _, err := types.Uint64.Convert(ctx, val) + if err == nil { + return num, nil + } + num, _, err = types.Int64.Convert(ctx, val) if err != nil { - num, _, err = types.Int64.Convert(ctx, val) - if err != nil { - return types.Uint64.Zero(), nil - } - return uint64(num.(int64)), nil + return types.Uint64.Zero(), nil } - return num, nil + return uint64(num.(int64)), nil case ConvertToYear: value, err := types.ConvertHexBlobToUint(val, originType) if err != nil { diff --git a/sql/types/conversion.go b/sql/types/conversion.go index 70235d41d3..59a3091867 100644 --- a/sql/types/conversion.go +++ b/sql/types/conversion.go @@ -826,6 +826,7 @@ func ConvertHexBlobToUint(val interface{}, originType sql.Type) (interface{}, er // TruncateStringToNumber truncates a string to the appropriate number prefix. // This function expects whitespace to already be properly trimmed. +// TODO: separate logic for ints and floating point? func TruncateStringToNumber(s string, isInt bool) string { seenDigit := false seenDot := false diff --git a/sql/types/number.go b/sql/types/number.go index d73d3abf86..e9506cd167 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -996,15 +996,17 @@ func convertToInt64(t NumberTypeImpl_, v interface{}) (int64, sql.ConvertInRange // StringType{}.Zero() returns empty string, but should represent "0" for number value return 0, sql.InRange, nil } + // TODO: always attempt to truncate and only error if truncation did something // Parse first an integer, which allows for more values than float64 if i, err := strconv.ParseInt(v, 10, 64); err == nil { return i, sql.InRange, nil } + // TODO: convert for insert should just be an entirely new function // If that fails, try as a float and round to integral - if f, err := strconv.ParseFloat(v, 64); err == nil { - f = math.Round(f) // TODO: inserting rounds up, while casting truncates - return int64(f), sql.InRange, nil - } + //if f, err := strconv.ParseFloat(v, 64); err == nil { + // f = math.Round(f) // TODO: inserting rounds up, while casting truncates + // return int64(f), sql.InRange, nil + //} // If that fails, truncate the string and parse as int v = TruncateStringToNumber(v, true) if len(v) == 0 { @@ -1553,8 +1555,9 @@ func convertToFloat64(t NumberTypeImpl_, v interface{}) (float64, error) { return i, nil } // TODO: what's the difference between this and TruncateStringToNumber? + // the difference is that this doesn't work lmao // parse the first longest valid numbers - s := numre.FindString(v) + s := TruncateStringToNumber(v, false) i, _ := strconv.ParseFloat(s, 64) return i, sql.ErrTruncatedIncorrect.New(t.String(), v) case bool: From d8a7a1047fd5863bdea966b23bb8940a52985493 Mon Sep 17 00:00:00 2001 From: jycor Date: Wed, 10 Sep 2025 17:14:47 +0000 Subject: [PATCH 07/48] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/types/decimal.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/types/decimal.go b/sql/types/decimal.go index 228d909de3..15c8b5cb5e 100644 --- a/sql/types/decimal.go +++ b/sql/types/decimal.go @@ -21,6 +21,9 @@ import ( "reflect" "strings" + "github.com/dolthub/vitess/go/sqltypes" + "github.com/dolthub/vitess/go/vt/proto/query" + "github.com/shopspring/decimal" "github.com/shopspring/decimal" "gopkg.in/src-d/go-errors.v1" From 6e8c62056a6466040873cca4a5e3559965ea2521 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 11 Sep 2025 16:41:48 -0700 Subject: [PATCH 08/48] refactoring and fixing char function --- sql/expression/convert.go | 4 + sql/expression/function/char.go | 17 +- sql/expression/function/if.go | 6 +- sql/rowexec/insert.go | 2 + sql/type.go | 7 + sql/types/conversion.go | 28 +- sql/types/decimal.go | 56 ++- sql/types/number.go | 692 ++++++++++++++++---------------- sql/types/strings.go | 4 +- 9 files changed, 419 insertions(+), 397 deletions(-) diff --git a/sql/expression/convert.go b/sql/expression/convert.go index 04a4147c73..a70ef3bd5f 100644 --- a/sql/expression/convert.go +++ b/sql/expression/convert.go @@ -396,6 +396,10 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s if err == nil { return num, nil } + if sql.ErrTruncatedIncorrect.Is(err) { + ctx.Warn(1265, "%s", err.Error()) + return num, nil + } num, _, err = types.Int64.Convert(ctx, val) if err != nil { return types.Uint64.Zero(), nil diff --git a/sql/expression/function/char.go b/sql/expression/function/char.go index 02c8d4a706..cbfeeb1626 100644 --- a/sql/expression/function/char.go +++ b/sql/expression/function/char.go @@ -89,9 +89,14 @@ func (c *Char) CollationCoercibility(ctx *sql.Context) (collation sql.CollationI // This function is essentially converting the number to base 256 func char(num uint32) []byte { if num == 0 { - return []byte{} + return []byte{0} } - return append(char(num>>8), byte(num&255)) + res := byte(num & 255) + nextNum := num >> 8 + if nextNum == 0 { + return []byte{res} + } + return append(char(num>>8), res) } // Eval implements the sql.Expression interface @@ -113,11 +118,11 @@ func (c *Char) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { v, _, err := types.Uint32.Convert(ctx, val) if err != nil { - ctx.Warn(1292, "Truncated incorrect INTEGER value: '%v'", val) - res = append(res, 0) - continue + if !sql.ErrTruncatedIncorrect.Is(err) { + return nil, err + } + ctx.Warn(1292, "%s", err.Error()) } - res = append(res, char(v.(uint32))...) } diff --git a/sql/expression/function/if.go b/sql/expression/function/if.go index c019357f39..39a587faf4 100644 --- a/sql/expression/function/if.go +++ b/sql/expression/function/if.go @@ -71,12 +71,12 @@ func (f *If) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { if e == nil { asBool = false } else { - asBool, err = sql.ConvertToBool(ctx, e) - if err != nil { + val, _, err := types.Boolean.Convert(ctx, e) + if err != nil && !sql.ErrTruncatedIncorrect.Is(err) { return nil, err } + asBool = val.(int8) == 1 } - var eval interface{} if asBool { eval, err = f.ifTrue.Eval(ctx, row) diff --git a/sql/rowexec/insert.go b/sql/rowexec/insert.go index 33bebe72fb..d78ee5372f 100644 --- a/sql/rowexec/insert.go +++ b/sql/rowexec/insert.go @@ -120,6 +120,8 @@ func (i *insertIter) Next(ctx *sql.Context) (returnRow sql.Row, returnErr error) ctxWithValues := context.WithValue(ctx.Context, types.ColumnNameKey, col.Name) ctxWithValues = context.WithValue(ctxWithValues, types.RowNumberKey, i.rowNumber) ctxWithColumnInfo := ctx.WithContext(ctxWithValues) + // TODO: add a ConvertForInsert? + // TODO: check mysql strict mode converted, inRange, cErr := col.Type.Convert(ctxWithColumnInfo, row[idx]) if cErr == nil && !inRange { cErr = sql.ErrValueOutOfRange.New(row[idx], col.Type) diff --git a/sql/type.go b/sql/type.go index 379ed92221..73af3cb44b 100644 --- a/sql/type.go +++ b/sql/type.go @@ -128,6 +128,13 @@ type NumberType interface { DisplayWidth() int } +// RoundingNumberType represents Number Types that implement an additional interface +// that supports rounding when converting rather than the default truncation. +type RoundingNumberType interface { + NumberType + ConvertRound(context.Context, any) (any, ConvertInRange, error) +} + // StringType represents all string types, including VARCHAR and BLOB. // https://dev.mysql.com/doc/refman/8.0/en/char.html // https://dev.mysql.com/doc/refman/8.0/en/binary-varbinary.html diff --git a/sql/types/conversion.go b/sql/types/conversion.go index 59a3091867..5381e5c555 100644 --- a/sql/types/conversion.go +++ b/sql/types/conversion.go @@ -826,25 +826,41 @@ func ConvertHexBlobToUint(val interface{}, originType sql.Type) (interface{}, er // TruncateStringToNumber truncates a string to the appropriate number prefix. // This function expects whitespace to already be properly trimmed. -// TODO: separate logic for ints and floating point? -func TruncateStringToNumber(s string, isInt bool) string { +func TruncateStringToNumber(s string) (string, bool) { seenDigit := false seenDot := false seenExp := false signIndex := 0 + s = strings.Trim(s, NumericCutSet) for i := 0; i < len(s); i++ { char := rune(s[i]) if unicode.IsDigit(char) { seenDigit = true - } else if char == '.' && !seenDot && !isInt { + } else if char == '.' && !seenDot { seenDot = true - } else if (char == 'e' || char == 'E') && !seenExp && seenDigit && !isInt { + } else if (char == 'e' || char == 'E') && !seenExp && seenDigit { seenExp = true signIndex = i + 1 } else if !((char == '-' || char == '+') && i == signIndex) { - return s[:i] + return s[:i], true } } - return s + return s, false +} + +// TruncateStringToInt will trim any whitespace from s, then keep the prefix that can be properly parsed into an +// integer. This will return a flag indicating if truncation occurred. +func TruncateStringToInt(s string) (string, bool) { + s = strings.Trim(s, IntCutSet) + for i := 0; i < len(s); i++ { + char := rune(s[i]) + if !unicode.IsDigit(char) { + if (char == '-' || char == '+') && i == 0 { + continue + } + return s[:i], true + } + } + return s, false } diff --git a/sql/types/decimal.go b/sql/types/decimal.go index 15c8b5cb5e..2a212806eb 100644 --- a/sql/types/decimal.go +++ b/sql/types/decimal.go @@ -17,15 +17,13 @@ package types import ( "context" "fmt" - "math/big" - "reflect" - "strings" - "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "github.com/shopspring/decimal" "github.com/shopspring/decimal" "gopkg.in/src-d/go-errors.v1" + "math/big" + "reflect" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/vitess/go/sqltypes" @@ -144,7 +142,7 @@ func (t DecimalType_) Compare(s context.Context, a interface{}, b interface{}) ( // Convert implements Type interface. func (t DecimalType_) Convert(c context.Context, v interface{}) (interface{}, sql.ConvertInRange, error) { dec, err := t.ConvertToNullDecimal(v) - if err != nil { + if err != nil && !sql.ErrIncorrectValue.Is(err) { return nil, sql.OutOfRange, err } if !dec.Valid { @@ -166,13 +164,10 @@ func (t DecimalType_) ConvertNoBoundsCheck(v interface{}) (decimal.Decimal, erro // ConvertToNullDecimal implements DecimalType interface. func (t DecimalType_) ConvertToNullDecimal(v interface{}) (decimal.NullDecimal, error) { - if v == nil { - return decimal.NullDecimal{}, nil - } - var res decimal.Decimal - switch value := v.(type) { + case nil: + return decimal.NullDecimal{}, nil case bool: if value { return t.ConvertToNullDecimal(decimal.NewFromInt(1)) @@ -204,30 +199,30 @@ func (t DecimalType_) ConvertToNullDecimal(v interface{}) (decimal.NullDecimal, case float64: return t.ConvertToNullDecimal(decimal.NewFromFloat(value)) case string: - value = strings.Trim(value, NumericCutSet) - if len(value) == 0 { - return t.ConvertToNullDecimal(decimal.NewFromInt(0)) - } - if dec, err := decimal.NewFromString(value); err == nil { - return t.ConvertToNullDecimal(dec) - } - // The decimal library cannot handle all the different formats - if bf, _, err := new(big.Float).SetPrec(217).Parse(value, 0); err == nil { - if res, err = decimal.NewFromString(bf.Text('f', -1)); err == nil { - return t.ConvertToNullDecimal(res) - } - } - // TODO: hex strings should not make it this far as numbers - value = TruncateStringToNumber(value, false) - if len(value) == 0 { - return t.ConvertToNullDecimal(decimal.NewFromInt(0)) + var err error + truncStr, didTrunc := TruncateStringToNumber(value) + if didTrunc { + err = sql.ErrIncorrectValue.New(t.String(), value) } - dec, err := decimal.NewFromString(value) - if err != nil { + var dec decimal.Decimal + if len(truncStr) == 0 { + dec = decimal.NewFromInt(0) + } else if d, err := decimal.NewFromString(truncStr); err == nil { + dec = d + } else if bf, _, err := new(big.Float).SetPrec(217).Parse(truncStr, 0); err == nil { + // The decimal library cannot handle all the different formats + if d, err = decimal.NewFromString(bf.Text('f', -1)); err == nil { + dec = d + } + } else { return decimal.NullDecimal{}, err } - return t.ConvertToNullDecimal(dec) + decRes, convErr := t.ConvertToNullDecimal(dec) + if convErr != nil { + return decRes, convErr + } + return decRes, err case *big.Float: return t.ConvertToNullDecimal(value.Text('f', -1)) case *big.Int: @@ -257,7 +252,6 @@ func (t DecimalType_) ConvertToNullDecimal(v interface{}) (decimal.NullDecimal, default: return decimal.NullDecimal{}, ErrConvertingToDecimal.New(v) } - return decimal.NullDecimal{Decimal: res, Valid: true}, nil } diff --git a/sql/types/number.go b/sql/types/number.go index e9506cd167..2818adeef8 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -17,12 +17,12 @@ package types import ( "context" "encoding/hex" + "errors" "fmt" "math" "reflect" "regexp" "strconv" - "strings" "time" "github.com/dolthub/vitess/go/sqltypes" @@ -105,6 +105,7 @@ var _ sql.Type = NumberTypeImpl_{} var _ sql.Type2 = NumberTypeImpl_{} var _ sql.CollationCoercible = NumberTypeImpl_{} var _ sql.NumberType = NumberTypeImpl_{} +var _ sql.RoundingNumberType = NumberTypeImpl_{} // CreateNumberType creates a NumberType. func CreateNumberType(baseType query.Type) (sql.NumberType, error) { @@ -149,51 +150,19 @@ func MustCreateNumberTypeWithDisplayWidth(baseType query.Type, displayWidth int) return nt } -func NumericUnaryValue(t sql.Type) interface{} { - nt := t.(NumberTypeImpl_) - switch nt.baseType { - case sqltypes.Int8: - return int8(1) - case sqltypes.Uint8: - return uint8(1) - case sqltypes.Int16: - return int16(1) - case sqltypes.Uint16: - return uint16(1) - case sqltypes.Int24: - return int32(1) - case sqltypes.Uint24: - return uint32(1) - case sqltypes.Int32: - return int32(1) - case sqltypes.Uint32: - return uint32(1) - case sqltypes.Int64: - return int64(1) - case sqltypes.Uint64: - return uint64(1) - case sqltypes.Float32: - return float32(1) - case sqltypes.Float64: - return float64(1) - default: - panic(fmt.Sprintf("%v is not a valid number base type", nt.baseType.String())) - } -} - // Compare implements Type interface. -func (t NumberTypeImpl_) Compare(s context.Context, a interface{}, b interface{}) (int, error) { +func (t NumberTypeImpl_) Compare(ctx context.Context, a any, b any) (int, error) { if hasNulls, res := CompareNulls(a, b); hasNulls { return res, nil } switch t.baseType { case sqltypes.Uint8, sqltypes.Uint16, sqltypes.Uint24, sqltypes.Uint32, sqltypes.Uint64: - ca, _, err := convertToUint64(t, a) + ca, _, err := convertToUint64(t, a, false) if err != nil { return 0, err } - cb, _, err := convertToUint64(t, b) + cb, _, err := convertToUint64(t, b, false) if err != nil { return 0, err } @@ -223,11 +192,11 @@ func (t NumberTypeImpl_) Compare(s context.Context, a interface{}, b interface{} } return +1, nil default: - ca, _, err := convertToInt64(t, a) + ca, _, err := convertToInt64(t, a, false) if err != nil { ca = 0 } - cb, _, err := convertToInt64(t, b) + cb, _, err := convertToInt64(t, b, false) if err != nil { cb = 0 } @@ -243,7 +212,7 @@ func (t NumberTypeImpl_) Compare(s context.Context, a interface{}, b interface{} } // Convert implements Type interface. -func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{}, sql.ConvertInRange, error) { +func (t NumberTypeImpl_) Convert(ctx context.Context, v any) (any, sql.ConvertInRange, error) { var err error if v == nil { return nil, sql.InRange, nil @@ -262,9 +231,9 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{ switch t.baseType { case sqltypes.Int8: - num, _, err := convertToInt64(t, v) - if err != nil { - return nil, sql.OutOfRange, err + num, _, err := convertToInt64(t, v, false) + if err != nil && !sql.ErrTruncatedIncorrect.Is(err) { + return 0, sql.OutOfRange, err } if num > math.MaxInt8 { return int8(math.MaxInt8), sql.OutOfRange, nil @@ -273,11 +242,12 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{ } return int8(num), sql.InRange, nil case sqltypes.Uint8: - return convertToUint8(t, v) + // TODO: convertToUint8 is unnecessary, we can just use convertToInt64 and handle overflow logic here + return convertToUint8(t, v, false) case sqltypes.Int16: - num, _, err := convertToInt64(t, v) - if err != nil { - return nil, sql.OutOfRange, err + num, _, err := convertToInt64(t, v, false) + if err != nil && !sql.ErrTruncatedIncorrect.Is(err) { + return 0, sql.OutOfRange, err } if num > math.MaxInt16 { return int16(math.MaxInt16), sql.OutOfRange, nil @@ -286,11 +256,11 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{ } return int16(num), sql.InRange, nil case sqltypes.Uint16: - return convertToUint16(t, v) + return convertToUint16(t, v, false) case sqltypes.Int24: - num, _, err := convertToInt64(t, v) - if err != nil { - return nil, sql.OutOfRange, err + num, _, err := convertToInt64(t, v, false) + if err != nil && !sql.ErrTruncatedIncorrect.Is(err) { + return 0, sql.OutOfRange, err } if num > (1<<23 - 1) { return int32(1<<23 - 1), sql.OutOfRange, nil @@ -299,9 +269,9 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{ } return int32(num), sql.InRange, nil case sqltypes.Uint24: - num, _, err := convertToInt64(t, v) - if err != nil { - return nil, sql.OutOfRange, err + num, _, err := convertToInt64(t, v, false) + if err != nil && !sql.ErrTruncatedIncorrect.Is(err) { + return 0, sql.OutOfRange, err } if num >= (1 << 24) { return uint32(1<<24 - 1), sql.OutOfRange, nil @@ -310,9 +280,9 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{ } return uint32(num), sql.InRange, nil case sqltypes.Int32: - num, _, err := convertToInt64(t, v) - if err != nil { - return nil, sql.OutOfRange, err + num, _, err := convertToInt64(t, v, false) + if err != nil && !sql.ErrTruncatedIncorrect.Is(err) { + return 0, sql.OutOfRange, err } if num > math.MaxInt32 { return int32(math.MaxInt32), sql.OutOfRange, nil @@ -321,19 +291,20 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{ } return int32(num), sql.InRange, nil case sqltypes.Uint32: - return convertToUint32(t, v) + return convertToUint32(t, v, false) case sqltypes.Int64: - return convertToInt64(t, v) + return convertToInt64(t, v, false) case sqltypes.Uint64: - return convertToUint64(t, v) + return convertToUint64(t, v, false) case sqltypes.Float32: num, err := convertToFloat64(t, v) if err != nil && !sql.ErrTruncatedIncorrect.Is(err) { - return nil, sql.OutOfRange, err + return 0, sql.OutOfRange, err } if num > math.MaxFloat32 { return float32(math.MaxFloat32), sql.OutOfRange, nil - } else if num < -math.MaxFloat32 { + } + if num < -math.MaxFloat32 { return float32(-math.MaxFloat32), sql.OutOfRange, nil } return float32(num), sql.InRange, nil // TODO: pass up error for warning? @@ -345,6 +316,103 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{ } } +func (t NumberTypeImpl_) ConvertRound(ctx context.Context, v any) (any, sql.ConvertInRange, error) { + switch t.baseType { + case sqltypes.Int8, sqltypes.Int16, sqltypes.Int24, sqltypes.Int32, sqltypes.Int64: + switch v.(type) { + case float32, float64: + num, inRange, err := convertToInt64(t, v, true) + if err != nil { + return nil, sql.OutOfRange, err + } + // TODO: write helper method? + switch t.baseType { + case sqltypes.Int8: + if num > math.MaxInt8 { + return int8(math.MaxInt8), sql.OutOfRange, nil + } + if num < math.MinInt8 { + return int8(math.MinInt8), sql.OutOfRange, nil + } + return int8(num), sql.InRange, nil + case sqltypes.Int16: + if num > math.MaxInt16 { + return int16(math.MaxInt16), sql.OutOfRange, nil + } + if num < math.MinInt16 { + return int16(math.MinInt16), sql.OutOfRange, nil + } + return int16(num), sql.InRange, nil + case sqltypes.Int24: + if num > (1<<23 - 1) { + return int32(1<<23 - 1), sql.OutOfRange, nil + } + if num < (-1 << 23) { + return int32(-1 << 23), sql.OutOfRange, nil + } + return int32(num), sql.InRange, nil + case sqltypes.Int32: + if num > math.MaxInt32 { + return int32(math.MaxInt32), sql.OutOfRange, nil + } + if num < math.MinInt32 { + return int32(math.MinInt32), sql.OutOfRange, nil + } + return int32(num), sql.InRange, nil + default: + return num, inRange, nil + } + default: + return t.Convert(ctx, v) + } + case sqltypes.Uint8: + switch v.(type) { + case float32, float64: + convertToUint8(t, v, true) + default: + return t.Convert(ctx, v) + } + case sqltypes.Uint16: + switch v.(type) { + case float32, float64: + convertToUint16(t, v, true) + default: + return t.Convert(ctx, v) + } + case sqltypes.Uint24: + switch v.(type) { + case float32, float64: + num, _, err := convertToInt64(t, v, true) + if err != nil { + return nil, sql.OutOfRange, err + } + if num >= (1 << 24) { + return uint32(1<<24 - 1), sql.OutOfRange, nil + } else if num < 0 { + return uint32(1<<24 - int32(-num)), sql.OutOfRange, nil + } + return uint32(num), sql.InRange, nil + default: + return t.Convert(ctx, v) + } + case sqltypes.Uint32: + switch v.(type) { + case float32, float64: + return convertToUint32(t, v, true) + default: + return t.Convert(ctx, v) + } + case sqltypes.Uint64: + switch v.(type) { + case float32, float64: + return convertToUint64(t, v, true) + default: + return t.Convert(ctx, v) + } + } + return t.Convert(ctx, v) +} + // MaxTextResponseByteLength implements the Type interface func (t NumberTypeImpl_) MaxTextResponseByteLength(*sql.Context) uint32 { // MySQL integer type limits: https://dev.mysql.com/doc/refman/8.0/en/integer-types.html @@ -398,8 +466,8 @@ func (t NumberTypeImpl_) Promote() sql.Type { } } -func (t NumberTypeImpl_) SQLInt8(ctx *sql.Context, dest []byte, v interface{}) ([]byte, error) { - num, _, err := convertToInt64(t, v) +func (t NumberTypeImpl_) SQLInt8(ctx *sql.Context, dest []byte, v any) ([]byte, error) { + num, _, err := convertToInt64(t, v, false) if err != nil { return nil, err } @@ -412,8 +480,8 @@ func (t NumberTypeImpl_) SQLInt8(ctx *sql.Context, dest []byte, v interface{}) ( return dest, nil } -func (t NumberTypeImpl_) SQLInt16(ctx *sql.Context, dest []byte, v interface{}) ([]byte, error) { - num, _, err := convertToInt64(t, v) +func (t NumberTypeImpl_) SQLInt16(ctx *sql.Context, dest []byte, v any) ([]byte, error) { + num, _, err := convertToInt64(t, v, false) if err != nil { return nil, err } @@ -426,8 +494,8 @@ func (t NumberTypeImpl_) SQLInt16(ctx *sql.Context, dest []byte, v interface{}) return dest, nil } -func (t NumberTypeImpl_) SQLInt24(ctx *sql.Context, dest []byte, v interface{}) ([]byte, error) { - num, _, err := convertToInt64(t, v) +func (t NumberTypeImpl_) SQLInt24(ctx *sql.Context, dest []byte, v any) ([]byte, error) { + num, _, err := convertToInt64(t, v, false) if err != nil { return nil, err } @@ -440,8 +508,8 @@ func (t NumberTypeImpl_) SQLInt24(ctx *sql.Context, dest []byte, v interface{}) return dest, nil } -func (t NumberTypeImpl_) SQLInt32(ctx *sql.Context, dest []byte, v interface{}) ([]byte, error) { - num, _, err := convertToInt64(t, v) +func (t NumberTypeImpl_) SQLInt32(ctx *sql.Context, dest []byte, v any) ([]byte, error) { + num, _, err := convertToInt64(t, v, false) if err != nil { return nil, err } @@ -454,8 +522,8 @@ func (t NumberTypeImpl_) SQLInt32(ctx *sql.Context, dest []byte, v interface{}) return dest, nil } -func (t NumberTypeImpl_) SQLInt64(ctx *sql.Context, dest []byte, v interface{}) ([]byte, error) { - vt, _, err := convertToInt64(t, v) +func (t NumberTypeImpl_) SQLInt64(ctx *sql.Context, dest []byte, v any) ([]byte, error) { + vt, _, err := convertToInt64(t, v, false) if err != nil { return nil, err } @@ -463,8 +531,8 @@ func (t NumberTypeImpl_) SQLInt64(ctx *sql.Context, dest []byte, v interface{}) return dest, nil } -func (t NumberTypeImpl_) SQLUint8(ctx *sql.Context, dest []byte, v interface{}) ([]byte, error) { - num, _, err := convertToUint64(t, v) +func (t NumberTypeImpl_) SQLUint8(ctx *sql.Context, dest []byte, v any) ([]byte, error) { + num, _, err := convertToUint64(t, v, false) if err != nil { return nil, err } @@ -476,8 +544,8 @@ func (t NumberTypeImpl_) SQLUint8(ctx *sql.Context, dest []byte, v interface{}) return dest, nil } -func (t NumberTypeImpl_) SQLUint16(ctx *sql.Context, dest []byte, v interface{}) ([]byte, error) { - num, _, err := convertToUint64(t, v) +func (t NumberTypeImpl_) SQLUint16(ctx *sql.Context, dest []byte, v any) ([]byte, error) { + num, _, err := convertToUint64(t, v, false) if err != nil { return nil, err } @@ -489,8 +557,8 @@ func (t NumberTypeImpl_) SQLUint16(ctx *sql.Context, dest []byte, v interface{}) return dest, nil } -func (t NumberTypeImpl_) SQLUint24(ctx *sql.Context, dest []byte, v interface{}) ([]byte, error) { - num, _, err := convertToUint64(t, v) +func (t NumberTypeImpl_) SQLUint24(ctx *sql.Context, dest []byte, v any) ([]byte, error) { + num, _, err := convertToUint64(t, v, false) if err != nil { return nil, err } @@ -502,8 +570,8 @@ func (t NumberTypeImpl_) SQLUint24(ctx *sql.Context, dest []byte, v interface{}) return dest, nil } -func (t NumberTypeImpl_) SQLUint32(ctx *sql.Context, dest []byte, v interface{}) ([]byte, error) { - num, _, err := convertToUint64(t, v) +func (t NumberTypeImpl_) SQLUint32(ctx *sql.Context, dest []byte, v any) ([]byte, error) { + num, _, err := convertToUint64(t, v, false) if err != nil { return nil, err } @@ -515,8 +583,8 @@ func (t NumberTypeImpl_) SQLUint32(ctx *sql.Context, dest []byte, v interface{}) return dest, nil } -func (t NumberTypeImpl_) SQLUint64(ctx *sql.Context, dest []byte, v interface{}) ([]byte, error) { - num, _, err := convertToUint64(t, v) +func (t NumberTypeImpl_) SQLUint64(ctx *sql.Context, dest []byte, v any) ([]byte, error) { + num, _, err := convertToUint64(t, v, false) if err != nil { return nil, err } @@ -528,7 +596,7 @@ func (t NumberTypeImpl_) SQLUint64(ctx *sql.Context, dest []byte, v interface{}) return dest, nil } -func (t NumberTypeImpl_) SQLFloat64(ctx *sql.Context, dest []byte, v interface{}) ([]byte, error) { +func (t NumberTypeImpl_) SQLFloat64(ctx *sql.Context, dest []byte, v any) ([]byte, error) { num, err := convertToFloat64(t, v) if err != nil && !sql.ErrTruncatedIncorrect.Is(err) { return nil, err @@ -537,7 +605,7 @@ func (t NumberTypeImpl_) SQLFloat64(ctx *sql.Context, dest []byte, v interface{} return dest, nil } -func (t NumberTypeImpl_) SQLFloat32(ctx *sql.Context, dest []byte, v interface{}) ([]byte, error) { +func (t NumberTypeImpl_) SQLFloat32(ctx *sql.Context, dest []byte, v any) ([]byte, error) { num, err := convertToFloat64(t, v) if err != nil { return nil, err @@ -552,7 +620,7 @@ func (t NumberTypeImpl_) SQLFloat32(ctx *sql.Context, dest []byte, v interface{} } // SQL implements Type interface. -func (t NumberTypeImpl_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Value, error) { +func (t NumberTypeImpl_) SQL(ctx *sql.Context, dest []byte, v any) (sqltypes.Value, error) { if v == nil { return sqltypes.NULL, nil } @@ -877,7 +945,7 @@ func (t NumberTypeImpl_) ValueType() reflect.Type { } // Zero implements Type interface. -func (t NumberTypeImpl_) Zero() interface{} { +func (t NumberTypeImpl_) Zero() any { switch t.baseType { case sqltypes.Int8: return int8(0) @@ -936,7 +1004,7 @@ func (t NumberTypeImpl_) DisplayWidth() int { return t.displayWidth } -func convertToInt64(t NumberTypeImpl_, v interface{}) (int64, sql.ConvertInRange, error) { +func convertToInt64(t NumberTypeImpl_, v any, round bool) (int64, sql.ConvertInRange, error) { switch v := v.(type) { case time.Time: return v.UTC().Unix(), sql.InRange, nil @@ -966,57 +1034,48 @@ func convertToInt64(t NumberTypeImpl_, v interface{}) (int64, sql.ConvertInRange case float32: if v > float32(math.MaxInt64) { return math.MaxInt64, sql.OutOfRange, nil - } else if v < float32(math.MinInt64) { + } + if v < float32(math.MinInt64) { return math.MinInt64, sql.OutOfRange, nil } - return int64(math.Round(float64(v))), sql.InRange, nil + if round { + return int64(math.Round(float64(v))), sql.InRange, nil + } + return int64(v), sql.InRange, nil case float64: if v > float64(math.MaxInt64) { return math.MaxInt64, sql.OutOfRange, nil - } else if v < float64(math.MinInt64) { + } + if v < float64(math.MinInt64) { return math.MinInt64, sql.OutOfRange, nil } - return int64(math.Round(v)), sql.InRange, nil + if round { + return int64(math.Round(v)), sql.InRange, nil + } + return int64(v), sql.InRange, nil case decimal.Decimal: + // TODO: round? if v.GreaterThan(dec_int64_max) { return dec_int64_max.IntPart(), sql.OutOfRange, nil - } else if v.LessThan(dec_int64_min) { + } + if v.LessThan(dec_int64_min) { return dec_int64_min.IntPart(), sql.OutOfRange, nil } return v.Round(0).IntPart(), sql.InRange, nil case []byte: - i, err := strconv.ParseInt(hex.EncodeToString(v), 16, 64) - if err != nil { - return 0, sql.OutOfRange, sql.ErrInvalidValue.New(v, t.String()) - } - return i, sql.InRange, nil + return convertToInt64(t, string(v), round) case string: - v = strings.Trim(v, IntCutSet) - if len(v) == 0 { - // StringType{}.Zero() returns empty string, but should represent "0" for number value - return 0, sql.InRange, nil - } - // TODO: always attempt to truncate and only error if truncation did something - // Parse first an integer, which allows for more values than float64 - if i, err := strconv.ParseInt(v, 10, 64); err == nil { - return i, sql.InRange, nil - } - // TODO: convert for insert should just be an entirely new function - // If that fails, try as a float and round to integral - //if f, err := strconv.ParseFloat(v, 64); err == nil { - // f = math.Round(f) // TODO: inserting rounds up, while casting truncates - // return int64(f), sql.InRange, nil - //} - // If that fails, truncate the string and parse as int - v = TruncateStringToNumber(v, true) - if len(v) == 0 { - return 0, sql.InRange, sql.ErrTruncatedIncorrect.New(t.String(), v) - } - if i, err := strconv.ParseInt(v, 10, 64); err == nil { - return i, sql.InRange, sql.ErrTruncatedIncorrect.New(t.String(), v) - } - // TODO: what should this error be? - return 0, sql.OutOfRange, sql.ErrInvalidValue.New(v, t.String()) + // TODO: round? + var err error + truncStr, didTrunc := TruncateStringToInt(v) + if didTrunc { + err = sql.ErrTruncatedIncorrect.New(t.String(), v) + } + if len(truncStr) == 0 { + return 0, sql.InRange, err + } + i, _ := strconv.ParseInt(truncStr, 10, 64) + return i, sql.InRange, err case bool: if v { return 1, sql.InRange, nil @@ -1117,7 +1176,7 @@ func convertValueToUint64(t NumberTypeImpl_, v sql.Value) (uint64, error) { } } -func convertToUint64(t NumberTypeImpl_, v interface{}) (uint64, sql.ConvertInRange, error) { +func convertToUint64(t NumberTypeImpl_, v any, round bool) (uint64, sql.ConvertInRange, error) { switch v := v.(type) { case time.Time: return uint64(v.UTC().Unix()), sql.InRange, nil @@ -1163,15 +1222,21 @@ func convertToUint64(t NumberTypeImpl_, v interface{}) (uint64, sql.ConvertInRan if v < 0 { return uint64(math.MaxUint64 - v), sql.OutOfRange, nil } - return uint64(math.Round(float64(v))), sql.InRange, nil + if round { + return uint64(math.Round(float64(v))), sql.InRange, nil + } + return uint64(v), sql.InRange, nil case float64: if v >= float64(math.MaxUint64) { return math.MaxUint64, sql.OutOfRange, nil } - if v <= 0 { + if v < 0 { return uint64(math.MaxUint64 - v), sql.OutOfRange, nil } - return uint64(math.Round(v)), sql.InRange, nil + if round { + return uint64(math.Round(v)), sql.InRange, nil + } + return uint64(v), sql.InRange, nil case decimal.Decimal: if v.GreaterThan(dec_uint64_max) { return math.MaxUint64, sql.OutOfRange, nil @@ -1190,17 +1255,31 @@ func convertToUint64(t NumberTypeImpl_, v interface{}) (uint64, sql.ConvertInRan } return i, sql.InRange, nil case string: - if i, err := strconv.ParseUint(v, 10, 64); err == nil { - return i, sql.InRange, nil - } else if err == strconv.ErrRange { + var err error + s, ok := TruncateStringToInt(v) + if ok { + err = sql.ErrTruncatedIncorrect.New(t.String(), v) + } + if len(s) == 0 { + return 0, sql.InRange, err + } + // Trim leading sign + neg := false + if s[0] == '+' { + s = s[1:] + } else if s[0] == '-' { + neg = true + s = s[1:] + } + i, pErr := strconv.ParseUint(s, 10, 64) + if errors.Is(pErr, strconv.ErrRange) { // Number is too large for uint64, return max value and OutOfRange - return math.MaxUint64, sql.OutOfRange, nil + return math.MaxUint64, sql.OutOfRange, err } - // If that fails, try as a float and round to integral - if f, err := strconv.ParseFloat(v, 64); err == nil { - return convertToUint64(t, f) + if neg { + i = math.MaxUint64 - i + 1 } - return 0, sql.OutOfRange, sql.ErrInvalidValue.New(v, t.String()) + return i, sql.InRange, err case bool: if v { return 1, sql.InRange, nil @@ -1213,7 +1292,7 @@ func convertToUint64(t NumberTypeImpl_, v interface{}) (uint64, sql.ConvertInRan } } -func convertToUint32(t NumberTypeImpl_, v interface{}) (uint32, sql.ConvertInRange, error) { +func convertToUint32(t NumberTypeImpl_, v any, round bool) (uint32, sql.ConvertInRange, error) { switch v := v.(type) { case int: if v < 0 { @@ -1251,7 +1330,10 @@ func convertToUint32(t NumberTypeImpl_, v interface{}) (uint32, sql.ConvertInRan } return uint32(v), sql.InRange, nil case uint: - return convertUintToUint32(uint64(v)) + if v > math.MaxUint32 { + return uint32(math.MaxUint32), sql.OutOfRange, nil + } + return uint32(v), sql.InRange, nil case uint8: return uint32(v), sql.InRange, nil case uint16: @@ -1259,14 +1341,10 @@ func convertToUint32(t NumberTypeImpl_, v interface{}) (uint32, sql.ConvertInRan case uint32: return v, sql.InRange, nil case uint64: - return convertUintToUint32(v) - case float64: - if float32(v) > float32(math.MaxInt32) { - return math.MaxUint32, sql.OutOfRange, nil - } else if v < 0 { - return uint32(math.MaxUint32 - v), sql.OutOfRange, nil + if v > math.MaxUint32 { + return uint32(math.MaxUint32), sql.OutOfRange, nil } - return uint32(math.Round(float64(v))), sql.InRange, nil + return uint32(v), sql.InRange, nil case float32: if v >= float32(math.MaxUint32) { return math.MaxUint32, sql.OutOfRange, nil @@ -1274,6 +1352,13 @@ func convertToUint32(t NumberTypeImpl_, v interface{}) (uint32, sql.ConvertInRan return uint32(math.MaxUint32 - v), sql.OutOfRange, nil } return uint32(math.Round(float64(v))), sql.InRange, nil + case float64: + if float32(v) > float32(math.MaxInt32) { + return math.MaxUint32, sql.OutOfRange, nil + } else if v < 0 { + return uint32(math.MaxUint32 - v), sql.OutOfRange, nil + } + return uint32(math.Round(float64(v))), sql.InRange, nil case decimal.Decimal: if v.GreaterThan(dec_uint32_max) { return math.MaxUint32, sql.InRange, nil @@ -1291,17 +1376,16 @@ func convertToUint32(t NumberTypeImpl_, v interface{}) (uint32, sql.ConvertInRan } return uint32(i), sql.InRange, nil case string: - if i, err := strconv.ParseUint(v, 10, 32); err == nil { - return uint32(i), sql.InRange, nil + var err error + truncStr, didTrunc := TruncateStringToInt(v) + if didTrunc { + err = sql.ErrTruncatedIncorrect.New(t.String(), v) } - if f, err := strconv.ParseFloat(v, 64); err == nil { - return convertToUint32(t, f) + if len(truncStr) == 0 { + return 0, sql.InRange, err } - v = TruncateStringToNumber(v, true) - if i, err := strconv.ParseUint(v, 10, 32); err == nil { - return uint32(i), sql.InRange, nil - } - return 0, sql.OutOfRange, sql.ErrInvalidValue.New(v, t.String()) + i, _ := strconv.ParseInt(truncStr, 10, 32) + return uint32(i), sql.InRange, err case bool: if v { return 1, sql.InRange, nil @@ -1314,12 +1398,13 @@ func convertToUint32(t NumberTypeImpl_, v interface{}) (uint32, sql.ConvertInRan } } -func convertToUint16(t NumberTypeImpl_, v interface{}) (uint16, sql.ConvertInRange, error) { +func convertToUint16(t NumberTypeImpl_, v any, round bool) (uint16, sql.ConvertInRange, error) { switch v := v.(type) { case int: if v < 0 { return uint16(math.MaxUint16 - uint(-v-1)), sql.OutOfRange, nil - } else if v > math.MaxUint16 { + } + if v > math.MaxUint16 { return uint16(math.MaxUint16), sql.OutOfRange, nil } return uint16(v), sql.InRange, nil @@ -1336,45 +1421,65 @@ func convertToUint16(t NumberTypeImpl_, v interface{}) (uint16, sql.ConvertInRan case int32: if v < 0 { return uint16(math.MaxUint16 - uint(-v-1)), sql.OutOfRange, nil - } else if v > math.MaxUint16 { + } + if v > math.MaxUint16 { return uint16(math.MaxUint16), sql.OutOfRange, nil } return uint16(v), sql.InRange, nil case int64: if v < 0 { return uint16(math.MaxUint16 - uint(-v-1)), sql.OutOfRange, nil - } else if v > math.MaxUint16 { + } + if v > math.MaxUint16 { return uint16(math.MaxUint16), sql.OutOfRange, nil } return uint16(v), sql.InRange, nil case uint: - return convertUintToUint16(uint64(v)) + if v > math.MaxUint16 { + return uint16(math.MaxUint16), sql.OutOfRange, nil + } + return uint16(v), sql.InRange, nil case uint8: return uint16(v), sql.InRange, nil - case uint64: - return convertUintToUint16(v) - case uint32: - return convertUintToUint16(uint64(v)) case uint16: return v, sql.InRange, nil + case uint32: + if v > math.MaxUint16 { + return uint16(math.MaxUint16), sql.OutOfRange, nil + } + return uint16(v), sql.InRange, nil + case uint64: + if v > math.MaxUint16 { + return uint16(math.MaxUint16), sql.OutOfRange, nil + } + return uint16(v), sql.InRange, nil case float32: if v > float32(math.MaxInt16) { return math.MaxUint16, sql.OutOfRange, nil - } else if v < 0 { + } + if v < 0 { return uint16(math.MaxUint16 - v), sql.OutOfRange, nil } - return uint16(math.Round(float64(v))), sql.InRange, nil + if round { + return uint16(math.Round(float64(v))), sql.InRange, nil + } + return uint16(v), sql.InRange, nil case float64: if v >= float64(math.MaxUint16) { return math.MaxUint16, sql.OutOfRange, nil - } else if v <= 0 { + } + if v <= 0 { return uint16(math.MaxUint16 - v), sql.OutOfRange, nil } - return uint16(math.Round(v)), sql.InRange, nil + if round { + return uint16(math.Round(float64(v))), sql.InRange, nil + } + return uint16(v), sql.InRange, nil case decimal.Decimal: if v.GreaterThan(dec_uint16_max) { return math.MaxUint16, sql.InRange, nil - } else if v.LessThan(dec_zero) { + } + if v.LessThan(dec_zero) { ret, _ := dec_uint16_max.Sub(v).Float64() return uint16(math.Round(ret)), sql.OutOfRange, nil } @@ -1388,17 +1493,16 @@ func convertToUint16(t NumberTypeImpl_, v interface{}) (uint16, sql.ConvertInRan } return uint16(i), sql.InRange, nil case string: - if i, err := strconv.ParseUint(v, 10, 16); err == nil { - return uint16(i), sql.InRange, nil - } - if f, err := strconv.ParseFloat(v, 64); err == nil { - return convertToUint16(t, f) + var err error + truncStr, didTrunc := TruncateStringToNumber(v) + if didTrunc { + err = sql.ErrTruncatedIncorrect.New(t.String(), v) } - v = TruncateStringToNumber(v, true) - if i, err := strconv.ParseUint(v, 10, 16); err == nil { - return uint16(i), sql.InRange, nil + if len(truncStr) == 0 { + return 0, sql.InRange, err } - return 0, sql.OutOfRange, sql.ErrInvalidValue.New(v, t.String()) + i, _ := strconv.ParseUint(truncStr, 10, 16) + return uint16(i), sql.InRange, err case bool: if v { return 1, sql.InRange, nil @@ -1411,71 +1515,97 @@ func convertToUint16(t NumberTypeImpl_, v interface{}) (uint16, sql.ConvertInRan } } -func convertToUint8(t NumberTypeImpl_, v interface{}) (uint8, sql.ConvertInRange, error) { +func convertToUint8(t NumberTypeImpl_, v any, round bool) (uint8, sql.ConvertInRange, error) { switch v := v.(type) { case int: if v < 0 { return uint8(math.MaxUint8 - uint(-v-1)), sql.OutOfRange, nil - } else if v > math.MaxUint8 { + } + if v > math.MaxUint8 { return uint8(math.MaxUint8), sql.OutOfRange, nil } return uint8(v), sql.InRange, nil case int16: if v < 0 { return uint8(math.MaxUint8 - uint(-v-1)), sql.OutOfRange, nil - } else if v > math.MaxUint8 { + } + if v > math.MaxUint8 { return uint8(math.MaxUint8), sql.OutOfRange, nil } return uint8(v), sql.InRange, nil case int8: if v < 0 { return uint8(math.MaxUint8 - uint(-v-1)), sql.OutOfRange, nil - } else if int(v) > math.MaxUint8 { + } + if int(v) > math.MaxUint8 { return uint8(math.MaxUint8), sql.OutOfRange, nil } return uint8(v), sql.InRange, nil case int32: if v < 0 { return uint8(math.MaxUint8 - uint(-v-1)), sql.OutOfRange, nil - } else if v > math.MaxUint8 { + } + if v > math.MaxUint8 { return uint8(math.MaxUint8), sql.OutOfRange, nil } return uint8(v), sql.InRange, nil case int64: if v < 0 { return uint8(math.MaxUint8 - uint(-v-1)), sql.OutOfRange, nil - } else if v > math.MaxUint8 { + } + if v > math.MaxUint8 { return uint8(math.MaxUint8), sql.OutOfRange, nil } return uint8(v), sql.InRange, nil case uint: - return convertUintToUint8(uint64(v)) - case uint16: - return convertUintToUint8(uint64(v)) - case uint64: - return convertUintToUint8(v) - case uint32: - return convertUintToUint8(uint64(v)) + if v > math.MaxUint8 { + return uint8(math.MaxUint8), sql.OutOfRange, nil + } + return uint8(v), sql.InRange, nil case uint8: return v, sql.InRange, nil + case uint16: + if v > math.MaxUint8 { + return uint8(math.MaxUint8), sql.OutOfRange, nil + } + return uint8(v), sql.InRange, nil + case uint32: + if v > math.MaxUint8 { + return uint8(math.MaxUint8), sql.OutOfRange, nil + } + return uint8(v), sql.InRange, nil + case uint64: + if v > math.MaxUint8 { + return uint8(math.MaxUint8), sql.OutOfRange, nil + } + return uint8(v), sql.InRange, nil case float32: if v > float32(math.MaxInt8) { return math.MaxUint8, sql.OutOfRange, nil - } else if v < 0 { + } + if v < 0 { return uint8(math.MaxUint8 - v), sql.OutOfRange, nil } - return uint8(math.Round(float64(v))), sql.InRange, nil + if round { + return uint8(math.Round(float64(v))), sql.InRange, nil + } + return uint8(v), sql.InRange, nil case float64: if v >= float64(math.MaxUint8) { return math.MaxUint8, sql.OutOfRange, nil - } else if v <= 0 { + } + if v <= 0 { return uint8(math.MaxUint8 - v), sql.OutOfRange, nil } - return uint8(math.Round(v)), sql.InRange, nil + if round { + return uint8(math.Round(v)), sql.InRange, nil + } + return uint8(v), sql.InRange, nil case decimal.Decimal: if v.GreaterThan(dec_uint8_max) { return math.MaxUint8, sql.InRange, nil - } else if v.LessThan(dec_zero) { + } + if v.LessThan(dec_zero) { ret, _ := dec_uint8_max.Sub(v).Float64() return uint8(math.Round(ret)), sql.OutOfRange, nil } @@ -1489,17 +1619,16 @@ func convertToUint8(t NumberTypeImpl_, v interface{}) (uint8, sql.ConvertInRange } return uint8(i), sql.InRange, nil case string: - if i, err := strconv.ParseUint(v, 10, 8); err == nil { - return uint8(i), sql.InRange, nil + var err error + truncStr, didTrunc := TruncateStringToInt(v) + if didTrunc { + err = sql.ErrTruncatedIncorrect.New(t.String(), v) } - if f, err := strconv.ParseFloat(v, 64); err == nil { - return convertToUint8(t, f) + if len(truncStr) == 0 { + return 0, sql.InRange, err } - v = TruncateStringToNumber(v, true) - if i, err := strconv.ParseUint(v, 10, 8); err == nil { - return uint8(i), sql.InRange, nil - } - return 0, sql.OutOfRange, sql.ErrInvalidValue.New(v, t.String()) + i, _ := strconv.ParseUint(v, 10, 8) + return uint8(i), sql.InRange, err case bool: if v { return 1, sql.InRange, nil @@ -1512,7 +1641,7 @@ func convertToUint8(t NumberTypeImpl_, v interface{}) (uint8, sql.ConvertInRange } } -func convertToFloat64(t NumberTypeImpl_, v interface{}) (float64, error) { +func convertToFloat64(t NumberTypeImpl_, v any) (float64, error) { switch v := v.(type) { case time.Time: return float64(v.UTC().Unix()), nil @@ -1550,16 +1679,16 @@ func convertToFloat64(t NumberTypeImpl_, v interface{}) (float64, error) { } return float64(i), nil case string: - v = strings.Trim(v, NumericCutSet) - if i, err := strconv.ParseFloat(v, 64); err == nil { - return i, nil - } - // TODO: what's the difference between this and TruncateStringToNumber? - // the difference is that this doesn't work lmao - // parse the first longest valid numbers - s := TruncateStringToNumber(v, false) - i, _ := strconv.ParseFloat(s, 64) - return i, sql.ErrTruncatedIncorrect.New(t.String(), v) + var err error + truncStr, didTrunc := TruncateStringToNumber(v) + if didTrunc { + err = sql.ErrTruncatedIncorrect.New(t.String(), v) + } + if len(truncStr) == 0 { + return 0, err + } + i, _ := strconv.ParseFloat(truncStr, 64) + return i, err case bool: if v { return 1, nil @@ -1603,116 +1732,8 @@ func convertValueToFloat64(t NumberTypeImpl_, v sql.Value) (float64, error) { } } -func mustInt64(v interface{}) int64 { - switch tv := v.(type) { - case int: - return int64(tv) - case int8: - return int64(tv) - case int16: - return int64(tv) - case int32: - return int64(tv) - case int64: - return tv - case uint: - return int64(tv) - case uint8: - return int64(tv) - case uint16: - return int64(tv) - case uint32: - return int64(tv) - case uint64: - return int64(tv) - case bool: - if tv { - return int64(1) - } - return int64(0) - case float32: - return int64(tv) - case float64: - return int64(tv) - default: - panic(fmt.Sprintf("unexpected type %v", v)) - } -} - -func mustUint64(v interface{}) uint64 { - switch tv := v.(type) { - case uint: - return uint64(tv) - case uint8: - return uint64(tv) - case uint16: - return uint64(tv) - case uint32: - return uint64(tv) - case uint64: - return tv - case int: - return uint64(tv) - case int8: - return uint64(tv) - case int16: - return uint64(tv) - case int32: - return uint64(tv) - case int64: - return uint64(tv) - case bool: - if tv { - return uint64(1) - } - return uint64(0) - case float32: - return uint64(tv) - case float64: - return uint64(tv) - default: - panic(fmt.Sprintf("unexpected type %v", v)) - } -} - -func mustFloat64(v interface{}) float64 { - switch tv := v.(type) { - case uint: - return float64(tv) - case uint8: - return float64(tv) - case uint16: - return float64(tv) - case uint32: - return float64(tv) - case uint64: - return float64(tv) - case int: - return float64(tv) - case int8: - return float64(tv) - case int16: - return float64(tv) - case int32: - return float64(tv) - case int64: - return float64(tv) - case bool: - if tv { - return float64(1) - } - return float64(0) - case float32: - return float64(tv) - case float64: - return tv - default: - panic(fmt.Sprintf("unexpected type %v", v)) - } -} - // CoalesceInt converts a int8/int16/... to int -func CoalesceInt(val interface{}) (int, bool) { +func CoalesceInt(val any) (int, bool) { switch v := val.(type) { case int: return v, true @@ -1736,30 +1757,3 @@ func CoalesceInt(val interface{}) (int, bool) { return 0, false } } - -// convertUintToUint8 converts a uint64 value to uint8 with overflow checking. -// Returns the converted value, range status, and any error. -func convertUintToUint8(v uint64) (uint8, sql.ConvertInRange, error) { - if v > math.MaxUint8 { - return uint8(math.MaxUint8), sql.OutOfRange, nil - } - return uint8(v), sql.InRange, nil -} - -// convertUintToUint16 converts a uint64 value to uint16 with overflow checking. -// Returns the converted value, range status, and any error. -func convertUintToUint16(v uint64) (uint16, sql.ConvertInRange, error) { - if v > math.MaxUint16 { - return uint16(math.MaxUint16), sql.OutOfRange, nil - } - return uint16(v), sql.InRange, nil -} - -// convertUintToUint32 converts a uint64 value to uint32 with overflow checking. -// Returns the converted value, range status, and any error. -func convertUintToUint32(v uint64) (uint32, sql.ConvertInRange, error) { - if v > math.MaxUint32 { - return uint32(math.MaxUint32), sql.OutOfRange, nil - } - return uint32(v), sql.InRange, nil -} diff --git a/sql/types/strings.go b/sql/types/strings.go index f71e009375..cc8171ee55 100644 --- a/sql/types/strings.go +++ b/sql/types/strings.go @@ -729,13 +729,13 @@ func (t StringType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes. dest = append(dest, v...) valueBytes = dest[start:] case int, int8, int16, int32, int64: - num, _, err := convertToInt64(Int64.(NumberTypeImpl_), v) + num, _, err := convertToInt64(Int64.(NumberTypeImpl_), v, false) if err != nil { return sqltypes.Value{}, err } valueBytes = strconv.AppendInt(dest, num, 10) case uint, uint8, uint16, uint32, uint64: - num, _, err := convertToUint64(Int64.(NumberTypeImpl_), v) + num, _, err := convertToUint64(Int64.(NumberTypeImpl_), v, false) if err != nil { return sqltypes.Value{}, err } From 052c2042c7b566b567199e71b56b09a74ec2d1a0 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 11 Sep 2025 16:54:29 -0700 Subject: [PATCH 09/48] simplify char function --- sql/expression/function/char.go | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/sql/expression/function/char.go b/sql/expression/function/char.go index cbfeeb1626..a071ad6ccd 100644 --- a/sql/expression/function/char.go +++ b/sql/expression/function/char.go @@ -85,18 +85,20 @@ func (c *Char) CollationCoercibility(ctx *sql.Context) (collation sql.CollationI return sql.Collation_binary, 5 } -// char converts num into a byte array -// This function is essentially converting the number to base 256 -func char(num uint32) []byte { - if num == 0 { - return []byte{0} +// encodeUInt32 converts uint32 `num` into a []byte using the fewest number of bytes in big endian (no leading 0s) +func encodeUInt32(num uint32) []byte { + res := make([]byte, 0, 4) + if x := byte(num >> 24); x > 0 { + res = append(res, x) } - res := byte(num & 255) - nextNum := num >> 8 - if nextNum == 0 { - return []byte{res} + if x := byte(num >> 16); x > 0 { + res = append(res, x) } - return append(char(num>>8), res) + if x := byte(num >> 8); x > 0 { + res = append(res, x) + } + res = append(res, byte(num)) + return res } // Eval implements the sql.Expression interface @@ -106,16 +108,13 @@ func (c *Char) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { if arg == nil { continue } - val, err := arg.Eval(ctx, row) if err != nil { return nil, err } - if val == nil { continue } - v, _, err := types.Uint32.Convert(ctx, val) if err != nil { if !sql.ErrTruncatedIncorrect.Is(err) { @@ -123,7 +122,7 @@ func (c *Char) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } ctx.Warn(1292, "%s", err.Error()) } - res = append(res, char(v.(uint32))...) + res = append(res, encodeUInt32(v.(uint32))...) } result, _, err := c.Type().Convert(ctx, res) From 00e92a7142afdd19985a90b3926b47fdd05774c9 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 11 Sep 2025 17:09:00 -0700 Subject: [PATCH 10/48] special case --- sql/expression/convert.go | 6 ++++++ sql/types/conversion.go | 10 +++++----- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/sql/expression/convert.go b/sql/expression/convert.go index a70ef3bd5f..90c9d53eed 100644 --- a/sql/expression/convert.go +++ b/sql/expression/convert.go @@ -347,6 +347,12 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s } return d, nil case ConvertToDecimal: + // TODO: HexBlobs shouldn't make it this far + var err error + val, err = types.ConvertHexBlobToUint(val, originType) + if err != nil { + return nil, err + } dt := createConvertedDecimalType(typeLength, typeScale, false) d, _, err := dt.Convert(ctx, val) if err != nil { diff --git a/sql/types/conversion.go b/sql/types/conversion.go index 5381e5c555..02648d1a14 100644 --- a/sql/types/conversion.go +++ b/sql/types/conversion.go @@ -33,7 +33,7 @@ import ( ) // ApproximateTypeFromValue returns the closest matching type to the given value. For example, an int16 will return SMALLINT. -func ApproximateTypeFromValue(val interface{}) sql.Type { +func ApproximateTypeFromValue(val any) sql.Type { switch v := val.(type) { case bool: return Boolean @@ -460,7 +460,7 @@ func ColumnTypeToType(ct *sqlparser.ColumnType) (sql.Type, error) { // CompareNulls compares two values, and returns true if either is null. // The returned integer represents the ordering, with a rule that states nulls // as being ordered before non-nulls. -func CompareNulls(a interface{}, b interface{}) (bool, int) { +func CompareNulls(a any, b any) (bool, int) { aIsNull := a == nil bIsNull := b == nil if aIsNull && bIsNull { @@ -751,7 +751,7 @@ func GeneralizeTypes(a, b sql.Type) sql.Type { // TypeAwareConversion converts a value to a specified type, with awareness of the value's original type. This is // necessary because some types, such as EnumType and SetType, are stored as ints and require information from the // original type to properly convert to strings. -func TypeAwareConversion(ctx *sql.Context, val interface{}, originalType sql.Type, convertedType sql.Type) (interface{}, sql.ConvertInRange, error) { +func TypeAwareConversion(ctx *sql.Context, val any, originalType sql.Type, convertedType sql.Type) (any, sql.ConvertInRange, error) { if val == nil { return nil, sql.InRange, nil } @@ -770,7 +770,7 @@ func TypeAwareConversion(ctx *sql.Context, val interface{}, originalType sql.Typ // cleanly and the type is automatically coerced (i.e. string and numeric types), then a warning is logged and the // value is truncated to the Zero value for type |t|. If the value does not convert and the type is not automatically // coerced, then return an error. -func ConvertOrTruncate(ctx *sql.Context, i interface{}, t sql.Type) (interface{}, error) { +func ConvertOrTruncate(ctx *sql.Context, i any, t sql.Type) (any, error) { // Do nothing if type is not provided. if t == nil { return i, nil @@ -812,7 +812,7 @@ func ConvertOrTruncate(ctx *sql.Context, i interface{}, t sql.Type) (interface{} // This function is called when convertTo type is number type only. The hex literal values are parsed into blobs as // binary string as default, but for numeric context, the value should be a number. // Byte arrays of other SQL types are not handled here. -func ConvertHexBlobToUint(val interface{}, originType sql.Type) (interface{}, error) { +func ConvertHexBlobToUint(val any, originType sql.Type) (any, error) { var err error if bin, isBinary := val.([]byte); isBinary && IsBlobType(originType) { stringVal := hex.EncodeToString(bin) From b1282683f8a58a9d1930cc86c68ed19d73bc5737 Mon Sep 17 00:00:00 2001 From: jycor Date: Thu, 11 Sep 2025 23:56:26 +0000 Subject: [PATCH 11/48] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/types/decimal.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/types/decimal.go b/sql/types/decimal.go index 2a212806eb..4d3497dad4 100644 --- a/sql/types/decimal.go +++ b/sql/types/decimal.go @@ -17,13 +17,14 @@ package types import ( "context" "fmt" + "math/big" + "reflect" + "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "github.com/shopspring/decimal" "github.com/shopspring/decimal" "gopkg.in/src-d/go-errors.v1" - "math/big" - "reflect" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/vitess/go/sqltypes" From 8d73f9b4449f43d070089bde35427916d17525b8 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 12 Sep 2025 14:05:26 -0700 Subject: [PATCH 12/48] more edge case fixing --- enginetest/queries/alter_table_queries.go | 6 +- memory/table.go | 3 +- sql/analyzer/resolve_column_defaults.go | 2 +- sql/columndefault.go | 12 +++- sql/expression/case.go | 4 +- sql/expression/comparison.go | 4 +- sql/expression/convert.go | 4 +- sql/expression/function/inet_convert.go | 2 +- sql/expression/set.go | 3 + sql/rowexec/ddl_iters.go | 2 +- sql/rowexec/insert.go | 18 +++++- sql/types/conversion.go | 7 ++- sql/types/decimal_test.go | 4 +- sql/types/number.go | 72 ++++++++++++----------- sql/types/number_test.go | 7 ++- 15 files changed, 91 insertions(+), 59 deletions(-) diff --git a/enginetest/queries/alter_table_queries.go b/enginetest/queries/alter_table_queries.go index 4a638e31b5..fd1a37f315 100644 --- a/enginetest/queries/alter_table_queries.go +++ b/enginetest/queries/alter_table_queries.go @@ -1011,9 +1011,9 @@ var AlterTableScripts = []ScriptTest{ Name: "alter modify column type float to bigint", SetUpScript: []string{ "create table t1 (pk int primary key, c1 float);", - "insert into t1 values (1, 0.0)", - "insert into t1 values (2, 127.9)", - "insert into t1 values (3, 42.1)", + "insert into t1 values (1, 0.0);", + "insert into t1 values (2, 127.9);", + "insert into t1 values (3, 42.1);", }, Assertions: []ScriptTestAssertion{ { diff --git a/memory/table.go b/memory/table.go index a7c6a0584f..13cc1cdd2c 100644 --- a/memory/table.go +++ b/memory/table.go @@ -1430,7 +1430,8 @@ func (t *Table) ModifyColumn(ctx *sql.Context, columnName string, column *sql.Co oldRowWithoutVal = append(oldRowWithoutVal, row[:oldIdx]...) oldRowWithoutVal = append(oldRowWithoutVal, row[oldIdx+1:]...) oldType := data.schema.Schema[oldIdx].Type - newVal, inRange, err := types.TypeAwareConversion(ctx, row[oldIdx], oldType, column.Type) + // TODO: this needs to call the rounding conversion thing + newVal, inRange, err := types.TypeAwareConversion(ctx, row[oldIdx], oldType, column.Type, true) if err != nil { if sql.ErrNotMatchingSRID.Is(err) { err = sql.ErrNotMatchingSRIDWithColName.New(columnName, err) diff --git a/sql/analyzer/resolve_column_defaults.go b/sql/analyzer/resolve_column_defaults.go index a8ed9f8124..d5275808a3 100644 --- a/sql/analyzer/resolve_column_defaults.go +++ b/sql/analyzer/resolve_column_defaults.go @@ -465,7 +465,7 @@ func normalizeDefault(ctx *sql.Context, colDefault *sql.ColumnDefaultValue) (sql } val, err := colDefault.Eval(ctx, nil) if err != nil { - return colDefault, transform.SameTree, nil + return nil, transform.SameTree, err } newDefault, err := colDefault.WithChildren(expression.NewLiteral(val, typ)) diff --git a/sql/columndefault.go b/sql/columndefault.go index 1f61e01b6e..3e8a5f967f 100644 --- a/sql/columndefault.go +++ b/sql/columndefault.go @@ -82,9 +82,15 @@ func (e *ColumnDefaultValue) Eval(ctx *Context, r Row) (interface{}, error) { if e.OutType != nil { var inRange ConvertInRange - if val, inRange, err = e.OutType.Convert(ctx, val); err != nil { + if roundType, isRoundType := e.OutType.(RoundingNumberType); isRoundType { + val, inRange, err = roundType.ConvertRound(ctx, val) + } else { + val, inRange, err = e.OutType.Convert(ctx, val) + } + if err != nil { return nil, ErrIncompatibleDefaultType.New() - } else if !inRange { + } + if !inRange { return nil, ErrValueOutOfRange.New(val, e.OutType) } } @@ -228,7 +234,7 @@ func (e *ColumnDefaultValue) CheckType(ctx *Context) error { return ErrIncompatibleDefaultType.New() } _, inRange, err := e.OutType.Convert(ctx, val) - if err != nil { + if err != nil && !ErrTruncatedIncorrect.Is(err) { return ErrIncompatibleDefaultType.Wrap(err) } else if !inRange { return ErrIncompatibleDefaultType.Wrap(ErrValueOutOfRange.New(val, e.Expr)) diff --git a/sql/expression/case.go b/sql/expression/case.go index 6724ba711c..1cfc70a9e0 100644 --- a/sql/expression/case.go +++ b/sql/expression/case.go @@ -136,7 +136,7 @@ func (c *Case) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } // When unable to convert to the type of the case, return the original value // A common error here is "Out of bounds value for decimal type" - if ret, inRange, err := types.TypeAwareConversion(ctx, bval, b.Value.Type(), t); inRange && err == nil { + if ret, inRange, err := types.TypeAwareConversion(ctx, bval, b.Value.Type(), t, false); inRange && err == nil { return ret, nil } return bval, nil @@ -150,7 +150,7 @@ func (c *Case) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } // When unable to convert to the type of the case, return the original value // A common error here is "Out of bounds value for decimal type" - if ret, inRange, err := types.TypeAwareConversion(ctx, val, c.Else.Type(), t); inRange && err == nil { + if ret, inRange, err := types.TypeAwareConversion(ctx, val, c.Else.Type(), t, false); inRange && err == nil { return ret, nil } return val, nil diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index 7ea42c475d..d31b98b5d6 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -188,7 +188,7 @@ func (c *comparison) CastLeftAndRight(ctx *sql.Context, left, right interface{}) if r, inRange, err := leftType.Convert(ctx, right); inRange && err == nil { return left, r, leftType, nil } else { - l, _, err := types.TypeAwareConversion(ctx, left, leftType, rightType) + l, _, err := types.TypeAwareConversion(ctx, left, leftType, rightType, false) if err != nil { return nil, nil, nil, err } @@ -200,7 +200,7 @@ func (c *comparison) CastLeftAndRight(ctx *sql.Context, left, right interface{}) if l, inRange, err := rightType.Convert(ctx, left); inRange && err == nil { return l, right, rightType, nil } else { - r, _, err := types.TypeAwareConversion(ctx, right, rightType, leftType) + r, _, err := types.TypeAwareConversion(ctx, right, rightType, leftType, false) if err != nil { return nil, nil, nil, err } diff --git a/sql/expression/convert.go b/sql/expression/convert.go index 90c9d53eed..e62cc4a089 100644 --- a/sql/expression/convert.go +++ b/sql/expression/convert.go @@ -299,7 +299,7 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s } switch strings.ToLower(castTo) { case ConvertToBinary: - b, _, err := types.TypeAwareConversion(ctx, val, originType, types.LongBlob) + b, _, err := types.TypeAwareConversion(ctx, val, originType, types.LongBlob, false) if err != nil { return nil, nil } @@ -317,7 +317,7 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s } return truncateConvertedValue(b, typeLength) case ConvertToChar, ConvertToNChar: - s, _, err := types.TypeAwareConversion(ctx, val, originType, types.LongText) + s, _, err := types.TypeAwareConversion(ctx, val, originType, types.LongText, false) if err != nil { return nil, nil } diff --git a/sql/expression/function/inet_convert.go b/sql/expression/function/inet_convert.go index b612ee43d0..5ec844383a 100644 --- a/sql/expression/function/inet_convert.go +++ b/sql/expression/function/inet_convert.go @@ -242,7 +242,7 @@ func (i *InetNtoa) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { // Convert val into int ipv4int, _, err := types.Int32.Convert(ctx, val) - if ipv4int != nil && err != nil { + if ipv4int != nil && err != nil && !sql.ErrTruncatedIncorrect.Is(err) { return nil, sql.ErrInvalidType.New(reflect.TypeOf(val).String()) } diff --git a/sql/expression/set.go b/sql/expression/set.go index a2eff1dbc1..c10c154a84 100644 --- a/sql/expression/set.go +++ b/sql/expression/set.go @@ -73,6 +73,9 @@ func (s *SetField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } if val != nil { convertedVal, _, err := getField.fieldType.Convert(ctx, val) + if sql.ErrTruncatedIncorrect.Is(err) { + err = sql.ErrInvalidValue.New(val, getField.fieldType) + } if err != nil { // Fill in error with information if types.ErrLengthBeyondLimit.Is(err) { diff --git a/sql/rowexec/ddl_iters.go b/sql/rowexec/ddl_iters.go index 8d638f79da..5669ce5c2c 100644 --- a/sql/rowexec/ddl_iters.go +++ b/sql/rowexec/ddl_iters.go @@ -922,7 +922,7 @@ func projectRowWithTypes(ctx *sql.Context, oldSchema, newSchema sql.Schema, proj } for i := range newRow { - converted, inRange, err := types.TypeAwareConversion(ctx, newRow[i], oldSchema[i].Type, newSchema[i].Type) + converted, inRange, err := types.TypeAwareConversion(ctx, newRow[i], oldSchema[i].Type, newSchema[i].Type, false) if err != nil { if sql.ErrNotMatchingSRID.Is(err) { err = sql.ErrNotMatchingSRIDWithColName.New(newSchema[i].Name, err) diff --git a/sql/rowexec/insert.go b/sql/rowexec/insert.go index d78ee5372f..f443919260 100644 --- a/sql/rowexec/insert.go +++ b/sql/rowexec/insert.go @@ -120,12 +120,26 @@ func (i *insertIter) Next(ctx *sql.Context) (returnRow sql.Row, returnErr error) ctxWithValues := context.WithValue(ctx.Context, types.ColumnNameKey, col.Name) ctxWithValues = context.WithValue(ctxWithValues, types.RowNumberKey, i.rowNumber) ctxWithColumnInfo := ctx.WithContext(ctxWithValues) - // TODO: add a ConvertForInsert? // TODO: check mysql strict mode - converted, inRange, cErr := col.Type.Convert(ctxWithColumnInfo, row[idx]) + var converted any + var inRange sql.ConvertInRange + var cErr error + // Hex strings shouldn't make it this far + val, cErr := types.ConvertHexBlobToUint(row[idx], col.Type) + if cErr != nil { + return nil, i.ignoreOrClose(ctx, origRow, cErr) + } + if typ, ok := col.Type.(sql.RoundingNumberType); ok { + converted, inRange, cErr = typ.ConvertRound(ctx, val) + } else { + converted, inRange, cErr = col.Type.Convert(ctxWithColumnInfo, val) + } if cErr == nil && !inRange { cErr = sql.ErrValueOutOfRange.New(row[idx], col.Type) } + if sql.ErrTruncatedIncorrect.Is(cErr) { + cErr = sql.ErrInvalidValue.New(row[idx], col.Type) + } if cErr != nil { // Ignore individual column errors when INSERT IGNORE, UPDATE IGNORE, etc. is specified. // For JSON column types, always throw an error. MySQL throws the following error even when diff --git a/sql/types/conversion.go b/sql/types/conversion.go index 02648d1a14..e39d00e520 100644 --- a/sql/types/conversion.go +++ b/sql/types/conversion.go @@ -751,7 +751,7 @@ func GeneralizeTypes(a, b sql.Type) sql.Type { // TypeAwareConversion converts a value to a specified type, with awareness of the value's original type. This is // necessary because some types, such as EnumType and SetType, are stored as ints and require information from the // original type to properly convert to strings. -func TypeAwareConversion(ctx *sql.Context, val any, originalType sql.Type, convertedType sql.Type) (any, sql.ConvertInRange, error) { +func TypeAwareConversion(ctx *sql.Context, val any, originalType sql.Type, convertedType sql.Type, round bool) (any, sql.ConvertInRange, error) { if val == nil { return nil, sql.InRange, nil } @@ -762,6 +762,11 @@ func TypeAwareConversion(ctx *sql.Context, val any, originalType sql.Type, conve return nil, sql.OutOfRange, err } } + if round { + if roundTyp, isRoundType := convertedType.(sql.RoundingNumberType); isRoundType { + return roundTyp.ConvertRound(ctx, val) + } + } return convertedType.Convert(ctx, val) } diff --git a/sql/types/decimal_test.go b/sql/types/decimal_test.go index e39e3496b6..6c1bed0d64 100644 --- a/sql/types/decimal_test.go +++ b/sql/types/decimal_test.go @@ -317,8 +317,8 @@ func TestDecimalConvert(t *testing.T) { {5, 0, "7742", "7742", false}, {5, 0, new(big.Float).SetFloat64(-4723.875), "-4724", false}, {5, 0, 99999, "99999", false}, - {5, 0, "0xf8e1", "63713", false}, - {5, 0, "0b1001110101100110", "40294", false}, + {5, 0, "0xf8e1", "0", false}, + {5, 0, "0b1001110101100110", "0", false}, {5, 0, new(big.Rat).SetFrac64(999999, 10), "", true}, {5, 0, 673927, "", true}, diff --git a/sql/types/number.go b/sql/types/number.go index 2818adeef8..78a366eec6 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -289,7 +289,7 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v any) (any, sql.ConvertIn } else if num < math.MinInt32 { return int32(math.MinInt32), sql.OutOfRange, nil } - return int32(num), sql.InRange, nil + return int32(num), sql.InRange, err case sqltypes.Uint32: return convertToUint32(t, v, false) case sqltypes.Int64: @@ -320,68 +320,62 @@ func (t NumberTypeImpl_) ConvertRound(ctx context.Context, v any) (any, sql.Conv switch t.baseType { case sqltypes.Int8, sqltypes.Int16, sqltypes.Int24, sqltypes.Int32, sqltypes.Int64: switch v.(type) { - case float32, float64: + case float32, float64, string: num, inRange, err := convertToInt64(t, v, true) - if err != nil { + if err != nil && !sql.ErrTruncatedIncorrect.Is(err) { return nil, sql.OutOfRange, err } // TODO: write helper method? switch t.baseType { case sqltypes.Int8: if num > math.MaxInt8 { - return int8(math.MaxInt8), sql.OutOfRange, nil + return int8(math.MaxInt8), sql.OutOfRange, err } if num < math.MinInt8 { - return int8(math.MinInt8), sql.OutOfRange, nil + return int8(math.MinInt8), sql.OutOfRange, err } - return int8(num), sql.InRange, nil + return int8(num), sql.InRange, err case sqltypes.Int16: if num > math.MaxInt16 { - return int16(math.MaxInt16), sql.OutOfRange, nil + return int16(math.MaxInt16), sql.OutOfRange, err } if num < math.MinInt16 { - return int16(math.MinInt16), sql.OutOfRange, nil + return int16(math.MinInt16), sql.OutOfRange, err } - return int16(num), sql.InRange, nil + return int16(num), sql.InRange, err case sqltypes.Int24: if num > (1<<23 - 1) { - return int32(1<<23 - 1), sql.OutOfRange, nil + return int32(1<<23 - 1), sql.OutOfRange, err } if num < (-1 << 23) { - return int32(-1 << 23), sql.OutOfRange, nil + return int32(-1 << 23), sql.OutOfRange, err } - return int32(num), sql.InRange, nil + return int32(num), sql.InRange, err case sqltypes.Int32: if num > math.MaxInt32 { - return int32(math.MaxInt32), sql.OutOfRange, nil + return int32(math.MaxInt32), sql.OutOfRange, err } if num < math.MinInt32 { - return int32(math.MinInt32), sql.OutOfRange, nil + return int32(math.MinInt32), sql.OutOfRange, err } - return int32(num), sql.InRange, nil + return int32(num), sql.InRange, err default: - return num, inRange, nil + return num, inRange, err } - default: - return t.Convert(ctx, v) } case sqltypes.Uint8: switch v.(type) { - case float32, float64: + case float32, float64, string: convertToUint8(t, v, true) - default: - return t.Convert(ctx, v) } case sqltypes.Uint16: switch v.(type) { - case float32, float64: + case float32, float64, string: convertToUint16(t, v, true) - default: - return t.Convert(ctx, v) } case sqltypes.Uint24: switch v.(type) { - case float32, float64: + case float32, float64, string: num, _, err := convertToInt64(t, v, true) if err != nil { return nil, sql.OutOfRange, err @@ -392,22 +386,16 @@ func (t NumberTypeImpl_) ConvertRound(ctx context.Context, v any) (any, sql.Conv return uint32(1<<24 - int32(-num)), sql.OutOfRange, nil } return uint32(num), sql.InRange, nil - default: - return t.Convert(ctx, v) } case sqltypes.Uint32: switch v.(type) { - case float32, float64: + case float32, float64, string: return convertToUint32(t, v, true) - default: - return t.Convert(ctx, v) } case sqltypes.Uint64: switch v.(type) { - case float32, float64: + case float32, float64, string: return convertToUint64(t, v, true) - default: - return t.Convert(ctx, v) } } return t.Convert(ctx, v) @@ -664,7 +652,7 @@ func (t NumberTypeImpl_) SQL(ctx *sql.Context, dest []byte, v any) (sqltypes.Val return sqltypes.Value{}, sql.ErrInvalidType.New(t.baseType.String()) } - if sql.ErrInvalidValue.Is(err) { + if sql.ErrInvalidValue.Is(err) || sql.ErrTruncatedIncorrect.Is(err) { switch str := v.(type) { case []byte: dest = str @@ -1065,8 +1053,22 @@ func convertToInt64(t NumberTypeImpl_, v any, round bool) (int64, sql.ConvertInR case []byte: return convertToInt64(t, string(v), round) case string: - // TODO: round? + // When round = true, truncation rules are less strict + // Integers will accept valid float notation without truncation error var err error + if round { + truncStr, didTrunc := TruncateStringToNumber(v) + if didTrunc { + err = sql.ErrTruncatedIncorrect.New(t.String(), v) + } + // TODO: might not be necessary + // Parse as int first + if i, pErr := strconv.ParseInt(truncStr, 10, 64); pErr == nil { + return i, sql.InRange, nil + } + f, _ := strconv.ParseFloat(v, 64) + return int64(math.Round(f)), sql.InRange, err + } truncStr, didTrunc := TruncateStringToInt(v) if didTrunc { err = sql.ErrTruncatedIncorrect.New(t.String(), v) diff --git a/sql/types/number_test.go b/sql/types/number_test.go index 8c7d2fca0a..4d6a655e1a 100644 --- a/sql/types/number_test.go +++ b/sql/types/number_test.go @@ -42,6 +42,7 @@ func TestNumberCompare(t *testing.T) { {Uint24, 0, nil, -1}, {Float64, nil, nil, 0}, + {Boolean, 0, 1, -1}, {Boolean, 0, 1, -1}, {Int8, -1, 2, -1}, {Int16, -2, 3, -1}, @@ -181,14 +182,14 @@ func TestNumberConvert(t *testing.T) { {typ: Int32, inp: nil, exp: nil, err: false, inRange: sql.InRange}, {typ: Int32, inp: 2147483647, exp: int32(2147483647), err: false, inRange: sql.InRange}, {typ: Int64, inp: "33", exp: int64(33), err: false, inRange: sql.InRange}, - {typ: Int64, inp: "33.0", exp: int64(33), err: false, inRange: sql.InRange}, - {typ: Int64, inp: "33.1", exp: int64(33), err: false, inRange: sql.InRange}, + {typ: Int64, inp: "33.0", exp: int64(33), err: true, inRange: sql.InRange}, + {typ: Int64, inp: "33.1", exp: int64(33), err: true, inRange: sql.InRange}, {typ: Int64, inp: strconv.FormatInt(math.MaxInt64, 10), exp: int64(math.MaxInt64), err: false, inRange: sql.InRange}, {typ: Int64, inp: true, exp: int64(1), err: false, inRange: sql.InRange}, {typ: Int64, inp: false, exp: int64(0), err: false, inRange: sql.InRange}, {typ: Uint8, inp: int64(34), exp: uint8(34), err: false, inRange: sql.InRange}, {typ: Uint16, inp: int16(35), exp: uint16(35), err: false, inRange: sql.InRange}, - {typ: Uint24, inp: 36.756, exp: uint32(37), err: false, inRange: sql.InRange}, + {typ: Uint24, inp: 36.756, exp: uint32(36), err: false, inRange: sql.InRange}, {typ: Uint32, inp: uint8(37), exp: uint32(37), err: false, inRange: sql.InRange}, {typ: Uint64, inp: time.Date(2009, 1, 2, 3, 4, 5, 0, time.UTC), exp: uint64(time.Date(2009, 1, 2, 3, 4, 5, 0, time.UTC).Unix()), err: false, inRange: sql.InRange}, {typ: Uint64, inp: "01000", exp: uint64(1000), err: false, inRange: sql.InRange}, From 071c4e0c2090cc54ed0642730726312efdc5de0e Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 12 Sep 2025 16:09:35 -0700 Subject: [PATCH 13/48] almost done --- enginetest/enginetests.go | 2 +- enginetest/queries/column_default_queries.go | 1 - enginetest/queries/json_table_queries.go | 4 +-- enginetest/queries/script_queries.go | 1 + .../function/aggregation/window_framer.go | 3 ++ sql/expression/interval.go | 2 +- sql/expression/procedurereference.go | 3 ++ sql/iters/rel_iters.go | 10 ++++-- sql/rowexec/insert.go | 14 ++++---- sql/types/number.go | 33 ++++++++++++++----- 10 files changed, 51 insertions(+), 22 deletions(-) diff --git a/enginetest/enginetests.go b/enginetest/enginetests.go index 6cf5577e52..4357fc6324 100644 --- a/enginetest/enginetests.go +++ b/enginetest/enginetests.go @@ -3960,9 +3960,9 @@ func TestWindowRangeFrames(t *testing.T, harness Harness) { TestQueryWithContext(t, ctx, e, harness, `SELECT sum(y) over (partition by z order by date range between unbounded preceding and interval '1' DAY following) FROM c order by x`, []sql.Row{{float64(1)}, {float64(1)}, {float64(1)}, {float64(1)}, {float64(5)}, {float64(5)}, {float64(10)}, {float64(10)}, {float64(10)}, {float64(10)}}, nil, nil, nil) TestQueryWithContext(t, ctx, e, harness, `SELECT count(y) over (partition by z order by date range between interval '1' DAY following and interval '2' DAY following) FROM c order by x`, []sql.Row{{1}, {1}, {1}, {1}, {1}, {0}, {2}, {2}, {0}, {0}}, nil, nil, nil) TestQueryWithContext(t, ctx, e, harness, `SELECT count(y) over (partition by z order by date range between interval '1' DAY preceding and interval '2' DAY following) FROM c order by x`, []sql.Row{{4}, {4}, {4}, {5}, {2}, {2}, {4}, {4}, {4}, {4}}, nil, nil, nil) + TestQueryWithContext(t, ctx, e, harness, "SELECT sum(y) over (partition by z order by date range interval 'e' DAY preceding) FROM c order by x", []sql.Row{{float64(0)}, {float64(0)}, {float64(0)}, {float64(1)}, {float64(1)}, {float64(3)}, {float64(1)}, {float64(1)}, {float64(4)}, {float64(4)}}, nil, nil, nil) AssertErr(t, e, harness, "SELECT sum(y) over (partition by z range between unbounded preceding and interval '1' DAY following) FROM c order by x", nil, aggregation.ErrRangeInvalidOrderBy) - AssertErr(t, e, harness, "SELECT sum(y) over (partition by z order by date range interval 'e' DAY preceding) FROM c order by x", nil, sql.ErrInvalidValue) } func TestNamedWindows(t *testing.T, harness Harness) { diff --git a/enginetest/queries/column_default_queries.go b/enginetest/queries/column_default_queries.go index 8ab0b9222e..cb53bded6e 100644 --- a/enginetest/queries/column_default_queries.go +++ b/enginetest/queries/column_default_queries.go @@ -571,7 +571,6 @@ var ColumnDefaultTests = []ScriptTest{ }, }, { - // Technically, MySQL does NOT allow BLOB/JSON/TEXT types to have a literal default value, and requires them // to be specified as an expression (i.e. wrapped in parens). We diverge from this behavior and allow it, for // compatibility with MariaDB. For more context, see: https://github.com/dolthub/dolt/issues/7033 Name: "BLOB types can define defaults with literals", diff --git a/enginetest/queries/json_table_queries.go b/enginetest/queries/json_table_queries.go index c4d7e13bcc..d615f068bf 100644 --- a/enginetest/queries/json_table_queries.go +++ b/enginetest/queries/json_table_queries.go @@ -571,7 +571,7 @@ var JSONTableScriptTests = []ScriptTest{ }, { Query: "SELECT * FROM JSON_TABLE('{\"c1\":\"abc\"}', '$' COLUMNS(c1 INT PATH '$.c1' DEFAULT 'def' ON ERROR)) as jt;", - ExpectedErrStr: "error: 'def' is not a valid value for 'int'", + ExpectedErrStr: "Invalid JSON text in argument 1 to function JSON_TABLE: \"Invalid value.\"", }, }, }, @@ -612,7 +612,7 @@ var JSONTableScriptTests = []ScriptTest{ }, { Query: "SELECT * FROM JSON_TABLE('{\"c1\":\"abc\"}', '$' COLUMNS(c1 INT PATH '$.c1' ERROR ON ERROR)) as jt;", - ExpectedErrStr: "error: 'abc' is not a valid value for 'int'", + ExpectedErrStr: "Invalid JSON text in argument 1 to function JSON_TABLE: \"Invalid value.\"", }, }, }, diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index a70dd2502b..8284d40d3c 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -3932,6 +3932,7 @@ CREATE TABLE tab3 ( }, }, { + Skip: true, // TODO: Aaaaaaaaaaaa Name: "Handle hex number to binary conversion", SetUpScript: []string{ "CREATE TABLE hex_nums1 (pk BIGINT PRIMARY KEY, v1 INT, v2 BIGINT UNSIGNED, v3 DOUBLE, v4 BINARY(32));", diff --git a/sql/expression/function/aggregation/window_framer.go b/sql/expression/function/aggregation/window_framer.go index 22ae82941b..a8c1884010 100644 --- a/sql/expression/function/aggregation/window_framer.go +++ b/sql/expression/function/aggregation/window_framer.go @@ -453,6 +453,9 @@ const ( // candidate. This is used as a sliding window algorithm for value ranges. func findInclusionBoundary(ctx *sql.Context, pos, searchStart, partitionEnd int, inclusion, expr sql.Expression, buf sql.WindowBuffer, stopCond stopCond) (int, error) { cur, err := inclusion.Eval(ctx, buf[pos]) + if sql.ErrTruncatedIncorrect.Is(err) { + return 0, nil + } if err != nil { return 0, err } diff --git a/sql/expression/interval.go b/sql/expression/interval.go index 0cf8e23a80..7ee22f744a 100644 --- a/sql/expression/interval.go +++ b/sql/expression/interval.go @@ -139,7 +139,7 @@ func (i *Interval) EvalDelta(ctx *sql.Context, row sql.Row) (*TimeDelta, error) } } else { val, _, err = types.Int64.Convert(ctx, val) - if err != nil { + if err != nil && !sql.ErrTruncatedIncorrect.Is(err) { return nil, err } diff --git a/sql/expression/procedurereference.go b/sql/expression/procedurereference.go index 4baef784ee..ce9adfb4d4 100644 --- a/sql/expression/procedurereference.go +++ b/sql/expression/procedurereference.go @@ -68,6 +68,9 @@ func (ppr *ProcedureReference) InitializeVariable(ctx *sql.Context, name string, return fmt.Errorf("cannot initialize variable `%s` in an empty procedure reference", name) } convertedVal, _, err := sqlType.Convert(ctx, val) + if sql.ErrTruncatedIncorrect.Is(err) { + return sql.ErrInvalidValue.New(val, sqlType) + } if err != nil { return err } diff --git a/sql/iters/rel_iters.go b/sql/iters/rel_iters.go index 47439d5dc3..0bbfa2179d 100644 --- a/sql/iters/rel_iters.go +++ b/sql/iters/rel_iters.go @@ -290,11 +290,17 @@ func (c *JsonTableCol) Next(ctx *sql.Context, obj interface{}, pass bool, ord in val, _, err = c.Opts.Typ.Convert(ctx, val) if err != nil { if c.Opts.ErrOnErr { - return nil, err + if sql.ErrTruncatedIncorrect.Is(err) { + return nil, sql.ErrInvalidJSONText.New(c.pos+1, "JSON_TABLE", "Invalid value.") + } + return nil, sql.ErrInvalidJSONText.New(c.pos+1, "JSON_TABLE", err.Error()) } val, _, err = c.Opts.Typ.Convert(ctx, c.Opts.DefErrVal) if err != nil { - return nil, err + if sql.ErrTruncatedIncorrect.Is(err) { + return nil, sql.ErrInvalidJSONText.New(c.pos+1, "JSON_TABLE", "Invalid value.") + } + return nil, sql.ErrInvalidJSONText.New(c.pos+1, "JSON_TABLE", err.Error()) } } diff --git a/sql/rowexec/insert.go b/sql/rowexec/insert.go index f443919260..27aedcbc42 100644 --- a/sql/rowexec/insert.go +++ b/sql/rowexec/insert.go @@ -120,25 +120,27 @@ func (i *insertIter) Next(ctx *sql.Context) (returnRow sql.Row, returnErr error) ctxWithValues := context.WithValue(ctx.Context, types.ColumnNameKey, col.Name) ctxWithValues = context.WithValue(ctxWithValues, types.RowNumberKey, i.rowNumber) ctxWithColumnInfo := ctx.WithContext(ctxWithValues) + val := row[idx] // TODO: check mysql strict mode var converted any var inRange sql.ConvertInRange var cErr error + // TODO: AAAAHHH // Hex strings shouldn't make it this far - val, cErr := types.ConvertHexBlobToUint(row[idx], col.Type) - if cErr != nil { - return nil, i.ignoreOrClose(ctx, origRow, cErr) - } + //val, cErr = types.ConvertHexBlobToUint(row[idx], col.Type) + //if cErr != nil { + // return nil, i.ignoreOrClose(ctx, origRow, cErr) + //} if typ, ok := col.Type.(sql.RoundingNumberType); ok { converted, inRange, cErr = typ.ConvertRound(ctx, val) } else { converted, inRange, cErr = col.Type.Convert(ctxWithColumnInfo, val) } if cErr == nil && !inRange { - cErr = sql.ErrValueOutOfRange.New(row[idx], col.Type) + cErr = sql.ErrValueOutOfRange.New(val, col.Type) } if sql.ErrTruncatedIncorrect.Is(cErr) { - cErr = sql.ErrInvalidValue.New(row[idx], col.Type) + cErr = sql.ErrInvalidValue.New(val, col.Type) } if cErr != nil { // Ignore individual column errors when INSERT IGNORE, UPDATE IGNORE, etc. is specified. diff --git a/sql/types/number.go b/sql/types/number.go index 78a366eec6..7b94c0f73d 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -1258,22 +1258,22 @@ func convertToUint64(t NumberTypeImpl_, v any, round bool) (uint64, sql.ConvertI return i, sql.InRange, nil case string: var err error - s, ok := TruncateStringToInt(v) - if ok { + truncStr, didTrunc := TruncateStringToInt(v) + if didTrunc { err = sql.ErrTruncatedIncorrect.New(t.String(), v) } - if len(s) == 0 { + if len(truncStr) == 0 { return 0, sql.InRange, err } // Trim leading sign neg := false - if s[0] == '+' { - s = s[1:] - } else if s[0] == '-' { + if truncStr[0] == '+' { + truncStr = truncStr[1:] + } else if truncStr[0] == '-' { neg = true - s = s[1:] + truncStr = truncStr[1:] } - i, pErr := strconv.ParseUint(s, 10, 64) + i, pErr := strconv.ParseUint(truncStr, 10, 64) if errors.Is(pErr, strconv.ErrRange) { // Number is too large for uint64, return max value and OutOfRange return math.MaxUint64, sql.OutOfRange, err @@ -1386,7 +1386,22 @@ func convertToUint32(t NumberTypeImpl_, v any, round bool) (uint32, sql.ConvertI if len(truncStr) == 0 { return 0, sql.InRange, err } - i, _ := strconv.ParseInt(truncStr, 10, 32) + // Trim leading sign + neg := false + if truncStr[0] == '+' { + truncStr = truncStr[1:] + } else if truncStr[0] == '-' { + neg = true + truncStr = truncStr[1:] + } + i, pErr := strconv.ParseUint(truncStr, 10, 32) + if errors.Is(pErr, strconv.ErrRange) || i > math.MaxUint32 { + // Number is too large for uint32, return max value and OutOfRange + return math.MaxUint32, sql.OutOfRange, err + } + if neg { + return uint32(math.MaxUint32 - i + 1), sql.OutOfRange, err + } return uint32(i), sql.InRange, err case bool: if v { From 7f3900f21eca403f94e6161c1881ceccd928fe60 Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 15 Sep 2025 10:14:19 -0700 Subject: [PATCH 14/48] cleaning up --- enginetest/queries/column_default_queries.go | 1 + memory/table.go | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/enginetest/queries/column_default_queries.go b/enginetest/queries/column_default_queries.go index cb53bded6e..8ab0b9222e 100644 --- a/enginetest/queries/column_default_queries.go +++ b/enginetest/queries/column_default_queries.go @@ -571,6 +571,7 @@ var ColumnDefaultTests = []ScriptTest{ }, }, { + // Technically, MySQL does NOT allow BLOB/JSON/TEXT types to have a literal default value, and requires them // to be specified as an expression (i.e. wrapped in parens). We diverge from this behavior and allow it, for // compatibility with MariaDB. For more context, see: https://github.com/dolthub/dolt/issues/7033 Name: "BLOB types can define defaults with literals", diff --git a/memory/table.go b/memory/table.go index 13cc1cdd2c..9932d0399a 100644 --- a/memory/table.go +++ b/memory/table.go @@ -1430,7 +1430,6 @@ func (t *Table) ModifyColumn(ctx *sql.Context, columnName string, column *sql.Co oldRowWithoutVal = append(oldRowWithoutVal, row[:oldIdx]...) oldRowWithoutVal = append(oldRowWithoutVal, row[oldIdx+1:]...) oldType := data.schema.Schema[oldIdx].Type - // TODO: this needs to call the rounding conversion thing newVal, inRange, err := types.TypeAwareConversion(ctx, row[oldIdx], oldType, column.Type, true) if err != nil { if sql.ErrNotMatchingSRID.Is(err) { From 8c5f4e82222fc1d87c5b07a43b3f20d247d39bf6 Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 15 Sep 2025 10:30:06 -0700 Subject: [PATCH 15/48] fixing tests --- enginetest/memory_engine_test.go | 162 +-------------------------- sql/expression/function/bit_count.go | 9 +- sql/types/number.go | 2 +- 3 files changed, 9 insertions(+), 164 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index a671d24cda..0d22da23c5 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -203,174 +203,18 @@ func TestSingleScript(t *testing.T) { //t.Skip() var scripts = []queries.ScriptTest{ { - Skip: true, - Name: "asdf", + Name: "test", SetUpScript: []string{}, Assertions: []queries.ScriptTestAssertion{ { - Query: "select cast('-3.1a' as signed);", + Query: "select bit_count('2.99a');", Expected: []sql.Row{ {-3}, }, }, }, }, - { - // https://github.com/dolthub/dolt/issues/9733 - // https://github.com/dolthub/dolt/issues/9739 - //Skip: true, - Name: "strings cast to numbers", - SetUpScript: []string{ - "create table test01(pk varchar(20) primary key);", - `insert into test01 values - (' 3 12 4'), - (' 3.2 12 4'), - ('-3.1234'), - ('-3.1a'), - ('-5+8'), - ('+3.1234'), - ('11d'), - ('11wha?'), - ('11'), - ('12'), - ('1a1'), - ('a1a1'), - ('11-5'), - ('3. 12 4'), - ('5.932887e+07'), - ('5.932887e+07abc'), - ('5.932887e7'), - ('5.932887e7abc');`, - }, - Assertions: []queries.ScriptTestAssertion{ - { - Dialect: "mysql", - Query: "select pk, cast(pk as signed) from test01", - Expected: []sql.Row{ - {" 3 12 4", 3}, - {" 3.2 12 4", 3}, - {"-3.1234", -3}, - {"-3.1a", -3}, - {"-5+8", -5}, - {"+3.1234", 3}, - {"11", 11}, - {"11-5", 11}, - {"11d", 11}, - {"11wha?", 11}, - {"12", 12}, - {"1a1", 1}, - {"3. 12 4", 3}, - {"5.932887e+07", 5}, - {"5.932887e+07abc", 5}, - {"5.932887e7", 5}, - {"5.932887e7abc", 5}, - {"a1a1", 0}, - }, - }, - { - Dialect: "mysql", - Query: "select pk, cast(pk as unsigned) from test01", - Expected: []sql.Row{ - {" 3 12 4", uint64(3)}, - {" 3.2 12 4", uint64(3)}, - {"-3.1234", uint64(18446744073709551613)}, - {"-3.1a", uint64(18446744073709551613)}, - {"-5+8", uint64(18446744073709551611)}, - {"+3.1234", uint64(3)}, - {"11", uint64(11)}, - {"11-5", uint64(11)}, - {"11d", uint64(11)}, - {"11wha?", uint64(11)}, - {"12", uint64(12)}, - {"1a1", uint64(1)}, - {"3. 12 4", uint64(3)}, - {"5.932887e+07", uint64(5)}, - {"5.932887e+07abc", uint64(5)}, - {"5.932887e7", uint64(5)}, - {"5.932887e7abc", uint64(5)}, - {"a1a1", uint64(0)}, - }, - }, - { - Dialect: "mysql", - Query: "select pk, cast(pk as decimal(12,3)) from test01", - Expected: []sql.Row{ - {" 3 12 4", "3.000"}, - {" 3.2 12 4", "3.200"}, - {"-3.1234", "-3.123"}, - {"-3.1a", "-3.100"}, - {"-5+8", "-5.000"}, - {"+3.1234", "3.123"}, - {"11", "11.000"}, - {"11-5", "11.000"}, - {"11d", "11.000"}, - {"11wha?", "11.000"}, - {"12", "12.000"}, - {"1a1", "1.000"}, - {"3. 12 4", "3.000"}, - {"5.932887e+07", "59328870.000"}, - {"5.932887e+07abc", "59328870.000"}, - {"5.932887e7", "59328870.000"}, - {"5.932887e7abc", "59328870.000"}, - {"a1a1", "0.000"}, - }, - }, - { - Query: "select * from test01 where pk in ('11')", - Expected: []sql.Row{{"11"}}, - }, - { - // https://github.com/dolthub/dolt/issues/9739 - Skip: true, - Dialect: "mysql", - Query: "select * from test01 where pk in (11)", - Expected: []sql.Row{ - {"11"}, - {"11-5"}, - {"11d"}, - {"11wha?"}, - }, - }, - { - // https://github.com/dolthub/dolt/issues/9739 - Skip: true, - Dialect: "mysql", - Query: "select * from test01 where pk=3", - Expected: []sql.Row{ - {" 3 12 4"}, - {" 3. 12 4"}, - {"3. 12 4"}, - }, - }, - { - // https://github.com/dolthub/dolt/issues/9739 - Skip: true, - Dialect: "mysql", - Query: "select * from test01 where pk>=3 and pk < 4", - Expected: []sql.Row{ - {" 3 12 4"}, - {" 3. 12 4"}, - {" 3.2 12 4"}, - {"+3.1234"}, - {"3. 12 4"}, - }, - }, - //{ - // // https://github.com/dolthub/dolt/issues/9739 - // Skip: true, - // Dialect: "mysql", - // Query: "select * from test02 where pk in ('11asdf')", - // Expected: []sql.Row{{"11"}}, - //}, - //{ - // // https://github.com/dolthub/dolt/issues/9739 - // Skip: true, - // Dialect: "mysql", - // Query: "select * from test02 where pk='11.12asdf'", - // Expected: []sql.Row{}, - //}, - }, - }, + //{ // Name: "AS OF propagates to nested CALLs", // SetUpScript: []string{}, diff --git a/sql/expression/function/bit_count.go b/sql/expression/function/bit_count.go index 8fbb0279f4..1830ad67b0 100644 --- a/sql/expression/function/bit_count.go +++ b/sql/expression/function/bit_count.go @@ -97,12 +97,13 @@ func (b *BitCount) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { res += countBits(uint64(v)) } default: - num, _, err := types.Int64.Convert(ctx, child) + num, _, err := types.Int64.(sql.RoundingNumberType).ConvertRound(ctx, child) if err != nil { - ctx.Warn(1292, "Truncated incorrect INTEGER value: '%v'", child) - num = int64(0) + if !sql.ErrTruncatedIncorrect.Is(err) { + return nil, err + } + ctx.Warn(1292, "%s", err.Error()) } - // Must convert to unsigned because shifting a negative signed value fills with 1s res = countBits(uint64(num.(int64))) } diff --git a/sql/types/number.go b/sql/types/number.go index 7b94c0f73d..c3057e2b1b 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -1066,7 +1066,7 @@ func convertToInt64(t NumberTypeImpl_, v any, round bool) (int64, sql.ConvertInR if i, pErr := strconv.ParseInt(truncStr, 10, 64); pErr == nil { return i, sql.InRange, nil } - f, _ := strconv.ParseFloat(v, 64) + f, _ := strconv.ParseFloat(truncStr, 64) return int64(math.Round(f)), sql.InRange, err } truncStr, didTrunc := TruncateStringToInt(v) From 92c424fdc08f0c068801b8487a7ad789766f4f45 Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 15 Sep 2025 12:14:51 -0700 Subject: [PATCH 16/48] fix more tests --- sql/expression/function/bit_count.go | 2 +- sql/expression/function/bit_count_test.go | 9 ++++++--- sql/expression/function/char.go | 21 ++++++++++---------- sql/expression/function/elt.go | 10 ++++++---- sql/expression/function/export_set.go | 5 ++++- sql/expression/function/export_set_test.go | 4 +++- sql/types/number.go | 23 +++++----------------- 7 files changed, 36 insertions(+), 38 deletions(-) diff --git a/sql/expression/function/bit_count.go b/sql/expression/function/bit_count.go index 1830ad67b0..56c4d68946 100644 --- a/sql/expression/function/bit_count.go +++ b/sql/expression/function/bit_count.go @@ -97,7 +97,7 @@ func (b *BitCount) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { res += countBits(uint64(v)) } default: - num, _, err := types.Int64.(sql.RoundingNumberType).ConvertRound(ctx, child) + num, _, err := types.Int64.Convert(ctx, child) if err != nil { if !sql.ErrTruncatedIncorrect.Is(err) { return nil, err diff --git a/sql/expression/function/bit_count_test.go b/sql/expression/function/bit_count_test.go index 3c0373e743..c3d2e5ec05 100644 --- a/sql/expression/function/bit_count_test.go +++ b/sql/expression/function/bit_count_test.go @@ -130,13 +130,16 @@ func TestBitCount(t *testing.T) { err: false, }, { - // we don't do truncation yet - // https://github.com/dolthub/dolt/issues/7302 + name: "valid float strings do not round", + arg: expression.NewLiteral("2.99", types.Text), + exp: int32(1), + err: false, + }, + { name: "scientific string is truncated", arg: expression.NewLiteral("1e1", types.Text), exp: int32(1), err: false, - skip: true, }, } diff --git a/sql/expression/function/char.go b/sql/expression/function/char.go index a071ad6ccd..531327a046 100644 --- a/sql/expression/function/char.go +++ b/sql/expression/function/char.go @@ -87,18 +87,19 @@ func (c *Char) CollationCoercibility(ctx *sql.Context) (collation sql.CollationI // encodeUInt32 converts uint32 `num` into a []byte using the fewest number of bytes in big endian (no leading 0s) func encodeUInt32(num uint32) []byte { - res := make([]byte, 0, 4) - if x := byte(num >> 24); x > 0 { - res = append(res, x) + res := []byte{ + byte(num >> 24), + byte(num >> 16), + byte(num >> 8), + byte(num), } - if x := byte(num >> 16); x > 0 { - res = append(res, x) - } - if x := byte(num >> 8); x > 0 { - res = append(res, x) + var i int + for i = 0; i < 3; i++ { + if res[i] != 0 { + break + } } - res = append(res, byte(num)) - return res + return res[i:] } // Eval implements the sql.Expression interface diff --git a/sql/expression/function/elt.go b/sql/expression/function/elt.go index ba2d050b96..00fb6bf01a 100644 --- a/sql/expression/function/elt.go +++ b/sql/expression/function/elt.go @@ -116,11 +116,13 @@ func (e *Elt) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, nil } - indexInt, _, err := types.Int64.Convert(ctx, index) + // TODO: aaaaaaaaaaaaahhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhh + indexInt, _, err := types.Int64.(sql.RoundingNumberType).ConvertRound(ctx, index) if err != nil { - // TODO: truncate - ctx.Warn(1292, "Truncated incorrect INTEGER value: '%v'", index) - indexInt = int64(0) + if !sql.ErrTruncatedIncorrect.Is(err) { + return nil, err + } + ctx.Warn(1292, "%s", err.Error()) } idx := int(indexInt.(int64)) diff --git a/sql/expression/function/export_set.go b/sql/expression/function/export_set.go index 9356ad7b22..bfd7648ea2 100644 --- a/sql/expression/function/export_set.go +++ b/sql/expression/function/export_set.go @@ -205,7 +205,10 @@ func (e *ExportSet) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { // Convert arguments to proper types bitsInt, _, err := types.Uint64.Convert(ctx, bitsVal) if err != nil { - return nil, err + if !sql.ErrTruncatedIncorrect.Is(err) { + return nil, err + } + ctx.Warn(1292, "%s", err.Error()) } onStr, _, err := types.LongText.Convert(ctx, onVal) diff --git a/sql/expression/function/export_set_test.go b/sql/expression/function/export_set_test.go index c6425211f3..5341866d76 100644 --- a/sql/expression/function/export_set_test.go +++ b/sql/expression/function/export_set_test.go @@ -72,7 +72,9 @@ func TestExportSet(t *testing.T) { {"null number of bits", []interface{}{5, "1", "0", ",", nil}, nil, false}, // Type conversion - {"string number", []interface{}{"5", "1", "0", ",", 4}, "1,0,1,0", false}, + {"string integer", []interface{}{"5", "1", "0", ",", 4}, "1,0,1,0", false}, + {"string float 5.99", []interface{}{"5.99", "1", "0", ",", 4}, "1,0,1,0", false}, + {"string float 5.01", []interface{}{"5.01", "1", "0", ",", 4}, "1,0,1,0", false}, {"float number", []interface{}{5.7, "1", "0", ",", 4}, "0,1,1,0", false}, {"negative number", []interface{}{-1, "1", "0", ",", 4}, "1,1,1,1", false}, } diff --git a/sql/types/number.go b/sql/types/number.go index c3057e2b1b..0cc56e9bdd 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -1026,10 +1026,7 @@ func convertToInt64(t NumberTypeImpl_, v any, round bool) (int64, sql.ConvertInR if v < float32(math.MinInt64) { return math.MinInt64, sql.OutOfRange, nil } - if round { - return int64(math.Round(float64(v))), sql.InRange, nil - } - return int64(v), sql.InRange, nil + return int64(math.Round(float64(v))), sql.InRange, nil case float64: if v > float64(math.MaxInt64) { return math.MaxInt64, sql.OutOfRange, nil @@ -1037,10 +1034,7 @@ func convertToInt64(t NumberTypeImpl_, v any, round bool) (int64, sql.ConvertInR if v < float64(math.MinInt64) { return math.MinInt64, sql.OutOfRange, nil } - if round { - return int64(math.Round(v)), sql.InRange, nil - } - return int64(v), sql.InRange, nil + return int64(math.Round(v)), sql.InRange, nil case decimal.Decimal: // TODO: round? if v.GreaterThan(dec_int64_max) { @@ -1061,13 +1055,12 @@ func convertToInt64(t NumberTypeImpl_, v any, round bool) (int64, sql.ConvertInR if didTrunc { err = sql.ErrTruncatedIncorrect.New(t.String(), v) } - // TODO: might not be necessary // Parse as int first if i, pErr := strconv.ParseInt(truncStr, 10, 64); pErr == nil { return i, sql.InRange, nil } f, _ := strconv.ParseFloat(truncStr, 64) - return int64(math.Round(f)), sql.InRange, err + return convertToInt64(t, f, round) } truncStr, didTrunc := TruncateStringToInt(v) if didTrunc { @@ -1224,10 +1217,7 @@ func convertToUint64(t NumberTypeImpl_, v any, round bool) (uint64, sql.ConvertI if v < 0 { return uint64(math.MaxUint64 - v), sql.OutOfRange, nil } - if round { - return uint64(math.Round(float64(v))), sql.InRange, nil - } - return uint64(v), sql.InRange, nil + return uint64(math.Round(float64(v))), sql.InRange, nil case float64: if v >= float64(math.MaxUint64) { return math.MaxUint64, sql.OutOfRange, nil @@ -1235,10 +1225,7 @@ func convertToUint64(t NumberTypeImpl_, v any, round bool) (uint64, sql.ConvertI if v < 0 { return uint64(math.MaxUint64 - v), sql.OutOfRange, nil } - if round { - return uint64(math.Round(v)), sql.InRange, nil - } - return uint64(v), sql.InRange, nil + return uint64(math.Round(v)), sql.InRange, nil case decimal.Decimal: if v.GreaterThan(dec_uint64_max) { return math.MaxUint64, sql.OutOfRange, nil From ef266d15859824ade25fe9c1a3327ab3bc58774c Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 15 Sep 2025 13:16:48 -0700 Subject: [PATCH 17/48] more tests --- sql/rowexec/insert_test.go | 1 + sql/types/number.go | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/sql/rowexec/insert_test.go b/sql/rowexec/insert_test.go index 6213f806c1..1b669ecbe0 100644 --- a/sql/rowexec/insert_test.go +++ b/sql/rowexec/insert_test.go @@ -38,6 +38,7 @@ func TestInsertIgnoreConversions(t *testing.T) { err bool }{ { + // TODO: this only works when sql_mode does not have STRICT_TRANS_TABLES / STRICT_ALL_TABLES name: "inserting a string into a integer defaults to a 0", colType: types.Int64, value: "dadasd", diff --git a/sql/types/number.go b/sql/types/number.go index 0cc56e9bdd..132836188f 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -1060,7 +1060,11 @@ func convertToInt64(t NumberTypeImpl_, v any, round bool) (int64, sql.ConvertInR return i, sql.InRange, nil } f, _ := strconv.ParseFloat(truncStr, 64) - return convertToInt64(t, f, round) + res, outOfRange, cErr := convertToInt64(t, f, round) + if cErr != nil { + err = cErr + } + return res, outOfRange, err } truncStr, didTrunc := TruncateStringToInt(v) if didTrunc { From 55eb7aea97e4797de7a758eae0ce32a0203a7aa6 Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 15 Sep 2025 13:42:03 -0700 Subject: [PATCH 18/48] fix tests --- sql/types/number_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/types/number_test.go b/sql/types/number_test.go index 4d6a655e1a..bc79a20be5 100644 --- a/sql/types/number_test.go +++ b/sql/types/number_test.go @@ -189,7 +189,7 @@ func TestNumberConvert(t *testing.T) { {typ: Int64, inp: false, exp: int64(0), err: false, inRange: sql.InRange}, {typ: Uint8, inp: int64(34), exp: uint8(34), err: false, inRange: sql.InRange}, {typ: Uint16, inp: int16(35), exp: uint16(35), err: false, inRange: sql.InRange}, - {typ: Uint24, inp: 36.756, exp: uint32(36), err: false, inRange: sql.InRange}, + {typ: Uint24, inp: 36.756, exp: uint32(37), err: false, inRange: sql.InRange}, {typ: Uint32, inp: uint8(37), exp: uint32(37), err: false, inRange: sql.InRange}, {typ: Uint64, inp: time.Date(2009, 1, 2, 3, 4, 5, 0, time.UTC), exp: uint64(time.Date(2009, 1, 2, 3, 4, 5, 0, time.UTC).Unix()), err: false, inRange: sql.InRange}, {typ: Uint64, inp: "01000", exp: uint64(1000), err: false, inRange: sql.InRange}, From e4ab60758665431dc01ef25bdc5412c97b1d089a Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 15 Sep 2025 14:12:15 -0700 Subject: [PATCH 19/48] more --- server/handler_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/handler_test.go b/server/handler_test.go index 03bf918754..6ebf543561 100644 --- a/server/handler_test.go +++ b/server/handler_test.go @@ -1572,7 +1572,7 @@ func TestStatusVariableMaxUsedConnections(t *testing.T) { } checkGlobalStatVar(t, "Max_used_connections", uint64(0)) - checkGlobalStatVar(t, "Max_used_connections_time", "") + checkGlobalStatVar(t, "Max_used_connections_time", uint64(0)) conn1 := newConn(1) handler.NewConnection(conn1) From 57e44c960c146de074eefba15821aac74618cb3d Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 15 Sep 2025 14:39:09 -0700 Subject: [PATCH 20/48] rebase --- sql/expression/convert.go | 1 - sql/types/decimal.go | 3 --- 2 files changed, 4 deletions(-) diff --git a/sql/expression/convert.go b/sql/expression/convert.go index e62cc4a089..d75ba20aad 100644 --- a/sql/expression/convert.go +++ b/sql/expression/convert.go @@ -19,7 +19,6 @@ import ( "strings" "time" - "github.com/dolthub/vitess/go/mysql" "github.com/dolthub/vitess/go/sqltypes" "github.com/sirupsen/logrus" "gopkg.in/src-d/go-errors.v1" diff --git a/sql/types/decimal.go b/sql/types/decimal.go index 4d3497dad4..c866875fbd 100644 --- a/sql/types/decimal.go +++ b/sql/types/decimal.go @@ -20,9 +20,6 @@ import ( "math/big" "reflect" - "github.com/dolthub/vitess/go/sqltypes" - "github.com/dolthub/vitess/go/vt/proto/query" - "github.com/shopspring/decimal" "github.com/shopspring/decimal" "gopkg.in/src-d/go-errors.v1" From f493c9f1805c95d37dd4768ddf111b6b14c2a640 Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 16 Sep 2025 11:18:17 -0700 Subject: [PATCH 21/48] asdf --- enginetest/memory_engine_test.go | 12 +- sql/expression/comparison.go | 274 +++++++++++++++++-------------- sql/expression/convert.go | 17 +- sql/expression/in.go | 13 +- sql/plan/hash_lookup.go | 42 +---- sql/types/datetime.go | 1 - sql/types/decimal.go | 10 +- sql/types/number.go | 3 +- sql/types/utils.go | 58 +++++++ 9 files changed, 242 insertions(+), 188 deletions(-) create mode 100644 sql/types/utils.go diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index 0d22da23c5..b77c46d94e 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -203,13 +203,17 @@ func TestSingleScript(t *testing.T) { //t.Skip() var scripts = []queries.ScriptTest{ { - Name: "test", - SetUpScript: []string{}, + Name: "asdfasdfasdf", + SetUpScript: []string{ + "create table parent (e enum('a', 'b', 'c') primary key);", + "insert into parent values (1), (2);", + "create table child1 (e enum('x', 'y', 'z'), foreign key (e) references parent (e));", + }, Assertions: []queries.ScriptTestAssertion{ { - Query: "select bit_count('2.99a');", + Query: "insert into child1 values (1), (2);", Expected: []sql.Row{ - {-3}, + {types.NewOkResult(2)}, }, }, }, diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index d31b98b5d6..69d4ab6fb5 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -16,7 +16,7 @@ package expression import ( "fmt" - + "github.com/dolthub/vitess/go/mysql" errors "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" @@ -172,135 +172,161 @@ func (c *comparison) evalLeftAndRight(ctx *sql.Context, row sql.Row) (interface{ } func (c *comparison) CastLeftAndRight(ctx *sql.Context, left, right interface{}) (interface{}, interface{}, sql.Type, error) { - leftType := c.Left().Type() - rightType := c.Right().Type() - - leftIsEnumOrSet := types.IsEnum(leftType) || types.IsSet(leftType) - rightIsEnumOrSet := types.IsEnum(rightType) || types.IsSet(rightType) - // Only convert if same Enum or Set - if leftIsEnumOrSet && rightIsEnumOrSet { - if types.TypesEqual(leftType, rightType) { - return left, right, leftType, nil - } - } else { - // If right side is convertible to enum/set, convert. Otherwise, convert left side - if leftIsEnumOrSet && (types.IsText(rightType) || types.IsNumber(rightType)) { - if r, inRange, err := leftType.Convert(ctx, right); inRange && err == nil { - return left, r, leftType, nil - } else { - l, _, err := types.TypeAwareConversion(ctx, left, leftType, rightType, false) - if err != nil { - return nil, nil, nil, err - } - return l, right, rightType, nil - } - } - // If left side is convertible to enum/set, convert. Otherwise, convert right side - if rightIsEnumOrSet && (types.IsText(leftType) || types.IsNumber(leftType)) { - if l, inRange, err := rightType.Convert(ctx, left); inRange && err == nil { - return l, right, rightType, nil - } else { - r, _, err := types.TypeAwareConversion(ctx, right, rightType, leftType, false) - if err != nil { - return nil, nil, nil, err - } - return left, r, leftType, nil - } - } - } - - if types.IsTimespan(leftType) || types.IsTimespan(rightType) { - if l, err := types.Time.ConvertToTimespan(left); err == nil { - if r, err := types.Time.ConvertToTimespan(right); err == nil { - return l, r, types.Time, nil - } - } - } - - if types.IsTuple(leftType) && types.IsTuple(rightType) { - return left, right, c.Left().Type(), nil - } - - if types.IsTime(leftType) || types.IsTime(rightType) { - l, r, err := convertLeftAndRight(ctx, left, right, ConvertToDatetime) - if err != nil { - return nil, nil, nil, err - } - - return l, r, types.DatetimeMaxPrecision, nil - } - - // Rely on types.JSON.Compare to handle JSON comparisons - if types.IsJSON(leftType) || types.IsJSON(rightType) { - return left, right, types.JSON, nil - } - - if types.IsBinaryType(leftType) || types.IsBinaryType(rightType) { - l, r, err := convertLeftAndRight(ctx, left, right, ConvertToBinary) - if err != nil { - return nil, nil, nil, err - } - return l, r, types.LongBlob, nil - } - - if types.IsNumber(leftType) || types.IsNumber(rightType) { - if types.IsDecimal(leftType) || types.IsDecimal(rightType) { - //TODO: We need to set to the actual DECIMAL type - l, r, err := convertLeftAndRight(ctx, left, right, ConvertToDecimal) - if err != nil { - return nil, nil, nil, err - } - - if types.IsDecimal(leftType) { - return l, r, leftType, nil - } else { - return l, r, rightType, nil - } - } - - if types.IsFloat(leftType) || types.IsFloat(rightType) { - l, r, err := convertLeftAndRight(ctx, left, right, ConvertToDouble) - if err != nil { - return nil, nil, nil, err - } - - return l, r, types.Float64, nil - } - - if types.IsSigned(leftType) && types.IsSigned(rightType) { - l, r, err := convertLeftAndRight(ctx, left, right, ConvertToSigned) - if err != nil { - return nil, nil, nil, err - } - - return l, r, types.Int64, nil - } - - if types.IsUnsigned(leftType) && types.IsUnsigned(rightType) { - l, r, err := convertLeftAndRight(ctx, left, right, ConvertToUnsigned) - if err != nil { - return nil, nil, nil, err - } - - return l, r, types.Uint64, nil - } - - l, r, err := convertLeftAndRight(ctx, left, right, ConvertToDouble) - if err != nil { - return nil, nil, nil, err + //leftType := c.Left().Type() + //rightType := c.Right().Type() + // + //leftIsEnumOrSet := types.IsEnum(leftType) || types.IsSet(leftType) + //rightIsEnumOrSet := types.IsEnum(rightType) || types.IsSet(rightType) + //// Only convert if same Enum or Set + //if leftIsEnumOrSet && rightIsEnumOrSet { + // if types.TypesEqual(leftType, rightType) { + // return left, right, leftType, nil + // } + //} else { + // // If right side is convertible to enum/set, convert. Otherwise, convert left side + // if leftIsEnumOrSet && (types.IsText(rightType) || types.IsNumber(rightType)) { + // if r, inRange, err := leftType.Convert(ctx, right); inRange && err == nil { + // return left, r, leftType, nil + // } else { + // l, _, err := types.TypeAwareConversion(ctx, left, leftType, rightType, false) + // if err != nil { + // return nil, nil, nil, err + // } + // return l, right, rightType, nil + // } + // } + // // If left side is convertible to enum/set, convert. Otherwise, convert right side + // if rightIsEnumOrSet && (types.IsText(leftType) || types.IsNumber(leftType)) { + // if l, inRange, err := rightType.Convert(ctx, left); inRange && err == nil { + // return l, right, rightType, nil + // } else { + // r, _, err := types.TypeAwareConversion(ctx, right, rightType, leftType, false) + // if err != nil { + // return nil, nil, nil, err + // } + // return left, r, leftType, nil + // } + // } + //} + // + //if types.IsTimespan(leftType) || types.IsTimespan(rightType) { + // if l, err := types.Time.ConvertToTimespan(left); err == nil { + // if r, err := types.Time.ConvertToTimespan(right); err == nil { + // return l, r, types.Time, nil + // } + // } + //} + // + //if types.IsTuple(leftType) && types.IsTuple(rightType) { + // return left, right, c.Left().Type(), nil + //} + // + //if types.IsTime(leftType) || types.IsTime(rightType) { + // l, r, err := convertLeftAndRight(ctx, left, right, ConvertToDatetime) + // if err != nil { + // return nil, nil, nil, err + // } + // + // return l, r, types.DatetimeMaxPrecision, nil + //} + // + //// Rely on types.JSON.Compare to handle JSON comparisons + //if types.IsJSON(leftType) || types.IsJSON(rightType) { + // return left, right, types.JSON, nil + //} + // + //if types.IsBinaryType(leftType) || types.IsBinaryType(rightType) { + // l, r, err := convertLeftAndRight(ctx, left, right, ConvertToBinary) + // if err != nil { + // return nil, nil, nil, err + // } + // return l, r, types.LongBlob, nil + //} + // + //if types.IsNumber(leftType) || types.IsNumber(rightType) { + // if types.IsDecimal(leftType) || types.IsDecimal(rightType) { + // //TODO: We need to set to the actual DECIMAL type + // l, r, err := convertLeftAndRight(ctx, left, right, ConvertToDecimal) + // if err != nil { + // return nil, nil, nil, err + // } + // + // if types.IsDecimal(leftType) { + // return l, r, leftType, nil + // } else { + // return l, r, rightType, nil + // } + // } + // + // if types.IsFloat(leftType) || types.IsFloat(rightType) { + // l, r, err := convertLeftAndRight(ctx, left, right, ConvertToDouble) + // if err != nil { + // return nil, nil, nil, err + // } + // + // return l, r, types.Float64, nil + // } + // + // if types.IsSigned(leftType) && types.IsSigned(rightType) { + // l, r, err := convertLeftAndRight(ctx, left, right, ConvertToSigned) + // if err != nil { + // return nil, nil, nil, err + // } + // + // return l, r, types.Int64, nil + // } + // + // if types.IsUnsigned(leftType) && types.IsUnsigned(rightType) { + // l, r, err := convertLeftAndRight(ctx, left, right, ConvertToUnsigned) + // if err != nil { + // return nil, nil, nil, err + // } + // + // return l, r, types.Uint64, nil + // } + // + // l, r, err := convertLeftAndRight(ctx, left, right, ConvertToDouble) + // if err != nil { + // return nil, nil, nil, err + // } + // + // return l, r, types.Float64, nil + //} + // + //left, right, err := convertLeftAndRight(ctx, left, right, ConvertToChar) + //if err != nil { + // return nil, nil, nil, err + //} + // + //return left, right, types.LongText, nil + + lType := c.Left().Type() + rType := c.Right().Type() + compType := types.GetCompareType(lType, rType) + + // Special case for JSON types + if types.IsJSON(compType) { + return left, right, compType, nil + } + + l, _, err := types.TypeAwareConversion(ctx, left, lType, compType, false) + if err != nil { + if sql.ErrTruncatedIncorrect.Is(err) { + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } - - return l, r, types.Float64, nil + // TODO: ignore all other errors? } - - left, right, err := convertLeftAndRight(ctx, left, right, ConvertToChar) + r, _, err := types.TypeAwareConversion(ctx, right, rType, compType, false) if err != nil { - return nil, nil, nil, err + if sql.ErrTruncatedIncorrect.Is(err) { + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) + } + // TODO: ignore all other errors? } - - return left, right, types.LongText, nil + return l, r, compType, nil } +// TODO: delete if everything else works out func convertLeftAndRight(ctx *sql.Context, left, right interface{}, convertTo string) (interface{}, interface{}, error) { l, err := convertValue(ctx, left, convertTo, nil, 0, 0) if err != nil { diff --git a/sql/expression/convert.go b/sql/expression/convert.go index d75ba20aad..df6d17eda1 100644 --- a/sql/expression/convert.go +++ b/sql/expression/convert.go @@ -16,6 +16,7 @@ package expression import ( "fmt" + "github.com/dolthub/vitess/go/mysql" "strings" "time" @@ -355,13 +356,19 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s dt := createConvertedDecimalType(typeLength, typeScale, false) d, _, err := dt.Convert(ctx, val) if err != nil { - return dt.Zero(), nil + if !sql.ErrTruncatedIncorrect.Is(err) { + return dt.Zero(), nil + } + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } return d, nil case ConvertToFloat: d, _, err := types.Float32.Convert(ctx, val) if err != nil { - return types.Float32.Zero(), nil + if !sql.ErrTruncatedIncorrect.Is(err) { + return types.Float32.Zero(), nil + } + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } return d, nil case ConvertToDouble, ConvertToReal: @@ -370,7 +377,7 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s return d, nil } if sql.ErrTruncatedIncorrect.Is(err) { - ctx.Warn(1265, "%s", err.Error()) + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) return d, nil } return types.Float64.Zero(), nil @@ -386,7 +393,7 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s return num, nil } if sql.ErrTruncatedIncorrect.Is(err) { - ctx.Warn(1265, "%s", err.Error()) + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) return num, nil } return types.Int64.Zero(), nil @@ -402,7 +409,7 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s return num, nil } if sql.ErrTruncatedIncorrect.Is(err) { - ctx.Warn(1265, "%s", err.Error()) + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) return num, nil } num, _, err = types.Int64.Convert(ctx, val) diff --git a/sql/expression/in.go b/sql/expression/in.go index 622ee5779f..219a07e512 100644 --- a/sql/expression/in.go +++ b/sql/expression/in.go @@ -85,6 +85,7 @@ func (in *InTuple) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } } + leftLit := NewLiteral(originalLeft, in.Left().Type()) for _, el := range right { originalRight, err := el.Eval(ctx, row) if err != nil { @@ -96,17 +97,13 @@ func (in *InTuple) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { continue } - comp := newComparison(NewLiteral(originalLeft, in.Left().Type()), NewLiteral(originalRight, el.Type())) - l, r, compareType, err := comp.CastLeftAndRight(ctx, originalLeft, originalRight) + // TODO: determine comparison type + comp := newComparison(leftLit, NewLiteral(originalRight, el.Type())) + res, err := comp.Compare(ctx, nil) if err != nil { return nil, err } - cmp, err := compareType.Compare(ctx, l, r) - if err != nil { - return nil, err - } - - if cmp == 0 { + if res == 0 { return true, nil } } diff --git a/sql/plan/hash_lookup.go b/sql/plan/hash_lookup.go index fe750db6c2..2e0e009604 100644 --- a/sql/plan/hash_lookup.go +++ b/sql/plan/hash_lookup.go @@ -34,7 +34,7 @@ import ( // simply delegates to the child. func NewHashLookup(n sql.Node, rightEntryKey sql.Expression, leftProbeKey sql.Expression, joinType JoinType) *HashLookup { leftKeySch := hash.ExprsToSchema(leftProbeKey) - compareType := GetCompareType(leftProbeKey.Type(), rightEntryKey.Type()) + compareType := types.GetCompareType(leftProbeKey.Type(), rightEntryKey.Type()) return &HashLookup{ UnaryNode: UnaryNode{n}, RightEntryKey: rightEntryKey, @@ -61,46 +61,6 @@ var _ sql.Node = (*HashLookup)(nil) var _ sql.Expressioner = (*HashLookup)(nil) var _ sql.CollationCoercible = (*HashLookup)(nil) -// GetCompareType returns the type to use when comparing values of types left and right. -func GetCompareType(left, right sql.Type) sql.Type { - // TODO: much of this logic is very similar to castLeftAndRight() from sql/expression/comparison.go - // consider consolidating - if left.Equals(right) { - return left - } - if types.IsTuple(left) && types.IsTuple(right) { - return left - } - if types.IsTime(left) || types.IsTime(right) { - return types.DatetimeMaxPrecision - } - if types.IsJSON(left) || types.IsJSON(right) { - return types.JSON - } - if types.IsBinaryType(left) || types.IsBinaryType(right) { - return types.LongBlob - } - if types.IsNumber(left) || types.IsNumber(right) { - if types.IsDecimal(left) { - return left - } - if types.IsDecimal(right) { - return right - } - if types.IsFloat(left) || types.IsFloat(right) { - return types.Float64 - } - if types.IsSigned(left) && types.IsSigned(right) { - return types.Int64 - } - if types.IsUnsigned(left) && types.IsUnsigned(right) { - return types.Uint64 - } - return types.Float64 - } - return types.LongText -} - func (n *HashLookup) Expressions() []sql.Expression { return []sql.Expression{n.RightEntryKey, n.LeftProbeKey} } diff --git a/sql/types/datetime.go b/sql/types/datetime.go index 596c71a6e0..659ac54725 100644 --- a/sql/types/datetime.go +++ b/sql/types/datetime.go @@ -169,7 +169,6 @@ func (t datetimeType) Compare(ctx context.Context, a interface{}, b interface{}) if err != nil { return 0, err } - } else if t.baseType == sqltypes.Date { bt = bt.Truncate(24 * time.Hour) } diff --git a/sql/types/decimal.go b/sql/types/decimal.go index c866875fbd..214ffa63a7 100644 --- a/sql/types/decimal.go +++ b/sql/types/decimal.go @@ -140,13 +140,17 @@ func (t DecimalType_) Compare(s context.Context, a interface{}, b interface{}) ( // Convert implements Type interface. func (t DecimalType_) Convert(c context.Context, v interface{}) (interface{}, sql.ConvertInRange, error) { dec, err := t.ConvertToNullDecimal(v) - if err != nil && !sql.ErrIncorrectValue.Is(err) { + if err != nil && !sql.ErrTruncatedIncorrect.Is(err) { return nil, sql.OutOfRange, err } if !dec.Valid { return nil, sql.InRange, nil } - return t.BoundsCheck(dec.Decimal) + d, inRange, bErr := t.BoundsCheck(dec.Decimal) + if bErr != nil { + err = bErr + } + return d, inRange, err } func (t DecimalType_) ConvertNoBoundsCheck(v interface{}) (decimal.Decimal, error) { @@ -201,7 +205,7 @@ func (t DecimalType_) ConvertToNullDecimal(v interface{}) (decimal.NullDecimal, var err error truncStr, didTrunc := TruncateStringToNumber(value) if didTrunc { - err = sql.ErrIncorrectValue.New(t.String(), value) + err = sql.ErrTruncatedIncorrect.New(t.String(), value) } var dec decimal.Decimal if len(truncStr) == 0 { diff --git a/sql/types/number.go b/sql/types/number.go index 132836188f..3a8a9c76a6 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -307,7 +307,7 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v any) (any, sql.ConvertIn if num < -math.MaxFloat32 { return float32(-math.MaxFloat32), sql.OutOfRange, nil } - return float32(num), sql.InRange, nil // TODO: pass up error for warning? + return float32(num), sql.InRange, err case sqltypes.Float64: ret, err := convertToFloat64(t, v) return ret, sql.InRange, err @@ -1036,7 +1036,6 @@ func convertToInt64(t NumberTypeImpl_, v any, round bool) (int64, sql.ConvertInR } return int64(math.Round(v)), sql.InRange, nil case decimal.Decimal: - // TODO: round? if v.GreaterThan(dec_int64_max) { return dec_int64_max.IntPart(), sql.OutOfRange, nil } diff --git a/sql/types/utils.go b/sql/types/utils.go new file mode 100644 index 0000000000..2a7dbbacfc --- /dev/null +++ b/sql/types/utils.go @@ -0,0 +1,58 @@ +// Copyright 2025 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package types + +import ( + "github.com/dolthub/go-mysql-server/sql" +) + +// GetCompareType returns the type to use when comparing values of types left and right. +func GetCompareType(left, right sql.Type) sql.Type { + if left.Equals(right) { + return left + } + + if IsTimespan(left) || IsTimespan(right) { + return left + } + if IsTuple(left) && IsTuple(right) { + return left + } + if IsTime(left) || IsTime(right) { + return DatetimeMaxPrecision + } + if IsJSON(left) || IsJSON(right) { + return JSON + } + if IsBinaryType(left) || IsBinaryType(right) { + return LongBlob + } + if IsNumber(left) || IsNumber(right) { + if IsDecimal(left) || IsDecimal(right) { + return InternalDecimalType + } + if IsFloat(left) || IsFloat(right) { + return Float64 + } + if IsSigned(left) && IsSigned(right) { + return Int64 + } + if IsUnsigned(left) && IsUnsigned(right) { + return Uint64 + } + return Float64 + } + return LongText +} From e2c9956b69ee44d9ce9177bfbf631b46eb046816 Mon Sep 17 00:00:00 2001 From: jycor Date: Mon, 15 Sep 2025 21:41:09 +0000 Subject: [PATCH 22/48] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/types/decimal.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/types/decimal.go b/sql/types/decimal.go index 214ffa63a7..0958d8c01a 100644 --- a/sql/types/decimal.go +++ b/sql/types/decimal.go @@ -20,12 +20,12 @@ import ( "math/big" "reflect" + "github.com/dolthub/vitess/go/sqltypes" + "github.com/dolthub/vitess/go/vt/proto/query" "github.com/shopspring/decimal" "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/vitess/go/sqltypes" - "github.com/dolthub/vitess/go/vt/proto/query" ) const ( From 83c1f303c7bcc294045d68675c2b9065cb43a848 Mon Sep 17 00:00:00 2001 From: jycor Date: Tue, 16 Sep 2025 18:20:56 +0000 Subject: [PATCH 23/48] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/expression/comparison.go | 1 + sql/expression/convert.go | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index 69d4ab6fb5..a3a2b456f8 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -16,6 +16,7 @@ package expression import ( "fmt" + "github.com/dolthub/vitess/go/mysql" errors "gopkg.in/src-d/go-errors.v1" diff --git a/sql/expression/convert.go b/sql/expression/convert.go index df6d17eda1..7d21cc0c6e 100644 --- a/sql/expression/convert.go +++ b/sql/expression/convert.go @@ -16,10 +16,10 @@ package expression import ( "fmt" - "github.com/dolthub/vitess/go/mysql" "strings" "time" + "github.com/dolthub/vitess/go/mysql" "github.com/dolthub/vitess/go/sqltypes" "github.com/sirupsen/logrus" "gopkg.in/src-d/go-errors.v1" From be3d33c698064284f9d2bf5b8373aa5c6aa7683d Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 16 Sep 2025 12:18:49 -0700 Subject: [PATCH 24/48] aaaaa --- enginetest/memory_engine_test.go | 15 ++++++++------- sql/expression/comparison.go | 26 ++++++++++++++++++++++++++ sql/types/utils.go | 10 ++++++++++ 3 files changed, 44 insertions(+), 7 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index b77c46d94e..b36a1e068d 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -200,18 +200,19 @@ func TestSingleQueryPrepared(t *testing.T) { // Convenience test for debugging a single query. Unskip and set to the desired query. func TestSingleScript(t *testing.T) { - //t.Skip() + t.Skip() var scripts = []queries.ScriptTest{ { - Name: "asdfasdfasdf", + Name: "sets", SetUpScript: []string{ - "create table parent (e enum('a', 'b', 'c') primary key);", - "insert into parent values (1), (2);", - "create table child1 (e enum('x', 'y', 'z'), foreign key (e) references parent (e));", + `CREATE TABLE test (pk SET("a","b","c") PRIMARY KEY, v1 SET("w","x","y","z"));`, + `INSERT INTO test VALUES (0, 1), ("b", "y"), ("b,c", "z,z"), ("a,c,b", 10);`, + `UPDATE test SET v1 = "y,x,w" WHERE pk >= 4`, + `DELETE FROM test WHERE pk > "b,c";`, }, Assertions: []queries.ScriptTestAssertion{ { - Query: "insert into child1 values (1), (2);", + Query: `SELECT * FROM test ORDER BY pk;`, Expected: []sql.Row{ {types.NewOkResult(2)}, }, @@ -242,7 +243,7 @@ func TestSingleScript(t *testing.T) { for _, test := range scripts { harness := enginetest.NewMemoryHarness("", 1, testNumPartitions, true, nil) - //harness.UseServer() + harness.UseServer() engine, err := harness.NewEngine(t) if err != nil { panic(err) diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index a3a2b456f8..54dd5b2359 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -303,6 +303,32 @@ func (c *comparison) CastLeftAndRight(ctx *sql.Context, left, right interface{}) lType := c.Left().Type() rType := c.Right().Type() + + //lIsEnumOrSet := types.IsEnum(lType) || types.IsSet(lType) + //rIsEnumOrSet := types.IsEnum(rType) || types.IsSet(rType) + //// If right side is convertible to enum/set, convert. Otherwise, convert left side + //if lIsEnumOrSet && (types.IsText(rType) || types.IsNumber(rType)) { + // if r, inRange, err := lType.Convert(ctx, right); inRange && err == nil { + // return left, r, lType, nil + // } + // l, _, err := types.TypeAwareConversion(ctx, left, lType, rType, false) + // if err != nil { + // return nil, nil, nil, err + // } + // return l, right, rType, nil + //} + //// If left side is convertible to enum/set, convert. Otherwise, convert right side + //if rIsEnumOrSet && (types.IsText(lType) || types.IsNumber(lType)) { + // if l, inRange, err := rType.Convert(ctx, left); inRange && err == nil { + // return l, right, rType, nil + // } + // r, _, err := types.TypeAwareConversion(ctx, right, rType, lType, false) + // if err != nil { + // return nil, nil, nil, err + // } + // return left, r, lType, nil + //} + compType := types.GetCompareType(lType, rType) // Special case for JSON types diff --git a/sql/types/utils.go b/sql/types/utils.go index 2a7dbbacfc..f6374cf6f6 100644 --- a/sql/types/utils.go +++ b/sql/types/utils.go @@ -16,6 +16,7 @@ package types import ( "github.com/dolthub/go-mysql-server/sql" + ) // GetCompareType returns the type to use when comparing values of types left and right. @@ -24,6 +25,15 @@ func GetCompareType(left, right sql.Type) sql.Type { return left } + // Left and right are both Enum types, but not the same, so use uint16 representation for comparison + if IsEnum(left) && IsEnum(right) { + return Uint16 + } + // Left and right are both Set types, but not the same, so use uint16 representation for comparison + if IsSet(left) && IsSet(right) { + return Uint16 + } + if IsTimespan(left) || IsTimespan(right) { return left } From 82ce778b86fffd33aef24733169a49af24a661fe Mon Sep 17 00:00:00 2001 From: jycor Date: Tue, 16 Sep 2025 19:24:38 +0000 Subject: [PATCH 25/48] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/types/utils.go | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/types/utils.go b/sql/types/utils.go index f6374cf6f6..f5c49e4237 100644 --- a/sql/types/utils.go +++ b/sql/types/utils.go @@ -16,7 +16,6 @@ package types import ( "github.com/dolthub/go-mysql-server/sql" - ) // GetCompareType returns the type to use when comparing values of types left and right. From 48c43ab481e782278792caa1092ff5860acbb1aa Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 16 Sep 2025 16:55:11 -0700 Subject: [PATCH 26/48] fixing tests --- enginetest/queries/type_wire_queries.go | 4 ++-- sql/expression/in_test.go | 1 - sql/types/datetime.go | 1 + sql/types/decimal_test.go | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/enginetest/queries/type_wire_queries.go b/enginetest/queries/type_wire_queries.go index 74b3556da8..2f26cc5950 100644 --- a/enginetest/queries/type_wire_queries.go +++ b/enginetest/queries/type_wire_queries.go @@ -685,8 +685,8 @@ var TypeWireTests = []TypeWireTest{ SetUpScript: []string{ `CREATE TABLE test (pk SET("a","b","c") PRIMARY KEY, v1 SET("w","x","y","z"));`, `INSERT INTO test VALUES (0, 1), ("b", "y"), ("b,c", "z,z"), ("a,c,b", 10);`, - `UPDATE test SET v1 = "y,x,w" WHERE pk >= 4`, - `DELETE FROM test WHERE pk > "b,c";`, + `UPDATE test SET v1 = "y,x,w" WHERE pk >= 4;`, + `DELETE FROM test WHERE pk = "a,b,c";`, }, Queries: []string{ `SELECT * FROM test ORDER BY pk;`, diff --git a/sql/expression/in_test.go b/sql/expression/in_test.go index af3bef9592..9a3344ffeb 100644 --- a/sql/expression/in_test.go +++ b/sql/expression/in_test.go @@ -178,7 +178,6 @@ func TestInTuple(t *testing.T) { expression.NewLiteral("hi", types.TinyText), expression.NewLiteral("bye", types.TinyText), ), - err: types.ErrConvertingToTime, row: nil, result: false, }} diff --git a/sql/types/datetime.go b/sql/types/datetime.go index 659ac54725..c585fee639 100644 --- a/sql/types/datetime.go +++ b/sql/types/datetime.go @@ -186,6 +186,7 @@ func (t datetimeType) Convert(ctx context.Context, v interface{}) (interface{}, if v == nil { return nil, sql.InRange, nil } + // TODO: implement datetime truncation res, err := ConvertToTime(ctx, v, t) if err != nil { return nil, sql.OutOfRange, err diff --git a/sql/types/decimal_test.go b/sql/types/decimal_test.go index 6c1bed0d64..597202143a 100644 --- a/sql/types/decimal_test.go +++ b/sql/types/decimal_test.go @@ -317,8 +317,8 @@ func TestDecimalConvert(t *testing.T) { {5, 0, "7742", "7742", false}, {5, 0, new(big.Float).SetFloat64(-4723.875), "-4724", false}, {5, 0, 99999, "99999", false}, - {5, 0, "0xf8e1", "0", false}, - {5, 0, "0b1001110101100110", "0", false}, + {5, 0, "0xf8e1", "0", true}, + {5, 0, "0b1001110101100110", "0", true}, {5, 0, new(big.Rat).SetFrac64(999999, 10), "", true}, {5, 0, 673927, "", true}, From 24031509b669393882398c2a898fd1f10a477189 Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 16 Sep 2025 18:20:26 -0700 Subject: [PATCH 27/48] more test fixes --- sql/expression/comparison.go | 194 ++--------------------------------- sql/rowexec/insert.go | 6 -- sql/types/number.go | 181 ++++++++++++++++++++++++-------- 3 files changed, 144 insertions(+), 237 deletions(-) diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index 54dd5b2359..dbb543c46c 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -142,10 +142,7 @@ func (c *comparison) Compare(ctx *sql.Context, row sql.Row) (int, error) { return c.Left().Type().Compare(ctx, left, right) } - l, r, compareType, err := c.CastLeftAndRight(ctx, left, right) - if err != nil { - return 0, err - } + l, r, compareType := c.castLeftAndRight(ctx, left, right) // Set comparison relies on empty strings not being converted yet if types.IsSet(compareType) { @@ -172,168 +169,16 @@ func (c *comparison) evalLeftAndRight(ctx *sql.Context, row sql.Row) (interface{ return left, right, nil } -func (c *comparison) CastLeftAndRight(ctx *sql.Context, left, right interface{}) (interface{}, interface{}, sql.Type, error) { - //leftType := c.Left().Type() - //rightType := c.Right().Type() - // - //leftIsEnumOrSet := types.IsEnum(leftType) || types.IsSet(leftType) - //rightIsEnumOrSet := types.IsEnum(rightType) || types.IsSet(rightType) - //// Only convert if same Enum or Set - //if leftIsEnumOrSet && rightIsEnumOrSet { - // if types.TypesEqual(leftType, rightType) { - // return left, right, leftType, nil - // } - //} else { - // // If right side is convertible to enum/set, convert. Otherwise, convert left side - // if leftIsEnumOrSet && (types.IsText(rightType) || types.IsNumber(rightType)) { - // if r, inRange, err := leftType.Convert(ctx, right); inRange && err == nil { - // return left, r, leftType, nil - // } else { - // l, _, err := types.TypeAwareConversion(ctx, left, leftType, rightType, false) - // if err != nil { - // return nil, nil, nil, err - // } - // return l, right, rightType, nil - // } - // } - // // If left side is convertible to enum/set, convert. Otherwise, convert right side - // if rightIsEnumOrSet && (types.IsText(leftType) || types.IsNumber(leftType)) { - // if l, inRange, err := rightType.Convert(ctx, left); inRange && err == nil { - // return l, right, rightType, nil - // } else { - // r, _, err := types.TypeAwareConversion(ctx, right, rightType, leftType, false) - // if err != nil { - // return nil, nil, nil, err - // } - // return left, r, leftType, nil - // } - // } - //} - // - //if types.IsTimespan(leftType) || types.IsTimespan(rightType) { - // if l, err := types.Time.ConvertToTimespan(left); err == nil { - // if r, err := types.Time.ConvertToTimespan(right); err == nil { - // return l, r, types.Time, nil - // } - // } - //} - // - //if types.IsTuple(leftType) && types.IsTuple(rightType) { - // return left, right, c.Left().Type(), nil - //} - // - //if types.IsTime(leftType) || types.IsTime(rightType) { - // l, r, err := convertLeftAndRight(ctx, left, right, ConvertToDatetime) - // if err != nil { - // return nil, nil, nil, err - // } - // - // return l, r, types.DatetimeMaxPrecision, nil - //} - // - //// Rely on types.JSON.Compare to handle JSON comparisons - //if types.IsJSON(leftType) || types.IsJSON(rightType) { - // return left, right, types.JSON, nil - //} - // - //if types.IsBinaryType(leftType) || types.IsBinaryType(rightType) { - // l, r, err := convertLeftAndRight(ctx, left, right, ConvertToBinary) - // if err != nil { - // return nil, nil, nil, err - // } - // return l, r, types.LongBlob, nil - //} - // - //if types.IsNumber(leftType) || types.IsNumber(rightType) { - // if types.IsDecimal(leftType) || types.IsDecimal(rightType) { - // //TODO: We need to set to the actual DECIMAL type - // l, r, err := convertLeftAndRight(ctx, left, right, ConvertToDecimal) - // if err != nil { - // return nil, nil, nil, err - // } - // - // if types.IsDecimal(leftType) { - // return l, r, leftType, nil - // } else { - // return l, r, rightType, nil - // } - // } - // - // if types.IsFloat(leftType) || types.IsFloat(rightType) { - // l, r, err := convertLeftAndRight(ctx, left, right, ConvertToDouble) - // if err != nil { - // return nil, nil, nil, err - // } - // - // return l, r, types.Float64, nil - // } - // - // if types.IsSigned(leftType) && types.IsSigned(rightType) { - // l, r, err := convertLeftAndRight(ctx, left, right, ConvertToSigned) - // if err != nil { - // return nil, nil, nil, err - // } - // - // return l, r, types.Int64, nil - // } - // - // if types.IsUnsigned(leftType) && types.IsUnsigned(rightType) { - // l, r, err := convertLeftAndRight(ctx, left, right, ConvertToUnsigned) - // if err != nil { - // return nil, nil, nil, err - // } - // - // return l, r, types.Uint64, nil - // } - // - // l, r, err := convertLeftAndRight(ctx, left, right, ConvertToDouble) - // if err != nil { - // return nil, nil, nil, err - // } - // - // return l, r, types.Float64, nil - //} - // - //left, right, err := convertLeftAndRight(ctx, left, right, ConvertToChar) - //if err != nil { - // return nil, nil, nil, err - //} - // - //return left, right, types.LongText, nil - +// castLeftAndRight will find the appropriate type to cast both left and right to for comparison. +// All errors are ignored, except for warnings about truncation. +func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{}) (interface{}, interface{}, sql.Type) { lType := c.Left().Type() rType := c.Right().Type() - - //lIsEnumOrSet := types.IsEnum(lType) || types.IsSet(lType) - //rIsEnumOrSet := types.IsEnum(rType) || types.IsSet(rType) - //// If right side is convertible to enum/set, convert. Otherwise, convert left side - //if lIsEnumOrSet && (types.IsText(rType) || types.IsNumber(rType)) { - // if r, inRange, err := lType.Convert(ctx, right); inRange && err == nil { - // return left, r, lType, nil - // } - // l, _, err := types.TypeAwareConversion(ctx, left, lType, rType, false) - // if err != nil { - // return nil, nil, nil, err - // } - // return l, right, rType, nil - //} - //// If left side is convertible to enum/set, convert. Otherwise, convert right side - //if rIsEnumOrSet && (types.IsText(lType) || types.IsNumber(lType)) { - // if l, inRange, err := rType.Convert(ctx, left); inRange && err == nil { - // return l, right, rType, nil - // } - // r, _, err := types.TypeAwareConversion(ctx, right, rType, lType, false) - // if err != nil { - // return nil, nil, nil, err - // } - // return left, r, lType, nil - //} - compType := types.GetCompareType(lType, rType) // Special case for JSON types if types.IsJSON(compType) { - return left, right, compType, nil + return left, right, compType } l, _, err := types.TypeAwareConversion(ctx, left, lType, compType, false) @@ -341,31 +186,14 @@ func (c *comparison) CastLeftAndRight(ctx *sql.Context, left, right interface{}) if sql.ErrTruncatedIncorrect.Is(err) { ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } - // TODO: ignore all other errors? } r, _, err := types.TypeAwareConversion(ctx, right, rType, compType, false) if err != nil { if sql.ErrTruncatedIncorrect.Is(err) { ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } - // TODO: ignore all other errors? } - return l, r, compType, nil -} - -// TODO: delete if everything else works out -func convertLeftAndRight(ctx *sql.Context, left, right interface{}, convertTo string) (interface{}, interface{}, error) { - l, err := convertValue(ctx, left, convertTo, nil, 0, 0) - if err != nil { - return nil, nil, err - } - - r, err := convertValue(ctx, right, convertTo, nil, 0, 0) - if err != nil { - return nil, nil, err - } - - return l, r, nil + return l, r, compType } // Type implements the Expression interface. @@ -500,15 +328,7 @@ func (e *NullSafeEquals) Compare(ctx *sql.Context, row sql.Row) (int, error) { return -1, nil } - if types.TypesEqual(e.Left().Type(), e.Right().Type()) { - return e.Left().Type().Compare(ctx, left, right) - } - - var compareType sql.Type - left, right, compareType, err = e.CastLeftAndRight(ctx, left, right) - if err != nil { - return 0, err - } + left, right, compareType := e.castLeftAndRight(ctx, left, right) return compareType.Compare(ctx, left, right) } diff --git a/sql/rowexec/insert.go b/sql/rowexec/insert.go index 27aedcbc42..d537214228 100644 --- a/sql/rowexec/insert.go +++ b/sql/rowexec/insert.go @@ -125,12 +125,6 @@ func (i *insertIter) Next(ctx *sql.Context) (returnRow sql.Row, returnErr error) var converted any var inRange sql.ConvertInRange var cErr error - // TODO: AAAAHHH - // Hex strings shouldn't make it this far - //val, cErr = types.ConvertHexBlobToUint(row[idx], col.Type) - //if cErr != nil { - // return nil, i.ignoreOrClose(ctx, origRow, cErr) - //} if typ, ok := col.Type.(sql.RoundingNumberType); ok { converted, inRange, cErr = typ.ConvertRound(ctx, val) } else { diff --git a/sql/types/number.go b/sql/types/number.go index 3a8a9c76a6..4a0da33bfd 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -213,7 +213,6 @@ func (t NumberTypeImpl_) Compare(ctx context.Context, a any, b any) (int, error) // Convert implements Type interface. func (t NumberTypeImpl_) Convert(ctx context.Context, v any) (any, sql.ConvertInRange, error) { - var err error if v == nil { return nil, sql.InRange, nil } @@ -223,6 +222,7 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v any) (any, sql.ConvertIn } if jv, ok := v.(sql.JSONWrapper); ok { + var err error v, err = jv.ToInterface(ctx) if err != nil { return nil, sql.OutOfRange, err @@ -236,62 +236,96 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v any) (any, sql.ConvertIn return 0, sql.OutOfRange, err } if num > math.MaxInt8 { - return int8(math.MaxInt8), sql.OutOfRange, nil - } else if num < math.MinInt8 { - return int8(math.MinInt8), sql.OutOfRange, nil + return int8(math.MaxInt8), sql.OutOfRange, err + } + if num < math.MinInt8 { + return int8(math.MinInt8), sql.OutOfRange, err } - return int8(num), sql.InRange, nil + return int8(num), sql.InRange, err case sqltypes.Uint8: - // TODO: convertToUint8 is unnecessary, we can just use convertToInt64 and handle overflow logic here - return convertToUint8(t, v, false) + num, _, err := convertToInt64(t, v, false) + if err != nil && !sql.ErrTruncatedIncorrect.Is(err) { + return 0, sql.OutOfRange, err + } + if num > math.MaxUint8 { + return uint8(math.MaxUint8), sql.OutOfRange, err + } + if num < 0 { + return uint8(math.MaxUint8 + num + 1), sql.OutOfRange, err + } + return uint8(num), sql.InRange, err case sqltypes.Int16: num, _, err := convertToInt64(t, v, false) if err != nil && !sql.ErrTruncatedIncorrect.Is(err) { return 0, sql.OutOfRange, err } if num > math.MaxInt16 { - return int16(math.MaxInt16), sql.OutOfRange, nil - } else if num < math.MinInt16 { - return int16(math.MinInt16), sql.OutOfRange, nil + return int16(math.MaxInt16), sql.OutOfRange, err } - return int16(num), sql.InRange, nil + if num < math.MinInt16 { + return int16(math.MinInt16), sql.OutOfRange, err + } + return int16(num), sql.InRange, err case sqltypes.Uint16: - return convertToUint16(t, v, false) + num, _, err := convertToInt64(t, v, false) + if err != nil && !sql.ErrTruncatedIncorrect.Is(err) { + return 0, sql.OutOfRange, err + } + if num > math.MaxUint16 { + return uint16(math.MaxUint16), sql.OutOfRange, err + } + if num < 0 { + return uint16(math.MaxUint16 + num + 1), sql.OutOfRange, err + } + return uint16(num), sql.InRange, nil case sqltypes.Int24: num, _, err := convertToInt64(t, v, false) if err != nil && !sql.ErrTruncatedIncorrect.Is(err) { return 0, sql.OutOfRange, err } if num > (1<<23 - 1) { - return int32(1<<23 - 1), sql.OutOfRange, nil - } else if num < (-1 << 23) { - return int32(-1 << 23), sql.OutOfRange, nil + return int32(1<<23 - 1), sql.OutOfRange, err + } + if num < (-1 << 23) { + return int32(-1 << 23), sql.OutOfRange, err } - return int32(num), sql.InRange, nil + return int32(num), sql.InRange, err case sqltypes.Uint24: num, _, err := convertToInt64(t, v, false) if err != nil && !sql.ErrTruncatedIncorrect.Is(err) { return 0, sql.OutOfRange, err } if num >= (1 << 24) { - return uint32(1<<24 - 1), sql.OutOfRange, nil - } else if num < 0 { - return uint32(1<<24 - int32(-num)), sql.OutOfRange, nil + return uint32(1<<24 - 1), sql.OutOfRange, err } - return uint32(num), sql.InRange, nil + if num < 0 { + return uint32(1<<24 - int32(-num)), sql.OutOfRange, err + } + return uint32(num), sql.InRange, err case sqltypes.Int32: num, _, err := convertToInt64(t, v, false) if err != nil && !sql.ErrTruncatedIncorrect.Is(err) { return 0, sql.OutOfRange, err } if num > math.MaxInt32 { - return int32(math.MaxInt32), sql.OutOfRange, nil - } else if num < math.MinInt32 { - return int32(math.MinInt32), sql.OutOfRange, nil + return int32(math.MaxInt32), sql.OutOfRange, err + } + if num < math.MinInt32 { + return int32(math.MinInt32), sql.OutOfRange, err } return int32(num), sql.InRange, err case sqltypes.Uint32: - return convertToUint32(t, v, false) + num, _, err := convertToInt64(t, v, false) + if err != nil && !sql.ErrTruncatedIncorrect.Is(err) { + return 0, sql.OutOfRange, err + } + if num > math.MaxUint32 { + return uint32(math.MaxUint32), sql.OutOfRange, err + } + if num < 0 { + return uint32(math.MaxUint32 + num + 1), sql.OutOfRange, err + } + return uint32(num), sql.InRange, err case sqltypes.Int64: return convertToInt64(t, v, false) case sqltypes.Uint64: @@ -366,12 +400,12 @@ func (t NumberTypeImpl_) ConvertRound(ctx context.Context, v any) (any, sql.Conv case sqltypes.Uint8: switch v.(type) { case float32, float64, string: - convertToUint8(t, v, true) + return convertToUint8(t, v, true) } case sqltypes.Uint16: switch v.(type) { case float32, float64, string: - convertToUint16(t, v, true) + return convertToUint16(t, v, true) } case sqltypes.Uint24: switch v.(type) { @@ -1247,7 +1281,25 @@ func convertToUint64(t NumberTypeImpl_, v any, round bool) (uint64, sql.ConvertI } return i, sql.InRange, nil case string: + // When round = true, truncation rules are less strict + // Integers will accept valid float notation without truncation error var err error + if round { + truncStr, didTrunc := TruncateStringToNumber(v) + if didTrunc { + err = sql.ErrTruncatedIncorrect.New(t.String(), v) + } + // Parse as int first + if i, pErr := strconv.ParseUint(truncStr, 10, 64); pErr == nil { + return i, sql.InRange, nil + } + f, _ := strconv.ParseFloat(truncStr, 64) + res, inRange, cErr := convertToUint64(t, f, round) + if cErr != nil { + err = cErr + } + return res, inRange, err + } truncStr, didTrunc := TruncateStringToInt(v) if didTrunc { err = sql.ErrTruncatedIncorrect.New(t.String(), v) @@ -1265,11 +1317,10 @@ func convertToUint64(t NumberTypeImpl_, v any, round bool) (uint64, sql.ConvertI } i, pErr := strconv.ParseUint(truncStr, 10, 64) if errors.Is(pErr, strconv.ErrRange) { - // Number is too large for uint64, return max value and OutOfRange return math.MaxUint64, sql.OutOfRange, err } if neg { - i = math.MaxUint64 - i + 1 + return math.MaxUint64 - i + 1, sql.OutOfRange, err } return i, sql.InRange, err case bool: @@ -1368,7 +1419,25 @@ func convertToUint32(t NumberTypeImpl_, v any, round bool) (uint32, sql.ConvertI } return uint32(i), sql.InRange, nil case string: + // When round = true, truncation rules are less strict + // Integers will accept valid float notation without truncation error var err error + if round { + truncStr, didTrunc := TruncateStringToNumber(v) + if didTrunc { + err = sql.ErrTruncatedIncorrect.New(t.String(), v) + } + // Parse as int first + if i, pErr := strconv.ParseUint(truncStr, 10, 32); pErr == nil { + return uint32(i), sql.InRange, nil + } + f, _ := strconv.ParseFloat(truncStr, 64) + res, inRange, cErr := convertToUint32(t, f, round) + if cErr != nil { + err = cErr + } + return res, inRange, err + } truncStr, didTrunc := TruncateStringToInt(v) if didTrunc { err = sql.ErrTruncatedIncorrect.New(t.String(), v) @@ -1467,10 +1536,7 @@ func convertToUint16(t NumberTypeImpl_, v any, round bool) (uint16, sql.ConvertI if v < 0 { return uint16(math.MaxUint16 - v), sql.OutOfRange, nil } - if round { - return uint16(math.Round(float64(v))), sql.InRange, nil - } - return uint16(v), sql.InRange, nil + return uint16(math.Round(float64(v))), sql.InRange, nil case float64: if v >= float64(math.MaxUint16) { return math.MaxUint16, sql.OutOfRange, nil @@ -1478,10 +1544,7 @@ func convertToUint16(t NumberTypeImpl_, v any, round bool) (uint16, sql.ConvertI if v <= 0 { return uint16(math.MaxUint16 - v), sql.OutOfRange, nil } - if round { - return uint16(math.Round(float64(v))), sql.InRange, nil - } - return uint16(v), sql.InRange, nil + return uint16(math.Round(float64(v))), sql.InRange, nil case decimal.Decimal: if v.GreaterThan(dec_uint16_max) { return math.MaxUint16, sql.InRange, nil @@ -1500,7 +1563,25 @@ func convertToUint16(t NumberTypeImpl_, v any, round bool) (uint16, sql.ConvertI } return uint16(i), sql.InRange, nil case string: + // When round = true, truncation rules are less strict + // Integers will accept valid float notation without truncation error var err error + if round { + truncStr, didTrunc := TruncateStringToNumber(v) + if didTrunc { + err = sql.ErrTruncatedIncorrect.New(t.String(), v) + } + // Parse as int first + if i, pErr := strconv.ParseUint(truncStr, 10, 16); pErr == nil { + return uint16(i), sql.InRange, nil + } + f, _ := strconv.ParseFloat(truncStr, 64) + res, inRange, cErr := convertToUint16(t, f, round) + if cErr != nil { + err = cErr + } + return res, inRange, err + } truncStr, didTrunc := TruncateStringToNumber(v) if didTrunc { err = sql.ErrTruncatedIncorrect.New(t.String(), v) @@ -1593,10 +1674,7 @@ func convertToUint8(t NumberTypeImpl_, v any, round bool) (uint8, sql.ConvertInR if v < 0 { return uint8(math.MaxUint8 - v), sql.OutOfRange, nil } - if round { - return uint8(math.Round(float64(v))), sql.InRange, nil - } - return uint8(v), sql.InRange, nil + return uint8(math.Round(float64(v))), sql.InRange, nil case float64: if v >= float64(math.MaxUint8) { return math.MaxUint8, sql.OutOfRange, nil @@ -1604,10 +1682,7 @@ func convertToUint8(t NumberTypeImpl_, v any, round bool) (uint8, sql.ConvertInR if v <= 0 { return uint8(math.MaxUint8 - v), sql.OutOfRange, nil } - if round { - return uint8(math.Round(v)), sql.InRange, nil - } - return uint8(v), sql.InRange, nil + return uint8(math.Round(v)), sql.InRange, nil case decimal.Decimal: if v.GreaterThan(dec_uint8_max) { return math.MaxUint8, sql.InRange, nil @@ -1626,7 +1701,25 @@ func convertToUint8(t NumberTypeImpl_, v any, round bool) (uint8, sql.ConvertInR } return uint8(i), sql.InRange, nil case string: + // When round = true, truncation rules are less strict + // Integers will accept valid float notation without truncation error var err error + if round { + truncStr, didTrunc := TruncateStringToNumber(v) + if didTrunc { + err = sql.ErrTruncatedIncorrect.New(t.String(), v) + } + // Parse as int first + if i, pErr := strconv.ParseUint(truncStr, 10, 8); pErr == nil { + return uint8(i), sql.InRange, nil + } + f, _ := strconv.ParseFloat(truncStr, 64) + res, inRange, cErr := convertToUint8(t, f, round) + if cErr != nil { + err = cErr + } + return res, inRange, err + } truncStr, didTrunc := TruncateStringToInt(v) if didTrunc { err = sql.ErrTruncatedIncorrect.New(t.String(), v) From 5038fc400ffb657f3bb48259958835213bb2da0f Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 17 Sep 2025 01:15:39 -0700 Subject: [PATCH 28/48] more test fixes --- memory/table.go | 3 +++ sql/rowexec/ddl_iters.go | 3 +++ 2 files changed, 6 insertions(+) diff --git a/memory/table.go b/memory/table.go index 9932d0399a..d48adb4111 100644 --- a/memory/table.go +++ b/memory/table.go @@ -1435,6 +1435,9 @@ func (t *Table) ModifyColumn(ctx *sql.Context, columnName string, column *sql.Co if sql.ErrNotMatchingSRID.Is(err) { err = sql.ErrNotMatchingSRIDWithColName.New(columnName, err) } + if sql.ErrTruncatedIncorrect.Is(err) { + err = sql.ErrInvalidValue.New(row[oldIdx], column.Type) + } return err } if !inRange { diff --git a/sql/rowexec/ddl_iters.go b/sql/rowexec/ddl_iters.go index 5669ce5c2c..bf1e660855 100644 --- a/sql/rowexec/ddl_iters.go +++ b/sql/rowexec/ddl_iters.go @@ -927,6 +927,9 @@ func projectRowWithTypes(ctx *sql.Context, oldSchema, newSchema sql.Schema, proj if sql.ErrNotMatchingSRID.Is(err) { err = sql.ErrNotMatchingSRIDWithColName.New(newSchema[i].Name, err) } + if sql.ErrTruncatedIncorrect.Is(err) { + err = sql.ErrInvalidValue.New(newRow[i], newSchema[i].Type) + } return nil, err } else if !inRange { return nil, sql.ErrValueOutOfRange.New(newRow[i], newSchema[i].Type) From 759a42c7efbc0a6fd91647142ef0d62d521adfaf Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 9 Sep 2025 10:14:58 -0700 Subject: [PATCH 29/48] attempt at consolidating logic --- enginetest/memory_engine_test.go | 2 +- sql/types/decimal.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index b36a1e068d..76fdf230a1 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -200,7 +200,7 @@ func TestSingleQueryPrepared(t *testing.T) { // Convenience test for debugging a single query. Unskip and set to the desired query. func TestSingleScript(t *testing.T) { - t.Skip() + //t.Skip() var scripts = []queries.ScriptTest{ { Name: "sets", diff --git a/sql/types/decimal.go b/sql/types/decimal.go index 0958d8c01a..214ffa63a7 100644 --- a/sql/types/decimal.go +++ b/sql/types/decimal.go @@ -20,12 +20,12 @@ import ( "math/big" "reflect" - "github.com/dolthub/vitess/go/sqltypes" - "github.com/dolthub/vitess/go/vt/proto/query" "github.com/shopspring/decimal" "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/vitess/go/sqltypes" + "github.com/dolthub/vitess/go/vt/proto/query" ) const ( From bbe38b85acd17cbf52c30127c1daaf7c2d843f5c Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 9 Sep 2025 16:27:31 -0700 Subject: [PATCH 30/48] aaaaa --- sql/types/number.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/types/number.go b/sql/types/number.go index 4a0da33bfd..2aae577437 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -1080,7 +1080,8 @@ func convertToInt64(t NumberTypeImpl_, v any, round bool) (int64, sql.ConvertInR case []byte: return convertToInt64(t, string(v), round) case string: - // When round = true, truncation rules are less strict + +// When round = true, truncation rules are less strict // Integers will accept valid float notation without truncation error var err error if round { From a0f4c95fb4ae1c08b421b4618fb957d0aeb3d031 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 10 Sep 2025 10:11:52 -0700 Subject: [PATCH 31/48] consolidate logic for truncation --- sql/types/decimal.go | 1 + sql/types/number.go | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/types/decimal.go b/sql/types/decimal.go index 214ffa63a7..0f780573da 100644 --- a/sql/types/decimal.go +++ b/sql/types/decimal.go @@ -19,6 +19,7 @@ import ( "fmt" "math/big" "reflect" + "strings" "github.com/shopspring/decimal" "gopkg.in/src-d/go-errors.v1" diff --git a/sql/types/number.go b/sql/types/number.go index 2aae577437..4a0da33bfd 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -1080,8 +1080,7 @@ func convertToInt64(t NumberTypeImpl_, v any, round bool) (int64, sql.ConvertInR case []byte: return convertToInt64(t, string(v), round) case string: - -// When round = true, truncation rules are less strict + // When round = true, truncation rules are less strict // Integers will accept valid float notation without truncation error var err error if round { From 71f29ee1658d49da166a87415d2c941675e07c7e Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 10 Sep 2025 13:38:09 -0700 Subject: [PATCH 32/48] some progress --- enginetest/memory_engine_test.go | 116 ++++++++++++++++++++++++++++++- sql/expression/convert.go | 2 +- sql/types/conversion.go | 1 + 3 files changed, 117 insertions(+), 2 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index 76fdf230a1..06789ddd7e 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -214,11 +214,125 @@ func TestSingleScript(t *testing.T) { { Query: `SELECT * FROM test ORDER BY pk;`, Expected: []sql.Row{ - {types.NewOkResult(2)}, + {" 3 12 4", uint64(3)}, + {" 3.2 12 4", uint64(3)}, + {"-3.1234", uint64(18446744073709551613)}, + {"-3.1a", uint64(18446744073709551613)}, + {"-5+8", uint64(18446744073709551611)}, + {"+3.1234", uint64(3)}, + {"11", uint64(11)}, + {"11-5", uint64(11)}, + {"11d", uint64(11)}, + {"11wha?", uint64(11)}, + {"12", uint64(12)}, + {"1a1", uint64(1)}, + {"3. 12 4", uint64(3)}, + {"5.932887e+07", uint64(5)}, + {"5.932887e+07abc", uint64(5)}, + {"5.932887e7", uint64(5)}, + {"5.932887e7abc", uint64(5)}, + {"a1a1", uint64(2)}, }, }, + { + Dialect: "mysql", + Query: "select pk, cast(pk as decimal(12,3)) from test01", + Expected: []sql.Row{ + {" 3 12 4", "3.000"}, + {" 3.2 12 4", "3.200"}, + {"-3.1234", "-3.123"}, + {"-3.1a", "-3.100"}, + {"-5+8", "-5.000"}, + {"+3.1234", "3.123"}, + {"11", "11.000"}, + {"11-5", "11.000"}, + {"11d", "11.000"}, + {"11wha?", "11.000"}, + {"12", "12.000"}, + {"1a1", "1.000"}, + {"3. 12 4", "3.000"}, + {"5.932887e+07", "59328870.000"}, + {"5.932887e+07abc", "59328870.000"}, + {"5.932887e7", "59328870.000"}, + {"5.932887e7abc", "59328870.000"}, + {"a1a1", "0.000"}, + }, + }, + { + Query: "select * from test01 where pk in ('11')", + Expected: []sql.Row{{"11"}}, + }, + { + // https://github.com/dolthub/dolt/issues/9739 + Skip: true, + Dialect: "mysql", + Query: "select * from test01 where pk in (11)", + Expected: []sql.Row{ + {"11"}, + {"11-5"}, + {"11d"}, + {"11wha?"}, + }, + }, + { + // https://github.com/dolthub/dolt/issues/9739 + Skip: true, + Dialect: "mysql", + Query: "select * from test01 where pk=3", + Expected: []sql.Row{ + {" 3 12 4"}, + {" 3. 12 4"}, + {"3. 12 4"}, + }, + }, + { + // https://github.com/dolthub/dolt/issues/9739 + Skip: true, + Dialect: "mysql", + Query: "select * from test01 where pk>=3 and pk < 4", + Expected: []sql.Row{ + {" 3 12 4"}, + {" 3. 12 4"}, + {" 3.2 12 4"}, + {"+3.1234"}, + {"3. 12 4"}, + }, + }, + //{ + // // https://github.com/dolthub/dolt/issues/9739 + // Skip: true, + // Dialect: "mysql", + // Query: "select * from test02 where pk in ('11asdf')", + // Expected: []sql.Row{{"11"}}, + //}, + //{ + // // https://github.com/dolthub/dolt/issues/9739 + // Skip: true, + // Dialect: "mysql", + // Query: "select * from test02 where pk='11.12asdf'", + // Expected: []sql.Row{}, + //}, }, }, + //{ + // Name: "AS OF propagates to nested CALLs", + // SetUpScript: []string{}, + // Assertions: []queries.ScriptTestAssertion{ + // { + // Query: "select cast('123.99' as signed);", + // Expected: []sql.Row{ + // {123}, + // }, + // }, + // // TODO: some how fix this + // { + // Query: "select x'20' = 32;", + // Expected: []sql.Row{ + // {types.NewOkResult(0)}, + // }, + // }, + // }, + //}, //{ // Name: "AS OF propagates to nested CALLs", diff --git a/sql/expression/convert.go b/sql/expression/convert.go index 7d21cc0c6e..78ac45fda9 100644 --- a/sql/expression/convert.go +++ b/sql/expression/convert.go @@ -377,7 +377,7 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s return d, nil } if sql.ErrTruncatedIncorrect.Is(err) { - ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) + ctx.Warn(1265, "%s", err.Error()) return d, nil } return types.Float64.Zero(), nil diff --git a/sql/types/conversion.go b/sql/types/conversion.go index e39d00e520..a68e00343a 100644 --- a/sql/types/conversion.go +++ b/sql/types/conversion.go @@ -831,6 +831,7 @@ func ConvertHexBlobToUint(val any, originType sql.Type) (any, error) { // TruncateStringToNumber truncates a string to the appropriate number prefix. // This function expects whitespace to already be properly trimmed. +// TODO: separate logic for ints and floating point? func TruncateStringToNumber(s string) (string, bool) { seenDigit := false seenDot := false From cdcaff4d648ae6aa5f5f4c18c4cd9b262509a1fa Mon Sep 17 00:00:00 2001 From: jycor Date: Wed, 10 Sep 2025 17:14:47 +0000 Subject: [PATCH 33/48] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/types/decimal.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sql/types/decimal.go b/sql/types/decimal.go index 0f780573da..0c36b353e2 100644 --- a/sql/types/decimal.go +++ b/sql/types/decimal.go @@ -21,6 +21,9 @@ import ( "reflect" "strings" + "github.com/dolthub/vitess/go/sqltypes" + "github.com/dolthub/vitess/go/vt/proto/query" + "github.com/shopspring/decimal" "github.com/shopspring/decimal" "gopkg.in/src-d/go-errors.v1" From 53b61c3987352b4d855a2f69fd212cee52098893 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 11 Sep 2025 16:41:48 -0700 Subject: [PATCH 34/48] refactoring and fixing char function --- sql/expression/function/char.go | 2 +- sql/types/conversion.go | 1 - sql/types/decimal.go | 9 ++------- sql/types/number.go | 36 ++++++++++++++++++++++++++------- 4 files changed, 32 insertions(+), 16 deletions(-) diff --git a/sql/expression/function/char.go b/sql/expression/function/char.go index 531327a046..d69493e705 100644 --- a/sql/expression/function/char.go +++ b/sql/expression/function/char.go @@ -104,7 +104,7 @@ func encodeUInt32(num uint32) []byte { // Eval implements the sql.Expression interface func (c *Char) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - res := []byte{} + var res []byte for _, arg := range c.args { if arg == nil { continue diff --git a/sql/types/conversion.go b/sql/types/conversion.go index a68e00343a..e39d00e520 100644 --- a/sql/types/conversion.go +++ b/sql/types/conversion.go @@ -831,7 +831,6 @@ func ConvertHexBlobToUint(val any, originType sql.Type) (any, error) { // TruncateStringToNumber truncates a string to the appropriate number prefix. // This function expects whitespace to already be properly trimmed. -// TODO: separate logic for ints and floating point? func TruncateStringToNumber(s string) (string, bool) { seenDigit := false seenDot := false diff --git a/sql/types/decimal.go b/sql/types/decimal.go index 0c36b353e2..6eedf9f53b 100644 --- a/sql/types/decimal.go +++ b/sql/types/decimal.go @@ -17,15 +17,10 @@ package types import ( "context" "fmt" - "math/big" - "reflect" - "strings" - - "github.com/dolthub/vitess/go/sqltypes" - "github.com/dolthub/vitess/go/vt/proto/query" - "github.com/shopspring/decimal" "github.com/shopspring/decimal" "gopkg.in/src-d/go-errors.v1" + "math/big" + "reflect" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/vitess/go/sqltypes" diff --git a/sql/types/number.go b/sql/types/number.go index 4a0da33bfd..7e7493084a 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -1060,7 +1060,10 @@ func convertToInt64(t NumberTypeImpl_, v any, round bool) (int64, sql.ConvertInR if v < float32(math.MinInt64) { return math.MinInt64, sql.OutOfRange, nil } - return int64(math.Round(float64(v))), sql.InRange, nil + if round { + return int64(math.Round(float64(v))), sql.InRange, nil + } + return int64(v), sql.InRange, nil case float64: if v > float64(math.MaxInt64) { return math.MaxInt64, sql.OutOfRange, nil @@ -1068,8 +1071,12 @@ func convertToInt64(t NumberTypeImpl_, v any, round bool) (int64, sql.ConvertInR if v < float64(math.MinInt64) { return math.MinInt64, sql.OutOfRange, nil } - return int64(math.Round(v)), sql.InRange, nil + if round { + return int64(math.Round(v)), sql.InRange, nil + } + return int64(v), sql.InRange, nil case decimal.Decimal: + // TODO: round? if v.GreaterThan(dec_int64_max) { return dec_int64_max.IntPart(), sql.OutOfRange, nil } @@ -1254,7 +1261,10 @@ func convertToUint64(t NumberTypeImpl_, v any, round bool) (uint64, sql.ConvertI if v < 0 { return uint64(math.MaxUint64 - v), sql.OutOfRange, nil } - return uint64(math.Round(float64(v))), sql.InRange, nil + if round { + return uint64(math.Round(float64(v))), sql.InRange, nil + } + return uint64(v), sql.InRange, nil case float64: if v >= float64(math.MaxUint64) { return math.MaxUint64, sql.OutOfRange, nil @@ -1262,7 +1272,10 @@ func convertToUint64(t NumberTypeImpl_, v any, round bool) (uint64, sql.ConvertI if v < 0 { return uint64(math.MaxUint64 - v), sql.OutOfRange, nil } - return uint64(math.Round(v)), sql.InRange, nil + if round { + return uint64(math.Round(v)), sql.InRange, nil + } + return uint64(v), sql.InRange, nil case decimal.Decimal: if v.GreaterThan(dec_uint64_max) { return math.MaxUint64, sql.OutOfRange, nil @@ -1536,7 +1549,10 @@ func convertToUint16(t NumberTypeImpl_, v any, round bool) (uint16, sql.ConvertI if v < 0 { return uint16(math.MaxUint16 - v), sql.OutOfRange, nil } - return uint16(math.Round(float64(v))), sql.InRange, nil + if round { + return uint16(math.Round(float64(v))), sql.InRange, nil + } + return uint16(v), sql.InRange, nil case float64: if v >= float64(math.MaxUint16) { return math.MaxUint16, sql.OutOfRange, nil @@ -1674,7 +1690,10 @@ func convertToUint8(t NumberTypeImpl_, v any, round bool) (uint8, sql.ConvertInR if v < 0 { return uint8(math.MaxUint8 - v), sql.OutOfRange, nil } - return uint8(math.Round(float64(v))), sql.InRange, nil + if round { + return uint8(math.Round(float64(v))), sql.InRange, nil + } + return uint8(v), sql.InRange, nil case float64: if v >= float64(math.MaxUint8) { return math.MaxUint8, sql.OutOfRange, nil @@ -1682,7 +1701,10 @@ func convertToUint8(t NumberTypeImpl_, v any, round bool) (uint8, sql.ConvertInR if v <= 0 { return uint8(math.MaxUint8 - v), sql.OutOfRange, nil } - return uint8(math.Round(v)), sql.InRange, nil + if round { + return uint8(math.Round(v)), sql.InRange, nil + } + return uint8(v), sql.InRange, nil case decimal.Decimal: if v.GreaterThan(dec_uint8_max) { return math.MaxUint8, sql.InRange, nil From 8888b237c542424bef3ce4baa8d56c0b2a1026b2 Mon Sep 17 00:00:00 2001 From: jycor Date: Thu, 11 Sep 2025 23:56:26 +0000 Subject: [PATCH 35/48] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/types/decimal.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sql/types/decimal.go b/sql/types/decimal.go index 6eedf9f53b..1e29879b19 100644 --- a/sql/types/decimal.go +++ b/sql/types/decimal.go @@ -17,11 +17,15 @@ package types import ( "context" "fmt" - "github.com/shopspring/decimal" - "gopkg.in/src-d/go-errors.v1" "math/big" "reflect" + "github.com/dolthub/vitess/go/sqltypes" + "github.com/dolthub/vitess/go/vt/proto/query" + "github.com/shopspring/decimal" + "github.com/shopspring/decimal" + "gopkg.in/src-d/go-errors.v1" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" From 867e8f93de5b9bfd8173ea1ae1c6dcfe55d2c070 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 12 Sep 2025 14:05:26 -0700 Subject: [PATCH 36/48] more edge case fixing --- memory/table.go | 1 + sql/types/number_test.go | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/memory/table.go b/memory/table.go index d48adb4111..eab5886859 100644 --- a/memory/table.go +++ b/memory/table.go @@ -1430,6 +1430,7 @@ func (t *Table) ModifyColumn(ctx *sql.Context, columnName string, column *sql.Co oldRowWithoutVal = append(oldRowWithoutVal, row[:oldIdx]...) oldRowWithoutVal = append(oldRowWithoutVal, row[oldIdx+1:]...) oldType := data.schema.Schema[oldIdx].Type + // TODO: this needs to call the rounding conversion thing newVal, inRange, err := types.TypeAwareConversion(ctx, row[oldIdx], oldType, column.Type, true) if err != nil { if sql.ErrNotMatchingSRID.Is(err) { diff --git a/sql/types/number_test.go b/sql/types/number_test.go index bc79a20be5..4d6a655e1a 100644 --- a/sql/types/number_test.go +++ b/sql/types/number_test.go @@ -189,7 +189,7 @@ func TestNumberConvert(t *testing.T) { {typ: Int64, inp: false, exp: int64(0), err: false, inRange: sql.InRange}, {typ: Uint8, inp: int64(34), exp: uint8(34), err: false, inRange: sql.InRange}, {typ: Uint16, inp: int16(35), exp: uint16(35), err: false, inRange: sql.InRange}, - {typ: Uint24, inp: 36.756, exp: uint32(37), err: false, inRange: sql.InRange}, + {typ: Uint24, inp: 36.756, exp: uint32(36), err: false, inRange: sql.InRange}, {typ: Uint32, inp: uint8(37), exp: uint32(37), err: false, inRange: sql.InRange}, {typ: Uint64, inp: time.Date(2009, 1, 2, 3, 4, 5, 0, time.UTC), exp: uint64(time.Date(2009, 1, 2, 3, 4, 5, 0, time.UTC).Unix()), err: false, inRange: sql.InRange}, {typ: Uint64, inp: "01000", exp: uint64(1000), err: false, inRange: sql.InRange}, From f390e647220b4b440fc16036f0f3197a54810e58 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 12 Sep 2025 16:09:35 -0700 Subject: [PATCH 37/48] almost done --- enginetest/queries/column_default_queries.go | 1 - 1 file changed, 1 deletion(-) diff --git a/enginetest/queries/column_default_queries.go b/enginetest/queries/column_default_queries.go index 8ab0b9222e..cb53bded6e 100644 --- a/enginetest/queries/column_default_queries.go +++ b/enginetest/queries/column_default_queries.go @@ -571,7 +571,6 @@ var ColumnDefaultTests = []ScriptTest{ }, }, { - // Technically, MySQL does NOT allow BLOB/JSON/TEXT types to have a literal default value, and requires them // to be specified as an expression (i.e. wrapped in parens). We diverge from this behavior and allow it, for // compatibility with MariaDB. For more context, see: https://github.com/dolthub/dolt/issues/7033 Name: "BLOB types can define defaults with literals", From 19d4470eb3ad0f1adcdf9884afe164cf85f1140b Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 15 Sep 2025 10:14:19 -0700 Subject: [PATCH 38/48] cleaning up --- enginetest/queries/column_default_queries.go | 1 + memory/table.go | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/enginetest/queries/column_default_queries.go b/enginetest/queries/column_default_queries.go index cb53bded6e..8ab0b9222e 100644 --- a/enginetest/queries/column_default_queries.go +++ b/enginetest/queries/column_default_queries.go @@ -571,6 +571,7 @@ var ColumnDefaultTests = []ScriptTest{ }, }, { + // Technically, MySQL does NOT allow BLOB/JSON/TEXT types to have a literal default value, and requires them // to be specified as an expression (i.e. wrapped in parens). We diverge from this behavior and allow it, for // compatibility with MariaDB. For more context, see: https://github.com/dolthub/dolt/issues/7033 Name: "BLOB types can define defaults with literals", diff --git a/memory/table.go b/memory/table.go index eab5886859..d48adb4111 100644 --- a/memory/table.go +++ b/memory/table.go @@ -1430,7 +1430,6 @@ func (t *Table) ModifyColumn(ctx *sql.Context, columnName string, column *sql.Co oldRowWithoutVal = append(oldRowWithoutVal, row[:oldIdx]...) oldRowWithoutVal = append(oldRowWithoutVal, row[oldIdx+1:]...) oldType := data.schema.Schema[oldIdx].Type - // TODO: this needs to call the rounding conversion thing newVal, inRange, err := types.TypeAwareConversion(ctx, row[oldIdx], oldType, column.Type, true) if err != nil { if sql.ErrNotMatchingSRID.Is(err) { From 7f5a6077a57f5986ac3c5b95b0a3038202c0ed4b Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 15 Sep 2025 10:30:06 -0700 Subject: [PATCH 39/48] fixing tests --- sql/expression/function/bit_count.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/expression/function/bit_count.go b/sql/expression/function/bit_count.go index 56c4d68946..1830ad67b0 100644 --- a/sql/expression/function/bit_count.go +++ b/sql/expression/function/bit_count.go @@ -97,7 +97,7 @@ func (b *BitCount) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { res += countBits(uint64(v)) } default: - num, _, err := types.Int64.Convert(ctx, child) + num, _, err := types.Int64.(sql.RoundingNumberType).ConvertRound(ctx, child) if err != nil { if !sql.ErrTruncatedIncorrect.Is(err) { return nil, err From b63da5700cc21b21ed9321a9ef809301acd694f5 Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 15 Sep 2025 12:14:51 -0700 Subject: [PATCH 40/48] fix more tests --- sql/expression/function/bit_count.go | 2 +- sql/types/number.go | 20 ++++---------------- 2 files changed, 5 insertions(+), 17 deletions(-) diff --git a/sql/expression/function/bit_count.go b/sql/expression/function/bit_count.go index 1830ad67b0..56c4d68946 100644 --- a/sql/expression/function/bit_count.go +++ b/sql/expression/function/bit_count.go @@ -97,7 +97,7 @@ func (b *BitCount) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { res += countBits(uint64(v)) } default: - num, _, err := types.Int64.(sql.RoundingNumberType).ConvertRound(ctx, child) + num, _, err := types.Int64.Convert(ctx, child) if err != nil { if !sql.ErrTruncatedIncorrect.Is(err) { return nil, err diff --git a/sql/types/number.go b/sql/types/number.go index 7e7493084a..25fe8f2e1a 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -1060,10 +1060,7 @@ func convertToInt64(t NumberTypeImpl_, v any, round bool) (int64, sql.ConvertInR if v < float32(math.MinInt64) { return math.MinInt64, sql.OutOfRange, nil } - if round { - return int64(math.Round(float64(v))), sql.InRange, nil - } - return int64(v), sql.InRange, nil + return int64(math.Round(float64(v))), sql.InRange, nil case float64: if v > float64(math.MaxInt64) { return math.MaxInt64, sql.OutOfRange, nil @@ -1071,10 +1068,7 @@ func convertToInt64(t NumberTypeImpl_, v any, round bool) (int64, sql.ConvertInR if v < float64(math.MinInt64) { return math.MinInt64, sql.OutOfRange, nil } - if round { - return int64(math.Round(v)), sql.InRange, nil - } - return int64(v), sql.InRange, nil + return int64(math.Round(v)), sql.InRange, nil case decimal.Decimal: // TODO: round? if v.GreaterThan(dec_int64_max) { @@ -1261,10 +1255,7 @@ func convertToUint64(t NumberTypeImpl_, v any, round bool) (uint64, sql.ConvertI if v < 0 { return uint64(math.MaxUint64 - v), sql.OutOfRange, nil } - if round { - return uint64(math.Round(float64(v))), sql.InRange, nil - } - return uint64(v), sql.InRange, nil + return uint64(math.Round(float64(v))), sql.InRange, nil case float64: if v >= float64(math.MaxUint64) { return math.MaxUint64, sql.OutOfRange, nil @@ -1272,10 +1263,7 @@ func convertToUint64(t NumberTypeImpl_, v any, round bool) (uint64, sql.ConvertI if v < 0 { return uint64(math.MaxUint64 - v), sql.OutOfRange, nil } - if round { - return uint64(math.Round(v)), sql.InRange, nil - } - return uint64(v), sql.InRange, nil + return uint64(math.Round(v)), sql.InRange, nil case decimal.Decimal: if v.GreaterThan(dec_uint64_max) { return math.MaxUint64, sql.OutOfRange, nil From d2360238a05c62ec1a73441f40eac5f22490fa32 Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 15 Sep 2025 13:42:03 -0700 Subject: [PATCH 41/48] fix tests --- sql/types/number_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/types/number_test.go b/sql/types/number_test.go index 4d6a655e1a..bc79a20be5 100644 --- a/sql/types/number_test.go +++ b/sql/types/number_test.go @@ -189,7 +189,7 @@ func TestNumberConvert(t *testing.T) { {typ: Int64, inp: false, exp: int64(0), err: false, inRange: sql.InRange}, {typ: Uint8, inp: int64(34), exp: uint8(34), err: false, inRange: sql.InRange}, {typ: Uint16, inp: int16(35), exp: uint16(35), err: false, inRange: sql.InRange}, - {typ: Uint24, inp: 36.756, exp: uint32(36), err: false, inRange: sql.InRange}, + {typ: Uint24, inp: 36.756, exp: uint32(37), err: false, inRange: sql.InRange}, {typ: Uint32, inp: uint8(37), exp: uint32(37), err: false, inRange: sql.InRange}, {typ: Uint64, inp: time.Date(2009, 1, 2, 3, 4, 5, 0, time.UTC), exp: uint64(time.Date(2009, 1, 2, 3, 4, 5, 0, time.UTC).Unix()), err: false, inRange: sql.InRange}, {typ: Uint64, inp: "01000", exp: uint64(1000), err: false, inRange: sql.InRange}, From d48d5b29de5cfa399f8b3fce35df92bfa8e12e60 Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 15 Sep 2025 14:39:09 -0700 Subject: [PATCH 42/48] rebase --- sql/expression/convert.go | 1 - sql/types/decimal.go | 3 --- 2 files changed, 4 deletions(-) diff --git a/sql/expression/convert.go b/sql/expression/convert.go index 78ac45fda9..6c719449bd 100644 --- a/sql/expression/convert.go +++ b/sql/expression/convert.go @@ -19,7 +19,6 @@ import ( "strings" "time" - "github.com/dolthub/vitess/go/mysql" "github.com/dolthub/vitess/go/sqltypes" "github.com/sirupsen/logrus" "gopkg.in/src-d/go-errors.v1" diff --git a/sql/types/decimal.go b/sql/types/decimal.go index 1e29879b19..214ffa63a7 100644 --- a/sql/types/decimal.go +++ b/sql/types/decimal.go @@ -20,9 +20,6 @@ import ( "math/big" "reflect" - "github.com/dolthub/vitess/go/sqltypes" - "github.com/dolthub/vitess/go/vt/proto/query" - "github.com/shopspring/decimal" "github.com/shopspring/decimal" "gopkg.in/src-d/go-errors.v1" From e521ce784aa8507f4161687550a0558b67de8050 Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 16 Sep 2025 11:18:17 -0700 Subject: [PATCH 43/48] asdf --- sql/expression/comparison.go | 1 - sql/expression/convert.go | 3 ++- sql/types/number.go | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index dbb543c46c..5694841061 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -16,7 +16,6 @@ package expression import ( "fmt" - "github.com/dolthub/vitess/go/mysql" errors "gopkg.in/src-d/go-errors.v1" diff --git a/sql/expression/convert.go b/sql/expression/convert.go index 6c719449bd..df6d17eda1 100644 --- a/sql/expression/convert.go +++ b/sql/expression/convert.go @@ -16,6 +16,7 @@ package expression import ( "fmt" + "github.com/dolthub/vitess/go/mysql" "strings" "time" @@ -376,7 +377,7 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s return d, nil } if sql.ErrTruncatedIncorrect.Is(err) { - ctx.Warn(1265, "%s", err.Error()) + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) return d, nil } return types.Float64.Zero(), nil diff --git a/sql/types/number.go b/sql/types/number.go index 25fe8f2e1a..1ed143562d 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -1070,7 +1070,6 @@ func convertToInt64(t NumberTypeImpl_, v any, round bool) (int64, sql.ConvertInR } return int64(math.Round(v)), sql.InRange, nil case decimal.Decimal: - // TODO: round? if v.GreaterThan(dec_int64_max) { return dec_int64_max.IntPart(), sql.OutOfRange, nil } From b39da9b815a038494f8b25612b203c8b04b825b6 Mon Sep 17 00:00:00 2001 From: jycor Date: Mon, 15 Sep 2025 21:41:09 +0000 Subject: [PATCH 44/48] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/types/decimal.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/types/decimal.go b/sql/types/decimal.go index 214ffa63a7..0958d8c01a 100644 --- a/sql/types/decimal.go +++ b/sql/types/decimal.go @@ -20,12 +20,12 @@ import ( "math/big" "reflect" + "github.com/dolthub/vitess/go/sqltypes" + "github.com/dolthub/vitess/go/vt/proto/query" "github.com/shopspring/decimal" "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/vitess/go/sqltypes" - "github.com/dolthub/vitess/go/vt/proto/query" ) const ( From 0ccdc325bde2ac82c593b9ca198801bebfb2a434 Mon Sep 17 00:00:00 2001 From: jycor Date: Tue, 16 Sep 2025 18:20:56 +0000 Subject: [PATCH 45/48] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/expression/comparison.go | 1 + sql/expression/convert.go | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index 5694841061..dbb543c46c 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -16,6 +16,7 @@ package expression import ( "fmt" + "github.com/dolthub/vitess/go/mysql" errors "gopkg.in/src-d/go-errors.v1" diff --git a/sql/expression/convert.go b/sql/expression/convert.go index df6d17eda1..7d21cc0c6e 100644 --- a/sql/expression/convert.go +++ b/sql/expression/convert.go @@ -16,10 +16,10 @@ package expression import ( "fmt" - "github.com/dolthub/vitess/go/mysql" "strings" "time" + "github.com/dolthub/vitess/go/mysql" "github.com/dolthub/vitess/go/sqltypes" "github.com/sirupsen/logrus" "gopkg.in/src-d/go-errors.v1" From 0d527bd2515fbda2a2c4703150a93120b29d970e Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 16 Sep 2025 12:18:49 -0700 Subject: [PATCH 46/48] aaaaa --- enginetest/memory_engine_test.go | 2 +- sql/expression/comparison.go | 26 ++++++++++++++++++++++++++ sql/types/utils.go | 1 + 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index 06789ddd7e..f6ed571467 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -200,7 +200,7 @@ func TestSingleQueryPrepared(t *testing.T) { // Convenience test for debugging a single query. Unskip and set to the desired query. func TestSingleScript(t *testing.T) { - //t.Skip() + t.Skip() var scripts = []queries.ScriptTest{ { Name: "sets", diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index dbb543c46c..80d011b2ec 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -174,6 +174,32 @@ func (c *comparison) evalLeftAndRight(ctx *sql.Context, row sql.Row) (interface{ func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{}) (interface{}, interface{}, sql.Type) { lType := c.Left().Type() rType := c.Right().Type() + + //lIsEnumOrSet := types.IsEnum(lType) || types.IsSet(lType) + //rIsEnumOrSet := types.IsEnum(rType) || types.IsSet(rType) + //// If right side is convertible to enum/set, convert. Otherwise, convert left side + //if lIsEnumOrSet && (types.IsText(rType) || types.IsNumber(rType)) { + // if r, inRange, err := lType.Convert(ctx, right); inRange && err == nil { + // return left, r, lType, nil + // } + // l, _, err := types.TypeAwareConversion(ctx, left, lType, rType, false) + // if err != nil { + // return nil, nil, nil, err + // } + // return l, right, rType, nil + //} + //// If left side is convertible to enum/set, convert. Otherwise, convert right side + //if rIsEnumOrSet && (types.IsText(lType) || types.IsNumber(lType)) { + // if l, inRange, err := rType.Convert(ctx, left); inRange && err == nil { + // return l, right, rType, nil + // } + // r, _, err := types.TypeAwareConversion(ctx, right, rType, lType, false) + // if err != nil { + // return nil, nil, nil, err + // } + // return left, r, lType, nil + //} + compType := types.GetCompareType(lType, rType) // Special case for JSON types diff --git a/sql/types/utils.go b/sql/types/utils.go index f5c49e4237..f6374cf6f6 100644 --- a/sql/types/utils.go +++ b/sql/types/utils.go @@ -16,6 +16,7 @@ package types import ( "github.com/dolthub/go-mysql-server/sql" + ) // GetCompareType returns the type to use when comparing values of types left and right. From 2c8eee296fe759bbbc187cddbd98699cbc4b5246 Mon Sep 17 00:00:00 2001 From: jycor Date: Tue, 16 Sep 2025 19:24:38 +0000 Subject: [PATCH 47/48] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/types/utils.go | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/types/utils.go b/sql/types/utils.go index f6374cf6f6..f5c49e4237 100644 --- a/sql/types/utils.go +++ b/sql/types/utils.go @@ -16,7 +16,6 @@ package types import ( "github.com/dolthub/go-mysql-server/sql" - ) // GetCompareType returns the type to use when comparing values of types left and right. From 187dbd9ea856788975501e1a4dc3853591ed0402 Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 16 Sep 2025 18:20:26 -0700 Subject: [PATCH 48/48] more test fixes --- sql/expression/comparison.go | 26 -------------------------- sql/types/number.go | 15 +++------------ 2 files changed, 3 insertions(+), 38 deletions(-) diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index 80d011b2ec..dbb543c46c 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -174,32 +174,6 @@ func (c *comparison) evalLeftAndRight(ctx *sql.Context, row sql.Row) (interface{ func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{}) (interface{}, interface{}, sql.Type) { lType := c.Left().Type() rType := c.Right().Type() - - //lIsEnumOrSet := types.IsEnum(lType) || types.IsSet(lType) - //rIsEnumOrSet := types.IsEnum(rType) || types.IsSet(rType) - //// If right side is convertible to enum/set, convert. Otherwise, convert left side - //if lIsEnumOrSet && (types.IsText(rType) || types.IsNumber(rType)) { - // if r, inRange, err := lType.Convert(ctx, right); inRange && err == nil { - // return left, r, lType, nil - // } - // l, _, err := types.TypeAwareConversion(ctx, left, lType, rType, false) - // if err != nil { - // return nil, nil, nil, err - // } - // return l, right, rType, nil - //} - //// If left side is convertible to enum/set, convert. Otherwise, convert right side - //if rIsEnumOrSet && (types.IsText(lType) || types.IsNumber(lType)) { - // if l, inRange, err := rType.Convert(ctx, left); inRange && err == nil { - // return l, right, rType, nil - // } - // r, _, err := types.TypeAwareConversion(ctx, right, rType, lType, false) - // if err != nil { - // return nil, nil, nil, err - // } - // return left, r, lType, nil - //} - compType := types.GetCompareType(lType, rType) // Special case for JSON types diff --git a/sql/types/number.go b/sql/types/number.go index 1ed143562d..4a0da33bfd 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -1536,10 +1536,7 @@ func convertToUint16(t NumberTypeImpl_, v any, round bool) (uint16, sql.ConvertI if v < 0 { return uint16(math.MaxUint16 - v), sql.OutOfRange, nil } - if round { - return uint16(math.Round(float64(v))), sql.InRange, nil - } - return uint16(v), sql.InRange, nil + return uint16(math.Round(float64(v))), sql.InRange, nil case float64: if v >= float64(math.MaxUint16) { return math.MaxUint16, sql.OutOfRange, nil @@ -1677,10 +1674,7 @@ func convertToUint8(t NumberTypeImpl_, v any, round bool) (uint8, sql.ConvertInR if v < 0 { return uint8(math.MaxUint8 - v), sql.OutOfRange, nil } - if round { - return uint8(math.Round(float64(v))), sql.InRange, nil - } - return uint8(v), sql.InRange, nil + return uint8(math.Round(float64(v))), sql.InRange, nil case float64: if v >= float64(math.MaxUint8) { return math.MaxUint8, sql.OutOfRange, nil @@ -1688,10 +1682,7 @@ func convertToUint8(t NumberTypeImpl_, v any, round bool) (uint8, sql.ConvertInR if v <= 0 { return uint8(math.MaxUint8 - v), sql.OutOfRange, nil } - if round { - return uint8(math.Round(v)), sql.InRange, nil - } - return uint8(v), sql.InRange, nil + return uint8(math.Round(v)), sql.InRange, nil case decimal.Decimal: if v.GreaterThan(dec_uint8_max) { return math.MaxUint8, sql.InRange, nil