Skip to content

Commit

Permalink
Make Message and VerifiedMessage generic over Verifier
Browse files Browse the repository at this point in the history
  • Loading branch information
fjarri committed Oct 30, 2024
1 parent a0eb1b6 commit 0625012
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 31 deletions.
4 changes: 2 additions & 2 deletions examples/tests/async_runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ use tracing_subscriber::{util::SubscriberInitExt, EnvFilter};
struct MessageOut<SP: SessionParameters> {
from: SP::Verifier,
to: SP::Verifier,
message: Message,
message: Message<SP::Verifier>,
}

struct MessageIn<SP: SessionParameters> {
from: SP::Verifier,
message: Message,
message: Message<SP::Verifier>,
}

/// Runs a session. Simulates what each participating party would run as the protocol progresses.
Expand Down
2 changes: 1 addition & 1 deletion manul/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ mod wire_format;

pub use crate::protocol::{LocalError, RemoteError};
pub use evidence::{Evidence, EvidenceError};
pub use message::{VerifiedMessage, Message};
pub use message::{Message, VerifiedMessage};
pub use session::{CanFinalize, RoundAccumulator, RoundOutcome, Session, SessionId, SessionParameters};
pub use transcript::{SessionOutcome, SessionReport};
pub use wire_format::WireFormat;
Expand Down
31 changes: 19 additions & 12 deletions manul/src/session/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,19 +177,24 @@ impl<M> VerifiedMessagePart<M> {

/// A signed message destined for another node.
#[derive(Clone, Debug)]
pub struct Message {
pub struct Message<Verifier> {
destination: Verifier,
direct_message: SignedMessagePart<DirectMessage>,
echo_broadcast: SignedMessagePart<EchoBroadcast>,
normal_broadcast: SignedMessagePart<NormalBroadcast>,
}

impl Message {
impl<Verifier> Message<Verifier>
where
Verifier: Clone,
{
#[allow(clippy::too_many_arguments)]
pub(crate) fn new<SP>(
rng: &mut impl CryptoRngCore,
signer: &SP::Signer,
session_id: &SessionId,
round_id: RoundId,
destination: &Verifier,
direct_message: DirectMessage,
echo_broadcast: SignedMessagePart<EchoBroadcast>,
normal_broadcast: SignedMessagePart<NormalBroadcast>,
Expand All @@ -199,12 +204,18 @@ impl Message {
{
let direct_message = SignedMessagePart::new::<SP>(rng, signer, session_id, round_id, direct_message)?;
Ok(Self {
destination: destination.clone(),
direct_message,
echo_broadcast,
normal_broadcast,
})
}

/// The verifier of the party this message is intended for.
pub fn destination(&self) -> &Verifier {
&self.destination
}

pub(crate) fn unify_metadata(self) -> Option<CheckedMessage> {
if self.echo_broadcast.metadata() != self.direct_message.metadata() {
return None;
Expand Down Expand Up @@ -241,7 +252,7 @@ impl CheckedMessage {
&self.metadata
}

pub fn verify<SP>(self, verifier: &SP::Verifier) -> Result<VerifiedMessage<SP>, MessageVerificationError>
pub fn verify<SP>(self, verifier: &SP::Verifier) -> Result<VerifiedMessage<SP::Verifier>, MessageVerificationError>
where
SP: SessionParameters,
{
Expand All @@ -264,25 +275,21 @@ impl CheckedMessage {
// signatures of message parts (direct, broadcast etc) from the original [`Message`] successfully verified.

/// A [`Message`] that had its metadata and signatures verified.
#[derive_where::derive_where(Debug)]
#[derive(Clone)]
pub struct VerifiedMessage<SP: SessionParameters> {
from: SP::Verifier,
#[derive(Debug, Clone)]
pub struct VerifiedMessage<Verifier> {
from: Verifier,
metadata: MessageMetadata,
direct_message: VerifiedMessagePart<DirectMessage>,
echo_broadcast: VerifiedMessagePart<EchoBroadcast>,
normal_broadcast: VerifiedMessagePart<NormalBroadcast>,
}

impl<SP> VerifiedMessage<SP>
where
SP: SessionParameters,
{
impl<Verifier> VerifiedMessage<Verifier> {
pub(crate) fn metadata(&self) -> &MessageMetadata {
&self.metadata
}

pub(crate) fn from(&self) -> &SP::Verifier {
pub(crate) fn from(&self) -> &Verifier {
&self.from
}

Expand Down
31 changes: 16 additions & 15 deletions manul/src/session/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ pub enum RoundOutcome<P: Protocol, SP: SessionParameters> {
/// The session object for the new round.
session: Session<P, SP>,
/// The messages intended for the new round cached during the previous round.
cached_messages: Vec<VerifiedMessage<SP>>,
cached_messages: Vec<VerifiedMessage<SP::Verifier>>,
},
}

Expand Down Expand Up @@ -231,7 +231,7 @@ where
&self,
rng: &mut impl CryptoRngCore,
destination: &SP::Verifier,
) -> Result<(Message, ProcessedArtifact<SP>), LocalError> {
) -> Result<(Message<SP::Verifier>, ProcessedArtifact<SP>), LocalError> {
let (direct_message, artifact) =
self.round
.make_direct_message_with_artifact(rng, &self.serializer, destination)?;
Expand All @@ -241,6 +241,7 @@ where
&self.signer,
&self.session_id,
self.round.id(),
destination,
direct_message,
self.echo_broadcast.clone(),
self.normal_broadcast.clone(),
Expand Down Expand Up @@ -285,8 +286,8 @@ where
&self,
accum: &mut RoundAccumulator<P, SP>,
from: &SP::Verifier,
message: Message,
) -> Result<Option<VerifiedMessage<SP>>, LocalError> {
message: Message<SP::Verifier>,
) -> Result<Option<VerifiedMessage<SP::Verifier>>, LocalError> {
// Quick preliminary checks, before we proceed with more expensive verification
let key = self.verifier();
if self.transcript.is_banned(from) || accum.is_banned(from) {
Expand Down Expand Up @@ -381,7 +382,7 @@ where
pub fn process_message(
&self,
rng: &mut impl CryptoRngCore,
message: VerifiedMessage<SP>,
message: VerifiedMessage<SP::Verifier>,
) -> ProcessedMessage<P, SP> {
let processed = self.round.receive_message(
rng,
Expand Down Expand Up @@ -544,7 +545,7 @@ pub struct RoundAccumulator<P: Protocol, SP: SessionParameters> {
processing: BTreeSet<SP::Verifier>,
payloads: BTreeMap<SP::Verifier, Payload>,
artifacts: BTreeMap<SP::Verifier, Artifact>,
cached: BTreeMap<SP::Verifier, BTreeMap<RoundId, VerifiedMessage<SP>>>,
cached: BTreeMap<SP::Verifier, BTreeMap<RoundId, VerifiedMessage<SP::Verifier>>>,
echo_broadcasts: BTreeMap<SP::Verifier, SignedMessagePart<EchoBroadcast>>,
normal_broadcasts: BTreeMap<SP::Verifier, SignedMessagePart<NormalBroadcast>>,
direct_messages: BTreeMap<SP::Verifier, SignedMessagePart<DirectMessage>>,
Expand Down Expand Up @@ -625,7 +626,7 @@ where
}
}

fn mark_processing(&mut self, message: &VerifiedMessage<SP>) -> Result<(), LocalError> {
fn mark_processing(&mut self, message: &VerifiedMessage<SP::Verifier>) -> Result<(), LocalError> {
if !self.processing.insert(message.from().clone()) {
Err(LocalError::new(format!(
"A message from {:?} is already marked as being processed",
Expand Down Expand Up @@ -732,7 +733,7 @@ where
}
}

fn cache_message(&mut self, message: VerifiedMessage<SP>) -> Result<(), LocalError> {
fn cache_message(&mut self, message: VerifiedMessage<SP::Verifier>) -> Result<(), LocalError> {
let from = message.from().clone();
let round_id = message.metadata().round_id();
let cached = self.cached.entry(from.clone()).or_default();
Expand All @@ -754,14 +755,14 @@ pub struct ProcessedArtifact<SP: SessionParameters> {

#[derive(Debug)]
pub struct ProcessedMessage<P: Protocol, SP: SessionParameters> {
message: VerifiedMessage<SP>,
message: VerifiedMessage<SP::Verifier>,
processed: Result<Payload, ReceiveError<SP::Verifier, P>>,
}

fn filter_messages<SP: SessionParameters>(
messages: BTreeMap<SP::Verifier, BTreeMap<RoundId, VerifiedMessage<SP>>>,
fn filter_messages<Verifier>(
messages: BTreeMap<Verifier, BTreeMap<RoundId, VerifiedMessage<Verifier>>>,
round_id: RoundId,
) -> Vec<VerifiedMessage<SP>> {
) -> Vec<VerifiedMessage<Verifier>> {
messages
.into_values()
.filter_map(|mut messages| messages.remove(&round_id))
Expand All @@ -781,7 +782,7 @@ mod tests {
Deserializer, DirectMessage, EchoBroadcast, NormalBroadcast, Protocol, ProtocolError,
ProtocolValidationError, RoundId,
},
testing::{BinaryFormat, TestSessionParams},
testing::{BinaryFormat, TestSessionParams, TestVerifier},
};

#[test]
Expand Down Expand Up @@ -835,9 +836,9 @@ mod tests {
assert!(impls!(Session<DummyProtocol, SP>: Sync));

// These objects are sent to/from message processing tasks
assert!(impls!(Message: Send));
assert!(impls!(Message<TestVerifier>: Send));
assert!(impls!(ProcessedArtifact<SP>: Send));
assert!(impls!(VerifiedMessage<SP>: Send));
assert!(impls!(VerifiedMessage<TestVerifier>: Send));
assert!(impls!(ProcessedMessage<DummyProtocol, SP>: Send));
}
}
2 changes: 1 addition & 1 deletion manul/src/testing/run_sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ enum State<P: Protocol, SP: SessionParameters> {
struct RoundMessage<SP: SessionParameters> {
from: SP::Verifier,
to: SP::Verifier,
message: Message,
message: Message<SP::Verifier>,
}

#[allow(clippy::type_complexity)]
Expand Down

0 comments on commit 0625012

Please sign in to comment.