Skip to content

Commit

Permalink
Reuse structs (wip)
Browse files Browse the repository at this point in the history
  • Loading branch information
MartyHub committed Aug 26, 2023
1 parent 8218707 commit 68a3154
Show file tree
Hide file tree
Showing 5 changed files with 309 additions and 38 deletions.
13 changes: 13 additions & 0 deletions internal/codegen/golang/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@ type Field struct {
EmbedFields []Field
}

// Match returns true if the name and the type of the 2 fields are equal.
func (gf Field) Match(other Field) bool {
if gf.Name != other.Name {
return false
}

if gf.Type != other.Type {
return false
}

return true
}

func (gf Field) Tag() string {
return TagsToString(gf.Tags)
}
Expand Down
55 changes: 55 additions & 0 deletions internal/codegen/golang/field_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package golang

import (
"testing"
)

func TestField_Match(t *testing.T) {
t.Parallel()

field := Field{
Name: "Name",
Type: "string",
}

tests := []struct {
name string
field Field
want bool
}{
{
name: "match",
field: Field{
Name: "Name",
Type: "string",
},
want: true,
},
{
name: "name mismatch",
field: Field{
Name: "OtherName",
Type: "string",
},
},
{
name: "type mismatch",
field: Field{
Name: "Name",
Type: "int",
},
},
}

for _, tt := range tests {
tt := tt

t.Run(tt.name, func(t *testing.T) {
t.Parallel()

if got := field.Match(tt.field); got != tt.want {
t.Errorf("Match() = %v, want %v", got, tt.want)
}
})
}
}
83 changes: 45 additions & 38 deletions internal/codegen/golang/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,21 @@ func argName(name string) string {
return out
}

func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error) {
func lookups(req *plugin.CodeGenRequest, other Struct, ls ...Structs) (Struct, bool) {
for _, s := range ls {
if exists, found := s.Lookup(req, other); found {
return exists, true
}
}

return other, false
}

func buildQueries(req *plugin.CodeGenRequest, tableStructs Structs) ([]Query, error) {
var queryStructs Structs

qs := make([]Query, 0, len(req.Queries))

for _, query := range req.Queries {
if query.Name == "" {
continue
Expand Down Expand Up @@ -233,8 +246,15 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error)
if err != nil {
return nil, err
}

found := false

if req.Settings.Go.ReuseStructs {
*s, found = lookups(req, *s, tableStructs, queryStructs)
}

gq.Arg = QueryValue{
Emit: true,
Emit: !found,
Name: "arg",
Struct: s,
SQLDriver: sqlpkg,
Expand All @@ -259,47 +279,34 @@ func buildQueries(req *plugin.CodeGenRequest, structs []Struct) ([]Query, error)
SQLDriver: sqlpkg,
}
} else if putOutColumns(query) {
var gs *Struct
var emit bool
var columns []goColumn
for i, c := range query.Columns {
columns = append(columns, goColumn{
id: i,
Column: c,
embed: newGoEmbed(c.EmbedTable, tableStructs, req.Catalog.DefaultSchema),
})
}

for _, s := range structs {
if len(s.Fields) != len(query.Columns) {
continue
}
same := true
for i, f := range s.Fields {
c := query.Columns[i]
sameName := f.Name == StructName(columnName(c, i), req.Settings)
sameType := f.Type == goType(req, c)
sameTable := sdk.SameTableName(c.Table, s.Table, req.Catalog.DefaultSchema)
if !sameName || !sameType || !sameTable {
same = false
}
}
if same {
gs = &s
break
}
gs, err := columnsToStruct(req, gq.MethodName+"Row", columns, true)
if err != nil {
return nil, err
}

if gs == nil {
var columns []goColumn
for i, c := range query.Columns {
columns = append(columns, goColumn{
id: i,
Column: c,
embed: newGoEmbed(c.EmbedTable, structs, req.Catalog.DefaultSchema),
})
}
var err error
gs, err = columnsToStruct(req, gq.MethodName+"Row", columns, true)
if err != nil {
return nil, err
}
emit = true
found := false

*gs, found = tableStructs.Lookup(req, *gs)

if !found && req.Settings.Go.ReuseStructs {
*gs, found = queryStructs.Lookup(req, *gs)
}

if !found {
queryStructs = append(queryStructs, *gs)
}

gq.Ret = QueryValue{
Emit: emit,
Emit: !found,
Name: "i",
Struct: gs,
SQLDriver: sqlpkg,
Expand Down
37 changes: 37 additions & 0 deletions internal/codegen/golang/struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"unicode"
"unicode/utf8"

"github.com/sqlc-dev/sqlc/internal/codegen/sdk"
"github.com/sqlc-dev/sqlc/internal/plugin"
)

Expand All @@ -15,6 +16,42 @@ type Struct struct {
Comment string
}

func (s Struct) Match(req *plugin.CodeGenRequest, other Struct) bool {
if len(s.Fields) != len(other.Fields) {
return false
}

for i, f := range s.Fields {
of := other.Fields[i]

if !f.Match(of) {
return false
}

if s.Table != nil && !sdk.SameTableName(of.Column.Table, s.Table, req.Catalog.DefaultSchema) {
return false
}
}

return true
}

type Structs []Struct

// Lookup search for a matching Struct in slice.
//
// - if found, returns the matching Struct and true
// - else returns the given Struct and false
func (s Structs) Lookup(req *plugin.CodeGenRequest, other Struct) (Struct, bool) {
for _, exists := range s {
if exists.Match(req, other) {
return exists, true
}
}

return other, false
}

func StructName(name string, settings *plugin.Settings) string {
if rename := settings.Rename[name]; rename != "" {
return rename
Expand Down
159 changes: 159 additions & 0 deletions internal/codegen/golang/struct_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
package golang

import (
"reflect"
"testing"

"github.com/sqlc-dev/sqlc/internal/plugin"
)

func TestStruct_Match(t *testing.T) {
t.Parallel()

req := &plugin.CodeGenRequest{Catalog: &plugin.Catalog{}}
tableId := &plugin.Identifier{Name: "Table"}
tableStruct := Struct{
Table: tableId,
Name: "Name",
Fields: []Field{
{Name: "Field", Type: "string"},
},
}

tests := []struct {
name string
str Struct
other Struct
want bool
}{
{
name: "match",
str: tableStruct,
other: Struct{
Table: tableId,
Name: "Name",
Fields: []Field{
{Name: "Field", Type: "string", Column: &plugin.Column{Table: tableId}},
},
},
want: true,
},
{
name: "table mismatch",
str: tableStruct,
other: Struct{
Table: tableId,
Name: "Name",
Fields: []Field{
{Name: "Field", Type: "string", Column: &plugin.Column{Table: &plugin.Identifier{Name: "OtherTable"}}},
},
},
},
{
name: "other table nil",
str: tableStruct,
other: Struct{
Table: tableId,
Name: "Name",
Fields: []Field{
{Name: "Field", Type: "string", Column: &plugin.Column{}},
},
},
},
{
name: "field count mismatch",
str: tableStruct,
other: Struct{
Table: tableId,
Name: "Name",
Fields: []Field{
{Name: "Field1", Type: "string"},
{Name: "Field2", Type: "string"},
},
},
},
{
name: "field mismatch",
str: tableStruct,
other: Struct{
Table: tableId,
Name: "Name",
Fields: []Field{
{Name: "OtherField", Type: "string", Column: &plugin.Column{Table: tableId}},
},
},
},
}

for _, tt := range tests {
tt := tt

t.Run(tt.name, func(t *testing.T) {
t.Parallel()

if got := tt.str.Match(req, tt.other); got != tt.want {
t.Errorf("Match() = %v, want %v", got, tt.want)
}
})
}
}

func TestStructs_Lookup(t *testing.T) {
t.Parallel()

req := &plugin.CodeGenRequest{Catalog: &plugin.Catalog{}}
str := Struct{
Fields: []Field{
{Name: "Field", Type: "string"},
},
}
other := Struct{
Fields: []Field{
{Name: "OtherField", Type: "string"},
},
Comment: "OtherStruct",
}
structs := Structs{str}

tests := []struct {
name string
other Struct
want Struct
wantFound bool
}{
{
name: "found",
other: Struct{
Fields: []Field{
{Name: "Field", Type: "string"},
},
Comment: "Matching Struct",
},
want: str,
wantFound: true,
},
{
name: "not found",
other: other,
want: other,
},
}

for _, tt := range tests {
tt := tt

t.Run(tt.name, func(t *testing.T) {
t.Parallel()

got, found := structs.Lookup(req, tt.other)

if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Lookup() got = %v, want %v", got, tt.want)
}

if found != tt.wantFound {
t.Errorf("Lookup() found = %v, want %v", found, tt.wantFound)
}
})
}
}

0 comments on commit 68a3154

Please sign in to comment.