From 8ccf31d1de4b7d444b794ca363f7aa4ed9df05d6 Mon Sep 17 00:00:00 2001 From: Hansie Odendaal Date: Fri, 11 Oct 2024 16:15:33 +0200 Subject: [PATCH] Add close RPC connections Added the ability to close RPC connections for a given peer: - When dialing a peer (in establishing a new connection), a drop old connections switch can be included in the dial request from the client so that only only one RPC client connection will ever be active. --- .../src/commands/command/dial_peer.rs | 2 +- .../base_node/sync/block_sync/synchronizer.rs | 2 +- .../sync/header_sync/synchronizer.rs | 2 +- .../sync/horizon_state_sync/synchronizer.rs | 2 +- base_layer/core/tests/tests/node_service.rs | 2 +- .../src/connectivity_service/service.rs | 10 +- .../utxo_scanner_service/utxo_scanner_task.rs | 2 +- .../transaction_service_tests/service.rs | 8 +- base_layer/wallet_ffi/src/lib.rs | 6 +- comms/core/examples/stress/service.rs | 2 +- comms/core/src/builder/tests.rs | 4 +- comms/core/src/connection_manager/dialer.rs | 9 +- comms/core/src/connection_manager/listener.rs | 1 + comms/core/src/connection_manager/manager.rs | 13 +- .../src/connection_manager/peer_connection.rs | 11 +- .../core/src/connection_manager/requester.rs | 16 +- .../tests/listener_dialer.rs | 8 +- comms/core/src/connectivity/manager.rs | 22 +- comms/core/src/connectivity/requester.rs | 11 +- comms/core/src/proto/rpc.proto | 2 + comms/core/src/protocol/messaging/outbound.rs | 2 +- comms/core/src/protocol/rpc/client/mod.rs | 12 +- comms/core/src/protocol/rpc/context.rs | 5 +- comms/core/src/protocol/rpc/handshake.rs | 18 +- comms/core/src/protocol/rpc/server/mod.rs | 2 - .../protocol/rpc/test/comms_integration.rs | 2 +- comms/core/src/protocol/rpc/test/handshake.rs | 22 +- .../test_utils/mocks/connection_manager.rs | 4 +- .../test_utils/mocks/connectivity_manager.rs | 6 +- .../src/test_utils/mocks/peer_connection.rs | 3 + comms/core/tests/tests/rpc.rs | 390 +++++++++++++++++- comms/core/tests/tests/rpc_stress.rs | 2 +- comms/core/tests/tests/substream_stress.rs | 2 +- comms/dht/src/actor.rs | 4 +- .../dht/src/network_discovery/discovering.rs | 2 +- comms/dht/tests/dht.rs | 16 +- 36 files changed, 563 insertions(+), 64 deletions(-) diff --git a/applications/minotari_node/src/commands/command/dial_peer.rs b/applications/minotari_node/src/commands/command/dial_peer.rs index a37feb0bf8b..699cd3f53f9 100644 --- a/applications/minotari_node/src/commands/command/dial_peer.rs +++ b/applications/minotari_node/src/commands/command/dial_peer.rs @@ -53,7 +53,7 @@ impl CommandContext { let start = Instant::now(); println!("☎️ Dialing peer..."); - match connectivity.dial_peer(dest_node_id).await { + match connectivity.dial_peer(dest_node_id, false).await { Ok(connection) => { println!("⚡️ Peer connected in {}ms!", start.elapsed().as_millis()); println!("Connection: {}", connection); diff --git a/base_layer/core/src/base_node/sync/block_sync/synchronizer.rs b/base_layer/core/src/base_node/sync/block_sync/synchronizer.rs index 796337cffa5..98941b0b316 100644 --- a/base_layer/core/src/base_node/sync/block_sync/synchronizer.rs +++ b/base_layer/core/src/base_node/sync/block_sync/synchronizer.rs @@ -217,7 +217,7 @@ impl<'a, B: BlockchainBackend + 'static> BlockSynchronizer<'a, B> { } async fn connect_to_sync_peer(&self, peer: NodeId) -> Result { - let connection = self.connectivity.dial_peer(peer).await?; + let connection = self.connectivity.dial_peer(peer, false).await?; Ok(connection) } diff --git a/base_layer/core/src/base_node/sync/header_sync/synchronizer.rs b/base_layer/core/src/base_node/sync/header_sync/synchronizer.rs index 12514a63bfa..826b4799470 100644 --- a/base_layer/core/src/base_node/sync/header_sync/synchronizer.rs +++ b/base_layer/core/src/base_node/sync/header_sync/synchronizer.rs @@ -230,7 +230,7 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { async fn dial_sync_peer(&self, node_id: &NodeId) -> Result { let timer = Instant::now(); debug!(target: LOG_TARGET, "Dialing {} sync peer", node_id); - let conn = self.connectivity.dial_peer(node_id.clone()).await?; + let conn = self.connectivity.dial_peer(node_id.clone(), false).await?; info!( target: LOG_TARGET, "Successfully dialed sync peer {} in {:.2?}", diff --git a/base_layer/core/src/base_node/sync/horizon_state_sync/synchronizer.rs b/base_layer/core/src/base_node/sync/horizon_state_sync/synchronizer.rs index ded89d17d77..5980a294aae 100644 --- a/base_layer/core/src/base_node/sync/horizon_state_sync/synchronizer.rs +++ b/base_layer/core/src/base_node/sync/horizon_state_sync/synchronizer.rs @@ -277,7 +277,7 @@ impl<'a, B: BlockchainBackend + 'static> HorizonStateSynchronization<'a, B> { async fn dial_sync_peer(&self, node_id: &NodeId) -> Result { let timer = Instant::now(); debug!(target: LOG_TARGET, "Dialing {} sync peer", node_id); - let conn = self.connectivity.dial_peer(node_id.clone()).await?; + let conn = self.connectivity.dial_peer(node_id.clone(), false).await?; info!( target: LOG_TARGET, "Successfully dialed sync peer {} in {:.2?}", diff --git a/base_layer/core/tests/tests/node_service.rs b/base_layer/core/tests/tests/node_service.rs index f45703d2dd0..5b33ef17eed 100644 --- a/base_layer/core/tests/tests/node_service.rs +++ b/base_layer/core/tests/tests/node_service.rs @@ -420,7 +420,7 @@ async fn propagate_and_forward_invalid_block() { alice_node .comms .connectivity() - .dial_peer(bob_node.node_identity.node_id().clone()) + .dial_peer(bob_node.node_identity.node_id().clone(), false) .await .unwrap(); wait_until_online(&[&alice_node, &bob_node, &carol_node, &dan_node]).await; diff --git a/base_layer/wallet/src/connectivity_service/service.rs b/base_layer/wallet/src/connectivity_service/service.rs index ab5ac039bf2..6705643b8cd 100644 --- a/base_layer/wallet/src/connectivity_service/service.rs +++ b/base_layer/wallet/src/connectivity_service/service.rs @@ -392,7 +392,7 @@ impl WalletConnectivityService { } async fn try_setup_rpc_pool(&mut self, peer_node_id: NodeId) -> Result { - let conn = match self.try_dial_peer(peer_node_id.clone()).await? { + let conn = match self.try_dial_peer(peer_node_id.clone(), true).await? { Some(c) => c, None => { warn!(target: LOG_TARGET, "Could not dial base node peer '{}'", peer_node_id); @@ -413,14 +413,18 @@ impl WalletConnectivityService { Ok(true) } - async fn try_dial_peer(&mut self, peer: NodeId) -> Result, WalletConnectivityError> { + async fn try_dial_peer( + &mut self, + peer: NodeId, + drop_old_connections: bool, + ) -> Result, WalletConnectivityError> { tokio::select! { biased; _ = self.base_node_watch_receiver.changed() => { Ok(None) } - result = self.connectivity.dial_peer(peer) => { + result = self.connectivity.dial_peer(peer, drop_old_connections) => { Ok(Some(result?)) } } diff --git a/base_layer/wallet/src/utxo_scanner_service/utxo_scanner_task.rs b/base_layer/wallet/src/utxo_scanner_service/utxo_scanner_task.rs index 41fb837de38..071c7c549b1 100644 --- a/base_layer/wallet/src/utxo_scanner_service/utxo_scanner_task.rs +++ b/base_layer/wallet/src/utxo_scanner_service/utxo_scanner_task.rs @@ -184,7 +184,7 @@ where target: LOG_TARGET, "Attempting UTXO sync with seed peer {} ({})", self.peer_index, peer, ); - match self.resources.comms_connectivity.dial_peer(peer.clone()).await { + match self.resources.comms_connectivity.dial_peer(peer.clone(), true).await { Ok(conn) => Ok(conn), Err(e) => { self.publish_event(UtxoScannerEvent::ConnectionFailedToBaseNode { diff --git a/base_layer/wallet/tests/transaction_service_tests/service.rs b/base_layer/wallet/tests/transaction_service_tests/service.rs index 15fea49f586..e39e4427b5d 100644 --- a/base_layer/wallet/tests/transaction_service_tests/service.rs +++ b/base_layer/wallet/tests/transaction_service_tests/service.rs @@ -588,7 +588,7 @@ async fn manage_single_transaction() { let _peer_connection = bob_comms .connectivity() - .dial_peer(alice_node_identity.node_id().clone()) + .dial_peer(alice_node_identity.node_id().clone(), false) .await .unwrap(); @@ -753,7 +753,7 @@ async fn large_interactive_transaction() { // Verify that Alice and Bob are connected let _peer_connection = bob_comms .connectivity() - .dial_peer(alice_node_identity.node_id().clone()) + .dial_peer(alice_node_identity.node_id().clone(), false) .await .unwrap(); @@ -2172,7 +2172,7 @@ async fn manage_multiple_transactions() { let _peer_connection = bob_comms .connectivity() - .dial_peer(alice_node_identity.node_id().clone()) + .dial_peer(alice_node_identity.node_id().clone(), false) .await .unwrap(); sleep(Duration::from_secs(3)).await; @@ -2180,7 +2180,7 @@ async fn manage_multiple_transactions() { // Connect alice to carol let _peer_connection = alice_comms .connectivity() - .dial_peer(carol_node_identity.node_id().clone()) + .dial_peer(carol_node_identity.node_id().clone(), false) .await .unwrap(); diff --git a/base_layer/wallet_ffi/src/lib.rs b/base_layer/wallet_ffi/src/lib.rs index b522120bb73..9c0ec7d4a6c 100644 --- a/base_layer/wallet_ffi/src/lib.rs +++ b/base_layer/wallet_ffi/src/lib.rs @@ -12731,7 +12731,7 @@ mod test { .block_on( alice_wallet_comms .connectivity() - .dial_peer(bob_node_identity.node_id().clone()), + .dial_peer(bob_node_identity.node_id().clone(), false), ) .is_ok(); } @@ -12740,7 +12740,7 @@ mod test { .block_on( bob_wallet_comms .connectivity() - .dial_peer(alice_node_identity.node_id().clone()), + .dial_peer(alice_node_identity.node_id().clone(), false), ) .is_ok(); } @@ -12811,7 +12811,7 @@ mod test { let bob_comms_dial_peer = bob_wallet_runtime.block_on( bob_wallet_comms .connectivity() - .dial_peer(alice_node_identity.node_id().clone()), + .dial_peer(alice_node_identity.node_id().clone(), false), ); if let Ok(mut connection_to_alice) = bob_comms_dial_peer { if bob_wallet_runtime diff --git a/comms/core/examples/stress/service.rs b/comms/core/examples/stress/service.rs index 2199638f4b0..0f6f84c536c 100644 --- a/comms/core/examples/stress/service.rs +++ b/comms/core/examples/stress/service.rs @@ -227,7 +227,7 @@ impl StressTestService { self.comms_node.peer_manager().add_peer(peer).await?; println!("Dialing peer `{}`...", node_id.short_str()); let start = Instant::now(); - let conn = self.comms_node.connectivity().dial_peer(node_id).await?; + let conn = self.comms_node.connectivity().dial_peer(node_id, false).await?; println!("Dial completed successfully in {:.2?}", start.elapsed()); let outbound_tx = self.outbound_tx.clone(); let inbound_rx = self.inbound_rx.clone(); diff --git a/comms/core/src/builder/tests.rs b/comms/core/src/builder/tests.rs index 02626c75e74..de4f49a3d8d 100644 --- a/comms/core/src/builder/tests.rs +++ b/comms/core/src/builder/tests.rs @@ -166,7 +166,7 @@ async fn peer_to_peer_custom_protocols() { let mut conn_man_events2 = comms_node2.subscribe_connection_manager_events(); let mut conn1 = conn_man_requester1 - .dial_peer(node_identity2.node_id().clone()) + .dial_peer(node_identity2.node_id().clone(), false) .await .unwrap(); @@ -347,7 +347,7 @@ async fn peer_to_peer_messaging_simultaneous() { comms_node1 .connectivity() - .dial_peer(comms_node2.node_identity().node_id().clone()) + .dial_peer(comms_node2.node_identity().node_id().clone(), false) .await .unwrap(); // Simultaneously send messages between the two nodes diff --git a/comms/core/src/connection_manager/dialer.rs b/comms/core/src/connection_manager/dialer.rs index 357491ae225..65db0845e2b 100644 --- a/comms/core/src/connection_manager/dialer.rs +++ b/comms/core/src/connection_manager/dialer.rs @@ -81,6 +81,7 @@ pub(crate) enum DialerRequest { Dial( Box, Option>>, + bool, ), CancelPendingDial(NodeId), NotifyNewInboundConnection(Box), @@ -176,8 +177,8 @@ where debug!(target: LOG_TARGET, "Connection dialer got request: {:?}", request); match request { - Dial(peer, reply_tx) => { - self.handle_dial_peer_request(pending_dials, peer, reply_tx); + Dial(peer, reply_tx, drop_old_connections) => { + self.handle_dial_peer_request(pending_dials, peer, reply_tx, drop_old_connections); }, CancelPendingDial(peer_id) => { self.cancel_dial(&peer_id); @@ -318,6 +319,7 @@ where pending_dials: &mut DialFuturesUnordered, peer: Box, reply_tx: Option>>, + drop_old_connections: bool, ) { if self.is_pending_dial(&peer.node_id) { debug!( @@ -371,6 +373,7 @@ where let result = Self::perform_socket_upgrade_procedure( &peer_manager, &node_identity, + drop_old_connections, socket, addr.clone(), authenticated_public_key, @@ -421,6 +424,7 @@ where async fn perform_socket_upgrade_procedure( peer_manager: &PeerManager, node_identity: &NodeIdentity, + drop_old_connections: bool, mut socket: NoiseSocket, dialed_addr: Multiaddr, authenticated_public_key: CommsPublicKey, @@ -474,6 +478,7 @@ where muxer, dialed_addr, NodeId::from_public_key(&authenticated_public_key), + drop_old_connections, peer_identity.claim.features, CONNECTION_DIRECTION, conn_man_notifier, diff --git a/comms/core/src/connection_manager/listener.rs b/comms/core/src/connection_manager/listener.rs index 42b81440e27..4c37b4b0090 100644 --- a/comms/core/src/connection_manager/listener.rs +++ b/comms/core/src/connection_manager/listener.rs @@ -423,6 +423,7 @@ where muxer, peer_addr, peer.node_id.clone(), + false, peer.features, CONNECTION_DIRECTION, conn_man_notifier, diff --git a/comms/core/src/connection_manager/manager.rs b/comms/core/src/connection_manager/manager.rs index a646a3dd413..d24926e4f7a 100644 --- a/comms/core/src/connection_manager/manager.rs +++ b/comms/core/src/connection_manager/manager.rs @@ -385,11 +385,17 @@ where use ConnectionManagerRequest::{CancelDial, DialPeer, NotifyListening}; trace!(target: LOG_TARGET, "Connection manager got request: {:?}", request); match request { - DialPeer { node_id, reply_tx } => { + DialPeer { + node_id, + reply_tx, + drop_old_connections, + } => { let tracing_id = tracing::Span::current().id(); let span = span!(Level::TRACE, "connection_manager::handle_request"); span.follows_from(tracing_id); - self.dial_peer(node_id, reply_tx).instrument(span).await + self.dial_peer(node_id, reply_tx, drop_old_connections) + .instrument(span) + .await }, CancelDial(node_id) => { if let Err(err) = self.dialer_tx.send(DialerRequest::CancelPendingDial(node_id)).await { @@ -500,10 +506,11 @@ where &mut self, node_id: NodeId, reply: Option>>, + drop_old_connections: bool, ) { match self.peer_manager.find_by_node_id(&node_id).await { Ok(Some(peer)) => { - self.send_dialer_request(DialerRequest::Dial(Box::new(peer), reply)) + self.send_dialer_request(DialerRequest::Dial(Box::new(peer), reply, drop_old_connections)) .await; }, Ok(None) => { diff --git a/comms/core/src/connection_manager/peer_connection.rs b/comms/core/src/connection_manager/peer_connection.rs index 9b51b69129f..cb4edbdd5df 100644 --- a/comms/core/src/connection_manager/peer_connection.rs +++ b/comms/core/src/connection_manager/peer_connection.rs @@ -72,6 +72,7 @@ pub fn create( connection: Yamux, peer_addr: Multiaddr, peer_node_id: NodeId, + drop_old_connections: bool, peer_features: PeerFeatures, direction: ConnectionDirection, event_notifier: mpsc::Sender, @@ -91,6 +92,7 @@ pub fn create( id, peer_tx, peer_node_id.clone(), + drop_old_connections, peer_features, peer_addr, direction, @@ -131,6 +133,7 @@ pub type ConnectionId = usize; pub struct PeerConnection { id: ConnectionId, peer_node_id: NodeId, + drop_old_connections: bool, peer_features: PeerFeatures, request_tx: mpsc::Sender, address: Arc, @@ -148,6 +151,7 @@ impl PeerConnection { id: ConnectionId, request_tx: mpsc::Sender, peer_node_id: NodeId, + drop_old_connections: bool, peer_features: PeerFeatures, address: Multiaddr, direction: ConnectionDirection, @@ -157,6 +161,7 @@ impl PeerConnection { id, request_tx, peer_node_id, + drop_old_connections, peer_features, address: Arc::new(address), direction, @@ -256,15 +261,15 @@ impl PeerConnection { let protocol = ProtocolId::from_static(T::PROTOCOL_NAME); debug!( target: LOG_TARGET, - "Attempting to establish RPC protocol `{}` to peer `{}`", - String::from_utf8_lossy(&protocol), - self.peer_node_id + "Attempting to establish RPC protocol `{}` to peer `{}` (drop_old_connections {})", + String::from_utf8_lossy(&protocol), self.peer_node_id, self.drop_old_connections ); let framed = self.open_framed_substream(&protocol, RPC_MAX_FRAME_SIZE).await?; let rpc_client = builder .with_protocol_id(protocol) .with_node_id(self.peer_node_id.clone()) + .with_drop_old_connections(self.drop_old_connections) .with_terminate_signal(self.drop_notifier.to_signal()) .connect(framed) .await?; diff --git a/comms/core/src/connection_manager/requester.rs b/comms/core/src/connection_manager/requester.rs index 40a09da7f9d..3569dd7e0ce 100644 --- a/comms/core/src/connection_manager/requester.rs +++ b/comms/core/src/connection_manager/requester.rs @@ -22,6 +22,7 @@ use std::sync::Arc; +use log::trace; use tokio::sync::{broadcast, mpsc, oneshot}; use super::{error::ConnectionManagerError, peer_connection::PeerConnection}; @@ -29,6 +30,7 @@ use crate::{ connection_manager::manager::{ConnectionManagerEvent, ListenerInfo}, peer_manager::NodeId, }; +const LOG_TARGET: &str = "comms::connectivity::manager::requester"; /// Requests which are handled by the ConnectionManagerService #[derive(Debug)] @@ -37,6 +39,7 @@ pub enum ConnectionManagerRequest { DialPeer { node_id: NodeId, reply_tx: Option>>, + drop_old_connections: bool, }, /// Cancels a pending dial if one exists CancelDial(NodeId), @@ -77,7 +80,7 @@ impl ConnectionManagerRequester { /// Attempt to connect to a remote peer pub async fn dial_peer(&mut self, node_id: NodeId) -> Result { let (reply_tx, reply_rx) = oneshot::channel(); - self.send_dial_peer(node_id, Some(reply_tx)).await?; + self.send_dial_peer(node_id, Some(reply_tx), false).await?; reply_rx .await .map_err(|_| ConnectionManagerError::ActorRequestCanceled)? @@ -97,9 +100,18 @@ impl ConnectionManagerRequester { &mut self, node_id: NodeId, reply_tx: Option>>, + drop_old_connections: bool, ) -> Result<(), ConnectionManagerError> { + trace!( + target: LOG_TARGET, "send_dial_peer: peer: {}, drop_old_connections: {}", + node_id.short_str(), drop_old_connections + ); self.sender - .send(ConnectionManagerRequest::DialPeer { node_id, reply_tx }) + .send(ConnectionManagerRequest::DialPeer { + node_id, + reply_tx, + drop_old_connections, + }) .await .map_err(|_| ConnectionManagerError::SendToActorFailed)?; Ok(()) diff --git a/comms/core/src/connection_manager/tests/listener_dialer.rs b/comms/core/src/connection_manager/tests/listener_dialer.rs index e73f0523794..3b7193e78f4 100644 --- a/comms/core/src/connection_manager/tests/listener_dialer.rs +++ b/comms/core/src/connection_manager/tests/listener_dialer.rs @@ -130,7 +130,7 @@ async fn smoke() { let (reply_tx, reply_rx) = oneshot::channel(); request_tx - .send(DialerRequest::Dial(Box::new(peer), Some(reply_tx))) + .send(DialerRequest::Dial(Box::new(peer), Some(reply_tx), false)) .await .unwrap(); @@ -238,7 +238,7 @@ async fn banned() { let (reply_tx, reply_rx) = oneshot::channel(); request_tx - .send(DialerRequest::Dial(Box::new(peer), Some(reply_tx))) + .send(DialerRequest::Dial(Box::new(peer), Some(reply_tx), false)) .await .unwrap(); @@ -311,7 +311,7 @@ async fn excluded_yes() { let (reply_tx, reply_rx) = oneshot::channel(); request_tx - .send(DialerRequest::Dial(Box::new(peer), Some(reply_tx))) + .send(DialerRequest::Dial(Box::new(peer), Some(reply_tx), false)) .await .unwrap(); @@ -380,7 +380,7 @@ async fn excluded_no() { let (reply_tx, reply_rx) = oneshot::channel(); request_tx - .send(DialerRequest::Dial(Box::new(peer), Some(reply_tx))) + .send(DialerRequest::Dial(Box::new(peer), Some(reply_tx), false)) .await .unwrap(); diff --git a/comms/core/src/connectivity/manager.rs b/comms/core/src/connectivity/manager.rs index 899a64c0cb3..8e74f82d478 100644 --- a/comms/core/src/connectivity/manager.rs +++ b/comms/core/src/connectivity/manager.rs @@ -221,11 +221,17 @@ impl ConnectivityManagerActor { GetConnectivityStatus(reply) => { let _ = reply.send(self.status); }, - DialPeer { node_id, reply_tx } => { + DialPeer { + node_id, + reply_tx, + drop_old_connections, + } => { let tracing_id = tracing::Span::current().id(); let span = span!(Level::TRACE, "handle_dial_peer"); span.follows_from(tracing_id); - self.handle_dial_peer(node_id.clone(), reply_tx).instrument(span).await; + self.handle_dial_peer(node_id.clone(), reply_tx, drop_old_connections) + .instrument(span) + .await; }, SelectConnections(selection, reply) => { let _result = reply.send(self.select_connections(selection).await); @@ -304,7 +310,12 @@ impl ConnectivityManagerActor { &mut self, node_id: NodeId, reply_tx: Option>>, + drop_old_connections: bool, ) { + trace!( + target: LOG_TARGET,"handle_dial_peer: peer: {}, drop_old_connections: {}", + node_id.short_str(), drop_old_connections + ); match self.peer_manager.is_peer_banned(&node_id).await { Ok(true) => { if let Some(reply) = reply_tx { @@ -323,6 +334,7 @@ impl ConnectivityManagerActor { match self.pool.get(&node_id) { // The connection pool may temporarily contain a connection that is not connected so we need to check this. Some(state) if state.is_connected() => { + trace!(target: LOG_TARGET,"handle_dial_peer: {}, {:?}", node_id, state.status()); if let Some(reply_tx) = reply_tx { let _result = reply_tx.send(Ok(state.connection().cloned().expect("Already checked"))); } @@ -346,7 +358,11 @@ impl ConnectivityManagerActor { }, } - if let Err(err) = self.connection_manager.send_dial_peer(node_id, reply_tx).await { + if let Err(err) = self + .connection_manager + .send_dial_peer(node_id, reply_tx, drop_old_connections) + .await + { error!( target: LOG_TARGET, "Failed to send dial request to connection manager: {:?}", err diff --git a/comms/core/src/connectivity/requester.rs b/comms/core/src/connectivity/requester.rs index 4b5bbf34c1d..9620b7a4e2d 100644 --- a/comms/core/src/connectivity/requester.rs +++ b/comms/core/src/connectivity/requester.rs @@ -90,6 +90,7 @@ pub enum ConnectivityRequest { DialPeer { node_id: NodeId, reply_tx: Option>>, + drop_old_connections: bool, }, GetConnectivityStatus(oneshot::Sender), SelectConnections( @@ -132,7 +133,11 @@ impl ConnectivityRequester { } /// Dial a single peer - pub async fn dial_peer(&self, peer: NodeId) -> Result { + pub async fn dial_peer( + &self, + peer: NodeId, + drop_old_connections: bool, + ) -> Result { let mut num_cancels = 0; loop { let (reply_tx, reply_rx) = oneshot::channel(); @@ -140,6 +145,7 @@ impl ConnectivityRequester { .send(ConnectivityRequest::DialPeer { node_id: peer.clone(), reply_tx: Some(reply_tx), + drop_old_connections, }) .await .map_err(|_| ConnectivityError::ActorDisconnected)?; @@ -168,7 +174,7 @@ impl ConnectivityRequester { ) -> impl Stream> + '_ { peers .into_iter() - .map(|peer| self.dial_peer(peer)) + .map(|peer| self.dial_peer(peer, false)) .collect::>() } @@ -178,6 +184,7 @@ impl ConnectivityRequester { self.sender.send(ConnectivityRequest::DialPeer { node_id: peer, reply_tx: None, + drop_old_connections: false, }) })) .await diff --git a/comms/core/src/proto/rpc.proto b/comms/core/src/proto/rpc.proto index 33e5cd4f0c7..d94b8da3f50 100644 --- a/comms/core/src/proto/rpc.proto +++ b/comms/core/src/proto/rpc.proto @@ -40,6 +40,8 @@ message RpcResponse { message RpcSession { // The RPC versions supported by the client repeated uint32 supported_versions = 1; + // Drop old connections for the incoming peer RPC connection + bool drop_old_connections = 2; } message RpcSessionReply { diff --git a/comms/core/src/protocol/messaging/outbound.rs b/comms/core/src/protocol/messaging/outbound.rs index aacb3115b42..76194c2c194 100644 --- a/comms/core/src/protocol/messaging/outbound.rs +++ b/comms/core/src/protocol/messaging/outbound.rs @@ -166,7 +166,7 @@ impl OutboundMessaging { async fn try_dial_peer(&mut self) -> Result { loop { - match self.connectivity.dial_peer(self.peer_node_id.clone()).await { + match self.connectivity.dial_peer(self.peer_node_id.clone(), false).await { Ok(conn) => break Ok(conn), Err(ConnectivityError::DialCancelled) => { debug!( diff --git a/comms/core/src/protocol/rpc/client/mod.rs b/comms/core/src/protocol/rpc/client/mod.rs index 300c902303d..bc980278f3a 100644 --- a/comms/core/src/protocol/rpc/client/mod.rs +++ b/comms/core/src/protocol/rpc/client/mod.rs @@ -273,6 +273,12 @@ impl RpcClientBuilder { self } + /// Old RPC connections will be dropped when a new connection is established. + pub fn with_drop_old_connections(mut self, drop_old_connections: bool) -> Self { + self.config.drop_old_connections = drop_old_connections; + self + } + /// Set a signal that indicates if this client should be immediately closed pub fn with_terminate_signal(mut self, terminate_signal: OneshotSignal) -> Self { self.terminate_signal = Some(terminate_signal); @@ -306,6 +312,7 @@ pub struct RpcClientConfig { pub deadline: Option, pub deadline_grace_period: Duration, pub handshake_timeout: Duration, + pub drop_old_connections: bool, } impl RpcClientConfig { @@ -326,6 +333,7 @@ impl Default for RpcClientConfig { deadline: Some(Duration::from_secs(120)), deadline_grace_period: Duration::from_secs(60), handshake_timeout: Duration::from_secs(90), + drop_old_connections: false, } } } @@ -464,7 +472,9 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId self.protocol_name() ); let start = Instant::now(); - let mut handshake = Handshake::new(&mut self.framed).with_timeout(self.config.handshake_timeout()); + let mut handshake = Handshake::new(&mut self.framed) + .with_timeout(self.config.handshake_timeout()) + .with_drop_old_connections(self.config.drop_old_connections); match handshake.perform_client_handshake().await { Ok(_) => { let latency = start.elapsed(); diff --git a/comms/core/src/protocol/rpc/context.rs b/comms/core/src/protocol/rpc/context.rs index 47ef988fb66..e68c2f4dbc3 100644 --- a/comms/core/src/protocol/rpc/context.rs +++ b/comms/core/src/protocol/rpc/context.rs @@ -71,7 +71,10 @@ impl RpcCommsProvider for RpcCommsBackend { } async fn dial_peer(&mut self, node_id: &NodeId) -> Result { - self.connectivity.dial_peer(node_id.clone()).await.map_err(Into::into) + self.connectivity + .dial_peer(node_id.clone(), false) + .await + .map_err(Into::into) } async fn select_connections(&mut self, selection: ConnectivitySelection) -> Result, RpcError> { diff --git a/comms/core/src/protocol/rpc/handshake.rs b/comms/core/src/protocol/rpc/handshake.rs index 05eddbda144..252afad9367 100644 --- a/comms/core/src/protocol/rpc/handshake.rs +++ b/comms/core/src/protocol/rpc/handshake.rs @@ -61,6 +61,7 @@ pub enum RpcHandshakeError { pub struct Handshake<'a, T> { framed: &'a mut CanonicalFraming, timeout: Option, + drop_old_connections: bool, } impl<'a, T> Handshake<'a, T> @@ -68,7 +69,11 @@ where T: AsyncRead + AsyncWrite + Unpin { /// Create a Handshake using the given framing and no timeout. To set a timeout, use `with_timeout`. pub fn new(framed: &'a mut CanonicalFraming) -> Self { - Self { framed, timeout: None } + Self { + framed, + timeout: None, + drop_old_connections: false, + } } /// Set the length of time that a client/server should wait for the other side to respond before timing out. @@ -77,8 +82,14 @@ where T: AsyncRead + AsyncWrite + Unpin self } + /// Old RPC connections will be dropped when a new connection is established. + pub fn with_drop_old_connections(mut self, drop_old_connections: bool) -> Self { + self.drop_old_connections = drop_old_connections; + self + } + /// Server-side handshake protocol - pub async fn perform_server_handshake(&mut self) -> Result { + pub async fn perform_server_handshake(&mut self) -> Result<(u32, bool), RpcHandshakeError> { match self.recv_next_frame().await { Ok(Some(Ok(msg))) => { let msg = proto::rpc::RpcSession::decode(&mut msg.freeze())?; @@ -96,7 +107,7 @@ where T: AsyncRead + AsyncWrite + Unpin .send(reply.to_encoded_bytes().into()) .instrument(span) .await?; - return Ok(*version); + return Ok((*version, msg.drop_old_connections)); } let span = span!(Level::INFO, "rpc::server::handshake::send_rejection"); @@ -135,6 +146,7 @@ where T: AsyncRead + AsyncWrite + Unpin pub async fn perform_client_handshake(&mut self) -> Result<(), RpcHandshakeError> { let msg = proto::rpc::RpcSession { supported_versions: SUPPORTED_RPC_VERSIONS.to_vec(), + drop_old_connections: self.drop_old_connections, }; let payload = msg.to_encoded_bytes(); debug!(target: LOG_TARGET, "Sending client handshake ({} bytes)", payload.len()); diff --git a/comms/core/src/protocol/rpc/server/mod.rs b/comms/core/src/protocol/rpc/server/mod.rs index 0cf2fec604a..68a55ab5dc9 100644 --- a/comms/core/src/protocol/rpc/server/mod.rs +++ b/comms/core/src/protocol/rpc/server/mod.rs @@ -492,7 +492,6 @@ where "NEW SESSION for {} ({} currently active) ", node_id, num_sessions ); }, - Err(err) => { handshake .reject_with_reason(HandshakeRejectReason::NoServerSessionsAvailable( @@ -503,7 +502,6 @@ where }, } - let version = handshake.perform_server_handshake().await?; debug!( target: LOG_TARGET, "Server negotiated RPC v{} with client node `{}`", version, node_id diff --git a/comms/core/src/protocol/rpc/test/comms_integration.rs b/comms/core/src/protocol/rpc/test/comms_integration.rs index f2d6f6dae55..93c68ba4f39 100644 --- a/comms/core/src/protocol/rpc/test/comms_integration.rs +++ b/comms/core/src/protocol/rpc/test/comms_integration.rs @@ -74,7 +74,7 @@ async fn run_service() { let mut conn = comms2 .connectivity() - .dial_peer(comms1.node_identity().node_id().clone()) + .dial_peer(comms1.node_identity().node_id().clone(), false) .await .unwrap(); diff --git a/comms/core/src/protocol/rpc/test/handshake.rs b/comms/core/src/protocol/rpc/test/handshake.rs index 17ead8fedc3..7ba4ddec8a4 100644 --- a/comms/core/src/protocol/rpc/test/handshake.rs +++ b/comms/core/src/protocol/rpc/test/handshake.rs @@ -47,8 +47,28 @@ async fn it_performs_the_handshake() { let mut handshake_client = Handshake::new(&mut client_framed); handshake_client.perform_client_handshake().await.unwrap(); - let v = handshake_result.await.unwrap().unwrap(); + let (v, drop_old_connections) = handshake_result.await.unwrap().unwrap(); assert!(SUPPORTED_RPC_VERSIONS.contains(&v)); + assert!(!drop_old_connections); +} + +#[tokio::test] +async fn it_performs_the_handshake_with_drop_old_connections() { + let (client, server) = MemorySocket::new_pair(); + + let handshake_result = task::spawn(async move { + let mut server_framed = framing::canonical(server, 1024); + let mut handshake_server = Handshake::new(&mut server_framed); + handshake_server.perform_server_handshake().await + }); + + let mut client_framed = framing::canonical(client, 1024); + let mut handshake_client = Handshake::new(&mut client_framed).with_drop_old_connections(true); + + handshake_client.perform_client_handshake().await.unwrap(); + let (v, drop_old_connections) = handshake_result.await.unwrap().unwrap(); + assert!(SUPPORTED_RPC_VERSIONS.contains(&v)); + assert!(drop_old_connections); } #[tokio::test] diff --git a/comms/core/src/test_utils/mocks/connection_manager.rs b/comms/core/src/test_utils/mocks/connection_manager.rs index a84a2a65f60..6a3c441632b 100644 --- a/comms/core/src/test_utils/mocks/connection_manager.rs +++ b/comms/core/src/test_utils/mocks/connection_manager.rs @@ -131,7 +131,9 @@ impl ConnectionManagerMock { self.state.inc_call_count(); self.state.add_call(format!("{:?}", req)).await; match req { - DialPeer { node_id, mut reply_tx } => { + DialPeer { + node_id, mut reply_tx, .. + } => { // Send Ok(&mut conn) if we have an active connection, otherwise Err(DialConnectFailedAllAddresses) let result = self .state diff --git a/comms/core/src/test_utils/mocks/connectivity_manager.rs b/comms/core/src/test_utils/mocks/connectivity_manager.rs index b66cdaa523e..e4a1a224a16 100644 --- a/comms/core/src/test_utils/mocks/connectivity_manager.rs +++ b/comms/core/src/test_utils/mocks/connectivity_manager.rs @@ -235,7 +235,11 @@ impl ConnectivityManagerMock { use ConnectivityRequest::*; self.state.add_call(format!("{:?}", req)).await; match req { - DialPeer { node_id, reply_tx } => { + DialPeer { + node_id, + reply_tx, + drop_old_connections: _drop_old_connections, + } => { self.state.add_dialed_peer(node_id.clone()).await; // No reply, no reason to do anything in the mock if reply_tx.is_none() { diff --git a/comms/core/src/test_utils/mocks/peer_connection.rs b/comms/core/src/test_utils/mocks/peer_connection.rs index 074d80b2aff..0bebbe72693 100644 --- a/comms/core/src/test_utils/mocks/peer_connection.rs +++ b/comms/core/src/test_utils/mocks/peer_connection.rs @@ -60,6 +60,7 @@ pub fn create_dummy_peer_connection(node_id: NodeId) -> (PeerConnection, mpsc::R 1, tx, node_id, + false, PeerFeatures::COMMUNICATION_NODE, addr, ConnectionDirection::Inbound, @@ -97,6 +98,7 @@ pub async fn create_peer_connection_mock_pair( ID_COUNTER.fetch_add(1, Ordering::SeqCst), tx1, peer2.node_id, + false, peer2.features, listen_addr.clone(), ConnectionDirection::Inbound, @@ -107,6 +109,7 @@ pub async fn create_peer_connection_mock_pair( ID_COUNTER.fetch_add(1, Ordering::SeqCst), tx2, peer1.node_id, + false, peer1.features, listen_addr, ConnectionDirection::Outbound, diff --git a/comms/core/tests/tests/rpc.rs b/comms/core/tests/tests/rpc.rs index 40bde84a5c5..2b3bb7964ca 100644 --- a/comms/core/tests/tests/rpc.rs +++ b/comms/core/tests/tests/rpc.rs @@ -483,6 +483,394 @@ async fn rpc_server_drop_sessions_when_peer_connection_is_dropped() { ); } +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn rpc_server_can_request_drop_sessions() { + // env_logger::init(); // Set `$env:RUST_LOG = "trace"` + let shutdown = Shutdown::new(); + let (numer_of_clients, drop_old_connections) = (3, false); + let (node1, _node2, _conn1_2, mut rpc_server2, mut clients) = { + let (node1, _rpc_server1) = spawn_node(shutdown.to_signal()).await; + let (node2, rpc_server2) = spawn_node(shutdown.to_signal()).await; + + node1 + .peer_manager() + .add_peer(node2.node_identity().to_peer()) + .await + .unwrap(); + + let mut conn1_2 = node1 + .connectivity() + .dial_peer(node2.node_identity().node_id().clone(), drop_old_connections) + .await + .unwrap(); + + let mut clients = Vec::new(); + for _ in 0..numer_of_clients { + clients.push(conn1_2.connect_rpc::().await.unwrap()); + } + + (node1, node2, conn1_2, rpc_server2, clients) + }; + + // Verify all RPC connections are active + let num_sessions = rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert_eq!(num_sessions, 3); + for client in &mut clients { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_ok()); + } + + // The RPC server closes all RPC connections + let num_closed = rpc_server2 + .close_all_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert_eq!(num_closed, 3); + + // Verify the RPC connections are closed + let num_sessions = rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert_eq!(num_sessions, 0); + for client in &mut clients { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_err()); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn rpc_server_drop_sessions_when_peer_is_disconnected() { + // env_logger::init(); // Set `$env:RUST_LOG = "trace"` + let shutdown = Shutdown::new(); + let (numer_of_clients, drop_old_connections) = (3, false); + let (node1, _node2, mut conn1_2, mut rpc_server2, mut clients) = { + let (node1, _rpc_server1) = spawn_node(shutdown.to_signal()).await; + let (node2, rpc_server2) = spawn_node(shutdown.to_signal()).await; + + node1 + .peer_manager() + .add_peer(node2.node_identity().to_peer()) + .await + .unwrap(); + + let mut conn1_2 = node1 + .connectivity() + .dial_peer(node2.node_identity().node_id().clone(), drop_old_connections) + .await + .unwrap(); + + let mut clients = Vec::new(); + for _ in 0..numer_of_clients { + clients.push(conn1_2.connect_rpc::().await.unwrap()); + } + + (node1, node2, conn1_2, rpc_server2, clients) + }; + + // Verify all RPC connections are active + let num_sessions = rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert_eq!(num_sessions, 3); + for client in &mut clients { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_ok()); + } + + // RPC connections are closed when the peer is disconnected + conn1_2.disconnect(Minimized::No).await.unwrap(); + + // Verify the RPC connections are closed + async_assert_eventually!( + rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(), + expect = 0, + max_attempts = 20, + interval = Duration::from_millis(1000) + ); + for client in &mut clients { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_err()); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn rpc_server_drop_sessions_when_peer_connection_is_dropped() { + // env_logger::init(); // Set `$env:RUST_LOG = "trace"` + let shutdown = Shutdown::new(); + let (numer_of_clients, drop_old_connections) = (3, false); + let (node1, _node2, conn1_2, mut rpc_server2, mut clients) = { + let (node1, _rpc_server1) = spawn_node(shutdown.to_signal()).await; + let (node2, rpc_server2) = spawn_node(shutdown.to_signal()).await; + + node1 + .peer_manager() + .add_peer(node2.node_identity().to_peer()) + .await + .unwrap(); + + let mut conn1_2 = node1 + .connectivity() + .dial_peer(node2.node_identity().node_id().clone(), drop_old_connections) + .await + .unwrap(); + + let mut clients = Vec::new(); + for _ in 0..numer_of_clients { + clients.push(conn1_2.connect_rpc::().await.unwrap()); + } + + (node1, node2, conn1_2, rpc_server2, clients) + }; + + // Verify all RPC connections are active + let num_sessions = rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert_eq!(num_sessions, 3); + for client in &mut clients { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_ok()); + } + + // RPC connections are closed when the peer connection is dropped + drop(conn1_2); + + // Verify the RPC connections are closed + async_assert_eventually!( + rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(), + expect = 0, + max_attempts = 20, + interval = Duration::from_millis(1000) + ); + for client in &mut clients { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_err()); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn new_peer_connection_can_request_drop_sessions_with_dial() { + // env_logger::init(); // Set `$env:RUST_LOG = "trace"` + let shutdown = Shutdown::new(); + let (numer_of_clients, drop_old_connections) = (3, true); + let (node1, _node2, _conn1_2, mut rpc_server2, mut clients) = { + let (node1, _rpc_server1) = spawn_node(shutdown.to_signal()).await; + let (node2, rpc_server2) = spawn_node(shutdown.to_signal()).await; + + node1 + .peer_manager() + .add_peer(node2.node_identity().to_peer()) + .await + .unwrap(); + + let mut conn1_2 = node1 + .connectivity() + .dial_peer(node2.node_identity().node_id().clone(), drop_old_connections) + .await + .unwrap(); + + let mut clients = Vec::new(); + for _ in 0..numer_of_clients { + clients.push(conn1_2.connect_rpc::().await.unwrap()); + } + + (node1, node2, conn1_2, rpc_server2, clients) + }; + + // Verify that only the last RPC connection is active + let num_sessions = rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert_eq!(num_sessions, 1); + let clients_len = clients.len(); + for (i, client) in clients.iter_mut().enumerate() { + if i < clients_len - 1 { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_err()); + } else { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_ok()); + } + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +#[allow(clippy::too_many_lines)] +async fn drop_sessions_with_dial_request_cannot_change_existing_peer_connection() { + // env_logger::init(); // Set `$env:RUST_LOG = "trace"` + let shutdown = Shutdown::new(); + let (numer_of_clients, drop_old_connections) = (3, false); + let (node1, node2, _conn1_2, mut rpc_server2, mut clients) = { + let (node1, _rpc_server1) = spawn_node(shutdown.to_signal()).await; + let (node2, rpc_server2) = spawn_node(shutdown.to_signal()).await; + + node1 + .peer_manager() + .add_peer(node2.node_identity().to_peer()) + .await + .unwrap(); + + let mut conn1_2 = node1 + .connectivity() + .dial_peer(node2.node_identity().node_id().clone(), drop_old_connections) + .await + .unwrap(); + + let mut clients = Vec::new(); + for _ in 0..numer_of_clients { + clients.push(conn1_2.connect_rpc::().await.unwrap()); + } + + (node1, node2, conn1_2, rpc_server2, clients) + }; + + // Verify all RPC connections are active + let num_sessions = rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert_eq!(num_sessions, 3); + for client in &mut clients { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_ok()); + } + + // Get a new connection to the same peer (with 'drop_old_connections = true') + let mut conn1_2b = node1 + .connectivity() + .dial_peer(node2.node_identity().node_id().clone(), true) + .await + .unwrap(); + let num_sessions = rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert_eq!(num_sessions, 3); + + // Establish some new RPC connections + let mut new_clients = Vec::new(); + for _ in 0..numer_of_clients { + new_clients.push(conn1_2b.connect_rpc::().await.unwrap()); + } + let num_sessions = rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert_eq!(num_sessions, 6); + + // Verify that the old RPC connections are active + for client in &mut clients { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_ok()); + } + + // Verify that the new RPC connections are active + for client in &mut new_clients { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_ok()); + } + + // The RPC server closes all RPC connections + let num_closed = rpc_server2 + .close_all_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert_eq!(num_closed, 6); + + // Verify the RPC connections are closed + let num_sessions = rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert_eq!(num_sessions, 0); + for client in &mut clients { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_err()); + } + for client in &mut clients { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_err()); + } +} + #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn client_prematurely_ends_session() { let shutdown = Shutdown::new(); @@ -497,7 +885,7 @@ async fn client_prematurely_ends_session() { let mut conn1_2 = node1 .connectivity() - .dial_peer(node2.node_identity().node_id().clone()) + .dial_peer(node2.node_identity().node_id().clone(), false) .await .unwrap(); diff --git a/comms/core/tests/tests/rpc_stress.rs b/comms/core/tests/tests/rpc_stress.rs index 9a445e8f140..239f74d840f 100644 --- a/comms/core/tests/tests/rpc_stress.rs +++ b/comms/core/tests/tests/rpc_stress.rs @@ -100,7 +100,7 @@ async fn run_stress_test(test_params: Params) { let conn1_2 = node1 .connectivity() - .dial_peer(node2.node_identity().node_id().clone()) + .dial_peer(node2.node_identity().node_id().clone(), false) .await .unwrap(); diff --git a/comms/core/tests/tests/substream_stress.rs b/comms/core/tests/tests/substream_stress.rs index 488ec9064cf..d64a5cb5bad 100644 --- a/comms/core/tests/tests/substream_stress.rs +++ b/comms/core/tests/tests/substream_stress.rs @@ -72,7 +72,7 @@ async fn run_stress_test(num_substreams: usize, num_iterations: usize, payload_s let mut conn = node1 .connectivity() - .dial_peer(node2.node_identity().node_id().clone()) + .dial_peer(node2.node_identity().node_id().clone(), false) .await .unwrap(); diff --git a/comms/dht/src/actor.rs b/comms/dht/src/actor.rs index 6bd1084e989..103b7fa82fb 100644 --- a/comms/dht/src/actor.rs +++ b/comms/dht/src/actor.rs @@ -912,7 +912,7 @@ impl DiscoveryDialTask { pub async fn run(&mut self, public_key: CommsPublicKey) -> Result { if self.peer_manager.exists(&public_key).await { let node_id = NodeId::from_public_key(&public_key); - match self.connectivity.dial_peer(node_id).await { + match self.connectivity.dial_peer(node_id, false).await { Ok(conn) => Ok(conn), Err(ConnectivityError::ConnectionFailed(err)) => match err { ConnectionManagerError::ConnectFailedMaximumAttemptsReached | @@ -949,7 +949,7 @@ impl DiscoveryDialTask { node_id, timer.elapsed() ); - let conn = self.connectivity.dial_peer(node_id).await?; + let conn = self.connectivity.dial_peer(node_id, false).await?; Ok(conn) } } diff --git a/comms/dht/src/network_discovery/discovering.rs b/comms/dht/src/network_discovery/discovering.rs index 44006396272..d148eee4cf3 100644 --- a/comms/dht/src/network_discovery/discovering.rs +++ b/comms/dht/src/network_discovery/discovering.rs @@ -298,7 +298,7 @@ impl Discovering { .map(|peer| { let connectivity = self.context.connectivity.clone(); let peer = peer.clone(); - async move { connectivity.dial_peer(peer).await } + async move { connectivity.dial_peer(peer, false).await } }) .collect::>(); diff --git a/comms/dht/tests/dht.rs b/comms/dht/tests/dht.rs index b58643ab5f6..80cc91df785 100644 --- a/comms/dht/tests/dht.rs +++ b/comms/dht/tests/dht.rs @@ -141,7 +141,7 @@ async fn test_dht_discover_propagation() { node_D .comms .connectivity() - .dial_peer(node_C.comms.node_identity().node_id().clone()) + .dial_peer(node_C.comms.node_identity().node_id().clone(), false) .await .unwrap(); @@ -328,7 +328,7 @@ async fn test_dht_propagate_dedup() { node1 .comms .connectivity() - .dial_peer(node2.node_identity().node_id().clone()) + .dial_peer(node2.node_identity().node_id().clone(), false) .await .unwrap(); } @@ -463,21 +463,21 @@ async fn test_dht_do_not_store_invalid_message_in_dedup() { node_A .comms .connectivity() - .dial_peer(node_B.node_identity().node_id().clone()) + .dial_peer(node_B.node_identity().node_id().clone(), false) .await .unwrap(); node_A .comms .connectivity() - .dial_peer(node_C.node_identity().node_id().clone()) + .dial_peer(node_C.node_identity().node_id().clone(), false) .await .unwrap(); node_B .comms .connectivity() - .dial_peer(node_C.node_identity().node_id().clone()) + .dial_peer(node_C.node_identity().node_id().clone(), false) .await .unwrap(); @@ -627,7 +627,7 @@ async fn test_dht_repropagate() { node1 .comms .connectivity() - .dial_peer(node2.node_identity().node_id().clone()) + .dial_peer(node2.node_identity().node_id().clone(), false) .await .unwrap(); } @@ -731,7 +731,7 @@ async fn test_dht_propagate_message_contents_not_malleable_ban() { node_A .comms .connectivity() - .dial_peer(node_B.node_identity().node_id().clone()) + .dial_peer(node_B.node_identity().node_id().clone(), false) .await .unwrap(); @@ -836,7 +836,7 @@ async fn test_dht_header_not_malleable() { node_A .comms .connectivity() - .dial_peer(node_B.node_identity().node_id().clone()) + .dial_peer(node_B.node_identity().node_id().clone(), false) .await .unwrap();