From e5d78d464b94b78438cf275b4c35f713d129961d Mon Sep 17 00:00:00 2001 From: Peter Magnusson Date: Fri, 29 Jul 2022 15:38:39 +0200 Subject: [PATCH 1/3] feat: conditions not supporting composite in --- dialect/feature/feature.go | 1 + dialect/pgdialect/dialect.go | 3 +- dialect/sqlitedialect/dialect.go | 3 +- relation_join.go | 93 ++++++++++++++++++++++++++++++++ 4 files changed, 98 insertions(+), 2 deletions(-) diff --git a/dialect/feature/feature.go b/dialect/feature/feature.go index 956dc4985..e311394d5 100644 --- a/dialect/feature/feature.go +++ b/dialect/feature/feature.go @@ -31,4 +31,5 @@ const ( UpdateFromTable MSSavepoint GeneratedIdentity + CompositeIn // ... WHERE (A,B) IN ((N, NN), (N, NN)...) ) diff --git a/dialect/pgdialect/dialect.go b/dialect/pgdialect/dialect.go index d524f0a1a..6ff85e166 100644 --- a/dialect/pgdialect/dialect.go +++ b/dialect/pgdialect/dialect.go @@ -47,7 +47,8 @@ func New() *Dialect { feature.TableNotExists | feature.InsertOnConflict | feature.SelectExists | - feature.GeneratedIdentity + feature.GeneratedIdentity | + feature.CompositeIn return d } diff --git a/dialect/sqlitedialect/dialect.go b/dialect/sqlitedialect/dialect.go index e79dcb004..720e979f5 100644 --- a/dialect/sqlitedialect/dialect.go +++ b/dialect/sqlitedialect/dialect.go @@ -38,7 +38,8 @@ func New() *Dialect { feature.DeleteTableAlias | feature.InsertOnConflict | feature.TableNotExists | - feature.SelectExists + feature.SelectExists | + feature.CompositeIn return d } diff --git a/relation_join.go b/relation_join.go index e8074e0c6..f6c0cad11 100644 --- a/relation_join.go +++ b/relation_join.go @@ -4,6 +4,7 @@ import ( "context" "reflect" + "github.com/uptrace/bun/dialect/feature" "github.com/uptrace/bun/internal" "github.com/uptrace/bun/schema" ) @@ -60,6 +61,14 @@ func (j *relationJoin) manyQuery(q *SelectQuery) *SelectQuery { q = q.Model(hasManyModel) var where []byte + + if q.db.dialect.Features().Has(feature.CompositeIn) { + return j.manyQueryCompositeIn(where, q) + } + return j.manyQueryMulti(where, q) +} + +func (j *relationJoin) manyQueryCompositeIn(where []byte, q *SelectQuery) *SelectQuery { if len(j.Relation.JoinFields) > 1 { where = append(where, '(') } @@ -88,6 +97,29 @@ func (j *relationJoin) manyQuery(q *SelectQuery) *SelectQuery { return q } +func (j *relationJoin) manyQueryMulti(where []byte, q *SelectQuery) *SelectQuery { + where = appendMultiValues( + q.db.Formatter(), + where, + j.JoinModel.rootValue(), + j.JoinModel.parentIndex(), + j.Relation.BaseFields, + j.Relation.JoinFields, + j.JoinModel.Table().SQLAlias, + ) + + q = q.Where(internal.String(where)) + + if j.Relation.PolymorphicField != nil { + q = q.Where("? = ?", j.Relation.PolymorphicField.SQLName, j.Relation.PolymorphicValue) + } + + j.applyTo(q) + q = q.Apply(j.hasManyColumns) + + return q +} + func (j *relationJoin) hasManyColumns(q *SelectQuery) *SelectQuery { b := make([]byte, 0, 32) @@ -312,3 +344,64 @@ func appendChildValues( } return b } + +func getColumns(table schema.Safe, fields []*schema.Field) [][]byte { + //Based upon query_base.appendColumns + var list [][]byte + for _, f := range fields { + b := []byte{} + + if len(table) > 0 { + b = append(b, table...) + b = append(b, '.') + } + b = append(b, f.SQLName...) + list = append(list, b) + } + return list +} + +func appendMultiValues( + fmter schema.Formatter, b []byte, v reflect.Value, index []int, baseFields, joinFields []*schema.Field, table schema.Safe, +) []byte { + // This is a mix of appendChildValues and query_base.appendColumns + if len(joinFields) != len(baseFields) { + panic("asdfasdf") + } + + // First get the columns + joins := getColumns(table, joinFields) + // Then values + b = append(b, '(') + seen := make(map[string]struct{}) + walk(v, index, func(v reflect.Value) { + start := len(b) + for i, f := range baseFields { + if i > 0 { + b = append(b, " AND "...) + } + if len(baseFields) > 1 { + b = append(b, '(') + } + b = append(b, joins[i]...) + b = append(b, '=') + b = f.AppendValue(fmter, b, v) + if len(baseFields) > 1 { + b = append(b, ')') + } + } + + b = append(b, ") OR ("...) + + if _, ok := seen[string(b[start:])]; ok { + b = b[:start] + } else { + seen[string(b[start:])] = struct{}{} + } + }) + if len(seen) > 0 { + b = b[:len(b)-6] // trim ") OR (" + } + b = append(b, ')') + return b +} From a757e78a27473ce04a36ccf1f74cdad5cca29664 Mon Sep 17 00:00:00 2001 From: Peter Magnusson Date: Fri, 29 Jul 2022 16:26:14 +0200 Subject: [PATCH 2/3] test: has-many with composite keys --- internal/dbtest/orm_test.go | 29 +++++++++++++++++++++++++++ internal/dbtest/testdata/fixture.yaml | 22 ++++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/internal/dbtest/orm_test.go b/internal/dbtest/orm_test.go index fdcfb5037..8edacdc3c 100644 --- a/internal/dbtest/orm_test.go +++ b/internal/dbtest/orm_test.go @@ -31,6 +31,7 @@ func TestORM(t *testing.T) { {testRelationExcludeAll}, {testM2MRelationExcludeColumn}, {testRelationBelongsToSelf}, + {testCompositeHasMany}, } testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) { @@ -429,6 +430,18 @@ func testM2MRelationExcludeColumn(t *testing.T, db *bun.DB) { require.NoError(t, err) } +func testCompositeHasMany(t *testing.T, db *bun.DB) { + department := new(Department) + err := db.NewSelect(). + Model(department). + Where("company_no=? AND no=?", "company one", "hr"). + Relation("Employees"). + Scan(ctx) + require.NoError(t, err) + require.Equal(t, "hr", department.No) + require.Equal(t, 2, len(department.Employees)) +} + type Genre struct { ID int `bun:",pk"` Name string @@ -530,6 +543,20 @@ type Comment struct { Text string } +type Department struct { + bun.BaseModel `bun:"alias:d"` + CompanyNo string `bun:",pk"` + No string `bun:",pk"` + Employees []Employee `bun:"rel:has-many,join:company_no=company_no,join:no=department_no"` +} + +type Employee struct { + bun.BaseModel `bun:"alias:p"` + CompanyNo string `bun:",pk"` + DepartmentNo string `bun:",pk"` + Name string `bun:",pk"` +} + func createTestSchema(t *testing.T, db *bun.DB) { _ = db.Table(reflect.TypeOf((*BookGenre)(nil)).Elem()) @@ -541,6 +568,8 @@ func createTestSchema(t *testing.T, db *bun.DB) { (*BookGenre)(nil), (*Translation)(nil), (*Comment)(nil), + (*Department)(nil), + (*Employee)(nil), } for _, model := range models { _, err := db.NewDropTable().Model(model).IfExists().Exec(ctx) diff --git a/internal/dbtest/testdata/fixture.yaml b/internal/dbtest/testdata/fixture.yaml index 8264c26d0..c158dfb33 100644 --- a/internal/dbtest/testdata/fixture.yaml +++ b/internal/dbtest/testdata/fixture.yaml @@ -80,3 +80,25 @@ - trackable_id: 1000 trackable_type: translation text: comment3 + +- model: Department + rows: + - company_no: company one + no: accounting + - company_no: company one + no: 'hr' + +- model: Employee + rows: + - company_no: company one + department_no: accounting + name: 'adam' + - company_no: company one + department_no: accounting + name: 'bravo' + - company_no: company one + department_no: hr + name: 'charlie' + - company_no: company one + department_no: hr + name: 'foxtrot' From 91348c59c184a2f182aeb639594793a02f37889e Mon Sep 17 00:00:00 2001 From: Peter Magnusson Date: Mon, 8 Aug 2022 07:34:04 +0200 Subject: [PATCH 3/3] refactor: change how things are done based on comments --- relation_join.go | 37 ++++++++++++++----------------------- 1 file changed, 14 insertions(+), 23 deletions(-) diff --git a/relation_join.go b/relation_join.go index f6c0cad11..4ca4075b6 100644 --- a/relation_join.go +++ b/relation_join.go @@ -345,33 +345,19 @@ func appendChildValues( return b } -func getColumns(table schema.Safe, fields []*schema.Field) [][]byte { - //Based upon query_base.appendColumns - var list [][]byte - for _, f := range fields { - b := []byte{} - - if len(table) > 0 { - b = append(b, table...) - b = append(b, '.') - } - b = append(b, f.SQLName...) - list = append(list, b) - } - return list -} - +// appendMultiValues is an alternative to appendChildValues that doesn't use the sql keyword ID +// but instead use a old style ((k1=v1) AND (k2=v2)) OR (...) of conditions. func appendMultiValues( - fmter schema.Formatter, b []byte, v reflect.Value, index []int, baseFields, joinFields []*schema.Field, table schema.Safe, + fmter schema.Formatter, b []byte, v reflect.Value, index []int, baseFields, joinFields []*schema.Field, joinTable schema.Safe, ) []byte { - // This is a mix of appendChildValues and query_base.appendColumns + // This is based on a mix of appendChildValues and query_base.appendColumns + + // These should never missmatch in length but nice to know if it does if len(joinFields) != len(baseFields) { - panic("asdfasdf") + panic("not reached") } - // First get the columns - joins := getColumns(table, joinFields) - // Then values + // walk the relations b = append(b, '(') seen := make(map[string]struct{}) walk(v, index, func(v reflect.Value) { @@ -383,7 +369,12 @@ func appendMultiValues( if len(baseFields) > 1 { b = append(b, '(') } - b = append(b, joins[i]...) + // Field name + b = append(b, joinTable...) + b = append(b, '.') + b = append(b, []byte(joinFields[i].SQLName)...) + + // Equals value b = append(b, '=') b = f.AppendValue(fmter, b, v) if len(baseFields) > 1 {