Skip to content

Commit

Permalink
Add a timeout to TestWorld
Browse files Browse the repository at this point in the history
  • Loading branch information
andyleiserson committed Dec 3, 2024
1 parent d3c7469 commit 1fb83fe
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 32 deletions.
6 changes: 4 additions & 2 deletions ipa-core/src/protocol/ipa_prf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -527,8 +527,9 @@ pub mod tests {
dp::NoiseParams,
ipa_prf::{oprf_ipa, oprf_padding::PaddingParameters},
},
sharding::NotSharded,
test_executor::run,
test_fixture::{ipa::TestRawDataRecord, Reconstruct, Runner, TestWorld},
test_fixture::{ipa::TestRawDataRecord, Reconstruct, Runner, TestWorld, TestWorldConfig},
};

fn test_input(
Expand Down Expand Up @@ -660,7 +661,8 @@ pub mod tests {
let dp_params = DpMechanism::Binomial { epsilon };
let per_user_credit_cap = 2_f64.powi(i32::try_from(SS_BITS).unwrap());
let padding_params = PaddingParameters::relaxed();
let world = TestWorld::default();
let config = TestWorldConfig::default().with_timeout_secs(60);
let world = TestWorld::<NotSharded>::with_config(&config);

let records: Vec<TestRawDataRecord> = vec![
test_input(0, 12345, false, 1, 0),
Expand Down
90 changes: 60 additions & 30 deletions ipa-core/src/test_fixture/world.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ use std::{
array::from_fn,
borrow::Borrow,
fmt::Debug,
future::IntoFuture,
io::stdout,
iter::{self, zip},
marker::PhantomData,
sync::Mutex,
time::Duration,
};

use async_trait::async_trait;
Expand Down Expand Up @@ -108,6 +110,7 @@ pub struct TestWorld<S: ShardingScheme = NotSharded> {
rng: Mutex<StdRng>,
gate_vendor: Box<dyn TestGateVendor>,
_shard_network: InMemoryShardNetwork,
timeout: Duration,
}

#[derive(Clone)]
Expand Down Expand Up @@ -155,6 +158,12 @@ pub struct TestWorldConfig {
/// [`MaliciousHelper`]: crate::helpers::in_memory_config::MaliciousHelper
/// [`passthrough`]: crate::helpers::in_memory_config::passthrough
pub stream_interceptor: DynStreamInterceptor,

/// Timeout for tests run by this `TestWorld`.
///
/// This timeout is implement using tokio, so it will only be able to terminate the test if the
/// futures are yielding periodically.
pub timeout: Duration,
}

impl ShardingScheme for NotSharded {
Expand Down Expand Up @@ -347,6 +356,7 @@ impl<S: ShardingScheme> TestWorld<S> {
rng: Mutex::new(rng),
gate_vendor: gate_vendor(config.initial_gate.clone()),
_shard_network: shard_network,
timeout: config.timeout,
}
}

Expand Down Expand Up @@ -375,6 +385,14 @@ impl<S: ShardingScheme> TestWorld<S> {
// unfortunately take `&self`.
StdRng::from_seed(self.rng.lock().unwrap().gen())
}

async fn with_timeout<F: IntoFuture>(&self, fut: F) -> F::Output {
let Ok(output) = tokio::time::timeout(self.timeout, fut).await else {
tracing::error!("timed out after {:?}", self.timeout);
panic!("timed out after {:?}", self.timeout);
};
output
}
}

impl Default for TestWorldConfig {
Expand All @@ -392,6 +410,7 @@ impl Default for TestWorldConfig {
seed: thread_rng().next_u64(),
initial_gate: None,
stream_interceptor: passthrough(),
timeout: Duration::from_secs(10),
}
}
}
Expand All @@ -409,6 +428,12 @@ impl TestWorldConfig {
self
}

#[must_use]
pub fn with_timeout_secs(mut self, timeout_secs: u64) -> Self {
self.timeout = Duration::from_secs(timeout_secs);
self
}

#[must_use]
pub fn role_assignment(&self) -> &RoleAssignment {
const DEFAULT_ASSIGNMENT: RoleAssignment = RoleAssignment::new([
Expand Down Expand Up @@ -538,18 +563,20 @@ impl<const SHARDS: usize, D: Distribute> Runner<WithShards<SHARDS, D>>
// No clippy, you're wrong, it is not redundant, it allows shard_fn to be `Copy`
#[allow(clippy::redundant_closure)]
let shard_fn = |ctx, input| helper_fn(ctx, input);
zip(shards.into_iter(), zip(zip(h1, h2), h3))
.map(|(shard, ((h1, h2), h3))| {
ShardWorld::<WithShards<SHARDS, D>>::run_either(
shard.contexts(&gate),
self.metrics_handle.span(),
[h1, h2, h3],
shard_fn,
)
})
.collect::<FuturesOrdered<_>>()
.collect::<Vec<_>>()
.await
self.with_timeout(
zip(shards.into_iter(), zip(zip(h1, h2), h3))
.map(|(shard, ((h1, h2), h3))| {
ShardWorld::<WithShards<SHARDS, D>>::run_either(
shard.contexts(&gate),
self.metrics_handle.span(),
[h1, h2, h3],
shard_fn,
)
})
.collect::<FuturesOrdered<_>>()
.collect::<Vec<_>>(),
)
.await
}

async fn malicious<'a, I, A, O, H, R>(&'a self, input: I, helper_fn: H) -> Vec<[O; 3]>
Expand All @@ -573,18 +600,20 @@ impl<const SHARDS: usize, D: Distribute> Runner<WithShards<SHARDS, D>>
// No clippy, you're wrong, it is not redundant, it allows shard_fn to be `Copy`
#[allow(clippy::redundant_closure)]
let shard_fn = |ctx, input| helper_fn(ctx, input);
zip(shards.into_iter(), zip(zip(h1, h2), h3))
.map(|(shard, ((h1, h2), h3))| {
ShardWorld::<WithShards<SHARDS, D>>::run_either(
shard.malicious_contexts(&gate),
self.metrics_handle.span(),
[h1, h2, h3],
shard_fn,
)
})
.collect::<FuturesOrdered<_>>()
.collect::<Vec<_>>()
.await
self.with_timeout(
zip(shards.into_iter(), zip(zip(h1, h2), h3))
.map(|(shard, ((h1, h2), h3))| {
ShardWorld::<WithShards<SHARDS, D>>::run_either(
shard.malicious_contexts(&gate),
self.metrics_handle.span(),
[h1, h2, h3],
shard_fn,
)
})
.collect::<FuturesOrdered<_>>()
.collect::<Vec<_>>(),
)
.await
}

async fn upgraded_malicious<'a, F, I, A, M, O, H, R, P>(
Expand Down Expand Up @@ -645,12 +674,12 @@ impl Runner<NotSharded> for TestWorld<NotSharded> {
H: Fn(Self::SemiHonestContext<'a>, A) -> R + Send + Sync,
R: Future<Output = O> + Send,
{
ShardWorld::<NotSharded>::run_either(
self.with_timeout(ShardWorld::<NotSharded>::run_either(
self.contexts(),
self.metrics_handle.span(),
input.share_with(&mut self.rng()),
helper_fn,
)
))
.await
}

Expand All @@ -662,12 +691,12 @@ impl Runner<NotSharded> for TestWorld<NotSharded> {
H: Fn(Self::MaliciousContext<'a>, A) -> R + Send + Sync,
R: Future<Output = O> + Send,
{
ShardWorld::<NotSharded>::run_either(
self.with_timeout(ShardWorld::<NotSharded>::run_either(
self.malicious_contexts(),
self.metrics_handle.span(),
input.share_with(&mut self.rng()),
helper_fn,
)
))
.await
}

Expand Down Expand Up @@ -747,7 +776,7 @@ impl Runner<NotSharded> for TestWorld<NotSharded> {
H: Fn(DZKPUpgradedSemiHonestContext<'a, NotSharded>, A) -> R + Send + Sync,
R: Future<Output = O> + Send,
{
ShardWorld::<NotSharded>::run_either(
self.with_timeout(ShardWorld::<NotSharded>::run_either(
self.contexts(),
self.metrics_handle.span(),
input.share(),
Expand All @@ -758,7 +787,7 @@ impl Runner<NotSharded> for TestWorld<NotSharded> {
v.validate().await.unwrap();
m_result
},
)
))
.await
}

Expand Down Expand Up @@ -852,6 +881,7 @@ impl<S: ShardingScheme> ShardWorld<S> {
}))
.instrument(span)
.await;

<[_; 3]>::try_from(output).unwrap()
}

Expand Down

0 comments on commit 1fb83fe

Please sign in to comment.