From 104d62561aed3e3e401c64316d29d4a2a883856f Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Fri, 1 Nov 2024 14:41:04 -0700 Subject: [PATCH] Add Misbehave combinator --- examples/src/lib.rs | 1 + examples/src/simple_chain.rs | 86 ++++++++ examples/src/simple_malicious.rs | 194 ++++++----------- manul/src/combinators.rs | 5 + manul/src/combinators/chain.rs | 320 +++++++++++++++++++++++++++++ manul/src/combinators/misbehave.rs | 192 +++++++++++++++++ manul/src/lib.rs | 1 + manul/src/protocol.rs | 3 +- manul/src/protocol/errors.rs | 26 +++ manul/src/protocol/object_safe.rs | 15 +- 10 files changed, 706 insertions(+), 137 deletions(-) create mode 100644 examples/src/simple_chain.rs create mode 100644 manul/src/combinators.rs create mode 100644 manul/src/combinators/chain.rs create mode 100644 manul/src/combinators/misbehave.rs diff --git a/examples/src/lib.rs b/examples/src/lib.rs index c3aff76..f24273f 100644 --- a/examples/src/lib.rs +++ b/examples/src/lib.rs @@ -1,6 +1,7 @@ extern crate alloc; pub mod simple; +mod simple_chain; #[cfg(test)] mod simple_malicious; diff --git a/examples/src/simple_chain.rs b/examples/src/simple_chain.rs new file mode 100644 index 0000000..2e8269b --- /dev/null +++ b/examples/src/simple_chain.rs @@ -0,0 +1,86 @@ +use core::fmt::Debug; +use core::marker::PhantomData; + +use manul::{combinators::*, protocol::PartyId}; + +use super::simple::{Inputs, Round1, SimpleProtocol}; + +pub struct ChainedSimple(PhantomData); + +#[derive(Debug)] +pub struct NewInputs(Inputs); + +impl<'a, Id: PartyId> From<&'a NewInputs> for Inputs { + fn from(source: &'a NewInputs) -> Self { + source.0.clone() + } +} + +impl From<(NewInputs, u8)> for Inputs { + fn from(source: (NewInputs, u8)) -> Self { + source.0 .0 + } +} + +impl Chained for ChainedSimple { + type Protocol1 = SimpleProtocol; + type Protocol2 = SimpleProtocol; + type CorrectnessProof = (); + type Inputs = NewInputs; + type EntryPoint1 = Round1; + type EntryPoint2 = Round1; +} + +#[cfg(test)] +mod tests { + use alloc::collections::BTreeSet; + + use manul::{ + combinators::Chain, + session::{signature::Keypair, SessionOutcome}, + testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier}, + }; + use rand_core::OsRng; + use tracing_subscriber::EnvFilter; + + use super::{ChainedSimple, NewInputs}; + use crate::simple::Inputs; + + #[test] + fn round() { + let signers = (0..3).map(TestSigner::new).collect::>(); + let all_ids = signers + .iter() + .map(|signer| signer.verifying_key()) + .collect::>(); + let inputs = signers + .into_iter() + .map(|signer| { + ( + signer, + NewInputs(Inputs { + all_ids: all_ids.clone(), + }), + ) + }) + .collect::>(); + + let my_subscriber = tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .finish(); + let reports = tracing::subscriber::with_default(my_subscriber, || { + run_sync::>, TestSessionParams>( + &mut OsRng, inputs, + ) + .unwrap() + }); + + for (_id, report) in reports { + if let SessionOutcome::Result(result) = report.outcome { + assert_eq!(result, 3); // 0 + 1 + 2 + } else { + panic!("Session did not finish successfully"); + } + } + } +} diff --git a/examples/src/simple_malicious.rs b/examples/src/simple_malicious.rs index 950c9f2..3556157 100644 --- a/examples/src/simple_malicious.rs +++ b/examples/src/simple_malicious.rs @@ -1,150 +1,74 @@ -use alloc::collections::{BTreeMap, BTreeSet}; +use alloc::collections::BTreeSet; use core::fmt::Debug; use manul::{ + combinators::{Misbehaving, MisbehavingInputs, MisbehavingRound}, protocol::{ - Artifact, DirectMessage, FinalizeError, FinalizeOutcome, FirstRound, LocalError, PartyId, Payload, - ProtocolMessagePart, Round, Serializer, + Artifact, DirectMessage, LocalError, ObjectSafeRound, PartyId, ProtocolMessagePart, RoundId, Serializer, }, session::signature::Keypair, - testing::{ - round_override, run_sync, BinaryFormat, RoundOverride, RoundWrapper, TestSessionParams, TestSigner, - TestVerifier, - }, + testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier}, }; use rand_core::{CryptoRngCore, OsRng}; use tracing_subscriber::EnvFilter; -use crate::simple::{Inputs, Round1, Round1Message, Round2, Round2Message}; +use crate::simple::{Inputs, Round1, Round1Message, Round2, Round2Message, SimpleProtocol}; #[derive(Debug, Clone, Copy)] enum Behavior { - Lawful, SerializedGarbage, AttributableFailure, AttributableFailureRound2, } -struct MaliciousInputs { - inputs: Inputs, - behavior: Behavior, -} - -#[derive(Debug)] -struct MaliciousRound1 { - round: Round1, - behavior: Behavior, -} - -impl RoundWrapper for MaliciousRound1 { - type InnerRound = Round1; - fn inner_round_ref(&self) -> &Self::InnerRound { - &self.round - } - fn inner_round(self) -> Self::InnerRound { - self.round - } -} +struct SimpleMaliciousProtocol; -impl FirstRound for MaliciousRound1 { - type Inputs = MaliciousInputs; - fn new( - rng: &mut impl CryptoRngCore, - shared_randomness: &[u8], - id: Id, - inputs: Self::Inputs, - ) -> Result { - let round = Round1::new(rng, shared_randomness, id, inputs.inputs)?; - Ok(Self { - round, - behavior: inputs.behavior, - }) - } -} +impl Misbehaving for SimpleMaliciousProtocol { + type Protocol = SimpleProtocol; + type FirstRound = Round1; -impl RoundOverride for MaliciousRound1 { - fn make_direct_message( - &self, - rng: &mut impl CryptoRngCore, + fn amend_direct_message( + _rng: &mut impl CryptoRngCore, + round: &dyn ObjectSafeRound, + behavior: &Behavior, serializer: &Serializer, - destination: &Id, + _destination: &Id, + direct_message: DirectMessage, + artifact: Option, ) -> Result<(DirectMessage, Option), LocalError> { - if matches!(self.behavior, Behavior::SerializedGarbage) { - Ok((DirectMessage::new(serializer, [99u8])?, None)) - } else if matches!(self.behavior, Behavior::AttributableFailure) { - let message = Round1Message { - my_position: self.round.context.ids_to_positions[&self.round.context.id], - your_position: self.round.context.ids_to_positions[&self.round.context.id], - }; - Ok((DirectMessage::new(serializer, message)?, None)) - } else { - self.inner_round_ref().make_direct_message(rng, serializer, destination) - } - } - - fn finalize( - self, - rng: &mut impl CryptoRngCore, - payloads: BTreeMap, - artifacts: BTreeMap, - ) -> Result< - FinalizeOutcome>::InnerRound as Round>::Protocol>, - FinalizeError<<>::InnerRound as Round>::Protocol>, - > { - let behavior = self.behavior; - let outcome = self.inner_round().finalize(rng, payloads, artifacts)?; - - Ok(match outcome { - FinalizeOutcome::Result(res) => FinalizeOutcome::Result(res), - FinalizeOutcome::AnotherRound(another_round) => { - let round2 = another_round.downcast::>().map_err(FinalizeError::Local)?; - FinalizeOutcome::another_round(MaliciousRound2 { - round: round2, - behavior, - }) + let dm = if round.id() == RoundId::new(1) { + match behavior { + Behavior::SerializedGarbage => DirectMessage::new(serializer, [99u8])?, + Behavior::AttributableFailure => { + let round1 = round.downcast_ref::>()?; + let message = Round1Message { + my_position: round1.context.ids_to_positions[&round1.context.id], + your_position: round1.context.ids_to_positions[&round1.context.id], + }; + DirectMessage::new(serializer, message)? + } + _ => direct_message, + } + } else if round.id() == RoundId::new(2) { + match behavior { + Behavior::AttributableFailureRound2 => { + let round2 = round.downcast_ref::>()?; + let message = Round2Message { + my_position: round2.context.ids_to_positions[&round2.context.id], + your_position: round2.context.ids_to_positions[&round2.context.id], + }; + DirectMessage::new(serializer, message)? + } + _ => direct_message, } - }) - } -} - -round_override!(MaliciousRound1); - -#[derive(Debug)] -struct MaliciousRound2 { - round: Round2, - behavior: Behavior, -} - -impl RoundWrapper for MaliciousRound2 { - type InnerRound = Round2; - fn inner_round_ref(&self) -> &Self::InnerRound { - &self.round - } - fn inner_round(self) -> Self::InnerRound { - self.round - } -} - -impl RoundOverride for MaliciousRound2 { - fn make_direct_message( - &self, - rng: &mut impl CryptoRngCore, - serializer: &Serializer, - destination: &Id, - ) -> Result<(DirectMessage, Option), LocalError> { - if matches!(self.behavior, Behavior::AttributableFailureRound2) { - let message = Round2Message { - my_position: self.round.context.ids_to_positions[&self.round.context.id], - your_position: self.round.context.ids_to_positions[&self.round.context.id], - }; - Ok((DirectMessage::new(serializer, message)?, None)) } else { - self.inner_round_ref().make_direct_message(rng, serializer, destination) - } + direct_message + }; + Ok((dm, artifact)) } } -round_override!(MaliciousRound2); +type MaliciousEntryPoint = MisbehavingRound; #[test] fn serialized_garbage() { @@ -160,13 +84,13 @@ fn serialized_garbage() { .enumerate() .map(|(idx, signer)| { let behavior = if idx == 0 { - Behavior::SerializedGarbage + Some(Behavior::SerializedGarbage) } else { - Behavior::Lawful + None }; - let malicious_inputs = MaliciousInputs { - inputs: inputs.clone(), + let malicious_inputs = MisbehavingInputs { + inner_inputs: inputs.clone(), behavior, }; (*signer, malicious_inputs) @@ -177,7 +101,7 @@ fn serialized_garbage() { .with_env_filter(EnvFilter::from_default_env()) .finish(); let mut reports = tracing::subscriber::with_default(my_subscriber, || { - run_sync::, TestSessionParams>(&mut OsRng, run_inputs).unwrap() + run_sync::, TestSessionParams>(&mut OsRng, run_inputs).unwrap() }); let v0 = signers[0].verifying_key(); @@ -206,13 +130,13 @@ fn attributable_failure() { .enumerate() .map(|(idx, signer)| { let behavior = if idx == 0 { - Behavior::AttributableFailure + Some(Behavior::AttributableFailure) } else { - Behavior::Lawful + None }; - let malicious_inputs = MaliciousInputs { - inputs: inputs.clone(), + let malicious_inputs = MisbehavingInputs { + inner_inputs: inputs.clone(), behavior, }; (*signer, malicious_inputs) @@ -223,7 +147,7 @@ fn attributable_failure() { .with_env_filter(EnvFilter::from_default_env()) .finish(); let mut reports = tracing::subscriber::with_default(my_subscriber, || { - run_sync::, TestSessionParams>(&mut OsRng, run_inputs).unwrap() + run_sync::, TestSessionParams>(&mut OsRng, run_inputs).unwrap() }); let v0 = signers[0].verifying_key(); @@ -252,13 +176,13 @@ fn attributable_failure_round2() { .enumerate() .map(|(idx, signer)| { let behavior = if idx == 0 { - Behavior::AttributableFailureRound2 + Some(Behavior::AttributableFailureRound2) } else { - Behavior::Lawful + None }; - let malicious_inputs = MaliciousInputs { - inputs: inputs.clone(), + let malicious_inputs = MisbehavingInputs { + inner_inputs: inputs.clone(), behavior, }; (*signer, malicious_inputs) @@ -269,7 +193,7 @@ fn attributable_failure_round2() { .with_env_filter(EnvFilter::from_default_env()) .finish(); let mut reports = tracing::subscriber::with_default(my_subscriber, || { - run_sync::, TestSessionParams>(&mut OsRng, run_inputs).unwrap() + run_sync::, TestSessionParams>(&mut OsRng, run_inputs).unwrap() }); let v0 = signers[0].verifying_key(); diff --git a/manul/src/combinators.rs b/manul/src/combinators.rs new file mode 100644 index 0000000..aad3e74 --- /dev/null +++ b/manul/src/combinators.rs @@ -0,0 +1,5 @@ +mod chain; +mod misbehave; + +pub use chain::{Chain, Chained, ChainedProtocol}; +pub use misbehave::{Misbehaving, MisbehavingInputs, MisbehavingRound}; diff --git a/manul/src/combinators/chain.rs b/manul/src/combinators/chain.rs new file mode 100644 index 0000000..becd175 --- /dev/null +++ b/manul/src/combinators/chain.rs @@ -0,0 +1,320 @@ +use alloc::{ + boxed::Box, + collections::{BTreeMap, BTreeSet}, + format, + string::String, + vec::Vec, +}; +use core::{fmt::Debug, marker::PhantomData}; + +use rand_core::CryptoRngCore; +use serde::{Deserialize, Serialize}; + +use crate::protocol::*; + +pub trait Chained { + type Protocol1: Protocol; + type Protocol2: Protocol; + type CorrectnessProof: Send + + Serialize + + for<'de> Deserialize<'de> + + Debug + + From<::CorrectnessProof> + + From<::CorrectnessProof>; + type Inputs: Send + Sync + Debug; + type EntryPoint1: FirstRound From<&'a Self::Inputs>>; + type EntryPoint2: FirstRound< + Id, + Protocol = Self::Protocol2, + Inputs: From<(Self::Inputs, ::Result)>, + >; +} + +#[derive_where::derive_where(Debug, Clone)] +#[derive(Serialize, Deserialize)] +pub enum ChainedProtocolError> { + Protocol1(::ProtocolError), + Protocol2(::ProtocolError), +} + +impl> ChainedProtocolError { + fn from_protocol1(err: ::ProtocolError) -> Self { + Self::Protocol1(err) + } + + fn from_protocol2(err: ::ProtocolError) -> Self { + Self::Protocol2(err) + } +} + +impl> ProtocolError for ChainedProtocolError { + fn description(&self) -> String { + match self { + Self::Protocol1(err) => format!("Protocol1: {}", err.description()), + Self::Protocol2(err) => format!("Protocol2: {}", err.description()), + } + } + + fn required_direct_messages(&self) -> BTreeSet { + match self { + // TODO: map rounds! + Self::Protocol1(err) => err.required_direct_messages(), + Self::Protocol2(err) => err.required_direct_messages(), + } + } + + fn required_echo_broadcasts(&self) -> BTreeSet { + match self { + Self::Protocol1(err) => err.required_echo_broadcasts(), + Self::Protocol2(err) => err.required_echo_broadcasts(), + } + } + + fn required_normal_broadcasts(&self) -> BTreeSet { + match self { + Self::Protocol1(err) => err.required_normal_broadcasts(), + Self::Protocol2(err) => err.required_normal_broadcasts(), + } + } + + fn required_combined_echos(&self) -> BTreeSet { + match self { + Self::Protocol1(err) => err.required_combined_echos(), + Self::Protocol2(err) => err.required_combined_echos(), + } + } + + #[allow(clippy::too_many_arguments)] + fn verify_messages_constitute_error( + &self, + deserializer: &Deserializer, + echo_broadcast: &EchoBroadcast, + normal_broadcast: &NormalBroadcast, + direct_message: &DirectMessage, + echo_broadcasts: &BTreeMap, + normal_broadcasts: &BTreeMap, + direct_messages: &BTreeMap, + combined_echos: &BTreeMap>, + ) -> Result<(), ProtocolValidationError> { + match self { + Self::Protocol1(err) => err.verify_messages_constitute_error( + deserializer, + echo_broadcast, + normal_broadcast, + direct_message, + echo_broadcasts, + normal_broadcasts, + direct_messages, + combined_echos, + ), + Self::Protocol2(err) => err.verify_messages_constitute_error( + deserializer, + echo_broadcast, + normal_broadcast, + direct_message, + echo_broadcasts, + normal_broadcasts, + direct_messages, + combined_echos, + ), + } + } +} + +pub struct ChainedProtocol>(PhantomData (Id, C)>); + +impl Protocol for ChainedProtocol +where + Id: PartyId, + C: 'static + Chained, +{ + type Result = ::Result; + type ProtocolError = ChainedProtocolError; + type CorrectnessProof = C::CorrectnessProof; +} + +#[derive_where::derive_where(Debug)] +pub struct Chain> { + state: ChainState, +} + +#[derive_where::derive_where(Debug)] +enum ChainState> { + Protocol1 { + round: Box>, + shared_randomness: Box<[u8]>, + id: Id, + inputs: C::Inputs, + }, + Protocol2(Box>), +} + +impl FirstRound for Chain +where + Id: PartyId, + C: 'static + Chained, +{ + type Inputs = C::Inputs; + + fn new( + rng: &mut impl CryptoRngCore, + shared_randomness: &[u8], + id: Id, + inputs: Self::Inputs, + ) -> Result { + let round = C::EntryPoint1::new(rng, shared_randomness, id.clone(), (&inputs).into())?; + Ok(Chain { + state: ChainState::Protocol1 { + shared_randomness: shared_randomness.into(), + id, + inputs, + round: Box::new(ObjectSafeRoundWrapper::new(round)), + }, + }) + } +} + +impl Round for Chain +where + Id: PartyId, + C: 'static + Chained, +{ + type Protocol = ChainedProtocol; + + fn id(&self) -> RoundId { + unimplemented!() + } + + fn possible_next_rounds(&self) -> BTreeSet { + unimplemented!() + } + + fn message_destinations(&self) -> &BTreeSet { + match &self.state { + ChainState::Protocol1 { round, .. } => round.message_destinations(), + ChainState::Protocol2(round) => round.message_destinations(), + } + } + + fn make_direct_message( + &self, + rng: &mut impl CryptoRngCore, + serializer: &Serializer, + destination: &Id, + ) -> Result<(DirectMessage, Option), LocalError> { + match &self.state { + ChainState::Protocol1 { round, .. } => round.make_direct_message(rng, serializer, destination), + ChainState::Protocol2(round) => round.make_direct_message(rng, serializer, destination), + } + } + + fn make_echo_broadcast( + &self, + #[allow(unused_variables)] rng: &mut impl CryptoRngCore, + #[allow(unused_variables)] serializer: &Serializer, + ) -> Result { + match &self.state { + ChainState::Protocol1 { round, .. } => round.make_echo_broadcast(rng, serializer), + ChainState::Protocol2(round) => round.make_echo_broadcast(rng, serializer), + } + } + + fn make_normal_broadcast( + &self, + #[allow(unused_variables)] rng: &mut impl CryptoRngCore, + #[allow(unused_variables)] serializer: &Serializer, + ) -> Result { + match &self.state { + ChainState::Protocol1 { round, .. } => round.make_normal_broadcast(rng, serializer), + ChainState::Protocol2(round) => round.make_normal_broadcast(rng, serializer), + } + } + + fn receive_message( + &self, + rng: &mut impl CryptoRngCore, + deserializer: &Deserializer, + from: &Id, + echo_broadcast: EchoBroadcast, + normal_broadcast: NormalBroadcast, + direct_message: DirectMessage, + ) -> Result> { + match &self.state { + ChainState::Protocol1 { round, .. } => match round.receive_message( + rng, + deserializer, + from, + echo_broadcast, + normal_broadcast, + direct_message, + ) { + Ok(payload) => Ok(payload), + Err(err) => Err(err.map(ChainedProtocolError::from_protocol1)), + }, + ChainState::Protocol2(round) => match round.receive_message( + rng, + deserializer, + from, + echo_broadcast, + normal_broadcast, + direct_message, + ) { + Ok(payload) => Ok(payload), + Err(err) => Err(err.map(ChainedProtocolError::from_protocol2)), + }, + } + } + + fn finalize( + self, + rng: &mut impl CryptoRngCore, + payloads: BTreeMap, + artifacts: BTreeMap, + ) -> Result, FinalizeError> { + match self.state { + ChainState::Protocol1 { + round, + id, + inputs, + shared_randomness, + } => match round.finalize(rng, payloads, artifacts) { + Ok(FinalizeOutcome::Result(result)) => { + let round = C::EntryPoint2::new(rng, &shared_randomness, id, (inputs, result).into())?; + + Ok(FinalizeOutcome::another_round(Chain:: { + state: ChainState::Protocol2(Box::new(ObjectSafeRoundWrapper::new(round))), + })) + } + Ok(FinalizeOutcome::AnotherRound(another_round)) => { + Ok(FinalizeOutcome::another_round(Chain:: { + state: ChainState::Protocol1 { + shared_randomness, + id, + inputs, + round: another_round.into_boxed(), + }, + })) + } + Err(FinalizeError::Local(err)) => Err(FinalizeError::Local(err)), + Err(FinalizeError::Unattributable(proof)) => Err(FinalizeError::Unattributable(proof.into())), + }, + ChainState::Protocol2(round) => match round.finalize(rng, payloads, artifacts) { + Ok(FinalizeOutcome::Result(result)) => Ok(FinalizeOutcome::Result(result)), + Ok(FinalizeOutcome::AnotherRound(another_round)) => { + Ok(FinalizeOutcome::another_round(Chain:: { + state: ChainState::Protocol2(another_round.into_boxed()), + })) + } + Err(FinalizeError::Local(err)) => Err(FinalizeError::Local(err)), + Err(FinalizeError::Unattributable(proof)) => Err(FinalizeError::Unattributable(proof.into())), + }, + } + } + + fn expecting_messages_from(&self) -> &BTreeSet { + match &self.state { + ChainState::Protocol1 { round, .. } => round.expecting_messages_from(), + ChainState::Protocol2(round) => round.expecting_messages_from(), + } + } +} diff --git a/manul/src/combinators/misbehave.rs b/manul/src/combinators/misbehave.rs new file mode 100644 index 0000000..867213e --- /dev/null +++ b/manul/src/combinators/misbehave.rs @@ -0,0 +1,192 @@ +use alloc::{ + boxed::Box, + collections::{BTreeMap, BTreeSet}, +}; +use core::fmt::Debug; + +use rand_core::CryptoRngCore; + +use crate::protocol::*; + +#[derive_where::derive_where(Debug)] +pub struct MisbehavingRound> { + round: Box>, + behavior: Option, +} + +pub trait Misbehaving { + type Protocol: Protocol; + type FirstRound: FirstRound; + + #[allow(unused_variables)] + fn amend_echo_broadcast( + rng: &mut impl CryptoRngCore, + round: &dyn ObjectSafeRound, + behavior: &B, + serializer: &Serializer, + echo_broadcast: EchoBroadcast, + ) -> Result { + Ok(echo_broadcast) + } + + #[allow(unused_variables)] + fn amend_normal_broadcast( + rng: &mut impl CryptoRngCore, + round: &dyn ObjectSafeRound, + behavior: &B, + serializer: &Serializer, + normal_broadcast: NormalBroadcast, + ) -> Result { + Ok(normal_broadcast) + } + + #[allow(unused_variables)] + fn amend_direct_message( + rng: &mut impl CryptoRngCore, + round: &dyn ObjectSafeRound, + behavior: &B, + serializer: &Serializer, + destination: &Id, + direct_message: DirectMessage, + artifact: Option, + ) -> Result<(DirectMessage, Option), LocalError> { + Ok((direct_message, artifact)) + } +} + +pub struct MisbehavingInputs> { + pub behavior: Option, + pub inner_inputs: >::Inputs, +} + +impl FirstRound for MisbehavingRound +where + Id: PartyId, + B: 'static + Debug + Send + Sync, + MP: 'static + Misbehaving, +{ + type Inputs = MisbehavingInputs; + + fn new( + rng: &mut impl CryptoRngCore, + shared_randomness: &[u8], + id: Id, + inputs: Self::Inputs, + ) -> Result { + let inner_round = MP::FirstRound::new(rng, shared_randomness, id, inputs.inner_inputs)?; + Ok(Self { + round: Box::new(ObjectSafeRoundWrapper::new(inner_round)), + behavior: inputs.behavior, + }) + } +} + +impl Round for MisbehavingRound +where + Id: PartyId, + B: 'static + Debug + Send + Sync, + M: 'static + Misbehaving, +{ + type Protocol = M::Protocol; + + fn id(&self) -> RoundId { + self.round.id() + } + + fn possible_next_rounds(&self) -> BTreeSet { + self.round.possible_next_rounds() + } + + fn message_destinations(&self) -> &BTreeSet { + self.round.message_destinations() + } + + fn make_direct_message( + &self, + rng: &mut impl CryptoRngCore, + serializer: &Serializer, + destination: &Id, + ) -> Result<(DirectMessage, Option), LocalError> { + let (direct_message, artifact) = self.round.make_direct_message(rng, serializer, destination)?; + if let Some(behavior) = self.behavior.as_ref() { + M::amend_direct_message( + rng, + self.round.as_ref(), + behavior, + serializer, + destination, + direct_message, + artifact, + ) + } else { + Ok((direct_message, artifact)) + } + } + + fn make_echo_broadcast( + &self, + rng: &mut impl CryptoRngCore, + serializer: &Serializer, + ) -> Result { + let echo_broadcast = self.round.make_echo_broadcast(rng, serializer)?; + if let Some(behavior) = self.behavior.as_ref() { + M::amend_echo_broadcast(rng, self.round.as_ref(), behavior, serializer, echo_broadcast) + } else { + Ok(echo_broadcast) + } + } + + fn make_normal_broadcast( + &self, + rng: &mut impl CryptoRngCore, + serializer: &Serializer, + ) -> Result { + let normal_broadcast = self.round.make_normal_broadcast(rng, serializer)?; + if let Some(behavior) = self.behavior.as_ref() { + M::amend_normal_broadcast(rng, self.round.as_ref(), behavior, serializer, normal_broadcast) + } else { + Ok(normal_broadcast) + } + } + + fn receive_message( + &self, + rng: &mut impl CryptoRngCore, + deserializer: &Deserializer, + from: &Id, + echo_broadcast: EchoBroadcast, + normal_broadcast: NormalBroadcast, + direct_message: DirectMessage, + ) -> Result> { + self.round.receive_message( + rng, + deserializer, + from, + echo_broadcast, + normal_broadcast, + direct_message, + ) + } + + fn finalize( + self, + rng: &mut impl CryptoRngCore, + payloads: BTreeMap, + artifacts: BTreeMap, + ) -> Result, FinalizeError> { + match self.round.finalize(rng, payloads, artifacts) { + Ok(FinalizeOutcome::Result(result)) => Ok(FinalizeOutcome::Result(result)), + Ok(FinalizeOutcome::AnotherRound(another_round)) => { + Ok(FinalizeOutcome::another_round(MisbehavingRound:: { + round: another_round.into_boxed(), + behavior: self.behavior, + })) + } + Err(err) => Err(err), + } + } + + fn expecting_messages_from(&self) -> &BTreeSet { + self.round.expecting_messages_from() + } +} diff --git a/manul/src/lib.rs b/manul/src/lib.rs index 8d3323f..ad55b47 100644 --- a/manul/src/lib.rs +++ b/manul/src/lib.rs @@ -16,6 +16,7 @@ extern crate alloc; +pub mod combinators; pub mod protocol; pub mod session; pub(crate) mod utils; diff --git a/manul/src/protocol.rs b/manul/src/protocol.rs index 0e53744..52c104d 100644 --- a/manul/src/protocol.rs +++ b/manul/src/protocol.rs @@ -28,6 +28,7 @@ pub use round::{ pub use serialization::{Deserializer, Serializer}; pub(crate) use errors::ReceiveErrorType; -pub(crate) use object_safe::{ObjectSafeRound, ObjectSafeRoundWrapper}; +pub use object_safe::ObjectSafeRound; +pub(crate) use object_safe::ObjectSafeRoundWrapper; pub use digest; diff --git a/manul/src/protocol/errors.rs b/manul/src/protocol/errors.rs index a77ecc3..460641d 100644 --- a/manul/src/protocol/errors.rs +++ b/manul/src/protocol/errors.rs @@ -68,6 +68,32 @@ impl ReceiveError { pub fn protocol(error: P::ProtocolError) -> Self { Self(ReceiveErrorType::Protocol(error)) } + + pub(crate) fn map(self, f: F) -> ReceiveError + where + F: Fn(P::ProtocolError) -> T::ProtocolError, + T: Protocol, + { + ReceiveError(self.0.map::(f)) + } +} + +impl ReceiveErrorType { + pub(crate) fn map(self, f: F) -> ReceiveErrorType + where + F: Fn(P::ProtocolError) -> T::ProtocolError, + T: Protocol, + { + match self { + Self::Local(err) => ReceiveErrorType::Local(err), + Self::InvalidDirectMessage(err) => ReceiveErrorType::InvalidDirectMessage(err), + Self::InvalidEchoBroadcast(err) => ReceiveErrorType::InvalidEchoBroadcast(err), + Self::InvalidNormalBroadcast(err) => ReceiveErrorType::InvalidNormalBroadcast(err), + Self::Unprovable(err) => ReceiveErrorType::Unprovable(err), + Self::Echo(err) => ReceiveErrorType::Echo(err), + Self::Protocol(err) => ReceiveErrorType::Protocol(f(err)), + } + } } impl From for ReceiveError diff --git a/manul/src/protocol/object_safe.rs b/manul/src/protocol/object_safe.rs index 6d282be..123b3b0 100644 --- a/manul/src/protocol/object_safe.rs +++ b/manul/src/protocol/object_safe.rs @@ -39,7 +39,7 @@ impl RngCore for BoxedRng<'_> { // Since we want `Round` methods to take `&mut impl CryptoRngCore` arguments // (which is what all cryptographic libraries generally take), it cannot be object-safe. // Thus we have to add this crate-private object-safe layer on top of `Round`. -pub(crate) trait ObjectSafeRound: 'static + Debug + Send + Sync { +pub trait ObjectSafeRound: 'static + Debug + Send + Sync { type Protocol: Protocol; fn id(&self) -> RoundId; @@ -225,4 +225,17 @@ where self.try_downcast() .map_err(|_| LocalError::new(format!("Failed to downcast into type {}", core::any::type_name::()))) } + + pub fn downcast_ref>(&self) -> Result<&T, LocalError> { + if core::any::TypeId::of::>() == self.get_type_id() { + let ptr: *const dyn ObjectSafeRound = self; + // This should be safe since we just checked that we are casting to a correct type. + Ok(unsafe { &*(ptr as *const T) }) + } else { + Err(LocalError::new(format!( + "Failed to downcast into type {}", + core::any::type_name::() + ))) + } + } }