diff --git a/p2p/subscriber.go b/p2p/subscriber.go index 300cd226..7ecfde1d 100644 --- a/p2p/subscriber.go +++ b/p2p/subscriber.go @@ -3,6 +3,7 @@ package p2p import ( "context" "errors" + "sync" pubsub "github.com/libp2p/go-libp2p-pubsub" "github.com/libp2p/go-libp2p/core/peer" @@ -29,6 +30,10 @@ type Subscriber[H header.Header[H]] struct { pubsub *pubsub.PubSub topic *pubsub.Topic msgID pubsub.MsgIdFunction + + verifierMu sync.Mutex + verifierSema chan struct{} // closed when verifier is set + verifier func(context.Context, H) error } // WithSubscriberMetrics enables metrics collection for the Subscriber. @@ -71,81 +76,50 @@ func NewSubscriber[H header.Header[H]]( pubsubTopicID: PubsubTopicID(params.networkID), pubsub: ps, msgID: msgID, + verifierSema: make(chan struct{}), }, nil } // Start starts the Subscriber and joins the instance's topic. SetVerifier must // be called separately to ensure a validator is mounted on the topic. func (s *Subscriber[H]) Start(context.Context) (err error) { - log.Infow("joining topic", "topic ID", s.pubsubTopicID) - s.topic, err = s.pubsub.Join(s.pubsubTopicID, pubsub.WithTopicMessageIdFn(s.msgID)) - return err -} + log.Debugf("joining topic", "topic ID", s.pubsubTopicID) + err = s.pubsub.RegisterTopicValidator(s.pubsubTopicID, s.verifyMessage) + if err != nil { + return err + } -// Stop closes the topic and unregisters its validator. -func (s *Subscriber[H]) Stop(context.Context) error { - regErr := s.pubsub.UnregisterTopicValidator(s.pubsubTopicID) - if regErr != nil { - // do not return this error as it is non-critical and usually - // means that a validator was not mounted. - log.Warnf("unregistering validator: %s", regErr) + topic, err := s.pubsub.Join(s.pubsubTopicID, pubsub.WithTopicMessageIdFn(s.msgID)) + if err != nil { + return err } - err := s.topic.Close() - return errors.Join(err, s.metrics.Close()) + s.topic = topic + return err } -// SetVerifier set given verification func as Header PubSub topic validator -// Does not punish peers if *header.VerifyError is given with Uncertain set to true. -func (s *Subscriber[H]) SetVerifier(val func(context.Context, H) error) error { - pval := func(ctx context.Context, p peer.ID, msg *pubsub.Message) (res pubsub.ValidationResult) { - defer func() { - err := recover() - if err != nil { - log.Errorf("PANIC while unmarshalling or verifying header: %s", err) - res = pubsub.ValidationReject - } - }() - - hdr := header.New[H]() - err := hdr.UnmarshalBinary(msg.Data) - if err != nil { - log.Errorw("unmarshalling header", - "from", p.ShortString(), - "err", err) - s.metrics.reject(ctx) - return pubsub.ValidationReject - } - // ensure header validity - err = hdr.Validate() - if err != nil { - log.Errorw("invalid header", - "from", p.ShortString(), - "err", err) - s.metrics.reject(ctx) - return pubsub.ValidationReject - } - - var verErr *header.VerifyError - err = val(ctx, hdr) - switch { - case errors.As(err, &verErr) && verErr.SoftFailure: - s.metrics.ignore(ctx) - return pubsub.ValidationIgnore - case err != nil: - s.metrics.reject(ctx) - return pubsub.ValidationReject - default: - } +// Stop closes the topic and unregisters its validator. +func (s *Subscriber[H]) Stop(context.Context) (err error) { + err = errors.Join(err, s.metrics.Close()) + // we must close the topic first and then unregister the validator + // this ensures we never get a message after the validator is unregistered + err = errors.Join(err, s.topic.Close()) + err = errors.Join(err, s.pubsub.UnregisterTopicValidator(s.pubsubTopicID)) + return err +} - // keep the valid header in the msg so Subscriptions can access it without - // additional unmarshalling - msg.ValidatorData = hdr - s.metrics.accept(ctx, len(msg.Data)) - return pubsub.ValidationAccept +// SetVerifier set given verification func as Header PubSub topic validator. +// Does not punish peers if *header.VerifyError is given with SoftFailure set to true. +func (s *Subscriber[H]) SetVerifier(verifier func(context.Context, H) error) error { + s.verifierMu.Lock() + defer s.verifierMu.Unlock() + if s.verifier != nil { + return errors.New("verifier already set") } - return s.pubsub.RegisterTopicValidator(s.pubsubTopicID, pval) + s.verifier = verifier + close(s.verifierSema) + return nil } // Subscribe returns a new subscription to the Subscriber's @@ -166,3 +140,59 @@ func (s *Subscriber[H]) Broadcast(ctx context.Context, header H, opts ...pubsub. } return s.topic.Publish(ctx, bin, opts...) } + +func (s *Subscriber[H]) verifyMessage(ctx context.Context, p peer.ID, msg *pubsub.Message) (res pubsub.ValidationResult) { + defer func() { + err := recover() + if err != nil { + log.Errorf("PANIC while unmarshalling or verifying header: %s", err) + res = pubsub.ValidationReject + } + }() + + hdr := header.New[H]() + err := hdr.UnmarshalBinary(msg.Data) + if err != nil { + log.Errorw("unmarshalling header", + "from", p.ShortString(), + "err", err) + s.metrics.reject(ctx) + return pubsub.ValidationReject + } + // ensure header validity + err = hdr.Validate() + if err != nil { + log.Errorw("invalid header", + "from", p.ShortString(), + "err", err) + s.metrics.reject(ctx) + return pubsub.ValidationReject + } + + // ensure we have a verifier set before verifying the message + select { + case <-s.verifierSema: + case <-ctx.Done(): + log.Errorw("verifier was not set before incoming header verification", "from", p.ShortString()) + s.metrics.ignore(ctx) + return pubsub.ValidationIgnore + } + + var verErr *header.VerifyError + err = s.verifier(ctx, hdr) + switch { + case errors.As(err, &verErr) && verErr.SoftFailure: + s.metrics.ignore(ctx) + return pubsub.ValidationIgnore + case err != nil: + s.metrics.reject(ctx) + return pubsub.ValidationReject + default: + } + + // keep the valid header in the msg so Subscriptions can access it without + // additional unmarshalling + msg.ValidatorData = hdr + s.metrics.accept(ctx, len(msg.Data)) + return pubsub.ValidationAccept +} diff --git a/p2p/subscription_test.go b/p2p/subscription_test.go index e9e53700..9911503a 100644 --- a/p2p/subscription_test.go +++ b/p2p/subscription_test.go @@ -34,6 +34,10 @@ func TestSubscriber(t *testing.T) { require.NoError(t, err) err = p2pSub1.Start(context.Background()) require.NoError(t, err) + err = p2pSub1.SetVerifier(func(context.Context, *headertest.DummyHeader) error { + return nil + }) + require.NoError(t, err) // get mock host and create new gossipsub on it pubsub2, err := pubsub.NewGossipSub(ctx, net.Hosts()[1], @@ -45,6 +49,10 @@ func TestSubscriber(t *testing.T) { require.NoError(t, err) err = p2pSub2.Start(context.Background()) require.NoError(t, err) + err = p2pSub2.SetVerifier(func(context.Context, *headertest.DummyHeader) error { + return nil + }) + require.NoError(t, err) sub0, err := net.Hosts()[0].EventBus().Subscribe(&event.EvtPeerIdentificationCompleted{}) require.NoError(t, err) @@ -68,11 +76,6 @@ func TestSubscriber(t *testing.T) { _, err = p2pSub2.Subscribe() require.NoError(t, err) - err = p2pSub1.SetVerifier(func(context.Context, *headertest.DummyHeader) error { - return nil - }) - require.NoError(t, err) - subscription, err := p2pSub1.Subscribe() require.NoError(t, err)