Skip to content

Commit 41ac4d7

Browse files
committed
test: Add tests for multi statement parsing
1 parent e8edcb1 commit 41ac4d7

File tree

1 file changed

+206
-3
lines changed

1 file changed

+206
-3
lines changed

database/firebird/firebird_test.go

Lines changed: 206 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,17 @@ import (
55
"database/sql"
66
sqldriver "database/sql/driver"
77
"fmt"
8-
"log"
9-
10-
"github.com/golang-migrate/migrate/v4"
118
"io"
9+
"log"
10+
nurl "net/url"
11+
"strconv"
1212
"strings"
1313
"testing"
1414

1515
"github.com/dhui/dktest"
1616

17+
"github.com/golang-migrate/migrate/v4"
18+
"github.com/golang-migrate/migrate/v4/database/multistmt"
1719
dt "github.com/golang-migrate/migrate/v4/database/testing"
1820
"github.com/golang-migrate/migrate/v4/dktesting"
1921
_ "github.com/golang-migrate/migrate/v4/source/file"
@@ -126,6 +128,41 @@ func TestMigrate(t *testing.T) {
126128
})
127129
}
128130

131+
func TestMultipleStatementsInMultiStatementMode(t *testing.T) {
132+
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
133+
ip, port, err := c.FirstPort()
134+
if err != nil {
135+
t.Fatal(err)
136+
}
137+
138+
addr := fbConnectionString(ip, port) + "?x-multi-statement=true"
139+
p := &Firebird{}
140+
d, err := p.Open(addr)
141+
if err != nil {
142+
t.Fatal(err)
143+
}
144+
defer func() {
145+
if err := d.Close(); err != nil {
146+
t.Error(err)
147+
}
148+
}()
149+
// Use CREATE INDEX instead of CONCURRENTLY (Firebird doesn't support CREATE INDEX CONCURRENTLY)
150+
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo VARCHAR(40)); CREATE INDEX idx_foo ON foo (foo);")); err != nil {
151+
t.Fatalf("expected err to be nil, got %v", err)
152+
}
153+
154+
// make sure created index exists
155+
var exists bool
156+
query := "SELECT CASE WHEN EXISTS (SELECT 1 FROM RDB$INDICES WHERE RDB$INDEX_NAME = 'IDX_FOO') THEN 1 ELSE 0 END FROM RDB$DATABASE"
157+
if err := d.(*Firebird).conn.QueryRowContext(context.Background(), query).Scan(&exists); err != nil {
158+
t.Fatal(err)
159+
}
160+
if !exists {
161+
t.Fatalf("expected index idx_foo to exist")
162+
}
163+
})
164+
}
165+
129166
func TestErrorParsing(t *testing.T) {
130167
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
131168
ip, port, err := c.FirstPort()
@@ -225,3 +262,169 @@ func Test_Lock(t *testing.T) {
225262
}
226263
})
227264
}
265+
266+
func TestMultiStatementURLParsing(t *testing.T) {
267+
tests := []struct {
268+
name string
269+
url string
270+
expectedMultiStmt bool
271+
expectedMultiStmtSize int
272+
shouldError bool
273+
}{
274+
{
275+
name: "multi-statement enabled",
276+
url: "firebird://user:pass@localhost:3050//path/to/db.fdb?x-multi-statement=true",
277+
expectedMultiStmt: true,
278+
expectedMultiStmtSize: DefaultMultiStatementMaxSize,
279+
shouldError: false,
280+
},
281+
{
282+
name: "multi-statement disabled",
283+
url: "firebird://user:pass@localhost:3050//path/to/db.fdb?x-multi-statement=false",
284+
expectedMultiStmt: false,
285+
expectedMultiStmtSize: DefaultMultiStatementMaxSize,
286+
shouldError: false,
287+
},
288+
{
289+
name: "multi-statement with custom size",
290+
url: "firebird://user:pass@localhost:3050//path/to/db.fdb?x-multi-statement=true&x-multi-statement-max-size=5242880",
291+
expectedMultiStmt: true,
292+
expectedMultiStmtSize: 5242880,
293+
shouldError: false,
294+
},
295+
{
296+
name: "multi-statement with invalid size falls back to default",
297+
url: "firebird://user:pass@localhost:3050//path/to/db.fdb?x-multi-statement=true&x-multi-statement-max-size=0",
298+
expectedMultiStmt: true,
299+
expectedMultiStmtSize: DefaultMultiStatementMaxSize,
300+
shouldError: false,
301+
},
302+
{
303+
name: "invalid boolean value should error",
304+
url: "firebird://user:pass@localhost:3050//path/to/db.fdb?x-multi-statement=invalid",
305+
expectedMultiStmt: false,
306+
expectedMultiStmtSize: DefaultMultiStatementMaxSize,
307+
shouldError: true,
308+
},
309+
{
310+
name: "invalid size value should error",
311+
url: "firebird://user:pass@localhost:3050//path/to/db.fdb?x-multi-statement=true&x-multi-statement-max-size=invalid",
312+
expectedMultiStmt: true,
313+
expectedMultiStmtSize: DefaultMultiStatementMaxSize,
314+
shouldError: true,
315+
},
316+
}
317+
318+
for _, tt := range tests {
319+
t.Run(tt.name, func(t *testing.T) {
320+
// We can't actually open a database connection without Docker,
321+
// but we can test the URL parsing logic by examining how Open would behave
322+
purl, err := nurl.Parse(tt.url)
323+
if err != nil {
324+
if !tt.shouldError {
325+
t.Fatalf("parseURL failed: %v", err)
326+
}
327+
return
328+
}
329+
330+
// Test multi-statement parameter parsing
331+
multiStatementEnabled := false
332+
multiStatementMaxSize := DefaultMultiStatementMaxSize
333+
334+
if s := purl.Query().Get("x-multi-statement"); len(s) > 0 {
335+
multiStatementEnabled, err = strconv.ParseBool(s)
336+
if err != nil {
337+
if tt.shouldError {
338+
return // Expected error
339+
}
340+
t.Fatalf("unable to parse option x-multi-statement: %v", err)
341+
}
342+
}
343+
344+
if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
345+
multiStatementMaxSize, err = strconv.Atoi(s)
346+
if err != nil {
347+
if tt.shouldError {
348+
return // Expected error
349+
}
350+
t.Fatalf("unable to parse x-multi-statement-max-size: %v", err)
351+
}
352+
if multiStatementMaxSize <= 0 {
353+
multiStatementMaxSize = DefaultMultiStatementMaxSize
354+
}
355+
}
356+
357+
if tt.shouldError {
358+
t.Fatalf("expected error but got none")
359+
}
360+
361+
if multiStatementEnabled != tt.expectedMultiStmt {
362+
t.Errorf("expected MultiStatementEnabled to be %v, got %v", tt.expectedMultiStmt, multiStatementEnabled)
363+
}
364+
365+
if multiStatementMaxSize != tt.expectedMultiStmtSize {
366+
t.Errorf("expected MultiStatementMaxSize to be %d, got %d", tt.expectedMultiStmtSize, multiStatementMaxSize)
367+
}
368+
})
369+
}
370+
}
371+
372+
func TestMultiStatementParsing(t *testing.T) {
373+
tests := []struct {
374+
name string
375+
input string
376+
expected []string
377+
}{
378+
{
379+
name: "single statement",
380+
input: "CREATE TABLE test (id INTEGER);",
381+
expected: []string{"CREATE TABLE test (id INTEGER);"},
382+
},
383+
{
384+
name: "multiple statements",
385+
input: "CREATE TABLE foo (id INTEGER); CREATE TABLE bar (name VARCHAR(50));",
386+
expected: []string{"CREATE TABLE foo (id INTEGER);", "CREATE TABLE bar (name VARCHAR(50));"},
387+
},
388+
{
389+
name: "statements with whitespace",
390+
input: "CREATE TABLE foo (id INTEGER);\n\n CREATE TABLE bar (name VARCHAR(50)); \n",
391+
expected: []string{"CREATE TABLE foo (id INTEGER);", "CREATE TABLE bar (name VARCHAR(50));"},
392+
},
393+
{
394+
name: "empty statements ignored",
395+
input: "CREATE TABLE foo (id INTEGER);;CREATE TABLE bar (name VARCHAR(50));",
396+
expected: []string{"CREATE TABLE foo (id INTEGER);", "CREATE TABLE bar (name VARCHAR(50));"},
397+
},
398+
}
399+
400+
for _, tt := range tests {
401+
t.Run(tt.name, func(t *testing.T) {
402+
var statements []string
403+
reader := strings.NewReader(tt.input)
404+
405+
// Simulate what the Firebird driver does with multi-statement parsing
406+
err := multistmt.Parse(reader, multiStmtDelimiter, DefaultMultiStatementMaxSize, func(stmt []byte) bool {
407+
query := strings.TrimSpace(string(stmt))
408+
// Skip empty statements and standalone semicolons
409+
if len(query) > 0 && query != ";" {
410+
statements = append(statements, query)
411+
}
412+
return true // continue parsing
413+
})
414+
415+
if err != nil {
416+
t.Fatalf("parsing failed: %v", err)
417+
}
418+
419+
if len(statements) != len(tt.expected) {
420+
t.Fatalf("expected %d statements, got %d: %v", len(tt.expected), len(statements), statements)
421+
}
422+
423+
for i, expected := range tt.expected {
424+
if statements[i] != expected {
425+
t.Errorf("statement %d: expected %q, got %q", i, expected, statements[i])
426+
}
427+
}
428+
})
429+
}
430+
}

0 commit comments

Comments
 (0)