Skip to content

Commit

Permalink
feat: add Connection to execute multiple commands in a single conne…
Browse files Browse the repository at this point in the history
…ction; (go-gorm#4982)
  • Loading branch information
kinggo authored Jan 7, 2022
1 parent f757b8f commit 0df42e9
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 0 deletions.
24 changes: 24 additions & 0 deletions finisher_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,30 @@ func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error {
return tx.Error
}

// Connection use a db conn to execute Multiple commands,this conn will put conn pool after it is executed.
func (db *DB) Connection(fc func(tx *DB) error) (err error) {
if db.Error != nil {
return db.Error
}

tx := db.getInstance()
sqlDB, err := tx.DB()
if err != nil {
return
}

conn, err := sqlDB.Conn(tx.Statement.Context)
if err != nil {
return
}

defer conn.Close()
tx.Statement.ConnPool = conn
err = fc(tx)

return
}

// Transaction start a transaction as a block, return error will rollback, otherwise to commit.
func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) {
panicked := true
Expand Down
48 changes: 48 additions & 0 deletions tests/connection_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package tests_test

import (
"fmt"
"gorm.io/driver/mysql"
"gorm.io/gorm"
"testing"
)

func TestWithSingleConnection(t *testing.T) {

var expectedName = "test"
var actualName string

setSQL, getSQL := getSetSQL(DB.Dialector.Name())
if len(setSQL) == 0 || len(getSQL) == 0 {
return
}

err := DB.Connection(func(tx *gorm.DB) error {
if err := tx.Exec(setSQL, expectedName).Error; err != nil {
return err
}

if err := tx.Raw(getSQL).Scan(&actualName).Error; err != nil {
return err
}
return nil
})

if err != nil {
t.Errorf(fmt.Sprintf("WithSingleConnection should work, but got err %v", err))
}

if actualName != expectedName {
t.Errorf("WithSingleConnection() method should get correct value, expect: %v, got %v", expectedName, actualName)
}

}

func getSetSQL(driverName string) (string, string) {
switch driverName {
case mysql.Dialector{}.Name():
return "SET @testName := ?", "SELECT @testName"
default:
return "", ""
}
}

0 comments on commit 0df42e9

Please sign in to comment.