Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support disable dialect's feature #1070

Merged
merged 1 commit into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ func WithDiscardUnknownColumns() DBOption {
type DB struct {
*sql.DB

dialect schema.Dialect
features feature.Feature
dialect schema.Dialect

queryHooks []QueryHook

Expand All @@ -50,10 +49,9 @@ func NewDB(sqldb *sql.DB, dialect schema.Dialect, opts ...DBOption) *DB {
dialect.Init(sqldb)

db := &DB{
DB: sqldb,
dialect: dialect,
features: dialect.Features(),
fmter: schema.NewFormatter(dialect),
DB: sqldb,
dialect: dialect,
fmter: schema.NewFormatter(dialect),
}

for _, opt := range opts {
Expand Down Expand Up @@ -231,7 +229,7 @@ func (db *DB) UpdateFQN(alias, column string) Ident {

// HasFeature uses feature package to report whether the underlying DBMS supports this feature.
func (db *DB) HasFeature(feat feature.Feature) bool {
return db.fmter.HasFeature(feat)
return db.dialect.Features().Has(feat)
}

//------------------------------------------------------------------------------
Expand Down Expand Up @@ -513,7 +511,7 @@ func (tx Tx) commitTX() error {
}

func (tx Tx) commitSP() error {
if tx.Dialect().Features().Has(feature.MSSavepoint) {
if tx.db.HasFeature(feature.MSSavepoint) {
return nil
}
query := "RELEASE SAVEPOINT " + tx.name
Expand All @@ -537,7 +535,7 @@ func (tx Tx) rollbackTX() error {

func (tx Tx) rollbackSP() error {
query := "ROLLBACK TO SAVEPOINT " + tx.name
if tx.Dialect().Features().Has(feature.MSSavepoint) {
if tx.db.HasFeature(feature.MSSavepoint) {
query = "ROLLBACK TRANSACTION " + tx.name
}
_, err := tx.ExecContext(tx.ctx, query)
Expand Down Expand Up @@ -601,7 +599,7 @@ func (tx Tx) BeginTx(ctx context.Context, _ *sql.TxOptions) (Tx, error) {

qName := "SP_" + hex.EncodeToString(sp)
query := "SAVEPOINT " + qName
if tx.Dialect().Features().Has(feature.MSSavepoint) {
if tx.db.HasFeature(feature.MSSavepoint) {
query = "SAVE TRANSACTION " + qName
}
_, err = tx.ExecContext(ctx, query)
Expand Down
14 changes: 13 additions & 1 deletion dialect/mssqldialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ type Dialect struct {
features feature.Feature
}

func New() *Dialect {
func New(opts ...DialectOption) *Dialect {
d := new(Dialect)
d.tables = schema.NewTables(d)
d.features = feature.CTE |
Expand All @@ -49,9 +49,21 @@ func New() *Dialect {
feature.OffsetFetch |
feature.UpdateFromTable |
feature.MSSavepoint

for _, opt := range opts {
opt(d)
}
return d
}

type DialectOption func(d *Dialect)

func WithoutFeature(other feature.Feature) DialectOption {
return func(d *Dialect) {
d.features = d.features.Remove(other)
}
}

func (d *Dialect) Init(db *sql.DB) {
var version string
if err := db.QueryRow("SELECT @@VERSION").Scan(&version); err != nil {
Expand Down
10 changes: 8 additions & 2 deletions dialect/mysqldialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ func init() {
}
}

type DialectOption func(d *Dialect)

type Dialect struct {
schema.BaseDialect

Expand Down Expand Up @@ -60,6 +58,8 @@ func New(opts ...DialectOption) *Dialect {
return d
}

type DialectOption func(d *Dialect)

func WithTimeLocation(loc string) DialectOption {
return func(d *Dialect) {
location, err := time.LoadLocation(loc)
Expand All @@ -70,6 +70,12 @@ func WithTimeLocation(loc string) DialectOption {
}
}

func WithoutFeature(other feature.Feature) DialectOption {
return func(d *Dialect) {
d.features = d.features.Remove(other)
}
}

func (d *Dialect) Init(db *sql.DB) {
var version string
if err := db.QueryRow("SELECT version()").Scan(&version); err != nil {
Expand Down
15 changes: 14 additions & 1 deletion dialect/oracledialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ type Dialect struct {
features feature.Feature
}

func New() *Dialect {
func New(opts ...DialectOption) *Dialect {
d := new(Dialect)
d.tables = schema.NewTables(d)
d.features = feature.CTE |
Expand All @@ -42,9 +42,22 @@ func New() *Dialect {
feature.AutoIncrement |
feature.CompositeIn |
feature.DeleteReturning

for _, opt := range opts {
opt(d)
}

return d
}

type DialectOption func(d *Dialect)

func WithoutFeature(other feature.Feature) DialectOption {
return func(d *Dialect) {
d.features = d.features.Remove(other)
}
}

func (d *Dialect) Init(*sql.DB) {}

func (d *Dialect) Name() dialect.Name {
Expand Down
15 changes: 14 additions & 1 deletion dialect/pgdialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ var _ schema.Dialect = (*Dialect)(nil)
var _ sqlschema.InspectorDialect = (*Dialect)(nil)
var _ sqlschema.MigratorDialect = (*Dialect)(nil)

func New() *Dialect {
func New(opts ...DialectOption) *Dialect {
d := new(Dialect)
d.tables = schema.NewTables(d)
d.features = feature.CTE |
Expand All @@ -55,9 +55,22 @@ func New() *Dialect {
feature.GeneratedIdentity |
feature.CompositeIn |
feature.DeleteReturning

for _, opt := range opts {
opt(d)
}

return d
}

type DialectOption func(d *Dialect)

func WithoutFeature(other feature.Feature) DialectOption {
return func(d *Dialect) {
d.features = d.features.Remove(other)
}
}

func (d *Dialect) Init(*sql.DB) {}

func (d *Dialect) Name() dialect.Name {
Expand Down
15 changes: 14 additions & 1 deletion dialect/sqlitedialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type Dialect struct {
features feature.Feature
}

func New() *Dialect {
func New(opts ...DialectOption) *Dialect {
d := new(Dialect)
d.tables = schema.NewTables(d)
d.features = feature.CTE |
Expand All @@ -42,9 +42,22 @@ func New() *Dialect {
feature.AutoIncrement |
feature.CompositeIn |
feature.DeleteReturning

for _, opt := range opts {
opt(d)
}

return d
}

type DialectOption func(d *Dialect)

func WithoutFeature(other feature.Feature) DialectOption {
return func(d *Dialect) {
d.features = d.features.Remove(other)
}
}

func (d *Dialect) Init(*sql.DB) {}

func (d *Dialect) Name() dialect.Name {
Expand Down
4 changes: 2 additions & 2 deletions model_map_slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func (m *mapSliceModel) appendValues(fmter schema.Formatter, b []byte) (_ []byte
slice := *m.dest

b = append(b, "VALUES "...)
if m.db.features.Has(feature.ValuesRow) {
if m.db.HasFeature(feature.ValuesRow) {
b = append(b, "ROW("...)
} else {
b = append(b, '(')
Expand All @@ -118,7 +118,7 @@ func (m *mapSliceModel) appendValues(fmter schema.Formatter, b []byte) (_ []byte
for i, el := range slice {
if i > 0 {
b = append(b, "), "...)
if m.db.features.Has(feature.ValuesRow) {
if m.db.HasFeature(feature.ValuesRow) {
b = append(b, "ROW("...)
} else {
b = append(b, '(')
Expand Down
2 changes: 1 addition & 1 deletion query_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ func (q *baseQuery) beforeAppendModel(ctx context.Context, query Query) error {
}

func (q *baseQuery) hasFeature(feature feature.Feature) bool {
return q.db.features.Has(feature)
return q.db.HasFeature(feature)
}

//------------------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion query_delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ func (q *DeleteQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, e
return upd.AppendQuery(fmter, b)
}

withAlias := q.db.features.Has(feature.DeleteTableAlias)
withAlias := q.db.HasFeature(feature.DeleteTableAlias)

b, err = q.appendWith(fmter, b)
if err != nil {
Expand Down
10 changes: 5 additions & 5 deletions query_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ func (q *InsertQuery) AppendQuery(fmter schema.Formatter, b []byte) (_ []byte, e
}
b = append(b, "INTO "...)

if q.db.features.Has(feature.InsertTableAlias) && !q.on.IsZero() {
if q.db.HasFeature(feature.InsertTableAlias) && !q.on.IsZero() {
b, err = q.appendFirstTableWithAlias(fmter, b)
} else {
b, err = q.appendFirstTable(fmter, b)
Expand Down Expand Up @@ -385,9 +385,9 @@ func (q *InsertQuery) appendSliceValues(
}

func (q *InsertQuery) getFields() ([]*schema.Field, error) {
hasIdentity := q.db.features.Has(feature.Identity)
hasIdentity := q.db.HasFeature(feature.Identity)

if len(q.columns) > 0 || q.db.features.Has(feature.DefaultPlaceholder) && !hasIdentity {
if len(q.columns) > 0 || q.db.HasFeature(feature.DefaultPlaceholder) && !hasIdentity {
return q.baseQuery.getFields()
}

Expand Down Expand Up @@ -640,8 +640,8 @@ func (q *InsertQuery) afterInsertHook(ctx context.Context) error {
}

func (q *InsertQuery) tryLastInsertID(res sql.Result, dest []interface{}) error {
if q.db.features.Has(feature.Returning) ||
q.db.features.Has(feature.Output) ||
if q.db.HasFeature(feature.Returning) ||
q.db.HasFeature(feature.Output) ||
q.table == nil ||
len(q.table.PKs) != 1 ||
!q.table.PKs[0].AutoIncrement {
Expand Down
2 changes: 1 addition & 1 deletion query_table_truncate.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func (q *TruncateTableQuery) AppendQuery(
return nil, err
}

if q.db.features.Has(feature.TableIdentity) {
if q.db.HasFeature(feature.TableIdentity) {
if q.continueIdentity {
b = append(b, " CONTINUE IDENTITY"...)
} else {
Expand Down
4 changes: 2 additions & 2 deletions query_values.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func (q *ValuesQuery) appendQuery(
fields []*schema.Field,
) (_ []byte, err error) {
b = append(b, "VALUES "...)
if q.db.features.Has(feature.ValuesRow) {
if q.db.HasFeature(feature.ValuesRow) {
b = append(b, "ROW("...)
} else {
b = append(b, '(')
Expand All @@ -168,7 +168,7 @@ func (q *ValuesQuery) appendQuery(
for i := 0; i < sliceLen; i++ {
if i > 0 {
b = append(b, "), "...)
if q.db.features.Has(feature.ValuesRow) {
if q.db.HasFeature(feature.ValuesRow) {
b = append(b, "ROW("...)
} else {
b = append(b, '(')
Expand Down
2 changes: 1 addition & 1 deletion relation_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func (j *relationJoin) manyQuery(q *SelectQuery) *SelectQuery {

var where []byte

if q.db.dialect.Features().Has(feature.CompositeIn) {
if q.db.HasFeature(feature.CompositeIn) {
return j.manyQueryCompositeIn(where, q)
}
return j.manyQueryMulti(where, q)
Expand Down
Loading