Skip to content

Commit

Permalink
Remove DNS fallback except for conncheck
Browse files Browse the repository at this point in the history
  • Loading branch information
dlon committed Dec 2, 2024
1 parent 068dafd commit a579c72
Show file tree
Hide file tree
Showing 11 changed files with 77 additions and 92 deletions.
20 changes: 18 additions & 2 deletions mullvad-api/src/address_cache.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
//! This module keeps track of the last known good API IP address and reads and stores it on disk.
use super::API;
use crate::DnsResolver;
use async_trait::async_trait;
use std::{io, net::SocketAddr, path::Path, sync::Arc};
use tokio::{
fs,
Expand All @@ -23,6 +25,17 @@ pub enum Error {
Write(#[source] io::Error),
}

/// A DNS resolver which resolves using `AddressCache`.
#[async_trait]
impl DnsResolver for AddressCache {
async fn resolve(&self, host: String) -> Result<Vec<SocketAddr>, io::Error> {
self.resolve_hostname(&host)
.await
.map(|addr| vec![addr])
.ok_or(io::Error::other("host does not match API host"))
}
}

#[derive(Clone)]
pub struct AddressCache {
inner: Arc<Mutex<AddressCacheInner>>,
Expand All @@ -42,7 +55,10 @@ impl AddressCache {
/// Initialize cache using `read_path`, and write changes to `write_path`.
pub async fn from_file(read_path: &Path, write_path: Option<Box<Path>>) -> Result<Self, Error> {
log::debug!("Loading API addresses from {}", read_path.display());
Ok(Self::new_inner(read_address_file(read_path).await?, write_path))
Ok(Self::new_inner(
read_address_file(read_path).await?,
write_path,
))
}

fn new_inner(address: SocketAddr, write_path: Option<Box<Path>>) -> Self {
Expand All @@ -56,7 +72,7 @@ impl AddressCache {
}

/// Returns the address if the hostname equals `API.host`. Otherwise, returns `None`.
pub async fn resolve_hostname(&self, hostname: &str) -> Option<SocketAddr> {
async fn resolve_hostname(&self, hostname: &str) -> Option<SocketAddr> {
if hostname.eq_ignore_ascii_case(API.host()) {
Some(self.get_address().await)
} else {
Expand Down
6 changes: 2 additions & 4 deletions mullvad-api/src/bin/relay_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
//! Used by the installer artifact packer to bundle the latest available
//! relay list at the time of creating the installer.
use mullvad_api::{
proxy::ApiConnectionMode, rest::Error as RestError, DefaultDnsResolver, RelayListProxy,
};
use mullvad_api::{proxy::ApiConnectionMode, rest::Error as RestError, RelayListProxy};
use std::process;
use talpid_types::ErrorExt;

#[tokio::main]
async fn main() {
let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current(), DefaultDnsResolver)
let runtime = mullvad_api::Runtime::new(tokio::runtime::Handle::current())
.expect("Failed to load runtime");

let relay_list_request =
Expand Down
6 changes: 3 additions & 3 deletions mullvad-api/src/ffi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::{

use crate::{
rest::{self, MullvadRestHandle},
AccountsProxy, DevicesProxy,
AccountsProxy, DevicesProxy, NullDnsResolver,
};

mod device;
Expand Down Expand Up @@ -209,11 +209,11 @@ impl FfiClient {
}

fn device_proxy(&self) -> DevicesProxy {
crate::DevicesProxy::new(self.rest_handle())
crate::DevicesProxy::new(self.rest_handle(NullDnsResolver))
}

fn accounts_proxy(&self) -> AccountsProxy {
crate::AccountsProxy::new(self.rest_handle())
crate::AccountsProxy::new(self.rest_handle(NullDnsResolver))
}

fn tokio_handle(&self) -> tokio::runtime::Handle {
Expand Down
34 changes: 10 additions & 24 deletions mullvad-api/src/https_client_with_sni.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{
abortable_stream::{AbortableStream, AbortableStreamHandle},
proxy::{ApiConnection, ApiConnectionMode, ProxyConfig},
tls_stream::TlsStream,
AddressCache, DnsResolver,
DnsResolver,
};
use futures::{channel::mpsc, future, pin_mut, StreamExt};
#[cfg(target_os = "android")]
Expand Down Expand Up @@ -287,7 +287,6 @@ impl TryFrom<ApiConnectionMode> for InnerConnectionMode {
pub struct HttpsConnectorWithSni {
inner: Arc<Mutex<HttpsConnectorWithSniInner>>,
sni_hostname: Option<String>,
address_cache: AddressCache,
abort_notify: Arc<tokio::sync::Notify>,
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")]
Expand All @@ -305,7 +304,6 @@ pub type SocketBypassRequest = (RawFd, oneshot::Sender<()>);
impl HttpsConnectorWithSni {
pub fn new(
sni_hostname: Option<String>,
address_cache: AddressCache,
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> (Self, HttpsConnectorWithSniHandle) {
Expand Down Expand Up @@ -353,7 +351,6 @@ impl HttpsConnectorWithSni {
HttpsConnectorWithSni {
inner,
sni_hostname,
address_cache,
abort_notify,
dns_resolver,
#[cfg(target_os = "android")]
Expand Down Expand Up @@ -390,13 +387,9 @@ impl HttpsConnectorWithSni {
}

/// Resolve the provided `uri` to an IP and port. If the URI contains an IP, that IP will be used.
/// Otherwise `address_cache` will be preferred, and `dns_resolver` will be used as a fallback.
/// Otherwise `dns_resolver` will be used as a fallback.
/// If the URI contains a port, then that port will be used.
async fn resolve_address(
address_cache: AddressCache,
dns_resolver: &dyn DnsResolver,
uri: Uri,
) -> io::Result<SocketAddr> {
async fn resolve_address(dns_resolver: &dyn DnsResolver, uri: Uri) -> io::Result<SocketAddr> {
const DEFAULT_PORT: u16 = 443;

let hostname = uri.host().ok_or_else(|| {
Expand All @@ -407,22 +400,16 @@ impl HttpsConnectorWithSni {
return Ok(SocketAddr::new(addr, port.unwrap_or(DEFAULT_PORT)));
}

// Preferentially, use cached address.
//
if let Some(addr) = address_cache.resolve_hostname(hostname).await {
return Ok(SocketAddr::new(
addr.ip(),
port.unwrap_or_else(|| addr.port()),
));
}

// Use DNS resolution as fallback
//
let addrs = dns_resolver.resolve(hostname.to_owned()).await?;
let addr = addrs
.first()
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Empty DNS response"))?;
Ok(SocketAddr::new(*addr, port.unwrap_or(DEFAULT_PORT)))
let port = match (addr.port(), port) {
(_, Some(port)) => port,
(0, None) => DEFAULT_PORT,
(addr_port, None) => addr_port,
};
Ok(SocketAddr::new(addr.ip(), port))
}
}

Expand Down Expand Up @@ -456,7 +443,6 @@ impl Service<Uri> for HttpsConnectorWithSni {
let abort_notify = self.abort_notify.clone();
#[cfg(target_os = "android")]
let socket_bypass_tx = self.socket_bypass_tx.clone();
let address_cache = self.address_cache.clone();
let dns_resolver = self.dns_resolver.clone();

let fut = async move {
Expand All @@ -468,7 +454,7 @@ impl Service<Uri> for HttpsConnectorWithSni {
}

let hostname = sni_hostname?;
let addr = Self::resolve_address(address_cache, &*dns_resolver, uri).await?;
let addr = Self::resolve_address(&*dns_resolver, uri).await?;

// Loop until we have established a connection. This starts over if a new endpoint
// is selected while connecting.
Expand Down
65 changes: 29 additions & 36 deletions mullvad-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,22 +308,22 @@ impl ApiEndpoint {

#[async_trait]
pub trait DnsResolver: 'static + Send + Sync {
async fn resolve(&self, host: String) -> io::Result<Vec<IpAddr>>;
async fn resolve(&self, host: String) -> io::Result<Vec<SocketAddr>>;
}

/// DNS resolver that relies on `ToSocketAddrs` (`getaddrinfo`).
pub struct DefaultDnsResolver;

#[async_trait]
impl DnsResolver for DefaultDnsResolver {
async fn resolve(&self, host: String) -> io::Result<Vec<IpAddr>> {
async fn resolve(&self, host: String) -> io::Result<Vec<SocketAddr>> {
use std::net::ToSocketAddrs;
// Spawn a blocking thread, since `to_socket_addrs` relies on `libc::getaddrinfo`, which
// blocks and either has no timeout or a very long one.
let addrs = tokio::task::spawn_blocking(move || (host, 0).to_socket_addrs())
.await
.expect("DNS task panicked")?;
Ok(addrs.map(|addr| addr.ip()).collect())
Ok(addrs.collect())
}
}

Expand All @@ -332,7 +332,7 @@ pub struct NullDnsResolver;

#[async_trait]
impl DnsResolver for NullDnsResolver {
async fn resolve(&self, _host: String) -> io::Result<Vec<IpAddr>> {
async fn resolve(&self, _host: String) -> io::Result<Vec<SocketAddr>> {
Ok(vec![])
}
}
Expand All @@ -342,7 +342,6 @@ pub struct Runtime {
handle: tokio::runtime::Handle,
address_cache: AddressCache,
api_availability: availability::ApiAvailability,
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")]
socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
}
Expand All @@ -364,13 +363,9 @@ pub enum Error {

impl Runtime {
/// Create a new `Runtime`.
pub fn new(
handle: tokio::runtime::Handle,
dns_resolver: impl DnsResolver,
) -> Result<Self, Error> {
pub fn new(handle: tokio::runtime::Handle) -> Result<Self, Error> {
Self::new_inner(
handle,
dns_resolver,
#[cfg(target_os = "android")]
None,
)
Expand All @@ -388,14 +383,12 @@ impl Runtime {

fn new_inner(
handle: tokio::runtime::Handle,
dns_resolver: impl DnsResolver,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> Result<Self, Error> {
Ok(Runtime {
handle,
address_cache: AddressCache::new(None)?,
address_cache: AddressCache::new(None),
api_availability: ApiAvailability::default(),
dns_resolver: Arc::new(dns_resolver),
#[cfg(target_os = "android")]
socket_bypass_tx,
})
Expand All @@ -404,7 +397,6 @@ impl Runtime {
/// Create a new `Runtime` using the specified directories.
/// Try to use the cache directory first, and fall back on the bundled address otherwise.
pub async fn with_cache(
dns_resolver: impl DnsResolver,
cache_dir: &Path,
write_changes: bool,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
Expand All @@ -415,7 +407,6 @@ impl Runtime {
if API.disable_address_cache {
return Self::new_inner(
handle,
dns_resolver,
#[cfg(target_os = "android")]
socket_bypass_tx,
);
Expand All @@ -439,7 +430,7 @@ impl Runtime {
)
);
}
AddressCache::new(write_file)?
AddressCache::new(write_file)
}
};

Expand All @@ -449,30 +440,11 @@ impl Runtime {
handle,
address_cache,
api_availability,
dns_resolver: Arc::new(dns_resolver),
#[cfg(target_os = "android")]
socket_bypass_tx,
})
}

/// Creates a new request service and returns a handle to it.
fn new_request_service<T: ConnectionModeProvider + 'static>(
&self,
sni_hostname: Option<String>,
connection_mode_provider: T,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> rest::RequestServiceHandle {
rest::RequestService::spawn(
sni_hostname,
self.api_availability.clone(),
self.address_cache.clone(),
connection_mode_provider,
self.dns_resolver.clone(),
#[cfg(target_os = "android")]
socket_bypass_tx,
)
}

/// Returns a request factory initialized to create requests for the master API
pub fn mullvad_rest_handle<T: ConnectionModeProvider + 'static>(
&self,
Expand All @@ -481,6 +453,7 @@ impl Runtime {
let service = self.new_request_service(
Some(API.host().to_string()),
connection_mode_provider,
Arc::new(self.address_cache.clone()),
#[cfg(target_os = "android")]
self.socket_bypass_tx.clone(),
);
Expand All @@ -495,6 +468,7 @@ impl Runtime {
let service = self.new_request_service(
Some(hostname.clone()),
ApiConnectionMode::Direct.into_provider(),
Arc::new(self.address_cache.clone()),
#[cfg(target_os = "android")]
self.socket_bypass_tx.clone(),
);
Expand All @@ -505,15 +479,34 @@ impl Runtime {
}

/// Returns a new request service handle
pub fn rest_handle(&self) -> rest::RequestServiceHandle {
pub fn rest_handle(&self, dns_resolver: impl DnsResolver) -> rest::RequestServiceHandle {
self.new_request_service(
None,
ApiConnectionMode::Direct.into_provider(),
Arc::new(dns_resolver),
#[cfg(target_os = "android")]
None,
)
}

/// Creates a new request service and returns a handle to it.
fn new_request_service<T: ConnectionModeProvider + 'static>(
&self,
sni_hostname: Option<String>,
connection_mode_provider: T,
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> rest::RequestServiceHandle {
rest::RequestService::spawn(
sni_hostname,
self.api_availability.clone(),
connection_mode_provider,
dns_resolver,
#[cfg(target_os = "android")]
socket_bypass_tx,
)
}

pub fn handle(&mut self) -> &mut tokio::runtime::Handle {
&mut self.handle
}
Expand Down
3 changes: 0 additions & 3 deletions mullvad-api/src/rest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
pub use crate::https_client_with_sni::SocketBypassRequest;
use crate::{
access::AccessTokenStore,
address_cache::AddressCache,
availability::ApiAvailability,
https_client_with_sni::{HttpsConnectorWithSni, HttpsConnectorWithSniHandle},
proxy::ConnectionModeProvider,
Expand Down Expand Up @@ -153,14 +152,12 @@ impl<T: ConnectionModeProvider + 'static> RequestService<T> {
pub fn spawn(
sni_hostname: Option<String>,
api_availability: ApiAvailability,
address_cache: AddressCache,
connection_mode_provider: T,
dns_resolver: Arc<dyn DnsResolver>,
#[cfg(target_os = "android")] socket_bypass_tx: Option<mpsc::Sender<SocketBypassRequest>>,
) -> RequestServiceHandle {
let (connector, connector_handle) = HttpsConnectorWithSni::new(
sni_hostname,
address_cache.clone(),
dns_resolver,
#[cfg(target_os = "android")]
socket_bypass_tx.clone(),
Expand Down
Loading

0 comments on commit a579c72

Please sign in to comment.