diff --git a/src/builder.rs b/src/builder.rs index 886ef3cf8..5fb3a8e32 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -2,6 +2,7 @@ use crate::config::{ Config, BDK_CLIENT_CONCURRENCY, BDK_CLIENT_STOP_GAP, DEFAULT_ESPLORA_SERVER_URL, WALLET_KEYS_SEED_LEN, }; +use crate::connection::ConnectionManager; use crate::event::EventQueue; use crate::fee_estimator::OnchainFeeEstimator; use crate::gossip::GossipSource; @@ -891,6 +892,9 @@ fn build_with_store_internal( liquidity_source.as_ref().map(|l| l.set_peer_manager(Arc::clone(&peer_manager))); + let connection_manager = + Arc::new(ConnectionManager::new(Arc::clone(&peer_manager), Arc::clone(&logger))); + // Init payment info storage let payment_store = match io::utils::read_payments(Arc::clone(&kv_store), Arc::clone(&logger)) { Ok(payments) => { @@ -958,6 +962,7 @@ fn build_with_store_internal( chain_monitor, output_sweeper, peer_manager, + connection_manager, keys_manager, network_graph, gossip_source, diff --git a/src/connection.rs b/src/connection.rs index 7a93c1d8d..4c0f7a47f 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -8,62 +8,141 @@ use bitcoin::secp256k1::PublicKey; use std::net::ToSocketAddrs; use std::ops::Deref; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use std::time::Duration; -pub(crate) async fn connect_peer_if_necessary( - node_id: PublicKey, addr: SocketAddress, peer_manager: Arc, logger: L, -) -> Result<(), Error> +pub(crate) struct ConnectionManager where L::Target: Logger, { - if peer_manager.peer_by_node_id(&node_id).is_some() { - return Ok(()); - } - - do_connect_peer(node_id, addr, peer_manager, logger).await + pending_connections: + Mutex>>)>>, + peer_manager: Arc, + logger: L, } -pub(crate) async fn do_connect_peer( - node_id: PublicKey, addr: SocketAddress, peer_manager: Arc, logger: L, -) -> Result<(), Error> +impl ConnectionManager where L::Target: Logger, { - log_info!(logger, "Connecting to peer: {}@{}", node_id, addr); - - let socket_addr = addr - .to_socket_addrs() - .map_err(|e| { - log_error!(logger, "Failed to resolve network address: {}", e); - Error::InvalidSocketAddress - })? - .next() - .ok_or(Error::ConnectionFailed)?; - - match lightning_net_tokio::connect_outbound(Arc::clone(&peer_manager), node_id, socket_addr) - .await - { - Some(connection_closed_future) => { - let mut connection_closed_future = Box::pin(connection_closed_future); - loop { - tokio::select! { - _ = &mut connection_closed_future => { - log_info!(logger, "Peer connection closed: {}@{}", node_id, addr); - return Err(Error::ConnectionFailed); - }, - _ = tokio::time::sleep(Duration::from_millis(10)) => {}, - }; - - match peer_manager.peer_by_node_id(&node_id) { - Some(_) => return Ok(()), - None => continue, + pub(crate) fn new(peer_manager: Arc, logger: L) -> Self { + let pending_connections = Mutex::new(Vec::new()); + Self { pending_connections, peer_manager, logger } + } + + pub(crate) async fn connect_peer_if_necessary( + &self, node_id: PublicKey, addr: SocketAddress, + ) -> Result<(), Error> { + if self.peer_manager.peer_by_node_id(&node_id).is_some() { + return Ok(()); + } + + self.do_connect_peer(node_id, addr).await + } + + pub(crate) async fn do_connect_peer( + &self, node_id: PublicKey, addr: SocketAddress, + ) -> Result<(), Error> { + // First, we check if there is already an outbound connection in flight, if so, we just + // await on the corresponding watch channel. The task driving the connection future will + // send us the result.. + let pending_ready_receiver_opt = self.register_or_subscribe_pending_connection(&node_id); + if let Some(pending_connection_ready_receiver) = pending_ready_receiver_opt { + return pending_connection_ready_receiver.await.map_err(|e| { + debug_assert!(false, "Failed to receive connection result: {:?}", e); + log_error!(self.logger, "Failed to receive connection result: {:?}", e); + Error::ConnectionFailed + })?; + } + + log_info!(self.logger, "Connecting to peer: {}@{}", node_id, addr); + + let socket_addr = addr + .to_socket_addrs() + .map_err(|e| { + log_error!(self.logger, "Failed to resolve network address: {}", e); + self.propagate_result_to_subscribers(&node_id, Err(Error::InvalidSocketAddress)); + Error::InvalidSocketAddress + })? + .next() + .ok_or_else(|| { + self.propagate_result_to_subscribers(&node_id, Err(Error::ConnectionFailed)); + Error::ConnectionFailed + })?; + + let connection_future = lightning_net_tokio::connect_outbound( + Arc::clone(&self.peer_manager), + node_id, + socket_addr, + ); + + let res = match connection_future.await { + Some(connection_closed_future) => { + let mut connection_closed_future = Box::pin(connection_closed_future); + loop { + tokio::select! { + _ = &mut connection_closed_future => { + log_info!(self.logger, "Peer connection closed: {}@{}", node_id, addr); + break Err(Error::ConnectionFailed); + }, + _ = tokio::time::sleep(Duration::from_millis(10)) => {}, + }; + + match self.peer_manager.peer_by_node_id(&node_id) { + Some(_) => break Ok(()), + None => continue, + } } + }, + None => { + log_error!(self.logger, "Failed to connect to peer: {}@{}", node_id, addr); + Err(Error::ConnectionFailed) + }, + }; + + self.propagate_result_to_subscribers(&node_id, res); + + res + } + + fn register_or_subscribe_pending_connection( + &self, node_id: &PublicKey, + ) -> Option>> { + let mut pending_connections_lock = self.pending_connections.lock().unwrap(); + if let Some((_, connection_ready_senders)) = + pending_connections_lock.iter_mut().find(|(id, _)| id == node_id) + { + let (tx, rx) = tokio::sync::oneshot::channel(); + connection_ready_senders.push(tx); + Some(rx) + } else { + pending_connections_lock.push((*node_id, Vec::new())); + None + } + } + + fn propagate_result_to_subscribers(&self, node_id: &PublicKey, res: Result<(), Error>) { + // Send the result to any other tasks that might be waiting on it by now. + let mut pending_connections_lock = self.pending_connections.lock().unwrap(); + if let Some((_, connection_ready_senders)) = pending_connections_lock + .iter() + .position(|(id, _)| id == node_id) + .map(|i| pending_connections_lock.remove(i)) + { + for sender in connection_ready_senders { + let _ = sender.send(res).map_err(|e| { + debug_assert!( + false, + "Failed to send connection result to subscribers: {:?}", + e + ); + log_error!( + self.logger, + "Failed to send connection result to subscribers: {:?}", + e + ); + }); } - }, - None => { - log_error!(logger, "Failed to connect to peer: {}@{}", node_id, addr); - Err(Error::ConnectionFailed) - }, + } } } diff --git a/src/error.rs b/src/error.rs index 0182b3092..c5234a6d4 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,6 @@ use std::fmt; -#[derive(Debug, PartialEq, Eq)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] /// An error that possibly needs to be handled by the user. pub enum Error { /// Returned when trying to start [`crate::Node`] while it is already running. diff --git a/src/lib.rs b/src/lib.rs index d24b45f9a..00beedc11 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -125,7 +125,7 @@ use config::{ LDK_PAYMENT_RETRY_TIMEOUT, NODE_ANN_BCAST_INTERVAL, PEER_RECONNECTION_INTERVAL, RGS_SYNC_INTERVAL, WALLET_SYNC_INTERVAL_MINIMUM_SECS, }; -use connection::{connect_peer_if_necessary, do_connect_peer}; +use connection::ConnectionManager; use event::{EventHandler, EventQueue}; use gossip::GossipSource; use liquidity::LiquiditySource; @@ -189,6 +189,7 @@ pub struct Node { chain_monitor: Arc, output_sweeper: Arc, peer_manager: Arc, + connection_manager: Arc>>, keys_manager: Arc, network_graph: Arc, gossip_source: Arc, @@ -462,6 +463,7 @@ impl Node { } // Regularly reconnect to persisted peers. + let connect_cm = Arc::clone(&self.connection_manager); let connect_pm = Arc::clone(&self.peer_manager); let connect_logger = Arc::clone(&self.logger); let connect_peer_store = Arc::clone(&self.peer_store); @@ -482,11 +484,9 @@ impl Node { .collect::>(); for peer_info in connect_peer_store.list_peers().iter().filter(|info| !pm_peers.contains(&info.node_id)) { - let res = do_connect_peer( + let res = connect_cm.do_connect_peer( peer_info.node_id, peer_info.address.clone(), - Arc::clone(&connect_pm), - Arc::clone(&connect_logger), ).await; match res { Ok(_) => { @@ -803,14 +803,13 @@ impl Node { let con_node_id = peer_info.node_id; let con_addr = peer_info.address.clone(); - let con_logger = Arc::clone(&self.logger); - let con_pm = Arc::clone(&self.peer_manager); + let con_cm = Arc::clone(&self.connection_manager); // We need to use our main runtime here as a local runtime might not be around to poll // connection futures going forward. tokio::task::block_in_place(move || { runtime.block_on(async move { - connect_peer_if_necessary(con_node_id, con_addr, con_pm, con_logger).await + con_cm.connect_peer_if_necessary(con_node_id, con_addr).await }) })?; @@ -876,14 +875,13 @@ impl Node { let con_node_id = peer_info.node_id; let con_addr = peer_info.address.clone(); - let con_logger = Arc::clone(&self.logger); - let con_pm = Arc::clone(&self.peer_manager); + let con_cm = Arc::clone(&self.connection_manager); // We need to use our main runtime here as a local runtime might not be around to poll // connection futures going forward. tokio::task::block_in_place(move || { runtime.block_on(async move { - connect_peer_if_necessary(con_node_id, con_addr, con_pm, con_logger).await + con_cm.connect_peer_if_necessary(con_node_id, con_addr).await }) })?; @@ -1533,14 +1531,13 @@ impl Node { let con_node_id = peer_info.node_id; let con_addr = peer_info.address.clone(); - let con_logger = Arc::clone(&self.logger); - let con_pm = Arc::clone(&self.peer_manager); + let con_cm = Arc::clone(&self.connection_manager); // We need to use our main runtime here as a local runtime might not be around to poll // connection futures going forward. tokio::task::block_in_place(move || { runtime.block_on(async move { - connect_peer_if_necessary(con_node_id, con_addr, con_pm, con_logger).await + con_cm.connect_peer_if_necessary(con_node_id, con_addr).await }) })?;