From aed78d08495e0ec274a10cb0b651e2b11cc6b113 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Fri, 1 Nov 2024 14:25:43 -0700 Subject: [PATCH 1/9] Remove the make_direct_message() shortcut --- CHANGELOG.md | 1 - examples/src/simple.rs | 10 ++++++---- examples/src/simple_malicious.rs | 10 +++++----- manul/benches/empty_rounds.rs | 2 +- manul/src/protocol/object_safe.rs | 7 +++---- manul/src/protocol/round.rs | 25 +++---------------------- manul/src/session/session.rs | 4 +--- manul/src/testing/macros.rs | 24 ++---------------------- 8 files changed, 21 insertions(+), 62 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c203760..e74c543 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `SessionId::new()` renamed to `from_seed()`. ([#41]) - `FirstRound::new()` takes a `&[u8]` instead of a `SessionId` object. ([#41]) - The signatures of `Round::make_echo_broadcast()`, `Round::make_direct_message()`, and `Round::receive_message()`, take messages without `Option`s. ([#46]) -- `Round::make_direct_message_with_artifact()` is the method returning an artifact now; `Round::make_direct_message()` is a shortcut for cases where no artifact is returned. ([#46]) - `Artifact::empty()` removed, the user should return `None` instead. ([#46]) - `EchoBroadcast` and `DirectMessage` now use `ProtocolMessagePart` trait for their methods. ([#47]) - Added normal broadcasts support in addition to echo ones; signatures of `Round` methods changed accordingly; added `Round::make_normal_broadcast()`. ([#47]) diff --git a/examples/src/simple.rs b/examples/src/simple.rs index b947fac..bbf4756 100644 --- a/examples/src/simple.rs +++ b/examples/src/simple.rs @@ -228,14 +228,15 @@ impl Round for Round1 { _rng: &mut impl CryptoRngCore, serializer: &Serializer, destination: &Id, - ) -> Result { + ) -> Result<(DirectMessage, Option), LocalError> { debug!("{:?}: making direct message for {:?}", self.context.id, destination); let message = Round1Message { my_position: self.context.ids_to_positions[&self.context.id], your_position: self.context.ids_to_positions[destination], }; - DirectMessage::new(serializer, message) + let dm = DirectMessage::new(serializer, message)?; + Ok((dm, None)) } fn receive_message( @@ -325,14 +326,15 @@ impl Round for Round2 { _rng: &mut impl CryptoRngCore, serializer: &Serializer, destination: &Id, - ) -> Result { + ) -> Result<(DirectMessage, Option), LocalError> { debug!("{:?}: making direct message for {:?}", self.context.id, destination); let message = Round2Message { my_position: self.context.ids_to_positions[&self.context.id], your_position: self.context.ids_to_positions[destination], }; - DirectMessage::new(serializer, message) + let dm = DirectMessage::new(serializer, message)?; + Ok((dm, None)) } fn receive_message( diff --git a/examples/src/simple_malicious.rs b/examples/src/simple_malicious.rs index 1ed749d..950c9f2 100644 --- a/examples/src/simple_malicious.rs +++ b/examples/src/simple_malicious.rs @@ -68,15 +68,15 @@ impl RoundOverride for MaliciousRound1 { rng: &mut impl CryptoRngCore, serializer: &Serializer, destination: &Id, - ) -> Result { + ) -> Result<(DirectMessage, Option), LocalError> { if matches!(self.behavior, Behavior::SerializedGarbage) { - DirectMessage::new(serializer, [99u8]) + Ok((DirectMessage::new(serializer, [99u8])?, None)) } else if matches!(self.behavior, Behavior::AttributableFailure) { let message = Round1Message { my_position: self.round.context.ids_to_positions[&self.round.context.id], your_position: self.round.context.ids_to_positions[&self.round.context.id], }; - DirectMessage::new(serializer, message) + Ok((DirectMessage::new(serializer, message)?, None)) } else { self.inner_round_ref().make_direct_message(rng, serializer, destination) } @@ -131,13 +131,13 @@ impl RoundOverride for MaliciousRound2 { rng: &mut impl CryptoRngCore, serializer: &Serializer, destination: &Id, - ) -> Result { + ) -> Result<(DirectMessage, Option), LocalError> { if matches!(self.behavior, Behavior::AttributableFailureRound2) { let message = Round2Message { my_position: self.round.context.ids_to_positions[&self.round.context.id], your_position: self.round.context.ids_to_positions[&self.round.context.id], }; - DirectMessage::new(serializer, message) + Ok((DirectMessage::new(serializer, message)?, None)) } else { self.inner_round_ref().make_direct_message(rng, serializer, destination) } diff --git a/manul/benches/empty_rounds.rs b/manul/benches/empty_rounds.rs index 6ce657b..4420243 100644 --- a/manul/benches/empty_rounds.rs +++ b/manul/benches/empty_rounds.rs @@ -119,7 +119,7 @@ impl Round for EmptyRound { } } - fn make_direct_message_with_artifact( + fn make_direct_message( &self, _rng: &mut impl CryptoRngCore, serializer: &Serializer, diff --git a/manul/src/protocol/object_safe.rs b/manul/src/protocol/object_safe.rs index fd65f95..6d282be 100644 --- a/manul/src/protocol/object_safe.rs +++ b/manul/src/protocol/object_safe.rs @@ -48,7 +48,7 @@ pub(crate) trait ObjectSafeRound: 'static + Debug + Send + Sync { fn message_destinations(&self) -> &BTreeSet; - fn make_direct_message_with_artifact( + fn make_direct_message( &self, rng: &mut dyn CryptoRngCore, serializer: &Serializer, @@ -130,15 +130,14 @@ where self.round.message_destinations() } - fn make_direct_message_with_artifact( + fn make_direct_message( &self, rng: &mut dyn CryptoRngCore, serializer: &Serializer, destination: &Id, ) -> Result<(DirectMessage, Option), LocalError> { let mut boxed_rng = BoxedRng(rng); - self.round - .make_direct_message_with_artifact(&mut boxed_rng, serializer, destination) + self.round.make_direct_message(&mut boxed_rng, serializer, destination) } fn make_echo_broadcast( diff --git a/manul/src/protocol/round.rs b/manul/src/protocol/round.rs index e2f113b..56c1aa1 100644 --- a/manul/src/protocol/round.rs +++ b/manul/src/protocol/round.rs @@ -358,35 +358,16 @@ pub trait Round: 'static + Debug + Send + Sync { /// /// Return [`DirectMessage::none`] if this round does not send direct messages. /// - /// Falls back to [`make_direct_message`](`Self::make_direct_message`) if not implemented. - /// This is the method that will be called by the upper layer when creating direct messages. - /// /// In some protocols, when a message to another node is created, there is some associated information /// that needs to be retained for later (randomness, proofs of knowledge, and so on). /// These should be put in an [`Artifact`] and will be available at the time of [`finalize`](`Self::finalize`). - fn make_direct_message_with_artifact( - &self, - rng: &mut impl CryptoRngCore, - serializer: &Serializer, - destination: &Id, - ) -> Result<(DirectMessage, Option), LocalError> { - Ok((self.make_direct_message(rng, serializer, destination)?, None)) - } - - /// Returns the direct message to the given destination. - /// - /// This method will not be called by the upper layer directly, - /// only via [`make_direct_message_with_artifact`](`Self::make_direct_message_with_artifact`). - /// - /// Return [`DirectMessage::none`] if this round does not send direct messages. - /// This is also the blanket implementation. fn make_direct_message( &self, #[allow(unused_variables)] rng: &mut impl CryptoRngCore, #[allow(unused_variables)] serializer: &Serializer, #[allow(unused_variables)] destination: &Id, - ) -> Result { - Ok(DirectMessage::none()) + ) -> Result<(DirectMessage, Option), LocalError> { + Ok((DirectMessage::none(), None)) } /// Returns the echo broadcast for this round. @@ -438,7 +419,7 @@ pub trait Round: 'static + Debug + Send + Sync { /// /// `payloads` here are the ones previously generated by [`receive_message`](`Self::receive_message`), /// and `artifacts` are the ones previously generated by - /// [`make_direct_message_with_artifact`](`Self::make_direct_message_with_artifact`). + /// [`make_direct_message`](`Self::make_direct_message`). fn finalize( self, rng: &mut impl CryptoRngCore, diff --git a/manul/src/session/session.rs b/manul/src/session/session.rs index 5c1845e..6daec53 100644 --- a/manul/src/session/session.rs +++ b/manul/src/session/session.rs @@ -226,9 +226,7 @@ where rng: &mut impl CryptoRngCore, destination: &SP::Verifier, ) -> Result<(Message, ProcessedArtifact), LocalError> { - let (direct_message, artifact) = - self.round - .make_direct_message_with_artifact(rng, &self.serializer, destination)?; + let (direct_message, artifact) = self.round.make_direct_message(rng, &self.serializer, destination)?; let message = Message::new::( rng, diff --git a/manul/src/testing/macros.rs b/manul/src/testing/macros.rs index d9c5d70..ac8fb57 100644 --- a/manul/src/testing/macros.rs +++ b/manul/src/testing/macros.rs @@ -25,24 +25,13 @@ pub trait RoundWrapper: 'static + Sized + Send + Sync { /// /// The blanket implementations delegate to the methods of the wrapped round. pub trait RoundOverride: RoundWrapper { - /// An override for [`Round::make_direct_message_with_artifact`]. - fn make_direct_message_with_artifact( - &self, - rng: &mut impl CryptoRngCore, - serializer: &Serializer, - destination: &Id, - ) -> Result<(DirectMessage, Option), LocalError> { - let dm = self.make_direct_message(rng, serializer, destination)?; - Ok((dm, None)) - } - /// An override for [`Round::make_direct_message`]. fn make_direct_message( &self, rng: &mut impl CryptoRngCore, serializer: &Serializer, destination: &Id, - ) -> Result { + ) -> Result<(DirectMessage, Option), LocalError> { self.inner_round_ref().make_direct_message(rng, serializer, destination) } @@ -113,17 +102,8 @@ macro_rules! round_override { rng: &mut impl CryptoRngCore, serializer: &$crate::protocol::Serializer, destination: &Id, - ) -> Result<$crate::protocol::DirectMessage, $crate::protocol::LocalError> { - >::make_direct_message(self, rng, serializer, destination) - } - - fn make_direct_message_with_artifact( - &self, - rng: &mut impl CryptoRngCore, - serializer: &$crate::protocol::Serializer, - destination: &Id, ) -> Result<($crate::protocol::DirectMessage, Option<$crate::protocol::Artifact>), $crate::protocol::LocalError> { - >::make_direct_message_with_artifact(self, rng, serializer, destination) + >::make_direct_message(self, rng, serializer, destination) } fn make_echo_broadcast( From a9e6d4d0b05cdf25cbcdd0d9676afed865aab685 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Sun, 3 Nov 2024 13:18:40 -0800 Subject: [PATCH 2/9] Rename `FirstRound` to `EntryPoint` --- CHANGELOG.md | 2 ++ examples/src/simple.rs | 2 +- examples/src/simple_malicious.rs | 4 ++-- manul/benches/empty_rounds.rs | 4 ++-- manul/src/protocol.rs | 4 ++-- manul/src/protocol/round.rs | 2 +- manul/src/session/session.rs | 4 ++-- manul/src/testing/run_sync.rs | 4 ++-- 8 files changed, 14 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e74c543..27b76ff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Renamed `(Verified)MessageBundle` to `(Verified)Message`. Both are now generic over `Verifier`. ([#56]) - `Session::preprocess_message()` now returns a `PreprocessOutcome` instead of just an `Option`. ([#57]) - `Session::terminate_due_to_errors()` replaces `terminate()`; `terminate()` now signals user interrupt. ([#58]) +- Renamed `FirstRound` trait to `EntryPoint`. ([#60]) ### Added @@ -44,6 +45,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#57]: https://github.com/entropyxyz/manul/pull/57 [#58]: https://github.com/entropyxyz/manul/pull/58 [#59]: https://github.com/entropyxyz/manul/pull/59 +[#60]: https://github.com/entropyxyz/manul/pull/60 ## [0.0.1] - 2024-10-12 diff --git a/examples/src/simple.rs b/examples/src/simple.rs index bbf4756..52fbe3f 100644 --- a/examples/src/simple.rs +++ b/examples/src/simple.rs @@ -149,7 +149,7 @@ struct Round1Payload { x: u8, } -impl FirstRound for Round1 { +impl EntryPoint for Round1 { type Inputs = Inputs; fn new( _rng: &mut impl CryptoRngCore, diff --git a/examples/src/simple_malicious.rs b/examples/src/simple_malicious.rs index 950c9f2..92d163e 100644 --- a/examples/src/simple_malicious.rs +++ b/examples/src/simple_malicious.rs @@ -3,7 +3,7 @@ use core::fmt::Debug; use manul::{ protocol::{ - Artifact, DirectMessage, FinalizeError, FinalizeOutcome, FirstRound, LocalError, PartyId, Payload, + Artifact, DirectMessage, EntryPoint, FinalizeError, FinalizeOutcome, LocalError, PartyId, Payload, ProtocolMessagePart, Round, Serializer, }, session::signature::Keypair, @@ -46,7 +46,7 @@ impl RoundWrapper for MaliciousRound1 { } } -impl FirstRound for MaliciousRound1 { +impl EntryPoint for MaliciousRound1 { type Inputs = MaliciousInputs; fn new( rng: &mut impl CryptoRngCore, diff --git a/manul/benches/empty_rounds.rs b/manul/benches/empty_rounds.rs index 4420243..211f0f9 100644 --- a/manul/benches/empty_rounds.rs +++ b/manul/benches/empty_rounds.rs @@ -9,7 +9,7 @@ use core::fmt::Debug; use criterion::{criterion_group, criterion_main, Criterion}; use manul::{ protocol::{ - Artifact, Deserializer, DirectMessage, EchoBroadcast, FinalizeError, FinalizeOutcome, FirstRound, LocalError, + Artifact, Deserializer, DirectMessage, EchoBroadcast, EntryPoint, FinalizeError, FinalizeOutcome, LocalError, NormalBroadcast, PartyId, Payload, Protocol, ProtocolError, ProtocolMessagePart, ProtocolValidationError, ReceiveError, Round, RoundId, Serializer, }, @@ -73,7 +73,7 @@ struct Round1Payload; struct Round1Artifact; -impl FirstRound for EmptyRound { +impl EntryPoint for EmptyRound { type Inputs = Inputs; fn new( _rng: &mut impl CryptoRngCore, diff --git a/manul/src/protocol.rs b/manul/src/protocol.rs index 0e53744..5912a62 100644 --- a/manul/src/protocol.rs +++ b/manul/src/protocol.rs @@ -4,7 +4,7 @@ API for protocol implementors. A protocol is a directed acyclic graph with the nodes being objects of types implementing [`Round`] (to be specific, "acyclic" means that the values returned by [`Round::id`] should not repeat during the protocol execution; the types might). -The starting point is a type that implements [`FirstRound`]. +The starting point is a type that implements [`EntryPoint`]. All the rounds must have their associated type [`Round::Protocol`] set to the same [`Protocol`] instance to be executed by a [`Session`](`crate::session::Session`). @@ -23,7 +23,7 @@ pub use errors::{ }; pub use message::{DirectMessage, EchoBroadcast, NormalBroadcast, ProtocolMessagePart}; pub use round::{ - AnotherRound, Artifact, FinalizeOutcome, FirstRound, PartyId, Payload, Protocol, ProtocolError, Round, RoundId, + AnotherRound, Artifact, EntryPoint, FinalizeOutcome, PartyId, Payload, Protocol, ProtocolError, Round, RoundId, }; pub use serialization::{Deserializer, Serializer}; diff --git a/manul/src/protocol/round.rs b/manul/src/protocol/round.rs index 56c1aa1..777b521 100644 --- a/manul/src/protocol/round.rs +++ b/manul/src/protocol/round.rs @@ -301,7 +301,7 @@ impl Artifact { /// /// This is a round that can be created directly; /// all the others are only reachable throud [`Round::finalize`] by the execution layer. -pub trait FirstRound: Round + Sized { +pub trait EntryPoint: Round + Sized { /// Additional inputs for the protocol (besides the mandatory ones in [`new`](`Self::new`)). type Inputs; diff --git a/manul/src/session/session.rs b/manul/src/session/session.rs index 6daec53..6750adc 100644 --- a/manul/src/session/session.rs +++ b/manul/src/session/session.rs @@ -23,7 +23,7 @@ use super::{ LocalError, RemoteError, }; use crate::protocol::{ - Artifact, Deserializer, DirectMessage, EchoBroadcast, FinalizeError, FinalizeOutcome, FirstRound, NormalBroadcast, + Artifact, Deserializer, DirectMessage, EchoBroadcast, EntryPoint, FinalizeError, FinalizeOutcome, NormalBroadcast, ObjectSafeRound, ObjectSafeRoundWrapper, PartyId, Payload, Protocol, ProtocolMessagePart, ReceiveError, ReceiveErrorType, Round, RoundId, Serializer, }; @@ -141,7 +141,7 @@ where inputs: R::Inputs, ) -> Result where - R: FirstRound + Round, + R: EntryPoint + Round, { let verifier = signer.verifying_key(); let first_round = Box::new(ObjectSafeRoundWrapper::new(R::new( diff --git a/manul/src/testing/run_sync.rs b/manul/src/testing/run_sync.rs index 83d24bc..9e06ac1 100644 --- a/manul/src/testing/run_sync.rs +++ b/manul/src/testing/run_sync.rs @@ -6,7 +6,7 @@ use signature::Keypair; use tracing::debug; use crate::{ - protocol::{FirstRound, Protocol}, + protocol::{EntryPoint, Protocol}, session::{ CanFinalize, LocalError, Message, RoundAccumulator, RoundOutcome, Session, SessionId, SessionParameters, SessionReport, @@ -94,7 +94,7 @@ pub fn run_sync( inputs: Vec<(SP::Signer, R::Inputs)>, ) -> Result>, LocalError> where - R: FirstRound, + R: EntryPoint, SP: SessionParameters, { let session_id = SessionId::random::(rng); From 24fb8ad962fd7365a90cba2902c40bcb5f9ed1af Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Mon, 4 Nov 2024 15:23:59 -0800 Subject: [PATCH 3/9] Add BoxedRound type to wrap Box --- CHANGELOG.md | 2 + examples/src/simple.rs | 13 +++--- examples/src/simple_malicious.rs | 21 ++++----- manul/benches/empty_rounds.rs | 19 ++++---- manul/src/protocol.rs | 6 +-- manul/src/protocol/object_safe.rs | 74 ++++++++++++++++++++++++------- manul/src/protocol/round.rs | 56 ++++------------------- manul/src/session/echo.rs | 19 ++++---- manul/src/session/session.rs | 67 ++++++++++++++-------------- manul/src/testing/run_sync.rs | 2 +- 10 files changed, 141 insertions(+), 138 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 27b76ff..ac0e5ad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `Session::preprocess_message()` now returns a `PreprocessOutcome` instead of just an `Option`. ([#57]) - `Session::terminate_due_to_errors()` replaces `terminate()`; `terminate()` now signals user interrupt. ([#58]) - Renamed `FirstRound` trait to `EntryPoint`. ([#60]) +- Added `Protocol` type to `EntryPoint`. ([#60]) +- `EntryPoint` and `FinalizeOutcome::AnotherRound` now use a new `BoxedRound` wrapper type. ([#60]) ### Added diff --git a/examples/src/simple.rs b/examples/src/simple.rs index 52fbe3f..eeab9ab 100644 --- a/examples/src/simple.rs +++ b/examples/src/simple.rs @@ -151,12 +151,13 @@ struct Round1Payload { impl EntryPoint for Round1 { type Inputs = Inputs; + type Protocol = SimpleProtocol; fn new( _rng: &mut impl CryptoRngCore, _shared_randomness: &[u8], id: Id, inputs: Self::Inputs, - ) -> Result { + ) -> Result, LocalError> { // Just some numbers associated with IDs to use in the dummy protocol. // They will be the same on each node since IDs are ordered. let ids_to_positions = inputs @@ -169,13 +170,13 @@ impl EntryPoint for Round1 { let mut ids = inputs.all_ids; ids.remove(&id); - Ok(Self { + Ok(BoxedRound::new_dynamic(Self { context: Context { id, other_ids: ids, ids_to_positions, }, - }) + })) } } @@ -282,11 +283,11 @@ impl Round for Round1 { let sum = self.context.ids_to_positions[&self.context.id] + typed_payloads.iter().map(|payload| payload.x).sum::(); - let round2 = Round2 { + let round2 = BoxedRound::new_dynamic(Round2 { round1_sum: sum, context: self.context, - }; - Ok(FinalizeOutcome::another_round(round2)) + }); + Ok(FinalizeOutcome::AnotherRound(round2)) } fn expecting_messages_from(&self) -> &BTreeSet { diff --git a/examples/src/simple_malicious.rs b/examples/src/simple_malicious.rs index 92d163e..90a4aa9 100644 --- a/examples/src/simple_malicious.rs +++ b/examples/src/simple_malicious.rs @@ -3,7 +3,7 @@ use core::fmt::Debug; use manul::{ protocol::{ - Artifact, DirectMessage, EntryPoint, FinalizeError, FinalizeOutcome, LocalError, PartyId, Payload, + Artifact, BoxedRound, DirectMessage, EntryPoint, FinalizeError, FinalizeOutcome, LocalError, PartyId, Payload, ProtocolMessagePart, Round, Serializer, }, session::signature::Keypair, @@ -15,7 +15,7 @@ use manul::{ use rand_core::{CryptoRngCore, OsRng}; use tracing_subscriber::EnvFilter; -use crate::simple::{Inputs, Round1, Round1Message, Round2, Round2Message}; +use crate::simple::{Inputs, Round1, Round1Message, Round2, Round2Message, SimpleProtocol}; #[derive(Debug, Clone, Copy)] enum Behavior { @@ -48,17 +48,18 @@ impl RoundWrapper for MaliciousRound1 { impl EntryPoint for MaliciousRound1 { type Inputs = MaliciousInputs; + type Protocol = SimpleProtocol; fn new( rng: &mut impl CryptoRngCore, shared_randomness: &[u8], id: Id, inputs: Self::Inputs, - ) -> Result { - let round = Round1::new(rng, shared_randomness, id, inputs.inputs)?; - Ok(Self { + ) -> Result, LocalError> { + let round = Round1::new(rng, shared_randomness, id, inputs.inputs)?.downcast::>()?; + Ok(BoxedRound::new_dynamic(Self { round, behavior: inputs.behavior, - }) + })) } } @@ -96,12 +97,12 @@ impl RoundOverride for MaliciousRound1 { Ok(match outcome { FinalizeOutcome::Result(res) => FinalizeOutcome::Result(res), - FinalizeOutcome::AnotherRound(another_round) => { - let round2 = another_round.downcast::>().map_err(FinalizeError::Local)?; - FinalizeOutcome::another_round(MaliciousRound2 { + FinalizeOutcome::AnotherRound(boxed_round) => { + let round2 = boxed_round.downcast::>().map_err(FinalizeError::Local)?; + FinalizeOutcome::AnotherRound(BoxedRound::new_dynamic(MaliciousRound2 { round: round2, behavior, - }) + })) } }) } diff --git a/manul/benches/empty_rounds.rs b/manul/benches/empty_rounds.rs index 211f0f9..25cbf5f 100644 --- a/manul/benches/empty_rounds.rs +++ b/manul/benches/empty_rounds.rs @@ -9,9 +9,9 @@ use core::fmt::Debug; use criterion::{criterion_group, criterion_main, Criterion}; use manul::{ protocol::{ - Artifact, Deserializer, DirectMessage, EchoBroadcast, EntryPoint, FinalizeError, FinalizeOutcome, LocalError, - NormalBroadcast, PartyId, Payload, Protocol, ProtocolError, ProtocolMessagePart, ProtocolValidationError, - ReceiveError, Round, RoundId, Serializer, + Artifact, BoxedRound, Deserializer, DirectMessage, EchoBroadcast, EntryPoint, FinalizeError, FinalizeOutcome, + LocalError, NormalBroadcast, PartyId, Payload, Protocol, ProtocolError, ProtocolMessagePart, + ProtocolValidationError, ReceiveError, Round, RoundId, Serializer, }, session::{signature::Keypair, SessionOutcome}, testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier}, @@ -75,16 +75,17 @@ struct Round1Artifact; impl EntryPoint for EmptyRound { type Inputs = Inputs; + type Protocol = EmptyProtocol; fn new( _rng: &mut impl CryptoRngCore, _shared_randomness: &[u8], _id: Id, inputs: Self::Inputs, - ) -> Result { - Ok(Self { + ) -> Result, LocalError> { + Ok(BoxedRound::new_dynamic(Self { round_counter: 1, inputs, - }) + })) } } @@ -165,11 +166,11 @@ impl Round for EmptyRound { if self.round_counter == self.inputs.rounds_num { Ok(FinalizeOutcome::Result(())) } else { - let round = EmptyRound { + let round = BoxedRound::new_dynamic(EmptyRound { round_counter: self.round_counter + 1, inputs: self.inputs, - }; - Ok(FinalizeOutcome::another_round(round)) + }); + Ok(FinalizeOutcome::AnotherRound(round)) } } diff --git a/manul/src/protocol.rs b/manul/src/protocol.rs index 5912a62..5c19d70 100644 --- a/manul/src/protocol.rs +++ b/manul/src/protocol.rs @@ -22,12 +22,10 @@ pub use errors::{ NormalBroadcastError, ProtocolValidationError, ReceiveError, RemoteError, }; pub use message::{DirectMessage, EchoBroadcast, NormalBroadcast, ProtocolMessagePart}; -pub use round::{ - AnotherRound, Artifact, EntryPoint, FinalizeOutcome, PartyId, Payload, Protocol, ProtocolError, Round, RoundId, -}; +pub use object_safe::BoxedRound; +pub use round::{Artifact, EntryPoint, FinalizeOutcome, PartyId, Payload, Protocol, ProtocolError, Round, RoundId}; pub use serialization::{Deserializer, Serializer}; pub(crate) use errors::ReceiveErrorType; -pub(crate) use object_safe::{ObjectSafeRound, ObjectSafeRoundWrapper}; pub use digest; diff --git a/manul/src/protocol/object_safe.rs b/manul/src/protocol/object_safe.rs index 6d282be..10de974 100644 --- a/manul/src/protocol/object_safe.rs +++ b/manul/src/protocol/object_safe.rs @@ -197,23 +197,44 @@ where } } -// When we are wrapping types implementing Round and overriding `finalize()`, -// we need to unbox the result of `finalize()`, set it as an attribute of the wrapping round, -// and then box the result. -// -// Because of Rust's peculiarities, Box that we return in `finalize()` -// cannot be unboxed into an object of a concrete type with `downcast()`, -// so we have to provide this workaround. -impl dyn ObjectSafeRound -where - Id: PartyId, - P: Protocol, -{ - pub fn try_downcast>(self: Box) -> Result> { - if core::any::TypeId::of::>() == self.get_type_id() { - // This should be safe since we just checked that we are casting to a correct type. +// We do not want to expose `ObjectSafeRound` to the user, so it is hidden in a struct. +/// A wrapped new round that may be returned by [`Round::finalize`] +/// or [`EntryPoint::new`](`crate::protocol::EntryPoint::new`). +#[derive_where::derive_where(Debug)] +pub struct BoxedRound { + wrapped: bool, + round: Box>, +} + +impl BoxedRound { + /// Wraps an object implementing the dynamic round trait ([`Round`](`crate::protocol::Round`)). + pub fn new_dynamic>(round: R) -> Self { + Self { + wrapped: true, + round: Box::new(ObjectSafeRoundWrapper::new(round)), + } + } + + pub(crate) fn as_ref(&self) -> &dyn ObjectSafeRound { + self.round.as_ref() + } + + pub(crate) fn into_boxed(self) -> Box> { + self.round + } + + fn boxed_type_is(&self) -> bool { + core::any::TypeId::of::() == self.round.get_type_id() + } + + /// Attempts to extract an object of a concrete type, preserving the original on failure. + pub fn try_downcast>(self) -> Result { + if self.wrapped && self.boxed_type_is::>() { + // Safety: This is safe since we just checked that we are casting to the correct type. let boxed_downcast = unsafe { - Box::>::from_raw(Box::into_raw(self) as *mut ObjectSafeRoundWrapper) + Box::>::from_raw( + Box::into_raw(self.round) as *mut ObjectSafeRoundWrapper + ) }; Ok(boxed_downcast.round) } else { @@ -221,8 +242,27 @@ where } } - pub fn downcast>(self: Box) -> Result { + /// Attempts to extract an object of a concrete type. + /// + /// Fails if the wrapped type is not `T`. + pub fn downcast>(self) -> Result { self.try_downcast() .map_err(|_| LocalError::new(format!("Failed to downcast into type {}", core::any::type_name::()))) } + + /// Attempts to provide a reference to an object of a concrete type. + /// + /// Fails if the wrapped type is not `T`. + pub fn downcast_ref>(&self) -> Result<&T, LocalError> { + if self.wrapped && self.boxed_type_is::>() { + let ptr: *const dyn ObjectSafeRound = self.round.as_ref(); + // Safety: This is safe since we just checked that we are casting to the correct type. + Ok(unsafe { &*(ptr as *const T) }) + } else { + Err(LocalError::new(format!( + "Failed to downcast into type {}", + core::any::type_name::() + ))) + } + } } diff --git a/manul/src/protocol/round.rs b/manul/src/protocol/round.rs index 777b521..ec58c1d 100644 --- a/manul/src/protocol/round.rs +++ b/manul/src/protocol/round.rs @@ -13,62 +13,19 @@ use serde::{Deserialize, Serialize}; use super::{ errors::{FinalizeError, LocalError, MessageValidationError, ProtocolValidationError, ReceiveError}, message::{DirectMessage, EchoBroadcast, NormalBroadcast, ProtocolMessagePart}, - object_safe::{ObjectSafeRound, ObjectSafeRoundWrapper}, + object_safe::BoxedRound, serialization::{Deserializer, Serializer}, }; /// Possible successful outcomes of [`Round::finalize`]. #[derive(Debug)] -pub enum FinalizeOutcome { +pub enum FinalizeOutcome { /// Transition to a new round. - AnotherRound(AnotherRound), + AnotherRound(BoxedRound), /// The protocol reached a result. Result(P::Result), } -impl FinalizeOutcome -where - Id: PartyId, - P: Protocol, -{ - /// A helper method to create an [`AnotherRound`](`Self::AnotherRound`) variant. - pub fn another_round(round: impl Round) -> Self { - Self::AnotherRound(AnotherRound::new(round)) - } -} - -// We do not want to expose `ObjectSafeRound` to the user, so it is hidden in a struct. -/// A wrapped new round that may be returned by [`Round::finalize`]. -#[derive(Debug)] -pub struct AnotherRound(Box>); - -impl AnotherRound -where - Id: PartyId, - P: Protocol, -{ - /// Wraps an object implementing [`Round`]. - pub fn new(round: impl Round) -> Self { - Self(Box::new(ObjectSafeRoundWrapper::new(round))) - } - - /// Returns the inner boxed type. - /// This is an internal method to be used in `Session`. - pub(crate) fn into_boxed(self) -> Box> { - self.0 - } - - /// Attempts to extract an object of a concrete type. - pub fn downcast>(self) -> Result { - self.0.downcast::() - } - - /// Attempts to extract an object of a concrete type, preserving the original on failure. - pub fn try_downcast>(self) -> Result { - self.0.try_downcast::().map_err(Self) - } -} - /// A round identifier. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub struct RoundId { @@ -301,10 +258,13 @@ impl Artifact { /// /// This is a round that can be created directly; /// all the others are only reachable throud [`Round::finalize`] by the execution layer. -pub trait EntryPoint: Round + Sized { +pub trait EntryPoint { /// Additional inputs for the protocol (besides the mandatory ones in [`new`](`Self::new`)). type Inputs; + /// The protocol implemented by the round this entry points returns. + type Protocol: Protocol; + /// Creates the round. /// /// `session_id` can be assumed to be the same for each node participating in a session. @@ -314,7 +274,7 @@ pub trait EntryPoint: Round + Sized { shared_randomness: &[u8], id: Id, inputs: Self::Inputs, - ) -> Result; + ) -> Result, LocalError>; } /// A trait alias for the combination of traits needed for a party identifier. diff --git a/manul/src/session/echo.rs b/manul/src/session/echo.rs index 5bbd124..936df48 100644 --- a/manul/src/session/echo.rs +++ b/manul/src/session/echo.rs @@ -1,5 +1,4 @@ use alloc::{ - boxed::Box, collections::{BTreeMap, BTreeSet}, format, string::String, @@ -18,8 +17,8 @@ use super::{ }; use crate::{ protocol::{ - Artifact, Deserializer, DirectMessage, EchoBroadcast, FinalizeError, FinalizeOutcome, MessageValidationError, - NormalBroadcast, ObjectSafeRound, Payload, Protocol, ProtocolMessagePart, ReceiveError, Round, RoundId, + Artifact, BoxedRound, Deserializer, DirectMessage, EchoBroadcast, FinalizeError, FinalizeOutcome, + MessageValidationError, NormalBroadcast, Payload, Protocol, ProtocolMessagePart, ReceiveError, Round, RoundId, Serializer, }, utils::SerializableMap, @@ -75,12 +74,12 @@ pub(crate) struct EchoRoundMessage { /// participants. The execution layer of the protocol guarantees that all participants have received /// the messages. #[derive_where::derive_where(Debug)] -pub struct EchoRound { +pub struct EchoRound { verifier: SP::Verifier, echo_broadcasts: BTreeMap>, destinations: BTreeSet, expected_echos: BTreeSet, - main_round: Box>, + main_round: BoxedRound, payloads: BTreeMap, artifacts: BTreeMap, } @@ -94,7 +93,7 @@ where verifier: SP::Verifier, my_echo_broadcast: SignedMessagePart, echo_broadcasts: BTreeMap>, - main_round: Box>, + main_round: BoxedRound, payloads: BTreeMap, artifacts: BTreeMap, ) -> Self { @@ -147,11 +146,11 @@ where type Protocol = P; fn id(&self) -> RoundId { - self.main_round.id().echo() + self.main_round.as_ref().id().echo() } fn possible_next_rounds(&self) -> BTreeSet { - self.main_round.possible_next_rounds() + self.main_round.as_ref().possible_next_rounds() } fn message_destinations(&self) -> &BTreeSet { @@ -297,6 +296,8 @@ where _payloads: BTreeMap, _artifacts: BTreeMap, ) -> Result, FinalizeError> { - self.main_round.finalize(rng, self.payloads, self.artifacts) + self.main_round + .into_boxed() + .finalize(rng, self.payloads, self.artifacts) } } diff --git a/manul/src/session/session.rs b/manul/src/session/session.rs index 6750adc..32d926b 100644 --- a/manul/src/session/session.rs +++ b/manul/src/session/session.rs @@ -23,9 +23,9 @@ use super::{ LocalError, RemoteError, }; use crate::protocol::{ - Artifact, Deserializer, DirectMessage, EchoBroadcast, EntryPoint, FinalizeError, FinalizeOutcome, NormalBroadcast, - ObjectSafeRound, ObjectSafeRoundWrapper, PartyId, Payload, Protocol, ProtocolMessagePart, ReceiveError, - ReceiveErrorType, Round, RoundId, Serializer, + Artifact, BoxedRound, Deserializer, DirectMessage, EchoBroadcast, EntryPoint, FinalizeError, FinalizeOutcome, + NormalBroadcast, PartyId, Payload, Protocol, ProtocolMessagePart, ReceiveError, ReceiveErrorType, RoundId, + Serializer, }; /// A set of types needed to execute a session. @@ -106,7 +106,7 @@ pub struct Session { verifier: SP::Verifier, serializer: Serializer, deserializer: Deserializer, - round: Box>, + round: BoxedRound, message_destinations: BTreeSet, echo_broadcast: SignedMessagePart, normal_broadcast: SignedMessagePart, @@ -141,15 +141,10 @@ where inputs: R::Inputs, ) -> Result where - R: EntryPoint + Round, + R: EntryPoint, { let verifier = signer.verifying_key(); - let first_round = Box::new(ObjectSafeRoundWrapper::new(R::new( - rng, - session_id.as_ref(), - verifier.clone(), - inputs, - )?)); + let first_round = R::new(rng, session_id.as_ref(), verifier.clone(), inputs)?; let serializer = Serializer::new::(); let deserializer = Deserializer::new::(); Self::new_for_next_round( @@ -169,23 +164,23 @@ where signer: SP::Signer, serializer: Serializer, deserializer: Deserializer, - round: Box>, + round: BoxedRound, transcript: Transcript, ) -> Result { let verifier = signer.verifying_key(); - let echo = round.make_echo_broadcast(rng, &serializer)?; - let echo_broadcast = SignedMessagePart::new::(rng, &signer, &session_id, round.id(), echo)?; + let echo = round.as_ref().make_echo_broadcast(rng, &serializer)?; + let echo_broadcast = SignedMessagePart::new::(rng, &signer, &session_id, round.as_ref().id(), echo)?; - let normal = round.make_normal_broadcast(rng, &serializer)?; - let normal_broadcast = SignedMessagePart::new::(rng, &signer, &session_id, round.id(), normal)?; + let normal = round.as_ref().make_normal_broadcast(rng, &serializer)?; + let normal_broadcast = SignedMessagePart::new::(rng, &signer, &session_id, round.as_ref().id(), normal)?; - let message_destinations = round.message_destinations().clone(); + let message_destinations = round.as_ref().message_destinations().clone(); let possible_next_rounds = if echo_broadcast.payload().is_none() { - round.possible_next_rounds() + round.as_ref().possible_next_rounds() } else { - BTreeSet::from([round.id().echo()]) + BTreeSet::from([round.as_ref().id().echo()]) }; Ok(Self { @@ -226,13 +221,16 @@ where rng: &mut impl CryptoRngCore, destination: &SP::Verifier, ) -> Result<(Message, ProcessedArtifact), LocalError> { - let (direct_message, artifact) = self.round.make_direct_message(rng, &self.serializer, destination)?; + let (direct_message, artifact) = self + .round + .as_ref() + .make_direct_message(rng, &self.serializer, destination)?; let message = Message::new::( rng, &self.signer, &self.session_id, - self.round.id(), + self.round.as_ref().id(), destination, direct_message, self.echo_broadcast.clone(), @@ -258,7 +256,7 @@ where /// Returns the ID of the current round. pub fn round_id(&self) -> RoundId { - self.round.id() + self.round.as_ref().id() } /// Performs some preliminary checks on the message to verify its integrity. @@ -374,7 +372,7 @@ where rng: &mut impl CryptoRngCore, message: VerifiedMessage, ) -> ProcessedMessage { - let processed = self.round.receive_message( + let processed = self.round.as_ref().receive_message( rng, &self.deserializer, message.from(), @@ -398,7 +396,7 @@ where /// Makes an accumulator for a new round. pub fn make_accumulator(&self) -> RoundAccumulator { - RoundAccumulator::new(self.round.expecting_messages_from()) + RoundAccumulator::new(self.round.as_ref().expecting_messages_from()) } fn terminate_inner( @@ -460,15 +458,15 @@ where let echo_round_needed = !self.echo_broadcast.payload().is_none(); if echo_round_needed { - let round = Box::new(ObjectSafeRoundWrapper::new(EchoRound::::new( + let round = BoxedRound::new_dynamic(EchoRound::::new( verifier, self.echo_broadcast, transcript.echo_broadcasts(round_id)?, self.round, accum.payloads, accum.artifacts, - ))); - let cached_messages = filter_messages(accum.cached, round.id()); + )); + let cached_messages = filter_messages(accum.cached, round.as_ref().id()); let session = Session::new_for_next_round( rng, self.session_id, @@ -484,24 +482,25 @@ where }); } - match self.round.finalize(rng, accum.payloads, accum.artifacts) { + match self.round.into_boxed().finalize(rng, accum.payloads, accum.artifacts) { Ok(result) => Ok(match result { FinalizeOutcome::Result(result) => { RoundOutcome::Finished(SessionReport::new(SessionOutcome::Result(result), transcript)) } - FinalizeOutcome::AnotherRound(another_round) => { - let round = another_round.into_boxed(); - + FinalizeOutcome::AnotherRound(round) => { // Protecting against common bugs - if !self.possible_next_rounds.contains(&round.id()) { - return Err(LocalError::new(format!("Unexpected next round id: {:?}", round.id()))); + if !self.possible_next_rounds.contains(&round.as_ref().id()) { + return Err(LocalError::new(format!( + "Unexpected next round id: {:?}", + round.as_ref().id() + ))); } // These messages could have been cached before // processing messages from the same node for the current round. // So there might have been some new errors, and we need to check again // if the sender is already banned. - let cached_messages = filter_messages(accum.cached, round.id()) + let cached_messages = filter_messages(accum.cached, round.as_ref().id()) .into_iter() .filter(|message| !transcript.is_banned(message.from())) .collect::>(); diff --git a/manul/src/testing/run_sync.rs b/manul/src/testing/run_sync.rs index 9e06ac1..aa3db0c 100644 --- a/manul/src/testing/run_sync.rs +++ b/manul/src/testing/run_sync.rs @@ -104,7 +104,7 @@ where for (signer, inputs) in inputs { let verifier = signer.verifying_key(); - let session = Session::::new::(rng, session_id.clone(), signer, inputs)?; + let session = Session::<_, SP>::new::(rng, session_id.clone(), signer, inputs)?; let mut accum = session.make_accumulator(); let destinations = session.message_destinations(); From 3f7eafa76e60cf7f4dd19db84c276bdc3858ffba Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Sun, 3 Nov 2024 14:15:22 -0800 Subject: [PATCH 4/9] Add an impl of `ProtocolError` for `()` --- CHANGELOG.md | 1 + manul/benches/empty_rounds.rs | 33 ++++----------------------------- manul/src/protocol/round.rs | 22 ++++++++++++++++++++++ manul/src/session/session.rs | 33 ++------------------------------- 4 files changed, 29 insertions(+), 60 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ac0e5ad..5919e5d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Re-export `digest` from the `session` module. ([#56]) - Added `Message::destination()`. ([#56]) - `PartyId` trait alias for the combination of bounds needed for a party identifier. ([#59]) +- An impl of `ProtocolError` for `()` for protocols that don't use errors. ([#60]) [#32]: https://github.com/entropyxyz/manul/pull/32 diff --git a/manul/benches/empty_rounds.rs b/manul/benches/empty_rounds.rs index 25cbf5f..87eefee 100644 --- a/manul/benches/empty_rounds.rs +++ b/manul/benches/empty_rounds.rs @@ -1,17 +1,14 @@ extern crate alloc; -use alloc::{ - collections::{BTreeMap, BTreeSet}, - string::String, -}; +use alloc::collections::{BTreeMap, BTreeSet}; use core::fmt::Debug; use criterion::{criterion_group, criterion_main, Criterion}; use manul::{ protocol::{ Artifact, BoxedRound, Deserializer, DirectMessage, EchoBroadcast, EntryPoint, FinalizeError, FinalizeOutcome, - LocalError, NormalBroadcast, PartyId, Payload, Protocol, ProtocolError, ProtocolMessagePart, - ProtocolValidationError, ReceiveError, Round, RoundId, Serializer, + LocalError, NormalBroadcast, PartyId, Payload, Protocol, ProtocolMessagePart, ReceiveError, Round, RoundId, + Serializer, }, session::{signature::Keypair, SessionOutcome}, testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier}, @@ -22,31 +19,9 @@ use serde::{Deserialize, Serialize}; #[derive(Debug)] pub struct EmptyProtocol; -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct EmptyProtocolError; - -impl ProtocolError for EmptyProtocolError { - fn description(&self) -> String { - unimplemented!() - } - fn verify_messages_constitute_error( - &self, - _deserializer: &Deserializer, - _echo_broadcast: &EchoBroadcast, - _normal_broadcast: &NormalBroadcast, - _direct_message: &DirectMessage, - _echo_broadcasts: &BTreeMap, - _normal_broadcasts: &BTreeMap, - _direct_messages: &BTreeMap, - _combined_echos: &BTreeMap>, - ) -> Result<(), ProtocolValidationError> { - unimplemented!() - } -} - impl Protocol for EmptyProtocol { type Result = (); - type ProtocolError = EmptyProtocolError; + type ProtocolError = (); type CorrectnessProof = (); } diff --git a/manul/src/protocol/round.rs b/manul/src/protocol/round.rs index ec58c1d..ddc3011 100644 --- a/manul/src/protocol/round.rs +++ b/manul/src/protocol/round.rs @@ -201,6 +201,28 @@ pub trait ProtocolError: Debug + Clone + Send { ) -> Result<(), ProtocolValidationError>; } +// A convenience implementation for protocols that don't define any errors. +// Have to do it for `()`, since `!` is unstable. +impl ProtocolError for () { + fn description(&self) -> String { + panic!("Attempt to use an empty error type in an evidence. This is a bug in the protocol implementation.") + } + + fn verify_messages_constitute_error( + &self, + _deserializer: &Deserializer, + _echo_broadcast: &EchoBroadcast, + _normal_broadcast: &NormalBroadcast, + _direct_message: &DirectMessage, + _echo_broadcasts: &BTreeMap, + _normal_broadcasts: &BTreeMap, + _direct_messages: &BTreeMap, + _combined_echos: &BTreeMap>, + ) -> Result<(), ProtocolValidationError> { + panic!("Attempt to use an empty error type in an evidence. This is a bug in the protocol implementation.") + } +} + /// Message payload created in [`Round::receive_message`]. #[derive(Debug)] pub struct Payload(pub Box); diff --git a/manul/src/session/session.rs b/manul/src/session/session.rs index 32d926b..c556344 100644 --- a/manul/src/session/session.rs +++ b/manul/src/session/session.rs @@ -816,17 +816,11 @@ fn filter_messages( #[cfg(test)] mod tests { - use alloc::{collections::BTreeMap, string::String, vec::Vec}; - use impls::impls; - use serde::{Deserialize, Serialize}; use super::{Message, ProcessedArtifact, ProcessedMessage, Session, VerifiedMessage}; use crate::{ - protocol::{ - Deserializer, DirectMessage, EchoBroadcast, NormalBroadcast, Protocol, ProtocolError, - ProtocolValidationError, RoundId, - }, + protocol::Protocol, testing::{BinaryFormat, TestSessionParams, TestVerifier}, }; @@ -842,32 +836,9 @@ mod tests { struct DummyProtocol; - #[derive(Debug, Clone, Serialize, Deserialize)] - struct DummyProtocolError; - - impl ProtocolError for DummyProtocolError { - fn description(&self) -> String { - unimplemented!() - } - - fn verify_messages_constitute_error( - &self, - _deserializer: &Deserializer, - _echo_broadcast: &EchoBroadcast, - _normal_broadcast: &NormalBroadcast, - _direct_message: &DirectMessage, - _echo_broadcasts: &BTreeMap, - _normal_broadcasts: &BTreeMap, - _direct_messages: &BTreeMap, - _combined_echos: &BTreeMap>, - ) -> Result<(), ProtocolValidationError> { - unimplemented!() - } - } - impl Protocol for DummyProtocol { type Result = (); - type ProtocolError = DummyProtocolError; + type ProtocolError = (); type CorrectnessProof = (); } From be30e0dd9744fc7cb32616e1422289cadfdf9350 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Mon, 4 Nov 2024 15:31:03 -0800 Subject: [PATCH 5/9] Add a dummy `CorrectnessProof` trait, to be extended elsewhere --- CHANGELOG.md | 1 + manul/src/protocol.rs | 4 +++- manul/src/protocol/round.rs | 14 +++++++++++++- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5919e5d..e0916b9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `Message::destination()`. ([#56]) - `PartyId` trait alias for the combination of bounds needed for a party identifier. ([#59]) - An impl of `ProtocolError` for `()` for protocols that don't use errors. ([#60]) +- A dummy `CorrectnessProof` trait. ([#60]) [#32]: https://github.com/entropyxyz/manul/pull/32 diff --git a/manul/src/protocol.rs b/manul/src/protocol.rs index 5c19d70..f1b063f 100644 --- a/manul/src/protocol.rs +++ b/manul/src/protocol.rs @@ -23,7 +23,9 @@ pub use errors::{ }; pub use message::{DirectMessage, EchoBroadcast, NormalBroadcast, ProtocolMessagePart}; pub use object_safe::BoxedRound; -pub use round::{Artifact, EntryPoint, FinalizeOutcome, PartyId, Payload, Protocol, ProtocolError, Round, RoundId}; +pub use round::{ + Artifact, CorrectnessProof, EntryPoint, FinalizeOutcome, PartyId, Payload, Protocol, ProtocolError, Round, RoundId, +}; pub use serialization::{Deserializer, Serializer}; pub(crate) use errors::ReceiveErrorType; diff --git a/manul/src/protocol/round.rs b/manul/src/protocol/round.rs index ddc3011..c4cc544 100644 --- a/manul/src/protocol/round.rs +++ b/manul/src/protocol/round.rs @@ -89,7 +89,7 @@ pub trait Protocol: 'static { /// An object of this type will be returned when an unattributable error happens during [`Round::finalize`]. /// /// It proves that the node did its job correctly, to be adjudicated by a third party. - type CorrectnessProof: Send + Serialize + for<'de> Deserialize<'de> + Debug; + type CorrectnessProof: CorrectnessProof + Serialize + for<'de> Deserialize<'de>; /// Returns `Ok(())` if the given direct message cannot be deserialized /// assuming it is a direct message from the round `round_id`. @@ -223,6 +223,18 @@ impl ProtocolError for () { } } +/// Describes unattributable errors originating during protocol execution. +/// +/// In the situations where no specific message can be blamed for an error, +/// each node must generate a correctness proof proving that they performed their duties correctly, +/// and the collection of proofs is verified by a third party. +/// One of the proofs will necessarily be missing or invalid. +pub trait CorrectnessProof: Debug + Clone + Send {} + +// A convenience implementation for protocols that don't define any errors. +// Have to do it for `()`, since `!` is unstable. +impl CorrectnessProof for () {} + /// Message payload created in [`Round::receive_message`]. #[derive(Debug)] pub struct Payload(pub Box); From e8ff7403669466a24de40132188a75380741572b Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Mon, 4 Nov 2024 15:02:29 -0800 Subject: [PATCH 6/9] Add serialization to the bounds of PartyId, ProtocolError, and CorrectnessProof --- CHANGELOG.md | 1 + manul/src/protocol/round.rs | 12 ++++++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e0916b9..dbd6402 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Renamed `FirstRound` trait to `EntryPoint`. ([#60]) - Added `Protocol` type to `EntryPoint`. ([#60]) - `EntryPoint` and `FinalizeOutcome::AnotherRound` now use a new `BoxedRound` wrapper type. ([#60]) +- `PartyId` and `ProtocolError` are now bound on `Serialize`/`Deserialize`. ([#60]) ### Added diff --git a/manul/src/protocol/round.rs b/manul/src/protocol/round.rs index c4cc544..1f8c556 100644 --- a/manul/src/protocol/round.rs +++ b/manul/src/protocol/round.rs @@ -84,12 +84,12 @@ pub trait Protocol: 'static { type Result: Debug; /// An object of this type will be returned when a provable error happens during [`Round::receive_message`]. - type ProtocolError: ProtocolError + Serialize + for<'de> Deserialize<'de>; + type ProtocolError: ProtocolError; /// An object of this type will be returned when an unattributable error happens during [`Round::finalize`]. /// /// It proves that the node did its job correctly, to be adjudicated by a third party. - type CorrectnessProof: CorrectnessProof + Serialize + for<'de> Deserialize<'de>; + type CorrectnessProof: CorrectnessProof; /// Returns `Ok(())` if the given direct message cannot be deserialized /// assuming it is a direct message from the round `round_id`. @@ -138,7 +138,7 @@ pub trait Protocol: 'static { /// /// Provable here means that we can create an evidence object entirely of messages signed by some party, /// which, in combination, prove the party's malicious actions. -pub trait ProtocolError: Debug + Clone + Send { +pub trait ProtocolError: Debug + Clone + Send + Serialize + for<'de> Deserialize<'de> { /// A description of the error that will be included in the generated evidence. /// /// Make it short and informative. @@ -229,7 +229,7 @@ impl ProtocolError for () { /// each node must generate a correctness proof proving that they performed their duties correctly, /// and the collection of proofs is verified by a third party. /// One of the proofs will necessarily be missing or invalid. -pub trait CorrectnessProof: Debug + Clone + Send {} +pub trait CorrectnessProof: Debug + Clone + Send + Serialize + for<'de> Deserialize<'de> {} // A convenience implementation for protocols that don't define any errors. // Have to do it for `()`, since `!` is unstable. @@ -312,9 +312,9 @@ pub trait EntryPoint { } /// A trait alias for the combination of traits needed for a party identifier. -pub trait PartyId: 'static + Debug + Clone + Ord + Send + Sync {} +pub trait PartyId: 'static + Debug + Clone + Ord + Send + Sync + Serialize + for<'de> Deserialize<'de> {} -impl PartyId for T where T: 'static + Debug + Clone + Ord + Send + Sync {} +impl PartyId for T where T: 'static + Debug + Clone + Ord + Send + Sync + Serialize + for<'de> Deserialize<'de> {} /** A type representing a single round of a protocol. From f5d5bd736c50f4ef4c2aa4309cfd21711ab28dab Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Wed, 13 Nov 2024 10:22:55 -0800 Subject: [PATCH 7/9] Add Misbehaving combinator --- CHANGELOG.md | 1 + examples/src/simple_malicious.rs | 196 ++++++------------- manul/src/combinators.rs | 3 + manul/src/combinators/misbehave.rs | 291 +++++++++++++++++++++++++++++ manul/src/lib.rs | 1 + manul/src/protocol.rs | 1 + manul/src/protocol/object_safe.rs | 28 ++- manul/src/session/echo.rs | 2 +- manul/src/session/session.rs | 33 ++-- manul/src/testing.rs | 6 - manul/src/testing/macros.rs | 157 ---------------- 11 files changed, 395 insertions(+), 324 deletions(-) create mode 100644 manul/src/combinators.rs create mode 100644 manul/src/combinators/misbehave.rs delete mode 100644 manul/src/testing/macros.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index dbd6402..3c464d9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `PartyId` trait alias for the combination of bounds needed for a party identifier. ([#59]) - An impl of `ProtocolError` for `()` for protocols that don't use errors. ([#60]) - A dummy `CorrectnessProof` trait. ([#60]) +- A `misbehave` combinator, intended primarily for testing. ([#60]) [#32]: https://github.com/entropyxyz/manul/pull/32 diff --git a/examples/src/simple_malicious.rs b/examples/src/simple_malicious.rs index 90a4aa9..2f58f77 100644 --- a/examples/src/simple_malicious.rs +++ b/examples/src/simple_malicious.rs @@ -1,151 +1,75 @@ -use alloc::collections::{BTreeMap, BTreeSet}; +use alloc::collections::BTreeSet; use core::fmt::Debug; use manul::{ + combinators::misbehave::{Misbehaving, MisbehavingEntryPoint, MisbehavingInputs}, protocol::{ - Artifact, BoxedRound, DirectMessage, EntryPoint, FinalizeError, FinalizeOutcome, LocalError, PartyId, Payload, - ProtocolMessagePart, Round, Serializer, + Artifact, BoxedRound, Deserializer, DirectMessage, EntryPoint, LocalError, PartyId, ProtocolMessagePart, + RoundId, Serializer, }, session::signature::Keypair, - testing::{ - round_override, run_sync, BinaryFormat, RoundOverride, RoundWrapper, TestSessionParams, TestSigner, - TestVerifier, - }, + testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier}, }; use rand_core::{CryptoRngCore, OsRng}; use tracing_subscriber::EnvFilter; -use crate::simple::{Inputs, Round1, Round1Message, Round2, Round2Message, SimpleProtocol}; +use crate::simple::{Inputs, Round1, Round1Message, Round2, Round2Message}; #[derive(Debug, Clone, Copy)] enum Behavior { - Lawful, SerializedGarbage, AttributableFailure, AttributableFailureRound2, } -struct MaliciousInputs { - inputs: Inputs, - behavior: Behavior, -} - -#[derive(Debug)] -struct MaliciousRound1 { - round: Round1, - behavior: Behavior, -} +struct MaliciousLogic; -impl RoundWrapper for MaliciousRound1 { - type InnerRound = Round1; - fn inner_round_ref(&self) -> &Self::InnerRound { - &self.round - } - fn inner_round(self) -> Self::InnerRound { - self.round - } -} +impl Misbehaving for MaliciousLogic { + type EntryPoint = Round1; -impl EntryPoint for MaliciousRound1 { - type Inputs = MaliciousInputs; - type Protocol = SimpleProtocol; - fn new( - rng: &mut impl CryptoRngCore, - shared_randomness: &[u8], - id: Id, - inputs: Self::Inputs, - ) -> Result, LocalError> { - let round = Round1::new(rng, shared_randomness, id, inputs.inputs)?.downcast::>()?; - Ok(BoxedRound::new_dynamic(Self { - round, - behavior: inputs.behavior, - })) - } -} - -impl RoundOverride for MaliciousRound1 { - fn make_direct_message( - &self, - rng: &mut impl CryptoRngCore, + fn modify_direct_message( + _rng: &mut impl CryptoRngCore, + round: &BoxedRound>::Protocol>, + behavior: &Behavior, serializer: &Serializer, - destination: &Id, + _deserializer: &Deserializer, + _destination: &Id, + direct_message: DirectMessage, + artifact: Option, ) -> Result<(DirectMessage, Option), LocalError> { - if matches!(self.behavior, Behavior::SerializedGarbage) { - Ok((DirectMessage::new(serializer, [99u8])?, None)) - } else if matches!(self.behavior, Behavior::AttributableFailure) { - let message = Round1Message { - my_position: self.round.context.ids_to_positions[&self.round.context.id], - your_position: self.round.context.ids_to_positions[&self.round.context.id], - }; - Ok((DirectMessage::new(serializer, message)?, None)) - } else { - self.inner_round_ref().make_direct_message(rng, serializer, destination) - } - } - - fn finalize( - self, - rng: &mut impl CryptoRngCore, - payloads: BTreeMap, - artifacts: BTreeMap, - ) -> Result< - FinalizeOutcome>::InnerRound as Round>::Protocol>, - FinalizeError<<>::InnerRound as Round>::Protocol>, - > { - let behavior = self.behavior; - let outcome = self.inner_round().finalize(rng, payloads, artifacts)?; - - Ok(match outcome { - FinalizeOutcome::Result(res) => FinalizeOutcome::Result(res), - FinalizeOutcome::AnotherRound(boxed_round) => { - let round2 = boxed_round.downcast::>().map_err(FinalizeError::Local)?; - FinalizeOutcome::AnotherRound(BoxedRound::new_dynamic(MaliciousRound2 { - round: round2, - behavior, - })) + let dm = if round.id() == RoundId::new(1) { + match behavior { + Behavior::SerializedGarbage => DirectMessage::new(serializer, [99u8])?, + Behavior::AttributableFailure => { + let round1 = round.downcast_ref::>()?; + let message = Round1Message { + my_position: round1.context.ids_to_positions[&round1.context.id], + your_position: round1.context.ids_to_positions[&round1.context.id], + }; + DirectMessage::new(serializer, message)? + } + _ => direct_message, + } + } else if round.id() == RoundId::new(2) { + match behavior { + Behavior::AttributableFailureRound2 => { + let round2 = round.downcast_ref::>()?; + let message = Round2Message { + my_position: round2.context.ids_to_positions[&round2.context.id], + your_position: round2.context.ids_to_positions[&round2.context.id], + }; + DirectMessage::new(serializer, message)? + } + _ => direct_message, } - }) - } -} - -round_override!(MaliciousRound1); - -#[derive(Debug)] -struct MaliciousRound2 { - round: Round2, - behavior: Behavior, -} - -impl RoundWrapper for MaliciousRound2 { - type InnerRound = Round2; - fn inner_round_ref(&self) -> &Self::InnerRound { - &self.round - } - fn inner_round(self) -> Self::InnerRound { - self.round - } -} - -impl RoundOverride for MaliciousRound2 { - fn make_direct_message( - &self, - rng: &mut impl CryptoRngCore, - serializer: &Serializer, - destination: &Id, - ) -> Result<(DirectMessage, Option), LocalError> { - if matches!(self.behavior, Behavior::AttributableFailureRound2) { - let message = Round2Message { - my_position: self.round.context.ids_to_positions[&self.round.context.id], - your_position: self.round.context.ids_to_positions[&self.round.context.id], - }; - Ok((DirectMessage::new(serializer, message)?, None)) } else { - self.inner_round_ref().make_direct_message(rng, serializer, destination) - } + direct_message + }; + Ok((dm, artifact)) } } -round_override!(MaliciousRound2); +type MaliciousEntryPoint = MisbehavingEntryPoint; #[test] fn serialized_garbage() { @@ -161,13 +85,13 @@ fn serialized_garbage() { .enumerate() .map(|(idx, signer)| { let behavior = if idx == 0 { - Behavior::SerializedGarbage + Some(Behavior::SerializedGarbage) } else { - Behavior::Lawful + None }; - let malicious_inputs = MaliciousInputs { - inputs: inputs.clone(), + let malicious_inputs = MisbehavingInputs { + inner_inputs: inputs.clone(), behavior, }; (*signer, malicious_inputs) @@ -178,7 +102,7 @@ fn serialized_garbage() { .with_env_filter(EnvFilter::from_default_env()) .finish(); let mut reports = tracing::subscriber::with_default(my_subscriber, || { - run_sync::, TestSessionParams>(&mut OsRng, run_inputs).unwrap() + run_sync::, TestSessionParams>(&mut OsRng, run_inputs).unwrap() }); let v0 = signers[0].verifying_key(); @@ -207,13 +131,13 @@ fn attributable_failure() { .enumerate() .map(|(idx, signer)| { let behavior = if idx == 0 { - Behavior::AttributableFailure + Some(Behavior::AttributableFailure) } else { - Behavior::Lawful + None }; - let malicious_inputs = MaliciousInputs { - inputs: inputs.clone(), + let malicious_inputs = MisbehavingInputs { + inner_inputs: inputs.clone(), behavior, }; (*signer, malicious_inputs) @@ -224,7 +148,7 @@ fn attributable_failure() { .with_env_filter(EnvFilter::from_default_env()) .finish(); let mut reports = tracing::subscriber::with_default(my_subscriber, || { - run_sync::, TestSessionParams>(&mut OsRng, run_inputs).unwrap() + run_sync::, TestSessionParams>(&mut OsRng, run_inputs).unwrap() }); let v0 = signers[0].verifying_key(); @@ -253,13 +177,13 @@ fn attributable_failure_round2() { .enumerate() .map(|(idx, signer)| { let behavior = if idx == 0 { - Behavior::AttributableFailureRound2 + Some(Behavior::AttributableFailureRound2) } else { - Behavior::Lawful + None }; - let malicious_inputs = MaliciousInputs { - inputs: inputs.clone(), + let malicious_inputs = MisbehavingInputs { + inner_inputs: inputs.clone(), behavior, }; (*signer, malicious_inputs) @@ -270,7 +194,7 @@ fn attributable_failure_round2() { .with_env_filter(EnvFilter::from_default_env()) .finish(); let mut reports = tracing::subscriber::with_default(my_subscriber, || { - run_sync::, TestSessionParams>(&mut OsRng, run_inputs).unwrap() + run_sync::, TestSessionParams>(&mut OsRng, run_inputs).unwrap() }); let v0 = signers[0].verifying_key(); diff --git a/manul/src/combinators.rs b/manul/src/combinators.rs new file mode 100644 index 0000000..15e2d19 --- /dev/null +++ b/manul/src/combinators.rs @@ -0,0 +1,3 @@ +//! Combinators operating on protocols. + +pub mod misbehave; diff --git a/manul/src/combinators/misbehave.rs b/manul/src/combinators/misbehave.rs new file mode 100644 index 0000000..d71ddc2 --- /dev/null +++ b/manul/src/combinators/misbehave.rs @@ -0,0 +1,291 @@ +/*! +A combinator allowing one to intercept outgoing messages from a round, and replace or modify them. + +The usage is as follows: + +1. Define a behavior type, subject to [`Behavior`] bounds. + This will represent the possible actions the override may perform. + +2. Implement [`Misbehaving`] for a type of your choice. Usually it will be an empty token type. + You will need to specify the entry point for the unmodified protocol, + and some of `modify_*` methods (the blanket implementations simply pass through the original messages). + +3. The `modify_*` methods can be called from any round, use [`BoxedRound::id`](`crate::protocol::BoxedRound::id`) + on the `round` argument to determine which round it is. + +4. In the `modify_*` methods, you can get the original typed message using the provided `deserializer` argument, + and create a new one using the `serializer`. + +5. You can get access to the typed `Round` object by using + [`BoxedRound::downcast_ref`](`crate::protocol::BoxedRound::downcast_ref`). +*/ + +use alloc::{ + boxed::Box, + collections::{BTreeMap, BTreeSet}, +}; +use core::fmt::Debug; + +use rand_core::CryptoRngCore; + +use crate::protocol::{ + Artifact, BoxedRng, BoxedRound, Deserializer, DirectMessage, EchoBroadcast, EntryPoint, FinalizeError, + FinalizeOutcome, LocalError, NormalBroadcast, ObjectSafeRound, PartyId, Payload, ReceiveError, RoundId, Serializer, +}; + +/// A trait describing required properties for a behavior type. +pub trait Behavior: 'static + Debug + Send + Sync {} + +impl Behavior for T {} + +/// The new entry point for the misbehaving rounds. +/// +/// Use as an entry point to run the session, with your ID, behavior `B` and the misbehavior definition `M` set. +#[derive_where::derive_where(Debug)] +pub struct MisbehavingEntryPoint +where + Id: PartyId, + B: Behavior, + M: Misbehaving, +{ + round: BoxedRound>::Protocol>, + behavior: Option, +} + +/// A trait defining a sequence of misbehaving rounds modifying or replacing the messages sent by some existing ones. +/// +/// Override one or more optional methods to modify the specific messages. +pub trait Misbehaving: 'static +where + Id: PartyId, + B: Behavior, +{ + /// The entry point of the wrapped rounds. + type EntryPoint: EntryPoint; + + /// Called after [`Round::make_echo_broadcast`](`crate::protocol::Round::make_echo_broadcast`) + /// and may modify its result. + /// + /// The default implementation passes through the original message. + #[allow(unused_variables)] + fn modify_echo_broadcast( + rng: &mut impl CryptoRngCore, + round: &BoxedRound>::Protocol>, + behavior: &B, + serializer: &Serializer, + deserializer: &Deserializer, + echo_broadcast: EchoBroadcast, + ) -> Result { + Ok(echo_broadcast) + } + + /// Called after [`Round::make_normal_broadcast`](`crate::protocol::Round::make_normal_broadcast`) + /// and may modify its result. + /// + /// The default implementation passes through the original message. + #[allow(unused_variables)] + fn modify_normal_broadcast( + rng: &mut impl CryptoRngCore, + round: &BoxedRound>::Protocol>, + behavior: &B, + serializer: &Serializer, + deserializer: &Deserializer, + normal_broadcast: NormalBroadcast, + ) -> Result { + Ok(normal_broadcast) + } + + /// Called after [`Round::make_direct_message`](`crate::protocol::Round::make_direct_message`) + /// and may modify its result. + /// + /// The default implementation passes through the original message. + #[allow(unused_variables, clippy::too_many_arguments)] + fn modify_direct_message( + rng: &mut impl CryptoRngCore, + round: &BoxedRound>::Protocol>, + behavior: &B, + serializer: &Serializer, + deserializer: &Deserializer, + destination: &Id, + direct_message: DirectMessage, + artifact: Option, + ) -> Result<(DirectMessage, Option), LocalError> { + Ok((direct_message, artifact)) + } +} + +/// The inputs for the misbehaving rounds. +#[derive_where::derive_where(Debug; >::Inputs)] +pub struct MisbehavingInputs +where + Id: PartyId, + B: Behavior, + M: Misbehaving, +{ + /// The behavior for the rounds starting with these inputs. + /// + /// If `None`, all the changed behavior will be skipped. + pub behavior: Option, + /// The inputs for the wrapped rounds. + pub inner_inputs: >::Inputs, +} + +impl EntryPoint for MisbehavingEntryPoint +where + Id: PartyId, + B: Behavior, + M: Misbehaving, +{ + type Inputs = MisbehavingInputs; + type Protocol = >::Protocol; + + fn new( + rng: &mut impl CryptoRngCore, + shared_randomness: &[u8], + id: Id, + inputs: Self::Inputs, + ) -> Result>::Protocol>, LocalError> { + let round = M::EntryPoint::new(rng, shared_randomness, id, inputs.inner_inputs)?; + Ok(BoxedRound::new_object_safe(Self { + round, + behavior: inputs.behavior, + })) + } +} + +impl ObjectSafeRound for MisbehavingEntryPoint +where + Id: PartyId, + B: Behavior, + M: Misbehaving, +{ + type Protocol = >::Protocol; + + fn id(&self) -> RoundId { + self.round.as_ref().id() + } + + fn possible_next_rounds(&self) -> BTreeSet { + self.round.as_ref().possible_next_rounds() + } + + fn message_destinations(&self) -> &BTreeSet { + self.round.as_ref().message_destinations() + } + + fn make_direct_message( + &self, + rng: &mut dyn CryptoRngCore, + serializer: &Serializer, + deserializer: &Deserializer, + destination: &Id, + ) -> Result<(DirectMessage, Option), LocalError> { + let (direct_message, artifact) = + self.round + .as_ref() + .make_direct_message(rng, serializer, deserializer, destination)?; + if let Some(behavior) = self.behavior.as_ref() { + let mut boxed_rng = BoxedRng(rng); + M::modify_direct_message( + &mut boxed_rng, + &self.round, + behavior, + serializer, + deserializer, + destination, + direct_message, + artifact, + ) + } else { + Ok((direct_message, artifact)) + } + } + + fn make_echo_broadcast( + &self, + rng: &mut dyn CryptoRngCore, + serializer: &Serializer, + deserializer: &Deserializer, + ) -> Result { + let echo_broadcast = self.round.as_ref().make_echo_broadcast(rng, serializer, deserializer)?; + if let Some(behavior) = self.behavior.as_ref() { + let mut boxed_rng = BoxedRng(rng); + M::modify_echo_broadcast( + &mut boxed_rng, + &self.round, + behavior, + serializer, + deserializer, + echo_broadcast, + ) + } else { + Ok(echo_broadcast) + } + } + + fn make_normal_broadcast( + &self, + rng: &mut dyn CryptoRngCore, + serializer: &Serializer, + deserializer: &Deserializer, + ) -> Result { + let normal_broadcast = self + .round + .as_ref() + .make_normal_broadcast(rng, serializer, deserializer)?; + if let Some(behavior) = self.behavior.as_ref() { + let mut boxed_rng = BoxedRng(rng); + M::modify_normal_broadcast( + &mut boxed_rng, + &self.round, + behavior, + serializer, + deserializer, + normal_broadcast, + ) + } else { + Ok(normal_broadcast) + } + } + + fn receive_message( + &self, + rng: &mut dyn CryptoRngCore, + deserializer: &Deserializer, + from: &Id, + echo_broadcast: EchoBroadcast, + normal_broadcast: NormalBroadcast, + direct_message: DirectMessage, + ) -> Result> { + self.round.as_ref().receive_message( + rng, + deserializer, + from, + echo_broadcast, + normal_broadcast, + direct_message, + ) + } + + fn finalize( + self: Box, + rng: &mut dyn CryptoRngCore, + payloads: BTreeMap, + artifacts: BTreeMap, + ) -> Result, FinalizeError> { + match self.round.into_boxed().finalize(rng, payloads, artifacts) { + Ok(FinalizeOutcome::Result(result)) => Ok(FinalizeOutcome::Result(result)), + Ok(FinalizeOutcome::AnotherRound(round)) => Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_object_safe( + MisbehavingEntryPoint:: { + round, + behavior: self.behavior, + }, + ))), + Err(err) => Err(err), + } + } + + fn expecting_messages_from(&self) -> &BTreeSet { + self.round.as_ref().expecting_messages_from() + } +} diff --git a/manul/src/lib.rs b/manul/src/lib.rs index 8d3323f..ad55b47 100644 --- a/manul/src/lib.rs +++ b/manul/src/lib.rs @@ -16,6 +16,7 @@ extern crate alloc; +pub mod combinators; pub mod protocol; pub mod session; pub(crate) mod utils; diff --git a/manul/src/protocol.rs b/manul/src/protocol.rs index f1b063f..46b356f 100644 --- a/manul/src/protocol.rs +++ b/manul/src/protocol.rs @@ -29,5 +29,6 @@ pub use round::{ pub use serialization::{Deserializer, Serializer}; pub(crate) use errors::ReceiveErrorType; +pub(crate) use object_safe::{BoxedRng, ObjectSafeRound}; pub use digest; diff --git a/manul/src/protocol/object_safe.rs b/manul/src/protocol/object_safe.rs index 10de974..910abc9 100644 --- a/manul/src/protocol/object_safe.rs +++ b/manul/src/protocol/object_safe.rs @@ -17,7 +17,7 @@ use super::{ /// Since object-safe trait methods cannot take `impl CryptoRngCore` arguments, /// this structure wraps the dynamic object and exposes a `CryptoRngCore` interface, /// to be passed to statically typed round methods. -struct BoxedRng<'a>(&'a mut dyn CryptoRngCore); +pub(crate) struct BoxedRng<'a>(pub(crate) &'a mut dyn CryptoRngCore); impl CryptoRng for BoxedRng<'_> {} @@ -52,6 +52,7 @@ pub(crate) trait ObjectSafeRound: 'static + Debug + Send + Sync { &self, rng: &mut dyn CryptoRngCore, serializer: &Serializer, + deserializer: &Deserializer, destination: &Id, ) -> Result<(DirectMessage, Option), LocalError>; @@ -59,12 +60,14 @@ pub(crate) trait ObjectSafeRound: 'static + Debug + Send + Sync { &self, rng: &mut dyn CryptoRngCore, serializer: &Serializer, + deserializer: &Deserializer, ) -> Result; fn make_normal_broadcast( &self, rng: &mut dyn CryptoRngCore, serializer: &Serializer, + deserializer: &Deserializer, ) -> Result; fn receive_message( @@ -87,7 +90,9 @@ pub(crate) trait ObjectSafeRound: 'static + Debug + Send + Sync { fn expecting_messages_from(&self) -> &BTreeSet; /// Returns the type ID of the implementing type. - fn get_type_id(&self) -> core::any::TypeId; + fn get_type_id(&self) -> core::any::TypeId { + core::any::TypeId::of::() + } } // The `fn(Id) -> Id` bit is so that `ObjectSafeRoundWrapper` didn't require a bound on `Id` to be @@ -134,6 +139,7 @@ where &self, rng: &mut dyn CryptoRngCore, serializer: &Serializer, + #[allow(unused_variables)] deserializer: &Deserializer, destination: &Id, ) -> Result<(DirectMessage, Option), LocalError> { let mut boxed_rng = BoxedRng(rng); @@ -144,6 +150,7 @@ where &self, rng: &mut dyn CryptoRngCore, serializer: &Serializer, + #[allow(unused_variables)] deserializer: &Deserializer, ) -> Result { let mut boxed_rng = BoxedRng(rng); self.round.make_echo_broadcast(&mut boxed_rng, serializer) @@ -153,6 +160,7 @@ where &self, rng: &mut dyn CryptoRngCore, serializer: &Serializer, + #[allow(unused_variables)] deserializer: &Deserializer, ) -> Result { let mut boxed_rng = BoxedRng(rng); self.round.make_normal_broadcast(&mut boxed_rng, serializer) @@ -191,10 +199,6 @@ where fn expecting_messages_from(&self) -> &BTreeSet { self.round.expecting_messages_from() } - - fn get_type_id(&self) -> core::any::TypeId { - core::any::TypeId::of::() - } } // We do not want to expose `ObjectSafeRound` to the user, so it is hidden in a struct. @@ -215,6 +219,13 @@ impl BoxedRound { } } + pub(crate) fn new_object_safe>(round: R) -> Self { + Self { + wrapped: false, + round: Box::new(round), + } + } + pub(crate) fn as_ref(&self) -> &dyn ObjectSafeRound { self.round.as_ref() } @@ -265,4 +276,9 @@ impl BoxedRound { ))) } } + + /// Returns the round's ID. + pub fn id(&self) -> RoundId { + self.round.id() + } } diff --git a/manul/src/session/echo.rs b/manul/src/session/echo.rs index 936df48..c4fbcb9 100644 --- a/manul/src/session/echo.rs +++ b/manul/src/session/echo.rs @@ -146,7 +146,7 @@ where type Protocol = P; fn id(&self) -> RoundId { - self.main_round.as_ref().id().echo() + self.main_round.id().echo() } fn possible_next_rounds(&self) -> BTreeSet { diff --git a/manul/src/session/session.rs b/manul/src/session/session.rs index c556344..ad33fc3 100644 --- a/manul/src/session/session.rs +++ b/manul/src/session/session.rs @@ -169,18 +169,18 @@ where ) -> Result { let verifier = signer.verifying_key(); - let echo = round.as_ref().make_echo_broadcast(rng, &serializer)?; - let echo_broadcast = SignedMessagePart::new::(rng, &signer, &session_id, round.as_ref().id(), echo)?; + let echo = round.as_ref().make_echo_broadcast(rng, &serializer, &deserializer)?; + let echo_broadcast = SignedMessagePart::new::(rng, &signer, &session_id, round.id(), echo)?; - let normal = round.as_ref().make_normal_broadcast(rng, &serializer)?; - let normal_broadcast = SignedMessagePart::new::(rng, &signer, &session_id, round.as_ref().id(), normal)?; + let normal = round.as_ref().make_normal_broadcast(rng, &serializer, &deserializer)?; + let normal_broadcast = SignedMessagePart::new::(rng, &signer, &session_id, round.id(), normal)?; let message_destinations = round.as_ref().message_destinations().clone(); let possible_next_rounds = if echo_broadcast.payload().is_none() { round.as_ref().possible_next_rounds() } else { - BTreeSet::from([round.as_ref().id().echo()]) + BTreeSet::from([round.id().echo()]) }; Ok(Self { @@ -221,16 +221,16 @@ where rng: &mut impl CryptoRngCore, destination: &SP::Verifier, ) -> Result<(Message, ProcessedArtifact), LocalError> { - let (direct_message, artifact) = self - .round - .as_ref() - .make_direct_message(rng, &self.serializer, destination)?; + let (direct_message, artifact) = + self.round + .as_ref() + .make_direct_message(rng, &self.serializer, &self.deserializer, destination)?; let message = Message::new::( rng, &self.signer, &self.session_id, - self.round.as_ref().id(), + self.round.id(), destination, direct_message, self.echo_broadcast.clone(), @@ -256,7 +256,7 @@ where /// Returns the ID of the current round. pub fn round_id(&self) -> RoundId { - self.round.as_ref().id() + self.round.id() } /// Performs some preliminary checks on the message to verify its integrity. @@ -466,7 +466,7 @@ where accum.payloads, accum.artifacts, )); - let cached_messages = filter_messages(accum.cached, round.as_ref().id()); + let cached_messages = filter_messages(accum.cached, round.id()); let session = Session::new_for_next_round( rng, self.session_id, @@ -489,18 +489,15 @@ where } FinalizeOutcome::AnotherRound(round) => { // Protecting against common bugs - if !self.possible_next_rounds.contains(&round.as_ref().id()) { - return Err(LocalError::new(format!( - "Unexpected next round id: {:?}", - round.as_ref().id() - ))); + if !self.possible_next_rounds.contains(&round.id()) { + return Err(LocalError::new(format!("Unexpected next round id: {:?}", round.id()))); } // These messages could have been cached before // processing messages from the same node for the current round. // So there might have been some new errors, and we need to check again // if the sender is already banned. - let cached_messages = filter_messages(accum.cached, round.as_ref().id()) + let cached_messages = filter_messages(accum.cached, round.id()) .into_iter() .filter(|message| !transcript.is_banned(message.from())) .collect::>(); diff --git a/manul/src/testing.rs b/manul/src/testing.rs index 1427c84..8846eaa 100644 --- a/manul/src/testing.rs +++ b/manul/src/testing.rs @@ -1,10 +1,6 @@ /*! Utilities for testing protocols. -When testing round based protocols it can be complicated to "inject" the proper faults into the -process, e.g. to emulate a malicious participant. This module provides facilities to make this -easier, by providing a [`RoundOverride`] type along with a [`round_override`] macro. - The [`TestSessionParams`] provides an implementation of the [`SessionParameters`](crate::session::SessionParameters) trait, which in turn is used to setup [`Session`](crate::session::Session)s to drive the protocol. @@ -12,12 +8,10 @@ which in turn is used to setup [`Session`](crate::session::Session)s to drive th The [`run_sync()`] method is helpful to execute a protocol synchronously and collect the outcomes. */ -mod macros; mod run_sync; mod session_parameters; mod wire_format; -pub use macros::{round_override, RoundOverride, RoundWrapper}; pub use run_sync::run_sync; pub use session_parameters::{TestHasher, TestSessionParams, TestSignature, TestSigner, TestVerifier}; pub use wire_format::{BinaryFormat, HumanReadableFormat}; diff --git a/manul/src/testing/macros.rs b/manul/src/testing/macros.rs deleted file mode 100644 index ac8fb57..0000000 --- a/manul/src/testing/macros.rs +++ /dev/null @@ -1,157 +0,0 @@ -use alloc::collections::BTreeMap; - -use rand_core::CryptoRngCore; - -use crate::protocol::{ - Artifact, DirectMessage, EchoBroadcast, FinalizeError, FinalizeOutcome, LocalError, NormalBroadcast, PartyId, - Payload, Round, Serializer, -}; - -/// A trait defining a wrapper around an existing type implementing [`Round`]. -pub trait RoundWrapper: 'static + Sized + Send + Sync { - /// The inner round type. - type InnerRound: Round; - - /// Returns a reference to the inner round. - fn inner_round_ref(&self) -> &Self::InnerRound; - - /// Returns the inner round by value. - fn inner_round(self) -> Self::InnerRound; -} - -/// This trait defines overrides of some methods of [`RoundWrapper::InnerRound`]. -/// -/// Intended to be used with the [`round_override`] macro to generate the [`Round`] implementation. -/// -/// The blanket implementations delegate to the methods of the wrapped round. -pub trait RoundOverride: RoundWrapper { - /// An override for [`Round::make_direct_message`]. - fn make_direct_message( - &self, - rng: &mut impl CryptoRngCore, - serializer: &Serializer, - destination: &Id, - ) -> Result<(DirectMessage, Option), LocalError> { - self.inner_round_ref().make_direct_message(rng, serializer, destination) - } - - /// An override for [`Round::make_echo_broadcast`]. - fn make_echo_broadcast( - &self, - rng: &mut impl CryptoRngCore, - serializer: &Serializer, - ) -> Result { - self.inner_round_ref().make_echo_broadcast(rng, serializer) - } - - /// An override for [`Round::make_normal_broadcast`]. - fn make_normal_broadcast( - &self, - rng: &mut impl CryptoRngCore, - serializer: &Serializer, - ) -> Result { - self.inner_round_ref().make_normal_broadcast(rng, serializer) - } - - /// An override for [`Round::finalize`]. - #[allow(clippy::type_complexity)] - fn finalize( - self, - rng: &mut impl CryptoRngCore, - payloads: BTreeMap, - artifacts: BTreeMap, - ) -> Result< - FinalizeOutcome>::InnerRound as Round>::Protocol>, - FinalizeError<<>::InnerRound as Round>::Protocol>, - > { - self.inner_round().finalize(rng, payloads, artifacts) - } -} - -/// A macro for "inheriting" from a [`Round`]-implementing type, and overriding some of its behavior. -/// -/// The given `$round` must implement [`RoundOverride`], and is generally some type -/// with one of its fields implementing [`Round`]. -/// Then, the macro will implement the [`Round`] trait for `$round` by delegating non-overridden methods to -/// the internal [`RoundWrapper::InnerRound`]. -#[macro_export] -macro_rules! round_override { - ($round: ident) => { - impl Round for $round - where - Id: $crate::protocol::PartyId, - $round: $crate::testing::RoundOverride, - { - type Protocol = - <<$round as $crate::testing::RoundWrapper>::InnerRound as $crate::protocol::Round>::Protocol; - - fn id(&self) -> $crate::protocol::RoundId { - self.inner_round_ref().id() - } - - fn possible_next_rounds(&self) -> ::alloc::collections::BTreeSet<$crate::protocol::RoundId> { - self.inner_round_ref().possible_next_rounds() - } - - fn message_destinations(&self) -> &::alloc::collections::BTreeSet { - self.inner_round_ref().message_destinations() - } - - fn make_direct_message( - &self, - rng: &mut impl CryptoRngCore, - serializer: &$crate::protocol::Serializer, - destination: &Id, - ) -> Result<($crate::protocol::DirectMessage, Option<$crate::protocol::Artifact>), $crate::protocol::LocalError> { - >::make_direct_message(self, rng, serializer, destination) - } - - fn make_echo_broadcast( - &self, - rng: &mut impl CryptoRngCore, - serializer: &$crate::protocol::Serializer, - ) -> Result<$crate::protocol::EchoBroadcast, $crate::protocol::LocalError> { - >::make_echo_broadcast(self, rng, serializer) - } - - fn make_normal_broadcast( - &self, - rng: &mut impl CryptoRngCore, - serializer: &$crate::protocol::Serializer, - ) -> Result<$crate::protocol::NormalBroadcast, $crate::protocol::LocalError> { - >::make_normal_broadcast(self, rng, serializer) - } - - fn receive_message( - &self, - rng: &mut impl CryptoRngCore, - deserializer: &$crate::protocol::Deserializer, - from: &Id, - echo_broadcast: $crate::protocol::EchoBroadcast, - normal_broadcast: $crate::protocol::NormalBroadcast, - direct_message: $crate::protocol::DirectMessage, - ) -> Result<$crate::protocol::Payload, $crate::protocol::ReceiveError> { - self.inner_round_ref() - .receive_message(rng, deserializer, from, echo_broadcast, normal_broadcast, direct_message) - } - - fn finalize( - self, - rng: &mut impl CryptoRngCore, - payloads: ::alloc::collections::BTreeMap, - artifacts: ::alloc::collections::BTreeMap, - ) -> Result< - $crate::protocol::FinalizeOutcome, - $crate::protocol::FinalizeError - > { - >::finalize(self, rng, payloads, artifacts) - } - - fn expecting_messages_from(&self) -> &BTreeSet { - self.inner_round_ref().expecting_messages_from() - } - } - }; -} - -pub use round_override; From 891bd044577d01887c601e0996e296c45f10ec1f Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Wed, 13 Nov 2024 10:33:39 -0800 Subject: [PATCH 8/9] Add Chain combinator --- CHANGELOG.md | 2 + examples/src/lib.rs | 1 + examples/src/simple_chain.rs | 85 ++++++ manul/src/combinators.rs | 1 + manul/src/combinators/chain.rs | 524 +++++++++++++++++++++++++++++++++ manul/src/protocol/errors.rs | 33 ++- manul/src/protocol/round.rs | 89 +++++- manul/src/session/session.rs | 6 +- manul/src/testing/run_sync.rs | 2 +- 9 files changed, 730 insertions(+), 13 deletions(-) create mode 100644 examples/src/simple_chain.rs create mode 100644 manul/src/combinators/chain.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c464d9..281deb2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - An impl of `ProtocolError` for `()` for protocols that don't use errors. ([#60]) - A dummy `CorrectnessProof` trait. ([#60]) - A `misbehave` combinator, intended primarily for testing. ([#60]) +- A `chain` combinator for chaining two protocols. ([#60]) +- `EntryPoint::ENTRY_ROUND` constant. ([#60]) [#32]: https://github.com/entropyxyz/manul/pull/32 diff --git a/examples/src/lib.rs b/examples/src/lib.rs index c3aff76..37b869c 100644 --- a/examples/src/lib.rs +++ b/examples/src/lib.rs @@ -1,6 +1,7 @@ extern crate alloc; pub mod simple; +pub mod simple_chain; #[cfg(test)] mod simple_malicious; diff --git a/examples/src/simple_chain.rs b/examples/src/simple_chain.rs new file mode 100644 index 0000000..e187dc6 --- /dev/null +++ b/examples/src/simple_chain.rs @@ -0,0 +1,85 @@ +use core::fmt::Debug; + +use manul::{ + combinators::chain::{Chained, ChainedEntryPoint}, + protocol::PartyId, +}; + +use super::simple::{Inputs, Round1}; + +pub struct ChainedSimple; + +#[derive(Debug)] +pub struct NewInputs(Inputs); + +impl<'a, Id: PartyId> From<&'a NewInputs> for Inputs { + fn from(source: &'a NewInputs) -> Self { + source.0.clone() + } +} + +impl From<(NewInputs, u8)> for Inputs { + fn from(source: (NewInputs, u8)) -> Self { + let (inputs, _result) = source; + inputs.0 + } +} + +impl Chained for ChainedSimple { + type Inputs = NewInputs; + type EntryPoint1 = Round1; + type EntryPoint2 = Round1; +} + +pub type DoubleSimpleEntryPoint = ChainedEntryPoint; + +#[cfg(test)] +mod tests { + use alloc::collections::BTreeSet; + + use manul::{ + session::{signature::Keypair, SessionOutcome}, + testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier}, + }; + use rand_core::OsRng; + use tracing_subscriber::EnvFilter; + + use super::{DoubleSimpleEntryPoint, NewInputs}; + use crate::simple::Inputs; + + #[test] + fn round() { + let signers = (0..3).map(TestSigner::new).collect::>(); + let all_ids = signers + .iter() + .map(|signer| signer.verifying_key()) + .collect::>(); + let inputs = signers + .into_iter() + .map(|signer| { + ( + signer, + NewInputs(Inputs { + all_ids: all_ids.clone(), + }), + ) + }) + .collect::>(); + + let my_subscriber = tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .finish(); + let reports = tracing::subscriber::with_default(my_subscriber, || { + run_sync::, TestSessionParams>(&mut OsRng, inputs) + .unwrap() + }); + + for (_id, report) in reports { + if let SessionOutcome::Result(result) = report.outcome { + assert_eq!(result, 3); // 0 + 1 + 2 + } else { + panic!("Session did not finish successfully"); + } + } + } +} diff --git a/manul/src/combinators.rs b/manul/src/combinators.rs index 15e2d19..48f012b 100644 --- a/manul/src/combinators.rs +++ b/manul/src/combinators.rs @@ -1,3 +1,4 @@ //! Combinators operating on protocols. +pub mod chain; pub mod misbehave; diff --git a/manul/src/combinators/chain.rs b/manul/src/combinators/chain.rs new file mode 100644 index 0000000..13450d9 --- /dev/null +++ b/manul/src/combinators/chain.rs @@ -0,0 +1,524 @@ +/*! +A combinator representing two protocols as a new protocol that, when executed, +executes the two inner protocols in sequence, feeding the result of the first protocol +into the inputs of the second protocol. + +For the session level users (that is, the ones executing the protocols) +the new protocol is a single entity with its own [`Protocol`](`crate::protocol::Protocol`) type +and an [`EntryPoint`](`crate::protocol::EntryPoint`) type. + +For example, imagine we have a `ProtocolA` with an entry point `EntryPointA`, inputs `InputsA`, +two rounds, `RA1` and `RA2`, and the result `ResultA`; +and similarly a `ProtocolB` with an entry point `EntryPointB`, inputs `InputsB`, +two rounds, `RB1` and `RB2`, and the result `ResultB`. + +Then the chained protocol will provide `ProtocolC: Protocol` and `EntryPointC: EntryPoint`, +the user will define `InputsC` for the new protocol, and the execution will look like: +- `InputsA` is created from `InputsC` via the user-defined `From` impl; +- `EntryPointA` is initialized with `InputsA`; +- `RA1` is executed; +- `RA2` is executed, producing `ResultA`; +- `InputsB` is created from `ResultA` and `InputsC` via the user-defined `From` impl; +- `RB1` is executed; +- `RB2` is executed, producing `ResultB` (which is also the result of `ChainedProtocol`). + +If the execution happens in a [`Session`](`crate::session::Session`), and there is an error at any point, +a regular evidence or correctness proof are created using the corresponding types from the new `ProtocolC`. + +The usage is as follows. + +1. Define an input type for the new joined protocol. + Most likely it will be a union between inputs of the first and the second protocol. + +2. Implement [`Chained`] for a type of your choice. Usually it will be an empty token type. + You will have to specify the entry points of the two protocols, + and the [`From`] conversions from the new input type to the inputs of both entry points + (see the corresponding associated type bounds). + +3. The entry point for the new protocol will be [`ChainedEntryPoint`] parametrized with + the type implementing [`Chained`] from step 2. + +4. The [`Protocol`](`crate::protocol::Protocol`)-implementing type for the new protocol will be + [`ChainedProtocol`] parametrized with the type implementing [`Chained`] from the step 2. +*/ + +use alloc::{ + boxed::Box, + collections::{BTreeMap, BTreeSet}, + format, + string::String, + vec::Vec, +}; +use core::{fmt::Debug, marker::PhantomData}; + +use rand_core::CryptoRngCore; +use serde::{Deserialize, Serialize}; + +use crate::protocol::*; + +/// A trait defining two protocols executed sequentially. +pub trait Chained: 'static +where + Id: PartyId, +{ + /// The inputs of the new chained protocol. + type Inputs: Send + Sync + Debug; + + /// The entry point of the first protocol. + type EntryPoint1: EntryPoint From<&'a Self::Inputs>>; + + /// The entry point of the second protocol. + type EntryPoint2: EntryPoint< + Id, + Inputs: From<( + Self::Inputs, + <>::Protocol as Protocol>::Result, + )>, + >; +} + +/// The protocol error type for the chained protocol. +#[derive_where::derive_where(Debug, Clone)] +#[derive(Serialize, Deserialize)] +#[serde(bound(serialize = " + <>::Protocol as Protocol>::ProtocolError: Serialize, + <>::Protocol as Protocol>::ProtocolError: Serialize, +"))] +#[serde(bound(deserialize = " + <>::Protocol as Protocol>::ProtocolError: for<'x> Deserialize<'x>, + <>::Protocol as Protocol>::ProtocolError: for<'x> Deserialize<'x>, +"))] +pub enum ChainedProtocolError> { + /// A protocol error from the first protocol. + Protocol1(<>::Protocol as Protocol>::ProtocolError), + /// A protocol error from the second protocol. + Protocol2(<>::Protocol as Protocol>::ProtocolError), +} + +impl ChainedProtocolError +where + Id: PartyId, + C: Chained, +{ + fn from_protocol1(err: <>::Protocol as Protocol>::ProtocolError) -> Self { + Self::Protocol1(err) + } + + fn from_protocol2(err: <>::Protocol as Protocol>::ProtocolError) -> Self { + Self::Protocol2(err) + } +} + +impl ProtocolError for ChainedProtocolError +where + Id: PartyId, + C: Chained, +{ + fn description(&self) -> String { + match self { + Self::Protocol1(err) => format!("Protocol1: {}", err.description()), + Self::Protocol2(err) => format!("Protocol2: {}", err.description()), + } + } + + fn required_direct_messages(&self) -> BTreeSet { + let (protocol_num, round_ids) = match self { + Self::Protocol1(err) => (1, err.required_direct_messages()), + Self::Protocol2(err) => (2, err.required_direct_messages()), + }; + round_ids + .into_iter() + .map(|round_id| round_id.group_under(protocol_num)) + .collect() + } + + fn required_echo_broadcasts(&self) -> BTreeSet { + let (protocol_num, round_ids) = match self { + Self::Protocol1(err) => (1, err.required_echo_broadcasts()), + Self::Protocol2(err) => (2, err.required_echo_broadcasts()), + }; + round_ids + .into_iter() + .map(|round_id| round_id.group_under(protocol_num)) + .collect() + } + + fn required_normal_broadcasts(&self) -> BTreeSet { + let (protocol_num, round_ids) = match self { + Self::Protocol1(err) => (1, err.required_normal_broadcasts()), + Self::Protocol2(err) => (2, err.required_normal_broadcasts()), + }; + round_ids + .into_iter() + .map(|round_id| round_id.group_under(protocol_num)) + .collect() + } + + fn required_combined_echos(&self) -> BTreeSet { + let (protocol_num, round_ids) = match self { + Self::Protocol1(err) => (1, err.required_combined_echos()), + Self::Protocol2(err) => (2, err.required_combined_echos()), + }; + round_ids + .into_iter() + .map(|round_id| round_id.group_under(protocol_num)) + .collect() + } + + #[allow(clippy::too_many_arguments)] + fn verify_messages_constitute_error( + &self, + deserializer: &Deserializer, + echo_broadcast: &EchoBroadcast, + normal_broadcast: &NormalBroadcast, + direct_message: &DirectMessage, + echo_broadcasts: &BTreeMap, + normal_broadcasts: &BTreeMap, + direct_messages: &BTreeMap, + combined_echos: &BTreeMap>, + ) -> Result<(), ProtocolValidationError> { + // TODO: the cloning can be avoided if instead we provide a reference to some "transcript API", + // and can replace it here with a proxy that will remove nesting from round ID's. + let echo_broadcasts = echo_broadcasts + .clone() + .into_iter() + .map(|(round_id, v)| round_id.ungroup().map(|round_id| (round_id, v))) + .collect::, _>>()?; + let normal_broadcasts = normal_broadcasts + .clone() + .into_iter() + .map(|(round_id, v)| round_id.ungroup().map(|round_id| (round_id, v))) + .collect::, _>>()?; + let direct_messages = direct_messages + .clone() + .into_iter() + .map(|(round_id, v)| round_id.ungroup().map(|round_id| (round_id, v))) + .collect::, _>>()?; + let combined_echos = combined_echos + .clone() + .into_iter() + .map(|(round_id, v)| round_id.ungroup().map(|round_id| (round_id, v))) + .collect::, _>>()?; + + match self { + Self::Protocol1(err) => err.verify_messages_constitute_error( + deserializer, + echo_broadcast, + normal_broadcast, + direct_message, + &echo_broadcasts, + &normal_broadcasts, + &direct_messages, + &combined_echos, + ), + Self::Protocol2(err) => err.verify_messages_constitute_error( + deserializer, + echo_broadcast, + normal_broadcast, + direct_message, + &echo_broadcasts, + &normal_broadcasts, + &direct_messages, + &combined_echos, + ), + } + } +} + +/// The correctness proof type for the chained protocol. +#[derive_where::derive_where(Debug, Clone)] +#[derive(Serialize, Deserialize)] +#[serde(bound(serialize = " + <>::Protocol as Protocol>::CorrectnessProof: Serialize, + <>::Protocol as Protocol>::CorrectnessProof: Serialize, +"))] +#[serde(bound(deserialize = " + <>::Protocol as Protocol>::CorrectnessProof: for<'x> Deserialize<'x>, + <>::Protocol as Protocol>::CorrectnessProof: for<'x> Deserialize<'x>, +"))] +pub enum ChainedCorrectnessProof +where + Id: PartyId, + C: Chained, +{ + /// A correctness proof from the first protocol. + Protocol1(<>::Protocol as Protocol>::CorrectnessProof), + /// A correctness proof from the second protocol. + Protocol2(<>::Protocol as Protocol>::CorrectnessProof), +} + +impl ChainedCorrectnessProof +where + Id: PartyId, + C: Chained, +{ + fn from_protocol1(proof: <>::Protocol as Protocol>::CorrectnessProof) -> Self { + Self::Protocol1(proof) + } + + fn from_protocol2(proof: <>::Protocol as Protocol>::CorrectnessProof) -> Self { + Self::Protocol2(proof) + } +} + +impl CorrectnessProof for ChainedCorrectnessProof +where + Id: PartyId, + C: Chained, +{ +} + +/// The protocol resulting from chaining two sub-protocols as described by `C`. +#[derive(Debug)] +#[allow(clippy::type_complexity)] +pub struct ChainedProtocol>(PhantomData (Id, C)>); + +impl Protocol for ChainedProtocol +where + Id: PartyId, + C: Chained, +{ + type Result = <>::Protocol as Protocol>::Result; + type ProtocolError = ChainedProtocolError; + type CorrectnessProof = ChainedCorrectnessProof; +} + +/// The entry point of the chained protocol. +#[derive_where::derive_where(Debug)] +pub struct ChainedEntryPoint> { + state: ChainState, +} + +#[derive_where::derive_where(Debug)] +enum ChainState +where + Id: PartyId, + C: Chained, +{ + Protocol1 { + round: BoxedRound>::Protocol>, + shared_randomness: Box<[u8]>, + id: Id, + inputs: C::Inputs, + }, + Protocol2(BoxedRound>::Protocol>), +} + +impl EntryPoint for ChainedEntryPoint +where + Id: PartyId, + C: Chained, +{ + type Inputs = C::Inputs; + type Protocol = ChainedProtocol; + + fn entry_round() -> RoundId { + >::entry_round().group_under(1) + } + + fn new( + rng: &mut impl CryptoRngCore, + shared_randomness: &[u8], + id: Id, + inputs: Self::Inputs, + ) -> Result, LocalError> { + let round = C::EntryPoint1::new(rng, shared_randomness, id.clone(), (&inputs).into())?; + let round = ChainedEntryPoint { + state: ChainState::Protocol1 { + shared_randomness: shared_randomness.into(), + id, + inputs, + round, + }, + }; + Ok(BoxedRound::new_object_safe(round)) + } +} + +impl ObjectSafeRound for ChainedEntryPoint +where + Id: PartyId, + C: Chained, +{ + type Protocol = ChainedProtocol; + + fn id(&self) -> RoundId { + match &self.state { + ChainState::Protocol1 { round, .. } => round.as_ref().id().group_under(1), + ChainState::Protocol2(round) => round.as_ref().id().group_under(2), + } + } + + fn possible_next_rounds(&self) -> BTreeSet { + match &self.state { + ChainState::Protocol1 { round, .. } => { + let mut next_rounds = round + .as_ref() + .possible_next_rounds() + .into_iter() + .map(|round_id| round_id.group_under(1)) + .collect::>(); + + // If there are no next rounds, this is the result round. + // This means that in the chain the next round will be the entry round of the second protocol. + if next_rounds.is_empty() { + next_rounds.insert(C::EntryPoint2::entry_round().group_under(2)); + } + next_rounds + } + ChainState::Protocol2(round) => round + .as_ref() + .possible_next_rounds() + .into_iter() + .map(|round_id| round_id.group_under(2)) + .collect(), + } + } + + fn message_destinations(&self) -> &BTreeSet { + match &self.state { + ChainState::Protocol1 { round, .. } => round.as_ref().message_destinations(), + ChainState::Protocol2(round) => round.as_ref().message_destinations(), + } + } + + fn make_direct_message( + &self, + rng: &mut dyn CryptoRngCore, + serializer: &Serializer, + deserializer: &Deserializer, + destination: &Id, + ) -> Result<(DirectMessage, Option), LocalError> { + match &self.state { + ChainState::Protocol1 { round, .. } => { + round + .as_ref() + .make_direct_message(rng, serializer, deserializer, destination) + } + ChainState::Protocol2(round) => { + round + .as_ref() + .make_direct_message(rng, serializer, deserializer, destination) + } + } + } + + fn make_echo_broadcast( + &self, + rng: &mut dyn CryptoRngCore, + serializer: &Serializer, + deserializer: &Deserializer, + ) -> Result { + match &self.state { + ChainState::Protocol1 { round, .. } => round.as_ref().make_echo_broadcast(rng, serializer, deserializer), + ChainState::Protocol2(round) => round.as_ref().make_echo_broadcast(rng, serializer, deserializer), + } + } + + fn make_normal_broadcast( + &self, + rng: &mut dyn CryptoRngCore, + serializer: &Serializer, + deserializer: &Deserializer, + ) -> Result { + match &self.state { + ChainState::Protocol1 { round, .. } => round.as_ref().make_normal_broadcast(rng, serializer, deserializer), + ChainState::Protocol2(round) => round.as_ref().make_normal_broadcast(rng, serializer, deserializer), + } + } + + fn receive_message( + &self, + rng: &mut dyn CryptoRngCore, + deserializer: &Deserializer, + from: &Id, + echo_broadcast: EchoBroadcast, + normal_broadcast: NormalBroadcast, + direct_message: DirectMessage, + ) -> Result> { + match &self.state { + ChainState::Protocol1 { round, .. } => match round.as_ref().receive_message( + rng, + deserializer, + from, + echo_broadcast, + normal_broadcast, + direct_message, + ) { + Ok(payload) => Ok(payload), + Err(err) => Err(err.map(ChainedProtocolError::from_protocol1)), + }, + ChainState::Protocol2(round) => match round.as_ref().receive_message( + rng, + deserializer, + from, + echo_broadcast, + normal_broadcast, + direct_message, + ) { + Ok(payload) => Ok(payload), + Err(err) => Err(err.map(ChainedProtocolError::from_protocol2)), + }, + } + } + + fn finalize( + self: Box, + rng: &mut dyn CryptoRngCore, + payloads: BTreeMap, + artifacts: BTreeMap, + ) -> Result, FinalizeError> { + match self.state { + ChainState::Protocol1 { + round, + id, + inputs, + shared_randomness, + } => match round.into_boxed().finalize(rng, payloads, artifacts) { + Ok(FinalizeOutcome::Result(result)) => { + let mut boxed_rng = BoxedRng(rng); + let round = C::EntryPoint2::new(&mut boxed_rng, &shared_randomness, id, (inputs, result).into())?; + + Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_object_safe( + ChainedEntryPoint:: { + state: ChainState::Protocol2(round), + }, + ))) + } + Ok(FinalizeOutcome::AnotherRound(round)) => Ok(FinalizeOutcome::AnotherRound( + BoxedRound::new_object_safe(ChainedEntryPoint:: { + state: ChainState::Protocol1 { + shared_randomness, + id, + inputs, + round, + }, + }), + )), + Err(FinalizeError::Local(err)) => Err(FinalizeError::Local(err)), + Err(FinalizeError::Unattributable(proof)) => Err(FinalizeError::Unattributable( + ChainedCorrectnessProof::from_protocol1(proof), + )), + }, + ChainState::Protocol2(round) => match round.into_boxed().finalize(rng, payloads, artifacts) { + Ok(FinalizeOutcome::Result(result)) => Ok(FinalizeOutcome::Result(result)), + Ok(FinalizeOutcome::AnotherRound(round)) => Ok(FinalizeOutcome::AnotherRound( + BoxedRound::new_object_safe(ChainedEntryPoint:: { + state: ChainState::Protocol2(round), + }), + )), + Err(FinalizeError::Local(err)) => Err(FinalizeError::Local(err)), + Err(FinalizeError::Unattributable(proof)) => Err(FinalizeError::Unattributable( + ChainedCorrectnessProof::from_protocol2(proof), + )), + }, + } + } + + fn expecting_messages_from(&self) -> &BTreeSet { + match &self.state { + ChainState::Protocol1 { round, .. } => round.as_ref().expecting_messages_from(), + ChainState::Protocol2(round) => round.as_ref().expecting_messages_from(), + } + } +} diff --git a/manul/src/protocol/errors.rs b/manul/src/protocol/errors.rs index a77ecc3..a5820f2 100644 --- a/manul/src/protocol/errors.rs +++ b/manul/src/protocol/errors.rs @@ -1,4 +1,4 @@ -use alloc::{format, string::String}; +use alloc::{boxed::Box, format, string::String}; use core::fmt::Debug; use super::round::Protocol; @@ -50,7 +50,7 @@ pub(crate) enum ReceiveErrorType { // so this whole enum is crate-private and the variants are created // via constructors and From impls. /// An echo round error occurred. - Echo(EchoRoundError), + Echo(Box>), } impl ReceiveError { @@ -68,6 +68,33 @@ impl ReceiveError { pub fn protocol(error: P::ProtocolError) -> Self { Self(ReceiveErrorType::Protocol(error)) } + + /// Maps the error to a different protocol, given the mapping function for protocol errors. + pub(crate) fn map(self, f: F) -> ReceiveError + where + F: Fn(P::ProtocolError) -> T::ProtocolError, + T: Protocol, + { + ReceiveError(self.0.map::(f)) + } +} + +impl ReceiveErrorType { + pub(crate) fn map(self, f: F) -> ReceiveErrorType + where + F: Fn(P::ProtocolError) -> T::ProtocolError, + T: Protocol, + { + match self { + Self::Local(err) => ReceiveErrorType::Local(err), + Self::InvalidDirectMessage(err) => ReceiveErrorType::InvalidDirectMessage(err), + Self::InvalidEchoBroadcast(err) => ReceiveErrorType::InvalidEchoBroadcast(err), + Self::InvalidNormalBroadcast(err) => ReceiveErrorType::InvalidNormalBroadcast(err), + Self::Unprovable(err) => ReceiveErrorType::Unprovable(err), + Self::Echo(err) => ReceiveErrorType::Echo(err), + Self::Protocol(err) => ReceiveErrorType::Protocol(f(err)), + } + } } impl From for ReceiveError @@ -93,7 +120,7 @@ where P: Protocol, { fn from(error: EchoRoundError) -> Self { - Self(ReceiveErrorType::Echo(error)) + Self(ReceiveErrorType::Echo(Box::new(error))) } } diff --git a/manul/src/protocol/round.rs b/manul/src/protocol/round.rs index 1f8c556..794677c 100644 --- a/manul/src/protocol/round.rs +++ b/manul/src/protocol/round.rs @@ -5,7 +5,10 @@ use alloc::{ string::String, vec::Vec, }; -use core::{any::Any, fmt::Debug}; +use core::{ + any::Any, + fmt::{self, Debug, Display}, +}; use rand_core::CryptoRngCore; use serde::{Deserialize, Serialize}; @@ -26,22 +29,89 @@ pub enum FinalizeOutcome { Result(P::Result), } +// Maximum depth of group nesting in RoundIds. +// We need this to be limited to allow the nesting to be performed in `const` context +// (since we cannot use heap there). +const ROUND_ID_DEPTH: usize = 8; + /// A round identifier. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub struct RoundId { - round_num: u8, + depth: u8, + round_nums: [u8; ROUND_ID_DEPTH], is_echo: bool, } +impl Display for RoundId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + write!(f, "Round ")?; + for i in (0..self.depth as usize).rev() { + write!(f, "{}", self.round_nums.get(i).expect("Depth within range"))?; + if i != 0 { + write!(f, "-")?; + } + } + if self.is_echo { + write!(f, " (echo)")?; + } + Ok(()) + } +} + impl RoundId { /// Creates a new round identifier. - pub fn new(round_num: u8) -> Self { + pub const fn new(round_num: u8) -> Self { + let mut round_nums = [0u8; ROUND_ID_DEPTH]; + #[allow(clippy::indexing_slicing)] + { + round_nums[0] = round_num; + } Self { - round_num, + depth: 1, + round_nums, is_echo: false, } } + /// Prefixes this round ID (possibly already nested) with a group number. + /// + /// **Warning:** the maximum nesting depth is 8. Panics if this nesting overflows it. + pub(crate) const fn group_under(&self, round_num: u8) -> Self { + if self.depth as usize == ROUND_ID_DEPTH { + panic!("Maximum depth reached"); + } + let mut round_nums = self.round_nums; + + // Would use `expect("Depth within range")` here, but `expect()` in const fns is unstable. + #[allow(clippy::indexing_slicing)] + { + round_nums[self.depth as usize] = round_num; + } + + Self { + depth: self.depth + 1, + round_nums, + is_echo: self.is_echo, + } + } + + /// Removes the top group prefix from this round ID. + /// + /// Returns the `Err` variant if the round ID is not nested. + pub(crate) fn ungroup(&self) -> Result { + if self.depth == 1 { + Err(LocalError::new("This round ID is not in a group")) + } else { + let mut round_nums = self.round_nums; + *round_nums.get_mut(self.depth as usize - 1).expect("Depth within range") = 0; + Ok(Self { + depth: self.depth - 1, + round_nums, + is_echo: self.is_echo, + }) + } + } + /// Returns `true` if this is an ID of an echo broadcast round. pub(crate) fn is_echo(&self) -> bool { self.is_echo @@ -57,7 +127,8 @@ impl RoundId { panic!("This is already an echo round ID"); } Self { - round_num: self.round_num, + depth: self.depth, + round_nums: self.round_nums, is_echo: true, } } @@ -72,7 +143,8 @@ impl RoundId { panic!("This is already an non-echo round ID"); } Self { - round_num: self.round_num, + depth: self.depth, + round_nums: self.round_nums, is_echo: false, } } @@ -299,6 +371,11 @@ pub trait EntryPoint { /// The protocol implemented by the round this entry points returns. type Protocol: Protocol; + /// Returns the ID of the round returned by [`Self::new`]. + fn entry_round() -> RoundId { + RoundId::new(1) + } + /// Creates the round. /// /// `session_id` can be assumed to be the same for each node participating in a session. diff --git a/manul/src/session/session.rs b/manul/src/session/session.rs index ad33fc3..b8edb31 100644 --- a/manul/src/session/session.rs +++ b/manul/src/session/session.rs @@ -349,7 +349,7 @@ where } Err(MessageVerificationError::Local(error)) => return Err(error), }; - debug!("{key:?}: Received {message_round_id:?} message from {from:?}"); + debug!("{key:?}: Received {message_round_id} message from {from:?}"); match message_for { MessageFor::ThisRound => { @@ -357,7 +357,7 @@ where Ok(PreprocessOutcome::ToProcess(verified_message)) } MessageFor::NextRound => { - debug!("{key:?}: Caching message from {from:?} for {message_round_id:?}"); + debug!("{key:?}: Caching message from {from:?} for {message_round_id}"); accum.cache_message(verified_message)?; Ok(PreprocessOutcome::Cached) } @@ -734,7 +734,7 @@ where } ReceiveErrorType::Echo(error) => { let (_echo_broadcast, normal_broadcast, _direct_message) = processed.message.into_parts(); - let evidence = Evidence::new_echo_round_error(&from, normal_broadcast, error)?; + let evidence = Evidence::new_echo_round_error(&from, normal_broadcast, *error)?; self.register_provable_error(&from, evidence) } ReceiveErrorType::Local(error) => Err(error), diff --git a/manul/src/testing/run_sync.rs b/manul/src/testing/run_sync.rs index aa3db0c..8d08c87 100644 --- a/manul/src/testing/run_sync.rs +++ b/manul/src/testing/run_sync.rs @@ -45,7 +45,7 @@ where let state = loop { match session.can_finalize(&accum) { CanFinalize::Yes => { - debug!("{:?}: finalizing {:?}", session.verifier(), session.round_id(),); + debug!("{:?}: finalizing {}", session.verifier(), session.round_id()); match session.finalize_round(rng, accum)? { RoundOutcome::Finished(report) => break State::Finished(report), RoundOutcome::AnotherRound { From 600cf8597a2e031c2c2833841770116a179a268f Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Fri, 8 Nov 2024 19:52:39 -0800 Subject: [PATCH 9/9] Use tinyvec instead of a fixed-size array in RoundId --- Cargo.lock | 17 ++++++++++ manul/Cargo.toml | 1 + manul/src/protocol/round.rs | 57 +++++++++------------------------ manul/src/session/evidence.rs | 32 +++++++++--------- manul/src/session/message.rs | 2 +- manul/src/session/session.rs | 18 +++++------ manul/src/session/transcript.rs | 10 +++--- 7 files changed, 65 insertions(+), 72 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c31e3f0..a5dc428 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -492,6 +492,7 @@ dependencies = [ "serde_asn1_der", "serde_json", "signature", + "tinyvec", "tracing", ] @@ -934,6 +935,22 @@ dependencies = [ "serde_json", ] +[[package]] +name = "tinyvec" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" +dependencies = [ + "serde", + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + [[package]] name = "tokio" version = "1.40.0" diff --git a/manul/Cargo.toml b/manul/Cargo.toml index 893157b..ede1380 100644 --- a/manul/Cargo.toml +++ b/manul/Cargo.toml @@ -20,6 +20,7 @@ rand_core = { version = "0.6.4", default-features = false } tracing = { version = "0.1", default-features = false } displaydoc = { version = "0.2", default-features = false } derive-where = "1" +tinyvec = { version = "1", default-features = false, features = ["alloc", "serde"] } rand = { version = "0.8", default-features = false, optional = true } serde-persistent-deserializer = { version = "0.3", optional = true } diff --git a/manul/src/protocol/round.rs b/manul/src/protocol/round.rs index 794677c..bf6f6e7 100644 --- a/manul/src/protocol/round.rs +++ b/manul/src/protocol/round.rs @@ -3,6 +3,7 @@ use alloc::{ collections::{BTreeMap, BTreeSet}, format, string::String, + vec, vec::Vec, }; use core::{ @@ -12,6 +13,7 @@ use core::{ use rand_core::CryptoRngCore; use serde::{Deserialize, Serialize}; +use tinyvec::{tiny_vec, TinyVec}; use super::{ errors::{FinalizeError, LocalError, MessageValidationError, ProtocolValidationError, ReceiveError}, @@ -29,24 +31,18 @@ pub enum FinalizeOutcome { Result(P::Result), } -// Maximum depth of group nesting in RoundIds. -// We need this to be limited to allow the nesting to be performed in `const` context -// (since we cannot use heap there). -const ROUND_ID_DEPTH: usize = 8; - /// A round identifier. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub struct RoundId { - depth: u8, - round_nums: [u8; ROUND_ID_DEPTH], + round_nums: TinyVec<[u8; 4]>, is_echo: bool, } impl Display for RoundId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { write!(f, "Round ")?; - for i in (0..self.depth as usize).rev() { - write!(f, "{}", self.round_nums.get(i).expect("Depth within range"))?; + for (i, round_num) in self.round_nums.iter().enumerate().rev() { + write!(f, "{}", round_num)?; if i != 0 { write!(f, "-")?; } @@ -60,36 +56,18 @@ impl Display for RoundId { impl RoundId { /// Creates a new round identifier. - pub const fn new(round_num: u8) -> Self { - let mut round_nums = [0u8; ROUND_ID_DEPTH]; - #[allow(clippy::indexing_slicing)] - { - round_nums[0] = round_num; - } + pub fn new(round_num: u8) -> Self { Self { - depth: 1, - round_nums, + round_nums: tiny_vec!(round_num, 0, 0, 0), is_echo: false, } } /// Prefixes this round ID (possibly already nested) with a group number. - /// - /// **Warning:** the maximum nesting depth is 8. Panics if this nesting overflows it. - pub(crate) const fn group_under(&self, round_num: u8) -> Self { - if self.depth as usize == ROUND_ID_DEPTH { - panic!("Maximum depth reached"); - } - let mut round_nums = self.round_nums; - - // Would use `expect("Depth within range")` here, but `expect()` in const fns is unstable. - #[allow(clippy::indexing_slicing)] - { - round_nums[self.depth as usize] = round_num; - } - + pub(crate) fn group_under(&self, round_num: u8) -> Self { + let mut round_nums = self.round_nums.clone(); + round_nums.push(round_num); Self { - depth: self.depth + 1, round_nums, is_echo: self.is_echo, } @@ -99,13 +77,12 @@ impl RoundId { /// /// Returns the `Err` variant if the round ID is not nested. pub(crate) fn ungroup(&self) -> Result { - if self.depth == 1 { + if self.round_nums.len() == 1 { Err(LocalError::new("This round ID is not in a group")) } else { - let mut round_nums = self.round_nums; - *round_nums.get_mut(self.depth as usize - 1).expect("Depth within range") = 0; + let mut round_nums = self.round_nums.clone(); + round_nums.pop().expect("vector size greater than 1"); Ok(Self { - depth: self.depth - 1, round_nums, is_echo: self.is_echo, }) @@ -127,8 +104,7 @@ impl RoundId { panic!("This is already an echo round ID"); } Self { - depth: self.depth, - round_nums: self.round_nums, + round_nums: self.round_nums.clone(), is_echo: true, } } @@ -143,8 +119,7 @@ impl RoundId { panic!("This is already an non-echo round ID"); } Self { - depth: self.depth, - round_nums: self.round_nums, + round_nums: self.round_nums.clone(), is_echo: false, } } diff --git a/manul/src/session/evidence.rs b/manul/src/session/evidence.rs index c0fab32..eafd75c 100644 --- a/manul/src/session/evidence.rs +++ b/manul/src/session/evidence.rs @@ -100,8 +100,8 @@ where .iter() .map(|round_id| { transcript - .get_echo_broadcast(*round_id, verifier) - .map(|echo| (*round_id, echo)) + .get_echo_broadcast(round_id.clone(), verifier) + .map(|echo| (round_id.clone(), echo)) }) .collect::, _>>()?; @@ -110,8 +110,8 @@ where .iter() .map(|round_id| { transcript - .get_normal_broadcast(*round_id, verifier) - .map(|bc| (*round_id, bc)) + .get_normal_broadcast(round_id.clone(), verifier) + .map(|bc| (round_id.clone(), bc)) }) .collect::, _>>()?; @@ -120,8 +120,8 @@ where .iter() .map(|round_id| { transcript - .get_direct_message(*round_id, verifier) - .map(|dm| (*round_id, dm)) + .get_direct_message(round_id.clone(), verifier) + .map(|dm| (round_id.clone(), dm)) }) .collect::, _>>()?; @@ -131,7 +131,7 @@ where .map(|round_id| { transcript .get_normal_broadcast(round_id.echo(), verifier) - .map(|dm| (*round_id, dm)) + .map(|dm| (round_id.clone(), dm)) }) .collect::, _>>()?; @@ -470,12 +470,12 @@ where for (round_id, direct_message) in self.direct_messages.iter() { let verified_direct_message = direct_message.clone().verify::(verifier)?; let metadata = verified_direct_message.metadata(); - if metadata.session_id() != session_id || metadata.round_id() != *round_id { + if metadata.session_id() != session_id || &metadata.round_id() != round_id { return Err(EvidenceError::InvalidEvidence( "Invalid attached message metadata".into(), )); } - verified_direct_messages.insert(*round_id, verified_direct_message.payload().clone()); + verified_direct_messages.insert(round_id.clone(), verified_direct_message.payload().clone()); } let verified_echo_broadcast = self.echo_broadcast.clone().verify::(verifier)?.payload().clone(); @@ -500,31 +500,31 @@ where for (round_id, echo_broadcast) in self.echo_broadcasts.iter() { let verified_echo_broadcast = echo_broadcast.clone().verify::(verifier)?; let metadata = verified_echo_broadcast.metadata(); - if metadata.session_id() != session_id || metadata.round_id() != *round_id { + if metadata.session_id() != session_id || &metadata.round_id() != round_id { return Err(EvidenceError::InvalidEvidence( "Invalid attached message metadata".into(), )); } - verified_echo_broadcasts.insert(*round_id, verified_echo_broadcast.payload().clone()); + verified_echo_broadcasts.insert(round_id.clone(), verified_echo_broadcast.payload().clone()); } let mut verified_normal_broadcasts = BTreeMap::new(); for (round_id, normal_broadcast) in self.normal_broadcasts.iter() { let verified_normal_broadcast = normal_broadcast.clone().verify::(verifier)?; let metadata = verified_normal_broadcast.metadata(); - if metadata.session_id() != session_id || metadata.round_id() != *round_id { + if metadata.session_id() != session_id || &metadata.round_id() != round_id { return Err(EvidenceError::InvalidEvidence( "Invalid attached message metadata".into(), )); } - verified_normal_broadcasts.insert(*round_id, verified_normal_broadcast.payload().clone()); + verified_normal_broadcasts.insert(round_id.clone(), verified_normal_broadcast.payload().clone()); } let mut combined_echos = BTreeMap::new(); for (round_id, combined_echo) in self.combined_echos.iter() { let verified_combined_echo = combined_echo.clone().verify::(verifier)?; let metadata = verified_combined_echo.metadata(); - if metadata.session_id() != session_id || metadata.round_id().non_echo() != *round_id { + if metadata.session_id() != session_id || &metadata.round_id().non_echo() != round_id { return Err(EvidenceError::InvalidEvidence( "Invalid attached message metadata".into(), )); @@ -537,14 +537,14 @@ where for (other_verifier, echo_broadcast) in echo_set.echo_broadcasts.iter() { let verified_echo_broadcast = echo_broadcast.clone().verify::(other_verifier)?; let metadata = verified_echo_broadcast.metadata(); - if metadata.session_id() != session_id || metadata.round_id() != *round_id { + if metadata.session_id() != session_id || &metadata.round_id() != round_id { return Err(EvidenceError::InvalidEvidence( "Invalid attached message metadata".into(), )); } verified_echo_set.push(verified_echo_broadcast.payload().clone()); } - combined_echos.insert(*round_id, verified_echo_set); + combined_echos.insert(round_id.clone(), verified_echo_set); } Ok(self.error.verify_messages_constitute_error( diff --git a/manul/src/session/message.rs b/manul/src/session/message.rs index 51dbc55..199c2f0 100644 --- a/manul/src/session/message.rs +++ b/manul/src/session/message.rs @@ -68,7 +68,7 @@ impl MessageMetadata { } pub fn round_id(&self) -> RoundId { - self.round_id + self.round_id.clone() } } diff --git a/manul/src/session/session.rs b/manul/src/session/session.rs index b8edb31..db81183 100644 --- a/manul/src/session/session.rs +++ b/manul/src/session/session.rs @@ -317,7 +317,7 @@ where } MessageFor::ThisRound } else if self.possible_next_rounds.contains(&message_round_id) { - if accum.message_is_cached(from, message_round_id) { + if accum.message_is_cached(from, &message_round_id) { let err = format!("Message for {:?} is already cached", message_round_id); accum.register_unprovable_error(from, RemoteError::new(&err))?; trace!("{key:?} {err}"); @@ -354,7 +354,7 @@ where match message_for { MessageFor::ThisRound => { accum.mark_processing(&verified_message)?; - Ok(PreprocessOutcome::ToProcess(verified_message)) + Ok(PreprocessOutcome::ToProcess(Box::new(verified_message))) } MessageFor::NextRound => { debug!("{key:?}: Caching message from {from:?} for {message_round_id}"); @@ -406,7 +406,7 @@ where ) -> Result, LocalError> { let round_id = self.round_id(); let transcript = self.transcript.update( - round_id, + &round_id, accum.echo_broadcasts, accum.normal_broadcasts, accum.direct_messages, @@ -446,7 +446,7 @@ where let round_id = self.round_id(); let transcript = self.transcript.update( - round_id, + &round_id, accum.echo_broadcasts, accum.normal_broadcasts, accum.direct_messages, @@ -604,9 +604,9 @@ where self.processing.contains(from) } - fn message_is_cached(&self, from: &SP::Verifier, round_id: RoundId) -> bool { + fn message_is_cached(&self, from: &SP::Verifier, round_id: &RoundId) -> bool { if let Some(entry) = self.cached.get(from) { - entry.contains_key(&round_id) + entry.contains_key(round_id) } else { false } @@ -745,7 +745,7 @@ where let from = message.from().clone(); let round_id = message.metadata().round_id(); let cached = self.cached.entry(from.clone()).or_default(); - if cached.insert(round_id, message).is_some() { + if cached.insert(round_id.clone(), message).is_some() { return Err(LocalError::new(format!( "A message from for {:?} has already been cached", round_id @@ -771,7 +771,7 @@ pub struct ProcessedMessage { #[derive(Debug, Clone)] pub enum PreprocessOutcome { /// The message was successfully verified, pass it on to [`Session::process_message`]. - ToProcess(VerifiedMessage), + ToProcess(Box>), /// The message was intended for the next round and was cached. /// /// No action required now, cached messages will be returned on successful [`Session::finalize_round`]. @@ -795,7 +795,7 @@ impl PreprocessOutcome { /// so the user may choose to ignore them if no logging is desired. pub fn ok(self) -> Option> { match self { - Self::ToProcess(message) => Some(message), + Self::ToProcess(message) => Some(*message), _ => None, } } diff --git a/manul/src/session/transcript.rs b/manul/src/session/transcript.rs index 382448d..3f678e5 100644 --- a/manul/src/session/transcript.rs +++ b/manul/src/session/transcript.rs @@ -36,7 +36,7 @@ where #[allow(clippy::too_many_arguments)] pub fn update( self, - round_id: RoundId, + round_id: &RoundId, echo_broadcasts: BTreeMap>, normal_broadcasts: BTreeMap>, direct_messages: BTreeMap>, @@ -45,7 +45,7 @@ where missing_messages: BTreeSet, ) -> Result { let mut all_echo_broadcasts = self.echo_broadcasts; - match all_echo_broadcasts.entry(round_id) { + match all_echo_broadcasts.entry(round_id.clone()) { Entry::Vacant(entry) => entry.insert(echo_broadcasts), Entry::Occupied(_) => { return Err(LocalError::new(format!( @@ -55,7 +55,7 @@ where }; let mut all_normal_broadcasts = self.normal_broadcasts; - match all_normal_broadcasts.entry(round_id) { + match all_normal_broadcasts.entry(round_id.clone()) { Entry::Vacant(entry) => entry.insert(normal_broadcasts), Entry::Occupied(_) => { return Err(LocalError::new(format!( @@ -65,7 +65,7 @@ where }; let mut all_direct_messages = self.direct_messages; - match all_direct_messages.entry(round_id) { + match all_direct_messages.entry(round_id.clone()) { Entry::Vacant(entry) => entry.insert(direct_messages), Entry::Occupied(_) => { return Err(LocalError::new(format!( @@ -93,7 +93,7 @@ where } let mut all_missing_messages = self.missing_messages; - match all_missing_messages.entry(round_id) { + match all_missing_messages.entry(round_id.clone()) { Entry::Vacant(entry) => entry.insert(missing_messages), Entry::Occupied(_) => { return Err(LocalError::new(format!(