diff --git a/sdk/src/network/mod.rs b/sdk/src/network/mod.rs index 2cf4a429..0f8da609 100644 --- a/sdk/src/network/mod.rs +++ b/sdk/src/network/mod.rs @@ -5,6 +5,8 @@ use dashmap::DashMap; use futures::{Stream, StreamExt}; use serde::{Deserialize, Serialize}; use sp_core::{ecdsa, sha2_256}; +use std::cmp::Reverse; +use std::collections::{BinaryHeap, HashMap}; use std::ops::{Deref, DerefMut}; use std::pin::Pin; use std::sync::Arc; @@ -101,10 +103,49 @@ pub trait Network: Send + Sync + 'static { } } +#[derive(Debug, Serialize, Deserialize)] +struct SequencedMessage { + seq: u64, + payload: Vec, +} + +#[derive(Debug)] +struct PendingMessage { + seq: u64, + message: ProtocolMessage, +} + +impl PartialEq for PendingMessage { + fn eq(&self, other: &Self) -> bool { + self.seq == other.seq + } +} + +impl Eq for PendingMessage {} + +impl PartialOrd for PendingMessage { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for PendingMessage { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.seq.cmp(&other.seq) + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct MultiplexedMessage { + stream_id: StreamKey, + payload: SequencedMessage, +} + pub struct NetworkMultiplexer { to_receiving_streams: ActiveStreams, unclaimed_receiving_streams: Arc>, tx_to_networking_layer: MultiplexedSender, + sequence_numbers: Arc>, } type ActiveStreams = Arc>>; @@ -175,10 +216,13 @@ impl Drop for MultiplexedReceiver { } } -#[derive(Debug, Serialize, Deserialize)] -struct MultiplexedMessage { - payload: Vec, +// Since a single stream can be used for multiple users, and, multiple users assign seq's independently, +// we need to make a key that is unique for each (send->dest) pair and stream. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +struct CompoundStreamKey { stream_id: StreamKey, + send_user_id: UserID, + recv_user_id: Option, } impl NetworkMultiplexer { @@ -190,29 +234,45 @@ impl NetworkMultiplexer { unclaimed_receiving_streams: Arc::new(DashMap::new()), tx_to_networking_layer: MultiplexedSender { inner: tx_to_networking_layer, - stream_id: Default::default(), // Start with an arbitrary stream ID, this won't get used + stream_id: Default::default(), }, + sequence_numbers: Arc::new(DashMap::new()), }; let active_streams = this.to_receiving_streams.clone(); let unclaimed_streams = this.unclaimed_receiving_streams.clone(); let tx_to_networking_layer = this.tx_to_networking_layer.clone(); + let sequence_numbers = this.sequence_numbers.clone(); + drop(tokio::spawn(async move { let network_clone = &network; let task1 = async move { while let Some((stream_id, msg)) = rx_from_substreams.recv().await { - crate::info!( - "Round {}: Sending RAW message from {} to {:?} (id: {})", - msg.identifier_info.round_id, + let compound_key = CompoundStreamKey { + stream_id, + send_user_id: msg.sender.user_id, + recv_user_id: msg.recipient.as_ref().map(|p| p.user_id), + }; + + let mut seq = sequence_numbers.entry(compound_key).or_insert(0); + let current_seq = *seq; + *seq += 1; + + crate::trace!( + "SEND SEQ {current_seq} FROM {} | StreamKey: {:?}", msg.sender.user_id, - msg.recipient.as_ref().map(|p| p.user_id), - msg.identifier_info.message_id, + hex::encode(bincode2::serialize(&compound_key).unwrap()) ); + let multiplexed_message = MultiplexedMessage { - payload: msg.payload, stream_id, + payload: SequencedMessage { + seq: current_seq, + payload: msg.payload, + }, }; + let message = ProtocolMessage { identifier_info: msg.identifier_info, sender: msg.sender, @@ -229,37 +289,97 @@ impl NetworkMultiplexer { }; let task2 = async move { + let mut pending_messages: HashMap< + CompoundStreamKey, + BinaryHeap>, + > = Default::default(); + let mut expected_seqs: HashMap = Default::default(); + while let Some(mut msg) = network_clone.next_message().await { - crate::info!( - "Round {}: Received RAW message from {} to {:?} (id: {})", - msg.identifier_info.round_id, - msg.sender.user_id, - msg.recipient.as_ref().map(|p| p.user_id), - msg.identifier_info.message_id, - ); if let Ok(multiplexed_message) = bincode2::deserialize::(&msg.payload) { let stream_id = multiplexed_message.stream_id; - msg.payload = multiplexed_message.payload; - // Two possibilities: the entry already exists, or, it doesn't and we need to enqueue + let compound_key = CompoundStreamKey { + stream_id, + send_user_id: msg.sender.user_id, + recv_user_id: msg.recipient.as_ref().map(|p| p.user_id), + }; + let seq = multiplexed_message.payload.seq; + msg.payload = multiplexed_message.payload.payload; + + // Get or create the pending heap for this stream + let pending = pending_messages.entry(compound_key).or_default(); + let expected_seq = expected_seqs.entry(compound_key).or_default(); + + let send_user = msg.sender.user_id; + let recv_user = msg + .recipient + .as_ref() + .map(|p| p.user_id as i32) + .unwrap_or(-1); + + let compound_key_hex = + hex::encode(bincode2::serialize(&compound_key).unwrap()); + crate::trace!( + "RECV SEQ {seq} FROM {} as user {:?} | Expecting: {} | StreamKey: {:?}", + send_user, + recv_user, + *expected_seq, + compound_key_hex, + ); + + // Add the message to pending + pending.push(Reverse(PendingMessage { seq, message: msg })); + + // Try to deliver messages in order if let Some(active_receiver) = active_streams.get(&stream_id) { - if let Err(err) = active_receiver.send(msg) { - crate::error!(%err, "Failed to send message to receiver"); - // Delete entry since the receiver is dead - let _ = active_streams.remove(&stream_id); + while let Some(Reverse(PendingMessage { seq, message: _ })) = + pending.peek() + { + if *seq != *expected_seq { + break; + } + + crate::trace!("DELIVERING SEQ {seq} FROM {} as user {:?} | Expecting: {} | StreamKey: {:?}", send_user, recv_user, *expected_seq, compound_key_hex); + + *expected_seq += 1; + + let message = pending.pop().unwrap().0.message; + + if let Err(err) = active_receiver.send(message) { + crate::error!(%err, "Failed to send message to receiver"); + let _ = active_streams.remove(&stream_id); + break; + } } } else { - // Second possibility: the entry does not exist, and another substream is received for this task. - // In this case, reserve an entry locally and store the message in the unclaimed streams. Later, - // when the user attempts to open the substream with the same ID, the message will be sent to the user. let (tx, rx) = Self::create_multiplexed_stream_inner( tx_to_networking_layer.clone(), &active_streams, stream_id, ); - let _ = tx.send(msg); - //let _ = active_streams.insert(stream_id, tx); TX already passed into active_streams above + + // Deliver any pending messages in order + while let Some(Reverse(PendingMessage { seq, message: _ })) = + pending.peek() + { + if *seq != *expected_seq { + break; + } + + crate::trace!("EARLY DELIVERY SEQ {seq} FROM {} as user {:?} | Expecting: {} | StreamKey: {:?}", send_user, recv_user, *expected_seq, compound_key_hex); + + *expected_seq += 1; + + let message = pending.pop().unwrap().0.message; + + if let Err(err) = tx.send(message) { + crate::error!(%err, "Failed to send message to receiver"); + break; + } + } + let _ = unclaimed_streams.insert(stream_id, rx); } } @@ -440,7 +560,6 @@ mod tests { velocity: (u16, u16, u16), } - // NOTE: if you lower the number of nodes to 2, this test passes without issues. const NODE_COUNT: u16 = 10; pub fn setup_log() { @@ -582,7 +701,7 @@ mod tests { assert!( old.is_none(), "Duplicate message from node {}", - msg.sender.user_id + msg.sender.user_id, ); // Break if all messages are received if msgs.len() == usize::from(NODE_COUNT) - 1 { @@ -637,7 +756,7 @@ mod tests { assert!( old.is_none(), "Duplicate message from node {}", - msg.sender.user_id + msg.sender.user_id, ); // Break if all messages are received if msgs.len() == usize::from(NODE_COUNT) - 1 { @@ -685,7 +804,7 @@ mod tests { assert!( old.is_none(), "Duplicate message from node {}", - msg.sender.user_id + msg.sender.user_id, ); // Break if all messages are received if msgs.len() == usize::from(NODE_COUNT) - 1 { @@ -699,18 +818,133 @@ mod tests { Ok(()) } - fn node() -> gossip::GossipHandle { + fn node_with_id() -> (gossip::GossipHandle, ecdsa::Pair) { let identity = libp2p::identity::Keypair::generate_ed25519(); let ecdsa_key = sp_core::ecdsa::Pair::generate().0; let bind_port = 0; - setup::start_p2p_network(setup::NetworkConfig::new_service_network( + let handle = setup::start_p2p_network(setup::NetworkConfig::new_service_network( identity, - ecdsa_key, + ecdsa_key.clone(), Default::default(), bind_port, TOPIC, )) - .unwrap() + .unwrap(); + + (handle, ecdsa_key) + } + + fn node() -> gossip::GossipHandle { + node_with_id().0 + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_stress_test_multiplexer() { + setup_log(); + crate::info!("Starting test_stress_test_multiplexer"); + + let (network0, id0) = node_with_id(); + let (network1, id1) = node_with_id(); + let mut networks = vec![network0, network1]; + + wait_for_nodes_connected(&networks).await; + + let (network0, network1) = (networks.remove(0), networks.remove(0)); + + let public0 = id0.public(); + let public1 = id1.public(); + + let multiplexer0 = NetworkMultiplexer::new(network0); + let multiplexer1 = NetworkMultiplexer::new(network1); + + let stream_key = StreamKey { + task_hash: sha2_256(&[255u8]), + round_id: 100, + }; + + let sub0 = multiplexer0.multiplex(stream_key); + let sub1 = multiplexer1.multiplex(stream_key); + + const MESSAGE_COUNT: u64 = 100; + + #[derive(Serialize, Deserialize)] + struct StressTestPayload { + value: u64, + } + + let handle0 = tokio::spawn(async move { + let sub0 = &sub0; + + let recv_task = async move { + let mut count = 0; + while let Some(msg) = sub0.next_message().await { + assert_eq!(msg.sender.user_id, 1, "Bad sender"); + assert_eq!(msg.recipient.unwrap().user_id, 0, "Bad recipient"); + + let number: StressTestPayload = deserialize(&msg.payload).unwrap(); + assert_eq!(number.value, count, "Bad message order"); + count += 1; + + if count == MESSAGE_COUNT { + break; + } + } + }; + + let send_task = async move { + for i in 0..MESSAGE_COUNT { + let msg = GossipHandle::build_protocol_message( + IdentifierInfo::default(), + 0, + Some(1), + &StressTestPayload { value: i }, + Some(public0), + Some(public1), + ); + sub0.send(msg).unwrap(); + } + }; + + tokio::join!(recv_task, send_task) + }); + + let handle1 = tokio::spawn(async move { + let sub1 = &sub1; + + let recv_task = async move { + let mut count = 0; + while let Some(msg) = sub1.next_message().await { + assert_eq!(msg.sender.user_id, 0, "Bad sender"); + assert_eq!(msg.recipient.unwrap().user_id, 1, "Bad recipient"); + let number: StressTestPayload = deserialize(&msg.payload).unwrap(); + assert_eq!(number.value, count, "Bad message order"); + count += 1; + + if count == MESSAGE_COUNT { + break; + } + } + }; + + let send_task = async move { + for i in 0..MESSAGE_COUNT { + let msg = GossipHandle::build_protocol_message( + IdentifierInfo::default(), + 1, + Some(0), + &StressTestPayload { value: i }, + Some(public1), + Some(public0), + ); + sub1.send(msg).unwrap(); + } + }; + + tokio::join!(recv_task, send_task) + }); + + // Wait for all tasks to complete + tokio::try_join!(handle0, handle1).unwrap(); } #[tokio::test(flavor = "multi_thread")]