Skip to content

Commit

Permalink
Add bucketed pool for decompression
Browse files Browse the repository at this point in the history
  • Loading branch information
Sovietaced committed Nov 15, 2024
1 parent 81ceb1a commit bac08fd
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 12 deletions.
3 changes: 3 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ run:
issues:
exclude-use-default: false
max-same-issues: 0 # print all failures
exclude-rules:
- path: pkg/kgo/internal/bucketed_pool.go
text: "SA6002: argument should be pointer-like to avoid allocations"

linters:
disable-all: true
Expand Down
63 changes: 51 additions & 12 deletions pkg/kgo/compression.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"compress/gzip"
"encoding/binary"
"errors"
pool "github.com/twmb/franz-go/pkg/kgo/internal"
"io"
"runtime"
"sync"
Expand All @@ -14,8 +15,6 @@ import (
"github.com/pierrec/lz4/v4"
)

var byteBuffers = sync.Pool{New: func() any { return bytes.NewBuffer(make([]byte, 8<<10)) }}

type codecType int8

const (
Expand Down Expand Up @@ -230,9 +229,10 @@ func (c *compressor) compress(dst *bytes.Buffer, src []byte, produceRequestVersi
}

type decompressor struct {
ungzPool sync.Pool
unlz4Pool sync.Pool
unzstdPool sync.Pool
ungzPool sync.Pool
unlz4Pool sync.Pool
unzstdPool sync.Pool
bucketedByteBufferPool *pool.BucketedPool[byte]
}

func newDecompressor() *decompressor {
Expand All @@ -256,6 +256,7 @@ func newDecompressor() *decompressor {
return r
},
},
bucketedByteBufferPool: pool.NewBucketedPool[byte](1, 50<<20, 2, func(int) []byte { return make([]byte, 1<<10) }),
}
return d
}
Expand All @@ -270,9 +271,16 @@ func (d *decompressor) decompress(src []byte, codec byte) ([]byte, error) {
if compCodec == codecNone {
return src, nil
}
out := byteBuffers.Get().(*bytes.Buffer)
out.Reset()
defer byteBuffers.Put(out)
out, buf, err := d.getDecodedBuffer(src, compCodec)
if err != nil {
return nil, err
}
defer func() {
if compCodec == codecSnappy {
return
}
d.bucketedByteBufferPool.Put(buf)
}()

switch compCodec {
case codecGzip:
Expand All @@ -284,7 +292,7 @@ func (d *decompressor) decompress(src []byte, codec byte) ([]byte, error) {
if _, err := io.Copy(out, ungz); err != nil {
return nil, err
}
return append([]byte(nil), out.Bytes()...), nil
return d.copyDecodedBuffer(out.Bytes(), compCodec), nil
case codecSnappy:
if len(src) > 16 && bytes.HasPrefix(src, xerialPfx) {
return xerialDecode(src)
Expand All @@ -293,23 +301,23 @@ func (d *decompressor) decompress(src []byte, codec byte) ([]byte, error) {
if err != nil {
return nil, err
}
return append([]byte(nil), decoded...), nil
return d.copyDecodedBuffer(decoded, compCodec), nil
case codecLZ4:
unlz4 := d.unlz4Pool.Get().(*lz4.Reader)
defer d.unlz4Pool.Put(unlz4)
unlz4.Reset(bytes.NewReader(src))
if _, err := io.Copy(out, unlz4); err != nil {
return nil, err
}
return append([]byte(nil), out.Bytes()...), nil
return d.copyDecodedBuffer(out.Bytes(), compCodec), nil
case codecZstd:
unzstd := d.unzstdPool.Get().(*zstdDecoder)
defer d.unzstdPool.Put(unzstd)
decoded, err := unzstd.inner.DecodeAll(src, out.Bytes())
if err != nil {
return nil, err
}
return append([]byte(nil), decoded...), nil
return d.copyDecodedBuffer(decoded, compCodec), nil
default:
return nil, errors.New("unknown compression codec")
}
Expand Down Expand Up @@ -344,3 +352,34 @@ func xerialDecode(src []byte) ([]byte, error) {
}
return dst, nil
}

func (d *decompressor) getDecodedBuffer(src []byte, compCodec codecType) (*bytes.Buffer, []byte, error) {
var (
decodedBufSize int
err error
)
switch compCodec {
case codecSnappy:
decodedBufSize, err = s2.DecodedLen(src)
if err != nil {
return nil, nil, err
}

default:
// Make a guess at the output size.
decodedBufSize = len(src) * 2
}
buf := d.bucketedByteBufferPool.Get(decodedBufSize)[:0]

return bytes.NewBuffer(buf), buf, nil
}

func (d *decompressor) copyDecodedBuffer(decoded []byte, compCodec codecType) []byte {
if compCodec == codecSnappy {
// We already know the actual size of the decoded buffer before decompression,
// so there's no need to copy the buffer.
return decoded
}
out := d.bucketedByteBufferPool.Get(len(decoded))
return append(out[:0], decoded...)
}
92 changes: 92 additions & 0 deletions pkg/kgo/internal/bucketed_pool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Copyright 2017 The Prometheus Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package pool

import (
"sync"
)

// BucketedPool is a bucketed pool for variably sized slices.
type BucketedPool[T any] struct {
buckets []sync.Pool
sizes []int
// make is the function used to create an empty slice when none exist yet.
make func(int) []T
}

// NewBucketedPool returns a new BucketedPool with size buckets for minSize to maxSize
// increasing by the given factor.
func NewBucketedPool[T any](minSize, maxSize int, factor float64, makeFunc func(int) []T) *BucketedPool[T] {
if minSize < 1 {
panic("invalid minimum pool size")
}
if maxSize < 1 {
panic("invalid maximum pool size")
}
if factor < 1 {
panic("invalid factor")
}

var sizes []int

for s := minSize; s <= maxSize; s = int(float64(s) * factor) {
sizes = append(sizes, s)
}

p := &BucketedPool[T]{
buckets: make([]sync.Pool, len(sizes)),
sizes: sizes,
make: makeFunc,
}
return p
}

// Get returns a new slice with capacity greater than or equal to size.
func (p *BucketedPool[T]) Get(size int) []T {
for i, bktSize := range p.sizes {
if size > bktSize {
continue
}
buff := p.buckets[i].Get()
if buff == nil {
buff = p.make(bktSize)
}
return buff.([]T)
}
return p.make(size)
}

// Put adds a slice to the right bucket in the pool.
// If the slice does not belong to any bucket in the pool, it is ignored.
func (p *BucketedPool[T]) Put(s []T) {
sCap := cap(s)
if sCap < p.sizes[0] {
return
}

for i, size := range p.sizes {
if sCap > size {
continue
}

if sCap == size {
// Buffer is exactly the minimum size for this bucket. Add it to this bucket.
p.buckets[i].Put(s)
} else {
// Buffer belongs in previous bucket.
p.buckets[i-1].Put(s)
}
return
}
}
68 changes: 68 additions & 0 deletions pkg/kgo/internal/bucketed_pool_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// SPDX-License-Identifier: Apache-2.0
// Provenance-includes-location: https://github.com/prometheus/prometheus/blob/main/util/pool/pool_test.go
// Provenance-includes-copyright: The Prometheus Authors

package pool

import (
"testing"
)

func makeFunc(size int) []int {
return make([]int, 0, size)
}

func TestBucketedPool_HappyPath(t *testing.T) {
testPool := NewBucketedPool(1, 8, 2, makeFunc)
cases := []struct {
size int
expectedCap int
}{
{
size: -1,
expectedCap: 1,
},
{
size: 3,
expectedCap: 4,
},
{
size: 10,
expectedCap: 10,
},
}
for _, c := range cases {
ret := testPool.Get(c.size)
if cap(ret) < c.expectedCap {
t.Fatalf("expected cap >= %d, got %d", c.expectedCap, cap(ret))
}
testPool.Put(ret)
}
}

func TestBucketedPool_SliceNotAlignedToBuckets(t *testing.T) {
pool := NewBucketedPool(1, 1000, 10, makeFunc)
pool.Put(make([]int, 0, 2))
s := pool.Get(3)
if cap(s) < 3 {
t.Fatalf("expected cap >= 3, got %d", cap(s))
}
}

func TestBucketedPool_PutEmptySlice(t *testing.T) {
pool := NewBucketedPool(1, 1000, 10, makeFunc)
pool.Put([]int{})
s := pool.Get(1)
if cap(s) < 1 {
t.Fatalf("expected cap >= 1, got %d", cap(s))
}
}

func TestBucketedPool_PutSliceSmallerThanMinimum(t *testing.T) {
pool := NewBucketedPool(3, 1000, 10, makeFunc)
pool.Put([]int{1, 2})
s := pool.Get(3)
if cap(s) < 3 {
t.Fatalf("expected cap >= 3, got %d", cap(s))
}
}
9 changes: 9 additions & 0 deletions pkg/kgo/pools.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package kgo

import (
"bytes"
"sync"
)

// shared general purpose byte buff
var byteBuffers = sync.Pool{New: func() any { return bytes.NewBuffer(make([]byte, 8<<10)) }}

0 comments on commit bac08fd

Please sign in to comment.