From 355a5b5caca13d8b10f1ad8bdfe1a09d527dde36 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Fri, 15 Nov 2024 11:47:54 -0800 Subject: [PATCH] Add a wrapping type for `run_sync()` output --- examples/src/simple.rs | 9 ++++-- examples/src/simple_chain.rs | 9 ++++-- examples/src/simple_malicious.rs | 12 ++++++-- manul/benches/empty_rounds.rs | 8 ++--- manul/src/session/transcript.rs | 44 ++++++++++++++++++++++++++ manul/src/testing/run_sync.rs | 53 ++++++++++++++++++++++++++------ manul/src/tests/partial_echo.rs | 8 ++--- 7 files changed, 116 insertions(+), 27 deletions(-) diff --git a/examples/src/simple.rs b/examples/src/simple.rs index 777186c..93ea482 100644 --- a/examples/src/simple.rs +++ b/examples/src/simple.rs @@ -429,10 +429,13 @@ mod tests { .map(|signer| (signer, SimpleProtocolEntryPoint::new(all_ids.clone()))) .collect::>(); - let reports = run_sync_with_tracing::<_, TestSessionParams>(&mut OsRng, entry_points).unwrap(); + let results = run_sync_with_tracing::<_, TestSessionParams>(&mut OsRng, entry_points) + .unwrap() + .results() + .unwrap(); - for (_id, report) in reports { - assert_eq!(report.result().unwrap(), 3); // 0 + 1 + 2 + for (_id, result) in results { + assert_eq!(result, 3); // 0 + 1 + 2 } } } diff --git a/examples/src/simple_chain.rs b/examples/src/simple_chain.rs index 0a85c6d..236c40c 100644 --- a/examples/src/simple_chain.rs +++ b/examples/src/simple_chain.rs @@ -87,10 +87,13 @@ mod tests { .map(|signer| (signer, DoubleSimpleEntryPoint::new(all_ids.clone()))) .collect::>(); - let reports = run_sync_with_tracing::<_, TestSessionParams>(&mut OsRng, entry_points).unwrap(); + let results = run_sync_with_tracing::<_, TestSessionParams>(&mut OsRng, entry_points) + .unwrap() + .results() + .unwrap(); - for (_id, report) in reports { - assert_eq!(report.result().unwrap(), 3); // 0 + 1 + 2 + for (_id, result) in results { + assert_eq!(result, 3); // 0 + 1 + 2 } } } diff --git a/examples/src/simple_malicious.rs b/examples/src/simple_malicious.rs index 016f035..efc6761 100644 --- a/examples/src/simple_malicious.rs +++ b/examples/src/simple_malicious.rs @@ -93,7 +93,9 @@ fn serialized_garbage() { }) .collect::>(); - let mut reports = run_sync_with_tracing::<_, TestSessionParams>(&mut OsRng, entry_points).unwrap(); + let mut reports = run_sync_with_tracing::<_, TestSessionParams>(&mut OsRng, entry_points) + .unwrap() + .reports; let v0 = signers[0].verifying_key(); let v1 = signers[1].verifying_key(); @@ -130,7 +132,9 @@ fn attributable_failure() { }) .collect::>(); - let mut reports = run_sync_with_tracing::<_, TestSessionParams>(&mut OsRng, entry_points).unwrap(); + let mut reports = run_sync_with_tracing::<_, TestSessionParams>(&mut OsRng, entry_points) + .unwrap() + .reports; let v0 = signers[0].verifying_key(); let v1 = signers[1].verifying_key(); @@ -167,7 +171,9 @@ fn attributable_failure_round2() { }) .collect::>(); - let mut reports = run_sync_with_tracing::<_, TestSessionParams>(&mut OsRng, entry_points).unwrap(); + let mut reports = run_sync_with_tracing::<_, TestSessionParams>(&mut OsRng, entry_points) + .unwrap() + .reports; let v0 = signers[0].verifying_key(); let v1 = signers[1].verifying_key(); diff --git a/manul/benches/empty_rounds.rs b/manul/benches/empty_rounds.rs index a0ec8bd..3e19da6 100644 --- a/manul/benches/empty_rounds.rs +++ b/manul/benches/empty_rounds.rs @@ -195,8 +195,8 @@ fn bench_empty_rounds(c: &mut Criterion) { assert!( run_sync::<_, TestSessionParams>(&mut OsRng, entry_points_no_echo.clone()) .unwrap() - .into_values() - .all(|report| report.result().is_some()) + .results() + .is_ok() ) }) }); @@ -225,8 +225,8 @@ fn bench_empty_rounds(c: &mut Criterion) { assert!( run_sync::<_, TestSessionParams>(&mut OsRng, entry_points_echo.clone()) .unwrap() - .into_values() - .all(|report| report.result().is_some()) + .results() + .is_ok() ); }) }); diff --git a/manul/src/session/transcript.rs b/manul/src/session/transcript.rs index 6739ee0..17cd4c6 100644 --- a/manul/src/session/transcript.rs +++ b/manul/src/session/transcript.rs @@ -1,6 +1,8 @@ use alloc::{ collections::{btree_map::Entry, BTreeMap, BTreeSet}, format, + string::String, + vec::Vec, }; use core::fmt::Debug; @@ -183,6 +185,21 @@ pub enum SessionOutcome { Terminated, } +impl

SessionOutcome

+where + P: Protocol, +{ + /// Returns a brief description of the outcome. + pub fn brief(&self) -> String { + match self { + Self::Result(result) => format!("Success ({result:?})"), + Self::NotEnoughMessages => "Not enough messages to finalize, terminated".into(), + Self::Terminated => "Terminated by the user".into(), + Self::StalledWithProof(_) => "Unattributable failure during finalization".into(), + } + } +} + /// The report of a session execution. #[derive(Debug)] pub struct SessionReport { @@ -217,4 +234,31 @@ where _ => None, } } + + /// Returns a brief description of report. + pub fn brief(&self) -> String { + let provable_errors = self + .provable_errors + .iter() + .map(|(id, evidence)| format!(" {:?}: {}", id, evidence.description())) + .collect::>(); + let unprovable_errors = self + .unprovable_errors + .iter() + .map(|(id, error)| format!(" {:?}: {}", id, error)) + .collect::>(); + let missing_messages = self + .missing_messages + .iter() + .map(|(id, parties)| format!(" {:?}: {:?}", id, parties)) + .collect::>(); + + format!( + "Result: {}\nProvable errors:\n{}\nUnprovable errors:\n{}\nMissing_messages:\n{}", + self.outcome.brief(), + provable_errors.join("\n"), + unprovable_errors.join("\n"), + missing_messages.join("\n") + ) + } } diff --git a/manul/src/testing/run_sync.rs b/manul/src/testing/run_sync.rs index 431e7ef..6109ef1 100644 --- a/manul/src/testing/run_sync.rs +++ b/manul/src/testing/run_sync.rs @@ -1,4 +1,4 @@ -use alloc::{collections::BTreeMap, vec::Vec}; +use alloc::{collections::BTreeMap, format, string::String, vec::Vec}; use rand::Rng; use rand_core::CryptoRngCore; @@ -9,8 +9,8 @@ use tracing_subscriber::EnvFilter; use crate::{ protocol::{EntryPoint, Protocol}, session::{ - CanFinalize, LocalError, Message, RoundAccumulator, RoundOutcome, Session, SessionId, SessionParameters, - SessionReport, + CanFinalize, LocalError, Message, RoundAccumulator, RoundOutcome, Session, SessionId, SessionOutcome, + SessionParameters, SessionReport, }, }; @@ -93,7 +93,7 @@ where pub fn run_sync( rng: &mut impl CryptoRngCore, entry_points: Vec<(SP::Signer, EP)>, -) -> Result>, LocalError> +) -> Result, LocalError> where EP: EntryPoint, SP: SessionParameters, @@ -156,25 +156,24 @@ where } } - let mut outcomes = BTreeMap::new(); + let mut reports = BTreeMap::new(); for (verifier, state) in states { - let outcome = match state { + let report = match state { State::InProgress { session, accum } => session.terminate(accum)?, State::Finished(report) => report, }; - outcomes.insert(verifier, outcome); + reports.insert(verifier, report); } - Ok(outcomes) + Ok(ExecutionResult { reports }) } /// Same as [`run_sync()`], but enables a [`tracing`] subscriber that prints the tracing events to stdout, /// taking options from the environment variable `RUST_LOG` (see [`mod@tracing_subscriber::fmt`] for details). -#[allow(clippy::type_complexity)] pub fn run_sync_with_tracing( rng: &mut impl CryptoRngCore, entry_points: Vec<(SP::Signer, EP)>, -) -> Result>, LocalError> +) -> Result, LocalError> where EP: EntryPoint, SP: SessionParameters, @@ -185,3 +184,37 @@ where .finish(); tracing::subscriber::with_default(subscriber, || run_sync::(rng, entry_points)) } + +/// The result of a protocol execution on a set of nodes. +#[derive(Debug)] +pub struct ExecutionResult { + pub reports: BTreeMap>, +} + +impl ExecutionResult +where + P: Protocol, + SP: SessionParameters, +{ + pub fn results(self) -> Result, String> { + let mut report_strings = Vec::new(); + let mut results = BTreeMap::new(); + + for (id, report) in self.reports.into_iter() { + match report.outcome { + SessionOutcome::Result(result) => { + results.insert(id, result); + } + _ => { + report_strings.push(format!("Id: {:?}\n{}", id, report.brief())); + } + } + } + + if report_strings.is_empty() { + Ok(results) + } else { + Err(report_strings.join("\n")) + } + } +} diff --git a/manul/src/tests/partial_echo.rs b/manul/src/tests/partial_echo.rs index 4560829..9969a16 100644 --- a/manul/src/tests/partial_echo.rs +++ b/manul/src/tests/partial_echo.rs @@ -213,8 +213,8 @@ fn partial_echo() { let entry_points = vec![node0, node1, node2, node3, node4]; - let reports = run_sync_with_tracing::<_, TestSessionParams>(&mut OsRng, entry_points).unwrap(); - for (_id, report) in reports { - assert!(report.result().is_some()); - } + let _results = run_sync_with_tracing::<_, TestSessionParams>(&mut OsRng, entry_points) + .unwrap() + .results() + .unwrap(); }