diff --git a/examples/src/simple.rs b/examples/src/simple.rs index bbf4756..eea21d9 100644 --- a/examples/src/simple.rs +++ b/examples/src/simple.rs @@ -149,14 +149,16 @@ struct Round1Payload { x: u8, } -impl FirstRound for Round1 { +impl EntryPoint for Round1 { type Inputs = Inputs; + type Protocol = SimpleProtocol; + const RESULT_ROUND: RoundId = RoundId::new(2); fn new( _rng: &mut impl CryptoRngCore, _shared_randomness: &[u8], id: Id, inputs: Self::Inputs, - ) -> Result { + ) -> Result, LocalError> { // Just some numbers associated with IDs to use in the dummy protocol. // They will be the same on each node since IDs are ordered. let ids_to_positions = inputs @@ -169,13 +171,13 @@ impl FirstRound for Round1 { let mut ids = inputs.all_ids; ids.remove(&id); - Ok(Self { + Ok(BoxedRound::new_dynamic(Self { context: Context { id, other_ids: ids, ids_to_positions, }, - }) + })) } } diff --git a/examples/src/simple_chain.rs b/examples/src/simple_chain.rs index 2e8269b..c6dd2d0 100644 --- a/examples/src/simple_chain.rs +++ b/examples/src/simple_chain.rs @@ -3,7 +3,7 @@ use core::marker::PhantomData; use manul::{combinators::*, protocol::PartyId}; -use super::simple::{Inputs, Round1, SimpleProtocol}; +use super::simple::{Inputs, Round1}; pub struct ChainedSimple(PhantomData); @@ -23,9 +23,6 @@ impl From<(NewInputs, u8)> for Inputs { } impl Chained for ChainedSimple { - type Protocol1 = SimpleProtocol; - type Protocol2 = SimpleProtocol; - type CorrectnessProof = (); type Inputs = NewInputs; type EntryPoint1 = Round1; type EntryPoint2 = Round1; diff --git a/examples/src/simple_malicious.rs b/examples/src/simple_malicious.rs index 3556157..4124623 100644 --- a/examples/src/simple_malicious.rs +++ b/examples/src/simple_malicious.rs @@ -4,7 +4,8 @@ use core::fmt::Debug; use manul::{ combinators::{Misbehaving, MisbehavingInputs, MisbehavingRound}, protocol::{ - Artifact, DirectMessage, LocalError, ObjectSafeRound, PartyId, ProtocolMessagePart, RoundId, Serializer, + Artifact, BoxedRound, Deserializer, DirectMessage, EntryPoint, LocalError, PartyId, ProtocolMessagePart, + RoundId, Serializer, }, session::signature::Keypair, testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier}, @@ -12,7 +13,7 @@ use manul::{ use rand_core::{CryptoRngCore, OsRng}; use tracing_subscriber::EnvFilter; -use crate::simple::{Inputs, Round1, Round1Message, Round2, Round2Message, SimpleProtocol}; +use crate::simple::{Inputs, Round1, Round1Message, Round2, Round2Message}; #[derive(Debug, Clone, Copy)] enum Behavior { @@ -24,14 +25,14 @@ enum Behavior { struct SimpleMaliciousProtocol; impl Misbehaving for SimpleMaliciousProtocol { - type Protocol = SimpleProtocol; - type FirstRound = Round1; + type EntryPoint = Round1; fn amend_direct_message( _rng: &mut impl CryptoRngCore, - round: &dyn ObjectSafeRound, + round: &BoxedRound>::Protocol>, behavior: &Behavior, serializer: &Serializer, + _deserializer: &Deserializer, _destination: &Id, direct_message: DirectMessage, artifact: Option, diff --git a/manul/benches/empty_rounds.rs b/manul/benches/empty_rounds.rs index 4420243..211f0f9 100644 --- a/manul/benches/empty_rounds.rs +++ b/manul/benches/empty_rounds.rs @@ -9,7 +9,7 @@ use core::fmt::Debug; use criterion::{criterion_group, criterion_main, Criterion}; use manul::{ protocol::{ - Artifact, Deserializer, DirectMessage, EchoBroadcast, FinalizeError, FinalizeOutcome, FirstRound, LocalError, + Artifact, Deserializer, DirectMessage, EchoBroadcast, EntryPoint, FinalizeError, FinalizeOutcome, LocalError, NormalBroadcast, PartyId, Payload, Protocol, ProtocolError, ProtocolMessagePart, ProtocolValidationError, ReceiveError, Round, RoundId, Serializer, }, @@ -73,7 +73,7 @@ struct Round1Payload; struct Round1Artifact; -impl FirstRound for EmptyRound { +impl EntryPoint for EmptyRound { type Inputs = Inputs; fn new( _rng: &mut impl CryptoRngCore, diff --git a/manul/src/combinators/chain.rs b/manul/src/combinators/chain.rs index becd175..29b627f 100644 --- a/manul/src/combinators/chain.rs +++ b/manul/src/combinators/chain.rs @@ -13,36 +13,38 @@ use serde::{Deserialize, Serialize}; use crate::protocol::*; pub trait Chained { - type Protocol1: Protocol; - type Protocol2: Protocol; - type CorrectnessProof: Send - + Serialize - + for<'de> Deserialize<'de> - + Debug - + From<::CorrectnessProof> - + From<::CorrectnessProof>; type Inputs: Send + Sync + Debug; - type EntryPoint1: FirstRound From<&'a Self::Inputs>>; - type EntryPoint2: FirstRound< + type EntryPoint1: EntryPoint From<&'a Self::Inputs>>; + type EntryPoint2: EntryPoint< Id, - Protocol = Self::Protocol2, - Inputs: From<(Self::Inputs, ::Result)>, + Inputs: From<( + Self::Inputs, + <>::Protocol as Protocol>::Result, + )>, >; } #[derive_where::derive_where(Debug, Clone)] #[derive(Serialize, Deserialize)] +#[serde(bound(serialize = " + <>::Protocol as Protocol>::ProtocolError: Serialize, + <>::Protocol as Protocol>::ProtocolError: Serialize, +"))] +#[serde(bound(deserialize = " + <>::Protocol as Protocol>::ProtocolError: for<'x> Deserialize<'x>, + <>::Protocol as Protocol>::ProtocolError: for<'x> Deserialize<'x>, +"))] pub enum ChainedProtocolError> { - Protocol1(::ProtocolError), - Protocol2(::ProtocolError), + Protocol1(<>::Protocol as Protocol>::ProtocolError), + Protocol2(<>::Protocol as Protocol>::ProtocolError), } impl> ChainedProtocolError { - fn from_protocol1(err: ::ProtocolError) -> Self { + fn from_protocol1(err: <>::Protocol as Protocol>::ProtocolError) -> Self { Self::Protocol1(err) } - fn from_protocol2(err: ::ProtocolError) -> Self { + fn from_protocol2(err: <>::Protocol as Protocol>::ProtocolError) -> Self { Self::Protocol2(err) } } @@ -56,32 +58,47 @@ impl> ProtocolError for ChainedProtocolError } fn required_direct_messages(&self) -> BTreeSet { - match self { - // TODO: map rounds! - Self::Protocol1(err) => err.required_direct_messages(), - Self::Protocol2(err) => err.required_direct_messages(), - } + let (protocol_num, round_ids) = match self { + Self::Protocol1(err) => (1, err.required_direct_messages()), + Self::Protocol2(err) => (2, err.required_direct_messages()), + }; + round_ids + .into_iter() + .map(|round_id| round_id.group_under(protocol_num)) + .collect() } fn required_echo_broadcasts(&self) -> BTreeSet { - match self { - Self::Protocol1(err) => err.required_echo_broadcasts(), - Self::Protocol2(err) => err.required_echo_broadcasts(), - } + let (protocol_num, round_ids) = match self { + Self::Protocol1(err) => (1, err.required_echo_broadcasts()), + Self::Protocol2(err) => (2, err.required_echo_broadcasts()), + }; + round_ids + .into_iter() + .map(|round_id| round_id.group_under(protocol_num)) + .collect() } fn required_normal_broadcasts(&self) -> BTreeSet { - match self { - Self::Protocol1(err) => err.required_normal_broadcasts(), - Self::Protocol2(err) => err.required_normal_broadcasts(), - } + let (protocol_num, round_ids) = match self { + Self::Protocol1(err) => (1, err.required_normal_broadcasts()), + Self::Protocol2(err) => (2, err.required_normal_broadcasts()), + }; + round_ids + .into_iter() + .map(|round_id| round_id.group_under(protocol_num)) + .collect() } fn required_combined_echos(&self) -> BTreeSet { - match self { - Self::Protocol1(err) => err.required_combined_echos(), - Self::Protocol2(err) => err.required_combined_echos(), - } + let (protocol_num, round_ids) = match self { + Self::Protocol1(err) => (1, err.required_combined_echos()), + Self::Protocol2(err) => (2, err.required_combined_echos()), + }; + round_ids + .into_iter() + .map(|round_id| round_id.group_under(protocol_num)) + .collect() } #[allow(clippy::too_many_arguments)] @@ -96,31 +113,81 @@ impl> ProtocolError for ChainedProtocolError direct_messages: &BTreeMap, combined_echos: &BTreeMap>, ) -> Result<(), ProtocolValidationError> { + // TODO: the cloning can be avoided if instead we provide a reference to some "transcript API", + // and can replace it here with a proxy that will remove nesting from round ID's. + let echo_broadcasts = echo_broadcasts + .clone() + .into_iter() + .map(|(round_id, v)| round_id.ungroup().map(|round_id| (round_id, v))) + .collect::, _>>()?; + let normal_broadcasts = normal_broadcasts + .clone() + .into_iter() + .map(|(round_id, v)| round_id.ungroup().map(|round_id| (round_id, v))) + .collect::, _>>()?; + let direct_messages = direct_messages + .clone() + .into_iter() + .map(|(round_id, v)| round_id.ungroup().map(|round_id| (round_id, v))) + .collect::, _>>()?; + let combined_echos = combined_echos + .clone() + .into_iter() + .map(|(round_id, v)| round_id.ungroup().map(|round_id| (round_id, v))) + .collect::, _>>()?; + match self { Self::Protocol1(err) => err.verify_messages_constitute_error( deserializer, echo_broadcast, normal_broadcast, direct_message, - echo_broadcasts, - normal_broadcasts, - direct_messages, - combined_echos, + &echo_broadcasts, + &normal_broadcasts, + &direct_messages, + &combined_echos, ), Self::Protocol2(err) => err.verify_messages_constitute_error( deserializer, echo_broadcast, normal_broadcast, direct_message, - echo_broadcasts, - normal_broadcasts, - direct_messages, - combined_echos, + &echo_broadcasts, + &normal_broadcasts, + &direct_messages, + &combined_echos, ), } } } +#[derive_where::derive_where(Debug, Clone)] +#[derive(Serialize, Deserialize)] +#[serde(bound(serialize = " + <>::Protocol as Protocol>::CorrectnessProof: Serialize, + <>::Protocol as Protocol>::CorrectnessProof: Serialize, +"))] +#[serde(bound(deserialize = " + <>::Protocol as Protocol>::CorrectnessProof: for<'x> Deserialize<'x>, + <>::Protocol as Protocol>::CorrectnessProof: for<'x> Deserialize<'x>, +"))] +pub enum ChainedCorrectnessProof> { + Protocol1(<>::Protocol as Protocol>::CorrectnessProof), + Protocol2(<>::Protocol as Protocol>::CorrectnessProof), +} + +impl> ChainedCorrectnessProof { + fn from_protocol1(proof: <>::Protocol as Protocol>::CorrectnessProof) -> Self { + Self::Protocol1(proof) + } + + fn from_protocol2(proof: <>::Protocol as Protocol>::CorrectnessProof) -> Self { + Self::Protocol2(proof) + } +} + +impl> CorrectnessProof for ChainedCorrectnessProof {} + pub struct ChainedProtocol>(PhantomData (Id, C)>); impl Protocol for ChainedProtocol @@ -128,9 +195,9 @@ where Id: PartyId, C: 'static + Chained, { - type Result = ::Result; + type Result = <>::Protocol as Protocol>::Result; type ProtocolError = ChainedProtocolError; - type CorrectnessProof = C::CorrectnessProof; + type CorrectnessProof = ChainedCorrectnessProof; } #[derive_where::derive_where(Debug)] @@ -141,40 +208,43 @@ pub struct Chain> { #[derive_where::derive_where(Debug)] enum ChainState> { Protocol1 { - round: Box>, + round: BoxedRound>::Protocol>, shared_randomness: Box<[u8]>, id: Id, inputs: C::Inputs, }, - Protocol2(Box>), + Protocol2(BoxedRound>::Protocol>), } -impl FirstRound for Chain +impl EntryPoint for Chain where Id: PartyId, C: 'static + Chained, { type Inputs = C::Inputs; + type Protocol = ChainedProtocol; + const RESULT_ROUND: RoundId = >::RESULT_ROUND.group_under(2); fn new( rng: &mut impl CryptoRngCore, shared_randomness: &[u8], id: Id, inputs: Self::Inputs, - ) -> Result { + ) -> Result, LocalError> { let round = C::EntryPoint1::new(rng, shared_randomness, id.clone(), (&inputs).into())?; - Ok(Chain { + let round = Chain { state: ChainState::Protocol1 { shared_randomness: shared_randomness.into(), id, inputs, - round: Box::new(ObjectSafeRoundWrapper::new(round)), + round, }, - }) + }; + Ok(BoxedRound::new_object_safe(round)) } } -impl Round for Chain +impl ObjectSafeRound for Chain where Id: PartyId, C: 'static + Chained, @@ -182,57 +252,90 @@ where type Protocol = ChainedProtocol; fn id(&self) -> RoundId { - unimplemented!() + match &self.state { + ChainState::Protocol1 { round, .. } => round.as_ref().id().group_under(1), + ChainState::Protocol2(round) => round.as_ref().id().group_under(2), + } } fn possible_next_rounds(&self) -> BTreeSet { - unimplemented!() + match &self.state { + ChainState::Protocol1 { round, .. } => { + let mut next_rounds = round + .as_ref() + .possible_next_rounds() + .into_iter() + .map(|round_id| round_id.group_under(1)) + .collect::>(); + if round.id() == C::EntryPoint1::RESULT_ROUND { + next_rounds.insert(C::EntryPoint2::ENTRY_ROUND.group_under(2)); + } + next_rounds + } + ChainState::Protocol2(round) => round + .as_ref() + .possible_next_rounds() + .into_iter() + .map(|round_id| round_id.group_under(2)) + .collect(), + } } fn message_destinations(&self) -> &BTreeSet { match &self.state { - ChainState::Protocol1 { round, .. } => round.message_destinations(), - ChainState::Protocol2(round) => round.message_destinations(), + ChainState::Protocol1 { round, .. } => round.as_ref().message_destinations(), + ChainState::Protocol2(round) => round.as_ref().message_destinations(), } } fn make_direct_message( &self, - rng: &mut impl CryptoRngCore, + rng: &mut dyn CryptoRngCore, serializer: &Serializer, + deserializer: &Deserializer, destination: &Id, ) -> Result<(DirectMessage, Option), LocalError> { match &self.state { - ChainState::Protocol1 { round, .. } => round.make_direct_message(rng, serializer, destination), - ChainState::Protocol2(round) => round.make_direct_message(rng, serializer, destination), + ChainState::Protocol1 { round, .. } => { + round + .as_ref() + .make_direct_message(rng, serializer, deserializer, destination) + } + ChainState::Protocol2(round) => { + round + .as_ref() + .make_direct_message(rng, serializer, deserializer, destination) + } } } fn make_echo_broadcast( &self, - #[allow(unused_variables)] rng: &mut impl CryptoRngCore, - #[allow(unused_variables)] serializer: &Serializer, + rng: &mut dyn CryptoRngCore, + serializer: &Serializer, + deserializer: &Deserializer, ) -> Result { match &self.state { - ChainState::Protocol1 { round, .. } => round.make_echo_broadcast(rng, serializer), - ChainState::Protocol2(round) => round.make_echo_broadcast(rng, serializer), + ChainState::Protocol1 { round, .. } => round.as_ref().make_echo_broadcast(rng, serializer, deserializer), + ChainState::Protocol2(round) => round.as_ref().make_echo_broadcast(rng, serializer, deserializer), } } fn make_normal_broadcast( &self, - #[allow(unused_variables)] rng: &mut impl CryptoRngCore, - #[allow(unused_variables)] serializer: &Serializer, + rng: &mut dyn CryptoRngCore, + serializer: &Serializer, + deserializer: &Deserializer, ) -> Result { match &self.state { - ChainState::Protocol1 { round, .. } => round.make_normal_broadcast(rng, serializer), - ChainState::Protocol2(round) => round.make_normal_broadcast(rng, serializer), + ChainState::Protocol1 { round, .. } => round.as_ref().make_normal_broadcast(rng, serializer, deserializer), + ChainState::Protocol2(round) => round.as_ref().make_normal_broadcast(rng, serializer, deserializer), } } fn receive_message( &self, - rng: &mut impl CryptoRngCore, + rng: &mut dyn CryptoRngCore, deserializer: &Deserializer, from: &Id, echo_broadcast: EchoBroadcast, @@ -240,7 +343,7 @@ where direct_message: DirectMessage, ) -> Result> { match &self.state { - ChainState::Protocol1 { round, .. } => match round.receive_message( + ChainState::Protocol1 { round, .. } => match round.as_ref().receive_message( rng, deserializer, from, @@ -251,7 +354,7 @@ where Ok(payload) => Ok(payload), Err(err) => Err(err.map(ChainedProtocolError::from_protocol1)), }, - ChainState::Protocol2(round) => match round.receive_message( + ChainState::Protocol2(round) => match round.as_ref().receive_message( rng, deserializer, from, @@ -266,8 +369,8 @@ where } fn finalize( - self, - rng: &mut impl CryptoRngCore, + self: Box, + rng: &mut dyn CryptoRngCore, payloads: BTreeMap, artifacts: BTreeMap, ) -> Result, FinalizeError> { @@ -277,44 +380,58 @@ where id, inputs, shared_randomness, - } => match round.finalize(rng, payloads, artifacts) { + } => match round.into_boxed().finalize(rng, payloads, artifacts) { Ok(FinalizeOutcome::Result(result)) => { - let round = C::EntryPoint2::new(rng, &shared_randomness, id, (inputs, result).into())?; + let mut boxed_rng = BoxedRng(rng); + let round = C::EntryPoint2::new(&mut boxed_rng, &shared_randomness, id, (inputs, result).into())?; - Ok(FinalizeOutcome::another_round(Chain:: { - state: ChainState::Protocol2(Box::new(ObjectSafeRoundWrapper::new(round))), - })) + Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_object_safe(Chain::< + Id, + C, + > { + state: ChainState::Protocol2(round), + }))) } - Ok(FinalizeOutcome::AnotherRound(another_round)) => { - Ok(FinalizeOutcome::another_round(Chain:: { + Ok(FinalizeOutcome::AnotherRound(round)) => { + Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_object_safe(Chain::< + Id, + C, + > { state: ChainState::Protocol1 { shared_randomness, id, inputs, - round: another_round.into_boxed(), + round, }, - })) + }))) } Err(FinalizeError::Local(err)) => Err(FinalizeError::Local(err)), - Err(FinalizeError::Unattributable(proof)) => Err(FinalizeError::Unattributable(proof.into())), + Err(FinalizeError::Unattributable(proof)) => Err(FinalizeError::Unattributable( + ChainedCorrectnessProof::from_protocol1(proof), + )), }, - ChainState::Protocol2(round) => match round.finalize(rng, payloads, artifacts) { + ChainState::Protocol2(round) => match round.into_boxed().finalize(rng, payloads, artifacts) { Ok(FinalizeOutcome::Result(result)) => Ok(FinalizeOutcome::Result(result)), - Ok(FinalizeOutcome::AnotherRound(another_round)) => { - Ok(FinalizeOutcome::another_round(Chain:: { - state: ChainState::Protocol2(another_round.into_boxed()), - })) + Ok(FinalizeOutcome::AnotherRound(round)) => { + Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_object_safe(Chain::< + Id, + C, + > { + state: ChainState::Protocol2(round), + }))) } Err(FinalizeError::Local(err)) => Err(FinalizeError::Local(err)), - Err(FinalizeError::Unattributable(proof)) => Err(FinalizeError::Unattributable(proof.into())), + Err(FinalizeError::Unattributable(proof)) => Err(FinalizeError::Unattributable( + ChainedCorrectnessProof::from_protocol2(proof), + )), }, } } fn expecting_messages_from(&self) -> &BTreeSet { match &self.state { - ChainState::Protocol1 { round, .. } => round.expecting_messages_from(), - ChainState::Protocol2(round) => round.expecting_messages_from(), + ChainState::Protocol1 { round, .. } => round.as_ref().expecting_messages_from(), + ChainState::Protocol2(round) => round.as_ref().expecting_messages_from(), } } } diff --git a/manul/src/combinators/misbehave.rs b/manul/src/combinators/misbehave.rs index 867213e..c26d391 100644 --- a/manul/src/combinators/misbehave.rs +++ b/manul/src/combinators/misbehave.rs @@ -10,20 +10,20 @@ use crate::protocol::*; #[derive_where::derive_where(Debug)] pub struct MisbehavingRound> { - round: Box>, + round: BoxedRound>::Protocol>, behavior: Option, } pub trait Misbehaving { - type Protocol: Protocol; - type FirstRound: FirstRound; + type EntryPoint: EntryPoint; #[allow(unused_variables)] fn amend_echo_broadcast( rng: &mut impl CryptoRngCore, - round: &dyn ObjectSafeRound, + round: &BoxedRound>::Protocol>, behavior: &B, serializer: &Serializer, + deserializer: &Deserializer, echo_broadcast: EchoBroadcast, ) -> Result { Ok(echo_broadcast) @@ -32,9 +32,10 @@ pub trait Misbehaving { #[allow(unused_variables)] fn amend_normal_broadcast( rng: &mut impl CryptoRngCore, - round: &dyn ObjectSafeRound, + round: &BoxedRound>::Protocol>, behavior: &B, serializer: &Serializer, + deserializer: &Deserializer, normal_broadcast: NormalBroadcast, ) -> Result { Ok(normal_broadcast) @@ -43,9 +44,10 @@ pub trait Misbehaving { #[allow(unused_variables)] fn amend_direct_message( rng: &mut impl CryptoRngCore, - round: &dyn ObjectSafeRound, + round: &BoxedRound>::Protocol>, behavior: &B, serializer: &Serializer, + deserializer: &Deserializer, destination: &Id, direct_message: DirectMessage, artifact: Option, @@ -56,64 +58,72 @@ pub trait Misbehaving { pub struct MisbehavingInputs> { pub behavior: Option, - pub inner_inputs: >::Inputs, + pub inner_inputs: >::Inputs, } -impl FirstRound for MisbehavingRound +impl EntryPoint for MisbehavingRound where Id: PartyId, B: 'static + Debug + Send + Sync, - MP: 'static + Misbehaving, + M: 'static + Misbehaving, { - type Inputs = MisbehavingInputs; + type Inputs = MisbehavingInputs; + type Protocol = >::Protocol; + const RESULT_ROUND: RoundId = >::RESULT_ROUND; fn new( rng: &mut impl CryptoRngCore, shared_randomness: &[u8], id: Id, inputs: Self::Inputs, - ) -> Result { - let inner_round = MP::FirstRound::new(rng, shared_randomness, id, inputs.inner_inputs)?; - Ok(Self { - round: Box::new(ObjectSafeRoundWrapper::new(inner_round)), + ) -> Result>::Protocol>, LocalError> { + let round = M::EntryPoint::new(rng, shared_randomness, id, inputs.inner_inputs)?; + Ok(BoxedRound::new_object_safe(Self { + round, behavior: inputs.behavior, - }) + })) } } -impl Round for MisbehavingRound +impl ObjectSafeRound for MisbehavingRound where Id: PartyId, B: 'static + Debug + Send + Sync, M: 'static + Misbehaving, { - type Protocol = M::Protocol; + type Protocol = >::Protocol; fn id(&self) -> RoundId { - self.round.id() + self.round.as_ref().id() } fn possible_next_rounds(&self) -> BTreeSet { - self.round.possible_next_rounds() + self.round.as_ref().possible_next_rounds() } fn message_destinations(&self) -> &BTreeSet { - self.round.message_destinations() + self.round.as_ref().message_destinations() } fn make_direct_message( &self, - rng: &mut impl CryptoRngCore, + rng: &mut dyn CryptoRngCore, serializer: &Serializer, + deserializer: &Deserializer, destination: &Id, ) -> Result<(DirectMessage, Option), LocalError> { - let (direct_message, artifact) = self.round.make_direct_message(rng, serializer, destination)?; + let (direct_message, artifact) = + self.round + .as_ref() + .make_direct_message(rng, serializer, deserializer, destination)?; if let Some(behavior) = self.behavior.as_ref() { + let mut boxed_rng = BoxedRng(rng); M::amend_direct_message( - rng, - self.round.as_ref(), + &mut boxed_rng, + &self.round, behavior, serializer, + deserializer, destination, direct_message, artifact, @@ -125,12 +135,21 @@ where fn make_echo_broadcast( &self, - rng: &mut impl CryptoRngCore, + rng: &mut dyn CryptoRngCore, serializer: &Serializer, + deserializer: &Deserializer, ) -> Result { - let echo_broadcast = self.round.make_echo_broadcast(rng, serializer)?; + let echo_broadcast = self.round.as_ref().make_echo_broadcast(rng, serializer, deserializer)?; if let Some(behavior) = self.behavior.as_ref() { - M::amend_echo_broadcast(rng, self.round.as_ref(), behavior, serializer, echo_broadcast) + let mut boxed_rng = BoxedRng(rng); + M::amend_echo_broadcast( + &mut boxed_rng, + &self.round, + behavior, + serializer, + deserializer, + echo_broadcast, + ) } else { Ok(echo_broadcast) } @@ -138,12 +157,24 @@ where fn make_normal_broadcast( &self, - rng: &mut impl CryptoRngCore, + rng: &mut dyn CryptoRngCore, serializer: &Serializer, + deserializer: &Deserializer, ) -> Result { - let normal_broadcast = self.round.make_normal_broadcast(rng, serializer)?; + let normal_broadcast = self + .round + .as_ref() + .make_normal_broadcast(rng, serializer, deserializer)?; if let Some(behavior) = self.behavior.as_ref() { - M::amend_normal_broadcast(rng, self.round.as_ref(), behavior, serializer, normal_broadcast) + let mut boxed_rng = BoxedRng(rng); + M::amend_normal_broadcast( + &mut boxed_rng, + &self.round, + behavior, + serializer, + deserializer, + normal_broadcast, + ) } else { Ok(normal_broadcast) } @@ -151,14 +182,14 @@ where fn receive_message( &self, - rng: &mut impl CryptoRngCore, + rng: &mut dyn CryptoRngCore, deserializer: &Deserializer, from: &Id, echo_broadcast: EchoBroadcast, normal_broadcast: NormalBroadcast, direct_message: DirectMessage, ) -> Result> { - self.round.receive_message( + self.round.as_ref().receive_message( rng, deserializer, from, @@ -169,24 +200,24 @@ where } fn finalize( - self, - rng: &mut impl CryptoRngCore, + self: Box, + rng: &mut dyn CryptoRngCore, payloads: BTreeMap, artifacts: BTreeMap, ) -> Result, FinalizeError> { - match self.round.finalize(rng, payloads, artifacts) { + match self.round.into_boxed().finalize(rng, payloads, artifacts) { Ok(FinalizeOutcome::Result(result)) => Ok(FinalizeOutcome::Result(result)), - Ok(FinalizeOutcome::AnotherRound(another_round)) => { - Ok(FinalizeOutcome::another_round(MisbehavingRound:: { - round: another_round.into_boxed(), + Ok(FinalizeOutcome::AnotherRound(round)) => Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_object_safe( + MisbehavingRound:: { + round, behavior: self.behavior, - })) - } + }, + ))), Err(err) => Err(err), } } fn expecting_messages_from(&self) -> &BTreeSet { - self.round.expecting_messages_from() + self.round.as_ref().expecting_messages_from() } } diff --git a/manul/src/lib.rs b/manul/src/lib.rs index ad55b47..a408315 100644 --- a/manul/src/lib.rs +++ b/manul/src/lib.rs @@ -5,13 +5,13 @@ clippy::mod_module_files, clippy::unwrap_used, clippy::indexing_slicing, - missing_docs, + //missing_docs, missing_copy_implementations, rust_2018_idioms, trivial_casts, trivial_numeric_casts, unused_qualifications, - missing_debug_implementations + //missing_debug_implementations )] extern crate alloc; diff --git a/manul/src/protocol.rs b/manul/src/protocol.rs index 52c104d..da19cdc 100644 --- a/manul/src/protocol.rs +++ b/manul/src/protocol.rs @@ -4,7 +4,7 @@ API for protocol implementors. A protocol is a directed acyclic graph with the nodes being objects of types implementing [`Round`] (to be specific, "acyclic" means that the values returned by [`Round::id`] should not repeat during the protocol execution; the types might). -The starting point is a type that implements [`FirstRound`]. +The starting point is a type that implements [`EntryPoint`]. All the rounds must have their associated type [`Round::Protocol`] set to the same [`Protocol`] instance to be executed by a [`Session`](`crate::session::Session`). @@ -23,12 +23,12 @@ pub use errors::{ }; pub use message::{DirectMessage, EchoBroadcast, NormalBroadcast, ProtocolMessagePart}; pub use round::{ - AnotherRound, Artifact, FinalizeOutcome, FirstRound, PartyId, Payload, Protocol, ProtocolError, Round, RoundId, + Artifact, CorrectnessProof, EntryPoint, FinalizeOutcome, PartyId, Payload, Protocol, ProtocolError, Round, RoundId, }; pub use serialization::{Deserializer, Serializer}; pub(crate) use errors::ReceiveErrorType; -pub use object_safe::ObjectSafeRound; -pub(crate) use object_safe::ObjectSafeRoundWrapper; +pub use object_safe::BoxedRound; +pub(crate) use object_safe::{BoxedRng, ObjectSafeRound}; pub use digest; diff --git a/manul/src/protocol/object_safe.rs b/manul/src/protocol/object_safe.rs index 123b3b0..daf2173 100644 --- a/manul/src/protocol/object_safe.rs +++ b/manul/src/protocol/object_safe.rs @@ -17,7 +17,7 @@ use super::{ /// Since object-safe trait methods cannot take `impl CryptoRngCore` arguments, /// this structure wraps the dynamic object and exposes a `CryptoRngCore` interface, /// to be passed to statically typed round methods. -struct BoxedRng<'a>(&'a mut dyn CryptoRngCore); +pub(crate) struct BoxedRng<'a>(pub(crate) &'a mut dyn CryptoRngCore); impl CryptoRng for BoxedRng<'_> {} @@ -52,6 +52,7 @@ pub trait ObjectSafeRound: 'static + Debug + Send + Sync { &self, rng: &mut dyn CryptoRngCore, serializer: &Serializer, + deserializer: &Deserializer, destination: &Id, ) -> Result<(DirectMessage, Option), LocalError>; @@ -59,12 +60,14 @@ pub trait ObjectSafeRound: 'static + Debug + Send + Sync { &self, rng: &mut dyn CryptoRngCore, serializer: &Serializer, + deserializer: &Deserializer, ) -> Result; fn make_normal_broadcast( &self, rng: &mut dyn CryptoRngCore, serializer: &Serializer, + deserializer: &Deserializer, ) -> Result; fn receive_message( @@ -87,7 +90,9 @@ pub trait ObjectSafeRound: 'static + Debug + Send + Sync { fn expecting_messages_from(&self) -> &BTreeSet; /// Returns the type ID of the implementing type. - fn get_type_id(&self) -> core::any::TypeId; + fn get_type_id(&self) -> core::any::TypeId { + core::any::TypeId::of::() + } } // The `fn(Id) -> Id` bit is so that `ObjectSafeRoundWrapper` didn't require a bound on `Id` to be @@ -134,6 +139,7 @@ where &self, rng: &mut dyn CryptoRngCore, serializer: &Serializer, + #[allow(unused_variables)] deserializer: &Deserializer, destination: &Id, ) -> Result<(DirectMessage, Option), LocalError> { let mut boxed_rng = BoxedRng(rng); @@ -144,6 +150,7 @@ where &self, rng: &mut dyn CryptoRngCore, serializer: &Serializer, + #[allow(unused_variables)] deserializer: &Deserializer, ) -> Result { let mut boxed_rng = BoxedRng(rng); self.round.make_echo_broadcast(&mut boxed_rng, serializer) @@ -153,6 +160,7 @@ where &self, rng: &mut dyn CryptoRngCore, serializer: &Serializer, + #[allow(unused_variables)] deserializer: &Deserializer, ) -> Result { let mut boxed_rng = BoxedRng(rng); self.round.make_normal_broadcast(&mut boxed_rng, serializer) @@ -191,10 +199,6 @@ where fn expecting_messages_from(&self) -> &BTreeSet { self.round.expecting_messages_from() } - - fn get_type_id(&self) -> core::any::TypeId { - core::any::TypeId::of::() - } } // When we are wrapping types implementing Round and overriding `finalize()`, @@ -209,8 +213,51 @@ where Id: PartyId, P: Protocol, { +} + +// We do not want to expose `ObjectSafeRound` to the user, so it is hidden in a struct. +/// A wrapped new round that may be returned by [`Round::finalize`]. +#[derive_where::derive_where(Debug)] +pub struct BoxedRound { + wrapped: bool, + round: Box>, +} + +impl BoxedRound { + pub fn new_dynamic>(round: R) -> Self { + Self { + wrapped: true, + round: Box::new(ObjectSafeRoundWrapper::new(round)), + } + } + + pub(crate) fn new_object_safe>(round: R) -> Self { + Self { + wrapped: false, + round: Box::new(round), + } + } + + pub(crate) fn as_ref(&self) -> &dyn ObjectSafeRound { + self.round.as_ref() + } + + pub(crate) fn into_boxed(self) -> Box> { + self.round + } + + /* + pub fn downcast>(self) -> Result { + self.round.downcast::() + } + + fn boxed_type_is(&self) -> bool { + core::any::TypeId::of::() == self.round.get_type_id() + } + + /// Attempts to extract an object of a concrete type, preserving the original on failure. pub fn try_downcast>(self: Box) -> Result> { - if core::any::TypeId::of::>() == self.get_type_id() { + if self.wrapped && self.boxed_type_is::>() { // This should be safe since we just checked that we are casting to a correct type. let boxed_downcast = unsafe { Box::>::from_raw(Box::into_raw(self) as *mut ObjectSafeRoundWrapper) @@ -221,14 +268,20 @@ where } } + /// Attempts to extract an object of a concrete type. pub fn downcast>(self: Box) -> Result { self.try_downcast() .map_err(|_| LocalError::new(format!("Failed to downcast into type {}", core::any::type_name::()))) } + */ + + pub fn id(&self) -> RoundId { + self.round.id() + } pub fn downcast_ref>(&self) -> Result<&T, LocalError> { - if core::any::TypeId::of::>() == self.get_type_id() { - let ptr: *const dyn ObjectSafeRound = self; + if self.wrapped && core::any::TypeId::of::>() == self.round.get_type_id() { + let ptr: *const dyn ObjectSafeRound = self.round.as_ref(); // This should be safe since we just checked that we are casting to a correct type. Ok(unsafe { &*(ptr as *const T) }) } else { diff --git a/manul/src/protocol/round.rs b/manul/src/protocol/round.rs index 681fdb8..db32c49 100644 --- a/manul/src/protocol/round.rs +++ b/manul/src/protocol/round.rs @@ -2,10 +2,13 @@ use alloc::{ boxed::Box, collections::{BTreeMap, BTreeSet}, format, - string::String, + string::{String, ToString}, vec::Vec, }; -use core::{any::Any, fmt::Debug}; +use core::{ + any::Any, + fmt::{self, Debug, Display}, +}; use rand_core::CryptoRngCore; use serde::{Deserialize, Serialize}; @@ -13,15 +16,15 @@ use serde::{Deserialize, Serialize}; use super::{ errors::{FinalizeError, LocalError, MessageValidationError, ProtocolValidationError, ReceiveError}, message::{DirectMessage, EchoBroadcast, NormalBroadcast, ProtocolMessagePart}, - object_safe::{ObjectSafeRound, ObjectSafeRoundWrapper}, + object_safe::BoxedRound, serialization::{Deserializer, Serializer}, }; /// Possible successful outcomes of [`Round::finalize`]. #[derive(Debug)] -pub enum FinalizeOutcome { +pub enum FinalizeOutcome { /// Transition to a new round. - AnotherRound(AnotherRound), + AnotherRound(BoxedRound), /// The protocol reached a result. Result(P::Result), } @@ -33,58 +36,78 @@ where { /// A helper method to create an [`AnotherRound`](`Self::AnotherRound`) variant. pub fn another_round(round: impl Round) -> Self { - Self::AnotherRound(AnotherRound::new(round)) - } -} - -// We do not want to expose `ObjectSafeRound` to the user, so it is hidden in a struct. -/// A wrapped new round that may be returned by [`Round::finalize`]. -#[derive(Debug)] -pub struct AnotherRound(Box>); - -impl AnotherRound -where - Id: PartyId, - P: Protocol, -{ - /// Wraps an object implementing [`Round`]. - pub fn new(round: impl Round) -> Self { - Self(Box::new(ObjectSafeRoundWrapper::new(round))) - } - - /// Returns the inner boxed type. - /// This is an internal method to be used in `Session`. - pub(crate) fn into_boxed(self) -> Box> { - self.0 - } - - /// Attempts to extract an object of a concrete type. - pub fn downcast>(self) -> Result { - self.0.downcast::() - } - - /// Attempts to extract an object of a concrete type, preserving the original on failure. - pub fn try_downcast>(self) -> Result { - self.0.try_downcast::().map_err(Self) + Self::AnotherRound(BoxedRound::new_dynamic(round)) } } /// A round identifier. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub struct RoundId { - round_num: u8, + depth: u8, + round_num: [u8; 8], is_echo: bool, } +impl Display for RoundId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + /*write!(f, "Round ")?; + for i in (0..self.depth as usize).rev() { + write!(f, "{}", self.round_num.get(i).ok_or_else(fmt::Error::)?)?; + if i != 0 { + write!(f, "-")?; + } + } + Ok(())*/ + let full_num = self + .round_num + .get(0..self.depth as usize) + .expect("Depth within range") + .iter() + .rev() + .map(|round| round.to_string()) + .collect::>() + .join("-"); + write!(f, "Round {}", full_num) + } +} + impl RoundId { /// Creates a new round identifier. - pub fn new(round_num: u8) -> Self { + pub const fn new(round_num: u8) -> Self { Self { - round_num, + depth: 1, + round_num: [round_num, 0, 0, 0, 0, 0, 0, 0], is_echo: false, } } + pub(crate) const fn group_under(&self, round_num: u8) -> Self { + if self.depth == 8 { + panic!("Maximum depth reached"); + } + let mut round_nums = self.round_num; + round_nums[self.depth as usize] = round_num; + Self { + depth: self.depth + 1, + round_num: round_nums, + is_echo: self.is_echo, + } + } + + pub(crate) fn ungroup(&self) -> Result { + if self.depth == 1 { + Err(LocalError::new("This round ID is not in a group")) + } else { + let mut round_nums = self.round_num; + round_nums[self.depth as usize - 1] = 0; + Ok(Self { + depth: self.depth - 1, + round_num: round_nums, + is_echo: self.is_echo, + }) + } + } + /// Returns `true` if this is an ID of an echo broadcast round. pub(crate) fn is_echo(&self) -> bool { self.is_echo @@ -100,7 +123,8 @@ impl RoundId { panic!("This is already an echo round ID"); } Self { - round_num: self.round_num, + depth: self.depth, + round_num: self.round_num.clone(), is_echo: true, } } @@ -115,7 +139,8 @@ impl RoundId { panic!("This is already an non-echo round ID"); } Self { - round_num: self.round_num, + depth: self.depth, + round_num: self.round_num.clone(), is_echo: false, } } @@ -132,7 +157,7 @@ pub trait Protocol: 'static { /// An object of this type will be returned when an unattributable error happens during [`Round::finalize`]. /// /// It proves that the node did its job correctly, to be adjudicated by a third party. - type CorrectnessProof: Send + Serialize + for<'de> Deserialize<'de> + Debug; + type CorrectnessProof: CorrectnessProof + Serialize + for<'de> Deserialize<'de>; /// Returns `Ok(())` if the given direct message cannot be deserialized /// assuming it is a direct message from the round `round_id`. @@ -244,6 +269,36 @@ pub trait ProtocolError: Debug + Clone + Send { ) -> Result<(), ProtocolValidationError>; } +impl ProtocolError for () { + fn description(&self) -> String { + unimplemented!() + } + + fn verify_messages_constitute_error( + &self, + _deserializer: &Deserializer, + _echo_broadcast: &EchoBroadcast, + _normal_broadcast: &NormalBroadcast, + _direct_message: &DirectMessage, + _echo_broadcasts: &BTreeMap, + _normal_broadcasts: &BTreeMap, + _direct_messages: &BTreeMap, + _combined_echos: &BTreeMap>, + ) -> Result<(), ProtocolValidationError> { + unimplemented!() + } +} + +/// Describes unattributable errors originating during protocol execution. +/// +/// In the situations where no specific message can be blamed for an error, +/// each node must generate a correctness proof proving that they performed their duties correctly, +/// and the collection of proofs is verified by a third party. +/// One of the proofs will be necessarily missing or invalid. +pub trait CorrectnessProof: Debug + Clone + Send {} + +impl CorrectnessProof for () {} + /// Message payload created in [`Round::receive_message`]. #[derive(Debug)] pub struct Payload(pub Box); @@ -301,10 +356,15 @@ impl Artifact { /// /// This is a round that can be created directly; /// all the others are only reachable throud [`Round::finalize`] by the execution layer. -pub trait FirstRound: Round + Sized { +pub trait EntryPoint { /// Additional inputs for the protocol (besides the mandatory ones in [`new`](`Self::new`)). type Inputs; + type Protocol: Protocol; + + const ENTRY_ROUND: RoundId = RoundId::new(1); + const RESULT_ROUND: RoundId; + /// Creates the round. /// /// `session_id` can be assumed to be the same for each node participating in a session. @@ -314,7 +374,7 @@ pub trait FirstRound: Round + Sized { shared_randomness: &[u8], id: Id, inputs: Self::Inputs, - ) -> Result; + ) -> Result, LocalError>; } /// A trait alias for the combination of traits needed for a party identifier. diff --git a/manul/src/session/echo.rs b/manul/src/session/echo.rs index 5bbd124..936df48 100644 --- a/manul/src/session/echo.rs +++ b/manul/src/session/echo.rs @@ -1,5 +1,4 @@ use alloc::{ - boxed::Box, collections::{BTreeMap, BTreeSet}, format, string::String, @@ -18,8 +17,8 @@ use super::{ }; use crate::{ protocol::{ - Artifact, Deserializer, DirectMessage, EchoBroadcast, FinalizeError, FinalizeOutcome, MessageValidationError, - NormalBroadcast, ObjectSafeRound, Payload, Protocol, ProtocolMessagePart, ReceiveError, Round, RoundId, + Artifact, BoxedRound, Deserializer, DirectMessage, EchoBroadcast, FinalizeError, FinalizeOutcome, + MessageValidationError, NormalBroadcast, Payload, Protocol, ProtocolMessagePart, ReceiveError, Round, RoundId, Serializer, }, utils::SerializableMap, @@ -75,12 +74,12 @@ pub(crate) struct EchoRoundMessage { /// participants. The execution layer of the protocol guarantees that all participants have received /// the messages. #[derive_where::derive_where(Debug)] -pub struct EchoRound { +pub struct EchoRound { verifier: SP::Verifier, echo_broadcasts: BTreeMap>, destinations: BTreeSet, expected_echos: BTreeSet, - main_round: Box>, + main_round: BoxedRound, payloads: BTreeMap, artifacts: BTreeMap, } @@ -94,7 +93,7 @@ where verifier: SP::Verifier, my_echo_broadcast: SignedMessagePart, echo_broadcasts: BTreeMap>, - main_round: Box>, + main_round: BoxedRound, payloads: BTreeMap, artifacts: BTreeMap, ) -> Self { @@ -147,11 +146,11 @@ where type Protocol = P; fn id(&self) -> RoundId { - self.main_round.id().echo() + self.main_round.as_ref().id().echo() } fn possible_next_rounds(&self) -> BTreeSet { - self.main_round.possible_next_rounds() + self.main_round.as_ref().possible_next_rounds() } fn message_destinations(&self) -> &BTreeSet { @@ -297,6 +296,8 @@ where _payloads: BTreeMap, _artifacts: BTreeMap, ) -> Result, FinalizeError> { - self.main_round.finalize(rng, self.payloads, self.artifacts) + self.main_round + .into_boxed() + .finalize(rng, self.payloads, self.artifacts) } } diff --git a/manul/src/session/message.rs b/manul/src/session/message.rs index 0016e7c..5705e6b 100644 --- a/manul/src/session/message.rs +++ b/manul/src/session/message.rs @@ -56,10 +56,10 @@ pub(crate) struct MessageMetadata { } impl MessageMetadata { - pub fn new(session_id: &SessionId, round_id: RoundId) -> Self { + pub fn new(session_id: &SessionId, round_id: &RoundId) -> Self { Self { session_id: session_id.clone(), - round_id, + round_id: round_id.clone(), } } @@ -103,7 +103,7 @@ where rng: &mut impl CryptoRngCore, signer: &SP::Signer, session_id: &SessionId, - round_id: RoundId, + round_id: &RoundId, message: M, ) -> Result where @@ -193,7 +193,7 @@ where rng: &mut impl CryptoRngCore, signer: &SP::Signer, session_id: &SessionId, - round_id: RoundId, + round_id: &RoundId, destination: &Verifier, direct_message: DirectMessage, echo_broadcast: SignedMessagePart, diff --git a/manul/src/session/session.rs b/manul/src/session/session.rs index 5368032..25d0247 100644 --- a/manul/src/session/session.rs +++ b/manul/src/session/session.rs @@ -23,9 +23,9 @@ use super::{ LocalError, RemoteError, }; use crate::protocol::{ - Artifact, Deserializer, DirectMessage, EchoBroadcast, FinalizeError, FinalizeOutcome, FirstRound, NormalBroadcast, - ObjectSafeRound, ObjectSafeRoundWrapper, PartyId, Payload, Protocol, ProtocolMessagePart, ReceiveError, - ReceiveErrorType, Round, RoundId, Serializer, + Artifact, BoxedRound, Deserializer, DirectMessage, EchoBroadcast, EntryPoint, FinalizeError, FinalizeOutcome, + NormalBroadcast, PartyId, Payload, Protocol, ProtocolMessagePart, ReceiveError, ReceiveErrorType, RoundId, + Serializer, }; /// A set of types needed to execute a session. @@ -106,7 +106,7 @@ pub struct Session { verifier: SP::Verifier, serializer: Serializer, deserializer: Deserializer, - round: Box>, + round: BoxedRound, message_destinations: BTreeSet, echo_broadcast: SignedMessagePart, normal_broadcast: SignedMessagePart, @@ -141,15 +141,10 @@ where inputs: R::Inputs, ) -> Result where - R: FirstRound + Round, + R: EntryPoint, { let verifier = signer.verifying_key(); - let first_round = Box::new(ObjectSafeRoundWrapper::new(R::new( - rng, - session_id.as_ref(), - verifier.clone(), - inputs, - )?)); + let first_round = R::new(rng, session_id.as_ref(), verifier.clone(), inputs)?; let serializer = Serializer::new::(); let deserializer = Deserializer::new::(); Self::new_for_next_round( @@ -169,23 +164,25 @@ where signer: SP::Signer, serializer: Serializer, deserializer: Deserializer, - round: Box>, + round: BoxedRound, transcript: Transcript, ) -> Result { let verifier = signer.verifying_key(); - let echo = round.make_echo_broadcast(rng, &serializer)?; - let echo_broadcast = SignedMessagePart::new::(rng, &signer, &session_id, round.id(), echo)?; + let round_id = round.as_ref().id(); + + let echo = round.as_ref().make_echo_broadcast(rng, &serializer, &deserializer)?; + let echo_broadcast = SignedMessagePart::new::(rng, &signer, &session_id, &round_id, echo)?; - let normal = round.make_normal_broadcast(rng, &serializer)?; - let normal_broadcast = SignedMessagePart::new::(rng, &signer, &session_id, round.id(), normal)?; + let normal = round.as_ref().make_normal_broadcast(rng, &serializer, &deserializer)?; + let normal_broadcast = SignedMessagePart::new::(rng, &signer, &session_id, &round_id, normal)?; - let message_destinations = round.message_destinations().clone(); + let message_destinations = round.as_ref().message_destinations().clone(); let possible_next_rounds = if echo_broadcast.payload().is_none() { - round.possible_next_rounds() + round.as_ref().possible_next_rounds() } else { - BTreeSet::from([round.id().echo()]) + BTreeSet::from([round_id.echo()]) }; Ok(Self { @@ -226,13 +223,16 @@ where rng: &mut impl CryptoRngCore, destination: &SP::Verifier, ) -> Result<(Message, ProcessedArtifact), LocalError> { - let (direct_message, artifact) = self.round.make_direct_message(rng, &self.serializer, destination)?; + let (direct_message, artifact) = + self.round + .as_ref() + .make_direct_message(rng, &self.serializer, &self.deserializer, destination)?; let message = Message::new::( rng, &self.signer, &self.session_id, - self.round.id(), + &self.round.as_ref().id(), destination, direct_message, self.echo_broadcast.clone(), @@ -258,7 +258,7 @@ where /// Returns the ID of the current round. pub fn round_id(&self) -> RoundId { - self.round.id() + self.round.as_ref().id() } /// Performs some preliminary checks on the message to verify its integrity. @@ -374,7 +374,7 @@ where rng: &mut impl CryptoRngCore, message: VerifiedMessage, ) -> ProcessedMessage { - let processed = self.round.receive_message( + let processed = self.round.as_ref().receive_message( rng, &self.deserializer, message.from(), @@ -398,7 +398,7 @@ where /// Makes an accumulator for a new round. pub fn make_accumulator(&self) -> RoundAccumulator { - RoundAccumulator::new(self.round.expecting_messages_from()) + RoundAccumulator::new(self.round.as_ref().expecting_messages_from()) } fn terminate_inner( @@ -460,15 +460,15 @@ where let echo_round_needed = !self.echo_broadcast.payload().is_none(); if echo_round_needed { - let round = Box::new(ObjectSafeRoundWrapper::new(EchoRound::::new( + let round = BoxedRound::new_dynamic(EchoRound::::new( verifier, self.echo_broadcast, transcript.echo_broadcasts(round_id)?, self.round, accum.payloads, accum.artifacts, - ))); - let cached_messages = filter_messages(accum.cached, round.id()); + )); + let cached_messages = filter_messages(accum.cached, round.as_ref().id()); let session = Session::new_for_next_round( rng, self.session_id, @@ -484,24 +484,25 @@ where }); } - match self.round.finalize(rng, accum.payloads, accum.artifacts) { + match self.round.into_boxed().finalize(rng, accum.payloads, accum.artifacts) { Ok(result) => Ok(match result { FinalizeOutcome::Result(result) => { RoundOutcome::Finished(SessionReport::new(SessionOutcome::Result(result), transcript)) } - FinalizeOutcome::AnotherRound(another_round) => { - let round = another_round.into_boxed(); - + FinalizeOutcome::AnotherRound(round) => { // Protecting against common bugs - if !self.possible_next_rounds.contains(&round.id()) { - return Err(LocalError::new(format!("Unexpected next round id: {:?}", round.id()))); + if !self.possible_next_rounds.contains(&round.as_ref().id()) { + return Err(LocalError::new(format!( + "Unexpected next round id: {:?}", + round.as_ref().id() + ))); } // These messages could have been cached before // processing messages from the same node for the current round. // So there might have been some new errors, and we need to check again // if the sender is already banned. - let cached_messages = filter_messages(accum.cached, round.id()) + let cached_messages = filter_messages(accum.cached, round.as_ref().id()) .into_iter() .filter(|message| !transcript.is_banned(message.from())) .collect::>(); @@ -817,17 +818,11 @@ fn filter_messages( #[cfg(test)] mod tests { - use alloc::{collections::BTreeMap, string::String, vec::Vec}; - use impls::impls; - use serde::{Deserialize, Serialize}; use super::{Message, ProcessedArtifact, ProcessedMessage, Session, VerifiedMessage}; use crate::{ - protocol::{ - Deserializer, DirectMessage, EchoBroadcast, NormalBroadcast, Protocol, ProtocolError, - ProtocolValidationError, RoundId, - }, + protocol::Protocol, testing::{BinaryFormat, TestSessionParams, TestVerifier}, }; @@ -843,32 +838,9 @@ mod tests { struct DummyProtocol; - #[derive(Debug, Clone, Serialize, Deserialize)] - struct DummyProtocolError; - - impl ProtocolError for DummyProtocolError { - fn description(&self) -> String { - unimplemented!() - } - - fn verify_messages_constitute_error( - &self, - _deserializer: &Deserializer, - _echo_broadcast: &EchoBroadcast, - _normal_broadcast: &NormalBroadcast, - _direct_message: &DirectMessage, - _echo_broadcasts: &BTreeMap, - _normal_broadcasts: &BTreeMap, - _direct_messages: &BTreeMap, - _combined_echos: &BTreeMap>, - ) -> Result<(), ProtocolValidationError> { - unimplemented!() - } - } - impl Protocol for DummyProtocol { type Result = (); - type ProtocolError = DummyProtocolError; + type ProtocolError = (); type CorrectnessProof = (); } diff --git a/manul/src/testing/run_sync.rs b/manul/src/testing/run_sync.rs index 83d24bc..9e06ac1 100644 --- a/manul/src/testing/run_sync.rs +++ b/manul/src/testing/run_sync.rs @@ -6,7 +6,7 @@ use signature::Keypair; use tracing::debug; use crate::{ - protocol::{FirstRound, Protocol}, + protocol::{EntryPoint, Protocol}, session::{ CanFinalize, LocalError, Message, RoundAccumulator, RoundOutcome, Session, SessionId, SessionParameters, SessionReport, @@ -94,7 +94,7 @@ pub fn run_sync( inputs: Vec<(SP::Signer, R::Inputs)>, ) -> Result>, LocalError> where - R: FirstRound, + R: EntryPoint, SP: SessionParameters, { let session_id = SessionId::random::(rng);