diff --git a/examples/src/simple.rs b/examples/src/simple.rs index eeab9ab..308216b 100644 --- a/examples/src/simple.rs +++ b/examples/src/simple.rs @@ -10,6 +10,7 @@ use rand_core::CryptoRngCore; use serde::{Deserialize, Serialize}; use tracing::debug; +#[derive(Debug)] pub struct SimpleProtocol; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -111,11 +112,6 @@ impl Protocol for SimpleProtocol { } } -#[derive(Debug, Clone)] -pub struct Inputs { - pub all_ids: BTreeSet, -} - #[derive(Debug)] pub(crate) struct Context { pub(crate) id: Id, @@ -149,30 +145,40 @@ struct Round1Payload { x: u8, } -impl EntryPoint for Round1 { - type Inputs = Inputs; +#[derive(Debug, Clone)] +pub struct SimpleProtocolEntryPoint { + my_id: Id, + all_ids: BTreeSet, +} + +impl SimpleProtocolEntryPoint { + pub fn new(my_id: Id, all_ids: BTreeSet) -> Self { + Self { my_id, all_ids } + } +} + +impl EntryPoint for SimpleProtocolEntryPoint { type Protocol = SimpleProtocol; - fn new( + fn make_round( + self, _rng: &mut impl CryptoRngCore, _shared_randomness: &[u8], - id: Id, - inputs: Self::Inputs, ) -> Result, LocalError> { // Just some numbers associated with IDs to use in the dummy protocol. // They will be the same on each node since IDs are ordered. - let ids_to_positions = inputs + let ids_to_positions = self .all_ids .iter() .enumerate() .map(|(idx, id)| (id.clone(), idx as u8)) .collect::>(); - let mut ids = inputs.all_ids; - ids.remove(&id); + let mut ids = self.all_ids; + ids.remove(&self.my_id); - Ok(BoxedRound::new_dynamic(Self { + Ok(BoxedRound::new_dynamic(Round1 { context: Context { - id, + id: self.my_id, other_ids: ids, ids_to_positions, }, @@ -401,12 +407,12 @@ mod tests { use manul::{ session::{signature::Keypair, SessionOutcome}, - testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier}, + testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner}, }; use rand_core::OsRng; use tracing_subscriber::EnvFilter; - use super::{Inputs, Round1}; + use super::SimpleProtocolEntryPoint; #[test] fn round() { @@ -415,14 +421,12 @@ mod tests { .iter() .map(|signer| signer.verifying_key()) .collect::>(); - let inputs = signers + let entry_points = signers .into_iter() .map(|signer| { ( signer, - Inputs { - all_ids: all_ids.clone(), - }, + SimpleProtocolEntryPoint::new(signer.verifying_key(), all_ids.clone()), ) }) .collect::>(); @@ -431,7 +435,7 @@ mod tests { .with_env_filter(EnvFilter::from_default_env()) .finish(); let reports = tracing::subscriber::with_default(my_subscriber, || { - run_sync::, TestSessionParams>(&mut OsRng, inputs).unwrap() + run_sync::<_, TestSessionParams>(&mut OsRng, entry_points).unwrap() }); for (_id, report) in reports { diff --git a/examples/src/simple_chain.rs b/examples/src/simple_chain.rs index e187dc6..07e9119 100644 --- a/examples/src/simple_chain.rs +++ b/examples/src/simple_chain.rs @@ -1,37 +1,74 @@ +use alloc::collections::BTreeSet; use core::fmt::Debug; +use rand_core::CryptoRngCore; use manul::{ - combinators::chain::{Chained, ChainedEntryPoint}, - protocol::PartyId, + combinators::chain::*, + protocol::{BoxedRound, EntryPoint, LocalError, PartyId, Protocol, RoundId}, }; -use super::simple::{Inputs, Round1}; +use super::simple::{SimpleProtocol, SimpleProtocolEntryPoint}; -pub struct ChainedSimple; +pub type DoubleSimpleProtocol = ChainedProtocol; -#[derive(Debug)] -pub struct NewInputs(Inputs); +pub struct DoubleSimpleEntryPoint { + my_id: Id, + all_ids: BTreeSet, +} -impl<'a, Id: PartyId> From<&'a NewInputs> for Inputs { - fn from(source: &'a NewInputs) -> Self { - source.0.clone() +impl DoubleSimpleEntryPoint { + pub fn new(my_id: Id, all_ids: BTreeSet) -> Self { + Self { my_id, all_ids } } } -impl From<(NewInputs, u8)> for Inputs { - fn from(source: (NewInputs, u8)) -> Self { - let (inputs, _result) = source; - inputs.0 +impl ChainedSplit for DoubleSimpleEntryPoint +where + Id: PartyId, +{ + type EntryPoint = SimpleProtocolEntryPoint; + fn make_entry_point1(self) -> (Self::EntryPoint, impl ChainedJoin) { + ( + SimpleProtocolEntryPoint::new(self.my_id.clone(), self.all_ids.clone()), + DoubleTransition { + my_id: self.my_id, + all_ids: self.all_ids, + }, + ) } } -impl Chained for ChainedSimple { - type Inputs = NewInputs; - type EntryPoint1 = Round1; - type EntryPoint2 = Round1; +#[derive(Debug)] +struct DoubleTransition { + my_id: Id, + all_ids: BTreeSet, +} + +impl ChainedJoin for DoubleTransition +where + Id: PartyId, +{ + type EntryPoint = SimpleProtocolEntryPoint; + fn make_entry_point2(self, _result: ::Result) -> Self::EntryPoint { + SimpleProtocolEntryPoint::new(self.my_id, self.all_ids) + } } -pub type DoubleSimpleEntryPoint = ChainedEntryPoint; +impl EntryPoint for DoubleSimpleEntryPoint { + type Protocol = DoubleSimpleProtocol; + + fn entry_round() -> RoundId { + >::entry_round_id() + } + + fn make_round( + self, + rng: &mut impl CryptoRngCore, + shared_randomness: &[u8], + ) -> Result, LocalError> { + make_chained_round(self, rng, shared_randomness) + } +} #[cfg(test)] mod tests { @@ -39,13 +76,12 @@ mod tests { use manul::{ session::{signature::Keypair, SessionOutcome}, - testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier}, + testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner}, }; use rand_core::OsRng; use tracing_subscriber::EnvFilter; - use super::{DoubleSimpleEntryPoint, NewInputs}; - use crate::simple::Inputs; + use super::DoubleSimpleEntryPoint; #[test] fn round() { @@ -54,14 +90,12 @@ mod tests { .iter() .map(|signer| signer.verifying_key()) .collect::>(); - let inputs = signers + let entry_points = signers .into_iter() .map(|signer| { ( signer, - NewInputs(Inputs { - all_ids: all_ids.clone(), - }), + DoubleSimpleEntryPoint::new(signer.verifying_key(), all_ids.clone()), ) }) .collect::>(); @@ -70,8 +104,7 @@ mod tests { .with_env_filter(EnvFilter::from_default_env()) .finish(); let reports = tracing::subscriber::with_default(my_subscriber, || { - run_sync::, TestSessionParams>(&mut OsRng, inputs) - .unwrap() + run_sync::<_, TestSessionParams>(&mut OsRng, entry_points).unwrap() }); for (_id, report) in reports { diff --git a/examples/src/simple_malicious.rs b/examples/src/simple_malicious.rs index 2f58f77..0f305a4 100644 --- a/examples/src/simple_malicious.rs +++ b/examples/src/simple_malicious.rs @@ -2,18 +2,18 @@ use alloc::collections::BTreeSet; use core::fmt::Debug; use manul::{ - combinators::misbehave::{Misbehaving, MisbehavingEntryPoint, MisbehavingInputs}, + combinators::misbehave::{Misbehaving, MisbehavingEntryPoint}, protocol::{ Artifact, BoxedRound, Deserializer, DirectMessage, EntryPoint, LocalError, PartyId, ProtocolMessagePart, RoundId, Serializer, }, session::signature::Keypair, - testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier}, + testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner}, }; use rand_core::{CryptoRngCore, OsRng}; use tracing_subscriber::EnvFilter; -use crate::simple::{Inputs, Round1, Round1Message, Round2, Round2Message}; +use crate::simple::{Round1, Round1Message, Round2, Round2Message, SimpleProtocolEntryPoint}; #[derive(Debug, Clone, Copy)] enum Behavior { @@ -25,7 +25,7 @@ enum Behavior { struct MaliciousLogic; impl Misbehaving for MaliciousLogic { - type EntryPoint = Round1; + type EntryPoint = SimpleProtocolEntryPoint; fn modify_direct_message( _rng: &mut impl CryptoRngCore, @@ -78,9 +78,8 @@ fn serialized_garbage() { .iter() .map(|signer| signer.verifying_key()) .collect::>(); - let inputs = Inputs { all_ids }; - let run_inputs = signers + let entry_points = signers .iter() .enumerate() .map(|(idx, signer)| { @@ -90,11 +89,11 @@ fn serialized_garbage() { None }; - let malicious_inputs = MisbehavingInputs { - inner_inputs: inputs.clone(), + let entry_point = MaliciousEntryPoint::new( + SimpleProtocolEntryPoint::new(signer.verifying_key(), all_ids.clone()), behavior, - }; - (*signer, malicious_inputs) + ); + (*signer, entry_point) }) .collect::>(); @@ -102,7 +101,7 @@ fn serialized_garbage() { .with_env_filter(EnvFilter::from_default_env()) .finish(); let mut reports = tracing::subscriber::with_default(my_subscriber, || { - run_sync::, TestSessionParams>(&mut OsRng, run_inputs).unwrap() + run_sync::<_, TestSessionParams>(&mut OsRng, entry_points).unwrap() }); let v0 = signers[0].verifying_key(); @@ -124,9 +123,8 @@ fn attributable_failure() { .iter() .map(|signer| signer.verifying_key()) .collect::>(); - let inputs = Inputs { all_ids }; - let run_inputs = signers + let entry_points = signers .iter() .enumerate() .map(|(idx, signer)| { @@ -136,11 +134,11 @@ fn attributable_failure() { None }; - let malicious_inputs = MisbehavingInputs { - inner_inputs: inputs.clone(), + let entry_point = MaliciousEntryPoint::new( + SimpleProtocolEntryPoint::new(signer.verifying_key(), all_ids.clone()), behavior, - }; - (*signer, malicious_inputs) + ); + (*signer, entry_point) }) .collect::>(); @@ -148,7 +146,7 @@ fn attributable_failure() { .with_env_filter(EnvFilter::from_default_env()) .finish(); let mut reports = tracing::subscriber::with_default(my_subscriber, || { - run_sync::, TestSessionParams>(&mut OsRng, run_inputs).unwrap() + run_sync::<_, TestSessionParams>(&mut OsRng, entry_points).unwrap() }); let v0 = signers[0].verifying_key(); @@ -170,9 +168,8 @@ fn attributable_failure_round2() { .iter() .map(|signer| signer.verifying_key()) .collect::>(); - let inputs = Inputs { all_ids }; - let run_inputs = signers + let entry_points = signers .iter() .enumerate() .map(|(idx, signer)| { @@ -182,11 +179,11 @@ fn attributable_failure_round2() { None }; - let malicious_inputs = MisbehavingInputs { - inner_inputs: inputs.clone(), + let entry_point = MaliciousEntryPoint::new( + SimpleProtocolEntryPoint::new(signer.verifying_key(), all_ids.clone()), behavior, - }; - (*signer, malicious_inputs) + ); + (*signer, entry_point) }) .collect::>(); @@ -194,7 +191,7 @@ fn attributable_failure_round2() { .with_env_filter(EnvFilter::from_default_env()) .finish(); let mut reports = tracing::subscriber::with_default(my_subscriber, || { - run_sync::, TestSessionParams>(&mut OsRng, run_inputs).unwrap() + run_sync::<_, TestSessionParams>(&mut OsRng, entry_points).unwrap() }); let v0 = signers[0].verifying_key(); diff --git a/examples/tests/async_runner.rs b/examples/tests/async_runner.rs index 67d094e..ae0f00f 100644 --- a/examples/tests/async_runner.rs +++ b/examples/tests/async_runner.rs @@ -10,7 +10,7 @@ use manul::{ }, testing::{BinaryFormat, TestSessionParams, TestSigner}, }; -use manul_example::simple::{Inputs, Round1, SimpleProtocol}; +use manul_example::simple::{SimpleProtocol, SimpleProtocolEntryPoint}; use rand::Rng; use rand_core::OsRng; use tokio::{ @@ -256,10 +256,8 @@ async fn async_run() { let sessions = signers .into_iter() .map(|signer| { - let inputs = Inputs { - all_ids: all_ids.clone(), - }; - SimpleSession::new::>(&mut OsRng, session_id.clone(), signer, inputs).unwrap() + let entry_point = SimpleProtocolEntryPoint::new(signer.verifying_key(), all_ids.clone()); + SimpleSession::new(&mut OsRng, session_id.clone(), signer, entry_point).unwrap() }) .collect::>(); diff --git a/manul/src/combinators/chain.rs b/manul/src/combinators/chain.rs index e4abf81..562e501 100644 --- a/manul/src/combinators/chain.rs +++ b/manul/src/combinators/chain.rs @@ -56,63 +56,46 @@ use serde::{Deserialize, Serialize}; use crate::protocol::*; -/// A trait defining two protocols executed sequentially. -pub trait Chained: 'static -where - Id: PartyId, -{ - /// The inputs of the new chained protocol. - type Inputs: Send + Sync + Debug; - - /// The entry point of the first protocol. - type EntryPoint1: EntryPoint From<&'a Self::Inputs>>; - - /// The entry point of the second protocol. - type EntryPoint2: EntryPoint< - Id, - Inputs: From<( - Self::Inputs, - <>::Protocol as Protocol>::Result, - )>, - >; -} - /// The protocol error type for the chained protocol. #[derive_where::derive_where(Debug, Clone)] #[derive(Serialize, Deserialize)] #[serde(bound(serialize = " - <>::Protocol as Protocol>::ProtocolError: Serialize, - <>::Protocol as Protocol>::ProtocolError: Serialize, + ::ProtocolError: Serialize, + ::ProtocolError: Serialize, "))] #[serde(bound(deserialize = " - <>::Protocol as Protocol>::ProtocolError: for<'x> Deserialize<'x>, - <>::Protocol as Protocol>::ProtocolError: for<'x> Deserialize<'x>, + ::ProtocolError: for<'x> Deserialize<'x>, + ::ProtocolError: for<'x> Deserialize<'x>, "))] -pub enum ChainedProtocolError> { +pub enum ChainedProtocolError +where + P1: Protocol, + P2: Protocol, +{ /// A protocol error from the first protocol. - Protocol1(<>::Protocol as Protocol>::ProtocolError), + Protocol1(::ProtocolError), /// A protocol error from the second protocol. - Protocol2(<>::Protocol as Protocol>::ProtocolError), + Protocol2(::ProtocolError), } -impl ChainedProtocolError +impl ChainedProtocolError where - Id: PartyId, - C: Chained, + P1: Protocol, + P2: Protocol, { - fn from_protocol1(err: <>::Protocol as Protocol>::ProtocolError) -> Self { + fn from_protocol1(err: ::ProtocolError) -> Self { Self::Protocol1(err) } - fn from_protocol2(err: <>::Protocol as Protocol>::ProtocolError) -> Self { + fn from_protocol2(err: ::ProtocolError) -> Self { Self::Protocol2(err) } } -impl ProtocolError for ChainedProtocolError +impl ProtocolError for ChainedProtocolError where - Id: PartyId, - C: Chained, + P1: Protocol, + P2: Protocol, { fn description(&self) -> String { match self { @@ -229,118 +212,133 @@ where #[derive_where::derive_where(Debug, Clone)] #[derive(Serialize, Deserialize)] #[serde(bound(serialize = " - <>::Protocol as Protocol>::CorrectnessProof: Serialize, - <>::Protocol as Protocol>::CorrectnessProof: Serialize, + ::CorrectnessProof: Serialize, + ::CorrectnessProof: Serialize, "))] #[serde(bound(deserialize = " - <>::Protocol as Protocol>::CorrectnessProof: for<'x> Deserialize<'x>, - <>::Protocol as Protocol>::CorrectnessProof: for<'x> Deserialize<'x>, + ::CorrectnessProof: for<'x> Deserialize<'x>, + ::CorrectnessProof: for<'x> Deserialize<'x>, "))] -pub enum ChainedCorrectnessProof +pub enum ChainedCorrectnessProof where - Id: PartyId, - C: Chained, + P1: Protocol, + P2: Protocol, { /// A correctness proof from the first protocol. - Protocol1(<>::Protocol as Protocol>::CorrectnessProof), + Protocol1(::CorrectnessProof), /// A correctness proof from the second protocol. - Protocol2(<>::Protocol as Protocol>::CorrectnessProof), + Protocol2(::CorrectnessProof), } -impl ChainedCorrectnessProof +impl ChainedCorrectnessProof where - Id: PartyId, - C: Chained, + P1: Protocol, + P2: Protocol, { - fn from_protocol1(proof: <>::Protocol as Protocol>::CorrectnessProof) -> Self { + fn from_protocol1(proof: ::CorrectnessProof) -> Self { Self::Protocol1(proof) } - fn from_protocol2(proof: <>::Protocol as Protocol>::CorrectnessProof) -> Self { + fn from_protocol2(proof: ::CorrectnessProof) -> Self { Self::Protocol2(proof) } } -impl CorrectnessProof for ChainedCorrectnessProof +impl CorrectnessProof for ChainedCorrectnessProof where - Id: PartyId, - C: Chained, + P1: Protocol, + P2: Protocol, { } /// The protocol resulting from chaining two sub-protocols as described by `C`. #[derive(Debug)] -#[allow(clippy::type_complexity)] -pub struct ChainedProtocol>(PhantomData (Id, C)>); +pub struct ChainedProtocol(PhantomData (P1, P2)>) +where + P1: Protocol, + P2: Protocol; -impl Protocol for ChainedProtocol +impl Protocol for ChainedProtocol where - Id: PartyId, - C: Chained, + P1: Protocol, + P2: Protocol, { - type Result = <>::Protocol as Protocol>::Result; - type ProtocolError = ChainedProtocolError; - type CorrectnessProof = ChainedCorrectnessProof; + type Result = ::Result; + type ProtocolError = ChainedProtocolError; + type CorrectnessProof = ChainedCorrectnessProof; } -/// The entry point of the chained protocol. -#[derive_where::derive_where(Debug)] -pub struct ChainedEntryPoint> { - state: ChainState, +pub trait ChainedSplit { + type EntryPoint: EntryPoint; + fn make_entry_point1(self) -> (Self::EntryPoint, impl ChainedJoin); + fn entry_round_id() -> RoundId { + Self::EntryPoint::entry_round().group_under(1) + } } -#[derive_where::derive_where(Debug)] -enum ChainState +pub trait ChainedJoin: 'static + Debug + Send + Sync { + type EntryPoint: EntryPoint; + fn make_entry_point2(self, result: P1::Result) -> Self::EntryPoint; +} + +pub fn make_chained_round( + entry_point: T, + rng: &mut impl CryptoRngCore, + shared_randomness: &[u8], +) -> Result>, LocalError> where Id: PartyId, - C: Chained, + P1: Debug + Protocol, + P2: Debug + Protocol, + T: ChainedSplit, { - Protocol1 { - round: BoxedRound>::Protocol>, - shared_randomness: Box<[u8]>, - id: Id, - inputs: C::Inputs, - }, - Protocol2(BoxedRound>::Protocol>), + let (entry_point, transition) = entry_point.make_entry_point1(); + let round = entry_point.make_round(rng, shared_randomness)?; + let chained_round = ChainedRound { + state: ChainState::Protocol1 { + shared_randomness: shared_randomness.into(), + transition, + round, + }, + }; + Ok(BoxedRound::new_object_safe(chained_round)) } -impl EntryPoint for ChainedEntryPoint +#[derive(Debug)] +struct ChainedRound where Id: PartyId, - C: Chained, + P1: Protocol, + P2: Protocol, + T: ChainedJoin, { - type Inputs = C::Inputs; - type Protocol = ChainedProtocol; - - fn entry_round() -> RoundId { - >::entry_round().group_under(1) - } + state: ChainState, +} - fn new( - rng: &mut impl CryptoRngCore, - shared_randomness: &[u8], - id: Id, - inputs: Self::Inputs, - ) -> Result, LocalError> { - let round = C::EntryPoint1::new(rng, shared_randomness, id.clone(), (&inputs).into())?; - let round = ChainedEntryPoint { - state: ChainState::Protocol1 { - shared_randomness: shared_randomness.into(), - id, - inputs, - round, - }, - }; - Ok(BoxedRound::new_object_safe(round)) - } +#[derive_where::derive_where(Debug)] +enum ChainState +where + Id: PartyId, + P1: Protocol, + P2: Protocol, + T: ChainedJoin, +{ + Protocol1 { + round: BoxedRound, + shared_randomness: Box<[u8]>, + transition: T, + }, + Protocol2(BoxedRound), } -impl ObjectSafeRound for ChainedEntryPoint +impl ObjectSafeRound for ChainedRound where Id: PartyId, - C: Chained, + P1: Debug + Protocol, + P2: Debug + Protocol, + T: ChainedJoin, { - type Protocol = ChainedProtocol; + type Protocol = ChainedProtocol; fn id(&self) -> RoundId { match &self.state { @@ -362,7 +360,7 @@ where // If there are no next rounds, this is the result round. // This means that in the chain the next round will be the entry round of the second protocol. if next_rounds.is_empty() { - next_rounds.insert(C::EntryPoint2::entry_round().group_under(2)); + next_rounds.insert(T::EntryPoint::entry_round().group_under(2)); } next_rounds } @@ -485,27 +483,26 @@ where match self.state { ChainState::Protocol1 { round, - id, - inputs, + transition, shared_randomness, } => match round.into_boxed().finalize(rng, payloads, artifacts) { Ok(FinalizeOutcome::Result(result)) => { let mut boxed_rng = BoxedRng(rng); - let round = C::EntryPoint2::new(&mut boxed_rng, &shared_randomness, id, (inputs, result).into())?; + let entry_point2 = transition.make_entry_point2(result); + let round = entry_point2.make_round(&mut boxed_rng, &shared_randomness)?; Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_object_safe( - ChainedEntryPoint:: { + ChainedRound:: { state: ChainState::Protocol2(round), }, ))) } Ok(FinalizeOutcome::AnotherRound(round)) => Ok(FinalizeOutcome::AnotherRound( - BoxedRound::new_object_safe(ChainedEntryPoint:: { + BoxedRound::new_object_safe(ChainedRound:: { state: ChainState::Protocol1 { shared_randomness, - id, - inputs, round, + transition, }, }), )), @@ -517,7 +514,7 @@ where ChainState::Protocol2(round) => match round.into_boxed().finalize(rng, payloads, artifacts) { Ok(FinalizeOutcome::Result(result)) => Ok(FinalizeOutcome::Result(result)), Ok(FinalizeOutcome::AnotherRound(round)) => Ok(FinalizeOutcome::AnotherRound( - BoxedRound::new_object_safe(ChainedEntryPoint:: { + BoxedRound::new_object_safe(ChainedRound:: { state: ChainState::Protocol2(round), }), )), diff --git a/manul/src/combinators/misbehave.rs b/manul/src/combinators/misbehave.rs index 311995a..3168104 100644 --- a/manul/src/combinators/misbehave.rs +++ b/manul/src/combinators/misbehave.rs @@ -39,20 +39,6 @@ pub trait Behavior: 'static + Debug + Send + Sync {} impl Behavior for T {} -/// The new entry point for the misbehaving rounds. -/// -/// Use as an entry point to run the session, with your ID, behavior `B` and the misbehavior definition `M` set. -#[derive_where::derive_where(Debug)] -pub struct MisbehavingEntryPoint -where - Id: PartyId, - B: Behavior, - M: Misbehaving, -{ - round: BoxedRound>::Protocol>, - behavior: Option, -} - /// A trait defining a sequence of misbehaving rounds modifying or replacing the messages sent by some existing ones. /// /// Override one or more optional methods to modify the specific messages. @@ -62,7 +48,7 @@ where B: Behavior, { /// The entry point of the wrapped rounds. - type EntryPoint: EntryPoint; + type EntryPoint: Debug + EntryPoint; /// Called after [`Round::make_echo_broadcast`](`crate::protocol::Round::make_echo_broadcast`) /// and may modify its result. @@ -115,20 +101,29 @@ where } } -/// The inputs for the misbehaving rounds. -#[derive_where::derive_where(Debug; >::Inputs)] -pub struct MisbehavingInputs +/// The new entry point for the misbehaving rounds. +/// +/// Use as an entry point to run the session, with your ID, behavior `B` and the misbehavior definition `M` set. +#[derive_where::derive_where(Debug)] +pub struct MisbehavingEntryPoint where Id: PartyId, B: Behavior, M: Misbehaving, { - /// The behavior for the rounds starting with these inputs. - /// - /// If `None`, all the changed behavior will be skipped. - pub behavior: Option, - /// The inputs for the wrapped rounds. - pub inner_inputs: >::Inputs, + entry_point: M::EntryPoint, + behavior: Option, +} + +impl MisbehavingEntryPoint +where + Id: PartyId, + B: Behavior, + M: Misbehaving, +{ + pub fn new(entry_point: M::EntryPoint, behavior: Option) -> Self { + Self { entry_point, behavior } + } } impl EntryPoint for MisbehavingEntryPoint @@ -137,24 +132,33 @@ where B: Behavior, M: Misbehaving, { - type Inputs = MisbehavingInputs; type Protocol = >::Protocol; - fn new( + fn make_round( + self, rng: &mut impl CryptoRngCore, shared_randomness: &[u8], - id: Id, - inputs: Self::Inputs, - ) -> Result>::Protocol>, LocalError> { - let round = M::EntryPoint::new(rng, shared_randomness, id, inputs.inner_inputs)?; - Ok(BoxedRound::new_object_safe(Self { + ) -> Result, LocalError> { + let round = self.entry_point.make_round(rng, shared_randomness)?; + Ok(BoxedRound::new_object_safe(MisbehavingRound:: { round, - behavior: inputs.behavior, + behavior: self.behavior, })) } } -impl ObjectSafeRound for MisbehavingEntryPoint +#[derive_where::derive_where(Debug)] +struct MisbehavingRound +where + Id: PartyId, + B: Behavior, + M: Misbehaving, +{ + round: BoxedRound>::Protocol>, + behavior: Option, +} + +impl ObjectSafeRound for MisbehavingRound where Id: PartyId, B: Behavior, @@ -284,12 +288,12 @@ where ) -> Result, FinalizeError> { match self.round.into_boxed().finalize(rng, payloads, artifacts) { Ok(FinalizeOutcome::Result(result)) => Ok(FinalizeOutcome::Result(result)), - Ok(FinalizeOutcome::AnotherRound(round)) => Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_object_safe( - MisbehavingEntryPoint:: { + Ok(FinalizeOutcome::AnotherRound(round)) => { + Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_object_safe(Self { round, behavior: self.behavior, - }, - ))), + }))) + } Err(err) => Err(err), } } diff --git a/manul/src/protocol/round.rs b/manul/src/protocol/round.rs index 00d09fa..443ab3b 100644 --- a/manul/src/protocol/round.rs +++ b/manul/src/protocol/round.rs @@ -339,9 +339,6 @@ impl Artifact { /// This is a round that can be created directly; /// all the others are only reachable throud [`Round::finalize`] by the execution layer. pub trait EntryPoint { - /// Additional inputs for the protocol (besides the mandatory ones in [`new`](`Self::new`)). - type Inputs; - /// The protocol implemented by the round this entry points returns. type Protocol: Protocol; @@ -354,11 +351,10 @@ pub trait EntryPoint { /// /// `session_id` can be assumed to be the same for each node participating in a session. /// `id` is the ID of this node. - fn new( + fn make_round( + self, rng: &mut impl CryptoRngCore, shared_randomness: &[u8], - id: Id, - inputs: Self::Inputs, ) -> Result, LocalError>; } diff --git a/manul/src/session/session.rs b/manul/src/session/session.rs index ab1a27c..e04251b 100644 --- a/manul/src/session/session.rs +++ b/manul/src/session/session.rs @@ -146,13 +146,12 @@ where rng: &mut impl CryptoRngCore, session_id: SessionId, signer: SP::Signer, - inputs: R::Inputs, + entry_point: R, ) -> Result where R: EntryPoint, { - let verifier = signer.verifying_key(); - let first_round = R::new(rng, session_id.as_ref(), verifier.clone(), inputs)?; + let first_round = entry_point.make_round(rng, session_id.as_ref())?; let serializer = Serializer::new::(); let deserializer = Deserializer::new::(); Self::new_for_next_round( diff --git a/manul/src/testing/run_sync.rs b/manul/src/testing/run_sync.rs index 8d08c87..5d7350e 100644 --- a/manul/src/testing/run_sync.rs +++ b/manul/src/testing/run_sync.rs @@ -91,7 +91,7 @@ where #[allow(clippy::type_complexity)] pub fn run_sync( rng: &mut impl CryptoRngCore, - inputs: Vec<(SP::Signer, R::Inputs)>, + entry_points: Vec<(SP::Signer, R)>, ) -> Result>, LocalError> where R: EntryPoint, @@ -102,9 +102,9 @@ where let mut messages = Vec::new(); let mut states = BTreeMap::new(); - for (signer, inputs) in inputs { + for (signer, entry_point) in entry_points { let verifier = signer.verifying_key(); - let session = Session::<_, SP>::new::(rng, session_id.clone(), signer, inputs)?; + let session = Session::<_, SP>::new(rng, session_id.clone(), signer, entry_point)?; let mut accum = session.make_accumulator(); let destinations = session.message_destinations(); diff --git a/manul/src/tests/partial_echo.rs b/manul/src/tests/partial_echo.rs index 3bd3536..dadc22d 100644 --- a/manul/src/tests/partial_echo.rs +++ b/manul/src/tests/partial_echo.rs @@ -51,17 +51,15 @@ impl ProtocolError for PartialEchoProtocolError { #[derive(Debug, Clone)] struct Inputs { - message_destinations: Vec, - expecting_messages_from: Vec, + id: Id, + message_destinations: BTreeSet, + expecting_messages_from: BTreeSet, echo_round_participation: EchoRoundParticipation, } #[derive(Debug)] struct Round1 { - id: Id, - message_destinations: BTreeSet, - expecting_messages_from: BTreeSet, - echo_round_participation: EchoRoundParticipation, + inputs: Inputs, } #[derive(Debug, Serialize, Deserialize)] @@ -69,23 +67,14 @@ struct Round1Echo { sender: Id, } -impl Deserialize<'de>> EntryPoint for Round1 { - type Inputs = Inputs; +impl Deserialize<'de>> EntryPoint for Inputs { type Protocol = PartialEchoProtocol; - fn new( + fn make_round( + self, _rng: &mut impl CryptoRngCore, _shared_randomness: &[u8], - id: Id, - inputs: Self::Inputs, ) -> Result, LocalError> { - let message_destinations = BTreeSet::from_iter(inputs.message_destinations); - let expecting_messages_from = BTreeSet::from_iter(inputs.expecting_messages_from); - Ok(BoxedRound::new_dynamic(Self { - id, - message_destinations, - expecting_messages_from, - echo_round_participation: inputs.echo_round_participation, - })) + Ok(BoxedRound::new_dynamic(Round1 { inputs: self })) } } @@ -101,15 +90,15 @@ impl Deserialize<'de>> Round for Round1 &BTreeSet { - &self.message_destinations + &self.inputs.message_destinations } fn expecting_messages_from(&self) -> &BTreeSet { - &self.expecting_messages_from + &self.inputs.expecting_messages_from } fn echo_round_participation(&self) -> EchoRoundParticipation { - self.echo_round_participation.clone() + self.inputs.echo_round_participation.clone() } fn make_echo_broadcast( @@ -117,13 +106,13 @@ impl Deserialize<'de>> Round for Round1 Result { - if self.message_destinations.is_empty() { + if self.inputs.message_destinations.is_empty() { Ok(EchoBroadcast::none()) } else { EchoBroadcast::new( serializer, Round1Echo { - sender: self.id.clone(), + sender: self.inputs.id.clone(), }, ) } @@ -141,12 +130,12 @@ impl Deserialize<'de>> Round for Round1>(deserializer)?; assert_eq!(&echo.sender, from); - assert!(self.expecting_messages_from.contains(from)); + assert!(self.inputs.expecting_messages_from.contains(from)); } Ok(Payload::new(())) @@ -175,24 +164,27 @@ fn partial_echo() { let node0 = ( signers[0], Inputs { - message_destinations: [ids[1], ids[2], ids[3]].into(), - expecting_messages_from: [].into(), + id: signers[0].verifying_key(), + message_destinations: BTreeSet::from([ids[1], ids[2], ids[3]]), + expecting_messages_from: BTreeSet::new(), echo_round_participation: EchoRoundParticipation::Send, }, ); let node1 = ( signers[1], Inputs { - message_destinations: [ids[2], ids[3]].into(), - expecting_messages_from: [ids[0]].into(), + id: signers[1].verifying_key(), + message_destinations: BTreeSet::from([ids[2], ids[3]]), + expecting_messages_from: BTreeSet::from([ids[0]]), echo_round_participation: EchoRoundParticipation::Default, }, ); let node2 = ( signers[2], Inputs { - message_destinations: [].into(), - expecting_messages_from: [ids[0], ids[1]].into(), + id: signers[2].verifying_key(), + message_destinations: BTreeSet::new(), + expecting_messages_from: BTreeSet::from([ids[0], ids[1]]), echo_round_participation: EchoRoundParticipation::Receive { echo_targets: BTreeSet::from([ids[1], ids[3]]), }, @@ -201,8 +193,9 @@ fn partial_echo() { let node3 = ( signers[3], Inputs { - message_destinations: [].into(), - expecting_messages_from: [ids[0], ids[1]].into(), + id: signers[3].verifying_key(), + message_destinations: BTreeSet::new(), + expecting_messages_from: BTreeSet::from([ids[0], ids[1]]), echo_round_participation: EchoRoundParticipation::Receive { echo_targets: BTreeSet::from([ids[1], ids[2]]), }, @@ -211,19 +204,20 @@ fn partial_echo() { let node4 = ( signers[4], Inputs { - message_destinations: [].into(), - expecting_messages_from: [].into(), + id: signers[4].verifying_key(), + message_destinations: BTreeSet::new(), + expecting_messages_from: BTreeSet::new(), echo_round_participation: EchoRoundParticipation::::Default, }, ); - let inputs = vec![node0, node1, node2, node3, node4]; + let entry_points = vec![node0, node1, node2, node3, node4]; let my_subscriber = tracing_subscriber::fmt() .with_env_filter(EnvFilter::from_default_env()) .finish(); let reports = tracing::subscriber::with_default(my_subscriber, || { - run_sync::, TestSessionParams>(&mut OsRng, inputs).unwrap() + run_sync::<_, TestSessionParams>(&mut OsRng, entry_points).unwrap() }); for (id, report) in reports {