From 84142fe07a2d5f1344ddd05d3e49208d189cc84c Mon Sep 17 00:00:00 2001 From: "liuqiang.06" Date: Mon, 2 Dec 2024 17:54:28 +0800 Subject: [PATCH] fix(jit): out of index when dump mismatched error --- .../decoder/jitdec/assembler_regabi_amd64.go | 23 ++++---- internal/decoder/jitdec/primitives.go | 9 +++- issue_test/issue670_test.go | 7 --- issue_test/issue716_test.go | 54 +++++++++++++++++++ 4 files changed, 76 insertions(+), 17 deletions(-) create mode 100644 issue_test/issue716_test.go diff --git a/internal/decoder/jitdec/assembler_regabi_amd64.go b/internal/decoder/jitdec/assembler_regabi_amd64.go index 4ff3b1962..76eef333b 100644 --- a/internal/decoder/jitdec/assembler_regabi_amd64.go +++ b/internal/decoder/jitdec/assembler_regabi_amd64.go @@ -483,6 +483,7 @@ var ( _V_stackOverflow = jit.Imm(int64(uintptr(unsafe.Pointer(&stackOverflow)))) _I_json_UnsupportedValueError = jit.Itab(_T_error, reflect.TypeOf(new(json.UnsupportedValueError))) _I_json_MismatchTypeError = jit.Itab(_T_error, reflect.TypeOf(new(MismatchTypeError))) + _I_json_MismatchQuotedError = jit.Itab(_T_error, reflect.TypeOf(new(MismatchQuotedError))) ) func (self *_Assembler) type_error() { @@ -1129,15 +1130,19 @@ func (self *_Assembler) unmarshal_func(t reflect.Type, fn obj.Addr, deref bool) self.Emit("MOVQ" , _ARG_sv_n, _DI) // MOVQ sv.n, DI self.call_go(fn) // CALL_GO ${fn} self.Emit("TESTQ", _ET, _ET) // TESTQ ET, ET - self.Sjmp("JZ" , "_unmarshal_func_end_{n}") // JNZ _error - self.Emit("MOVQ", _I_json_MismatchTypeError, _CX) // MOVQ ET, VAR.et - self.Emit("CMPQ", _ET, _CX) // check if MismatchedError - self.Sjmp("JNE" , _LB_error) - self.Emit("MOVQ", jit.Type(t), _CX) // store current type - self.Emit("MOVQ", _CX, _VAR_et) // store current type - self.Emit("MOVQ", _VAR_ic, _IC) // recover the pos - self.Emit("XORL", _ET, _ET) - self.Link("_unmarshal_func_end_{n}") + if fn == _F_decodeJsonUnmarshalerQuoted { + self.Sjmp("JZ" , "_unmarshal_func_end_{n}") // JZ _unmarshal_func_end_{n} + self.Emit("MOVQ", _I_json_MismatchQuotedError, _CX) // MOVQ _I_json_MismatchQuotedError, CX + self.Emit("CMPQ", _ET, _CX) // check if MismatchQuotedError + self.Sjmp("JNE" , _LB_error) // JNE _error + self.Emit("MOVQ", jit.Type(t), _CX) // store current type + self.Emit("MOVQ", _CX, _VAR_et) // store current type as mismatched type + self.Emit("MOVQ", _VAR_ic, _IC) // recover the pos at mismatched, continue to parse + self.Emit("XORL", _ET, _ET) // clear ET + self.Link("_unmarshal_func_end_{n}") + } else { + self.Sjmp("JNE" , _LB_error) // JNE _error + } } /** Dynamic Decoding Routine **/ diff --git a/internal/decoder/jitdec/primitives.go b/internal/decoder/jitdec/primitives.go index 5adfc038a..9de885007 100644 --- a/internal/decoder/jitdec/primitives.go +++ b/internal/decoder/jitdec/primitives.go @@ -39,9 +39,16 @@ func decodeJsonUnmarshaler(vv interface{}, s string) error { return vv.(json.Unmarshaler).UnmarshalJSON(rt.Str2Mem(s)) } +// used to distinguish between MismatchQuoted and other MismatchedTyped errors, see issue #670 and #716 +type MismatchQuotedError struct {} + +func (*MismatchQuotedError) Error() string { + return "mismatch quoted" +} + func decodeJsonUnmarshalerQuoted(vv interface{}, s string) error { if len(s) < 2 || s[0] != '"' || s[len(s)-1] != '"' { - return &MismatchTypeError{} + return &MismatchQuotedError{} } return vv.(json.Unmarshaler).UnmarshalJSON(rt.Str2Mem(s[1:len(s)-1])) } diff --git a/issue_test/issue670_test.go b/issue_test/issue670_test.go index f4605ab8a..15f554994 100644 --- a/issue_test/issue670_test.go +++ b/issue_test/issue670_test.go @@ -30,7 +30,6 @@ func TestIssue670_JSONMarshaler(t *testing.T) { so, _ := sonic.MarshalString(obj) eo, _ := json.Marshal(obj) assert.Equal(t, string(eo), so) - println(string(eo)) } func TestIssue670_JSONUnmarshaler(t *testing.T) { @@ -50,11 +49,8 @@ func TestIssue670_JSONUnmarshaler(t *testing.T) { func testUnmarshal(t *testing.T, eo []byte, rt reflect.Type, checkobj bool) { obj := reflect.New(rt).Interface() - println(string(eo)) - println("sonic") es := sonic.Unmarshal(eo, obj) obj2 := reflect.New(rt).Interface() - println("std") ee := json.Unmarshal(eo, obj2) assert.Equal(t, ee ==nil, es == nil, es) if checkobj { @@ -107,7 +103,6 @@ func (d *Date) UnmarshalJSON(in []byte) error { return nil } - println("hook ", string(in)) t, err := time.Parse("2006-01-02", string(in)) if err != nil { return err @@ -125,7 +120,6 @@ type Issue670TextMarshaler struct { type Date2 int64 func (d Date2) MarshalText() ([]byte, error) { - println("hook 1") if d == 0 { return []byte("null"), nil } @@ -133,7 +127,6 @@ func (d Date2) MarshalText() ([]byte, error) { } func (d *Date2) UnmarshalText(in []byte) error { - println("hook 2", string(in)) if string(in) == "null" { *d = 0 return nil diff --git a/issue_test/issue716_test.go b/issue_test/issue716_test.go new file mode 100644 index 000000000..b874cb11a --- /dev/null +++ b/issue_test/issue716_test.go @@ -0,0 +1,54 @@ +// Copyright 2024 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +package issue_test + +import ( + "fmt" + "testing" + + "github.com/bytedance/sonic" + "github.com/stretchr/testify/assert" +) + +type UnmFoo struct { + Name string + Age int +} + +func (p *UnmFoo) UnmarshalJSON(data []byte) error { + var aux struct { + Name string `json:"name"` + Age int `json:"age"` + } + + if err := sonic.Unmarshal(data, &aux); err != nil { + return err + } + + p.Name = aux.Name + p.Age = aux.Age + return nil +} + +func TestIssue716(t *testing.T) { + jsonData := `{"name": "Alice", "age": "30"}` + var obj UnmFoo + err := sonic.Unmarshal([]byte(jsonData), &obj) + assert.Error(t, err) + if err != nil { + fmt.Println("Error:", err) + } +}