Skip to content

Commit

Permalink
Fix allow Wireguard-Go tunnel setup to be cancelled
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkusPettersson98 committed Nov 20, 2024
1 parent a87c09c commit 96cf64e
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 18 deletions.
2 changes: 2 additions & 0 deletions talpid-wireguard/src/connectivity/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ mod mock;
mod monitor;
mod pinger;

#[cfg(target_os = "android")]
pub use check::Cancellable;
pub use check::Check;
pub use error::Error;
pub use monitor::Monitor;
28 changes: 24 additions & 4 deletions talpid-wireguard/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,11 @@ impl WireguardMonitor {
}

let should_negotiate_ephemeral_peer = config.quantum_resistant || config.daita;

let (connectivity_check, pinger_tx) = connectivity::Check::new(config.ipv4_gateway)
.map_err(Error::ConnectivityMonitorError)?
.with_cancellation();

let tunnel = Self::open_wireguard_go_tunnel(
&config,
log_path,
Expand All @@ -436,12 +441,10 @@ impl WireguardMonitor {
// that we only allows traffic to/from the gateway. This is only needed on Android
// since we lack a firewall there.
should_negotiate_ephemeral_peer,
connectivity_check,
)?;
let iface_name = tunnel.get_interface_name();
let (connectivity_check, pinger_tx) = connectivity::Check::new(config.ipv4_gateway)
.map_err(Error::ConnectivityMonitorError)?
.with_cancellation();

let iface_name = tunnel.get_interface_name();
let tunnel = Arc::new(AsyncMutex::new(Some(tunnel)));
let monitor = WireguardMonitor {
runtime: args.runtime.clone(),
Expand Down Expand Up @@ -487,6 +490,18 @@ impl WireguardMonitor {
let metadata = Self::tunnel_metadata(&iface_name, &config);
args.on_event.clone()(TunnelEvent::Up(metadata)).await;

// HACK: The tunnel does not need the connectivity::Check anymore, so lets take it
let connectivity_check = {
let mut tunnel_lock = tunnel.lock().await;
let Some(tunnel) = tunnel_lock.as_mut() else {
log::debug!("Tunnel is no longer running");
return Err::<Infallible, CloseMsg>(CloseMsg::PingErr);
};
tunnel
.take_checker()
.expect("connectivity checker unexpectedly dropped")
};

tokio::task::spawn_blocking(move || {
let tunnel = Arc::downgrade(&tunnel);
if let Err(error) = connectivity::Monitor::init(connectivity_check).run(tunnel) {
Expand Down Expand Up @@ -710,6 +725,9 @@ impl WireguardMonitor {
#[cfg(daita)] resource_dir: &Path,
tun_provider: Arc<Mutex<TunProvider>>,
#[cfg(target_os = "android")] gateway_only: bool,
#[cfg(target_os = "android")] connectivity_check: connectivity::Check<
connectivity::Cancellable,
>,
) -> Result<WgGoTunnel> {
let routes = config
.get_tunnel_destinations()
Expand Down Expand Up @@ -748,6 +766,7 @@ impl WireguardMonitor {
routes,
#[cfg(daita)]
resource_dir,
connectivity_check,
)
.map_err(Error::TunnelError)?
} else {
Expand All @@ -759,6 +778,7 @@ impl WireguardMonitor {
routes,
#[cfg(daita)]
resource_dir,
connectivity_check,
)
.map_err(Error::TunnelError)?
};
Expand Down
52 changes: 38 additions & 14 deletions talpid-wireguard/src/wireguard_go/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ use crate::logging::{clean_up_logging, initialize_logging};
use ipnetwork::IpNetwork;
#[cfg(daita)]
use once_cell::sync::OnceCell;
#[cfg(target_os = "android")]
use std::net::Ipv4Addr;
#[cfg(daita)]
use std::{ffi::CString, fs, path::PathBuf};
use std::{
Expand Down Expand Up @@ -79,7 +77,7 @@ impl WgGoTunnel {
&self.0
}

fn to_state_mut(&mut self) -> &mut WgGoTunnelState {
fn as_state_mut(&mut self) -> &mut WgGoTunnelState {
&mut self.0
}
}
Expand All @@ -100,14 +98,17 @@ impl WgGoTunnel {
}
}

fn to_state_mut(&mut self) -> &mut WgGoTunnelState {
fn as_state_mut(&mut self) -> &mut WgGoTunnelState {
match self {
WgGoTunnel::Multihop(state) => state,
WgGoTunnel::Singlehop(state) => state,
}
}

pub fn set_config(self, config: &Config) -> Result<Self> {
pub fn set_config(mut self, config: &Config) -> Result<Self> {
let connectivity_checker = self
.take_checker()
.expect("connectivity checker unexpectedly dropped");
let state = self.as_state();
let log_path = state._logging_context.path.clone();
let tun_provider = Arc::clone(&state.tun_provider);
Expand All @@ -124,6 +125,7 @@ impl WgGoTunnel {
tun_provider,
routes,
&resource_dir,
connectivity_checker,
)
}
WgGoTunnel::Singlehop(state) if config.is_multihop() => {
Expand All @@ -135,6 +137,7 @@ impl WgGoTunnel {
tun_provider,
routes,
&resource_dir,
connectivity_checker,
)
}
WgGoTunnel::Singlehop(mut state) => {
Expand Down Expand Up @@ -167,6 +170,13 @@ pub(crate) struct WgGoTunnelState {
resource_dir: PathBuf,
#[cfg(daita)]
config: Config,
// HACK: Check is not Clone, so we have to pass this around ..
// This is conceptually the connection between this Tunnel and the currently running
// WireguardMonitor, and it is used to allow WireguardMonitor to cancel the setup of
// a new Tunnel during the "ensure_connectivity" phase. This field should be removed
// as soon as we implement a better way to cancel Check asynchronously.
#[cfg(target_os = "android")]
connectivity_checker: Option<connectivity::Check<connectivity::Cancellable>>,
}

impl WgGoTunnelState {
Expand Down Expand Up @@ -292,6 +302,7 @@ impl WgGoTunnel {
tun_provider: Arc<Mutex<TunProvider>>,
routes: impl Iterator<Item = IpNetwork>,
#[cfg(daita)] resource_dir: &Path,
mut connectivity_check: connectivity::Check<connectivity::Cancellable>,
) -> Result<Self> {
let (mut tunnel_device, tunnel_fd) =
Self::get_tunnel(Arc::clone(&tun_provider), config, routes)?;
Expand All @@ -314,7 +325,7 @@ impl WgGoTunnel {
Self::bypass_tunnel_sockets(&handle, &mut tunnel_device)
.map_err(TunnelError::BypassError)?;

let tunnel = WgGoTunnel::Singlehop(WgGoTunnelState {
let mut tunnel = WgGoTunnel::Singlehop(WgGoTunnelState {
interface_name,
tunnel_handle: handle,
_tunnel_device: tunnel_device,
Expand All @@ -324,9 +335,12 @@ impl WgGoTunnel {
resource_dir: resource_dir.to_owned(),
#[cfg(daita)]
config: config.clone(),
connectivity_checker: None,
});

tunnel.ensure_tunnel_is_running(config.ipv4_gateway)?;
// HACK: Check if the tunnel is working by sending a ping in the tunnel.
tunnel.ensure_tunnel_is_running(&mut connectivity_check)?;
tunnel.as_state_mut().connectivity_checker = Some(connectivity_check);

Ok(tunnel)
}
Expand All @@ -338,6 +352,7 @@ impl WgGoTunnel {
tun_provider: Arc<Mutex<TunProvider>>,
routes: impl Iterator<Item = IpNetwork>,
#[cfg(daita)] resource_dir: &Path,
mut connectivity_check: connectivity::Check<connectivity::Cancellable>,
) -> Result<Self> {
let (mut tunnel_device, tunnel_fd) =
Self::get_tunnel(Arc::clone(&tun_provider), config, routes)?;
Expand Down Expand Up @@ -376,7 +391,7 @@ impl WgGoTunnel {
Self::bypass_tunnel_sockets(&handle, &mut tunnel_device)
.map_err(TunnelError::BypassError)?;

let tunnel = WgGoTunnel::Multihop(WgGoTunnelState {
let mut tunnel = WgGoTunnel::Multihop(WgGoTunnelState {
interface_name,
tunnel_handle: handle,
_tunnel_device: tunnel_device,
Expand All @@ -386,9 +401,12 @@ impl WgGoTunnel {
resource_dir: resource_dir.to_owned(),
#[cfg(daita)]
config: config.clone(),
connectivity_checker: None,
});

tunnel.ensure_tunnel_is_running(config.ipv4_gateway)?;
// HACK: Check if the tunnel is working by sending a ping in the tunnel.
tunnel.ensure_tunnel_is_running(&mut connectivity_check)?;
tunnel.as_state_mut().connectivity_checker = Some(connectivity_check);

Ok(tunnel)
}
Expand All @@ -406,12 +424,18 @@ impl WgGoTunnel {
Ok(())
}

/// There is a breif period of time between setting up a Wireguard-go tunnel and the tunnel being ready to serve
pub fn take_checker(&mut self) -> Option<connectivity::Check<connectivity::Cancellable>> {
self.as_state_mut().connectivity_checker.take()
}

/// There is a brief period of time between setting up a Wireguard-go tunnel and the tunnel being ready to serve
/// traffic. This function blocks until the tunnel starts to serve traffic or until [connectivity::Check] times out.
fn ensure_tunnel_is_running(&self, addr: Ipv4Addr) -> Result<()> {
fn ensure_tunnel_is_running(
&self,
checker: &mut connectivity::Check<connectivity::Cancellable>,
) -> Result<()> {
let connectivity_err = |e| TunnelError::Connectivity(Box::new(e));
let connection_established = connectivity::Check::new(addr)
.map_err(connectivity_err)?
let connection_established = checker
.establish_connectivity(0, self)
.map_err(connectivity_err)?;

Expand Down Expand Up @@ -446,7 +470,7 @@ impl Tunnel for WgGoTunnel {
&mut self,
config: Config,
) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> {
Box::pin(async move { self.to_state_mut().set_config(config) })
Box::pin(async move { self.as_state_mut().set_config(config) })
}

#[cfg(daita)]
Expand Down

0 comments on commit 96cf64e

Please sign in to comment.