Skip to content

Commit c4e885f

Browse files
kyleconroyclaude
andcommitted
feat(ast): complete SQL AST formatting implementation
Fixes all ast.Format test failures by implementing comprehensive Format methods for SQL AST nodes. Key improvements include: - Named parameters (@param) formatting without space after @ - NULLIF expression support in A_Expr - NULLS FIRST/LAST in ORDER BY clauses - Type name mapping (int4→integer, timestamptz→timestamp with time zone) - Array type support (text[]) and type modifiers (varchar(32)) - CREATE FUNCTION with parameters, options (AS, LANGUAGE), and modes - CREATE EXTENSION statement formatting - DO $$ ... $$ anonymous code blocks - WITHIN GROUP clause for ordered-set aggregates - Automatic quoting for SQL reserved words and mixed-case identifiers - CROSS JOIN detection (JOIN without ON/USING clause) - LATERAL keyword in subselects and function calls - Array subscript access in UPDATE statements (names[$1]) - Proper AS keyword before aliases Also removes unused deparse files and cleans up fmt_test.go to use ast.Format directly. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 25ee705 commit c4e885f

25 files changed

+446
-109
lines changed

internal/endtoend/fmt_test.go

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/sqlc-dev/sqlc/internal/config"
1212
"github.com/sqlc-dev/sqlc/internal/debug"
1313
"github.com/sqlc-dev/sqlc/internal/engine/postgresql"
14+
"github.com/sqlc-dev/sqlc/internal/sql/ast"
1415
)
1516

1617
func TestFormat(t *testing.T) {
@@ -103,27 +104,18 @@ func TestFormat(t *testing.T) {
103104
t.Fatal(err)
104105
}
105106

106-
// Parse the query to get a ParseResult for Deparse
107-
parseResult, err := postgresql.Parse(query)
108-
if err != nil {
109-
t.Fatal(err)
110-
}
111-
112107
if false {
113-
debug.Dump(parseResult)
114-
}
115-
116-
out, err := postgresql.Deparse(parseResult)
117-
if err != nil {
118-
t.Fatal(err)
108+
r, err := postgresql.Parse(query)
109+
debug.Dump(r, err)
119110
}
120111

112+
out := ast.Format(stmt.Raw)
121113
actual, err := postgresql.Fingerprint(out)
122114
if err != nil {
123115
t.Error(err)
124116
}
125117
if expected != actual {
126-
debug.Dump(parseResult)
118+
debug.Dump(stmt.Raw)
127119
t.Errorf("- %s", expected)
128120
t.Errorf("- %s", query)
129121
t.Errorf("+ %s", actual)

internal/engine/postgresql/convert.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1965,6 +1965,22 @@ func convertNullTest(n *pg.NullTest) *ast.NullTest {
19651965
}
19661966
}
19671967

1968+
func convertNullIfExpr(n *pg.NullIfExpr) *ast.NullIfExpr {
1969+
if n == nil {
1970+
return nil
1971+
}
1972+
return &ast.NullIfExpr{
1973+
Xpr: convertNode(n.Xpr),
1974+
Opno: ast.Oid(n.Opno),
1975+
Opresulttype: ast.Oid(n.Opresulttype),
1976+
Opretset: n.Opretset,
1977+
Opcollid: ast.Oid(n.Opcollid),
1978+
Inputcollid: ast.Oid(n.Inputcollid),
1979+
Args: convertSlice(n.Args),
1980+
Location: int(n.Location),
1981+
}
1982+
}
1983+
19681984
func convertObjectWithArgs(n *pg.ObjectWithArgs) *ast.ObjectWithArgs {
19691985
if n == nil {
19701986
return nil
@@ -3420,6 +3436,9 @@ func convertNode(node *pg.Node) ast.Node {
34203436
case *pg.Node_NullTest:
34213437
return convertNullTest(n.NullTest)
34223438

3439+
case *pg.Node_NullIfExpr:
3440+
return convertNullIfExpr(n.NullIfExpr)
3441+
34233442
case *pg.Node_ObjectWithArgs:
34243443
return convertObjectWithArgs(n.ObjectWithArgs)
34253444

internal/engine/postgresql/deparse.go

Lines changed: 0 additions & 26 deletions
This file was deleted.

internal/engine/postgresql/deparse_wasi.go

Lines changed: 0 additions & 26 deletions
This file was deleted.

internal/engine/postgresql/parse.go

Lines changed: 90 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,94 @@ func stringSliceFromNodes(s []*nodes.Node) []string {
3434
return items
3535
}
3636

37+
func translateNode(node *nodes.Node) ast.Node {
38+
if node == nil {
39+
return nil
40+
}
41+
switch n := node.Node.(type) {
42+
case *nodes.Node_String_:
43+
return &ast.String{Str: n.String_.Sval}
44+
case *nodes.Node_Integer:
45+
return &ast.Integer{Ival: int64(n.Integer.Ival)}
46+
case *nodes.Node_Boolean:
47+
return &ast.Boolean{Boolval: n.Boolean.Boolval}
48+
case *nodes.Node_AConst:
49+
// A_Const contains a constant value (used in type modifiers like varchar(32))
50+
if n.AConst.GetIval() != nil {
51+
return &ast.Integer{Ival: int64(n.AConst.GetIval().Ival)}
52+
}
53+
if n.AConst.GetSval() != nil {
54+
return &ast.String{Str: n.AConst.GetSval().Sval}
55+
}
56+
if n.AConst.GetFval() != nil {
57+
return &ast.Float{Str: n.AConst.GetFval().Fval}
58+
}
59+
if n.AConst.GetBoolval() != nil {
60+
return &ast.Boolean{Boolval: n.AConst.GetBoolval().Boolval}
61+
}
62+
return &ast.TODO{}
63+
case *nodes.Node_List:
64+
list := &ast.List{}
65+
for _, item := range n.List.Items {
66+
list.Items = append(list.Items, translateNode(item))
67+
}
68+
return list
69+
default:
70+
return &ast.TODO{}
71+
}
72+
}
73+
74+
func translateDefElem(n *nodes.DefElem) *ast.DefElem {
75+
if n == nil {
76+
return nil
77+
}
78+
defname := n.Defname
79+
return &ast.DefElem{
80+
Defname: &defname,
81+
Arg: translateNode(n.Arg),
82+
Location: int(n.Location),
83+
}
84+
}
85+
86+
func translateOptions(opts []*nodes.Node) *ast.List {
87+
if opts == nil {
88+
return nil
89+
}
90+
list := &ast.List{}
91+
for _, opt := range opts {
92+
if de, ok := opt.Node.(*nodes.Node_DefElem); ok {
93+
list.Items = append(list.Items, translateDefElem(de.DefElem))
94+
}
95+
}
96+
return list
97+
}
98+
99+
func translateTypeNameFromPG(tn *nodes.TypeName) *ast.TypeName {
100+
if tn == nil {
101+
return nil
102+
}
103+
rel, err := parseRelationFromNodes(tn.Names)
104+
if err != nil {
105+
return nil
106+
}
107+
result := rel.TypeName()
108+
// Preserve array bounds
109+
if len(tn.ArrayBounds) > 0 {
110+
result.ArrayBounds = &ast.List{}
111+
for _, ab := range tn.ArrayBounds {
112+
result.ArrayBounds.Items = append(result.ArrayBounds.Items, translateNode(ab))
113+
}
114+
}
115+
// Preserve type modifiers
116+
if len(tn.Typmods) > 0 {
117+
result.Typmods = &ast.List{}
118+
for _, tm := range tn.Typmods {
119+
result.Typmods.Items = append(result.Typmods.Items, translateNode(tm))
120+
}
121+
}
122+
return result
123+
}
124+
37125
type relation struct {
38126
Catalog string
39127
Schema string
@@ -431,11 +519,6 @@ func translate(node *nodes.Node) (ast.Node, error) {
431519
for _, elt := range n.TableElts {
432520
switch item := elt.Node.(type) {
433521
case *nodes.Node_ColumnDef:
434-
rel, err := parseRelationFromNodes(item.ColumnDef.TypeName.Names)
435-
if err != nil {
436-
return nil, err
437-
}
438-
439522
primary := false
440523
for _, con := range item.ColumnDef.Constraints {
441524
if constraint, ok := con.Node.(*nodes.Node_Constraint); ok {
@@ -445,7 +528,7 @@ func translate(node *nodes.Node) (ast.Node, error) {
445528

446529
create.Cols = append(create.Cols, &ast.ColumnDef{
447530
Colname: item.ColumnDef.Colname,
448-
TypeName: rel.TypeName(),
531+
TypeName: translateTypeNameFromPG(item.ColumnDef.TypeName),
449532
IsNotNull: isNotNull(item.ColumnDef) || primaryKey[item.ColumnDef.Colname],
450533
IsArray: isArray(item.ColumnDef.TypeName),
451534
ArrayDims: len(item.ColumnDef.TypeName.ArrayBounds),
@@ -494,6 +577,7 @@ func translate(node *nodes.Node) (ast.Node, error) {
494577
ReturnType: rt,
495578
Replace: n.Replace,
496579
Params: &ast.List{},
580+
Options: translateOptions(n.Options),
497581
}
498582
for _, item := range n.Parameters {
499583
arg := item.Node.(*nodes.Node_FunctionParameter).FunctionParameter

internal/engine/postgresql/parse_default.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,3 @@ import (
88

99
var Parse = nodes.Parse
1010
var Fingerprint = nodes.Fingerprint
11-
12-
var nodeDeparse = nodes.Deparse

internal/engine/postgresql/parse_wasi.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,3 @@ import (
88

99
var Parse = nodes.Parse
1010
var Fingerprint = nodes.Fingerprint
11-
12-
var nodeDeparse = nodes.Deparse

internal/sql/ast/a_expr.go

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,35 @@ func (n *A_Expr) Format(buf *TrackedBuffer) {
5858
buf.astFormat(n.Lexpr)
5959
buf.WriteString(" IS NOT DISTINCT FROM ")
6060
buf.astFormat(n.Rexpr)
61+
case A_Expr_Kind_NULLIF:
62+
buf.WriteString("NULLIF(")
63+
buf.astFormat(n.Lexpr)
64+
buf.WriteString(", ")
65+
buf.astFormat(n.Rexpr)
66+
buf.WriteString(")")
6167
case A_Expr_Kind_OP:
62-
// Standard binary operator
63-
if set(n.Lexpr) {
64-
buf.astFormat(n.Lexpr)
65-
buf.WriteString(" ")
68+
// Check if this is a named parameter (@name)
69+
opName := ""
70+
if n.Name != nil && len(n.Name.Items) == 1 {
71+
if s, ok := n.Name.Items[0].(*String); ok {
72+
opName = s.Str
73+
}
6674
}
67-
buf.astFormat(n.Name)
68-
if set(n.Rexpr) {
69-
buf.WriteString(" ")
75+
if opName == "@" && !set(n.Lexpr) && set(n.Rexpr) {
76+
// Named parameter: @name (no space after @)
77+
buf.WriteString("@")
7078
buf.astFormat(n.Rexpr)
79+
} else {
80+
// Standard binary operator
81+
if set(n.Lexpr) {
82+
buf.astFormat(n.Lexpr)
83+
buf.WriteString(" ")
84+
}
85+
buf.astFormat(n.Name)
86+
if set(n.Rexpr) {
87+
buf.WriteString(" ")
88+
buf.astFormat(n.Rexpr)
89+
}
7190
}
7291
default:
7392
// Fallback for other cases

internal/sql/ast/a_indices.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,22 @@ type A_Indices struct {
99
func (n *A_Indices) Pos() int {
1010
return 0
1111
}
12+
13+
func (n *A_Indices) Format(buf *TrackedBuffer) {
14+
if n == nil {
15+
return
16+
}
17+
buf.WriteString("[")
18+
if n.IsSlice {
19+
if set(n.Lidx) {
20+
buf.astFormat(n.Lidx)
21+
}
22+
buf.WriteString(":")
23+
if set(n.Uidx) {
24+
buf.astFormat(n.Uidx)
25+
}
26+
} else {
27+
buf.astFormat(n.Uidx)
28+
}
29+
buf.WriteString("]")
30+
}

internal/sql/ast/column_ref.go

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,57 @@ package ast
22

33
import "strings"
44

5+
// sqlReservedWords is a set of SQL keywords that must be quoted when used as identifiers
6+
var sqlReservedWords = map[string]bool{
7+
"all": true, "analyse": true, "analyze": true, "and": true, "any": true,
8+
"array": true, "as": true, "asc": true, "asymmetric": true, "authorization": true,
9+
"between": true, "binary": true, "both": true, "case": true, "cast": true,
10+
"check": true, "collate": true, "collation": true, "column": true, "concurrently": true,
11+
"constraint": true, "create": true, "cross": true, "current_catalog": true,
12+
"current_date": true, "current_role": true, "current_schema": true,
13+
"current_time": true, "current_timestamp": true, "current_user": true,
14+
"default": true, "deferrable": true, "desc": true, "distinct": true, "do": true,
15+
"else": true, "end": true, "except": true, "false": true, "fetch": true,
16+
"for": true, "foreign": true, "freeze": true, "from": true, "full": true,
17+
"grant": true, "group": true, "having": true, "ilike": true, "in": true,
18+
"initially": true, "inner": true, "intersect": true, "into": true, "is": true,
19+
"isnull": true, "join": true, "lateral": true, "leading": true, "left": true,
20+
"like": true, "limit": true, "localtime": true, "localtimestamp": true,
21+
"natural": true, "not": true, "notnull": true, "null": true, "offset": true,
22+
"on": true, "only": true, "or": true, "order": true, "outer": true,
23+
"overlaps": true, "placing": true, "primary": true, "references": true,
24+
"returning": true, "right": true, "select": true, "session_user": true,
25+
"similar": true, "some": true, "symmetric": true, "table": true, "tablesample": true,
26+
"then": true, "to": true, "trailing": true, "true": true, "union": true,
27+
"unique": true, "user": true, "using": true, "variadic": true, "verbose": true,
28+
"when": true, "where": true, "window": true, "with": true,
29+
}
30+
31+
// needsQuoting returns true if the identifier is a SQL reserved word
32+
// that needs to be quoted when used as an identifier
33+
func needsQuoting(s string) bool {
34+
return sqlReservedWords[strings.ToLower(s)]
35+
}
36+
37+
// hasMixedCase returns true if the string has any uppercase letters
38+
// (identifiers with mixed case need quoting in PostgreSQL)
39+
func hasMixedCase(s string) bool {
40+
for _, r := range s {
41+
if r >= 'A' && r <= 'Z' {
42+
return true
43+
}
44+
}
45+
return false
46+
}
47+
48+
// quoteIdent returns a quoted identifier if it needs quoting
49+
func quoteIdent(s string) string {
50+
if needsQuoting(s) || hasMixedCase(s) {
51+
return `"` + s + `"`
52+
}
53+
return s
54+
}
55+
556
type ColumnRef struct {
657
Name string
758

@@ -24,11 +75,7 @@ func (n *ColumnRef) Format(buf *TrackedBuffer) {
2475
for _, item := range n.Fields.Items {
2576
switch nn := item.(type) {
2677
case *String:
27-
if nn.Str == "user" {
28-
items = append(items, `"user"`)
29-
} else {
30-
items = append(items, nn.Str)
31-
}
78+
items = append(items, quoteIdent(nn.Str))
3279
case *A_Star:
3380
items = append(items, "*")
3481
}

0 commit comments

Comments
 (0)