Skip to content

Commit

Permalink
De-duplicate find DaemonEvent function
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkusPettersson98 committed Jan 12, 2024
1 parent 3212fc7 commit 8c1e963
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 41 deletions.
51 changes: 20 additions & 31 deletions test/test-manager/src/tests/account.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,6 @@ pub async fn test_automatic_wireguard_rotation(
rpc: ServiceClient,
mut mullvad_client: MullvadProxyClient,
) -> Result<(), Error> {
use futures::StreamExt;
// Make note of current WG key
let old_key = mullvad_client
.get_device()
Expand Down Expand Up @@ -340,36 +339,26 @@ pub async fn test_automatic_wireguard_rotation(
// Verify rotation has happened after a minute
const KEY_ROTATION_TIMEOUT: Duration = Duration::from_secs(100);

let mut event_stream = mullvad_client.events_listen().await.unwrap();
let get_pub_key_event = async {
loop {
// TODO(markus): See if this can be refactored. This is exactly the same as helpers:274.
match event_stream.next().await {
Some(Ok(DaemonEvent::Device(device_event))) => {
let pubkey = device_event
.new_state
.into_device()
.expect("Could not get device")
.device
.pubkey;
return Ok(pubkey);
}
Some(Ok(_)) => continue,
Some(Err(status)) => {
break Err(Error::Daemon(format!(
"Failed to get next event: {}",
status
)))
}
None => break Err(Error::Daemon(String::from("Lost daemon event stream"))),
}
}
};

let new_key = tokio::time::timeout(KEY_ROTATION_TIMEOUT, get_pub_key_event)
.await
.unwrap()
.unwrap();
let new_key = tokio::time::timeout(
KEY_ROTATION_TIMEOUT,
helpers::find_daemon_event(
mullvad_client.events_listen().await.unwrap(),
|daemon_event| match daemon_event {
DaemonEvent::Device(device_event) => Some(device_event),
_ => None,
},
),
)
.await
.map_err(|_error| Error::Daemon(String::from("Tunnel event listener timed out")))?
.map(|device_event| {
device_event
.new_state
.into_device()
.expect("Could not get device")
.device
.pubkey
})?;

assert_ne!(old_key, new_key);
Ok(())
Expand Down
26 changes: 16 additions & 10 deletions test/test-manager/src/tests/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,23 +263,29 @@ pub async fn find_next_tunnel_state(
) -> Result<mullvad_types::states::TunnelState, Error> {
tokio::time::timeout(
WAIT_FOR_TUNNEL_STATE_TIMEOUT,
find_next_tunnel_state_inner(stream, accept_state_fn),
find_daemon_event(stream, |daemon_event| match daemon_event {
DaemonEvent::TunnelState(state) if accept_state_fn(&state) => Some(state),
_ => None,
}),
)
.await
.map_err(|_error| Error::Daemon(String::from("Tunnel event listener timed out")))?
}

async fn find_next_tunnel_state_inner(
mut stream: impl futures::Stream<Item = Result<DaemonEvent, mullvad_management_interface::Error>>
pub async fn find_daemon_event<Accept, AcceptedEvent>(
mut event_stream: impl futures::Stream<Item = Result<DaemonEvent, mullvad_management_interface::Error>>
+ Unpin,
accept_state_fn: impl Fn(&mullvad_types::states::TunnelState) -> bool,
) -> Result<mullvad_types::states::TunnelState, Error> {
accept_event: Accept,
) -> Result<AcceptedEvent, Error>
where
Accept: Fn(DaemonEvent) -> Option<AcceptedEvent>,
{
loop {
match stream.next().await {
Some(Ok(DaemonEvent::TunnelState(state))) if accept_state_fn(&state) => {
return Ok(state)
}
Some(Ok(_)) => continue,
match event_stream.next().await {
Some(Ok(daemon_event)) => match accept_event(daemon_event) {
Some(accepted_event) => break Ok(accepted_event),
None => continue,
},
Some(Err(status)) => {
break Err(Error::Daemon(format!(
"Failed to get next event: {}",
Expand Down

0 comments on commit 8c1e963

Please sign in to comment.