Skip to content

Commit

Permalink
refactor: remove superficial sqlschema.Operation interface
Browse files Browse the repository at this point in the history
Each dialect has to type-switch the operation before building a query for it.
Since the migrator knows the concrete type of each operation, they are free
to provide FQN in any form. Using schema.FQN field from the start simplifies
the data structure later.

Empty inteface is better that a superficial one.
  • Loading branch information
bevzzz committed Oct 21, 2024
1 parent ee2db8b commit 4c8829d
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 152 deletions.
50 changes: 20 additions & 30 deletions dialect/pgdialect/alter_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type migrator struct {

var _ sqlschema.Migrator = (*migrator)(nil)

func (m *migrator) Apply(ctx context.Context, changes ...sqlschema.Operation) error {
func (m *migrator) Apply(ctx context.Context, changes ...interface{}) error {
if len(changes) == 0 {
return nil
}
Expand All @@ -41,17 +41,17 @@ func (m *migrator) Apply(ctx context.Context, changes ...sqlschema.Operation) er

switch change := change.(type) {
case *migrate.CreateTable:
log.Printf("create table %q", change.Name)
log.Printf("create table %q", change.FQN.Table)
err = m.CreateTable(ctx, change.Model)
if err != nil {
return fmt.Errorf("apply changes: create table %s: %w", change.FQN(), err)
return fmt.Errorf("apply changes: create table %s: %w", change.FQN, err)
}
continue
case *migrate.DropTable:
log.Printf("drop table %q", change.Name)
err = m.DropTable(ctx, change.Schema, change.Name)
log.Printf("drop table %q", change.FQN.Table)
err = m.DropTable(ctx, change.FQN)
if err != nil {
return fmt.Errorf("apply changes: drop table %s: %w", change.FQN(), err)
return fmt.Errorf("apply changes: drop table %s: %w", change.FQN, err)
}
continue
case *migrate.RenameTable:
Expand Down Expand Up @@ -88,35 +88,29 @@ func (m *migrator) Apply(ctx context.Context, changes ...sqlschema.Operation) er

func (m *migrator) renameTable(fmter schema.Formatter, b []byte, rename *migrate.RenameTable) (_ []byte, err error) {
b = append(b, "ALTER TABLE "...)
fqn := rename.FQN()
if b, err = fqn.AppendQuery(fmter, b); err != nil {
return b, err
}
b, _ = rename.FQN.AppendQuery(fmter, b)

b = append(b, " RENAME TO "...)
b = fmter.AppendIdent(b, rename.NewName)
b = fmter.AppendName(b, rename.NewName)
return b, nil
}

func (m *migrator) renameColumn(fmter schema.Formatter, b []byte, rename *migrate.RenameColumn) (_ []byte, err error) {
b = append(b, "ALTER TABLE "...)
fqn := rename.FQN()
if b, err = fqn.AppendQuery(fmter, b); err != nil {
return b, err
}
b, _ = rename.FQN.AppendQuery(fmter, b)

b = append(b, " RENAME COLUMN "...)
b = fmter.AppendIdent(b, rename.OldName)
b = fmter.AppendName(b, rename.OldName)

b = append(b, " TO "...)
b = fmter.AppendIdent(b, rename.NewName)
b = fmter.AppendName(b, rename.NewName)

return b, nil
}

func (m *migrator) addColumn(fmter schema.Formatter, b []byte, add *migrate.AddColumn) (_ []byte, err error) {
b = append(b, "ALTER TABLE "...)
fqn := add.FQN()
b, _ = fqn.AppendQuery(fmter, b)
b, _ = add.FQN.AppendQuery(fmter, b)

b = append(b, " ADD COLUMN "...)
b = fmter.AppendName(b, add.Column)
Expand All @@ -129,8 +123,7 @@ func (m *migrator) addColumn(fmter schema.Formatter, b []byte, add *migrate.AddC

func (m *migrator) dropColumn(fmter schema.Formatter, b []byte, drop *migrate.DropColumn) (_ []byte, err error) {
b = append(b, "ALTER TABLE "...)
fqn := drop.FQN()
b, _ = fqn.AppendQuery(fmter, b)
b, _ = drop.FQN.AppendQuery(fmter, b)

b = append(b, " DROP COLUMN "...)
b = fmter.AppendName(b, drop.Column)
Expand All @@ -146,10 +139,10 @@ func (m *migrator) renameConstraint(fmter schema.Formatter, b []byte, rename *mi
}

b = append(b, " RENAME CONSTRAINT "...)
b = fmter.AppendIdent(b, rename.OldName)
b = fmter.AppendName(b, rename.OldName)

b = append(b, " TO "...)
b = fmter.AppendIdent(b, rename.NewName)
b = fmter.AppendName(b, rename.NewName)

return b, nil
}
Expand All @@ -162,7 +155,7 @@ func (m *migrator) dropContraint(fmter schema.Formatter, b []byte, drop *migrate
}

b = append(b, " DROP CONSTRAINT "...)
b = fmter.AppendIdent(b, drop.ConstraintName)
b = fmter.AppendName(b, drop.ConstraintName)

return b, nil
}
Expand All @@ -175,7 +168,7 @@ func (m *migrator) addForeignKey(fmter schema.Formatter, b []byte, add *migrate.
}

b = append(b, " ADD CONSTRAINT "...)
b = fmter.AppendIdent(b, add.ConstraintName)
b = fmter.AppendName(b, add.ConstraintName)

b = append(b, " FOREIGN KEY ("...)
if b, err = add.FK.From.Column.Safe().AppendQuery(fmter, b); err != nil {
Expand All @@ -200,10 +193,7 @@ func (m *migrator) addForeignKey(fmter schema.Formatter, b []byte, add *migrate.

func (m *migrator) changeColumnType(fmter schema.Formatter, b []byte, colDef *migrate.ChangeColumnType) (_ []byte, err error) {
b = append(b, "ALTER TABLE "...)
fqn := colDef.FQN()
if b, err = fqn.AppendQuery(fmter, b); err != nil {
return b, err
}
b, _ = colDef.FQN.AppendQuery(fmter, b)

// alterColumn never re-assigns err, so there is no need to check for err != nil after calling it
var i int
Expand All @@ -212,7 +202,7 @@ func (m *migrator) changeColumnType(fmter schema.Formatter, b []byte, colDef *mi
b = append(b, ","...)
}
b = append(b, " ALTER COLUMN "...)
b = fmter.AppendIdent(b, colDef.Column)
b = fmter.AppendName(b, colDef.Column)
i++
}

Expand Down
60 changes: 60 additions & 0 deletions internal/dbtest/migrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ func TestAutoMigrator_Run(t *testing.T) {
{testChangeColumnType_AutoCast},
{testIdentity},
{testAddDropColumn},
// {testUnique},
}

testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) {
Expand Down Expand Up @@ -753,6 +754,65 @@ func testAddDropColumn(t *testing.T, db *bun.DB) {
cmpTables(t, db.Dialect().(sqlschema.InspectorDialect), wantTables, state.Tables)
}

func testUnique(t *testing.T, db *bun.DB) {
type TableBefore struct {
bun.BaseModel `bun:"table:table"`
FirstName string `bun:"first_name,unique:full_name"`
LastName string `bun:"last_name,unique:full_name"`
Birthday string `bun:"birthday,unique"`
}

type TableAfter struct {
bun.BaseModel `bun:"table:table"`
FirstName string `bun:"first_name,unique:full_name"`
MiddleName string `bun:"middle_name,unique:full_name"` // extend "full_name" unique group
LastName string `bun:"last_name,unique:full_name"`
Birthday string `bun:"birthday"` // doesn't have to be unique any more
Email string `bun:"email,unique"` // new column, unique
}

wantTables := []sqlschema.Table{
{
Schema: db.Dialect().DefaultSchema(),
Name: "table",
Columns: map[string]sqlschema.Column{
"first_name": {
SQLType: sqltype.VarChar,
IsNullable: true,
},
"middle_name": {
SQLType: sqltype.VarChar,
IsNullable: true,
},
"last_name": {
SQLType: sqltype.VarChar,
IsNullable: true,
},
"birthday": {
SQLType: sqltype.VarChar,
IsNullable: true,
},
"email": {
SQLType: sqltype.VarChar,
IsNullable: true,
},
},
},
}

ctx := context.Background()
inspect := inspectDbOrSkip(t, db)
mustResetModel(t, ctx, db, (*TableBefore)(nil))
m := newAutoMigrator(t, db, migrate.WithModel((*TableAfter)(nil)))

// Act
err := m.Run(ctx)
require.NoError(t, err)

// Assert
state := inspect(ctx)
cmpTables(t, db.Dialect().(sqlschema.InspectorDialect), wantTables, state.Tables)
}

// // TODO: rewrite these tests into AutoMigrator tests, Diff should be moved to migrate/internal package
// func TestDiff(t *testing.T) {
Expand Down
28 changes: 11 additions & 17 deletions migrate/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/uptrace/bun"
"github.com/uptrace/bun/migrate/sqlschema"
"github.com/uptrace/bun/schema"
)

// Diff calculates the diff between the current database schema and the target state.
Expand All @@ -29,8 +30,7 @@ AddedLoop:
for _, removed := range removedTables.Values() {
if d.canRename(removed, added) {
d.changes.Add(&RenameTable{
Schema: removed.Schema,
OldName: removed.Name,
FQN: schema.FQN{removed.Schema, removed.Name},
NewName: added.Name,
})

Expand All @@ -52,9 +52,8 @@ AddedLoop:
}
// If a new table did not appear because of the rename operation, then it must've been created.
d.changes.Add(&CreateTable{
Schema: added.Schema,
Name: added.Name,
Model: added.Model,
FQN: schema.FQN{added.Schema, added.Name},
Model: added.Model,
})
created.Add(added)
}
Expand All @@ -63,8 +62,7 @@ AddedLoop:
dropped := currentTables.Sub(targetTables)
for _, t := range dropped.Values() {
d.changes.Add(&DropTable{
Schema: t.Schema,
Name: t.Name,
FQN: schema.FQN{t.Schema, t.Name},
})
}

Expand Down Expand Up @@ -144,9 +142,9 @@ func (c *changeset) Add(op ...Operation) {
// Func creates a MigrationFunc that applies all operations all the changeset.
func (c *changeset) Func(m sqlschema.Migrator) MigrationFunc {
return func(ctx context.Context, db *bun.DB) error {
var operations []sqlschema.Operation
var operations []interface{}
for _, op := range c.operations {
operations = append(operations, op.(sqlschema.Operation))
operations = append(operations, op.(interface{}))
}
return m.Apply(ctx, operations...)
}
Expand Down Expand Up @@ -349,8 +347,7 @@ ChangedRenamed:
if cCol, ok := current.Columns[tName]; ok {
if checkType && !d.equalColumns(cCol, tCol) {
d.changes.Add(&ChangeColumnType{
Schema: target.Schema,
Table: target.Name,
FQN: schema.FQN{target.Schema, target.Name},
Column: tName,
From: cCol,
To: d.makeTargetColDef(cCol, tCol),
Expand All @@ -367,8 +364,7 @@ ChangedRenamed:
continue
}
d.changes.Add(&RenameColumn{
Schema: target.Schema,
Table: target.Name,
FQN: schema.FQN{target.Schema, target.Name},
OldName: cName,
NewName: tName,
})
Expand All @@ -379,8 +375,7 @@ ChangedRenamed:
}

d.changes.Add(&AddColumn{
Schema: target.Schema,
Table: target.Name,
FQN: schema.FQN{target.Schema, target.Name},
Column: tName,
ColDef: tCol,
})
Expand All @@ -390,8 +385,7 @@ ChangedRenamed:
for cName, cCol := range current.Columns {
if _, keep := target.Columns[cName]; !keep {
d.changes.Add(&DropColumn{
Schema: target.Schema,
Table: target.Name,
FQN: schema.FQN{target.Schema, target.Name},
Column: cName,
ColDef: cCol,
})
Expand Down
Loading

0 comments on commit 4c8829d

Please sign in to comment.