diff --git a/chainable_api.go b/chainable_api.go index f358d3168..cac7495d5 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -80,7 +80,7 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.AddClause(clause.Where{ - ORConditions: []clause.ORConditions{ + OrConditions: []clause.OrConditions{ tx.Statement.BuildCondtion(query, args...), }, }) diff --git a/clause/clause_test.go b/clause/clause_test.go index 97d30f2d3..37f07686b 100644 --- a/clause/clause_test.go +++ b/clause/clause_test.go @@ -12,17 +12,19 @@ import ( "github.com/jinzhu/gorm/tests" ) -func TestClause(t *testing.T) { +func TestClauses(t *testing.T) { var ( - db, _ = gorm.Open(nil, nil) + db, _ = gorm.Open(tests.DummyDialector{}, nil) results = []struct { Clauses []clause.Interface Result string Vars []interface{} - }{{ - []clause.Interface{clause.Select{}, clause.From{}}, - "SELECT * FROM users", []interface{}{}, - }} + }{ + { + []clause.Interface{clause.Select{}, clause.From{}, clause.Where{AndConditions: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}}}}, + "SELECT * FROM `users` WHERE `users`.`id` = ?", []interface{}{"1"}, + }, + } ) for idx, result := range results { diff --git a/clause/expression.go b/clause/expression.go index 722df7c79..3ddc146d2 100644 --- a/clause/expression.go +++ b/clause/expression.go @@ -5,6 +5,11 @@ const ( CurrentTable string = "@@@table@@@" ) +var PrimaryColumn = Column{ + Table: CurrentTable, + Name: PrimaryKey, +} + // Expression expression interface type Expression interface { Build(builder Builder) diff --git a/clause/query.go b/clause/query.go index 949678d9c..ce609014a 100644 --- a/clause/query.go +++ b/clause/query.go @@ -6,6 +6,14 @@ import "strings" // Query Expressions //////////////////////////////////////////////////////////////////////////////// +func Add(exprs ...Expression) AddConditions { + return AddConditions(exprs) +} + +func Or(exprs ...Expression) OrConditions { + return OrConditions(exprs) +} + type AddConditions []Expression func (cs AddConditions) Build(builder Builder) { @@ -17,9 +25,9 @@ func (cs AddConditions) Build(builder Builder) { } } -type ORConditions []Expression +type OrConditions []Expression -func (cs ORConditions) Build(builder Builder) { +func (cs OrConditions) Build(builder Builder) { for idx, c := range cs { if idx > 0 { builder.Write(" OR ") diff --git a/clause/where.go b/clause/where.go index 888b9d076..de82662cc 100644 --- a/clause/where.go +++ b/clause/where.go @@ -3,7 +3,7 @@ package clause // Where where clause type Where struct { AndConditions AddConditions - ORConditions []ORConditions + OrConditions []OrConditions builders []Expression } @@ -31,8 +31,8 @@ func (where Where) Build(builder Builder) { } } - var singleOrConditions []ORConditions - for _, or := range where.ORConditions { + var singleOrConditions []OrConditions + for _, or := range where.OrConditions { if len(or) == 1 { if withConditions { builder.Write(" OR ") @@ -69,7 +69,7 @@ func (where Where) Build(builder Builder) { func (where Where) MergeExpression(expr Expression) { if w, ok := expr.(Where); ok { where.AndConditions = append(where.AndConditions, w.AndConditions...) - where.ORConditions = append(where.ORConditions, w.ORConditions...) + where.OrConditions = append(where.OrConditions, w.OrConditions...) where.builders = append(where.builders, w.builders...) } else { where.builders = append(where.builders, expr) diff --git a/dialects/mysql/mysql.go b/dialects/mysql/mysql.go index ba306889c..b402ef95e 100644 --- a/dialects/mysql/mysql.go +++ b/dialects/mysql/mysql.go @@ -27,3 +27,7 @@ func (Dialector) Migrator() gorm.Migrator { func (Dialector) BindVar(stmt gorm.Statement, v interface{}) string { return "?" } + +func (Dialector) QuoteChars() [2]byte { + return [2]byte{'`', '`'} // `name` +} diff --git a/dialects/postgres/postgres.go b/dialects/postgres/postgres.go index 3abf05e3b..9ea0048a0 100644 --- a/dialects/postgres/postgres.go +++ b/dialects/postgres/postgres.go @@ -31,3 +31,7 @@ func (Dialector) Migrator() gorm.Migrator { func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "?" } + +func (Dialector) QuoteChars() [2]byte { + return [2]byte{'"', '"'} // "name" +} diff --git a/dialects/sqlite/sqlite.go b/dialects/sqlite/sqlite.go index 91c3389e9..80a18cfb0 100644 --- a/dialects/sqlite/sqlite.go +++ b/dialects/sqlite/sqlite.go @@ -31,3 +31,7 @@ func (Dialector) Migrator() gorm.Migrator { func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string { return "?" } + +func (Dialector) QuoteChars() [2]byte { + return [2]byte{'`', '`'} // `name` +} diff --git a/go.mod b/go.mod index 1f4d31a20..e47297fbe 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,8 @@ module github.com/jinzhu/gorm go 1.13 require ( + github.com/go-sql-driver/mysql v1.5.0 // indirect github.com/jinzhu/inflection v1.0.0 - github.com/lib/pq v1.3.0 - github.com/mattn/go-sqlite3 v2.0.3+incompatible + github.com/lib/pq v1.3.0 // indirect + github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect ) diff --git a/gorm.go b/gorm.go index 10d61f80d..23f812d19 100644 --- a/gorm.go +++ b/gorm.go @@ -23,16 +23,21 @@ type Config struct { NowFunc func() time.Time } +type shared struct { + callbacks *callbacks + cacheStore *sync.Map + quoteChars [2]byte +} + // DB GORM DB definition type DB struct { *Config Dialector Instance - DB CommonDB ClauseBuilders map[string]clause.ClauseBuilder + DB CommonDB clone bool - callbacks *callbacks - cacheStore *sync.Map + *shared } // Session session config when create session with Session() method @@ -65,13 +70,16 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) { Dialector: dialector, ClauseBuilders: map[string]clause.ClauseBuilder{}, clone: true, - cacheStore: &sync.Map{}, + shared: &shared{ + cacheStore: &sync.Map{}, + }, } db.callbacks = initializeCallbacks(db) if dialector != nil { err = dialector.Initialize(db) + db.quoteChars = dialector.QuoteChars() } return } @@ -146,8 +154,7 @@ func (db *DB) getInstance() *DB { Dialector: db.Dialector, ClauseBuilders: db.ClauseBuilders, DB: db.DB, - callbacks: db.callbacks, - cacheStore: db.cacheStore, + shared: db.shared, } } diff --git a/interfaces.go b/interfaces.go index 6ba24dc4c..71522455d 100644 --- a/interfaces.go +++ b/interfaces.go @@ -10,6 +10,7 @@ type Dialector interface { Initialize(*DB) error Migrator() Migrator BindVar(stmt *Statement, v interface{}) string + QuoteChars() [2]byte } // CommonDB common db interface diff --git a/statement.go b/statement.go index 26acb319c..bc07b6e4b 100644 --- a/statement.go +++ b/statement.go @@ -81,6 +81,7 @@ func (stmt *Statement) WriteQuoted(field interface{}) (err error) { // Quote returns quoted value func (stmt Statement) Quote(field interface{}) string { var str strings.Builder + str.WriteByte(stmt.DB.quoteChars[0]) switch v := field.(type) { case clause.Table: @@ -91,8 +92,11 @@ func (stmt Statement) Quote(field interface{}) string { } if v.Alias != "" { + str.WriteByte(stmt.DB.quoteChars[1]) str.WriteString(" AS ") + str.WriteByte(stmt.DB.quoteChars[0]) str.WriteString(v.Alias) + str.WriteByte(stmt.DB.quoteChars[1]) } case clause.Column: if v.Table != "" { @@ -101,7 +105,9 @@ func (stmt Statement) Quote(field interface{}) string { } else { str.WriteString(v.Table) } + str.WriteByte(stmt.DB.quoteChars[1]) str.WriteByte('.') + str.WriteByte(stmt.DB.quoteChars[0]) } if v.Name == clause.PrimaryKey { @@ -111,14 +117,19 @@ func (stmt Statement) Quote(field interface{}) string { } else { str.WriteString(v.Name) } + if v.Alias != "" { + str.WriteByte(stmt.DB.quoteChars[1]) str.WriteString(" AS ") + str.WriteByte(stmt.DB.quoteChars[0]) str.WriteString(v.Alias) + str.WriteByte(stmt.DB.quoteChars[1]) } default: fmt.Sprint(field) } + str.WriteByte(stmt.DB.quoteChars[1]) return str.String() } diff --git a/tests/dummy_dialecter.go b/tests/dummy_dialecter.go new file mode 100644 index 000000000..e2cda8fc0 --- /dev/null +++ b/tests/dummy_dialecter.go @@ -0,0 +1,24 @@ +package tests + +import ( + "github.com/jinzhu/gorm" +) + +type DummyDialector struct { +} + +func (DummyDialector) Initialize(*gorm.DB) error { + return nil +} + +func (DummyDialector) Migrator() gorm.Migrator { + return nil +} + +func (DummyDialector) BindVar(stmt *gorm.Statement, v interface{}) string { + return "?" +} + +func (DummyDialector) QuoteChars() [2]byte { + return [2]byte{'`', '`'} // `name` +}