diff --git a/gorm/transaction.go b/gorm/transaction.go index e03943f8..d10710f8 100644 --- a/gorm/transaction.go +++ b/gorm/transaction.go @@ -24,8 +24,19 @@ type ctxKey int var txnKey ctxKey var ( - ErrCtxTxnMissing = errors.New("Database transaction for request missing in context") - ErrCtxTxnNoDB = errors.New("Transaction in context, but DB is nil") + ErrCtxTxnMissing = errors.New("Database transaction for request missing in context") + ErrCtxTxnNoDB = errors.New("Transaction in context, but DB is nil") + ErrCtxTxnOptMismatch = errors.New("Transaction in context, but Txn opts are mismatched") + ErrCtxDBOptMismatch = errors.New("Transaction in context, but DB opts are mismatched") +) + +// This is used to define various options to set/get the current database setup in the context +type dbType int + +const ( + dbNotSet dbType = iota + dbReadOnly + dbReadWrite ) // NewContext returns a new Context that carries value txn. @@ -45,10 +56,44 @@ func FromContext(ctx context.Context) (txn *Transaction, ok bool) { type Transaction struct { mu sync.Mutex parent *gorm.DB + parentRO *gorm.DB current *gorm.DB + currentOpts databaseOptions afterCommitHook []func(context.Context) } +type databaseOptions struct { + database dbType + txOpts *sql.TxOptions +} + +type DatabaseOption func(*databaseOptions) + +// WithRODB returns closure to set the readOnlyReplica flag +func WithRODB(readOnlyReplica bool) DatabaseOption { + return func(ops *databaseOptions) { + if !readOnlyReplica { + ops.database = dbReadWrite + } else { + ops.database = dbReadOnly + } + } +} + +// WithTxOptions returns a closure to set the TxOptions +func WithTxOptions(opts *sql.TxOptions) DatabaseOption { + return func(ops *databaseOptions) { + ops.txOpts = opts + } +} +func toDatabaseOptions(options ...DatabaseOption) *databaseOptions { + opts := &databaseOptions{} + for _, op := range options { + op(opts) + } + return opts +} + func NewTransaction(db *gorm.DB) Transaction { return Transaction{parent: db} } @@ -57,23 +102,98 @@ func (t *Transaction) AddAfterCommitHook(hooks ...func(context.Context)) { t.afterCommitHook = append(t.afterCommitHook, hooks...) } -// BeginFromContext will extract transaction wrapper from context and start new transaction. +// getReadOnlyDBInstance returns the read only db txn if RO DB available otherwise it returns read/write db txn +func getReadOnlyDBTxn(ctx context.Context, opts *databaseOptions, txn *Transaction) (*gorm.DB, error) { + var db *gorm.DB + switch { + case txn.parentRO == nil: + return getReadWriteDBTxn(ctx, opts, txn) + case opts.txOpts != nil && txn.currentOpts.txOpts != nil: + if *opts.txOpts != *txn.currentOpts.txOpts { + return nil, ErrCtxTxnOptMismatch + } + case opts.txOpts != nil: + // We should error in two cases 1. We should error if read-only DB requested with read-write txn + // 2. If no txn options provided in previous call but provided in subsequent call + if !opts.txOpts.ReadOnly || txn.currentOpts.database != dbNotSet { + return nil, ErrCtxTxnOptMismatch + } + txnOpts := *opts.txOpts + txn.currentOpts.txOpts = &txnOpts + } + if txn.current != nil { + return txn.current, nil + } + db = txn.beginReadOnlyWithContextAndOptions(ctx, txn.currentOpts.txOpts) + if db.Error != nil { + return nil, db.Error + } + if txn.currentOpts.database == dbNotSet { + txn.currentOpts.database = dbReadOnly + } + return db, nil +} + +// getReadWriteDBTxn returns the read/write db txn +func getReadWriteDBTxn(ctx context.Context, opts *databaseOptions, txn *Transaction) (*gorm.DB, error) { + var db *gorm.DB + switch { + case txn.parent == nil: + return nil, ErrCtxTxnNoDB + case opts.txOpts != nil && txn.currentOpts.txOpts != nil: + if *opts.txOpts != *txn.currentOpts.txOpts { + return nil, ErrCtxTxnOptMismatch + } + case opts.txOpts != nil: + // We should return error If no txn options provided in previous call but provided in subsequent call + if txn.currentOpts.database != dbNotSet { + return nil, ErrCtxTxnOptMismatch + } + txnOpts := *opts.txOpts + txn.currentOpts.txOpts = &txnOpts + } + if txn.current != nil { + return txn.current, nil + } + db = txn.beginWithContextAndOptions(ctx, txn.currentOpts.txOpts) + if db.Error != nil { + return nil, db.Error + } + if txn.currentOpts.database == dbNotSet { + txn.currentOpts.database = dbReadWrite + } + return db, nil +} + +// BeginFromContext will return read only db txn if readOnlyReplica flag is set otherwise it will extract transaction wrapper from context and start new transaction +// If readOnlyReplica flag is set and read only db is not available then it will check if a txn with read-write db is already in use then it will return a txn from ctx otherwise it will start a new txn with read/write db and return // As result new instance of `*gorm.DB` will be returned. // Error will be returned in case either transaction or db connection info is missing in context. // Gorm specific error can be checked by `*gorm.DB.Error`. -func BeginFromContext(ctx context.Context) (*gorm.DB, error) { +func BeginFromContext(ctx context.Context, options ...DatabaseOption) (*gorm.DB, error) { txn, ok := FromContext(ctx) if !ok { return nil, ErrCtxTxnMissing } - if txn.parent == nil { - return nil, ErrCtxTxnNoDB - } - db := txn.beginWithContext(ctx) - if db.Error != nil { - return nil, db.Error + opts := toDatabaseOptions(options...) + switch opts.database { + case dbReadOnly: + if txn.currentOpts.database == dbReadWrite && txn.parentRO != nil { + return nil, ErrCtxDBOptMismatch + } + return getReadOnlyDBTxn(ctx, opts, txn) + case dbReadWrite: + if txn.currentOpts.database == dbReadOnly { + return nil, ErrCtxDBOptMismatch + } + return getReadWriteDBTxn(ctx, opts, txn) + default: + // This is the case to handle when no database options provided + if txn.currentOpts.database == dbReadOnly { + return getReadOnlyDBTxn(ctx, opts, txn) + } + return getReadWriteDBTxn(ctx, opts, txn) } - return db, nil } // BeginWithOptionsFromContext will extract transaction wrapper from context and start new transaction, @@ -119,6 +239,18 @@ func (t *Transaction) BeginWithOptions(opts *sql.TxOptions) *gorm.DB { return t.beginWithContextAndOptions(context.Background(), opts) } +// beginReadOnlyWithContextAndOptions will start a new transaction by calling `*gorm.DB.BeginTx() if no current transaction exist +func (t *Transaction) beginReadOnlyWithContextAndOptions(ctx context.Context, opts *sql.TxOptions) *gorm.DB { + t.mu.Lock() + defer t.mu.Unlock() + + if t.current == nil { + t.current = t.parentRO.BeginTx(ctx, opts) + } + + return t.current +} + func (t *Transaction) beginWithContextAndOptions(ctx context.Context, opts *sql.TxOptions) *gorm.DB { t.mu.Lock() defer t.mu.Unlock() @@ -173,15 +305,21 @@ func (t *Transaction) Commit(ctx context.Context) error { // Client is responsible to call `txn.Begin()` to open transaction. // If call of grpc.UnaryHandler returns with an error the transaction // is aborted, otherwise committed. -func UnaryServerInterceptor(db *gorm.DB) grpc.UnaryServerInterceptor { +func UnaryServerInterceptor(db *gorm.DB, readOnlyDB ...*gorm.DB) grpc.UnaryServerInterceptor { txn := &Transaction{parent: db} + if len(readOnlyDB) > 0 { + dbRO := readOnlyDB[0] + if dbRO != nil { + txn.parentRO = dbRO + } + } return UnaryServerInterceptorTxn(txn) } func UnaryServerInterceptorTxn(txn *Transaction) grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { // Deep copy is necessary as a tansaction should be created per request. - txn := &Transaction{parent: txn.parent, afterCommitHook: txn.afterCommitHook} + txn := &Transaction{parent: txn.parent, parentRO: txn.parentRO, afterCommitHook: txn.afterCommitHook} defer func() { // simple panic handler if perr := recover(); perr != nil { diff --git a/gorm/transaction_test.go b/gorm/transaction_test.go index c276cb53..36771481 100644 --- a/gorm/transaction_test.go +++ b/gorm/transaction_test.go @@ -14,6 +14,14 @@ import ( "google.golang.org/grpc/status" ) +type dbOptions int + +const ( + noOptions dbOptions = 0 + readOnly dbOptions = 1 + readWrite dbOptions = 2 +) + func TestUnaryServerInterceptor_success(t *testing.T) { db, mock, err := sqlmock.New() if err != nil { @@ -45,6 +53,95 @@ func TestUnaryServerInterceptor_success(t *testing.T) { } } +func TestUnaryServerInterceptor_with_readonlydb(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create sqlmock - %s", err) + } + readOnlyDB, dbROMock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create sqlmock for read-only db - %s", err) + } + mock.ExpectBegin() + mock.ExpectCommit() + + gdb, err := gorm.Open("postgres", db) + if err != nil { + t.Fatalf("failed to open gorm db - %s", err) + } + dbRO, err := gorm.Open("postgres", readOnlyDB) + if err != nil { + t.Fatalf("failed to open read-only gorm db - %s", err) + } + + interceptor := UnaryServerInterceptor(gdb, dbRO) + _, err = interceptor(context.Background(), nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) { + txn, ok := FromContext(ctx) + if !ok { + t.Error("failed to extract transaction from context") + } + if dbRO != txn.parentRO { + t.Errorf("failed to set read-only db") + } + return nil, txn.Begin().Error + }) + if err != nil { + t.Errorf("unexpected error - %s", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("failed to manage transaction on success response - %s", err) + } + if err := dbROMock.ExpectationsWereMet(); err != nil { + t.Errorf("failed to manage transaction on success response for read-only db - %s", err) + } +} + +func TestUnaryServerInterceptorTxn_with_readonlydb(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create sqlmock - %s", err) + } + readOnlyDB, dbROMock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create sqlmock for read-only db - %s", err) + } + mock.ExpectBegin() + mock.ExpectCommit() + + gdb, err := gorm.Open("postgres", db) + if err != nil { + t.Fatalf("failed to open gorm db - %s", err) + } + dbRO, err := gorm.Open("postgres", readOnlyDB) + if err != nil { + t.Fatalf("failed to open read-only gorm db - %s", err) + } + txn := NewTransaction(gdb) + txn.parentRO = dbRO + interceptor := UnaryServerInterceptorTxn(&txn) + _, err = interceptor(context.Background(), nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) { + txn, ok := FromContext(ctx) + if !ok { + t.Error("failed to extract transaction from context") + } + if dbRO != txn.parentRO { + t.Errorf("failed to set read only db") + } + return nil, txn.Begin().Error + }) + if err != nil { + t.Errorf("unexpected error - %s", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("failed to manage transaction on success response - %s", err) + } + if err := dbROMock.ExpectationsWereMet(); err != nil { + t.Errorf("failed to manage transaction on success response for read-only db - %s", err) + } +} + func TestUnaryServerInterceptorTxn_success(t *testing.T) { db, mock, err := sqlmock.New() if err != nil { @@ -329,6 +426,440 @@ func TestContext(t *testing.T) { } } +func beginFromContextWithOptions(ctx context.Context, withOpts dbOptions, txOpts *sql.TxOptions) (*gorm.DB, error) { + switch withOpts { + case noOptions: + if txOpts == nil { + return BeginFromContext(ctx) + } + return BeginFromContext(ctx, WithTxOptions(txOpts)) + case readOnly: + if txOpts == nil { + return BeginFromContext(ctx, WithRODB(true)) + } + return BeginFromContext(ctx, WithRODB(true), WithTxOptions(txOpts)) + case readWrite: + if txOpts == nil { + return BeginFromContext(ctx, WithRODB(false)) + } + return BeginFromContext(ctx, WithRODB(false), WithTxOptions(txOpts)) + } + return nil, nil +} + +func TestBeginFromContextStartWithNoOptions(t *testing.T) { + tests := []struct { + desc string + withOpts dbOptions + txOpts *sql.TxOptions + }{ + { + desc: "begin without options and without Tx options", + withOpts: noOptions, + }, + { + desc: "begin without options and with Tx options", + withOpts: noOptions, + txOpts: &sql.TxOptions{}, + }, + } + for _, test := range tests { + test := test + t.Run(test.desc, func(t *testing.T) { + ctx := context.Background() + + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create sqlmock - %s", err) + } + gdb, err := gorm.Open("postgres", db) + if err != nil { + t.Fatalf("failed to open gorm db - %s", err) + } + readOnlyDB, _, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create sqlmock for read-only db - %s", err) + } + dbRO, err := gorm.Open("postgres", readOnlyDB) + if err != nil { + t.Fatalf("failed to open gorm read-only db - %s", err) + } + mock.ExpectBegin() + ctxtxn := &Transaction{parent: gdb, parentRO: dbRO} + ctx = NewContext(ctx, ctxtxn) + if test.txOpts == nil { + txn1, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != nil { + t.Error("Received an error beginning transaction") + } + if txn1 == nil { + t.Error("Did not receive a transaction from context") + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("failed to begin transaction for read-write db - %s", err) + } + test.withOpts = readOnly + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != ErrCtxDBOptMismatch { + t.Error("begin transaction should fail with an error DBOptionsMismatch") + } + test.withOpts = readWrite + txn3, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != nil { + t.Error("Received an error beginning transaction") + } + if txn3 == nil { + t.Error("Did not receive a transaction from context") + } + // Case: Transaction begin is idempotent + if txn1 != txn3 { + t.Error("Got a different txn than was opened before") + } + test.txOpts = &sql.TxOptions{} + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != ErrCtxTxnOptMismatch { + t.Error("begin transaction should fail with an error TxnOptionsMismatch") + } + } else { + txn1, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != nil { + t.Error("Received an error beginning transaction") + } + if txn1 == nil { + t.Error("Did not receive a transaction from context") + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("failed to begin transaction for read-write db - %s", err) + } + test.txOpts.ReadOnly = true + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != ErrCtxTxnOptMismatch { + t.Error("begin transaction should fail with an error TxOptionsMismatch") + } + test.txOpts.ReadOnly = false + test.txOpts.Isolation = sql.LevelSerializable + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != ErrCtxTxnOptMismatch { + t.Error("begin transaction should fail with an error TxOptionsMismatch") + } + test.txOpts.Isolation = sql.LevelDefault + test.withOpts = readOnly + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != ErrCtxDBOptMismatch { + t.Error("begin transaction should fail with an error DBOptionsMismatch") + } + test.withOpts = readWrite + txn3, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != nil { + t.Error("Received an error beginning transaction") + } + if txn3 == nil { + t.Error("Did not receive a transaction from context") + } + // Case: Transaction begin is idempotent + if txn1 != txn3 { + t.Error("Got a different txn than was opened before") + } + test.txOpts.ReadOnly = true + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != ErrCtxTxnOptMismatch { + t.Error("begin transaction should fail with an error TxOptionsMismatch") + } + test.txOpts.ReadOnly = false + test.txOpts.Isolation = sql.LevelSerializable + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != ErrCtxTxnOptMismatch { + t.Error("begin transaction should fail with an error TxOptionsMismatch") + } + txn4, err := beginFromContextWithOptions(ctx, test.withOpts, nil) + if err != nil { + t.Error("Received an error beginning transaction") + } + if txn4 == nil { + t.Error("Did not receive a transaction from context") + } + // Case: Transaction begin is idempotent + if txn1 != txn4 { + t.Error("Got a different txn than was opened before") + } + } + }) + + } +} + +func TestBeginFromContextStartWithReadOnlyOptions(t *testing.T) { + tests := []struct { + desc string + withOpts dbOptions + txOpts *sql.TxOptions + }{ + { + desc: "begin with read only options and without Tx options", + withOpts: readOnly, + }, + { + desc: "begin with read only options and with Tx options", + withOpts: readOnly, + txOpts: &sql.TxOptions{}, + }, + } + for _, test := range tests { + test := test + t.Run(test.desc, func(t *testing.T) { + ctx := context.Background() + + db, _, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create sqlmock - %s", err) + } + gdb, err := gorm.Open("postgres", db) + if err != nil { + t.Fatalf("failed to open gorm db - %s", err) + } + readOnlyDB, dbROMock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create sqlmock for read-only db - %s", err) + } + dbRO, err := gorm.Open("postgres", readOnlyDB) + if err != nil { + t.Fatalf("failed to open read-only gorm db - %s", err) + } + dbROMock.ExpectBegin() + ctxtxn := &Transaction{parent: gdb, parentRO: dbRO} + ctx = NewContext(ctx, ctxtxn) + if test.txOpts == nil { + txn1, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != nil { + t.Error("Received an error beginning transaction") + } + if txn1 == nil { + t.Error("Did not receive a transaction from context") + } + if err := dbROMock.ExpectationsWereMet(); err != nil { + t.Errorf("failed to begin transaction for read-only db - %s", err) + } + test.withOpts = noOptions + txn2, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != nil { + t.Error("Received an error beginning transaction") + } + if txn2 == nil { + t.Error("Did not receive a transaction from context") + } + // Case: Transaction begin is idempotent + if txn1 != txn2 { + t.Error("Got a different txn than was opened before") + } + test.withOpts = readWrite + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != ErrCtxDBOptMismatch { + t.Error("begin transaction should fail with an error DBOptionsMismatch") + } + test.withOpts = noOptions + test.txOpts = &sql.TxOptions{} + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != ErrCtxTxnOptMismatch { + t.Error("begin transaction should fail with an error TxnOptionsMismatch") + } + } else { + _, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != ErrCtxTxnOptMismatch { + t.Error("begin transaction should fail with an error TxOptionsMismatch") + } + test.txOpts.ReadOnly = true + txn1, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != nil { + t.Error("Received an error beginning transaction") + } + if txn1 == nil { + t.Error("Did not receive a transaction from context") + } + if err := dbROMock.ExpectationsWereMet(); err != nil { + t.Errorf("failed to begin transaction for read-only db - %s", err) + } + test.txOpts.ReadOnly = false + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != ErrCtxTxnOptMismatch { + t.Error("begin transaction should fail with an error TxOptionsMismatch") + } + test.txOpts.ReadOnly = true + test.txOpts.Isolation = sql.LevelSerializable + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != ErrCtxTxnOptMismatch { + t.Error("begin transaction should fail with an error TxOptionsMismatch") + } + test.txOpts.Isolation = sql.LevelDefault + test.withOpts = noOptions + txn2, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != nil { + t.Error("Received an error beginning transaction") + } + if txn2 == nil { + t.Error("Did not receive a transaction from context") + } + if txn1 != txn2 { + t.Error("Got a different txn than was opened before") + } + txn3, err := beginFromContextWithOptions(ctx, test.withOpts, nil) + if err != nil { + t.Error("Received an error beginning transaction") + } + if txn3 == nil { + t.Error("Did not receive a transaction from context") + } + if txn1 != txn3 { + t.Error("Got a different txn than was opened before") + } + test.withOpts = readWrite + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != ErrCtxDBOptMismatch { + t.Error("begin transaction should fail with an error DBOptionsMismatch") + } + } + }) + } +} + +func TestBeginFromContextStartWithReadWriteOptions(t *testing.T) { + tests := []struct { + desc string + withOpts dbOptions + txOpts *sql.TxOptions + }{ + { + desc: "begin with read write options and without Tx options", + withOpts: readWrite, + }, + { + desc: "begin with read write options and with Tx options", + withOpts: readWrite, + txOpts: &sql.TxOptions{}, + }, + } + for _, test := range tests { + test := test + t.Run(test.desc, func(t *testing.T) { + ctx := context.Background() + + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create sqlmock - %s", err) + } + gdb, err := gorm.Open("postgres", db) + if err != nil { + t.Fatalf("failed to open gorm db - %s", err) + } + readOnlyDB, _, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create sqlmock for read-only db - %s", err) + } + dbRO, err := gorm.Open("postgres", readOnlyDB) + if err != nil { + t.Fatalf("failed to open read-only gorm db - %s", err) + } + mock.ExpectBegin() + ctxtxn := &Transaction{parent: gdb, parentRO: dbRO} + ctx = NewContext(ctx, ctxtxn) + if test.txOpts == nil { + txn1, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != nil { + t.Error("Received an error beginning transaction") + } + if txn1 == nil { + t.Error("Did not receive a transaction from context") + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("failed to begin transaction for read-write db - %s", err) + } + test.withOpts = readOnly + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != ErrCtxDBOptMismatch { + t.Error("begin transaction should fail with an error DBOptionsMismatch") + } + test.withOpts = noOptions + txn3, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != nil { + t.Error("Received an error beginning transaction") + } + if txn3 == nil { + t.Error("Did not receive a transaction from context") + } + // Case: Transaction begin is idempotent + if txn1 != txn3 { + t.Error("Got a different txn than was opened before") + } + test.txOpts = &sql.TxOptions{} + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != ErrCtxTxnOptMismatch { + t.Error("begin transaction should fail with an error TxOptionsMismatch") + } + } else { + txn1, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != nil { + t.Error("Received an error beginning transaction") + } + if txn1 == nil { + t.Error("Did not receive a transaction from context") + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("failed to begin transaction for read-write db - %s", err) + } + test.txOpts.ReadOnly = true + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != ErrCtxTxnOptMismatch { + t.Error("begin transaction should fail with an error TxOptionsMismatch") + } + test.txOpts.ReadOnly = false + test.txOpts.Isolation = sql.LevelSerializable + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != ErrCtxTxnOptMismatch { + t.Error("begin transaction should fail with an error TxOptionsMismatch") + } + test.txOpts.Isolation = sql.LevelDefault + test.withOpts = readOnly + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != ErrCtxDBOptMismatch { + t.Error("begin transaction should fail with an error DBOptionsMismatch") + } + test.withOpts = noOptions + txn2, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != nil { + t.Error("Received an error beginning transaction") + } + if txn2 == nil { + t.Error("Did not receive a transaction from context") + } + // Case: Transaction begin is idempotent + if txn1 != txn2 { + t.Error("Got a different txn than was opened before") + } + test.txOpts.ReadOnly = true + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != ErrCtxTxnOptMismatch { + t.Error("begin transaction should fail with an error TxOptionsMismatch") + } + test.txOpts.ReadOnly = false + test.txOpts.Isolation = sql.LevelSerializable + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != ErrCtxTxnOptMismatch { + t.Error("begin transaction should fail with an error TxOptionsMismatch") + } + txn3, err := beginFromContextWithOptions(ctx, test.withOpts, nil) + if err != nil { + t.Error("Received an error beginning transaction") + } + if txn3 == nil { + t.Error("Did not receive a transaction from context") + } + if txn1 != txn3 { + t.Error("Got a different txn than was opened before") + } + } + }) + } +} + func beginFromContext(ctx context.Context, withOpts bool) (*gorm.DB, error) { switch withOpts { case true: