diff --git a/internal/codegen/golang/field.go b/internal/codegen/golang/field.go index ae7ba63573..0086e77e8b 100644 --- a/internal/codegen/golang/field.go +++ b/internal/codegen/golang/field.go @@ -20,6 +20,19 @@ type Field struct { EmbedFields []Field } +// Match returns true if the name and the type of the 2 fields are equal. +func (gf Field) Match(other Field) bool { + if gf.Name != other.Name { + return false + } + + if gf.Type != other.Type { + return false + } + + return true +} + func (gf Field) Tag() string { return TagsToString(gf.Tags) } diff --git a/internal/codegen/golang/field_test.go b/internal/codegen/golang/field_test.go new file mode 100644 index 0000000000..1e85208f6a --- /dev/null +++ b/internal/codegen/golang/field_test.go @@ -0,0 +1,55 @@ +package golang + +import ( + "testing" +) + +func TestField_Match(t *testing.T) { + t.Parallel() + + field := Field{ + Name: "Name", + Type: "string", + } + + tests := []struct { + name string + field Field + want bool + }{ + { + name: "match", + field: Field{ + Name: "Name", + Type: "string", + }, + want: true, + }, + { + name: "name mismatch", + field: Field{ + Name: "OtherName", + Type: "string", + }, + }, + { + name: "type mismatch", + field: Field{ + Name: "Name", + Type: "int", + }, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := field.Match(tt.field); got != tt.want { + t.Errorf("Match() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/codegen/golang/result.go b/internal/codegen/golang/result.go index 8e1c2714f7..d7dcf86950 100644 --- a/internal/codegen/golang/result.go +++ b/internal/codegen/golang/result.go @@ -181,8 +181,21 @@ func argName(name string) string { return out } -func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error) { +func lookups(req *plugin.CodeGenRequest, other Struct, ls ...Structs) (Struct, bool) { + for _, s := range ls { + if exists, found := s.Lookup(req, other); found { + return exists, true + } + } + + return other, false +} + +func buildQueries(req *plugin.CodeGenRequest, tableStructs Structs) ([]Query, error) { + var queryStructs Structs + qs := make([]Query, 0, len(req.Queries)) + for _, query := range req.Queries { if query.Name == "" { continue @@ -233,8 +246,15 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error) if err != nil { return nil, err } + + found := false + + if req.Settings.Go.ReuseStructs { + *s, found = lookups(req, *s, tableStructs, queryStructs) + } + gq.Arg = QueryValue{ - Emit: true, + Emit: !found, Name: "arg", Struct: s, SQLDriver: sqlpkg, @@ -259,47 +279,34 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error) SQLDriver: sqlpkg, } } else if putOutColumns(query) { - var gs *Struct - var emit bool + var columns []goColumn + for i, c := range query.Columns { + columns = append(columns, goColumn{ + id: i, + Column: c, + embed: newGoEmbed(c.EmbedTable, tableStructs, req.Catalog.DefaultSchema), + }) + } - for _, s := range structs { - if len(s.Fields) != len(query.Columns) { - continue - } - same := true - for i, f := range s.Fields { - c := query.Columns[i] - sameName := f.Name == StructName(columnName(c, i), req.Settings) - sameType := f.Type == goType(req, c) - sameTable := sdk.SameTableName(c.Table, s.Table, req.Catalog.DefaultSchema) - if !sameName || !sameType || !sameTable { - same = false - } - } - if same { - gs = &s - break - } + gs, err := columnsToStruct(req, gq.MethodName+"Row", columns, true) + if err != nil { + return nil, err } - if gs == nil { - var columns []goColumn - for i, c := range query.Columns { - columns = append(columns, goColumn{ - id: i, - Column: c, - embed: newGoEmbed(c.EmbedTable, structs, req.Catalog.DefaultSchema), - }) - } - var err error - gs, err = columnsToStruct(req, gq.MethodName+"Row", columns, true) - if err != nil { - return nil, err - } - emit = true + found := false + + *gs, found = tableStructs.Lookup(req, *gs) + + if !found && req.Settings.Go.ReuseStructs { + *gs, found = queryStructs.Lookup(req, *gs) } + + if !found { + queryStructs = append(queryStructs, *gs) + } + gq.Ret = QueryValue{ - Emit: emit, + Emit: !found, Name: "i", Struct: gs, SQLDriver: sqlpkg, diff --git a/internal/codegen/golang/struct.go b/internal/codegen/golang/struct.go index 31904205e1..5e4493d6bb 100644 --- a/internal/codegen/golang/struct.go +++ b/internal/codegen/golang/struct.go @@ -5,6 +5,7 @@ import ( "unicode" "unicode/utf8" + "github.com/sqlc-dev/sqlc/internal/codegen/sdk" "github.com/sqlc-dev/sqlc/internal/plugin" ) @@ -15,6 +16,42 @@ type Struct struct { Comment string } +func (s Struct) Match(req *plugin.CodeGenRequest, other Struct) bool { + if len(s.Fields) != len(other.Fields) { + return false + } + + for i, f := range s.Fields { + of := other.Fields[i] + + if !f.Match(of) { + return false + } + + if s.Table != nil && !sdk.SameTableName(of.Column.Table, s.Table, req.Catalog.DefaultSchema) { + return false + } + } + + return true +} + +type Structs []Struct + +// Lookup search for a matching Struct in slice. +// +// - if found, returns the matching Struct and true +// - else returns the given Struct and false +func (s Structs) Lookup(req *plugin.CodeGenRequest, other Struct) (Struct, bool) { + for _, exists := range s { + if exists.Match(req, other) { + return exists, true + } + } + + return other, false +} + func StructName(name string, settings *plugin.Settings) string { if rename := settings.Rename[name]; rename != "" { return rename diff --git a/internal/codegen/golang/struct_test.go b/internal/codegen/golang/struct_test.go new file mode 100644 index 0000000000..486b595a36 --- /dev/null +++ b/internal/codegen/golang/struct_test.go @@ -0,0 +1,159 @@ +package golang + +import ( + "reflect" + "testing" + + "github.com/sqlc-dev/sqlc/internal/plugin" +) + +func TestStruct_Match(t *testing.T) { + t.Parallel() + + req := &plugin.CodeGenRequest{Catalog: &plugin.Catalog{}} + tableId := &plugin.Identifier{Name: "Table"} + tableStruct := Struct{ + Table: tableId, + Name: "Name", + Fields: []Field{ + {Name: "Field", Type: "string"}, + }, + } + + tests := []struct { + name string + str Struct + other Struct + want bool + }{ + { + name: "match", + str: tableStruct, + other: Struct{ + Table: tableId, + Name: "Name", + Fields: []Field{ + {Name: "Field", Type: "string", Column: &plugin.Column{Table: tableId}}, + }, + }, + want: true, + }, + { + name: "table mismatch", + str: tableStruct, + other: Struct{ + Table: tableId, + Name: "Name", + Fields: []Field{ + {Name: "Field", Type: "string", Column: &plugin.Column{Table: &plugin.Identifier{Name: "OtherTable"}}}, + }, + }, + }, + { + name: "other table nil", + str: tableStruct, + other: Struct{ + Table: tableId, + Name: "Name", + Fields: []Field{ + {Name: "Field", Type: "string", Column: &plugin.Column{}}, + }, + }, + }, + { + name: "field count mismatch", + str: tableStruct, + other: Struct{ + Table: tableId, + Name: "Name", + Fields: []Field{ + {Name: "Field1", Type: "string"}, + {Name: "Field2", Type: "string"}, + }, + }, + }, + { + name: "field mismatch", + str: tableStruct, + other: Struct{ + Table: tableId, + Name: "Name", + Fields: []Field{ + {Name: "OtherField", Type: "string", Column: &plugin.Column{Table: tableId}}, + }, + }, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := tt.str.Match(req, tt.other); got != tt.want { + t.Errorf("Match() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestStructs_Lookup(t *testing.T) { + t.Parallel() + + req := &plugin.CodeGenRequest{Catalog: &plugin.Catalog{}} + str := Struct{ + Fields: []Field{ + {Name: "Field", Type: "string"}, + }, + } + other := Struct{ + Fields: []Field{ + {Name: "OtherField", Type: "string"}, + }, + Comment: "OtherStruct", + } + structs := Structs{str} + + tests := []struct { + name string + other Struct + want Struct + wantFound bool + }{ + { + name: "found", + other: Struct{ + Fields: []Field{ + {Name: "Field", Type: "string"}, + }, + Comment: "Matching Struct", + }, + want: str, + wantFound: true, + }, + { + name: "not found", + other: other, + want: other, + }, + } + + for _, tt := range tests { + tt := tt + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got, found := structs.Lookup(req, tt.other) + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Lookup() got = %v, want %v", got, tt.want) + } + + if found != tt.wantFound { + t.Errorf("Lookup() found = %v, want %v", found, tt.wantFound) + } + }) + } +}