diff --git a/callbacks/update.go b/callbacks/update.go index cfa8c86b4..c16b77d16 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -59,6 +59,11 @@ func Update(db *gorm.DB) { db.Statement.Build("UPDATE", "SET", "WHERE") } + if _, ok := db.Statement.Clauses["WHERE"]; !ok { + db.AddError(gorm.ErrMissingWhereClause) + return + } + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err == nil { diff --git a/tests/delete_test.go b/tests/delete_test.go index 3f17f1a16..4288253fb 100644 --- a/tests/delete_test.go +++ b/tests/delete_test.go @@ -36,10 +36,6 @@ func TestDelete(t *testing.T) { } } - if err := DB.Delete(&User{}).Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) { - t.Errorf("should returns missing WHERE clause while deleting error") - } - for _, user := range []User{users[0], users[2]} { if err := DB.Where("id = ?", user.ID).First(&result).Error; err != nil { t.Errorf("no error should returns when query %v, but got %v", user.ID, err) @@ -64,3 +60,9 @@ func TestInlineCondDelete(t *testing.T) { t.Errorf("User can't be found after delete") } } + +func TestBlockGlobalDelete(t *testing.T) { + if err := DB.Delete(&User{}).Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) { + t.Errorf("should returns missing WHERE clause while deleting error") + } +} diff --git a/tests/joins_test.go b/tests/joins_test.go index 8a9cdde55..d9cfd22f2 100644 --- a/tests/joins_test.go +++ b/tests/joins_test.go @@ -92,16 +92,26 @@ func TestJoinConds(t *testing.T) { func TestJoinsWithSelect(t *testing.T) { type result struct { - ID uint - Name string + ID uint + PetID uint + Name string } user := *GetUser("joins_with_select", Config{Pets: 2}) DB.Save(&user) var results []result - DB.Table("users").Select("users.id, pets.name").Joins("left join pets on pets.user_id = users.id").Where("users.name = ?", "joins_with_select").Scan(&results) + DB.Table("users").Select("users.id, pets.id as pet_id, pets.name").Joins("left join pets on pets.user_id = users.id").Where("users.name = ?", "joins_with_select").Scan(&results) + + sort.Slice(results, func(i, j int) bool { + return results[i].PetID > results[j].PetID + }) + + sort.Slice(results, func(i, j int) bool { + return user.Pets[i].ID > user.Pets[j].ID + }) + if len(results) != 2 || results[0].Name != user.Pets[0].Name || results[1].Name != user.Pets[1].Name { - t.Errorf("Should find all two pets with Join select") + t.Errorf("Should find all two pets with Join select, got %+v", results) } } diff --git a/tests/main_test.go b/tests/main_test.go index da2003d6d..095588a20 100644 --- a/tests/main_test.go +++ b/tests/main_test.go @@ -35,3 +35,17 @@ func TestExceptionsWithInvalidSql(t *testing.T) { t.Errorf("No user should not be deleted by invalid SQL") } } + +func TestSetAndGet(t *testing.T) { + if value, ok := DB.Set("hello", "world").Get("hello"); !ok { + t.Errorf("Should be able to get setting after set") + } else { + if value.(string) != "world" { + t.Errorf("Setted value should not be changed") + } + } + + if _, ok := DB.Get("non_existing"); ok { + t.Errorf("Get non existing key should return error") + } +} diff --git a/tests/non_std_test.go b/tests/non_std_test.go index e5e50141d..606b4fc9c 100644 --- a/tests/non_std_test.go +++ b/tests/non_std_test.go @@ -34,7 +34,7 @@ func TestNonStdPrimaryKeyAndDefaultValues(t *testing.T) { var animals []Animal DB.Find(&animals) - if count := DB.Model(Animal{}).Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(animals)) { + if count := DB.Model(Animal{}).Where("1=1").Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(animals)) { t.Error("RowsAffected should be correct when do batch update") } diff --git a/tests/update_test.go b/tests/update_test.go index 371a9f788..869ce4cdc 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -1,6 +1,7 @@ package tests_test import ( + "errors" "testing" "time" @@ -211,3 +212,9 @@ func TestUpdateColumn(t *testing.T) { CheckUser(t, user5, *users[0]) CheckUser(t, user6, *users[1]) } + +func TestBlockGlobalUpdate(t *testing.T) { + if err := DB.Model(&User{}).Update("name", "jinzhu").Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) { + t.Errorf("should returns missing WHERE clause while updating error, got err %v", err) + } +}