diff --git a/protocol/thrift/apache/apache.go b/protocol/thrift/apache/apache.go index 5af65e7..be42536 100644 --- a/protocol/thrift/apache/apache.go +++ b/protocol/thrift/apache/apache.go @@ -37,13 +37,14 @@ package apache import ( "errors" + "io" ) var ( fnCheckTStruct func(v interface{}) error - fnThriftRead func(t TTransport, v interface{}) error - fnThriftWrite func(t TTransport, v interface{}) error + fnThriftRead func(rw io.ReadWriter, v interface{}) error + fnThriftWrite func(rw io.ReadWriter, v interface{}) error ) // RegisterCheckTStruct accepts `thrift.TStruct check` func and save it for later use. @@ -52,12 +53,12 @@ func RegisterCheckTStruct(fn func(v interface{}) error) { } // RegisterThriftRead ... -func RegisterThriftRead(fn func(t TTransport, v interface{}) error) { +func RegisterThriftRead(fn func(rw io.ReadWriter, v interface{}) error) { fnThriftRead = fn } // RegisterThriftWrite ... -func RegisterThriftWrite(fn func(t TTransport, v interface{}) error) { +func RegisterThriftWrite(fn func(rw io.ReadWriter, v interface{}) error) { fnThriftWrite = fn } @@ -76,17 +77,17 @@ func CheckTStruct(v interface{}) error { } // ThriftRead ... -func ThriftRead(t TTransport, v interface{}) error { +func ThriftRead(rw io.ReadWriter, v interface{}) error { if fnThriftRead == nil { return errThriftReadNotRegistered } - return fnThriftRead(t, v) + return fnThriftRead(rw, v) } // ThriftWrite ... -func ThriftWrite(t TTransport, v interface{}) error { +func ThriftWrite(rw io.ReadWriter, v interface{}) error { if fnThriftWrite == nil { return errThriftWriteNotRegistered } - return fnThriftWrite(t, v) + return fnThriftWrite(rw, v) } diff --git a/protocol/thrift/apache/apache_test.go b/protocol/thrift/apache/apache_test.go index 64e3fd6..bcbf85e 100644 --- a/protocol/thrift/apache/apache_test.go +++ b/protocol/thrift/apache/apache_test.go @@ -38,7 +38,7 @@ func TestThriftReadWrite(t *testing.T) { buf := &bytes.Buffer{} - err = ThriftWrite(NewBufferTransport(buf), v) + err = ThriftWrite(buf, v) require.Same(t, err, errThriftWriteNotRegistered) RegisterThriftWrite(callThriftWrite) @@ -84,18 +84,18 @@ func checkTStruct(v interface{}) error { return nil } -func callThriftRead(t TTransport, v interface{}) error { +func callThriftRead(rw io.ReadWriter, v interface{}) error { p, ok := v.(TStruct) if !ok { return errNotThriftTStruct } - return p.Read(t) + return p.Read(rw) } -func callThriftWrite(t TTransport, v interface{}) error { +func callThriftWrite(rw io.ReadWriter, v interface{}) error { p, ok := v.(TStruct) if !ok { return errNotThriftTStruct } - return p.Write(t) + return p.Write(rw) }