diff --git a/enginetest/enginetests.go b/enginetest/enginetests.go index 2132c8a9fc..49d69c0755 100644 --- a/enginetest/enginetests.go +++ b/enginetest/enginetests.go @@ -3961,9 +3961,9 @@ func TestWindowRangeFrames(t *testing.T, harness Harness) { TestQueryWithContext(t, ctx, e, harness, `SELECT sum(y) over (partition by z order by date range between unbounded preceding and interval '1' DAY following) FROM c order by x`, []sql.Row{{float64(1)}, {float64(1)}, {float64(1)}, {float64(1)}, {float64(5)}, {float64(5)}, {float64(10)}, {float64(10)}, {float64(10)}, {float64(10)}}, nil, nil, nil) TestQueryWithContext(t, ctx, e, harness, `SELECT count(y) over (partition by z order by date range between interval '1' DAY following and interval '2' DAY following) FROM c order by x`, []sql.Row{{1}, {1}, {1}, {1}, {1}, {0}, {2}, {2}, {0}, {0}}, nil, nil, nil) TestQueryWithContext(t, ctx, e, harness, `SELECT count(y) over (partition by z order by date range between interval '1' DAY preceding and interval '2' DAY following) FROM c order by x`, []sql.Row{{4}, {4}, {4}, {5}, {2}, {2}, {4}, {4}, {4}, {4}}, nil, nil, nil) + TestQueryWithContext(t, ctx, e, harness, "SELECT sum(y) over (partition by z order by date range interval 'e' DAY preceding) FROM c order by x", []sql.Row{{float64(0)}, {float64(0)}, {float64(0)}, {float64(1)}, {float64(1)}, {float64(3)}, {float64(1)}, {float64(1)}, {float64(4)}, {float64(4)}}, nil, nil, nil) AssertErr(t, e, harness, "SELECT sum(y) over (partition by z range between unbounded preceding and interval '1' DAY following) FROM c order by x", nil, aggregation.ErrRangeInvalidOrderBy) - AssertErr(t, e, harness, "SELECT sum(y) over (partition by z order by date range interval 'e' DAY preceding) FROM c order by x", nil, sql.ErrInvalidValue) } func TestNamedWindows(t *testing.T, harness Harness) { diff --git a/enginetest/evaluation.go b/enginetest/evaluation.go index 87499b3498..bf79d42ac0 100644 --- a/enginetest/evaluation.go +++ b/enginetest/evaluation.go @@ -129,9 +129,18 @@ func TestScriptWithEngine(t *testing.T, e QueryEngine, harness Harness, script q } else if assertion.ExpectedErrStr != "" { AssertErrWithCtx(t, e, harness, ctx, assertion.Query, assertion.Bindings, nil, assertion.ExpectedErrStr) } else if assertion.ExpectedWarning != 0 { - AssertWarningAndTestQuery(t, e, nil, harness, assertion.Query, - assertion.Expected, nil, assertion.ExpectedWarning, assertion.ExpectedWarningsCount, - assertion.ExpectedWarningMessageSubstring, assertion.SkipResultsCheck) + if IsServerEngine(e) && assertion.SkipResultCheckOnServerEngine { + t.Skip() + } + AssertWarningAndTestQuery(t, e, nil, harness, + assertion.Query, + assertion.Expected, + nil, + assertion.ExpectedWarning, + assertion.ExpectedWarningsCount, + assertion.ExpectedWarningMessageSubstring, + assertion.SkipResultsCheck, + ) } else if assertion.SkipResultsCheck { RunQueryWithContext(t, e, harness, nil, assertion.Query) } else if assertion.CheckIndexedAccess { diff --git a/enginetest/queries/function_queries.go b/enginetest/queries/function_queries.go index b5f74422b1..deb08f6ece 100644 --- a/enginetest/queries/function_queries.go +++ b/enginetest/queries/function_queries.go @@ -1052,7 +1052,7 @@ var FunctionQueryTests = []QueryTest{ { Query: `SELECT FLOOR(15728640/1024/1030)`, Expected: []sql.Row{ - {"14"}, + {14}, }, }, { diff --git a/enginetest/queries/json_table_queries.go b/enginetest/queries/json_table_queries.go index c4d7e13bcc..d615f068bf 100644 --- a/enginetest/queries/json_table_queries.go +++ b/enginetest/queries/json_table_queries.go @@ -571,7 +571,7 @@ var JSONTableScriptTests = []ScriptTest{ }, { Query: "SELECT * FROM JSON_TABLE('{\"c1\":\"abc\"}', '$' COLUMNS(c1 INT PATH '$.c1' DEFAULT 'def' ON ERROR)) as jt;", - ExpectedErrStr: "error: 'def' is not a valid value for 'int'", + ExpectedErrStr: "Invalid JSON text in argument 1 to function JSON_TABLE: \"Invalid value.\"", }, }, }, @@ -612,7 +612,7 @@ var JSONTableScriptTests = []ScriptTest{ }, { Query: "SELECT * FROM JSON_TABLE('{\"c1\":\"abc\"}', '$' COLUMNS(c1 INT PATH '$.c1' ERROR ON ERROR)) as jt;", - ExpectedErrStr: "error: 'abc' is not a valid value for 'int'", + ExpectedErrStr: "Invalid JSON text in argument 1 to function JSON_TABLE: \"Invalid value.\"", }, }, }, diff --git a/enginetest/queries/queries.go b/enginetest/queries/queries.go index e54470338b..750d58f47c 100644 --- a/enginetest/queries/queries.go +++ b/enginetest/queries/queries.go @@ -5509,11 +5509,11 @@ SELECT * FROM cte WHERE d = 2;`, }, { Query: "select ceil(i + 0.5) from mytable order by 1", - Expected: []sql.Row{{"2"}, {"3"}, {"4"}}, + Expected: []sql.Row{{2}, {3}, {4}}, }, { Query: "select floor(i + 0.5) from mytable order by 1", - Expected: []sql.Row{{"1"}, {"2"}, {"3"}}, + Expected: []sql.Row{{1}, {2}, {3}}, }, { Query: "select round(i + 0.55, 1) from mytable order by 1", diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index 24a43ea29d..d5d21922b0 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -501,12 +501,12 @@ FROM task_instance INNER JOIN job ON job.id = task_instance.queued_by_job_id INN }, { // TODO: 123.456 is converted to a DECIMAL by Builder.ConvertVal, when it should be a DOUBLE - Skip: true, + SkipResultCheckOnServerEngine: true, // TODO: warnings do not make it to server engine Query: "SELECT '123.456ABC' = 123.456;", Expected: []sql.Row{{true}}, ExpectedWarningsCount: 1, ExpectedWarning: mysql.ERTruncatedWrongValue, - ExpectedWarningMessageSubstring: "Truncated incorrect double value: 123A", + ExpectedWarningMessageSubstring: "Truncated incorrect decimal(65,30) value: 123.456ABC", }, { Query: "SELECT '123.456e2' = 12345.6;", @@ -528,20 +528,18 @@ FROM task_instance INNER JOIN job ON job.id = task_instance.queued_by_job_id INN }, { // Valid float strings used as arguments to functions are truncated not rounded - Skip: true, Query: "SELECT LENGTH(SPACE('1.9'));", Expected: []sql.Row{{1}}, - ExpectedWarningsCount: 2, // MySQL throws two warnings for some reason + ExpectedWarningsCount: 1, // TODO: MySQL throws two warnings for some reason ExpectedWarning: mysql.ERTruncatedWrongValue, }, { // TODO: 123.456 is converted to a DECIMAL by Builder.ConvertVal, when it should be a DOUBLE - Skip: true, Query: "SELECT -'+123.456ABC' = -123.456", Expected: []sql.Row{{true}}, ExpectedWarningsCount: 1, ExpectedWarning: mysql.ERTruncatedWrongValue, - ExpectedWarningMessageSubstring: "Truncated incorrect double value: +123.456ABC", + ExpectedWarningMessageSubstring: "Truncated incorrect decimal(65,30) value: +123.456ABC", }, { Query: "SELECT '0xBEEF' = 0;", @@ -608,12 +606,12 @@ FROM task_instance INNER JOIN job ON job.id = task_instance.queued_by_job_id INN }, { // TODO: 123.456 is converted to a DECIMAL by Builder.ConvertVal, when it should be a DOUBLE - Skip: true, + SkipResultCheckOnServerEngine: true, // TODO: warnings do not make it to server engine Query: "SELECT '123.456ABC' in (123.456);", Expected: []sql.Row{{true}}, ExpectedWarningsCount: 1, ExpectedWarning: mysql.ERTruncatedWrongValue, - ExpectedWarningMessageSubstring: "Truncated incorrect double value: 123A", + ExpectedWarningMessageSubstring: "Truncated incorrect decimal(65,30) value: 123.456ABC", }, { Query: "SELECT '123.456e2' in (12345.6);", @@ -975,13 +973,13 @@ FROM task_instance INNER JOIN job ON job.id = task_instance.queued_by_job_id INN Dialect: "mysql", Name: "strings cast to numbers", SetUpScript: []string{ - "create table test01(pk varchar(20) primary key)", + "create table test01(pk varchar(20) primary key);", `insert into test01 values (' 3 12 4'), (' 3.2 12 4'),('-3.1234'),('-3.1a'),('-5+8'),('+3.1234'), ('11d'),('11wha?'),('11'),('12'),('1a1'),('a1a1'),('11-5'), - ('3. 12 4'),('5.932887e+07'),('5.932887e+07abc'),('5.932887e7'),('5.932887e7abc')`, - "create table test02(pk int primary key)", - "insert into test02 values(11),(12),(13),(14),(15)", + ('3. 12 4'),('5.932887e+07'),('5.932887e+07abc'),('5.932887e7'),('5.932887e7abc');`, + "create table test02(pk int primary key);", + "insert into test02 values(11),(12),(13),(14),(15);", }, Assertions: []ScriptTestAssertion{ { @@ -1082,10 +1080,8 @@ FROM task_instance INNER JOIN job ON job.id = task_instance.queued_by_job_id INN {"5.932887e7abc", uint64(5)}, {"a1a1", uint64(0)}, }, - // TODO: Should be 19. Missing warnings for "Cast to unsigned converted negative integer to its positive - // complement" (1105) https://github.com/dolthub/dolt/issues/9840 - ExpectedWarningsCount: 16, - ExpectedWarning: mysql.ERTruncatedWrongValue, + ExpectedWarningsCount: 19, + // Can't check multiple different warnings }, { Query: "select pk, cast(pk as decimal(12,3)) from test01", @@ -1119,7 +1115,6 @@ FROM task_instance INNER JOIN job ON job.id = task_instance.queued_by_job_id INN ExpectedWarningsCount: 0, }, - // TODO: these are not directly testing casting { // https://github.com/dolthub/dolt/issues/9739 Skip: true, @@ -1163,19 +1158,15 @@ FROM task_instance INNER JOIN job ON job.id = task_instance.queued_by_job_id INN ExpectedWarning: mysql.ERTruncatedWrongValue, }, { - // https://github.com/dolthub/dolt/issues/9739 - Skip: true, Dialect: "mysql", - Query: "select * from test02 where pk in ('11asdf')", - Expected: []sql.Row{{"11"}}, + Query: "select * from test02 where pk in ('11asdf');", + Expected: []sql.Row{{11}}, ExpectedWarningsCount: 1, ExpectedWarning: mysql.ERTruncatedWrongValue, }, { - // https://github.com/dolthub/dolt/issues/9739 - Skip: true, Dialect: "mysql", - Query: "select * from test02 where pk='11.12asdf'", + Query: "select * from test02 where pk='11.12asdf';", Expected: []sql.Row{}, ExpectedWarningsCount: 1, ExpectedWarning: mysql.ERTruncatedWrongValue, diff --git a/server/handler_test.go b/server/handler_test.go index 03bf918754..6ebf543561 100644 --- a/server/handler_test.go +++ b/server/handler_test.go @@ -1572,7 +1572,7 @@ func TestStatusVariableMaxUsedConnections(t *testing.T) { } checkGlobalStatVar(t, "Max_used_connections", uint64(0)) - checkGlobalStatVar(t, "Max_used_connections_time", "") + checkGlobalStatVar(t, "Max_used_connections_time", uint64(0)) conn1 := newConn(1) handler.NewConnection(conn1) diff --git a/sql/columndefault.go b/sql/columndefault.go index 1f61e01b6e..1e464927ec 100644 --- a/sql/columndefault.go +++ b/sql/columndefault.go @@ -82,9 +82,15 @@ func (e *ColumnDefaultValue) Eval(ctx *Context, r Row) (interface{}, error) { if e.OutType != nil { var inRange ConvertInRange - if val, inRange, err = e.OutType.Convert(ctx, val); err != nil { + if roundType, isRoundType := e.OutType.(RoundingNumberType); isRoundType { + val, inRange, err = roundType.ConvertRound(ctx, val) + } else { + val, inRange, err = e.OutType.Convert(ctx, val) + } + if err != nil { return nil, ErrIncompatibleDefaultType.New() - } else if !inRange { + } + if !inRange { return nil, ErrValueOutOfRange.New(val, e.OutType) } } @@ -227,13 +233,18 @@ func (e *ColumnDefaultValue) CheckType(ctx *Context) error { if val == nil && !e.ReturnNil { return ErrIncompatibleDefaultType.New() } - _, inRange, err := e.OutType.Convert(ctx, val) + var inRange ConvertInRange + if roundType, isRoundType := e.OutType.(RoundingNumberType); isRoundType { + val, inRange, err = roundType.ConvertRound(ctx, val) + } else { + val, inRange, err = e.OutType.Convert(ctx, val) + } if err != nil { return ErrIncompatibleDefaultType.Wrap(err) - } else if !inRange { + } + if !inRange { return ErrIncompatibleDefaultType.Wrap(ErrValueOutOfRange.New(val, e.Expr)) } - } return nil } diff --git a/sql/core.go b/sql/core.go index d7e38e310e..c1e1f90b2a 100644 --- a/sql/core.go +++ b/sql/core.go @@ -24,10 +24,8 @@ import ( "strings" "sync/atomic" "time" - "unicode" "unsafe" - "github.com/dolthub/vitess/go/mysql" "github.com/shopspring/decimal" "gopkg.in/src-d/go-errors.v1" @@ -340,48 +338,6 @@ const ( NumericCutSet = " \t\n\r" ) -// TODO: type processing logic should all be in the types package -func TrimStringToNumberPrefix(ctx *Context, s string, isInt bool) string { - if isInt { - s = strings.TrimLeft(s, IntCutSet) - } else { - s = strings.TrimLeft(s, NumericCutSet) - } - - seenDigit := false - seenDot := false - seenExp := false - signIndex := 0 - - for i := 0; i < len(s); i++ { - char := rune(s[i]) - if unicode.IsDigit(char) { - seenDigit = true - } else if char == '.' && !seenDot && !isInt { - seenDot = true - } else if (char == 'e' || char == 'E') && !seenExp && seenDigit && !isInt { - seenExp = true - signIndex = i + 1 - } else if !((char == '-' || char == '+') && i == signIndex) { - // TODO: this should not happen here, and it should use sql.ErrIncorrectTruncation - if isInt { - ctx.Warn(mysql.ERTruncatedWrongValue, "Truncated incorrect INTEGER value: '%s'", s) - } else { - ctx.Warn(mysql.ERTruncatedWrongValue, "Truncated incorrect DOUBLE value: '%s'", s) - } - return convertEmptyStringToZero(s[:i]) - } - } - return convertEmptyStringToZero(s) -} - -func convertEmptyStringToZero(s string) string { - if s == "" { - return "0" - } - return s -} - var ErrVectorInvalidBinaryLength = errors.NewKind("cannot convert BINARY(%d) to vector, byte length must be a multiple of 4 bytes") // DecodeVector decodes a byte slice that represents a vector. This is needed for distance functions. diff --git a/sql/expression/arithmetic.go b/sql/expression/arithmetic.go index 83adb6e8a1..94dde18e5c 100644 --- a/sql/expression/arithmetic.go +++ b/sql/expression/arithmetic.go @@ -691,9 +691,12 @@ func (e *UnaryMinus) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } if !types.IsNumber(e.Child.Type()) { - child, err = decimal.NewFromString(fmt.Sprintf("%v", child)) + child, _, err = types.InternalDecimalType.Convert(ctx, child) if err != nil { - child = 0.0 + if !sql.ErrTruncatedIncorrect.Is(err) { + child = 0.0 + } + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } } @@ -735,7 +738,7 @@ func (e *UnaryMinus) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { case uint64: return -int64(n), nil case decimal.Decimal: - return n.Neg(), err + return n.Neg(), nil case string: // try getting int out of string value i, iErr := strconv.ParseInt(n, 10, 64) diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index 7ea42c475d..97e5f4ba6e 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -141,7 +141,7 @@ func (c *comparison) Compare(ctx *sql.Context, row sql.Row) (int, error) { return c.Left().Type().Compare(ctx, left, right) } - l, r, compareType, err := c.CastLeftAndRight(ctx, left, right) + l, r, compareType, err := c.castLeftAndRight(ctx, left, right) if err != nil { return 0, err } @@ -171,7 +171,7 @@ func (c *comparison) evalLeftAndRight(ctx *sql.Context, row sql.Row) (interface{ return left, right, nil } -func (c *comparison) CastLeftAndRight(ctx *sql.Context, left, right interface{}) (interface{}, interface{}, sql.Type, error) { +func (c *comparison) castLeftAndRight(ctx *sql.Context, left, right interface{}) (interface{}, interface{}, sql.Type, error) { leftType := c.Left().Type() rightType := c.Right().Type() @@ -452,7 +452,7 @@ func (e *NullSafeEquals) Compare(ctx *sql.Context, row sql.Row) (int, error) { } var compareType sql.Type - left, right, compareType, err = e.CastLeftAndRight(ctx, left, right) + left, right, compareType, err = e.castLeftAndRight(ctx, left, right) if err != nil { return 0, err } diff --git a/sql/expression/convert.go b/sql/expression/convert.go index b15548cd67..3201b0da96 100644 --- a/sql/expression/convert.go +++ b/sql/expression/convert.go @@ -15,9 +15,7 @@ package expression import ( - "encoding/hex" "fmt" - "strconv" "strings" "time" @@ -349,38 +347,43 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s } return d, nil case ConvertToDecimal: - value, err := prepareForNumericContext(ctx, val, originType, false) + value, err := types.ConvertHexBlobToDecimalForNumericContext(val, originType) if err != nil { return nil, err } dt := createConvertedDecimalType(typeLength, typeScale, false) d, _, err := dt.Convert(ctx, value) if err != nil { - return dt.Zero(), nil + if !sql.ErrTruncatedIncorrect.Is(err) { + return dt.Zero(), nil + } + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } return d, nil case ConvertToFloat: - value, err := prepareForNumericContext(ctx, val, originType, false) + value, err := types.ConvertHexBlobToDecimalForNumericContext(val, originType) if err != nil { return nil, err } d, _, err := types.Float32.Convert(ctx, value) if err != nil { - return types.Float32.Zero(), nil + if !sql.ErrTruncatedIncorrect.Is(err) { + return types.Float64.Zero(), nil + } + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } return d, nil case ConvertToDouble, ConvertToReal: - value, err := prepareForNumericContext(ctx, val, originType, false) + value, err := types.ConvertHexBlobToDecimalForNumericContext(val, originType) if err != nil { return nil, err } d, _, err := types.Float64.Convert(ctx, value) if err != nil { - if sql.ErrTruncatedIncorrect.Is(err) { - ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) - return d, nil + if !sql.ErrTruncatedIncorrect.Is(err) { + return types.Float64.Zero(), nil } - return types.Float64.Zero(), nil + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } return d, nil case ConvertToJSON: @@ -390,15 +393,17 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s } return js, nil case ConvertToSigned: - value, err := prepareForNumericContext(ctx, val, originType, true) + value, err := types.ConvertHexBlobToDecimalForNumericContext(val, originType) if err != nil { return nil, err } num, _, err := types.Int64.Convert(ctx, value) if err != nil { - return types.Int64.Zero(), nil + if !sql.ErrTruncatedIncorrect.Is(err) { + return types.Int64.Zero(), nil + } + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } - return num, nil case ConvertToTime: t, _, err := types.Time.Convert(ctx, val) @@ -407,31 +412,32 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s } return t, nil case ConvertToUnsigned: - value, err := prepareForNumericContext(ctx, val, originType, true) + value, err := types.ConvertHexBlobToDecimalForNumericContext(val, originType) if err != nil { return nil, err } - num, _, err := types.Uint64.Convert(ctx, value) + num, inRange, err := types.Uint64.Convert(ctx, value) if err != nil { - num, _, err = types.Int64.Convert(ctx, value) - if err != nil { + if !sql.ErrTruncatedIncorrect.Is(err) { return types.Uint64.Zero(), nil } - return uint64(num.(int64)), nil + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) + } + if !inRange { + ctx.Warn(1105, "Cast to unsigned converted negative integer to its positive complement") } return num, nil case ConvertToYear: - value, err := convertHexBlobToDecimalForNumericContext(val, originType) + value, err := types.ConvertHexBlobToDecimalForNumericContext(val, originType) if err != nil { return nil, err } num, _, err := types.Uint64.Convert(ctx, value) if err != nil { - num, _, err = types.Int64.Convert(ctx, value) - if err != nil { - return types.Uint64.Zero(), nil + if !sql.ErrTruncatedIncorrect.Is(err) { + return types.Float64.Zero(), nil } - return uint64(num.(int64)), nil + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } return num, nil default: @@ -483,27 +489,3 @@ func createConvertedDecimalType(length, scale int, logErrors bool) sql.DecimalTy } return types.InternalDecimalType } - -// prepareForNumberContext makes necessary preparations to strings and byte arrays for conversions to numbers -func prepareForNumericContext(ctx *sql.Context, val interface{}, originType sql.Type, isInt bool) (interface{}, error) { - if s, isString := val.(string); isString && types.IsTextOnly(originType) { - return sql.TrimStringToNumberPrefix(ctx, s, isInt), nil - } - return convertHexBlobToDecimalForNumericContext(val, originType) -} - -// convertHexBlobToDecimalForNumericContext converts byte array value to unsigned int value if originType is BLOB type. -// This function is called when convertTo type is number type only. The hex literal values are parsed into blobs as -// binary string as default, but for numeric context, the value should be a number. -// Byte arrays of other SQL types are not handled here. -func convertHexBlobToDecimalForNumericContext(val interface{}, originType sql.Type) (interface{}, error) { - if bin, isBinary := val.([]byte); isBinary && types.IsBlobType(originType) { - stringVal := hex.EncodeToString(bin) - decimalNum, err := strconv.ParseUint(stringVal, 16, 64) - if err != nil { - return nil, errors.NewKind("failed to convert hex blob value to unsigned int").New() - } - val = decimalNum - } - return val, nil -} diff --git a/sql/expression/function/ceil_round_floor.go b/sql/expression/function/ceil_round_floor.go index 8db7276cd8..c261788401 100644 --- a/sql/expression/function/ceil_round_floor.go +++ b/sql/expression/function/ceil_round_floor.go @@ -18,6 +18,7 @@ import ( "fmt" "math" + "github.com/dolthub/vitess/go/mysql" "github.com/shopspring/decimal" "github.com/dolthub/go-mysql-server/sql" @@ -51,10 +52,13 @@ func (c *Ceil) Description() string { // Type implements the Expression interface. func (c *Ceil) Type() sql.Type { childType := c.Child.Type() - if types.IsInteger(childType) { - return childType + if types.IsUnsigned(childType) { + return types.Uint64 } - return types.Int32 + if types.IsNumber(childType) { + return types.Int64 + } + return types.Float64 } // CollationCoercibility implements the interface sql.CollationCoercible. @@ -77,36 +81,32 @@ func (c *Ceil) WithChildren(children ...sql.Expression) (sql.Expression, error) // Eval implements the Expression interface. func (c *Ceil) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { child, err := c.Child.Eval(ctx, row) - if err != nil { return nil, err } - if child == nil { return nil, nil } - - // non number type will be caught here if !types.IsNumber(c.Child.Type()) { child, _, err = types.Float64.Convert(ctx, child) if err != nil { - return int32(0), nil + if !sql.ErrTruncatedIncorrect.Is(err) { + return nil, err + } + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } - - return int32(math.Ceil(child.(float64))), nil } - // if it's number type and not float value, it does not need ceil-ing switch num := child.(type) { - case float64: - return math.Ceil(num), nil case float32: - return float32(math.Ceil(float64(num))), nil + child = math.Ceil(float64(num)) + case float64: + child = math.Ceil(num) case decimal.Decimal: - return num.Ceil(), nil - default: - return child, nil + child = num.Ceil() } + child, _, _ = c.Type().Convert(ctx, child) + return child, nil } // Floor returns the biggest integer value not less than X. @@ -135,10 +135,13 @@ func (f *Floor) Description() string { // Type implements the Expression interface. func (f *Floor) Type() sql.Type { childType := f.Child.Type() - if types.IsInteger(childType) { - return childType + if types.IsUnsigned(childType) { + return types.Uint64 } - return types.Int32 + if types.IsNumber(childType) { + return types.Int64 + } + return types.Float64 } // CollationCoercibility implements the interface sql.CollationCoercible. @@ -161,36 +164,33 @@ func (f *Floor) WithChildren(children ...sql.Expression) (sql.Expression, error) // Eval implements the Expression interface. func (f *Floor) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { child, err := f.Child.Eval(ctx, row) - if err != nil { return nil, err } - if child == nil { return nil, nil } - - // non number type will be caught here if !types.IsNumber(f.Child.Type()) { child, _, err = types.Float64.Convert(ctx, child) if err != nil { - return int32(0), nil + if !sql.ErrTruncatedIncorrect.Is(err) { + return nil, err + } + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } - - return int32(math.Floor(child.(float64))), nil } - // if it's number type and not float value, it does not need floor-ing + // if it's number type and not float value, it does not need ceil-ing switch num := child.(type) { - case float64: - return math.Floor(num), nil case float32: - return float32(math.Floor(float64(num))), nil + child = math.Floor(float64(num)) + case float64: + child = math.Floor(num) case decimal.Decimal: - return num.Floor(), nil - default: - return child, nil + child = num.Floor() } + child, _, _ = f.Type().Convert(ctx, child) + return child, nil } // Round returns the number (x) with (d) requested decimal places. @@ -244,62 +244,61 @@ func (r *Round) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { if err != nil { return nil, err } - if val == nil { return nil, nil } - decType := types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, types.DecimalTypeMaxScale) - val, _, err = decType.Convert(ctx, val) - if err != nil { - // TODO: truncate - return nil, err + val, _, err = types.InternalDecimalType.Convert(ctx, val) + if err != nil && sql.ErrTruncatedIncorrect.Is(err) { + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } prec := int32(0) if r.RightChild != nil { - var tmp interface{} + var tmp any tmp, err = r.RightChild.Eval(ctx, row) if err != nil { return nil, err } - if tmp == nil { return nil, nil } - - if tmp != nil { - tmp, _, err = types.Int32.Convert(ctx, tmp) - if err != nil { - // TODO: truncate + tmp, _, err = types.Int32.Convert(ctx, tmp) + if err != nil { + if !sql.ErrTruncatedIncorrect.Is(err) { return nil, err } - prec = tmp.(int32) - // MySQL cuts off at 30 for larger values - // TODO: these limits are fine only because we can't handle decimals larger than this - if prec > types.DecimalTypeMaxPrecision { - prec = types.DecimalTypeMaxPrecision - } - if prec < -types.DecimalTypeMaxScale { - prec = -types.DecimalTypeMaxScale - } + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) + } + prec = tmp.(int32) + // MySQL cuts off at 30 for larger values + // TODO: these limits are fine only because we can't handle decimals larger than this + if prec > types.DecimalTypeMaxPrecision { + prec = types.DecimalTypeMaxPrecision + } + if prec < -types.DecimalTypeMaxScale { + prec = -types.DecimalTypeMaxScale } } var res interface{} tmp := val.(decimal.Decimal).Round(prec) - if types.IsSigned(r.LeftChild.Type()) { + lType := r.LeftChild.Type() + if types.IsSigned(lType) { res, _, err = types.Int64.Convert(ctx, tmp) - } else if types.IsUnsigned(r.LeftChild.Type()) { + } else if types.IsUnsigned(lType) { res, _, err = types.Uint64.Convert(ctx, tmp) - } else if types.IsFloat(r.LeftChild.Type()) { + } else if types.IsFloat(lType) { res, _, err = types.Float64.Convert(ctx, tmp) - } else if types.IsDecimal(r.LeftChild.Type()) { + } else if types.IsDecimal(lType) { res = tmp - } else if types.IsTextBlob(r.LeftChild.Type()) { + } else if types.IsTextBlob(lType) { res, _, err = types.Float64.Convert(ctx, tmp) } - + if err != nil && sql.ErrTruncatedIncorrect.Is(err) { + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) + err = nil + } return res, err } diff --git a/sql/expression/function/ceil_round_floor_test.go b/sql/expression/function/ceil_round_floor_test.go index 3b111c0a43..373ceaa940 100644 --- a/sql/expression/function/ceil_round_floor_test.go +++ b/sql/expression/function/ceil_round_floor_test.go @@ -34,17 +34,21 @@ func TestCeil(t *testing.T) { err *errors.Kind }{ {"float64 is nil", types.Float64, sql.NewRow(nil), nil, nil}, - {"float64 is ok", types.Float64, sql.NewRow(5.8), float64(6), nil}, + {"float64 is ok", types.Float64, sql.NewRow(5.8), int64(6), nil}, {"float32 is nil", types.Float32, sql.NewRow(nil), nil, nil}, - {"float32 is ok", types.Float32, sql.NewRow(float32(5.8)), float32(6), nil}, + {"float32 is ok", types.Float32, sql.NewRow(float32(5.8)), int64(6), nil}, {"int32 is nil", types.Int32, sql.NewRow(nil), nil, nil}, - {"int32 is ok", types.Int32, sql.NewRow(int32(6)), int32(6), nil}, + {"int32 is ok", types.Int32, sql.NewRow(int32(6)), int64(6), nil}, {"int64 is nil", types.Int64, sql.NewRow(nil), nil, nil}, {"int64 is ok", types.Int64, sql.NewRow(int64(6)), int64(6), nil}, {"blob is nil", types.Blob, sql.NewRow(nil), nil, nil}, - {"blob is ok", types.Blob, sql.NewRow([]byte{1, 2, 3}), int32(66051), nil}, - {"string int is ok", types.Text, sql.NewRow("1"), int32(1), nil}, - {"string float is ok", types.Text, sql.NewRow("1.2"), int32(2), nil}, + {"blob is ok", types.Blob, sql.NewRow([]byte{1, 2, 3}), 66051.0, nil}, + {"string int is ok", types.Text, sql.NewRow("1"), 1.0, nil}, + {"string float is ok", types.Text, sql.NewRow("1.2"), 2.0, nil}, + {"empty string is 0", types.Text, sql.NewRow(""), 0.0, nil}, + {"strings are truncated", types.Text, sql.NewRow("1.2abc"), 2.0, nil}, + {"completely invalid string is truncated to 0", types.Text, sql.NewRow("notavalue"), 0.0, nil}, + {"float notation is properly truncated", types.Text, sql.NewRow("1.234e2blah"), 124.0, nil}, } for _, tt := range testCases { @@ -65,7 +69,15 @@ func TestCeil(t *testing.T) { require.Equal(tt.expected, result) } - require.True(types.IsInteger(f.Type())) + // unsigned -> unsigned, signed -> signed, everything else -> double + resType := f.Type() + if types.IsUnsigned(tt.rowType) { + require.True(resType.Equals(types.Uint64)) + } else if types.IsNumber(tt.rowType) { + require.True(resType.Equals(types.Int64)) + } else { + require.True(resType.Equals(types.Float64)) + } require.False(f.IsNullable()) }) } @@ -80,17 +92,21 @@ func TestFloor(t *testing.T) { err *errors.Kind }{ {"float64 is nil", types.Float64, sql.NewRow(nil), nil, nil}, - {"float64 is ok", types.Float64, sql.NewRow(5.8), float64(5), nil}, + {"float64 is ok", types.Float64, sql.NewRow(5.8), int64(5), nil}, {"float32 is nil", types.Float32, sql.NewRow(nil), nil, nil}, - {"float32 is ok", types.Float32, sql.NewRow(float32(5.8)), float32(5), nil}, + {"float32 is ok", types.Float32, sql.NewRow(float32(5.8)), int64(5), nil}, {"int32 is nil", types.Int32, sql.NewRow(nil), nil, nil}, - {"int32 is ok", types.Int32, sql.NewRow(int32(6)), int32(6), nil}, + {"int32 is ok", types.Int32, sql.NewRow(int32(6)), int64(6), nil}, {"int64 is nil", types.Int64, sql.NewRow(nil), nil, nil}, {"int64 is ok", types.Int64, sql.NewRow(int64(6)), int64(6), nil}, {"blob is nil", types.Blob, sql.NewRow(nil), nil, nil}, - {"blob is ok", types.Blob, sql.NewRow([]byte{1, 2, 3}), int32(66051), nil}, - {"string int is ok", types.Text, sql.NewRow("1"), int32(1), nil}, - {"string float is ok", types.Text, sql.NewRow("1.2"), int32(1), nil}, + {"blob is ok", types.Blob, sql.NewRow([]byte{1, 2, 3}), float64(66051), nil}, + {"string int is ok", types.Text, sql.NewRow("1"), float64(1), nil}, + {"string float is ok", types.Text, sql.NewRow("1.2"), float64(1), nil}, + {"empty string is 0", types.Text, sql.NewRow(""), 0.0, nil}, + {"strings are truncated", types.Text, sql.NewRow("1.2abc"), float64(1), nil}, + {"completely invalid string is truncated to 0", types.Text, sql.NewRow("notavalue"), 0.0, nil}, + {"float notation is properly truncated", types.Text, sql.NewRow("1.234e2blah"), 123.0, nil}, } for _, tt := range testCases { @@ -111,7 +127,15 @@ func TestFloor(t *testing.T) { require.Equal(tt.expected, result) } - require.True(types.IsInteger(f.Type())) + // signed -> signed, unsigned -> unsigned, everything else -> double + resType := f.Type() + if types.IsUnsigned(tt.rowType) { + require.True(resType.Equals(types.Uint64)) + } else if types.IsNumber(tt.rowType) { + require.True(resType.Equals(types.Int64)) + } else { + require.True(resType.Equals(types.Float64)) + } require.False(f.IsNullable()) }) } @@ -603,7 +627,44 @@ func TestRound(t *testing.T) { exp: 123.0, }, - // TODO: tests truncated strings + { + name: "invalid text float is just 0.0", + xExpr: expression.NewLiteral("notafloat", types.Text), + dExpr: expression.NewLiteral("stillnotafloat", types.Text), + exp: 0.0, + }, + { + name: "invalid text float with d is just 0.0", + xExpr: expression.NewLiteral("notafloat", types.Text), + exp: 0.0, + }, + { + name: "text float truncates rounds down", + xExpr: expression.NewLiteral("123.456abc", types.Text), + exp: 123.0, + }, + { + name: "text float truncates rounds up", + xExpr: expression.NewLiteral("123.999abc", types.Text), + exp: 124.0, + }, + { + name: "text float with d truncates", + xExpr: expression.NewLiteral("123.456abc", types.Text), + dExpr: expression.NewLiteral("1abc", types.Text), + exp: 123.5, + }, + { + name: "text float signed notation truncates", + xExpr: expression.NewLiteral("+1.23456e2abcefg", types.Text), + exp: 123.0, + }, + { + name: "text float signed notation with d", + xExpr: expression.NewLiteral("+1.23456e2abcde", types.Text), + dExpr: expression.NewLiteral("0.2e1abcde", types.Text), + exp: 123.0, + }, } for _, tt := range testCases { diff --git a/sql/expression/function/inet_convert.go b/sql/expression/function/inet_convert.go index b612ee43d0..5ec844383a 100644 --- a/sql/expression/function/inet_convert.go +++ b/sql/expression/function/inet_convert.go @@ -242,7 +242,7 @@ func (i *InetNtoa) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { // Convert val into int ipv4int, _, err := types.Int32.Convert(ctx, val) - if ipv4int != nil && err != nil { + if ipv4int != nil && err != nil && !sql.ErrTruncatedIncorrect.Is(err) { return nil, sql.ErrInvalidType.New(reflect.TypeOf(val).String()) } diff --git a/sql/expression/function/logarithm.go b/sql/expression/function/logarithm.go index 14923692d7..9877a748c0 100644 --- a/sql/expression/function/logarithm.go +++ b/sql/expression/function/logarithm.go @@ -17,8 +17,8 @@ package function import ( "fmt" "math" - "reflect" + "github.com/dolthub/vitess/go/mysql" "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" @@ -124,14 +124,16 @@ func (l *LogBase) Eval( if err != nil { return nil, err } - if v == nil { return nil, nil } val, _, err := types.Float64.Convert(ctx, v) if err != nil { - return nil, sql.ErrInvalidType.New(reflect.TypeOf(v)) + if !sql.ErrTruncatedIncorrect.Is(err) { + return nil, err + } + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } return computeLog(ctx, val.(float64), l.base) } @@ -206,28 +208,30 @@ func (l *Log) Eval( if err != nil { return nil, err } - if left == nil { return nil, nil } - lhs, _, err := types.Float64.Convert(ctx, left) if err != nil { - return nil, sql.ErrInvalidType.New(reflect.TypeOf(left)) + if !sql.ErrTruncatedIncorrect.Is(err) { + return nil, err + } + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } right, err := l.RightChild.Eval(ctx, row) if err != nil { return nil, err } - if right == nil { return nil, nil } - rhs, _, err := types.Float64.Convert(ctx, right) if err != nil { - return nil, sql.ErrInvalidType.New(reflect.TypeOf(right)) + if !sql.ErrTruncatedIncorrect.Is(err) { + return nil, err + } + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } // rhs becomes value, lhs becomes base @@ -252,6 +256,6 @@ func computeLog(ctx *sql.Context, v float64, base float64) (interface{}, error) return math.Log(v), nil default: // LOG(BASE,V) is equivalent to LOG(V) / LOG(BASE). - return float64(math.Log(v) / math.Log(base)), nil + return math.Log(v) / math.Log(base), nil } } diff --git a/sql/expression/function/logarithm_test.go b/sql/expression/function/logarithm_test.go index 87dbd6cac0..359e29bd68 100644 --- a/sql/expression/function/logarithm_test.go +++ b/sql/expression/function/logarithm_test.go @@ -40,12 +40,16 @@ func TestLn(t *testing.T) { {"Input value is null", types.Float64, sql.NewRow(nil), nil, nil}, {"Input value is zero", types.Float64, sql.NewRow(0), nil, nil}, {"Input value is negative", types.Float64, sql.NewRow(-1), nil, nil}, - {"Input value is valid string", types.Float64, sql.NewRow("2"), float64(0.6931471805599453), nil}, - {"Input value is invalid string", types.Float64, sql.NewRow("aaa"), nil, sql.ErrInvalidType}, - {"Input value is valid float64", types.Float64, sql.NewRow(3), float64(1.0986122886681096), nil}, - {"Input value is valid float32", types.Float32, sql.NewRow(float32(6)), float64(1.791759469228055), nil}, - {"Input value is valid int64", types.Int64, sql.NewRow(int64(8)), float64(2.0794415416798357), nil}, - {"Input value is valid int32", types.Int32, sql.NewRow(int32(10)), float64(2.302585092994046), nil}, + {"Input value is valid float64", types.Float64, sql.NewRow(3), 1.0986122886681096, nil}, + {"Input value is valid float32", types.Float32, sql.NewRow(float32(6)), 1.791759469228055, nil}, + {"Input value is valid int64", types.Int64, sql.NewRow(int64(8)), 2.0794415416798357, nil}, + {"Input value is valid int32", types.Int32, sql.NewRow(int32(10)), 2.302585092994046, nil}, + {"Input value is empty string", types.Text, sql.NewRow(""), nil, nil}, + {"Input value is valid string int", types.Text, sql.NewRow("2"), 0.6931471805599453, nil}, + {"Input value is valid string float", types.Text, sql.NewRow("123.456"), 4.815884817283264, nil}, + {"Input value is invalid string", types.Text, sql.NewRow("aaa"), nil, nil}, + {"Input value is truncated string float", types.Text, sql.NewRow("123.456abc"), 4.815884817283264, nil}, + {"Input value is string using float notation", types.Text, sql.NewRow("1.23456e+2notanumber"), 4.815884817283264, nil}, } for _, tt := range testCases { @@ -80,12 +84,16 @@ func TestLog2(t *testing.T) { {"Input value is null", types.Float64, sql.NewRow(nil), nil, nil}, {"Input value is zero", types.Float64, sql.NewRow(0), nil, nil}, {"Input value is negative", types.Float64, sql.NewRow(-1), nil, nil}, - {"Input value is valid string", types.Float64, sql.NewRow("2"), float64(1), nil}, - {"Input value is invalid string", types.Float64, sql.NewRow("aaa"), nil, sql.ErrInvalidType}, - {"Input value is valid float64", types.Float64, sql.NewRow(3), float64(1.5849625007211563), nil}, - {"Input value is valid float32", types.Float32, sql.NewRow(float32(6)), float64(2.584962500721156), nil}, - {"Input value is valid int64", types.Int64, sql.NewRow(int64(8)), float64(3), nil}, - {"Input value is valid int32", types.Int32, sql.NewRow(int32(10)), float64(3.321928094887362), nil}, + {"Input value is valid float64", types.Float64, sql.NewRow(3), 1.5849625007211563, nil}, + {"Input value is valid float32", types.Float32, sql.NewRow(float32(6)), 2.584962500721156, nil}, + {"Input value is valid int64", types.Int64, sql.NewRow(int64(8)), 3.0, nil}, + {"Input value is valid int32", types.Int32, sql.NewRow(int32(10)), 3.321928094887362, nil}, + {"Input value is empty string", types.Text, sql.NewRow(""), nil, nil}, + {"Input value is valid string int", types.Text, sql.NewRow("123"), 6.94251450533924, nil}, + {"Input value is valid string float", types.Text, sql.NewRow("123.456"), 6.947853143387016, nil}, + {"Input value is invalid string", types.Text, sql.NewRow("aaa"), nil, nil}, + {"Input value is truncated string float", types.Text, sql.NewRow("123.456abc"), 6.947853143387016, nil}, + {"Input value is is truncated string using float notation", types.Text, sql.NewRow("1.23456e+2notanumber"), 6.947853143387016, nil}, } for _, tt := range testCases { @@ -120,12 +128,16 @@ func TestLog10(t *testing.T) { {"Input value is null", types.Float64, sql.NewRow(0), nil, nil}, {"Input value is zero", types.Float64, sql.NewRow(0), nil, nil}, {"Input value is negative", types.Float64, sql.NewRow(-1), nil, nil}, - {"Input value is valid string", types.Float64, sql.NewRow("2"), float64(0.3010299956639812), nil}, - {"Input value is invalid string", types.Float64, sql.NewRow("aaa"), nil, sql.ErrInvalidType}, - {"Input value is valid float64", types.Float64, sql.NewRow(3), float64(0.4771212547196624), nil}, - {"Input value is valid float32", types.Float32, sql.NewRow(float32(6)), float64(0.7781512503836436), nil}, - {"Input value is valid int64", types.Int64, sql.NewRow(int64(8)), float64(0.9030899869919435), nil}, - {"Input value is valid int32", types.Int32, sql.NewRow(int32(10)), float64(1), nil}, + {"Input value is valid float64", types.Float64, sql.NewRow(3), 0.4771212547196624, nil}, + {"Input value is valid float32", types.Float32, sql.NewRow(float32(6)), 0.7781512503836436, nil}, + {"Input value is valid int64", types.Int64, sql.NewRow(int64(8)), 0.9030899869919435, nil}, + {"Input value is valid int32", types.Int32, sql.NewRow(int32(10)), 1, nil}, + {"Input value is empty string", types.Text, sql.NewRow(""), nil, nil}, + {"Input value is valid string int", types.Text, sql.NewRow("2"), 0.3010299956639812, nil}, + {"Input value is valid string float", types.Text, sql.NewRow("123.456"), 2.0915122016277716, nil}, + {"Input value is invalid string", types.Text, sql.NewRow("aaa"), nil, nil}, + {"Input value is truncated string float", types.Text, sql.NewRow("123.456abc"), 2.0915122016277716, nil}, + {"Input value is is truncated string using float notation", types.Text, sql.NewRow("1.23456e+2notanumber"), 2.0915122016277716, nil}, } for _, tt := range testCases { @@ -172,24 +184,30 @@ func TestLog(t *testing.T) { {"Input base is nil", []sql.Expression{expression.NewLiteral(nil, types.Float64), expression.NewLiteral(float64(10), types.Float64)}, nil, nil}, {"Input base is zero", []sql.Expression{expression.NewLiteral(float64(0), types.Float64), expression.NewLiteral(float64(10), types.Float64)}, nil, nil}, {"Input base is negative", []sql.Expression{expression.NewLiteral(float64(-5), types.Float64), expression.NewLiteral(float64(10), types.Float64)}, nil, nil}, - {"Input base is valid string", []sql.Expression{expression.NewLiteral("4", types.LongText), expression.NewLiteral(float64(10), types.Float64)}, float64(1.6609640474436813), nil}, - {"Input base is invalid string", []sql.Expression{expression.NewLiteral("bbb", types.LongText), expression.NewLiteral(float64(10), types.Float64)}, nil, sql.ErrInvalidType}, + + {"Input base is valid string", []sql.Expression{expression.NewLiteral("4", types.LongText), expression.NewLiteral(float64(10), types.Float64)}, 1.6609640474436813, nil}, + {"Input base is invalid string", []sql.Expression{expression.NewLiteral("bbb", types.LongText), expression.NewLiteral(float64(10), types.Float64)}, nil, nil}, + {"Input base is valid string int", []sql.Expression{expression.NewLiteral("2", types.LongText), expression.NewLiteral(float64(10), types.Float64)}, 3.321928094887362, nil}, + {"Input base is valid string float", []sql.Expression{expression.NewLiteral("1.23", types.LongText), expression.NewLiteral(float64(10), types.Float64)}, 11.122838112203077, nil}, + {"Input base is invalid string truncates", []sql.Expression{expression.NewLiteral("1.23abc", types.LongText), expression.NewLiteral(float64(10), types.Float64)}, 11.122838112203077, nil}, + {"Input value is truncated string using float notation", []sql.Expression{expression.NewLiteral("1.23456e+2notanumber", types.LongText), expression.NewLiteral(float64(10), types.Float64)}, 0.4781229577440309, nil}, {"Input value is null", []sql.Expression{expression.NewLiteral(nil, types.Float64)}, nil, nil}, {"Input value is zero", []sql.Expression{expression.NewLiteral(float64(0), types.Float64)}, nil, nil}, {"Input value is negative", []sql.Expression{expression.NewLiteral(float64(-9), types.Float64)}, nil, nil}, - {"Input value is valid string", []sql.Expression{expression.NewLiteral("7", types.LongText)}, float64(1.9459101490553132), nil}, - {"Input value is invalid string", []sql.Expression{expression.NewLiteral("766j", types.LongText)}, nil, sql.ErrInvalidType}, - - {"Input base is valid float64", []sql.Expression{expression.NewLiteral(float64(5), types.Float64), expression.NewLiteral(float64(99), types.Float64)}, float64(2.855108491376949), nil}, - {"Input base is valid float32", []sql.Expression{expression.NewLiteral(float32(6), types.Float32), expression.NewLiteral(float64(80), types.Float64)}, float64(2.4456556306420936), nil}, - {"Input base is valid int64", []sql.Expression{expression.NewLiteral(int64(8), types.Int64), expression.NewLiteral(float64(64), types.Float64)}, float64(2), nil}, - {"Input base is valid int32", []sql.Expression{expression.NewLiteral(int32(10), types.Int32), expression.NewLiteral(float64(100), types.Float64)}, float64(2), nil}, - - {"Input value is valid float64", []sql.Expression{expression.NewLiteral(float64(5), types.Float64), expression.NewLiteral(float64(66), types.Float64)}, float64(2.6031788549643564), nil}, - {"Input value is valid float32", []sql.Expression{expression.NewLiteral(float32(3), types.Float32), expression.NewLiteral(float64(50), types.Float64)}, float64(3.560876795007312), nil}, - {"Input value is valid int64", []sql.Expression{expression.NewLiteral(int64(5), types.Int64), expression.NewLiteral(float64(77), types.Float64)}, float64(2.698958057527146), nil}, - {"Input value is valid int32", []sql.Expression{expression.NewLiteral(int32(4), types.Int32), expression.NewLiteral(float64(40), types.Float64)}, float64(2.6609640474436813), nil}, + {"Input value is valid string", []sql.Expression{expression.NewLiteral("7", types.LongText)}, 1.9459101490553132, nil}, + {"Input value is invalid string", []sql.Expression{expression.NewLiteral("bbb", types.LongText)}, nil, nil}, + {"Input value is invalid string truncates", []sql.Expression{expression.NewLiteral("766j", types.LongText)}, 6.641182169740591, nil}, + + {"Input base is valid float64", []sql.Expression{expression.NewLiteral(float64(5), types.Float64), expression.NewLiteral(float64(99), types.Float64)}, 2.855108491376949, nil}, + {"Input base is valid float32", []sql.Expression{expression.NewLiteral(float32(6), types.Float32), expression.NewLiteral(float64(80), types.Float64)}, 2.4456556306420936, nil}, + {"Input base is valid int64", []sql.Expression{expression.NewLiteral(int64(8), types.Int64), expression.NewLiteral(float64(64), types.Float64)}, 2.0, nil}, + {"Input base is valid int32", []sql.Expression{expression.NewLiteral(int32(10), types.Int32), expression.NewLiteral(float64(100), types.Float64)}, 2.0, nil}, + + {"Input value is valid float64", []sql.Expression{expression.NewLiteral(float64(5), types.Float64), expression.NewLiteral(float64(66), types.Float64)}, 2.6031788549643564, nil}, + {"Input value is valid float32", []sql.Expression{expression.NewLiteral(float32(3), types.Float32), expression.NewLiteral(float64(50), types.Float64)}, 3.560876795007312, nil}, + {"Input value is valid int64", []sql.Expression{expression.NewLiteral(int64(5), types.Int64), expression.NewLiteral(float64(77), types.Float64)}, 2.698958057527146, nil}, + {"Input value is valid int32", []sql.Expression{expression.NewLiteral(int32(4), types.Int32), expression.NewLiteral(float64(40), types.Float64)}, 2.6609640474436813, nil}, } for _, tt := range testCases { diff --git a/sql/expression/function/math.go b/sql/expression/function/math.go index 613d04d685..7bc92e0b44 100644 --- a/sql/expression/function/math.go +++ b/sql/expression/function/math.go @@ -24,6 +24,7 @@ import ( "strings" "time" + "github.com/dolthub/vitess/go/mysql" "github.com/shopspring/decimal" "github.com/dolthub/go-mysql-server/sql" @@ -129,15 +130,15 @@ func (r *Rand) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } - var seed int64 - if types.IsNumber(r.Child.Type()) { - e, _, err = types.Int64.Convert(ctx, e) - if err == nil { - seed = e.(int64) + e, _, err = types.Int64.Convert(ctx, e) + if err != nil { + if !sql.ErrTruncatedIncorrect.Is(err) { + return nil, err } + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } - return rand.New(rand.NewSource(seed)).Float64(), nil + return rand.New(rand.NewSource(e.(int64))).Float64(), nil } // Sin is the SIN function @@ -176,7 +177,10 @@ func (s *Sin) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { n, _, err := types.Float64.Convert(ctx, val) if err != nil { - return nil, err + if !sql.ErrTruncatedIncorrect.Is(err) { + return nil, err + } + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } return math.Sin(n.(float64)), nil @@ -225,7 +229,10 @@ func (s *Cos) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { n, _, err := types.Float64.Convert(ctx, val) if err != nil { - return nil, err + if !sql.ErrTruncatedIncorrect.Is(err) { + return nil, err + } + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } return math.Cos(n.(float64)), nil @@ -274,8 +281,12 @@ func (t *Tan) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { n, _, err := types.Float64.Convert(ctx, val) if err != nil { - return nil, err + if !sql.ErrTruncatedIncorrect.Is(err) { + return nil, err + } + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } + res := math.Tan(n.(float64)) if math.IsNaN(res) { return nil, nil @@ -327,7 +338,10 @@ func (a *Asin) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { n, _, err := types.Float64.Convert(ctx, val) if err != nil { - return nil, err + if !sql.ErrTruncatedIncorrect.Is(err) { + return nil, err + } + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } res := math.Asin(n.(float64)) @@ -381,7 +395,10 @@ func (a *Acos) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { n, _, err := types.Float64.Convert(ctx, val) if err != nil { - return nil, err + if !sql.ErrTruncatedIncorrect.Is(err) { + return nil, err + } + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } res := math.Acos(n.(float64)) @@ -490,12 +507,18 @@ func (a *Atan) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { nx, _, err := types.Float64.Convert(ctx, xx) if err != nil { - return nil, err + if !sql.ErrTruncatedIncorrect.Is(err) { + return nil, err + } + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } ny, _, err := types.Float64.Convert(ctx, yy) if err != nil { - return nil, err + if !sql.ErrTruncatedIncorrect.Is(err) { + return nil, err + } + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } return math.Atan2(ny.(float64), nx.(float64)), nil @@ -549,7 +572,10 @@ func (c *Cot) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { n, _, err := types.Float64.Convert(ctx, val) if err != nil { - return nil, err + if !sql.ErrTruncatedIncorrect.Is(err) { + return nil, err + } + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } tan := math.Tan(n.(float64)) @@ -613,7 +639,10 @@ func (d *Degrees) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { n, _, err := types.Float64.Convert(ctx, val) if err != nil { - return nil, err + if !sql.ErrTruncatedIncorrect.Is(err) { + return nil, err + } + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } return (n.(float64) * 180.0) / math.Pi, nil @@ -662,7 +691,10 @@ func (r *Radians) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { n, _, err := types.Float64.Convert(ctx, val) if err != nil { - return nil, err + if !sql.ErrTruncatedIncorrect.Is(err) { + return nil, err + } + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } return (n.(float64) * math.Pi) / 180.0, nil @@ -976,13 +1008,13 @@ func (e *Exp) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { v, _, err := types.Float64.Convert(ctx, val) if err != nil { - // TODO: truncate - ctx.Warn(1292, "Truncated incorrect DOUBLE value: '%v'", val) - v = 0.0 + if !sql.ErrTruncatedIncorrect.Is(err) { + return nil, err + } + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } - vv := v.(float64) - res := math.Exp(vv) + res := math.Exp(v.(float64)) if math.IsNaN(res) || math.IsInf(res, 0) { return nil, nil diff --git a/sql/expression/function/math_test.go b/sql/expression/function/math_test.go index 203c4bc4fe..4771f390c2 100644 --- a/sql/expression/function/math_test.go +++ b/sql/expression/function/math_test.go @@ -84,6 +84,22 @@ func TestRandWithSeed(t *testing.T) { f642 = f.(float64) assert.Equal(t, f64, f642) + + r, _ = NewRand(expression.NewLiteral("10 not a number", types.LongText)) + assert.Equal(t, `rand('10 not a number')`, r.String()) + + f, err = r.Eval(nil, nil) + require.NoError(t, err) + f64 = f.(float64) + + assert.GreaterOrEqual(t, f64, float64(0)) + assert.Less(t, f64, float64(1)) + + f, err = r.Eval(nil, nil) + require.NoError(t, err) + f642 = f.(float64) + + assert.Equal(t, f64, f642) } func TestRadians(t *testing.T) { @@ -94,6 +110,7 @@ func TestRadians(t *testing.T) { tf.AddSucceeding(math.Pi, int16(180)) tf.AddSucceeding(math.Pi/2.0, (90)) tf.AddSucceeding(2*math.Pi, 360.0) + tf.AddSucceeding(math.Pi, "180.0abc") tf.Test(t, nil, nil) } @@ -107,6 +124,7 @@ func TestDegrees(t *testing.T) { {"decimal 2pi", decimal.NewFromFloat(2 * math.Pi), 360.0}, {"float64 pi/2", math.Pi / 2.0, 90.0}, {"float32 3*pi/2", float32(3.0 * math.Pi / 2.0), 270.0}, + {"string truncates", "3.1415926536ABC", 180.0}, } f := sql.Function1{Name: "degrees", Fn: NewDegrees} @@ -385,7 +403,7 @@ func TestExp(t *testing.T) { exp: math.Exp(0), }, { - name: "empty string", + name: "empty string is 0", arg: expression.NewLiteral("", types.Text), exp: math.Exp(0), }, @@ -395,13 +413,22 @@ func TestExp(t *testing.T) { exp: math.Exp(10), }, { - // we don't do truncation yet - // https://github.com/dolthub/dolt/issues/7302 - name: "scientific string is truncated", + name: "scientific float notation string is evaluated", arg: expression.NewLiteral("1e1", types.Text), - exp: "", + exp: math.Exp(10), + err: false, + }, + { + name: "string is truncated", + arg: expression.NewLiteral("10abc", types.Text), + exp: math.Exp(10), + err: false, + }, + { + name: "string is truncated", + arg: expression.NewLiteral("+.123e+1abc", types.Text), + exp: math.Exp(1.23), err: false, - skip: true, }, } diff --git a/sql/expression/function/space.go b/sql/expression/function/space.go index c98d7d23d2..0fd3cda8a1 100644 --- a/sql/expression/function/space.go +++ b/sql/expression/function/space.go @@ -15,6 +15,8 @@ package function import ( + "github.com/dolthub/vitess/go/mysql" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/types" ) @@ -55,8 +57,10 @@ func (s *Space) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { // TODO: better truncate integer handling v, _, err := types.Int64.Convert(ctx, val) if err != nil { - ctx.Warn(1292, "Truncated incorrect INTEGER value: '%v'", val) - v = int64(0) + if !sql.ErrTruncatedIncorrect.Is(err) { + return nil, err + } + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } num := int(v.(int64)) diff --git a/sql/expression/function/sqrt_power.go b/sql/expression/function/sqrt_power.go index dd2be1755b..818e228b25 100644 --- a/sql/expression/function/sqrt_power.go +++ b/sql/expression/function/sqrt_power.go @@ -18,6 +18,8 @@ import ( "fmt" "math" + "github.com/dolthub/vitess/go/mysql" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/types" @@ -76,18 +78,19 @@ func (s *Sqrt) WithChildren(children ...sql.Expression) (sql.Expression, error) // Eval implements the Expression interface. func (s *Sqrt) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { child, err := s.Child.Eval(ctx, row) - if err != nil { return nil, err } - if child == nil { return nil, nil } child, _, err = types.Float64.Convert(ctx, child) if err != nil { - return nil, err + if !sql.ErrTruncatedIncorrect.Is(err) { + return nil, err + } + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } res := math.Sqrt(child.(float64)) @@ -155,28 +158,30 @@ func (p *Power) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { if err != nil { return nil, err } - if left == nil { return nil, nil } - left, _, err = types.Float64.Convert(ctx, left) if err != nil { - return nil, err + if !sql.ErrTruncatedIncorrect.Is(err) { + return nil, err + } + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } right, err := p.RightChild.Eval(ctx, row) if err != nil { return nil, err } - if right == nil { return nil, nil } - right, _, err = types.Float64.Convert(ctx, right) if err != nil { - return nil, err + if !sql.ErrTruncatedIncorrect.Is(err) { + return nil, err + } + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } res := math.Pow(left.(float64), right.(float64)) diff --git a/sql/expression/function/sqrt_power_test.go b/sql/expression/function/sqrt_power_test.go index ea98263cd5..31abc6a81c 100644 --- a/sql/expression/function/sqrt_power_test.go +++ b/sql/expression/function/sqrt_power_test.go @@ -36,10 +36,14 @@ func TestSqrt(t *testing.T) { err bool }{ {"null input", sql.NewRow(nil), nil, false}, - {"invalid string", sql.NewRow("foo"), nil, true}, - {"valid string", sql.NewRow("9"), float64(3), false}, - {"number is zero", sql.NewRow(0), float64(0), false}, - {"positive number", sql.NewRow(8), float64(2.8284271247461903), false}, + {"empty string", sql.NewRow("0"), 0.0, false}, + {"invalid string", sql.NewRow("foo"), 0.0, false}, + {"invalid string int truncated", sql.NewRow("123foo"), 11.090536506409418, false}, + {"invalid string float truncated", sql.NewRow("1.23abc"), 1.1090536506409416, false}, + {"scientific string notation truncated", sql.NewRow("+1.23e2"), 11.090536506409418, false}, + {"valid string", sql.NewRow("9"), 3.0, false}, + {"number is zero", sql.NewRow(0), 0.0, false}, + {"positive number", sql.NewRow(8), 2.8284271247461903, false}, {"negative number", sql.NewRow(-1), nil, false}, } for _, tt := range testCases { @@ -71,13 +75,20 @@ func TestPower(t *testing.T) { {"Base is nil", types.Float64, sql.NewRow(2, nil), nil, false}, {"Exp is nil", types.Float64, sql.NewRow(nil, 2), nil, false}, - {"Base is 0", types.Float64, sql.NewRow(0, 2), float64(0), false}, - {"Base and exp is 0", types.Float64, sql.NewRow(0, 0), float64(1), false}, - {"Exp is 0", types.Float64, sql.NewRow(2, 0), float64(1), false}, - {"Base is negative", types.Float64, sql.NewRow(-2, 2), float64(4), false}, - {"Exp is negative", types.Float64, sql.NewRow(2, -2), float64(0.25), false}, - {"Base and exp are invalid strings", types.Float64, sql.NewRow("a", "b"), nil, true}, - {"Base and exp are valid strings", types.Float64, sql.NewRow("2", "2"), float64(4), false}, + {"Base is 0", types.Float64, sql.NewRow(0, 2), 0.0, false}, + {"Base and exp is 0", types.Float64, sql.NewRow(0, 0), 1.0, false}, + {"Exp is 0", types.Float64, sql.NewRow(2, 0), 1.0, false}, + {"Base is negative", types.Float64, sql.NewRow(-2, 2), 4.0, false}, + {"Exp is negative", types.Float64, sql.NewRow(2, -2), 0.25, false}, + + {"Base and exp are invalid strings", types.Text, sql.NewRow("a", "b"), 1.0, false}, + {"Base and exp are invalid truncated strings", types.Text, sql.NewRow("2a", "2b"), 4.0, false}, + {"Base and exp are valid int strings", types.Text, sql.NewRow("2", "2"), 4.0, false}, + {"Base and exp are valid float strings", types.Text, sql.NewRow("1.2", "3.4"), 1.8587296919794811, false}, + {"Base and exp are valid int strings truncated", types.Text, sql.NewRow("12abc", "2abc"), 144.0, false}, + {"Base and exp are valid float strings truncated", types.Text, sql.NewRow("1.2abc", "2.3abc"), 1.5209567545525318, false}, + {"Base and exp are valid scientific float notation strings truncated", types.Text, sql.NewRow("0.12e1asdf", "+23e-1asdf"), 1.5209567545525318, false}, + {"positive inf", types.Float64, sql.NewRow(2, math.Inf(1)), nil, false}, {"negative inf", types.Float64, sql.NewRow(2, math.Inf(1)), nil, false}, } diff --git a/sql/expression/in.go b/sql/expression/in.go index 622ee5779f..b54b55580f 100644 --- a/sql/expression/in.go +++ b/sql/expression/in.go @@ -66,7 +66,6 @@ func (in *InTuple) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { if err != nil { return nil, err } - if originalLeft == nil { return nil, nil } @@ -85,6 +84,7 @@ func (in *InTuple) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } } + leftLit := NewLiteral(originalLeft, in.Left().Type()) for _, el := range right { originalRight, err := el.Eval(ctx, row) if err != nil { @@ -96,8 +96,8 @@ func (in *InTuple) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { continue } - comp := newComparison(NewLiteral(originalLeft, in.Left().Type()), NewLiteral(originalRight, el.Type())) - l, r, compareType, err := comp.CastLeftAndRight(ctx, originalLeft, originalRight) + comp := newComparison(leftLit, NewLiteral(originalRight, el.Type())) + l, r, compareType, err := comp.castLeftAndRight(ctx, originalLeft, originalRight) if err != nil { return nil, err } @@ -105,7 +105,6 @@ func (in *InTuple) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { if err != nil { return nil, err } - if cmp == 0 { return true, nil } diff --git a/sql/expression/interval.go b/sql/expression/interval.go index 0cf8e23a80..6f4dee4dca 100644 --- a/sql/expression/interval.go +++ b/sql/expression/interval.go @@ -21,6 +21,7 @@ import ( "strings" "time" + "github.com/dolthub/vitess/go/mysql" errors "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" @@ -140,7 +141,10 @@ func (i *Interval) EvalDelta(ctx *sql.Context, row sql.Row) (*TimeDelta, error) } else { val, _, err = types.Int64.Convert(ctx, val) if err != nil { - return nil, err + if !sql.ErrTruncatedIncorrect.Is(err) { + return nil, err + } + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) } num := val.(int64) diff --git a/sql/expression/procedurereference.go b/sql/expression/procedurereference.go index 4baef784ee..f9f414dad2 100644 --- a/sql/expression/procedurereference.go +++ b/sql/expression/procedurereference.go @@ -69,6 +69,9 @@ func (ppr *ProcedureReference) InitializeVariable(ctx *sql.Context, name string, } convertedVal, _, err := sqlType.Convert(ctx, val) if err != nil { + if sql.ErrTruncatedIncorrect.Is(err) { + return sql.ErrInvalidValue.New(val, sqlType) + } return err } lowerName := strings.ToLower(name) diff --git a/sql/expression/set.go b/sql/expression/set.go index a2eff1dbc1..38a567a570 100644 --- a/sql/expression/set.go +++ b/sql/expression/set.go @@ -78,6 +78,9 @@ func (s *SetField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { if types.ErrLengthBeyondLimit.Is(err) { return nil, sql.NewWrappedTypeConversionError(val, getField.fieldIndex, types.ErrLengthBeyondLimit.New(val, getField.Name())) } + if sql.ErrTruncatedIncorrect.Is(err) { + err = sql.ErrInvalidValue.New(val, getField.fieldType) + } return nil, sql.NewWrappedTypeConversionError(val, getField.fieldIndex, err) } val = convertedVal diff --git a/sql/index_builder.go b/sql/index_builder.go index 8652a2b3b1..a7fb3ac599 100644 --- a/sql/index_builder.go +++ b/sql/index_builder.go @@ -229,8 +229,11 @@ func (b *MySQLIndexBuilder) convertKey(ctx *Context, colType Type, keyType Type, if et, ok := colType.(ExtendedType); ok { return et.ConvertToType(ctx, keyType.(ExtendedType), key) } else { - key, _, err := colType.Convert(ctx, key) - return key, err + k, _, err := colType.Convert(ctx, key) + if err != nil && !ErrTruncatedIncorrect.Is(err) { + return nil, err + } + return k, nil } } diff --git a/sql/iters/rel_iters.go b/sql/iters/rel_iters.go index 47439d5dc3..0bbfa2179d 100644 --- a/sql/iters/rel_iters.go +++ b/sql/iters/rel_iters.go @@ -290,11 +290,17 @@ func (c *JsonTableCol) Next(ctx *sql.Context, obj interface{}, pass bool, ord in val, _, err = c.Opts.Typ.Convert(ctx, val) if err != nil { if c.Opts.ErrOnErr { - return nil, err + if sql.ErrTruncatedIncorrect.Is(err) { + return nil, sql.ErrInvalidJSONText.New(c.pos+1, "JSON_TABLE", "Invalid value.") + } + return nil, sql.ErrInvalidJSONText.New(c.pos+1, "JSON_TABLE", err.Error()) } val, _, err = c.Opts.Typ.Convert(ctx, c.Opts.DefErrVal) if err != nil { - return nil, err + if sql.ErrTruncatedIncorrect.Is(err) { + return nil, sql.ErrInvalidJSONText.New(c.pos+1, "JSON_TABLE", "Invalid value.") + } + return nil, sql.ErrInvalidJSONText.New(c.pos+1, "JSON_TABLE", err.Error()) } } diff --git a/sql/rowexec/ddl_iters.go b/sql/rowexec/ddl_iters.go index 814c50530d..67e8a95951 100644 --- a/sql/rowexec/ddl_iters.go +++ b/sql/rowexec/ddl_iters.go @@ -942,6 +942,9 @@ func projectRowWithTypes(ctx *sql.Context, oldSchema, newSchema sql.Schema, proj if sql.ErrNotMatchingSRID.Is(err) { err = sql.ErrNotMatchingSRIDWithColName.New(newSchema[i].Name, err) } + if sql.ErrTruncatedIncorrect.Is(err) { + err = sql.ErrInvalidValue.New(newSchema[i].Type, newRow[i]) + } return nil, err } else if !inRange { return nil, sql.ErrValueOutOfRange.New(newRow[i], newSchema[i].Type) diff --git a/sql/rowexec/insert.go b/sql/rowexec/insert.go index 33bebe72fb..ebc7f53f02 100644 --- a/sql/rowexec/insert.go +++ b/sql/rowexec/insert.go @@ -120,9 +120,21 @@ func (i *insertIter) Next(ctx *sql.Context) (returnRow sql.Row, returnErr error) ctxWithValues := context.WithValue(ctx.Context, types.ColumnNameKey, col.Name) ctxWithValues = context.WithValue(ctxWithValues, types.RowNumberKey, i.rowNumber) ctxWithColumnInfo := ctx.WithContext(ctxWithValues) - converted, inRange, cErr := col.Type.Convert(ctxWithColumnInfo, row[idx]) + val := row[idx] + // TODO: check mysql strict sql_mode + var converted any + var inRange sql.ConvertInRange + var cErr error + if typ, ok := col.Type.(sql.RoundingNumberType); ok { + converted, inRange, cErr = typ.ConvertRound(ctx, val) + } else { + converted, inRange, cErr = col.Type.Convert(ctxWithColumnInfo, val) + } if cErr == nil && !inRange { - cErr = sql.ErrValueOutOfRange.New(row[idx], col.Type) + cErr = sql.ErrValueOutOfRange.New(val, col.Type) + } + if sql.ErrTruncatedIncorrect.Is(cErr) { + cErr = sql.ErrInvalidValue.New(val, col.Type) } if cErr != nil { // Ignore individual column errors when INSERT IGNORE, UPDATE IGNORE, etc. is specified. diff --git a/sql/type.go b/sql/type.go index 379ed92221..17e50fd983 100644 --- a/sql/type.go +++ b/sql/type.go @@ -18,7 +18,11 @@ import ( "context" "fmt" "reflect" + "strings" "time" + "unicode" + + "github.com/dolthub/vitess/go/mysql" "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" @@ -101,6 +105,46 @@ type Type interface { fmt.Stringer } +// TrimStringToNumberPrefix will remove any white space for s and truncate any trailing non-numeric characters. +func TrimStringToNumberPrefix(ctx *Context, s string, isInt bool) string { + if isInt { + s = strings.TrimLeft(s, IntCutSet) + } else { + s = strings.TrimLeft(s, NumericCutSet) + } + + seenDigit := false + seenDot := false + seenExp := false + signIndex := 0 + + var i int + for i = 0; i < len(s); i++ { + char := rune(s[i]) + if unicode.IsDigit(char) { + seenDigit = true + } else if char == '.' && !seenDot && !isInt { + seenDot = true + } else if (char == 'e' || char == 'E') && !seenExp && seenDigit && !isInt { + seenExp = true + signIndex = i + 1 + } else if !((char == '-' || char == '+') && i == signIndex) { + // TODO: this should not happen here, and it should use sql.ErrIncorrectTruncation + if isInt { + ctx.Warn(mysql.ERTruncatedWrongValue, "Truncated incorrect INTEGER value: '%s'", s) + } else { + ctx.Warn(mysql.ERTruncatedWrongValue, "Truncated incorrect DOUBLE value: '%s'", s) + } + break + } + } + s = s[:i] + if s == "" { + s = "0" + } + return s +} + // NullType represents the type of NULL values type NullType interface { Type @@ -128,6 +172,13 @@ type NumberType interface { DisplayWidth() int } +// RoundingNumberType represents Number Types that implement an additional interface +// that supports rounding when converting rather than the default truncation. +type RoundingNumberType interface { + NumberType + ConvertRound(context.Context, any) (any, ConvertInRange, error) +} + // StringType represents all string types, including VARCHAR and BLOB. // https://dev.mysql.com/doc/refman/8.0/en/char.html // https://dev.mysql.com/doc/refman/8.0/en/binary-varbinary.html diff --git a/sql/types/conversion.go b/sql/types/conversion.go index 966c61815c..2e5ca52748 100644 --- a/sql/types/conversion.go +++ b/sql/types/conversion.go @@ -21,6 +21,7 @@ import ( "strings" "time" + "github.com/dolthub/vitess/go/mysql" "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" "github.com/dolthub/vitess/go/vt/sqlparser" @@ -773,6 +774,10 @@ func ConvertOrTruncate(ctx *sql.Context, i interface{}, t sql.Type) (interface{} if err == nil { return converted, nil } + if sql.ErrTruncatedIncorrect.Is(err) { + ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error()) + return converted, nil + } // If a value can't be converted to an enum or set type, truncate it to a value that is guaranteed // to not match any enum value. diff --git a/sql/types/decimal.go b/sql/types/decimal.go index 48fa0288bc..01fc939e05 100644 --- a/sql/types/decimal.go +++ b/sql/types/decimal.go @@ -141,13 +141,17 @@ func (t DecimalType_) Compare(s context.Context, a interface{}, b interface{}) ( // Convert implements Type interface. func (t DecimalType_) Convert(c context.Context, v interface{}) (interface{}, sql.ConvertInRange, error) { dec, err := t.ConvertToNullDecimal(v) - if err != nil { + if err != nil && !sql.ErrTruncatedIncorrect.Is(err) { return nil, sql.OutOfRange, err } if !dec.Valid { return nil, sql.InRange, nil } - return t.BoundsCheck(dec.Decimal) + res, inRange, cErr := t.BoundsCheck(dec.Decimal) + if cErr != nil { + return nil, sql.OutOfRange, cErr + } + return res, inRange, err } func (t DecimalType_) ConvertNoBoundsCheck(v interface{}) (decimal.Decimal, error) { @@ -201,25 +205,40 @@ func (t DecimalType_) ConvertToNullDecimal(v interface{}) (decimal.NullDecimal, case float64: return t.ConvertToNullDecimal(decimal.NewFromFloat(value)) case string: - // TODO: implement truncation here - value = strings.Trim(value, sql.NumericCutSet) - if len(value) == 0 { - return t.ConvertToNullDecimal(decimal.NewFromInt(0)) - } var err error - res, err = decimal.NewFromString(value) - if err != nil { - // The decimal library cannot handle all of the different formats - bf, _, err := new(big.Float).SetPrec(217).Parse(value, 0) - if err != nil { - return decimal.NullDecimal{}, err - } + truncStr := strings.Trim(value, sql.NumericCutSet) + res, err = decimal.NewFromString(truncStr) + if err == nil { + return t.ConvertToNullDecimal(res) + } + // The decimal library cannot handle all the different formats + bf, _, err := new(big.Float).SetPrec(217).Parse(truncStr, 0) + if err == nil { res, err = decimal.NewFromString(bf.Text('f', -1)) - if err != nil { - return decimal.NullDecimal{}, err + if err == nil { + return t.ConvertToNullDecimal(res) + } + } + truncStr, didTrunc := TruncateStringToDouble(value) + if truncStr == "0" { + nullDec, cErr := t.ConvertToNullDecimal(decimal.NewFromInt(0)) + if cErr != nil { + return decimal.NullDecimal{}, cErr } + if didTrunc { + return nullDec, sql.ErrTruncatedIncorrect.New(t, value) + } + return nullDec, nil + } + res, _ = decimal.NewFromString(truncStr) + nullDec, cErr := t.ConvertToNullDecimal(res) + if cErr != nil { + return decimal.NullDecimal{}, cErr + } + if didTrunc { + err = sql.ErrTruncatedIncorrect.New(t, value) } - return t.ConvertToNullDecimal(res) + return nullDec, err case *big.Float: return t.ConvertToNullDecimal(value.Text('f', -1)) case *big.Int: diff --git a/sql/types/number.go b/sql/types/number.go index 35ba4b1c91..3f6ac146e3 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -17,6 +17,7 @@ package types import ( "context" "encoding/hex" + "errors" "fmt" "math" "reflect" @@ -24,6 +25,7 @@ import ( "strconv" "strings" "time" + "unicode" "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" @@ -87,6 +89,13 @@ var ( numre = regexp.MustCompile(`^[ ]*[0-9]*\.?[0-9]+`) ) +type Round bool + +const ( + ShouldTruncate Round = false + ShouldRound Round = true +) + type NumberTypeImpl_ struct { baseType query.Type displayWidth int @@ -96,6 +105,7 @@ var _ sql.Type = NumberTypeImpl_{} var _ sql.Type2 = NumberTypeImpl_{} var _ sql.CollationCoercible = NumberTypeImpl_{} var _ sql.NumberType = NumberTypeImpl_{} +var _ sql.RoundingNumberType = NumberTypeImpl_{} // CreateNumberType creates a NumberType. func CreateNumberType(baseType query.Type) (sql.NumberType, error) { @@ -109,7 +119,6 @@ func CreateNumberTypeWithDisplayWidth(baseType query.Type, displayWidth int) (sq switch baseType { case sqltypes.Int8, sqltypes.Uint8, sqltypes.Int16, sqltypes.Uint16, sqltypes.Int24, sqltypes.Uint24, sqltypes.Int32, sqltypes.Uint32, sqltypes.Int64, sqltypes.Uint64, sqltypes.Float32, sqltypes.Float64: - // displayWidth of 0 is valid for all types, displayWidth of 1 is only valid for Int8 if displayWidth == 0 || (displayWidth == 1 && baseType == sqltypes.Int8) { return NumberTypeImpl_{ @@ -148,11 +157,11 @@ func (t NumberTypeImpl_) Compare(s context.Context, a interface{}, b interface{} switch t.baseType { case sqltypes.Uint8, sqltypes.Uint16, sqltypes.Uint24, sqltypes.Uint32, sqltypes.Uint64: - ca, _, err := convertToUint64(t, a) + ca, _, err := convertToUint64(t, a, ShouldTruncate) if err != nil { return 0, err } - cb, _, err := convertToUint64(t, b) + cb, _, err := convertToUint64(t, b, ShouldTruncate) if err != nil { return 0, err } @@ -182,11 +191,11 @@ func (t NumberTypeImpl_) Compare(s context.Context, a interface{}, b interface{} } return +1, nil default: - ca, _, err := convertToInt64(t, a) + ca, _, err := convertToInt64(t, a, ShouldTruncate) if err != nil { ca = 0 } - cb, _, err := convertToInt64(t, b) + cb, _, err := convertToInt64(t, b, ShouldTruncate) if err != nil { cb = 0 } @@ -221,10 +230,136 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{ switch t.baseType { case sqltypes.Int8: - num, _, err := convertToInt64(t, v) - if err != nil { + num, _, err := convertToInt64(t, v, ShouldTruncate) + if err != nil && !sql.ErrTruncatedIncorrect.Is(err) { + return int8(num), sql.OutOfRange, err + } + if num > math.MaxInt8 { + return int8(math.MaxInt8), sql.OutOfRange, nil + } + if num < math.MinInt8 { + return int8(math.MinInt8), sql.OutOfRange, nil + } + return int8(num), sql.InRange, err + case sqltypes.Uint8: + num, _, err := convertToInt64(t, v, ShouldTruncate) + if err != nil && !sql.ErrTruncatedIncorrect.Is(err) { + return uint8(num), sql.OutOfRange, err + } + if num > math.MaxUint8 { + return uint8(math.MaxUint8), sql.OutOfRange, nil + } + if num < 0 { + return uint8(math.MaxUint8 + num + 1), sql.OutOfRange, nil + } + return uint8(num), sql.InRange, err + case sqltypes.Int16: + num, _, err := convertToInt64(t, v, ShouldTruncate) + if err != nil && !sql.ErrTruncatedIncorrect.Is(err) { + return int16(num), sql.OutOfRange, err + } + if num > math.MaxInt16 { + return int16(math.MaxInt16), sql.OutOfRange, nil + } + if num < math.MinInt16 { + return int16(math.MinInt16), sql.OutOfRange, nil + } + return int16(num), sql.InRange, err + case sqltypes.Uint16: + num, _, err := convertToInt64(t, v, ShouldTruncate) + if err != nil && !sql.ErrTruncatedIncorrect.Is(err) { + return uint16(num), sql.OutOfRange, err + } + if num > math.MaxUint16 { + return uint16(math.MaxUint16), sql.OutOfRange, nil + } + if num < 0 { + return uint16(math.MaxUint16 + num + 1), sql.OutOfRange, nil + } + return uint16(num), sql.InRange, err + case sqltypes.Int24: + num, _, err := convertToInt64(t, v, ShouldTruncate) + if err != nil && !sql.ErrTruncatedIncorrect.Is(err) { + return int32(num), sql.OutOfRange, err + } + if num > (1<<23 - 1) { + return int32(1<<23 - 1), sql.OutOfRange, nil + } + if num < (-1 << 23) { + return int32(-1 << 23), sql.OutOfRange, nil + } + return int32(num), sql.InRange, err + case sqltypes.Uint24: + num, _, err := convertToInt64(t, v, ShouldTruncate) + if err != nil && !sql.ErrTruncatedIncorrect.Is(err) { + return uint32(num), sql.OutOfRange, err + } + if num >= (1 << 24) { + return uint32(1<<24 - 1), sql.OutOfRange, nil + } + if num < 0 { + return uint32(1<<24 + num), sql.OutOfRange, nil + } + return uint32(num), sql.InRange, err + case sqltypes.Int32: + num, _, err := convertToInt64(t, v, ShouldTruncate) + if err != nil && !sql.ErrTruncatedIncorrect.Is(err) { + return int32(num), sql.OutOfRange, err + } + if num > math.MaxInt32 { + return int32(math.MaxInt32), sql.OutOfRange, nil + } + if num < math.MinInt32 { + return int32(math.MinInt32), sql.OutOfRange, nil + } + return int32(num), sql.InRange, err + case sqltypes.Uint32: + num, _, err := convertToInt64(t, v, ShouldTruncate) + if err != nil && !sql.ErrTruncatedIncorrect.Is(err) { + return uint32(num), sql.OutOfRange, err + } + if num > math.MaxUint32 { + return uint32(math.MaxUint32), sql.OutOfRange, nil + } + if num < 0 { + return uint32(math.MaxUint32 + num + 1), sql.OutOfRange, nil + } + return uint32(num), sql.InRange, err + case sqltypes.Int64: + return convertToInt64(t, v, ShouldTruncate) + case sqltypes.Uint64: + return convertToUint64(t, v, ShouldTruncate) + case sqltypes.Float32: + num, err := convertToFloat64(t, v) + if err != nil && !sql.ErrTruncatedIncorrect.Is(err) { return nil, sql.OutOfRange, err } + if num > math.MaxFloat32 { + return float32(math.MaxFloat32), sql.OutOfRange, nil + } + if num < -math.MaxFloat32 { + return float32(-math.MaxFloat32), sql.OutOfRange, nil + } + return float32(num), sql.InRange, err + case sqltypes.Float64: + ret, err := convertToFloat64(t, v) + return ret, sql.InRange, err + default: + return nil, sql.OutOfRange, sql.ErrInvalidType.New(t.baseType.String()) + } +} + +func (t NumberTypeImpl_) ConvertRound(ctx context.Context, v interface{}) (any, sql.ConvertInRange, error) { + // This operates specifically on Integer base types and when v is a string + if _, isStr := v.(string); !isStr { + return t.Convert(ctx, v) + } + switch t.baseType { + case sqltypes.Int8: + num, _, err := convertToInt64(t, v, ShouldRound) + if err != nil { + return int8(num), sql.OutOfRange, err + } if num > math.MaxInt8 { return int8(math.MaxInt8), sql.OutOfRange, nil } @@ -233,9 +368,9 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{ } return int8(num), sql.InRange, nil case sqltypes.Uint8: - num, _, err := convertToInt64(t, v) + num, _, err := convertToInt64(t, v, ShouldRound) if err != nil { - return nil, sql.OutOfRange, err + return uint8(num), sql.OutOfRange, err } if num > math.MaxUint8 { return uint8(math.MaxUint8), sql.OutOfRange, nil @@ -245,9 +380,9 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{ } return uint8(num), sql.InRange, nil case sqltypes.Int16: - num, _, err := convertToInt64(t, v) + num, _, err := convertToInt64(t, v, ShouldRound) if err != nil { - return nil, sql.OutOfRange, err + return int16(num), sql.OutOfRange, err } if num > math.MaxInt16 { return int16(math.MaxInt16), sql.OutOfRange, nil @@ -257,9 +392,9 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{ } return int16(num), sql.InRange, nil case sqltypes.Uint16: - num, _, err := convertToInt64(t, v) + num, _, err := convertToInt64(t, v, ShouldRound) if err != nil { - return nil, sql.OutOfRange, err + return uint16(num), sql.OutOfRange, err } if num > math.MaxUint16 { return uint16(math.MaxUint16), sql.OutOfRange, nil @@ -269,9 +404,9 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{ } return uint16(num), sql.InRange, nil case sqltypes.Int24: - num, _, err := convertToInt64(t, v) + num, _, err := convertToInt64(t, v, ShouldRound) if err != nil { - return nil, sql.OutOfRange, err + return int32(num), sql.OutOfRange, err } if num > (1<<23 - 1) { return int32(1<<23 - 1), sql.OutOfRange, nil @@ -281,9 +416,9 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{ } return int32(num), sql.InRange, nil case sqltypes.Uint24: - num, _, err := convertToInt64(t, v) + num, _, err := convertToInt64(t, v, ShouldRound) if err != nil { - return nil, sql.OutOfRange, err + return uint32(num), sql.OutOfRange, err } if num >= (1 << 24) { return uint32(1<<24 - 1), sql.OutOfRange, nil @@ -293,9 +428,9 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{ } return uint32(num), sql.InRange, nil case sqltypes.Int32: - num, _, err := convertToInt64(t, v) + num, _, err := convertToInt64(t, v, ShouldRound) if err != nil { - return nil, sql.OutOfRange, err + return int32(num), sql.OutOfRange, err } if num > math.MaxInt32 { return int32(math.MaxInt32), sql.OutOfRange, nil @@ -305,9 +440,9 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{ } return int32(num), sql.InRange, nil case sqltypes.Uint32: - num, _, err := convertToInt64(t, v) + num, _, err := convertToInt64(t, v, ShouldRound) if err != nil { - return nil, sql.OutOfRange, err + return uint32(num), sql.OutOfRange, err } if num > math.MaxUint32 { return uint32(math.MaxUint32), sql.OutOfRange, nil @@ -317,26 +452,11 @@ func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{ } return uint32(num), sql.InRange, nil case sqltypes.Int64: - return convertToInt64(t, v) + return convertToInt64(t, v, ShouldRound) case sqltypes.Uint64: - return convertToUint64(t, v) - case sqltypes.Float32: - num, err := convertToFloat64(t, v) - if err != nil { - return nil, sql.OutOfRange, err - } - if num > math.MaxFloat32 { - return float32(math.MaxFloat32), sql.OutOfRange, nil - } - if num < -math.MaxFloat32 { - return float32(-math.MaxFloat32), sql.OutOfRange, nil - } - return float32(num), sql.InRange, nil - case sqltypes.Float64: - ret, err := convertToFloat64(t, v) - return ret, sql.InRange, err + return convertToUint64(t, v, ShouldRound) default: - return nil, sql.OutOfRange, sql.ErrInvalidType.New(t.baseType.String()) + return t.Convert(ctx, v) } } @@ -394,7 +514,7 @@ func (t NumberTypeImpl_) Promote() sql.Type { } func (t NumberTypeImpl_) SQLInt8(ctx *sql.Context, dest []byte, v interface{}) ([]byte, error) { - num, _, err := convertToInt64(t, v) + num, _, err := convertToInt64(t, v, false) if err != nil { return nil, err } @@ -408,7 +528,7 @@ func (t NumberTypeImpl_) SQLInt8(ctx *sql.Context, dest []byte, v interface{}) ( } func (t NumberTypeImpl_) SQLInt16(ctx *sql.Context, dest []byte, v interface{}) ([]byte, error) { - num, _, err := convertToInt64(t, v) + num, _, err := convertToInt64(t, v, false) if err != nil { return nil, err } @@ -422,7 +542,7 @@ func (t NumberTypeImpl_) SQLInt16(ctx *sql.Context, dest []byte, v interface{}) } func (t NumberTypeImpl_) SQLInt24(ctx *sql.Context, dest []byte, v interface{}) ([]byte, error) { - num, _, err := convertToInt64(t, v) + num, _, err := convertToInt64(t, v, false) if err != nil { return nil, err } @@ -436,7 +556,7 @@ func (t NumberTypeImpl_) SQLInt24(ctx *sql.Context, dest []byte, v interface{}) } func (t NumberTypeImpl_) SQLInt32(ctx *sql.Context, dest []byte, v interface{}) ([]byte, error) { - num, _, err := convertToInt64(t, v) + num, _, err := convertToInt64(t, v, false) if err != nil { return nil, err } @@ -450,7 +570,7 @@ func (t NumberTypeImpl_) SQLInt32(ctx *sql.Context, dest []byte, v interface{}) } func (t NumberTypeImpl_) SQLInt64(ctx *sql.Context, dest []byte, v interface{}) ([]byte, error) { - vt, _, err := convertToInt64(t, v) + vt, _, err := convertToInt64(t, v, false) if err != nil { return nil, err } @@ -459,7 +579,7 @@ func (t NumberTypeImpl_) SQLInt64(ctx *sql.Context, dest []byte, v interface{}) } func (t NumberTypeImpl_) SQLUint8(ctx *sql.Context, dest []byte, v interface{}) ([]byte, error) { - num, _, err := convertToUint64(t, v) + num, _, err := convertToUint64(t, v, false) if err != nil { return nil, err } @@ -472,7 +592,7 @@ func (t NumberTypeImpl_) SQLUint8(ctx *sql.Context, dest []byte, v interface{}) } func (t NumberTypeImpl_) SQLUint16(ctx *sql.Context, dest []byte, v interface{}) ([]byte, error) { - num, _, err := convertToUint64(t, v) + num, _, err := convertToUint64(t, v, false) if err != nil { return nil, err } @@ -485,7 +605,7 @@ func (t NumberTypeImpl_) SQLUint16(ctx *sql.Context, dest []byte, v interface{}) } func (t NumberTypeImpl_) SQLUint24(ctx *sql.Context, dest []byte, v interface{}) ([]byte, error) { - num, _, err := convertToUint64(t, v) + num, _, err := convertToUint64(t, v, false) if err != nil { return nil, err } @@ -498,7 +618,7 @@ func (t NumberTypeImpl_) SQLUint24(ctx *sql.Context, dest []byte, v interface{}) } func (t NumberTypeImpl_) SQLUint32(ctx *sql.Context, dest []byte, v interface{}) ([]byte, error) { - num, _, err := convertToUint64(t, v) + num, _, err := convertToUint64(t, v, false) if err != nil { return nil, err } @@ -511,7 +631,7 @@ func (t NumberTypeImpl_) SQLUint32(ctx *sql.Context, dest []byte, v interface{}) } func (t NumberTypeImpl_) SQLUint64(ctx *sql.Context, dest []byte, v interface{}) ([]byte, error) { - num, _, err := convertToUint64(t, v) + num, _, err := convertToUint64(t, v, false) if err != nil { return nil, err } @@ -591,7 +711,7 @@ func (t NumberTypeImpl_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqlt return sqltypes.Value{}, sql.ErrInvalidType.New(t.baseType.String()) } - if sql.ErrInvalidValue.Is(err) { + if sql.ErrInvalidValue.Is(err) || sql.ErrTruncatedIncorrect.Is(err) { switch str := v.(type) { case []byte: dest = str @@ -931,7 +1051,7 @@ func (t NumberTypeImpl_) DisplayWidth() int { return t.displayWidth } -func convertToInt64(t NumberTypeImpl_, v interface{}) (int64, sql.ConvertInRange, error) { +func convertToInt64(t NumberTypeImpl_, v any, round Round) (int64, sql.ConvertInRange, error) { switch v := v.(type) { case time.Time: return v.UTC().Unix(), sql.InRange, nil @@ -961,21 +1081,24 @@ func convertToInt64(t NumberTypeImpl_, v interface{}) (int64, sql.ConvertInRange case float32: if v > float32(math.MaxInt64) { return math.MaxInt64, sql.OutOfRange, nil - } else if v < float32(math.MinInt64) { + } + if v < float32(math.MinInt64) { return math.MinInt64, sql.OutOfRange, nil } return int64(math.Round(float64(v))), sql.InRange, nil case float64: if v > float64(math.MaxInt64) { return math.MaxInt64, sql.OutOfRange, nil - } else if v < float64(math.MinInt64) { + } + if v < float64(math.MinInt64) { return math.MinInt64, sql.OutOfRange, nil } return int64(math.Round(v)), sql.InRange, nil case decimal.Decimal: if v.GreaterThan(dec_int64_max) { return dec_int64_max.IntPart(), sql.OutOfRange, nil - } else if v.LessThan(dec_int64_min) { + } + if v.LessThan(dec_int64_min) { return dec_int64_min.IntPart(), sql.OutOfRange, nil } return v.Round(0).IntPart(), sql.InRange, nil @@ -986,23 +1109,34 @@ func convertToInt64(t NumberTypeImpl_, v interface{}) (int64, sql.ConvertInRange } return i, sql.InRange, nil case string: - v = strings.Trim(v, sql.IntCutSet) - if v == "" { - // StringType{}.Zero() returns empty string, but should represent "0" for number value - return 0, sql.InRange, nil - } - // Parse first an integer, which allows for more values than float64 - i, err := strconv.ParseInt(v, 10, 64) - if err == nil { - return i, sql.InRange, nil - } - // If that fails, try as a float and truncate it to integral - f, err := strconv.ParseFloat(v, 64) - if err != nil { - return 0, sql.OutOfRange, sql.ErrInvalidValue.New(v, t.String()) + var err error + if round { + truncStr, didTrunc := TruncateStringToDouble(v) + if didTrunc { + err = sql.ErrTruncatedIncorrect.New(t, v) + } + // Parse first an integer, which allows for more values than float64 + i, pErr := strconv.ParseInt(truncStr, 10, 64) + if pErr == nil { + return i, sql.InRange, err + } + // If that fails, try as a float + f, pErr := strconv.ParseFloat(truncStr, 64) + if pErr != nil { + return 0, sql.OutOfRange, sql.ErrInvalidValue.New(v, t.String()) + } + i, inRange, _ := convertToInt64(t, f, round) + return i, inRange, err + } + truncStr, didTrunc := TruncateStringToInt(v) + if didTrunc { + err = sql.ErrTruncatedIncorrect.New(t, v) + } + i, pErr := strconv.ParseInt(truncStr, 10, 64) + if pErr == nil { + return i, sql.InRange, err } - f = math.Round(f) - return int64(f), sql.InRange, nil + return 0, sql.OutOfRange, sql.ErrInvalidValue.New(v, t.String()) case bool: if v { return 1, sql.InRange, nil @@ -1103,7 +1237,7 @@ func convertValueToUint64(t NumberTypeImpl_, v sql.Value) (uint64, error) { } } -func convertToUint64(t NumberTypeImpl_, v interface{}) (uint64, sql.ConvertInRange, error) { +func convertToUint64(t NumberTypeImpl_, v any, round Round) (uint64, sql.ConvertInRange, error) { switch v := v.(type) { case time.Time: return uint64(v.UTC().Unix()), sql.InRange, nil @@ -1145,21 +1279,24 @@ func convertToUint64(t NumberTypeImpl_, v interface{}) (uint64, sql.ConvertInRan case float32: if v > float32(math.MaxInt64) { return math.MaxUint64, sql.OutOfRange, nil - } else if v < 0 { - return uint64(math.MaxUint64 - v), sql.OutOfRange, nil + } + if v < 0 { + return uint64(math.MaxUint64 - uint(-v-1)), sql.OutOfRange, nil } return uint64(math.Round(float64(v))), sql.InRange, nil case float64: if v >= float64(math.MaxUint64) { return math.MaxUint64, sql.OutOfRange, nil - } else if v <= 0 { - return uint64(math.MaxUint64 - v), sql.OutOfRange, nil + } + if v < 0 { + return uint64(math.MaxUint64 - uint(-v-1)), sql.OutOfRange, nil } return uint64(math.Round(v)), sql.InRange, nil case decimal.Decimal: if v.GreaterThan(dec_uint64_max) { return math.MaxUint64, sql.OutOfRange, nil - } else if v.LessThan(dec_zero) { + } + if v.LessThan(dec_zero) { ret, _ := dec_uint64_max.Sub(v).Float64() return uint64(math.Round(ret)), sql.OutOfRange, nil } @@ -1173,19 +1310,46 @@ func convertToUint64(t NumberTypeImpl_, v interface{}) (uint64, sql.ConvertInRan } return i, sql.InRange, nil case string: - v = strings.Trim(v, sql.IntCutSet) - if i, err := strconv.ParseUint(v, 10, 64); err == nil { - return i, sql.InRange, nil - } else if err == strconv.ErrRange { - // Number is too large for uint64, return max value and OutOfRange + var err error + if round { + truncStr, didTrunc := TruncateStringToDouble(v) + if didTrunc { + err = sql.ErrTruncatedIncorrect.New(t, v) + } + // Parse first an integer, which allows for more values than float64 + i, pErr := strconv.ParseUint(truncStr, 10, 64) + if pErr == nil { + return i, sql.InRange, err + } + // If that fails, try as a float + f, pErr := strconv.ParseFloat(truncStr, 64) + if pErr != nil { + return 0, sql.OutOfRange, sql.ErrInvalidValue.New(v, t.String()) + } + i, inRange, _ := convertToUint64(t, f, round) + return i, inRange, err + } + truncStr, didTrunc := TruncateStringToInt(v) + if didTrunc { + err = sql.ErrTruncatedIncorrect.New(t, v) + } + var neg bool + if truncStr[0] == '+' { + truncStr = truncStr[1:] + } else if truncStr[0] == '-' { + truncStr = truncStr[1:] + neg = true + } + // Parse first as an integer, which allows for more values than float64 + i, pErr := strconv.ParseUint(truncStr, 10, 64) + // Number is too large for uint64, return max value and OutOfRange + if errors.Is(pErr, strconv.ErrRange) { return math.MaxUint64, sql.OutOfRange, nil } - if f, err := strconv.ParseFloat(v, 64); err == nil { - if val, inRange, err := convertToUint64(t, f); err == nil && inRange { - return val, inRange, err - } + if neg { + return math.MaxUint64 - i + 1, sql.OutOfRange, err } - return 0, sql.OutOfRange, sql.ErrInvalidValue.New(v, t.String()) + return i, sql.InRange, err case bool: if v { return 1, sql.InRange, nil @@ -1236,15 +1400,13 @@ func convertToFloat64(t NumberTypeImpl_, v interface{}) (float64, error) { } return float64(i), nil case string: - v = strings.Trim(v, sql.NumericCutSet) - i, err := strconv.ParseFloat(v, 64) - if err != nil { - // parse the first longest valid numbers - s := numre.FindString(v) - i, _ = strconv.ParseFloat(s, 64) - return i, sql.ErrTruncatedIncorrect.New(t.String(), v) + var err error + truncStr, didTrunc := TruncateStringToDouble(v) + if didTrunc { + err = sql.ErrTruncatedIncorrect.New(t, v) } - return i, nil + f, _ := strconv.ParseFloat(truncStr, 64) + return f, err case bool: if v { return 1, nil @@ -1313,3 +1475,84 @@ func CoalesceInt(val interface{}) (int, bool) { return 0, false } } + +const ( + // IntCutSet is the set of characters that should be trimmed from the beginning and end of a string + // when converting to a signed or unsigned integer + IntCutSet = " \t" + + // NumericCutSet is the set of characters to trim from a string before converting it to a number. + NumericCutSet = " \t\n\r" +) + +// TruncateStringToInt trims any whitespace from s, then truncates the string to the left most characters that make +// up a valid integer. Empty strings are converted "0". Additionally, returns a flag indicating if truncation occurred. +func TruncateStringToInt(s string) (string, bool) { + var seenDigit bool + s = strings.Trim(s, IntCutSet) + i, n := 0, len(s) + for ; i < n; i++ { + c := rune(s[i]) + if unicode.IsDigit(c) { + seenDigit = true + continue + } + if i == 0 && (c == '-' || c == '+') { + continue + } + break + } + if !seenDigit { + return "0", i != n + } + return s[:i], i != n +} + +// TruncateStringToDouble trims any whitespace from s, then truncates the string to the left most characters that make +// up a valid double. Empty strings are converted "0". Additionally, returns a flag indicating if truncation occurred. +func TruncateStringToDouble(s string) (string, bool) { + var signIndex int + var seenDigit, seenDot, seenExp bool + s = strings.Trim(s, NumericCutSet) + i, n := 0, len(s) + for ; i < n; i++ { + char := rune(s[i]) + if unicode.IsDigit(char) { + seenDigit = true + continue + } + if char == '.' && !seenDot { + seenDot = true + continue + } + if (char == 'e' || char == 'E') && !seenExp && seenDigit { + seenExp = true + signIndex = i + 1 // allow a sign following exponent + continue + } + if i == signIndex && (char == '-' || char == '+') { + continue + } + break + } + if !seenDigit { + return "0", i != n + } + return s[:i], i != n +} + +// ConvertHexBlobToDecimalForNumericContext converts byte array value to unsigned int value if originType is BLOB type. +// This function is called when convertTo type is number type only. The hex literal values are parsed into blobs as +// binary string as default, but for numeric context, the value should be a number. +// Byte arrays of other SQL types are not handled here. +func ConvertHexBlobToDecimalForNumericContext(val interface{}, originType sql.Type) (interface{}, error) { + if bin, isBinary := val.([]byte); isBinary && IsBlobType(originType) { + stringVal := hex.EncodeToString(bin) + decimalNum, err := strconv.ParseUint(stringVal, 16, 64) + if err != nil { + return nil, errors.New("failed to convert hex blob value to unsigned int") + } + val = decimalNum + } + return val, nil +} diff --git a/sql/types/number_test.go b/sql/types/number_test.go index 8c7d2fca0a..695f61fcc0 100644 --- a/sql/types/number_test.go +++ b/sql/types/number_test.go @@ -194,6 +194,8 @@ func TestNumberConvert(t *testing.T) { {typ: Uint64, inp: "01000", exp: uint64(1000), err: false, inRange: sql.InRange}, {typ: Uint64, inp: true, exp: uint64(1), err: false, inRange: sql.InRange}, {typ: Uint64, inp: false, exp: uint64(0), err: false, inRange: sql.InRange}, + {typ: Uint64, inp: "123.9abc", exp: uint64(123), err: false, inRange: sql.InRange}, + {typ: Uint64, inp: "+123.9abc", exp: uint64(123), err: false, inRange: sql.InRange}, {typ: Float32, inp: "22.25", exp: float32(22.25), err: false, inRange: sql.InRange}, {typ: Float32, inp: []byte{90, 140, 228, 206, 116}, exp: float32(388910861940), err: false, inRange: sql.InRange}, {typ: Float64, inp: float32(893.875), exp: float64(893.875), err: false, inRange: sql.InRange}, @@ -216,6 +218,7 @@ func TestNumberConvert(t *testing.T) { {typ: Uint32, inp: math.MaxUint32 + 1, exp: uint32(math.MaxUint32), err: false, inRange: sql.OutOfRange}, {typ: Uint32, inp: -1, exp: uint32(math.MaxUint32), err: false, inRange: sql.OutOfRange}, {typ: Uint64, inp: -1, exp: uint64(math.MaxUint64), err: false, inRange: sql.OutOfRange}, + {typ: Uint64, inp: "-1", exp: uint64(math.MaxUint64), err: false, inRange: sql.OutOfRange}, {typ: Float32, inp: math.MaxFloat32 * 2, exp: float32(math.MaxFloat32), err: false, inRange: sql.OutOfRange}, } @@ -225,6 +228,9 @@ func TestNumberConvert(t *testing.T) { if test.err { assert.Error(t, err) } else { + if sql.ErrTruncatedIncorrect.Is(err) { + err = nil + } require.NoError(t, err) assert.Equal(t, test.exp, val) assert.Equal(t, test.inRange, inRange) @@ -236,6 +242,227 @@ func TestNumberConvert(t *testing.T) { } } +func TestNumberConvertRound(t *testing.T) { + ctx := sql.NewEmptyContext() + tests := []struct { + typ sql.Type + inp interface{} + exp interface{} + err bool + inRange sql.ConvertInRange + }{ + // Boolean, Int8, Uint8, ... all use convertToInt64 + { + typ: Int64, + inp: "", + exp: int64(0), + err: false, + inRange: sql.InRange, + }, + { + typ: Int64, + inp: " \t", + exp: int64(0), + err: false, + inRange: sql.InRange, + }, + { + typ: Int64, + inp: "!@#$%^&*()", + exp: int64(0), + err: true, + inRange: sql.InRange, + }, + { + typ: Int64, + inp: "1.1", + exp: int64(1), + err: false, + inRange: sql.InRange, + }, + { + typ: Int64, + inp: "1.9", + exp: int64(2), + err: false, + inRange: sql.InRange, + }, + { + typ: Int64, + inp: "100.1ABC", + exp: int64(100), + err: true, + inRange: sql.InRange, + }, + { + typ: Int64, + inp: "100.9ABC", + exp: int64(101), + err: true, + inRange: sql.InRange, + }, + { + typ: Int64, + inp: ".123ABC", + exp: int64(0), + err: true, + inRange: sql.InRange, + }, + { + typ: Int64, + inp: "1.1.1", + exp: int64(1), + err: true, + inRange: sql.InRange, + }, + { + typ: Int64, + inp: "+1", + exp: int64(1), + err: false, + inRange: sql.InRange, + }, + { + typ: Int64, + inp: "-1", + exp: int64(-1), + err: false, + inRange: sql.InRange, + }, + { + typ: Int64, + inp: "+ 1", + exp: int64(0), + err: true, + inRange: sql.InRange, + }, + { + typ: Int64, + inp: "- 1", + exp: int64(0), + err: true, + inRange: sql.InRange, + }, + { + typ: Int64, + inp: "+-+-1", + exp: int64(0), + err: false, + inRange: sql.InRange, + }, + + { + typ: Uint64, + inp: "", + exp: uint64(0), + err: false, + inRange: sql.InRange, + }, + { + typ: Uint64, + inp: " \t", + exp: uint64(0), + err: false, + inRange: sql.InRange, + }, + { + typ: Uint64, + inp: "!@#$%^&*()", + exp: uint64(0), + err: true, + inRange: sql.InRange, + }, + { + typ: Uint64, + inp: "1.1", + exp: uint64(1), + err: false, + inRange: sql.InRange, + }, + { + typ: Uint64, + inp: "1.9", + exp: uint64(2), + err: false, + inRange: sql.InRange, + }, + { + typ: Uint64, + inp: "100.1ABC", + exp: uint64(100), + err: true, + inRange: sql.InRange, + }, + { + typ: Uint64, + inp: "100.9ABC", + exp: uint64(101), + err: true, + inRange: sql.InRange, + }, + { + typ: Uint64, + inp: ".123ABC", + exp: uint64(0), + err: true, + inRange: sql.InRange, + }, + { + typ: Uint64, + inp: "1.1.1", + exp: uint64(1), + err: true, + inRange: sql.InRange, + }, + { + typ: Uint64, + inp: "+1", + exp: uint64(1), + err: false, + inRange: sql.InRange, + }, + { + typ: Uint64, + inp: "-1", + exp: uint64(math.MaxUint64), + err: false, + inRange: sql.OutOfRange, + }, + { + typ: Uint64, + inp: "+ 1", + exp: uint64(0), + err: true, + inRange: sql.InRange, + }, + { + typ: Uint64, + inp: "- 1", + exp: uint64(0), + err: true, + inRange: sql.InRange, + }, + { + typ: Uint64, + inp: "+-+-1", + exp: uint64(0), + err: false, + inRange: sql.InRange, + }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("%v %v %v", test.typ, test.inp, test.exp), func(t *testing.T) { + val, inRange, err := test.typ.(sql.RoundingNumberType).ConvertRound(ctx, test.inp) + if test.err { + assert.True(t, sql.ErrTruncatedIncorrect.Is(err)) + } + assert.Equal(t, test.exp, val) + assert.Equal(t, test.inRange, inRange) + }) + } +} + func TestNumberSQL_BooleanFromBoolean(t *testing.T) { val, err := Boolean.SQL(sql.NewEmptyContext(), nil, true) require.NoError(t, err) @@ -283,3 +510,204 @@ func TestNumberString(t *testing.T) { }) } } + +func TestTruncateStringToInt(t *testing.T) { + tests := []struct { + input string + exp string + expTrunc bool + }{ + { + input: "1", + exp: "1", + expTrunc: false, + }, + { + // Whitespace does not count as truncation + input: " \t 1 \t ", + exp: "1", + expTrunc: false, + }, + { + // Newlines do count as part of truncation + input: " \t\n1", + exp: "0", + expTrunc: true, + }, + { + input: "123abc", + exp: "123", + expTrunc: true, + }, + { + input: "abc", + exp: "0", + expTrunc: true, + }, + { + // Leading sign is fine + input: "+123", + exp: "+123", + expTrunc: false, + }, + { + // Leading sign is fine + input: "-123", + exp: "-123", + expTrunc: false, + }, + { + // Repeated signs + input: "+-+-+-123", + exp: "0", + expTrunc: true, + }, + { + // Space after sign + input: "+ 123", + exp: "0", + expTrunc: true, + }, + { + // Valid float strings are not valid ints + input: "1.23", + exp: "1", + expTrunc: true, + }, + { + // Scientific float notation is not valid + input: "123.456e10", + exp: "123", + expTrunc: true, + }, + { + // Scientific notation is not valid + input: "123e10", + exp: "123", + expTrunc: true, + }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("%v", test.input), func(t *testing.T) { + truncStr, didTrunc := TruncateStringToInt(test.input) + assert.Equal(t, test.exp, truncStr) + assert.Equal(t, test.expTrunc, didTrunc) + }) + } +} + +func TestTruncateStringToDouble(t *testing.T) { + tests := []struct { + input string + exp string + expTrunc bool + }{ + { + input: "1", + exp: "1", + expTrunc: false, + }, + { + // Whitespace does not count as truncation + input: " \t\n 1 \t\n ", + exp: "1", + expTrunc: false, + }, + { + input: "123abc", + exp: "123", + expTrunc: true, + }, + { + input: "abc", + exp: "0", + expTrunc: true, + }, + { + // Leading sign is fine + input: "+123", + exp: "+123", + expTrunc: false, + }, + { + // Leading sign is fine + input: "-123", + exp: "-123", + expTrunc: false, + }, + { + // Repeated signs + input: "+-+-+-123", + exp: "0", + expTrunc: true, + }, + { + // Space after sign + input: "+ 123", + exp: "0", + expTrunc: true, + }, + { + // Valid float strings are not valid ints + input: "1.23", + exp: "1.23", + expTrunc: false, + }, + { + // Scientific notation + input: "123.456e10", + exp: "123.456e10", + expTrunc: false, + }, + { + // Scientific notation + input: "123e10", + exp: "123e10", + expTrunc: false, + }, + { + // Scientific notation + input: "+123.456e-10", + exp: "+123.456e-10", + expTrunc: false, + }, + { + // Scientific notation truncates + input: "+123.456e-10notaumber", + exp: "+123.456e-10", + expTrunc: true, + }, + { + // Invalid Scientific notation + input: "e123", + exp: "0", + expTrunc: true, + }, + { + // Invalid Scientific notation + input: ".e1", + exp: "0", + expTrunc: true, + }, + { + // Invalid Scientific notation + input: "1e2e3", + exp: "1e2", + expTrunc: true, + }, + { + input: ".0e123", + exp: ".0e123", + expTrunc: false, + }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("%v", test.input), func(t *testing.T) { + truncStr, didTrunc := TruncateStringToDouble(test.input) + assert.Equal(t, test.exp, truncStr) + assert.Equal(t, test.expTrunc, didTrunc) + }) + } +} diff --git a/sql/types/strings.go b/sql/types/strings.go index f71e009375..cc8171ee55 100644 --- a/sql/types/strings.go +++ b/sql/types/strings.go @@ -729,13 +729,13 @@ func (t StringType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes. dest = append(dest, v...) valueBytes = dest[start:] case int, int8, int16, int32, int64: - num, _, err := convertToInt64(Int64.(NumberTypeImpl_), v) + num, _, err := convertToInt64(Int64.(NumberTypeImpl_), v, false) if err != nil { return sqltypes.Value{}, err } valueBytes = strconv.AppendInt(dest, num, 10) case uint, uint8, uint16, uint32, uint64: - num, _, err := convertToUint64(Int64.(NumberTypeImpl_), v) + num, _, err := convertToUint64(Int64.(NumberTypeImpl_), v, false) if err != nil { return sqltypes.Value{}, err } diff --git a/sql/types/tuple.go b/sql/types/tuple.go index 23bf85c730..13d62095f1 100644 --- a/sql/types/tuple.go +++ b/sql/types/tuple.go @@ -82,7 +82,7 @@ func (t TupleType) Convert(ctx context.Context, v interface{}) (interface{}, sql for i, typ := range t { var err error result[i], _, err = typ.Convert(ctx, vals[i]) - if err != nil { + if err != nil && !sql.ErrTruncatedIncorrect.Is(err) { return nil, sql.OutOfRange, err } }