Skip to content

Commit afd50ce

Browse files
committed
feat(sqlite): support CTEs in INSERT / UPDATE / DELETE statements
CTE handling differs slighlty between SELECT and other kinds of statements, as SELECT grammar uses a `common_table_stmt` token, while other DML statements use a `with_clause` token
1 parent 5252c63 commit afd50ce

File tree

1 file changed

+41
-2
lines changed

1 file changed

+41
-2
lines changed

internal/engine/sqlite/convert.go

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,12 +195,18 @@ func (c *cc) convertCreate_view_stmtContext(n *parser.Create_view_stmtContext) a
195195
type Delete_stmt interface {
196196
node
197197

198+
With_clause() parser.IWith_clauseContext
198199
Qualified_table_name() parser.IQualified_table_nameContext
199200
WHERE_() antlr.TerminalNode
200201
Expr() parser.IExprContext
201202
}
202203

203204
func (c *cc) convertDelete_stmtContext(n Delete_stmt) ast.Node {
205+
var withClause *ast.WithClause
206+
if w := n.With_clause(); w != nil {
207+
withClause = c.convertWithClause(w)
208+
}
209+
204210
if qualifiedName, ok := n.Qualified_table_name().(*parser.Qualified_table_nameContext); ok {
205211

206212
tableName := identifier(qualifiedName.Table_name().GetText())
@@ -223,8 +229,8 @@ func (c *cc) convertDelete_stmtContext(n Delete_stmt) ast.Node {
223229
relations.Items = append(relations.Items, relation)
224230

225231
delete := &ast.DeleteStmt{
232+
WithClause: withClause,
226233
Relations: relations,
227-
WithClause: nil,
228234
}
229235

230236
if n.WHERE_() != nil && n.Expr() != nil {
@@ -854,6 +860,11 @@ func (c *cc) convertReturning_caluseContext(n parser.IReturning_clauseContext) *
854860
}
855861

856862
func (c *cc) convertInsert_stmtContext(n *parser.Insert_stmtContext) ast.Node {
863+
var withClause *ast.WithClause
864+
if w := n.With_clause(); w != nil {
865+
withClause = c.convertWithClause(w)
866+
}
867+
857868
tableName := identifier(n.Table_name().GetText())
858869
rel := &ast.RangeVar{
859870
Relname: &tableName,
@@ -870,6 +881,7 @@ func (c *cc) convertInsert_stmtContext(n *parser.Insert_stmtContext) ast.Node {
870881
}
871882

872883
insert := &ast.InsertStmt{
884+
WithClause: withClause,
873885
Relation: rel,
874886
Cols: c.convertColumnNames(n.AllColumn_name()),
875887
ReturningList: c.convertReturning_caluseContext(n.Returning_clause()),
@@ -1020,7 +1032,29 @@ func (c *cc) convertTablesOrSubquery(n []parser.ITable_or_subqueryContext) []ast
10201032
return tables
10211033
}
10221034

1035+
func (c *cc) convertWithClause(w parser.IWith_clauseContext) *ast.WithClause {
1036+
var ctes ast.List
1037+
recursive := w.RECURSIVE_() != nil
1038+
for idx, cte := range w.AllCte_table_name() {
1039+
tableName := identifier(cte.Table_name().GetText())
1040+
var cteCols ast.List
1041+
for _, col := range cte.AllColumn_name() {
1042+
cteCols.Items = append(cteCols.Items, NewIdentifier(col.GetText()))
1043+
}
1044+
ctes.Items = append(ctes.Items, &ast.CommonTableExpr{
1045+
Ctename: &tableName,
1046+
Ctequery: c.convert(w.Select_stmt(idx)),
1047+
Location: cte.GetStart().GetStart(),
1048+
Cterecursive: recursive,
1049+
Ctecolnames: &cteCols,
1050+
})
1051+
}
1052+
1053+
return &ast.WithClause{Ctes: &ctes}
1054+
}
1055+
10231056
type Update_stmt interface {
1057+
With_clause() parser.IWith_clauseContext
10241058
Qualified_table_name() parser.IQualified_table_nameContext
10251059
GetStart() antlr.Token
10261060
AllColumn_name() []parser.IColumn_nameContext
@@ -1034,6 +1068,11 @@ func (c *cc) convertUpdate_stmtContext(n Update_stmt) ast.Node {
10341068
return nil
10351069
}
10361070

1071+
var withClause *ast.WithClause
1072+
if w := n.With_clause(); w != nil {
1073+
withClause = c.convertWithClause(w)
1074+
}
1075+
10371076
relations := &ast.List{}
10381077
tableName := identifier(n.Qualified_table_name().GetText())
10391078
rel := ast.RangeVar{
@@ -1062,7 +1101,7 @@ func (c *cc) convertUpdate_stmtContext(n Update_stmt) ast.Node {
10621101
TargetList: list,
10631102
WhereClause: where,
10641103
FromClause: &ast.List{},
1065-
WithClause: nil, // TODO: support with clause
1104+
WithClause: withClause,
10661105
}
10671106
if n, ok := n.(interface {
10681107
Returning_clause() parser.IReturning_clauseContext

0 commit comments

Comments
 (0)