Skip to content

Commit

Permalink
feat(thrift): skipdecoder
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaost committed Jul 22, 2024
1 parent aad015e commit d1e1922
Show file tree
Hide file tree
Showing 6 changed files with 580 additions and 16 deletions.
13 changes: 2 additions & 11 deletions protocol/thrift/binaryreader.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,6 @@ import (
"sync"
)

type nextIface interface {
Next(n int) ([]byte, error)
}

type discardIface interface {
Discard(n int) (int, error)
}

// BinaryReader represents a reader for binary protocol
type BinaryReader struct {
r nextIface
Expand All @@ -53,8 +45,7 @@ func NewBinaryReader(r io.Reader) *BinaryReader {
if nextr, ok := r.(nextIface); ok {
ret.r = nextr
} else {
nextr := poolNextReader.Get().(*nextReader)
nextr.Reset(r)
nextr := newNextReader(r)
ret.r = nextr
ret.d = nextr
}
Expand All @@ -65,7 +56,7 @@ func NewBinaryReader(r io.Reader) *BinaryReader {
func (r *BinaryReader) Release() {
nextr, ok := r.r.(*nextReader)
if ok {
poolNextReader.Put(nextr)
nextr.Release()
}
r.reset()
poolBinaryReader.Put(r)
Expand Down
153 changes: 153 additions & 0 deletions protocol/thrift/skipdecoder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package thrift

import (
"encoding/binary"
"fmt"
"io"
"sync"
)

// SkipDecoder scans the underlying io.Reader and returns the bytes of a type
type SkipDecoder struct {
p skipReaderIface
}

var poolSkipDecoder = sync.Pool{
New: func() interface{} {
return &SkipDecoder{}
},
}

// NewSkipDecoder ... call Release if no longer use
func NewSkipDecoder(r io.Reader) *SkipDecoder {
p := poolSkipDecoder.Get().(*SkipDecoder)
p.Reset(r)
return p
}

// Reset ...
func (p *SkipDecoder) Reset(r io.Reader) {
if p.p != nil {
p.p.Release()
}
if buf, ok := r.(remoteByteBuffer); ok {
p.p = newSkipByteBuffer(buf)
} else {
p.p = newSkipReader(r)
}
}

// Release ...
func (p *SkipDecoder) Release() {
p.p.Release()
p.p = nil
poolSkipDecoder.Put(p)
}

// Next skips a specific type and returns its bytes
func (p *SkipDecoder) Next(t TType) (buf []byte, err error) {
if err := p.skip(t, defaultRecursionDepth); err != nil {
return nil, err
}
return p.p.Bytes()
}

func (p *SkipDecoder) skip(t TType, maxdepth int) error {
if maxdepth == 0 {
return errDepthLimitExceeded
}
if sz := typeToSize[t]; sz > 0 {
_, err := p.p.Next(int(sz))
return err
}
switch t {
case STRING:
b, err := p.p.Next(4)
if err != nil {
return err
}
sz := int(binary.BigEndian.Uint32(b))
if sz < 0 {
return errNegativeSize
}
if _, err := p.p.Next(sz); err != nil {
return err
}
case STRUCT:
for {
b, err := p.p.Next(1) // TType
if err != nil {
return err
}
tp := TType(b[0])
if tp == STOP {
break
}
if _, err := p.p.Next(2); err != nil { // Field ID
return err
}
if err := p.skip(tp, maxdepth-1); err != nil {
return err
}
}
case MAP:
b, err := p.p.Next(6) // 1 byte key TType, 1 byte value TType, 4 bytes Len
if err != nil {
return err
}
kt, vt, sz := TType(b[0]), TType(b[1]), int32(binary.BigEndian.Uint32(b[2:]))
if sz < 0 {
return errNegativeSize
}
ksz, vsz := int(typeToSize[kt]), int(typeToSize[vt])
if ksz > 0 && vsz > 0 {
_, err := p.p.Next(int(sz) * (ksz + vsz))
return err
}
for i := int32(0); i < sz; i++ {
if err := p.skip(kt, maxdepth-1); err != nil {
return err
}
if err := p.skip(vt, maxdepth-1); err != nil {
return err
}
}
case SET, LIST:
b, err := p.p.Next(5) // 1 byte value type, 4 bytes Len
if err != nil {
return err
}
vt, sz := TType(b[0]), int32(binary.BigEndian.Uint32(b[1:]))
if sz < 0 {
return errNegativeSize
}
if vsz := typeToSize[vt]; vsz > 0 {
_, err := p.p.Next(int(sz) * int(vsz))
return err
}
for i := int32(0); i < sz; i++ {
if err := p.skip(vt, maxdepth-1); err != nil {
return err
}
}
default:
return NewProtocolException(INVALID_DATA, fmt.Sprintf("unknown data type %d", t))
}
return nil
}
157 changes: 157 additions & 0 deletions protocol/thrift/skipdecoder_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package thrift

import (
"bytes"
"strings"
"testing"

"github.com/stretchr/testify/require"
)

func TestSkipDecoder(t *testing.T) {
x := BinaryProtocol{}
// byte
b := x.AppendByte([]byte(nil), 1)
sz0 := len(b)

// string
b = x.AppendString(b, strings.Repeat("hello", 500)) // larger than buffer
sz1 := len(b)

// list<i32>
b = x.AppendListBegin(b, I32, 1)
b = x.AppendI32(b, 1)
sz2 := len(b)

// list<string>
b = x.AppendListBegin(b, STRING, 1)
b = x.AppendString(b, "hello")
sz3 := len(b)

// list<list<i32>>
b = x.AppendListBegin(b, LIST, 1)
b = x.AppendListBegin(b, I32, 1)
b = x.AppendI32(b, 1)
sz4 := len(b)

// map<i32, i64>
b = x.AppendMapBegin(b, I32, I64, 1)
b = x.AppendI32(b, 1)
b = x.AppendI64(b, 2)
sz5 := len(b)

// map<i32, string>
b = x.AppendMapBegin(b, I32, STRING, 1)
b = x.AppendI32(b, 1)
b = x.AppendString(b, "hello")
sz6 := len(b)

// map<string, i64>
b = x.AppendMapBegin(b, STRING, I64, 1)
b = x.AppendString(b, "hello")
b = x.AppendI64(b, 2)
sz7 := len(b)

// map<i32, list<i32>>
b = x.AppendMapBegin(b, I32, LIST, 1)
b = x.AppendI32(b, 1)
b = x.AppendListBegin(b, I32, 1)
b = x.AppendI32(b, 1)
sz8 := len(b)

// map<list<i32>, i32>
b = x.AppendMapBegin(b, LIST, I32, 1)
b = x.AppendListBegin(b, I32, 1)
b = x.AppendI32(b, 1)
b = x.AppendI32(b, 1)
sz9 := len(b)

// struct i32, list<i32>
b = x.AppendFieldBegin(b, I32, 1)
b = x.AppendI32(b, 1)
b = x.AppendFieldBegin(b, LIST, 1)
b = x.AppendListBegin(b, I32, 1)
b = x.AppendI32(b, 1)
b = x.AppendFieldStop(b)
sz10 := len(b)

r := NewSkipDecoder(bytes.NewReader(b))
defer r.Release()

readn := 0
b, err := r.Next(BYTE) // byte
require.NoError(t, err)
readn += len(b)
require.Equal(t, sz0, readn)
b, err = r.Next(STRING) // string
require.NoError(t, err)
readn += len(b)
require.Equal(t, sz1, readn)
b, err = r.Next(LIST) // list<i32>
require.NoError(t, err)
readn += len(b)
require.Equal(t, sz2, readn)
b, err = r.Next(LIST) // list<string>
require.NoError(t, err)
readn += len(b)
require.Equal(t, sz3, readn)
b, err = r.Next(LIST) // list<list<i32>>
require.NoError(t, err)
readn += len(b)
require.Equal(t, sz4, readn)
b, err = r.Next(MAP) // map<i32, i64>
require.NoError(t, err)
readn += len(b)
require.Equal(t, sz5, readn)
b, err = r.Next(MAP) // map<i32, string>
require.NoError(t, err)
readn += len(b)
require.Equal(t, sz6, readn)
b, err = r.Next(MAP) // map<string, i64>
require.NoError(t, err)
readn += len(b)
require.Equal(t, sz7, readn)
b, err = r.Next(MAP) // map<i32, list<i32>>
require.NoError(t, err)
readn += len(b)
require.Equal(t, sz8, readn)
b, err = r.Next(MAP) // map<list<i32>, i32>
require.NoError(t, err)
readn += len(b)
require.Equal(t, sz9, readn)
b, err = r.Next(STRUCT) // struct i32, list<i32>
require.NoError(t, err)
readn += len(b)
require.Equal(t, sz10, readn)

{ // other cases
// errDepthLimitExceeded
b = b[:0]
for i := 0; i < defaultRecursionDepth+1; i++ {
b = x.AppendFieldBegin(b, STRUCT, 1)
}
r := NewSkipDecoder(bytes.NewReader(b))
_, err := r.Next(STRUCT)
require.Same(t, errDepthLimitExceeded, err)

// unknown type
_, err = r.Next(TType(122))
require.Error(t, err)
}
}
Loading

0 comments on commit d1e1922

Please sign in to comment.