diff --git a/association.go b/association.go index 140ae6acd..7adb8c914 100644 --- a/association.go +++ b/association.go @@ -118,7 +118,7 @@ func (association *Association) Replace(values ...interface{}) error { if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 { column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) - tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap) + association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error } case schema.Many2Many: var ( @@ -154,7 +154,7 @@ func (association *Association) Replace(values ...interface{}) error { tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues})) } - tx.Delete(modelValue) + association.Error = tx.Delete(modelValue).Error } } return association.Error @@ -417,7 +417,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear) // TODO support save slice data, sql with case? - association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Index(i).Addr().Interface()).Error + association.Error = association.DB.Session(&Session{NewDB: true}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Index(i).Addr().Interface()).Error } case reflect.Struct: // clear old data @@ -439,7 +439,7 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } if len(values) > 0 { - association.Error = association.DB.Session(&Session{}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Addr().Interface()).Error + association.Error = association.DB.Session(&Session{NewDB: true}).Select(selectedSaveColumns).Model(nil).Updates(reflectValue.Addr().Interface()).Error } } diff --git a/callbacks/associations.go b/callbacks/associations.go index 1e6f62c58..e66696004 100644 --- a/callbacks/associations.go +++ b/callbacks/associations.go @@ -2,6 +2,7 @@ package callbacks import ( "reflect" + "strings" "gorm.io/gorm" "gorm.io/gorm/clause" @@ -66,7 +67,7 @@ func SaveBeforeAssociations(db *gorm.DB) { } if elems.Len() > 0 { - if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) == nil { + if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil { for i := 0; i < elems.Len(); i++ { setupReferences(objs[i], elems.Index(i)) } @@ -79,7 +80,7 @@ func SaveBeforeAssociations(db *gorm.DB) { rv = rv.Addr() } - if db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(rv.Interface()).Error) == nil { + if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil { setupReferences(db.Statement.ReflectValue, rv) } } @@ -141,9 +142,7 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.AddError(db.Session(&gorm.Session{}).Clauses( - onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), - ).Create(elems.Interface()).Error) + saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) } case reflect.Struct: if _, zero := rel.Field.ValueOf(db.Statement.ReflectValue); !zero { @@ -163,9 +162,7 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.AddError(db.Session(&gorm.Session{}).Clauses( - onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), - ).Create(f.Interface()).Error) + saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns) } } } @@ -224,9 +221,7 @@ func SaveAfterAssociations(db *gorm.DB) { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } - db.AddError(db.Session(&gorm.Session{}).Clauses( - onConflictOption(db.Statement, rel.FieldSchema, assignmentColumns), - ).Create(elems.Interface()).Error) + saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns) } } @@ -291,7 +286,9 @@ func SaveAfterAssociations(db *gorm.DB) { } if elems.Len() > 0 { - db.AddError(db.Session(&gorm.Session{}).Clauses(onConflictOption(db.Statement, rel.FieldSchema, nil)).Create(elems.Interface()).Error) + if v, ok := selectColumns[rel.Name+".*"]; !ok || v { + saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) + } for i := 0; i < elems.Len(); i++ { appendToJoins(objs[i], elems.Index(i)) @@ -299,16 +296,20 @@ func SaveAfterAssociations(db *gorm.DB) { } if joins.Len() > 0 { - db.AddError(db.Session(&gorm.Session{}).Clauses(clause.OnConflict{DoNothing: true}).Create(joins.Interface()).Error) + db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Create(joins.Interface()).Error) } } } } -func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) clause.OnConflict { +func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) clause.OnConflict { if stmt.DB.FullSaveAssociations { defaultUpdatingColumns = make([]string, 0, len(s.DBNames)) for _, dbName := range s.DBNames { + if v, ok := selectColumns[dbName]; (ok && !v) || (!ok && restricted) { + continue + } + if !s.LookUpField(dbName).PrimaryKey { defaultUpdatingColumns = append(defaultUpdatingColumns, dbName) } @@ -333,3 +334,40 @@ func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingCol return clause.OnConflict{DoNothing: true} } + +func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error { + var ( + selects, omits []string + onConflict = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns) + refName = rel.Name + "." + ) + + for name, ok := range selectColumns { + columnName := "" + if strings.HasPrefix(name, refName) { + columnName = strings.TrimPrefix(name, refName) + } else if strings.HasPrefix(name, clause.Associations) { + columnName = name + } + + if columnName != "" { + if ok { + selects = append(selects, columnName) + } else { + omits = append(omits, columnName) + } + } + } + + tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks}) + + if len(selects) > 0 { + tx = tx.Select(selects) + } + + if len(omits) > 0 { + tx = tx.Omit(omits...) + } + + return db.AddError(tx.Create(values).Error) +} diff --git a/callbacks/callmethod.go b/callbacks/callmethod.go index b81fc915f..bcaa03f3d 100644 --- a/callbacks/callmethod.go +++ b/callbacks/callmethod.go @@ -7,7 +7,7 @@ import ( ) func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) { - tx := db.Session(&gorm.Session{}) + tx := db.Session(&gorm.Session{NewDB: true}) if called := fc(db.Statement.ReflectValue.Interface(), tx); !called { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: diff --git a/callbacks/create.go b/callbacks/create.go index 67f3ab143..3ca56d733 100644 --- a/callbacks/create.go +++ b/callbacks/create.go @@ -10,7 +10,7 @@ import ( ) func BeforeCreate(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.BeforeSave { if i, ok := value.(BeforeSaveInterface); ok { @@ -55,6 +55,7 @@ func Create(config *Config) func(db *gorm.DB) { if err == nil { db.RowsAffected, _ = result.RowsAffected() + if db.RowsAffected > 0 { if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil && db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue { if insertID, err := result.LastInsertId(); err == nil && insertID > 0 { @@ -138,6 +139,7 @@ func CreateWithReturning(db *gorm.DB) { } if !db.DryRun && db.Error == nil { + db.RowsAffected = 0 rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err == nil { @@ -201,7 +203,7 @@ func CreateWithReturning(db *gorm.DB) { } func AfterCreate(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.AfterSave { if i, ok := value.(AfterSaveInterface); ok { @@ -329,26 +331,29 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { } } - if stmt.UpdatingColumn { - if stmt.Schema != nil && len(values.Columns) > 1 { - columns := make([]string, 0, len(values.Columns)-1) - for _, column := range values.Columns { - if field := stmt.Schema.LookUpField(column.Name); field != nil { - if !field.PrimaryKey && !field.HasDefaultValue && field.AutoCreateTime == 0 { - columns = append(columns, column.Name) + if c, ok := stmt.Clauses["ON CONFLICT"]; ok { + if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll { + if stmt.Schema != nil && len(values.Columns) > 1 { + columns := make([]string, 0, len(values.Columns)-1) + for _, column := range values.Columns { + if field := stmt.Schema.LookUpField(column.Name); field != nil { + if !field.PrimaryKey && !field.HasDefaultValue && field.AutoCreateTime == 0 { + columns = append(columns, column.Name) + } } } - } - onConflict := clause.OnConflict{ - Columns: make([]clause.Column, len(stmt.Schema.PrimaryFieldDBNames)), - DoUpdates: clause.AssignmentColumns(columns), - } + onConflict := clause.OnConflict{ + Columns: make([]clause.Column, len(stmt.Schema.PrimaryFieldDBNames)), + DoUpdates: clause.AssignmentColumns(columns), + } + + for idx, field := range stmt.Schema.PrimaryFields { + onConflict.Columns[idx] = clause.Column{Name: field.DBName} + } - for idx, field := range stmt.Schema.PrimaryFields { - onConflict.Columns[idx] = clause.Column{Name: field.DBName} + stmt.AddClause(onConflict) } - stmt.AddClause(onConflict) } } diff --git a/callbacks/delete.go b/callbacks/delete.go index 0f4bcd6be..867aa6970 100644 --- a/callbacks/delete.go +++ b/callbacks/delete.go @@ -10,7 +10,7 @@ import ( ) func BeforeDelete(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.BeforeDelete { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.BeforeDelete { callMethod(db, func(value interface{}, tx *gorm.DB) bool { if i, ok := value.(BeforeDeleteInterface); ok { db.AddError(i.BeforeDelete(tx)) @@ -34,7 +34,7 @@ func DeleteBeforeAssociations(db *gorm.DB) { case schema.HasOne, schema.HasMany: queryConds := rel.ToQueryConditions(db.Statement.ReflectValue) modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() - tx := db.Session(&gorm.Session{}).Model(modelValue) + tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue) withoutConditions := false if len(db.Statement.Selects) > 0 { @@ -71,7 +71,7 @@ func DeleteBeforeAssociations(db *gorm.DB) { relForeignKeys []string modelValue = reflect.New(rel.JoinTable.ModelType).Interface() table = rel.JoinTable.Table - tx = db.Session(&gorm.Session{}).Model(modelValue).Table(table) + tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table) ) for _, ref := range rel.References { @@ -153,7 +153,7 @@ func Delete(db *gorm.DB) { } func AfterDelete(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterDelete { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterDelete { callMethod(db, func(value interface{}, tx *gorm.DB) bool { if i, ok := value.(AfterDeleteInterface); ok { db.AddError(i.AfterDelete(tx)) diff --git a/callbacks/preload.go b/callbacks/preload.go index d60079e48..682427c9e 100644 --- a/callbacks/preload.go +++ b/callbacks/preload.go @@ -13,7 +13,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) { var ( reflectValue = db.Statement.ReflectValue rel = rels[len(rels)-1] - tx = db.Session(&gorm.Session{}) + tx = db.Session(&gorm.Session{NewDB: true}).Model(nil).Session(&gorm.Session{SkipHooks: db.Statement.SkipHooks}) relForeignKeys []string relForeignFields []*schema.Field foreignFields []*schema.Field diff --git a/callbacks/query.go b/callbacks/query.go index 8613e46d6..aa4629a2e 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -68,26 +68,28 @@ func BuildQuerySQL(db *gorm.DB) { clauseSelect.Columns = make([]clause.Column, 0, len(db.Statement.Schema.DBNames)) for _, dbName := range db.Statement.Schema.DBNames { if v, ok := selectColumns[dbName]; (ok && v) || !ok { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{Name: dbName}) + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{Table: db.Statement.Table, Name: dbName}) } } } else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() { - smallerStruct := false - switch db.Statement.ReflectValue.Kind() { - case reflect.Struct: - smallerStruct = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType - case reflect.Slice: - smallerStruct = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType + queryFields := db.QueryFields + if !queryFields { + switch db.Statement.ReflectValue.Kind() { + case reflect.Struct: + queryFields = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType + case reflect.Slice: + queryFields = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType + } } - if smallerStruct { + if queryFields { stmt := gorm.Statement{DB: db} // smaller struct - if err := stmt.Parse(db.Statement.Dest); err == nil && stmt.Schema.ModelType != db.Statement.Schema.ModelType { + if err := stmt.Parse(db.Statement.Dest); err == nil && (db.QueryFields || stmt.Schema.ModelType != db.Statement.Schema.ModelType) { clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames)) for idx, dbName := range stmt.Schema.DBNames { - clauseSelect.Columns[idx] = clause.Column{Name: dbName} + clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} } } } @@ -206,13 +208,15 @@ func Preload(db *gorm.DB) { } } - preload(db, rels, db.Statement.Preloads[name]) + if db.Error == nil { + preload(db, rels, db.Statement.Preloads[name]) + } } } } func AfterQuery(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil && db.Statement.Schema.AfterFind { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind { callMethod(db, func(value interface{}, tx *gorm.DB) bool { if i, ok := value.(AfterFindInterface); ok { db.AddError(i.AfterFind(tx)) diff --git a/callbacks/update.go b/callbacks/update.go index 46f59157a..c8f3922eb 100644 --- a/callbacks/update.go +++ b/callbacks/update.go @@ -29,7 +29,7 @@ func SetupUpdateReflectValue(db *gorm.DB) { } func BeforeUpdate(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.BeforeSave { if i, ok := value.(BeforeSaveInterface); ok { @@ -87,7 +87,7 @@ func Update(db *gorm.DB) { } func AfterUpdate(db *gorm.DB) { - if db.Error == nil && db.Statement.Schema != nil && !db.Statement.UpdatingColumn && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { + if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.AfterSave { if i, ok := value.(AfterSaveInterface); ok { @@ -198,7 +198,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } } - if !stmt.UpdatingColumn && stmt.Schema != nil { + if !stmt.SkipHooks && stmt.Schema != nil { for _, dbName := range stmt.Schema.DBNames { field := stmt.Schema.LookUpField(dbName) if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { @@ -228,7 +228,7 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if !field.PrimaryKey || (!updatingValue.CanAddr() || stmt.Dest != stmt.Model) { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { value, isZero := field.ValueOf(updatingValue) - if !stmt.UpdatingColumn { + if !stmt.SkipHooks { if field.AutoUpdateTime > 0 { if field.AutoUpdateTime == schema.UnixNanosecond { value = stmt.DB.NowFunc().UnixNano() diff --git a/clause/expression.go b/clause/expression.go index 40265ac6a..b30c46b03 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -160,8 +160,13 @@ func (in IN) Build(builder Builder) { case 0: builder.WriteString(" IN (NULL)") case 1: - builder.WriteString(" = ") - builder.AddVar(builder, in.Values...) + if _, ok := in.Values[0].([]interface{}); !ok { + builder.WriteString(" = ") + builder.AddVar(builder, in.Values[0]) + break + } + + fallthrough default: builder.WriteString(" IN (") builder.AddVar(builder, in.Values...) @@ -173,9 +178,14 @@ func (in IN) NegationBuild(builder Builder) { switch len(in.Values) { case 0: case 1: - builder.WriteQuoted(in.Column) - builder.WriteString(" <> ") - builder.AddVar(builder, in.Values...) + if _, ok := in.Values[0].([]interface{}); !ok { + builder.WriteQuoted(in.Column) + builder.WriteString(" <> ") + builder.AddVar(builder, in.Values[0]) + break + } + + fallthrough default: builder.WriteQuoted(in.Column) builder.WriteString(" NOT IN (") diff --git a/clause/on_conflict.go b/clause/on_conflict.go index 47f69fc9a..47fe169ca 100644 --- a/clause/on_conflict.go +++ b/clause/on_conflict.go @@ -5,6 +5,7 @@ type OnConflict struct { Where Where DoNothing bool DoUpdates Set + UpdateAll bool } func (OnConflict) Name() string { diff --git a/finisher_api.go b/finisher_api.go index 857f94198..f2aed8da6 100644 --- a/finisher_api.go +++ b/finisher_api.go @@ -21,6 +21,29 @@ func (db *DB) Create(value interface{}) (tx *DB) { return } +// CreateInBatches insert the value in batches into database +func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { + reflectValue := reflect.Indirect(reflect.ValueOf(value)) + + switch reflectValue.Kind() { + case reflect.Slice, reflect.Array: + tx = db.getInstance() + for i := 0; i < reflectValue.Len(); i += batchSize { + tx.AddError(tx.Transaction(func(tx *DB) error { + ends := i + batchSize + if ends > reflectValue.Len() { + ends = reflectValue.Len() + } + + return tx.Create(reflectValue.Slice(i, ends).Interface()).Error + })) + } + default: + return db.Create(value) + } + return +} + // Save update value in database, if the value doesn't have primary key, will insert it func (db *DB) Save(value interface{}) (tx *DB) { tx = db.getInstance() @@ -29,7 +52,9 @@ func (db *DB) Save(value interface{}) (tx *DB) { reflectValue := reflect.Indirect(reflect.ValueOf(value)) switch reflectValue.Kind() { case reflect.Slice, reflect.Array: - tx.Statement.UpdatingColumn = true + if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok { + tx = tx.Clauses(clause.OnConflict{UpdateAll: true}) + } tx.callbacks.Create().Execute(tx) case reflect.Struct: if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { @@ -53,7 +78,7 @@ func (db *DB) Save(value interface{}) (tx *DB) { if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate { result := reflect.New(tx.Statement.Schema.ModelType).Interface() - if err := tx.Session(&Session{WithConditions: true}).First(result).Error; errors.Is(err, ErrRecordNotFound) { + if err := tx.Session(&Session{}).First(result).Error; errors.Is(err, ErrRecordNotFound) { return tx.Create(value) } } @@ -115,13 +140,18 @@ func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { } // FindInBatches find records in batches -func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) (tx *DB) { - tx = db.Session(&Session{WithConditions: true}) - rowsAffected := int64(0) - batch := 0 +func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB { + var ( + tx = db.Order(clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + }).Session(&Session{}) + queryDB = tx + rowsAffected int64 + batch int + ) for { - result := tx.Limit(batchSize).Offset(batch * batchSize).Find(dest) + result := queryDB.Limit(batchSize).Find(dest) rowsAffected += result.RowsAffected batch++ @@ -131,11 +161,15 @@ func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, bat if tx.Error != nil || int(result.RowsAffected) < batchSize { break + } else { + resultsValue := reflect.Indirect(reflect.ValueOf(dest)) + primaryValue, _ := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(resultsValue.Index(resultsValue.Len() - 1)) + queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) } } tx.RowsAffected = rowsAffected - return + return tx } func (tx *DB) assignInterfacesToValue(values ...interface{}) { @@ -186,7 +220,11 @@ func (tx *DB) assignInterfacesToValue(values ...interface{}) { } func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { - if tx = db.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) { + queryTx := db.Limit(1).Order(clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + }) + + if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 { if c, ok := tx.Statement.Clauses["WHERE"]; ok { if where, ok := c.Expression.(clause.Where); ok { tx.assignInterfacesToValue(where.Exprs) @@ -197,7 +235,6 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { if len(tx.Statement.attrs) > 0 { tx.assignInterfacesToValue(tx.Statement.attrs...) } - tx.Error = nil } // initialize with attrs, conds @@ -208,9 +245,11 @@ func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { } func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { - if tx = db.First(dest, conds...); errors.Is(tx.Error, ErrRecordNotFound) { - tx.Error = nil + queryTx := db.Limit(1).Order(clause.OrderByColumn{ + Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, + }) + if tx = queryTx.Find(dest, conds...); queryTx.RowsAffected == 0 { if c, ok := tx.Statement.Clauses["WHERE"]; ok { if where, ok := c.Expression.(clause.Where); ok { tx.assignInterfacesToValue(where.Exprs) @@ -268,7 +307,7 @@ func (db *DB) Updates(values interface{}) (tx *DB) { func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = map[string]interface{}{column: value} - tx.Statement.UpdatingColumn = true + tx.Statement.SkipHooks = true tx.callbacks.Update().Execute(tx) return } @@ -276,7 +315,7 @@ func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { func (db *DB) UpdateColumns(values interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = values - tx.Statement.UpdatingColumn = true + tx.Statement.SkipHooks = true tx.callbacks.Update().Execute(tx) return } @@ -433,7 +472,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { // nested transaction - db.SavePoint(fmt.Sprintf("sp%p", fc)) + err = db.SavePoint(fmt.Sprintf("sp%p", fc)).Error defer func() { // Make sure to rollback when panic, Block error or Commit error if panicked || err != nil { @@ -441,7 +480,9 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er } }() - err = fc(db.Session(&Session{WithConditions: true})) + if err == nil { + err = fc(db.Session(&Session{})) + } } else { tx := db.Begin(opts...) @@ -452,7 +493,9 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er } }() - err = fc(tx) + if err = tx.Error; err == nil { + err = fc(tx) + } if err == nil { err = tx.Commit().Error @@ -467,7 +510,7 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er func (db *DB) Begin(opts ...*sql.TxOptions) *DB { var ( // clone statement - tx = db.Session(&Session{WithConditions: true, Context: db.Statement.Context}) + tx = db.Session(&Session{Context: db.Statement.Context}) opt *sql.TxOptions err error ) diff --git a/gorm.go b/gorm.go index 2dfbb855f..1947b4dff 100644 --- a/gorm.go +++ b/gorm.go @@ -36,6 +36,8 @@ type Config struct { DisableForeignKeyConstraintWhenMigrating bool // AllowGlobalUpdate allow global update AllowGlobalUpdate bool + // QueryFields executes the SQL query with all fields of the table + QueryFields bool // ClauseBuilders clause builder ClauseBuilders map[string]clause.ClauseBuilder @@ -63,10 +65,12 @@ type DB struct { type Session struct { DryRun bool PrepareStmt bool - WithConditions bool + NewDB bool + SkipHooks bool SkipDefaultTransaction bool AllowGlobalUpdate bool FullSaveAssociations bool + QueryFields bool Context context.Context Logger logger.Interface NowFunc func() time.Time @@ -169,15 +173,17 @@ func (db *DB) Session(config *Session) *DB { txConfig.FullSaveAssociations = true } - if config.Context != nil { + if config.Context != nil || config.PrepareStmt || config.SkipHooks { tx.Statement = tx.Statement.clone() tx.Statement.DB = tx + } + + if config.Context != nil { tx.Statement.Context = config.Context } if config.PrepareStmt { if v, ok := db.cacheStore.Load("preparedStmt"); ok { - tx.Statement = tx.Statement.clone() preparedStmt := v.(*PreparedStmtDB) tx.Statement.ConnPool = &PreparedStmtDB{ ConnPool: db.Config.ConnPool, @@ -189,7 +195,11 @@ func (db *DB) Session(config *Session) *DB { } } - if config.WithConditions { + if config.SkipHooks { + tx.Statement.SkipHooks = true + } + + if !config.NewDB { tx.clone = 2 } @@ -197,6 +207,10 @@ func (db *DB) Session(config *Session) *DB { tx.Config.DryRun = true } + if config.QueryFields { + tx.Config.QueryFields = true + } + if config.Logger != nil { tx.Config.Logger = config.Logger } @@ -210,14 +224,13 @@ func (db *DB) Session(config *Session) *DB { // WithContext change current instance db's context to ctx func (db *DB) WithContext(ctx context.Context) *DB { - return db.Session(&Session{WithConditions: true, Context: ctx}) + return db.Session(&Session{Context: ctx}) } // Debug start debug mode func (db *DB) Debug() (tx *DB) { return db.Session(&Session{ - WithConditions: true, - Logger: db.Logger.LogMode(logger.Info), + Logger: db.Logger.LogMode(logger.Info), }) } diff --git a/migrator.go b/migrator.go index ac06a1444..28ac35e7d 100644 --- a/migrator.go +++ b/migrator.go @@ -7,7 +7,7 @@ import ( // Migrator returns migrator func (db *DB) Migrator() Migrator { - return db.Dialector.Migrator(db.Session(&Session{WithConditions: true})) + return db.Dialector.Migrator(db.Session(&Session{})) } // AutoMigrate run auto migration for given models diff --git a/migrator/migrator.go b/migrator/migrator.go index 016ebfc75..5de820a83 100644 --- a/migrator/migrator.go +++ b/migrator/migrator.go @@ -82,7 +82,7 @@ func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { // AutoMigrate func (m Migrator) AutoMigrate(values ...interface{}) error { for _, value := range m.ReorderModels(values, true) { - tx := m.DB.Session(&gorm.Session{}) + tx := m.DB.Session(&gorm.Session{NewDB: true}) if !tx.Migrator().HasTable(value) { if err := tx.Migrator().CreateTable(value); err != nil { return err @@ -154,7 +154,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { func (m Migrator) CreateTable(values ...interface{}) error { for _, value := range m.ReorderModels(values, false) { - tx := m.DB.Session(&gorm.Session{}) + tx := m.DB.Session(&gorm.Session{NewDB: true}) if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { var ( createTableSQL = "CREATE TABLE ? (" @@ -237,7 +237,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { func (m Migrator) DropTable(values ...interface{}) error { values = m.ReorderModels(values, false) for i := len(values) - 1; i >= 0; i-- { - tx := m.DB.Session(&gorm.Session{}) + tx := m.DB.Session(&gorm.Session{NewDB: true}) if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { return tx.Exec("DROP TABLE IF EXISTS ?", m.CurrentTable(stmt)).Error }); err != nil { @@ -404,7 +404,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy func (m Migrator) ColumnTypes(value interface{}) (columnTypes []gorm.ColumnType, err error) { columnTypes = make([]gorm.ColumnType, 0) err = m.RunWithValue(value, func(stmt *gorm.Statement) error { - rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows() + rows, err := m.DB.Session(&gorm.Session{NewDB: true}).Table(stmt.Table).Limit(1).Rows() if err == nil { defer rows.Close() rawColumnTypes, err := rows.ColumnTypes() diff --git a/schema/naming.go b/schema/naming.go index dbc71e04f..e3b2104af 100644 --- a/schema/naming.go +++ b/schema/naming.go @@ -95,7 +95,7 @@ func toDBName(name string) string { if name == "" { return "" } else if v, ok := smap.Load(name); ok { - return fmt.Sprint(v) + return v.(string) } var ( @@ -134,6 +134,7 @@ func toDBName(name string) string { } else { buf.WriteByte(value[len(value)-1]) } - - return buf.String() + ret := buf.String() + smap.Store(name, ret) + return ret } diff --git a/statement.go b/statement.go index 7c0af59c1..27edf9da8 100644 --- a/statement.go +++ b/statement.go @@ -37,7 +37,7 @@ type Statement struct { Schema *schema.Schema Context context.Context RaiseErrorOnNotFound bool - UpdatingColumn bool + SkipHooks bool SQL strings.Builder Vars []interface{} CurDestIndex int @@ -190,7 +190,7 @@ func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { writer.WriteString("(NULL)") } case *DB: - subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true, WithConditions: true}).getInstance() + subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance() subdb.Statement.Vars = append(subdb.Statement.Vars, stmt.Vars...) subdb.callbacks.Query().Execute(subdb) writer.WriteString(subdb.Statement.SQL.String()) @@ -421,7 +421,7 @@ func (stmt *Statement) clone() *Statement { Schema: stmt.Schema, Context: stmt.Context, RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound, - UpdatingColumn: stmt.UpdatingColumn, + SkipHooks: stmt.SkipHooks, } for k, c := range stmt.Clauses { diff --git a/tests/associations_has_one_test.go b/tests/associations_has_one_test.go index f487bd9ee..a4fc8c4fc 100644 --- a/tests/associations_has_one_test.go +++ b/tests/associations_has_one_test.go @@ -83,6 +83,20 @@ func TestHasOneAssociation(t *testing.T) { AssertAssociationCount(t, user2, "Account", 0, "after clear") } +func TestHasOneAssociationWithSelect(t *testing.T) { + var user = *GetUser("hasone", Config{Account: true}) + + DB.Omit("Account.Number").Create(&user) + + AssertAssociationCount(t, user, "Account", 1, "") + + var account Account + DB.Model(&user).Association("Account").Find(&account) + if account.Number != "" { + t.Errorf("account's number should not be saved") + } +} + func TestHasOneAssociationForSlice(t *testing.T) { var users = []User{ *GetUser("slice-hasone-1", Config{Account: true}), diff --git a/tests/associations_many2many_test.go b/tests/associations_many2many_test.go index 2ecf7b669..1ddd3b858 100644 --- a/tests/associations_many2many_test.go +++ b/tests/associations_many2many_test.go @@ -93,6 +93,28 @@ func TestMany2ManyAssociation(t *testing.T) { AssertAssociationCount(t, user2, "Languages", 0, "after clear") } +func TestMany2ManyOmitAssociations(t *testing.T) { + var user = *GetUser("many2many_omit_associations", Config{Languages: 2}) + + if err := DB.Omit("Languages.*").Create(&user).Error; err == nil { + t.Fatalf("should raise error when create users without languages reference") + } + + if err := DB.Create(&user.Languages).Error; err != nil { + t.Fatalf("no error should happen when create languages, but got %v", err) + } + + if err := DB.Omit("Languages.*").Create(&user).Error; err != nil { + t.Fatalf("no error should happen when create user when languages exists, but got %v", err) + } + + // Find + var languages []Language + if DB.Model(&user).Association("Languages").Find(&languages); len(languages) != 2 { + t.Errorf("languages count should be %v, but got %v", 2, len(languages)) + } +} + func TestMany2ManyAssociationForSlice(t *testing.T) { var users = []User{ *GetUser("slice-many2many-1", Config{Languages: 2}), diff --git a/tests/associations_test.go b/tests/associations_test.go index c1a4e2b2a..f470338fd 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -3,6 +3,7 @@ package tests_test import ( "testing" + "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) @@ -32,6 +33,41 @@ func TestInvalidAssociation(t *testing.T) { } } +func TestAssociationNotNullClear(t *testing.T) { + type Profile struct { + gorm.Model + Number string + MemberID uint `gorm:"not null"` + } + + type Member struct { + gorm.Model + Profiles []Profile + } + + DB.Migrator().DropTable(&Member{}, &Profile{}) + + if err := DB.AutoMigrate(&Member{}, &Profile{}); err != nil { + t.Fatalf("Failed to migrate, got error: %v", err) + } + + member := &Member{ + Profiles: []Profile{{ + Number: "1", + }, { + Number: "2", + }}, + } + + if err := DB.Create(&member).Error; err != nil { + t.Fatalf("Failed to create test data, got error: %v", err) + } + + if err := DB.Model(member).Association("Profiles").Clear(); err == nil { + t.Fatalf("No error occured during clearind not null association") + } +} + func TestForeignKeyConstraints(t *testing.T) { type Profile struct { ID uint diff --git a/tests/count_test.go b/tests/count_test.go index 41bad71d8..55fb71e20 100644 --- a/tests/count_test.go +++ b/tests/count_test.go @@ -41,7 +41,7 @@ func TestCount(t *testing.T) { t.Errorf("multiple count in chain should works") } - tx := DB.Model(&User{}).Where("name = ?", user1.Name).Session(&gorm.Session{WithConditions: true}) + tx := DB.Model(&User{}).Where("name = ?", user1.Name).Session(&gorm.Session{}) tx.Count(&count1) tx.Or("name in ?", []string{user2.Name, user3.Name}).Count(&count2) if count1 != 1 || count2 != 3 { diff --git a/tests/create_test.go b/tests/create_test.go index 00674eec8..8d005d0b5 100644 --- a/tests/create_test.go +++ b/tests/create_test.go @@ -40,6 +40,32 @@ func TestCreate(t *testing.T) { } } +func TestCreateInBatches(t *testing.T) { + users := []User{ + *GetUser("create_in_batches_1", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 0, Languages: 1, Friends: 1}), + *GetUser("create_in_batches_2", Config{Account: false, Pets: 2, Toys: 4, Company: false, Manager: false, Team: 1, Languages: 3, Friends: 5}), + *GetUser("create_in_batches_3", Config{Account: true, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 4, Languages: 0, Friends: 1}), + *GetUser("create_in_batches_4", Config{Account: true, Pets: 3, Toys: 0, Company: false, Manager: true, Team: 0, Languages: 3, Friends: 0}), + *GetUser("create_in_batches_5", Config{Account: false, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 1, Languages: 3, Friends: 1}), + *GetUser("create_in_batches_6", Config{Account: true, Pets: 4, Toys: 3, Company: false, Manager: true, Team: 1, Languages: 3, Friends: 0}), + } + + DB.CreateInBatches(&users, 2) + + for _, user := range users { + if user.ID == 0 { + t.Fatalf("failed to fill user's ID, got %v", user.ID) + } else { + var newUser User + if err := DB.Where("id = ?", user.ID).Preload(clause.Associations).First(&newUser).Error; err != nil { + t.Fatalf("errors happened when query: %v", err) + } else { + CheckUser(t, newUser, user) + } + } + } +} + func TestCreateFromMap(t *testing.T) { if err := DB.Model(&User{}).Create(map[string]interface{}{"Name": "create_from_map", "Age": 18}).Error; err != nil { t.Fatalf("failed to create data from map, got error: %v", err) diff --git a/tests/hooks_test.go b/tests/hooks_test.go index d8b1770e1..fe3f7d082 100644 --- a/tests/hooks_test.go +++ b/tests/hooks_test.go @@ -371,10 +371,22 @@ func TestSetColumn(t *testing.T) { t.Errorf("invalid data after update, got %+v", product) } + DB.Model(&product).Session(&gorm.Session{SkipHooks: true}).Updates(Product3{Code: "L1216"}) + if product.Price != 270 || product.Code != "L1216" { + t.Errorf("invalid data after update, got %+v", product) + } + var result2 Product3 DB.First(&result2, product.ID) AssertEqual(t, result2, product) + + product2 := Product3{Name: "Product", Price: 0} + DB.Session(&gorm.Session{SkipHooks: true}).Create(&product2) + + if product2.Price != 0 { + t.Errorf("invalid price after create without hooks, got %+v", product2) + } } func TestHooksForSlice(t *testing.T) { diff --git a/tests/multi_primary_keys_test.go b/tests/multi_primary_keys_test.go index 68da8a888..dcc90cd9a 100644 --- a/tests/multi_primary_keys_test.go +++ b/tests/multi_primary_keys_test.go @@ -140,7 +140,7 @@ func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { } if name := DB.Dialector.Name(); name == "postgres" { - t.Skip("skip postgers due to it only allow unique constraint matching given keys") + t.Skip("skip postgres due to it only allow unique constraint matching given keys") } DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags") @@ -265,7 +265,7 @@ func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { } if name := DB.Dialector.Name(); name == "postgres" { - t.Skip("skip postgers due to it only allow unique constraint matching given keys") + t.Skip("skip postgres due to it only allow unique constraint matching given keys") } DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags") diff --git a/tests/query_test.go b/tests/query_test.go index dc2907e6d..c4162bdc6 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -268,6 +268,14 @@ func TestFindInBatches(t *testing.T) { t.Errorf("Incorrect users length, expects: 2, got %v", len(results)) } + for idx := range results { + results[idx].Name = results[idx].Name + "_new" + } + + if err := tx.Save(results).Error; err != nil { + t.Errorf("failed to save users, got error %v", err) + } + return nil }); result.Error != nil || result.RowsAffected != 6 { t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) @@ -276,6 +284,12 @@ func TestFindInBatches(t *testing.T) { if totalBatch != 6 { t.Errorf("incorrect total batch, expects: %v, got %v", 6, totalBatch) } + + var count int64 + DB.Model(&User{}).Where("name = ?", "find_in_batches_new").Count(&count) + if count != 6 { + t.Errorf("incorrect count after update, expects: %v, got %v", 6, count) + } } func TestFillSmallerStruct(t *testing.T) { @@ -334,6 +348,39 @@ func TestFillSmallerStruct(t *testing.T) { } } +func TestFillSmallerStructWithAllFields(t *testing.T) { + user := User{Name: "SmallerUser", Age: 100} + DB.Save(&user) + type SimpleUser struct { + ID int64 + Name string + UpdatedAt time.Time + CreatedAt time.Time + } + var simpleUsers []SimpleUser + dryDB := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}) + + result := dryDB.Model(&User{}).Find(&simpleUsers, user.ID) + if !regexp.MustCompile("SELECT .users.*id.*users.*name.*users.*updated_at.*users.*created_at.* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should include selected names, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Model(&User{}).Find(&User{}, user.ID) + if regexp.MustCompile("SELECT \\* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should not include a * wildcard, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Model(&User{}).Find(&[]User{}, user.ID) + if regexp.MustCompile("SELECT \\* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should not include a * wildcard, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Model(&User{}).Find(&[]*User{}, user.ID) + if regexp.MustCompile("SELECT \\* FROM .*users").MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL should not include a * wildcard, but got %v", result.Statement.SQL.String()) + } +} + func TestNot(t *testing.T) { dryDB := DB.Session(&gorm.Session{DryRun: true}) @@ -378,6 +425,53 @@ func TestNot(t *testing.T) { } } +func TestNotWithAllFields(t *testing.T) { + dryDB := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}) + userQuery := "SELECT .*users.*id.*users.*created_at.*users.*updated_at.*users.*deleted_at.*users.*name" + + ".*users.*age.*users.*birthday.*users.*company_id.*users.*manager_id.*users.*active.* FROM .*users.* " + + result := dryDB.Not(map[string]interface{}{"users.name": "jinzhu"}).Find(&User{}) + + if !regexp.MustCompile(userQuery + "WHERE .*users.*name.* <> .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("users.name = ?", "jinzhu1").Not("users.name = ?", "jinzhu2").Find(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*users.*name.* = .+ AND NOT .*users.*name.* = .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where(map[string]interface{}{"users.name": []string{"jinzhu", "jinzhu 2"}}).Find(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*users.*name.* IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not("users.name = ?", "jinzhu").Find(&User{}) + if !regexp.MustCompile(userQuery + "WHERE NOT .*users.*name.* = .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not(map[string]interface{}{"users.name": []string{"jinzhu", "jinzhu 2"}}).Find(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*users.*name.* NOT IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not([]int64{1, 2}).First(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*users.*id.* NOT IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not([]int64{}).First(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .users.\\..deleted_at. IS NULL ORDER BY").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Not(User{Name: "jinzhu", Age: 18}).First(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*users.*..*name.* <> .+ AND .*users.*..*age.* <> .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) + } +} + func TestOr(t *testing.T) { dryDB := DB.Session(&gorm.Session{DryRun: true}) @@ -397,6 +491,27 @@ func TestOr(t *testing.T) { } } +func TestOrWithAllFields(t *testing.T) { + dryDB := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}) + userQuery := "SELECT .*users.*id.*users.*created_at.*users.*updated_at.*users.*deleted_at.*users.*name" + + ".*users.*age.*users.*birthday.*users.*company_id.*users.*manager_id.*users.*active.* FROM .*users.* " + + result := dryDB.Where("role = ?", "admin").Or("role = ?", "super_admin").Find(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*role.* = .+ OR .*role.* = .+").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("users.name = ?", "jinzhu").Or(User{Name: "jinzhu 2", Age: 18}).Find(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*users.*name.* = .+ OR \\(.*users.*name.* AND .*users.*age.*\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Where("users.name = ?", "jinzhu").Or(map[string]interface{}{"name": "jinzhu 2", "age": 18}).Find(&User{}) + if !regexp.MustCompile(userQuery + "WHERE .*users.*name.* = .+ OR \\(.*age.* AND .*name.*\\)").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) + } +} + func TestPluck(t *testing.T) { users := []*User{ GetUser("pluck-user1", Config{}), @@ -529,6 +644,30 @@ func TestOmit(t *testing.T) { } } +func TestOmitWithAllFields(t *testing.T) { + user := User{Name: "OmitUser1", Age: 20} + DB.Save(&user) + + var userResult User + DB.Session(&gorm.Session{QueryFields: true}).Where("users.name = ?", user.Name).Omit("name").Find(&userResult) + if userResult.ID == 0 { + t.Errorf("Should not have ID because only selected name, %+v", userResult.ID) + } + + if userResult.Name != "" || userResult.Age != 20 { + t.Errorf("User Name should be omitted, got %v, Age should be ok, got %v", userResult.Name, userResult.Age) + } + + dryDB := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}) + userQuery := "SELECT .*users.*id.*users.*created_at.*users.*updated_at.*users.*deleted_at.*users.*birthday" + + ".*users.*company_id.*users.*manager_id.*users.*active.* FROM .*users.* " + + result := dryDB.Omit("name, age").Find(&User{}) + if !regexp.MustCompile(userQuery).MatchString(result.Statement.SQL.String()) { + t.Fatalf("SQL must include table name and selected fields, got %v", result.Statement.SQL.String()) + } +} + func TestPluckWithSelect(t *testing.T) { users := []User{ {Name: "pluck_with_select_1", Age: 25}, @@ -671,6 +810,31 @@ func TestOrder(t *testing.T) { } } +func TestOrderWithAllFields(t *testing.T) { + dryDB := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}) + userQuery := "SELECT .*users.*id.*users.*created_at.*users.*updated_at.*users.*deleted_at.*users.*name.*users.*age" + + ".*users.*birthday.*users.*company_id.*users.*manager_id.*users.*active.* FROM .*users.* " + + result := dryDB.Order("users.age desc, users.name").Find(&User{}) + if !regexp.MustCompile(userQuery + "users.age desc, users.name").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String()) + } + + result = dryDB.Order("users.age desc").Order("users.name").Find(&User{}) + if !regexp.MustCompile(userQuery + "ORDER BY users.age desc,users.name").MatchString(result.Statement.SQL.String()) { + t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String()) + } + + stmt := dryDB.Clauses(clause.OrderBy{ + Expression: clause.Expr{SQL: "FIELD(id,?)", Vars: []interface{}{[]int{1, 2, 3}}, WithoutParentheses: true}, + }).Find(&User{}).Statement + + explainedSQL := dryDB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) + if !regexp.MustCompile(userQuery + "ORDER BY FIELD\\(id,1,2,3\\)").MatchString(explainedSQL) { + t.Fatalf("Build Order condition, but got %v", explainedSQL) + } +} + func TestLimit(t *testing.T) { users := []User{ {Name: "LimitUser1", Age: 1}, @@ -878,3 +1042,13 @@ func TestQueryWithTableAndConditions(t *testing.T) { t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) } } + +func TestQueryWithTableAndConditionsAndAllFields(t *testing.T) { + result := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}).Table("user").Find(&User{}, User{Name: "jinzhu"}) + userQuery := "SELECT .*user.*id.*user.*created_at.*user.*updated_at.*user.*deleted_at.*user.*name.*user.*age" + + ".*user.*birthday.*user.*company_id.*user.*manager_id.*user.*active.* FROM .user. " + + if !regexp.MustCompile(userQuery + `WHERE .user.\..name. = .+ AND .user.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { + t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) + } +} diff --git a/tests/table_test.go b/tests/table_test.go index 647b5e191..0c6b3eb04 100644 --- a/tests/table_test.go +++ b/tests/table_test.go @@ -68,3 +68,60 @@ func TestTable(t *testing.T) { AssertEqual(t, r.Statement.Vars, []interface{}{2, 4, 1, 3}) } + +func TestTableWithAllFields(t *testing.T) { + dryDB := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}) + userQuery := "SELECT .*user.*id.*user.*created_at.*user.*updated_at.*user.*deleted_at.*user.*name.*user.*age" + + ".*user.*birthday.*user.*company_id.*user.*manager_id.*user.*active.* " + + r := dryDB.Table("`user`").Find(&User{}).Statement + if !regexp.MustCompile(userQuery + "FROM `user`").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("user as u").Select("name").Find(&User{}).Statement + if !regexp.MustCompile("SELECT .name. FROM user as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("gorm.user").Select("name").Find(&User{}).Statement + if !regexp.MustCompile("SELECT .name. FROM .gorm.\\..user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Select("name").Find(&UserWithTable{}).Statement + if !regexp.MustCompile("SELECT .name. FROM .gorm.\\..user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Create(&UserWithTable{}).Statement + if DB.Dialector.Name() != "sqlite" { + if !regexp.MustCompile(`INSERT INTO .gorm.\..user. (.*name.*) VALUES (.*)`).MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + } else { + if !regexp.MustCompile(`INSERT INTO .user. (.*name.*) VALUES (.*)`).MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + } + + userQueryCharacter := "SELECT .*u.*id.*u.*created_at.*u.*updated_at.*u.*deleted_at.*u.*name.*u.*age.*u.*birthday" + + ".*u.*company_id.*u.*manager_id.*u.*active.* " + + r = dryDB.Table("(?) as u", DB.Model(&User{}).Select("name")).Find(&User{}).Statement + if !regexp.MustCompile(userQueryCharacter + "FROM \\(SELECT .name. FROM .users. WHERE .users.\\..deleted_at. IS NULL\\) as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Table("(?) as u, (?) as p", DB.Model(&User{}).Select("name"), DB.Model(&Pet{}).Select("name")).Find(&User{}).Statement + if !regexp.MustCompile(userQueryCharacter + "FROM \\(SELECT .name. FROM .users. WHERE .users.\\..deleted_at. IS NULL\\) as u, \\(SELECT .name. FROM .pets. WHERE .pets.\\..deleted_at. IS NULL\\) as p WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + r = dryDB.Where("name = ?", 1).Table("(?) as u, (?) as p", DB.Model(&User{}).Select("name").Where("name = ?", 2), DB.Model(&Pet{}).Where("name = ?", 4).Select("name")).Where("name = ?", 3).Find(&User{}).Statement + if !regexp.MustCompile(userQueryCharacter + "FROM \\(SELECT .name. FROM .users. WHERE name = .+ AND .users.\\..deleted_at. IS NULL\\) as u, \\(SELECT .name. FROM .pets. WHERE name = .+ AND .pets.\\..deleted_at. IS NULL\\) as p WHERE name = .+ AND name = .+ AND .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { + t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) + } + + AssertEqual(t, r.Statement.Vars, []interface{}{2, 4, 1, 3}) +} diff --git a/tests/upsert_test.go b/tests/upsert_test.go index ba7c1a9d7..0ba8b9f06 100644 --- a/tests/upsert_test.go +++ b/tests/upsert_test.go @@ -41,6 +41,16 @@ func TestUpsert(t *testing.T) { } else if langs[0].Name != "upsert-new" { t.Errorf("should update name on conflict, but got name %+v", langs[0].Name) } + + lang = Language{Code: "upsert", Name: "Upsert-Newname"} + if err := DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&lang).Error; err != nil { + t.Fatalf("failed to upsert, got %v", err) + } + + var result Language + if err := DB.Find(&result, "code = ?", lang.Code).Error; err != nil || result.Name != lang.Name { + t.Fatalf("failed to upsert, got name %v", result.Name) + } } func TestUpsertSlice(t *testing.T) {