Skip to content

Commit

Permalink
Integrate zstd compression into chain exchange (#842)
Browse files Browse the repository at this point in the history
* Integrate zstd compression into chain exchange

The GPBFT message exchange over pubsub already uses zstd compression on
top of CBOR encoded messages. The work here integrates the same style
of compression for chain exchange messages, with additional
unification of the encoding mechanism across the two.

The work refactors the root level encoding implementation into a generic
encoder decoder that both chain exchange and gpbft used. Tests and
benchmarks are updated to reflect this.

The benchmarking of partial gmessage encoding is also adjusted to fix a
few redundant statements and bugs in testing.

Fixes #819

* Strictly Limit the size of decompressed values to 1 MiB

The default message size limit in GossipSub is 1 MiB, which is unchanged
in Lotus. This means when decompressing values, we can never have a
valid compressed message that expands to larger than 1 MiB.

Set this limit explicitly in the zstd decoder.

* Massage the flaky test to submission
  • Loading branch information
masih authored Jan 27, 2025
1 parent ded3d04 commit 2d636c9
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 112 deletions.
18 changes: 11 additions & 7 deletions chainexchange/pubsub.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package chainexchange

import (
"bytes"
"context"
"fmt"
"sync"
"time"

"github.com/filecoin-project/go-f3/gpbft"
"github.com/filecoin-project/go-f3/internal/encoding"
"github.com/filecoin-project/go-f3/internal/psutil"
lru "github.com/hashicorp/golang-lru/v2"
logging "github.com/ipfs/go-log/v2"
Expand Down Expand Up @@ -38,18 +38,24 @@ type PubSubChainExchange struct {
pendingCacheAsWanted chan Message
topic *pubsub.Topic
stop func() error
encoding *encoding.ZSTD[*Message]
}

func NewPubSubChainExchange(o ...Option) (*PubSubChainExchange, error) {
opts, err := newOptions(o...)
if err != nil {
return nil, err
}
zstd, err := encoding.NewZSTD[*Message]()
if err != nil {
return nil, err
}
return &PubSubChainExchange{
options: opts,
chainsWanted: map[uint64]*lru.Cache[gpbft.ECChainKey, *chainPortion]{},
chainsDiscovered: map[uint64]*lru.Cache[gpbft.ECChainKey, *chainPortion]{},
pendingCacheAsWanted: make(chan Message, 100), // TODO: parameterise.
encoding: zstd,
}, nil
}

Expand Down Expand Up @@ -189,8 +195,7 @@ func (p *PubSubChainExchange) newChainPortionCache(capacity int) *lru.Cache[gpbf

func (p *PubSubChainExchange) validatePubSubMessage(_ context.Context, _ peer.ID, msg *pubsub.Message) pubsub.ValidationResult {
var cmsg Message
buf := bytes.NewBuffer(msg.Data)
if err := cmsg.UnmarshalCBOR(buf); err != nil {
if err := p.encoding.Decode(msg.Data, &cmsg); err != nil {
log.Debugw("failed to decode message", "from", msg.GetFrom(), "err", err)
return pubsub.ValidationReject
}
Expand Down Expand Up @@ -266,12 +271,11 @@ func (p *PubSubChainExchange) Broadcast(ctx context.Context, msg Message) error
log.Warnw("Dropping wanted cache entry. Chain exchange is too slow to process chains as wanted", "msg", msg)
}

// TODO: integrate zstd compression.
var buf bytes.Buffer
if err := msg.MarshalCBOR(&buf); err != nil {
encoded, err := p.encoding.Encode(&msg)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
if err := p.topic.Publish(ctx, buf.Bytes()); err != nil {
if err := p.topic.Publish(ctx, encoded); err != nil {
return fmt.Errorf("failed to publish message: %w", err)
}
return nil
Expand Down
50 changes: 27 additions & 23 deletions msg_encoding_test.go → encoding_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/filecoin-project/go-bitfield"
"github.com/filecoin-project/go-f3/gpbft"
"github.com/filecoin-project/go-f3/internal/encoding"
"github.com/ipfs/go-cid"
"github.com/multiformats/go-multihash"
"github.com/stretchr/testify/require"
Expand All @@ -15,23 +16,23 @@ const seed = 1413

func BenchmarkCborEncoding(b *testing.B) {
rng := rand.New(rand.NewSource(seed))
encoder := &cborGMessageEncoding{}
encoder := encoding.NewCBOR[*PartialGMessage]()
msg := generateRandomPartialGMessage(b, rng)

b.ResetTimer()
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
if _, err := encoder.Encode(msg); err != nil {
require.NoError(b, err)
}
got, err := encoder.Encode(msg)
require.NoError(b, err)
require.NotEmpty(b, got)
}
})
}

func BenchmarkCborDecoding(b *testing.B) {
rng := rand.New(rand.NewSource(seed))
encoder := &cborGMessageEncoding{}
encoder := encoding.NewCBOR[*PartialGMessage]()
msg := generateRandomPartialGMessage(b, rng)
data, err := encoder.Encode(msg)
require.NoError(b, err)
Expand All @@ -40,34 +41,33 @@ func BenchmarkCborDecoding(b *testing.B) {
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
if got, err := encoder.Decode(data); err != nil {
require.NoError(b, err)
require.Equal(b, msg, got)
}
var got PartialGMessage
require.NoError(b, encoder.Decode(data, &got))
require.Equal(b, msg, &got)
}
})
}

func BenchmarkZstdEncoding(b *testing.B) {
rng := rand.New(rand.NewSource(seed))
encoder, err := newZstdGMessageEncoding()
encoder, err := encoding.NewZSTD[*PartialGMessage]()
require.NoError(b, err)
msg := generateRandomPartialGMessage(b, rng)

b.ResetTimer()
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
if _, err := encoder.Encode(msg); err != nil {
require.NoError(b, err)
}
got, err := encoder.Encode(msg)
require.NoError(b, err)
require.NotEmpty(b, got)
}
})
}

func BenchmarkZstdDecoding(b *testing.B) {
rng := rand.New(rand.NewSource(seed))
encoder, err := newZstdGMessageEncoding()
encoder, err := encoding.NewZSTD[*PartialGMessage]()
require.NoError(b, err)
msg := generateRandomPartialGMessage(b, rng)
data, err := encoder.Encode(msg)
Expand All @@ -77,10 +77,9 @@ func BenchmarkZstdDecoding(b *testing.B) {
b.ReportAllocs()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
if got, err := encoder.Decode(data); err != nil {
require.NoError(b, err)
require.Equal(b, msg, got)
}
var got PartialGMessage
require.NoError(b, encoder.Decode(data, &got))
require.Equal(b, msg, &got)
}
})
}
Expand All @@ -99,9 +98,8 @@ func generateRandomPartialGMessage(b *testing.B, rng *rand.Rand) *PartialGMessag
func generateRandomGMessage(b *testing.B, rng *rand.Rand) *gpbft.GMessage {
var maybeTicket []byte
if rng.Float64() < 0.5 {
generateRandomBytes(b, rng, 96)
maybeTicket = generateRandomBytes(b, rng, 96)
}

return &gpbft.GMessage{
Sender: gpbft.ActorID(rng.Uint64()),
Vote: generateRandomPayload(b, rng),
Expand All @@ -114,7 +112,7 @@ func generateRandomGMessage(b *testing.B, rng *rand.Rand) *gpbft.GMessage {
func generateRandomJustification(b *testing.B, rng *rand.Rand) *gpbft.Justification {
return &gpbft.Justification{
Vote: generateRandomPayload(b, rng),
Signers: generateRandomBitfield(rng),
Signers: generateRandomBitfield(b, rng),
Signature: generateRandomBytes(b, rng, 96),
}
}
Expand All @@ -138,12 +136,18 @@ func generateRandomPayload(b *testing.B, rng *rand.Rand) gpbft.Payload {
}
}

func generateRandomBitfield(rng *rand.Rand) bitfield.BitField {
func generateRandomBitfield(b *testing.B, rng *rand.Rand) bitfield.BitField {
ids := make([]uint64, rng.Intn(2_000)+1)
for i := range ids {
ids[i] = rng.Uint64()
}
return bitfield.NewFromSet(ids)
// Copy the bitfield once to force initialization of internal bit field state.
// This is to work around the equality assertions in tests, where under the hood
// reflection is used to check for equality. This way we can avoid writing custom
// equality checking for bitfields.
bitField, err := bitfield.NewFromSet(ids).Copy()
require.NoError(b, err)
return bitField
}

func generateRandomECChain(b *testing.B, rng *rand.Rand, length int) *gpbft.ECChain {
Expand Down
2 changes: 1 addition & 1 deletion f3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ func (e *testEnv) waitForEpochFinalized(epoch int64) {
// here and reduce the clock advance to give messages a chance of being
// delivered in time. See:
// - https://github.com/filecoin-project/go-f3/issues/818
time.Sleep(10 * time.Millisecond)
time.Sleep(20 * time.Millisecond)
for _, nd := range e.nodes {
if nd.f3 == nil || !nd.f3.IsRunning() {
continue
Expand Down
15 changes: 8 additions & 7 deletions host.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/filecoin-project/go-f3/gpbft"
"github.com/filecoin-project/go-f3/internal/caching"
"github.com/filecoin-project/go-f3/internal/clock"
"github.com/filecoin-project/go-f3/internal/encoding"
"github.com/filecoin-project/go-f3/internal/psutil"
"github.com/filecoin-project/go-f3/internal/writeaheadlog"
"github.com/filecoin-project/go-f3/manifest"
Expand Down Expand Up @@ -52,7 +53,7 @@ type gpbftRunner struct {
selfMessages map[uint64]map[roundPhase][]*gpbft.GMessage

inputs gpbftInputs
msgEncoding gMessageEncoding
msgEncoding encoding.EncodeDecoder[*PartialGMessage]
pmm *partialMessageManager
pmv *cachingPartialValidator
pmCache *caching.GroupedSet
Expand Down Expand Up @@ -138,12 +139,12 @@ func newRunner(
runner.participant = p

if runner.manifest.PubSub.CompressionEnabled {
runner.msgEncoding, err = newZstdGMessageEncoding()
runner.msgEncoding, err = encoding.NewZSTD[*PartialGMessage]()
if err != nil {
return nil, err
}
} else {
runner.msgEncoding = &cborGMessageEncoding{}
runner.msgEncoding = encoding.NewCBOR[*PartialGMessage]()
}

runner.pmm, err = newPartialMessageManager(runner.Progress, ps, m)
Expand Down Expand Up @@ -541,15 +542,15 @@ func (h *gpbftRunner) validatePubsubMessage(ctx context.Context, _ peer.ID, msg
recordValidationTime(ctx, start, _result)
}(time.Now())

pgmsg, err := h.msgEncoding.Decode(msg.Data)
if err != nil {
var pgmsg PartialGMessage
if err := h.msgEncoding.Decode(msg.Data, &pgmsg); err != nil {
log.Debugw("failed to decode message", "from", msg.GetFrom(), "err", err)
return pubsub.ValidationReject
}

gmsg, completed := h.pmm.CompleteMessage(ctx, pgmsg)
gmsg, completed := h.pmm.CompleteMessage(ctx, &pgmsg)
if !completed {
partiallyValidatedMessage, err := h.pmv.PartiallyValidateMessage(pgmsg)
partiallyValidatedMessage, err := h.pmv.PartiallyValidateMessage(&pgmsg)
result := pubsubValidationResultFromError(err)
if result == pubsub.ValidationAccept {
msg.ValidatorData = partiallyValidatedMessage
Expand Down
86 changes: 86 additions & 0 deletions internal/encoding/encoding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package encoding

import (
"bytes"
"fmt"

"github.com/klauspost/compress/zstd"
cbg "github.com/whyrusleeping/cbor-gen"
)

// maxDecompressedSize is the default maximum amount of memory allocated by the
// zstd decoder. The limit of 1MiB is chosen based on the default maximum message
// size in GossipSub.
const maxDecompressedSize = 1 << 20

type CBORMarshalUnmarshaler interface {
cbg.CBORMarshaler
cbg.CBORUnmarshaler
}

type EncodeDecoder[T CBORMarshalUnmarshaler] interface {
Encode(v T) ([]byte, error)
Decode([]byte, T) error
}

type CBOR[T CBORMarshalUnmarshaler] struct{}

func NewCBOR[T CBORMarshalUnmarshaler]() *CBOR[T] {
return &CBOR[T]{}
}

func (c *CBOR[T]) Encode(m T) ([]byte, error) {
var buf bytes.Buffer
if err := m.MarshalCBOR(&buf); err != nil {
return nil, err
}
return buf.Bytes(), nil
}

func (c *CBOR[T]) Decode(v []byte, t T) error {
r := bytes.NewReader(v)
return t.UnmarshalCBOR(r)
}

type ZSTD[T CBORMarshalUnmarshaler] struct {
cborEncoding *CBOR[T]
compressor *zstd.Encoder
decompressor *zstd.Decoder
}

func NewZSTD[T CBORMarshalUnmarshaler]() (*ZSTD[T], error) {
writer, err := zstd.NewWriter(nil)
if err != nil {
return nil, err
}
reader, err := zstd.NewReader(nil, zstd.WithDecoderMaxMemory(maxDecompressedSize))
if err != nil {
return nil, err
}
return &ZSTD[T]{
cborEncoding: &CBOR[T]{},
compressor: writer,
decompressor: reader,
}, nil
}

func (c *ZSTD[T]) Encode(m T) ([]byte, error) {
cborEncoded, err := c.cborEncoding.Encode(m)
if len(cborEncoded) > maxDecompressedSize {
// Error out early if the encoded value is too large to be decompressed.
return nil, fmt.Errorf("encoded value cannot exceed maximum size: %d > %d", len(cborEncoded), maxDecompressedSize)
}
if err != nil {
return nil, err
}
compressed := c.compressor.EncodeAll(cborEncoded, make([]byte, 0, len(cborEncoded)))
return compressed, nil
}

func (c *ZSTD[T]) Decode(v []byte, t T) error {
cborEncoded, err := c.decompressor.DecodeAll(v, make([]byte, 0, len(v)))
if err != nil {
return err
}
return c.cborEncoding.Decode(cborEncoded, t)
}
Loading

0 comments on commit 2d636c9

Please sign in to comment.