From 21cbfebee28b3be0993bb232eb4007735ec498d1 Mon Sep 17 00:00:00 2001 From: Joakim Hulthe Date: Thu, 21 Nov 2024 16:15:20 +0100 Subject: [PATCH] wip: Integrate leak-checker in daemon --- Cargo.lock | 6 + Cargo.toml | 2 + leak-checker/Cargo.toml | 2 +- leak-checker/src/traceroute.rs | 583 ++++-------------- .../src/traceroute/platform/android.rs | 27 + .../src/traceroute/platform/common.rs | 102 +++ leak-checker/src/traceroute/platform/linux.rs | 215 +++++++ leak-checker/src/traceroute/platform/macos.rs | 79 +++ leak-checker/src/traceroute/platform/mod.rs | 85 +++ leak-checker/src/traceroute/platform/unix.rs | 53 ++ .../src/traceroute/platform/windows.rs | 132 ++++ leak-checker/src/util.rs | 4 + mullvad-daemon/Cargo.toml | 5 + mullvad-daemon/src/leak_checker/mod.rs | 244 +++++++- mullvad-daemon/src/lib.rs | 24 +- talpid-core/Cargo.toml | 3 +- talpid-net/Cargo.toml | 2 +- talpid-windows/Cargo.toml | 2 +- talpid-wireguard/Cargo.toml | 2 +- test/test-runner/Cargo.toml | 2 +- 20 files changed, 1079 insertions(+), 495 deletions(-) create mode 100644 leak-checker/src/traceroute/platform/android.rs create mode 100644 leak-checker/src/traceroute/platform/common.rs create mode 100644 leak-checker/src/traceroute/platform/linux.rs create mode 100644 leak-checker/src/traceroute/platform/macos.rs create mode 100644 leak-checker/src/traceroute/platform/mod.rs create mode 100644 leak-checker/src/traceroute/platform/unix.rs create mode 100644 leak-checker/src/traceroute/platform/windows.rs diff --git a/Cargo.lock b/Cargo.lock index 25d542ad83ae..3fd50ad1a1e4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2419,6 +2419,7 @@ name = "mullvad-daemon" version = "0.0.0" dependencies = [ "android_logger", + "anyhow", "async-trait", "chrono", "clap", @@ -2428,6 +2429,7 @@ dependencies = [ "fern", "futures", "hickory-resolver", + "leak-checker", "libc", "log", "log-panics", @@ -2445,6 +2447,8 @@ dependencies = [ "serde", "serde_json", "simple-signal", + "socket2", + "surge-ping", "talpid-core", "talpid-dbus", "talpid-future", @@ -3205,6 +3209,8 @@ dependencies = [ [[package]] name = "pfctl" version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a44e65c0d3523afa79a600a3964c3ac0fabdabe2d7c68da624b2bb0b441b9d61" dependencies = [ "derive_builder", "ioctl-sys 0.8.0", diff --git a/Cargo.toml b/Cargo.toml index d1fea010a858..9d332e123549 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -115,6 +115,7 @@ hickory-server = { version = "0.24.2", features = ["resolver"] } tokio = { version = "1.8" } parity-tokio-ipc = "0.9" futures = "0.3.15" + # Tonic and related crates tonic = "0.12.3" tonic-build = { version = "0.10.0", default-features = false } @@ -139,6 +140,7 @@ serde = "1.0.204" serde_json = "1.0.122" ipnetwork = "0.20" +socket2 = "0.5.7" # Test dependencies proptest = "1.4" diff --git a/leak-checker/Cargo.toml b/leak-checker/Cargo.toml index 6a24daba0cb3..968878c80cc6 100644 --- a/leak-checker/Cargo.toml +++ b/leak-checker/Cargo.toml @@ -24,7 +24,7 @@ clap = { version = "*", features = ["derive"] } tokio = { workspace = true, features = ["full"] } [target.'cfg(unix)'.dependencies] -nix = { version = "0.29.0", features = ["net"] } +nix = { version = "0.29.0", features = ["net", "socket", "uio"] } [target.'cfg(windows)'.dependencies] windows-sys.workspace = true diff --git a/leak-checker/src/traceroute.rs b/leak-checker/src/traceroute.rs index 59b1e0fc3fd9..a4040f2ad9c1 100644 --- a/leak-checker/src/traceroute.rs +++ b/leak-checker/src/traceroute.rs @@ -1,15 +1,15 @@ use std::{ ascii::escape_default, + convert::Infallible, + future::ready, io, net::{IpAddr, Ipv4Addr}, ops::{Range, RangeFrom}, - os::fd::{FromRawFd, IntoRawFd}, time::Duration, }; use eyre::{bail, ensure, eyre, OptionExt, WrapErr}; -use futures::{future::pending, stream, StreamExt, TryFutureExt, TryStreamExt}; -use match_cfg::match_cfg; +use futures::{future::pending, select, stream, FutureExt, StreamExt, TryFutureExt, TryStreamExt}; use pnet_packet::{ icmp::{ echo_request::EchoRequestPacket, time_exceeded::TimeExceededPacket, IcmpPacket, IcmpTypes, @@ -20,13 +20,15 @@ use pnet_packet::{ Packet, }; use socket2::{Domain, Protocol, Socket, Type}; -use tokio::{ - net::UdpSocket, - select, - time::{sleep, sleep_until, timeout, Instant}, -}; +use tokio::time::{sleep, timeout}; + +use crate::LeakStatus; + +mod platform; -use crate::{LeakInfo, LeakStatus}; +use platform::{ + AsyncIcmpSocket, AsyncIcmpSocketImpl, AsyncUdpSocket, AsyncUdpSocketImpl, Impl, Traceroute, +}; #[derive(Clone, clap::Args)] pub struct TracerouteOpt { @@ -36,14 +38,14 @@ pub struct TracerouteOpt { /// Destination IP of the probe packets #[clap(short, long)] - pub destination: Ipv4Addr, + pub destination: IpAddr, - /// Avoid sending probe packets to this port - #[clap(long)] + /// Avoid sending UDP probe packets to this port. + #[clap(long, conflicts_with = "icmp")] pub exclude_port: Option, - /// Send probe packets only to this port, instead of the default ports. - #[clap(long)] + /// Send UDP probe packets only to this port, instead of the default ports. + #[clap(long, conflicts_with = "icmp")] pub port: Option, /// Use ICMP-Echo for the probe packets instead of UDP. @@ -97,101 +99,75 @@ pub async fn run_leak_test(opt: &TracerouteOpt) -> LeakStatus { /// root/admin priviliges. pub async fn try_run_leak_test(opt: &TracerouteOpt) -> eyre::Result { // create the socket used for receiving the ICMP/TimeExceeded responses - let icmp_socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::ICMPV4)) + + // don't ask me why, but this is how it must be. + let icmp_socket_type = if cfg!(target_os = "windows") { + Type::RAW + } else { + Type::DGRAM + }; + + let icmp_socket = Socket::new(Domain::IPV4, icmp_socket_type, Some(Protocol::ICMPV4)) .wrap_err("Failed to open ICMP socket")?; icmp_socket .set_nonblocking(true) .wrap_err("Failed to set icmp_socket to nonblocking")?; - #[cfg(any(target_os = "linux", target_os = "android"))] - { - use std::ffi::c_void; - use std::os::fd::{AsFd, AsRawFd}; - - let n = 1; - unsafe { - setsockopt( - icmp_socket.as_fd().as_raw_fd(), - nix::libc::SOL_IP, - nix::libc::IP_RECVERR, - &n as *const _ as *const std::ffi::c_void, - size_of_val(&n) as u32, - ) - }; - } + Impl::bind_socket_to_interface(&icmp_socket, &opt.interface)?; + Impl::configure_icmp_socket(&icmp_socket, opt)?; - bind_socket_to_interface(&icmp_socket, &opt.interface)?; + let icmp_socket = AsyncIcmpSocketImpl::from_socket2(icmp_socket); - // HACK: Wrap the socket in a tokio::net::UdpSocket to be able to use it async - // SAFETY: `into_raw_fd()` consumes the socket and returns an owned & open file descriptor. - let icmp_socket = unsafe { std::net::UdpSocket::from_raw_fd(icmp_socket.into_raw_fd()) }; - let mut icmp_socket = UdpSocket::from_std(icmp_socket)?; + let send_probes = async { + if opt.icmp { + send_icmp_probes(opt, &icmp_socket).await?; + } else { + // create the socket used for sending the UDP probing packets + let udp_socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)) + .wrap_err("Failed to open UDP socket")?; - // on Windows, we need to do some additional configuration of the raw socket - #[cfg(target_os = "windows")] - configure_listen_socket(&icmp_socket, interface)?; + Impl::bind_socket_to_interface(&udp_socket, &opt.interface) + .wrap_err("Failed to bind UDP socket to interface")?; - if opt.icmp { - timeout(SEND_TIMEOUT, send_icmp_probes(&mut icmp_socket, opt)) - .map_err(|_timeout| eyre!("Timed out while trying to send probe packet")) - .await??; - } else { - // create the socket used for sending the UDP probing packets - let udp_socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)) - .wrap_err("Failed to open UDP socket")?; - bind_socket_to_interface(&udp_socket, &opt.interface) - .wrap_err("Failed to bind UDP socket to interface")?; - udp_socket - .set_nonblocking(true) - .wrap_err("Failed to set udp_socket to nonblocking")?; - - // HACK: Wrap the socket in a tokio::net::UdpSocket to be able to use it async - // SAFETY: `into_raw_fd()` consumes the socket and returns an owned & open file descriptor. - let udp_socket = unsafe { std::net::UdpSocket::from_raw_fd(udp_socket.into_raw_fd()) }; - let mut udp_socket = UdpSocket::from_std(udp_socket)?; - - timeout(SEND_TIMEOUT, send_udp_probes(&mut udp_socket, opt)) - .map_err(|_timeout| eyre!("Timed out while trying to send probe packet")) - .await??; - } + udp_socket + .set_nonblocking(true) + .wrap_err("Failed to set udp_socket to nonblocking")?; - //let recv_task = read_probe_responses(&opt.interface, icmp_socket); - let recv_task = read_probe_responses(&opt.interface, icmp_socket); + let mut udp_socket = AsyncUdpSocketImpl::from_socket2(udp_socket); - // wait until either task exits, or the timeout is reached - let leak_status = select! { - _ = sleep(LEAK_TIMEOUT) => LeakStatus::NoLeak, - result = recv_task => result?, + send_udp_probes(opt, &mut udp_socket).await?; + } + + // Never return + pending::>().await }; - // let send_task = timeout(SEND_TIMEOUT, send_icmp_probes(&mut udp_socket, opt)) - // .map_err(|_timeout| eyre!("Timed out while trying to send probe packet")) - // // never return on success - // .and_then(|_| pending()); - // - // let recv_task = read_probe_responses(&opt.interface, icmp_socket); - // - // wait until either thread exits, or the timeout is reached - // let leak_status = select! { - // _ = sleep(LEAK_TIMEOUT) => LeakStatus::NoLeak, - // result = recv_task => result?, - // result = send_task => result?, - // }; + // error if sending the probes takes longer than SEND_TIMEOUT + let send_probes = timeout(SEND_TIMEOUT, send_probes) + .map_err(|_timeout| eyre!("Timed out while trying to send probe packet")) + .and_then(ready); + + let recv_probe_responses = icmp_socket.recv_ttl_responses(opt); + + // wait until either future returns, or the timeout is reached + // friendship ended with tokio::select. now futures::select is my best friend! + let leak_status = select! { + result = recv_probe_responses.fuse() => result?, + Err(e) = send_probes.fuse() => return Err(e), + _ = sleep(LEAK_TIMEOUT).fuse() => LeakStatus::NoLeak, + }; Ok(leak_status) } -async fn send_icmp_probes(socket: &mut UdpSocket, opt: &TracerouteOpt) -> eyre::Result<()> { +/// Send ICMP/Echo packets with a very low TTL to `opt.destination`. +/// +/// Use [AsyncIcmpSocket::recv_ttl_responses] to receive replies. +async fn send_icmp_probes(opt: &TracerouteOpt, socket: &impl AsyncIcmpSocket) -> eyre::Result<()> { use pnet_packet::icmp::{echo_request::*, *}; - let ports = DEFAULT_PORT_RANGE - // ensure we don't send anything to `opt.exclude_port` - .filter(|&p| Some(p) != opt.exclude_port) - // `opt.port` overrides the default port range - .map(|port| opt.port.unwrap_or(port)); - - for (port, ttl) in ports.zip(DEFAULT_TTL_RANGE) { + for ttl in DEFAULT_TTL_RANGE { log::debug!("sending probe packet (ttl={ttl})"); socket @@ -216,7 +192,7 @@ async fn send_icmp_probes(socket: &mut UdpSocket, opt: &TracerouteOpt) -> eyre:: let result: io::Result<()> = stream::iter(0..number_of_sends) // call `send_to` `number_of_sends` times - .then(|_| socket.send_to(&packet.packet(), (opt.destination, port))) + .then(|_| socket.send_to(packet.packet(), opt.destination)) .map_ok(drop) .try_collect() // abort on the first error .await; @@ -234,7 +210,13 @@ async fn send_icmp_probes(socket: &mut UdpSocket, opt: &TracerouteOpt) -> eyre:: Ok(()) } -async fn send_udp_probes(socket: &mut UdpSocket, opt: &TracerouteOpt) -> eyre::Result<()> { +/// Send UDP packets with a very low TTL to `opt.destination`. +/// +/// Use [Impl::recv_ttl_responses] to receive replies. +async fn send_udp_probes( + opt: &TracerouteOpt, + socket: &mut impl AsyncUdpSocket, +) -> eyre::Result<()> { // ensure we don't send anything to `opt.exclude_port` let ports = DEFAULT_PORT_RANGE // skip the excluded port @@ -272,239 +254,21 @@ async fn send_udp_probes(socket: &mut UdpSocket, opt: &TracerouteOpt) -> eyre::R Ok(()) } -/// Experimental PoC of a linux implementation that doesn't need root. -#[cfg(any(target_os = "linux", target_os = "android"))] -#[allow(dead_code)] -async fn read_probe_responses_no_root( - _interface: &str, - socket: UdpSocket, -) -> eyre::Result { - use nix::libc::{errno::Errno, libc::setsockopt, setsockopt, sock_extended_err}; - use std::ffi::c_void; - use std::mem::transmute; - use std::os::fd::AsRawFd; - - // the list of node IP addresses from which we received a response to our probe packets. - let mut reachable_nodes = vec![]; - - let mut read_buf = vec![0u8; usize::from(u16::MAX)].into_boxed_slice(); - loop { - log::debug!("Reading from ICMP socket"); - - // XXX: only works for ipv4 - let mut msg_name: nix::libc::sockaddr_in = unsafe { std::mem::zeroed() }; - let mut msg_iov = vec![nix::libc::iovec { - iov_base: read_buf.as_mut_ptr() as *mut _, - iov_len: read_buf.len(), - }]; - let mut msg_control = vec![0u8; 2048]; - - let mut msg_header = nix::libc::msghdr { - msg_name: &mut msg_name as *mut _ as *mut c_void, - msg_namelen: size_of_val(&msg_name) as u32, - msg_iov: msg_iov.as_mut_ptr() as *mut _, - msg_iovlen: msg_iov.len(), - msg_control: msg_control.as_mut_ptr() as *mut _, - msg_controllen: msg_control.len(), - msg_flags: 0, - }; - log::debug!("header: {msg_header:?}"); - - // Calling recvmsg with MSG_ERRQUEUE will prompt linux to tell us if we get any ICMP errorr - // replies to our Echos. - let flags = nix::libc::MSG_ERRQUEUE; - let n = loop { - match unsafe { nix::libc::recvmsg(socket.as_raw_fd(), &mut msg_header, flags) } { - ..0 => match nix::errno::Errno::last() { - nix::errno::Errno::EWOULDBLOCK => { - sleep(Duration::from_millis(10)).await; - continue; - } - e => bail!("Faileed to read from socket {e}"), - }, - n => break n as usize, - } - }; - - log::debug!("header after: {msg_header:?}"); - msg_iov.truncate(msg_header.msg_iovlen); - msg_control.truncate(msg_header.msg_controllen); - let _ = msg_header; - - log::debug!("msg_name: {msg_name:?}"); - log::debug!("msg_iov: {msg_iov:?}"); - log::debug!("msg_iov[0]: {:?}", &read_buf[..n]); - log::debug!("msg_control: {msg_control:?}"); - - let source = Ipv4Addr::from_bits(msg_name.sin_addr.s_addr); - //let source = source.ip(); - let (control_header, rest) = msg_control - .split_first_chunk::<{ size_of::() }>() - .ok_or_eyre("Foo")?; - let control_header: nix::libc::cmsghdr = unsafe { transmute(*control_header) }; - let _control_message_len = control_header - .cmsg_len - .saturating_sub(size_of::()); - - debug_assert_eq!(control_header.cmsg_level, nix::libc::IPPROTO_IP); - debug_assert_eq!(control_header.cmsg_type, nix::libc::IP_RECVERR); - - let (control_message, rest) = rest - .split_first_chunk::<{ size_of::() }>() - .ok_or_eyre("ASADAD")?; - //debug_assert_eq!(control_message_len, control_message.len()); - - let control_message: sock_extended_err = unsafe { transmute(*control_message) }; - - let result = parse_icmp_time_exceeded_raw(&rest) - .map_err(|e| eyre!("Ignoring packet (len={n}, ip.src={source}): {e}",)); - - log::debug!("{control_header:?}"); - log::debug!("{control_message:?}"); - log::debug!("rest: {rest:?}"); - log::debug!("{:?}", Errno::from_raw(control_message.ee_errno as i32)); - - let _original_icmp_echo = &read_buf[..n]; - - // contains the source address of the ICMP Time Exceeded packet - let _icmp_source/*: nix::libc::sockaddr */ = rest; - - match result { - Ok(..) => { - log::debug!("Got a probe response, we are leaking!"); - //timeout_at.get_or_insert_with(|| Instant::now() + RECV_TIMEOUT); - //let ip = IpAddr::from(ip); - let ip = IpAddr::from(Ipv4Addr::new(1, 3, 3, 7)); - if !reachable_nodes.contains(&ip) { - reachable_nodes.push(ip); - } - } - - // an error means the packet wasn't the ICMP/TimeExceeded we're listening for. - Err(e) => log::debug!("{e}"), - } - } -} - -async fn read_probe_responses(interface: &str, socket: UdpSocket) -> eyre::Result { - // the list of node IP addresses from which we received a response to our probe packets. - let mut reachable_nodes = vec![]; - - // a time at which this function should exit. this is set when we receive the first probe - // response, and allows us to wait a while to collect any additional probe responses before - // returning. - let mut timeout_at = None; - - let mut read_buf = vec![0u8; usize::from(u16::MAX)].into_boxed_slice(); - loop { - let timer = async { - match timeout_at { - // resolve future at the timeout, if it's set - Some(time) => sleep_until(time).await, - - // otherwise, never resolve - None => pending().await, - } - }; - - log::debug!("Reading from ICMP socket"); - - // let n = socket - // .recv(unsafe { &mut *(&mut read_buf[..] as *mut [u8] as *mut [MaybeUninit]) }) - // .wrap_err("Failed to read from raw socket")?; - - let (n, source) = select! { - result = socket.recv_from(&mut read_buf[..]) => result - .wrap_err("Failed to read from raw socket")?, - - _timeout = timer => { - return Ok(LeakStatus::LeakDetected(LeakInfo::NodeReachableOnInterface { - reachable_nodes, - interface: interface.to_string(), - })); - } - }; - - let source = source.ip(); - let packet = &read_buf[..n]; - let result = parse_ipv4(packet) - .map_err(|e| eyre!("Ignoring packet: (len={n}, ip.src={source}) {e} ({packet:02x?})")) - .and_then(|ip_packet| { - parse_icmp_time_exceeded(&ip_packet).map_err(|e| { - eyre!( - "Ignoring packet (len={n}, ip.src={source}, ip.dest={}): {e}", - ip_packet.get_destination(), - ) - }) - }); - - match result { - Ok(ip) => { - log::debug!("Got a probe response, we are leaking!"); - timeout_at.get_or_insert_with(|| Instant::now() + RECV_TIMEOUT); - let ip = IpAddr::from(ip); - if !reachable_nodes.contains(&ip) { - reachable_nodes.push(ip); - } - } - - // an error means the packet wasn't the ICMP/TimeExceeded we're listening for. - Err(e) => log::debug!("{e}"), - } - } -} - -/// Configure the raw socket we use for listening to ICMP responses. -/// -/// This will bind the socket to an interface, and set the `SIO_RCVALL`-option. -#[cfg(target_os = "windows")] -fn configure_listen_socket(socket: &Socket, interface: &str) -> eyre::Result<()> { - use std::{ffi::c_void, os::windows::io::AsRawSocket, ptr::null_mut}; - use windows_sys::Win32::Networking::WinSock::{ - WSAGetLastError, WSAIoctl, SIO_RCVALL, SOCKET, SOCKET_ERROR, - }; - - bind_socket_to_interface(&socket, interface) - .wrap_err("Failed to bind listen socket to interface")?; - - let j = 1; - let mut _in: u32 = 0; - let result = unsafe { - WSAIoctl( - socket.as_raw_socket() as SOCKET, - SIO_RCVALL, - &j as *const _ as *const c_void, - size_of_val(&j) as u32, - null_mut(), - 0, - &mut _in as *mut u32, - null_mut(), - None, - ) - }; - - if result == SOCKET_ERROR { - let code = unsafe { WSAGetLastError() }; - bail!("Failed to call WSAIoctl(listen_socket, SIO_RCVALL, ...), code = {code}"); - } - - Ok(()) -} - /// Try to parse the bytes as an IPv4 packet. /// /// This only valdiates the IPv4 header, not the payload. fn parse_ipv4(packet: &[u8]) -> eyre::Result> { - let ip_packet = Ipv4Packet::new(packet).ok_or_eyre("Too small")?; + let ip_packet = Ipv4Packet::new(packet).ok_or_else(too_small)?; ensure!(ip_packet.get_version() == 4, "Not IPv4"); eyre::Ok(ip_packet) } /// Try to parse an [Ipv4Packet] as an ICMP/TimeExceeded response to a packet sent by -/// [send_probes]. If successful, returns the [Ipv4Addr] of the packet source. +/// [send_udp_probes] or [send_icmp_probes]. If successful, returns the [Ipv4Addr] of the packet +/// source. /// -/// If the packet fails to parse, or is not a reply to a packet sent by [send_probes], this -/// function returns an error. +/// If the packet fails to parse, or is not a reply to a packet sent by us, this function returns +/// an error. fn parse_icmp_time_exceeded(ip_packet: &Ipv4Packet<'_>) -> eyre::Result { let ip_protocol = ip_packet.get_next_level_protocol(); ensure!(ip_protocol == IpProtocol::Icmp, "Not ICMP"); @@ -512,9 +276,13 @@ fn parse_icmp_time_exceeded(ip_packet: &Ipv4Packet<'_>) -> eyre::Result eyre::Result<()> { let icmp_packet = IcmpPacket::new(bytes).ok_or(eyre!("Too small"))?; - let too_small = || eyre!("Too small"); let correct_type = icmp_packet.get_icmp_type() == IcmpTypes::TimeExceeded; ensure!(correct_type, "Not ICMP/TimeExceeded"); @@ -538,7 +306,7 @@ fn parse_icmp_time_exceeded_raw(bytes: &[u8]) -> eyre::Result<()> { .checked_sub(UdpPacket::minimum_packet_size()) .and_then(|len| original_udp_packet.payload().get(..len)) .ok_or_eyre("Invalid UDP length")?; - if udp_payload != &PROBE_PAYLOAD { + if udp_payload != PROBE_PAYLOAD { let udp_payload: String = udp_payload .iter() .copied() @@ -581,163 +349,40 @@ fn parse_icmp_time_exceeded_raw(bytes: &[u8]) -> eyre::Result<()> { } } -match_cfg! { - #[cfg(any(target_os = "windows", target_os = "android"))] => { - fn bind_socket_to_interface(socket: &Socket, interface: &str) -> eyre::Result<()> { - use crate::util::get_interface_ip; - use std::net::SocketAddr; - - let interface_ip = get_interface_ip(interface)?; - - log::info!("Binding socket to {interface_ip} ({interface:?})"); - - socket.bind(&SocketAddr::new(interface_ip, 0).into()) - .wrap_err("Failed to bind socket to interface address")?; +fn parse_icmp_echo(ip_packet: &Ipv4Packet<'_>) -> eyre::Result<()> { + let ip_protocol = ip_packet.get_next_level_protocol(); - return Ok(()); - } + match ip_protocol { + IpProtocol::Icmp => parse_icmp_echo_raw(ip_packet.payload()), + _ => bail!("Not UDP/ICMP"), } - #[cfg(target_os = "linux")] => { - fn bind_socket_to_interface(socket: &Socket, interface: &str) -> eyre::Result<()> { - log::info!("Binding socket to {interface:?}"); - - socket - .bind_device(Some(interface.as_bytes())) - .wrap_err("Failed to bind socket to interface")?; +} - Ok(()) - } +fn parse_icmp_echo_raw(icmp_bytes: &[u8]) -> eyre::Result<()> { + let echo_packet = EchoRequestPacket::new(icmp_bytes).ok_or_else(too_small)?; + + ensure!( + echo_packet.get_icmp_type() == IcmpTypes::EchoRequest, + "Not ICMP/EchoRequest" + ); + + // check if payload looks right + // some network nodes will strip the payload. + // some network nodes will add a bunch of zeros at the end. + let echo_payload = echo_packet.payload(); + if !echo_payload.is_empty() && !echo_payload.starts_with(&PROBE_PAYLOAD) { + let echo_payload: String = echo_payload + .iter() + .copied() + .flat_map(escape_default) + .map(char::from) + .collect(); + bail!("Wrong ICMP/Echo payload: {echo_payload:?}"); } - #[cfg(target_os = "macos")] => { - fn bind_socket_to_interface(socket: &Socket, interface: &str) -> eyre::Result<()> { - use nix::net::if_::if_nametoindex; - use std::num::NonZero; - - log::info!("Binding socket to {interface:?}"); - - let interface_index = if_nametoindex(interface) - .map_err(eyre::Report::from) - .and_then(|code| NonZero::new(code).ok_or_eyre("Non-zero error code")) - .wrap_err("Failed to get interface index")?; - socket.bind_device_by_index_v4(Some(interface_index))?; - Ok(()) - } - } + Ok(()) } -// OLD ICMP SEND CODE -// -// use talpid_windows::net::{get_ip_address_for_interface, luid_from_alias, AddressFamily}; -// let interface_luid = luid_from_alias(INTERFACE)?; -// let IpAddr::V4(interface_ip) = -// get_ip_address_for_interface(AddressFamily::Ipv4, interface_luid)? -// .ok_or(eyre!("No IP for interface {INTERFACE:?}"))? -// else { -// panic!() -// }; -// -// for ttl in 1..=5 { -// let mut packet = Packet { -// ip: Ipv4Header { -// version_and_ihl: 0x45, -// dscp_and_ecn: 0, // should be fine -// total_length: (size_of::() as u16).to_be_bytes(), -// _stuff: Default::default(), // should be fine -// ttl, -// protocol: 1, // icmp -// header_checksum: Default::default(), -// source_address: interface_ip.octets(), -// destination_address: destination.octets(), -// }, -// icmp: Icmpv4Header { -// icmp_type: 8, // echo -// code: 0, -// checksum: Default::default(), -// }, -// }; -// let icmp = Icmpv4Header { -// icmp_type: 8, // echo -// code: 0, -// checksum: Default::default(), -// }; -// -// packet.ip.header_checksum = checksum(packet.ip.as_bytes()); -// let mut packet = Icmpv4Packet { -// header: icmp, -// payload: Icmpv4EchoPayload { -// identifier: 0u16.to_be_bytes(), -// sequence_number: (ttl as u16).to_be_bytes(), -// data: [0x77; 32], -// }, -// }; -// -// packet.header.checksum = checksum(packet.as_bytes()); -// -// let packet = packet; -// -// listen_socket.set_ttl(ttl).wrap_err("Failed to set TTL")?; -// listen_socket -// .send_to( -// packet.as_bytes(), -// &SocketAddrV4::new(destination, 0u16).into(), -// ) -// .wrap_err("Failed to send on raw socket")?; -// } - -// use talpid_windows::net::{get_ip_address_for_interface, luid_from_alias, AddressFamily}; -// let interface_luid = luid_from_alias(INTERFACE)?; -// let IpAddr::V4(interface_ip) = -// get_ip_address_for_interface(AddressFamily::Ipv4, interface_luid)? -// .ok_or(eyre!("No IP for interface {INTERFACE:?}"))? -// else { -// panic!() -// }; -// -// for ttl in 1..=5 { -// let mut packet = Packet { -// ip: Ipv4Header { -// version_and_ihl: 0x45, -// dscp_and_ecn: 0, // should be fine -// total_length: (size_of::() as u16).to_be_bytes(), -// _stuff: Default::default(), // should be fine -// ttl, -// protocol: 1, // icmp -// header_checksum: Default::default(), -// source_address: interface_ip.octets(), -// destination_address: destination.octets(), -// }, -// icmp: Icmpv4Header { -// icmp_type: 8, // echo -// code: 0, -// checksum: Default::default(), -// }, -// }; -// let icmp = Icmpv4Header { -// icmp_type: 8, // echo -// code: 0, -// checksum: Default::default(), -// }; -// -// packet.ip.header_checksum = checksum(packet.ip.as_bytes()); -// let mut packet = Icmpv4Packet { -// header: icmp, -// payload: Icmpv4EchoPayload { -// identifier: 0u16.to_be_bytes(), -// sequence_number: (ttl as u16).to_be_bytes(), -// data: [0x77; 32], -// }, -// }; -// -// packet.header.checksum = checksum(packet.as_bytes()); -// -// let packet = packet; -// -// listen_socket.set_ttl(ttl).wrap_err("Failed to set TTL")?; -// listen_socket -// .send_to( -// packet.as_bytes(), -// &SocketAddrV4::new(destination, 0u16).into(), -// ) -// .wrap_err("Failed to send on raw socket")?; -// } +fn too_small() -> eyre::Report { + eyre!("Too small") +} diff --git a/leak-checker/src/traceroute/platform/android.rs b/leak-checker/src/traceroute/platform/android.rs new file mode 100644 index 000000000000..ac02c589beac --- /dev/null +++ b/leak-checker/src/traceroute/platform/android.rs @@ -0,0 +1,27 @@ +use std::net::IpAddr; + +use socket2::Socket; + +use crate::traceroute::TracerouteOpt; + +use super::{linux, linux::TracerouteLinux, unix, Traceroute}; + +pub struct TracerouteAndroid; + +impl Traceroute for TracerouteAndroid { + type AsyncIcmpSocket = linux::AsyncIcmpSocketImpl; + type AsyncUdpSocket = unix::AsyncUdpSocketUnix; + + fn bind_socket_to_interface(socket: &Socket, interface: &str) -> eyre::Result<()> { + // can't use the same method as desktop-linux here beacuse reasons + super::common::bind_socket_to_interface(socket, interface) + } + + fn get_interface_ip(interface: &str) -> eyre::Result { + super::unix::get_interface_ip(interface) + } + + fn configure_icmp_socket(socket: &socket2::Socket, opt: &TracerouteOpt) -> eyre::Result<()> { + TracerouteLinux::configure_icmp_socket(socket, opt) + } +} diff --git a/leak-checker/src/traceroute/platform/common.rs b/leak-checker/src/traceroute/platform/common.rs new file mode 100644 index 000000000000..9c7a5c2f18c0 --- /dev/null +++ b/leak-checker/src/traceroute/platform/common.rs @@ -0,0 +1,102 @@ +#![allow(dead_code)] // some code here is not used on some targets. + +use std::{ + future::pending, + net::{IpAddr, SocketAddr}, +}; + +use eyre::{eyre, Context}; +use socket2::Socket; +use tokio::{ + select, + time::{sleep_until, Instant}, +}; + +use crate::{ + traceroute::{parse_icmp_time_exceeded, parse_ipv4, RECV_TIMEOUT}, + LeakInfo, LeakStatus, +}; + +use super::{AsyncIcmpSocket, Impl, Traceroute}; + +pub fn bind_socket_to_interface(socket: &Socket, interface: &str) -> eyre::Result<()> { + let interface_ip = Impl::get_interface_ip(interface)?; + + log::info!("Binding socket to {interface_ip} ({interface:?})"); + + socket + .bind(&SocketAddr::new(interface_ip, 0).into()) + .wrap_err("Failed to bind socket to interface address")?; + + Ok(()) +} + +pub async fn recv_ttl_responses( + socket: &impl AsyncIcmpSocket, + interface: &str, +) -> eyre::Result { + // the list of node IP addresses from which we received a response to our probe packets. + let mut reachable_nodes = vec![]; + + // A time at which this function should exit. This is set when we receive the first probe + // response, and allows us to wait a while to collect any additional probe responses before + // returning. + let mut timeout_at = None; + + let mut read_buf = vec![0u8; usize::from(u16::MAX)].into_boxed_slice(); + loop { + let timer = async { + match timeout_at { + // resolve future at the timeout, if it's set + Some(time) => sleep_until(time).await, + + // otherwise, never resolve + None => pending().await, + } + }; + + log::debug!("Reading from ICMP socket"); + + // let n = socket + // .recv(unsafe { &mut *(&mut read_buf[..] as *mut [u8] as *mut [MaybeUninit]) }) + // .wrap_err("Failed to read from raw socket")?; + + let (n, source) = select! { + result = socket.recv_from(&mut read_buf[..]) => result + .wrap_err("Failed to read from raw socket")?, + + _timeout = timer => { + return Ok(LeakStatus::LeakDetected(LeakInfo::NodeReachableOnInterface { + reachable_nodes, + interface: interface.to_string(), + })); + } + }; + + let packet = &read_buf[..n]; + let result = parse_ipv4(packet) + .map_err(|e| eyre!("Ignoring packet: (len={n}, ip.src={source}) {e} ({packet:02x?})")) + .and_then(|ip_packet| { + parse_icmp_time_exceeded(&ip_packet).map_err(|e| { + eyre!( + "Ignoring packet (len={n}, ip.src={source}, ip.dest={}): {e}", + ip_packet.get_destination(), + ) + }) + }); + + match result { + Ok(ip) => { + log::debug!("Got a probe response, we are leaking!"); + timeout_at.get_or_insert_with(|| Instant::now() + RECV_TIMEOUT); + let ip = IpAddr::from(ip); + if !reachable_nodes.contains(&ip) { + reachable_nodes.push(ip); + } + } + + // an error means the packet wasn't the ICMP/TimeExceeded we're listening for. + Err(e) => log::debug!("{e}"), + } + } +} diff --git a/leak-checker/src/traceroute/platform/linux.rs b/leak-checker/src/traceroute/platform/linux.rs new file mode 100644 index 000000000000..3dd07ab00681 --- /dev/null +++ b/leak-checker/src/traceroute/platform/linux.rs @@ -0,0 +1,215 @@ +use std::io::{self, IoSliceMut}; +use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd}; +use std::{net::IpAddr, time::Duration}; + +use eyre::{bail, WrapErr}; +use nix::errno::Errno; +use nix::sys::socket::sockopt::Ipv4RecvErr; +use nix::sys::socket::{setsockopt, ControlMessageOwned, MsgFlags, SockaddrIn}; +use nix::{cmsg_space, libc}; +use pnet_packet::icmp::time_exceeded::IcmpCodes; +use pnet_packet::icmp::IcmpTypes; +use pnet_packet::icmp::{IcmpCode, IcmpType}; +use socket2::Socket; +use tokio::time::{sleep, Instant}; + +use crate::traceroute::{parse_icmp_echo_raw, TracerouteOpt, RECV_TIMEOUT}; +use crate::{LeakInfo, LeakStatus}; + +use super::{unix, AsyncIcmpSocket, Traceroute}; + +pub struct TracerouteLinux; + +pub struct AsyncIcmpSocketImpl(tokio::net::UdpSocket); + +impl Traceroute for TracerouteLinux { + type AsyncIcmpSocket = AsyncIcmpSocketImpl; + type AsyncUdpSocket = unix::AsyncUdpSocketUnix; + + fn bind_socket_to_interface(socket: &Socket, interface: &str) -> eyre::Result<()> { + bind_socket_to_interface(socket, interface) + } + + fn get_interface_ip(interface: &str) -> eyre::Result { + super::unix::get_interface_ip(interface) + } + + fn configure_icmp_socket(socket: &socket2::Socket, _opt: &TracerouteOpt) -> eyre::Result<()> { + // IP_RECVERR tells Linux to pass any error packets received over ICMP to us through `recvmsg` control messages. + setsockopt(socket, Ipv4RecvErr, &true).wrap_err("Failed to set IP_RECVERR") + } +} + +impl AsyncIcmpSocket for AsyncIcmpSocketImpl { + fn from_socket2(socket: Socket) -> Self { + let raw_socket = socket.into_raw_fd(); + let std_socket = unsafe { std::net::UdpSocket::from_raw_fd(raw_socket) }; + let tokio_socket = tokio::net::UdpSocket::from_std(std_socket).unwrap(); + AsyncIcmpSocketImpl(tokio_socket) + } + + fn set_ttl(&self, ttl: u32) -> eyre::Result<()> { + self.0 + .set_ttl(ttl) + .wrap_err("Failed to set TTL value for socket") + } + + async fn send_to(&self, packet: &[u8], destination: impl Into) -> io::Result { + self.0.send_to(packet, (destination.into(), 0)).await + } + + async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, IpAddr)> { + self.0 + .recv_from(buf) + .await + .map(|(n, source)| (n, source.ip())) + } + + async fn recv_ttl_responses(&self, opt: &TracerouteOpt) -> eyre::Result { + recv_ttl_responses(opt.destination, &opt.interface, &self.0).await + } +} + +fn bind_socket_to_interface(socket: &Socket, interface: &str) -> eyre::Result<()> { + log::info!("Binding socket to {interface:?}"); + + socket + .bind_device(Some(interface.as_bytes())) + .wrap_err("Failed to bind socket to interface")?; + + Ok(()) +} + +/// Try to read ICMP/TimeExceeded error packets from an ICMP socket. +/// +/// This method does not require root, but only works on Linux (including Android). +// TODO: double check if this works on MacOS +async fn recv_ttl_responses( + destination: IpAddr, + interface: &str, + socket: &impl AsRawFd, +) -> eyre::Result { + // the list of node IP addresses from which we received a response to our probe packets. + let mut reachable_nodes = vec![]; + + // A time at which this function should exit. This is set when we receive the first probe + // response, and allows us to wait a while to collect any additional probe responses before + // returning. + let mut timeout_at = None; + + // Allocate buffer for receiving packets. + let mut recv_buf = vec![0u8; usize::from(u16::MAX)].into_boxed_slice(); + let mut io_vec = [IoSliceMut::new(&mut recv_buf)]; + + // Allocate space for EHOSTUNREACH errors caused by ICMP/TimeExceeded packets. + // This is the size of ControlMessageOwned::Ipv4RecvErr(sock_extended_err, sockaddr_in). + // FIXME: sockaddr_in only works for ipv4 + let mut control_buf = cmsg_space!(libc::sock_extended_err, libc::sockaddr_in); + + 'outer: loop { + log::debug!("Reading from ICMP socket"); + + let recv = loop { + if let Some(timeout_at) = timeout_at { + if Instant::now() >= timeout_at { + break 'outer; + } + } + + match nix::sys::socket::recvmsg::( + socket.as_raw_fd(), + &mut io_vec, + Some(&mut control_buf), + // NOTE: MSG_ERRQUEUE asks linux to tell us if we get any ICMP error replies to + // our Echo packets. + MsgFlags::MSG_ERRQUEUE, + ) { + Ok(recv) => break recv, + + // poor-mans async IO :'( + Err(Errno::EWOULDBLOCK) => { + sleep(Duration::from_millis(10)).await; + continue; + } + + Err(e) => bail!("Faileed to read from socket {e}"), + }; + }; + + // NOTE: This should be the IP destination of our ping packets. That does NOT mean the + // packets reached the destination. Instead, if we see an EHOSTUNREACH control message, + // it means the packets was instead dropped along the way. Seeing this address helps us + // identify that this is a response to the ping we sent. + // // FIXME: sockaddr_in only works for ipv4 + let source: SockaddrIn = recv.address.unwrap(); + let source = source.ip(); + debug_assert_eq!(source, destination); + + let mut control_messages = recv + .cmsgs() + .wrap_err("Failed to decode cmsgs from recvmsg")?; + + let error_source = match control_messages.next() { + Some(ControlMessageOwned::Ipv6RecvErr(_socket_error, _source_addr)) => { + bail!("IPv6 not implemented"); + } + Some(ControlMessageOwned::Ipv4RecvErr(socket_error, source_addr)) => { + let libc::sock_extended_err { + ee_errno, // Error Number: Should be EHOSTUNREACH + ee_origin, // Error Origin: 2 = Icmp, 3 = Icmp6. + ee_type, // ICMP Type: 11 = ICMP/TimeExceeded. + ee_code, // ICMP Code. 0 = TTL exceeded in transit. + ee_pad: _, // padding + ee_info: _, // N/A + ee_data: _, // N/A + } = socket_error; + + let errno = Errno::from_raw(ee_errno as i32); + debug_assert_eq!(errno, Errno::EHOSTUNREACH); + debug_assert_eq!(ee_origin, nix::libc::SO_EE_ORIGIN_ICMP); // TODO: or SO_EE_ORIGIN_ICMP6 + + // TODO: Icmp6Types + let icmp_type = IcmpType::new(ee_type); + debug_assert_eq!(icmp_type, IcmpTypes::TimeExceeded); + + let icmp_code = IcmpCode::new(ee_code); + debug_assert_eq!(icmp_code, IcmpCodes::TimeToLiveExceededInTransit); + + // NOTE: This is the IP of the node that dropped the packet due to TTL exceeded. + let error_source = SockaddrIn::from(source_addr.unwrap()); + log::debug!("addr: {error_source}"); + + error_source + } + Some(other_message) => { + // TODO: We might want to not error in this case, and just ignore the cmsg. + // If so, we should loop over the iterator instead of taking the first elem. + bail!("Unhandled control message: {other_message:?}"); + } + None => { + // We're looking for EHOSTUNREACH errors. No errors means skip. + log::debug!("Skipping recvmsg that produced no control messages."); + continue; + } + }; + + let packet = recv.iovs().next().unwrap(); + + // Ensure that this is the original Echo packet that we sent. + // TODO: skip on error + parse_icmp_echo_raw(packet).wrap_err("")?; + + log::debug!("Got a probe response, we are leaking!"); + timeout_at.get_or_insert_with(|| Instant::now() + RECV_TIMEOUT); + reachable_nodes.push(IpAddr::from(error_source.ip())); + } + + debug_assert!(!reachable_nodes.is_empty()); + + Ok(LeakStatus::LeakDetected( + LeakInfo::NodeReachableOnInterface { + reachable_nodes, + interface: interface.to_string(), + }, + )) +} diff --git a/leak-checker/src/traceroute/platform/macos.rs b/leak-checker/src/traceroute/platform/macos.rs new file mode 100644 index 000000000000..520c3681f32d --- /dev/null +++ b/leak-checker/src/traceroute/platform/macos.rs @@ -0,0 +1,79 @@ +use std::io; +use std::net::IpAddr; +use std::os::fd::{FromRawFd, IntoRawFd}; + +use eyre::{OptionExt, WrapErr}; +use socket2::Socket; + +use crate::traceroute::TracerouteOpt; +use crate::LeakStatus; + +use super::{common, unix, AsyncIcmpSocket, Traceroute}; + +pub struct TracerouteMacos; + +pub struct AsyncIcmpSocketImpl(tokio::net::UdpSocket); + +impl Traceroute for TracerouteMacos { + type AsyncIcmpSocket = AsyncIcmpSocketImpl; + type AsyncUdpSocket = unix::AsyncUdpSocketUnix; + + fn bind_socket_to_interface(socket: &Socket, interface: &str) -> eyre::Result<()> { + // can't use the same method as desktop-linux here beacuse reasons + bind_socket_to_interface(socket, interface) + } + + fn get_interface_ip(interface: &str) -> eyre::Result { + super::unix::get_interface_ip(interface) + } + + fn configure_icmp_socket(_socket: &socket2::Socket, _opt: &TracerouteOpt) -> eyre::Result<()> { + Ok(()) + // TODO: not sure if we need to do anything here + } +} + +impl AsyncIcmpSocket for AsyncIcmpSocketImpl { + fn from_socket2(socket: Socket) -> Self { + let raw_socket = socket.into_raw_fd(); + let std_socket = unsafe { std::net::UdpSocket::from_raw_fd(raw_socket) }; + let tokio_socket = tokio::net::UdpSocket::from_std(std_socket).unwrap(); + AsyncIcmpSocketImpl(tokio_socket) + } + + fn set_ttl(&self, ttl: u32) -> eyre::Result<()> { + self.0 + .set_ttl(ttl) + .wrap_err("Failed to set TTL value for socket") + } + + async fn send_to(&self, packet: &[u8], destination: impl Into) -> io::Result { + self.0.send_to(packet, (destination.into(), 0)).await + } + + async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, IpAddr)> { + self.0 + .recv_from(buf) + .await + .map(|(n, source)| (n, source.ip())) + } + + async fn recv_ttl_responses(&self, opt: &TracerouteOpt) -> eyre::Result { + common::recv_ttl_responses(self, &opt.interface).await + } +} + +pub fn bind_socket_to_interface(socket: &Socket, interface: &str) -> eyre::Result<()> { + use nix::net::if_::if_nametoindex; + use std::num::NonZero; + + log::info!("Binding socket to {interface:?}"); + + let interface_index = if_nametoindex(interface) + .map_err(eyre::Report::from) + .and_then(|code| NonZero::new(code).ok_or_eyre("Non-zero error code")) + .wrap_err("Failed to get interface index")?; + + socket.bind_device_by_index_v4(Some(interface_index))?; + Ok(()) +} diff --git a/leak-checker/src/traceroute/platform/mod.rs b/leak-checker/src/traceroute/platform/mod.rs new file mode 100644 index 000000000000..1812726cf15a --- /dev/null +++ b/leak-checker/src/traceroute/platform/mod.rs @@ -0,0 +1,85 @@ +use std::{ + io, + net::{IpAddr, SocketAddr}, +}; + +use crate::LeakStatus; + +use super::TracerouteOpt; + +#[cfg(any(target_os = "linux", target_os = "android"))] +pub mod android; + +#[cfg(any(target_os = "linux", target_os = "android"))] +pub mod linux; + +#[cfg(target_os = "macos")] +pub mod macos; + +#[cfg(target_os = "windows")] +pub mod windows; + +/// Implementations that are applicable to all unix platforms. +#[cfg(unix)] +pub mod unix; + +/// Implementations that are applicable to all platforms. +pub mod common; + +/// Private trait that let's us define the platform-specific implementations and types required for +/// tracerouting. +pub trait Traceroute { + type AsyncIcmpSocket: AsyncIcmpSocket; + type AsyncUdpSocket: AsyncUdpSocket; + + fn get_interface_ip(interface: &str) -> eyre::Result; + + fn bind_socket_to_interface(socket: &socket2::Socket, interface: &str) -> eyre::Result<()>; + + /// Configure an ICMP socket to allow reception of ICMP/TimeExceeded errors. + // TODO: consider moving into AsyncIcmpSocket constructor + fn configure_icmp_socket(socket: &socket2::Socket, opt: &TracerouteOpt) -> eyre::Result<()>; +} + +pub trait AsyncIcmpSocket { + fn from_socket2(socket: socket2::Socket) -> Self; + + fn set_ttl(&self, ttl: u32) -> eyre::Result<()>; + + /// Send an ICMP packet to the destination. + // TODO: eyre? + async fn send_to(&self, packet: &[u8], destination: impl Into) -> io::Result; + + /// Receive an ICMP packet + async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, IpAddr)>; + + /// Try to read ICMP/TimeExceeded error packets. + // TODO: this should be renamed, or not return a LeakStatus + async fn recv_ttl_responses(&self, opt: &TracerouteOpt) -> eyre::Result; +} + +pub trait AsyncUdpSocket { + fn from_socket2(socket: socket2::Socket) -> Self; + + fn set_ttl(&self, ttl: u32) -> eyre::Result<()>; + + /// Send an UDP packet to the destination. + // TODO: eyre? + async fn send_to(&self, packet: &[u8], destination: impl Into) + -> io::Result; +} + +#[cfg(target_os = "android")] +pub type Impl = platform::android::TracerouteAndroid; + +#[cfg(target_os = "linux")] +pub type Impl = linux::TracerouteLinux; + +#[cfg(target_os = "macos")] +pub type Impl = macos::TracerouteMacos; + +#[cfg(target_os = "windows")] +pub type Impl = windows::TracerouteWindows; + +pub type AsyncIcmpSocketImpl = ::AsyncIcmpSocket; +pub type AsyncUdpSocketImpl = ::AsyncUdpSocket; diff --git a/leak-checker/src/traceroute/platform/unix.rs b/leak-checker/src/traceroute/platform/unix.rs new file mode 100644 index 000000000000..b63d8745fa99 --- /dev/null +++ b/leak-checker/src/traceroute/platform/unix.rs @@ -0,0 +1,53 @@ +use std::net::{IpAddr, SocketAddr}; +use std::os::fd::{FromRawFd, IntoRawFd}; + +use eyre::Context; + +use super::AsyncUdpSocket; + +pub fn get_interface_ip(interface: &str) -> eyre::Result { + for interface_address in nix::ifaddrs::getifaddrs()? { + if interface_address.interface_name != interface { + continue; + }; + let Some(address) = interface_address.address else { + continue; + }; + let Some(address) = address.as_sockaddr_in() else { + continue; + }; + + // TODO: ipv6 + //let Some(address) = address.as_sockaddr_in6() else { continue }; + + return Ok(address.ip().into()); + } + + eyre::bail!("Interface {interface:?} has no valid IP to bind to"); +} + +pub struct AsyncUdpSocketUnix(tokio::net::UdpSocket); + +impl AsyncUdpSocket for AsyncUdpSocketUnix { + fn from_socket2(socket: socket2::Socket) -> Self { + // HACK: Wrap the socket in a tokio::net::UdpSocket to be able to use it async + // SAFETY: `into_raw_fd()` consumes the socket and returns an owned & open file descriptor. + let udp_socket = unsafe { std::net::UdpSocket::from_raw_fd(socket.into_raw_fd()) }; + let udp_socket = tokio::net::UdpSocket::from_std(udp_socket).unwrap(); + AsyncUdpSocketUnix(udp_socket) + } + + fn set_ttl(&self, ttl: u32) -> eyre::Result<()> { + self.0 + .set_ttl(ttl) + .wrap_err("Failed to set TTL value for UDP socket") + } + + async fn send_to( + &self, + packet: &[u8], + destination: impl Into, + ) -> std::io::Result { + self.0.send_to(packet, destination.into()).await + } +} diff --git a/leak-checker/src/traceroute/platform/windows.rs b/leak-checker/src/traceroute/platform/windows.rs new file mode 100644 index 000000000000..adcb3b7d7f95 --- /dev/null +++ b/leak-checker/src/traceroute/platform/windows.rs @@ -0,0 +1,132 @@ +use std::{ + ffi::c_void, + io, mem, + net::{IpAddr, SocketAddr}, + os::windows::io::{AsRawSocket, AsSocket, FromRawSocket, IntoRawSocket}, + ptr::null_mut, +}; + +use eyre::{bail, eyre, Context}; +use socket2::Socket; +use talpid_windows::net::{get_ip_address_for_interface, luid_from_alias, AddressFamily}; + +use windows_sys::Win32::Networking::WinSock::{ + WSAGetLastError, WSAIoctl, SIO_RCVALL, SOCKET, SOCKET_ERROR, +}; + +use crate::{traceroute::TracerouteOpt, LeakStatus}; + +use super::{common, AsyncIcmpSocket, AsyncUdpSocket, Traceroute}; + +pub struct TracerouteWindows; + +pub struct AsyncIcmpSocketImpl(tokio::net::UdpSocket); + +pub struct AsyncUdpSocketWindows(tokio::net::UdpSocket); + +impl Traceroute for TracerouteWindows { + type AsyncIcmpSocket = AsyncIcmpSocketImpl; + type AsyncUdpSocket = AsyncUdpSocketWindows; + + fn bind_socket_to_interface(socket: &Socket, interface: &str) -> eyre::Result<()> { + common::bind_socket_to_interface(socket, interface) + } + + fn get_interface_ip(interface: &str) -> eyre::Result { + get_interface_ip(interface) + } + + fn configure_icmp_socket(socket: &socket2::Socket, _opt: &TracerouteOpt) -> eyre::Result<()> { + configure_icmp_socket(socket) + } +} + +impl AsyncIcmpSocket for AsyncIcmpSocketImpl { + fn from_socket2(socket: Socket) -> Self { + let raw_socket = socket.as_socket().as_raw_socket(); + mem::forget(socket); + let std_socket = unsafe { std::net::UdpSocket::from_raw_socket(raw_socket) }; + let tokio_socket = tokio::net::UdpSocket::from_std(std_socket).unwrap(); + AsyncIcmpSocketImpl(tokio_socket) + } + + fn set_ttl(&self, ttl: u32) -> eyre::Result<()> { + self.0 + .set_ttl(ttl) + .wrap_err("Failed to set TTL value for ICMP socket") + } + + async fn send_to(&self, packet: &[u8], destination: impl Into) -> io::Result { + self.0.send_to(packet, (destination.into(), 0)).await + } + + async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, std::net::IpAddr)> { + let (n, source) = self.0.recv_from(buf).await?; + Ok((n, source.ip())) + } + + async fn recv_ttl_responses(&self, opt: &TracerouteOpt) -> eyre::Result { + common::recv_ttl_responses(self, &opt.interface).await + } +} + +impl AsyncUdpSocket for AsyncUdpSocketWindows { + fn from_socket2(socket: socket2::Socket) -> Self { + // HACK: Wrap the socket in a tokio::net::UdpSocket to be able to use it async + let udp_socket = unsafe { std::net::UdpSocket::from_raw_socket(socket.into_raw_socket()) }; + let udp_socket = tokio::net::UdpSocket::from_std(udp_socket).unwrap(); + AsyncUdpSocketWindows(udp_socket) + } + + fn set_ttl(&self, ttl: u32) -> eyre::Result<()> { + self.0 + .set_ttl(ttl) + .wrap_err("Failed to set TTL value for UDP socket") + } + + async fn send_to( + &self, + packet: &[u8], + destination: impl Into, + ) -> std::io::Result { + self.0.send_to(packet, destination.into()).await + } +} + +pub fn get_interface_ip(interface: &str) -> eyre::Result { + let interface_luid = luid_from_alias(interface)?; + + // TODO: ipv6 + let interface_ip = get_ip_address_for_interface(AddressFamily::Ipv4, interface_luid)? + .ok_or(eyre!("No IP for interface {interface:?}"))?; + + Ok(interface_ip) +} + +/// Configure the raw socket we use for listening to ICMP responses. +/// +/// This will set the `SIO_RCVALL`-option. +pub fn configure_icmp_socket(socket: &Socket) -> eyre::Result<()> { + let j = 1; + let mut _in: u32 = 0; + let result = unsafe { + WSAIoctl( + socket.as_raw_socket() as SOCKET, + SIO_RCVALL, + &j as *const _ as *const c_void, + size_of_val(&j) as u32, + null_mut(), + 0, + &mut _in as *mut u32, + null_mut(), + None, + ) + }; + + if result == SOCKET_ERROR { + let code = unsafe { WSAGetLastError() }; + bail!("Failed to call WSAIoctl(listen_socket, SIO_RCVALL, ...), code = {code}"); + } + + Ok(()) +} diff --git a/leak-checker/src/util.rs b/leak-checker/src/util.rs index a7a61febf31b..841a437e41b9 100644 --- a/leak-checker/src/util.rs +++ b/leak-checker/src/util.rs @@ -1,3 +1,5 @@ +// TODO: Remove this file + use match_cfg::match_cfg; #[cfg(any(target_os = "windows", target_os = "macos", target_os = "android"))] @@ -6,6 +8,8 @@ use std::net::IpAddr; match_cfg! { #[cfg(target_os = "windows")] => { pub fn get_interface_ip(interface: &str) -> eyre::Result { + use eyre::eyre; + use talpid_windows::net::{get_ip_address_for_interface, luid_from_alias, AddressFamily}; let interface_luid = luid_from_alias(interface)?; diff --git a/mullvad-daemon/Cargo.toml b/mullvad-daemon/Cargo.toml index d0986fa70f71..19196ad49e02 100644 --- a/mullvad-daemon/Cargo.toml +++ b/mullvad-daemon/Cargo.toml @@ -15,6 +15,8 @@ workspace = true api-override = ["mullvad-api/api-override"] [dependencies] +anyhow = "*" # TODO: do we want this? +surge-ping = "0.8.0" # TODO: workspace dep? chrono = { workspace = true } thiserror = { workspace = true } either = "1.11" @@ -27,6 +29,7 @@ serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } tokio = { workspace = true, features = ["fs", "io-util", "rt-multi-thread", "sync", "time"] } tokio-stream = "0.1" +socket2 = { workspace = true } mullvad-relay-selector = { path = "../mullvad-relay-selector" } mullvad-types = { path = "../mullvad-types" } @@ -41,6 +44,8 @@ talpid-platform-metadata = { path = "../talpid-platform-metadata" } talpid-time = { path = "../talpid-time" } talpid-types = { path = "../talpid-types" } +leak-checker = { path = "../leak-checker" } + clap = { workspace = true } log-panics = "2.0.0" mullvad-management-interface = { path = "../mullvad-management-interface" } diff --git a/mullvad-daemon/src/leak_checker/mod.rs b/mullvad-daemon/src/leak_checker/mod.rs index e3cd57d194a6..dc0824214f89 100644 --- a/mullvad-daemon/src/leak_checker/mod.rs +++ b/mullvad-daemon/src/leak_checker/mod.rs @@ -1,26 +1,234 @@ -pub fn check_for_leaks() { - // TODO: When do we run this? - // After connecting? - // Periodically? - // Whenever something changes? (interface, connection state, dns server, etc) - // All of the above? +use anyhow::anyhow; +use leak_checker::traceroute::TracerouteOpt; +pub use leak_checker::LeakInfo; +use std::net::IpAddr; +use std::ops::ControlFlow; +use talpid_types::tunnel::TunnelStateTransition; +use tokio::runtime::Handle; +use tokio::sync::mpsc; - // TODO: Figure out which interface(s) to bind to +/// An actor that tries to leak traffic outside the tunnel while we are connected. +pub struct LeakChecker { + task_event_tx: mpsc::UnboundedSender, +} + +/// [LeakChecker] internal task state. +struct Task { + events_rx: mpsc::UnboundedReceiver, + callbacks: Vec>, +} + +enum TaskEvent { + NewTunnelState(TunnelStateTransition), + AddCallback(Box), +} + +pub enum CallbackResult { + /// Callback completed successfully + Ok, + + /// Callback is no longer valid and should be dropped. + Drop, +} + +pub trait LeakCheckerCallback: Send + 'static { + fn on_leak(&mut self, info: LeakInfo) -> CallbackResult; +} + +impl LeakChecker { + pub fn new() -> Self { + let (task_event_tx, events_rx) = mpsc::unbounded_channel(); + + let task = Task { + events_rx, + callbacks: vec![], + }; + + // TODO: fix task.run() not being Send + //tokio::task::spawn(task.run()); + tokio::task::spawn_blocking(|| Handle::current().block_on(task.run())); + + LeakChecker { task_event_tx } + } + + /// Call when we transition to a new tunnel state. + pub fn on_tunnel_state_transition(&mut self, tunnel_state: TunnelStateTransition) { + self.send(TaskEvent::NewTunnelState(tunnel_state)) + } - // TODO: get connection check config - // http get https://am.i.mullvad.net/config + pub fn add_leak_callback(&mut self, callback: impl LeakCheckerCallback) { + self.send(TaskEvent::AddCallback(Box::new(callback))) + } - // TODO: For each interface: + fn send(&mut self, event: TaskEvent) { + if self.task_event_tx.send(event).is_err() { + panic!("LeakChecker unexpectedly closed"); + } + } - // TODO: send an ICMP ping (to the relay?) - // TODO: how to see if the pings are actually going outside the tunnel? + ///// Wait until the leak detector detects a leak. + ///// + ///// Ideally, this should never return. + //pub async fn wait_for_leak(&self) -> LeakInfo { + // self.leak_rx + // .recv() + // .await + // .expect("LeakChecker unexpectedly closed") + //} +} + +impl Task { + async fn run(mut self) { + loop { + let Some(event) = self.events_rx.recv().await else { + break; // All LeakChecker handles dropped. + }; + + match event { + TaskEvent::NewTunnelState(s) => { + if self.on_new_tunnel_state(s).await.is_break() { + break; + } + } + TaskEvent::AddCallback(c) => self.on_add_callback(c), + } + } + } + + fn on_add_callback(&mut self, c: Box) { + self.callbacks.push(c); + } + + async fn on_new_tunnel_state( + &mut self, + mut tunnel_state: TunnelStateTransition, + ) -> ControlFlow<()> { + 'leak_test: loop { + //let TunnelStateTransition::Connected(tunnel) = &tunnel_state else { + let TunnelStateTransition::Connected(tunnel) = &tunnel_state else { + return ControlFlow::Continue(()); + }; + + tokio::time::sleep(std::time::Duration::from_millis(500)).await; + + let ping_destination = tunnel.endpoint.address.ip(); + //let ping_destination = Ipv4Addr::new(185, 213, 154, 218); + + let interface = "wlan0"; // TODO - // TODO: send a DNS request to leak check endpoint - // TODO: will the service be able to handle all of the mullvad users constantly doing leak - // checks + let leak_info = match check_for_leaks(interface, ping_destination).await { + Ok(Some(leak_info)) => leak_info, + Ok(None) => { + log::debug!("No leak detected"); + continue; + } + Err(e) => { + log::debug!("Leak check errored: {e:#?}"); + return ControlFlow::Continue(()); + } + }; + + log::debug!("leak detected: {leak_info:?}"); + + // Make sure the tunnel state didn't change while we were doing the leak test. + // If that happened, then our results might be invalid. + while let Ok(event) = self.events_rx.try_recv() { + let new_state = match event { + TaskEvent::NewTunnelState(tunnel_state) => tunnel_state, + TaskEvent::AddCallback(c) => { + self.on_add_callback(c); + continue; + } + }; + + if let TunnelStateTransition::Connected(..) = new_state { + // Still connected, all is well... + } else { + // Tunnel state changed! We have to discard the leak test and try again. + tunnel_state = new_state; + continue 'leak_test; + } + } + + for callback in &mut self.callbacks { + callback.on_leak(leak_info.clone()); + } + return ControlFlow::Continue(()); + } + } +} + +async fn check_for_leaks(interface: &str, destination: IpAddr) -> anyhow::Result> { + leak_checker::traceroute::try_run_leak_test(&TracerouteOpt { + interface: interface.to_string(), + destination, + exclude_port: None, + port: None, + icmp: true, + }) + .await + .map_err(|e| anyhow!("{e:#}")) + .map(|status| match status { + leak_checker::LeakStatus::NoLeak => None, + leak_checker::LeakStatus::LeakDetected(info) => Some(info), + }) +} - // TODO: query DNS leak checker HTTPS endpoint +// async fn check_for_leaks(interface: &str, destination: IpAddr) -> anyhow::Result> { +// use std::mem::ManuallyDrop; +// use std::os::fd::FromRawFd; +// let client = surge_ping::Client::new(&surge_ping::Config { +// sock_type_hint: socket2::Type::DGRAM, +// kind: surge_ping::ICMP::V4, +// +// // On desktop linux, we can bind directly to the interface. +// interface: cfg!(target_os = "linux").then(|| interface.to_string()), +// +// // On other systems, we resord to binding to the interfaces IP address instead. +// bind: cfg!(not(target_os = "linux")).then(|| get_interface_ip(interface)), +// +// ttl: None, +// fib: None, +// }) +// .context("Failed to create ping client")?; +// +// // TODO: additional configuration? +// let socket = client.get_socket(); +// +// // SAFETY: socket.get_native_sock returns an open fd. +// // The socket2 socket is not used after we drop the client. +// // We wrap the socket2 socket in a ManuallyDrop to prevent it from dropping the socket. +// let socket = unsafe { socket2::Socket::from_raw_fd(socket.get_native_sock()) }; +// let socket = ManuallyDrop::new(socket); +// let mut pinger = client.pinger(destination, PingIdentifier(12345)).await; +// +// for ttl in 1..=5u16 { +// let ping_seq = ttl; +// +// socket +// .set_ttl(u32::from(ttl)) +// .context("Failed to set TTL")?; +// +// let (reply, _duration) = pinger +// .ping(PingSequence(ping_seq), b"ABCDEFGHIJKLMNOP") +// .await +// .context("Failed to send ping")?; +// +// println!("icmp_reply: {reply:?}"); +// } +// +// todo!() +// } - // TODO: query https://ipv4.am.i.mullvad.net/ - // TODO: query https://ipv6.am.i.mullvad.net/ +impl LeakCheckerCallback for T +where + T: FnMut(LeakInfo) -> bool + Send + 'static, +{ + fn on_leak(&mut self, info: LeakInfo) -> CallbackResult { + if self(info) { + CallbackResult::Ok + } else { + CallbackResult::Drop + } + } } diff --git a/mullvad-daemon/src/lib.rs b/mullvad-daemon/src/lib.rs index d7a042cf3814..1875167a1770 100644 --- a/mullvad-daemon/src/lib.rs +++ b/mullvad-daemon/src/lib.rs @@ -39,6 +39,7 @@ use futures::{ StreamExt, }; use geoip::GeoIpHandler; +use leak_checker::{LeakChecker, LeakInfo}; use management_interface::ManagementInterfaceServer; use mullvad_relay_selector::{RelaySelector, SelectorConfig}; #[cfg(target_os = "android")] @@ -413,6 +414,8 @@ pub(crate) enum InternalDaemonEvent { /// The split tunnel paths or state were updated. #[cfg(any(windows, target_os = "android", target_os = "macos"))] ExcludedPathsEvent(ExcludedPathsUpdate, oneshot::Sender>), + /// A network leak was detected. + LeakDetected(LeakInfo), } #[cfg(any(windows, target_os = "android", target_os = "macos"))] @@ -587,6 +590,7 @@ pub struct Daemon { #[cfg(target_os = "windows")] volume_update_tx: mpsc::UnboundedSender<()>, location_handler: GeoIpHandler, + leak_checker: LeakChecker, } impl Daemon { @@ -839,6 +843,17 @@ impl Daemon { internal_event_tx.clone().to_specialized_sender(), ); + let leak_checker = { + let mut leak_checker = LeakChecker::new(); + let internal_event_tx = internal_event_tx.clone(); + leak_checker.add_leak_callback(move |info| { + internal_event_tx + .send(InternalDaemonEvent::LeakDetected(info)) + .is_ok() + }); + leak_checker + }; + let daemon = Daemon { tunnel_state: TunnelState::Disconnected { location: None, @@ -869,6 +884,7 @@ impl Daemon { #[cfg(target_os = "windows")] volume_update_tx, location_handler, + leak_checker, }; api_availability.unsuspend(); @@ -967,7 +983,7 @@ impl Daemon { let mut should_stop = false; match event { TunnelStateTransition(transition) => { - self.handle_tunnel_state_transition(transition).await + self.handle_tunnel_state_transition(transition).await; } Command(command) => self.handle_command(command).await, TriggerShutdown(user_init_shutdown) => { @@ -989,6 +1005,9 @@ impl Daemon { } #[cfg(any(windows, target_os = "android", target_os = "macos"))] ExcludedPathsEvent(update, tx) => self.handle_new_excluded_paths(update, tx).await, + LeakDetected(leak_info) => { + log::warn!("LEAK DETECTED! AAAH: {leak_info:?}"); + } } should_stop } @@ -997,6 +1016,9 @@ impl Daemon { &mut self, tunnel_state_transition: TunnelStateTransition, ) { + self.leak_checker + .on_tunnel_state_transition(tunnel_state_transition.clone()); + self.reset_rpc_sockets_on_tunnel_state_transition(&tunnel_state_transition); self.device_checker .handle_state_transition(&tunnel_state_transition); diff --git a/talpid-core/Cargo.toml b/talpid-core/Cargo.toml index e2593313d309..7ec3fd20f288 100644 --- a/talpid-core/Cargo.toml +++ b/talpid-core/Cargo.toml @@ -47,8 +47,7 @@ duct = "0.13" [target.'cfg(target_os = "macos")'.dependencies] async-trait = "0.1" duct = "0.13" -#pfctl = "0.6.1" -pfctl = { path = "../../pfctl-rs" } +pfctl = "0.6.1" subslice = "0.2" system-configuration = "0.5.1" hickory-proto = { workspace = true } diff --git a/talpid-net/Cargo.toml b/talpid-net/Cargo.toml index aa30ed1b5b6a..861e1765cc60 100644 --- a/talpid-net/Cargo.toml +++ b/talpid-net/Cargo.toml @@ -13,5 +13,5 @@ workspace = true [target.'cfg(unix)'.dependencies] libc = "0.2" talpid-types = { path = "../talpid-types" } -socket2 = { version = "0.5.3", features = ["all"] } +socket2 = { workspace = true, features = ["all"] } log = { workspace = true } diff --git a/talpid-windows/Cargo.toml b/talpid-windows/Cargo.toml index a44229b61d07..0b9e1d267217 100644 --- a/talpid-windows/Cargo.toml +++ b/talpid-windows/Cargo.toml @@ -12,7 +12,7 @@ workspace = true [target.'cfg(windows)'.dependencies] thiserror = { workspace = true } -socket2 = { version = "0.5.3" } +socket2 = { workspace = true } futures = { workspace = true } talpid-types = { path = "../talpid-types" } diff --git a/talpid-wireguard/Cargo.toml b/talpid-wireguard/Cargo.toml index e02bf874d253..620cbab5cc26 100644 --- a/talpid-wireguard/Cargo.toml +++ b/talpid-wireguard/Cargo.toml @@ -39,7 +39,7 @@ duct = "0.13" [target.'cfg(not(target_os="android"))'.dependencies] byteorder = "1" internet-checksum = "0.2" -socket2 = { version = "0.5.3", features = ["all"] } +socket2 = { workspace = true, features = ["all"] } tokio-stream = { version = "0.1", features = ["io-util"] } [target.'cfg(unix)'.dependencies] diff --git a/test/test-runner/Cargo.toml b/test/test-runner/Cargo.toml index fd53f4b7cb79..af84ef4daeb0 100644 --- a/test/test-runner/Cargo.toml +++ b/test/test-runner/Cargo.toml @@ -33,7 +33,7 @@ test-rpc = { path = "../test-rpc" } mullvad-paths = { path = "../../mullvad-paths" } talpid-platform-metadata = { path = "../../talpid-platform-metadata", default-features = false } -socket2 = { version = "0.5.4", features = ["all"] } +socket2 = { workspace = true, features = ["all"] } [target."cfg(target_os=\"windows\")".dependencies] talpid-windows = { path = "../../talpid-windows" }