From b2288fcf8f4b48147b969aabe0f1b7e4d32c70bd Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Wed, 13 Nov 2024 10:48:38 +0100 Subject: [PATCH 1/3] refactor: pass schemaName and excludeTables as InspectorOptions --- dialect/pgdialect/inspector.go | 22 +++---- internal/dbtest/inspect_test.go | 28 ++++----- internal/dbtest/migrate_test.go | 14 +++-- migrate/auto.go | 10 ++-- migrate/sqlschema/inspector.go | 103 ++++++++++++++++++++++---------- 5 files changed, 112 insertions(+), 65 deletions(-) diff --git a/dialect/pgdialect/inspector.go b/dialect/pgdialect/inspector.go index ae2b7cc7e..d21e21911 100644 --- a/dialect/pgdialect/inspector.go +++ b/dialect/pgdialect/inspector.go @@ -15,40 +15,42 @@ type ( Column = sqlschema.BaseColumn ) -func (d *Dialect) Inspector(db *bun.DB, excludeTables ...string) sqlschema.Inspector { - return newInspector(db, excludeTables...) +func (d *Dialect) Inspector(db *bun.DB, options ...sqlschema.InspectorOption) sqlschema.Inspector { + return newInspector(db, options...) } type Inspector struct { - db *bun.DB - excludeTables []string + sqlschema.InspectorConfig + db *bun.DB } var _ sqlschema.Inspector = (*Inspector)(nil) -func newInspector(db *bun.DB, excludeTables ...string) *Inspector { - return &Inspector{db: db, excludeTables: excludeTables} +func newInspector(db *bun.DB, options ...sqlschema.InspectorOption) *Inspector { + i := &Inspector{db: db} + sqlschema.ApplyInspectorOptions(&i.InspectorConfig, options...) + return i } -func (in *Inspector) Inspect(ctx context.Context, schemaName string) (sqlschema.Database, error) { +func (in *Inspector) Inspect(ctx context.Context) (sqlschema.Database, error) { dbSchema := Schema{ Tables: orderedmap.New[string, sqlschema.Table](), ForeignKeys: make(map[sqlschema.ForeignKey]string), } - exclude := in.excludeTables + exclude := in.ExcludeTables if len(exclude) == 0 { // Avoid getting NOT IN (NULL) if bun.In() is called with an empty slice. exclude = []string{""} } var tables []*InformationSchemaTable - if err := in.db.NewRaw(sqlInspectTables, schemaName, bun.In(exclude)).Scan(ctx, &tables); err != nil { + if err := in.db.NewRaw(sqlInspectTables, in.SchemaName, bun.In(exclude)).Scan(ctx, &tables); err != nil { return dbSchema, err } var fks []*ForeignKey - if err := in.db.NewRaw(sqlInspectForeignKeys, schemaName, bun.In(exclude), bun.In(exclude)).Scan(ctx, &fks); err != nil { + if err := in.db.NewRaw(sqlInspectForeignKeys, in.SchemaName, bun.In(exclude), bun.In(exclude)).Scan(ctx, &fks); err != nil { return dbSchema, err } dbSchema.ForeignKeys = make(map[sqlschema.ForeignKey]string, len(fks)) diff --git a/internal/dbtest/inspect_test.go b/internal/dbtest/inspect_test.go index dd37e2f13..9ecea49da 100644 --- a/internal/dbtest/inspect_test.go +++ b/internal/dbtest/inspect_test.go @@ -335,7 +335,7 @@ func TestDatabaseInspector_Inspect(t *testing.T) { t.Run(tt.name, func(t *testing.T) { db.RegisterModel((*PublisherToJournalist)(nil)) - dbInspector, err := sqlschema.NewInspector(db, migrationsTable, migrationLocksTable) + dbInspector, err := sqlschema.NewInspector(db, sqlschema.WithSchemaName(tt.schemaName), sqlschema.WithExcludeTables(migrationsTable, migrationLocksTable)) if err != nil { t.Skip(err) } @@ -353,7 +353,7 @@ func TestDatabaseInspector_Inspect(t *testing.T) { (*Article)(nil), // references Journalist and Publisher ) - got, err := dbInspector.Inspect(ctx, tt.schemaName) + got, err := dbInspector.Inspect(ctx) require.NoError(t, err) // State.FKs store their database names, which differ from dialect to dialect. @@ -523,7 +523,7 @@ func TestBunModelInspector_Inspect(t *testing.T) { tables := schema.NewTables(dialect) tables.Register((*Model)(nil)) - inspector := sqlschema.NewBunModelInspector(tables) + inspector := sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName(dialect.DefaultSchema())) want := orderedmap.New[string, sqlschema.Column](orderedmap.WithInitialData( orderedmap.Pair[string, sqlschema.Column]{ @@ -542,7 +542,7 @@ func TestBunModelInspector_Inspect(t *testing.T) { }, )) - got, err := inspector.Inspect(context.Background(), dialect.DefaultSchema()) + got, err := inspector.Inspect(context.Background()) require.NoError(t, err) gotTables := got.GetTables() @@ -562,7 +562,7 @@ func TestBunModelInspector_Inspect(t *testing.T) { tables := schema.NewTables(dialect) tables.Register((*Model)(nil)) - inspector := sqlschema.NewBunModelInspector(tables) + inspector := sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName(dialect.DefaultSchema())) want := orderedmap.New[string, sqlschema.Column](orderedmap.WithInitialData( orderedmap.Pair[string, sqlschema.Column]{ @@ -587,7 +587,7 @@ func TestBunModelInspector_Inspect(t *testing.T) { }, )) - got, err := inspector.Inspect(context.Background(), dialect.DefaultSchema()) + got, err := inspector.Inspect(context.Background()) require.NoError(t, err) gotTables := got.GetTables() @@ -606,7 +606,7 @@ func TestBunModelInspector_Inspect(t *testing.T) { tables := schema.NewTables(dialect) tables.Register((*Model)(nil)) - inspector := sqlschema.NewBunModelInspector(tables) + inspector := sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName(dialect.DefaultSchema())) want := &sqlschema.BaseTable{ Name: "models", @@ -616,7 +616,7 @@ func TestBunModelInspector_Inspect(t *testing.T) { }, } - got, err := inspector.Inspect(context.Background(), dialect.DefaultSchema()) + got, err := inspector.Inspect(context.Background()) require.NoError(t, err) gotTables := got.GetTables() @@ -635,10 +635,10 @@ func TestBunModelInspector_Inspect(t *testing.T) { tables := schema.NewTables(dialect) tables.Register((*Model)(nil)) - inspector := sqlschema.NewBunModelInspector(tables) + inspector := sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName(dialect.DefaultSchema())) want := sqlschema.NewColumns("id", "email") - got, err := inspector.Inspect(context.Background(), dialect.DefaultSchema()) + got, err := inspector.Inspect(context.Background()) require.NoError(t, err) gotTables := got.GetTables() @@ -658,9 +658,9 @@ func TestBunModelInspector_Inspect(t *testing.T) { tables := schema.NewTables(dialect) tables.Register((*Model)(nil)) - inspector := sqlschema.NewBunModelInspector(tables) + inspector := sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName("custom_schema")) - got, err := inspector.Inspect(context.Background(), "custom_schema") + got, err := inspector.Inspect(context.Background()) require.NoError(t, err) gotTables := got.GetTables() @@ -683,9 +683,9 @@ func TestBunModelInspector_Inspect(t *testing.T) { tables := schema.NewTables(dialect) tables.Register((*KeepMe)(nil), (*LoseMe)(nil)) - inspector := sqlschema.NewBunModelInspector(tables) + inspector := sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName("want")) - got, err := inspector.Inspect(context.Background(), "want") + got, err := inspector.Inspect(context.Background()) require.NoError(t, err) gotTables := got.GetTables() diff --git a/internal/dbtest/migrate_test.go b/internal/dbtest/migrate_test.go index 06bc531af..0705aec30 100644 --- a/internal/dbtest/migrate_test.go +++ b/internal/dbtest/migrate_test.go @@ -217,19 +217,21 @@ func newAutoMigratorOrSkip(tb testing.TB, db *bun.DB, opts ...migrate.AutoMigrat // and fail if the inspector cannot successfully retrieve database state. func inspectDbOrSkip(tb testing.TB, db *bun.DB, schemaName ...string) func(context.Context) sqlschema.BaseDatabase { tb.Helper() - // AutoMigrator excludes these tables by default, but here we need to do this explicitly. - inspector, err := sqlschema.NewInspector(db, migrationsTable, migrationLocksTable) - if err != nil { - tb.Skip(err) - } // For convenience, schemaName is an optional parameter in this function. inspectSchema := db.Dialect().DefaultSchema() if len(schemaName) > 0 { inspectSchema = schemaName[0] } + + // AutoMigrator excludes these tables by default, but here we need to do this explicitly. + inspector, err := sqlschema.NewInspector(db, sqlschema.WithSchemaName(inspectSchema), sqlschema.WithExcludeTables(migrationsTable, migrationLocksTable)) + if err != nil { + tb.Skip(err) + } + return func(ctx context.Context) sqlschema.BaseDatabase { - state, err := inspector.Inspect(ctx, inspectSchema) + state, err := inspector.Inspect(ctx) require.NoError(tb, err) return state.(sqlschema.BaseDatabase) } diff --git a/migrate/auto.go b/migrate/auto.go index 32582eba3..be6954f01 100644 --- a/migrate/auto.go +++ b/migrate/auto.go @@ -27,6 +27,8 @@ func WithModel(models ...interface{}) AutoMigratorOption { // WithExcludeTable tells the AutoMigrator to ignore a table in the database. // This prevents AutoMigrator from dropping tables which may exist in the schema // but which are not used by the application. +// +// Do not exclude tables included via WithModel, as BunModelInspector ignores this setting. func WithExcludeTable(tables ...string) AutoMigratorOption { return func(m *AutoMigrator) { m.excludeTables = append(m.excludeTables, tables...) @@ -148,7 +150,7 @@ func NewAutoMigrator(db *bun.DB, opts ...AutoMigratorOption) (*AutoMigrator, err } am.excludeTables = append(am.excludeTables, am.table, am.locksTable) - dbInspector, err := sqlschema.NewInspector(db, am.excludeTables...) + dbInspector, err := sqlschema.NewInspector(db, sqlschema.WithSchemaName(am.schemaName), sqlschema.WithExcludeTables(am.excludeTables...)) if err != nil { return nil, err } @@ -163,7 +165,7 @@ func NewAutoMigrator(db *bun.DB, opts ...AutoMigratorOption) (*AutoMigrator, err tables := schema.NewTables(db.Dialect()) tables.Register(am.includeModels...) - am.modelInspector = sqlschema.NewBunModelInspector(tables) + am.modelInspector = sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName(am.schemaName)) return am, nil } @@ -171,12 +173,12 @@ func NewAutoMigrator(db *bun.DB, opts ...AutoMigratorOption) (*AutoMigrator, err func (am *AutoMigrator) plan(ctx context.Context) (*changeset, error) { var err error - got, err := am.dbInspector.Inspect(ctx, am.schemaName) + got, err := am.dbInspector.Inspect(ctx) if err != nil { return nil, err } - want, err := am.modelInspector.Inspect(ctx, am.schemaName) + want, err := am.modelInspector.Inspect(ctx) if err != nil { return nil, err } diff --git a/migrate/sqlschema/inspector.go b/migrate/sqlschema/inspector.go index ed474ed95..1464e5ccf 100644 --- a/migrate/sqlschema/inspector.go +++ b/migrate/sqlschema/inspector.go @@ -13,7 +13,13 @@ import ( type InspectorDialect interface { schema.Dialect - Inspector(db *bun.DB, excludeTables ...string) Inspector + + // Inspector returns a new instance of Inspector for the dialect. + // Dialects MAY set their default InspectorConfig values in constructor + // but MUST apply InspectorOptions to ensure they can be overriden. + // + // Use ApplyInspectorOptions to reduce boilerplate. + Inspector(db *bun.DB, options ...InspectorOption) Inspector // EquivalentType returns true if col1 and co2 SQL types are equivalent, // i.e. they might use dialect-specifc type aliases (SERIAL ~ SMALLINT) @@ -21,61 +27,77 @@ type InspectorDialect interface { EquivalentType(Column, Column) bool } +// InspectorConfig controls the scope of migration by limiting the objects Inspector should return. +// Inspectors SHOULD use the configuration directly instead of copying it, or MAY choose to embed it, +// to make sure options are always applied correctly. +type InspectorConfig struct { + // SchemaName limits inspection to tables in a particular schema. + SchemaName string + + // ExcludeTables from inspection. + ExcludeTables []string +} + // Inspector reads schema state. type Inspector interface { - Inspect(ctx context.Context, schemaName string) (Database, error) + Inspect(ctx context.Context) (Database, error) } -// inspector is opaque pointer to a databse inspector. -type inspector struct { - Inspector +func WithSchemaName(schemaName string) InspectorOption { + return func(cfg *InspectorConfig) { + cfg.SchemaName = schemaName + } +} + +// WithExcludeTables works in append-only mode, i.e. tables cannot be re-included. +func WithExcludeTables(tables ...string) InspectorOption { + return func(cfg *InspectorConfig) { + cfg.ExcludeTables = append(cfg.ExcludeTables, tables...) + } } // NewInspector creates a new database inspector, if the dialect supports it. -func NewInspector(db *bun.DB, excludeTables ...string) (Inspector, error) { +func NewInspector(db *bun.DB, options ...InspectorOption) (Inspector, error) { dialect, ok := (db.Dialect()).(InspectorDialect) if !ok { return nil, fmt.Errorf("%s does not implement sqlschema.Inspector", db.Dialect().Name()) } return &inspector{ - Inspector: dialect.Inspector(db, excludeTables...), + Inspector: dialect.Inspector(db, options...), }, nil } -// BunModelInspector creates the current project state from the passed bun.Models. -// Do not recycle BunModelInspector for different sets of models, as older models will not be de-registerred before the next run. -type BunModelInspector struct { - tables *schema.Tables -} - -var _ Inspector = (*BunModelInspector)(nil) - -func NewBunModelInspector(tables *schema.Tables) *BunModelInspector { - return &BunModelInspector{ +func NewBunModelInspector(tables *schema.Tables, options ...InspectorOption) *BunModelInspector { + bmi := &BunModelInspector{ tables: tables, } + ApplyInspectorOptions(&bmi.InspectorConfig, options...) + return bmi } -// BunModelSchema is the schema state derived from bun table models. -type BunModelSchema struct { - BaseDatabase +type InspectorOption func(*InspectorConfig) - Tables *orderedmap.OrderedMap[string, Table] +func ApplyInspectorOptions(cfg *InspectorConfig, options ...InspectorOption) { + for _, opt := range options { + opt(cfg) + } } -func (ms BunModelSchema) GetTables() *orderedmap.OrderedMap[string, Table] { - return ms.Tables +// inspector is opaque pointer to a database inspector. +type inspector struct { + Inspector } -// BunTable provides additional table metadata that is only accessible from scanning bun models. -type BunTable struct { - BaseTable - - // Model stores the zero interface to the underlying Go struct. - Model interface{} +// BunModelInspector creates the current project state from the passed bun.Models. +// Do not recycle BunModelInspector for different sets of models, as older models will not be de-registerred before the next run. +type BunModelInspector struct { + InspectorConfig + tables *schema.Tables } -func (bmi *BunModelInspector) Inspect(ctx context.Context, schemaName string) (Database, error) { +var _ Inspector = (*BunModelInspector)(nil) + +func (bmi *BunModelInspector) Inspect(ctx context.Context) (Database, error) { state := BunModelSchema{ BaseDatabase: BaseDatabase{ ForeignKeys: make(map[ForeignKey]string), @@ -83,7 +105,7 @@ func (bmi *BunModelInspector) Inspect(ctx context.Context, schemaName string) (D Tables: orderedmap.New[string, Table](), } for _, t := range bmi.tables.All() { - if t.Schema != schemaName { + if t.Schema != bmi.SchemaName { continue } @@ -198,3 +220,22 @@ func exprToLower(s string) string { } return strings.ToLower(s) } + +// BunModelSchema is the schema state derived from bun table models. +type BunModelSchema struct { + BaseDatabase + + Tables *orderedmap.OrderedMap[string, Table] +} + +func (ms BunModelSchema) GetTables() *orderedmap.OrderedMap[string, Table] { + return ms.Tables +} + +// BunTable provides additional table metadata that is only accessible from scanning bun models. +type BunTable struct { + BaseTable + + // Model stores the zero interface to the underlying Go struct. + Model interface{} +} From 5cff0dd3cd4ef2c8897c78cc9dbe86a67ac9843b Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Wed, 13 Nov 2024 11:02:58 +0100 Subject: [PATCH 2/3] refactor: rename dialect methods to NewInspector and NewMigrator --- dialect/pgdialect/alter_table.go | 2 +- dialect/pgdialect/inspector.go | 2 +- migrate/sqlschema/inspector.go | 4 ++-- migrate/sqlschema/migrator.go | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/dialect/pgdialect/alter_table.go b/dialect/pgdialect/alter_table.go index dac827a20..654921472 100644 --- a/dialect/pgdialect/alter_table.go +++ b/dialect/pgdialect/alter_table.go @@ -10,7 +10,7 @@ import ( "github.com/uptrace/bun/schema" ) -func (d *Dialect) Migrator(db *bun.DB, schemaName string) sqlschema.Migrator { +func (d *Dialect) NewMigrator(db *bun.DB, schemaName string) sqlschema.Migrator { return &migrator{db: db, schemaName: schemaName, BaseMigrator: sqlschema.NewBaseMigrator(db)} } diff --git a/dialect/pgdialect/inspector.go b/dialect/pgdialect/inspector.go index d21e21911..42bbbe84f 100644 --- a/dialect/pgdialect/inspector.go +++ b/dialect/pgdialect/inspector.go @@ -15,7 +15,7 @@ type ( Column = sqlschema.BaseColumn ) -func (d *Dialect) Inspector(db *bun.DB, options ...sqlschema.InspectorOption) sqlschema.Inspector { +func (d *Dialect) NewInspector(db *bun.DB, options ...sqlschema.InspectorOption) sqlschema.Inspector { return newInspector(db, options...) } diff --git a/migrate/sqlschema/inspector.go b/migrate/sqlschema/inspector.go index 1464e5ccf..1532d036c 100644 --- a/migrate/sqlschema/inspector.go +++ b/migrate/sqlschema/inspector.go @@ -19,7 +19,7 @@ type InspectorDialect interface { // but MUST apply InspectorOptions to ensure they can be overriden. // // Use ApplyInspectorOptions to reduce boilerplate. - Inspector(db *bun.DB, options ...InspectorOption) Inspector + NewInspector(db *bun.DB, options ...InspectorOption) Inspector // EquivalentType returns true if col1 and co2 SQL types are equivalent, // i.e. they might use dialect-specifc type aliases (SERIAL ~ SMALLINT) @@ -63,7 +63,7 @@ func NewInspector(db *bun.DB, options ...InspectorOption) (Inspector, error) { return nil, fmt.Errorf("%s does not implement sqlschema.Inspector", db.Dialect().Name()) } return &inspector{ - Inspector: dialect.Inspector(db, options...), + Inspector: dialect.NewInspector(db, options...), }, nil } diff --git a/migrate/sqlschema/migrator.go b/migrate/sqlschema/migrator.go index c9f9d2592..00500061b 100644 --- a/migrate/sqlschema/migrator.go +++ b/migrate/sqlschema/migrator.go @@ -9,7 +9,7 @@ import ( type MigratorDialect interface { schema.Dialect - Migrator(db *bun.DB, schemaName string) Migrator + NewMigrator(db *bun.DB, schemaName string) Migrator } type Migrator interface { @@ -27,7 +27,7 @@ func NewMigrator(db *bun.DB, schemaName string) (Migrator, error) { return nil, fmt.Errorf("%q dialect does not implement sqlschema.Migrator", db.Dialect().Name()) } return &migrator{ - Migrator: md.Migrator(db, schemaName), + Migrator: md.NewMigrator(db, schemaName), }, nil } From c228b0e155d30c48ceaac0b62c1a03cd36482c92 Mon Sep 17 00:00:00 2001 From: dyma solovei Date: Wed, 13 Nov 2024 11:08:13 +0100 Subject: [PATCH 3/3] refactor: rename EquivalentType to CompareType --- dialect/pgdialect/alter_table.go | 2 +- dialect/pgdialect/sqltype.go | 2 +- dialect/pgdialect/sqltype_test.go | 6 +++--- internal/dbtest/inspect_test.go | 2 +- migrate/auto.go | 2 +- migrate/diff.go | 28 ++++++++++++++-------------- migrate/sqlschema/inspector.go | 4 ++-- 7 files changed, 23 insertions(+), 23 deletions(-) diff --git a/dialect/pgdialect/alter_table.go b/dialect/pgdialect/alter_table.go index 654921472..d20f8c069 100644 --- a/dialect/pgdialect/alter_table.go +++ b/dialect/pgdialect/alter_table.go @@ -202,7 +202,7 @@ func (m *migrator) changeColumnType(fmter schema.Formatter, b []byte, colDef *mi got, want := colDef.From, colDef.To inspector := m.db.Dialect().(sqlschema.InspectorDialect) - if !inspector.EquivalentType(want, got) { + if !inspector.CompareType(want, got) { appendAlterColumn() b = append(b, " SET DATA TYPE "...) if b, err = want.AppendQuery(fmter, b); err != nil { diff --git a/dialect/pgdialect/sqltype.go b/dialect/pgdialect/sqltype.go index fcb9f8ebb..bacc00e86 100644 --- a/dialect/pgdialect/sqltype.go +++ b/dialect/pgdialect/sqltype.go @@ -125,7 +125,7 @@ var ( timestampTz = newAliases(sqltype.Timestamp, pgTypeTimestampTz, pgTypeTimestampWithTz) ) -func (d *Dialect) EquivalentType(col1, col2 sqlschema.Column) bool { +func (d *Dialect) CompareType(col1, col2 sqlschema.Column) bool { typ1, typ2 := strings.ToUpper(col1.GetSQLType()), strings.ToUpper(col2.GetSQLType()) if typ1 == typ2 { diff --git a/dialect/pgdialect/sqltype_test.go b/dialect/pgdialect/sqltype_test.go index 4f707a7e9..8181e599d 100644 --- a/dialect/pgdialect/sqltype_test.go +++ b/dialect/pgdialect/sqltype_test.go @@ -8,7 +8,7 @@ import ( "github.com/uptrace/bun/migrate/sqlschema" ) -func TestInspectorDialect_EquivalentType(t *testing.T) { +func TestInspectorDialect_CompareType(t *testing.T) { d := New() t.Run("common types", func(t *testing.T) { @@ -41,7 +41,7 @@ func TestInspectorDialect_EquivalentType(t *testing.T) { eq = " !~ " } t.Run(tt.typ1+eq+tt.typ2, func(t *testing.T) { - got := d.EquivalentType( + got := d.CompareType( &sqlschema.BaseColumn{SQLType: tt.typ1}, &sqlschema.BaseColumn{SQLType: tt.typ2}, ) @@ -77,7 +77,7 @@ func TestInspectorDialect_EquivalentType(t *testing.T) { }, } { t.Run(tt.name, func(t *testing.T) { - got := d.EquivalentType(&tt.col1, &tt.col2) + got := d.CompareType(&tt.col1, &tt.col2) require.Equal(t, tt.want, got) }) } diff --git a/internal/dbtest/inspect_test.go b/internal/dbtest/inspect_test.go index 9ecea49da..943846106 100644 --- a/internal/dbtest/inspect_test.go +++ b/internal/dbtest/inspect_test.go @@ -433,7 +433,7 @@ func cmpColumns( continue } - if !d.EquivalentType(wantCol, gotCol) { + if !d.CompareType(wantCol, gotCol) { errorf("sql types are not equivalent:\n\t(+want)\t%s\n\t(-got)\t%s", formatType(wantCol), formatType(gotCol)) } diff --git a/migrate/auto.go b/migrate/auto.go index be6954f01..e56fa23a0 100644 --- a/migrate/auto.go +++ b/migrate/auto.go @@ -155,7 +155,7 @@ func NewAutoMigrator(db *bun.DB, opts ...AutoMigratorOption) (*AutoMigrator, err return nil, err } am.dbInspector = dbInspector - am.diffOpts = append(am.diffOpts, withTypeEquivalenceFunc(db.Dialect().(sqlschema.InspectorDialect).EquivalentType)) + am.diffOpts = append(am.diffOpts, withCompareTypeFunc(db.Dialect().(sqlschema.InspectorDialect).CompareType)) dbMigrator, err := sqlschema.NewMigrator(db, am.schemaName) if err != nil { diff --git a/migrate/diff.go b/migrate/diff.go index facd47c74..42e55dcde 100644 --- a/migrate/diff.go +++ b/migrate/diff.go @@ -217,7 +217,7 @@ Drop: func newDetector(got, want sqlschema.Database, opts ...diffOption) *detector { cfg := &detectorConfig{ - EqType: func(c1, c2 sqlschema.Column) bool { + cmpType: func(c1, c2 sqlschema.Column) bool { return c1.GetSQLType() == c2.GetSQLType() && c1.GetVarcharLen() == c2.GetVarcharLen() }, } @@ -229,21 +229,21 @@ func newDetector(got, want sqlschema.Database, opts ...diffOption) *detector { current: got, target: want, refMap: newRefMap(got.GetForeignKeys()), - eqType: cfg.EqType, + cmpType: cfg.cmpType, } } type diffOption func(*detectorConfig) -func withTypeEquivalenceFunc(f TypeEquivalenceFunc) diffOption { +func withCompareTypeFunc(f CompareTypeFunc) diffOption { return func(cfg *detectorConfig) { - cfg.EqType = f + cfg.cmpType = f } } // detectorConfig controls how differences in the model states are resolved. type detectorConfig struct { - EqType TypeEquivalenceFunc + cmpType CompareTypeFunc } // detector may modify the passed database schemas, so it isn't safe to re-use them. @@ -257,11 +257,11 @@ type detector struct { changes changeset refMap refMap - // eqType determines column type equivalence. + // cmpType determines column type equivalence. // Default is direct comparison with '==' operator, which is inaccurate // due to the existence of dialect-specific type aliases. The caller // should pass a concrete InspectorDialect.EquuivalentType for robust comparison. - eqType TypeEquivalenceFunc + cmpType CompareTypeFunc } // canRename checks if t1 can be renamed to t2. @@ -270,7 +270,7 @@ func (d detector) canRename(t1, t2 sqlschema.Table) bool { } func (d detector) equalColumns(col1, col2 sqlschema.Column) bool { - return d.eqType(col1, col2) && + return d.cmpType(col1, col2) && col1.GetDefaultValue() == col2.GetDefaultValue() && col1.GetIsNullable() == col2.GetIsNullable() && col1.GetIsAutoIncrement() == col2.GetIsAutoIncrement() && @@ -279,7 +279,7 @@ func (d detector) equalColumns(col1, col2 sqlschema.Column) bool { func (d detector) makeTargetColDef(current, target sqlschema.Column) sqlschema.Column { // Avoid unneccessary type-change migrations if the types are equivalent. - if d.eqType(current, target) { + if d.cmpType(current, target) { target = &sqlschema.BaseColumn{ Name: target.GetName(), DefaultValue: target.GetDefaultValue(), @@ -294,10 +294,10 @@ func (d detector) makeTargetColDef(current, target sqlschema.Column) sqlschema.C return target } -type TypeEquivalenceFunc func(sqlschema.Column, sqlschema.Column) bool +type CompareTypeFunc func(sqlschema.Column, sqlschema.Column) bool // equalSignatures determines if two tables have the same "signature". -func equalSignatures(t1, t2 sqlschema.Table, eq TypeEquivalenceFunc) bool { +func equalSignatures(t1, t2 sqlschema.Table, eq CompareTypeFunc) bool { sig1 := newSignature(t1, eq) sig2 := newSignature(t2, eq) return sig1.Equals(sig2) @@ -311,10 +311,10 @@ type signature struct { // It helps to account for the fact that a table might have multiple columns that have the same type. underlying map[sqlschema.BaseColumn]int - eq TypeEquivalenceFunc + eq CompareTypeFunc } -func newSignature(t sqlschema.Table, eq TypeEquivalenceFunc) signature { +func newSignature(t sqlschema.Table, eq CompareTypeFunc) signature { s := signature{ underlying: make(map[sqlschema.BaseColumn]int), eq: eq, @@ -338,7 +338,7 @@ func (s *signature) scan(t sqlschema.Table) { } } -// getCount uses TypeEquivalenceFunc to find a column with the same (equivalent) SQL type +// getCount uses CompareTypeFunc to find a column with the same (equivalent) SQL type // and returns its count. Count 0 means there are no columns with of this type. func (s *signature) getCount(keyCol sqlschema.BaseColumn) (key sqlschema.BaseColumn, count int) { for col, cnt := range s.underlying { diff --git a/migrate/sqlschema/inspector.go b/migrate/sqlschema/inspector.go index 1532d036c..fc9af06fc 100644 --- a/migrate/sqlschema/inspector.go +++ b/migrate/sqlschema/inspector.go @@ -21,10 +21,10 @@ type InspectorDialect interface { // Use ApplyInspectorOptions to reduce boilerplate. NewInspector(db *bun.DB, options ...InspectorOption) Inspector - // EquivalentType returns true if col1 and co2 SQL types are equivalent, + // CompareType returns true if col1 and co2 SQL types are equivalent, // i.e. they might use dialect-specifc type aliases (SERIAL ~ SMALLINT) // or specify the same VARCHAR length differently (VARCHAR(255) ~ VARCHAR). - EquivalentType(Column, Column) bool + CompareType(Column, Column) bool } // InspectorConfig controls the scope of migration by limiting the objects Inspector should return.