diff --git a/internal/unsafe/unsafe.go b/internal/hack/hack.go similarity index 98% rename from internal/unsafe/unsafe.go rename to internal/hack/hack.go index 4006280..c219209 100644 --- a/internal/unsafe/unsafe.go +++ b/internal/hack/hack.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package unsafe +package hack import "unsafe" diff --git a/internal/unsafe/unsafe_test.go b/internal/hack/hack_test.go similarity index 98% rename from internal/unsafe/unsafe_test.go rename to internal/hack/hack_test.go index 92c2aa0..749dd2f 100644 --- a/internal/unsafe/unsafe_test.go +++ b/internal/hack/hack_test.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package unsafe +package hack import ( "testing" diff --git a/protocol/thrift/binary.go b/protocol/thrift/binary.go index 7645ba7..b4648ab 100644 --- a/protocol/thrift/binary.go +++ b/protocol/thrift/binary.go @@ -20,15 +20,16 @@ import ( "encoding/binary" "fmt" "math" + "unsafe" - "github.com/cloudwego/gopkg/internal/unsafe" + "github.com/cloudwego/gopkg/internal/hack" ) -var Binary binaryProtocol +var Binary BinaryProtocol -type binaryProtocol struct{} +type BinaryProtocol struct{} -func (binaryProtocol) WriteMessageBegin(buf []byte, name string, typeID TMessageType, seq int32) int { +func (BinaryProtocol) WriteMessageBegin(buf []byte, name string, typeID TMessageType, seq int32) int { binary.BigEndian.PutUint32(buf, uint32(msgVersion1)|uint32(typeID&msgTypeMask)) binary.BigEndian.PutUint32(buf[4:], uint32(len(name))) off := 8 + copy(buf[8:], name) @@ -36,37 +37,37 @@ func (binaryProtocol) WriteMessageBegin(buf []byte, name string, typeID TMessage return off + 4 } -func (binaryProtocol) WriteFieldBegin(buf []byte, typeID TType, id int16) int { +func (BinaryProtocol) WriteFieldBegin(buf []byte, typeID TType, id int16) int { buf[0] = byte(typeID) binary.BigEndian.PutUint16(buf[1:], uint16(id)) return 3 } -func (binaryProtocol) WriteFieldStop(buf []byte) int { +func (BinaryProtocol) WriteFieldStop(buf []byte) int { buf[0] = byte(STOP) return 1 } -func (binaryProtocol) WriteMapBegin(buf []byte, kt, vt TType, size int) int { +func (BinaryProtocol) WriteMapBegin(buf []byte, kt, vt TType, size int) int { buf[0] = byte(kt) buf[1] = byte(vt) binary.BigEndian.PutUint32(buf[2:], uint32(size)) return 6 } -func (binaryProtocol) WriteListBegin(buf []byte, et TType, size int) int { +func (BinaryProtocol) WriteListBegin(buf []byte, et TType, size int) int { buf[0] = byte(et) binary.BigEndian.PutUint32(buf[1:], uint32(size)) return 5 } -func (binaryProtocol) WriteSetBegin(buf []byte, et TType, size int) int { +func (BinaryProtocol) WriteSetBegin(buf []byte, et TType, size int) int { buf[0] = byte(et) binary.BigEndian.PutUint32(buf[1:], uint32(size)) return 5 } -func (binaryProtocol) WriteBool(buf []byte, v bool) int { +func (BinaryProtocol) WriteBool(buf []byte, v bool) int { if v { buf[0] = 1 } else { @@ -75,91 +76,91 @@ func (binaryProtocol) WriteBool(buf []byte, v bool) int { return 1 } -func (binaryProtocol) WriteByte(buf []byte, v int8) int { +func (BinaryProtocol) WriteByte(buf []byte, v int8) int { buf[0] = byte(v) return 1 } -func (binaryProtocol) WriteI16(buf []byte, v int16) int { +func (BinaryProtocol) WriteI16(buf []byte, v int16) int { binary.BigEndian.PutUint16(buf, uint16(v)) return 2 } -func (binaryProtocol) WriteI32(buf []byte, v int32) int { +func (BinaryProtocol) WriteI32(buf []byte, v int32) int { binary.BigEndian.PutUint32(buf, uint32(v)) return 4 } -func (binaryProtocol) WriteI64(buf []byte, v int64) int { +func (BinaryProtocol) WriteI64(buf []byte, v int64) int { binary.BigEndian.PutUint64(buf, uint64(v)) return 8 } -func (binaryProtocol) WriteDouble(buf []byte, v float64) int { +func (BinaryProtocol) WriteDouble(buf []byte, v float64) int { binary.BigEndian.PutUint64(buf, math.Float64bits(v)) return 8 } -func (binaryProtocol) WriteBinary(buf, v []byte) int { +func (BinaryProtocol) WriteBinary(buf, v []byte) int { binary.BigEndian.PutUint32(buf, uint32(len(v))) return 4 + copy(buf[4:], v) } -func (binaryProtocol) WriteBinaryNocopy(buf []byte, w NocopyWriter, v []byte) int { +func (p BinaryProtocol) WriteBinaryNocopy(buf []byte, w NocopyWriter, v []byte) int { if w == nil || len(buf) < NocopyWriteThreshold { - return Binary.WriteBinary(buf, v) + return p.WriteBinary(buf, v) } binary.BigEndian.PutUint32(buf, uint32(len(v))) _ = w.WriteDirect(v, len(buf[4:])) // always err == nil ? return 4 } -func (binaryProtocol) WriteString(buf []byte, v string) int { +func (BinaryProtocol) WriteString(buf []byte, v string) int { binary.BigEndian.PutUint32(buf, uint32(len(v))) return 4 + copy(buf[4:], v) } -func (binaryProtocol) WriteStringNocopy(buf []byte, w NocopyWriter, v string) int { - return Binary.WriteBinaryNocopy(buf, w, unsafe.StringToByteSlice(v)) +func (p BinaryProtocol) WriteStringNocopy(buf []byte, w NocopyWriter, v string) int { + return p.WriteBinaryNocopy(buf, w, hack.StringToByteSlice(v)) } // Append methods -func (binaryProtocol) AppendMessageBegin(buf []byte, name string, typeID TMessageType, seq int32) []byte { +func (p BinaryProtocol) AppendMessageBegin(buf []byte, name string, typeID TMessageType, seq int32) []byte { buf = appendUint32(buf, uint32(msgVersion1)|uint32(typeID&msgTypeMask)) - buf = Binary.AppendString(buf, name) - return Binary.AppendI32(buf, seq) + buf = p.AppendString(buf, name) + return p.AppendI32(buf, seq) } -func (binaryProtocol) AppendFieldBegin(buf []byte, typeID TType, id int16) []byte { +func (BinaryProtocol) AppendFieldBegin(buf []byte, typeID TType, id int16) []byte { return append(buf, byte(typeID), byte(uint16(id>>8)), byte(id)) } -func (binaryProtocol) AppendFieldStop(buf []byte) []byte { +func (BinaryProtocol) AppendFieldStop(buf []byte) []byte { return append(buf, byte(STOP)) } -func (binaryProtocol) AppendMapBegin(buf []byte, kt, vt TType, size int) []byte { - return Binary.AppendI32(append(buf, byte(kt), byte(vt)), int32(size)) +func (p BinaryProtocol) AppendMapBegin(buf []byte, kt, vt TType, size int) []byte { + return p.AppendI32(append(buf, byte(kt), byte(vt)), int32(size)) } -func (binaryProtocol) AppendListBegin(buf []byte, et TType, size int) []byte { - return Binary.AppendI32(append(buf, byte(et)), int32(size)) +func (p BinaryProtocol) AppendListBegin(buf []byte, et TType, size int) []byte { + return p.AppendI32(append(buf, byte(et)), int32(size)) } -func (binaryProtocol) AppendSetBegin(buf []byte, et TType, size int) []byte { - return Binary.AppendI32(append(buf, byte(et)), int32(size)) +func (p BinaryProtocol) AppendSetBegin(buf []byte, et TType, size int) []byte { + return p.AppendI32(append(buf, byte(et)), int32(size)) } -func (binaryProtocol) AppendBinary(buf, v []byte) []byte { - return append(Binary.AppendI32(buf, int32(len(v))), v...) +func (p BinaryProtocol) AppendBinary(buf, v []byte) []byte { + return append(p.AppendI32(buf, int32(len(v))), v...) } -func (binaryProtocol) AppendString(buf []byte, v string) []byte { - return append(Binary.AppendI32(buf, int32(len(v))), v...) +func (p BinaryProtocol) AppendString(buf []byte, v string) []byte { + return append(p.AppendI32(buf, int32(len(v))), v...) } -func (binaryProtocol) AppendBool(buf []byte, v bool) []byte { +func (BinaryProtocol) AppendBool(buf []byte, v bool) []byte { if v { return append(buf, 1) } else { @@ -167,23 +168,23 @@ func (binaryProtocol) AppendBool(buf []byte, v bool) []byte { } } -func (binaryProtocol) AppendByte(buf []byte, v int8) []byte { +func (BinaryProtocol) AppendByte(buf []byte, v int8) []byte { return append(buf, byte(v)) } -func (binaryProtocol) AppendI16(buf []byte, v int16) []byte { +func (BinaryProtocol) AppendI16(buf []byte, v int16) []byte { return append(buf, byte(uint16(v)>>8), byte(v)) } -func (binaryProtocol) AppendI32(buf []byte, v int32) []byte { +func (BinaryProtocol) AppendI32(buf []byte, v int32) []byte { return appendUint32(buf, uint32(v)) } -func (binaryProtocol) AppendI64(buf []byte, v int64) []byte { +func (BinaryProtocol) AppendI64(buf []byte, v int64) []byte { return appendUint64(buf, uint64(v)) } -func (binaryProtocol) AppendDouble(buf []byte, v float64) []byte { +func (BinaryProtocol) AppendDouble(buf []byte, v float64) []byte { return appendUint64(buf, math.Float64bits(v)) } @@ -198,25 +199,25 @@ func appendUint64(buf []byte, v uint64) []byte { // Length methods -func (binaryProtocol) MessageBeginLength(name string, _ TMessageType, _ int32) int { +func (BinaryProtocol) MessageBeginLength(name string, _ TMessageType, _ int32) int { return 4 + (4 + len(name)) + 4 } -func (binaryProtocol) FieldBeginLength() int { return 3 } -func (binaryProtocol) FieldStopLength() int { return 1 } -func (binaryProtocol) MapBeginLength() int { return 6 } -func (binaryProtocol) ListBeginLength() int { return 5 } -func (binaryProtocol) SetBeginLength() int { return 5 } -func (binaryProtocol) BoolLength() int { return 1 } -func (binaryProtocol) ByteLength() int { return 1 } -func (binaryProtocol) I16Length() int { return 2 } -func (binaryProtocol) I32Length() int { return 4 } -func (binaryProtocol) I64Length() int { return 8 } -func (binaryProtocol) DoubleLength() int { return 8 } -func (binaryProtocol) StringLength(v string) int { return 4 + len(v) } -func (binaryProtocol) BinaryLength(v []byte) int { return 4 + len(v) } -func (binaryProtocol) StringLengthNocopy(v string) int { return 4 + len(v) } -func (binaryProtocol) BinaryLengthNocopy(v []byte) int { return 4 + len(v) } +func (BinaryProtocol) FieldBeginLength() int { return 3 } +func (BinaryProtocol) FieldStopLength() int { return 1 } +func (BinaryProtocol) MapBeginLength() int { return 6 } +func (BinaryProtocol) ListBeginLength() int { return 5 } +func (BinaryProtocol) SetBeginLength() int { return 5 } +func (BinaryProtocol) BoolLength() int { return 1 } +func (BinaryProtocol) ByteLength() int { return 1 } +func (BinaryProtocol) I16Length() int { return 2 } +func (BinaryProtocol) I32Length() int { return 4 } +func (BinaryProtocol) I64Length() int { return 8 } +func (BinaryProtocol) DoubleLength() int { return 8 } +func (BinaryProtocol) StringLength(v string) int { return 4 + len(v) } +func (BinaryProtocol) BinaryLength(v []byte) int { return 4 + len(v) } +func (BinaryProtocol) StringLengthNocopy(v string) int { return 4 + len(v) } +func (BinaryProtocol) BinaryLengthNocopy(v []byte) int { return 4 + len(v) } // Read methods @@ -225,7 +226,7 @@ var ( errBadVersion = NewProtocolException(BAD_VERSION, "ReadMessageBegin: bad version") ) -func (binaryProtocol) ReadMessageBegin(buf []byte) (name string, typeID TMessageType, seq int32, l int, err error) { +func (p BinaryProtocol) ReadMessageBegin(buf []byte) (name string, typeID TMessageType, seq int32, l int, err error) { if len(buf) < 4 { // version+type header + name header return "", 0, 0, 0, errReadMessage } @@ -240,14 +241,14 @@ func (binaryProtocol) ReadMessageBegin(buf []byte) (name string, typeID TMessage off := 4 // read method name - name, l, err1 := Binary.ReadString(buf[off:]) + name, l, err1 := p.ReadString(buf[off:]) if err1 != nil { return "", 0, 0, 0, errReadMessage } off += l // read seq - seq, l, err2 := Binary.ReadI32(buf[off:]) + seq, l, err2 := p.ReadI32(buf[off:]) if err2 != nil { return "", 0, 0, 0, errReadMessage } @@ -271,7 +272,7 @@ var ( errReadDouble = NewProtocolException(INVALID_DATA, "ReadDouble: len(buf) < 8") ) -func (binaryProtocol) ReadFieldBegin(buf []byte) (typeID TType, id int16, l int, err error) { +func (BinaryProtocol) ReadFieldBegin(buf []byte) (typeID TType, id int16, l int, err error) { if len(buf) < 1 { return 0, 0, 0, errReadField } @@ -285,29 +286,29 @@ func (binaryProtocol) ReadFieldBegin(buf []byte) (typeID TType, id int16, l int, return typeID, int16(binary.BigEndian.Uint16(buf[1:])), 3, nil } -func (binaryProtocol) ReadMapBegin(buf []byte) (kt, vt TType, size, l int, err error) { +func (BinaryProtocol) ReadMapBegin(buf []byte) (kt, vt TType, size, l int, err error) { if len(buf) < 6 { return 0, 0, 0, 0, errReadMap } return TType(buf[0]), TType(buf[1]), int(binary.BigEndian.Uint32(buf[2:])), 6, nil } -func (binaryProtocol) ReadListBegin(buf []byte) (et TType, size, l int, err error) { +func (BinaryProtocol) ReadListBegin(buf []byte) (et TType, size, l int, err error) { if len(buf) < 5 { return 0, 0, 0, errReadList } return TType(buf[0]), int(binary.BigEndian.Uint32(buf[1:])), 5, nil } -func (binaryProtocol) ReadSetBegin(buf []byte) (et TType, size, l int, err error) { +func (BinaryProtocol) ReadSetBegin(buf []byte) (et TType, size, l int, err error) { if len(buf) < 5 { return 0, 0, 0, errReadSet } return TType(buf[0]), int(binary.BigEndian.Uint32(buf[1:])), 5, nil } -func (binaryProtocol) ReadBinary(buf []byte) (b []byte, l int, err error) { - sz, _, err := Binary.ReadI32(buf) +func (p BinaryProtocol) ReadBinary(buf []byte) (b []byte, l int, err error) { + sz, _, err := p.ReadI32(buf) if err != nil { return nil, 0, errReadBin } @@ -319,8 +320,8 @@ func (binaryProtocol) ReadBinary(buf []byte) (b []byte, l int, err error) { return []byte(string(buf[4:l])), l, nil } -func (binaryProtocol) ReadString(buf []byte) (s string, l int, err error) { - sz, _, err := Binary.ReadI32(buf) +func (p BinaryProtocol) ReadString(buf []byte) (s string, l int, err error) { + sz, _, err := p.ReadI32(buf) if err != nil { return "", 0, errReadStr } @@ -332,7 +333,7 @@ func (binaryProtocol) ReadString(buf []byte) (s string, l int, err error) { return string(buf[4:l]), l, nil } -func (binaryProtocol) ReadBool(buf []byte) (v bool, l int, err error) { +func (BinaryProtocol) ReadBool(buf []byte) (v bool, l int, err error) { if len(buf) < 1 { return false, 0, errReadBool } @@ -342,45 +343,42 @@ func (binaryProtocol) ReadBool(buf []byte) (v bool, l int, err error) { return false, 1, nil } -func (binaryProtocol) ReadByte(buf []byte) (v int8, l int, err error) { +func (BinaryProtocol) ReadByte(buf []byte) (v int8, l int, err error) { if len(buf) < 1 { return 0, 0, errReadByte } return int8(buf[0]), 1, nil } -func (binaryProtocol) ReadI16(buf []byte) (v int16, l int, err error) { +func (BinaryProtocol) ReadI16(buf []byte) (v int16, l int, err error) { if len(buf) < 2 { return 0, 0, errReadI16 } return int16(binary.BigEndian.Uint16(buf)), 2, nil } -func (binaryProtocol) ReadI32(buf []byte) (v int32, l int, err error) { +func (BinaryProtocol) ReadI32(buf []byte) (v int32, l int, err error) { if len(buf) < 4 { return 0, 0, errReadI32 } return int32(binary.BigEndian.Uint32(buf)), 4, nil } -func (binaryProtocol) ReadI64(buf []byte) (v int64, l int, err error) { +func (BinaryProtocol) ReadI64(buf []byte) (v int64, l int, err error) { if len(buf) < 8 { return 0, 0, errReadI64 } return int64(binary.BigEndian.Uint64(buf)), 8, nil } -func (binaryProtocol) ReadDouble(buf []byte) (v float64, l int, err error) { +func (BinaryProtocol) ReadDouble(buf []byte) (v float64, l int, err error) { if len(buf) < 8 { return 0, 0, errReadDouble } return math.Float64frombits(binary.BigEndian.Uint64(buf)), 8, nil } -var ( - errDepthLimitExceeded = NewProtocolException(DEPTH_LIMIT, "depth limit exceeded") - errNegativeSize = NewProtocolException(NEGATIVE_SIZE, "negative size") -) +var errDepthLimitExceeded = NewProtocolException(DEPTH_LIMIT, "depth limit exceeded") var typeToSize = [256]int8{ BOOL: 1, @@ -391,91 +389,155 @@ var typeToSize = [256]int8{ I64: 8, } -func skipstr(b []byte) int { - return 4 + int(binary.BigEndian.Uint32(b)) +func skipstr(p unsafe.Pointer, e uintptr) (int, error) { + if uintptr(p)+uintptr(4) <= e { + n := int(p2i32(p)) + if n < 0 { + return 0, errNegativeSize + } + if uintptr(p)+uintptr(4+n) <= e { + return 4 + n, nil + } + } + return 0, errBufferTooShort } // Skip skips over the value for the given type using Go implementation. -func (binaryProtocol) Skip(b []byte, t TType) (int, error) { - return skipType(b, t, defaultRecursionDepth) +func (BinaryProtocol) Skip(b []byte, t TType) (int, error) { + if len(b) == 0 { + return 0, errBufferTooShort + } + p := unsafe.Pointer(&b[0]) + e := uintptr(p) + uintptr(len(b)) + return skipType(p, e, t, defaultRecursionDepth) } -func skipType(b []byte, t TType, maxdepth int) (int, error) { +func skipType(p unsafe.Pointer, e uintptr, t TType, maxdepth int) (int, error) { if maxdepth == 0 { return 0, errDepthLimitExceeded } if n := typeToSize[t]; n > 0 { + if uintptr(p)+uintptr(n) > e { + return 0, errBufferTooShort + } return int(n), nil } + var err error switch t { case STRING: - return skipstr(b), nil + return skipstr(p, e) case MAP: - i := 6 - kt, vt, sz := TType(b[0]), TType(b[1]), int32(binary.BigEndian.Uint32(b[2:])) + if uintptr(p)+uintptr(6) > e { + return 0, errBufferTooShort + } + kt, vt, sz := TType(*(*byte)(p)), TType(*(*byte)(unsafe.Add(p, 1))), p2i32(unsafe.Add(p, 2)) if sz < 0 { return 0, errNegativeSize } ksz, vsz := int(typeToSize[kt]), int(typeToSize[vt]) - if ksz > 0 && vsz > 0 { - return i + (int(sz) * (ksz + vsz)), nil + if ksz > 0 && vsz > 0 { // fast path, fast skip + mapkvsize := (int(sz) * (ksz + vsz)) + if uintptr(p)+uintptr(6+mapkvsize) > e { + return 0, errBufferTooShort + } + return 6 + mapkvsize, nil } + i := 6 for j := int32(0); j < sz; j++ { + if uintptr(p)+uintptr(i) >= e { + return 0, errBufferTooShort + } + ki := 0 if ksz > 0 { - i += ksz + ki = ksz } else if kt == STRING { - i += skipstr(b[i:]) - } else if n, err := skipType(b[i:], kt, maxdepth-1); err != nil { - return i, err + ki, err = skipstr(unsafe.Add(p, i), e) } else { - i += n + ki, err = skipType(unsafe.Add(p, i), e, kt, maxdepth-1) } + if err != nil { + return i, err + } + i += ki + if uintptr(p)+uintptr(i) >= e { + return 0, errBufferTooShort + } + vi := 0 if vsz > 0 { - i += vsz + vi = vsz } else if vt == STRING { - i += skipstr(b[i:]) - } else if n, err := skipType(b[i:], vt, maxdepth-1); err != nil { - return i, err + vi, err = skipstr(unsafe.Add(p, i), e) } else { - i += n + vi, err = skipType(unsafe.Add(p, i), e, vt, maxdepth-1) } + if err != nil { + return i, err + } + i += vi } return i, nil case LIST, SET: - i := 5 - vt, sz := TType(b[0]), int32(binary.BigEndian.Uint32(b[1:])) + if uintptr(p)+uintptr(5) > e { + return 0, errBufferTooShort + } + vt, sz := TType(*(*byte)(p)), p2i32(unsafe.Add(p, 1)) if sz < 0 { return 0, errNegativeSize } - if typeToSize[vt] > 0 { - return i + int(sz)*int(typeToSize[vt]), nil + vsz := int(typeToSize[vt]) + if vsz > 0 { // fast path, fast skip + listvsize := int(sz) * vsz + if uintptr(p)+uintptr(5+listvsize) > e { + return 0, errBufferTooShort + } + return 5 + listvsize, nil } + i := 5 for j := int32(0); j < sz; j++ { - if vt == STRING { - i += skipstr(b[i:]) - } else if n, err := skipType(b[i:], vt, maxdepth-1); err != nil { - return i, err + if uintptr(p)+uintptr(i) >= e { + return 0, errBufferTooShort + } + vi := 0 + if vsz > 0 { + vi = vsz + } else if vt == STRING { + vi, err = skipstr(unsafe.Add(p, i), e) } else { - i += n + vi, err = skipType(unsafe.Add(p, i), e, vt, maxdepth-1) + } + if err != nil { + return i, err } + i += vi } return i, nil case STRUCT: i := 0 for { - ft := TType(b[i]) + if uintptr(p)+uintptr(i) >= e { + return i, errBufferTooShort + } + ft := TType(*(*byte)(unsafe.Add(p, i))) i += 1 // TType if ft == STOP { return i, nil } i += 2 // Field ID + if uintptr(p)+uintptr(i) >= e { + return i, errBufferTooShort + } + fi := 0 if typeToSize[ft] > 0 { - i += int(typeToSize[ft]) - } else if n, err := skipType(b[i:], ft, maxdepth-1); err != nil { - return i, err + fi = int(typeToSize[ft]) + } else if ft == STRING { + fi, err = skipstr(unsafe.Add(p, i), e) } else { - i += n + fi, err = skipType(unsafe.Add(p, i), e, ft, maxdepth-1) + } + if err != nil { + return i, err } + i += fi } default: return 0, NewProtocolException(INVALID_DATA, fmt.Sprintf("unknown data type %d", t)) diff --git a/protocol/thrift/binaryreader.go b/protocol/thrift/binaryreader.go new file mode 100644 index 0000000..b0a77e0 --- /dev/null +++ b/protocol/thrift/binaryreader.go @@ -0,0 +1,375 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "encoding/binary" + "fmt" + "io" + "math" + "sync" +) + +type nextIface interface { + Next(n int) ([]byte, error) +} + +type discardIface interface { + Discard(n int) (int, error) +} + +// BinaryReader represents a reader for binary protocol +type BinaryReader struct { + r nextIface + d discardIface + + rn int64 +} + +var poolBinaryReader = sync.Pool{ + New: func() interface{} { + return &BinaryReader{} + }, +} + +// NewBinaryReader ... call Release if no longer use for reusing +func NewBinaryReader(r io.Reader) *BinaryReader { + ret := poolBinaryReader.Get().(*BinaryReader) + ret.reset() + if nextr, ok := r.(nextIface); ok { + ret.r = nextr + } else { + nextr := poolNextReader.Get().(*nextReader) + nextr.Reset(r) + ret.r = nextr + ret.d = nextr + } + return ret +} + +// Release ... +func (r *BinaryReader) Release() { + nextr, ok := r.r.(*nextReader) + if ok { + poolNextReader.Put(nextr) + } + r.reset() + poolBinaryReader.Put(r) +} + +func (r *BinaryReader) reset() { + r.r = nil + r.d = nil + r.rn = 0 +} + +func (r *BinaryReader) next(n int) (b []byte, err error) { + b, err = r.r.Next(n) + if err != nil { + err = NewProtocolExceptionWithErr(err) + } + r.rn += int64(len(b)) + return +} + +func (r *BinaryReader) skipn(n int) (err error) { + if n < 0 { + return errNegativeSize + } + if r.d != nil { + var sz int + sz, err = r.d.Discard(n) + r.rn += int64(sz) + } else { + var b []byte + b, err = r.r.Next(n) + r.rn += int64(len(b)) + } + if err != nil { + return NewProtocolExceptionWithErr(err) + } + return nil +} + +// Readn returns total bytes read from underlying reader +func (r *BinaryReader) Readn() int64 { + return r.rn +} + +// ReadBool ... +func (r *BinaryReader) ReadBool() (v bool, err error) { + b, err := r.next(1) + if err != nil { + return false, err + } + v = b[0] == 1 + return +} + +// ReadByte ... +func (r *BinaryReader) ReadByte() (v int8, err error) { + b, err := r.next(1) + if err != nil { + return 0, err + } + v = int8(b[0]) + return +} + +// ReadI16 ... +func (r *BinaryReader) ReadI16() (v int16, err error) { + b, err := r.next(2) + if err != nil { + return 0, err + } + v = int16(binary.BigEndian.Uint16(b)) + return +} + +// ReadI32 ... +func (r *BinaryReader) ReadI32() (v int32, err error) { + b, err := r.next(4) + if err != nil { + return 0, err + } + v = int32(binary.BigEndian.Uint32(b)) + return +} + +// ReadI64 ... +func (r *BinaryReader) ReadI64() (v int64, err error) { + b, err := r.next(8) + if err != nil { + return 0, err + } + v = int64(binary.BigEndian.Uint64(b)) + return +} + +// ReadDouble ... +func (r *BinaryReader) ReadDouble() (v float64, err error) { + b, err := r.next(8) + if err != nil { + return 0, err + } + v = math.Float64frombits(binary.BigEndian.Uint64(b)) + return +} + +// ReadBinary ... +func (r *BinaryReader) ReadBinary() (b []byte, err error) { + sz, err := r.ReadI32() + if err != nil { + return nil, err + } + b, err = r.next(int(sz)) + b = []byte(string(b)) // copy. use span cache? + return +} + +// ReadString ... +func (r *BinaryReader) ReadString() (s string, err error) { + sz, err := r.ReadI32() + if err != nil { + return "", err + } + b, err := r.next(int(sz)) + if err != nil { + return "", err + } + s = string(b) // copy. use span cache? + return +} + +// ReadMessageBegin ... +func (r *BinaryReader) ReadMessageBegin() (name string, typeID TMessageType, seq int32, err error) { + var header int32 + header, err = r.ReadI32() + if err != nil { + return + } + // read header for version and type + if uint32(header)&msgVersionMask != msgVersion1 { + err = errBadVersion + return + } + typeID = TMessageType(uint32(header) & msgTypeMask) + + // read method name + name, err = r.ReadString() + if err != nil { + return + } + + // read seq + seq, err = r.ReadI32() + if err != nil { + return + } + return +} + +// ReadFieldBegin ... +func (r *BinaryReader) ReadFieldBegin() (typeID TType, id int16, err error) { + b, err := r.next(1) + if err != nil { + return 0, 0, err + } + typeID = TType(b[0]) + if typeID == STOP { + return STOP, 0, nil + } + b, err = r.next(2) + if err != nil { + return 0, 0, err + } + id = int16(binary.BigEndian.Uint16(b)) + return +} + +// ReadMapBegin ... +func (r *BinaryReader) ReadMapBegin() (kt, vt TType, size int, err error) { + b, err := r.next(6) + if err != nil { + return 0, 0, 0, err + } + kt, vt, size = TType(b[0]), TType(b[1]), int(binary.BigEndian.Uint32(b[2:])) + return +} + +// ReadListBegin ... +func (r *BinaryReader) ReadListBegin() (et TType, size int, err error) { + b, err := r.next(5) + if err != nil { + return 0, 0, err + } + et, size = TType(b[0]), int(binary.BigEndian.Uint32(b[1:])) + return +} + +// ReadSetBegin ... +func (r *BinaryReader) ReadSetBegin() (et TType, size int, err error) { + b, err := r.next(5) + if err != nil { + return 0, 0, err + } + et, size = TType(b[0]), int(binary.BigEndian.Uint32(b[1:])) + return +} + +// Skip ... +func (r *BinaryReader) Skip(t TType) error { + return r.skipType(t, defaultRecursionDepth) +} + +func (r *BinaryReader) skipstr() error { + n, err := r.ReadI32() + if err != nil { + return err + } + return r.skipn(int(n)) +} + +func (r *BinaryReader) skipType(t TType, maxdepth int) error { + if maxdepth == 0 { + return errDepthLimitExceeded + } + if n := typeToSize[t]; n > 0 { + return r.skipn(int(n)) + } + switch t { + case STRING: + return r.skipstr() + case MAP: + kt, vt, sz, err := r.ReadMapBegin() + if err != nil { + return err + } + if sz < 0 { + return errNegativeSize + } + ksz, vsz := int(typeToSize[kt]), int(typeToSize[vt]) + if ksz > 0 && vsz > 0 { + return r.skipn(int(sz) * (ksz + vsz)) + } + for j := 0; j < sz; j++ { + if ksz > 0 { + err = r.skipn(ksz) + } else if kt == STRING { + err = r.skipstr() + } else { + err = r.skipType(kt, maxdepth-1) + } + if err != nil { + return err + } + if vsz > 0 { + err = r.skipn(vsz) + } else if vt == STRING { + err = r.skipstr() + } else { + err = r.skipType(vt, maxdepth-1) + } + if err != nil { + return err + } + } + return nil + case LIST, SET: + vt, sz, err := r.ReadListBegin() + if err != nil { + return err + } + if sz < 0 { + return errNegativeSize + } + if vsz := typeToSize[vt]; vsz > 0 { + return r.skipn(sz * int(vsz)) + } + for j := 0; j < sz; j++ { + if vt == STRING { + err = r.skipstr() + } else { + err = r.skipType(vt, maxdepth-1) + } + if err != nil { + return err + } + } + return nil + case STRUCT: + for { + ft, _, err := r.ReadFieldBegin() + if err != nil { + return err + } + if ft == STOP { + return nil + } + if fsz := typeToSize[ft]; fsz > 0 { + err = r.skipn(int(fsz)) + } else { + err = r.skipType(ft, maxdepth-1) + } + if err != nil { + return err + } + } + default: + return NewProtocolException(INVALID_DATA, fmt.Sprintf("unknown data type %d", t)) + } +} diff --git a/protocol/thrift/binaryreader_test.go b/protocol/thrift/binaryreader_test.go new file mode 100644 index 0000000..accd165 --- /dev/null +++ b/protocol/thrift/binaryreader_test.go @@ -0,0 +1,260 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBinaryReader(t *testing.T) { + x := BinaryProtocol{} + b := x.AppendMessageBegin(nil, "hello", 1, 2) + sz0 := len(b) + b = x.AppendFieldBegin(b, 3, 4) + sz1 := len(b) + b = x.AppendFieldStop(b) + sz2 := len(b) + b = x.AppendMapBegin(b, 5, 6, 7) + sz3 := len(b) + b = x.AppendListBegin(b, 8, 9) + sz4 := len(b) + b = x.AppendSetBegin(b, 10, 11) + sz5 := len(b) + b = x.AppendBinary(b, []byte("12")) + sz6 := len(b) + b = x.AppendString(b, "13") + sz7 := len(b) + b = x.AppendBool(b, true) + b = x.AppendBool(b, false) + sz8 := len(b) + b = x.AppendByte(b, 14) + sz9 := len(b) + b = x.AppendI16(b, 15) + sz10 := len(b) + b = x.AppendI32(b, 16) + sz11 := len(b) + b = x.AppendI64(b, 17) + sz12 := len(b) + b = x.AppendDouble(b, 18.5) + sz13 := len(b) + + r := NewBinaryReader(bytes.NewReader(b)) + defer r.Release() + name, mt, seq, err := r.ReadMessageBegin() + require.NoError(t, err) + require.Equal(t, "hello", name) + require.Equal(t, TMessageType(1), mt) + require.Equal(t, int32(2), seq) + require.Equal(t, sz0, int(r.Readn())) + + ft, fid, err := r.ReadFieldBegin() + require.NoError(t, err) + require.Equal(t, TType(3), ft) + require.Equal(t, int16(4), fid) + require.Equal(t, sz1, int(r.Readn())) + + ft, fid, err = r.ReadFieldBegin() // for AppendFieldStop + require.NoError(t, err) + require.Equal(t, STOP, ft) + require.Equal(t, int16(0), fid) + require.Equal(t, sz2, int(r.Readn())) + + kt, vt, sz, err := r.ReadMapBegin() + require.NoError(t, err) + require.Equal(t, TType(5), kt) + require.Equal(t, TType(6), vt) + require.Equal(t, int(7), sz) + require.Equal(t, sz3, int(r.Readn())) + + et, sz, err := r.ReadListBegin() + require.NoError(t, err) + require.Equal(t, TType(8), et) + require.Equal(t, int(9), sz) + require.Equal(t, sz4, int(r.Readn())) + + et, sz, err = r.ReadSetBegin() + require.NoError(t, err) + require.Equal(t, TType(10), et) + require.Equal(t, int(11), sz) + require.Equal(t, sz5, int(r.Readn())) + + bin, err := r.ReadBinary() + require.NoError(t, err) + require.Equal(t, "12", string(bin)) + require.Equal(t, sz6, int(r.Readn())) + + s, err := r.ReadString() + require.NoError(t, err) + require.Equal(t, "13", s) + require.Equal(t, sz7, int(r.Readn())) + + vb, err := r.ReadBool() + require.NoError(t, err) + require.True(t, vb) + vb, err = r.ReadBool() + require.NoError(t, err) + require.False(t, vb) + require.Equal(t, sz8, int(r.Readn())) + + v8, err := r.ReadByte() + require.NoError(t, err) + require.Equal(t, int8(14), v8) + require.Equal(t, sz9, int(r.Readn())) + + v16, err := r.ReadI16() + require.NoError(t, err) + require.Equal(t, int16(15), v16) + require.Equal(t, sz10, int(r.Readn())) + + v32, err := r.ReadI32() + require.NoError(t, err) + require.Equal(t, int32(16), v32) + require.Equal(t, sz11, int(r.Readn())) + + v64, err := r.ReadI64() + require.NoError(t, err) + require.Equal(t, int64(17), v64) + require.Equal(t, sz12, int(r.Readn())) + + vf, err := r.ReadDouble() + require.NoError(t, err) + require.Equal(t, float64(18.5), vf) + require.Equal(t, sz13, int(r.Readn())) +} + +func TestBinaryReaderSkip(t *testing.T) { + x := BinaryProtocol{} + // byte + b := x.AppendByte([]byte(nil), 1) + sz0 := len(b) + + // string + b = x.AppendString(b, "hello") + sz1 := len(b) + + // list + b = x.AppendListBegin(b, I32, 1) + b = x.AppendI32(b, 1) + sz2 := len(b) + + // list + b = x.AppendListBegin(b, STRING, 1) + b = x.AppendString(b, "hello") + sz3 := len(b) + + // list> + b = x.AppendListBegin(b, LIST, 1) + b = x.AppendListBegin(b, I32, 1) + b = x.AppendI32(b, 1) + sz4 := len(b) + + // map + b = x.AppendMapBegin(b, I32, I64, 1) + b = x.AppendI32(b, 1) + b = x.AppendI64(b, 2) + sz5 := len(b) + + // map + b = x.AppendMapBegin(b, I32, STRING, 1) + b = x.AppendI32(b, 1) + b = x.AppendString(b, "hello") + sz6 := len(b) + + // map + b = x.AppendMapBegin(b, STRING, I64, 1) + b = x.AppendString(b, "hello") + b = x.AppendI64(b, 2) + sz7 := len(b) + + // map> + b = x.AppendMapBegin(b, I32, LIST, 1) + b = x.AppendI32(b, 1) + b = x.AppendListBegin(b, I32, 1) + b = x.AppendI32(b, 1) + sz8 := len(b) + + // map, i32> + b = x.AppendMapBegin(b, LIST, I32, 1) + b = x.AppendListBegin(b, I32, 1) + b = x.AppendI32(b, 1) + b = x.AppendI32(b, 1) + sz9 := len(b) + + // struct i32, list + b = x.AppendFieldBegin(b, I32, 1) + b = x.AppendI32(b, 1) + b = x.AppendFieldBegin(b, LIST, 1) + b = x.AppendListBegin(b, I32, 1) + b = x.AppendI32(b, 1) + b = x.AppendFieldStop(b) + sz10 := len(b) + + r := NewBinaryReader(bytes.NewReader(b)) + defer r.Release() + + err := r.Skip(BYTE) // byte + require.NoError(t, err) + require.Equal(t, int64(sz0), r.Readn()) + err = r.Skip(STRING) // string + require.NoError(t, err) + require.Equal(t, int64(sz1), r.Readn()) + err = r.Skip(LIST) // list + require.NoError(t, err) + require.Equal(t, int64(sz2), r.Readn()) + err = r.Skip(LIST) // list + require.NoError(t, err) + require.Equal(t, int64(sz3), r.Readn()) + err = r.Skip(LIST) // list> + require.NoError(t, err) + require.Equal(t, int64(sz4), r.Readn()) + err = r.Skip(MAP) // map + require.NoError(t, err) + require.Equal(t, int64(sz5), r.Readn()) + err = r.Skip(MAP) // map + require.NoError(t, err) + require.Equal(t, int64(sz6), r.Readn()) + err = r.Skip(MAP) // map + require.NoError(t, err) + require.Equal(t, int64(sz7), r.Readn()) + err = r.Skip(MAP) // map> + require.NoError(t, err) + require.Equal(t, int64(sz8), r.Readn()) + err = r.Skip(MAP) // map, i32> + require.NoError(t, err) + require.Equal(t, int64(sz9), r.Readn()) + err = r.Skip(STRUCT) // struct i32, list + require.NoError(t, err) + require.Equal(t, int64(sz10), r.Readn()) + + { // other cases + // errDepthLimitExceeded + b = b[:0] + for i := 0; i < defaultRecursionDepth+1; i++ { + b = x.AppendFieldBegin(b, STRUCT, 1) + } + r := NewBinaryReader(bytes.NewReader(b)) + err := r.Skip(STRUCT) + require.Same(t, errDepthLimitExceeded, err) + + // unknown type + err = r.Skip(TType(122)) + require.Error(t, err) + } +} diff --git a/protocol/thrift/binarywriter.go b/protocol/thrift/binarywriter.go new file mode 100644 index 0000000..d77636e --- /dev/null +++ b/protocol/thrift/binarywriter.go @@ -0,0 +1,120 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "math" + "sync" +) + +const defaultBinaryWriterBufferSize = 4096 + +type BinaryWriter struct { + buf []byte +} + +var poolBinaryWriter = sync.Pool{ + New: func() interface{} { + return &BinaryWriter{buf: make([]byte, 0, defaultBinaryWriterBufferSize)} + }, +} + +func NewBinaryWriter() *BinaryWriter { + return NewBinaryWriterSize(0) +} + +func NewBinaryWriterSize(sz int) *BinaryWriter { + w := poolBinaryWriter.Get().(*BinaryWriter) + if cap(w.buf) < sz { + w.Release() + w = &BinaryWriter{buf: make([]byte, 0, sz)} + } + w.Reset() + return w +} + +func (w *BinaryWriter) Release() { + poolBinaryWriter.Put(w) +} + +func (w *BinaryWriter) Reset() { + w.buf = w.buf[:0] +} + +func (w *BinaryWriter) Bytes() []byte { + return w.buf +} + +func (w *BinaryWriter) WriteMessageBegin(name string, typeID TMessageType, seq int32) { + w.buf = Binary.AppendMessageBegin(w.buf, name, typeID, seq) +} + +func (w *BinaryWriter) WriteFieldBegin(typeID TType, id int16) { + w.buf = Binary.AppendFieldBegin(w.buf, typeID, id) +} + +func (w *BinaryWriter) WriteFieldStop() { + w.buf = append(w.buf, byte(STOP)) +} + +func (w *BinaryWriter) WriteMapBegin(kt, vt TType, size int) { + w.buf = Binary.AppendMapBegin(w.buf, kt, vt, size) +} + +func (w *BinaryWriter) WriteListBegin(et TType, size int) { + w.buf = Binary.AppendListBegin(w.buf, et, size) +} + +func (w *BinaryWriter) WriteSetBegin(et TType, size int) { + w.buf = Binary.AppendSetBegin(w.buf, et, size) +} + +func (w *BinaryWriter) WriteBinary(v []byte) { + w.buf = Binary.AppendBinary(w.buf, v) +} + +func (w *BinaryWriter) WriteString(v string) { + w.buf = Binary.AppendString(w.buf, v) +} + +func (w *BinaryWriter) WriteBool(v bool) { + if v { + w.buf = append(w.buf, 1) + } else { + w.buf = append(w.buf, 0) + } +} + +func (w *BinaryWriter) WriteByte(v int8) { + w.buf = append(w.buf, byte(v)) +} + +func (w *BinaryWriter) WriteI16(v int16) { + w.buf = append(w.buf, byte(uint16(v)>>8), byte(v)) +} + +func (w *BinaryWriter) WriteI32(v int32) { + w.buf = appendUint32(w.buf, uint32(v)) +} + +func (w *BinaryWriter) WriteI64(v int64) { + w.buf = appendUint64(w.buf, uint64(v)) +} + +func (w *BinaryWriter) WriteDouble(v float64) { + w.buf = appendUint64(w.buf, math.Float64bits(v)) +} diff --git a/protocol/thrift/binarywriter_test.go b/protocol/thrift/binarywriter_test.go new file mode 100644 index 0000000..ee77817 --- /dev/null +++ b/protocol/thrift/binarywriter_test.go @@ -0,0 +1,86 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBinaryWriter(t *testing.T) { + w := NewBinaryWriterSize(defaultBinaryWriterBufferSize * 2) + x := BinaryProtocol{} + + b := x.AppendMessageBegin(nil, "hello", 1, 2) + w.WriteMessageBegin("hello", 1, 2) + require.Equal(t, b, w.Bytes()) + + b = x.AppendFieldBegin(b, 3, 4) + w.WriteFieldBegin(3, 4) + require.Equal(t, b, w.Bytes()) + + b = x.AppendFieldStop(b) + w.WriteFieldStop() + require.Equal(t, b, w.Bytes()) + + b = x.AppendMapBegin(b, 5, 6, 7) + w.WriteMapBegin(5, 6, 7) + require.Equal(t, b, w.Bytes()) + + b = x.AppendListBegin(b, 8, 9) + w.WriteListBegin(8, 9) + require.Equal(t, b, w.Bytes()) + + b = x.AppendSetBegin(b, 10, 11) + w.WriteSetBegin(10, 11) + require.Equal(t, b, w.Bytes()) + + b = x.AppendBinary(b, []byte("12")) + w.WriteBinary([]byte("12")) + require.Equal(t, b, w.Bytes()) + + b = x.AppendString(b, "13") + w.WriteString("13") + require.Equal(t, b, w.Bytes()) + + b = x.AppendBool(b, true) + b = x.AppendBool(b, false) + w.WriteBool(true) + w.WriteBool(false) + require.Equal(t, b, w.Bytes()) + + b = x.AppendByte(b, 14) + w.WriteByte(14) + require.Equal(t, b, w.Bytes()) + + b = x.AppendI16(b, 15) + w.WriteI16(15) + require.Equal(t, b, w.Bytes()) + + b = x.AppendI32(b, 16) + w.WriteI32(16) + require.Equal(t, b, w.Bytes()) + + b = x.AppendI64(b, 17) + w.WriteI64(17) + require.Equal(t, b, w.Bytes()) + + b = x.AppendDouble(b, 18.5) + w.WriteDouble(18.5) + require.Equal(t, b, w.Bytes()) +} diff --git a/protocol/thrift/exception.go b/protocol/thrift/exception.go index a1e069e..05e7fb5 100644 --- a/protocol/thrift/exception.go +++ b/protocol/thrift/exception.go @@ -154,6 +154,8 @@ func NewTransportException(t int32, m string) *TransportException { // it implements ThriftFastCodec interface. type ProtocolException struct { ApplicationException // same implementation ... + + err error } const ( // ProtocolException codes from apache thrift @@ -166,7 +168,12 @@ const ( // ProtocolException codes from apache thrift DEPTH_LIMIT = 6 ) -// NewTransportException ... +var ( + errBufferTooShort = NewProtocolException(INVALID_DATA, "buffer too short") + errNegativeSize = NewProtocolException(NEGATIVE_SIZE, "negative size") +) + +// NewTransportExceptionWithType func NewProtocolException(t int32, m string) *ProtocolException { ret := ProtocolException{} ret.t = t @@ -174,6 +181,29 @@ func NewProtocolException(t int32, m string) *ProtocolException { return &ret } +// NewProtocolException ... +func NewProtocolExceptionWithErr(err error) *ProtocolException { + e, ok := err.(*ProtocolException) + if ok { + return e + } + ret := NewProtocolException(UNKNOWN_PROTOCOL_EXCEPTION, err.Error()) + ret.err = err + return ret +} + +// Unwrap ... for errors pkg +func (e *ProtocolException) Unwrap() error { return e.err } + +// Is ... for errors pkg +func (e *ProtocolException) Is(err error) bool { + t, ok := err.(tException) + if ok && t.TypeId() == e.t && t.Error() == e.m { + return true + } + return errors.Is(e.err, err) +} + // Generic Thrift exception with TypeId method type tException interface { Error() string diff --git a/protocol/thrift/exception_test.go b/protocol/thrift/exception_test.go index 6bf98b7..134498f 100644 --- a/protocol/thrift/exception_test.go +++ b/protocol/thrift/exception_test.go @@ -18,6 +18,7 @@ package thrift import ( "errors" + "io" "testing" "github.com/stretchr/testify/assert" @@ -47,6 +48,12 @@ func TestApplicationException(t *testing.T) { t.Log(ex4.String()) // ... } +func TestProtocolException(t *testing.T) { + e := NewProtocolExceptionWithErr(io.EOF) + assert.ErrorIs(t, e, io.EOF) // will call errors.Is + assert.True(t, e.Is(NewProtocolException(UNKNOWN_PROTOCOL_EXCEPTION, "EOF"))) +} + type testTException struct{} func (testTException) Error() string { return "testTException" } diff --git a/protocol/thrift/utils.go b/protocol/thrift/utils.go new file mode 100644 index 0000000..756af5f --- /dev/null +++ b/protocol/thrift/utils.go @@ -0,0 +1,81 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "io" + "sync" + "unsafe" +) + +// p2i32, used by skipType which implements a fast skip with unsafe.Pointer without bounds check +func p2i32(p unsafe.Pointer) int32 { + return int32(uint32(*(*byte)(unsafe.Add(p, 3))) | + uint32(*(*byte)(unsafe.Add(p, 2)))<<8 | + uint32(*(*byte)(unsafe.Add(p, 1)))<<16 | + uint32(*(*byte)(p))<<24) +} + +// nextReader provides a wrapper for io.Reader to use BinaryReader +type nextReader struct { + r io.Reader + b [4096]byte +} + +var poolNextReader = sync.Pool{ + New: func() interface{} { + return &nextReader{} + }, +} + +// Next implements nextIface of BinaryReader +func (r *nextReader) Next(n int) ([]byte, error) { + b := r.b[:] + if n <= len(b) { + b = b[:n] + } else { + b = make([]byte, n) + } + _, err := io.ReadFull(r.r, b) + if err != nil { + return nil, err + } + return b, nil +} + +// Discard implements discardIface of BinaryReader +func (r *nextReader) Discard(n int) (int, error) { + ret := 0 + b := r.b[:] + for n > 0 { + if len(b) > n { + b = b[:n] + } + readn, err := r.r.Read(b) + ret += readn + if err != nil { + return ret, err + } + n -= readn + } + return ret, nil +} + +// Reset ... for reusing nextReader +func (r *nextReader) Reset(rd io.Reader) { + r.r = rd +}