Skip to content

Commit bfaab5d

Browse files
author
James Cor
committed
truncate in Convert function
1 parent 2d9dd1a commit bfaab5d

File tree

11 files changed

+121
-30
lines changed

11 files changed

+121
-30
lines changed

enginetest/enginetests.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3960,9 +3960,9 @@ func TestWindowRangeFrames(t *testing.T, harness Harness) {
39603960
TestQueryWithContext(t, ctx, e, harness, `SELECT sum(y) over (partition by z order by date range between unbounded preceding and interval '1' DAY following) FROM c order by x`, []sql.Row{{float64(1)}, {float64(1)}, {float64(1)}, {float64(1)}, {float64(5)}, {float64(5)}, {float64(10)}, {float64(10)}, {float64(10)}, {float64(10)}}, nil, nil, nil)
39613961
TestQueryWithContext(t, ctx, e, harness, `SELECT count(y) over (partition by z order by date range between interval '1' DAY following and interval '2' DAY following) FROM c order by x`, []sql.Row{{1}, {1}, {1}, {1}, {1}, {0}, {2}, {2}, {0}, {0}}, nil, nil, nil)
39623962
TestQueryWithContext(t, ctx, e, harness, `SELECT count(y) over (partition by z order by date range between interval '1' DAY preceding and interval '2' DAY following) FROM c order by x`, []sql.Row{{4}, {4}, {4}, {5}, {2}, {2}, {4}, {4}, {4}, {4}}, nil, nil, nil)
3963+
TestQueryWithContext(t, ctx, e, harness, "SELECT sum(y) over (partition by z order by date range interval 'e' DAY preceding) FROM c order by x", []sql.Row{{float64(0)}, {float64(0)}, {float64(0)}, {float64(1)}, {float64(1)}, {float64(3)}, {float64(1)}, {float64(1)}, {float64(4)}, {float64(4)}}, nil, nil, nil)
39633964

39643965
AssertErr(t, e, harness, "SELECT sum(y) over (partition by z range between unbounded preceding and interval '1' DAY following) FROM c order by x", nil, aggregation.ErrRangeInvalidOrderBy)
3965-
AssertErr(t, e, harness, "SELECT sum(y) over (partition by z order by date range interval 'e' DAY preceding) FROM c order by x", nil, sql.ErrInvalidValue)
39663966
}
39673967

39683968
func TestNamedWindows(t *testing.T, harness Harness) {

enginetest/queries/json_table_queries.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ var JSONTableScriptTests = []ScriptTest{
571571
},
572572
{
573573
Query: "SELECT * FROM JSON_TABLE('{\"c1\":\"abc\"}', '$' COLUMNS(c1 INT PATH '$.c1' DEFAULT 'def' ON ERROR)) as jt;",
574-
ExpectedErrStr: "error: 'def' is not a valid value for 'int'",
574+
ExpectedErrStr: "Invalid JSON text in argument 1 to function JSON_TABLE: \"Invalid value.\"",
575575
},
576576
},
577577
},
@@ -612,7 +612,7 @@ var JSONTableScriptTests = []ScriptTest{
612612
},
613613
{
614614
Query: "SELECT * FROM JSON_TABLE('{\"c1\":\"abc\"}', '$' COLUMNS(c1 INT PATH '$.c1' ERROR ON ERROR)) as jt;",
615-
ExpectedErrStr: "error: 'abc' is not a valid value for 'int'",
615+
ExpectedErrStr: "Invalid JSON text in argument 1 to function JSON_TABLE: \"Invalid value.\"",
616616
},
617617
},
618618
},

sql/core.go

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,11 @@ func ConvertToBool(ctx *Context, v interface{}) (bool, error) {
317317
case float64:
318318
return b != 0, nil
319319
case string:
320-
bFloat, err := strconv.ParseFloat(TrimStringToNumberPrefix(ctx, b, false), 64)
320+
truncStr, didTrunc := TruncateStringToInt(b)
321+
if didTrunc {
322+
ctx.Warn(mysql.ERTruncatedWrongValue, "%s", ErrTruncatedIncorrect.New("INTEGER", b))
323+
}
324+
bFloat, err := strconv.ParseFloat(truncStr, 64)
321325
if err != nil {
322326
return false, nil
323327
}
@@ -382,6 +386,60 @@ func convertEmptyStringToZero(s string) string {
382386
return s
383387
}
384388

389+
// TruncateStringToInt trims any whitespace from s, then truncates the string to the left most characters that make
390+
// up a valid integer. Empty strings are converted "0". Additionally, returns a flag indicating if truncation occurred.
391+
func TruncateStringToInt(s string) (string, bool) {
392+
s = strings.Trim(s, IntCutSet)
393+
i, n := 0, len(s)
394+
for ; i < n; i++ {
395+
c := rune(s[i])
396+
if unicode.IsDigit(c) {
397+
continue
398+
}
399+
if i == 0 && (c == '-' || c == '+') {
400+
continue
401+
}
402+
break
403+
}
404+
if i == 0 {
405+
return "0", i != n
406+
}
407+
return s[:i], i != n
408+
}
409+
410+
// TruncateStringToDouble trims any whitespace from s, then truncates the string to the left most characters that make
411+
// up a valid double. Empty strings are converted "0". Additionally, returns a flag indicating if truncation occurred.
412+
func TruncateStringToDouble(s string) (string, bool) {
413+
var signIndex int
414+
var seenDigit, seenDot, seenExp bool
415+
s = strings.Trim(s, NumericCutSet)
416+
i, n := 0, len(s)
417+
for ; i < n; i++ {
418+
char := rune(s[i])
419+
if unicode.IsDigit(char) {
420+
seenDigit = true
421+
continue
422+
}
423+
if char == '.' && !seenDot {
424+
seenDot = true
425+
continue
426+
}
427+
if (char == 'e' || char == 'E') && !seenExp && seenDigit {
428+
seenExp = true
429+
signIndex = i + 1 // allow a sign following exponent
430+
continue
431+
}
432+
if i == signIndex && (char == '-' || char == '+') {
433+
continue
434+
}
435+
break
436+
}
437+
if i == 0 {
438+
return "0", i != n
439+
}
440+
return s[:i], i != n
441+
}
442+
385443
var ErrVectorInvalidBinaryLength = errors.NewKind("cannot convert BINARY(%d) to vector, byte length must be a multiple of 4 bytes")
386444

387445
// DecodeVector decodes a byte slice that represents a vector. This is needed for distance functions.

sql/expression/convert.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ func convertValue(ctx *sql.Context, val interface{}, castTo string, originType s
421421
}
422422
return num, nil
423423
case ConvertToYear:
424-
value, err := convertHexBlobToDecimalForNumericContext(val, originType)
424+
value, err := prepareForNumericContext(ctx, val, originType, true)
425425
if err != nil {
426426
return nil, err
427427
}

sql/expression/interval.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package expression
1616

1717
import (
1818
"fmt"
19+
"github.com/dolthub/vitess/go/mysql"
1920
"regexp"
2021
"strconv"
2122
"strings"
@@ -140,7 +141,10 @@ func (i *Interval) EvalDelta(ctx *sql.Context, row sql.Row) (*TimeDelta, error)
140141
} else {
141142
val, _, err = types.Int64.Convert(ctx, val)
142143
if err != nil {
143-
return nil, err
144+
if !sql.ErrTruncatedIncorrect.Is(err) {
145+
return nil, err
146+
}
147+
ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error())
144148
}
145149

146150
num := val.(int64)

sql/expression/procedurereference.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ func (ppr *ProcedureReference) InitializeVariable(ctx *sql.Context, name string,
6969
}
7070
convertedVal, _, err := sqlType.Convert(ctx, val)
7171
if err != nil {
72+
if sql.ErrTruncatedIncorrect.Is(err) {
73+
return sql.ErrInvalidValue.New(val, sqlType)
74+
}
7275
return err
7376
}
7477
lowerName := strings.ToLower(name)

sql/index_builder.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ import (
2121

2222
"github.com/shopspring/decimal"
2323
"gopkg.in/src-d/go-errors.v1"
24+
25+
"github.com/dolthub/vitess/go/mysql"
2426
)
2527

2628
var (
@@ -229,9 +231,14 @@ func (b *MySQLIndexBuilder) convertKey(ctx *Context, colType Type, keyType Type,
229231
if et, ok := colType.(ExtendedType); ok {
230232
return et.ConvertToType(ctx, keyType.(ExtendedType), key)
231233
} else {
232-
// TODO: would it make more sense for colType.Convert to handle the truncation or just do it here?
233-
key, _, err := colType.Convert(ctx, key)
234-
return key, err
234+
k, _, err := colType.Convert(ctx, key)
235+
if err != nil {
236+
if !ErrTruncatedIncorrect.Is(err) {
237+
return nil, err
238+
}
239+
ctx.Warn(mysql.ERTruncatedWrongValue, "%s", err.Error())
240+
}
241+
return k, nil
235242
}
236243
}
237244

sql/iters/rel_iters.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,11 +290,17 @@ func (c *JsonTableCol) Next(ctx *sql.Context, obj interface{}, pass bool, ord in
290290
val, _, err = c.Opts.Typ.Convert(ctx, val)
291291
if err != nil {
292292
if c.Opts.ErrOnErr {
293-
return nil, err
293+
if sql.ErrTruncatedIncorrect.Is(err) {
294+
return nil, sql.ErrInvalidJSONText.New(c.pos+1, "JSON_TABLE", "Invalid value.")
295+
}
296+
return nil, sql.ErrInvalidJSONText.New(c.pos+1, "JSON_TABLE", err.Error())
294297
}
295298
val, _, err = c.Opts.Typ.Convert(ctx, c.Opts.DefErrVal)
296299
if err != nil {
297-
return nil, err
300+
if sql.ErrTruncatedIncorrect.Is(err) {
301+
return nil, sql.ErrInvalidJSONText.New(c.pos+1, "JSON_TABLE", "Invalid value.")
302+
}
303+
return nil, sql.ErrInvalidJSONText.New(c.pos+1, "JSON_TABLE", err.Error())
298304
}
299305
}
300306

sql/rowexec/insert.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,14 @@ func (i *insertIter) Next(ctx *sql.Context) (returnRow sql.Row, returnErr error)
120120
ctxWithValues := context.WithValue(ctx.Context, types.ColumnNameKey, col.Name)
121121
ctxWithValues = context.WithValue(ctxWithValues, types.RowNumberKey, i.rowNumber)
122122
ctxWithColumnInfo := ctx.WithContext(ctxWithValues)
123-
converted, inRange, cErr := col.Type.Convert(ctxWithColumnInfo, row[idx])
123+
val := row[idx]
124+
// TODO: check mysql strict sql_mode
125+
converted, inRange, cErr := col.Type.Convert(ctxWithColumnInfo, val)
124126
if cErr == nil && !inRange {
125-
cErr = sql.ErrValueOutOfRange.New(row[idx], col.Type)
127+
cErr = sql.ErrValueOutOfRange.New(val, col.Type)
128+
}
129+
if sql.ErrTruncatedIncorrect.Is(cErr) {
130+
cErr = sql.ErrInvalidValue.New(val, col.Type)
126131
}
127132
if cErr != nil {
128133
// Ignore individual column errors when INSERT IGNORE, UPDATE IGNORE, etc. is specified.

sql/types/number.go

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -961,21 +961,24 @@ func convertToInt64(t NumberTypeImpl_, v interface{}) (int64, sql.ConvertInRange
961961
case float32:
962962
if v > float32(math.MaxInt64) {
963963
return math.MaxInt64, sql.OutOfRange, nil
964-
} else if v < float32(math.MinInt64) {
964+
}
965+
if v < float32(math.MinInt64) {
965966
return math.MinInt64, sql.OutOfRange, nil
966967
}
967968
return int64(math.Round(float64(v))), sql.InRange, nil
968969
case float64:
969970
if v > float64(math.MaxInt64) {
970971
return math.MaxInt64, sql.OutOfRange, nil
971-
} else if v < float64(math.MinInt64) {
972+
}
973+
if v < float64(math.MinInt64) {
972974
return math.MinInt64, sql.OutOfRange, nil
973975
}
974976
return int64(math.Round(v)), sql.InRange, nil
975977
case decimal.Decimal:
976978
if v.GreaterThan(dec_int64_max) {
977979
return dec_int64_max.IntPart(), sql.OutOfRange, nil
978-
} else if v.LessThan(dec_int64_min) {
980+
}
981+
if v.LessThan(dec_int64_min) {
979982
return dec_int64_min.IntPart(), sql.OutOfRange, nil
980983
}
981984
return v.Round(0).IntPart(), sql.InRange, nil
@@ -986,23 +989,25 @@ func convertToInt64(t NumberTypeImpl_, v interface{}) (int64, sql.ConvertInRange
986989
}
987990
return i, sql.InRange, nil
988991
case string:
989-
v = strings.Trim(v, sql.IntCutSet)
990-
if v == "" {
991-
// StringType{}.Zero() returns empty string, but should represent "0" for number value
992-
return 0, sql.InRange, nil
992+
// TODO: this currently assumes we are always rounding to preserve behavior
993+
// but we should only be rounding on inserts
994+
var err error
995+
truncStr, didTrunc := sql.TruncateStringToDouble(v)
996+
if didTrunc {
997+
err = sql.ErrTruncatedIncorrect.New(t, v)
993998
}
994999
// Parse first an integer, which allows for more values than float64
995-
i, err := strconv.ParseInt(v, 10, 64)
996-
if err == nil {
997-
return i, sql.InRange, nil
1000+
i, pErr := strconv.ParseInt(truncStr, 10, 64)
1001+
if pErr == nil {
1002+
return i, sql.InRange, err
9981003
}
999-
// If that fails, try as a float and truncate it to integral
1000-
f, err := strconv.ParseFloat(v, 64)
1001-
if err != nil {
1002-
return 0, sql.OutOfRange, sql.ErrInvalidValue.New(v, t.String())
1004+
// If that fails, try as a float and round it to integral
1005+
f, pErr := strconv.ParseFloat(truncStr, 64)
1006+
if pErr == nil {
1007+
f = math.Round(f)
1008+
return int64(f), sql.InRange, err
10031009
}
1004-
f = math.Round(f)
1005-
return int64(f), sql.InRange, nil
1010+
return 0, sql.OutOfRange, sql.ErrInvalidValue.New(v, t.String())
10061011
case bool:
10071012
if v {
10081013
return 1, sql.InRange, nil

0 commit comments

Comments
 (0)