Skip to content

Commit

Permalink
feat: add and drop columns
Browse files Browse the repository at this point in the history
- New operations: AddColumn and DropColumn
- Fixed cmpColumns to find 'extra' columns
- Refactored alter query builder in pgdialect
  • Loading branch information
bevzzz committed Oct 21, 2024
1 parent 9faa55b commit ee2db8b
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 27 deletions.
64 changes: 42 additions & 22 deletions dialect/pgdialect/alter_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ func (m *migrator) Apply(ctx context.Context, changes ...sqlschema.Operation) er

switch change := change.(type) {
case *migrate.CreateTable:
log.Printf("create table %q", change.Name)
err = m.CreateTable(ctx, change.Model)
if err != nil {
return fmt.Errorf("apply changes: create table %s: %w", change.FQN(), err)
}
continue
case *migrate.DropTable:
log.Printf("drop table %q", change.Name)
err = m.DropTable(ctx, change.Schema, change.Name)
if err != nil {
return fmt.Errorf("apply changes: drop table %s: %w", change.FQN(), err)
Expand All @@ -56,6 +58,10 @@ func (m *migrator) Apply(ctx context.Context, changes ...sqlschema.Operation) er
b, err = m.renameTable(fmter, b, change)
case *migrate.RenameColumn:
b, err = m.renameColumn(fmter, b, change)
case *migrate.AddColumn:
b, err = m.addColumn(fmter, b, change)
case *migrate.DropColumn:
b, err = m.dropColumn(fmter, b, change)
case *migrate.DropConstraint:
b, err = m.dropContraint(fmter, b, change)
case *migrate.AddForeignKey:
Expand Down Expand Up @@ -87,9 +93,7 @@ func (m *migrator) renameTable(fmter schema.Formatter, b []byte, rename *migrate
return b, err
}
b = append(b, " RENAME TO "...)
if b, err = bun.Ident(rename.NewName).AppendQuery(fmter, b); err != nil {
return b, err
}
b = fmter.AppendIdent(b, rename.NewName)
return b, nil
}

Expand All @@ -101,14 +105,36 @@ func (m *migrator) renameColumn(fmter schema.Formatter, b []byte, rename *migrat
}

b = append(b, " RENAME COLUMN "...)
if b, err = bun.Ident(rename.OldName).AppendQuery(fmter, b); err != nil {
return b, err
}
b = fmter.AppendIdent(b, rename.OldName)

b = append(b, " TO "...)
if b, err = bun.Ident(rename.NewName).AppendQuery(fmter, b); err != nil {
return b, err
}
b = fmter.AppendIdent(b, rename.NewName)

return b, nil
}

func (m *migrator) addColumn(fmter schema.Formatter, b []byte, add *migrate.AddColumn) (_ []byte, err error) {
b = append(b, "ALTER TABLE "...)
fqn := add.FQN()
b, _ = fqn.AppendQuery(fmter, b)

b = append(b, " ADD COLUMN "...)
b = fmter.AppendName(b, add.Column)
b = append(b, " "...)

b, _ = add.ColDef.AppendQuery(fmter, b)

return b, nil
}

func (m *migrator) dropColumn(fmter schema.Formatter, b []byte, drop *migrate.DropColumn) (_ []byte, err error) {
b = append(b, "ALTER TABLE "...)
fqn := drop.FQN()
b, _ = fqn.AppendQuery(fmter, b)

b = append(b, " DROP COLUMN "...)
b = fmter.AppendName(b, drop.Column)

return b, nil
}

Expand All @@ -120,14 +146,11 @@ func (m *migrator) renameConstraint(fmter schema.Formatter, b []byte, rename *mi
}

b = append(b, " RENAME CONSTRAINT "...)
if b, err = bun.Ident(rename.OldName).AppendQuery(fmter, b); err != nil {
return b, err
}
b = fmter.AppendIdent(b, rename.OldName)

b = append(b, " TO "...)
if b, err = bun.Ident(rename.NewName).AppendQuery(fmter, b); err != nil {
return b, err
}
b = fmter.AppendIdent(b, rename.NewName)

return b, nil
}

Expand All @@ -139,9 +162,8 @@ func (m *migrator) dropContraint(fmter schema.Formatter, b []byte, drop *migrate
}

b = append(b, " DROP CONSTRAINT "...)
if b, err = bun.Ident(drop.ConstraintName).AppendQuery(fmter, b); err != nil {
return b, err
}
b = fmter.AppendIdent(b, drop.ConstraintName)

return b, nil
}

Expand All @@ -153,9 +175,7 @@ func (m *migrator) addForeignKey(fmter schema.Formatter, b []byte, add *migrate.
}

b = append(b, " ADD CONSTRAINT "...)
if b, err = bun.Ident(add.ConstraintName).AppendQuery(fmter, b); err != nil {
return b, err
}
b = fmter.AppendIdent(b, add.ConstraintName)

b = append(b, " FOREIGN KEY ("...)
if b, err = add.FK.From.Column.Safe().AppendQuery(fmter, b); err != nil {
Expand Down Expand Up @@ -192,7 +212,7 @@ func (m *migrator) changeColumnType(fmter schema.Formatter, b []byte, colDef *mi
b = append(b, ","...)
}
b = append(b, " ALTER COLUMN "...)
b, _ = bun.Ident(colDef.Column).AppendQuery(fmter, b)
b = fmter.AppendIdent(b, colDef.Column)
i++
}

Expand Down
21 changes: 20 additions & 1 deletion internal/dbtest/inspect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package dbtest_test
import (
"context"
"fmt"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -306,14 +307,17 @@ func cmpTables(tb testing.TB, d sqlschema.InspectorDialect, want, got []sqlschem
}

func cmpColumns(tb testing.TB, d sqlschema.InspectorDialect, tableName string, want, got map[string]sqlschema.Column) {
tb.Helper()
var errs []string

var missing []string
for colName, wantCol := range want {
errorf := func(format string, args ...interface{}) {
errs = append(errs, fmt.Sprintf("[%s.%s] "+format, append([]interface{}{tableName, colName}, args...)...))
}
gotCol, ok := got[colName]
if !ok {
errorf("column is missing")
missing = append(missing, colName)
continue
}

Expand All @@ -338,6 +342,21 @@ func cmpColumns(tb testing.TB, d sqlschema.InspectorDialect, tableName string, w
}
}

if len(missing) > 0 {
errs = append(errs, fmt.Sprintf("%q has missing columns: %q", tableName, strings.Join(missing, "\", \"")))
}

var extra []string
for colName := range got {
if _, ok := want[colName]; !ok {
extra = append(extra, colName)
}
}

if len(extra) > 0 {
errs = append(errs, fmt.Sprintf("%q has extra columns: %q", tableName, strings.Join(extra, "\", \"")))
}

for _, errMsg := range errs {
tb.Error(errMsg)
}
Expand Down
46 changes: 46 additions & 0 deletions internal/dbtest/migrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ func TestAutoMigrator_Run(t *testing.T) {
{testRenameColumnRenamesFK},
{testChangeColumnType_AutoCast},
{testIdentity},
{testAddDropColumn},
}

testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) {
Expand Down Expand Up @@ -708,6 +709,51 @@ func testIdentity(t *testing.T, db *bun.DB) {
cmpTables(t, db.Dialect().(sqlschema.InspectorDialect), wantTables, state.Tables)
}

func testAddDropColumn(t *testing.T, db *bun.DB) {
type TableBefore struct {
bun.BaseModel `bun:"table:table"`
DoNotTouch string `bun:"do_not_touch"`
DropMe string `bun:"dropme"`
}

type TableAfter struct {
bun.BaseModel `bun:"table:table"`
DoNotTouch string `bun:"do_not_touch"`
AddMe bool `bun:"addme"`
}

wantTables := []sqlschema.Table{
{
Schema: db.Dialect().DefaultSchema(),
Name: "table",
Columns: map[string]sqlschema.Column{
"do_not_touch": {
SQLType: sqltype.VarChar,
IsNullable: true,
},
"addme": {
SQLType: sqltype.Boolean,
IsNullable: true,
},
},
},
}

ctx := context.Background()
inspect := inspectDbOrSkip(t, db)
mustResetModel(t, ctx, db, (*TableBefore)(nil))
m := newAutoMigrator(t, db, migrate.WithModel((*TableAfter)(nil)))

// Act
err := m.Run(ctx)
require.NoError(t, err)

// Assert
state := inspect(ctx)
cmpTables(t, db.Dialect().(sqlschema.InspectorDialect), wantTables, state.Tables)
}


// // TODO: rewrite these tests into AutoMigrator tests, Diff should be moved to migrate/internal package
// func TestDiff(t *testing.T) {
// type Journal struct {
Expand Down
34 changes: 30 additions & 4 deletions migrate/diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,12 @@ func (d *detector) makeTargetColDef(current, target sqlschema.Column) sqlschema.

// detechColumnChanges finds renamed columns and, if checkType == true, columns with changed type.
func (d *detector) detectColumnChanges(current, target sqlschema.Table, checkType bool) {
ChangedRenamed:
for tName, tCol := range target.Columns {
// This column exists in the database, so it hasn't been renamed.

// This column exists in the database, so it hasn't been renamed, dropped, or added.
// Still, we should not delete(columns, thisColumn), because later we will need to
// check that we do not try to rename a column to an already a name that already exists.
if cCol, ok := current.Columns[tName]; ok {
if checkType && !d.equalColumns(cCol, tCol) {
d.changes.Add(&ChangeColumnType{
Expand All @@ -351,13 +355,15 @@ func (d *detector) detectColumnChanges(current, target sqlschema.Table, checkTyp
From: cCol,
To: d.makeTargetColDef(cCol, tCol),
})
// TODO: Can I delete (current.Column, tName) then? Because if it's type has changed, it will never match in the line 343.
}
continue
}

// Column tName does not exist in the database -- it's been either renamed or added.
// Find renamed columns first.
for cName, cCol := range current.Columns {
if _, keep := target.Columns[cName]; keep || !d.equalColumns(tCol, cCol) {
// Cannot rename if a column with this name already exists or the types differ.
if _, exists := current.Columns[tName]; exists || !d.equalColumns(tCol, cCol) {
continue
}
d.changes.Add(&RenameColumn{
Expand All @@ -368,7 +374,27 @@ func (d *detector) detectColumnChanges(current, target sqlschema.Table, checkTyp
})
delete(current.Columns, cName) // no need to check this column again
d.refMap.UpdateC(sqlschema.C(target.Schema, target.Name, cName), tName)
break

continue ChangedRenamed
}

d.changes.Add(&AddColumn{
Schema: target.Schema,
Table: target.Name,
Column: tName,
ColDef: tCol,
})
}

// Drop columns which do not exist in the target schema and were not renamed.
for cName, cCol := range current.Columns {
if _, keep := target.Columns[cName]; !keep {
d.changes.Add(&DropColumn{
Schema: target.Schema,
Table: target.Name,
Column: cName,
ColDef: cCol,
})
}
}
}
Expand Down
79 changes: 79 additions & 0 deletions migrate/operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,84 @@ func (op *RenameColumn) DependsOn(another Operation) bool {
return ok && rt.Schema == op.Schema && rt.NewName == op.Table
}

type AddColumn struct {
Schema string
Table string
Column string
ColDef sqlschema.Column
}

var _ Operation = (*AddColumn)(nil)
var _ sqlschema.Operation = (*AddColumn)(nil)

func (op *AddColumn) FQN() schema.FQN {
return schema.FQN{
Schema: op.Schema,
Table: op.Table,
}
}

func (op *AddColumn) GetReverse() Operation {
return &DropColumn{
Schema: op.Schema,
Table: op.Table,
Column: op.Column,
}
}

type DropColumn struct {
Schema string
Table string
Column string
ColDef sqlschema.Column
}

var _ Operation = (*DropColumn)(nil)
var _ sqlschema.Operation = (*DropColumn)(nil)

func (op *DropColumn) FQN() schema.FQN {
return schema.FQN{
Schema: op.Schema,
Table: op.Table,
}
}

func (op *DropColumn) GetReverse() Operation {
return &AddColumn{
Schema: op.Schema,
Table: op.Table,
Column: op.Column,
ColDef: op.ColDef,
}
}

func (op *DropColumn) DependsOn(another Operation) bool {
// TODO: refactor
if dc, ok := another.(*DropConstraint); ok {
var fCol bool
fCols := dc.FK.From.Column.Split()
for _, c := range fCols {
if c == op.Column {
fCol = true
break
}
}

var tCol bool
tCols := dc.FK.To.Column.Split()
for _, c := range tCols {
if c == op.Column {
tCol = true
break
}
}

return (dc.FK.From.Schema == op.Schema && dc.FK.From.Table == op.Table && fCol) ||
(dc.FK.To.Schema == op.Schema && dc.FK.To.Table == op.Table && tCol)
}
return false
}

// RenameConstraint.
type RenameConstraint struct {
FK sqlschema.FK
Expand Down Expand Up @@ -168,6 +246,7 @@ func (op *AddForeignKey) FQN() schema.FQN {
func (op *AddForeignKey) DependsOn(another Operation) bool {
switch another := another.(type) {
case *RenameTable:
// TODO: provide some sort of "DependsOn" method for FK
return another.Schema == op.FK.From.Schema && another.NewName == op.FK.From.Table
case *CreateTable:
return (another.Schema == op.FK.To.Schema && another.Name == op.FK.To.Table) || // either it's the referencing one
Expand Down

0 comments on commit ee2db8b

Please sign in to comment.