diff --git a/crates/papyrus_network/src/discovery/flow_test.rs b/crates/papyrus_network/src/discovery/flow_test.rs index 1b7f296d81..7b2b21c907 100644 --- a/crates/papyrus_network/src/discovery/flow_test.rs +++ b/crates/papyrus_network/src/discovery/flow_test.rs @@ -78,7 +78,7 @@ async fn all_nodes_have_same_bootstrap_peer() { while connected_peers.len() < NUM_NODES * (NUM_NODES - 1) { let (peer_id, event) = swarms_stream.next().await.unwrap(); - let mixed_event: mixed_behaviour::Event = match event { + let mixed_event: mixed_behaviour::Event = match event.unwrap() { SwarmEvent::Behaviour(DiscoveryMixedBehaviourEvent::Discovery(event)) => event.into(), SwarmEvent::Behaviour(DiscoveryMixedBehaviourEvent::Kademlia(event)) => event.into(), SwarmEvent::Behaviour(DiscoveryMixedBehaviourEvent::Identify(event)) => event.into(), diff --git a/crates/papyrus_network/src/network_manager/mod.rs b/crates/papyrus_network/src/network_manager/mod.rs index 51c6721d6a..30f3eb332a 100644 --- a/crates/papyrus_network/src/network_manager/mod.rs +++ b/crates/papyrus_network/src/network_manager/mod.rs @@ -14,7 +14,7 @@ use futures::channel::mpsc::{Receiver, SendError, Sender}; use futures::channel::oneshot; use futures::future::{ready, BoxFuture, Ready}; use futures::sink::With; -use futures::stream::{self, FuturesUnordered, Map, Stream}; +use futures::stream::{FuturesUnordered, Map, Stream}; use futures::{pin_mut, FutureExt, Sink, SinkExt, StreamExt}; use libp2p::gossipsub::{SubscriptionError, TopicHash}; use libp2p::swarm::SwarmEvent; @@ -74,10 +74,10 @@ impl GenericNetworkManager { Some(event) = self.swarm.next() => self.handle_swarm_event(event), Some(res) = self.sqmr_inbound_response_receivers.next() => self.handle_response_for_inbound_query(res), Some((protocol, client_payload)) = self.sqmr_outbound_payload_receivers.next() => { - self.handle_local_sqmr_payload(protocol, client_payload) + self.handle_local_sqmr_payload(protocol, client_payload.expect("An SQMR client channel should not be terminated.")) } Some((topic_hash, message)) = self.messages_to_broadcast_receivers.next() => { - self.broadcast_message(message, topic_hash); + self.broadcast_message(message.expect("A broadcast channel should not be terminated."), topic_hash); } Some(Some(peer_id)) = self.reported_peer_receivers.next() => self.swarm.report_peer_as_malicious(peer_id), Some(peer_id) = self.reported_peers_receiver.next() => self.swarm.report_peer_as_malicious(peer_id), @@ -404,7 +404,7 @@ impl GenericNetworkManager { inbound_session_id, // Adding a None at the end of the stream so that we will receive a message // letting us know the stream has ended. - Box::new(responses_receiver.map(Some).chain(stream::once(ready(None)))), + Box::new(responses_receiver), ); // TODO(shahak): Close the inbound session if the buffer is full. @@ -658,7 +658,7 @@ type GenericSender = Box + Unpin + Send>; pub type GenericReceiver = Box + Unpin + Send>; type ResponsesSender = GenericSender; -type ResponsesReceiver = GenericReceiver>; +type ResponsesReceiver = GenericReceiver; type ClientResponsesReceiver = GenericReceiver>::Error>>; diff --git a/crates/papyrus_network/src/sqmr/flow_test.rs b/crates/papyrus_network/src/sqmr/flow_test.rs index 619bdafeb5..99fced9439 100644 --- a/crates/papyrus_network/src/sqmr/flow_test.rs +++ b/crates/papyrus_network/src/sqmr/flow_test.rs @@ -32,7 +32,7 @@ async fn collect_events_from_swarms( loop { // Swarms should never finish, so we can unwrap the option. let (peer_id, event) = swarms_stream.next().await.unwrap(); - if let Some((other_peer_id, value)) = map_and_filter_event(peer_id, event) { + if let Some((other_peer_id, value)) = map_and_filter_event(peer_id, event.unwrap()) { let is_unique = results.insert((peer_id, other_peer_id), value).is_none(); if assert_unique { assert!(is_unique); diff --git a/crates/papyrus_network/src/utils.rs b/crates/papyrus_network/src/utils.rs index bc0a3ab971..72c862b61c 100644 --- a/crates/papyrus_network/src/utils.rs +++ b/crates/papyrus_network/src/utils.rs @@ -1,6 +1,6 @@ use core::net::Ipv4Addr; use std::collections::hash_map::{Keys, ValuesMut}; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::hash::Hash; use std::pin::Pin; use std::task::{Context, Poll, Waker}; @@ -48,23 +48,29 @@ impl StreamHashMap { } impl Stream for StreamHashMap { - type Item = (K, ::Item); + type Item = (K, Option<::Item>); fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let unpinned_self = Pin::into_inner(self); - let mut finished_streams = HashSet::new(); + let mut finished_stream_key: Option = None; for (key, stream) in &mut unpinned_self.map { match stream.poll_next_unpin(cx) { Poll::Ready(Some(value)) => { - return Poll::Ready(Some((key.clone(), value))); + return Poll::Ready(Some((key.clone(), Some(value)))); } Poll::Ready(None) => { - finished_streams.insert(key.clone()); + finished_stream_key = Some(key.clone()); + // breaking and removing the finished stream from the map outside of the loop + // because we can't have two mutable references to the map. + break; } Poll::Pending => {} } } - HashMap::retain(&mut unpinned_self.map, |key, _| !finished_streams.contains(key)); + if let Some(finished_stream_key) = finished_stream_key { + unpinned_self.map.remove(&finished_stream_key); + return Poll::Ready(Some((finished_stream_key, None))); + } unpinned_self.wakers_waiting_for_new_stream.push(cx.waker().clone()); Poll::Pending }