Skip to content

Commit

Permalink
feat(prefix): adding extension prefix (#21)
Browse files Browse the repository at this point in the history
* adding extension prefix

* adding extension prefix

* passing through
  • Loading branch information
Jacobbrewer1 authored Oct 16, 2024
1 parent babf9b2 commit 9cd969c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
6 changes: 5 additions & 1 deletion cmd/schema/cmd_generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ type generateCmd struct {

// sqlLocation is the location of the SQL files to use.
sqlLocation string

// fileExtensionPrefix is the prefix to add to the generated file extension.
fileExtensionPrefix string
}

func (g *generateCmd) Name() string {
Expand All @@ -39,6 +42,7 @@ func (g *generateCmd) SetFlags(f *flag.FlagSet) {
f.StringVar(&g.templatesLocation, "templates", "./templates/*.tmpl", "The location of the templates to use.")
f.StringVar(&g.outputLocation, "out", ".", "The location to write the generated files to.")
f.StringVar(&g.sqlLocation, "sql", "./pkg/models/*.sql", "The location of the SQL files to use.")
f.StringVar(&g.fileExtensionPrefix, "extension", "", "The prefix to add to the generated file extension.")
}

func (g *generateCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface{}) subcommands.ExitStatus {
Expand All @@ -65,7 +69,7 @@ func (g *generateCmd) Execute(_ context.Context, _ *flag.FlagSet, _ ...interface
return subcommands.ExitFailure
}

err = generation.RenderTemplates(tables, g.templatesLocation, g.outputLocation)
err = generation.RenderTemplates(tables, g.templatesLocation, g.outputLocation, g.fileExtensionPrefix)
if err != nil {
slog.Error("Error rendering templates", slog.String("templatesLocation", g.templatesLocation), slog.String("outputLocation", g.outputLocation), slog.String("error", err.Error()))
return subcommands.ExitFailure
Expand Down
17 changes: 13 additions & 4 deletions pkg/services/generation/templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type templateInfo struct {
Table *models.Table
}

func RenderTemplates(tables []*models.Table, templatesLoc, outputLoc string) error {
func RenderTemplates(tables []*models.Table, templatesLoc, outputLoc string, fileExtensionPrefix string) error {
tmpl, err := template.New("model.tmpl").Funcs(sprig.TxtFuncMap()).Funcs(Helpers).ParseGlob(templatesLoc)
if err != nil {
return fmt.Errorf("error parsing templates: %w", err)
Expand All @@ -27,16 +27,25 @@ func RenderTemplates(tables []*models.Table, templatesLoc, outputLoc string) err
if err = generate(&templateInfo{
OutputDir: outputLoc,
Table: t,
}, tmpl, outputLoc); err != nil {
}, tmpl, outputLoc, fileExtensionPrefix); err != nil {
return fmt.Errorf("error generating template: %w", err)
}
}

return nil
}

func generate(t *templateInfo, tmpl *template.Template, outputLoc string) error {
fn := filepath.Join(outputLoc, xstrings.ToSnakeCase(t.Table.Name)+".go")
func generate(t *templateInfo, tmpl *template.Template, outputLoc string, fileExtensionPrefix string) error {
ext := ".go"
if fileExtensionPrefix != "" {
// Add a period if it's not already there
if fileExtensionPrefix[0] != '.' {
fileExtensionPrefix = "." + fileExtensionPrefix
}
ext = fileExtensionPrefix + ext
}

fn := filepath.Join(outputLoc, xstrings.ToSnakeCase(t.Table.Name)+ext)
if err := os.MkdirAll(filepath.Dir(fn), 0750); err != nil {
return err
}
Expand Down

0 comments on commit 9cd969c

Please sign in to comment.