diff --git a/callbacks/query.go b/callbacks/query.go index edf8f281c..8d13095e7 100644 --- a/callbacks/query.go +++ b/callbacks/query.go @@ -9,9 +9,7 @@ import ( func Query(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clause.Select{}) - db.Statement.AddClauseIfNotExists(clause.From{ - Tables: []clause.Table{{Table: clause.CurrentTable}}, - }) + db.Statement.AddClauseIfNotExists(clause.From{}) db.Statement.Build("SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR") result, err := db.DB.ExecContext(db.Context, db.Statement.SQL.String(), db.Statement.Vars...) diff --git a/clause/clause_test.go b/clause/clause_test.go new file mode 100644 index 000000000..97d30f2d3 --- /dev/null +++ b/clause/clause_test.go @@ -0,0 +1,54 @@ +package clause_test + +import ( + "fmt" + "reflect" + "sync" + "testing" + + "github.com/jinzhu/gorm" + "github.com/jinzhu/gorm/clause" + "github.com/jinzhu/gorm/schema" + "github.com/jinzhu/gorm/tests" +) + +func TestClause(t *testing.T) { + var ( + db, _ = gorm.Open(nil, nil) + results = []struct { + Clauses []clause.Interface + Result string + Vars []interface{} + }{{ + []clause.Interface{clause.Select{}, clause.From{}}, + "SELECT * FROM users", []interface{}{}, + }} + ) + + for idx, result := range results { + t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { + var ( + user, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) + stmt = gorm.Statement{ + DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}, + } + buildNames []string + ) + + for _, c := range result.Clauses { + buildNames = append(buildNames, c.Name()) + stmt.AddClause(c) + } + + stmt.Build(buildNames...) + + if stmt.SQL.String() != result.Result { + t.Errorf("SQL expects %v got %v", result.Result, stmt.SQL.String()) + } + + if reflect.DeepEqual(stmt.Vars, result.Vars) { + t.Errorf("Vars expects %+v got %v", stmt.Vars, result.Vars) + } + }) + } +} diff --git a/clause/from.go b/clause/from.go index 1a7bcb5c2..b7665bc3e 100644 --- a/clause/from.go +++ b/clause/from.go @@ -10,14 +10,20 @@ func (From) Name() string { return "FROM" } +var currentTable = Table{Table: CurrentTable} + // Build build from clause func (from From) Build(builder Builder) { - for idx, table := range from.Tables { - if idx > 0 { - builder.WriteByte(',') - } + if len(from.Tables) > 0 { + for idx, table := range from.Tables { + if idx > 0 { + builder.WriteByte(',') + } - builder.WriteQuoted(table) + builder.WriteQuoted(table) + } + } else { + builder.WriteQuoted(currentTable) } } diff --git a/statement.go b/statement.go index b24075995..26acb319c 100644 --- a/statement.go +++ b/statement.go @@ -84,6 +84,11 @@ func (stmt Statement) Quote(field interface{}) string { switch v := field.(type) { case clause.Table: + if v.Table == clause.CurrentTable { + str.WriteString(stmt.Table) + } else { + str.WriteString(v.Table) + } if v.Alias != "" { str.WriteString(" AS ")