Skip to content

Commit

Permalink
feat(network): make StreamHashMap notify when a stream is finished (#…
Browse files Browse the repository at this point in the history
…1439)

* feat(network): make StreamHashMap notify when a stream is finished

* fix(network): fix CR comments
  • Loading branch information
AlonLStarkWare authored Oct 27, 2024
1 parent f40da24 commit 50bdb62
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 13 deletions.
2 changes: 1 addition & 1 deletion crates/papyrus_network/src/discovery/flow_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ async fn all_nodes_have_same_bootstrap_peer() {
while connected_peers.len() < NUM_NODES * (NUM_NODES - 1) {
let (peer_id, event) = swarms_stream.next().await.unwrap();

let mixed_event: mixed_behaviour::Event = match event {
let mixed_event: mixed_behaviour::Event = match event.unwrap() {
SwarmEvent::Behaviour(DiscoveryMixedBehaviourEvent::Discovery(event)) => event.into(),
SwarmEvent::Behaviour(DiscoveryMixedBehaviourEvent::Kademlia(event)) => event.into(),
SwarmEvent::Behaviour(DiscoveryMixedBehaviourEvent::Identify(event)) => event.into(),
Expand Down
10 changes: 5 additions & 5 deletions crates/papyrus_network/src/network_manager/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use futures::channel::mpsc::{Receiver, SendError, Sender};
use futures::channel::oneshot;
use futures::future::{ready, BoxFuture, Ready};
use futures::sink::With;
use futures::stream::{self, FuturesUnordered, Map, Stream};
use futures::stream::{FuturesUnordered, Map, Stream};
use futures::{pin_mut, FutureExt, Sink, SinkExt, StreamExt};
use libp2p::gossipsub::{SubscriptionError, TopicHash};
use libp2p::swarm::SwarmEvent;
Expand Down Expand Up @@ -74,10 +74,10 @@ impl<SwarmT: SwarmTrait> GenericNetworkManager<SwarmT> {
Some(event) = self.swarm.next() => self.handle_swarm_event(event),
Some(res) = self.sqmr_inbound_response_receivers.next() => self.handle_response_for_inbound_query(res),
Some((protocol, client_payload)) = self.sqmr_outbound_payload_receivers.next() => {
self.handle_local_sqmr_payload(protocol, client_payload)
self.handle_local_sqmr_payload(protocol, client_payload.expect("An SQMR client channel should not be terminated."))
}
Some((topic_hash, message)) = self.messages_to_broadcast_receivers.next() => {
self.broadcast_message(message, topic_hash);
self.broadcast_message(message.expect("A broadcast channel should not be terminated."), topic_hash);
}
Some(Some(peer_id)) = self.reported_peer_receivers.next() => self.swarm.report_peer_as_malicious(peer_id),
Some(peer_id) = self.reported_peers_receiver.next() => self.swarm.report_peer_as_malicious(peer_id),
Expand Down Expand Up @@ -404,7 +404,7 @@ impl<SwarmT: SwarmTrait> GenericNetworkManager<SwarmT> {
inbound_session_id,
// Adding a None at the end of the stream so that we will receive a message
// letting us know the stream has ended.
Box::new(responses_receiver.map(Some).chain(stream::once(ready(None)))),
Box::new(responses_receiver),
);

// TODO(shahak): Close the inbound session if the buffer is full.
Expand Down Expand Up @@ -658,7 +658,7 @@ type GenericSender<T> = Box<dyn Sink<T, Error = SendError> + Unpin + Send>;
pub type GenericReceiver<T> = Box<dyn Stream<Item = T> + Unpin + Send>;

type ResponsesSender = GenericSender<Bytes>;
type ResponsesReceiver = GenericReceiver<Option<Bytes>>;
type ResponsesReceiver = GenericReceiver<Bytes>;

type ClientResponsesReceiver<Response> =
GenericReceiver<Result<Response, <Response as TryFrom<Bytes>>::Error>>;
Expand Down
2 changes: 1 addition & 1 deletion crates/papyrus_network/src/sqmr/flow_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ async fn collect_events_from_swarms<BehaviourTrait: NetworkBehaviour, T>(
loop {
// Swarms should never finish, so we can unwrap the option.
let (peer_id, event) = swarms_stream.next().await.unwrap();
if let Some((other_peer_id, value)) = map_and_filter_event(peer_id, event) {
if let Some((other_peer_id, value)) = map_and_filter_event(peer_id, event.unwrap()) {
let is_unique = results.insert((peer_id, other_peer_id), value).is_none();
if assert_unique {
assert!(is_unique);
Expand Down
18 changes: 12 additions & 6 deletions crates/papyrus_network/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use core::net::Ipv4Addr;
use std::collections::hash_map::{Keys, ValuesMut};
use std::collections::{HashMap, HashSet};
use std::collections::HashMap;
use std::hash::Hash;
use std::pin::Pin;
use std::task::{Context, Poll, Waker};
Expand Down Expand Up @@ -48,23 +48,29 @@ impl<K: Unpin + Clone + Eq + Hash, V: Stream + Unpin> StreamHashMap<K, V> {
}

impl<K: Unpin + Clone + Eq + Hash, V: Stream + Unpin> Stream for StreamHashMap<K, V> {
type Item = (K, <V as Stream>::Item);
type Item = (K, Option<<V as Stream>::Item>);

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let unpinned_self = Pin::into_inner(self);
let mut finished_streams = HashSet::new();
let mut finished_stream_key: Option<K> = None;
for (key, stream) in &mut unpinned_self.map {
match stream.poll_next_unpin(cx) {
Poll::Ready(Some(value)) => {
return Poll::Ready(Some((key.clone(), value)));
return Poll::Ready(Some((key.clone(), Some(value))));
}
Poll::Ready(None) => {
finished_streams.insert(key.clone());
finished_stream_key = Some(key.clone());
// breaking and removing the finished stream from the map outside of the loop
// because we can't have two mutable references to the map.
break;
}
Poll::Pending => {}
}
}
HashMap::retain(&mut unpinned_self.map, |key, _| !finished_streams.contains(key));
if let Some(finished_stream_key) = finished_stream_key {
unpinned_self.map.remove(&finished_stream_key);
return Poll::Ready(Some((finished_stream_key, None)));
}
unpinned_self.wakers_waiting_for_new_stream.push(cx.waker().clone());
Poll::Pending
}
Expand Down

0 comments on commit 50bdb62

Please sign in to comment.