Skip to content

Commit

Permalink
Add Misbehave combinator
Browse files Browse the repository at this point in the history
  • Loading branch information
fjarri committed Nov 3, 2024
1 parent aed78d0 commit 104d625
Show file tree
Hide file tree
Showing 10 changed files with 706 additions and 137 deletions.
1 change: 1 addition & 0 deletions examples/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
extern crate alloc;

pub mod simple;
mod simple_chain;

#[cfg(test)]
mod simple_malicious;
86 changes: 86 additions & 0 deletions examples/src/simple_chain.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
use core::fmt::Debug;
use core::marker::PhantomData;

use manul::{combinators::*, protocol::PartyId};

use super::simple::{Inputs, Round1, SimpleProtocol};

pub struct ChainedSimple<Id>(PhantomData<Id>);

#[derive(Debug)]
pub struct NewInputs<Id>(Inputs<Id>);

impl<'a, Id: PartyId> From<&'a NewInputs<Id>> for Inputs<Id> {
fn from(source: &'a NewInputs<Id>) -> Self {
source.0.clone()
}
}

impl<Id: PartyId> From<(NewInputs<Id>, u8)> for Inputs<Id> {
fn from(source: (NewInputs<Id>, u8)) -> Self {
source.0 .0
}
}

impl<Id: PartyId> Chained<Id> for ChainedSimple<Id> {
type Protocol1 = SimpleProtocol;
type Protocol2 = SimpleProtocol;
type CorrectnessProof = ();
type Inputs = NewInputs<Id>;
type EntryPoint1 = Round1<Id>;
type EntryPoint2 = Round1<Id>;
}

#[cfg(test)]
mod tests {
use alloc::collections::BTreeSet;

use manul::{
combinators::Chain,
session::{signature::Keypair, SessionOutcome},
testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier},
};
use rand_core::OsRng;
use tracing_subscriber::EnvFilter;

use super::{ChainedSimple, NewInputs};
use crate::simple::Inputs;

#[test]
fn round() {
let signers = (0..3).map(TestSigner::new).collect::<Vec<_>>();
let all_ids = signers
.iter()
.map(|signer| signer.verifying_key())
.collect::<BTreeSet<_>>();
let inputs = signers
.into_iter()
.map(|signer| {
(
signer,
NewInputs(Inputs {
all_ids: all_ids.clone(),
}),
)
})
.collect::<Vec<_>>();

let my_subscriber = tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.finish();
let reports = tracing::subscriber::with_default(my_subscriber, || {
run_sync::<Chain<TestVerifier, ChainedSimple<TestVerifier>>, TestSessionParams<BinaryFormat>>(
&mut OsRng, inputs,
)
.unwrap()
});

for (_id, report) in reports {
if let SessionOutcome::Result(result) = report.outcome {
assert_eq!(result, 3); // 0 + 1 + 2
} else {
panic!("Session did not finish successfully");
}
}
}
}
194 changes: 59 additions & 135 deletions examples/src/simple_malicious.rs
Original file line number Diff line number Diff line change
@@ -1,150 +1,74 @@
use alloc::collections::{BTreeMap, BTreeSet};
use alloc::collections::BTreeSet;
use core::fmt::Debug;

use manul::{
combinators::{Misbehaving, MisbehavingInputs, MisbehavingRound},
protocol::{
Artifact, DirectMessage, FinalizeError, FinalizeOutcome, FirstRound, LocalError, PartyId, Payload,
ProtocolMessagePart, Round, Serializer,
Artifact, DirectMessage, LocalError, ObjectSafeRound, PartyId, ProtocolMessagePart, RoundId, Serializer,
},
session::signature::Keypair,
testing::{
round_override, run_sync, BinaryFormat, RoundOverride, RoundWrapper, TestSessionParams, TestSigner,
TestVerifier,
},
testing::{run_sync, BinaryFormat, TestSessionParams, TestSigner, TestVerifier},
};
use rand_core::{CryptoRngCore, OsRng};
use tracing_subscriber::EnvFilter;

use crate::simple::{Inputs, Round1, Round1Message, Round2, Round2Message};
use crate::simple::{Inputs, Round1, Round1Message, Round2, Round2Message, SimpleProtocol};

#[derive(Debug, Clone, Copy)]
enum Behavior {
Lawful,
SerializedGarbage,
AttributableFailure,
AttributableFailureRound2,
}

struct MaliciousInputs<Id> {
inputs: Inputs<Id>,
behavior: Behavior,
}

#[derive(Debug)]
struct MaliciousRound1<Id> {
round: Round1<Id>,
behavior: Behavior,
}

impl<Id: PartyId> RoundWrapper<Id> for MaliciousRound1<Id> {
type InnerRound = Round1<Id>;
fn inner_round_ref(&self) -> &Self::InnerRound {
&self.round
}
fn inner_round(self) -> Self::InnerRound {
self.round
}
}
struct SimpleMaliciousProtocol;

impl<Id: PartyId> FirstRound<Id> for MaliciousRound1<Id> {
type Inputs = MaliciousInputs<Id>;
fn new(
rng: &mut impl CryptoRngCore,
shared_randomness: &[u8],
id: Id,
inputs: Self::Inputs,
) -> Result<Self, LocalError> {
let round = Round1::new(rng, shared_randomness, id, inputs.inputs)?;
Ok(Self {
round,
behavior: inputs.behavior,
})
}
}
impl<Id: PartyId> Misbehaving<Id, Behavior> for SimpleMaliciousProtocol {
type Protocol = SimpleProtocol;
type FirstRound = Round1<Id>;

impl<Id: PartyId> RoundOverride<Id> for MaliciousRound1<Id> {
fn make_direct_message(
&self,
rng: &mut impl CryptoRngCore,
fn amend_direct_message(
_rng: &mut impl CryptoRngCore,
round: &dyn ObjectSafeRound<Id, Protocol = Self::Protocol>,
behavior: &Behavior,
serializer: &Serializer,
destination: &Id,
_destination: &Id,
direct_message: DirectMessage,
artifact: Option<Artifact>,
) -> Result<(DirectMessage, Option<Artifact>), LocalError> {
if matches!(self.behavior, Behavior::SerializedGarbage) {
Ok((DirectMessage::new(serializer, [99u8])?, None))
} else if matches!(self.behavior, Behavior::AttributableFailure) {
let message = Round1Message {
my_position: self.round.context.ids_to_positions[&self.round.context.id],
your_position: self.round.context.ids_to_positions[&self.round.context.id],
};
Ok((DirectMessage::new(serializer, message)?, None))
} else {
self.inner_round_ref().make_direct_message(rng, serializer, destination)
}
}

fn finalize(
self,
rng: &mut impl CryptoRngCore,
payloads: BTreeMap<Id, Payload>,
artifacts: BTreeMap<Id, Artifact>,
) -> Result<
FinalizeOutcome<Id, <<Self as RoundWrapper<Id>>::InnerRound as Round<Id>>::Protocol>,
FinalizeError<<<Self as RoundWrapper<Id>>::InnerRound as Round<Id>>::Protocol>,
> {
let behavior = self.behavior;
let outcome = self.inner_round().finalize(rng, payloads, artifacts)?;

Ok(match outcome {
FinalizeOutcome::Result(res) => FinalizeOutcome::Result(res),
FinalizeOutcome::AnotherRound(another_round) => {
let round2 = another_round.downcast::<Round2<Id>>().map_err(FinalizeError::Local)?;
FinalizeOutcome::another_round(MaliciousRound2 {
round: round2,
behavior,
})
let dm = if round.id() == RoundId::new(1) {
match behavior {
Behavior::SerializedGarbage => DirectMessage::new(serializer, [99u8])?,
Behavior::AttributableFailure => {
let round1 = round.downcast_ref::<Round1<Id>>()?;
let message = Round1Message {
my_position: round1.context.ids_to_positions[&round1.context.id],
your_position: round1.context.ids_to_positions[&round1.context.id],
};
DirectMessage::new(serializer, message)?
}
_ => direct_message,
}
} else if round.id() == RoundId::new(2) {
match behavior {
Behavior::AttributableFailureRound2 => {
let round2 = round.downcast_ref::<Round2<Id>>()?;
let message = Round2Message {
my_position: round2.context.ids_to_positions[&round2.context.id],
your_position: round2.context.ids_to_positions[&round2.context.id],
};
DirectMessage::new(serializer, message)?
}
_ => direct_message,
}
})
}
}

round_override!(MaliciousRound1);

#[derive(Debug)]
struct MaliciousRound2<Id> {
round: Round2<Id>,
behavior: Behavior,
}

impl<Id: PartyId> RoundWrapper<Id> for MaliciousRound2<Id> {
type InnerRound = Round2<Id>;
fn inner_round_ref(&self) -> &Self::InnerRound {
&self.round
}
fn inner_round(self) -> Self::InnerRound {
self.round
}
}

impl<Id: PartyId> RoundOverride<Id> for MaliciousRound2<Id> {
fn make_direct_message(
&self,
rng: &mut impl CryptoRngCore,
serializer: &Serializer,
destination: &Id,
) -> Result<(DirectMessage, Option<Artifact>), LocalError> {
if matches!(self.behavior, Behavior::AttributableFailureRound2) {
let message = Round2Message {
my_position: self.round.context.ids_to_positions[&self.round.context.id],
your_position: self.round.context.ids_to_positions[&self.round.context.id],
};
Ok((DirectMessage::new(serializer, message)?, None))
} else {
self.inner_round_ref().make_direct_message(rng, serializer, destination)
}
direct_message
};
Ok((dm, artifact))
}
}

round_override!(MaliciousRound2);
type MaliciousEntryPoint<Id> = MisbehavingRound<Id, Behavior, SimpleMaliciousProtocol>;

#[test]
fn serialized_garbage() {
Expand All @@ -160,13 +84,13 @@ fn serialized_garbage() {
.enumerate()
.map(|(idx, signer)| {
let behavior = if idx == 0 {
Behavior::SerializedGarbage
Some(Behavior::SerializedGarbage)
} else {
Behavior::Lawful
None
};

let malicious_inputs = MaliciousInputs {
inputs: inputs.clone(),
let malicious_inputs = MisbehavingInputs {
inner_inputs: inputs.clone(),
behavior,
};
(*signer, malicious_inputs)
Expand All @@ -177,7 +101,7 @@ fn serialized_garbage() {
.with_env_filter(EnvFilter::from_default_env())
.finish();
let mut reports = tracing::subscriber::with_default(my_subscriber, || {
run_sync::<MaliciousRound1<TestVerifier>, TestSessionParams<BinaryFormat>>(&mut OsRng, run_inputs).unwrap()
run_sync::<MaliciousEntryPoint<TestVerifier>, TestSessionParams<BinaryFormat>>(&mut OsRng, run_inputs).unwrap()
});

let v0 = signers[0].verifying_key();
Expand Down Expand Up @@ -206,13 +130,13 @@ fn attributable_failure() {
.enumerate()
.map(|(idx, signer)| {
let behavior = if idx == 0 {
Behavior::AttributableFailure
Some(Behavior::AttributableFailure)
} else {
Behavior::Lawful
None
};

let malicious_inputs = MaliciousInputs {
inputs: inputs.clone(),
let malicious_inputs = MisbehavingInputs {
inner_inputs: inputs.clone(),
behavior,
};
(*signer, malicious_inputs)
Expand All @@ -223,7 +147,7 @@ fn attributable_failure() {
.with_env_filter(EnvFilter::from_default_env())
.finish();
let mut reports = tracing::subscriber::with_default(my_subscriber, || {
run_sync::<MaliciousRound1<TestVerifier>, TestSessionParams<BinaryFormat>>(&mut OsRng, run_inputs).unwrap()
run_sync::<MaliciousEntryPoint<TestVerifier>, TestSessionParams<BinaryFormat>>(&mut OsRng, run_inputs).unwrap()
});

let v0 = signers[0].verifying_key();
Expand Down Expand Up @@ -252,13 +176,13 @@ fn attributable_failure_round2() {
.enumerate()
.map(|(idx, signer)| {
let behavior = if idx == 0 {
Behavior::AttributableFailureRound2
Some(Behavior::AttributableFailureRound2)
} else {
Behavior::Lawful
None
};

let malicious_inputs = MaliciousInputs {
inputs: inputs.clone(),
let malicious_inputs = MisbehavingInputs {
inner_inputs: inputs.clone(),
behavior,
};
(*signer, malicious_inputs)
Expand All @@ -269,7 +193,7 @@ fn attributable_failure_round2() {
.with_env_filter(EnvFilter::from_default_env())
.finish();
let mut reports = tracing::subscriber::with_default(my_subscriber, || {
run_sync::<MaliciousRound1<TestVerifier>, TestSessionParams<BinaryFormat>>(&mut OsRng, run_inputs).unwrap()
run_sync::<MaliciousEntryPoint<TestVerifier>, TestSessionParams<BinaryFormat>>(&mut OsRng, run_inputs).unwrap()
});

let v0 = signers[0].verifying_key();
Expand Down
5 changes: 5 additions & 0 deletions manul/src/combinators.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
mod chain;
mod misbehave;

pub use chain::{Chain, Chained, ChainedProtocol};
pub use misbehave::{Misbehaving, MisbehavingInputs, MisbehavingRound};
Loading

0 comments on commit 104d625

Please sign in to comment.