From 42deda164b9ee3718e6f95fc6247b5bb237a2840 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Fri, 18 Oct 2024 16:47:55 -0700 Subject: [PATCH] Switch to supplying a (de)serializer externally --- example/src/lib.rs | 3 ++ example/src/serializer.rs | 17 +++++++ example/src/simple.rs | 65 +++++++++++++---------- example/src/simple_malicious.rs | 85 +++++++++++++++++++++---------- example/tests/async.rs | 20 +++++--- manul/src/protocol.rs | 2 +- manul/src/protocol/object_safe.rs | 37 ++++++++------ manul/src/protocol/round.rs | 77 +++++++++++++++------------- manul/src/session.rs | 3 +- manul/src/session/echo.rs | 22 ++++---- manul/src/session/evidence.rs | 79 ++++++++++++++++------------ manul/src/session/message.rs | 26 ++++++---- manul/src/session/session.rs | 67 +++++++++++++----------- manul/src/testing/macros.rs | 35 +++++++------ manul/src/testing/run_sync.rs | 35 +++++++------ 15 files changed, 343 insertions(+), 230 deletions(-) create mode 100644 example/src/serializer.rs diff --git a/example/src/lib.rs b/example/src/lib.rs index c3aff76..7f0932b 100644 --- a/example/src/lib.rs +++ b/example/src/lib.rs @@ -1,6 +1,9 @@ extern crate alloc; +mod serializer; pub mod simple; #[cfg(test)] mod simple_malicious; + +pub use serializer::BincodeSerializer; diff --git a/example/src/serializer.rs b/example/src/serializer.rs new file mode 100644 index 0000000..cc3e76a --- /dev/null +++ b/example/src/serializer.rs @@ -0,0 +1,17 @@ +use manul::protocol::{DeserializationError, LocalError, Serializer}; +use serde::{Deserialize, Serialize}; + +pub struct BincodeSerializer; + +impl Serializer for BincodeSerializer { + fn serialize(value: T) -> Result, LocalError> { + bincode::serde::encode_to_vec(value, bincode::config::standard()) + .map(|vec| vec.into()) + .map_err(|err| LocalError::new(err.to_string())) + } + + fn deserialize<'de, T: Deserialize<'de>>(bytes: &'de [u8]) -> Result { + bincode::serde::decode_borrowed_from_slice(bytes, bincode::config::standard()) + .map_err(|err| DeserializationError::new(err.to_string())) + } +} diff --git a/example/src/simple.rs b/example/src/simple.rs index 4cbaf80..f0f6e33 100644 --- a/example/src/simple.rs +++ b/example/src/simple.rs @@ -1,5 +1,5 @@ use alloc::collections::{BTreeMap, BTreeSet}; -use core::fmt::Debug; +use core::{fmt::Debug, marker::PhantomData}; use manul::protocol::*; use rand_core::CryptoRngCore; @@ -35,7 +35,7 @@ impl ProtocolError for SimpleProtocolError { } } - fn verify_messages_constitute_error( + fn verify_messages_constitute_error( &self, _echo_broadcast: &Option, direct_message: &DirectMessage, @@ -45,12 +45,12 @@ impl ProtocolError for SimpleProtocolError { ) -> Result<(), ProtocolValidationError> { match self { SimpleProtocolError::Round1InvalidPosition => { - let _message = direct_message.deserialize::()?; + let _message = direct_message.deserialize::()?; // Message contents would be checked here Ok(()) } SimpleProtocolError::Round2InvalidPosition => { - let _r1_message = direct_message.deserialize::()?; + let _r1_message = direct_message.deserialize::()?; let r1_echos_serialized = combined_echos .get(&RoundId::new(1)) .ok_or_else(|| LocalError::new("Could not find combined echos for Round 1"))?; @@ -58,7 +58,7 @@ impl ProtocolError for SimpleProtocolError { // Deserialize the echos let _r1_echos = r1_echos_serialized .iter() - .map(|echo| echo.deserialize::()) + .map(|echo| echo.deserialize::()) .collect::, _>>()?; // Message contents would be checked here @@ -75,23 +75,12 @@ impl Protocol for SimpleProtocol { type Digest = Sha3_256; - fn serialize(value: T) -> Result, LocalError> { - bincode::serde::encode_to_vec(value, bincode::config::standard()) - .map(|vec| vec.into()) - .map_err(|err| LocalError::new(err.to_string())) - } - - fn deserialize<'de, T: Deserialize<'de>>(bytes: &'de [u8]) -> Result { - bincode::serde::decode_borrowed_from_slice(bytes, bincode::config::standard()) - .map_err(|err| DeserializationError::new(err.to_string())) - } - - fn verify_direct_message_is_invalid( + fn verify_direct_message_is_invalid( round_id: RoundId, message: &DirectMessage, ) -> Result<(), MessageValidationError> { if round_id == RoundId::new(1) { - return message.verify_is_invalid::(); + return message.verify_is_invalid::(); } Err(MessageValidationError::InvalidEvidence("Invalid round number".into()))? } @@ -108,8 +97,9 @@ pub(crate) struct Context { pub(crate) ids_to_positions: BTreeMap, } -pub struct Round1 { +pub struct Round1 { pub(crate) context: Context, + phantom: PhantomData S>, } #[derive(Debug, Serialize, Deserialize)] @@ -127,7 +117,11 @@ struct Round1Payload { x: u8, } -impl FirstRound for Round1 { +impl FirstRound for Round1 +where + Id: 'static + Debug + Clone + Ord + Send + Sync, + S: 'static + Serializer, +{ type Inputs = Inputs; fn new( _rng: &mut impl CryptoRngCore, @@ -153,11 +147,16 @@ impl FirstRound for Round1< other_ids: ids, ids_to_positions, }, + phantom: PhantomData, }) } } -impl Round for Round1 { +impl Round for Round1 +where + Id: 'static + Debug + Clone + Ord + Send + Sync, + S: 'static + Serializer, +{ type Protocol = SimpleProtocol; fn id(&self) -> RoundId { @@ -207,7 +206,7 @@ impl Round for Round1 { ) -> Result> { debug!("{:?}: receiving message from {:?}", self.context.id, from); - let message = direct_message.deserialize::()?; + let message = direct_message.deserialize::()?; debug!("{:?}: received message: {:?}", self.context.id, message); @@ -223,7 +222,7 @@ impl Round for Round1 { _rng: &mut impl CryptoRngCore, payloads: BTreeMap, _artifacts: BTreeMap, - ) -> Result, FinalizeError> { + ) -> Result, FinalizeError> { debug!( "{:?}: finalizing with messages from {:?}", self.context.id, @@ -241,6 +240,7 @@ impl Round for Round1 { let round2 = Round2 { round1_sum: sum, context: self.context, + phantom: PhantomData, }; Ok(FinalizeOutcome::another_round(round2)) } @@ -250,9 +250,10 @@ impl Round for Round1 { } } -pub(crate) struct Round2 { +pub(crate) struct Round2 { round1_sum: u8, pub(crate) context: Context, + phantom: PhantomData S>, } #[derive(Debug, Serialize, Deserialize)] @@ -261,7 +262,11 @@ pub(crate) struct Round2Message { pub(crate) your_position: u8, } -impl Round for Round2 { +impl Round for Round2 +where + Id: 'static + Debug + Clone + Ord + Send + Sync, + S: 'static + Serializer, +{ type Protocol = SimpleProtocol; fn id(&self) -> RoundId { @@ -311,7 +316,7 @@ impl Round for Round2 { ) -> Result> { debug!("{:?}: receiving message from {:?}", self.context.id, from); - let message = direct_message.deserialize::()?; + let message = direct_message.deserialize::()?; debug!("{:?}: received message: {:?}", self.context.id, message); @@ -327,7 +332,7 @@ impl Round for Round2 { _rng: &mut impl CryptoRngCore, payloads: BTreeMap, _artifacts: BTreeMap, - ) -> Result, FinalizeError> { + ) -> Result, FinalizeError> { debug!( "{:?}: finalizing with messages from {:?}", self.context.id, @@ -366,6 +371,7 @@ mod tests { use tracing_subscriber::EnvFilter; use super::{Inputs, Round1}; + use crate::BincodeSerializer; #[test] fn round() { @@ -390,7 +396,10 @@ mod tests { .with_env_filter(EnvFilter::from_default_env()) .finish(); let reports = tracing::subscriber::with_default(my_subscriber, || { - run_sync::, Signer, Verifier, Signature>(&mut OsRng, inputs).unwrap() + run_sync::, BincodeSerializer, Signer, Verifier, Signature>( + &mut OsRng, inputs, + ) + .unwrap() }); for (_id, report) in reports { diff --git a/example/src/simple_malicious.rs b/example/src/simple_malicious.rs index 46d38b5..6158fa3 100644 --- a/example/src/simple_malicious.rs +++ b/example/src/simple_malicious.rs @@ -3,7 +3,8 @@ use core::fmt::Debug; use manul::{ protocol::{ - Artifact, DirectMessage, FinalizeError, FinalizeOutcome, FirstRound, LocalError, Payload, Round, SessionId, + Artifact, DirectMessage, FinalizeError, FinalizeOutcome, FirstRound, LocalError, Payload, Round, Serializer, + SessionId, }, session::signature::Keypair, testing::{round_override, run_sync, RoundOverride, RoundWrapper, Signature, Signer, Verifier}, @@ -11,7 +12,10 @@ use manul::{ use rand_core::{CryptoRngCore, OsRng}; use tracing_subscriber::EnvFilter; -use crate::simple::{Inputs, Round1, Round1Message, Round2, Round2Message}; +use crate::{ + simple::{Inputs, Round1, Round1Message, Round2, Round2Message}, + BincodeSerializer, +}; #[derive(Debug, Clone, Copy)] enum Behavior { @@ -26,13 +30,17 @@ struct MaliciousInputs { behavior: Behavior, } -struct MaliciousRound1 { - round: Round1, +struct MaliciousRound1 { + round: Round1, behavior: Behavior, } -impl RoundWrapper for MaliciousRound1 { - type InnerRound = Round1; +impl RoundWrapper for MaliciousRound1 +where + Id: 'static + Debug + Clone + Ord + Send + Sync, + S: 'static + Serializer, +{ + type InnerRound = Round1; fn inner_round_ref(&self) -> &Self::InnerRound { &self.round } @@ -41,7 +49,11 @@ impl RoundWrapper for Malic } } -impl FirstRound for MaliciousRound1 { +impl FirstRound for MaliciousRound1 +where + Id: 'static + Debug + Clone + Ord + Send + Sync, + S: 'static + Serializer, +{ type Inputs = MaliciousInputs; fn new( rng: &mut impl CryptoRngCore, @@ -57,21 +69,25 @@ impl FirstRound for Malicio } } -impl RoundOverride for MaliciousRound1 { +impl RoundOverride for MaliciousRound1 +where + Id: 'static + Debug + Clone + Ord + Send + Sync, + S: 'static + Serializer, +{ fn make_direct_message( &self, rng: &mut impl CryptoRngCore, destination: &Id, ) -> Result<(DirectMessage, Artifact), LocalError> { if matches!(self.behavior, Behavior::SerializedGarbage) { - let dm = DirectMessage::new::<>::Protocol, _>(&[99u8]).unwrap(); + let dm = DirectMessage::new::(&[99u8]).unwrap(); Ok((dm, Artifact::empty())) } else if matches!(self.behavior, Behavior::AttributableFailure) { let message = Round1Message { my_position: self.round.context.ids_to_positions[&self.round.context.id], your_position: self.round.context.ids_to_positions[&self.round.context.id], }; - let dm = DirectMessage::new::<>::Protocol, _>(&message)?; + let dm = DirectMessage::new::(&message)?; Ok((dm, Artifact::empty())) } else { self.inner_round_ref().make_direct_message(rng, destination) @@ -84,8 +100,8 @@ impl RoundOverride for Mali payloads: BTreeMap, artifacts: BTreeMap, ) -> Result< - FinalizeOutcome>::InnerRound as Round>::Protocol>, - FinalizeError<<>::InnerRound as Round>::Protocol>, + FinalizeOutcome>::InnerRound as Round>::Protocol, S>, + FinalizeError<<>::InnerRound as Round>::Protocol>, > { let behavior = self.behavior; let outcome = self.inner_round().finalize(rng, payloads, artifacts)?; @@ -93,7 +109,9 @@ impl RoundOverride for Mali Ok(match outcome { FinalizeOutcome::Result(res) => FinalizeOutcome::Result(res), FinalizeOutcome::AnotherRound(another_round) => { - let round2 = another_round.downcast::>().map_err(FinalizeError::Local)?; + let round2 = another_round + .downcast::>() + .map_err(FinalizeError::Local)?; FinalizeOutcome::another_round(MaliciousRound2 { round: round2, behavior, @@ -105,13 +123,17 @@ impl RoundOverride for Mali round_override!(MaliciousRound1); -struct MaliciousRound2 { - round: Round2, +struct MaliciousRound2 { + round: Round2, behavior: Behavior, } -impl RoundWrapper for MaliciousRound2 { - type InnerRound = Round2; +impl RoundWrapper for MaliciousRound2 +where + Id: 'static + Debug + Clone + Ord + Send + Sync, + S: 'static + Serializer, +{ + type InnerRound = Round2; fn inner_round_ref(&self) -> &Self::InnerRound { &self.round } @@ -120,7 +142,11 @@ impl RoundWrapper for Malic } } -impl RoundOverride for MaliciousRound2 { +impl RoundOverride for MaliciousRound2 +where + Id: 'static + Debug + Clone + Ord + Send + Sync, + S: 'static + Serializer, +{ fn make_direct_message( &self, rng: &mut impl CryptoRngCore, @@ -131,7 +157,7 @@ impl RoundOverride for Mali my_position: self.round.context.ids_to_positions[&self.round.context.id], your_position: self.round.context.ids_to_positions[&self.round.context.id], }; - let dm = DirectMessage::new::<>::Protocol, _>(&message)?; + let dm = DirectMessage::new::(&message)?; Ok((dm, Artifact::empty())) } else { self.inner_round_ref().make_direct_message(rng, destination) @@ -172,7 +198,8 @@ fn serialized_garbage() { .with_env_filter(EnvFilter::from_default_env()) .finish(); let mut reports = tracing::subscriber::with_default(my_subscriber, || { - run_sync::, Signer, Verifier, Signature>(&mut OsRng, run_inputs).unwrap() + run_sync::, BincodeSerializer, Signer, Verifier, Signature>(&mut OsRng, run_inputs) + .unwrap() }); let v0 = signers[0].verifying_key(); @@ -183,8 +210,8 @@ fn serialized_garbage() { let report1 = reports.remove(&v1).unwrap(); let report2 = reports.remove(&v2).unwrap(); - assert!(report1.provable_errors[&v0].verify(&v0).is_ok()); - assert!(report2.provable_errors[&v0].verify(&v0).is_ok()); + assert!(report1.provable_errors[&v0].verify::(&v0).is_ok()); + assert!(report2.provable_errors[&v0].verify::(&v0).is_ok()); } #[test] @@ -218,7 +245,8 @@ fn attributable_failure() { .with_env_filter(EnvFilter::from_default_env()) .finish(); let mut reports = tracing::subscriber::with_default(my_subscriber, || { - run_sync::, Signer, Verifier, Signature>(&mut OsRng, run_inputs).unwrap() + run_sync::, BincodeSerializer, Signer, Verifier, Signature>(&mut OsRng, run_inputs) + .unwrap() }); let v0 = signers[0].verifying_key(); @@ -229,8 +257,8 @@ fn attributable_failure() { let report1 = reports.remove(&v1).unwrap(); let report2 = reports.remove(&v2).unwrap(); - assert!(report1.provable_errors[&v0].verify(&v0).is_ok()); - assert!(report2.provable_errors[&v0].verify(&v0).is_ok()); + assert!(report1.provable_errors[&v0].verify::(&v0).is_ok()); + assert!(report2.provable_errors[&v0].verify::(&v0).is_ok()); } #[test] @@ -264,7 +292,8 @@ fn attributable_failure_round2() { .with_env_filter(EnvFilter::from_default_env()) .finish(); let mut reports = tracing::subscriber::with_default(my_subscriber, || { - run_sync::, Signer, Verifier, Signature>(&mut OsRng, run_inputs).unwrap() + run_sync::, BincodeSerializer, Signer, Verifier, Signature>(&mut OsRng, run_inputs) + .unwrap() }); let v0 = signers[0].verifying_key(); @@ -275,6 +304,6 @@ fn attributable_failure_round2() { let report1 = reports.remove(&v1).unwrap(); let report2 = reports.remove(&v2).unwrap(); - assert!(report1.provable_errors[&v0].verify(&v0).is_ok()); - assert!(report2.provable_errors[&v0].verify(&v0).is_ok()); + assert!(report1.provable_errors[&v0].verify::(&v0).is_ok()); + assert!(report2.provable_errors[&v0].verify::(&v0).is_ok()); } diff --git a/example/tests/async.rs b/example/tests/async.rs index 7b7951c..633b8ab 100644 --- a/example/tests/async.rs +++ b/example/tests/async.rs @@ -3,13 +3,16 @@ extern crate alloc; use alloc::collections::{BTreeMap, BTreeSet}; use manul::{ - protocol::{Protocol, Round}, + protocol::Protocol, session::{ signature::Keypair, CanFinalize, LocalError, MessageBundle, RoundOutcome, Session, SessionId, SessionReport, }, testing::{Signature, Signer, Verifier}, }; -use manul_example::simple::{Inputs, Round1}; +use manul_example::{ + simple::{Inputs, Round1}, + BincodeSerializer, +}; use rand::Rng; use rand_core::OsRng; use tokio::{ @@ -25,7 +28,7 @@ type MessageIn = (Verifier, MessageBundle); async fn run_session

( tx: mpsc::Sender, rx: mpsc::Receiver, - session: Session, + session: Session, ) -> Result, LocalError> where P: Protocol + 'static, @@ -147,7 +150,7 @@ async fn message_dispatcher(txs: BTreeMap>, rx } async fn run_nodes

( - sessions: Vec>, + sessions: Vec>, ) -> Vec> where P: Protocol + Send + 'static, @@ -204,9 +207,12 @@ async fn async_run() { let inputs = Inputs { all_ids: all_ids.clone(), }; - Session::< as Round>::Protocol, Signer, Verifier, Signature>::new::< - Round1, - >(&mut OsRng, session_id.clone(), signer, inputs) + Session::<_, BincodeSerializer, Signer, Verifier, Signature>::new::>( + &mut OsRng, + session_id.clone(), + signer, + inputs, + ) .unwrap() }) .collect::>(); diff --git a/manul/src/protocol.rs b/manul/src/protocol.rs index d98c256..93c4077 100644 --- a/manul/src/protocol.rs +++ b/manul/src/protocol.rs @@ -20,7 +20,7 @@ pub use error::{LocalError, RemoteError}; pub use round::{ AnotherRound, Artifact, DeserializationError, DirectMessage, DirectMessageError, EchoBroadcast, EchoBroadcastError, FinalizeError, FinalizeOutcome, FirstRound, MessageValidationError, Payload, Protocol, ProtocolError, - ProtocolValidationError, ReceiveError, Round, RoundId, + ProtocolValidationError, ReceiveError, Round, RoundId, Serializer, }; pub(crate) use object_safe::{ObjectSafeRound, ObjectSafeRoundWrapper}; diff --git a/manul/src/protocol/object_safe.rs b/manul/src/protocol/object_safe.rs index c0c4760..b008f7c 100644 --- a/manul/src/protocol/object_safe.rs +++ b/manul/src/protocol/object_safe.rs @@ -11,7 +11,7 @@ use super::{ error::LocalError, round::{ Artifact, DirectMessage, EchoBroadcast, FinalizeError, FinalizeOutcome, Payload, Protocol, ReceiveError, Round, - RoundId, + RoundId, Serializer, }, }; @@ -40,7 +40,7 @@ impl<'a> RngCore for BoxedRng<'a> { // Since we want `Round` methods to take `&mut impl CryptoRngCore` arguments // (which is what all cryptographic libraries generally take), it cannot be object-safe. // Thus we have to add this crate-private object-safe layer on top of `Round`. -pub(crate) trait ObjectSafeRound: 'static + Send + Sync { +pub(crate) trait ObjectSafeRound: 'static + Send + Sync { type Protocol: Protocol; fn id(&self) -> RoundId; @@ -70,7 +70,7 @@ pub(crate) trait ObjectSafeRound: 'static + Send + Sync { rng: &mut dyn CryptoRngCore, payloads: BTreeMap, artifacts: BTreeMap, - ) -> Result, FinalizeError>; + ) -> Result, FinalizeError>; fn expecting_messages_from(&self) -> &BTreeSet; @@ -79,12 +79,13 @@ pub(crate) trait ObjectSafeRound: 'static + Send + Sync { } // The `fn(Id) -> Id` bit is so that `ObjectSafeRoundWrapper` didn't require a bound on `Id` to be `Send + Sync`. -pub(crate) struct ObjectSafeRoundWrapper { +pub(crate) struct ObjectSafeRoundWrapper { round: R, - phantom: PhantomData Id>, + #[allow(clippy::type_complexity)] + phantom: PhantomData<(fn(Id) -> Id, fn(S) -> S)>, } -impl> ObjectSafeRoundWrapper { +impl, S: Serializer> ObjectSafeRoundWrapper { pub fn new(round: R) -> Self { Self { round, @@ -93,12 +94,13 @@ impl> ObjectSafeRoundWrapper { } } -impl ObjectSafeRound for ObjectSafeRoundWrapper +impl ObjectSafeRound for ObjectSafeRoundWrapper where Id: 'static, - R: Round, + R: Round, + S: 'static + Serializer, { - type Protocol = >::Protocol; + type Protocol = >::Protocol; fn id(&self) -> RoundId { self.round.id() @@ -143,7 +145,7 @@ where rng: &mut dyn CryptoRngCore, payloads: BTreeMap, artifacts: BTreeMap, - ) -> Result, FinalizeError> { + ) -> Result, FinalizeError> { let mut boxed_rng = BoxedRng(rng); self.round.finalize(&mut boxed_rng, payloads, artifacts) } @@ -164,16 +166,19 @@ where // Because of Rust's peculiarities, Box that we return in `finalize()` // cannot be unboxed into an object of a concrete type with `downcast()`, // so we have to provide this workaround. -impl dyn ObjectSafeRound +impl dyn ObjectSafeRound where Id: 'static, P: 'static + Protocol, + S: 'static + Serializer, { - pub fn try_downcast>(self: Box) -> Result> { - if core::any::TypeId::of::>() == self.get_type_id() { + pub fn try_downcast>(self: Box) -> Result> { + if core::any::TypeId::of::>() == self.get_type_id() { // This should be safe since we just checked that we are casting to a correct type. let boxed_downcast = unsafe { - Box::>::from_raw(Box::into_raw(self) as *mut ObjectSafeRoundWrapper) + Box::>::from_raw( + Box::into_raw(self) as *mut ObjectSafeRoundWrapper + ) }; Ok(boxed_downcast.round) } else { @@ -181,8 +186,8 @@ where } } - pub fn downcast>(self: Box) -> Result { + pub fn downcast>(self: Box) -> Result { self.try_downcast() - .map_err(|_| LocalError::new(format!("Failed to downcast into type {}", core::any::type_name::()))) + .map_err(|_| LocalError::new(format!("Failed to downcast into type {}", core::any::type_name::()))) } } diff --git a/manul/src/protocol/round.rs b/manul/src/protocol/round.rs index ba7b118..b1c0a15 100644 --- a/manul/src/protocol/round.rs +++ b/manul/src/protocol/round.rs @@ -104,51 +104,53 @@ where } /// Possible successful outcomes of [`Round::finalize`]. -pub enum FinalizeOutcome { +pub enum FinalizeOutcome { /// Transition to a new round. - AnotherRound(AnotherRound), + AnotherRound(AnotherRound), /// The protocol reached a result. Result(P::Result), } -impl FinalizeOutcome +impl FinalizeOutcome where Id: 'static, P: 'static + Protocol, + S: 'static + Serializer, { /// A helper method to create an [`AnotherRound`](`Self::AnotherRound`) variant. - pub fn another_round(round: impl Round) -> Self { + pub fn another_round(round: impl Round) -> Self { Self::AnotherRound(AnotherRound::new(round)) } } // We do not want to expose `ObjectSafeRound` to the user, so it is hidden in a struct. /// A wrapped new round that may be returned by [`Round::finalize`]. -pub struct AnotherRound(Box>); +pub struct AnotherRound(Box>); -impl AnotherRound +impl AnotherRound where Id: 'static, P: 'static + Protocol, + S: 'static + Serializer, { /// Wraps an object implementing [`Round`]. - pub fn new(round: impl Round) -> Self { + pub fn new(round: impl Round) -> Self { Self(Box::new(ObjectSafeRoundWrapper::new(round))) } /// Returns the inner boxed type. /// This is an internal method to be used in `Session`. - pub(crate) fn into_boxed(self) -> Box> { + pub(crate) fn into_boxed(self) -> Box> { self.0 } /// Attempts to extract an object of a concrete type. - pub fn downcast>(self) -> Result { + pub fn downcast>(self) -> Result { self.0.downcast::() } /// Attempts to extract an object of a concrete type, preserving the original on failure. - pub fn try_downcast>(self) -> Result { + pub fn try_downcast>(self) -> Result { self.0.try_downcast::().map_err(Self) } } @@ -236,6 +238,15 @@ impl From for MessageValidationError { } } +/// A (de)serializer that will be used for the protocol messages. +pub trait Serializer { + /// Serializes the given object into a bytestring. + fn serialize(value: T) -> Result, LocalError>; + + /// Tries to deserialize the given bytestring as an object of type `T`. + fn deserialize<'de, T: Deserialize<'de>>(bytes: &'de [u8]) -> Result; +} + /// A distributed protocol. pub trait Protocol: Debug + Sized { /// The successful result of an execution of this protocol. @@ -254,17 +265,11 @@ pub trait Protocol: Debug + Sized { /// This will be used to generate message signatures. type Digest: Digest; - /// Serializes the given object into a bytestring. - fn serialize(value: T) -> Result, LocalError>; - - /// Tries to deserialize the given bytestring as an object of type `T`. - fn deserialize<'de, T: Deserialize<'de>>(bytes: &'de [u8]) -> Result; - /// Returns `Ok(())` if the given direct message cannot be deserialized /// assuming it is a direct message from the round `round_id`. /// /// Normally one would use [`DirectMessage::verify_is_invalid`] when implementing this. - fn verify_direct_message_is_invalid( + fn verify_direct_message_is_invalid( round_id: RoundId, #[allow(unused_variables)] message: &DirectMessage, ) -> Result<(), MessageValidationError> { @@ -277,7 +282,7 @@ pub trait Protocol: Debug + Sized { /// assuming it is an echo broadcast from the round `round_id`. /// /// Normally one would use [`EchoBroadcast::verify_is_invalid`] when implementing this. - fn verify_echo_broadcast_is_invalid( + fn verify_echo_broadcast_is_invalid( round_id: RoundId, #[allow(unused_variables)] message: &EchoBroadcast, ) -> Result<(), MessageValidationError> { @@ -359,7 +364,7 @@ pub trait ProtocolError: Debug + Clone + Send { /// [`required_echo_broadcasts`](`Self::required_echo_broadcasts`). /// `combined_echos` are bundled echos from other parties from the previous rounds, /// as requested by [`required_combined_echos`](`Self::required_combined_echos`). - fn verify_messages_constitute_error( + fn verify_messages_constitute_error( &self, echo_broadcast: &Option, direct_message: &DirectMessage, @@ -385,15 +390,15 @@ pub struct DirectMessage(#[serde(with = "SliceLike::")] Box<[u8]>); impl DirectMessage { /// Creates a new serialized direct message. - pub fn new(message: T) -> Result { - P::serialize(message).map(Self) + pub fn new(message: T) -> Result { + S::serialize(message).map(Self) } /// Returns `Ok(())` if the message cannot be deserialized into `T`. /// /// This is intended to be used in the implementations of [`Protocol::verify_direct_message_is_invalid`]. - pub fn verify_is_invalid Deserialize<'de>>(&self) -> Result<(), MessageValidationError> { - if self.deserialize::().is_err() { + pub fn verify_is_invalid Deserialize<'de>>(&self) -> Result<(), MessageValidationError> { + if self.deserialize::().is_err() { Ok(()) } else { Err(MessageValidationError::InvalidEvidence( @@ -403,8 +408,8 @@ impl DirectMessage { } /// Deserializes the direct message. - pub fn deserialize Deserialize<'de>>(&self) -> Result { - P::deserialize(&self.0).map_err(DirectMessageError) + pub fn deserialize Deserialize<'de>>(&self) -> Result { + S::deserialize(&self.0).map_err(DirectMessageError) } } @@ -414,15 +419,15 @@ pub struct EchoBroadcast(#[serde(with = "SliceLike::")] Box<[u8]>); impl EchoBroadcast { /// Creates a new serialized echo broadcast. - pub fn new(message: T) -> Result { - P::serialize(message).map(Self) + pub fn new(message: T) -> Result { + S::serialize(message).map(Self) } /// Returns `Ok(())` if the message cannot be deserialized into `T`. /// /// This is intended to be used in the implementations of [`Protocol::verify_direct_message_is_invalid`]. - pub fn verify_is_invalid Deserialize<'de>>(&self) -> Result<(), MessageValidationError> { - if self.deserialize::().is_err() { + pub fn verify_is_invalid Deserialize<'de>>(&self) -> Result<(), MessageValidationError> { + if self.deserialize::().is_err() { Ok(()) } else { Err(MessageValidationError::InvalidEvidence( @@ -432,8 +437,8 @@ impl EchoBroadcast { } /// Deserializes the echo broadcast. - pub fn deserialize Deserialize<'de>>(&self) -> Result { - P::deserialize(&self.0).map_err(EchoBroadcastError) + pub fn deserialize Deserialize<'de>>(&self) -> Result { + S::deserialize(&self.0).map_err(EchoBroadcastError) } } @@ -499,7 +504,7 @@ impl Artifact { /// /// This is a round that can be created directly; /// all the others are only reachable throud [`Round::finalize`] by the execution layer. -pub trait FirstRound: Round + Sized { +pub trait FirstRound: Round + Sized { /// Additional inputs for the protocol (besides the mandatory ones in [`new`](`Self::new`)). type Inputs; @@ -524,7 +529,7 @@ The way a round will be used by an external caller: - process received messages from other nodes (by calling [`receive_message`](`Self::receive_message`)); - attempt to finalize (by calling [`finalize`](`Self::finalize`)) to produce the next round, or return a result. */ -pub trait Round: 'static + Send + Sync { +pub trait Round: 'static + Send + Sync { /// The protocol this round is a part of. type Protocol: Protocol; @@ -593,7 +598,7 @@ pub trait Round: 'static + Send + Sync { rng: &mut impl CryptoRngCore, payloads: BTreeMap, artifacts: BTreeMap, - ) -> Result, FinalizeError>; + ) -> Result, FinalizeError>; /// Returns the set of node IDs from which this round expects messages. /// @@ -604,12 +609,12 @@ pub trait Round: 'static + Send + Sync { /// A convenience method to create an [`EchoBroadcast`] object /// to return in [`make_echo_broadcast`](`Self::make_echo_broadcast`). fn serialize_echo_broadcast(message: impl Serialize) -> Result { - EchoBroadcast::new::(message) + EchoBroadcast::new::(message) } /// A convenience method to create a [`DirectMessage`] object /// to return in [`make_direct_message`](`Self::make_direct_message`). fn serialize_direct_message(message: impl Serialize) -> Result { - DirectMessage::new::(message) + DirectMessage::new::(message) } } diff --git a/manul/src/session.rs b/manul/src/session.rs index d03cb5d..0d2df67 100644 --- a/manul/src/session.rs +++ b/manul/src/session.rs @@ -12,10 +12,9 @@ mod session; mod transcript; pub use crate::protocol::{LocalError, RemoteError}; +pub(crate) use echo::EchoRoundError; pub use message::MessageBundle; pub use session::{CanFinalize, RoundAccumulator, RoundOutcome, Session, SessionId}; pub use transcript::{SessionOutcome, SessionReport}; -pub(crate) use echo::EchoRoundError; - pub use signature; diff --git a/manul/src/session/echo.rs b/manul/src/session/echo.rs index 8a9fc4e..4243de1 100644 --- a/manul/src/session/echo.rs +++ b/manul/src/session/echo.rs @@ -17,7 +17,7 @@ use super::{ }; use crate::protocol::{ Artifact, DirectMessage, EchoBroadcast, FinalizeError, FinalizeOutcome, ObjectSafeRound, Payload, Protocol, - ReceiveError, Round, RoundId, + ReceiveError, Round, RoundId, Serializer, }; #[derive(Debug)] @@ -31,26 +31,27 @@ pub struct EchoRoundMessage { pub(crate) echo_messages: BTreeMap>, } -pub struct EchoRound { +pub struct EchoRound { verifier: Id, echo_messages: BTreeMap>, destinations: BTreeSet, expected_echos: BTreeSet, - main_round: Box>, + main_round: Box>, payloads: BTreeMap, artifacts: BTreeMap, } -impl EchoRound +impl EchoRound where P: Protocol, Id: Debug + Clone + Ord, + S: Serializer, { pub fn new( verifier: Id, my_echo_message: SignedMessage, echo_messages: BTreeMap>, - main_round: Box>, + main_round: Box>, payloads: BTreeMap, artifacts: BTreeMap, ) -> Self { @@ -76,7 +77,7 @@ where } } -impl Round for EchoRound +impl Round for EchoRound where P: 'static + Protocol, Id: 'static @@ -90,6 +91,7 @@ where + Sync + DigestVerifier, Sig: 'static + Debug + Clone + Serialize + for<'de> Deserialize<'de> + Eq + Send + Sync, + S: 'static + Serializer, { type Protocol = P; @@ -122,7 +124,7 @@ where } let message = EchoRoundMessage { echo_messages }; - let dm = DirectMessage::new::(&message)?; + let dm = DirectMessage::new::(&message)?; Ok((dm, Artifact::empty())) } @@ -139,7 +141,7 @@ where ) -> Result> { debug!("{:?}: received an echo message from {:?}", self.verifier, from); - let message = direct_message.deserialize::>()?; + let message = direct_message.deserialize::>()?; // Check that the received message contains entries from `destinations` sans `from` // It is an unprovable fault. @@ -186,7 +188,7 @@ where continue; } - let verified_echo = match echo.clone().verify::(sender) { + let verified_echo = match echo.clone().verify::(sender) { Ok(echo) => echo, Err(MessageVerificationError::Local(error)) => return Err(error.into()), // This means `from` sent us an incorrectly signed message. @@ -217,7 +219,7 @@ where rng: &mut impl CryptoRngCore, _payloads: BTreeMap, _artifacts: BTreeMap, - ) -> Result, FinalizeError> { + ) -> Result, FinalizeError> { self.main_round.finalize(rng, self.payloads, self.artifacts) } } diff --git a/manul/src/session/evidence.rs b/manul/src/session/evidence.rs index 56b1f26..85f78b1 100644 --- a/manul/src/session/evidence.rs +++ b/manul/src/session/evidence.rs @@ -12,7 +12,7 @@ use super::{ }; use crate::protocol::{ DirectMessage, DirectMessageError, EchoBroadcast, EchoBroadcastError, MessageValidationError, Protocol, - ProtocolError, ProtocolValidationError, RoundId, + ProtocolError, ProtocolValidationError, RoundId, Serializer, }; #[derive(Debug, Clone)] @@ -120,12 +120,15 @@ where }) } - pub(crate) fn new_echo_round_error( + pub(crate) fn new_echo_round_error( verifier: &Verifier, direct_message: SignedMessage, error: EchoRoundError, transcript: &Transcript, - ) -> Result { + ) -> Result + where + S: Serializer, + { let description = format!("{:?}", error); match error { EchoRoundError::InvalidEcho(from) => Ok(Self { @@ -146,7 +149,7 @@ where let deserialized = direct_message .payload() - .deserialize::>() + .deserialize::>() .map_err(|error| { LocalError::new(format!("Failed to deserialize the given direct message: {:?}", error)) })?; @@ -207,13 +210,13 @@ where &self.description } - pub fn verify(&self, party: &Verifier) -> Result<(), EvidenceError> { + pub fn verify(&self, party: &Verifier) -> Result<(), EvidenceError> { match &self.evidence { - EvidenceEnum::Protocol(evidence) => evidence.verify(party), - EvidenceEnum::InvalidDirectMessage(evidence) => evidence.verify(party), - EvidenceEnum::InvalidEchoBroadcast(evidence) => evidence.verify(party), - EvidenceEnum::InvalidEchoPack(evidence) => evidence.verify(party), - EvidenceEnum::MismatchedBroadcasts(evidence) => evidence.verify(party), + EvidenceEnum::Protocol(evidence) => evidence.verify::(party), + EvidenceEnum::InvalidDirectMessage(evidence) => evidence.verify::(party), + EvidenceEnum::InvalidEchoBroadcast(evidence) => evidence.verify::(party), + EvidenceEnum::InvalidEchoPack(evidence) => evidence.verify::(party), + EvidenceEnum::MismatchedBroadcasts(evidence) => evidence.verify::(party), } } } @@ -240,9 +243,12 @@ where Sig: Clone + for<'de> Deserialize<'de>, Verifier: Debug + Clone + Ord + DigestVerifier + for<'de> Deserialize<'de>, { - fn verify(&self, verifier: &Verifier) -> Result<(), EvidenceError> { - let verified = self.direct_message.clone().verify::(verifier)?; - let deserialized = verified.payload().deserialize::>()?; + fn verify(&self, verifier: &Verifier) -> Result<(), EvidenceError> + where + S: Serializer, + { + let verified = self.direct_message.clone().verify::(verifier)?; + let deserialized = verified.payload().deserialize::>()?; let invalid_echo = deserialized .echo_messages .get(&self.invalid_echo_sender) @@ -253,7 +259,7 @@ where )) })?; - let verified_echo = match invalid_echo.clone().verify::(&self.invalid_echo_sender) { + let verified_echo = match invalid_echo.clone().verify::(&self.invalid_echo_sender) { Ok(echo) => echo, Err(MessageVerificationError::Local(error)) => return Err(EvidenceError::Local(error)), // The message was indeed incorrectly signed - fault proven @@ -284,12 +290,13 @@ where P: Protocol, Sig: Clone, { - fn verify(&self, verifier: &Verifier) -> Result<(), EvidenceError> + fn verify(&self, verifier: &Verifier) -> Result<(), EvidenceError> where + S: Serializer, Verifier: Debug + Clone + DigestVerifier, { - let we_received = self.we_received.clone().verify::(verifier)?; - let echoed_to_us = self.echoed_to_us.clone().verify::(verifier)?; + let we_received = self.we_received.clone().verify::(verifier)?; + let echoed_to_us = self.echoed_to_us.clone().verify::(verifier)?; if we_received.metadata() == echoed_to_us.metadata() && we_received.payload() != echoed_to_us.payload() { return Ok(()); @@ -312,12 +319,13 @@ where P: Protocol, Sig: Clone, { - fn verify(&self, verifier: &Verifier) -> Result<(), EvidenceError> + fn verify(&self, verifier: &Verifier) -> Result<(), EvidenceError> where + S: Serializer, Verifier: Debug + Clone + DigestVerifier, { - let verified_direct_message = self.direct_message.clone().verify::(verifier)?; - Ok(P::verify_direct_message_is_invalid( + let verified_direct_message = self.direct_message.clone().verify::(verifier)?; + Ok(P::verify_direct_message_is_invalid::( self.direct_message.metadata().round_id(), verified_direct_message.payload(), )?) @@ -335,12 +343,13 @@ where P: Protocol, Sig: Clone, { - fn verify(&self, verifier: &Verifier) -> Result<(), EvidenceError> + fn verify(&self, verifier: &Verifier) -> Result<(), EvidenceError> where + S: Serializer, Verifier: Debug + Clone + DigestVerifier, { - let verified_echo_broadcast = self.echo_broadcast.clone().verify::(verifier)?; - Ok(P::verify_echo_broadcast_is_invalid( + let verified_echo_broadcast = self.echo_broadcast.clone().verify::(verifier)?; + Ok(P::verify_echo_broadcast_is_invalid::( self.echo_broadcast.metadata().round_id(), verified_echo_broadcast.payload(), )?) @@ -362,17 +371,23 @@ where P: Protocol, Sig: Clone + for<'de> Deserialize<'de>, { - fn verify(&self, verifier: &Verifier) -> Result<(), EvidenceError> + fn verify(&self, verifier: &Verifier) -> Result<(), EvidenceError> where + S: Serializer, Verifier: Debug + Clone + Ord + for<'de> Deserialize<'de> + DigestVerifier, { let session_id = self.direct_message.metadata().session_id(); - let verified_direct_message = self.direct_message.clone().verify::(verifier)?.payload().clone(); + let verified_direct_message = self + .direct_message + .clone() + .verify::(verifier)? + .payload() + .clone(); let mut verified_direct_messages = BTreeMap::new(); for (round_id, direct_message) in self.direct_messages.iter() { - let verified_direct_message = direct_message.clone().verify::(verifier)?; + let verified_direct_message = direct_message.clone().verify::(verifier)?; let metadata = verified_direct_message.metadata(); if metadata.session_id() != session_id || metadata.round_id() != *round_id { return Err(EvidenceError::InvalidEvidence( @@ -389,14 +404,14 @@ where "Invalid attached message metadata".into(), )); } - Some(echo.clone().verify::(verifier)?.payload().clone()) + Some(echo.clone().verify::(verifier)?.payload().clone()) } else { None }; let mut verified_echo_broadcasts = BTreeMap::new(); for (round_id, echo_broadcast) in self.echo_broadcasts.iter() { - let verified_echo_broadcast = echo_broadcast.clone().verify::(verifier)?; + let verified_echo_broadcast = echo_broadcast.clone().verify::(verifier)?; let metadata = verified_echo_broadcast.metadata(); if metadata.session_id() != session_id || metadata.round_id() != *round_id { return Err(EvidenceError::InvalidEvidence( @@ -408,7 +423,7 @@ where let mut combined_echos = BTreeMap::new(); for (round_id, combined_echo) in self.combined_echos.iter() { - let verified_combined_echo = combined_echo.clone().verify::(verifier)?; + let verified_combined_echo = combined_echo.clone().verify::(verifier)?; let metadata = verified_combined_echo.metadata(); if metadata.session_id() != session_id || metadata.round_id().non_echo() != *round_id { return Err(EvidenceError::InvalidEvidence( @@ -416,11 +431,11 @@ where )); } let echo_set = - DirectMessage::deserialize::>(verified_combined_echo.payload())?; + DirectMessage::deserialize::>(verified_combined_echo.payload())?; let mut verified_echo_set = Vec::new(); for (other_verifier, echo_broadcast) in echo_set.echo_messages.iter() { - let verified_echo_broadcast = echo_broadcast.clone().verify::(other_verifier)?; + let verified_echo_broadcast = echo_broadcast.clone().verify::(other_verifier)?; let metadata = verified_echo_broadcast.metadata(); if metadata.session_id() != session_id || metadata.round_id() != *round_id { return Err(EvidenceError::InvalidEvidence( @@ -432,7 +447,7 @@ where combined_echos.insert(*round_id, verified_echo_set); } - Ok(self.error.verify_messages_constitute_error( + Ok(self.error.verify_messages_constitute_error::( &verified_echo_broadcast, &verified_direct_message, &verified_echo_broadcasts, diff --git a/manul/src/session/message.rs b/manul/src/session/message.rs index 3657bb1..548802e 100644 --- a/manul/src/session/message.rs +++ b/manul/src/session/message.rs @@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize}; use signature::{DigestVerifier, RandomizedDigestSigner}; use super::{session::SessionId, LocalError}; -use crate::protocol::{DirectMessage, EchoBroadcast, Protocol, RoundId}; +use crate::protocol::{DirectMessage, EchoBroadcast, Protocol, RoundId, Serializer}; #[derive(Debug, Clone)] pub(crate) enum MessageVerificationError { @@ -53,7 +53,7 @@ impl SignedMessage where M: Serialize, { - pub fn new( + pub fn new( rng: &mut impl CryptoRngCore, signer: &Signer, session_id: &SessionId, @@ -62,11 +62,12 @@ where ) -> Result where P: Protocol, + S: Serializer, Signer: RandomizedDigestSigner, { let metadata = MessageMetadata::new(session_id, round_id); let message_with_metadata = MessageWithMetadata { metadata, message }; - let message_bytes = P::serialize(&message_with_metadata)?; + let message_bytes = S::serialize(&message_with_metadata)?; let digest = P::Digest::new_with_prefix(b"SignedMessage").chain_update(message_bytes); let signature = signer .try_sign_digest_with_rng(rng, digest) @@ -85,14 +86,16 @@ where &self.message_with_metadata.message } - pub(crate) fn verify( + pub(crate) fn verify( self, verifier: &Verifier, ) -> Result, MessageVerificationError> where + P: Protocol, + S: Serializer, Verifier: Clone + DigestVerifier, { - let message_bytes = P::serialize(&self.message_with_metadata).map_err(MessageVerificationError::Local)?; + let message_bytes = S::serialize(&self.message_with_metadata).map_err(MessageVerificationError::Local)?; let digest = P::Digest::new_with_prefix(b"SignedMessage").chain_update(message_bytes); if verifier.verify_digest(digest, &self.signature).is_ok() { Ok(VerifiedMessage { @@ -141,7 +144,7 @@ impl MessageBundle where Sig: PartialEq + Clone, { - pub(crate) fn new( + pub(crate) fn new( rng: &mut impl CryptoRngCore, signer: &Signer, session_id: &SessionId, @@ -151,9 +154,10 @@ where ) -> Result where P: Protocol, + S: Serializer, Signer: RandomizedDigestSigner, { - let direct_message = SignedMessage::new::(rng, signer, session_id, round_id, direct_message)?; + let direct_message = SignedMessage::new::(rng, signer, session_id, round_id, direct_message)?; Ok(Self { direct_message, echo_broadcast, @@ -191,17 +195,19 @@ impl CheckedMessageBundle { &self.metadata } - pub fn verify( + pub fn verify( self, verifier: &Verifier, ) -> Result, MessageVerificationError> where + P: Protocol, + S: Serializer, Verifier: Clone + DigestVerifier, { - let direct_message = self.direct_message.verify::(verifier)?; + let direct_message = self.direct_message.verify::(verifier)?; let echo_broadcast = self .echo_broadcast - .map(|echo| echo.verify::(verifier)) + .map(|echo| echo.verify::(verifier)) .transpose()?; Ok(VerifiedMessageBundle { from: verifier.clone(), diff --git a/manul/src/session/session.rs b/manul/src/session/session.rs index 7b2be91..96296be 100644 --- a/manul/src/session/session.rs +++ b/manul/src/session/session.rs @@ -21,7 +21,7 @@ use super::{ }; use crate::protocol::{ Artifact, DirectMessage, EchoBroadcast, FinalizeError, FinalizeOutcome, FirstRound, ObjectSafeRound, - ObjectSafeRoundWrapper, Payload, Protocol, ReceiveError, ReceiveErrorType, Round, RoundId, + ObjectSafeRoundWrapper, Payload, Protocol, ReceiveError, ReceiveErrorType, Round, RoundId, Serializer, }; /// A session identifier shared between the parties. @@ -55,11 +55,11 @@ impl AsRef<[u8]> for SessionId { /// An object encapsulating the currently active round, transport protocol, /// and the database of messages and errors from the previous rounds. -pub struct Session { +pub struct Session { session_id: SessionId, signer: Signer, verifier: Verifier, - round: Box>, + round: Box>, message_destinations: BTreeSet, echo_message: Option>, possible_next_rounds: BTreeSet, @@ -67,33 +67,34 @@ pub struct Session { } /// Possible non-erroneous results of finalizing a round. -pub enum RoundOutcome { +pub enum RoundOutcome { /// The execution is finished. Finished(SessionReport), /// Transitioned to another round. AnotherRound { /// The session object for the new round. - session: Session, + session: Session, /// The messages intended for the new round cached during the previous round. cached_messages: Vec>, }, } -impl Session +impl Session where - P: Protocol + 'static, + P: 'static + Protocol, + S: 'static + Serializer, Signer: RandomizedDigestSigner + Keypair, - Verifier: Debug + Verifier: 'static + + Debug + Clone + Eq + Ord + DigestVerifier - + 'static + Serialize + for<'de> Deserialize<'de> + Send + Sync, - Sig: Debug + Clone + Eq + 'static + Serialize + for<'de> Deserialize<'de> + Send + Sync, + Sig: 'static + Debug + Clone + Eq + Serialize + for<'de> Deserialize<'de> + Send + Sync, { /// Initializes a new session. pub fn new( @@ -103,7 +104,7 @@ where inputs: R::Inputs, ) -> Result where - R: FirstRound + Round + 'static, + R: 'static + FirstRound + Round, { let verifier = signer.verifying_key(); let first_round = Box::new(ObjectSafeRoundWrapper::new(R::new( @@ -119,14 +120,14 @@ where rng: &mut impl CryptoRngCore, session_id: SessionId, signer: Signer, - round: Box>, + round: Box>, transcript: Transcript, ) -> Result { let verifier = signer.verifying_key(); let echo_message = round .make_echo_broadcast(rng) .transpose()? - .map(|echo| SignedMessage::new::(rng, &signer, &session_id, round.id(), echo)) + .map(|echo| SignedMessage::new::(rng, &signer, &session_id, round.id(), echo)) .transpose()?; let message_destinations = round.message_destinations().clone(); @@ -173,7 +174,7 @@ where ) -> Result<(MessageBundle, ProcessedArtifact), LocalError> { let (direct_message, artifact) = self.round.make_direct_message(rng, destination)?; - let bundle = MessageBundle::new::( + let bundle = MessageBundle::new::( rng, &self.signer, &self.session_id, @@ -261,7 +262,7 @@ where // Verify the signature now - let verified_message = match checked_message.verify::(from) { + let verified_message = match checked_message.verify::(from) { Ok(verified_message) => verified_message, Err(MessageVerificationError::InvalidSignature) => { accum.register_unprovable_error(from, RemoteError::new("Message verification failed"))?; @@ -318,7 +319,7 @@ where accum: &mut RoundAccumulator, processed: ProcessedMessage, ) -> Result<(), LocalError> { - accum.add_processed_message(&self.transcript, processed) + accum.add_processed_message::(&self.transcript, processed) } /// Makes an accumulator for a new round. @@ -348,7 +349,7 @@ where self, rng: &mut impl CryptoRngCore, accum: RoundAccumulator, - ) -> Result, LocalError> { + ) -> Result, LocalError> { let verifier = self.verifier().clone(); let round_id = self.round_id(); @@ -551,11 +552,14 @@ where Ok(()) } - fn add_processed_message( + fn add_processed_message( &mut self, transcript: &Transcript, processed: ProcessedMessage, - ) -> Result<(), LocalError> { + ) -> Result<(), LocalError> + where + S: Serializer, + { if self.payloads.contains_key(processed.message.from()) { return Err(LocalError::new(format!( "A processed message from {:?} has already been recorded", @@ -609,7 +613,7 @@ where } ReceiveErrorType::Echo(error) => { let (_echo_broadcast, direct_message) = processed.message.into_unverified(); - let evidence = Evidence::new_echo_round_error(&from, direct_message, error, transcript)?; + let evidence = Evidence::new_echo_round_error::(&from, direct_message, error, transcript)?; self.register_provable_error(&from, evidence) } ReceiveErrorType::Local(error) => Err(error), @@ -661,7 +665,7 @@ mod tests { use crate::{ protocol::{ DeserializationError, DirectMessage, EchoBroadcast, LocalError, Protocol, ProtocolError, - ProtocolValidationError, RoundId, + ProtocolValidationError, RoundId, Serializer, }, testing::{Hasher, Signature, Signer, Verifier}, }; @@ -683,7 +687,7 @@ mod tests { struct DummyProtocolError; impl ProtocolError for DummyProtocolError { - fn verify_messages_constitute_error( + fn verify_messages_constitute_error( &self, _echo_broadcast: &Option, _direct_message: &DirectMessage, @@ -695,11 +699,9 @@ mod tests { } } - impl Protocol for DummyProtocol { - type Result = (); - type ProtocolError = DummyProtocolError; - type CorrectnessProof = (); - type Digest = Hasher; + struct DummySerializer; + + impl Serializer for DummySerializer { fn serialize(_: T) -> Result, LocalError> where T: Serialize, @@ -714,12 +716,19 @@ mod tests { } } + impl Protocol for DummyProtocol { + type Result = (); + type ProtocolError = DummyProtocolError; + type CorrectnessProof = (); + type Digest = Hasher; + } + // We need `Session` to be `Send` so that we send a `Session` object to a task // to run the loop there. - assert!(impls!(Session: Send)); + assert!(impls!(Session: Send)); // This is needed so that message processing offloaded to a task could use `&Session`. - assert!(impls!(Session: Sync)); + assert!(impls!(Session: Sync)); // These objects are sent to/from message processing tasks assert!(impls!(MessageBundle: Send)); diff --git a/manul/src/testing/macros.rs b/manul/src/testing/macros.rs index 9ce75c2..bf4a9e1 100644 --- a/manul/src/testing/macros.rs +++ b/manul/src/testing/macros.rs @@ -3,13 +3,13 @@ use alloc::collections::BTreeMap; use rand_core::CryptoRngCore; use crate::protocol::{ - Artifact, DirectMessage, EchoBroadcast, FinalizeError, FinalizeOutcome, LocalError, Payload, Round, + Artifact, DirectMessage, EchoBroadcast, FinalizeError, FinalizeOutcome, LocalError, Payload, Round, Serializer, }; /// A trait defining a wrapper around an existing type implementing [`Round`]. -pub trait RoundWrapper: 'static + Sized + Send + Sync { +pub trait RoundWrapper: 'static + Sized + Send + Sync { /// The inner round type. - type InnerRound: Round; + type InnerRound: Round; /// Returns a reference to the inner round. fn inner_round_ref(&self) -> &Self::InnerRound; @@ -23,7 +23,7 @@ pub trait RoundWrapper: 'static + Sized + Send + Sync { /// Intended to be used with [`round_override`] to generate the [`Round`] implementation. /// /// The blanket implementations default to the methods of the wrapped round. -pub trait RoundOverride: RoundWrapper { +pub trait RoundOverride: RoundWrapper { /// An override for [`Round::make_direct_message`]. fn make_direct_message( &self, @@ -46,8 +46,8 @@ pub trait RoundOverride: RoundWrapper { payloads: BTreeMap, artifacts: BTreeMap, ) -> Result< - FinalizeOutcome>::InnerRound as Round>::Protocol>, - FinalizeError<<>::InnerRound as Round>::Protocol>, + FinalizeOutcome>::InnerRound as Round>::Protocol, S>, + FinalizeError<<>::InnerRound as Round>::Protocol>, > { self.inner_round().finalize(rng, payloads, artifacts) } @@ -62,12 +62,16 @@ pub trait RoundOverride: RoundWrapper { #[macro_export] macro_rules! round_override { ($round: ident) => { - impl Round for $round + impl Round for $round where - $round: RoundOverride, + $round: RoundOverride, + S: $crate::protocol::Serializer, { type Protocol = - <<$round as $crate::testing::RoundWrapper>::InnerRound as $crate::protocol::Round>::Protocol; + <<$round as $crate::testing::RoundWrapper>::InnerRound as $crate::protocol::Round< + Id, + S, + >>::Protocol; fn id(&self) -> $crate::protocol::RoundId { self.inner_round_ref().id() @@ -85,15 +89,16 @@ macro_rules! round_override { &self, rng: &mut impl CryptoRngCore, destination: &Id, - ) -> Result<($crate::protocol::DirectMessage, $crate::protocol::Artifact), $crate::protocol::LocalError> { - >::make_direct_message(self, rng, destination) + ) -> Result<($crate::protocol::DirectMessage, $crate::protocol::Artifact), $crate::protocol::LocalError> + { + >::make_direct_message(self, rng, destination) } fn make_echo_broadcast( &self, rng: &mut impl CryptoRngCore, ) -> Option> { - >::make_echo_broadcast(self, rng) + >::make_echo_broadcast(self, rng) } fn receive_message( @@ -113,10 +118,10 @@ macro_rules! round_override { payloads: ::alloc::collections::BTreeMap, artifacts: ::alloc::collections::BTreeMap, ) -> Result< - $crate::protocol::FinalizeOutcome, - $crate::protocol::FinalizeError + $crate::protocol::FinalizeOutcome, + $crate::protocol::FinalizeError, > { - >::finalize(self, rng, payloads, artifacts) + >::finalize(self, rng, payloads, artifacts) } fn expecting_messages_from(&self) -> &BTreeSet { diff --git a/manul/src/testing/run_sync.rs b/manul/src/testing/run_sync.rs index 691b31c..8c40405 100644 --- a/manul/src/testing/run_sync.rs +++ b/manul/src/testing/run_sync.rs @@ -8,15 +8,15 @@ use signature::{DigestVerifier, Keypair, RandomizedDigestSigner}; use tracing::debug; use crate::{ - protocol::{FirstRound, Protocol}, + protocol::{FirstRound, Protocol, Serializer}, session::{ CanFinalize, LocalError, MessageBundle, RoundAccumulator, RoundOutcome, Session, SessionId, SessionReport, }, }; -enum State { +enum State { InProgress { - session: Session, + session: Session, accum: RoundAccumulator, }, Finished(SessionReport), @@ -29,25 +29,26 @@ struct Message { } #[allow(clippy::type_complexity)] -fn propagate( +fn propagate( rng: &mut impl CryptoRngCore, - session: Session, + session: Session, accum: RoundAccumulator, -) -> Result<(State, Vec>), LocalError> +) -> Result<(State, Vec>), LocalError> where - P: Protocol + 'static, + P: 'static + Protocol, + S: 'static + Serializer, Signer: RandomizedDigestSigner + Keypair, - Verifier: Debug + Verifier: 'static + + Debug + Clone + Eq + Ord + DigestVerifier - + 'static + Serialize + for<'de> Deserialize<'de> + Send + Sync, - Sig: Debug + Clone + Eq + 'static + Serialize + for<'de> Deserialize<'de> + Send + Sync, + Sig: 'static + Debug + Clone + Eq + Serialize + for<'de> Deserialize<'de> + Send + Sync, { let mut messages = Vec::new(); @@ -101,24 +102,25 @@ where /// Execute sessions for multiple nodes concurrently, given the the inputs /// for the first round `R` and the signer for each node. #[allow(clippy::type_complexity)] -pub fn run_sync( +pub fn run_sync( rng: &mut impl CryptoRngCore, inputs: Vec<(Signer, R::Inputs)>, ) -> Result>, LocalError> where - R: FirstRound + 'static, + R: 'static + FirstRound, + S: 'static + Serializer, Signer: RandomizedDigestSigner<::Digest, Sig> + Keypair, - Verifier: Debug + Verifier: 'static + + Debug + Clone + Eq + Ord + DigestVerifier<::Digest, Sig> - + 'static + Serialize + for<'de> Deserialize<'de> + Send + Sync, - Sig: Debug + Clone + Eq + 'static + Serialize + for<'de> Deserialize<'de> + Send + Sync, + Sig: 'static + Debug + Clone + Eq + Serialize + for<'de> Deserialize<'de> + Send + Sync, { let session_id = SessionId::random(rng); @@ -127,7 +129,8 @@ where for (signer, inputs) in inputs { let verifier = signer.verifying_key(); - let session = Session::::new::(rng, session_id.clone(), signer, inputs)?; + let session = + Session::::new::(rng, session_id.clone(), signer, inputs)?; let mut accum = session.make_accumulator(); let destinations = session.message_destinations();