diff --git a/cmd/goose/main.go b/cmd/goose/main.go index 9c9386572..a9e10b282 100644 --- a/cmd/goose/main.go +++ b/cmd/goose/main.go @@ -19,20 +19,21 @@ func main() { flags.Parse(os.Args[1:]) args := flags.Args() + if len(args) == 0 || args[0] == "-h" || args[0] == "--help" { + flags.Usage() + return + } - if len(args) > 1 && args[0] == "create" { + switch args[0] { + case "create": if err := goose.Run("create", nil, *dir, args[1:]...); err != nil { log.Fatalf("goose run: %v", err) } return - } - - // TODO clean up arg/flag parsing flow - if args[0] == "fix" { + case "fix": if err := goose.Run("fix", nil, *dir); err != nil { log.Fatalf("goose run: %v", err) } - return } if len(args) < 3 { @@ -40,11 +41,6 @@ func main() { return } - if args[0] == "-h" || args[0] == "--help" { - flags.Usage() - return - } - driver, dbstring, command := args[0], args[1], args[2] if err := goose.SetDialect(driver); err != nil { diff --git a/examples/go-migrations/main.go b/examples/go-migrations/main.go index bb323afd1..28970af22 100644 --- a/examples/go-migrations/main.go +++ b/examples/go-migrations/main.go @@ -25,20 +25,21 @@ func main() { flags.Parse(os.Args[1:]) args := flags.Args() + if len(args) == 0 || args[0] == "-h" || args[0] == "--help" { + flags.Usage() + return + } - if len(args) > 1 && args[0] == "create" { + switch args[0] { + case "create": if err := goose.Run("create", nil, *dir, args[1:]...); err != nil { log.Fatalf("goose run: %v", err) } return - } - - // TODO clean up arg/flag parsing flow - if args[0] == "fix" { + case "fix": if err := goose.Run("fix", nil, *dir); err != nil { log.Fatalf("goose run: %v", err) } - return } if len(args) < 3 { diff --git a/goose_test.go b/goose_test.go index 06ae60a6f..9af0de401 100644 --- a/goose_test.go +++ b/goose_test.go @@ -16,11 +16,18 @@ func TestDefaultBinary(t *testing.T) { "./bin/goose -dir=examples/sql-migrations sqlite3 sql.db version", "./bin/goose -dir=examples/sql-migrations sqlite3 sql.db down", "./bin/goose -dir=examples/sql-migrations sqlite3 sql.db status", + "./bin/goose", } for _, cmd := range commands { args := strings.Split(cmd, " ") - out, err := exec.Command(args[0], args[1:]...).CombinedOutput() + command := args[0] + var params []string + if len(args) > 1 { + params = args[1:] + } + + out, err := exec.Command(command, params...).CombinedOutput() if err != nil { t.Fatalf("%s:\n%v\n\n%s", err, cmd, out) }