From e2ff8ac5e787050108f6ac68944d0e6fb1ae3553 Mon Sep 17 00:00:00 2001 From: liu Date: Wed, 20 Nov 2024 15:01:10 +0800 Subject: [PATCH] fix(aarch64): invalid skip number (#712) --- ast/decode.go | 77 ++++-------------------------- decode_test.go | 21 ++++++++ internal/decoder/optdec/helper.go | 32 +++++++------ internal/decoder/optdec/node.go | 7 ++- internal/utils/skip.go | 79 +++++++++++++++++++++++++++++++ 5 files changed, 130 insertions(+), 86 deletions(-) create mode 100644 internal/utils/skip.go diff --git a/ast/decode.go b/ast/decode.go index 6690d513e..27aaf1408 100644 --- a/ast/decode.go +++ b/ast/decode.go @@ -17,13 +17,14 @@ package ast import ( - `encoding/base64` - `runtime` - `strconv` - `unsafe` - - `github.com/bytedance/sonic/internal/native/types` - `github.com/bytedance/sonic/internal/rt` + "encoding/base64" + "runtime" + "strconv" + "unsafe" + + "github.com/bytedance/sonic/internal/native/types" + "github.com/bytedance/sonic/internal/rt" + "github.com/bytedance/sonic/internal/utils" ) // Hack: this is used for both checking space and cause firendly compile errors in 32-bit arch. @@ -290,67 +291,7 @@ func decodeValue(src string, pos int, skipnum bool) (ret int, v types.JsonState) //go:nocheckptr func skipNumber(src string, pos int) (ret int) { - sp := uintptr(rt.IndexChar(src, pos)) - se := uintptr(rt.IndexChar(src, len(src))) - if uintptr(sp) >= se { - return -int(types.ERR_EOF) - } - - if c := *(*byte)(unsafe.Pointer(sp)); c == '-' { - sp += 1 - } - ss := sp - - var pointer bool - var exponent bool - var lastIsDigit bool - var nextNeedDigit = true - - for ; sp < se; sp += uintptr(1) { - c := *(*byte)(unsafe.Pointer(sp)) - if isDigit(c) { - lastIsDigit = true - nextNeedDigit = false - continue - } else if nextNeedDigit { - return -int(types.ERR_INVALID_CHAR) - } else if c == '.' { - if !lastIsDigit || pointer || exponent || sp == ss { - return -int(types.ERR_INVALID_CHAR) - } - pointer = true - lastIsDigit = false - nextNeedDigit = true - continue - } else if c == 'e' || c == 'E' { - if !lastIsDigit || exponent { - return -int(types.ERR_INVALID_CHAR) - } - if sp == se-1 { - return -int(types.ERR_EOF) - } - exponent = true - lastIsDigit = false - nextNeedDigit = false - continue - } else if c == '-' || c == '+' { - if prev := *(*byte)(unsafe.Pointer(sp - 1)); prev != 'e' && prev != 'E' { - return -int(types.ERR_INVALID_CHAR) - } - lastIsDigit = false - nextNeedDigit = true - continue - } else { - break - } - } - - if nextNeedDigit { - return -int(types.ERR_EOF) - } - - runtime.KeepAlive(src) - return int(uintptr(sp) - uintptr((*rt.GoString)(unsafe.Pointer(&src)).Ptr)) + return utils.SkipNumber(src, pos) } //go:nocheckptr diff --git a/decode_test.go b/decode_test.go index 320d8cfb9..acee43790 100644 --- a/decode_test.go +++ b/decode_test.go @@ -2967,3 +2967,24 @@ func BenchmarkDecoderRawMessage(b *testing.B) { }) }) } + + +func TestJsonNumber(t *testing.T) { + api := Config { + UseNumber: true, + }.Froze() + + + type Foo struct { + A json.Number `json:"a"` + B json.Number `json:"b"` + C json.Number `json:"c"` + } + + data := []byte(`{"a": 1 , "b": "123", "c": "0.4e+56"}`) + var foo1, foo2 Foo + serr := api.Unmarshal(data, &foo1) + jerr := json.Unmarshal(data, &foo2) + assert.Equal(t, jerr, serr) + assert.Equal(t, foo2, foo1) +} \ No newline at end of file diff --git a/internal/decoder/optdec/helper.go b/internal/decoder/optdec/helper.go index 1d76f8051..143fa6708 100644 --- a/internal/decoder/optdec/helper.go +++ b/internal/decoder/optdec/helper.go @@ -5,38 +5,42 @@ import ( "strconv" "github.com/bytedance/sonic/internal/native" + "github.com/bytedance/sonic/internal/utils" "github.com/bytedance/sonic/internal/native/types" ) func SkipNumberFast(json string, start int) (int, error) { - // find the number ending, we pasred in sonic-cpp, it alway valid + // find the number ending, we pasred in native, it alway valid pos := start for pos < len(json) && json[pos] != ']' && json[pos] != '}' && json[pos] != ',' { if json[pos] >= '0' && json[pos] <= '9' || json[pos] == '.' || json[pos] == '-' || json[pos] == '+' || json[pos] == 'e' || json[pos] == 'E' { pos += 1 } else { - return pos, error_syntax(pos, json, "invalid number") + break } } return pos, nil } -func ValidNumberFast(json string) error { - // find the number ending, we pasred in sonic-cpp, it alway valid - pos := 0 - for pos < len(json) && json[pos] != ']' && json[pos] != '}' && json[pos] != ',' { - if json[pos] >= '0' && json[pos] <= '9' || json[pos] == '.' || json[pos] == '-' || json[pos] == '+' || json[pos] == 'e' || json[pos] == 'E' { - pos += 1 - } else { - return error_syntax(pos, json, "invalid number") - } + +func isSpace(c byte) bool { + return c == ' ' || c == '\t' || c == '\n' || c == '\r' +} + +// pos is the start index of the raw +func ValidNumberFast(raw string) bool { + ret := utils.SkipNumber(raw, 0) + if ret < 0 { + return false } - if pos == 0 { - return error_syntax(pos, json, "invalid number") + // check trainling chars + for ret < len(raw) { + return false } - return nil + + return true } func SkipOneFast2(json string, pos *int) (int, error) { diff --git a/internal/decoder/optdec/node.go b/internal/decoder/optdec/node.go index 8b49ebb3a..690fbd5d4 100644 --- a/internal/decoder/optdec/node.go +++ b/internal/decoder/optdec/node.go @@ -509,12 +509,11 @@ func (val Node) AsNumber(ctx *Context) (json.Number, bool) { // parse JSON string as number if val.IsStr() { s, _ := val.AsStr(ctx) - err := ValidNumberFast(s) - if err != nil { + if !ValidNumberFast(s) { return "", false + } else { + return json.Number(s), true } - - return json.Number(s), true } return val.NonstrAsNumber(ctx) diff --git a/internal/utils/skip.go b/internal/utils/skip.go new file mode 100644 index 000000000..e42bfe759 --- /dev/null +++ b/internal/utils/skip.go @@ -0,0 +1,79 @@ + +package utils + +import ( + `runtime` + `unsafe` + + `github.com/bytedance/sonic/internal/native/types` + `github.com/bytedance/sonic/internal/rt` +) + +func isDigit(c byte) bool { + return c >= '0' && c <= '9' +} + +//go:nocheckptr +func SkipNumber(src string, pos int) (ret int) { + sp := uintptr(rt.IndexChar(src, pos)) + se := uintptr(rt.IndexChar(src, len(src))) + if uintptr(sp) >= se { + return -int(types.ERR_EOF) + } + + if c := *(*byte)(unsafe.Pointer(sp)); c == '-' { + sp += 1 + } + ss := sp + + var pointer bool + var exponent bool + var lastIsDigit bool + var nextNeedDigit = true + + for ; sp < se; sp += uintptr(1) { + c := *(*byte)(unsafe.Pointer(sp)) + if isDigit(c) { + lastIsDigit = true + nextNeedDigit = false + continue + } else if nextNeedDigit { + return -int(types.ERR_INVALID_CHAR) + } else if c == '.' { + if !lastIsDigit || pointer || exponent || sp == ss { + return -int(types.ERR_INVALID_CHAR) + } + pointer = true + lastIsDigit = false + nextNeedDigit = true + continue + } else if c == 'e' || c == 'E' { + if !lastIsDigit || exponent { + return -int(types.ERR_INVALID_CHAR) + } + if sp == se-1 { + return -int(types.ERR_EOF) + } + exponent = true + lastIsDigit = false + nextNeedDigit = false + continue + } else if c == '-' || c == '+' { + if prev := *(*byte)(unsafe.Pointer(sp - 1)); prev != 'e' && prev != 'E' { + return -int(types.ERR_INVALID_CHAR) + } + lastIsDigit = false + nextNeedDigit = true + continue + } else { + break + } + } + + if nextNeedDigit { + return -int(types.ERR_EOF) + } + + runtime.KeepAlive(src) + return int(uintptr(sp) - uintptr((*rt.GoString)(unsafe.Pointer(&src)).Ptr)) +} \ No newline at end of file