diff --git a/describe_columns.go b/describe_columns.go index dd514bc..a174510 100644 --- a/describe_columns.go +++ b/describe_columns.go @@ -7,7 +7,6 @@ import ( type ColumnDescription struct { Name string - Index int Type string IsNullable bool Default string @@ -21,7 +20,6 @@ type ColumnDescription struct { func (d ColumnDescription) Equals(other ColumnDescription) bool { return true && d.Name == other.Name && - d.Index == other.Index && d.Type == other.Type && d.IsNullable == other.IsNullable && d.Default == other.Default && @@ -36,7 +34,6 @@ type column struct { Namespace string TableName string Name string - Index int Type string IsNullable bool Default *string @@ -58,7 +55,6 @@ var scanColumns = NewSliceScanner(func(s Scanner) (c column, _ error) { &c.Namespace, &c.TableName, &c.Name, - &c.Index, &c.Type, &isNullable, &c.Default, @@ -81,7 +77,6 @@ func describeColumns(ctx context.Context, db DB) (map[string][]ColumnDescription c.table_schema AS namespace, c.table_name AS name, c.column_name AS column_name, - c.ordinal_position AS index, CASE WHEN c.data_type = 'ARRAY' THEN COALESCE(( SELECT e.data_type @@ -124,7 +119,6 @@ func describeColumns(ctx context.Context, db DB) (map[string][]ColumnDescription columnMap[key] = append(columnMap[key], ColumnDescription{ Name: column.Name, - Index: column.Index, Type: column.Type, IsNullable: column.IsNullable, Default: deref(column.Default), diff --git a/drift.go b/drift.go index 885bde8..949663c 100644 --- a/drift.go +++ b/drift.go @@ -4,6 +4,7 @@ import ( "cmp" "fmt" "slices" + "sort" "strings" ) @@ -66,46 +67,122 @@ func Compare(a, b SchemaDescription) (statements []string) { // views : depends on tables, columns, views // triggers : depends on tables, columns, functions + cmpWithKey := func(a, b ddlStatement) int { return cmp.Compare(a.key, b.key) } + defaultSort := func(statements []ddlStatement) { slices.SortFunc(statements, cmpWithKey) } + + // + // TODO - rewrite this + // + + // sortByClosure := func(cls closure) func([]ddlStatement) { + // return func(statements []ddlStatement) { + // slices.SortFunc(statements, func(a, b ddlStatement) int { + // if _, ok := cls[a.key][b.key]; ok { + // return -1 + // } + // if _, ok := cls[b.key][a.key]; ok { + // return +1 + // } + + // return cmp.Compare(a.key, b.key) + // }) + // } + // } + + sortByClosure := func(cls closure) func([]ddlStatement) { + return func(statements []ddlStatement) { + keys := map[string]ddlStatement{} + for _, stmt := range statements { + keys[stmt.key] = stmt + } + + cls2 := map[string][]string{} + for _, stmt := range statements { + cls2[stmt.key] = nil + } + + for k, v := range cls { + for k2 := range v { + if _, ok := keys[k]; ok { + if _, ok := cls2[k2]; ok { + cls2[k2] = append(cls2[k2], k) + } + } + } + } + + var sorted []ddlStatement + for len(cls2) > 0 { + var candidates []string + for k, v := range cls2 { + if len(v) == 0 { + candidates = append(candidates, k) + } + } + + if len(candidates) == 0 { + panic("cycle detected in closure, cannot perform topological sort") + } + + sort.Strings(candidates) + top := candidates[0] + sorted = append(sorted, keys[top]) + + for k, v := range cls2 { + for i := 0; i < len(v); i++ { + if v[i] == top { + cls2[k] = append(v[:i], v[i+1:]...) + } + } + } + + delete(cls2, top) + } + + for i := range statements { + statements[i] = sorted[i] + } + } + } + createClosure, dropClosure := viewClosures(a, b) - sortCreateViews := func(statements []ddlStatement) { slices.SortFunc(statements, cmpWithClosure(createClosure)) } - sortDropViews := func(statements []ddlStatement) { slices.SortFunc(statements, cmpWithClosure(dropClosure)) } + sortCreateViews := sortByClosure(createClosure) + sortDropViews := sortByClosure(dropClosure) order := []struct { statementType string objectType string order func(statements []ddlStatement) }{ - {"drop", "trigger", nil}, + {"drop", "trigger", defaultSort}, {"drop", "view", sortDropViews}, - {"drop", "index", nil}, - {"drop", "constraint", nil}, - {"drop", "column", nil}, - {"drop", "sequence", nil}, - {"drop", "table", nil}, - {"drop", "function", nil}, - {"drop", "enum", nil}, - {"drop", "extension", nil}, - {"create", "extension", nil}, - {"create", "enum", nil}, - {"replace", "enum", nil}, - {"create", "function", nil}, - {"replace", "function", nil}, - {"create", "table", nil}, - {"create", "sequence", nil}, - {"replace", "sequence", nil}, - {"create", "column", nil}, - {"replace", "column", nil}, - {"create", "constraint", nil}, - {"create", "index", nil}, + {"drop", "index", defaultSort}, + {"drop", "constraint", defaultSort}, + {"drop", "column", defaultSort}, + {"drop", "sequence", defaultSort}, + {"drop", "table", defaultSort}, + {"drop", "function", defaultSort}, + {"drop", "enum", defaultSort}, + {"drop", "extension", defaultSort}, + {"create", "extension", defaultSort}, + {"create", "enum", defaultSort}, + {"replace", "enum", defaultSort}, + {"create", "function", defaultSort}, + {"replace", "function", defaultSort}, + {"create", "table", defaultSort}, + {"create", "sequence", defaultSort}, + {"replace", "sequence", defaultSort}, + {"create", "column", defaultSort}, + {"replace", "column", defaultSort}, + {"create", "constraint", defaultSort}, + {"create", "index", defaultSort}, {"create", "view", sortCreateViews}, - {"create", "trigger", nil}, + {"create", "trigger", defaultSort}, } for _, o := range order { filtered := filter(o.statementType, o.objectType) - if o.order != nil { - o.order(filtered) - } + o.order(filtered) for _, statement := range filtered { statements = append(statements, statement.statements...) @@ -119,29 +196,18 @@ func Compare(a, b SchemaDescription) (statements []string) { // // +// TODO - document func viewClosures(a, b SchemaDescription) (createClosure, dropClosure closure) { - // TODO - // keys := map[string]struct{}{} - // for _, view := range a.Views { - // keys[fmt.Sprintf("%q.%q", view.Namespace, view.Name)] = struct{}{} - // } - // for _, view := range a.Views { - // keys[fmt.Sprintf("%q.%q", view.Namespace, view.Name)] = struct{}{} - // } - createClosure = closure{} for _, dependency := range a.ColumnDependencies { sourceKey := fmt.Sprintf("%q.%q", dependency.SourceNamespace, dependency.SourceTableOrViewName) dependencyKey := fmt.Sprintf("%q.%q", dependency.UsedNamespace, dependency.UsedTableOrView) - // TODO - // if _, ok := keys[dependencyKey]; ok { if _, ok := createClosure[sourceKey]; !ok { createClosure[sourceKey] = map[string]struct{}{} } createClosure[sourceKey][dependencyKey] = struct{}{} - // } } dropClosure = closure{} @@ -149,14 +215,11 @@ func viewClosures(a, b SchemaDescription) (createClosure, dropClosure closure) { sourceKey := fmt.Sprintf("%q.%q", dependency.SourceNamespace, dependency.SourceTableOrViewName) dependencyKey := fmt.Sprintf("%q.%q", dependency.UsedNamespace, dependency.UsedTableOrView) - // TODO - // if _, ok := keys[sourceKey]; ok { if _, ok := dropClosure[dependencyKey]; !ok { dropClosure[dependencyKey] = map[string]struct{}{} } dropClosure[dependencyKey][sourceKey] = struct{}{} - // } } transitiveClosure(createClosure) @@ -170,6 +233,7 @@ func viewClosures(a, b SchemaDescription) (createClosure, dropClosure closure) { type closure map[string]map[string]struct{} +// TODO - rename internals func transitiveClosure(m closure) { changed := true for changed { @@ -188,19 +252,6 @@ func transitiveClosure(m closure) { } } -func cmpWithClosure(createClosure closure) func(a, b ddlStatement) int { - return func(a, b ddlStatement) int { - if _, ok := createClosure[a.key][b.key]; ok { - return -1 - } - if _, ok := createClosure[b.key][a.key]; ok { - return +1 - } - - return cmp.Compare(a.key, b.key) - } -} - // // // @@ -227,8 +278,8 @@ type alterer[T any] interface { type ddlStatement struct { key string - statementType string // TODO - enum - objectType string // TODO - enum + statementType string + objectType string statements []string } diff --git a/drift_columns.go b/drift_columns.go index 89966a9..b6799ab 100644 --- a/drift_columns.go +++ b/drift_columns.go @@ -1,6 +1,8 @@ package pgutil -import "fmt" +import ( + "fmt" +) type ColumnModifier struct { t TableDescription @@ -37,7 +39,7 @@ func (m ColumnModifier) Create() string { defaultExpr = fmt.Sprintf(" DEFAULT %s", m.d.Default) } - return fmt.Sprintf("ALTER TABLE %q.%q ADD COLUMN IF NOT EXISTS %q %s %s %s;", m.t.Namespace, m.t.Name, m.d.Name, m.d.Type, nullableExpr, defaultExpr) + return fmt.Sprintf("ALTER TABLE %q.%q ADD COLUMN IF NOT EXISTS %q %s%s%s;", m.t.Namespace, m.t.Name, m.d.Name, m.d.Type, nullableExpr, defaultExpr) } func (m ColumnModifier) Drop() string { @@ -45,56 +47,41 @@ func (m ColumnModifier) Drop() string { } func (m ColumnModifier) AlterExisting(existingSchema SchemaDescription, existingObject ColumnDescription) ([]ddlStatement, bool) { - // TODO - stop tracking? - // Index int - - // TODO - implement these - // Type string - // IsNullable bool - // Default string - - // TODO - how to modify these? - // CharacterMaximumLength int - // IsIdentity bool - // IdentityGeneration string - // IsGenerated bool - // GenerationExpression string - - // statements := []string{} - - // if d.TypeName != target.TypeName { - // statements = append(statements, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s SET DATA TYPE %s;", table.Name, target.Name, target.TypeName)) - - // // Remove from diff below - // d.TypeName = target.TypeName - // } - // if d.IsNullable != target.IsNullable { - // var verb string - // if target.IsNullable { - // verb = "DROP" - // } else { - // verb = "SET" - // } - - // statements = append(statements, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s %s NOT NULL;", table.Name, target.Name, verb)) - - // // Remove from diff below - // d.IsNullable = target.IsNullable - // } - // if d.Default != target.Default { - // if target.Default == "" { - // statements = append(statements, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s DROP DEFAULT;", table.Name, target.Name)) - // } else { - // statements = append(statements, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s;", table.Name, target.Name, target.Default)) - // } - - // // Remove from diff below - // d.Default = target.Default - // } + statements := []string{} + alterColumn := func(format string, args ...any) { + statements = append(statements, fmt.Sprintf(fmt.Sprintf("ALTER TABLE %q.%q ALTER COLUMN %q %s;", m.t.Namespace, m.t.Name, m.d.Name, format), args...)) + } - // // Abort if there are other fields we haven't addressed - // hasAdditionalDiff := cmp.Diff(d, target) != "" - // return statements, !hasAdditionalDiff + if m.d.Type != existingObject.Type { + alterColumn("SET DATA TYPE %s", m.d.Type) + } + if m.d.Default != existingObject.Default { + if m.d.Default == "" { + alterColumn("DROP DEFAULT") + } else { + alterColumn("SET DEFAULT %s", m.d.Default) + } + } + if m.d.IsNullable != existingObject.IsNullable { + if m.d.IsNullable { + alterColumn("DROP NOT NULL") + } else { + alterColumn("SET NOT NULL") + } + } - return nil, false + // TODO - handle CharacterMaximumLength + // TODO - handle IsIdentity + // TODO - handle IdentityGeneration + // TODO - handle IsGenerated + // TODO - handle GenerationExpression + + return []ddlStatement{ + newStatement( + m.Key(), + "replace", + m.ObjectType(), + statements..., + ), + }, true } diff --git a/drift_constraints.go b/drift_constraints.go index ccd91a9..93220b7 100644 --- a/drift_constraints.go +++ b/drift_constraints.go @@ -31,5 +31,5 @@ func (m ConstraintModifier) Create() string { } func (m ConstraintModifier) Drop() string { - return fmt.Sprintf("ALTER TABLE %q.%q DROP CONSTRAINT %q;", m.t.Namespace, m.t.Name, m.d.Name) + return fmt.Sprintf("ALTER TABLE %q.%q DROP CONSTRAINT IF EXISTS %q;", m.t.Namespace, m.t.Name, m.d.Name) } diff --git a/drift_enums.go b/drift_enums.go index 32bae0c..ce7ed7c 100644 --- a/drift_enums.go +++ b/drift_enums.go @@ -6,7 +6,7 @@ import ( ) type EnumModifier struct { - s SchemaDescription // TODO - rename + s SchemaDescription d EnumDescription } @@ -32,7 +32,7 @@ func (m EnumModifier) Description() EnumDescription { func (m EnumModifier) Create() string { var quotedLabels []string for _, label := range m.d.Labels { - quotedLabels = append(quotedLabels, fmt.Sprintf("'%s'", label)) // TODO - escape '? + quotedLabels = append(quotedLabels, enumQuote(label)) } return fmt.Sprintf("CREATE TYPE %s AS ENUM (%s);", m.Key(), strings.Join(quotedLabels, ", ")) @@ -48,8 +48,9 @@ func (m EnumModifier) Drop() string { func (m EnumModifier) AlterExisting(existingSchema SchemaDescription, existingObject EnumDescription) ([]ddlStatement, bool) { reconstruction, ok := unifyLabels(m.d.Labels, existingObject.Labels) if !ok { + // TODO - document var views []string - createClosure, _ := viewClosures(existingSchema, SchemaDescription{}) // TODO : note this (woah) + createClosure, _ := viewClosures(existingSchema, SchemaDescription{}) var alters []string for _, dep := range existingSchema.EnumDependencies { @@ -160,12 +161,12 @@ func (m EnumModifier) AlterExisting(existingSchema SchemaDescription, existingOb for _, missingLabel := range reconstruction { relativeTo := "" if missingLabel.Next != nil { - relativeTo = fmt.Sprintf("BEFORE '%s'", *missingLabel.Next) + relativeTo = fmt.Sprintf("BEFORE %s", enumQuote(*missingLabel.Next)) } else { - relativeTo = fmt.Sprintf("AFTER '%s'", *missingLabel.Prev) + relativeTo = fmt.Sprintf("AFTER %s", enumQuote(*missingLabel.Prev)) } - statements = append(statements, fmt.Sprintf("ALTER TYPE %q.%q ADD VALUE '%s' %s;", m.d.Namespace, m.d.Name, missingLabel.Label, relativeTo)) + statements = append(statements, fmt.Sprintf("ALTER TYPE %q.%q ADD VALUE %s %s;", m.d.Namespace, m.d.Name, enumQuote(missingLabel.Label), relativeTo)) } return []ddlStatement{ @@ -178,6 +179,10 @@ func (m EnumModifier) AlterExisting(existingSchema SchemaDescription, existingOb }, true } +func enumQuote(label string) string { + return fmt.Sprintf("'%s'", strings.ReplaceAll(label, "'", "''")) +} + type missingLabel struct { Label string Prev *string diff --git a/drift_indexes.go b/drift_indexes.go index 1a69afc..bd904f5 100644 --- a/drift_indexes.go +++ b/drift_indexes.go @@ -36,11 +36,10 @@ func (m IndexModifier) Create() string { func (m IndexModifier) Drop() string { if m.isConstraint() { - return fmt.Sprintf("ALTER TABLE %q.%q DROP CONSTRAINT %q;", m.t.Namespace, m.t.Name, m.d.Name) + return fmt.Sprintf("ALTER TABLE %q.%q DROP CONSTRAINT IF EXISTS %q;", m.t.Namespace, m.t.Name, m.d.Name) } - // TODO - namespace? - return fmt.Sprintf("DROP INDEX IF EXISTS %q;", m.t.Name) + return fmt.Sprintf("DROP INDEX IF EXISTS %q.%q;", m.t.Namespace, m.d.Name) } func (m IndexModifier) isConstraint() bool { diff --git a/drift_test.go b/drift_test.go index 246b78d..fef8435 100644 --- a/drift_test.go +++ b/drift_test.go @@ -12,49 +12,46 @@ import ( func TestDrift_Extensions(t *testing.T) { t.Run("create", func(t *testing.T) { - testDrift(t, - // setup - `CREATE EXTENSION hstore;`, - // alter - `DROP EXTENSION hstore;`, - // expected - `CREATE EXTENSION IF NOT EXISTS "hstore" WITH SCHEMA "public";`, - ) + testDrift(t, DriftTestCase{ + Setup: []string{`CREATE EXTENSION hstore;`}, + Alter: []string{`DROP EXTENSION hstore;`}, + Expected: []string{`CREATE EXTENSION IF NOT EXISTS "hstore" WITH SCHEMA "public";`}, + }) }) t.Run("drop", func(t *testing.T) { - testDrift(t, - // setup - ``, - // alter - `CREATE EXTENSION pg_trgm;`, - // expected - `DROP EXTENSION IF EXISTS "pg_trgm";`, - ) + testDrift(t, DriftTestCase{ + Setup: []string{}, + Alter: []string{`CREATE EXTENSION pg_trgm;`}, + Expected: []string{`DROP EXTENSION IF EXISTS "pg_trgm";`}, + }) }) } func TestDrift_Enums(t *testing.T) { + t.Run("create", func(t *testing.T) { - testDrift(t, - // setup - `CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy');`, - // alter - `DROP TYPE mood;`, - // expected - `CREATE TYPE "public"."mood" AS ENUM ('sad', 'ok', 'happy');`, - ) + testDrift(t, DriftTestCase{ + Setup: []string{`CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy');`}, + Alter: []string{`DROP TYPE mood;`}, + Expected: []string{`CREATE TYPE "public"."mood" AS ENUM ('sad', 'ok', 'happy');`}, + }) + + t.Run("escaped values", func(t *testing.T) { + testDrift(t, DriftTestCase{ + Setup: []string{`CREATE TYPE spell_check AS ENUM ('there', 'their', 'they''re');`}, + Alter: []string{`DROP TYPE spell_check;`}, + Expected: []string{`CREATE TYPE "public"."spell_check" AS ENUM ('there', 'their', 'they''re');`}, + }) + }) }) t.Run("drop", func(t *testing.T) { - testDrift(t, - // setup - ``, - // alter - `CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy');`, - // expected - `DROP TYPE IF EXISTS "public"."mood";`, - ) + testDrift(t, DriftTestCase{ + Setup: []string{}, + Alter: []string{`CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy');`}, + Expected: []string{`DROP TYPE IF EXISTS "public"."mood";`}, + }) }) t.Run("alter", func(t *testing.T) { @@ -94,122 +91,140 @@ func TestDrift_Enums(t *testing.T) { }, } { t.Run(testCase.name, func(t *testing.T) { - var setupLabels []string - for _, label := range testCase.expectedLabels { - setupLabels = append(setupLabels, fmt.Sprintf("'%s'", label)) + // Prepare setup queries + setupLabels := make([]string, len(testCase.expectedLabels)) + for i, label := range testCase.expectedLabels { + setupLabels[i] = fmt.Sprintf("'%s'", label) } + setupQuery := fmt.Sprintf("CREATE TYPE mood AS ENUM (%s);", strings.Join(setupLabels, ", ")) - var alterLabels []string - for _, label := range testCase.existingLabels { - alterLabels = append(alterLabels, fmt.Sprintf("'%s'", label)) + // Prepare alter queries + existingLabelsFormatted := make([]string, len(testCase.existingLabels)) + for i, label := range testCase.existingLabels { + existingLabelsFormatted[i] = fmt.Sprintf("'%s'", label) } - - testDrift(t, - // setup - `CREATE TYPE mood AS ENUM (`+strings.Join(setupLabels, ", ")+`);`, - // alter - ` - DROP TYPE mood; - CREATE TYPE mood AS ENUM (`+strings.Join(alterLabels, ", ")+`); - `, - // expected - testCase.expectedQueries..., - ) + alterQuery := fmt.Sprintf(` + DROP TYPE mood; + CREATE TYPE mood AS ENUM (%s); + `, strings.Join(existingLabelsFormatted, ", ")) + + // Execute testDrift with the new struct + testDrift(t, DriftTestCase{ + Setup: []string{setupQuery}, + Alter: []string{alterQuery}, + Expected: testCase.expectedQueries, + }) }) } + + t.Run("escaped values", func(t *testing.T) { + testDrift(t, DriftTestCase{ + Setup: []string{ + `CREATE TYPE spell_check AS ENUM ('they''re', 'there', 'their', 'whose', 'who''s');`, + }, + Alter: []string{ + `DROP TYPE spell_check;`, + `CREATE TYPE spell_check AS ENUM ('they''re', 'their', 'whose');`, + }, + Expected: []string{ + `ALTER TYPE "public"."spell_check" ADD VALUE 'there' AFTER 'they''re';`, + `ALTER TYPE "public"."spell_check" ADD VALUE 'who''s' AFTER 'whose';`, + }, + }) + }) }) t.Run("non-repairable labels", func(t *testing.T) { - testDrift(t, - // setup - `CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy');`, - // alter - ` - DROP TYPE mood; - CREATE TYPE mood AS ENUM ('happy', 'sad', 'ok', 'gleeful'); - `, - // expected - `ALTER TYPE "public"."mood" RENAME TO "mood_bak";`, - `CREATE TYPE "public"."mood" AS ENUM ('sad', 'ok', 'happy');`, - `DROP TYPE "public"."mood_bak";`, - ) + testDrift(t, DriftTestCase{ + Setup: []string{ + `CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy');`, + }, + Alter: []string{ + `DROP TYPE mood;`, + `CREATE TYPE mood AS ENUM ('happy', 'sad', 'ok', 'gleeful');`, + }, + Expected: []string{ + `ALTER TYPE "public"."mood" RENAME TO "mood_bak";`, + `CREATE TYPE "public"."mood" AS ENUM ('sad', 'ok', 'happy');`, + `DROP TYPE "public"."mood_bak";`, + }, + }) }) t.Run("updates column types", func(t *testing.T) { - testDrift(t, - // setup - ` - CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy'); - CREATE TABLE t (m mood); - `, - // alter - ` - DROP TABLE t; - DROP TYPE mood; - CREATE TYPE mood AS ENUM ('happy', 'sad', 'ok', 'gleeful'); - CREATE TABLE t (m mood); - `, - // expected - `ALTER TYPE "public"."mood" RENAME TO "mood_bak";`, - `CREATE TYPE "public"."mood" AS ENUM ('sad', 'ok', 'happy');`, - `ALTER TABLE "public"."t" ALTER COLUMN "m" TYPE "public"."mood" USING ("m"::text::"public"."mood");`, - `DROP TYPE "public"."mood_bak";`, - ) + testDrift(t, DriftTestCase{ + Setup: []string{ + `CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy');`, + `CREATE TABLE t (m mood);`, + }, + Alter: []string{ + `DROP TABLE t;`, + `DROP TYPE mood;`, + `CREATE TYPE mood AS ENUM ('happy', 'sad', 'ok', 'gleeful');`, + `CREATE TABLE t (m mood);`, + }, + Expected: []string{ + `ALTER TYPE "public"."mood" RENAME TO "mood_bak";`, + `CREATE TYPE "public"."mood" AS ENUM ('sad', 'ok', 'happy');`, + `ALTER TABLE "public"."t" ALTER COLUMN "m" TYPE "public"."mood" USING ("m"::text::"public"."mood");`, + `DROP TYPE "public"."mood_bak";`, + }, + }) }) t.Run("updates column defaults", func(t *testing.T) { - testDrift(t, - // setup - ` - CREATE SCHEMA s; - CREATE TYPE s.mood AS ENUM ('sad', 'ok', 'happy'); - CREATE TABLE s.t (m s.mood DEFAULT 'sad'); - `, - // alter - ` - DROP TABLE s.t; - DROP TYPE s.mood; - CREATE TYPE s.mood AS ENUM ('happy', 'sad', 'ok', 'gleeful'); - CREATE TABLE s.t (m s.mood DEFAULT 'sad'); - `, - // expected - `ALTER TYPE "s"."mood" RENAME TO "mood_bak";`, - `CREATE TYPE "s"."mood" AS ENUM ('sad', 'ok', 'happy');`, - `ALTER TABLE "s"."t" ALTER COLUMN "m" DROP DEFAULT, ALTER COLUMN "m" TYPE "s"."mood" USING ("m"::text::"s"."mood"), ALTER COLUMN "m" SET DEFAULT 'sad'::s.mood;`, - `DROP TYPE "s"."mood_bak";`, - ) + testDrift(t, DriftTestCase{ + Setup: []string{ + `CREATE SCHEMA s;`, + `CREATE TYPE s.mood AS ENUM ('sad', 'ok', 'happy');`, + `CREATE TABLE s.t (m s.mood DEFAULT 'sad');`, + }, + Alter: []string{ + `DROP TABLE s.t;`, + `DROP TYPE s.mood;`, + `CREATE TYPE s.mood AS ENUM ('happy', 'sad', 'ok', 'gleeful');`, + `CREATE TABLE s.t (m s.mood DEFAULT 'sad');`, + }, + Expected: []string{ + `ALTER TYPE "s"."mood" RENAME TO "mood_bak";`, + `CREATE TYPE "s"."mood" AS ENUM ('sad', 'ok', 'happy');`, + `ALTER TABLE "s"."t" ALTER COLUMN "m" DROP DEFAULT, ALTER COLUMN "m" TYPE "s"."mood" USING ("m"::text::"s"."mood"), ALTER COLUMN "m" SET DEFAULT 'sad'::s.mood;`, + `DROP TYPE "s"."mood_bak";`, + }, + }) }) t.Run("temporarily drops dependent views", func(t *testing.T) { - testDrift(t, - // setup - ` - CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy'); - CREATE TABLE t (m mood); - CREATE VIEW v1 AS SELECT m FROM t; - CREATE VIEW v2 AS SELECT m FROM v1; - `, - // alter - ` - DROP VIEW v2; - DROP VIEW v1; - DROP TABLE t; - DROP TYPE mood; - CREATE TYPE mood AS ENUM ('happy', 'sad', 'ok', 'gleeful'); - CREATE TABLE t (m mood); - CREATE VIEW v1 AS SELECT m FROM t; - CREATE VIEW v2 AS SELECT m FROM v1; - `, - // expected - `DROP VIEW IF EXISTS "public"."v2";`, - `DROP VIEW IF EXISTS "public"."v1";`, - `ALTER TYPE "public"."mood" RENAME TO "mood_bak";`, - `CREATE TYPE "public"."mood" AS ENUM ('sad', 'ok', 'happy');`, - `ALTER TABLE "public"."t" ALTER COLUMN "m" TYPE "public"."mood" USING ("m"::text::"public"."mood");`, - `DROP TYPE "public"."mood_bak";`, - `CREATE OR REPLACE VIEW "public"."v1" AS SELECT m`+"\n"+` FROM t;`, - `CREATE OR REPLACE VIEW "public"."v2" AS SELECT m`+"\n"+` FROM v1;`, - ) + testDrift(t, DriftTestCase{ + Setup: []string{ + `CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy');`, + `CREATE TABLE t (m mood);`, + `CREATE VIEW v1 AS SELECT m FROM t;`, + `CREATE VIEW v2 AS SELECT m FROM v1;`, + }, + Alter: []string{ + `DROP VIEW v2;`, + `DROP VIEW v1;`, + `DROP TABLE t;`, + `DROP TYPE mood;`, + `CREATE TYPE mood AS ENUM ('happy', 'sad', 'ok', 'gleeful');`, + `CREATE TABLE t (m mood);`, + `CREATE VIEW v1 AS SELECT m FROM t;`, + `CREATE VIEW v2 AS SELECT m FROM v1;`, + }, + Expected: []string{ + `DROP VIEW IF EXISTS "public"."v2";`, + `DROP VIEW IF EXISTS "public"."v1";`, + `ALTER TYPE "public"."mood" RENAME TO "mood_bak";`, + `CREATE TYPE "public"."mood" AS ENUM ('sad', 'ok', 'happy');`, + `ALTER TABLE "public"."t" ALTER COLUMN "m" TYPE "public"."mood" USING ("m"::text::"public"."mood");`, + `DROP TYPE "public"."mood_bak";`, + `CREATE OR REPLACE VIEW "public"."v1" AS SELECT m + FROM t;`, + `CREATE OR REPLACE VIEW "public"."v2" AS SELECT m + FROM v1;`, + }, + }) }) }) } @@ -222,206 +237,621 @@ AS $function$SELECT $1 + $2;$function$ func TestDrift_Functions(t *testing.T) { t.Run("create", func(t *testing.T) { - testDrift(t, - // setup - `CREATE FUNCTION add(integer, integer) RETURNS integer AS 'SELECT $1 + $2;' LANGUAGE SQL;`, - // alter - `DROP FUNCTION add(integer, integer);`, - // expected - postgresAddFunctionDefinition, - ) + testDrift(t, DriftTestCase{ + Setup: []string{`CREATE FUNCTION add(integer, integer) RETURNS integer AS 'SELECT $1 + $2;' LANGUAGE SQL;`}, + Alter: []string{`DROP FUNCTION add(integer, integer);`}, + Expected: []string{postgresAddFunctionDefinition}, + }) }) t.Run("drop", func(t *testing.T) { - testDrift(t, - // setup - ``, - // alter - `CREATE FUNCTION add(integer, integer) RETURNS integer AS 'SELECT $1 + $2;' LANGUAGE SQL;`, - // expected - `DROP FUNCTION IF EXISTS "public"."add"(int4, int4);`, - ) + testDrift(t, DriftTestCase{ + Setup: []string{}, + Alter: []string{`CREATE FUNCTION add(integer, integer) RETURNS integer AS 'SELECT $1 + $2;' LANGUAGE SQL;`}, + Expected: []string{`DROP FUNCTION IF EXISTS "public"."add"(int4, int4);`}, + }) }) t.Run("alter", func(t *testing.T) { t.Run("mismatched definition", func(t *testing.T) { - testDrift(t, - // setup - `CREATE FUNCTION add(integer, integer) RETURNS integer AS 'SELECT $1 + $2;' LANGUAGE SQL;`, - // alter - `CREATE OR REPLACE FUNCTION add(integer, integer) RETURNS integer AS 'SELECT $1 - $2;' LANGUAGE SQL;`, - // expected - postgresAddFunctionDefinition, - ) + testDrift(t, DriftTestCase{ + Setup: []string{`CREATE FUNCTION add(integer, integer) RETURNS integer AS 'SELECT $1 + $2;' LANGUAGE SQL;`}, + Alter: []string{`CREATE OR REPLACE FUNCTION add(integer, integer) RETURNS integer AS 'SELECT $1 - $2;' LANGUAGE SQL;`}, + Expected: []string{postgresAddFunctionDefinition}, + }) }) t.Run("preserves functions with differing argument types", func(t *testing.T) { - testDrift(t, - // setup - `CREATE FUNCTION add(integer, integer) RETURNS integer AS 'SELECT $1 + $2;' LANGUAGE SQL;`, - // alter - `CREATE FUNCTION add(integer, integer, integer) RETURNS integer AS 'SELECT $1 + $2 + $3;' LANGUAGE SQL;`, - // expected (drops extra function) - `DROP FUNCTION IF EXISTS "public"."add"(int4, int4, int4);`, - ) + testDrift(t, DriftTestCase{ + Setup: []string{`CREATE FUNCTION add(integer, integer) RETURNS integer AS 'SELECT $1 + $2;' LANGUAGE SQL;`}, + Alter: []string{`CREATE FUNCTION add(integer, integer, integer) RETURNS integer AS 'SELECT $1 + $2 + $3;' LANGUAGE SQL;`}, + Expected: []string{`DROP FUNCTION IF EXISTS "public"."add"(int4, int4, int4);`}, + }) }) }) } func TestDrift_Tables(t *testing.T) { - // TODO + t.Run("create", func(t *testing.T) { + testDrift(t, DriftTestCase{ + Setup: []string{ + ` + CREATE TABLE my_table ( + id INTEGER PRIMARY KEY, + name TEXT + ); + `, + }, + Alter: []string{ + `DROP TABLE my_table;`, + }, + Expected: []string{ + `CREATE TABLE IF NOT EXISTS "public"."my_table"();`, + `ALTER TABLE "public"."my_table" ADD COLUMN IF NOT EXISTS "id" integer NOT NULL;`, + `ALTER TABLE "public"."my_table" ADD COLUMN IF NOT EXISTS "name" text;`, + `ALTER TABLE "public"."my_table" ADD CONSTRAINT "my_table_pkey" PRIMARY KEY (id);`, + }, + }) + }) + + t.Run("drop", func(t *testing.T) { + testDrift(t, DriftTestCase{ + Setup: []string{}, + Alter: []string{ + ` + CREATE TABLE my_table ( + id INTEGER PRIMARY KEY, + name TEXT + ); + `, + }, + Expected: []string{ + `ALTER TABLE "public"."my_table" DROP CONSTRAINT IF EXISTS "my_table_pkey";`, + `ALTER TABLE "public"."my_table" DROP COLUMN IF EXISTS "id";`, + `ALTER TABLE "public"."my_table" DROP COLUMN IF EXISTS "name";`, + `DROP TABLE IF EXISTS "public"."my_table";`, + }, + }) + }) + + t.Run("alter", func(t *testing.T) { + t.Run("columns", func(t *testing.T) { + t.Run("missing", func(t *testing.T) { + testDrift(t, DriftTestCase{ + Setup: []string{ + ` + CREATE TABLE my_table ( + id INTEGER PRIMARY KEY, + name TEXT + ); + `, + }, + Alter: []string{ + `ALTER TABLE my_table DROP COLUMN name;`, + }, + Expected: []string{ + `ALTER TABLE "public"."my_table" ADD COLUMN IF NOT EXISTS "name" text;`, + }, + }) + }) + + t.Run("extra", func(t *testing.T) { + testDrift(t, DriftTestCase{ + Setup: []string{ + ` + CREATE TABLE my_table ( + id INTEGER PRIMARY KEY, + name TEXT + ); + `, + }, + Alter: []string{ + `ALTER TABLE my_table ADD COLUMN age INTEGER;`, + }, + Expected: []string{ + `ALTER TABLE "public"."my_table" DROP COLUMN IF EXISTS "age";`, + }, + }) + }) + + t.Run("mismatched type", func(t *testing.T) { + testDrift(t, DriftTestCase{ + Setup: []string{ + ` + CREATE TABLE my_table ( + id INTEGER PRIMARY KEY, + name TEXT + ); + `, + }, + Alter: []string{ + `ALTER TABLE my_table ALTER COLUMN name TYPE VARCHAR(255);`, + }, + Expected: []string{ + `ALTER TABLE "public"."my_table" ALTER COLUMN "name" SET DATA TYPE text;`, + }, + }) + }) + + t.Run("mismatched default", func(t *testing.T) { + t.Run("add default", func(t *testing.T) { + testDrift(t, DriftTestCase{ + Setup: []string{ + ` + CREATE TABLE my_table ( + id INTEGER PRIMARY KEY, + name TEXT + ); + `, + }, + Alter: []string{ + `ALTER TABLE my_table ALTER COLUMN name SET DEFAULT 'foo';`, + }, + Expected: []string{ + `ALTER TABLE "public"."my_table" ALTER COLUMN "name" DROP DEFAULT;`, + }, + }) + + t.Run("drop default", func(t *testing.T) { + testDrift(t, DriftTestCase{ + Setup: []string{ + ` + CREATE TABLE my_table ( + id INTEGER PRIMARY KEY, + name TEXT DEFAULT 'foo' + ); + `, + }, + Alter: []string{ + `ALTER TABLE my_table ALTER COLUMN name DROP DEFAULT;`, + }, + Expected: []string{ + `ALTER TABLE "public"."my_table" ALTER COLUMN "name" SET DEFAULT 'foo'::text;`, + }, + }) + }) + + t.Run("change default", func(t *testing.T) { + testDrift(t, DriftTestCase{ + Setup: []string{ + ` + CREATE TABLE my_table ( + id INTEGER PRIMARY KEY, + name TEXT DEFAULT 'foo' + ); + `, + }, + Alter: []string{ + `ALTER TABLE my_table ALTER COLUMN name SET DEFAULT 'bar';`, + }, + Expected: []string{ + `ALTER TABLE "public"."my_table" ALTER COLUMN "name" SET DEFAULT 'foo'::text;`, + }, + }) + }) + }) + + t.Run("mismatched nullability", func(t *testing.T) { + t.Run("add not null", func(t *testing.T) { + testDrift(t, DriftTestCase{ + Setup: []string{ + ` + CREATE TABLE my_table ( + id INTEGER PRIMARY KEY, + name TEXT + ); + `, + }, + Alter: []string{ + `ALTER TABLE my_table ALTER COLUMN name SET NOT NULL;`, + }, + Expected: []string{ + `ALTER TABLE "public"."my_table" ALTER COLUMN "name" DROP NOT NULL;`, + }, + }) + }) + + t.Run("drop not null`", func(t *testing.T) { + testDrift(t, DriftTestCase{ + Setup: []string{ + ` + CREATE TABLE my_table ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL + ); + `, + }, + Alter: []string{ + `ALTER TABLE my_table ALTER COLUMN name DROP NOT NULL;`, + }, + Expected: []string{ + `ALTER TABLE "public"."my_table" ALTER COLUMN "name" SET NOT NULL;`, + }, + }) + }) + }) + + t.Run("multiple changes", func(t *testing.T) { + testDrift(t, DriftTestCase{ + Setup: []string{ + ` + CREATE TABLE my_table ( + id INTEGER PRIMARY KEY, + name TEXT DEFAULT 'foo' + ); + `, + }, + Alter: []string{ + `ALTER TABLE my_table ALTER COLUMN name SET DEFAULT 'bar';`, + `ALTER TABLE my_table ALTER COLUMN name SET NOT NULL;`, + }, + Expected: []string{ + `ALTER TABLE "public"."my_table" ALTER COLUMN "name" SET DEFAULT 'foo'::text;`, + `ALTER TABLE "public"."my_table" ALTER COLUMN "name" DROP NOT NULL;`, + }, + }) + }) + }) + + t.Run("constraint", func(t *testing.T) { + t.Run("missing", func(t *testing.T) { + testDrift(t, DriftTestCase{ + Setup: []string{ + ` + CREATE TABLE my_table ( + id INTEGER PRIMARY KEY, + name TEXT UNIQUE + ); + `, + }, + Alter: []string{ + `ALTER TABLE my_table DROP CONSTRAINT my_table_name_key;`, + }, + Expected: []string{ + `ALTER TABLE "public"."my_table" ADD CONSTRAINT "my_table_name_key" UNIQUE (name);`, + }, + }) + }) + + t.Run("extra", func(t *testing.T) { + testDrift(t, DriftTestCase{ + Setup: []string{ + ` + CREATE TABLE my_table ( + id INTEGER PRIMARY KEY, + name TEXT + ); + `, + }, + Alter: []string{ + `ALTER TABLE my_table ADD CONSTRAINT my_table_name_key UNIQUE (name);`, + }, + Expected: []string{ + `ALTER TABLE "public"."my_table" DROP CONSTRAINT IF EXISTS "my_table_name_key";`, + }, + }) + }) + + t.Run("alter", func(t *testing.T) { + testDrift(t, DriftTestCase{ + Setup: []string{ + ` + CREATE TABLE my_table ( + id INTEGER PRIMARY KEY, + name TEXT + ); + `, + }, + Alter: []string{ + `ALTER TABLE my_table DROP CONSTRAINT my_table_pkey;`, + `ALTER TABLE my_table ADD CONSTRAINT my_table_pkey UNIQUE (name);`, + }, + Expected: []string{ + `ALTER TABLE "public"."my_table" DROP CONSTRAINT IF EXISTS "my_table_pkey";`, + `ALTER TABLE "public"."my_table" ADD CONSTRAINT "my_table_pkey" PRIMARY KEY (id);`, + }, + }) + }) + }) + + t.Run("index", func(t *testing.T) { + t.Run("missing", func(t *testing.T) { + testDrift(t, DriftTestCase{ + Setup: []string{ + ` + CREATE SCHEMA s; + CREATE TABLE s.my_table ( + id INTEGER PRIMARY KEY, + name TEXT + ); + `, + `CREATE INDEX my_table_name_idx ON s.my_table (name);`, + }, + Alter: []string{ + `DROP INDEX s.my_table_name_idx;`, + }, + Expected: []string{ + `CREATE INDEX my_table_name_idx ON s.my_table USING btree (name);`, + }, + }) + }) + + t.Run("extra", func(t *testing.T) { + testDrift(t, DriftTestCase{ + Setup: []string{ + ` + CREATE SCHEMA s; + CREATE TABLE s.my_table ( + id INTEGER PRIMARY KEY, + name TEXT + ); + `, + }, + Alter: []string{ + `CREATE INDEX my_table_name_idx ON s.my_table (name);`, + }, + Expected: []string{ + `DROP INDEX IF EXISTS "s"."my_table_name_idx";`, + }, + }) + }) + + t.Run("alter", func(t *testing.T) { + testDrift(t, DriftTestCase{ + Setup: []string{ + ` + CREATE SCHEMA s; + CREATE TABLE s.my_table ( + id INTEGER PRIMARY KEY, + name TEXT + ); + `, + `CREATE INDEX my_table_name_idx ON s.my_table (name);`, + }, + Alter: []string{ + `DROP INDEX s.my_table_name_idx;`, + `CREATE INDEX my_table_name_idx ON s.my_table (name DESC);`, + }, + Expected: []string{ + `DROP INDEX IF EXISTS "s"."my_table_name_idx";`, + `CREATE INDEX my_table_name_idx ON s.my_table USING btree (name);`, + }, + }) + }) + }) + }) + }) } func TestDrift_Sequences(t *testing.T) { t.Run("create", func(t *testing.T) { - testDrift(t, - // setup - `CREATE SEQUENCE my_seq AS bigint;`, - // alter - `DROP SEQUENCE my_seq;`, - // expected - `CREATE SEQUENCE IF NOT EXISTS "public"."my_seq" AS bigint INCREMENT BY 1 MINVALUE 1 MAXVALUE 9223372036854775807 START WITH 1 NO CYCLE;`, - ) + testDrift(t, DriftTestCase{ + Setup: []string{`CREATE SEQUENCE my_seq AS bigint;`}, + Alter: []string{`DROP SEQUENCE my_seq;`}, + Expected: []string{`CREATE SEQUENCE IF NOT EXISTS "public"."my_seq" AS bigint INCREMENT BY 1 MINVALUE 1 MAXVALUE 9223372036854775807 START WITH 1 NO CYCLE;`}, + }) }) t.Run("drop", func(t *testing.T) { - testDrift(t, - // setup - ``, - // alter - `CREATE SEQUENCE my_seq;`, - // expected - `DROP SEQUENCE IF EXISTS "public"."my_seq";`, - ) + testDrift(t, DriftTestCase{ + Setup: []string{}, + Alter: []string{`CREATE SEQUENCE my_seq;`}, + Expected: []string{`DROP SEQUENCE IF EXISTS "public"."my_seq";`}, + }) }) t.Run("alter", func(t *testing.T) { - testDrift(t, - // setup - `CREATE SEQUENCE my_seq AS bigint INCREMENT BY 2 MINVALUE 2 MAXVALUE 12000 START WITH 2 CYCLE;`, - // alter - ` - DROP SEQUENCE my_seq; - CREATE SEQUENCE my_seq AS int INCREMENT BY 1 MINVALUE 1 MAXVALUE 24000 START WITH 1 NO CYCLE; - SELECT setval('my_seq', 43, true); - `, - // expected - `ALTER SEQUENCE IF EXISTS "public"."my_seq" AS bigint INCREMENT BY 2 MINVALUE 2 MAXVALUE 12000 START WITH 2 CYCLE;`, - ) + testDrift(t, DriftTestCase{ + Setup: []string{ + `CREATE SEQUENCE my_seq AS bigint INCREMENT BY 2 MINVALUE 2 MAXVALUE 12000 START WITH 2 CYCLE;`, + }, + Alter: []string{ + `DROP SEQUENCE my_seq;`, + `CREATE SEQUENCE my_seq AS int INCREMENT BY 1 MINVALUE 1 MAXVALUE 24000 START WITH 1 NO CYCLE;`, + `SELECT setval('my_seq', 43, true);`, + }, + Expected: []string{ + `ALTER SEQUENCE IF EXISTS "public"."my_seq" AS bigint INCREMENT BY 2 MINVALUE 2 MAXVALUE 12000 START WITH 2 CYCLE;`, + }, + }) }) } func TestDrift_Views(t *testing.T) { t.Run("create", func(t *testing.T) { - testDrift(t, - // setup - `CREATE VIEW my_view AS SELECT 1 AS one;`, - // alter - `DROP VIEW my_view;`, - // expected - `CREATE OR REPLACE VIEW "public"."my_view" AS SELECT 1 AS one;`, - ) + testDrift(t, DriftTestCase{ + Setup: []string{`CREATE VIEW my_view AS SELECT 1 AS one;`}, + Alter: []string{`DROP VIEW my_view;`}, + Expected: []string{`CREATE OR REPLACE VIEW "public"."my_view" AS SELECT 1 AS one;`}, + }) }) t.Run("drop", func(t *testing.T) { - testDrift(t, - // setup - ``, - // alter - `CREATE VIEW my_view AS SELECT 1 AS one;`, - // expected - `DROP VIEW IF EXISTS "public"."my_view";`, - ) + testDrift(t, DriftTestCase{ + Setup: []string{}, + Alter: []string{`CREATE VIEW my_view AS SELECT 1 AS one;`}, + Expected: []string{`DROP VIEW IF EXISTS "public"."my_view";`}, + }) }) t.Run("alter", func(t *testing.T) { - testDrift(t, - // setup - `CREATE VIEW my_view AS SELECT 1 AS one;`, - // alter - ` - DROP VIEW my_view; - CREATE OR REPLACE VIEW my_view AS SELECT 2 AS two; - `, - // expected - `DROP VIEW IF EXISTS "public"."my_view";`, - `CREATE OR REPLACE VIEW "public"."my_view" AS SELECT 1 AS one;`, - ) + testDrift(t, DriftTestCase{ + Setup: []string{ + `CREATE VIEW my_view AS SELECT 1 AS one;`, + }, + Alter: []string{ + `DROP VIEW my_view;`, + `CREATE OR REPLACE VIEW my_view AS SELECT 2 AS two;`, + }, + Expected: []string{ + `DROP VIEW IF EXISTS "public"."my_view";`, + `CREATE OR REPLACE VIEW "public"."my_view" AS SELECT 1 AS one;`, + }, + }) + }) + + t.Run("dependency closure", func(t *testing.T) { + t.Run("create", func(t *testing.T) { + testDrift(t, DriftTestCase{ + Setup: []string{ + `CREATE TABLE t (x int);`, + `CREATE VIEW v_foo AS SELECT * FROM t;`, + `CREATE VIEW v_bar AS SELECT * FROM v_foo;`, + `CREATE VIEW v_baz AS SELECT * FROM t UNION SELECT * FROM v_foo;`, + `CREATE VIEW v_bonk AS SELECT * FROM t UNION SELECT * FROM v_bar;`, + `CREATE VIEW v_quux AS SELECT * FROM v_foo UNION SELECT * FROM v_bar;`, + `CREATE VIEW v_one AS SELECT 1 AS one;`, + `CREATE VIEW v_two AS SELECT 2 AS two;`, + }, + Alter: []string{ + `DROP VIEW v_two;`, + `DROP VIEW v_one;`, + `DROP VIEW v_quux;`, + `DROP VIEW v_bonk;`, + `DROP VIEW v_baz;`, + `DROP VIEW v_bar;`, + `DROP VIEW v_foo;`, + }, + Expected: []string{ + `CREATE OR REPLACE VIEW "public"."v_foo" AS SELECT x + FROM t;`, + `CREATE OR REPLACE VIEW "public"."v_bar" AS SELECT x + FROM v_foo;`, + `CREATE OR REPLACE VIEW "public"."v_baz" AS SELECT t.x + FROM t +UNION + SELECT v_foo.x + FROM v_foo;`, + `CREATE OR REPLACE VIEW "public"."v_bonk" AS SELECT t.x + FROM t +UNION + SELECT v_bar.x + FROM v_bar;`, + `CREATE OR REPLACE VIEW "public"."v_one" AS SELECT 1 AS one;`, + `CREATE OR REPLACE VIEW "public"."v_quux" AS SELECT v_foo.x + FROM v_foo +UNION + SELECT v_bar.x + FROM v_bar;`, + `CREATE OR REPLACE VIEW "public"."v_two" AS SELECT 2 AS two;`, + }, + }) + }) + + t.Run("drop", func(t *testing.T) { + testDrift(t, DriftTestCase{ + Setup: []string{ + `CREATE TABLE t (x int);`, + }, + Alter: []string{ + `CREATE VIEW v_foo AS SELECT * FROM t;`, + `CREATE VIEW v_bar AS SELECT * FROM v_foo;`, + `CREATE VIEW v_baz AS SELECT * FROM t UNION SELECT * FROM v_foo;`, + `CREATE VIEW v_bonk AS SELECT * FROM t UNION SELECT * FROM v_bar;`, + `CREATE VIEW v_quux AS SELECT * FROM v_foo UNION SELECT * FROM v_bar;`, + `CREATE VIEW v_one AS SELECT 1 AS one;`, + `CREATE VIEW v_two AS SELECT 2 AS two;`, + }, + Expected: []string{ + `DROP VIEW IF EXISTS "public"."v_baz";`, + `DROP VIEW IF EXISTS "public"."v_bonk";`, + `DROP VIEW IF EXISTS "public"."v_one";`, + `DROP VIEW IF EXISTS "public"."v_quux";`, + `DROP VIEW IF EXISTS "public"."v_bar";`, + `DROP VIEW IF EXISTS "public"."v_foo";`, + `DROP VIEW IF EXISTS "public"."v_two";`, + }, + }) + }) }) } func TestDrift_Triggers(t *testing.T) { t.Run("create", func(t *testing.T) { - testDrift(t, - // setup - ` - CREATE SCHEMA a; - CREATE SCHEMA b; - CREATE TABLE a.t (x int); - - CREATE FUNCTION b.f() RETURNS TRIGGER AS $$ + testDrift(t, DriftTestCase{ + Setup: []string{ + `CREATE SCHEMA a;`, + `CREATE SCHEMA b;`, + `CREATE TABLE a.t (x int);`, + + `CREATE FUNCTION b.f() RETURNS TRIGGER AS $$ BEGIN RETURN NEW; END; - $$ LANGUAGE plpgsql; - - CREATE TRIGGER "t-insert" BEFORE INSERT ON a.t FOR EACH ROW EXECUTE FUNCTION b.f(); - `, - // alter - `DROP TRIGGER "t-insert" ON a.t;`, - // expected - `CREATE TRIGGER "t-insert" BEFORE INSERT ON a.t FOR EACH ROW EXECUTE FUNCTION b.f();`, - ) + $$ LANGUAGE plpgsql;`, + + `CREATE TRIGGER "t-insert" BEFORE INSERT ON a.t FOR EACH ROW EXECUTE FUNCTION b.f();`, + }, + Alter: []string{ + `DROP TRIGGER "t-insert" ON a.t;`, + }, + Expected: []string{ + `CREATE TRIGGER "t-insert" BEFORE INSERT ON a.t FOR EACH ROW EXECUTE FUNCTION b.f();`, + }, + }) }) t.Run("drop", func(t *testing.T) { - testDrift(t, - // setup - ` - CREATE SCHEMA a; - CREATE SCHEMA b; - CREATE TABLE a.t (x int); - CREATE FUNCTION b.f() RETURNS TRIGGER AS $$ + testDrift(t, DriftTestCase{ + Setup: []string{ + `CREATE SCHEMA a;`, + `CREATE SCHEMA b;`, + `CREATE TABLE a.t (x int);`, + `CREATE FUNCTION b.f() RETURNS TRIGGER AS $$ BEGIN RETURN NEW; END; - $$ LANGUAGE plpgsql; - `, - // alter - `CREATE TRIGGER "t-insert" BEFORE INSERT ON a.t FOR EACH ROW EXECUTE FUNCTION b.f();`, - // expected - `DROP TRIGGER IF EXISTS "t-insert" ON "a"."t";`, - ) + $$ LANGUAGE plpgsql;`, + }, + Alter: []string{ + `CREATE TRIGGER "t-insert" BEFORE INSERT ON a.t FOR EACH ROW EXECUTE FUNCTION b.f();`, + }, + Expected: []string{ + `DROP TRIGGER IF EXISTS "t-insert" ON "a"."t";`, + }, + }) }) } -// TODO - more types? -// TODO - test order between drift types - // // -func testDrift(t *testing.T, setupQuery string, alterQuery string, expectedQueries ...string) { +type DriftTestCase struct { + Setup []string + Alter []string + Expected []string +} + +func testDrift(t *testing.T, testCase DriftTestCase) { t.Helper() db := NewTestDB(t) ctx := context.Background() - // Setup initial schema and describe it - require.NoError(t, db.Exec(ctx, RawQuery(setupQuery))) + // Execute all setup queries + for _, query := range testCase.Setup { + require.NoError(t, db.Exec(ctx, RawQuery(query)), "query=%q", query) + } + + // Describe the initial schema before, err := DescribeSchema(ctx, db) if err != nil { t.Fatalf("Failed to describe schema: %v", err) } - // Apply schema alterations and describe it - require.NoError(t, db.Exec(ctx, RawQuery(alterQuery))) + // Execute all alter queries + for _, query := range testCase.Alter { + require.NoError(t, db.Exec(ctx, RawQuery(query)), "query=%q", query) + } + + // Describe the altered schema after, err := DescribeSchema(ctx, db) require.NoError(t, err) - // Calculate drift between initial and altered schema - // Assert that it contains our expected repair queries - require.Equal(t, expectedQueries, Compare(before, after)) + // Compare schemas and assert expected drift + require.Equal(t, testCase.Expected, Compare(before, after)) - // Perform the repair and ensure there's no additional drift - for _, query := range expectedQueries { - require.NoError(t, db.Exec(ctx, RawQuery(query))) + // Apply the expected repair queries + for _, query := range testCase.Expected { + require.NoError(t, db.Exec(ctx, RawQuery(query)), "query=%q", query) } + + // Verify that the drift has been repaired repaired, err := DescribeSchema(ctx, db) require.NoError(t, err) assert.Empty(t, Compare(before, repaired)) diff --git a/query.go b/query.go index c13c04a..d136612 100644 --- a/query.go +++ b/query.go @@ -126,7 +126,7 @@ func tokenize(format string) []string { } // capture from last match to end of string - // note: if there were no matches offset will be zero + // NOTE: if there were no matches offset will be zero return append(parts, format[offset:]) } diff --git a/testdata/golden/TestDescribeSchema.golden b/testdata/golden/TestDescribeSchema.golden index 5532cbc..7cef74d 100644 --- a/testdata/golden/TestDescribeSchema.golden +++ b/testdata/golden/TestDescribeSchema.golden @@ -92,32 +92,27 @@ $function$ Name: "comments", Columns: []pgutil.ColumnDescription{ { - Name: "content", - Index: 4, - Type: "text", + Name: "content", + Type: "text", }, { Name: "created_at", - Index: 5, Type: "timestamp with time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP", }, { Name: "id", - Index: 1, Type: "uuid", Default: "uuid_generate_v4()", }, { - Name: "post_id", - Index: 2, - Type: "uuid", + Name: "post_id", + Type: "uuid", }, { - Name: "user_id", - Index: 3, - Type: "integer", + Name: "user_id", + Type: "integer", }, }, Constraints: []pgutil.ConstraintDescription{ @@ -159,40 +154,34 @@ $function$ Columns: []pgutil.ColumnDescription{ { Name: "content", - Index: 4, Type: "text", IsNullable: true, }, { Name: "created_at", - Index: 5, Type: "timestamp with time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP", }, { Name: "id", - Index: 1, Type: "uuid", Default: "uuid_generate_v4()", }, { Name: "last_modified", - Index: 6, Type: "timestamp with time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP", }, { Name: "title", - Index: 3, Type: "character varying(200)", CharacterMaximumLength: 200, }, { - Name: "user_id", - Index: 2, - Type: "integer", + Name: "user_id", + Type: "integer", }, }, Constraints: []pgutil.ConstraintDescription{{ @@ -226,45 +215,38 @@ $function$ Columns: []pgutil.ColumnDescription{ { Name: "created_at", - Index: 6, Type: "timestamp with time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP", }, { Name: "email", - Index: 3, Type: "character varying(100)", CharacterMaximumLength: 100, }, { Name: "id", - Index: 1, Type: "integer", Default: "nextval('user_id_seq'::regclass)", }, { Name: "last_modified", - Index: 7, Type: "timestamp with time zone", IsNullable: true, Default: "CURRENT_TIMESTAMP", }, { Name: "mood", - Index: 5, Type: "mood", IsNullable: true, }, { Name: "password_hash", - Index: 4, Type: "character varying(100)", CharacterMaximumLength: 100, }, { Name: "username", - Index: 2, Type: "character varying(50)", CharacterMaximumLength: 50, },