diff --git a/sdk/src/event_listener/tangle/mod.rs b/sdk/src/event_listener/tangle/mod.rs index d8ad0990..496bd897 100644 --- a/sdk/src/event_listener/tangle/mod.rs +++ b/sdk/src/event_listener/tangle/mod.rs @@ -176,7 +176,7 @@ impl .filter_map(|r| r.ok().and_then(E::try_decode)) .collect::>(); - crate::info!("Found {} possible events ...", events.len()); + crate::debug!("Found {} possible events ...", events.len()); self.enqueued_events = events; } } diff --git a/sdk/src/network/mod.rs b/sdk/src/network/mod.rs index 55d9a671..2cf4a429 100644 --- a/sdk/src/network/mod.rs +++ b/sdk/src/network/mod.rs @@ -5,7 +5,7 @@ use dashmap::DashMap; use futures::{Stream, StreamExt}; use serde::{Deserialize, Serialize}; use sp_core::{ecdsa, sha2_256}; -use std::ops::Deref; +use std::ops::{Deref, DerefMut}; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -24,31 +24,15 @@ pub mod setup; #[derive(Debug, Serialize, Deserialize, Clone, Copy, Default)] pub struct IdentifierInfo { - pub block_id: Option, - pub session_id: Option, - pub retry_id: Option, - pub task_id: Option, + pub message_id: u64, + pub round_id: u16, } impl Display for IdentifierInfo { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let block_id = self - .block_id - .map(|id| format!("block_id: {}", id)) - .unwrap_or_default(); - let session_id = self - .session_id - .map(|id| format!("session_id: {}", id)) - .unwrap_or_default(); - let retry_id = self - .retry_id - .map(|id| format!("retry_id: {}", id)) - .unwrap_or_default(); - let task_id = self - .task_id - .map(|id| format!("task_id: {}", id)) - .unwrap_or_default(); - write!(f, "{} {} {} {}", block_id, session_id, retry_id, task_id) + let message_id = format!("message_id: {}", self.message_id); + let round_id = format!("round_id: {}", self.round_id); + write!(f, "{} {}", message_id, round_id) } } @@ -179,6 +163,12 @@ impl Deref for MultiplexedReceiver { } } +impl DerefMut for MultiplexedReceiver { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + impl Drop for MultiplexedReceiver { fn drop(&mut self) { let _ = self.active_streams.remove(&self.stream_id); @@ -211,15 +201,22 @@ impl NetworkMultiplexer { let network_clone = &network; let task1 = async move { - while let Some((stream_id, proto_message)) = rx_from_substreams.recv().await { + while let Some((stream_id, msg)) = rx_from_substreams.recv().await { + crate::info!( + "Round {}: Sending RAW message from {} to {:?} (id: {})", + msg.identifier_info.round_id, + msg.sender.user_id, + msg.recipient.as_ref().map(|p| p.user_id), + msg.identifier_info.message_id, + ); let multiplexed_message = MultiplexedMessage { - payload: proto_message.payload, + payload: msg.payload, stream_id, }; let message = ProtocolMessage { - identifier_info: proto_message.identifier_info, - sender: proto_message.sender, - recipient: proto_message.recipient, + identifier_info: msg.identifier_info, + sender: msg.sender, + recipient: msg.recipient, payload: bincode2::serialize(&multiplexed_message) .expect("Failed to serialize message"), }; @@ -233,6 +230,13 @@ impl NetworkMultiplexer { let task2 = async move { while let Some(mut msg) = network_clone.next_message().await { + crate::info!( + "Round {}: Received RAW message from {} to {:?} (id: {})", + msg.identifier_info.round_id, + msg.sender.user_id, + msg.recipient.as_ref().map(|p| p.user_id), + msg.identifier_info.message_id, + ); if let Ok(multiplexed_message) = bincode2::deserialize::(&msg.payload) { @@ -282,7 +286,7 @@ impl NetworkMultiplexer { tx_to_networking_layer.stream_id = id; return SubNetwork { tx: tx_to_networking_layer, - rx: unclaimed.1.into(), + rx: Some(unclaimed.1.into()), }; } @@ -292,7 +296,42 @@ impl NetworkMultiplexer { id, ); - SubNetwork { tx, rx: rx.into() } + SubNetwork { + tx, + rx: Some(rx.into()), + } + } + + /// Creates a subnetwork, and also forwards all messages to the given channel. The network cannot be used to + /// receive messages since the messages will be forwarded to the provided channel. + pub fn multiplex_with_forwarding( + &self, + id: impl Into, + forward_tx: tokio::sync::mpsc::UnboundedSender, + ) -> SubNetwork { + let mut network = self.multiplex(id); + let rx = network.rx.take().expect("Rx from network should be Some"); + let forwarding_task = async move { + let mut rx = rx.into_inner(); + while let Some(msg) = rx.recv().await { + crate::info!( + "Round {}: Received message from {} to {:?} (id: {})", + msg.identifier_info.round_id, + msg.sender.user_id, + msg.recipient.as_ref().map(|p| p.user_id), + msg.identifier_info.message_id, + ); + if let Err(err) = forward_tx.send(msg) { + crate::error!(%err, "Failed to forward message to network"); + // TODO: Add AtomicBool to make sending stop + break; + } + } + }; + + drop(tokio::spawn(forwarding_task)); + + network } fn create_multiplexed_stream_inner( @@ -327,7 +366,7 @@ impl From for NetworkMultiplexer { pub struct SubNetwork { tx: MultiplexedSender, - rx: Mutex, + rx: Option>, } impl SubNetwork { @@ -336,7 +375,7 @@ impl SubNetwork { } pub async fn recv(&self) -> Option { - self.rx.lock().await.next().await + self.rx.as_ref()?.lock().await.next().await } } @@ -351,14 +390,6 @@ impl Network for SubNetwork { } } -impl Stream for SubNetwork { - type Item = ProtocolMessage; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(self.rx.get_mut()).poll_next(cx) - } -} - pub fn deserialize<'a, T>(data: &'a [u8]) -> Result where T: Deserialize<'a>, @@ -492,17 +523,17 @@ mod tests { // used throughout the program must also use the multiplexer to prevent mixed messages. let multiplexer = NetworkMultiplexer::new(node); - let mut round1_network = multiplexer.multiplex(StreamKey { + let round1_network = multiplexer.multiplex(StreamKey { task_hash, // To differentiate between different instances of a running program (i.e., a task) round_id: 0, // To differentiate between different subsets of a running task }); - let mut round2_network = multiplexer.multiplex(StreamKey { + let round2_network = multiplexer.multiplex(StreamKey { task_hash, // To differentiate between different instances of a running program (i.e., a task) round_id: 1, // To differentiate between different subsets of a running task }); - let mut round3_network = multiplexer.multiplex(StreamKey { + let round3_network = multiplexer.multiplex(StreamKey { task_hash, // To differentiate between different instances of a running program (i.e., a task) round_id: 2, // To differentiate between different subsets of a running task }); @@ -519,10 +550,8 @@ mod tests { GossipHandle::build_protocol_message( IdentifierInfo { - block_id: None, - session_id: None, - retry_id: None, - task_id: None, + message_id: 0, + round_id: 0, }, i, None, @@ -539,7 +568,7 @@ mod tests { // Wait for all other nodes to send their messages let mut msgs = BTreeMap::new(); - while let Some(msg) = round1_network.next().await { + while let Some(msg) = round1_network.recv().await { let m = deserialize::(&msg.payload).unwrap(); crate::debug!(from = %msg.sender.user_id, ?m, "Received message"); // Expecting Round1 message @@ -573,10 +602,8 @@ mod tests { .map(|j| { GossipHandle::build_protocol_message( IdentifierInfo { - block_id: None, - session_id: None, - retry_id: None, - task_id: None, + message_id: 0, + round_id: 0, }, i, Some(j), @@ -596,7 +623,7 @@ mod tests { // Wait for all other nodes to send their messages let mut msgs = BTreeMap::new(); - while let Some(msg) = round2_network.next().await { + while let Some(msg) = round2_network.recv().await { let m = deserialize::(&msg.payload).unwrap(); crate::debug!(from = %msg.sender.user_id, ?m, "Received message"); // Expecting Round2 message @@ -628,10 +655,8 @@ mod tests { }; GossipHandle::build_protocol_message( IdentifierInfo { - block_id: None, - session_id: None, - retry_id: None, - task_id: None, + message_id: 0, + round_id: 0, }, i, None, @@ -646,7 +671,7 @@ mod tests { // Wait for all other nodes to send their messages let mut msgs = BTreeMap::new(); - while let Some(msg) = round3_network.next().await { + while let Some(msg) = round3_network.recv().await { let m = deserialize::(&msg.payload).unwrap(); crate::debug!(from = %msg.sender.user_id, ?m, "Received message"); // Expecting Round3 message diff --git a/sdk/src/network/round_based_compat.rs b/sdk/src/network/round_based_compat.rs index e1ac3c1f..c564dd38 100644 --- a/sdk/src/network/round_based_compat.rs +++ b/sdk/src/network/round_based_compat.rs @@ -5,89 +5,88 @@ use std::collections::{BTreeMap, HashMap, VecDeque}; use std::sync::Arc; use crate::futures::prelude::*; -use crate::network::{self, IdentifierInfo, Network, NetworkMultiplexer, StreamKey, SubNetwork}; +use crate::network::{IdentifierInfo, NetworkMultiplexer, ProtocolMessage, StreamKey, SubNetwork}; use crate::subxt_core::ext::sp_core::ecdsa; -use round_based::{Delivery, Incoming, Outgoing}; -use round_based::{MessageDestination, MessageType, MsgId, PartyIndex}; +use round_based::{Delivery, Incoming, MessageType, Outgoing}; +use round_based::{MessageDestination, MsgId, PartyIndex}; use stream::{SplitSink, SplitStream}; -pub struct NetworkDeliveryWrapper { +use super::ParticipantInfo; + +pub struct NetworkDeliveryWrapper { /// The wrapped network implementation. - network: NetworkWrapper, + network: NetworkWrapper, } -impl NetworkDeliveryWrapper +impl NetworkDeliveryWrapper where - N: Network + Unpin, M: Clone + Send + Unpin + 'static, - M: serde::Serialize, - M: serde::de::DeserializeOwned, + M: serde::Serialize + serde::de::DeserializeOwned, { /// Create a new NetworkDeliveryWrapper over a network implementation with the given party index. pub fn new( - network: N, + mux: Arc, i: PartyIndex, task_hash: [u8; 32], parties: BTreeMap, ) -> Self { - let mux = NetworkMultiplexer::new(network); + let (tx_forward, rx) = tokio::sync::mpsc::unbounded_channel(); // By default, we create 10 substreams for each party. - let sub_streams = (0..10) - .map(|i| { - let key = StreamKey { - // This is a dummy task hash, it should be replaced with the actual task hash - task_hash: [0u8; 32], - round_id: i, - }; - let substream = mux.multiplex(key); - (key, substream) - }) - .collect(); + let mut sub_streams = HashMap::new(); + for x in 0..10 { + let key = StreamKey { + task_hash, + round_id: x, + }; + // Creates a multiplexed subnetwork, and also forwards all messages to the given channel + let _ = sub_streams.insert(key, mux.multiplex_with_forwarding(key, tx_forward.clone())); + } + let network = NetworkWrapper { me: i, mux, incoming_queue: VecDeque::new(), - outgoing_queue: VecDeque::new(), sub_streams, participants: parties, task_hash, + tx_forward, + rx, next_msg_id: Arc::new(NextMessageId::default()), - _network: core::marker::PhantomData, }; + NetworkDeliveryWrapper { network } } } /// A NetworkWrapper wraps a network implementation and implements [`Stream`] and [`Sink`] for /// it. -pub struct NetworkWrapper { +pub struct NetworkWrapper { /// The current party index. me: PartyIndex, /// Our network Multiplexer. - mux: NetworkMultiplexer, + mux: Arc, /// A Map of substreams for each round. - sub_streams: HashMap, + sub_streams: HashMap, //HashMap, /// A queue of incoming messages. + #[allow(dead_code)] incoming_queue: VecDeque>, - /// A queue of outgoing messages. - outgoing_queue: VecDeque>, /// Participants in the network with their corresponding ECDSA public keys. // Note: This is a BTreeMap to ensure that the participants are sorted by their party index. participants: BTreeMap, next_msg_id: Arc, + tx_forward: tokio::sync::mpsc::UnboundedSender, + rx: tokio::sync::mpsc::UnboundedReceiver, task_hash: [u8; 32], - _network: core::marker::PhantomData, } -impl Delivery for NetworkDeliveryWrapper +impl Delivery for NetworkDeliveryWrapper where - N: Network + Unpin, M: Clone + Send + Unpin + 'static, M: serde::Serialize + serde::de::DeserializeOwned, M: round_based::ProtocolMessage, { - type Send = SplitSink, Outgoing>; - type Receive = SplitStream>; + type Send = SplitSink, Outgoing>; + type Receive = SplitStream>; type SendError = crate::Error; type ReceiveError = crate::Error; @@ -97,61 +96,46 @@ where } } -impl Stream for NetworkWrapper +impl Stream for NetworkWrapper where - N: Network + Unpin, M: serde::de::DeserializeOwned + Unpin, M: round_based::ProtocolMessage, { type Item = Result, crate::Error>; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let sub_streams = self.sub_streams.values(); - // pull all substreams - let mut messages = Vec::new(); - for sub_stream in sub_streams { - let p = sub_stream.next_message().poll_unpin(cx); - let m = match p { - Poll::Ready(Some(msg)) => msg, - _ => continue, + let res = ready!(self.get_mut().rx.poll_recv(cx)); + if let Some(res) = res { + let msg_type = if res.recipient.is_some() { + MessageType::P2P + } else { + MessageType::Broadcast }; - let msg = network::deserialize::(&m.payload)?; - messages.push((m.sender.user_id, m.recipient, msg)); - } - // Sort the incoming messages by round. - messages.sort_by_key(|(_, _, msg)| msg.round()); + let id = res.identifier_info.message_id; - let this = self.get_mut(); - // Push all messages to the incoming queue - messages - .into_iter() - .map(|(sender, recipient, msg)| Incoming { - id: this.next_msg_id.next(), - sender, - msg_type: match recipient { - Some(_) => MessageType::P2P, - None => MessageType::Broadcast, - }, + let msg = match crate::network::deserialize(&res.payload) { + Ok(msg) => msg, + Err(err) => { + crate::error!(%err, "Failed to deserialize message"); + return Poll::Ready(Some(Err(crate::Error::Other(err.to_string())))); + } + }; + + Poll::Ready(Some(Ok(Incoming { msg, - }) - .for_each(|m| this.incoming_queue.push_back(m)); - // Reorder the incoming queue by round message. - let maybe_msg = this.incoming_queue.pop_front(); - if let Some(msg) = maybe_msg { - Poll::Ready(Some(Ok(msg))) + sender: res.sender.user_id, + id, + msg_type, + }))) } else { - // No message in the queue, and no message in the substreams. - // Tell the network to wake us up when a new message arrives. - cx.waker().wake_by_ref(); - Poll::Pending + Poll::Ready(None) } } } -impl Sink> for NetworkWrapper +impl Sink> for NetworkWrapper where - N: Network + Unpin, M: Unpin + serde::Serialize, M: round_based::ProtocolMessage, { @@ -161,48 +145,58 @@ where Poll::Ready(Ok(())) } - fn start_send(self: Pin<&mut Self>, msg: Outgoing) -> Result<(), Self::Error> { - self.get_mut().outgoing_queue.push_back(msg); - Ok(()) - } - - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - // Dequeue all messages and send them one by one to the network + fn start_send(self: Pin<&mut Self>, out: Outgoing) -> Result<(), Self::Error> { let this = self.get_mut(); - while let Some(out) = this.outgoing_queue.pop_front() { - // Get the substream to send the message to. - let key = StreamKey { - task_hash: [0u8; 32], // TODO: Use real hash - round_id: i32::from(out.msg.round()), - }; - let substream = this - .sub_streams - .entry(key) - .or_insert_with(|| this.mux.multiplex(key)); - let identifier_info = IdentifierInfo { - block_id: None, - session_id: None, - retry_id: None, - task_id: None, - }; - let (to, to_network_id) = match out.recipient { - MessageDestination::AllParties => (None, None), - MessageDestination::OneParty(p) => (Some(p), this.participants.get(&p).cloned()), - }; - let protocol_message = N::build_protocol_message( - identifier_info, - this.me, - to, - &out.msg, - this.participants.get(&this.me).cloned(), - to_network_id, - ); - let p = substream.send_message(protocol_message).poll_unpin(cx); - match ready!(p) { - Ok(()) => continue, - Err(e) => return Poll::Ready(Err(e)), - } + let id = this.next_msg_id.next(); + + crate::info!( + "Round {}: Sending message from {} to {:?} (id: {})", + out.msg.round(), + this.me, + out.recipient, + id, + ); + + // Get the substream to send the message to. + let key = StreamKey { + task_hash: this.task_hash, + round_id: i32::from(out.msg.round()), + }; + let substream = this.sub_streams.entry(key).or_insert_with(|| { + this.mux + .multiplex_with_forwarding(key, this.tx_forward.clone()) + }); + + let round_id = out.msg.round(); + let identifier_info = IdentifierInfo { + message_id: id, + round_id, + }; + let (to, to_network_id) = match out.recipient { + MessageDestination::AllParties => (None, None), + MessageDestination::OneParty(p) => (Some(p), this.participants.get(&p).cloned()), + }; + + let protocol_message = ProtocolMessage { + identifier_info, + sender: ParticipantInfo { + user_id: this.me, + ecdsa_key: this.participants.get(&this.me).cloned(), + }, + recipient: to.map(|user_id| ParticipantInfo { + user_id, + ecdsa_key: to_network_id, + }), + payload: crate::network::serialize(&out.msg)?, + }; + + match substream.send(protocol_message) { + Ok(()) => Ok(()), + Err(e) => Err(e), } + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { Poll::Ready(Ok(())) } @@ -215,7 +209,7 @@ where struct NextMessageId(AtomicU64); impl NextMessageId { - pub fn next(&self) -> MsgId { - self.0.fetch_add(1, core::sync::atomic::Ordering::Relaxed) + fn next(&self) -> MsgId { + self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed) } }