diff --git a/protocol/thrift/skipdecoder.go b/protocol/thrift/skipdecoder.go index 00c806a..78eace1 100644 --- a/protocol/thrift/skipdecoder.go +++ b/protocol/thrift/skipdecoder.go @@ -17,16 +17,13 @@ package thrift import ( - "encoding/binary" - "fmt" + "io" "sync" "github.com/bytedance/gopkg/lang/mcache" "github.com/cloudwego/gopkg/bufiox" ) -const defaultSkipDecoderSize = 4096 - var poolSkipDecoder = sync.Pool{ New: func() interface{} { return &SkipDecoder{} @@ -37,143 +34,184 @@ var poolSkipDecoder = sync.Pool{ type SkipDecoder struct { r bufiox.Reader - // for storing Next(ttype) buffer - nextBuf []byte - - // for reusing buffer - pendingBuf [][]byte + rn int } -// NewSkipDecoder ... call Release if no longer use +// NewSkipDecoder ... +// +// call Release if no longer use func NewSkipDecoder(r bufiox.Reader) *SkipDecoder { p := poolSkipDecoder.Get().(*SkipDecoder) p.r = r return p } -// Release releases the peekAck decoder, callers cannot use the returned data of Next after calling Release. +// Release puts SkipDecoder back to pool and reuse it next time. +// +// DO NOT USE SkipDecoder after calling Release. func (p *SkipDecoder) Release() { - if cap(p.nextBuf) > 0 { - mcache.Free(p.nextBuf) - } *p = SkipDecoder{} poolSkipDecoder.Put(p) } // Next skips a specific type and returns its bytes. -// Callers cannot use the returned data after calling Release. +// +// The returned buf is directly from bufiox.Reader with the same lifecycle. func (p *SkipDecoder) Next(t TType) (buf []byte, err error) { - p.nextBuf = mcache.Malloc(0, defaultSkipDecoderSize) - if err = p.skip(t, defaultRecursionDepth); err != nil { + p.rn = 0 + if err = NewSkipDecoderTpl(p).Skip(t, defaultRecursionDepth); err != nil { return } - var offset int - for _, b := range p.pendingBuf { - offset += copy(p.nextBuf[offset:], b[offset:]) - mcache.Free(b) + buf, err = p.r.Next(p.rn) + return +} + +// SkipN implements SkipDecoderIFace +func (p *SkipDecoder) SkipN(n int) (buf []byte, err error) { + // old version netpoll may have performance issue when using Peek + // see: https://github.com/cloudwego/netpoll/pull/335 + if buf, err = p.r.Peek(p.rn + n); err == nil { + buf = buf[p.rn:] + p.rn += n } - p.pendingBuf = nil - buf = p.nextBuf return } -func (p *SkipDecoder) skip(t TType, maxdepth int) error { - if maxdepth == 0 { - return errDepthLimitExceeded +// BytesSkipDecoder ... +type BytesSkipDecoder struct { + n int + b []byte +} + +var poolBytesSkipDecoder = sync.Pool{ + New: func() interface{} { + return &BytesSkipDecoder{} + }, +} + +// NewBytesSkipDecoder ... +// +// call Release if no longer use +func NewBytesSkipDecoder(b []byte) *BytesSkipDecoder { + p := poolBytesSkipDecoder.Get().(*BytesSkipDecoder) + p.Reset(b) + return p +} + +// Release puts BytesSkipDecoder back to pool and reuse it next time. +// +// DO NOT USE BytesSkipDecoder after calling Release. +func (p *BytesSkipDecoder) Release() { + p.Reset(nil) + poolBytesSkipDecoder.Put(p) +} + +// Reset ... +func (p *BytesSkipDecoder) Reset(b []byte) { + p.n = 0 + p.b = b +} + +// Next skips a specific type and returns its bytes. +// +// The returned buf refers to the input []byte without copy +func (p *BytesSkipDecoder) Next(t TType) (b []byte, err error) { + if err = NewSkipDecoderTpl(p).Skip(t, defaultRecursionDepth); err != nil { + return } - if sz := typeToSize[t]; sz > 0 { - _, err := p.next(int(sz)) - return err + b = p.b[:p.n] + p.b = p.b[p.n:] + p.n = 0 + return +} + +// SkipN implements SkipDecoderIFace +func (p *BytesSkipDecoder) SkipN(n int) ([]byte, error) { + if len(p.b) >= p.n+n { + p.n += n + return p.b[p.n-n : p.n], nil } - switch t { - case STRING: - b, err := p.next(4) - if err != nil { - return err - } - sz := int(binary.BigEndian.Uint32(b)) - if sz < 0 { - return errNegativeSize - } - if _, err := p.next(sz); err != nil { - return err - } - case STRUCT: - for { - b, err := p.next(1) // TType - if err != nil { - return err - } - tp := TType(b[0]) - if tp == STOP { - break - } - if _, err := p.next(2); err != nil { // Field ID - return err - } - if err := p.skip(tp, maxdepth-1); err != nil { - return err - } - } - case MAP: - b, err := p.next(6) // 1 byte key TType, 1 byte value TType, 4 bytes Len - if err != nil { - return err - } - kt, vt, sz := TType(b[0]), TType(b[1]), int32(binary.BigEndian.Uint32(b[2:])) - if sz < 0 { - return errNegativeSize - } - ksz, vsz := int(typeToSize[kt]), int(typeToSize[vt]) - if ksz > 0 && vsz > 0 { - _, err := p.next(int(sz) * (ksz + vsz)) - return err - } - for i := int32(0); i < sz; i++ { - if err := p.skip(kt, maxdepth-1); err != nil { - return err - } - if err := p.skip(vt, maxdepth-1); err != nil { - return err - } - } - case SET, LIST: - b, err := p.next(5) // 1 byte value type, 4 bytes Len - if err != nil { - return err - } - vt, sz := TType(b[0]), int32(binary.BigEndian.Uint32(b[1:])) - if sz < 0 { - return errNegativeSize - } - if vsz := typeToSize[vt]; vsz > 0 { - _, err := p.next(int(sz) * int(vsz)) - return err - } - for i := int32(0); i < sz; i++ { - if err := p.skip(vt, maxdepth-1); err != nil { - return err - } - } - default: - return NewProtocolException(INVALID_DATA, fmt.Sprintf("unknown data type %d", t)) + return nil, io.EOF +} + +// ReaderSkipDecoder ... +type ReaderSkipDecoder struct { + r io.Reader + + n int // bytes read, n <= len(b) + b []byte +} + +var poolReaderSkipDecoder = sync.Pool{ + New: func() interface{} { + return &ReaderSkipDecoder{} + }, +} + +// NewReaderSkipDecoder creates a ReaderSkipDecoder from pool +// +// call Release if no longer use +func NewReaderSkipDecoder(r io.Reader) *ReaderSkipDecoder { + p := poolReaderSkipDecoder.Get().(*ReaderSkipDecoder) + p.Reset(r) + return p +} + +// Release puts ReaderSkipDecoder back to pool and reuse it next time. +// +// DO NOT USE ReaderSkipDecoder after calling Release. +func (p *ReaderSkipDecoder) Release() { + // no need to free p.b + // will make use of p.b without reallcation + p.Reset(nil) + poolReaderSkipDecoder.Put(p) +} + +// Reset ... +func (p *ReaderSkipDecoder) Reset(r io.Reader) { + p.r = r + p.n = 0 +} + +// Grow grows the underlying buffer to fit n bytes +func (p *ReaderSkipDecoder) Grow(n int) { + if len(p.b)-p.n >= n { + return } - return nil + p.growSlow(n) +} + +func (p *ReaderSkipDecoder) growSlow(n int) { + // mcache will take care of the size of newb + newb := mcache.Malloc(p.n + n) + copy(newb, p.b[:p.n]) + mcache.Free(p.b) + p.b = newb } -func (p *SkipDecoder) next(n int) (buf []byte, err error) { - if buf, err = p.r.Next(n); err != nil { +// Next skips a specific type and returns its bytes. +// +// The returned []byte is valid before the next `Next` call or `Release` +func (p *ReaderSkipDecoder) Next(t TType) (b []byte, err error) { + p.n = 0 + if err = NewSkipDecoderTpl(p).Skip(t, defaultRecursionDepth); err != nil { return } - if cap(p.nextBuf)-len(p.nextBuf) < n { - var ncap int - for ncap = cap(p.nextBuf) * 2; ncap-len(p.nextBuf) < n; ncap *= 2 { - } - nbs := mcache.Malloc(ncap, ncap) - p.pendingBuf = append(p.pendingBuf, p.nextBuf) - p.nextBuf = nbs[:len(p.nextBuf)] + return p.b[:p.n], nil +} + +// SkipN implements SkipDecoderIFace +func (p *ReaderSkipDecoder) SkipN(n int) (buf []byte, err error) { + p.Grow(n) + buf = p.b[p.n : p.n+n] + for i := 0; i < n && err == nil; { // io.ReadFull(buf) + var nn int + nn, err = p.r.Read(buf[i:]) + i += nn + } + if err != nil { + return } - cn := copy(p.nextBuf[len(p.nextBuf):cap(p.nextBuf)], buf) - p.nextBuf = p.nextBuf[:len(p.nextBuf)+cn] + p.n += n return } diff --git a/protocol/thrift/skipdecoder_test.go b/protocol/thrift/skipdecoder_test.go index 60a5e7a..4c0aea8 100644 --- a/protocol/thrift/skipdecoder_test.go +++ b/protocol/thrift/skipdecoder_test.go @@ -92,54 +92,75 @@ func TestSkipDecoder(t *testing.T) { b = x.AppendFieldStop(b) sz10 := len(b) - r := NewSkipDecoder(bufiox.NewBytesReader(b)) - defer r.Release() - - readn := 0 - b, err := r.Next(BYTE) // byte - require.NoError(t, err) - readn += len(b) - require.Equal(t, sz0, readn) - b, err = r.Next(STRING) // string - require.NoError(t, err) - readn += len(b) - require.Equal(t, sz1, readn) - b, err = r.Next(LIST) // list - require.NoError(t, err) - readn += len(b) - require.Equal(t, sz2, readn) - b, err = r.Next(LIST) // list - require.NoError(t, err) - readn += len(b) - require.Equal(t, sz3, readn) - b, err = r.Next(LIST) // list> - require.NoError(t, err) - readn += len(b) - require.Equal(t, sz4, readn) - b, err = r.Next(MAP) // map - require.NoError(t, err) - readn += len(b) - require.Equal(t, sz5, readn) - b, err = r.Next(MAP) // map - require.NoError(t, err) - readn += len(b) - require.Equal(t, sz6, readn) - b, err = r.Next(MAP) // map - require.NoError(t, err) - readn += len(b) - require.Equal(t, sz7, readn) - b, err = r.Next(MAP) // map> - require.NoError(t, err) - readn += len(b) - require.Equal(t, sz8, readn) - b, err = r.Next(MAP) // map, i32> - require.NoError(t, err) - readn += len(b) - require.Equal(t, sz9, readn) - b, err = r.Next(STRUCT) // struct i32, list - require.NoError(t, err) - readn += len(b) - require.Equal(t, sz10, readn) + type NextIFace interface { + Next(t TType) (buf []byte, err error) + } + + testNext := func(t *testing.T, r NextIFace) { + readn := 0 + b, err := r.Next(BYTE) // byte + require.NoError(t, err) + readn += len(b) + require.Equal(t, sz0, readn) + b, err = r.Next(STRING) // string + require.NoError(t, err) + readn += len(b) + require.Equal(t, sz1, readn) + b, err = r.Next(LIST) // list + require.NoError(t, err) + readn += len(b) + require.Equal(t, sz2, readn) + b, err = r.Next(LIST) // list + require.NoError(t, err) + readn += len(b) + require.Equal(t, sz3, readn) + b, err = r.Next(LIST) // list> + require.NoError(t, err) + readn += len(b) + require.Equal(t, sz4, readn) + b, err = r.Next(MAP) // map + require.NoError(t, err) + readn += len(b) + require.Equal(t, sz5, readn) + b, err = r.Next(MAP) // map + require.NoError(t, err) + readn += len(b) + require.Equal(t, sz6, readn) + b, err = r.Next(MAP) // map + require.NoError(t, err) + readn += len(b) + require.Equal(t, sz7, readn) + b, err = r.Next(MAP) // map> + require.NoError(t, err) + readn += len(b) + require.Equal(t, sz8, readn) + b, err = r.Next(MAP) // map, i32> + require.NoError(t, err) + readn += len(b) + require.Equal(t, sz9, readn) + b, err = r.Next(STRUCT) // struct i32, list + require.NoError(t, err) + readn += len(b) + require.Equal(t, sz10, readn) + } + + t.Run("NewSkipDecoder", func(t *testing.T) { + r := NewSkipDecoder(bufiox.NewBytesReader(b)) + defer r.Release() + testNext(t, r) + }) + + t.Run("NewBytesSkipDecoder", func(t *testing.T) { + r := NewBytesSkipDecoder(b) + defer r.Release() + testNext(t, r) + }) + + t.Run("NewReaderSkipDecoder", func(t *testing.T) { + r := NewReaderSkipDecoder(bytes.NewBuffer(b)) + defer r.Release() + testNext(t, r) + }) { // other cases // errDepthLimitExceeded @@ -219,8 +240,7 @@ func BenchmarkSkipDecoder(b *testing.B) { b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { - bufReader := bufiox.NewBytesReader(bs) - sr := NewSkipDecoder(bufReader) + sr := NewBytesSkipDecoder(bs) buf, err := sr.Next(STRUCT) if err != nil { b.Fatal(err) @@ -229,7 +249,6 @@ func BenchmarkSkipDecoder(b *testing.B) { b.Fatal("bytes not equal") } sr.Release() - _ = bufReader.Release(nil) } }) } diff --git a/protocol/thrift/skipdecoder_tpl.go b/protocol/thrift/skipdecoder_tpl.go new file mode 100644 index 0000000..1613cc9 --- /dev/null +++ b/protocol/thrift/skipdecoder_tpl.go @@ -0,0 +1,130 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "encoding/binary" + "fmt" +) + +// SkipDecoderIFace represent the generics constraint of a SkipDecoder. +// +// It's used by SkipDecoderTpl +type SkipDecoderIFace interface { + // SkipN read and skip n bytes + // + // SkipDecoderTpl will not hold or modify the bytes between two `SkipN` calls. + // It's safe to reuse buffer for next `SkipN` call. + // + // if SkipN is short enough, it can be inlined. + SkipN(n int) ([]byte, error) +} + +// SkipDecoderTpl is the core logic of skipping thrift binary +type SkipDecoderTpl[T SkipDecoderIFace] struct { + r T +} + +// NewSkipDecoderTpl ... +func NewSkipDecoderTpl[T SkipDecoderIFace](r T) SkipDecoderTpl[T] { + return SkipDecoderTpl[T]{r} +} + +// Skip ... +func (p SkipDecoderTpl[T]) Skip(t TType, maxdepth int) error { + if maxdepth == 0 { + return errDepthLimitExceeded + } + if sz := typeToSize[t]; sz > 0 { + _, err := p.r.SkipN(int(sz)) + return err + } + switch t { + case STRING: + b, err := p.r.SkipN(4) + if err != nil { + return err + } + sz := int(binary.BigEndian.Uint32(b)) + if sz < 0 { + return errNegativeSize + } + if _, err := p.r.SkipN(sz); err != nil { + return err + } + case STRUCT: + for { + b, err := p.r.SkipN(1) // TType + if err != nil { + return err + } + tp := TType(b[0]) + if tp == STOP { + break + } + if _, err := p.r.SkipN(2); err != nil { // Field ID + return err + } + if err := p.Skip(tp, maxdepth-1); err != nil { + return err + } + } + case MAP: + b, err := p.r.SkipN(6) // 1 byte key TType, 1 byte value TType, 4 bytes Len + if err != nil { + return err + } + kt, vt, sz := TType(b[0]), TType(b[1]), int32(binary.BigEndian.Uint32(b[2:])) + if sz < 0 { + return errNegativeSize + } + ksz, vsz := int(typeToSize[kt]), int(typeToSize[vt]) + if ksz > 0 && vsz > 0 { + _, err := p.r.SkipN(int(sz) * (ksz + vsz)) + return err + } + for i := int32(0); i < sz; i++ { + if err := p.Skip(kt, maxdepth-1); err != nil { + return err + } + if err := p.Skip(vt, maxdepth-1); err != nil { + return err + } + } + case SET, LIST: + b, err := p.r.SkipN(5) // 1 byte value type, 4 bytes Len + if err != nil { + return err + } + vt, sz := TType(b[0]), int32(binary.BigEndian.Uint32(b[1:])) + if sz < 0 { + return errNegativeSize + } + if vsz := typeToSize[vt]; vsz > 0 { + _, err := p.r.SkipN(int(sz) * int(vsz)) + return err + } + for i := int32(0); i < sz; i++ { + if err := p.Skip(vt, maxdepth-1); err != nil { + return err + } + } + default: + return NewProtocolException(INVALID_DATA, fmt.Sprintf("unknown data type %d", t)) + } + return nil +}