Skip to content

Commit

Permalink
Test Pluck
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed May 31, 2020
1 parent e26abb8 commit 95a6539
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 11 deletions.
1 change: 1 addition & 0 deletions finisher_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) {
// db.Find(&users).Pluck("age", &ages)
func (db *DB) Pluck(column string, dest interface{}) (tx *DB) {
tx = db.getInstance()
tx.Statement.AddClause(clause.Select{Columns: []clause.Column{{Name: column}}})
tx.Statement.Dest = dest
tx.callbacks.Query().Execute(tx)
return
Expand Down
32 changes: 21 additions & 11 deletions scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,12 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {
default:
switch db.Statement.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
isPtr := db.Statement.ReflectValue.Type().Elem().Kind() == reflect.Ptr
reflectValueType := db.Statement.ReflectValue.Type().Elem()
isPtr := reflectValueType.Kind() == reflect.Ptr
if isPtr {
reflectValueType = reflectValueType.Elem()
}

db.Statement.ReflectValue.Set(reflect.MakeSlice(db.Statement.ReflectValue.Type(), 0, 0))
fields := make([]*schema.Field, len(columns))
joinFields := make([][2]*schema.Field, len(columns))
Expand All @@ -81,17 +86,22 @@ func Scan(rows *sql.Rows, db *DB, initialized bool) {

for initialized || rows.Next() {
initialized = false
elem := reflect.New(db.Statement.Schema.ModelType).Elem()
for idx, field := range fields {
if field != nil {
values[idx] = field.ReflectValueOf(elem).Addr().Interface()
} else if joinFields[idx][0] != nil {
relValue := joinFields[idx][0].ReflectValueOf(elem)
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
relValue.Set(reflect.New(relValue.Type().Elem()))
}
elem := reflect.New(reflectValueType).Elem()

values[idx] = joinFields[idx][1].ReflectValueOf(relValue).Addr().Interface()
if reflectValueType.Kind() != reflect.Struct && len(fields) == 1 {
values[0] = elem.Addr().Interface()
} else {
for idx, field := range fields {
if field != nil {
values[idx] = field.ReflectValueOf(elem).Addr().Interface()
} else if joinFields[idx][0] != nil {
relValue := joinFields[idx][0].ReflectValueOf(elem)
if relValue.Kind() == reflect.Ptr && relValue.IsNil() {
relValue.Set(reflect.New(relValue.Type().Elem()))
}

values[idx] = joinFields[idx][1].ReflectValueOf(relValue).Addr().Interface()
}
}
}

Expand Down
32 changes: 32 additions & 0 deletions tests/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,35 @@ func TestFind(t *testing.T) {
}
}
}

func TestPluck(t *testing.T) {
users := []*User{
GetUser("pluck-user1", Config{}),
GetUser("pluck-user2", Config{}),
GetUser("pluck-user3", Config{}),
}

DB.Create(&users)

var names []string
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Order("name").Pluck("name", &names).Error; err != nil {
t.Errorf("Raise error when pluck name, got %v", err)
}

var ids []int
if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Order("name").Pluck("id", &ids).Error; err != nil {
t.Errorf("Raise error when pluck id, got %v", err)
}

for idx, name := range names {
if name != users[idx].Name {
t.Errorf("Unexpected result on pluck name, got %+v", names)
}
}

for idx, id := range ids {
if int(id) != int(users[idx].ID) {
t.Errorf("Unexpected result on pluck id, got %+v", ids)
}
}
}

0 comments on commit 95a6539

Please sign in to comment.