Skip to content

Commit

Permalink
Supply the format from the session level
Browse files Browse the repository at this point in the history
  • Loading branch information
fjarri committed Oct 21, 2024
1 parent 5cfc6c1 commit b486c6e
Show file tree
Hide file tree
Showing 17 changed files with 301 additions and 177 deletions.
20 changes: 20 additions & 0 deletions examples/src/format.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use manul::{
protocol::{DeserializationError, LocalError},
session::Format,
};
use serde::{Deserialize, Serialize};

pub struct Bincode;

impl Format for Bincode {
fn serialize<T: Serialize>(value: T) -> Result<Box<[u8]>, LocalError> {
bincode::serde::encode_to_vec(value, bincode::config::standard())
.map(|vec| vec.into())
.map_err(|err| LocalError::new(err.to_string()))
}

fn deserialize<'de, T: Deserialize<'de>>(bytes: &'de [u8]) -> Result<T, DeserializationError> {
bincode::serde::decode_borrowed_from_slice(bytes, bincode::config::standard())
.map_err(|err| DeserializationError::new(err.to_string()))
}
}
3 changes: 3 additions & 0 deletions examples/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
extern crate alloc;

mod format;
pub mod simple;

#[cfg(test)]
mod simple_malicious;

pub use format::Bincode;
52 changes: 28 additions & 24 deletions examples/src/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ impl ProtocolError for SimpleProtocolError {

fn verify_messages_constitute_error(
&self,
deserializer: &Deserializer,
_echo_broadcast: &Option<EchoBroadcast>,
direct_message: &DirectMessage,
_echo_broadcasts: &BTreeMap<RoundId, EchoBroadcast>,
Expand All @@ -44,20 +45,20 @@ impl ProtocolError for SimpleProtocolError {
) -> Result<(), ProtocolValidationError> {
match self {
SimpleProtocolError::Round1InvalidPosition => {
let _message = direct_message.deserialize::<SimpleProtocol, Round1Message>()?;
let _message = direct_message.deserialize::<Round1Message>(deserializer)?;
// Message contents would be checked here
Ok(())
}
SimpleProtocolError::Round2InvalidPosition => {
let _r1_message = direct_message.deserialize::<SimpleProtocol, Round1Message>()?;
let _r1_message = direct_message.deserialize::<Round1Message>(deserializer)?;
let r1_echos_serialized = combined_echos
.get(&RoundId::new(1))
.ok_or_else(|| LocalError::new("Could not find combined echos for Round 1"))?;

// Deserialize the echos
let _r1_echos = r1_echos_serialized
.iter()
.map(|echo| echo.deserialize::<SimpleProtocol, Round1Echo>())
.map(|echo| echo.deserialize::<Round1Echo>(deserializer))
.collect::<Result<Vec<_>, _>>()?;

// Message contents would be checked here
Expand All @@ -72,23 +73,13 @@ impl Protocol for SimpleProtocol {
type ProtocolError = SimpleProtocolError;
type CorrectnessProof = ();

fn serialize<T: Serialize>(value: T) -> Result<Box<[u8]>, LocalError> {
bincode::serde::encode_to_vec(value, bincode::config::standard())
.map(|vec| vec.into())
.map_err(|err| LocalError::new(err.to_string()))
}

fn deserialize<'de, T: Deserialize<'de>>(bytes: &'de [u8]) -> Result<T, DeserializationError> {
bincode::serde::decode_borrowed_from_slice(bytes, bincode::config::standard())
.map_err(|err| DeserializationError::new(err.to_string()))
}

fn verify_direct_message_is_invalid(
deserializer: &Deserializer,
round_id: RoundId,
message: &DirectMessage,
) -> Result<(), MessageValidationError> {
if round_id == RoundId::new(1) {
return message.verify_is_invalid::<Self, Round1Message>();
return message.verify_is_invalid::<Round1Message>(deserializer);
}
Err(MessageValidationError::InvalidEvidence("Invalid round number".into()))?
}
Expand Down Expand Up @@ -169,19 +160,24 @@ impl<Id: 'static + Debug + Clone + Ord + Send + Sync> Round<Id> for Round1<Id> {
&self.context.other_ids
}

fn make_echo_broadcast(&self, _rng: &mut impl CryptoRngCore) -> Option<Result<EchoBroadcast, LocalError>> {
fn make_echo_broadcast(
&self,
_rng: &mut impl CryptoRngCore,
serializer: &Serializer,
) -> Option<Result<EchoBroadcast, LocalError>> {
debug!("{:?}: making echo broadcast", self.context.id);

let message = Round1Echo {
my_position: self.context.ids_to_positions[&self.context.id],
};

Some(Self::serialize_echo_broadcast(message))
Some(EchoBroadcast::new(serializer, message))
}

fn make_direct_message(
&self,
_rng: &mut impl CryptoRngCore,
serializer: &Serializer,
destination: &Id,
) -> Result<(DirectMessage, Artifact), LocalError> {
debug!("{:?}: making direct message for {:?}", self.context.id, destination);
Expand All @@ -190,21 +186,22 @@ impl<Id: 'static + Debug + Clone + Ord + Send + Sync> Round<Id> for Round1<Id> {
my_position: self.context.ids_to_positions[&self.context.id],
your_position: self.context.ids_to_positions[destination],
};
let dm = Self::serialize_direct_message(message)?;
let dm = DirectMessage::new(serializer, message)?;
let artifact = Artifact::empty();
Ok((dm, artifact))
}

fn receive_message(
&self,
_rng: &mut impl CryptoRngCore,
deserializer: &Deserializer,
from: &Id,
_echo_broadcast: Option<EchoBroadcast>,
direct_message: DirectMessage,
) -> Result<Payload, ReceiveError<Id, Self::Protocol>> {
debug!("{:?}: receiving message from {:?}", self.context.id, from);

let message = direct_message.deserialize::<SimpleProtocol, Round1Message>()?;
let message = direct_message.deserialize::<Round1Message>(deserializer)?;

debug!("{:?}: received message: {:?}", self.context.id, message);

Expand Down Expand Up @@ -272,19 +269,24 @@ impl<Id: 'static + Debug + Clone + Ord + Send + Sync> Round<Id> for Round2<Id> {
&self.context.other_ids
}

fn make_echo_broadcast(&self, _rng: &mut impl CryptoRngCore) -> Option<Result<EchoBroadcast, LocalError>> {
fn make_echo_broadcast(
&self,
_rng: &mut impl CryptoRngCore,
serializer: &Serializer,
) -> Option<Result<EchoBroadcast, LocalError>> {
debug!("{:?}: making echo broadcast", self.context.id);

let message = Round1Echo {
my_position: self.context.ids_to_positions[&self.context.id],
};

Some(Self::serialize_echo_broadcast(message))
Some(EchoBroadcast::new(serializer, message))
}

fn make_direct_message(
&self,
_rng: &mut impl CryptoRngCore,
serializer: &Serializer,
destination: &Id,
) -> Result<(DirectMessage, Artifact), LocalError> {
debug!("{:?}: making direct message for {:?}", self.context.id, destination);
Expand All @@ -293,21 +295,22 @@ impl<Id: 'static + Debug + Clone + Ord + Send + Sync> Round<Id> for Round2<Id> {
my_position: self.context.ids_to_positions[&self.context.id],
your_position: self.context.ids_to_positions[destination],
};
let dm = Self::serialize_direct_message(message)?;
let dm = DirectMessage::new(serializer, message)?;
let artifact = Artifact::empty();
Ok((dm, artifact))
}

fn receive_message(
&self,
_rng: &mut impl CryptoRngCore,
deserializer: &Deserializer,
from: &Id,
_echo_broadcast: Option<EchoBroadcast>,
direct_message: DirectMessage,
) -> Result<Payload, ReceiveError<Id, Self::Protocol>> {
debug!("{:?}: receiving message from {:?}", self.context.id, from);

let message = direct_message.deserialize::<SimpleProtocol, Round1Message>()?;
let message = direct_message.deserialize::<Round1Message>(deserializer)?;

debug!("{:?}: received message: {:?}", self.context.id, message);

Expand Down Expand Up @@ -362,6 +365,7 @@ mod tests {
use tracing_subscriber::EnvFilter;

use super::{Inputs, Round1};
use crate::Bincode;

#[test]
fn round() {
Expand All @@ -386,7 +390,7 @@ mod tests {
.with_env_filter(EnvFilter::from_default_env())
.finish();
let reports = tracing::subscriber::with_default(my_subscriber, || {
run_sync::<Round1<Verifier>, TestingSessionParams>(&mut OsRng, inputs).unwrap()
run_sync::<Round1<Verifier>, TestingSessionParams<Bincode>>(&mut OsRng, inputs).unwrap()
});

for (_id, report) in reports {
Expand Down
25 changes: 15 additions & 10 deletions examples/src/simple_malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@ use manul::{
protocol::{
Artifact, DirectMessage, FinalizeError, FinalizeOutcome, FirstRound, LocalError, Payload, Round, SessionId,
},
session::signature::Keypair,
session::{signature::Keypair, Serializer},
testing::{round_override, run_sync, RoundOverride, RoundWrapper, Signer, TestingSessionParams, Verifier},
};
use rand_core::{CryptoRngCore, OsRng};
use tracing_subscriber::EnvFilter;

use crate::simple::{Inputs, Round1, Round1Message, Round2, Round2Message};
use crate::{
simple::{Inputs, Round1, Round1Message, Round2, Round2Message},
Bincode,
};

#[derive(Debug, Clone, Copy)]
enum Behavior {
Expand Down Expand Up @@ -61,20 +64,21 @@ impl<Id: 'static + Debug + Clone + Ord + Send + Sync> RoundOverride<Id> for Mali
fn make_direct_message(
&self,
rng: &mut impl CryptoRngCore,
serializer: &Serializer,
destination: &Id,
) -> Result<(DirectMessage, Artifact), LocalError> {
if matches!(self.behavior, Behavior::SerializedGarbage) {
let dm = DirectMessage::new::<<Self::InnerRound as Round<Id>>::Protocol, _>(&[99u8]).unwrap();
let dm = DirectMessage::new(serializer, [99u8]).unwrap();
Ok((dm, Artifact::empty()))
} else if matches!(self.behavior, Behavior::AttributableFailure) {
let message = Round1Message {
my_position: self.round.context.ids_to_positions[&self.round.context.id],
your_position: self.round.context.ids_to_positions[&self.round.context.id],
};
let dm = DirectMessage::new::<<Self::InnerRound as Round<Id>>::Protocol, _>(&message)?;
let dm = DirectMessage::new(serializer, &message)?;
Ok((dm, Artifact::empty()))
} else {
self.inner_round_ref().make_direct_message(rng, destination)
self.inner_round_ref().make_direct_message(rng, serializer, destination)
}
}

Expand Down Expand Up @@ -124,17 +128,18 @@ impl<Id: 'static + Debug + Clone + Ord + Send + Sync> RoundOverride<Id> for Mali
fn make_direct_message(
&self,
rng: &mut impl CryptoRngCore,
serializer: &Serializer,
destination: &Id,
) -> Result<(DirectMessage, Artifact), LocalError> {
if matches!(self.behavior, Behavior::AttributableFailureRound2) {
let message = Round2Message {
my_position: self.round.context.ids_to_positions[&self.round.context.id],
your_position: self.round.context.ids_to_positions[&self.round.context.id],
};
let dm = DirectMessage::new::<<Self::InnerRound as Round<Id>>::Protocol, _>(&message)?;
let dm = DirectMessage::new(serializer, &message)?;
Ok((dm, Artifact::empty()))
} else {
self.inner_round_ref().make_direct_message(rng, destination)
self.inner_round_ref().make_direct_message(rng, serializer, destination)
}
}
}
Expand Down Expand Up @@ -172,7 +177,7 @@ fn serialized_garbage() {
.with_env_filter(EnvFilter::from_default_env())
.finish();
let mut reports = tracing::subscriber::with_default(my_subscriber, || {
run_sync::<MaliciousRound1<Verifier>, TestingSessionParams>(&mut OsRng, run_inputs).unwrap()
run_sync::<MaliciousRound1<Verifier>, TestingSessionParams<Bincode>>(&mut OsRng, run_inputs).unwrap()
});

let v0 = signers[0].verifying_key();
Expand Down Expand Up @@ -218,7 +223,7 @@ fn attributable_failure() {
.with_env_filter(EnvFilter::from_default_env())
.finish();
let mut reports = tracing::subscriber::with_default(my_subscriber, || {
run_sync::<MaliciousRound1<Verifier>, TestingSessionParams>(&mut OsRng, run_inputs).unwrap()
run_sync::<MaliciousRound1<Verifier>, TestingSessionParams<Bincode>>(&mut OsRng, run_inputs).unwrap()
});

let v0 = signers[0].verifying_key();
Expand Down Expand Up @@ -264,7 +269,7 @@ fn attributable_failure_round2() {
.with_env_filter(EnvFilter::from_default_env())
.finish();
let mut reports = tracing::subscriber::with_default(my_subscriber, || {
run_sync::<MaliciousRound1<Verifier>, TestingSessionParams>(&mut OsRng, run_inputs).unwrap()
run_sync::<MaliciousRound1<Verifier>, TestingSessionParams<Bincode>>(&mut OsRng, run_inputs).unwrap()
});

let v0 = signers[0].verifying_key();
Expand Down
14 changes: 11 additions & 3 deletions examples/tests/async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ use manul::{
},
testing::{Signer, TestingSessionParams, Verifier},
};
use manul_example::simple::{Inputs, Round1};
use manul_example::{
simple::{Inputs, Round1},
Bincode,
};
use rand::Rng;
use rand_core::OsRng;
use tokio::{
Expand Down Expand Up @@ -231,8 +234,13 @@ async fn async_run() {
let inputs = Inputs {
all_ids: all_ids.clone(),
};
Session::<_, TestingSessionParams>::new::<Round1<Verifier>>(&mut OsRng, session_id.clone(), signer, inputs)
.unwrap()
Session::<_, TestingSessionParams<Bincode>>::new::<Round1<Verifier>>(
&mut OsRng,
session_id.clone(),
signer,
inputs,
)
.unwrap()
})
.collect::<Vec<_>>();

Expand Down
1 change: 1 addition & 0 deletions manul/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ rand_core = { version = "0.6.4", default-features = false }
tracing = { version = "0.1", default-features = false }
displaydoc = { version = "0.2", default-features = false }
rand = { version = "0.8", default-features = false, optional = true }
bincode = { version = "2.0.0-rc.3", default-features = false, features = ["alloc", "serde"] }

[dev-dependencies]
impls = "1"
Expand Down
Loading

0 comments on commit b486c6e

Please sign in to comment.