From a9fe025ef53b419ea5d6406f5f79a2bc7e52d71a Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Fri, 19 Mar 2021 15:54:32 +0800 Subject: [PATCH] Add GetDBConnector interface --- gorm.go | 4 ++-- interfaces.go | 4 ++++ prepare_stmt.go | 12 ++++++++++++ 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/gorm.go b/gorm.go index 88212e942..9323c46db 100644 --- a/gorm.go +++ b/gorm.go @@ -331,8 +331,8 @@ func (db *DB) AddError(err error) error { func (db *DB) DB() (*sql.DB, error) { connPool := db.ConnPool - if stmtDB, ok := connPool.(*PreparedStmtDB); ok { - connPool = stmtDB.ConnPool + if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil { + return dbConnector.GetDBConn() } if sqldb, ok := connPool.(*sql.DB); ok { diff --git a/interfaces.go b/interfaces.go index e933952bb..44b2fcedb 100644 --- a/interfaces.go +++ b/interfaces.go @@ -57,3 +57,7 @@ type TxCommitter interface { type Valuer interface { GormValue(context.Context, *DB) clause.Expr } + +type GetDBConnector interface { + GetDBConn() (*sql.DB, error) +} diff --git a/prepare_stmt.go b/prepare_stmt.go index 78a8adb48..bc7ef180f 100644 --- a/prepare_stmt.go +++ b/prepare_stmt.go @@ -18,6 +18,18 @@ type PreparedStmtDB struct { ConnPool } +func (db *PreparedStmtDB) GetDB() (*sql.DB, error) { + if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { + return dbConnector.GetDBConn() + } + + if sqldb, ok := db.ConnPool.(*sql.DB); ok { + return sqldb, nil + } + + return nil, ErrInvaildDB +} + func (db *PreparedStmtDB) Close() { db.Mux.Lock() for _, query := range db.PreparedSQL {