Skip to content

Commit

Permalink
WIP.
Browse files Browse the repository at this point in the history
  • Loading branch information
efritz committed Sep 14, 2024
1 parent a4143c7 commit a69c5d1
Show file tree
Hide file tree
Showing 47 changed files with 701 additions and 152 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
.vscode
vendor
migrations
21 changes: 6 additions & 15 deletions db_transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,8 @@ func createSavepoint(ctx context.Context, tx *loggingTx) (*loggingSavepoint, err
}
savepointID := fmt.Sprintf("sp_%s", id)

if err := tx.Exec(ctx, Query(
// NOTE: Must interpolate identifier here as placeholders aren't valid in this position.
fmt.Sprintf("SAVEPOINT %s", savepointID),
Args{},
)); err != nil {
// NOTE: Must interpolate identifier here as placeholders aren't valid in this position.
if err := tx.Exec(ctx, queryf("SAVEPOINT %s", savepointID)); err != nil {
return nil, err
}

Expand All @@ -85,18 +82,12 @@ func (tx *loggingSavepoint) Done(err error) (combinedErr error) {
defer func() { logDone(tx.logger, time.Since(tx.start), combinedErr) }()

if err != nil {
return errors.Join(err, tx.Exec(context.Background(), Query(
// NOTE: Must interpolate identifier here as placeholders aren't valid in this position.
fmt.Sprintf("ROLLBACK TO %s", tx.savepointID),
Args{},
)))
// NOTE: Must interpolate identifier here as placeholders aren't valid in this position.
return errors.Join(err, tx.Exec(context.Background(), queryf("ROLLBACK TO %s", tx.savepointID)))
}

return tx.Exec(context.Background(), Query(
// NOTE: Must interpolate identifier here as placeholders aren't valid in this position.
fmt.Sprintf("RELEASE %s", tx.savepointID),
Args{},
))
// NOTE: Must interpolate identifier here as placeholders aren't valid in this position.
return tx.Exec(context.Background(), queryf("RELEASE %s", tx.savepointID))
}

var ErrPanicDuringTransaction = fmt.Errorf("encountered panic during transaction")
Expand Down
36 changes: 8 additions & 28 deletions db_transaction_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,20 @@ func TestTransaction(t *testing.T) {
setupTestTransactionTable(t, db)

// Add record outside of transaction, ensure it's visible
err := db.Exec(context.Background(), Query(
`INSERT INTO test (x, y) VALUES (1, 42)`,
Args{},
))
err := db.Exec(context.Background(), RawQuery(`INSERT INTO test (x, y) VALUES (1, 42)`))
require.NoError(t, err)
assert.Equal(t, map[int]int{1: 42}, testTableContents(t, db))

// Add record inside of a transaction
tx1, err := db.Transact(context.Background())
require.NoError(t, err)
err = tx1.Exec(context.Background(), Query(
`INSERT INTO test (x, y) VALUES (2, 43)`,
Args{},
))
err = tx1.Exec(context.Background(), RawQuery(`INSERT INTO test (x, y) VALUES (2, 43)`))
require.NoError(t, err)

// Add record inside of another transaction
tx2, err := db.Transact(context.Background())
require.NoError(t, err)
err = tx2.Exec(context.Background(), Query(
`INSERT INTO test (x, y) VALUES (3, 44)`,
Args{},
))
err = tx2.Exec(context.Background(), RawQuery(`INSERT INTO test (x, y) VALUES (3, 44)`))
require.NoError(t, err)

// Check what's visible pre-commit/rollback
Expand Down Expand Up @@ -76,10 +67,7 @@ func TestConcurrentTransactions(t *testing.T) {
}
defer func() { err = tx.Done(err) }()

if err := tx.Exec(context.Background(), Query(
`SELECT pg_sleep(0.1)`,
Args{},
)); err != nil {
if err := tx.Exec(context.Background(), RawQuery(`SELECT pg_sleep(0.1)`)); err != nil {
return err
}

Expand Down Expand Up @@ -111,10 +99,7 @@ func TestConcurrentTransactions(t *testing.T) {
for i := 0; i < 10; i++ {
routine := i
g.Go(func() (err error) {
if err := tx.Exec(context.Background(), Query(
`SELECT pg_sleep(0.1);`,
Args{},
)); err != nil {
if err := tx.Exec(context.Background(), RawQuery(`SELECT pg_sleep(0.1);`)); err != nil {
return err
}

Expand Down Expand Up @@ -177,22 +162,17 @@ func recurSavepoints(t *testing.T, db DB, index, rollbackAt int) {
}

func setupTestTransactionTable(t *testing.T, db DB) {
require.NoError(t, db.Exec(context.Background(), Query(`
require.NoError(t, db.Exec(context.Background(), RawQuery(`
CREATE TABLE test (
id SERIAL PRIMARY KEY,
x INTEGER NOT NULL,
y INTEGER NOT NULL
);
`,
Args{},
)))
`)))
}

func testTableContents(t *testing.T, db DB) map[int]int {
pairs, err := scanTestPairs(db.Query(context.Background(), Query(
`SELECT x, y FROM test`,
Args{},
)))
pairs, err := scanTestPairs(db.Query(context.Background(), RawQuery(`SELECT x, y FROM test`)))
require.NoError(t, err)

pairsMap := make(map[int]int)
Expand Down
48 changes: 29 additions & 19 deletions migration_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ type MigrationReader interface {
ReadAll() ([]RawDefinition, error)
}

type MigrationReaderFunc func() ([]RawDefinition, error)

func (f MigrationReaderFunc) ReadAll() ([]RawDefinition, error) {
return f()
}

type RawDefinition struct {
ID int
Name string
Expand All @@ -31,22 +37,26 @@ type RawDefinition struct {
}

var (
identifierPattern = `[a-zA-Z0-9$_]+|"(?:[^"]+)"`

cicPatternParts = strings.Join([]string{
`CREATE`,
`(?:UNIQUE)?`,
`INDEX`,
`CONCURRENTLY`,
`(?:IF\s+NOT\s+EXISTS)?`,
`(` + identifierPattern + `)`, // capture index name
`ON`,
`(?:ONLY)?`,
`(` + identifierPattern + `)`, // capture table name
}, `\s+`)

createIndexConcurrentlyPattern = regexp.MustCompile(cicPatternParts)
createIndexConcurrentlyPatternAll = regexp.MustCompile(cicPatternParts + "[^;]+;")
keyword = func(pattern string) string { return phrase(pattern) }
phrase = func(patterns ...string) string { return strings.Join(patterns, `\s+`) + `\s+` }
opt = func(pattern string) string { return `(?:` + pattern + `)?` }

capturedIdentifierPattern = `([a-zA-Z0-9$_]+|"(?:[^"]+)")`
createIndexConcurrentlyPatternHead = strings.Join([]string{
keyword(`CREATE`),
opt(keyword(`UNIQUE`)),
keyword(`INDEX`),
opt(keyword(`CONCURRENTLY`)),
opt(phrase(`IF`, `NOT`, `EXISTS`)),
capturedIdentifierPattern, // capture index name
`\s+`,
keyword(`ON`),
opt(keyword(`ONLY`)),
capturedIdentifierPattern, // capture table name
}, ``)

createIndexConcurrentlyPattern = regexp.MustCompile(createIndexConcurrentlyPatternHead)
createIndexConcurrentlyPatternAll = regexp.MustCompile(createIndexConcurrentlyPatternHead + "[^;]+;")
)

func ReadMigrations(reader MigrationReader) (definitions []Definition, _ error) {
Expand All @@ -68,7 +78,7 @@ func ReadMigrations(reader MigrationReader) (definitions []Definition, _ error)

if matches := createIndexConcurrentlyPattern.FindStringSubmatch(prunedUp); len(matches) > 0 {
if strings.TrimSpace(createIndexConcurrentlyPatternAll.ReplaceAllString(prunedUp, "")) != "" {
return nil, fmt.Errorf("CIC is not the only statement in the up migration")
return nil, fmt.Errorf(`"create index concurrently" is not the only statement in the up migration`)
}

indexMetadata = &IndexMetadata{
Expand All @@ -77,8 +87,8 @@ func ReadMigrations(reader MigrationReader) (definitions []Definition, _ error)
}
}

if len(createIndexConcurrentlyPatternAll.FindAllString(prunedDown, 1)) > 0 {
return nil, fmt.Errorf("CIC is not allowed in down migrations")
if len(createIndexConcurrentlyPattern.FindAllString(prunedDown, 1)) > 0 {
return nil, fmt.Errorf(`"create index concurrently" is not allowed in down migrations`)
}

definitions = append(definitions, Definition{
Expand Down
6 changes: 3 additions & 3 deletions migration_reader_filesystem.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"io/fs"
"net/http"
"os"
"path/filepath"
"path"
"sort"
"strconv"
"strings"
Expand Down Expand Up @@ -52,8 +52,8 @@ func (r *FilesystemMigrationReader) ReadAll() (definitions []RawDefinition, _ er
}

func (r *FilesystemMigrationReader) readDefinition(dirname string) (RawDefinition, bool, error) {
upPath := filepath.Join(dirname, "up.sql") // TODO - filepath join for embed?
downPath := filepath.Join(dirname, "down.sql")
upPath := path.Join(dirname, "up.sql")
downPath := path.Join(dirname, "down.sql")

upFileContents, upErr := readFile(r.fs, upPath)
downFileContents, downErr := readFile(r.fs, downPath)
Expand Down
77 changes: 77 additions & 0 deletions migration_reader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package pgutil

import (
"path"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

var testUpQuery = `-- Create the comments table
CREATE TABLE comments (
id SERIAL PRIMARY KEY,
post_id INTEGER NOT NULL REFERENCES posts(id) ON DELETE CASCADE,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
content TEXT NOT NULL,
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
);
`

var testDownQuery = `-- Drop the comments table
DROP TABLE IF EXISTS comments;
`

var testConcurrentIndexUpQuery = `-- Create a concurrent index
CREATE INDEX CONCURRENTLY idx_users_email ON users (email);`

var testConcurrentIndexDownQuery = `-- Drop the concurrent index
DROP INDEX CONCURRENTLY IF EXISTS idx_users_email;`

func TestReadMigrations(t *testing.T) {
t.Run("valid", func(t *testing.T) {
definitions, err := ReadMigrations(NewFilesystemMigrationReader(path.Join("testdata", "valid")))
require.NoError(t, err)
require.Len(t, definitions, 3)

assert.Equal(t, Definition{
ID: 3,
Name: "third",
UpQuery: RawQuery(testUpQuery),
DownQuery: RawQuery(testDownQuery),
}, definitions[2])
})

t.Run("CIC pattern", func(t *testing.T) {
t.Skip()
definitions, err := ReadMigrations(NewFilesystemMigrationReader(path.Join("testdata", "cic_pattern")))
require.NoError(t, err)
require.Len(t, definitions, 4)

assert.Equal(t, Definition{
ID: 3,
Name: "third",
UpQuery: RawQuery(testConcurrentIndexUpQuery),
DownQuery: RawQuery(testConcurrentIndexDownQuery),
IndexMetadata: &IndexMetadata{
TableName: "users",
IndexName: "idx_users_email",
},
}, definitions[3])
})

t.Run("duplicate identifiers", func(t *testing.T) {
_, err := ReadMigrations(NewFilesystemMigrationReader(path.Join("testdata", "duplicate_identifiers")))
assert.ErrorContains(t, err, "duplicate migration identifier 2")
})

t.Run("CIC with additional queries", func(t *testing.T) {
_, err := ReadMigrations(NewFilesystemMigrationReader(path.Join("testdata", "cic_with_additional_queries")))
assert.ErrorContains(t, err, `"create index concurrently" is not the only statement in the up migration`)
})

t.Run("CIC in down migration", func(t *testing.T) {
_, err := ReadMigrations(NewFilesystemMigrationReader(path.Join("testdata", "cic_in_down_migration")))
assert.ErrorContains(t, err, `"create index concurrently" is not the only statement in the up migration`)
})
}
Loading

0 comments on commit a69c5d1

Please sign in to comment.