Skip to content

Commit

Permalink
Merge pull request #1070 from j2gg0s/feat-support-disable-dialect-fea…
Browse files Browse the repository at this point in the history
…ture

feat: support disable dialect's feature
  • Loading branch information
j2gg0s authored Nov 25, 2024
2 parents ae8bdd3 + 5343bd7 commit 89e9d51
Show file tree
Hide file tree
Showing 13 changed files with 84 additions and 29 deletions.
18 changes: 8 additions & 10 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ func WithDiscardUnknownColumns() DBOption {
type DB struct {
*sql.DB

dialect schema.Dialect
features feature.Feature
dialect schema.Dialect

queryHooks []QueryHook

Expand All @@ -50,10 +49,9 @@ func NewDB(sqldb *sql.DB, dialect schema.Dialect, opts ...DBOption) *DB {
dialect.Init(sqldb)

db := &DB{
DB: sqldb,
dialect: dialect,
features: dialect.Features(),
fmter: schema.NewFormatter(dialect),
DB: sqldb,
dialect: dialect,
fmter: schema.NewFormatter(dialect),
}

for _, opt := range opts {
Expand Down Expand Up @@ -231,7 +229,7 @@ func (db *DB) UpdateFQN(alias, column string) Ident {

// HasFeature uses feature package to report whether the underlying DBMS supports this feature.
func (db *DB) HasFeature(feat feature.Feature) bool {
return db.fmter.HasFeature(feat)
return db.dialect.Features().Has(feat)
}

//------------------------------------------------------------------------------
Expand Down Expand Up @@ -513,7 +511,7 @@ func (tx Tx) commitTX() error {
}

func (tx Tx) commitSP() error {
if tx.Dialect().Features().Has(feature.MSSavepoint) {
if tx.db.HasFeature(feature.MSSavepoint) {
return nil
}
query := "RELEASE SAVEPOINT " + tx.name
Expand All @@ -537,7 +535,7 @@ func (tx Tx) rollbackTX() error {

func (tx Tx) rollbackSP() error {
query := "ROLLBACK TO SAVEPOINT " + tx.name
if tx.Dialect().Features().Has(feature.MSSavepoint) {
if tx.db.HasFeature(feature.MSSavepoint) {
query = "ROLLBACK TRANSACTION " + tx.name
}
_, err := tx.ExecContext(tx.ctx, query)
Expand Down Expand Up @@ -601,7 +599,7 @@ func (tx Tx) BeginTx(ctx context.Context, _ *sql.TxOptions) (Tx, error) {

qName := "SP_" + hex.EncodeToString(sp)
query := "SAVEPOINT " + qName
if tx.Dialect().Features().Has(feature.MSSavepoint) {
if tx.db.HasFeature(feature.MSSavepoint) {
query = "SAVE TRANSACTION " + qName
}
_, err = tx.ExecContext(ctx, query)
Expand Down
14 changes: 13 additions & 1 deletion dialect/mssqldialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ type Dialect struct {
features feature.Feature
}

func New() *Dialect {
func New(opts ...DialectOption) *Dialect {
d := new(Dialect)
d.tables = schema.NewTables(d)
d.features = feature.CTE |
Expand All @@ -49,9 +49,21 @@ func New() *Dialect {
feature.OffsetFetch |
feature.UpdateFromTable |
feature.MSSavepoint

for _, opt := range opts {
opt(d)
}
return d
}

type DialectOption func(d *Dialect)

func WithoutFeature(other feature.Feature) DialectOption {
return func(d *Dialect) {
d.features = d.features.Remove(other)
}
}

func (d *Dialect) Init(db *sql.DB) {
var version string
if err := db.QueryRow("SELECT @@VERSION").Scan(&version); err != nil {
Expand Down
10 changes: 8 additions & 2 deletions dialect/mysqldialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ func init() {
}
}

type DialectOption func(d *Dialect)

type Dialect struct {
schema.BaseDialect

Expand Down Expand Up @@ -60,6 +58,8 @@ func New(opts ...DialectOption) *Dialect {
return d
}

type DialectOption func(d *Dialect)

func WithTimeLocation(loc string) DialectOption {
return func(d *Dialect) {
location, err := time.LoadLocation(loc)
Expand All @@ -70,6 +70,12 @@ func WithTimeLocation(loc string) DialectOption {
}
}

func WithoutFeature(other feature.Feature) DialectOption {
return func(d *Dialect) {
d.features = d.features.Remove(other)
}
}

func (d *Dialect) Init(db *sql.DB) {
var version string
if err := db.QueryRow("SELECT version()").Scan(&version); err != nil {
Expand Down
15 changes: 14 additions & 1 deletion dialect/oracledialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type Dialect struct {
features feature.Feature
}

func New() *Dialect {
func New(opts ...DialectOption) *Dialect {
d := new(Dialect)
d.tables = schema.NewTables(d)
d.features = feature.CTE |
Expand All @@ -42,9 +42,22 @@ func New() *Dialect {
feature.AutoIncrement |
feature.CompositeIn |
feature.DeleteReturning

for _, opt := range opts {
opt(d)
}

return d
}

type DialectOption func(d *Dialect)

func WithoutFeature(other feature.Feature) DialectOption {
return func(d *Dialect) {
d.features = d.features.Remove(other)
}
}

func (d *Dialect) Init(*sql.DB) {}

func (d *Dialect) Name() dialect.Name {
Expand Down
15 changes: 14 additions & 1 deletion dialect/pgdialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ var _ schema.Dialect = (*Dialect)(nil)
var _ sqlschema.InspectorDialect = (*Dialect)(nil)
var _ sqlschema.MigratorDialect = (*Dialect)(nil)

func New() *Dialect {
func New(opts ...DialectOption) *Dialect {
d := new(Dialect)
d.tables = schema.NewTables(d)
d.features = feature.CTE |
Expand All @@ -55,9 +55,22 @@ func New() *Dialect {
feature.GeneratedIdentity |
feature.CompositeIn |
feature.DeleteReturning

for _, opt := range opts {
opt(d)
}

return d
}

type DialectOption func(d *Dialect)

func WithoutFeature(other feature.Feature) DialectOption {
return func(d *Dialect) {
d.features = d.features.Remove(other)
}
}

func (d *Dialect) Init(*sql.DB) {}

func (d *Dialect) Name() dialect.Name {
Expand Down
15 changes: 14 additions & 1 deletion dialect/sqlitedialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type Dialect struct {
features feature.Feature
}

func New() *Dialect {
func New(opts ...DialectOption) *Dialect {
d := new(Dialect)
d.tables = schema.NewTables(d)
d.features = feature.CTE |
Expand All @@ -42,9 +42,22 @@ func New() *Dialect {
feature.AutoIncrement |
feature.CompositeIn |
feature.DeleteReturning

for _, opt := range opts {
opt(d)
}

return d
}

type DialectOption func(d *Dialect)

func WithoutFeature(other feature.Feature) DialectOption {
return func(d *Dialect) {
d.features = d.features.Remove(other)
}
}

func (d *Dialect) Init(*sql.DB) {}

func (d *Dialect) Name() dialect.Name {
Expand Down
4 changes: 2 additions & 2 deletions model_map_slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func (m *mapSliceModel) appendValues(fmter schema.Formatter, b []byte) (_ []byte
slice := *m.dest

b = append(b, "VALUES "...)
if m.db.features.Has(feature.ValuesRow) {
if m.db.HasFeature(feature.ValuesRow) {
b = append(b, "ROW("...)
} else {
b = append(b, '(')
Expand All @@ -118,7 +118,7 @@ func (m *mapSliceModel) appendValues(fmter schema.Formatter, b []byte) (_ []byte
for i, el := range slice {
if i > 0 {
b = append(b, "), "...)
if m.db.features.Has(feature.ValuesRow) {
if m.db.HasFeature(feature.ValuesRow) {
b = append(b, "ROW("...)
} else {
b = append(b, '(')
Expand Down
2 changes: 1 addition & 1 deletion query_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ func (q *baseQuery) beforeAppendModel(ctx context.Context, query Query) error {
}

func (q *baseQuery) hasFeature(feature feature.Feature) bool {
return q.db.features.Has(feature)
return q.db.HasFeature(feature)
}

//------------------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion query_delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ func (q *DeleteQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, e
return upd.AppendQuery(fmter, b)
}

withAlias := q.db.features.Has(feature.DeleteTableAlias)
withAlias := q.db.HasFeature(feature.DeleteTableAlias)

b, err = q.appendWith(fmter, b)
if err != nil {
Expand Down
10 changes: 5 additions & 5 deletions query_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ func (q *InsertQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, e
}
b = append(b, "INTO "...)

if q.db.features.Has(feature.InsertTableAlias) && !q.on.IsZero() {
if q.db.HasFeature(feature.InsertTableAlias) && !q.on.IsZero() {
b, err = q.appendFirstTableWithAlias(fmter, b)
} else {
b, err = q.appendFirstTable(fmter, b)
Expand Down Expand Up @@ -385,9 +385,9 @@ func (q *InsertQuery) appendSliceValues(
}

func (q *InsertQuery) getFields() ([]*schema.Field, error) {
hasIdentity := q.db.features.Has(feature.Identity)
hasIdentity := q.db.HasFeature(feature.Identity)

if len(q.columns) > 0 || q.db.features.Has(feature.DefaultPlaceholder) && !hasIdentity {
if len(q.columns) > 0 || q.db.HasFeature(feature.DefaultPlaceholder) && !hasIdentity {
return q.baseQuery.getFields()
}

Expand Down Expand Up @@ -640,8 +640,8 @@ func (q *InsertQuery) afterInsertHook(ctx context.Context) error {
}

func (q *InsertQuery) tryLastInsertID(res sql.Result, dest []interface{}) error {
if q.db.features.Has(feature.Returning) ||
q.db.features.Has(feature.Output) ||
if q.db.HasFeature(feature.Returning) ||
q.db.HasFeature(feature.Output) ||
q.table == nil ||
len(q.table.PKs) != 1 ||
!q.table.PKs[0].AutoIncrement {
Expand Down
2 changes: 1 addition & 1 deletion query_table_truncate.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func (q *TruncateTableQuery) AppendQuery(
return nil, err
}

if q.db.features.Has(feature.TableIdentity) {
if q.db.HasFeature(feature.TableIdentity) {
if q.continueIdentity {
b = append(b, " CONTINUE IDENTITY"...)
} else {
Expand Down
4 changes: 2 additions & 2 deletions query_values.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (q *ValuesQuery) appendQuery(
fields []*schema.Field,
) (_ []byte, err error) {
b = append(b, "VALUES "...)
if q.db.features.Has(feature.ValuesRow) {
if q.db.HasFeature(feature.ValuesRow) {
b = append(b, "ROW("...)
} else {
b = append(b, '(')
Expand All @@ -168,7 +168,7 @@ func (q *ValuesQuery) appendQuery(
for i := 0; i < sliceLen; i++ {
if i > 0 {
b = append(b, "), "...)
if q.db.features.Has(feature.ValuesRow) {
if q.db.HasFeature(feature.ValuesRow) {
b = append(b, "ROW("...)
} else {
b = append(b, '(')
Expand Down
2 changes: 1 addition & 1 deletion relation_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func (j *relationJoin) manyQuery(q *SelectQuery) *SelectQuery {

var where []byte

if q.db.dialect.Features().Has(feature.CompositeIn) {
if q.db.HasFeature(feature.CompositeIn) {
return j.manyQueryCompositeIn(where, q)
}
return j.manyQueryMulti(where, q)
Expand Down

0 comments on commit 89e9d51

Please sign in to comment.