diff --git a/internal/dbtest/db_test.go b/internal/dbtest/db_test.go index 8055d6e4f..d44e48b52 100644 --- a/internal/dbtest/db_test.go +++ b/internal/dbtest/db_test.go @@ -258,6 +258,8 @@ func TestDB(t *testing.T) { {testFKViolation}, {testWithForeignKeysAndRules}, {testWithForeignKeys}, + {testWithForeignKeysHasMany}, + {testWithPointerForeignKeysHasMany}, {testInterfaceAny}, {testInterfaceJSON}, {testScanRawMessage}, @@ -1043,6 +1045,114 @@ func testWithForeignKeys(t *testing.T, db *bun.DB) { require.Equal(t, d.User.Name, "root") } +func testWithForeignKeysHasMany(t *testing.T, db *bun.DB) { + type User struct { + ID int `bun:",pk"` + DeckID int + Name string + } + type Deck struct { + ID int `bun:",pk"` + Users []*User `bun:"rel:has-many,join:id=deck_id"` + } + + if db.Dialect().Name() == dialect.SQLite { + _, err := db.Exec("PRAGMA foreign_keys = ON;") + require.NoError(t, err) + } + + for _, model := range []interface{}{(*Deck)(nil), (*User)(nil)} { + _, err := db.NewDropTable().Model(model).IfExists().Exec(ctx) + require.NoError(t, err) + } + + mustResetModel(t, ctx, db, (*User)(nil)) + _, err := db.NewCreateTable(). + Model((*Deck)(nil)). + IfNotExists(). + WithForeignKeys(). + Exec(ctx) + require.NoError(t, err) + mustDropTableOnCleanup(t, ctx, db, (*Deck)(nil)) + + deckID := 1 + deck := Deck{ID: deckID} + _, err = db.NewInsert().Model(&deck).Exec(ctx) + require.NoError(t, err) + + userID1 := 1 + userID2 := 2 + users := []*User{ + {ID: userID1, DeckID: deckID, Name: "user 1"}, + {ID: userID2, DeckID: deckID, Name: "user 2"}, + } + + res, err := db.NewInsert().Model(&users).Exec(ctx) + require.NoError(t, err) + + affected, err := res.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(2), affected) + + err = db.NewSelect().Model(&deck).Relation("Users").Scan(ctx) + require.NoError(t, err) + require.Len(t, deck.Users, 2) +} + +func testWithPointerForeignKeysHasMany(t *testing.T, db *bun.DB) { + type User struct { + ID *int `bun:",pk"` + DeckID *int + Name string + } + type Deck struct { + ID *int `bun:",pk"` + Users []*User `bun:"rel:has-many,join:id=deck_id"` + } + + if db.Dialect().Name() == dialect.SQLite { + _, err := db.Exec("PRAGMA foreign_keys = ON;") + require.NoError(t, err) + } + + for _, model := range []interface{}{(*Deck)(nil), (*User)(nil)} { + _, err := db.NewDropTable().Model(model).IfExists().Exec(ctx) + require.NoError(t, err) + } + + mustResetModel(t, ctx, db, (*User)(nil)) + _, err := db.NewCreateTable(). + Model((*Deck)(nil)). + IfNotExists(). + WithForeignKeys(). + Exec(ctx) + require.NoError(t, err) + mustDropTableOnCleanup(t, ctx, db, (*Deck)(nil)) + + deckID := 1 + deck := Deck{ID: &deckID} + _, err = db.NewInsert().Model(&deck).Exec(ctx) + require.NoError(t, err) + + userID1 := 1 + userID2 := 2 + users := []*User{ + {ID: &userID1, DeckID: &deckID, Name: "user 1"}, + {ID: &userID2, DeckID: &deckID, Name: "user 2"}, + } + + res, err := db.NewInsert().Model(&users).Exec(ctx) + require.NoError(t, err) + + affected, err := res.RowsAffected() + require.NoError(t, err) + require.Equal(t, int64(2), affected) + + err = db.NewSelect().Model(&deck).Relation("Users").Scan(ctx) + require.NoError(t, err) + require.Len(t, deck.Users, 2) +} + func testInterfaceAny(t *testing.T, db *bun.DB) { switch db.Dialect().Name() { case dialect.MySQL: diff --git a/model_table_has_many.go b/model_table_has_many.go index 3d8a5da6f..f3e977fca 100644 --- a/model_table_has_many.go +++ b/model_table_has_many.go @@ -94,7 +94,7 @@ func (m *hasManyModel) Scan(src interface{}) error { for _, f := range m.rel.JoinFields { if f.Name == field.Name { - m.structKey = append(m.structKey, field.Value(m.strct).Interface()) + m.structKey = append(m.structKey, getFieldValue(field.Value(m.strct))) break } } @@ -103,6 +103,7 @@ func (m *hasManyModel) Scan(src interface{}) error { } func (m *hasManyModel) parkStruct() error { + baseValues, ok := m.baseValues[internal.NewMapKey(m.structKey)] if !ok { return fmt.Errorf( @@ -143,7 +144,24 @@ func baseValues(model TableModel, fields []*schema.Field) map[internal.MapKey][] func modelKey(key []interface{}, strct reflect.Value, fields []*schema.Field) []interface{} { for _, f := range fields { - key = append(key, f.Value(strct).Interface()) + key = append(key, getFieldValue(f.Value(strct))) } return key } + +// getFieldValue extracts the value from a reflect.Value, handling pointer types appropriately. +func getFieldValue(fieldValue reflect.Value) interface{} { + var keyValue interface{} + + if fieldValue.Kind() == reflect.Ptr { + if !fieldValue.IsNil() { + keyValue = fieldValue.Elem().Interface() + } else { + keyValue = nil + } + } else { + keyValue = fieldValue.Interface() + } + + return keyValue +}