Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Follow-up #926 #1058

Merged
merged 3 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dialect/pgdialect/alter_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"github.com/uptrace/bun/schema"
)

func (d *Dialect) Migrator(db *bun.DB, schemaName string) sqlschema.Migrator {
func (d *Dialect) NewMigrator(db *bun.DB, schemaName string) sqlschema.Migrator {
return &migrator{db: db, schemaName: schemaName, BaseMigrator: sqlschema.NewBaseMigrator(db)}
}

Expand Down Expand Up @@ -202,7 +202,7 @@ func (m *migrator) changeColumnType(fmter schema.Formatter, b []byte, colDef *mi
got, want := colDef.From, colDef.To

inspector := m.db.Dialect().(sqlschema.InspectorDialect)
if !inspector.EquivalentType(want, got) {
if !inspector.CompareType(want, got) {
appendAlterColumn()
b = append(b, " SET DATA TYPE "...)
if b, err = want.AppendQuery(fmter, b); err != nil {
Expand Down
22 changes: 12 additions & 10 deletions dialect/pgdialect/inspector.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,40 +15,42 @@ type (
Column = sqlschema.BaseColumn
)

func (d *Dialect) Inspector(db *bun.DB, excludeTables ...string) sqlschema.Inspector {
return newInspector(db, excludeTables...)
func (d *Dialect) NewInspector(db *bun.DB, options ...sqlschema.InspectorOption) sqlschema.Inspector {
return newInspector(db, options...)
}

type Inspector struct {
db *bun.DB
excludeTables []string
sqlschema.InspectorConfig
db *bun.DB
}

var _ sqlschema.Inspector = (*Inspector)(nil)

func newInspector(db *bun.DB, excludeTables ...string) *Inspector {
return &Inspector{db: db, excludeTables: excludeTables}
func newInspector(db *bun.DB, options ...sqlschema.InspectorOption) *Inspector {
i := &Inspector{db: db}
sqlschema.ApplyInspectorOptions(&i.InspectorConfig, options...)
return i
}

func (in *Inspector) Inspect(ctx context.Context, schemaName string) (sqlschema.Database, error) {
func (in *Inspector) Inspect(ctx context.Context) (sqlschema.Database, error) {
dbSchema := Schema{
Tables: orderedmap.New[string, sqlschema.Table](),
ForeignKeys: make(map[sqlschema.ForeignKey]string),
}

exclude := in.excludeTables
exclude := in.ExcludeTables
if len(exclude) == 0 {
// Avoid getting NOT IN (NULL) if bun.In() is called with an empty slice.
exclude = []string{""}
}

var tables []*InformationSchemaTable
if err := in.db.NewRaw(sqlInspectTables, schemaName, bun.In(exclude)).Scan(ctx, &tables); err != nil {
if err := in.db.NewRaw(sqlInspectTables, in.SchemaName, bun.In(exclude)).Scan(ctx, &tables); err != nil {
return dbSchema, err
}

var fks []*ForeignKey
if err := in.db.NewRaw(sqlInspectForeignKeys, schemaName, bun.In(exclude), bun.In(exclude)).Scan(ctx, &fks); err != nil {
if err := in.db.NewRaw(sqlInspectForeignKeys, in.SchemaName, bun.In(exclude), bun.In(exclude)).Scan(ctx, &fks); err != nil {
return dbSchema, err
}
dbSchema.ForeignKeys = make(map[sqlschema.ForeignKey]string, len(fks))
Expand Down
2 changes: 1 addition & 1 deletion dialect/pgdialect/sqltype.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ var (
timestampTz = newAliases(sqltype.Timestamp, pgTypeTimestampTz, pgTypeTimestampWithTz)
)

func (d *Dialect) EquivalentType(col1, col2 sqlschema.Column) bool {
func (d *Dialect) CompareType(col1, col2 sqlschema.Column) bool {
typ1, typ2 := strings.ToUpper(col1.GetSQLType()), strings.ToUpper(col2.GetSQLType())

if typ1 == typ2 {
Expand Down
6 changes: 3 additions & 3 deletions dialect/pgdialect/sqltype_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/uptrace/bun/migrate/sqlschema"
)

func TestInspectorDialect_EquivalentType(t *testing.T) {
func TestInspectorDialect_CompareType(t *testing.T) {
d := New()

t.Run("common types", func(t *testing.T) {
Expand Down Expand Up @@ -41,7 +41,7 @@ func TestInspectorDialect_EquivalentType(t *testing.T) {
eq = " !~ "
}
t.Run(tt.typ1+eq+tt.typ2, func(t *testing.T) {
got := d.EquivalentType(
got := d.CompareType(
&sqlschema.BaseColumn{SQLType: tt.typ1},
&sqlschema.BaseColumn{SQLType: tt.typ2},
)
Expand Down Expand Up @@ -77,7 +77,7 @@ func TestInspectorDialect_EquivalentType(t *testing.T) {
},
} {
t.Run(tt.name, func(t *testing.T) {
got := d.EquivalentType(&tt.col1, &tt.col2)
got := d.CompareType(&tt.col1, &tt.col2)
require.Equal(t, tt.want, got)
})
}
Expand Down
30 changes: 15 additions & 15 deletions internal/dbtest/inspect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ func TestDatabaseInspector_Inspect(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
db.RegisterModel((*PublisherToJournalist)(nil))

dbInspector, err := sqlschema.NewInspector(db, migrationsTable, migrationLocksTable)
dbInspector, err := sqlschema.NewInspector(db, sqlschema.WithSchemaName(tt.schemaName), sqlschema.WithExcludeTables(migrationsTable, migrationLocksTable))
if err != nil {
t.Skip(err)
}
Expand All @@ -353,7 +353,7 @@ func TestDatabaseInspector_Inspect(t *testing.T) {
(*Article)(nil), // references Journalist and Publisher
)

got, err := dbInspector.Inspect(ctx, tt.schemaName)
got, err := dbInspector.Inspect(ctx)
require.NoError(t, err)

// State.FKs store their database names, which differ from dialect to dialect.
Expand Down Expand Up @@ -433,7 +433,7 @@ func cmpColumns(
continue
}

if !d.EquivalentType(wantCol, gotCol) {
if !d.CompareType(wantCol, gotCol) {
errorf("sql types are not equivalent:\n\t(+want)\t%s\n\t(-got)\t%s", formatType(wantCol), formatType(gotCol))
}

Expand Down Expand Up @@ -523,7 +523,7 @@ func TestBunModelInspector_Inspect(t *testing.T) {

tables := schema.NewTables(dialect)
tables.Register((*Model)(nil))
inspector := sqlschema.NewBunModelInspector(tables)
inspector := sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName(dialect.DefaultSchema()))

want := orderedmap.New[string, sqlschema.Column](orderedmap.WithInitialData(
orderedmap.Pair[string, sqlschema.Column]{
Expand All @@ -542,7 +542,7 @@ func TestBunModelInspector_Inspect(t *testing.T) {
},
))

got, err := inspector.Inspect(context.Background(), dialect.DefaultSchema())
got, err := inspector.Inspect(context.Background())
require.NoError(t, err)

gotTables := got.GetTables()
Expand All @@ -562,7 +562,7 @@ func TestBunModelInspector_Inspect(t *testing.T) {

tables := schema.NewTables(dialect)
tables.Register((*Model)(nil))
inspector := sqlschema.NewBunModelInspector(tables)
inspector := sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName(dialect.DefaultSchema()))

want := orderedmap.New[string, sqlschema.Column](orderedmap.WithInitialData(
orderedmap.Pair[string, sqlschema.Column]{
Expand All @@ -587,7 +587,7 @@ func TestBunModelInspector_Inspect(t *testing.T) {
},
))

got, err := inspector.Inspect(context.Background(), dialect.DefaultSchema())
got, err := inspector.Inspect(context.Background())
require.NoError(t, err)

gotTables := got.GetTables()
Expand All @@ -606,7 +606,7 @@ func TestBunModelInspector_Inspect(t *testing.T) {

tables := schema.NewTables(dialect)
tables.Register((*Model)(nil))
inspector := sqlschema.NewBunModelInspector(tables)
inspector := sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName(dialect.DefaultSchema()))

want := &sqlschema.BaseTable{
Name: "models",
Expand All @@ -616,7 +616,7 @@ func TestBunModelInspector_Inspect(t *testing.T) {
},
}

got, err := inspector.Inspect(context.Background(), dialect.DefaultSchema())
got, err := inspector.Inspect(context.Background())
require.NoError(t, err)

gotTables := got.GetTables()
Expand All @@ -635,10 +635,10 @@ func TestBunModelInspector_Inspect(t *testing.T) {

tables := schema.NewTables(dialect)
tables.Register((*Model)(nil))
inspector := sqlschema.NewBunModelInspector(tables)
inspector := sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName(dialect.DefaultSchema()))
want := sqlschema.NewColumns("id", "email")

got, err := inspector.Inspect(context.Background(), dialect.DefaultSchema())
got, err := inspector.Inspect(context.Background())
require.NoError(t, err)

gotTables := got.GetTables()
Expand All @@ -658,9 +658,9 @@ func TestBunModelInspector_Inspect(t *testing.T) {

tables := schema.NewTables(dialect)
tables.Register((*Model)(nil))
inspector := sqlschema.NewBunModelInspector(tables)
inspector := sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName("custom_schema"))

got, err := inspector.Inspect(context.Background(), "custom_schema")
got, err := inspector.Inspect(context.Background())
require.NoError(t, err)

gotTables := got.GetTables()
Expand All @@ -683,9 +683,9 @@ func TestBunModelInspector_Inspect(t *testing.T) {

tables := schema.NewTables(dialect)
tables.Register((*KeepMe)(nil), (*LoseMe)(nil))
inspector := sqlschema.NewBunModelInspector(tables)
inspector := sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName("want"))

got, err := inspector.Inspect(context.Background(), "want")
got, err := inspector.Inspect(context.Background())
require.NoError(t, err)

gotTables := got.GetTables()
Expand Down
14 changes: 8 additions & 6 deletions internal/dbtest/migrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,19 +217,21 @@ func newAutoMigratorOrSkip(tb testing.TB, db *bun.DB, opts ...migrate.AutoMigrat
// and fail if the inspector cannot successfully retrieve database state.
func inspectDbOrSkip(tb testing.TB, db *bun.DB, schemaName ...string) func(context.Context) sqlschema.BaseDatabase {
tb.Helper()
// AutoMigrator excludes these tables by default, but here we need to do this explicitly.
inspector, err := sqlschema.NewInspector(db, migrationsTable, migrationLocksTable)
if err != nil {
tb.Skip(err)
}

// For convenience, schemaName is an optional parameter in this function.
inspectSchema := db.Dialect().DefaultSchema()
if len(schemaName) > 0 {
inspectSchema = schemaName[0]
}

// AutoMigrator excludes these tables by default, but here we need to do this explicitly.
inspector, err := sqlschema.NewInspector(db, sqlschema.WithSchemaName(inspectSchema), sqlschema.WithExcludeTables(migrationsTable, migrationLocksTable))
if err != nil {
tb.Skip(err)
}

return func(ctx context.Context) sqlschema.BaseDatabase {
state, err := inspector.Inspect(ctx, inspectSchema)
state, err := inspector.Inspect(ctx)
require.NoError(tb, err)
return state.(sqlschema.BaseDatabase)
}
Expand Down
12 changes: 7 additions & 5 deletions migrate/auto.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ func WithModel(models ...interface{}) AutoMigratorOption {
// WithExcludeTable tells the AutoMigrator to ignore a table in the database.
// This prevents AutoMigrator from dropping tables which may exist in the schema
// but which are not used by the application.
//
// Do not exclude tables included via WithModel, as BunModelInspector ignores this setting.
func WithExcludeTable(tables ...string) AutoMigratorOption {
return func(m *AutoMigrator) {
m.excludeTables = append(m.excludeTables, tables...)
Expand Down Expand Up @@ -148,12 +150,12 @@ func NewAutoMigrator(db *bun.DB, opts ...AutoMigratorOption) (*AutoMigrator, err
}
am.excludeTables = append(am.excludeTables, am.table, am.locksTable)

dbInspector, err := sqlschema.NewInspector(db, am.excludeTables...)
dbInspector, err := sqlschema.NewInspector(db, sqlschema.WithSchemaName(am.schemaName), sqlschema.WithExcludeTables(am.excludeTables...))
if err != nil {
return nil, err
}
am.dbInspector = dbInspector
am.diffOpts = append(am.diffOpts, withTypeEquivalenceFunc(db.Dialect().(sqlschema.InspectorDialect).EquivalentType))
am.diffOpts = append(am.diffOpts, withCompareTypeFunc(db.Dialect().(sqlschema.InspectorDialect).CompareType))

dbMigrator, err := sqlschema.NewMigrator(db, am.schemaName)
if err != nil {
Expand All @@ -163,20 +165,20 @@ func NewAutoMigrator(db *bun.DB, opts ...AutoMigratorOption) (*AutoMigrator, err

tables := schema.NewTables(db.Dialect())
tables.Register(am.includeModels...)
am.modelInspector = sqlschema.NewBunModelInspector(tables)
am.modelInspector = sqlschema.NewBunModelInspector(tables, sqlschema.WithSchemaName(am.schemaName))

return am, nil
}

func (am *AutoMigrator) plan(ctx context.Context) (*changeset, error) {
var err error

got, err := am.dbInspector.Inspect(ctx, am.schemaName)
got, err := am.dbInspector.Inspect(ctx)
if err != nil {
return nil, err
}

want, err := am.modelInspector.Inspect(ctx, am.schemaName)
want, err := am.modelInspector.Inspect(ctx)
if err != nil {
return nil, err
}
Expand Down
Loading
Loading