diff --git a/model_table_has_many.go b/model_table_has_many.go index f3e977fca..26f1266ad 100644 --- a/model_table_has_many.go +++ b/model_table_has_many.go @@ -94,7 +94,7 @@ func (m *hasManyModel) Scan(src interface{}) error { for _, f := range m.rel.JoinFields { if f.Name == field.Name { - m.structKey = append(m.structKey, getFieldValue(field.Value(m.strct))) + m.structKey = append(m.structKey, indirectFieldValue(field.Value(m.strct))) break } } @@ -144,24 +144,19 @@ func baseValues(model TableModel, fields []*schema.Field) map[internal.MapKey][] func modelKey(key []interface{}, strct reflect.Value, fields []*schema.Field) []interface{} { for _, f := range fields { - key = append(key, getFieldValue(f.Value(strct))) + key = append(key, indirectFieldValue(f.Value(strct))) } return key } -// getFieldValue extracts the value from a reflect.Value, handling pointer types appropriately. -func getFieldValue(fieldValue reflect.Value) interface{} { - var keyValue interface{} - - if fieldValue.Kind() == reflect.Ptr { - if !fieldValue.IsNil() { - keyValue = fieldValue.Elem().Interface() - } else { - keyValue = nil - } - } else { - keyValue = fieldValue.Interface() +// indirectFieldValue return the field value dereferencing the pointer if necessary. +// The value is then used as a map key. +func indirectFieldValue(field reflect.Value) interface{} { + if field.Kind() != reflect.Ptr { + return field.Interface() } - - return keyValue + if field.IsNil() { + return nil + } + return field.Elem().Interface() }