From 17ea2ec99a0c758eb751518dd0ce818746edf5e0 Mon Sep 17 00:00:00 2001 From: Thomas Braun Date: Fri, 22 Nov 2024 21:40:38 -0500 Subject: [PATCH] fix: add retry mechanism to hole puncher --- citadel_wire/src/udp_traversal/multi/mod.rs | 8 ++-- .../src/udp_traversal/udp_hole_puncher.rs | 38 +++++++++++++++++-- 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/citadel_wire/src/udp_traversal/multi/mod.rs b/citadel_wire/src/udp_traversal/multi/mod.rs index 7130d5224..ef05dcc76 100644 --- a/citadel_wire/src/udp_traversal/multi/mod.rs +++ b/citadel_wire/src/udp_traversal/multi/mod.rs @@ -213,7 +213,6 @@ async fn drive( let rebuilder_task = async move { log::trace!(target: "citadel", "*** Will now await post_rebuild_rx ... {} have finished", finished_count.lock()); - let mut count = 0; // Note: if properly implemented, the below should return almost instantly loop { if let Some(current_enqueued) = current_enqueued.lock().await.take() { @@ -230,10 +229,11 @@ async fn drive( None => return Err(anyhow::Error::msg("post_rebuild_rx failed")), Some(None) => { - count += 1; + let mut count = finished_count.lock(); + *count += 1; log::trace!(target: "citadel", "*** [rebuild] So-far, {}/{} have finished", count, hole_puncher_count); - if count == hole_puncher_count { - log::error!(target: "citadel", "This should not happen") + if *count == hole_puncher_count { + return Err(anyhow::Error::msg("All hole-punchers have failed")); } } diff --git a/citadel_wire/src/udp_traversal/udp_hole_puncher.rs b/citadel_wire/src/udp_traversal/udp_hole_puncher.rs index 70ae99861..4a23810fb 100644 --- a/citadel_wire/src/udp_traversal/udp_hole_puncher.rs +++ b/citadel_wire/src/udp_traversal/udp_hole_puncher.rs @@ -16,7 +16,7 @@ pub struct UdpHolePuncher<'a> { driver: Pin> + Send + 'a>>, } -const DEFAULT_TIMEOUT: Duration = Duration::from_millis(6000); +const DEFAULT_TIMEOUT: Duration = Duration::from_millis(5000); impl<'a> UdpHolePuncher<'a> { pub fn new( @@ -32,9 +32,9 @@ impl<'a> UdpHolePuncher<'a> { timeout: Duration, ) -> Self { Self { - driver: Box::pin(async move { - tokio::time::timeout(timeout, driver(conn, encrypted_config_container)).await? - }), + driver: Box::pin( + async move { driver(conn, encrypted_config_container, timeout).await }, + ), } } } @@ -47,11 +47,41 @@ impl Future for UdpHolePuncher<'_> { } } +const MAX_RETRIES: usize = 3; + #[cfg_attr( feature = "localhost-testing", tracing::instrument(level = "trace", target = "citadel", skip_all, ret, err(Debug)) )] async fn driver( + conn: &NetworkEndpoint, + encrypted_config_container: HolePunchConfigContainer, + timeout: Duration, +) -> Result { + let mut retries = 0; + loop { + let task = tokio::time::timeout( + timeout, + driver_inner(conn, encrypted_config_container.clone()), + ); + match task.await { + Ok(Ok(res)) => return Ok(res), + Ok(Err(err)) => { + log::warn!(target: "citadel", "Hole puncher failed: {err:?}"); + } + Err(_) => { + log::warn!(target: "citadel", "Hole puncher timed-out"); + retries += 1; + } + } + + if retries >= MAX_RETRIES { + return Err(anyhow::Error::msg("Max retries reached for UDP Traversal")); + } + } +} + +async fn driver_inner( conn: &NetworkEndpoint, mut encrypted_config_container: HolePunchConfigContainer, ) -> Result {