From 062501292d5101962d09f7272930974744a7ea46 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Wed, 30 Oct 2024 12:03:58 -0700 Subject: [PATCH] Make `Message` and `VerifiedMessage` generic over `Verifier` --- examples/tests/async_runner.rs | 4 ++-- manul/src/session.rs | 2 +- manul/src/session/message.rs | 31 +++++++++++++++++++------------ manul/src/session/session.rs | 31 ++++++++++++++++--------------- manul/src/testing/run_sync.rs | 2 +- 5 files changed, 39 insertions(+), 31 deletions(-) diff --git a/examples/tests/async_runner.rs b/examples/tests/async_runner.rs index 1c52472..008ee82 100644 --- a/examples/tests/async_runner.rs +++ b/examples/tests/async_runner.rs @@ -23,12 +23,12 @@ use tracing_subscriber::{util::SubscriberInitExt, EnvFilter}; struct MessageOut { from: SP::Verifier, to: SP::Verifier, - message: Message, + message: Message, } struct MessageIn { from: SP::Verifier, - message: Message, + message: Message, } /// Runs a session. Simulates what each participating party would run as the protocol progresses. diff --git a/manul/src/session.rs b/manul/src/session.rs index 83791b8..3710d96 100644 --- a/manul/src/session.rs +++ b/manul/src/session.rs @@ -17,7 +17,7 @@ mod wire_format; pub use crate::protocol::{LocalError, RemoteError}; pub use evidence::{Evidence, EvidenceError}; -pub use message::{VerifiedMessage, Message}; +pub use message::{Message, VerifiedMessage}; pub use session::{CanFinalize, RoundAccumulator, RoundOutcome, Session, SessionId, SessionParameters}; pub use transcript::{SessionOutcome, SessionReport}; pub use wire_format::WireFormat; diff --git a/manul/src/session/message.rs b/manul/src/session/message.rs index e25fe66..51dbc55 100644 --- a/manul/src/session/message.rs +++ b/manul/src/session/message.rs @@ -177,19 +177,24 @@ impl VerifiedMessagePart { /// A signed message destined for another node. #[derive(Clone, Debug)] -pub struct Message { +pub struct Message { + destination: Verifier, direct_message: SignedMessagePart, echo_broadcast: SignedMessagePart, normal_broadcast: SignedMessagePart, } -impl Message { +impl Message +where + Verifier: Clone, +{ #[allow(clippy::too_many_arguments)] pub(crate) fn new( rng: &mut impl CryptoRngCore, signer: &SP::Signer, session_id: &SessionId, round_id: RoundId, + destination: &Verifier, direct_message: DirectMessage, echo_broadcast: SignedMessagePart, normal_broadcast: SignedMessagePart, @@ -199,12 +204,18 @@ impl Message { { let direct_message = SignedMessagePart::new::(rng, signer, session_id, round_id, direct_message)?; Ok(Self { + destination: destination.clone(), direct_message, echo_broadcast, normal_broadcast, }) } + /// The verifier of the party this message is intended for. + pub fn destination(&self) -> &Verifier { + &self.destination + } + pub(crate) fn unify_metadata(self) -> Option { if self.echo_broadcast.metadata() != self.direct_message.metadata() { return None; @@ -241,7 +252,7 @@ impl CheckedMessage { &self.metadata } - pub fn verify(self, verifier: &SP::Verifier) -> Result, MessageVerificationError> + pub fn verify(self, verifier: &SP::Verifier) -> Result, MessageVerificationError> where SP: SessionParameters, { @@ -264,25 +275,21 @@ impl CheckedMessage { // signatures of message parts (direct, broadcast etc) from the original [`Message`] successfully verified. /// A [`Message`] that had its metadata and signatures verified. -#[derive_where::derive_where(Debug)] -#[derive(Clone)] -pub struct VerifiedMessage { - from: SP::Verifier, +#[derive(Debug, Clone)] +pub struct VerifiedMessage { + from: Verifier, metadata: MessageMetadata, direct_message: VerifiedMessagePart, echo_broadcast: VerifiedMessagePart, normal_broadcast: VerifiedMessagePart, } -impl VerifiedMessage -where - SP: SessionParameters, -{ +impl VerifiedMessage { pub(crate) fn metadata(&self) -> &MessageMetadata { &self.metadata } - pub(crate) fn from(&self) -> &SP::Verifier { + pub(crate) fn from(&self) -> &Verifier { &self.from } diff --git a/manul/src/session/session.rs b/manul/src/session/session.rs index b1a084b..210e0b7 100644 --- a/manul/src/session/session.rs +++ b/manul/src/session/session.rs @@ -130,7 +130,7 @@ pub enum RoundOutcome { /// The session object for the new round. session: Session, /// The messages intended for the new round cached during the previous round. - cached_messages: Vec>, + cached_messages: Vec>, }, } @@ -231,7 +231,7 @@ where &self, rng: &mut impl CryptoRngCore, destination: &SP::Verifier, - ) -> Result<(Message, ProcessedArtifact), LocalError> { + ) -> Result<(Message, ProcessedArtifact), LocalError> { let (direct_message, artifact) = self.round .make_direct_message_with_artifact(rng, &self.serializer, destination)?; @@ -241,6 +241,7 @@ where &self.signer, &self.session_id, self.round.id(), + destination, direct_message, self.echo_broadcast.clone(), self.normal_broadcast.clone(), @@ -285,8 +286,8 @@ where &self, accum: &mut RoundAccumulator, from: &SP::Verifier, - message: Message, - ) -> Result>, LocalError> { + message: Message, + ) -> Result>, LocalError> { // Quick preliminary checks, before we proceed with more expensive verification let key = self.verifier(); if self.transcript.is_banned(from) || accum.is_banned(from) { @@ -381,7 +382,7 @@ where pub fn process_message( &self, rng: &mut impl CryptoRngCore, - message: VerifiedMessage, + message: VerifiedMessage, ) -> ProcessedMessage { let processed = self.round.receive_message( rng, @@ -544,7 +545,7 @@ pub struct RoundAccumulator { processing: BTreeSet, payloads: BTreeMap, artifacts: BTreeMap, - cached: BTreeMap>>, + cached: BTreeMap>>, echo_broadcasts: BTreeMap>, normal_broadcasts: BTreeMap>, direct_messages: BTreeMap>, @@ -625,7 +626,7 @@ where } } - fn mark_processing(&mut self, message: &VerifiedMessage) -> Result<(), LocalError> { + fn mark_processing(&mut self, message: &VerifiedMessage) -> Result<(), LocalError> { if !self.processing.insert(message.from().clone()) { Err(LocalError::new(format!( "A message from {:?} is already marked as being processed", @@ -732,7 +733,7 @@ where } } - fn cache_message(&mut self, message: VerifiedMessage) -> Result<(), LocalError> { + fn cache_message(&mut self, message: VerifiedMessage) -> Result<(), LocalError> { let from = message.from().clone(); let round_id = message.metadata().round_id(); let cached = self.cached.entry(from.clone()).or_default(); @@ -754,14 +755,14 @@ pub struct ProcessedArtifact { #[derive(Debug)] pub struct ProcessedMessage { - message: VerifiedMessage, + message: VerifiedMessage, processed: Result>, } -fn filter_messages( - messages: BTreeMap>>, +fn filter_messages( + messages: BTreeMap>>, round_id: RoundId, -) -> Vec> { +) -> Vec> { messages .into_values() .filter_map(|mut messages| messages.remove(&round_id)) @@ -781,7 +782,7 @@ mod tests { Deserializer, DirectMessage, EchoBroadcast, NormalBroadcast, Protocol, ProtocolError, ProtocolValidationError, RoundId, }, - testing::{BinaryFormat, TestSessionParams}, + testing::{BinaryFormat, TestSessionParams, TestVerifier}, }; #[test] @@ -835,9 +836,9 @@ mod tests { assert!(impls!(Session: Sync)); // These objects are sent to/from message processing tasks - assert!(impls!(Message: Send)); + assert!(impls!(Message: Send)); assert!(impls!(ProcessedArtifact: Send)); - assert!(impls!(VerifiedMessage: Send)); + assert!(impls!(VerifiedMessage: Send)); assert!(impls!(ProcessedMessage: Send)); } } diff --git a/manul/src/testing/run_sync.rs b/manul/src/testing/run_sync.rs index 9cd08de..1ca8669 100644 --- a/manul/src/testing/run_sync.rs +++ b/manul/src/testing/run_sync.rs @@ -24,7 +24,7 @@ enum State { struct RoundMessage { from: SP::Verifier, to: SP::Verifier, - message: Message, + message: Message, } #[allow(clippy::type_complexity)]