diff --git a/ast/parser.go b/ast/parser.go index a1f582623..506f9d86c 100644 --- a/ast/parser.go +++ b/ast/parser.go @@ -115,6 +115,10 @@ func (self *Parser) lspace(sp int) int { return sp } +func (self *Parser) backward() { + for ; self.p >= 0 && isSpace(self.s[self.p]); self.p-=1 {} +} + func (self *Parser) decodeArray(ret *linkedNodes) (Node, types.ParsingError) { sp := self.p ns := len(self.s) diff --git a/ast/visitor.go b/ast/visitor.go index d409509f5..f25407b99 100644 --- a/ast/visitor.go +++ b/ast/visitor.go @@ -18,6 +18,7 @@ package ast import ( `encoding/json` + `errors` `github.com/bytedance/sonic/internal/native/types` ) @@ -174,6 +175,19 @@ func (self *traverser) decodeArray() error { sp := self.parser.p ns := len(self.parser.s) + /* allocate array space and parse every element */ + if err := self.visitor.OnArrayBegin(_DEFAULT_NODE_CAP); err != nil { + if err == VisitOPSkip { + // NOTICE: for user needs to skip entiry object + self.parser.p -= 1 + if _, e := self.parser.skipFast(); e != 0 { + return e + } + return self.visitor.OnArrayEnd() + } + return err + } + /* check for EOF */ self.parser.p = self.parser.lspace(sp) if self.parser.p >= ns { @@ -183,16 +197,9 @@ func (self *traverser) decodeArray() error { /* check for empty array */ if self.parser.s[self.parser.p] == ']' { self.parser.p++ - if err := self.visitor.OnArrayBegin(0); err != nil { - return err - } return self.visitor.OnArrayEnd() } - /* allocate array space and parse every element */ - if err := self.visitor.OnArrayBegin(_DEFAULT_NODE_CAP); err != nil { - return err - } for { /* decode the value */ if err := self.decodeValue(); err != nil { @@ -223,6 +230,19 @@ func (self *traverser) decodeObject() error { sp := self.parser.p ns := len(self.parser.s) + /* allocate object space and decode each pair */ + if err := self.visitor.OnObjectBegin(_DEFAULT_NODE_CAP); err != nil { + if err == VisitOPSkip { + // NOTICE: for user needs to skip entiry object + self.parser.p -= 1 + if _, e := self.parser.skipFast(); e != 0 { + return e + } + return self.visitor.OnObjectEnd() + } + return err + } + /* check for EOF */ self.parser.p = self.parser.lspace(sp) if self.parser.p >= ns { @@ -231,17 +251,9 @@ func (self *traverser) decodeObject() error { /* check for empty object */ if self.parser.s[self.parser.p] == '}' { - self.parser.p++ - if err := self.visitor.OnObjectBegin(0); err != nil { - return err - } return self.visitor.OnObjectEnd() } - /* allocate object space and decode each pair */ - if err := self.visitor.OnObjectBegin(_DEFAULT_NODE_CAP); err != nil { - return err - } for { var njs types.JsonState var err types.ParsingError @@ -313,3 +325,7 @@ func (self *traverser) decodeString(iv int64, ep int) error { } return self.visitor.OnString(out) } + +// If visitor return this error on `OnObjectBegin()` or `OnArrayBegin()`, +// the transverer will skip entiry object or array +var VisitOPSkip = errors.New("") diff --git a/ast/visitor_test.go b/ast/visitor_test.go index 9ecdc4a02..c576bdd9d 100644 --- a/ast/visitor_test.go +++ b/ast/visitor_test.go @@ -648,6 +648,115 @@ func TestVisitor_UserNodeDiff(t *testing.T) { }) } +type skipVisitor struct { + sp int + Skip int + inSkip bool + CountSkip int +} + +func (self *skipVisitor) OnNull() error { + if self.sp == self.Skip+1 && self.inSkip { + panic("unexpected key") + } + return nil +} + +func (self *skipVisitor) OnFloat64(v float64, n json.Number) error { + if self.sp == self.Skip+1 && self.inSkip { + panic("unexpected key") + } + return nil +} + +func (self *skipVisitor) OnInt64(v int64, n json.Number) error { + if self.sp == self.Skip+1 && self.inSkip { + panic("unexpected key") + } + return nil +} + +func (self *skipVisitor) OnBool(v bool) error { + if self.sp == self.Skip+1 && self.inSkip { + panic("unexpected key") + } + return nil +} + +func (self *skipVisitor) OnString(v string) error { + if self.sp == self.Skip+1 && self.inSkip { + panic("unexpected key") + } + return nil +} + +func (self *skipVisitor) OnObjectBegin(capacity int) error { + println("self.sp", self.sp) + if self.sp == self.Skip { + self.inSkip = true + self.CountSkip++ + println("op skip") + self.sp++ + return VisitOPSkip + } + self.sp++ + return nil +} + +func (self *skipVisitor) OnObjectKey(key string) error { + if self.sp == self.Skip+1 && self.inSkip { + panic("unexpected key") + } + return nil +} + +func (self *skipVisitor) OnObjectEnd() error { + if self.sp == self.Skip + 1 { + if !self.inSkip { + panic("not in skip") + } + self.inSkip = false + println("finish op skip") + } + self.sp-- + return nil +} + +func (self *skipVisitor) OnArrayBegin(capacity int) error { + println("arr self.sp", self.sp) + if self.sp == self.Skip { + self.inSkip = true + self.CountSkip++ + println("arr op skip") + self.sp++ + return VisitOPSkip + } + self.sp++ + return nil +} + +func (self *skipVisitor) OnArrayEnd() error { + println("arr self.sp", self.sp) + if self.sp == self.Skip + 1 { + if !self.inSkip { + panic("arr not in skip") + } + self.inSkip = false + println("arr finish op skip") + } + self.sp-- + return nil +} + +func TestVisitor_OpSkip(t *testing.T) { + var suite skipVisitor + suite.Skip = 1 + Preorder(`{ "a": [ null ] , "b": 1, "c": { "1" : 1 } }`, &suite, nil) + if suite.CountSkip != 2 { + t.Fatal(suite.CountSkip) + } +} + func BenchmarkVisitor_UserNode(b *testing.B) { const str = _TwitterJson b.Run("AST", func(b *testing.B) { diff --git a/fuzz/go-fuzz-corpus b/fuzz/go-fuzz-corpus new file mode 160000 index 000000000..c42c1b291 --- /dev/null +++ b/fuzz/go-fuzz-corpus @@ -0,0 +1 @@ +Subproject commit c42c1b2914c7503500996ee15927d3ab3d2ba968 diff --git a/internal/encoder/assembler_stkabi_amd64.go b/internal/encoder/assembler_stkabi_amd64.go index c506ea607..83f9428f6 100644 --- a/internal/encoder/assembler_stkabi_amd64.go +++ b/internal/encoder/assembler_stkabi_amd64.go @@ -579,12 +579,12 @@ var ( func (self *_Assembler) more_space() { self.Link(_LB_more_space) - self.Emit("MOVQ", _T_byte, _AX) // MOVQ $_T_byte, _AX - self.Emit("MOVQ", _AX, jit.Ptr(_SP, 0)) // MOVQ _AX, (SP) self.Emit("MOVQ", _RP, jit.Ptr(_SP, 8)) // MOVQ RP, 8(SP) self.Emit("MOVQ", _RL, jit.Ptr(_SP, 16)) // MOVQ RL, 16(SP) self.Emit("MOVQ", _RC, jit.Ptr(_SP, 24)) // MOVQ RC, 24(SP) self.Emit("MOVQ", _AX, jit.Ptr(_SP, 32)) // MOVQ AX, 32(SP) + self.Emit("MOVQ", _T_byte, _AX) // MOVQ $_T_byte, _AX + self.Emit("MOVQ", _AX, jit.Ptr(_SP, 0)) // MOVQ _AX, (SP) self.xsave(_REG_jsr...) // SAVE $REG_jsr self.call(_F_growslice) // CALL $pc self.xload(_REG_jsr...) // LOAD $REG_jsr diff --git a/issue_test/issue634_test.go b/issue_test/issue634_test.go new file mode 100644 index 000000000..fec01d44e --- /dev/null +++ b/issue_test/issue634_test.go @@ -0,0 +1,79 @@ +/* + * Copyright 2024 ByteDance Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package issue_test + +import ( + "strings" + "testing" + + "sync" + + "github.com/bytedance/sonic" + "github.com/bytedance/sonic/option" + "github.com/stretchr/testify/assert" +) + +func marshalSingle() { + var m = map[string]interface{}{ + "1": map[string]interface{} { + `"`+strings.Repeat("a", int(option.DefaultEncoderBufferSize) - 38)+`"`: "b", + "1": map[string]int32{ + "b": 1658219785, + }, + }, + } + _, err := sonic.Marshal(&m) + if err != nil { + panic("err") + } +} + +type zoo foo + +func (z *zoo) MarshalJSON() ([]byte, error) { + marshalSingle() + return sonic.Marshal((*foo)(z)) +} + +type foo bar + +func (f *foo) MarshalJSON() ([]byte, error) { + marshalSingle() + return sonic.Marshal((*bar)(f)) +} + +type bar int + +func (b *bar) MarshalJSON() ([]byte, error) { + marshalSingle() + return sonic.Marshal(int(*b)) +} + + func TestEncodeOOM(t *testing.T) { + wg := &sync.WaitGroup{} + N := 10000 + for i:=0; i