Skip to content

Commit 29cdfb9

Browse files
committed
fix(clickhouse): fix parameter detection and managed-db integration
- Update ClickHouse to use managed-db context instead of separate context - Fix detectParameters to count ? placeholders in addition to parsing {name:Type} style parameters with doubleclick parser - Add DROP TABLE IF EXISTS before CREATE TABLE in migrations for idempotent schema application - Set ClickHouse database URI in test config for proper connection - Add pgrep check for MySQL service detection in native tests 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent ff082bd commit 29cdfb9

File tree

4 files changed

+181
-109
lines changed

4 files changed

+181
-109
lines changed

internal/endtoend/endtoend_test.go

Lines changed: 21 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package main
33
import (
44
"bytes"
55
"context"
6-
"database/sql"
76
"os"
87
osexec "os/exec"
98
"path/filepath"
@@ -12,7 +11,6 @@ import (
1211
"strings"
1312
"testing"
1413

15-
_ "github.com/ClickHouse/clickhouse-go/v2" // ClickHouse driver
1614
"github.com/google/go-cmp/cmp"
1715
"github.com/google/go-cmp/cmp/cmpopts"
1816

@@ -129,7 +127,7 @@ func TestReplay(t *testing.T) {
129127
}
130128

131129
// Try Docker for any missing databases
132-
if postgresURI == "" || mysqlURI == "" {
130+
if postgresURI == "" || mysqlURI == "" || clickhouseURI == "" {
133131
if err := docker.Installed(); err == nil {
134132
if postgresURI == "" {
135133
host, err := docker.StartPostgreSQLServer(ctx)
@@ -147,6 +145,14 @@ func TestReplay(t *testing.T) {
147145
mysqlURI = host
148146
}
149147
}
148+
if clickhouseURI == "" {
149+
host, err := docker.StartClickHouseServer(ctx)
150+
if err != nil {
151+
t.Logf("docker clickhouse startup failed: %s", err)
152+
} else {
153+
clickhouseURI = host
154+
}
155+
}
150156
}
151157
}
152158

@@ -205,6 +211,11 @@ func TestReplay(t *testing.T) {
205211
Engine: config.EngineMySQL,
206212
URI: mysqlURI,
207213
},
214+
{
215+
Name: "clickhouse",
216+
Engine: config.EngineClickHouse,
217+
URI: clickhouseURI,
218+
},
208219
}
209220

210221
for i := range c.SQL {
@@ -221,6 +232,12 @@ func TestReplay(t *testing.T) {
221232
c.SQL[i].Database = &config.Database{
222233
Managed: true,
223234
}
235+
case config.EngineClickHouse:
236+
// ClickHouse uses URI directly (not managed mode)
237+
c.SQL[i].Database = &config.Database{
238+
URI: clickhouseURI,
239+
Managed: true,
240+
}
224241
default:
225242
// pass
226243
}
@@ -229,56 +246,7 @@ func TestReplay(t *testing.T) {
229246
},
230247
Enabled: func() bool {
231248
// Enabled if at least one database URI is available
232-
return postgresURI != "" || mysqlURI != ""
233-
},
234-
},
235-
"clickhouse": {
236-
Mutate: func(t *testing.T, path string) func(*config.Config) {
237-
return func(c *config.Config) {
238-
for i := range c.SQL {
239-
if c.SQL[i].Engine == config.EngineClickHouse {
240-
c.SQL[i].Database = &config.Database{
241-
URI: clickhouseURI,
242-
}
243-
// Apply schema migrations to ClickHouse
244-
for _, schemaPath := range c.SQL[i].Schema {
245-
fullPath := filepath.Join(path, schemaPath)
246-
schemaSQL, err := os.ReadFile(fullPath)
247-
if err != nil {
248-
t.Logf("Failed to read schema %s: %v", fullPath, err)
249-
continue
250-
}
251-
db, err := sql.Open("clickhouse", clickhouseURI)
252-
if err != nil {
253-
t.Logf("Failed to connect to ClickHouse: %v", err)
254-
continue
255-
}
256-
// Execute each statement separately
257-
for _, stmt := range strings.Split(string(schemaSQL), ";") {
258-
stmt = strings.TrimSpace(stmt)
259-
if stmt == "" {
260-
continue
261-
}
262-
// Drop table first if this is a CREATE TABLE statement
263-
if strings.HasPrefix(strings.ToUpper(stmt), "CREATE TABLE") {
264-
parts := strings.Fields(stmt)
265-
if len(parts) >= 3 {
266-
tableName := strings.TrimSuffix(parts[2], "(")
267-
db.Exec("DROP TABLE IF EXISTS " + tableName)
268-
}
269-
}
270-
if _, err := db.Exec(stmt); err != nil {
271-
t.Logf("Failed to apply schema: %v", err)
272-
}
273-
}
274-
db.Close()
275-
}
276-
}
277-
}
278-
}
279-
},
280-
Enabled: func() bool {
281-
return clickhouseURI != ""
249+
return postgresURI != "" || mysqlURI != "" || clickhouseURI != ""
282250
},
283251
},
284252
}
@@ -316,12 +284,6 @@ func TestReplay(t *testing.T) {
316284
if !slices.Contains(args.Contexts, name) {
317285
t.Skipf("unsupported context: %s", name)
318286
}
319-
} else if name == "clickhouse" {
320-
// For clickhouse context, only run tests that explicitly include it
321-
// or that have ClickHouse engine (checked by having "clickhouse" in path)
322-
if !strings.Contains(tc.Name, "clickhouse") {
323-
t.Skipf("clickhouse context: skipping non-clickhouse test")
324-
}
325287
}
326288

327289
if len(args.OS) > 0 {

internal/endtoend/testdata/clickhouse_authors/clickhouse/stdlib/exec.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"contexts": ["clickhouse"],
2+
"contexts": ["managed-db"],
33
"env": {
44
"SQLCEXPERIMENT": "clickhouse"
55
}

internal/engine/clickhouse/analyzer/analyze.go

Lines changed: 142 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import (
88
"sync"
99

1010
_ "github.com/ClickHouse/clickhouse-go/v2" // ClickHouse driver
11+
dcast "github.com/sqlc-dev/doubleclick/ast"
12+
"github.com/sqlc-dev/doubleclick/parser"
1113

1214
core "github.com/sqlc-dev/sqlc/internal/analysis"
1315
"github.com/sqlc-dev/sqlc/internal/config"
@@ -155,6 +157,18 @@ func (a *Analyzer) connect(ctx context.Context, migrations []string) error {
155157
if len(strings.TrimSpace(m)) == 0 {
156158
continue
157159
}
160+
// For CREATE TABLE statements, drop the table first if it exists
161+
upper := strings.ToUpper(strings.TrimSpace(m))
162+
if strings.HasPrefix(upper, "CREATE TABLE") {
163+
// Extract table name and drop it first
164+
parts := strings.Fields(m)
165+
if len(parts) >= 3 {
166+
tableName := parts[2]
167+
// Remove any trailing characters like "("
168+
tableName = strings.TrimSuffix(tableName, "(")
169+
a.conn.ExecContext(ctx, "DROP TABLE IF EXISTS "+tableName)
170+
}
171+
}
158172
if _, err := a.conn.ExecContext(ctx, m); err != nil {
159173
a.conn.Close()
160174
a.conn = nil
@@ -212,12 +226,16 @@ func (a *Analyzer) GetColumnNames(ctx context.Context, query string) ([]string,
212226
// Replace ? placeholders with NULL for introspection
213227
preparedQuery := strings.ReplaceAll(query, "?", "NULL")
214228

215-
// Add LIMIT 0 to avoid fetching data
216-
limitQuery := addLimit0(preparedQuery)
217-
218-
rows, err := a.conn.QueryContext(ctx, limitQuery)
229+
// Use DESCRIBE (query) to get column information
230+
describeQuery := fmt.Sprintf("DESCRIBE (%s)", preparedQuery)
231+
rows, err := a.conn.QueryContext(ctx, describeQuery)
219232
if err != nil {
220-
return nil, err
233+
// Fallback to LIMIT 0 if DESCRIBE fails
234+
limitQuery := addLimit0(preparedQuery)
235+
rows, err = a.conn.QueryContext(ctx, limitQuery)
236+
if err != nil {
237+
return nil, err
238+
}
221239
}
222240
defer rows.Close()
223241

@@ -314,65 +332,142 @@ type paramInfo struct {
314332
Type string
315333
}
316334

317-
// detectParameters finds parameters in a ClickHouse query.
318-
// ClickHouse supports {name:Type} and $1, $2 style parameters.
335+
// detectParameters finds parameters in a ClickHouse query using the doubleclick parser.
336+
// ClickHouse supports {name:Type} and ? style parameters.
319337
func detectParameters(query string) []paramInfo {
320338
var params []paramInfo
321339

322-
// Find {name:Type} style parameters
323-
i := 0
324-
for i < len(query) {
325-
if query[i] == '{' {
326-
j := i + 1
327-
for j < len(query) && query[j] != '}' {
328-
j++
329-
}
330-
if j < len(query) {
331-
paramStr := query[i+1 : j]
332-
parts := strings.SplitN(paramStr, ":", 2)
333-
if len(parts) == 2 {
334-
params = append(params, paramInfo{
335-
Name: parts[0],
336-
Type: normalizeType(parts[1]),
337-
})
338-
} else if len(parts) == 1 {
339-
params = append(params, paramInfo{
340-
Name: parts[0],
341-
Type: "any",
342-
})
340+
// First, try to find {name:Type} style parameters using the doubleclick parser
341+
ctx := context.Background()
342+
stmts, err := parser.Parse(ctx, strings.NewReader(query))
343+
if err == nil {
344+
// Walk the AST to find Parameter nodes (for {name:Type} style)
345+
for _, stmt := range stmts {
346+
walkStatement(stmt, func(expr dcast.Expression) {
347+
if param, ok := expr.(*dcast.Parameter); ok {
348+
name := param.Name
349+
dataType := "any"
350+
if param.Type != nil {
351+
dataType = normalizeType(param.Type.Name)
352+
}
353+
if name != "" {
354+
// Only add named parameters from the parser
355+
params = append(params, paramInfo{
356+
Name: name,
357+
Type: dataType,
358+
})
359+
}
343360
}
344-
}
345-
i = j + 1
346-
} else {
347-
i++
348-
}
349-
}
350-
351-
// Find $1, $2 style parameters (simpler approach)
352-
for i := 1; i <= 100; i++ {
353-
placeholder := fmt.Sprintf("$%d", i)
354-
if strings.Contains(query, placeholder) {
355-
params = append(params, paramInfo{
356-
Name: fmt.Sprintf("p%d", i),
357-
Type: "any",
358361
})
359-
} else {
360-
break
361362
}
362363
}
363364

364-
// Find ? placeholders
365+
// Count ? placeholders (the doubleclick parser doesn't fully support these)
366+
// The ? placeholders are added after any named parameters
365367
count := strings.Count(query, "?")
366-
for i := len(params); i < count; i++ {
368+
for i := 0; i < count; i++ {
367369
params = append(params, paramInfo{
368-
Name: fmt.Sprintf("p%d", i+1),
370+
Name: fmt.Sprintf("p%d", len(params)+1),
369371
Type: "any",
370372
})
371373
}
372374

373375
return params
374376
}
375377

378+
// walkStatement walks a statement and calls fn for each expression.
379+
func walkStatement(stmt dcast.Statement, fn func(dcast.Expression)) {
380+
switch s := stmt.(type) {
381+
case *dcast.SelectQuery:
382+
walkSelectQuery(s, fn)
383+
case *dcast.SelectWithUnionQuery:
384+
for _, sel := range s.Selects {
385+
walkStatement(sel, fn)
386+
}
387+
case *dcast.InsertQuery:
388+
if s.Select != nil {
389+
walkStatement(s.Select, fn)
390+
}
391+
}
392+
}
393+
394+
// walkSelectQuery walks a SELECT query and calls fn for each expression.
395+
func walkSelectQuery(s *dcast.SelectQuery, fn func(dcast.Expression)) {
396+
// Walk columns
397+
for _, col := range s.Columns {
398+
walkExpression(col, fn)
399+
}
400+
// Walk WHERE clause
401+
if s.Where != nil {
402+
walkExpression(s.Where, fn)
403+
}
404+
// Walk GROUP BY
405+
for _, g := range s.GroupBy {
406+
walkExpression(g, fn)
407+
}
408+
// Walk HAVING
409+
if s.Having != nil {
410+
walkExpression(s.Having, fn)
411+
}
412+
// Walk ORDER BY
413+
for _, o := range s.OrderBy {
414+
walkExpression(o.Expression, fn)
415+
}
416+
// Walk LIMIT
417+
if s.Limit != nil {
418+
walkExpression(s.Limit, fn)
419+
}
420+
// Walk OFFSET
421+
if s.Offset != nil {
422+
walkExpression(s.Offset, fn)
423+
}
424+
}
425+
426+
// walkExpression walks an expression and calls fn for each sub-expression.
427+
func walkExpression(expr dcast.Expression, fn func(dcast.Expression)) {
428+
if expr == nil {
429+
return
430+
}
431+
fn(expr)
432+
433+
switch e := expr.(type) {
434+
case *dcast.BinaryExpr:
435+
walkExpression(e.Left, fn)
436+
walkExpression(e.Right, fn)
437+
case *dcast.UnaryExpr:
438+
walkExpression(e.Operand, fn)
439+
case *dcast.FunctionCall:
440+
for _, arg := range e.Arguments {
441+
walkExpression(arg, fn)
442+
}
443+
case *dcast.Subquery:
444+
walkStatement(e.Query, fn)
445+
case *dcast.CaseExpr:
446+
if e.Operand != nil {
447+
walkExpression(e.Operand, fn)
448+
}
449+
for _, when := range e.Whens {
450+
walkExpression(when.Condition, fn)
451+
walkExpression(when.Result, fn)
452+
}
453+
if e.Else != nil {
454+
walkExpression(e.Else, fn)
455+
}
456+
case *dcast.InExpr:
457+
walkExpression(e.Expr, fn)
458+
for _, v := range e.List {
459+
walkExpression(v, fn)
460+
}
461+
if e.Query != nil {
462+
walkStatement(e.Query, fn)
463+
}
464+
case *dcast.BetweenExpr:
465+
walkExpression(e.Expr, fn)
466+
walkExpression(e.Low, fn)
467+
walkExpression(e.High, fn)
468+
}
469+
}
470+
376471
// addLimit0 adds LIMIT 0 to a query for schema introspection.
377472
func addLimit0(query string) string {
378473
// Simple approach: append LIMIT 0 if not already present

0 commit comments

Comments
 (0)