From e4a61be05e04611ba15889c51e5c65c3f2be63b7 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Wed, 13 Nov 2024 10:27:10 -0800 Subject: [PATCH] Add Chain combinator --- CHANGELOG.md | 2 + examples/src/lib.rs | 1 + examples/src/simple_chain.rs | 85 ++++++ manul/src/combinators.rs | 1 + manul/src/combinators/chain.rs | 524 +++++++++++++++++++++++++++++++++ manul/src/protocol/errors.rs | 33 ++- manul/src/protocol/round.rs | 89 +++++- manul/src/session/session.rs | 6 +- manul/src/testing/run_sync.rs | 2 +- 9 files changed, 730 insertions(+), 13 deletions(-) create mode 100644 examples/src/simple_chain.rs create mode 100644 manul/src/combinators/chain.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c464d9..281deb2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - An impl of `ProtocolError` for `()` for protocols that don't use errors. ([#60]) - A dummy `CorrectnessProof` trait. ([#60]) - A `misbehave` combinator, intended primarily for testing. ([#60]) +- A `chain` combinator for chaining two protocols. ([#60]) +- `EntryPoint::ENTRY_ROUND` constant. ([#60]) [#32]: https://github.com/entropyxyz/manul/pull/32 diff --git a/examples/src/lib.rs b/examples/src/lib.rs index c3aff76..37b869c 100644 --- a/examples/src/lib.rs +++ b/examples/src/lib.rs @@ -1,6 +1,7 @@ extern crate alloc; pub mod simple; +pub mod simple_chain; #[cfg(test)] mod simple_malicious; diff --git a/examples/src/simple_chain.rs b/examples/src/simple_chain.rs new file mode 100644 index 0000000..e187dc6 --- /dev/null +++ b/examples/src/simple_chain.rs @@ -0,0 +1,85 @@ +use core::fmt::Debug; + +use manul::{ + combinators::chain::{Chained, ChainedEntryPoint}, + protocol::PartyId, +}; + +use super::simple::{Inputs, Round1}; + +pub struct ChainedSimple; + +#[derive(Debug)] +pub struct NewInputs(Inputs); + +impl<'a, Id: PartyId> From<&'a NewInputs> for Inputs { + fn from(source: &'a NewInputs) -> Self { + source.0.clone() + } +} + +impl From<(NewInputs, u8)> for Inputs { + fn from(source: (NewInputs, u8)) -> Self { + let (inputs, _result) = source; + inputs.0 + } +} + +impl Chained for ChainedSimple { + type Inputs = NewInputs; + type EntryPoint1 = Round1; + type EntryPoint2 = Round1; +} + +pub type DoubleSimpleEntryPoint = ChainedEntryPoint; + +#[cfg(test)] +mod tests { + use alloc::collections::BTreeSet; + + use manul::{ + session::{signature::Keypair, SessionOutcome}, + testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier}, + }; + use rand_core::OsRng; + use tracing_subscriber::EnvFilter; + + use super::{DoubleSimpleEntryPoint, NewInputs}; + use crate::simple::Inputs; + + #[test] + fn round() { + let signers = (0..3).map(TestSigner::new).collect::>(); + let all_ids = signers + .iter() + .map(|signer| signer.verifying_key()) + .collect::>(); + let inputs = signers + .into_iter() + .map(|signer| { + ( + signer, + NewInputs(Inputs { + all_ids: all_ids.clone(), + }), + ) + }) + .collect::>(); + + let my_subscriber = tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .finish(); + let reports = tracing::subscriber::with_default(my_subscriber, || { + run_sync::, TestSessionParams>(&mut OsRng, inputs) + .unwrap() + }); + + for (_id, report) in reports { + if let SessionOutcome::Result(result) = report.outcome { + assert_eq!(result, 3); // 0 + 1 + 2 + } else { + panic!("Session did not finish successfully"); + } + } + } +} diff --git a/manul/src/combinators.rs b/manul/src/combinators.rs index 15e2d19..848fee1 100644 --- a/manul/src/combinators.rs +++ b/manul/src/combinators.rs @@ -1,3 +1,4 @@ //! Combinators operating on protocols. pub mod misbehave; +pub mod chain; diff --git a/manul/src/combinators/chain.rs b/manul/src/combinators/chain.rs new file mode 100644 index 0000000..13450d9 --- /dev/null +++ b/manul/src/combinators/chain.rs @@ -0,0 +1,524 @@ +/*! +A combinator representing two protocols as a new protocol that, when executed, +executes the two inner protocols in sequence, feeding the result of the first protocol +into the inputs of the second protocol. + +For the session level users (that is, the ones executing the protocols) +the new protocol is a single entity with its own [`Protocol`](`crate::protocol::Protocol`) type +and an [`EntryPoint`](`crate::protocol::EntryPoint`) type. + +For example, imagine we have a `ProtocolA` with an entry point `EntryPointA`, inputs `InputsA`, +two rounds, `RA1` and `RA2`, and the result `ResultA`; +and similarly a `ProtocolB` with an entry point `EntryPointB`, inputs `InputsB`, +two rounds, `RB1` and `RB2`, and the result `ResultB`. + +Then the chained protocol will provide `ProtocolC: Protocol` and `EntryPointC: EntryPoint`, +the user will define `InputsC` for the new protocol, and the execution will look like: +- `InputsA` is created from `InputsC` via the user-defined `From` impl; +- `EntryPointA` is initialized with `InputsA`; +- `RA1` is executed; +- `RA2` is executed, producing `ResultA`; +- `InputsB` is created from `ResultA` and `InputsC` via the user-defined `From` impl; +- `RB1` is executed; +- `RB2` is executed, producing `ResultB` (which is also the result of `ChainedProtocol`). + +If the execution happens in a [`Session`](`crate::session::Session`), and there is an error at any point, +a regular evidence or correctness proof are created using the corresponding types from the new `ProtocolC`. + +The usage is as follows. + +1. Define an input type for the new joined protocol. + Most likely it will be a union between inputs of the first and the second protocol. + +2. Implement [`Chained`] for a type of your choice. Usually it will be an empty token type. + You will have to specify the entry points of the two protocols, + and the [`From`] conversions from the new input type to the inputs of both entry points + (see the corresponding associated type bounds). + +3. The entry point for the new protocol will be [`ChainedEntryPoint`] parametrized with + the type implementing [`Chained`] from step 2. + +4. The [`Protocol`](`crate::protocol::Protocol`)-implementing type for the new protocol will be + [`ChainedProtocol`] parametrized with the type implementing [`Chained`] from the step 2. +*/ + +use alloc::{ + boxed::Box, + collections::{BTreeMap, BTreeSet}, + format, + string::String, + vec::Vec, +}; +use core::{fmt::Debug, marker::PhantomData}; + +use rand_core::CryptoRngCore; +use serde::{Deserialize, Serialize}; + +use crate::protocol::*; + +/// A trait defining two protocols executed sequentially. +pub trait Chained: 'static +where + Id: PartyId, +{ + /// The inputs of the new chained protocol. + type Inputs: Send + Sync + Debug; + + /// The entry point of the first protocol. + type EntryPoint1: EntryPoint From<&'a Self::Inputs>>; + + /// The entry point of the second protocol. + type EntryPoint2: EntryPoint< + Id, + Inputs: From<( + Self::Inputs, + <>::Protocol as Protocol>::Result, + )>, + >; +} + +/// The protocol error type for the chained protocol. +#[derive_where::derive_where(Debug, Clone)] +#[derive(Serialize, Deserialize)] +#[serde(bound(serialize = " + <>::Protocol as Protocol>::ProtocolError: Serialize, + <>::Protocol as Protocol>::ProtocolError: Serialize, +"))] +#[serde(bound(deserialize = " + <>::Protocol as Protocol>::ProtocolError: for<'x> Deserialize<'x>, + <>::Protocol as Protocol>::ProtocolError: for<'x> Deserialize<'x>, +"))] +pub enum ChainedProtocolError> { + /// A protocol error from the first protocol. + Protocol1(<>::Protocol as Protocol>::ProtocolError), + /// A protocol error from the second protocol. + Protocol2(<>::Protocol as Protocol>::ProtocolError), +} + +impl ChainedProtocolError +where + Id: PartyId, + C: Chained, +{ + fn from_protocol1(err: <>::Protocol as Protocol>::ProtocolError) -> Self { + Self::Protocol1(err) + } + + fn from_protocol2(err: <>::Protocol as Protocol>::ProtocolError) -> Self { + Self::Protocol2(err) + } +} + +impl ProtocolError for ChainedProtocolError +where + Id: PartyId, + C: Chained, +{ + fn description(&self) -> String { + match self { + Self::Protocol1(err) => format!("Protocol1: {}", err.description()), + Self::Protocol2(err) => format!("Protocol2: {}", err.description()), + } + } + + fn required_direct_messages(&self) -> BTreeSet { + let (protocol_num, round_ids) = match self { + Self::Protocol1(err) => (1, err.required_direct_messages()), + Self::Protocol2(err) => (2, err.required_direct_messages()), + }; + round_ids + .into_iter() + .map(|round_id| round_id.group_under(protocol_num)) + .collect() + } + + fn required_echo_broadcasts(&self) -> BTreeSet { + let (protocol_num, round_ids) = match self { + Self::Protocol1(err) => (1, err.required_echo_broadcasts()), + Self::Protocol2(err) => (2, err.required_echo_broadcasts()), + }; + round_ids + .into_iter() + .map(|round_id| round_id.group_under(protocol_num)) + .collect() + } + + fn required_normal_broadcasts(&self) -> BTreeSet { + let (protocol_num, round_ids) = match self { + Self::Protocol1(err) => (1, err.required_normal_broadcasts()), + Self::Protocol2(err) => (2, err.required_normal_broadcasts()), + }; + round_ids + .into_iter() + .map(|round_id| round_id.group_under(protocol_num)) + .collect() + } + + fn required_combined_echos(&self) -> BTreeSet { + let (protocol_num, round_ids) = match self { + Self::Protocol1(err) => (1, err.required_combined_echos()), + Self::Protocol2(err) => (2, err.required_combined_echos()), + }; + round_ids + .into_iter() + .map(|round_id| round_id.group_under(protocol_num)) + .collect() + } + + #[allow(clippy::too_many_arguments)] + fn verify_messages_constitute_error( + &self, + deserializer: &Deserializer, + echo_broadcast: &EchoBroadcast, + normal_broadcast: &NormalBroadcast, + direct_message: &DirectMessage, + echo_broadcasts: &BTreeMap, + normal_broadcasts: &BTreeMap, + direct_messages: &BTreeMap, + combined_echos: &BTreeMap>, + ) -> Result<(), ProtocolValidationError> { + // TODO: the cloning can be avoided if instead we provide a reference to some "transcript API", + // and can replace it here with a proxy that will remove nesting from round ID's. + let echo_broadcasts = echo_broadcasts + .clone() + .into_iter() + .map(|(round_id, v)| round_id.ungroup().map(|round_id| (round_id, v))) + .collect::, _>>()?; + let normal_broadcasts = normal_broadcasts + .clone() + .into_iter() + .map(|(round_id, v)| round_id.ungroup().map(|round_id| (round_id, v))) + .collect::, _>>()?; + let direct_messages = direct_messages + .clone() + .into_iter() + .map(|(round_id, v)| round_id.ungroup().map(|round_id| (round_id, v))) + .collect::, _>>()?; + let combined_echos = combined_echos + .clone() + .into_iter() + .map(|(round_id, v)| round_id.ungroup().map(|round_id| (round_id, v))) + .collect::, _>>()?; + + match self { + Self::Protocol1(err) => err.verify_messages_constitute_error( + deserializer, + echo_broadcast, + normal_broadcast, + direct_message, + &echo_broadcasts, + &normal_broadcasts, + &direct_messages, + &combined_echos, + ), + Self::Protocol2(err) => err.verify_messages_constitute_error( + deserializer, + echo_broadcast, + normal_broadcast, + direct_message, + &echo_broadcasts, + &normal_broadcasts, + &direct_messages, + &combined_echos, + ), + } + } +} + +/// The correctness proof type for the chained protocol. +#[derive_where::derive_where(Debug, Clone)] +#[derive(Serialize, Deserialize)] +#[serde(bound(serialize = " + <>::Protocol as Protocol>::CorrectnessProof: Serialize, + <>::Protocol as Protocol>::CorrectnessProof: Serialize, +"))] +#[serde(bound(deserialize = " + <>::Protocol as Protocol>::CorrectnessProof: for<'x> Deserialize<'x>, + <>::Protocol as Protocol>::CorrectnessProof: for<'x> Deserialize<'x>, +"))] +pub enum ChainedCorrectnessProof +where + Id: PartyId, + C: Chained, +{ + /// A correctness proof from the first protocol. + Protocol1(<>::Protocol as Protocol>::CorrectnessProof), + /// A correctness proof from the second protocol. + Protocol2(<>::Protocol as Protocol>::CorrectnessProof), +} + +impl ChainedCorrectnessProof +where + Id: PartyId, + C: Chained, +{ + fn from_protocol1(proof: <>::Protocol as Protocol>::CorrectnessProof) -> Self { + Self::Protocol1(proof) + } + + fn from_protocol2(proof: <>::Protocol as Protocol>::CorrectnessProof) -> Self { + Self::Protocol2(proof) + } +} + +impl CorrectnessProof for ChainedCorrectnessProof +where + Id: PartyId, + C: Chained, +{ +} + +/// The protocol resulting from chaining two sub-protocols as described by `C`. +#[derive(Debug)] +#[allow(clippy::type_complexity)] +pub struct ChainedProtocol>(PhantomData (Id, C)>); + +impl Protocol for ChainedProtocol +where + Id: PartyId, + C: Chained, +{ + type Result = <>::Protocol as Protocol>::Result; + type ProtocolError = ChainedProtocolError; + type CorrectnessProof = ChainedCorrectnessProof; +} + +/// The entry point of the chained protocol. +#[derive_where::derive_where(Debug)] +pub struct ChainedEntryPoint> { + state: ChainState, +} + +#[derive_where::derive_where(Debug)] +enum ChainState +where + Id: PartyId, + C: Chained, +{ + Protocol1 { + round: BoxedRound>::Protocol>, + shared_randomness: Box<[u8]>, + id: Id, + inputs: C::Inputs, + }, + Protocol2(BoxedRound>::Protocol>), +} + +impl EntryPoint for ChainedEntryPoint +where + Id: PartyId, + C: Chained, +{ + type Inputs = C::Inputs; + type Protocol = ChainedProtocol; + + fn entry_round() -> RoundId { + >::entry_round().group_under(1) + } + + fn new( + rng: &mut impl CryptoRngCore, + shared_randomness: &[u8], + id: Id, + inputs: Self::Inputs, + ) -> Result, LocalError> { + let round = C::EntryPoint1::new(rng, shared_randomness, id.clone(), (&inputs).into())?; + let round = ChainedEntryPoint { + state: ChainState::Protocol1 { + shared_randomness: shared_randomness.into(), + id, + inputs, + round, + }, + }; + Ok(BoxedRound::new_object_safe(round)) + } +} + +impl ObjectSafeRound for ChainedEntryPoint +where + Id: PartyId, + C: Chained, +{ + type Protocol = ChainedProtocol; + + fn id(&self) -> RoundId { + match &self.state { + ChainState::Protocol1 { round, .. } => round.as_ref().id().group_under(1), + ChainState::Protocol2(round) => round.as_ref().id().group_under(2), + } + } + + fn possible_next_rounds(&self) -> BTreeSet { + match &self.state { + ChainState::Protocol1 { round, .. } => { + let mut next_rounds = round + .as_ref() + .possible_next_rounds() + .into_iter() + .map(|round_id| round_id.group_under(1)) + .collect::>(); + + // If there are no next rounds, this is the result round. + // This means that in the chain the next round will be the entry round of the second protocol. + if next_rounds.is_empty() { + next_rounds.insert(C::EntryPoint2::entry_round().group_under(2)); + } + next_rounds + } + ChainState::Protocol2(round) => round + .as_ref() + .possible_next_rounds() + .into_iter() + .map(|round_id| round_id.group_under(2)) + .collect(), + } + } + + fn message_destinations(&self) -> &BTreeSet { + match &self.state { + ChainState::Protocol1 { round, .. } => round.as_ref().message_destinations(), + ChainState::Protocol2(round) => round.as_ref().message_destinations(), + } + } + + fn make_direct_message( + &self, + rng: &mut dyn CryptoRngCore, + serializer: &Serializer, + deserializer: &Deserializer, + destination: &Id, + ) -> Result<(DirectMessage, Option), LocalError> { + match &self.state { + ChainState::Protocol1 { round, .. } => { + round + .as_ref() + .make_direct_message(rng, serializer, deserializer, destination) + } + ChainState::Protocol2(round) => { + round + .as_ref() + .make_direct_message(rng, serializer, deserializer, destination) + } + } + } + + fn make_echo_broadcast( + &self, + rng: &mut dyn CryptoRngCore, + serializer: &Serializer, + deserializer: &Deserializer, + ) -> Result { + match &self.state { + ChainState::Protocol1 { round, .. } => round.as_ref().make_echo_broadcast(rng, serializer, deserializer), + ChainState::Protocol2(round) => round.as_ref().make_echo_broadcast(rng, serializer, deserializer), + } + } + + fn make_normal_broadcast( + &self, + rng: &mut dyn CryptoRngCore, + serializer: &Serializer, + deserializer: &Deserializer, + ) -> Result { + match &self.state { + ChainState::Protocol1 { round, .. } => round.as_ref().make_normal_broadcast(rng, serializer, deserializer), + ChainState::Protocol2(round) => round.as_ref().make_normal_broadcast(rng, serializer, deserializer), + } + } + + fn receive_message( + &self, + rng: &mut dyn CryptoRngCore, + deserializer: &Deserializer, + from: &Id, + echo_broadcast: EchoBroadcast, + normal_broadcast: NormalBroadcast, + direct_message: DirectMessage, + ) -> Result> { + match &self.state { + ChainState::Protocol1 { round, .. } => match round.as_ref().receive_message( + rng, + deserializer, + from, + echo_broadcast, + normal_broadcast, + direct_message, + ) { + Ok(payload) => Ok(payload), + Err(err) => Err(err.map(ChainedProtocolError::from_protocol1)), + }, + ChainState::Protocol2(round) => match round.as_ref().receive_message( + rng, + deserializer, + from, + echo_broadcast, + normal_broadcast, + direct_message, + ) { + Ok(payload) => Ok(payload), + Err(err) => Err(err.map(ChainedProtocolError::from_protocol2)), + }, + } + } + + fn finalize( + self: Box, + rng: &mut dyn CryptoRngCore, + payloads: BTreeMap, + artifacts: BTreeMap, + ) -> Result, FinalizeError> { + match self.state { + ChainState::Protocol1 { + round, + id, + inputs, + shared_randomness, + } => match round.into_boxed().finalize(rng, payloads, artifacts) { + Ok(FinalizeOutcome::Result(result)) => { + let mut boxed_rng = BoxedRng(rng); + let round = C::EntryPoint2::new(&mut boxed_rng, &shared_randomness, id, (inputs, result).into())?; + + Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_object_safe( + ChainedEntryPoint:: { + state: ChainState::Protocol2(round), + }, + ))) + } + Ok(FinalizeOutcome::AnotherRound(round)) => Ok(FinalizeOutcome::AnotherRound( + BoxedRound::new_object_safe(ChainedEntryPoint:: { + state: ChainState::Protocol1 { + shared_randomness, + id, + inputs, + round, + }, + }), + )), + Err(FinalizeError::Local(err)) => Err(FinalizeError::Local(err)), + Err(FinalizeError::Unattributable(proof)) => Err(FinalizeError::Unattributable( + ChainedCorrectnessProof::from_protocol1(proof), + )), + }, + ChainState::Protocol2(round) => match round.into_boxed().finalize(rng, payloads, artifacts) { + Ok(FinalizeOutcome::Result(result)) => Ok(FinalizeOutcome::Result(result)), + Ok(FinalizeOutcome::AnotherRound(round)) => Ok(FinalizeOutcome::AnotherRound( + BoxedRound::new_object_safe(ChainedEntryPoint:: { + state: ChainState::Protocol2(round), + }), + )), + Err(FinalizeError::Local(err)) => Err(FinalizeError::Local(err)), + Err(FinalizeError::Unattributable(proof)) => Err(FinalizeError::Unattributable( + ChainedCorrectnessProof::from_protocol2(proof), + )), + }, + } + } + + fn expecting_messages_from(&self) -> &BTreeSet { + match &self.state { + ChainState::Protocol1 { round, .. } => round.as_ref().expecting_messages_from(), + ChainState::Protocol2(round) => round.as_ref().expecting_messages_from(), + } + } +} diff --git a/manul/src/protocol/errors.rs b/manul/src/protocol/errors.rs index a77ecc3..a5820f2 100644 --- a/manul/src/protocol/errors.rs +++ b/manul/src/protocol/errors.rs @@ -1,4 +1,4 @@ -use alloc::{format, string::String}; +use alloc::{boxed::Box, format, string::String}; use core::fmt::Debug; use super::round::Protocol; @@ -50,7 +50,7 @@ pub(crate) enum ReceiveErrorType { // so this whole enum is crate-private and the variants are created // via constructors and From impls. /// An echo round error occurred. - Echo(EchoRoundError), + Echo(Box>), } impl ReceiveError { @@ -68,6 +68,33 @@ impl ReceiveError { pub fn protocol(error: P::ProtocolError) -> Self { Self(ReceiveErrorType::Protocol(error)) } + + /// Maps the error to a different protocol, given the mapping function for protocol errors. + pub(crate) fn map(self, f: F) -> ReceiveError + where + F: Fn(P::ProtocolError) -> T::ProtocolError, + T: Protocol, + { + ReceiveError(self.0.map::(f)) + } +} + +impl ReceiveErrorType { + pub(crate) fn map(self, f: F) -> ReceiveErrorType + where + F: Fn(P::ProtocolError) -> T::ProtocolError, + T: Protocol, + { + match self { + Self::Local(err) => ReceiveErrorType::Local(err), + Self::InvalidDirectMessage(err) => ReceiveErrorType::InvalidDirectMessage(err), + Self::InvalidEchoBroadcast(err) => ReceiveErrorType::InvalidEchoBroadcast(err), + Self::InvalidNormalBroadcast(err) => ReceiveErrorType::InvalidNormalBroadcast(err), + Self::Unprovable(err) => ReceiveErrorType::Unprovable(err), + Self::Echo(err) => ReceiveErrorType::Echo(err), + Self::Protocol(err) => ReceiveErrorType::Protocol(f(err)), + } + } } impl From for ReceiveError @@ -93,7 +120,7 @@ where P: Protocol, { fn from(error: EchoRoundError) -> Self { - Self(ReceiveErrorType::Echo(error)) + Self(ReceiveErrorType::Echo(Box::new(error))) } } diff --git a/manul/src/protocol/round.rs b/manul/src/protocol/round.rs index 1f8c556..794677c 100644 --- a/manul/src/protocol/round.rs +++ b/manul/src/protocol/round.rs @@ -5,7 +5,10 @@ use alloc::{ string::String, vec::Vec, }; -use core::{any::Any, fmt::Debug}; +use core::{ + any::Any, + fmt::{self, Debug, Display}, +}; use rand_core::CryptoRngCore; use serde::{Deserialize, Serialize}; @@ -26,22 +29,89 @@ pub enum FinalizeOutcome { Result(P::Result), } +// Maximum depth of group nesting in RoundIds. +// We need this to be limited to allow the nesting to be performed in `const` context +// (since we cannot use heap there). +const ROUND_ID_DEPTH: usize = 8; + /// A round identifier. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub struct RoundId { - round_num: u8, + depth: u8, + round_nums: [u8; ROUND_ID_DEPTH], is_echo: bool, } +impl Display for RoundId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + write!(f, "Round ")?; + for i in (0..self.depth as usize).rev() { + write!(f, "{}", self.round_nums.get(i).expect("Depth within range"))?; + if i != 0 { + write!(f, "-")?; + } + } + if self.is_echo { + write!(f, " (echo)")?; + } + Ok(()) + } +} + impl RoundId { /// Creates a new round identifier. - pub fn new(round_num: u8) -> Self { + pub const fn new(round_num: u8) -> Self { + let mut round_nums = [0u8; ROUND_ID_DEPTH]; + #[allow(clippy::indexing_slicing)] + { + round_nums[0] = round_num; + } Self { - round_num, + depth: 1, + round_nums, is_echo: false, } } + /// Prefixes this round ID (possibly already nested) with a group number. + /// + /// **Warning:** the maximum nesting depth is 8. Panics if this nesting overflows it. + pub(crate) const fn group_under(&self, round_num: u8) -> Self { + if self.depth as usize == ROUND_ID_DEPTH { + panic!("Maximum depth reached"); + } + let mut round_nums = self.round_nums; + + // Would use `expect("Depth within range")` here, but `expect()` in const fns is unstable. + #[allow(clippy::indexing_slicing)] + { + round_nums[self.depth as usize] = round_num; + } + + Self { + depth: self.depth + 1, + round_nums, + is_echo: self.is_echo, + } + } + + /// Removes the top group prefix from this round ID. + /// + /// Returns the `Err` variant if the round ID is not nested. + pub(crate) fn ungroup(&self) -> Result { + if self.depth == 1 { + Err(LocalError::new("This round ID is not in a group")) + } else { + let mut round_nums = self.round_nums; + *round_nums.get_mut(self.depth as usize - 1).expect("Depth within range") = 0; + Ok(Self { + depth: self.depth - 1, + round_nums, + is_echo: self.is_echo, + }) + } + } + /// Returns `true` if this is an ID of an echo broadcast round. pub(crate) fn is_echo(&self) -> bool { self.is_echo @@ -57,7 +127,8 @@ impl RoundId { panic!("This is already an echo round ID"); } Self { - round_num: self.round_num, + depth: self.depth, + round_nums: self.round_nums, is_echo: true, } } @@ -72,7 +143,8 @@ impl RoundId { panic!("This is already an non-echo round ID"); } Self { - round_num: self.round_num, + depth: self.depth, + round_nums: self.round_nums, is_echo: false, } } @@ -299,6 +371,11 @@ pub trait EntryPoint { /// The protocol implemented by the round this entry points returns. type Protocol: Protocol; + /// Returns the ID of the round returned by [`Self::new`]. + fn entry_round() -> RoundId { + RoundId::new(1) + } + /// Creates the round. /// /// `session_id` can be assumed to be the same for each node participating in a session. diff --git a/manul/src/session/session.rs b/manul/src/session/session.rs index ad33fc3..b8edb31 100644 --- a/manul/src/session/session.rs +++ b/manul/src/session/session.rs @@ -349,7 +349,7 @@ where } Err(MessageVerificationError::Local(error)) => return Err(error), }; - debug!("{key:?}: Received {message_round_id:?} message from {from:?}"); + debug!("{key:?}: Received {message_round_id} message from {from:?}"); match message_for { MessageFor::ThisRound => { @@ -357,7 +357,7 @@ where Ok(PreprocessOutcome::ToProcess(verified_message)) } MessageFor::NextRound => { - debug!("{key:?}: Caching message from {from:?} for {message_round_id:?}"); + debug!("{key:?}: Caching message from {from:?} for {message_round_id}"); accum.cache_message(verified_message)?; Ok(PreprocessOutcome::Cached) } @@ -734,7 +734,7 @@ where } ReceiveErrorType::Echo(error) => { let (_echo_broadcast, normal_broadcast, _direct_message) = processed.message.into_parts(); - let evidence = Evidence::new_echo_round_error(&from, normal_broadcast, error)?; + let evidence = Evidence::new_echo_round_error(&from, normal_broadcast, *error)?; self.register_provable_error(&from, evidence) } ReceiveErrorType::Local(error) => Err(error), diff --git a/manul/src/testing/run_sync.rs b/manul/src/testing/run_sync.rs index aa3db0c..8d08c87 100644 --- a/manul/src/testing/run_sync.rs +++ b/manul/src/testing/run_sync.rs @@ -45,7 +45,7 @@ where let state = loop { match session.can_finalize(&accum) { CanFinalize::Yes => { - debug!("{:?}: finalizing {:?}", session.verifier(), session.round_id(),); + debug!("{:?}: finalizing {}", session.verifier(), session.round_id()); match session.finalize_round(rng, accum)? { RoundOutcome::Finished(report) => break State::Finished(report), RoundOutcome::AnotherRound {