diff --git a/.golangci.yml b/.golangci.yml index 193366c2..db3f0cd9 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -9,6 +9,9 @@ run: issues: exclude-use-default: false max-same-issues: 0 # print all failures + exclude-rules: + - path: pkg/kgo/internal/bucketed_pool.go + text: "SA6002: argument should be pointer-like to avoid allocations" linters: disable-all: true diff --git a/pkg/kgo/compression.go b/pkg/kgo/compression.go index fe8ad645..0ec00a73 100644 --- a/pkg/kgo/compression.go +++ b/pkg/kgo/compression.go @@ -5,6 +5,7 @@ import ( "compress/gzip" "encoding/binary" "errors" + pool "github.com/twmb/franz-go/pkg/kgo/internal" "io" "runtime" "sync" @@ -14,8 +15,6 @@ import ( "github.com/pierrec/lz4/v4" ) -var byteBuffers = sync.Pool{New: func() any { return bytes.NewBuffer(make([]byte, 8<<10)) }} - type codecType int8 const ( @@ -230,9 +229,10 @@ func (c *compressor) compress(dst *bytes.Buffer, src []byte, produceRequestVersi } type decompressor struct { - ungzPool sync.Pool - unlz4Pool sync.Pool - unzstdPool sync.Pool + ungzPool sync.Pool + unlz4Pool sync.Pool + unzstdPool sync.Pool + bucketedByteBufferPool *pool.BucketedPool[byte] } func newDecompressor() *decompressor { @@ -256,6 +256,7 @@ func newDecompressor() *decompressor { return r }, }, + bucketedByteBufferPool: pool.NewBucketedPool[byte](1, 50<<20, 2, func(int) []byte { return make([]byte, 1<<10) }), } return d } @@ -270,9 +271,16 @@ func (d *decompressor) decompress(src []byte, codec byte) ([]byte, error) { if compCodec == codecNone { return src, nil } - out := byteBuffers.Get().(*bytes.Buffer) - out.Reset() - defer byteBuffers.Put(out) + out, buf, err := d.getDecodedBuffer(src, compCodec) + if err != nil { + return nil, err + } + defer func() { + if compCodec == codecSnappy { + return + } + d.bucketedByteBufferPool.Put(buf) + }() switch compCodec { case codecGzip: @@ -284,7 +292,7 @@ func (d *decompressor) decompress(src []byte, codec byte) ([]byte, error) { if _, err := io.Copy(out, ungz); err != nil { return nil, err } - return append([]byte(nil), out.Bytes()...), nil + return d.copyDecodedBuffer(out.Bytes(), compCodec), nil case codecSnappy: if len(src) > 16 && bytes.HasPrefix(src, xerialPfx) { return xerialDecode(src) @@ -293,7 +301,7 @@ func (d *decompressor) decompress(src []byte, codec byte) ([]byte, error) { if err != nil { return nil, err } - return append([]byte(nil), decoded...), nil + return d.copyDecodedBuffer(decoded, compCodec), nil case codecLZ4: unlz4 := d.unlz4Pool.Get().(*lz4.Reader) defer d.unlz4Pool.Put(unlz4) @@ -301,7 +309,7 @@ func (d *decompressor) decompress(src []byte, codec byte) ([]byte, error) { if _, err := io.Copy(out, unlz4); err != nil { return nil, err } - return append([]byte(nil), out.Bytes()...), nil + return d.copyDecodedBuffer(out.Bytes(), compCodec), nil case codecZstd: unzstd := d.unzstdPool.Get().(*zstdDecoder) defer d.unzstdPool.Put(unzstd) @@ -309,7 +317,7 @@ func (d *decompressor) decompress(src []byte, codec byte) ([]byte, error) { if err != nil { return nil, err } - return append([]byte(nil), decoded...), nil + return d.copyDecodedBuffer(decoded, compCodec), nil default: return nil, errors.New("unknown compression codec") } @@ -344,3 +352,34 @@ func xerialDecode(src []byte) ([]byte, error) { } return dst, nil } + +func (d *decompressor) getDecodedBuffer(src []byte, compCodec codecType) (*bytes.Buffer, []byte, error) { + var ( + decodedBufSize int + err error + ) + switch compCodec { + case codecSnappy: + decodedBufSize, err = s2.DecodedLen(src) + if err != nil { + return nil, nil, err + } + + default: + // Make a guess at the output size. + decodedBufSize = len(src) * 2 + } + buf := d.bucketedByteBufferPool.Get(decodedBufSize)[:0] + + return bytes.NewBuffer(buf), buf, nil +} + +func (d *decompressor) copyDecodedBuffer(decoded []byte, compCodec codecType) []byte { + if compCodec == codecSnappy { + // We already know the actual size of the decoded buffer before decompression, + // so there's no need to copy the buffer. + return decoded + } + out := d.bucketedByteBufferPool.Get(len(decoded)) + return append(out[:0], decoded...) +} diff --git a/pkg/kgo/internal/bucketed_pool.go b/pkg/kgo/internal/bucketed_pool.go new file mode 100644 index 00000000..d6af0390 --- /dev/null +++ b/pkg/kgo/internal/bucketed_pool.go @@ -0,0 +1,92 @@ +// Copyright 2017 The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pool + +import ( + "sync" +) + +// BucketedPool is a bucketed pool for variably sized slices. +type BucketedPool[T any] struct { + buckets []sync.Pool + sizes []int + // make is the function used to create an empty slice when none exist yet. + make func(int) []T +} + +// NewBucketedPool returns a new BucketedPool with size buckets for minSize to maxSize +// increasing by the given factor. +func NewBucketedPool[T any](minSize, maxSize int, factor float64, makeFunc func(int) []T) *BucketedPool[T] { + if minSize < 1 { + panic("invalid minimum pool size") + } + if maxSize < 1 { + panic("invalid maximum pool size") + } + if factor < 1 { + panic("invalid factor") + } + + var sizes []int + + for s := minSize; s <= maxSize; s = int(float64(s) * factor) { + sizes = append(sizes, s) + } + + p := &BucketedPool[T]{ + buckets: make([]sync.Pool, len(sizes)), + sizes: sizes, + make: makeFunc, + } + return p +} + +// Get returns a new slice with capacity greater than or equal to size. +func (p *BucketedPool[T]) Get(size int) []T { + for i, bktSize := range p.sizes { + if size > bktSize { + continue + } + buff := p.buckets[i].Get() + if buff == nil { + buff = p.make(bktSize) + } + return buff.([]T) + } + return p.make(size) +} + +// Put adds a slice to the right bucket in the pool. +// If the slice does not belong to any bucket in the pool, it is ignored. +func (p *BucketedPool[T]) Put(s []T) { + sCap := cap(s) + if sCap < p.sizes[0] { + return + } + + for i, size := range p.sizes { + if sCap > size { + continue + } + + if sCap == size { + // Buffer is exactly the minimum size for this bucket. Add it to this bucket. + p.buckets[i].Put(s) + } else { + // Buffer belongs in previous bucket. + p.buckets[i-1].Put(s) + } + return + } +} diff --git a/pkg/kgo/internal/bucketed_pool_test.go b/pkg/kgo/internal/bucketed_pool_test.go new file mode 100644 index 00000000..71735c1a --- /dev/null +++ b/pkg/kgo/internal/bucketed_pool_test.go @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: Apache-2.0 +// Provenance-includes-location: https://github.com/prometheus/prometheus/blob/main/util/pool/pool_test.go +// Provenance-includes-copyright: The Prometheus Authors + +package pool + +import ( + "testing" +) + +func makeFunc(size int) []int { + return make([]int, 0, size) +} + +func TestBucketedPool_HappyPath(t *testing.T) { + testPool := NewBucketedPool(1, 8, 2, makeFunc) + cases := []struct { + size int + expectedCap int + }{ + { + size: -1, + expectedCap: 1, + }, + { + size: 3, + expectedCap: 4, + }, + { + size: 10, + expectedCap: 10, + }, + } + for _, c := range cases { + ret := testPool.Get(c.size) + if cap(ret) < c.expectedCap { + t.Fatalf("expected cap >= %d, got %d", c.expectedCap, cap(ret)) + } + testPool.Put(ret) + } +} + +func TestBucketedPool_SliceNotAlignedToBuckets(t *testing.T) { + pool := NewBucketedPool(1, 1000, 10, makeFunc) + pool.Put(make([]int, 0, 2)) + s := pool.Get(3) + if cap(s) < 3 { + t.Fatalf("expected cap >= 3, got %d", cap(s)) + } +} + +func TestBucketedPool_PutEmptySlice(t *testing.T) { + pool := NewBucketedPool(1, 1000, 10, makeFunc) + pool.Put([]int{}) + s := pool.Get(1) + if cap(s) < 1 { + t.Fatalf("expected cap >= 1, got %d", cap(s)) + } +} + +func TestBucketedPool_PutSliceSmallerThanMinimum(t *testing.T) { + pool := NewBucketedPool(3, 1000, 10, makeFunc) + pool.Put([]int{1, 2}) + s := pool.Get(3) + if cap(s) < 3 { + t.Fatalf("expected cap >= 3, got %d", cap(s)) + } +} diff --git a/pkg/kgo/pools.go b/pkg/kgo/pools.go new file mode 100644 index 00000000..78055482 --- /dev/null +++ b/pkg/kgo/pools.go @@ -0,0 +1,9 @@ +package kgo + +import ( + "bytes" + "sync" +) + +// shared general purpose byte buff +var byteBuffers = sync.Pool{New: func() any { return bytes.NewBuffer(make([]byte, 8<<10)) }}