diff --git a/db.go b/db.go index a19f4cc3e..37ced9567 100644 --- a/db.go +++ b/db.go @@ -40,29 +40,36 @@ func WithReadOnlyReplica(replica *sql.DB) DBOption { } type DB struct { + // Must be a pointer so we copy the state, not the state fields. + *noCopyState + + queryHooks []QueryHook + + fmter schema.Formatter + stats DBStats +} + +type noCopyState struct { *sql.DB + dialect schema.Dialect replicas []*sql.DB healthyReplicas atomic.Pointer[[]*sql.DB] nextReplica atomic.Int64 - dialect schema.Dialect - queryHooks []QueryHook - - fmter schema.Formatter flags internal.Flag closed atomic.Bool - - stats DBStats } func NewDB(sqldb *sql.DB, dialect schema.Dialect, opts ...DBOption) *DB { dialect.Init(sqldb) db := &DB{ - DB: sqldb, - dialect: dialect, - fmter: schema.NewFormatter(dialect), + noCopyState: &noCopyState{ + DB: sqldb, + dialect: dialect, + }, + fmter: schema.NewFormatter(dialect), } for _, opt := range opts { diff --git a/query_base.go b/query_base.go index 70bedb999..27f557426 100644 --- a/query_base.go +++ b/query_base.go @@ -147,10 +147,8 @@ func (q *baseQuery) GetTableName() string { } for _, wq := range q.with { - if v, ok := wq.query.(Query); ok { - if model := v.GetModel(); model != nil { - return v.GetTableName() - } + if model := wq.query.GetModel(); model != nil { + return wq.query.GetTableName() } } diff --git a/query_select.go b/query_select.go index 70be52e27..95e04f455 100644 --- a/query_select.go +++ b/query_select.go @@ -748,7 +748,7 @@ func (q *SelectQuery) Rows(ctx context.Context) (*sql.Rows, error) { query := internal.String(queryBytes) ctx, event := q.db.beforeQuery(ctx, q, query, nil, query, q.model) - rows, err := q.conn.QueryContext(ctx, query) + rows, err := q.resolveConn(q).QueryContext(ctx, query) q.db.afterQuery(ctx, event, nil, err) return rows, err } @@ -876,7 +876,7 @@ func (q *SelectQuery) Count(ctx context.Context) (int, error) { ctx, event := q.db.beforeQuery(ctx, qq, query, nil, query, q.model) var num int - err = q.conn.QueryRowContext(ctx, query).Scan(&num) + err = q.resolveConn(q).QueryRowContext(ctx, query).Scan(&num) q.db.afterQuery(ctx, event, nil, err) @@ -894,13 +894,15 @@ func (q *SelectQuery) ScanAndCount(ctx context.Context, dest ...interface{}) (in return int(n), nil } } - if _, ok := q.conn.(*DB); ok { - return q.scanAndCountConc(ctx, dest...) + if q.conn == nil { + return q.scanAndCountConcurrently(ctx, dest...) } return q.scanAndCountSeq(ctx, dest...) } -func (q *SelectQuery) scanAndCountConc(ctx context.Context, dest ...interface{}) (int, error) { +func (q *SelectQuery) scanAndCountConcurrently( + ctx context.Context, dest ...interface{}, +) (int, error) { var count int var wg sync.WaitGroup var mu sync.Mutex @@ -978,7 +980,7 @@ func (q *SelectQuery) selectExists(ctx context.Context) (bool, error) { ctx, event := q.db.beforeQuery(ctx, qq, query, nil, query, q.model) var exists bool - err = q.conn.QueryRowContext(ctx, query).Scan(&exists) + err = q.resolveConn(q).QueryRowContext(ctx, query).Scan(&exists) q.db.afterQuery(ctx, event, nil, err)