Skip to content

Commit

Permalink
Add callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Feb 2, 2020
1 parent d833efe commit 728c0d4
Show file tree
Hide file tree
Showing 13 changed files with 101 additions and 51 deletions.
29 changes: 18 additions & 11 deletions callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ import (
"github.com/jinzhu/gorm/utils"
)

func InitializeCallbacks() *callbacks {
func initializeCallbacks(db *DB) *callbacks {
return &callbacks{
processors: map[string]*processor{
"create": &processor{},
"query": &processor{},
"update": &processor{},
"delete": &processor{},
"row": &processor{},
"raw": &processor{},
"create": &processor{db: db},
"query": &processor{db: db},
"update": &processor{db: db},
"delete": &processor{db: db},
"row": &processor{db: db},
"raw": &processor{db: db},
},
}
}
Expand Down Expand Up @@ -118,7 +118,14 @@ func (p *processor) Replace(name string, fn func(*DB)) error {
return (&callback{processor: p}).Replace(name, fn)
}

func (p *processor) compile(db *DB) (err error) {
func (p *processor) compile() (err error) {
var callbacks []*callback
for _, callback := range p.callbacks {
if callback.match == nil || callback.match(p.db) {
callbacks = append(callbacks, callback)
}
}

if p.fns, err = sortCallbacks(p.callbacks); err != nil {
logger.Default.Error("Got error when compile callbacks, got %v", err)
}
Expand All @@ -139,15 +146,15 @@ func (c *callback) Register(name string, fn func(*DB)) error {
c.name = name
c.handler = fn
c.processor.callbacks = append(c.processor.callbacks, c)
return c.processor.compile(c.processor.db)
return c.processor.compile()
}

func (c *callback) Remove(name string) error {
logger.Default.Warn("removing callback `%v` from %v\n", name, utils.FileWithLineNum())
c.name = name
c.remove = true
c.processor.callbacks = append(c.processor.callbacks, c)
return c.processor.compile(c.processor.db)
return c.processor.compile()
}

func (c *callback) Replace(name string, fn func(*DB)) error {
Expand All @@ -156,7 +163,7 @@ func (c *callback) Replace(name string, fn func(*DB)) error {
c.handler = fn
c.replace = true
c.processor.callbacks = append(c.processor.callbacks, c)
return c.processor.compile(c.processor.db)
return c.processor.compile()
}

// getRIndex get right index from string slice
Expand Down
39 changes: 33 additions & 6 deletions callbacks/callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,37 @@ package callbacks
import "github.com/jinzhu/gorm"

func RegisterDefaultCallbacks(db *gorm.DB) {
callback := db.Callback()
callback.Create().Register("gorm:before_create", BeforeCreate)
callback.Create().Register("gorm:save_before_associations", SaveBeforeAssociations)
callback.Create().Register("gorm:create", Create)
callback.Create().Register("gorm:save_after_associations", SaveAfterAssociations)
callback.Create().Register("gorm:after_create", AfterCreate)
enableTransaction := func(db *gorm.DB) bool {
return !db.SkipDefaultTransaction
}

createCallback := db.Callback().Create()
createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
createCallback.Register("gorm:before_create", BeforeCreate)
createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations)
createCallback.Register("gorm:create", Create)
createCallback.Register("gorm:save_after_associations", SaveAfterAssociations)
createCallback.Register("gorm:after_create", AfterCreate)
createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)

queryCallback := db.Callback().Query()
queryCallback.Register("gorm:query", BeforeCreate)
queryCallback.Register("gorm:preload", Preload)
queryCallback.Register("gorm:after_query", AfterQuery)

deleteCallback := db.Callback().Delete()
deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
deleteCallback.Register("gorm:before_delete", BeforeDelete)
deleteCallback.Register("gorm:delete", Delete)
deleteCallback.Register("gorm:after_delete", AfterDelete)
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)

updateCallback := db.Callback().Update()
updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
updateCallback.Register("gorm:before_update", BeforeUpdate)
updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations)
updateCallback.Register("gorm:update", Update)
updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations)
updateCallback.Register("gorm:after_update", AfterUpdate)
updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
}
16 changes: 1 addition & 15 deletions callbacks/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func SaveBeforeAssociations(db *gorm.DB) {

func Create(db *gorm.DB) {
db.Statement.Build("WITH", "INSERT", "VALUES", "ON_CONFLICT", "RETURNING")

db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
fmt.Println(db.Statement.SQL.String(), db.Statement.Vars)
}

Expand All @@ -29,17 +29,3 @@ func AfterCreate(db *gorm.DB) {
// after save
// after create
}

func objectToFieldsMap(stmt *gorm.Statement) {
if stmt.Schema != nil {
if s, ok := stmt.Clauses["SELECT"]; ok {
s.Attrs
}

if s, ok := stmt.Clauses["OMIT"]; ok {
s.Attrs
}

stmt.Schema.LookUpField(s.S)
}
}
12 changes: 12 additions & 0 deletions callbacks/delete.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package callbacks

import "github.com/jinzhu/gorm"

func BeforeDelete(db *gorm.DB) {
}

func Delete(db *gorm.DB) {
}

func AfterDelete(db *gorm.DB) {
}
9 changes: 9 additions & 0 deletions callbacks/transaction.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package callbacks

import "github.com/jinzhu/gorm"

func BeginTransaction(db *gorm.DB) {
}

func CommitOrRollbackTransaction(db *gorm.DB) {
}
12 changes: 12 additions & 0 deletions callbacks/update.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package callbacks

import "github.com/jinzhu/gorm"

func BeforeUpdate(db *gorm.DB) {
}

func Update(db *gorm.DB) {
}

func AfterUpdate(db *gorm.DB) {
}
5 changes: 0 additions & 5 deletions dialects/sqlite/go.mod

This file was deleted.

2 changes: 0 additions & 2 deletions dialects/sqlite/go.sum

This file was deleted.

5 changes: 1 addition & 4 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,4 @@ module github.com/jinzhu/gorm

go 1.13

require (
github.com/jinzhu/inflection v1.0.0
gopkg.in/errgo.v2 v2.1.0
)
require github.com/jinzhu/inflection v1.0.0
2 changes: 0 additions & 2 deletions go.sum

This file was deleted.

3 changes: 2 additions & 1 deletion gorm.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,11 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
Config: config,
Dialector: dialector,
clone: true,
callbacks: InitializeCallbacks(),
cacheStore: &sync.Map{},
}

db.callbacks = initializeCallbacks(db)

if dialector != nil {
err = dialector.Initialize(db)
}
Expand Down
14 changes: 11 additions & 3 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ type Instance struct {
Statement *Statement
}

func (instance Instance) ToSQL(clauses ...string) (string, []interface{}) {
if len(clauses) > 0 {
instance.Statement.Build(clauses...)
}
return instance.Statement.SQL.String(), instance.Statement.Vars
}

// AddError add error to instance
func (inst Instance) AddError(err error) {
if inst.Error == nil {
Expand Down Expand Up @@ -205,16 +212,17 @@ func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (con

// Build build sql with clauses names
func (stmt Statement) Build(clauses ...string) {
var includeSpace bool
var firstClauseWritten bool

for _, name := range clauses {
if c, ok := stmt.Clauses[name]; ok {
if includeSpace {
if firstClauseWritten {
stmt.WriteByte(' ')
}

includeSpace = true
firstClauseWritten = true
c.Build(stmt)
}
}
// TODO handle named vars
}
4 changes: 2 additions & 2 deletions tests/callbacks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ func TestCallbacks(t *testing.T) {
}

for idx, data := range datas {
var err error
callbacks := gorm.InitializeCallbacks()
db, err := gorm.Open(nil, nil)
callbacks := db.Callback()

for _, c := range data.callbacks {
var v interface{} = callbacks.Create()
Expand Down

0 comments on commit 728c0d4

Please sign in to comment.