diff --git a/main.go b/main.go index df31d43f..2b03bd21 100644 --- a/main.go +++ b/main.go @@ -4,12 +4,14 @@ package main //go:generate ./gen.sh models import ( + "bytes" "crypto/sha1" "database/sql" "encoding/base64" "encoding/json" "errors" "fmt" + "go/format" "io" "io/ioutil" "log" @@ -336,6 +338,7 @@ func getFileName(args *internal.ArgType, t *internal.TBuf) (string, string, stri // file from files. If the built filename is not already defined, then it calls // the os.OpenFile with the correct parameters depending on the state of args. func getFile(args *internal.ArgType, filename string, pkg string) (*os.File, error) { + var buf bytes.Buffer var f *os.File var err error @@ -354,25 +357,25 @@ func getFile(args *internal.ArgType, filename string, pkg string) (*os.File, err // file didn't originally exist, so add package header if args.Tags != "" { - f.WriteString(`// +build ` + args.Tags + "\n\n") + buf.WriteString(`// +build ` + args.Tags + "\n\n") } generatedText := "Code generated by Xo. DO NOT EDIT.\n\n" switch { case strings.HasSuffix(filename, ".go"): - f.WriteString("// " + generatedText) + buf.WriteString("// " + generatedText) case strings.HasSuffix(filename, ".yml"): fallthrough case strings.HasSuffix(filename, ".graphql"): - f.WriteString("# " + generatedText) + buf.WriteString("# " + generatedText) case strings.HasSuffix(filename, ".sql"): - f.WriteString("-- " + generatedText) + buf.WriteString("-- " + generatedText) } if strings.HasSuffix(filename, ".go") { if strings.HasSuffix(filename, "wire.go") { - if _, err = f.WriteString("//+build wireinject\n\npackage main"); err != nil { + if _, err = buf.WriteString("//+build wireinject\n\npackage main"); err != nil { return nil, err } } else { @@ -383,7 +386,7 @@ func getFile(args *internal.ArgType, filename string, pkg string) (*os.File, err } } } else if strings.HasSuffix(filename, ".yml") { - err = args.TemplateSet().Execute(f, "gqlgen.yml.tpl", args) + err = args.TemplateSet().Execute(buf, "gqlgen.yml.tpl", args) if err != nil { return nil, err } @@ -391,6 +394,13 @@ func getFile(args *internal.ArgType, filename string, pkg string) (*os.File, err args.Package = oldArgPkg + byts, err := format.Source(buf.Bytes()) + if err != nil { + f.Write(buf.Bytes()) + } else { + f.Write(byts) + } + return f, nil }