diff --git a/.gitignore b/.gitignore index 02f3b3a..719bfcc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,2 @@ .vscode vendor -migrations diff --git a/db_transaction.go b/db_transaction.go index ed0a0a7..ce570c4 100644 --- a/db_transaction.go +++ b/db_transaction.go @@ -54,11 +54,8 @@ func createSavepoint(ctx context.Context, tx *loggingTx) (*loggingSavepoint, err } savepointID := fmt.Sprintf("sp_%s", id) - if err := tx.Exec(ctx, Query( - // NOTE: Must interpolate identifier here as placeholders aren't valid in this position. - fmt.Sprintf("SAVEPOINT %s", savepointID), - Args{}, - )); err != nil { + // NOTE: Must interpolate identifier here as placeholders aren't valid in this position. + if err := tx.Exec(ctx, queryf("SAVEPOINT %s", savepointID)); err != nil { return nil, err } @@ -85,18 +82,12 @@ func (tx *loggingSavepoint) Done(err error) (combinedErr error) { defer func() { logDone(tx.logger, time.Since(tx.start), combinedErr) }() if err != nil { - return errors.Join(err, tx.Exec(context.Background(), Query( - // NOTE: Must interpolate identifier here as placeholders aren't valid in this position. - fmt.Sprintf("ROLLBACK TO %s", tx.savepointID), - Args{}, - ))) + // NOTE: Must interpolate identifier here as placeholders aren't valid in this position. + return errors.Join(err, tx.Exec(context.Background(), queryf("ROLLBACK TO %s", tx.savepointID))) } - return tx.Exec(context.Background(), Query( - // NOTE: Must interpolate identifier here as placeholders aren't valid in this position. - fmt.Sprintf("RELEASE %s", tx.savepointID), - Args{}, - )) + // NOTE: Must interpolate identifier here as placeholders aren't valid in this position. + return tx.Exec(context.Background(), queryf("RELEASE %s", tx.savepointID)) } var ErrPanicDuringTransaction = fmt.Errorf("encountered panic during transaction") diff --git a/db_transaction_test.go b/db_transaction_test.go index 6aeb65f..f51593c 100644 --- a/db_transaction_test.go +++ b/db_transaction_test.go @@ -18,29 +18,20 @@ func TestTransaction(t *testing.T) { setupTestTransactionTable(t, db) // Add record outside of transaction, ensure it's visible - err := db.Exec(context.Background(), Query( - `INSERT INTO test (x, y) VALUES (1, 42)`, - Args{}, - )) + err := db.Exec(context.Background(), RawQuery(`INSERT INTO test (x, y) VALUES (1, 42)`)) require.NoError(t, err) assert.Equal(t, map[int]int{1: 42}, testTableContents(t, db)) // Add record inside of a transaction tx1, err := db.Transact(context.Background()) require.NoError(t, err) - err = tx1.Exec(context.Background(), Query( - `INSERT INTO test (x, y) VALUES (2, 43)`, - Args{}, - )) + err = tx1.Exec(context.Background(), RawQuery(`INSERT INTO test (x, y) VALUES (2, 43)`)) require.NoError(t, err) // Add record inside of another transaction tx2, err := db.Transact(context.Background()) require.NoError(t, err) - err = tx2.Exec(context.Background(), Query( - `INSERT INTO test (x, y) VALUES (3, 44)`, - Args{}, - )) + err = tx2.Exec(context.Background(), RawQuery(`INSERT INTO test (x, y) VALUES (3, 44)`)) require.NoError(t, err) // Check what's visible pre-commit/rollback @@ -76,10 +67,7 @@ func TestConcurrentTransactions(t *testing.T) { } defer func() { err = tx.Done(err) }() - if err := tx.Exec(context.Background(), Query( - `SELECT pg_sleep(0.1)`, - Args{}, - )); err != nil { + if err := tx.Exec(context.Background(), RawQuery(`SELECT pg_sleep(0.1)`)); err != nil { return err } @@ -111,10 +99,7 @@ func TestConcurrentTransactions(t *testing.T) { for i := 0; i < 10; i++ { routine := i g.Go(func() (err error) { - if err := tx.Exec(context.Background(), Query( - `SELECT pg_sleep(0.1);`, - Args{}, - )); err != nil { + if err := tx.Exec(context.Background(), RawQuery(`SELECT pg_sleep(0.1);`)); err != nil { return err } @@ -177,22 +162,17 @@ func recurSavepoints(t *testing.T, db DB, index, rollbackAt int) { } func setupTestTransactionTable(t *testing.T, db DB) { - require.NoError(t, db.Exec(context.Background(), Query(` + require.NoError(t, db.Exec(context.Background(), RawQuery(` CREATE TABLE test ( id SERIAL PRIMARY KEY, x INTEGER NOT NULL, y INTEGER NOT NULL ); - `, - Args{}, - ))) + `))) } func testTableContents(t *testing.T, db DB) map[int]int { - pairs, err := scanTestPairs(db.Query(context.Background(), Query( - `SELECT x, y FROM test`, - Args{}, - ))) + pairs, err := scanTestPairs(db.Query(context.Background(), RawQuery(`SELECT x, y FROM test`))) require.NoError(t, err) pairsMap := make(map[int]int) diff --git a/migration_reader.go b/migration_reader.go index c4da427..57ba48a 100644 --- a/migration_reader.go +++ b/migration_reader.go @@ -23,6 +23,12 @@ type MigrationReader interface { ReadAll() ([]RawDefinition, error) } +type MigrationReaderFunc func() ([]RawDefinition, error) + +func (f MigrationReaderFunc) ReadAll() ([]RawDefinition, error) { + return f() +} + type RawDefinition struct { ID int Name string @@ -31,22 +37,26 @@ type RawDefinition struct { } var ( - identifierPattern = `[a-zA-Z0-9$_]+|"(?:[^"]+)"` - - cicPatternParts = strings.Join([]string{ - `CREATE`, - `(?:UNIQUE)?`, - `INDEX`, - `CONCURRENTLY`, - `(?:IF\s+NOT\s+EXISTS)?`, - `(` + identifierPattern + `)`, // capture index name - `ON`, - `(?:ONLY)?`, - `(` + identifierPattern + `)`, // capture table name - }, `\s+`) - - createIndexConcurrentlyPattern = regexp.MustCompile(cicPatternParts) - createIndexConcurrentlyPatternAll = regexp.MustCompile(cicPatternParts + "[^;]+;") + keyword = func(pattern string) string { return phrase(pattern) } + phrase = func(patterns ...string) string { return strings.Join(patterns, `\s+`) + `\s+` } + opt = func(pattern string) string { return `(?:` + pattern + `)?` } + + capturedIdentifierPattern = `([a-zA-Z0-9$_]+|"(?:[^"]+)")` + createIndexConcurrentlyPatternHead = strings.Join([]string{ + keyword(`CREATE`), + opt(keyword(`UNIQUE`)), + keyword(`INDEX`), + opt(keyword(`CONCURRENTLY`)), + opt(phrase(`IF`, `NOT`, `EXISTS`)), + capturedIdentifierPattern, // capture index name + `\s+`, + keyword(`ON`), + opt(keyword(`ONLY`)), + capturedIdentifierPattern, // capture table name + }, ``) + + createIndexConcurrentlyPattern = regexp.MustCompile(createIndexConcurrentlyPatternHead) + createIndexConcurrentlyPatternAll = regexp.MustCompile(createIndexConcurrentlyPatternHead + "[^;]+;") ) func ReadMigrations(reader MigrationReader) (definitions []Definition, _ error) { @@ -68,7 +78,7 @@ func ReadMigrations(reader MigrationReader) (definitions []Definition, _ error) if matches := createIndexConcurrentlyPattern.FindStringSubmatch(prunedUp); len(matches) > 0 { if strings.TrimSpace(createIndexConcurrentlyPatternAll.ReplaceAllString(prunedUp, "")) != "" { - return nil, fmt.Errorf("CIC is not the only statement in the up migration") + return nil, fmt.Errorf(`"create index concurrently" is not the only statement in the up migration`) } indexMetadata = &IndexMetadata{ @@ -77,8 +87,8 @@ func ReadMigrations(reader MigrationReader) (definitions []Definition, _ error) } } - if len(createIndexConcurrentlyPatternAll.FindAllString(prunedDown, 1)) > 0 { - return nil, fmt.Errorf("CIC is not allowed in down migrations") + if len(createIndexConcurrentlyPattern.FindAllString(prunedDown, 1)) > 0 { + return nil, fmt.Errorf(`"create index concurrently" is not allowed in down migrations`) } definitions = append(definitions, Definition{ diff --git a/migration_reader_filesystem.go b/migration_reader_filesystem.go index d44f6cc..f97c576 100644 --- a/migration_reader_filesystem.go +++ b/migration_reader_filesystem.go @@ -5,7 +5,7 @@ import ( "io/fs" "net/http" "os" - "path/filepath" + "path" "sort" "strconv" "strings" @@ -52,8 +52,8 @@ func (r *FilesystemMigrationReader) ReadAll() (definitions []RawDefinition, _ er } func (r *FilesystemMigrationReader) readDefinition(dirname string) (RawDefinition, bool, error) { - upPath := filepath.Join(dirname, "up.sql") // TODO - filepath join for embed? - downPath := filepath.Join(dirname, "down.sql") + upPath := path.Join(dirname, "up.sql") + downPath := path.Join(dirname, "down.sql") upFileContents, upErr := readFile(r.fs, upPath) downFileContents, downErr := readFile(r.fs, downPath) diff --git a/migration_reader_test.go b/migration_reader_test.go new file mode 100644 index 0000000..a6306f7 --- /dev/null +++ b/migration_reader_test.go @@ -0,0 +1,77 @@ +package pgutil + +import ( + "path" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var testUpQuery = `-- Create the comments table +CREATE TABLE comments ( + id SERIAL PRIMARY KEY, + post_id INTEGER NOT NULL REFERENCES posts(id) ON DELETE CASCADE, + user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + content TEXT NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), +); +` + +var testDownQuery = `-- Drop the comments table +DROP TABLE IF EXISTS comments; +` + +var testConcurrentIndexUpQuery = `-- Create a concurrent index +CREATE INDEX CONCURRENTLY idx_users_email ON users (email);` + +var testConcurrentIndexDownQuery = `-- Drop the concurrent index +DROP INDEX CONCURRENTLY IF EXISTS idx_users_email;` + +func TestReadMigrations(t *testing.T) { + t.Run("valid", func(t *testing.T) { + definitions, err := ReadMigrations(NewFilesystemMigrationReader(path.Join("testdata", "valid"))) + require.NoError(t, err) + require.Len(t, definitions, 3) + + assert.Equal(t, Definition{ + ID: 3, + Name: "third", + UpQuery: RawQuery(testUpQuery), + DownQuery: RawQuery(testDownQuery), + }, definitions[2]) + }) + + t.Run("CIC pattern", func(t *testing.T) { + t.Skip() + definitions, err := ReadMigrations(NewFilesystemMigrationReader(path.Join("testdata", "cic_pattern"))) + require.NoError(t, err) + require.Len(t, definitions, 4) + + assert.Equal(t, Definition{ + ID: 3, + Name: "third", + UpQuery: RawQuery(testConcurrentIndexUpQuery), + DownQuery: RawQuery(testConcurrentIndexDownQuery), + IndexMetadata: &IndexMetadata{ + TableName: "users", + IndexName: "idx_users_email", + }, + }, definitions[3]) + }) + + t.Run("duplicate identifiers", func(t *testing.T) { + _, err := ReadMigrations(NewFilesystemMigrationReader(path.Join("testdata", "duplicate_identifiers"))) + assert.ErrorContains(t, err, "duplicate migration identifier 2") + }) + + t.Run("CIC with additional queries", func(t *testing.T) { + _, err := ReadMigrations(NewFilesystemMigrationReader(path.Join("testdata", "cic_with_additional_queries"))) + assert.ErrorContains(t, err, `"create index concurrently" is not the only statement in the up migration`) + }) + + t.Run("CIC in down migration", func(t *testing.T) { + _, err := ReadMigrations(NewFilesystemMigrationReader(path.Join("testdata", "cic_in_down_migration"))) + assert.ErrorContains(t, err, `"create index concurrently" is not the only statement in the up migration`) + }) +} diff --git a/migration_runner.go b/migration_runner.go index 349007d..61f44ff 100644 --- a/migration_runner.go +++ b/migration_runner.go @@ -18,8 +18,6 @@ type Runner struct { locker *TransactionalLocker } -// TODO - additional logging - func NewMigrationRunner(db DB, reader MigrationReader, logger nacelle.Logger) (*Runner, error) { definitions, err := ReadMigrations(reader) if err != nil { @@ -132,7 +130,7 @@ func (r *Runner) applyDefinitions(ctx context.Context, definitions []Definition, } if len(migrationsToApply) == 0 { - // Nothing to apply + r.logger.Info("No migrations to apply") upToDate = true return nil } @@ -147,11 +145,17 @@ func (r *Runner) applyDefinitions(ctx context.Context, definitions []Definition, if err := r.withMigrationLog(ctx, definition, reverse, func(_ int) error { return r.db.WithTransaction(ctx, func(tx DB) error { - query := definition.UpQuery + query, direction := definition.UpQuery, "up" if reverse { - query = definition.DownQuery + query, direction = definition.DownQuery, "down" } + r.logger.InfoWithFields(log.LogFields{ + "id": definition.ID, + "name": definition.Name, + "direction": direction, + }, "Applying migration") + return tx.Exec(ctx, query) }) }); err != nil { @@ -169,6 +173,15 @@ func (r *Runner) applyConcurrentIndexCreation(ctx context.Context, definition De tableName := definition.IndexMetadata.TableName indexName := definition.IndexMetadata.IndexName + logger := r.logger.WithFields(log.LogFields{ + "id": definition.ID, + "name": definition.Name, + "direction": "up", + "tableName": tableName, + "indexName": indexName, + }) + logger.Info("Handling concurrent index creation") + indexPollLoop: for i := 0; ; i++ { if i != 0 { @@ -183,11 +196,25 @@ indexPollLoop: } if exists { + logger.InfoWithFields(log.LogFields{ + "phase": deref(indexStatus.Phase), + "lockersTotal": deref(indexStatus.LockersTotal), + "lockersDone": deref(indexStatus.LockersDone), + "blocksTotal": deref(indexStatus.BlocksTotal), + "blocksDone": deref(indexStatus.BlocksDone), + "tuplesTotal": deref(indexStatus.TuplesTotal), + "tuplesDone": deref(indexStatus.TuplesDone), + }, "Index exists") + if indexStatus.IsValid { + logger.Info("Index is valid") + if recheck, err := r.handleValidIndex(ctx, definition); err != nil { return err } else if recheck { continue indexPollLoop + } else { + return nil } } @@ -195,14 +222,16 @@ indexPollLoop: continue indexPollLoop } - if err := r.db.Exec(ctx, Query( - "DROP INDEX IF EXISTS {:indexName}", - Args{"indexName": indexName}, - )); err != nil { + logger.Info("Dropping invalid index") + + // NOTE: Must interpolate identifier here as placeholders aren't valid in this position. + if err := r.db.Exec(ctx, queryf(`DROP INDEX IF EXISTS %s`, indexName)); err != nil { return err } } + logger.Info("Creating index") + if raceDetected, err := r.createIndexConcurrently(ctx, definition); err != nil { return err } else if raceDetected { @@ -219,14 +248,13 @@ func (r *Runner) handleValidIndex(ctx context.Context, definition Definition) (r if err != nil { return err } - if !ok { if err := tx.Exec(ctx, Query(` INSERT INTO migration_logs (migration_id, reverse, finished_at, success) VALUES ({:id}, false, current_timestamp, true) - `, Args{ - "id": definition.ID, - })); err != nil { + `, + Args{"id": definition.ID}, + )); err != nil { return err } @@ -250,9 +278,9 @@ func (r *Runner) handleValidIndex(ctx context.Context, definition Definition) (r UPDATE migration_logs SET success = true, finished_at = current_timestamp WHERE id = {:id} - `, Args{ - "id": log.ID, - })); err != nil { + `, + Args{"id": log.ID}, + )); err != nil { return err } @@ -273,9 +301,9 @@ func (r *Runner) createIndexConcurrently(ctx context.Context, definition Definit UPDATE migration_logs SET last_heartbeat_at = current_timestamp WHERE id = {:id} - `, Args{ - "id": id, - })); err != nil { + `, + Args{"id": id}, + )); err != nil && ctx.Err() != context.Canceled { r.logger.ErrorWithFields(log.LogFields{ "error": err, }, "Failed to update heartbeat") @@ -293,7 +321,10 @@ func (r *Runner) createIndexConcurrently(ctx context.Context, definition Definit return err } - if err := r.db.Exec(ctx, Query("DELETE FROM migration_logs WHERE id = {:id}", Args{"id": id})); err != nil { + if err := r.db.Exec(ctx, Query( + `DELETE FROM migration_logs WHERE id = {:id}`, + Args{"id": id}, + )); err != nil { return err } @@ -466,9 +497,9 @@ func (r *Runner) getLogForConcurrentIndex(ctx context.Context, db DB, id int) (c COALESCE(last_heartbeat_at, started_at) FROM ranked_migration_logs WHERE rank = 1 AND NOT reverse - `, Args{ - "id": id, - }))) + `, + Args{"id": id}, + ))) } // diff --git a/migration_runner_test.go b/migration_runner_test.go index 44ff93a..e8b9beb 100644 --- a/migration_runner_test.go +++ b/migration_runner_test.go @@ -1,3 +1,330 @@ package pgutil -// TODO +import ( + "context" + "sync" + "testing" + "time" + + "github.com/go-nacelle/log/v2" + "github.com/stretchr/testify/require" +) + +func TestApply(t *testing.T) { + definitions := []RawDefinition{ + {ID: 1, RawUpQuery: "CREATE TABLE users (id SERIAL PRIMARY KEY, email TEXT);"}, + {ID: 2, RawUpQuery: "INSERT INTO users (email) VALUES ('test@gmail.com');"}, + {ID: 3, RawUpQuery: "ALTER TABLE users ADD COLUMN name TEXT;"}, + {ID: 4, RawUpQuery: "UPDATE users SET name = 'test';"}, + {ID: 5, RawUpQuery: "CREATE UNIQUE INDEX users_email_idx ON users (email);"}, + } + definitionsWithoutUpdates := []RawDefinition{definitions[0], definitions[1], definitions[2], definitions[4]} + reader := MigrationReaderFunc(func() ([]RawDefinition, error) { return definitions, nil }) + readerWithoutUpdates := MigrationReaderFunc(func() ([]RawDefinition, error) { return definitionsWithoutUpdates, nil }) + + t.Run("all", func(t *testing.T) { + db := NewTestDB(t) + ctx := context.Background() + + // Apply all migrations from scratch + runner, err := NewMigrationRunner(db, reader, log.NewNilLogger()) + require.NoError(t, err) + require.NoError(t, runner.ApplyAll(ctx)) + + // Assert last migration (unique index) was applied + err = db.Exec(ctx, Query( + "INSERT INTO users (name, email) VALUES ({:name}, {:email})", + Args{"name": "duplicate", "email": "test@gmail.com"}, + )) + require.ErrorContains(t, err, "duplicate key value violates unique constraint") + }) + + t.Run("tail", func(t *testing.T) { + db := NewTestDB(t) + ctx := context.Background() + + // Head first + runner, err := NewMigrationRunner(db, reader, log.NewNilLogger()) + require.NoError(t, err) + require.NoError(t, runner.Apply(ctx, 2)) + + // Assert no name column yet + _, _, err = ScanString(db.Query(ctx, RawQuery("SELECT name FROM users WHERE email = 'test@gmail.com'"))) + require.ErrorContains(t, err, "column \"name\" does not exist") + + // Apply the tail + require.NoError(t, runner.Apply(ctx, 5)) + + // Assert name column added and populated + email, _, err := ScanString(db.Query(ctx, RawQuery("SELECT name FROM users WHERE email = 'test@gmail.com'"))) + require.NoError(t, err) + require.Equal(t, "test", email) + }) + + t.Run("gaps", func(t *testing.T) { + db := NewTestDB(t) + ctx := context.Background() + + // Apply all migrations except #4 + runnerWithHoles, err := NewMigrationRunner(db, readerWithoutUpdates, log.NewNilLogger()) + require.NoError(t, err) + require.NoError(t, runnerWithHoles.ApplyAll(ctx)) + + // Assert name column exists but is not yet populated + namePtr, _, err := ScanNilString(db.Query(ctx, RawQuery("SELECT name FROM users WHERE email = 'test@gmail.com'"))) + require.NoError(t, err) + require.Nil(t, namePtr) + + // Apply all missing migrations + runner, err := NewMigrationRunner(db, reader, log.NewNilLogger()) + require.NoError(t, err) + require.NoError(t, runner.ApplyAll(ctx)) + + // Assert name colum now populated + name, _, err := ScanString(db.Query(ctx, RawQuery("SELECT name FROM users WHERE email = 'test@gmail.com'"))) + require.NoError(t, err) + require.Equal(t, "test", name) + }) +} + +func TestApplyCreateConcurrentIndex(t *testing.T) { + definitions := []RawDefinition{ + {ID: 1, RawUpQuery: "CREATE TABLE users (id SERIAL PRIMARY KEY, name TEXT NOT NULL, email TEXT NOT NULL);"}, + {ID: 2, RawUpQuery: "INSERT INTO users (name, email) VALUES ('test1', 'test1@gmail.com');"}, + {ID: 3, RawUpQuery: "CREATE UNIQUE INDEX CONCURRENTLY users_email_idx ON users (email);"}, + {ID: 4, RawUpQuery: "INSERT INTO users (name, email) VALUES ('test2', 'test2@gmail.com');"}, + } + reader := MigrationReaderFunc(func() ([]RawDefinition, error) { return definitions, nil }) + + t.Run("CIC", func(t *testing.T) { + db := NewTestDB(t) + ctx := context.Background() + + // Apply all migrations from scratch + runner, err := NewMigrationRunner(db, reader, log.NewNilLogger()) + require.NoError(t, err) + require.NoError(t, runner.ApplyAll(ctx)) + + // Assert last migration (unique index) was applied + err = db.Exec(ctx, Query( + "INSERT INTO users (name, email) VALUES ({:name}, {:email})", + Args{"name": "duplicate", "email": "test2@gmail.com"}, + )) + require.ErrorContains(t, err, "duplicate key value violates unique constraint") + }) + + t.Run("CIC (already created)", func(t *testing.T) { + db := NewTestDB(t) + ctx := context.Background() + + // Apply just the first migration + runner, err := NewMigrationRunner(db, reader, log.NewNilLogger()) + require.NoError(t, err) + require.NoError(t, runner.Apply(ctx, 2)) + + // Create the index outside of the migration infrastructure + require.NoError(t, db.Exec(ctx, RawQuery(definitions[2].RawUpQuery))) + + // Apply remaining migrations + require.NoError(t, runner.ApplyAll(ctx)) + + // Assert last migration (unique index) was applied + err = db.Exec(ctx, Query( + "INSERT INTO users (name, email) VALUES ({:name}, {:email})", + Args{"name": "duplicate", "email": "test2@gmail.com"}, + )) + require.ErrorContains(t, err, "duplicate key value violates unique constraint") + }) + + t.Run("CIC (invalid)", func(t *testing.T) { + db := NewTestDB(t) + ctx := context.Background() + + // Apply just the first migration + runner, err := NewMigrationRunner(db, reader, log.NewNilLogger()) + require.NoError(t, err) + require.NoError(t, runner.Apply(ctx, 2)) + + // Create the index outside of the migration infrastructure and force it to be invalid + require.NoError(t, db.Exec(ctx, RawQuery(definitions[2].RawUpQuery))) + require.NoError(t, db.Exec(ctx, RawQuery(` + UPDATE pg_index + SET indisvalid = false + WHERE indexrelid = ( + SELECT oid + FROM pg_class + WHERE relname = 'users_email_idx' + ); + `))) + + // Apply remaining migrations + require.NoError(t, runner.ApplyAll(ctx)) + + // Assert last migration (unique index) was applied + err = db.Exec(ctx, Query( + "INSERT INTO users (name, email) VALUES ({:name}, {:email})", + Args{"name": "duplicate", "email": "test2@gmail.com"}, + )) + require.ErrorContains(t, err, "duplicate key value violates unique constraint") + }) + + // + // TODO - rewrite this + // + + t.Run("CIC (in progress)", func(t *testing.T) { + db := NewTestDB(t) + ctx := context.Background() + + // Apply the first two migrations + runner, err := NewMigrationRunner(db, reader, log.NewNilLogger()) + require.NoError(t, err) + require.NoError(t, runner.Apply(ctx, 2)) + + var wg sync.WaitGroup + errCh := make(chan error, 1) + + // Start a transaction that holds a lock on the table + tx, err := db.Transact(ctx) + require.NoError(t, err) + defer tx.Done(nil) + + // Insert a row but don't commit, holding a lock on the 'users' table + err = tx.Exec(ctx, RawQuery("INSERT INTO users (name, email) VALUES ('blocking', 'blocking@example.com')")) + require.NoError(t, err) + + // Begin creating the index concurrently outside the migration runner + wg.Add(1) + go func() { + defer wg.Done() + + // This will block until the transaction above commits or rolls back + err := db.Exec(ctx, RawQuery(definitions[2].RawUpQuery)) + if err != nil { + errCh <- err + return + } + }() + + // Give some time for the index creation to start and block + time.Sleep(1 * time.Second) + + // Begin applying the third migration in the runner + wg.Add(1) + go func() { + defer wg.Done() + + // This should wait until the index creation completes + err := runner.ApplyAll(ctx) + if err != nil { + errCh <- err + return + } + }() + + // Hold the transaction open for a short time to simulate blocking + time.Sleep(2 * time.Second) + + // Commit the transaction to release the lock + err = tx.Done(nil) + require.NoError(t, err) + + // Wait for all goroutines to complete + wg.Wait() + close(errCh) + + // Check for errors from goroutines + for err := range errCh { + require.NoError(t, err) + } + + // Assert that the migration runner has unblocked and the index exists + err = db.Exec(ctx, Query( + "INSERT INTO users (name, email) VALUES ({:name}, {:email})", + Args{"name": "duplicate", "email": "test2@gmail.com"}, + )) + require.ErrorContains(t, err, "duplicate key value violates unique constraint") + }) +} + +func TestUndo(t *testing.T) { + definitions := []RawDefinition{ + { + ID: 1, + RawUpQuery: "CREATE TABLE users (id SERIAL PRIMARY KEY, email TEXT);", + RawDownQuery: "DROP TABLE users;", + }, + { + ID: 2, + RawUpQuery: "CREATE TABLE comments (id SERIAL PRIMARY KEY, content TEXT NOT NULL, user_id INTEGER NOT NULL);", + RawDownQuery: "DROP TABLE comments;", + }, + { + ID: 3, + RawUpQuery: "ALTER TABLE comments ADD COLUMN updated_at TIMESTAMP WITH TIME ZONE;", + RawDownQuery: "ALTER TABLE comments DROP COLUMN updated_at;", + }, + { + ID: 4, + RawUpQuery: "ALTER TABLE comments ADD COLUMN created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW();", + RawDownQuery: "ALTER TABLE comments DROP COLUMN created_at;", + }, + + {ID: 5, RawUpQuery: "INSERT INTO users (email) VALUES ('test@gmail.com');"}, + {ID: 6, RawUpQuery: "INSERT INTO comments (content, user_id) VALUES ('test', 1);"}, + {ID: 7, RawUpQuery: "UPDATE comments SET updated_at = NOW();"}, + } + definitionsWithoutCreatedAt := []RawDefinition{definitions[0], definitions[1], definitions[2], definitions[4], definitions[5], definitions[6]} + reader := MigrationReaderFunc(func() ([]RawDefinition, error) { return definitions, nil }) + readerWithoutCreatedAt := MigrationReaderFunc(func() ([]RawDefinition, error) { return definitionsWithoutCreatedAt, nil }) + + t.Run("tail", func(t *testing.T) { + db := NewTestDB(t) + ctx := context.Background() + + // Apply all migrations + runner, err := NewMigrationRunner(db, reader, log.NewNilLogger()) + require.NoError(t, err) + require.NoError(t, runner.ApplyAll(ctx)) + + // Assert columns exist and are populated + updatedAt, _, err := ScanNilTime(db.Query(ctx, RawQuery("SELECT created_at FROM comments WHERE user_id = 1"))) + require.NoError(t, err) + require.NotNil(t, updatedAt) + + // Undo migrations that added created_at/updated_at columns + require.NoError(t, runner.Undo(ctx, 3)) + + // Assert columns dropped + _, _, err = ScanString(db.Query(ctx, RawQuery("SELECT updated_at FROM comments WHERE user_id = 1"))) + require.ErrorContains(t, err, "column \"updated_at\" does not exist") + }) + + t.Run("gaps", func(t *testing.T) { + db := NewTestDB(t) + ctx := context.Background() + + // Apply all migrations + runner, err := NewMigrationRunner(db, reader, log.NewNilLogger()) + require.NoError(t, err) + require.NoError(t, runner.ApplyAll(ctx)) + + // Undo migrations but skip #4 + runnerWithHoles, err := NewMigrationRunner(db, readerWithoutCreatedAt, log.NewNilLogger()) + require.NoError(t, err) + require.NoError(t, runnerWithHoles.Undo(ctx, 3)) + + // Assert created_at exists but updated_at does not + _, _, err = ScanNilTime(db.Query(ctx, RawQuery("SELECT created_at FROM comments WHERE user_id = 1"))) + require.NoError(t, err) + _, _, err = ScanString(db.Query(ctx, RawQuery("SELECT updated_at FROM comments WHERE user_id = 1"))) + require.ErrorContains(t, err, "column \"updated_at\" does not exist") + + // Undo migrations including #4 + require.NoError(t, runner.Undo(ctx, 3)) + + // Assert both columns dropped + _, _, err = ScanString(db.Query(ctx, RawQuery("SELECT created_at FROM comments WHERE user_id = 1"))) + require.ErrorContains(t, err, "column \"created_at\" does not exist") + }) +} diff --git a/query.go b/query.go index 665be21..c13c04a 100644 --- a/query.go +++ b/query.go @@ -82,6 +82,10 @@ func RawQuery(format string, args ...any) Q { return Q{internalFormat: format, parameterizedArgs: args} } +func queryf(format string, args ...any) Q { + return RawQuery(fmt.Sprintf(format, args...)) +} + func (q Q) Format() (string, []any) { return replaceWithPairs(q.internalFormat, q.replacerPairs...), q.parameterizedArgs } diff --git a/rows_slice_scanner.go b/rows_slice_scanner.go index 90450b3..1b7cd37 100644 --- a/rows_slice_scanner.go +++ b/rows_slice_scanner.go @@ -1,5 +1,7 @@ package pgutil +import "time" + type SliceScannerFunc[T any] func(rows Rows, queryErr error) ([]T, error) type FirstScannerFunc[T any] func(rows Rows, queryErr error) (T, bool, error) @@ -77,4 +79,7 @@ var ( // TODO - nulls // TODO - arrays // TODO - bytes + + ScanNilTime = NewFirstScanner(NewAnyValueScanner[*time.Time]()) + ScanNilString = NewFirstScanner(NewAnyValueScanner[*string]()) ) diff --git a/rows_slice_scanner_test.go b/rows_slice_scanner_test.go index a86d63b..fdb532b 100644 --- a/rows_slice_scanner_test.go +++ b/rows_slice_scanner_test.go @@ -10,86 +10,70 @@ import ( func TestSliceScanner(t *testing.T) { t.Run("scalar values", func(t *testing.T) { - values, err := ScanInts(NewTestDB(t).Query(context.Background(), Query( - `SELECT * FROM (VALUES (1), (2), (3)) AS t(number)`, - Args{}, - ))) - + values, err := ScanInts(NewTestDB(t).Query(context.Background(), + RawQuery(`SELECT * FROM (VALUES (1), (2), (3)) AS t(number)`), + )) require.NoError(t, err) assert.Equal(t, []int{1, 2, 3}, values) }) t.Run("custom struct values", func(t *testing.T) { - values, err := scanTestPairs(NewTestDB(t).Query(context.Background(), Query( - `SELECT * FROM (VALUES (1,2), (2,3), (3,4)) AS t(x,y)`, - Args{}, - ))) - + values, err := scanTestPairs(NewTestDB(t).Query(context.Background(), + RawQuery(`SELECT * FROM (VALUES (1,2), (2,3), (3,4)) AS t(x,y)`), + )) require.NoError(t, err) assert.Equal(t, []testPair{{1, 2}, {2, 3}, {3, 4}}, values) }) t.Run("no values", func(t *testing.T) { - values, err := ScanInts(NewTestDB(t).Query(context.Background(), Query( - `SELECT * FROM (VALUES (1), (2), (3)) AS t(number) LIMIT 0`, - Args{}, - ))) - + values, err := ScanInts(NewTestDB(t).Query(context.Background(), + RawQuery(`SELECT * FROM (VALUES (1), (2), (3)) AS t(number) LIMIT 0`), + )) require.NoError(t, err) assert.Empty(t, values) }) } func TestMaybeSliceScanner(t *testing.T) { - values, err := scanMaybeTestPairs(NewTestDB(t).Query(context.Background(), Query( - `SELECT * FROM (VALUES (1,2), (2,3), (0,0), (3,4)) AS t(x,y)`, - Args{}, - ))) - + values, err := scanMaybeTestPairs(NewTestDB(t).Query(context.Background(), + RawQuery(`SELECT * FROM (VALUES (1,2), (2,3), (0,0), (3,4)) AS t(x,y)`), + )) require.NoError(t, err) assert.Equal(t, []testPair{{1, 2}, {2, 3}}, values) } func TestFirstScanner(t *testing.T) { t.Run("scalar value", func(t *testing.T) { - value, ok, err := ScanInt(NewTestDB(t).Query(context.Background(), Query( - `SELECT * FROM (VALUES (1)) AS t(number)`, - Args{}, - ))) - + value, ok, err := ScanInt(NewTestDB(t).Query(context.Background(), + RawQuery(`SELECT * FROM (VALUES (1)) AS t(number)`), + )) require.NoError(t, err) assert.True(t, ok) assert.Equal(t, 1, value) }) t.Run("scalar value (ignores non-first values)", func(t *testing.T) { - value, ok, err := ScanInt(NewTestDB(t).Query(context.Background(), Query( - `SELECT * FROM (VALUES (1), (2), (3)) AS t(number)`, - Args{}, - ))) - + value, ok, err := ScanInt(NewTestDB(t).Query(context.Background(), + RawQuery(`SELECT * FROM (VALUES (1), (2), (3)) AS t(number)`), + )) require.NoError(t, err) assert.True(t, ok) assert.Equal(t, 1, value) }) t.Run("custom struct value", func(t *testing.T) { - value, ok, err := scanFirstTestPair(NewTestDB(t).Query(context.Background(), Query( - `SELECT * FROM (VALUES (1,2)) AS t(x,y)`, - Args{}, - ))) - + value, ok, err := scanFirstTestPair(NewTestDB(t).Query(context.Background(), + RawQuery(`SELECT * FROM (VALUES (1,2)) AS t(x,y)`), + )) require.NoError(t, err) assert.True(t, ok) assert.Equal(t, testPair{1, 2}, value) }) t.Run("no value", func(t *testing.T) { - _, ok, err := ScanInt(NewTestDB(t).Query(context.Background(), Query( - `SELECT * FROM (VALUES (1), (2), (3)) AS t(number) LIMIT 0`, - Args{}, - ))) - + _, ok, err := ScanInt(NewTestDB(t).Query(context.Background(), + RawQuery(`SELECT * FROM (VALUES (1), (2), (3)) AS t(number) LIMIT 0`), + )) require.NoError(t, err) assert.False(t, ok) }) @@ -97,11 +81,9 @@ func TestFirstScanner(t *testing.T) { func TestMaybeFirstScanner(t *testing.T) { t.Run("custom struct value", func(t *testing.T) { - value, ok, err := scanMaybeFirstTestPair(NewTestDB(t).Query(context.Background(), Query( - `SELECT * FROM (VALUES (1,2)) AS t(x,y)`, - Args{}, - ))) - + value, ok, err := scanMaybeFirstTestPair(NewTestDB(t).Query(context.Background(), + RawQuery(`SELECT * FROM (VALUES (1,2)) AS t(x,y)`), + )) require.NoError(t, err) assert.True(t, ok) assert.Equal(t, testPair{1, 2}, value) @@ -117,11 +99,9 @@ func TestMaybeFirstScanner(t *testing.T) { return p, p.x != 0 && p.y != 0, err }) - _, ok, err := scanner(NewTestDB(t).Query(context.Background(), Query( - `SELECT * FROM (VALUES (0,0), (1,2)) AS t(x,y)`, - Args{}, - ))) - + _, ok, err := scanner(NewTestDB(t).Query(context.Background(), + RawQuery(`SELECT * FROM (VALUES (0,0), (1,2)) AS t(x,y)`), + )) require.NoError(t, err) assert.False(t, ok) }) diff --git a/rows_value_scanner_test.go b/rows_value_scanner_test.go index 5878f50..88c831d 100644 --- a/rows_value_scanner_test.go +++ b/rows_value_scanner_test.go @@ -13,15 +13,7 @@ func TestCollector(t *testing.T) { collector := NewCollector[int](NewAnyValueScanner[int]()) scanner := NewRowScanner(collector.Scanner()) - require.NoError(t, scanner(db.Query(context.Background(), Query( - `SELECT * FROM (VALUES (1), (2), (3)) AS t(number)`, - Args{}, - )))) - - require.NoError(t, scanner(db.Query(context.Background(), Query( - `SELECT * FROM (VALUES (4), (5), (6)) AS t(number)`, - Args{}, - )))) - + require.NoError(t, scanner(db.Query(context.Background(), RawQuery(`SELECT * FROM (VALUES (1), (2), (3)) AS t(number)`)))) + require.NoError(t, scanner(db.Query(context.Background(), RawQuery(`SELECT * FROM (VALUES (4), (5), (6)) AS t(number)`)))) assert.Equal(t, []int{1, 2, 3, 4, 5, 6}, collector.Slice()) } diff --git a/testdata/cic_in_down_migration/1_first/down.sql b/testdata/cic_in_down_migration/1_first/down.sql new file mode 100644 index 0000000..2332b61 --- /dev/null +++ b/testdata/cic_in_down_migration/1_first/down.sql @@ -0,0 +1,2 @@ +-- Drop the sample table +DROP TABLE IF EXISTS users; diff --git a/testdata/cic_in_down_migration/1_first/up.sql b/testdata/cic_in_down_migration/1_first/up.sql new file mode 100644 index 0000000..e93c3ac --- /dev/null +++ b/testdata/cic_in_down_migration/1_first/up.sql @@ -0,0 +1,6 @@ +-- Create a sample table +CREATE TABLE users ( + id SERIAL PRIMARY KEY, + username TEXT NOT NULL, + email TEXT NOT NULL +); diff --git a/testdata/cic_in_down_migration/2_second/down.sql b/testdata/cic_in_down_migration/2_second/down.sql new file mode 100644 index 0000000..854c795 --- /dev/null +++ b/testdata/cic_in_down_migration/2_second/down.sql @@ -0,0 +1,3 @@ +-- Remove the index and drop the column +DROP INDEX IF EXISTS idx_users_created_at; +ALTER TABLE users DROP COLUMN IF EXISTS created_at; diff --git a/testdata/cic_in_down_migration/2_second/up.sql b/testdata/cic_in_down_migration/2_second/up.sql new file mode 100644 index 0000000..ae260ef --- /dev/null +++ b/testdata/cic_in_down_migration/2_second/up.sql @@ -0,0 +1,3 @@ +-- Add a new column and create an index +ALTER TABLE users ADD COLUMN created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(); +CREATE INDEX idx_users_created_at ON users(created_at); diff --git a/testdata/cic_in_down_migration/3_third/down.sql b/testdata/cic_in_down_migration/3_third/down.sql new file mode 100644 index 0000000..55048d7 --- /dev/null +++ b/testdata/cic_in_down_migration/3_third/down.sql @@ -0,0 +1,2 @@ +-- Recreate the index using CREATE INDEX CONCURRENTLY +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_users_created_at ON users(created_at); diff --git a/testdata/cic_in_down_migration/3_third/up.sql b/testdata/cic_in_down_migration/3_third/up.sql new file mode 100644 index 0000000..d48ebc4 --- /dev/null +++ b/testdata/cic_in_down_migration/3_third/up.sql @@ -0,0 +1,2 @@ +-- Drop the index created in the second migration +DROP INDEX IF EXISTS idx_users_created_at; diff --git a/testdata/cic_pattern/1_first/down.sql b/testdata/cic_pattern/1_first/down.sql new file mode 100755 index 0000000..26ce8a7 --- /dev/null +++ b/testdata/cic_pattern/1_first/down.sql @@ -0,0 +1,2 @@ +-- Drop the users table +DROP TABLE IF EXISTS users; diff --git a/testdata/cic_pattern/1_first/up.sql b/testdata/cic_pattern/1_first/up.sql new file mode 100755 index 0000000..14d93e8 --- /dev/null +++ b/testdata/cic_pattern/1_first/up.sql @@ -0,0 +1,7 @@ +-- Create the users table +CREATE TABLE users ( + id SERIAL PRIMARY KEY, + username TEXT NOT NULL UNIQUE, + email TEXT NOT NULL UNIQUE, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); diff --git a/testdata/cic_pattern/2_second/down.sql b/testdata/cic_pattern/2_second/down.sql new file mode 100755 index 0000000..d032b51 --- /dev/null +++ b/testdata/cic_pattern/2_second/down.sql @@ -0,0 +1,2 @@ +-- Drop the posts table +DROP TABLE IF EXISTS posts; diff --git a/testdata/cic_pattern/2_second/up.sql b/testdata/cic_pattern/2_second/up.sql new file mode 100755 index 0000000..25a7f72 --- /dev/null +++ b/testdata/cic_pattern/2_second/up.sql @@ -0,0 +1,8 @@ +-- Create the posts table +CREATE TABLE posts ( + id SERIAL PRIMARY KEY, + user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + title TEXT NOT NULL, + content TEXT, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), +); diff --git a/testdata/cic_pattern/3_third/down.sql b/testdata/cic_pattern/3_third/down.sql new file mode 100755 index 0000000..549f08a --- /dev/null +++ b/testdata/cic_pattern/3_third/down.sql @@ -0,0 +1,2 @@ +-- Drop the comments table +DROP TABLE IF EXISTS comments; diff --git a/testdata/cic_pattern/3_third/up.sql b/testdata/cic_pattern/3_third/up.sql new file mode 100755 index 0000000..9af4f1e --- /dev/null +++ b/testdata/cic_pattern/3_third/up.sql @@ -0,0 +1,8 @@ +-- Create the comments table +CREATE TABLE comments ( + id SERIAL PRIMARY KEY, + post_id INTEGER NOT NULL REFERENCES posts(id) ON DELETE CASCADE, + user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + content TEXT NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), +); diff --git a/testdata/cic_pattern/4_fourth/down.sql b/testdata/cic_pattern/4_fourth/down.sql new file mode 100644 index 0000000..e47a9eb --- /dev/null +++ b/testdata/cic_pattern/4_fourth/down.sql @@ -0,0 +1,2 @@ +-- Drop the concurrent index +DROP INDEX CONCURRENTLY IF EXISTS idx_users_email; \ No newline at end of file diff --git a/testdata/cic_pattern/4_fourth/up.sql b/testdata/cic_pattern/4_fourth/up.sql new file mode 100644 index 0000000..70fed1d --- /dev/null +++ b/testdata/cic_pattern/4_fourth/up.sql @@ -0,0 +1,2 @@ +-- Create a concurrent index +CREATE INDEX CONCURRENTLY idx_users_email ON users (email); \ No newline at end of file diff --git a/testdata/cic_with_additional_queries/1_first/down.sql b/testdata/cic_with_additional_queries/1_first/down.sql new file mode 100755 index 0000000..26ce8a7 --- /dev/null +++ b/testdata/cic_with_additional_queries/1_first/down.sql @@ -0,0 +1,2 @@ +-- Drop the users table +DROP TABLE IF EXISTS users; diff --git a/testdata/cic_with_additional_queries/1_first/up.sql b/testdata/cic_with_additional_queries/1_first/up.sql new file mode 100755 index 0000000..14d93e8 --- /dev/null +++ b/testdata/cic_with_additional_queries/1_first/up.sql @@ -0,0 +1,7 @@ +-- Create the users table +CREATE TABLE users ( + id SERIAL PRIMARY KEY, + username TEXT NOT NULL UNIQUE, + email TEXT NOT NULL UNIQUE, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); diff --git a/testdata/cic_with_additional_queries/2_second/down.sql b/testdata/cic_with_additional_queries/2_second/down.sql new file mode 100755 index 0000000..d032b51 --- /dev/null +++ b/testdata/cic_with_additional_queries/2_second/down.sql @@ -0,0 +1,2 @@ +-- Drop the posts table +DROP TABLE IF EXISTS posts; diff --git a/testdata/cic_with_additional_queries/2_second/up.sql b/testdata/cic_with_additional_queries/2_second/up.sql new file mode 100755 index 0000000..25a7f72 --- /dev/null +++ b/testdata/cic_with_additional_queries/2_second/up.sql @@ -0,0 +1,8 @@ +-- Create the posts table +CREATE TABLE posts ( + id SERIAL PRIMARY KEY, + user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + title TEXT NOT NULL, + content TEXT, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), +); diff --git a/testdata/cic_with_additional_queries/3_third/down.sql b/testdata/cic_with_additional_queries/3_third/down.sql new file mode 100755 index 0000000..549f08a --- /dev/null +++ b/testdata/cic_with_additional_queries/3_third/down.sql @@ -0,0 +1,2 @@ +-- Drop the comments table +DROP TABLE IF EXISTS comments; diff --git a/testdata/cic_with_additional_queries/3_third/up.sql b/testdata/cic_with_additional_queries/3_third/up.sql new file mode 100755 index 0000000..9af4f1e --- /dev/null +++ b/testdata/cic_with_additional_queries/3_third/up.sql @@ -0,0 +1,8 @@ +-- Create the comments table +CREATE TABLE comments ( + id SERIAL PRIMARY KEY, + post_id INTEGER NOT NULL REFERENCES posts(id) ON DELETE CASCADE, + user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + content TEXT NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), +); diff --git a/testdata/cic_with_additional_queries/4_fourth/down.sql b/testdata/cic_with_additional_queries/4_fourth/down.sql new file mode 100644 index 0000000..be079dc --- /dev/null +++ b/testdata/cic_with_additional_queries/4_fourth/down.sql @@ -0,0 +1,6 @@ +-- Drop new indexes +DROP INDEX idx_users_created_at; +DROP INDEX idx_users_email; + +-- Drop new column +ALTER TABLE users DROP COLUMN last_login; diff --git a/testdata/cic_with_additional_queries/4_fourth/up.sql b/testdata/cic_with_additional_queries/4_fourth/up.sql new file mode 100644 index 0000000..7b16a34 --- /dev/null +++ b/testdata/cic_with_additional_queries/4_fourth/up.sql @@ -0,0 +1,9 @@ +-- Add and backfill last_login column +ALTER TABLE users ADD COLUMN last_login TIMESTAMP; +UPDATE users SET last_login = NOW() WHERE email IN ('user1@example.com', 'user2@example.com'); + +-- Create a index concurrently +CREATE INDEX CONCURRENTLY idx_users_email ON users (email); + +-- Create another index +CREATE INDEX idx_users_created_at ON users (created_at); diff --git a/testdata/duplicate_identifiers/1_first/down.sql b/testdata/duplicate_identifiers/1_first/down.sql new file mode 100755 index 0000000..26ce8a7 --- /dev/null +++ b/testdata/duplicate_identifiers/1_first/down.sql @@ -0,0 +1,2 @@ +-- Drop the users table +DROP TABLE IF EXISTS users; diff --git a/testdata/duplicate_identifiers/1_first/up.sql b/testdata/duplicate_identifiers/1_first/up.sql new file mode 100755 index 0000000..14d93e8 --- /dev/null +++ b/testdata/duplicate_identifiers/1_first/up.sql @@ -0,0 +1,7 @@ +-- Create the users table +CREATE TABLE users ( + id SERIAL PRIMARY KEY, + username TEXT NOT NULL UNIQUE, + email TEXT NOT NULL UNIQUE, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); diff --git a/testdata/duplicate_identifiers/2_second/down.sql b/testdata/duplicate_identifiers/2_second/down.sql new file mode 100755 index 0000000..d032b51 --- /dev/null +++ b/testdata/duplicate_identifiers/2_second/down.sql @@ -0,0 +1,2 @@ +-- Drop the posts table +DROP TABLE IF EXISTS posts; diff --git a/testdata/duplicate_identifiers/2_second/up.sql b/testdata/duplicate_identifiers/2_second/up.sql new file mode 100755 index 0000000..25a7f72 --- /dev/null +++ b/testdata/duplicate_identifiers/2_second/up.sql @@ -0,0 +1,8 @@ +-- Create the posts table +CREATE TABLE posts ( + id SERIAL PRIMARY KEY, + user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + title TEXT NOT NULL, + content TEXT, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), +); diff --git a/testdata/duplicate_identifiers/2_third/down.sql b/testdata/duplicate_identifiers/2_third/down.sql new file mode 100755 index 0000000..549f08a --- /dev/null +++ b/testdata/duplicate_identifiers/2_third/down.sql @@ -0,0 +1,2 @@ +-- Drop the comments table +DROP TABLE IF EXISTS comments; diff --git a/testdata/duplicate_identifiers/2_third/up.sql b/testdata/duplicate_identifiers/2_third/up.sql new file mode 100755 index 0000000..9af4f1e --- /dev/null +++ b/testdata/duplicate_identifiers/2_third/up.sql @@ -0,0 +1,8 @@ +-- Create the comments table +CREATE TABLE comments ( + id SERIAL PRIMARY KEY, + post_id INTEGER NOT NULL REFERENCES posts(id) ON DELETE CASCADE, + user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + content TEXT NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), +); diff --git a/testdata/valid/1_first/down.sql b/testdata/valid/1_first/down.sql new file mode 100755 index 0000000..26ce8a7 --- /dev/null +++ b/testdata/valid/1_first/down.sql @@ -0,0 +1,2 @@ +-- Drop the users table +DROP TABLE IF EXISTS users; diff --git a/testdata/valid/1_first/up.sql b/testdata/valid/1_first/up.sql new file mode 100755 index 0000000..14d93e8 --- /dev/null +++ b/testdata/valid/1_first/up.sql @@ -0,0 +1,7 @@ +-- Create the users table +CREATE TABLE users ( + id SERIAL PRIMARY KEY, + username TEXT NOT NULL UNIQUE, + email TEXT NOT NULL UNIQUE, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() +); diff --git a/testdata/valid/2_second/down.sql b/testdata/valid/2_second/down.sql new file mode 100755 index 0000000..d032b51 --- /dev/null +++ b/testdata/valid/2_second/down.sql @@ -0,0 +1,2 @@ +-- Drop the posts table +DROP TABLE IF EXISTS posts; diff --git a/testdata/valid/2_second/up.sql b/testdata/valid/2_second/up.sql new file mode 100755 index 0000000..25a7f72 --- /dev/null +++ b/testdata/valid/2_second/up.sql @@ -0,0 +1,8 @@ +-- Create the posts table +CREATE TABLE posts ( + id SERIAL PRIMARY KEY, + user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + title TEXT NOT NULL, + content TEXT, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), +); diff --git a/testdata/valid/3_third/down.sql b/testdata/valid/3_third/down.sql new file mode 100755 index 0000000..549f08a --- /dev/null +++ b/testdata/valid/3_third/down.sql @@ -0,0 +1,2 @@ +-- Drop the comments table +DROP TABLE IF EXISTS comments; diff --git a/testdata/valid/3_third/up.sql b/testdata/valid/3_third/up.sql new file mode 100755 index 0000000..9af4f1e --- /dev/null +++ b/testdata/valid/3_third/up.sql @@ -0,0 +1,8 @@ +-- Create the comments table +CREATE TABLE comments ( + id SERIAL PRIMARY KEY, + post_id INTEGER NOT NULL REFERENCES posts(id) ON DELETE CASCADE, + user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + content TEXT NOT NULL, + created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), +); diff --git a/testing.go b/testing.go index a95dd79..a3837fe 100644 --- a/testing.go +++ b/testing.go @@ -29,8 +29,8 @@ func NewTestDBWithLogger(t testing.TB, logger log.Logger) DB { quotedTemplateDatabaseName = pq.QuoteIdentifier(os.Getenv("TEMPLATEDB")) // NOTE: Must interpolate identifiers here as placeholders aren't valid in this position. - createDatabaseQuery = Query(fmt.Sprintf("CREATE DATABASE %s TEMPLATE %s", quotedTestDatabaseName, quotedTemplateDatabaseName), Args{}) - dropDatabaseQuery = Query(fmt.Sprintf("DROP DATABASE %s", quotedTestDatabaseName), Args{}) + createDatabaseQuery = queryf("CREATE DATABASE %s TEMPLATE %s", quotedTestDatabaseName, quotedTemplateDatabaseName) + dropDatabaseQuery = queryf("DROP DATABASE %s", quotedTestDatabaseName) terminateConnectionsQuery = Query("SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = {:name}", Args{"name": testDatabaseName}) )