From 1494dfa2b81a066d4966dd2b18fdd1c8d7e92f51 Mon Sep 17 00:00:00 2001 From: Thomas Braun <38082993+tbraun96@users.noreply.github.com> Date: Fri, 8 Nov 2024 14:32:11 -0500 Subject: [PATCH] fix(gadget-sdk)!: prevent duplicate and self-referential messages (#458) Co-authored-by: Shady Khalifa --- Cargo.lock | 10 ++++ Cargo.toml | 1 + blueprint-manager/src/sdk/setup.rs | 38 ++++++------- blueprint-test-utils/src/test_ext.rs | 6 ++- .../contracts/lib/forge-std | 2 +- .../contracts/lib/forge-std | 2 +- sdk/Cargo.toml | 3 +- sdk/src/network/gossip.rs | 54 ++++++++++++------- sdk/src/network/handlers/gossip.rs | 6 +++ sdk/src/network/handlers/p2p.rs | 5 ++ sdk/src/network/messaging.rs | 8 +-- sdk/src/network/mod.rs | 32 +++++------ sdk/src/network/setup.rs | 10 +++- 13 files changed, 110 insertions(+), 67 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ea306fc8..98c2a9a4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4847,6 +4847,7 @@ dependencies = [ "libp2p", "lock_api", "log", + "lru-mem", "nix 0.29.0", "num-bigint 0.4.6", "parking_lot 0.12.3", @@ -7514,6 +7515,15 @@ dependencies = [ "linked-hash-map", ] +[[package]] +name = "lru-mem" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf5c8c26d903a41c80d4cc171940a57a4d1bc51139ebd6aad87e2f9ae3774780" +dependencies = [ + "hashbrown 0.14.5", +] + [[package]] name = "mach" version = "0.3.2" diff --git a/Cargo.toml b/Cargo.toml index bb2f5fb4..11ca378a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -197,6 +197,7 @@ testcontainers = { version = "0.20.1" } symbiotic-rs = { version = "0.1.0" } dashmap = "6.1.0" bincode2 = "2.0.1" +lru-mem = "0.3.0" [profile.dev.package.backtrace] opt-level = 3 diff --git a/blueprint-manager/src/sdk/setup.rs b/blueprint-manager/src/sdk/setup.rs index fe7ee90a..94755132 100644 --- a/blueprint-manager/src/sdk/setup.rs +++ b/blueprint-manager/src/sdk/setup.rs @@ -1,12 +1,13 @@ -use std::collections::BTreeMap; -use std::time::Duration; - +use futures::stream::FuturesOrdered; +use futures::StreamExt; use gadget_io::tokio::task::JoinHandle; use gadget_sdk::clients::tangle::runtime::TangleRuntimeClient; use gadget_sdk::network::Network; use gadget_sdk::prometheus::PrometheusConfig; use gadget_sdk::store::{ECDSAKeyStore, KeyValueStoreBackend}; use sp_core::{keccak_256, sr25519, Pair}; +use std::collections::BTreeMap; +use std::time::Duration; use crate::sdk::config::SingleGadgetConfig; pub use gadget_io::KeystoreContainer; @@ -110,28 +111,29 @@ pub async fn wait_for_connection_to_bootnodes( debug!("Waiting for {n_required} peers to show up across {n_networks} networks"); - let mut tasks = gadget_io::tokio::task::JoinSet::new(); + let mut tasks = FuturesOrdered::new(); // For each network, we start a task that checks if we have enough peers connected // and then we wait for all of them to finish. - let wait_for_peers = |handle: GossipHandle, n_required| async move { - 'inner: loop { - let n_connected = handle.connected_peers(); - if n_connected >= n_required { - break 'inner; - } - let topic = handle.topic(); - debug!("`{topic}`: We currently have {n_connected}/{n_required} peers connected to network"); - gadget_io::tokio::time::sleep(Duration::from_millis(1000)).await; - } - }; - for handle in handles.values() { - tasks.spawn(wait_for_peers(handle.clone(), n_required)); + tasks.push_back(wait_for_peers(handle, n_required)); } + // Wait for all tasks to finish - while tasks.join_next().await.is_some() {} + tasks.collect::<()>().await; Ok(()) } + +async fn wait_for_peers(handle: &GossipHandle, required: usize) { + loop { + let n_connected = handle.connected_peers(); + if n_connected >= required { + return; + } + let topic = handle.topic(); + debug!("`{topic}`: We currently have {n_connected}/{required} peers connected to network"); + gadget_io::tokio::time::sleep(Duration::from_millis(1000)).await; + } +} diff --git a/blueprint-test-utils/src/test_ext.rs b/blueprint-test-utils/src/test_ext.rs index 8b210c58..1237c2dc 100644 --- a/blueprint-test-utils/src/test_ext.rs +++ b/blueprint-test-utils/src/test_ext.rs @@ -309,10 +309,12 @@ pub async fn new_test_ext_blueprint_manager< pub fn find_open_tcp_bind_port() -> u16 { let listener = std::net::TcpListener::bind(format!("{LOCAL_BIND_ADDR}:0")) .expect("Should bind to localhost"); - listener + let port = listener .local_addr() .expect("Should have a local address") - .port() + .port(); + drop(listener); + port } pub struct LocalhostTestExt { diff --git a/blueprints/incredible-squaring-eigenlayer/contracts/lib/forge-std b/blueprints/incredible-squaring-eigenlayer/contracts/lib/forge-std index 1eea5bae..1de6eecf 160000 --- a/blueprints/incredible-squaring-eigenlayer/contracts/lib/forge-std +++ b/blueprints/incredible-squaring-eigenlayer/contracts/lib/forge-std @@ -1 +1 @@ -Subproject commit 1eea5bae12ae557d589f9f0f0edae2faa47cb262 +Subproject commit 1de6eecf821de7fe2c908cc48d3ab3dced20717f diff --git a/blueprints/incredible-squaring/contracts/lib/forge-std b/blueprints/incredible-squaring/contracts/lib/forge-std index 1eea5bae..1de6eecf 160000 --- a/blueprints/incredible-squaring/contracts/lib/forge-std +++ b/blueprints/incredible-squaring/contracts/lib/forge-std @@ -1 +1 @@ -Subproject commit 1eea5bae12ae557d589f9f0f0edae2faa47cb262 +Subproject commit 1de6eecf821de7fe2c908cc48d3ab3dced20717f diff --git a/sdk/Cargo.toml b/sdk/Cargo.toml index 79afb452..c3a25281 100644 --- a/sdk/Cargo.toml +++ b/sdk/Cargo.toml @@ -29,7 +29,6 @@ url = { workspace = true, features = ["serde"] } uuid = { workspace = true } failure = { workspace = true } num-bigint = { workspace = true } - # Keystore deps ed25519-zebra = { workspace = true } k256 = { workspace = true, features = ["ecdsa", "ecdsa-core", "arithmetic"] } @@ -92,6 +91,8 @@ gadget-blueprint-proc-macro = { workspace = true, default-features = false } gadget-context-derive = { workspace = true, default-features = false } gadget-blueprint-proc-macro-core = { workspace = true, default-features = false } +lru-mem = { workspace = true } + # Benchmarking deps sysinfo = { workspace = true } dashmap = { workspace = true } diff --git a/sdk/src/network/gossip.rs b/sdk/src/network/gossip.rs index 7699bf6c..851f9496 100644 --- a/sdk/src/network/gossip.rs +++ b/sdk/src/network/gossip.rs @@ -4,6 +4,8 @@ clippy::module_name_repetitions, clippy::exhaustive_enums )] +use crate::error::Error; +use crate::{error, trace, warn}; use async_trait::async_trait; use ecdsa::Public; use gadget_io::tokio::sync::mpsc::UnboundedSender; @@ -13,15 +15,13 @@ use libp2p::kad::store::MemoryStore; use libp2p::{ gossipsub, mdns, request_response, swarm::NetworkBehaviour, swarm::SwarmEvent, PeerId, }; +use lru_mem::LruCache; use serde::{Deserialize, Serialize}; -use sp_core::ecdsa; +use sp_core::{ecdsa, sha2_256}; use std::collections::BTreeMap; use std::sync::atomic::AtomicU32; use std::sync::Arc; -use crate::error::Error; -use crate::{error, trace, warn}; - use super::{Network, ParticipantInfo, ProtocolMessage}; /// Maximum allowed size for a Signed Message. @@ -48,6 +48,7 @@ pub struct NetworkServiceWithoutSwarm<'a> { pub ecdsa_peer_id_to_libp2p_id: Arc>>, pub ecdsa_key: &'a ecdsa::Pair, pub span: tracing::Span, + pub my_id: PeerId, } impl<'a> NetworkServiceWithoutSwarm<'a> { @@ -61,6 +62,7 @@ impl<'a> NetworkServiceWithoutSwarm<'a> { ecdsa_peer_id_to_libp2p_id: &self.ecdsa_peer_id_to_libp2p_id, ecdsa_key: self.ecdsa_key, span: &self.span, + my_id: self.my_id, } } } @@ -71,6 +73,7 @@ pub struct NetworkService<'a> { pub ecdsa_peer_id_to_libp2p_id: &'a Arc>>, pub ecdsa_key: &'a ecdsa::Pair, pub span: &'a tracing::Span, + pub my_id: PeerId, } impl NetworkService<'_> { @@ -247,13 +250,14 @@ impl NetworkService<'_> { } } -#[derive(Clone)] pub struct GossipHandle { pub topic: IdentTopic, pub tx_to_outbound: UnboundedSender, pub rx_from_inbound: Arc>>>, pub connected_peers: Arc, pub ecdsa_peer_id_to_libp2p_id: Arc>>, + pub recent_messages: parking_lot::Mutex>, + pub my_id: PeerId, } impl GossipHandle { @@ -338,18 +342,29 @@ enum MessageType { #[async_trait] impl Network for GossipHandle { async fn next_message(&self) -> Option { - let mut lock = self - .rx_from_inbound - .try_lock() - .expect("There should be only a single caller for `next_message`"); + loop { + let mut lock = self + .rx_from_inbound + .try_lock() + .expect("There should be only a single caller for `next_message`"); - let message = lock.recv().await?; - match bincode::deserialize(&message) { - Ok(message) => Some(message), - Err(e) => { - error!("Failed to deserialize message: {e}"); - drop(lock); - Network::next_message(self).await + let message_bytes = lock.recv().await?; + drop(lock); + match bincode::deserialize::(&message_bytes) { + Ok(message) => { + let hash = sha2_256(&message.payload); + let mut map = self.recent_messages.lock(); + if map + .insert(hash, ()) + .expect("Should not exceed memory limit (rx)") + .is_none() + { + return Some(message); + } + } + Err(e) => { + error!("Failed to deserialize message: {e}"); + } } } } @@ -377,14 +392,17 @@ impl Network for GossipHandle { MessageType::Broadcast }; + let raw_payload = bincode::serialize(&message).map_err(|e| Error::Network { + reason: format!("Failed to serialize message: {e}"), + })?; let payload_inner = match message_type { MessageType::Broadcast => GossipOrRequestResponse::Gossip(GossipMessage { topic: self.topic.to_string(), - raw_payload: bincode::serialize(&message).expect("Should serialize"), + raw_payload, }), MessageType::P2P(_) => GossipOrRequestResponse::Request(MyBehaviourRequest::Message { topic: self.topic.to_string(), - raw_payload: bincode::serialize(&message).expect("Should serialize"), + raw_payload, }), }; diff --git a/sdk/src/network/handlers/gossip.rs b/sdk/src/network/handlers/gossip.rs index 89fe3760..e3a9e750 100644 --- a/sdk/src/network/handlers/gossip.rs +++ b/sdk/src/network/handlers/gossip.rs @@ -78,6 +78,12 @@ impl NetworkService<'_> { error!("Got message from unknown peer"); return; }; + + // Reject messages from self + if origin == self.my_id { + return; + } + trace!("Got message from peer: {origin}"); match bincode::deserialize::(&message.data) { Ok(GossipMessage { topic, raw_payload }) => { diff --git a/sdk/src/network/handlers/p2p.rs b/sdk/src/network/handlers/p2p.rs index 38c35b8e..82c978b9 100644 --- a/sdk/src/network/handlers/p2p.rs +++ b/sdk/src/network/handlers/p2p.rs @@ -144,6 +144,11 @@ impl NetworkService<'_> { ) } Message { topic, raw_payload } => { + // Reject messages from self + if peer == self.my_id { + return; + } + let topic = IdentTopic::new(topic); if let Some((_, tx, _)) = self .inbound_mapping diff --git a/sdk/src/network/messaging.rs b/sdk/src/network/messaging.rs index 15ebb480..487cf954 100644 --- a/sdk/src/network/messaging.rs +++ b/sdk/src/network/messaging.rs @@ -37,7 +37,7 @@ pub trait MessageMetadata { } #[async_trait] -pub trait Network { +pub trait MessagingNetwork { type Message: MessageMetadata + Send + Sync + 'static; async fn next_message(&self) -> Option>; @@ -133,7 +133,7 @@ where M: MessageMetadata + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static, B: Backend + Send + Sync + 'static, L: LocalDelivery + Send + Sync + 'static, - N: Network + Send + Sync + 'static, + N: MessagingNetwork + Send + Sync + 'static, { backend: Arc, local_delivery: Arc, @@ -147,7 +147,7 @@ where M: MessageMetadata + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static, B: Backend + Send + Sync + 'static, L: LocalDelivery + Send + Sync + 'static, - N: Network + Send + Sync + 'static, + N: MessagingNetwork + Send + Sync + 'static, { fn clone(&self) -> Self { Self { @@ -165,7 +165,7 @@ where M: MessageMetadata + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static, B: Backend + Send + Sync + 'static, L: LocalDelivery + Send + Sync + 'static, - N: Network + Send + Sync + 'static, + N: MessagingNetwork + Send + Sync + 'static, { pub fn new(backend: B, local_delivery: L, network: N) -> Self { let this = Self { diff --git a/sdk/src/network/mod.rs b/sdk/src/network/mod.rs index efd6923a..b54fed0d 100644 --- a/sdk/src/network/mod.rs +++ b/sdk/src/network/mod.rs @@ -86,6 +86,7 @@ impl Display for ProtocolMessage { } #[async_trait] +#[auto_impl::auto_impl(&, Box, Arc)] pub trait Network: Send + Sync + 'static { async fn next_message(&self) -> Option; async fn send_message(&self, message: ProtocolMessage) -> Result<(), Error>; @@ -376,7 +377,6 @@ mod tests { use serde::{Deserialize, Serialize}; use sp_core::Pair; use std::collections::BTreeMap; - use tokio::sync::Barrier; const TOPIC: &str = "/gadget/test/1.0.0"; @@ -460,7 +460,7 @@ mod tests { } #[tokio::test(flavor = "multi_thread")] - async fn p2p() { + async fn test_p2p() { setup_log(); let nodes = stream::iter(0..NODE_COUNT) .map(|_| node()) @@ -470,10 +470,8 @@ mod tests { wait_for_nodes_connected(&nodes).await; let mut tasks = Vec::new(); - let barrier = Arc::new(Barrier::new(NODE_COUNT as usize)); for (i, node) in nodes.into_iter().enumerate() { - let barrier = barrier.clone(); - let task = tokio::spawn(run_protocol(node, i as u16, barrier)); + let task = tokio::spawn(run_protocol(node, i as u16)); tasks.push(task); } // Wait for all tasks to finish @@ -487,11 +485,7 @@ mod tests { ); } - async fn run_protocol( - node: N, - i: u16, - barrier: Arc, - ) -> Result<(), crate::Error> { + async fn run_protocol(node: N, i: u16) -> Result<(), crate::Error> { let task_hash = [0u8; 32]; // Safety note: We should be passed a NetworkMultiplexer, and all uses of the N: Network // used throughout the program must also use the multiplexer to prevent mixed messages. @@ -554,12 +548,12 @@ mod tests { m, msg.sender.user_id, ); - let _old = msgs.insert(msg.sender.user_id, m); - /*assert!( + let old = msgs.insert(msg.sender.user_id, m); + assert!( old.is_none(), "Duplicate message from node {}", msg.sender.user_id - );*/ + ); // Break if all messages are received if msgs.len() == usize::from(NODE_COUNT) - 1 { break; @@ -611,13 +605,12 @@ mod tests { m, msg.sender.user_id, ); - let _old = msgs.insert(msg.sender.user_id, m); - /* + let old = msgs.insert(msg.sender.user_id, m); assert!( old.is_none(), "Duplicate message from node {}", msg.sender.user_id - );*/ + ); // Break if all messages are received if msgs.len() == usize::from(NODE_COUNT) - 1 { break; @@ -662,20 +655,18 @@ mod tests { m, msg.sender.user_id, ); - let _old = msgs.insert(msg.sender.user_id, m); - /* + let old = msgs.insert(msg.sender.user_id, m); assert!( old.is_none(), "Duplicate message from node {}", msg.sender.user_id - );*/ + ); // Break if all messages are received if msgs.len() == usize::from(NODE_COUNT) - 1 { break; } } crate::debug!("Done r3 w/ {i}"); - let _ = barrier.wait().await; crate::info!(node = i, "Protocol completed"); @@ -761,6 +752,7 @@ mod tests { let received_msg = subnetwork0.recv().await.unwrap(); assert_eq!(received_msg.payload, msg.payload); + tracing::info!("Done nested depth = {cur_depth}/{max_depth}"); Box::pin(nested_multiplex( cur_depth + 1, diff --git a/sdk/src/network/setup.rs b/sdk/src/network/setup.rs index 1e659782..55b47d5e 100644 --- a/sdk/src/network/setup.rs +++ b/sdk/src/network/setup.rs @@ -15,6 +15,7 @@ use gadget_io::tokio::select; use gadget_io::tokio::sync::{Mutex, RwLock}; use gadget_io::tokio::task::{spawn, JoinHandle}; use libp2p::Multiaddr; +use lru_mem::LruCache; use sp_core::ecdsa; use std::collections::BTreeMap; use std::error::Error; @@ -159,6 +160,8 @@ pub fn multiplexed_libp2p_network(config: NetworkConfig) -> NetworkResult { let networks = topics; + let my_id = identity.public().to_peer_id(); + let mut swarm = libp2p::SwarmBuilder::with_existing_identity(identity) .with_tokio() .with_tcp( @@ -266,15 +269,16 @@ pub fn multiplexed_libp2p_network(config: NetworkConfig) -> NetworkResult { tx_to_outbound: tx_to_outbound.clone(), rx_from_inbound: Arc::new(Mutex::new(inbound_rx)), ecdsa_peer_id_to_libp2p_id: ecdsa_peer_id_to_libp2p_id.clone(), + // Each key is 32 bytes, therefore 512 messages hashes can be stored in the set + recent_messages: LruCache::new(16 * 1024).into(), + my_id, }, ); } let ips_to_bind_to = vec![ - IpAddr::from_str("127.0.0.1").unwrap(), IpAddr::from_str("0.0.0.0").unwrap(), IpAddr::from_str("::1").unwrap(), - IpAddr::from_str("::").unwrap(), ]; for addr in ips_to_bind_to { @@ -300,7 +304,9 @@ pub fn multiplexed_libp2p_network(config: NetworkConfig) -> NetworkResult { ecdsa_peer_id_to_libp2p_id, ecdsa_key: &ecdsa_key, span: tracing::debug_span!(parent: &span, "network_service"), + my_id, }; + loop { select! { // Setup outbound channel