Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,10 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
defer wg.Done()
for {
if r == nil {
r = &sqltypes.Result{Fields: resultFields}
r = &sqltypes.Result{
Fields: resultFields,
Rows: make([][]sqltypes.Value, 0, rowsBatch),
}
}
if r.RowsAffected == rowsBatch {
if err := resetCallback(r, more); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion sql/expression/function/aggregation/group_concat.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ func (g *groupConcatBuffer) Update(ctx *sql.Context, originalRow sql.Row) error

// Append the current value to the end of the row. We want to preserve the row's original structure
// for sort ordering in the final step.
g.rows = append(g.rows, append(originalRow, vs))
g.rows = append(g.rows, append(originalRow.Copy(), vs))

return nil
}
Expand Down
2 changes: 1 addition & 1 deletion sql/expression/function/aggregation/window_partition.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func (i *WindowPartitionIter) materializeInput(ctx *sql.Context) (sql.WindowBuff
}
return nil, nil, err
}
input = append(input, append(row, j))
input = append(input, append(append(sql.Row(nil), row...), j))
j++
}

Expand Down
10 changes: 5 additions & 5 deletions sql/rowexec/ddl_iters.go
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,7 @@ func (i *loggingKeyValueIter) Close(ctx *sql.Context) error {
// projectRowWithTypes projects the row given with the projections given and additionally converts them to the
// corresponding types found in the schema given, using the standard type conversion logic.
func projectRowWithTypes(ctx *sql.Context, oldSchema, newSchema sql.Schema, projections []sql.Expression, r sql.Row) (sql.Row, error) {
newRow, err := ProjectRow(ctx, projections, r)
newRow, err := ProjectRow(ctx, projections, r, nil)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1440,7 +1440,7 @@ func (i *addColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTable)
return false, err
}

newRow, err := ProjectRow(ctx, projections, r)
newRow, err := ProjectRow(ctx, projections, r, nil)
if err != nil {
_ = inserter.DiscardChanges(ctx, err)
_ = inserter.Close(ctx)
Expand Down Expand Up @@ -1736,7 +1736,7 @@ func (i *dropColumnIter) rewriteTable(ctx *sql.Context, rwt sql.RewritableTable)
return false, err
}

newRow, err := ProjectRow(ctx, projections, r)
newRow, err := ProjectRow(ctx, projections, r, nil)
if err != nil {
_ = inserter.DiscardChanges(ctx, err)
_ = inserter.Close(ctx)
Expand Down Expand Up @@ -2240,7 +2240,7 @@ func buildIndex(ctx *sql.Context, n *plan.AlterIndex, ibt sql.IndexBuildingTable
}

if isVirtual {
r, err = ProjectRow(ctx, projections, r)
r, err = ProjectRow(ctx, projections, r, nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -2326,7 +2326,7 @@ func rewriteTableForIndexCreate(ctx *sql.Context, n *plan.AlterIndex, table sql.
}

if isVirtual {
r, err = ProjectRow(ctx, projections, r)
r, err = ProjectRow(ctx, projections, r, nil)
if err != nil {
return err
}
Expand Down
10 changes: 4 additions & 6 deletions sql/rowexec/insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,16 @@ func getInsertExpressions(values sql.Node) []sql.Expression {
}

func (i *insertIter) Next(ctx *sql.Context) (returnRow sql.Row, returnErr error) {
row, err := i.rowSource.Next(ctx)
origRow, err := i.rowSource.Next(ctx)
if err == io.EOF {
return nil, err
}

if err != nil {
return nil, i.ignoreOrClose(ctx, row, err)
return nil, i.ignoreOrClose(ctx, origRow, err)
}

row := origRow.Copy()

// Increment row number for error reporting (MySQL starts at 1)
i.rowNumber++

Expand Down Expand Up @@ -107,9 +108,6 @@ func (i *insertIter) Next(ctx *sql.Context) (returnRow sql.Row, returnErr error)
return nil, i.ignoreOrClose(ctx, row, err)
}

origRow := make(sql.Row, len(row))
copy(origRow, row)

// Do any necessary type conversions to the target schema
for idx, col := range i.schema {
if row[idx] != nil {
Expand Down
104 changes: 72 additions & 32 deletions sql/rowexec/join_iters.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ type joinIter struct {
rowSize int
scopeLen int
parentLen int

rowBuffer *sql.RowBuffer
}

func newJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNode, row sql.Row) (sql.RowIter, error) {
Expand Down Expand Up @@ -75,9 +77,10 @@ func newJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNode, row
return nil, err
}

parentLen := len(row)
rowBuffer := sql.RowBufPool.Get().(*sql.RowBuffer)

primaryRow := make(sql.Row, parentLen+len(j.Left().Schema()))
parentLen := len(row)
primaryRow := rowBuffer.Get(parentLen + len(j.Left().Schema()))
copy(primaryRow, row)

return sql.NewSpanIter(span, &joinIter{
Expand All @@ -94,6 +97,8 @@ func newJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNode, row
rowSize: parentLen + len(j.Left().Schema()) + len(j.Right().Schema()),
scopeLen: j.ScopeLen,
parentLen: parentLen,

rowBuffer: rowBuffer,
}), nil
}

Expand Down Expand Up @@ -184,6 +189,8 @@ func (i *joinIter) Next(ctx *sql.Context) (sql.Row, error) {
}

if !sql.IsTrue(res) {
// TODO: we are trashing row here, so we can release the memory...right?
i.rowBuffer.Erase(i.rowSize)
continue
}

Expand All @@ -200,13 +207,16 @@ func (i *joinIter) removeParentRow(r sql.Row) sql.Row {

// buildRow builds the result set row using the rows from the primary and secondary tables
func (i *joinIter) buildRow(primary, secondary sql.Row) sql.Row {
row := make(sql.Row, i.rowSize)
row := i.rowBuffer.Get(i.rowSize)
copy(row, primary)
copy(row[len(primary):], secondary)
return row
}

func (i *joinIter) Close(ctx *sql.Context) (err error) {
//i.rowBuffer.Reset()
//sql.RowBufPool.Put(i.rowBuffer)

if i.primary != nil {
if err = i.primary.Close(ctx); err != nil {
if i.secondary != nil {
Expand All @@ -232,11 +242,13 @@ func newExistsIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNode, ro

parentLen := len(row)

rowBuffer := sql.RowBufPool.Get().(*sql.RowBuffer)

rowSize := parentLen + len(j.Left().Schema()) + len(j.Right().Schema())
fullRow := make(sql.Row, rowSize)
fullRow := rowBuffer.Get(rowSize)
copy(fullRow, row)

primaryRow := make(sql.Row, parentLen+len(j.Left().Schema()))
primaryRow := rowBuffer.Get(parentLen + len(j.Left().Schema()))
copy(primaryRow, row)

return &existsIter{
Expand All @@ -251,6 +263,7 @@ func newExistsIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNode, ro
scopeLen: j.ScopeLen,
rowSize: rowSize,
nullRej: !(j.Filter != nil && plan.IsNullRejecting(j.Filter)),
rowBuffer: rowBuffer,
}, nil
}

Expand All @@ -271,6 +284,8 @@ type existsIter struct {

nullRej bool
rightIterNonEmpty bool

rowBuffer *sql.RowBuffer
}

type existsState uint8
Expand Down Expand Up @@ -396,13 +411,16 @@ func (i *existsIter) removeParentRow(r sql.Row) sql.Row {

// buildRow builds the result set row using the rows from the primary and secondary tables
func (i *existsIter) buildRow(primary, secondary sql.Row) sql.Row {
row := make(sql.Row, i.rowSize)
row := i.rowBuffer.Get(i.rowSize)
copy(row, primary)
copy(row[len(primary):], secondary)
return row
}

func (i *existsIter) Close(ctx *sql.Context) (err error) {
i.rowBuffer.Reset()
sql.RowBufPool.Put(i.rowBuffer)

if i.primary != nil {
if err = i.primary.Close(ctx); err != nil {
return err
Expand All @@ -411,26 +429,6 @@ func (i *existsIter) Close(ctx *sql.Context) (err error) {
return err
}

func newFullJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNode, row sql.Row) (sql.RowIter, error) {
leftIter, err := b.Build(ctx, j.Left(), row)
if err != nil {
return nil, err
}
return &fullJoinIter{
parentRow: row,
l: leftIter,
rp: j.Right(),
cond: j.Filter,
scopeLen: j.ScopeLen,
rowSize: len(row) + len(j.Left().Schema()) + len(j.Right().Schema()),
seenLeft: make(map[uint64]struct{}),
seenRight: make(map[uint64]struct{}),
leftLen: len(j.Left().Schema()),
rightLen: len(j.Right().Schema()),
b: b,
}, nil
}

// fullJoinIter implements full join as a union of left and right join:
// FJ(A,B) => U(LJ(A,B), RJ(A,B)). The current algorithm will have a
// runtime and memory complexity O(m+n).
Expand All @@ -451,6 +449,30 @@ type fullJoinIter struct {
leftDone bool
seenLeft map[uint64]struct{}
seenRight map[uint64]struct{}

rowBuffer *sql.RowBuffer
}

func newFullJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNode, row sql.Row) (sql.RowIter, error) {
leftIter, err := b.Build(ctx, j.Left(), row)
if err != nil {
return nil, err
}
return &fullJoinIter{
parentRow: row,
l: leftIter,
rp: j.Right(),
cond: j.Filter,
scopeLen: j.ScopeLen,
rowSize: len(row) + len(j.Left().Schema()) + len(j.Right().Schema()),
seenLeft: make(map[uint64]struct{}),
seenRight: make(map[uint64]struct{}),
leftLen: len(j.Left().Schema()),
rightLen: len(j.Right().Schema()),
b: b,

rowBuffer: sql.RowBufPool.Get().(*sql.RowBuffer),
}, nil
}

func (i *fullJoinIter) Next(ctx *sql.Context) (sql.Row, error) {
Expand Down Expand Up @@ -546,7 +568,7 @@ func (i *fullJoinIter) Next(ctx *sql.Context) (sql.Row, error) {
continue
}
// (null, right) only if we haven't matched right
ret := make(sql.Row, i.rowSize)
ret := i.rowBuffer.Get(i.rowSize)
copy(ret[i.leftLen:], rightRow)
return i.removeParentRow(ret), nil
}
Expand All @@ -560,13 +582,16 @@ func (i *fullJoinIter) removeParentRow(r sql.Row) sql.Row {

// buildRow builds the result set row using the rows from the primary and secondary tables
func (i *fullJoinIter) buildRow(primary, secondary sql.Row) sql.Row {
row := make(sql.Row, i.rowSize)
row := i.rowBuffer.Get(i.rowSize)
copy(row, primary)
copy(row[len(primary):], secondary)
return row
}

func (i *fullJoinIter) Close(ctx *sql.Context) (err error) {
i.rowBuffer.Reset()
sql.RowBufPool.Put(i.rowBuffer)

if i.l != nil {
err = i.l.Close(ctx)
}
Expand All @@ -593,6 +618,8 @@ type crossJoinIterator struct {
rowSize int
scopeLen int
parentLen int

rowBuffer *sql.RowBuffer
}

func newCrossJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNode, row sql.Row) (sql.RowIter, error) {
Expand Down Expand Up @@ -620,9 +647,10 @@ func newCrossJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNode,
return nil, err
}

parentLen := len(row)
rowBuffer := sql.RowBufPool.Get().(*sql.RowBuffer)

primaryRow := make(sql.Row, parentLen+len(j.Left().Schema()))
parentLen := len(row)
primaryRow := rowBuffer.Get(parentLen + len(j.Left().Schema()))
copy(primaryRow, row)

return sql.NewSpanIter(span, &crossJoinIterator{
Expand All @@ -635,6 +663,8 @@ func newCrossJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNode,
rowSize: len(row) + len(j.Left().Schema()) + len(j.Right().Schema()),
scopeLen: j.ScopeLen,
parentLen: parentLen,

rowBuffer: rowBuffer,
}), nil
}

Expand Down Expand Up @@ -664,7 +694,7 @@ func (i *crossJoinIterator) Next(ctx *sql.Context) (sql.Row, error) {
return nil, err
}

row := make(sql.Row, i.rowSize)
row := i.rowBuffer.Get(i.rowSize)
copy(row, i.primaryRow)
copy(row[len(i.primaryRow):], rightRow)
return i.removeParentRow(row), nil
Expand All @@ -678,6 +708,9 @@ func (i *crossJoinIterator) removeParentRow(r sql.Row) sql.Row {
}

func (i *crossJoinIterator) Close(ctx *sql.Context) (err error) {
i.rowBuffer.Reset()
sql.RowBufPool.Put(i.rowBuffer)

if i.l != nil {
err = i.l.Close(ctx)
}
Expand Down Expand Up @@ -734,6 +767,8 @@ type lateralJoinIterator struct {
foundMatch bool

b sql.NodeExecBuilder

rowBuffer *sql.RowBuffer
}

func newLateralJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNode, row sql.Row) (sql.RowIter, error) {
Expand Down Expand Up @@ -769,6 +804,8 @@ func newLateralJoinIter(ctx *sql.Context, b sql.NodeExecBuilder, j *plan.JoinNod
rowSize: len(row) + len(j.Left().Schema()) + len(j.Right().Schema()),
scopeLen: j.ScopeLen,
b: b,

rowBuffer: sql.RowBufPool.Get().(*sql.RowBuffer),
}), nil
}

Expand Down Expand Up @@ -811,7 +848,7 @@ func (i *lateralJoinIterator) loadRight(ctx *sql.Context) error {
}

func (i *lateralJoinIterator) buildRow(lRow, rRow sql.Row) sql.Row {
row := make(sql.Row, i.rowSize)
row := i.rowBuffer.Get(i.rowSize)
copy(row, lRow)
copy(row[len(lRow):], rRow)
return row
Expand Down Expand Up @@ -874,6 +911,9 @@ func (i *lateralJoinIterator) Next(ctx *sql.Context) (sql.Row, error) {
}

func (i *lateralJoinIterator) Close(ctx *sql.Context) error {
i.rowBuffer.Reset()
sql.RowBufPool.Put(i.rowBuffer)

var lerr, rerr error
if i.lIter != nil {
lerr = i.lIter.Close(ctx)
Expand Down
Loading
Loading