Skip to content

Commit

Permalink
feat(thrift/apache): check methods for kitex
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaost committed Aug 6, 2024
1 parent 9ef090c commit 762726c
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 67 deletions.
93 changes: 70 additions & 23 deletions protocol/thrift/apache/apache.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import (
"errors"
"fmt"
"reflect"
"sync"
)

var (
Expand Down Expand Up @@ -101,6 +102,8 @@ func RegisterNewTBinaryProtocol(fn interface{}) error {
return errNewFuncTypeNotMatch(t)
}
newTBinaryProtocol = v
hasThriftRead = sync.Map{}
hasThriftWrite = sync.Map{}
return nil
}

Expand All @@ -121,30 +124,83 @@ func checkThriftReadWriteFuncType(t reflect.Type) error {
return nil
}

// ThriftRead calls Read method of v.
//
// RegisterNewTBinaryProtocol must be called with `thrift.NewTBinaryProtocol`
// before using this func.
func ThriftRead(t TTransport, v interface{}) error {
var hasThriftRead = sync.Map{}

// CheckThriftRead returns nil if v has Read method and matches the func signature
func CheckThriftRead(v interface{}) error {
rv := reflect.ValueOf(v)
rt := rv.Type()
res, ok := hasThriftRead.Load(rt)
if ok {
// fast path
if res == nil {
return nil
}
return res.(error)
}
if rv.Kind() != reflect.Ptr {
// Read/Write method is always pointer receiver
hasThriftRead.Store(rt, errNotPointer)
return errNotPointer
}
rfunc := rv.MethodByName("Read")

// check Read func signature: func(thrift.TProtocol) error
if !rfunc.IsValid() || rfunc.Kind() != reflect.Func {
fv := rv.MethodByName("Read")
if !fv.IsValid() {
hasThriftRead.Store(rt, errNoReadMethod)
return errNoReadMethod
}
if err := checkThriftReadWriteFuncType(rfunc.Type()); err != nil {
if err := checkThriftReadWriteFuncType(fv.Type()); err != nil {
hasThriftRead.Store(rt, err)
return err
}
hasThriftRead.Store(rt, nil)
return nil
}

var hasThriftWrite = sync.Map{}

// CheckThriftWrite returns nil if v has Write method and matches the func signature
func CheckThriftWrite(v interface{}) error {
rv := reflect.ValueOf(v)
rt := rv.Type()
res, ok := hasThriftWrite.Load(rt)
if ok {
// fast path
if res == nil {
return nil
}
return res.(error)
}
if rv.Kind() != reflect.Ptr {
hasThriftWrite.Store(rt, errNotPointer)
return errNotPointer
}
fv := rv.MethodByName("Write")
if !fv.IsValid() {
hasThriftWrite.Store(rt, errNoWriteMethod)
return errNoWriteMethod
}
if err := checkThriftReadWriteFuncType(fv.Type()); err != nil {
hasThriftWrite.Store(rt, err)
return err
}
hasThriftWrite.Store(rt, nil)
return nil
}

// ThriftRead calls Read method of v.
//
// RegisterNewTBinaryProtocol must be called with `thrift.NewTBinaryProtocol`
// before using this func.
func ThriftRead(t TTransport, v interface{}) error {
if err := CheckThriftRead(v); err != nil {
return err
}

// iprot := NewTBinaryProtocol(t, true, true)
iprot := newTBinaryProtocol.Call([]reflect.Value{reflect.ValueOf(t), rvTrue, rvTrue})[0]

// err := v.Read(iprot)
rv := reflect.ValueOf(v)
rfunc := rv.MethodByName("Read")
err := rfunc.Call([]reflect.Value{iprot})[0]
if err.IsNil() {
return nil
Expand All @@ -157,25 +213,16 @@ func ThriftRead(t TTransport, v interface{}) error {
// RegisterNewTBinaryProtocol must be called with `thrift.NewTBinaryProtocol`
// before using this func.
func ThriftWrite(t TTransport, v interface{}) error {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr {
// Read/Write method is always pointer receiver
return errNotPointer
}
wfunc := rv.MethodByName("Write")

// check Write func signature: func(thrift.TProtocol) error
if !wfunc.IsValid() || wfunc.Kind() != reflect.Func {
return errNoWriteMethod
}
if err := checkThriftReadWriteFuncType(wfunc.Type()); err != nil {
if err := CheckThriftWrite(v); err != nil {
return err
}

// oprot := NewTBinaryProtocol(t, true, true)
oprot := newTBinaryProtocol.Call([]reflect.Value{reflect.ValueOf(t), rvTrue, rvTrue})[0]

// err := v.Write(oprot)
rv := reflect.ValueOf(v)
wfunc := rv.MethodByName("Write")
err := wfunc.Call([]reflect.Value{oprot})[0]
if err.IsNil() {
return nil
Expand Down
109 changes: 78 additions & 31 deletions protocol/thrift/apache/apache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"errors"
"io"
"reflect"
"sync"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -85,58 +86,104 @@ func TestThriftWriteRead(t *testing.T) {
assert.True(t, b0)
assert.True(t, b1)
called++
return trans.(BufferTransport).Buffer
return &trans.(*bufferTransport).Buffer
}
err := RegisterNewTBinaryProtocol(fn)
require.NoError(t, err)
defer func() { newTBinaryProtocol = reflect.Value{} }()

buf := &bytes.Buffer{}
p0 := &TestingWriteRead{Msg: "hello"}
err = ThriftWrite(BufferTransport{buf}, p0) // calls p0.Write
require.NoError(t, err)
require.Equal(t, 1, called)

p1 := &TestingWriteRead{}
err = ThriftRead(BufferTransport{buf}, p1) // calls p1.Read
require.NoError(t, err)
require.Equal(t, 2, called)
require.Equal(t, p0, p1)
expectcalls := 0
for i := 0; i < 2; i++ { // run twice to test cache
buf := &bytes.Buffer{}
p0 := &TestingWriteRead{Msg: "hello"}
err = ThriftWrite(NewBufferTransport(buf), p0) // calls p0.Write
require.NoError(t, err)
expectcalls++
require.Equal(t, expectcalls, called)

p1 := &TestingWriteRead{}
err = ThriftRead(NewBufferTransport(buf), p1) // calls p1.Read
require.NoError(t, err)
expectcalls++
require.Equal(t, expectcalls, called)
require.Equal(t, p0, p1)
}
}

type TestingWriteReadMethodNotMatch struct{}

func (p *TestingWriteReadMethodNotMatch) Read(v bool) error { return nil }
func (p *TestingWriteReadMethodNotMatch) Write(v bool) error { return nil }
func (_ *TestingWriteReadMethodNotMatch) Read(v bool) error { return nil }
func (_ *TestingWriteReadMethodNotMatch) Write(v bool) error { return nil }

type TestingNoReadWriteMethod struct{}

func (_ *TestingNoReadWriteMethod) Read1(v bool) error { return nil }
func (_ *TestingNoReadWriteMethod) Write1(v bool) error { return nil }

type TestingWriteReadNotReturningErr struct{}

func (_ *TestingWriteReadNotReturningErr) Read(r io.Reader) {}
func (_ *TestingWriteReadNotReturningErr) Write(w io.Writer) {}

func TestCheckThriftReadWriteErr(t *testing.T) {
// reset type cache
hasThriftRead = sync.Map{}
hasThriftWrite = sync.Map{}

func TestThriftWriteReadErr(t *testing.T) {
var err error

// errNotPointer
p := TestingWriteRead{Msg: "hello"}
err = ThriftWrite(BufferTransport{nil}, p)
assert.Same(t, err, errNotPointer)
err = ThriftRead(BufferTransport{nil}, p)
assert.Same(t, err, errNotPointer)
for i := 0; i < 2; i++ { // run twice to test cache
err = CheckThriftRead(TestingWriteRead{})
assert.Same(t, err, errNotPointer)
err = CheckThriftWrite(TestingWriteRead{})
assert.Same(t, err, errNotPointer)
}

// errNoNewTBinaryProtocol
err = ThriftWrite(BufferTransport{nil}, &p)
assert.Same(t, err, errNoNewTBinaryProtocol)
for i := 0; i < 2; i++ { // run twice to test cache
err = CheckThriftRead(&TestingWriteRead{})
assert.Same(t, err, errNoNewTBinaryProtocol)
err = CheckThriftWrite(&TestingWriteRead{})
assert.Same(t, err, errNoNewTBinaryProtocol)
}

fn := func(trans TTransport, b0, b1 bool) *bytes.Buffer { return nil }
RegisterNewTBinaryProtocol(fn)
defer func() { newTBinaryProtocol = reflect.Value{} }()

// errMethodType
for i := 0; i < 2; i++ {
err = CheckThriftRead(&TestingWriteReadMethodNotMatch{}) // input type err
assert.ErrorIs(t, err, errMethodType)
err = CheckThriftWrite(&TestingWriteReadMethodNotMatch{})
assert.ErrorIs(t, err, errMethodType)
err = CheckThriftRead(&TestingWriteReadNotReturningErr{}) // return type err
assert.ErrorIs(t, err, errMethodType)
err = CheckThriftWrite(&TestingWriteReadNotReturningErr{})
assert.ErrorIs(t, err, errMethodType)
}

// errNoReadMethod, errNoWriteMethod
for i := 0; i < 2; i++ {
err = CheckThriftRead(&TestingNoReadWriteMethod{})
assert.ErrorIs(t, err, errNoReadMethod)
err = CheckThriftWrite(&TestingNoReadWriteMethod{})
assert.ErrorIs(t, err, errNoWriteMethod)
}
}

func TestThriftWriteReadErr(t *testing.T) {
var err error

// Read/Write returns err
p := TestingWriteRead{}
fn := func(trans TTransport, b0, b1 bool) *bytes.Buffer { return nil }
RegisterNewTBinaryProtocol(fn)
defer func() { newTBinaryProtocol = reflect.Value{} }()
p.mockErr = errors.New("mock")
err = ThriftWrite(BufferTransport{nil}, &p)
err = ThriftWrite(NewBufferTransport(nil), &p)
assert.Same(t, err, p.mockErr)
err = ThriftRead(BufferTransport{nil}, &p)
err = ThriftRead(NewBufferTransport(nil), &p)
assert.Same(t, err, p.mockErr)

// errMethodType
p1 := TestingWriteReadMethodNotMatch{}
err = ThriftWrite(BufferTransport{nil}, &p1)
assert.ErrorIs(t, err, errMethodType)
err = ThriftRead(BufferTransport{nil}, &p1)
assert.ErrorIs(t, err, errMethodType)
}
55 changes: 46 additions & 9 deletions protocol/thrift/apache/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"bytes"
"context"
"io"
"unsafe"
)

// TTransport is identical with thrift.TTransport.
Expand All @@ -31,15 +32,51 @@ type TTransport interface {
IsOpen() bool
}

// BufferTransport extends bytes.Buffer to support TTransport
type BufferTransport struct {
*bytes.Buffer
type defaultTransport struct {
io.ReadWriter
}

func (p BufferTransport) IsOpen() bool { return true }
func (p BufferTransport) Open() error { return nil }
func (p BufferTransport) Close() error { p.Reset(); return nil }
func (p BufferTransport) Flush(_ context.Context) error { return nil }
func (p BufferTransport) RemainingBytes() uint64 { return uint64(p.Len()) }
// NewDefaultTransport converts io.ReadWriter to TTransport.
// Use NewBufferTransport if using *bytes.Buffer for better performance.
func NewDefaultTransport(rw io.ReadWriter) TTransport {
if buf, ok := rw.(*bytes.Buffer); ok {
return NewBufferTransport(buf)
}
return defaultTransport{rw}
}

// remoteByteBuffer represents remote.ByteBuffer in kitex
type remoteByteBuffer interface {
ReadableLen() (n int)
}

func (p defaultTransport) IsOpen() bool { return true }
func (p defaultTransport) Open() error { return nil }
func (p defaultTransport) Close() error { return nil }
func (p defaultTransport) Flush(_ context.Context) error { return nil }

func (p defaultTransport) RemainingBytes() uint64 {
if v, ok := p.ReadWriter.(remoteByteBuffer); ok {
n := v.ReadableLen()
if n > 0 {
return uint64(n)
}
}
return ^uint64(0)
}

type bufferTransport struct {
bytes.Buffer
}

// NewBufferTransport extends bytes.Buffer to support TTransport
func NewBufferTransport(buf *bytes.Buffer) TTransport {
// reuse buf's pointer with more methods
return (*bufferTransport)(unsafe.Pointer(buf))
}

var _ TTransport = BufferTransport{nil}
func (p *bufferTransport) IsOpen() bool { return true }
func (p *bufferTransport) Open() error { return nil }
func (p *bufferTransport) Close() error { p.Reset(); return nil }
func (p *bufferTransport) Flush(_ context.Context) error { return nil }
func (p *bufferTransport) RemainingBytes() uint64 { return uint64(p.Len()) }
30 changes: 26 additions & 4 deletions protocol/thrift/apache/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,38 @@ package apache
import (
"bytes"
"context"
"io"
"testing"

"github.com/stretchr/testify/require"
)

func TestTBufferTransport(t *testing.T) {
buf := &bytes.Buffer{}
type mockReadableLen struct {
io.ReadWriter

n int
}

func (f *mockReadableLen) ReadableLen() int { return f.n }

p := BufferTransport{buf}
func TestTBufferTransport(t *testing.T) {
m := &mockReadableLen{n: 7}
p := NewDefaultTransport(m)
_ = p.IsOpen()
_ = p.Open()
_ = p.Close()
_ = p.Flush(context.Background())
_ = p.RemainingBytes()
require.Equal(t, uint64(7), p.RemainingBytes())
m.n = -1
require.Equal(t, ^uint64(0), p.RemainingBytes())

b := &bytes.Buffer{}
b.WriteByte(0)
p = NewDefaultTransport(b)
_ = p.IsOpen()
_ = p.Open()
_ = p.Flush(context.Background())
require.Equal(t, uint64(1), p.RemainingBytes())
require.NoError(t, p.Close())
require.Equal(t, uint64(0), p.RemainingBytes())
}

0 comments on commit 762726c

Please sign in to comment.