From e2a360b9faa72efb3f35f3edca4ed6e293d9185e Mon Sep 17 00:00:00 2001 From: Jinzhu Date: Sun, 23 Feb 2020 21:22:35 +0800 Subject: [PATCH] Add Before/After callbacks --- callbacks/create.go | 64 ++++++++++++++++++++++++++++++++++--- callbacks/delete.go | 50 ++++++++++++++++++++++++++++- callbacks/query.go | 27 ++++++++++++++-- callbacks/update.go | 66 ++++++++++++++++++++++++++++++++++++++- clause/benchmarks_test.go | 4 +-- clause/clause_test.go | 2 +- clause/expression_test.go | 2 +- interfaces.go | 36 +++++++++++++++++++++ schema/callbacks_test.go | 38 ++++++++++++++++++++++ schema/check_test.go | 2 +- schema/field_test.go | 24 +++++++------- schema/index_test.go | 2 +- schema/schema.go | 45 +++++++++++++++++--------- schema/schema_test.go | 6 ++-- 14 files changed, 325 insertions(+), 43 deletions(-) create mode 100644 schema/callbacks_test.go diff --git a/callbacks/create.go b/callbacks/create.go index 3866ddb0d..2e1b33813 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -8,8 +8,36 @@ import ( ) func BeforeCreate(db *gorm.DB) { - // before save - // before create + if db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { + callMethod := func(value interface{}) bool { + var ok bool + if db.Statement.Schema.BeforeSave { + if i, ok := value.(gorm.BeforeSaveInterface); ok { + ok = true + i.BeforeSave(db) + } + } + + if db.Statement.Schema.BeforeCreate { + if i, ok := value.(gorm.BeforeCreateInterface); ok { + ok = true + i.BeforeCreate(db) + } + } + return ok + } + + if ok := callMethod(db.Statement.Dest); !ok { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { + callMethod(db.Statement.ReflectValue.Index(i).Interface()) + } + case reflect.Struct: + callMethod(db.Statement.ReflectValue.Interface()) + } + } + } } func SaveBeforeAssociations(db *gorm.DB) { @@ -48,8 +76,36 @@ func SaveAfterAssociations(db *gorm.DB) { } func AfterCreate(db *gorm.DB) { - // after save - // after create + if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { + callMethod := func(value interface{}) bool { + var ok bool + if db.Statement.Schema.AfterSave { + if i, ok := value.(gorm.AfterSaveInterface); ok { + ok = true + i.AfterSave(db) + } + } + + if db.Statement.Schema.AfterCreate { + if i, ok := value.(gorm.AfterCreateInterface); ok { + ok = true + i.AfterCreate(db) + } + } + return ok + } + + if ok := callMethod(db.Statement.Dest); !ok { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { + callMethod(db.Statement.ReflectValue.Index(i).Interface()) + } + case reflect.Struct: + callMethod(db.Statement.ReflectValue.Interface()) + } + } + } } // ConvertToCreateValues convert to create values diff --git a/callbacks/delete.go b/callbacks/delete.go index 96c392f24..d79f88fc4 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -1,12 +1,60 @@ package callbacks -import "github.com/jinzhu/gorm" +import ( + "reflect" + + "github.com/jinzhu/gorm" +) func BeforeDelete(db *gorm.DB) { + if db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete { + callMethod := func(value interface{}) bool { + if db.Statement.Schema.BeforeDelete { + if i, ok := value.(gorm.BeforeDeleteInterface); ok { + i.BeforeDelete(db) + return true + } + } + return false + } + + if ok := callMethod(db.Statement.Dest); !ok { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { + callMethod(db.Statement.ReflectValue.Index(i).Interface()) + } + case reflect.Struct: + callMethod(db.Statement.ReflectValue.Interface()) + } + } + } } func Delete(db *gorm.DB) { } func AfterDelete(db *gorm.DB) { + if db.Statement.Schema != nil && db.Statement.Schema.AfterDelete { + callMethod := func(value interface{}) bool { + if db.Statement.Schema.AfterDelete { + if i, ok := value.(gorm.AfterDeleteInterface); ok { + i.AfterDelete(db) + return true + } + } + return false + } + + if ok := callMethod(db.Statement.Dest); !ok { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { + callMethod(db.Statement.ReflectValue.Index(i).Interface()) + } + case reflect.Struct: + callMethod(db.Statement.ReflectValue.Interface()) + } + } + } } diff --git a/callbacks/query.go b/callbacks/query.go index 195709fe9..d87850575 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -1,6 +1,8 @@ package callbacks import ( + "reflect" + "github.com/jinzhu/gorm" "github.com/jinzhu/gorm/clause" ) @@ -13,7 +15,7 @@ func Query(db *gorm.DB) { db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") } - rows, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) + _, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) db.AddError(err) } @@ -21,5 +23,26 @@ func Preload(db *gorm.DB) { } func AfterQuery(db *gorm.DB) { - // after find + if db.Statement.Schema != nil && db.Statement.Schema.AfterFind { + callMethod := func(value interface{}) bool { + if db.Statement.Schema.AfterFind { + if i, ok := value.(gorm.AfterFindInterface); ok { + i.AfterFind(db) + return true + } + } + return false + } + + if ok := callMethod(db.Statement.Dest); !ok { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { + callMethod(db.Statement.ReflectValue.Index(i).Interface()) + } + case reflect.Struct: + callMethod(db.Statement.ReflectValue.Interface()) + } + } + } } diff --git a/callbacks/update.go b/callbacks/update.go index 8e5044036..82df3e812 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -1,12 +1,76 @@ package callbacks -import "github.com/jinzhu/gorm" +import ( + "reflect" + + "github.com/jinzhu/gorm" +) func BeforeUpdate(db *gorm.DB) { + if db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { + callMethod := func(value interface{}) bool { + var ok bool + if db.Statement.Schema.BeforeSave { + if i, ok := value.(gorm.BeforeSaveInterface); ok { + ok = true + i.BeforeSave(db) + } + } + + if db.Statement.Schema.BeforeUpdate { + if i, ok := value.(gorm.BeforeUpdateInterface); ok { + ok = true + i.BeforeUpdate(db) + } + } + return ok + } + + if ok := callMethod(db.Statement.Dest); !ok { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { + callMethod(db.Statement.ReflectValue.Index(i).Interface()) + } + case reflect.Struct: + callMethod(db.Statement.ReflectValue.Interface()) + } + } + } } func Update(db *gorm.DB) { } func AfterUpdate(db *gorm.DB) { + if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { + callMethod := func(value interface{}) bool { + var ok bool + if db.Statement.Schema.AfterSave { + if i, ok := value.(gorm.AfterSaveInterface); ok { + ok = true + i.AfterSave(db) + } + } + + if db.Statement.Schema.AfterUpdate { + if i, ok := value.(gorm.AfterUpdateInterface); ok { + ok = true + i.AfterUpdate(db) + } + } + return ok + } + + if ok := callMethod(db.Statement.Dest); !ok { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { + callMethod(db.Statement.ReflectValue.Index(i).Interface()) + } + case reflect.Struct: + callMethod(db.Statement.ReflectValue.Interface()) + } + } + } } diff --git a/clause/benchmarks_test.go b/clause/benchmarks_test.go index 33d3430af..3813fd8ea 100644 --- a/clause/benchmarks_test.go +++ b/clause/benchmarks_test.go @@ -11,7 +11,7 @@ import ( ) func BenchmarkSelect(b *testing.B) { - user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + user, _, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) for i := 0; i < b.N; i++ { stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} @@ -27,7 +27,7 @@ func BenchmarkSelect(b *testing.B) { } func BenchmarkComplexSelect(b *testing.B) { - user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + user, _, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) for i := 0; i < b.N; i++ { stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} diff --git a/clause/clause_test.go b/clause/clause_test.go index 30ea93436..8e4580437 100644 --- a/clause/clause_test.go +++ b/clause/clause_test.go @@ -18,7 +18,7 @@ func checkBuildClauses(t *testing.T, clauses []clause.Interface, result string, var ( buildNames []string buildNamesMap = map[string]bool{} - user, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + user, _, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) stmt = gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} ) diff --git a/clause/expression_test.go b/clause/expression_test.go index e51d189ea..363b40472 100644 --- a/clause/expression_test.go +++ b/clause/expression_test.go @@ -24,7 +24,7 @@ func TestExpr(t *testing.T) { for idx, result := range results { t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { - user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + user, _, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} clause.Expr{SQL: result.SQL, Vars: result.Vars}.Build(stmt) if stmt.SQL.String() != result.Result { diff --git a/interfaces.go b/interfaces.go index bf1aab460..21563b7d0 100644 --- a/interfaces.go +++ b/interfaces.go @@ -24,3 +24,39 @@ type CommonDB interface { QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row } + +type BeforeCreateInterface interface { + BeforeCreate(*DB) +} + +type AfterCreateInterface interface { + AfterCreate(*DB) +} + +type BeforeUpdateInterface interface { + BeforeUpdate(*DB) +} + +type AfterUpdateInterface interface { + AfterUpdate(*DB) +} + +type BeforeSaveInterface interface { + BeforeSave(*DB) +} + +type AfterSaveInterface interface { + AfterSave(*DB) +} + +type BeforeDeleteInterface interface { + BeforeDelete(*DB) +} + +type AfterDeleteInterface interface { + AfterDelete(*DB) +} + +type AfterFindInterface interface { + AfterFind(*DB) +} diff --git a/schema/callbacks_test.go b/schema/callbacks_test.go new file mode 100644 index 000000000..34c0e687a --- /dev/null +++ b/schema/callbacks_test.go @@ -0,0 +1,38 @@ +package schema_test + +import ( + "reflect" + "sync" + "testing" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/schema" +) + +type UserWithCallback struct { +} + +func (UserWithCallback) BeforeSave(*gorm.DB) { +} + +func (UserWithCallback) AfterCreate(*gorm.DB) { +} + +func TestCallback(t *testing.T) { + user, _, err := schema.Parse(&UserWithCallback{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse user with callback, got error %v", err) + } + + for _, str := range []string{"BeforeSave", "AfterCreate"} { + if !reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) { + t.Errorf("%v should be true", str) + } + } + + for _, str := range []string{"BeforeCreate", "BeforeUpdate", "AfterUpdate", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} { + if reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) { + t.Errorf("%v should be false", str) + } + } +} diff --git a/schema/check_test.go b/schema/check_test.go index e4bc9ebe9..f0ba553cc 100644 --- a/schema/check_test.go +++ b/schema/check_test.go @@ -15,7 +15,7 @@ type UserCheck struct { } func TestParseCheck(t *testing.T) { - user, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{}) + user, _, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse user check, got error %v", err) } diff --git a/schema/field_test.go b/schema/field_test.go index 15dfa41d4..02e6aec08 100644 --- a/schema/field_test.go +++ b/schema/field_test.go @@ -14,8 +14,8 @@ import ( func TestFieldValuerAndSetter(t *testing.T) { var ( - userSchema, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) - user = tests.User{ + userSchema, _, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) + user = tests.User{ Model: gorm.Model{ ID: 10, CreatedAt: time.Now(), @@ -81,11 +81,11 @@ func TestFieldValuerAndSetter(t *testing.T) { func TestPointerFieldValuerAndSetter(t *testing.T) { var ( - userSchema, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) - name = "pointer_field_valuer_and_setter" - age uint = 18 - active = true - user = User{ + userSchema, _, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) + name = "pointer_field_valuer_and_setter" + age uint = 18 + active = true + user = User{ Model: &gorm.Model{ ID: 10, CreatedAt: time.Now(), @@ -151,11 +151,11 @@ func TestPointerFieldValuerAndSetter(t *testing.T) { func TestAdvancedDataTypeValuerAndSetter(t *testing.T) { var ( - userSchema, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) - name = "advanced_data_type_valuer_and_setter" - deletedAt = mytime(time.Now()) - isAdmin = mybool(false) - user = AdvancedDataTypeUser{ + userSchema, _, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) + name = "advanced_data_type_valuer_and_setter" + deletedAt = mytime(time.Now()) + isAdmin = mybool(false) + user = AdvancedDataTypeUser{ ID: sql.NullInt64{Int64: 10, Valid: true}, Name: &sql.NullString{String: name, Valid: true}, Birthday: sql.NullTime{Time: time.Now(), Valid: true}, diff --git a/schema/index_test.go b/schema/index_test.go index d0e8dfe02..03d75b978 100644 --- a/schema/index_test.go +++ b/schema/index_test.go @@ -19,7 +19,7 @@ type UserIndex struct { } func TestParseIndex(t *testing.T) { - user, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{}) + user, _, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse user index, got error %v", err) } diff --git a/schema/schema.go b/schema/schema.go index c3ac2bd9a..c56932ad4 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -14,20 +14,25 @@ import ( var ErrUnsupportedDataType = errors.New("unsupported data type") type Schema struct { - Name string - ModelType reflect.Type - Table string - PrioritizedPrimaryField *Field - DBNames []string - PrimaryFields []*Field - Fields []*Field - FieldsByName map[string]*Field - FieldsByDBName map[string]*Field - FieldsWithDefaultDBValue map[string]*Field // fields with default value assigned by database - Relationships Relationships - err error - namer Namer - cacheStore *sync.Map + Name string + ModelType reflect.Type + Table string + PrioritizedPrimaryField *Field + DBNames []string + PrimaryFields []*Field + Fields []*Field + FieldsByName map[string]*Field + FieldsByDBName map[string]*Field + FieldsWithDefaultDBValue map[string]*Field // fields with default value assigned by database + Relationships Relationships + BeforeCreate, AfterCreate bool + BeforeUpdate, AfterUpdate bool + BeforeDelete, AfterDelete bool + BeforeSave, AfterSave bool + AfterFind bool + err error + namer Namer + cacheStore *sync.Map } func (schema Schema) String() string { @@ -162,6 +167,18 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflec } } + callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} + for _, name := range callbacks { + if methodValue := reflectValue.MethodByName(name); methodValue.IsValid() { + switch methodValue.Type().String() { + case "func(*gorm.DB)": // TODO hack + reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true) + default: + logger.Default.Warn("Model %v don't match %vInterface, should be %v(*gorm.DB)", schema, name, name) + } + } + } + cacheStore.Store(modelType, schema) // parse relations for unidentified fields diff --git a/schema/schema_test.go b/schema/schema_test.go index ce225010f..04cd9d828 100644 --- a/schema/schema_test.go +++ b/schema/schema_test.go @@ -9,7 +9,7 @@ import ( ) func TestParseSchema(t *testing.T) { - user, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) + user, _, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse user, got error %v", err) } @@ -18,7 +18,7 @@ func TestParseSchema(t *testing.T) { } func TestParseSchemaWithPointerFields(t *testing.T) { - user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) + user, _, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse pointer user, got error %v", err) } @@ -114,7 +114,7 @@ func checkUserSchema(t *testing.T, user *schema.Schema) { } func TestParseSchemaWithAdvancedDataType(t *testing.T) { - user, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) + user, _, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse pointer user, got error %v", err) }