diff --git a/cmd/greenmask/cmd/root.go b/cmd/greenmask/cmd/root.go index 2fc4c1b5..8860fcca 100644 --- a/cmd/greenmask/cmd/root.go +++ b/cmd/greenmask/cmd/root.go @@ -137,6 +137,8 @@ func initConfig() { decoderCfg := func(cfg *mapstructure.DecoderConfig) { cfg.DecodeHook = mapstructure.ComposeDecodeHookFunc( configUtils.ParamsToByteSliceHookFunc(), + configUtils.StringToStructHookFunc(), + configUtils.StringToSliceWithBracketHookFunc(), mapstructure.StringToTimeDurationHookFunc(), mapstructure.StringToSliceHookFunc(","), ) diff --git a/internal/utils/config/mapstructure_hook.go b/internal/utils/config/mapstructure_hook.go index 66d4e322..37058b05 100644 --- a/internal/utils/config/mapstructure_hook.go +++ b/internal/utils/config/mapstructure_hook.go @@ -46,3 +46,60 @@ func ParamsToByteSliceHookFunc() mapstructure.DecodeHookFunc { } } } + +func StringToSliceWithBracketHookFunc() mapstructure.DecodeHookFunc { + return func( + f reflect.Kind, + t reflect.Kind, + data interface{}) (interface{}, error) { + if f != reflect.String || t != reflect.Slice { + return data, nil + } + + raw := data.(string) + if raw == "" { + return []string{}, nil + } + var slice []json.RawMessage + err := json.Unmarshal([]byte(raw), &slice) + if err != nil { + return data, nil + } + + var strSlice []string + for _, v := range slice { + strSlice = append(strSlice, string(v)) + } + return strSlice, nil + } +} + +func StringToStructHookFunc() mapstructure.DecodeHookFunc { + return func( + f reflect.Type, + t reflect.Type, + data interface{}, + ) (interface{}, error) { + if f.Kind() != reflect.String || + (t.Kind() != reflect.Struct && !(t.Kind() == reflect.Pointer && t.Elem().Kind() == reflect.Struct)) { + return data, nil + } + raw := data.(string) + var val reflect.Value + // Struct or the pointer to a struct + if t.Kind() == reflect.Struct { + val = reflect.New(t) + } else { + val = reflect.New(t.Elem()) + } + + if raw == "" { + return val, nil + } + err := json.Unmarshal([]byte(raw), val.Interface()) + if err != nil { + return data, nil + } + return val.Interface(), nil + } +}