From 48ced75d1d8d8aab844ab29787ae97337095b8e1 Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sat, 19 Feb 2022 23:42:20 +0800 Subject: [PATCH] Improve support for AutoMigrate --- migrator/column_type.go | 4 ++-- migrator/migrator.go | 24 +++++++++++++++++++++ tests/go.mod | 10 ++++----- tests/migrate_test.go | 47 ++++++++++++++++++++++++++++++----------- 4 files changed, 66 insertions(+), 19 deletions(-) diff --git a/migrator/column_type.go b/migrator/column_type.go index eb8d1b7f8..cc1331b92 100644 --- a/migrator/column_type.go +++ b/migrator/column_type.go @@ -11,7 +11,7 @@ type ColumnType struct { NameValue sql.NullString DataTypeValue sql.NullString ColumnTypeValue sql.NullString - PrimayKeyValue sql.NullBool + PrimaryKeyValue sql.NullBool UniqueValue sql.NullBool AutoIncrementValue sql.NullBool LengthValue sql.NullInt64 @@ -51,7 +51,7 @@ func (ct ColumnType) ColumnType() (columnType string, ok bool) { // PrimaryKey returns the column is primary key or not. func (ct ColumnType) PrimaryKey() (isPrimaryKey bool, ok bool) { - return ct.PrimayKeyValue.Bool, ct.PrimayKeyValue.Valid + return ct.PrimaryKeyValue.Bool, ct.PrimaryKeyValue.Valid } // AutoIncrement returns the column is auto increment or not. diff --git a/migrator/migrator.go b/migrator/migrator.go index 9695f3129..a50bb3ff8 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -436,6 +436,30 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy } } + // check unique + if unique, ok := columnType.Unique(); ok && unique != field.Unique { + // not primary key + if !field.PrimaryKey { + alterColumn = true + } + } + + // check default value + if v, ok := columnType.DefaultValue(); ok && v != field.DefaultValue { + // not primary key + if !field.PrimaryKey { + alterColumn = true + } + } + + // check comment + if comment, ok := columnType.Comment(); ok && comment != field.Comment { + // not primary key + if !field.PrimaryKey { + alterColumn = true + } + } + if alterColumn && !field.IgnoreMigration { return m.DB.Migrator().AlterColumn(value, field.Name) } diff --git a/tests/go.mod b/tests/go.mod index 0cd036371..1c1fb2389 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -9,11 +9,11 @@ require ( github.com/lib/pq v1.10.4 github.com/mattn/go-sqlite3 v1.14.11 // indirect golang.org/x/crypto v0.0.0-20220214200702-86341886e292 // indirect - gorm.io/driver/mysql v1.3.0 - gorm.io/driver/postgres v1.3.0 - gorm.io/driver/sqlite v1.3.0 - gorm.io/driver/sqlserver v1.3.0 - gorm.io/gorm v1.22.5 + gorm.io/driver/mysql v1.3.1 + gorm.io/driver/postgres v1.3.1 + gorm.io/driver/sqlite v1.3.1 + gorm.io/driver/sqlserver v1.3.1 + gorm.io/gorm v1.23.0 ) replace gorm.io/gorm => ../ diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 5e9c01fa8..94f562b47 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -45,7 +45,7 @@ func TestMigrate(t *testing.T) { for _, m := range allModels { if !DB.Migrator().HasTable(m) { - t.Fatalf("Failed to create table for %#v---", m) + t.Fatalf("Failed to create table for %#v", m) } } @@ -313,15 +313,16 @@ func TestMigrateIndexes(t *testing.T) { } func TestMigrateColumns(t *testing.T) { - fullSupported := map[string]bool{"sqlite": true, "mysql": true, "postgres": true, "sqlserver": true}[DB.Dialector.Name()] sqlite := DB.Dialector.Name() == "sqlite" sqlserver := DB.Dialector.Name() == "sqlserver" type ColumnStruct struct { gorm.Model - Name string - Age int `gorm:"default:18;comment:my age"` - Code string `gorm:"unique"` + Name string + Age int `gorm:"default:18;comment:my age"` + Code string `gorm:"unique;comment:my code;"` + Code2 string + Code3 string `gorm:"unique"` } DB.Migrator().DropTable(&ColumnStruct{}) @@ -332,13 +333,20 @@ func TestMigrateColumns(t *testing.T) { type ColumnStruct2 struct { gorm.Model - Name string `gorm:"size:100"` + Name string `gorm:"size:100"` + Code string `gorm:"unique;comment:my code2;default:hello"` + Code2 string `gorm:"unique"` + // Code3 string } - if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct2{}, "Name"); err != nil { + if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct{}, "Name"); err != nil { t.Fatalf("no error should happened when alter column, but got %v", err) } + if err := DB.Table("column_structs").AutoMigrate(&ColumnStruct2{}); err != nil { + t.Fatalf("no error should happened when auto migrate column, but got %v", err) + } + if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil { t.Fatalf("no error should returns for ColumnTypes") } else { @@ -348,7 +356,7 @@ func TestMigrateColumns(t *testing.T) { for _, columnType := range columnTypes { switch columnType.Name() { case "id": - if v, ok := columnType.PrimaryKey(); (fullSupported || ok) && !v { + if v, ok := columnType.PrimaryKey(); !ok || !v { t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType) } case "name": @@ -356,20 +364,35 @@ func TestMigrateColumns(t *testing.T) { if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) } - if length, ok := columnType.Length(); ((fullSupported && !sqlite) || ok) && length != 100 { + if length, ok := columnType.Length(); !sqlite && (!ok || length != 100) { t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType) } case "age": - if v, ok := columnType.DefaultValue(); (fullSupported || ok) && v != "18" { + if v, ok := columnType.DefaultValue(); !ok || v != "18" { t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) } - if v, ok := columnType.Comment(); ((fullSupported && !sqlite && !sqlserver) || ok) && v != "my age" { + if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my age") { t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) } case "code": - if v, ok := columnType.Unique(); (fullSupported || ok) && !v { + if v, ok := columnType.Unique(); !ok || !v { t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) } + if v, ok := columnType.DefaultValue(); !sqlserver && (!ok || v != "hello") { + t.Fatalf("column code default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my code2") { + t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + case "code2": + if v, ok := columnType.Unique(); !sqlserver && (!ok || !v) { + t.Fatalf("column code2 unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) + } + case "code3": + // TODO + // if v, ok := columnType.Unique(); !ok || v { + // t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) + // } } } }