Skip to content

Commit

Permalink
Add GetDBConnector interface
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Mar 19, 2021
1 parent 220349c commit a9fe025
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 2 deletions.
4 changes: 2 additions & 2 deletions gorm.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,8 @@ func (db *DB) AddError(err error) error {
func (db *DB) DB() (*sql.DB, error) {
connPool := db.ConnPool

if stmtDB, ok := connPool.(*PreparedStmtDB); ok {
connPool = stmtDB.ConnPool
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
return dbConnector.GetDBConn()
}

if sqldb, ok := connPool.(*sql.DB); ok {
Expand Down
4 changes: 4 additions & 0 deletions interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,7 @@ type TxCommitter interface {
type Valuer interface {
GormValue(context.Context, *DB) clause.Expr
}

type GetDBConnector interface {
GetDBConn() (*sql.DB, error)
}
12 changes: 12 additions & 0 deletions prepare_stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@ type PreparedStmtDB struct {
ConnPool
}

func (db *PreparedStmtDB) GetDB() (*sql.DB, error) {
if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil {
return dbConnector.GetDBConn()
}

if sqldb, ok := db.ConnPool.(*sql.DB); ok {
return sqldb, nil
}

return nil, ErrInvaildDB
}

func (db *PreparedStmtDB) Close() {
db.Mux.Lock()
for _, query := range db.PreparedSQL {
Expand Down

0 comments on commit a9fe025

Please sign in to comment.