From 7dc255acfe2e20c033e082b532c6b1c85c7751a9 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 19 Jun 2020 18:30:04 +0800 Subject: [PATCH] Add SavePoint/RollbackTo/NestedTransaction --- errors.go | 2 + finisher_api.go | 54 +++++++++++++---- interfaces.go | 5 ++ tests/go.mod | 10 ++-- tests/transaction_test.go | 120 ++++++++++++++++++++++++++++++++++++++ wercker.yml | 6 -- 6 files changed, 176 insertions(+), 21 deletions(-) diff --git a/errors.go b/errors.go index ff06f24e4..2506ecc57 100644 --- a/errors.go +++ b/errors.go @@ -25,4 +25,6 @@ var ( ErrorPrimaryKeyRequired = errors.New("primary key required") // ErrorModelValueRequired model value required ErrorModelValueRequired = errors.New("model value required") + // ErrUnsupportedDriver unsupported driver + ErrUnsupportedDriver = errors.New("unsupported driver") ) diff --git a/finisher_api.go b/finisher_api.go index 43aff843b..92d4fe720 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -3,6 +3,7 @@ package gorm import ( "database/sql" "errors" + "fmt" "reflect" "strings" @@ -343,18 +344,33 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { // Transaction start a transaction as a block, return error will rollback, otherwise to commit. func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { panicked := true - tx := db.Begin(opts...) - defer func() { - // Make sure to rollback when panic, Block error or Commit error - if panicked || err != nil { - tx.Rollback() - } - }() - err = fc(tx) + if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { + // nested transaction + db.SavePoint(fmt.Sprintf("sp%p", fc)) + defer func() { + // Make sure to rollback when panic, Block error or Commit error + if panicked || err != nil { + db.RollbackTo(fmt.Sprintf("sp%p", fc)) + } + }() + + err = fc(db.Session(&Session{WithConditions: true})) + } else { + tx := db.Begin(opts...) + + defer func() { + // Make sure to rollback when panic, Block error or Commit error + if panicked || err != nil { + tx.Rollback() + } + }() + + err = fc(tx) - if err == nil { - err = tx.Commit().Error + if err == nil { + err = tx.Commit().Error + } } panicked = false @@ -409,6 +425,24 @@ func (db *DB) Rollback() *DB { return db } +func (db *DB) SavePoint(name string) *DB { + if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { + savePointer.SavePoint(db, name) + } else { + db.AddError(ErrUnsupportedDriver) + } + return db +} + +func (db *DB) RollbackTo(name string) *DB { + if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { + savePointer.RollbackTo(db, name) + } else { + db.AddError(ErrUnsupportedDriver) + } + return db +} + // Exec execute raw sql func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { tx = db.getInstance() diff --git a/interfaces.go b/interfaces.go index 4be545654..f3e5c0287 100644 --- a/interfaces.go +++ b/interfaces.go @@ -27,6 +27,11 @@ type ConnPool interface { QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row } +type SavePointerDialectorInterface interface { + SavePoint(tx *DB, name string) error + RollbackTo(tx *DB, name string) error +} + type TxBeginner interface { BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) } diff --git a/tests/go.mod b/tests/go.mod index 07ec6be23..a2121b7a5 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -6,11 +6,11 @@ require ( github.com/google/uuid v1.1.1 github.com/jinzhu/now v1.1.1 github.com/lib/pq v1.6.0 - gorm.io/driver/mysql v0.2.0 - gorm.io/driver/postgres v0.2.0 - gorm.io/driver/sqlite v1.0.2 - gorm.io/driver/sqlserver v0.2.0 - gorm.io/gorm v0.0.0-00010101000000-000000000000 + gorm.io/driver/mysql v0.2.1 + gorm.io/driver/postgres v0.2.1 + gorm.io/driver/sqlite v1.0.4 + gorm.io/driver/sqlserver v0.2.1 + gorm.io/gorm v0.2.7 ) replace gorm.io/gorm => ../ diff --git a/tests/transaction_test.go b/tests/transaction_test.go index d1bf86459..c101388a0 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -142,3 +142,123 @@ func TestTransactionRaiseErrorOnRollbackAfterCommit(t *testing.T) { t.Fatalf("Rollback after commit should raise error") } } + +func TestTransactionWithSavePoint(t *testing.T) { + tx := DB.Begin() + + user := *GetUser("transaction-save-point", Config{}) + tx.Create(&user) + + if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := tx.SavePoint("save_point1").Error; err != nil { + t.Fatalf("Failed to save point, got error %v", err) + } + + user1 := *GetUser("transaction-save-point-1", Config{}) + tx.Create(&user1) + + if err := tx.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := tx.RollbackTo("save_point1").Error; err != nil { + t.Fatalf("Failed to save point, got error %v", err) + } + + if err := tx.First(&User{}, "name = ?", user1.Name).Error; err == nil { + t.Fatalf("Should not find rollbacked record") + } + + if err := tx.SavePoint("save_point2").Error; err != nil { + t.Fatalf("Failed to save point, got error %v", err) + } + + user2 := *GetUser("transaction-save-point-2", Config{}) + tx.Create(&user2) + + if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := tx.Commit().Error; err != nil { + t.Fatalf("Failed to commit, got error %v", err) + } + + if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := DB.First(&User{}, "name = ?", user1.Name).Error; err == nil { + t.Fatalf("Should not find rollbacked record") + } + + if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } +} + +func TestNestedTransactionWithBlock(t *testing.T) { + var ( + user = *GetUser("transaction-nested", Config{}) + user1 = *GetUser("transaction-nested-1", Config{}) + user2 = *GetUser("transaction-nested-2", Config{}) + ) + + if err := DB.Transaction(func(tx *gorm.DB) error { + tx.Create(&user) + + if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := tx.Transaction(func(tx1 *gorm.DB) error { + tx1.Create(&user1) + + if err := tx1.First(&User{}, "name = ?", user1.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + return errors.New("rollback") + }); err == nil { + t.Fatalf("nested transaction should returns error") + } + + if err := tx.First(&User{}, "name = ?", user1.Name).Error; err == nil { + t.Fatalf("Should not find rollbacked record") + } + + if err := tx.Transaction(func(tx2 *gorm.DB) error { + tx2.Create(&user2) + + if err := tx2.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + return nil + }); err != nil { + t.Fatalf("nested transaction returns error: %v", err) + } + + if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + return nil + }); err != nil { + t.Fatalf("no error should return, but got %v", err) + } + + if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } + + if err := DB.First(&User{}, "name = ?", user1.Name).Error; err == nil { + t.Fatalf("Should not find rollbacked record") + } + + if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil { + t.Fatalf("Should find saved record") + } +} diff --git a/wercker.yml b/wercker.yml index baece1bc2..d4fb63e35 100644 --- a/wercker.yml +++ b/wercker.yml @@ -124,9 +124,3 @@ build: name: test mssql code: | GORM_DIALECT=mssql GORM_VERBOSE=true GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" ./tests/tests_all.sh - - - script: - name: codecov - code: | - go test -race -coverprofile=coverage.txt -covermode=atomic ./... - bash <(curl -s https://codecov.io/bash)