diff --git a/sn_networking/src/bootstrap.rs b/sn_networking/src/bootstrap.rs index 5643eb9bc6..148e3c76ff 100644 --- a/sn_networking/src/bootstrap.rs +++ b/sn_networking/src/bootstrap.rs @@ -6,7 +6,7 @@ // KIND, either express or implied. Please review the Licences for the specific language governing // permissions and limitations relating to use of the SAFE Network Software. -use crate::SwarmDriver; +use crate::{driver::PendingGetClosestType, SwarmDriver}; use std::time::{Duration, Instant}; use tokio::time::Interval; @@ -53,11 +53,15 @@ impl SwarmDriver { // The query is just to trigger the network discovery, // hence no need to wait for a result. for addr in self.network_discovery_candidates.candidates() { - let _ = self + let query_id = self .swarm .behaviour_mut() .kademlia .get_closest_peers(addr.as_bytes()); + let _ = self.pending_get_closest_peers.insert( + query_id, + (PendingGetClosestType::NetworkDiscovery, Default::default()), + ); } self.network_discovery_candidates .try_generate_new_candidates(); diff --git a/sn_networking/src/cmd.rs b/sn_networking/src/cmd.rs index ce050f738c..f21c245887 100644 --- a/sn_networking/src/cmd.rs +++ b/sn_networking/src/cmd.rs @@ -7,7 +7,7 @@ // permissions and limitations relating to use of the SAFE Network Software. use crate::{ - driver::SwarmDriver, + driver::{PendingGetClosestType, SwarmDriver}, error::{Error, Result}, sort_peers_by_address, GetQuorum, MsgResponder, NetworkEvent, CLOSE_GROUP_SIZE, REPLICATE_RANGE, @@ -512,9 +512,13 @@ impl SwarmDriver { .behaviour_mut() .kademlia .get_closest_peers(key.as_bytes()); - let _ = self - .pending_get_closest_peers - .insert(query_id, (sender, Default::default())); + let _ = self.pending_get_closest_peers.insert( + query_id, + ( + PendingGetClosestType::FunctionCall(sender), + Default::default(), + ), + ); } SwarmCmd::GetAllLocalPeers { sender } => { let _ = sender.send(self.get_all_local_peers()); diff --git a/sn_networking/src/driver.rs b/sn_networking/src/driver.rs index d23d6743f3..a5c3811b01 100644 --- a/sn_networking/src/driver.rs +++ b/sn_networking/src/driver.rs @@ -64,7 +64,15 @@ use tracing::warn; /// List of expected record holders to be verified. pub(super) type ExpectedHoldersList = HashSet; -type PendingGetClosest = HashMap>, HashSet)>; +/// The ways in which the Get Closest queries are used. +pub(crate) enum PendingGetClosestType { + /// The network discovery method is present at the networking layer + /// Thus we can just process the queries made by NetworkDiscovery without using any channels + NetworkDiscovery, + /// These are queries made by a function at the upper layers and contains a channel to send the result back. + FunctionCall(oneshot::Sender>), +} +type PendingGetClosest = HashMap)>; type PendingGetRecord = HashMap< QueryId, ( diff --git a/sn_networking/src/event.rs b/sn_networking/src/event.rs index ce921f8570..e2306ee1e6 100644 --- a/sn_networking/src/event.rs +++ b/sn_networking/src/event.rs @@ -8,7 +8,7 @@ use crate::{ close_group_majority, - driver::{truncate_patch_version, SwarmDriver}, + driver::{truncate_patch_version, PendingGetClosestType, SwarmDriver}, error::{Error, Result}, multiaddr_is_global, multiaddr_strip_p2p, sort_peers_by_address, GetQuorum, CLOSE_GROUP_SIZE, }; @@ -606,7 +606,7 @@ impl SwarmDriver { "Query task {id:?} returned with peers {closest_peers:?}, {stats:?} - {step:?}" ); - let (sender, mut current_closest) = + let (get_closest_type, mut current_closest) = self.pending_get_closest_peers.remove(&id).ok_or_else(|| { trace!( "Can't locate query task {id:?}, it has likely been completed already." @@ -621,13 +621,20 @@ impl SwarmDriver { let new_peers: HashSet = closest_peers.peers.clone().into_iter().collect(); current_closest.extend(new_peers); if current_closest.len() >= usize::from(K_VALUE) || step.last { - sender - .send(current_closest) - .map_err(|_| Error::InternalMsgChannelDropped)?; + match get_closest_type { + PendingGetClosestType::NetworkDiscovery => self + .network_discovery_candidates + .handle_get_closest_query(current_closest), + PendingGetClosestType::FunctionCall(sender) => { + sender + .send(current_closest) + .map_err(|_| Error::InternalMsgChannelDropped)?; + } + } } else { let _ = self .pending_get_closest_peers - .insert(id, (sender, current_closest)); + .insert(id, (get_closest_type, current_closest)); } } // Handle GetClosestPeers timeouts @@ -640,7 +647,7 @@ impl SwarmDriver { event_string = "kad_event::get_closest_peers_err"; error!("GetClosest Query task {id:?} errored with {err:?}, {stats:?} - {step:?}"); - let (sender, mut current_closest) = + let (get_closest_type, mut current_closest) = self.pending_get_closest_peers.remove(&id).ok_or_else(|| { trace!( "Can't locate query task {id:?}, it has likely been completed already." @@ -657,9 +664,16 @@ impl SwarmDriver { } } - sender - .send(current_closest) - .map_err(|_| Error::InternalMsgChannelDropped)?; + match get_closest_type { + PendingGetClosestType::NetworkDiscovery => self + .network_discovery_candidates + .handle_get_closest_query(current_closest), + PendingGetClosestType::FunctionCall(sender) => { + sender + .send(current_closest) + .map_err(|_| Error::InternalMsgChannelDropped)?; + } + } } // For `get_record` returning behaviour: diff --git a/sn_networking/src/network_discovery.rs b/sn_networking/src/network_discovery.rs index 1bf7369238..93781470f7 100644 --- a/sn_networking/src/network_discovery.rs +++ b/sn_networking/src/network_discovery.rs @@ -10,7 +10,7 @@ use libp2p::{kad::KBucketKey, PeerId}; use rayon::iter::{IntoParallelIterator, ParallelIterator}; use sn_protocol::NetworkAddress; use std::{ - collections::{hash_map::Entry, HashMap, VecDeque}, + collections::{hash_map::Entry, HashMap, HashSet, VecDeque}, time::Instant, }; @@ -91,6 +91,36 @@ impl NetworkDiscoveryCandidates { .filter_map(|candidates| candidates.front()) } + pub(crate) fn handle_get_closest_query(&mut self, closest_peers: HashSet) { + let now = Instant::now(); + for peer in closest_peers { + let peer = NetworkAddress::from_peer(peer); + let peer_key = peer.as_kbucket_key(); + if let Some(ilog2_distance) = peer_key.distance(&self.self_key).ilog2() { + match self.candidates.entry(ilog2_distance) { + Entry::Occupied(mut entry) => { + let entry = entry.get_mut(); + // extra check to make sure we don't insert the same peer again + if entry.len() >= MAX_PEERS_PER_BUCKET && !entry.contains(&peer) { + // pop the front (as it might have been already used for querying and insert the new one at the back + let _ = entry.pop_front(); + entry.push_back(peer); + } else { + entry.push_back(peer); + } + } + Entry::Vacant(entry) => { + let _ = entry.insert(VecDeque::from([peer])); + } + } + } + } + trace!( + "It took {:?} to NetworkDiscovery::handle get closest query", + now.elapsed() + ); + } + fn generate_candidates( self_key: &KBucketKey, num_to_generate: usize,