From 4a3d9504dc5b6ed55b4697c13edc51f230446d13 Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Tue, 17 Dec 2024 14:22:46 +0800 Subject: [PATCH] perf(thrift): simplify skipdecoder tpl code MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit goos: darwin goarch: arm64 pkg: github.com/cloudwego/gopkg/protocol/thrift cpu: Apple M2 Pro │ ./old.txt │ ./new.txt │ │ sec/op │ sec/op vs base │ SkipDecoder-12 15.18n ± 6% 13.24n ± 1% -12.75% (p=0.000 n=10) --- protocol/thrift/skipdecoder.go | 8 +-- protocol/thrift/skipdecoder_test.go | 39 +++++++++++--- protocol/thrift/skipdecoder_tpl.go | 82 +++++++++++++++++------------ 3 files changed, 83 insertions(+), 46 deletions(-) diff --git a/protocol/thrift/skipdecoder.go b/protocol/thrift/skipdecoder.go index ee83704..bcd7aa8 100644 --- a/protocol/thrift/skipdecoder.go +++ b/protocol/thrift/skipdecoder.go @@ -59,7 +59,7 @@ func (p *SkipDecoder) Release() { // The returned buf is directly from bufiox.Reader with the same lifecycle. func (p *SkipDecoder) Next(t TType) (buf []byte, err error) { p.rn = 0 - if err = NewSkipDecoderTpl(p).Skip(t, defaultRecursionDepth); err != nil { + if err = skipDecoderImpl(p, t, defaultRecursionDepth); err != nil { return } buf, err = p.r.Next(p.rn) @@ -116,7 +116,7 @@ func (p *BytesSkipDecoder) Reset(b []byte) { // // The returned buf refers to the input []byte without copy func (p *BytesSkipDecoder) Next(t TType) (b []byte, err error) { - if err = NewSkipDecoderTpl(p).Skip(t, defaultRecursionDepth); err != nil { + if err = skipDecoderImpl(p, t, defaultRecursionDepth); err != nil { return } b = p.b[:p.n] @@ -125,7 +125,7 @@ func (p *BytesSkipDecoder) Next(t TType) (b []byte, err error) { return } -// SkipN implements SkipDecoderIface +// SkipN implements skipDecoderIface func (p *BytesSkipDecoder) SkipN(n int) ([]byte, error) { if len(p.b) >= p.n+n { p.n += n @@ -194,7 +194,7 @@ func (p *ReaderSkipDecoder) growSlow(n int) { // The returned []byte is valid before the next `Next` call or `Release` func (p *ReaderSkipDecoder) Next(t TType) (b []byte, err error) { p.n = 0 - if err = NewSkipDecoderTpl(p).Skip(t, defaultRecursionDepth); err != nil { + if err = skipDecoderImpl(p, t, defaultRecursionDepth); err != nil { return } return p.b[:p.n], nil diff --git a/protocol/thrift/skipdecoder_test.go b/protocol/thrift/skipdecoder_test.go index 906b868..f8fe338 100644 --- a/protocol/thrift/skipdecoder_test.go +++ b/protocol/thrift/skipdecoder_test.go @@ -214,8 +214,8 @@ func BenchmarkSkipDecoder(b *testing.B) { // MAP, fid=8 bs = Binary.AppendFieldBegin(bs, MAP, 8) - bs = Binary.AppendMapBegin(bs, DOUBLE, DOUBLE, 1) - bs = Binary.AppendDouble(bs, 8.1) + bs = Binary.AppendMapBegin(bs, I16, DOUBLE, 1) + bs = Binary.AppendI16(bs, 8) bs = Binary.AppendDouble(bs, 8.2) // SET, fid=9 @@ -234,21 +234,44 @@ func BenchmarkSkipDecoder(b *testing.B) { bs = Binary.AppendI64(bs, 11) bs = Binary.AppendFieldStop(bs) + // MAP(STRUCT), fid=12 + bs = Binary.AppendFieldBegin(bs, MAP, 12) + bs = Binary.AppendMapBegin(bs, I64, STRUCT, 1) + bs = Binary.AppendI64(bs, 12) + bs = Binary.AppendFieldStop( + Binary.AppendI64(Binary.AppendFieldBegin(bs, I64, 1), 121), // fid=1, I64, v=121 + ) + + // LIST(STRUCT), fid=13 + bs = Binary.AppendFieldBegin(bs, LIST, 13) + bs = Binary.AppendListBegin(bs, STRUCT, 1) + bs = Binary.AppendFieldStop( + Binary.AppendI64(Binary.AppendFieldBegin(bs, I64, 1), 121), // fid=1, I64, v=121 + ) + // Finish struct bs = Binary.AppendFieldStop(bs) + // run a test first before benchmark + sr := NewBytesSkipDecoder(bs) + buf, err := sr.Next(STRUCT) + if err != nil { + b.Fatal(err) + } + if !bytes.Equal(buf, bs) { + b.Fatal("bytes not equal") + } + sr.Release() + b.ResetTimer() b.RunParallel(func(pb *testing.PB) { + sr := &BytesSkipDecoder{} for pb.Next() { - sr := NewBytesSkipDecoder(bs) - buf, err := sr.Next(STRUCT) + sr.Reset(bs) + _, err := sr.Next(STRUCT) if err != nil { b.Fatal(err) } - if !bytes.Equal(buf, bs) { - b.Fatal("bytes not equal") - } - sr.Release() } }) } diff --git a/protocol/thrift/skipdecoder_tpl.go b/protocol/thrift/skipdecoder_tpl.go index d23b4f3..3eada56 100644 --- a/protocol/thrift/skipdecoder_tpl.go +++ b/protocol/thrift/skipdecoder_tpl.go @@ -21,41 +21,32 @@ import ( "fmt" ) -// SkipDecoderIface represent the generics constraint of a SkipDecoder. +// skipDecoderIface represent the generics constraint of a SkipDecoder. // -// It's used by SkipDecoderTpl -type SkipDecoderIface interface { - // SkipN read and skip n bytes +// It's used by skipDecoderImpl +type skipDecoderIface interface { + // SkipN reads and skips n bytes // - // SkipDecoderTpl will not hold or modify the bytes between two `SkipN` calls. + // SkipDecoderIface will not hold or modify the bytes between two `SkipN` calls. // It's safe to reuse buffer for next `SkipN` call. - // - // if SkipN is short enough, it can be inlined. SkipN(n int) ([]byte, error) } -// SkipDecoderTpl is the core logic of skipping thrift binary -type SkipDecoderTpl[T SkipDecoderIface] struct { - r T -} - -// NewSkipDecoderTpl ... -func NewSkipDecoderTpl[T SkipDecoderIface](r T) SkipDecoderTpl[T] { - return SkipDecoderTpl[T]{r} -} - -// Skip ... -func (p SkipDecoderTpl[T]) Skip(t TType, maxdepth int) error { +// NOTE: At the time of writing Go generics doesn't fully support monomorphization, and +// it doesn't generate code copies for specific types which means +// inline calling of SkipN is not working ... +// This would be fixed in the future hopefully, so we use generics here. +func skipDecoderImpl[T skipDecoderIface](r T, t TType, maxdepth int) error { if maxdepth == 0 { return errDepthLimitExceeded } if sz := typeToSize[t]; sz > 0 { - _, err := p.r.SkipN(int(sz)) + _, err := r.SkipN(int(sz)) return err } switch t { case STRING: - b, err := p.r.SkipN(4) + b, err := r.SkipN(4) if err != nil { return err } @@ -63,12 +54,12 @@ func (p SkipDecoderTpl[T]) Skip(t TType, maxdepth int) error { if sz < 0 { return errDataLength } - if _, err := p.r.SkipN(sz); err != nil { + if _, err := r.SkipN(sz); err != nil { return err } case STRUCT: for { - b, err := p.r.SkipN(1) // TType + b, err := r.SkipN(1) // TType if err != nil { return err } @@ -76,15 +67,26 @@ func (p SkipDecoderTpl[T]) Skip(t TType, maxdepth int) error { if tp == STOP { break } - if _, err := p.r.SkipN(2); err != nil { // Field ID + if sz := typeToSize[tp]; sz > 0 { + // fastpath + // Field ID + Value + if _, err := r.SkipN(2 + int(sz)); err != nil { + return err + } + continue + } + + // Field ID + if _, err := r.SkipN(2); err != nil { return err } - if err := p.Skip(tp, maxdepth-1); err != nil { + // Field Value + if err := skipDecoderImpl(r, tp, maxdepth-1); err != nil { return err } } case MAP: - b, err := p.r.SkipN(6) // 1 byte key TType, 1 byte value TType, 4 bytes Len + b, err := r.SkipN(6) // 1 byte key TType, 1 byte value TType, 4 bytes Len if err != nil { return err } @@ -94,19 +96,31 @@ func (p SkipDecoderTpl[T]) Skip(t TType, maxdepth int) error { } ksz, vsz := int(typeToSize[kt]), int(typeToSize[vt]) if ksz > 0 && vsz > 0 { - _, err := p.r.SkipN(int(sz) * (ksz + vsz)) + _, err := r.SkipN(int(sz) * (ksz + vsz)) return err } for i := int32(0); i < sz; i++ { - if err := p.Skip(kt, maxdepth-1); err != nil { - return err + if ksz > 0 { + if _, err := r.SkipN(ksz); err != nil { + return err + } + } else { + if err := skipDecoderImpl(r, kt, maxdepth-1); err != nil { + return err + } } - if err := p.Skip(vt, maxdepth-1); err != nil { - return err + if vsz > 0 { + if _, err := r.SkipN(vsz); err != nil { + return err + } + } else { + if err := skipDecoderImpl(r, vt, maxdepth-1); err != nil { + return err + } } } case SET, LIST: - b, err := p.r.SkipN(5) // 1 byte value type, 4 bytes Len + b, err := r.SkipN(5) // 1 byte value type, 4 bytes Len if err != nil { return err } @@ -115,11 +129,11 @@ func (p SkipDecoderTpl[T]) Skip(t TType, maxdepth int) error { return errDataLength } if vsz := typeToSize[vt]; vsz > 0 { - _, err := p.r.SkipN(int(sz) * int(vsz)) + _, err := r.SkipN(int(sz) * int(vsz)) return err } for i := int32(0); i < sz; i++ { - if err := p.Skip(vt, maxdepth-1); err != nil { + if err := skipDecoderImpl(r, vt, maxdepth-1); err != nil { return err } }