Skip to content

Commit

Permalink
Make inesrt into db works
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Feb 23, 2020
1 parent 868ae05 commit fa22807
Show file tree
Hide file tree
Showing 10 changed files with 92 additions and 66 deletions.
2 changes: 1 addition & 1 deletion callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func (p *processor) Execute(db *DB) {

if stmt := db.Statement; stmt != nil {
db.Logger.Trace(curTime, func() (string, int64) {
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars), db.RowsAffected
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...), db.RowsAffected
}, db.Error)
}
}
Expand Down
58 changes: 33 additions & 25 deletions callbacks/create.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package callbacks

import (
"fmt"
"reflect"

"github.com/jinzhu/gorm"
Expand All @@ -11,8 +10,6 @@ import (
func BeforeCreate(db *gorm.DB) {
// before save
// before create

// assign timestamp
}

func SaveBeforeAssociations(db *gorm.DB) {
Expand All @@ -22,16 +19,29 @@ func Create(db *gorm.DB) {
db.Statement.AddClauseIfNotExists(clause.Insert{
Table: clause.Table{Name: db.Statement.Table},
})
values, _ := ConvertToCreateValues(db.Statement)
db.Statement.AddClause(values)
db.Statement.AddClause(ConvertToCreateValues(db.Statement))

db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT")
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)

fmt.Printf("%+v\n", values)
fmt.Println(err)
fmt.Println(result)
fmt.Println(db.Statement.SQL.String(), db.Statement.Vars)
if err == nil {
if db.Statement.Schema != nil {
if insertID, err := result.LastInsertId(); err == nil {
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- {
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue.Index(i), insertID)
insertID--
}
case reflect.Struct:
db.Statement.Schema.PrioritizedPrimaryField.Set(db.Statement.ReflectValue, insertID)
}
}
}
db.RowsAffected, _ = result.RowsAffected()
} else {
db.AddError(err)
}
}

func SaveAfterAssociations(db *gorm.DB) {
Expand All @@ -43,19 +53,18 @@ func AfterCreate(db *gorm.DB) {
}

// ConvertToCreateValues convert to create values
func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]interface{}) {
func ConvertToCreateValues(stmt *gorm.Statement) clause.Values {
switch value := stmt.Dest.(type) {
case map[string]interface{}:
return ConvertMapToValues(stmt, value), nil
return ConvertMapToValues(stmt, value)
case []map[string]interface{}:
return ConvertSliceOfMapToValues(stmt, value), nil
return ConvertSliceOfMapToValues(stmt, value)
default:
var (
values = clause.Values{}
selectColumns, restricted = SelectAndOmitColumns(stmt)
curTime = stmt.DB.NowFunc()
isZero = false
returnningValues []map[string]interface{}
)

for _, db := range stmt.Schema.DBNames {
Expand All @@ -66,13 +75,12 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in
}
}

reflectValue := reflect.Indirect(reflect.ValueOf(stmt.Dest))
switch reflectValue.Kind() {
switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
values.Values = make([][]interface{}, reflectValue.Len())
values.Values = make([][]interface{}, stmt.ReflectValue.Len())
defaultValueFieldsHavingValue := map[string][]interface{}{}
for i := 0; i < reflectValue.Len(); i++ {
rv := reflect.Indirect(reflectValue.Index(i))
for i := 0; i < stmt.ReflectValue.Len(); i++ {
rv := reflect.Indirect(stmt.ReflectValue.Index(i))
values.Values[i] = make([]interface{}, len(values.Columns))
for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name]
Expand All @@ -91,7 +99,7 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in
if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) {
if v, isZero := field.ValueOf(rv); !isZero {
if len(defaultValueFieldsHavingValue[db]) == 0 {
defaultValueFieldsHavingValue[db] = make([]interface{}, reflectValue.Len())
defaultValueFieldsHavingValue[db] = make([]interface{}, stmt.ReflectValue.Len())
}
defaultValueFieldsHavingValue[db][i] = v
}
Expand All @@ -113,27 +121,27 @@ func ConvertToCreateValues(stmt *gorm.Statement) (clause.Values, []map[string]in
values.Values = [][]interface{}{make([]interface{}, len(values.Columns))}
for idx, column := range values.Columns {
field := stmt.Schema.FieldsByDBName[column.Name]
if values.Values[0][idx], isZero = field.ValueOf(reflectValue); isZero {
if values.Values[0][idx], isZero = field.ValueOf(stmt.ReflectValue); isZero {
if field.DefaultValueInterface != nil {
values.Values[0][idx] = field.DefaultValueInterface
field.Set(reflectValue, field.DefaultValueInterface)
field.Set(stmt.ReflectValue, field.DefaultValueInterface)
} else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 {
field.Set(reflectValue, curTime)
values.Values[0][idx], _ = field.ValueOf(reflectValue)
field.Set(stmt.ReflectValue, curTime)
values.Values[0][idx], _ = field.ValueOf(stmt.ReflectValue)
}
}
}

for db, field := range stmt.Schema.FieldsWithDefaultDBValue {
if v, ok := selectColumns[db]; (ok && v) || (!ok && !restricted) {
if v, isZero := field.ValueOf(reflectValue); !isZero {
if v, isZero := field.ValueOf(stmt.ReflectValue); !isZero {
values.Columns = append(values.Columns, clause.Column{Name: db})
values.Values[0] = append(values.Values[0], v)
}
}
}
}

return values, returnningValues
return values
}
}
8 changes: 2 additions & 6 deletions callbacks/query.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package callbacks

import (
"fmt"

"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/clause"
)
Expand All @@ -15,10 +13,8 @@ func Query(db *gorm.DB) {
db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
}

result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
fmt.Println(err)
fmt.Println(result)
fmt.Println(db.Statement.SQL.String(), db.Statement.Vars)
rows, err := db.DB.QueryContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
db.AddError(err)
}

func Preload(db *gorm.DB) {
Expand Down
23 changes: 13 additions & 10 deletions logger/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ func New(writer Writer, config Config) Interface {
)

if config.Colorful {
infoPrefix = Green + "%s\n" + Reset + Green + "[info]" + Reset
warnPrefix = Blue + "%s\n" + Reset + Magenta + "[warn]" + Reset
errPrefix = Magenta + "%s\n" + Reset + Red + "[error]" + Reset
infoPrefix = Green + "%s\n" + Reset + Green + "[info] " + Reset
warnPrefix = Blue + "%s\n" + Reset + Magenta + "[warn] " + Reset
errPrefix = Magenta + "%s\n" + Reset + Red + "[error] " + Reset
tracePrefix = Green + "%s\n" + Reset + YellowBold + "[%.3fms] " + Green + "[rows:%d]" + Reset + " %s"
traceErrPrefix = Magenta + "%s\n" + Reset + Redbold + "[%.3fms] " + Yellow + "[rows:%d]" + Reset + " %s"
}
Expand All @@ -93,37 +93,40 @@ type logger struct {

// LogMode log mode
func (l logger) LogMode(level LogLevel) Interface {
config := l.Config
config.LogLevel = level
return logger{Writer: l.Writer, Config: config}
l.LogLevel = level
return l
}

// Info print info
func (l logger) Info(msg string, data ...interface{}) {
if l.LogLevel >= Info {
l.Printf(l.infoPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...))
l.Printf(l.infoPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
}
}

// Warn print warn messages
func (l logger) Warn(msg string, data ...interface{}) {
if l.LogLevel >= Warn {
l.Printf(l.warnPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...))
l.Printf(l.warnPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
}
}

// Error print error messages
func (l logger) Error(msg string, data ...interface{}) {
if l.LogLevel >= Error {
l.Printf(l.errPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...))
l.Printf(l.errPrefix+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...)
}
}

// Trace print sql message
func (l logger) Trace(begin time.Time, fc func() (string, int64), err error) {
if elapsed := time.Now().Sub(begin); err != nil || elapsed > l.SlowThreshold {
sql, rows := fc()
l.Printf(l.traceErrPrefix, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
fileline := utils.FileWithLineNum()
if err != nil {
fileline += " " + err.Error()
}
l.Printf(l.traceErrPrefix, fileline, float64(elapsed.Nanoseconds())/1e6, rows, sql)
} else if l.LogLevel >= Info {
sql, rows := fc()
l.Printf(l.tracePrefix, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
Expand Down
11 changes: 10 additions & 1 deletion logger/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v
case bool:
vars[idx] = fmt.Sprint(v)
case time.Time:
vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper
if v.IsZero() {
vars[idx] = escaper + "0000-00-00 00:00:00" + escaper
} else {
vars[idx] = escaper + v.Format("2006-01-02 15:04:05") + escaper
}
case []byte:
if isPrintable(v) {
vars[idx] = escaper + strings.Replace(string(v), escaper, "\\"+escaper, -1) + escaper
Expand All @@ -48,6 +52,11 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, v
vars[idx] = "NULL"
} else {
rv := reflect.Indirect(reflect.ValueOf(v))
if !rv.IsValid() {
vars[idx] = "NULL"
return
}

for _, t := range convertableTypes {
if rv.Type().ConvertibleTo(t) {
convertParams(rv.Convert(t).Interface(), idx)
Expand Down
2 changes: 1 addition & 1 deletion schema/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
var err error
field.Creatable = false
field.Updatable = false
if field.EmbeddedSchema, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil {
if field.EmbeddedSchema, _, err = Parse(fieldValue.Interface(), &sync.Map{}, schema.namer); err != nil {
schema.err = err
}
for _, ef := range field.EmbeddedSchema.Fields {
Expand Down
4 changes: 2 additions & 2 deletions schema/relationship.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (schema *Schema) parseRelation(field *Field) {
}
)

if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
if relation.FieldSchema, _, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
schema.err = err
return
}
Expand Down Expand Up @@ -192,7 +192,7 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
}
}

if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
if relation.JoinTable, _, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil {
schema.err = err
}
relation.JoinTable.Name = many2many
Expand Down
15 changes: 8 additions & 7 deletions schema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,22 @@ func (schema Schema) LookUpField(name string) *Field {
}

// get data type from dialector
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
modelType := reflect.ValueOf(dest).Type()
func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, reflect.Value, error) {
reflectValue := reflect.ValueOf(dest)
modelType := reflectValue.Type()
for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Ptr {
modelType = modelType.Elem()
}

if modelType.Kind() != reflect.Struct {
if modelType.PkgPath() == "" {
return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
return nil, reflectValue, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
}
return nil, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
return nil, reflectValue, fmt.Errorf("%w: %v.%v", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
}

if v, ok := cacheStore.Load(modelType); ok {
return v.(*Schema), nil
return v.(*Schema), reflectValue, nil
}

schema := &Schema{
Expand Down Expand Up @@ -167,10 +168,10 @@ func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error)
for _, field := range schema.Fields {
if field.DataType == "" && field.Creatable {
if schema.parseRelation(field); schema.err != nil {
return schema, schema.err
return schema, reflectValue, schema.err
}
}
}

return schema, schema.err
return schema, reflectValue, schema.err
}
32 changes: 19 additions & 13 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -32,22 +33,23 @@ func (instance *Instance) ToSQL(clauses ...string) (string, []interface{}) {
func (inst *Instance) AddError(err error) {
if inst.Error == nil {
inst.Error = err
} else {
} else if err != nil {
inst.Error = fmt.Errorf("%v; %w", inst.Error, err)
}
}

// Statement statement
type Statement struct {
Table string
Model interface{}
Dest interface{}
Clauses map[string]clause.Clause
Selects []string // selected columns
Omits []string // omit columns
Settings sync.Map
DB *DB
Schema *schema.Schema
Table string
Model interface{}
Dest interface{}
ReflectValue reflect.Value
Clauses map[string]clause.Clause
Selects []string // selected columns
Omits []string // omit columns
Settings sync.Map
DB *DB
Schema *schema.Schema

// SQL Builder
SQL strings.Builder
Expand Down Expand Up @@ -197,7 +199,7 @@ func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
// BuildCondtion build condition
func (stmt Statement) BuildCondtion(query interface{}, args ...interface{}) (conditions []clause.Expression) {
if sql, ok := query.(string); ok {
if i, err := strconv.Atoi(sql); err != nil {
if i, err := strconv.Atoi(sql); err == nil {
query = i
} else if len(args) == 0 || (len(args) > 0 && strings.Contains(sql, "?")) || strings.Contains(sql, "@") {
return []clause.Expression{clause.Expr{SQL: sql, Vars: args}}
Expand Down Expand Up @@ -272,8 +274,12 @@ func (stmt *Statement) Build(clauses ...string) {
}

func (stmt *Statement) Parse(value interface{}) (err error) {
if stmt.Schema, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil && stmt.Table == "" {
stmt.Table = stmt.Schema.Table
if stmt.Schema, stmt.ReflectValue, err = schema.Parse(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
stmt.ReflectValue = reflect.Indirect(stmt.ReflectValue)

if stmt.Table == "" {
stmt.Table = stmt.Schema.Table
}
}
return err
}
3 changes: 3 additions & 0 deletions tests/tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ func RunTestsSuit(t *testing.T, db *gorm.DB) {
}

func TestCreate(t *testing.T, db *gorm.DB) {
db.AutoMigrate(&User{})
db = db.Debug()

t.Run("Create", func(t *testing.T) {
var user = User{
Name: "create",
Expand Down

0 comments on commit fa22807

Please sign in to comment.