Skip to content

Commit

Permalink
Add SavePoint/RollbackTo/NestedTransaction
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Jun 19, 2020
1 parent 2c1b04a commit 7dc255a
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 21 deletions.
2 changes: 2 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,6 @@ var (
ErrorPrimaryKeyRequired = errors.New("primary key required")
// ErrorModelValueRequired model value required
ErrorModelValueRequired = errors.New("model value required")
// ErrUnsupportedDriver unsupported driver
ErrUnsupportedDriver = errors.New("unsupported driver")
)
54 changes: 44 additions & 10 deletions finisher_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package gorm
import (
"database/sql"
"errors"
"fmt"
"reflect"
"strings"

Expand Down Expand Up @@ -343,18 +344,33 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
// Transaction start a transaction as a block, return error will rollback, otherwise to commit.
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
panicked := true
tx := db.Begin(opts...)
defer func() {
// Make sure to rollback when panic, Block error or Commit error
if panicked || err != nil {
tx.Rollback()
}
}()

err = fc(tx)
if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil {
// nested transaction
db.SavePoint(fmt.Sprintf("sp%p", fc))
defer func() {
// Make sure to rollback when panic, Block error or Commit error
if panicked || err != nil {
db.RollbackTo(fmt.Sprintf("sp%p", fc))
}
}()

err = fc(db.Session(&Session{WithConditions: true}))
} else {
tx := db.Begin(opts...)

defer func() {
// Make sure to rollback when panic, Block error or Commit error
if panicked || err != nil {
tx.Rollback()
}
}()

err = fc(tx)

if err == nil {
err = tx.Commit().Error
if err == nil {
err = tx.Commit().Error
}
}

panicked = false
Expand Down Expand Up @@ -409,6 +425,24 @@ func (db *DB) Rollback() *DB {
return db
}

func (db *DB) SavePoint(name string) *DB {
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
savePointer.SavePoint(db, name)
} else {
db.AddError(ErrUnsupportedDriver)
}
return db
}

func (db *DB) RollbackTo(name string) *DB {
if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok {
savePointer.RollbackTo(db, name)
} else {
db.AddError(ErrUnsupportedDriver)
}
return db
}

// Exec execute raw sql
func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) {
tx = db.getInstance()
Expand Down
5 changes: 5 additions & 0 deletions interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ type ConnPool interface {
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}

type SavePointerDialectorInterface interface {
SavePoint(tx *DB, name string) error
RollbackTo(tx *DB, name string) error
}

type TxBeginner interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
}
Expand Down
10 changes: 5 additions & 5 deletions tests/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ require (
github.com/google/uuid v1.1.1
github.com/jinzhu/now v1.1.1
github.com/lib/pq v1.6.0
gorm.io/driver/mysql v0.2.0
gorm.io/driver/postgres v0.2.0
gorm.io/driver/sqlite v1.0.2
gorm.io/driver/sqlserver v0.2.0
gorm.io/gorm v0.0.0-00010101000000-000000000000
gorm.io/driver/mysql v0.2.1
gorm.io/driver/postgres v0.2.1
gorm.io/driver/sqlite v1.0.4
gorm.io/driver/sqlserver v0.2.1
gorm.io/gorm v0.2.7
)

replace gorm.io/gorm => ../
120 changes: 120 additions & 0 deletions tests/transaction_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,123 @@ func TestTransactionRaiseErrorOnRollbackAfterCommit(t *testing.T) {
t.Fatalf("Rollback after commit should raise error")
}
}

func TestTransactionWithSavePoint(t *testing.T) {
tx := DB.Begin()

user := *GetUser("transaction-save-point", Config{})
tx.Create(&user)

if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}

if err := tx.SavePoint("save_point1").Error; err != nil {
t.Fatalf("Failed to save point, got error %v", err)
}

user1 := *GetUser("transaction-save-point-1", Config{})
tx.Create(&user1)

if err := tx.First(&User{}, "name = ?", user1.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}

if err := tx.RollbackTo("save_point1").Error; err != nil {
t.Fatalf("Failed to save point, got error %v", err)
}

if err := tx.First(&User{}, "name = ?", user1.Name).Error; err == nil {
t.Fatalf("Should not find rollbacked record")
}

if err := tx.SavePoint("save_point2").Error; err != nil {
t.Fatalf("Failed to save point, got error %v", err)
}

user2 := *GetUser("transaction-save-point-2", Config{})
tx.Create(&user2)

if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}

if err := tx.Commit().Error; err != nil {
t.Fatalf("Failed to commit, got error %v", err)
}

if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}

if err := DB.First(&User{}, "name = ?", user1.Name).Error; err == nil {
t.Fatalf("Should not find rollbacked record")
}

if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}
}

func TestNestedTransactionWithBlock(t *testing.T) {
var (
user = *GetUser("transaction-nested", Config{})
user1 = *GetUser("transaction-nested-1", Config{})
user2 = *GetUser("transaction-nested-2", Config{})
)

if err := DB.Transaction(func(tx *gorm.DB) error {
tx.Create(&user)

if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}

if err := tx.Transaction(func(tx1 *gorm.DB) error {
tx1.Create(&user1)

if err := tx1.First(&User{}, "name = ?", user1.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}

return errors.New("rollback")
}); err == nil {
t.Fatalf("nested transaction should returns error")
}

if err := tx.First(&User{}, "name = ?", user1.Name).Error; err == nil {
t.Fatalf("Should not find rollbacked record")
}

if err := tx.Transaction(func(tx2 *gorm.DB) error {
tx2.Create(&user2)

if err := tx2.First(&User{}, "name = ?", user2.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}

return nil
}); err != nil {
t.Fatalf("nested transaction returns error: %v", err)
}

if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}
return nil
}); err != nil {
t.Fatalf("no error should return, but got %v", err)
}

if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}

if err := DB.First(&User{}, "name = ?", user1.Name).Error; err == nil {
t.Fatalf("Should not find rollbacked record")
}

if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil {
t.Fatalf("Should find saved record")
}
}
6 changes: 0 additions & 6 deletions wercker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,3 @@ build:
name: test mssql
code: |
GORM_DIALECT=mssql GORM_VERBOSE=true GORM_DSN="sqlserver://gorm:LoremIpsum86@mssql:1433?database=gorm" ./tests/tests_all.sh
- script:
name: codecov
code: |
go test -race -coverprofile=coverage.txt -covermode=atomic ./...
bash <(curl -s https://codecov.io/bash)

0 comments on commit 7dc255a

Please sign in to comment.