Skip to content

Commit

Permalink
feat(network): change broadcast api to allow serialization of Broadca…
Browse files Browse the repository at this point in the history
…stedMessageManager (#695)

* feat(network): change broadcast api to allow serialization of BroadcastedMessageManager

* feat(network): fix CR comments
  • Loading branch information
AlonLStarkWare authored Sep 5, 2024
1 parent 35a3320 commit 13235ae
Show file tree
Hide file tree
Showing 10 changed files with 206 additions and 110 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/papyrus_network/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ testing = []

[dependencies]
async-stream.workspace = true
async-trait.workspace = true
bytes.workspace = true
derive_more.workspace = true
futures.workspace = true
Expand Down
16 changes: 8 additions & 8 deletions crates/papyrus_network/src/e2e_broadcast_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,20 +106,20 @@ async fn broadcast_subscriber_end_to_end_test() {
tokio::time::sleep(Duration::from_secs(1)).await;
let number1 = Number(1);
let number2 = Number(2);
let mut broadcasted_messages_receiver2_1 =
subscriber_channels2_1.broadcasted_messages_receiver;
let mut broadcasted_messages_receiver2_2 =
subscriber_channels2_2.broadcasted_messages_receiver;
let mut broadcast_client2_1 =
subscriber_channels2_1.broadcast_client_channels;
let mut broadcast_client2_2 =
subscriber_channels2_2.broadcast_client_channels;
subscriber_channels1_1.messages_to_broadcast_sender.send(number1).await.unwrap();
subscriber_channels1_2.messages_to_broadcast_sender.send(number2).await.unwrap();
let (received_number1, _report_callback) =
broadcasted_messages_receiver2_1.next().await.unwrap();
broadcast_client2_1.next().await.unwrap();
let (received_number2, _report_callback) =
broadcasted_messages_receiver2_2.next().await.unwrap();
broadcast_client2_2.next().await.unwrap();
assert_eq!(received_number1.unwrap(), number1);
assert_eq!(received_number2.unwrap(), number2);
assert!(broadcasted_messages_receiver2_1.next().now_or_never().is_none());
assert!(broadcasted_messages_receiver2_2.next().now_or_never().is_none());
assert!(broadcast_client2_1.next().now_or_never().is_none());
assert!(broadcast_client2_2.next().now_or_never().is_none());
}
) => {
result.unwrap()
Expand Down
94 changes: 74 additions & 20 deletions crates/papyrus_network/src/network_manager/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use std::collections::HashMap;
use std::pin::Pin;
use std::task::{Context, Poll};

use async_trait::async_trait;
use futures::channel::mpsc::{Receiver, SendError, Sender};
use futures::channel::oneshot;
use futures::future::{ready, BoxFuture, Ready};
Expand Down Expand Up @@ -38,12 +39,14 @@ pub enum NetworkError {
DialError(#[from] libp2p::swarm::DialError),
}

// TODO: Understand whats the correct thing to do here.
const MESSAGE_MANAGER_BUFFER_SIZE: usize = 100000;

pub struct GenericNetworkManager<SwarmT: SwarmTrait> {
swarm: SwarmT,
inbound_protocol_to_buffer_size: HashMap<StreamProtocol, usize>,
sqmr_inbound_response_receivers: StreamHashMap<InboundSessionId, ResponsesReceiver>,
sqmr_inbound_payload_senders: HashMap<StreamProtocol, SqmrServerSender>,

sqmr_outbound_payload_receivers: StreamHashMap<StreamProtocol, SqmrClientReceiver>,
sqmr_outbound_response_senders: HashMap<OutboundSessionId, ResponsesSender>,
sqmr_outbound_report_receivers_awaiting_assignment: HashMap<OutboundSessionId, ReportReceiver>,
Expand All @@ -54,6 +57,10 @@ pub struct GenericNetworkManager<SwarmT: SwarmTrait> {
broadcasted_messages_senders: HashMap<TopicHash, Sender<(Bytes, BroadcastedMessageManager)>>,
reported_peer_receivers: FuturesUnordered<BoxFuture<'static, Option<PeerId>>>,
advertised_multiaddr: Option<Multiaddr>,
reported_peers_receiver: Receiver<PeerId>,
reported_peers_sender: Sender<PeerId>,
continue_propagation_sender: Sender<BroadcastedMessageManager>,
continue_propagation_receiver: Receiver<BroadcastedMessageManager>,
// Fields for metrics
num_active_inbound_sessions: usize,
num_active_outbound_sessions: usize,
Expand All @@ -72,6 +79,10 @@ impl<SwarmT: SwarmTrait> GenericNetworkManager<SwarmT> {
self.broadcast_message(message, topic_hash);
}
Some(Some(peer_id)) = self.reported_peer_receivers.next() => self.swarm.report_peer(peer_id),
Some(peer_id) = self.reported_peers_receiver.next() => self.swarm.report_peer(peer_id),
Some(broadcasted_message_manager) = self.continue_propagation_receiver.next() => {
self.swarm.continue_propagation(broadcasted_message_manager);
}
}
}
}
Expand All @@ -85,6 +96,10 @@ impl<SwarmT: SwarmTrait> GenericNetworkManager<SwarmT> {
if let Some(address) = advertised_multiaddr.clone() {
swarm.add_external_address(address);
}
let (reported_peers_sender, reported_peers_receiver) =
futures::channel::mpsc::channel(MESSAGE_MANAGER_BUFFER_SIZE);
let (continue_propagation_sender, continue_propagation_receiver) =
futures::channel::mpsc::channel(MESSAGE_MANAGER_BUFFER_SIZE);
Self {
swarm,
inbound_protocol_to_buffer_size: HashMap::new(),
Expand All @@ -97,6 +112,10 @@ impl<SwarmT: SwarmTrait> GenericNetworkManager<SwarmT> {
broadcasted_messages_senders: HashMap::new(),
reported_peer_receivers,
advertised_multiaddr,
reported_peers_receiver,
reported_peers_sender,
continue_propagation_sender,
continue_propagation_receiver,
num_active_inbound_sessions: 0,
num_active_outbound_sessions: 0,
}
Expand Down Expand Up @@ -168,6 +187,8 @@ impl<SwarmT: SwarmTrait> GenericNetworkManager<SwarmT> {

/// Register a new subscriber for broadcasting and receiving broadcasts for a given topic.
/// Panics if this topic is already subscribed.
// TODO: consider splitting into register_broadcast_topic_client and
// register_broadcast_topic_server
pub fn register_broadcast_topic<T>(
&mut self,
topic: Topic,
Expand Down Expand Up @@ -195,7 +216,7 @@ impl<SwarmT: SwarmTrait> GenericNetworkManager<SwarmT> {

let insert_result = self
.broadcasted_messages_senders
.insert(topic_hash.clone(), broadcasted_messages_sender);
.insert(topic_hash.clone(), broadcasted_messages_sender.clone());
if insert_result.is_some() {
panic!("Topic '{}' has already been registered.", topic);
}
Expand All @@ -210,7 +231,17 @@ impl<SwarmT: SwarmTrait> GenericNetworkManager<SwarmT> {
let broadcasted_messages_receiver =
broadcasted_messages_receiver.map(broadcasted_messages_fn);

Ok(BroadcastTopicChannels { messages_to_broadcast_sender, broadcasted_messages_receiver })
let reported_messages_sender = self
.reported_peers_sender
.clone()
.with(|manager: BroadcastedMessageManager| ready(Ok(manager.peer_id)));
let continue_propagation_sender = self.continue_propagation_sender.clone();
let broadcast_client_channels = BroadcastClientChannels {
broadcasted_messages_receiver,
reported_messages_sender: Box::new(reported_messages_sender),
continue_propagation_sender: Box::new(continue_propagation_sender),
};
Ok(BroadcastTopicChannels { messages_to_broadcast_sender, broadcast_client_channels })
}

fn handle_swarm_event(&mut self, event: SwarmEvent<mixed_behaviour::Event>) {
Expand Down Expand Up @@ -445,9 +476,7 @@ impl<SwarmT: SwarmTrait> GenericNetworkManager<SwarmT> {
fn handle_gossipsub_behaviour_event(&mut self, event: gossipsub_impl::ExternalEvent) {
let gossipsub_impl::ExternalEvent::Received { originated_peer_id, message, topic_hash } =
event;
let (report_sender, report_receiver) = oneshot::channel::<()>();
let broadcasted_message_manager = BroadcastedMessageManager { report_sender };
self.handle_new_report_receiver(originated_peer_id, report_receiver);
let broadcasted_message_manager = BroadcastedMessageManager { peer_id: originated_peer_id };
let Some(sender) = self.broadcasted_messages_senders.get_mut(&topic_hash) else {
error!(
"Received a message from a topic we're not subscribed to with hash {topic_hash:?}"
Expand Down Expand Up @@ -807,20 +836,10 @@ pub type BroadcastTopicSender<T> = With<
fn(T) -> Ready<Result<Bytes, SendError>>,
>;

// TODO(eitan): consider adding the message to the struct
// TODO(alonl): remove clone
#[derive(Clone)]
pub struct BroadcastedMessageManager {
report_sender: ReportSender,
}
impl BroadcastedMessageManager {
pub fn report_peer(self) {
warn!("Reporting peer");
if let Err(e) = self.report_sender.send(()) {
error!("Failed to report peer. Error: {e:?}");
}
}

// TODO(eitan): implement
pub fn continue_propogation(&mut self) {}
peer_id: PeerId,
}

pub type BroadcastTopicReceiver<T> =
Expand All @@ -833,5 +852,40 @@ type BroadcastReceivedMessagesConverterFn<T> =

pub struct BroadcastTopicChannels<T: TryFrom<Bytes>> {
pub messages_to_broadcast_sender: BroadcastTopicSender<T>,
pub broadcasted_messages_receiver: BroadcastTopicReceiver<T>,
pub broadcast_client_channels: BroadcastClientChannels<T>,
}

pub struct BroadcastClientChannels<T: TryFrom<Bytes>> {
broadcasted_messages_receiver: BroadcastTopicReceiver<T>,
reported_messages_sender: GenericSender<BroadcastedMessageManager>,
continue_propagation_sender: GenericSender<BroadcastedMessageManager>,
}

impl<T: TryFrom<Bytes>> Stream for BroadcastClientChannels<T> {
type Item = (Result<T, <T as TryFrom<Bytes>>::Error>, BroadcastedMessageManager);
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<(Result<T, <T as TryFrom<Bytes>>::Error>, BroadcastedMessageManager)>> {
self.broadcasted_messages_receiver.poll_next_unpin(cx)
}
}

#[async_trait]
pub trait BroadcastClientTrait<T: TryFrom<Bytes>>:
Stream<Item = (Result<T, <T as TryFrom<Bytes>>::Error>, BroadcastedMessageManager)> + Unpin
{
async fn report_message(&mut self, message_manager: BroadcastedMessageManager);
async fn continue_propagation(&mut self, message_manager: &BroadcastedMessageManager);
}

#[async_trait]
impl<T: TryFrom<Bytes>> BroadcastClientTrait<T> for BroadcastClientChannels<T> {
async fn report_message(&mut self, message_manager: BroadcastedMessageManager) {
let _ = self.reported_messages_sender.send(message_manager).await;
}

async fn continue_propagation(&mut self, message_manager: &BroadcastedMessageManager) {
let _ = self.continue_propagation_sender.send(message_manager.clone()).await;
}
}
6 changes: 6 additions & 0 deletions crates/papyrus_network/src/network_manager/swarm_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use libp2p::swarm::{DialError, NetworkBehaviour, SwarmEvent};
use libp2p::{Multiaddr, PeerId, StreamProtocol, Swarm};
use tracing::{error, info};

use super::BroadcastedMessageManager;
use crate::gossipsub_impl::Topic;
use crate::mixed_behaviour;
use crate::peer_manager::ReputationModifier;
Expand Down Expand Up @@ -52,6 +53,8 @@ pub trait SwarmTrait: Stream<Item = Event> + Unpin {
fn report_peer(&mut self, peer_id: PeerId);

fn add_new_supported_inbound_protocol(&mut self, protocol_name: StreamProtocol);

fn continue_propagation(&mut self, message_manager: BroadcastedMessageManager);
}

impl SwarmTrait for Swarm<mixed_behaviour::MixedBehaviour> {
Expand Down Expand Up @@ -129,4 +132,7 @@ impl SwarmTrait for Swarm<mixed_behaviour::MixedBehaviour> {
fn add_new_supported_inbound_protocol(&mut self, protocol: StreamProtocol) {
self.behaviour_mut().sqmr.add_new_supported_inbound_protocol(protocol);
}

// TODO(shahak): Implement this function.
fn continue_propagation(&mut self, _message_manager: BroadcastedMessageManager) {}
}
34 changes: 21 additions & 13 deletions crates/papyrus_network/src/network_manager/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ use tokio::sync::Mutex;
use tokio::time::sleep;

use super::swarm_trait::{Event, SwarmTrait};
use super::GenericNetworkManager;
use super::{BroadcastTopicChannels, GenericNetworkManager};
use crate::gossipsub_impl::{self, Topic};
use crate::mixed_behaviour;
use crate::network_manager::ServerQueryManager;
use crate::network_manager::{BroadcastClientChannels, ServerQueryManager};
use crate::sqmr::behaviour::{PeerNotConnected, SessionIdNotFoundError};
use crate::sqmr::{Bytes, GenericEvent, InboundSessionId, OutboundSessionId};

Expand Down Expand Up @@ -202,6 +202,11 @@ impl SwarmTrait for MockSwarm {
) -> Result<PeerId, SessionIdNotFoundError> {
Ok(PeerId::random())
}

// TODO (shahak): Add test for continue propagation.
fn continue_propagation(&mut self, _message_manager: super::BroadcastedMessageManager) {
unimplemented!()
}
}

const BUFFER_SIZE: usize = 100;
Expand Down Expand Up @@ -353,22 +358,25 @@ async fn receive_broadcasted_message_and_report_it() {

let mut network_manager = GenericNetworkManager::generic_new(mock_swarm, None);

let mut broadcasted_messages_receiver = network_manager
.register_broadcast_topic::<Bytes>(topic.clone(), BUFFER_SIZE)
.unwrap()
.broadcasted_messages_receiver;
let BroadcastTopicChannels { broadcast_client_channels, .. } =
network_manager.register_broadcast_topic::<Bytes>(topic.clone(), BUFFER_SIZE).unwrap();
let BroadcastClientChannels {
mut reported_messages_sender,
mut broadcasted_messages_receiver,
..
} = broadcast_client_channels;

tokio::select! {
_ = network_manager.run() => panic!("network manager ended"),
// We need to do the entire calculation in the future here so that the network will keep
// running while we call report_callback.
reported_peer_result = tokio::time::timeout(TIMEOUT, broadcasted_messages_receiver.next())
.then(|result| {
let (message_result, broadcasted_message_manager) = result.unwrap().unwrap();
assert_eq!(message, message_result.unwrap());
broadcasted_message_manager.report_peer();
tokio::time::timeout(TIMEOUT, reported_peer_receiver.next())
}) => {
reported_peer_result = tokio::time::timeout(TIMEOUT, async {
let result = broadcasted_messages_receiver.next().await;
let (message_result, broadcasted_message_manager) = result.unwrap();
assert_eq!(message, message_result.unwrap());
reported_messages_sender.send(broadcasted_message_manager).await.unwrap();
reported_peer_receiver.next().await
}) => {
assert_eq!(originated_peer_id, reported_peer_result.unwrap().unwrap());
}
}
Expand Down
48 changes: 45 additions & 3 deletions crates/papyrus_network/src/network_manager/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ use futures::sink::With;
use futures::stream::Map;
use futures::{FutureExt, SinkExt, StreamExt};
use libp2p::gossipsub::SubscriptionError;
use libp2p::PeerId;

use super::{
BroadcastClientChannels,
BroadcastedMessageManager,
GenericReceiver,
ReportReceiver,
Expand Down Expand Up @@ -74,8 +76,36 @@ where
}

pub fn create_test_broadcasted_message_manager() -> BroadcastedMessageManager {
let (report_sender, _report_receiver) = oneshot::channel::<()>();
BroadcastedMessageManager { report_sender }
BroadcastedMessageManager { peer_id: PeerId::random() }
}

// TODO: remove either this method or the one below
// TODO: also return reported_messages_receiver and continue_propagation_receiver, possibly wrapped
// in a struct
pub fn create_test_broadcast_client_channels<T>()
-> (Sender<(Bytes, BroadcastedMessageManager)>, BroadcastClientChannels<T>)
where
T: TryFrom<Bytes>,
{
let (broadcasted_messages_sender, broadcasted_messages_receiver) =
futures::channel::mpsc::channel(CHANNEL_BUFFER_SIZE);
let (reported_messages_sender, _mock_reported_messages_receiver) =
futures::channel::mpsc::channel(CHANNEL_BUFFER_SIZE);
let (continue_propagation_sender, _mock_continue_propagation_receiver) =
futures::channel::mpsc::channel(CHANNEL_BUFFER_SIZE);

let broadcasted_messages_fn: BroadcastReceivedMessagesConverterFn<T> =
|(x, broadcasted_message_manager)| (T::try_from(x), broadcasted_message_manager);
let broadcasted_messages_receiver = broadcasted_messages_receiver.map(broadcasted_messages_fn);

(
broadcasted_messages_sender,
BroadcastClientChannels {
broadcasted_messages_receiver,
reported_messages_sender: Box::new(reported_messages_sender),
continue_propagation_sender: Box::new(continue_propagation_sender),
},
)
}

const CHANNEL_BUFFER_SIZE: usize = 10000;
Expand All @@ -98,8 +128,20 @@ where
|(x, report_sender)| (T::try_from(x), report_sender);
let broadcasted_messages_receiver = broadcasted_messages_receiver.map(broadcasted_messages_fn);

let (reported_messages_sender, _mock_reported_messages_receiver) =
futures::channel::mpsc::channel(CHANNEL_BUFFER_SIZE);

let (continue_propagation_sender, _mock_continue_propagation_receiver) =
futures::channel::mpsc::channel(CHANNEL_BUFFER_SIZE);

let broadcast_client_channels = BroadcastClientChannels {
broadcasted_messages_receiver,
reported_messages_sender: Box::new(reported_messages_sender),
continue_propagation_sender: Box::new(continue_propagation_sender),
};

let subscriber_channels =
BroadcastTopicChannels { messages_to_broadcast_sender, broadcasted_messages_receiver };
BroadcastTopicChannels { messages_to_broadcast_sender, broadcast_client_channels };

let mock_broadcasted_messages_fn: MockBroadcastedMessagesFn<T> =
|(x, report_call_back)| ready(Ok((Bytes::from(x), report_call_back)));
Expand Down
Loading

0 comments on commit 13235ae

Please sign in to comment.