Skip to content

Commit dc44702

Browse files
kyleconroyclaude
andcommitted
refactor(format): add Formatter interface for SQL dialect-specific quoting
- Create internal/sql/format package with Formatter interface - Add QuoteIdent method to TrackedBuffer that delegates to Formatter - Implement QuoteIdent on postgresql.Parser using existing IsReservedKeyword - Update all Format() methods to use buf.QuoteIdent() instead of local quoteIdent() - Remove duplicate reserved word logic from ast/column_ref.go - Update ast.Format() to accept a Formatter parameter This allows each SQL dialect to provide its own identifier quoting logic based on its reserved keywords and quoting rules. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent 0e7fa5a commit dc44702

File tree

8 files changed

+55
-66
lines changed

8 files changed

+55
-66
lines changed

internal/endtoend/fmt_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ func TestFormat(t *testing.T) {
109109
debug.Dump(r, err)
110110
}
111111

112-
out := ast.Format(stmt.Raw)
112+
out := ast.Format(stmt.Raw, parse)
113113
actual, err := postgresql.Fingerprint(out)
114114
if err != nil {
115115
t.Error(err)

internal/engine/postgresql/reserved.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,26 @@ package postgresql
22

33
import "strings"
44

5+
// hasMixedCase returns true if the string has any uppercase letters
6+
// (identifiers with mixed case need quoting in PostgreSQL)
7+
func hasMixedCase(s string) bool {
8+
for _, r := range s {
9+
if r >= 'A' && r <= 'Z' {
10+
return true
11+
}
12+
}
13+
return false
14+
}
15+
16+
// QuoteIdent returns a quoted identifier if it needs quoting.
17+
// This implements the format.Formatter interface.
18+
func (p *Parser) QuoteIdent(s string) string {
19+
if p.IsReservedKeyword(s) || hasMixedCase(s) {
20+
return `"` + s + `"`
21+
}
22+
return s
23+
}
24+
525
// https://www.postgresql.org/docs/current/sql-keywords-appendix.html
626
func (p *Parser) IsReservedKeyword(s string) bool {
727
switch strings.ToLower(s) {

internal/sql/ast/column_ref.go

Lines changed: 1 addition & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,57 +2,6 @@ package ast
22

33
import "strings"
44

5-
// sqlReservedWords is a set of SQL keywords that must be quoted when used as identifiers
6-
var sqlReservedWords = map[string]bool{
7-
"all": true, "analyse": true, "analyze": true, "and": true, "any": true,
8-
"array": true, "as": true, "asc": true, "asymmetric": true, "authorization": true,
9-
"between": true, "binary": true, "both": true, "case": true, "cast": true,
10-
"check": true, "collate": true, "collation": true, "column": true, "concurrently": true,
11-
"constraint": true, "create": true, "cross": true, "current_catalog": true,
12-
"current_date": true, "current_role": true, "current_schema": true,
13-
"current_time": true, "current_timestamp": true, "current_user": true,
14-
"default": true, "deferrable": true, "desc": true, "distinct": true, "do": true,
15-
"else": true, "end": true, "except": true, "false": true, "fetch": true,
16-
"for": true, "foreign": true, "freeze": true, "from": true, "full": true,
17-
"grant": true, "group": true, "having": true, "ilike": true, "in": true,
18-
"initially": true, "inner": true, "intersect": true, "into": true, "is": true,
19-
"isnull": true, "join": true, "lateral": true, "leading": true, "left": true,
20-
"like": true, "limit": true, "localtime": true, "localtimestamp": true,
21-
"natural": true, "not": true, "notnull": true, "null": true, "offset": true,
22-
"on": true, "only": true, "or": true, "order": true, "outer": true,
23-
"overlaps": true, "placing": true, "primary": true, "references": true,
24-
"returning": true, "right": true, "select": true, "session_user": true,
25-
"similar": true, "some": true, "symmetric": true, "table": true, "tablesample": true,
26-
"then": true, "to": true, "trailing": true, "true": true, "union": true,
27-
"unique": true, "user": true, "using": true, "variadic": true, "verbose": true,
28-
"when": true, "where": true, "window": true, "with": true,
29-
}
30-
31-
// needsQuoting returns true if the identifier is a SQL reserved word
32-
// that needs to be quoted when used as an identifier
33-
func needsQuoting(s string) bool {
34-
return sqlReservedWords[strings.ToLower(s)]
35-
}
36-
37-
// hasMixedCase returns true if the string has any uppercase letters
38-
// (identifiers with mixed case need quoting in PostgreSQL)
39-
func hasMixedCase(s string) bool {
40-
for _, r := range s {
41-
if r >= 'A' && r <= 'Z' {
42-
return true
43-
}
44-
}
45-
return false
46-
}
47-
48-
// quoteIdent returns a quoted identifier if it needs quoting
49-
func quoteIdent(s string) string {
50-
if needsQuoting(s) || hasMixedCase(s) {
51-
return `"` + s + `"`
52-
}
53-
return s
54-
}
55-
565
type ColumnRef struct {
576
Name string
587

@@ -75,7 +24,7 @@ func (n *ColumnRef) Format(buf *TrackedBuffer) {
7524
for _, item := range n.Fields.Items {
7625
switch nn := item.(type) {
7726
case *String:
78-
items = append(items, quoteIdent(nn.Str))
27+
items = append(items, buf.QuoteIdent(nn.Str))
7928
case *A_Star:
8029
items = append(items, "*")
8130
}

internal/sql/ast/print.go

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,38 @@ import (
44
"strings"
55

66
"github.com/sqlc-dev/sqlc/internal/debug"
7+
"github.com/sqlc-dev/sqlc/internal/sql/format"
78
)
89

9-
type formatter interface {
10+
type nodeFormatter interface {
1011
Format(*TrackedBuffer)
1112
}
1213

1314
type TrackedBuffer struct {
1415
*strings.Builder
16+
formatter format.Formatter
1517
}
1618

17-
// NewTrackedBuffer creates a new TrackedBuffer.
18-
func NewTrackedBuffer() *TrackedBuffer {
19+
// NewTrackedBuffer creates a new TrackedBuffer with the given formatter.
20+
func NewTrackedBuffer(f format.Formatter) *TrackedBuffer {
1921
buf := &TrackedBuffer{
20-
Builder: new(strings.Builder),
22+
Builder: new(strings.Builder),
23+
formatter: f,
2124
}
2225
return buf
2326
}
2427

28+
// QuoteIdent returns a quoted identifier if it needs quoting.
29+
// If no formatter is set, it returns the identifier unchanged.
30+
func (t *TrackedBuffer) QuoteIdent(s string) string {
31+
if t.formatter != nil {
32+
return t.formatter.QuoteIdent(s)
33+
}
34+
return s
35+
}
36+
2537
func (t *TrackedBuffer) astFormat(n Node) {
26-
if ft, ok := n.(formatter); ok {
38+
if ft, ok := n.(nodeFormatter); ok {
2739
ft.Format(t)
2840
} else {
2941
debug.Dump(n)
@@ -45,9 +57,9 @@ func (t *TrackedBuffer) join(n *List, sep string) {
4557
}
4658
}
4759

48-
func Format(n Node) string {
49-
tb := NewTrackedBuffer()
50-
if ft, ok := n.(formatter); ok {
60+
func Format(n Node, f format.Formatter) string {
61+
tb := NewTrackedBuffer(f)
62+
if ft, ok := n.(nodeFormatter); ok {
5163
ft.Format(tb)
5264
}
5365
return tb.String()

internal/sql/ast/range_var.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ func (n *RangeVar) Format(buf *TrackedBuffer) {
1919
return
2020
}
2121
if n.Schemaname != nil {
22-
buf.WriteString(quoteIdent(*n.Schemaname))
22+
buf.WriteString(buf.QuoteIdent(*n.Schemaname))
2323
buf.WriteString(".")
2424
}
2525
if n.Relname != nil {
26-
buf.WriteString(quoteIdent(*n.Relname))
26+
buf.WriteString(buf.QuoteIdent(*n.Relname))
2727
}
2828
if n.Alias != nil {
2929
buf.WriteString(" ")

internal/sql/ast/res_target.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ func (n *ResTarget) Format(buf *TrackedBuffer) {
1919
buf.astFormat(n.Val)
2020
if n.Name != nil {
2121
buf.WriteString(" AS ")
22-
buf.WriteString(quoteIdent(*n.Name))
22+
buf.WriteString(buf.QuoteIdent(*n.Name))
2323
}
2424
} else {
2525
if n.Name != nil {
26-
buf.WriteString(quoteIdent(*n.Name))
26+
buf.WriteString(buf.QuoteIdent(*n.Name))
2727
}
2828
}
2929
}

internal/sql/ast/update_stmt.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ func (n *UpdateStmt) Format(buf *TrackedBuffer) {
7979
switch nn := item.(type) {
8080
case *ResTarget:
8181
if nn.Name != nil {
82-
buf.WriteString(quoteIdent(*nn.Name))
82+
buf.WriteString(buf.QuoteIdent(*nn.Name))
8383
}
8484
// Handle array subscript indirection (e.g., names[$1])
8585
if items(nn.Indirection) {

internal/sql/format/format.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package format
2+
3+
// Formatter provides SQL dialect-specific formatting behavior
4+
type Formatter interface {
5+
// QuoteIdent returns a quoted identifier if it needs quoting
6+
// (e.g., reserved words, mixed case identifiers)
7+
QuoteIdent(s string) string
8+
}

0 commit comments

Comments
 (0)