From 616ca3aa348bda90e4c0eab93c7b0c939ce7f572 Mon Sep 17 00:00:00 2001 From: Hansie Odendaal Date: Fri, 11 Oct 2024 16:15:33 +0200 Subject: [PATCH] Add close RPC connections Added the ability to close RPC connections for a given peer: - The RPC server can request sessions to be dropped. - RPC session count on the server are atomically managed with RPC server sessions. - When a client peer connection (final clone) drops, as a safety precaution, all client RPC connections will be signalled to drop, resulting in the server RPC sessions closing as well. - Improved management of wallet connectivity to base nodes. --- .../src/monero_fail.rs | 2 +- applications/minotari_node/src/bootstrap.rs | 1 + base_layer/contacts/tests/contacts_service.rs | 1 + .../core/src/consensus/consensus_constants.rs | 2 +- base_layer/core/tests/helpers/nodes.rs | 1 + base_layer/p2p/src/config.rs | 5 + .../wallet/src/base_node_service/monitor.rs | 3 + .../wallet/src/connectivity_service/handle.rs | 8 + .../src/connectivity_service/interface.rs | 2 + .../wallet/src/connectivity_service/mock.rs | 4 + .../src/connectivity_service/service.rs | 90 ++-- .../utxo_scanner_service/utxo_scanner_task.rs | 4 +- base_layer/wallet/src/wallet.rs | 17 +- .../output_manager_service_tests/service.rs | 1 + base_layer/wallet_ffi/src/lib.rs | 1 + common/config/presets/c_base_node_c.toml | 4 + common/config/presets/d_console_wallet.toml | 5 + .../src/connection_manager/peer_connection.rs | 51 ++- comms/core/src/protocol/rpc/client/mod.rs | 43 +- comms/core/src/protocol/rpc/client/pool.rs | 68 ++- comms/core/src/protocol/rpc/handshake.rs | 4 +- comms/core/src/protocol/rpc/server/handle.rs | 10 + comms/core/src/protocol/rpc/server/mod.rs | 236 ++++++---- .../src/protocol/rpc/test/greeting_service.rs | 1 + comms/core/src/protocol/rpc/test/smoke.rs | 3 +- comms/core/tests/tests/rpc.rs | 423 +++++++++++++++++- comms/dht/src/store_forward/service.rs | 7 + comms/rpc_macros/src/generator.rs | 2 +- 28 files changed, 850 insertions(+), 149 deletions(-) diff --git a/applications/minotari_merge_mining_proxy/src/monero_fail.rs b/applications/minotari_merge_mining_proxy/src/monero_fail.rs index 0d6c24f18cd..793b92c9cb3 100644 --- a/applications/minotari_merge_mining_proxy/src/monero_fail.rs +++ b/applications/minotari_merge_mining_proxy/src/monero_fail.rs @@ -478,7 +478,7 @@ mod test { } println!("{}: {:?}", i, entry); } - assert_eq!(ordered_entries.len(), 2); + assert!(ordered_entries.len() <= 2); } #[tokio::test] diff --git a/applications/minotari_node/src/bootstrap.rs b/applications/minotari_node/src/bootstrap.rs index f06ddb37e9a..fc41b77cfdb 100644 --- a/applications/minotari_node/src/bootstrap.rs +++ b/applications/minotari_node/src/bootstrap.rs @@ -230,6 +230,7 @@ where B: BlockchainBackend + 'static let rpc_server = RpcServer::builder() .with_maximum_simultaneous_sessions(config.rpc_max_simultaneous_sessions) .with_maximum_sessions_per_client(config.rpc_max_sessions_per_peer) + .with_cull_oldest_peer_rpc_connection_on_full(config.cull_oldest_peer_rpc_connection_on_full) .finish(); // Add your RPC services here ‍🏴‍☠️️☮️🌊 diff --git a/base_layer/contacts/tests/contacts_service.rs b/base_layer/contacts/tests/contacts_service.rs index 73acc051ae8..7b8f171267e 100644 --- a/base_layer/contacts/tests/contacts_service.rs +++ b/base_layer/contacts/tests/contacts_service.rs @@ -96,6 +96,7 @@ pub fn setup_contacts_service( rpc_max_simultaneous_sessions: 0, rpc_max_sessions_per_peer: 0, listener_self_liveness_check_interval: None, + cull_oldest_peer_rpc_connection_on_full: true, }; let peer_message_subscription_factory = Arc::new(subscription_factory); let shutdown = Shutdown::new(); diff --git a/base_layer/core/src/consensus/consensus_constants.rs b/base_layer/core/src/consensus/consensus_constants.rs index 472bbe59e08..7762f94dc2d 100644 --- a/base_layer/core/src/consensus/consensus_constants.rs +++ b/base_layer/core/src/consensus/consensus_constants.rs @@ -656,7 +656,7 @@ impl ConsensusConstants { let consensus_constants = vec![con_1, con_2]; #[cfg(any(test, debug_assertions))] - assert_hybrid_pow_constants(&consensus_constants, &[120], &[50], &[50]); + assert_hybrid_pow_constants(&consensus_constants, &[120, 120], &[50, 50], &[50, 50]); consensus_constants } diff --git a/base_layer/core/tests/helpers/nodes.rs b/base_layer/core/tests/helpers/nodes.rs index 7b5bcfd0f7c..42db66ad40b 100644 --- a/base_layer/core/tests/helpers/nodes.rs +++ b/base_layer/core/tests/helpers/nodes.rs @@ -383,6 +383,7 @@ async fn setup_base_node_services( let rpc_server = RpcServer::builder() .with_maximum_simultaneous_sessions(p2p_config.rpc_max_simultaneous_sessions) .with_maximum_sessions_per_client(p2p_config.rpc_max_sessions_per_peer) + .with_cull_oldest_peer_rpc_connection_on_full(p2p_config.cull_oldest_peer_rpc_connection_on_full) .finish(); let rpc_server = rpc_server.add_service(base_node::create_base_node_sync_rpc_service( blockchain_db.clone().into(), diff --git a/base_layer/p2p/src/config.rs b/base_layer/p2p/src/config.rs index 753c23076ff..57ef8f687ec 100644 --- a/base_layer/p2p/src/config.rs +++ b/base_layer/p2p/src/config.rs @@ -139,6 +139,10 @@ pub struct P2pConfig { /// The maximum allowed RPC sessions per peer. /// Default: 10 pub rpc_max_sessions_per_peer: usize, + /// If true, and the maximum per peer RPC sessions is reached, the RPC server will close an old session and replace + /// it with a new session. If false, the RPC server will reject the new session and preserve the older session. + /// (default value = true). + pub cull_oldest_peer_rpc_connection_on_full: bool, } impl Default for P2pConfig { @@ -163,6 +167,7 @@ impl Default for P2pConfig { auxiliary_tcp_listener_address: None, rpc_max_simultaneous_sessions: 100, rpc_max_sessions_per_peer: 10, + cull_oldest_peer_rpc_connection_on_full: true, } } } diff --git a/base_layer/wallet/src/base_node_service/monitor.rs b/base_layer/wallet/src/base_node_service/monitor.rs index 30164ebf003..22d1cbeb888 100644 --- a/base_layer/wallet/src/base_node_service/monitor.rs +++ b/base_layer/wallet/src/base_node_service/monitor.rs @@ -106,6 +106,9 @@ where latency: None, }) .await; + if let Some(node_id) = self.wallet_connectivity.get_current_base_node_peer_node_id() { + self.wallet_connectivity.disconnect_base_node(node_id).await; + } continue; }, Err(e @ BaseNodeMonitorError::InvalidBaseNodeResponse(_)) | diff --git a/base_layer/wallet/src/connectivity_service/handle.rs b/base_layer/wallet/src/connectivity_service/handle.rs index 876ad089d3e..ada839eb54c 100644 --- a/base_layer/wallet/src/connectivity_service/handle.rs +++ b/base_layer/wallet/src/connectivity_service/handle.rs @@ -37,6 +37,7 @@ use crate::{ pub enum WalletConnectivityRequest { ObtainBaseNodeWalletRpcClient(oneshot::Sender>), ObtainBaseNodeSyncRpcClient(oneshot::Sender>), + DisconnectBaseNode(NodeId), } #[derive(Clone)] @@ -118,6 +119,13 @@ impl WalletConnectivityInterface for WalletConnectivityHandle { reply_rx.await.ok() } + async fn disconnect_base_node(&mut self, node_id: NodeId) { + let _unused = self + .sender + .send(WalletConnectivityRequest::DisconnectBaseNode(node_id)) + .await; + } + fn get_connectivity_status(&mut self) -> OnlineStatus { *self.online_status_rx.borrow() } diff --git a/base_layer/wallet/src/connectivity_service/interface.rs b/base_layer/wallet/src/connectivity_service/interface.rs index e974df59eb6..65c3757be04 100644 --- a/base_layer/wallet/src/connectivity_service/interface.rs +++ b/base_layer/wallet/src/connectivity_service/interface.rs @@ -64,6 +64,8 @@ pub trait WalletConnectivityInterface: Clone + Send + Sync + 'static { /// BaseNodeSyncRpcClient RPC session. async fn obtain_base_node_sync_rpc_client(&mut self) -> Option>; + async fn disconnect_base_node(&mut self, node_id: NodeId); + fn get_connectivity_status(&mut self) -> OnlineStatus; fn get_connectivity_status_watch(&self) -> watch::Receiver; diff --git a/base_layer/wallet/src/connectivity_service/mock.rs b/base_layer/wallet/src/connectivity_service/mock.rs index e34f4eca477..5928f652211 100644 --- a/base_layer/wallet/src/connectivity_service/mock.rs +++ b/base_layer/wallet/src/connectivity_service/mock.rs @@ -116,6 +116,10 @@ impl WalletConnectivityInterface for WalletConnectivityMock { borrow.as_ref().cloned() } + async fn disconnect_base_node(&mut self, _node_id: NodeId) { + self.send_shutdown(); + } + fn get_connectivity_status(&mut self) -> OnlineStatus { *self.online_status_watch.borrow() } diff --git a/base_layer/wallet/src/connectivity_service/service.rs b/base_layer/wallet/src/connectivity_service/service.rs index ffbef54b7ca..ab5ac039bf2 100644 --- a/base_layer/wallet/src/connectivity_service/service.rs +++ b/base_layer/wallet/src/connectivity_service/service.rs @@ -45,7 +45,6 @@ use crate::{ const LOG_TARGET: &str = "wallet::connectivity"; pub(crate) const CONNECTIVITY_WAIT: u64 = 5; -pub(crate) const COOL_OFF_PERIOD: u64 = 60; /// Connection status of the Base Node #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -161,7 +160,11 @@ impl WalletConnectivityService { } async fn handle_request(&mut self, request: WalletConnectivityRequest) { - use WalletConnectivityRequest::{ObtainBaseNodeSyncRpcClient, ObtainBaseNodeWalletRpcClient}; + use WalletConnectivityRequest::{ + DisconnectBaseNode, + ObtainBaseNodeSyncRpcClient, + ObtainBaseNodeWalletRpcClient, + }; match request { ObtainBaseNodeWalletRpcClient(reply) => { self.handle_pool_request(reply.into()).await; @@ -169,6 +172,9 @@ impl WalletConnectivityService { ObtainBaseNodeSyncRpcClient(reply) => { self.handle_pool_request(reply.into()).await; }, + DisconnectBaseNode(node_id) => { + self.disconnect_base_node(node_id).await; + }, } } @@ -195,18 +201,17 @@ impl WalletConnectivityService { match self.pools.get(&node_id) { Some(pools) => match pools.base_node_wallet_rpc_client.get().await { Ok(client) => { + debug!(target: LOG_TARGET, "Obtained pool RPC 'wallet' connection to base node '{}'", node_id); let _result = reply.send(client); }, Err(e) => { warn!( target: LOG_TARGET, - "Base node '{}' wallet RPC pool connection failed ({}). Reconnecting...", + "Base node '{}' pool RPC 'wallet' connection failed ({}). Reconnecting...", node_id, e ); - if let Some(node_id) = self.current_base_node() { - self.disconnect_base_node(node_id).await; - }; + self.disconnect_base_node(node_id).await; self.pending_requests.push(reply.into()); }, }, @@ -237,18 +242,17 @@ impl WalletConnectivityService { match self.pools.get(&node_id) { Some(pools) => match pools.base_node_sync_rpc_client.get().await { Ok(client) => { + debug!(target: LOG_TARGET, "Obtained pool RPC 'sync' connection to base node '{}'", node_id); let _result = reply.send(client); }, Err(e) => { warn!( target: LOG_TARGET, - "Base node '{}' sync RPC pool connection failed ({}). Reconnecting...", + "Base node '{}' pool RPC 'sync' connection failed ({}). Reconnecting...", node_id, e ); - if let Some(node_id) = self.current_base_node() { - self.disconnect_base_node(node_id).await; - }; + self.disconnect_base_node(node_id).await; self.pending_requests.push(reply.into()); }, }, @@ -282,6 +286,8 @@ impl WalletConnectivityService { Err(e) => error!(target: LOG_TARGET, "Failed to disconnect base node: {}", e), } self.pools.remove(&node_id); + // We want to ensure any active RPC clients are dropped when this connection (a clone) is dropped + connection.set_force_disconnect_rpc_clients_when_clone_drops(); }; } @@ -292,16 +298,17 @@ impl WalletConnectivityService { return; }; loop { - let node_id = if let Some(time) = peer_manager.time_since_last_connection_attempt() { - if time < Duration::from_secs(COOL_OFF_PERIOD) { - if peer_manager.get_current_peer().node_id == peer_manager.get_next_peer().node_id { - // If we only have one peer in the list, wait a bit before retrying - time::sleep(Duration::from_secs(CONNECTIVITY_WAIT)).await; - } - peer_manager.get_current_peer().node_id - } else { - peer_manager.get_current_peer().node_id + let node_id = if let Some(_time) = peer_manager.time_since_last_connection_attempt() { + if peer_manager.get_current_peer().node_id == peer_manager.get_next_peer().node_id { + // If we only have one peer in the list, wait a bit before retrying + debug!(target: LOG_TARGET, + "Retrying after {}s ...", + Duration::from_secs(CONNECTIVITY_WAIT).as_secs() + ); + time::sleep(Duration::from_secs(CONNECTIVITY_WAIT)).await; } + // If 'peer_manager.get_next_peer()' is called, 'current_peer' is advanced to the next peer + peer_manager.get_current_peer().node_id } else { peer_manager.get_current_peer().node_id }; @@ -325,14 +332,13 @@ impl WalletConnectivityService { break; } self.base_node_watch.send(Some(peer_manager.clone())); - if let Err(e) = self.notify_pending_requests().await { - warn!(target: LOG_TARGET, "Error notifying pending RPC requests: {}", e); + if let Ok(true) = self.notify_pending_requests().await { + self.set_online_status(OnlineStatus::Online); + debug!( + target: LOG_TARGET, + "Wallet is ONLINE and connected to base node '{}'", node_id + ); } - self.set_online_status(OnlineStatus::Online); - debug!( - target: LOG_TARGET, - "Wallet is ONLINE and connected to base node '{}'", node_id - ); break; }, Ok(false) => { @@ -340,21 +346,15 @@ impl WalletConnectivityService { target: LOG_TARGET, "The peer has changed while connecting. Attempting to connect to new base node." ); + self.disconnect_base_node(node_id).await; }, Err(WalletConnectivityError::ConnectivityError(ConnectivityError::DialCancelled)) => { - debug!( - target: LOG_TARGET, - "Dial was cancelled. Retrying after {}s ...", - Duration::from_secs(CONNECTIVITY_WAIT).as_secs() - ); - time::sleep(Duration::from_secs(CONNECTIVITY_WAIT)).await; + debug!(target: LOG_TARGET, "Dial was cancelled."); + self.disconnect_base_node(node_id).await; }, Err(e) => { warn!(target: LOG_TARGET, "{}", e); - if self.current_base_node().as_ref() == Some(&node_id) { - self.disconnect_base_node(node_id).await; - time::sleep(Duration::from_secs(CONNECTIVITY_WAIT)).await; - } + self.disconnect_base_node(node_id).await; }, } if self.peer_list_change_detected(&peer_manager) { @@ -401,7 +401,7 @@ impl WalletConnectivityService { }; debug!( target: LOG_TARGET, - "Successfully established peer connection to base node '{}'", + "Established peer connection to base node '{}'", conn.peer_node_id() ); self.pools.insert(peer_node_id.clone(), ClientPoolContainer { @@ -409,7 +409,7 @@ impl WalletConnectivityService { base_node_wallet_rpc_client: conn .create_rpc_client_pool(self.config.base_node_rpc_pool_size, Default::default()), }); - debug!(target: LOG_TARGET, "Successfully established RPC connection to base node '{}'", peer_node_id); + trace!(target: LOG_TARGET, "Created RPC pools for '{}'", peer_node_id); Ok(true) } @@ -426,16 +426,24 @@ impl WalletConnectivityService { } } - async fn notify_pending_requests(&mut self) -> Result<(), WalletConnectivityError> { + async fn notify_pending_requests(&mut self) -> Result { let current_pending = mem::take(&mut self.pending_requests); + let mut count = 0; + let current_pending_len = current_pending.len(); for reply in current_pending { if reply.is_canceled() { continue; } - + count += 1; + trace!(target: LOG_TARGET, "Handle {} of {} pending RPC pool requests", count, current_pending_len); self.handle_pool_request(reply).await; } - Ok(()) + if self.pending_requests.is_empty() { + Ok(true) + } else { + warn!(target: LOG_TARGET, "{} of {} pending RPC pool requests not handled", count, current_pending_len); + Ok(false) + } } } diff --git a/base_layer/wallet/src/utxo_scanner_service/utxo_scanner_task.rs b/base_layer/wallet/src/utxo_scanner_service/utxo_scanner_task.rs index 500a6acf277..41fb837de38 100644 --- a/base_layer/wallet/src/utxo_scanner_service/utxo_scanner_task.rs +++ b/base_layer/wallet/src/utxo_scanner_service/utxo_scanner_task.rs @@ -179,7 +179,7 @@ where Ok(()) } - async fn connect_to_peer(&mut self, peer: NodeId) -> Result { + async fn new_connection_to_peer(&mut self, peer: NodeId) -> Result { debug!( target: LOG_TARGET, "Attempting UTXO sync with seed peer {} ({})", self.peer_index, peer, @@ -333,7 +333,7 @@ where &mut self, peer: &NodeId, ) -> Result, UtxoScannerError> { - let mut connection = self.connect_to_peer(peer.clone()).await?; + let mut connection = self.new_connection_to_peer(peer.clone()).await?; let client = connection .connect_rpc_using_builder(BaseNodeWalletRpcClient::builder().with_deadline(Duration::from_secs(60))) .await?; diff --git a/base_layer/wallet/src/wallet.rs b/base_layer/wallet/src/wallet.rs index dcd779007d0..bbbfade4da8 100644 --- a/base_layer/wallet/src/wallet.rs +++ b/base_layer/wallet/src/wallet.rs @@ -412,14 +412,12 @@ where peer_manager.add_peer(current_peer.clone()).await?; } } - connectivity - .add_peer_to_allow_list(current_peer.node_id.clone()) - .await?; let mut peer_list = vec![current_peer]; if let Some(pos) = backup_peers.iter().position(|p| p.public_key == public_key) { backup_peers.remove(pos); } peer_list.append(&mut backup_peers); + self.update_allow_list(&peer_list).await?; self.wallet_connectivity .set_base_node(BaseNodePeerManager::new(0, peer_list)?); } else { @@ -451,6 +449,7 @@ where backup_peers.remove(pos); } peer_list.append(&mut backup_peers); + self.update_allow_list(&peer_list).await?; self.wallet_connectivity .set_base_node(BaseNodePeerManager::new(0, peer_list)?); } @@ -458,6 +457,18 @@ where Ok(()) } + async fn update_allow_list(&mut self, peer_list: &[Peer]) -> Result<(), WalletError> { + let mut connectivity = self.comms.connectivity(); + let current_allow_list = connectivity.get_allow_list().await?; + for peer in ¤t_allow_list { + connectivity.remove_peer_from_allow_list(peer.clone()).await?; + } + for peer in peer_list { + connectivity.add_peer_to_allow_list(peer.node_id.clone()).await?; + } + Ok(()) + } + pub async fn get_base_node_peer(&mut self) -> Option { self.wallet_connectivity.get_current_base_node_peer() } diff --git a/base_layer/wallet/tests/output_manager_service_tests/service.rs b/base_layer/wallet/tests/output_manager_service_tests/service.rs index e788e093d0e..8d821462570 100644 --- a/base_layer/wallet/tests/output_manager_service_tests/service.rs +++ b/base_layer/wallet/tests/output_manager_service_tests/service.rs @@ -1927,6 +1927,7 @@ async fn test_txo_validation() { #[tokio::test] #[allow(clippy::too_many_lines)] async fn test_txo_revalidation() { + // env_logger::init(); // Set `$env:RUST_LOG = "trace"` let (connection, _tempdir) = get_temp_sqlite_database_connection(); let backend = OutputManagerSqliteDatabase::new(connection.clone()); diff --git a/base_layer/wallet_ffi/src/lib.rs b/base_layer/wallet_ffi/src/lib.rs index 0da659a5139..b522120bb73 100644 --- a/base_layer/wallet_ffi/src/lib.rs +++ b/base_layer/wallet_ffi/src/lib.rs @@ -5452,6 +5452,7 @@ pub unsafe extern "C" fn comms_config_create( rpc_max_simultaneous_sessions: 0, rpc_max_sessions_per_peer: 0, listener_self_liveness_check_interval: None, + cull_oldest_peer_rpc_connection_on_full: true, }; Box::into_raw(Box::new(config)) diff --git a/common/config/presets/c_base_node_c.toml b/common/config/presets/c_base_node_c.toml index d1c81f3dea4..8aeeec7c23a 100644 --- a/common/config/presets/c_base_node_c.toml +++ b/common/config/presets/c_base_node_c.toml @@ -162,6 +162,10 @@ listener_self_liveness_check_interval = 15 #rpc_max_simultaneous_sessions = 100 # The maximum comms RPC sessions allowed per peer (default value = 10). #rpc_max_sessions_per_peer = 10 +# If true, and the maximum per peer RPC sessions is reached, the RPC server will close an old session and replace it +# with a new session. If false, the RPC server will reject the new session and preserve the older session. +# (default value = true). +#pub cull_oldest_peer_rpc_connection_on_full = true [base_node.p2p.transport] # -------------- Transport configuration -------------- diff --git a/common/config/presets/d_console_wallet.toml b/common/config/presets/d_console_wallet.toml index cee79ca337a..6fdef9e5161 100644 --- a/common/config/presets/d_console_wallet.toml +++ b/common/config/presets/d_console_wallet.toml @@ -208,6 +208,11 @@ event_channel_size = 3500 #rpc_max_simultaneous_sessions = 100 # The maximum comms RPC sessions allowed per peer (default value = 10). #rpc_max_sessions_per_peer = 10 +#rpc_max_sessions_per_peer = 10 +# If true, and the maximum per peer RPC sessions is reached, the RPC server will close an old session and replace it +# with a new session. If false, the RPC server will reject the new session and preserve the older session. +# (default value = true). +#pub cull_oldest_peer_rpc_connection_on_full = true [wallet.p2p.transport] # -------------- Transport configuration -------------- diff --git a/comms/core/src/connection_manager/peer_connection.rs b/comms/core/src/connection_manager/peer_connection.rs index b1e1435d2c0..9b51b69129f 100644 --- a/comms/core/src/connection_manager/peer_connection.rs +++ b/comms/core/src/connection_manager/peer_connection.rs @@ -24,7 +24,7 @@ use std::{ fmt, future::Future, sync::{ - atomic::{AtomicUsize, Ordering}, + atomic::{AtomicBool, AtomicUsize, Ordering}, Arc, }, time::{Duration, Instant}, @@ -33,6 +33,7 @@ use std::{ use futures::{future::BoxFuture, stream::FuturesUnordered}; use log::*; use multiaddr::Multiaddr; +use tari_shutdown::oneshot_trigger::OneshotTrigger; use tokio::{ sync::{mpsc, oneshot}, time, @@ -137,6 +138,9 @@ pub struct PeerConnection { started_at: Instant, substream_counter: AtomicRefCounter, handle_counter: Arc<()>, + drop_notifier: OneshotTrigger, + number_of_rpc_clients: Arc, + force_disconnect_rpc_clients_when_clone_drops: Arc, } impl PeerConnection { @@ -159,6 +163,9 @@ impl PeerConnection { started_at: Instant::now(), substream_counter, handle_counter: Arc::new(()), + drop_notifier: OneshotTrigger::::new(), + number_of_rpc_clients: Arc::new(AtomicUsize::new(0)), + force_disconnect_rpc_clients_when_clone_drops: Arc::new(Default::default()), } } @@ -254,11 +261,16 @@ impl PeerConnection { self.peer_node_id ); let framed = self.open_framed_substream(&protocol, RPC_MAX_FRAME_SIZE).await?; - builder + + let rpc_client = builder .with_protocol_id(protocol) .with_node_id(self.peer_node_id.clone()) + .with_terminate_signal(self.drop_notifier.to_signal()) .connect(framed) - .await + .await?; + self.number_of_rpc_clients.fetch_add(1, Ordering::Relaxed); + + Ok(rpc_client) } /// Creates a new RpcClientPool that can be shared between tasks. The client pool will lazily establish up to @@ -296,6 +308,39 @@ impl PeerConnection { .await .map_err(|_| PeerConnectionError::InternalReplyCancelled)? } + + /// Forcefully disconnect all RPC clients when any clone is dropped - if not set (the default behaviour) all RPC + /// clients will be disconnected when the last instance is dropped. i.e. when `self.handle_counter == 1` + pub fn set_force_disconnect_rpc_clients_when_clone_drops(&mut self) { + self.force_disconnect_rpc_clients_when_clone_drops + .store(true, Ordering::Relaxed); + } +} + +impl Drop for PeerConnection { + fn drop(&mut self) { + if self.handle_count() <= 1 || + self.force_disconnect_rpc_clients_when_clone_drops + .load(Ordering::Relaxed) + { + let number_of_rpc_clients = self.number_of_rpc_clients.load(Ordering::Relaxed); + if number_of_rpc_clients > 0 { + self.drop_notifier.broadcast(self.peer_node_id.clone()); + trace!( + target: LOG_TARGET, + "PeerConnection `{}` drop called, open sub-streams: {}, notified {} potential RPC clients to drop \ + connection", + self.peer_node_id.clone(), self.substream_count(), number_of_rpc_clients, + ); + } else { + trace!( + target: LOG_TARGET, + "PeerConnection `{}` drop called, open sub-streams: {}, RPC clients: {}", + self.peer_node_id, self.substream_count(), number_of_rpc_clients + ); + } + } + } } impl fmt::Display for PeerConnection { diff --git a/comms/core/src/protocol/rpc/client/mod.rs b/comms/core/src/protocol/rpc/client/mod.rs index 19957151006..300c902303d 100644 --- a/comms/core/src/protocol/rpc/client/mod.rs +++ b/comms/core/src/protocol/rpc/client/mod.rs @@ -49,7 +49,7 @@ use futures::{ }; use log::*; use prost::Message; -use tari_shutdown::{Shutdown, ShutdownSignal}; +use tari_shutdown::{oneshot_trigger::OneshotSignal, Shutdown, ShutdownSignal}; use tokio::{ io::{AsyncRead, AsyncWrite}, sync::{mpsc, oneshot, watch, Mutex}, @@ -101,10 +101,12 @@ impl RpcClient { node_id: NodeId, framed: CanonicalFraming, protocol_name: ProtocolId, + terminate_signal: Option>, ) -> Result where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId + 'static, { + trace!(target: LOG_TARGET,"connect to {:?} with {:?}", node_id, config); let (request_tx, request_rx) = mpsc::channel(1); let shutdown = Shutdown::new(); let shutdown_signal = shutdown.to_signal(); @@ -112,6 +114,7 @@ impl RpcClient { let connector = ClientConnector::new(request_tx, last_request_latency_rx, shutdown); let (ready_tx, ready_rx) = oneshot::channel(); let tracing_id = tracing::Span::current().id(); + tokio::spawn({ let span = span!(Level::TRACE, "start_rpc_worker"); span.follows_from(tracing_id); @@ -125,6 +128,7 @@ impl RpcClient { ready_tx, protocol_name, shutdown_signal, + terminate_signal, ) .run() .instrument(span) @@ -207,6 +211,7 @@ pub struct RpcClientBuilder { config: RpcClientConfig, protocol_id: Option, node_id: Option, + terminate_signal: Option>, _client: PhantomData, } @@ -216,6 +221,7 @@ impl Default for RpcClientBuilder { config: Default::default(), protocol_id: None, node_id: None, + terminate_signal: None, _client: PhantomData, } } @@ -266,6 +272,12 @@ impl RpcClientBuilder { self.node_id = Some(node_id); self } + + /// Set a signal that indicates if this client should be immediately closed + pub fn with_terminate_signal(mut self, terminate_signal: OneshotSignal) -> Self { + self.terminate_signal = Some(terminate_signal); + self + } } impl RpcClientBuilder @@ -282,6 +294,7 @@ where TClient: From + NamedProtocolService .as_ref() .cloned() .unwrap_or_else(|| ProtocolId::from_static(TClient::PROTOCOL_NAME)), + self.terminate_signal, ) .await .map(Into::into) @@ -404,6 +417,7 @@ struct RpcClientWorker { ready_tx: Option>>, protocol_id: ProtocolId, shutdown_signal: ShutdownSignal, + terminate_signal: Option>, } impl RpcClientWorker @@ -418,6 +432,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId ready_tx: oneshot::Sender>, protocol_id: ProtocolId, shutdown_signal: ShutdownSignal, + terminate_signal: Option>, ) -> Self { Self { config, @@ -429,6 +444,7 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId last_request_latency_tx, protocol_id, shutdown_signal, + terminate_signal, } } @@ -477,6 +493,12 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId }, } + let mut terminate_signal = self + .terminate_signal + .take() + .map(|f| f.boxed()) + .unwrap_or_else(|| future::pending::>().boxed()); + #[cfg(feature = "metrics")] metrics::num_sessions(&self.node_id, &self.protocol_id).inc(); loop { @@ -486,18 +508,33 @@ where TSubstream: AsyncRead + AsyncWrite + Unpin + Send + StreamId _ = &mut self.shutdown_signal => { break; } + node_id = &mut terminate_signal => { + debug!( + target: LOG_TARGET, "(stream={}) Peer '{}' connection has dropped. Worker is terminating.", + self.stream_id(), node_id.unwrap_or_default() + ); + break; + } req = self.request_rx.recv() => { match req { Some(req) => { if let Err(err) = self.handle_request(req).await { #[cfg(feature = "metrics")] metrics::client_errors(&self.node_id, &self.protocol_id).inc(); - error!(target: LOG_TARGET, "(stream={}) Unexpected error: {}. Worker is terminating.", self.stream_id(), err); + error!( + target: LOG_TARGET, + "(stream={}) Unexpected error: {}. Worker is terminating.", + self.stream_id(), err + ); break; } } None => { - debug!(target: LOG_TARGET, "(stream={}) Request channel closed. Worker is terminating.", self.stream_id()); + debug!( + target: LOG_TARGET, + "(stream={}) Request channel closed. Worker is terminating.", + self.stream_id() + ); break }, } diff --git a/comms/core/src/protocol/rpc/client/pool.rs b/comms/core/src/protocol/rpc/client/pool.rs index 56d1a018397..b6dee1d2e7d 100644 --- a/comms/core/src/protocol/rpc/client/pool.rs +++ b/comms/core/src/protocol/rpc/client/pool.rs @@ -90,21 +90,54 @@ where T: RpcPoolClient + From + NamedProtocolService + Clone } pub async fn get_least_used_or_connect(&mut self) -> Result, RpcClientPoolError> { - loop { + { self.check_peer_connection()?; + let peer_node_id = self.connection.peer_node_id().clone(); + let clients_capacity = self.clients.capacity(); + let protocol_id = self.client_config.protocol_id.clone(); let client = match self.get_next_lease() { - Some(c) => c, + Some(c) => { + trace!( + target: LOG_TARGET, + "Used existing RPC client session for connection '{}'", + self.connection.peer_node_id(), + ); + c + }, None => match self.add_new_client_session().await { - Ok(c) => c, + Ok((c, len)) => { + trace!( + target: LOG_TARGET, + "Added new RPC client session for connection '{}' ({} of {}, protocol: {:?})", + peer_node_id, len, clients_capacity, protocol_id, + ); + c + }, // This is an edge case where the remote node does not have any further sessions available. This is // gracefully handled by returning one of the existing used sessions. - Err(RpcClientPoolError::NoMoreRemoteServerRpcSessions(val)) => self - .get_least_used() - .ok_or(RpcClientPoolError::NoMoreRemoteServerRpcSessions(val))?, - Err(RpcClientPoolError::NoMoreRemoteClientRpcSessions(val)) => self - .get_least_used() - .ok_or(RpcClientPoolError::NoMoreRemoteClientRpcSessions(val))?, + Err(RpcClientPoolError::NoMoreRemoteServerRpcSessions(val)) => { + let c = self + .get_least_used() + .ok_or(RpcClientPoolError::NoMoreRemoteServerRpcSessions(val.clone()))?; + trace!( + target: LOG_TARGET, + "Used existing RPC client session for connection '{}', protocol: {:?} ({})", + peer_node_id, protocol_id, RpcClientPoolError::NoMoreRemoteServerRpcSessions(val), + ); + c + }, + Err(RpcClientPoolError::NoMoreRemoteClientRpcSessions(val)) => { + let c = self + .get_least_used() + .ok_or(RpcClientPoolError::NoMoreRemoteClientRpcSessions(val.clone()))?; + trace!( + target: LOG_TARGET, + "used existing RPC client session for connection '{}', protocol: {:?} ({})", + peer_node_id, protocol_id, RpcClientPoolError::NoMoreRemoteClientRpcSessions(val), + ); + c + }, Err(err) => { return Err(err); }, @@ -112,12 +145,17 @@ where T: RpcPoolClient + From + NamedProtocolService + Clone }; if !client.is_connected() { + trace!( + target: LOG_TARGET, + "RPC client for connection '{}' is not connected, pruning", + peer_node_id + ); self.prune(); - continue; + return Err(RpcClientPoolError::CouldNotObtainRpcConnection); } // Clone is what actually takes the lease out (increments the Arc) - return Ok(client.clone()); + Ok(client.clone()) } } @@ -184,7 +222,7 @@ where T: RpcPoolClient + From + NamedProtocolService + Clone self.clients.len() == self.clients.capacity() } - async fn add_new_client_session(&mut self) -> Result<&RpcClientLease, RpcClientPoolError> { + async fn add_new_client_session(&mut self) -> Result<(&RpcClientLease, usize), RpcClientPoolError> { debug_assert!(!self.is_full(), "add_new_client called when pool is full"); let client = self .connection @@ -192,7 +230,7 @@ where T: RpcPoolClient + From + NamedProtocolService + Clone .await?; let client = RpcClientLease::new(client); self.clients.push(client); - Ok(self.clients.last().unwrap()) + Ok((self.clients.last().unwrap(), self.clients.len())) } fn prune(&mut self) { @@ -204,7 +242,7 @@ where T: RpcPoolClient + From + NamedProtocolService + Clone } vec }); - assert_eq!(self.clients.capacity(), cap); + debug_assert_eq!(self.clients.capacity(), cap); debug!( target: LOG_TARGET, "Pruned {} client(s) (total connections: {})", @@ -267,6 +305,8 @@ pub enum RpcClientPoolError { NoMoreRemoteClientRpcSessions(String), #[error("Failed to create client connection: {0}")] FailedToConnect(RpcError), + #[error("Could not obtain RPC connection")] + CouldNotObtainRpcConnection, } impl From for RpcClientPoolError { diff --git a/comms/core/src/protocol/rpc/handshake.rs b/comms/core/src/protocol/rpc/handshake.rs index fb61f66a11f..05eddbda144 100644 --- a/comms/core/src/protocol/rpc/handshake.rs +++ b/comms/core/src/protocol/rpc/handshake.rs @@ -86,7 +86,7 @@ where T: AsyncRead + AsyncWrite + Unpin .iter() .find(|v| msg.supported_versions.contains(v)); if let Some(version) = version { - debug!(target: LOG_TARGET, "Server accepted version: {}", version); + debug!(target: LOG_TARGET, "Local server accepted version: {}", version); let reply = proto::rpc::RpcSessionReply { session_result: Some(proto::rpc::rpc_session_reply::SessionResult::AcceptedVersion(*version)), ..Default::default() @@ -152,7 +152,7 @@ where T: AsyncRead + AsyncWrite + Unpin Ok(Some(Ok(msg))) => { let msg = proto::rpc::RpcSessionReply::decode(&mut msg.freeze())?; let version = msg.result()?; - debug!(target: LOG_TARGET, "Server accepted version {}", version); + debug!(target: LOG_TARGET, "Remote server accepted version {}", version); Ok(()) }, Ok(Some(Err(err))) => { diff --git a/comms/core/src/protocol/rpc/server/handle.rs b/comms/core/src/protocol/rpc/server/handle.rs index 8a82912cb52..8a4bbb7a82b 100644 --- a/comms/core/src/protocol/rpc/server/handle.rs +++ b/comms/core/src/protocol/rpc/server/handle.rs @@ -29,6 +29,7 @@ use crate::peer_manager::NodeId; pub enum RpcServerRequest { GetNumActiveSessions(oneshot::Sender), GetNumActiveSessionsForPeer(NodeId, oneshot::Sender), + CloseAllSessionsForPeer(NodeId, oneshot::Sender), } #[derive(Debug, Clone)] @@ -58,4 +59,13 @@ impl RpcServerHandle { .map_err(|_| RpcServerError::RequestCanceled)?; resp.await.map_err(Into::into) } + + pub async fn close_all_sessions_for(&mut self, peer: NodeId) -> Result { + let (req, resp) = oneshot::channel(); + self.sender + .send(RpcServerRequest::CloseAllSessionsForPeer(peer, req)) + .await + .map_err(|_| RpcServerError::RequestCanceled)?; + resp.await.map_err(Into::into) + } } diff --git a/comms/core/src/protocol/rpc/server/mod.rs b/comms/core/src/protocol/rpc/server/mod.rs index 37e969638a1..0cf2fec604a 100644 --- a/comms/core/src/protocol/rpc/server/mod.rs +++ b/comms/core/src/protocol/rpc/server/mod.rs @@ -89,7 +89,7 @@ use crate::{ ProtocolNotification, ProtocolNotificationRx, }, - stream_id::StreamId, + stream_id::{Id, StreamId}, Bytes, Substream, }; @@ -179,6 +179,7 @@ pub struct RpcServerBuilder { maximum_sessions_per_client: Option, minimum_client_deadline: Duration, handshake_timeout: Duration, + cull_oldest_peer_rpc_connection_on_full: bool, } impl RpcServerBuilder { @@ -201,6 +202,11 @@ impl RpcServerBuilder { self } + pub fn with_cull_oldest_peer_rpc_connection_on_full(mut self, cull: bool) -> Self { + self.cull_oldest_peer_rpc_connection_on_full = cull; + self + } + pub fn with_unlimited_sessions_per_client(mut self) -> Self { self.maximum_sessions_per_client = None; self @@ -228,6 +234,7 @@ impl Default for RpcServerBuilder { maximum_sessions_per_client: None, minimum_client_deadline: Duration::from_secs(1), handshake_timeout: Duration::from_secs(15), + cull_oldest_peer_rpc_connection_on_full: false, } } } @@ -239,8 +246,13 @@ pub(super) struct PeerRpcServer { protocol_notifications: Option>, comms_provider: TCommsProvider, request_rx: mpsc::Receiver, - sessions: HashMap, - tasks: FuturesUnordered>, + sessions: HashMap>, + tasks: FuturesUnordered>, +} + +struct SessionInfo { + pub(crate) peer_watch: tokio::sync::watch::Sender<()>, + pub(crate) stream_id: Id, } impl PeerRpcServer @@ -297,8 +309,8 @@ where } } - Some(Ok(node_id)) = self.tasks.next() => { - self.on_session_complete(&node_id); + Some(Ok((node_id, stream_id))) = self.tasks.next() => { + self.on_session_complete(&node_id, stream_id); }, Some(req) = self.request_rx.recv() => { @@ -315,7 +327,7 @@ where Ok(()) } - async fn handle_request(&self, req: RpcServerRequest) { + async fn handle_request(&mut self, req: RpcServerRequest) { #[allow(clippy::enum_glob_use)] use RpcServerRequest::*; match req { @@ -328,9 +340,13 @@ where let _ = reply.send(num_active); }, GetNumActiveSessionsForPeer(node_id, reply) => { - let num_active = self.sessions.get(&node_id).copied().unwrap_or(0); + let num_active = self.sessions.get(&node_id).map(|v| v.len()).unwrap_or(0); let _ = reply.send(num_active); }, + CloseAllSessionsForPeer(node_id, reply) => { + let num_closed = self.close_all_sessions(&node_id); + let _ = reply.send(num_closed); + }, } } @@ -368,35 +384,67 @@ where Ok(()) } - fn new_session_for(&mut self, node_id: NodeId) -> Result { - let count = self.sessions.entry(node_id.clone()).or_insert(0); + fn new_session_possible_for(&mut self, node_id: &NodeId) -> Result { match self.config.maximum_sessions_per_client { Some(max) if max > 0 => { - debug_assert!(*count <= max); - if *count >= max { - return Err(RpcServerError::MaxSessionsPerClientReached { - node_id, - max_sessions: max, - }); + if let Some(session_info) = self.sessions.get_mut(node_id) { + if max > session_info.len() { + Ok(session_info.len()) + } else if self.config.cull_oldest_peer_rpc_connection_on_full { + // Remove the oldest session(s) until we have space for a new one + let num_to_remove = session_info.len() - max + 1; + for _ in 0..num_to_remove { + let info = session_info.remove(0); + info!(target: LOG_TARGET, "Culling oldest RPC session for peer `{}`", node_id); + let _ = info.peer_watch.send(()); + } + Ok(session_info.len()) + } else { + warn!( + target: LOG_TARGET, + "Maximum RPC sessions for peer {} met or exceeded. Max: {}, Current: {}", + node_id, max, session_info.len() + ); + Err(RpcServerError::MaxSessionsPerClientReached { + node_id: node_id.clone(), + max_sessions: max, + }) + } + } else { + Ok(0) } }, - Some(_) | None => {}, + Some(_) | None => Ok(0), } + } - *count += 1; - Ok(*count) + fn close_all_sessions(&mut self, node_id: &NodeId) -> usize { + let mut count = 0; + if let Some(session_info) = self.sessions.get_mut(node_id) { + for info in session_info.iter_mut() { + count += 1; + info!(target: LOG_TARGET, "Closing RPC session {} for peer `{}`", info.stream_id, node_id); + let _ = info.peer_watch.send(()); + } + self.sessions.remove(node_id); + } + count } - fn on_session_complete(&mut self, node_id: &NodeId) { - info!(target: LOG_TARGET, "Session complete for {}", node_id); - if let Some(v) = self.sessions.get_mut(node_id) { - *v -= 1; - if *v == 0 { + fn on_session_complete(&mut self, node_id: &NodeId, stream_id: Id) { + if let Some(session_info) = self.sessions.get_mut(node_id) { + if let Some(info) = session_info.iter_mut().find(|info| info.stream_id == stream_id) { + info!(target: LOG_TARGET, "Session complete for {} (stream id {})", node_id, stream_id); + let _ = info.peer_watch.send(()); + }; + session_info.retain(|info| info.stream_id != stream_id); + if session_info.is_empty() { self.sessions.remove(node_id); } } } + #[allow(clippy::too_many_lines)] async fn try_initiate_service( &mut self, protocol: ProtocolId, @@ -437,11 +485,11 @@ where }, }; - match self.new_session_for(node_id.clone()) { + match self.new_session_possible_for(node_id) { Ok(num_sessions) => { info!( target: LOG_TARGET, - "NEW SESSION for {} ({} active) ", node_id, num_sessions + "NEW SESSION for {} ({} currently active) ", node_id, num_sessions ); }, @@ -460,7 +508,8 @@ where target: LOG_TARGET, "Server negotiated RPC v{} with client node `{}`", version, node_id ); - + let stream_id = framed.stream_id(); + let (stop_tx, stop_rx) = tokio::sync::watch::channel(()); let service = ActivePeerRpcService::new( self.config.clone(), protocol, @@ -468,26 +517,48 @@ where service, framed, self.comms_provider.clone(), + stop_rx, ); - let node_id = node_id.clone(); + let node_id_clone = node_id.clone(); let handle = self .executor .try_spawn(async move { #[cfg(feature = "metrics")] - let num_sessions = metrics::num_sessions(&node_id, &service.protocol); + let num_sessions = metrics::num_sessions(&node_id_clone, &service.protocol); #[cfg(feature = "metrics")] num_sessions.inc(); service.start().await; - info!(target: LOG_TARGET, "END OF SESSION for {} ", node_id,); + info!(target: LOG_TARGET, "END OF SESSION for {} ", node_id_clone,); #[cfg(feature = "metrics")] num_sessions.dec(); - node_id + (node_id_clone, stream_id) }) .map_err(|e| RpcServerError::MaximumSessionsReached(format!("{:?}", e)))?; self.tasks.push(handle); + let mut peer_stop = vec![SessionInfo { + peer_watch: stop_tx, + stream_id, + }]; + self.sessions + .entry(node_id.clone()) + .and_modify(|entry| entry.append(&mut peer_stop)) + .or_insert(peer_stop); + if let Some(info) = self.sessions.get(&node_id.clone()) { + info!( + target: LOG_TARGET, + "NEW SESSION created for {} ({} active) ", node_id.clone(), info.len() + ); + // Warn if `stream_id` is already in use + if info.iter().filter(|session| session.stream_id == stream_id).count() > 1 { + warn!( + target: LOG_TARGET, + "Stream ID {} already in use for peer {}. This should not happen.", stream_id, node_id + ); + } + } Ok(()) } @@ -501,6 +572,7 @@ struct ActivePeerRpcService { framed: EarlyClose>, comms_provider: TCommsProvider, logging_context_string: Arc, + stop_rx: tokio::sync::watch::Receiver<()>, } impl ActivePeerRpcService @@ -515,12 +587,14 @@ where service: TSvc, framed: CanonicalFraming, comms_provider: TCommsProvider, + stop_rx: tokio::sync::watch::Receiver<()>, ) -> Self { Self { logging_context_string: Arc::new(format!( - "stream_id: {}, peer: {}, protocol: {}", + "stream_id: {}, peer: {}, protocol: {}, stream_id: {}", framed.stream_id(), node_id, + framed.stream_id(), String::from_utf8_lossy(&protocol) )), @@ -530,6 +604,7 @@ where service, framed: EarlyClose::new(framed), comms_provider, + stop_rx, } } @@ -557,55 +632,64 @@ where } async fn run(&mut self) -> Result<(), RpcServerError> { - while let Some(result) = self.framed.next().await { - match result { - Ok(frame) => { - #[cfg(feature = "metrics")] - metrics::inbound_requests_bytes(&self.node_id, &self.protocol).observe(frame.len() as f64); - - let start = Instant::now(); - - if let Err(err) = self.handle_request(frame.freeze()).await { - if let Err(err) = self.framed.close().await { - let level = err.io().map(err_to_log_level).unwrap_or(log::Level::Error); - - log!( + loop { + tokio::select! { + _ = self.stop_rx.changed() => { + debug!(target: LOG_TARGET, "({}) Stop signal received, closing substream.", self.logging_context_string); + break; + } + result = self.framed.next() => { + match result { + Some(Ok(frame)) => { + #[cfg(feature = "metrics")] + metrics::inbound_requests_bytes(&self.node_id, &self.protocol).observe(frame.len() as f64); + + let start = Instant::now(); + + if let Err(err) = self.handle_request(frame.freeze()).await { + if let Err(err) = self.framed.close().await { + let level = err.io().map(err_to_log_level).unwrap_or(log::Level::Error); + + log!( + target: LOG_TARGET, + level, + "({}) Failed to close substream after socket error: {}", + self.logging_context_string, + err, + ); + } + let level = err.early_close_io().map(err_to_log_level).unwrap_or(log::Level::Error); + log!( + target: LOG_TARGET, + level, + "(peer: {}, protocol: {}) Failed to handle request: {}", + self.node_id, + self.protocol_name(), + err + ); + return Err(err); + } + let elapsed = start.elapsed(); + trace!( target: LOG_TARGET, - level, - "({}) Failed to close substream after socket error: {}", + "({}) RPC request completed in {:.0?}{}", self.logging_context_string, - err, + elapsed, + if elapsed.as_secs() > 5 { " (LONG REQUEST)" } else { "" } ); - } - let level = err.early_close_io().map(err_to_log_level).unwrap_or(log::Level::Error); - log!( - target: LOG_TARGET, - level, - "(peer: {}, protocol: {}) Failed to handle request: {}", - self.node_id, - self.protocol_name(), - err - ); - return Err(err); - } - let elapsed = start.elapsed(); - trace!( - target: LOG_TARGET, - "({}) RPC request completed in {:.0?}{}", - self.logging_context_string, - elapsed, - if elapsed.as_secs() > 5 { " (LONG REQUEST)" } else { "" } - ); - }, - Err(err) => { - if let Err(err) = self.framed.close().await { - error!( - target: LOG_TARGET, - "({}) Failed to close substream after socket error: {}", self.logging_context_string, err - ); + }, + Some(Err(err)) => { + if let Err(err) = self.framed.close().await { + error!( + target: LOG_TARGET, + "({}) Failed to close substream after socket error: {}", self.logging_context_string, err + ); + } + return Err(err.into()); + }, + None => break, } - return Err(err.into()); - }, + } } } diff --git a/comms/core/src/protocol/rpc/test/greeting_service.rs b/comms/core/src/protocol/rpc/test/greeting_service.rs index 885e2d13aa9..802ba98c889 100644 --- a/comms/core/src/protocol/rpc/test/greeting_service.rs +++ b/comms/core/src/protocol/rpc/test/greeting_service.rs @@ -406,6 +406,7 @@ impl GreetingClient { Default::default(), framed, Self::PROTOCOL_NAME.into(), + Default::default(), ) .await?; Ok(Self { inner }) diff --git a/comms/core/src/protocol/rpc/test/smoke.rs b/comms/core/src/protocol/rpc/test/smoke.rs index f43190a392e..367c4459798 100644 --- a/comms/core/src/protocol/rpc/test/smoke.rs +++ b/comms/core/src/protocol/rpc/test/smoke.rs @@ -541,7 +541,8 @@ async fn max_global_sessions() { async fn max_per_client_sessions() { let builder = RpcServer::builder() .with_maximum_simultaneous_sessions(3) - .with_maximum_sessions_per_client(1); + .with_maximum_sessions_per_client(1) + .with_cull_oldest_peer_rpc_connection_on_full(false); let (muxer, _outbound, context, _shutdown) = setup_service_with_builder(GreetingService::default(), builder).await; let (_, inbound, outbound) = build_multiplexed_connections().await; diff --git a/comms/core/tests/tests/rpc.rs b/comms/core/tests/tests/rpc.rs index d4845d226fc..40bde84a5c5 100644 --- a/comms/core/tests/tests/rpc.rs +++ b/comms/core/tests/tests/rpc.rs @@ -27,13 +27,14 @@ use tari_comms::{ protocol::rpc::{RpcServer, RpcServerHandle}, transports::TcpTransport, CommsNode, + Minimized, }; use tari_shutdown::{Shutdown, ShutdownSignal}; use tari_test_utils::async_assert_eventually; use tokio::time; use crate::tests::{ - greeting_service::{GreetingClient, GreetingServer, GreetingService, StreamLargeItemsRequest}, + greeting_service::{GreetingClient, GreetingServer, GreetingService, SayHelloRequest, StreamLargeItemsRequest}, helpers::create_comms, }; @@ -62,6 +63,426 @@ async fn spawn_node(signal: ShutdownSignal) -> (CommsNode, RpcServerHandle) { (comms, rpc_server_hnd) } +async fn spawn_culling_node(signal: ShutdownSignal, sessions: usize, culling: bool) -> (CommsNode, RpcServerHandle) { + let rpc_server = RpcServer::builder() + .with_maximum_sessions_per_client(sessions) + .with_cull_oldest_peer_rpc_connection_on_full(culling) + .finish() + .add_service(GreetingServer::new(GreetingService::default())); + + let rpc_server_hnd = rpc_server.get_handle(); + let mut comms = create_comms(signal) + .add_rpc_server(rpc_server) + .spawn_with_transport(TcpTransport::new()) + .await + .unwrap(); + + let address = comms + .connection_manager_requester() + .wait_until_listening() + .await + .unwrap(); + comms + .node_identity() + .set_public_addresses(vec![address.bind_address().clone()]); + + (comms, rpc_server_hnd) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn rpc_server_can_request_drop_sessions() { + // env_logger::init(); // Set `$env:RUST_LOG = "trace"` + let shutdown = Shutdown::new(); + let numer_of_clients = 3; + let (node1, _node2, _conn1_2, mut rpc_server2, mut clients) = { + let (node1, _rpc_server1) = spawn_node(shutdown.to_signal()).await; + let (node2, rpc_server2) = spawn_node(shutdown.to_signal()).await; + + node1 + .peer_manager() + .add_peer(node2.node_identity().to_peer()) + .await + .unwrap(); + + let mut conn1_2 = node1 + .connectivity() + .dial_peer(node2.node_identity().node_id().clone()) + .await + .unwrap(); + + let mut clients = Vec::new(); + for _ in 0..numer_of_clients { + clients.push(conn1_2.connect_rpc::().await.unwrap()); + } + + (node1, node2, conn1_2, rpc_server2, clients) + }; + + // Verify all RPC connections are active + let num_sessions = rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert_eq!(num_sessions, 3); + for client in &mut clients { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_ok()); + } + + // The RPC server closes all RPC connections + let num_closed = rpc_server2 + .close_all_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert_eq!(num_closed, 3); + + // Verify the RPC connections are closed + let num_sessions = rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert_eq!(num_sessions, 0); + for client in &mut clients { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_err()); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn rpc_server_can_prioritize_new_connections() { + // env_logger::init(); // Set `$env:RUST_LOG = "trace"` + let shutdown = Shutdown::new(); + let (numer_of_clients, maximum_sessions, cull_oldest) = (3, 2, true); + let (node1, _node2, _conn1_2, mut rpc_server2, mut clients) = { + let (node1, _rpc_server1) = spawn_node(shutdown.to_signal()).await; + let (node2, rpc_server2) = spawn_culling_node(shutdown.to_signal(), maximum_sessions, cull_oldest).await; + + node1 + .peer_manager() + .add_peer(node2.node_identity().to_peer()) + .await + .unwrap(); + + let mut conn1_2 = node1 + .connectivity() + .dial_peer(node2.node_identity().node_id().clone()) + .await + .unwrap(); + + let mut clients = Vec::new(); + for _ in 0..numer_of_clients { + clients.push(conn1_2.connect_rpc::().await.unwrap()); + } + + (node1, node2, conn1_2, rpc_server2, clients) + }; + + // Verify only the latest RPC connections are active + let num_sessions = rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert_eq!(num_sessions, 2); + assert!(clients[0] + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_err()); + for client in clients.iter_mut().skip(1) { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_ok()); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn rpc_server_can_prioritize_old_connections() { + // env_logger::init(); // Set `$env:RUST_LOG = "trace"` + let shutdown = Shutdown::new(); + let (numer_of_clients, maximum_sessions, cull_oldest) = (3, 2, false); + let (node1, _node2, _conn1_2, mut rpc_server2, clients) = { + let (node1, _rpc_server1) = spawn_node(shutdown.to_signal()).await; + let (node2, rpc_server2) = spawn_culling_node(shutdown.to_signal(), maximum_sessions, cull_oldest).await; + + node1 + .peer_manager() + .add_peer(node2.node_identity().to_peer()) + .await + .unwrap(); + + let mut conn1_2 = node1 + .connectivity() + .dial_peer(node2.node_identity().node_id().clone()) + .await + .unwrap(); + + let mut clients = Vec::new(); + for _ in 0..numer_of_clients { + clients.push(conn1_2.connect_rpc::().await); + } + + (node1, node2, conn1_2, rpc_server2, clients) + }; + + // Verify only the initial RPC connections are active + let num_sessions = rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert_eq!(num_sessions, 2); + for (i, mut client_result) in clients.into_iter().enumerate() { + match client_result { + Ok(ref mut client) => { + assert!(i < 2); + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_ok()); + }, + Err(e) => { + assert_eq!(i, 2); + assert_eq!( + e.to_string(), + "Handshake error: RPC handshake was explicitly rejected: no more RPC server sessions available: \ + session limit reached" + .to_string() + ) + }, + } + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn rpc_server_drop_sessions_when_peer_is_disconnected() { + // env_logger::init(); // Set `$env:RUST_LOG = "trace"` + let shutdown = Shutdown::new(); + let numer_of_clients = 3; + let (node1, _node2, mut conn1_2, mut rpc_server2, mut clients) = { + let (node1, _rpc_server1) = spawn_node(shutdown.to_signal()).await; + let (node2, rpc_server2) = spawn_node(shutdown.to_signal()).await; + + node1 + .peer_manager() + .add_peer(node2.node_identity().to_peer()) + .await + .unwrap(); + + let mut conn1_2 = node1 + .connectivity() + .dial_peer(node2.node_identity().node_id().clone()) + .await + .unwrap(); + + let mut clients = Vec::new(); + for _ in 0..numer_of_clients { + clients.push(conn1_2.connect_rpc::().await.unwrap()); + } + + (node1, node2, conn1_2, rpc_server2, clients) + }; + + // Verify all RPC connections are active + let num_sessions = rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert_eq!(num_sessions, 3); + for client in &mut clients { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_ok()); + } + + // RPC connections are closed when the peer is disconnected + conn1_2.disconnect(Minimized::No).await.unwrap(); + + // Verify the RPC connections are closed + async_assert_eventually!( + rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(), + expect = 0, + max_attempts = 20, + interval = Duration::from_millis(1000) + ); + for client in &mut clients { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_err()); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn rpc_server_drop_sessions_when_peer_connection_clone_is_dropped() { + // env_logger::init(); // Set `$env:RUST_LOG = "trace"` + let shutdown = Shutdown::new(); + let numer_of_clients = 3; + let (node1, _node2, mut conn1_2, mut rpc_server2, mut clients) = { + let (node1, _rpc_server1) = spawn_node(shutdown.to_signal()).await; + let (node2, rpc_server2) = spawn_node(shutdown.to_signal()).await; + + node1 + .peer_manager() + .add_peer(node2.node_identity().to_peer()) + .await + .unwrap(); + + let mut conn1_2 = node1 + .connectivity() + .dial_peer(node2.node_identity().node_id().clone()) + .await + .unwrap(); + + let mut clients = Vec::new(); + for _ in 0..numer_of_clients { + clients.push(conn1_2.connect_rpc::().await.unwrap()); + } + + (node1, node2, conn1_2, rpc_server2, clients) + }; + + // Verify all RPC connections are active + let num_sessions = rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert_eq!(num_sessions, 3); + for client in &mut clients { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_ok()); + } + + // RPC connections are closed when the first peer connection clone is dropped + conn1_2.set_force_disconnect_rpc_clients_when_clone_drops(); + assert!(conn1_2.handle_count() > 1); + drop(conn1_2); + + let num_sessions = rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert!(num_sessions >= 1); + + // Verify the RPC connections are closed eventually + async_assert_eventually!( + rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(), + expect = 0, + max_attempts = 10, + interval = Duration::from_millis(1000) + ); + for client in &mut clients { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_err()); + } +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn rpc_server_drop_sessions_when_peer_connection_is_dropped() { + // env_logger::init(); // Set `$env:RUST_LOG = "trace"` + let shutdown = Shutdown::new(); + let numer_of_clients = 3; + let (node1, node2, mut rpc_server2) = { + let (node1, _rpc_server1) = spawn_node(shutdown.to_signal()).await; + let (node2, rpc_server2) = spawn_node(shutdown.to_signal()).await; + + node1 + .peer_manager() + .add_peer(node2.node_identity().to_peer()) + .await + .unwrap(); + + (node1, node2, rpc_server2) + }; + + // Some peer connection clones still exist at the end of this scope, but they are eventually dropped + { + let mut clients = Vec::new(); + let mut conn1_2 = node1 + .connectivity() + .dial_peer(node2.node_identity().node_id().clone()) + .await + .unwrap(); + + for _ in 0..numer_of_clients { + clients.push(conn1_2.connect_rpc::().await.unwrap()); + } + + // Verify all RPC connections are active + let num_sessions = rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert_eq!(num_sessions, 3); + for client in &mut clients { + assert!(client + .say_hello(SayHelloRequest { + name: "Bob".to_string(), + language: 0 + }) + .await + .is_ok()); + } + assert!(conn1_2.handle_count() > 1); + } + let num_sessions = rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(); + assert!(num_sessions >= 1); + + // Verify the RPC connections are eventually closed + async_assert_eventually!( + rpc_server2 + .get_num_active_sessions_for(node1.node_identity().node_id().clone()) + .await + .unwrap(), + expect = 0, + max_attempts = 10, + interval = Duration::from_millis(1000) + ); +} + #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn client_prematurely_ends_session() { let shutdown = Shutdown::new(); diff --git a/comms/dht/src/store_forward/service.rs b/comms/dht/src/store_forward/service.rs index 587261570db..b2ab496626d 100644 --- a/comms/dht/src/store_forward/service.rs +++ b/comms/dht/src/store_forward/service.rs @@ -264,6 +264,13 @@ impl StoreAndForwardService { warn!(target: LOG_TARGET, "Failed to get the minimize connections threshold: {:?}", err); None }); + // If the comms node is configured to minimize connections, delay the SAF connectivity to prioritize other + // higher priority connections + if self.ignore_saf_threshold.is_some() { + let delay = Duration::from_secs(60); + tokio::time::sleep(delay).await; + debug!(target: LOG_TARGET, "SAF connectivity starting after delayed for {:.0?}", delay); + } self.node_id = self.connectivity.get_node_identity().await.map_or_else( |err| { warn!(target: LOG_TARGET, "Failed to get the node identity: {:?}", err); diff --git a/comms/rpc_macros/src/generator.rs b/comms/rpc_macros/src/generator.rs index 6d8dd28a4e4..bc27ad2f054 100644 --- a/comms/rpc_macros/src/generator.rs +++ b/comms/rpc_macros/src/generator.rs @@ -198,7 +198,7 @@ impl RpcCodeGenerator { pub async fn connect(framed: #dep_mod::CanonicalFraming) -> Result where TSubstream: #dep_mod::AsyncRead + #dep_mod::AsyncWrite + Unpin + Send + #dep_mod::StreamId + 'static { use #dep_mod::NamedProtocolService; - let inner = #dep_mod::RpcClient::connect(Default::default(), Default::default(), framed, Self::PROTOCOL_NAME.into()).await?; + let inner = #dep_mod::RpcClient::connect(Default::default(), Default::default(), framed, Self::PROTOCOL_NAME.into(), None).await?; Ok(Self { inner }) }