From e8edcb12ce9ff4303d225501dfdf0908018612aa Mon Sep 17 00:00:00 2001 From: Jacob Alberty Date: Sun, 29 Jun 2025 02:16:09 -0500 Subject: [PATCH 1/2] feat: Add multistmt support to firebird driver --- database/firebird/README.md | 4 ++- database/firebird/firebird.go | 61 ++++++++++++++++++++++++++++++++--- 2 files changed, 59 insertions(+), 6 deletions(-) diff --git a/database/firebird/README.md b/database/firebird/README.md index bdfef8aa9..cfdb14fd3 100644 --- a/database/firebird/README.md +++ b/database/firebird/README.md @@ -5,8 +5,10 @@ | URL Query | WithInstance Config | Description | |------------|---------------------|-------------| | `x-migrations-table` | `MigrationsTable` | Name of the migrations table | +| `x-multi-statement` | `MultiStatementEnabled` | Enable multi-statement execution (default: false) | +| `x-multi-statement-max-size` | `MultiStatementMaxSize` | Maximum size of single statement in bytes (default: 10MB) | | `auth_plugin_name` | | Authentication plugin name. Srp256/Srp/Legacy_Auth are available. (default is Srp) | | `column_name_to_lower` | | Force column name to lower. (default is false) | | `role` | | Role name | | `tzname` | | Time Zone name. (For Firebird 4.0+) | -| `wire_crypt` | | Enable wire data encryption or not. For Firebird 3.0+ (default is true) | +| `wire_crypt` | | Enable wire data encryption or not. For Firebird 3.0+ (default is true) | \ No newline at end of file diff --git a/database/firebird/firebird.go b/database/firebird/firebird.go index e15ea96b8..a725bfd0a 100644 --- a/database/firebird/firebird.go +++ b/database/firebird/firebird.go @@ -8,9 +8,12 @@ import ( "fmt" "io" nurl "net/url" + "strconv" + "strings" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database" + "github.com/golang-migrate/migrate/v4/database/multistmt" "github.com/hashicorp/go-multierror" _ "github.com/nakagami/firebirdsql" "go.uber.org/atomic" @@ -22,15 +25,22 @@ func init() { database.Register("firebirdsql", &db) } -var DefaultMigrationsTable = "schema_migrations" +var ( + multiStmtDelimiter = []byte(";") + + DefaultMigrationsTable = "schema_migrations" + DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB +) var ( ErrNilConfig = fmt.Errorf("no config") ) type Config struct { - DatabaseName string - MigrationsTable string + DatabaseName string + MigrationsTable string + MultiStatementEnabled bool + MultiStatementMaxSize int } type Firebird struct { @@ -85,9 +95,30 @@ func (f *Firebird) Open(dsn string) (database.Driver, error) { return nil, err } + multiStatementMaxSize := DefaultMultiStatementMaxSize + if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 { + multiStatementMaxSize, err = strconv.Atoi(s) + if err != nil { + return nil, err + } + if multiStatementMaxSize <= 0 { + multiStatementMaxSize = DefaultMultiStatementMaxSize + } + } + + multiStatementEnabled := false + if s := purl.Query().Get("x-multi-statement"); len(s) > 0 { + multiStatementEnabled, err = strconv.ParseBool(s) + if err != nil { + return nil, fmt.Errorf("unable to parse option x-multi-statement: %w", err) + } + } + px, err := WithInstance(db, &Config{ - MigrationsTable: purl.Query().Get("x-migrations-table"), - DatabaseName: purl.Path, + MigrationsTable: purl.Query().Get("x-migrations-table"), + DatabaseName: purl.Path, + MultiStatementEnabled: multiStatementEnabled, + MultiStatementMaxSize: multiStatementMaxSize, }) if err != nil { @@ -121,6 +152,26 @@ func (f *Firebird) Unlock() error { } func (f *Firebird) Run(migration io.Reader) error { + if f.config.MultiStatementEnabled { + var err error + + if e := multistmt.Parse(migration, multiStmtDelimiter, f.config.MultiStatementMaxSize, func(m []byte) bool { + query := strings.TrimSpace(string(m)) + if len(query) == 0 { + return true + } + if _, err = f.conn.ExecContext(context.Background(), query); err != nil { + return false // stop parsing on error + } + return true // continue parsing + }); e != nil { + return &database.Error{OrigErr: e, Err: "error parsing multi-statement migration"} + } + if err != nil { + return &database.Error{OrigErr: err, Err: "error executing multi-statement migration"} + } + return nil + } migr, err := io.ReadAll(migration) if err != nil { return err From 41ac4d761569ca48526257db7179bf71e51666eb Mon Sep 17 00:00:00 2001 From: Jacob Alberty Date: Sun, 29 Jun 2025 12:48:01 -0500 Subject: [PATCH 2/2] test: Add tests for multi statement parsing --- database/firebird/firebird_test.go | 209 ++++++++++++++++++++++++++++- 1 file changed, 206 insertions(+), 3 deletions(-) diff --git a/database/firebird/firebird_test.go b/database/firebird/firebird_test.go index 1e6701c4e..3977e8cf4 100644 --- a/database/firebird/firebird_test.go +++ b/database/firebird/firebird_test.go @@ -5,15 +5,17 @@ import ( "database/sql" sqldriver "database/sql/driver" "fmt" - "log" - - "github.com/golang-migrate/migrate/v4" "io" + "log" + nurl "net/url" + "strconv" "strings" "testing" "github.com/dhui/dktest" + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database/multistmt" dt "github.com/golang-migrate/migrate/v4/database/testing" "github.com/golang-migrate/migrate/v4/dktesting" _ "github.com/golang-migrate/migrate/v4/source/file" @@ -126,6 +128,41 @@ func TestMigrate(t *testing.T) { }) } +func TestMultipleStatementsInMultiStatementMode(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.FirstPort() + if err != nil { + t.Fatal(err) + } + + addr := fbConnectionString(ip, port) + "?x-multi-statement=true" + p := &Firebird{} + d, err := p.Open(addr) + if err != nil { + t.Fatal(err) + } + defer func() { + if err := d.Close(); err != nil { + t.Error(err) + } + }() + // Use CREATE INDEX instead of CONCURRENTLY (Firebird doesn't support CREATE INDEX CONCURRENTLY) + if err := d.Run(strings.NewReader("CREATE TABLE foo (foo VARCHAR(40)); CREATE INDEX idx_foo ON foo (foo);")); err != nil { + t.Fatalf("expected err to be nil, got %v", err) + } + + // make sure created index exists + var exists bool + query := "SELECT CASE WHEN EXISTS (SELECT 1 FROM RDB$INDICES WHERE RDB$INDEX_NAME = 'IDX_FOO') THEN 1 ELSE 0 END FROM RDB$DATABASE" + if err := d.(*Firebird).conn.QueryRowContext(context.Background(), query).Scan(&exists); err != nil { + t.Fatal(err) + } + if !exists { + t.Fatalf("expected index idx_foo to exist") + } + }) +} + func TestErrorParsing(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { ip, port, err := c.FirstPort() @@ -225,3 +262,169 @@ func Test_Lock(t *testing.T) { } }) } + +func TestMultiStatementURLParsing(t *testing.T) { + tests := []struct { + name string + url string + expectedMultiStmt bool + expectedMultiStmtSize int + shouldError bool + }{ + { + name: "multi-statement enabled", + url: "firebird://user:pass@localhost:3050//path/to/db.fdb?x-multi-statement=true", + expectedMultiStmt: true, + expectedMultiStmtSize: DefaultMultiStatementMaxSize, + shouldError: false, + }, + { + name: "multi-statement disabled", + url: "firebird://user:pass@localhost:3050//path/to/db.fdb?x-multi-statement=false", + expectedMultiStmt: false, + expectedMultiStmtSize: DefaultMultiStatementMaxSize, + shouldError: false, + }, + { + name: "multi-statement with custom size", + url: "firebird://user:pass@localhost:3050//path/to/db.fdb?x-multi-statement=true&x-multi-statement-max-size=5242880", + expectedMultiStmt: true, + expectedMultiStmtSize: 5242880, + shouldError: false, + }, + { + name: "multi-statement with invalid size falls back to default", + url: "firebird://user:pass@localhost:3050//path/to/db.fdb?x-multi-statement=true&x-multi-statement-max-size=0", + expectedMultiStmt: true, + expectedMultiStmtSize: DefaultMultiStatementMaxSize, + shouldError: false, + }, + { + name: "invalid boolean value should error", + url: "firebird://user:pass@localhost:3050//path/to/db.fdb?x-multi-statement=invalid", + expectedMultiStmt: false, + expectedMultiStmtSize: DefaultMultiStatementMaxSize, + shouldError: true, + }, + { + name: "invalid size value should error", + url: "firebird://user:pass@localhost:3050//path/to/db.fdb?x-multi-statement=true&x-multi-statement-max-size=invalid", + expectedMultiStmt: true, + expectedMultiStmtSize: DefaultMultiStatementMaxSize, + shouldError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // We can't actually open a database connection without Docker, + // but we can test the URL parsing logic by examining how Open would behave + purl, err := nurl.Parse(tt.url) + if err != nil { + if !tt.shouldError { + t.Fatalf("parseURL failed: %v", err) + } + return + } + + // Test multi-statement parameter parsing + multiStatementEnabled := false + multiStatementMaxSize := DefaultMultiStatementMaxSize + + if s := purl.Query().Get("x-multi-statement"); len(s) > 0 { + multiStatementEnabled, err = strconv.ParseBool(s) + if err != nil { + if tt.shouldError { + return // Expected error + } + t.Fatalf("unable to parse option x-multi-statement: %v", err) + } + } + + if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 { + multiStatementMaxSize, err = strconv.Atoi(s) + if err != nil { + if tt.shouldError { + return // Expected error + } + t.Fatalf("unable to parse x-multi-statement-max-size: %v", err) + } + if multiStatementMaxSize <= 0 { + multiStatementMaxSize = DefaultMultiStatementMaxSize + } + } + + if tt.shouldError { + t.Fatalf("expected error but got none") + } + + if multiStatementEnabled != tt.expectedMultiStmt { + t.Errorf("expected MultiStatementEnabled to be %v, got %v", tt.expectedMultiStmt, multiStatementEnabled) + } + + if multiStatementMaxSize != tt.expectedMultiStmtSize { + t.Errorf("expected MultiStatementMaxSize to be %d, got %d", tt.expectedMultiStmtSize, multiStatementMaxSize) + } + }) + } +} + +func TestMultiStatementParsing(t *testing.T) { + tests := []struct { + name string + input string + expected []string + }{ + { + name: "single statement", + input: "CREATE TABLE test (id INTEGER);", + expected: []string{"CREATE TABLE test (id INTEGER);"}, + }, + { + name: "multiple statements", + input: "CREATE TABLE foo (id INTEGER); CREATE TABLE bar (name VARCHAR(50));", + expected: []string{"CREATE TABLE foo (id INTEGER);", "CREATE TABLE bar (name VARCHAR(50));"}, + }, + { + name: "statements with whitespace", + input: "CREATE TABLE foo (id INTEGER);\n\n CREATE TABLE bar (name VARCHAR(50)); \n", + expected: []string{"CREATE TABLE foo (id INTEGER);", "CREATE TABLE bar (name VARCHAR(50));"}, + }, + { + name: "empty statements ignored", + input: "CREATE TABLE foo (id INTEGER);;CREATE TABLE bar (name VARCHAR(50));", + expected: []string{"CREATE TABLE foo (id INTEGER);", "CREATE TABLE bar (name VARCHAR(50));"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var statements []string + reader := strings.NewReader(tt.input) + + // Simulate what the Firebird driver does with multi-statement parsing + err := multistmt.Parse(reader, multiStmtDelimiter, DefaultMultiStatementMaxSize, func(stmt []byte) bool { + query := strings.TrimSpace(string(stmt)) + // Skip empty statements and standalone semicolons + if len(query) > 0 && query != ";" { + statements = append(statements, query) + } + return true // continue parsing + }) + + if err != nil { + t.Fatalf("parsing failed: %v", err) + } + + if len(statements) != len(tt.expected) { + t.Fatalf("expected %d statements, got %d: %v", len(tt.expected), len(statements), statements) + } + + for i, expected := range tt.expected { + if statements[i] != expected { + t.Errorf("statement %d: expected %q, got %q", i, expected, statements[i]) + } + } + }) + } +}