diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..c7d06c3 --- /dev/null +++ b/go.mod @@ -0,0 +1,11 @@ +module github.com/cloudwego/gopkg + +go 1.19 + +require github.com/stretchr/testify v1.9.0 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..60ce688 --- /dev/null +++ b/go.sum @@ -0,0 +1,10 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/unsafe/unsafe.go b/internal/unsafe/unsafe.go new file mode 100644 index 0000000..d461664 --- /dev/null +++ b/internal/unsafe/unsafe.go @@ -0,0 +1,30 @@ +package unsafe + +import "unsafe" + +type sliceHeader struct { + Data uintptr + Len int + Cap int +} + +type strHeader struct { + Data uintptr + Len int +} + +// ByteSliceToString converts []byte to string without copy +func ByteSliceToString(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} + +// StringToByteSlice converts string to []byte without copy +func StringToByteSlice(s string) []byte { + var v []byte + p0 := (*sliceHeader)(unsafe.Pointer(&v)) + p1 := (*strHeader)(unsafe.Pointer(&s)) + p0.Data = p1.Data + p0.Len = p1.Len + p0.Cap = p1.Len + return v +} diff --git a/pkg/protocol/thrift/binary.go b/pkg/protocol/thrift/binary.go new file mode 100644 index 0000000..6e991fd --- /dev/null +++ b/pkg/protocol/thrift/binary.go @@ -0,0 +1,379 @@ +package thrift + +import ( + "encoding/binary" + "fmt" + "math" + + "github.com/cloudwego/gopkg/internal/unsafe" +) + +var Binary binaryProtocol + +type binaryProtocol struct{} + +func (binaryProtocol) WriteMessageBegin(buf []byte, name string, typeID TMessageType, seq int32) int { + binary.BigEndian.PutUint32(buf, uint32(msgVersion1)|uint32(typeID&msgTypeMask)) + binary.BigEndian.PutUint32(buf, uint32(len(name))) + off := 8 + copy(buf[8:], name) + binary.BigEndian.PutUint32(buf[off:], uint32(seq)) + return off + 4 +} + +func (binaryProtocol) WriteFieldBegin(buf []byte, typeID TType, id int16) int { + buf[0] = byte(typeID) + binary.BigEndian.PutUint16(buf[1:], uint16(id)) + return 3 +} + +func (binaryProtocol) WriteFieldStop(buf []byte) int { + buf[0] = byte(STOP) + return 1 +} + +func (binaryProtocol) WriteMapBegin(buf []byte, kt, vt TType, size int) int { + buf[0] = byte(kt) + buf[1] = byte(vt) + binary.BigEndian.PutUint32(buf[2:], uint32(size)) + return 6 +} + +func (binaryProtocol) WriteListBegin(buf []byte, et TType, size int) int { + buf[0] = byte(et) + binary.BigEndian.PutUint32(buf[1:], uint32(size)) + return 5 +} + +func (binaryProtocol) WriteSetBegin(buf []byte, et TType, size int) int { + buf[0] = byte(et) + binary.BigEndian.PutUint32(buf[1:], uint32(size)) + return 5 +} + +func (binaryProtocol) WriteBool(buf []byte, v bool) int { + if v { + buf[0] = 1 + } else { + buf[0] = 0 + } + return 1 +} + +func (binaryProtocol) WriteByte(buf []byte, v int8) int { + buf[0] = byte(v) + return 1 +} + +func (binaryProtocol) WriteI16(buf []byte, v int16) int { + binary.BigEndian.PutUint16(buf, uint16(v)) + return 2 +} + +func (binaryProtocol) WriteI32(buf []byte, v int32) int { + binary.BigEndian.PutUint32(buf, uint32(v)) + return 4 +} + +func (binaryProtocol) WriteI64(buf []byte, v int64) int { + binary.BigEndian.PutUint64(buf, uint64(v)) + return 8 +} + +func (binaryProtocol) WriteDouble(buf []byte, v float64) int { + binary.BigEndian.PutUint64(buf, math.Float64bits(v)) + return 8 +} + +func (binaryProtocol) WriteBinary(buf, v []byte) int { + binary.BigEndian.PutUint32(buf, uint32(len(v))) + return 4 + copy(buf[4:], v) +} + +func (binaryProtocol) WriteBinaryNocopy(buf []byte, w NocopyWriter, v []byte) int { + if w == nil { + return Binary.WriteBinary(buf, v) + } + binary.BigEndian.PutUint32(buf, uint32(len(v))) + _ = w.WriteDirect(v, len(buf[4:])) // always err == nil ? + return 4 +} + +func (binaryProtocol) WriteString(buf []byte, v string) int { + binary.BigEndian.PutUint32(buf, uint32(len(v))) + return 4 + copy(buf[4:], v) +} + +func (binaryProtocol) WriteStringNocopy(buf []byte, w NocopyWriter, v string) int { + return Binary.WriteBinaryNocopy(buf, w, unsafe.StringToByteSlice(v)) +} + +// Length methods + +func (binaryProtocol) MessageBeginLength(name string, _ TMessageType, _ int32) int { + return 4 + (4 + len(name)) + 4 +} + +func (binaryProtocol) FieldBeginLength(_ TType, _ int16) int { return 3 } +func (binaryProtocol) FieldStopLength() int { return 1 } +func (binaryProtocol) MapBeginLength() int { return 6 } +func (binaryProtocol) ListBeginLength() int { return 5 } +func (binaryProtocol) SetBeginLength() int { return 5 } +func (binaryProtocol) BoolLength(_ bool) int { return 1 } +func (binaryProtocol) ByteLength(_ int8) int { return 1 } +func (binaryProtocol) I16Length(_ int16) int { return 2 } +func (binaryProtocol) I32Length(_ int32) int { return 4 } +func (binaryProtocol) I64Length(_ int64) int { return 8 } +func (binaryProtocol) DoubleLength(_ float64) int { return 8 } +func (binaryProtocol) StringLength(v string) int { return 4 + len(v) } +func (binaryProtocol) BinaryLength(v []byte) int { return 4 + len(v) } +func (binaryProtocol) StringLengthNocopy(v string) int { return 4 + len(v) } +func (binaryProtocol) BinaryLengthNocopy(v []byte) int { return 4 + len(v) } + +// Read methods + +var ( + errReadMessage = NewProtocolException(INVALID_DATA, "ReadMessageBegin: buf too small") + errBadVersion = NewProtocolException(BAD_VERSION, "ReadMessageBegin: bad version") +) + +func (binaryProtocol) ReadMessageBegin(buf []byte) (name string, typeID TMessageType, seq int32, l int, err error) { + if len(buf) < 4 { // version+type header + name header + return "", 0, 0, 0, errReadMessage + } + + // read header for version and type + header := binary.BigEndian.Uint32(buf) + if header&msgVersionMask != msgVersion1 { + return "", 0, 0, 0, errBadVersion + } + typeID = TMessageType(header & msgTypeMask) + + off := 4 + + // read method name + name, l, err1 := Binary.ReadString(buf[off:]) + if err1 != nil { + return "", 0, 0, 0, errReadMessage + } + off += l + + // read seq + seq, l, err2 := Binary.ReadI32(buf[off:]) + if err2 != nil { + return "", 0, 0, 0, errReadMessage + } + off += l + return name, typeID, seq, off, nil +} + +var ( + errReadField = NewProtocolException(INVALID_DATA, "ReadFieldBegin: buf too small") + errReadMap = NewProtocolException(INVALID_DATA, "ReadMapBegin: buf too small") + errReadList = NewProtocolException(INVALID_DATA, "ReadListBegin: buf too small") + errReadSet = NewProtocolException(INVALID_DATA, "ReadSetBegin: buf too small") + errReadStr = NewProtocolException(INVALID_DATA, "ReadString: buf too small") + + errReadBool = NewProtocolException(INVALID_DATA, "ReadBool: len(buf) < 1") + errReadByte = NewProtocolException(INVALID_DATA, "ReadByte: len(buf) < 1") + errReadI16 = NewProtocolException(INVALID_DATA, "ReadI16: len(buf) < 2") + errReadI32 = NewProtocolException(INVALID_DATA, "ReadI32: len(buf) < 4") + errReadI64 = NewProtocolException(INVALID_DATA, "ReadI64: len(buf) < 8") +) + +func (binaryProtocol) ReadFieldBegin(buf []byte) (typeID TType, id int16, l int, err error) { + if len(buf) < 1 { + return 0, 0, 0, errReadField + } + typeID = TType(buf[0]) + if typeID == STOP { + return STOP, 0, 1, nil + } + if len(buf) < 3 { + return 0, 0, 0, errReadField + } + return typeID, int16(binary.BigEndian.Uint16(buf[1:])), 3, nil +} + +func (binaryProtocol) ReadMapBegin(buf []byte) (kt, vt TType, size, l int, err error) { + if len(buf) < 6 { + return 0, 0, 0, 0, errReadMap + } + return TType(buf[0]), TType(buf[1]), int(binary.BigEndian.Uint32(buf[2:])), 6, nil +} + +func (binaryProtocol) ReadListBegin(buf []byte) (et TType, size, l int, err error) { + if len(buf) < 5 { + return 0, 0, 0, errReadList + } + return TType(buf[0]), int(binary.BigEndian.Uint32(buf[1:])), 5, nil +} + +func (binaryProtocol) ReadSetBegin(buf []byte) (et TType, size, l int, err error) { + if len(buf) < 5 { + return 0, 0, 0, errReadSet + } + return TType(buf[0]), int(binary.BigEndian.Uint32(buf[1:])), 5, nil +} + +func (binaryProtocol) ReadString(buf []byte) (s string, l int, err error) { + sz, _, err := Binary.ReadI32(buf) + if err != nil { + return "", 0, errReadStr + } + l = 4 + int(sz) + if len(buf) < l { + return "", 4, errReadStr + } + // TODO: use span + return string(buf[4:l]), l, nil +} + +func (binaryProtocol) ReadBool(buf []byte) (v bool, l int, err error) { + if len(buf) < 1 { + return false, 0, errReadBool + } + if buf[0] == 1 { + return true, 1, nil + } + return false, 1, nil +} + +func (binaryProtocol) ReadByte(buf []byte) (v int8, l int, err error) { + if len(buf) < 1 { + return 0, 0, errReadByte + } + return int8(buf[0]), 1, nil +} + +func (binaryProtocol) ReadI16(buf []byte) (v int16, l int, err error) { + if len(buf) < 2 { + return 0, 0, errReadI16 + } + return int16(binary.BigEndian.Uint16(buf)), 2, nil +} + +func (binaryProtocol) ReadI32(buf []byte) (v int32, l int, err error) { + if len(buf) < 4 { + return 0, 0, errReadI32 + } + return int32(binary.BigEndian.Uint32(buf)), 4, nil +} + +func (binaryProtocol) ReadI64(buf []byte) (v int64, l int, err error) { + if len(buf) < 8 { + return 0, 0, errReadI64 + } + return int64(binary.BigEndian.Uint64(buf)), 8, nil +} + +func (binaryProtocol) ReadDouble(buf []byte) (v float64, l int, err error) { + if len(buf) < 8 { + return 0, 0, errReadI64 + } + return math.Float64frombits(binary.BigEndian.Uint64(buf)), 8, nil +} + +var ( + errDepthLimitExceeded = NewProtocolException(DEPTH_LIMIT, "depth limit exceeded") + errNegativeSize = NewProtocolException(NEGATIVE_SIZE, "negative size") +) + +var typeToSize = [256]int8{ + BOOL: 1, + BYTE: 1, + DOUBLE: 8, + I16: 2, + I32: 4, + I64: 8, +} + +func skipstr(b []byte) int { + return 4 + int(binary.BigEndian.Uint32(b)) +} + +// Skip skips over the value for the given type using Go implementation. +func (binaryProtocol) Skip(b []byte, t TType) (int, error) { + return skipType(b, t, defaultRecursionDepth) +} + +func skipType(b []byte, t TType, maxdepth int) (int, error) { + if maxdepth == 0 { + return 0, errDepthLimitExceeded + } + if n := typeToSize[t]; n > 0 { + return int(n), nil + } + switch t { + case STRING: + return skipstr(b), nil + case MAP: + i := 6 + kt, vt, sz := TType(b[0]), TType(b[1]), int32(binary.BigEndian.Uint32(b[2:])) + if sz < 0 { + return 0, errNegativeSize + } + ksz, vsz := int(typeToSize[kt]), int(typeToSize[vt]) + if ksz > 0 && vsz > 0 { + return i + (int(sz) * (ksz + vsz)), nil + } + for j := int32(0); j < sz; j++ { + if ksz > 0 { + i += ksz + } else if kt == STRING { + i += skipstr(b[i:]) + } else if n, err := skipType(b[i:], kt, maxdepth-1); err != nil { + return i, err + } else { + i += n + } + if vsz > 0 { + i += vsz + } else if vt == STRING { + i += skipstr(b[i:]) + } else if n, err := skipType(b[i:], vt, maxdepth-1); err != nil { + return i, err + } else { + i += n + } + } + return i, nil + case LIST, SET: + i := 5 + vt, sz := TType(b[0]), int32(binary.BigEndian.Uint32(b[1:])) + if sz < 0 { + return 0, errNegativeSize + } + if typeToSize[vt] > 0 { + return i + int(sz)*int(typeToSize[vt]), nil + } + for j := int32(0); j < sz; j++ { + if vt == STRING { + i += skipstr(b[i:]) + } else if n, err := skipType(b[i:], vt, maxdepth-1); err != nil { + return i, err + } else { + i += n + } + } + return i, nil + case STRUCT: + i := 0 + for { + ft := TType(b[i]) + i += 1 // TType + if ft == STOP { + return i, nil + } + i += 2 // Field ID + if typeToSize[ft] > 0 { + i += int(typeToSize[ft]) + } else if n, err := skipType(b[i:], ft, maxdepth-1); err != nil { + return i, err + } else { + i += n + } + } + default: + return 0, NewProtocolException(INVALID_DATA, fmt.Sprintf("unknown data type %d", t)) + } +} diff --git a/pkg/protocol/thrift/exception.go b/pkg/protocol/thrift/exception.go new file mode 100644 index 0000000..ac9bba3 --- /dev/null +++ b/pkg/protocol/thrift/exception.go @@ -0,0 +1,199 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "errors" + "fmt" +) + +const ( // ApplicationException codes from apache thrift + UNKNOWN_APPLICATION_EXCEPTION = 0 + UNKNOWN_METHOD = 1 + INVALID_MESSAGE_TYPE_EXCEPTION = 2 + WRONG_METHOD_NAME = 3 + BAD_SEQUENCE_ID = 4 + MISSING_RESULT = 5 + INTERNAL_ERROR = 6 + PROTOCOL_ERROR = 7 + INVALID_TRANSFORM = 8 + INVALID_PROTOCOL = 9 + UNSUPPORTED_CLIENT_TYPE = 10 +) + +// ApplicationException is for replacing apache.TApplicationException +// it implements ThriftFastCodec interface. +type ApplicationException struct { + t int32 + m string +} + +// NewApplicationException creates an ApplicationException instance +func NewApplicationException(t int32, msg string) *ApplicationException { + return &ApplicationException{t: t, m: msg} +} + +// Msg ... +func (e *ApplicationException) Msg() string { return e.m } + +// TypeID ... +func (e *ApplicationException) TypeID() int32 { return e.t } + +// TypeId ... for apache ApplicationException compatibility +func (e *ApplicationException) TypeId() int32 { return e.t } + +// BLength returns the len of encoded buffer. +func (e *ApplicationException) BLength() int { + // Msg Field: 1 (type) + 2 (id) + 4(strlen) + len(m) + // Type Field: 1 (type) + 2 (id) + 4(ex type) + // STOP: 1 byte + return (1 + 2 + 4 + len(e.m)) + (1 + 2 + 4) + 1 +} + +// FastRead ... +func (e *ApplicationException) FastRead(b []byte) (off int, err error) { + for i := 0; i < 2; i++ { + tp, id, l, err := Binary.ReadFieldBegin(b[off:]) + if err != nil { + return 0, err + } + off += l + switch { + case id == 1 && tp == STRING: // Msg + e.m, l, err = Binary.ReadString(b[off:]) + case id == 2 && tp == I32: // TypeID + e.t, l, err = Binary.ReadI32(b[off:]) + default: + l, err = Binary.Skip(b, tp) + } + if err != nil { + return 0, err + } + off += l + } + v, l, err := Binary.ReadByte(b[off:]) + if err != nil { + return 0, err + } + if v != STOP { + return 0, fmt.Errorf("expects thrift.STOP, found: %d", v) + } + off += l + return off, nil +} + +// FastWrite ... +func (e *ApplicationException) FastWrite(b []byte) (off int) { + off += Binary.WriteFieldBegin(b[off:], STRING, 1) + off += Binary.WriteString(b[off:], e.m) + off += Binary.WriteFieldBegin(b[off:], I32, 2) + off += Binary.WriteI32(b[off:], e.t) + off += Binary.WriteByte(b[off:], STOP) + return off +} + +// FastWriteNocopy ... +func (e *ApplicationException) FastWriteNocopy(b []byte, _ NocopyWriter) int { + return e.FastWrite(b) +} + +// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/exception.go +var defaultApplicationExceptionMessage = map[int32]string{ + UNKNOWN_APPLICATION_EXCEPTION: "unknown application exception", + UNKNOWN_METHOD: "unknown method", + INVALID_MESSAGE_TYPE_EXCEPTION: "invalid message type", + WRONG_METHOD_NAME: "wrong method name", + BAD_SEQUENCE_ID: "bad sequence ID", + MISSING_RESULT: "missing result", + INTERNAL_ERROR: "unknown internal error", + PROTOCOL_ERROR: "unknown protocol error", + INVALID_TRANSFORM: "Invalid transform", + INVALID_PROTOCOL: "Invalid protocol", + UNSUPPORTED_CLIENT_TYPE: "Unsupported client type", +} + +// Error implements apache.Exception +func (e *ApplicationException) Error() string { + if e.m != "" { + return e.m + } + if m, ok := defaultApplicationExceptionMessage[e.t]; ok { + return m + } + return fmt.Sprintf("unknown exception type [%d]", e.t) +} + +// TransportException is for replacing apache.TransportException +// it implements ThriftFastCodec interface. +type TransportException struct { + ApplicationException // same implementation ... +} + +// NewTransportException ... +func NewTransportException(t int32, m string) *TransportException { + ret := TransportException{} + ret.t = t + ret.m = m + return &ret +} + +// ProtocolException is for replacing apache.ProtocolException +// it implements ThriftFastCodec interface. +type ProtocolException struct { + ApplicationException // same implementation ... +} + +const ( // ProtocolException codes from apache thrift + UNKNOWN_PROTOCOL_EXCEPTION = 0 + INVALID_DATA = 1 + NEGATIVE_SIZE = 2 + SIZE_LIMIT = 3 + BAD_VERSION = 4 + NOT_IMPLEMENTED = 5 + DEPTH_LIMIT = 6 +) + +// NewTransportException ... +func NewProtocolException(t int32, m string) *ProtocolException { + ret := ProtocolException{} + ret.t = t + ret.m = m + return &ret +} + +// Generic Thrift exception with TypeId method +type tException interface { + Error() string + TypeId() int32 +} + +// Prepends additional information to an error without losing the Thrift exception interface +func PrependError(prepend string, err error) error { + if t, ok := err.(*TransportException); ok { + return NewTransportException(t.TypeID(), prepend+t.Error()) + } + if t, ok := err.(*ProtocolException); ok { + return NewProtocolException(t.TypeID(), prepend+err.Error()) + } + if t, ok := err.(*ApplicationException); ok { + return NewApplicationException(t.TypeID(), prepend+t.Error()) + } + if t, ok := err.(tException); ok { // apache thrift exception? + return NewApplicationException(t.TypeId(), prepend+t.Error()) + } + return errors.New(prepend + err.Error()) +} diff --git a/pkg/protocol/thrift/exception_test.go b/pkg/protocol/thrift/exception_test.go new file mode 100644 index 0000000..1d3d051 --- /dev/null +++ b/pkg/protocol/thrift/exception_test.go @@ -0,0 +1,68 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestApplicationException(t *testing.T) { + ex1 := NewApplicationException(1, "t1") + b := make([]byte, ex1.BLength()) + n := ex1.FastWrite(b) + assert.Equal(t, len(b), n) + + ex2 := NewApplicationException(0, "") + n, err := ex2.FastRead(b) + require.NoError(t, err) + assert.Equal(t, len(b), n) + assert.Equal(t, int32(1), ex2.TypeID()) + assert.Equal(t, "t1", ex2.Msg()) +} + +func TestPrependError(t *testing.T) { + var ok bool + ex0 := NewTransportException(1, "world") + err0 := PrependError("hello ", ex0) + ex0, ok = err0.(*TransportException) + require.True(t, ok) + assert.Equal(t, int32(1), ex0.TypeID()) + assert.Equal(t, "hello world", ex0.Error()) + + ex1 := NewProtocolException(2, "world") + err1 := PrependError("hello ", ex1) + ex1, ok = err1.(*ProtocolException) + require.True(t, ok) + assert.Equal(t, int32(2), ex1.TypeID()) + assert.Equal(t, "hello world", ex1.Error()) + + ex2 := NewApplicationException(3, "world") + err2 := PrependError("hello ", ex2) + ex2, ok = err2.(*ApplicationException) + require.True(t, ok) + assert.Equal(t, int32(3), ex2.TypeID()) + assert.Equal(t, "hello world", ex2.Error()) + + err3 := PrependError("hello ", errors.New("world")) + _, ok = err3.(tException) + require.False(t, ok) + assert.Equal(t, "hello world", err3.Error()) +} diff --git a/pkg/protocol/thrift/thrift.go b/pkg/protocol/thrift/thrift.go new file mode 100644 index 0000000..9a6968c --- /dev/null +++ b/pkg/protocol/thrift/thrift.go @@ -0,0 +1,58 @@ +package thrift + +// TMessageType represents message type constants in the Thrift protocol. +// originally from github.com/apache/thrift +type TMessageType = int32 // use alias for better flexibility of interfaces + +const ( + INVALID_TMESSAGE_TYPE TMessageType = 0 + CALL TMessageType = 1 + REPLY TMessageType = 2 + EXCEPTION TMessageType = 3 + ONEWAY TMessageType = 4 +) + +// TType represents field type constants in the Thrift protocol +// originally from github.com/apache/thrift +type TType = int8 // use alias for better flexibility of interfaces + +const ( + STOP TType = 0 + VOID TType = 1 + BOOL TType = 2 + BYTE TType = 3 + I08 TType = 3 + DOUBLE TType = 4 + I16 TType = 6 + I32 TType = 8 + I64 TType = 10 + STRING TType = 11 + UTF7 TType = 11 + STRUCT TType = 12 + MAP TType = 13 + SET TType = 14 + LIST TType = 15 + UTF8 TType = 16 + UTF16 TType = 17 +) + +const defaultRecursionDepth = 64 // for skip + +const ( // for Write/ReadMessage + msgVersion1 = 0x80010000 + msgVersionMask = 0xffff0000 + msgTypeMask = 0x0000ffff // for TMessageType +) + +// BinaryWriter represents the method used in thrift encoding for nocopy writes +// It supports netpoll nocopy feature, see: https://github.com/cloudwego/netpoll/blob/develop/nocopy.go +type NocopyWriter interface { + WriteDirect(b []byte, remainCap int) error +} + +// ThriftFastCodec represents the interface of thrift fastcodec generated structs +type ThriftFastCodec interface { + BLength() int + FastWriteNocopy(buf []byte, bw NocopyWriter) int + FastRead(buf []byte) (int, error) +}