@@ -5,15 +5,17 @@ import (
55 "database/sql"
66 sqldriver "database/sql/driver"
77 "fmt"
8- "log"
9-
10- "github.com/golang-migrate/migrate/v4"
118 "io"
9+ "log"
10+ nurl "net/url"
11+ "strconv"
1212 "strings"
1313 "testing"
1414
1515 "github.com/dhui/dktest"
1616
17+ "github.com/golang-migrate/migrate/v4"
18+ "github.com/golang-migrate/migrate/v4/database/multistmt"
1719 dt "github.com/golang-migrate/migrate/v4/database/testing"
1820 "github.com/golang-migrate/migrate/v4/dktesting"
1921 _ "github.com/golang-migrate/migrate/v4/source/file"
@@ -126,6 +128,75 @@ func TestMigrate(t *testing.T) {
126128 })
127129}
128130
131+ func TestMultipleStatements (t * testing.T ) {
132+ dktesting .ParallelTest (t , specs , func (t * testing.T , c dktest.ContainerInfo ) {
133+ ip , port , err := c .FirstPort ()
134+ if err != nil {
135+ t .Fatal (err )
136+ }
137+
138+ addr := fbConnectionString (ip , port )
139+ p := & Firebird {}
140+ d , err := p .Open (addr )
141+ if err != nil {
142+ t .Fatal (err )
143+ }
144+ defer func () {
145+ if err := d .Close (); err != nil {
146+ t .Error (err )
147+ }
148+ }()
149+ if err := d .Run (strings .NewReader ("CREATE TABLE foo (foo VARCHAR(40)); CREATE TABLE bar (bar VARCHAR(40));" )); err != nil {
150+ t .Fatalf ("expected err to be nil, got %v" , err )
151+ }
152+
153+ // make sure second table exists
154+ var exists bool
155+ query := "SELECT CASE WHEN EXISTS (SELECT 1 FROM RDB$RELATIONS WHERE RDB$RELATION_NAME = 'BAR') THEN 1 ELSE 0 END FROM RDB$DATABASE"
156+ if err := d .(* Firebird ).conn .QueryRowContext (context .Background (), query ).Scan (& exists ); err != nil {
157+ t .Fatal (err )
158+ }
159+ if ! exists {
160+ t .Fatalf ("expected table bar to exist" )
161+ }
162+ })
163+ }
164+
165+ func TestMultipleStatementsInMultiStatementMode (t * testing.T ) {
166+ dktesting .ParallelTest (t , specs , func (t * testing.T , c dktest.ContainerInfo ) {
167+ ip , port , err := c .FirstPort ()
168+ if err != nil {
169+ t .Fatal (err )
170+ }
171+
172+ addr := fbConnectionString (ip , port ) + "?x-multi-statement=true"
173+ p := & Firebird {}
174+ d , err := p .Open (addr )
175+ if err != nil {
176+ t .Fatal (err )
177+ }
178+ defer func () {
179+ if err := d .Close (); err != nil {
180+ t .Error (err )
181+ }
182+ }()
183+ // Use CREATE INDEX instead of CONCURRENTLY (Firebird doesn't support CREATE INDEX CONCURRENTLY)
184+ if err := d .Run (strings .NewReader ("CREATE TABLE foo (foo VARCHAR(40)); CREATE INDEX idx_foo ON foo (foo);" )); err != nil {
185+ t .Fatalf ("expected err to be nil, got %v" , err )
186+ }
187+
188+ // make sure created index exists
189+ var exists bool
190+ query := "SELECT CASE WHEN EXISTS (SELECT 1 FROM RDB$INDICES WHERE RDB$INDEX_NAME = 'IDX_FOO') THEN 1 ELSE 0 END FROM RDB$DATABASE"
191+ if err := d .(* Firebird ).conn .QueryRowContext (context .Background (), query ).Scan (& exists ); err != nil {
192+ t .Fatal (err )
193+ }
194+ if ! exists {
195+ t .Fatalf ("expected index idx_foo to exist" )
196+ }
197+ })
198+ }
199+
129200func TestErrorParsing (t * testing.T ) {
130201 dktesting .ParallelTest (t , specs , func (t * testing.T , c dktest.ContainerInfo ) {
131202 ip , port , err := c .FirstPort ()
@@ -225,3 +296,169 @@ func Test_Lock(t *testing.T) {
225296 }
226297 })
227298}
299+
300+ func TestMultiStatementURLParsing (t * testing.T ) {
301+ tests := []struct {
302+ name string
303+ url string
304+ expectedMultiStmt bool
305+ expectedMultiStmtSize int
306+ shouldError bool
307+ }{
308+ {
309+ name : "multi-statement enabled" ,
310+ url : "firebird://user:pass@localhost:3050//path/to/db.fdb?x-multi-statement=true" ,
311+ expectedMultiStmt : true ,
312+ expectedMultiStmtSize : DefaultMultiStatementMaxSize ,
313+ shouldError : false ,
314+ },
315+ {
316+ name : "multi-statement disabled" ,
317+ url : "firebird://user:pass@localhost:3050//path/to/db.fdb?x-multi-statement=false" ,
318+ expectedMultiStmt : false ,
319+ expectedMultiStmtSize : DefaultMultiStatementMaxSize ,
320+ shouldError : false ,
321+ },
322+ {
323+ name : "multi-statement with custom size" ,
324+ url : "firebird://user:pass@localhost:3050//path/to/db.fdb?x-multi-statement=true&x-multi-statement-max-size=5242880" ,
325+ expectedMultiStmt : true ,
326+ expectedMultiStmtSize : 5242880 ,
327+ shouldError : false ,
328+ },
329+ {
330+ name : "multi-statement with invalid size falls back to default" ,
331+ url : "firebird://user:pass@localhost:3050//path/to/db.fdb?x-multi-statement=true&x-multi-statement-max-size=0" ,
332+ expectedMultiStmt : true ,
333+ expectedMultiStmtSize : DefaultMultiStatementMaxSize ,
334+ shouldError : false ,
335+ },
336+ {
337+ name : "invalid boolean value should error" ,
338+ url : "firebird://user:pass@localhost:3050//path/to/db.fdb?x-multi-statement=invalid" ,
339+ expectedMultiStmt : false ,
340+ expectedMultiStmtSize : DefaultMultiStatementMaxSize ,
341+ shouldError : true ,
342+ },
343+ {
344+ name : "invalid size value should error" ,
345+ url : "firebird://user:pass@localhost:3050//path/to/db.fdb?x-multi-statement=true&x-multi-statement-max-size=invalid" ,
346+ expectedMultiStmt : true ,
347+ expectedMultiStmtSize : DefaultMultiStatementMaxSize ,
348+ shouldError : true ,
349+ },
350+ }
351+
352+ for _ , tt := range tests {
353+ t .Run (tt .name , func (t * testing.T ) {
354+ // We can't actually open a database connection without Docker,
355+ // but we can test the URL parsing logic by examining how Open would behave
356+ purl , err := nurl .Parse (tt .url )
357+ if err != nil {
358+ if ! tt .shouldError {
359+ t .Fatalf ("parseURL failed: %v" , err )
360+ }
361+ return
362+ }
363+
364+ // Test multi-statement parameter parsing
365+ multiStatementEnabled := false
366+ multiStatementMaxSize := DefaultMultiStatementMaxSize
367+
368+ if s := purl .Query ().Get ("x-multi-statement" ); len (s ) > 0 {
369+ multiStatementEnabled , err = strconv .ParseBool (s )
370+ if err != nil {
371+ if tt .shouldError {
372+ return // Expected error
373+ }
374+ t .Fatalf ("unable to parse option x-multi-statement: %v" , err )
375+ }
376+ }
377+
378+ if s := purl .Query ().Get ("x-multi-statement-max-size" ); len (s ) > 0 {
379+ multiStatementMaxSize , err = strconv .Atoi (s )
380+ if err != nil {
381+ if tt .shouldError {
382+ return // Expected error
383+ }
384+ t .Fatalf ("unable to parse x-multi-statement-max-size: %v" , err )
385+ }
386+ if multiStatementMaxSize <= 0 {
387+ multiStatementMaxSize = DefaultMultiStatementMaxSize
388+ }
389+ }
390+
391+ if tt .shouldError {
392+ t .Fatalf ("expected error but got none" )
393+ }
394+
395+ if multiStatementEnabled != tt .expectedMultiStmt {
396+ t .Errorf ("expected MultiStatementEnabled to be %v, got %v" , tt .expectedMultiStmt , multiStatementEnabled )
397+ }
398+
399+ if multiStatementMaxSize != tt .expectedMultiStmtSize {
400+ t .Errorf ("expected MultiStatementMaxSize to be %d, got %d" , tt .expectedMultiStmtSize , multiStatementMaxSize )
401+ }
402+ })
403+ }
404+ }
405+
406+ func TestMultiStatementParsing (t * testing.T ) {
407+ tests := []struct {
408+ name string
409+ input string
410+ expected []string
411+ }{
412+ {
413+ name : "single statement" ,
414+ input : "CREATE TABLE test (id INTEGER);" ,
415+ expected : []string {"CREATE TABLE test (id INTEGER);" },
416+ },
417+ {
418+ name : "multiple statements" ,
419+ input : "CREATE TABLE foo (id INTEGER); CREATE TABLE bar (name VARCHAR(50));" ,
420+ expected : []string {"CREATE TABLE foo (id INTEGER);" , "CREATE TABLE bar (name VARCHAR(50));" },
421+ },
422+ {
423+ name : "statements with whitespace" ,
424+ input : "CREATE TABLE foo (id INTEGER);\n \n CREATE TABLE bar (name VARCHAR(50)); \n " ,
425+ expected : []string {"CREATE TABLE foo (id INTEGER);" , "CREATE TABLE bar (name VARCHAR(50));" },
426+ },
427+ {
428+ name : "empty statements ignored" ,
429+ input : "CREATE TABLE foo (id INTEGER);;CREATE TABLE bar (name VARCHAR(50));" ,
430+ expected : []string {"CREATE TABLE foo (id INTEGER);" , "CREATE TABLE bar (name VARCHAR(50));" },
431+ },
432+ }
433+
434+ for _ , tt := range tests {
435+ t .Run (tt .name , func (t * testing.T ) {
436+ var statements []string
437+ reader := strings .NewReader (tt .input )
438+
439+ // Simulate what the Firebird driver does with multi-statement parsing
440+ err := multistmt .Parse (reader , multiStmtDelimiter , DefaultMultiStatementMaxSize , func (stmt []byte ) bool {
441+ query := strings .TrimSpace (string (stmt ))
442+ // Skip empty statements and standalone semicolons
443+ if len (query ) > 0 && query != ";" {
444+ statements = append (statements , query )
445+ }
446+ return true // continue parsing
447+ })
448+
449+ if err != nil {
450+ t .Fatalf ("parsing failed: %v" , err )
451+ }
452+
453+ if len (statements ) != len (tt .expected ) {
454+ t .Fatalf ("expected %d statements, got %d: %v" , len (tt .expected ), len (statements ), statements )
455+ }
456+
457+ for i , expected := range tt .expected {
458+ if statements [i ] != expected {
459+ t .Errorf ("statement %d: expected %q, got %q" , i , expected , statements [i ])
460+ }
461+ }
462+ })
463+ }
464+ }
0 commit comments