From ba8aaaeca5ca7bf9da1fe6c5a361de6709cad334 Mon Sep 17 00:00:00 2001 From: zepatrik Date: Mon, 7 Oct 2024 14:36:59 +0200 Subject: [PATCH] chore(popx): code simplification and refactoring --- pkgerx/file.go | 30 --- pkgerx/migration_box.go | 187 ------------------ pkgerx/migration_box_test.go | 43 ---- pkgerx/sql_template_funcs.go | 22 --- ...sql_create_tablename_template.expected.sql | 0 .../0_sql_create_tablename_template.up.sql | 0 popx/migration_box_gomigration_test.go | 9 +- popx/migration_box_template_test.go | 4 +- popx/migration_box_testdata_test.go | 5 +- popx/migration_info.go | 5 +- popx/migrator.go | 81 +++----- popx/migrator_test.go | 155 +++------------ popx/test_migrator.go | 2 +- 13 files changed, 69 insertions(+), 474 deletions(-) delete mode 100644 pkgerx/file.go delete mode 100644 pkgerx/migration_box.go delete mode 100644 pkgerx/migration_box_test.go delete mode 100644 pkgerx/sql_template_funcs.go delete mode 100644 pkgerx/testdata/0_sql_create_tablename_template.expected.sql delete mode 100644 pkgerx/testdata/0_sql_create_tablename_template.up.sql diff --git a/pkgerx/file.go b/pkgerx/file.go deleted file mode 100644 index f5c21594..00000000 --- a/pkgerx/file.go +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright © 2023 Ory Corp -// SPDX-License-Identifier: Apache-2.0 - -package pkgerx - -import ( - "io" - - "github.com/ory/x/ioutilx" - - "github.com/markbates/pkger/pkging" -) - -// MustRead reads a pkging.File or panics. -func MustRead(f pkging.File, err error) []byte { - if err != nil { - panic(err) - } - defer (func() { _ = f.Close() })() - return ioutilx.MustReadAll(f) -} - -// Read reads a pkging.File or returns an error -func Read(f pkging.File, err error) ([]byte, error) { - if err != nil { - return nil, err - } - defer (func() { _ = f.Close() })() - return io.ReadAll(f) -} diff --git a/pkgerx/migration_box.go b/pkgerx/migration_box.go deleted file mode 100644 index 9820ce7d..00000000 --- a/pkgerx/migration_box.go +++ /dev/null @@ -1,187 +0,0 @@ -// Copyright © 2023 Ory Corp -// SPDX-License-Identifier: Apache-2.0 - -package pkgerx - -import ( - "bytes" - "io" - "os" - "strings" - "text/template" - - "github.com/gobuffalo/fizz" - "github.com/gobuffalo/pop/v6" - "github.com/markbates/pkger" - "github.com/pkg/errors" - - "github.com/ory/x/logrusx" -) - -type ( - // MigrationBox is a wrapper around pkger.Dir and Migrator. - // This will allow you to run migrations from migrations packed - // inside of a compiled binary. - MigrationBox struct { - pop.Migrator - - Dir pkger.Dir - l *logrusx.Logger - migrationContent MigrationContent - } - MigrationContent func(mf pop.Migration, c *pop.Connection, r io.Reader, usingTemplate bool) (string, error) -) - -func templatingMigrationContent(params map[string]interface{}) func(pop.Migration, *pop.Connection, io.Reader, bool) (string, error) { - return func(mf pop.Migration, c *pop.Connection, r io.Reader, usingTemplate bool) (string, error) { - b, err := io.ReadAll(r) - if err != nil { - return "", nil - } - - content := "" - if usingTemplate { - t := template.New("migration") - t.Funcs(SQLTemplateFuncs) - t, err := t.Parse(string(b)) - if err != nil { - return "", err - } - - var bb bytes.Buffer - err = t.Execute(&bb, struct { - DialectDetails *pop.ConnectionDetails - Parameters map[string]interface{} - }{ - DialectDetails: c.Dialect.Details(), - Parameters: params, - }) - if err != nil { - return "", errors.Wrapf(err, "could not execute migration template %s", mf.Path) - } - content = bb.String() - } else { - content = string(b) - } - - if mf.Type == "fizz" { - content, err = fizz.AString(content, c.Dialect.FizzTranslator()) - if err != nil { - return "", errors.Wrapf(err, "could not fizz the migration %s", mf.Path) - } - } - - return content, nil - } -} - -func WithTemplateValues(v map[string]interface{}) func(*MigrationBox) *MigrationBox { - return func(m *MigrationBox) *MigrationBox { - m.migrationContent = templatingMigrationContent(v) - return m - } -} - -func WithMigrationContentMiddleware(middleware func(content string, err error) (string, error)) func(*MigrationBox) *MigrationBox { - return func(m *MigrationBox) *MigrationBox { - prev := m.migrationContent - m.migrationContent = func(mf pop.Migration, c *pop.Connection, r io.Reader, usingTemplate bool) (string, error) { - return middleware(prev(mf, c, r, usingTemplate)) - } - return m - } -} - -// NewMigrationBox from a packr.Dir and a Connection. -// -// migrations, err := NewMigrationBox(pkger.Dir("/migrations")) -func NewMigrationBox(dir pkger.Dir, c *pop.Connection, l *logrusx.Logger, opts ...func(*MigrationBox) *MigrationBox) (*MigrationBox, error) { - mb := &MigrationBox{ - Migrator: pop.NewMigrator(c), - Dir: dir, - l: l, - migrationContent: pop.MigrationContent, - } - - for _, o := range opts { - mb = o(mb) - } - - runner := func(f io.Reader) func(mf pop.Migration, tx *pop.Connection) error { - return func(mf pop.Migration, tx *pop.Connection) error { - content, err := mb.migrationContent(mf, tx, f, true) - if err != nil { - return errors.Wrapf(err, "error processing %s", mf.Path) - } - if content == "" { - return nil - } - err = tx.RawQuery(content).Exec() - if err != nil { - return errors.Wrapf(err, "error executing %s, sql: %s", mf.Path, content) - } - return nil - } - } - - err := mb.findMigrations(runner) - if err != nil { - return mb, err - } - - return mb, nil -} - -func (fm *MigrationBox) findMigrations(runner func(f io.Reader) func(mf pop.Migration, tx *pop.Connection) error) error { - return pkger.Walk(string(fm.Dir), func(p string, info os.FileInfo, err error) error { - if err != nil { - return errors.WithStack(err) - } - - match, err := pop.ParseMigrationFilename(info.Name()) - if err != nil { - if strings.HasPrefix(err.Error(), "unsupported dialect") { - fm.l.Debugf("Ignoring migration file %s because dialect is not supported: %s", info.Name(), err.Error()) - return nil - } - return errors.WithStack(err) - } - - if match == nil { - fm.l.Debugf("Ignoring migration file %s because it does not match the file pattern.", info.Name()) - return nil - } - - file, err := pkger.Open(p) - if err != nil { - return errors.WithStack(err) - } - defer file.Close() - - content, err := io.ReadAll(file) - if err != nil { - return errors.WithStack(err) - } - - mf := pop.Migration{ - Path: p, - Version: match.Version, - Name: match.Name, - DBType: match.DBType, - Direction: match.Direction, - Type: match.Type, - Runner: runner(bytes.NewReader(content)), - } - - switch mf.Direction { - case "down": - fm.DownMigrations.Migrations = append(fm.DownMigrations.Migrations, mf) - case "up": - fm.UpMigrations.Migrations = append(fm.UpMigrations.Migrations, mf) - default: - return errors.Errorf("unknown direction %s", mf.Direction) - } - - return nil - }) -} diff --git a/pkgerx/migration_box_test.go b/pkgerx/migration_box_test.go deleted file mode 100644 index bb518e3e..00000000 --- a/pkgerx/migration_box_test.go +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright © 2023 Ory Corp -// SPDX-License-Identifier: Apache-2.0 - -package pkgerx - -import ( - "os" - "path/filepath" - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/ory/x/sqlcon/dockertest" - - "github.com/markbates/pkger" - "github.com/stretchr/testify/require" - - "github.com/ory/x/logrusx" -) - -var testData = pkger.Dir("github.com/ory/x:/pkgerx/testdata") - -func TestMigrationBoxTemplating(t *testing.T) { - templateVals := map[string]interface{}{ - "tableName": "test_table_name", - } - - expectedMigration, err := os.ReadFile(filepath.Join("testdata", "0_sql_create_tablename_template.expected.sql")) - require.NoError(t, err) - - c := dockertest.ConnectToTestCockroachDBPop(t) - - mb, err := NewMigrationBox(testData, c, logrusx.New("", ""), WithTemplateValues(templateVals), WithMigrationContentMiddleware(func(content string, err error) (string, error) { - require.NoError(t, err) - assert.Equal(t, string(expectedMigration), content) - - return content, err - })) - require.NoError(t, err) - - err = mb.Up() - require.NoError(t, err) -} diff --git a/pkgerx/sql_template_funcs.go b/pkgerx/sql_template_funcs.go deleted file mode 100644 index 223f55c5..00000000 --- a/pkgerx/sql_template_funcs.go +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright © 2023 Ory Corp -// SPDX-License-Identifier: Apache-2.0 - -package pkgerx - -import ( - "fmt" - "regexp" -) - -var SQLTemplateFuncs = map[string]interface{}{ - "identifier": Identifier, -} - -var identifierPattern = regexp.MustCompile("^[a-zA-Z][a-zA-Z0-9_]*$") - -func Identifier(i string) (string, error) { - if !identifierPattern.MatchString(i) { - return "", fmt.Errorf("invalid SQL identifier '%s'", i) - } - return i, nil -} diff --git a/pkgerx/testdata/0_sql_create_tablename_template.expected.sql b/pkgerx/testdata/0_sql_create_tablename_template.expected.sql deleted file mode 100644 index e69de29b..00000000 diff --git a/pkgerx/testdata/0_sql_create_tablename_template.up.sql b/pkgerx/testdata/0_sql_create_tablename_template.up.sql deleted file mode 100644 index e69de29b..00000000 diff --git a/popx/migration_box_gomigration_test.go b/popx/migration_box_gomigration_test.go index b685c6f0..3bf6920f 100644 --- a/popx/migration_box_gomigration_test.go +++ b/popx/migration_box_gomigration_test.go @@ -6,6 +6,7 @@ package popx_test import ( "context" "database/sql" + "github.com/ory/x/dbal" "math/rand" "testing" "time" @@ -78,7 +79,7 @@ func TestGoMigrations(t *testing.T) { called = make([]time.Time, len(goMigrations)) c, err := pop.NewConnection(&pop.ConnectionDetails{ - URL: "sqlite://file::memory:?_fk=true", + URL: dbal.NewSQLiteTestDatabase(t), }) require.NoError(t, err) require.NoError(t, c.Open()) @@ -101,7 +102,7 @@ func TestGoMigrations(t *testing.T) { t.Run("tc=errs_on_missing_down_migration", func(t *testing.T) { c, err := pop.NewConnection(&pop.ConnectionDetails{ - URL: "sqlite://file::memory:?_fk=true", + URL: dbal.NewSQLiteTestDatabase(t), }) require.NoError(t, err) require.NoError(t, c.Open()) @@ -112,7 +113,7 @@ func TestGoMigrations(t *testing.T) { t.Run("tc=runs everything in one transaction", func(t *testing.T) { c, err := pop.NewConnection(&pop.ConnectionDetails{ - URL: "sqlite://file::memory:?_fk=true", + URL: dbal.NewSQLiteTestDatabase(t), }) require.NoError(t, err) require.NoError(t, c.Open()) @@ -223,7 +224,7 @@ func TestIncompatibleRunners(t *testing.T) { func TestNoTransaction(t *testing.T) { c, err := pop.NewConnection(&pop.ConnectionDetails{ - URL: "sqlite://file::memory:", + URL: dbal.NewSQLiteTestDatabase(t), }) require.NoError(t, err) require.NoError(t, c.Open()) diff --git a/popx/migration_box_template_test.go b/popx/migration_box_template_test.go index 8d082b6c..58377641 100644 --- a/popx/migration_box_template_test.go +++ b/popx/migration_box_template_test.go @@ -8,10 +8,10 @@ import ( "testing" "github.com/gobuffalo/pop/v6" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/ory/x/dbal" "github.com/ory/x/logrusx" ) @@ -27,7 +27,7 @@ func TestMigrationBoxTemplating(t *testing.T) { require.NoError(t, err) c, err := pop.NewConnection(&pop.ConnectionDetails{ - URL: "sqlite://file::memory:?_fk=true", + URL: dbal.NewSQLiteTestDatabase(t), }) require.NoError(t, err) require.NoError(t, c.Open()) diff --git a/popx/migration_box_testdata_test.go b/popx/migration_box_testdata_test.go index fe6860a7..3455aa7b 100644 --- a/popx/migration_box_testdata_test.go +++ b/popx/migration_box_testdata_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/ory/x/dbal" "github.com/ory/x/logrusx" "github.com/ory/x/popx" ) @@ -31,7 +32,7 @@ type testdata struct { func TestMigrationBoxWithTestdata(t *testing.T) { c, err := pop.NewConnection(&pop.ConnectionDetails{ - URL: "sqlite://file::memory:?_fk=true", + URL: dbal.NewSQLiteTestDatabase(t), }) require.NoError(t, err) require.NoError(t, c.Open()) @@ -56,7 +57,7 @@ func TestMigrationBoxWithTestdata(t *testing.T) { func TestMigrationBox_CheckNoErr(t *testing.T) { c, err := pop.NewConnection(&pop.ConnectionDetails{ - URL: "sqlite://file::memory:?_fk=true", + URL: dbal.NewSQLiteTestDatabase(t), }) require.NoError(t, err) require.NoError(t, c.Open()) diff --git a/popx/migration_info.go b/popx/migration_info.go index 92b9738e..2ec94ee7 100644 --- a/popx/migration_info.go +++ b/popx/migration_info.go @@ -65,10 +65,11 @@ func (mfs Migrations) Swap(i, j int) { func (mfs Migrations) SortAndFilter(dialect string, modifiers ...func(sort.Interface) sort.Interface) Migrations { // We need to sort mfs in order to push the dbType=="all" migrations // to the back. - m := append(Migrations{}, mfs...) + m := make(Migrations, len(mfs)) + copy(m, mfs) sort.Sort(m) - vsf := make(Migrations, 0) + vsf := make(Migrations, 0, len(m)) for k, v := range m { if v.DBType == "all" { // Add "all" only if we can not find a more specific migration for the dialect. diff --git a/popx/migrator.go b/popx/migrator.go index 9bff3d53..4ef8211c 100644 --- a/popx/migrator.go +++ b/popx/migrator.go @@ -11,6 +11,7 @@ import ( "math" "os" "regexp" + "slices" "sort" "strings" "text/tabwriter" @@ -18,7 +19,6 @@ import ( "github.com/gobuffalo/pop/v6" - "github.com/cockroachdb/cockroach-go/v2/crdb" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" @@ -92,7 +92,7 @@ func (m *Migrator) Up(ctx context.Context) error { // If step <= 0 all pending migrations are run. func (m *Migrator) UpTo(ctx context.Context, step int) (applied int, err error) { span, ctx := m.startSpan(ctx, MigrationUpOpName) - defer span.End() + defer otelx.End(span, &err) c := m.Connection.WithContext(ctx) err = m.exec(ctx, func() error { @@ -101,44 +101,39 @@ func (m *Migrator) UpTo(ctx context.Context, step int) (applied int, err error) for _, mi := range mfs { l := m.l.WithField("version", mi.Version).WithField("migration_name", mi.Name).WithField("migration_file", mi.Path) - exists, err := c.Where("version = ?", mi.Version).Exists(mtn) + appliedMigrations := make([]string, 0, 2) + legacyVersion := mi.Version + if len(legacyVersion) > 14 { + legacyVersion = legacyVersion[:14] + } + err := c.RawQuery(fmt.Sprintf("SELECT version FROM %s WHERE version IN (?, ?)", mtn), mi.Version, legacyVersion).All(&appliedMigrations) if err != nil { return errors.Wrapf(err, "problem checking for migration version %s", mi.Version) } - if exists { + if slices.Contains(appliedMigrations, mi.Version) { l.Debug("Migration has already been applied, skipping.") continue } - if len(mi.Version) > 14 { - l.Debug("Migration has not been applied but it might be a legacy migration, investigating.") - - legacyVersion := mi.Version[:14] - exists, err = c.Where("version = ?", legacyVersion).Exists(mtn) - if err != nil { - return errors.Wrapf(err, "problem checking for legacy migration version %s", legacyVersion) - } + if slices.Contains(appliedMigrations, legacyVersion) { + l.WithField("legacy_version", legacyVersion).WithField("migration_table", mtn).Debug("Migration has already been applied in a legacy migration run. Updating version in migration table.") + if err := m.isolatedTransaction(ctx, "init-migrate", func(conn *pop.Connection) error { + // We do not want to remove the legacy migration version or subsequent migrations might be applied twice. + // + // Do not activate the following - it is just for reference. + // + // if _, err := tx.Store.Exec(fmt.Sprintf("DELETE FROM %s WHERE version = ?", mtn), legacyVersion); err != nil { + // return errors.Wrapf(err, "problem removing legacy version %s", mi.Version) + // } - if exists { - l.WithField("legacy_version", legacyVersion).WithField("migration_table", mtn).Debug("Migration has already been applied in a legacy migration run. Updating version in migration table.") - if err := m.isolatedTransaction(ctx, "init-migrate", func(conn *pop.Connection) error { - // We do not want to remove the legacy migration version or subsequent migrations might be applied twice. - // - // Do not activate the following - it is just for reference. - // - // if _, err := tx.Store.Exec(fmt.Sprintf("DELETE FROM %s WHERE version = ?", mtn), legacyVersion); err != nil { - // return errors.Wrapf(err, "problem removing legacy version %s", mi.Version) - // } - - // #nosec G201 - mtn is a system-wide const - err := conn.RawQuery(fmt.Sprintf("INSERT INTO %s (version) VALUES (?)", mtn), mi.Version).Exec() - return errors.Wrapf(err, "problem inserting migration version %s", mi.Version) - }); err != nil { - return err - } - continue + // #nosec G201 - mtn is a system-wide const + err := conn.RawQuery(fmt.Sprintf("INSERT INTO %s (version) VALUES (?)", mtn), mi.Version).Exec() + return errors.Wrapf(err, "problem inserting migration version %s", mi.Version) + }); err != nil { + return err } + continue } l.Info("Migration has not yet been applied, running migration.") @@ -506,10 +501,9 @@ func (m *Migrator) Status(ctx context.Context) (MigrationStatuses, error) { if len(migrations) == 0 { return nil, errors.Errorf("unable to find any migrations for dialect: %s", con.Dialect.Name()) } - m.sanitizedMigrationTableName(con) - var migrationRows []migrationRow - err := con.RawQuery(fmt.Sprintf("SELECT * FROM %s", m.sanitizedMigrationTableName(con))).All(&migrationRows) + alreadyApplied := make([]string, 0, len(migrations)) + err := con.RawQuery(fmt.Sprintf("SELECT version FROM %s", m.sanitizedMigrationTableName(con))).All(&alreadyApplied) if err != nil { if errIsTableNotFound(err) { // This means that no migrations have been applied and we need to apply all of them first! @@ -529,15 +523,11 @@ func (m *Migrator) Status(ctx context.Context) (MigrationStatuses, error) { Name: mf.Name, } - for _, mr := range migrationRows { - if mr.Version == mf.Version { - statuses[k].State = Applied - break - } else if len(mf.Version) > 14 { - if mr.Version == mf.Version[:14] { - statuses[k].State = Applied - } - } + if slices.ContainsFunc(alreadyApplied, func(applied string) bool { + return applied == mf.Version || (len(mf.Version) > 14 && applied == mf.Version[:14]) + }) { + statuses[k].State = Applied + continue } } @@ -596,13 +586,6 @@ func (m *Migrator) exec(ctx context.Context, fn func() error) error { } } - if m.Connection.Dialect.Name() == "cockroach" { - outer := fn - fn = func() error { - return crdb.Execute(outer) - } - } - if err := fn(); err != nil { return err } diff --git a/popx/migrator_test.go b/popx/migrator_test.go index 808ed8f2..e03ddfdc 100644 --- a/popx/migrator_test.go +++ b/popx/migrator_test.go @@ -4,12 +4,9 @@ package popx_test import ( - "bytes" "context" "embed" - "fmt" - "os" - "strings" + "github.com/ory/x/dbal" "testing" "github.com/gobuffalo/pop/v6" @@ -18,130 +15,17 @@ import ( "github.com/stretchr/testify/require" "github.com/ory/x/logrusx" - "github.com/ory/x/pkgerx" . "github.com/ory/x/popx" - "github.com/ory/x/sqlcon/dockertest" ) //go:embed stub/migrations/transactional/*.sql var transactionalMigrations embed.FS -func TestMigratorUpgrading(t *testing.T) { - litedb, err := os.CreateTemp("", "sqlite-*") - require.NoError(t, err) - require.NoError(t, litedb.Close()) - - ctx := context.Background() - - sqlite, err := pop.NewConnection(&pop.ConnectionDetails{ - URL: "sqlite://file::memory:?_fk=true", - }) - require.NoError(t, err) - require.NoError(t, sqlite.Open()) - - connections := map[string]*pop.Connection{ - "sqlite": sqlite, - } - - if !testing.Short() { - dockertest.Parallel([]func(){ - func() { - connections["postgres"] = dockertest.ConnectToTestPostgreSQLPop(t) - }, - func() { - connections["mysql"] = dockertest.ConnectToTestMySQLPop(t) - }, - func() { - connections["cockroach"] = dockertest.ConnectToTestCockroachDBPop(t) - }, - }) - } - - l := logrusx.New("", "", logrusx.ForceLevel(logrus.DebugLevel)) - - for name, c := range connections { - t.Run(fmt.Sprintf("database=%s", name), func(t *testing.T) { - t.SkipNow() - - legacy, err := pkgerx.NewMigrationBox("/popx/stub/migrations/legacy", c, l) - require.NoError(t, err) - require.NoError(t, legacy.Up()) - - var legacyStatusBuffer bytes.Buffer - require.NoError(t, legacy.Status(&legacyStatusBuffer)) - - legacyStatus := filterMySQL(t, name, legacyStatusBuffer.String()) - - require.NotContains(t, legacyStatus, Pending) - - expected := legacy.DumpMigrationSchema() - - transactional, err := NewMigrationBox(transactionalMigrations, NewMigrator(c, l, nil, 0)) - require.NoError(t, err) - - var transactionalStatusBuffer bytes.Buffer - statuses, err := transactional.Status(ctx) - require.NoError(t, err) - - require.NoError(t, statuses.Write(&transactionalStatusBuffer)) - transactionalStatus := filterMySQL(t, name, transactionalStatusBuffer.String()) - require.NotContains(t, transactionalStatus, Pending) - require.False(t, statuses.HasPending()) - - require.NoError(t, transactional.Up(ctx)) - - actual := transactional.DumpMigrationSchema(ctx) - assert.EqualValues(t, expected, actual) - - // Re-set and re-try - - require.NoError(t, legacy.Down(-1)) - require.NoError(t, transactional.Up(ctx)) - actual = transactional.DumpMigrationSchema(ctx) - assert.EqualValues(t, expected, actual) - }) - } -} - -func filterMySQL(t *testing.T, name string, status string) string { - if name == "mysql" { - return status - } - // These only run for mysql and are thus expected to be pending: - // - // 20191100000005 identities Pending - // 20191100000009 verification Pending - // 20200519101058 create_recovery_addresses Pending - // 20200601101001 verification Pending - - pending := []string{"20191100000005", "20191100000009", "20200519101058", "20200601101001"} - var lines []string - for _, l := range strings.Split(status, "\n") { - var skip bool - for _, p := range pending { - if strings.Contains(l, p) { - t.Logf("Removing expected pending line: %s", l) - skip = true - break - } - } - if !skip { - lines = append(lines, l) - } - } - - return strings.Join(lines, "\n") -} - func TestMigratorUpgradingFromStart(t *testing.T) { - litedb, err := os.CreateTemp("", "sqlite-*") - require.NoError(t, err) - require.NoError(t, litedb.Close()) - ctx := context.Background() c, err := pop.NewConnection(&pop.ConnectionDetails{ - URL: "sqlite://file::memory:?_fk=true", + URL: dbal.NewSQLiteTestDatabase(t), }) require.NoError(t, err) require.NoError(t, c.Open()) @@ -151,37 +35,44 @@ func TestMigratorUpgradingFromStart(t *testing.T) { require.NoError(t, err) status, err := transactional.Status(ctx) require.NoError(t, err) - require.True(t, status.HasPending()) + assert.True(t, status.HasPending()) + + applied, err := transactional.UpTo(ctx, 1) + require.NoError(t, err) + assert.Equal(t, 1, applied) + + status, err = transactional.Status(ctx) + require.NoError(t, err) + assert.True(t, status.HasPending()) + assert.Equal(t, Applied, status[0].State) + assert.Equal(t, Pending, status[1].State) require.NoError(t, transactional.Up(ctx)) status, err = transactional.Status(ctx) require.NoError(t, err) - require.False(t, status.HasPending()) + assert.False(t, status.HasPending()) // Are all the tables here? var rows []string - require.NoError(t, c.Store.Select(&rows, "SELECT name FROM sqlite_master WHERE type='table'")) + require.NoError(t, c.RawQuery("SELECT name FROM sqlite_master WHERE type='table'").All(&rows)) - for _, expected := range []string{ - "schema_migration", - "identities", - } { - require.Contains(t, rows, expected) - } + assert.ElementsMatch(t, rows, []string{"schema_migration", "identities", "identity_credential_types", + "identity_credentials", "identity_credential_identifiers", "selfservice_login_flows", "selfservice_login_flow_methods", + "selfservice_registration_flows", "selfservice_registration_flow_methods", "selfservice_errors", "courier_messages", + "selfservice_settings_flow_methods", "continuity_containers", "identity_recovery_addresses", + "selfservice_recovery_flows", "selfservice_recovery_flow_methods", "selfservice_settings_flows", "sessions", + "selfservice_verification_flow_methods", "selfservice_verification_flows", "identity_verification_tokens", + "identity_recovery_tokens", "identity_verifiable_addresses"}) require.NoError(t, transactional.Down(ctx, -1)) } func TestMigratorSanitizeMigrationTableName(t *testing.T) { - litedb, err := os.CreateTemp("", "sqlite-*") - require.NoError(t, err) - require.NoError(t, litedb.Close()) - ctx := context.Background() c, err := pop.NewConnection(&pop.ConnectionDetails{ - URL: `sqlite://file::memory:?_fk=true&migration_table_name=injection--`, + URL: dbal.NewSQLiteTestDatabase(t) + "&migration_table_name=injection--", }) require.NoError(t, err) require.NoError(t, c.Open()) diff --git a/popx/test_migrator.go b/popx/test_migrator.go index 01bb09f1..22ec87f0 100644 --- a/popx/test_migrator.go +++ b/popx/test_migrator.go @@ -22,7 +22,7 @@ type TestMigrator struct { *Migrator } -// Returns a new TestMigrator +// NewTestMigrator returns a new TestMigrator // After running each migration it applies it's corresponding testData sql files. // They are identified by having the same version (= number in the front of the filename). // The filenames are expected to be of the format ([0-9]+).*(_testdata(\.[dbtype])?.sql