Skip to content

Commit

Permalink
feat: add SQLDB() to get the underlying database/sql.(*DB) from a store
Browse files Browse the repository at this point in the history
  • Loading branch information
alnr committed Sep 11, 2023
1 parent 01ebd5b commit 8ed7a2b
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 0 deletions.
4 changes: 4 additions & 0 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ type dB struct {
*sqlx.DB
}

func (db *dB) SQLDB() *sql.DB {
return db.DB.DB
}

func (db *dB) TransactionContext(ctx context.Context) (*Tx, error) {
return newTX(ctx, db, nil)
}
Expand Down
2 changes: 2 additions & 0 deletions store.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ type store interface {
Commit() error
Close() error

SQLDB() *sql.DB

// Context versions to wrap with contextStore
SelectContext(context.Context, interface{}, string, ...interface{}) error
GetContext(context.Context, interface{}, string, ...interface{}) error
Expand Down
6 changes: 6 additions & 0 deletions tx.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ func init() {
type Tx struct {
ID int
*sqlx.Tx
db *sql.DB
}

func newTX(ctx context.Context, db *dB, opts *sql.TxOptions) (*Tx, error) {
t := &Tx{
ID: rand.Int(),
db: db.SQLDB(),
}
tx, err := db.BeginTxx(ctx, opts)
t.Tx = tx
Expand All @@ -32,6 +34,10 @@ func newTX(ctx context.Context, db *dB, opts *sql.TxOptions) (*Tx, error) {
return t, nil
}

func (tx *Tx) SQLDB() *sql.DB {
return tx.db
}

// TransactionContext simply returns the current transaction,
// this is defined so it implements the `Store` interface.
func (tx *Tx) TransactionContext(ctx context.Context) (*Tx, error) {
Expand Down

0 comments on commit 8ed7a2b

Please sign in to comment.