diff --git a/cmd/balance-worker/main.go b/cmd/balance-worker/main.go index 8261f6c4c..ef0010da6 100644 --- a/cmd/balance-worker/main.go +++ b/cmd/balance-worker/main.go @@ -239,7 +239,7 @@ func main() { entClient := entPostgresDriver.Client() - if err := startup.DB(ctx, conf.Postgres, entClient); err != nil { + if err := startup.DB(ctx, conf.Postgres, entClient, postgresDriver.DB()); err != nil { logger.Error("failed to initialize database", "error", err) os.Exit(1) } diff --git a/cmd/notification-service/main.go b/cmd/notification-service/main.go index bc1cffa20..0c8b365d0 100644 --- a/cmd/notification-service/main.go +++ b/cmd/notification-service/main.go @@ -242,7 +242,7 @@ func main() { entClient := entPostgresDriver.Client() // Run database schema creation - if err := startup.DB(ctx, conf.Postgres, entClient); err != nil { + if err := startup.DB(ctx, conf.Postgres, entClient, postgresDriver.DB()); err != nil { logger.Error("failed to initialize database", "error", err) os.Exit(1) } diff --git a/cmd/server/main.go b/cmd/server/main.go index 5cf92a7aa..69704b17e 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -351,7 +351,7 @@ func main() { entClient := entPostgresDriver.Client() - if err := startup.DB(ctx, conf.Postgres, entClient); err != nil { + if err := startup.DB(ctx, conf.Postgres, entClient, postgresDriver.DB()); err != nil { logger.Error("failed to initialize database", "error", err) os.Exit(1) } diff --git a/go.mod b/go.mod index 88bc25eb7..9cd74d009 100644 --- a/go.mod +++ b/go.mod @@ -263,6 +263,7 @@ require ( github.com/itchyny/timefmt-go v0.1.5 // indirect github.com/jackc/chunkreader/v2 v2.0.1 // indirect github.com/jackc/pgconn v1.14.3 // indirect + github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa // indirect github.com/jackc/pgio v1.0.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgproto3/v2 v2.3.3 // indirect diff --git a/go.sum b/go.sum index d9096a79f..a28736c3e 100644 --- a/go.sum +++ b/go.sum @@ -785,6 +785,8 @@ github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8 github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= github.com/jackc/pgconn v1.14.3 h1:bVoTr12EGANZz66nZPkMInAV/KHD2TxH9npjXXgiB3w= github.com/jackc/pgconn v1.14.3/go.mod h1:RZbme4uasqzybK2RK5c65VsHxoyaml09lx3tXOcO/VM= +github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa h1:s+4MhCQ6YrzisK6hFJUX53drDT4UsSW3DEhKn0ifuHw= +github.com/jackc/pgerrcode v0.0.0-20220416144525-469b46aa5efa/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds= github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= diff --git a/openmeter/entitlement/metered/utils_test.go b/openmeter/entitlement/metered/utils_test.go index ef74b3fae..ba30a9804 100644 --- a/openmeter/entitlement/metered/utils_test.go +++ b/openmeter/entitlement/metered/utils_test.go @@ -77,8 +77,12 @@ func setupConnector(t *testing.T) (meteredentitlement.Connector, *dependencies) m.Lock() defer m.Unlock() // migrate db - if err := migrate.Up(testdb.URL); err != nil { - t.Fatalf("failed to migrate db: %s", err.Error()) + if m, err := migrate.Default(testdb.SQLDriver); err == nil { + if err := m.Up(); err != nil { + t.Fatalf("failed to migrate db: %s", err.Error()) + } + } else { + t.Fatalf("failed to create migrate instance: %s", err.Error()) } mockPublisher := eventbus.NewMock(t) diff --git a/openmeter/productcatalog/adapter/feature_test.go b/openmeter/productcatalog/adapter/feature_test.go index 12d3959cd..d37316295 100644 --- a/openmeter/productcatalog/adapter/feature_test.go +++ b/openmeter/productcatalog/adapter/feature_test.go @@ -220,8 +220,12 @@ func TestCreateFeature(t *testing.T) { dbClient := testdb.EntDriver.Client() defer dbClient.Close() - if err := migrate.Up(testdb.URL); err != nil { - t.Fatalf("failed to migrate db: %s", err.Error()) + if m, err := migrate.Default(testdb.SQLDriver); err == nil { + if err := m.Up(); err != nil { + t.Fatalf("failed to migrate db: %s", err.Error()) + } + } else { + t.Fatalf("failed to create migrate instance: %s", err.Error()) } dbConnector := adapter.NewPostgresFeatureRepo(dbClient, testutils.NewLogger(t)) @@ -239,8 +243,12 @@ func TestCreateFeature(t *testing.T) { dbClient := testdb.EntDriver.Client() defer dbClient.Close() - if err := migrate.Up(testdb.URL); err != nil { - t.Fatalf("failed to migrate db: %s", err.Error()) + if m, err := migrate.Default(testdb.SQLDriver); err == nil { + if err := m.Up(); err != nil { + t.Fatalf("failed to migrate db: %s", err.Error()) + } + } else { + t.Fatalf("failed to create migrate instance: %s", err.Error()) } dbConnector := adapter.NewPostgresFeatureRepo(dbClient, testutils.NewLogger(t)) diff --git a/openmeter/registry/startup/db.go b/openmeter/registry/startup/db.go index 0d800333e..5e9264f29 100644 --- a/openmeter/registry/startup/db.go +++ b/openmeter/registry/startup/db.go @@ -2,6 +2,7 @@ package startup import ( "context" + "database/sql" "fmt" "github.com/openmeterio/openmeter/config" @@ -9,19 +10,23 @@ import ( "github.com/openmeterio/openmeter/tools/migrate" ) -func DB(ctx context.Context, cfg config.PostgresConfig, db *db.Client) error { +func DB(ctx context.Context, cfg config.PostgresConfig, client *db.Client, db *sql.DB) error { if !cfg.AutoMigrate.Enabled() { return nil } switch cfg.AutoMigrate { case config.AutoMigrateEnt: - if err := db.Schema.Create(ctx); err != nil { + if err := client.Schema.Create(ctx); err != nil { return fmt.Errorf("failed to migrate db: %w", err) } case config.AutoMigrateMigration: - if err := migrate.Up(cfg.URL); err != nil { - return fmt.Errorf("failed to migrate db: %w", err) + if m, err := migrate.Default(db); err == nil { + if err := m.Up(); err != nil { + return fmt.Errorf("failed to migrate db: %w", err) + } + } else { + return fmt.Errorf("failed to create migrate instance: %w", err) } } diff --git a/openmeter/testutils/pg_driver.go b/openmeter/testutils/pg_driver.go index a3d74855b..21787b155 100644 --- a/openmeter/testutils/pg_driver.go +++ b/openmeter/testutils/pg_driver.go @@ -91,6 +91,7 @@ func InitPostgresDB(t *testing.T) *TestDB { return &TestDB{ PGDriver: postgresDriver, + SQLDriver: postgresDriver.DB(), EntDriver: entDriver, URL: dbConf.URL(), } diff --git a/test/entitlement/regression/framework_test.go b/test/entitlement/regression/framework_test.go index 7ddd65f3a..6b09c62e7 100644 --- a/test/entitlement/regression/framework_test.go +++ b/test/entitlement/regression/framework_test.go @@ -62,8 +62,14 @@ func setupDependencies(t *testing.T) Dependencies { driver := testutils.InitPostgresDB(t) // init db dbClient := db.NewClient(db.Driver(driver.EntDriver.Driver())) - if err := migrate.Up(driver.URL); err != nil { - t.Fatalf("failed to migrate db: %s", err.Error()) + + // Migrate + if m, err := migrate.Default(driver.SQLDriver); err == nil { + if err := m.Up(); err != nil { + t.Fatalf("failed to migrate db: %s", err.Error()) + } + } else { + t.Fatalf("failed to create migrate instance: %s", err.Error()) } // Init product catalog diff --git a/tools/migrate/migrate.go b/tools/migrate/migrate.go index 96f6ed7bd..1f2f36498 100644 --- a/tools/migrate/migrate.go +++ b/tools/migrate/migrate.go @@ -2,12 +2,12 @@ package migrate import ( + "database/sql" "embed" "io/fs" - "net/url" "github.com/golang-migrate/migrate/v4" - _ "github.com/golang-migrate/migrate/v4/database/postgres" + "github.com/golang-migrate/migrate/v4/database/pgx" "github.com/golang-migrate/migrate/v4/source/iofs" ) @@ -20,42 +20,35 @@ type Migrate = migrate.Migrate //go:embed migrations var OMMigrations embed.FS +type Options struct { + DB *sql.DB + FS fs.FS + FSPath string + PGConfig *pgx.Config +} + // NewMigrate creates a new migrate instance. -func NewMigrate(conn string, fs fs.FS, fsPath string) (*Migrate, error) { - d, err := iofs.New(fs, fsPath) +func NewMigrate(opt Options) (*Migrate, error) { + d, err := iofs.New(opt.FS, opt.FSPath) if err != nil { return nil, err } - return migrate.NewWithSourceInstance("iofs", d, conn) -} -func Up(conn string) error { - conn, err := SetMigrationTableName(conn, MigrationsTable) + driver, err := pgx.WithInstance(opt.DB, opt.PGConfig) if err != nil { - return err - } - m, err := NewMigrate(conn, OMMigrations, "migrations") - if err != nil { - return err + return nil, err } - defer m.Close() - err = m.Up() - if err != nil && err != migrate.ErrNoChange { - return err - } - return nil + return migrate.NewWithInstance("iofs", d, "postgres", driver) } -func SetMigrationTableName(conn, tableName string) (string, error) { - parsedURL, err := url.Parse(conn) - if err != nil { - return "", err - } - - values := parsedURL.Query() - values.Set("x-migrations-table", tableName) - parsedURL.RawQuery = values.Encode() - - return parsedURL.String(), nil +func Default(db *sql.DB) (*Migrate, error) { + return NewMigrate(Options{ + DB: db, + FS: OMMigrations, + FSPath: "migrations", + PGConfig: &pgx.Config{ + MigrationsTable: MigrationsTable, + }, + }) } diff --git a/tools/migrate/migrate_test.go b/tools/migrate/migrate_test.go index 1a3775ef8..942485951 100644 --- a/tools/migrate/migrate_test.go +++ b/tools/migrate/migrate_test.go @@ -12,7 +12,7 @@ func TestUpDownUp(t *testing.T) { testDB := testutils.InitPostgresDB(t) defer testDB.PGDriver.Close() - migrator, err := migrate.NewMigrate(testDB.URL, migrate.OMMigrations, "migrations") + migrator, err := migrate.Default(testDB.SQLDriver) if err != nil { t.Fatal(err) }