From 27e2753c9dfbb7c4330ea14d5ff04fd672d341be Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Mon, 29 Nov 2021 18:34:50 +0800 Subject: [PATCH] Fix create duplicated value when updating nested has many relationship, close #4796 --- callbacks/associations.go | 21 +++++++++++++++++---- tests/associations_test.go | 29 ++++++++++++++++++----------- tests/multi_primary_keys_test.go | 2 +- tests/tests_test.go | 2 +- utils/tests/models.go | 7 +++++++ 5 files changed, 44 insertions(+), 17 deletions(-) diff --git a/callbacks/associations.go b/callbacks/associations.go index 9d5b7c21c..38f21218a 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -7,6 +7,7 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" + "gorm.io/gorm/utils" ) func SaveBeforeAssociations(create bool) func(db *gorm.DB) { @@ -182,6 +183,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { fieldType = reflect.PtrTo(fieldType) } elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + identityMap := map[string]bool{} appendToElems := func(v reflect.Value) { if _, zero := rel.Field.ValueOf(v); !zero { f := reflect.Indirect(rel.Field.ReflectValueOf(v)) @@ -197,10 +199,21 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) { } } - if isPtr { - elems = reflect.Append(elems, elem) - } else { - elems = reflect.Append(elems, elem.Addr()) + relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields)) + for _, pf := range rel.FieldSchema.PrimaryFields { + if pfv, ok := pf.ValueOf(elem); !ok { + relPrimaryValues = append(relPrimaryValues, pfv) + } + } + + cacheKey := utils.ToStringKey(relPrimaryValues) + if len(relPrimaryValues) == 0 || (len(relPrimaryValues) == len(rel.FieldSchema.PrimaryFields) && !identityMap[cacheKey]) { + identityMap[cacheKey] = true + if isPtr { + elems = reflect.Append(elems, elem) + } else { + elems = reflect.Append(elems, elem.Addr()) + } } } } diff --git a/tests/associations_test.go b/tests/associations_test.go index a8d478867..a4b1f1f28 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -178,19 +178,21 @@ func TestForeignKeyConstraintsBelongsTo(t *testing.T) { } func TestFullSaveAssociations(t *testing.T) { + coupon := &Coupon{ + ID: "full-save-association-coupon1", + AppliesToProduct: []*CouponProduct{ + { + CouponId: "full-save-association-coupon1", + ProductId: "full-save-association-product1", + }, + }, + AmountOff: 10, + PercentOff: 0.0, + } + err := DB. Session(&gorm.Session{FullSaveAssociations: true}). - Create(&Coupon{ - ID: "full-save-association-coupon1", - AppliesToProduct: []*CouponProduct{ - { - CouponId: "full-save-association-coupon1", - ProductId: "full-save-association-product1", - }, - }, - AmountOff: 10, - PercentOff: 0.0, - }).Error + Create(coupon).Error if err != nil { t.Errorf("Failed, got error: %v", err) @@ -203,4 +205,9 @@ func TestFullSaveAssociations(t *testing.T) { if DB.First(&CouponProduct{}, "coupon_id = ? AND product_id = ?", "full-save-association-coupon1", "full-save-association-product1").Error != nil { t.Errorf("Failed to query saved association") } + + orders := []Order{{Num: "order1", Coupon: coupon}, {Num: "order2", Coupon: coupon}} + if err := DB.Create(&orders).Error; err != nil { + t.Errorf("failed to create orders, got %v", err) + } } diff --git a/tests/multi_primary_keys_test.go b/tests/multi_primary_keys_test.go index dcc90cd9a..3a8c08aa8 100644 --- a/tests/multi_primary_keys_test.go +++ b/tests/multi_primary_keys_test.go @@ -427,7 +427,7 @@ func TestCompositePrimaryKeysAssociations(t *testing.T) { DB.Migrator().DropTable(&Label{}, &Book{}) if err := DB.AutoMigrate(&Label{}, &Book{}); err != nil { - t.Fatalf("failed to migrate") + t.Fatalf("failed to migrate, got %v", err) } book := Book{ diff --git a/tests/tests_test.go b/tests/tests_test.go index 5799662fe..d1f19df30 100644 --- a/tests/tests_test.go +++ b/tests/tests_test.go @@ -87,7 +87,7 @@ func OpenTestConnection() (db *gorm.DB, err error) { func RunMigrations() { var err error - allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}} + allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}} rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) diff --git a/utils/tests/models.go b/utils/tests/models.go index 5eee84680..337682d61 100644 --- a/utils/tests/models.go +++ b/utils/tests/models.go @@ -72,3 +72,10 @@ type CouponProduct struct { CouponId string `gorm:"primarykey; size:255"` ProductId string `gorm:"primarykey; size:255"` } + +type Order struct { + gorm.Model + Num string + Coupon *Coupon + CouponID string +}