From 682bdcabc4108a3426db16d76ac600cec167b9e0 Mon Sep 17 00:00:00 2001 From: Ashok Ankam Date: Wed, 18 Dec 2024 14:27:54 -0500 Subject: [PATCH 1/7] promoting changes to add support for read only db instance --- gorm/transaction.go | 52 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 48 insertions(+), 4 deletions(-) diff --git a/gorm/transaction.go b/gorm/transaction.go index e03943f8..b8ed9a2f 100644 --- a/gorm/transaction.go +++ b/gorm/transaction.go @@ -7,6 +7,7 @@ import ( "reflect" "sync" + "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/infobloxopen/atlas-app-toolkit/rpc/errdetails" "github.com/jinzhu/gorm" "google.golang.org/grpc" @@ -17,15 +18,18 @@ import ( // ctxKey is an unexported type for keys defined in this package. // This prevents collisions with keys defined in other packages. type ctxKey int +type readOnlyDBKey int // txnKey is the key for `*Transaction` values in `context.Context`. // It is unexported; clients use NewContext and FromContext // instead of using this key directly. var txnKey ctxKey +var roDBKey readOnlyDBKey var ( ErrCtxTxnMissing = errors.New("Database transaction for request missing in context") ErrCtxTxnNoDB = errors.New("Transaction in context, but DB is nil") + ErrNoReadOnlyDB = errors.New("No read-only DB") ) // NewContext returns a new Context that carries value txn. @@ -39,6 +43,25 @@ func FromContext(ctx context.Context) (txn *Transaction, ok bool) { return } +// GetReadOnlyDB returns the read only db instance stored in the ctx if there is no txn in use with read/write database +func GetReadOnlyDB(ctx context.Context) (*gorm.DB, error) { + logger := ctxlogrus.Extract(ctx) + txn, ok := FromContext(ctx) + if !ok { + return nil, ErrCtxTxnMissing + } + if txn.current != nil { + logger.Warnf("GetReadOnlyDB: Txn already initialized with read/write DB") + return txn.current, nil + } + dbRO, ok := ctx.Value(roDBKey).(*gorm.DB) + if !ok { + return nil, ErrNoReadOnlyDB + } + txn.readOnly = true + return dbRO, nil +} + // Transaction serves as a wrapper around `*gorm.DB` instance. // It works as a singleton to prevent an application of creating more than one // transaction instance per incoming request. @@ -46,6 +69,7 @@ type Transaction struct { mu sync.Mutex parent *gorm.DB current *gorm.DB + readOnly bool afterCommitHook []func(context.Context) } @@ -57,15 +81,29 @@ func (t *Transaction) AddAfterCommitHook(hooks ...func(context.Context)) { t.afterCommitHook = append(t.afterCommitHook, hooks...) } -// BeginFromContext will extract transaction wrapper from context and start new transaction. +// GetReadOnlyDBInstance returns the read only database instance stored in the ctx +func GetReadOnlyDBInstance(ctx context.Context) (*gorm.DB, error) { + dbRO, ok := ctx.Value(roDBKey).(*gorm.DB) + if !ok { + return nil, ErrNoReadOnlyDB + } + return dbRO, nil +} + +// BeginFromContext will extract transaction wrapper from context and start new transaction if transaction is not set to read only otherwise it will return read only database instance // As result new instance of `*gorm.DB` will be returned. // Error will be returned in case either transaction or db connection info is missing in context. // Gorm specific error can be checked by `*gorm.DB.Error`. func BeginFromContext(ctx context.Context) (*gorm.DB, error) { + logger := ctxlogrus.Extract(ctx) txn, ok := FromContext(ctx) if !ok { return nil, ErrCtxTxnMissing } + if txn.readOnly == true { + logger.Warnf("BeginFromContext: Read Only DB instance already in use!") + return GetReadOnlyDBInstance(ctx) + } if txn.parent == nil { return nil, ErrCtxTxnNoDB } @@ -173,12 +211,12 @@ func (t *Transaction) Commit(ctx context.Context) error { // Client is responsible to call `txn.Begin()` to open transaction. // If call of grpc.UnaryHandler returns with an error the transaction // is aborted, otherwise committed. -func UnaryServerInterceptor(db *gorm.DB) grpc.UnaryServerInterceptor { +func UnaryServerInterceptor(db *gorm.DB, readOnlyDB ...*gorm.DB) grpc.UnaryServerInterceptor { txn := &Transaction{parent: db} - return UnaryServerInterceptorTxn(txn) + return UnaryServerInterceptorTxn(txn, readOnlyDB...) } -func UnaryServerInterceptorTxn(txn *Transaction) grpc.UnaryServerInterceptor { +func UnaryServerInterceptorTxn(txn *Transaction, readOnlyDB ...*gorm.DB) grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { // Deep copy is necessary as a tansaction should be created per request. txn := &Transaction{parent: txn.parent, afterCommitHook: txn.afterCommitHook} @@ -220,6 +258,12 @@ func UnaryServerInterceptorTxn(txn *Transaction) grpc.UnaryServerInterceptor { }() ctx = NewContext(ctx, txn) + if len(readOnlyDB) > 0 { + dbRO := readOnlyDB[0] + if dbRO != nil { + ctx = context.WithValue(ctx, roDBKey, dbRO) + } + } resp, err = handler(ctx, req) return resp, err From 8f6932c63ffa2722189f9ab5f3b3cf066be8431d Mon Sep 17 00:00:00 2001 From: Ashok Ankam Date: Sat, 21 Dec 2024 17:45:49 -0500 Subject: [PATCH 2/7] incorporated review comments --- gorm/transaction.go | 88 +++++++++++++++++++++++++++------------- gorm/transaction_test.go | 69 +++++++++++++++++++++++++++++++ 2 files changed, 128 insertions(+), 29 deletions(-) diff --git a/gorm/transaction.go b/gorm/transaction.go index b8ed9a2f..472ea3e8 100644 --- a/gorm/transaction.go +++ b/gorm/transaction.go @@ -7,9 +7,9 @@ import ( "reflect" "sync" - "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/infobloxopen/atlas-app-toolkit/rpc/errdetails" "github.com/jinzhu/gorm" + logger "github.com/sirupsen/logrus" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -18,12 +18,19 @@ import ( // ctxKey is an unexported type for keys defined in this package. // This prevents collisions with keys defined in other packages. type ctxKey int + +// readOnlyDBKey is an unexported type and used to define a key for storing read only db instance in the context. +// This prevents collisions with keys defined in other package type readOnlyDBKey int // txnKey is the key for `*Transaction` values in `context.Context`. // It is unexported; clients use NewContext and FromContext // instead of using this key directly. var txnKey ctxKey + +// roDBKey is the key used for storing read-only db instance in the context. +// It is unexported; clients use BeginFromContext with options to get the read only db instance +// instead of using this key directly. var roDBKey readOnlyDBKey var ( @@ -43,25 +50,6 @@ func FromContext(ctx context.Context) (txn *Transaction, ok bool) { return } -// GetReadOnlyDB returns the read only db instance stored in the ctx if there is no txn in use with read/write database -func GetReadOnlyDB(ctx context.Context) (*gorm.DB, error) { - logger := ctxlogrus.Extract(ctx) - txn, ok := FromContext(ctx) - if !ok { - return nil, ErrCtxTxnMissing - } - if txn.current != nil { - logger.Warnf("GetReadOnlyDB: Txn already initialized with read/write DB") - return txn.current, nil - } - dbRO, ok := ctx.Value(roDBKey).(*gorm.DB) - if !ok { - return nil, ErrNoReadOnlyDB - } - txn.readOnly = true - return dbRO, nil -} - // Transaction serves as a wrapper around `*gorm.DB` instance. // It works as a singleton to prevent an application of creating more than one // transaction instance per incoming request. @@ -73,6 +61,27 @@ type Transaction struct { afterCommitHook []func(context.Context) } +type databaseOptions struct { + readOnlyReplica bool +} + +type DatabaseOption func(*databaseOptions) + +// WithRODB returns clouser to set the readOnlyReplica flag +func WithRODB(readOnlyReplica bool) DatabaseOption { + return func(ops *databaseOptions) { + ops.readOnlyReplica = readOnlyReplica + } +} + +func toDatabaseOptions(options ...DatabaseOption) *databaseOptions { + opts := &databaseOptions{} + for _, op := range options { + op(opts) + } + return opts +} + func NewTransaction(db *gorm.DB) Transaction { return Transaction{parent: db} } @@ -81,28 +90,49 @@ func (t *Transaction) AddAfterCommitHook(hooks ...func(context.Context)) { t.afterCommitHook = append(t.afterCommitHook, hooks...) } -// GetReadOnlyDBInstance returns the read only database instance stored in the ctx -func GetReadOnlyDBInstance(ctx context.Context) (*gorm.DB, error) { +// getReadOnlyDBInstance returns the read only database instance stored in the ctx +func getReadOnlyDBInstance(ctx context.Context) (*gorm.DB, error) { + txn, ok := FromContext(ctx) + if !ok { + return nil, ErrCtxTxnMissing + } dbRO, ok := ctx.Value(roDBKey).(*gorm.DB) if !ok { - return nil, ErrNoReadOnlyDB + logger.Warnf("BeginFromContext: requested: read-only DB, returns: read-write DB, reason: read-only DB not available") + if txn.parent == nil { + return nil, ErrCtxTxnNoDB + } + db := txn.beginWithContext(ctx) + if db.Error != nil { + return nil, db.Error + } + return db, nil } + txn.readOnly = true return dbRO, nil } -// BeginFromContext will extract transaction wrapper from context and start new transaction if transaction is not set to read only otherwise it will return read only database instance +// BeginFromContext will return read only db instance if readOnlyReplica flag is set otherwise it will extract transaction wrapper from context and start new transaction +// If readOnlyReplica flag is set and a txn with read-write db is already in use then it will return a txn from ctx rather than providing a read-only db instance. // As result new instance of `*gorm.DB` will be returned. // Error will be returned in case either transaction or db connection info is missing in context. // Gorm specific error can be checked by `*gorm.DB.Error`. -func BeginFromContext(ctx context.Context) (*gorm.DB, error) { - logger := ctxlogrus.Extract(ctx) +func BeginFromContext(ctx context.Context, options ...DatabaseOption) (*gorm.DB, error) { txn, ok := FromContext(ctx) if !ok { return nil, ErrCtxTxnMissing } - if txn.readOnly == true { - logger.Warnf("BeginFromContext: Read Only DB instance already in use!") - return GetReadOnlyDBInstance(ctx) + opts := toDatabaseOptions(options...) + if opts.readOnlyReplica == true { + if txn.current == nil { + return getReadOnlyDBInstance(ctx) + } else { + logger.Warnf("BeginFromContext: requested: read-only DB, returns: read-write DB, reason: read-write DB txn in use") + return txn.current, nil + } + } else if txn.readOnly == true { + logger.Warnf("BeginFromContext: requested: read-write DB, returns: read-only DB, reason: txn set to read only") + return getReadOnlyDBInstance(ctx) } if txn.parent == nil { return nil, ErrCtxTxnNoDB diff --git a/gorm/transaction_test.go b/gorm/transaction_test.go index c276cb53..8e8b0570 100644 --- a/gorm/transaction_test.go +++ b/gorm/transaction_test.go @@ -329,6 +329,75 @@ func TestContext(t *testing.T) { } } +func beginFromContextWithOptions(ctx context.Context, withOpts bool) (*gorm.DB, error) { + switch withOpts { + case true: + return BeginFromContext(ctx, WithRODB(true)) + case false: + return BeginFromContext(ctx) + } + return nil, nil +} + +func TestBeginFromContextWithOptions(t *testing.T) { + tests := []struct { + desc string + withOpts bool + }{ + { + desc: "begin without options", + withOpts: false, + }, + { + desc: "begin with options", + withOpts: true, + }, + } + for _, test := range tests { + test := test + t.Run(test.desc, func(t *testing.T) { + ctx := context.Background() + + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create sqlmock - %s", err) + } + gdb, err := gorm.Open("postgres", db) + if err != nil { + t.Fatalf("failed to open gorm db - %s", err) + } + dbReadOnly, dbROMock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create sqlmock - %s", err) + } + dbRO, err := gorm.Open("postgres", dbReadOnly) + if err != nil { + t.Fatalf("failed to open gorm read only db - %s", err) + } + mock.ExpectBegin() + dbROMock.ExpectBegin() + ctxtxn := &Transaction{parent: gdb} + ctx = NewContext(ctx, ctxtxn) + ctx = context.WithValue(ctx, roDBKey, dbRO) + txn1, err := beginFromContextWithOptions(ctx, test.withOpts) + if txn1 == nil { + t.Error("Did not receive a transaction from context") + } + if err != nil { + t.Error("Received an error beginning transaction") + } + // Case: Transaction begin is idempotent + txn2, err := beginFromContextWithOptions(ctx, !test.withOpts) + if txn2 != txn1 { + t.Error("Got a different txn than was opened before") + } + if err != nil { + t.Error("Received an error opening transaction") + } + }) + } +} + func beginFromContext(ctx context.Context, withOpts bool) (*gorm.DB, error) { switch withOpts { case true: From 689e73cea56ebf16fbf60b78461848acf2175c2c Mon Sep 17 00:00:00 2001 From: Ashok Ankam Date: Sun, 22 Dec 2024 13:40:57 -0500 Subject: [PATCH 3/7] incorporated review comments --- gorm/transaction_test.go | 92 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/gorm/transaction_test.go b/gorm/transaction_test.go index 8e8b0570..8bb18215 100644 --- a/gorm/transaction_test.go +++ b/gorm/transaction_test.go @@ -45,6 +45,98 @@ func TestUnaryServerInterceptor_success(t *testing.T) { } } +func TestUnaryServerInterceptor_with_readonlydb(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create sqlmock - %s", err) + } + dbReadOnly, dbROMock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create sqlmock - %s", err) + } + mock.ExpectBegin() + mock.ExpectCommit() + dbROMock.ExpectBegin() + + gdb, err := gorm.Open("postgres", db) + if err != nil { + t.Fatalf("failed to open gorm db - %s", err) + } + dbRO, err := gorm.Open("postgres", dbReadOnly) + if err != nil { + t.Fatalf("failed to open gorm db - %s", err) + } + + interceptor := UnaryServerInterceptor(gdb, dbRO) + _, err = interceptor(context.Background(), nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) { + txn, ok := FromContext(ctx) + if !ok { + t.Error("failed to extract transaction from context") + } + readOnlyDB, err := BeginFromContext(ctx, WithRODB(true)) + if err != nil { + t.Errorf("failed to get read only db instance - %s", err) + } + if dbRO != readOnlyDB { + t.Errorf("failed to set read only db instance") + } + return nil, txn.Begin().Error + }) + if err != nil { + t.Errorf("unexpected error - %s", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("failed to manage transaction on success response - %s", err) + } +} + +func TestUnaryServerInterceptorTxn_with_readonlydb(t *testing.T) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create sqlmock - %s", err) + } + dbReadOnly, dbROMock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create sqlmock - %s", err) + } + mock.ExpectBegin() + mock.ExpectCommit() + dbROMock.ExpectBegin() + + gdb, err := gorm.Open("postgres", db) + if err != nil { + t.Fatalf("failed to open gorm db - %s", err) + } + dbRO, err := gorm.Open("postgres", dbReadOnly) + if err != nil { + t.Fatalf("failed to open gorm db - %s", err) + } + txn := NewTransaction(gdb) + interceptor := UnaryServerInterceptorTxn(&txn, dbRO) + _, err = interceptor(context.Background(), nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) { + txn, ok := FromContext(ctx) + if !ok { + t.Error("failed to extract transaction from context") + } + readOnlyDB, err := BeginFromContext(ctx, WithRODB(true)) + if err != nil { + t.Errorf("failed to get read only db instance - %s", err) + } + if dbRO != readOnlyDB { + t.Errorf("failed to set read only db instance") + } + return nil, txn.Begin().Error + }) + if err != nil { + t.Errorf("unexpected error - %s", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("failed to manage transaction on success response - %s", err) + } +} + func TestUnaryServerInterceptorTxn_success(t *testing.T) { db, mock, err := sqlmock.New() if err != nil { From 0dc2e3d6819bd95ab5144cf6e97f66b904b70669 Mon Sep 17 00:00:00 2001 From: Ashok Ankam Date: Tue, 7 Jan 2025 01:03:15 -0500 Subject: [PATCH 4/7] incorporated review comments --- gorm/transaction.go | 186 ++++++++++++++------ gorm/transaction_test.go | 364 +++++++++++++++++++++++++++++++++++---- 2 files changed, 459 insertions(+), 91 deletions(-) diff --git a/gorm/transaction.go b/gorm/transaction.go index 472ea3e8..39d2761f 100644 --- a/gorm/transaction.go +++ b/gorm/transaction.go @@ -9,7 +9,6 @@ import ( "github.com/infobloxopen/atlas-app-toolkit/rpc/errdetails" "github.com/jinzhu/gorm" - logger "github.com/sirupsen/logrus" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -19,24 +18,25 @@ import ( // This prevents collisions with keys defined in other packages. type ctxKey int -// readOnlyDBKey is an unexported type and used to define a key for storing read only db instance in the context. -// This prevents collisions with keys defined in other package -type readOnlyDBKey int +// This is used to define various options to set/get the current database setup in the context +type dbType int // txnKey is the key for `*Transaction` values in `context.Context`. // It is unexported; clients use NewContext and FromContext // instead of using this key directly. var txnKey ctxKey -// roDBKey is the key used for storing read-only db instance in the context. -// It is unexported; clients use BeginFromContext with options to get the read only db instance -// instead of using this key directly. -var roDBKey readOnlyDBKey - var ( - ErrCtxTxnMissing = errors.New("Database transaction for request missing in context") - ErrCtxTxnNoDB = errors.New("Transaction in context, but DB is nil") - ErrNoReadOnlyDB = errors.New("No read-only DB") + ErrCtxTxnMissing = errors.New("Database transaction for request missing in context") + ErrCtxTxnNoDB = errors.New("Transaction in context, but DB is nil") + ErrCtxTxnOptMismatch = errors.New("Transaction in context, but Txn opts are mismatched") + ErrCtxDBOptMismatch = errors.New("Transaction in context, but DB opts are mismatched") +) + +const ( + dbNotSet dbType = iota + dbReadOnly + dbReadWrite ) // NewContext returns a new Context that carries value txn. @@ -56,24 +56,33 @@ func FromContext(ctx context.Context) (txn *Transaction, ok bool) { type Transaction struct { mu sync.Mutex parent *gorm.DB + parentRO *gorm.DB current *gorm.DB - readOnly bool + txOpts *sql.TxOptions + currentDB dbType afterCommitHook []func(context.Context) } type databaseOptions struct { readOnlyReplica bool + txOpts *sql.TxOptions } type DatabaseOption func(*databaseOptions) -// WithRODB returns clouser to set the readOnlyReplica flag +// WithRODB returns closure to set the readOnlyReplica flag func WithRODB(readOnlyReplica bool) DatabaseOption { return func(ops *databaseOptions) { ops.readOnlyReplica = readOnlyReplica } } +// WithTxOptions returns a closure to set the TxOptions +func WithTxOptions(opts *sql.TxOptions) DatabaseOption { + return func(ops *databaseOptions) { + ops.txOpts = opts + } +} func toDatabaseOptions(options ...DatabaseOption) *databaseOptions { opts := &databaseOptions{} for _, op := range options { @@ -90,30 +99,86 @@ func (t *Transaction) AddAfterCommitHook(hooks ...func(context.Context)) { t.afterCommitHook = append(t.afterCommitHook, hooks...) } -// getReadOnlyDBInstance returns the read only database instance stored in the ctx -func getReadOnlyDBInstance(ctx context.Context) (*gorm.DB, error) { - txn, ok := FromContext(ctx) - if !ok { - return nil, ErrCtxTxnMissing - } - dbRO, ok := ctx.Value(roDBKey).(*gorm.DB) - if !ok { - logger.Warnf("BeginFromContext: requested: read-only DB, returns: read-write DB, reason: read-only DB not available") - if txn.parent == nil { - return nil, ErrCtxTxnNoDB +// getReadOnlyDBInstance returns the read only db txn if RO DB available otherwise it returns read/write db txn +func getReadOnlyDBTxn(ctx context.Context, opts *databaseOptions, txn *Transaction) (*gorm.DB, error) { + var db *gorm.DB + if txn.parentRO == nil { + return getReadWriteDBTxn(ctx, opts, txn) + } else { + if opts.txOpts != nil || txn.txOpts != nil { + if opts.txOpts == nil { + return nil, ErrCtxTxnOptMismatch + } else if txn.txOpts == nil { + if opts.txOpts.ReadOnly == false { // For Read Only DB, the Txn should be set as read only before calling this method + return nil, ErrCtxTxnOptMismatch + } + txn.txOpts = &sql.TxOptions{} + txn.txOpts.ReadOnly = opts.txOpts.ReadOnly + txn.txOpts.Isolation = opts.txOpts.Isolation + } else { + if opts.txOpts.Isolation != txn.txOpts.Isolation || opts.txOpts.ReadOnly != txn.txOpts.ReadOnly { + return nil, ErrCtxTxnOptMismatch + } + } + if txn.current != nil { + return txn.current, nil + } + db = txn.beginReadOnlyWithContextAndOptions(ctx, txn.txOpts) + } else { + if txn.current != nil { + return txn.current, nil + } + db = txn.beginReadOnlyWithContext(ctx) } - db := txn.beginWithContext(ctx) if db.Error != nil { return nil, db.Error } + if txn.currentDB == dbNotSet { + txn.currentDB = dbReadOnly + } return db, nil } - txn.readOnly = true - return dbRO, nil } -// BeginFromContext will return read only db instance if readOnlyReplica flag is set otherwise it will extract transaction wrapper from context and start new transaction -// If readOnlyReplica flag is set and a txn with read-write db is already in use then it will return a txn from ctx rather than providing a read-only db instance. +// getReadWriteDBTxn returns the read/write db txn +func getReadWriteDBTxn(ctx context.Context, opts *databaseOptions, txn *Transaction) (*gorm.DB, error) { + var db *gorm.DB + if txn.parent == nil { + return nil, ErrCtxTxnNoDB + } + if opts.txOpts != nil || txn.txOpts != nil { + if opts.txOpts == nil { + return nil, ErrCtxTxnOptMismatch + } else if txn.txOpts == nil { + txn.txOpts = &sql.TxOptions{} + txn.txOpts.Isolation = opts.txOpts.Isolation + txn.txOpts.ReadOnly = opts.txOpts.ReadOnly + } else { + if opts.txOpts.Isolation != txn.txOpts.Isolation || opts.txOpts.ReadOnly != txn.txOpts.ReadOnly { + return nil, ErrCtxTxnOptMismatch + } + } + if txn.current != nil { + return txn.current, nil + } + db = txn.beginWithContextAndOptions(ctx, txn.txOpts) + } else { + if txn.current != nil { + return txn.current, nil + } + db = txn.beginWithContext(ctx) + } + if db.Error != nil { + return nil, db.Error + } + if txn.currentDB == dbNotSet { + txn.currentDB = dbReadWrite + } + return db, nil +} + +// BeginFromContext will return read only db txn if readOnlyReplica flag is set otherwise it will extract transaction wrapper from context and start new transaction +// If readOnlyReplica flag is set and read only db is not available then it will check if a txn with read-write db is already in use then it will return a txn from ctx otherwise it will start a new txn with read/write db and return // As result new instance of `*gorm.DB` will be returned. // Error will be returned in case either transaction or db connection info is missing in context. // Gorm specific error can be checked by `*gorm.DB.Error`. @@ -124,24 +189,15 @@ func BeginFromContext(ctx context.Context, options ...DatabaseOption) (*gorm.DB, } opts := toDatabaseOptions(options...) if opts.readOnlyReplica == true { - if txn.current == nil { - return getReadOnlyDBInstance(ctx) + if txn.currentDB == dbReadWrite && txn.parentRO != nil { //check if currentDB is set to read/write DB in case the RO DB is not available + return nil, ErrCtxDBOptMismatch } else { - logger.Warnf("BeginFromContext: requested: read-only DB, returns: read-write DB, reason: read-write DB txn in use") - return txn.current, nil + return getReadOnlyDBTxn(ctx, opts, txn) } - } else if txn.readOnly == true { - logger.Warnf("BeginFromContext: requested: read-write DB, returns: read-only DB, reason: txn set to read only") - return getReadOnlyDBInstance(ctx) + } else if txn.currentDB == dbReadOnly { + return nil, ErrCtxDBOptMismatch } - if txn.parent == nil { - return nil, ErrCtxTxnNoDB - } - db := txn.beginWithContext(ctx) - if db.Error != nil { - return nil, db.Error - } - return db, nil + return getReadWriteDBTxn(ctx, opts, txn) } // BeginWithOptionsFromContext will extract transaction wrapper from context and start new transaction, @@ -181,12 +237,34 @@ func (t *Transaction) beginWithContext(ctx context.Context) *gorm.DB { return t.current } +func (t *Transaction) beginReadOnlyWithContext(ctx context.Context) *gorm.DB { + t.mu.Lock() + defer t.mu.Unlock() + + if t.current == nil { + t.current = t.parentRO.BeginTx(ctx, nil) + } + + return t.current +} + // BeginWithOptions starts new transaction by calling `*gorm.DB.BeginTx()` // Returns new instance of `*gorm.DB` (error can be checked by `*gorm.DB.Error`) func (t *Transaction) BeginWithOptions(opts *sql.TxOptions) *gorm.DB { return t.beginWithContextAndOptions(context.Background(), opts) } +func (t *Transaction) beginReadOnlyWithContextAndOptions(ctx context.Context, opts *sql.TxOptions) *gorm.DB { + t.mu.Lock() + defer t.mu.Unlock() + + if t.current == nil { + t.current = t.parentRO.BeginTx(ctx, opts) + } + + return t.current +} + func (t *Transaction) beginWithContextAndOptions(ctx context.Context, opts *sql.TxOptions) *gorm.DB { t.mu.Lock() defer t.mu.Unlock() @@ -243,13 +321,19 @@ func (t *Transaction) Commit(ctx context.Context) error { // is aborted, otherwise committed. func UnaryServerInterceptor(db *gorm.DB, readOnlyDB ...*gorm.DB) grpc.UnaryServerInterceptor { txn := &Transaction{parent: db} - return UnaryServerInterceptorTxn(txn, readOnlyDB...) + if len(readOnlyDB) > 0 { + dbRO := readOnlyDB[0] + if dbRO != nil { + txn.parentRO = dbRO + } + } + return UnaryServerInterceptorTxn(txn) } -func UnaryServerInterceptorTxn(txn *Transaction, readOnlyDB ...*gorm.DB) grpc.UnaryServerInterceptor { +func UnaryServerInterceptorTxn(txn *Transaction) grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { // Deep copy is necessary as a tansaction should be created per request. - txn := &Transaction{parent: txn.parent, afterCommitHook: txn.afterCommitHook} + txn := &Transaction{parent: txn.parent, parentRO: txn.parentRO, afterCommitHook: txn.afterCommitHook} defer func() { // simple panic handler if perr := recover(); perr != nil { @@ -288,12 +372,6 @@ func UnaryServerInterceptorTxn(txn *Transaction, readOnlyDB ...*gorm.DB) grpc.Un }() ctx = NewContext(ctx, txn) - if len(readOnlyDB) > 0 { - dbRO := readOnlyDB[0] - if dbRO != nil { - ctx = context.WithValue(ctx, roDBKey, dbRO) - } - } resp, err = handler(ctx, req) return resp, err diff --git a/gorm/transaction_test.go b/gorm/transaction_test.go index 8bb18215..a9f75014 100644 --- a/gorm/transaction_test.go +++ b/gorm/transaction_test.go @@ -14,6 +14,14 @@ import ( "google.golang.org/grpc/status" ) +type dbOptions int + +const ( + noOptions dbOptions = 0 + readOnly dbOptions = 1 + readWrite dbOptions = 2 +) + func TestUnaryServerInterceptor_success(t *testing.T) { db, mock, err := sqlmock.New() if err != nil { @@ -50,7 +58,7 @@ func TestUnaryServerInterceptor_with_readonlydb(t *testing.T) { if err != nil { t.Fatalf("failed to create sqlmock - %s", err) } - dbReadOnly, dbROMock, err := sqlmock.New() + readOnlyDB, dbROMock, err := sqlmock.New() if err != nil { t.Fatalf("failed to create sqlmock - %s", err) } @@ -62,7 +70,7 @@ func TestUnaryServerInterceptor_with_readonlydb(t *testing.T) { if err != nil { t.Fatalf("failed to open gorm db - %s", err) } - dbRO, err := gorm.Open("postgres", dbReadOnly) + dbRO, err := gorm.Open("postgres", readOnlyDB) if err != nil { t.Fatalf("failed to open gorm db - %s", err) } @@ -73,11 +81,7 @@ func TestUnaryServerInterceptor_with_readonlydb(t *testing.T) { if !ok { t.Error("failed to extract transaction from context") } - readOnlyDB, err := BeginFromContext(ctx, WithRODB(true)) - if err != nil { - t.Errorf("failed to get read only db instance - %s", err) - } - if dbRO != readOnlyDB { + if dbRO != txn.parentRO { t.Errorf("failed to set read only db instance") } return nil, txn.Begin().Error @@ -96,7 +100,7 @@ func TestUnaryServerInterceptorTxn_with_readonlydb(t *testing.T) { if err != nil { t.Fatalf("failed to create sqlmock - %s", err) } - dbReadOnly, dbROMock, err := sqlmock.New() + readOnlyDB, dbROMock, err := sqlmock.New() if err != nil { t.Fatalf("failed to create sqlmock - %s", err) } @@ -108,22 +112,19 @@ func TestUnaryServerInterceptorTxn_with_readonlydb(t *testing.T) { if err != nil { t.Fatalf("failed to open gorm db - %s", err) } - dbRO, err := gorm.Open("postgres", dbReadOnly) + dbRO, err := gorm.Open("postgres", readOnlyDB) if err != nil { t.Fatalf("failed to open gorm db - %s", err) } txn := NewTransaction(gdb) - interceptor := UnaryServerInterceptorTxn(&txn, dbRO) + txn.parentRO = dbRO + interceptor := UnaryServerInterceptorTxn(&txn) _, err = interceptor(context.Background(), nil, nil, func(ctx context.Context, req interface{}) (interface{}, error) { txn, ok := FromContext(ctx) if !ok { t.Error("failed to extract transaction from context") } - readOnlyDB, err := BeginFromContext(ctx, WithRODB(true)) - if err != nil { - t.Errorf("failed to get read only db instance - %s", err) - } - if dbRO != readOnlyDB { + if dbRO != txn.parentRO { t.Errorf("failed to set read only db instance") } return nil, txn.Begin().Error @@ -421,28 +422,41 @@ func TestContext(t *testing.T) { } } -func beginFromContextWithOptions(ctx context.Context, withOpts bool) (*gorm.DB, error) { +func beginFromContextWithOptions(ctx context.Context, withOpts dbOptions, txOpts *sql.TxOptions) (*gorm.DB, error) { switch withOpts { - case true: - return BeginFromContext(ctx, WithRODB(true)) - case false: - return BeginFromContext(ctx) + case noOptions: + if txOpts == nil { + return BeginFromContext(ctx) + } + return BeginFromContext(ctx, WithTxOptions(txOpts)) + case readOnly: + if txOpts == nil { + return BeginFromContext(ctx, WithRODB(true)) + } + return BeginFromContext(ctx, WithRODB(true), WithTxOptions(txOpts)) + case readWrite: + if txOpts == nil { + return BeginFromContext(ctx, WithRODB(false)) + } + return BeginFromContext(ctx, WithRODB(false), WithTxOptions(txOpts)) } return nil, nil } -func TestBeginFromContextWithOptions(t *testing.T) { +func TestBeginFromContextStartWithNoOptions(t *testing.T) { tests := []struct { desc string - withOpts bool + withOpts dbOptions + txOpts *sql.TxOptions }{ { - desc: "begin without options", - withOpts: false, + desc: "begin without options and without Tx options", + withOpts: noOptions, }, { - desc: "begin with options", - withOpts: true, + desc: "begin without options and with Tx options", + withOpts: noOptions, + txOpts: &sql.TxOptions{}, }, } for _, test := range tests { @@ -468,23 +482,299 @@ func TestBeginFromContextWithOptions(t *testing.T) { } mock.ExpectBegin() dbROMock.ExpectBegin() - ctxtxn := &Transaction{parent: gdb} + ctxtxn := &Transaction{parent: gdb, parentRO: dbRO} ctx = NewContext(ctx, ctxtxn) - ctx = context.WithValue(ctx, roDBKey, dbRO) - txn1, err := beginFromContextWithOptions(ctx, test.withOpts) - if txn1 == nil { - t.Error("Did not receive a transaction from context") + if test.txOpts == nil { + txn1, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != nil { + t.Error("Received an error beginning transaction") + } + if txn1 == nil { + t.Error("Did not receive a transaction from context") + } + test.withOpts = readOnly + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err == nil { + t.Error("begin transaction should fail with an error DBOptionsMismatch") + } + test.withOpts = readWrite + txn3, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != nil { + t.Error("Received an error beginning transaction") + } + if txn3 == nil { + t.Error("Did not receive a transaction from context") + } + // Case: Transaction begin is idempotent + if txn1 != txn3 { + t.Error("Got a different txn than was opened before") + } + } else { + txn1, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != nil { + t.Error("Received an error beginning transaction") + } + if txn1 == nil { + t.Error("Did not receive a transaction from context") + } + test.txOpts.ReadOnly = true + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err == nil { + t.Error("begin transaction should fail with an error TxOptionsMismatch") + } + test.txOpts.ReadOnly = false + test.txOpts.Isolation = sql.LevelSerializable + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err == nil { + t.Error("begin transaction should fail with an error TxOptionsMismatch") + } + test.txOpts.Isolation = sql.LevelDefault + test.withOpts = readOnly + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err == nil { + t.Error("begin transaction should fail with an error DBOptionsMismatch") + } + test.withOpts = readWrite + txn3, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != nil { + t.Error("Received an error beginning transaction") + } + if txn3 == nil { + t.Error("Did not receive a transaction from context") + } + // Case: Transaction begin is idempotent + if txn1 != txn3 { + t.Error("Got a different txn than was opened before") + } + test.txOpts.ReadOnly = true + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err == nil { + t.Error("begin transaction should fail with an error TxOptionsMismatch") + } + test.txOpts.ReadOnly = false + test.txOpts.Isolation = sql.LevelSerializable + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err == nil { + t.Error("begin transaction should fail with an error TxOptionsMismatch") + } } + }) + } +} + +func TestBeginFromContextStartWithReadOnlyOptions(t *testing.T) { + tests := []struct { + desc string + withOpts dbOptions + txOpts *sql.TxOptions + }{ + { + desc: "begin with read only options and without Tx options", + withOpts: readOnly, + }, + { + desc: "begin with read only options and with Tx options", + withOpts: readOnly, + txOpts: &sql.TxOptions{}, + }, + } + for _, test := range tests { + test := test + t.Run(test.desc, func(t *testing.T) { + ctx := context.Background() + + db, mock, err := sqlmock.New() if err != nil { - t.Error("Received an error beginning transaction") + t.Fatalf("failed to create sqlmock - %s", err) } - // Case: Transaction begin is idempotent - txn2, err := beginFromContextWithOptions(ctx, !test.withOpts) - if txn2 != txn1 { - t.Error("Got a different txn than was opened before") + gdb, err := gorm.Open("postgres", db) + if err != nil { + t.Fatalf("failed to open gorm db - %s", err) + } + dbReadOnly, dbROMock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create sqlmock - %s", err) } + dbRO, err := gorm.Open("postgres", dbReadOnly) if err != nil { - t.Error("Received an error opening transaction") + t.Fatalf("failed to open gorm read only db - %s", err) + } + mock.ExpectBegin() + dbROMock.ExpectBegin() + ctxtxn := &Transaction{parent: gdb, parentRO: dbRO} + ctx = NewContext(ctx, ctxtxn) + if test.txOpts == nil { + txn1, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != nil { + t.Error("Received an error beginning transaction") + } + if txn1 == nil { + t.Error("Did not receive a transaction from context") + } + test.withOpts = noOptions + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err == nil { + t.Error("begin transaction should fail with an error DBOptionsMismatch") + } + test.withOpts = readWrite + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err == nil { + t.Error("begin transaction should fail with an error DBOptionsMismatch") + } + } else { + _, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err == nil { + t.Error("begin transaction should fail with an error TxOptionsMismatch") + } + test.txOpts.ReadOnly = true + txn1, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != nil { + t.Error("Received an error beginning transaction") + } + if txn1 == nil { + t.Error("Did not receive a transaction from context") + } + test.txOpts.ReadOnly = false + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err == nil { + t.Error("begin transaction should fail with an error TxOptionsMismatch") + } + test.txOpts.ReadOnly = true + test.txOpts.Isolation = sql.LevelSerializable + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err == nil { + t.Error("begin transaction should fail with an error TxOptionsMismatch") + } + test.txOpts.Isolation = sql.LevelDefault + test.withOpts = noOptions + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err == nil { + t.Error("begin transaction should fail with an error DBOptionsMismatch") + } + test.withOpts = readWrite + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err == nil { + t.Error("begin transaction should fail with an error DBOptionsMismatch") + } + } + }) + } +} + +func TestBeginFromContextStartWithReadWriteOptions(t *testing.T) { + tests := []struct { + desc string + withOpts dbOptions + txOpts *sql.TxOptions + }{ + { + desc: "begin with read write options and without Tx options", + withOpts: readWrite, + }, + { + desc: "begin with read write options and with Tx options", + withOpts: readWrite, + txOpts: &sql.TxOptions{}, + }, + } + for _, test := range tests { + test := test + t.Run(test.desc, func(t *testing.T) { + ctx := context.Background() + + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create sqlmock - %s", err) + } + gdb, err := gorm.Open("postgres", db) + if err != nil { + t.Fatalf("failed to open gorm db - %s", err) + } + dbReadOnly, dbROMock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create sqlmock - %s", err) + } + dbRO, err := gorm.Open("postgres", dbReadOnly) + if err != nil { + t.Fatalf("failed to open gorm read only db - %s", err) + } + mock.ExpectBegin() + dbROMock.ExpectBegin() + ctxtxn := &Transaction{parent: gdb, parentRO: dbRO} + ctx = NewContext(ctx, ctxtxn) + if test.txOpts == nil { + txn1, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != nil { + t.Error("Received an error beginning transaction") + } + if txn1 == nil { + t.Error("Did not receive a transaction from context") + } + test.withOpts = readOnly + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err == nil { + t.Error("begin transaction should fail with an error DBOptionsMismatch") + } + test.withOpts = noOptions + txn3, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != nil { + t.Error("Received an error beginning transaction") + } + if txn3 == nil { + t.Error("Did not receive a transaction from context") + } + // Case: Transaction begin is idempotent + if txn1 != txn3 { + t.Error("Got a different txn than was opened before") + } + } else { + txn1, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != nil { + t.Error("Received an error beginning transaction") + } + if txn1 == nil { + t.Error("Did not receive a transaction from context") + } + test.txOpts.ReadOnly = true + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err == nil { + t.Error("begin transaction should fail with an error TxOptionsMismatch") + } + test.txOpts.ReadOnly = false + test.txOpts.Isolation = sql.LevelSerializable + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err == nil { + t.Error("begin transaction should fail with an error TxOptionsMismatch") + } + test.txOpts.Isolation = sql.LevelDefault + test.withOpts = readOnly + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err == nil { + t.Error("begin transaction should fail with an error DBOptionsMismatch") + } + test.withOpts = noOptions + txn3, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != nil { + t.Error("Received an error beginning transaction") + } + if txn3 == nil { + t.Error("Did not receive a transaction from context") + } + // Case: Transaction begin is idempotent + if txn1 != txn3 { + t.Error("Got a different txn than was opened before") + } + test.txOpts.ReadOnly = true + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err == nil { + t.Error("begin transaction should fail with an error TxOptionsMismatch") + } + test.txOpts.ReadOnly = false + test.txOpts.Isolation = sql.LevelSerializable + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err == nil { + t.Error("begin transaction should fail with an error TxOptionsMismatch") + } } }) } From 1b2f990fc84f7bd2c6c80787aa1e6c2e91d69b20 Mon Sep 17 00:00:00 2001 From: Ashok Ankam Date: Wed, 8 Jan 2025 04:36:08 -0500 Subject: [PATCH 5/7] incorporated the review comments --- gorm/transaction.go | 144 ++++++++++++++++++--------------------- gorm/transaction_test.go | 107 ++++++++++++++++++++++------- 2 files changed, 148 insertions(+), 103 deletions(-) diff --git a/gorm/transaction.go b/gorm/transaction.go index 39d2761f..9547b562 100644 --- a/gorm/transaction.go +++ b/gorm/transaction.go @@ -18,9 +18,6 @@ import ( // This prevents collisions with keys defined in other packages. type ctxKey int -// This is used to define various options to set/get the current database setup in the context -type dbType int - // txnKey is the key for `*Transaction` values in `context.Context`. // It is unexported; clients use NewContext and FromContext // instead of using this key directly. @@ -33,6 +30,9 @@ var ( ErrCtxDBOptMismatch = errors.New("Transaction in context, but DB opts are mismatched") ) +// This is used to define various options to set/get the current database setup in the context +type dbType int + const ( dbNotSet dbType = iota dbReadOnly @@ -58,14 +58,13 @@ type Transaction struct { parent *gorm.DB parentRO *gorm.DB current *gorm.DB - txOpts *sql.TxOptions - currentDB dbType + currentOpts databaseOptions afterCommitHook []func(context.Context) } type databaseOptions struct { - readOnlyReplica bool - txOpts *sql.TxOptions + database dbType + txOpts *sql.TxOptions } type DatabaseOption func(*databaseOptions) @@ -73,7 +72,11 @@ type DatabaseOption func(*databaseOptions) // WithRODB returns closure to set the readOnlyReplica flag func WithRODB(readOnlyReplica bool) DatabaseOption { return func(ops *databaseOptions) { - ops.readOnlyReplica = readOnlyReplica + if readOnlyReplica == false { + ops.database = dbReadWrite + } else { + ops.database = dbReadOnly + } } } @@ -102,77 +105,66 @@ func (t *Transaction) AddAfterCommitHook(hooks ...func(context.Context)) { // getReadOnlyDBInstance returns the read only db txn if RO DB available otherwise it returns read/write db txn func getReadOnlyDBTxn(ctx context.Context, opts *databaseOptions, txn *Transaction) (*gorm.DB, error) { var db *gorm.DB - if txn.parentRO == nil { + switch { + case txn.parentRO == nil: return getReadWriteDBTxn(ctx, opts, txn) - } else { - if opts.txOpts != nil || txn.txOpts != nil { - if opts.txOpts == nil { + case opts.txOpts != nil && txn.currentOpts.txOpts != nil: + if opts.txOpts.ReadOnly != txn.currentOpts.txOpts.ReadOnly || opts.txOpts.Isolation != txn.currentOpts.txOpts.Isolation { + return nil, ErrCtxTxnOptMismatch + } + default: + // We should error in two cases 1. We should error if read-only DB requested with read-write txn + // 2. If no txn options provided in previous call but provided in subsequent call + if opts.txOpts != nil { + if opts.txOpts.ReadOnly == false || txn.currentOpts.database != dbNotSet { return nil, ErrCtxTxnOptMismatch - } else if txn.txOpts == nil { - if opts.txOpts.ReadOnly == false { // For Read Only DB, the Txn should be set as read only before calling this method - return nil, ErrCtxTxnOptMismatch - } - txn.txOpts = &sql.TxOptions{} - txn.txOpts.ReadOnly = opts.txOpts.ReadOnly - txn.txOpts.Isolation = opts.txOpts.Isolation - } else { - if opts.txOpts.Isolation != txn.txOpts.Isolation || opts.txOpts.ReadOnly != txn.txOpts.ReadOnly { - return nil, ErrCtxTxnOptMismatch - } - } - if txn.current != nil { - return txn.current, nil - } - db = txn.beginReadOnlyWithContextAndOptions(ctx, txn.txOpts) - } else { - if txn.current != nil { - return txn.current, nil } - db = txn.beginReadOnlyWithContext(ctx) - } - if db.Error != nil { - return nil, db.Error - } - if txn.currentDB == dbNotSet { - txn.currentDB = dbReadOnly + txnOpts := *opts.txOpts + txn.currentOpts.txOpts = &txnOpts } - return db, nil } + if txn.current != nil { + return txn.current, nil + } + db = txn.beginReadOnlyWithContextAndOptions(ctx, txn.currentOpts.txOpts) + if db.Error != nil { + return nil, db.Error + } + if txn.currentOpts.database == dbNotSet { + txn.currentOpts.database = dbReadOnly + } + return db, nil } // getReadWriteDBTxn returns the read/write db txn func getReadWriteDBTxn(ctx context.Context, opts *databaseOptions, txn *Transaction) (*gorm.DB, error) { var db *gorm.DB - if txn.parent == nil { + switch { + case txn.parent == nil: return nil, ErrCtxTxnNoDB - } - if opts.txOpts != nil || txn.txOpts != nil { - if opts.txOpts == nil { + case opts.txOpts != nil && txn.currentOpts.txOpts != nil: + if opts.txOpts.ReadOnly != txn.currentOpts.txOpts.ReadOnly || opts.txOpts.Isolation != txn.currentOpts.txOpts.Isolation { return nil, ErrCtxTxnOptMismatch - } else if txn.txOpts == nil { - txn.txOpts = &sql.TxOptions{} - txn.txOpts.Isolation = opts.txOpts.Isolation - txn.txOpts.ReadOnly = opts.txOpts.ReadOnly - } else { - if opts.txOpts.Isolation != txn.txOpts.Isolation || opts.txOpts.ReadOnly != txn.txOpts.ReadOnly { + } + default: + if opts.txOpts != nil { + // We should return error If no txn options provided in previous call but provided in subsequent call + if txn.currentOpts.database != dbNotSet { return nil, ErrCtxTxnOptMismatch } + txnOpts := *opts.txOpts + txn.currentOpts.txOpts = &txnOpts } - if txn.current != nil { - return txn.current, nil - } - db = txn.beginWithContextAndOptions(ctx, txn.txOpts) - } else { - if txn.current != nil { - return txn.current, nil - } - db = txn.beginWithContext(ctx) } + if txn.current != nil { + return txn.current, nil + } + db = txn.beginWithContextAndOptions(ctx, txn.currentOpts.txOpts) if db.Error != nil { return nil, db.Error } - if txn.currentDB == dbNotSet { - txn.currentDB = dbReadWrite + if txn.currentOpts.database == dbNotSet { + txn.currentOpts.database = dbReadWrite } return db, nil } @@ -188,16 +180,24 @@ func BeginFromContext(ctx context.Context, options ...DatabaseOption) (*gorm.DB, return nil, ErrCtxTxnMissing } opts := toDatabaseOptions(options...) - if opts.readOnlyReplica == true { - if txn.currentDB == dbReadWrite && txn.parentRO != nil { //check if currentDB is set to read/write DB in case the RO DB is not available + switch opts.database { + case dbReadOnly: + if txn.currentOpts.database == dbReadWrite && txn.parentRO != nil { return nil, ErrCtxDBOptMismatch - } else { + } + return getReadOnlyDBTxn(ctx, opts, txn) + case dbReadWrite: + if txn.currentOpts.database == dbReadOnly { + return nil, ErrCtxDBOptMismatch + } + return getReadWriteDBTxn(ctx, opts, txn) + default: + // This is the case to handle when no database options provided + if txn.currentOpts.database == dbReadOnly { return getReadOnlyDBTxn(ctx, opts, txn) } - } else if txn.currentDB == dbReadOnly { - return nil, ErrCtxDBOptMismatch + return getReadWriteDBTxn(ctx, opts, txn) } - return getReadWriteDBTxn(ctx, opts, txn) } // BeginWithOptionsFromContext will extract transaction wrapper from context and start new transaction, @@ -237,23 +237,13 @@ func (t *Transaction) beginWithContext(ctx context.Context) *gorm.DB { return t.current } -func (t *Transaction) beginReadOnlyWithContext(ctx context.Context) *gorm.DB { - t.mu.Lock() - defer t.mu.Unlock() - - if t.current == nil { - t.current = t.parentRO.BeginTx(ctx, nil) - } - - return t.current -} - // BeginWithOptions starts new transaction by calling `*gorm.DB.BeginTx()` // Returns new instance of `*gorm.DB` (error can be checked by `*gorm.DB.Error`) func (t *Transaction) BeginWithOptions(opts *sql.TxOptions) *gorm.DB { return t.beginWithContextAndOptions(context.Background(), opts) } +// beginReadOnlyWithContextAndOptions will start a new transaction by calling `*gorm.DB.BeginTx() if no current transaction exist func (t *Transaction) beginReadOnlyWithContextAndOptions(ctx context.Context, opts *sql.TxOptions) *gorm.DB { t.mu.Lock() defer t.mu.Unlock() diff --git a/gorm/transaction_test.go b/gorm/transaction_test.go index a9f75014..9ed65dbc 100644 --- a/gorm/transaction_test.go +++ b/gorm/transaction_test.go @@ -60,11 +60,10 @@ func TestUnaryServerInterceptor_with_readonlydb(t *testing.T) { } readOnlyDB, dbROMock, err := sqlmock.New() if err != nil { - t.Fatalf("failed to create sqlmock - %s", err) + t.Fatalf("failed to create sqlmock for read-only db - %s", err) } mock.ExpectBegin() mock.ExpectCommit() - dbROMock.ExpectBegin() gdb, err := gorm.Open("postgres", db) if err != nil { @@ -72,7 +71,7 @@ func TestUnaryServerInterceptor_with_readonlydb(t *testing.T) { } dbRO, err := gorm.Open("postgres", readOnlyDB) if err != nil { - t.Fatalf("failed to open gorm db - %s", err) + t.Fatalf("failed to open read-only gorm db - %s", err) } interceptor := UnaryServerInterceptor(gdb, dbRO) @@ -82,7 +81,7 @@ func TestUnaryServerInterceptor_with_readonlydb(t *testing.T) { t.Error("failed to extract transaction from context") } if dbRO != txn.parentRO { - t.Errorf("failed to set read only db instance") + t.Errorf("failed to set read-only db") } return nil, txn.Begin().Error }) @@ -93,6 +92,9 @@ func TestUnaryServerInterceptor_with_readonlydb(t *testing.T) { if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("failed to manage transaction on success response - %s", err) } + if err := dbROMock.ExpectationsWereMet(); err != nil { + t.Errorf("failed to manage transaction on success response for read-only db - %s", err) + } } func TestUnaryServerInterceptorTxn_with_readonlydb(t *testing.T) { @@ -102,11 +104,10 @@ func TestUnaryServerInterceptorTxn_with_readonlydb(t *testing.T) { } readOnlyDB, dbROMock, err := sqlmock.New() if err != nil { - t.Fatalf("failed to create sqlmock - %s", err) + t.Fatalf("failed to create sqlmock for read-only db - %s", err) } mock.ExpectBegin() mock.ExpectCommit() - dbROMock.ExpectBegin() gdb, err := gorm.Open("postgres", db) if err != nil { @@ -114,7 +115,7 @@ func TestUnaryServerInterceptorTxn_with_readonlydb(t *testing.T) { } dbRO, err := gorm.Open("postgres", readOnlyDB) if err != nil { - t.Fatalf("failed to open gorm db - %s", err) + t.Fatalf("failed to open read-only gorm db - %s", err) } txn := NewTransaction(gdb) txn.parentRO = dbRO @@ -125,7 +126,7 @@ func TestUnaryServerInterceptorTxn_with_readonlydb(t *testing.T) { t.Error("failed to extract transaction from context") } if dbRO != txn.parentRO { - t.Errorf("failed to set read only db instance") + t.Errorf("failed to set read only db") } return nil, txn.Begin().Error }) @@ -136,6 +137,9 @@ func TestUnaryServerInterceptorTxn_with_readonlydb(t *testing.T) { if err := mock.ExpectationsWereMet(); err != nil { t.Errorf("failed to manage transaction on success response - %s", err) } + if err := dbROMock.ExpectationsWereMet(); err != nil { + t.Errorf("failed to manage transaction on success response for read-only db - %s", err) + } } func TestUnaryServerInterceptorTxn_success(t *testing.T) { @@ -472,13 +476,13 @@ func TestBeginFromContextStartWithNoOptions(t *testing.T) { if err != nil { t.Fatalf("failed to open gorm db - %s", err) } - dbReadOnly, dbROMock, err := sqlmock.New() + readOnlyDB, dbROMock, err := sqlmock.New() if err != nil { - t.Fatalf("failed to create sqlmock - %s", err) + t.Fatalf("failed to create sqlmock for read-only db - %s", err) } - dbRO, err := gorm.Open("postgres", dbReadOnly) + dbRO, err := gorm.Open("postgres", readOnlyDB) if err != nil { - t.Fatalf("failed to open gorm read only db - %s", err) + t.Fatalf("failed to open gorm read-only db - %s", err) } mock.ExpectBegin() dbROMock.ExpectBegin() @@ -509,6 +513,11 @@ func TestBeginFromContextStartWithNoOptions(t *testing.T) { if txn1 != txn3 { t.Error("Got a different txn than was opened before") } + test.txOpts = &sql.TxOptions{} + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err == nil { + t.Error("begin transaction should fail with an error TxnOptionsMismatch") + } } else { txn1, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) if err != nil { @@ -557,6 +566,13 @@ func TestBeginFromContextStartWithNoOptions(t *testing.T) { if err == nil { t.Error("begin transaction should fail with an error TxOptionsMismatch") } + txn4, err := beginFromContextWithOptions(ctx, test.withOpts, nil) + if err != nil { + t.Error("Received an error beginning transaction") + } + if txn4 == nil { + t.Error("Did not receive a transaction from context") + } } }) } @@ -591,13 +607,13 @@ func TestBeginFromContextStartWithReadOnlyOptions(t *testing.T) { if err != nil { t.Fatalf("failed to open gorm db - %s", err) } - dbReadOnly, dbROMock, err := sqlmock.New() + readOnlyDB, dbROMock, err := sqlmock.New() if err != nil { - t.Fatalf("failed to create sqlmock - %s", err) + t.Fatalf("failed to create sqlmock for read-only db - %s", err) } - dbRO, err := gorm.Open("postgres", dbReadOnly) + dbRO, err := gorm.Open("postgres", readOnlyDB) if err != nil { - t.Fatalf("failed to open gorm read only db - %s", err) + t.Fatalf("failed to open read-only gorm db - %s", err) } mock.ExpectBegin() dbROMock.ExpectBegin() @@ -612,14 +628,27 @@ func TestBeginFromContextStartWithReadOnlyOptions(t *testing.T) { t.Error("Did not receive a transaction from context") } test.withOpts = noOptions + txn2, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != nil { + t.Error("Received an error beginning transaction") + } + if txn2 == nil { + t.Error("Did not receive a transaction from context") + } + // Case: Transaction begin is idempotent + if txn1 != txn2 { + t.Error("Got a different txn than was opened before") + } + test.withOpts = readWrite _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) if err == nil { t.Error("begin transaction should fail with an error DBOptionsMismatch") } - test.withOpts = readWrite + test.withOpts = noOptions + test.txOpts = &sql.TxOptions{} _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) if err == nil { - t.Error("begin transaction should fail with an error DBOptionsMismatch") + t.Error("begin transaction should fail with an error TxnOptionsMismatch") } } else { _, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) @@ -647,9 +676,22 @@ func TestBeginFromContextStartWithReadOnlyOptions(t *testing.T) { } test.txOpts.Isolation = sql.LevelDefault test.withOpts = noOptions - _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) - if err == nil { - t.Error("begin transaction should fail with an error DBOptionsMismatch") + txn2, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err != nil { + t.Error("Received an error beginning transaction") + } + if txn2 == nil { + t.Error("Did not receive a transaction from context") + } + if txn1 != txn2 { + t.Error("Got a different txn than was opened before") + } + txn3, err := beginFromContextWithOptions(ctx, test.withOpts, nil) + if err != nil { + t.Error("Received an error beginning transaction") + } + if txn3 == nil { + t.Error("Did not receive a transaction from context") } test.withOpts = readWrite _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) @@ -692,11 +734,11 @@ func TestBeginFromContextStartWithReadWriteOptions(t *testing.T) { } dbReadOnly, dbROMock, err := sqlmock.New() if err != nil { - t.Fatalf("failed to create sqlmock - %s", err) + t.Fatalf("failed to create sqlmock for read-only db - %s", err) } dbRO, err := gorm.Open("postgres", dbReadOnly) if err != nil { - t.Fatalf("failed to open gorm read only db - %s", err) + t.Fatalf("failed to open read-only gorm db - %s", err) } mock.ExpectBegin() dbROMock.ExpectBegin() @@ -727,6 +769,11 @@ func TestBeginFromContextStartWithReadWriteOptions(t *testing.T) { if txn1 != txn3 { t.Error("Got a different txn than was opened before") } + test.txOpts = &sql.TxOptions{} + _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + if err == nil { + t.Error("begin transaction should fail with an error DBOptionsMismatch") + } } else { txn1, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) if err != nil { @@ -753,15 +800,15 @@ func TestBeginFromContextStartWithReadWriteOptions(t *testing.T) { t.Error("begin transaction should fail with an error DBOptionsMismatch") } test.withOpts = noOptions - txn3, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) + txn2, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) if err != nil { t.Error("Received an error beginning transaction") } - if txn3 == nil { + if txn2 == nil { t.Error("Did not receive a transaction from context") } // Case: Transaction begin is idempotent - if txn1 != txn3 { + if txn1 != txn2 { t.Error("Got a different txn than was opened before") } test.txOpts.ReadOnly = true @@ -775,6 +822,14 @@ func TestBeginFromContextStartWithReadWriteOptions(t *testing.T) { if err == nil { t.Error("begin transaction should fail with an error TxOptionsMismatch") } + txn3, err := beginFromContextWithOptions(ctx, test.withOpts, nil) + if err != nil { + t.Error("Received an error beginning transaction") + } + if txn3 == nil { + t.Error("Did not receive a transaction from context") + } + } }) } From a0621ec96c4f7609b404611cc5c4cc1ba2338af9 Mon Sep 17 00:00:00 2001 From: Ashok Ankam Date: Wed, 8 Jan 2025 13:25:01 -0500 Subject: [PATCH 6/7] incorporated the review comments --- gorm/transaction.go | 30 +++++++-------- gorm/transaction_test.go | 83 ++++++++++++++++++++++++++-------------- 2 files changed, 67 insertions(+), 46 deletions(-) diff --git a/gorm/transaction.go b/gorm/transaction.go index 9547b562..cf471be7 100644 --- a/gorm/transaction.go +++ b/gorm/transaction.go @@ -109,19 +109,17 @@ func getReadOnlyDBTxn(ctx context.Context, opts *databaseOptions, txn *Transacti case txn.parentRO == nil: return getReadWriteDBTxn(ctx, opts, txn) case opts.txOpts != nil && txn.currentOpts.txOpts != nil: - if opts.txOpts.ReadOnly != txn.currentOpts.txOpts.ReadOnly || opts.txOpts.Isolation != txn.currentOpts.txOpts.Isolation { + if *opts.txOpts != *txn.currentOpts.txOpts { return nil, ErrCtxTxnOptMismatch } - default: + case opts.txOpts != nil: // We should error in two cases 1. We should error if read-only DB requested with read-write txn // 2. If no txn options provided in previous call but provided in subsequent call - if opts.txOpts != nil { - if opts.txOpts.ReadOnly == false || txn.currentOpts.database != dbNotSet { - return nil, ErrCtxTxnOptMismatch - } - txnOpts := *opts.txOpts - txn.currentOpts.txOpts = &txnOpts + if opts.txOpts.ReadOnly == false || txn.currentOpts.database != dbNotSet { + return nil, ErrCtxTxnOptMismatch } + txnOpts := *opts.txOpts + txn.currentOpts.txOpts = &txnOpts } if txn.current != nil { return txn.current, nil @@ -143,18 +141,16 @@ func getReadWriteDBTxn(ctx context.Context, opts *databaseOptions, txn *Transact case txn.parent == nil: return nil, ErrCtxTxnNoDB case opts.txOpts != nil && txn.currentOpts.txOpts != nil: - if opts.txOpts.ReadOnly != txn.currentOpts.txOpts.ReadOnly || opts.txOpts.Isolation != txn.currentOpts.txOpts.Isolation { + if *opts.txOpts != *txn.currentOpts.txOpts { return nil, ErrCtxTxnOptMismatch } - default: - if opts.txOpts != nil { - // We should return error If no txn options provided in previous call but provided in subsequent call - if txn.currentOpts.database != dbNotSet { - return nil, ErrCtxTxnOptMismatch - } - txnOpts := *opts.txOpts - txn.currentOpts.txOpts = &txnOpts + case opts.txOpts != nil: + // We should return error If no txn options provided in previous call but provided in subsequent call + if txn.currentOpts.database != dbNotSet { + return nil, ErrCtxTxnOptMismatch } + txnOpts := *opts.txOpts + txn.currentOpts.txOpts = &txnOpts } if txn.current != nil { return txn.current, nil diff --git a/gorm/transaction_test.go b/gorm/transaction_test.go index 9ed65dbc..36771481 100644 --- a/gorm/transaction_test.go +++ b/gorm/transaction_test.go @@ -476,7 +476,7 @@ func TestBeginFromContextStartWithNoOptions(t *testing.T) { if err != nil { t.Fatalf("failed to open gorm db - %s", err) } - readOnlyDB, dbROMock, err := sqlmock.New() + readOnlyDB, _, err := sqlmock.New() if err != nil { t.Fatalf("failed to create sqlmock for read-only db - %s", err) } @@ -485,7 +485,6 @@ func TestBeginFromContextStartWithNoOptions(t *testing.T) { t.Fatalf("failed to open gorm read-only db - %s", err) } mock.ExpectBegin() - dbROMock.ExpectBegin() ctxtxn := &Transaction{parent: gdb, parentRO: dbRO} ctx = NewContext(ctx, ctxtxn) if test.txOpts == nil { @@ -496,9 +495,12 @@ func TestBeginFromContextStartWithNoOptions(t *testing.T) { if txn1 == nil { t.Error("Did not receive a transaction from context") } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("failed to begin transaction for read-write db - %s", err) + } test.withOpts = readOnly _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) - if err == nil { + if err != ErrCtxDBOptMismatch { t.Error("begin transaction should fail with an error DBOptionsMismatch") } test.withOpts = readWrite @@ -515,7 +517,7 @@ func TestBeginFromContextStartWithNoOptions(t *testing.T) { } test.txOpts = &sql.TxOptions{} _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) - if err == nil { + if err != ErrCtxTxnOptMismatch { t.Error("begin transaction should fail with an error TxnOptionsMismatch") } } else { @@ -526,21 +528,24 @@ func TestBeginFromContextStartWithNoOptions(t *testing.T) { if txn1 == nil { t.Error("Did not receive a transaction from context") } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("failed to begin transaction for read-write db - %s", err) + } test.txOpts.ReadOnly = true _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) - if err == nil { + if err != ErrCtxTxnOptMismatch { t.Error("begin transaction should fail with an error TxOptionsMismatch") } test.txOpts.ReadOnly = false test.txOpts.Isolation = sql.LevelSerializable _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) - if err == nil { + if err != ErrCtxTxnOptMismatch { t.Error("begin transaction should fail with an error TxOptionsMismatch") } test.txOpts.Isolation = sql.LevelDefault test.withOpts = readOnly _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) - if err == nil { + if err != ErrCtxDBOptMismatch { t.Error("begin transaction should fail with an error DBOptionsMismatch") } test.withOpts = readWrite @@ -557,13 +562,13 @@ func TestBeginFromContextStartWithNoOptions(t *testing.T) { } test.txOpts.ReadOnly = true _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) - if err == nil { + if err != ErrCtxTxnOptMismatch { t.Error("begin transaction should fail with an error TxOptionsMismatch") } test.txOpts.ReadOnly = false test.txOpts.Isolation = sql.LevelSerializable _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) - if err == nil { + if err != ErrCtxTxnOptMismatch { t.Error("begin transaction should fail with an error TxOptionsMismatch") } txn4, err := beginFromContextWithOptions(ctx, test.withOpts, nil) @@ -573,8 +578,13 @@ func TestBeginFromContextStartWithNoOptions(t *testing.T) { if txn4 == nil { t.Error("Did not receive a transaction from context") } + // Case: Transaction begin is idempotent + if txn1 != txn4 { + t.Error("Got a different txn than was opened before") + } } }) + } } @@ -599,7 +609,7 @@ func TestBeginFromContextStartWithReadOnlyOptions(t *testing.T) { t.Run(test.desc, func(t *testing.T) { ctx := context.Background() - db, mock, err := sqlmock.New() + db, _, err := sqlmock.New() if err != nil { t.Fatalf("failed to create sqlmock - %s", err) } @@ -615,7 +625,6 @@ func TestBeginFromContextStartWithReadOnlyOptions(t *testing.T) { if err != nil { t.Fatalf("failed to open read-only gorm db - %s", err) } - mock.ExpectBegin() dbROMock.ExpectBegin() ctxtxn := &Transaction{parent: gdb, parentRO: dbRO} ctx = NewContext(ctx, ctxtxn) @@ -627,6 +636,9 @@ func TestBeginFromContextStartWithReadOnlyOptions(t *testing.T) { if txn1 == nil { t.Error("Did not receive a transaction from context") } + if err := dbROMock.ExpectationsWereMet(); err != nil { + t.Errorf("failed to begin transaction for read-only db - %s", err) + } test.withOpts = noOptions txn2, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) if err != nil { @@ -641,18 +653,18 @@ func TestBeginFromContextStartWithReadOnlyOptions(t *testing.T) { } test.withOpts = readWrite _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) - if err == nil { + if err != ErrCtxDBOptMismatch { t.Error("begin transaction should fail with an error DBOptionsMismatch") } test.withOpts = noOptions test.txOpts = &sql.TxOptions{} _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) - if err == nil { + if err != ErrCtxTxnOptMismatch { t.Error("begin transaction should fail with an error TxnOptionsMismatch") } } else { _, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) - if err == nil { + if err != ErrCtxTxnOptMismatch { t.Error("begin transaction should fail with an error TxOptionsMismatch") } test.txOpts.ReadOnly = true @@ -663,15 +675,18 @@ func TestBeginFromContextStartWithReadOnlyOptions(t *testing.T) { if txn1 == nil { t.Error("Did not receive a transaction from context") } + if err := dbROMock.ExpectationsWereMet(); err != nil { + t.Errorf("failed to begin transaction for read-only db - %s", err) + } test.txOpts.ReadOnly = false _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) - if err == nil { + if err != ErrCtxTxnOptMismatch { t.Error("begin transaction should fail with an error TxOptionsMismatch") } test.txOpts.ReadOnly = true test.txOpts.Isolation = sql.LevelSerializable _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) - if err == nil { + if err != ErrCtxTxnOptMismatch { t.Error("begin transaction should fail with an error TxOptionsMismatch") } test.txOpts.Isolation = sql.LevelDefault @@ -693,9 +708,12 @@ func TestBeginFromContextStartWithReadOnlyOptions(t *testing.T) { if txn3 == nil { t.Error("Did not receive a transaction from context") } + if txn1 != txn3 { + t.Error("Got a different txn than was opened before") + } test.withOpts = readWrite _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) - if err == nil { + if err != ErrCtxDBOptMismatch { t.Error("begin transaction should fail with an error DBOptionsMismatch") } } @@ -732,16 +750,15 @@ func TestBeginFromContextStartWithReadWriteOptions(t *testing.T) { if err != nil { t.Fatalf("failed to open gorm db - %s", err) } - dbReadOnly, dbROMock, err := sqlmock.New() + readOnlyDB, _, err := sqlmock.New() if err != nil { t.Fatalf("failed to create sqlmock for read-only db - %s", err) } - dbRO, err := gorm.Open("postgres", dbReadOnly) + dbRO, err := gorm.Open("postgres", readOnlyDB) if err != nil { t.Fatalf("failed to open read-only gorm db - %s", err) } mock.ExpectBegin() - dbROMock.ExpectBegin() ctxtxn := &Transaction{parent: gdb, parentRO: dbRO} ctx = NewContext(ctx, ctxtxn) if test.txOpts == nil { @@ -752,9 +769,12 @@ func TestBeginFromContextStartWithReadWriteOptions(t *testing.T) { if txn1 == nil { t.Error("Did not receive a transaction from context") } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("failed to begin transaction for read-write db - %s", err) + } test.withOpts = readOnly _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) - if err == nil { + if err != ErrCtxDBOptMismatch { t.Error("begin transaction should fail with an error DBOptionsMismatch") } test.withOpts = noOptions @@ -771,8 +791,8 @@ func TestBeginFromContextStartWithReadWriteOptions(t *testing.T) { } test.txOpts = &sql.TxOptions{} _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) - if err == nil { - t.Error("begin transaction should fail with an error DBOptionsMismatch") + if err != ErrCtxTxnOptMismatch { + t.Error("begin transaction should fail with an error TxOptionsMismatch") } } else { txn1, err := beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) @@ -782,21 +802,24 @@ func TestBeginFromContextStartWithReadWriteOptions(t *testing.T) { if txn1 == nil { t.Error("Did not receive a transaction from context") } + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("failed to begin transaction for read-write db - %s", err) + } test.txOpts.ReadOnly = true _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) - if err == nil { + if err != ErrCtxTxnOptMismatch { t.Error("begin transaction should fail with an error TxOptionsMismatch") } test.txOpts.ReadOnly = false test.txOpts.Isolation = sql.LevelSerializable _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) - if err == nil { + if err != ErrCtxTxnOptMismatch { t.Error("begin transaction should fail with an error TxOptionsMismatch") } test.txOpts.Isolation = sql.LevelDefault test.withOpts = readOnly _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) - if err == nil { + if err != ErrCtxDBOptMismatch { t.Error("begin transaction should fail with an error DBOptionsMismatch") } test.withOpts = noOptions @@ -813,13 +836,13 @@ func TestBeginFromContextStartWithReadWriteOptions(t *testing.T) { } test.txOpts.ReadOnly = true _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) - if err == nil { + if err != ErrCtxTxnOptMismatch { t.Error("begin transaction should fail with an error TxOptionsMismatch") } test.txOpts.ReadOnly = false test.txOpts.Isolation = sql.LevelSerializable _, err = beginFromContextWithOptions(ctx, test.withOpts, test.txOpts) - if err == nil { + if err != ErrCtxTxnOptMismatch { t.Error("begin transaction should fail with an error TxOptionsMismatch") } txn3, err := beginFromContextWithOptions(ctx, test.withOpts, nil) @@ -829,7 +852,9 @@ func TestBeginFromContextStartWithReadWriteOptions(t *testing.T) { if txn3 == nil { t.Error("Did not receive a transaction from context") } - + if txn1 != txn3 { + t.Error("Got a different txn than was opened before") + } } }) } From f274246b6f7409fa1d579c6c5ac4d41a91efd858 Mon Sep 17 00:00:00 2001 From: Ashok Ankam Date: Wed, 8 Jan 2025 13:42:38 -0500 Subject: [PATCH 7/7] fixed the lint issues --- gorm/transaction.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gorm/transaction.go b/gorm/transaction.go index cf471be7..d10710f8 100644 --- a/gorm/transaction.go +++ b/gorm/transaction.go @@ -72,7 +72,7 @@ type DatabaseOption func(*databaseOptions) // WithRODB returns closure to set the readOnlyReplica flag func WithRODB(readOnlyReplica bool) DatabaseOption { return func(ops *databaseOptions) { - if readOnlyReplica == false { + if !readOnlyReplica { ops.database = dbReadWrite } else { ops.database = dbReadOnly @@ -115,7 +115,7 @@ func getReadOnlyDBTxn(ctx context.Context, opts *databaseOptions, txn *Transacti case opts.txOpts != nil: // We should error in two cases 1. We should error if read-only DB requested with read-write txn // 2. If no txn options provided in previous call but provided in subsequent call - if opts.txOpts.ReadOnly == false || txn.currentOpts.database != dbNotSet { + if !opts.txOpts.ReadOnly || txn.currentOpts.database != dbNotSet { return nil, ErrCtxTxnOptMismatch } txnOpts := *opts.txOpts