diff --git a/examples/src/lib.rs b/examples/src/lib.rs index c3aff76..37b869c 100644 --- a/examples/src/lib.rs +++ b/examples/src/lib.rs @@ -1,6 +1,7 @@ extern crate alloc; pub mod simple; +pub mod simple_chain; #[cfg(test)] mod simple_malicious; diff --git a/examples/src/simple.rs b/examples/src/simple.rs index eeab9ab..aaebbea 100644 --- a/examples/src/simple.rs +++ b/examples/src/simple.rs @@ -152,6 +152,7 @@ struct Round1Payload { impl EntryPoint for Round1 { type Inputs = Inputs; type Protocol = SimpleProtocol; + fn new( _rng: &mut impl CryptoRngCore, _shared_randomness: &[u8], diff --git a/examples/src/simple_chain.rs b/examples/src/simple_chain.rs new file mode 100644 index 0000000..a1da175 --- /dev/null +++ b/examples/src/simple_chain.rs @@ -0,0 +1,84 @@ +use core::fmt::Debug; + +use manul::{ + combinators::{ChainEntryPoint, Chained}, + protocol::PartyId, +}; + +use super::simple::{Inputs, Round1}; + +pub struct ChainedSimple; + +#[derive(Debug)] +pub struct NewInputs(Inputs); + +impl<'a, Id: PartyId> From<&'a NewInputs> for Inputs { + fn from(source: &'a NewInputs) -> Self { + source.0.clone() + } +} + +impl From<(NewInputs, u8)> for Inputs { + fn from(source: (NewInputs, u8)) -> Self { + source.0 .0 + } +} + +impl Chained for ChainedSimple { + type Inputs = NewInputs; + type EntryPoint1 = Round1; + type EntryPoint2 = Round1; +} + +pub type DoubleSimpleEntryPoint = ChainEntryPoint; + +#[cfg(test)] +mod tests { + use alloc::collections::BTreeSet; + + use manul::{ + session::{signature::Keypair, SessionOutcome}, + testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier}, + }; + use rand_core::OsRng; + use tracing_subscriber::EnvFilter; + + use super::{DoubleSimpleEntryPoint, NewInputs}; + use crate::simple::Inputs; + + #[test] + fn round() { + let signers = (0..3).map(TestSigner::new).collect::>(); + let all_ids = signers + .iter() + .map(|signer| signer.verifying_key()) + .collect::>(); + let inputs = signers + .into_iter() + .map(|signer| { + ( + signer, + NewInputs(Inputs { + all_ids: all_ids.clone(), + }), + ) + }) + .collect::>(); + + let my_subscriber = tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .finish(); + let reports = tracing::subscriber::with_default(my_subscriber, || { + run_sync::, TestSessionParams>(&mut OsRng, inputs) + .unwrap() + }); + + for (_id, report) in reports { + if let SessionOutcome::Result(result) = report.outcome { + assert_eq!(result, 3); // 0 + 1 + 2 + } else { + panic!("Session did not finish successfully"); + } + } + } +} diff --git a/manul/benches/empty_rounds.rs b/manul/benches/empty_rounds.rs index 87eefee..9978486 100644 --- a/manul/benches/empty_rounds.rs +++ b/manul/benches/empty_rounds.rs @@ -51,6 +51,7 @@ struct Round1Artifact; impl EntryPoint for EmptyRound { type Inputs = Inputs; type Protocol = EmptyProtocol; + fn new( _rng: &mut impl CryptoRngCore, _shared_randomness: &[u8], diff --git a/manul/src/combinators.rs b/manul/src/combinators.rs index 66d2aba..4a7aaa1 100644 --- a/manul/src/combinators.rs +++ b/manul/src/combinators.rs @@ -1,5 +1,7 @@ //! Combinators operating on protocols. +mod chain; mod misbehave; +pub use chain::{ChainEntryPoint, Chained}; pub use misbehave::{Behavior, Misbehaving, MisbehavingEntryPoint, MisbehavingInputs}; diff --git a/manul/src/combinators/chain.rs b/manul/src/combinators/chain.rs new file mode 100644 index 0000000..f160888 --- /dev/null +++ b/manul/src/combinators/chain.rs @@ -0,0 +1,470 @@ +use alloc::{ + boxed::Box, + collections::{BTreeMap, BTreeSet}, + format, + string::String, + vec::Vec, +}; +use core::{fmt::Debug, marker::PhantomData}; + +use rand_core::CryptoRngCore; +use serde::{Deserialize, Serialize}; + +use crate::protocol::*; + +/// A trait defining two protocols executed sequentially. +pub trait Chained: 'static +where + Id: PartyId, +{ + /// The inputs of the new chained protocol. + type Inputs: Send + Sync + Debug; + + /// The entry point of the first protocol. + type EntryPoint1: EntryPoint From<&'a Self::Inputs>>; + + /// The entry point of the second protocol. + type EntryPoint2: EntryPoint< + Id, + Inputs: From<( + Self::Inputs, + <>::Protocol as Protocol>::Result, + )>, + >; +} + +#[derive_where::derive_where(Debug, Clone)] +#[derive(Serialize, Deserialize)] +#[serde(bound(serialize = " + <>::Protocol as Protocol>::ProtocolError: Serialize, + <>::Protocol as Protocol>::ProtocolError: Serialize, +"))] +#[serde(bound(deserialize = " + <>::Protocol as Protocol>::ProtocolError: for<'x> Deserialize<'x>, + <>::Protocol as Protocol>::ProtocolError: for<'x> Deserialize<'x>, +"))] +pub enum ChainedProtocolError> { + Protocol1(<>::Protocol as Protocol>::ProtocolError), + Protocol2(<>::Protocol as Protocol>::ProtocolError), +} + +impl ChainedProtocolError +where + Id: PartyId, + C: Chained, +{ + fn from_protocol1(err: <>::Protocol as Protocol>::ProtocolError) -> Self { + Self::Protocol1(err) + } + + fn from_protocol2(err: <>::Protocol as Protocol>::ProtocolError) -> Self { + Self::Protocol2(err) + } +} + +impl ProtocolError for ChainedProtocolError +where + Id: PartyId, + C: Chained, +{ + fn description(&self) -> String { + match self { + Self::Protocol1(err) => format!("Protocol1: {}", err.description()), + Self::Protocol2(err) => format!("Protocol2: {}", err.description()), + } + } + + fn required_direct_messages(&self) -> BTreeSet { + let (protocol_num, round_ids) = match self { + Self::Protocol1(err) => (1, err.required_direct_messages()), + Self::Protocol2(err) => (2, err.required_direct_messages()), + }; + round_ids + .into_iter() + .map(|round_id| round_id.group_under(protocol_num)) + .collect() + } + + fn required_echo_broadcasts(&self) -> BTreeSet { + let (protocol_num, round_ids) = match self { + Self::Protocol1(err) => (1, err.required_echo_broadcasts()), + Self::Protocol2(err) => (2, err.required_echo_broadcasts()), + }; + round_ids + .into_iter() + .map(|round_id| round_id.group_under(protocol_num)) + .collect() + } + + fn required_normal_broadcasts(&self) -> BTreeSet { + let (protocol_num, round_ids) = match self { + Self::Protocol1(err) => (1, err.required_normal_broadcasts()), + Self::Protocol2(err) => (2, err.required_normal_broadcasts()), + }; + round_ids + .into_iter() + .map(|round_id| round_id.group_under(protocol_num)) + .collect() + } + + fn required_combined_echos(&self) -> BTreeSet { + let (protocol_num, round_ids) = match self { + Self::Protocol1(err) => (1, err.required_combined_echos()), + Self::Protocol2(err) => (2, err.required_combined_echos()), + }; + round_ids + .into_iter() + .map(|round_id| round_id.group_under(protocol_num)) + .collect() + } + + #[allow(clippy::too_many_arguments)] + fn verify_messages_constitute_error( + &self, + deserializer: &Deserializer, + echo_broadcast: &EchoBroadcast, + normal_broadcast: &NormalBroadcast, + direct_message: &DirectMessage, + echo_broadcasts: &BTreeMap, + normal_broadcasts: &BTreeMap, + direct_messages: &BTreeMap, + combined_echos: &BTreeMap>, + ) -> Result<(), ProtocolValidationError> { + // TODO: the cloning can be avoided if instead we provide a reference to some "transcript API", + // and can replace it here with a proxy that will remove nesting from round ID's. + let echo_broadcasts = echo_broadcasts + .clone() + .into_iter() + .map(|(round_id, v)| round_id.ungroup().map(|round_id| (round_id, v))) + .collect::, _>>()?; + let normal_broadcasts = normal_broadcasts + .clone() + .into_iter() + .map(|(round_id, v)| round_id.ungroup().map(|round_id| (round_id, v))) + .collect::, _>>()?; + let direct_messages = direct_messages + .clone() + .into_iter() + .map(|(round_id, v)| round_id.ungroup().map(|round_id| (round_id, v))) + .collect::, _>>()?; + let combined_echos = combined_echos + .clone() + .into_iter() + .map(|(round_id, v)| round_id.ungroup().map(|round_id| (round_id, v))) + .collect::, _>>()?; + + match self { + Self::Protocol1(err) => err.verify_messages_constitute_error( + deserializer, + echo_broadcast, + normal_broadcast, + direct_message, + &echo_broadcasts, + &normal_broadcasts, + &direct_messages, + &combined_echos, + ), + Self::Protocol2(err) => err.verify_messages_constitute_error( + deserializer, + echo_broadcast, + normal_broadcast, + direct_message, + &echo_broadcasts, + &normal_broadcasts, + &direct_messages, + &combined_echos, + ), + } + } +} + +#[derive_where::derive_where(Debug, Clone)] +#[derive(Serialize, Deserialize)] +#[serde(bound(serialize = " + <>::Protocol as Protocol>::CorrectnessProof: Serialize, + <>::Protocol as Protocol>::CorrectnessProof: Serialize, +"))] +#[serde(bound(deserialize = " + <>::Protocol as Protocol>::CorrectnessProof: for<'x> Deserialize<'x>, + <>::Protocol as Protocol>::CorrectnessProof: for<'x> Deserialize<'x>, +"))] +pub enum ChainedCorrectnessProof +where + Id: PartyId, + C: Chained, +{ + Protocol1(<>::Protocol as Protocol>::CorrectnessProof), + Protocol2(<>::Protocol as Protocol>::CorrectnessProof), +} + +impl ChainedCorrectnessProof +where + Id: PartyId, + C: Chained, +{ + fn from_protocol1(proof: <>::Protocol as Protocol>::CorrectnessProof) -> Self { + Self::Protocol1(proof) + } + + fn from_protocol2(proof: <>::Protocol as Protocol>::CorrectnessProof) -> Self { + Self::Protocol2(proof) + } +} + +impl CorrectnessProof for ChainedCorrectnessProof +where + Id: PartyId, + C: Chained, +{ +} + +#[derive(Debug)] +#[allow(clippy::type_complexity)] +pub struct ChainedProtocol>(PhantomData (Id, C)>); + +impl Protocol for ChainedProtocol +where + Id: PartyId, + C: Chained, +{ + type Result = <>::Protocol as Protocol>::Result; + type ProtocolError = ChainedProtocolError; + type CorrectnessProof = ChainedCorrectnessProof; +} + +/// The entry point of the chained protocol. +#[derive_where::derive_where(Debug)] +pub struct ChainEntryPoint> { + state: ChainState, +} + +#[derive_where::derive_where(Debug)] +enum ChainState +where + Id: PartyId, + C: Chained, +{ + Protocol1 { + round: BoxedRound>::Protocol>, + shared_randomness: Box<[u8]>, + id: Id, + inputs: C::Inputs, + }, + Protocol2(BoxedRound>::Protocol>), +} + +impl EntryPoint for ChainEntryPoint +where + Id: PartyId, + C: Chained, +{ + type Inputs = C::Inputs; + type Protocol = ChainedProtocol; + const ENTRY_ROUND: RoundId = >::ENTRY_ROUND.group_under(1); + + fn new( + rng: &mut impl CryptoRngCore, + shared_randomness: &[u8], + id: Id, + inputs: Self::Inputs, + ) -> Result, LocalError> { + let round = C::EntryPoint1::new(rng, shared_randomness, id.clone(), (&inputs).into())?; + let round = ChainEntryPoint { + state: ChainState::Protocol1 { + shared_randomness: shared_randomness.into(), + id, + inputs, + round, + }, + }; + Ok(BoxedRound::new_object_safe(round)) + } +} + +impl ObjectSafeRound for ChainEntryPoint +where + Id: PartyId, + C: Chained, +{ + type Protocol = ChainedProtocol; + + fn id(&self) -> RoundId { + match &self.state { + ChainState::Protocol1 { round, .. } => round.as_ref().id().group_under(1), + ChainState::Protocol2(round) => round.as_ref().id().group_under(2), + } + } + + fn possible_next_rounds(&self) -> BTreeSet { + match &self.state { + ChainState::Protocol1 { round, .. } => { + let mut next_rounds = round + .as_ref() + .possible_next_rounds() + .into_iter() + .map(|round_id| round_id.group_under(1)) + .collect::>(); + + // If there are no next rounds, this is the result round. + // This means that in the chain the next round will be the entry round of the second protocol. + if next_rounds.is_empty() { + next_rounds.insert(C::EntryPoint2::ENTRY_ROUND.group_under(2)); + } + next_rounds + } + ChainState::Protocol2(round) => round + .as_ref() + .possible_next_rounds() + .into_iter() + .map(|round_id| round_id.group_under(2)) + .collect(), + } + } + + fn message_destinations(&self) -> &BTreeSet { + match &self.state { + ChainState::Protocol1 { round, .. } => round.as_ref().message_destinations(), + ChainState::Protocol2(round) => round.as_ref().message_destinations(), + } + } + + fn make_direct_message( + &self, + rng: &mut dyn CryptoRngCore, + serializer: &Serializer, + deserializer: &Deserializer, + destination: &Id, + ) -> Result<(DirectMessage, Option), LocalError> { + match &self.state { + ChainState::Protocol1 { round, .. } => { + round + .as_ref() + .make_direct_message(rng, serializer, deserializer, destination) + } + ChainState::Protocol2(round) => { + round + .as_ref() + .make_direct_message(rng, serializer, deserializer, destination) + } + } + } + + fn make_echo_broadcast( + &self, + rng: &mut dyn CryptoRngCore, + serializer: &Serializer, + deserializer: &Deserializer, + ) -> Result { + match &self.state { + ChainState::Protocol1 { round, .. } => round.as_ref().make_echo_broadcast(rng, serializer, deserializer), + ChainState::Protocol2(round) => round.as_ref().make_echo_broadcast(rng, serializer, deserializer), + } + } + + fn make_normal_broadcast( + &self, + rng: &mut dyn CryptoRngCore, + serializer: &Serializer, + deserializer: &Deserializer, + ) -> Result { + match &self.state { + ChainState::Protocol1 { round, .. } => round.as_ref().make_normal_broadcast(rng, serializer, deserializer), + ChainState::Protocol2(round) => round.as_ref().make_normal_broadcast(rng, serializer, deserializer), + } + } + + fn receive_message( + &self, + rng: &mut dyn CryptoRngCore, + deserializer: &Deserializer, + from: &Id, + echo_broadcast: EchoBroadcast, + normal_broadcast: NormalBroadcast, + direct_message: DirectMessage, + ) -> Result> { + match &self.state { + ChainState::Protocol1 { round, .. } => match round.as_ref().receive_message( + rng, + deserializer, + from, + echo_broadcast, + normal_broadcast, + direct_message, + ) { + Ok(payload) => Ok(payload), + Err(err) => Err(err.map(ChainedProtocolError::from_protocol1)), + }, + ChainState::Protocol2(round) => match round.as_ref().receive_message( + rng, + deserializer, + from, + echo_broadcast, + normal_broadcast, + direct_message, + ) { + Ok(payload) => Ok(payload), + Err(err) => Err(err.map(ChainedProtocolError::from_protocol2)), + }, + } + } + + fn finalize( + self: Box, + rng: &mut dyn CryptoRngCore, + payloads: BTreeMap, + artifacts: BTreeMap, + ) -> Result, FinalizeError> { + match self.state { + ChainState::Protocol1 { + round, + id, + inputs, + shared_randomness, + } => match round.into_boxed().finalize(rng, payloads, artifacts) { + Ok(FinalizeOutcome::Result(result)) => { + let mut boxed_rng = BoxedRng(rng); + let round = C::EntryPoint2::new(&mut boxed_rng, &shared_randomness, id, (inputs, result).into())?; + + Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_object_safe( + ChainEntryPoint:: { + state: ChainState::Protocol2(round), + }, + ))) + } + Ok(FinalizeOutcome::AnotherRound(round)) => Ok(FinalizeOutcome::AnotherRound( + BoxedRound::new_object_safe(ChainEntryPoint:: { + state: ChainState::Protocol1 { + shared_randomness, + id, + inputs, + round, + }, + }), + )), + Err(FinalizeError::Local(err)) => Err(FinalizeError::Local(err)), + Err(FinalizeError::Unattributable(proof)) => Err(FinalizeError::Unattributable( + ChainedCorrectnessProof::from_protocol1(proof), + )), + }, + ChainState::Protocol2(round) => match round.into_boxed().finalize(rng, payloads, artifacts) { + Ok(FinalizeOutcome::Result(result)) => Ok(FinalizeOutcome::Result(result)), + Ok(FinalizeOutcome::AnotherRound(round)) => Ok(FinalizeOutcome::AnotherRound( + BoxedRound::new_object_safe(ChainEntryPoint:: { + state: ChainState::Protocol2(round), + }), + )), + Err(FinalizeError::Local(err)) => Err(FinalizeError::Local(err)), + Err(FinalizeError::Unattributable(proof)) => Err(FinalizeError::Unattributable( + ChainedCorrectnessProof::from_protocol2(proof), + )), + }, + } + } + + fn expecting_messages_from(&self) -> &BTreeSet { + match &self.state { + ChainState::Protocol1 { round, .. } => round.as_ref().expecting_messages_from(), + ChainState::Protocol2(round) => round.as_ref().expecting_messages_from(), + } + } +} diff --git a/manul/src/protocol/errors.rs b/manul/src/protocol/errors.rs index a77ecc3..c874ea2 100644 --- a/manul/src/protocol/errors.rs +++ b/manul/src/protocol/errors.rs @@ -1,4 +1,4 @@ -use alloc::{format, string::String}; +use alloc::{boxed::Box, format, string::String}; use core::fmt::Debug; use super::round::Protocol; @@ -50,7 +50,7 @@ pub(crate) enum ReceiveErrorType { // so this whole enum is crate-private and the variants are created // via constructors and From impls. /// An echo round error occurred. - Echo(EchoRoundError), + Echo(Box>), } impl ReceiveError { @@ -68,6 +68,32 @@ impl ReceiveError { pub fn protocol(error: P::ProtocolError) -> Self { Self(ReceiveErrorType::Protocol(error)) } + + pub(crate) fn map(self, f: F) -> ReceiveError + where + F: Fn(P::ProtocolError) -> T::ProtocolError, + T: Protocol, + { + ReceiveError(self.0.map::(f)) + } +} + +impl ReceiveErrorType { + pub(crate) fn map(self, f: F) -> ReceiveErrorType + where + F: Fn(P::ProtocolError) -> T::ProtocolError, + T: Protocol, + { + match self { + Self::Local(err) => ReceiveErrorType::Local(err), + Self::InvalidDirectMessage(err) => ReceiveErrorType::InvalidDirectMessage(err), + Self::InvalidEchoBroadcast(err) => ReceiveErrorType::InvalidEchoBroadcast(err), + Self::InvalidNormalBroadcast(err) => ReceiveErrorType::InvalidNormalBroadcast(err), + Self::Unprovable(err) => ReceiveErrorType::Unprovable(err), + Self::Echo(err) => ReceiveErrorType::Echo(err), + Self::Protocol(err) => ReceiveErrorType::Protocol(f(err)), + } + } } impl From for ReceiveError @@ -93,7 +119,7 @@ where P: Protocol, { fn from(error: EchoRoundError) -> Self { - Self(ReceiveErrorType::Echo(error)) + Self(ReceiveErrorType::Echo(Box::new(error))) } } diff --git a/manul/src/protocol/round.rs b/manul/src/protocol/round.rs index bc58e0a..fcadb70 100644 --- a/manul/src/protocol/round.rs +++ b/manul/src/protocol/round.rs @@ -2,10 +2,13 @@ use alloc::{ boxed::Box, collections::{BTreeMap, BTreeSet}, format, - string::String, + string::{String, ToString}, vec::Vec, }; -use core::{any::Any, fmt::Debug}; +use core::{ + any::Any, + fmt::{self, Debug, Display}, +}; use rand_core::CryptoRngCore; use serde::{Deserialize, Serialize}; @@ -29,19 +32,77 @@ pub enum FinalizeOutcome { /// A round identifier. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub struct RoundId { - round_num: u8, + depth: u8, + round_num: [u8; 8], is_echo: bool, } +impl Display for RoundId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + /*write!(f, "Round ")?; + for i in (0..self.depth as usize).rev() { + write!(f, "{}", self.round_num.get(i).ok_or_else(fmt::Error::)?)?; + if i != 0 { + write!(f, "-")?; + } + } + Ok(())*/ + let full_num = self + .round_num + .get(0..self.depth as usize) + .expect("Depth within range") + .iter() + .rev() + .map(|round| round.to_string()) + .collect::>() + .join("-"); + write!(f, "Round {}", full_num) + } +} + impl RoundId { /// Creates a new round identifier. - pub fn new(round_num: u8) -> Self { + pub const fn new(round_num: u8) -> Self { Self { - round_num, + depth: 1, + round_num: [round_num, 0, 0, 0, 0, 0, 0, 0], is_echo: false, } } + pub(crate) const fn group_under(&self, round_num: u8) -> Self { + if self.depth == 8 { + panic!("Maximum depth reached"); + } + let mut round_nums = self.round_num; + + // Would use `expect("Depth within range")` here, but `expect()` in const fns is unstable. + #[allow(clippy::indexing_slicing)] + { + round_nums[self.depth as usize] = round_num; + } + + Self { + depth: self.depth + 1, + round_num: round_nums, + is_echo: self.is_echo, + } + } + + pub(crate) fn ungroup(&self) -> Result { + if self.depth == 1 { + Err(LocalError::new("This round ID is not in a group")) + } else { + let mut round_nums = self.round_num; + *round_nums.get_mut(self.depth as usize - 1).expect("Depth within range") = 0; + Ok(Self { + depth: self.depth - 1, + round_num: round_nums, + is_echo: self.is_echo, + }) + } + } + /// Returns `true` if this is an ID of an echo broadcast round. pub(crate) fn is_echo(&self) -> bool { self.is_echo @@ -57,6 +118,7 @@ impl RoundId { panic!("This is already an echo round ID"); } Self { + depth: self.depth, round_num: self.round_num, is_echo: true, } @@ -72,6 +134,7 @@ impl RoundId { panic!("This is already an non-echo round ID"); } Self { + depth: self.depth, round_num: self.round_num, is_echo: false, } @@ -299,6 +362,9 @@ pub trait EntryPoint { /// The protocol implemented by the round this entry points returns. type Protocol: Protocol; + /// The ID of the round returned by [`Self::new`]. + const ENTRY_ROUND: RoundId = RoundId::new(1); + /// Creates the round. /// /// `session_id` can be assumed to be the same for each node participating in a session. diff --git a/manul/src/session/session.rs b/manul/src/session/session.rs index ad33fc3..7a33283 100644 --- a/manul/src/session/session.rs +++ b/manul/src/session/session.rs @@ -734,7 +734,7 @@ where } ReceiveErrorType::Echo(error) => { let (_echo_broadcast, normal_broadcast, _direct_message) = processed.message.into_parts(); - let evidence = Evidence::new_echo_round_error(&from, normal_broadcast, error)?; + let evidence = Evidence::new_echo_round_error(&from, normal_broadcast, *error)?; self.register_provable_error(&from, evidence) } ReceiveErrorType::Local(error) => Err(error),