Skip to content

Commit

Permalink
Use type safe Mullvad protobuf client over generated dito
Browse files Browse the repository at this point in the history
Re-write some code in the test framework to prefer the type safe wrapper
around the Mullvad app gRPC client instead of its auto-generated dito.

`ManagementServiceClient` is automatically generated from the protobuf
definitions found in `management_interface.proto`, and contains some
very crude types. The `MullvadProxyClient` is a type-safe wrapper around
`ManagementServiceClient` which performs conversions & validation of the
data types from the gRPC server (the daemon) to their respective
mappings in the `talpid-*` and `mullvad-*` crates. These types are more
ergonomic to work with, and since we already have the conversions in
place we should prefer those.
  • Loading branch information
MarkusPettersson98 committed Jan 12, 2024
1 parent 2f05885 commit 7425298
Show file tree
Hide file tree
Showing 14 changed files with 274 additions and 352 deletions.
3 changes: 3 additions & 0 deletions test/clippy.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
disallowed-types = [
{ path = "mullvad_management_interface::ManagementServiceClient", reason = "use `mullvad_management_interface::MullvadProxyClient` instead" },
]
2 changes: 1 addition & 1 deletion test/test-manager/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ with the `#[test_function]` attribute
#[test_function]
pub async fn test(
rpc: ServiceClient,
mut mullvad_client: mullvad_management_interface::ManagementServiceClient,
mut mullvad_client: mullvad_management_interface::MullvadProxyClient,
) -> Result<(), Error> {
Ok(())
}
Expand Down
7 changes: 4 additions & 3 deletions test/test-manager/src/mullvad_daemon.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#![allow(clippy::disallowed_types)]
use std::{io, time::Duration};

use futures::{channel::mpsc, future::BoxFuture, pin_mut, FutureExt, SinkExt, StreamExt};
use mullvad_management_interface::ManagementServiceClient;
use mullvad_management_interface::{ManagementServiceClient, MullvadProxyClient};
use test_rpc::{
mullvad_daemon::MullvadClientVersion,
transport::{ConnectionHandle, GrpcForwarder},
Expand Down Expand Up @@ -61,7 +62,7 @@ impl RpcClientProvider {
}
}

pub async fn new_client(&self) -> ManagementServiceClient {
pub async fn new_client(&self) -> MullvadProxyClient {
// FIXME: Ugly workaround to ensure that we don't receive stuff from a
// previous RPC session.
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
Expand All @@ -72,7 +73,7 @@ impl RpcClientProvider {
.await
.unwrap();

ManagementServiceClient::new(channel)
MullvadProxyClient::from_rpc_client(ManagementServiceClient::new(channel))
}
}

Expand Down
4 changes: 2 additions & 2 deletions test/test-manager/src/run_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
};
use anyhow::{Context, Result};
use futures::FutureExt;
use mullvad_management_interface::ManagementServiceClient;
use mullvad_management_interface::MullvadProxyClient;
use std::future::Future;
use std::panic;
use std::time::Duration;
Expand Down Expand Up @@ -84,7 +84,7 @@ pub async fn run(
.as_type(test.mullvad_client_version)
.await;

if let Some(client) = mclient.downcast_mut::<ManagementServiceClient>() {
if let Some(client) = mclient.downcast_mut::<MullvadProxyClient>() {
crate::tests::init_default_settings(client).await;
}

Expand Down
124 changes: 60 additions & 64 deletions test/test-manager/src/tests/account.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use super::config::TEST_CONFIG;
use super::{helpers, ui, Error, TestContext};
use mullvad_api::DevicesProxy;
use mullvad_management_interface::{types, Code, ManagementServiceClient};
use mullvad_types::device::Device;
use mullvad_management_interface::{self, client::DaemonEvent, Code, MullvadProxyClient};
use mullvad_types::device::{Device, DeviceState};
use mullvad_types::states::TunnelState;
use std::net::ToSocketAddrs;
use std::time::Duration;
Expand All @@ -17,7 +17,7 @@ const THROTTLE_RETRY_DELAY: Duration = Duration::from_secs(120);
pub async fn test_login(
_: TestContext,
_rpc: ServiceClient,
mut mullvad_client: ManagementServiceClient,
mut mullvad_client: MullvadProxyClient,
) -> Result<(), Error> {
//
// Instruct daemon to log in
Expand All @@ -33,7 +33,7 @@ pub async fn test_login(
.expect("login failed");

// Wait for the relay list to be updated
helpers::ensure_updated_relay_list(&mut mullvad_client).await;
helpers::ensure_updated_relay_list(&mut mullvad_client).await?;

Ok(())
}
Expand All @@ -44,12 +44,12 @@ pub async fn test_login(
pub async fn test_logout(
_: TestContext,
_rpc: ServiceClient,
mut mullvad_client: ManagementServiceClient,
mut mullvad_client: MullvadProxyClient,
) -> Result<(), Error> {
log::info!("Removing device");

mullvad_client
.logout_account(())
.logout_account()
.await
.expect("logout failed");

Expand All @@ -61,7 +61,7 @@ pub async fn test_logout(
pub async fn test_too_many_devices(
_: TestContext,
rpc: ServiceClient,
mut mullvad_client: ManagementServiceClient,
mut mullvad_client: MullvadProxyClient,
) -> Result<(), Error> {
log::info!("Using up all devices");

Expand Down Expand Up @@ -97,7 +97,9 @@ pub async fn test_too_many_devices(
log::info!("Log in with too many devices");
let login_result = login_with_retries(&mut mullvad_client).await;

assert!(matches!(login_result, Err(status) if status.code() == Code::ResourceExhausted));
assert!(
matches!(login_result, Err(mullvad_management_interface::Error::Rpc(status)) if status.code() == Code::ResourceExhausted)
);

// Run UI test
let ui_result = ui::run_test_env(
Expand Down Expand Up @@ -128,22 +130,20 @@ pub async fn test_too_many_devices(
pub async fn test_revoked_device(
_: TestContext,
rpc: ServiceClient,
mut mullvad_client: ManagementServiceClient,
mut mullvad_client: MullvadProxyClient,
) -> Result<(), Error> {
log::info!("Logging in/generating device");
login_with_retries(&mut mullvad_client)
.await
.expect("login failed");

let device_id = mullvad_client
.get_device(())
.get_device()
.await
.expect("failed to get device data")
.into_inner()
.device
.into_device()
.unwrap()
.device
.unwrap()
.id;

helpers::connect_and_wait(&mut mullvad_client).await?;
Expand All @@ -165,16 +165,16 @@ pub async fn test_revoked_device(
// Begin listening to tunnel state changes first, so that we catch changes due to
// `update_device`.
let events = mullvad_client
.events_listen(())
.events_listen()
.await
.expect("failed to begin listening for state changes")
.into_inner();
.expect("failed to begin listening for state changes");
let next_state =
helpers::find_next_tunnel_state(events, |state| matches!(state, TunnelState::Error(..),));

log::debug!("Update device state");

let _update_status = mullvad_client.update_device(()).await;
// Update the device status, which performs a device validation.
let _ = mullvad_client.update_device().await;

// Ensure that the tunnel state transitions to "error". Fail if it transitions to some other
// state.
Expand All @@ -186,12 +186,11 @@ pub async fn test_revoked_device(

// Verify that the device state is `Revoked`.
let device_state = mullvad_client
.get_device(())
.get_device()
.await
.expect("failed to get device data");
assert_eq!(
device_state.into_inner().state,
i32::from(types::device_state::State::Revoked),
assert!(
matches!(device_state, DeviceState::Revoked),
"expected device to be revoked"
);

Expand Down Expand Up @@ -244,28 +243,26 @@ pub async fn new_device_client() -> DevicesProxy {

/// Log in and retry if it fails due to throttling
pub async fn login_with_retries(
mullvad_client: &mut ManagementServiceClient,
) -> Result<(), mullvad_management_interface::Status> {
mullvad_client: &mut MullvadProxyClient,
) -> Result<(), mullvad_management_interface::Error> {
loop {
let result = mullvad_client
match mullvad_client
.login_account(TEST_CONFIG.account_number.clone())
.await;
.await
{
Err(mullvad_management_interface::Error::Rpc(status))
if status.message().to_uppercase().contains("THROTTLED") =>
{
// Work around throttling errors by sleeping
log::debug!(
"Login failed due to throttling. Sleeping for {} seconds",
THROTTLE_RETRY_DELAY.as_secs()
);

if let Err(error) = result {
if !error.message().contains("THROTTLED") {
return Err(error);
tokio::time::sleep(THROTTLE_RETRY_DELAY).await;
}

// Work around throttling errors by sleeping

log::debug!(
"Login failed due to throttling. Sleeping for {} seconds",
THROTTLE_RETRY_DELAY.as_secs()
);

tokio::time::sleep(THROTTLE_RETRY_DELAY).await;
} else {
break Ok(());
Err(err) => return Err(err),
Ok(_) => break Ok(()),
}
}
}
Expand Down Expand Up @@ -306,18 +303,17 @@ pub async fn retry_if_throttled<
pub async fn test_automatic_wireguard_rotation(
ctx: TestContext,
rpc: ServiceClient,
mut mullvad_client: ManagementServiceClient,
mut mullvad_client: MullvadProxyClient,
) -> Result<(), Error> {
use futures::StreamExt;
// Make note of current WG key
let old_key = mullvad_client
.get_device(())
.get_device()
.await
.expect("Could not get device")
.into_inner()
.device
.unwrap()
.into_device()
.expect("Could not get device")
.device
.unwrap()
.pubkey;

// Stop daemon
Expand All @@ -343,29 +339,29 @@ pub async fn test_automatic_wireguard_rotation(
// Verify rotation has happened after a minute
const KEY_ROTATION_TIMEOUT: Duration = Duration::from_secs(100);

let mut event_stream = mullvad_client.events_listen(()).await.unwrap().into_inner();
let mut event_stream = mullvad_client.events_listen().await.unwrap();
let get_pub_key_event = async {
loop {
let message = event_stream.message().await;
if let Ok(Some(event)) = message {
match event.event.unwrap() {
mullvad_management_interface::types::daemon_event::Event::Device(
device_event,
) => {
let pubkey = device_event
.new_state
.unwrap()
.device
.unwrap()
.device
.unwrap()
.pubkey;
return Ok(pubkey);
}
_ => continue,
// TODO(markus): See if this can be refactored. This is exactly the same as helpers:274.
match event_stream.next().await {
Some(Ok(DaemonEvent::Device(device_event))) => {
let pubkey = device_event
.new_state
.into_device()
.expect("Could not get device")
.device
.pubkey;
return Ok(pubkey);
}
Some(Ok(_)) => continue,
Some(Err(status)) => {
break Err(Error::Daemon(format!(
"Failed to get next event: {}",
status
)))
}
None => break Err(Error::Daemon(String::from("Lost daemon event stream"))),
}
return Err(message);
}
};

Expand Down
Loading

0 comments on commit 7425298

Please sign in to comment.