diff --git a/bufiox/defaultbuf.go b/bufiox/defaultbuf.go index cdc709d..58b750a 100644 --- a/bufiox/defaultbuf.go +++ b/bufiox/defaultbuf.go @@ -175,6 +175,10 @@ func (r *DefaultReader) ReadLen() (n int) { return r.ri } +func (r *DefaultReader) ReadableLen() (n int) { + return len(r.buf) - r.ri +} + func (r *DefaultReader) ReadBinary(bs []byte) (m int, err error) { m = r.acquire(len(bs)) copy(bs, r.buf[r.ri:r.ri+m]) @@ -195,7 +199,7 @@ func (r *DefaultReader) Release(e error) error { if len(r.buf)-r.ri == 0 { // release buf r.maxSizeStats.update(cap(r.buf)) - if !r.bufReadOnly { + if !r.bufReadOnly && cap(r.buf) > 0 { mcache.Free(r.buf) } r.buf = nil @@ -348,7 +352,9 @@ func (w *DefaultWriter) Flush() (err error) { } w.maxSizeStats.update(cap(w.buf)) if !w.disableCache { - mcache.Free(w.buf) + if cap(w.buf) > 0 { + mcache.Free(w.buf) + } if w.pendingBuf != nil { for _, buf := range w.pendingBuf { mcache.Free(buf) diff --git a/protocol/thrift/apache/apache.go b/protocol/thrift/apache/apache.go index be42536..76423fd 100644 --- a/protocol/thrift/apache/apache.go +++ b/protocol/thrift/apache/apache.go @@ -37,14 +37,15 @@ package apache import ( "errors" - "io" + + "github.com/cloudwego/gopkg/bufiox" ) var ( fnCheckTStruct func(v interface{}) error - fnThriftRead func(rw io.ReadWriter, v interface{}) error - fnThriftWrite func(rw io.ReadWriter, v interface{}) error + fnThriftRead func(r bufiox.Reader, v interface{}) error + fnThriftWrite func(w bufiox.Writer, v interface{}) error ) // RegisterCheckTStruct accepts `thrift.TStruct check` func and save it for later use. @@ -53,12 +54,12 @@ func RegisterCheckTStruct(fn func(v interface{}) error) { } // RegisterThriftRead ... -func RegisterThriftRead(fn func(rw io.ReadWriter, v interface{}) error) { +func RegisterThriftRead(fn func(r bufiox.Reader, v interface{}) error) { fnThriftRead = fn } // RegisterThriftWrite ... -func RegisterThriftWrite(fn func(rw io.ReadWriter, v interface{}) error) { +func RegisterThriftWrite(fn func(w bufiox.Writer, v interface{}) error) { fnThriftWrite = fn } @@ -77,17 +78,17 @@ func CheckTStruct(v interface{}) error { } // ThriftRead ... -func ThriftRead(rw io.ReadWriter, v interface{}) error { +func ThriftRead(r bufiox.Reader, v interface{}) error { if fnThriftRead == nil { return errThriftReadNotRegistered } - return fnThriftRead(rw, v) + return fnThriftRead(r, v) } // ThriftWrite ... -func ThriftWrite(rw io.ReadWriter, v interface{}) error { +func ThriftWrite(w bufiox.Writer, v interface{}) error { if fnThriftWrite == nil { return errThriftWriteNotRegistered } - return fnThriftWrite(rw, v) + return fnThriftWrite(w, v) } diff --git a/protocol/thrift/apache/apache_test.go b/protocol/thrift/apache/apache_test.go index bcbf85e..0cbb9a2 100644 --- a/protocol/thrift/apache/apache_test.go +++ b/protocol/thrift/apache/apache_test.go @@ -18,16 +18,14 @@ package apache import ( "bytes" - "encoding/json" "errors" - "io" "testing" + "github.com/cloudwego/gopkg/bufiox" "github.com/stretchr/testify/require" ) func TestThriftReadWrite(t *testing.T) { - v := &TestingWriteRead{Msg: "Hello"} err := CheckTStruct(v) @@ -37,41 +35,56 @@ func TestThriftReadWrite(t *testing.T) { require.NoError(t, err) buf := &bytes.Buffer{} + bw := bufiox.NewDefaultWriter(buf) - err = ThriftWrite(buf, v) + err = ThriftWrite(bw, v) require.Same(t, err, errThriftWriteNotRegistered) RegisterThriftWrite(callThriftWrite) - err = ThriftWrite(NewBufferTransport(buf), v) // calls v.Write + err = ThriftWrite(bw, v) // calls v.Write + require.NoError(t, err) + err = bw.Flush() require.NoError(t, err) p := &TestingWriteRead{} - err = ThriftRead(NewBufferTransport(buf), p) + br := bufiox.NewDefaultReader(buf) + + err = ThriftRead(br, p) require.Same(t, err, errThriftReadNotRegistered) RegisterThriftRead(callThriftRead) - err = ThriftRead(NewBufferTransport(buf), p) // calls p.Read + err = ThriftRead(br, p) // calls p.Read require.NoError(t, err) require.Equal(t, v.Msg, p.Msg) } type TStruct interface { // simulate thrift.TStruct - Read(r io.Reader) error - Write(w io.Writer) error + Read(r bufiox.Reader) error + Write(w bufiox.Writer) error } type TestingWriteRead struct { Msg string } -func (t *TestingWriteRead) Read(r io.Reader) error { - return json.NewDecoder(r).Decode(t) +func (t *TestingWriteRead) Read(r bufiox.Reader) error { + b, err := r.Next(5) + if err != nil { + return err + } + t.Msg = string(b) + return nil } -func (t *TestingWriteRead) Write(w io.Writer) error { - return json.NewEncoder(w).Encode(t) +func (t *TestingWriteRead) Write(w bufiox.Writer) error { + b, err := w.Malloc(5) + if err != nil { + return err + } + copy(b, t.Msg) + return nil } var errNotThriftTStruct = errors.New("errNotThriftTStruct") @@ -84,7 +97,7 @@ func checkTStruct(v interface{}) error { return nil } -func callThriftRead(rw io.ReadWriter, v interface{}) error { +func callThriftRead(rw bufiox.Reader, v interface{}) error { p, ok := v.(TStruct) if !ok { return errNotThriftTStruct @@ -92,7 +105,7 @@ func callThriftRead(rw io.ReadWriter, v interface{}) error { return p.Read(rw) } -func callThriftWrite(rw io.ReadWriter, v interface{}) error { +func callThriftWrite(rw bufiox.Writer, v interface{}) error { p, ok := v.(TStruct) if !ok { return errNotThriftTStruct diff --git a/protocol/thrift/binary.go b/protocol/thrift/binary.go index 7f50d53..455348a 100644 --- a/protocol/thrift/binary.go +++ b/protocol/thrift/binary.go @@ -22,9 +22,20 @@ import ( "math" "unsafe" + "github.com/bytedance/gopkg/lang/span" "github.com/cloudwego/gopkg/internal/hack" ) +var ( + spanCache = span.NewSpanCache(1024 * 1024) + spanCacheEnable bool = false +) + +// SetSpanCache enable/disable binary protocol bytes/string allocator +func SetSpanCache(enable bool) { + spanCacheEnable = enable +} + var Binary BinaryProtocol type BinaryProtocol struct{} @@ -316,8 +327,12 @@ func (p BinaryProtocol) ReadBinary(buf []byte) (b []byte, l int, err error) { if len(buf) < l { return nil, 4, errReadBin } - // TODO: use span - return []byte(string(buf[4:l])), l, nil + if spanCacheEnable { + b = spanCache.Copy(buf[4:l]) + } else { + b = []byte(string(buf[4:l])) + } + return b, l, nil } func (p BinaryProtocol) ReadString(buf []byte) (s string, l int, err error) { @@ -329,8 +344,13 @@ func (p BinaryProtocol) ReadString(buf []byte) (s string, l int, err error) { if len(buf) < l { return "", 4, errReadStr } - // TODO: use span - return string(buf[4:l]), l, nil + if spanCacheEnable { + data := spanCache.Copy(buf[4:l]) + s = hack.ByteSliceToString(data) + } else { + s = string(buf[4:l]) + } + return s, l, nil } func (BinaryProtocol) ReadBool(buf []byte) (v bool, l int, err error) { diff --git a/protocol/thrift/binarywriter.go b/protocol/thrift/binarywriter.go deleted file mode 100644 index d77636e..0000000 --- a/protocol/thrift/binarywriter.go +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package thrift - -import ( - "math" - "sync" -) - -const defaultBinaryWriterBufferSize = 4096 - -type BinaryWriter struct { - buf []byte -} - -var poolBinaryWriter = sync.Pool{ - New: func() interface{} { - return &BinaryWriter{buf: make([]byte, 0, defaultBinaryWriterBufferSize)} - }, -} - -func NewBinaryWriter() *BinaryWriter { - return NewBinaryWriterSize(0) -} - -func NewBinaryWriterSize(sz int) *BinaryWriter { - w := poolBinaryWriter.Get().(*BinaryWriter) - if cap(w.buf) < sz { - w.Release() - w = &BinaryWriter{buf: make([]byte, 0, sz)} - } - w.Reset() - return w -} - -func (w *BinaryWriter) Release() { - poolBinaryWriter.Put(w) -} - -func (w *BinaryWriter) Reset() { - w.buf = w.buf[:0] -} - -func (w *BinaryWriter) Bytes() []byte { - return w.buf -} - -func (w *BinaryWriter) WriteMessageBegin(name string, typeID TMessageType, seq int32) { - w.buf = Binary.AppendMessageBegin(w.buf, name, typeID, seq) -} - -func (w *BinaryWriter) WriteFieldBegin(typeID TType, id int16) { - w.buf = Binary.AppendFieldBegin(w.buf, typeID, id) -} - -func (w *BinaryWriter) WriteFieldStop() { - w.buf = append(w.buf, byte(STOP)) -} - -func (w *BinaryWriter) WriteMapBegin(kt, vt TType, size int) { - w.buf = Binary.AppendMapBegin(w.buf, kt, vt, size) -} - -func (w *BinaryWriter) WriteListBegin(et TType, size int) { - w.buf = Binary.AppendListBegin(w.buf, et, size) -} - -func (w *BinaryWriter) WriteSetBegin(et TType, size int) { - w.buf = Binary.AppendSetBegin(w.buf, et, size) -} - -func (w *BinaryWriter) WriteBinary(v []byte) { - w.buf = Binary.AppendBinary(w.buf, v) -} - -func (w *BinaryWriter) WriteString(v string) { - w.buf = Binary.AppendString(w.buf, v) -} - -func (w *BinaryWriter) WriteBool(v bool) { - if v { - w.buf = append(w.buf, 1) - } else { - w.buf = append(w.buf, 0) - } -} - -func (w *BinaryWriter) WriteByte(v int8) { - w.buf = append(w.buf, byte(v)) -} - -func (w *BinaryWriter) WriteI16(v int16) { - w.buf = append(w.buf, byte(uint16(v)>>8), byte(v)) -} - -func (w *BinaryWriter) WriteI32(v int32) { - w.buf = appendUint32(w.buf, uint32(v)) -} - -func (w *BinaryWriter) WriteI64(v int64) { - w.buf = appendUint64(w.buf, uint64(v)) -} - -func (w *BinaryWriter) WriteDouble(v float64) { - w.buf = appendUint64(w.buf, math.Float64bits(v)) -} diff --git a/protocol/thrift/binaryreader.go b/protocol/thrift/bufferreader.go similarity index 69% rename from protocol/thrift/binaryreader.go rename to protocol/thrift/bufferreader.go index a08ebf0..7df0fa7 100644 --- a/protocol/thrift/binaryreader.go +++ b/protocol/thrift/bufferreader.go @@ -19,90 +19,72 @@ package thrift import ( "encoding/binary" "fmt" - "io" "math" "sync" -) -// BinaryReader represents a reader for binary protocol -type BinaryReader struct { - r nextIface - d discardIface + "github.com/bytedance/gopkg/lang/dirtmake" + "github.com/cloudwego/gopkg/bufiox" + "github.com/cloudwego/gopkg/internal/hack" +) - rn int64 +// BufferReader represents a reader for binary protocol +type BufferReader struct { + r bufiox.Reader } -var poolBinaryReader = sync.Pool{ +var poolBufferReader = sync.Pool{ New: func() interface{} { - return &BinaryReader{} + return &BufferReader{} }, } -// NewBinaryReader ... call Release if no longer use for reusing -func NewBinaryReader(r io.Reader) *BinaryReader { - ret := poolBinaryReader.Get().(*BinaryReader) - ret.reset() - if nextr, ok := r.(nextIface); ok { - ret.r = nextr - } else { - nextr := newNextReader(r) - ret.r = nextr - ret.d = nextr - } +// NewBufferReader ... call Release if no longer use for reusing +func NewBufferReader(r bufiox.Reader) *BufferReader { + ret := poolBufferReader.Get().(*BufferReader) + ret.r = r return ret } -// Release ... -func (r *BinaryReader) Release() { - nextr, ok := r.r.(*nextReader) - if ok { - nextr.Release() - } - r.reset() - poolBinaryReader.Put(r) -} - -func (r *BinaryReader) reset() { +// Recycle ... +func (r *BufferReader) Recycle() { r.r = nil - r.d = nil - r.rn = 0 + poolBufferReader.Put(r) + return } -func (r *BinaryReader) next(n int) (b []byte, err error) { +func (r *BufferReader) next(n int) (b []byte, err error) { b, err = r.r.Next(n) if err != nil { err = NewProtocolExceptionWithErr(err) } - r.rn += int64(len(b)) return } -func (r *BinaryReader) skipn(n int) (err error) { +func (r *BufferReader) readBinary(bs []byte) (n int, err error) { + n, err = r.r.ReadBinary(bs) + if err != nil { + err = NewProtocolExceptionWithErr(err) + } + return +} + +func (r *BufferReader) skipn(n int) (err error) { if n < 0 { return errNegativeSize } - if r.d != nil { - var sz int - sz, err = r.d.Discard(n) - r.rn += int64(sz) - } else { - var b []byte - b, err = r.r.Next(n) - r.rn += int64(len(b)) - } - if err != nil { + if err = r.r.Skip(n); err != nil { return NewProtocolExceptionWithErr(err) } return nil } // Readn returns total bytes read from underlying reader -func (r *BinaryReader) Readn() int64 { - return r.rn +func (r *BufferReader) Readn() int64 { + return int64(r.r.ReadLen()) } // ReadBool ... -func (r *BinaryReader) ReadBool() (v bool, err error) { +func (r *BufferReader) ReadBool() (v bool, err error) { b, err := r.next(1) if err != nil { return false, err @@ -112,7 +94,7 @@ func (r *BinaryReader) ReadBool() (v bool, err error) { } // ReadByte ... -func (r *BinaryReader) ReadByte() (v int8, err error) { +func (r *BufferReader) ReadByte() (v int8, err error) { b, err := r.next(1) if err != nil { return 0, err @@ -122,7 +104,7 @@ func (r *BinaryReader) ReadByte() (v int8, err error) { } // ReadI16 ... -func (r *BinaryReader) ReadI16() (v int16, err error) { +func (r *BufferReader) ReadI16() (v int16, err error) { b, err := r.next(2) if err != nil { return 0, err @@ -132,7 +114,7 @@ func (r *BinaryReader) ReadI16() (v int16, err error) { } // ReadI32 ... -func (r *BinaryReader) ReadI32() (v int32, err error) { +func (r *BufferReader) ReadI32() (v int32, err error) { b, err := r.next(4) if err != nil { return 0, err @@ -142,7 +124,7 @@ func (r *BinaryReader) ReadI32() (v int32, err error) { } // ReadI64 ... -func (r *BinaryReader) ReadI64() (v int64, err error) { +func (r *BufferReader) ReadI64() (v int64, err error) { b, err := r.next(8) if err != nil { return 0, err @@ -152,7 +134,7 @@ func (r *BinaryReader) ReadI64() (v int64, err error) { } // ReadDouble ... -func (r *BinaryReader) ReadDouble() (v float64, err error) { +func (r *BufferReader) ReadDouble() (v float64, err error) { b, err := r.next(8) if err != nil { return 0, err @@ -162,32 +144,27 @@ func (r *BinaryReader) ReadDouble() (v float64, err error) { } // ReadBinary ... -func (r *BinaryReader) ReadBinary() (b []byte, err error) { +func (r *BufferReader) ReadBinary() (b []byte, err error) { sz, err := r.ReadI32() if err != nil { return nil, err } - b, err = r.next(int(sz)) - b = []byte(string(b)) // copy. use span cache? + b = dirtmake.Bytes(int(sz), int(sz)) + _, err = r.readBinary(b) return } // ReadString ... -func (r *BinaryReader) ReadString() (s string, err error) { - sz, err := r.ReadI32() +func (r *BufferReader) ReadString() (s string, err error) { + b, err := r.ReadBinary() if err != nil { return "", err } - b, err := r.next(int(sz)) - if err != nil { - return "", err - } - s = string(b) // copy. use span cache? - return + return hack.ByteSliceToString(b), nil } // ReadMessageBegin ... -func (r *BinaryReader) ReadMessageBegin() (name string, typeID TMessageType, seq int32, err error) { +func (r *BufferReader) ReadMessageBegin() (name string, typeID TMessageType, seq int32, err error) { var header int32 header, err = r.ReadI32() if err != nil { @@ -215,7 +192,7 @@ func (r *BinaryReader) ReadMessageBegin() (name string, typeID TMessageType, seq } // ReadFieldBegin ... -func (r *BinaryReader) ReadFieldBegin() (typeID TType, id int16, err error) { +func (r *BufferReader) ReadFieldBegin() (typeID TType, id int16, err error) { b, err := r.next(1) if err != nil { return 0, 0, err @@ -233,7 +210,7 @@ func (r *BinaryReader) ReadFieldBegin() (typeID TType, id int16, err error) { } // ReadMapBegin ... -func (r *BinaryReader) ReadMapBegin() (kt, vt TType, size int, err error) { +func (r *BufferReader) ReadMapBegin() (kt, vt TType, size int, err error) { b, err := r.next(6) if err != nil { return 0, 0, 0, err @@ -243,7 +220,7 @@ func (r *BinaryReader) ReadMapBegin() (kt, vt TType, size int, err error) { } // ReadListBegin ... -func (r *BinaryReader) ReadListBegin() (et TType, size int, err error) { +func (r *BufferReader) ReadListBegin() (et TType, size int, err error) { b, err := r.next(5) if err != nil { return 0, 0, err @@ -253,7 +230,7 @@ func (r *BinaryReader) ReadListBegin() (et TType, size int, err error) { } // ReadSetBegin ... -func (r *BinaryReader) ReadSetBegin() (et TType, size int, err error) { +func (r *BufferReader) ReadSetBegin() (et TType, size int, err error) { b, err := r.next(5) if err != nil { return 0, 0, err @@ -263,11 +240,11 @@ func (r *BinaryReader) ReadSetBegin() (et TType, size int, err error) { } // Skip ... -func (r *BinaryReader) Skip(t TType) error { +func (r *BufferReader) Skip(t TType) error { return r.skipType(t, defaultRecursionDepth) } -func (r *BinaryReader) skipstr() error { +func (r *BufferReader) skipstr() error { n, err := r.ReadI32() if err != nil { return err @@ -275,7 +252,7 @@ func (r *BinaryReader) skipstr() error { return r.skipn(int(n)) } -func (r *BinaryReader) skipType(t TType, maxdepth int) error { +func (r *BufferReader) skipType(t TType, maxdepth int) error { if maxdepth == 0 { return errDepthLimitExceeded } diff --git a/protocol/thrift/binaryreader_test.go b/protocol/thrift/bufferreader_test.go similarity index 97% rename from protocol/thrift/binaryreader_test.go rename to protocol/thrift/bufferreader_test.go index accd165..f437015 100644 --- a/protocol/thrift/binaryreader_test.go +++ b/protocol/thrift/bufferreader_test.go @@ -17,9 +17,9 @@ package thrift import ( - "bytes" "testing" + "github.com/cloudwego/gopkg/bufiox" "github.com/stretchr/testify/require" ) @@ -55,8 +55,7 @@ func TestBinaryReader(t *testing.T) { b = x.AppendDouble(b, 18.5) sz13 := len(b) - r := NewBinaryReader(bytes.NewReader(b)) - defer r.Release() + r := NewBufferReader(bufiox.NewBytesReader(b)) name, mt, seq, err := r.ReadMessageBegin() require.NoError(t, err) require.Equal(t, "hello", name) @@ -206,8 +205,7 @@ func TestBinaryReaderSkip(t *testing.T) { b = x.AppendFieldStop(b) sz10 := len(b) - r := NewBinaryReader(bytes.NewReader(b)) - defer r.Release() + r := NewBufferReader(bufiox.NewBytesReader(b)) err := r.Skip(BYTE) // byte require.NoError(t, err) @@ -242,6 +240,7 @@ func TestBinaryReaderSkip(t *testing.T) { err = r.Skip(STRUCT) // struct i32, list require.NoError(t, err) require.Equal(t, int64(sz10), r.Readn()) + r.Recycle() { // other cases // errDepthLimitExceeded @@ -249,7 +248,7 @@ func TestBinaryReaderSkip(t *testing.T) { for i := 0; i < defaultRecursionDepth+1; i++ { b = x.AppendFieldBegin(b, STRUCT, 1) } - r := NewBinaryReader(bytes.NewReader(b)) + r := NewBufferReader(bufiox.NewBytesReader(b)) err := r.Skip(STRUCT) require.Same(t, errDepthLimitExceeded, err) diff --git a/protocol/thrift/bufferwriter.go b/protocol/thrift/bufferwriter.go new file mode 100644 index 0000000..7459947 --- /dev/null +++ b/protocol/thrift/bufferwriter.go @@ -0,0 +1,180 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "encoding/binary" + "math" + "sync" + + "github.com/cloudwego/gopkg/bufiox" + "github.com/cloudwego/gopkg/internal/hack" +) + +type BufferWriter struct { + w bufiox.Writer +} + +var poolBufferWriter = sync.Pool{ + New: func() interface{} { + return &BufferWriter{} + }, +} + +func NewBufferWriter(iw bufiox.Writer) *BufferWriter { + w := poolBufferWriter.Get().(*BufferWriter) + w.w = iw + return w +} + +func (w *BufferWriter) Recycle() { + w.w = nil + poolBufferWriter.Put(w) + return +} + +func (w *BufferWriter) WriteMessageBegin(name string, typeID TMessageType, seq int32) error { + buf, err := w.w.Malloc(Binary.MessageBeginLength(name)) + if err != nil { + return err + } + binary.BigEndian.PutUint32(buf, uint32(msgVersion1)|uint32(typeID&msgTypeMask)) + binary.BigEndian.PutUint32(buf[4:], uint32(len(name))) + copy(buf[8:], name) + binary.BigEndian.PutUint32(buf[8+len(name):], uint32(seq)) + return nil +} + +func (w *BufferWriter) WriteFieldBegin(typeID TType, id int16) error { + buf, err := w.w.Malloc(3) + if err != nil { + return err + } + buf[0], buf[1], buf[2] = byte(typeID), byte(uint16(id>>8)), byte(id) + return nil +} + +func (w *BufferWriter) WriteFieldStop() error { + buf, err := w.w.Malloc(1) + if err != nil { + return err + } + buf[0] = byte(STOP) + return nil +} + +func (w *BufferWriter) WriteMapBegin(kt, vt TType, size int) error { + buf, err := w.w.Malloc(6) + if err != nil { + return err + } + buf[0], buf[1] = byte(kt), byte(vt) + binary.BigEndian.PutUint32(buf[2:], uint32(size)) + return nil +} + +func (w *BufferWriter) WriteListBegin(et TType, size int) error { + buf, err := w.w.Malloc(5) + if err != nil { + return err + } + buf[0] = byte(et) + binary.BigEndian.PutUint32(buf[1:], uint32(size)) + return nil +} + +func (w *BufferWriter) WriteSetBegin(et TType, size int) error { + buf, err := w.w.Malloc(5) + if err != nil { + return err + } + buf[0] = byte(et) + binary.BigEndian.PutUint32(buf[1:], uint32(size)) + return nil +} + +func (w *BufferWriter) WriteBinary(v []byte) error { + buf, err := w.w.Malloc(4) + if err != nil { + return err + } + binary.BigEndian.PutUint32(buf, uint32(len(v))) + _, err = w.w.WriteBinary(v) + return err +} + +func (w *BufferWriter) WriteString(v string) error { + return w.WriteBinary(hack.StringToByteSlice(v)) +} + +func (w *BufferWriter) WriteBool(v bool) error { + buf, err := w.w.Malloc(1) + if err != nil { + return err + } + if v { + buf[0] = 1 + } else { + buf[0] = 0 + } + return nil +} + +func (w *BufferWriter) WriteByte(v int8) error { + buf, err := w.w.Malloc(1) + if err != nil { + return err + } + buf[0] = byte(v) + return nil +} + +func (w *BufferWriter) WriteI16(v int16) error { + buf, err := w.w.Malloc(2) + if err != nil { + return err + } + binary.BigEndian.PutUint16(buf, uint16(v)) + return nil +} + +func (w *BufferWriter) WriteI32(v int32) error { + buf, err := w.w.Malloc(4) + if err != nil { + return err + } + binary.BigEndian.PutUint32(buf, uint32(v)) + return nil +} + +func (w *BufferWriter) WriteI64(v int64) error { + buf, err := w.w.Malloc(8) + if err != nil { + return err + } + binary.BigEndian.PutUint64(buf, uint64(v)) + return nil +} + +func (w *BufferWriter) WriteDouble(v float64) error { + buf, err := w.w.Malloc(8) + if err != nil { + return err + } + binary.BigEndian.PutUint64(buf, math.Float64bits(v)) + return nil +} diff --git a/protocol/thrift/binarywriter_test.go b/protocol/thrift/bufferwriter_test.go similarity index 75% rename from protocol/thrift/binarywriter_test.go rename to protocol/thrift/bufferwriter_test.go index ee77817..c9e176f 100644 --- a/protocol/thrift/binarywriter_test.go +++ b/protocol/thrift/bufferwriter_test.go @@ -19,68 +19,64 @@ package thrift import ( "testing" + "github.com/bytedance/gopkg/lang/dirtmake" + "github.com/cloudwego/gopkg/bufiox" "github.com/stretchr/testify/require" ) +const defaultBinaryWriterBufferSize = 4096 + func TestBinaryWriter(t *testing.T) { - w := NewBinaryWriterSize(defaultBinaryWriterBufferSize * 2) + buf := dirtmake.Bytes(0, defaultBinaryWriterBufferSize*2) + w := NewBufferWriter(bufiox.NewBytesWriter(&buf)) x := BinaryProtocol{} b := x.AppendMessageBegin(nil, "hello", 1, 2) w.WriteMessageBegin("hello", 1, 2) - require.Equal(t, b, w.Bytes()) b = x.AppendFieldBegin(b, 3, 4) w.WriteFieldBegin(3, 4) - require.Equal(t, b, w.Bytes()) b = x.AppendFieldStop(b) w.WriteFieldStop() - require.Equal(t, b, w.Bytes()) b = x.AppendMapBegin(b, 5, 6, 7) w.WriteMapBegin(5, 6, 7) - require.Equal(t, b, w.Bytes()) b = x.AppendListBegin(b, 8, 9) w.WriteListBegin(8, 9) - require.Equal(t, b, w.Bytes()) b = x.AppendSetBegin(b, 10, 11) w.WriteSetBegin(10, 11) - require.Equal(t, b, w.Bytes()) b = x.AppendBinary(b, []byte("12")) w.WriteBinary([]byte("12")) - require.Equal(t, b, w.Bytes()) b = x.AppendString(b, "13") w.WriteString("13") - require.Equal(t, b, w.Bytes()) b = x.AppendBool(b, true) b = x.AppendBool(b, false) w.WriteBool(true) w.WriteBool(false) - require.Equal(t, b, w.Bytes()) b = x.AppendByte(b, 14) w.WriteByte(14) - require.Equal(t, b, w.Bytes()) b = x.AppendI16(b, 15) w.WriteI16(15) - require.Equal(t, b, w.Bytes()) b = x.AppendI32(b, 16) w.WriteI32(16) - require.Equal(t, b, w.Bytes()) b = x.AppendI64(b, 17) w.WriteI64(17) - require.Equal(t, b, w.Bytes()) b = x.AppendDouble(b, 18.5) w.WriteDouble(18.5) - require.Equal(t, b, w.Bytes()) + + w.w.Flush() + w.Recycle() + + require.Equal(t, b, buf) } diff --git a/protocol/thrift/skipdecoder.go b/protocol/thrift/skipdecoder.go index 3e5f1c4..114ab28 100644 --- a/protocol/thrift/skipdecoder.go +++ b/protocol/thrift/skipdecoder.go @@ -19,10 +19,14 @@ package thrift import ( "encoding/binary" "fmt" - "io" "sync" + + "github.com/bytedance/gopkg/lang/mcache" + "github.com/cloudwego/gopkg/bufiox" ) +const defaultSkipDecoderSize = 4096 + var poolSkipDecoder = sync.Pool{ New: func() interface{} { return &SkipDecoder{} @@ -31,53 +35,46 @@ var poolSkipDecoder = sync.Pool{ // SkipDecoder scans the underlying io.Reader and returns the bytes of a type type SkipDecoder struct { - p skipReaderIface + r bufiox.Reader + + // for storing Next(ttype) buffer + nextBuf []byte + + // for reusing buffer + pendingBuf [][]byte } // NewSkipDecoder ... call Release if no longer use -func NewSkipDecoder(r io.Reader) *SkipDecoder { +func NewSkipDecoder(r bufiox.Reader) *SkipDecoder { p := poolSkipDecoder.Get().(*SkipDecoder) - p.Reset(r) + p.r = r return p } -// Reset ... -func (p *SkipDecoder) Reset(r io.Reader) { - // fast path without returning to pool if remote.ByteBuffer && *skipByteBuffer - if buf, ok := r.(remoteByteBuffer); ok { - if p.p != nil { - r, ok := p.p.(*skipByteBuffer) - if ok { - r.Reset(buf) - return - } - p.p.Release() - } - p.p = newSkipByteBuffer(buf) - return - } - - // not remote.ByteBuffer - - if p.p != nil { - p.p.Release() - } - p.p = newSkipReader(r) -} - -// Release ... +// Release releases the peekAck decoder, callers cannot use the returned data of Next after calling Release. func (p *SkipDecoder) Release() { - p.p.Release() - p.p = nil + if p.nextBuf != nil { + mcache.Free(p.nextBuf) + } + *p = SkipDecoder{} poolSkipDecoder.Put(p) } -// Next skips a specific type and returns its bytes +// Next skips a specific type and returns its bytes. +// Callers cannot use the returned data after calling Release. func (p *SkipDecoder) Next(t TType) (buf []byte, err error) { - if err := p.skip(t, defaultRecursionDepth); err != nil { - return nil, err + p.nextBuf = mcache.Malloc(0, defaultSkipDecoderSize) + if err = p.skip(t, defaultRecursionDepth); err != nil { + return + } + var offset int + for _, b := range p.pendingBuf { + offset += copy(p.nextBuf[offset:], b[offset:]) + mcache.Free(b) } - return p.p.Bytes() + p.pendingBuf = nil + buf = p.nextBuf + return } func (p *SkipDecoder) skip(t TType, maxdepth int) error { @@ -85,12 +82,12 @@ func (p *SkipDecoder) skip(t TType, maxdepth int) error { return errDepthLimitExceeded } if sz := typeToSize[t]; sz > 0 { - _, err := p.p.Next(int(sz)) + _, err := p.next(int(sz)) return err } switch t { case STRING: - b, err := p.p.Next(4) + b, err := p.next(4) if err != nil { return err } @@ -98,12 +95,12 @@ func (p *SkipDecoder) skip(t TType, maxdepth int) error { if sz < 0 { return errNegativeSize } - if _, err := p.p.Next(sz); err != nil { + if _, err := p.next(sz); err != nil { return err } case STRUCT: for { - b, err := p.p.Next(1) // TType + b, err := p.next(1) // TType if err != nil { return err } @@ -111,7 +108,7 @@ func (p *SkipDecoder) skip(t TType, maxdepth int) error { if tp == STOP { break } - if _, err := p.p.Next(2); err != nil { // Field ID + if _, err := p.next(2); err != nil { // Field ID return err } if err := p.skip(tp, maxdepth-1); err != nil { @@ -119,7 +116,7 @@ func (p *SkipDecoder) skip(t TType, maxdepth int) error { } } case MAP: - b, err := p.p.Next(6) // 1 byte key TType, 1 byte value TType, 4 bytes Len + b, err := p.next(6) // 1 byte key TType, 1 byte value TType, 4 bytes Len if err != nil { return err } @@ -129,7 +126,7 @@ func (p *SkipDecoder) skip(t TType, maxdepth int) error { } ksz, vsz := int(typeToSize[kt]), int(typeToSize[vt]) if ksz > 0 && vsz > 0 { - _, err := p.p.Next(int(sz) * (ksz + vsz)) + _, err := p.next(int(sz) * (ksz + vsz)) return err } for i := int32(0); i < sz; i++ { @@ -141,7 +138,7 @@ func (p *SkipDecoder) skip(t TType, maxdepth int) error { } } case SET, LIST: - b, err := p.p.Next(5) // 1 byte value type, 4 bytes Len + b, err := p.next(5) // 1 byte value type, 4 bytes Len if err != nil { return err } @@ -150,7 +147,7 @@ func (p *SkipDecoder) skip(t TType, maxdepth int) error { return errNegativeSize } if vsz := typeToSize[vt]; vsz > 0 { - _, err := p.p.Next(int(sz) * int(vsz)) + _, err := p.next(int(sz) * int(vsz)) return err } for i := int32(0); i < sz; i++ { @@ -163,3 +160,20 @@ func (p *SkipDecoder) skip(t TType, maxdepth int) error { } return nil } + +func (p *SkipDecoder) next(n int) (buf []byte, err error) { + if buf, err = p.r.Next(n); err != nil { + return + } + if cap(p.nextBuf)-len(p.nextBuf) < n { + var ncap int + for ncap = cap(p.nextBuf) * 2; ncap-len(p.nextBuf) < n; ncap *= 2 { + } + nbs := mcache.Malloc(ncap, ncap) + p.pendingBuf = append(p.pendingBuf, p.nextBuf) + p.nextBuf = nbs[:len(p.nextBuf)] + } + cn := copy(p.nextBuf[len(p.nextBuf):cap(p.nextBuf)], buf) + p.nextBuf = p.nextBuf[:len(p.nextBuf)+cn] + return +} diff --git a/protocol/thrift/skipdecoder_test.go b/protocol/thrift/skipdecoder_test.go index e5d21b2..6e3eafa 100644 --- a/protocol/thrift/skipdecoder_test.go +++ b/protocol/thrift/skipdecoder_test.go @@ -18,10 +18,10 @@ package thrift import ( "bytes" - "math/rand" "strings" "testing" + "github.com/cloudwego/gopkg/bufiox" "github.com/stretchr/testify/require" ) @@ -92,7 +92,7 @@ func TestSkipDecoder(t *testing.T) { b = x.AppendFieldStop(b) sz10 := len(b) - r := NewSkipDecoder(bytes.NewReader(b)) + r := NewSkipDecoder(bufiox.NewBytesReader(b)) defer r.Release() readn := 0 @@ -147,7 +147,7 @@ func TestSkipDecoder(t *testing.T) { for i := 0; i < defaultRecursionDepth+1; i++ { b = x.AppendFieldBegin(b, STRUCT, 1) } - r := NewSkipDecoder(bytes.NewReader(b)) + r := NewSkipDecoder(bufiox.NewBytesReader(b)) _, err := r.Next(STRUCT) require.Same(t, errDepthLimitExceeded, err) @@ -157,19 +157,79 @@ func TestSkipDecoder(t *testing.T) { } } -func TestSkipDecoderReset(t *testing.T) { - x := BinaryProtocol{} - b := x.AppendString([]byte(nil), "hello") - - r := NewSkipDecoder(nil) - for i := 0; i < 10; i++ { - if rand.Intn(2) == 1 { // random skipreader to test Reset - r.Reset(&remoteByteBufferImplForT{b: b}) - } else { - r.Reset(bytes.NewReader(b)) +var mockString = make([]byte, 5000) + +func BenchmarkSkipDecoder(b *testing.B) { + // prepare data + bs := make([]byte, 0, 1024) + + // BOOL, fid=1 + bs = Binary.AppendFieldBegin(bs, BOOL, 1) + bs = Binary.AppendBool(bs, true) + + // BYTE, fid=2 + bs = Binary.AppendFieldBegin(bs, BYTE, 2) + bs = Binary.AppendByte(bs, 2) + + // I16, fid=3 + bs = Binary.AppendFieldBegin(bs, I16, 3) + bs = Binary.AppendI16(bs, 3) + + // I32, fid=4 + bs = Binary.AppendFieldBegin(bs, I32, 4) + bs = Binary.AppendI32(bs, 4) + + // I64, fid=5 + bs = Binary.AppendFieldBegin(bs, I64, 5) + bs = Binary.AppendI64(bs, 5) + + // DOUBLE, fid=6 + bs = Binary.AppendFieldBegin(bs, DOUBLE, 6) + bs = Binary.AppendDouble(bs, 6) + + // STRING, fid=7 + bs = Binary.AppendFieldBegin(bs, STRING, 7) + bs = Binary.AppendString(bs, string(mockString)) + + // MAP, fid=8 + bs = Binary.AppendFieldBegin(bs, MAP, 8) + bs = Binary.AppendMapBegin(bs, DOUBLE, DOUBLE, 1) + bs = Binary.AppendDouble(bs, 8.1) + bs = Binary.AppendDouble(bs, 8.2) + + // SET, fid=9 + bs = Binary.AppendFieldBegin(bs, SET, 9) + bs = Binary.AppendSetBegin(bs, I64, 1) + bs = Binary.AppendI64(bs, 9) + + // LIST, fid=10 + bs = Binary.AppendFieldBegin(bs, LIST, 10) + bs = Binary.AppendListBegin(bs, I64, 1) + bs = Binary.AppendI64(bs, 10) + + // STRUCT with 1 field I64, fid=11,1 + bs = Binary.AppendFieldBegin(bs, STRUCT, 11) + bs = Binary.AppendFieldBegin(bs, I64, 1) + bs = Binary.AppendI64(bs, 11) + bs = Binary.AppendFieldStop(bs) + + // Finish struct + bs = Binary.AppendFieldStop(bs) + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + bufReader := bufiox.NewBytesReader(bs) + sr := NewSkipDecoder(bufReader) + buf, err := sr.Next(STRUCT) + if err != nil { + b.Fatal(err) + } + if !bytes.Equal(buf, bs) { + b.Fatal("bytes not equal") + } + sr.Release() + bufReader.Release(nil) } - retb, err := r.Next(STRING) - require.NoError(t, err) - require.Equal(t, b, retb) - } + }) } diff --git a/protocol/thrift/skipreader.go b/protocol/thrift/skipreader.go deleted file mode 100644 index eb04580..0000000 --- a/protocol/thrift/skipreader.go +++ /dev/null @@ -1,161 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package thrift - -import ( - "io" - "sync" -) - -// this file contains readers for SkipDecoder - -type skipReaderIface interface { - Next(n int) (buf []byte, err error) - Bytes() (buf []byte, err error) - Release() -} - -var poolSkipReader = sync.Pool{ - New: func() interface{} { - return &skipReader{b: make([]byte, 1024)} - }, -} - -var poolSkipRemoteBuffer = sync.Pool{ - New: func() interface{} { - return &skipByteBuffer{} - }, -} - -// skipReader ... general skip reader for io.Reader -type skipReader struct { - r io.Reader - - p int - b []byte -} - -func newSkipReader(r io.Reader) *skipReader { - ret := poolSkipReader.Get().(*skipReader) - ret.Reset(r) - return ret -} - -func (p *skipReader) Release() { - poolSkipReader.Put(p) -} - -func (p *skipReader) Reset(r io.Reader) { - p.r = r - p.p = 0 -} - -func (p *skipReader) Bytes() ([]byte, error) { - ret := p.b[:p.p] - p.p = 0 - return ret, nil -} - -func (p *skipReader) grow(n int) { - // assert: len(p.b)-p.p < n - sz := 2 * cap(p.b) - if sz < p.p+n { - sz = p.p + n - } - b := make([]byte, sz) - copy(b, p.b[:p.p]) - p.b = b -} - -func (p *skipReader) Next(n int) (buf []byte, err error) { - if len(p.b)-p.p < n { - p.grow(n) - } - if _, err := io.ReadFull(p.r, p.b[p.p:p.p+n]); err != nil { - return nil, NewProtocolExceptionWithErr(err) - } - ret := p.b[p.p : p.p+n] - p.p += n - return ret, nil -} - -// remoteByteBuffer ... github.com/cloudwego/kitex/pkg/remote.ByteBuffer -type remoteByteBuffer interface { - Peek(n int) (buf []byte, err error) - ReadableLen() (n int) - Skip(n int) (err error) -} - -// skipByteBuffer ... optimized zero copy skipreader for remote.ByteBuffer -type skipByteBuffer struct { - p remoteByteBuffer - - r int - b []byte -} - -func newSkipByteBuffer(buf remoteByteBuffer) *skipByteBuffer { - ret := poolSkipRemoteBuffer.Get().(*skipByteBuffer) - ret.Reset(buf) - return ret -} - -func (p *skipByteBuffer) Release() { - poolSkipRemoteBuffer.Put(p) -} - -func (p *skipByteBuffer) Reset(buf remoteByteBuffer) { - p.r = 0 - p.b = nil - p.p = buf -} - -func (p *skipByteBuffer) Bytes() ([]byte, error) { - ret := p.b[:p.r] - if err := p.p.Skip(p.r); err != nil { - return nil, err - } - p.r = 0 - return ret, nil -} - -// Next ... -func (p *skipByteBuffer) Next(n int) (ret []byte, err error) { - if p.r+n < len(p.b) { // fast path - ret, p.r = p.b[p.r:p.r+n], p.r+n - return - } - return p.nextSlow(n) -} - -func (p *skipByteBuffer) nextSlow(n int) ([]byte, error) { - // trigger underlying conn to read more - _, err := p.p.Peek(p.r + n) - if err != nil { - return nil, err - } - // read as much as possible, luckily, we will have a full buffer - // then we no need to call p.Peek many times - p.b, err = p.p.Peek(p.p.ReadableLen()) - if err != nil { - return nil, err - } - // after calling p.p.Peek, p.buf MUST be at least (p.r + n) len - ret := p.b[p.r : p.r+n] - p.r += n - return ret, nil -} diff --git a/protocol/thrift/skipreader_test.go b/protocol/thrift/skipreader_test.go deleted file mode 100644 index 4307022..0000000 --- a/protocol/thrift/skipreader_test.go +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package thrift - -import ( - "bytes" - "errors" - "io" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestSkipReader(t *testing.T) { - b := make([]byte, 2048) - for i := 0; i < len(b); i++ { - b[i] = byte(i) - } - - r := newSkipReader(bytes.NewReader(b)) - defer r.Release() - for i := 0; i < len(b); i++ { - b, err := r.Next(1) - require.NoError(t, err) - require.True(t, b[0] == byte(i)) - } - - retb, err := r.Bytes() - require.NoError(t, err) - require.Equal(t, b, retb) -} - -type remoteByteBufferImplForT struct { - p int - b []byte -} - -func (remoteByteBufferImplForT) Read(_ []byte) (int, error) { return 0, errors.New("not implemented") } - -func (p *remoteByteBufferImplForT) Peek(n int) (buf []byte, err error) { - if n > len(p.b) { - return nil, io.EOF - } - return p.b[:n], nil -} - -func (p *remoteByteBufferImplForT) ReadableLen() int { - return len(p.b) - p.p -} - -func (p *remoteByteBufferImplForT) Skip(n int) error { - if n > len(p.b) { - panic("bug") - } - p.p += n - return nil -} - -func TestSkipRemoteBuffer(t *testing.T) { - b := make([]byte, 2048) - for i := 0; i < len(b); i++ { - b[i] = byte(i) - } - - r := newSkipByteBuffer(&remoteByteBufferImplForT{b: b}) - defer r.Release() - for i := 0; i < len(b); i++ { - b, err := r.Next(1) - require.NoError(t, err) - require.True(t, b[0] == byte(i)) - } - - retb, err := r.Bytes() - require.NoError(t, err) - require.Equal(t, b, retb) -} diff --git a/protocol/thrift/utils.go b/protocol/thrift/utils.go index dfd1d4a..3b4a34f 100644 --- a/protocol/thrift/utils.go +++ b/protocol/thrift/utils.go @@ -17,8 +17,6 @@ package thrift import ( - "io" - "sync" "unsafe" ) @@ -37,60 +35,3 @@ type nextIface interface { type discardIface interface { Discard(n int) (int, error) } - -// nextReader provides a wrapper for io.Reader to use BinaryReader -type nextReader struct { - r io.Reader - b [4096]byte -} - -var poolNextReader = sync.Pool{ - New: func() interface{} { - return &nextReader{} - }, -} - -func newNextReader(r io.Reader) *nextReader { - ret := poolNextReader.Get().(*nextReader) - ret.Reset(r) - return ret -} - -// Release ... -func (r *nextReader) Release() { poolNextReader.Put(r) } - -// Reset ... for reusing nextReader -func (r *nextReader) Reset(rd io.Reader) { r.r = rd } - -// Next implements nextIface of BinaryReader -func (r *nextReader) Next(n int) ([]byte, error) { - b := r.b[:] - if n <= len(b) { - b = b[:n] - } else { - b = make([]byte, n) - } - _, err := io.ReadFull(r.r, b) - if err != nil { - return nil, err - } - return b, nil -} - -// Discard implements discardIface of BinaryReader -func (r *nextReader) Discard(n int) (int, error) { - ret := 0 - b := r.b[:] - for n > 0 { - if len(b) > n { - b = b[:n] - } - readn, err := r.r.Read(b) - ret += readn - if err != nil { - return ret, err - } - n -= readn - } - return ret, nil -}