From 6044ecf078a5a28ba37c961230075654974ce3c6 Mon Sep 17 00:00:00 2001 From: dyma solovei <53943884+bevzzz@users.noreply.github.com> Date: Sat, 6 Jan 2024 08:31:29 +0100 Subject: [PATCH] test: cleanup test databases to avoid side-effects (#927) --- internal/dbtest/bench_test.go | 8 +- internal/dbtest/db_test.go | 146 ++++++++++++---------------- internal/dbtest/migrate_test.go | 28 +++++- internal/dbtest/model_hook_test.go | 3 +- internal/dbtest/mssql_test.go | 7 +- internal/dbtest/orm_test.go | 17 +--- internal/dbtest/pg_test.go | 126 ++++++++++-------------- internal/dbtest/soft_delete_test.go | 24 ++--- 8 files changed, 156 insertions(+), 203 deletions(-) diff --git a/internal/dbtest/bench_test.go b/internal/dbtest/bench_test.go index 3c896e86a..9eeb16abe 100644 --- a/internal/dbtest/bench_test.go +++ b/internal/dbtest/bench_test.go @@ -77,7 +77,7 @@ func benchEachDB(b *testing.B, f func(b *testing.B, db *bun.DB)) { db.SetMaxOpenConns(64) db.SetMaxIdleConns(64) - err := resetBenchSchema(db) + err := resetBenchSchema(b, db) require.NoError(b, err) b.ResetTimer() @@ -86,10 +86,8 @@ func benchEachDB(b *testing.B, f func(b *testing.B, db *bun.DB)) { } } -func resetBenchSchema(db *bun.DB) error { - if err := db.ResetModel(ctx, (*Bench)(nil)); err != nil { - return err - } +func resetBenchSchema(tb testing.TB, db *bun.DB) error { + mustResetModel(tb, ctx, db, (*Bench)(nil)) for i := 0; i < 1000; i++ { bench := &Bench{ diff --git a/internal/dbtest/db_test.go b/internal/dbtest/db_test.go index d560041f6..7c72c562d 100644 --- a/internal/dbtest/db_test.go +++ b/internal/dbtest/db_test.go @@ -679,10 +679,9 @@ func testRunInTx(t *testing.T, db *bun.DB) { Count int64 } - err := db.ResetModel(ctx, (*Counter)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Counter)(nil)) - _, err = db.NewInsert().Model(&Counter{Count: 0}).Exec(ctx) + _, err := db.NewInsert().Model(&Counter{Count: 0}).Exec(ctx) require.NoError(t, err) err = db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { @@ -723,15 +722,14 @@ func testJSONSpecialChars(t *testing.T, db *bun.DB) { ctx := context.Background() - err := db.ResetModel(ctx, (*Model)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Model)(nil)) model := &Model{ Attrs: map[string]interface{}{ "hello": "\000world\nworld\u0000", }, } - _, err = db.NewInsert().Model(model).Exec(ctx) + _, err := db.NewInsert().Model(model).Exec(ctx) require.NoError(t, err) model = new(Model) @@ -757,11 +755,10 @@ func testJSONInterface(t *testing.T, db *bun.DB) { ctx := context.Background() - err := db.ResetModel(ctx, (*Model)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Model)(nil)) model := new(Model) - _, err = db.NewInsert().Model(model).Exec(ctx) + _, err := db.NewInsert().Model(model).Exec(ctx) require.NoError(t, err) model = &Model{ @@ -801,11 +798,10 @@ func testJSONValuer(t *testing.T, db *bun.DB) { ctx := context.Background() - err := db.ResetModel(ctx, (*Model)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Model)(nil)) model := new(Model) - _, err = db.NewInsert().Model(model).Exec(ctx) + _, err := db.NewInsert().Model(model).Exec(ctx) require.NoError(t, err) model2 := new(Model) @@ -868,18 +864,14 @@ func testFKViolation(t *testing.T, db *bun.DB) { require.NoError(t, err) } + mustResetModel(t, ctx, db, (*User)(nil)) _, err := db.NewCreateTable(). - Model((*User)(nil)). - IfNotExists(). - Exec(ctx) - require.NoError(t, err) - - _, err = db.NewCreateTable(). Model((*Deck)(nil)). IfNotExists(). ForeignKey("(user_id) REFERENCES users (id) ON DELETE CASCADE"). Exec(ctx) require.NoError(t, err) + mustDropTableOnCleanup(t, ctx, db, (*Deck)(nil)) // Empty deck should violate FK constraint. _, err = db.NewInsert().Model(new(Deck)).Exec(ctx) @@ -923,18 +915,14 @@ func testWithForeignKeysAndRules(t *testing.T, db *bun.DB) { require.NoError(t, err) } + mustResetModel(t, ctx, db, (*User)(nil)) _, err := db.NewCreateTable(). - Model((*User)(nil)). - IfNotExists(). - Exec(ctx) - require.NoError(t, err) - - _, err = db.NewCreateTable(). Model((*Deck)(nil)). IfNotExists(). WithForeignKeys(). Exec(ctx) require.NoError(t, err) + mustDropTableOnCleanup(t, ctx, db, (*Deck)(nil)) // Empty deck should violate FK constraint. _, err = db.NewInsert().Model(new(Deck)).Exec(ctx) @@ -1011,18 +999,15 @@ func testWithForeignKeys(t *testing.T, db *bun.DB) { require.NoError(t, err) } - _, err := db.NewCreateTable(). - Model((*User)(nil)). - IfNotExists(). - Exec(ctx) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*User)(nil)) - _, err = db.NewCreateTable(). + _, err := db.NewCreateTable(). Model((*Deck)(nil)). IfNotExists(). WithForeignKeys(). Exec(ctx) require.NoError(t, err) + mustDropTableOnCleanup(t, ctx, db, (*Deck)(nil)) // Empty deck should violate FK constraint. _, err = db.NewInsert().Model(new(Deck)).Exec(ctx) @@ -1109,14 +1094,13 @@ func testScanRawMessage(t *testing.T, db *bun.DB) { ctx := context.Background() - err := db.ResetModel(ctx, (*Model)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Model)(nil)) models := []Model{ {Value: json.RawMessage(`"hello"`)}, {Value: json.RawMessage(`"world"`)}, } - _, err = db.NewInsert().Model(&models).Exec(ctx) + _, err := db.NewInsert().Model(&models).Exec(ctx) require.NoError(t, err) var models1 []Model @@ -1139,8 +1123,7 @@ func testPointers(t *testing.T, db *bun.DB) { ctx := context.Background() for _, id := range []int64{-1, 0, 1} { - err := db.ResetModel(ctx, (*Model)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Model)(nil)) var model Model if id >= 0 { @@ -1150,7 +1133,7 @@ func testPointers(t *testing.T, db *bun.DB) { } - _, err = db.NewInsert().Model(&model).Exec(ctx) + _, err := db.NewInsert().Model(&model).Exec(ctx) require.NoError(t, err) var model2 Model @@ -1195,11 +1178,9 @@ func testBinaryData(t *testing.T, db *bun.DB) { } ctx := context.Background() + mustResetModel(t, ctx, db, (*Model)(nil)) - err := db.ResetModel(ctx, (*Model)(nil)) - require.NoError(t, err) - - _, err = db.NewInsert().Model(&Model{Data: []byte("hello")}).Exec(ctx) + _, err := db.NewInsert().Model(&Model{Data: []byte("hello")}).Exec(ctx) require.NoError(t, err) var model Model @@ -1219,13 +1200,11 @@ func testUpsert(t *testing.T, db *bun.DB) { } ctx := context.Background() - - err := db.ResetModel(ctx, (*Model)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Model)(nil)) model := &Model{ID: 1, Str: "hello"} - _, err = db.NewInsert().Model(model).Exec(ctx) + _, err := db.NewInsert().Model(model).Exec(ctx) require.NoError(t, err) model.Str = "world" @@ -1256,13 +1235,11 @@ func testMultiUpdate(t *testing.T, db *bun.DB) { } ctx := context.Background() - - err := db.ResetModel(ctx, (*Model)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Model)(nil)) model := &Model{ID: 1, Str: "hello"} - _, err = db.NewInsert().Model(model).Exec(ctx) + _, err := db.NewInsert().Model(model).Exec(ctx) require.NoError(t, err) selq := db.NewSelect().Model(new(Model)) @@ -1285,15 +1262,13 @@ func testUpdateWithSkipupdateTag(t *testing.T, db *bun.DB) { } ctx := context.Background() - - err := db.ResetModel(ctx, (*Model)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Model)(nil)) createdAt := time.Now().Truncate(time.Minute).UTC() model := &Model{ID: 1, Name: "foo", CreatedAt: createdAt} - _, err = db.NewInsert().Model(model).Exec(ctx) + _, err := db.NewInsert().Model(model).Exec(ctx) require.NoError(t, err) require.NotZero(t, model.CreatedAt) @@ -1326,9 +1301,7 @@ func testScanAndCount(t *testing.T, db *bun.DB) { } ctx := context.Background() - - err := db.ResetModel(ctx, (*Model)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Model)(nil)) t.Run("tx", func(t *testing.T) { for i := 0; i < 100; i++ { @@ -1348,7 +1321,7 @@ func testScanAndCount(t *testing.T, db *bun.DB) { {Str: "str1"}, {Str: "str2"}, } - _, err = db.NewInsert().Model(&src).Exec(ctx) + _, err := db.NewInsert().Model(&src).Exec(ctx) require.NoError(t, err) var dest []Model @@ -1376,9 +1349,7 @@ func testEmbedModelValue(t *testing.T, db *bun.DB) { } ctx := context.Background() - - err := db.ResetModel(ctx, (*Model)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Model)(nil)) m1 := &Model{ X: Embed{ @@ -1406,7 +1377,7 @@ func testEmbedModelValue(t *testing.T, db *bun.DB) { }, }, } - _, err = db.NewInsert().Model(m1).Exec(ctx) + _, err := db.NewInsert().Model(m1).Exec(ctx) require.NoError(t, err) var m2 Model @@ -1432,9 +1403,7 @@ func testEmbedModelPointer(t *testing.T, db *bun.DB) { } ctx := context.Background() - - err := db.ResetModel(ctx, (*Model)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Model)(nil)) m1 := &Model{ X: &Embed{ @@ -1462,7 +1431,7 @@ func testEmbedModelPointer(t *testing.T, db *bun.DB) { }, }, } - _, err = db.NewInsert().Model(m1).Exec(ctx) + _, err := db.NewInsert().Model(m1).Exec(ctx) require.NoError(t, err) var m2 Model @@ -1478,14 +1447,12 @@ func testEmbedTypeField(t *testing.T, db *bun.DB) { } ctx := context.Background() - - err := db.ResetModel(ctx, (*Model)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Model)(nil)) m1 := &Model{ Embed: Embed("foo"), } - _, err = db.NewInsert().Model(m1).Exec(ctx) + _, err := db.NewInsert().Model(m1).Exec(ctx) require.NoError(t, err) var m2 Model @@ -1508,12 +1475,10 @@ func testJSONMarshaler(t *testing.T, db *bun.DB) { } ctx := context.Background() - - err := db.ResetModel(ctx, (*Model)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Model)(nil)) m1 := &Model{Field: new(JSONField)} - _, err = db.NewInsert().Model(m1).Exec(ctx) + _, err := db.NewInsert().Model(m1).Exec(ctx) require.NoError(t, err) var m2 Model @@ -1538,11 +1503,9 @@ func testNilDriverValue(t *testing.T, db *bun.DB) { } ctx := context.Background() + mustResetModel(t, ctx, db, (*Model)(nil)) - err := db.ResetModel(ctx, (*Model)(nil)) - require.NoError(t, err) - - _, err = db.NewInsert().Model(&Model{}).Exec(ctx) + _, err := db.NewInsert().Model(&Model{}).Exec(ctx) require.NoError(t, err) _, err = db.NewInsert().Model(&Model{Value: &DriverValue{s: "hello"}}).Exec(ctx) @@ -1554,10 +1517,9 @@ func testRunInTxAndSavepoint(t *testing.T, db *bun.DB) { Count int64 } - err := db.ResetModel(ctx, (*Counter)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Counter)(nil)) - _, err = db.NewInsert().Model(&Counter{Count: 0}).Exec(ctx) + _, err := db.NewInsert().Model(&Counter{Count: 0}).Exec(ctx) require.NoError(t, err) err = db.RunInTx(ctx, nil, func(ctx context.Context, tx bun.Tx) error { @@ -1662,12 +1624,10 @@ func testDriverValuerReturnsItself(t *testing.T, db *bun.DB) { } ctx := context.Background() - - err := db.ResetModel(ctx, (*Model)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Model)(nil)) model := &Model{Value: expectedValue} - _, err = db.NewInsert().Model(model).Exec(ctx) + _, err := db.NewInsert().Model(model).Exec(ctx) require.Error(t, err) } @@ -1678,9 +1638,7 @@ func testNoPanicWhenReturningNullColumns(t *testing.T, db *bun.DB) { } ctx := context.Background() - - err := db.ResetModel(ctx, (*Model)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Model)(nil)) modelSlice := []*Model{{Value: "boom"}} @@ -1688,3 +1646,19 @@ func testNoPanicWhenReturningNullColumns(t *testing.T, db *bun.DB) { db.NewInsert().Model(&modelSlice).Exec(ctx) }) } + +func mustResetModel(tb testing.TB, ctx context.Context, db *bun.DB, models ...interface{}) { + err := db.ResetModel(ctx, models...) + require.NoError(tb, err, "must reset model") + mustDropTableOnCleanup(tb, ctx, db, models...) +} + +func mustDropTableOnCleanup(tb testing.TB, ctx context.Context, db *bun.DB, models ...interface{}) { + tb.Cleanup(func() { + for _, model := range models { + drop := db.NewDropTable().IfExists().Model(model) + _, err := drop.Exec(ctx) + require.NoError(tb, err, "must drop table: %q", drop.GetTableName()) + } + }) +} diff --git a/internal/dbtest/migrate_test.go b/internal/dbtest/migrate_test.go index 7ef513775..74e33eab2 100644 --- a/internal/dbtest/migrate_test.go +++ b/internal/dbtest/migrate_test.go @@ -10,6 +10,22 @@ import ( "github.com/uptrace/bun/migrate" ) +const ( + migrationsTable = "test_migrations" + migrationLocksTable = "test_migration_locks" +) + +func cleanupMigrations(tb testing.TB, ctx context.Context, db *bun.DB) { + tb.Cleanup(func() { + var err error + _, err = db.NewDropTable().ModelTableExpr(migrationsTable).Exec(ctx) + require.NoError(tb, err, "drop %q table", migrationsTable) + + _, err = db.NewDropTable().ModelTableExpr(migrationLocksTable).Exec(ctx) + require.NoError(tb, err, "drop %q table", migrationLocksTable) + }) +} + func TestMigrate(t *testing.T) { type Test struct { run func(t *testing.T, db *bun.DB) @@ -21,6 +37,8 @@ func TestMigrate(t *testing.T) { } testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) { + cleanupMigrations(t, ctx, db) + for _, test := range tests { t.Run(funcName(test.run), func(t *testing.T) { test.run(t, db) @@ -58,7 +76,10 @@ func testMigrateUpAndDown(t *testing.T, db *bun.DB) { }, }) - m := migrate.NewMigrator(db, migrations) + m := migrate.NewMigrator(db, migrations, + migrate.WithTableName(migrationsTable), + migrate.WithLocksTableName(migrationLocksTable), + ) err := m.Reset(ctx) require.NoError(t, err) @@ -116,7 +137,10 @@ func testMigrateUpError(t *testing.T, db *bun.DB) { }, }) - m := migrate.NewMigrator(db, migrations) + m := migrate.NewMigrator(db, migrations, + migrate.WithTableName(migrationsTable), + migrate.WithLocksTableName(migrationLocksTable), + ) err := m.Reset(ctx) require.NoError(t, err) diff --git a/internal/dbtest/model_hook_test.go b/internal/dbtest/model_hook_test.go index 7043ddd2a..7a630a1e0 100644 --- a/internal/dbtest/model_hook_test.go +++ b/internal/dbtest/model_hook_test.go @@ -39,8 +39,7 @@ func TestModelHook(t *testing.T) { } func testModelHook(t *testing.T, dbName string, db *bun.DB) { - err := db.ResetModel(ctx, (*ModelHookTest)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*ModelHookTest)(nil)) { hook := &ModelHookTest{ID: 1} diff --git a/internal/dbtest/mssql_test.go b/internal/dbtest/mssql_test.go index 579cee42b..39de96527 100644 --- a/internal/dbtest/mssql_test.go +++ b/internal/dbtest/mssql_test.go @@ -8,7 +8,7 @@ import ( func TestMssqlMerge(t *testing.T) { db := mssql2019(t) - defer db.Close() + t.Cleanup(func() { db.Close() }) type Model struct { ID int64 `bun:",pk,autoincrement"` @@ -17,10 +17,9 @@ func TestMssqlMerge(t *testing.T) { Value string } - err := db.ResetModel(ctx, (*Model)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Model)(nil)) - _, err = db.NewInsert().Model(&Model{Name: "A", Value: "hello"}).Exec(ctx) + _, err := db.NewInsert().Model(&Model{Name: "A", Value: "hello"}).Exec(ctx) require.NoError(t, err) newModels := []*Model{ diff --git a/internal/dbtest/orm_test.go b/internal/dbtest/orm_test.go index 622924437..fcd1b3931 100644 --- a/internal/dbtest/orm_test.go +++ b/internal/dbtest/orm_test.go @@ -355,14 +355,13 @@ func testRelationBelongsToSelf(t *testing.T, db *bun.DB) { Model *Model `bun:"rel:belongs-to"` } - err := db.ResetModel(ctx, (*Model)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Model)(nil)) models := []Model{ {ID: 1}, {ID: 2, ModelID: 1}, } - _, err = db.NewInsert().Model(&models).Exec(ctx) + _, err := db.NewInsert().Model(&models).Exec(ctx) require.NoError(t, err) models = nil @@ -396,15 +395,13 @@ func testM2MRelationExcludeColumn(t *testing.T, db *bun.DB) { } db.RegisterModel((*OrderToItem)(nil)) - - err := db.ResetModel(ctx, (*Order)(nil), (*Item)(nil), (*OrderToItem)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Order)(nil), (*Item)(nil), (*OrderToItem)(nil)) items := []Item{ {ID: 1, CreatedAt: time.Unix(1, 0), UpdatedAt: time.Unix(1, 0)}, {ID: 2, CreatedAt: time.Unix(2, 0), UpdatedAt: time.Unix(1, 0)}, } - _, err = db.NewInsert().Model(&items).Exec(ctx) + _, err := db.NewInsert().Model(&items).Exec(ctx) require.NoError(t, err) orders := []Order{ @@ -574,11 +571,7 @@ func createTestSchema(t *testing.T, db *bun.DB) { (*Employee)(nil), } for _, model := range models { - _, err := db.NewDropTable().Model(model).IfExists().Exec(ctx) - require.NoError(t, err) - - _, err = db.NewCreateTable().Model(model).Exec(ctx) - require.NoError(t, err) + mustResetModel(t, ctx, db, model) } } diff --git a/internal/dbtest/pg_test.go b/internal/dbtest/pg_test.go index 4521f11b3..cc4031033 100644 --- a/internal/dbtest/pg_test.go +++ b/internal/dbtest/pg_test.go @@ -28,20 +28,15 @@ func TestPostgresArray(t *testing.T) { } db := pg(t) - defer db.Close() - - _, err := db.NewDropTable().Model((*Model)(nil)).IfExists().Exec(ctx) - require.NoError(t, err) - - _, err = db.NewCreateTable().Model((*Model)(nil)).Exec(ctx) - require.NoError(t, err) + t.Cleanup(func() { db.Close() }) + mustResetModel(t, ctx, db, (*Model)(nil)) model1 := &Model{ ID: 123, Array1: []string{"one", "two", "three"}, Array2: &[]string{"hello", "world"}, } - _, err = db.NewInsert().Model(model1).Exec(ctx) + _, err := db.NewInsert().Model(model1).Exec(ctx) require.NoError(t, err) model2 := new(Model) @@ -65,7 +60,7 @@ func TestPostgresArray(t *testing.T) { func TestPostgresArrayQuote(t *testing.T) { db := pg(t) - defer db.Close() + t.Cleanup(func() { db.Close() }) wanted := []string{"'", "''", "'''", "\""} var strs []string @@ -101,16 +96,15 @@ func TestPostgresArrayValuer(t *testing.T) { } db := pg(t) - defer db.Close() + t.Cleanup(func() { db.Close() }) - err := db.ResetModel(ctx, (*Model)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Model)(nil)) model1 := &Model{ ID: 123, Array: []Hash{Hash{}}, } - _, err = db.NewInsert().Model(model1).Exec(ctx) + _, err := db.NewInsert().Model(model1).Exec(ctx) require.NoError(t, err) model2 := new(Model) @@ -154,11 +148,7 @@ func TestPostgresMultiTenant(t *testing.T) { (*IngredientRecipe)(nil), } for _, model := range models { - _, err := db.NewDropTable().Model(model).IfExists().Exec(ctx) - require.NoError(t, err) - - _, err = db.NewCreateTable().Model(model).Exec(ctx) - require.NoError(t, err) + mustResetModel(t, ctx, db, model) } models = []interface{}{ @@ -191,9 +181,9 @@ func TestPostgresInsertNoRows(t *testing.T) { } db := pg(t) + t.Cleanup(func() { db.Close() }) - err := db.ResetModel(ctx, (*User)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*User)(nil)) { res, err := db.NewInsert(). @@ -228,9 +218,9 @@ func TestPostgresInsertNoRowsIdentity(t *testing.T) { } db := pg(t) + t.Cleanup(func() { db.Close() }) - err := db.ResetModel(ctx, (*User)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*User)(nil)) { res, err := db.NewInsert(). @@ -335,6 +325,7 @@ func TestPostgresTransaction(t *testing.T) { _, err = db.NewCreateTable().Conn(tx).Model((*Model)(nil)).Exec(ctx) require.NoError(t, err) + mustDropTableOnCleanup(t, ctx, db, (*Model)(nil)) n, err := db.NewSelect().Conn(tx).Model((*Model)(nil)).Count(ctx) require.NoError(t, err) @@ -349,18 +340,17 @@ func TestPostgresTransaction(t *testing.T) { } func TestPostgresScanWithoutResult(t *testing.T) { - db := pg(t) - defer db.Close() - type Model struct { ID int64 `bun:",pk,autoincrement"` } - err := db.ResetModel(ctx, (*Model)(nil)) - require.NoError(t, err) + db := pg(t) + t.Cleanup(func() { db.Close() }) + + mustResetModel(t, ctx, db, (*Model)(nil)) var num int64 - _, err = db.NewUpdate().Model(new(Model)).Set("id = NULL").Where("id = 0").Exec(ctx, &num) + _, err := db.NewUpdate().Model(new(Model)).Set("id = NULL").Where("id = 0").Exec(ctx, &num) require.Equal(t, sql.ErrNoRows, err) } @@ -370,10 +360,9 @@ func TestPostgresIPNet(t *testing.T) { } db := pg(t) - defer db.Close() + t.Cleanup(func() { db.Close() }) - err := db.ResetModel(ctx, (*Model)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Model)(nil)) _, ipv4Net, err := net.ParseCIDR("192.0.2.1/24") require.NoError(t, err) @@ -393,12 +382,11 @@ func TestPostgresBytea(t *testing.T) { } db := pg(t) - defer db.Close() + t.Cleanup(func() { db.Close() }) - err := db.ResetModel(ctx, (*Model)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Model)(nil)) - _, err = db.NewInsert().Model(&Model{Bytes: []byte("hello")}).Exec(ctx) + _, err := db.NewInsert().Model(&Model{Bytes: []byte("hello")}).Exec(ctx) require.NoError(t, err) model := new(Model) @@ -413,13 +401,12 @@ func TestPostgresByteaArray(t *testing.T) { } db := pg(t) - defer db.Close() + t.Cleanup(func() { db.Close() }) - err := db.ResetModel(ctx, (*Model)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Model)(nil)) model1 := &Model{BytesSlice: [][]byte{[]byte("hello"), []byte("world")}} - _, err = db.NewInsert().Model(model1).Exec(ctx) + _, err := db.NewInsert().Model(model1).Exec(ctx) require.NoError(t, err) model2 := new(Model) @@ -430,7 +417,7 @@ func TestPostgresByteaArray(t *testing.T) { func TestPostgresDate(t *testing.T) { db := pg(t) - defer db.Close() + t.Cleanup(func() { db.Close() }) var str string err := db.NewSelect().ColumnExpr("'2021-09-15'::date").Scan(ctx, &str) @@ -455,7 +442,7 @@ func TestPostgresDate(t *testing.T) { func TestPostgresTimetz(t *testing.T) { db := pg(t) - defer db.Close() + t.Cleanup(func() { db.Close() }) var tm time.Time err := db.NewSelect().ColumnExpr("now()::timetz").Scan(ctx, &tm) @@ -470,14 +457,11 @@ func TestPostgresTimeArray(t *testing.T) { Array2 *[]time.Time `bun:",array"` Array3 *[]time.Time `bun:",array"` } - db := pg(t) - defer db.Close() - _, err := db.NewDropTable().Model((*Model)(nil)).IfExists().Exec(ctx) - require.NoError(t, err) + db := pg(t) + t.Cleanup(func() { db.Close() }) - _, err = db.NewCreateTable().Model((*Model)(nil)).Exec(ctx) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Model)(nil)) time1 := time.Now() time2 := time.Now().Add(time.Hour) @@ -488,7 +472,7 @@ func TestPostgresTimeArray(t *testing.T) { Array1: []time.Time{time1, time2, time3}, Array2: &[]time.Time{time1, time2, time3}, } - _, err = db.NewInsert().Model(model1).Exec(ctx) + _, err := db.NewInsert().Model(model1).Exec(ctx) require.NoError(t, err) model2 := new(Model) @@ -525,14 +509,13 @@ func TestPostgresOnConflictDoUpdate(t *testing.T) { ctx := context.Background() db := pg(t) - defer db.Close() + t.Cleanup(func() { db.Close() }) - err := db.ResetModel(ctx, (*Model)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Model)(nil)) model := &Model{ID: 1} - _, err = db.NewInsert(). + _, err := db.NewInsert(). Model(model). On("CONFLICT (id) DO UPDATE"). Set("updated_at = now()"). @@ -562,14 +545,13 @@ func TestPostgresOnConflictDoUpdateIdentity(t *testing.T) { ctx := context.Background() db := pg(t) - defer db.Close() + t.Cleanup(func() { db.Close() }) - err := db.ResetModel(ctx, (*Model)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Model)(nil)) model := &Model{ID: 1} - _, err = db.NewInsert(). + _, err := db.NewInsert(). Model(model). On("CONFLICT (id) DO UPDATE"). Set("updated_at = now()"). @@ -594,7 +576,7 @@ func TestPostgresCopyFromCopyTo(t *testing.T) { ctx := context.Background() db := pg(t) - defer db.Close() + t.Cleanup(func() { db.Close() }) conn, err := db.Conn(ctx) require.NoError(t, err) @@ -689,13 +671,12 @@ func TestPostgresUUID(t *testing.T) { ctx := context.Background() db := pg(t) - defer db.Close() + t.Cleanup(func() { db.Close() }) _, err := db.Exec(`CREATE EXTENSION IF NOT EXISTS "uuid-ossp"`) require.NoError(t, err) - err = db.ResetModel(ctx, (*Model)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Model)(nil)) model := new(Model) _, err = db.NewInsert().Model(model).Exec(ctx) @@ -712,16 +693,11 @@ func TestPostgresHStore(t *testing.T) { } db := pg(t) - defer db.Close() + t.Cleanup(func() { db.Close() }) _, err := db.Exec(`CREATE EXTENSION IF NOT EXISTS HSTORE;`) require.NoError(t, err) - - _, err = db.NewDropTable().Model((*Model)(nil)).IfExists().Exec(ctx) - require.NoError(t, err) - - _, err = db.NewCreateTable().Model((*Model)(nil)).Exec(ctx) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Model)(nil)) model1 := &Model{ ID: 123, @@ -760,7 +736,7 @@ func TestPostgresHStore(t *testing.T) { func TestPostgresHStoreQuote(t *testing.T) { db := pg(t) - defer db.Close() + t.Cleanup(func() { db.Close() }) _, err := db.Exec(`CREATE EXTENSION IF NOT EXISTS HSTORE;`) require.NoError(t, err) @@ -784,16 +760,15 @@ func TestPostgresSkipupdateField(t *testing.T) { ctx := context.Background() db := pg(t) - defer db.Close() + t.Cleanup(func() { db.Close() }) - err := db.ResetModel(ctx, (*Model)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Model)(nil)) createdAt := time.Now().Truncate(time.Minute).UTC() model := &Model{ID: 1, Name: "foo", CreatedAt: createdAt} - _, err = db.NewInsert().Model(model).Exec(ctx) + _, err := db.NewInsert().Model(model).Exec(ctx) require.NoError(t, err) require.NotZero(t, model.CreatedAt) @@ -850,13 +825,12 @@ func TestPostgresCustomTypeBytes(t *testing.T) { ctx := context.Background() db := pg(t) - defer db.Close() + t.Cleanup(func() { db.Close() }) - err := db.ResetModel(ctx, (*Model)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Model)(nil)) in := &Model{Data: []*Issue722{{V: []byte("hello")}}} - _, err = db.NewInsert().Model(in).Exec(ctx) + _, err := db.NewInsert().Model(in).Exec(ctx) require.NoError(t, err) out := new(Model) diff --git a/internal/dbtest/soft_delete_test.go b/internal/dbtest/soft_delete_test.go index b8714ddfc..c66ab9262 100644 --- a/internal/dbtest/soft_delete_test.go +++ b/internal/dbtest/soft_delete_test.go @@ -39,11 +39,9 @@ type Video struct { func testSoftDeleteNilModel(t *testing.T, db *bun.DB) { ctx := context.Background() + mustResetModel(t, ctx, db, (*Video)(nil)) - err := db.ResetModel(ctx, (*Video)(nil)) - require.NoError(t, err) - - _, err = db.NewDelete().Model((*Video)(nil)).Where("1 = 1").Exec(ctx) + _, err := db.NewDelete().Model((*Video)(nil)).Where("1 = 1").Exec(ctx) require.NoError(t, err) _, err = db.NewDelete().Model((*Video)(nil)).Where("1 = 1").ForceDelete().Exec(ctx) @@ -52,14 +50,12 @@ func testSoftDeleteNilModel(t *testing.T, db *bun.DB) { func testSoftDeleteAPI(t *testing.T, db *bun.DB) { ctx := context.Background() - - err := db.ResetModel(ctx, (*Video)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Video)(nil)) video1 := &Video{ ID: 1, } - _, err = db.NewInsert().Model(video1).Exec(ctx) + _, err := db.NewInsert().Model(video1).Exec(ctx) require.NoError(t, err) // Count visible videos. @@ -107,9 +103,7 @@ func testSoftDeleteAPI(t *testing.T, db *bun.DB) { func testSoftDeleteForce(t *testing.T, db *bun.DB) { ctx := context.Background() - - err := db.ResetModel(ctx, (*Video)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Video)(nil)) videos := []Video{ {Name: "video1"}, @@ -117,7 +111,7 @@ func testSoftDeleteForce(t *testing.T, db *bun.DB) { {Name: "video3"}, } - _, err = db.NewInsert().Model(&videos).Exec(ctx) + _, err := db.NewInsert().Model(&videos).Exec(ctx) require.NoError(t, err) // Force delete video1. @@ -160,15 +154,13 @@ func testSoftDeleteForce(t *testing.T, db *bun.DB) { func testSoftDeleteBulk(t *testing.T, db *bun.DB) { ctx := context.Background() - - err := db.ResetModel(ctx, (*Video)(nil)) - require.NoError(t, err) + mustResetModel(t, ctx, db, (*Video)(nil)) videos := []Video{ {Name: "video1"}, {Name: "video2"}, } - _, err = db.NewInsert().Model(&videos).Exec(ctx) + _, err := db.NewInsert().Model(&videos).Exec(ctx) require.NoError(t, err) if db.Dialect().Features().Has(feature.CTE) {