From 3ef44ce1cdd969a21b76d6c803119cf12c375cb0 Mon Sep 17 00:00:00 2001 From: I am Goroot Date: Mon, 21 Oct 2024 07:43:14 +0200 Subject: [PATCH] feat: add transaction isolation level support to pgdriver (#1034) --- driver/pgdriver/driver.go | 17 ++++++--- driver/pgdriver/driver_test.go | 69 ++++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 5 deletions(-) diff --git a/driver/pgdriver/driver.go b/driver/pgdriver/driver.go index 1b46a752d..77246ac75 100644 --- a/driver/pgdriver/driver.go +++ b/driver/pgdriver/driver.go @@ -5,7 +5,6 @@ import ( "context" "database/sql" "database/sql/driver" - "errors" "fmt" "io" "log" @@ -213,15 +212,23 @@ var _ driver.ConnBeginTx = (*Conn)(nil) func (cn *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { // No need to check if the conn is closed. ExecContext below handles that. + isolation := sql.IsolationLevel(opts.Isolation) - if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault { - return nil, errors.New("pgdriver: custom IsolationLevel is not supported") + var command string + switch isolation { + case sql.LevelDefault: + command = "BEGIN" + case sql.LevelReadUncommitted, sql.LevelReadCommitted, sql.LevelRepeatableRead, sql.LevelSerializable: + command = fmt.Sprintf("BEGIN; SET TRANSACTION ISOLATION LEVEL %s", isolation.String()) + default: + return nil, fmt.Errorf("pgdriver: unsupported transaction isolation: %s", isolation.String()) } + if opts.ReadOnly { - return nil, errors.New("pgdriver: ReadOnly transactions are not supported") + command = fmt.Sprintf("%s READ ONLY", command) } - if _, err := cn.ExecContext(ctx, "BEGIN", nil); err != nil { + if _, err := cn.ExecContext(ctx, command, nil); err != nil { return nil, err } return tx{cn: cn}, nil diff --git a/driver/pgdriver/driver_test.go b/driver/pgdriver/driver_test.go index 801801671..aaaabb1c6 100644 --- a/driver/pgdriver/driver_test.go +++ b/driver/pgdriver/driver_test.go @@ -3,7 +3,9 @@ package pgdriver_test import ( "context" "database/sql" + "fmt" "os" + "strings" "sync" "sync/atomic" "testing" @@ -288,6 +290,73 @@ func TestPartialScan(t *testing.T) { } } +func TestTransactionIsolationLevels(t *testing.T) { + db := sqlDB() + t.Cleanup(func() { + require.NoError(t, db.Close()) + }) + type testCase struct { + *sql.TxOptions + supported bool + expectedIsoLvl string + } + testCases := []testCase{ + // supported + {TxOptions: &sql.TxOptions{Isolation: sql.LevelDefault, ReadOnly: true}, supported: true, expectedIsoLvl: "READ COMMITTED"}, + {TxOptions: &sql.TxOptions{Isolation: sql.LevelDefault, ReadOnly: false}, supported: true, expectedIsoLvl: "READ COMMITTED"}, + + {TxOptions: &sql.TxOptions{Isolation: sql.LevelReadUncommitted, ReadOnly: true}, supported: true, expectedIsoLvl: sql.LevelReadUncommitted.String()}, + {TxOptions: &sql.TxOptions{Isolation: sql.LevelReadUncommitted, ReadOnly: false}, supported: true, expectedIsoLvl: sql.LevelReadUncommitted.String()}, + {TxOptions: &sql.TxOptions{Isolation: sql.LevelReadCommitted, ReadOnly: true}, supported: true, expectedIsoLvl: sql.LevelReadCommitted.String()}, + {TxOptions: &sql.TxOptions{Isolation: sql.LevelReadCommitted, ReadOnly: false}, supported: true, expectedIsoLvl: sql.LevelReadCommitted.String()}, + {TxOptions: &sql.TxOptions{Isolation: sql.LevelRepeatableRead, ReadOnly: true}, supported: true, expectedIsoLvl: sql.LevelRepeatableRead.String()}, + {TxOptions: &sql.TxOptions{Isolation: sql.LevelRepeatableRead, ReadOnly: false}, supported: true, expectedIsoLvl: sql.LevelRepeatableRead.String()}, + {TxOptions: &sql.TxOptions{Isolation: sql.LevelSerializable, ReadOnly: true}, supported: true, expectedIsoLvl: sql.LevelSerializable.String()}, + {TxOptions: &sql.TxOptions{Isolation: sql.LevelSerializable, ReadOnly: false}, supported: true, expectedIsoLvl: sql.LevelSerializable.String()}, + // unsupported + {TxOptions: &sql.TxOptions{Isolation: sql.LevelLinearizable, ReadOnly: true}, supported: false}, + {TxOptions: &sql.TxOptions{Isolation: sql.LevelLinearizable, ReadOnly: false}, supported: false}, + {TxOptions: &sql.TxOptions{Isolation: sql.LevelSnapshot, ReadOnly: true}, supported: false}, + {TxOptions: &sql.TxOptions{Isolation: sql.LevelSnapshot, ReadOnly: false}, supported: false}, + {TxOptions: &sql.TxOptions{Isolation: sql.LevelWriteCommitted, ReadOnly: true}, supported: false}, + {TxOptions: &sql.TxOptions{Isolation: sql.LevelWriteCommitted, ReadOnly: false}, supported: false}, + } + testIsolationFunc := func(t *testing.T, testCase testCase) { + tx, err := db.BeginTx(context.Background(), testCase.TxOptions) + if !testCase.supported { + require.Error(t, err) + return + } + require.NoError(t, err) + t.Cleanup(func() { + err := tx.Rollback() + require.NoError(t, err) + }) + + var currentLvl string + err = tx.QueryRow("SHOW TRANSACTION ISOLATION LEVEL;").Scan(¤tLvl) + require.NoError(t, err) + expectedIsoLvl := strings.ToUpper(testCase.expectedIsoLvl) + currentIsoLvl := strings.ToUpper(currentLvl) + require.Equal(t, expectedIsoLvl, currentIsoLvl) + + var readOnlyResult string + err = tx.QueryRow("SHOW TRANSACTION_READ_ONLY").Scan(&readOnlyResult) + require.NoError(t, err) + isReadOnly := strings.ToUpper(readOnlyResult) == "ON" + + require.Equal(t, testCase.ReadOnly, isReadOnly) + } + + for i := 0; i < len(testCases); i++ { + testCase := testCases[i] + name := fmt.Sprintf("test isolation level %s read only %t", testCase.Isolation.String(), testCase.ReadOnly) + t.Run(name, func(t *testing.T) { + testIsolationFunc(t, testCase) + }) + } +} + func sqlDB() *sql.DB { db, err := sql.Open("pg", dsn()) if err != nil {