From e331cc747b3edfe6306ed9bcca308aa3355e38da Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Fri, 8 Nov 2024 20:37:32 -0800 Subject: [PATCH] Make EntryPoints stateful --- examples/src/simple.rs | 48 ++--- examples/src/simple_chain.rs | 93 +++++++--- examples/src/simple_malicious.rs | 47 +++-- examples/tests/async_runner.rs | 8 +- manul/benches/empty_rounds.rs | 40 ++--- manul/src/combinators/chain.rs | 271 +++++++++++++++-------------- manul/src/combinators/misbehave.rs | 84 +++++---- manul/src/protocol/object_safe.rs | 2 +- manul/src/protocol/round.rs | 10 +- manul/src/session/session.rs | 5 +- manul/src/testing/run_sync.rs | 6 +- manul/src/tests/partial_echo.rs | 70 ++++---- 12 files changed, 359 insertions(+), 325 deletions(-) diff --git a/examples/src/simple.rs b/examples/src/simple.rs index eeab9ab..308216b 100644 --- a/examples/src/simple.rs +++ b/examples/src/simple.rs @@ -10,6 +10,7 @@ use rand_core::CryptoRngCore; use serde::{Deserialize, Serialize}; use tracing::debug; +#[derive(Debug)] pub struct SimpleProtocol; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -111,11 +112,6 @@ impl Protocol for SimpleProtocol { } } -#[derive(Debug, Clone)] -pub struct Inputs { - pub all_ids: BTreeSet, -} - #[derive(Debug)] pub(crate) struct Context { pub(crate) id: Id, @@ -149,30 +145,40 @@ struct Round1Payload { x: u8, } -impl EntryPoint for Round1 { - type Inputs = Inputs; +#[derive(Debug, Clone)] +pub struct SimpleProtocolEntryPoint { + my_id: Id, + all_ids: BTreeSet, +} + +impl SimpleProtocolEntryPoint { + pub fn new(my_id: Id, all_ids: BTreeSet) -> Self { + Self { my_id, all_ids } + } +} + +impl EntryPoint for SimpleProtocolEntryPoint { type Protocol = SimpleProtocol; - fn new( + fn make_round( + self, _rng: &mut impl CryptoRngCore, _shared_randomness: &[u8], - id: Id, - inputs: Self::Inputs, ) -> Result, LocalError> { // Just some numbers associated with IDs to use in the dummy protocol. // They will be the same on each node since IDs are ordered. - let ids_to_positions = inputs + let ids_to_positions = self .all_ids .iter() .enumerate() .map(|(idx, id)| (id.clone(), idx as u8)) .collect::>(); - let mut ids = inputs.all_ids; - ids.remove(&id); + let mut ids = self.all_ids; + ids.remove(&self.my_id); - Ok(BoxedRound::new_dynamic(Self { + Ok(BoxedRound::new_dynamic(Round1 { context: Context { - id, + id: self.my_id, other_ids: ids, ids_to_positions, }, @@ -401,12 +407,12 @@ mod tests { use manul::{ session::{signature::Keypair, SessionOutcome}, - testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier}, + testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner}, }; use rand_core::OsRng; use tracing_subscriber::EnvFilter; - use super::{Inputs, Round1}; + use super::SimpleProtocolEntryPoint; #[test] fn round() { @@ -415,14 +421,12 @@ mod tests { .iter() .map(|signer| signer.verifying_key()) .collect::>(); - let inputs = signers + let entry_points = signers .into_iter() .map(|signer| { ( signer, - Inputs { - all_ids: all_ids.clone(), - }, + SimpleProtocolEntryPoint::new(signer.verifying_key(), all_ids.clone()), ) }) .collect::>(); @@ -431,7 +435,7 @@ mod tests { .with_env_filter(EnvFilter::from_default_env()) .finish(); let reports = tracing::subscriber::with_default(my_subscriber, || { - run_sync::, TestSessionParams>(&mut OsRng, inputs).unwrap() + run_sync::<_, TestSessionParams>(&mut OsRng, entry_points).unwrap() }); for (_id, report) in reports { diff --git a/examples/src/simple_chain.rs b/examples/src/simple_chain.rs index e187dc6..d2fd018 100644 --- a/examples/src/simple_chain.rs +++ b/examples/src/simple_chain.rs @@ -1,37 +1,80 @@ +use alloc::collections::BTreeSet; use core::fmt::Debug; +use rand_core::CryptoRngCore; use manul::{ - combinators::chain::{Chained, ChainedEntryPoint}, - protocol::PartyId, + combinators::chain::*, + protocol::{BoxedRound, EntryPoint, LocalError, PartyId, Protocol, RoundId}, }; -use super::simple::{Inputs, Round1}; - -pub struct ChainedSimple; +use super::simple::{SimpleProtocol, SimpleProtocolEntryPoint}; #[derive(Debug)] -pub struct NewInputs(Inputs); +pub struct DoubleSimpleProtocol; + +impl ChainedProtocol for DoubleSimpleProtocol { + type Protocol1 = SimpleProtocol; + type Protocol2 = SimpleProtocol; +} -impl<'a, Id: PartyId> From<&'a NewInputs> for Inputs { - fn from(source: &'a NewInputs) -> Self { - source.0.clone() +pub struct DoubleSimpleEntryPoint { + my_id: Id, + all_ids: BTreeSet, +} + +impl DoubleSimpleEntryPoint { + pub fn new(my_id: Id, all_ids: BTreeSet) -> Self { + Self { my_id, all_ids } } } -impl From<(NewInputs, u8)> for Inputs { - fn from(source: (NewInputs, u8)) -> Self { - let (inputs, _result) = source; - inputs.0 +impl ChainedSplit for DoubleSimpleEntryPoint +where + Id: PartyId, +{ + type EntryPoint = SimpleProtocolEntryPoint; + fn make_entry_point1(self) -> (Self::EntryPoint, impl ChainedJoin) { + ( + SimpleProtocolEntryPoint::new(self.my_id.clone(), self.all_ids.clone()), + DoubleTransition { + my_id: self.my_id, + all_ids: self.all_ids, + }, + ) } } -impl Chained for ChainedSimple { - type Inputs = NewInputs; - type EntryPoint1 = Round1; - type EntryPoint2 = Round1; +#[derive(Debug)] +struct DoubleTransition { + my_id: Id, + all_ids: BTreeSet, +} + +impl ChainedJoin for DoubleTransition +where + Id: PartyId, +{ + type EntryPoint = SimpleProtocolEntryPoint; + fn make_entry_point2(self, _result: ::Result) -> Self::EntryPoint { + SimpleProtocolEntryPoint::new(self.my_id, self.all_ids) + } } -pub type DoubleSimpleEntryPoint = ChainedEntryPoint; +impl EntryPoint for DoubleSimpleEntryPoint { + type Protocol = DoubleSimpleProtocol; + + fn entry_round() -> RoundId { + >::entry_round() + } + + fn make_round( + self, + rng: &mut impl CryptoRngCore, + shared_randomness: &[u8], + ) -> Result, LocalError> { + make_chained_round(self, rng, shared_randomness) + } +} #[cfg(test)] mod tests { @@ -39,13 +82,12 @@ mod tests { use manul::{ session::{signature::Keypair, SessionOutcome}, - testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier}, + testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner}, }; use rand_core::OsRng; use tracing_subscriber::EnvFilter; - use super::{DoubleSimpleEntryPoint, NewInputs}; - use crate::simple::Inputs; + use super::DoubleSimpleEntryPoint; #[test] fn round() { @@ -54,14 +96,12 @@ mod tests { .iter() .map(|signer| signer.verifying_key()) .collect::>(); - let inputs = signers + let entry_points = signers .into_iter() .map(|signer| { ( signer, - NewInputs(Inputs { - all_ids: all_ids.clone(), - }), + DoubleSimpleEntryPoint::new(signer.verifying_key(), all_ids.clone()), ) }) .collect::>(); @@ -70,8 +110,7 @@ mod tests { .with_env_filter(EnvFilter::from_default_env()) .finish(); let reports = tracing::subscriber::with_default(my_subscriber, || { - run_sync::, TestSessionParams>(&mut OsRng, inputs) - .unwrap() + run_sync::<_, TestSessionParams>(&mut OsRng, entry_points).unwrap() }); for (_id, report) in reports { diff --git a/examples/src/simple_malicious.rs b/examples/src/simple_malicious.rs index 2f58f77..0f305a4 100644 --- a/examples/src/simple_malicious.rs +++ b/examples/src/simple_malicious.rs @@ -2,18 +2,18 @@ use alloc::collections::BTreeSet; use core::fmt::Debug; use manul::{ - combinators::misbehave::{Misbehaving, MisbehavingEntryPoint, MisbehavingInputs}, + combinators::misbehave::{Misbehaving, MisbehavingEntryPoint}, protocol::{ Artifact, BoxedRound, Deserializer, DirectMessage, EntryPoint, LocalError, PartyId, ProtocolMessagePart, RoundId, Serializer, }, session::signature::Keypair, - testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier}, + testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner}, }; use rand_core::{CryptoRngCore, OsRng}; use tracing_subscriber::EnvFilter; -use crate::simple::{Inputs, Round1, Round1Message, Round2, Round2Message}; +use crate::simple::{Round1, Round1Message, Round2, Round2Message, SimpleProtocolEntryPoint}; #[derive(Debug, Clone, Copy)] enum Behavior { @@ -25,7 +25,7 @@ enum Behavior { struct MaliciousLogic; impl Misbehaving for MaliciousLogic { - type EntryPoint = Round1; + type EntryPoint = SimpleProtocolEntryPoint; fn modify_direct_message( _rng: &mut impl CryptoRngCore, @@ -78,9 +78,8 @@ fn serialized_garbage() { .iter() .map(|signer| signer.verifying_key()) .collect::>(); - let inputs = Inputs { all_ids }; - let run_inputs = signers + let entry_points = signers .iter() .enumerate() .map(|(idx, signer)| { @@ -90,11 +89,11 @@ fn serialized_garbage() { None }; - let malicious_inputs = MisbehavingInputs { - inner_inputs: inputs.clone(), + let entry_point = MaliciousEntryPoint::new( + SimpleProtocolEntryPoint::new(signer.verifying_key(), all_ids.clone()), behavior, - }; - (*signer, malicious_inputs) + ); + (*signer, entry_point) }) .collect::>(); @@ -102,7 +101,7 @@ fn serialized_garbage() { .with_env_filter(EnvFilter::from_default_env()) .finish(); let mut reports = tracing::subscriber::with_default(my_subscriber, || { - run_sync::, TestSessionParams>(&mut OsRng, run_inputs).unwrap() + run_sync::<_, TestSessionParams>(&mut OsRng, entry_points).unwrap() }); let v0 = signers[0].verifying_key(); @@ -124,9 +123,8 @@ fn attributable_failure() { .iter() .map(|signer| signer.verifying_key()) .collect::>(); - let inputs = Inputs { all_ids }; - let run_inputs = signers + let entry_points = signers .iter() .enumerate() .map(|(idx, signer)| { @@ -136,11 +134,11 @@ fn attributable_failure() { None }; - let malicious_inputs = MisbehavingInputs { - inner_inputs: inputs.clone(), + let entry_point = MaliciousEntryPoint::new( + SimpleProtocolEntryPoint::new(signer.verifying_key(), all_ids.clone()), behavior, - }; - (*signer, malicious_inputs) + ); + (*signer, entry_point) }) .collect::>(); @@ -148,7 +146,7 @@ fn attributable_failure() { .with_env_filter(EnvFilter::from_default_env()) .finish(); let mut reports = tracing::subscriber::with_default(my_subscriber, || { - run_sync::, TestSessionParams>(&mut OsRng, run_inputs).unwrap() + run_sync::<_, TestSessionParams>(&mut OsRng, entry_points).unwrap() }); let v0 = signers[0].verifying_key(); @@ -170,9 +168,8 @@ fn attributable_failure_round2() { .iter() .map(|signer| signer.verifying_key()) .collect::>(); - let inputs = Inputs { all_ids }; - let run_inputs = signers + let entry_points = signers .iter() .enumerate() .map(|(idx, signer)| { @@ -182,11 +179,11 @@ fn attributable_failure_round2() { None }; - let malicious_inputs = MisbehavingInputs { - inner_inputs: inputs.clone(), + let entry_point = MaliciousEntryPoint::new( + SimpleProtocolEntryPoint::new(signer.verifying_key(), all_ids.clone()), behavior, - }; - (*signer, malicious_inputs) + ); + (*signer, entry_point) }) .collect::>(); @@ -194,7 +191,7 @@ fn attributable_failure_round2() { .with_env_filter(EnvFilter::from_default_env()) .finish(); let mut reports = tracing::subscriber::with_default(my_subscriber, || { - run_sync::, TestSessionParams>(&mut OsRng, run_inputs).unwrap() + run_sync::<_, TestSessionParams>(&mut OsRng, entry_points).unwrap() }); let v0 = signers[0].verifying_key(); diff --git a/examples/tests/async_runner.rs b/examples/tests/async_runner.rs index 67d094e..ae0f00f 100644 --- a/examples/tests/async_runner.rs +++ b/examples/tests/async_runner.rs @@ -10,7 +10,7 @@ use manul::{ }, testing::{BinaryFormat, TestSessionParams, TestSigner}, }; -use manul_example::simple::{Inputs, Round1, SimpleProtocol}; +use manul_example::simple::{SimpleProtocol, SimpleProtocolEntryPoint}; use rand::Rng; use rand_core::OsRng; use tokio::{ @@ -256,10 +256,8 @@ async fn async_run() { let sessions = signers .into_iter() .map(|signer| { - let inputs = Inputs { - all_ids: all_ids.clone(), - }; - SimpleSession::new::>(&mut OsRng, session_id.clone(), signer, inputs).unwrap() + let entry_point = SimpleProtocolEntryPoint::new(signer.verifying_key(), all_ids.clone()); + SimpleSession::new(&mut OsRng, session_id.clone(), signer, entry_point).unwrap() }) .collect::>(); diff --git a/manul/benches/empty_rounds.rs b/manul/benches/empty_rounds.rs index 87eefee..3770937 100644 --- a/manul/benches/empty_rounds.rs +++ b/manul/benches/empty_rounds.rs @@ -11,7 +11,7 @@ use manul::{ Serializer, }, session::{signature::Keypair, SessionOutcome}, - testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier}, + testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner}, }; use rand_core::{CryptoRngCore, OsRng}; use serde::{Deserialize, Serialize}; @@ -48,18 +48,16 @@ struct Round1Payload; struct Round1Artifact; -impl EntryPoint for EmptyRound { - type Inputs = Inputs; +impl EntryPoint for Inputs { type Protocol = EmptyProtocol; - fn new( + fn make_round( + self, _rng: &mut impl CryptoRngCore, _shared_randomness: &[u8], - _id: Id, - inputs: Self::Inputs, ) -> Result, LocalError> { - Ok(BoxedRound::new_dynamic(Self { + Ok(BoxedRound::new_dynamic(EmptyRound { round_counter: 1, - inputs, + inputs: self, })) } } @@ -170,7 +168,7 @@ fn bench_empty_rounds(c: &mut Criterion) { .map(|signer| signer.verifying_key()) .collect::>(); - let inputs_no_echo = signers + let entry_points_no_echo = signers .iter() .cloned() .map(|signer| { @@ -189,17 +187,16 @@ fn bench_empty_rounds(c: &mut Criterion) { group.bench_function("25 nodes, 5 rounds, no echo", |b| { b.iter(|| { - assert!(run_sync::, TestSessionParams>( - &mut OsRng, - inputs_no_echo.clone() + assert!( + run_sync::<_, TestSessionParams>(&mut OsRng, entry_points_no_echo.clone()) + .unwrap() + .values() + .all(|report| matches!(report.outcome, SessionOutcome::Result(_))) ) - .unwrap() - .values() - .all(|report| matches!(report.outcome, SessionOutcome::Result(_)))) }) }); - let inputs_echo = signers + let entry_points_echo = signers .iter() .cloned() .map(|signer| { @@ -220,13 +217,12 @@ fn bench_empty_rounds(c: &mut Criterion) { group.bench_function("25 nodes, 5 rounds, echo each round", |b| { b.iter(|| { - assert!(run_sync::, TestSessionParams>( - &mut OsRng, - inputs_echo.clone() + assert!( + run_sync::<_, TestSessionParams>(&mut OsRng, entry_points_echo.clone()) + .unwrap() + .values() + .all(|report| matches!(report.outcome, SessionOutcome::Result(_))) ) - .unwrap() - .values() - .all(|report| matches!(report.outcome, SessionOutcome::Result(_)))) }) }); diff --git a/manul/src/combinators/chain.rs b/manul/src/combinators/chain.rs index 32ef8c9..9a36d81 100644 --- a/manul/src/combinators/chain.rs +++ b/manul/src/combinators/chain.rs @@ -4,42 +4,40 @@ executes the two inner protocols in sequence, feeding the result of the first pr into the inputs of the second protocol. For the session level users (that is, the ones executing the protocols) -the new protocol is a single entity with its own [`Protocol`](`crate::protocol::Protocol`) type -and an [`EntryPoint`](`crate::protocol::EntryPoint`) type. +the new protocol is a single entity with its own [`Protocol`](`crate::protocol::Protocol`)-imlementing type +and an [`EntryPoint`](`crate::protocol::EntryPoint`)-implementing type. -For example, imagine we have a `ProtocolA` with an entry point `EntryPointA`, inputs `InputsA`, +For example, imagine we have a `ProtocolA` with an entry point `EntryPointA`, two rounds, `RA1` and `RA2`, and the result `ResultA`; -and similarly a `ProtocolB` with an entry point `EntryPointB`, inputs `InputsB`, +and similarly a `ProtocolB` with an entry point `EntryPointB`, two rounds, `RB1` and `RB2`, and the result `ResultB`. -Then the chained protocol will provide `ProtocolC: Protocol` and `EntryPointC: EntryPoint`, -the user will define `InputsC` for the new protocol, and the execution will look like: -- `InputsA` is created from `InputsC` via the user-defined `From` impl; -- `EntryPointA` is initialized with `InputsA`; +Then the chained protocol will have a `ProtocolC: Protocol` type and an `EntryPointC: EntryPoint` type, +and the execution will look like: +- `EntryPointC` is initialized by the user with whatever constructor it may have; +- Internally, `EntryPointA` is created from `EntryPointC` using the [`ChainedSplit`] implementation + provided by the protocol author; - `RA1` is executed; - `RA2` is executed, producing `ResultA`; -- `InputsB` is created from `ResultA` and `InputsC` via the user-defined `From` impl; +- Internally, `EntryPointB` is created from `ResultA` and the data created in [`ChainedSplit::make_entry_point1`] + using the [`ChainedJoin`] implementation provided by the protocol author; - `RB1` is executed; -- `RB2` is executed, producing `ResultB` (which is also the result of `ChainedProtocol`). +- `RB2` is executed, producing `ResultB` (which is also the result of `ProtocolC`). If the execution happens in a [`Session`](`crate::session::Session`), and there is an error at any point, a regular evidence or correctness proof are created using the corresponding types from the new `ProtocolC`. The usage is as follows. -1. Define an input type for the new joined protocol. - Most likely it will be a union between inputs of the first and the second protocol. +1. Implement [`ChainedProtocol`] for a type of your choice. Usually it will be a ZST. + You will have to specify the two protocol types you want to chain. + This type will then automatically implement [`Protocol`](`crate::protocol::Protocol`). -2. Implement [`Chained`] for a type of your choice. Usually it will be an empty token type. - You will have to specify the entry points of the two protocols, - and the [`From`] conversions from the new input type to the inputs of both entry points - (see the corresponding associated type bounds). +2. Define an entry point type for the new joined protocol. + Most likely it will contain a union between the required data for the entry point + of the first and the second protocol. -3. The entry point for the new protocol will be [`ChainedEntryPoint`] parametrized with - the type implementing [`Chained`] from step 2. - -4. The [`Protocol`](`crate::protocol::Protocol`)-implementing type for the new protocol will be - [`ChainedProtocol`] parametrized with the type implementing [`Chained`] from the step 2. +3. Implement [`ChainedSplit`] and [`ChainedJoin`] for the new entry point. */ use alloc::{ @@ -49,7 +47,7 @@ use alloc::{ string::String, vec::Vec, }; -use core::{fmt::Debug, marker::PhantomData}; +use core::fmt::Debug; use rand_core::CryptoRngCore; use serde::{Deserialize, Serialize}; @@ -57,62 +55,51 @@ use serde::{Deserialize, Serialize}; use crate::protocol::*; /// A trait defining two protocols executed sequentially. -pub trait Chained: 'static -where - Id: PartyId, -{ - /// The inputs of the new chained protocol. - type Inputs: Send + Sync + Debug; - - /// The entry point of the first protocol. - type EntryPoint1: EntryPoint From<&'a Self::Inputs>>; - - /// The entry point of the second protocol. - type EntryPoint2: EntryPoint< - Id, - Inputs: From<( - Self::Inputs, - <>::Protocol as Protocol>::Result, - )>, - >; +pub trait ChainedProtocol: 'static + Debug { + /// The protcol that is executed first. + type Protocol1: Protocol; + + /// The protcol that is executed second. + type Protocol2: Protocol; } /// The protocol error type for the chained protocol. #[derive_where::derive_where(Debug, Clone)] #[derive(Serialize, Deserialize)] #[serde(bound(serialize = " - <>::Protocol as Protocol>::ProtocolError: Serialize, - <>::Protocol as Protocol>::ProtocolError: Serialize, + ::ProtocolError: Serialize, + ::ProtocolError: Serialize, "))] #[serde(bound(deserialize = " - <>::Protocol as Protocol>::ProtocolError: for<'x> Deserialize<'x>, - <>::Protocol as Protocol>::ProtocolError: for<'x> Deserialize<'x>, + ::ProtocolError: for<'x> Deserialize<'x>, + ::ProtocolError: for<'x> Deserialize<'x>, "))] -pub enum ChainedProtocolError> { +pub enum ChainedProtocolError +where + C: ChainedProtocol, +{ /// A protocol error from the first protocol. - Protocol1(<>::Protocol as Protocol>::ProtocolError), + Protocol1(::ProtocolError), /// A protocol error from the second protocol. - Protocol2(<>::Protocol as Protocol>::ProtocolError), + Protocol2(::ProtocolError), } -impl ChainedProtocolError +impl ChainedProtocolError where - Id: PartyId, - C: Chained, + C: ChainedProtocol, { - fn from_protocol1(err: <>::Protocol as Protocol>::ProtocolError) -> Self { + fn from_protocol1(err: ::ProtocolError) -> Self { Self::Protocol1(err) } - fn from_protocol2(err: <>::Protocol as Protocol>::ProtocolError) -> Self { + fn from_protocol2(err: ::ProtocolError) -> Self { Self::Protocol2(err) } } -impl ProtocolError for ChainedProtocolError +impl ProtocolError for ChainedProtocolError where - Id: PartyId, - C: Chained, + C: ChainedProtocol, { fn description(&self) -> String { match self { @@ -229,118 +216,135 @@ where #[derive_where::derive_where(Debug, Clone)] #[derive(Serialize, Deserialize)] #[serde(bound(serialize = " - <>::Protocol as Protocol>::CorrectnessProof: Serialize, - <>::Protocol as Protocol>::CorrectnessProof: Serialize, + ::CorrectnessProof: Serialize, + ::CorrectnessProof: Serialize, "))] #[serde(bound(deserialize = " - <>::Protocol as Protocol>::CorrectnessProof: for<'x> Deserialize<'x>, - <>::Protocol as Protocol>::CorrectnessProof: for<'x> Deserialize<'x>, + ::CorrectnessProof: for<'x> Deserialize<'x>, + ::CorrectnessProof: for<'x> Deserialize<'x>, "))] -pub enum ChainedCorrectnessProof +pub enum ChainedCorrectnessProof where - Id: PartyId, - C: Chained, + C: ChainedProtocol, { /// A correctness proof from the first protocol. - Protocol1(<>::Protocol as Protocol>::CorrectnessProof), + Protocol1(::CorrectnessProof), /// A correctness proof from the second protocol. - Protocol2(<>::Protocol as Protocol>::CorrectnessProof), + Protocol2(::CorrectnessProof), } -impl ChainedCorrectnessProof +impl ChainedCorrectnessProof where - Id: PartyId, - C: Chained, + C: ChainedProtocol, { - fn from_protocol1(proof: <>::Protocol as Protocol>::CorrectnessProof) -> Self { + fn from_protocol1(proof: ::CorrectnessProof) -> Self { Self::Protocol1(proof) } - fn from_protocol2(proof: <>::Protocol as Protocol>::CorrectnessProof) -> Self { + fn from_protocol2(proof: ::CorrectnessProof) -> Self { Self::Protocol2(proof) } } -impl CorrectnessProof for ChainedCorrectnessProof +impl CorrectnessProof for ChainedCorrectnessProof where C: ChainedProtocol {} + +impl Protocol for C where - Id: PartyId, - C: Chained, + C: ChainedProtocol, { + type Result = ::Result; + type ProtocolError = ChainedProtocolError; + type CorrectnessProof = ChainedCorrectnessProof; } -/// The protocol resulting from chaining two sub-protocols as described by `C`. -#[derive(Debug)] -#[allow(clippy::type_complexity)] -pub struct ChainedProtocol>(PhantomData (Id, C)>); +/// A trait defining how the entry point for the whole chained protocol +/// will be split into the entry point for the first protocol, and a piece of data +/// that, along with the first protocol's result, will be used to create the entry point for the second protocol. +pub trait ChainedSplit { + /// The first protocol's entry point. + type EntryPoint: EntryPoint; + + /// Creates the first protocol's entry point and the data for creating the second entry point. + fn make_entry_point1(self) -> (Self::EntryPoint, impl ChainedJoin); + + /// Returns the entry round ID for the chained protocol. + /// + /// This function is supposed to be used in the implementation of [`EntryPoint::entry_round`] + /// for the chained protocol. + /// You don't need to override it. + fn entry_round() -> RoundId { + Self::EntryPoint::entry_round().group_under(1) + } +} + +/// A trait defining how the data created in [`ChainedSplit::make_entry_point1`] +/// will be joined with the result of the first protocol to create an entry point for the second protocol. +pub trait ChainedJoin: 'static + Debug + Send + Sync { + /// The second protocol's entry point. + type EntryPoint: EntryPoint; + + /// Creates the second protocol's entry point using the first protocol's result. + fn make_entry_point2(self, result: ::Result) -> Self::EntryPoint; +} -impl Protocol for ChainedProtocol +/// Creates the first round of the chained protocol. +/// +/// This function is supposed to be used in the implementation of [`EntryPoint::make_round`] +/// for the chained protocol. +pub fn make_chained_round( + entry_point: T, + rng: &mut impl CryptoRngCore, + shared_randomness: &[u8], +) -> Result, LocalError> where Id: PartyId, - C: Chained, + C: ChainedProtocol, + T: ChainedSplit, { - type Result = <>::Protocol as Protocol>::Result; - type ProtocolError = ChainedProtocolError; - type CorrectnessProof = ChainedCorrectnessProof; + let (entry_point, transition) = entry_point.make_entry_point1(); + let round = entry_point.make_round(rng, shared_randomness)?; + let chained_round = ChainedRound { + state: ChainState::Protocol1 { + shared_randomness: shared_randomness.into(), + transition, + round, + }, + }; + Ok(BoxedRound::new_object_safe(chained_round)) } -/// The entry point of the chained protocol. -#[derive_where::derive_where(Debug)] -pub struct ChainedEntryPoint> { - state: ChainState, +#[derive(Debug)] +struct ChainedRound +where + Id: PartyId, + C: ChainedProtocol, + T: ChainedJoin, +{ + state: ChainState, } #[derive_where::derive_where(Debug)] -enum ChainState +enum ChainState where Id: PartyId, - C: Chained, + C: ChainedProtocol, + T: ChainedJoin, { Protocol1 { - round: BoxedRound>::Protocol>, + round: BoxedRound, shared_randomness: Box<[u8]>, - id: Id, - inputs: C::Inputs, + transition: T, }, - Protocol2(BoxedRound>::Protocol>), -} - -impl EntryPoint for ChainedEntryPoint -where - Id: PartyId, - C: Chained, -{ - type Inputs = C::Inputs; - type Protocol = ChainedProtocol; - - fn entry_round() -> RoundId { - >::entry_round().group_under(1) - } - - fn new( - rng: &mut impl CryptoRngCore, - shared_randomness: &[u8], - id: Id, - inputs: Self::Inputs, - ) -> Result, LocalError> { - let round = C::EntryPoint1::new(rng, shared_randomness, id.clone(), (&inputs).into())?; - let round = ChainedEntryPoint { - state: ChainState::Protocol1 { - shared_randomness: shared_randomness.into(), - id, - inputs, - round, - }, - }; - Ok(BoxedRound::new_object_safe(round)) - } + Protocol2(BoxedRound), } -impl ObjectSafeRound for ChainedEntryPoint +impl ObjectSafeRound for ChainedRound where Id: PartyId, - C: Chained, + C: ChainedProtocol, + T: ChainedJoin, { - type Protocol = ChainedProtocol; + type Protocol = C; fn id(&self) -> RoundId { match &self.state { @@ -362,7 +366,7 @@ where // If there are no next rounds, this is the result round. // This means that in the chain the next round will be the entry round of the second protocol. if next_rounds.is_empty() { - next_rounds.insert(C::EntryPoint2::entry_round().group_under(2)); + next_rounds.insert(T::EntryPoint::entry_round().group_under(2)); } next_rounds } @@ -485,27 +489,26 @@ where match self.state { ChainState::Protocol1 { round, - id, - inputs, + transition, shared_randomness, } => match round.into_boxed().finalize(rng, payloads, artifacts) { Ok(FinalizeOutcome::Result(result)) => { let mut boxed_rng = BoxedRng(rng); - let round = C::EntryPoint2::new(&mut boxed_rng, &shared_randomness, id, (inputs, result).into())?; + let entry_point2 = transition.make_entry_point2(result); + let round = entry_point2.make_round(&mut boxed_rng, &shared_randomness)?; Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_object_safe( - ChainedEntryPoint:: { + ChainedRound:: { state: ChainState::Protocol2(round), }, ))) } Ok(FinalizeOutcome::AnotherRound(round)) => Ok(FinalizeOutcome::AnotherRound( - BoxedRound::new_object_safe(ChainedEntryPoint:: { + BoxedRound::new_object_safe(ChainedRound:: { state: ChainState::Protocol1 { shared_randomness, - id, - inputs, round, + transition, }, }), )), @@ -517,7 +520,7 @@ where ChainState::Protocol2(round) => match round.into_boxed().finalize(rng, payloads, artifacts) { Ok(FinalizeOutcome::Result(result)) => Ok(FinalizeOutcome::Result(result)), Ok(FinalizeOutcome::AnotherRound(round)) => Ok(FinalizeOutcome::AnotherRound( - BoxedRound::new_object_safe(ChainedEntryPoint:: { + BoxedRound::new_object_safe(ChainedRound:: { state: ChainState::Protocol2(round), }), )), diff --git a/manul/src/combinators/misbehave.rs b/manul/src/combinators/misbehave.rs index 311995a..08a73b7 100644 --- a/manul/src/combinators/misbehave.rs +++ b/manul/src/combinators/misbehave.rs @@ -6,7 +6,7 @@ The usage is as follows: 1. Define a behavior type, subject to [`Behavior`] bounds. This will represent the possible actions the override may perform. -2. Implement [`Misbehaving`] for a type of your choice. Usually it will be an empty token type. +2. Implement [`Misbehaving`] for a type of your choice. Usually it will be a ZST. You will need to specify the entry point for the unmodified protocol, and some of `modify_*` methods (the blanket implementations simply pass through the original messages). @@ -18,6 +18,9 @@ The usage is as follows: 5. You can get access to the typed `Round` object by using [`BoxedRound::downcast_ref`](`crate::protocol::BoxedRound::downcast_ref`). + +6. Use [`MisbehavingEntryPoint`] parametrized by `Id`, the behavior type from step 1, and the type from step 2 + as the entry point of the new protocol. */ use alloc::{ @@ -39,20 +42,6 @@ pub trait Behavior: 'static + Debug + Send + Sync {} impl Behavior for T {} -/// The new entry point for the misbehaving rounds. -/// -/// Use as an entry point to run the session, with your ID, behavior `B` and the misbehavior definition `M` set. -#[derive_where::derive_where(Debug)] -pub struct MisbehavingEntryPoint -where - Id: PartyId, - B: Behavior, - M: Misbehaving, -{ - round: BoxedRound>::Protocol>, - behavior: Option, -} - /// A trait defining a sequence of misbehaving rounds modifying or replacing the messages sent by some existing ones. /// /// Override one or more optional methods to modify the specific messages. @@ -62,7 +51,7 @@ where B: Behavior, { /// The entry point of the wrapped rounds. - type EntryPoint: EntryPoint; + type EntryPoint: Debug + EntryPoint; /// Called after [`Round::make_echo_broadcast`](`crate::protocol::Round::make_echo_broadcast`) /// and may modify its result. @@ -115,20 +104,30 @@ where } } -/// The inputs for the misbehaving rounds. -#[derive_where::derive_where(Debug; >::Inputs)] -pub struct MisbehavingInputs +/// The new entry point for the misbehaving rounds. +/// +/// Use as an entry point to run the session, with your ID, behavior `B` and the misbehavior definition `M` set. +#[derive_where::derive_where(Debug)] +pub struct MisbehavingEntryPoint +where + Id: PartyId, + B: Behavior, + M: Misbehaving, +{ + entry_point: M::EntryPoint, + behavior: Option, +} + +impl MisbehavingEntryPoint where Id: PartyId, B: Behavior, M: Misbehaving, { - /// The behavior for the rounds starting with these inputs. - /// - /// If `None`, all the changed behavior will be skipped. - pub behavior: Option, - /// The inputs for the wrapped rounds. - pub inner_inputs: >::Inputs, + /// Creates an entry point for the misbehaving protocol using an entry point for the inner protocol. + pub fn new(entry_point: M::EntryPoint, behavior: Option) -> Self { + Self { entry_point, behavior } + } } impl EntryPoint for MisbehavingEntryPoint @@ -137,24 +136,33 @@ where B: Behavior, M: Misbehaving, { - type Inputs = MisbehavingInputs; type Protocol = >::Protocol; - fn new( + fn make_round( + self, rng: &mut impl CryptoRngCore, shared_randomness: &[u8], - id: Id, - inputs: Self::Inputs, - ) -> Result>::Protocol>, LocalError> { - let round = M::EntryPoint::new(rng, shared_randomness, id, inputs.inner_inputs)?; - Ok(BoxedRound::new_object_safe(Self { + ) -> Result, LocalError> { + let round = self.entry_point.make_round(rng, shared_randomness)?; + Ok(BoxedRound::new_object_safe(MisbehavingRound:: { round, - behavior: inputs.behavior, + behavior: self.behavior, })) } } -impl ObjectSafeRound for MisbehavingEntryPoint +#[derive_where::derive_where(Debug)] +struct MisbehavingRound +where + Id: PartyId, + B: Behavior, + M: Misbehaving, +{ + round: BoxedRound>::Protocol>, + behavior: Option, +} + +impl ObjectSafeRound for MisbehavingRound where Id: PartyId, B: Behavior, @@ -284,12 +292,12 @@ where ) -> Result, FinalizeError> { match self.round.into_boxed().finalize(rng, payloads, artifacts) { Ok(FinalizeOutcome::Result(result)) => Ok(FinalizeOutcome::Result(result)), - Ok(FinalizeOutcome::AnotherRound(round)) => Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_object_safe( - MisbehavingEntryPoint:: { + Ok(FinalizeOutcome::AnotherRound(round)) => { + Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_object_safe(Self { round, behavior: self.behavior, - }, - ))), + }))) + } Err(err) => Err(err), } } diff --git a/manul/src/protocol/object_safe.rs b/manul/src/protocol/object_safe.rs index 8678951..dbbcf90 100644 --- a/manul/src/protocol/object_safe.rs +++ b/manul/src/protocol/object_safe.rs @@ -209,7 +209,7 @@ where // We do not want to expose `ObjectSafeRound` to the user, so it is hidden in a struct. /// A wrapped new round that may be returned by [`Round::finalize`] -/// or [`EntryPoint::new`](`crate::protocol::EntryPoint::new`). +/// or [`EntryPoint::make_round`](`crate::protocol::EntryPoint::make_round`). #[derive_where::derive_where(Debug)] pub struct BoxedRound { wrapped: bool, diff --git a/manul/src/protocol/round.rs b/manul/src/protocol/round.rs index 9a306cc..be171c1 100644 --- a/manul/src/protocol/round.rs +++ b/manul/src/protocol/round.rs @@ -340,13 +340,10 @@ impl Artifact { /// This is a round that can be created directly; /// all the others are only reachable throud [`Round::finalize`] by the execution layer. pub trait EntryPoint { - /// Additional inputs for the protocol (besides the mandatory ones in [`new`](`Self::new`)). - type Inputs; - /// The protocol implemented by the round this entry points returns. type Protocol: Protocol; - /// Returns the ID of the round returned by [`Self::new`]. + /// Returns the ID of the round returned by [`Self::make_round`]. fn entry_round() -> RoundId { RoundId::new(1) } @@ -355,11 +352,10 @@ pub trait EntryPoint { /// /// `session_id` can be assumed to be the same for each node participating in a session. /// `id` is the ID of this node. - fn new( + fn make_round( + self, rng: &mut impl CryptoRngCore, shared_randomness: &[u8], - id: Id, - inputs: Self::Inputs, ) -> Result, LocalError>; } diff --git a/manul/src/session/session.rs b/manul/src/session/session.rs index ab1a27c..e04251b 100644 --- a/manul/src/session/session.rs +++ b/manul/src/session/session.rs @@ -146,13 +146,12 @@ where rng: &mut impl CryptoRngCore, session_id: SessionId, signer: SP::Signer, - inputs: R::Inputs, + entry_point: R, ) -> Result where R: EntryPoint, { - let verifier = signer.verifying_key(); - let first_round = R::new(rng, session_id.as_ref(), verifier.clone(), inputs)?; + let first_round = entry_point.make_round(rng, session_id.as_ref())?; let serializer = Serializer::new::(); let deserializer = Deserializer::new::(); Self::new_for_next_round( diff --git a/manul/src/testing/run_sync.rs b/manul/src/testing/run_sync.rs index 8d08c87..5d7350e 100644 --- a/manul/src/testing/run_sync.rs +++ b/manul/src/testing/run_sync.rs @@ -91,7 +91,7 @@ where #[allow(clippy::type_complexity)] pub fn run_sync( rng: &mut impl CryptoRngCore, - inputs: Vec<(SP::Signer, R::Inputs)>, + entry_points: Vec<(SP::Signer, R)>, ) -> Result>, LocalError> where R: EntryPoint, @@ -102,9 +102,9 @@ where let mut messages = Vec::new(); let mut states = BTreeMap::new(); - for (signer, inputs) in inputs { + for (signer, entry_point) in entry_points { let verifier = signer.verifying_key(); - let session = Session::<_, SP>::new::(rng, session_id.clone(), signer, inputs)?; + let session = Session::<_, SP>::new(rng, session_id.clone(), signer, entry_point)?; let mut accum = session.make_accumulator(); let destinations = session.message_destinations(); diff --git a/manul/src/tests/partial_echo.rs b/manul/src/tests/partial_echo.rs index 3bd3536..dadc22d 100644 --- a/manul/src/tests/partial_echo.rs +++ b/manul/src/tests/partial_echo.rs @@ -51,17 +51,15 @@ impl ProtocolError for PartialEchoProtocolError { #[derive(Debug, Clone)] struct Inputs { - message_destinations: Vec, - expecting_messages_from: Vec, + id: Id, + message_destinations: BTreeSet, + expecting_messages_from: BTreeSet, echo_round_participation: EchoRoundParticipation, } #[derive(Debug)] struct Round1 { - id: Id, - message_destinations: BTreeSet, - expecting_messages_from: BTreeSet, - echo_round_participation: EchoRoundParticipation, + inputs: Inputs, } #[derive(Debug, Serialize, Deserialize)] @@ -69,23 +67,14 @@ struct Round1Echo { sender: Id, } -impl Deserialize<'de>> EntryPoint for Round1 { - type Inputs = Inputs; +impl Deserialize<'de>> EntryPoint for Inputs { type Protocol = PartialEchoProtocol; - fn new( + fn make_round( + self, _rng: &mut impl CryptoRngCore, _shared_randomness: &[u8], - id: Id, - inputs: Self::Inputs, ) -> Result, LocalError> { - let message_destinations = BTreeSet::from_iter(inputs.message_destinations); - let expecting_messages_from = BTreeSet::from_iter(inputs.expecting_messages_from); - Ok(BoxedRound::new_dynamic(Self { - id, - message_destinations, - expecting_messages_from, - echo_round_participation: inputs.echo_round_participation, - })) + Ok(BoxedRound::new_dynamic(Round1 { inputs: self })) } } @@ -101,15 +90,15 @@ impl Deserialize<'de>> Round for Round1 &BTreeSet { - &self.message_destinations + &self.inputs.message_destinations } fn expecting_messages_from(&self) -> &BTreeSet { - &self.expecting_messages_from + &self.inputs.expecting_messages_from } fn echo_round_participation(&self) -> EchoRoundParticipation { - self.echo_round_participation.clone() + self.inputs.echo_round_participation.clone() } fn make_echo_broadcast( @@ -117,13 +106,13 @@ impl Deserialize<'de>> Round for Round1 Result { - if self.message_destinations.is_empty() { + if self.inputs.message_destinations.is_empty() { Ok(EchoBroadcast::none()) } else { EchoBroadcast::new( serializer, Round1Echo { - sender: self.id.clone(), + sender: self.inputs.id.clone(), }, ) } @@ -141,12 +130,12 @@ impl Deserialize<'de>> Round for Round1>(deserializer)?; assert_eq!(&echo.sender, from); - assert!(self.expecting_messages_from.contains(from)); + assert!(self.inputs.expecting_messages_from.contains(from)); } Ok(Payload::new(())) @@ -175,24 +164,27 @@ fn partial_echo() { let node0 = ( signers[0], Inputs { - message_destinations: [ids[1], ids[2], ids[3]].into(), - expecting_messages_from: [].into(), + id: signers[0].verifying_key(), + message_destinations: BTreeSet::from([ids[1], ids[2], ids[3]]), + expecting_messages_from: BTreeSet::new(), echo_round_participation: EchoRoundParticipation::Send, }, ); let node1 = ( signers[1], Inputs { - message_destinations: [ids[2], ids[3]].into(), - expecting_messages_from: [ids[0]].into(), + id: signers[1].verifying_key(), + message_destinations: BTreeSet::from([ids[2], ids[3]]), + expecting_messages_from: BTreeSet::from([ids[0]]), echo_round_participation: EchoRoundParticipation::Default, }, ); let node2 = ( signers[2], Inputs { - message_destinations: [].into(), - expecting_messages_from: [ids[0], ids[1]].into(), + id: signers[2].verifying_key(), + message_destinations: BTreeSet::new(), + expecting_messages_from: BTreeSet::from([ids[0], ids[1]]), echo_round_participation: EchoRoundParticipation::Receive { echo_targets: BTreeSet::from([ids[1], ids[3]]), }, @@ -201,8 +193,9 @@ fn partial_echo() { let node3 = ( signers[3], Inputs { - message_destinations: [].into(), - expecting_messages_from: [ids[0], ids[1]].into(), + id: signers[3].verifying_key(), + message_destinations: BTreeSet::new(), + expecting_messages_from: BTreeSet::from([ids[0], ids[1]]), echo_round_participation: EchoRoundParticipation::Receive { echo_targets: BTreeSet::from([ids[1], ids[2]]), }, @@ -211,19 +204,20 @@ fn partial_echo() { let node4 = ( signers[4], Inputs { - message_destinations: [].into(), - expecting_messages_from: [].into(), + id: signers[4].verifying_key(), + message_destinations: BTreeSet::new(), + expecting_messages_from: BTreeSet::new(), echo_round_participation: EchoRoundParticipation::::Default, }, ); - let inputs = vec![node0, node1, node2, node3, node4]; + let entry_points = vec![node0, node1, node2, node3, node4]; let my_subscriber = tracing_subscriber::fmt() .with_env_filter(EnvFilter::from_default_env()) .finish(); let reports = tracing::subscriber::with_default(my_subscriber, || { - run_sync::, TestSessionParams>(&mut OsRng, inputs).unwrap() + run_sync::<_, TestSessionParams>(&mut OsRng, entry_points).unwrap() }); for (id, report) in reports {