From 15b426ac1fdddf3c66bf79c2f4dc0b549267aa8a Mon Sep 17 00:00:00 2001 From: Markus Pettersson Date: Fri, 22 Dec 2023 13:40:51 +0100 Subject: [PATCH] Cleanup daemon code related to access methods --- mullvad-daemon/src/access_method.rs | 73 +++++++++++++---------------- 1 file changed, 32 insertions(+), 41 deletions(-) diff --git a/mullvad-daemon/src/access_method.rs b/mullvad-daemon/src/access_method.rs index 3bea2c4c7dae..d5c1bd80e497 100644 --- a/mullvad-daemon/src/access_method.rs +++ b/mullvad-daemon/src/access_method.rs @@ -34,18 +34,6 @@ pub enum Error { Settings(#[error(source)] settings::Error), } -/// A tiny datastructure used for signaling whether the daemon should force a -/// rotation of the currently used [`AccessMethodSetting`] or not, and if so: -/// how it should do it. -pub enum Command { - /// There is no need to force a rotation of [`AccessMethodSetting`] - Nothing, - /// Select the next available [`AccessMethodSetting`], whichever that is - Rotate, - /// Select the [`AccessMethodSetting`] with a certain [`access_method::Id`] - Set(access_method::Id), -} - impl Daemon where L: EventListener + Clone + Send + 'static, @@ -79,30 +67,29 @@ where &mut self, access_method: access_method::Id, ) -> Result<(), Error> { - // Make sure that we are not trying to remove a built-in API access - // method - let command = match self.settings.api_access_methods.find_by_id(&access_method) { - Some(api_access_method) => { - if api_access_method.is_builtin() { - Err(Error::RemoveBuiltIn) - } else if api_access_method.get_id() - == self.get_current_access_method().await?.get_id() - { - Ok(Command::Rotate) - } else { - Ok(Command::Nothing) - } + match self.settings.api_access_methods.find_by_id(&access_method) { + // Make sure that we are not trying to remove a built-in API access + // method + Some(api_access_method) if api_access_method.is_builtin() => { + return Err(Error::RemoveBuiltIn) } - None => Ok(Command::Nothing), - }?; + // If the currently active access method is removed, a new access + // method should trigger + Some(api_access_method) + if api_access_method.get_id() + == self.get_current_access_method().await?.get_id() => + { + self.force_api_endpoint_rotation().await?; + } + _ => (), + } self.settings .update(|settings| settings.api_access_methods.remove(&access_method)) .await .map(|did_change| self.notify_on_change(did_change)) - .map_err(Error::Settings)? - .process_command(command) - .await + .map(|_| ()) + .map_err(Error::Settings) } /// Set a [`AccessMethodSetting`] as the current API access method. @@ -119,9 +106,6 @@ where .set_access_method(access_method) .await?; // Force a rotation of Access Methods. - // - // This is not a call to `process_command` due to the restrictions on - // recursively calling async functions. self.force_api_endpoint_rotation().await } @@ -150,7 +134,8 @@ where // If the currently active access method is updated, we need to re-set // it after updating the settings. let current = self.get_current_access_method().await?; - let mut command = Command::Nothing; + // If the currently active access method is updated, we need to re-set it. + let mut refresh = None; let settings_update = |settings: &mut Settings| { let access_methods = &mut settings.api_access_methods; if let Some(access_method) = @@ -158,7 +143,7 @@ where { *access_method = access_method_update; if access_method.get_id() == current.get_id() { - command = Command::Set(access_method.get_id()) + refresh = Some(access_method.get_id()) } // We have to be a bit careful. If we are about to disable the last // remaining enabled access method, we would cause an inconsistent state @@ -185,19 +170,25 @@ where .update(settings_update) .await .map(|did_change| self.notify_on_change(did_change)) - .map_err(Error::Settings)? - .process_command(command) - .await + .map_err(Error::Settings)?; + if let Some(id) = refresh { + self.set_api_access_method(id).await?; + } + Ok(()) } /// Return the [`AccessMethodSetting`] which is currently used to access the /// Mullvad API. pub async fn get_current_access_method(&self) -> Result { - Ok(self.connection_modes_handler.get_access_method().await?) + self.connection_modes_handler + .get_current() + .await + .map(|current| current.setting) + .map_err(Error::ConnectionMode) } - /// Change which [`AccessMethodSetting`] which will be used to figure out - /// the Mullvad API endpoint. + /// Change which [`AccessMethodSetting`] which will be used as the Mullvad + /// API endpoint. async fn force_api_endpoint_rotation(&self) -> Result<(), Error> { self.api_handle .service()