diff --git a/protocol/thrift/binaryreader.go b/protocol/thrift/binaryreader.go index b0a77e0..a08ebf0 100644 --- a/protocol/thrift/binaryreader.go +++ b/protocol/thrift/binaryreader.go @@ -24,14 +24,6 @@ import ( "sync" ) -type nextIface interface { - Next(n int) ([]byte, error) -} - -type discardIface interface { - Discard(n int) (int, error) -} - // BinaryReader represents a reader for binary protocol type BinaryReader struct { r nextIface @@ -53,8 +45,7 @@ func NewBinaryReader(r io.Reader) *BinaryReader { if nextr, ok := r.(nextIface); ok { ret.r = nextr } else { - nextr := poolNextReader.Get().(*nextReader) - nextr.Reset(r) + nextr := newNextReader(r) ret.r = nextr ret.d = nextr } @@ -65,7 +56,7 @@ func NewBinaryReader(r io.Reader) *BinaryReader { func (r *BinaryReader) Release() { nextr, ok := r.r.(*nextReader) if ok { - poolNextReader.Put(nextr) + nextr.Release() } r.reset() poolBinaryReader.Put(r) diff --git a/protocol/thrift/skipdecoder.go b/protocol/thrift/skipdecoder.go new file mode 100644 index 0000000..51c94fc --- /dev/null +++ b/protocol/thrift/skipdecoder.go @@ -0,0 +1,153 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "encoding/binary" + "fmt" + "io" + "sync" +) + +// SkipDecoder scans the underlying io.Reader and returns the bytes of a type +type SkipDecoder struct { + p skipReaderIface +} + +var poolSkipDecoder = sync.Pool{ + New: func() interface{} { + return &SkipDecoder{} + }, +} + +// NewSkipDecoder ... call Release if no longer use +func NewSkipDecoder(r io.Reader) *SkipDecoder { + p := poolSkipDecoder.Get().(*SkipDecoder) + p.Reset(r) + return p +} + +// Reset ... +func (p *SkipDecoder) Reset(r io.Reader) { + if p.p != nil { + p.p.Release() + } + if buf, ok := r.(remoteByteBuffer); ok { + p.p = newSkipByteBuffer(buf) + } else { + p.p = newSkipReader(r) + } +} + +// Release ... +func (p *SkipDecoder) Release() { + p.p.Release() + p.p = nil + poolSkipDecoder.Put(p) +} + +// Next skips a specific type and returns its bytes +func (p *SkipDecoder) Next(t TType) (buf []byte, err error) { + if err := p.skip(t, defaultRecursionDepth); err != nil { + return nil, err + } + return p.p.Bytes() +} + +func (p *SkipDecoder) skip(t TType, maxdepth int) error { + if maxdepth == 0 { + return errDepthLimitExceeded + } + if sz := typeToSize[t]; sz > 0 { + _, err := p.p.Next(int(sz)) + return err + } + switch t { + case STRING: + b, err := p.p.Next(4) + if err != nil { + return err + } + sz := int(binary.BigEndian.Uint32(b)) + if sz < 0 { + return errNegativeSize + } + if _, err := p.p.Next(sz); err != nil { + return err + } + case STRUCT: + for { + b, err := p.p.Next(1) // TType + if err != nil { + return err + } + tp := TType(b[0]) + if tp == STOP { + break + } + if _, err := p.p.Next(2); err != nil { // Field ID + return err + } + if err := p.skip(tp, maxdepth-1); err != nil { + return err + } + } + case MAP: + b, err := p.p.Next(6) // 1 byte key TType, 1 byte value TType, 4 bytes Len + if err != nil { + return err + } + kt, vt, sz := TType(b[0]), TType(b[1]), int32(binary.BigEndian.Uint32(b[2:])) + if sz < 0 { + return errNegativeSize + } + ksz, vsz := int(typeToSize[kt]), int(typeToSize[vt]) + if ksz > 0 && vsz > 0 { + _, err := p.p.Next(int(sz) * (ksz + vsz)) + return err + } + for i := int32(0); i < sz; i++ { + if err := p.skip(kt, maxdepth-1); err != nil { + return err + } + if err := p.skip(vt, maxdepth-1); err != nil { + return err + } + } + case SET, LIST: + b, err := p.p.Next(5) // 1 byte value type, 4 bytes Len + if err != nil { + return err + } + vt, sz := TType(b[0]), int32(binary.BigEndian.Uint32(b[1:])) + if sz < 0 { + return errNegativeSize + } + if vsz := typeToSize[vt]; vsz > 0 { + _, err := p.p.Next(int(sz) * int(vsz)) + return err + } + for i := int32(0); i < sz; i++ { + if err := p.skip(vt, maxdepth-1); err != nil { + return err + } + } + default: + return NewProtocolException(INVALID_DATA, fmt.Sprintf("unknown data type %d", t)) + } + return nil +} diff --git a/protocol/thrift/skipdecoder_test.go b/protocol/thrift/skipdecoder_test.go new file mode 100644 index 0000000..22cdd6a --- /dev/null +++ b/protocol/thrift/skipdecoder_test.go @@ -0,0 +1,157 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "bytes" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSkipDecoder(t *testing.T) { + x := BinaryProtocol{} + // byte + b := x.AppendByte([]byte(nil), 1) + sz0 := len(b) + + // string + b = x.AppendString(b, strings.Repeat("hello", 500)) // larger than buffer + sz1 := len(b) + + // list + b = x.AppendListBegin(b, I32, 1) + b = x.AppendI32(b, 1) + sz2 := len(b) + + // list + b = x.AppendListBegin(b, STRING, 1) + b = x.AppendString(b, "hello") + sz3 := len(b) + + // list> + b = x.AppendListBegin(b, LIST, 1) + b = x.AppendListBegin(b, I32, 1) + b = x.AppendI32(b, 1) + sz4 := len(b) + + // map + b = x.AppendMapBegin(b, I32, I64, 1) + b = x.AppendI32(b, 1) + b = x.AppendI64(b, 2) + sz5 := len(b) + + // map + b = x.AppendMapBegin(b, I32, STRING, 1) + b = x.AppendI32(b, 1) + b = x.AppendString(b, "hello") + sz6 := len(b) + + // map + b = x.AppendMapBegin(b, STRING, I64, 1) + b = x.AppendString(b, "hello") + b = x.AppendI64(b, 2) + sz7 := len(b) + + // map> + b = x.AppendMapBegin(b, I32, LIST, 1) + b = x.AppendI32(b, 1) + b = x.AppendListBegin(b, I32, 1) + b = x.AppendI32(b, 1) + sz8 := len(b) + + // map, i32> + b = x.AppendMapBegin(b, LIST, I32, 1) + b = x.AppendListBegin(b, I32, 1) + b = x.AppendI32(b, 1) + b = x.AppendI32(b, 1) + sz9 := len(b) + + // struct i32, list + b = x.AppendFieldBegin(b, I32, 1) + b = x.AppendI32(b, 1) + b = x.AppendFieldBegin(b, LIST, 1) + b = x.AppendListBegin(b, I32, 1) + b = x.AppendI32(b, 1) + b = x.AppendFieldStop(b) + sz10 := len(b) + + r := NewSkipDecoder(bytes.NewReader(b)) + defer r.Release() + + readn := 0 + b, err := r.Next(BYTE) // byte + require.NoError(t, err) + readn += len(b) + require.Equal(t, sz0, readn) + b, err = r.Next(STRING) // string + require.NoError(t, err) + readn += len(b) + require.Equal(t, sz1, readn) + b, err = r.Next(LIST) // list + require.NoError(t, err) + readn += len(b) + require.Equal(t, sz2, readn) + b, err = r.Next(LIST) // list + require.NoError(t, err) + readn += len(b) + require.Equal(t, sz3, readn) + b, err = r.Next(LIST) // list> + require.NoError(t, err) + readn += len(b) + require.Equal(t, sz4, readn) + b, err = r.Next(MAP) // map + require.NoError(t, err) + readn += len(b) + require.Equal(t, sz5, readn) + b, err = r.Next(MAP) // map + require.NoError(t, err) + readn += len(b) + require.Equal(t, sz6, readn) + b, err = r.Next(MAP) // map + require.NoError(t, err) + readn += len(b) + require.Equal(t, sz7, readn) + b, err = r.Next(MAP) // map> + require.NoError(t, err) + readn += len(b) + require.Equal(t, sz8, readn) + b, err = r.Next(MAP) // map, i32> + require.NoError(t, err) + readn += len(b) + require.Equal(t, sz9, readn) + b, err = r.Next(STRUCT) // struct i32, list + require.NoError(t, err) + readn += len(b) + require.Equal(t, sz10, readn) + + { // other cases + // errDepthLimitExceeded + b = b[:0] + for i := 0; i < defaultRecursionDepth+1; i++ { + b = x.AppendFieldBegin(b, STRUCT, 1) + } + r := NewSkipDecoder(bytes.NewReader(b)) + _, err := r.Next(STRUCT) + require.Same(t, errDepthLimitExceeded, err) + + // unknown type + _, err = r.Next(TType(122)) + require.Error(t, err) + } +} diff --git a/protocol/thrift/skipreader.go b/protocol/thrift/skipreader.go new file mode 100644 index 0000000..c2f9a00 --- /dev/null +++ b/protocol/thrift/skipreader.go @@ -0,0 +1,161 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "io" + "sync" +) + +// this file contains readers for SkipCodec + +type skipReaderIface interface { + Next(n int) (buf []byte, err error) + Bytes() (buf []byte, err error) + Release() +} + +var poolSkipReader = sync.Pool{ + New: func() interface{} { + return &skipReader{b: make([]byte, 1024)} + }, +} + +var poolSkipRemoteBuffer = sync.Pool{ + New: func() interface{} { + return &skipByteBuffer{} + }, +} + +// skipReader ... general skip reader for io.Reader +type skipReader struct { + r io.Reader + + p int + b []byte +} + +func newSkipReader(r io.Reader) *skipReader { + ret := poolSkipReader.Get().(*skipReader) + ret.Reset(r) + return ret +} + +func (p *skipReader) Release() { + poolSkipReader.Put(p) +} + +func (p *skipReader) Reset(r io.Reader) { + p.r = r + p.p = 0 +} + +func (p *skipReader) Bytes() ([]byte, error) { + ret := p.b[:p.p] + p.p = 0 + return ret, nil +} + +func (p *skipReader) grow(n int) { + // assert: len(p.b)-p.p < n + sz := 2 * cap(p.b) + if sz < p.p+n { + sz = p.p + n + } + b := make([]byte, sz) + copy(b, p.b[:p.p]) + p.b = b +} + +func (p *skipReader) Next(n int) (buf []byte, err error) { + if len(p.b)-p.p < n { + p.grow(n) + } + if _, err := io.ReadFull(p.r, p.b[p.p:p.p+n]); err != nil { + return nil, NewProtocolExceptionWithErr(err) + } + ret := p.b[p.p : p.p+n] + p.p += n + return ret, nil +} + +// remoteByteBuffer ... github.com/cloudwego/kitex/pkg/remote.ByteBuffer +type remoteByteBuffer interface { + Peek(n int) (buf []byte, err error) + ReadableLen() (n int) + Skip(n int) (err error) +} + +// skipByteBuffer ... optimized zero copy skipreader for remote.ByteBuffer +type skipByteBuffer struct { + p remoteByteBuffer + + r int + b []byte +} + +func newSkipByteBuffer(buf remoteByteBuffer) *skipByteBuffer { + ret := poolSkipRemoteBuffer.Get().(*skipByteBuffer) + ret.Reset(buf) + return ret +} + +func (p *skipByteBuffer) Release() { + poolSkipRemoteBuffer.Put(p) +} + +func (p *skipByteBuffer) Reset(buf remoteByteBuffer) { + p.r = 0 + p.b = nil + p.p = buf +} + +func (p *skipByteBuffer) Bytes() ([]byte, error) { + ret := p.b[:p.r] + if err := p.p.Skip(p.r); err != nil { + return nil, err + } + p.r = 0 + return ret, nil +} + +// Next ... +func (p *skipByteBuffer) Next(n int) (ret []byte, err error) { + if p.r+n < len(p.b) { // fast path + ret, p.r = p.b[p.r:p.r+n], p.r+n + return + } + return p.nextSlow(n) +} + +func (p *skipByteBuffer) nextSlow(n int) ([]byte, error) { + // trigger underlying conn to read more + _, err := p.p.Peek(p.r + n) + if err != nil { + return nil, err + } + // read as much as possible, luckily, we will have a full buffer + // then we no need to call p.Peek many times + p.b, err = p.p.Peek(p.p.ReadableLen()) + if err != nil { + return nil, err + } + // after calling p.loadmore, p.buf MUST be at least (p.r + n) len + ret := p.b[p.r : p.r+n] + p.r += n + return ret, nil +} diff --git a/protocol/thrift/skipreader_test.go b/protocol/thrift/skipreader_test.go new file mode 100644 index 0000000..c055391 --- /dev/null +++ b/protocol/thrift/skipreader_test.go @@ -0,0 +1,87 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSkipReader(t *testing.T) { + b := make([]byte, 2048) + for i := 0; i < len(b); i++ { + b[i] = byte(i) + } + + r := newSkipReader(bytes.NewReader(b)) + defer r.Release() + for i := 0; i < len(b); i++ { + b, err := r.Next(1) + require.NoError(t, err) + require.True(t, b[0] == byte(i)) + } + + retb, err := r.Bytes() + require.NoError(t, err) + require.Equal(t, b, retb) +} + +type remoteByteBufferImplForT struct { + p int + b []byte +} + +func (p *remoteByteBufferImplForT) Peek(n int) (buf []byte, err error) { + if n > len(p.b) { + return nil, io.EOF + } + return p.b[:n], nil +} + +func (p *remoteByteBufferImplForT) ReadableLen() int { + return len(p.b) - p.p +} + +func (p *remoteByteBufferImplForT) Skip(n int) error { + if n > len(p.b) { + panic("bug") + } + p.p += n + return nil +} + +func TestSkipRemoteBuffer(t *testing.T) { + b := make([]byte, 2048) + for i := 0; i < len(b); i++ { + b[i] = byte(i) + } + + r := newSkipByteBuffer(&remoteByteBufferImplForT{b: b}) + defer r.Release() + for i := 0; i < len(b); i++ { + b, err := r.Next(1) + require.NoError(t, err) + require.True(t, b[0] == byte(i)) + } + + retb, err := r.Bytes() + require.NoError(t, err) + require.Equal(t, b, retb) +} diff --git a/protocol/thrift/utils.go b/protocol/thrift/utils.go index 756af5f..dfd1d4a 100644 --- a/protocol/thrift/utils.go +++ b/protocol/thrift/utils.go @@ -30,6 +30,14 @@ func p2i32(p unsafe.Pointer) int32 { uint32(*(*byte)(p))<<24) } +type nextIface interface { + Next(n int) ([]byte, error) +} + +type discardIface interface { + Discard(n int) (int, error) +} + // nextReader provides a wrapper for io.Reader to use BinaryReader type nextReader struct { r io.Reader @@ -42,6 +50,18 @@ var poolNextReader = sync.Pool{ }, } +func newNextReader(r io.Reader) *nextReader { + ret := poolNextReader.Get().(*nextReader) + ret.Reset(r) + return ret +} + +// Release ... +func (r *nextReader) Release() { poolNextReader.Put(r) } + +// Reset ... for reusing nextReader +func (r *nextReader) Reset(rd io.Reader) { r.r = rd } + // Next implements nextIface of BinaryReader func (r *nextReader) Next(n int) ([]byte, error) { b := r.b[:] @@ -74,8 +94,3 @@ func (r *nextReader) Discard(n int) (int, error) { } return ret, nil } - -// Reset ... for reusing nextReader -func (r *nextReader) Reset(rd io.Reader) { - r.r = rd -}