diff --git a/examples/src/lib.rs b/examples/src/lib.rs index c3aff76..f24273f 100644 --- a/examples/src/lib.rs +++ b/examples/src/lib.rs @@ -1,6 +1,7 @@ extern crate alloc; pub mod simple; +mod simple_chain; #[cfg(test)] mod simple_malicious; diff --git a/examples/src/simple_chain.rs b/examples/src/simple_chain.rs new file mode 100644 index 0000000..2e8269b --- /dev/null +++ b/examples/src/simple_chain.rs @@ -0,0 +1,86 @@ +use core::fmt::Debug; +use core::marker::PhantomData; + +use manul::{combinators::*, protocol::PartyId}; + +use super::simple::{Inputs, Round1, SimpleProtocol}; + +pub struct ChainedSimple(PhantomData); + +#[derive(Debug)] +pub struct NewInputs(Inputs); + +impl<'a, Id: PartyId> From<&'a NewInputs> for Inputs { + fn from(source: &'a NewInputs) -> Self { + source.0.clone() + } +} + +impl From<(NewInputs, u8)> for Inputs { + fn from(source: (NewInputs, u8)) -> Self { + source.0 .0 + } +} + +impl Chained for ChainedSimple { + type Protocol1 = SimpleProtocol; + type Protocol2 = SimpleProtocol; + type CorrectnessProof = (); + type Inputs = NewInputs; + type EntryPoint1 = Round1; + type EntryPoint2 = Round1; +} + +#[cfg(test)] +mod tests { + use alloc::collections::BTreeSet; + + use manul::{ + combinators::Chain, + session::{signature::Keypair, SessionOutcome}, + testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier}, + }; + use rand_core::OsRng; + use tracing_subscriber::EnvFilter; + + use super::{ChainedSimple, NewInputs}; + use crate::simple::Inputs; + + #[test] + fn round() { + let signers = (0..3).map(TestSigner::new).collect::>(); + let all_ids = signers + .iter() + .map(|signer| signer.verifying_key()) + .collect::>(); + let inputs = signers + .into_iter() + .map(|signer| { + ( + signer, + NewInputs(Inputs { + all_ids: all_ids.clone(), + }), + ) + }) + .collect::>(); + + let my_subscriber = tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .finish(); + let reports = tracing::subscriber::with_default(my_subscriber, || { + run_sync::>, TestSessionParams>( + &mut OsRng, inputs, + ) + .unwrap() + }); + + for (_id, report) in reports { + if let SessionOutcome::Result(result) = report.outcome { + assert_eq!(result, 3); // 0 + 1 + 2 + } else { + panic!("Session did not finish successfully"); + } + } + } +} diff --git a/manul/src/combinators.rs b/manul/src/combinators.rs index b6b2a9b..aad3e74 100644 --- a/manul/src/combinators.rs +++ b/manul/src/combinators.rs @@ -1,3 +1,5 @@ +mod chain; mod misbehave; +pub use chain::{Chain, Chained, ChainedProtocol}; pub use misbehave::{Misbehaving, MisbehavingInputs, MisbehavingRound}; diff --git a/manul/src/combinators/chain.rs b/manul/src/combinators/chain.rs new file mode 100644 index 0000000..becd175 --- /dev/null +++ b/manul/src/combinators/chain.rs @@ -0,0 +1,320 @@ +use alloc::{ + boxed::Box, + collections::{BTreeMap, BTreeSet}, + format, + string::String, + vec::Vec, +}; +use core::{fmt::Debug, marker::PhantomData}; + +use rand_core::CryptoRngCore; +use serde::{Deserialize, Serialize}; + +use crate::protocol::*; + +pub trait Chained { + type Protocol1: Protocol; + type Protocol2: Protocol; + type CorrectnessProof: Send + + Serialize + + for<'de> Deserialize<'de> + + Debug + + From<::CorrectnessProof> + + From<::CorrectnessProof>; + type Inputs: Send + Sync + Debug; + type EntryPoint1: FirstRound From<&'a Self::Inputs>>; + type EntryPoint2: FirstRound< + Id, + Protocol = Self::Protocol2, + Inputs: From<(Self::Inputs, ::Result)>, + >; +} + +#[derive_where::derive_where(Debug, Clone)] +#[derive(Serialize, Deserialize)] +pub enum ChainedProtocolError> { + Protocol1(::ProtocolError), + Protocol2(::ProtocolError), +} + +impl> ChainedProtocolError { + fn from_protocol1(err: ::ProtocolError) -> Self { + Self::Protocol1(err) + } + + fn from_protocol2(err: ::ProtocolError) -> Self { + Self::Protocol2(err) + } +} + +impl> ProtocolError for ChainedProtocolError { + fn description(&self) -> String { + match self { + Self::Protocol1(err) => format!("Protocol1: {}", err.description()), + Self::Protocol2(err) => format!("Protocol2: {}", err.description()), + } + } + + fn required_direct_messages(&self) -> BTreeSet { + match self { + // TODO: map rounds! + Self::Protocol1(err) => err.required_direct_messages(), + Self::Protocol2(err) => err.required_direct_messages(), + } + } + + fn required_echo_broadcasts(&self) -> BTreeSet { + match self { + Self::Protocol1(err) => err.required_echo_broadcasts(), + Self::Protocol2(err) => err.required_echo_broadcasts(), + } + } + + fn required_normal_broadcasts(&self) -> BTreeSet { + match self { + Self::Protocol1(err) => err.required_normal_broadcasts(), + Self::Protocol2(err) => err.required_normal_broadcasts(), + } + } + + fn required_combined_echos(&self) -> BTreeSet { + match self { + Self::Protocol1(err) => err.required_combined_echos(), + Self::Protocol2(err) => err.required_combined_echos(), + } + } + + #[allow(clippy::too_many_arguments)] + fn verify_messages_constitute_error( + &self, + deserializer: &Deserializer, + echo_broadcast: &EchoBroadcast, + normal_broadcast: &NormalBroadcast, + direct_message: &DirectMessage, + echo_broadcasts: &BTreeMap, + normal_broadcasts: &BTreeMap, + direct_messages: &BTreeMap, + combined_echos: &BTreeMap>, + ) -> Result<(), ProtocolValidationError> { + match self { + Self::Protocol1(err) => err.verify_messages_constitute_error( + deserializer, + echo_broadcast, + normal_broadcast, + direct_message, + echo_broadcasts, + normal_broadcasts, + direct_messages, + combined_echos, + ), + Self::Protocol2(err) => err.verify_messages_constitute_error( + deserializer, + echo_broadcast, + normal_broadcast, + direct_message, + echo_broadcasts, + normal_broadcasts, + direct_messages, + combined_echos, + ), + } + } +} + +pub struct ChainedProtocol>(PhantomData (Id, C)>); + +impl Protocol for ChainedProtocol +where + Id: PartyId, + C: 'static + Chained, +{ + type Result = ::Result; + type ProtocolError = ChainedProtocolError; + type CorrectnessProof = C::CorrectnessProof; +} + +#[derive_where::derive_where(Debug)] +pub struct Chain> { + state: ChainState, +} + +#[derive_where::derive_where(Debug)] +enum ChainState> { + Protocol1 { + round: Box>, + shared_randomness: Box<[u8]>, + id: Id, + inputs: C::Inputs, + }, + Protocol2(Box>), +} + +impl FirstRound for Chain +where + Id: PartyId, + C: 'static + Chained, +{ + type Inputs = C::Inputs; + + fn new( + rng: &mut impl CryptoRngCore, + shared_randomness: &[u8], + id: Id, + inputs: Self::Inputs, + ) -> Result { + let round = C::EntryPoint1::new(rng, shared_randomness, id.clone(), (&inputs).into())?; + Ok(Chain { + state: ChainState::Protocol1 { + shared_randomness: shared_randomness.into(), + id, + inputs, + round: Box::new(ObjectSafeRoundWrapper::new(round)), + }, + }) + } +} + +impl Round for Chain +where + Id: PartyId, + C: 'static + Chained, +{ + type Protocol = ChainedProtocol; + + fn id(&self) -> RoundId { + unimplemented!() + } + + fn possible_next_rounds(&self) -> BTreeSet { + unimplemented!() + } + + fn message_destinations(&self) -> &BTreeSet { + match &self.state { + ChainState::Protocol1 { round, .. } => round.message_destinations(), + ChainState::Protocol2(round) => round.message_destinations(), + } + } + + fn make_direct_message( + &self, + rng: &mut impl CryptoRngCore, + serializer: &Serializer, + destination: &Id, + ) -> Result<(DirectMessage, Option), LocalError> { + match &self.state { + ChainState::Protocol1 { round, .. } => round.make_direct_message(rng, serializer, destination), + ChainState::Protocol2(round) => round.make_direct_message(rng, serializer, destination), + } + } + + fn make_echo_broadcast( + &self, + #[allow(unused_variables)] rng: &mut impl CryptoRngCore, + #[allow(unused_variables)] serializer: &Serializer, + ) -> Result { + match &self.state { + ChainState::Protocol1 { round, .. } => round.make_echo_broadcast(rng, serializer), + ChainState::Protocol2(round) => round.make_echo_broadcast(rng, serializer), + } + } + + fn make_normal_broadcast( + &self, + #[allow(unused_variables)] rng: &mut impl CryptoRngCore, + #[allow(unused_variables)] serializer: &Serializer, + ) -> Result { + match &self.state { + ChainState::Protocol1 { round, .. } => round.make_normal_broadcast(rng, serializer), + ChainState::Protocol2(round) => round.make_normal_broadcast(rng, serializer), + } + } + + fn receive_message( + &self, + rng: &mut impl CryptoRngCore, + deserializer: &Deserializer, + from: &Id, + echo_broadcast: EchoBroadcast, + normal_broadcast: NormalBroadcast, + direct_message: DirectMessage, + ) -> Result> { + match &self.state { + ChainState::Protocol1 { round, .. } => match round.receive_message( + rng, + deserializer, + from, + echo_broadcast, + normal_broadcast, + direct_message, + ) { + Ok(payload) => Ok(payload), + Err(err) => Err(err.map(ChainedProtocolError::from_protocol1)), + }, + ChainState::Protocol2(round) => match round.receive_message( + rng, + deserializer, + from, + echo_broadcast, + normal_broadcast, + direct_message, + ) { + Ok(payload) => Ok(payload), + Err(err) => Err(err.map(ChainedProtocolError::from_protocol2)), + }, + } + } + + fn finalize( + self, + rng: &mut impl CryptoRngCore, + payloads: BTreeMap, + artifacts: BTreeMap, + ) -> Result, FinalizeError> { + match self.state { + ChainState::Protocol1 { + round, + id, + inputs, + shared_randomness, + } => match round.finalize(rng, payloads, artifacts) { + Ok(FinalizeOutcome::Result(result)) => { + let round = C::EntryPoint2::new(rng, &shared_randomness, id, (inputs, result).into())?; + + Ok(FinalizeOutcome::another_round(Chain:: { + state: ChainState::Protocol2(Box::new(ObjectSafeRoundWrapper::new(round))), + })) + } + Ok(FinalizeOutcome::AnotherRound(another_round)) => { + Ok(FinalizeOutcome::another_round(Chain:: { + state: ChainState::Protocol1 { + shared_randomness, + id, + inputs, + round: another_round.into_boxed(), + }, + })) + } + Err(FinalizeError::Local(err)) => Err(FinalizeError::Local(err)), + Err(FinalizeError::Unattributable(proof)) => Err(FinalizeError::Unattributable(proof.into())), + }, + ChainState::Protocol2(round) => match round.finalize(rng, payloads, artifacts) { + Ok(FinalizeOutcome::Result(result)) => Ok(FinalizeOutcome::Result(result)), + Ok(FinalizeOutcome::AnotherRound(another_round)) => { + Ok(FinalizeOutcome::another_round(Chain:: { + state: ChainState::Protocol2(another_round.into_boxed()), + })) + } + Err(FinalizeError::Local(err)) => Err(FinalizeError::Local(err)), + Err(FinalizeError::Unattributable(proof)) => Err(FinalizeError::Unattributable(proof.into())), + }, + } + } + + fn expecting_messages_from(&self) -> &BTreeSet { + match &self.state { + ChainState::Protocol1 { round, .. } => round.expecting_messages_from(), + ChainState::Protocol2(round) => round.expecting_messages_from(), + } + } +} diff --git a/manul/src/protocol/errors.rs b/manul/src/protocol/errors.rs index a77ecc3..460641d 100644 --- a/manul/src/protocol/errors.rs +++ b/manul/src/protocol/errors.rs @@ -68,6 +68,32 @@ impl ReceiveError { pub fn protocol(error: P::ProtocolError) -> Self { Self(ReceiveErrorType::Protocol(error)) } + + pub(crate) fn map(self, f: F) -> ReceiveError + where + F: Fn(P::ProtocolError) -> T::ProtocolError, + T: Protocol, + { + ReceiveError(self.0.map::(f)) + } +} + +impl ReceiveErrorType { + pub(crate) fn map(self, f: F) -> ReceiveErrorType + where + F: Fn(P::ProtocolError) -> T::ProtocolError, + T: Protocol, + { + match self { + Self::Local(err) => ReceiveErrorType::Local(err), + Self::InvalidDirectMessage(err) => ReceiveErrorType::InvalidDirectMessage(err), + Self::InvalidEchoBroadcast(err) => ReceiveErrorType::InvalidEchoBroadcast(err), + Self::InvalidNormalBroadcast(err) => ReceiveErrorType::InvalidNormalBroadcast(err), + Self::Unprovable(err) => ReceiveErrorType::Unprovable(err), + Self::Echo(err) => ReceiveErrorType::Echo(err), + Self::Protocol(err) => ReceiveErrorType::Protocol(f(err)), + } + } } impl From for ReceiveError