Skip to content

Commit

Permalink
feat: support disable dialect's feature
Browse files Browse the repository at this point in the history
At the same time, we have also consolidated the logic
for determining features in the database into the dialect.
  • Loading branch information
j2gg0s committed Nov 22, 2024
1 parent cb319f8 commit 742587b
Show file tree
Hide file tree
Showing 13 changed files with 84 additions and 29 deletions.
18 changes: 8 additions & 10 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ func WithDiscardUnknownColumns() DBOption {
type DB struct {
*sql.DB

dialect schema.Dialect
features feature.Feature
dialect schema.Dialect

queryHooks []QueryHook

Expand All @@ -50,10 +49,9 @@ func NewDB(sqldb *sql.DB, dialect schema.Dialect, opts ...DBOption) *DB {
dialect.Init(sqldb)

db := &DB{
DB: sqldb,
dialect: dialect,
features: dialect.Features(),
fmter: schema.NewFormatter(dialect),
DB: sqldb,
dialect: dialect,
fmter: schema.NewFormatter(dialect),
}

for _, opt := range opts {
Expand Down Expand Up @@ -231,7 +229,7 @@ func (db *DB) UpdateFQN(alias, column string) Ident {

// HasFeature uses feature package to report whether the underlying DBMS supports this feature.
func (db *DB) HasFeature(feat feature.Feature) bool {
return db.fmter.HasFeature(feat)
return db.dialect.Features().Has(feat)
}

//------------------------------------------------------------------------------
Expand Down Expand Up @@ -513,7 +511,7 @@ func (tx Tx) commitTX() error {
}

func (tx Tx) commitSP() error {
if tx.Dialect().Features().Has(feature.MSSavepoint) {
if tx.db.HasFeature(feature.MSSavepoint) {
return nil
}
query := "RELEASE SAVEPOINT " + tx.name
Expand All @@ -537,7 +535,7 @@ func (tx Tx) rollbackTX() error {

func (tx Tx) rollbackSP() error {
query := "ROLLBACK TO SAVEPOINT " + tx.name
if tx.Dialect().Features().Has(feature.MSSavepoint) {
if tx.db.HasFeature(feature.MSSavepoint) {
query = "ROLLBACK TRANSACTION " + tx.name
}
_, err := tx.ExecContext(tx.ctx, query)
Expand Down Expand Up @@ -601,7 +599,7 @@ func (tx Tx) BeginTx(ctx context.Context, _ *sql.TxOptions) (Tx, error) {

qName := "SP_" + hex.EncodeToString(sp)
query := "SAVEPOINT " + qName
if tx.Dialect().Features().Has(feature.MSSavepoint) {
if tx.db.HasFeature(feature.MSSavepoint) {
query = "SAVE TRANSACTION " + qName
}
_, err = tx.ExecContext(ctx, query)
Expand Down
14 changes: 13 additions & 1 deletion dialect/mssqldialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ type Dialect struct {
features feature.Feature
}

func New() *Dialect {
func New(opts ...DialectOption) *Dialect {
d := new(Dialect)
d.tables = schema.NewTables(d)
d.features = feature.CTE |
Expand All @@ -49,9 +49,21 @@ func New() *Dialect {
feature.OffsetFetch |
feature.UpdateFromTable |
feature.MSSavepoint

for _, opt := range opts {
opt(d)
}
return d
}

type DialectOption func(d *Dialect)

func RemoveFeature(other feature.Feature) DialectOption {
return func(d *Dialect) {
d.features = d.features.Remove(other)
}
}

func (d *Dialect) Init(db *sql.DB) {
var version string
if err := db.QueryRow("SELECT @@VERSION").Scan(&version); err != nil {
Expand Down
10 changes: 8 additions & 2 deletions dialect/mysqldialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ func init() {
}
}

type DialectOption func(d *Dialect)

type Dialect struct {
schema.BaseDialect

Expand Down Expand Up @@ -60,6 +58,8 @@ func New(opts ...DialectOption) *Dialect {
return d
}

type DialectOption func(d *Dialect)

func WithTimeLocation(loc string) DialectOption {
return func(d *Dialect) {
location, err := time.LoadLocation(loc)
Expand All @@ -70,6 +70,12 @@ func WithTimeLocation(loc string) DialectOption {
}
}

func RemoveFeature(other feature.Feature) DialectOption {
return func(d *Dialect) {
d.features = d.features.Remove(other)
}
}

func (d *Dialect) Init(db *sql.DB) {
var version string
if err := db.QueryRow("SELECT version()").Scan(&version); err != nil {
Expand Down
15 changes: 14 additions & 1 deletion dialect/oracledialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type Dialect struct {
features feature.Feature
}

func New() *Dialect {
func New(opts ...DialectOption) *Dialect {
d := new(Dialect)
d.tables = schema.NewTables(d)
d.features = feature.CTE |
Expand All @@ -42,9 +42,22 @@ func New() *Dialect {
feature.AutoIncrement |
feature.CompositeIn |
feature.DeleteReturning

for _, opt := range opts {
opt(d)
}

return d
}

type DialectOption func(d *Dialect)

func RemoveFeature(other feature.Feature) DialectOption {
return func(d *Dialect) {
d.features = d.features.Remove(other)
}
}

func (d *Dialect) Init(*sql.DB) {}

func (d *Dialect) Name() dialect.Name {
Expand Down
15 changes: 14 additions & 1 deletion dialect/pgdialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ var _ schema.Dialect = (*Dialect)(nil)
var _ sqlschema.InspectorDialect = (*Dialect)(nil)
var _ sqlschema.MigratorDialect = (*Dialect)(nil)

func New() *Dialect {
func New(opts ...DialectOption) *Dialect {
d := new(Dialect)
d.tables = schema.NewTables(d)
d.features = feature.CTE |
Expand All @@ -55,9 +55,22 @@ func New() *Dialect {
feature.GeneratedIdentity |
feature.CompositeIn |
feature.DeleteReturning

for _, opt := range opts {
opt(d)
}

return d
}

type DialectOption func(d *Dialect)

func RemoveFeature(other feature.Feature) DialectOption {
return func(d *Dialect) {
d.features = d.features.Remove(other)
}
}

func (d *Dialect) Init(*sql.DB) {}

func (d *Dialect) Name() dialect.Name {
Expand Down
15 changes: 14 additions & 1 deletion dialect/sqlitedialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type Dialect struct {
features feature.Feature
}

func New() *Dialect {
func New(opts ...DialectOption) *Dialect {
d := new(Dialect)
d.tables = schema.NewTables(d)
d.features = feature.CTE |
Expand All @@ -42,9 +42,22 @@ func New() *Dialect {
feature.AutoIncrement |
feature.CompositeIn |
feature.DeleteReturning

for _, opt := range opts {
opt(d)
}

return d
}

type DialectOption func(d *Dialect)

func RemoveFeature(other feature.Feature) DialectOption {
return func(d *Dialect) {
d.features = d.features.Remove(other)
}
}

func (d *Dialect) Init(*sql.DB) {}

func (d *Dialect) Name() dialect.Name {
Expand Down
4 changes: 2 additions & 2 deletions model_map_slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func (m *mapSliceModel) appendValues(fmter schema.Formatter, b []byte) (_ []byte
slice := *m.dest

b = append(b, "VALUES "...)
if m.db.features.Has(feature.ValuesRow) {
if m.db.HasFeature(feature.ValuesRow) {
b = append(b, "ROW("...)
} else {
b = append(b, '(')
Expand All @@ -118,7 +118,7 @@ func (m *mapSliceModel) appendValues(fmter schema.Formatter, b []byte) (_ []byte
for i, el := range slice {
if i > 0 {
b = append(b, "), "...)
if m.db.features.Has(feature.ValuesRow) {
if m.db.HasFeature(feature.ValuesRow) {
b = append(b, "ROW("...)
} else {
b = append(b, '(')
Expand Down
2 changes: 1 addition & 1 deletion query_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ func (q *baseQuery) beforeAppendModel(ctx context.Context, query Query) error {
}

func (q *baseQuery) hasFeature(feature feature.Feature) bool {
return q.db.features.Has(feature)
return q.db.HasFeature(feature)
}

//------------------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion query_delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ func (q *DeleteQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, e
return upd.AppendQuery(fmter, b)
}

withAlias := q.db.features.Has(feature.DeleteTableAlias)
withAlias := q.db.HasFeature(feature.DeleteTableAlias)

b, err = q.appendWith(fmter, b)
if err != nil {
Expand Down
10 changes: 5 additions & 5 deletions query_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ func (q *InsertQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, e
}
b = append(b, "INTO "...)

if q.db.features.Has(feature.InsertTableAlias) && !q.on.IsZero() {
if q.db.HasFeature(feature.InsertTableAlias) && !q.on.IsZero() {
b, err = q.appendFirstTableWithAlias(fmter, b)
} else {
b, err = q.appendFirstTable(fmter, b)
Expand Down Expand Up @@ -385,9 +385,9 @@ func (q *InsertQuery) appendSliceValues(
}

func (q *InsertQuery) getFields() ([]*schema.Field, error) {
hasIdentity := q.db.features.Has(feature.Identity)
hasIdentity := q.db.HasFeature(feature.Identity)

if len(q.columns) > 0 || q.db.features.Has(feature.DefaultPlaceholder) && !hasIdentity {
if len(q.columns) > 0 || q.db.HasFeature(feature.DefaultPlaceholder) && !hasIdentity {
return q.baseQuery.getFields()
}

Expand Down Expand Up @@ -640,8 +640,8 @@ func (q *InsertQuery) afterInsertHook(ctx context.Context) error {
}

func (q *InsertQuery) tryLastInsertID(res sql.Result, dest []interface{}) error {
if q.db.features.Has(feature.Returning) ||
q.db.features.Has(feature.Output) ||
if q.db.HasFeature(feature.Returning) ||
q.db.HasFeature(feature.Output) ||
q.table == nil ||
len(q.table.PKs) != 1 ||
!q.table.PKs[0].AutoIncrement {
Expand Down
2 changes: 1 addition & 1 deletion query_table_truncate.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func (q *TruncateTableQuery) AppendQuery(
return nil, err
}

if q.db.features.Has(feature.TableIdentity) {
if q.db.HasFeature(feature.TableIdentity) {
if q.continueIdentity {
b = append(b, " CONTINUE IDENTITY"...)
} else {
Expand Down
4 changes: 2 additions & 2 deletions query_values.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (q *ValuesQuery) appendQuery(
fields []*schema.Field,
) (_ []byte, err error) {
b = append(b, "VALUES "...)
if q.db.features.Has(feature.ValuesRow) {
if q.db.HasFeature(feature.ValuesRow) {
b = append(b, "ROW("...)
} else {
b = append(b, '(')
Expand All @@ -168,7 +168,7 @@ func (q *ValuesQuery) appendQuery(
for i := 0; i < sliceLen; i++ {
if i > 0 {
b = append(b, "), "...)
if q.db.features.Has(feature.ValuesRow) {
if q.db.HasFeature(feature.ValuesRow) {
b = append(b, "ROW("...)
} else {
b = append(b, '(')
Expand Down
2 changes: 1 addition & 1 deletion relation_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func (j *relationJoin) manyQuery(q *SelectQuery) *SelectQuery {

var where []byte

if q.db.dialect.Features().Has(feature.CompositeIn) {
if q.db.HasFeature(feature.CompositeIn) {
return j.manyQueryCompositeIn(where, q)
}
return j.manyQueryMulti(where, q)
Expand Down

0 comments on commit 742587b

Please sign in to comment.