Skip to content

Commit

Permalink
WIP.
Browse files Browse the repository at this point in the history
  • Loading branch information
efritz committed Sep 18, 2024
1 parent 86c7a4b commit 1c79d89
Show file tree
Hide file tree
Showing 9 changed files with 865 additions and 417 deletions.
6 changes: 0 additions & 6 deletions describe_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (

type ColumnDescription struct {
Name string
Index int
Type string
IsNullable bool
Default string
Expand All @@ -21,7 +20,6 @@ type ColumnDescription struct {
func (d ColumnDescription) Equals(other ColumnDescription) bool {
return true &&
d.Name == other.Name &&
d.Index == other.Index &&
d.Type == other.Type &&
d.IsNullable == other.IsNullable &&
d.Default == other.Default &&
Expand All @@ -36,7 +34,6 @@ type column struct {
Namespace string
TableName string
Name string
Index int
Type string
IsNullable bool
Default *string
Expand All @@ -58,7 +55,6 @@ var scanColumns = NewSliceScanner(func(s Scanner) (c column, _ error) {
&c.Namespace,
&c.TableName,
&c.Name,
&c.Index,
&c.Type,
&isNullable,
&c.Default,
Expand All @@ -81,7 +77,6 @@ func describeColumns(ctx context.Context, db DB) (map[string][]ColumnDescription
c.table_schema AS namespace,
c.table_name AS name,
c.column_name AS column_name,
c.ordinal_position AS index,
CASE
WHEN c.data_type = 'ARRAY' THEN COALESCE((
SELECT e.data_type
Expand Down Expand Up @@ -124,7 +119,6 @@ func describeColumns(ctx context.Context, db DB) (map[string][]ColumnDescription

columnMap[key] = append(columnMap[key], ColumnDescription{
Name: column.Name,
Index: column.Index,
Type: column.Type,
IsNullable: column.IsNullable,
Default: deref(column.Default),
Expand Down
165 changes: 108 additions & 57 deletions drift.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"cmp"
"fmt"
"slices"
"sort"
"strings"
)

Expand Down Expand Up @@ -66,46 +67,122 @@ func Compare(a, b SchemaDescription) (statements []string) {
// views : depends on tables, columns, views
// triggers : depends on tables, columns, functions

cmpWithKey := func(a, b ddlStatement) int { return cmp.Compare(a.key, b.key) }
defaultSort := func(statements []ddlStatement) { slices.SortFunc(statements, cmpWithKey) }

//
// TODO - rewrite this
//

// sortByClosure := func(cls closure) func([]ddlStatement) {
// return func(statements []ddlStatement) {
// slices.SortFunc(statements, func(a, b ddlStatement) int {
// if _, ok := cls[a.key][b.key]; ok {
// return -1
// }
// if _, ok := cls[b.key][a.key]; ok {
// return +1
// }

// return cmp.Compare(a.key, b.key)
// })
// }
// }

sortByClosure := func(cls closure) func([]ddlStatement) {
return func(statements []ddlStatement) {
keys := map[string]ddlStatement{}
for _, stmt := range statements {
keys[stmt.key] = stmt
}

cls2 := map[string][]string{}
for _, stmt := range statements {
cls2[stmt.key] = nil
}

for k, v := range cls {
for k2 := range v {
if _, ok := keys[k]; ok {
if _, ok := cls2[k2]; ok {
cls2[k2] = append(cls2[k2], k)
}
}
}
}

var sorted []ddlStatement
for len(cls2) > 0 {
var candidates []string
for k, v := range cls2 {
if len(v) == 0 {
candidates = append(candidates, k)
}
}

if len(candidates) == 0 {
panic("cycle detected in closure, cannot perform topological sort")
}

sort.Strings(candidates)
top := candidates[0]
sorted = append(sorted, keys[top])

for k, v := range cls2 {
for i := 0; i < len(v); i++ {
if v[i] == top {
cls2[k] = append(v[:i], v[i+1:]...)
}
}
}

delete(cls2, top)
}

for i := range statements {
statements[i] = sorted[i]
}
}
}

createClosure, dropClosure := viewClosures(a, b)
sortCreateViews := func(statements []ddlStatement) { slices.SortFunc(statements, cmpWithClosure(createClosure)) }
sortDropViews := func(statements []ddlStatement) { slices.SortFunc(statements, cmpWithClosure(dropClosure)) }
sortCreateViews := sortByClosure(createClosure)
sortDropViews := sortByClosure(dropClosure)

order := []struct {
statementType string
objectType string
order func(statements []ddlStatement)
}{
{"drop", "trigger", nil},
{"drop", "trigger", defaultSort},
{"drop", "view", sortDropViews},
{"drop", "index", nil},
{"drop", "constraint", nil},
{"drop", "column", nil},
{"drop", "sequence", nil},
{"drop", "table", nil},
{"drop", "function", nil},
{"drop", "enum", nil},
{"drop", "extension", nil},
{"create", "extension", nil},
{"create", "enum", nil},
{"replace", "enum", nil},
{"create", "function", nil},
{"replace", "function", nil},
{"create", "table", nil},
{"create", "sequence", nil},
{"replace", "sequence", nil},
{"create", "column", nil},
{"replace", "column", nil},
{"create", "constraint", nil},
{"create", "index", nil},
{"drop", "index", defaultSort},
{"drop", "constraint", defaultSort},
{"drop", "column", defaultSort},
{"drop", "sequence", defaultSort},
{"drop", "table", defaultSort},
{"drop", "function", defaultSort},
{"drop", "enum", defaultSort},
{"drop", "extension", defaultSort},
{"create", "extension", defaultSort},
{"create", "enum", defaultSort},
{"replace", "enum", defaultSort},
{"create", "function", defaultSort},
{"replace", "function", defaultSort},
{"create", "table", defaultSort},
{"create", "sequence", defaultSort},
{"replace", "sequence", defaultSort},
{"create", "column", defaultSort},
{"replace", "column", defaultSort},
{"create", "constraint", defaultSort},
{"create", "index", defaultSort},
{"create", "view", sortCreateViews},
{"create", "trigger", nil},
{"create", "trigger", defaultSort},
}

for _, o := range order {
filtered := filter(o.statementType, o.objectType)
if o.order != nil {
o.order(filtered)
}
o.order(filtered)

for _, statement := range filtered {
statements = append(statements, statement.statements...)
Expand All @@ -119,44 +196,30 @@ func Compare(a, b SchemaDescription) (statements []string) {
//
//

// TODO - document
func viewClosures(a, b SchemaDescription) (createClosure, dropClosure closure) {
// TODO
// keys := map[string]struct{}{}
// for _, view := range a.Views {
// keys[fmt.Sprintf("%q.%q", view.Namespace, view.Name)] = struct{}{}
// }
// for _, view := range a.Views {
// keys[fmt.Sprintf("%q.%q", view.Namespace, view.Name)] = struct{}{}
// }

createClosure = closure{}
for _, dependency := range a.ColumnDependencies {
sourceKey := fmt.Sprintf("%q.%q", dependency.SourceNamespace, dependency.SourceTableOrViewName)
dependencyKey := fmt.Sprintf("%q.%q", dependency.UsedNamespace, dependency.UsedTableOrView)

// TODO
// if _, ok := keys[dependencyKey]; ok {
if _, ok := createClosure[sourceKey]; !ok {
createClosure[sourceKey] = map[string]struct{}{}
}

createClosure[sourceKey][dependencyKey] = struct{}{}
// }
}

dropClosure = closure{}
for _, dependency := range b.ColumnDependencies {
sourceKey := fmt.Sprintf("%q.%q", dependency.SourceNamespace, dependency.SourceTableOrViewName)
dependencyKey := fmt.Sprintf("%q.%q", dependency.UsedNamespace, dependency.UsedTableOrView)

// TODO
// if _, ok := keys[sourceKey]; ok {
if _, ok := dropClosure[dependencyKey]; !ok {
dropClosure[dependencyKey] = map[string]struct{}{}
}

dropClosure[dependencyKey][sourceKey] = struct{}{}
// }
}

transitiveClosure(createClosure)
Expand All @@ -170,6 +233,7 @@ func viewClosures(a, b SchemaDescription) (createClosure, dropClosure closure) {

type closure map[string]map[string]struct{}

// TODO - rename internals
func transitiveClosure(m closure) {
changed := true
for changed {
Expand All @@ -188,19 +252,6 @@ func transitiveClosure(m closure) {
}
}

func cmpWithClosure(createClosure closure) func(a, b ddlStatement) int {
return func(a, b ddlStatement) int {
if _, ok := createClosure[a.key][b.key]; ok {
return -1
}
if _, ok := createClosure[b.key][a.key]; ok {
return +1
}

return cmp.Compare(a.key, b.key)
}
}

//
//
//
Expand All @@ -227,8 +278,8 @@ type alterer[T any] interface {

type ddlStatement struct {
key string
statementType string // TODO - enum
objectType string // TODO - enum
statementType string
objectType string
statements []string
}

Expand Down
Loading

0 comments on commit 1c79d89

Please sign in to comment.