diff --git a/.vscode/settings.json b/.vscode/settings.json index d5d00188a7..cbb1b7459c 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,3 +1,6 @@ { - "files.insertFinalNewline": true + "files.insertFinalNewline": true, + "files.associations": { + "*.tmpl": "html" + } } diff --git a/Makefile b/Makefile index e9c2fbb899..92c5bd7942 100644 --- a/Makefile +++ b/Makefile @@ -3,6 +3,9 @@ build: go build ./... +build-local: + go build -o ../sqlc-tilby ./cmd/sqlc + install: go install ./... diff --git a/internal/codegen/golang/gen.go b/internal/codegen/golang/gen.go index 5b7977f500..2d6d3bf738 100644 --- a/internal/codegen/golang/gen.go +++ b/internal/codegen/golang/gen.go @@ -41,6 +41,7 @@ type tmplCtx struct { UsesBatch bool OmitSqlcVersion bool BuildTags string + EmitSchemaName bool } func (t *tmplCtx) OutputQuery(sourceName string) bool { @@ -60,6 +61,10 @@ func (t *tmplCtx) codegenEmitPreparedQueries() bool { return t.EmitPreparedQueries } +func (t *tmplCtx) codegenEmitSchemaName() bool { + return t.EmitSchemaName +} + func (t *tmplCtx) codegenQueryMethod(q Query) string { db := "q.db" if t.EmitMethodsWithDBArgument { @@ -187,6 +192,7 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum, SqlcVersion: req.SqlcVersion, BuildTags: options.BuildTags, OmitSqlcVersion: options.OmitSqlcVersion, + EmitSchemaName: options.EmitSchemaName, } if tctx.UsesCopyFrom && !tctx.SQLDriver.IsPGX() && options.SqlDriver != opts.SQLDriverGoSQLDriverMySQL { @@ -218,6 +224,7 @@ func generate(req *plugin.GenerateRequest, options *opts.Options, enums []Enum, "emitPreparedQueries": tctx.codegenEmitPreparedQueries, "queryMethod": tctx.codegenQueryMethod, "queryRetval": tctx.codegenQueryRetval, + "emitSchemaName": tctx.codegenEmitSchemaName, } tmpl := template.Must( diff --git a/internal/codegen/golang/opts/options.go b/internal/codegen/golang/opts/options.go index 30a6c2246c..263b256cf2 100644 --- a/internal/codegen/golang/opts/options.go +++ b/internal/codegen/golang/opts/options.go @@ -25,6 +25,7 @@ type Options struct { EmitEnumValidMethod bool `json:"emit_enum_valid_method,omitempty" yaml:"emit_enum_valid_method"` EmitAllEnumValues bool `json:"emit_all_enum_values,omitempty" yaml:"emit_all_enum_values"` EmitSqlAsComment bool `json:"emit_sql_as_comment,omitempty" yaml:"emit_sql_as_comment"` + EmitSchemaName bool `json:"emit_schema_name,omitempty" yaml:"emit_schema_name"` JsonTagsCaseStyle string `json:"json_tags_case_style,omitempty" yaml:"json_tags_case_style"` Package string `json:"package" yaml:"package"` Out string `json:"out" yaml:"out"` diff --git a/internal/codegen/golang/result.go b/internal/codegen/golang/result.go index 515d0a654f..d6388f383c 100644 --- a/internal/codegen/golang/result.go +++ b/internal/codegen/golang/result.go @@ -141,9 +141,7 @@ func newGoEmbed(embed *plugin.Identifier, structs []Struct, defaultSchema string } fields := make([]Field, len(s.Fields)) - for i, f := range s.Fields { - fields[i] = f - } + copy(fields, s.Fields) return &goEmbed{ modelType: s.Name, @@ -216,6 +214,10 @@ func buildQueries(req *plugin.GenerateRequest, options *opts.Options, structs [] } } + if options.EmitSchemaName { + query.Text = ApplySchema(query.Text) + } + gq := Query{ Cmd: query.Cmd, ConstantName: constantName, diff --git a/internal/codegen/golang/schema.go b/internal/codegen/golang/schema.go new file mode 100644 index 0000000000..d69adaa781 --- /dev/null +++ b/internal/codegen/golang/schema.go @@ -0,0 +1,62 @@ +package golang + +import ( + "fmt" + "strings" +) + +func ApplySchema(query string) string { + tables := make(map[string]bool) + ctes := make(map[string]bool) + + words := strings.Fields(query) + + // Getting all the table names and CTEs + withinCTE := false + for i, word := range words { + upperWord := strings.ToUpper(word) + + if upperWord == "WITH" { + withinCTE = true + continue + } else if withinCTE { + ctes[words[i]] = true + withinCTE = false + continue + } + + if isSQLKeyword(upperWord) { + tables[nextNonKeyword(words, i)] = true + } + } + + // Removing from tables the CTEs + for cte := range ctes { + delete(tables, cte) + } + + // Replacing the table names with the placeholder + for table := range tables { + query = strings.ReplaceAll(query, " "+table, fmt.Sprintf(" `%%s`.%s", table)) + } + + return query +} + +// Helper function to check if a word is a relevant SQL keyword +func isSQLKeyword(word string) bool { + switch word { + case "FROM", "JOIN", "UPDATE", "INTO": + return true + } + return false +} + +func nextNonKeyword(words []string, currentIndex int) string { + for i := currentIndex + 1; i < len(words); i++ { + if !isSQLKeyword(words[i]) && words[i] != "AS" && words[i] != "(" { + return words[i] + } + } + return "" +} diff --git a/internal/codegen/golang/schema_test.go b/internal/codegen/golang/schema_test.go new file mode 100644 index 0000000000..70928776d7 --- /dev/null +++ b/internal/codegen/golang/schema_test.go @@ -0,0 +1,47 @@ +package golang_test + +import ( + "testing" + + "github.com/sqlc-dev/sqlc/internal/codegen/golang" +) + +func TestApplySchema(t *testing.T) { + testCases := []struct { + name string + inputQuery string + expectedQuery string + }{ + { + name: "Simple Query with Single Table", + inputQuery: "SELECT * FROM users", + expectedQuery: "SELECT * FROM `%s`.users", + }, + { + name: "Query with Multiple Tables", + inputQuery: "SELECT * FROM users JOIN orders ON users.id = orders.user_id", + expectedQuery: "SELECT * FROM `%s`.users JOIN `%s`.orders ON `%s`.users.id = `%s`.orders.user_id", + }, + { + name: "Query with CTE", + inputQuery: "WITH user_orders AS (SELECT * FROM users JOIN orders ON users.id = orders.user_id) SELECT * FROM user_orders", + expectedQuery: "WITH user_orders AS (SELECT * FROM `%s`.users JOIN `%s`.orders ON `%s`.users.id = `%s`.orders.user_id) SELECT * FROM user_orders", + }, + { + name: "Query with CTE and Aliases", + inputQuery: "WITH user_orders AS (SELECT * FROM users u JOIN orders o ON u.id = o.user_id) SELECT * FROM user_orders uo", + expectedQuery: "WITH user_orders AS (SELECT * FROM `%s`.users u JOIN `%s`.orders o ON u.id = o.user_id) SELECT * FROM user_orders uo", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := golang.ApplySchema(tc.inputQuery) + if result != tc.expectedQuery { + t.Errorf("Expected:\n%s\nGot:\n%s", tc.expectedQuery, result) + } + }) + } +} + +// "SELECT * FROM `%s`.users JOIN `%s`.orders ON `%s`.users.id `%s`.= `%s`.orders.user_id" diff --git a/internal/codegen/golang/templates/stdlib/queryCode.tmpl b/internal/codegen/golang/templates/stdlib/queryCode.tmpl index cf56000ec6..66efae4f3d 100644 --- a/internal/codegen/golang/templates/stdlib/queryCode.tmpl +++ b/internal/codegen/golang/templates/stdlib/queryCode.tmpl @@ -22,20 +22,53 @@ type {{.Ret.Type}} struct { {{- range .Ret.Struct.Fields}} {{if eq .Cmd ":one"}} {{range .Comments}}//{{.}} {{end -}} -func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) { +{{- if emitSchemaName }} +func (q *Queries) {{.MethodName}}(ctx context.Context, schema string, {{ dbarg }} {{.Arg.Pair}}) ({{.Ret.DefineType}}, error) { +{{- else }} +func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) error { +{{- end }} {{- template "queryCodeStdExec" . }} {{- if or (ne .Arg.Pair .Ret.Pair) (ne .Arg.DefineType .Ret.DefineType) }} var {{.Ret.Name}} {{.Ret.Type}} {{- end}} - err := row.Scan({{.Ret.Scan}}) + {{- if emitSchemaName }} + err := row.Scan({{.Ret.Scan}}) + {{- else }} + err := row.Scan({{.Ret.Scan}}) + {{- end }} return {{.Ret.ReturnName}}, err } {{end}} {{if eq .Cmd ":many"}} +var columnRegex = regexp.MustCompile(`SELECT\s+([\w,\s]+)\s+FROM`); + {{range .Comments}}//{{.}} {{end -}} +{{- if emitSchemaName }} + +type {{.MethodName}}Filter struct { + FieldName string + Value string +} + +type {{.MethodName}}FilterParams struct { + ExactParams []{{.MethodName}}Filter + InParams []{{.MethodName}}Filter + LikeParams []{{.MethodName}}Filter + SinceParams []{{.MethodName}}Filter + MaxParams []{{.MethodName}}Filter + SortParam string + SortOrder string + Pagination bool + PerPage int + PageNumber int +} + +func (q *Queries) {{.MethodName}}(ctx context.Context, schema string, filterParams {{.MethodName}}FilterParams, {{ dbarg }} {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) {; +{{- else }} func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) ([]{{.Ret.DefineType}}, error) { +{{- end }} {{- template "queryCodeStdExec" . }} if err != nil { return nil, err @@ -66,7 +99,11 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} {{if eq .Cmd ":exec"}} {{range .Comments}}//{{.}} {{end -}} +{{- if emitSchemaName }} +func (q *Queries) {{.MethodName}}(ctx context.Context, schema string, {{ dbarg }} {{.Arg.Pair}}) error { +{{- else }} func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) error { +{{- end }} {{- template "queryCodeStdExec" . }} return err } @@ -75,7 +112,11 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} {{if eq .Cmd ":execrows"}} {{range .Comments}}//{{.}} {{end -}} +{{- if emitSchemaName }} +func (q *Queries) {{.MethodName}}(ctx context.Context, schema string, {{ dbarg }} {{.Arg.Pair}}) (int64, error) { +{{- else }} func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) (int64, error) { +{{- end }} {{- template "queryCodeStdExec" . }} if err != nil { return 0, err @@ -87,7 +128,11 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} {{if eq .Cmd ":execlastid"}} {{range .Comments}}//{{.}} {{end -}} +{{- if emitSchemaName }} +func (q *Queries) {{.MethodName}}(ctx context.Context, schema string, {{ dbarg }} {{.Arg.Pair}}) (int64, error) { +{{- else }} func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) (int64, error) { +{{- end }} {{- template "queryCodeStdExec" . }} if err != nil { return 0, err @@ -99,7 +144,11 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} {{if eq .Cmd ":execresult"}} {{range .Comments}}//{{.}} {{end -}} +{{- if emitSchemaName }} +func (q *Queries) {{.MethodName}}(ctx context.Context, schema string, {{ dbarg }} {{.Arg.Pair}}) (sql.Result, error) { +{{else }} func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}}) (sql.Result, error) { +{{- end }} {{- template "queryCodeStdExec" . }} } {{end}} @@ -110,8 +159,13 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} {{define "queryCodeStdExec"}} {{- if .Arg.HasSqlcSlices }} - query := {{.ConstantName}} + query := strings.ReplaceAll({{.ConstantName}}, "%s", schema) + + // Extract the columns from the query; + columns := columnRegex.FindString(query); + var queryParams []interface{} + {{- if .Arg.Struct }} {{- $arg := .Arg }} {{- range .Arg.Struct.Fields }} @@ -137,19 +191,154 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{ dbarg }} {{.Arg.Pair}} for _, v := range {{.Arg.Name}} { queryParams = append(queryParams, v) } - query = strings.Replace(query, "/*SLICE:{{.Arg.Column.Name}}*/?", strings.Repeat(",?", len({{.Arg.Name}}))[1:], 1) + query = strings.ReplaceAll(query, "/*SLICE:{{.Arg.Column.Name}}*/?", strings.Repeat(",?", len({{.Arg.Name}}))[1:]) } else { - query = strings.Replace(query, "/*SLICE:{{.Arg.Column.Name}}*/?", "NULL", 1) + query = strings.ReplaceAll(query, "/*SLICE:{{.Arg.Column.Name}}*/?", "NULL") } {{- end }} + {{- if emitPreparedQueries }} - {{ queryRetval . }} {{ queryMethod . }}(ctx, nil, query, queryParams...) + {{- if emitSchemaName }} + replacedQuery := strings.ReplaceAll(query, "%s", schema) + {{ queryRetval . }} {{ queryMethod . }}(ctx, nil, replacedQuery, queryParams...) + {{- else }} + {{ queryRetval . }} {{ queryMethod . }}(ctx, nil, query, queryParams...) + {{- end }} {{- else}} - {{ queryRetval . }} {{ queryMethod . }}(ctx, query, queryParams...) + {{- if emitSchemaName }} + replacedQuery := strings.ReplaceAll(query, "%s", schema) + {{ queryRetval . }} {{ queryMethod . }}(ctx, replacedQuery, queryParams...) + {{- else }} + {{ queryRetval . }} {{ queryMethod . }}(ctx, query, queryParams...) + {{- end -}} {{- end -}} {{- else if emitPreparedQueries }} {{- queryRetval . }} {{ queryMethod . }}(ctx, q.{{.FieldName}}, {{.ConstantName}}, {{.Arg.Params}}) {{- else}} + {{- if emitSchemaName }} + query := strings.ReplaceAll({{.ConstantName}}, "%s", schema); + + {{- /** + If the query has a filter, we need to add the filter to the query + and the query params. + */}} + + {{- if eq .Cmd ":many"}} + + // Extract the columns from the query; + columns := columnRegex.FindString(query); + + isFirstFilter := true; + var queryParams []interface{}; + + for _, filter := range filterParams.ExactParams { + // Is the filter in the columns? + if !strings.Contains(columns, filter.FieldName) { + continue + } + + if isFirstFilter { + query += " WHERE " + isFirstFilter = false + } else { + query += " AND " + } + + query += filter.FieldName + " = ?" + queryParams = append(queryParams, filter.Value) + }; + + for _, filter := range filterParams.InParams { + // Is the filter in the columns? + if !strings.Contains(columns, filter.FieldName) { + continue + } + + if isFirstFilter { + query += " WHERE " + isFirstFilter = false + } else { + query += " AND " + } + + query += filter.FieldName + " IN (?)" + queryParams = append(queryParams, filter.Value) + }; + + for _, filter := range filterParams.LikeParams { + // Is the filter in the columns? + if !strings.Contains(columns, filter.FieldName) { + continue + } + + if isFirstFilter { + query += " WHERE " + isFirstFilter = false + } else { + query += " AND " + } + + query += filter.FieldName + " LIKE ?" + queryParams = append(queryParams, filter.Value) + }; + + for _, filter := range filterParams.SinceParams { + // Is the filter in the columns? + if !strings.Contains(columns, filter.FieldName) { + continue + } + + if isFirstFilter { + query += " WHERE " + isFirstFilter = false + } else { + query += " AND " + } + + query += filter.FieldName + " > ?" + queryParams = append(queryParams, filter.Value) + }; + + for _, filter := range filterParams.MaxParams { + // Is the filter in the columns? + if !strings.Contains(columns, filter.FieldName) { + continue + } + + if isFirstFilter { + query += " WHERE " + isFirstFilter = false + } else { + query += " AND " + } + + query += filter.FieldName + " < ?" + queryParams = append(queryParams, filter.Value) + }; + + if filterParams.SortParam != "" { + query += " ORDER BY " + filterParams.SortParam + " " + filterParams.SortOrder + } + + if filterParams.Pagination { + query += " LIMIT " + fmt.Sprint(filterParams.PerPage) + " OFFSET " + fmt.Sprint(filterParams.PageNumber * filterParams.PerPage) + } + + // If there is not the ; at the end, add it + if !strings.HasSuffix(query, ";") { + query += ";" + }; + + {{- queryRetval . }} {{ queryMethod . }}(ctx, query, queryParams...) + + {{- end }} + + {{- else }} {{- queryRetval . }} {{ queryMethod . }}(ctx, {{.ConstantName}}, {{.Arg.Params}}) + {{- end -}} + + {{- if ne .Cmd ":many"}} + {{- queryRetval . }} {{ queryMethod . }}(ctx, query, {{.Arg.Params}}) + {{- end -}} {{- end -}} {{end}} diff --git a/internal/codegen/golang/templates/template.tmpl b/internal/codegen/golang/templates/template.tmpl index afd50c01ac..5ce5f31a6b 100644 --- a/internal/codegen/golang/templates/template.tmpl +++ b/internal/codegen/golang/templates/template.tmpl @@ -11,6 +11,9 @@ package {{.Package}} {{ if hasImports .SourceName }} import ( + {{ if .EmitSchemaName }} + {{end}} + {{range imports .SourceName}} {{range .}}{{.}} {{end}} @@ -44,6 +47,9 @@ package {{.Package}} {{ if hasImports .SourceName }} import ( + {{ if .EmitSchemaName }} + {{end}} + {{range imports .SourceName}} {{range .}}{{.}} {{end}} @@ -75,6 +81,9 @@ package {{.Package}} {{ if hasImports .SourceName }} import ( + {{ if .EmitSchemaName }} + {{end}} + {{range imports .SourceName}} {{range .}}{{.}} {{end}} @@ -175,10 +184,16 @@ package {{.Package}} {{ if hasImports .SourceName }} import ( + {{ if .EmitSchemaName }} + "strings" + "regexp" + {{end}} + {{range imports .SourceName}} {{range .}}{{.}} {{end}} {{end}} + ) {{end}} @@ -206,10 +221,17 @@ package {{.Package}} {{ if hasImports .SourceName }} import ( + {{ if .EmitSchemaName }} + "strings" + "regexp" + {{end}} + {{range imports .SourceName}} {{range .}}{{.}} {{end}} {{end}} + + "fmt" ) {{end}} @@ -237,6 +259,9 @@ package {{.Package}} {{ if hasImports .SourceName }} import ( + {{ if .EmitSchemaName }} + {{end}} + {{range imports .SourceName}} {{range .}}{{.}} {{end}}