Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor address cache in mullvad-api #7248

Merged
merged 3 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions mullvad-api/src/address_cache.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
//! This module keeps track of the last known good API IP address and reads and stores it on disk.

use super::API;
use crate::DnsResolver;
use async_trait::async_trait;
use std::{io, net::SocketAddr, path::Path, sync::Arc};
use tokio::{
fs,
Expand All @@ -23,6 +25,17 @@ pub enum Error {
Write(#[source] io::Error),
}

/// A DNS resolver which resolves using `AddressCache`.
#[async_trait]
impl DnsResolver for AddressCache {
async fn resolve(&self, host: String) -> Result<Vec<SocketAddr>, io::Error> {
self.resolve_hostname(&host)
.await
.map(|addr| vec![addr])
.ok_or(io::Error::other("host does not match API host"))
}
}

#[derive(Clone)]
pub struct AddressCache {
inner: Arc<Mutex<AddressCacheInner>>,
Expand All @@ -31,34 +44,35 @@ pub struct AddressCache {

impl AddressCache {
/// Initialize cache using the hardcoded address, and write changes to `write_path`.
pub fn new(write_path: Option<Box<Path>>) -> Result<Self, Error> {
pub fn new(write_path: Option<Box<Path>>) -> Self {
Self::new_inner(API.address(), write_path)
}

pub fn with_static_addr(address: SocketAddr) -> Self {
Self::new_inner(address, None)
.expect("Failed to construct an address cache from a static address")
}

/// Initialize cache using `read_path`, and write changes to `write_path`.
pub async fn from_file(read_path: &Path, write_path: Option<Box<Path>>) -> Result<Self, Error> {
log::debug!("Loading API addresses from {}", read_path.display());
Self::new_inner(read_address_file(read_path).await?, write_path)
Ok(Self::new_inner(
read_address_file(read_path).await?,
write_path,
))
}

fn new_inner(address: SocketAddr, write_path: Option<Box<Path>>) -> Result<Self, Error> {
fn new_inner(address: SocketAddr, write_path: Option<Box<Path>>) -> Self {
let cache = AddressCacheInner::from_address(address);
log::debug!("Using API address: {}", cache.address);

let address_cache = Self {
Self {
inner: Arc::new(Mutex::new(cache)),
write_path: write_path.map(Arc::from),
};
Ok(address_cache)
}
}

/// Returns the address if the hostname equals `API.host`. Otherwise, returns `None`.
pub async fn resolve_hostname(&self, hostname: &str) -> Option<SocketAddr> {
async fn resolve_hostname(&self, hostname: &str) -> Option<SocketAddr> {
if hostname.eq_ignore_ascii_case(API.host()) {
Some(self.get_address().await)
} else {
Expand Down
6 changes: 2 additions & 4 deletions mullvad-api/src/bin/relay_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
//! Used by the installer artifact packer to bundle the latest available
//! relay list at the time of creating the installer.

use mullvad_api::{
proxy::ApiConnectionMode, rest::Error as RestError, DefaultDnsResolver, RelayListProxy,
};
use mullvad_api::{proxy::ApiConnectionMode, rest::Error as RestError, RelayListProxy};
use std::process;
use talpid_types::ErrorExt;

#[tokio::main]
async fn main() {
let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current(), DefaultDnsResolver)
let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current())
.expect("Failed to load runtime");

let relay_list_request =
Expand Down
52 changes: 16 additions & 36 deletions mullvad-api/src/https_client_with_sni.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{
abortable_stream::{AbortableStream, AbortableStreamHandle},
proxy::{ApiConnection, ApiConnectionMode, ProxyConfig},
tls_stream::TlsStream,
AddressCache, DnsResolver,
DnsResolver,
};
use futures::{channel::mpsc, future, pin_mut, StreamExt};
#[cfg(target_os = "android")]
Expand Down Expand Up @@ -286,8 +286,6 @@ impl TryFrom<ApiConnectionMode> for InnerConnectionMode {
#[derive(Clone)]
pub struct HttpsConnectorWithSni {
inner: Arc<Mutex<HttpsConnectorWithSniInner>>,
sni_hostname: Option<String>,
address_cache: AddressCache,
abort_notify: Arc<tokio::sync::Notify>,
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")]
Expand All @@ -304,8 +302,6 @@ pub type SocketBypassRequest = (RawFd, oneshot::Sender<()>);

impl HttpsConnectorWithSni {
pub fn new(
sni_hostname: Option<String>,
address_cache: AddressCache,
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> (Self, HttpsConnectorWithSniHandle) {
Expand Down Expand Up @@ -352,8 +348,6 @@ impl HttpsConnectorWithSni {
(
HttpsConnectorWithSni {
inner,
sni_hostname,
address_cache,
abort_notify,
dns_resolver,
#[cfg(target_os = "android")]
Expand Down Expand Up @@ -390,13 +384,9 @@ impl HttpsConnectorWithSni {
}

/// Resolve the provided `uri` to an IP and port. If the URI contains an IP, that IP will be used.
/// Otherwise `address_cache` will be preferred, and `dns_resolver` will be used as a fallback.
/// Otherwise `dns_resolver` will be used as a fallback.
/// If the URI contains a port, then that port will be used.
async fn resolve_address(
address_cache: AddressCache,
dns_resolver: &dyn DnsResolver,
uri: Uri,
) -> io::Result<SocketAddr> {
async fn resolve_address(dns_resolver: &dyn DnsResolver, uri: Uri) -> io::Result<SocketAddr> {
const DEFAULT_PORT: u16 = 443;

let hostname = uri.host().ok_or_else(|| {
Expand All @@ -407,22 +397,16 @@ impl HttpsConnectorWithSni {
return Ok(SocketAddr::new(addr, port.unwrap_or(DEFAULT_PORT)));
}

// Preferentially, use cached address.
//
if let Some(addr) = address_cache.resolve_hostname(hostname).await {
return Ok(SocketAddr::new(
addr.ip(),
port.unwrap_or_else(|| addr.port()),
));
}

// Use DNS resolution as fallback
//
let addrs = dns_resolver.resolve(hostname.to_owned()).await?;
let addr = addrs
.first()
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Empty DNS response"))?;
Ok(SocketAddr::new(*addr, port.unwrap_or(DEFAULT_PORT)))
let port = match (addr.port(), port) {
(_, Some(port)) => port,
(0, None) => DEFAULT_PORT,
(addr_port, None) => addr_port,
};
Ok(SocketAddr::new(addr.ip(), port))
}
}

Expand All @@ -445,18 +429,10 @@ impl Service<Uri> for HttpsConnectorWithSni {
}

fn call(&mut self, uri: Uri) -> Self::Future {
let sni_hostname = self
.sni_hostname
.clone()
.or_else(|| uri.host().map(str::to_owned))
.ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidInput, "invalid url, missing host")
});
let inner = self.inner.clone();
let abort_notify = self.abort_notify.clone();
#[cfg(target_os = "android")]
let socket_bypass_tx = self.socket_bypass_tx.clone();
let address_cache = self.address_cache.clone();
let dns_resolver = self.dns_resolver.clone();

let fut = async move {
Expand All @@ -466,9 +442,13 @@ impl Service<Uri> for HttpsConnectorWithSni {
"invalid url, not https",
));
}

let hostname = sni_hostname?;
let addr = Self::resolve_address(address_cache, &*dns_resolver, uri).await?;
let Some(hostname) = uri.host().map(str::to_owned) else {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid url, missing host",
));
};
let addr = Self::resolve_address(&*dns_resolver, uri).await?;

// Loop until we have established a connection. This starts over if a new endpoint
// is selected while connecting.
Expand Down
67 changes: 27 additions & 40 deletions mullvad-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,22 +308,22 @@ impl ApiEndpoint {

#[async_trait]
pub trait DnsResolver: 'static + Send + Sync {
async fn resolve(&self, host: String) -> io::Result<Vec<IpAddr>>;
async fn resolve(&self, host: String) -> io::Result<Vec<SocketAddr>>;
}

/// DNS resolver that relies on `ToSocketAddrs` (`getaddrinfo`).
pub struct DefaultDnsResolver;

#[async_trait]
impl DnsResolver for DefaultDnsResolver {
async fn resolve(&self, host: String) -> io::Result<Vec<IpAddr>> {
async fn resolve(&self, host: String) -> io::Result<Vec<SocketAddr>> {
use std::net::ToSocketAddrs;
// Spawn a blocking thread, since `to_socket_addrs` relies on `libc::getaddrinfo`, which
// blocks and either has no timeout or a very long one.
let addrs = tokio::task::spawn_blocking(move || (host, 0).to_socket_addrs())
.await
.expect("DNS task panicked")?;
Ok(addrs.map(|addr| addr.ip()).collect())
Ok(addrs.collect())
}
}

Expand All @@ -332,7 +332,7 @@ pub struct NullDnsResolver;

#[async_trait]
impl DnsResolver for NullDnsResolver {
async fn resolve(&self, _host: String) -> io::Result<Vec<IpAddr>> {
async fn resolve(&self, _host: String) -> io::Result<Vec<SocketAddr>> {
Ok(vec![])
}
}
Expand All @@ -342,7 +342,6 @@ pub struct Runtime {
handle: tokio::runtime::Handle,
address_cache: AddressCache,
api_availability: availability::ApiAvailability,
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")]
socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
}
Expand All @@ -364,13 +363,9 @@ pub enum Error {

impl Runtime {
/// Create a new `Runtime`.
pub fn new(
handle: tokio::runtime::Handle,
dns_resolver: impl DnsResolver,
) -> Result<Self, Error> {
pub fn new(handle: tokio::runtime::Handle) -> Result<Self, Error> {
Self::new_inner(
handle,
dns_resolver,
#[cfg(target_os = "android")]
None,
)
Expand All @@ -381,21 +376,18 @@ impl Runtime {
Runtime {
handle,
address_cache: AddressCache::with_static_addr(address),
dns_resolver: Arc::new(NullDnsResolver),
api_availability: ApiAvailability::default(),
}
}

fn new_inner(
handle: tokio::runtime::Handle,
dns_resolver: impl DnsResolver,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> Result<Self, Error> {
Ok(Runtime {
handle,
address_cache: AddressCache::new(None)?,
address_cache: AddressCache::new(None),
api_availability: ApiAvailability::default(),
dns_resolver: Arc::new(dns_resolver),
#[cfg(target_os = "android")]
socket_bypass_tx,
})
Expand All @@ -404,7 +396,6 @@ impl Runtime {
/// Create a new `Runtime` using the specified directories.
/// Try to use the cache directory first, and fall back on the bundled address otherwise.
pub async fn with_cache(
dns_resolver: impl DnsResolver,
cache_dir: &Path,
write_changes: bool,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
Expand All @@ -415,7 +406,6 @@ impl Runtime {
if API.disable_address_cache {
return Self::new_inner(
handle,
dns_resolver,
#[cfg(target_os = "android")]
socket_bypass_tx,
);
Expand All @@ -439,7 +429,7 @@ impl Runtime {
)
);
}
AddressCache::new(write_file)?
AddressCache::new(write_file)
}
};

Expand All @@ -449,38 +439,19 @@ impl Runtime {
handle,
address_cache,
api_availability,
dns_resolver: Arc::new(dns_resolver),
#[cfg(target_os = "android")]
socket_bypass_tx,
})
}

/// Creates a new request service and returns a handle to it.
fn new_request_service<T: ConnectionModeProvider + 'static>(
&self,
sni_hostname: Option<String>,
connection_mode_provider: T,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> rest::RequestServiceHandle {
rest::RequestService::spawn(
sni_hostname,
self.api_availability.clone(),
self.address_cache.clone(),
connection_mode_provider,
self.dns_resolver.clone(),
#[cfg(target_os = "android")]
socket_bypass_tx,
)
}

/// Returns a request factory initialized to create requests for the master API
pub fn mullvad_rest_handle<T: ConnectionModeProvider + 'static>(
&self,
connection_mode_provider: T,
) -> rest::MullvadRestHandle {
let service = self.new_request_service(
Some(API.host().to_string()),
connection_mode_provider,
Arc::new(self.address_cache.clone()),
#[cfg(target_os = "android")]
self.socket_bypass_tx.clone(),
);
Expand All @@ -493,8 +464,8 @@ impl Runtime {
/// This is only to be used in test code
pub fn static_mullvad_rest_handle(&self, hostname: String) -> rest::MullvadRestHandle {
let service = self.new_request_service(
Some(hostname.clone()),
ApiConnectionMode::Direct.into_provider(),
Arc::new(self.address_cache.clone()),
#[cfg(target_os = "android")]
self.socket_bypass_tx.clone(),
);
Expand All @@ -505,15 +476,31 @@ impl Runtime {
}

/// Returns a new request service handle
pub fn rest_handle(&self) -> rest::RequestServiceHandle {
pub fn rest_handle(&self, dns_resolver: impl DnsResolver) -> rest::RequestServiceHandle {
self.new_request_service(
None,
ApiConnectionMode::Direct.into_provider(),
Arc::new(dns_resolver),
#[cfg(target_os = "android")]
None,
)
}

/// Creates a new request service and returns a handle to it.
fn new_request_service<T: ConnectionModeProvider + 'static>(
&self,
connection_mode_provider: T,
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> rest::RequestServiceHandle {
rest::RequestService::spawn(
self.api_availability.clone(),
connection_mode_provider,
dns_resolver,
#[cfg(target_os = "android")]
socket_bypass_tx,
)
}

pub fn handle(&mut self) -> &mut tokio::runtime::Handle {
&mut self.handle
}
Expand Down
Loading
Loading