diff --git a/protocol/thrift/binary.go b/protocol/thrift/binary.go index 06f13af..dbf07ed 100644 --- a/protocol/thrift/binary.go +++ b/protocol/thrift/binary.go @@ -306,21 +306,39 @@ func (BinaryProtocol) ReadMapBegin(buf []byte) (kt, vt TType, size, l int, err e if len(buf) < 6 { return 0, 0, 0, 0, errReadMap } - return TType(buf[0]), TType(buf[1]), int(binary.BigEndian.Uint32(buf[2:])), 6, nil + l = 6 + kt, vt = TType(buf[0]), TType(buf[1]) + size = int(int32(binary.BigEndian.Uint32(buf[2:]))) + if size < 0 { + err = errDataLength + } + return } func (BinaryProtocol) ReadListBegin(buf []byte) (et TType, size, l int, err error) { if len(buf) < 5 { return 0, 0, 0, errReadList } - return TType(buf[0]), int(binary.BigEndian.Uint32(buf[1:])), 5, nil + l = 5 + et = TType(buf[0]) + size = int(int32(binary.BigEndian.Uint32(buf[1:]))) + if size < 0 { + err = errDataLength + } + return } func (BinaryProtocol) ReadSetBegin(buf []byte) (et TType, size, l int, err error) { if len(buf) < 5 { return 0, 0, 0, errReadSet } - return TType(buf[0]), int(binary.BigEndian.Uint32(buf[1:])), 5, nil + l = 5 + et = TType(buf[0]) + size = int(int32(binary.BigEndian.Uint32(buf[1:]))) + if size < 0 { + err = errDataLength + } + return } func (p BinaryProtocol) ReadBinary(buf []byte) (b []byte, l int, err error) { @@ -329,7 +347,7 @@ func (p BinaryProtocol) ReadBinary(buf []byte) (b []byte, l int, err error) { return nil, 0, errReadBin } if sz < 0 { - return nil, 0, errNegativeSize + return nil, 0, errDataLength } l = 4 + int(sz) if len(buf) < l { @@ -349,7 +367,7 @@ func (p BinaryProtocol) ReadString(buf []byte) (s string, l int, err error) { return "", 0, errReadStr } if sz < 0 { - return "", 0, errNegativeSize + return "", 0, errDataLength } l = 4 + int(sz) if len(buf) < l { @@ -424,7 +442,7 @@ func skipstr(p unsafe.Pointer, e uintptr) (int, error) { if uintptr(p)+uintptr(4) <= e { n := int(p2i32(p)) if n < 0 { - return 0, errNegativeSize + return 0, errDataLength } if uintptr(p)+uintptr(4+n) <= e { return 4 + n, nil @@ -463,7 +481,7 @@ func skipType(p unsafe.Pointer, e uintptr, t TType, maxdepth int) (int, error) { } kt, vt, sz := TType(*(*byte)(p)), TType(*(*byte)(unsafe.Add(p, 1))), p2i32(unsafe.Add(p, 2)) if sz < 0 { - return 0, errNegativeSize + return 0, errDataLength } ksz, vsz := int(typeToSize[kt]), int(typeToSize[vt]) if ksz > 0 && vsz > 0 { // fast path, fast skip @@ -513,7 +531,7 @@ func skipType(p unsafe.Pointer, e uintptr, t TType, maxdepth int) (int, error) { } vt, sz := TType(*(*byte)(p)), p2i32(unsafe.Add(p, 1)) if sz < 0 { - return 0, errNegativeSize + return 0, errDataLength } vsz := int(typeToSize[vt]) if vsz > 0 { // fast path, fast skip diff --git a/protocol/thrift/binary_test.go b/protocol/thrift/binary_test.go index 52016e8..1243658 100644 --- a/protocol/thrift/binary_test.go +++ b/protocol/thrift/binary_test.go @@ -303,6 +303,42 @@ func TestBinary(t *testing.T) { } } +func TestBinary_ErrDataLength(t *testing.T) { + x := BinaryProtocol{} + { // String + b := x.AppendI32([]byte(nil), -1) + _, _, err := x.ReadString(b) + require.Same(t, errDataLength, err) + } + + { // Binary + b := x.AppendI32([]byte(nil), -1) + _, _, err := x.ReadBinary(b) + require.Same(t, errDataLength, err) + } + + { // Map + testkt, testvt, testsize := I64, I32, -1 + b := x.AppendMapBegin([]byte(nil), testkt, testvt, testsize) + _, _, _, _, err := x.ReadMapBegin(b) + require.Same(t, errDataLength, err) + } + + { // List + testvt, testsize := I32, -1 + b := x.AppendListBegin([]byte(nil), testvt, testsize) + _, _, _, err := x.ReadListBegin(b) + require.Same(t, errDataLength, err) + } + + { // Set + testvt, testsize := I32, -1 + b := x.AppendSetBegin([]byte(nil), testvt, testsize) + _, _, _, err := x.ReadSetBegin(b) + require.Same(t, errDataLength, err) + } +} + func TestBinarySkip(t *testing.T) { // byte b := Binary.AppendByte([]byte(nil), 1) diff --git a/protocol/thrift/bufferreader.go b/protocol/thrift/bufferreader.go index 8915e54..ac187ea 100644 --- a/protocol/thrift/bufferreader.go +++ b/protocol/thrift/bufferreader.go @@ -69,7 +69,7 @@ func (r *BufferReader) readBinary(bs []byte) (n int, err error) { func (r *BufferReader) skipn(n int) (err error) { if n < 0 { - return errNegativeSize + return errDataLength } if err = r.r.Skip(n); err != nil { return NewProtocolExceptionWithErr(err) @@ -149,7 +149,7 @@ func (r *BufferReader) ReadBinary() (b []byte, err error) { return nil, err } if sz < 0 { - return nil, errNegativeSize + return nil, errDataLength } b = dirtmake.Bytes(int(sz), int(sz)) _, err = r.readBinary(b) @@ -270,7 +270,7 @@ func (r *BufferReader) skipType(t TType, maxdepth int) error { return err } if sz < 0 { - return errNegativeSize + return errDataLength } ksz, vsz := int(typeToSize[kt]), int(typeToSize[vt]) if ksz > 0 && vsz > 0 { @@ -305,7 +305,7 @@ func (r *BufferReader) skipType(t TType, maxdepth int) error { return err } if sz < 0 { - return errNegativeSize + return errDataLength } if vsz := typeToSize[vt]; vsz > 0 { return r.skipn(sz * int(vsz)) diff --git a/protocol/thrift/exception.go b/protocol/thrift/exception.go index 05e7fb5..331e371 100644 --- a/protocol/thrift/exception.go +++ b/protocol/thrift/exception.go @@ -170,7 +170,7 @@ const ( // ProtocolException codes from apache thrift var ( errBufferTooShort = NewProtocolException(INVALID_DATA, "buffer too short") - errNegativeSize = NewProtocolException(NEGATIVE_SIZE, "negative size") + errDataLength = NewProtocolException(INVALID_DATA, "invalid data length") ) // NewTransportExceptionWithType diff --git a/protocol/thrift/skipdecoder_tpl.go b/protocol/thrift/skipdecoder_tpl.go index 00b19d8..d23b4f3 100644 --- a/protocol/thrift/skipdecoder_tpl.go +++ b/protocol/thrift/skipdecoder_tpl.go @@ -61,7 +61,7 @@ func (p SkipDecoderTpl[T]) Skip(t TType, maxdepth int) error { } sz := int(binary.BigEndian.Uint32(b)) if sz < 0 { - return errNegativeSize + return errDataLength } if _, err := p.r.SkipN(sz); err != nil { return err @@ -90,7 +90,7 @@ func (p SkipDecoderTpl[T]) Skip(t TType, maxdepth int) error { } kt, vt, sz := TType(b[0]), TType(b[1]), int32(binary.BigEndian.Uint32(b[2:])) if sz < 0 { - return errNegativeSize + return errDataLength } ksz, vsz := int(typeToSize[kt]), int(typeToSize[vt]) if ksz > 0 && vsz > 0 { @@ -112,7 +112,7 @@ func (p SkipDecoderTpl[T]) Skip(t TType, maxdepth int) error { } vt, sz := TType(b[0]), int32(binary.BigEndian.Uint32(b[1:])) if sz < 0 { - return errNegativeSize + return errDataLength } if vsz := typeToSize[vt]; vsz > 0 { _, err := p.r.SkipN(int(sz) * int(vsz))