diff --git a/callbacks/associations.go b/callbacks/associations.go index 76fc5b814..3c8c2a507 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -10,7 +10,7 @@ import ( ) func SaveBeforeAssociations(db *gorm.DB) { - if db.Statement.Schema != nil { + if db.Error == nil && db.Statement.Schema != nil { selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false) // Save Belongs To associations @@ -83,7 +83,7 @@ func SaveBeforeAssociations(db *gorm.DB) { } func SaveAfterAssociations(db *gorm.DB) { - if db.Statement.Schema != nil { + if db.Error == nil && db.Statement.Schema != nil { selectColumns, restricted := SelectAndOmitColumns(db.Statement, true, false) // Save Has One associations diff --git a/callbacks/create.go b/callbacks/create.go index f558d7aec..7a2b8bfe0 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -9,20 +9,21 @@ import ( ) func BeforeCreate(db *gorm.DB) { - if db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { + if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { + tx := db.Session(&gorm.Session{}) callMethod := func(value interface{}) bool { var ok bool if db.Statement.Schema.BeforeSave { if i, ok := value.(gorm.BeforeSaveInterface); ok { ok = true - i.BeforeSave(db) + db.AddError(i.BeforeSave(tx)) } } if db.Statement.Schema.BeforeCreate { if i, ok := value.(gorm.BeforeCreateInterface); ok { ok = true - i.BeforeCreate(db) + db.AddError(i.BeforeCreate(tx)) } } return ok @@ -31,7 +32,7 @@ func BeforeCreate(db *gorm.DB) { if ok := callMethod(db.Statement.Dest); !ok { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { callMethod(db.Statement.ReflectValue.Index(i).Interface()) } case reflect.Struct: @@ -46,146 +47,151 @@ func Create(config *Config) func(db *gorm.DB) { return CreateWithReturning } else { return func(db *gorm.DB) { - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.CreateClauses { - db.Statement.AddClause(c) + if db.Error == nil { + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.CreateClauses { + db.Statement.AddClause(c) + } } - } - - if db.Statement.SQL.String() == "" { - db.Statement.AddClauseIfNotExists(clause.Insert{ - Table: clause.Table{Name: db.Statement.Table}, - }) - db.Statement.AddClause(ConvertToCreateValues(db.Statement)) - db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") - } + if db.Statement.SQL.String() == "" { + db.Statement.AddClauseIfNotExists(clause.Insert{ + Table: clause.Table{Name: db.Statement.Table}, + }) + db.Statement.AddClause(ConvertToCreateValues(db.Statement)) - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") + } - if err == nil { - if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil { - if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok { - if insertID, err := result.LastInsertId(); err == nil { - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - if config.LastInsertIDReversed { - for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) - insertID-- - } - } else { - for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) - insertID++ + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + + if err == nil { + if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil { + if _, ok := db.Statement.Schema.FieldsWithDefaultDBValue[db.Statement.Schema.PrioritizedPrimaryField.DBName]; ok { + if insertID, err := result.LastInsertId(); err == nil { + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + if config.LastInsertIDReversed { + for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + insertID-- + } + } else { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID) + insertID++ + } } + case reflect.Struct: + db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) } - case reflect.Struct: - db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID) + } else { + db.AddError(err) } - } else { - db.AddError(err) } } + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) } - db.RowsAffected, _ = result.RowsAffected() - } else { - db.AddError(err) } } } } func CreateWithReturning(db *gorm.DB) { - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.CreateClauses { - db.Statement.AddClause(c) + if db.Error == nil { + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.CreateClauses { + db.Statement.AddClause(c) + } } - } - if db.Statement.SQL.String() == "" { - db.Statement.AddClauseIfNotExists(clause.Insert{ - Table: clause.Table{Name: db.Statement.Table}, - }) - db.Statement.AddClause(ConvertToCreateValues(db.Statement)) + if db.Statement.SQL.String() == "" { + db.Statement.AddClauseIfNotExists(clause.Insert{ + Table: clause.Table{Name: db.Statement.Table}, + }) + db.Statement.AddClause(ConvertToCreateValues(db.Statement)) - db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") - } + db.Statement.Build("INSERT", "VALUES", "ON CONFLICT") + } - if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 { - db.Statement.WriteString(" RETURNING ") + if sch := db.Statement.Schema; sch != nil && len(sch.FieldsWithDefaultDBValue) > 0 { + db.Statement.WriteString(" RETURNING ") - var ( - idx int - fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue)) - values = make([]interface{}, len(sch.FieldsWithDefaultDBValue)) - ) + var ( + idx int + fields = make([]*schema.Field, len(sch.FieldsWithDefaultDBValue)) + values = make([]interface{}, len(sch.FieldsWithDefaultDBValue)) + ) - for dbName, field := range sch.FieldsWithDefaultDBValue { - if idx != 0 { - db.Statement.WriteByte(',') - } + for dbName, field := range sch.FieldsWithDefaultDBValue { + if idx != 0 { + db.Statement.WriteByte(',') + } - fields[idx] = field - db.Statement.WriteQuoted(dbName) - idx++ - } + fields[idx] = field + db.Statement.WriteQuoted(dbName) + idx++ + } - rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - if err == nil { - defer rows.Close() + if err == nil { + defer rows.Close() - switch db.Statement.ReflectValue.Kind() { - case reflect.Slice, reflect.Array: - for rows.Next() { - for idx, field := range fields { - values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() + switch db.Statement.ReflectValue.Kind() { + case reflect.Slice, reflect.Array: + for rows.Next() { + for idx, field := range fields { + values[idx] = field.ReflectValueOf(db.Statement.ReflectValue.Index(int(db.RowsAffected))).Addr().Interface() + } + if err := rows.Scan(values...); err != nil { + db.AddError(err) + } + db.RowsAffected++ } - if err := rows.Scan(values...); err != nil { - db.AddError(err) + case reflect.Struct: + for idx, field := range fields { + values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() } - db.RowsAffected++ - } - case reflect.Struct: - for idx, field := range fields { - values[idx] = field.ReflectValueOf(db.Statement.ReflectValue).Addr().Interface() - } - if rows.Next() { - db.RowsAffected++ - err = rows.Scan(values...) + if rows.Next() { + db.RowsAffected++ + err = rows.Scan(values...) + } } } - } - if err != nil { - db.AddError(err) - } - } else { - if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { - db.RowsAffected, _ = result.RowsAffected() + if err != nil { + db.AddError(err) + } } else { - db.AddError(err) + if result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); err == nil { + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) + } } } } func AfterCreate(db *gorm.DB) { - if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { + if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { + tx := db.Session(&gorm.Session{}) callMethod := func(value interface{}) bool { var ok bool if db.Statement.Schema.AfterSave { if i, ok := value.(gorm.AfterSaveInterface); ok { ok = true - i.AfterSave(db) + db.AddError(i.AfterSave(tx)) } } if db.Statement.Schema.AfterCreate { if i, ok := value.(gorm.AfterCreateInterface); ok { ok = true - i.AfterCreate(db) + db.AddError(i.AfterCreate(tx)) } } return ok @@ -194,7 +200,7 @@ func AfterCreate(db *gorm.DB) { if ok := callMethod(db.Statement.Dest); !ok { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { callMethod(db.Statement.ReflectValue.Index(i).Interface()) } case reflect.Struct: diff --git a/callbacks/delete.go b/callbacks/delete.go index b3278c833..582a76f46 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -9,11 +9,12 @@ import ( ) func BeforeDelete(db *gorm.DB) { - if db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete { + if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete { + tx := db.Session(&gorm.Session{}) callMethod := func(value interface{}) bool { if db.Statement.Schema.BeforeDelete { if i, ok := value.(gorm.BeforeDeleteInterface); ok { - i.BeforeDelete(db) + db.AddError(i.BeforeDelete(tx)) return true } } @@ -23,7 +24,7 @@ func BeforeDelete(db *gorm.DB) { if ok := callMethod(db.Statement.Dest); !ok { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { callMethod(db.Statement.ReflectValue.Index(i).Interface()) } case reflect.Struct: @@ -34,57 +35,60 @@ func BeforeDelete(db *gorm.DB) { } func Delete(db *gorm.DB) { - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.DeleteClauses { - db.Statement.AddClause(c) + if db.Error == nil { + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.DeleteClauses { + db.Statement.AddClause(c) + } } - } - if db.Statement.SQL.String() == "" { - db.Statement.AddClauseIfNotExists(clause.Delete{}) + if db.Statement.SQL.String() == "" { + db.Statement.AddClauseIfNotExists(clause.Delete{}) - if db.Statement.Schema != nil { - _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) - column, values := schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues) - - if len(values) > 0 { - db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) - } - - if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { - _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) - column, values = schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues) + if db.Statement.Schema != nil { + _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) + column, values := schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) } + + if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { + _, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) + column, values = schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues) + + if len(values) > 0 { + db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) + } + } } - } - if _, ok := db.Statement.Clauses["WHERE"]; !ok { - db.AddError(gorm.ErrMissingWhereClause) - return - } + if _, ok := db.Statement.Clauses["WHERE"]; !ok { + db.AddError(gorm.ErrMissingWhereClause) + return + } - db.Statement.AddClauseIfNotExists(clause.From{}) - db.Statement.Build("DELETE", "FROM", "WHERE") - } + db.Statement.AddClauseIfNotExists(clause.From{}) + db.Statement.Build("DELETE", "FROM", "WHERE") + } - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - if err == nil { - db.RowsAffected, _ = result.RowsAffected() - } else { - db.AddError(err) + if err == nil { + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) + } } } func AfterDelete(db *gorm.DB) { - if db.Statement.Schema != nil && db.Statement.Schema.AfterDelete { + if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete { + tx := db.Session(&gorm.Session{}) callMethod := func(value interface{}) bool { if db.Statement.Schema.AfterDelete { if i, ok := value.(gorm.AfterDeleteInterface); ok { - i.AfterDelete(db) + db.AddError(i.AfterDelete(tx)) return true } } @@ -94,7 +98,7 @@ func AfterDelete(db *gorm.DB) { if ok := callMethod(db.Statement.Dest); !ok { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { callMethod(db.Statement.ReflectValue.Index(i).Interface()) } case reflect.Struct: diff --git a/callbacks/query.go b/callbacks/query.go index 55f2c65ba..919480312 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -12,24 +12,26 @@ import ( ) func Query(db *gorm.DB) { - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.QueryClauses { - db.Statement.AddClause(c) + if db.Error == nil { + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.QueryClauses { + db.Statement.AddClause(c) + } } - } - if db.Statement.SQL.String() == "" { - BuildQuerySQL(db) - } + if db.Statement.SQL.String() == "" { + BuildQuerySQL(db) + } - rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - if err != nil { - db.AddError(err) - return - } - defer rows.Close() + rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if err != nil { + db.AddError(err) + return + } + defer rows.Close() - gorm.Scan(rows, db, false) + gorm.Scan(rows, db, false) + } } func BuildQuerySQL(db *gorm.DB) { @@ -129,50 +131,53 @@ func BuildQuerySQL(db *gorm.DB) { } func Preload(db *gorm.DB) { - if len(db.Statement.Preloads) > 0 { - preloadMap := map[string][]string{} - for name := range db.Statement.Preloads { - preloadFields := strings.Split(name, ".") - for idx := range preloadFields { - preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1] + if db.Error == nil { + if len(db.Statement.Preloads) > 0 { + preloadMap := map[string][]string{} + for name := range db.Statement.Preloads { + preloadFields := strings.Split(name, ".") + for idx := range preloadFields { + preloadMap[strings.Join(preloadFields[:idx+1], ".")] = preloadFields[:idx+1] + } } - } - preloadNames := make([]string, len(preloadMap)) - idx := 0 - for key := range preloadMap { - preloadNames[idx] = key - idx++ - } - sort.Strings(preloadNames) - - for _, name := range preloadNames { - var ( - curSchema = db.Statement.Schema - preloadFields = preloadMap[name] - rels = make([]*schema.Relationship, len(preloadFields)) - ) - - for idx, preloadField := range preloadFields { - if rel := curSchema.Relationships.Relations[preloadField]; rel != nil { - rels[idx] = rel - curSchema = rel.FieldSchema - } else { - db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation)) - } + preloadNames := make([]string, len(preloadMap)) + idx := 0 + for key := range preloadMap { + preloadNames[idx] = key + idx++ } + sort.Strings(preloadNames) + + for _, name := range preloadNames { + var ( + curSchema = db.Statement.Schema + preloadFields = preloadMap[name] + rels = make([]*schema.Relationship, len(preloadFields)) + ) + + for idx, preloadField := range preloadFields { + if rel := curSchema.Relationships.Relations[preloadField]; rel != nil { + rels[idx] = rel + curSchema = rel.FieldSchema + } else { + db.AddError(fmt.Errorf("%v: %w", name, gorm.ErrUnsupportedRelation)) + } + } - preload(db, rels, db.Statement.Preloads[name]) + preload(db, rels, db.Statement.Preloads[name]) + } } } } func AfterQuery(db *gorm.DB) { - if db.Statement.Schema != nil && db.Statement.Schema.AfterFind { + if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind { + tx := db.Session(&gorm.Session{}) callMethod := func(value interface{}) bool { if db.Statement.Schema.AfterFind { if i, ok := value.(gorm.AfterFindInterface); ok { - i.AfterFind(db) + db.AddError(i.AfterFind(tx)) return true } } @@ -182,7 +187,7 @@ func AfterQuery(db *gorm.DB) { if ok := callMethod(db.Statement.Dest); !ok { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { callMethod(db.Statement.ReflectValue.Index(i).Interface()) } case reflect.Struct: diff --git a/callbacks/raw.go b/callbacks/raw.go index ce125e616..cb0cd6c9e 100644 --- a/callbacks/raw.go +++ b/callbacks/raw.go @@ -5,10 +5,12 @@ import ( ) func RawExec(db *gorm.DB) { - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - if err != nil { - db.AddError(err) - } else { - db.RowsAffected, _ = result.RowsAffected() + if db.Error == nil { + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if err != nil { + db.AddError(err) + } else { + db.RowsAffected, _ = result.RowsAffected() + } } } diff --git a/callbacks/row.go b/callbacks/row.go index 004a89d52..f4ff734cd 100644 --- a/callbacks/row.go +++ b/callbacks/row.go @@ -5,13 +5,15 @@ import ( ) func RowQuery(db *gorm.DB) { - if db.Statement.SQL.String() == "" { - BuildQuerySQL(db) - } + if db.Error == nil { + if db.Statement.SQL.String() == "" { + BuildQuerySQL(db) + } - if _, ok := db.Get("rows"); ok { - db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - } else { - db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + if _, ok := db.Get("rows"); ok { + db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + } else { + db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + } } } diff --git a/callbacks/transaction.go b/callbacks/transaction.go index 253c4e82f..630153645 100644 --- a/callbacks/transaction.go +++ b/callbacks/transaction.go @@ -1,9 +1,25 @@ package callbacks -import "github.com/jinzhu/gorm" +import ( + "github.com/jinzhu/gorm" +) func BeginTransaction(db *gorm.DB) { + if tx := db.Begin(); tx.Error == nil { + db.Statement.ConnPool = tx.Statement.ConnPool + tx.InstanceSet("gorm:started_transaction", true) + } else { + tx.Error = nil + } } func CommitOrRollbackTransaction(db *gorm.DB) { + if _, ok := db.InstanceGet("gorm:started_transaction"); ok { + if db.Error == nil { + db.Commit() + } else { + db.Rollback() + } + db.Statement.ConnPool = db.ConnPool + } } diff --git a/callbacks/update.go b/callbacks/update.go index c16b77d16..cbbcddf7d 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -10,20 +10,21 @@ import ( ) func BeforeUpdate(db *gorm.DB) { - if db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { + if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { + tx := db.Session(&gorm.Session{}) callMethod := func(value interface{}) bool { var ok bool if db.Statement.Schema.BeforeSave { if i, ok := value.(gorm.BeforeSaveInterface); ok { ok = true - i.BeforeSave(db) + db.AddError(i.BeforeSave(tx)) } } if db.Statement.Schema.BeforeUpdate { if i, ok := value.(gorm.BeforeUpdateInterface); ok { ok = true - i.BeforeUpdate(db) + db.AddError(i.BeforeUpdate(tx)) } } return ok @@ -32,7 +33,7 @@ func BeforeUpdate(db *gorm.DB) { if ok := callMethod(db.Statement.Dest); !ok { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { callMethod(db.Statement.ReflectValue.Index(i).Interface()) } case reflect.Struct: @@ -43,51 +44,54 @@ func BeforeUpdate(db *gorm.DB) { } func Update(db *gorm.DB) { - if db.Statement.Schema != nil && !db.Statement.Unscoped { - for _, c := range db.Statement.Schema.UpdateClauses { - db.Statement.AddClause(c) + if db.Error == nil { + if db.Statement.Schema != nil && !db.Statement.Unscoped { + for _, c := range db.Statement.Schema.UpdateClauses { + db.Statement.AddClause(c) + } } - } - if db.Statement.SQL.String() == "" { - db.Statement.AddClauseIfNotExists(clause.Update{}) - if set := ConvertToAssignments(db.Statement); len(set) != 0 { - db.Statement.AddClause(set) - } else { - return + if db.Statement.SQL.String() == "" { + db.Statement.AddClauseIfNotExists(clause.Update{}) + if set := ConvertToAssignments(db.Statement); len(set) != 0 { + db.Statement.AddClause(set) + } else { + return + } + db.Statement.Build("UPDATE", "SET", "WHERE") } - db.Statement.Build("UPDATE", "SET", "WHERE") - } - if _, ok := db.Statement.Clauses["WHERE"]; !ok { - db.AddError(gorm.ErrMissingWhereClause) - return - } + if _, ok := db.Statement.Clauses["WHERE"]; !ok { + db.AddError(gorm.ErrMissingWhereClause) + return + } - result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) + result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) - if err == nil { - db.RowsAffected, _ = result.RowsAffected() - } else { - db.AddError(err) + if err == nil { + db.RowsAffected, _ = result.RowsAffected() + } else { + db.AddError(err) + } } } func AfterUpdate(db *gorm.DB) { - if db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { + if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { + tx := db.Session(&gorm.Session{}) callMethod := func(value interface{}) bool { var ok bool if db.Statement.Schema.AfterSave { if i, ok := value.(gorm.AfterSaveInterface); ok { ok = true - i.AfterSave(db) + db.AddError(i.AfterSave(tx)) } } if db.Statement.Schema.AfterUpdate { if i, ok := value.(gorm.AfterUpdateInterface); ok { ok = true - i.AfterUpdate(db) + db.AddError(i.AfterUpdate(tx)) } } return ok @@ -96,7 +100,7 @@ func AfterUpdate(db *gorm.DB) { if ok := callMethod(db.Statement.Dest); !ok { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: - for i := 0; i <= db.Statement.ReflectValue.Len(); i++ { + for i := 0; i < db.Statement.ReflectValue.Len(); i++ { callMethod(db.Statement.ReflectValue.Index(i).Interface()) } case reflect.Struct: diff --git a/errors.go b/errors.go index 140a5186f..82f24df21 100644 --- a/errors.go +++ b/errors.go @@ -16,7 +16,7 @@ var ( // ErrNotImplemented not implemented ErrNotImplemented = errors.New("not implemented") // ErrMissingWhereClause missing where clause - ErrMissingWhereClause = errors.New("missing WHERE clause while deleting") + ErrMissingWhereClause = errors.New("WHERE conditions required") // ErrUnsupportedRelation unsupported relations ErrUnsupportedRelation = errors.New("unsupported relations") // ErrPtrStructSupported only ptr of struct supported diff --git a/gorm.go b/gorm.go index ac4bff5ed..c1d6f8da7 100644 --- a/gorm.go +++ b/gorm.go @@ -40,14 +40,15 @@ type DB struct { Error error RowsAffected int64 Statement *Statement - clone bool + clone int } // Session session config when create session with Session() method type Session struct { - Context context.Context - Logger logger.Interface - NowFunc func() time.Time + WithConditions bool + Context context.Context + Logger logger.Interface + NowFunc func() time.Time } // Open initialize db session based on dialector @@ -76,10 +77,7 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { config.cacheStore = &sync.Map{} } - db = &DB{ - Config: config, - clone: true, - } + db = &DB{Config: config, clone: 1} db.callbacks = initializeCallbacks(db) @@ -96,38 +94,54 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { // Session create new db session func (db *DB) Session(config *Session) *DB { var ( - tx = db.getInstance() - stmt = tx.Statement.clone() - txConfig = *tx.Config + txConfig = *db.Config + tx = &DB{ + Config: &txConfig, + Statement: db.Statement, + clone: 1, + } ) if config.Context != nil { - stmt.Context = config.Context + if tx.Statement != nil { + tx.Statement = tx.Statement.clone() + } else { + tx.Statement = &Statement{ + DB: tx, + Clauses: map[string]clause.Clause{}, + ConnPool: tx.ConnPool, + } + } + + tx.Statement.Context = config.Context + } + + if config.WithConditions { + tx.clone = 3 } if config.Logger != nil { - txConfig.Logger = config.Logger + tx.Config.Logger = config.Logger } if config.NowFunc != nil { - txConfig.NowFunc = config.NowFunc + tx.Config.NowFunc = config.NowFunc } - return &DB{ - Config: &txConfig, - Statement: stmt, - clone: true, - } + return tx } // WithContext change current instance db's context to ctx func (db *DB) WithContext(ctx context.Context) *DB { - return db.Session(&Session{Context: ctx}) + return db.Session(&Session{WithConditions: true, Context: ctx}) } // Debug start debug mode func (db *DB) Debug() (tx *DB) { - return db.Session(&Session{Logger: db.Logger.LogMode(logger.Info)}) + return db.Session(&Session{ + WithConditions: true, + Logger: db.Logger.LogMode(logger.Info), + }) } // Set store value with key into current db instance's context @@ -145,6 +159,21 @@ func (db *DB) Get(key string) (interface{}, bool) { return nil, false } +// InstanceSet store value with key into current db instance's context +func (db *DB) InstanceSet(key string, value interface{}) *DB { + tx := db.getInstance() + tx.Statement.Settings.Store(fmt.Sprintf("%p", tx.Statement)+key, value) + return tx +} + +// InstanceGet get value with key from current db instance's context +func (db *DB) InstanceGet(key string) (interface{}, bool) { + if db.Statement != nil { + return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key) + } + return nil, false +} + // Callback returns callback manager func (db *DB) Callback() *callbacks { return db.callbacks @@ -166,18 +195,37 @@ func (db *DB) AddError(err error) error { } func (db *DB) getInstance() *DB { - if db.clone { - stmt := &Statement{DB: db, Clauses: map[string]clause.Clause{}} + if db.clone > 0 { + tx := &DB{Config: db.Config} + + switch db.clone { + case 1: // clone with new statement + case 2: // with old statement, generate new statement for future call, used to pass to callbacks + db.clone = 1 + tx.Statement = db.Statement + case 3: // with clone statement + if db.Statement != nil { + tx.Statement = db.Statement.clone() + tx.Statement.DB = tx + } + } + + if tx.Statement == nil { + tx.Statement = &Statement{ + DB: tx, + Clauses: map[string]clause.Clause{}, + } + } if db.Statement != nil { - stmt.Context = db.Statement.Context - stmt.ConnPool = db.Statement.ConnPool + tx.Statement.Context = db.Statement.Context + tx.Statement.ConnPool = db.Statement.ConnPool } else { - stmt.Context = context.Background() - stmt.ConnPool = db.ConnPool + tx.Statement.Context = context.Background() + tx.Statement.ConnPool = db.ConnPool } - return &DB{Config: db.Config, Statement: stmt} + return tx } return db diff --git a/interfaces.go b/interfaces.go index 9dd00c15a..14d8fa341 100644 --- a/interfaces.go +++ b/interfaces.go @@ -36,37 +36,37 @@ type TxCommiter interface { } type BeforeCreateInterface interface { - BeforeCreate(*DB) + BeforeCreate(*DB) error } type AfterCreateInterface interface { - AfterCreate(*DB) + AfterCreate(*DB) error } type BeforeUpdateInterface interface { - BeforeUpdate(*DB) + BeforeUpdate(*DB) error } type AfterUpdateInterface interface { - AfterUpdate(*DB) + AfterUpdate(*DB) error } type BeforeSaveInterface interface { - BeforeSave(*DB) + BeforeSave(*DB) error } type AfterSaveInterface interface { - AfterSave(*DB) + AfterSave(*DB) error } type BeforeDeleteInterface interface { - BeforeDelete(*DB) + BeforeDelete(*DB) error } type AfterDeleteInterface interface { - AfterDelete(*DB) + AfterDelete(*DB) error } type AfterFindInterface interface { - AfterFind(*DB) + AfterFind(*DB) error } diff --git a/schema/callbacks_test.go b/schema/callbacks_test.go index 720c9a5b0..efa01e899 100644 --- a/schema/callbacks_test.go +++ b/schema/callbacks_test.go @@ -12,10 +12,12 @@ import ( type UserWithCallback struct { } -func (UserWithCallback) BeforeSave(*gorm.DB) { +func (UserWithCallback) BeforeSave(*gorm.DB) error { + return nil } -func (UserWithCallback) AfterCreate(*gorm.DB) { +func (UserWithCallback) AfterCreate(*gorm.DB) error { + return nil } func TestCallback(t *testing.T) { diff --git a/schema/schema.go b/schema/schema.go index 77b9832cd..231ed1db0 100644 --- a/schema/schema.go +++ b/schema/schema.go @@ -200,12 +200,12 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) } } - reflectValue := reflect.Indirect(reflect.New(modelType)) + reflectValue := reflect.New(modelType) callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} for _, name := range callbacks { if methodValue := reflectValue.MethodByName(name); methodValue.IsValid() { switch methodValue.Type().String() { - case "func(*gorm.DB)": // TODO hack + case "func(*gorm.DB) error": // TODO hack reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true) default: logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be %v(*gorm.DB)", schema, name, name) diff --git a/tests/hooks_test.go b/tests/hooks_test.go new file mode 100644 index 000000000..432226a32 --- /dev/null +++ b/tests/hooks_test.go @@ -0,0 +1,201 @@ +package tests_test + +import ( + "errors" + "reflect" + "testing" + + "github.com/jinzhu/gorm" + . "github.com/jinzhu/gorm/tests" +) + +type Product struct { + gorm.Model + Name string + Code string + Price float64 + AfterFindCallTimes int64 + BeforeCreateCallTimes int64 + AfterCreateCallTimes int64 + BeforeUpdateCallTimes int64 + AfterUpdateCallTimes int64 + BeforeSaveCallTimes int64 + AfterSaveCallTimes int64 + BeforeDeleteCallTimes int64 + AfterDeleteCallTimes int64 +} + +func (s *Product) BeforeCreate(tx *gorm.DB) (err error) { + if s.Code == "Invalid" { + err = errors.New("invalid product") + } + s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1 + return +} + +func (s *Product) BeforeUpdate(tx *gorm.DB) (err error) { + if s.Code == "dont_update" { + err = errors.New("can't update") + } + s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1 + return +} + +func (s *Product) BeforeSave(tx *gorm.DB) (err error) { + if s.Code == "dont_save" { + err = errors.New("can't save") + } + s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1 + return +} + +func (s *Product) AfterFind(tx *gorm.DB) (err error) { + s.AfterFindCallTimes = s.AfterFindCallTimes + 1 + return +} + +func (s *Product) AfterCreate(tx *gorm.DB) (err error) { + return tx.Model(s).UpdateColumn("AfterCreateCallTimes", s.AfterCreateCallTimes+1).Error +} + +func (s *Product) AfterUpdate(tx *gorm.DB) (err error) { + s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1 + return +} + +func (s *Product) AfterSave(tx *gorm.DB) (err error) { + if s.Code == "after_save_error" { + err = errors.New("can't save") + } + s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1 + return +} + +func (s *Product) BeforeDelete(tx *gorm.DB) (err error) { + if s.Code == "dont_delete" { + err = errors.New("can't delete") + } + s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1 + return +} + +func (s *Product) AfterDelete(tx *gorm.DB) (err error) { + if s.Code == "after_delete_error" { + err = errors.New("can't delete") + } + s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1 + return +} + +func (s *Product) GetCallTimes() []int64 { + return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes} +} + +func TestRunCallbacks(t *testing.T) { + DB.Migrator().DropTable(&Product{}) + DB.AutoMigrate(&Product{}) + + p := Product{Code: "unique_code", Price: 100} + DB.Save(&p) + + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) { + t.Errorf("Callbacks should be invoked successfully, %v", p.GetCallTimes()) + } + + DB.Where("Code = ?", "unique_code").First(&p) + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) { + t.Fatalf("After callbacks values are not saved, %v", p.GetCallTimes()) + } + + p.Price = 200 + DB.Save(&p) + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) { + t.Fatalf("After update callbacks should be invoked successfully, %v", p.GetCallTimes()) + } + + var products []Product + DB.Find(&products, "code = ?", "unique_code") + if products[0].AfterFindCallTimes != 1 { + t.Fatalf("AfterFind callbacks should work with slice, called %v", products[0].AfterFindCallTimes) + } + + DB.Where("Code = ?", "unique_code").First(&p) + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) { + t.Fatalf("After update callbacks values are not saved, %v", p.GetCallTimes()) + } + + DB.Delete(&p) + if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) { + t.Fatalf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes()) + } + + if DB.Where("Code = ?", "unique_code").First(&p).Error == nil { + t.Fatalf("Can't find a deleted record") + } +} + +func TestCallbacksWithErrors(t *testing.T) { + DB.Migrator().DropTable(&Product{}) + DB.AutoMigrate(&Product{}) + + p := Product{Code: "Invalid", Price: 100} + if DB.Save(&p).Error == nil { + t.Fatalf("An error from before create callbacks happened when create with invalid value") + } + + if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil { + t.Fatalf("Should not save record that have errors") + } + + if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil { + t.Fatalf("An error from after create callbacks happened when create with invalid value") + } + + p2 := Product{Code: "update_callback", Price: 100} + DB.Save(&p2) + + p2.Code = "dont_update" + if DB.Save(&p2).Error == nil { + t.Fatalf("An error from before update callbacks happened when update with invalid value") + } + + if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil { + t.Fatalf("Record Should not be updated due to errors happened in before update callback") + } + + if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil { + t.Fatalf("Record Should not be updated due to errors happened in before update callback") + } + + p2.Code = "dont_save" + if DB.Save(&p2).Error == nil { + t.Fatalf("An error from before save callbacks happened when update with invalid value") + } + + p3 := Product{Code: "dont_delete", Price: 100} + DB.Save(&p3) + if DB.Delete(&p3).Error == nil { + t.Fatalf("An error from before delete callbacks happened when delete") + } + + if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil { + t.Fatalf("An error from before delete callbacks happened") + } + + p4 := Product{Code: "after_save_error", Price: 100} + DB.Save(&p4) + if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil { + t.Fatalf("Record should be reverted if get an error in after save callback") + } + + p5 := Product{Code: "after_delete_error", Price: 100} + DB.Save(&p5) + if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { + t.Fatalf("Record should be found") + } + + DB.Delete(&p5) + if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { + t.Fatalf("Record shouldn't be deleted because of an error happened in after delete callback") + } +} diff --git a/tests/tests.go b/tests/tests.go index 7e216776e..d92578985 100644 --- a/tests/tests.go +++ b/tests/tests.go @@ -59,9 +59,9 @@ func OpenTestConnection() (db *gorm.DB, err error) { } if debug := os.Getenv("DEBUG"); debug == "true" { - db.Logger.LogMode(logger.Info) + db.Logger = db.Logger.LogMode(logger.Info) } else if debug == "false" { - db.Logger.LogMode(logger.Silent) + db.Logger = db.Logger.LogMode(logger.Silent) } return diff --git a/tests/transaction_test.go b/tests/transaction_test.go index 9405fd764..f39b3167c 100644 --- a/tests/transaction_test.go +++ b/tests/transaction_test.go @@ -14,37 +14,37 @@ func TestTransaction(t *testing.T) { user := *GetUser("transcation", Config{}) if err := tx.Save(&user).Error; err != nil { - t.Errorf("No error should raise, but got %v", err) + t.Fatalf("No error should raise, but got %v", err) } if err := tx.First(&User{}, "name = ?", "transcation").Error; err != nil { - t.Errorf("Should find saved record, but got %v", err) + t.Fatalf("Should find saved record, but got %v", err) } if sqlTx, ok := tx.Statement.ConnPool.(*sql.Tx); !ok || sqlTx == nil { - t.Errorf("Should return the underlying sql.Tx") + t.Fatalf("Should return the underlying sql.Tx") } tx.Rollback() if err := DB.First(&User{}, "name = ?", "transcation").Error; err == nil { - t.Errorf("Should not find record after rollback, but got %v", err) + t.Fatalf("Should not find record after rollback, but got %v", err) } tx2 := DB.Begin() user2 := *GetUser("transcation-2", Config{}) if err := tx2.Save(&user2).Error; err != nil { - t.Errorf("No error should raise, but got %v", err) + t.Fatalf("No error should raise, but got %v", err) } if err := tx2.First(&User{}, "name = ?", "transcation-2").Error; err != nil { - t.Errorf("Should find saved record, but got %v", err) + t.Fatalf("Should find saved record, but got %v", err) } tx2.Commit() if err := DB.First(&User{}, "name = ?", "transcation-2").Error; err != nil { - t.Errorf("Should be able to find committed record, but got %v", err) + t.Fatalf("Should be able to find committed record, but got %v", err) } } @@ -52,7 +52,7 @@ func TestTransactionWithBlock(t *testing.T) { assertPanic := func(f func()) { defer func() { if r := recover(); r == nil { - t.Errorf("The code did not panic") + t.Fatalf("The code did not panic") } }() f() @@ -62,39 +62,39 @@ func TestTransactionWithBlock(t *testing.T) { err := DB.Transaction(func(tx *gorm.DB) error { user := *GetUser("transcation-block", Config{}) if err := tx.Save(&user).Error; err != nil { - t.Errorf("No error should raise") + t.Fatalf("No error should raise") } if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { - t.Errorf("Should find saved record") + t.Fatalf("Should find saved record") } return errors.New("the error message") }) if err.Error() != "the error message" { - t.Errorf("Transaction return error will equal the block returns error") + t.Fatalf("Transaction return error will equal the block returns error") } if err := DB.First(&User{}, "name = ?", "transcation-block").Error; err == nil { - t.Errorf("Should not find record after rollback") + t.Fatalf("Should not find record after rollback") } // commit DB.Transaction(func(tx *gorm.DB) error { user := *GetUser("transcation-block-2", Config{}) if err := tx.Save(&user).Error; err != nil { - t.Errorf("No error should raise") + t.Fatalf("No error should raise") } if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { - t.Errorf("Should find saved record") + t.Fatalf("Should find saved record") } return nil }) if err := DB.First(&User{}, "name = ?", "transcation-block-2").Error; err != nil { - t.Errorf("Should be able to find committed record") + t.Fatalf("Should be able to find committed record") } // panic will rollback @@ -102,11 +102,11 @@ func TestTransactionWithBlock(t *testing.T) { DB.Transaction(func(tx *gorm.DB) error { user := *GetUser("transcation-block-3", Config{}) if err := tx.Save(&user).Error; err != nil { - t.Errorf("No error should raise") + t.Fatalf("No error should raise") } if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { - t.Errorf("Should find saved record") + t.Fatalf("Should find saved record") } panic("force panic") @@ -114,7 +114,7 @@ func TestTransactionWithBlock(t *testing.T) { }) if err := DB.First(&User{}, "name = ?", "transcation-block-3").Error; err == nil { - t.Errorf("Should not find record after panic rollback") + t.Fatalf("Should not find record after panic rollback") } } @@ -122,14 +122,14 @@ func TestTransactionRaiseErrorOnRollbackAfterCommit(t *testing.T) { tx := DB.Begin() user := User{Name: "transcation"} if err := tx.Save(&user).Error; err != nil { - t.Errorf("No error should raise") + t.Fatalf("No error should raise") } if err := tx.Commit().Error; err != nil { - t.Errorf("Commit should not raise error") + t.Fatalf("Commit should not raise error") } if err := tx.Rollback().Error; err == nil { - t.Errorf("Rollback after commit should raise error") + t.Fatalf("Rollback after commit should raise error") } }