Skip to content

Commit

Permalink
feat: allow to specify read-only replica for SELECTs
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Dec 4, 2024
1 parent 96baf25 commit cbbe1e9
Show file tree
Hide file tree
Showing 17 changed files with 123 additions and 69 deletions.
64 changes: 61 additions & 3 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"reflect"
"strings"
"sync/atomic"
"time"

"github.com/uptrace/bun/dialect/feature"
"github.com/uptrace/bun/internal"
Expand All @@ -32,15 +33,25 @@ func WithDiscardUnknownColumns() DBOption {
}
}

func WithReadOnlyReplica(replica *sql.DB) DBOption {
return func(db *DB) {
db.replicas = append(db.replicas, replica)
}
}

type DB struct {
*sql.DB

dialect schema.Dialect
replicas []*sql.DB
healthyReplicas atomic.Pointer[[]*sql.DB]
nextReplica atomic.Int64

dialect schema.Dialect
queryHooks []QueryHook

fmter schema.Formatter
flags internal.Flag
fmter schema.Formatter
flags internal.Flag
closed atomic.Bool

stats DBStats
}
Expand All @@ -58,6 +69,10 @@ func NewDB(sqldb *sql.DB, dialect schema.Dialect, opts ...DBOption) *DB {
opt(db)
}

if len(db.replicas) > 0 {
go db.monitorReplicas()
}

return db
}

Expand All @@ -69,6 +84,11 @@ func (db *DB) String() string {
return b.String()
}

func (db *DB) Close() error {
db.closed.Store(true)
return db.DB.Close()
}

func (db *DB) DBStats() DBStats {
return DBStats{
Queries: atomic.LoadUint32(&db.stats.Queries),
Expand Down Expand Up @@ -232,6 +252,44 @@ func (db *DB) HasFeature(feat feature.Feature) bool {
return db.dialect.Features().Has(feat)
}

// healthyReplica returns a random healthy replica.
func (db *DB) healthyReplica() *sql.DB {
replicas := db.loadHealthyReplicas()
if len(replicas) == 0 {
return db.DB
}
if len(replicas) == 1 {
return replicas[0]
}
i := db.nextReplica.Add(1)
return replicas[int(i)%len(replicas)]
}

func (db *DB) loadHealthyReplicas() []*sql.DB {
if ptr := db.healthyReplicas.Load(); ptr != nil {
return *ptr
}
return nil
}

func (db *DB) monitorReplicas() {
for !db.closed.Load() {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()

healthy := make([]*sql.DB, 0, len(db.replicas))

for _, replica := range db.replicas {
if err := replica.PingContext(ctx); err == nil {
healthy = append(healthy, replica)
}
}

db.healthyReplicas.Store(&healthy)
time.Sleep(5 * time.Second)
}
}

//------------------------------------------------------------------------------

func (db *DB) Exec(query string, args ...interface{}) (sql.Result, error) {
Expand Down
2 changes: 0 additions & 2 deletions internal/dbtest/docker-compose.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
version: '3.9'

services:
mysql8:
image: mysql:8.0
Expand Down
52 changes: 38 additions & 14 deletions query_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ const (

type withQuery struct {
name string
query schema.QueryAppender
query Query
recursive bool
}

Expand Down Expand Up @@ -114,8 +114,27 @@ func (q *baseQuery) DB() *DB {
return q.db
}

func (q *baseQuery) GetConn() IConn {
return q.conn
func (q *baseQuery) resolveConn(query Query) IConn {
if q.conn != nil {
return q.conn
}
if len(q.db.replicas) == 0 || !isReadOnlyQuery(query) {
return q.db.DB
}
return q.db.healthyReplica()
}

func isReadOnlyQuery(query Query) bool {
sel, ok := query.(*SelectQuery)
if !ok {
return false
}
for _, el := range sel.with {
if !isReadOnlyQuery(el.query) {
return false
}
}
return true
}

func (q *baseQuery) GetModel() Model {
Expand Down Expand Up @@ -249,7 +268,7 @@ func (q *baseQuery) isSoftDelete() bool {

//------------------------------------------------------------------------------

func (q *baseQuery) addWith(name string, query schema.QueryAppender, recursive bool) {
func (q *baseQuery) addWith(name string, query Query, recursive bool) {
q.with = append(q.with, withQuery{
name: name,
query: query,
Expand Down Expand Up @@ -565,28 +584,33 @@ func (q *baseQuery) scan(
hasDest bool,
) (sql.Result, error) {
ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, query, q.model)
res, err := q._scan(ctx, iquery, query, model, hasDest)
q.db.afterQuery(ctx, event, res, err)
return res, err
}

rows, err := q.conn.QueryContext(ctx, query)
func (q *baseQuery) _scan(
ctx context.Context,
iquery Query,
query string,
model Model,
hasDest bool,
) (sql.Result, error) {
rows, err := q.resolveConn(iquery).QueryContext(ctx, query)
if err != nil {
q.db.afterQuery(ctx, event, nil, err)
return nil, err
}
defer rows.Close()

numRow, err := model.ScanRows(ctx, rows)
if err != nil {
q.db.afterQuery(ctx, event, nil, err)
return nil, err
}

if numRow == 0 && hasDest && isSingleRowModel(model) {
err = sql.ErrNoRows
return nil, sql.ErrNoRows
}

res := driver.RowsAffected(numRow)
q.db.afterQuery(ctx, event, res, err)

return res, err
return driver.RowsAffected(numRow), nil
}

func (q *baseQuery) exec(
Expand All @@ -595,7 +619,7 @@ func (q *baseQuery) exec(
query string,
) (sql.Result, error) {
ctx, event := q.db.beforeQuery(ctx, iquery, query, nil, query, q.model)
res, err := q.conn.ExecContext(ctx, query)
res, err := q.resolveConn(iquery).ExecContext(ctx, query)
q.db.afterQuery(ctx, event, res, err)
return res, err
}
Expand Down
3 changes: 1 addition & 2 deletions query_column_add.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ var _ Query = (*AddColumnQuery)(nil)
func NewAddColumnQuery(db *DB) *AddColumnQuery {
q := &AddColumnQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
db: db,
},
}
return q
Expand Down
3 changes: 1 addition & 2 deletions query_column_drop.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ var _ Query = (*DropColumnQuery)(nil)
func NewDropColumnQuery(db *DB) *DropColumnQuery {
q := &DropColumnQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
db: db,
},
}
return q
Expand Down
7 changes: 3 additions & 4 deletions query_delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ func NewDeleteQuery(db *DB) *DeleteQuery {
q := &DeleteQuery{
whereBaseQuery: whereBaseQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
db: db,
},
},
}
Expand Down Expand Up @@ -56,12 +55,12 @@ func (q *DeleteQuery) Apply(fns ...func(*DeleteQuery) *DeleteQuery) *DeleteQuery
return q
}

func (q *DeleteQuery) With(name string, query schema.QueryAppender) *DeleteQuery {
func (q *DeleteQuery) With(name string, query Query) *DeleteQuery {
q.addWith(name, query, false)
return q
}

func (q *DeleteQuery) WithRecursive(name string, query schema.QueryAppender) *DeleteQuery {
func (q *DeleteQuery) WithRecursive(name string, query Query) *DeleteQuery {
q.addWith(name, query, true)
return q
}
Expand Down
3 changes: 1 addition & 2 deletions query_index_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ func NewCreateIndexQuery(db *DB) *CreateIndexQuery {
q := &CreateIndexQuery{
whereBaseQuery: whereBaseQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
db: db,
},
},
}
Expand Down
3 changes: 1 addition & 2 deletions query_index_drop.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ var _ Query = (*DropIndexQuery)(nil)
func NewDropIndexQuery(db *DB) *DropIndexQuery {
q := &DropIndexQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
db: db,
},
}
return q
Expand Down
7 changes: 3 additions & 4 deletions query_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ func NewInsertQuery(db *DB) *InsertQuery {
q := &InsertQuery{
whereBaseQuery: whereBaseQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
db: db,
},
},
}
Expand Down Expand Up @@ -63,12 +62,12 @@ func (q *InsertQuery) Apply(fns ...func(*InsertQuery) *InsertQuery) *InsertQuery
return q
}

func (q *InsertQuery) With(name string, query schema.QueryAppender) *InsertQuery {
func (q *InsertQuery) With(name string, query Query) *InsertQuery {
q.addWith(name, query, false)
return q
}

func (q *InsertQuery) WithRecursive(name string, query schema.QueryAppender) *InsertQuery {
func (q *InsertQuery) WithRecursive(name string, query Query) *InsertQuery {
q.addWith(name, query, true)
return q
}
Expand Down
7 changes: 3 additions & 4 deletions query_merge.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ var _ Query = (*MergeQuery)(nil)
func NewMergeQuery(db *DB) *MergeQuery {
q := &MergeQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
db: db,
},
}
if q.db.dialect.Name() != dialect.MSSQL && q.db.dialect.Name() != dialect.PG {
Expand Down Expand Up @@ -60,12 +59,12 @@ func (q *MergeQuery) Apply(fns ...func(*MergeQuery) *MergeQuery) *MergeQuery {
return q
}

func (q *MergeQuery) With(name string, query schema.QueryAppender) *MergeQuery {
func (q *MergeQuery) With(name string, query Query) *MergeQuery {
q.addWith(name, query, false)
return q
}

func (q *MergeQuery) WithRecursive(name string, query schema.QueryAppender) *MergeQuery {
func (q *MergeQuery) WithRecursive(name string, query Query) *MergeQuery {
q.addWith(name, query, true)
return q
}
Expand Down
15 changes: 1 addition & 14 deletions query_raw.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,10 @@ type RawQuery struct {
args []interface{}
}

// Deprecated: Use NewRaw instead. When add it to IDB, it conflicts with the sql.Conn#Raw
func (db *DB) Raw(query string, args ...interface{}) *RawQuery {
return &RawQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
},
query: query,
args: args,
}
}

func NewRawQuery(db *DB, query string, args ...interface{}) *RawQuery {
return &RawQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
db: db,
},
query: query,
args: args,
Expand Down
7 changes: 3 additions & 4 deletions query_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ func NewSelectQuery(db *DB) *SelectQuery {
return &SelectQuery{
whereBaseQuery: whereBaseQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
db: db,
},
},
}
Expand Down Expand Up @@ -72,12 +71,12 @@ func (q *SelectQuery) Apply(fns ...func(*SelectQuery) *SelectQuery) *SelectQuery
return q
}

func (q *SelectQuery) With(name string, query schema.QueryAppender) *SelectQuery {
func (q *SelectQuery) With(name string, query Query) *SelectQuery {
q.addWith(name, query, false)
return q
}

func (q *SelectQuery) WithRecursive(name string, query schema.QueryAppender) *SelectQuery {
func (q *SelectQuery) WithRecursive(name string, query Query) *SelectQuery {
q.addWith(name, query, true)
return q
}
Expand Down
3 changes: 1 addition & 2 deletions query_table_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ var _ Query = (*CreateTableQuery)(nil)
func NewCreateTableQuery(db *DB) *CreateTableQuery {
q := &CreateTableQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
db: db,
},
varchar: db.Dialect().DefaultVarcharLen(),
}
Expand Down
3 changes: 1 addition & 2 deletions query_table_drop.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ var _ Query = (*DropTableQuery)(nil)
func NewDropTableQuery(db *DB) *DropTableQuery {
q := &DropTableQuery{
baseQuery: baseQuery{
db: db,
conn: db.DB,
db: db,
},
}
return q
Expand Down
Loading

0 comments on commit cbbe1e9

Please sign in to comment.