diff --git a/Cargo.lock b/Cargo.lock index c31e3f0..a5dc428 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -492,6 +492,7 @@ dependencies = [ "serde_asn1_der", "serde_json", "signature", + "tinyvec", "tracing", ] @@ -934,6 +935,22 @@ dependencies = [ "serde_json", ] +[[package]] +name = "tinyvec" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" +dependencies = [ + "serde", + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + [[package]] name = "tokio" version = "1.40.0" diff --git a/manul/Cargo.toml b/manul/Cargo.toml index 893157b..ede1380 100644 --- a/manul/Cargo.toml +++ b/manul/Cargo.toml @@ -20,6 +20,7 @@ rand_core = { version = "0.6.4", default-features = false } tracing = { version = "0.1", default-features = false } displaydoc = { version = "0.2", default-features = false } derive-where = "1" +tinyvec = { version = "1", default-features = false, features = ["alloc", "serde"] } rand = { version = "0.8", default-features = false, optional = true } serde-persistent-deserializer = { version = "0.3", optional = true } diff --git a/manul/src/protocol/round.rs b/manul/src/protocol/round.rs index 794677c..bf6f6e7 100644 --- a/manul/src/protocol/round.rs +++ b/manul/src/protocol/round.rs @@ -3,6 +3,7 @@ use alloc::{ collections::{BTreeMap, BTreeSet}, format, string::String, + vec, vec::Vec, }; use core::{ @@ -12,6 +13,7 @@ use core::{ use rand_core::CryptoRngCore; use serde::{Deserialize, Serialize}; +use tinyvec::{tiny_vec, TinyVec}; use super::{ errors::{FinalizeError, LocalError, MessageValidationError, ProtocolValidationError, ReceiveError}, @@ -29,24 +31,18 @@ pub enum FinalizeOutcome { Result(P::Result), } -// Maximum depth of group nesting in RoundIds. -// We need this to be limited to allow the nesting to be performed in `const` context -// (since we cannot use heap there). -const ROUND_ID_DEPTH: usize = 8; - /// A round identifier. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub struct RoundId { - depth: u8, - round_nums: [u8; ROUND_ID_DEPTH], + round_nums: TinyVec<[u8; 4]>, is_echo: bool, } impl Display for RoundId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { write!(f, "Round ")?; - for i in (0..self.depth as usize).rev() { - write!(f, "{}", self.round_nums.get(i).expect("Depth within range"))?; + for (i, round_num) in self.round_nums.iter().enumerate().rev() { + write!(f, "{}", round_num)?; if i != 0 { write!(f, "-")?; } @@ -60,36 +56,18 @@ impl Display for RoundId { impl RoundId { /// Creates a new round identifier. - pub const fn new(round_num: u8) -> Self { - let mut round_nums = [0u8; ROUND_ID_DEPTH]; - #[allow(clippy::indexing_slicing)] - { - round_nums[0] = round_num; - } + pub fn new(round_num: u8) -> Self { Self { - depth: 1, - round_nums, + round_nums: tiny_vec!(round_num, 0, 0, 0), is_echo: false, } } /// Prefixes this round ID (possibly already nested) with a group number. - /// - /// **Warning:** the maximum nesting depth is 8. Panics if this nesting overflows it. - pub(crate) const fn group_under(&self, round_num: u8) -> Self { - if self.depth as usize == ROUND_ID_DEPTH { - panic!("Maximum depth reached"); - } - let mut round_nums = self.round_nums; - - // Would use `expect("Depth within range")` here, but `expect()` in const fns is unstable. - #[allow(clippy::indexing_slicing)] - { - round_nums[self.depth as usize] = round_num; - } - + pub(crate) fn group_under(&self, round_num: u8) -> Self { + let mut round_nums = self.round_nums.clone(); + round_nums.push(round_num); Self { - depth: self.depth + 1, round_nums, is_echo: self.is_echo, } @@ -99,13 +77,12 @@ impl RoundId { /// /// Returns the `Err` variant if the round ID is not nested. pub(crate) fn ungroup(&self) -> Result { - if self.depth == 1 { + if self.round_nums.len() == 1 { Err(LocalError::new("This round ID is not in a group")) } else { - let mut round_nums = self.round_nums; - *round_nums.get_mut(self.depth as usize - 1).expect("Depth within range") = 0; + let mut round_nums = self.round_nums.clone(); + round_nums.pop().expect("vector size greater than 1"); Ok(Self { - depth: self.depth - 1, round_nums, is_echo: self.is_echo, }) @@ -127,8 +104,7 @@ impl RoundId { panic!("This is already an echo round ID"); } Self { - depth: self.depth, - round_nums: self.round_nums, + round_nums: self.round_nums.clone(), is_echo: true, } } @@ -143,8 +119,7 @@ impl RoundId { panic!("This is already an non-echo round ID"); } Self { - depth: self.depth, - round_nums: self.round_nums, + round_nums: self.round_nums.clone(), is_echo: false, } } diff --git a/manul/src/session/evidence.rs b/manul/src/session/evidence.rs index c0fab32..eafd75c 100644 --- a/manul/src/session/evidence.rs +++ b/manul/src/session/evidence.rs @@ -100,8 +100,8 @@ where .iter() .map(|round_id| { transcript - .get_echo_broadcast(*round_id, verifier) - .map(|echo| (*round_id, echo)) + .get_echo_broadcast(round_id.clone(), verifier) + .map(|echo| (round_id.clone(), echo)) }) .collect::, _>>()?; @@ -110,8 +110,8 @@ where .iter() .map(|round_id| { transcript - .get_normal_broadcast(*round_id, verifier) - .map(|bc| (*round_id, bc)) + .get_normal_broadcast(round_id.clone(), verifier) + .map(|bc| (round_id.clone(), bc)) }) .collect::, _>>()?; @@ -120,8 +120,8 @@ where .iter() .map(|round_id| { transcript - .get_direct_message(*round_id, verifier) - .map(|dm| (*round_id, dm)) + .get_direct_message(round_id.clone(), verifier) + .map(|dm| (round_id.clone(), dm)) }) .collect::, _>>()?; @@ -131,7 +131,7 @@ where .map(|round_id| { transcript .get_normal_broadcast(round_id.echo(), verifier) - .map(|dm| (*round_id, dm)) + .map(|dm| (round_id.clone(), dm)) }) .collect::, _>>()?; @@ -470,12 +470,12 @@ where for (round_id, direct_message) in self.direct_messages.iter() { let verified_direct_message = direct_message.clone().verify::(verifier)?; let metadata = verified_direct_message.metadata(); - if metadata.session_id() != session_id || metadata.round_id() != *round_id { + if metadata.session_id() != session_id || &metadata.round_id() != round_id { return Err(EvidenceError::InvalidEvidence( "Invalid attached message metadata".into(), )); } - verified_direct_messages.insert(*round_id, verified_direct_message.payload().clone()); + verified_direct_messages.insert(round_id.clone(), verified_direct_message.payload().clone()); } let verified_echo_broadcast = self.echo_broadcast.clone().verify::(verifier)?.payload().clone(); @@ -500,31 +500,31 @@ where for (round_id, echo_broadcast) in self.echo_broadcasts.iter() { let verified_echo_broadcast = echo_broadcast.clone().verify::(verifier)?; let metadata = verified_echo_broadcast.metadata(); - if metadata.session_id() != session_id || metadata.round_id() != *round_id { + if metadata.session_id() != session_id || &metadata.round_id() != round_id { return Err(EvidenceError::InvalidEvidence( "Invalid attached message metadata".into(), )); } - verified_echo_broadcasts.insert(*round_id, verified_echo_broadcast.payload().clone()); + verified_echo_broadcasts.insert(round_id.clone(), verified_echo_broadcast.payload().clone()); } let mut verified_normal_broadcasts = BTreeMap::new(); for (round_id, normal_broadcast) in self.normal_broadcasts.iter() { let verified_normal_broadcast = normal_broadcast.clone().verify::(verifier)?; let metadata = verified_normal_broadcast.metadata(); - if metadata.session_id() != session_id || metadata.round_id() != *round_id { + if metadata.session_id() != session_id || &metadata.round_id() != round_id { return Err(EvidenceError::InvalidEvidence( "Invalid attached message metadata".into(), )); } - verified_normal_broadcasts.insert(*round_id, verified_normal_broadcast.payload().clone()); + verified_normal_broadcasts.insert(round_id.clone(), verified_normal_broadcast.payload().clone()); } let mut combined_echos = BTreeMap::new(); for (round_id, combined_echo) in self.combined_echos.iter() { let verified_combined_echo = combined_echo.clone().verify::(verifier)?; let metadata = verified_combined_echo.metadata(); - if metadata.session_id() != session_id || metadata.round_id().non_echo() != *round_id { + if metadata.session_id() != session_id || &metadata.round_id().non_echo() != round_id { return Err(EvidenceError::InvalidEvidence( "Invalid attached message metadata".into(), )); @@ -537,14 +537,14 @@ where for (other_verifier, echo_broadcast) in echo_set.echo_broadcasts.iter() { let verified_echo_broadcast = echo_broadcast.clone().verify::(other_verifier)?; let metadata = verified_echo_broadcast.metadata(); - if metadata.session_id() != session_id || metadata.round_id() != *round_id { + if metadata.session_id() != session_id || &metadata.round_id() != round_id { return Err(EvidenceError::InvalidEvidence( "Invalid attached message metadata".into(), )); } verified_echo_set.push(verified_echo_broadcast.payload().clone()); } - combined_echos.insert(*round_id, verified_echo_set); + combined_echos.insert(round_id.clone(), verified_echo_set); } Ok(self.error.verify_messages_constitute_error( diff --git a/manul/src/session/message.rs b/manul/src/session/message.rs index 51dbc55..199c2f0 100644 --- a/manul/src/session/message.rs +++ b/manul/src/session/message.rs @@ -68,7 +68,7 @@ impl MessageMetadata { } pub fn round_id(&self) -> RoundId { - self.round_id + self.round_id.clone() } } diff --git a/manul/src/session/session.rs b/manul/src/session/session.rs index b8edb31..db81183 100644 --- a/manul/src/session/session.rs +++ b/manul/src/session/session.rs @@ -317,7 +317,7 @@ where } MessageFor::ThisRound } else if self.possible_next_rounds.contains(&message_round_id) { - if accum.message_is_cached(from, message_round_id) { + if accum.message_is_cached(from, &message_round_id) { let err = format!("Message for {:?} is already cached", message_round_id); accum.register_unprovable_error(from, RemoteError::new(&err))?; trace!("{key:?} {err}"); @@ -354,7 +354,7 @@ where match message_for { MessageFor::ThisRound => { accum.mark_processing(&verified_message)?; - Ok(PreprocessOutcome::ToProcess(verified_message)) + Ok(PreprocessOutcome::ToProcess(Box::new(verified_message))) } MessageFor::NextRound => { debug!("{key:?}: Caching message from {from:?} for {message_round_id}"); @@ -406,7 +406,7 @@ where ) -> Result, LocalError> { let round_id = self.round_id(); let transcript = self.transcript.update( - round_id, + &round_id, accum.echo_broadcasts, accum.normal_broadcasts, accum.direct_messages, @@ -446,7 +446,7 @@ where let round_id = self.round_id(); let transcript = self.transcript.update( - round_id, + &round_id, accum.echo_broadcasts, accum.normal_broadcasts, accum.direct_messages, @@ -604,9 +604,9 @@ where self.processing.contains(from) } - fn message_is_cached(&self, from: &SP::Verifier, round_id: RoundId) -> bool { + fn message_is_cached(&self, from: &SP::Verifier, round_id: &RoundId) -> bool { if let Some(entry) = self.cached.get(from) { - entry.contains_key(&round_id) + entry.contains_key(round_id) } else { false } @@ -745,7 +745,7 @@ where let from = message.from().clone(); let round_id = message.metadata().round_id(); let cached = self.cached.entry(from.clone()).or_default(); - if cached.insert(round_id, message).is_some() { + if cached.insert(round_id.clone(), message).is_some() { return Err(LocalError::new(format!( "A message from for {:?} has already been cached", round_id @@ -771,7 +771,7 @@ pub struct ProcessedMessage { #[derive(Debug, Clone)] pub enum PreprocessOutcome { /// The message was successfully verified, pass it on to [`Session::process_message`]. - ToProcess(VerifiedMessage), + ToProcess(Box>), /// The message was intended for the next round and was cached. /// /// No action required now, cached messages will be returned on successful [`Session::finalize_round`]. @@ -795,7 +795,7 @@ impl PreprocessOutcome { /// so the user may choose to ignore them if no logging is desired. pub fn ok(self) -> Option> { match self { - Self::ToProcess(message) => Some(message), + Self::ToProcess(message) => Some(*message), _ => None, } } diff --git a/manul/src/session/transcript.rs b/manul/src/session/transcript.rs index 382448d..3f678e5 100644 --- a/manul/src/session/transcript.rs +++ b/manul/src/session/transcript.rs @@ -36,7 +36,7 @@ where #[allow(clippy::too_many_arguments)] pub fn update( self, - round_id: RoundId, + round_id: &RoundId, echo_broadcasts: BTreeMap>, normal_broadcasts: BTreeMap>, direct_messages: BTreeMap>, @@ -45,7 +45,7 @@ where missing_messages: BTreeSet, ) -> Result { let mut all_echo_broadcasts = self.echo_broadcasts; - match all_echo_broadcasts.entry(round_id) { + match all_echo_broadcasts.entry(round_id.clone()) { Entry::Vacant(entry) => entry.insert(echo_broadcasts), Entry::Occupied(_) => { return Err(LocalError::new(format!( @@ -55,7 +55,7 @@ where }; let mut all_normal_broadcasts = self.normal_broadcasts; - match all_normal_broadcasts.entry(round_id) { + match all_normal_broadcasts.entry(round_id.clone()) { Entry::Vacant(entry) => entry.insert(normal_broadcasts), Entry::Occupied(_) => { return Err(LocalError::new(format!( @@ -65,7 +65,7 @@ where }; let mut all_direct_messages = self.direct_messages; - match all_direct_messages.entry(round_id) { + match all_direct_messages.entry(round_id.clone()) { Entry::Vacant(entry) => entry.insert(direct_messages), Entry::Occupied(_) => { return Err(LocalError::new(format!( @@ -93,7 +93,7 @@ where } let mut all_missing_messages = self.missing_messages; - match all_missing_messages.entry(round_id) { + match all_missing_messages.entry(round_id.clone()) { Entry::Vacant(entry) => entry.insert(missing_messages), Entry::Occupied(_) => { return Err(LocalError::new(format!(