diff --git a/mapstructure.go b/mapstructure.go index 7581806a..e13c7acf 100644 --- a/mapstructure.go +++ b/mapstructure.go @@ -518,13 +518,13 @@ func (d *Decoder) decodeBasic(name string, data interface{}, val reflect.Value) copied = true // Make *T - copy := reflect.New(elem.Type()) + clone := reflect.New(elem.Type()) // *T = elem - copy.Elem().Set(elem) + clone.Elem().Set(elem) // Set elem so we decode into it - elem = copy + elem = clone } // Decode. If we have an error then return. We also return right @@ -857,7 +857,7 @@ func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val refle valElemType := valType.Elem() // Accumulate errors - errors := make([]string, 0) + errs := make([]string, 0) // If the input data is empty, then we just match what the input data is. if dataVal.Len() == 0 { @@ -879,7 +879,7 @@ func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val refle // First decode the key into the proper type currentKey := reflect.Indirect(reflect.New(valKeyType)) if err := d.decode(fieldName, k.Interface(), currentKey); err != nil { - errors = appendErrors(errors, err) + errs = appendErrors(errs, err) continue } @@ -887,7 +887,7 @@ func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val refle v := dataVal.MapIndex(k).Interface() currentVal := reflect.Indirect(reflect.New(valElemType)) if err := d.decode(fieldName, v, currentVal); err != nil { - errors = appendErrors(errors, err) + errs = appendErrors(errs, err) continue } @@ -898,14 +898,14 @@ func (d *Decoder) decodeMapFromMap(name string, dataVal reflect.Value, val refle val.Set(valMap) // If we had errors, return those - if len(errors) > 0 { - return &Error{errors} + if len(errs) > 0 { + return &Error{errs} } return nil } -func (d *Decoder) decodeMapFromStruct(name string, dataVal reflect.Value, val reflect.Value, valMap reflect.Value) error { +func (d *Decoder) decodeMapFromStruct(_ string, dataVal reflect.Value, val reflect.Value, valMap reflect.Value) error { typ := dataVal.Type() for i := 0; i < typ.NumField(); i++ { // Get the StructField first since this is a cheap operation. If the @@ -1128,7 +1128,7 @@ func (d *Decoder) decodeSlice(name string, data interface{}, val reflect.Value) } // Accumulate any errors - errors := make([]string, 0) + errs := make([]string, 0) for i := 0; i < dataVal.Len(); i++ { currentData := dataVal.Index(i).Interface() @@ -1139,7 +1139,7 @@ func (d *Decoder) decodeSlice(name string, data interface{}, val reflect.Value) fieldName := name + "[" + strconv.Itoa(i) + "]" if err := d.decode(fieldName, currentData, currentField); err != nil { - errors = appendErrors(errors, err) + errs = appendErrors(errs, err) } } @@ -1147,8 +1147,8 @@ func (d *Decoder) decodeSlice(name string, data interface{}, val reflect.Value) val.Set(valSlice) // If there were errors, we return those - if len(errors) > 0 { - return &Error{errors} + if len(errs) > 0 { + return &Error{errs} } return nil @@ -1198,7 +1198,7 @@ func (d *Decoder) decodeArray(name string, data interface{}, val reflect.Value) } // Accumulate any errors - errors := make([]string, 0) + errs := make([]string, 0) for i := 0; i < dataVal.Len(); i++ { currentData := dataVal.Index(i).Interface() @@ -1206,7 +1206,7 @@ func (d *Decoder) decodeArray(name string, data interface{}, val reflect.Value) fieldName := name + "[" + strconv.Itoa(i) + "]" if err := d.decode(fieldName, currentData, currentField); err != nil { - errors = appendErrors(errors, err) + errs = appendErrors(errs, err) } } @@ -1214,8 +1214,8 @@ func (d *Decoder) decodeArray(name string, data interface{}, val reflect.Value) val.Set(valArray) // If there were errors, we return those - if len(errors) > 0 { - return &Error{errors} + if len(errs) > 0 { + return &Error{errs} } return nil @@ -1280,7 +1280,7 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e } targetValKeysUnused := make(map[interface{}]struct{}) - errors := make([]string, 0) + errs := make([]string, 0) // This slice will keep track of all the structs we'll be decoding. // There can be more than one struct if there are embedded structs @@ -1291,20 +1291,23 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e // Compile the list of all the fields that we're going to be decoding // from all the structs. type field struct { - field reflect.StructField - val reflect.Value + field reflect.StructField + val reflect.Value + prefix string } // remainField is set to a valid field set with the "remain" tag if // we are keeping track of remaining values. var remainField *field - fields := []field{} + var fields []field + fieldPrefixes := make(map[reflect.Value]string) for len(structs) > 0 { structVal := structs[0] - structs = structs[1:] - structType := structVal.Type() + fieldPrefix := fieldPrefixes[structVal] + + structs = structs[1:] for i := 0; i < structType.NumField(); i++ { fieldType := structType.Field(i) @@ -1334,27 +1337,29 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e if squash { if fieldVal.Kind() != reflect.Struct { - errors = appendErrors(errors, + errs = appendErrors(errs, fmt.Errorf("%s: unsupported type for squash: %s", fieldType.Name, fieldVal.Kind())) } else { structs = append(structs, fieldVal) + if prefix := tagParts[0]; prefix != "" { + fieldPrefixes[fieldVal] = addPrefix(prefix, fieldPrefix) + } } continue } // Build our field if remain { - remainField = &field{fieldType, fieldVal} + remainField = &field{fieldType, fieldVal, fieldPrefix} } else { // Normal struct field, store it away - fields = append(fields, field{fieldType, fieldVal}) + fields = append(fields, field{fieldType, fieldVal, fieldPrefix}) } } } - // for fieldType, field := range fields { for _, f := range fields { - field, fieldValue := f.field, f.val + field, fieldValue, fieldPrefix := f.field, f.val, f.prefix fieldName := field.Name tagValue := field.Tag.Get(d.config.TagName) @@ -1362,6 +1367,9 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e if tagValue != "" { fieldName = tagValue } + if fieldPrefix != "" { + fieldName = addPrefix(fieldName, fieldPrefix) + } rawMapKey := reflect.ValueOf(fieldName) rawMapVal := dataVal.MapIndex(rawMapKey) @@ -1411,7 +1419,7 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e } if err := d.decode(fieldName, rawMapVal.Interface(), fieldValue); err != nil { - errors = appendErrors(errors, err) + errs = appendErrors(errs, err) } } @@ -1426,7 +1434,7 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e // Decode it as-if we were just decoding this map onto our map. if err := d.decodeMap(name, remain, remainField.val); err != nil { - errors = appendErrors(errors, err) + errs = appendErrors(errs, err) } // Set the map to nil so we have none so that the next check will @@ -1442,7 +1450,7 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e sort.Strings(keys) err := fmt.Errorf("'%s' has invalid keys: %s", name, strings.Join(keys, ", ")) - errors = appendErrors(errors, err) + errs = appendErrors(errs, err) } if d.config.ErrorUnset && len(targetValKeysUnused) > 0 { @@ -1453,11 +1461,11 @@ func (d *Decoder) decodeStructFromMap(name string, dataVal, val reflect.Value) e sort.Strings(keys) err := fmt.Errorf("'%s' has unset fields: %s", name, strings.Join(keys, ", ")) - errors = appendErrors(errors, err) + errs = appendErrors(errs, err) } - if len(errors) > 0 { - return &Error{errors} + if len(errs) > 0 { + return &Error{errs} } // Add the unused keys to the list of unused keys if we're tracking metadata @@ -1540,3 +1548,10 @@ func dereferencePtrToStructIfNeeded(v reflect.Value, tagName string) reflect.Val } return v } + +func addPrefix(s string, prefix string) string { + if prefix == "" { + return s + } + return prefix + "_" + s +} diff --git a/mapstructure_test.go b/mapstructure_test.go index d31129d7..c9b5198a 100644 --- a/mapstructure_test.go +++ b/mapstructure_test.go @@ -2732,6 +2732,51 @@ func TestDecoder_IgnoreUntaggedFields(t *testing.T) { } } +func TestDecoder_Decode_SquashWithPrefix(t *testing.T) { + type Git struct { + Remote string `mapstructure:"remote"` + } + + type GitHub struct { + Git `mapstructure:"git,squash"` + Token string `mapstructure:"token"` + } + + type Config struct { + GitHub `mapstructure:"github,squash"` + } + + var cnf Config + decoder, err := NewDecoder(&DecoderConfig{ + DecodeHook: nil, + ErrorUnused: false, + ZeroFields: false, + WeaklyTypedInput: false, + Squash: false, + Metadata: nil, + Result: &cnf, + TagName: "", + MatchName: nil, + }) + if err != nil { + t.Fatalf("err: %s", err) + } + + input := map[string]interface{}{ + "GITHUB_GIT_REMOTE": "git@github.com:mitchellh/mapstructure.git", + "GITHUB_TOKEN": "secret", + } + if err := decoder.Decode(input); err != nil { + t.Fatalf("err: %s", err) + } + if cnf.Remote != input["GITHUB_GIT_REMOTE"].(string) { + t.Errorf("expected: %#v, obtained: %#v", input["GITHUB_GIT_REMOTE"], cnf.Remote) + } + if cnf.Token != input["GITHUB_TOKEN"].(string) { + t.Errorf("expected: %#v, obtained: %#v", input["GITHUB_TOKEN"], cnf.Token) + } +} + func testSliceInput(t *testing.T, input map[string]interface{}, expected *Slice) { var result Slice err := Decode(input, &result)