From fe8e047da97539628361ea6bfa0531aa84693ac5 Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Thu, 11 Jul 2024 14:41:36 +0800 Subject: [PATCH] feat(protocol): add thrift --- .gitignore | 3 + check_branch_name.sh | 10 - go.mod | 13 + go.sum | 26 ++ internal/unsafe/unsafe.go | 46 +++ internal/unsafe/unsafe_test.go | 30 ++ profile/README.md | 13 - protocol/thrift/binary.go | 483 ++++++++++++++++++++++++++++++ protocol/thrift/binary_test.go | 419 ++++++++++++++++++++++++++ protocol/thrift/exception.go | 198 ++++++++++++ protocol/thrift/exception_test.go | 95 ++++++ protocol/thrift/thrift.go | 76 +++++ 12 files changed, 1389 insertions(+), 23 deletions(-) delete mode 100755 check_branch_name.sh create mode 100644 go.mod create mode 100644 go.sum create mode 100644 internal/unsafe/unsafe.go create mode 100644 internal/unsafe/unsafe_test.go delete mode 100644 profile/README.md create mode 100644 protocol/thrift/binary.go create mode 100644 protocol/thrift/binary_test.go create mode 100644 protocol/thrift/exception.go create mode 100644 protocol/thrift/exception_test.go create mode 100644 protocol/thrift/thrift.go diff --git a/.gitignore b/.gitignore index 644df37..4634819 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,6 @@ output/* # Vscode files .vscode +# Go workspace file +go.work +go.work.sum diff --git a/check_branch_name.sh b/check_branch_name.sh deleted file mode 100755 index 1876fc0..0000000 --- a/check_branch_name.sh +++ /dev/null @@ -1,10 +0,0 @@ -#!/usr/bin/env bash - -current=$(git status | head -n1 | sed 's/On branch //') -name=${1:-$current} -if [[ ! $name =~ ^(((opt(imize)?|feat(ure)?|(bug|hot)?fix|test|refact(or)?|ci)/.+)|(main|develop)|(release-v[0-9]+\.[0-9]+)|(release/v[0-9]+\.[0-9]+\.[0-9]+(-[a-z0-9.]+(\+[a-z0-9.]+)?)?)|revert-[a-z0-9]+)$ ]]; then - echo "branch name '$name' is invalid" - exit 1 -else - echo "branch name '$name' is valid" -fi diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..8ad4bb4 --- /dev/null +++ b/go.mod @@ -0,0 +1,13 @@ +module github.com/cloudwego/gopkg + +go 1.17 + +require github.com/stretchr/testify v1.9.0 + +require ( + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/kr/pretty v0.1.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..340439d --- /dev/null +++ b/go.sum @@ -0,0 +1,26 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/unsafe/unsafe.go b/internal/unsafe/unsafe.go new file mode 100644 index 0000000..4006280 --- /dev/null +++ b/internal/unsafe/unsafe.go @@ -0,0 +1,46 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package unsafe + +import "unsafe" + +type sliceHeader struct { + Data uintptr + Len int + Cap int +} + +type strHeader struct { + Data uintptr + Len int +} + +// ByteSliceToString converts []byte to string without copy +func ByteSliceToString(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} + +// StringToByteSlice converts string to []byte without copy +func StringToByteSlice(s string) []byte { + var v []byte + p0 := (*sliceHeader)(unsafe.Pointer(&v)) + p1 := (*strHeader)(unsafe.Pointer(&s)) + p0.Data = p1.Data + p0.Len = p1.Len + p0.Cap = p1.Len + return v +} diff --git a/internal/unsafe/unsafe_test.go b/internal/unsafe/unsafe_test.go new file mode 100644 index 0000000..92c2aa0 --- /dev/null +++ b/internal/unsafe/unsafe_test.go @@ -0,0 +1,30 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package unsafe + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestUnsafe(t *testing.T) { + s := "hello" + b := []byte("hello") + assert.Equal(t, s, ByteSliceToString(b)) + assert.Equal(t, b, StringToByteSlice(s)) +} diff --git a/profile/README.md b/profile/README.md deleted file mode 100644 index 2127160..0000000 --- a/profile/README.md +++ /dev/null @@ -1,13 +0,0 @@ -## Hi there 👋 - -🙋‍♀️ A short introduction - CloudWeGo is an open-source middleware set launched by ByteDance that can be used to quickly build enterprise-class cloud native architectures. The common characteristics of CloudWeGo projects are high performance, high scalability, high reliability and focusing on microservices communication and governance. - -🌈 Community Membership - the [Responsibilities and Requirements](https://github.com/cloudwego/community/blob/main/COMMUNITY_MEMBERSHIP.md) of contributor roles in CloudWeGo. - -👩‍💻 Useful resources - [Portal](https://www.cloudwego.io/), [Community](https://www.cloudwego.io/zh/community/), [Blogs](https://www.cloudwego.io/zh/blog/), [Use Cases](https://www.cloudwego.io/zh/cooperation/) - -🍿 Security - [Vulnerability Reporting](https://www.cloudwego.io/zh/security/vulnerability-reporting/), [Safety Bulletin](https://www.cloudwego.io/zh/security/safety-bulletin/) - -🌲 Ecosystem - [Kitex-contrib](https://github.com/kitex-contrib), [Hertz-contrib](https://github.com/hertz-contrib), [Volo-rs](https://github.com/volo-rs) - -🎊 Example - [kitex-example](https://github.com/cloudwego/kitex-examples), [hertz-example](https://github.com/cloudwego/hertz-examples), [biz-demo](https://github.com/cloudwego/biz-demo), [netpoll-example](https://github.com/cloudwego/netpoll-examples) diff --git a/protocol/thrift/binary.go b/protocol/thrift/binary.go new file mode 100644 index 0000000..7645ba7 --- /dev/null +++ b/protocol/thrift/binary.go @@ -0,0 +1,483 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "encoding/binary" + "fmt" + "math" + + "github.com/cloudwego/gopkg/internal/unsafe" +) + +var Binary binaryProtocol + +type binaryProtocol struct{} + +func (binaryProtocol) WriteMessageBegin(buf []byte, name string, typeID TMessageType, seq int32) int { + binary.BigEndian.PutUint32(buf, uint32(msgVersion1)|uint32(typeID&msgTypeMask)) + binary.BigEndian.PutUint32(buf[4:], uint32(len(name))) + off := 8 + copy(buf[8:], name) + binary.BigEndian.PutUint32(buf[off:], uint32(seq)) + return off + 4 +} + +func (binaryProtocol) WriteFieldBegin(buf []byte, typeID TType, id int16) int { + buf[0] = byte(typeID) + binary.BigEndian.PutUint16(buf[1:], uint16(id)) + return 3 +} + +func (binaryProtocol) WriteFieldStop(buf []byte) int { + buf[0] = byte(STOP) + return 1 +} + +func (binaryProtocol) WriteMapBegin(buf []byte, kt, vt TType, size int) int { + buf[0] = byte(kt) + buf[1] = byte(vt) + binary.BigEndian.PutUint32(buf[2:], uint32(size)) + return 6 +} + +func (binaryProtocol) WriteListBegin(buf []byte, et TType, size int) int { + buf[0] = byte(et) + binary.BigEndian.PutUint32(buf[1:], uint32(size)) + return 5 +} + +func (binaryProtocol) WriteSetBegin(buf []byte, et TType, size int) int { + buf[0] = byte(et) + binary.BigEndian.PutUint32(buf[1:], uint32(size)) + return 5 +} + +func (binaryProtocol) WriteBool(buf []byte, v bool) int { + if v { + buf[0] = 1 + } else { + buf[0] = 0 + } + return 1 +} + +func (binaryProtocol) WriteByte(buf []byte, v int8) int { + buf[0] = byte(v) + return 1 +} + +func (binaryProtocol) WriteI16(buf []byte, v int16) int { + binary.BigEndian.PutUint16(buf, uint16(v)) + return 2 +} + +func (binaryProtocol) WriteI32(buf []byte, v int32) int { + binary.BigEndian.PutUint32(buf, uint32(v)) + return 4 +} + +func (binaryProtocol) WriteI64(buf []byte, v int64) int { + binary.BigEndian.PutUint64(buf, uint64(v)) + return 8 +} + +func (binaryProtocol) WriteDouble(buf []byte, v float64) int { + binary.BigEndian.PutUint64(buf, math.Float64bits(v)) + return 8 +} + +func (binaryProtocol) WriteBinary(buf, v []byte) int { + binary.BigEndian.PutUint32(buf, uint32(len(v))) + return 4 + copy(buf[4:], v) +} + +func (binaryProtocol) WriteBinaryNocopy(buf []byte, w NocopyWriter, v []byte) int { + if w == nil || len(buf) < NocopyWriteThreshold { + return Binary.WriteBinary(buf, v) + } + binary.BigEndian.PutUint32(buf, uint32(len(v))) + _ = w.WriteDirect(v, len(buf[4:])) // always err == nil ? + return 4 +} + +func (binaryProtocol) WriteString(buf []byte, v string) int { + binary.BigEndian.PutUint32(buf, uint32(len(v))) + return 4 + copy(buf[4:], v) +} + +func (binaryProtocol) WriteStringNocopy(buf []byte, w NocopyWriter, v string) int { + return Binary.WriteBinaryNocopy(buf, w, unsafe.StringToByteSlice(v)) +} + +// Append methods + +func (binaryProtocol) AppendMessageBegin(buf []byte, name string, typeID TMessageType, seq int32) []byte { + buf = appendUint32(buf, uint32(msgVersion1)|uint32(typeID&msgTypeMask)) + buf = Binary.AppendString(buf, name) + return Binary.AppendI32(buf, seq) +} + +func (binaryProtocol) AppendFieldBegin(buf []byte, typeID TType, id int16) []byte { + return append(buf, byte(typeID), byte(uint16(id>>8)), byte(id)) +} + +func (binaryProtocol) AppendFieldStop(buf []byte) []byte { + return append(buf, byte(STOP)) +} + +func (binaryProtocol) AppendMapBegin(buf []byte, kt, vt TType, size int) []byte { + return Binary.AppendI32(append(buf, byte(kt), byte(vt)), int32(size)) +} + +func (binaryProtocol) AppendListBegin(buf []byte, et TType, size int) []byte { + return Binary.AppendI32(append(buf, byte(et)), int32(size)) +} + +func (binaryProtocol) AppendSetBegin(buf []byte, et TType, size int) []byte { + return Binary.AppendI32(append(buf, byte(et)), int32(size)) +} + +func (binaryProtocol) AppendBinary(buf, v []byte) []byte { + return append(Binary.AppendI32(buf, int32(len(v))), v...) +} + +func (binaryProtocol) AppendString(buf []byte, v string) []byte { + return append(Binary.AppendI32(buf, int32(len(v))), v...) +} + +func (binaryProtocol) AppendBool(buf []byte, v bool) []byte { + if v { + return append(buf, 1) + } else { + return append(buf, 0) + } +} + +func (binaryProtocol) AppendByte(buf []byte, v int8) []byte { + return append(buf, byte(v)) +} + +func (binaryProtocol) AppendI16(buf []byte, v int16) []byte { + return append(buf, byte(uint16(v)>>8), byte(v)) +} + +func (binaryProtocol) AppendI32(buf []byte, v int32) []byte { + return appendUint32(buf, uint32(v)) +} + +func (binaryProtocol) AppendI64(buf []byte, v int64) []byte { + return appendUint64(buf, uint64(v)) +} + +func (binaryProtocol) AppendDouble(buf []byte, v float64) []byte { + return appendUint64(buf, math.Float64bits(v)) +} + +func appendUint32(buf []byte, v uint32) []byte { + return append(buf, byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) +} + +func appendUint64(buf []byte, v uint64) []byte { + return append(buf, byte(v>>56), byte(v>>48), byte(v>>40), byte(v>>32), + byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) +} + +// Length methods + +func (binaryProtocol) MessageBeginLength(name string, _ TMessageType, _ int32) int { + return 4 + (4 + len(name)) + 4 +} + +func (binaryProtocol) FieldBeginLength() int { return 3 } +func (binaryProtocol) FieldStopLength() int { return 1 } +func (binaryProtocol) MapBeginLength() int { return 6 } +func (binaryProtocol) ListBeginLength() int { return 5 } +func (binaryProtocol) SetBeginLength() int { return 5 } +func (binaryProtocol) BoolLength() int { return 1 } +func (binaryProtocol) ByteLength() int { return 1 } +func (binaryProtocol) I16Length() int { return 2 } +func (binaryProtocol) I32Length() int { return 4 } +func (binaryProtocol) I64Length() int { return 8 } +func (binaryProtocol) DoubleLength() int { return 8 } +func (binaryProtocol) StringLength(v string) int { return 4 + len(v) } +func (binaryProtocol) BinaryLength(v []byte) int { return 4 + len(v) } +func (binaryProtocol) StringLengthNocopy(v string) int { return 4 + len(v) } +func (binaryProtocol) BinaryLengthNocopy(v []byte) int { return 4 + len(v) } + +// Read methods + +var ( + errReadMessage = NewProtocolException(INVALID_DATA, "ReadMessageBegin: buf too small") + errBadVersion = NewProtocolException(BAD_VERSION, "ReadMessageBegin: bad version") +) + +func (binaryProtocol) ReadMessageBegin(buf []byte) (name string, typeID TMessageType, seq int32, l int, err error) { + if len(buf) < 4 { // version+type header + name header + return "", 0, 0, 0, errReadMessage + } + + // read header for version and type + header := binary.BigEndian.Uint32(buf) + if header&msgVersionMask != msgVersion1 { + return "", 0, 0, 0, errBadVersion + } + typeID = TMessageType(header & msgTypeMask) + + off := 4 + + // read method name + name, l, err1 := Binary.ReadString(buf[off:]) + if err1 != nil { + return "", 0, 0, 0, errReadMessage + } + off += l + + // read seq + seq, l, err2 := Binary.ReadI32(buf[off:]) + if err2 != nil { + return "", 0, 0, 0, errReadMessage + } + off += l + return name, typeID, seq, off, nil +} + +var ( + errReadField = NewProtocolException(INVALID_DATA, "ReadFieldBegin: buf too small") + errReadMap = NewProtocolException(INVALID_DATA, "ReadMapBegin: buf too small") + errReadList = NewProtocolException(INVALID_DATA, "ReadListBegin: buf too small") + errReadSet = NewProtocolException(INVALID_DATA, "ReadSetBegin: buf too small") + errReadStr = NewProtocolException(INVALID_DATA, "ReadString: buf too small") + errReadBin = NewProtocolException(INVALID_DATA, "ReadBinary: buf too small") + + errReadBool = NewProtocolException(INVALID_DATA, "ReadBool: len(buf) < 1") + errReadByte = NewProtocolException(INVALID_DATA, "ReadByte: len(buf) < 1") + errReadI16 = NewProtocolException(INVALID_DATA, "ReadI16: len(buf) < 2") + errReadI32 = NewProtocolException(INVALID_DATA, "ReadI32: len(buf) < 4") + errReadI64 = NewProtocolException(INVALID_DATA, "ReadI64: len(buf) < 8") + errReadDouble = NewProtocolException(INVALID_DATA, "ReadDouble: len(buf) < 8") +) + +func (binaryProtocol) ReadFieldBegin(buf []byte) (typeID TType, id int16, l int, err error) { + if len(buf) < 1 { + return 0, 0, 0, errReadField + } + typeID = TType(buf[0]) + if typeID == STOP { + return STOP, 0, 1, nil + } + if len(buf) < 3 { + return 0, 0, 0, errReadField + } + return typeID, int16(binary.BigEndian.Uint16(buf[1:])), 3, nil +} + +func (binaryProtocol) ReadMapBegin(buf []byte) (kt, vt TType, size, l int, err error) { + if len(buf) < 6 { + return 0, 0, 0, 0, errReadMap + } + return TType(buf[0]), TType(buf[1]), int(binary.BigEndian.Uint32(buf[2:])), 6, nil +} + +func (binaryProtocol) ReadListBegin(buf []byte) (et TType, size, l int, err error) { + if len(buf) < 5 { + return 0, 0, 0, errReadList + } + return TType(buf[0]), int(binary.BigEndian.Uint32(buf[1:])), 5, nil +} + +func (binaryProtocol) ReadSetBegin(buf []byte) (et TType, size, l int, err error) { + if len(buf) < 5 { + return 0, 0, 0, errReadSet + } + return TType(buf[0]), int(binary.BigEndian.Uint32(buf[1:])), 5, nil +} + +func (binaryProtocol) ReadBinary(buf []byte) (b []byte, l int, err error) { + sz, _, err := Binary.ReadI32(buf) + if err != nil { + return nil, 0, errReadBin + } + l = 4 + int(sz) + if len(buf) < l { + return nil, 4, errReadBin + } + // TODO: use span + return []byte(string(buf[4:l])), l, nil +} + +func (binaryProtocol) ReadString(buf []byte) (s string, l int, err error) { + sz, _, err := Binary.ReadI32(buf) + if err != nil { + return "", 0, errReadStr + } + l = 4 + int(sz) + if len(buf) < l { + return "", 4, errReadStr + } + // TODO: use span + return string(buf[4:l]), l, nil +} + +func (binaryProtocol) ReadBool(buf []byte) (v bool, l int, err error) { + if len(buf) < 1 { + return false, 0, errReadBool + } + if buf[0] == 1 { + return true, 1, nil + } + return false, 1, nil +} + +func (binaryProtocol) ReadByte(buf []byte) (v int8, l int, err error) { + if len(buf) < 1 { + return 0, 0, errReadByte + } + return int8(buf[0]), 1, nil +} + +func (binaryProtocol) ReadI16(buf []byte) (v int16, l int, err error) { + if len(buf) < 2 { + return 0, 0, errReadI16 + } + return int16(binary.BigEndian.Uint16(buf)), 2, nil +} + +func (binaryProtocol) ReadI32(buf []byte) (v int32, l int, err error) { + if len(buf) < 4 { + return 0, 0, errReadI32 + } + return int32(binary.BigEndian.Uint32(buf)), 4, nil +} + +func (binaryProtocol) ReadI64(buf []byte) (v int64, l int, err error) { + if len(buf) < 8 { + return 0, 0, errReadI64 + } + return int64(binary.BigEndian.Uint64(buf)), 8, nil +} + +func (binaryProtocol) ReadDouble(buf []byte) (v float64, l int, err error) { + if len(buf) < 8 { + return 0, 0, errReadDouble + } + return math.Float64frombits(binary.BigEndian.Uint64(buf)), 8, nil +} + +var ( + errDepthLimitExceeded = NewProtocolException(DEPTH_LIMIT, "depth limit exceeded") + errNegativeSize = NewProtocolException(NEGATIVE_SIZE, "negative size") +) + +var typeToSize = [256]int8{ + BOOL: 1, + BYTE: 1, + DOUBLE: 8, + I16: 2, + I32: 4, + I64: 8, +} + +func skipstr(b []byte) int { + return 4 + int(binary.BigEndian.Uint32(b)) +} + +// Skip skips over the value for the given type using Go implementation. +func (binaryProtocol) Skip(b []byte, t TType) (int, error) { + return skipType(b, t, defaultRecursionDepth) +} + +func skipType(b []byte, t TType, maxdepth int) (int, error) { + if maxdepth == 0 { + return 0, errDepthLimitExceeded + } + if n := typeToSize[t]; n > 0 { + return int(n), nil + } + switch t { + case STRING: + return skipstr(b), nil + case MAP: + i := 6 + kt, vt, sz := TType(b[0]), TType(b[1]), int32(binary.BigEndian.Uint32(b[2:])) + if sz < 0 { + return 0, errNegativeSize + } + ksz, vsz := int(typeToSize[kt]), int(typeToSize[vt]) + if ksz > 0 && vsz > 0 { + return i + (int(sz) * (ksz + vsz)), nil + } + for j := int32(0); j < sz; j++ { + if ksz > 0 { + i += ksz + } else if kt == STRING { + i += skipstr(b[i:]) + } else if n, err := skipType(b[i:], kt, maxdepth-1); err != nil { + return i, err + } else { + i += n + } + if vsz > 0 { + i += vsz + } else if vt == STRING { + i += skipstr(b[i:]) + } else if n, err := skipType(b[i:], vt, maxdepth-1); err != nil { + return i, err + } else { + i += n + } + } + return i, nil + case LIST, SET: + i := 5 + vt, sz := TType(b[0]), int32(binary.BigEndian.Uint32(b[1:])) + if sz < 0 { + return 0, errNegativeSize + } + if typeToSize[vt] > 0 { + return i + int(sz)*int(typeToSize[vt]), nil + } + for j := int32(0); j < sz; j++ { + if vt == STRING { + i += skipstr(b[i:]) + } else if n, err := skipType(b[i:], vt, maxdepth-1); err != nil { + return i, err + } else { + i += n + } + } + return i, nil + case STRUCT: + i := 0 + for { + ft := TType(b[i]) + i += 1 // TType + if ft == STOP { + return i, nil + } + i += 2 // Field ID + if typeToSize[ft] > 0 { + i += int(typeToSize[ft]) + } else if n, err := skipType(b[i:], ft, maxdepth-1); err != nil { + return i, err + } else { + i += n + } + } + default: + return 0, NewProtocolException(INVALID_DATA, fmt.Sprintf("unknown data type %d", t)) + } +} diff --git a/protocol/thrift/binary_test.go b/protocol/thrift/binary_test.go new file mode 100644 index 0000000..dbb1ce7 --- /dev/null +++ b/protocol/thrift/binary_test.go @@ -0,0 +1,419 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBinary(t *testing.T) { + + { // Bool + sz := 2 * Binary.BoolLength() + + b := Binary.AppendBool([]byte(nil), true) + b = Binary.AppendBool(b, false) + require.Equal(t, int(sz), len(b)) + + b1 := make([]byte, sz) + l := Binary.WriteBool(b1, true) + l += Binary.WriteBool(b1[l:], false) + require.Equal(t, int(sz), l) + require.Equal(t, b, b1) + + v, l, _ := Binary.ReadBool(b) + require.Equal(t, 1, l) + require.True(t, v) + v, l, _ = Binary.ReadBool(b[1:]) + require.Equal(t, 1, l) + require.False(t, v) + + _, _, err := Binary.ReadBool([]byte(nil)) + require.Same(t, errReadBool, err) + } + + { // Byte + sz := Binary.ByteLength() + + b := Binary.AppendByte([]byte(nil), 1) + require.Equal(t, int(sz), len(b)) + + b1 := make([]byte, sz) + l := Binary.WriteByte(b1, 1) + require.Equal(t, int(sz), l) + require.Equal(t, b, b1) + + v, l, _ := Binary.ReadByte(b) + require.Equal(t, 1, l) + require.Equal(t, int8(1), v) + + _, _, err := Binary.ReadByte([]byte(nil)) + require.Same(t, errReadByte, err) + } + + { // I16 + testv := int16(0x7f) + sz := Binary.I16Length() + + b := Binary.AppendI16([]byte(nil), testv) + require.Equal(t, int(sz), len(b)) + + b1 := make([]byte, sz) + l := Binary.WriteI16(b1, testv) + require.Equal(t, int(sz), l) + require.Equal(t, b, b1) + + v, l, _ := Binary.ReadI16(b) + require.Equal(t, sz, l) + require.Equal(t, testv, v) + + _, _, err := Binary.ReadI16([]byte(nil)) + require.Same(t, errReadI16, err) + } + + { // I32 + testv := int32(0x7fffffff) + sz := Binary.I32Length() + + b := Binary.AppendI32([]byte(nil), testv) + require.Equal(t, int(sz), len(b)) + + b1 := make([]byte, sz) + l := Binary.WriteI32(b1, testv) + require.Equal(t, int(sz), l) + require.Equal(t, b, b1) + + v, l, _ := Binary.ReadI32(b) + require.Equal(t, sz, l) + require.Equal(t, testv, v) + + _, _, err := Binary.ReadI32([]byte(nil)) + require.Same(t, errReadI32, err) + } + + { // I64 + testv := int64(0x7fffffff7fffffff) + sz := Binary.I64Length() + + b := Binary.AppendI64([]byte(nil), testv) + require.Equal(t, int(sz), len(b)) + + b1 := make([]byte, sz) + l := Binary.WriteI64(b1, testv) + require.Equal(t, int(sz), l) + require.Equal(t, b, b1) + + v, l, _ := Binary.ReadI64(b) + require.Equal(t, sz, l) + require.Equal(t, testv, v) + + _, _, err := Binary.ReadI64([]byte(nil)) + require.Same(t, errReadI64, err) + } + + { // Double + testv := float64(0.125) + sz := Binary.DoubleLength() + + b := Binary.AppendDouble([]byte(nil), testv) + require.Equal(t, int(sz), len(b)) + + b1 := make([]byte, sz) + l := Binary.WriteDouble(b1, testv) + require.Equal(t, int(sz), l) + require.Equal(t, b, b1) + + v, l, _ := Binary.ReadDouble(b) + require.Equal(t, sz, l) + require.Equal(t, testv, v) + + _, _, err := Binary.ReadDouble([]byte(nil)) + require.Same(t, errReadDouble, err) + } + + { // Binary + testv := []byte("hello") + sz := Binary.BinaryLength(testv) + + b := Binary.AppendBinary([]byte(nil), testv) + require.Equal(t, int(sz), len(b)) + + b1 := make([]byte, sz) + l := Binary.WriteBinaryNocopy(b1, nil, testv) + require.Equal(t, int(sz), l) + require.Equal(t, b, b1) + + v, l, _ := Binary.ReadBinary(b) + require.Equal(t, sz, l) + require.Equal(t, testv, v) + + _, _, err := Binary.ReadBinary([]byte(nil)) + require.Same(t, errReadBin, err) + } + + { // String + testv := "hello" + sz := Binary.StringLength(testv) + + b := Binary.AppendString([]byte(nil), testv) + require.Equal(t, int(sz), len(b)) + + b1 := make([]byte, sz) + l := Binary.WriteStringNocopy(b1, nil, testv) + require.Equal(t, int(sz), l) + require.Equal(t, b, b1) + + v, l, _ := Binary.ReadString(b) + require.Equal(t, sz, l) + require.Equal(t, testv, v) + + _, _, err := Binary.ReadString([]byte(nil)) + require.Same(t, errReadStr, err) + } + + { // Message + testname, testtyp, testseq := "name", CALL, int32(7) + sz := Binary.MessageBeginLength(testname, testtyp, testseq) + + b := Binary.AppendMessageBegin([]byte(nil), testname, testtyp, testseq) + require.Equal(t, int(sz), len(b)) + + b1 := make([]byte, sz) + l := Binary.WriteMessageBegin(b1, testname, testtyp, testseq) + require.Equal(t, int(sz), l) + require.Equal(t, b, b1) + + name, typ, seq, l, _ := Binary.ReadMessageBegin(b) + require.Equal(t, sz, l) + require.Equal(t, testname, name) + require.Equal(t, testtyp, typ) + require.Equal(t, testseq, seq) + + _, _, _, _, err := Binary.ReadMessageBegin([]byte(nil)) + require.Same(t, errReadMessage, err) + } + + { // Field + testtyp, testfid := I64, int16(7) + sz := Binary.FieldBeginLength() + Binary.FieldStopLength() + + b := Binary.AppendFieldBegin([]byte(nil), testtyp, testfid) + b = Binary.AppendFieldStop(b) + require.Equal(t, int(sz), len(b)) + + b1 := make([]byte, sz) + l := Binary.WriteFieldBegin(b1, testtyp, testfid) + l += Binary.WriteFieldStop(b1[l:]) + require.Equal(t, int(sz), l) + require.Equal(t, b, b1) + + typ, fid, l, _ := Binary.ReadFieldBegin(b) + require.Equal(t, sz, l+1) // +STOP + require.Equal(t, testtyp, typ) + require.Equal(t, testfid, fid) + + typ, _, l, err := Binary.ReadFieldBegin(b[l:]) + require.NoError(t, err) + require.Equal(t, 1, l) + require.Equal(t, STOP, typ) + + _, _, _, err = Binary.ReadFieldBegin([]byte(nil)) + require.Same(t, errReadField, err) + } + + { // Map + testkt, testvt, testsize := I64, I32, 7 + sz := Binary.MapBeginLength() + + b := Binary.AppendMapBegin([]byte(nil), testkt, testvt, testsize) + require.Equal(t, int(sz), len(b)) + + b1 := make([]byte, sz) + l := Binary.WriteMapBegin(b1, testkt, testvt, testsize) + require.Equal(t, int(sz), l) + require.Equal(t, b, b1) + + kt, vt, size, l, _ := Binary.ReadMapBegin(b) + require.Equal(t, sz, l) + require.Equal(t, testkt, kt) + require.Equal(t, testvt, vt) + require.Equal(t, testsize, size) + + _, _, _, _, err := Binary.ReadMapBegin([]byte(nil)) + require.Same(t, errReadMap, err) + } + + { // List + testvt, testsize := I32, 7 + sz := Binary.ListBeginLength() + + b := Binary.AppendListBegin([]byte(nil), testvt, testsize) + require.Equal(t, int(sz), len(b)) + + b1 := make([]byte, sz) + l := Binary.WriteListBegin(b1, testvt, testsize) + require.Equal(t, int(sz), l) + require.Equal(t, b, b1) + + vt, size, l, _ := Binary.ReadListBegin(b) + require.Equal(t, sz, l) + require.Equal(t, testvt, vt) + require.Equal(t, testsize, size) + + _, _, _, err := Binary.ReadListBegin([]byte(nil)) + require.Same(t, errReadList, err) + } + + { // Set + testvt, testsize := I32, 7 + sz := Binary.SetBeginLength() + + b := Binary.AppendSetBegin([]byte(nil), testvt, testsize) + require.Equal(t, int(sz), len(b)) + + b1 := make([]byte, sz) + l := Binary.WriteSetBegin(b1, testvt, testsize) + require.Equal(t, int(sz), l) + require.Equal(t, b, b1) + + vt, size, l, _ := Binary.ReadSetBegin(b) + require.Equal(t, sz, l) + require.Equal(t, testvt, vt) + require.Equal(t, testsize, size) + + _, _, _, err := Binary.ReadSetBegin([]byte(nil)) + require.Same(t, errReadSet, err) + } +} + +func TestBinarySkip(t *testing.T) { + // byte + b := Binary.AppendByte([]byte(nil), 1) + + // string + b = Binary.AppendString(b, "hello") + + // list + b = Binary.AppendListBegin(b, I32, 1) + b = Binary.AppendI32(b, 1) + + // list + b = Binary.AppendListBegin(b, STRING, 1) + b = Binary.AppendString(b, "hello") + + // list> + b = Binary.AppendListBegin(b, LIST, 1) + b = Binary.AppendListBegin(b, I32, 1) + b = Binary.AppendI32(b, 1) + + // map + b = Binary.AppendMapBegin(b, I32, I64, 1) + b = Binary.AppendI32(b, 1) + b = Binary.AppendI64(b, 2) + + // map + b = Binary.AppendMapBegin(b, I32, STRING, 1) + b = Binary.AppendI32(b, 1) + b = Binary.AppendString(b, "hello") + + // map + b = Binary.AppendMapBegin(b, STRING, I64, 1) + b = Binary.AppendString(b, "hello") + b = Binary.AppendI64(b, 2) + + // map> + b = Binary.AppendMapBegin(b, I32, LIST, 1) + b = Binary.AppendI32(b, 1) + b = Binary.AppendListBegin(b, I32, 1) + b = Binary.AppendI32(b, 1) + + // map, i32> + b = Binary.AppendMapBegin(b, LIST, I32, 1) + b = Binary.AppendListBegin(b, I32, 1) + b = Binary.AppendI32(b, 1) + b = Binary.AppendI32(b, 1) + + // struct i32, list + b = Binary.AppendFieldBegin(b, I32, 1) + b = Binary.AppendI32(b, 1) + b = Binary.AppendFieldBegin(b, LIST, 1) + b = Binary.AppendListBegin(b, I32, 1) + b = Binary.AppendI32(b, 1) + b = Binary.AppendFieldStop(b) + + off := 0 + + l, err := Binary.Skip(b[off:], BYTE) + require.NoError(t, err) + off += l + + l, err = Binary.Skip(b[off:], STRING) + require.NoError(t, err) + off += l + + l, err = Binary.Skip(b[off:], LIST) // list + require.NoError(t, err) + off += l + + l, err = Binary.Skip(b[off:], LIST) // list + require.NoError(t, err) + off += l + + l, err = Binary.Skip(b[off:], LIST) // list> + require.NoError(t, err) + off += l + + l, err = Binary.Skip(b[off:], MAP) // map + require.NoError(t, err) + off += l + + l, err = Binary.Skip(b[off:], MAP) // map + require.NoError(t, err) + off += l + + l, err = Binary.Skip(b[off:], MAP) // map + require.NoError(t, err) + off += l + + l, err = Binary.Skip(b[off:], MAP) // map> + require.NoError(t, err) + off += l + + l, err = Binary.Skip(b[off:], MAP) // map, i32> + require.NoError(t, err) + off += l + + l, err = Binary.Skip(b[off:], STRUCT) // struct i32, list + require.NoError(t, err) + off += l + + require.Equal(t, len(b), off) + + // errDepthLimitExceeded + b = b[:0] + for i := 0; i < defaultRecursionDepth+1; i++ { + b = Binary.AppendFieldBegin(b, STRUCT, 1) + } + _, err = Binary.Skip(b, STRUCT) + require.Same(t, errDepthLimitExceeded, err) + + // unknown type + _, err = Binary.Skip(b, TType(122)) + require.Error(t, err) +} diff --git a/protocol/thrift/exception.go b/protocol/thrift/exception.go new file mode 100644 index 0000000..a1e069e --- /dev/null +++ b/protocol/thrift/exception.go @@ -0,0 +1,198 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "errors" + "fmt" +) + +const ( // ApplicationException codes from apache thrift + UNKNOWN_APPLICATION_EXCEPTION = 0 + UNKNOWN_METHOD = 1 + INVALID_MESSAGE_TYPE_EXCEPTION = 2 + WRONG_METHOD_NAME = 3 + BAD_SEQUENCE_ID = 4 + MISSING_RESULT = 5 + INTERNAL_ERROR = 6 + PROTOCOL_ERROR = 7 + INVALID_TRANSFORM = 8 + INVALID_PROTOCOL = 9 + UNSUPPORTED_CLIENT_TYPE = 10 +) + +// ApplicationException is for replacing apache.TApplicationException +// it implements ThriftFastCodec interface. +type ApplicationException struct { + t int32 + m string +} + +// NewApplicationException creates an ApplicationException instance +func NewApplicationException(t int32, msg string) *ApplicationException { + return &ApplicationException{t: t, m: msg} +} + +// Msg ... +func (e *ApplicationException) Msg() string { return e.m } + +// TypeID ... for kitex +func (e *ApplicationException) TypeID() int32 { return e.t } + +// TypeId ... for apache ApplicationException compatibility +func (e *ApplicationException) TypeId() int32 { return e.t } + +// BLength returns the len of encoded buffer. +func (e *ApplicationException) BLength() int { + return Binary.FieldBeginLength() + Binary.StringLength(e.m) + // e.m + Binary.FieldBeginLength() + Binary.I32Length() + // e.t + Binary.FieldStopLength() // STOP +} + +// FastRead ... +func (e *ApplicationException) FastRead(b []byte) (off int, err error) { + for { + tp, id, l, err := Binary.ReadFieldBegin(b[off:]) + if err != nil { + return off, err + } + off += l + if tp == STOP { + break + } + switch { + case id == 1 && tp == STRING: // Msg + e.m, l, err = Binary.ReadString(b[off:]) + case id == 2 && tp == I32: // TypeID + e.t, l, err = Binary.ReadI32(b[off:]) + default: + l, err = Binary.Skip(b, tp) + } + if err != nil { + return off, err + } + off += l + } + return off, nil +} + +// FastWrite ... +func (e *ApplicationException) FastWrite(b []byte) (off int) { + off += Binary.WriteFieldBegin(b[off:], STRING, 1) + off += Binary.WriteString(b[off:], e.m) + off += Binary.WriteFieldBegin(b[off:], I32, 2) + off += Binary.WriteI32(b[off:], e.t) + off += Binary.WriteByte(b[off:], STOP) + return off +} + +// FastWriteNocopy ... +func (e *ApplicationException) FastWriteNocopy(b []byte, _ NocopyWriter) int { + return e.FastWrite(b) +} + +// originally from github.com/apache/thrift@v0.13.0/lib/go/thrift/exception.go +var defaultApplicationExceptionMessage = map[int32]string{ + UNKNOWN_APPLICATION_EXCEPTION: "unknown application exception", + UNKNOWN_METHOD: "unknown method", + INVALID_MESSAGE_TYPE_EXCEPTION: "invalid message type", + WRONG_METHOD_NAME: "wrong method name", + BAD_SEQUENCE_ID: "bad sequence ID", + MISSING_RESULT: "missing result", + INTERNAL_ERROR: "unknown internal error", + PROTOCOL_ERROR: "unknown protocol error", + INVALID_TRANSFORM: "Invalid transform", + INVALID_PROTOCOL: "Invalid protocol", + UNSUPPORTED_CLIENT_TYPE: "Unsupported client type", +} + +// Error ... +func (e *ApplicationException) Error() string { + if e.m != "" { + return e.m + } + if m, ok := defaultApplicationExceptionMessage[e.t]; ok { + return m + } + return fmt.Sprintf("unknown exception type [%d]", e.t) +} + +// String ... +func (e *ApplicationException) String() string { + return fmt.Sprintf("ApplicationException(%d): %q", e.t, e.m) +} + +// TransportException is for replacing apache.TransportException +// it implements ThriftFastCodec interface. +type TransportException struct { + ApplicationException // same implementation ... +} + +// NewTransportException ... +func NewTransportException(t int32, m string) *TransportException { + ret := TransportException{} + ret.t = t + ret.m = m + return &ret +} + +// ProtocolException is for replacing apache.ProtocolException +// it implements ThriftFastCodec interface. +type ProtocolException struct { + ApplicationException // same implementation ... +} + +const ( // ProtocolException codes from apache thrift + UNKNOWN_PROTOCOL_EXCEPTION = 0 + INVALID_DATA = 1 + NEGATIVE_SIZE = 2 + SIZE_LIMIT = 3 + BAD_VERSION = 4 + NOT_IMPLEMENTED = 5 + DEPTH_LIMIT = 6 +) + +// NewTransportException ... +func NewProtocolException(t int32, m string) *ProtocolException { + ret := ProtocolException{} + ret.t = t + ret.m = m + return &ret +} + +// Generic Thrift exception with TypeId method +type tException interface { + Error() string + TypeId() int32 +} + +// Prepends additional information to an error without losing the Thrift exception interface +func PrependError(prepend string, err error) error { + if t, ok := err.(*TransportException); ok { + return NewTransportException(t.TypeID(), prepend+t.Error()) + } + if t, ok := err.(*ProtocolException); ok { + return NewProtocolException(t.TypeID(), prepend+err.Error()) + } + if t, ok := err.(*ApplicationException); ok { + return NewApplicationException(t.TypeID(), prepend+t.Error()) + } + if t, ok := err.(tException); ok { // apache thrift exception? + return NewApplicationException(t.TypeId(), prepend+t.Error()) + } + return errors.New(prepend + err.Error()) +} diff --git a/protocol/thrift/exception_test.go b/protocol/thrift/exception_test.go new file mode 100644 index 0000000..6bf98b7 --- /dev/null +++ b/protocol/thrift/exception_test.go @@ -0,0 +1,95 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestApplicationException(t *testing.T) { + ex1 := NewApplicationException(1, "t1") + b := make([]byte, ex1.BLength()) + n := ex1.FastWriteNocopy(b, nil) + assert.Equal(t, len(b), n) + + ex2 := NewApplicationException(0, "") + n, err := ex2.FastRead(b) + require.NoError(t, err) + assert.Equal(t, len(b), n) + assert.Equal(t, int32(1), ex2.TypeID()) + assert.Equal(t, int32(1), ex2.TypeId()) + assert.Equal(t, "t1", ex2.Msg()) + + ex3 := NewApplicationException(1, "") + assert.Equal(t, defaultApplicationExceptionMessage[ex3.TypeID()], ex3.Error()) + + ex4 := NewApplicationException(999, "") + assert.Equal(t, "unknown exception type [999]", ex4.Error()) + + t.Log(ex4.String()) // ... +} + +type testTException struct{} + +func (testTException) Error() string { return "testTException" } +func (testTException) TypeId() int32 { return -1 } + +func TestPrependError(t *testing.T) { + var ok bool + + // case TransportException + ex0 := NewTransportException(1, "world") + err0 := PrependError("hello ", ex0) + ex0, ok = err0.(*TransportException) + require.True(t, ok) + assert.Equal(t, int32(1), ex0.TypeID()) + assert.Equal(t, "hello world", ex0.Error()) + + // case ProtocolException + ex1 := NewProtocolException(2, "world") + err1 := PrependError("hello ", ex1) + ex1, ok = err1.(*ProtocolException) + require.True(t, ok) + assert.Equal(t, int32(2), ex1.TypeID()) + assert.Equal(t, "hello world", ex1.Error()) + + // case ApplicationException + ex2 := NewApplicationException(3, "world") + err2 := PrependError("hello ", ex2) + ex2, ok = err2.(*ApplicationException) + require.True(t, ok) + assert.Equal(t, int32(3), ex2.TypeID()) + assert.Equal(t, "hello world", ex2.Error()) + + // case tException + ex3 := testTException{} + err3 := PrependError("hello ", ex3) + ex4, ok := err3.(*ApplicationException) + require.True(t, ok) + assert.Equal(t, int32(-1), ex4.TypeID()) + assert.Equal(t, "hello testTException", ex4.Error()) + + // case normal error + err4 := PrependError("hello ", errors.New("world")) + _, ok = err4.(tException) + require.False(t, ok) + assert.Equal(t, "hello world", err4.Error()) +} diff --git a/protocol/thrift/thrift.go b/protocol/thrift/thrift.go new file mode 100644 index 0000000..2f9ed78 --- /dev/null +++ b/protocol/thrift/thrift.go @@ -0,0 +1,76 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +// TMessageType represents message type constants in the Thrift protocol. +// originally from github.com/apache/thrift +type TMessageType = int32 // use alias for better flexibility of interfaces + +const ( + INVALID_TMESSAGE_TYPE TMessageType = 0 + CALL TMessageType = 1 + REPLY TMessageType = 2 + EXCEPTION TMessageType = 3 + ONEWAY TMessageType = 4 +) + +// TType represents field type constants in the Thrift protocol +// originally from github.com/apache/thrift +type TType = int8 // use alias for better flexibility of interfaces + +const ( + STOP TType = 0 + VOID TType = 1 + BOOL TType = 2 + BYTE TType = 3 + I08 TType = 3 + DOUBLE TType = 4 + I16 TType = 6 + I32 TType = 8 + I64 TType = 10 + STRING TType = 11 + UTF7 TType = 11 + STRUCT TType = 12 + MAP TType = 13 + SET TType = 14 + LIST TType = 15 + UTF8 TType = 16 + UTF16 TType = 17 +) + +const defaultRecursionDepth = 64 // for skip + +const ( // for Write/ReadMessage + msgVersion1 = 0x80010000 + msgVersionMask = 0xffff0000 + msgTypeMask = 0x0000ffff // for TMessageType +) + +var NocopyWriteThreshold = 4096 // use NocopyWriter when binary or string > the value + +// BinaryWriter represents the method used in thrift encoding for nocopy writes +// It supports netpoll nocopy feature, see: https://github.com/cloudwego/netpoll/blob/develop/nocopy.go +type NocopyWriter interface { + WriteDirect(b []byte, remainCap int) error +} + +// ThriftFastCodec represents the interface of thrift fastcodec generated structs +type ThriftFastCodec interface { + BLength() int + FastWriteNocopy(buf []byte, bw NocopyWriter) int + FastRead(buf []byte) (int, error) +}