Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: didn't consider json.Marshaler/Unmarshal when handling json:",string" tag #682

Merged
merged 6 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions internal/decoder/jitdec/assembler_regabi_amd64.go
Original file line number Diff line number Diff line change
Expand Up @@ -972,11 +972,13 @@ var (

var (
_F_decodeJsonUnmarshaler obj.Addr
_F_decodeJsonUnmarshalerQuoted obj.Addr
_F_decodeTextUnmarshaler obj.Addr
)

func init() {
_F_decodeJsonUnmarshaler = jit.Func(decodeJsonUnmarshaler)
_F_decodeJsonUnmarshalerQuoted = jit.Func(decodeJsonUnmarshalerQuoted)
_F_decodeTextUnmarshaler = jit.Func(decodeTextUnmarshaler)
}

Expand Down Expand Up @@ -1061,14 +1063,15 @@ var (
_F_skip_number = jit.Imm(int64(native.S_skip_number))
)

func (self *_Assembler) unmarshal_json(t reflect.Type, deref bool) {
func (self *_Assembler) unmarshal_json(t reflect.Type, deref bool, f obj.Addr) {
self.call_sf(_F_skip_one) // CALL_SF skip_one
self.Emit("TESTQ", _AX, _AX) // TESTQ AX, AX
self.Sjmp("JS" , _LB_parsing_error_v) // JS _parse_error_v
self.Emit("MOVQ", _IC, _VAR_ic) // store for mismatche error skip
self.slice_from_r(_AX, 0) // SLICE_R AX, $0
self.Emit("MOVQ" , _DI, _ARG_sv_p) // MOVQ DI, sv.p
self.Emit("MOVQ" , _SI, _ARG_sv_n) // MOVQ SI, sv.n
self.unmarshal_func(t, _F_decodeJsonUnmarshaler, deref) // UNMARSHAL json, ${t}, ${deref}
self.unmarshal_func(t, f, deref) // UNMARSHAL json, ${t}, ${deref}
}

func (self *_Assembler) unmarshal_text(t reflect.Type, deref bool) {
Expand Down Expand Up @@ -1103,7 +1106,15 @@ func (self *_Assembler) unmarshal_func(t reflect.Type, fn obj.Addr, deref bool)
self.Emit("MOVQ" , _ARG_sv_n, _DI) // MOVQ sv.n, DI
self.call_go(fn) // CALL_GO ${fn}
self.Emit("TESTQ", _ET, _ET) // TESTQ ET, ET
self.Sjmp("JNZ" , _LB_error) // JNZ _error
self.Sjmp("JZ" , "_unmarshal_func_end_{n}") // JNZ _error
self.Emit("MOVQ", _I_json_MismatchTypeError, _CX) // MOVQ ET, VAR.et
self.Emit("CMPQ", _ET, _CX) // check if MismatchedError
self.Sjmp("JNE" , _LB_error)
self.Emit("MOVQ", jit.Type(t), _CX) // store current type
self.Emit("MOVQ", _CX, _VAR_et) // store current type
self.Emit("MOVQ", _VAR_ic, _IC) // recover the pos
self.Emit("XORL", _ET, _ET)
self.Link("_unmarshal_func_end_{n}")
}

/** Dynamic Decoding Routine **/
Expand Down Expand Up @@ -1774,11 +1785,19 @@ func (self *_Assembler) _asm_OP_struct_field(p *_Instr) {
}

func (self *_Assembler) _asm_OP_unmarshal(p *_Instr) {
self.unmarshal_json(p.vt(), true)
if iv := p.i64(); iv != 0 {
self.unmarshal_json(p.vt(), true, _F_decodeJsonUnmarshalerQuoted)
} else {
self.unmarshal_json(p.vt(), true, _F_decodeJsonUnmarshaler)
}
}

func (self *_Assembler) _asm_OP_unmarshal_p(p *_Instr) {
self.unmarshal_json(p.vt(), false)
if iv := p.i64(); iv != 0 {
self.unmarshal_json(p.vt(), false, _F_decodeJsonUnmarshalerQuoted)
} else {
self.unmarshal_json(p.vt(), false, _F_decodeJsonUnmarshaler)
}
}

func (self *_Assembler) _asm_OP_unmarshal_text(p *_Instr) {
Expand Down
77 changes: 61 additions & 16 deletions internal/decoder/jitdec/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,13 @@ func newInsVt(op _Op, vt reflect.Type) _Instr {
}
}

func newInsVtI(op _Op, vt reflect.Type, iv int) _Instr {
return _Instr {
u: packOp(op) | rt.PackInt(iv),
p: unsafe.Pointer(rt.UnpackType(vt)),
}
}

func newInsVf(op _Op, vf *caching.FieldMap) _Instr {
return _Instr {
u: packOp(op),
Expand Down Expand Up @@ -452,6 +459,10 @@ func (self *_Program) rtt(op _Op, vt reflect.Type) {
*self = append(*self, newInsVt(op, vt))
}

func (self *_Program) rtti(op _Op, vt reflect.Type, iv int) {
*self = append(*self, newInsVtI(op, vt, iv))
}

func (self *_Program) fmv(op _Op, vf *caching.FieldMap) {
*self = append(*self, newInsVf(op, vf))
}
Expand Down Expand Up @@ -527,35 +538,54 @@ func (self *_Compiler) compile(vt reflect.Type) (ret _Program, err error) {
return
}

func (self *_Compiler) checkMarshaler(p *_Program, vt reflect.Type) bool {
const (
checkMarshalerFlags_quoted = 1
)

func (self *_Compiler) checkMarshaler(p *_Program, vt reflect.Type, flags int, exec bool) bool {
pt := reflect.PtrTo(vt)

/* check for `json.Unmarshaler` with pointer receiver */
if pt.Implements(jsonUnmarshalerType) {
p.rtt(_OP_unmarshal_p, pt)
if exec {
p.add(_OP_lspace)
p.rtti(_OP_unmarshal_p, pt, flags)
}
return true
}

/* check for `json.Unmarshaler` */
if vt.Implements(jsonUnmarshalerType) {
p.add(_OP_lspace)
self.compileUnmarshalJson(p, vt)
if exec {
p.add(_OP_lspace)
self.compileUnmarshalJson(p, vt, flags)
}
return true
}

if flags == checkMarshalerFlags_quoted {
// text marshaler shouldn't be supported for quoted string
return false
}

/* check for `encoding.TextMarshaler` with pointer receiver */
if pt.Implements(encodingTextUnmarshalerType) {
p.add(_OP_lspace)
self.compileUnmarshalTextPtr(p, pt)
if exec {
p.add(_OP_lspace)
self.compileUnmarshalTextPtr(p, pt, flags)
}
return true
}

/* check for `encoding.TextUnmarshaler` */
if vt.Implements(encodingTextUnmarshalerType) {
p.add(_OP_lspace)
self.compileUnmarshalText(p, vt)
if exec {
p.add(_OP_lspace)
self.compileUnmarshalText(p, vt, flags)
}
return true
}

return false
}

Expand All @@ -567,7 +597,7 @@ func (self *_Compiler) compileOne(p *_Program, sp int, vt reflect.Type) {
return
}

if self.checkMarshaler(p, vt) {
if self.checkMarshaler(p, vt, 0, true) {
return
}

Expand Down Expand Up @@ -690,7 +720,7 @@ func (self *_Compiler) compilePtr(p *_Program, sp int, et reflect.Type) {

/* dereference all the way down */
for et.Kind() == reflect.Ptr {
if self.checkMarshaler(p, et) {
if self.checkMarshaler(p, et, 0, true) {
return
}
et = et.Elem()
Expand Down Expand Up @@ -938,7 +968,22 @@ end_of_object:
p.pin(skip)
}

func (self *_Compiler) compileStructFieldStrUnmarshal(p *_Program, vt reflect.Type) {
p.add(_OP_lspace)
n0 := p.pc()
p.add(_OP_is_null)
self.checkMarshaler(p, vt, checkMarshalerFlags_quoted, true)
p.pin(n0)
}

func (self *_Compiler) compileStructFieldStr(p *_Program, sp int, vt reflect.Type) {
// according to std, json.Unmarshaler should be called before stringize
// see https://github.com/bytedance/sonic/issues/670
if self.checkMarshaler(p, vt, checkMarshalerFlags_quoted, false) {
self.compileStructFieldStrUnmarshal(p, vt)
return
}

n1 := -1
ft := vt
sv := false
Expand Down Expand Up @@ -1106,7 +1151,7 @@ func (self *_Compiler) compileUnmarshalEnd(p *_Program, vt reflect.Type, i int)
p.pin(j)
}

func (self *_Compiler) compileUnmarshalJson(p *_Program, vt reflect.Type) {
func (self *_Compiler) compileUnmarshalJson(p *_Program, vt reflect.Type, flags int) {
i := p.pc()
v := _OP_unmarshal
p.add(_OP_is_null)
Expand All @@ -1117,11 +1162,11 @@ func (self *_Compiler) compileUnmarshalJson(p *_Program, vt reflect.Type) {
}

/* call the unmarshaler */
p.rtt(v, vt)
p.rtti(v, vt, flags)
self.compileUnmarshalEnd(p, vt, i)
}

func (self *_Compiler) compileUnmarshalText(p *_Program, vt reflect.Type) {
func (self *_Compiler) compileUnmarshalText(p *_Program, vt reflect.Type, iv int) {
i := p.pc()
v := _OP_unmarshal_text
p.add(_OP_is_null)
Expand All @@ -1134,15 +1179,15 @@ func (self *_Compiler) compileUnmarshalText(p *_Program, vt reflect.Type) {
}

/* call the unmarshaler */
p.rtt(v, vt)
p.rtti(v, vt, iv)
self.compileUnmarshalEnd(p, vt, i)
}

func (self *_Compiler) compileUnmarshalTextPtr(p *_Program, vt reflect.Type) {
func (self *_Compiler) compileUnmarshalTextPtr(p *_Program, vt reflect.Type, iv int) {
i := p.pc()
p.add(_OP_is_null)
p.chr(_OP_match_char, '"')
p.rtt(_OP_unmarshal_text_p, vt)
p.rtti(_OP_unmarshal_text_p, vt, iv)
p.pin(i)
}

Expand Down
1 change: 1 addition & 0 deletions internal/decoder/jitdec/generic_regabi_amd64.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ var (
_T_slice = jit.Type(reflect.TypeOf(([]interface{})(nil)))
_T_string = jit.Type(reflect.TypeOf(""))
_T_number = jit.Type(reflect.TypeOf(json.Number("")))
_T_miserr = jit.Type(reflect.TypeOf(MismatchTypeError{}))
_T_float64 = jit.Type(reflect.TypeOf(float64(0)))
)

Expand Down
7 changes: 7 additions & 0 deletions internal/decoder/jitdec/primitives.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ func decodeJsonUnmarshaler(vv interface{}, s string) error {
return vv.(json.Unmarshaler).UnmarshalJSON(rt.Str2Mem(s))
}

func decodeJsonUnmarshalerQuoted(vv interface{}, s string) error {
if len(s) < 2 || s[0] != '"' || s[len(s)-1] != '"' {
return &MismatchTypeError{}
}
return vv.(json.Unmarshaler).UnmarshalJSON(rt.Str2Mem(s[1:len(s)-1]))
}

func decodeTextUnmarshaler(vv interface{}, s string) error {
return vv.(encoding.TextUnmarshaler).UnmarshalText(rt.Str2Mem(s))
}
39 changes: 38 additions & 1 deletion internal/decoder/optdec/compile_struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,43 @@ func (c *compiler) compileIntStringOption(vt reflect.Type) decFunc {
panic("unreachable")
}

func isInteger(vt reflect.Type) bool {
switch vt.Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint, reflect.Uintptr, reflect.Int: return true
default: return false
}
}

func (c *compiler) assertStringOptTypes(vt reflect.Type) {
if c.depth > _CompileMaxDepth {
panic(*stackOverflow)
}

c.depth += 1
defer func () {
c.depth -= 1
}()

if isInteger(vt) {
return
}

switch vt.Kind() {
case reflect.String, reflect.Bool, reflect.Float32, reflect.Float64:
return
case reflect.Ptr: c.assertStringOptTypes(vt.Elem())
default:
panicForInvalidStrType(vt)
}
}

func (c *compiler) compileFieldStringOption(vt reflect.Type) decFunc {
c.assertStringOptTypes(vt)
unmDec := c.tryCompilePtrUnmarshaler(vt, true)
if unmDec != nil {
return unmDec
}

switch vt.Kind() {
case reflect.String:
if vt == jsonNumberType {
Expand Down Expand Up @@ -80,7 +116,8 @@ func (c *compiler) compileFieldStringOption(vt reflect.Type) decFunc {
deref: c.compileFieldStringOption(vt.Elem()),
}
default:
panic("string options should appliy only to fields of string, floating point, integer, or boolean types.")
panicForInvalidStrType(vt)
return nil
}
}

Expand Down
14 changes: 11 additions & 3 deletions internal/decoder/optdec/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ type compiler struct {
counts int
opts option.CompileOptions
namedPtr bool

}

func newCompiler() *compiler {
Expand Down Expand Up @@ -114,7 +113,7 @@ func (c *compiler) compile(vt reflect.Type) decFunc {
}
}

dec := c.tryCompilePtrUnmarshaler(vt)
dec := c.tryCompilePtrUnmarshaler(vt, false)
if dec != nil {
return dec
}
Expand Down Expand Up @@ -420,22 +419,31 @@ func (c *compiler) compileMapKey(vt reflect.Type) decKey {
}

// maybe vt is a named type, and not a pointer receiver, see issue 379
func (c *compiler) tryCompilePtrUnmarshaler(vt reflect.Type) decFunc {
func (c *compiler) tryCompilePtrUnmarshaler(vt reflect.Type, strOpt bool) decFunc {
pt := reflect.PtrTo(vt)

/* check for `json.Unmarshaler` with pointer receiver */
if pt.Implements(jsonUnmarshalerType) {
return &unmarshalJSONDecoder{
typ: rt.UnpackType(pt),
strOpt: strOpt,
}
}

/* check for `encoding.TextMarshaler` with pointer receiver */
if pt.Implements(encodingTextUnmarshalerType) {
/* TextUnmarshal not support ,strig tag */
if strOpt {
panicForInvalidStrType(vt)
}
return &unmarshalTextDecoder{
typ: rt.UnpackType(pt),
}
}

return nil
}

func panicForInvalidStrType(vt reflect.Type) {
panic(error_type(rt.UnpackType(vt)))
}
Loading
Loading