diff --git a/cmd/bfg/main.go b/cmd/bfg/main.go index 8c5717e..a9b368a 100644 --- a/cmd/bfg/main.go +++ b/cmd/bfg/main.go @@ -2,6 +2,7 @@ package main import ( "bufio" + "errors" "flag" "fmt" "io" @@ -29,28 +30,15 @@ func main() { var sourceBuf, inputBuf io.ByteReader source := flag.Arg(0) input := flag.Arg(1) - if source == "" { - fmt.Println("please provide a source file") + sourceBuf, err := inputReader(source, false) + if err != nil { + fmt.Println("error opening program: err:", err) os.Exit(1) - } else if source == "-" { - sourceBuf = bufio.NewReader(os.Stdin) - } else { - sourceFile, err := os.Open(source) - if err != nil { - fmt.Println("error opening source file: err:", err) - os.Exit(1) - } - sourceBuf = bufio.NewReader(sourceFile) } - if input == "" { - inputBuf = bufio.NewReader(os.Stdin) - } else { - file, err := os.Open(input) - if err != nil { - fmt.Println("error opening input file: err:", err) - os.Exit(1) - } - inputBuf = bufio.NewReader(file) + inputBuf, err = inputReader(input, true) + if err != nil { + fmt.Println("error opening input: err:", err) + os.Exit(1) } outputBuf := bufio.NewWriter(os.Stdout) defer outputBuf.Flush() @@ -66,3 +54,21 @@ func main() { parser.Execute(program, inputBuf, outputBuf) } } + +func inputReader(pathStr string, useDefault bool) (buff io.ByteReader, err error) { + if useDefault && pathStr == "" { + pathStr = "-" + } + if pathStr == "" { + return nil, errors.New("no path provided") + } else if pathStr == "-" { + buff = bufio.NewReader(os.Stdin) + } else { + sourceFile, err := os.Open(pathStr) + if err != nil { + return nil, err + } + buff = bufio.NewReader(sourceFile) + } + return +}