From 96cf64e9321d7e41fdb248d1e4bce52e3f1c9eac Mon Sep 17 00:00:00 2001 From: Markus Pettersson Date: Tue, 19 Nov 2024 15:15:53 +0100 Subject: [PATCH] Fix allow Wireguard-Go tunnel setup to be cancelled --- talpid-wireguard/src/connectivity/mod.rs | 2 + talpid-wireguard/src/lib.rs | 28 +++++++++++-- talpid-wireguard/src/wireguard_go/mod.rs | 52 +++++++++++++++++------- 3 files changed, 64 insertions(+), 18 deletions(-) diff --git a/talpid-wireguard/src/connectivity/mod.rs b/talpid-wireguard/src/connectivity/mod.rs index efa8e0ccc5f4..512d8715f17d 100644 --- a/talpid-wireguard/src/connectivity/mod.rs +++ b/talpid-wireguard/src/connectivity/mod.rs @@ -6,6 +6,8 @@ mod mock; mod monitor; mod pinger; +#[cfg(target_os = "android")] +pub use check::Cancellable; pub use check::Check; pub use error::Error; pub use monitor::Monitor; diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index 9c5a2b377bb2..db3c90f8259c 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -427,6 +427,11 @@ impl WireguardMonitor { } let should_negotiate_ephemeral_peer = config.quantum_resistant || config.daita; + + let (connectivity_check, pinger_tx) = connectivity::Check::new(config.ipv4_gateway) + .map_err(Error::ConnectivityMonitorError)? + .with_cancellation(); + let tunnel = Self::open_wireguard_go_tunnel( &config, log_path, @@ -436,12 +441,10 @@ impl WireguardMonitor { // that we only allows traffic to/from the gateway. This is only needed on Android // since we lack a firewall there. should_negotiate_ephemeral_peer, + connectivity_check, )?; - let iface_name = tunnel.get_interface_name(); - let (connectivity_check, pinger_tx) = connectivity::Check::new(config.ipv4_gateway) - .map_err(Error::ConnectivityMonitorError)? - .with_cancellation(); + let iface_name = tunnel.get_interface_name(); let tunnel = Arc::new(AsyncMutex::new(Some(tunnel))); let monitor = WireguardMonitor { runtime: args.runtime.clone(), @@ -487,6 +490,18 @@ impl WireguardMonitor { let metadata = Self::tunnel_metadata(&iface_name, &config); args.on_event.clone()(TunnelEvent::Up(metadata)).await; + // HACK: The tunnel does not need the connectivity::Check anymore, so lets take it + let connectivity_check = { + let mut tunnel_lock = tunnel.lock().await; + let Some(tunnel) = tunnel_lock.as_mut() else { + log::debug!("Tunnel is no longer running"); + return Err::(CloseMsg::PingErr); + }; + tunnel + .take_checker() + .expect("connectivity checker unexpectedly dropped") + }; + tokio::task::spawn_blocking(move || { let tunnel = Arc::downgrade(&tunnel); if let Err(error) = connectivity::Monitor::init(connectivity_check).run(tunnel) { @@ -710,6 +725,9 @@ impl WireguardMonitor { #[cfg(daita)] resource_dir: &Path, tun_provider: Arc>, #[cfg(target_os = "android")] gateway_only: bool, + #[cfg(target_os = "android")] connectivity_check: connectivity::Check< + connectivity::Cancellable, + >, ) -> Result { let routes = config .get_tunnel_destinations() @@ -748,6 +766,7 @@ impl WireguardMonitor { routes, #[cfg(daita)] resource_dir, + connectivity_check, ) .map_err(Error::TunnelError)? } else { @@ -759,6 +778,7 @@ impl WireguardMonitor { routes, #[cfg(daita)] resource_dir, + connectivity_check, ) .map_err(Error::TunnelError)? }; diff --git a/talpid-wireguard/src/wireguard_go/mod.rs b/talpid-wireguard/src/wireguard_go/mod.rs index 9910805e026a..722f0d942261 100644 --- a/talpid-wireguard/src/wireguard_go/mod.rs +++ b/talpid-wireguard/src/wireguard_go/mod.rs @@ -12,8 +12,6 @@ use crate::logging::{clean_up_logging, initialize_logging}; use ipnetwork::IpNetwork; #[cfg(daita)] use once_cell::sync::OnceCell; -#[cfg(target_os = "android")] -use std::net::Ipv4Addr; #[cfg(daita)] use std::{ffi::CString, fs, path::PathBuf}; use std::{ @@ -79,7 +77,7 @@ impl WgGoTunnel { &self.0 } - fn to_state_mut(&mut self) -> &mut WgGoTunnelState { + fn as_state_mut(&mut self) -> &mut WgGoTunnelState { &mut self.0 } } @@ -100,14 +98,17 @@ impl WgGoTunnel { } } - fn to_state_mut(&mut self) -> &mut WgGoTunnelState { + fn as_state_mut(&mut self) -> &mut WgGoTunnelState { match self { WgGoTunnel::Multihop(state) => state, WgGoTunnel::Singlehop(state) => state, } } - pub fn set_config(self, config: &Config) -> Result { + pub fn set_config(mut self, config: &Config) -> Result { + let connectivity_checker = self + .take_checker() + .expect("connectivity checker unexpectedly dropped"); let state = self.as_state(); let log_path = state._logging_context.path.clone(); let tun_provider = Arc::clone(&state.tun_provider); @@ -124,6 +125,7 @@ impl WgGoTunnel { tun_provider, routes, &resource_dir, + connectivity_checker, ) } WgGoTunnel::Singlehop(state) if config.is_multihop() => { @@ -135,6 +137,7 @@ impl WgGoTunnel { tun_provider, routes, &resource_dir, + connectivity_checker, ) } WgGoTunnel::Singlehop(mut state) => { @@ -167,6 +170,13 @@ pub(crate) struct WgGoTunnelState { resource_dir: PathBuf, #[cfg(daita)] config: Config, + // HACK: Check is not Clone, so we have to pass this around .. + // This is conceptually the connection between this Tunnel and the currently running + // WireguardMonitor, and it is used to allow WireguardMonitor to cancel the setup of + // a new Tunnel during the "ensure_connectivity" phase. This field should be removed + // as soon as we implement a better way to cancel Check asynchronously. + #[cfg(target_os = "android")] + connectivity_checker: Option>, } impl WgGoTunnelState { @@ -292,6 +302,7 @@ impl WgGoTunnel { tun_provider: Arc>, routes: impl Iterator, #[cfg(daita)] resource_dir: &Path, + mut connectivity_check: connectivity::Check, ) -> Result { let (mut tunnel_device, tunnel_fd) = Self::get_tunnel(Arc::clone(&tun_provider), config, routes)?; @@ -314,7 +325,7 @@ impl WgGoTunnel { Self::bypass_tunnel_sockets(&handle, &mut tunnel_device) .map_err(TunnelError::BypassError)?; - let tunnel = WgGoTunnel::Singlehop(WgGoTunnelState { + let mut tunnel = WgGoTunnel::Singlehop(WgGoTunnelState { interface_name, tunnel_handle: handle, _tunnel_device: tunnel_device, @@ -324,9 +335,12 @@ impl WgGoTunnel { resource_dir: resource_dir.to_owned(), #[cfg(daita)] config: config.clone(), + connectivity_checker: None, }); - tunnel.ensure_tunnel_is_running(config.ipv4_gateway)?; + // HACK: Check if the tunnel is working by sending a ping in the tunnel. + tunnel.ensure_tunnel_is_running(&mut connectivity_check)?; + tunnel.as_state_mut().connectivity_checker = Some(connectivity_check); Ok(tunnel) } @@ -338,6 +352,7 @@ impl WgGoTunnel { tun_provider: Arc>, routes: impl Iterator, #[cfg(daita)] resource_dir: &Path, + mut connectivity_check: connectivity::Check, ) -> Result { let (mut tunnel_device, tunnel_fd) = Self::get_tunnel(Arc::clone(&tun_provider), config, routes)?; @@ -376,7 +391,7 @@ impl WgGoTunnel { Self::bypass_tunnel_sockets(&handle, &mut tunnel_device) .map_err(TunnelError::BypassError)?; - let tunnel = WgGoTunnel::Multihop(WgGoTunnelState { + let mut tunnel = WgGoTunnel::Multihop(WgGoTunnelState { interface_name, tunnel_handle: handle, _tunnel_device: tunnel_device, @@ -386,9 +401,12 @@ impl WgGoTunnel { resource_dir: resource_dir.to_owned(), #[cfg(daita)] config: config.clone(), + connectivity_checker: None, }); - tunnel.ensure_tunnel_is_running(config.ipv4_gateway)?; + // HACK: Check if the tunnel is working by sending a ping in the tunnel. + tunnel.ensure_tunnel_is_running(&mut connectivity_check)?; + tunnel.as_state_mut().connectivity_checker = Some(connectivity_check); Ok(tunnel) } @@ -406,12 +424,18 @@ impl WgGoTunnel { Ok(()) } - /// There is a breif period of time between setting up a Wireguard-go tunnel and the tunnel being ready to serve + pub fn take_checker(&mut self) -> Option> { + self.as_state_mut().connectivity_checker.take() + } + + /// There is a brief period of time between setting up a Wireguard-go tunnel and the tunnel being ready to serve /// traffic. This function blocks until the tunnel starts to serve traffic or until [connectivity::Check] times out. - fn ensure_tunnel_is_running(&self, addr: Ipv4Addr) -> Result<()> { + fn ensure_tunnel_is_running( + &self, + checker: &mut connectivity::Check, + ) -> Result<()> { let connectivity_err = |e| TunnelError::Connectivity(Box::new(e)); - let connection_established = connectivity::Check::new(addr) - .map_err(connectivity_err)? + let connection_established = checker .establish_connectivity(0, self) .map_err(connectivity_err)?; @@ -446,7 +470,7 @@ impl Tunnel for WgGoTunnel { &mut self, config: Config, ) -> Pin> + Send + '_>> { - Box::pin(async move { self.to_state_mut().set_config(config) }) + Box::pin(async move { self.as_state_mut().set_config(config) }) } #[cfg(daita)]