|
8 | 8 | "sync" |
9 | 9 |
|
10 | 10 | _ "github.com/ClickHouse/clickhouse-go/v2" // ClickHouse driver |
| 11 | + dcast "github.com/sqlc-dev/doubleclick/ast" |
| 12 | + "github.com/sqlc-dev/doubleclick/parser" |
11 | 13 |
|
12 | 14 | core "github.com/sqlc-dev/sqlc/internal/analysis" |
13 | 15 | "github.com/sqlc-dev/sqlc/internal/config" |
@@ -155,6 +157,18 @@ func (a *Analyzer) connect(ctx context.Context, migrations []string) error { |
155 | 157 | if len(strings.TrimSpace(m)) == 0 { |
156 | 158 | continue |
157 | 159 | } |
| 160 | + // For CREATE TABLE statements, drop the table first if it exists |
| 161 | + upper := strings.ToUpper(strings.TrimSpace(m)) |
| 162 | + if strings.HasPrefix(upper, "CREATE TABLE") { |
| 163 | + // Extract table name and drop it first |
| 164 | + parts := strings.Fields(m) |
| 165 | + if len(parts) >= 3 { |
| 166 | + tableName := parts[2] |
| 167 | + // Remove any trailing characters like "(" |
| 168 | + tableName = strings.TrimSuffix(tableName, "(") |
| 169 | + a.conn.ExecContext(ctx, "DROP TABLE IF EXISTS "+tableName) |
| 170 | + } |
| 171 | + } |
158 | 172 | if _, err := a.conn.ExecContext(ctx, m); err != nil { |
159 | 173 | a.conn.Close() |
160 | 174 | a.conn = nil |
@@ -212,12 +226,16 @@ func (a *Analyzer) GetColumnNames(ctx context.Context, query string) ([]string, |
212 | 226 | // Replace ? placeholders with NULL for introspection |
213 | 227 | preparedQuery := strings.ReplaceAll(query, "?", "NULL") |
214 | 228 |
|
215 | | - // Add LIMIT 0 to avoid fetching data |
216 | | - limitQuery := addLimit0(preparedQuery) |
217 | | - |
218 | | - rows, err := a.conn.QueryContext(ctx, limitQuery) |
| 229 | + // Use DESCRIBE (query) to get column information |
| 230 | + describeQuery := fmt.Sprintf("DESCRIBE (%s)", preparedQuery) |
| 231 | + rows, err := a.conn.QueryContext(ctx, describeQuery) |
219 | 232 | if err != nil { |
220 | | - return nil, err |
| 233 | + // Fallback to LIMIT 0 if DESCRIBE fails |
| 234 | + limitQuery := addLimit0(preparedQuery) |
| 235 | + rows, err = a.conn.QueryContext(ctx, limitQuery) |
| 236 | + if err != nil { |
| 237 | + return nil, err |
| 238 | + } |
221 | 239 | } |
222 | 240 | defer rows.Close() |
223 | 241 |
|
@@ -314,65 +332,142 @@ type paramInfo struct { |
314 | 332 | Type string |
315 | 333 | } |
316 | 334 |
|
317 | | -// detectParameters finds parameters in a ClickHouse query. |
318 | | -// ClickHouse supports {name:Type} and $1, $2 style parameters. |
| 335 | +// detectParameters finds parameters in a ClickHouse query using the doubleclick parser. |
| 336 | +// ClickHouse supports {name:Type} and ? style parameters. |
319 | 337 | func detectParameters(query string) []paramInfo { |
320 | 338 | var params []paramInfo |
321 | 339 |
|
322 | | - // Find {name:Type} style parameters |
323 | | - i := 0 |
324 | | - for i < len(query) { |
325 | | - if query[i] == '{' { |
326 | | - j := i + 1 |
327 | | - for j < len(query) && query[j] != '}' { |
328 | | - j++ |
329 | | - } |
330 | | - if j < len(query) { |
331 | | - paramStr := query[i+1 : j] |
332 | | - parts := strings.SplitN(paramStr, ":", 2) |
333 | | - if len(parts) == 2 { |
334 | | - params = append(params, paramInfo{ |
335 | | - Name: parts[0], |
336 | | - Type: normalizeType(parts[1]), |
337 | | - }) |
338 | | - } else if len(parts) == 1 { |
339 | | - params = append(params, paramInfo{ |
340 | | - Name: parts[0], |
341 | | - Type: "any", |
342 | | - }) |
| 340 | + // First, try to find {name:Type} style parameters using the doubleclick parser |
| 341 | + ctx := context.Background() |
| 342 | + stmts, err := parser.Parse(ctx, strings.NewReader(query)) |
| 343 | + if err == nil { |
| 344 | + // Walk the AST to find Parameter nodes (for {name:Type} style) |
| 345 | + for _, stmt := range stmts { |
| 346 | + walkStatement(stmt, func(expr dcast.Expression) { |
| 347 | + if param, ok := expr.(*dcast.Parameter); ok { |
| 348 | + name := param.Name |
| 349 | + dataType := "any" |
| 350 | + if param.Type != nil { |
| 351 | + dataType = normalizeType(param.Type.Name) |
| 352 | + } |
| 353 | + if name != "" { |
| 354 | + // Only add named parameters from the parser |
| 355 | + params = append(params, paramInfo{ |
| 356 | + Name: name, |
| 357 | + Type: dataType, |
| 358 | + }) |
| 359 | + } |
343 | 360 | } |
344 | | - } |
345 | | - i = j + 1 |
346 | | - } else { |
347 | | - i++ |
348 | | - } |
349 | | - } |
350 | | - |
351 | | - // Find $1, $2 style parameters (simpler approach) |
352 | | - for i := 1; i <= 100; i++ { |
353 | | - placeholder := fmt.Sprintf("$%d", i) |
354 | | - if strings.Contains(query, placeholder) { |
355 | | - params = append(params, paramInfo{ |
356 | | - Name: fmt.Sprintf("p%d", i), |
357 | | - Type: "any", |
358 | 361 | }) |
359 | | - } else { |
360 | | - break |
361 | 362 | } |
362 | 363 | } |
363 | 364 |
|
364 | | - // Find ? placeholders |
| 365 | + // Count ? placeholders (the doubleclick parser doesn't fully support these) |
| 366 | + // The ? placeholders are added after any named parameters |
365 | 367 | count := strings.Count(query, "?") |
366 | | - for i := len(params); i < count; i++ { |
| 368 | + for i := 0; i < count; i++ { |
367 | 369 | params = append(params, paramInfo{ |
368 | | - Name: fmt.Sprintf("p%d", i+1), |
| 370 | + Name: fmt.Sprintf("p%d", len(params)+1), |
369 | 371 | Type: "any", |
370 | 372 | }) |
371 | 373 | } |
372 | 374 |
|
373 | 375 | return params |
374 | 376 | } |
375 | 377 |
|
| 378 | +// walkStatement walks a statement and calls fn for each expression. |
| 379 | +func walkStatement(stmt dcast.Statement, fn func(dcast.Expression)) { |
| 380 | + switch s := stmt.(type) { |
| 381 | + case *dcast.SelectQuery: |
| 382 | + walkSelectQuery(s, fn) |
| 383 | + case *dcast.SelectWithUnionQuery: |
| 384 | + for _, sel := range s.Selects { |
| 385 | + walkStatement(sel, fn) |
| 386 | + } |
| 387 | + case *dcast.InsertQuery: |
| 388 | + if s.Select != nil { |
| 389 | + walkStatement(s.Select, fn) |
| 390 | + } |
| 391 | + } |
| 392 | +} |
| 393 | + |
| 394 | +// walkSelectQuery walks a SELECT query and calls fn for each expression. |
| 395 | +func walkSelectQuery(s *dcast.SelectQuery, fn func(dcast.Expression)) { |
| 396 | + // Walk columns |
| 397 | + for _, col := range s.Columns { |
| 398 | + walkExpression(col, fn) |
| 399 | + } |
| 400 | + // Walk WHERE clause |
| 401 | + if s.Where != nil { |
| 402 | + walkExpression(s.Where, fn) |
| 403 | + } |
| 404 | + // Walk GROUP BY |
| 405 | + for _, g := range s.GroupBy { |
| 406 | + walkExpression(g, fn) |
| 407 | + } |
| 408 | + // Walk HAVING |
| 409 | + if s.Having != nil { |
| 410 | + walkExpression(s.Having, fn) |
| 411 | + } |
| 412 | + // Walk ORDER BY |
| 413 | + for _, o := range s.OrderBy { |
| 414 | + walkExpression(o.Expression, fn) |
| 415 | + } |
| 416 | + // Walk LIMIT |
| 417 | + if s.Limit != nil { |
| 418 | + walkExpression(s.Limit, fn) |
| 419 | + } |
| 420 | + // Walk OFFSET |
| 421 | + if s.Offset != nil { |
| 422 | + walkExpression(s.Offset, fn) |
| 423 | + } |
| 424 | +} |
| 425 | + |
| 426 | +// walkExpression walks an expression and calls fn for each sub-expression. |
| 427 | +func walkExpression(expr dcast.Expression, fn func(dcast.Expression)) { |
| 428 | + if expr == nil { |
| 429 | + return |
| 430 | + } |
| 431 | + fn(expr) |
| 432 | + |
| 433 | + switch e := expr.(type) { |
| 434 | + case *dcast.BinaryExpr: |
| 435 | + walkExpression(e.Left, fn) |
| 436 | + walkExpression(e.Right, fn) |
| 437 | + case *dcast.UnaryExpr: |
| 438 | + walkExpression(e.Operand, fn) |
| 439 | + case *dcast.FunctionCall: |
| 440 | + for _, arg := range e.Arguments { |
| 441 | + walkExpression(arg, fn) |
| 442 | + } |
| 443 | + case *dcast.Subquery: |
| 444 | + walkStatement(e.Query, fn) |
| 445 | + case *dcast.CaseExpr: |
| 446 | + if e.Operand != nil { |
| 447 | + walkExpression(e.Operand, fn) |
| 448 | + } |
| 449 | + for _, when := range e.Whens { |
| 450 | + walkExpression(when.Condition, fn) |
| 451 | + walkExpression(when.Result, fn) |
| 452 | + } |
| 453 | + if e.Else != nil { |
| 454 | + walkExpression(e.Else, fn) |
| 455 | + } |
| 456 | + case *dcast.InExpr: |
| 457 | + walkExpression(e.Expr, fn) |
| 458 | + for _, v := range e.List { |
| 459 | + walkExpression(v, fn) |
| 460 | + } |
| 461 | + if e.Query != nil { |
| 462 | + walkStatement(e.Query, fn) |
| 463 | + } |
| 464 | + case *dcast.BetweenExpr: |
| 465 | + walkExpression(e.Expr, fn) |
| 466 | + walkExpression(e.Low, fn) |
| 467 | + walkExpression(e.High, fn) |
| 468 | + } |
| 469 | +} |
| 470 | + |
376 | 471 | // addLimit0 adds LIMIT 0 to a query for schema introspection. |
377 | 472 | func addLimit0(query string) string { |
378 | 473 | // Simple approach: append LIMIT 0 if not already present |
|
0 commit comments