Skip to content

Commit

Permalink
Work on create callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Feb 3, 2020
1 parent 728c0d4 commit d52ee0a
Show file tree
Hide file tree
Showing 11 changed files with 222 additions and 62 deletions.
11 changes: 9 additions & 2 deletions callbacks/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"

"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/clause"
)

func BeforeCreate(db *gorm.DB) {
Expand All @@ -17,8 +18,14 @@ func SaveBeforeAssociations(db *gorm.DB) {
}

func Create(db *gorm.DB) {
db.Statement.Build("WITH", "INSERT", "VALUES", "ON_CONFLICT", "RETURNING")
db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
db.Statement.AddClauseIfNotExists(clause.Insert{
Table: clause.Table{Table: db.Statement.Table},
})

db.Statement.Build("INSERT", "VALUES", "ON_CONFLICT", "RETURNING")
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
fmt.Println(err)
fmt.Println(result)
fmt.Println(db.Statement.SQL.String(), db.Statement.Vars)
}

Expand Down
12 changes: 9 additions & 3 deletions chainable_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,19 @@ func (db *DB) Omit(columns ...string) (tx *DB) {

func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.AddClause(clause.Where{AndConditions: tx.Statement.BuildCondtion(query, args...)})
tx.Statement.AddClause(clause.Where{
AndConditions: tx.Statement.BuildCondtion(query, args...),
})
return
}

// Not add NOT condition
func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.AddClause(clause.Where{
AndConditions: []clause.Expression{clause.NotConditions(tx.Statement.BuildCondtion(query, args...))},
AndConditions: []clause.Expression{
clause.NotConditions(tx.Statement.BuildCondtion(query, args...)),
},
})
return
}
Expand All @@ -72,7 +76,9 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.AddClause(clause.Where{
ORConditions: []clause.ORConditions{tx.Statement.BuildCondtion(query, args...)},
ORConditions: []clause.ORConditions{
tx.Statement.BuildCondtion(query, args...),
},
})
return
}
Expand Down
34 changes: 34 additions & 0 deletions clause/insert.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package clause

type Insert struct {
Table Table
Priority string
}

// Name insert clause name
func (insert Insert) Name() string {
return "INSERT"
}

// Build build insert clause
func (insert Insert) Build(builder Builder) {
if insert.Priority != "" {
builder.Write(insert.Priority)
builder.WriteByte(' ')
}

builder.Write("INTO ")
builder.WriteQuoted(insert.Table)
}

// MergeExpression merge insert clauses
func (insert Insert) MergeExpression(expr Expression) {
if v, ok := expr.(Insert); ok {
if insert.Priority == "" {
insert.Priority = v.Priority
}
if insert.Table.Table == "" {
insert.Table = v.Table
}
}
}
39 changes: 39 additions & 0 deletions clause/value.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package clause

type Values struct {
Columns []Column
Values [][]interface{}
}

// Name from clause name
func (Values) Name() string {
return ""
}

// Build build from clause
func (values Values) Build(builder Builder) {
if len(values.Columns) > 0 {
builder.WriteByte('(')
for idx, column := range values.Columns {
if idx > 0 {
builder.WriteByte(',')
}
builder.WriteQuoted(column)
}
builder.WriteByte(')')

builder.Write(" VALUES ")

for idx, value := range values.Values {
builder.WriteByte('(')
if idx > 0 {
builder.WriteByte(',')
}

builder.Write(builder.AddVar(value...))
builder.WriteByte(')')
}
} else {
builder.Write("DEFAULT VALUES")
}
}
33 changes: 33 additions & 0 deletions dialects/postgres/postgres.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package postgres

import (
"database/sql"

"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/callbacks"
_ "github.com/lib/pq"
)

type Dialector struct {
DSN string
}

func Open(dsn string) gorm.Dialector {
return &Dialector{DSN: dsn}
}

func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
// register callbacks
callbacks.RegisterDefaultCallbacks(db)

db.DB, err = sql.Open("postgres", dialector.DSN)
return
}

func (Dialector) Migrator() gorm.Migrator {
return nil
}

func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
return "?"
}
12 changes: 8 additions & 4 deletions dialects/sqlite/sqlite.go
Original file line number Diff line number Diff line change
@@ -1,29 +1,33 @@
package sqlite

import (
"database/sql"

"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/callbacks"
_ "github.com/mattn/go-sqlite3"
)

type Dialector struct {
DSN string
}

func Open(dsn string) gorm.Dialector {
return &Dialector{}
return &Dialector{DSN: dsn}
}

func (Dialector) Initialize(db *gorm.DB) error {
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
// register callbacks
callbacks.RegisterDefaultCallbacks(db)

return nil
db.DB, err = sql.Open("sqlite3", dialector.DSN)
return
}

func (Dialector) Migrator() gorm.Migrator {
return nil
}

func (Dialector) BindVar(stmt gorm.Statement, v interface{}) string {
func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
return "?"
}
66 changes: 34 additions & 32 deletions finisher_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,16 @@ import (
"database/sql"
)

func (db *DB) Count(sql string, values ...interface{}) (tx *DB) {
// Create insert the value into database
func (db *DB) Create(value interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.Dest = value
tx.callbacks.Create().Execute(tx)
return
}

// Save update value in database, if the value doesn't have primary key, will insert it
func (db *DB) Save(value interface{}) (tx *DB) {
tx = db.getInstance()
return
}
Expand Down Expand Up @@ -36,32 +45,12 @@ func (db *DB) Find(out interface{}, where ...interface{}) (tx *DB) {
return
}

func (db *DB) Row() *sql.Row {
return nil
}

func (db *DB) Rows() (*sql.Rows, error) {
return nil, nil
}

// Scan scan value to a struct
func (db *DB) Scan(dest interface{}) (tx *DB) {
tx = db.getInstance()
return
}

func (db *DB) ScanRows(rows *sql.Rows, result interface{}) error {
return nil
}

// Create insert the value into database
func (db *DB) Create(value interface{}) (tx *DB) {
func (db *DB) FirstOrInit(out interface{}, where ...interface{}) (tx *DB) {
tx = db.getInstance()
return
}

// Save update value in database, if the value doesn't have primary key, will insert it
func (db *DB) Save(value interface{}) (tx *DB) {
func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) {
tx = db.getInstance()
return
}
Expand All @@ -78,7 +67,7 @@ func (db *DB) Updates(values interface{}) (tx *DB) {
return
}

func (db *DB) UpdateColumn(attrs ...interface{}) (tx *DB) {
func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) {
tx = db.getInstance()
return
}
Expand All @@ -88,34 +77,47 @@ func (db *DB) UpdateColumns(values interface{}) (tx *DB) {
return
}

func (db *DB) FirstOrInit(out interface{}, where ...interface{}) (tx *DB) {
// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
func (db *DB) Delete(value interface{}, where ...interface{}) (tx *DB) {
tx = db.getInstance()
return
}

func (db *DB) FirstOrCreate(out interface{}, where ...interface{}) (tx *DB) {
func (db *DB) Related(value interface{}, foreignKeys ...string) (tx *DB) {
tx = db.getInstance()
return
}

// Delete delete value match given conditions, if the value has primary key, then will including the primary key as condition
func (db *DB) Delete(value interface{}, where ...interface{}) (tx *DB) {
//Preloads only preloads relations, don`t touch out
func (db *DB) Preloads(out interface{}) (tx *DB) {
tx = db.getInstance()
return
}

func (db *DB) Related(value interface{}, foreignKeys ...string) (tx *DB) {
func (db *DB) Association(column string) *Association {
return nil
}

func (db *DB) Count(value interface{}) (tx *DB) {
tx = db.getInstance()
return
}

//Preloads only preloads relations, don`t touch out
func (db *DB) Preloads(out interface{}) (tx *DB) {
func (db *DB) Row() *sql.Row {
return nil
}

func (db *DB) Rows() (*sql.Rows, error) {
return nil, nil
}

// Scan scan value to a struct
func (db *DB) Scan(dest interface{}) (tx *DB) {
tx = db.getInstance()
return
}

func (db *DB) Association(column string) *Association {
func (db *DB) ScanRows(rows *sql.Rows, result interface{}) error {
return nil
}

Expand Down
6 changes: 5 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,8 @@ module github.com/jinzhu/gorm

go 1.13

require github.com/jinzhu/inflection v1.0.0
require (
github.com/jinzhu/inflection v1.0.0
github.com/lib/pq v1.3.0
github.com/mattn/go-sqlite3 v2.0.3+incompatible
)
20 changes: 11 additions & 9 deletions gorm.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ type DB struct {
*Config
Dialector
Instance
DB CommonDB
clone bool
callbacks *callbacks
cacheStore *sync.Map
DB CommonDB
ClauseBuilders map[string]clause.ClauseBuilder
clone bool
callbacks *callbacks
cacheStore *sync.Map
}

// Session session config when create session with Session() method
Expand Down Expand Up @@ -140,11 +141,12 @@ func (db *DB) getInstance() *DB {
Context: ctx,
Statement: &Statement{DB: db, Clauses: map[string]clause.Clause{}},
},
Config: db.Config,
Dialector: db.Dialector,
DB: db.DB,
callbacks: db.callbacks,
cacheStore: db.cacheStore,
Config: db.Config,
Dialector: db.Dialector,
ClauseBuilders: db.ClauseBuilders,
DB: db.DB,
callbacks: db.callbacks,
cacheStore: db.cacheStore,
}
}

Expand Down
2 changes: 1 addition & 1 deletion interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
type Dialector interface {
Initialize(*DB) error
Migrator() Migrator
BindVar(stmt Statement, v interface{}) string
BindVar(stmt *Statement, v interface{}) string
}

// CommonDB common db interface
Expand Down
Loading

0 comments on commit d52ee0a

Please sign in to comment.