Skip to content

Commit

Permalink
fix(thrift): ReadXXX return errDataLength if sz < 0 (#36)
Browse files Browse the repository at this point in the history
align with previous implementation of kitex/pkg/protocol/bthrift
  • Loading branch information
xiaost authored Dec 17, 2024
1 parent 22433c9 commit 3172135
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 16 deletions.
34 changes: 26 additions & 8 deletions protocol/thrift/binary.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,21 +306,39 @@ func (BinaryProtocol) ReadMapBegin(buf []byte) (kt, vt TType, size, l int, err e
if len(buf) < 6 {
return 0, 0, 0, 0, errReadMap
}
return TType(buf[0]), TType(buf[1]), int(binary.BigEndian.Uint32(buf[2:])), 6, nil
l = 6
kt, vt = TType(buf[0]), TType(buf[1])
size = int(int32(binary.BigEndian.Uint32(buf[2:])))
if size < 0 {
err = errDataLength
}
return
}

func (BinaryProtocol) ReadListBegin(buf []byte) (et TType, size, l int, err error) {
if len(buf) < 5 {
return 0, 0, 0, errReadList
}
return TType(buf[0]), int(binary.BigEndian.Uint32(buf[1:])), 5, nil
l = 5
et = TType(buf[0])
size = int(int32(binary.BigEndian.Uint32(buf[1:])))
if size < 0 {
err = errDataLength
}
return
}

func (BinaryProtocol) ReadSetBegin(buf []byte) (et TType, size, l int, err error) {
if len(buf) < 5 {
return 0, 0, 0, errReadSet
}
return TType(buf[0]), int(binary.BigEndian.Uint32(buf[1:])), 5, nil
l = 5
et = TType(buf[0])
size = int(int32(binary.BigEndian.Uint32(buf[1:])))
if size < 0 {
err = errDataLength
}
return
}

func (p BinaryProtocol) ReadBinary(buf []byte) (b []byte, l int, err error) {
Expand All @@ -329,7 +347,7 @@ func (p BinaryProtocol) ReadBinary(buf []byte) (b []byte, l int, err error) {
return nil, 0, errReadBin
}
if sz < 0 {
return nil, 0, errNegativeSize
return nil, 0, errDataLength
}
l = 4 + int(sz)
if len(buf) < l {
Expand All @@ -349,7 +367,7 @@ func (p BinaryProtocol) ReadString(buf []byte) (s string, l int, err error) {
return "", 0, errReadStr
}
if sz < 0 {
return "", 0, errNegativeSize
return "", 0, errDataLength
}
l = 4 + int(sz)
if len(buf) < l {
Expand Down Expand Up @@ -424,7 +442,7 @@ func skipstr(p unsafe.Pointer, e uintptr) (int, error) {
if uintptr(p)+uintptr(4) <= e {
n := int(p2i32(p))
if n < 0 {
return 0, errNegativeSize
return 0, errDataLength
}
if uintptr(p)+uintptr(4+n) <= e {
return 4 + n, nil
Expand Down Expand Up @@ -463,7 +481,7 @@ func skipType(p unsafe.Pointer, e uintptr, t TType, maxdepth int) (int, error) {
}
kt, vt, sz := TType(*(*byte)(p)), TType(*(*byte)(unsafe.Add(p, 1))), p2i32(unsafe.Add(p, 2))
if sz < 0 {
return 0, errNegativeSize
return 0, errDataLength
}
ksz, vsz := int(typeToSize[kt]), int(typeToSize[vt])
if ksz > 0 && vsz > 0 { // fast path, fast skip
Expand Down Expand Up @@ -513,7 +531,7 @@ func skipType(p unsafe.Pointer, e uintptr, t TType, maxdepth int) (int, error) {
}
vt, sz := TType(*(*byte)(p)), p2i32(unsafe.Add(p, 1))
if sz < 0 {
return 0, errNegativeSize
return 0, errDataLength
}
vsz := int(typeToSize[vt])
if vsz > 0 { // fast path, fast skip
Expand Down
36 changes: 36 additions & 0 deletions protocol/thrift/binary_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,42 @@ func TestBinary(t *testing.T) {
}
}

func TestBinary_ErrDataLength(t *testing.T) {
x := BinaryProtocol{}
{ // String
b := x.AppendI32([]byte(nil), -1)
_, _, err := x.ReadString(b)
require.Same(t, errDataLength, err)
}

{ // Binary
b := x.AppendI32([]byte(nil), -1)
_, _, err := x.ReadBinary(b)
require.Same(t, errDataLength, err)
}

{ // Map
testkt, testvt, testsize := I64, I32, -1
b := x.AppendMapBegin([]byte(nil), testkt, testvt, testsize)
_, _, _, _, err := x.ReadMapBegin(b)
require.Same(t, errDataLength, err)
}

{ // List
testvt, testsize := I32, -1
b := x.AppendListBegin([]byte(nil), testvt, testsize)
_, _, _, err := x.ReadListBegin(b)
require.Same(t, errDataLength, err)
}

{ // Set
testvt, testsize := I32, -1
b := x.AppendSetBegin([]byte(nil), testvt, testsize)
_, _, _, err := x.ReadSetBegin(b)
require.Same(t, errDataLength, err)
}
}

func TestBinarySkip(t *testing.T) {
// byte
b := Binary.AppendByte([]byte(nil), 1)
Expand Down
8 changes: 4 additions & 4 deletions protocol/thrift/bufferreader.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func (r *BufferReader) readBinary(bs []byte) (n int, err error) {

func (r *BufferReader) skipn(n int) (err error) {
if n < 0 {
return errNegativeSize
return errDataLength
}
if err = r.r.Skip(n); err != nil {
return NewProtocolExceptionWithErr(err)
Expand Down Expand Up @@ -149,7 +149,7 @@ func (r *BufferReader) ReadBinary() (b []byte, err error) {
return nil, err
}
if sz < 0 {
return nil, errNegativeSize
return nil, errDataLength
}
b = dirtmake.Bytes(int(sz), int(sz))
_, err = r.readBinary(b)
Expand Down Expand Up @@ -270,7 +270,7 @@ func (r *BufferReader) skipType(t TType, maxdepth int) error {
return err
}
if sz < 0 {
return errNegativeSize
return errDataLength
}
ksz, vsz := int(typeToSize[kt]), int(typeToSize[vt])
if ksz > 0 && vsz > 0 {
Expand Down Expand Up @@ -305,7 +305,7 @@ func (r *BufferReader) skipType(t TType, maxdepth int) error {
return err
}
if sz < 0 {
return errNegativeSize
return errDataLength
}
if vsz := typeToSize[vt]; vsz > 0 {
return r.skipn(sz * int(vsz))
Expand Down
2 changes: 1 addition & 1 deletion protocol/thrift/exception.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ const ( // ProtocolException codes from apache thrift

var (
errBufferTooShort = NewProtocolException(INVALID_DATA, "buffer too short")
errNegativeSize = NewProtocolException(NEGATIVE_SIZE, "negative size")
errDataLength = NewProtocolException(INVALID_DATA, "invalid data length")
)

// NewTransportExceptionWithType
Expand Down
6 changes: 3 additions & 3 deletions protocol/thrift/skipdecoder_tpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (p SkipDecoderTpl[T]) Skip(t TType, maxdepth int) error {
}
sz := int(binary.BigEndian.Uint32(b))
if sz < 0 {
return errNegativeSize
return errDataLength
}
if _, err := p.r.SkipN(sz); err != nil {
return err
Expand Down Expand Up @@ -90,7 +90,7 @@ func (p SkipDecoderTpl[T]) Skip(t TType, maxdepth int) error {
}
kt, vt, sz := TType(b[0]), TType(b[1]), int32(binary.BigEndian.Uint32(b[2:]))
if sz < 0 {
return errNegativeSize
return errDataLength
}
ksz, vsz := int(typeToSize[kt]), int(typeToSize[vt])
if ksz > 0 && vsz > 0 {
Expand All @@ -112,7 +112,7 @@ func (p SkipDecoderTpl[T]) Skip(t TType, maxdepth int) error {
}
vt, sz := TType(b[0]), int32(binary.BigEndian.Uint32(b[1:]))
if sz < 0 {
return errNegativeSize
return errDataLength
}
if vsz := typeToSize[vt]; vsz > 0 {
_, err := p.r.SkipN(int(sz) * int(vsz))
Expand Down

0 comments on commit 3172135

Please sign in to comment.