diff --git a/applications/minotari_node/src/commands/command/dial_peer.rs b/applications/minotari_node/src/commands/command/dial_peer.rs index a37feb0bf8..699cd3f53f 100644 --- a/applications/minotari_node/src/commands/command/dial_peer.rs +++ b/applications/minotari_node/src/commands/command/dial_peer.rs @@ -53,7 +53,7 @@ impl CommandContext { let start = Instant::now(); println!("☎️ Dialing peer..."); - match connectivity.dial_peer(dest_node_id).await { + match connectivity.dial_peer(dest_node_id, false).await { Ok(connection) => { println!("⚡️ Peer connected in {}ms!", start.elapsed().as_millis()); println!("Connection: {}", connection); diff --git a/base_layer/core/src/base_node/sync/block_sync/synchronizer.rs b/base_layer/core/src/base_node/sync/block_sync/synchronizer.rs index 796337cffa..98941b0b31 100644 --- a/base_layer/core/src/base_node/sync/block_sync/synchronizer.rs +++ b/base_layer/core/src/base_node/sync/block_sync/synchronizer.rs @@ -217,7 +217,7 @@ impl<'a, B: BlockchainBackend + 'static> BlockSynchronizer<'a, B> { } async fn connect_to_sync_peer(&self, peer: NodeId) -> Result { - let connection = self.connectivity.dial_peer(peer).await?; + let connection = self.connectivity.dial_peer(peer, false).await?; Ok(connection) } diff --git a/base_layer/core/src/base_node/sync/header_sync/synchronizer.rs b/base_layer/core/src/base_node/sync/header_sync/synchronizer.rs index 12514a63bf..826b479947 100644 --- a/base_layer/core/src/base_node/sync/header_sync/synchronizer.rs +++ b/base_layer/core/src/base_node/sync/header_sync/synchronizer.rs @@ -230,7 +230,7 @@ impl<'a, B: BlockchainBackend + 'static> HeaderSynchronizer<'a, B> { async fn dial_sync_peer(&self, node_id: &NodeId) -> Result { let timer = Instant::now(); debug!(target: LOG_TARGET, "Dialing {} sync peer", node_id); - let conn = self.connectivity.dial_peer(node_id.clone()).await?; + let conn = self.connectivity.dial_peer(node_id.clone(), false).await?; info!( target: LOG_TARGET, "Successfully dialed sync peer {} in {:.2?}", diff --git a/base_layer/core/src/base_node/sync/horizon_state_sync/synchronizer.rs b/base_layer/core/src/base_node/sync/horizon_state_sync/synchronizer.rs index ded89d17d7..5980a294aa 100644 --- a/base_layer/core/src/base_node/sync/horizon_state_sync/synchronizer.rs +++ b/base_layer/core/src/base_node/sync/horizon_state_sync/synchronizer.rs @@ -277,7 +277,7 @@ impl<'a, B: BlockchainBackend + 'static> HorizonStateSynchronization<'a, B> { async fn dial_sync_peer(&self, node_id: &NodeId) -> Result { let timer = Instant::now(); debug!(target: LOG_TARGET, "Dialing {} sync peer", node_id); - let conn = self.connectivity.dial_peer(node_id.clone()).await?; + let conn = self.connectivity.dial_peer(node_id.clone(), false).await?; info!( target: LOG_TARGET, "Successfully dialed sync peer {} in {:.2?}", diff --git a/base_layer/core/tests/tests/node_service.rs b/base_layer/core/tests/tests/node_service.rs index f45703d2dd..5b33ef17ee 100644 --- a/base_layer/core/tests/tests/node_service.rs +++ b/base_layer/core/tests/tests/node_service.rs @@ -420,7 +420,7 @@ async fn propagate_and_forward_invalid_block() { alice_node .comms .connectivity() - .dial_peer(bob_node.node_identity.node_id().clone()) + .dial_peer(bob_node.node_identity.node_id().clone(), false) .await .unwrap(); wait_until_online(&[&alice_node, &bob_node, &carol_node, &dan_node]).await; diff --git a/base_layer/wallet/src/connectivity_service/service.rs b/base_layer/wallet/src/connectivity_service/service.rs index ffbef54b7c..a2f705f9de 100644 --- a/base_layer/wallet/src/connectivity_service/service.rs +++ b/base_layer/wallet/src/connectivity_service/service.rs @@ -392,7 +392,7 @@ impl WalletConnectivityService { } async fn try_setup_rpc_pool(&mut self, peer_node_id: NodeId) -> Result { - let conn = match self.try_dial_peer(peer_node_id.clone()).await? { + let conn = match self.try_dial_peer(peer_node_id.clone(), true).await? { Some(c) => c, None => { warn!(target: LOG_TARGET, "Could not dial base node peer '{}'", peer_node_id); @@ -413,14 +413,18 @@ impl WalletConnectivityService { Ok(true) } - async fn try_dial_peer(&mut self, peer: NodeId) -> Result, WalletConnectivityError> { + async fn try_dial_peer( + &mut self, + peer: NodeId, + drop_old_connections: bool, + ) -> Result, WalletConnectivityError> { tokio::select! { biased; _ = self.base_node_watch_receiver.changed() => { Ok(None) } - result = self.connectivity.dial_peer(peer) => { + result = self.connectivity.dial_peer(peer, drop_old_connections) => { Ok(Some(result?)) } } diff --git a/base_layer/wallet/src/utxo_scanner_service/utxo_scanner_task.rs b/base_layer/wallet/src/utxo_scanner_service/utxo_scanner_task.rs index 500a6acf27..071c7c549b 100644 --- a/base_layer/wallet/src/utxo_scanner_service/utxo_scanner_task.rs +++ b/base_layer/wallet/src/utxo_scanner_service/utxo_scanner_task.rs @@ -179,12 +179,12 @@ where Ok(()) } - async fn connect_to_peer(&mut self, peer: NodeId) -> Result { + async fn new_connection_to_peer(&mut self, peer: NodeId) -> Result { debug!( target: LOG_TARGET, "Attempting UTXO sync with seed peer {} ({})", self.peer_index, peer, ); - match self.resources.comms_connectivity.dial_peer(peer.clone()).await { + match self.resources.comms_connectivity.dial_peer(peer.clone(), true).await { Ok(conn) => Ok(conn), Err(e) => { self.publish_event(UtxoScannerEvent::ConnectionFailedToBaseNode { @@ -333,7 +333,7 @@ where &mut self, peer: &NodeId, ) -> Result, UtxoScannerError> { - let mut connection = self.connect_to_peer(peer.clone()).await?; + let mut connection = self.new_connection_to_peer(peer.clone()).await?; let client = connection .connect_rpc_using_builder(BaseNodeWalletRpcClient::builder().with_deadline(Duration::from_secs(60))) .await?; diff --git a/base_layer/wallet/tests/transaction_service_tests/service.rs b/base_layer/wallet/tests/transaction_service_tests/service.rs index 15fea49f58..e39e4427b5 100644 --- a/base_layer/wallet/tests/transaction_service_tests/service.rs +++ b/base_layer/wallet/tests/transaction_service_tests/service.rs @@ -588,7 +588,7 @@ async fn manage_single_transaction() { let _peer_connection = bob_comms .connectivity() - .dial_peer(alice_node_identity.node_id().clone()) + .dial_peer(alice_node_identity.node_id().clone(), false) .await .unwrap(); @@ -753,7 +753,7 @@ async fn large_interactive_transaction() { // Verify that Alice and Bob are connected let _peer_connection = bob_comms .connectivity() - .dial_peer(alice_node_identity.node_id().clone()) + .dial_peer(alice_node_identity.node_id().clone(), false) .await .unwrap(); @@ -2172,7 +2172,7 @@ async fn manage_multiple_transactions() { let _peer_connection = bob_comms .connectivity() - .dial_peer(alice_node_identity.node_id().clone()) + .dial_peer(alice_node_identity.node_id().clone(), false) .await .unwrap(); sleep(Duration::from_secs(3)).await; @@ -2180,7 +2180,7 @@ async fn manage_multiple_transactions() { // Connect alice to carol let _peer_connection = alice_comms .connectivity() - .dial_peer(carol_node_identity.node_id().clone()) + .dial_peer(carol_node_identity.node_id().clone(), false) .await .unwrap(); diff --git a/base_layer/wallet_ffi/src/lib.rs b/base_layer/wallet_ffi/src/lib.rs index 0da659a513..fdbe148b7c 100644 --- a/base_layer/wallet_ffi/src/lib.rs +++ b/base_layer/wallet_ffi/src/lib.rs @@ -12730,7 +12730,7 @@ mod test { .block_on( alice_wallet_comms .connectivity() - .dial_peer(bob_node_identity.node_id().clone()), + .dial_peer(bob_node_identity.node_id().clone(), false), ) .is_ok(); } @@ -12739,7 +12739,7 @@ mod test { .block_on( bob_wallet_comms .connectivity() - .dial_peer(alice_node_identity.node_id().clone()), + .dial_peer(alice_node_identity.node_id().clone(), false), ) .is_ok(); } @@ -12810,7 +12810,7 @@ mod test { let bob_comms_dial_peer = bob_wallet_runtime.block_on( bob_wallet_comms .connectivity() - .dial_peer(alice_node_identity.node_id().clone()), + .dial_peer(alice_node_identity.node_id().clone(), false), ); if let Ok(mut connection_to_alice) = bob_comms_dial_peer { if bob_wallet_runtime diff --git a/comms/core/examples/stress/service.rs b/comms/core/examples/stress/service.rs index 2199638f4b..0f6f84c536 100644 --- a/comms/core/examples/stress/service.rs +++ b/comms/core/examples/stress/service.rs @@ -227,7 +227,7 @@ impl StressTestService { self.comms_node.peer_manager().add_peer(peer).await?; println!("Dialing peer `{}`...", node_id.short_str()); let start = Instant::now(); - let conn = self.comms_node.connectivity().dial_peer(node_id).await?; + let conn = self.comms_node.connectivity().dial_peer(node_id, false).await?; println!("Dial completed successfully in {:.2?}", start.elapsed()); let outbound_tx = self.outbound_tx.clone(); let inbound_rx = self.inbound_rx.clone(); diff --git a/comms/core/src/builder/tests.rs b/comms/core/src/builder/tests.rs index 02626c75e7..de4f49a3d8 100644 --- a/comms/core/src/builder/tests.rs +++ b/comms/core/src/builder/tests.rs @@ -166,7 +166,7 @@ async fn peer_to_peer_custom_protocols() { let mut conn_man_events2 = comms_node2.subscribe_connection_manager_events(); let mut conn1 = conn_man_requester1 - .dial_peer(node_identity2.node_id().clone()) + .dial_peer(node_identity2.node_id().clone(), false) .await .unwrap(); @@ -347,7 +347,7 @@ async fn peer_to_peer_messaging_simultaneous() { comms_node1 .connectivity() - .dial_peer(comms_node2.node_identity().node_id().clone()) + .dial_peer(comms_node2.node_identity().node_id().clone(), false) .await .unwrap(); // Simultaneously send messages between the two nodes diff --git a/comms/core/src/connection_manager/dialer.rs b/comms/core/src/connection_manager/dialer.rs index 357491ae22..65db0845e2 100644 --- a/comms/core/src/connection_manager/dialer.rs +++ b/comms/core/src/connection_manager/dialer.rs @@ -81,6 +81,7 @@ pub(crate) enum DialerRequest { Dial( Box, Option>>, + bool, ), CancelPendingDial(NodeId), NotifyNewInboundConnection(Box), @@ -176,8 +177,8 @@ where debug!(target: LOG_TARGET, "Connection dialer got request: {:?}", request); match request { - Dial(peer, reply_tx) => { - self.handle_dial_peer_request(pending_dials, peer, reply_tx); + Dial(peer, reply_tx, drop_old_connections) => { + self.handle_dial_peer_request(pending_dials, peer, reply_tx, drop_old_connections); }, CancelPendingDial(peer_id) => { self.cancel_dial(&peer_id); @@ -318,6 +319,7 @@ where pending_dials: &mut DialFuturesUnordered, peer: Box, reply_tx: Option>>, + drop_old_connections: bool, ) { if self.is_pending_dial(&peer.node_id) { debug!( @@ -371,6 +373,7 @@ where let result = Self::perform_socket_upgrade_procedure( &peer_manager, &node_identity, + drop_old_connections, socket, addr.clone(), authenticated_public_key, @@ -421,6 +424,7 @@ where async fn perform_socket_upgrade_procedure( peer_manager: &PeerManager, node_identity: &NodeIdentity, + drop_old_connections: bool, mut socket: NoiseSocket, dialed_addr: Multiaddr, authenticated_public_key: CommsPublicKey, @@ -474,6 +478,7 @@ where muxer, dialed_addr, NodeId::from_public_key(&authenticated_public_key), + drop_old_connections, peer_identity.claim.features, CONNECTION_DIRECTION, conn_man_notifier, diff --git a/comms/core/src/connection_manager/listener.rs b/comms/core/src/connection_manager/listener.rs index 42b81440e2..4c37b4b009 100644 --- a/comms/core/src/connection_manager/listener.rs +++ b/comms/core/src/connection_manager/listener.rs @@ -423,6 +423,7 @@ where muxer, peer_addr, peer.node_id.clone(), + false, peer.features, CONNECTION_DIRECTION, conn_man_notifier, diff --git a/comms/core/src/connection_manager/manager.rs b/comms/core/src/connection_manager/manager.rs index a646a3dd41..d24926e4f7 100644 --- a/comms/core/src/connection_manager/manager.rs +++ b/comms/core/src/connection_manager/manager.rs @@ -385,11 +385,17 @@ where use ConnectionManagerRequest::{CancelDial, DialPeer, NotifyListening}; trace!(target: LOG_TARGET, "Connection manager got request: {:?}", request); match request { - DialPeer { node_id, reply_tx } => { + DialPeer { + node_id, + reply_tx, + drop_old_connections, + } => { let tracing_id = tracing::Span::current().id(); let span = span!(Level::TRACE, "connection_manager::handle_request"); span.follows_from(tracing_id); - self.dial_peer(node_id, reply_tx).instrument(span).await + self.dial_peer(node_id, reply_tx, drop_old_connections) + .instrument(span) + .await }, CancelDial(node_id) => { if let Err(err) = self.dialer_tx.send(DialerRequest::CancelPendingDial(node_id)).await { @@ -500,10 +506,11 @@ where &mut self, node_id: NodeId, reply: Option>>, + drop_old_connections: bool, ) { match self.peer_manager.find_by_node_id(&node_id).await { Ok(Some(peer)) => { - self.send_dialer_request(DialerRequest::Dial(Box::new(peer), reply)) + self.send_dialer_request(DialerRequest::Dial(Box::new(peer), reply, drop_old_connections)) .await; }, Ok(None) => { diff --git a/comms/core/src/connection_manager/peer_connection.rs b/comms/core/src/connection_manager/peer_connection.rs index b1e1435d2c..b8f978daec 100644 --- a/comms/core/src/connection_manager/peer_connection.rs +++ b/comms/core/src/connection_manager/peer_connection.rs @@ -33,6 +33,7 @@ use std::{ use futures::{future::BoxFuture, stream::FuturesUnordered}; use log::*; use multiaddr::Multiaddr; +use tari_shutdown::oneshot_trigger::OneshotTrigger; use tokio::{ sync::{mpsc, oneshot}, time, @@ -71,6 +72,7 @@ pub fn create( connection: Yamux, peer_addr: Multiaddr, peer_node_id: NodeId, + drop_old_connections: bool, peer_features: PeerFeatures, direction: ConnectionDirection, event_notifier: mpsc::Sender, @@ -90,6 +92,7 @@ pub fn create( id, peer_tx, peer_node_id.clone(), + drop_old_connections, peer_features, peer_addr, direction, @@ -130,6 +133,7 @@ pub type ConnectionId = usize; pub struct PeerConnection { id: ConnectionId, peer_node_id: NodeId, + drop_old_connections: bool, peer_features: PeerFeatures, request_tx: mpsc::Sender, address: Arc, @@ -137,6 +141,8 @@ pub struct PeerConnection { started_at: Instant, substream_counter: AtomicRefCounter, handle_counter: Arc<()>, + drop_notifier: OneshotTrigger, + number_of_rpc_clients: Option, } impl PeerConnection { @@ -144,6 +150,7 @@ impl PeerConnection { id: ConnectionId, request_tx: mpsc::Sender, peer_node_id: NodeId, + drop_old_connections: bool, peer_features: PeerFeatures, address: Multiaddr, direction: ConnectionDirection, @@ -153,12 +160,15 @@ impl PeerConnection { id, request_tx, peer_node_id, + drop_old_connections, peer_features, address: Arc::new(address), direction, started_at: Instant::now(), substream_counter, handle_counter: Arc::new(()), + drop_notifier: OneshotTrigger::::new(), + number_of_rpc_clients: None, } } @@ -249,16 +259,21 @@ impl PeerConnection { let protocol = ProtocolId::from_static(T::PROTOCOL_NAME); debug!( target: LOG_TARGET, - "Attempting to establish RPC protocol `{}` to peer `{}`", - String::from_utf8_lossy(&protocol), - self.peer_node_id + "Attempting to establish RPC protocol `{}` to peer `{}` (drop_old_connections {})", + String::from_utf8_lossy(&protocol), self.peer_node_id, self.drop_old_connections ); let framed = self.open_framed_substream(&protocol, RPC_MAX_FRAME_SIZE).await?; - builder + + let rpc_client = builder .with_protocol_id(protocol) .with_node_id(self.peer_node_id.clone()) + .with_drop_old_connections(self.drop_old_connections) + .with_drop_receiver(self.drop_notifier.clone()) .connect(framed) - .await + .await?; + self.number_of_rpc_clients = Some(self.number_of_rpc_clients.unwrap_or(0) + 1); + + Ok(rpc_client) } /// Creates a new RpcClientPool that can be shared between tasks. The client pool will lazily establish up to @@ -298,6 +313,24 @@ impl PeerConnection { } } +impl Drop for PeerConnection { + fn drop(&mut self) { + trace!( + target: LOG_TARGET, + "PeerConnection `{}` drop called, still has {} sub-streams and {} handles open", + self.peer_node_id, self.substream_count(), self.handle_count(), + ); + if let Some(number_of_rpc_clients) = self.number_of_rpc_clients { + self.drop_notifier.broadcast(self.peer_node_id.clone()); + trace!( + target: LOG_TARGET, + "PeerConnection `{}` on drop notified {} RPC clients to drop connection", + self.peer_node_id.clone(), number_of_rpc_clients, + ); + } + } +} + impl fmt::Display for PeerConnection { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { write!( diff --git a/comms/core/src/connection_manager/requester.rs b/comms/core/src/connection_manager/requester.rs index 40a09da7f9..3569dd7e0c 100644 --- a/comms/core/src/connection_manager/requester.rs +++ b/comms/core/src/connection_manager/requester.rs @@ -22,6 +22,7 @@ use std::sync::Arc; +use log::trace; use tokio::sync::{broadcast, mpsc, oneshot}; use super::{error::ConnectionManagerError, peer_connection::PeerConnection}; @@ -29,6 +30,7 @@ use crate::{ connection_manager::manager::{ConnectionManagerEvent, ListenerInfo}, peer_manager::NodeId, }; +const LOG_TARGET: &str = "comms::connectivity::manager::requester"; /// Requests which are handled by the ConnectionManagerService #[derive(Debug)] @@ -37,6 +39,7 @@ pub enum ConnectionManagerRequest { DialPeer { node_id: NodeId, reply_tx: Option>>, + drop_old_connections: bool, }, /// Cancels a pending dial if one exists CancelDial(NodeId), @@ -77,7 +80,7 @@ impl ConnectionManagerRequester { /// Attempt to connect to a remote peer pub async fn dial_peer(&mut self, node_id: NodeId) -> Result { let (reply_tx, reply_rx) = oneshot::channel(); - self.send_dial_peer(node_id, Some(reply_tx)).await?; + self.send_dial_peer(node_id, Some(reply_tx), false).await?; reply_rx .await .map_err(|_| ConnectionManagerError::ActorRequestCanceled)? @@ -97,9 +100,18 @@ impl ConnectionManagerRequester { &mut self, node_id: NodeId, reply_tx: Option>>, + drop_old_connections: bool, ) -> Result<(), ConnectionManagerError> { + trace!( + target: LOG_TARGET, "send_dial_peer: peer: {}, drop_old_connections: {}", + node_id.short_str(), drop_old_connections + ); self.sender - .send(ConnectionManagerRequest::DialPeer { node_id, reply_tx }) + .send(ConnectionManagerRequest::DialPeer { + node_id, + reply_tx, + drop_old_connections, + }) .await .map_err(|_| ConnectionManagerError::SendToActorFailed)?; Ok(()) diff --git a/comms/core/src/connection_manager/tests/listener_dialer.rs b/comms/core/src/connection_manager/tests/listener_dialer.rs index e73f052379..3b7193e78f 100644 --- a/comms/core/src/connection_manager/tests/listener_dialer.rs +++ b/comms/core/src/connection_manager/tests/listener_dialer.rs @@ -130,7 +130,7 @@ async fn smoke() { let (reply_tx, reply_rx) = oneshot::channel(); request_tx - .send(DialerRequest::Dial(Box::new(peer), Some(reply_tx))) + .send(DialerRequest::Dial(Box::new(peer), Some(reply_tx), false)) .await .unwrap(); @@ -238,7 +238,7 @@ async fn banned() { let (reply_tx, reply_rx) = oneshot::channel(); request_tx - .send(DialerRequest::Dial(Box::new(peer), Some(reply_tx))) + .send(DialerRequest::Dial(Box::new(peer), Some(reply_tx), false)) .await .unwrap(); @@ -311,7 +311,7 @@ async fn excluded_yes() { let (reply_tx, reply_rx) = oneshot::channel(); request_tx - .send(DialerRequest::Dial(Box::new(peer), Some(reply_tx))) + .send(DialerRequest::Dial(Box::new(peer), Some(reply_tx), false)) .await .unwrap(); @@ -380,7 +380,7 @@ async fn excluded_no() { let (reply_tx, reply_rx) = oneshot::channel(); request_tx - .send(DialerRequest::Dial(Box::new(peer), Some(reply_tx))) + .send(DialerRequest::Dial(Box::new(peer), Some(reply_tx), false)) .await .unwrap(); diff --git a/comms/core/src/connectivity/manager.rs b/comms/core/src/connectivity/manager.rs index 899a64c0cb..8e74f82d47 100644 --- a/comms/core/src/connectivity/manager.rs +++ b/comms/core/src/connectivity/manager.rs @@ -221,11 +221,17 @@ impl ConnectivityManagerActor { GetConnectivityStatus(reply) => { let _ = reply.send(self.status); }, - DialPeer { node_id, reply_tx } => { + DialPeer { + node_id, + reply_tx, + drop_old_connections, + } => { let tracing_id = tracing::Span::current().id(); let span = span!(Level::TRACE, "handle_dial_peer"); span.follows_from(tracing_id); - self.handle_dial_peer(node_id.clone(), reply_tx).instrument(span).await; + self.handle_dial_peer(node_id.clone(), reply_tx, drop_old_connections) + .instrument(span) + .await; }, SelectConnections(selection, reply) => { let _result = reply.send(self.select_connections(selection).await); @@ -304,7 +310,12 @@ impl ConnectivityManagerActor { &mut self, node_id: NodeId, reply_tx: Option>>, + drop_old_connections: bool, ) { + trace!( + target: LOG_TARGET,"handle_dial_peer: peer: {}, drop_old_connections: {}", + node_id.short_str(), drop_old_connections + ); match self.peer_manager.is_peer_banned(&node_id).await { Ok(true) => { if let Some(reply) = reply_tx { @@ -323,6 +334,7 @@ impl ConnectivityManagerActor { match self.pool.get(&node_id) { // The connection pool may temporarily contain a connection that is not connected so we need to check this. Some(state) if state.is_connected() => { + trace!(target: LOG_TARGET,"handle_dial_peer: {}, {:?}", node_id, state.status()); if let Some(reply_tx) = reply_tx { let _result = reply_tx.send(Ok(state.connection().cloned().expect("Already checked"))); } @@ -346,7 +358,11 @@ impl ConnectivityManagerActor { }, } - if let Err(err) = self.connection_manager.send_dial_peer(node_id, reply_tx).await { + if let Err(err) = self + .connection_manager + .send_dial_peer(node_id, reply_tx, drop_old_connections) + .await + { error!( target: LOG_TARGET, "Failed to send dial request to connection manager: {:?}", err diff --git a/comms/core/src/connectivity/requester.rs b/comms/core/src/connectivity/requester.rs index 4b5bbf34c1..9620b7a4e2 100644 --- a/comms/core/src/connectivity/requester.rs +++ b/comms/core/src/connectivity/requester.rs @@ -90,6 +90,7 @@ pub enum ConnectivityRequest { DialPeer { node_id: NodeId, reply_tx: Option>>, + drop_old_connections: bool, }, GetConnectivityStatus(oneshot::Sender), SelectConnections( @@ -132,7 +133,11 @@ impl ConnectivityRequester { } /// Dial a single peer - pub async fn dial_peer(&self, peer: NodeId) -> Result { + pub async fn dial_peer( + &self, + peer: NodeId, + drop_old_connections: bool, + ) -> Result { let mut num_cancels = 0; loop { let (reply_tx, reply_rx) = oneshot::channel(); @@ -140,6 +145,7 @@ impl ConnectivityRequester { .send(ConnectivityRequest::DialPeer { node_id: peer.clone(), reply_tx: Some(reply_tx), + drop_old_connections, }) .await .map_err(|_| ConnectivityError::ActorDisconnected)?; @@ -168,7 +174,7 @@ impl ConnectivityRequester { ) -> impl Stream> + '_ { peers .into_iter() - .map(|peer| self.dial_peer(peer)) + .map(|peer| self.dial_peer(peer, false)) .collect::>() } @@ -178,6 +184,7 @@ impl ConnectivityRequester { self.sender.send(ConnectivityRequest::DialPeer { node_id: peer, reply_tx: None, + drop_old_connections: false, }) })) .await diff --git a/comms/core/src/proto/rpc.proto b/comms/core/src/proto/rpc.proto index 33e5cd4f0c..d94b8da3f5 100644 --- a/comms/core/src/proto/rpc.proto +++ b/comms/core/src/proto/rpc.proto @@ -40,6 +40,8 @@ message RpcResponse { message RpcSession { // The RPC versions supported by the client repeated uint32 supported_versions = 1; + // Drop old connections for the incoming peer RPC connection + bool drop_old_connections = 2; } message RpcSessionReply { diff --git a/comms/core/src/protocol/messaging/outbound.rs b/comms/core/src/protocol/messaging/outbound.rs index aacb3115b4..76194c2c19 100644 --- a/comms/core/src/protocol/messaging/outbound.rs +++ b/comms/core/src/protocol/messaging/outbound.rs @@ -166,7 +166,7 @@ impl OutboundMessaging { async fn try_dial_peer(&mut self) -> Result { loop { - match self.connectivity.dial_peer(self.peer_node_id.clone()).await { + match self.connectivity.dial_peer(self.peer_node_id.clone(), false).await { Ok(conn) => break Ok(conn), Err(ConnectivityError::DialCancelled) => { debug!( diff --git a/comms/core/src/protocol/rpc/client/mod.rs b/comms/core/src/protocol/rpc/client/mod.rs index 1995715100..8425090058 100644 --- a/comms/core/src/protocol/rpc/client/mod.rs +++ b/comms/core/src/protocol/rpc/client/mod.rs @@ -49,7 +49,11 @@ use futures::{ }; use log::*; use prost::Message; -use tari_shutdown::{Shutdown, ShutdownSignal}; +use tari_shutdown::{ + oneshot_trigger::{OneshotSignal, OneshotTrigger}, + Shutdown, + ShutdownSignal, +}; use tokio::{ io::{AsyncRead, AsyncWrite}, sync::{mpsc, oneshot, watch, Mutex}, @@ -101,10 +105,12 @@ impl RpcClient { node_id: NodeId, framed: CanonicalFraming, protocol_name: ProtocolId, + drop_receiver: Option>, ) -> Result where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId + 'static, { + trace!(target: LOG_TARGET,"connect to {:?} with {:?}", node_id, config); let (request_tx, request_rx) = mpsc::channel(1); let shutdown = Shutdown::new(); let shutdown_signal = shutdown.to_signal(); @@ -112,6 +118,11 @@ impl RpcClient { let connector = ClientConnector::new(request_tx, last_request_latency_rx, shutdown); let (ready_tx, ready_rx) = oneshot::channel(); let tracing_id = tracing::Span::current().id(); + let drop_signal = if let Some(val) = drop_receiver.as_ref() { + val.to_signal() + } else { + OneshotTrigger::::new().to_signal() + }; tokio::spawn({ let span = span!(Level::TRACE, "start_rpc_worker"); span.follows_from(tracing_id); @@ -125,6 +136,7 @@ impl RpcClient { ready_tx, protocol_name, shutdown_signal, + drop_signal, ) .run() .instrument(span) @@ -207,6 +219,7 @@ pub struct RpcClientBuilder { config: RpcClientConfig, protocol_id: Option, node_id: Option, + drop_receiver: Option>, _client: PhantomData, } @@ -216,6 +229,7 @@ impl Default for RpcClientBuilder { config: Default::default(), protocol_id: None, node_id: None, + drop_receiver: None, _client: PhantomData, } } @@ -266,6 +280,18 @@ impl RpcClientBuilder { self.node_id = Some(node_id); self } + + /// Old RPC connections will be dropped when a new connection is established. + pub fn with_drop_old_connections(mut self, drop_old_connections: bool) -> Self { + self.config.drop_old_connections = drop_old_connections; + self + } + + /// Set the drop receiver to be used to trigger the client to close + pub fn with_drop_receiver(mut self, drop_receiver: OneshotTrigger) -> Self { + self.drop_receiver = Some(drop_receiver); + self + } } impl RpcClientBuilder @@ -282,6 +308,7 @@ where TClient: From + NamedProtocolService .as_ref() .cloned() .unwrap_or_else(|| ProtocolId::from_static(TClient::PROTOCOL_NAME)), + self.drop_receiver, ) .await .map(Into::into) @@ -293,6 +320,7 @@ pub struct RpcClientConfig { pub deadline: Option, pub deadline_grace_period: Duration, pub handshake_timeout: Duration, + pub drop_old_connections: bool, } impl RpcClientConfig { @@ -313,6 +341,7 @@ impl Default for RpcClientConfig { deadline: Some(Duration::from_secs(120)), deadline_grace_period: Duration::from_secs(60), handshake_timeout: Duration::from_secs(90), + drop_old_connections: false, } } } @@ -404,6 +433,7 @@ struct RpcClientWorker { ready_tx: Option>>, protocol_id: ProtocolId, shutdown_signal: ShutdownSignal, + drop_signal: OneshotSignal, } impl RpcClientWorker @@ -418,6 +448,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId ready_tx: oneshot::Sender>, protocol_id: ProtocolId, shutdown_signal: ShutdownSignal, + drop_signal: OneshotSignal, ) -> Self { Self { config, @@ -429,6 +460,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId last_request_latency_tx, protocol_id, shutdown_signal, + drop_signal, } } @@ -448,7 +480,9 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId self.protocol_name() ); let start = Instant::now(); - let mut handshake = Handshake::new(&mut self.framed).with_timeout(self.config.handshake_timeout()); + let mut handshake = Handshake::new(&mut self.framed) + .with_timeout(self.config.handshake_timeout()) + .with_drop_old_connections(self.config.drop_old_connections); match handshake.perform_client_handshake().await { Ok(_) => { let latency = start.elapsed(); @@ -486,18 +520,33 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId _ = &mut self.shutdown_signal => { break; } + node_id = &mut self.drop_signal => { + debug!( + target: LOG_TARGET, "(stream={}) Peer '{}' connection has dropped. Worker is terminating.", + self.stream_id(), node_id.unwrap_or_default() + ); + break; + } req = self.request_rx.recv() => { match req { Some(req) => { if let Err(err) = self.handle_request(req).await { #[cfg(feature = "metrics")] metrics::client_errors(&self.node_id, &self.protocol_id).inc(); - error!(target: LOG_TARGET, "(stream={}) Unexpected error: {}. Worker is terminating.", self.stream_id(), err); + error!( + target: LOG_TARGET, + "(stream={}) Unexpected error: {}. Worker is terminating.", + self.stream_id(), err + ); break; } } None => { - debug!(target: LOG_TARGET, "(stream={}) Request channel closed. Worker is terminating.", self.stream_id()); + debug!( + target: LOG_TARGET, + "(stream={}) Request channel closed. Worker is terminating.", + self.stream_id() + ); break }, } diff --git a/comms/core/src/protocol/rpc/context.rs b/comms/core/src/protocol/rpc/context.rs index 47ef988fb6..e68c2f4dbc 100644 --- a/comms/core/src/protocol/rpc/context.rs +++ b/comms/core/src/protocol/rpc/context.rs @@ -71,7 +71,10 @@ impl RpcCommsProvider for RpcCommsBackend { } async fn dial_peer(&mut self, node_id: &NodeId) -> Result { - self.connectivity.dial_peer(node_id.clone()).await.map_err(Into::into) + self.connectivity + .dial_peer(node_id.clone(), false) + .await + .map_err(Into::into) } async fn select_connections(&mut self, selection: ConnectivitySelection) -> Result, RpcError> { diff --git a/comms/core/src/protocol/rpc/handshake.rs b/comms/core/src/protocol/rpc/handshake.rs index fb61f66a11..7ab27b68ca 100644 --- a/comms/core/src/protocol/rpc/handshake.rs +++ b/comms/core/src/protocol/rpc/handshake.rs @@ -61,6 +61,7 @@ pub enum RpcHandshakeError { pub struct Handshake<'a, T> { framed: &'a mut CanonicalFraming, timeout: Option, + drop_old_connections: bool, } impl<'a, T> Handshake<'a, T> @@ -68,7 +69,11 @@ where T: AsyncRead + AsyncWrite + Unpin { /// Create a Handshake using the given framing and no timeout. To set a timeout, use `with_timeout`. pub fn new(framed: &'a mut CanonicalFraming) -> Self { - Self { framed, timeout: None } + Self { + framed, + timeout: None, + drop_old_connections: false, + } } /// Set the length of time that a client/server should wait for the other side to respond before timing out. @@ -77,8 +82,14 @@ where T: AsyncRead + AsyncWrite + Unpin self } + /// Old RPC connections will be dropped when a new connection is established. + pub fn with_drop_old_connections(mut self, drop_old_connections: bool) -> Self { + self.drop_old_connections = drop_old_connections; + self + } + /// Server-side handshake protocol - pub async fn perform_server_handshake(&mut self) -> Result { + pub async fn perform_server_handshake(&mut self) -> Result<(u32, bool), RpcHandshakeError> { match self.recv_next_frame().await { Ok(Some(Ok(msg))) => { let msg = proto::rpc::RpcSession::decode(&mut msg.freeze())?; @@ -96,7 +107,7 @@ where T: AsyncRead + AsyncWrite + Unpin .send(reply.to_encoded_bytes().into()) .instrument(span) .await?; - return Ok(*version); + return Ok((*version, msg.drop_old_connections)); } let span = span!(Level::INFO, "rpc::server::handshake::send_rejection"); @@ -135,6 +146,7 @@ where T: AsyncRead + AsyncWrite + Unpin pub async fn perform_client_handshake(&mut self) -> Result<(), RpcHandshakeError> { let msg = proto::rpc::RpcSession { supported_versions: SUPPORTED_RPC_VERSIONS.to_vec(), + drop_old_connections: self.drop_old_connections, }; let payload = msg.to_encoded_bytes(); debug!(target: LOG_TARGET, "Sending client handshake ({} bytes)", payload.len()); diff --git a/comms/core/src/protocol/rpc/server/handle.rs b/comms/core/src/protocol/rpc/server/handle.rs index 8a82912cb5..8a4bbb7a82 100644 --- a/comms/core/src/protocol/rpc/server/handle.rs +++ b/comms/core/src/protocol/rpc/server/handle.rs @@ -29,6 +29,7 @@ use crate::peer_manager::NodeId; pub enum RpcServerRequest { GetNumActiveSessions(oneshot::Sender), GetNumActiveSessionsForPeer(NodeId, oneshot::Sender), + CloseAllSessionsForPeer(NodeId, oneshot::Sender), } #[derive(Debug, Clone)] @@ -58,4 +59,13 @@ impl RpcServerHandle { .map_err(|_| RpcServerError::RequestCanceled)?; resp.await.map_err(Into::into) } + + pub async fn close_all_sessions_for(&mut self, peer: NodeId) -> Result { + let (req, resp) = oneshot::channel(); + self.sender + .send(RpcServerRequest::CloseAllSessionsForPeer(peer, req)) + .await + .map_err(|_| RpcServerError::RequestCanceled)?; + resp.await.map_err(Into::into) + } } diff --git a/comms/core/src/protocol/rpc/server/mod.rs b/comms/core/src/protocol/rpc/server/mod.rs index 37e969638a..1c0dac2ffe 100644 --- a/comms/core/src/protocol/rpc/server/mod.rs +++ b/comms/core/src/protocol/rpc/server/mod.rs @@ -89,7 +89,7 @@ use crate::{ ProtocolNotification, ProtocolNotificationRx, }, - stream_id::StreamId, + stream_id::{Id, StreamId}, Bytes, Substream, }; @@ -239,8 +239,13 @@ pub(super) struct PeerRpcServer { protocol_notifications: Option>, comms_provider: TCommsProvider, request_rx: mpsc::Receiver, - sessions: HashMap, - tasks: FuturesUnordered>, + sessions: HashMap>, + tasks: FuturesUnordered>, +} + +struct SessionInfo { + pub(crate) peer_watch: tokio::sync::watch::Sender<()>, + pub(crate) stream_id: Id, } impl PeerRpcServer @@ -297,8 +302,8 @@ where } } - Some(Ok(node_id)) = self.tasks.next() => { - self.on_session_complete(&node_id); + Some(Ok((node_id, stream_id))) = self.tasks.next() => { + self.on_session_complete(&node_id, stream_id); }, Some(req) = self.request_rx.recv() => { @@ -315,7 +320,7 @@ where Ok(()) } - async fn handle_request(&self, req: RpcServerRequest) { + async fn handle_request(&mut self, req: RpcServerRequest) { #[allow(clippy::enum_glob_use)] use RpcServerRequest::*; match req { @@ -328,9 +333,13 @@ where let _ = reply.send(num_active); }, GetNumActiveSessionsForPeer(node_id, reply) => { - let num_active = self.sessions.get(&node_id).copied().unwrap_or(0); + let num_active = self.sessions.get(&node_id).map(|v| v.len()).unwrap_or(0); let _ = reply.send(num_active); }, + CloseAllSessionsForPeer(node_id, reply) => { + let num_closed = self.close_all_sessions(&node_id); + let _ = reply.send(num_closed); + }, } } @@ -368,35 +377,60 @@ where Ok(()) } - fn new_session_for(&mut self, node_id: NodeId) -> Result { - let count = self.sessions.entry(node_id.clone()).or_insert(0); + fn new_session_possible_for(&mut self, node_id: &NodeId) -> Result<(), RpcServerError> { match self.config.maximum_sessions_per_client { Some(max) if max > 0 => { - debug_assert!(*count <= max); - if *count >= max { - return Err(RpcServerError::MaxSessionsPerClientReached { - node_id, - max_sessions: max, - }); + if let Some(session_info) = self.sessions.get(node_id) { + if max > session_info.len() { + Ok(()) + } else { + if max <= session_info.len() { + warn!( + target: LOG_TARGET, + "Maximum RPC sessions for peer {} met or exceeded. Max: {}, Current: {}", + node_id, max, session_info.len() + ); + } + Err(RpcServerError::MaxSessionsPerClientReached { + node_id: node_id.clone(), + max_sessions: max, + }) + } + } else { + Ok(()) } }, - Some(_) | None => {}, + Some(_) | None => Ok(()), } + } - *count += 1; - Ok(*count) + fn close_all_sessions(&mut self, node_id: &NodeId) -> usize { + let mut count = 0; + if let Some(session_info) = self.sessions.get_mut(node_id) { + for info in session_info.iter_mut() { + count += 1; + info!(target: LOG_TARGET, "Closing RPC session {} for peer `{}`", info.stream_id, node_id); + let _ = info.peer_watch.send(()); + } + self.sessions.remove(node_id); + } + count } - fn on_session_complete(&mut self, node_id: &NodeId) { - info!(target: LOG_TARGET, "Session complete for {}", node_id); - if let Some(v) = self.sessions.get_mut(node_id) { - *v -= 1; - if *v == 0 { + fn on_session_complete(&mut self, node_id: &NodeId, stream_id: Id) { + if let Some(session_info) = self.sessions.get_mut(node_id) { + if let Some(info) = session_info.iter_mut().find(|info| info.stream_id == stream_id) { + info!(target: LOG_TARGET, "Session complete for {} (stream id {})", node_id, stream_id); + let _ = info.peer_watch.send(()); + }; + session_info.retain(|info| info.stream_id != stream_id); + if session_info.is_empty() { self.sessions.remove(node_id); } } } + #[allow(clippy::too_many_lines)] async fn try_initiate_service( &mut self, protocol: ProtocolId, @@ -437,14 +471,19 @@ where }, }; - match self.new_session_for(node_id.clone()) { - Ok(num_sessions) => { - info!( - target: LOG_TARGET, - "NEW SESSION for {} ({} active) ", node_id, num_sessions - ); - }, + let (version, drop_old_connections) = handshake.perform_server_handshake().await?; + debug!( + target: LOG_TARGET, + "Server negotiated RPC v{} with client node `{}` (drop_old_connections: {})", version, node_id, drop_old_connections + ); + if drop_old_connections { + self.close_all_sessions(node_id); + } + match self.new_session_possible_for(node_id) { + Ok(_) => { + trace!(target: LOG_TARGET, "NEW SESSION can be created for {}", node_id); + }, Err(err) => { handshake .reject_with_reason(HandshakeRejectReason::NoServerSessionsAvailable( @@ -455,12 +494,8 @@ where }, } - let version = handshake.perform_server_handshake().await?; - debug!( - target: LOG_TARGET, - "Server negotiated RPC v{} with client node `{}`", version, node_id - ); - + let stream_id = framed.stream_id(); + let (stop_tx, stop_rx) = tokio::sync::watch::channel(()); let service = ActivePeerRpcService::new( self.config.clone(), protocol, @@ -468,26 +503,48 @@ where service, framed, self.comms_provider.clone(), + stop_rx, ); - let node_id = node_id.clone(); + let node_id_clone = node_id.clone(); let handle = self .executor .try_spawn(async move { #[cfg(feature = "metrics")] - let num_sessions = metrics::num_sessions(&node_id, &service.protocol); + let num_sessions = metrics::num_sessions(&node_id_clone, &service.protocol); #[cfg(feature = "metrics")] num_sessions.inc(); service.start().await; - info!(target: LOG_TARGET, "END OF SESSION for {} ", node_id,); + info!(target: LOG_TARGET, "END OF SESSION for {} ", node_id_clone,); #[cfg(feature = "metrics")] num_sessions.dec(); - node_id + (node_id_clone, stream_id) }) .map_err(|e| RpcServerError::MaximumSessionsReached(format!("{:?}", e)))?; self.tasks.push(handle); + let mut peer_stop = vec![SessionInfo { + peer_watch: stop_tx, + stream_id, + }]; + self.sessions + .entry(node_id.clone()) + .and_modify(|entry| entry.append(&mut peer_stop)) + .or_insert(peer_stop); + if let Some(info) = self.sessions.get(&node_id.clone()) { + info!( + target: LOG_TARGET, + "NEW SESSION created for {} ({} active) ", node_id.clone(), info.len() + ); + // Warn if `stream_id` is already in use + if info.iter().filter(|session| session.stream_id == stream_id).count() > 1 { + warn!( + target: LOG_TARGET, + "Stream ID {} already in use for peer {}. This should not happen.", stream_id, node_id + ); + } + } Ok(()) } @@ -501,6 +558,7 @@ struct ActivePeerRpcService { framed: EarlyClose>, comms_provider: TCommsProvider, logging_context_string: Arc, + stop_rx: tokio::sync::watch::Receiver<()>, } impl ActivePeerRpcService @@ -515,12 +573,14 @@ where service: TSvc, framed: CanonicalFraming, comms_provider: TCommsProvider, + stop_rx: tokio::sync::watch::Receiver<()>, ) -> Self { Self { logging_context_string: Arc::new(format!( - "stream_id: {}, peer: {}, protocol: {}", + "stream_id: {}, peer: {}, protocol: {}, stream_id: {}", framed.stream_id(), node_id, + framed.stream_id(), String::from_utf8_lossy(&protocol) )), @@ -530,6 +590,7 @@ where service, framed: EarlyClose::new(framed), comms_provider, + stop_rx, } } @@ -557,55 +618,64 @@ where } async fn run(&mut self) -> Result<(), RpcServerError> { - while let Some(result) = self.framed.next().await { - match result { - Ok(frame) => { - #[cfg(feature = "metrics")] - metrics::inbound_requests_bytes(&self.node_id, &self.protocol).observe(frame.len() as f64); - - let start = Instant::now(); - - if let Err(err) = self.handle_request(frame.freeze()).await { - if let Err(err) = self.framed.close().await { - let level = err.io().map(err_to_log_level).unwrap_or(log::Level::Error); - - log!( + loop { + tokio::select! { + _ = self.stop_rx.changed() => { + debug!(target: LOG_TARGET, "({}) Stop signal received, closing substream.", self.logging_context_string); + break; + } + result = self.framed.next() => { + match result { + Some(Ok(frame)) => { + #[cfg(feature = "metrics")] + metrics::inbound_requests_bytes(&self.node_id, &self.protocol).observe(frame.len() as f64); + + let start = Instant::now(); + + if let Err(err) = self.handle_request(frame.freeze()).await { + if let Err(err) = self.framed.close().await { + let level = err.io().map(err_to_log_level).unwrap_or(log::Level::Error); + + log!( + target: LOG_TARGET, + level, + "({}) Failed to close substream after socket error: {}", + self.logging_context_string, + err, + ); + } + let level = err.early_close_io().map(err_to_log_level).unwrap_or(log::Level::Error); + log!( + target: LOG_TARGET, + level, + "(peer: {}, protocol: {}) Failed to handle request: {}", + self.node_id, + self.protocol_name(), + err + ); + return Err(err); + } + let elapsed = start.elapsed(); + trace!( target: LOG_TARGET, - level, - "({}) Failed to close substream after socket error: {}", + "({}) RPC request completed in {:.0?}{}", self.logging_context_string, - err, + elapsed, + if elapsed.as_secs() > 5 { " (LONG REQUEST)" } else { "" } ); - } - let level = err.early_close_io().map(err_to_log_level).unwrap_or(log::Level::Error); - log!( - target: LOG_TARGET, - level, - "(peer: {}, protocol: {}) Failed to handle request: {}", - self.node_id, - self.protocol_name(), - err - ); - return Err(err); - } - let elapsed = start.elapsed(); - trace!( - target: LOG_TARGET, - "({}) RPC request completed in {:.0?}{}", - self.logging_context_string, - elapsed, - if elapsed.as_secs() > 5 { " (LONG REQUEST)" } else { "" } - ); - }, - Err(err) => { - if let Err(err) = self.framed.close().await { - error!( - target: LOG_TARGET, - "({}) Failed to close substream after socket error: {}", self.logging_context_string, err - ); + }, + Some(Err(err)) => { + if let Err(err) = self.framed.close().await { + error!( + target: LOG_TARGET, + "({}) Failed to close substream after socket error: {}", self.logging_context_string, err + ); + } + return Err(err.into()); + }, + None => break, } - return Err(err.into()); - }, + } } } diff --git a/comms/core/src/protocol/rpc/test/comms_integration.rs b/comms/core/src/protocol/rpc/test/comms_integration.rs index f2d6f6dae5..93c68ba4f3 100644 --- a/comms/core/src/protocol/rpc/test/comms_integration.rs +++ b/comms/core/src/protocol/rpc/test/comms_integration.rs @@ -74,7 +74,7 @@ async fn run_service() { let mut conn = comms2 .connectivity() - .dial_peer(comms1.node_identity().node_id().clone()) + .dial_peer(comms1.node_identity().node_id().clone(), false) .await .unwrap(); diff --git a/comms/core/src/protocol/rpc/test/greeting_service.rs b/comms/core/src/protocol/rpc/test/greeting_service.rs index 885e2d13aa..802ba98c88 100644 --- a/comms/core/src/protocol/rpc/test/greeting_service.rs +++ b/comms/core/src/protocol/rpc/test/greeting_service.rs @@ -406,6 +406,7 @@ impl GreetingClient { Default::default(), framed, Self::PROTOCOL_NAME.into(), + Default::default(), ) .await?; Ok(Self { inner }) diff --git a/comms/core/src/protocol/rpc/test/handshake.rs b/comms/core/src/protocol/rpc/test/handshake.rs index 17ead8fedc..7ba4ddec8a 100644 --- a/comms/core/src/protocol/rpc/test/handshake.rs +++ b/comms/core/src/protocol/rpc/test/handshake.rs @@ -47,8 +47,28 @@ async fn it_performs_the_handshake() { let mut handshake_client = Handshake::new(&mut client_framed); handshake_client.perform_client_handshake().await.unwrap(); - let v = handshake_result.await.unwrap().unwrap(); + let (v, drop_old_connections) = handshake_result.await.unwrap().unwrap(); assert!(SUPPORTED_RPC_VERSIONS.contains(&v)); + assert!(!drop_old_connections); +} + +#[tokio::test] +async fn it_performs_the_handshake_with_drop_old_connections() { + let (client, server) = MemorySocket::new_pair(); + + let handshake_result = task::spawn(async move { + let mut server_framed = framing::canonical(server, 1024); + let mut handshake_server = Handshake::new(&mut server_framed); + handshake_server.perform_server_handshake().await + }); + + let mut client_framed = framing::canonical(client, 1024); + let mut handshake_client = Handshake::new(&mut client_framed).with_drop_old_connections(true); + + handshake_client.perform_client_handshake().await.unwrap(); + let (v, drop_old_connections) = handshake_result.await.unwrap().unwrap(); + assert!(SUPPORTED_RPC_VERSIONS.contains(&v)); + assert!(drop_old_connections); } #[tokio::test] diff --git a/comms/core/src/test_utils/mocks/connection_manager.rs b/comms/core/src/test_utils/mocks/connection_manager.rs index a84a2a65f6..6a3c441632 100644 --- a/comms/core/src/test_utils/mocks/connection_manager.rs +++ b/comms/core/src/test_utils/mocks/connection_manager.rs @@ -131,7 +131,9 @@ impl ConnectionManagerMock { self.state.inc_call_count(); self.state.add_call(format!("{:?}", req)).await; match req { - DialPeer { node_id, mut reply_tx } => { + DialPeer { + node_id, mut reply_tx, .. + } => { // Send Ok(&mut conn) if we have an active connection, otherwise Err(DialConnectFailedAllAddresses) let result = self .state diff --git a/comms/core/src/test_utils/mocks/connectivity_manager.rs b/comms/core/src/test_utils/mocks/connectivity_manager.rs index b66cdaa523..e4a1a224a1 100644 --- a/comms/core/src/test_utils/mocks/connectivity_manager.rs +++ b/comms/core/src/test_utils/mocks/connectivity_manager.rs @@ -235,7 +235,11 @@ impl ConnectivityManagerMock { use ConnectivityRequest::*; self.state.add_call(format!("{:?}", req)).await; match req { - DialPeer { node_id, reply_tx } => { + DialPeer { + node_id, + reply_tx, + drop_old_connections: _drop_old_connections, + } => { self.state.add_dialed_peer(node_id.clone()).await; // No reply, no reason to do anything in the mock if reply_tx.is_none() { diff --git a/comms/core/src/test_utils/mocks/peer_connection.rs b/comms/core/src/test_utils/mocks/peer_connection.rs index 074d80b2af..0bebbe7269 100644 --- a/comms/core/src/test_utils/mocks/peer_connection.rs +++ b/comms/core/src/test_utils/mocks/peer_connection.rs @@ -60,6 +60,7 @@ pub fn create_dummy_peer_connection(node_id: NodeId) -> (PeerConnection, mpsc::R 1, tx, node_id, + false, PeerFeatures::COMMUNICATION_NODE, addr, ConnectionDirection::Inbound, @@ -97,6 +98,7 @@ pub async fn create_peer_connection_mock_pair( ID_COUNTER.fetch_add(1, Ordering::SeqCst), tx1, peer2.node_id, + false, peer2.features, listen_addr.clone(), ConnectionDirection::Inbound, @@ -107,6 +109,7 @@ pub async fn create_peer_connection_mock_pair( ID_COUNTER.fetch_add(1, Ordering::SeqCst), tx2, peer1.node_id, + false, peer1.features, listen_addr, ConnectionDirection::Outbound, diff --git a/comms/core/tests/tests/rpc.rs b/comms/core/tests/tests/rpc.rs index d4845d226f..a465631bd0 100644 --- a/comms/core/tests/tests/rpc.rs +++ b/comms/core/tests/tests/rpc.rs @@ -27,13 +27,14 @@ use tari_comms::{ protocol::rpc::{RpcServer, RpcServerHandle}, transports::TcpTransport, CommsNode, + Minimized, }; use tari_shutdown::{Shutdown, ShutdownSignal}; use tari_test_utils::async_assert_eventually; use tokio::time; use crate::tests::{ - greeting_service::{GreetingClient, GreetingServer, GreetingService, StreamLargeItemsRequest}, + greeting_service::{GreetingClient, GreetingServer, GreetingService, SayHelloRequest, StreamLargeItemsRequest}, helpers::create_comms, }; @@ -62,6 +63,394 @@ async fn spawn_node(signal: ShutdownSignal) -> (CommsNode, RpcServerHandle) { (comms, rpc_server_hnd) } +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn rpc_server_can_request_drop_sessions() { + // env_logger::init(); // Set `$env:RUST_LOG = "trace"` + let shutdown = Shutdown::new(); + let (numer_of_clients, drop_old_connections) = (3, false); + let (node1, _node2, _conn1_2, mut rpc_server2, mut clients) = { + let (node1, _rpc_server1) = spawn_node(shutdown.to_signal()).await; + let (node2, rpc_server2) = spawn_node(shutdown.to_signal()).await; + + node1 + .peer_manager() + .add_peer(node2.node_identity().to_peer()) + .await + .unwrap(); + + let mut conn1_2 = node1 + .connectivity() + .dial_peer(node2.node_identity().node_id().clone(), drop_old_connections) + .await + .unwrap(); + + let mut clients = Vec::new(); + for _ in 0..numer_of_clients { + clients.push(conn1_2.connect_rpc::().await.unwrap()); + } + + (node1, node2, conn1_2, rpc_server2, clients) + }; + + // Verify all RPC connections are active + let num_sessions = rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert_eq!(num_sessions, 3); + for client in &mut clients { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_ok()); + } + + // The RPC server closes all RPC connections + let num_closed = rpc_server2 + .close_all_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert_eq!(num_closed, 3); + + // Verify the RPC connections are closed + let num_sessions = rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert_eq!(num_sessions, 0); + for client in &mut clients { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_err()); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn rpc_server_drop_sessions_when_peer_is_disconnected() { + // env_logger::init(); // Set `$env:RUST_LOG = "trace"` + let shutdown = Shutdown::new(); + let (numer_of_clients, drop_old_connections) = (3, false); + let (node1, _node2, mut conn1_2, mut rpc_server2, mut clients) = { + let (node1, _rpc_server1) = spawn_node(shutdown.to_signal()).await; + let (node2, rpc_server2) = spawn_node(shutdown.to_signal()).await; + + node1 + .peer_manager() + .add_peer(node2.node_identity().to_peer()) + .await + .unwrap(); + + let mut conn1_2 = node1 + .connectivity() + .dial_peer(node2.node_identity().node_id().clone(), drop_old_connections) + .await + .unwrap(); + + let mut clients = Vec::new(); + for _ in 0..numer_of_clients { + clients.push(conn1_2.connect_rpc::().await.unwrap()); + } + + (node1, node2, conn1_2, rpc_server2, clients) + }; + + // Verify all RPC connections are active + let num_sessions = rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert_eq!(num_sessions, 3); + for client in &mut clients { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_ok()); + } + + // RPC connections are closed when the peer is disconnected + conn1_2.disconnect(Minimized::No).await.unwrap(); + + // Verify the RPC connections are closed + async_assert_eventually!( + rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(), + expect = 0, + max_attempts = 20, + interval = Duration::from_millis(1000) + ); + for client in &mut clients { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_err()); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn rpc_server_drop_sessions_when_peer_connection_is_dropped() { + // env_logger::init(); // Set `$env:RUST_LOG = "trace"` + let shutdown = Shutdown::new(); + let (numer_of_clients, drop_old_connections) = (3, false); + let (node1, _node2, conn1_2, mut rpc_server2, mut clients) = { + let (node1, _rpc_server1) = spawn_node(shutdown.to_signal()).await; + let (node2, rpc_server2) = spawn_node(shutdown.to_signal()).await; + + node1 + .peer_manager() + .add_peer(node2.node_identity().to_peer()) + .await + .unwrap(); + + let mut conn1_2 = node1 + .connectivity() + .dial_peer(node2.node_identity().node_id().clone(), drop_old_connections) + .await + .unwrap(); + + let mut clients = Vec::new(); + for _ in 0..numer_of_clients { + clients.push(conn1_2.connect_rpc::().await.unwrap()); + } + + (node1, node2, conn1_2, rpc_server2, clients) + }; + + // Verify all RPC connections are active + let num_sessions = rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert_eq!(num_sessions, 3); + for client in &mut clients { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_ok()); + } + + // RPC connections are closed when the peer connection is dropped + drop(conn1_2); + + // Verify the RPC connections are closed + async_assert_eventually!( + rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(), + expect = 0, + max_attempts = 20, + interval = Duration::from_millis(1000) + ); + for client in &mut clients { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_err()); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn new_peer_connection_can_request_drop_sessions_with_dial() { + // env_logger::init(); // Set `$env:RUST_LOG = "trace"` + let shutdown = Shutdown::new(); + let (numer_of_clients, drop_old_connections) = (3, true); + let (node1, _node2, _conn1_2, mut rpc_server2, mut clients) = { + let (node1, _rpc_server1) = spawn_node(shutdown.to_signal()).await; + let (node2, rpc_server2) = spawn_node(shutdown.to_signal()).await; + + node1 + .peer_manager() + .add_peer(node2.node_identity().to_peer()) + .await + .unwrap(); + + let mut conn1_2 = node1 + .connectivity() + .dial_peer(node2.node_identity().node_id().clone(), drop_old_connections) + .await + .unwrap(); + + let mut clients = Vec::new(); + for _ in 0..numer_of_clients { + clients.push(conn1_2.connect_rpc::().await.unwrap()); + } + + (node1, node2, conn1_2, rpc_server2, clients) + }; + + // Verify that only the last RPC connection is active + let num_sessions = rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert_eq!(num_sessions, 1); + let clients_len = clients.len(); + for (i, client) in clients.iter_mut().enumerate() { + if i < clients_len - 1 { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_err()); + } else { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_ok()); + } + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +#[allow(clippy::too_many_lines)] +async fn drop_sessions_with_dial_request_cannot_change_existing_peer_connection() { + // env_logger::init(); // Set `$env:RUST_LOG = "trace"` + let shutdown = Shutdown::new(); + let (numer_of_clients, drop_old_connections) = (3, false); + let (node1, node2, _conn1_2, mut rpc_server2, mut clients) = { + let (node1, _rpc_server1) = spawn_node(shutdown.to_signal()).await; + let (node2, rpc_server2) = spawn_node(shutdown.to_signal()).await; + + node1 + .peer_manager() + .add_peer(node2.node_identity().to_peer()) + .await + .unwrap(); + + let mut conn1_2 = node1 + .connectivity() + .dial_peer(node2.node_identity().node_id().clone(), drop_old_connections) + .await + .unwrap(); + + let mut clients = Vec::new(); + for _ in 0..numer_of_clients { + clients.push(conn1_2.connect_rpc::().await.unwrap()); + } + + (node1, node2, conn1_2, rpc_server2, clients) + }; + + // Verify all RPC connections are active + let num_sessions = rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert_eq!(num_sessions, 3); + for client in &mut clients { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_ok()); + } + + // Get a new connection to the same peer (with 'drop_old_connections = true') + let mut conn1_2b = node1 + .connectivity() + .dial_peer(node2.node_identity().node_id().clone(), true) + .await + .unwrap(); + let num_sessions = rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert_eq!(num_sessions, 3); + + // Establish some new RPC connections + let mut new_clients = Vec::new(); + for _ in 0..numer_of_clients { + new_clients.push(conn1_2b.connect_rpc::().await.unwrap()); + } + let num_sessions = rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert_eq!(num_sessions, 6); + + // Verify that the old RPC connections are active + for client in &mut clients { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_ok()); + } + + // Verify that the new RPC connections are active + for client in &mut new_clients { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_ok()); + } + + // The RPC server closes all RPC connections + let num_closed = rpc_server2 + .close_all_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert_eq!(num_closed, 6); + + // Verify the RPC connections are closed + let num_sessions = rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert_eq!(num_sessions, 0); + for client in &mut clients { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_err()); + } + for client in &mut clients { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_err()); + } +} + #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn client_prematurely_ends_session() { let shutdown = Shutdown::new(); @@ -76,7 +465,7 @@ async fn client_prematurely_ends_session() { let mut conn1_2 = node1 .connectivity() - .dial_peer(node2.node_identity().node_id().clone()) + .dial_peer(node2.node_identity().node_id().clone(), false) .await .unwrap(); diff --git a/comms/core/tests/tests/rpc_stress.rs b/comms/core/tests/tests/rpc_stress.rs index 9a445e8f14..239f74d840 100644 --- a/comms/core/tests/tests/rpc_stress.rs +++ b/comms/core/tests/tests/rpc_stress.rs @@ -100,7 +100,7 @@ async fn run_stress_test(test_params: Params) { let conn1_2 = node1 .connectivity() - .dial_peer(node2.node_identity().node_id().clone()) + .dial_peer(node2.node_identity().node_id().clone(), false) .await .unwrap(); diff --git a/comms/core/tests/tests/substream_stress.rs b/comms/core/tests/tests/substream_stress.rs index 488ec9064c..d64a5cb5ba 100644 --- a/comms/core/tests/tests/substream_stress.rs +++ b/comms/core/tests/tests/substream_stress.rs @@ -72,7 +72,7 @@ async fn run_stress_test(num_substreams: usize, num_iterations: usize, payload_s let mut conn = node1 .connectivity() - .dial_peer(node2.node_identity().node_id().clone()) + .dial_peer(node2.node_identity().node_id().clone(), false) .await .unwrap(); diff --git a/comms/dht/src/actor.rs b/comms/dht/src/actor.rs index 6bd1084e98..103b7fa82f 100644 --- a/comms/dht/src/actor.rs +++ b/comms/dht/src/actor.rs @@ -912,7 +912,7 @@ impl DiscoveryDialTask { pub async fn run(&mut self, public_key: CommsPublicKey) -> Result { if self.peer_manager.exists(&public_key).await { let node_id = NodeId::from_public_key(&public_key); - match self.connectivity.dial_peer(node_id).await { + match self.connectivity.dial_peer(node_id, false).await { Ok(conn) => Ok(conn), Err(ConnectivityError::ConnectionFailed(err)) => match err { ConnectionManagerError::ConnectFailedMaximumAttemptsReached | @@ -949,7 +949,7 @@ impl DiscoveryDialTask { node_id, timer.elapsed() ); - let conn = self.connectivity.dial_peer(node_id).await?; + let conn = self.connectivity.dial_peer(node_id, false).await?; Ok(conn) } } diff --git a/comms/dht/src/network_discovery/discovering.rs b/comms/dht/src/network_discovery/discovering.rs index 4400639627..d148eee4cf 100644 --- a/comms/dht/src/network_discovery/discovering.rs +++ b/comms/dht/src/network_discovery/discovering.rs @@ -298,7 +298,7 @@ impl Discovering { .map(|peer| { let connectivity = self.context.connectivity.clone(); let peer = peer.clone(); - async move { connectivity.dial_peer(peer).await } + async move { connectivity.dial_peer(peer, false).await } }) .collect::>(); diff --git a/comms/dht/tests/dht.rs b/comms/dht/tests/dht.rs index b58643ab5f..80cc91df78 100644 --- a/comms/dht/tests/dht.rs +++ b/comms/dht/tests/dht.rs @@ -141,7 +141,7 @@ async fn test_dht_discover_propagation() { node_D .comms .connectivity() - .dial_peer(node_C.comms.node_identity().node_id().clone()) + .dial_peer(node_C.comms.node_identity().node_id().clone(), false) .await .unwrap(); @@ -328,7 +328,7 @@ async fn test_dht_propagate_dedup() { node1 .comms .connectivity() - .dial_peer(node2.node_identity().node_id().clone()) + .dial_peer(node2.node_identity().node_id().clone(), false) .await .unwrap(); } @@ -463,21 +463,21 @@ async fn test_dht_do_not_store_invalid_message_in_dedup() { node_A .comms .connectivity() - .dial_peer(node_B.node_identity().node_id().clone()) + .dial_peer(node_B.node_identity().node_id().clone(), false) .await .unwrap(); node_A .comms .connectivity() - .dial_peer(node_C.node_identity().node_id().clone()) + .dial_peer(node_C.node_identity().node_id().clone(), false) .await .unwrap(); node_B .comms .connectivity() - .dial_peer(node_C.node_identity().node_id().clone()) + .dial_peer(node_C.node_identity().node_id().clone(), false) .await .unwrap(); @@ -627,7 +627,7 @@ async fn test_dht_repropagate() { node1 .comms .connectivity() - .dial_peer(node2.node_identity().node_id().clone()) + .dial_peer(node2.node_identity().node_id().clone(), false) .await .unwrap(); } @@ -731,7 +731,7 @@ async fn test_dht_propagate_message_contents_not_malleable_ban() { node_A .comms .connectivity() - .dial_peer(node_B.node_identity().node_id().clone()) + .dial_peer(node_B.node_identity().node_id().clone(), false) .await .unwrap(); @@ -836,7 +836,7 @@ async fn test_dht_header_not_malleable() { node_A .comms .connectivity() - .dial_peer(node_B.node_identity().node_id().clone()) + .dial_peer(node_B.node_identity().node_id().clone(), false) .await .unwrap(); diff --git a/comms/rpc_macros/src/generator.rs b/comms/rpc_macros/src/generator.rs index 6d8dd28a4e..ec5cd264a1 100644 --- a/comms/rpc_macros/src/generator.rs +++ b/comms/rpc_macros/src/generator.rs @@ -198,7 +198,7 @@ impl RpcCodeGenerator { pub async fn connect(framed: #dep_mod::CanonicalFraming) -> Result where TSubstream: #dep_mod::AsyncRead + #dep_mod::AsyncWrite + Unpin + Send + #dep_mod::StreamId + 'static { use #dep_mod::NamedProtocolService; - let inner = #dep_mod::RpcClient::connect(Default::default(), Default::default(), framed, Self::PROTOCOL_NAME.into()).await?; + let inner = #dep_mod::RpcClient::connect(Default::default(), Default::default(), framed, Self::PROTOCOL_NAME.into(), Default::default()).await?; Ok(Self { inner }) }