From cf6cf61ab35b75f24b20727110f84352843dbc4e Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Wed, 31 Jul 2024 18:23:58 +0800 Subject: [PATCH] feat(thrift): new pkg for deprecating apache (#6) --- protocol/thrift/apache/apache.go | 184 +++++++++++++++++++++++ protocol/thrift/apache/apache_test.go | 142 +++++++++++++++++ protocol/thrift/apache/transport.go | 45 ++++++ protocol/thrift/apache/transport_test.go | 34 +++++ 4 files changed, 405 insertions(+) create mode 100644 protocol/thrift/apache/apache.go create mode 100644 protocol/thrift/apache/apache_test.go create mode 100644 protocol/thrift/apache/transport.go create mode 100644 protocol/thrift/apache/transport_test.go diff --git a/protocol/thrift/apache/apache.go b/protocol/thrift/apache/apache.go new file mode 100644 index 0000000..1024bcf --- /dev/null +++ b/protocol/thrift/apache/apache.go @@ -0,0 +1,184 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package apache contains code for working with apache thrift indirectly +// +// It acts as a bridge between generated code which relies on apache codec like: +// +// Write(p thrift.TProtocol) error +// Read(p thrift.TProtocol) error +// +// and kitex ecosystem. +// +// Because we're deprecating apache thrift, all kitex ecosystem code will not rely on apache thrift +// except one pkg: `github.com/cloudwego/kitex/pkg/protocol/bthrift`. Why is the package chosen? +// All legacy generated code relies on it, and we may not be able to update the code in a brief timeframe. +// So the package is chosen to register `thrift.NewTBinaryProtocol` to this package in order to use it +// without importing `github.com/apache/thrift` +// +// ThriftRead or ThriftWrite is implemented for calling Read/Write +// without knowing the interface of `thrift.TProtocol`. +// Since we already have `thrift.NewTBinaryProtocol`, we only need to check: +// if the return value of `thrift.NewTBinaryProtocol` implements +// the input which is `thrift.TProtocol` of Read/Write +// +// For new generated code, +// it no longer uses the `github.com/cloudwego/kitex/pkg/protocol/bthrift` +package apache + +import ( + "errors" + "fmt" + "reflect" +) + +var ( + newTBinaryProtocol reflect.Value + + rvTrue = reflect.ValueOf(true) // for calling NewTBinaryProtocol +) + +var ( + ttransportType = reflect.TypeOf((*TTransport)(nil)).Elem() + errorType = reflect.TypeOf((*error)(nil)).Elem() +) + +var ( + errNoNewTBinaryProtocol = errors.New("thrift.NewTBinaryProtocol method not registered. Make sure you're using apache/thrift == 0.13.0 and clouwdwego/kitex >= 0.11.0") + errNotPointer = errors.New("input not pointer") + errNoReadMethod = errors.New("thrift.TStruct `Read` method not found") + errNoWriteMethod = errors.New("thrift.TStruct `Write` method not found") + + errMethodType = errors.New("method type not match") + errNewFuncType = errors.New("function type not match") +) + +func errNewFuncTypeNotMatch(t reflect.Type) error { + const expect = "func(thrift.TTransport, bool, bool) *thrift.TBinaryProtocol" + return fmt.Errorf("%w:\n\texpect: %s\n\t got: %s", errNewFuncType, expect, t) +} + +func errReadWriteMethodNotMatch(t reflect.Type) error { + const expect = "func(thrift.TProtocol) error" + return fmt.Errorf("%w:\n\texpect: %s\n\t got: %s", errMethodType, expect, t) +} + +// RegisterNewTBinaryProtocol accepts `thrift.NewTBinaryProtocol` func and save it for later use. +func RegisterNewTBinaryProtocol(fn interface{}) error { + v := reflect.ValueOf(fn) + t := v.Type() + + // check it's func + if t.Kind() != reflect.Func { + return errNewFuncTypeNotMatch(t) + } + + // check "func(thrift.TTransport, bool, bool) *thrift.TBinaryProtocol" + // can also check with t.String() instead of field by field? + if t.NumIn() != 3 || + !t.In(0).Implements(ttransportType) || + t.In(1).Kind() != reflect.Bool || + t.In(2).Kind() != reflect.Bool { + return errNewFuncTypeNotMatch(t) + } + if t.NumOut() != 1 { + // not checking if it's thrift.TProtocol + // but in ThriftRead/ThriftWrite, we will check if it implements the input of Read/Write + // so we can make it easier to test. + return errNewFuncTypeNotMatch(t) + } + newTBinaryProtocol = v + return nil +} + +func checkThriftReadWriteFuncType(t reflect.Type) error { + if !newTBinaryProtocol.IsValid() { + return errNoNewTBinaryProtocol + } + + // checks `func(thrift.TProtocol) error` + if t.NumIn() != 1 || t.In(0).Kind() != reflect.Interface || + !newTBinaryProtocol.Type().Out(0).Implements(t.In(0)) { + return errReadWriteMethodNotMatch(t) + } + if t.NumOut() != 1 || + !t.Out(0).Implements(errorType) { + return errReadWriteMethodNotMatch(t) + } + return nil +} + +// ThriftRead calls Read method of v. +// +// RegisterNewTBinaryProtocol must be called with `thrift.NewTBinaryProtocol` +// before using this func. +func ThriftRead(t TTransport, v interface{}) error { + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr { + // Read/Write method is always pointer receiver + return errNotPointer + } + rfunc := rv.MethodByName("Read") + + // check Read func signature: func(thrift.TProtocol) error + if !rfunc.IsValid() || rfunc.Kind() != reflect.Func { + return errNoReadMethod + } + if err := checkThriftReadWriteFuncType(rfunc.Type()); err != nil { + return err + } + + // iprot := NewTBinaryProtocol(t, true, true) + iprot := newTBinaryProtocol.Call([]reflect.Value{reflect.ValueOf(t), rvTrue, rvTrue})[0] + + // err := v.Read(iprot) + err := rfunc.Call([]reflect.Value{iprot})[0] + if err.IsNil() { + return nil + } + return err.Interface().(error) +} + +// ThriftWrite calls Write method of v. +// +// RegisterNewTBinaryProtocol must be called with `thrift.NewTBinaryProtocol` +// before using this func. +func ThriftWrite(t TTransport, v interface{}) error { + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr { + // Read/Write method is always pointer receiver + return errNotPointer + } + wfunc := rv.MethodByName("Write") + + // check Write func signature: func(thrift.TProtocol) error + if !wfunc.IsValid() || wfunc.Kind() != reflect.Func { + return errNoWriteMethod + } + if err := checkThriftReadWriteFuncType(wfunc.Type()); err != nil { + return err + } + + // oprot := NewTBinaryProtocol(t, true, true) + oprot := newTBinaryProtocol.Call([]reflect.Value{reflect.ValueOf(t), rvTrue, rvTrue})[0] + + // err := v.Write(oprot) + err := wfunc.Call([]reflect.Value{oprot})[0] + if err.IsNil() { + return nil + } + return err.Interface().(error) +} diff --git a/protocol/thrift/apache/apache_test.go b/protocol/thrift/apache/apache_test.go new file mode 100644 index 0000000..b3a305b --- /dev/null +++ b/protocol/thrift/apache/apache_test.go @@ -0,0 +1,142 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package apache + +import ( + "bytes" + "encoding/json" + "errors" + "io" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRegisterNewTBinaryProtocol(t *testing.T) { + { // case: not func type + fn := 1 + err := RegisterNewTBinaryProtocol(fn) + t.Log(err) + assert.ErrorIs(t, err, errNewFuncType) + } + + { // case: args err + fn := func(_ TTransport, _ bool, _ int) {} + err := RegisterNewTBinaryProtocol(fn) + t.Log(err) + assert.ErrorIs(t, err, errNewFuncType) + } + + { // case: ret err + fn := func(_ TTransport, _, _ bool) {} + err := RegisterNewTBinaryProtocol(fn) + t.Log(err) + assert.ErrorIs(t, err, errNewFuncType) + } + + { // case: no err + fn := func(_ TTransport, _, _ bool) error { return nil } + err := RegisterNewTBinaryProtocol(fn) + assert.NoError(t, err) + assert.True(t, newTBinaryProtocol.IsValid()) + newTBinaryProtocol = reflect.Value{} // reset + } +} + +type TestingWriteRead struct { + Msg string + + mockErr error +} + +func (t *TestingWriteRead) Read(r io.Reader) error { + if t.mockErr != nil { + return t.mockErr + } + return json.NewDecoder(r).Decode(t) +} + +func (t *TestingWriteRead) Write(w io.Writer) error { + if t.mockErr != nil { + return t.mockErr + } + return json.NewEncoder(w).Encode(t) +} + +func TestThriftWriteRead(t *testing.T) { + called := 0 + fn := func(trans TTransport, b0, b1 bool) *bytes.Buffer { + assert.True(t, b0) + assert.True(t, b1) + called++ + return trans.(BufferTransport).Buffer + } + err := RegisterNewTBinaryProtocol(fn) + require.NoError(t, err) + defer func() { newTBinaryProtocol = reflect.Value{} }() + + buf := &bytes.Buffer{} + p0 := &TestingWriteRead{Msg: "hello"} + err = ThriftWrite(BufferTransport{buf}, p0) // calls p0.Write + require.NoError(t, err) + require.Equal(t, 1, called) + + p1 := &TestingWriteRead{} + err = ThriftRead(BufferTransport{buf}, p1) // calls p1.Read + require.NoError(t, err) + require.Equal(t, 2, called) + require.Equal(t, p0, p1) +} + +type TestingWriteReadMethodNotMatch struct{} + +func (p *TestingWriteReadMethodNotMatch) Read(v bool) error { return nil } +func (p *TestingWriteReadMethodNotMatch) Write(v bool) error { return nil } + +func TestThriftWriteReadErr(t *testing.T) { + var err error + + // errNotPointer + p := TestingWriteRead{Msg: "hello"} + err = ThriftWrite(BufferTransport{nil}, p) + assert.Same(t, err, errNotPointer) + err = ThriftRead(BufferTransport{nil}, p) + assert.Same(t, err, errNotPointer) + + // errNoNewTBinaryProtocol + err = ThriftWrite(BufferTransport{nil}, &p) + assert.Same(t, err, errNoNewTBinaryProtocol) + + // Read/Write returns err + fn := func(trans TTransport, b0, b1 bool) *bytes.Buffer { return nil } + RegisterNewTBinaryProtocol(fn) + defer func() { newTBinaryProtocol = reflect.Value{} }() + p.mockErr = errors.New("mock") + err = ThriftWrite(BufferTransport{nil}, &p) + assert.Same(t, err, p.mockErr) + err = ThriftRead(BufferTransport{nil}, &p) + assert.Same(t, err, p.mockErr) + + // errMethodType + p1 := TestingWriteReadMethodNotMatch{} + err = ThriftWrite(BufferTransport{nil}, &p1) + assert.ErrorIs(t, err, errMethodType) + err = ThriftRead(BufferTransport{nil}, &p1) + assert.ErrorIs(t, err, errMethodType) +} diff --git a/protocol/thrift/apache/transport.go b/protocol/thrift/apache/transport.go new file mode 100644 index 0000000..76e7044 --- /dev/null +++ b/protocol/thrift/apache/transport.go @@ -0,0 +1,45 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package apache + +import ( + "bytes" + "context" + "io" +) + +// TTransport is identical with thrift.TTransport. +type TTransport interface { + io.ReadWriteCloser + RemainingBytes() (num_bytes uint64) + Flush(ctx context.Context) (err error) + Open() error + IsOpen() bool +} + +// BufferTransport extends bytes.Buffer to support TTransport +type BufferTransport struct { + *bytes.Buffer +} + +func (p BufferTransport) IsOpen() bool { return true } +func (p BufferTransport) Open() error { return nil } +func (p BufferTransport) Close() error { p.Reset(); return nil } +func (p BufferTransport) Flush(_ context.Context) error { return nil } +func (p BufferTransport) RemainingBytes() uint64 { return uint64(p.Len()) } + +var _ TTransport = BufferTransport{nil} diff --git a/protocol/thrift/apache/transport_test.go b/protocol/thrift/apache/transport_test.go new file mode 100644 index 0000000..e568b56 --- /dev/null +++ b/protocol/thrift/apache/transport_test.go @@ -0,0 +1,34 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package apache + +import ( + "bytes" + "context" + "testing" +) + +func TestTBufferTransport(t *testing.T) { + buf := &bytes.Buffer{} + + p := BufferTransport{buf} + _ = p.IsOpen() + _ = p.Open() + _ = p.Close() + _ = p.Flush(context.Background()) + _ = p.RemainingBytes() +}