diff --git a/internal/decoder/optdec/compile_struct.go b/internal/decoder/optdec/compile_struct.go index 51552a287..03e3f66f8 100644 --- a/internal/decoder/optdec/compile_struct.go +++ b/internal/decoder/optdec/compile_struct.go @@ -39,7 +39,44 @@ func (c *compiler) compileIntStringOption(vt reflect.Type) decFunc { panic("unreachable") } +func isInteger(vt reflect.Type) bool { + switch vt.Kind() { + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, reflect.Uintptr, reflect.Int: return true + default: return false + } +} + +func (c *compiler) assertStringOptTypes(vt reflect.Type) { + if c.depth > _CompileMaxDepth { + panic(*stackOverflow) + } + + c.depth += 1 + defer func () { + c.depth -= 1 + }() + + if isInteger(vt) { + return + } + + switch vt.Kind() { + case reflect.String, reflect.Bool, reflect.Float32, reflect.Float64: + return + case reflect.Ptr: c.assertStringOptTypes(vt.Elem()) + default: + panicForInvalidStrType(vt) + } +} + func (c *compiler) compileFieldStringOption(vt reflect.Type) decFunc { + c.assertStringOptTypes(vt) + + unmDec := c.tryCompilePtrUnmarshaler(vt, true) + if unmDec != nil { + return unmDec + } + switch vt.Kind() { case reflect.String: if vt == jsonNumberType { @@ -80,7 +117,8 @@ func (c *compiler) compileFieldStringOption(vt reflect.Type) decFunc { deref: c.compileFieldStringOption(vt.Elem()), } default: - panic("string options should appliy only to fields of string, floating point, integer, or boolean types.") + panicForInvalidStrType(vt) + return nil } } diff --git a/internal/decoder/optdec/compiler.go b/internal/decoder/optdec/compiler.go index bb47f91f8..3a2d5ec43 100644 --- a/internal/decoder/optdec/compiler.go +++ b/internal/decoder/optdec/compiler.go @@ -34,7 +34,6 @@ type compiler struct { counts int opts option.CompileOptions namedPtr bool - } func newCompiler() *compiler { @@ -114,7 +113,7 @@ func (c *compiler) compile(vt reflect.Type) decFunc { } } - dec := c.tryCompilePtrUnmarshaler(vt) + dec := c.tryCompilePtrUnmarshaler(vt, false) if dec != nil { return dec } @@ -420,16 +419,21 @@ func (c *compiler) compileMapKey(vt reflect.Type) decKey { } // maybe vt is a named type, and not a pointer receiver, see issue 379 -func (c *compiler) tryCompilePtrUnmarshaler(vt reflect.Type) decFunc { +func (c *compiler) tryCompilePtrUnmarshaler(vt reflect.Type, strOpt bool) decFunc { pt := reflect.PtrTo(vt) /* check for `json.Unmarshaler` with pointer receiver */ if pt.Implements(jsonUnmarshalerType) { return &unmarshalJSONDecoder{ typ: rt.UnpackType(pt), + strOpt: strOpt, } } + if strOpt { + panicForInvalidStrType(vt) + } + /* check for `encoding.TextMarshaler` with pointer receiver */ if pt.Implements(encodingTextUnmarshalerType) { return &unmarshalTextDecoder{ @@ -439,3 +443,7 @@ func (c *compiler) tryCompilePtrUnmarshaler(vt reflect.Type) decFunc { return nil } + +func panicForInvalidStrType(vt reflect.Type) { + panic(error_type(rt.UnpackType(vt))) +} diff --git a/internal/decoder/optdec/interface.go b/internal/decoder/optdec/interface.go index b96d3fb1c..0c063d55f 100644 --- a/internal/decoder/optdec/interface.go +++ b/internal/decoder/optdec/interface.go @@ -131,7 +131,8 @@ func (d *unmarshalTextDecoder) FromDom(vp unsafe.Pointer, node Node, ctx *contex } type unmarshalJSONDecoder struct { - typ *rt.GoType + typ *rt.GoType + strOpt bool } func (d *unmarshalJSONDecoder) FromDom(vp unsafe.Pointer, node Node, ctx *context) error { @@ -140,15 +141,28 @@ func (d *unmarshalJSONDecoder) FromDom(vp unsafe.Pointer, node Node, ctx *contex Value: vp, })) + var input []byte + if d.strOpt && node.IsNull() { + input = []byte("null") + } else if d.strOpt { + s, ok := node.AsStringText(ctx) + if !ok { + return error_mismatch(node, ctx, d.typ.Pack()) + } + input = s + } else { + input = []byte(node.AsRaw(ctx)) + } + // fast path if u, ok := v.(json.Unmarshaler); ok { - return u.UnmarshalJSON([]byte(node.AsRaw(ctx))) + return u.UnmarshalJSON((input)) } // slow path rv := reflect.ValueOf(v) if u, ok := rv.Interface().(json.Unmarshaler); ok { - return u.UnmarshalJSON([]byte(node.AsRaw(ctx))) + return u.UnmarshalJSON(input) } return error_type(d.typ)