Skip to content

Commit

Permalink
fix(gadget-sdk)!: prevent duplicate and self-referential messages (#458)
Browse files Browse the repository at this point in the history
Co-authored-by: Shady Khalifa <[email protected]>
  • Loading branch information
tbraun96 and shekohex authored Nov 8, 2024
1 parent ffe2f06 commit 1494dfa
Show file tree
Hide file tree
Showing 13 changed files with 110 additions and 67 deletions.
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ testcontainers = { version = "0.20.1" }
symbiotic-rs = { version = "0.1.0" }
dashmap = "6.1.0"
bincode2 = "2.0.1"
lru-mem = "0.3.0"

[profile.dev.package.backtrace]
opt-level = 3
Expand Down
38 changes: 20 additions & 18 deletions blueprint-manager/src/sdk/setup.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use std::collections::BTreeMap;
use std::time::Duration;

use futures::stream::FuturesOrdered;
use futures::StreamExt;
use gadget_io::tokio::task::JoinHandle;
use gadget_sdk::clients::tangle::runtime::TangleRuntimeClient;
use gadget_sdk::network::Network;
use gadget_sdk::prometheus::PrometheusConfig;
use gadget_sdk::store::{ECDSAKeyStore, KeyValueStoreBackend};
use sp_core::{keccak_256, sr25519, Pair};
use std::collections::BTreeMap;
use std::time::Duration;

use crate::sdk::config::SingleGadgetConfig;
pub use gadget_io::KeystoreContainer;
Expand Down Expand Up @@ -110,28 +111,29 @@ pub async fn wait_for_connection_to_bootnodes(

debug!("Waiting for {n_required} peers to show up across {n_networks} networks");

let mut tasks = gadget_io::tokio::task::JoinSet::new();
let mut tasks = FuturesOrdered::new();

// For each network, we start a task that checks if we have enough peers connected
// and then we wait for all of them to finish.

let wait_for_peers = |handle: GossipHandle, n_required| async move {
'inner: loop {
let n_connected = handle.connected_peers();
if n_connected >= n_required {
break 'inner;
}
let topic = handle.topic();
debug!("`{topic}`: We currently have {n_connected}/{n_required} peers connected to network");
gadget_io::tokio::time::sleep(Duration::from_millis(1000)).await;
}
};

for handle in handles.values() {
tasks.spawn(wait_for_peers(handle.clone(), n_required));
tasks.push_back(wait_for_peers(handle, n_required));
}

// Wait for all tasks to finish
while tasks.join_next().await.is_some() {}
tasks.collect::<()>().await;

Ok(())
}

async fn wait_for_peers(handle: &GossipHandle, required: usize) {
loop {
let n_connected = handle.connected_peers();
if n_connected >= required {
return;
}
let topic = handle.topic();
debug!("`{topic}`: We currently have {n_connected}/{required} peers connected to network");
gadget_io::tokio::time::sleep(Duration::from_millis(1000)).await;
}
}
6 changes: 4 additions & 2 deletions blueprint-test-utils/src/test_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,10 +309,12 @@ pub async fn new_test_ext_blueprint_manager<
pub fn find_open_tcp_bind_port() -> u16 {
let listener = std::net::TcpListener::bind(format!("{LOCAL_BIND_ADDR}:0"))
.expect("Should bind to localhost");
listener
let port = listener
.local_addr()
.expect("Should have a local address")
.port()
.port();
drop(listener);
port
}

pub struct LocalhostTestExt {
Expand Down
Submodule forge-std updated 1 files
+1 −1 package.json
2 changes: 1 addition & 1 deletion blueprints/incredible-squaring/contracts/lib/forge-std
Submodule forge-std updated 1 files
+1 −1 package.json
3 changes: 2 additions & 1 deletion sdk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ url = { workspace = true, features = ["serde"] }
uuid = { workspace = true }
failure = { workspace = true }
num-bigint = { workspace = true }

# Keystore deps
ed25519-zebra = { workspace = true }
k256 = { workspace = true, features = ["ecdsa", "ecdsa-core", "arithmetic"] }
Expand Down Expand Up @@ -92,6 +91,8 @@ gadget-blueprint-proc-macro = { workspace = true, default-features = false }
gadget-context-derive = { workspace = true, default-features = false }
gadget-blueprint-proc-macro-core = { workspace = true, default-features = false }

lru-mem = { workspace = true }

# Benchmarking deps
sysinfo = { workspace = true }
dashmap = { workspace = true }
Expand Down
54 changes: 36 additions & 18 deletions sdk/src/network/gossip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
clippy::module_name_repetitions,
clippy::exhaustive_enums
)]
use crate::error::Error;
use crate::{error, trace, warn};
use async_trait::async_trait;
use ecdsa::Public;
use gadget_io::tokio::sync::mpsc::UnboundedSender;
Expand All @@ -13,15 +15,13 @@ use libp2p::kad::store::MemoryStore;
use libp2p::{
gossipsub, mdns, request_response, swarm::NetworkBehaviour, swarm::SwarmEvent, PeerId,
};
use lru_mem::LruCache;
use serde::{Deserialize, Serialize};
use sp_core::ecdsa;
use sp_core::{ecdsa, sha2_256};
use std::collections::BTreeMap;
use std::sync::atomic::AtomicU32;
use std::sync::Arc;

use crate::error::Error;
use crate::{error, trace, warn};

use super::{Network, ParticipantInfo, ProtocolMessage};

/// Maximum allowed size for a Signed Message.
Expand All @@ -48,6 +48,7 @@ pub struct NetworkServiceWithoutSwarm<'a> {
pub ecdsa_peer_id_to_libp2p_id: Arc<RwLock<BTreeMap<ecdsa::Public, PeerId>>>,
pub ecdsa_key: &'a ecdsa::Pair,
pub span: tracing::Span,
pub my_id: PeerId,
}

impl<'a> NetworkServiceWithoutSwarm<'a> {
Expand All @@ -61,6 +62,7 @@ impl<'a> NetworkServiceWithoutSwarm<'a> {
ecdsa_peer_id_to_libp2p_id: &self.ecdsa_peer_id_to_libp2p_id,
ecdsa_key: self.ecdsa_key,
span: &self.span,
my_id: self.my_id,
}
}
}
Expand All @@ -71,6 +73,7 @@ pub struct NetworkService<'a> {
pub ecdsa_peer_id_to_libp2p_id: &'a Arc<RwLock<BTreeMap<ecdsa::Public, PeerId>>>,
pub ecdsa_key: &'a ecdsa::Pair,
pub span: &'a tracing::Span,
pub my_id: PeerId,
}

impl NetworkService<'_> {
Expand Down Expand Up @@ -247,13 +250,14 @@ impl NetworkService<'_> {
}
}

#[derive(Clone)]
pub struct GossipHandle {
pub topic: IdentTopic,
pub tx_to_outbound: UnboundedSender<IntraNodePayload>,
pub rx_from_inbound: Arc<Mutex<gadget_io::tokio::sync::mpsc::UnboundedReceiver<Vec<u8>>>>,
pub connected_peers: Arc<AtomicU32>,
pub ecdsa_peer_id_to_libp2p_id: Arc<RwLock<BTreeMap<ecdsa::Public, PeerId>>>,
pub recent_messages: parking_lot::Mutex<LruCache<[u8; 32], ()>>,
pub my_id: PeerId,
}

impl GossipHandle {
Expand Down Expand Up @@ -338,18 +342,29 @@ enum MessageType {
#[async_trait]
impl Network for GossipHandle {
async fn next_message(&self) -> Option<ProtocolMessage> {
let mut lock = self
.rx_from_inbound
.try_lock()
.expect("There should be only a single caller for `next_message`");
loop {
let mut lock = self
.rx_from_inbound
.try_lock()
.expect("There should be only a single caller for `next_message`");

let message = lock.recv().await?;
match bincode::deserialize(&message) {
Ok(message) => Some(message),
Err(e) => {
error!("Failed to deserialize message: {e}");
drop(lock);
Network::next_message(self).await
let message_bytes = lock.recv().await?;
drop(lock);
match bincode::deserialize::<ProtocolMessage>(&message_bytes) {
Ok(message) => {
let hash = sha2_256(&message.payload);
let mut map = self.recent_messages.lock();
if map
.insert(hash, ())
.expect("Should not exceed memory limit (rx)")
.is_none()
{
return Some(message);
}
}
Err(e) => {
error!("Failed to deserialize message: {e}");
}
}
}
}
Expand Down Expand Up @@ -377,14 +392,17 @@ impl Network for GossipHandle {
MessageType::Broadcast
};

let raw_payload = bincode::serialize(&message).map_err(|e| Error::Network {
reason: format!("Failed to serialize message: {e}"),
})?;
let payload_inner = match message_type {
MessageType::Broadcast => GossipOrRequestResponse::Gossip(GossipMessage {
topic: self.topic.to_string(),
raw_payload: bincode::serialize(&message).expect("Should serialize"),
raw_payload,
}),
MessageType::P2P(_) => GossipOrRequestResponse::Request(MyBehaviourRequest::Message {
topic: self.topic.to_string(),
raw_payload: bincode::serialize(&message).expect("Should serialize"),
raw_payload,
}),
};

Expand Down
6 changes: 6 additions & 0 deletions sdk/src/network/handlers/gossip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ impl NetworkService<'_> {
error!("Got message from unknown peer");
return;
};

// Reject messages from self
if origin == self.my_id {
return;
}

trace!("Got message from peer: {origin}");
match bincode::deserialize::<GossipMessage>(&message.data) {
Ok(GossipMessage { topic, raw_payload }) => {
Expand Down
5 changes: 5 additions & 0 deletions sdk/src/network/handlers/p2p.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,11 @@ impl NetworkService<'_> {
)
}
Message { topic, raw_payload } => {
// Reject messages from self
if peer == self.my_id {
return;
}

let topic = IdentTopic::new(topic);
if let Some((_, tx, _)) = self
.inbound_mapping
Expand Down
8 changes: 4 additions & 4 deletions sdk/src/network/messaging.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub trait MessageMetadata {
}

#[async_trait]
pub trait Network {
pub trait MessagingNetwork {
type Message: MessageMetadata + Send + Sync + 'static;

async fn next_message(&self) -> Option<Payload<Self::Message>>;
Expand Down Expand Up @@ -133,7 +133,7 @@ where
M: MessageMetadata + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
B: Backend<M> + Send + Sync + 'static,
L: LocalDelivery<M> + Send + Sync + 'static,
N: Network<Message = M> + Send + Sync + 'static,
N: MessagingNetwork<Message = M> + Send + Sync + 'static,
{
backend: Arc<B>,
local_delivery: Arc<L>,
Expand All @@ -147,7 +147,7 @@ where
M: MessageMetadata + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
B: Backend<M> + Send + Sync + 'static,
L: LocalDelivery<M> + Send + Sync + 'static,
N: Network<Message = M> + Send + Sync + 'static,
N: MessagingNetwork<Message = M> + Send + Sync + 'static,
{
fn clone(&self) -> Self {
Self {
Expand All @@ -165,7 +165,7 @@ where
M: MessageMetadata + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static,
B: Backend<M> + Send + Sync + 'static,
L: LocalDelivery<M> + Send + Sync + 'static,
N: Network<Message = M> + Send + Sync + 'static,
N: MessagingNetwork<Message = M> + Send + Sync + 'static,
{
pub fn new(backend: B, local_delivery: L, network: N) -> Self {
let this = Self {
Expand Down
Loading

0 comments on commit 1494dfa

Please sign in to comment.