From d3211908a030169184801800ba74a3a3d93ea6ea Mon Sep 17 00:00:00 2001 From: Jason Lee Date: Mon, 25 Oct 2021 11:26:44 +0800 Subject: [PATCH] Refactor ParseWithSchemaTable method and improve test. (#4789) * Refactor ParseWithSchemaTable method and improve test. * Fix schema.ParseWithSchemaTable method for only use schemaTable in migrator and improve test. * Rename `schemaTable` to `specialTableName` for clearly argument. --- migrator/migrator.go | 2 +- schema/schema.go | 44 ++++++++++++++++++++++++------------------- statement.go | 6 +++++- tests/migrate_test.go | 33 ++++++++++++++++++++------------ 4 files changed, 52 insertions(+), 33 deletions(-) diff --git a/migrator/migrator.go b/migrator/migrator.go index 48db151e0..30586a8cf 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -43,7 +43,7 @@ func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error if table, ok := value.(string); ok { stmt.Table = table - } else if err := stmt.Parse(value); err != nil { + } else if err := stmt.ParseWithSpecialTableName(value, stmt.Table); err != nil { return err } diff --git a/schema/schema.go b/schema/schema.go index c8d79ddc1..ce7cf3b13 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -73,15 +73,11 @@ type Tabler interface { // Parse get data type from dialector func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { - return parse(dest, cacheStore, namer, "") + return ParseWithSpecialTableName(dest, cacheStore, namer, "") } -// ParseWithSchemaTable get data type from dialector with extra schema table -func ParseWithSchemaTable(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable string) (*Schema, error) { - return parse(dest, cacheStore, namer, schemaTable) -} - -func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable string) (*Schema, error) { +// ParseWithSpecialTableName get data type from dialector with extra schema table +func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) { if dest == nil { return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } @@ -107,7 +103,17 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } - if v, ok := cacheStore.Load(modelType); ok { + // Cache the Schema for performance, + // Use the modelType or modelType + schemaTable (if it present) as cache key. + var schemaCacheKey interface{} + if specialTableName != "" { + schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName) + } else { + schemaCacheKey = modelType + } + + // Load exist schmema cache, return if exists + if v, ok := cacheStore.Load(schemaCacheKey); ok { s := v.(*Schema) // Wait for the initialization of other goroutines to complete <-s.initialized @@ -116,15 +122,15 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri modelValue := reflect.New(modelType) tableName := namer.TableName(modelType.Name()) - if schemaTable != "" { - tableName = schemaTable - } if tabler, ok := modelValue.Interface().(Tabler); ok { tableName = tabler.TableName() } if en, ok := namer.(embeddedNamer); ok { tableName = en.Table } + if specialTableName != "" && specialTableName != tableName { + tableName = specialTableName + } schema := &Schema{ Name: modelType.Name(), @@ -140,7 +146,8 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri // When the schema initialization is completed, the channel will be closed defer close(schema.initialized) - if v, loaded := cacheStore.Load(modelType); loaded { + // Load exist schmema cache, return if exists + if v, ok := cacheStore.Load(schemaCacheKey); ok { s := v.(*Schema) // Wait for the initialization of other goroutines to complete <-s.initialized @@ -247,13 +254,12 @@ func parse(dest interface{}, cacheStore *sync.Map, namer Namer, schemaTable stri } } - if schemaTable == "" { - if v, loaded := cacheStore.LoadOrStore(modelType, schema); loaded { - s := v.(*Schema) - // Wait for the initialization of other goroutines to complete - <-s.initialized - return s, s.err - } + // Cache the schema + if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded { + s := v.(*Schema) + // Wait for the initialization of other goroutines to complete + <-s.initialized + return s, s.err } defer func() { diff --git a/statement.go b/statement.go index bbe001063..85432e48f 100644 --- a/statement.go +++ b/statement.go @@ -456,7 +456,11 @@ func (stmt *Statement) Build(clauses ...string) { } func (stmt *Statement) Parse(value interface{}) (err error) { - if stmt.Schema, err = schema.ParseWithSchemaTable(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, stmt.DB.Statement.Table); err == nil && stmt.Table == "" { + return stmt.ParseWithSpecialTableName(value, "") +} + +func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) { + if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" { if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 { stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)} stmt.Table = tables[1] diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 06eb96b33..0354e84e1 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -382,32 +382,41 @@ func TestMigrateConstraint(t *testing.T) { } } -type MigrateUser struct { +type DynamicUser struct { gorm.Model - Name string `gorm:"index"` + Name string + CompanyID string `gorm:"index"` } +// To test auto migrate crate indexes for dynamic table name // https://github.com/go-gorm/gorm/issues/4752 func TestMigrateIndexesWithDynamicTableName(t *testing.T) { - tableNameSuffixes := []string{"01", "02", "03"} - for _, v := range tableNameSuffixes { - tableName := "migrate_user_" + v + // Create primary table + if err := DB.AutoMigrate(&DynamicUser{}); err != nil { + t.Fatalf("AutoMigrate create table error: %#v", err) + } + + // Create sub tables + for _, v := range []string{"01", "02", "03"} { + tableName := "dynamic_users_" + v m := DB.Scopes(func(db *gorm.DB) *gorm.DB { return db.Table(tableName) }).Migrator() - if err := m.AutoMigrate(&MigrateUser{}); err != nil { - t.Fatalf("Failed to create table for %#v", tableName) + if err := m.AutoMigrate(&DynamicUser{}); err != nil { + t.Fatalf("AutoMigrate create table error: %#v", err) } if !m.HasTable(tableName) { - t.Fatalf("Failed to create table for %#v", tableName) + t.Fatalf("AutoMigrate expected %#v exist, but not.", tableName) } - if !m.HasIndex(&MigrateUser{}, "Name") { - t.Fatalf("Should find index for %s's name after AutoMigrate", tableName) + + if !m.HasIndex(&DynamicUser{}, "CompanyID") { + t.Fatalf("Should have index on %s", "CompanyI.") } - if !m.HasIndex(&MigrateUser{}, "DeletedAt") { - t.Fatalf("Should find index for %s's deleted_at after AutoMigrate", tableName) + + if !m.HasIndex(&DynamicUser{}, "DeletedAt") { + t.Fatalf("Should have index on deleted_at.") } } }