From 1558c1cc51847c3c98dae3cd1b48aba5d75d0aef Mon Sep 17 00:00:00 2001 From: Thomas Braun Date: Thu, 7 Nov 2024 21:37:03 -0500 Subject: [PATCH] chore: add nested multiplexing test --- sdk/src/network/mod.rs | 118 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 105 insertions(+), 13 deletions(-) diff --git a/sdk/src/network/mod.rs b/sdk/src/network/mod.rs index 18d3a538..efd6923a 100644 --- a/sdk/src/network/mod.rs +++ b/sdk/src/network/mod.rs @@ -21,7 +21,7 @@ pub mod matchbox; pub mod messaging; pub mod setup; -#[derive(Debug, Serialize, Deserialize, Clone, Copy)] +#[derive(Debug, Serialize, Deserialize, Clone, Copy, Default)] pub struct IdentifierInfo { pub block_id: Option, pub session_id: Option, @@ -317,6 +317,12 @@ impl NetworkMultiplexer { } } +impl From for NetworkMultiplexer { + fn from(network: N) -> Self { + Self::new(network) + } +} + pub struct SubNetwork { tx: MultiplexedSender, rx: Mutex, @@ -423,18 +429,14 @@ mod tests { .try_init(); } - #[tokio::test(flavor = "multi_thread")] - async fn p2p() { - setup_log(); - let nodes = stream::iter(0..NODE_COUNT) - .map(|_| node()) - .collect::>() - .await; + async fn wait_for_nodes_connected(nodes: &[GossipHandle]) { + let node_count = nodes.len(); + // wait for the nodes to connect to each other - let max_retries = 30 * NODE_COUNT; + let max_retries = 30 * node_count; let mut retry = 0; loop { - crate::debug!(%NODE_COUNT, %max_retries, %retry, "Checking if all nodes are connected to each other"); + crate::debug!(%node_count, %max_retries, %retry, "Checking if all nodes are connected to each other"); let connected = nodes .iter() .map(|node| node.connected_peers()) @@ -444,9 +446,10 @@ mod tests { .iter() .enumerate() .inspect(|(node, peers)| crate::debug!("Node {node} has {peers} connected peers")) - .all(|(_, &peers)| peers == usize::from(NODE_COUNT) - 1); + .all(|(_, &peers)| peers == node_count - 1); if all_connected { - break; + crate::debug!("All nodes are connected to each other"); + return; } tokio::time::sleep(tokio::time::Duration::from_millis(300)).await; retry += 1; @@ -454,7 +457,18 @@ mod tests { panic!("Failed to connect all nodes to each other"); } } - crate::debug!("All nodes are connected to each other"); + } + + #[tokio::test(flavor = "multi_thread")] + async fn p2p() { + setup_log(); + let nodes = stream::iter(0..NODE_COUNT) + .map(|_| node()) + .collect::>() + .await; + + wait_for_nodes_connected(&nodes).await; + let mut tasks = Vec::new(); let barrier = Arc::new(Barrier::new(NODE_COUNT as usize)); for (i, node) in nodes.into_iter().enumerate() { @@ -681,4 +695,82 @@ mod tests { )) .unwrap() } + + #[tokio::test(flavor = "multi_thread")] + async fn test_nested_multiplexer() { + setup_log(); + crate::info!("Starting test_nested_multiplexer"); + let network0 = node(); + let network1 = node(); + + let mut networks = vec![network0, network1]; + + wait_for_nodes_connected(&networks).await; + + let (network0, network1) = (networks.remove(0), networks.remove(0)); + + async fn nested_multiplex( + cur_depth: usize, + max_depth: usize, + network0: N, + network1: N, + ) { + crate::info!("At nested depth = {cur_depth}/{max_depth}"); + + if cur_depth == max_depth { + return; + } + + let multiplexer0 = NetworkMultiplexer::new(network0); + let multiplexer1 = NetworkMultiplexer::new(network1); + + let stream_key = StreamKey { + task_hash: sha2_256(&[(cur_depth % 255) as u8]), + round_id: 0, + }; + + let subnetwork0 = multiplexer0.multiplex(stream_key); + let subnetwork1 = multiplexer1.multiplex(stream_key); + + // Send a message in the subnetwork0 to subnetwork1 and vice versa, assert values of message + let payload = vec![1, 2, 3]; + let msg = GossipHandle::build_protocol_message( + IdentifierInfo::default(), + 0, + Some(1), + &payload, + None, + None, + ); + + subnetwork0.send(msg.clone()).unwrap(); + + let received_msg = subnetwork1.recv().await.unwrap(); + assert_eq!(received_msg.payload, msg.payload); + + let msg = GossipHandle::build_protocol_message( + IdentifierInfo::default(), + 1, + Some(0), + &payload, + None, + None, + ); + + subnetwork1.send(msg.clone()).unwrap(); + + let received_msg = subnetwork0.recv().await.unwrap(); + assert_eq!(received_msg.payload, msg.payload); + + Box::pin(nested_multiplex( + cur_depth + 1, + max_depth, + subnetwork0, + subnetwork1, + )) + .await + } + + nested_multiplex(0, 10, network0, network1).await; + } }