Skip to content

Commit

Permalink
Add clause tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Feb 5, 2020
1 parent 9d19be0 commit 0160bab
Show file tree
Hide file tree
Showing 13 changed files with 92 additions and 21 deletions.
2 changes: 1 addition & 1 deletion chainable_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.AddClause(clause.Where{
ORConditions: []clause.ORConditions{
OrConditions: []clause.OrConditions{
tx.Statement.BuildCondtion(query, args...),
},
})
Expand Down
14 changes: 8 additions & 6 deletions clause/clause_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,19 @@ import (
"github.com/jinzhu/gorm/tests"
)

func TestClause(t *testing.T) {
func TestClauses(t *testing.T) {
var (
db, _ = gorm.Open(nil, nil)
db, _ = gorm.Open(tests.DummyDialector{}, nil)
results = []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{{
[]clause.Interface{clause.Select{}, clause.From{}},
"SELECT * FROM users", []interface{}{},
}}
}{
{
[]clause.Interface{clause.Select{}, clause.From{}, clause.Where{AndConditions: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}}}},
"SELECT * FROM `users` WHERE `users`.`id` = ?", []interface{}{"1"},
},
}
)

for idx, result := range results {
Expand Down
5 changes: 5 additions & 0 deletions clause/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ const (
CurrentTable string = "@@@table@@@"
)

var PrimaryColumn = Column{
Table: CurrentTable,
Name: PrimaryKey,
}

// Expression expression interface
type Expression interface {
Build(builder Builder)
Expand Down
12 changes: 10 additions & 2 deletions clause/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@ import "strings"
// Query Expressions
////////////////////////////////////////////////////////////////////////////////

func Add(exprs ...Expression) AddConditions {
return AddConditions(exprs)
}

func Or(exprs ...Expression) OrConditions {
return OrConditions(exprs)
}

type AddConditions []Expression

func (cs AddConditions) Build(builder Builder) {
Expand All @@ -17,9 +25,9 @@ func (cs AddConditions) Build(builder Builder) {
}
}

type ORConditions []Expression
type OrConditions []Expression

func (cs ORConditions) Build(builder Builder) {
func (cs OrConditions) Build(builder Builder) {
for idx, c := range cs {
if idx > 0 {
builder.Write(" OR ")
Expand Down
8 changes: 4 additions & 4 deletions clause/where.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package clause
// Where where clause
type Where struct {
AndConditions AddConditions
ORConditions []ORConditions
OrConditions []OrConditions
builders []Expression
}

Expand Down Expand Up @@ -31,8 +31,8 @@ func (where Where) Build(builder Builder) {
}
}

var singleOrConditions []ORConditions
for _, or := range where.ORConditions {
var singleOrConditions []OrConditions
for _, or := range where.OrConditions {
if len(or) == 1 {
if withConditions {
builder.Write(" OR ")
Expand Down Expand Up @@ -69,7 +69,7 @@ func (where Where) Build(builder Builder) {
func (where Where) MergeExpression(expr Expression) {
if w, ok := expr.(Where); ok {
where.AndConditions = append(where.AndConditions, w.AndConditions...)
where.ORConditions = append(where.ORConditions, w.ORConditions...)
where.OrConditions = append(where.OrConditions, w.OrConditions...)
where.builders = append(where.builders, w.builders...)
} else {
where.builders = append(where.builders, expr)
Expand Down
4 changes: 4 additions & 0 deletions dialects/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,7 @@ func (Dialector) Migrator() gorm.Migrator {
func (Dialector) BindVar(stmt gorm.Statement, v interface{}) string {
return "?"
}

func (Dialector) QuoteChars() [2]byte {
return [2]byte{'`', '`'} // `name`
}
4 changes: 4 additions & 0 deletions dialects/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,7 @@ func (Dialector) Migrator() gorm.Migrator {
func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
return "?"
}

func (Dialector) QuoteChars() [2]byte {
return [2]byte{'"', '"'} // "name"
}
4 changes: 4 additions & 0 deletions dialects/sqlite/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,7 @@ func (Dialector) Migrator() gorm.Migrator {
func (Dialector) BindVar(stmt *gorm.Statement, v interface{}) string {
return "?"
}

func (Dialector) QuoteChars() [2]byte {
return [2]byte{'`', '`'} // `name`
}
5 changes: 3 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ module github.com/jinzhu/gorm
go 1.13

require (
github.com/go-sql-driver/mysql v1.5.0 // indirect
github.com/jinzhu/inflection v1.0.0
github.com/lib/pq v1.3.0
github.com/mattn/go-sqlite3 v2.0.3+incompatible
github.com/lib/pq v1.3.0 // indirect
github.com/mattn/go-sqlite3 v2.0.3+incompatible // indirect
)
19 changes: 13 additions & 6 deletions gorm.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,21 @@ type Config struct {
NowFunc func() time.Time
}

type shared struct {
callbacks *callbacks
cacheStore *sync.Map
quoteChars [2]byte
}

// DB GORM DB definition
type DB struct {
*Config
Dialector
Instance
DB CommonDB
ClauseBuilders map[string]clause.ClauseBuilder
DB CommonDB
clone bool
callbacks *callbacks
cacheStore *sync.Map
*shared
}

// Session session config when create session with Session() method
Expand Down Expand Up @@ -65,13 +70,16 @@ func Open(dialector Dialector, config *Config) (db *DB, err error) {
Dialector: dialector,
ClauseBuilders: map[string]clause.ClauseBuilder{},
clone: true,
cacheStore: &sync.Map{},
shared: &shared{
cacheStore: &sync.Map{},
},
}

db.callbacks = initializeCallbacks(db)

if dialector != nil {
err = dialector.Initialize(db)
db.quoteChars = dialector.QuoteChars()
}
return
}
Expand Down Expand Up @@ -146,8 +154,7 @@ func (db *DB) getInstance() *DB {
Dialector: db.Dialector,
ClauseBuilders: db.ClauseBuilders,
DB: db.DB,
callbacks: db.callbacks,
cacheStore: db.cacheStore,
shared: db.shared,
}
}

Expand Down
1 change: 1 addition & 0 deletions interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ type Dialector interface {
Initialize(*DB) error
Migrator() Migrator
BindVar(stmt *Statement, v interface{}) string
QuoteChars() [2]byte
}

// CommonDB common db interface
Expand Down
11 changes: 11 additions & 0 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ func (stmt *Statement) WriteQuoted(field interface{}) (err error) {
// Quote returns quoted value
func (stmt Statement) Quote(field interface{}) string {
var str strings.Builder
str.WriteByte(stmt.DB.quoteChars[0])

switch v := field.(type) {
case clause.Table:
Expand All @@ -91,8 +92,11 @@ func (stmt Statement) Quote(field interface{}) string {
}

if v.Alias != "" {
str.WriteByte(stmt.DB.quoteChars[1])
str.WriteString(" AS ")
str.WriteByte(stmt.DB.quoteChars[0])
str.WriteString(v.Alias)
str.WriteByte(stmt.DB.quoteChars[1])
}
case clause.Column:
if v.Table != "" {
Expand All @@ -101,7 +105,9 @@ func (stmt Statement) Quote(field interface{}) string {
} else {
str.WriteString(v.Table)
}
str.WriteByte(stmt.DB.quoteChars[1])
str.WriteByte('.')
str.WriteByte(stmt.DB.quoteChars[0])
}

if v.Name == clause.PrimaryKey {
Expand All @@ -111,14 +117,19 @@ func (stmt Statement) Quote(field interface{}) string {
} else {
str.WriteString(v.Name)
}

if v.Alias != "" {
str.WriteByte(stmt.DB.quoteChars[1])
str.WriteString(" AS ")
str.WriteByte(stmt.DB.quoteChars[0])
str.WriteString(v.Alias)
str.WriteByte(stmt.DB.quoteChars[1])
}
default:
fmt.Sprint(field)
}

str.WriteByte(stmt.DB.quoteChars[1])
return str.String()
}

Expand Down
24 changes: 24 additions & 0 deletions tests/dummy_dialecter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package tests

import (
"github.com/jinzhu/gorm"
)

type DummyDialector struct {
}

func (DummyDialector) Initialize(*gorm.DB) error {
return nil
}

func (DummyDialector) Migrator() gorm.Migrator {
return nil
}

func (DummyDialector) BindVar(stmt *gorm.Statement, v interface{}) string {
return "?"
}

func (DummyDialector) QuoteChars() [2]byte {
return [2]byte{'`', '`'} // `name`
}

0 comments on commit 0160bab

Please sign in to comment.