diff --git a/protocol/thrift/binary.go b/protocol/thrift/binary.go index 7645ba7..5374143 100644 --- a/protocol/thrift/binary.go +++ b/protocol/thrift/binary.go @@ -24,11 +24,11 @@ import ( "github.com/cloudwego/gopkg/internal/unsafe" ) -var Binary binaryProtocol +var Binary BinaryProtocol -type binaryProtocol struct{} +type BinaryProtocol struct{} -func (binaryProtocol) WriteMessageBegin(buf []byte, name string, typeID TMessageType, seq int32) int { +func (BinaryProtocol) WriteMessageBegin(buf []byte, name string, typeID TMessageType, seq int32) int { binary.BigEndian.PutUint32(buf, uint32(msgVersion1)|uint32(typeID&msgTypeMask)) binary.BigEndian.PutUint32(buf[4:], uint32(len(name))) off := 8 + copy(buf[8:], name) @@ -36,37 +36,37 @@ func (binaryProtocol) WriteMessageBegin(buf []byte, name string, typeID TMessage return off + 4 } -func (binaryProtocol) WriteFieldBegin(buf []byte, typeID TType, id int16) int { +func (BinaryProtocol) WriteFieldBegin(buf []byte, typeID TType, id int16) int { buf[0] = byte(typeID) binary.BigEndian.PutUint16(buf[1:], uint16(id)) return 3 } -func (binaryProtocol) WriteFieldStop(buf []byte) int { +func (BinaryProtocol) WriteFieldStop(buf []byte) int { buf[0] = byte(STOP) return 1 } -func (binaryProtocol) WriteMapBegin(buf []byte, kt, vt TType, size int) int { +func (BinaryProtocol) WriteMapBegin(buf []byte, kt, vt TType, size int) int { buf[0] = byte(kt) buf[1] = byte(vt) binary.BigEndian.PutUint32(buf[2:], uint32(size)) return 6 } -func (binaryProtocol) WriteListBegin(buf []byte, et TType, size int) int { +func (BinaryProtocol) WriteListBegin(buf []byte, et TType, size int) int { buf[0] = byte(et) binary.BigEndian.PutUint32(buf[1:], uint32(size)) return 5 } -func (binaryProtocol) WriteSetBegin(buf []byte, et TType, size int) int { +func (BinaryProtocol) WriteSetBegin(buf []byte, et TType, size int) int { buf[0] = byte(et) binary.BigEndian.PutUint32(buf[1:], uint32(size)) return 5 } -func (binaryProtocol) WriteBool(buf []byte, v bool) int { +func (BinaryProtocol) WriteBool(buf []byte, v bool) int { if v { buf[0] = 1 } else { @@ -75,37 +75,37 @@ func (binaryProtocol) WriteBool(buf []byte, v bool) int { return 1 } -func (binaryProtocol) WriteByte(buf []byte, v int8) int { +func (BinaryProtocol) WriteByte(buf []byte, v int8) int { buf[0] = byte(v) return 1 } -func (binaryProtocol) WriteI16(buf []byte, v int16) int { +func (BinaryProtocol) WriteI16(buf []byte, v int16) int { binary.BigEndian.PutUint16(buf, uint16(v)) return 2 } -func (binaryProtocol) WriteI32(buf []byte, v int32) int { +func (BinaryProtocol) WriteI32(buf []byte, v int32) int { binary.BigEndian.PutUint32(buf, uint32(v)) return 4 } -func (binaryProtocol) WriteI64(buf []byte, v int64) int { +func (BinaryProtocol) WriteI64(buf []byte, v int64) int { binary.BigEndian.PutUint64(buf, uint64(v)) return 8 } -func (binaryProtocol) WriteDouble(buf []byte, v float64) int { +func (BinaryProtocol) WriteDouble(buf []byte, v float64) int { binary.BigEndian.PutUint64(buf, math.Float64bits(v)) return 8 } -func (binaryProtocol) WriteBinary(buf, v []byte) int { +func (BinaryProtocol) WriteBinary(buf, v []byte) int { binary.BigEndian.PutUint32(buf, uint32(len(v))) return 4 + copy(buf[4:], v) } -func (binaryProtocol) WriteBinaryNocopy(buf []byte, w NocopyWriter, v []byte) int { +func (BinaryProtocol) WriteBinaryNocopy(buf []byte, w NocopyWriter, v []byte) int { if w == nil || len(buf) < NocopyWriteThreshold { return Binary.WriteBinary(buf, v) } @@ -114,52 +114,52 @@ func (binaryProtocol) WriteBinaryNocopy(buf []byte, w NocopyWriter, v []byte) in return 4 } -func (binaryProtocol) WriteString(buf []byte, v string) int { +func (BinaryProtocol) WriteString(buf []byte, v string) int { binary.BigEndian.PutUint32(buf, uint32(len(v))) return 4 + copy(buf[4:], v) } -func (binaryProtocol) WriteStringNocopy(buf []byte, w NocopyWriter, v string) int { +func (BinaryProtocol) WriteStringNocopy(buf []byte, w NocopyWriter, v string) int { return Binary.WriteBinaryNocopy(buf, w, unsafe.StringToByteSlice(v)) } // Append methods -func (binaryProtocol) AppendMessageBegin(buf []byte, name string, typeID TMessageType, seq int32) []byte { +func (BinaryProtocol) AppendMessageBegin(buf []byte, name string, typeID TMessageType, seq int32) []byte { buf = appendUint32(buf, uint32(msgVersion1)|uint32(typeID&msgTypeMask)) buf = Binary.AppendString(buf, name) return Binary.AppendI32(buf, seq) } -func (binaryProtocol) AppendFieldBegin(buf []byte, typeID TType, id int16) []byte { +func (BinaryProtocol) AppendFieldBegin(buf []byte, typeID TType, id int16) []byte { return append(buf, byte(typeID), byte(uint16(id>>8)), byte(id)) } -func (binaryProtocol) AppendFieldStop(buf []byte) []byte { +func (BinaryProtocol) AppendFieldStop(buf []byte) []byte { return append(buf, byte(STOP)) } -func (binaryProtocol) AppendMapBegin(buf []byte, kt, vt TType, size int) []byte { +func (BinaryProtocol) AppendMapBegin(buf []byte, kt, vt TType, size int) []byte { return Binary.AppendI32(append(buf, byte(kt), byte(vt)), int32(size)) } -func (binaryProtocol) AppendListBegin(buf []byte, et TType, size int) []byte { +func (BinaryProtocol) AppendListBegin(buf []byte, et TType, size int) []byte { return Binary.AppendI32(append(buf, byte(et)), int32(size)) } -func (binaryProtocol) AppendSetBegin(buf []byte, et TType, size int) []byte { +func (BinaryProtocol) AppendSetBegin(buf []byte, et TType, size int) []byte { return Binary.AppendI32(append(buf, byte(et)), int32(size)) } -func (binaryProtocol) AppendBinary(buf, v []byte) []byte { +func (BinaryProtocol) AppendBinary(buf, v []byte) []byte { return append(Binary.AppendI32(buf, int32(len(v))), v...) } -func (binaryProtocol) AppendString(buf []byte, v string) []byte { +func (BinaryProtocol) AppendString(buf []byte, v string) []byte { return append(Binary.AppendI32(buf, int32(len(v))), v...) } -func (binaryProtocol) AppendBool(buf []byte, v bool) []byte { +func (BinaryProtocol) AppendBool(buf []byte, v bool) []byte { if v { return append(buf, 1) } else { @@ -167,23 +167,23 @@ func (binaryProtocol) AppendBool(buf []byte, v bool) []byte { } } -func (binaryProtocol) AppendByte(buf []byte, v int8) []byte { +func (BinaryProtocol) AppendByte(buf []byte, v int8) []byte { return append(buf, byte(v)) } -func (binaryProtocol) AppendI16(buf []byte, v int16) []byte { +func (BinaryProtocol) AppendI16(buf []byte, v int16) []byte { return append(buf, byte(uint16(v)>>8), byte(v)) } -func (binaryProtocol) AppendI32(buf []byte, v int32) []byte { +func (BinaryProtocol) AppendI32(buf []byte, v int32) []byte { return appendUint32(buf, uint32(v)) } -func (binaryProtocol) AppendI64(buf []byte, v int64) []byte { +func (BinaryProtocol) AppendI64(buf []byte, v int64) []byte { return appendUint64(buf, uint64(v)) } -func (binaryProtocol) AppendDouble(buf []byte, v float64) []byte { +func (BinaryProtocol) AppendDouble(buf []byte, v float64) []byte { return appendUint64(buf, math.Float64bits(v)) } @@ -198,25 +198,25 @@ func appendUint64(buf []byte, v uint64) []byte { // Length methods -func (binaryProtocol) MessageBeginLength(name string, _ TMessageType, _ int32) int { +func (BinaryProtocol) MessageBeginLength(name string, _ TMessageType, _ int32) int { return 4 + (4 + len(name)) + 4 } -func (binaryProtocol) FieldBeginLength() int { return 3 } -func (binaryProtocol) FieldStopLength() int { return 1 } -func (binaryProtocol) MapBeginLength() int { return 6 } -func (binaryProtocol) ListBeginLength() int { return 5 } -func (binaryProtocol) SetBeginLength() int { return 5 } -func (binaryProtocol) BoolLength() int { return 1 } -func (binaryProtocol) ByteLength() int { return 1 } -func (binaryProtocol) I16Length() int { return 2 } -func (binaryProtocol) I32Length() int { return 4 } -func (binaryProtocol) I64Length() int { return 8 } -func (binaryProtocol) DoubleLength() int { return 8 } -func (binaryProtocol) StringLength(v string) int { return 4 + len(v) } -func (binaryProtocol) BinaryLength(v []byte) int { return 4 + len(v) } -func (binaryProtocol) StringLengthNocopy(v string) int { return 4 + len(v) } -func (binaryProtocol) BinaryLengthNocopy(v []byte) int { return 4 + len(v) } +func (BinaryProtocol) FieldBeginLength() int { return 3 } +func (BinaryProtocol) FieldStopLength() int { return 1 } +func (BinaryProtocol) MapBeginLength() int { return 6 } +func (BinaryProtocol) ListBeginLength() int { return 5 } +func (BinaryProtocol) SetBeginLength() int { return 5 } +func (BinaryProtocol) BoolLength() int { return 1 } +func (BinaryProtocol) ByteLength() int { return 1 } +func (BinaryProtocol) I16Length() int { return 2 } +func (BinaryProtocol) I32Length() int { return 4 } +func (BinaryProtocol) I64Length() int { return 8 } +func (BinaryProtocol) DoubleLength() int { return 8 } +func (BinaryProtocol) StringLength(v string) int { return 4 + len(v) } +func (BinaryProtocol) BinaryLength(v []byte) int { return 4 + len(v) } +func (BinaryProtocol) StringLengthNocopy(v string) int { return 4 + len(v) } +func (BinaryProtocol) BinaryLengthNocopy(v []byte) int { return 4 + len(v) } // Read methods @@ -225,7 +225,7 @@ var ( errBadVersion = NewProtocolException(BAD_VERSION, "ReadMessageBegin: bad version") ) -func (binaryProtocol) ReadMessageBegin(buf []byte) (name string, typeID TMessageType, seq int32, l int, err error) { +func (BinaryProtocol) ReadMessageBegin(buf []byte) (name string, typeID TMessageType, seq int32, l int, err error) { if len(buf) < 4 { // version+type header + name header return "", 0, 0, 0, errReadMessage } @@ -271,7 +271,7 @@ var ( errReadDouble = NewProtocolException(INVALID_DATA, "ReadDouble: len(buf) < 8") ) -func (binaryProtocol) ReadFieldBegin(buf []byte) (typeID TType, id int16, l int, err error) { +func (BinaryProtocol) ReadFieldBegin(buf []byte) (typeID TType, id int16, l int, err error) { if len(buf) < 1 { return 0, 0, 0, errReadField } @@ -285,28 +285,28 @@ func (binaryProtocol) ReadFieldBegin(buf []byte) (typeID TType, id int16, l int, return typeID, int16(binary.BigEndian.Uint16(buf[1:])), 3, nil } -func (binaryProtocol) ReadMapBegin(buf []byte) (kt, vt TType, size, l int, err error) { +func (BinaryProtocol) ReadMapBegin(buf []byte) (kt, vt TType, size, l int, err error) { if len(buf) < 6 { return 0, 0, 0, 0, errReadMap } return TType(buf[0]), TType(buf[1]), int(binary.BigEndian.Uint32(buf[2:])), 6, nil } -func (binaryProtocol) ReadListBegin(buf []byte) (et TType, size, l int, err error) { +func (BinaryProtocol) ReadListBegin(buf []byte) (et TType, size, l int, err error) { if len(buf) < 5 { return 0, 0, 0, errReadList } return TType(buf[0]), int(binary.BigEndian.Uint32(buf[1:])), 5, nil } -func (binaryProtocol) ReadSetBegin(buf []byte) (et TType, size, l int, err error) { +func (BinaryProtocol) ReadSetBegin(buf []byte) (et TType, size, l int, err error) { if len(buf) < 5 { return 0, 0, 0, errReadSet } return TType(buf[0]), int(binary.BigEndian.Uint32(buf[1:])), 5, nil } -func (binaryProtocol) ReadBinary(buf []byte) (b []byte, l int, err error) { +func (BinaryProtocol) ReadBinary(buf []byte) (b []byte, l int, err error) { sz, _, err := Binary.ReadI32(buf) if err != nil { return nil, 0, errReadBin @@ -319,7 +319,7 @@ func (binaryProtocol) ReadBinary(buf []byte) (b []byte, l int, err error) { return []byte(string(buf[4:l])), l, nil } -func (binaryProtocol) ReadString(buf []byte) (s string, l int, err error) { +func (BinaryProtocol) ReadString(buf []byte) (s string, l int, err error) { sz, _, err := Binary.ReadI32(buf) if err != nil { return "", 0, errReadStr @@ -332,7 +332,7 @@ func (binaryProtocol) ReadString(buf []byte) (s string, l int, err error) { return string(buf[4:l]), l, nil } -func (binaryProtocol) ReadBool(buf []byte) (v bool, l int, err error) { +func (BinaryProtocol) ReadBool(buf []byte) (v bool, l int, err error) { if len(buf) < 1 { return false, 0, errReadBool } @@ -342,35 +342,35 @@ func (binaryProtocol) ReadBool(buf []byte) (v bool, l int, err error) { return false, 1, nil } -func (binaryProtocol) ReadByte(buf []byte) (v int8, l int, err error) { +func (BinaryProtocol) ReadByte(buf []byte) (v int8, l int, err error) { if len(buf) < 1 { return 0, 0, errReadByte } return int8(buf[0]), 1, nil } -func (binaryProtocol) ReadI16(buf []byte) (v int16, l int, err error) { +func (BinaryProtocol) ReadI16(buf []byte) (v int16, l int, err error) { if len(buf) < 2 { return 0, 0, errReadI16 } return int16(binary.BigEndian.Uint16(buf)), 2, nil } -func (binaryProtocol) ReadI32(buf []byte) (v int32, l int, err error) { +func (BinaryProtocol) ReadI32(buf []byte) (v int32, l int, err error) { if len(buf) < 4 { return 0, 0, errReadI32 } return int32(binary.BigEndian.Uint32(buf)), 4, nil } -func (binaryProtocol) ReadI64(buf []byte) (v int64, l int, err error) { +func (BinaryProtocol) ReadI64(buf []byte) (v int64, l int, err error) { if len(buf) < 8 { return 0, 0, errReadI64 } return int64(binary.BigEndian.Uint64(buf)), 8, nil } -func (binaryProtocol) ReadDouble(buf []byte) (v float64, l int, err error) { +func (BinaryProtocol) ReadDouble(buf []byte) (v float64, l int, err error) { if len(buf) < 8 { return 0, 0, errReadDouble } @@ -379,7 +379,6 @@ func (binaryProtocol) ReadDouble(buf []byte) (v float64, l int, err error) { var ( errDepthLimitExceeded = NewProtocolException(DEPTH_LIMIT, "depth limit exceeded") - errNegativeSize = NewProtocolException(NEGATIVE_SIZE, "negative size") ) var typeToSize = [256]int8{ @@ -396,7 +395,7 @@ func skipstr(b []byte) int { } // Skip skips over the value for the given type using Go implementation. -func (binaryProtocol) Skip(b []byte, t TType) (int, error) { +func (BinaryProtocol) Skip(b []byte, t TType) (int, error) { return skipType(b, t, defaultRecursionDepth) } diff --git a/protocol/thrift/binaryreader.go b/protocol/thrift/binaryreader.go new file mode 100644 index 0000000..29fc476 --- /dev/null +++ b/protocol/thrift/binaryreader.go @@ -0,0 +1,413 @@ +package thrift + +import ( + "bufio" + "encoding/binary" + "fmt" + "io" + "math" + "sync" +) + +type BinaryReader struct { + p int // for r.Peek + r *bufio.Reader + b []byte + + rn int64 // for r and b, total bytes read + + // if true will reuse *bufio.Reader + // it becomes false if Reader() called. + reuse bool +} + +var poolBinaryReader = sync.Pool{ + New: func() interface{} { + return &BinaryReader{r: bufio.NewReader(nil), reuse: true} + }, +} + +// NewBinaryReader creates a BinaryReader from io.Reader +// +// NOTE: +// * call `Reader()` if you want to continue using io.Reader after using BinaryReader +// * call `Release()` if no longer use BinaryReader +// * DO NOT USE `BinaryReader` after calling `Release()` +func NewBinaryReader(r io.Reader) *BinaryReader { + br := poolBinaryReader.Get().(*BinaryReader) + br.reset() + if br.r == nil { + br.r = bufio.NewReader(r) + } else { + br.r.Reset(r) + } + return br +} + +// NewBinaryReaderBuffer creates a BinaryReader with []byte +func NewBinaryReaderBuffer(b []byte) *BinaryReader { + if b == nil { + panic("b == nil") + } + br := poolBinaryReader.Get().(*BinaryReader) + br.reset() + br.b = b + return br +} + +func (r *BinaryReader) reset() { + r.p = 0 + r.b = nil + r.rn = 0 + r.reuse = (r.r != nil) +} + +// Readn returns bytes read or skipped. +// +// you can Readn() with Skip() like: +// +// r := NewBinaryReaderBuffer(b) +// defer r.Release() +// r.Skip(TStruct) +// structb := b[:r.Readn()] // <--- here is the struct bytes +func (r *BinaryReader) Readn() int64 { + return r.rn +} + +// Release ... see comment of NewBinaryReader +func (r *BinaryReader) Release() { + if !r.reuse { + r.r = nil + } + r.b = nil + poolBinaryReader.Put(r) +} + +// Reader returns the underlying buffered reader of BinaryReader. +// +// Coz the underlying buffered reader may retain unread bytes, +// user may want to continue to use the io.Reader after reading from BinaryReader. +func (r *BinaryReader) Reader() io.Reader { + r.reuse = false + return r.r +} + +// peek ... MUST call ack(n) after calling peek(n) +func (r *BinaryReader) peek(n int) ([]byte, error) { + if n < 0 { + return nil, errNegativeSize + } + if r.b != nil { + return r.peekb(n) + } + return r.peekr(n) +} + +func (r *BinaryReader) peekb(n int) ([]byte, error) { + if int(r.rn)+n > len(r.b) { + return nil, errInvalidDataLen + } + b := r.b[int(r.rn):] + ret := b[:n:n] + return ret, nil +} + +func (r *BinaryReader) peekr(n int) ([]byte, error) { + b, err := r.r.Peek(r.p + n) + if err != nil { + return nil, NewProtocolExceptionWithErr(err) + } + ret := b[r.p:] + r.p += n + return ret, nil +} + +func (r *BinaryReader) ack(n int) { + r.rn += int64(n) + if r.b == nil { + r.p -= n + _, _ = r.r.Discard(n) // always return n and nil err + } +} + +func (r *BinaryReader) ReadMessageBegin() (name string, typeID TMessageType, seq int32, err error) { + var header int32 + header, err = r.ReadI32() + if err != nil { + return + } + // read header for version and type + if uint32(header)&msgVersionMask != msgVersion1 { + err = errBadVersion + return + } + typeID = TMessageType(uint32(header) & msgTypeMask) + + // read method name + name, err = r.ReadString() + if err != nil { + return + } + + // read seq + seq, err = r.ReadI32() + if err != nil { + return + } + return +} + +func (r *BinaryReader) ReadFieldBegin() (typeID TType, id int16, err error) { + b, err := r.peek(1) + if err != nil { + return 0, 0, err + } + typeID = TType(b[0]) + r.ack(1) + if typeID == STOP { + return STOP, 0, nil + } + b, err = r.peek(2) + if err != nil { + return 0, 0, err + } + id = int16(binary.BigEndian.Uint16(b)) + r.ack(2) + return +} + +func (r *BinaryReader) ReadMapBegin() (kt, vt TType, size int, err error) { + b, err := r.peek(6) + if err != nil { + return 0, 0, 0, err + } + kt, vt, size = TType(b[0]), TType(b[1]), int(binary.BigEndian.Uint32(b[2:])) + r.ack(6) + return +} + +func (r *BinaryReader) ReadListBegin() (et TType, size int, err error) { + b, err := r.peek(5) + if err != nil { + return 0, 0, err + } + et, size = TType(b[0]), int(binary.BigEndian.Uint32(b[1:])) + r.ack(5) + return +} + +func (r *BinaryReader) ReadSetBegin() (et TType, size int, err error) { + b, err := r.peek(5) + if err != nil { + return 0, 0, err + } + et, size = TType(b[0]), int(binary.BigEndian.Uint32(b[1:])) + r.ack(5) + return +} + +func (r *BinaryReader) ReadBinary() (b []byte, err error) { + sz, err := r.ReadI32() + if err != nil { + return nil, err + } + b, err = r.peek(int(sz)) + if err != nil { + return nil, err + } + b = []byte(string(b)) // copy. use span cache? + r.ack(int(sz)) + return +} + +func (r *BinaryReader) ReadString() (s string, err error) { + sz, err := r.ReadI32() + if err != nil { + return "", err + } + b, err := r.peek(int(sz)) + if err != nil { + return "", err + } + s = string(b) // copy. use span cache? + r.ack(int(sz)) + return +} + +func (r *BinaryReader) ReadBool() (v bool, err error) { + b, err := r.peek(1) + if err != nil { + return false, err + } + v = (b[0] == 1) + r.ack(1) + return +} + +func (r *BinaryReader) ReadByte() (v int8, err error) { + b, err := r.peek(1) + if err != nil { + return 0, err + } + v = int8(b[0]) + r.ack(1) + return +} + +func (r *BinaryReader) ReadI16() (v int16, err error) { + b, err := r.peek(2) + if err != nil { + return 0, err + } + v = int16(binary.BigEndian.Uint16(b)) + r.ack(2) + return +} + +func (r *BinaryReader) ReadI32() (v int32, err error) { + b, err := r.peek(4) + if err != nil { + return 0, err + } + v = int32(binary.BigEndian.Uint32(b)) + r.ack(4) + return +} + +func (r *BinaryReader) ReadI64() (v int64, err error) { + b, err := r.peek(8) + if err != nil { + return 0, err + } + v = int64(binary.BigEndian.Uint64(b)) + r.ack(8) + return +} + +func (r *BinaryReader) ReadDouble() (v float64, err error) { + b, err := r.peek(8) + if err != nil { + return 0, err + } + v = math.Float64frombits(binary.BigEndian.Uint64(b)) + r.ack(8) + return +} + +func (r *BinaryReader) Skip(t TType) error { + return r.skipType(t, defaultRecursionDepth) +} + +func (r *BinaryReader) skipn(n int) error { + if r.b != nil { + if int(r.rn)+n > len(r.b) { + return errInvalidDataLen + } + } else { + if _, err := r.r.Discard(n); err != nil { + return NewProtocolExceptionWithErr(err) + } + } + r.rn += int64(n) + return nil +} + +func (r *BinaryReader) skipstr() error { + b, err := r.peek(4) + if err != nil { + return err + } + n := int32(binary.BigEndian.Uint32(b)) + if n < 0 { + return errNegativeSize + } + r.ack(4) + return r.skipn(int(n)) +} + +func (r *BinaryReader) skipType(t TType, maxdepth int) error { + if maxdepth == 0 { + return errDepthLimitExceeded + } + if n := typeToSize[t]; n > 0 { + return r.skipn(int(n)) + } + switch t { + case STRING: + return r.skipstr() + case MAP: + kt, vt, sz, err := r.ReadMapBegin() + if err != nil { + return err + } + if sz < 0 { + return errNegativeSize + } + ksz, vsz := int(typeToSize[kt]), int(typeToSize[vt]) + if ksz > 0 && vsz > 0 { + return r.skipn(int(sz) * (ksz + vsz)) + } + for j := 0; j < sz; j++ { + if ksz > 0 { + err = r.skipn(ksz) + } else if kt == STRING { + err = r.skipstr() + } else { + err = r.skipType(kt, maxdepth-1) + } + if err != nil { + return err + } + if vsz > 0 { + err = r.skipn(vsz) + } else if kt == STRING { + err = r.skipstr() + } else { + err = r.skipType(kt, maxdepth-1) + } + if err != nil { + return err + } + } + return nil + case LIST, SET: + vt, sz, err := r.ReadListBegin() + if err != nil { + return err + } + if sz < 0 { + return errNegativeSize + } + if vsz := typeToSize[vt]; vsz > 0 { + return r.skipn(sz * int(vsz)) + } + for j := 0; j < sz; j++ { + if vt == STRING { + err = r.skipstr() + } else { + err = r.skipType(vt, maxdepth-1) + } + if err != nil { + return err + } + } + return nil + case STRUCT: + for { + ft, _, err := r.ReadFieldBegin() + if ft == STOP { + return nil + } + if fsz := typeToSize[ft]; fsz > 0 { + err = r.skipn(int(fsz)) + } else { + err = r.skipType(ft, maxdepth-1) + } + if err != nil { + return err + } + } + default: + return NewProtocolException(INVALID_DATA, fmt.Sprintf("unknown data type %d", t)) + } +} diff --git a/protocol/thrift/binaryreader_test.go b/protocol/thrift/binaryreader_test.go new file mode 100644 index 0000000..7cf842a --- /dev/null +++ b/protocol/thrift/binaryreader_test.go @@ -0,0 +1,135 @@ +package thrift + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBinaryReader(t *testing.T) { + x := BinaryProtocol{} + b := x.AppendMessageBegin(nil, "hello", 1, 2) + sz0 := len(b) + b = x.AppendFieldBegin(b, 3, 4) + sz1 := len(b) + b = x.AppendFieldStop(b) + sz2 := len(b) + b = x.AppendMapBegin(b, 5, 6, 7) + sz3 := len(b) + b = x.AppendListBegin(b, 8, 9) + sz4 := len(b) + b = x.AppendSetBegin(b, 10, 11) + sz5 := len(b) + b = x.AppendBinary(b, []byte("12")) + sz6 := len(b) + b = x.AppendString(b, "13") + sz7 := len(b) + b = x.AppendBool(b, true) + b = x.AppendBool(b, false) + sz8 := len(b) + b = x.AppendByte(b, 14) + sz9 := len(b) + b = x.AppendI16(b, 15) + sz10 := len(b) + b = x.AppendI32(b, 16) + sz11 := len(b) + b = x.AppendI64(b, 17) + sz12 := len(b) + b = x.AppendDouble(b, 18.5) + sz13 := len(b) + + r0 := NewBinaryReaderBuffer(b) + r1 := NewBinaryReader(bytes.NewReader(b)) + for i, r := range []*BinaryReader{r0, r1} { + name := "NewBinaryReaderBuffer" + if i != 0 { + name = "NewBinaryReader" + } + t.Run(name, func(t *testing.T) { + name, mt, seq, err := r.ReadMessageBegin() + require.NoError(t, err) + require.Equal(t, "hello", name) + require.Equal(t, TMessageType(1), mt) + require.Equal(t, int32(2), seq) + require.Equal(t, sz0, int(r.Readn())) + + ft, fid, err := r.ReadFieldBegin() + require.NoError(t, err) + require.Equal(t, TType(3), ft) + require.Equal(t, int16(4), fid) + require.Equal(t, sz1, int(r.Readn())) + + ft, fid, err = r.ReadFieldBegin() // for AppendFieldStop + require.NoError(t, err) + require.Equal(t, STOP, ft) + require.Equal(t, int16(0), fid) + require.Equal(t, sz2, int(r.Readn())) + + kt, vt, sz, err := r.ReadMapBegin() + require.NoError(t, err) + require.Equal(t, TType(5), kt) + require.Equal(t, TType(6), vt) + require.Equal(t, int(7), sz) + require.Equal(t, sz3, int(r.Readn())) + + et, sz, err := r.ReadListBegin() + require.NoError(t, err) + require.Equal(t, TType(8), et) + require.Equal(t, int(9), sz) + require.Equal(t, sz4, int(r.Readn())) + + et, sz, err = r.ReadSetBegin() + require.NoError(t, err) + require.Equal(t, TType(10), et) + require.Equal(t, int(11), sz) + require.Equal(t, sz5, int(r.Readn())) + + x, err := r.ReadBinary() + require.NoError(t, err) + require.Equal(t, "12", string(x)) + require.Equal(t, sz6, int(r.Readn())) + + s, err := r.ReadString() + require.NoError(t, err) + require.Equal(t, "13", s) + require.Equal(t, sz7, int(r.Readn())) + + vb, err := r.ReadBool() + require.NoError(t, err) + require.True(t, vb) + vb, err = r.ReadBool() + require.NoError(t, err) + require.False(t, vb) + require.Equal(t, sz8, int(r.Readn())) + + v8, err := r.ReadByte() + require.NoError(t, err) + require.Equal(t, int8(14), v8) + require.Equal(t, sz9, int(r.Readn())) + + v16, err := r.ReadI16() + require.NoError(t, err) + require.Equal(t, int16(15), v16) + require.Equal(t, sz10, int(r.Readn())) + + v32, err := r.ReadI32() + require.NoError(t, err) + require.Equal(t, int32(16), v32) + require.Equal(t, sz11, int(r.Readn())) + + v64, err := r.ReadI64() + require.NoError(t, err) + require.Equal(t, int64(17), v64) + require.Equal(t, sz12, int(r.Readn())) + + vf, err := r.ReadDouble() + require.NoError(t, err) + require.Equal(t, float64(18.5), vf) + require.Equal(t, sz13, int(r.Readn())) + }) + } +} + +func TestBinaryReaderSkip(t *testing.T) { +} diff --git a/protocol/thrift/binarywriter.go b/protocol/thrift/binarywriter.go new file mode 100644 index 0000000..f0b18df --- /dev/null +++ b/protocol/thrift/binarywriter.go @@ -0,0 +1,104 @@ +package thrift + +import ( + "math" + "sync" +) + +const defaultBinaryWriterBufferSize = 4096 + +type BinaryWriter struct { + buf []byte +} + +var poolBinaryWriter = sync.Pool{ + New: func() interface{} { + return &BinaryWriter{buf: make([]byte, 0, defaultBinaryWriterBufferSize)} + }, +} + +func NewBinaryWriter() *BinaryWriter { + return NewBinaryWriterSize(0) +} + +func NewBinaryWriterSize(sz int) *BinaryWriter { + w := poolBinaryWriter.Get().(*BinaryWriter) + if cap(w.buf) < sz { + w.Release() + w = &BinaryWriter{buf: make([]byte, 0, sz)} + } + w.Reset() + return w +} + +func (w *BinaryWriter) Release() { + poolBinaryWriter.Put(w) +} + +func (w *BinaryWriter) Reset() { + w.buf = w.buf[:0] +} + +func (w *BinaryWriter) Bytes() []byte { + return w.buf +} + +func (w *BinaryWriter) WriteMessageBegin(name string, typeID TMessageType, seq int32) { + w.buf = Binary.AppendMessageBegin(w.buf, name, typeID, seq) +} + +func (w *BinaryWriter) WriteFieldBegin(typeID TType, id int16) { + w.buf = Binary.AppendFieldBegin(w.buf, typeID, id) +} + +func (w *BinaryWriter) WriteFieldStop() { + w.buf = append(w.buf, byte(STOP)) +} + +func (w *BinaryWriter) WriteMapBegin(kt, vt TType, size int) { + w.buf = Binary.AppendMapBegin(w.buf, kt, vt, size) +} + +func (w *BinaryWriter) WriteListBegin(et TType, size int) { + w.buf = Binary.AppendListBegin(w.buf, et, size) +} + +func (w *BinaryWriter) WriteSetBegin(et TType, size int) { + w.buf = Binary.AppendSetBegin(w.buf, et, size) +} + +func (w *BinaryWriter) WriteBinary(v []byte) { + w.buf = Binary.AppendBinary(w.buf, v) +} + +func (w *BinaryWriter) WriteString(v string) { + w.buf = Binary.AppendString(w.buf, v) +} + +func (w *BinaryWriter) WriteBool(v bool) { + if v { + w.buf = append(w.buf, 1) + } else { + w.buf = append(w.buf, 0) + } +} + +func (w *BinaryWriter) WriteByte(v int8) { + w.buf = append(w.buf, byte(v)) +} + +func (w *BinaryWriter) WriteI16(v int16) { + w.buf = append(w.buf, byte(uint16(v)>>8), byte(v)) +} + +func (w *BinaryWriter) WriteI32(v int32) { + w.buf = appendUint32(w.buf, uint32(v)) +} + +func (w *BinaryWriter) WriteI64(v int64) { + w.buf = appendUint64(w.buf, uint64(v)) +} + +func (w *BinaryWriter) WriteDouble(v float64) { + w.buf = appendUint64(w.buf, math.Float64bits(v)) +} diff --git a/protocol/thrift/binarywriter_test.go b/protocol/thrift/binarywriter_test.go new file mode 100644 index 0000000..4876edf --- /dev/null +++ b/protocol/thrift/binarywriter_test.go @@ -0,0 +1,70 @@ +package thrift + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBinaryWriter(t *testing.T) { + w := NewBinaryWriterSize(defaultBinaryWriterBufferSize * 2) + x := BinaryProtocol{} + + b := x.AppendMessageBegin(nil, "hello", 1, 2) + w.WriteMessageBegin("hello", 1, 2) + require.Equal(t, b, w.Bytes()) + + b = x.AppendFieldBegin(b, 3, 4) + w.WriteFieldBegin(3, 4) + require.Equal(t, b, w.Bytes()) + + b = x.AppendFieldStop(b) + w.WriteFieldStop() + require.Equal(t, b, w.Bytes()) + + b = x.AppendMapBegin(b, 5, 6, 7) + w.WriteMapBegin(5, 6, 7) + require.Equal(t, b, w.Bytes()) + + b = x.AppendListBegin(b, 8, 9) + w.WriteListBegin(8, 9) + require.Equal(t, b, w.Bytes()) + + b = x.AppendSetBegin(b, 10, 11) + w.WriteSetBegin(10, 11) + require.Equal(t, b, w.Bytes()) + + b = x.AppendBinary(b, []byte("12")) + w.WriteBinary([]byte("12")) + require.Equal(t, b, w.Bytes()) + + b = x.AppendString(b, "13") + w.WriteString("13") + require.Equal(t, b, w.Bytes()) + + b = x.AppendBool(b, true) + b = x.AppendBool(b, false) + w.WriteBool(true) + w.WriteBool(false) + require.Equal(t, b, w.Bytes()) + + b = x.AppendByte(b, 14) + w.WriteByte(14) + require.Equal(t, b, w.Bytes()) + + b = x.AppendI16(b, 15) + w.WriteI16(15) + require.Equal(t, b, w.Bytes()) + + b = x.AppendI32(b, 16) + w.WriteI32(16) + require.Equal(t, b, w.Bytes()) + + b = x.AppendI64(b, 17) + w.WriteI64(17) + require.Equal(t, b, w.Bytes()) + + b = x.AppendDouble(b, 18.5) + w.WriteDouble(18.5) + require.Equal(t, b, w.Bytes()) +} diff --git a/protocol/thrift/exception.go b/protocol/thrift/exception.go index a1e069e..92c5d8e 100644 --- a/protocol/thrift/exception.go +++ b/protocol/thrift/exception.go @@ -154,6 +154,8 @@ func NewTransportException(t int32, m string) *TransportException { // it implements ThriftFastCodec interface. type ProtocolException struct { ApplicationException // same implementation ... + + err error } const ( // ProtocolException codes from apache thrift @@ -166,7 +168,12 @@ const ( // ProtocolException codes from apache thrift DEPTH_LIMIT = 6 ) -// NewTransportException ... +var ( + errInvalidDataLen = NewProtocolException(INVALID_DATA, "Invalid data length") + errNegativeSize = NewProtocolException(NEGATIVE_SIZE, "negative size") +) + +// NewTransportExceptionWithType func NewProtocolException(t int32, m string) *ProtocolException { ret := ProtocolException{} ret.t = t @@ -174,6 +181,29 @@ func NewProtocolException(t int32, m string) *ProtocolException { return &ret } +// NewProtocolException ... +func NewProtocolExceptionWithErr(err error) *ProtocolException { + e, ok := err.(*ProtocolException) + if ok { + return e + } + ret := NewProtocolException(UNKNOWN_PROTOCOL_EXCEPTION, err.Error()) + ret.err = err + return ret +} + +// Unwrap ... for errors pkg +func (e *ProtocolException) Unwrap() error { return e.err } + +// Is ... for errors pkg +func (e *ProtocolException) Is(err error) bool { + t, ok := err.(tException) + if ok && t.TypeId() == e.t && t.Error() == e.m { + return true + } + return errors.Is(e.err, err) +} + // Generic Thrift exception with TypeId method type tException interface { Error() string diff --git a/protocol/thrift/exception_test.go b/protocol/thrift/exception_test.go index 6bf98b7..134498f 100644 --- a/protocol/thrift/exception_test.go +++ b/protocol/thrift/exception_test.go @@ -18,6 +18,7 @@ package thrift import ( "errors" + "io" "testing" "github.com/stretchr/testify/assert" @@ -47,6 +48,12 @@ func TestApplicationException(t *testing.T) { t.Log(ex4.String()) // ... } +func TestProtocolException(t *testing.T) { + e := NewProtocolExceptionWithErr(io.EOF) + assert.ErrorIs(t, e, io.EOF) // will call errors.Is + assert.True(t, e.Is(NewProtocolException(UNKNOWN_PROTOCOL_EXCEPTION, "EOF"))) +} + type testTException struct{} func (testTException) Error() string { return "testTException" }