Skip to content

Commit

Permalink
Switch to supplying a (de)serializer externally
Browse files Browse the repository at this point in the history
  • Loading branch information
fjarri committed Oct 19, 2024
1 parent 9cadf8b commit ddd45da
Show file tree
Hide file tree
Showing 16 changed files with 392 additions and 257 deletions.
3 changes: 3 additions & 0 deletions example/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
extern crate alloc;

mod serializer;
pub mod simple;

#[cfg(test)]
mod simple_malicious;

pub use serializer::BincodeSerializer;
17 changes: 17 additions & 0 deletions example/src/serializer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
use manul::protocol::{DeserializationError, LocalError, Serializer};
use serde::{Deserialize, Serialize};

pub struct BincodeSerializer;

impl Serializer for BincodeSerializer {
fn serialize<T: Serialize>(value: T) -> Result<Box<[u8]>, LocalError> {
bincode::serde::encode_to_vec(value, bincode::config::standard())
.map(|vec| vec.into())
.map_err(|err| LocalError::new(err.to_string()))
}

fn deserialize<'de, T: Deserialize<'de>>(bytes: &'de [u8]) -> Result<T, DeserializationError> {
bincode::serde::decode_borrowed_from_slice(bytes, bincode::config::standard())
.map_err(|err| DeserializationError::new(err.to_string()))
}
}
65 changes: 37 additions & 28 deletions example/src/simple.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use alloc::collections::{BTreeMap, BTreeSet};
use core::fmt::Debug;
use core::{fmt::Debug, marker::PhantomData};

use manul::protocol::*;
use rand_core::CryptoRngCore;
Expand Down Expand Up @@ -35,7 +35,7 @@ impl ProtocolError for SimpleProtocolError {
}
}

fn verify_messages_constitute_error(
fn verify_messages_constitute_error<S: Serializer>(
&self,
_echo_broadcast: &Option<EchoBroadcast>,
direct_message: &DirectMessage,
Expand All @@ -45,20 +45,20 @@ impl ProtocolError for SimpleProtocolError {
) -> Result<(), ProtocolValidationError> {
match self {
SimpleProtocolError::Round1InvalidPosition => {
let _message = direct_message.deserialize::<SimpleProtocol, Round1Message>()?;
let _message = direct_message.deserialize::<S, Round1Message>()?;
// Message contents would be checked here
Ok(())
}
SimpleProtocolError::Round2InvalidPosition => {
let _r1_message = direct_message.deserialize::<SimpleProtocol, Round1Message>()?;
let _r1_message = direct_message.deserialize::<S, Round1Message>()?;
let r1_echos_serialized = combined_echos
.get(&RoundId::new(1))
.ok_or_else(|| LocalError::new("Could not find combined echos for Round 1"))?;

// Deserialize the echos
let _r1_echos = r1_echos_serialized
.iter()
.map(|echo| echo.deserialize::<SimpleProtocol, Round1Echo>())
.map(|echo| echo.deserialize::<S, Round1Echo>())
.collect::<Result<Vec<_>, _>>()?;

// Message contents would be checked here
Expand All @@ -75,23 +75,12 @@ impl Protocol for SimpleProtocol {

type Digest = Sha3_256;

fn serialize<T: Serialize>(value: T) -> Result<Box<[u8]>, LocalError> {
bincode::serde::encode_to_vec(value, bincode::config::standard())
.map(|vec| vec.into())
.map_err(|err| LocalError::new(err.to_string()))
}

fn deserialize<'de, T: Deserialize<'de>>(bytes: &'de [u8]) -> Result<T, DeserializationError> {
bincode::serde::decode_borrowed_from_slice(bytes, bincode::config::standard())
.map_err(|err| DeserializationError::new(err.to_string()))
}

fn verify_direct_message_is_invalid(
fn verify_direct_message_is_invalid<S: Serializer>(
round_id: RoundId,
message: &DirectMessage,
) -> Result<(), MessageValidationError> {
if round_id == RoundId::new(1) {
return message.verify_is_invalid::<Self, Round1Message>();
return message.verify_is_invalid::<S, Round1Message>();
}
Err(MessageValidationError::InvalidEvidence("Invalid round number".into()))?
}
Expand All @@ -108,8 +97,9 @@ pub(crate) struct Context<Id> {
pub(crate) ids_to_positions: BTreeMap<Id, u8>,
}

pub struct Round1<Id> {
pub struct Round1<Id, S> {
pub(crate) context: Context<Id>,
phantom: PhantomData<fn(S) -> S>,
}

#[derive(Debug, Serialize, Deserialize)]
Expand All @@ -127,7 +117,11 @@ struct Round1Payload {
x: u8,
}

impl<Id: 'static + Debug + Clone + Ord + Send + Sync> FirstRound<Id> for Round1<Id> {
impl<Id, S> FirstRound<Id, S> for Round1<Id, S>
where
Id: 'static + Debug + Clone + Ord + Send + Sync,
S: 'static + Serializer,
{
type Inputs = Inputs<Id>;
fn new(
_rng: &mut impl CryptoRngCore,
Expand All @@ -153,11 +147,16 @@ impl<Id: 'static + Debug + Clone + Ord + Send + Sync> FirstRound<Id> for Round1<
other_ids: ids,
ids_to_positions,
},
phantom: PhantomData,
})
}
}

impl<Id: 'static + Debug + Clone + Ord + Send + Sync> Round<Id> for Round1<Id> {
impl<Id, S> Round<Id, S> for Round1<Id, S>
where
Id: 'static + Debug + Clone + Ord + Send + Sync,
S: 'static + Serializer,
{
type Protocol = SimpleProtocol;

fn id(&self) -> RoundId {
Expand Down Expand Up @@ -207,7 +206,7 @@ impl<Id: 'static + Debug + Clone + Ord + Send + Sync> Round<Id> for Round1<Id> {
) -> Result<Payload, ReceiveError<Id, Self::Protocol>> {
debug!("{:?}: receiving message from {:?}", self.context.id, from);

let message = direct_message.deserialize::<SimpleProtocol, Round1Message>()?;
let message = direct_message.deserialize::<S, Round1Message>()?;

debug!("{:?}: received message: {:?}", self.context.id, message);

Expand All @@ -223,7 +222,7 @@ impl<Id: 'static + Debug + Clone + Ord + Send + Sync> Round<Id> for Round1<Id> {
_rng: &mut impl CryptoRngCore,
payloads: BTreeMap<Id, Payload>,
_artifacts: BTreeMap<Id, Artifact>,
) -> Result<FinalizeOutcome<Id, Self::Protocol>, FinalizeError<Self::Protocol>> {
) -> Result<FinalizeOutcome<Id, Self::Protocol, S>, FinalizeError<Self::Protocol>> {
debug!(
"{:?}: finalizing with messages from {:?}",
self.context.id,
Expand All @@ -240,6 +239,7 @@ impl<Id: 'static + Debug + Clone + Ord + Send + Sync> Round<Id> for Round1<Id> {
let round2 = Round2 {
round1_sum: sum,
context: self.context,
phantom: PhantomData,
};
Ok(FinalizeOutcome::another_round(round2))
}
Expand All @@ -249,9 +249,10 @@ impl<Id: 'static + Debug + Clone + Ord + Send + Sync> Round<Id> for Round1<Id> {
}
}

pub(crate) struct Round2<Id> {
pub(crate) struct Round2<Id, S> {
round1_sum: u8,
pub(crate) context: Context<Id>,
phantom: PhantomData<fn(S) -> S>,
}

#[derive(Debug, Serialize, Deserialize)]
Expand All @@ -260,7 +261,11 @@ pub(crate) struct Round2Message {
pub(crate) your_position: u8,
}

impl<Id: 'static + Debug + Clone + Ord + Send + Sync> Round<Id> for Round2<Id> {
impl<Id, S> Round<Id, S> for Round2<Id, S>
where
Id: 'static + Debug + Clone + Ord + Send + Sync,
S: 'static + Serializer,
{
type Protocol = SimpleProtocol;

fn id(&self) -> RoundId {
Expand Down Expand Up @@ -310,7 +315,7 @@ impl<Id: 'static + Debug + Clone + Ord + Send + Sync> Round<Id> for Round2<Id> {
) -> Result<Payload, ReceiveError<Id, Self::Protocol>> {
debug!("{:?}: receiving message from {:?}", self.context.id, from);

let message = direct_message.deserialize::<SimpleProtocol, Round1Message>()?;
let message = direct_message.deserialize::<S, Round1Message>()?;

debug!("{:?}: received message: {:?}", self.context.id, message);

Expand All @@ -326,7 +331,7 @@ impl<Id: 'static + Debug + Clone + Ord + Send + Sync> Round<Id> for Round2<Id> {
_rng: &mut impl CryptoRngCore,
payloads: BTreeMap<Id, Payload>,
_artifacts: BTreeMap<Id, Artifact>,
) -> Result<FinalizeOutcome<Id, Self::Protocol>, FinalizeError<Self::Protocol>> {
) -> Result<FinalizeOutcome<Id, Self::Protocol, S>, FinalizeError<Self::Protocol>> {
debug!(
"{:?}: finalizing with messages from {:?}",
self.context.id,
Expand Down Expand Up @@ -365,6 +370,7 @@ mod tests {
use tracing_subscriber::EnvFilter;

use super::{Inputs, Round1};
use crate::BincodeSerializer;

#[test]
fn round() {
Expand All @@ -389,7 +395,10 @@ mod tests {
.with_env_filter(EnvFilter::from_default_env())
.finish();
let reports = tracing::subscriber::with_default(my_subscriber, || {
run_sync::<Round1<Verifier>, Signer, Verifier, Signature>(&mut OsRng, inputs).unwrap()
run_sync::<Round1<Verifier, BincodeSerializer>, BincodeSerializer, Signer, Verifier, Signature>(
&mut OsRng, inputs,
)
.unwrap()
});

for (_id, report) in reports {
Expand Down
Loading

0 comments on commit ddd45da

Please sign in to comment.