Skip to content

Commit

Permalink
Make EntryPoints stateful
Browse files Browse the repository at this point in the history
  • Loading branch information
fjarri committed Nov 9, 2024
1 parent fbbcb8e commit 041ce21
Show file tree
Hide file tree
Showing 10 changed files with 300 additions and 278 deletions.
48 changes: 26 additions & 22 deletions examples/src/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use rand_core::CryptoRngCore;
use serde::{Deserialize, Serialize};
use tracing::debug;

#[derive(Debug)]
pub struct SimpleProtocol;

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down Expand Up @@ -111,11 +112,6 @@ impl Protocol for SimpleProtocol {
}
}

#[derive(Debug, Clone)]
pub struct Inputs<Id> {
pub all_ids: BTreeSet<Id>,
}

#[derive(Debug)]
pub(crate) struct Context<Id> {
pub(crate) id: Id,
Expand Down Expand Up @@ -149,30 +145,40 @@ struct Round1Payload {
x: u8,
}

impl<Id: PartyId> EntryPoint<Id> for Round1<Id> {
type Inputs = Inputs<Id>;
#[derive(Debug, Clone)]
pub struct SimpleProtocolEntryPoint<Id> {
my_id: Id,
all_ids: BTreeSet<Id>,
}

impl<Id: PartyId> SimpleProtocolEntryPoint<Id> {
pub fn new(my_id: Id, all_ids: BTreeSet<Id>) -> Self {
Self { my_id, all_ids }
}
}

impl<Id: PartyId> EntryPoint<Id> for SimpleProtocolEntryPoint<Id> {
type Protocol = SimpleProtocol;
fn new(
fn make_round(
self,
_rng: &mut impl CryptoRngCore,
_shared_randomness: &[u8],
id: Id,
inputs: Self::Inputs,
) -> Result<BoxedRound<Id, Self::Protocol>, LocalError> {
// Just some numbers associated with IDs to use in the dummy protocol.
// They will be the same on each node since IDs are ordered.
let ids_to_positions = inputs
let ids_to_positions = self
.all_ids
.iter()
.enumerate()
.map(|(idx, id)| (id.clone(), idx as u8))
.collect::<BTreeMap<_, _>>();

let mut ids = inputs.all_ids;
ids.remove(&id);
let mut ids = self.all_ids;
ids.remove(&self.my_id);

Ok(BoxedRound::new_dynamic(Self {
Ok(BoxedRound::new_dynamic(Round1 {
context: Context {
id,
id: self.my_id,
other_ids: ids,
ids_to_positions,
},
Expand Down Expand Up @@ -401,12 +407,12 @@ mod tests {

use manul::{
session::{signature::Keypair, SessionOutcome},
testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier},
testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner},
};
use rand_core::OsRng;
use tracing_subscriber::EnvFilter;

use super::{Inputs, Round1};
use super::SimpleProtocolEntryPoint;

#[test]
fn round() {
Expand All @@ -415,14 +421,12 @@ mod tests {
.iter()
.map(|signer| signer.verifying_key())
.collect::<BTreeSet<_>>();
let inputs = signers
let entry_points = signers
.into_iter()
.map(|signer| {
(
signer,
Inputs {
all_ids: all_ids.clone(),
},
SimpleProtocolEntryPoint::new(signer.verifying_key(), all_ids.clone()),
)
})
.collect::<Vec<_>>();
Expand All @@ -431,7 +435,7 @@ mod tests {
.with_env_filter(EnvFilter::from_default_env())
.finish();
let reports = tracing::subscriber::with_default(my_subscriber, || {
run_sync::<Round1<TestVerifier>, TestSessionParams<BinaryFormat>>(&mut OsRng, inputs).unwrap()
run_sync::<_, TestSessionParams<BinaryFormat>>(&mut OsRng, entry_points).unwrap()
});

for (_id, report) in reports {
Expand Down
87 changes: 60 additions & 27 deletions examples/src/simple_chain.rs
Original file line number Diff line number Diff line change
@@ -1,51 +1,87 @@
use alloc::collections::BTreeSet;
use core::fmt::Debug;
use rand_core::CryptoRngCore;

use manul::{
combinators::chain::{Chained, ChainedEntryPoint},
protocol::PartyId,
combinators::chain::*,
protocol::{BoxedRound, EntryPoint, LocalError, PartyId, Protocol, RoundId},
};

use super::simple::{Inputs, Round1};
use super::simple::{SimpleProtocol, SimpleProtocolEntryPoint};

pub struct ChainedSimple;
pub type DoubleSimpleProtocol = ChainedProtocol<SimpleProtocol, SimpleProtocol>;

#[derive(Debug)]
pub struct NewInputs<Id>(Inputs<Id>);
pub struct DoubleSimpleEntryPoint<Id> {
my_id: Id,
all_ids: BTreeSet<Id>,
}

impl<'a, Id: PartyId> From<&'a NewInputs<Id>> for Inputs<Id> {
fn from(source: &'a NewInputs<Id>) -> Self {
source.0.clone()
impl<Id: PartyId> DoubleSimpleEntryPoint<Id> {
pub fn new(my_id: Id, all_ids: BTreeSet<Id>) -> Self {
Self { my_id, all_ids }
}
}

impl<Id: PartyId> From<(NewInputs<Id>, u8)> for Inputs<Id> {
fn from(source: (NewInputs<Id>, u8)) -> Self {
let (inputs, _result) = source;
inputs.0
impl<Id> ChainedSplit<Id, SimpleProtocol, SimpleProtocol> for DoubleSimpleEntryPoint<Id>
where
Id: PartyId,
{
type EntryPoint = SimpleProtocolEntryPoint<Id>;
fn make_entry_point1(self) -> (Self::EntryPoint, impl ChainedJoin<Id, SimpleProtocol, SimpleProtocol>) {
(
SimpleProtocolEntryPoint::new(self.my_id.clone(), self.all_ids.clone()),
DoubleTransition {
my_id: self.my_id,
all_ids: self.all_ids,
},
)
}
}

impl<Id: PartyId> Chained<Id> for ChainedSimple {
type Inputs = NewInputs<Id>;
type EntryPoint1 = Round1<Id>;
type EntryPoint2 = Round1<Id>;
#[derive(Debug)]
struct DoubleTransition<Id> {
my_id: Id,
all_ids: BTreeSet<Id>,
}

impl<Id> ChainedJoin<Id, SimpleProtocol, SimpleProtocol> for DoubleTransition<Id>
where
Id: PartyId,
{
type EntryPoint = SimpleProtocolEntryPoint<Id>;
fn make_entry_point2(self, _result: <SimpleProtocol as Protocol>::Result) -> Self::EntryPoint {
SimpleProtocolEntryPoint::new(self.my_id, self.all_ids)
}
}

pub type DoubleSimpleEntryPoint<Id> = ChainedEntryPoint<Id, ChainedSimple>;
impl<Id: PartyId> EntryPoint<Id> for DoubleSimpleEntryPoint<Id> {
type Protocol = DoubleSimpleProtocol;

fn entry_round() -> RoundId {
<Self as ChainedSplit<Id, SimpleProtocol, SimpleProtocol>>::entry_round_id()
}

fn make_round(
self,
rng: &mut impl CryptoRngCore,
shared_randomness: &[u8],
) -> Result<BoxedRound<Id, Self::Protocol>, LocalError> {
make_chained_round(self, rng, shared_randomness)
}
}

#[cfg(test)]
mod tests {
use alloc::collections::BTreeSet;

use manul::{
session::{signature::Keypair, SessionOutcome},
testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier},
testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner},
};
use rand_core::OsRng;
use tracing_subscriber::EnvFilter;

use super::{DoubleSimpleEntryPoint, NewInputs};
use crate::simple::Inputs;
use super::DoubleSimpleEntryPoint;

#[test]
fn round() {
Expand All @@ -54,14 +90,12 @@ mod tests {
.iter()
.map(|signer| signer.verifying_key())
.collect::<BTreeSet<_>>();
let inputs = signers
let entry_points = signers
.into_iter()
.map(|signer| {
(
signer,
NewInputs(Inputs {
all_ids: all_ids.clone(),
}),
DoubleSimpleEntryPoint::new(signer.verifying_key(), all_ids.clone()),
)
})
.collect::<Vec<_>>();
Expand All @@ -70,8 +104,7 @@ mod tests {
.with_env_filter(EnvFilter::from_default_env())
.finish();
let reports = tracing::subscriber::with_default(my_subscriber, || {
run_sync::<DoubleSimpleEntryPoint<TestVerifier>, TestSessionParams<BinaryFormat>>(&mut OsRng, inputs)
.unwrap()
run_sync::<_, TestSessionParams<BinaryFormat>>(&mut OsRng, entry_points).unwrap()
});

for (_id, report) in reports {
Expand Down
47 changes: 22 additions & 25 deletions examples/src/simple_malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@ use alloc::collections::BTreeSet;
use core::fmt::Debug;

use manul::{
combinators::misbehave::{Misbehaving, MisbehavingEntryPoint, MisbehavingInputs},
combinators::misbehave::{Misbehaving, MisbehavingEntryPoint},
protocol::{
Artifact, BoxedRound, Deserializer, DirectMessage, EntryPoint, LocalError, PartyId, ProtocolMessagePart,
RoundId, Serializer,
},
session::signature::Keypair,
testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier},
testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner},
};
use rand_core::{CryptoRngCore, OsRng};
use tracing_subscriber::EnvFilter;

use crate::simple::{Inputs, Round1, Round1Message, Round2, Round2Message};
use crate::simple::{Round1, Round1Message, Round2, Round2Message, SimpleProtocolEntryPoint};

#[derive(Debug, Clone, Copy)]
enum Behavior {
Expand All @@ -25,7 +25,7 @@ enum Behavior {
struct MaliciousLogic;

impl<Id: PartyId> Misbehaving<Id, Behavior> for MaliciousLogic {
type EntryPoint = Round1<Id>;
type EntryPoint = SimpleProtocolEntryPoint<Id>;

fn modify_direct_message(
_rng: &mut impl CryptoRngCore,
Expand Down Expand Up @@ -78,9 +78,8 @@ fn serialized_garbage() {
.iter()
.map(|signer| signer.verifying_key())
.collect::<BTreeSet<_>>();
let inputs = Inputs { all_ids };

let run_inputs = signers
let entry_points = signers
.iter()
.enumerate()
.map(|(idx, signer)| {
Expand All @@ -90,19 +89,19 @@ fn serialized_garbage() {
None
};

let malicious_inputs = MisbehavingInputs {
inner_inputs: inputs.clone(),
let entry_point = MaliciousEntryPoint::new(
SimpleProtocolEntryPoint::new(signer.verifying_key(), all_ids.clone()),
behavior,
};
(*signer, malicious_inputs)
);
(*signer, entry_point)
})
.collect::<Vec<_>>();

let my_subscriber = tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.finish();
let mut reports = tracing::subscriber::with_default(my_subscriber, || {
run_sync::<MaliciousEntryPoint<TestVerifier>, TestSessionParams<BinaryFormat>>(&mut OsRng, run_inputs).unwrap()
run_sync::<_, TestSessionParams<BinaryFormat>>(&mut OsRng, entry_points).unwrap()
});

let v0 = signers[0].verifying_key();
Expand All @@ -124,9 +123,8 @@ fn attributable_failure() {
.iter()
.map(|signer| signer.verifying_key())
.collect::<BTreeSet<_>>();
let inputs = Inputs { all_ids };

let run_inputs = signers
let entry_points = signers
.iter()
.enumerate()
.map(|(idx, signer)| {
Expand All @@ -136,19 +134,19 @@ fn attributable_failure() {
None
};

let malicious_inputs = MisbehavingInputs {
inner_inputs: inputs.clone(),
let entry_point = MaliciousEntryPoint::new(
SimpleProtocolEntryPoint::new(signer.verifying_key(), all_ids.clone()),
behavior,
};
(*signer, malicious_inputs)
);
(*signer, entry_point)
})
.collect::<Vec<_>>();

let my_subscriber = tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.finish();
let mut reports = tracing::subscriber::with_default(my_subscriber, || {
run_sync::<MaliciousEntryPoint<TestVerifier>, TestSessionParams<BinaryFormat>>(&mut OsRng, run_inputs).unwrap()
run_sync::<_, TestSessionParams<BinaryFormat>>(&mut OsRng, entry_points).unwrap()
});

let v0 = signers[0].verifying_key();
Expand All @@ -170,9 +168,8 @@ fn attributable_failure_round2() {
.iter()
.map(|signer| signer.verifying_key())
.collect::<BTreeSet<_>>();
let inputs = Inputs { all_ids };

let run_inputs = signers
let entry_points = signers
.iter()
.enumerate()
.map(|(idx, signer)| {
Expand All @@ -182,19 +179,19 @@ fn attributable_failure_round2() {
None
};

let malicious_inputs = MisbehavingInputs {
inner_inputs: inputs.clone(),
let entry_point = MaliciousEntryPoint::new(
SimpleProtocolEntryPoint::new(signer.verifying_key(), all_ids.clone()),
behavior,
};
(*signer, malicious_inputs)
);
(*signer, entry_point)
})
.collect::<Vec<_>>();

let my_subscriber = tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.finish();
let mut reports = tracing::subscriber::with_default(my_subscriber, || {
run_sync::<MaliciousEntryPoint<TestVerifier>, TestSessionParams<BinaryFormat>>(&mut OsRng, run_inputs).unwrap()
run_sync::<_, TestSessionParams<BinaryFormat>>(&mut OsRng, entry_points).unwrap()
});

let v0 = signers[0].verifying_key();
Expand Down
Loading

0 comments on commit 041ce21

Please sign in to comment.