Skip to content

Commit

Permalink
Add Before/After callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Feb 23, 2020
1 parent fa22807 commit e2a360b
Show file tree
Hide file tree
Showing 14 changed files with 325 additions and 43 deletions.
64 changes: 60 additions & 4 deletions callbacks/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,36 @@ import (
)

func BeforeCreate(db *gorm.DB) {
// before save
// before create
if db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
callMethod := func(value interface{}) bool {
var ok bool
if db.Statement.Schema.BeforeSave {
if i, ok := value.(gorm.BeforeSaveInterface); ok {
ok = true
i.BeforeSave(db)
}
}

if db.Statement.Schema.BeforeCreate {
if i, ok := value.(gorm.BeforeCreateInterface); ok {
ok = true
i.BeforeCreate(db)
}
}
return ok
}

if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
}

func SaveBeforeAssociations(db *gorm.DB) {
Expand Down Expand Up @@ -48,8 +76,36 @@ func SaveAfterAssociations(db *gorm.DB) {
}

func AfterCreate(db *gorm.DB) {
// after save
// after create
if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
callMethod := func(value interface{}) bool {
var ok bool
if db.Statement.Schema.AfterSave {
if i, ok := value.(gorm.AfterSaveInterface); ok {
ok = true
i.AfterSave(db)
}
}

if db.Statement.Schema.AfterCreate {
if i, ok := value.(gorm.AfterCreateInterface); ok {
ok = true
i.AfterCreate(db)
}
}
return ok
}

if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
}

// ConvertToCreateValues convert to create values
Expand Down
50 changes: 49 additions & 1 deletion callbacks/delete.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,60 @@
package callbacks

import "github.com/jinzhu/gorm"
import (
"reflect"

"github.com/jinzhu/gorm"
)

func BeforeDelete(db *gorm.DB) {
if db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete {
callMethod := func(value interface{}) bool {
if db.Statement.Schema.BeforeDelete {
if i, ok := value.(gorm.BeforeDeleteInterface); ok {
i.BeforeDelete(db)
return true
}
}
return false
}

if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
}

func Delete(db *gorm.DB) {
}

func AfterDelete(db *gorm.DB) {
if db.Statement.Schema != nil && db.Statement.Schema.AfterDelete {
callMethod := func(value interface{}) bool {
if db.Statement.Schema.AfterDelete {
if i, ok := value.(gorm.AfterDeleteInterface); ok {
i.AfterDelete(db)
return true
}
}
return false
}

if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
}
27 changes: 25 additions & 2 deletions callbacks/query.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package callbacks

import (
"reflect"

"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/clause"
)
Expand All @@ -13,13 +15,34 @@ func Query(db *gorm.DB) {
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
}

rows, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
_, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
db.AddError(err)
}

func Preload(db *gorm.DB) {
}

func AfterQuery(db *gorm.DB) {
// after find
if db.Statement.Schema != nil && db.Statement.Schema.AfterFind {
callMethod := func(value interface{}) bool {
if db.Statement.Schema.AfterFind {
if i, ok := value.(gorm.AfterFindInterface); ok {
i.AfterFind(db)
return true
}
}
return false
}

if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
}
66 changes: 65 additions & 1 deletion callbacks/update.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,76 @@
package callbacks

import "github.com/jinzhu/gorm"
import (
"reflect"

"github.com/jinzhu/gorm"
)

func BeforeUpdate(db *gorm.DB) {
if db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
callMethod := func(value interface{}) bool {
var ok bool
if db.Statement.Schema.BeforeSave {
if i, ok := value.(gorm.BeforeSaveInterface); ok {
ok = true
i.BeforeSave(db)
}
}

if db.Statement.Schema.BeforeUpdate {
if i, ok := value.(gorm.BeforeUpdateInterface); ok {
ok = true
i.BeforeUpdate(db)
}
}
return ok
}

if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
}

func Update(db *gorm.DB) {
}

func AfterUpdate(db *gorm.DB) {
if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
callMethod := func(value interface{}) bool {
var ok bool
if db.Statement.Schema.AfterSave {
if i, ok := value.(gorm.AfterSaveInterface); ok {
ok = true
i.AfterSave(db)
}
}

if db.Statement.Schema.AfterUpdate {
if i, ok := value.(gorm.AfterUpdateInterface); ok {
ok = true
i.AfterUpdate(db)
}
}
return ok
}

if ok := callMethod(db.Statement.Dest); !ok {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i <= db.Statement.ReflectValue.Len(); i++ {
callMethod(db.Statement.ReflectValue.Index(i).Interface())
}
case reflect.Struct:
callMethod(db.Statement.ReflectValue.Interface())
}
}
}
}
4 changes: 2 additions & 2 deletions clause/benchmarks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
)

func BenchmarkSelect(b *testing.B) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
user, _, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)

for i := 0; i < b.N; i++ {
stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
Expand All @@ -27,7 +27,7 @@ func BenchmarkSelect(b *testing.B) {
}

func BenchmarkComplexSelect(b *testing.B) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
user, _, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)

for i := 0; i < b.N; i++ {
stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
Expand Down
2 changes: 1 addition & 1 deletion clause/clause_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func checkBuildClauses(t *testing.T, clauses []clause.Interface, result string,
var (
buildNames []string
buildNamesMap = map[string]bool{}
user, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
user, _, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt = gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
)

Expand Down
2 changes: 1 addition & 1 deletion clause/expression_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func TestExpr(t *testing.T) {

for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
user, _, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}}
clause.Expr{SQL: result.SQL, Vars: result.Vars}.Build(stmt)
if stmt.SQL.String() != result.Result {
Expand Down
36 changes: 36 additions & 0 deletions interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,39 @@ type CommonDB interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}

type BeforeCreateInterface interface {
BeforeCreate(*DB)
}

type AfterCreateInterface interface {
AfterCreate(*DB)
}

type BeforeUpdateInterface interface {
BeforeUpdate(*DB)
}

type AfterUpdateInterface interface {
AfterUpdate(*DB)
}

type BeforeSaveInterface interface {
BeforeSave(*DB)
}

type AfterSaveInterface interface {
AfterSave(*DB)
}

type BeforeDeleteInterface interface {
BeforeDelete(*DB)
}

type AfterDeleteInterface interface {
AfterDelete(*DB)
}

type AfterFindInterface interface {
AfterFind(*DB)
}
38 changes: 38 additions & 0 deletions schema/callbacks_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package schema_test

import (
"reflect"
"sync"
"testing"

"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/schema"
)

type UserWithCallback struct {
}

func (UserWithCallback) BeforeSave(*gorm.DB) {
}

func (UserWithCallback) AfterCreate(*gorm.DB) {
}

func TestCallback(t *testing.T) {
user, _, err := schema.Parse(&UserWithCallback{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user with callback, got error %v", err)
}

for _, str := range []string{"BeforeSave", "AfterCreate"} {
if !reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) {
t.Errorf("%v should be true", str)
}
}

for _, str := range []string{"BeforeCreate", "BeforeUpdate", "AfterUpdate", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} {
if reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) {
t.Errorf("%v should be false", str)
}
}
}
2 changes: 1 addition & 1 deletion schema/check_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type UserCheck struct {
}

func TestParseCheck(t *testing.T) {
user, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{})
user, _, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{})
if err != nil {
t.Fatalf("failed to parse user check, got error %v", err)
}
Expand Down
Loading

0 comments on commit e2a360b

Please sign in to comment.