Skip to content

Commit

Permalink
Merge pull request #1 from go-gorm/master
Browse files Browse the repository at this point in the history
merge
  • Loading branch information
iesreza authored Nov 22, 2020
2 parents 1e241aa + 6186a4d commit a2a5799
Show file tree
Hide file tree
Showing 26 changed files with 567 additions and 101 deletions.
8 changes: 4 additions & 4 deletions association.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func (association *Association) Replace(values ...interface{}) error {

if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 {
column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap)
association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error
}
case schema.Many2Many:
var (
Expand Down Expand Up @@ -154,7 +154,7 @@ func (association *Association) Replace(values ...interface{}) error {
tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues}))
}

tx.Delete(modelValue)
association.Error = tx.Delete(modelValue).Error
}
}
return association.Error
Expand Down Expand Up @@ -417,7 +417,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear)

// TODO support save slice data, sql with case?
association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Index(i).Addr().Interface()).Error
association.Error = association.DB.Session(&Session{NewDB: true}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Index(i).Addr().Interface()).Error
}
case reflect.Struct:
// clear old data
Expand All @@ -439,7 +439,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{
}

if len(values) > 0 {
association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Addr().Interface()).Error
association.Error = association.DB.Session(&Session{NewDB: true}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Addr().Interface()).Error
}
}

Expand Down
66 changes: 52 additions & 14 deletions callbacks/associations.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package callbacks

import (
"reflect"
"strings"

"gorm.io/gorm"
"gorm.io/gorm/clause"
Expand Down Expand Up @@ -66,7 +67,7 @@ func SaveBeforeAssociations(db *gorm.DB) {
}

if elems.Len() > 0 {
if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) == nil {
if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil {
for i := 0; i < elems.Len(); i++ {
setupReferences(objs[i], elems.Index(i))
}
Expand All @@ -79,7 +80,7 @@ func SaveBeforeAssociations(db *gorm.DB) {
rv = rv.Addr()
}

if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(rv.Interface()).Error) == nil {
if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil {
setupReferences(db.Statement.ReflectValue, rv)
}
}
Expand Down Expand Up @@ -141,9 +142,7 @@ func SaveAfterAssociations(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}

db.AddError(db.Session(&gorm.Session{}).Clauses(
onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
).Create(elems.Interface()).Error)
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
}
case reflect.Struct:
if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero {
Expand All @@ -163,9 +162,7 @@ func SaveAfterAssociations(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}

db.AddError(db.Session(&gorm.Session{}).Clauses(
onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
).Create(f.Interface()).Error)
saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns)
}
}
}
Expand Down Expand Up @@ -224,9 +221,7 @@ func SaveAfterAssociations(db *gorm.DB) {
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
}

db.AddError(db.Session(&gorm.Session{}).Clauses(
onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns),
).Create(elems.Interface()).Error)
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
}
}

Expand Down Expand Up @@ -291,24 +286,30 @@ func SaveAfterAssociations(db *gorm.DB) {
}

if elems.Len() > 0 {
db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error)
if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil)
}

for i := 0; i < elems.Len(); i++ {
appendToJoins(objs[i], elems.Index(i))
}
}

if joins.Len() > 0 {
db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(joins.Interface()).Error)
db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Create(joins.Interface()).Error)
}
}
}
}

func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) clause.OnConflict {
func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) clause.OnConflict {
if stmt.DB.FullSaveAssociations {
defaultUpdatingColumns = make([]string, 0, len(s.DBNames))
for _, dbName := range s.DBNames {
if v, ok := selectColumns[dbName]; (ok && !v) || (!ok && restricted) {
continue
}

if !s.LookUpField(dbName).PrimaryKey {
defaultUpdatingColumns = append(defaultUpdatingColumns, dbName)
}
Expand All @@ -333,3 +334,40 @@ func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingCol

return clause.OnConflict{DoNothing: true}
}

func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
var (
selects, omits []string
onConflict = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns)
refName = rel.Name + "."
)

for name, ok := range selectColumns {
columnName := ""
if strings.HasPrefix(name, refName) {
columnName = strings.TrimPrefix(name, refName)
} else if strings.HasPrefix(name, clause.Associations) {
columnName = name
}

if columnName != "" {
if ok {
selects = append(selects, columnName)
} else {
omits = append(omits, columnName)
}
}
}

tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks})

if len(selects) > 0 {
tx = tx.Select(selects)
}

if len(omits) > 0 {
tx = tx.Omit(omits...)
}

return db.AddError(tx.Create(values).Error)
}
2 changes: 1 addition & 1 deletion callbacks/callmethod.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
)

func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) {
tx := db.Session(&gorm.Session{})
tx := db.Session(&gorm.Session{NewDB: true})
if called := fc(db.Statement.ReflectValue.Interface(), tx); !called {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
Expand Down
39 changes: 22 additions & 17 deletions callbacks/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
)

func BeforeCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.BeforeSave {
if i, ok := value.(BeforeSaveInterface); ok {
Expand Down Expand Up @@ -55,6 +55,7 @@ func Create(config *Config) func(db *gorm.DB) {

if err == nil {
db.RowsAffected, _ = result.RowsAffected()

if db.RowsAffected > 0 {
if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue {
if insertID, err := result.LastInsertId(); err == nil && insertID > 0 {
Expand Down Expand Up @@ -138,6 +139,7 @@ func CreateWithReturning(db *gorm.DB) {
}

if !db.DryRun && db.Error == nil {
db.RowsAffected = 0
rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)

if err == nil {
Expand Down Expand Up @@ -201,7 +203,7 @@ func CreateWithReturning(db *gorm.DB) {
}

func AfterCreate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.AfterSave {
if i, ok := value.(AfterSaveInterface); ok {
Expand Down Expand Up @@ -329,26 +331,29 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) {
}
}

if stmt.UpdatingColumn {
if stmt.Schema != nil && len(values.Columns) > 1 {
columns := make([]string, 0, len(values.Columns)-1)
for _, column := range values.Columns {
if field := stmt.Schema.LookUpField(column.Name); field != nil {
if !field.PrimaryKey && !field.HasDefaultValue && field.AutoCreateTime == 0 {
columns = append(columns, column.Name)
if c, ok := stmt.Clauses["ON CONFLICT"]; ok {
if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll {
if stmt.Schema != nil && len(values.Columns) > 1 {
columns := make([]string, 0, len(values.Columns)-1)
for _, column := range values.Columns {
if field := stmt.Schema.LookUpField(column.Name); field != nil {
if !field.PrimaryKey && !field.HasDefaultValue && field.AutoCreateTime == 0 {
columns = append(columns, column.Name)
}
}
}
}

onConflict := clause.OnConflict{
Columns: make([]clause.Column, len(stmt.Schema.PrimaryFieldDBNames)),
DoUpdates: clause.AssignmentColumns(columns),
}
onConflict := clause.OnConflict{
Columns: make([]clause.Column, len(stmt.Schema.PrimaryFieldDBNames)),
DoUpdates: clause.AssignmentColumns(columns),
}

for idx, field := range stmt.Schema.PrimaryFields {
onConflict.Columns[idx] = clause.Column{Name: field.DBName}
}

for idx, field := range stmt.Schema.PrimaryFields {
onConflict.Columns[idx] = clause.Column{Name: field.DBName}
stmt.AddClause(onConflict)
}
stmt.AddClause(onConflict)
}
}

Expand Down
8 changes: 4 additions & 4 deletions callbacks/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
)

func BeforeDelete(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.BeforeDelete {
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(BeforeDeleteInterface); ok {
db.AddError(i.BeforeDelete(tx))
Expand All @@ -34,7 +34,7 @@ func DeleteBeforeAssociations(db *gorm.DB) {
case schema.HasOne, schema.HasMany:
queryConds := rel.ToQueryConditions(db.Statement.ReflectValue)
modelValue := reflect.New(rel.FieldSchema.ModelType).Interface()
tx := db.Session(&gorm.Session{}).Model(modelValue)
tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue)
withoutConditions := false

if len(db.Statement.Selects) > 0 {
Expand Down Expand Up @@ -71,7 +71,7 @@ func DeleteBeforeAssociations(db *gorm.DB) {
relForeignKeys []string
modelValue = reflect.New(rel.JoinTable.ModelType).Interface()
table = rel.JoinTable.Table
tx = db.Session(&gorm.Session{}).Model(modelValue).Table(table)
tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table)
)

for _, ref := range rel.References {
Expand Down Expand Up @@ -153,7 +153,7 @@ func Delete(db *gorm.DB) {
}

func AfterDelete(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterDelete {
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(AfterDeleteInterface); ok {
db.AddError(i.AfterDelete(tx))
Expand Down
2 changes: 1 addition & 1 deletion callbacks/preload.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
var (
reflectValue = db.Statement.ReflectValue
rel = rels[len(rels)-1]
tx = db.Session(&gorm.Session{})
tx = db.Session(&gorm.Session{NewDB: true}).Model(nil).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks})
relForeignKeys []string
relForeignFields []*schema.Field
foreignFields []*schema.Field
Expand Down
28 changes: 16 additions & 12 deletions callbacks/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,26 +68,28 @@ func BuildQuerySQL(db *gorm.DB) {
clauseSelect.Columns = make([]clause.Column, 0, len(db.Statement.Schema.DBNames))
for _, dbName := range db.Statement.Schema.DBNames {
if v, ok := selectColumns[dbName]; (ok && v) || !ok {
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{Name: dbName})
clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{Table: db.Statement.Table, Name: dbName})
}
}
} else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() {
smallerStruct := false
switch db.Statement.ReflectValue.Kind() {
case reflect.Struct:
smallerStruct = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType
case reflect.Slice:
smallerStruct = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType
queryFields := db.QueryFields
if !queryFields {
switch db.Statement.ReflectValue.Kind() {
case reflect.Struct:
queryFields = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType
case reflect.Slice:
queryFields = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType
}
}

if smallerStruct {
if queryFields {
stmt := gorm.Statement{DB: db}
// smaller struct
if err := stmt.Parse(db.Statement.Dest); err == nil && stmt.Schema.ModelType != db.Statement.Schema.ModelType {
if err := stmt.Parse(db.Statement.Dest); err == nil && (db.QueryFields || stmt.Schema.ModelType != db.Statement.Schema.ModelType) {
clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames))

for idx, dbName := range stmt.Schema.DBNames {
clauseSelect.Columns[idx] = clause.Column{Name: dbName}
clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
}
}
}
Expand Down Expand Up @@ -206,13 +208,15 @@ func Preload(db *gorm.DB) {
}
}

preload(db, rels, db.Statement.Preloads[name])
if db.Error == nil {
preload(db, rels, db.Statement.Preloads[name])
}
}
}
}

func AfterQuery(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind {
callMethod(db, func(value interface{}, tx *gorm.DB) bool {
if i, ok := value.(AfterFindInterface); ok {
db.AddError(i.AfterFind(tx))
Expand Down
8 changes: 4 additions & 4 deletions callbacks/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func SetupUpdateReflectValue(db *gorm.DB) {
}

func BeforeUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.BeforeSave {
if i, ok := value.(BeforeSaveInterface); ok {
Expand Down Expand Up @@ -87,7 +87,7 @@ func Update(db *gorm.DB) {
}

func AfterUpdate(db *gorm.DB) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) {
callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) {
if db.Statement.Schema.AfterSave {
if i, ok := value.(AfterSaveInterface); ok {
Expand Down Expand Up @@ -198,7 +198,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
}
}

if !stmt.UpdatingColumn && stmt.Schema != nil {
if !stmt.SkipHooks && stmt.Schema != nil {
for _, dbName := range stmt.Schema.DBNames {
field := stmt.Schema.LookUpField(dbName)
if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil {
Expand Down Expand Up @@ -228,7 +228,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) {
if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) {
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
value, isZero := field.ValueOf(updatingValue)
if !stmt.UpdatingColumn {
if !stmt.SkipHooks {
if field.AutoUpdateTime > 0 {
if field.AutoUpdateTime == schema.UnixNanosecond {
value = stmt.DB.NowFunc().UnixNano()
Expand Down
Loading

0 comments on commit a2a5799

Please sign in to comment.