From 87c3e91a7394241ff97fe394ca7883cf900fc4a5 Mon Sep 17 00:00:00 2001 From: xiezhengyao Date: Mon, 19 Aug 2024 15:28:04 +0800 Subject: [PATCH] refactor: thrift codec uses bufiox interface for encoding and decoding --- protocol/thrift/binarywriter.go | 120 ------------ .../{binaryreader.go => bufferreader.go} | 121 +++++------- ...aryreader_test.go => bufferreader_test.go} | 8 +- protocol/thrift/bufferwriter.go | 179 ++++++++++++++++++ ...arywriter_test.go => bufferwriter_test.go} | 25 +-- protocol/thrift/skipdecoder.go | 110 ++++++----- protocol/thrift/skipdecoder_test.go | 24 +-- protocol/thrift/skipreader.go | 161 ---------------- protocol/thrift/skipreader_test.go | 90 --------- protocol/thrift/utils.go | 59 ------ 10 files changed, 312 insertions(+), 585 deletions(-) delete mode 100644 protocol/thrift/binarywriter.go rename protocol/thrift/{binaryreader.go => bufferreader.go} (69%) rename protocol/thrift/{binaryreader_test.go => bufferreader_test.go} (97%) create mode 100644 protocol/thrift/bufferwriter.go rename protocol/thrift/{binarywriter_test.go => bufferwriter_test.go} (75%) delete mode 100644 protocol/thrift/skipreader.go delete mode 100644 protocol/thrift/skipreader_test.go diff --git a/protocol/thrift/binarywriter.go b/protocol/thrift/binarywriter.go deleted file mode 100644 index d77636e..0000000 --- a/protocol/thrift/binarywriter.go +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package thrift - -import ( - "math" - "sync" -) - -const defaultBinaryWriterBufferSize = 4096 - -type BinaryWriter struct { - buf []byte -} - -var poolBinaryWriter = sync.Pool{ - New: func() interface{} { - return &BinaryWriter{buf: make([]byte, 0, defaultBinaryWriterBufferSize)} - }, -} - -func NewBinaryWriter() *BinaryWriter { - return NewBinaryWriterSize(0) -} - -func NewBinaryWriterSize(sz int) *BinaryWriter { - w := poolBinaryWriter.Get().(*BinaryWriter) - if cap(w.buf) < sz { - w.Release() - w = &BinaryWriter{buf: make([]byte, 0, sz)} - } - w.Reset() - return w -} - -func (w *BinaryWriter) Release() { - poolBinaryWriter.Put(w) -} - -func (w *BinaryWriter) Reset() { - w.buf = w.buf[:0] -} - -func (w *BinaryWriter) Bytes() []byte { - return w.buf -} - -func (w *BinaryWriter) WriteMessageBegin(name string, typeID TMessageType, seq int32) { - w.buf = Binary.AppendMessageBegin(w.buf, name, typeID, seq) -} - -func (w *BinaryWriter) WriteFieldBegin(typeID TType, id int16) { - w.buf = Binary.AppendFieldBegin(w.buf, typeID, id) -} - -func (w *BinaryWriter) WriteFieldStop() { - w.buf = append(w.buf, byte(STOP)) -} - -func (w *BinaryWriter) WriteMapBegin(kt, vt TType, size int) { - w.buf = Binary.AppendMapBegin(w.buf, kt, vt, size) -} - -func (w *BinaryWriter) WriteListBegin(et TType, size int) { - w.buf = Binary.AppendListBegin(w.buf, et, size) -} - -func (w *BinaryWriter) WriteSetBegin(et TType, size int) { - w.buf = Binary.AppendSetBegin(w.buf, et, size) -} - -func (w *BinaryWriter) WriteBinary(v []byte) { - w.buf = Binary.AppendBinary(w.buf, v) -} - -func (w *BinaryWriter) WriteString(v string) { - w.buf = Binary.AppendString(w.buf, v) -} - -func (w *BinaryWriter) WriteBool(v bool) { - if v { - w.buf = append(w.buf, 1) - } else { - w.buf = append(w.buf, 0) - } -} - -func (w *BinaryWriter) WriteByte(v int8) { - w.buf = append(w.buf, byte(v)) -} - -func (w *BinaryWriter) WriteI16(v int16) { - w.buf = append(w.buf, byte(uint16(v)>>8), byte(v)) -} - -func (w *BinaryWriter) WriteI32(v int32) { - w.buf = appendUint32(w.buf, uint32(v)) -} - -func (w *BinaryWriter) WriteI64(v int64) { - w.buf = appendUint64(w.buf, uint64(v)) -} - -func (w *BinaryWriter) WriteDouble(v float64) { - w.buf = appendUint64(w.buf, math.Float64bits(v)) -} diff --git a/protocol/thrift/binaryreader.go b/protocol/thrift/bufferreader.go similarity index 69% rename from protocol/thrift/binaryreader.go rename to protocol/thrift/bufferreader.go index a08ebf0..8cfd2b4 100644 --- a/protocol/thrift/binaryreader.go +++ b/protocol/thrift/bufferreader.go @@ -19,90 +19,72 @@ package thrift import ( "encoding/binary" "fmt" - "io" "math" "sync" -) -// BinaryReader represents a reader for binary protocol -type BinaryReader struct { - r nextIface - d discardIface + "github.com/bytedance/gopkg/lang/dirtmake" + "github.com/cloudwego/gopkg/bufiox" + "github.com/cloudwego/gopkg/internal/hack" +) - rn int64 +// BufferReader represents a reader for binary protocol +type BufferReader struct { + r bufiox.Reader } -var poolBinaryReader = sync.Pool{ +var poolBufferReader = sync.Pool{ New: func() interface{} { - return &BinaryReader{} + return &BufferReader{} }, } -// NewBinaryReader ... call Release if no longer use for reusing -func NewBinaryReader(r io.Reader) *BinaryReader { - ret := poolBinaryReader.Get().(*BinaryReader) - ret.reset() - if nextr, ok := r.(nextIface); ok { - ret.r = nextr - } else { - nextr := newNextReader(r) - ret.r = nextr - ret.d = nextr - } +// NewBufferReader ... call Release if no longer use for reusing +func NewBufferReader(r bufiox.Reader) *BufferReader { + ret := poolBufferReader.Get().(*BufferReader) + ret.r = r return ret } // Release ... -func (r *BinaryReader) Release() { - nextr, ok := r.r.(*nextReader) - if ok { - nextr.Release() - } - r.reset() - poolBinaryReader.Put(r) -} - -func (r *BinaryReader) reset() { +func (r *BufferReader) Release() { + r.r.Release(nil) r.r = nil - r.d = nil - r.rn = 0 + poolBufferReader.Put(r) } -func (r *BinaryReader) next(n int) (b []byte, err error) { +func (r *BufferReader) next(n int) (b []byte, err error) { b, err = r.r.Next(n) if err != nil { err = NewProtocolExceptionWithErr(err) } - r.rn += int64(len(b)) return } -func (r *BinaryReader) skipn(n int) (err error) { +func (r *BufferReader) readBinary(bs []byte) (n int, err error) { + n, err = r.r.ReadBinary(bs) + if err != nil { + err = NewProtocolExceptionWithErr(err) + } + return +} + +func (r *BufferReader) skipn(n int) (err error) { if n < 0 { return errNegativeSize } - if r.d != nil { - var sz int - sz, err = r.d.Discard(n) - r.rn += int64(sz) - } else { - var b []byte - b, err = r.r.Next(n) - r.rn += int64(len(b)) - } - if err != nil { + if err = r.r.Skip(n); err != nil { return NewProtocolExceptionWithErr(err) } return nil } // Readn returns total bytes read from underlying reader -func (r *BinaryReader) Readn() int64 { - return r.rn +func (r *BufferReader) Readn() int64 { + return int64(r.r.ReadLen()) } // ReadBool ... -func (r *BinaryReader) ReadBool() (v bool, err error) { +func (r *BufferReader) ReadBool() (v bool, err error) { b, err := r.next(1) if err != nil { return false, err @@ -112,7 +94,7 @@ func (r *BinaryReader) ReadBool() (v bool, err error) { } // ReadByte ... -func (r *BinaryReader) ReadByte() (v int8, err error) { +func (r *BufferReader) ReadByte() (v int8, err error) { b, err := r.next(1) if err != nil { return 0, err @@ -122,7 +104,7 @@ func (r *BinaryReader) ReadByte() (v int8, err error) { } // ReadI16 ... -func (r *BinaryReader) ReadI16() (v int16, err error) { +func (r *BufferReader) ReadI16() (v int16, err error) { b, err := r.next(2) if err != nil { return 0, err @@ -132,7 +114,7 @@ func (r *BinaryReader) ReadI16() (v int16, err error) { } // ReadI32 ... -func (r *BinaryReader) ReadI32() (v int32, err error) { +func (r *BufferReader) ReadI32() (v int32, err error) { b, err := r.next(4) if err != nil { return 0, err @@ -142,7 +124,7 @@ func (r *BinaryReader) ReadI32() (v int32, err error) { } // ReadI64 ... -func (r *BinaryReader) ReadI64() (v int64, err error) { +func (r *BufferReader) ReadI64() (v int64, err error) { b, err := r.next(8) if err != nil { return 0, err @@ -152,7 +134,7 @@ func (r *BinaryReader) ReadI64() (v int64, err error) { } // ReadDouble ... -func (r *BinaryReader) ReadDouble() (v float64, err error) { +func (r *BufferReader) ReadDouble() (v float64, err error) { b, err := r.next(8) if err != nil { return 0, err @@ -162,32 +144,27 @@ func (r *BinaryReader) ReadDouble() (v float64, err error) { } // ReadBinary ... -func (r *BinaryReader) ReadBinary() (b []byte, err error) { +func (r *BufferReader) ReadBinary() (b []byte, err error) { sz, err := r.ReadI32() if err != nil { return nil, err } - b, err = r.next(int(sz)) - b = []byte(string(b)) // copy. use span cache? + b = dirtmake.Bytes(int(sz), int(sz)) + _, err = r.readBinary(b) return } // ReadString ... -func (r *BinaryReader) ReadString() (s string, err error) { - sz, err := r.ReadI32() +func (r *BufferReader) ReadString() (s string, err error) { + b, err := r.ReadBinary() if err != nil { return "", err } - b, err := r.next(int(sz)) - if err != nil { - return "", err - } - s = string(b) // copy. use span cache? - return + return hack.ByteSliceToString(b), nil } // ReadMessageBegin ... -func (r *BinaryReader) ReadMessageBegin() (name string, typeID TMessageType, seq int32, err error) { +func (r *BufferReader) ReadMessageBegin() (name string, typeID TMessageType, seq int32, err error) { var header int32 header, err = r.ReadI32() if err != nil { @@ -215,7 +192,7 @@ func (r *BinaryReader) ReadMessageBegin() (name string, typeID TMessageType, seq } // ReadFieldBegin ... -func (r *BinaryReader) ReadFieldBegin() (typeID TType, id int16, err error) { +func (r *BufferReader) ReadFieldBegin() (typeID TType, id int16, err error) { b, err := r.next(1) if err != nil { return 0, 0, err @@ -233,7 +210,7 @@ func (r *BinaryReader) ReadFieldBegin() (typeID TType, id int16, err error) { } // ReadMapBegin ... -func (r *BinaryReader) ReadMapBegin() (kt, vt TType, size int, err error) { +func (r *BufferReader) ReadMapBegin() (kt, vt TType, size int, err error) { b, err := r.next(6) if err != nil { return 0, 0, 0, err @@ -243,7 +220,7 @@ func (r *BinaryReader) ReadMapBegin() (kt, vt TType, size int, err error) { } // ReadListBegin ... -func (r *BinaryReader) ReadListBegin() (et TType, size int, err error) { +func (r *BufferReader) ReadListBegin() (et TType, size int, err error) { b, err := r.next(5) if err != nil { return 0, 0, err @@ -253,7 +230,7 @@ func (r *BinaryReader) ReadListBegin() (et TType, size int, err error) { } // ReadSetBegin ... -func (r *BinaryReader) ReadSetBegin() (et TType, size int, err error) { +func (r *BufferReader) ReadSetBegin() (et TType, size int, err error) { b, err := r.next(5) if err != nil { return 0, 0, err @@ -263,11 +240,11 @@ func (r *BinaryReader) ReadSetBegin() (et TType, size int, err error) { } // Skip ... -func (r *BinaryReader) Skip(t TType) error { +func (r *BufferReader) Skip(t TType) error { return r.skipType(t, defaultRecursionDepth) } -func (r *BinaryReader) skipstr() error { +func (r *BufferReader) skipstr() error { n, err := r.ReadI32() if err != nil { return err @@ -275,7 +252,7 @@ func (r *BinaryReader) skipstr() error { return r.skipn(int(n)) } -func (r *BinaryReader) skipType(t TType, maxdepth int) error { +func (r *BufferReader) skipType(t TType, maxdepth int) error { if maxdepth == 0 { return errDepthLimitExceeded } diff --git a/protocol/thrift/binaryreader_test.go b/protocol/thrift/bufferreader_test.go similarity index 97% rename from protocol/thrift/binaryreader_test.go rename to protocol/thrift/bufferreader_test.go index accd165..75a03c3 100644 --- a/protocol/thrift/binaryreader_test.go +++ b/protocol/thrift/bufferreader_test.go @@ -17,9 +17,9 @@ package thrift import ( - "bytes" "testing" + "github.com/cloudwego/gopkg/bufiox" "github.com/stretchr/testify/require" ) @@ -55,7 +55,7 @@ func TestBinaryReader(t *testing.T) { b = x.AppendDouble(b, 18.5) sz13 := len(b) - r := NewBinaryReader(bytes.NewReader(b)) + r := NewBufferReader(bufiox.NewBytesReader(b)) defer r.Release() name, mt, seq, err := r.ReadMessageBegin() require.NoError(t, err) @@ -206,7 +206,7 @@ func TestBinaryReaderSkip(t *testing.T) { b = x.AppendFieldStop(b) sz10 := len(b) - r := NewBinaryReader(bytes.NewReader(b)) + r := NewBufferReader(bufiox.NewBytesReader(b)) defer r.Release() err := r.Skip(BYTE) // byte @@ -249,7 +249,7 @@ func TestBinaryReaderSkip(t *testing.T) { for i := 0; i < defaultRecursionDepth+1; i++ { b = x.AppendFieldBegin(b, STRUCT, 1) } - r := NewBinaryReader(bytes.NewReader(b)) + r := NewBufferReader(bufiox.NewBytesReader(b)) err := r.Skip(STRUCT) require.Same(t, errDepthLimitExceeded, err) diff --git a/protocol/thrift/bufferwriter.go b/protocol/thrift/bufferwriter.go new file mode 100644 index 0000000..19b4783 --- /dev/null +++ b/protocol/thrift/bufferwriter.go @@ -0,0 +1,179 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "encoding/binary" + "math" + "sync" + + "github.com/cloudwego/gopkg/bufiox" + "github.com/cloudwego/gopkg/internal/hack" +) + +type BufferWriter struct { + w bufiox.Writer +} + +var poolBufferWriter = sync.Pool{ + New: func() interface{} { + return &BufferWriter{} + }, +} + +func NewBufferWriter(iw bufiox.Writer) *BufferWriter { + w := poolBufferWriter.Get().(*BufferWriter) + w.w = iw + return w +} + +func (w *BufferWriter) Release() { + w.w = nil + poolBufferWriter.Put(w) +} + +func (w *BufferWriter) WriteMessageBegin(name string, typeID TMessageType, seq int32) error { + buf, err := w.w.Malloc(Binary.MessageBeginLength(name)) + if err != nil { + return err + } + binary.BigEndian.PutUint32(buf, uint32(msgVersion1)|uint32(typeID&msgTypeMask)) + binary.BigEndian.PutUint32(buf[4:], uint32(len(name))) + copy(buf[8:], name) + binary.BigEndian.PutUint32(buf[8+len(name):], uint32(seq)) + return nil +} + +func (w *BufferWriter) WriteFieldBegin(typeID TType, id int16) error { + buf, err := w.w.Malloc(3) + if err != nil { + return err + } + buf[0], buf[1], buf[2] = byte(typeID), byte(uint16(id>>8)), byte(id) + return nil +} + +func (w *BufferWriter) WriteFieldStop() error { + buf, err := w.w.Malloc(1) + if err != nil { + return err + } + buf[0] = byte(STOP) + return nil +} + +func (w *BufferWriter) WriteMapBegin(kt, vt TType, size int) error { + buf, err := w.w.Malloc(6) + if err != nil { + return err + } + buf[0], buf[1] = byte(kt), byte(vt) + binary.BigEndian.PutUint32(buf[2:], uint32(size)) + return nil +} + +func (w *BufferWriter) WriteListBegin(et TType, size int) error { + buf, err := w.w.Malloc(5) + if err != nil { + return err + } + buf[0] = byte(et) + binary.BigEndian.PutUint32(buf[1:], uint32(size)) + return nil +} + +func (w *BufferWriter) WriteSetBegin(et TType, size int) error { + buf, err := w.w.Malloc(5) + if err != nil { + return err + } + buf[0] = byte(et) + binary.BigEndian.PutUint32(buf[1:], uint32(size)) + return nil +} + +func (w *BufferWriter) WriteBinary(v []byte) error { + buf, err := w.w.Malloc(4) + if err != nil { + return err + } + binary.BigEndian.PutUint32(buf, uint32(len(v))) + _, err = w.w.WriteBinary(v) + return err +} + +func (w *BufferWriter) WriteString(v string) error { + return w.WriteBinary(hack.StringToByteSlice(v)) +} + +func (w *BufferWriter) WriteBool(v bool) error { + buf, err := w.w.Malloc(1) + if err != nil { + return err + } + if v { + buf[0] = 1 + } else { + buf[0] = 0 + } + return nil +} + +func (w *BufferWriter) WriteByte(v int8) error { + buf, err := w.w.Malloc(1) + if err != nil { + return err + } + buf[0] = byte(v) + return nil +} + +func (w *BufferWriter) WriteI16(v int16) error { + buf, err := w.w.Malloc(2) + if err != nil { + return err + } + binary.BigEndian.PutUint16(buf, uint16(v)) + return nil +} + +func (w *BufferWriter) WriteI32(v int32) error { + buf, err := w.w.Malloc(4) + if err != nil { + return err + } + binary.BigEndian.PutUint32(buf, uint32(v)) + return nil +} + +func (w *BufferWriter) WriteI64(v int64) error { + buf, err := w.w.Malloc(8) + if err != nil { + return err + } + binary.BigEndian.PutUint64(buf, uint64(v)) + return nil +} + +func (w *BufferWriter) WriteDouble(v float64) error { + buf, err := w.w.Malloc(8) + if err != nil { + return err + } + binary.BigEndian.PutUint64(buf, math.Float64bits(v)) + return nil +} diff --git a/protocol/thrift/binarywriter_test.go b/protocol/thrift/bufferwriter_test.go similarity index 75% rename from protocol/thrift/binarywriter_test.go rename to protocol/thrift/bufferwriter_test.go index ee77817..85254d5 100644 --- a/protocol/thrift/binarywriter_test.go +++ b/protocol/thrift/bufferwriter_test.go @@ -19,68 +19,63 @@ package thrift import ( "testing" + "github.com/bytedance/gopkg/lang/dirtmake" + "github.com/cloudwego/gopkg/bufiox" "github.com/stretchr/testify/require" ) +const defaultBinaryWriterBufferSize = 4096 + func TestBinaryWriter(t *testing.T) { - w := NewBinaryWriterSize(defaultBinaryWriterBufferSize * 2) + buf := dirtmake.Bytes(0, defaultBinaryWriterBufferSize*2) + w := NewBufferWriter(bufiox.NewBytesWriter(&buf)) x := BinaryProtocol{} b := x.AppendMessageBegin(nil, "hello", 1, 2) w.WriteMessageBegin("hello", 1, 2) - require.Equal(t, b, w.Bytes()) b = x.AppendFieldBegin(b, 3, 4) w.WriteFieldBegin(3, 4) - require.Equal(t, b, w.Bytes()) b = x.AppendFieldStop(b) w.WriteFieldStop() - require.Equal(t, b, w.Bytes()) b = x.AppendMapBegin(b, 5, 6, 7) w.WriteMapBegin(5, 6, 7) - require.Equal(t, b, w.Bytes()) b = x.AppendListBegin(b, 8, 9) w.WriteListBegin(8, 9) - require.Equal(t, b, w.Bytes()) b = x.AppendSetBegin(b, 10, 11) w.WriteSetBegin(10, 11) - require.Equal(t, b, w.Bytes()) b = x.AppendBinary(b, []byte("12")) w.WriteBinary([]byte("12")) - require.Equal(t, b, w.Bytes()) b = x.AppendString(b, "13") w.WriteString("13") - require.Equal(t, b, w.Bytes()) b = x.AppendBool(b, true) b = x.AppendBool(b, false) w.WriteBool(true) w.WriteBool(false) - require.Equal(t, b, w.Bytes()) b = x.AppendByte(b, 14) w.WriteByte(14) - require.Equal(t, b, w.Bytes()) b = x.AppendI16(b, 15) w.WriteI16(15) - require.Equal(t, b, w.Bytes()) b = x.AppendI32(b, 16) w.WriteI32(16) - require.Equal(t, b, w.Bytes()) b = x.AppendI64(b, 17) w.WriteI64(17) - require.Equal(t, b, w.Bytes()) b = x.AppendDouble(b, 18.5) w.WriteDouble(18.5) - require.Equal(t, b, w.Bytes()) + + w.w.Flush() + + require.Equal(t, b, buf) } diff --git a/protocol/thrift/skipdecoder.go b/protocol/thrift/skipdecoder.go index 3e5f1c4..b6eabd6 100644 --- a/protocol/thrift/skipdecoder.go +++ b/protocol/thrift/skipdecoder.go @@ -19,10 +19,14 @@ package thrift import ( "encoding/binary" "fmt" - "io" "sync" + + "github.com/bytedance/gopkg/lang/mcache" + "github.com/cloudwego/gopkg/bufiox" ) +const defaultSkipDecoderSize = 4096 + var poolSkipDecoder = sync.Pool{ New: func() interface{} { return &SkipDecoder{} @@ -31,53 +35,54 @@ var poolSkipDecoder = sync.Pool{ // SkipDecoder scans the underlying io.Reader and returns the bytes of a type type SkipDecoder struct { - p skipReaderIface + r bufiox.Reader + bs []byte + + toRelease toReleaseBuffer } // NewSkipDecoder ... call Release if no longer use -func NewSkipDecoder(r io.Reader) *SkipDecoder { +func NewSkipDecoder(r bufiox.Reader) *SkipDecoder { p := poolSkipDecoder.Get().(*SkipDecoder) - p.Reset(r) + p.r = r return p } -// Reset ... -func (p *SkipDecoder) Reset(r io.Reader) { - // fast path without returning to pool if remote.ByteBuffer && *skipByteBuffer - if buf, ok := r.(remoteByteBuffer); ok { - if p.p != nil { - r, ok := p.p.(*skipByteBuffer) - if ok { - r.Reset(buf) - return - } - p.p.Release() - } - p.p = newSkipByteBuffer(buf) +// Release releases the skip decoder, callers cannot use the returned data of Next after calling Release. +func (p *SkipDecoder) Release() { + p.toRelease.release() + *p = SkipDecoder{} + poolSkipDecoder.Put(p) +} + +func (p *SkipDecoder) next(n int) (buf []byte, err error) { + buf, err = p.r.Next(n) + if err != nil { return } - - // not remote.ByteBuffer - - if p.p != nil { - p.p.Release() + if cap(p.bs)-len(p.bs) < n { + var ncap int + for ncap = cap(p.bs) * 2; ncap-len(p.bs) < n; ncap *= 2 { + } + nbs := mcache.Malloc(ncap, ncap) + p.toRelease.append(p.bs) + p.bs = nbs[:copy(nbs, p.bs)] } - p.p = newSkipReader(r) + cn := copy(p.bs[len(p.bs):cap(p.bs)], buf) + p.bs = p.bs[:len(p.bs)+cn] + return } -// Release ... -func (p *SkipDecoder) Release() { - p.p.Release() - p.p = nil - poolSkipDecoder.Put(p) -} - -// Next skips a specific type and returns its bytes +// Next skips a specific type and returns its bytes. +// Callers cannot use the returned data after calling Release. func (p *SkipDecoder) Next(t TType) (buf []byte, err error) { - if err := p.skip(t, defaultRecursionDepth); err != nil { - return nil, err + p.bs = mcache.Malloc(0, defaultSkipDecoderSize) + if err = p.skip(t, defaultRecursionDepth); err != nil { + return } - return p.p.Bytes() + buf = p.bs + p.toRelease.append(p.bs) + return } func (p *SkipDecoder) skip(t TType, maxdepth int) error { @@ -85,12 +90,12 @@ func (p *SkipDecoder) skip(t TType, maxdepth int) error { return errDepthLimitExceeded } if sz := typeToSize[t]; sz > 0 { - _, err := p.p.Next(int(sz)) + _, err := p.next(int(sz)) return err } switch t { case STRING: - b, err := p.p.Next(4) + b, err := p.next(4) if err != nil { return err } @@ -98,12 +103,12 @@ func (p *SkipDecoder) skip(t TType, maxdepth int) error { if sz < 0 { return errNegativeSize } - if _, err := p.p.Next(sz); err != nil { + if _, err := p.next(sz); err != nil { return err } case STRUCT: for { - b, err := p.p.Next(1) // TType + b, err := p.next(1) // TType if err != nil { return err } @@ -111,7 +116,7 @@ func (p *SkipDecoder) skip(t TType, maxdepth int) error { if tp == STOP { break } - if _, err := p.p.Next(2); err != nil { // Field ID + if _, err := p.next(2); err != nil { // Field ID return err } if err := p.skip(tp, maxdepth-1); err != nil { @@ -119,7 +124,7 @@ func (p *SkipDecoder) skip(t TType, maxdepth int) error { } } case MAP: - b, err := p.p.Next(6) // 1 byte key TType, 1 byte value TType, 4 bytes Len + b, err := p.next(6) // 1 byte key TType, 1 byte value TType, 4 bytes Len if err != nil { return err } @@ -129,7 +134,7 @@ func (p *SkipDecoder) skip(t TType, maxdepth int) error { } ksz, vsz := int(typeToSize[kt]), int(typeToSize[vt]) if ksz > 0 && vsz > 0 { - _, err := p.p.Next(int(sz) * (ksz + vsz)) + _, err := p.next(int(sz) * (ksz + vsz)) return err } for i := int32(0); i < sz; i++ { @@ -141,7 +146,7 @@ func (p *SkipDecoder) skip(t TType, maxdepth int) error { } } case SET, LIST: - b, err := p.p.Next(5) // 1 byte value type, 4 bytes Len + b, err := p.next(5) // 1 byte value type, 4 bytes Len if err != nil { return err } @@ -150,7 +155,7 @@ func (p *SkipDecoder) skip(t TType, maxdepth int) error { return errNegativeSize } if vsz := typeToSize[vt]; vsz > 0 { - _, err := p.p.Next(int(sz) * int(vsz)) + _, err := p.next(int(sz) * int(vsz)) return err } for i := int32(0); i < sz; i++ { @@ -163,3 +168,22 @@ func (p *SkipDecoder) skip(t TType, maxdepth int) error { } return nil } + +type toReleaseBuffer struct { + idx int + bs [16][]byte +} + +func (rb *toReleaseBuffer) append(b []byte) { + rb.bs[rb.idx%16] = b + rb.idx++ +} + +func (rb *toReleaseBuffer) release() { + for _, b := range rb.bs { + if b == nil { + break + } + mcache.Free(b) + } +} diff --git a/protocol/thrift/skipdecoder_test.go b/protocol/thrift/skipdecoder_test.go index e5d21b2..e748a38 100644 --- a/protocol/thrift/skipdecoder_test.go +++ b/protocol/thrift/skipdecoder_test.go @@ -17,11 +17,10 @@ package thrift import ( - "bytes" - "math/rand" "strings" "testing" + "github.com/cloudwego/gopkg/bufiox" "github.com/stretchr/testify/require" ) @@ -92,7 +91,7 @@ func TestSkipDecoder(t *testing.T) { b = x.AppendFieldStop(b) sz10 := len(b) - r := NewSkipDecoder(bytes.NewReader(b)) + r := NewSkipDecoder(bufiox.NewBytesReader(b)) defer r.Release() readn := 0 @@ -147,7 +146,7 @@ func TestSkipDecoder(t *testing.T) { for i := 0; i < defaultRecursionDepth+1; i++ { b = x.AppendFieldBegin(b, STRUCT, 1) } - r := NewSkipDecoder(bytes.NewReader(b)) + r := NewSkipDecoder(bufiox.NewBytesReader(b)) _, err := r.Next(STRUCT) require.Same(t, errDepthLimitExceeded, err) @@ -156,20 +155,3 @@ func TestSkipDecoder(t *testing.T) { require.Error(t, err) } } - -func TestSkipDecoderReset(t *testing.T) { - x := BinaryProtocol{} - b := x.AppendString([]byte(nil), "hello") - - r := NewSkipDecoder(nil) - for i := 0; i < 10; i++ { - if rand.Intn(2) == 1 { // random skipreader to test Reset - r.Reset(&remoteByteBufferImplForT{b: b}) - } else { - r.Reset(bytes.NewReader(b)) - } - retb, err := r.Next(STRING) - require.NoError(t, err) - require.Equal(t, b, retb) - } -} diff --git a/protocol/thrift/skipreader.go b/protocol/thrift/skipreader.go deleted file mode 100644 index eb04580..0000000 --- a/protocol/thrift/skipreader.go +++ /dev/null @@ -1,161 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package thrift - -import ( - "io" - "sync" -) - -// this file contains readers for SkipDecoder - -type skipReaderIface interface { - Next(n int) (buf []byte, err error) - Bytes() (buf []byte, err error) - Release() -} - -var poolSkipReader = sync.Pool{ - New: func() interface{} { - return &skipReader{b: make([]byte, 1024)} - }, -} - -var poolSkipRemoteBuffer = sync.Pool{ - New: func() interface{} { - return &skipByteBuffer{} - }, -} - -// skipReader ... general skip reader for io.Reader -type skipReader struct { - r io.Reader - - p int - b []byte -} - -func newSkipReader(r io.Reader) *skipReader { - ret := poolSkipReader.Get().(*skipReader) - ret.Reset(r) - return ret -} - -func (p *skipReader) Release() { - poolSkipReader.Put(p) -} - -func (p *skipReader) Reset(r io.Reader) { - p.r = r - p.p = 0 -} - -func (p *skipReader) Bytes() ([]byte, error) { - ret := p.b[:p.p] - p.p = 0 - return ret, nil -} - -func (p *skipReader) grow(n int) { - // assert: len(p.b)-p.p < n - sz := 2 * cap(p.b) - if sz < p.p+n { - sz = p.p + n - } - b := make([]byte, sz) - copy(b, p.b[:p.p]) - p.b = b -} - -func (p *skipReader) Next(n int) (buf []byte, err error) { - if len(p.b)-p.p < n { - p.grow(n) - } - if _, err := io.ReadFull(p.r, p.b[p.p:p.p+n]); err != nil { - return nil, NewProtocolExceptionWithErr(err) - } - ret := p.b[p.p : p.p+n] - p.p += n - return ret, nil -} - -// remoteByteBuffer ... github.com/cloudwego/kitex/pkg/remote.ByteBuffer -type remoteByteBuffer interface { - Peek(n int) (buf []byte, err error) - ReadableLen() (n int) - Skip(n int) (err error) -} - -// skipByteBuffer ... optimized zero copy skipreader for remote.ByteBuffer -type skipByteBuffer struct { - p remoteByteBuffer - - r int - b []byte -} - -func newSkipByteBuffer(buf remoteByteBuffer) *skipByteBuffer { - ret := poolSkipRemoteBuffer.Get().(*skipByteBuffer) - ret.Reset(buf) - return ret -} - -func (p *skipByteBuffer) Release() { - poolSkipRemoteBuffer.Put(p) -} - -func (p *skipByteBuffer) Reset(buf remoteByteBuffer) { - p.r = 0 - p.b = nil - p.p = buf -} - -func (p *skipByteBuffer) Bytes() ([]byte, error) { - ret := p.b[:p.r] - if err := p.p.Skip(p.r); err != nil { - return nil, err - } - p.r = 0 - return ret, nil -} - -// Next ... -func (p *skipByteBuffer) Next(n int) (ret []byte, err error) { - if p.r+n < len(p.b) { // fast path - ret, p.r = p.b[p.r:p.r+n], p.r+n - return - } - return p.nextSlow(n) -} - -func (p *skipByteBuffer) nextSlow(n int) ([]byte, error) { - // trigger underlying conn to read more - _, err := p.p.Peek(p.r + n) - if err != nil { - return nil, err - } - // read as much as possible, luckily, we will have a full buffer - // then we no need to call p.Peek many times - p.b, err = p.p.Peek(p.p.ReadableLen()) - if err != nil { - return nil, err - } - // after calling p.p.Peek, p.buf MUST be at least (p.r + n) len - ret := p.b[p.r : p.r+n] - p.r += n - return ret, nil -} diff --git a/protocol/thrift/skipreader_test.go b/protocol/thrift/skipreader_test.go deleted file mode 100644 index 4307022..0000000 --- a/protocol/thrift/skipreader_test.go +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Copyright 2024 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package thrift - -import ( - "bytes" - "errors" - "io" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestSkipReader(t *testing.T) { - b := make([]byte, 2048) - for i := 0; i < len(b); i++ { - b[i] = byte(i) - } - - r := newSkipReader(bytes.NewReader(b)) - defer r.Release() - for i := 0; i < len(b); i++ { - b, err := r.Next(1) - require.NoError(t, err) - require.True(t, b[0] == byte(i)) - } - - retb, err := r.Bytes() - require.NoError(t, err) - require.Equal(t, b, retb) -} - -type remoteByteBufferImplForT struct { - p int - b []byte -} - -func (remoteByteBufferImplForT) Read(_ []byte) (int, error) { return 0, errors.New("not implemented") } - -func (p *remoteByteBufferImplForT) Peek(n int) (buf []byte, err error) { - if n > len(p.b) { - return nil, io.EOF - } - return p.b[:n], nil -} - -func (p *remoteByteBufferImplForT) ReadableLen() int { - return len(p.b) - p.p -} - -func (p *remoteByteBufferImplForT) Skip(n int) error { - if n > len(p.b) { - panic("bug") - } - p.p += n - return nil -} - -func TestSkipRemoteBuffer(t *testing.T) { - b := make([]byte, 2048) - for i := 0; i < len(b); i++ { - b[i] = byte(i) - } - - r := newSkipByteBuffer(&remoteByteBufferImplForT{b: b}) - defer r.Release() - for i := 0; i < len(b); i++ { - b, err := r.Next(1) - require.NoError(t, err) - require.True(t, b[0] == byte(i)) - } - - retb, err := r.Bytes() - require.NoError(t, err) - require.Equal(t, b, retb) -} diff --git a/protocol/thrift/utils.go b/protocol/thrift/utils.go index dfd1d4a..3b4a34f 100644 --- a/protocol/thrift/utils.go +++ b/protocol/thrift/utils.go @@ -17,8 +17,6 @@ package thrift import ( - "io" - "sync" "unsafe" ) @@ -37,60 +35,3 @@ type nextIface interface { type discardIface interface { Discard(n int) (int, error) } - -// nextReader provides a wrapper for io.Reader to use BinaryReader -type nextReader struct { - r io.Reader - b [4096]byte -} - -var poolNextReader = sync.Pool{ - New: func() interface{} { - return &nextReader{} - }, -} - -func newNextReader(r io.Reader) *nextReader { - ret := poolNextReader.Get().(*nextReader) - ret.Reset(r) - return ret -} - -// Release ... -func (r *nextReader) Release() { poolNextReader.Put(r) } - -// Reset ... for reusing nextReader -func (r *nextReader) Reset(rd io.Reader) { r.r = rd } - -// Next implements nextIface of BinaryReader -func (r *nextReader) Next(n int) ([]byte, error) { - b := r.b[:] - if n <= len(b) { - b = b[:n] - } else { - b = make([]byte, n) - } - _, err := io.ReadFull(r.r, b) - if err != nil { - return nil, err - } - return b, nil -} - -// Discard implements discardIface of BinaryReader -func (r *nextReader) Discard(n int) (int, error) { - ret := 0 - b := r.b[:] - for n > 0 { - if len(b) > n { - b = b[:n] - } - readn, err := r.r.Read(b) - ret += readn - if err != nil { - return ret, err - } - n -= readn - } - return ret, nil -}