diff --git a/crates/sequencing/papyrus_consensus/src/stream_handler.rs b/crates/sequencing/papyrus_consensus/src/stream_handler.rs index aef64421c7..7101579dce 100644 --- a/crates/sequencing/papyrus_consensus/src/stream_handler.rs +++ b/crates/sequencing/papyrus_consensus/src/stream_handler.rs @@ -1,9 +1,8 @@ //! Stream handler, see StreamManager struct. use std::cmp::Ordering; -use std::collections::btree_map::Entry as BTreeEntry; -use std::collections::hash_map::Entry as HashMapEntry; -use std::collections::{BTreeMap, HashMap}; +use std::collections::hash_map::Entry::{Occupied, Vacant}; +use std::collections::HashMap; use futures::channel::mpsc; use futures::StreamExt; @@ -39,7 +38,7 @@ struct StreamData< max_message_id_received: MessageId, sender: mpsc::Sender, // A buffer for messages that were received out of order. - message_buffer: BTreeMap>, + message_buffer: HashMap>, } impl> + TryFrom, Error = ProtobufConversionError>> StreamData { @@ -49,7 +48,7 @@ impl> + TryFrom, Error = ProtobufConversionError fin_message_id: None, max_message_id_received: 0, sender, - message_buffer: BTreeMap::new(), + message_buffer: HashMap::new(), } } } @@ -232,8 +231,8 @@ impl> + TryFrom, Error = ProtobufConversi let message_id = message.message_id; let data = match self.inbound_stream_data.entry(key.clone()) { - HashMapEntry::Occupied(entry) => entry.into_mut(), - HashMapEntry::Vacant(e) => { + Occupied(entry) => entry.into_mut(), + Vacant(e) => { // If we received a message for a stream that we have not seen before, // we need to create a new receiver for it. let (sender, receiver) = mpsc::channel(CHANNEL_BUFFER_LENGTH); @@ -309,10 +308,10 @@ impl> + TryFrom, Error = ProtobufConversi let message_id = message.message_id; match data.message_buffer.entry(message_id) { - BTreeEntry::Vacant(e) => { + Vacant(e) => { e.insert(message); } - BTreeEntry::Occupied(_) => { + Occupied(_) => { // TODO(guyn): replace warnings with more graceful error handling warn!( "Two messages with the same message_id in buffer! key: {:?}, message_id: {}", diff --git a/crates/sequencing/papyrus_consensus/src/stream_handler_test.rs b/crates/sequencing/papyrus_consensus/src/stream_handler_test.rs index 4bd575da8d..c2995108b8 100644 --- a/crates/sequencing/papyrus_consensus/src/stream_handler_test.rs +++ b/crates/sequencing/papyrus_consensus/src/stream_handler_test.rs @@ -34,10 +34,16 @@ mod tests { StreamMessage { message: content, stream_id, message_id } } - // Check if two vectors are the same: - fn do_vecs_match(a: &[T], b: &[T]) -> bool { - let matching = a.iter().zip(b.iter()).filter(|&(a, b)| a == b).count(); - matching == a.len() && matching == b.len() + // Check if two vectors are the same, regardless of ordering + fn do_vecs_match_unordered(a: &Vec, b: &Vec) -> bool + where + T: std::hash::Hash + Eq, + { + let mut a = a.clone(); + a.sort(); + let mut b = b.clone(); + b.sort(); + a == b } async fn send( @@ -183,7 +189,7 @@ mod tests { .message_buffer .into_keys() .collect(); - assert!(do_vecs_match(&keys, &range)); + assert!(do_vecs_match_unordered(&keys, &range)); // Now send the last message: send(&mut network_sender, &inbound_metadata, make_test_message(stream_id, 0, false)).await; @@ -258,7 +264,7 @@ mod tests { ); // We have all message from 1 to 9 buffered. - assert!(do_vecs_match( + assert!(do_vecs_match_unordered( &stream_handler.inbound_stream_data[&(peer_id.clone(), stream_id1)] .message_buffer .clone() @@ -268,7 +274,7 @@ mod tests { )); // We have all message from 1 to 5 buffered. - assert!(do_vecs_match( + assert!(do_vecs_match_unordered( &stream_handler.inbound_stream_data[&(peer_id.clone(), stream_id2)] .message_buffer .clone() @@ -278,7 +284,7 @@ mod tests { )); // We have all message from 1 to 5 buffered. - assert!(do_vecs_match( + assert!(do_vecs_match_unordered( &stream_handler.inbound_stream_data[&(peer_id.clone(), stream_id3)] .message_buffer .clone() @@ -486,7 +492,7 @@ mod tests { vec1.sort(); let mut vec2 = vec![&stream_id1, &stream_id2]; vec2.sort(); - do_vecs_match(&vec1, &vec2); + do_vecs_match_unordered(&vec1, &vec2); assert_eq!(stream_handler.outbound_stream_number[&stream_id2], 1); // Close the first channel.