diff --git a/CHANGELOG.md b/CHANGELOG.md index 4932d69b987b..7f5664ee773d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,6 +43,7 @@ Line wrap the file at 100 chars. Th #### macOS - Fix intermittent failures to connect with PQ enabled. - Exclude programs when executed using a relative path from a shell. +- Reduce packet loss when using split tunneling. ## [2024.4] - 2024-07-23 diff --git a/Cargo.lock b/Cargo.lock index 6c25a0ade8de..26f0e8ac3dbd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3903,6 +3903,7 @@ dependencies = [ "subslice", "system-configuration", "talpid-dbus", + "talpid-net", "talpid-openvpn", "talpid-platform-metadata", "talpid-routing", @@ -3945,6 +3946,16 @@ dependencies = [ "tokio", ] +[[package]] +name = "talpid-net" +version = "0.0.0" +dependencies = [ + "libc", + "log", + "socket2", + "talpid-types", +] + [[package]] name = "talpid-openvpn" version = "0.0.0" @@ -4128,6 +4139,7 @@ dependencies = [ "socket2", "surge-ping", "talpid-dbus", + "talpid-net", "talpid-routing", "talpid-tunnel", "talpid-tunnel-config-client", diff --git a/Cargo.toml b/Cargo.toml index c3b1185cbbda..1317951f6549 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ members = [ "talpid-core", "talpid-dbus", "talpid-future", + "talpid-net", "talpid-openvpn", "talpid-openvpn-plugin", "talpid-platform-metadata", diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml index 22515b3cc212..20b69e8afca6 100644 --- a/talpid-core/Cargo.toml +++ b/talpid-core/Cargo.toml @@ -60,6 +60,7 @@ tun = { version = "0.5.5", features = ["async"] } nix = { version = "0.28", features = ["socket"] } serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } +talpid-net = { path = "../talpid-net" } [target.'cfg(windows)'.dependencies] bitflags = "1.2" diff --git a/talpid-core/src/split_tunnel/macos/mod.rs b/talpid-core/src/split_tunnel/macos/mod.rs index d5c7a508c36c..0ef879842dbd 100644 --- a/talpid-core/src/split_tunnel/macos/mod.rs +++ b/talpid-core/src/split_tunnel/macos/mod.rs @@ -516,6 +516,7 @@ impl State { let result = tun::create_split_tunnel( default_interface, new_vpn_interface.clone(), + route_manager.clone(), Box::new(move |packet| { match states.get_process_status(packet.header.pth_pid as u32) { ExclusionStatus::Excluded => tun::RoutingDecision::DefaultInterface, diff --git a/talpid-core/src/split_tunnel/macos/tun.rs b/talpid-core/src/split_tunnel/macos/tun.rs index f5f7211499c7..f5cc58878630 100644 --- a/talpid-core/src/split_tunnel/macos/tun.rs +++ b/talpid-core/src/split_tunnel/macos/tun.rs @@ -10,6 +10,7 @@ use super::{ }; use futures::{Stream, StreamExt}; use libc::{AF_INET, AF_INET6}; +use nix::net::if_::if_nametoindex; use pcap::PacketCodec; use pnet_packet::{ ethernet::{EtherTypes, MutableEthernetPacket}, @@ -26,6 +27,7 @@ use std::{ net::{Ipv4Addr, Ipv6Addr}, ptr::NonNull, }; +use talpid_routing::RouteManagerHandle; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, sync::broadcast, @@ -114,11 +116,15 @@ pub struct SplitTunnelHandle { /// Task that handles outgoing packets. On completion, it returns a handle for the pktap, as /// well as the function used to classify packets egress_task: tokio::task::JoinHandle>, + /// Task that synchronizes the ST tunnel MTU with the VPN tunnel MTU + mtu_listener: Option>, + route_manager: RouteManagerHandle, } impl SplitTunnelHandle { - pub async fn shutdown(self) -> Result<(), Error> { + pub async fn shutdown(mut self) -> Result<(), Error> { log::debug!("Shutting down split tunnel"); + self.abort_mtu_listener().await; let _ = self.abort_tx.send(()); let _ = self.ingress_task.await.map_err(|_| Error::StopRedirect)?; let _ = self.egress_task.await.map_err(|_| Error::StopRedirect)??; @@ -131,12 +137,14 @@ impl SplitTunnelHandle { } pub async fn set_interfaces( - self, + mut self, default_interface: DefaultInterface, vpn_interface: Option, ) -> Result { let _ = self.abort_tx.send(()); + self.abort_mtu_listener().await; + let st_utun = self.ingress_task.await.map_err(|_| Error::StopRedirect)?; let egress_completion = self.egress_task.await.map_err(|_| Error::StopRedirect)??; @@ -146,9 +154,17 @@ impl SplitTunnelHandle { egress_completion.pktap_stream, default_interface, vpn_interface, + self.route_manager, egress_completion.classify, ) } + + async fn abort_mtu_listener(&mut self) { + if let Some(mtu_listener) = self.mtu_listener.take() { + mtu_listener.abort(); + let _ = mtu_listener.await; + } + } } /// Create split tunnel device and handle all packets using `classify`. Handle any changes to the @@ -161,10 +177,17 @@ impl SplitTunnelHandle { pub async fn create_split_tunnel( default_interface: DefaultInterface, vpn_interface: Option, + route_manager: RouteManagerHandle, classify: ClassifyFn, ) -> Result { let tun_device = create_utun().await?; - redirect_packets(tun_device, default_interface, vpn_interface, classify) + redirect_packets( + tun_device, + default_interface, + vpn_interface, + route_manager, + classify, + ) } /// Create a utun device for split tunneling, and configure its IP addresses. @@ -207,6 +230,7 @@ fn redirect_packets( st_tun_device: tun::AsyncDevice, default_interface: DefaultInterface, vpn_interface: Option, + route_manager: RouteManagerHandle, classify: ClassifyFn, ) -> Result { let pktap_stream = capture_outbound_packets(st_tun_device.get_ref().name())?; @@ -215,6 +239,7 @@ fn redirect_packets( Box::pin(pktap_stream), default_interface, vpn_interface, + route_manager, Box::new(classify), ) } @@ -232,8 +257,17 @@ fn redirect_packets_for_pktap_stream( pktap_stream: PktapStream, default_interface: DefaultInterface, vpn_interface: Option, + route_manager: RouteManagerHandle, classify: ClassifyFn, ) -> Result { + let mtu_listener = vpn_interface.as_ref().map(|vpn_interface| { + tokio::spawn(mtu_updater( + st_tun_device.get_ref().name().to_owned(), + vpn_interface.name.clone(), + route_manager.clone(), + )) + }); + let (default_stream, default_write, read_buffer_size) = open_default_bpf(&default_interface)?; let st_utun_name = st_tun_device.get_ref().name().to_owned(); @@ -265,9 +299,62 @@ fn redirect_packets_for_pktap_stream( abort_tx, ingress_task, egress_task, + mtu_listener, + route_manager, }) } +/// Listen for changes to VPN interface MTU and apply them to the ST utun accordingly +async fn mtu_updater( + st_interface_name: String, + vpn_interface_name: String, + route_manager: RouteManagerHandle, +) { + let vpn_tun_index = match if_nametoindex(vpn_interface_name.as_str()) { + Ok(index) => u16::try_from(index).unwrap(), + Err(error) => { + log::error!("Failed to obtain VPN utun index: {error}"); + return; + } + }; + let mut current_mtu = match talpid_net::unix::get_mtu(&vpn_interface_name) { + Ok(mtu) => mtu, + Err(error) => { + log::error!("Failed to fetch current VPN tunnel MTU: {error}"); + return; + } + }; + + try_update_mtu(&st_interface_name, current_mtu); + + let mut listener = match route_manager.interface_change_listener().await { + Ok(listener) => listener, + Err(error) => { + log::warn!("Failed to start interface listener: {error}"); + return; + } + }; + while let Some(details) = listener.next().await { + if details.interface_index != vpn_tun_index || details.mtu == current_mtu { + continue; + } + current_mtu = details.mtu; + try_update_mtu(&st_interface_name, current_mtu); + } +} + +/// Try to update the MTU of `st_iface_name`, and log if this fails +fn try_update_mtu(st_iface_name: &str, mtu: u16) { + match talpid_net::unix::set_mtu(st_iface_name, mtu) { + Ok(()) => { + log::debug!("ST interface MTU: {mtu}"); + } + Err(error) => { + log::error!("Failed to set MTU of {st_iface_name} to {mtu}: {error}"); + } + } +} + /// Open a BPF device for the specified default interface. Return a read and write half, and the /// buffer size. fn open_default_bpf( @@ -440,6 +527,16 @@ fn classify_and_send( log::error!("dropping invalid IPv4 packet"); return; }; + if let Some(vpn_v4) = vpn_interface.and_then(|iface| iface.0.v4_address) { + let src_ip = ip.get_source(); + if src_ip != vpn_v4 && src_ip != addrs.source_ip { + // Drop packet from invalid source + return; + } + } else if ip.get_source() != addrs.source_ip { + // Drop packet from invalid source + return; + } fix_ipv4_checksums(&mut ip, Some(addrs.source_ip), None); if let Err(error) = default_write.write(packet.frame.packet()) { log::error!("Failed to forward to default device: {error}"); @@ -457,6 +554,16 @@ fn classify_and_send( log::error!("dropping invalid IPv6 packet"); return; }; + if let Some(vpn_v6) = vpn_interface.and_then(|iface| iface.0.v6_address) { + let src_ip = ip.get_source(); + if src_ip != vpn_v6 && src_ip != addrs.source_ip { + // Drop packet from invalid source + return; + } + } else if ip.get_source() != addrs.source_ip { + // Drop packet from invalid source + return; + } fix_ipv6_checksums(&mut ip, Some(addrs.source_ip), None); if let Err(error) = default_write.write(packet.frame.packet()) { log::error!("Failed to forward to default device: {error}"); @@ -480,9 +587,16 @@ fn classify_and_send( log::error!("dropping invalid IPv4 packet"); return; }; + if ip.get_source() != addr { + // Drop packet from invalid source + return; + } fix_ipv4_checksums(&mut ip, Some(addr), None); if let Err(error) = vpn_write.write(packet.frame.payload()) { - log::error!("Failed to forward to tun device: {error}"); + log::trace!( + "Failed to forward to VPN tunnel: {error}, size: {}", + packet.frame.payload().len() + ); } } EtherTypes::Ipv6 => { @@ -494,9 +608,16 @@ fn classify_and_send( log::error!("dropping invalid IPv6 packet"); return; }; + if ip.get_source() != addr { + // Drop packet from invalid source + return; + } fix_ipv6_checksums(&mut ip, Some(addr), None); if let Err(error) = vpn_write.write(packet.frame.payload()) { - log::error!("Failed to forward to tun device: {error}"); + log::trace!( + "Failed to forward to VPN tunnel: {error}, size: {}", + packet.frame.payload().len() + ); } } other => log::error!("unknown ethertype: {other}"), @@ -690,10 +811,12 @@ fn capture_outbound_packets( .open() .map_err(Error::CaptureSplitTunnelDevice)?; - // TODO: This is unsupported on macOS 13 and lower, so we determine the direction using the - // pktap header flags. Once macOS 13 is no longer supported, this can be uncommented. - // cap.direction(pcap::Direction::Out) - // .map_err(Error::SetDirection)?; + // TODO: `Capture::direction` is unsupported on macOS 13 and lower, so we determine the + // direction using the pktap header as well. Once macOS 13 is no longer supported, + // this can be assumed to work. Filtering here appears to be a lot faster. + if let Err(error) = cap.direction(pcap::Direction::Out) { + log::debug!("Failed to set capture direction. Might be on macOS 13: {error}"); + } let cap = cap.setnonblock().map_err(Error::EnableNonblock)?; let stream = cap diff --git a/talpid-net/Cargo.toml b/talpid-net/Cargo.toml new file mode 100644 index 000000000000..aa30ed1b5b6a --- /dev/null +++ b/talpid-net/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "talpid-net" +description = "Networking helpers" +authors.workspace = true +repository.workspace = true +license.workspace = true +edition.workspace = true +rust-version.workspace = true + +[lints] +workspace = true + +[target.'cfg(unix)'.dependencies] +libc = "0.2" +talpid-types = { path = "../talpid-types" } +socket2 = { version = "0.5.3", features = ["all"] } +log = { workspace = true } diff --git a/talpid-net/src/lib.rs b/talpid-net/src/lib.rs new file mode 100644 index 000000000000..b13064bf70d6 --- /dev/null +++ b/talpid-net/src/lib.rs @@ -0,0 +1,2 @@ +#[cfg(unix)] +pub mod unix; diff --git a/talpid-net/src/unix.rs b/talpid-net/src/unix.rs new file mode 100644 index 000000000000..ef2bfdeb276c --- /dev/null +++ b/talpid-net/src/unix.rs @@ -0,0 +1,83 @@ +#![cfg(any(target_os = "linux", target_os = "macos"))] + +use std::{io, os::fd::AsRawFd}; + +use socket2::Domain; +use talpid_types::ErrorExt; + +#[cfg(target_os = "macos")] +const SIOCSIFMTU: u64 = 0x80206934; +#[cfg(target_os = "macos")] +const SIOCGIFMTU: u64 = 0xc0206933; +#[cfg(target_os = "linux")] +const SIOCSIFMTU: u64 = libc::SIOCSIFMTU; +#[cfg(target_os = "linux")] +const SIOCGIFMTU: u64 = libc::SIOCSIFMTU; + +pub fn set_mtu(interface_name: &str, mtu: u16) -> Result<(), io::Error> { + let sock = socket2::Socket::new( + Domain::IPV4, + socket2::Type::STREAM, + Some(socket2::Protocol::TCP), + )?; + + let mut ifr: libc::ifreq = unsafe { std::mem::zeroed() }; + if interface_name.len() >= ifr.ifr_name.len() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Interface name too long", + )); + } + + // SAFETY: `interface_name.len()` is less than `ifr.ifr_name.len()` + unsafe { + std::ptr::copy_nonoverlapping( + interface_name.as_ptr() as *const libc::c_char, + &mut ifr.ifr_name as *mut _, + interface_name.len(), + ) + }; + ifr.ifr_ifru.ifru_mtu = mtu as i32; + + // SAFETY: SIOCSIFMTU expects an ifreq with an MTU and interface set + if unsafe { libc::ioctl(sock.as_raw_fd(), SIOCSIFMTU, &ifr) } < 0 { + let e = std::io::Error::last_os_error(); + log::error!("{}", e.display_chain_with_msg("SIOCSIFMTU failed")); + return Err(e); + } + Ok(()) +} + +pub fn get_mtu(interface_name: &str) -> Result { + let sock = socket2::Socket::new( + Domain::IPV4, + socket2::Type::STREAM, + Some(socket2::Protocol::TCP), + )?; + + let mut ifr: libc::ifreq = unsafe { std::mem::zeroed() }; + if interface_name.len() >= ifr.ifr_name.len() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "Interface name too long", + )); + } + + // SAFETY: `interface_name.len()` is less than `ifr.ifr_name.len()` + unsafe { + std::ptr::copy_nonoverlapping( + interface_name.as_ptr() as *const libc::c_char, + &mut ifr.ifr_name as *mut _, + interface_name.len(), + ) + }; + + // SAFETY: SIOCGIFMTU expects an ifreq with an interface set + if unsafe { libc::ioctl(sock.as_raw_fd(), SIOCGIFMTU, &ifr) } < 0 { + let e = std::io::Error::last_os_error(); + log::error!("{}", e.display_chain_with_msg("SIOCGIFMTU failed")); + return Err(e); + } + // SAFETY: ifru_mtu is initialized by SIOCGIFMTU + Ok(u16::try_from(unsafe { ifr.ifr_ifru.ifru_mtu }).unwrap()) +} diff --git a/talpid-routing/src/unix/macos/data.rs b/talpid-routing/src/unix/macos/data.rs index a5e8a49efd24..5ac83d08859e 100644 --- a/talpid-routing/src/unix/macos/data.rs +++ b/talpid-routing/src/unix/macos/data.rs @@ -634,6 +634,10 @@ impl Interface { self.header.ifm_index } + pub fn mtu(&self) -> u32 { + self.header.ifm_data.ifi_mtu + } + fn from_byte_buffer(buffer: &[u8]) -> Result { const INTERFACE_MESSAGE_HEADER_SIZE: usize = std::mem::size_of::(); if INTERFACE_MESSAGE_HEADER_SIZE > buffer.len() { diff --git a/talpid-routing/src/unix/macos/mod.rs b/talpid-routing/src/unix/macos/mod.rs index 141b8c06d38d..85a020ba797f 100644 --- a/talpid-routing/src/unix/macos/mod.rs +++ b/talpid-routing/src/unix/macos/mod.rs @@ -90,6 +90,7 @@ pub struct RouteManagerImpl { v6_default_route: Option, update_trigger: BurstGuard, default_route_listeners: Vec>, + interface_change_listeners: Vec>, check_default_routes_restored: Pin + Send>>, unhandled_default_route_changes: bool, primary_interface_monitor: interface::PrimaryInterfaceMonitor, @@ -127,6 +128,7 @@ impl RouteManagerImpl { v6_default_route: None, update_trigger, default_route_listeners: vec![], + interface_change_listeners: vec![], check_default_routes_restored: Box::pin(futures::stream::pending()), unhandled_default_route_changes: false, primary_interface_monitor, @@ -227,6 +229,13 @@ impl RouteManagerImpl { log::error!("Failed to clean up rotues: {err}"); } }, + + Some(RouteManagerCommand::NewInterfaceChangeListener(tx)) => { + let (events_tx, events_rx) = mpsc::unbounded(); + self.interface_change_listeners.push(events_tx); + let _ = tx.send(events_rx); + } + Some(RouteManagerCommand::RefreshRoutes) => { if let Err(error) = self.refresh_routes().await { log::error!("Failed to refresh routes: {error}"); @@ -377,6 +386,20 @@ impl RouteManagerImpl { Ok(RouteSocketMessage::AddAddress(_) | RouteSocketMessage::DeleteAddress(_)) => { self.update_trigger.trigger(); } + Ok(RouteSocketMessage::Interface(iface)) => { + let Ok(mtu) = u16::try_from(iface.mtu()) else { + log::warn!("Invalid mtu for interface: {}", iface.index()); + return; + }; + + self.interface_change_listeners.retain(|tx| { + tx.unbounded_send(super::InterfaceEvent { + interface_index: iface.index(), + mtu, + }) + .is_ok() + }); + } // ignore all other message types Ok(_) => {} Err(err) => { diff --git a/talpid-routing/src/unix/mod.rs b/talpid-routing/src/unix/mod.rs index 044a09433a16..d257140f7e3c 100644 --- a/talpid-routing/src/unix/mod.rs +++ b/talpid-routing/src/unix/mod.rs @@ -118,10 +118,18 @@ pub(crate) enum RouteManagerCommand { RefreshRoutes, NewDefaultRouteListener(oneshot::Sender>), GetDefaultRoutes(oneshot::Sender<(Option, Option)>), + NewInterfaceChangeListener(oneshot::Sender>), /// Return gateway for V4 and V6 GetDefaultGateway(oneshot::Sender<(Option, Option)>), } +/// Event that is sent when interface details may have changed for some interface. +#[cfg(target_os = "macos")] +pub struct InterfaceEvent { + pub interface_index: u16, + pub mtu: u16, +} + /// Event that is sent when a preferred non-tunnel default route is /// added or removed. #[cfg(target_os = "macos")] @@ -227,6 +235,18 @@ impl RouteManagerHandle { response_rx.await.map_err(|_| Error::ManagerChannelDown) } + /// Listen for interface changes. + #[cfg(target_os = "macos")] + pub async fn interface_change_listener( + &self, + ) -> Result, Error> { + let (response_tx, response_rx) = oneshot::channel(); + self.tx + .unbounded_send(RouteManagerCommand::NewInterfaceChangeListener(response_tx)) + .map_err(|_| Error::RouteManagerDown)?; + response_rx.await.map_err(|_| Error::ManagerChannelDown) + } + /// Get default gateway #[cfg(target_os = "macos")] pub async fn get_default_gateway(&self) -> Result<(Option, Option), Error> { diff --git a/talpid-wireguard/Cargo.toml b/talpid-wireguard/Cargo.toml index 806cdd144718..3fea8a17c075 100644 --- a/talpid-wireguard/Cargo.toml +++ b/talpid-wireguard/Cargo.toml @@ -45,6 +45,9 @@ tokio-stream = { version = "0.1", features = ["io-util"] } [target.'cfg(unix)'.dependencies] nix = "0.23" +[target.'cfg(any(target_os = "linux", target_os = "macos"))'.dependencies] +talpid-net = { path = "../talpid-net" } + [target.'cfg(target_os = "linux")'.dependencies] rtnetlink = "0.11" netlink-packet-core = "0.4.2" diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index b9a85560ee59..ca59e014ff54 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -49,8 +49,6 @@ mod connectivity_check; mod logging; mod ping_monitor; mod stats; -#[cfg(any(target_os = "linux", target_os = "macos"))] -mod unix; #[cfg(wireguard_go)] mod wireguard_go; #[cfg(target_os = "linux")] diff --git a/talpid-wireguard/src/mtu_detection.rs b/talpid-wireguard/src/mtu_detection.rs index 5132705719a2..11c6625f2cff 100644 --- a/talpid-wireguard/src/mtu_detection.rs +++ b/talpid-wireguard/src/mtu_detection.rs @@ -60,7 +60,7 @@ pub async fn automatic_mtu_correction( log::warn!("Lowering MTU from {} to {verified_mtu}", current_tunnel_mtu); #[cfg(any(target_os = "linux", target_os = "macos"))] - crate::unix::set_mtu(&iface_name, verified_mtu).map_err(Error::SetMtu)?; + talpid_net::unix::set_mtu(&iface_name, verified_mtu).map_err(Error::SetMtu)?; #[cfg(windows)] set_mtu_windows(verified_mtu, iface_name, ipv6).map_err(Error::SetMtu)?; } else { diff --git a/talpid-wireguard/src/unix.rs b/talpid-wireguard/src/unix.rs deleted file mode 100644 index e2d18f8ab0be..000000000000 --- a/talpid-wireguard/src/unix.rs +++ /dev/null @@ -1,46 +0,0 @@ -use std::{io, os::fd::AsRawFd}; - -use socket2::Domain; -use talpid_types::ErrorExt; - -#[cfg(target_os = "macos")] -const SIOCSIFMTU: u64 = 0x80206934; -#[cfg(target_os = "linux")] -const SIOCSIFMTU: u64 = libc::SIOCSIFMTU; - -pub fn set_mtu(interface_name: &str, mtu: u16) -> Result<(), io::Error> { - debug_assert_ne!( - interface_name, "eth0", - "Should be name of mullvad tunnel interface, e.g. 'wg0-mullvad'" - ); - - let sock = socket2::Socket::new( - Domain::IPV4, - socket2::Type::STREAM, - Some(socket2::Protocol::TCP), - )?; - - let mut ifr: libc::ifreq = unsafe { std::mem::zeroed() }; - if interface_name.len() >= ifr.ifr_name.len() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "Interface name too long", - )); - } - - unsafe { - std::ptr::copy_nonoverlapping( - interface_name.as_ptr() as *const libc::c_char, - &mut ifr.ifr_name as *mut _, - interface_name.len(), - ) - }; - ifr.ifr_ifru.ifru_mtu = mtu as i32; - - if unsafe { libc::ioctl(sock.as_raw_fd(), SIOCSIFMTU, &ifr) } < 0 { - let e = std::io::Error::last_os_error(); - log::error!("{}", e.display_chain_with_msg("SIOCSIFMTU failed")); - return Err(e); - } - Ok(()) -}