Skip to content

Commit

Permalink
Allow AccessModeSelector to resolve api endpoints
Browse files Browse the repository at this point in the history
Until now, `AccessModeSelector` has not been able to resolve API
endpoints on it's own. This has happened at some later stage, for
example in the `mullvad-api` crate. However, for testing the `Direct`
access method, it is very useful to be able to resolve the actual
endpoint without involving the daemon's "API runtime".
  • Loading branch information
MarkusPettersson98 committed Jan 3, 2024
1 parent 8d8c478 commit fb35fc2
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 48 deletions.
122 changes: 81 additions & 41 deletions mullvad-daemon/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ use crate::DaemonCommand;
use crate::DaemonEventSender;
use futures::{
channel::{mpsc, oneshot},
stream::unfold,
Stream, StreamExt,
};
use mullvad_api::{
availability::ApiAvailabilityHandle,
proxy::{ApiConnectionMode, ProxyConfig},
AddressCache,
};
use mullvad_relay_selector::RelaySelector;
use mullvad_types::access_method::{self, AccessMethod, AccessMethodSetting, BuiltInAccessMethod};
Expand All @@ -24,10 +24,12 @@ use talpid_types::net::{
};

pub enum Message {
Get(ResponseTx<AccessMethodSetting>),
Get(ResponseTx<ResolvedConnectionMode>),
Set(ResponseTx<()>, AccessMethodSetting),
Next(ResponseTx<ApiConnectionMode>),
Update(ResponseTx<()>, Vec<AccessMethodSetting>),
Resolve(ResponseTx<ResolvedConnectionMode>, AccessMethodSetting),
}

/// A [`NewAccessMethodEvent`] is emitted when the active access method changes.
/// Which access method that should be active at any given time is decided by
Expand Down Expand Up @@ -103,6 +105,23 @@ impl NewAccessMethodEvent {
update_finished_rx
}
}

/// This struct represent a concrete API endpoint (in the form of an
/// [`ApiConnectionMode`] and [`AllowedEndpoint`]) which has been derived from
/// some [`AccessMethodSetting`] (most likely the currently active access
/// method). These logically related values are sometimes useful to group
/// together into one value, which is encoded by [`ResolvedConnectionMode`].
#[derive(Clone)]
pub struct ResolvedConnectionMode {
/// The connection strategy to be used by the `mullvad-api` crate when
/// initializing API requests.
pub connection_mode: ApiConnectionMode,
/// The actual endpoint of the Mullvad API and which clients should be
/// allowed to initialize a connection to this endpoint.
pub endpoint: AllowedEndpoint,
/// This is the [`AccessMethodSetting`] which resolved into
/// `connection_mode` and `endpoint`.
pub setting: AccessMethodSetting,
}

#[derive(err_derive::Error, Debug)]
Expand All @@ -124,6 +143,7 @@ impl std::fmt::Display for Message {
Message::Set(_, _) => f.write_str("Set"),
Message::Next(_) => f.write_str("Next"),
Message::Update(_, _) => f.write_str("Update"),
Message::Resolve(_, _) => f.write_str("Resolve"),
}
}
}
Expand Down Expand Up @@ -159,7 +179,7 @@ impl AccessModeSelectorHandle {
rx.await.map_err(Error::NotRunning)?
}

pub async fn get_access_method(&self) -> Result<AccessMethodSetting> {
pub async fn get_current(&self) -> Result<ResolvedConnectionMode> {
self.send_command(Message::Get).await.map_err(|err| {
log::debug!("Failed to get current access method!");
err
Expand All @@ -184,6 +204,15 @@ impl AccessModeSelectorHandle {
})
}

pub async fn resolve(&self, setting: AccessMethodSetting) -> Result<ResolvedConnectionMode> {
self.send_command(|tx| Message::Resolve(tx, setting))
.await
.map_err(|err| {
log::error!("Failed to update new access methods!");
err
})
}

pub async fn next(&self) -> Result<ApiConnectionMode> {
self.send_command(Message::Next).await.map_err(|err| {
log::debug!("Failed while getting the next access method");
Expand All @@ -194,10 +223,10 @@ impl AccessModeSelectorHandle {
/// Convert this handle to a [`Stream`] of [`ApiConnectionMode`] from the
/// associated [`AccessModeSelector`].
///
/// Practically converts the handle to a listener for when the
/// currently valid connection modes changes.
/// Calling `next` on this stream will poll for the next access method,
/// which will be lazily produced (on-demand rather than speculatively).
pub fn into_stream(self) -> impl Stream<Item = ApiConnectionMode> {
unfold(self, |handle| async move {
futures::stream::unfold(self, |handle| async move {
match handle.next().await {
Ok(connection_mode) => Some((connection_mode, handle)),
// End this stream in case of failure in `next`. `next` should
Expand Down Expand Up @@ -231,7 +260,7 @@ pub struct AccessModeSelector {
}

impl AccessModeSelector {
pub fn spawn(
pub(crate) async fn spawn(
cache_dir: PathBuf,
relay_selector: RelaySelector,
connection_modes: Vec<AccessMethodSetting>,
Expand All @@ -240,7 +269,7 @@ impl AccessModeSelector {
) -> Result<AccessModeSelectorHandle> {
let (cmd_tx, cmd_rx) = mpsc::unbounded();

let connection_modes = match ConnectionModesIterator::new(connection_modes) {
let mut connection_modes = match ConnectionModesIterator::new(connection_modes) {
Ok(provider) => provider,
Err(Error::NoAccessMethods) | Err(_) => {
// No settings seem to have been found. Default to using the the
Expand All @@ -251,6 +280,12 @@ impl AccessModeSelector {
)
}
};

let initial_connection_mode = {
let next = connection_modes.next().ok_or(Error::NoAccessMethods)?;
Self::resolve_inner(next, &relay_selector, &address_cache).await
};

let selector = AccessModeSelector {
cmd_rx,
cache_dir,
Expand All @@ -260,8 +295,10 @@ impl AccessModeSelector {
access_method_event_sender,
current: initial_connection_mode,
};

tokio::spawn(selector.into_future());
AccessModeSelectorHandle { cmd_tx }

Ok(AccessModeSelectorHandle { cmd_tx })
}

async fn into_future(mut self) {
Expand All @@ -270,8 +307,9 @@ impl AccessModeSelector {
let execution = match cmd {
Message::Get(tx) => self.on_get_access_method(tx),
Message::Set(tx, value) => self.on_set_access_method(tx, value),
Message::Next(tx) => self.on_next_connection_mode(tx),
Message::Next(tx) => self.on_next_connection_mode(tx).await,
Message::Update(tx, values) => self.on_update_access_methods(tx, values),
Message::Resolve(tx, setting) => self.on_resolve_access_method(tx, setting).await,
};
match execution {
Ok(_) => (),
Expand All @@ -291,13 +329,8 @@ impl AccessModeSelector {
Ok(())
}

fn on_get_access_method(&mut self, tx: ResponseTx<AccessMethodSetting>) -> Result<()> {
let value = self.get_access_method();
self.reply(tx, value)
}

fn get_access_method(&mut self) -> AccessMethodSetting {
self.connection_modes.peek()
fn on_get_access_method(&mut self, tx: ResponseTx<ResolvedConnectionMode>) -> Result<()> {
self.reply(tx, self.current.clone())
}

fn on_set_access_method(
Expand All @@ -315,36 +348,48 @@ impl AccessModeSelector {
self.connection_modes.set_access_method(value);
}

fn on_next_connection_mode(&mut self, tx: ResponseTx<ApiConnectionMode>) -> Result<()> {
let next = self.next_connection_mode();
async fn on_next_connection_mode(&mut self, tx: ResponseTx<ApiConnectionMode>) -> Result<()> {
let next = self.next_connection_mode().await?;
self.reply(tx, next)
}

async fn next_connection_mode(&mut self) -> Result<ApiConnectionMode> {
let access_method = self.connection_modes.next().ok_or(Error::NoAccessMethods)?;
log::info!(
"A new API access method has been selected: {name}",
name = access_method.name
);
let next = {
let resolved = self.resolve(access_method).await;
// Note: If the daemon is busy waiting for a call to this function
// to complete while we wait for the daemon to fully handle this
// `NewAccessMethodEvent`, then we find ourselves in a deadlock.
// This can happen during daemon startup when spawning a new
// `MullvadRestHandle`, which will call and await `next` on a Stream
// created from this `AccessModeSelector` instance. As such, the
// completion channel is discarded in this instance.
let _completion =
NewAccessMethodEvent::new(resolved.setting.clone(), resolved.endpoint.clone())
.send(&self.access_method_event_sender);
self.current = resolved.clone();
resolved
};

// Save the new connection mode to cache!
{
let cache_dir = self.cache_dir.clone();
let next = next.clone();
let new_connection_mode = next.connection_mode.clone();
tokio::spawn(async move {
if next.save(&cache_dir).await.is_err() {
if new_connection_mode.save(&cache_dir).await.is_err() {
log::warn!(
"Failed to save {connection_mode} to cache",
connection_mode = next
connection_mode = new_connection_mode
)
}
});
}
self.reply(tx, next)
Ok(next.connection_mode)
}

fn next_connection_mode(&mut self) -> ApiConnectionMode {
let access_method = self
.connection_modes
.next()
.map(|access_method_setting| access_method_setting.access_method)
.unwrap_or(AccessMethod::from(BuiltInAccessMethod::Direct));

let connection_mode = self.from(access_method);
log::info!("New API connection mode selected: {connection_mode}");
connection_mode
}

fn on_update_access_methods(
&mut self,
tx: ResponseTx<()>,
Expand Down Expand Up @@ -481,11 +526,6 @@ impl ConnectionModesIterator {
Ok(Box::new(access_methods.into_iter().cycle()))
}
}

/// Look at the currently active [`AccessMethod`]
pub fn peek(&self) -> AccessMethodSetting {
self.current.clone()
}
}

impl Iterator for ConnectionModesIterator {
Expand Down
18 changes: 11 additions & 7 deletions mullvad-daemon/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -696,14 +696,17 @@ where
});

let connection_modes = settings.api_access_methods.collect_enabled();
let connection_modes_address_cache = api_runtime.address_cache.clone();

let connection_modes_handler = api::AccessModeSelector::spawn(
cache_dir.clone(),
relay_selector.clone(),
connection_modes,
internal_event_tx.to_specialized_sender(),
connection_modes_address_cache.clone(),
)
);
.await
.map_err(Error::ApiConnectionModeError)?;

let api_handle = api_runtime
.mullvad_rest_handle(Box::pin(connection_modes_handler.clone().into_stream()))
Expand Down Expand Up @@ -758,11 +761,6 @@ where
vec![]
};

let initial_api_endpoint =
api::get_allowed_endpoint(talpid_types::net::Endpoint::from_socket_address(
api_runtime.address_cache.get_address().await,
talpid_types::net::TransportProtocol::Tcp,
));
let parameters_generator = tunnel::ParametersGenerator::new(
account_manager.clone(),
relay_selector.clone(),
Expand All @@ -780,6 +778,11 @@ where
let _ = param_gen_tx.unbounded_send(settings.tunnel_options.to_owned());
});

let initial_api_endpoint = connection_modes_handler
.get_current()
.await
.map_err(Error::ApiConnectionModeError)?
.endpoint;
let (offline_state_tx, offline_state_rx) = mpsc::unbounded();
#[cfg(target_os = "windows")]
let (volume_update_tx, volume_update_rx) = mpsc::unbounded();
Expand Down Expand Up @@ -2403,8 +2406,9 @@ where
let handle = self.connection_modes_handler.clone();
tokio::spawn(async move {
let result = handle
.get_access_method()
.get_current()
.await
.map(|current| current.setting)
.map_err(Error::ApiConnectionModeError);
Self::oneshot_send(tx, result, "get_current_api_access_method response");
});
Expand Down

0 comments on commit fb35fc2

Please sign in to comment.