Skip to content

Commit

Permalink
Setup clauses tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Feb 4, 2020
1 parent 46b1c85 commit 9d19be0
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 8 deletions.
4 changes: 1 addition & 3 deletions callbacks/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@ import (

func Query(db *gorm.DB) {
db.Statement.AddClauseIfNotExists(clause.Select{})
db.Statement.AddClauseIfNotExists(clause.From{
Tables: []clause.Table{{Table: clause.CurrentTable}},
})
db.Statement.AddClauseIfNotExists(clause.From{})

db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR")
result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...)
Expand Down
54 changes: 54 additions & 0 deletions clause/clause_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package clause_test

import (
"fmt"
"reflect"
"sync"
"testing"

"github.com/jinzhu/gorm"
"github.com/jinzhu/gorm/clause"
"github.com/jinzhu/gorm/schema"
"github.com/jinzhu/gorm/tests"
)

func TestClause(t *testing.T) {
var (
db, _ = gorm.Open(nil, nil)
results = []struct {
Clauses []clause.Interface
Result string
Vars []interface{}
}{{
[]clause.Interface{clause.Select{}, clause.From{}},
"SELECT * FROM users", []interface{}{},
}}
)

for idx, result := range results {
t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) {
var (
user, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy)
stmt = gorm.Statement{
DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{},
}
buildNames []string
)

for _, c := range result.Clauses {
buildNames = append(buildNames, c.Name())
stmt.AddClause(c)
}

stmt.Build(buildNames...)

if stmt.SQL.String() != result.Result {
t.Errorf("SQL expects %v got %v", result.Result, stmt.SQL.String())
}

if reflect.DeepEqual(stmt.Vars, result.Vars) {
t.Errorf("Vars expects %+v got %v", stmt.Vars, result.Vars)
}
})
}
}
16 changes: 11 additions & 5 deletions clause/from.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,20 @@ func (From) Name() string {
return "FROM"
}

var currentTable = Table{Table: CurrentTable}

// Build build from clause
func (from From) Build(builder Builder) {
for idx, table := range from.Tables {
if idx > 0 {
builder.WriteByte(',')
}
if len(from.Tables) > 0 {
for idx, table := range from.Tables {
if idx > 0 {
builder.WriteByte(',')
}

builder.WriteQuoted(table)
builder.WriteQuoted(table)
}
} else {
builder.WriteQuoted(currentTable)
}
}

Expand Down
5 changes: 5 additions & 0 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ func (stmt Statement) Quote(field interface{}) string {

switch v := field.(type) {
case clause.Table:
if v.Table == clause.CurrentTable {
str.WriteString(stmt.Table)
} else {
str.WriteString(v.Table)
}

if v.Alias != "" {
str.WriteString(" AS ")
Expand Down

0 comments on commit 9d19be0

Please sign in to comment.