diff --git a/CHANGELOG.md b/CHANGELOG.md index c203760..281deb2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `SessionId::new()` renamed to `from_seed()`. ([#41]) - `FirstRound::new()` takes a `&[u8]` instead of a `SessionId` object. ([#41]) - The signatures of `Round::make_echo_broadcast()`, `Round::make_direct_message()`, and `Round::receive_message()`, take messages without `Option`s. ([#46]) -- `Round::make_direct_message_with_artifact()` is the method returning an artifact now; `Round::make_direct_message()` is a shortcut for cases where no artifact is returned. ([#46]) - `Artifact::empty()` removed, the user should return `None` instead. ([#46]) - `EchoBroadcast` and `DirectMessage` now use `ProtocolMessagePart` trait for their methods. ([#47]) - Added normal broadcasts support in addition to echo ones; signatures of `Round` methods changed accordingly; added `Round::make_normal_broadcast()`. ([#47]) @@ -22,6 +21,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Renamed `(Verified)MessageBundle` to `(Verified)Message`. Both are now generic over `Verifier`. ([#56]) - `Session::preprocess_message()` now returns a `PreprocessOutcome` instead of just an `Option`. ([#57]) - `Session::terminate_due_to_errors()` replaces `terminate()`; `terminate()` now signals user interrupt. ([#58]) +- Renamed `FirstRound` trait to `EntryPoint`. ([#60]) +- Added `Protocol` type to `EntryPoint`. ([#60]) +- `EntryPoint` and `FinalizeOutcome::AnotherRound` now use a new `BoxedRound` wrapper type. ([#60]) +- `PartyId` and `ProtocolError` are now bound on `Serialize`/`Deserialize`. ([#60]) ### Added @@ -31,6 +34,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Re-export `digest` from the `session` module. ([#56]) - Added `Message::destination()`. ([#56]) - `PartyId` trait alias for the combination of bounds needed for a party identifier. ([#59]) +- An impl of `ProtocolError` for `()` for protocols that don't use errors. ([#60]) +- A dummy `CorrectnessProof` trait. ([#60]) +- A `misbehave` combinator, intended primarily for testing. ([#60]) +- A `chain` combinator for chaining two protocols. ([#60]) +- `EntryPoint::ENTRY_ROUND` constant. ([#60]) [#32]: https://github.com/entropyxyz/manul/pull/32 @@ -45,6 +53,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#57]: https://github.com/entropyxyz/manul/pull/57 [#58]: https://github.com/entropyxyz/manul/pull/58 [#59]: https://github.com/entropyxyz/manul/pull/59 +[#60]: https://github.com/entropyxyz/manul/pull/60 ## [0.0.1] - 2024-10-12 diff --git a/Cargo.lock b/Cargo.lock index c31e3f0..a5dc428 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -492,6 +492,7 @@ dependencies = [ "serde_asn1_der", "serde_json", "signature", + "tinyvec", "tracing", ] @@ -934,6 +935,22 @@ dependencies = [ "serde_json", ] +[[package]] +name = "tinyvec" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" +dependencies = [ + "serde", + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + [[package]] name = "tokio" version = "1.40.0" diff --git a/examples/src/lib.rs b/examples/src/lib.rs index c3aff76..37b869c 100644 --- a/examples/src/lib.rs +++ b/examples/src/lib.rs @@ -1,6 +1,7 @@ extern crate alloc; pub mod simple; +pub mod simple_chain; #[cfg(test)] mod simple_malicious; diff --git a/examples/src/simple.rs b/examples/src/simple.rs index b947fac..eeab9ab 100644 --- a/examples/src/simple.rs +++ b/examples/src/simple.rs @@ -149,14 +149,15 @@ struct Round1Payload { x: u8, } -impl FirstRound for Round1 { +impl EntryPoint for Round1 { type Inputs = Inputs; + type Protocol = SimpleProtocol; fn new( _rng: &mut impl CryptoRngCore, _shared_randomness: &[u8], id: Id, inputs: Self::Inputs, - ) -> Result { + ) -> Result, LocalError> { // Just some numbers associated with IDs to use in the dummy protocol. // They will be the same on each node since IDs are ordered. let ids_to_positions = inputs @@ -169,13 +170,13 @@ impl FirstRound for Round1 { let mut ids = inputs.all_ids; ids.remove(&id); - Ok(Self { + Ok(BoxedRound::new_dynamic(Self { context: Context { id, other_ids: ids, ids_to_positions, }, - }) + })) } } @@ -228,14 +229,15 @@ impl Round for Round1 { _rng: &mut impl CryptoRngCore, serializer: &Serializer, destination: &Id, - ) -> Result { + ) -> Result<(DirectMessage, Option), LocalError> { debug!("{:?}: making direct message for {:?}", self.context.id, destination); let message = Round1Message { my_position: self.context.ids_to_positions[&self.context.id], your_position: self.context.ids_to_positions[destination], }; - DirectMessage::new(serializer, message) + let dm = DirectMessage::new(serializer, message)?; + Ok((dm, None)) } fn receive_message( @@ -281,11 +283,11 @@ impl Round for Round1 { let sum = self.context.ids_to_positions[&self.context.id] + typed_payloads.iter().map(|payload| payload.x).sum::(); - let round2 = Round2 { + let round2 = BoxedRound::new_dynamic(Round2 { round1_sum: sum, context: self.context, - }; - Ok(FinalizeOutcome::another_round(round2)) + }); + Ok(FinalizeOutcome::AnotherRound(round2)) } fn expecting_messages_from(&self) -> &BTreeSet { @@ -325,14 +327,15 @@ impl Round for Round2 { _rng: &mut impl CryptoRngCore, serializer: &Serializer, destination: &Id, - ) -> Result { + ) -> Result<(DirectMessage, Option), LocalError> { debug!("{:?}: making direct message for {:?}", self.context.id, destination); let message = Round2Message { my_position: self.context.ids_to_positions[&self.context.id], your_position: self.context.ids_to_positions[destination], }; - DirectMessage::new(serializer, message) + let dm = DirectMessage::new(serializer, message)?; + Ok((dm, None)) } fn receive_message( diff --git a/examples/src/simple_chain.rs b/examples/src/simple_chain.rs new file mode 100644 index 0000000..e187dc6 --- /dev/null +++ b/examples/src/simple_chain.rs @@ -0,0 +1,85 @@ +use core::fmt::Debug; + +use manul::{ + combinators::chain::{Chained, ChainedEntryPoint}, + protocol::PartyId, +}; + +use super::simple::{Inputs, Round1}; + +pub struct ChainedSimple; + +#[derive(Debug)] +pub struct NewInputs(Inputs); + +impl<'a, Id: PartyId> From<&'a NewInputs> for Inputs { + fn from(source: &'a NewInputs) -> Self { + source.0.clone() + } +} + +impl From<(NewInputs, u8)> for Inputs { + fn from(source: (NewInputs, u8)) -> Self { + let (inputs, _result) = source; + inputs.0 + } +} + +impl Chained for ChainedSimple { + type Inputs = NewInputs; + type EntryPoint1 = Round1; + type EntryPoint2 = Round1; +} + +pub type DoubleSimpleEntryPoint = ChainedEntryPoint; + +#[cfg(test)] +mod tests { + use alloc::collections::BTreeSet; + + use manul::{ + session::{signature::Keypair, SessionOutcome}, + testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier}, + }; + use rand_core::OsRng; + use tracing_subscriber::EnvFilter; + + use super::{DoubleSimpleEntryPoint, NewInputs}; + use crate::simple::Inputs; + + #[test] + fn round() { + let signers = (0..3).map(TestSigner::new).collect::>(); + let all_ids = signers + .iter() + .map(|signer| signer.verifying_key()) + .collect::>(); + let inputs = signers + .into_iter() + .map(|signer| { + ( + signer, + NewInputs(Inputs { + all_ids: all_ids.clone(), + }), + ) + }) + .collect::>(); + + let my_subscriber = tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .finish(); + let reports = tracing::subscriber::with_default(my_subscriber, || { + run_sync::, TestSessionParams>(&mut OsRng, inputs) + .unwrap() + }); + + for (_id, report) in reports { + if let SessionOutcome::Result(result) = report.outcome { + assert_eq!(result, 3); // 0 + 1 + 2 + } else { + panic!("Session did not finish successfully"); + } + } + } +} diff --git a/examples/src/simple_malicious.rs b/examples/src/simple_malicious.rs index 1ed749d..2f58f77 100644 --- a/examples/src/simple_malicious.rs +++ b/examples/src/simple_malicious.rs @@ -1,16 +1,14 @@ -use alloc::collections::{BTreeMap, BTreeSet}; +use alloc::collections::BTreeSet; use core::fmt::Debug; use manul::{ + combinators::misbehave::{Misbehaving, MisbehavingEntryPoint, MisbehavingInputs}, protocol::{ - Artifact, DirectMessage, FinalizeError, FinalizeOutcome, FirstRound, LocalError, PartyId, Payload, - ProtocolMessagePart, Round, Serializer, + Artifact, BoxedRound, Deserializer, DirectMessage, EntryPoint, LocalError, PartyId, ProtocolMessagePart, + RoundId, Serializer, }, session::signature::Keypair, - testing::{ - round_override, run_sync, BinaryFormat, RoundOverride, RoundWrapper, TestSessionParams, TestSigner, - TestVerifier, - }, + testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier}, }; use rand_core::{CryptoRngCore, OsRng}; use tracing_subscriber::EnvFilter; @@ -19,132 +17,59 @@ use crate::simple::{Inputs, Round1, Round1Message, Round2, Round2Message}; #[derive(Debug, Clone, Copy)] enum Behavior { - Lawful, SerializedGarbage, AttributableFailure, AttributableFailureRound2, } -struct MaliciousInputs { - inputs: Inputs, - behavior: Behavior, -} - -#[derive(Debug)] -struct MaliciousRound1 { - round: Round1, - behavior: Behavior, -} +struct MaliciousLogic; -impl RoundWrapper for MaliciousRound1 { - type InnerRound = Round1; - fn inner_round_ref(&self) -> &Self::InnerRound { - &self.round - } - fn inner_round(self) -> Self::InnerRound { - self.round - } -} +impl Misbehaving for MaliciousLogic { + type EntryPoint = Round1; -impl FirstRound for MaliciousRound1 { - type Inputs = MaliciousInputs; - fn new( - rng: &mut impl CryptoRngCore, - shared_randomness: &[u8], - id: Id, - inputs: Self::Inputs, - ) -> Result { - let round = Round1::new(rng, shared_randomness, id, inputs.inputs)?; - Ok(Self { - round, - behavior: inputs.behavior, - }) - } -} - -impl RoundOverride for MaliciousRound1 { - fn make_direct_message( - &self, - rng: &mut impl CryptoRngCore, + fn modify_direct_message( + _rng: &mut impl CryptoRngCore, + round: &BoxedRound>::Protocol>, + behavior: &Behavior, serializer: &Serializer, - destination: &Id, - ) -> Result { - if matches!(self.behavior, Behavior::SerializedGarbage) { - DirectMessage::new(serializer, [99u8]) - } else if matches!(self.behavior, Behavior::AttributableFailure) { - let message = Round1Message { - my_position: self.round.context.ids_to_positions[&self.round.context.id], - your_position: self.round.context.ids_to_positions[&self.round.context.id], - }; - DirectMessage::new(serializer, message) - } else { - self.inner_round_ref().make_direct_message(rng, serializer, destination) - } - } - - fn finalize( - self, - rng: &mut impl CryptoRngCore, - payloads: BTreeMap, - artifacts: BTreeMap, - ) -> Result< - FinalizeOutcome>::InnerRound as Round>::Protocol>, - FinalizeError<<>::InnerRound as Round>::Protocol>, - > { - let behavior = self.behavior; - let outcome = self.inner_round().finalize(rng, payloads, artifacts)?; - - Ok(match outcome { - FinalizeOutcome::Result(res) => FinalizeOutcome::Result(res), - FinalizeOutcome::AnotherRound(another_round) => { - let round2 = another_round.downcast::>().map_err(FinalizeError::Local)?; - FinalizeOutcome::another_round(MaliciousRound2 { - round: round2, - behavior, - }) + _deserializer: &Deserializer, + _destination: &Id, + direct_message: DirectMessage, + artifact: Option, + ) -> Result<(DirectMessage, Option), LocalError> { + let dm = if round.id() == RoundId::new(1) { + match behavior { + Behavior::SerializedGarbage => DirectMessage::new(serializer, [99u8])?, + Behavior::AttributableFailure => { + let round1 = round.downcast_ref::>()?; + let message = Round1Message { + my_position: round1.context.ids_to_positions[&round1.context.id], + your_position: round1.context.ids_to_positions[&round1.context.id], + }; + DirectMessage::new(serializer, message)? + } + _ => direct_message, + } + } else if round.id() == RoundId::new(2) { + match behavior { + Behavior::AttributableFailureRound2 => { + let round2 = round.downcast_ref::>()?; + let message = Round2Message { + my_position: round2.context.ids_to_positions[&round2.context.id], + your_position: round2.context.ids_to_positions[&round2.context.id], + }; + DirectMessage::new(serializer, message)? + } + _ => direct_message, } - }) - } -} - -round_override!(MaliciousRound1); - -#[derive(Debug)] -struct MaliciousRound2 { - round: Round2, - behavior: Behavior, -} - -impl RoundWrapper for MaliciousRound2 { - type InnerRound = Round2; - fn inner_round_ref(&self) -> &Self::InnerRound { - &self.round - } - fn inner_round(self) -> Self::InnerRound { - self.round - } -} - -impl RoundOverride for MaliciousRound2 { - fn make_direct_message( - &self, - rng: &mut impl CryptoRngCore, - serializer: &Serializer, - destination: &Id, - ) -> Result { - if matches!(self.behavior, Behavior::AttributableFailureRound2) { - let message = Round2Message { - my_position: self.round.context.ids_to_positions[&self.round.context.id], - your_position: self.round.context.ids_to_positions[&self.round.context.id], - }; - DirectMessage::new(serializer, message) } else { - self.inner_round_ref().make_direct_message(rng, serializer, destination) - } + direct_message + }; + Ok((dm, artifact)) } } -round_override!(MaliciousRound2); +type MaliciousEntryPoint = MisbehavingEntryPoint; #[test] fn serialized_garbage() { @@ -160,13 +85,13 @@ fn serialized_garbage() { .enumerate() .map(|(idx, signer)| { let behavior = if idx == 0 { - Behavior::SerializedGarbage + Some(Behavior::SerializedGarbage) } else { - Behavior::Lawful + None }; - let malicious_inputs = MaliciousInputs { - inputs: inputs.clone(), + let malicious_inputs = MisbehavingInputs { + inner_inputs: inputs.clone(), behavior, }; (*signer, malicious_inputs) @@ -177,7 +102,7 @@ fn serialized_garbage() { .with_env_filter(EnvFilter::from_default_env()) .finish(); let mut reports = tracing::subscriber::with_default(my_subscriber, || { - run_sync::, TestSessionParams>(&mut OsRng, run_inputs).unwrap() + run_sync::, TestSessionParams>(&mut OsRng, run_inputs).unwrap() }); let v0 = signers[0].verifying_key(); @@ -206,13 +131,13 @@ fn attributable_failure() { .enumerate() .map(|(idx, signer)| { let behavior = if idx == 0 { - Behavior::AttributableFailure + Some(Behavior::AttributableFailure) } else { - Behavior::Lawful + None }; - let malicious_inputs = MaliciousInputs { - inputs: inputs.clone(), + let malicious_inputs = MisbehavingInputs { + inner_inputs: inputs.clone(), behavior, }; (*signer, malicious_inputs) @@ -223,7 +148,7 @@ fn attributable_failure() { .with_env_filter(EnvFilter::from_default_env()) .finish(); let mut reports = tracing::subscriber::with_default(my_subscriber, || { - run_sync::, TestSessionParams>(&mut OsRng, run_inputs).unwrap() + run_sync::, TestSessionParams>(&mut OsRng, run_inputs).unwrap() }); let v0 = signers[0].verifying_key(); @@ -252,13 +177,13 @@ fn attributable_failure_round2() { .enumerate() .map(|(idx, signer)| { let behavior = if idx == 0 { - Behavior::AttributableFailureRound2 + Some(Behavior::AttributableFailureRound2) } else { - Behavior::Lawful + None }; - let malicious_inputs = MaliciousInputs { - inputs: inputs.clone(), + let malicious_inputs = MisbehavingInputs { + inner_inputs: inputs.clone(), behavior, }; (*signer, malicious_inputs) @@ -269,7 +194,7 @@ fn attributable_failure_round2() { .with_env_filter(EnvFilter::from_default_env()) .finish(); let mut reports = tracing::subscriber::with_default(my_subscriber, || { - run_sync::, TestSessionParams>(&mut OsRng, run_inputs).unwrap() + run_sync::, TestSessionParams>(&mut OsRng, run_inputs).unwrap() }); let v0 = signers[0].verifying_key(); diff --git a/manul/Cargo.toml b/manul/Cargo.toml index 893157b..ede1380 100644 --- a/manul/Cargo.toml +++ b/manul/Cargo.toml @@ -20,6 +20,7 @@ rand_core = { version = "0.6.4", default-features = false } tracing = { version = "0.1", default-features = false } displaydoc = { version = "0.2", default-features = false } derive-where = "1" +tinyvec = { version = "1", default-features = false, features = ["alloc", "serde"] } rand = { version = "0.8", default-features = false, optional = true } serde-persistent-deserializer = { version = "0.3", optional = true } diff --git a/manul/benches/empty_rounds.rs b/manul/benches/empty_rounds.rs index 6ce657b..87eefee 100644 --- a/manul/benches/empty_rounds.rs +++ b/manul/benches/empty_rounds.rs @@ -1,17 +1,14 @@ extern crate alloc; -use alloc::{ - collections::{BTreeMap, BTreeSet}, - string::String, -}; +use alloc::collections::{BTreeMap, BTreeSet}; use core::fmt::Debug; use criterion::{criterion_group, criterion_main, Criterion}; use manul::{ protocol::{ - Artifact, Deserializer, DirectMessage, EchoBroadcast, FinalizeError, FinalizeOutcome, FirstRound, LocalError, - NormalBroadcast, PartyId, Payload, Protocol, ProtocolError, ProtocolMessagePart, ProtocolValidationError, - ReceiveError, Round, RoundId, Serializer, + Artifact, BoxedRound, Deserializer, DirectMessage, EchoBroadcast, EntryPoint, FinalizeError, FinalizeOutcome, + LocalError, NormalBroadcast, PartyId, Payload, Protocol, ProtocolMessagePart, ReceiveError, Round, RoundId, + Serializer, }, session::{signature::Keypair, SessionOutcome}, testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier}, @@ -22,31 +19,9 @@ use serde::{Deserialize, Serialize}; #[derive(Debug)] pub struct EmptyProtocol; -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct EmptyProtocolError; - -impl ProtocolError for EmptyProtocolError { - fn description(&self) -> String { - unimplemented!() - } - fn verify_messages_constitute_error( - &self, - _deserializer: &Deserializer, - _echo_broadcast: &EchoBroadcast, - _normal_broadcast: &NormalBroadcast, - _direct_message: &DirectMessage, - _echo_broadcasts: &BTreeMap, - _normal_broadcasts: &BTreeMap, - _direct_messages: &BTreeMap, - _combined_echos: &BTreeMap>, - ) -> Result<(), ProtocolValidationError> { - unimplemented!() - } -} - impl Protocol for EmptyProtocol { type Result = (); - type ProtocolError = EmptyProtocolError; + type ProtocolError = (); type CorrectnessProof = (); } @@ -73,18 +48,19 @@ struct Round1Payload; struct Round1Artifact; -impl FirstRound for EmptyRound { +impl EntryPoint for EmptyRound { type Inputs = Inputs; + type Protocol = EmptyProtocol; fn new( _rng: &mut impl CryptoRngCore, _shared_randomness: &[u8], _id: Id, inputs: Self::Inputs, - ) -> Result { - Ok(Self { + ) -> Result, LocalError> { + Ok(BoxedRound::new_dynamic(Self { round_counter: 1, inputs, - }) + })) } } @@ -119,7 +95,7 @@ impl Round for EmptyRound { } } - fn make_direct_message_with_artifact( + fn make_direct_message( &self, _rng: &mut impl CryptoRngCore, serializer: &Serializer, @@ -165,11 +141,11 @@ impl Round for EmptyRound { if self.round_counter == self.inputs.rounds_num { Ok(FinalizeOutcome::Result(())) } else { - let round = EmptyRound { + let round = BoxedRound::new_dynamic(EmptyRound { round_counter: self.round_counter + 1, inputs: self.inputs, - }; - Ok(FinalizeOutcome::another_round(round)) + }); + Ok(FinalizeOutcome::AnotherRound(round)) } } diff --git a/manul/src/combinators.rs b/manul/src/combinators.rs new file mode 100644 index 0000000..48f012b --- /dev/null +++ b/manul/src/combinators.rs @@ -0,0 +1,4 @@ +//! Combinators operating on protocols. + +pub mod chain; +pub mod misbehave; diff --git a/manul/src/combinators/chain.rs b/manul/src/combinators/chain.rs new file mode 100644 index 0000000..13450d9 --- /dev/null +++ b/manul/src/combinators/chain.rs @@ -0,0 +1,524 @@ +/*! +A combinator representing two protocols as a new protocol that, when executed, +executes the two inner protocols in sequence, feeding the result of the first protocol +into the inputs of the second protocol. + +For the session level users (that is, the ones executing the protocols) +the new protocol is a single entity with its own [`Protocol`](`crate::protocol::Protocol`) type +and an [`EntryPoint`](`crate::protocol::EntryPoint`) type. + +For example, imagine we have a `ProtocolA` with an entry point `EntryPointA`, inputs `InputsA`, +two rounds, `RA1` and `RA2`, and the result `ResultA`; +and similarly a `ProtocolB` with an entry point `EntryPointB`, inputs `InputsB`, +two rounds, `RB1` and `RB2`, and the result `ResultB`. + +Then the chained protocol will provide `ProtocolC: Protocol` and `EntryPointC: EntryPoint`, +the user will define `InputsC` for the new protocol, and the execution will look like: +- `InputsA` is created from `InputsC` via the user-defined `From` impl; +- `EntryPointA` is initialized with `InputsA`; +- `RA1` is executed; +- `RA2` is executed, producing `ResultA`; +- `InputsB` is created from `ResultA` and `InputsC` via the user-defined `From` impl; +- `RB1` is executed; +- `RB2` is executed, producing `ResultB` (which is also the result of `ChainedProtocol`). + +If the execution happens in a [`Session`](`crate::session::Session`), and there is an error at any point, +a regular evidence or correctness proof are created using the corresponding types from the new `ProtocolC`. + +The usage is as follows. + +1. Define an input type for the new joined protocol. + Most likely it will be a union between inputs of the first and the second protocol. + +2. Implement [`Chained`] for a type of your choice. Usually it will be an empty token type. + You will have to specify the entry points of the two protocols, + and the [`From`] conversions from the new input type to the inputs of both entry points + (see the corresponding associated type bounds). + +3. The entry point for the new protocol will be [`ChainedEntryPoint`] parametrized with + the type implementing [`Chained`] from step 2. + +4. The [`Protocol`](`crate::protocol::Protocol`)-implementing type for the new protocol will be + [`ChainedProtocol`] parametrized with the type implementing [`Chained`] from the step 2. +*/ + +use alloc::{ + boxed::Box, + collections::{BTreeMap, BTreeSet}, + format, + string::String, + vec::Vec, +}; +use core::{fmt::Debug, marker::PhantomData}; + +use rand_core::CryptoRngCore; +use serde::{Deserialize, Serialize}; + +use crate::protocol::*; + +/// A trait defining two protocols executed sequentially. +pub trait Chained: 'static +where + Id: PartyId, +{ + /// The inputs of the new chained protocol. + type Inputs: Send + Sync + Debug; + + /// The entry point of the first protocol. + type EntryPoint1: EntryPoint From<&'a Self::Inputs>>; + + /// The entry point of the second protocol. + type EntryPoint2: EntryPoint< + Id, + Inputs: From<( + Self::Inputs, + <>::Protocol as Protocol>::Result, + )>, + >; +} + +/// The protocol error type for the chained protocol. +#[derive_where::derive_where(Debug, Clone)] +#[derive(Serialize, Deserialize)] +#[serde(bound(serialize = " + <>::Protocol as Protocol>::ProtocolError: Serialize, + <>::Protocol as Protocol>::ProtocolError: Serialize, +"))] +#[serde(bound(deserialize = " + <>::Protocol as Protocol>::ProtocolError: for<'x> Deserialize<'x>, + <>::Protocol as Protocol>::ProtocolError: for<'x> Deserialize<'x>, +"))] +pub enum ChainedProtocolError> { + /// A protocol error from the first protocol. + Protocol1(<>::Protocol as Protocol>::ProtocolError), + /// A protocol error from the second protocol. + Protocol2(<>::Protocol as Protocol>::ProtocolError), +} + +impl ChainedProtocolError +where + Id: PartyId, + C: Chained, +{ + fn from_protocol1(err: <>::Protocol as Protocol>::ProtocolError) -> Self { + Self::Protocol1(err) + } + + fn from_protocol2(err: <>::Protocol as Protocol>::ProtocolError) -> Self { + Self::Protocol2(err) + } +} + +impl ProtocolError for ChainedProtocolError +where + Id: PartyId, + C: Chained, +{ + fn description(&self) -> String { + match self { + Self::Protocol1(err) => format!("Protocol1: {}", err.description()), + Self::Protocol2(err) => format!("Protocol2: {}", err.description()), + } + } + + fn required_direct_messages(&self) -> BTreeSet { + let (protocol_num, round_ids) = match self { + Self::Protocol1(err) => (1, err.required_direct_messages()), + Self::Protocol2(err) => (2, err.required_direct_messages()), + }; + round_ids + .into_iter() + .map(|round_id| round_id.group_under(protocol_num)) + .collect() + } + + fn required_echo_broadcasts(&self) -> BTreeSet { + let (protocol_num, round_ids) = match self { + Self::Protocol1(err) => (1, err.required_echo_broadcasts()), + Self::Protocol2(err) => (2, err.required_echo_broadcasts()), + }; + round_ids + .into_iter() + .map(|round_id| round_id.group_under(protocol_num)) + .collect() + } + + fn required_normal_broadcasts(&self) -> BTreeSet { + let (protocol_num, round_ids) = match self { + Self::Protocol1(err) => (1, err.required_normal_broadcasts()), + Self::Protocol2(err) => (2, err.required_normal_broadcasts()), + }; + round_ids + .into_iter() + .map(|round_id| round_id.group_under(protocol_num)) + .collect() + } + + fn required_combined_echos(&self) -> BTreeSet { + let (protocol_num, round_ids) = match self { + Self::Protocol1(err) => (1, err.required_combined_echos()), + Self::Protocol2(err) => (2, err.required_combined_echos()), + }; + round_ids + .into_iter() + .map(|round_id| round_id.group_under(protocol_num)) + .collect() + } + + #[allow(clippy::too_many_arguments)] + fn verify_messages_constitute_error( + &self, + deserializer: &Deserializer, + echo_broadcast: &EchoBroadcast, + normal_broadcast: &NormalBroadcast, + direct_message: &DirectMessage, + echo_broadcasts: &BTreeMap, + normal_broadcasts: &BTreeMap, + direct_messages: &BTreeMap, + combined_echos: &BTreeMap>, + ) -> Result<(), ProtocolValidationError> { + // TODO: the cloning can be avoided if instead we provide a reference to some "transcript API", + // and can replace it here with a proxy that will remove nesting from round ID's. + let echo_broadcasts = echo_broadcasts + .clone() + .into_iter() + .map(|(round_id, v)| round_id.ungroup().map(|round_id| (round_id, v))) + .collect::, _>>()?; + let normal_broadcasts = normal_broadcasts + .clone() + .into_iter() + .map(|(round_id, v)| round_id.ungroup().map(|round_id| (round_id, v))) + .collect::, _>>()?; + let direct_messages = direct_messages + .clone() + .into_iter() + .map(|(round_id, v)| round_id.ungroup().map(|round_id| (round_id, v))) + .collect::, _>>()?; + let combined_echos = combined_echos + .clone() + .into_iter() + .map(|(round_id, v)| round_id.ungroup().map(|round_id| (round_id, v))) + .collect::, _>>()?; + + match self { + Self::Protocol1(err) => err.verify_messages_constitute_error( + deserializer, + echo_broadcast, + normal_broadcast, + direct_message, + &echo_broadcasts, + &normal_broadcasts, + &direct_messages, + &combined_echos, + ), + Self::Protocol2(err) => err.verify_messages_constitute_error( + deserializer, + echo_broadcast, + normal_broadcast, + direct_message, + &echo_broadcasts, + &normal_broadcasts, + &direct_messages, + &combined_echos, + ), + } + } +} + +/// The correctness proof type for the chained protocol. +#[derive_where::derive_where(Debug, Clone)] +#[derive(Serialize, Deserialize)] +#[serde(bound(serialize = " + <>::Protocol as Protocol>::CorrectnessProof: Serialize, + <>::Protocol as Protocol>::CorrectnessProof: Serialize, +"))] +#[serde(bound(deserialize = " + <>::Protocol as Protocol>::CorrectnessProof: for<'x> Deserialize<'x>, + <>::Protocol as Protocol>::CorrectnessProof: for<'x> Deserialize<'x>, +"))] +pub enum ChainedCorrectnessProof +where + Id: PartyId, + C: Chained, +{ + /// A correctness proof from the first protocol. + Protocol1(<>::Protocol as Protocol>::CorrectnessProof), + /// A correctness proof from the second protocol. + Protocol2(<>::Protocol as Protocol>::CorrectnessProof), +} + +impl ChainedCorrectnessProof +where + Id: PartyId, + C: Chained, +{ + fn from_protocol1(proof: <>::Protocol as Protocol>::CorrectnessProof) -> Self { + Self::Protocol1(proof) + } + + fn from_protocol2(proof: <>::Protocol as Protocol>::CorrectnessProof) -> Self { + Self::Protocol2(proof) + } +} + +impl CorrectnessProof for ChainedCorrectnessProof +where + Id: PartyId, + C: Chained, +{ +} + +/// The protocol resulting from chaining two sub-protocols as described by `C`. +#[derive(Debug)] +#[allow(clippy::type_complexity)] +pub struct ChainedProtocol>(PhantomData (Id, C)>); + +impl Protocol for ChainedProtocol +where + Id: PartyId, + C: Chained, +{ + type Result = <>::Protocol as Protocol>::Result; + type ProtocolError = ChainedProtocolError; + type CorrectnessProof = ChainedCorrectnessProof; +} + +/// The entry point of the chained protocol. +#[derive_where::derive_where(Debug)] +pub struct ChainedEntryPoint> { + state: ChainState, +} + +#[derive_where::derive_where(Debug)] +enum ChainState +where + Id: PartyId, + C: Chained, +{ + Protocol1 { + round: BoxedRound>::Protocol>, + shared_randomness: Box<[u8]>, + id: Id, + inputs: C::Inputs, + }, + Protocol2(BoxedRound>::Protocol>), +} + +impl EntryPoint for ChainedEntryPoint +where + Id: PartyId, + C: Chained, +{ + type Inputs = C::Inputs; + type Protocol = ChainedProtocol; + + fn entry_round() -> RoundId { + >::entry_round().group_under(1) + } + + fn new( + rng: &mut impl CryptoRngCore, + shared_randomness: &[u8], + id: Id, + inputs: Self::Inputs, + ) -> Result, LocalError> { + let round = C::EntryPoint1::new(rng, shared_randomness, id.clone(), (&inputs).into())?; + let round = ChainedEntryPoint { + state: ChainState::Protocol1 { + shared_randomness: shared_randomness.into(), + id, + inputs, + round, + }, + }; + Ok(BoxedRound::new_object_safe(round)) + } +} + +impl ObjectSafeRound for ChainedEntryPoint +where + Id: PartyId, + C: Chained, +{ + type Protocol = ChainedProtocol; + + fn id(&self) -> RoundId { + match &self.state { + ChainState::Protocol1 { round, .. } => round.as_ref().id().group_under(1), + ChainState::Protocol2(round) => round.as_ref().id().group_under(2), + } + } + + fn possible_next_rounds(&self) -> BTreeSet { + match &self.state { + ChainState::Protocol1 { round, .. } => { + let mut next_rounds = round + .as_ref() + .possible_next_rounds() + .into_iter() + .map(|round_id| round_id.group_under(1)) + .collect::>(); + + // If there are no next rounds, this is the result round. + // This means that in the chain the next round will be the entry round of the second protocol. + if next_rounds.is_empty() { + next_rounds.insert(C::EntryPoint2::entry_round().group_under(2)); + } + next_rounds + } + ChainState::Protocol2(round) => round + .as_ref() + .possible_next_rounds() + .into_iter() + .map(|round_id| round_id.group_under(2)) + .collect(), + } + } + + fn message_destinations(&self) -> &BTreeSet { + match &self.state { + ChainState::Protocol1 { round, .. } => round.as_ref().message_destinations(), + ChainState::Protocol2(round) => round.as_ref().message_destinations(), + } + } + + fn make_direct_message( + &self, + rng: &mut dyn CryptoRngCore, + serializer: &Serializer, + deserializer: &Deserializer, + destination: &Id, + ) -> Result<(DirectMessage, Option), LocalError> { + match &self.state { + ChainState::Protocol1 { round, .. } => { + round + .as_ref() + .make_direct_message(rng, serializer, deserializer, destination) + } + ChainState::Protocol2(round) => { + round + .as_ref() + .make_direct_message(rng, serializer, deserializer, destination) + } + } + } + + fn make_echo_broadcast( + &self, + rng: &mut dyn CryptoRngCore, + serializer: &Serializer, + deserializer: &Deserializer, + ) -> Result { + match &self.state { + ChainState::Protocol1 { round, .. } => round.as_ref().make_echo_broadcast(rng, serializer, deserializer), + ChainState::Protocol2(round) => round.as_ref().make_echo_broadcast(rng, serializer, deserializer), + } + } + + fn make_normal_broadcast( + &self, + rng: &mut dyn CryptoRngCore, + serializer: &Serializer, + deserializer: &Deserializer, + ) -> Result { + match &self.state { + ChainState::Protocol1 { round, .. } => round.as_ref().make_normal_broadcast(rng, serializer, deserializer), + ChainState::Protocol2(round) => round.as_ref().make_normal_broadcast(rng, serializer, deserializer), + } + } + + fn receive_message( + &self, + rng: &mut dyn CryptoRngCore, + deserializer: &Deserializer, + from: &Id, + echo_broadcast: EchoBroadcast, + normal_broadcast: NormalBroadcast, + direct_message: DirectMessage, + ) -> Result> { + match &self.state { + ChainState::Protocol1 { round, .. } => match round.as_ref().receive_message( + rng, + deserializer, + from, + echo_broadcast, + normal_broadcast, + direct_message, + ) { + Ok(payload) => Ok(payload), + Err(err) => Err(err.map(ChainedProtocolError::from_protocol1)), + }, + ChainState::Protocol2(round) => match round.as_ref().receive_message( + rng, + deserializer, + from, + echo_broadcast, + normal_broadcast, + direct_message, + ) { + Ok(payload) => Ok(payload), + Err(err) => Err(err.map(ChainedProtocolError::from_protocol2)), + }, + } + } + + fn finalize( + self: Box, + rng: &mut dyn CryptoRngCore, + payloads: BTreeMap, + artifacts: BTreeMap, + ) -> Result, FinalizeError> { + match self.state { + ChainState::Protocol1 { + round, + id, + inputs, + shared_randomness, + } => match round.into_boxed().finalize(rng, payloads, artifacts) { + Ok(FinalizeOutcome::Result(result)) => { + let mut boxed_rng = BoxedRng(rng); + let round = C::EntryPoint2::new(&mut boxed_rng, &shared_randomness, id, (inputs, result).into())?; + + Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_object_safe( + ChainedEntryPoint:: { + state: ChainState::Protocol2(round), + }, + ))) + } + Ok(FinalizeOutcome::AnotherRound(round)) => Ok(FinalizeOutcome::AnotherRound( + BoxedRound::new_object_safe(ChainedEntryPoint:: { + state: ChainState::Protocol1 { + shared_randomness, + id, + inputs, + round, + }, + }), + )), + Err(FinalizeError::Local(err)) => Err(FinalizeError::Local(err)), + Err(FinalizeError::Unattributable(proof)) => Err(FinalizeError::Unattributable( + ChainedCorrectnessProof::from_protocol1(proof), + )), + }, + ChainState::Protocol2(round) => match round.into_boxed().finalize(rng, payloads, artifacts) { + Ok(FinalizeOutcome::Result(result)) => Ok(FinalizeOutcome::Result(result)), + Ok(FinalizeOutcome::AnotherRound(round)) => Ok(FinalizeOutcome::AnotherRound( + BoxedRound::new_object_safe(ChainedEntryPoint:: { + state: ChainState::Protocol2(round), + }), + )), + Err(FinalizeError::Local(err)) => Err(FinalizeError::Local(err)), + Err(FinalizeError::Unattributable(proof)) => Err(FinalizeError::Unattributable( + ChainedCorrectnessProof::from_protocol2(proof), + )), + }, + } + } + + fn expecting_messages_from(&self) -> &BTreeSet { + match &self.state { + ChainState::Protocol1 { round, .. } => round.as_ref().expecting_messages_from(), + ChainState::Protocol2(round) => round.as_ref().expecting_messages_from(), + } + } +} diff --git a/manul/src/combinators/misbehave.rs b/manul/src/combinators/misbehave.rs new file mode 100644 index 0000000..d71ddc2 --- /dev/null +++ b/manul/src/combinators/misbehave.rs @@ -0,0 +1,291 @@ +/*! +A combinator allowing one to intercept outgoing messages from a round, and replace or modify them. + +The usage is as follows: + +1. Define a behavior type, subject to [`Behavior`] bounds. + This will represent the possible actions the override may perform. + +2. Implement [`Misbehaving`] for a type of your choice. Usually it will be an empty token type. + You will need to specify the entry point for the unmodified protocol, + and some of `modify_*` methods (the blanket implementations simply pass through the original messages). + +3. The `modify_*` methods can be called from any round, use [`BoxedRound::id`](`crate::protocol::BoxedRound::id`) + on the `round` argument to determine which round it is. + +4. In the `modify_*` methods, you can get the original typed message using the provided `deserializer` argument, + and create a new one using the `serializer`. + +5. You can get access to the typed `Round` object by using + [`BoxedRound::downcast_ref`](`crate::protocol::BoxedRound::downcast_ref`). +*/ + +use alloc::{ + boxed::Box, + collections::{BTreeMap, BTreeSet}, +}; +use core::fmt::Debug; + +use rand_core::CryptoRngCore; + +use crate::protocol::{ + Artifact, BoxedRng, BoxedRound, Deserializer, DirectMessage, EchoBroadcast, EntryPoint, FinalizeError, + FinalizeOutcome, LocalError, NormalBroadcast, ObjectSafeRound, PartyId, Payload, ReceiveError, RoundId, Serializer, +}; + +/// A trait describing required properties for a behavior type. +pub trait Behavior: 'static + Debug + Send + Sync {} + +impl Behavior for T {} + +/// The new entry point for the misbehaving rounds. +/// +/// Use as an entry point to run the session, with your ID, behavior `B` and the misbehavior definition `M` set. +#[derive_where::derive_where(Debug)] +pub struct MisbehavingEntryPoint +where + Id: PartyId, + B: Behavior, + M: Misbehaving, +{ + round: BoxedRound>::Protocol>, + behavior: Option, +} + +/// A trait defining a sequence of misbehaving rounds modifying or replacing the messages sent by some existing ones. +/// +/// Override one or more optional methods to modify the specific messages. +pub trait Misbehaving: 'static +where + Id: PartyId, + B: Behavior, +{ + /// The entry point of the wrapped rounds. + type EntryPoint: EntryPoint; + + /// Called after [`Round::make_echo_broadcast`](`crate::protocol::Round::make_echo_broadcast`) + /// and may modify its result. + /// + /// The default implementation passes through the original message. + #[allow(unused_variables)] + fn modify_echo_broadcast( + rng: &mut impl CryptoRngCore, + round: &BoxedRound>::Protocol>, + behavior: &B, + serializer: &Serializer, + deserializer: &Deserializer, + echo_broadcast: EchoBroadcast, + ) -> Result { + Ok(echo_broadcast) + } + + /// Called after [`Round::make_normal_broadcast`](`crate::protocol::Round::make_normal_broadcast`) + /// and may modify its result. + /// + /// The default implementation passes through the original message. + #[allow(unused_variables)] + fn modify_normal_broadcast( + rng: &mut impl CryptoRngCore, + round: &BoxedRound>::Protocol>, + behavior: &B, + serializer: &Serializer, + deserializer: &Deserializer, + normal_broadcast: NormalBroadcast, + ) -> Result { + Ok(normal_broadcast) + } + + /// Called after [`Round::make_direct_message`](`crate::protocol::Round::make_direct_message`) + /// and may modify its result. + /// + /// The default implementation passes through the original message. + #[allow(unused_variables, clippy::too_many_arguments)] + fn modify_direct_message( + rng: &mut impl CryptoRngCore, + round: &BoxedRound>::Protocol>, + behavior: &B, + serializer: &Serializer, + deserializer: &Deserializer, + destination: &Id, + direct_message: DirectMessage, + artifact: Option, + ) -> Result<(DirectMessage, Option), LocalError> { + Ok((direct_message, artifact)) + } +} + +/// The inputs for the misbehaving rounds. +#[derive_where::derive_where(Debug; >::Inputs)] +pub struct MisbehavingInputs +where + Id: PartyId, + B: Behavior, + M: Misbehaving, +{ + /// The behavior for the rounds starting with these inputs. + /// + /// If `None`, all the changed behavior will be skipped. + pub behavior: Option, + /// The inputs for the wrapped rounds. + pub inner_inputs: >::Inputs, +} + +impl EntryPoint for MisbehavingEntryPoint +where + Id: PartyId, + B: Behavior, + M: Misbehaving, +{ + type Inputs = MisbehavingInputs; + type Protocol = >::Protocol; + + fn new( + rng: &mut impl CryptoRngCore, + shared_randomness: &[u8], + id: Id, + inputs: Self::Inputs, + ) -> Result>::Protocol>, LocalError> { + let round = M::EntryPoint::new(rng, shared_randomness, id, inputs.inner_inputs)?; + Ok(BoxedRound::new_object_safe(Self { + round, + behavior: inputs.behavior, + })) + } +} + +impl ObjectSafeRound for MisbehavingEntryPoint +where + Id: PartyId, + B: Behavior, + M: Misbehaving, +{ + type Protocol = >::Protocol; + + fn id(&self) -> RoundId { + self.round.as_ref().id() + } + + fn possible_next_rounds(&self) -> BTreeSet { + self.round.as_ref().possible_next_rounds() + } + + fn message_destinations(&self) -> &BTreeSet { + self.round.as_ref().message_destinations() + } + + fn make_direct_message( + &self, + rng: &mut dyn CryptoRngCore, + serializer: &Serializer, + deserializer: &Deserializer, + destination: &Id, + ) -> Result<(DirectMessage, Option), LocalError> { + let (direct_message, artifact) = + self.round + .as_ref() + .make_direct_message(rng, serializer, deserializer, destination)?; + if let Some(behavior) = self.behavior.as_ref() { + let mut boxed_rng = BoxedRng(rng); + M::modify_direct_message( + &mut boxed_rng, + &self.round, + behavior, + serializer, + deserializer, + destination, + direct_message, + artifact, + ) + } else { + Ok((direct_message, artifact)) + } + } + + fn make_echo_broadcast( + &self, + rng: &mut dyn CryptoRngCore, + serializer: &Serializer, + deserializer: &Deserializer, + ) -> Result { + let echo_broadcast = self.round.as_ref().make_echo_broadcast(rng, serializer, deserializer)?; + if let Some(behavior) = self.behavior.as_ref() { + let mut boxed_rng = BoxedRng(rng); + M::modify_echo_broadcast( + &mut boxed_rng, + &self.round, + behavior, + serializer, + deserializer, + echo_broadcast, + ) + } else { + Ok(echo_broadcast) + } + } + + fn make_normal_broadcast( + &self, + rng: &mut dyn CryptoRngCore, + serializer: &Serializer, + deserializer: &Deserializer, + ) -> Result { + let normal_broadcast = self + .round + .as_ref() + .make_normal_broadcast(rng, serializer, deserializer)?; + if let Some(behavior) = self.behavior.as_ref() { + let mut boxed_rng = BoxedRng(rng); + M::modify_normal_broadcast( + &mut boxed_rng, + &self.round, + behavior, + serializer, + deserializer, + normal_broadcast, + ) + } else { + Ok(normal_broadcast) + } + } + + fn receive_message( + &self, + rng: &mut dyn CryptoRngCore, + deserializer: &Deserializer, + from: &Id, + echo_broadcast: EchoBroadcast, + normal_broadcast: NormalBroadcast, + direct_message: DirectMessage, + ) -> Result> { + self.round.as_ref().receive_message( + rng, + deserializer, + from, + echo_broadcast, + normal_broadcast, + direct_message, + ) + } + + fn finalize( + self: Box, + rng: &mut dyn CryptoRngCore, + payloads: BTreeMap, + artifacts: BTreeMap, + ) -> Result, FinalizeError> { + match self.round.into_boxed().finalize(rng, payloads, artifacts) { + Ok(FinalizeOutcome::Result(result)) => Ok(FinalizeOutcome::Result(result)), + Ok(FinalizeOutcome::AnotherRound(round)) => Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_object_safe( + MisbehavingEntryPoint:: { + round, + behavior: self.behavior, + }, + ))), + Err(err) => Err(err), + } + } + + fn expecting_messages_from(&self) -> &BTreeSet { + self.round.as_ref().expecting_messages_from() + } +} diff --git a/manul/src/lib.rs b/manul/src/lib.rs index 8d3323f..ad55b47 100644 --- a/manul/src/lib.rs +++ b/manul/src/lib.rs @@ -16,6 +16,7 @@ extern crate alloc; +pub mod combinators; pub mod protocol; pub mod session; pub(crate) mod utils; diff --git a/manul/src/protocol.rs b/manul/src/protocol.rs index 0e53744..46b356f 100644 --- a/manul/src/protocol.rs +++ b/manul/src/protocol.rs @@ -4,7 +4,7 @@ API for protocol implementors. A protocol is a directed acyclic graph with the nodes being objects of types implementing [`Round`] (to be specific, "acyclic" means that the values returned by [`Round::id`] should not repeat during the protocol execution; the types might). -The starting point is a type that implements [`FirstRound`]. +The starting point is a type that implements [`EntryPoint`]. All the rounds must have their associated type [`Round::Protocol`] set to the same [`Protocol`] instance to be executed by a [`Session`](`crate::session::Session`). @@ -22,12 +22,13 @@ pub use errors::{ NormalBroadcastError, ProtocolValidationError, ReceiveError, RemoteError, }; pub use message::{DirectMessage, EchoBroadcast, NormalBroadcast, ProtocolMessagePart}; +pub use object_safe::BoxedRound; pub use round::{ - AnotherRound, Artifact, FinalizeOutcome, FirstRound, PartyId, Payload, Protocol, ProtocolError, Round, RoundId, + Artifact, CorrectnessProof, EntryPoint, FinalizeOutcome, PartyId, Payload, Protocol, ProtocolError, Round, RoundId, }; pub use serialization::{Deserializer, Serializer}; pub(crate) use errors::ReceiveErrorType; -pub(crate) use object_safe::{ObjectSafeRound, ObjectSafeRoundWrapper}; +pub(crate) use object_safe::{BoxedRng, ObjectSafeRound}; pub use digest; diff --git a/manul/src/protocol/errors.rs b/manul/src/protocol/errors.rs index a77ecc3..a5820f2 100644 --- a/manul/src/protocol/errors.rs +++ b/manul/src/protocol/errors.rs @@ -1,4 +1,4 @@ -use alloc::{format, string::String}; +use alloc::{boxed::Box, format, string::String}; use core::fmt::Debug; use super::round::Protocol; @@ -50,7 +50,7 @@ pub(crate) enum ReceiveErrorType { // so this whole enum is crate-private and the variants are created // via constructors and From impls. /// An echo round error occurred. - Echo(EchoRoundError), + Echo(Box>), } impl ReceiveError { @@ -68,6 +68,33 @@ impl ReceiveError { pub fn protocol(error: P::ProtocolError) -> Self { Self(ReceiveErrorType::Protocol(error)) } + + /// Maps the error to a different protocol, given the mapping function for protocol errors. + pub(crate) fn map(self, f: F) -> ReceiveError + where + F: Fn(P::ProtocolError) -> T::ProtocolError, + T: Protocol, + { + ReceiveError(self.0.map::(f)) + } +} + +impl ReceiveErrorType { + pub(crate) fn map(self, f: F) -> ReceiveErrorType + where + F: Fn(P::ProtocolError) -> T::ProtocolError, + T: Protocol, + { + match self { + Self::Local(err) => ReceiveErrorType::Local(err), + Self::InvalidDirectMessage(err) => ReceiveErrorType::InvalidDirectMessage(err), + Self::InvalidEchoBroadcast(err) => ReceiveErrorType::InvalidEchoBroadcast(err), + Self::InvalidNormalBroadcast(err) => ReceiveErrorType::InvalidNormalBroadcast(err), + Self::Unprovable(err) => ReceiveErrorType::Unprovable(err), + Self::Echo(err) => ReceiveErrorType::Echo(err), + Self::Protocol(err) => ReceiveErrorType::Protocol(f(err)), + } + } } impl From for ReceiveError @@ -93,7 +120,7 @@ where P: Protocol, { fn from(error: EchoRoundError) -> Self { - Self(ReceiveErrorType::Echo(error)) + Self(ReceiveErrorType::Echo(Box::new(error))) } } diff --git a/manul/src/protocol/object_safe.rs b/manul/src/protocol/object_safe.rs index fd65f95..910abc9 100644 --- a/manul/src/protocol/object_safe.rs +++ b/manul/src/protocol/object_safe.rs @@ -17,7 +17,7 @@ use super::{ /// Since object-safe trait methods cannot take `impl CryptoRngCore` arguments, /// this structure wraps the dynamic object and exposes a `CryptoRngCore` interface, /// to be passed to statically typed round methods. -struct BoxedRng<'a>(&'a mut dyn CryptoRngCore); +pub(crate) struct BoxedRng<'a>(pub(crate) &'a mut dyn CryptoRngCore); impl CryptoRng for BoxedRng<'_> {} @@ -48,10 +48,11 @@ pub(crate) trait ObjectSafeRound: 'static + Debug + Send + Sync { fn message_destinations(&self) -> &BTreeSet; - fn make_direct_message_with_artifact( + fn make_direct_message( &self, rng: &mut dyn CryptoRngCore, serializer: &Serializer, + deserializer: &Deserializer, destination: &Id, ) -> Result<(DirectMessage, Option), LocalError>; @@ -59,12 +60,14 @@ pub(crate) trait ObjectSafeRound: 'static + Debug + Send + Sync { &self, rng: &mut dyn CryptoRngCore, serializer: &Serializer, + deserializer: &Deserializer, ) -> Result; fn make_normal_broadcast( &self, rng: &mut dyn CryptoRngCore, serializer: &Serializer, + deserializer: &Deserializer, ) -> Result; fn receive_message( @@ -87,7 +90,9 @@ pub(crate) trait ObjectSafeRound: 'static + Debug + Send + Sync { fn expecting_messages_from(&self) -> &BTreeSet; /// Returns the type ID of the implementing type. - fn get_type_id(&self) -> core::any::TypeId; + fn get_type_id(&self) -> core::any::TypeId { + core::any::TypeId::of::() + } } // The `fn(Id) -> Id` bit is so that `ObjectSafeRoundWrapper` didn't require a bound on `Id` to be @@ -130,21 +135,22 @@ where self.round.message_destinations() } - fn make_direct_message_with_artifact( + fn make_direct_message( &self, rng: &mut dyn CryptoRngCore, serializer: &Serializer, + #[allow(unused_variables)] deserializer: &Deserializer, destination: &Id, ) -> Result<(DirectMessage, Option), LocalError> { let mut boxed_rng = BoxedRng(rng); - self.round - .make_direct_message_with_artifact(&mut boxed_rng, serializer, destination) + self.round.make_direct_message(&mut boxed_rng, serializer, destination) } fn make_echo_broadcast( &self, rng: &mut dyn CryptoRngCore, serializer: &Serializer, + #[allow(unused_variables)] deserializer: &Deserializer, ) -> Result { let mut boxed_rng = BoxedRng(rng); self.round.make_echo_broadcast(&mut boxed_rng, serializer) @@ -154,6 +160,7 @@ where &self, rng: &mut dyn CryptoRngCore, serializer: &Serializer, + #[allow(unused_variables)] deserializer: &Deserializer, ) -> Result { let mut boxed_rng = BoxedRng(rng); self.round.make_normal_broadcast(&mut boxed_rng, serializer) @@ -192,29 +199,53 @@ where fn expecting_messages_from(&self) -> &BTreeSet { self.round.expecting_messages_from() } +} - fn get_type_id(&self) -> core::any::TypeId { - core::any::TypeId::of::() - } +// We do not want to expose `ObjectSafeRound` to the user, so it is hidden in a struct. +/// A wrapped new round that may be returned by [`Round::finalize`] +/// or [`EntryPoint::new`](`crate::protocol::EntryPoint::new`). +#[derive_where::derive_where(Debug)] +pub struct BoxedRound { + wrapped: bool, + round: Box>, } -// When we are wrapping types implementing Round and overriding `finalize()`, -// we need to unbox the result of `finalize()`, set it as an attribute of the wrapping round, -// and then box the result. -// -// Because of Rust's peculiarities, Box that we return in `finalize()` -// cannot be unboxed into an object of a concrete type with `downcast()`, -// so we have to provide this workaround. -impl dyn ObjectSafeRound -where - Id: PartyId, - P: Protocol, -{ - pub fn try_downcast>(self: Box) -> Result> { - if core::any::TypeId::of::>() == self.get_type_id() { - // This should be safe since we just checked that we are casting to a correct type. +impl BoxedRound { + /// Wraps an object implementing the dynamic round trait ([`Round`](`crate::protocol::Round`)). + pub fn new_dynamic>(round: R) -> Self { + Self { + wrapped: true, + round: Box::new(ObjectSafeRoundWrapper::new(round)), + } + } + + pub(crate) fn new_object_safe>(round: R) -> Self { + Self { + wrapped: false, + round: Box::new(round), + } + } + + pub(crate) fn as_ref(&self) -> &dyn ObjectSafeRound { + self.round.as_ref() + } + + pub(crate) fn into_boxed(self) -> Box> { + self.round + } + + fn boxed_type_is(&self) -> bool { + core::any::TypeId::of::() == self.round.get_type_id() + } + + /// Attempts to extract an object of a concrete type, preserving the original on failure. + pub fn try_downcast>(self) -> Result { + if self.wrapped && self.boxed_type_is::>() { + // Safety: This is safe since we just checked that we are casting to the correct type. let boxed_downcast = unsafe { - Box::>::from_raw(Box::into_raw(self) as *mut ObjectSafeRoundWrapper) + Box::>::from_raw( + Box::into_raw(self.round) as *mut ObjectSafeRoundWrapper + ) }; Ok(boxed_downcast.round) } else { @@ -222,8 +253,32 @@ where } } - pub fn downcast>(self: Box) -> Result { + /// Attempts to extract an object of a concrete type. + /// + /// Fails if the wrapped type is not `T`. + pub fn downcast>(self) -> Result { self.try_downcast() .map_err(|_| LocalError::new(format!("Failed to downcast into type {}", core::any::type_name::()))) } + + /// Attempts to provide a reference to an object of a concrete type. + /// + /// Fails if the wrapped type is not `T`. + pub fn downcast_ref>(&self) -> Result<&T, LocalError> { + if self.wrapped && self.boxed_type_is::>() { + let ptr: *const dyn ObjectSafeRound = self.round.as_ref(); + // Safety: This is safe since we just checked that we are casting to the correct type. + Ok(unsafe { &*(ptr as *const T) }) + } else { + Err(LocalError::new(format!( + "Failed to downcast into type {}", + core::any::type_name::() + ))) + } + } + + /// Returns the round's ID. + pub fn id(&self) -> RoundId { + self.round.id() + } } diff --git a/manul/src/protocol/round.rs b/manul/src/protocol/round.rs index e2f113b..bf6f6e7 100644 --- a/manul/src/protocol/round.rs +++ b/manul/src/protocol/round.rs @@ -3,88 +3,92 @@ use alloc::{ collections::{BTreeMap, BTreeSet}, format, string::String, + vec, vec::Vec, }; -use core::{any::Any, fmt::Debug}; +use core::{ + any::Any, + fmt::{self, Debug, Display}, +}; use rand_core::CryptoRngCore; use serde::{Deserialize, Serialize}; +use tinyvec::{tiny_vec, TinyVec}; use super::{ errors::{FinalizeError, LocalError, MessageValidationError, ProtocolValidationError, ReceiveError}, message::{DirectMessage, EchoBroadcast, NormalBroadcast, ProtocolMessagePart}, - object_safe::{ObjectSafeRound, ObjectSafeRoundWrapper}, + object_safe::BoxedRound, serialization::{Deserializer, Serializer}, }; /// Possible successful outcomes of [`Round::finalize`]. #[derive(Debug)] -pub enum FinalizeOutcome { +pub enum FinalizeOutcome { /// Transition to a new round. - AnotherRound(AnotherRound), + AnotherRound(BoxedRound), /// The protocol reached a result. Result(P::Result), } -impl FinalizeOutcome -where - Id: PartyId, - P: Protocol, -{ - /// A helper method to create an [`AnotherRound`](`Self::AnotherRound`) variant. - pub fn another_round(round: impl Round) -> Self { - Self::AnotherRound(AnotherRound::new(round)) - } -} - -// We do not want to expose `ObjectSafeRound` to the user, so it is hidden in a struct. -/// A wrapped new round that may be returned by [`Round::finalize`]. -#[derive(Debug)] -pub struct AnotherRound(Box>); - -impl AnotherRound -where - Id: PartyId, - P: Protocol, -{ - /// Wraps an object implementing [`Round`]. - pub fn new(round: impl Round) -> Self { - Self(Box::new(ObjectSafeRoundWrapper::new(round))) - } - - /// Returns the inner boxed type. - /// This is an internal method to be used in `Session`. - pub(crate) fn into_boxed(self) -> Box> { - self.0 - } - - /// Attempts to extract an object of a concrete type. - pub fn downcast>(self) -> Result { - self.0.downcast::() - } - - /// Attempts to extract an object of a concrete type, preserving the original on failure. - pub fn try_downcast>(self) -> Result { - self.0.try_downcast::().map_err(Self) - } -} - /// A round identifier. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub struct RoundId { - round_num: u8, + round_nums: TinyVec<[u8; 4]>, is_echo: bool, } +impl Display for RoundId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + write!(f, "Round ")?; + for (i, round_num) in self.round_nums.iter().enumerate().rev() { + write!(f, "{}", round_num)?; + if i != 0 { + write!(f, "-")?; + } + } + if self.is_echo { + write!(f, " (echo)")?; + } + Ok(()) + } +} + impl RoundId { /// Creates a new round identifier. pub fn new(round_num: u8) -> Self { Self { - round_num, + round_nums: tiny_vec!(round_num, 0, 0, 0), is_echo: false, } } + /// Prefixes this round ID (possibly already nested) with a group number. + pub(crate) fn group_under(&self, round_num: u8) -> Self { + let mut round_nums = self.round_nums.clone(); + round_nums.push(round_num); + Self { + round_nums, + is_echo: self.is_echo, + } + } + + /// Removes the top group prefix from this round ID. + /// + /// Returns the `Err` variant if the round ID is not nested. + pub(crate) fn ungroup(&self) -> Result { + if self.round_nums.len() == 1 { + Err(LocalError::new("This round ID is not in a group")) + } else { + let mut round_nums = self.round_nums.clone(); + round_nums.pop().expect("vector size greater than 1"); + Ok(Self { + round_nums, + is_echo: self.is_echo, + }) + } + } + /// Returns `true` if this is an ID of an echo broadcast round. pub(crate) fn is_echo(&self) -> bool { self.is_echo @@ -100,7 +104,7 @@ impl RoundId { panic!("This is already an echo round ID"); } Self { - round_num: self.round_num, + round_nums: self.round_nums.clone(), is_echo: true, } } @@ -115,7 +119,7 @@ impl RoundId { panic!("This is already an non-echo round ID"); } Self { - round_num: self.round_num, + round_nums: self.round_nums.clone(), is_echo: false, } } @@ -127,12 +131,12 @@ pub trait Protocol: 'static { type Result: Debug; /// An object of this type will be returned when a provable error happens during [`Round::receive_message`]. - type ProtocolError: ProtocolError + Serialize + for<'de> Deserialize<'de>; + type ProtocolError: ProtocolError; /// An object of this type will be returned when an unattributable error happens during [`Round::finalize`]. /// /// It proves that the node did its job correctly, to be adjudicated by a third party. - type CorrectnessProof: Send + Serialize + for<'de> Deserialize<'de> + Debug; + type CorrectnessProof: CorrectnessProof; /// Returns `Ok(())` if the given direct message cannot be deserialized /// assuming it is a direct message from the round `round_id`. @@ -181,7 +185,7 @@ pub trait Protocol: 'static { /// /// Provable here means that we can create an evidence object entirely of messages signed by some party, /// which, in combination, prove the party's malicious actions. -pub trait ProtocolError: Debug + Clone + Send { +pub trait ProtocolError: Debug + Clone + Send + Serialize + for<'de> Deserialize<'de> { /// A description of the error that will be included in the generated evidence. /// /// Make it short and informative. @@ -244,6 +248,40 @@ pub trait ProtocolError: Debug + Clone + Send { ) -> Result<(), ProtocolValidationError>; } +// A convenience implementation for protocols that don't define any errors. +// Have to do it for `()`, since `!` is unstable. +impl ProtocolError for () { + fn description(&self) -> String { + panic!("Attempt to use an empty error type in an evidence. This is a bug in the protocol implementation.") + } + + fn verify_messages_constitute_error( + &self, + _deserializer: &Deserializer, + _echo_broadcast: &EchoBroadcast, + _normal_broadcast: &NormalBroadcast, + _direct_message: &DirectMessage, + _echo_broadcasts: &BTreeMap, + _normal_broadcasts: &BTreeMap, + _direct_messages: &BTreeMap, + _combined_echos: &BTreeMap>, + ) -> Result<(), ProtocolValidationError> { + panic!("Attempt to use an empty error type in an evidence. This is a bug in the protocol implementation.") + } +} + +/// Describes unattributable errors originating during protocol execution. +/// +/// In the situations where no specific message can be blamed for an error, +/// each node must generate a correctness proof proving that they performed their duties correctly, +/// and the collection of proofs is verified by a third party. +/// One of the proofs will necessarily be missing or invalid. +pub trait CorrectnessProof: Debug + Clone + Send + Serialize + for<'de> Deserialize<'de> {} + +// A convenience implementation for protocols that don't define any errors. +// Have to do it for `()`, since `!` is unstable. +impl CorrectnessProof for () {} + /// Message payload created in [`Round::receive_message`]. #[derive(Debug)] pub struct Payload(pub Box); @@ -301,10 +339,18 @@ impl Artifact { /// /// This is a round that can be created directly; /// all the others are only reachable throud [`Round::finalize`] by the execution layer. -pub trait FirstRound: Round + Sized { +pub trait EntryPoint { /// Additional inputs for the protocol (besides the mandatory ones in [`new`](`Self::new`)). type Inputs; + /// The protocol implemented by the round this entry points returns. + type Protocol: Protocol; + + /// Returns the ID of the round returned by [`Self::new`]. + fn entry_round() -> RoundId { + RoundId::new(1) + } + /// Creates the round. /// /// `session_id` can be assumed to be the same for each node participating in a session. @@ -314,13 +360,13 @@ pub trait FirstRound: Round + Sized { shared_randomness: &[u8], id: Id, inputs: Self::Inputs, - ) -> Result; + ) -> Result, LocalError>; } /// A trait alias for the combination of traits needed for a party identifier. -pub trait PartyId: 'static + Debug + Clone + Ord + Send + Sync {} +pub trait PartyId: 'static + Debug + Clone + Ord + Send + Sync + Serialize + for<'de> Deserialize<'de> {} -impl PartyId for T where T: 'static + Debug + Clone + Ord + Send + Sync {} +impl PartyId for T where T: 'static + Debug + Clone + Ord + Send + Sync + Serialize + for<'de> Deserialize<'de> {} /** A type representing a single round of a protocol. @@ -358,35 +404,16 @@ pub trait Round: 'static + Debug + Send + Sync { /// /// Return [`DirectMessage::none`] if this round does not send direct messages. /// - /// Falls back to [`make_direct_message`](`Self::make_direct_message`) if not implemented. - /// This is the method that will be called by the upper layer when creating direct messages. - /// /// In some protocols, when a message to another node is created, there is some associated information /// that needs to be retained for later (randomness, proofs of knowledge, and so on). /// These should be put in an [`Artifact`] and will be available at the time of [`finalize`](`Self::finalize`). - fn make_direct_message_with_artifact( - &self, - rng: &mut impl CryptoRngCore, - serializer: &Serializer, - destination: &Id, - ) -> Result<(DirectMessage, Option), LocalError> { - Ok((self.make_direct_message(rng, serializer, destination)?, None)) - } - - /// Returns the direct message to the given destination. - /// - /// This method will not be called by the upper layer directly, - /// only via [`make_direct_message_with_artifact`](`Self::make_direct_message_with_artifact`). - /// - /// Return [`DirectMessage::none`] if this round does not send direct messages. - /// This is also the blanket implementation. fn make_direct_message( &self, #[allow(unused_variables)] rng: &mut impl CryptoRngCore, #[allow(unused_variables)] serializer: &Serializer, #[allow(unused_variables)] destination: &Id, - ) -> Result { - Ok(DirectMessage::none()) + ) -> Result<(DirectMessage, Option), LocalError> { + Ok((DirectMessage::none(), None)) } /// Returns the echo broadcast for this round. @@ -438,7 +465,7 @@ pub trait Round: 'static + Debug + Send + Sync { /// /// `payloads` here are the ones previously generated by [`receive_message`](`Self::receive_message`), /// and `artifacts` are the ones previously generated by - /// [`make_direct_message_with_artifact`](`Self::make_direct_message_with_artifact`). + /// [`make_direct_message`](`Self::make_direct_message`). fn finalize( self, rng: &mut impl CryptoRngCore, diff --git a/manul/src/session/echo.rs b/manul/src/session/echo.rs index 5bbd124..c4fbcb9 100644 --- a/manul/src/session/echo.rs +++ b/manul/src/session/echo.rs @@ -1,5 +1,4 @@ use alloc::{ - boxed::Box, collections::{BTreeMap, BTreeSet}, format, string::String, @@ -18,8 +17,8 @@ use super::{ }; use crate::{ protocol::{ - Artifact, Deserializer, DirectMessage, EchoBroadcast, FinalizeError, FinalizeOutcome, MessageValidationError, - NormalBroadcast, ObjectSafeRound, Payload, Protocol, ProtocolMessagePart, ReceiveError, Round, RoundId, + Artifact, BoxedRound, Deserializer, DirectMessage, EchoBroadcast, FinalizeError, FinalizeOutcome, + MessageValidationError, NormalBroadcast, Payload, Protocol, ProtocolMessagePart, ReceiveError, Round, RoundId, Serializer, }, utils::SerializableMap, @@ -75,12 +74,12 @@ pub(crate) struct EchoRoundMessage { /// participants. The execution layer of the protocol guarantees that all participants have received /// the messages. #[derive_where::derive_where(Debug)] -pub struct EchoRound { +pub struct EchoRound { verifier: SP::Verifier, echo_broadcasts: BTreeMap>, destinations: BTreeSet, expected_echos: BTreeSet, - main_round: Box>, + main_round: BoxedRound, payloads: BTreeMap, artifacts: BTreeMap, } @@ -94,7 +93,7 @@ where verifier: SP::Verifier, my_echo_broadcast: SignedMessagePart, echo_broadcasts: BTreeMap>, - main_round: Box>, + main_round: BoxedRound, payloads: BTreeMap, artifacts: BTreeMap, ) -> Self { @@ -151,7 +150,7 @@ where } fn possible_next_rounds(&self) -> BTreeSet { - self.main_round.possible_next_rounds() + self.main_round.as_ref().possible_next_rounds() } fn message_destinations(&self) -> &BTreeSet { @@ -297,6 +296,8 @@ where _payloads: BTreeMap, _artifacts: BTreeMap, ) -> Result, FinalizeError> { - self.main_round.finalize(rng, self.payloads, self.artifacts) + self.main_round + .into_boxed() + .finalize(rng, self.payloads, self.artifacts) } } diff --git a/manul/src/session/evidence.rs b/manul/src/session/evidence.rs index c0fab32..eafd75c 100644 --- a/manul/src/session/evidence.rs +++ b/manul/src/session/evidence.rs @@ -100,8 +100,8 @@ where .iter() .map(|round_id| { transcript - .get_echo_broadcast(*round_id, verifier) - .map(|echo| (*round_id, echo)) + .get_echo_broadcast(round_id.clone(), verifier) + .map(|echo| (round_id.clone(), echo)) }) .collect::, _>>()?; @@ -110,8 +110,8 @@ where .iter() .map(|round_id| { transcript - .get_normal_broadcast(*round_id, verifier) - .map(|bc| (*round_id, bc)) + .get_normal_broadcast(round_id.clone(), verifier) + .map(|bc| (round_id.clone(), bc)) }) .collect::, _>>()?; @@ -120,8 +120,8 @@ where .iter() .map(|round_id| { transcript - .get_direct_message(*round_id, verifier) - .map(|dm| (*round_id, dm)) + .get_direct_message(round_id.clone(), verifier) + .map(|dm| (round_id.clone(), dm)) }) .collect::, _>>()?; @@ -131,7 +131,7 @@ where .map(|round_id| { transcript .get_normal_broadcast(round_id.echo(), verifier) - .map(|dm| (*round_id, dm)) + .map(|dm| (round_id.clone(), dm)) }) .collect::, _>>()?; @@ -470,12 +470,12 @@ where for (round_id, direct_message) in self.direct_messages.iter() { let verified_direct_message = direct_message.clone().verify::(verifier)?; let metadata = verified_direct_message.metadata(); - if metadata.session_id() != session_id || metadata.round_id() != *round_id { + if metadata.session_id() != session_id || &metadata.round_id() != round_id { return Err(EvidenceError::InvalidEvidence( "Invalid attached message metadata".into(), )); } - verified_direct_messages.insert(*round_id, verified_direct_message.payload().clone()); + verified_direct_messages.insert(round_id.clone(), verified_direct_message.payload().clone()); } let verified_echo_broadcast = self.echo_broadcast.clone().verify::(verifier)?.payload().clone(); @@ -500,31 +500,31 @@ where for (round_id, echo_broadcast) in self.echo_broadcasts.iter() { let verified_echo_broadcast = echo_broadcast.clone().verify::(verifier)?; let metadata = verified_echo_broadcast.metadata(); - if metadata.session_id() != session_id || metadata.round_id() != *round_id { + if metadata.session_id() != session_id || &metadata.round_id() != round_id { return Err(EvidenceError::InvalidEvidence( "Invalid attached message metadata".into(), )); } - verified_echo_broadcasts.insert(*round_id, verified_echo_broadcast.payload().clone()); + verified_echo_broadcasts.insert(round_id.clone(), verified_echo_broadcast.payload().clone()); } let mut verified_normal_broadcasts = BTreeMap::new(); for (round_id, normal_broadcast) in self.normal_broadcasts.iter() { let verified_normal_broadcast = normal_broadcast.clone().verify::(verifier)?; let metadata = verified_normal_broadcast.metadata(); - if metadata.session_id() != session_id || metadata.round_id() != *round_id { + if metadata.session_id() != session_id || &metadata.round_id() != round_id { return Err(EvidenceError::InvalidEvidence( "Invalid attached message metadata".into(), )); } - verified_normal_broadcasts.insert(*round_id, verified_normal_broadcast.payload().clone()); + verified_normal_broadcasts.insert(round_id.clone(), verified_normal_broadcast.payload().clone()); } let mut combined_echos = BTreeMap::new(); for (round_id, combined_echo) in self.combined_echos.iter() { let verified_combined_echo = combined_echo.clone().verify::(verifier)?; let metadata = verified_combined_echo.metadata(); - if metadata.session_id() != session_id || metadata.round_id().non_echo() != *round_id { + if metadata.session_id() != session_id || &metadata.round_id().non_echo() != round_id { return Err(EvidenceError::InvalidEvidence( "Invalid attached message metadata".into(), )); @@ -537,14 +537,14 @@ where for (other_verifier, echo_broadcast) in echo_set.echo_broadcasts.iter() { let verified_echo_broadcast = echo_broadcast.clone().verify::(other_verifier)?; let metadata = verified_echo_broadcast.metadata(); - if metadata.session_id() != session_id || metadata.round_id() != *round_id { + if metadata.session_id() != session_id || &metadata.round_id() != round_id { return Err(EvidenceError::InvalidEvidence( "Invalid attached message metadata".into(), )); } verified_echo_set.push(verified_echo_broadcast.payload().clone()); } - combined_echos.insert(*round_id, verified_echo_set); + combined_echos.insert(round_id.clone(), verified_echo_set); } Ok(self.error.verify_messages_constitute_error( diff --git a/manul/src/session/message.rs b/manul/src/session/message.rs index 51dbc55..199c2f0 100644 --- a/manul/src/session/message.rs +++ b/manul/src/session/message.rs @@ -68,7 +68,7 @@ impl MessageMetadata { } pub fn round_id(&self) -> RoundId { - self.round_id + self.round_id.clone() } } diff --git a/manul/src/session/session.rs b/manul/src/session/session.rs index 5c1845e..db81183 100644 --- a/manul/src/session/session.rs +++ b/manul/src/session/session.rs @@ -23,9 +23,9 @@ use super::{ LocalError, RemoteError, }; use crate::protocol::{ - Artifact, Deserializer, DirectMessage, EchoBroadcast, FinalizeError, FinalizeOutcome, FirstRound, NormalBroadcast, - ObjectSafeRound, ObjectSafeRoundWrapper, PartyId, Payload, Protocol, ProtocolMessagePart, ReceiveError, - ReceiveErrorType, Round, RoundId, Serializer, + Artifact, BoxedRound, Deserializer, DirectMessage, EchoBroadcast, EntryPoint, FinalizeError, FinalizeOutcome, + NormalBroadcast, PartyId, Payload, Protocol, ProtocolMessagePart, ReceiveError, ReceiveErrorType, RoundId, + Serializer, }; /// A set of types needed to execute a session. @@ -106,7 +106,7 @@ pub struct Session { verifier: SP::Verifier, serializer: Serializer, deserializer: Deserializer, - round: Box>, + round: BoxedRound, message_destinations: BTreeSet, echo_broadcast: SignedMessagePart, normal_broadcast: SignedMessagePart, @@ -141,15 +141,10 @@ where inputs: R::Inputs, ) -> Result where - R: FirstRound + Round, + R: EntryPoint, { let verifier = signer.verifying_key(); - let first_round = Box::new(ObjectSafeRoundWrapper::new(R::new( - rng, - session_id.as_ref(), - verifier.clone(), - inputs, - )?)); + let first_round = R::new(rng, session_id.as_ref(), verifier.clone(), inputs)?; let serializer = Serializer::new::(); let deserializer = Deserializer::new::(); Self::new_for_next_round( @@ -169,21 +164,21 @@ where signer: SP::Signer, serializer: Serializer, deserializer: Deserializer, - round: Box>, + round: BoxedRound, transcript: Transcript, ) -> Result { let verifier = signer.verifying_key(); - let echo = round.make_echo_broadcast(rng, &serializer)?; + let echo = round.as_ref().make_echo_broadcast(rng, &serializer, &deserializer)?; let echo_broadcast = SignedMessagePart::new::(rng, &signer, &session_id, round.id(), echo)?; - let normal = round.make_normal_broadcast(rng, &serializer)?; + let normal = round.as_ref().make_normal_broadcast(rng, &serializer, &deserializer)?; let normal_broadcast = SignedMessagePart::new::(rng, &signer, &session_id, round.id(), normal)?; - let message_destinations = round.message_destinations().clone(); + let message_destinations = round.as_ref().message_destinations().clone(); let possible_next_rounds = if echo_broadcast.payload().is_none() { - round.possible_next_rounds() + round.as_ref().possible_next_rounds() } else { BTreeSet::from([round.id().echo()]) }; @@ -228,7 +223,8 @@ where ) -> Result<(Message, ProcessedArtifact), LocalError> { let (direct_message, artifact) = self.round - .make_direct_message_with_artifact(rng, &self.serializer, destination)?; + .as_ref() + .make_direct_message(rng, &self.serializer, &self.deserializer, destination)?; let message = Message::new::( rng, @@ -321,7 +317,7 @@ where } MessageFor::ThisRound } else if self.possible_next_rounds.contains(&message_round_id) { - if accum.message_is_cached(from, message_round_id) { + if accum.message_is_cached(from, &message_round_id) { let err = format!("Message for {:?} is already cached", message_round_id); accum.register_unprovable_error(from, RemoteError::new(&err))?; trace!("{key:?} {err}"); @@ -353,15 +349,15 @@ where } Err(MessageVerificationError::Local(error)) => return Err(error), }; - debug!("{key:?}: Received {message_round_id:?} message from {from:?}"); + debug!("{key:?}: Received {message_round_id} message from {from:?}"); match message_for { MessageFor::ThisRound => { accum.mark_processing(&verified_message)?; - Ok(PreprocessOutcome::ToProcess(verified_message)) + Ok(PreprocessOutcome::ToProcess(Box::new(verified_message))) } MessageFor::NextRound => { - debug!("{key:?}: Caching message from {from:?} for {message_round_id:?}"); + debug!("{key:?}: Caching message from {from:?} for {message_round_id}"); accum.cache_message(verified_message)?; Ok(PreprocessOutcome::Cached) } @@ -376,7 +372,7 @@ where rng: &mut impl CryptoRngCore, message: VerifiedMessage, ) -> ProcessedMessage { - let processed = self.round.receive_message( + let processed = self.round.as_ref().receive_message( rng, &self.deserializer, message.from(), @@ -400,7 +396,7 @@ where /// Makes an accumulator for a new round. pub fn make_accumulator(&self) -> RoundAccumulator { - RoundAccumulator::new(self.round.expecting_messages_from()) + RoundAccumulator::new(self.round.as_ref().expecting_messages_from()) } fn terminate_inner( @@ -410,7 +406,7 @@ where ) -> Result, LocalError> { let round_id = self.round_id(); let transcript = self.transcript.update( - round_id, + &round_id, accum.echo_broadcasts, accum.normal_broadcasts, accum.direct_messages, @@ -450,7 +446,7 @@ where let round_id = self.round_id(); let transcript = self.transcript.update( - round_id, + &round_id, accum.echo_broadcasts, accum.normal_broadcasts, accum.direct_messages, @@ -462,14 +458,14 @@ where let echo_round_needed = !self.echo_broadcast.payload().is_none(); if echo_round_needed { - let round = Box::new(ObjectSafeRoundWrapper::new(EchoRound::::new( + let round = BoxedRound::new_dynamic(EchoRound::::new( verifier, self.echo_broadcast, transcript.echo_broadcasts(round_id)?, self.round, accum.payloads, accum.artifacts, - ))); + )); let cached_messages = filter_messages(accum.cached, round.id()); let session = Session::new_for_next_round( rng, @@ -486,14 +482,12 @@ where }); } - match self.round.finalize(rng, accum.payloads, accum.artifacts) { + match self.round.into_boxed().finalize(rng, accum.payloads, accum.artifacts) { Ok(result) => Ok(match result { FinalizeOutcome::Result(result) => { RoundOutcome::Finished(SessionReport::new(SessionOutcome::Result(result), transcript)) } - FinalizeOutcome::AnotherRound(another_round) => { - let round = another_round.into_boxed(); - + FinalizeOutcome::AnotherRound(round) => { // Protecting against common bugs if !self.possible_next_rounds.contains(&round.id()) { return Err(LocalError::new(format!("Unexpected next round id: {:?}", round.id()))); @@ -610,9 +604,9 @@ where self.processing.contains(from) } - fn message_is_cached(&self, from: &SP::Verifier, round_id: RoundId) -> bool { + fn message_is_cached(&self, from: &SP::Verifier, round_id: &RoundId) -> bool { if let Some(entry) = self.cached.get(from) { - entry.contains_key(&round_id) + entry.contains_key(round_id) } else { false } @@ -740,7 +734,7 @@ where } ReceiveErrorType::Echo(error) => { let (_echo_broadcast, normal_broadcast, _direct_message) = processed.message.into_parts(); - let evidence = Evidence::new_echo_round_error(&from, normal_broadcast, error)?; + let evidence = Evidence::new_echo_round_error(&from, normal_broadcast, *error)?; self.register_provable_error(&from, evidence) } ReceiveErrorType::Local(error) => Err(error), @@ -751,7 +745,7 @@ where let from = message.from().clone(); let round_id = message.metadata().round_id(); let cached = self.cached.entry(from.clone()).or_default(); - if cached.insert(round_id, message).is_some() { + if cached.insert(round_id.clone(), message).is_some() { return Err(LocalError::new(format!( "A message from for {:?} has already been cached", round_id @@ -777,7 +771,7 @@ pub struct ProcessedMessage { #[derive(Debug, Clone)] pub enum PreprocessOutcome { /// The message was successfully verified, pass it on to [`Session::process_message`]. - ToProcess(VerifiedMessage), + ToProcess(Box>), /// The message was intended for the next round and was cached. /// /// No action required now, cached messages will be returned on successful [`Session::finalize_round`]. @@ -801,7 +795,7 @@ impl PreprocessOutcome { /// so the user may choose to ignore them if no logging is desired. pub fn ok(self) -> Option> { match self { - Self::ToProcess(message) => Some(message), + Self::ToProcess(message) => Some(*message), _ => None, } } @@ -819,17 +813,11 @@ fn filter_messages( #[cfg(test)] mod tests { - use alloc::{collections::BTreeMap, string::String, vec::Vec}; - use impls::impls; - use serde::{Deserialize, Serialize}; use super::{Message, ProcessedArtifact, ProcessedMessage, Session, VerifiedMessage}; use crate::{ - protocol::{ - Deserializer, DirectMessage, EchoBroadcast, NormalBroadcast, Protocol, ProtocolError, - ProtocolValidationError, RoundId, - }, + protocol::Protocol, testing::{BinaryFormat, TestSessionParams, TestVerifier}, }; @@ -845,32 +833,9 @@ mod tests { struct DummyProtocol; - #[derive(Debug, Clone, Serialize, Deserialize)] - struct DummyProtocolError; - - impl ProtocolError for DummyProtocolError { - fn description(&self) -> String { - unimplemented!() - } - - fn verify_messages_constitute_error( - &self, - _deserializer: &Deserializer, - _echo_broadcast: &EchoBroadcast, - _normal_broadcast: &NormalBroadcast, - _direct_message: &DirectMessage, - _echo_broadcasts: &BTreeMap, - _normal_broadcasts: &BTreeMap, - _direct_messages: &BTreeMap, - _combined_echos: &BTreeMap>, - ) -> Result<(), ProtocolValidationError> { - unimplemented!() - } - } - impl Protocol for DummyProtocol { type Result = (); - type ProtocolError = DummyProtocolError; + type ProtocolError = (); type CorrectnessProof = (); } diff --git a/manul/src/session/transcript.rs b/manul/src/session/transcript.rs index 382448d..3f678e5 100644 --- a/manul/src/session/transcript.rs +++ b/manul/src/session/transcript.rs @@ -36,7 +36,7 @@ where #[allow(clippy::too_many_arguments)] pub fn update( self, - round_id: RoundId, + round_id: &RoundId, echo_broadcasts: BTreeMap>, normal_broadcasts: BTreeMap>, direct_messages: BTreeMap>, @@ -45,7 +45,7 @@ where missing_messages: BTreeSet, ) -> Result { let mut all_echo_broadcasts = self.echo_broadcasts; - match all_echo_broadcasts.entry(round_id) { + match all_echo_broadcasts.entry(round_id.clone()) { Entry::Vacant(entry) => entry.insert(echo_broadcasts), Entry::Occupied(_) => { return Err(LocalError::new(format!( @@ -55,7 +55,7 @@ where }; let mut all_normal_broadcasts = self.normal_broadcasts; - match all_normal_broadcasts.entry(round_id) { + match all_normal_broadcasts.entry(round_id.clone()) { Entry::Vacant(entry) => entry.insert(normal_broadcasts), Entry::Occupied(_) => { return Err(LocalError::new(format!( @@ -65,7 +65,7 @@ where }; let mut all_direct_messages = self.direct_messages; - match all_direct_messages.entry(round_id) { + match all_direct_messages.entry(round_id.clone()) { Entry::Vacant(entry) => entry.insert(direct_messages), Entry::Occupied(_) => { return Err(LocalError::new(format!( @@ -93,7 +93,7 @@ where } let mut all_missing_messages = self.missing_messages; - match all_missing_messages.entry(round_id) { + match all_missing_messages.entry(round_id.clone()) { Entry::Vacant(entry) => entry.insert(missing_messages), Entry::Occupied(_) => { return Err(LocalError::new(format!( diff --git a/manul/src/testing.rs b/manul/src/testing.rs index 1427c84..8846eaa 100644 --- a/manul/src/testing.rs +++ b/manul/src/testing.rs @@ -1,10 +1,6 @@ /*! Utilities for testing protocols. -When testing round based protocols it can be complicated to "inject" the proper faults into the -process, e.g. to emulate a malicious participant. This module provides facilities to make this -easier, by providing a [`RoundOverride`] type along with a [`round_override`] macro. - The [`TestSessionParams`] provides an implementation of the [`SessionParameters`](crate::session::SessionParameters) trait, which in turn is used to setup [`Session`](crate::session::Session)s to drive the protocol. @@ -12,12 +8,10 @@ which in turn is used to setup [`Session`](crate::session::Session)s to drive th The [`run_sync()`] method is helpful to execute a protocol synchronously and collect the outcomes. */ -mod macros; mod run_sync; mod session_parameters; mod wire_format; -pub use macros::{round_override, RoundOverride, RoundWrapper}; pub use run_sync::run_sync; pub use session_parameters::{TestHasher, TestSessionParams, TestSignature, TestSigner, TestVerifier}; pub use wire_format::{BinaryFormat, HumanReadableFormat}; diff --git a/manul/src/testing/macros.rs b/manul/src/testing/macros.rs deleted file mode 100644 index d9c5d70..0000000 --- a/manul/src/testing/macros.rs +++ /dev/null @@ -1,177 +0,0 @@ -use alloc::collections::BTreeMap; - -use rand_core::CryptoRngCore; - -use crate::protocol::{ - Artifact, DirectMessage, EchoBroadcast, FinalizeError, FinalizeOutcome, LocalError, NormalBroadcast, PartyId, - Payload, Round, Serializer, -}; - -/// A trait defining a wrapper around an existing type implementing [`Round`]. -pub trait RoundWrapper: 'static + Sized + Send + Sync { - /// The inner round type. - type InnerRound: Round; - - /// Returns a reference to the inner round. - fn inner_round_ref(&self) -> &Self::InnerRound; - - /// Returns the inner round by value. - fn inner_round(self) -> Self::InnerRound; -} - -/// This trait defines overrides of some methods of [`RoundWrapper::InnerRound`]. -/// -/// Intended to be used with the [`round_override`] macro to generate the [`Round`] implementation. -/// -/// The blanket implementations delegate to the methods of the wrapped round. -pub trait RoundOverride: RoundWrapper { - /// An override for [`Round::make_direct_message_with_artifact`]. - fn make_direct_message_with_artifact( - &self, - rng: &mut impl CryptoRngCore, - serializer: &Serializer, - destination: &Id, - ) -> Result<(DirectMessage, Option), LocalError> { - let dm = self.make_direct_message(rng, serializer, destination)?; - Ok((dm, None)) - } - - /// An override for [`Round::make_direct_message`]. - fn make_direct_message( - &self, - rng: &mut impl CryptoRngCore, - serializer: &Serializer, - destination: &Id, - ) -> Result { - self.inner_round_ref().make_direct_message(rng, serializer, destination) - } - - /// An override for [`Round::make_echo_broadcast`]. - fn make_echo_broadcast( - &self, - rng: &mut impl CryptoRngCore, - serializer: &Serializer, - ) -> Result { - self.inner_round_ref().make_echo_broadcast(rng, serializer) - } - - /// An override for [`Round::make_normal_broadcast`]. - fn make_normal_broadcast( - &self, - rng: &mut impl CryptoRngCore, - serializer: &Serializer, - ) -> Result { - self.inner_round_ref().make_normal_broadcast(rng, serializer) - } - - /// An override for [`Round::finalize`]. - #[allow(clippy::type_complexity)] - fn finalize( - self, - rng: &mut impl CryptoRngCore, - payloads: BTreeMap, - artifacts: BTreeMap, - ) -> Result< - FinalizeOutcome>::InnerRound as Round>::Protocol>, - FinalizeError<<>::InnerRound as Round>::Protocol>, - > { - self.inner_round().finalize(rng, payloads, artifacts) - } -} - -/// A macro for "inheriting" from a [`Round`]-implementing type, and overriding some of its behavior. -/// -/// The given `$round` must implement [`RoundOverride`], and is generally some type -/// with one of its fields implementing [`Round`]. -/// Then, the macro will implement the [`Round`] trait for `$round` by delegating non-overridden methods to -/// the internal [`RoundWrapper::InnerRound`]. -#[macro_export] -macro_rules! round_override { - ($round: ident) => { - impl Round for $round - where - Id: $crate::protocol::PartyId, - $round: $crate::testing::RoundOverride, - { - type Protocol = - <<$round as $crate::testing::RoundWrapper>::InnerRound as $crate::protocol::Round>::Protocol; - - fn id(&self) -> $crate::protocol::RoundId { - self.inner_round_ref().id() - } - - fn possible_next_rounds(&self) -> ::alloc::collections::BTreeSet<$crate::protocol::RoundId> { - self.inner_round_ref().possible_next_rounds() - } - - fn message_destinations(&self) -> &::alloc::collections::BTreeSet { - self.inner_round_ref().message_destinations() - } - - fn make_direct_message( - &self, - rng: &mut impl CryptoRngCore, - serializer: &$crate::protocol::Serializer, - destination: &Id, - ) -> Result<$crate::protocol::DirectMessage, $crate::protocol::LocalError> { - >::make_direct_message(self, rng, serializer, destination) - } - - fn make_direct_message_with_artifact( - &self, - rng: &mut impl CryptoRngCore, - serializer: &$crate::protocol::Serializer, - destination: &Id, - ) -> Result<($crate::protocol::DirectMessage, Option<$crate::protocol::Artifact>), $crate::protocol::LocalError> { - >::make_direct_message_with_artifact(self, rng, serializer, destination) - } - - fn make_echo_broadcast( - &self, - rng: &mut impl CryptoRngCore, - serializer: &$crate::protocol::Serializer, - ) -> Result<$crate::protocol::EchoBroadcast, $crate::protocol::LocalError> { - >::make_echo_broadcast(self, rng, serializer) - } - - fn make_normal_broadcast( - &self, - rng: &mut impl CryptoRngCore, - serializer: &$crate::protocol::Serializer, - ) -> Result<$crate::protocol::NormalBroadcast, $crate::protocol::LocalError> { - >::make_normal_broadcast(self, rng, serializer) - } - - fn receive_message( - &self, - rng: &mut impl CryptoRngCore, - deserializer: &$crate::protocol::Deserializer, - from: &Id, - echo_broadcast: $crate::protocol::EchoBroadcast, - normal_broadcast: $crate::protocol::NormalBroadcast, - direct_message: $crate::protocol::DirectMessage, - ) -> Result<$crate::protocol::Payload, $crate::protocol::ReceiveError> { - self.inner_round_ref() - .receive_message(rng, deserializer, from, echo_broadcast, normal_broadcast, direct_message) - } - - fn finalize( - self, - rng: &mut impl CryptoRngCore, - payloads: ::alloc::collections::BTreeMap, - artifacts: ::alloc::collections::BTreeMap, - ) -> Result< - $crate::protocol::FinalizeOutcome, - $crate::protocol::FinalizeError - > { - >::finalize(self, rng, payloads, artifacts) - } - - fn expecting_messages_from(&self) -> &BTreeSet { - self.inner_round_ref().expecting_messages_from() - } - } - }; -} - -pub use round_override; diff --git a/manul/src/testing/run_sync.rs b/manul/src/testing/run_sync.rs index 83d24bc..8d08c87 100644 --- a/manul/src/testing/run_sync.rs +++ b/manul/src/testing/run_sync.rs @@ -6,7 +6,7 @@ use signature::Keypair; use tracing::debug; use crate::{ - protocol::{FirstRound, Protocol}, + protocol::{EntryPoint, Protocol}, session::{ CanFinalize, LocalError, Message, RoundAccumulator, RoundOutcome, Session, SessionId, SessionParameters, SessionReport, @@ -45,7 +45,7 @@ where let state = loop { match session.can_finalize(&accum) { CanFinalize::Yes => { - debug!("{:?}: finalizing {:?}", session.verifier(), session.round_id(),); + debug!("{:?}: finalizing {}", session.verifier(), session.round_id()); match session.finalize_round(rng, accum)? { RoundOutcome::Finished(report) => break State::Finished(report), RoundOutcome::AnotherRound { @@ -94,7 +94,7 @@ pub fn run_sync( inputs: Vec<(SP::Signer, R::Inputs)>, ) -> Result>, LocalError> where - R: FirstRound, + R: EntryPoint, SP: SessionParameters, { let session_id = SessionId::random::(rng); @@ -104,7 +104,7 @@ where for (signer, inputs) in inputs { let verifier = signer.verifying_key(); - let session = Session::::new::(rng, session_id.clone(), signer, inputs)?; + let session = Session::<_, SP>::new::(rng, session_id.clone(), signer, inputs)?; let mut accum = session.make_accumulator(); let destinations = session.message_destinations();