diff --git a/go.mod b/go.mod index 692abdee79..68d8f88fd4 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( github.com/lestrrat-go/strftime v1.0.4 github.com/pkg/errors v0.9.1 github.com/pmezard/go-difflib v1.0.0 - github.com/shopspring/decimal v1.3.1 + github.com/shopspring/decimal v1.4.0 github.com/sirupsen/logrus v1.8.1 github.com/stretchr/testify v1.9.0 go.opentelemetry.io/otel v1.31.0 diff --git a/go.sum b/go.sum index bab2369fe9..ee3d7c83d5 100644 --- a/go.sum +++ b/go.sum @@ -66,8 +66,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= -github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= +github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/sirupsen/logrus v1.8.1 h1:dJKuHgqk1NNQlqoA6BTlM1Wf9DOH3NBjQyu0h9+AZZE= github.com/sirupsen/logrus v1.8.1/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= diff --git a/server/handler.go b/server/handler.go index 2275ca7a2d..0cdcea2f57 100644 --- a/server/handler.go +++ b/server/handler.go @@ -495,6 +495,8 @@ func (h *Handler) doQuery( r, err = resultForEmptyIter(sqlCtx, rowIter, resultFields) } else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) { r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields, buf) + } else if vr, ok := rowIter.(sql.ValueRowIter); ok && vr.IsValueRowIter(sqlCtx) { + r, processedAtLeastOneBatch, err = h.resultForValueRowIter(sqlCtx, c, schema, vr, resultFields, buf, callback, more) } else { r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, c, schema, rowIter, callback, resultFields, more, buf) } @@ -768,6 +770,149 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s return r, processedAtLeastOneBatch, nil } +func (h *Handler) resultForValueRowIter(ctx *sql.Context, c *mysql.Conn, schema sql.Schema, iter sql.ValueRowIter, resultFields []*querypb.Field, buf *sql.ByteBuffer, callback func(*sqltypes.Result, bool) error, more bool) (*sqltypes.Result, bool, error) { + defer trace.StartRegion(ctx, "Handler.resultForValueRowIter").End() + + eg, ctx := ctx.NewErrgroup() + pan2err := func(err *error) { + if recoveredPanic := recover(); recoveredPanic != nil { + wrappedErr := fmt.Errorf("handler caught panic: %v\n%s", recoveredPanic, debug.Stack()) + *err = goerrors.Join(*err, wrappedErr) + } + } + + // TODO: poll for closed connections should obviously also run even if + // we're doing something with an OK result or a single row result, etc. + // This should be in the caller. + pollCtx, cancelF := ctx.NewSubContext() + eg.Go(func() (err error) { + defer pan2err(&err) + return h.pollForClosedConnection(pollCtx, c) + }) + + // Default waitTime is one minute if there is no timeout configured, in which case + // it will loop to iterate again unless the socket died by the OS timeout or other problems. + // If there is a timeout, it will be enforced to ensure that Vitess has a chance to + // call Handler.CloseConnection() + waitTime := 1 * time.Minute + if h.readTimeout > 0 { + waitTime = h.readTimeout + } + timer := time.NewTimer(waitTime) + defer timer.Stop() + + wg := sync.WaitGroup{} + wg.Add(2) + + // Wrap the callback to include a BytesBuffer.Reset() for non-cursor requests, to + // clean out rows that have already been spooled. + resetCallback := func(r *sqltypes.Result, more bool) error { + // A server-side cursor allows the caller to fetch results cached on the server-side, + // so if a cursor exists, we can't release the buffer memory yet. + if c.StatusFlags&uint16(mysql.ServerCursorExists) != 0 { + defer buf.Reset() + } + return callback(r, more) + } + + // TODO: send results instead of rows? + // Read rows from iter and send them off + var rowChan = make(chan sql.ValueRow, 512) + eg.Go(func() (err error) { + defer pan2err(&err) + defer wg.Done() + defer close(rowChan) + for { + select { + case <-ctx.Done(): + return context.Cause(ctx) + default: + row, err := iter.NextValueRow(ctx) + if err == io.EOF { + return nil + } + if err != nil { + return err + } + select { + case rowChan <- row: + case <-ctx.Done(): + return nil + } + } + } + }) + + var res *sqltypes.Result + var processedAtLeastOneBatch bool + eg.Go(func() (err error) { + defer pan2err(&err) + defer cancelF() + defer wg.Done() + for { + if res == nil { + res = &sqltypes.Result{ + Fields: resultFields, + Rows: make([][]sqltypes.Value, 0, rowsBatch), + } + } + if res.RowsAffected == rowsBatch { + if err := resetCallback(res, more); err != nil { + return err + } + res = nil + processedAtLeastOneBatch = true + continue + } + + select { + case <-ctx.Done(): + return context.Cause(ctx) + case <-timer.C: + if h.readTimeout != 0 { + // Cancel and return so Vitess can call the CloseConnection callback + ctx.GetLogger().Tracef("connection timeout") + return ErrRowTimeout.New() + } + case row, ok := <-rowChan: + if !ok { + return nil + } + resRow, err := RowValueToSQLValues(ctx, schema, row, buf) + if err != nil { + return err + } + ctx.GetLogger().Tracef("spooling result row %s", resRow) + res.Rows = append(res.Rows, resRow) + res.RowsAffected++ + if !timer.Stop() { + <-timer.C + } + } + timer.Reset(waitTime) + } + }) + + // Close() kills this PID in the process list, + // wait until all rows have be sent over the wire + eg.Go(func() (err error) { + defer pan2err(&err) + wg.Wait() + return iter.Close(ctx) + }) + + err := eg.Wait() + if err != nil { + ctx.GetLogger().WithError(err).Warn("error running query") + if verboseErrorLogging { + fmt.Printf("Err: %+v", err) + } + return nil, false, err + } + + return res, processedAtLeastOneBatch, nil +} + // See https://dev.mysql.com/doc/internals/en/status-flags.html func setConnStatusFlags(ctx *sql.Context, c *mysql.Conn) error { ok, err := isSessionAutocommit(ctx) @@ -994,7 +1139,7 @@ func toSqlHelper(ctx *sql.Context, typ sql.Type, buf *sql.ByteBuffer, val interf return typ.SQL(ctx, nil, val) } ret, err := typ.SQL(ctx, buf.Get(), val) - buf.Grow(ret.Len()) + buf.Grow(ret.Len()) // TODO: shouldn't we check capacity beforehand? return ret, err } @@ -1037,6 +1182,39 @@ func RowToSQL(ctx *sql.Context, sch sql.Schema, row sql.Row, projs []sql.Express return outVals, nil } +func RowValueToSQLValues(ctx *sql.Context, sch sql.Schema, row sql.ValueRow, buf *sql.ByteBuffer) ([]sqltypes.Value, error) { + if len(sch) == 0 { + return []sqltypes.Value{}, nil + } + var err error + outVals := make([]sqltypes.Value, len(sch)) + for i, col := range sch { + // TODO: remove this check once all Types implement this + valType, ok := col.Type.(sql.ValueType) + if !ok { + if row[i].IsNull() { + outVals[i] = sqltypes.NULL + continue + } + outVals[i] = sqltypes.MakeTrusted(row[i].Typ, row[i].Val) + continue + } + if buf == nil { + outVals[i], err = valType.SQLValue(ctx, row[i], nil) + if err != nil { + return nil, err + } + continue + } + outVals[i], err = valType.SQLValue(ctx, row[i], buf.Get()) + if err != nil { + return nil, err + } + buf.Grow(outVals[i].Len()) + } + return outVals, nil +} + func schemaToFields(ctx *sql.Context, s sql.Schema) []*querypb.Field { charSetResults := ctx.GetCharacterSetResults() fields := make([]*querypb.Field, len(s)) diff --git a/sql/analyzer/replace_sort.go b/sql/analyzer/replace_sort.go index 7a2e4b3ff7..534bd48b93 100644 --- a/sql/analyzer/replace_sort.go +++ b/sql/analyzer/replace_sort.go @@ -175,12 +175,12 @@ func replaceIdxSortHelper(ctx *sql.Context, scope *plan.Scope, node sql.Node, so sortFields[i] = sortField } else { sameSortFields = false - col2, _ := col.(sql.Expression2) + valCol, _ := col.(sql.ValueExpression) sortFields[i] = sql.SortField{ - Column: col, - Column2: col2, - NullOrdering: sortField.NullOrdering, - Order: sortField.Order, + Column: col, + ValueExprColumn: valCol, + NullOrdering: sortField.NullOrdering, + Order: sortField.Order, } } } diff --git a/sql/cache.go b/sql/cache.go index d664b2c957..ec772a0e9d 100644 --- a/sql/cache.go +++ b/sql/cache.go @@ -71,10 +71,10 @@ func (l *lruCache) Dispose() { } type rowsCache struct { - memory Freeable - reporter Reporter - rows []Row - rows2 []Row2 + memory Freeable + reporter Reporter + rows []Row + valueRows []ValueRow } func newRowsCache(memory Freeable, r Reporter) *rowsCache { @@ -92,17 +92,17 @@ func (c *rowsCache) Add(row Row) error { func (c *rowsCache) Get() []Row { return c.rows } -func (c *rowsCache) Add2(row2 Row2) error { +func (c *rowsCache) AddValueRow(row ValueRow) error { if !releaseMemoryIfNeeded(c.reporter, c.memory.Free) { return ErrNoMemoryAvailable.New() } - c.rows2 = append(c.rows2, row2) + c.valueRows = append(c.valueRows, row) return nil } -func (c *rowsCache) Get2() []Row2 { - return c.rows2 +func (c *rowsCache) GetValueRow() []ValueRow { + return c.valueRows } func (c *rowsCache) Dispose() { diff --git a/sql/convert_value.go b/sql/convert_value.go index d46fe4de4e..880b9f2f58 100644 --- a/sql/convert_value.go +++ b/sql/convert_value.go @@ -3,9 +3,9 @@ package sql import ( "fmt" - "github.com/dolthub/vitess/go/vt/proto/query" - "github.com/dolthub/go-mysql-server/sql/values" + + "github.com/dolthub/vitess/go/vt/proto/query" ) // ConvertToValue converts the interface to a sql value. @@ -90,11 +90,3 @@ func ConvertToValue(v interface{}) (Value, error) { return Value{}, fmt.Errorf("type %T not implemented", v) } } - -func MustConvertToValue(v interface{}) Value { - ret, err := ConvertToValue(v) - if err != nil { - panic(err) - } - return ret -} diff --git a/sql/core.go b/sql/core.go index c1e1f90b2a..44100f02c4 100644 --- a/sql/core.go +++ b/sql/core.go @@ -460,13 +460,13 @@ func DebugString(nodeOrExpression interface{}) string { panic(fmt.Sprintf("Expected sql.DebugString or fmt.Stringer for %T", nodeOrExpression)) } -// Expression2 is an experimental future interface alternative to Expression to provide faster access. -type Expression2 interface { +// ValueExpression is an experimental future interface alternative to Expression to provide faster access. +type ValueExpression interface { Expression - // Eval2 evaluates the given row frame and returns a result. - Eval2(ctx *Context, row Row2) (Value, error) - // Type2 returns the expression type. - Type2() Type2 + // EvalValue evaluates the given row frame and returns a result. + EvalValue(ctx *Context, row ValueRow) (Value, error) + // IsValueExpression indicates whether this expression and all its children support ValueExpression. + IsValueExpression() bool } var SystemVariables SystemVariableRegistry diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index e7cb0b8e15..c9612656e7 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -17,7 +17,7 @@ package expression import ( "fmt" - errors "gopkg.in/src-d/go-errors.v1" + "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/types" @@ -157,6 +157,60 @@ func (c *comparison) Compare(ctx *sql.Context, row sql.Row) (int, error) { return compareType.Compare(ctx, l, r) } +// CompareValue the two given values using the types of the expressions in the comparison. +func (c *comparison) CompareValue(ctx *sql.Context, row sql.ValueRow) (int, error) { + // TODO: avoid type assertions + lv, err := c.LeftChild.(sql.ValueExpression).EvalValue(ctx, row) + if err != nil { + return 0, err + } + rv, err := c.RightChild.(sql.ValueExpression).EvalValue(ctx, row) + if err != nil { + return 0, err + } + + if lv.IsNull() || rv.IsNull() { + return 0, nil + } + + lTyp, rTyp := c.LeftChild.Type().(sql.ValueType), c.RightChild.Type().(sql.ValueType) + if types.TypesEqual(lTyp, rTyp) { + return lTyp.(sql.ValueType).CompareValue(ctx, lv, rv) + } + + if types.IsNumber(lTyp) || types.IsNumber(rTyp) { + if types.IsUnsigned(lTyp) && types.IsUnsigned(rTyp) { + return types.Uint64.(sql.ValueType).CompareValue(ctx, lv, rv) + } + if types.IsSigned(lTyp) && types.IsSigned(rTyp) { + return types.Int64.(sql.ValueType).CompareValue(ctx, lv, rv) + } + if types.IsDecimal(lTyp) || types.IsDecimal(rTyp) { + return types.InternalDecimalType.(sql.ValueType).CompareValue(ctx, lv, rv) + } + return types.Float64.(sql.ValueType).CompareValue(ctx, lv, rv) + } + + return lTyp.CompareValue(ctx, lv, rv) +} + +// IsValueExpression returns whether every child supports sql.ValueExpression +func (c *comparison) IsValueExpression() bool { + l, ok := c.LeftChild.(sql.ValueExpression) + if !ok { + return false + } + r, ok := c.RightChild.(sql.ValueExpression) + if !ok { + return false + } + // TODO: only allow comparisons between Integers, Floats, Decimals, Bits and Year for now + if !types.IsNumber(c.LeftChild.Type()) || !types.IsNumber(c.RightChild.Type()) { + return false + } + return l.IsValueExpression() && r.IsValueExpression() +} + func (c *comparison) evalLeftAndRight(ctx *sql.Context, row sql.Row) (interface{}, interface{}, error) { left, err := c.Left().Eval(ctx, row) if err != nil { @@ -495,6 +549,7 @@ type GreaterThan struct { } var _ sql.Expression = (*GreaterThan)(nil) +var _ sql.ValueExpression = (*GreaterThan)(nil) var _ sql.CollationCoercible = (*GreaterThan)(nil) // NewGreaterThan creates a new GreaterThan expression. @@ -541,6 +596,23 @@ func (gt *GreaterThan) DebugString() string { return pr.String() } +// EvalValue implements the sql.ValueExpression interface. +func (gt *GreaterThan) EvalValue(ctx *sql.Context, row sql.ValueRow) (sql.Value, error) { + cmp, err := gt.CompareValue(ctx, row) + if err != nil { + return sql.NullValue, err + } + if cmp != 1 { + return sql.FalseValue, nil + } + return sql.TrueValue, nil +} + +// IsValueExpression implements the ValueExpression interface. +func (gt *GreaterThan) IsValueExpression() bool { + return gt.comparison.IsValueExpression() +} + // LessThan is a comparison that checks an expression is less than another. type LessThan struct { comparison @@ -566,10 +638,8 @@ func (lt *LessThan) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { if ErrNilOperand.Is(err) { return nil, nil } - return nil, err } - return result == -1, nil } @@ -593,6 +663,23 @@ func (lt *LessThan) DebugString() string { return pr.String() } +// EvalValue implements the sql.ValueExpression interface. +func (lt *LessThan) EvalValue(ctx *sql.Context, row sql.ValueRow) (sql.Value, error) { + cmp, err := lt.CompareValue(ctx, row) + if err != nil { + return sql.NullValue, err + } + if cmp != -1 { + return sql.FalseValue, nil + } + return sql.TrueValue, nil +} + +// IsValueExpression implements the ValueExpression interface. +func (lt *LessThan) IsValueExpression() bool { + return lt.comparison.IsValueExpression() +} + // GreaterThanOrEqual is a comparison that checks an expression is greater or equal to // another. type GreaterThanOrEqual struct { @@ -619,10 +706,8 @@ func (gte *GreaterThanOrEqual) Eval(ctx *sql.Context, row sql.Row) (interface{}, if ErrNilOperand.Is(err) { return nil, nil } - return nil, err } - return result > -1, nil } @@ -646,6 +731,23 @@ func (gte *GreaterThanOrEqual) DebugString() string { return pr.String() } +// EvalValue implements the sql.ValueExpression interface. +func (gte *GreaterThanOrEqual) EvalValue(ctx *sql.Context, row sql.ValueRow) (sql.Value, error) { + cmp, err := gte.CompareValue(ctx, row) + if err != nil { + return sql.NullValue, err + } + if cmp == -1 { + return sql.FalseValue, nil + } + return sql.TrueValue, nil +} + +// IsValueExpression implements the ValueExpression interface. +func (gte *GreaterThanOrEqual) IsValueExpression() bool { + return gte.comparison.IsValueExpression() +} + // LessThanOrEqual is a comparison that checks an expression is equal or lower than // another. type LessThanOrEqual struct { @@ -699,6 +801,23 @@ func (lte *LessThanOrEqual) DebugString() string { return pr.String() } +// EvalValue implements the sql.ValueExpression interface. +func (lte *LessThanOrEqual) EvalValue(ctx *sql.Context, row sql.ValueRow) (sql.Value, error) { + cmp, err := lte.CompareValue(ctx, row) + if err != nil { + return sql.NullValue, err + } + if cmp == 1 { + return sql.FalseValue, nil + } + return sql.TrueValue, nil +} + +// IsValueExpression implements the ValueExpression interface. +func (lte *LessThanOrEqual) IsValueExpression() bool { + return lte.comparison.IsValueExpression() +} + var ( // ErrUnsupportedInOperand is returned when there is an invalid righthand // operand in an IN operator. diff --git a/sql/expression/comparison_test.go b/sql/expression/comparison_test.go index a2ce6eeae1..710333d121 100644 --- a/sql/expression/comparison_test.go +++ b/sql/expression/comparison_test.go @@ -12,15 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -package expression_test +package expression import ( + "encoding/binary" "testing" + "github.com/dolthub/vitess/go/sqltypes" "github.com/stretchr/testify/require" "github.com/dolthub/go-mysql-server/sql" - "github.com/dolthub/go-mysql-server/sql/expression" "github.com/dolthub/go-mysql-server/sql/types" ) @@ -109,11 +110,11 @@ var likeComparisonCases = map[sql.Type]map[int][][]interface{}{ func TestEquals(t *testing.T) { require := require.New(t) for resultType, cmpCase := range comparisonCases { - get0 := expression.NewGetField(0, resultType, "col1", true) + get0 := NewGetField(0, resultType, "col1", true) require.NotNil(get0) - get1 := expression.NewGetField(1, resultType, "col2", true) + get1 := NewGetField(1, resultType, "col2", true) require.NotNil(get1) - eq := expression.NewEquals(get0, get1) + eq := NewEquals(get0, get1) require.NotNil(eq) require.Equal(types.Boolean, eq.Type()) for cmpResult, cases := range cmpCase { @@ -136,11 +137,11 @@ func TestEquals(t *testing.T) { func TestNullSafeEquals(t *testing.T) { require := require.New(t) for resultType, cmpCase := range comparisonCases { - get0 := expression.NewGetField(0, resultType, "col1", true) + get0 := NewGetField(0, resultType, "col1", true) require.NotNil(get0) - get1 := expression.NewGetField(1, resultType, "col2", true) + get1 := NewGetField(1, resultType, "col2", true) require.NotNil(get1) - seq := expression.NewNullSafeEquals(get0, get1) + seq := NewNullSafeEquals(get0, get1) require.NotNil(seq) require.Equal(types.Boolean, seq.Type()) for cmpResult, cases := range cmpCase { @@ -167,11 +168,11 @@ func TestNullSafeEquals(t *testing.T) { func TestLessThan(t *testing.T) { require := require.New(t) for resultType, cmpCase := range comparisonCases { - get0 := expression.NewGetField(0, resultType, "col1", true) + get0 := NewGetField(0, resultType, "col1", true) require.NotNil(get0) - get1 := expression.NewGetField(1, resultType, "col2", true) + get1 := NewGetField(1, resultType, "col2", true) require.NotNil(get1) - eq := expression.NewLessThan(get0, get1) + eq := NewLessThan(get0, get1) require.NotNil(eq) require.Equal(types.Boolean, eq.Type()) for cmpResult, cases := range cmpCase { @@ -194,11 +195,11 @@ func TestLessThan(t *testing.T) { func TestGreaterThan(t *testing.T) { require := require.New(t) for resultType, cmpCase := range comparisonCases { - get0 := expression.NewGetField(0, resultType, "col1", true) + get0 := NewGetField(0, resultType, "col1", true) require.NotNil(get0) - get1 := expression.NewGetField(1, resultType, "col2", true) + get1 := NewGetField(1, resultType, "col2", true) require.NotNil(get1) - eq := expression.NewGreaterThan(get0, get1) + eq := NewGreaterThan(get0, get1) require.NotNil(eq) require.Equal(types.Boolean, eq.Type()) for cmpResult, cases := range cmpCase { @@ -218,9 +219,49 @@ func TestGreaterThan(t *testing.T) { } } -func eval(t *testing.T, e sql.Expression, row sql.Row) interface{} { - t.Helper() - v, err := e.Eval(sql.NewEmptyContext(), row) - require.NoError(t, err) - return v +func TestValueComparison(t *testing.T) { + t.Skip("TODO: write tests for comparison between sql.Values") +} + +// BenchmarkComparison +// BenchmarkComparison-14 4426766 264.4 ns/op +func BenchmarkComparison(b *testing.B) { + ctx := sql.NewEmptyContext() + gf1 := NewGetField(0, types.Int64, "col1", true) + gf2 := NewGetField(1, types.Int64, "col2", true) + cmp := newComparison(gf1, gf2) + row := sql.Row{1, 1} + b.ResetTimer() + + for i := 0; i < b.N; i++ { + res, err := cmp.Compare(ctx, row) + require.NoError(b, err) + require.Equal(b, 0, res) + } +} + +// BenchmarkValueComparison +// BenchmarkValueComparison-14 4115744 285.8 ns/op +func BenchmarkValueComparison(b *testing.B) { + ctx := sql.NewEmptyContext() + gf1 := NewGetField(0, types.Int64, "col1", true) + gf2 := NewGetField(1, types.Int64, "col2", true) + cmp := newComparison(gf1, gf2) + row := sql.ValueRow{ + { + Val: binary.LittleEndian.AppendUint64(nil, uint64(1)), + Typ: sqltypes.Int64, + }, + { + Val: binary.LittleEndian.AppendUint64(nil, uint64(1)), + Typ: sqltypes.Int64, + }, + } + b.ResetTimer() + + for i := 0; i < b.N; i++ { + res, err := cmp.CompareValue(ctx, row) + require.NoError(b, err) + require.Equal(b, 0, res) + } } diff --git a/sql/expression/get_field.go b/sql/expression/get_field.go index f4ff9b429e..205b7b505f 100644 --- a/sql/expression/get_field.go +++ b/sql/expression/get_field.go @@ -25,8 +25,7 @@ import ( // GetField is an expression to get the field of a table. type GetField struct { - fieldType sql.Type - fieldType2 sql.Type2 + fieldType sql.Type // schemaFormatter is the schemaFormatter used to quote field names schemaFormatter sql.SchemaFormatter @@ -47,7 +46,7 @@ type GetField struct { } var _ sql.Expression = (*GetField)(nil) -var _ sql.Expression2 = (*GetField)(nil) +var _ sql.ValueExpression = (*GetField)(nil) var _ sql.CollationCoercible = (*GetField)(nil) var _ sql.IdExpression = (*GetField)(nil) @@ -58,13 +57,11 @@ func NewGetField(index int, fieldType sql.Type, fieldName string, nullable bool) // NewGetFieldWithTable creates a GetField expression with table name. The table name may be an alias. func NewGetFieldWithTable(index, tableId int, fieldType sql.Type, db, table, fieldName string, nullable bool) *GetField { - fieldType2, _ := fieldType.(sql.Type2) return &GetField{ db: db, table: table, fieldIndex: index, fieldType: fieldType, - fieldType2: fieldType2, name: fieldName, nullable: nullable, exprId: sql.ColumnId(index), @@ -133,13 +130,8 @@ func (p *GetField) Type() sql.Type { return p.fieldType } -// Type2 returns the type of the field, if this field has a sql.Type2. -func (p *GetField) Type2() sql.Type2 { - return p.fieldType2 -} - // ErrIndexOutOfBounds is returned when the field index is out of the bounds. -var ErrIndexOutOfBounds = errors.NewKind("unable to find field with index %d in row of %d columns") +var ErrIndexOutOfBounds = errors.NewKind("unable to find field with index %d in row of %d columns. \n This is a bug. Please file an issue here: https://github.com/dolthub/dolt/issues") // Eval implements the Expression interface. func (p *GetField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { @@ -149,12 +141,17 @@ func (p *GetField) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return row[p.fieldIndex], nil } -func (p *GetField) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { - if p.fieldIndex < 0 || p.fieldIndex >= row.Len() { - return sql.Value{}, ErrIndexOutOfBounds.New(p.fieldIndex, row.Len()) +// EvalValue implements the ValueExpression interface. +func (p *GetField) EvalValue(ctx *sql.Context, row sql.ValueRow) (sql.Value, error) { + if p.fieldIndex < 0 || p.fieldIndex >= len(row) { + return sql.Value{}, ErrIndexOutOfBounds.New(p.fieldIndex, len(row)) } + return row[p.fieldIndex], nil +} - return row.GetField(p.fieldIndex), nil +// IsValueRowIter implements the ValueExpression interface. +func (p *GetField) IsValueExpression() bool { + return true } // WithChildren implements the Expression interface. diff --git a/sql/expression/literal.go b/sql/expression/literal.go index 8fff9557a7..104c04fd97 100644 --- a/sql/expression/literal.go +++ b/sql/expression/literal.go @@ -34,7 +34,7 @@ type Literal struct { } var _ sql.Expression = &Literal{} -var _ sql.Expression2 = &Literal{} +var _ sql.ValueExpression = &Literal{} var _ sql.CollationCoercible = &Literal{} var _ sqlparser.Injectable = &Literal{} @@ -136,21 +136,19 @@ func (*Literal) Children() []sql.Expression { return nil } -func (lit *Literal) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { +// EvalValue implements the sql.ValueExpression interface. +func (lit *Literal) EvalValue(ctx *sql.Context, row sql.ValueRow) (sql.Value, error) { return lit.val2, nil } -func (lit *Literal) Type2() sql.Type2 { - t2, ok := lit.Typ.(sql.Type2) - if !ok { - panic(fmt.Errorf("expected Type2, but was %T", lit.Typ)) - } - return t2 +// IsValueExpression implements the ValueExpression interface. +func (lit *Literal) IsValueExpression() bool { + return types.IsInteger(lit.Typ) } // Value returns the literal value. -func (p *Literal) Value() interface{} { - return p.Val +func (lit *Literal) Value() interface{} { + return lit.Val } func (lit *Literal) WithResolvedChildren(children []any) (any, error) { diff --git a/sql/expression/namedliteral.go b/sql/expression/namedliteral.go index ebf8d80ded..ce5550dd54 100644 --- a/sql/expression/namedliteral.go +++ b/sql/expression/namedliteral.go @@ -25,7 +25,7 @@ type NamedLiteral struct { } var _ sql.Expression = NamedLiteral{} -var _ sql.Expression2 = NamedLiteral{} +var _ sql.ValueExpression = NamedLiteral{} var _ sql.CollationCoercible = NamedLiteral{} // NewNamedLiteral returns a new NamedLiteral. diff --git a/sql/expression/sort.go b/sql/expression/sort.go index d54d13ea77..6b75031677 100644 --- a/sql/expression/sort.go +++ b/sql/expression/sort.go @@ -86,72 +86,22 @@ func (s *Sorter) Less(i, j int) bool { return false } -// Sorter2 is a version of Sorter that operates on Row2 -type Sorter2 struct { +// ValueRowSorter is a version of Sorter that operates on ValueRow +type ValueRowSorter struct { LastError error Ctx *sql.Context SortFields []sql.SortField - Rows []sql.Row2 + Rows []sql.ValueRow } -func (s *Sorter2) Len() int { +func (s *ValueRowSorter) Len() int { return len(s.Rows) } -func (s *Sorter2) Swap(i, j int) { +func (s *ValueRowSorter) Swap(i, j int) { s.Rows[i], s.Rows[j] = s.Rows[j], s.Rows[i] } -func (s *Sorter2) Less(i, j int) bool { - if s.LastError != nil { - return false - } - - a := s.Rows[i] - b := s.Rows[j] - for _, sf := range s.SortFields { - typ := sf.Column2.Type2() - av, err := sf.Column2.Eval2(s.Ctx, a) - if err != nil { - s.LastError = sql.ErrUnableSort.Wrap(err) - return false - } - - bv, err := sf.Column2.Eval2(s.Ctx, b) - if err != nil { - s.LastError = sql.ErrUnableSort.Wrap(err) - return false - } - - if sf.Order == sql.Descending { - av, bv = bv, av - } - - if av.IsNull() && bv.IsNull() { - continue - } else if av.IsNull() { - return sf.NullOrdering == sql.NullsFirst - } else if bv.IsNull() { - return sf.NullOrdering != sql.NullsFirst - } - - cmp, err := typ.Compare2(av, bv) - if err != nil { - s.LastError = err - return false - } - - switch cmp { - case -1: - return true - case 1: - return false - } - } - - return false -} - // TopRowsHeap implements heap.Interface based on Sorter. It inverts the Less() // function so that it can be used to implement TopN. heap.Push() rows into it, // and if Len() > MAX; heap.Pop() the current min row. Then, at the end of diff --git a/sql/expression/unresolved.go b/sql/expression/unresolved.go index 78e2e9d0b9..a18c7ccf73 100644 --- a/sql/expression/unresolved.go +++ b/sql/expression/unresolved.go @@ -32,7 +32,7 @@ type UnresolvedColumn struct { } var _ sql.Expression = (*UnresolvedColumn)(nil) -var _ sql.Expression2 = (*UnresolvedColumn)(nil) +var _ sql.ValueExpression = (*UnresolvedColumn)(nil) var _ sql.CollationCoercible = (*UnresolvedColumn)(nil) // NewUnresolvedColumn creates a new UnresolvedColumn expression. @@ -71,12 +71,14 @@ func (*UnresolvedColumn) CollationCoercibility(ctx *sql.Context) (collation sql. return sql.Collation_binary, 7 } -func (uc *UnresolvedColumn) Eval2(ctx *sql.Context, row sql.Row2) (sql.Value, error) { - panic("unresolved column is a placeholder node, but Eval2 was called") +// EvalValue implements the sql.ValueExpression interface. +func (uc *UnresolvedColumn) EvalValue(ctx *sql.Context, row sql.ValueRow) (sql.Value, error) { + panic("unresolved column is a placeholder node, but EvalValue was called") } -func (uc *UnresolvedColumn) Type2() sql.Type2 { - panic("unresolved column is a placeholder node, but Type2 was called") +// IsValueRowIter implements the ValueExpression interface. +func (uc *UnresolvedColumn) IsValueExpression() bool { + panic("unresolved column is a placeholder node, but IsValueExpression was called") } // Name implements the Nameable interface. diff --git a/sql/memory.go b/sql/memory.go index 651ed9e68f..1bc49bab1b 100644 --- a/sql/memory.go +++ b/sql/memory.go @@ -64,15 +64,15 @@ type RowsCache interface { Get() []Row } -// Rows2Cache is a cache of Row2s. -type Rows2Cache interface { +// ValueRowsCache is a cache of ValueRows. +type ValueRowsCache interface { RowsCache - // Add2 a new row to the cache. If there is no memory available, it will try to + // AddValueRow a new row to the cache. If there is no memory available, it will try to // free some memory. If after that there is still no memory available, it // will return an error and erase all the content of the cache. - Add2(Row2) error - // Get2 gets all rows. - Get2() []Row2 + AddValueRow(ValueRow) error + // GetValueRow gets all rows. + GetValueRow() []ValueRow } // ErrNoMemoryAvailable is returned when there is no more available memory. @@ -200,7 +200,7 @@ func (m *MemoryManager) NewRowsCache() (RowsCache, DisposeFunc) { // NewRowsCache returns an empty rows cache and a function to dispose it when it's // no longer needed. -func (m *MemoryManager) NewRows2Cache() (Rows2Cache, DisposeFunc) { +func (m *MemoryManager) NewRows2Cache() (ValueRowsCache, DisposeFunc) { c := newRowsCache(m, m.reporter) pos := m.addCache(c) return c, func() { diff --git a/sql/plan/filter.go b/sql/plan/filter.go index f2c0691112..4e78b50500 100644 --- a/sql/plan/filter.go +++ b/sql/plan/filter.go @@ -106,6 +106,9 @@ type FilterIter struct { childIter sql.RowIter } +var _ sql.RowIter = (*FilterIter)(nil) +var _ sql.ValueRowIter = (*FilterIter)(nil) + // NewFilterIter creates a new FilterIter. func NewFilterIter( cond sql.Expression, @@ -133,6 +136,36 @@ func (i *FilterIter) Next(ctx *sql.Context) (sql.Row, error) { } } +// NextValueRow implements the sql.ValueRowIter interface. +func (i *FilterIter) NextValueRow(ctx *sql.Context) (sql.ValueRow, error) { + for { + row, err := i.childIter.(sql.ValueRowIter).NextValueRow(ctx) + if err != nil { + return nil, err + } + res, err := i.cond.(sql.ValueExpression).EvalValue(ctx, row) + if err != nil { + return nil, err + } + if res.Val[0] == 1 { + return row, nil + } + } +} + +// IsValueRowIter implements the sql.ValueRowIter interface. +func (i *FilterIter) IsValueRowIter(ctx *sql.Context) bool { + cond, ok := i.cond.(sql.ValueExpression) + if !ok || !cond.IsValueExpression() { + return false + } + childIter, ok := i.childIter.(sql.ValueRowIter) + if !ok || !childIter.IsValueRowIter(ctx) { + return false + } + return true +} + // Close implements the RowIter interface. func (i *FilterIter) Close(ctx *sql.Context) error { return i.childIter.Close(ctx) diff --git a/sql/plan/indexed_table_access.go b/sql/plan/indexed_table_access.go index 7a708986b7..efb54cad54 100644 --- a/sql/plan/indexed_table_access.go +++ b/sql/plan/indexed_table_access.go @@ -307,13 +307,13 @@ func (i *IndexedTableAccess) GetLookup(ctx *sql.Context, row sql.Row) (sql.Index return i.lb.GetLookup(ctx, key) } -func (i *IndexedTableAccess) getLookup2(ctx *sql.Context, row sql.Row2) (sql.IndexLookup, error) { +func (i *IndexedTableAccess) getValueLookup(ctx *sql.Context, row sql.ValueRow) (sql.IndexLookup, error) { // if the lookup was provided at analysis time (static evaluation), use it. if !i.lookup.IsEmpty() { return i.lookup, nil } - key, err := i.lb.GetKey2(ctx, row) + key, err := i.lb.GetValueRowKey(ctx, row) if err != nil { return sql.IndexLookup{}, err } @@ -500,9 +500,9 @@ type lookupBuilderKey []interface{} // IndexedTableAccess nodes below an indexed join, for example. This struct is // also used to implement Expressioner on the IndexedTableAccess node. type LookupBuilder struct { - index sql.Index - keyExprs []sql.Expression - keyExprs2 []sql.Expression2 + index sql.Index + keyExprs []sql.Expression + keyValExprs []sql.ValueExpression // When building the lookup, we will use an MySQLIndexBuilder. If the // extracted lookup value is NULL, but we have a non-NULL safe // comparison, then the lookup should return no values. But if the @@ -636,13 +636,13 @@ func (lb *LookupBuilder) GetKey(ctx *sql.Context, row sql.Row) (lookupBuilderKey return lb.key, nil } -func (lb *LookupBuilder) GetKey2(ctx *sql.Context, row sql.Row2) (lookupBuilderKey, error) { +func (lb *LookupBuilder) GetValueRowKey(ctx *sql.Context, row sql.ValueRow) (lookupBuilderKey, error) { if lb.key == nil { lb.key = make([]interface{}, len(lb.keyExprs)) } for i := range lb.keyExprs { var err error - lb.key[i], err = lb.keyExprs2[i].Eval2(ctx, row) + lb.key[i], err = lb.keyValExprs[i].EvalValue(ctx, row) if err != nil { return nil, err } diff --git a/sql/plan/process.go b/sql/plan/process.go index ee95249f10..adcf3cdf68 100644 --- a/sql/plan/process.go +++ b/sql/plan/process.go @@ -317,6 +317,25 @@ func (i *TrackedRowIter) Next(ctx *sql.Context) (sql.Row, error) { return row, nil } +// NextValueRow implements the sql.ValueRowIter interface. +func (i *TrackedRowIter) NextValueRow(ctx *sql.Context) (sql.ValueRow, error) { + row, err := i.iter.(sql.ValueRowIter).NextValueRow(ctx) + if err != nil { + return nil, err + } + i.numRows++ + if i.onNext != nil { + i.onNext() + } + return row, nil +} + +// IsValueRowIter implements the sql.ValueRowIter interface. +func (i *TrackedRowIter) IsValueRowIter(ctx *sql.Context) bool { + iter, ok := i.iter.(sql.ValueRowIter) + return ok && iter.IsValueRowIter(ctx) +} + func (i *TrackedRowIter) Close(ctx *sql.Context) error { err := i.iter.Close(ctx) diff --git a/sql/planbuilder/parse_old_test.go b/sql/planbuilder/parse_old_test.go index 01c79e9e78..66906fea55 100644 --- a/sql/planbuilder/parse_old_test.go +++ b/sql/planbuilder/parse_old_test.go @@ -1805,7 +1805,7 @@ func TestParse(t *testing.T) { // []sql.SortField{ // { // Column: expression.NewUnresolvedColumn("baz"), - // Column2: expression.NewUnresolvedColumn("baz"), + // ValueExprColumn: expression.NewUnresolvedColumn("baz"), // Order: sql.Descending, // NullOrdering: sql.NullsFirst, // }, @@ -1844,7 +1844,7 @@ func TestParse(t *testing.T) { // []sql.SortField{ // { // Column: expression.NewUnresolvedColumn("baz"), - // Column2: expression.NewUnresolvedColumn("baz"), + // ValueExprColumn: expression.NewUnresolvedColumn("baz"), // Order: sql.Descending, // NullOrdering: sql.NullsFirst, // }, @@ -1866,7 +1866,7 @@ func TestParse(t *testing.T) { // []sql.SortField{ // { // Column: expression.NewUnresolvedColumn("baz"), - // Column2: expression.NewUnresolvedColumn("baz"), + // ValueExprColumn: expression.NewUnresolvedColumn("baz"), // Order: sql.Descending, // NullOrdering: sql.NullsFirst, // }, @@ -2634,13 +2634,13 @@ func TestParse(t *testing.T) { // []sql.SortField{ // { // Column: expression.NewLiteral(int8(2), types.Int8), - // Column2: expression.NewLiteral(int8(2), types.Int8), + // ValueExprColumn: expression.NewLiteral(int8(2), types.Int8), // Order: sql.Ascending, // NullOrdering: sql.NullsFirst, // }, // { // Column: expression.NewLiteral(int8(1), types.Int8), - // Column2: expression.NewLiteral(int8(1), types.Int8), + // ValueExprColumn: expression.NewLiteral(int8(1), types.Int8), // Order: sql.Ascending, // NullOrdering: sql.NullsFirst, // }, @@ -3863,7 +3863,7 @@ func TestParse(t *testing.T) { // }, sql.SortFields{ // { // Column: expression.NewUnresolvedColumn("x"), - // Column2: expression.NewUnresolvedColumn("x"), + // ValueExprColumn: expression.NewUnresolvedColumn("x"), // Order: sql.Ascending, // NullOrdering: sql.NullsFirst, // }, @@ -3896,7 +3896,7 @@ func TestParse(t *testing.T) { // expression.NewUnresolvedFunction("row_number", true, sql.NewWindowDefinition([]sql.Expression{}, sql.SortFields{ // { // Column: expression.NewUnresolvedColumn("x"), - // Column2: expression.NewUnresolvedColumn("x"), + // ValueExprColumn: expression.NewUnresolvedColumn("x"), // Order: sql.Ascending, // NullOrdering: sql.NullsFirst, // }, @@ -3920,7 +3920,7 @@ func TestParse(t *testing.T) { // expression.NewUnresolvedFunction("row_number", true, sql.NewWindowDefinition([]sql.Expression{}, sql.SortFields{ // { // Column: expression.NewUnresolvedColumn("x"), - // Column2: expression.NewUnresolvedColumn("x"), + // ValueExprColumn: expression.NewUnresolvedColumn("x"), // Order: sql.Ascending, // NullOrdering: sql.NullsFirst, // }, @@ -3986,7 +3986,7 @@ func TestParse(t *testing.T) { // expression.NewUnresolvedFunction("count", true, sql.NewWindowDefinition([]sql.Expression{}, sql.SortFields{ // { // Column: expression.NewUnresolvedColumn("x"), - // Column2: expression.NewUnresolvedColumn("x"), + // ValueExprColumn: expression.NewUnresolvedColumn("x"), // Order: sql.Ascending, // NullOrdering: sql.NullsFirst, // }, @@ -4023,7 +4023,7 @@ func TestParse(t *testing.T) { // expression.NewUnresolvedFunction("row_number", true, sql.NewWindowDefinition([]sql.Expression{}, sql.SortFields{ // { // Column: expression.NewUnresolvedColumn("a"), - // Column2: expression.NewUnresolvedColumn("a"), + // ValueExprColumn: expression.NewUnresolvedColumn("a"), // Order: sql.Ascending, // NullOrdering: sql.NullsFirst, // }, @@ -4299,7 +4299,7 @@ func TestParse(t *testing.T) { // "w1": sql.NewWindowDefinition([]sql.Expression{}, sql.SortFields{ // { // Column: expression.NewUnresolvedColumn("x"), - // Column2: expression.NewUnresolvedColumn("x"), + // ValueExprColumn: expression.NewUnresolvedColumn("x"), // Order: sql.Ascending, // NullOrdering: sql.NullsFirst, // }, @@ -4329,7 +4329,7 @@ func TestParse(t *testing.T) { // "w1": sql.NewWindowDefinition([]sql.Expression{}, sql.SortFields{ // { // Column: expression.NewUnresolvedColumn("x"), - // Column2: expression.NewUnresolvedColumn("x"), + // ValueExprColumn: expression.NewUnresolvedColumn("x"), // Order: sql.Ascending, // NullOrdering: sql.NullsFirst, // }, @@ -4838,7 +4838,7 @@ func TestParse(t *testing.T) { // ), true, nil, nil, []sql.SortField{ // { // Column: expression.NewLiteral(int8(2), types.Int8), - // Column2: expression.NewLiteral(int8(2), types.Int8), + // ValueExprColumn: expression.NewLiteral(int8(2), types.Int8), // Order: sql.Ascending, // NullOrdering: sql.NullsFirst, // }, diff --git a/sql/rowexec/transaction_iters.go b/sql/rowexec/transaction_iters.go index db69cf5327..99b0041436 100644 --- a/sql/rowexec/transaction_iters.go +++ b/sql/rowexec/transaction_iters.go @@ -99,6 +99,17 @@ func (t *TransactionCommittingIter) Next(ctx *sql.Context) (sql.Row, error) { return t.childIter.Next(ctx) } +// NextValueRow implements the sql.ValueRowIter interface. +func (t *TransactionCommittingIter) NextValueRow(ctx *sql.Context) (sql.ValueRow, error) { + return t.childIter.(sql.ValueRowIter).NextValueRow(ctx) +} + +// IsValueRowIter implements the sql.ValueRowIter interface. +func (t *TransactionCommittingIter) IsValueRowIter(ctx *sql.Context) bool { + childIter, ok := t.childIter.(sql.ValueRowIter) + return ok && childIter.IsValueRowIter(ctx) +} + func (t *TransactionCommittingIter) Close(ctx *sql.Context) error { var err error if t.childIter != nil { diff --git a/sql/rows.go b/sql/rows.go index a9e5f55d5c..061727b7f3 100644 --- a/sql/rows.go +++ b/sql/rows.go @@ -18,10 +18,6 @@ import ( "fmt" "io" "strings" - - "github.com/dolthub/vitess/go/vt/proto/query" - - "github.com/dolthub/go-mysql-server/sql/values" ) // Row is a tuple of values. @@ -87,11 +83,21 @@ func FormatRow(row Row) string { // TODO: most row iters need to be Disposable for CachedResult safety type RowIter interface { // Next retrieves the next row. It will return io.EOF if it's the last row. - // After retrieving the last row, Close will be automatically closed. + // After retrieving the last row, Close will be automatically called. Next(ctx *Context) (Row, error) Closer } +// ValueRowIter is an iterator that produces sql.ValueRows. +type ValueRowIter interface { + // NextValueRow retrieves the next ValueRow. It will return io.EOF if it's the last ValueRow. + // After retrieving the last ValueRow, Close will be automatically called. + NextValueRow(ctx *Context) (ValueRow, error) + // IsValueRowIter checks whether this implementor and all its children support ValueRowIter. + IsValueRowIter(ctx *Context) bool + Closer +} + // RowIterToRows converts a row iterator to a slice of rows. func RowIterToRows(ctx *Context, i RowIter) ([]Row, error) { var rows []Row @@ -112,71 +118,6 @@ func RowIterToRows(ctx *Context, i RowIter) ([]Row, error) { return rows, i.Close(ctx) } -func rowFromRow2(sch Schema, r Row2) Row { - row := make(Row, len(sch)) - for i, col := range sch { - switch col.Type.Type() { - case query.Type_INT8: - row[i] = values.ReadInt8(r.GetField(i).Val) - case query.Type_UINT8: - row[i] = values.ReadUint8(r.GetField(i).Val) - case query.Type_INT16: - row[i] = values.ReadInt16(r.GetField(i).Val) - case query.Type_UINT16: - row[i] = values.ReadUint16(r.GetField(i).Val) - case query.Type_INT32: - row[i] = values.ReadInt32(r.GetField(i).Val) - case query.Type_UINT32: - row[i] = values.ReadUint32(r.GetField(i).Val) - case query.Type_INT64: - row[i] = values.ReadInt64(r.GetField(i).Val) - case query.Type_UINT64: - row[i] = values.ReadUint64(r.GetField(i).Val) - case query.Type_FLOAT32: - row[i] = values.ReadFloat32(r.GetField(i).Val) - case query.Type_FLOAT64: - row[i] = values.ReadFloat64(r.GetField(i).Val) - case query.Type_TEXT, query.Type_VARCHAR, query.Type_CHAR: - row[i] = values.ReadString(r.GetField(i).Val, values.ByteOrderCollation) - case query.Type_BLOB, query.Type_VARBINARY, query.Type_BINARY: - row[i] = values.ReadBytes(r.GetField(i).Val, values.ByteOrderCollation) - case query.Type_BIT: - fallthrough - case query.Type_ENUM: - fallthrough - case query.Type_SET: - fallthrough - case query.Type_TUPLE: - fallthrough - case query.Type_GEOMETRY: - fallthrough - case query.Type_JSON: - fallthrough - case query.Type_EXPRESSION: - fallthrough - case query.Type_INT24: - fallthrough - case query.Type_UINT24: - fallthrough - case query.Type_TIMESTAMP: - fallthrough - case query.Type_DATE: - fallthrough - case query.Type_TIME: - fallthrough - case query.Type_DATETIME: - fallthrough - case query.Type_YEAR: - fallthrough - case query.Type_DECIMAL: - panic(fmt.Sprintf("Unimplemented type conversion: %T", col.Type)) - default: - panic(fmt.Sprintf("unknown type %T", col.Type)) - } - } - return row -} - // RowsToRowIter creates a RowIter that iterates over the given rows. func RowsToRowIter(rows ...Row) RowIter { return &sliceRowIter{rows: rows} diff --git a/sql/sort_field.go b/sql/sort_field.go index 02b844a07f..9cbcbbaef1 100644 --- a/sql/sort_field.go +++ b/sql/sort_field.go @@ -24,8 +24,8 @@ import ( type SortField struct { // Column to order by. Column Expression - // Column Expression2 to order by. This is always the same value as Column, but avoids a type cast - Column2 Expression2 + // Column ValueExpression to order by. This is always the same value as Column, but avoids a type cast + ValueExprColumn ValueExpression // Order type. Order SortOrder // NullOrdering defining how nulls will be ordered. @@ -50,12 +50,12 @@ func (sf SortFields) FromExpressions(exprs ...Expression) SortFields { } for i, expr := range exprs { - expr2, _ := expr.(Expression2) + valueExpr, _ := expr.(ValueExpression) fields[i] = SortField{ - Column: expr, - Column2: expr2, - NullOrdering: sf[i].NullOrdering, - Order: sf[i].Order, + Column: expr, + ValueExprColumn: valueExpr, + NullOrdering: sf[i].NullOrdering, + Order: sf[i].Order, } } return fields diff --git a/sql/table_iter.go b/sql/table_iter.go index e302d5428a..1c82ac05ad 100644 --- a/sql/table_iter.go +++ b/sql/table_iter.go @@ -24,6 +24,7 @@ type TableRowIter struct { partitions PartitionIter partition Partition rows RowIter + valueRows ValueRowIter } var _ RowIter = (*TableRowIter)(nil) @@ -76,6 +77,70 @@ func (i *TableRowIter) Next(ctx *Context) (Row, error) { return row, err } +// NextValueRow implements the sql.ValueRowIter interface +func (i *TableRowIter) NextValueRow(ctx *Context) (ValueRow, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + if i.partition == nil { + partition, err := i.partitions.Next(ctx) + if err != nil { + if err == io.EOF { + if e := i.partitions.Close(ctx); e != nil { + return nil, e + } + } + return nil, err + } + i.partition = partition + } + + if i.valueRows == nil { + rows, err := i.table.PartitionRows(ctx, i.partition) + if err != nil { + return nil, err + } + i.valueRows = rows.(ValueRowIter) + } + + row, err := i.valueRows.NextValueRow(ctx) + if err != nil && err == io.EOF { + if err = i.valueRows.Close(ctx); err != nil { + return nil, err + } + i.partition = nil + i.valueRows = nil + row, err = i.NextValueRow(ctx) + } + return row, err +} + +// IsValueRowIter implements the sql.ValueRowIter interface. +func (i *TableRowIter) IsValueRowIter(ctx *Context) bool { + if i.partition == nil { + partition, err := i.partitions.Next(ctx) + if err != nil { + return false + } + i.partition = partition + } + if i.valueRows == nil { + rows, err := i.table.PartitionRows(ctx, i.partition) + if err != nil { + return false + } + valRowIter, ok := rows.(ValueRowIter) + if !ok { + return false + } + i.valueRows = valRowIter + } + return i.valueRows.IsValueRowIter(ctx) +} + func (i *TableRowIter) Close(ctx *Context) error { if i.rows != nil { if err := i.rows.Close(ctx); err != nil { @@ -83,5 +148,11 @@ func (i *TableRowIter) Close(ctx *Context) error { return err } } + if i.valueRows != nil { + if err := i.valueRows.Close(ctx); err != nil { + _ = i.partitions.Close(ctx) + return err + } + } return i.partitions.Close(ctx) } diff --git a/sql/type.go b/sql/type.go index 59af5360f1..3a36255839 100644 --- a/sql/type.go +++ b/sql/type.go @@ -105,6 +105,18 @@ type Type interface { fmt.Stringer } +// ValueType is an extension of the Type interface, that operates over sql.Values. +type ValueType interface { + Type + // CompareValue returns an integer comparing two sql.Values. + // The result will be 0 if a == b, -1 if a < b, and +1 if a > b. + CompareValue(*Context, Value, Value) (int, error) + // SQLValue returns the sqltypes.Value for the given sql.Value. + // Implementations can optionally use |dest| to append + // serialized data, but should not mutate existing data. + SQLValue(*Context, Value, []byte) (sqltypes.Value, error) +} + // TrimStringToNumberPrefix will remove any white space for s and truncate any trailing non-numeric characters. func TrimStringToNumberPrefix(ctx *Context, s string, isInt bool) string { if isInt { @@ -292,19 +304,6 @@ func IsDecimalType(t Type) bool { return ok } -type Type2 interface { - Type - - // Compare2 returns an integer comparing two Values. - Compare2(Value, Value) (int, error) - // Convert2 converts a value of a compatible type. - Convert2(Value) (Value, error) - // Zero2 returns the zero Value for this type. - Zero2() Value - // SQL2 returns the sqltypes.Value for the given value - SQL2(Value) (sqltypes.Value, error) -} - // SpatialColumnType is a node that contains a reference to all spatial types. type SpatialColumnType interface { // GetSpatialTypeSRID returns the SRID value for spatial types. diff --git a/sql/types/bit.go b/sql/types/bit.go index 7f9ef77d95..50a5b71cc3 100644 --- a/sql/types/bit.go +++ b/sql/types/bit.go @@ -103,6 +103,31 @@ func (t BitType_) Compare(ctx context.Context, a interface{}, b interface{}) (in return 0, nil } +// CompareValue implements the ValueType interface +func (t BitType_) CompareValue(ctx *sql.Context, a, b sql.Value) (int, error) { + if hasNulls, res := CompareNullValues(a, b); hasNulls { + return res, nil + } + + av, _, err := convertValueToUint64(ctx, a) + if err != nil { + return 0, err + } + bv, _, err := convertValueToUint64(ctx, b) + if err != nil { + return 0, err + } + + switch { + case av < bv: + return -1, nil + case av > bv: + return 1, nil + default: + return 0, nil + } +} + // Convert implements Type interface. func (t BitType_) Convert(ctx context.Context, v interface{}) (interface{}, sql.ConvertInRange, error) { if v == nil { @@ -211,6 +236,30 @@ func (t BitType_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Va return sqltypes.MakeTrusted(sqltypes.Bit, val), nil } +// SQLValue implements ValueType interface. +func (t BitType_) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { + if v.IsNull() { + return sqltypes.NULL, nil + } + + // Trim/Pad result to the appropriate length + numBytes := t.numOfBits / 8 + if t.numOfBits%8 != 0 { + numBytes += 1 + } + for i := uint8(len(v.Val)); i < numBytes; i++ { + v.Val = append(v.Val, 0) + } + v.Val = v.Val[:numBytes] + + // want the results in big endian + dest = append(dest, v.Val...) + for i, j := 0, len(dest)-1; i < j; i, j = i+1, j-1 { + dest[i], dest[j] = dest[j], dest[i] + } + return sqltypes.MakeTrusted(sqltypes.Bit, dest), nil +} + // String implements Type interface. func (t BitType_) String() string { return fmt.Sprintf("bit(%v)", t.numOfBits) diff --git a/sql/types/conversion.go b/sql/types/conversion.go index 2e5ca52748..0006f94cc1 100644 --- a/sql/types/conversion.go +++ b/sql/types/conversion.go @@ -472,6 +472,24 @@ func CompareNulls(a interface{}, b interface{}) (bool, int) { return false, 0 } +// CompareNullValues compares two sql.Values, and returns true if either is null. +// The returned integer represents the ordering, with a rule that states nulls +// as being ordered before non-nulls. +func CompareNullValues(a, b sql.Value) (bool, int) { + aIsNull := a.IsNull() + bIsNull := b.IsNull() + switch { + case aIsNull && bIsNull: + return true, 0 + case aIsNull && !bIsNull: + return false, 1 + case !aIsNull && bIsNull: + return false, -1 + default: + return false, 0 + } +} + // NumColumns returns the number of columns in a type. This is one for all // types, except tuples. func NumColumns(t sql.Type) int { diff --git a/sql/types/datetime.go b/sql/types/datetime.go index 387956f0fb..c10dc759fb 100644 --- a/sql/types/datetime.go +++ b/sql/types/datetime.go @@ -28,6 +28,7 @@ import ( "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/values" ) const ZeroDateStr = "0000-00-00" @@ -188,6 +189,11 @@ func (t datetimeType) Compare(ctx context.Context, a interface{}, b interface{}) return 0, nil } +// CompareValue implements the ValueType interface +func (t datetimeType) CompareValue(ctx *sql.Context, a, b sql.Value) (int, error) { + panic("TODO: implement CompareValue for DatetimeType") +} + // Convert implements Type interface. func (t datetimeType) Convert(ctx context.Context, v interface{}) (interface{}, sql.ConvertInRange, error) { if v == nil { @@ -273,14 +279,14 @@ func (t datetimeType) ConvertWithoutRangeCheck(ctx context.Context, v interface{ } // TODO: consider not using time.Parse if we want to match MySQL exactly ('2010-06-03 11:22.:.:.:.:' is a valid timestamp) var parsed bool - res, parsed, err = t.parseDatetime(value) + res, parsed, err = parseDatetime(value) if !parsed { return zeroTime, ErrConvertingToTime.New(v) } case time.Time: res = value.UTC() - // For most integer values, we just return an error (but MySQL is more lenient for some of these). A special case - // is zero values, which are important when converting from postgres defaults. + // For most integer values, we just return an error (but MySQL is more lenient for some of these). A special case + // is zero values, which are important when converting from postgres defaults. case int: if value == 0 { return zeroTime, nil @@ -370,7 +376,7 @@ func (t datetimeType) ConvertWithoutRangeCheck(ctx context.Context, v interface{ return res, err } -func (t datetimeType) parseDatetime(value string) (time.Time, bool, error) { +func parseDatetime(value string) (time.Time, bool, error) { if t, err := time.Parse(TimezoneTimestampDatetimeLayout, value); err == nil { return t.UTC(), true, nil } @@ -474,6 +480,25 @@ func (t datetimeType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltype return sqltypes.MakeTrusted(typ, valBytes), nil } +// SQLValue implements the ValueType interface. +func (t datetimeType) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { + if v.IsNull() { + return sqltypes.NULL, nil + } + switch t.baseType { + case sqltypes.Date: + t := values.ReadDate(v.Val) + dest = t.AppendFormat(dest, sql.DateLayout) + case sqltypes.Datetime, sqltypes.Timestamp: + x := values.ReadInt64(v.Val) + t := time.UnixMicro(x).UTC() + dest = t.AppendFormat(dest, sql.TimestampDatetimeLayout) + default: + return sqltypes.Value{}, sql.ErrInvalidBaseType.New(t.baseType.String(), "datetime") + } + return sqltypes.MakeTrusted(t.baseType, dest), nil +} + func (t datetimeType) String() string { switch t.baseType { case sqltypes.Date: diff --git a/sql/types/decimal.go b/sql/types/decimal.go index ccfa6eb321..ea51479d78 100644 --- a/sql/types/decimal.go +++ b/sql/types/decimal.go @@ -27,6 +27,7 @@ import ( "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/values" ) const ( @@ -138,6 +139,22 @@ func (t DecimalType_) Compare(s context.Context, a interface{}, b interface{}) ( return af.Decimal.Cmp(bf.Decimal), nil } +// CompareValue implements the ValueType interface +func (t DecimalType_) CompareValue(ctx *sql.Context, a, b sql.Value) (int, error) { + if hasNulls, res := CompareNullValues(a, b); hasNulls { + return res, nil + } + aDec, err := convertValueToDecimal(ctx, a) + if err != nil { + return 0, err + } + bDec, err := convertValueToDecimal(ctx, b) + if err != nil { + return 0, err + } + return aDec.Cmp(bDec), nil +} + // Convert implements Type interface. func (t DecimalType_) Convert(c context.Context, v interface{}) (interface{}, sql.ConvertInRange, error) { dec, err := t.ConvertToNullDecimal(v) @@ -199,7 +216,7 @@ func (t DecimalType_) ConvertToNullDecimal(v interface{}) (decimal.NullDecimal, case int64: return t.ConvertToNullDecimal(decimal.NewFromInt(value)) case uint64: - return t.ConvertToNullDecimal(decimal.NewFromBigInt(new(big.Int).SetUint64(value), 0)) + return t.ConvertToNullDecimal(decimal.NewFromUint64(value)) case float32: return t.ConvertToNullDecimal(decimal.NewFromFloat32(value)) case float64: @@ -329,6 +346,14 @@ func (t DecimalType_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltype return sqltypes.MakeTrusted(sqltypes.Decimal, val), nil } +func (t DecimalType_) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { + if v.IsNull() { + return sqltypes.NULL, nil + } + d := values.ReadDecimal(v.Val) + return sqltypes.MakeTrusted(sqltypes.Decimal, []byte(t.DecimalValueStringFixed(d))), nil +} + // String implements Type interface. func (t DecimalType_) String() string { return fmt.Sprintf("decimal(%v,%v)", t.precision, t.scale) @@ -385,3 +410,73 @@ func (t DecimalType_) DecimalValueStringFixed(v decimal.Decimal) string { return v.StringFixed(v.Exponent() * -1) } } + +func convertValueToDecimal(ctx *sql.Context, v sql.Value) (decimal.Decimal, error) { + switch v.Typ { + case sqltypes.Int8: + x := values.ReadInt8(v.Val) + return decimal.NewFromInt(int64(x)), nil + case sqltypes.Int16: + x := values.ReadInt16(v.Val) + return decimal.NewFromInt(int64(x)), nil + case sqltypes.Int32: + x := values.ReadInt32(v.Val) + return decimal.NewFromInt(int64(x)), nil + case sqltypes.Int64: + x := values.ReadInt64(v.Val) + return decimal.NewFromInt(x), nil + case sqltypes.Uint8: + x := values.ReadUint8(v.Val) + return decimal.NewFromInt(int64(x)), nil + case sqltypes.Uint16: + x := values.ReadUint16(v.Val) + return decimal.NewFromInt(int64(x)), nil + case sqltypes.Uint32: + x := values.ReadUint32(v.Val) + return decimal.NewFromInt(int64(x)), nil + case sqltypes.Uint64: + x := values.ReadUint64(v.Val) + return decimal.NewFromUint64(x), nil + case sqltypes.Float32: + x := values.ReadFloat32(v.Val) + return decimal.NewFromFloat32(x), nil + case sqltypes.Float64: + x := values.ReadFloat64(v.Val) + return decimal.NewFromFloat(x), nil + case sqltypes.Decimal: + x := values.ReadDecimal(v.Val) + return x, nil + case sqltypes.Bit: + x := values.ReadUint64(v.Val) + return decimal.NewFromUint64(x), nil + case sqltypes.Year: + x := values.ReadUint16(v.Val) + return decimal.NewFromInt(int64(x)), nil + case sqltypes.Date: + x := values.ReadDate(v.Val) + s := x.UTC().Unix() + return decimal.NewFromInt(s), nil + case sqltypes.Time: + x := values.ReadInt64(v.Val) + return decimal.NewFromInt(x), nil + case sqltypes.Datetime, sqltypes.Timestamp: + x := values.ReadDatetime(v.Val) + return decimal.NewFromInt(x.UTC().Unix()), nil + case sqltypes.Text, sqltypes.Blob: + var err error + if v.Val == nil { + v.Val, err = v.WrappedVal.Unwrap(ctx) + if err != nil { + return decimal.Decimal{}, err + } + } + x := values.ReadString(v.Val) + res, err := decimal.NewFromString(x) + if err != nil { + return decimal.Decimal{}, err + } + return res, nil + default: + return decimal.Decimal{}, ErrConvertingToDecimal.New(v) + } +} diff --git a/sql/types/decimal_test.go b/sql/types/decimal_test.go index e39e3496b6..740d0a1dc0 100644 --- a/sql/types/decimal_test.go +++ b/sql/types/decimal_test.go @@ -16,18 +16,21 @@ package types import ( "context" + "encoding/binary" "fmt" + "math" "math/big" "reflect" "strings" "testing" "time" - "github.com/dolthub/go-mysql-server/sql" - + "github.com/dolthub/vitess/go/sqltypes" "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/dolthub/go-mysql-server/sql" ) func TestDecimalAccuracy(t *testing.T) { @@ -426,3 +429,455 @@ func TestDecimalZero(t *testing.T) { }) } } + +func TestConvertValueToDecimal(t *testing.T) { + ctx := sql.NewEmptyContext() + + zeroDec := serializeDecimal(decimal.Zero) + testDec := serializeDecimal(decimal.NewFromFloat(123.456)) + minInt64Dec := serializeDecimal(decimal.NewFromInt(math.MinInt64)) + maxInt64Dec := serializeDecimal(decimal.NewFromInt(math.MaxInt64)) + + tests := []struct { + val sql.Value + exp decimal.Decimal + err bool + }{ + // Int8 -> Decimal + { + val: sql.Value{ + Val: []byte{0}, + Typ: sqltypes.Int8, + }, + exp: decimal.Zero, + }, + { + val: sql.Value{ + Val: []byte{67}, + Typ: sqltypes.Int8, + }, + exp: decimal.NewFromInt(67), + }, + { + val: sql.Value{ + Val: []byte{127}, + Typ: sqltypes.Int8, + }, + exp: decimal.NewFromInt(127), + }, + { + val: sql.Value{ + Val: []byte{255}, + Typ: sqltypes.Int8, + }, + exp: decimal.NewFromInt(-1), + }, + + // Int16 -> Decimal + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(0)), + Typ: sqltypes.Int16, + }, + exp: decimal.Zero, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(67)), + Typ: sqltypes.Int16, + }, + exp: decimal.NewFromInt(67), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxInt16), + Typ: sqltypes.Int16, + }, + exp: decimal.NewFromInt(math.MaxInt16), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxInt16+1), + Typ: sqltypes.Int16, + }, + exp: decimal.NewFromInt(math.MinInt16), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxUint16), + Typ: sqltypes.Int16, + }, + exp: decimal.NewFromInt(-1), + }, + + // Int32 -> Decimal + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, uint32(0)), + Typ: sqltypes.Int32, + }, + exp: decimal.Zero, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, uint32(67)), + Typ: sqltypes.Int32, + }, + exp: decimal.NewFromInt(67), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxInt32), + Typ: sqltypes.Int32, + }, + exp: decimal.NewFromInt(math.MaxInt32), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxInt32+1), + Typ: sqltypes.Int32, + }, + exp: decimal.NewFromInt(math.MinInt32), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxUint32), + Typ: sqltypes.Int32, + }, + exp: decimal.NewFromInt(-1), + }, + + // Int64 -> Decimal + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(0)), + Typ: sqltypes.Int64, + }, + exp: decimal.Zero, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(67)), + Typ: sqltypes.Int64, + }, + exp: decimal.NewFromInt(67), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64), + Typ: sqltypes.Int64, + }, + exp: decimal.NewFromInt(math.MaxInt64), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64+1), + Typ: sqltypes.Int64, + }, + exp: decimal.NewFromInt(math.MinInt64), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxUint64), + Typ: sqltypes.Int64, + }, + exp: decimal.NewFromInt(-1), + }, + + // Uint8 -> Decimal + { + val: sql.Value{ + Val: []byte{0}, + Typ: sqltypes.Uint8, + }, + exp: decimal.Zero, + }, + { + val: sql.Value{ + Val: []byte{67}, + Typ: sqltypes.Uint8, + }, + exp: decimal.NewFromInt(67), + }, + { + val: sql.Value{ + Val: []byte{128}, + Typ: sqltypes.Uint8, + }, + exp: decimal.NewFromInt(128), + }, + { + val: sql.Value{ + Val: []byte{255}, + Typ: sqltypes.Uint8, + }, + exp: decimal.NewFromInt(255), + }, + + // Uint16 -> Decimal + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(0)), + Typ: sqltypes.Uint16, + }, + exp: decimal.Zero, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(67)), + Typ: sqltypes.Uint16, + }, + exp: decimal.NewFromInt(67), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxInt16), + Typ: sqltypes.Uint16, + }, + exp: decimal.NewFromInt(math.MaxInt16), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxUint16), + Typ: sqltypes.Uint16, + }, + exp: decimal.NewFromInt(math.MaxUint16), + }, + + // Uint32 -> Decimal + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, uint32(0)), + Typ: sqltypes.Uint32, + }, + exp: decimal.Zero, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, uint32(67)), + Typ: sqltypes.Uint32, + }, + exp: decimal.NewFromInt(67), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxInt32), + Typ: sqltypes.Uint32, + }, + exp: decimal.NewFromInt(math.MaxInt32), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxUint32), + Typ: sqltypes.Uint32, + }, + exp: decimal.NewFromInt(math.MaxUint32), + }, + + // Uint64 -> Decimal + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(0)), + Typ: sqltypes.Uint64, + }, + exp: decimal.Zero, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(67)), + Typ: sqltypes.Uint64, + }, + exp: decimal.NewFromInt(67), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64), + Typ: sqltypes.Uint64, + }, + exp: decimal.NewFromInt(math.MaxInt64), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxUint64), + Typ: sqltypes.Uint64, + }, + exp: decimal.NewFromUint64(math.MaxUint64), + }, + + // Float32 -> Decimal + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(0)), + Typ: sqltypes.Float32, + }, + exp: decimal.Zero, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(123.456)), + Typ: sqltypes.Float32, + }, + exp: decimal.NewFromFloat32(123.456), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(-math.MaxFloat32)), + Typ: sqltypes.Float32, + }, + exp: decimal.NewFromFloat32(-math.MaxFloat32), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(math.MaxFloat32)), + Typ: sqltypes.Float32, + }, + exp: decimal.NewFromFloat32(math.MaxFloat32), + }, + + // Float64 -> Decimal + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(0)), + Typ: sqltypes.Float64, + }, + exp: decimal.Zero, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(123.456)), + Typ: sqltypes.Float64, + }, + exp: decimal.NewFromFloat(123.456), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(-math.MaxFloat32)), + Typ: sqltypes.Float64, + }, + exp: decimal.NewFromFloat(-math.MaxFloat32), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(math.MaxFloat32)), + Typ: sqltypes.Float64, + }, + exp: decimal.NewFromFloat(math.MaxFloat32), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(-math.MaxFloat64)), + Typ: sqltypes.Float64, + }, + exp: decimal.NewFromFloat(-math.MaxFloat64), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(math.MaxFloat64)), + Typ: sqltypes.Float64, + }, + exp: decimal.NewFromFloat(math.MaxFloat64), + }, + + // Decimal -> Decimal + { + val: sql.Value{ + Val: zeroDec, + Typ: sqltypes.Decimal, + }, + exp: decimal.Zero, + }, + { + val: sql.Value{ + Val: testDec, + Typ: sqltypes.Decimal, + }, + exp: decimal.NewFromFloat(123.456), + }, + { + val: sql.Value{ + Val: minInt64Dec, + Typ: sqltypes.Decimal, + }, + exp: decimal.NewFromInt(math.MinInt64), + }, + { + val: sql.Value{ + Val: maxInt64Dec, + Typ: sqltypes.Decimal, + }, + exp: decimal.NewFromInt(math.MaxInt64), + }, + + // Bit -> Decimal + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(0)), + Typ: sqltypes.Bit, + }, + exp: decimal.Zero, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(67)), + Typ: sqltypes.Bit, + }, + exp: decimal.NewFromInt(67), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64), + Typ: sqltypes.Bit, + }, + exp: decimal.NewFromInt(math.MaxInt64), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxUint64), + Typ: sqltypes.Bit, + }, + exp: decimal.NewFromUint64(math.MaxUint64), + }, + + // Year -> Decimal + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(0)), + Typ: sqltypes.Year, + }, + exp: decimal.Zero, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(1967)), + Typ: sqltypes.Year, + }, + exp: decimal.NewFromInt(1967), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(1901)), + Typ: sqltypes.Year, + }, + exp: decimal.NewFromInt(1901), + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(2155)), + Typ: sqltypes.Year, + }, + exp: decimal.NewFromInt(2155), + }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("%v", test.val), func(t *testing.T) { + res, err := convertValueToDecimal(ctx, test.val) + if test.err { + require.Error(t, err) + return + } + require.NoError(t, err) + require.True(t, test.exp.Equal(res), fmt.Sprintf("%v != %v", test.exp, res)) + }) + } +} diff --git a/sql/types/enum.go b/sql/types/enum.go index 3dcfa27147..5db9cfde81 100644 --- a/sql/types/enum.go +++ b/sql/types/enum.go @@ -29,6 +29,7 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/encodings" + "github.com/dolthub/go-mysql-server/sql/values" ) const ( @@ -157,6 +158,11 @@ func (t EnumType) Compare(ctx context.Context, a interface{}, b interface{}) (in return 0, nil } +// CompareValue implements the ValueType interface +func (t EnumType) CompareValue(ctx *sql.Context, a, b sql.Value) (int, error) { + panic("TODO: implement CompareValue for EnumType") +} + // Convert implements Type interface. func (t EnumType) Convert(ctx context.Context, v interface{}) (interface{}, sql.ConvertInRange, error) { if v == nil { @@ -268,6 +274,33 @@ func (t EnumType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Va return sqltypes.MakeTrusted(sqltypes.Enum, val), nil } +// SQLValue implements the ValueType interface. +func (t EnumType) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { + if v.IsNull() { + return sqltypes.NULL, nil + } + + idx := values.ReadUint16(v.Val) + value, _ := t.At(int(idx)) + + charset := ctx.GetCharacterSetResults() + if charset == sql.CharacterSet_Unspecified || charset == sql.CharacterSet_binary { + charset = t.collation.CharacterSet() + } + + // TODO: write append style encoder + res, ok := charset.Encoder().Encode([]byte(value)) + if !ok { + if len(value) > 50 { + value = value[:50] + } + value = strings.ToValidUTF8(value, string(utf8.RuneError)) + return sqltypes.Value{}, sql.ErrCharSetFailedToEncode.New(charset.Name(), utf8.ValidString(value), value) + } + + return sqltypes.MakeTrusted(sqltypes.Enum, res), nil +} + // String implements Type interface. func (t EnumType) String() string { return t.StringWithTableCollation(sql.Collation_Default) diff --git a/sql/types/number.go b/sql/types/number.go index e9ecfc04f7..e10e668b6e 100644 --- a/sql/types/number.go +++ b/sql/types/number.go @@ -102,7 +102,7 @@ type NumberTypeImpl_ struct { } var _ sql.Type = NumberTypeImpl_{} -var _ sql.Type2 = NumberTypeImpl_{} +var _ sql.ValueType = NumberTypeImpl_{} var _ sql.CollationCoercible = NumberTypeImpl_{} var _ sql.NumberType = NumberTypeImpl_{} var _ sql.RoundingNumberType = NumberTypeImpl_{} @@ -210,6 +210,67 @@ func (t NumberTypeImpl_) Compare(s context.Context, a interface{}, b interface{} } } +// CompareValue implements the ValueType interface +func (t NumberTypeImpl_) CompareValue(ctx *sql.Context, a, b sql.Value) (int, error) { + if hasNulls, res := CompareNullValues(a, b); hasNulls { + return res, nil + } + + switch t.baseType { + case sqltypes.Uint8, sqltypes.Uint16, sqltypes.Uint24, sqltypes.Uint32, sqltypes.Uint64: + ca, _, err := convertValueToUint64(ctx, a) + if err != nil { + return 0, err + } + cb, _, err := convertValueToUint64(ctx, b) + if err != nil { + return 0, err + } + + if ca == cb { + return 0, nil + } + if ca < cb { + return -1, nil + } + return +1, nil + case sqltypes.Float32, sqltypes.Float64: + ca, err := convertValueToFloat64(ctx, a) + if err != nil { + return 0, err + } + cb, err := convertValueToFloat64(ctx, b) + if err != nil { + return 0, err + } + + if ca == cb { + return 0, nil + } + if ca < cb { + return -1, nil + } + return +1, nil + default: + ca, _, err := convertValueToInt64(ctx, a) + if err != nil { + return 0, err + } + cb, _, err := convertValueToInt64(ctx, b) + if err != nil { + return 0, err + } + + if ca == cb { + return 0, nil + } + if ca < cb { + return -1, nil + } + return +1, nil + } +} + // Convert implements Type interface. func (t NumberTypeImpl_) Convert(ctx context.Context, v interface{}) (interface{}, sql.ConvertInRange, error) { var err error @@ -728,194 +789,54 @@ func (t NumberTypeImpl_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqlt return sqltypes.MakeTrusted(t.baseType, val), nil } -func (t NumberTypeImpl_) Compare2(a sql.Value, b sql.Value) (int, error) { - switch t.baseType { - case sqltypes.Uint8, sqltypes.Uint16, sqltypes.Uint24, sqltypes.Uint32, sqltypes.Uint64: - ca, err := convertValueToUint64(t, a) - if err != nil { - return 0, err - } - cb, err := convertValueToUint64(t, b) - if err != nil { - return 0, err - } - - if ca == cb { - return 0, nil - } - if ca < cb { - return -1, nil - } - return +1, nil - case sqltypes.Float32, sqltypes.Float64: - ca, err := convertValueToFloat64(t, a) - if err != nil { - return 0, err - } - cb, err := convertValueToFloat64(t, b) - if err != nil { - return 0, err - } - - if ca == cb { - return 0, nil - } - if ca < cb { - return -1, nil - } - return +1, nil - default: - ca, err := convertValueToInt64(t, a) - if err != nil { - return 0, err - } - cb, err := convertValueToInt64(t, b) - if err != nil { - return 0, err - } - - if ca == cb { - return 0, nil - } - if ca < cb { - return -1, nil - } - return +1, nil - } -} - -func (t NumberTypeImpl_) Convert2(value sql.Value) (sql.Value, error) { - panic("implement me") -} - -func (t NumberTypeImpl_) Zero2() sql.Value { - switch t.baseType { - case sqltypes.Int8: - x := values.WriteInt8(make([]byte, values.Int8Size), 0) - return sql.Value{ - Typ: query.Type_INT8, - Val: x, - } - case sqltypes.Int16: - x := values.WriteInt16(make([]byte, values.Int16Size), 0) - return sql.Value{ - Typ: query.Type_INT16, - Val: x, - } - case sqltypes.Int24: - x := values.WriteInt24(make([]byte, values.Int24Size), 0) - return sql.Value{ - Typ: query.Type_INT24, - Val: x, - } - case sqltypes.Int32: - x := values.WriteInt32(make([]byte, values.Int32Size), 0) - return sql.Value{ - Typ: query.Type_INT32, - Val: x, - } - case sqltypes.Int64: - x := values.WriteInt64(make([]byte, values.Int64Size), 0) - return sql.Value{ - Typ: query.Type_INT64, - Val: x, - } - case sqltypes.Uint8: - x := values.WriteUint8(make([]byte, values.Uint8Size), 0) - return sql.Value{ - Typ: query.Type_UINT8, - Val: x, - } - case sqltypes.Uint16: - x := values.WriteUint16(make([]byte, values.Uint16Size), 0) - return sql.Value{ - Typ: query.Type_UINT16, - Val: x, - } - case sqltypes.Uint24: - x := values.WriteUint24(make([]byte, values.Uint24Size), 0) - return sql.Value{ - Typ: query.Type_UINT24, - Val: x, - } - case sqltypes.Uint32: - x := values.WriteUint32(make([]byte, values.Uint32Size), 0) - return sql.Value{ - Typ: query.Type_UINT32, - Val: x, - } - case sqltypes.Uint64: - x := values.WriteUint64(make([]byte, values.Uint64Size), 0) - return sql.Value{ - Typ: query.Type_UINT64, - Val: x, - } - case sqltypes.Float32: - x := values.WriteFloat32(make([]byte, values.Float32Size), 0) - return sql.Value{ - Typ: query.Type_FLOAT32, - Val: x, - } - case sqltypes.Float64: - x := values.WriteUint64(make([]byte, values.Uint64Size), 0) - return sql.Value{ - Typ: query.Type_UINT64, - Val: x, - } - default: - panic(sql.ErrInvalidBaseType.New(t.baseType.String(), "number")) - } -} - -// SQL2 implements Type2 interface. -func (t NumberTypeImpl_) SQL2(v sql.Value) (sqltypes.Value, error) { +// SQLValue implements ValueType interface. +func (t NumberTypeImpl_) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { if v.IsNull() { return sqltypes.NULL, nil } - var val []byte switch t.baseType { case sqltypes.Int8: x := values.ReadInt8(v.Val) - val = []byte(strconv.FormatInt(int64(x), 10)) + dest = strconv.AppendInt(dest, int64(x), 10) case sqltypes.Int16: x := values.ReadInt16(v.Val) - val = []byte(strconv.FormatInt(int64(x), 10)) + dest = strconv.AppendInt(dest, int64(x), 10) case sqltypes.Int24: x := values.ReadInt24(v.Val) - val = []byte(strconv.FormatInt(int64(x), 10)) + dest = strconv.AppendInt(dest, int64(x), 10) case sqltypes.Int32: x := values.ReadInt32(v.Val) - val = []byte(strconv.FormatInt(int64(x), 10)) + dest = strconv.AppendInt(dest, int64(x), 10) case sqltypes.Int64: x := values.ReadInt64(v.Val) - val = []byte(strconv.FormatInt(x, 10)) + dest = strconv.AppendInt(dest, x, 10) case sqltypes.Uint8: x := values.ReadUint8(v.Val) - val = []byte(strconv.FormatUint(uint64(x), 10)) + dest = strconv.AppendUint(dest, uint64(x), 10) case sqltypes.Uint16: x := values.ReadUint16(v.Val) - val = []byte(strconv.FormatUint(uint64(x), 10)) + dest = strconv.AppendUint(dest, uint64(x), 10) case sqltypes.Uint24: x := values.ReadUint24(v.Val) - val = []byte(strconv.FormatUint(uint64(x), 10)) + dest = strconv.AppendUint(dest, uint64(x), 10) case sqltypes.Uint32: x := values.ReadUint32(v.Val) - val = []byte(strconv.FormatUint(uint64(x), 10)) + dest = strconv.AppendUint(dest, uint64(x), 10) case sqltypes.Uint64: x := values.ReadUint64(v.Val) - val = []byte(strconv.FormatUint(x, 10)) + dest = strconv.AppendUint(dest, x, 10) case sqltypes.Float32: x := values.ReadFloat32(v.Val) - val = []byte(strconv.FormatFloat(float64(x), 'f', -1, 32)) + dest = strconv.AppendFloat(dest, float64(x), 'f', -1, 32) case sqltypes.Float64: x := values.ReadFloat64(v.Val) - val = []byte(strconv.FormatFloat(x, 'f', -1, 64)) + dest = strconv.AppendFloat(dest, x, 'f', -1, 64) default: panic(sql.ErrInvalidBaseType.New(t.baseType.String(), "number")) } - return sqltypes.MakeTrusted(t.baseType, val), nil + return sqltypes.MakeTrusted(t.baseType, dest), nil } // String implements Type interface. @@ -1152,103 +1073,15 @@ func convertToInt64(t NumberTypeImpl_, v any, round Round) (int64, sql.ConvertIn } } -func convertValueToInt64(t NumberTypeImpl_, v sql.Value) (int64, error) { - switch v.Typ { - case query.Type_INT8: - return int64(values.ReadInt8(v.Val)), nil - case query.Type_INT16: - return int64(values.ReadInt16(v.Val)), nil - case query.Type_INT24: - return int64(values.ReadInt24(v.Val)), nil - case query.Type_INT32: - return int64(values.ReadInt32(v.Val)), nil - case query.Type_INT64: - return values.ReadInt64(v.Val), nil - case query.Type_UINT8: - return int64(values.ReadUint8(v.Val)), nil - case query.Type_UINT16: - return int64(values.ReadUint16(v.Val)), nil - case query.Type_UINT24: - return int64(values.ReadUint24(v.Val)), nil - case query.Type_UINT32: - return int64(values.ReadUint32(v.Val)), nil - case query.Type_UINT64: - v := values.ReadUint64(v.Val) - if v > math.MaxInt64 { - return math.MaxInt64, nil - } - return int64(v), nil - case query.Type_FLOAT32: - v := values.ReadFloat32(v.Val) - if v > float32(math.MaxInt64) { - return math.MaxInt64, nil - } else if v < float32(math.MinInt64) { - return math.MinInt64, nil - } - return int64(math.Round(float64(v))), nil - case query.Type_FLOAT64: - v := values.ReadFloat64(v.Val) - if v > float64(math.MaxInt64) { - return math.MaxInt64, nil - } else if v < float64(math.MinInt64) { - return math.MinInt64, nil - } - return int64(math.Round(v)), nil - // TODO: add more conversions - default: - panic(sql.ErrInvalidBaseType.New(t.baseType.String(), "number")) - } -} - -func convertValueToUint64(t NumberTypeImpl_, v sql.Value) (uint64, error) { - switch v.Typ { - case query.Type_INT8: - return uint64(values.ReadInt8(v.Val)), nil - case query.Type_INT16: - return uint64(values.ReadInt16(v.Val)), nil - case query.Type_INT24: - return uint64(values.ReadInt24(v.Val)), nil - case query.Type_INT32: - return uint64(values.ReadInt32(v.Val)), nil - case query.Type_INT64: - return uint64(values.ReadInt64(v.Val)), nil - case query.Type_UINT8: - return uint64(values.ReadUint8(v.Val)), nil - case query.Type_UINT16: - return uint64(values.ReadUint16(v.Val)), nil - case query.Type_UINT24: - return uint64(values.ReadUint24(v.Val)), nil - case query.Type_UINT32: - return uint64(values.ReadUint32(v.Val)), nil - case query.Type_UINT64: - return values.ReadUint64(v.Val), nil - case query.Type_FLOAT32: - v := values.ReadFloat32(v.Val) - if v >= float32(math.MaxUint64) { - return math.MaxUint64, nil - } - return uint64(math.Round(float64(v))), nil - case query.Type_FLOAT64: - v := values.ReadFloat64(v.Val) - if v >= float64(math.MaxUint64) { - return math.MaxUint64, nil - } - return uint64(math.Round(v)), nil - // TODO: add more conversions - default: - panic(sql.ErrInvalidBaseType.New(t.baseType.String(), "number")) - } -} - func convertToUint64(t NumberTypeImpl_, v any, round Round) (uint64, sql.ConvertInRange, error) { switch v := v.(type) { case time.Time: return uint64(v.UTC().Unix()), sql.InRange, nil case int: if v < 0 { - return uint64(math.MaxUint64 - uint(-v-1)), sql.OutOfRange, nil + return uint64(v), sql.OutOfRange, nil } - return uint64(v), sql.InRange, nil + return uint64(v), v > 0, nil case int8: if v < 0 { return uint64(math.MaxUint64 - uint(-v-1)), sql.OutOfRange, nil @@ -1428,34 +1261,167 @@ func convertToFloat64(t NumberTypeImpl_, v interface{}) (float64, error) { } } -func convertValueToFloat64(t NumberTypeImpl_, v sql.Value) (float64, error) { +func convertValueToInt64(ctx *sql.Context, v sql.Value) (int64, sql.ConvertInRange, error) { switch v.Typ { - case query.Type_INT8: + case sqltypes.Int8: + return int64(values.ReadInt8(v.Val)), sql.InRange, nil + case sqltypes.Int16: + return int64(values.ReadInt16(v.Val)), sql.InRange, nil + case sqltypes.Int24: + return int64(values.ReadInt24(v.Val)), sql.InRange, nil + case sqltypes.Int32: + return int64(values.ReadInt32(v.Val)), sql.InRange, nil + case sqltypes.Int64: + return values.ReadInt64(v.Val), sql.InRange, nil + case sqltypes.Uint8: + return int64(values.ReadUint8(v.Val)), sql.InRange, nil + case sqltypes.Uint16: + return int64(values.ReadUint16(v.Val)), sql.InRange, nil + case sqltypes.Uint24: + return int64(values.ReadUint24(v.Val)), sql.InRange, nil + case sqltypes.Uint32: + return int64(values.ReadUint32(v.Val)), sql.InRange, nil + case sqltypes.Uint64: + x := values.ReadUint64(v.Val) + if x > math.MaxInt64 { + return math.MaxInt64, sql.OutOfRange, nil + } + return int64(x), sql.InRange, nil + case sqltypes.Float32: + x := values.ReadFloat32(v.Val) + if x > float32(math.MaxInt64) { + return math.MaxInt64, sql.OutOfRange, nil + } + if x < float32(math.MinInt64) { + return math.MinInt64, sql.OutOfRange, nil + } + return int64(math.Round(float64(x))), sql.InRange, nil + case sqltypes.Float64: + x := values.ReadFloat64(v.Val) + if x > float64(math.MaxInt64) { + return math.MaxInt64, sql.OutOfRange, nil + } + if x < float64(math.MinInt64) { + return math.MinInt64, sql.OutOfRange, nil + } + return int64(math.Round(x)), sql.InRange, nil + case sqltypes.Decimal: + x := values.ReadDecimal(v.Val) + if x.GreaterThan(dec_int64_max) { + return math.MaxInt64, sql.OutOfRange, nil + } + if x.LessThan(dec_int64_min) { + return math.MinInt64, sql.OutOfRange, nil + } + return x.Round(0).IntPart(), sql.InRange, nil + case sqltypes.Bit: + x := values.ReadUint64(v.Val) + if x > math.MaxInt64 { + return math.MaxInt64, sql.OutOfRange, nil + } + return int64(x), sql.InRange, nil + case sqltypes.Year: + return int64(values.ReadUint16(v.Val)), sql.InRange, nil + default: + return 0, sql.InRange, sql.ErrInvalidBaseType.New(v.Typ.String(), "number") + } +} + +func convertValueToUint64(ctx *sql.Context, v sql.Value) (uint64, sql.ConvertInRange, error) { + switch v.Typ { + case sqltypes.Int8: + return uint64(values.ReadInt8(v.Val)), sql.InRange, nil + case sqltypes.Int16: + return uint64(values.ReadInt16(v.Val)), sql.InRange, nil + case sqltypes.Int24: + return uint64(values.ReadInt24(v.Val)), sql.InRange, nil + case sqltypes.Int32: + return uint64(values.ReadInt32(v.Val)), sql.InRange, nil + case sqltypes.Int64: + return uint64(values.ReadInt64(v.Val)), sql.InRange, nil + case sqltypes.Uint8: + return uint64(values.ReadUint8(v.Val)), sql.InRange, nil + case sqltypes.Uint16: + return uint64(values.ReadUint16(v.Val)), sql.InRange, nil + case sqltypes.Uint24: + return uint64(values.ReadUint24(v.Val)), sql.InRange, nil + case sqltypes.Uint32: + return uint64(values.ReadUint32(v.Val)), sql.InRange, nil + case sqltypes.Uint64: + return values.ReadUint64(v.Val), sql.InRange, nil + case sqltypes.Float32: + x := values.ReadFloat32(v.Val) + if x > float32(math.MaxUint64) { + return math.MaxUint64, sql.OutOfRange, nil + } + if x < 0 { + return uint64(x), sql.OutOfRange, nil + } + return uint64(math.Round(float64(x))), sql.InRange, nil + case sqltypes.Float64: + x := values.ReadFloat64(v.Val) + if x > float64(math.MaxUint64) { + return math.MaxUint64, sql.OutOfRange, nil + } + if x < 0 { + return uint64(x), sql.OutOfRange, nil + } + return uint64(math.Round(x)), sql.InRange, nil + case sqltypes.Decimal: + x := values.ReadDecimal(v.Val) + if x.GreaterThan(dec_uint64_max) { + return math.MaxUint64, sql.OutOfRange, nil + } + if x.LessThan(dec_zero) { + ret, _ := dec_uint64_max.Sub(x).Float64() + return uint64(math.Round(ret)), sql.OutOfRange, nil + } + return uint64(x.Round(0).IntPart()), sql.InRange, nil + case sqltypes.Bit: + return values.ReadUint64(v.Val), sql.InRange, nil + case sqltypes.Year: + return uint64(values.ReadUint16(v.Val)), sql.InRange, nil + default: + return 0, sql.InRange, sql.ErrInvalidBaseType.New(v.Typ.String(), "number") + } +} + +func convertValueToFloat64(ctx *sql.Context, v sql.Value) (float64, error) { + switch v.Typ { + case sqltypes.Int8: return float64(values.ReadInt8(v.Val)), nil - case query.Type_INT16: + case sqltypes.Int16: return float64(values.ReadInt16(v.Val)), nil - case query.Type_INT24: + case sqltypes.Int24: return float64(values.ReadInt24(v.Val)), nil - case query.Type_INT32: + case sqltypes.Int32: return float64(values.ReadInt32(v.Val)), nil - case query.Type_INT64: + case sqltypes.Int64: return float64(values.ReadInt64(v.Val)), nil - case query.Type_UINT8: + case sqltypes.Uint8: return float64(values.ReadUint8(v.Val)), nil - case query.Type_UINT16: + case sqltypes.Uint16: return float64(values.ReadUint16(v.Val)), nil - case query.Type_UINT24: + case sqltypes.Uint24: return float64(values.ReadUint24(v.Val)), nil - case query.Type_UINT32: + case sqltypes.Uint32: return float64(values.ReadUint32(v.Val)), nil - case query.Type_UINT64: + case sqltypes.Uint64: return float64(values.ReadUint64(v.Val)), nil - case query.Type_FLOAT32: + case sqltypes.Float32: return float64(values.ReadFloat32(v.Val)), nil - case query.Type_FLOAT64: + case sqltypes.Float64: return values.ReadFloat64(v.Val), nil + case sqltypes.Decimal: + x := values.ReadDecimal(v.Val) + f, _ := x.Float64() + return f, nil + case sqltypes.Bit: + return float64(values.ReadUint64(v.Val)), nil + case sqltypes.Year: + return float64(values.ReadUint16(v.Val)), nil default: - panic(sql.ErrInvalidBaseType.New(t.baseType.String(), "number")) + return 0, sql.ErrInvalidBaseType.New(v.Typ.String(), "number") } } diff --git a/sql/types/number_test.go b/sql/types/number_test.go index 695f61fcc0..9980284e22 100644 --- a/sql/types/number_test.go +++ b/sql/types/number_test.go @@ -15,6 +15,7 @@ package types import ( + "encoding/binary" "fmt" "math" "reflect" @@ -24,6 +25,7 @@ import ( "github.com/dolthub/vitess/go/sqltypes" "github.com/dolthub/vitess/go/vt/proto/query" + "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -711,3 +713,1458 @@ func TestTruncateStringToDouble(t *testing.T) { }) } } + +func serializeDecimal(dec decimal.Decimal) []byte { + var res []byte + coef := dec.Coefficient() + res = binary.LittleEndian.AppendUint32(res, uint32(dec.Exponent())) + res = append(res, byte(coef.Sign())) + res = append(res, coef.Bytes()...) + return res +} + +func TestConvertValueToInt64(t *testing.T) { + ctx := sql.NewEmptyContext() + + zeroDec := serializeDecimal(decimal.Zero) + testDec := serializeDecimal(decimal.NewFromFloat(123.456)) + minInt64Dec := serializeDecimal(decimal.NewFromInt(math.MinInt64)) + maxInt64Dec := serializeDecimal(decimal.NewFromInt(math.MaxInt64)) + + tests := []struct { + val sql.Value + exp int64 + rng sql.ConvertInRange + err bool + }{ + // Int8 -> Int64 + { + val: sql.Value{ + Val: []byte{0}, + Typ: sqltypes.Int8, + }, + exp: 0, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: []byte{67}, + Typ: sqltypes.Int8, + }, + exp: 67, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: []byte{128}, + Typ: sqltypes.Int8, + }, + exp: -128, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: []byte{255}, + Typ: sqltypes.Int8, + }, + exp: -1, + rng: sql.InRange, + }, + + // Int16 -> Int64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(0)), + Typ: sqltypes.Int16, + }, + exp: 0, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(67)), + Typ: sqltypes.Int16, + }, + exp: 67, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxInt16+1), + Typ: sqltypes.Int16, + }, + exp: math.MinInt16, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxInt16), + Typ: sqltypes.Int16, + }, + exp: math.MaxInt16, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxUint16), + Typ: sqltypes.Int16, + }, + exp: -1, + rng: sql.InRange, + }, + + // Int32 -> Int64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, uint32(0)), + Typ: sqltypes.Int32, + }, + exp: 0, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, uint32(67)), + Typ: sqltypes.Int32, + }, + exp: 67, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxInt32+1), + Typ: sqltypes.Int32, + }, + exp: math.MinInt32, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxInt32), + Typ: sqltypes.Int32, + }, + exp: math.MaxInt32, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxUint32), + Typ: sqltypes.Int32, + }, + exp: -1, + rng: sql.InRange, + }, + + // Int64 -> Int64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(0)), + Typ: sqltypes.Int64, + }, + exp: 0, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(67)), + Typ: sqltypes.Int64, + }, + exp: 67, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64+1), + Typ: sqltypes.Int64, + }, + exp: math.MinInt64, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64), + Typ: sqltypes.Int64, + }, + exp: math.MaxInt64, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxUint64), + Typ: sqltypes.Int64, + }, + exp: -1, + rng: sql.InRange, + }, + + // Uint8 -> Int64 + { + val: sql.Value{ + Val: []byte{0}, + Typ: sqltypes.Uint8, + }, + exp: 0, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: []byte{67}, + Typ: sqltypes.Uint8, + }, + exp: 67, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: []byte{128}, + Typ: sqltypes.Uint8, + }, + exp: 128, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: []byte{255}, + Typ: sqltypes.Uint8, + }, + exp: 255, + rng: sql.InRange, + }, + + // Uint16 -> Int64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(0)), + Typ: sqltypes.Uint16, + }, + exp: 0, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(67)), + Typ: sqltypes.Uint16, + }, + exp: 67, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxInt16), + Typ: sqltypes.Uint16, + }, + exp: math.MaxInt16, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxUint16), + Typ: sqltypes.Uint16, + }, + exp: math.MaxUint16, + rng: sql.InRange, + }, + + // Uint32 -> Int64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, uint32(0)), + Typ: sqltypes.Uint32, + }, + exp: 0, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, uint32(67)), + Typ: sqltypes.Uint32, + }, + exp: 67, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxInt32), + Typ: sqltypes.Uint32, + }, + exp: math.MaxInt32, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxUint32), + Typ: sqltypes.Uint32, + }, + exp: math.MaxUint32, + rng: sql.InRange, + }, + + // Uint64 -> Int64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(0)), + Typ: sqltypes.Uint64, + }, + exp: 0, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(67)), + Typ: sqltypes.Uint64, + }, + exp: 67, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64), + Typ: sqltypes.Uint64, + }, + exp: math.MaxInt64, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxUint64), + Typ: sqltypes.Uint64, + }, + exp: math.MaxInt64, + rng: sql.OutOfRange, + }, + + // Float32 -> Int64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(0)), + Typ: sqltypes.Float32, + }, + exp: 0, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(123.456)), + Typ: sqltypes.Float32, + }, + exp: 123, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(-math.MaxFloat32)), + Typ: sqltypes.Float32, + }, + exp: math.MinInt64, + rng: sql.OutOfRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(math.MaxFloat32)), + Typ: sqltypes.Float32, + }, + exp: math.MaxInt64, + rng: sql.OutOfRange, + }, + + // Float64 -> Int64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(0)), + Typ: sqltypes.Float64, + }, + exp: 0, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(123.456)), + Typ: sqltypes.Float64, + }, + exp: 123, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(-math.MaxFloat32)), + Typ: sqltypes.Float64, + }, + exp: math.MinInt64, + rng: sql.OutOfRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(math.MaxFloat32)), + Typ: sqltypes.Float64, + }, + exp: math.MaxInt64, + rng: sql.OutOfRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(-math.MaxFloat64)), + Typ: sqltypes.Float64, + }, + exp: math.MinInt64, + rng: sql.OutOfRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(math.MaxFloat64)), + Typ: sqltypes.Float64, + }, + exp: math.MaxInt64, + rng: sql.OutOfRange, + }, + + // Decimal -> Int64 + { + val: sql.Value{ + Val: zeroDec, + Typ: sqltypes.Decimal, + }, + exp: 0, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: testDec, + Typ: sqltypes.Decimal, + }, + exp: 123, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: minInt64Dec, + Typ: sqltypes.Decimal, + }, + exp: math.MinInt64, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: maxInt64Dec, + Typ: sqltypes.Decimal, + }, + exp: math.MaxInt64, + rng: sql.InRange, + }, + + // Bit -> Int64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(0)), + Typ: sqltypes.Bit, + }, + exp: 0, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(67)), + Typ: sqltypes.Bit, + }, + exp: 67, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64), + Typ: sqltypes.Bit, + }, + exp: math.MaxInt64, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxUint64), + Typ: sqltypes.Bit, + }, + exp: math.MaxInt64, + rng: sql.OutOfRange, + }, + + // Year -> Int64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(0)), + Typ: sqltypes.Year, + }, + exp: 0, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(1967)), + Typ: sqltypes.Year, + }, + exp: 1967, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(1901)), + Typ: sqltypes.Year, + }, + exp: 1901, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(2155)), + Typ: sqltypes.Year, + }, + exp: 2155, + rng: sql.InRange, + }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("%v", test.val), func(t *testing.T) { + res, rng, err := convertValueToInt64(ctx, test.val) + if test.err { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, test.exp, res) + require.Equal(t, test.rng, rng) + }) + } +} + +func TestConvertValueToUint64(t *testing.T) { + ctx := sql.NewEmptyContext() + + zeroDec := serializeDecimal(decimal.Zero) + testDec := serializeDecimal(decimal.NewFromFloat(123.456)) + maxInt64Dec := serializeDecimal(decimal.NewFromInt(math.MaxInt64)) + + tests := []struct { + val sql.Value + exp uint64 + rng sql.ConvertInRange + err bool + }{ + // Int8 -> Uint64 + { + val: sql.Value{ + Val: []byte{0}, + Typ: sqltypes.Int8, + }, + exp: 0, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: []byte{67}, + Typ: sqltypes.Int8, + }, + exp: 67, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: []byte{127}, + Typ: sqltypes.Int8, + }, + exp: 127, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: []byte{255}, + Typ: sqltypes.Int8, + }, + exp: math.MaxUint64, + rng: sql.InRange, + }, + + // Int16 -> Uint64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(0)), + Typ: sqltypes.Int16, + }, + exp: 0, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(67)), + Typ: sqltypes.Int16, + }, + exp: 67, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxInt16), + Typ: sqltypes.Int16, + }, + exp: math.MaxInt16, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxInt16+1), + Typ: sqltypes.Int16, + }, + exp: math.MaxUint64 - math.MaxInt16, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxUint16), + Typ: sqltypes.Int16, + }, + exp: math.MaxUint64, + rng: sql.InRange, + }, + + // Int32 -> Uint64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, uint32(0)), + Typ: sqltypes.Int32, + }, + exp: 0, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, uint32(67)), + Typ: sqltypes.Int32, + }, + exp: 67, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxInt32), + Typ: sqltypes.Int32, + }, + exp: math.MaxInt32, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxInt32+1), + Typ: sqltypes.Int32, + }, + exp: math.MaxUint64 - math.MaxInt32, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxUint32), + Typ: sqltypes.Int32, + }, + exp: math.MaxUint64, + rng: sql.InRange, + }, + + // Int64 -> Uint64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(0)), + Typ: sqltypes.Int64, + }, + exp: 0, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(67)), + Typ: sqltypes.Int64, + }, + exp: 67, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64), + Typ: sqltypes.Int64, + }, + exp: math.MaxInt64, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64+1), + Typ: sqltypes.Int64, + }, + exp: math.MaxInt64 + 1, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxUint64), + Typ: sqltypes.Int64, + }, + exp: math.MaxUint64, + rng: sql.InRange, + }, + + // Uint8 -> Uint64 + { + val: sql.Value{ + Val: []byte{0}, + Typ: sqltypes.Uint8, + }, + exp: 0, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: []byte{67}, + Typ: sqltypes.Uint8, + }, + exp: 67, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: []byte{128}, + Typ: sqltypes.Uint8, + }, + exp: 128, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: []byte{255}, + Typ: sqltypes.Uint8, + }, + exp: 255, + rng: sql.InRange, + }, + + // Uint16 -> Uint64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(0)), + Typ: sqltypes.Uint16, + }, + exp: 0, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(67)), + Typ: sqltypes.Uint16, + }, + exp: 67, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxInt16), + Typ: sqltypes.Uint16, + }, + exp: math.MaxInt16, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxUint16), + Typ: sqltypes.Uint16, + }, + exp: math.MaxUint16, + rng: sql.InRange, + }, + + // Uint32 -> Uint64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, uint32(0)), + Typ: sqltypes.Uint32, + }, + exp: 0, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, uint32(67)), + Typ: sqltypes.Uint32, + }, + exp: 67, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxInt32), + Typ: sqltypes.Uint32, + }, + exp: math.MaxInt32, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxUint32), + Typ: sqltypes.Uint32, + }, + exp: math.MaxUint32, + rng: sql.InRange, + }, + + // Uint64 -> Uint64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(0)), + Typ: sqltypes.Uint64, + }, + exp: 0, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(67)), + Typ: sqltypes.Uint64, + }, + exp: 67, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64), + Typ: sqltypes.Uint64, + }, + exp: math.MaxInt64, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxUint64), + Typ: sqltypes.Uint64, + }, + exp: math.MaxUint64, + rng: sql.InRange, + }, + + // Float32 -> Uint64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(0)), + Typ: sqltypes.Float32, + }, + exp: 0, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(123.456)), + Typ: sqltypes.Float32, + }, + exp: 123, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(math.MaxFloat32)), + Typ: sqltypes.Float32, + }, + exp: math.MaxUint64, + rng: sql.OutOfRange, + }, + + // Float64 -> Uint64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(0)), + Typ: sqltypes.Float64, + }, + exp: 0, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(123.456)), + Typ: sqltypes.Float64, + }, + exp: 123, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(math.MaxFloat32)), + Typ: sqltypes.Float64, + }, + exp: math.MaxUint64, + rng: sql.OutOfRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(math.MaxFloat64)), + Typ: sqltypes.Float64, + }, + exp: math.MaxUint64, + rng: sql.OutOfRange, + }, + + // Decimal -> Uint64 + { + val: sql.Value{ + Val: zeroDec, + Typ: sqltypes.Decimal, + }, + exp: 0, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: testDec, + Typ: sqltypes.Decimal, + }, + exp: 123, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: maxInt64Dec, + Typ: sqltypes.Decimal, + }, + exp: math.MaxInt64, + rng: sql.InRange, + }, + + // Bit -> Uint64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(0)), + Typ: sqltypes.Bit, + }, + exp: 0, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(67)), + Typ: sqltypes.Bit, + }, + exp: 67, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64), + Typ: sqltypes.Bit, + }, + exp: math.MaxInt64, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxUint64), + Typ: sqltypes.Bit, + }, + exp: math.MaxUint64, + rng: sql.InRange, + }, + + // Year -> Uint64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(0)), + Typ: sqltypes.Year, + }, + exp: 0, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(1967)), + Typ: sqltypes.Year, + }, + exp: 1967, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(1901)), + Typ: sqltypes.Year, + }, + exp: 1901, + rng: sql.InRange, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(2155)), + Typ: sqltypes.Year, + }, + exp: 2155, + rng: sql.InRange, + }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("Val: %v Typ: %v to UINT64", test.val.Val, test.val.Typ), func(t *testing.T) { + res, rng, err := convertValueToUint64(ctx, test.val) + if test.err { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, test.exp, res) + require.Equal(t, test.rng, rng) + }) + } +} + +func TestConvertValueToFloat64(t *testing.T) { + ctx := sql.NewEmptyContext() + + zeroDec := serializeDecimal(decimal.Zero) + testDec := serializeDecimal(decimal.NewFromFloat(123.456)) + minInt64Dec := serializeDecimal(decimal.NewFromInt(math.MinInt64)) + maxInt64Dec := serializeDecimal(decimal.NewFromInt(math.MaxInt64)) + + tests := []struct { + val sql.Value + exp float64 + err bool + }{ + // Int8 -> Float64 + { + val: sql.Value{ + Val: []byte{0}, + Typ: sqltypes.Int8, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: []byte{67}, + Typ: sqltypes.Int8, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: []byte{127}, + Typ: sqltypes.Int8, + }, + exp: 127, + }, + { + val: sql.Value{ + Val: []byte{255}, + Typ: sqltypes.Int8, + }, + exp: -1, + }, + + // Int16 -> Float64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(0)), + Typ: sqltypes.Int16, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(67)), + Typ: sqltypes.Int16, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxInt16), + Typ: sqltypes.Int16, + }, + exp: math.MaxInt16, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxInt16+1), + Typ: sqltypes.Int16, + }, + exp: math.MinInt16, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxUint16), + Typ: sqltypes.Int16, + }, + exp: -1, + }, + + // Int32 -> Float64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, uint32(0)), + Typ: sqltypes.Int32, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, uint32(67)), + Typ: sqltypes.Int32, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxInt32), + Typ: sqltypes.Int32, + }, + exp: math.MaxInt32, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxInt32+1), + Typ: sqltypes.Int32, + }, + exp: math.MinInt32, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxUint32), + Typ: sqltypes.Int32, + }, + exp: -1, + }, + + // Int64 -> Float64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(0)), + Typ: sqltypes.Int64, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(67)), + Typ: sqltypes.Int64, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64), + Typ: sqltypes.Int64, + }, + exp: math.MaxInt64, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64+1), + Typ: sqltypes.Int64, + }, + exp: math.MinInt64, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxUint64), + Typ: sqltypes.Int64, + }, + exp: -1, + }, + + // Uint8 -> Float64 + { + val: sql.Value{ + Val: []byte{0}, + Typ: sqltypes.Uint8, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: []byte{67}, + Typ: sqltypes.Uint8, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: []byte{128}, + Typ: sqltypes.Uint8, + }, + exp: 128, + }, + { + val: sql.Value{ + Val: []byte{255}, + Typ: sqltypes.Uint8, + }, + exp: 255, + }, + + // Uint16 -> Float64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(0)), + Typ: sqltypes.Uint16, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(67)), + Typ: sqltypes.Uint16, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxInt16), + Typ: sqltypes.Uint16, + }, + exp: math.MaxInt16, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, math.MaxUint16), + Typ: sqltypes.Uint16, + }, + exp: math.MaxUint16, + }, + + // Uint32 -> Float64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, uint32(0)), + Typ: sqltypes.Uint32, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, uint32(67)), + Typ: sqltypes.Uint32, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxInt32), + Typ: sqltypes.Uint32, + }, + exp: math.MaxInt32, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.MaxUint32), + Typ: sqltypes.Uint32, + }, + exp: math.MaxUint32, + }, + + // Uint64 -> Float64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(0)), + Typ: sqltypes.Uint64, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(67)), + Typ: sqltypes.Uint64, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64), + Typ: sqltypes.Uint64, + }, + exp: math.MaxInt64, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxUint64), + Typ: sqltypes.Uint64, + }, + exp: math.MaxUint64, + }, + + // Float32 -> Float64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(0)), + Typ: sqltypes.Float32, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(123.456)), + Typ: sqltypes.Float32, + }, + exp: 123.456, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(-math.MaxFloat32)), + Typ: sqltypes.Float32, + }, + exp: -math.MaxFloat32, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint32(nil, math.Float32bits(math.MaxFloat32)), + Typ: sqltypes.Float32, + }, + exp: math.MaxFloat32, + }, + + // Float64 -> Float64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(0)), + Typ: sqltypes.Float64, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(123.456)), + Typ: sqltypes.Float64, + }, + exp: 123, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(-math.MaxFloat32)), + Typ: sqltypes.Float64, + }, + exp: -math.MaxFloat32, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(math.MaxFloat32)), + Typ: sqltypes.Float64, + }, + exp: math.MaxFloat32, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(-math.MaxFloat64)), + Typ: sqltypes.Float64, + }, + exp: -math.MaxFloat64, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.Float64bits(math.MaxFloat64)), + Typ: sqltypes.Float64, + }, + exp: math.MaxFloat64, + }, + + // Decimal -> Float64 + { + val: sql.Value{ + Val: zeroDec, + Typ: sqltypes.Decimal, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: testDec, + Typ: sqltypes.Decimal, + }, + exp: 123.456, + }, + { + val: sql.Value{ + Val: minInt64Dec, + Typ: sqltypes.Decimal, + }, + exp: math.MinInt64, + }, + { + val: sql.Value{ + Val: maxInt64Dec, + Typ: sqltypes.Decimal, + }, + exp: math.MaxInt64, + }, + + // Bit -> Float64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(0)), + Typ: sqltypes.Bit, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, uint64(67)), + Typ: sqltypes.Bit, + }, + exp: 67, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxInt64), + Typ: sqltypes.Bit, + }, + exp: math.MaxInt64, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint64(nil, math.MaxUint64), + Typ: sqltypes.Bit, + }, + exp: math.MaxUint64, + }, + + // Year -> Float64 + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(0)), + Typ: sqltypes.Year, + }, + exp: 0, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(1967)), + Typ: sqltypes.Year, + }, + exp: 1967, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(1901)), + Typ: sqltypes.Year, + }, + exp: 1901, + }, + { + val: sql.Value{ + Val: binary.LittleEndian.AppendUint16(nil, uint16(2155)), + Typ: sqltypes.Year, + }, + exp: 2155, + }, + } + + epsilon := 0.01 + for _, test := range tests { + t.Run(fmt.Sprintf("%v", test.val), func(t *testing.T) { + res, err := convertValueToFloat64(ctx, test.val) + if test.err { + require.Error(t, err) + return + } + require.NoError(t, err) + if test.exp == 0 { + require.Zero(t, res) + return + } + require.InEpsilonf(t, test.exp, res, epsilon, fmt.Sprintf("Actual is: %v", res)) + }) + } +} diff --git a/sql/types/set.go b/sql/types/set.go index 98b96f1390..6bba11ac11 100644 --- a/sql/types/set.go +++ b/sql/types/set.go @@ -30,6 +30,7 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/encodings" + "github.com/dolthub/go-mysql-server/sql/values" ) const ( @@ -155,6 +156,11 @@ func (t SetType) Compare(ctx context.Context, a interface{}, b interface{}) (int return 0, nil } +// CompareValue implements the ValueType interface +func (t SetType) CompareValue(ctx *sql.Context, a, b sql.Value) (int, error) { + panic("TODO: implement CompareValue for SetType") +} + // Convert implements Type interface. // Returns the string representing the given value if applicable. func (t SetType) Convert(ctx context.Context, v interface{}) (interface{}, sql.ConvertInRange, error) { @@ -261,6 +267,36 @@ func (t SetType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.Val return sqltypes.MakeTrusted(sqltypes.Set, val), nil } +// SQLValue implements ValueType interface. +func (t SetType) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { + if v.IsNull() { + return sqltypes.NULL, nil + } + + bits := values.ReadUint64(v.Val) + value, err := t.BitsToString(bits) + if err != nil { + return sqltypes.Value{}, err + } + + resultCharset := ctx.GetCharacterSetResults() + if resultCharset == sql.CharacterSet_Unspecified || resultCharset == sql.CharacterSet_binary { + resultCharset = t.collation.CharacterSet() + } + + // TODO: write append style encoder + res, ok := resultCharset.Encoder().Encode([]byte(value)) + if !ok { + if len(value) > 50 { + value = value[:50] + } + value = strings.ToValidUTF8(value, string(utf8.RuneError)) + return sqltypes.Value{}, sql.ErrCharSetFailedToEncode.New(resultCharset.Name(), utf8.ValidString(value), value) + } + + return sqltypes.MakeTrusted(sqltypes.Set, res), nil +} + // String implements Type interface. func (t SetType) String() string { return t.StringWithTableCollation(sql.Collation_Default) diff --git a/sql/types/strings.go b/sql/types/strings.go index e44e3d5a54..3dcdac8148 100644 --- a/sql/types/strings.go +++ b/sql/types/strings.go @@ -334,6 +334,11 @@ func (t StringType) Compare(ctx context.Context, a interface{}, b interface{}) ( } } +// CompareValue implements the sql.ValueType interface. +func (t StringType) CompareValue(ctx *sql.Context, a, b sql.Value) (int, error) { + panic("TODO: implement CompareValue for StringTypes") +} + // Convert implements Type interface. func (t StringType) Convert(ctx context.Context, v interface{}) (interface{}, sql.ConvertInRange, error) { if v == nil { @@ -790,6 +795,37 @@ func (t StringType) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes. return sqltypes.MakeTrusted(t.baseType, val), nil } +// SQLValue implements ValueType interface. +func (t StringType) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { + if v.IsNull() { + return sqltypes.NULL, nil + } + + // TODO: deal with casting numbers? + // No need to use dest buffer as we have already allocated []byte + var err error + if v.Val == nil && v.WrappedVal != nil { + v.Val, err = v.WrappedVal.Unwrap(ctx) + if err != nil { + return sqltypes.Value{}, err + } + } + charset := ctx.GetCharacterSetResults() + if charset == sql.CharacterSet_Unspecified || charset == sql.CharacterSet_binary { + charset = t.collation.CharacterSet() + } + res, ok := charset.Encoder().Encode(v.Val) + if !ok { + if len(v.Val) > 50 { + v.Val = v.Val[:50] + } + snippetStr := strings2.ToValidUTF8(string(v.Val), string(utf8.RuneError)) + return sqltypes.Value{}, sql.ErrCharSetFailedToEncode.New(charset.Name(), utf8.ValidString(snippetStr), v.Val) + } + + return sqltypes.MakeTrusted(t.baseType, res), nil +} + // String implements Type interface. func (t StringType) String() string { return t.StringWithTableCollation(sql.Collation_Default) diff --git a/sql/types/time.go b/sql/types/time.go index b8ca8b005e..29ac916948 100644 --- a/sql/types/time.go +++ b/sql/types/time.go @@ -28,6 +28,7 @@ import ( "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/values" ) var ( @@ -99,6 +100,11 @@ func (t TimespanType_) Compare(s context.Context, a interface{}, b interface{}) return as.Compare(bs), nil } +// CompareValue implements the ValueType interface +func (t TimespanType_) CompareValue(ctx *sql.Context, a, b sql.Value) (int, error) { + panic("TODO: implement CompareValue for TimespanType") +} + func (t TimespanType_) Convert(c context.Context, v interface{}) (interface{}, sql.ConvertInRange, error) { if v == nil { return nil, sql.InRange, nil @@ -267,6 +273,16 @@ func (t TimespanType_) SQL(_ *sql.Context, dest []byte, v interface{}) (sqltypes return sqltypes.MakeTrusted(sqltypes.Time, val), nil } +// SQLValue implements ValueType interface. +func (t TimespanType_) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { + if v.IsNull() { + return sqltypes.NULL, nil + } + x := values.ReadInt64(v.Val) + dest = Timespan(x).AppendBytes(dest) + return sqltypes.MakeTrusted(sqltypes.Time, dest), nil +} + // String implements Type interface. func (t TimespanType_) String() string { return "time(6)" @@ -485,7 +501,45 @@ func (t Timespan) Bytes() []byte { return ret[:i] } -// appendDigit format prints 0-entended integer into buffer +func (t Timespan) AppendBytes(dest []byte) []byte { + isNegative, hours, minutes, seconds, microseconds := t.timespanToUnits() + sz := 10 + if microseconds > 0 { + sz += 7 + } + if isNegative { + dest = append(dest, '-') + } + + if hours < 10 { + dest = append(dest, '0') + } + dest = strconv.AppendInt(dest, int64(hours), 10) + dest = append(dest, ':') + + if minutes < 10 { + dest = append(dest, '0') + } + dest = strconv.AppendInt(dest, int64(minutes), 10) + dest = append(dest, ':') + + if seconds < 10 { + dest = append(dest, '0') + } + dest = strconv.AppendInt(dest, int64(seconds), 10) + if microseconds > 0 { + dest = append(dest, '.') + cmp := int32(100000) + for cmp > 0 && microseconds < cmp { + dest = append(dest, '0') + cmp /= 10 + } + dest = strconv.AppendInt(dest, int64(microseconds), 10) + } + return dest +} + +// appendDigit format prints 0-extended integer into buffer func appendDigit(v int64, extend int, buf []byte, i int) int { cmp := int64(1) for _ = range extend - 1 { diff --git a/sql/types/year.go b/sql/types/year.go index c1e1ddc5ff..8dbc0ba311 100644 --- a/sql/types/year.go +++ b/sql/types/year.go @@ -26,6 +26,7 @@ import ( "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/values" ) var ( @@ -64,6 +65,29 @@ func (t YearType_) Compare(ctx context.Context, a interface{}, b interface{}) (i return 1, nil } +// CompareValue implements the ValueType interface. +func (t YearType_) CompareValue(ctx *sql.Context, a, b sql.Value) (int, error) { + if hasNulls, res := CompareNullValues(a, b); hasNulls { + return res, nil + } + ay, err := ConvertValueToYear(ctx, a) + if err != nil { + return 0, err + } + by, err := ConvertValueToYear(ctx, b) + if err != nil { + return 0, err + } + switch { + case ay < by: + return -1, nil + case ay > by: + return 1, nil + default: + return 0, nil + } +} + // Convert implements Type interface. func (t YearType_) Convert(ctx context.Context, v interface{}) (interface{}, sql.ConvertInRange, error) { if v == nil { @@ -171,6 +195,16 @@ func (t YearType_) SQL(ctx *sql.Context, dest []byte, v interface{}) (sqltypes.V return sqltypes.MakeTrusted(sqltypes.Year, val), nil } +// SQLValue implements ValueType interface. +func (t YearType_) SQLValue(ctx *sql.Context, v sql.Value, dest []byte) (sqltypes.Value, error) { + if v.IsNull() { + return sqltypes.NULL, nil + } + x := values.ReadUint16(v.Val) + dest = strconv.AppendInt(dest, int64(x), 10) + return sqltypes.MakeTrusted(sqltypes.Year, dest), nil +} + // String implements Type interface. func (t YearType_) String() string { return "year" @@ -195,3 +229,73 @@ func (t YearType_) Zero() interface{} { func (YearType_) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { return sql.Collation_binary, 5 } + +func ConvertValueToYear(ctx *sql.Context, v sql.Value) (uint16, error) { + switch v.Typ { + case sqltypes.Int8: + x := values.ReadInt8(v.Val) + return uint16(x), nil + case sqltypes.Int16: + x := values.ReadInt16(v.Val) + return uint16(x), nil + case sqltypes.Int32: + x := values.ReadInt32(v.Val) + return uint16(x), nil + case sqltypes.Int64: + x := values.ReadInt64(v.Val) + return uint16(x), nil + case sqltypes.Uint8: + x := values.ReadUint8(v.Val) + return uint16(x), nil + case sqltypes.Uint16: + x := values.ReadUint16(v.Val) + return x, nil + case sqltypes.Uint32: + x := values.ReadUint32(v.Val) + return uint16(x), nil + case sqltypes.Uint64: + x := values.ReadUint64(v.Val) + return uint16(x), nil + case sqltypes.Float32: + x := values.ReadFloat32(v.Val) + return uint16(x), nil + case sqltypes.Float64: + x := values.ReadFloat64(v.Val) + return uint16(x), nil + case sqltypes.Decimal: + x := values.ReadDecimal(v.Val) + return uint16(x.IntPart()), nil + case sqltypes.Year: + x := values.ReadUint16(v.Val) + return x, nil + case sqltypes.Date: + x := values.ReadDate(v.Val) + return uint16(x.UTC().Unix()), nil + case sqltypes.Time: + x := values.ReadInt64(v.Val) + return uint16(x), nil + case sqltypes.Datetime, sqltypes.Timestamp: + x := values.ReadDatetime(v.Val) + return uint16(x.UTC().Unix()), nil + case sqltypes.Text, sqltypes.Blob: + var err error + if v.Val == nil { + v.Val, err = v.WrappedVal.Unwrap(ctx) + if err != nil { + return 0, err + } + } + val := values.ReadString(v.Val) + truncStr, didTrunc := TruncateStringToInt(val) + if didTrunc { + err = sql.ErrTruncatedIncorrect.New(v.Typ, val) + } + i, pErr := strconv.ParseInt(truncStr, 10, 64) + if pErr != nil { + return 0, sql.ErrInvalidValue.New(v, v.Typ.String()) + } + return uint16(i), err + default: + return 0, ErrConvertingToYear.New(v) + } +} diff --git a/sql/row_frame.go b/sql/value_row.go similarity index 77% rename from sql/row_frame.go rename to sql/value_row.go index ef3ea6010f..f9140c41c5 100644 --- a/sql/row_frame.go +++ b/sql/value_row.go @@ -17,7 +17,7 @@ package sql import ( "sync" - querypb "github.com/dolthub/vitess/go/vt/proto/query" + "github.com/dolthub/vitess/go/vt/proto/query" ) const ( @@ -25,34 +25,35 @@ const ( fieldArrSize = 2048 ) -// Row2 is a slice of values -type Row2 []Value +type ValueBytes []byte -// GetField returns the Value for the ith field in this row. -func (r Row2) GetField(i int) Value { - return r[i] +// Value is a logical index into a ValueRow. For efficiency reasons, use sparingly. +type Value struct { + Val ValueBytes + WrappedVal BytesWrapper + Typ query.Type } -// Len returns the number of fields of this row -func (r Row2) Len() int { - return len(r) +var NullValue = Value{} +var FalseValue = Value{ + Val: []byte{0}, + Typ: query.Type_INT8, } - -// Value is a logical index into a Row2. For efficiency reasons, use sparingly. -type Value struct { - Val ValueBytes - Typ querypb.Type +var TrueValue = Value{ + Val: []byte{1}, + Typ: query.Type_INT8, } +// ValueRow is a slice of values +type ValueRow []Value + // IsNull returns whether this value represents NULL func (v Value) IsNull() bool { - return v.Val == nil || v.Typ == querypb.Type_NULL_TYPE + return (v.Val == nil && v.WrappedVal == nil) || v.Typ == query.Type_NULL_TYPE } -type ValueBytes []byte - type RowFrame struct { - Types []querypb.Type + Types []query.Type // Values are the values this row. Values []ValueBytes @@ -88,34 +89,34 @@ func (f *RowFrame) Recycle() { framePool.Put(f) } -// Row2 returns the underlying row value in this frame. Does not make a deep copy of underlying byte arrays, so +// AsValueRow returns the underlying row value in this frame. Does not make a deep copy of underlying byte arrays, so // further modification to this frame may result in the returned value changing as well. -func (f *RowFrame) Row2() Row2 { +func (f *RowFrame) AsValueRow() ValueRow { if f == nil { return nil } - rs := make(Row2, len(f.Values)) + rs := make(ValueRow, len(f.Values)) for i := range f.Values { rs[i] = Value{ - Typ: f.Types[i], Val: f.Values[i], + Typ: f.Types[i], } } return rs } -// Row2Copy returns the row in this frame as a deep copy of the underlying byte arrays. Useful when reusing the +// ValueRowCopy returns the row in this frame as a deep copy of the underlying byte arrays. Useful when reusing the // RowFrame object via Clear() -func (f *RowFrame) Row2Copy() Row2 { - rs := make(Row2, len(f.Values)) +func (f *RowFrame) ValueRowCopy() ValueRow { + rs := make(ValueRow, len(f.Values)) // TODO: it would be faster here to just copy the entire value backing array in one pass for i := range f.Values { v := make(ValueBytes, len(f.Values[i])) copy(v, f.Values[i]) rs[i] = Value{ - Typ: f.Types[i], Val: v, + Typ: f.Types[i], } } return rs @@ -137,7 +138,7 @@ func (f *RowFrame) Append(vals ...Value) { } // AppendMany appends the types and values given, as two parallel arrays, into this frame. -func (f *RowFrame) AppendMany(types []querypb.Type, vals []ValueBytes) { +func (f *RowFrame) AppendMany(types []query.Type, vals []ValueBytes) { // TODO: one big copy here would be better probably, need to benchmark for i := range vals { f.appendTypeAndVal(types[i], vals[i]) @@ -156,7 +157,7 @@ func (f *RowFrame) append(v Value) { f.Values = append(f.Values, v.Val) } -func (f *RowFrame) appendTypeAndVal(typ querypb.Type, val ValueBytes) { +func (f *RowFrame) appendTypeAndVal(typ query.Type, val ValueBytes) { v := f.bufferForBytes(val) copy(v, val) diff --git a/sql/values/encoding.go b/sql/values/encoding.go index 3472e870e5..2cf4f2a765 100644 --- a/sql/values/encoding.go +++ b/sql/values/encoding.go @@ -17,7 +17,13 @@ package values import ( "bytes" "encoding/binary" + "fmt" "math" + "math/big" + "time" + "unsafe" + + "github.com/shopspring/decimal" ) type Type struct { @@ -37,15 +43,18 @@ const ( Uint24Size ByteSize = 3 Int32Size ByteSize = 4 Uint32Size ByteSize = 4 - Int48Size ByteSize = 6 - Uint48Size ByteSize = 6 Int64Size ByteSize = 8 Uint64Size ByteSize = 8 Float32Size ByteSize = 4 Float64Size ByteSize = 8 + DecimalSize ByteSize = 5 + + DateSize ByteSize = 4 + TimeSize ByteSize = 8 + DatetimeSize ByteSize = 8 + TimestampSize ByteSize = 8 ) -const maxUint48 = uint64(1<<48 - 1) const maxUint24 = uint32(1<<24 - 1) type Collation uint16 @@ -108,6 +117,7 @@ func ReadBool(val []byte) bool { expectSize(val, Int8Size) return val[0] == 1 } + func ReadInt8(val []byte) int8 { expectSize(val, Int8Size) return int8(val[0]) @@ -129,22 +139,14 @@ func ReadUint16(val []byte) uint16 { } func ReadInt24(val []byte) (i int32) { - expectSize(val, Int24Size) - var tmp [4]byte - // copy |val| to |tmp| - tmp[3], tmp[2] = val[3], val[2] - tmp[1], tmp[0] = val[1], val[0] - i = int32(binary.LittleEndian.Uint32(tmp[:])) + expectSize(val, Int32Size) + i = int32(binary.LittleEndian.Uint32(val)) return } func ReadUint24(val []byte) (u uint32) { - expectSize(val, Int24Size) - var tmp [4]byte - // copy |val| to |tmp| - tmp[3], tmp[2] = val[3], val[2] - tmp[1], tmp[0] = val[1], val[0] - u = binary.LittleEndian.Uint32(tmp[:]) + expectSize(val, Int32Size) + u = binary.LittleEndian.Uint32(val) return } @@ -158,28 +160,6 @@ func ReadUint32(val []byte) uint32 { return binary.LittleEndian.Uint32(val) } -func ReadInt48(val []byte) (i int64) { - expectSize(val, Int48Size) - var tmp [8]byte - // copy |val| to |tmp| - tmp[5], tmp[4] = val[5], val[4] - tmp[3], tmp[2] = val[3], val[2] - tmp[1], tmp[0] = val[1], val[0] - i = int64(binary.LittleEndian.Uint64(tmp[:])) - return -} - -func ReadUint48(val []byte) (u uint64) { - expectSize(val, Uint48Size) - var tmp [8]byte - // copy |val| to |tmp| - tmp[5], tmp[4] = val[5], val[4] - tmp[3], tmp[2] = val[3], val[2] - tmp[1], tmp[0] = val[1], val[0] - u = binary.LittleEndian.Uint64(tmp[:]) - return -} - func ReadInt64(val []byte) int64 { expectSize(val, Int64Size) return int64(binary.LittleEndian.Uint64(val)) @@ -192,20 +172,47 @@ func ReadUint64(val []byte) uint64 { func ReadFloat32(val []byte) float32 { expectSize(val, Float32Size) - return math.Float32frombits(ReadUint32(val)) + x := binary.LittleEndian.Uint32(val) + return math.Float32frombits(x) } func ReadFloat64(val []byte) float64 { expectSize(val, Float64Size) - return math.Float64frombits(ReadUint64(val)) + x := binary.LittleEndian.Uint64(val) + return math.Float64frombits(x) +} + +func ReadDecimal(val []byte) decimal.Decimal { + e := ReadInt32(val[:Int32Size]) + s := ReadInt8(val[Int32Size : Int32Size+Int8Size]) + b := big.NewInt(0).SetBytes(val[Int32Size+Int8Size:]) + if s < 0 { + b = b.Neg(b) + } + return decimal.NewFromBigInt(b, e) +} + +func ReadDate(val []byte) time.Time { + expectSize(val, Uint32Size) + x := binary.LittleEndian.Uint32(val) + y := x >> 16 + m := (x & (255 << 8)) >> 8 + d := x & 255 + return time.Date(int(y), time.Month(m), int(d), 0, 0, 0, 0, time.UTC) +} + +func ReadDatetime(val []byte) time.Time { + expectSize(val, DatetimeSize) + ms := int64(binary.LittleEndian.Uint64(val)) + return time.UnixMicro(ms).UTC() } -func ReadString(val []byte, coll Collation) string { - // todo: fix allocation - return string(val) +func ReadString(val []byte) string { + // TODO: this is essentially encoding.BytesToString + return *(*string)(unsafe.Pointer(&val)) } -func ReadBytes(val []byte, coll Collation) []byte { +func ReadBytes(val []byte) []byte { // todo: fix collation return val } @@ -243,29 +250,6 @@ func WriteUint16(buf []byte, val uint16) []byte { return buf } -func WriteInt24(buf []byte, val int32) []byte { - expectSize(buf, Int24Size) - - var tmp [4]byte - binary.LittleEndian.PutUint32(tmp[:], uint32(val)) - // copy |tmp| to |buf| - buf[2], buf[1], buf[0] = tmp[2], tmp[1], tmp[0] - return buf -} - -func WriteUint24(buf []byte, val uint32) []byte { - expectSize(buf, Uint24Size) - if val > maxUint24 { - panic("uint is greater than max uint24") - } - - var tmp [4]byte - binary.LittleEndian.PutUint32(tmp[:], uint32(val)) - // copy |tmp| to |buf| - buf[2], buf[1], buf[0] = tmp[2], tmp[1], tmp[0] - return buf -} - func WriteInt32(buf []byte, val int32) []byte { expectSize(buf, Int32Size) binary.LittleEndian.PutUint32(buf, uint32(val)) @@ -278,20 +262,6 @@ func WriteUint32(buf []byte, val uint32) []byte { return buf } -func WriteUint48(buf []byte, u uint64) []byte { - expectSize(buf, Uint48Size) - if u > maxUint48 { - panic("uint is greater than max uint48") - } - var tmp [8]byte - binary.LittleEndian.PutUint64(tmp[:], u) - // copy |tmp| to |buf| - buf[5], buf[4] = tmp[5], tmp[4] - buf[3], buf[2] = tmp[3], tmp[2] - buf[1], buf[0] = tmp[1], tmp[0] - return buf -} - func WriteInt64(buf []byte, val int64) []byte { expectSize(buf, Int64Size) binary.LittleEndian.PutUint64(buf, uint64(val)) @@ -332,7 +302,7 @@ func WriteBytes(buf, val []byte, coll Collation) []byte { func expectSize(buf []byte, sz ByteSize) { if ByteSize(len(buf)) != sz { - panic("byte slice is not of expected size") + panic(fmt.Sprintf("byte slice is length %v expected %v", len(buf), sz)) } } @@ -378,9 +348,9 @@ func compare(typ Type, left, right []byte) int { case Float64Enc: return compareFloat64(ReadFloat64(left), ReadFloat64(right)) case StringEnc: - return compareString(ReadString(left, typ.Coll), ReadString(right, typ.Coll), typ.Coll) + return compareString(ReadString(left), ReadString(right), typ.Coll) case BytesEnc: - return compareBytes(ReadBytes(left, typ.Coll), ReadBytes(right, typ.Coll), typ.Coll) + return compareBytes(ReadBytes(left), ReadBytes(right), typ.Coll) default: panic("unknown encoding") }