Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(thrift/apache): check methods for kitex #13

Merged
merged 1 commit into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 70 additions & 23 deletions protocol/thrift/apache/apache.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import (
"errors"
"fmt"
"reflect"
"sync"
)

var (
Expand Down Expand Up @@ -101,6 +102,8 @@ func RegisterNewTBinaryProtocol(fn interface{}) error {
return errNewFuncTypeNotMatch(t)
}
newTBinaryProtocol = v
hasThriftRead = sync.Map{}
hasThriftWrite = sync.Map{}
return nil
}

Expand All @@ -121,30 +124,83 @@ func checkThriftReadWriteFuncType(t reflect.Type) error {
return nil
}

// ThriftRead calls Read method of v.
//
// RegisterNewTBinaryProtocol must be called with `thrift.NewTBinaryProtocol`
// before using this func.
func ThriftRead(t TTransport, v interface{}) error {
var hasThriftRead = sync.Map{}

// CheckThriftRead returns nil if v has Read method and matches the func signature
func CheckThriftRead(v interface{}) error {
rv := reflect.ValueOf(v)
rt := rv.Type()
res, ok := hasThriftRead.Load(rt)
if ok {
// fast path
if res == nil {
return nil
}
return res.(error)
}
if rv.Kind() != reflect.Ptr {
// Read/Write method is always pointer receiver
hasThriftRead.Store(rt, errNotPointer)
return errNotPointer
}
rfunc := rv.MethodByName("Read")

// check Read func signature: func(thrift.TProtocol) error
if !rfunc.IsValid() || rfunc.Kind() != reflect.Func {
fv := rv.MethodByName("Read")
if !fv.IsValid() {
hasThriftRead.Store(rt, errNoReadMethod)
return errNoReadMethod
}
if err := checkThriftReadWriteFuncType(rfunc.Type()); err != nil {
if err := checkThriftReadWriteFuncType(fv.Type()); err != nil {
hasThriftRead.Store(rt, err)
return err
}
hasThriftRead.Store(rt, nil)
return nil
}

var hasThriftWrite = sync.Map{}

// CheckThriftWrite returns nil if v has Write method and matches the func signature
func CheckThriftWrite(v interface{}) error {
rv := reflect.ValueOf(v)
rt := rv.Type()
res, ok := hasThriftWrite.Load(rt)
if ok {
// fast path
if res == nil {
return nil
}
return res.(error)
}
if rv.Kind() != reflect.Ptr {
hasThriftWrite.Store(rt, errNotPointer)
return errNotPointer
}
fv := rv.MethodByName("Write")
if !fv.IsValid() {
hasThriftWrite.Store(rt, errNoWriteMethod)
return errNoWriteMethod
}
if err := checkThriftReadWriteFuncType(fv.Type()); err != nil {
hasThriftWrite.Store(rt, err)
return err
}
hasThriftWrite.Store(rt, nil)
return nil
}

// ThriftRead calls Read method of v.
//
// RegisterNewTBinaryProtocol must be called with `thrift.NewTBinaryProtocol`
// before using this func.
func ThriftRead(t TTransport, v interface{}) error {
if err := CheckThriftRead(v); err != nil {
return err
}

// iprot := NewTBinaryProtocol(t, true, true)
iprot := newTBinaryProtocol.Call([]reflect.Value{reflect.ValueOf(t), rvTrue, rvTrue})[0]

// err := v.Read(iprot)
rv := reflect.ValueOf(v)
rfunc := rv.MethodByName("Read")
err := rfunc.Call([]reflect.Value{iprot})[0]
if err.IsNil() {
return nil
Expand All @@ -157,25 +213,16 @@ func ThriftRead(t TTransport, v interface{}) error {
// RegisterNewTBinaryProtocol must be called with `thrift.NewTBinaryProtocol`
// before using this func.
func ThriftWrite(t TTransport, v interface{}) error {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr {
// Read/Write method is always pointer receiver
return errNotPointer
}
wfunc := rv.MethodByName("Write")

// check Write func signature: func(thrift.TProtocol) error
if !wfunc.IsValid() || wfunc.Kind() != reflect.Func {
return errNoWriteMethod
}
if err := checkThriftReadWriteFuncType(wfunc.Type()); err != nil {
if err := CheckThriftWrite(v); err != nil {
return err
}

// oprot := NewTBinaryProtocol(t, true, true)
oprot := newTBinaryProtocol.Call([]reflect.Value{reflect.ValueOf(t), rvTrue, rvTrue})[0]

// err := v.Write(oprot)
rv := reflect.ValueOf(v)
wfunc := rv.MethodByName("Write")
err := wfunc.Call([]reflect.Value{oprot})[0]
if err.IsNil() {
return nil
Expand Down
109 changes: 78 additions & 31 deletions protocol/thrift/apache/apache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"errors"
"io"
"reflect"
"sync"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -85,58 +86,104 @@ func TestThriftWriteRead(t *testing.T) {
assert.True(t, b0)
assert.True(t, b1)
called++
return trans.(BufferTransport).Buffer
return &trans.(*bufferTransport).Buffer
}
err := RegisterNewTBinaryProtocol(fn)
require.NoError(t, err)
defer func() { newTBinaryProtocol = reflect.Value{} }()

buf := &bytes.Buffer{}
p0 := &TestingWriteRead{Msg: "hello"}
err = ThriftWrite(BufferTransport{buf}, p0) // calls p0.Write
require.NoError(t, err)
require.Equal(t, 1, called)

p1 := &TestingWriteRead{}
err = ThriftRead(BufferTransport{buf}, p1) // calls p1.Read
require.NoError(t, err)
require.Equal(t, 2, called)
require.Equal(t, p0, p1)
expectcalls := 0
for i := 0; i < 2; i++ { // run twice to test cache
buf := &bytes.Buffer{}
p0 := &TestingWriteRead{Msg: "hello"}
err = ThriftWrite(NewBufferTransport(buf), p0) // calls p0.Write
require.NoError(t, err)
expectcalls++
require.Equal(t, expectcalls, called)

p1 := &TestingWriteRead{}
err = ThriftRead(NewBufferTransport(buf), p1) // calls p1.Read
require.NoError(t, err)
expectcalls++
require.Equal(t, expectcalls, called)
require.Equal(t, p0, p1)
}
}

type TestingWriteReadMethodNotMatch struct{}

func (p *TestingWriteReadMethodNotMatch) Read(v bool) error { return nil }
func (p *TestingWriteReadMethodNotMatch) Write(v bool) error { return nil }
func (_ *TestingWriteReadMethodNotMatch) Read(v bool) error { return nil }
func (_ *TestingWriteReadMethodNotMatch) Write(v bool) error { return nil }

type TestingNoReadWriteMethod struct{}

func (_ *TestingNoReadWriteMethod) Read1(v bool) error { return nil }
func (_ *TestingNoReadWriteMethod) Write1(v bool) error { return nil }

type TestingWriteReadNotReturningErr struct{}

func (_ *TestingWriteReadNotReturningErr) Read(r io.Reader) {}
func (_ *TestingWriteReadNotReturningErr) Write(w io.Writer) {}

func TestCheckThriftReadWriteErr(t *testing.T) {
// reset type cache
hasThriftRead = sync.Map{}
hasThriftWrite = sync.Map{}

func TestThriftWriteReadErr(t *testing.T) {
var err error

// errNotPointer
p := TestingWriteRead{Msg: "hello"}
err = ThriftWrite(BufferTransport{nil}, p)
assert.Same(t, err, errNotPointer)
err = ThriftRead(BufferTransport{nil}, p)
assert.Same(t, err, errNotPointer)
for i := 0; i < 2; i++ { // run twice to test cache
err = CheckThriftRead(TestingWriteRead{})
assert.Same(t, err, errNotPointer)
err = CheckThriftWrite(TestingWriteRead{})
assert.Same(t, err, errNotPointer)
}

// errNoNewTBinaryProtocol
err = ThriftWrite(BufferTransport{nil}, &p)
assert.Same(t, err, errNoNewTBinaryProtocol)
for i := 0; i < 2; i++ { // run twice to test cache
err = CheckThriftRead(&TestingWriteRead{})
assert.Same(t, err, errNoNewTBinaryProtocol)
err = CheckThriftWrite(&TestingWriteRead{})
assert.Same(t, err, errNoNewTBinaryProtocol)
}

fn := func(trans TTransport, b0, b1 bool) *bytes.Buffer { return nil }
RegisterNewTBinaryProtocol(fn)
defer func() { newTBinaryProtocol = reflect.Value{} }()

// errMethodType
for i := 0; i < 2; i++ {
err = CheckThriftRead(&TestingWriteReadMethodNotMatch{}) // input type err
assert.ErrorIs(t, err, errMethodType)
err = CheckThriftWrite(&TestingWriteReadMethodNotMatch{})
assert.ErrorIs(t, err, errMethodType)
err = CheckThriftRead(&TestingWriteReadNotReturningErr{}) // return type err
assert.ErrorIs(t, err, errMethodType)
err = CheckThriftWrite(&TestingWriteReadNotReturningErr{})
assert.ErrorIs(t, err, errMethodType)
}

// errNoReadMethod, errNoWriteMethod
for i := 0; i < 2; i++ {
err = CheckThriftRead(&TestingNoReadWriteMethod{})
assert.ErrorIs(t, err, errNoReadMethod)
err = CheckThriftWrite(&TestingNoReadWriteMethod{})
assert.ErrorIs(t, err, errNoWriteMethod)
}
}

func TestThriftWriteReadErr(t *testing.T) {
var err error

// Read/Write returns err
p := TestingWriteRead{}
fn := func(trans TTransport, b0, b1 bool) *bytes.Buffer { return nil }
RegisterNewTBinaryProtocol(fn)
defer func() { newTBinaryProtocol = reflect.Value{} }()
p.mockErr = errors.New("mock")
err = ThriftWrite(BufferTransport{nil}, &p)
err = ThriftWrite(NewBufferTransport(nil), &p)
assert.Same(t, err, p.mockErr)
err = ThriftRead(BufferTransport{nil}, &p)
err = ThriftRead(NewBufferTransport(nil), &p)
assert.Same(t, err, p.mockErr)

// errMethodType
p1 := TestingWriteReadMethodNotMatch{}
err = ThriftWrite(BufferTransport{nil}, &p1)
assert.ErrorIs(t, err, errMethodType)
err = ThriftRead(BufferTransport{nil}, &p1)
assert.ErrorIs(t, err, errMethodType)
}
55 changes: 46 additions & 9 deletions protocol/thrift/apache/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"bytes"
"context"
"io"
"unsafe"
)

// TTransport is identical with thrift.TTransport.
Expand All @@ -31,15 +32,51 @@ type TTransport interface {
IsOpen() bool
}

// BufferTransport extends bytes.Buffer to support TTransport
type BufferTransport struct {
*bytes.Buffer
type defaultTransport struct {
io.ReadWriter
}

func (p BufferTransport) IsOpen() bool { return true }
func (p BufferTransport) Open() error { return nil }
func (p BufferTransport) Close() error { p.Reset(); return nil }
func (p BufferTransport) Flush(_ context.Context) error { return nil }
func (p BufferTransport) RemainingBytes() uint64 { return uint64(p.Len()) }
// NewDefaultTransport converts io.ReadWriter to TTransport.
// Use NewBufferTransport if using *bytes.Buffer for better performance.
func NewDefaultTransport(rw io.ReadWriter) TTransport {
if buf, ok := rw.(*bytes.Buffer); ok {
return NewBufferTransport(buf)
}
return defaultTransport{rw}
}

// remoteByteBuffer represents remote.ByteBuffer in kitex
type remoteByteBuffer interface {
ReadableLen() (n int)
}

func (p defaultTransport) IsOpen() bool { return true }
func (p defaultTransport) Open() error { return nil }
func (p defaultTransport) Close() error { return nil }
func (p defaultTransport) Flush(_ context.Context) error { return nil }

func (p defaultTransport) RemainingBytes() uint64 {
if v, ok := p.ReadWriter.(remoteByteBuffer); ok {
n := v.ReadableLen()
if n > 0 {
return uint64(n)
}
}
return ^uint64(0)
}

type bufferTransport struct {
bytes.Buffer
}

// NewBufferTransport extends bytes.Buffer to support TTransport
func NewBufferTransport(buf *bytes.Buffer) TTransport {
// reuse buf's pointer with more methods
return (*bufferTransport)(unsafe.Pointer(buf))
}

var _ TTransport = BufferTransport{nil}
func (p *bufferTransport) IsOpen() bool { return true }
func (p *bufferTransport) Open() error { return nil }
func (p *bufferTransport) Close() error { p.Reset(); return nil }
func (p *bufferTransport) Flush(_ context.Context) error { return nil }
func (p *bufferTransport) RemainingBytes() uint64 { return uint64(p.Len()) }
30 changes: 26 additions & 4 deletions protocol/thrift/apache/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,38 @@ package apache
import (
"bytes"
"context"
"io"
"testing"

"github.com/stretchr/testify/require"
)

func TestTBufferTransport(t *testing.T) {
buf := &bytes.Buffer{}
type mockReadableLen struct {
io.ReadWriter

n int
}

func (f *mockReadableLen) ReadableLen() int { return f.n }

p := BufferTransport{buf}
func TestTBufferTransport(t *testing.T) {
m := &mockReadableLen{n: 7}
p := NewDefaultTransport(m)
_ = p.IsOpen()
_ = p.Open()
_ = p.Close()
_ = p.Flush(context.Background())
_ = p.RemainingBytes()
require.Equal(t, uint64(7), p.RemainingBytes())
m.n = -1
require.Equal(t, ^uint64(0), p.RemainingBytes())

b := &bytes.Buffer{}
b.WriteByte(0)
p = NewDefaultTransport(b)
_ = p.IsOpen()
_ = p.Open()
_ = p.Flush(context.Background())
require.Equal(t, uint64(1), p.RemainingBytes())
require.NoError(t, p.Close())
require.Equal(t, uint64(0), p.RemainingBytes())
}
Loading