From c1037b3752571de420fe161673629b4b61049169 Mon Sep 17 00:00:00 2001 From: Thomas Braun <38082993+tbraun96@users.noreply.github.com> Date: Wed, 20 Nov 2024 09:34:35 -0500 Subject: [PATCH] fix: race conditions in multiplexer (#486) * fix: network wrapper to use [0u8; 32] task hash * refactor: cleanup round_based_compat fix: don't poll newly created futures * fix: race conditions, out of order delivery * chore: cleanup --- Cargo.lock | 11 - Cargo.toml | 1 - sdk/Cargo.toml | 1 - sdk/src/event_listener/tangle/mod.rs | 2 +- sdk/src/network/mod.rs | 425 +++++++++++++++++++++----- sdk/src/network/round_based_compat.rs | 230 +++++++------- 6 files changed, 459 insertions(+), 211 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 881348d8..f118c458 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1754,16 +1754,6 @@ dependencies = [ "serde", ] -[[package]] -name = "bincode2" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f49f6183038e081170ebbbadee6678966c7d54728938a3e7de7f4e780770318f" -dependencies = [ - "byteorder", - "serde", -] - [[package]] name = "bindgen" version = "0.69.5" @@ -4834,7 +4824,6 @@ dependencies = [ "auto_impl", "backon", "bincode", - "bincode2", "bollard", "clap", "color-eyre", diff --git a/Cargo.toml b/Cargo.toml index ac7a0c67..ad3d0e2b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -196,7 +196,6 @@ testcontainers = { version = "0.20.1" } # Symbiotic symbiotic-rs = { version = "0.1.0" } dashmap = "6.1.0" -bincode2 = "2.0.1" lru-mem = "0.3.0" [profile.dev.package.backtrace] diff --git a/sdk/Cargo.toml b/sdk/Cargo.toml index c2267046..9c111096 100644 --- a/sdk/Cargo.toml +++ b/sdk/Cargo.toml @@ -99,7 +99,6 @@ lru-mem = { workspace = true } sysinfo = { workspace = true } dashmap = { workspace = true } lazy_static = "1.5.0" -bincode2 = { workspace = true } color-eyre = { workspace = true } diff --git a/sdk/src/event_listener/tangle/mod.rs b/sdk/src/event_listener/tangle/mod.rs index d8ad0990..496bd897 100644 --- a/sdk/src/event_listener/tangle/mod.rs +++ b/sdk/src/event_listener/tangle/mod.rs @@ -176,7 +176,7 @@ impl .filter_map(|r| r.ok().and_then(E::try_decode)) .collect::>(); - crate::info!("Found {} possible events ...", events.len()); + crate::debug!("Found {} possible events ...", events.len()); self.enqueued_events = events; } } diff --git a/sdk/src/network/mod.rs b/sdk/src/network/mod.rs index 55d9a671..7066af2e 100644 --- a/sdk/src/network/mod.rs +++ b/sdk/src/network/mod.rs @@ -5,7 +5,9 @@ use dashmap::DashMap; use futures::{Stream, StreamExt}; use serde::{Deserialize, Serialize}; use sp_core::{ecdsa, sha2_256}; -use std::ops::Deref; +use std::cmp::Reverse; +use std::collections::{BinaryHeap, HashMap}; +use std::ops::{Deref, DerefMut}; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -24,31 +26,15 @@ pub mod setup; #[derive(Debug, Serialize, Deserialize, Clone, Copy, Default)] pub struct IdentifierInfo { - pub block_id: Option, - pub session_id: Option, - pub retry_id: Option, - pub task_id: Option, + pub message_id: u64, + pub round_id: u16, } impl Display for IdentifierInfo { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let block_id = self - .block_id - .map(|id| format!("block_id: {}", id)) - .unwrap_or_default(); - let session_id = self - .session_id - .map(|id| format!("session_id: {}", id)) - .unwrap_or_default(); - let retry_id = self - .retry_id - .map(|id| format!("retry_id: {}", id)) - .unwrap_or_default(); - let task_id = self - .task_id - .map(|id| format!("task_id: {}", id)) - .unwrap_or_default(); - write!(f, "{} {} {} {}", block_id, session_id, retry_id, task_id) + let message_id = format!("message_id: {}", self.message_id); + let round_id = format!("round_id: {}", self.round_id); + write!(f, "{} {}", message_id, round_id) } } @@ -117,10 +103,49 @@ pub trait Network: Send + Sync + 'static { } } +#[derive(Debug, Serialize, Deserialize)] +struct SequencedMessage { + seq: u64, + payload: Vec, +} + +#[derive(Debug)] +struct PendingMessage { + seq: u64, + message: ProtocolMessage, +} + +impl PartialEq for PendingMessage { + fn eq(&self, other: &Self) -> bool { + self.seq == other.seq + } +} + +impl Eq for PendingMessage {} + +impl PartialOrd for PendingMessage { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for PendingMessage { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.seq.cmp(&other.seq) + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct MultiplexedMessage { + stream_id: StreamKey, + payload: SequencedMessage, +} + pub struct NetworkMultiplexer { to_receiving_streams: ActiveStreams, unclaimed_receiving_streams: Arc>, tx_to_networking_layer: MultiplexedSender, + sequence_numbers: Arc>, } type ActiveStreams = Arc>>; @@ -179,16 +204,25 @@ impl Deref for MultiplexedReceiver { } } +impl DerefMut for MultiplexedReceiver { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + impl Drop for MultiplexedReceiver { fn drop(&mut self) { let _ = self.active_streams.remove(&self.stream_id); } } -#[derive(Debug, Serialize, Deserialize)] -struct MultiplexedMessage { - payload: Vec, +// Since a single stream can be used for multiple users, and, multiple users assign seq's independently, +// we need to make a key that is unique for each (send->dest) pair and stream. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +struct CompoundStreamKey { stream_id: StreamKey, + send_user_id: UserID, + recv_user_id: Option, } impl NetworkMultiplexer { @@ -200,27 +234,50 @@ impl NetworkMultiplexer { unclaimed_receiving_streams: Arc::new(DashMap::new()), tx_to_networking_layer: MultiplexedSender { inner: tx_to_networking_layer, - stream_id: Default::default(), // Start with an arbitrary stream ID, this won't get used + stream_id: Default::default(), }, + sequence_numbers: Arc::new(DashMap::new()), }; let active_streams = this.to_receiving_streams.clone(); let unclaimed_streams = this.unclaimed_receiving_streams.clone(); let tx_to_networking_layer = this.tx_to_networking_layer.clone(); + let sequence_numbers = this.sequence_numbers.clone(); + drop(tokio::spawn(async move { let network_clone = &network; let task1 = async move { - while let Some((stream_id, proto_message)) = rx_from_substreams.recv().await { + while let Some((stream_id, msg)) = rx_from_substreams.recv().await { + let compound_key = CompoundStreamKey { + stream_id, + send_user_id: msg.sender.user_id, + recv_user_id: msg.recipient.as_ref().map(|p| p.user_id), + }; + + let mut seq = sequence_numbers.entry(compound_key).or_insert(0); + let current_seq = *seq; + *seq += 1; + + crate::trace!( + "SEND SEQ {current_seq} FROM {} | StreamKey: {:?}", + msg.sender.user_id, + hex::encode(bincode::serialize(&compound_key).unwrap()) + ); + let multiplexed_message = MultiplexedMessage { - payload: proto_message.payload, stream_id, + payload: SequencedMessage { + seq: current_seq, + payload: msg.payload, + }, }; + let message = ProtocolMessage { - identifier_info: proto_message.identifier_info, - sender: proto_message.sender, - recipient: proto_message.recipient, - payload: bincode2::serialize(&multiplexed_message) + identifier_info: msg.identifier_info, + sender: msg.sender, + recipient: msg.recipient, + payload: bincode::serialize(&multiplexed_message) .expect("Failed to serialize message"), }; @@ -232,32 +289,101 @@ impl NetworkMultiplexer { }; let task2 = async move { + let mut pending_messages: HashMap< + CompoundStreamKey, + BinaryHeap>, + > = Default::default(); + let mut expected_seqs: HashMap = Default::default(); + while let Some(mut msg) = network_clone.next_message().await { if let Ok(multiplexed_message) = - bincode2::deserialize::(&msg.payload) + bincode::deserialize::(&msg.payload) { let stream_id = multiplexed_message.stream_id; - msg.payload = multiplexed_message.payload; - // Two possibilities: the entry already exists, or, it doesn't and we need to enqueue + let compound_key = CompoundStreamKey { + stream_id, + send_user_id: msg.sender.user_id, + recv_user_id: msg.recipient.as_ref().map(|p| p.user_id), + }; + let seq = multiplexed_message.payload.seq; + msg.payload = multiplexed_message.payload.payload; + + // Get or create the pending heap for this stream + let pending = pending_messages.entry(compound_key).or_default(); + let expected_seq = expected_seqs.entry(compound_key).or_default(); + + let send_user = msg.sender.user_id; + let recv_user = msg + .recipient + .as_ref() + .map(|p| p.user_id as i32) + .unwrap_or(-1); + + let compound_key_hex = + hex::encode(bincode::serialize(&compound_key).unwrap()); + crate::trace!( + "RECV SEQ {seq} FROM {} as user {:?} | Expecting: {} | StreamKey: {:?}", + send_user, + recv_user, + *expected_seq, + compound_key_hex, + ); + + // Add the message to pending + pending.push(Reverse(PendingMessage { seq, message: msg })); + + // Try to deliver messages in order if let Some(active_receiver) = active_streams.get(&stream_id) { - if let Err(err) = active_receiver.send(msg) { - crate::error!(%err, "Failed to send message to receiver"); - // Delete entry since the receiver is dead - let _ = active_streams.remove(&stream_id); + while let Some(Reverse(PendingMessage { seq, message: _ })) = + pending.peek() + { + if *seq != *expected_seq { + break; + } + + crate::trace!("DELIVERING SEQ {seq} FROM {} as user {:?} | Expecting: {} | StreamKey: {:?}", send_user, recv_user, *expected_seq, compound_key_hex); + + *expected_seq += 1; + + let message = pending.pop().unwrap().0.message; + + if let Err(err) = active_receiver.send(message) { + crate::error!(%err, "Failed to send message to receiver"); + let _ = active_streams.remove(&stream_id); + break; + } } } else { - // Second possibility: the entry does not exist, and another substream is received for this task. - // In this case, reserve an entry locally and store the message in the unclaimed streams. Later, - // when the user attempts to open the substream with the same ID, the message will be sent to the user. let (tx, rx) = Self::create_multiplexed_stream_inner( tx_to_networking_layer.clone(), &active_streams, stream_id, ); - let _ = tx.send(msg); - //let _ = active_streams.insert(stream_id, tx); TX already passed into active_streams above + + // Deliver any pending messages in order + while let Some(Reverse(PendingMessage { seq, message: _ })) = + pending.peek() + { + if *seq != *expected_seq { + break; + } + + crate::warn!("EARLY DELIVERY SEQ {seq} FROM {} as user {:?} | Expecting: {} | StreamKey: {:?}", send_user, recv_user, *expected_seq, compound_key_hex); + + *expected_seq += 1; + + let message = pending.pop().unwrap().0.message; + + if let Err(err) = tx.send(message) { + crate::error!(%err, "Failed to send message to receiver"); + break; + } + } + let _ = unclaimed_streams.insert(stream_id, rx); } + } else { + crate::error!("Failed to deserialize message"); } } }; @@ -282,7 +408,7 @@ impl NetworkMultiplexer { tx_to_networking_layer.stream_id = id; return SubNetwork { tx: tx_to_networking_layer, - rx: unclaimed.1.into(), + rx: Some(unclaimed.1.into()), }; } @@ -292,7 +418,42 @@ impl NetworkMultiplexer { id, ); - SubNetwork { tx, rx: rx.into() } + SubNetwork { + tx, + rx: Some(rx.into()), + } + } + + /// Creates a subnetwork, and also forwards all messages to the given channel. The network cannot be used to + /// receive messages since the messages will be forwarded to the provided channel. + pub fn multiplex_with_forwarding( + &self, + id: impl Into, + forward_tx: tokio::sync::mpsc::UnboundedSender, + ) -> SubNetwork { + let mut network = self.multiplex(id); + let rx = network.rx.take().expect("Rx from network should be Some"); + let forwarding_task = async move { + let mut rx = rx.into_inner(); + while let Some(msg) = rx.recv().await { + crate::info!( + "Round {}: Received message from {} to {:?} (id: {})", + msg.identifier_info.round_id, + msg.sender.user_id, + msg.recipient.as_ref().map(|p| p.user_id), + msg.identifier_info.message_id, + ); + if let Err(err) = forward_tx.send(msg) { + crate::error!(%err, "Failed to forward message to network"); + // TODO: Add AtomicBool to make sending stop + break; + } + } + }; + + drop(tokio::spawn(forwarding_task)); + + network } fn create_multiplexed_stream_inner( @@ -327,7 +488,7 @@ impl From for NetworkMultiplexer { pub struct SubNetwork { tx: MultiplexedSender, - rx: Mutex, + rx: Option>, } impl SubNetwork { @@ -336,7 +497,7 @@ impl SubNetwork { } pub async fn recv(&self) -> Option { - self.rx.lock().await.next().await + self.rx.as_ref()?.lock().await.next().await } } @@ -351,14 +512,6 @@ impl Network for SubNetwork { } } -impl Stream for SubNetwork { - type Item = ProtocolMessage; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(self.rx.get_mut()).poll_next(cx) - } -} - pub fn deserialize<'a, T>(data: &'a [u8]) -> Result where T: Deserialize<'a>, @@ -409,7 +562,6 @@ mod tests { velocity: (u16, u16, u16), } - // NOTE: if you lower the number of nodes to 2, this test passes without issues. const NODE_COUNT: u16 = 10; pub fn setup_log() { @@ -492,17 +644,17 @@ mod tests { // used throughout the program must also use the multiplexer to prevent mixed messages. let multiplexer = NetworkMultiplexer::new(node); - let mut round1_network = multiplexer.multiplex(StreamKey { + let round1_network = multiplexer.multiplex(StreamKey { task_hash, // To differentiate between different instances of a running program (i.e., a task) round_id: 0, // To differentiate between different subsets of a running task }); - let mut round2_network = multiplexer.multiplex(StreamKey { + let round2_network = multiplexer.multiplex(StreamKey { task_hash, // To differentiate between different instances of a running program (i.e., a task) round_id: 1, // To differentiate between different subsets of a running task }); - let mut round3_network = multiplexer.multiplex(StreamKey { + let round3_network = multiplexer.multiplex(StreamKey { task_hash, // To differentiate between different instances of a running program (i.e., a task) round_id: 2, // To differentiate between different subsets of a running task }); @@ -519,10 +671,8 @@ mod tests { GossipHandle::build_protocol_message( IdentifierInfo { - block_id: None, - session_id: None, - retry_id: None, - task_id: None, + message_id: 0, + round_id: 0, }, i, None, @@ -539,7 +689,7 @@ mod tests { // Wait for all other nodes to send their messages let mut msgs = BTreeMap::new(); - while let Some(msg) = round1_network.next().await { + while let Some(msg) = round1_network.recv().await { let m = deserialize::(&msg.payload).unwrap(); crate::debug!(from = %msg.sender.user_id, ?m, "Received message"); // Expecting Round1 message @@ -553,7 +703,7 @@ mod tests { assert!( old.is_none(), "Duplicate message from node {}", - msg.sender.user_id + msg.sender.user_id, ); // Break if all messages are received if msgs.len() == usize::from(NODE_COUNT) - 1 { @@ -573,10 +723,8 @@ mod tests { .map(|j| { GossipHandle::build_protocol_message( IdentifierInfo { - block_id: None, - session_id: None, - retry_id: None, - task_id: None, + message_id: 0, + round_id: 0, }, i, Some(j), @@ -596,7 +744,7 @@ mod tests { // Wait for all other nodes to send their messages let mut msgs = BTreeMap::new(); - while let Some(msg) = round2_network.next().await { + while let Some(msg) = round2_network.recv().await { let m = deserialize::(&msg.payload).unwrap(); crate::debug!(from = %msg.sender.user_id, ?m, "Received message"); // Expecting Round2 message @@ -610,7 +758,7 @@ mod tests { assert!( old.is_none(), "Duplicate message from node {}", - msg.sender.user_id + msg.sender.user_id, ); // Break if all messages are received if msgs.len() == usize::from(NODE_COUNT) - 1 { @@ -628,10 +776,8 @@ mod tests { }; GossipHandle::build_protocol_message( IdentifierInfo { - block_id: None, - session_id: None, - retry_id: None, - task_id: None, + message_id: 0, + round_id: 0, }, i, None, @@ -646,7 +792,7 @@ mod tests { // Wait for all other nodes to send their messages let mut msgs = BTreeMap::new(); - while let Some(msg) = round3_network.next().await { + while let Some(msg) = round3_network.recv().await { let m = deserialize::(&msg.payload).unwrap(); crate::debug!(from = %msg.sender.user_id, ?m, "Received message"); // Expecting Round3 message @@ -660,7 +806,7 @@ mod tests { assert!( old.is_none(), "Duplicate message from node {}", - msg.sender.user_id + msg.sender.user_id, ); // Break if all messages are received if msgs.len() == usize::from(NODE_COUNT) - 1 { @@ -674,18 +820,133 @@ mod tests { Ok(()) } - fn node() -> gossip::GossipHandle { + fn node_with_id() -> (gossip::GossipHandle, ecdsa::Pair) { let identity = libp2p::identity::Keypair::generate_ed25519(); let ecdsa_key = sp_core::ecdsa::Pair::generate().0; let bind_port = 0; - setup::start_p2p_network(setup::NetworkConfig::new_service_network( + let handle = setup::start_p2p_network(setup::NetworkConfig::new_service_network( identity, - ecdsa_key, + ecdsa_key.clone(), Default::default(), bind_port, TOPIC, )) - .unwrap() + .unwrap(); + + (handle, ecdsa_key) + } + + fn node() -> gossip::GossipHandle { + node_with_id().0 + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_stress_test_multiplexer() { + setup_log(); + crate::info!("Starting test_stress_test_multiplexer"); + + let (network0, id0) = node_with_id(); + let (network1, id1) = node_with_id(); + let mut networks = vec![network0, network1]; + + wait_for_nodes_connected(&networks).await; + + let (network0, network1) = (networks.remove(0), networks.remove(0)); + + let public0 = id0.public(); + let public1 = id1.public(); + + let multiplexer0 = NetworkMultiplexer::new(network0); + let multiplexer1 = NetworkMultiplexer::new(network1); + + let stream_key = StreamKey { + task_hash: sha2_256(&[255u8]), + round_id: 100, + }; + + let sub0 = multiplexer0.multiplex(stream_key); + let sub1 = multiplexer1.multiplex(stream_key); + + const MESSAGE_COUNT: u64 = 100; + + #[derive(Serialize, Deserialize)] + struct StressTestPayload { + value: u64, + } + + let handle0 = tokio::spawn(async move { + let sub0 = &sub0; + + let recv_task = async move { + let mut count = 0; + while let Some(msg) = sub0.next_message().await { + assert_eq!(msg.sender.user_id, 1, "Bad sender"); + assert_eq!(msg.recipient.unwrap().user_id, 0, "Bad recipient"); + + let number: StressTestPayload = deserialize(&msg.payload).unwrap(); + assert_eq!(number.value, count, "Bad message order"); + count += 1; + + if count == MESSAGE_COUNT { + break; + } + } + }; + + let send_task = async move { + for i in 0..MESSAGE_COUNT { + let msg = GossipHandle::build_protocol_message( + IdentifierInfo::default(), + 0, + Some(1), + &StressTestPayload { value: i }, + Some(public0), + Some(public1), + ); + sub0.send(msg).unwrap(); + } + }; + + tokio::join!(recv_task, send_task) + }); + + let handle1 = tokio::spawn(async move { + let sub1 = &sub1; + + let recv_task = async move { + let mut count = 0; + while let Some(msg) = sub1.next_message().await { + assert_eq!(msg.sender.user_id, 0, "Bad sender"); + assert_eq!(msg.recipient.unwrap().user_id, 1, "Bad recipient"); + let number: StressTestPayload = deserialize(&msg.payload).unwrap(); + assert_eq!(number.value, count, "Bad message order"); + count += 1; + + if count == MESSAGE_COUNT { + break; + } + } + }; + + let send_task = async move { + for i in 0..MESSAGE_COUNT { + let msg = GossipHandle::build_protocol_message( + IdentifierInfo::default(), + 1, + Some(0), + &StressTestPayload { value: i }, + Some(public1), + Some(public0), + ); + sub1.send(msg).unwrap(); + } + }; + + tokio::join!(recv_task, send_task) + }); + + // Wait for all tasks to complete + tokio::try_join!(handle0, handle1).unwrap(); } #[tokio::test(flavor = "multi_thread")] diff --git a/sdk/src/network/round_based_compat.rs b/sdk/src/network/round_based_compat.rs index d2872f19..012e6931 100644 --- a/sdk/src/network/round_based_compat.rs +++ b/sdk/src/network/round_based_compat.rs @@ -5,89 +5,88 @@ use std::collections::{BTreeMap, HashMap, VecDeque}; use std::sync::Arc; use crate::futures::prelude::*; -use crate::network::{self, IdentifierInfo, Network, NetworkMultiplexer, StreamKey, SubNetwork}; +use crate::network::{IdentifierInfo, NetworkMultiplexer, ProtocolMessage, StreamKey, SubNetwork}; use crate::subxt_core::ext::sp_core::ecdsa; -use round_based::{Delivery, Incoming, Outgoing}; -use round_based::{MessageDestination, MessageType, MsgId, PartyIndex}; +use round_based::{Delivery, Incoming, MessageType, Outgoing}; +use round_based::{MessageDestination, MsgId, PartyIndex}; use stream::{SplitSink, SplitStream}; -pub struct NetworkDeliveryWrapper { +use super::ParticipantInfo; + +pub struct NetworkDeliveryWrapper { /// The wrapped network implementation. - network: NetworkWrapper, + network: NetworkWrapper, } -impl NetworkDeliveryWrapper +impl NetworkDeliveryWrapper where - N: Network + Unpin, M: Clone + Send + Unpin + 'static, - M: serde::Serialize, - M: serde::de::DeserializeOwned, + M: serde::Serialize + serde::de::DeserializeOwned, { /// Create a new NetworkDeliveryWrapper over a network implementation with the given party index. pub fn new( - network: N, + mux: Arc, i: PartyIndex, task_hash: [u8; 32], parties: BTreeMap, ) -> Self { - let mux = NetworkMultiplexer::new(network); - // By default, we create 4 substreams for each party. - let sub_streams = (1..5) - .map(|i| { - let key = StreamKey { - // This is a dummy task hash, it should be replaced with the actual task hash - task_hash: [0u8; 32], - round_id: i, - }; - let substream = mux.multiplex(key); - (key, substream) - }) - .collect(); + let (tx_forward, rx) = tokio::sync::mpsc::unbounded_channel(); + // By default, we create 10 substreams for each party. + let mut sub_streams = HashMap::new(); + for x in 0..10 { + let key = StreamKey { + task_hash, + round_id: x, + }; + // Creates a multiplexed subnetwork, and also forwards all messages to the given channel + let _ = sub_streams.insert(key, mux.multiplex_with_forwarding(key, tx_forward.clone())); + } + let network = NetworkWrapper { me: i, mux, incoming_queue: VecDeque::new(), - outgoing_queue: VecDeque::new(), sub_streams, participants: parties, task_hash, + tx_forward, + rx, next_msg_id: Arc::new(NextMessageId::default()), - _network: core::marker::PhantomData, }; + NetworkDeliveryWrapper { network } } } /// A NetworkWrapper wraps a network implementation and implements [`Stream`] and [`Sink`] for /// it. -pub struct NetworkWrapper { +pub struct NetworkWrapper { /// The current party index. me: PartyIndex, /// Our network Multiplexer. - mux: NetworkMultiplexer, + mux: Arc, /// A Map of substreams for each round. - sub_streams: HashMap, + sub_streams: HashMap, //HashMap, /// A queue of incoming messages. + #[allow(dead_code)] incoming_queue: VecDeque>, - /// A queue of outgoing messages. - outgoing_queue: VecDeque>, /// Participants in the network with their corresponding ECDSA public keys. // Note: This is a BTreeMap to ensure that the participants are sorted by their party index. participants: BTreeMap, next_msg_id: Arc, + tx_forward: tokio::sync::mpsc::UnboundedSender, + rx: tokio::sync::mpsc::UnboundedReceiver, task_hash: [u8; 32], - _network: core::marker::PhantomData, } -impl Delivery for NetworkDeliveryWrapper +impl Delivery for NetworkDeliveryWrapper where - N: Network + Unpin, M: Clone + Send + Unpin + 'static, M: serde::Serialize + serde::de::DeserializeOwned, M: round_based::ProtocolMessage, { - type Send = SplitSink, Outgoing>; - type Receive = SplitStream>; + type Send = SplitSink, Outgoing>; + type Receive = SplitStream>; type SendError = crate::Error; type ReceiveError = crate::Error; @@ -97,61 +96,46 @@ where } } -impl Stream for NetworkWrapper +impl Stream for NetworkWrapper where - N: Network + Unpin, M: serde::de::DeserializeOwned + Unpin, M: round_based::ProtocolMessage, { type Item = Result, crate::Error>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let sub_streams = self.sub_streams.values(); - // pull all substreams - let mut messages = Vec::new(); - for sub_stream in sub_streams { - let p = sub_stream.next_message().poll_unpin(cx); - let m = match p { - Poll::Ready(Some(msg)) => msg, - _ => continue, + let res = ready!(self.get_mut().rx.poll_recv(cx)); + if let Some(res) = res { + let msg_type = if res.recipient.is_some() { + MessageType::P2P + } else { + MessageType::Broadcast }; - let msg = network::deserialize::(&m.payload)?; - messages.push((m.sender.user_id, m.recipient, msg)); - } - // Sort the incoming messages by round. - messages.sort_by_key(|(_, _, msg)| msg.round()); + let id = res.identifier_info.message_id; - let this = self.get_mut(); - // Push all messages to the incoming queue - messages - .into_iter() - .map(|(sender, recipient, msg)| Incoming { - id: this.next_msg_id.next(), - sender, - msg_type: match recipient { - Some(_) => MessageType::P2P, - None => MessageType::Broadcast, - }, + let msg = match bincode::deserialize(&res.payload) { + Ok(msg) => msg, + Err(err) => { + crate::error!(%err, "Failed to deserialize message"); + return Poll::Ready(Some(Err(crate::Error::Other(err.to_string())))); + } + }; + + Poll::Ready(Some(Ok(Incoming { msg, - }) - .for_each(|m| this.incoming_queue.push_back(m)); - // Reorder the incoming queue by round message. - let maybe_msg = this.incoming_queue.pop_front(); - if let Some(msg) = maybe_msg { - Poll::Ready(Some(Ok(msg))) + sender: res.sender.user_id, + id, + msg_type, + }))) } else { - // No message in the queue, and no message in the substreams. - // Tell the network to wake us up when a new message arrives. - cx.waker().wake_by_ref(); - Poll::Pending + Poll::Ready(None) } } } -impl Sink> for NetworkWrapper +impl Sink> for NetworkWrapper where - N: Network + Unpin, M: Unpin + serde::Serialize, M: round_based::ProtocolMessage, { @@ -161,48 +145,64 @@ where Poll::Ready(Ok(())) } - fn start_send(self: Pin<&mut Self>, msg: Outgoing) -> Result<(), Self::Error> { - self.get_mut().outgoing_queue.push_back(msg); - Ok(()) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - // Dequeue all messages and send them one by one to the network + fn start_send(self: Pin<&mut Self>, out: Outgoing) -> Result<(), Self::Error> { let this = self.get_mut(); - while let Some(out) = this.outgoing_queue.pop_front() { - // Get the substream to send the message to. - let key = StreamKey { - task_hash: this.task_hash, - round_id: i32::from(out.msg.round()), - }; - let substream = this - .sub_streams - .entry(key) - .or_insert_with(|| this.mux.multiplex(key)); - let identifier_info = IdentifierInfo { - block_id: None, - session_id: None, - retry_id: None, - task_id: None, - }; - let (to, to_network_id) = match out.recipient { - MessageDestination::AllParties => (None, None), - MessageDestination::OneParty(p) => (Some(p), this.participants.get(&p).cloned()), - }; - let protocol_message = N::build_protocol_message( - identifier_info, - this.me, - to, - &out.msg, - this.participants.get(&this.me).cloned(), - to_network_id, - ); - let p = substream.send_message(protocol_message).poll_unpin(cx); - match ready!(p) { - Ok(()) => continue, - Err(e) => return Poll::Ready(Err(e)), - } + let id = this.next_msg_id.next(); + + let round_id = out.msg.round(); + + crate::info!( + "Round {}: Sending message from {} to {:?} (id: {})", + round_id, + this.me, + out.recipient, + id, + ); + + // Get the substream to send the message to. + let key = StreamKey { + task_hash: this.task_hash, + round_id: i32::from(round_id), + }; + let substream = this.sub_streams.entry(key).or_insert_with(|| { + this.mux + .multiplex_with_forwarding(key, this.tx_forward.clone()) + }); + + let identifier_info = IdentifierInfo { + message_id: id, + round_id, + }; + let (to, to_network_id) = match out.recipient { + MessageDestination::AllParties => (None, None), + MessageDestination::OneParty(p) => (Some(p), this.participants.get(&p).cloned()), + }; + + if matches!(out.recipient, MessageDestination::OneParty(_)) && to_network_id.is_none() { + crate::warn!("Recipient not found when required for {:?}", out.recipient); + return Err(crate::Error::Other("Recipient not found".to_string())); } + + let protocol_message = ProtocolMessage { + identifier_info, + sender: ParticipantInfo { + user_id: this.me, + ecdsa_key: this.participants.get(&this.me).cloned(), + }, + recipient: to.map(|user_id| ParticipantInfo { + user_id, + ecdsa_key: to_network_id, + }), + payload: bincode::serialize(&out.msg).expect("Should be able to serialize message"), + }; + + match substream.send(protocol_message) { + Ok(()) => Ok(()), + Err(e) => Err(e), + } + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { Poll::Ready(Ok(())) } @@ -215,7 +215,7 @@ where struct NextMessageId(AtomicU64); impl NextMessageId { - pub fn next(&self) -> MsgId { - self.0.fetch_add(1, core::sync::atomic::Ordering::Relaxed) + fn next(&self) -> MsgId { + self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed) } }