Skip to content

Commit

Permalink
Add returning support to delete
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Oct 27, 2021
1 parent af3fbdc commit 835d7bd
Show file tree
Hide file tree
Showing 10 changed files with 68 additions and 46 deletions.
2 changes: 1 addition & 1 deletion callbacks/callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func RegisterDefaultCallbacks(db *gorm.DB, config *Config) {
deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction)
deleteCallback.Register("gorm:before_delete", BeforeDelete)
deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations)
deleteCallback.Register("gorm:delete", Delete)
deleteCallback.Register("gorm:delete", Delete(config))
deleteCallback.Register("gorm:after_delete", AfterDelete)
deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction)
deleteCallback.Clauses = config.DeleteClauses
Expand Down
27 changes: 9 additions & 18 deletions callbacks/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)

func BeforeCreate(db *gorm.DB) {
Expand All @@ -31,18 +32,12 @@ func BeforeCreate(db *gorm.DB) {
}

func Create(config *Config) func(db *gorm.DB) {
withReturning := false
for _, clause := range config.CreateClauses {
if clause == "RETURNING" {
withReturning = true
}
}
supportReturning := utils.Contains(config.CreateClauses, "RETURNING")

return func(db *gorm.DB) {
if db.Error != nil {
return
}
onReturning := false

if db.Statement.Schema != nil {
if !db.Statement.Unscoped {
Expand All @@ -51,8 +46,7 @@ func Create(config *Config) func(db *gorm.DB) {
}
}

if withReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 {
onReturning = true
if supportReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 {
if _, ok := db.Statement.Clauses["RETURNING"]; !ok {
fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue))
for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue {
Expand All @@ -72,18 +66,15 @@ func Create(config *Config) func(db *gorm.DB) {
}

if !db.DryRun && db.Error == nil {
if onReturning {
doNothing := false

if ok, mode := hasReturning(db, supportReturning); ok {
if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok {
onConflict, _ := c.Expression.(clause.OnConflict)
doNothing = onConflict.DoNothing
if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.DoNothing {
mode |= gorm.ScanOnConflictDoNothing
}
}
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
if doNothing {
gorm.Scan(rows, db, gorm.ScanUpdate|gorm.ScanOnConflictDoNothing)
} else {
gorm.Scan(rows, db, gorm.ScanUpdate)
}
gorm.Scan(rows, db, mode)
rows.Close()
}
} else {
Expand Down
25 changes: 18 additions & 7 deletions callbacks/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)

func BeforeDelete(db *gorm.DB) {
Expand Down Expand Up @@ -104,8 +105,14 @@ func DeleteBeforeAssociations(db *gorm.DB) {
}
}

func Delete(db *gorm.DB) {
if db.Error == nil {
func Delete(config *Config) func(db *gorm.DB) {
supportReturning := utils.Contains(config.DeleteClauses, "RETURNING")

return func(db *gorm.DB) {
if db.Error != nil {
return
}

if db.Statement.Schema != nil && !db.Statement.Unscoped {
for _, c := range db.Statement.Schema.DeleteClauses {
db.Statement.AddClause(c)
Expand Down Expand Up @@ -144,12 +151,16 @@ func Delete(db *gorm.DB) {
}

if !db.DryRun && db.Error == nil {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)

if err == nil {
db.RowsAffected, _ = result.RowsAffected()
if ok, mode := hasReturning(db, supportReturning); ok {
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
gorm.Scan(rows, db, mode)
rows.Close()
}
} else {
db.AddError(err)
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected()
}
}
}
}
Expand Down
13 changes: 13 additions & 0 deletions callbacks/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,16 @@ func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[st
}
return
}

func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) {
if supportReturning {
if c, ok := tx.Statement.Clauses["RETURNING"]; ok {
returning, _ := c.Expression.(clause.Returning)
if len(returning.Columns) == 0 || (len(returning.Columns) == 1 && returning.Columns[0].Name == "*") {
return true, 0
}
return true, gorm.ScanUpdate
}
}
return false, 0
}
16 changes: 5 additions & 11 deletions callbacks/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)

func SetupUpdateReflectValue(db *gorm.DB) {
Expand Down Expand Up @@ -51,12 +52,7 @@ func BeforeUpdate(db *gorm.DB) {
}

func Update(config *Config) func(db *gorm.DB) {
withReturning := false
for _, clause := range config.UpdateClauses {
if clause == "RETURNING" {
withReturning = true
}
}
supportReturning := utils.Contains(config.UpdateClauses, "RETURNING")

return func(db *gorm.DB) {
if db.Error != nil {
Expand Down Expand Up @@ -86,18 +82,16 @@ func Update(config *Config) func(db *gorm.DB) {
}

if !db.DryRun && db.Error == nil {
if _, ok := db.Statement.Clauses["RETURNING"]; withReturning && ok {
if ok, mode := hasReturning(db, supportReturning); ok {
if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
gorm.Scan(rows, db, gorm.ScanUpdate)
gorm.Scan(rows, db, mode)
rows.Close()
}
} else {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)

if err == nil {
if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected()
} else {
db.AddError(err)
}
}
}
Expand Down
14 changes: 9 additions & 5 deletions clause/returning.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,16 @@ func (returning Returning) Name() string {

// Build build where clause
func (returning Returning) Build(builder Builder) {
for idx, column := range returning.Columns {
if idx > 0 {
builder.WriteByte(',')
}
if len(returning.Columns) > 0 {
for idx, column := range returning.Columns {
if idx > 0 {
builder.WriteByte(',')
}

builder.WriteQuoted(column)
builder.WriteQuoted(column)
}
} else {
builder.WriteByte('*')
}
}

Expand Down
2 changes: 1 addition & 1 deletion scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ func Scan(rows *sql.Rows, db *DB, mode ScanMode) {
case reflect.Slice, reflect.Array:
var elem reflect.Value

if !update {
if !update && reflectValue.Len() != 0 {
db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20))
}

Expand Down
4 changes: 2 additions & 2 deletions tests/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ require (
gorm.io/driver/mysql v1.1.2
gorm.io/driver/postgres v1.2.0
gorm.io/driver/sqlite v1.2.0
gorm.io/driver/sqlserver v1.1.1
gorm.io/gorm v1.21.16
gorm.io/driver/sqlserver v1.1.2
gorm.io/gorm v1.22.0
)

replace gorm.io/gorm => ../
2 changes: 1 addition & 1 deletion tests/update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ func TestUpdates(t *testing.T) {
}

// update with gorm exprs
if err := DB.Debug().Model(&user3).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil {
if err := DB.Model(&user3).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil {
t.Errorf("Not error should happen when updating with gorm expr, but got %v", err)
}
var user4 User
Expand Down
9 changes: 9 additions & 0 deletions utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ func ToStringKey(values ...interface{}) string {
return strings.Join(results, "_")
}

func Contains(elems []string, elem string) bool {
for _, e := range elems {
if elem == e {
return true
}
}
return false
}

func AssertEqual(src, dst interface{}) bool {
if !reflect.DeepEqual(src, dst) {
if valuer, ok := src.(driver.Valuer); ok {
Expand Down

0 comments on commit 835d7bd

Please sign in to comment.