diff --git a/ipa-core/benches/oneshot/ipa.rs b/ipa-core/benches/oneshot/ipa.rs index 4e72dea26..b24d7ea7b 100644 --- a/ipa-core/benches/oneshot/ipa.rs +++ b/ipa-core/benches/oneshot/ipa.rs @@ -116,6 +116,7 @@ async fn run(args: Args) -> Result<(), Error> { ..Default::default() }, initial_gate: Some(Gate::default().narrow(&IpaPrf)), + timeout: None, ..TestWorldConfig::default() }; // Construct TestWorld early to initialize logging. diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index 4f0b47096..5cc726369 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -973,7 +973,7 @@ mod tests { seq_join::{seq_join, SeqJoin}, sharding::NotSharded, test_executor::run_random, - test_fixture::{join3v, Reconstruct, Runner, TestWorld}, + test_fixture::{join3v, Reconstruct, Runner, TestWorld, TestWorldConfig}, }; async fn test_select_semi_honest() @@ -1162,30 +1162,35 @@ mod tests { let a: Vec = repeat_with(|| rng.gen()).take(count).collect(); let b: Vec = repeat_with(|| rng.gen()).take(count).collect(); - let [ab0, ab1, ab2]: [Vec>; 3] = TestWorld::default() - .malicious( - zip(bit.clone(), zip(a.clone(), b.clone())), - |ctx, inputs| async move { - let v = ctx - .set_total_records(count) - .dzkp_validator(TEST_DZKP_STEPS, max_multiplications_per_gate); - let m_ctx = v.context(); - - v.validated_seq_join(stream::iter(inputs).enumerate().map( - |(i, (bit_share, (a_share, b_share)))| { - let m_ctx = m_ctx.clone(); - async move { - select(m_ctx, RecordId::from(i), &bit_share, &a_share, &b_share) - .await - } - }, - )) - .try_collect() - .await - }, - ) - .await - .map(Result::unwrap); + // Timeout is 10 seconds plus count * (3 ms). + let config = TestWorldConfig::default() + .with_timeout_secs(10 + 3 * u64::try_from(count).unwrap() / 1000); + + let [ab0, ab1, ab2]: [Vec>; 3] = + TestWorld::::with_config(&config) + .malicious( + zip(bit.clone(), zip(a.clone(), b.clone())), + |ctx, inputs| async move { + let v = ctx + .set_total_records(count) + .dzkp_validator(TEST_DZKP_STEPS, max_multiplications_per_gate); + let m_ctx = v.context(); + + v.validated_seq_join(stream::iter(inputs).enumerate().map( + |(i, (bit_share, (a_share, b_share)))| { + let m_ctx = m_ctx.clone(); + async move { + select(m_ctx, RecordId::from(i), &bit_share, &a_share, &b_share) + .await + } + }, + )) + .try_collect() + .await + }, + ) + .await + .map(Result::unwrap); let ab: Vec = [ab0, ab1, ab2].reconstruct(); @@ -1355,7 +1360,11 @@ mod tests { } } + // This test is much slower in the multi-threading config, perhaps because the + // amount of work it does for each record is very small compared to the overhead of + // spawning tasks. #[tokio::test] + #[cfg(not(feature = "multi-threading"))] async fn large_batch() { multi_select_malicious::(2 * TARGET_PROOF_SIZE, 2 * TARGET_PROOF_SIZE).await; } @@ -1371,7 +1380,10 @@ mod tests { let a: Vec = repeat_with(|| rng.gen()).take(count).collect(); let b: Vec = repeat_with(|| rng.gen()).take(count).collect(); - let [ab0, ab1, ab2]: [Vec>; 3] = TestWorld::default() + let config = TestWorldConfig::default().with_timeout_secs(60); + let world = TestWorld::::with_config(&config); + + let [ab0, ab1, ab2]: [Vec>; 3] = world .malicious( zip(bit.clone(), zip(a.clone(), b.clone())), |ctx, inputs| async move { diff --git a/ipa-core/src/protocol/dp/mod.rs b/ipa-core/src/protocol/dp/mod.rs index 6ec703c0f..fbd8263f4 100644 --- a/ipa-core/src/protocol/dp/mod.rs +++ b/ipa-core/src/protocol/dp/mod.rs @@ -619,6 +619,7 @@ mod test { replicated::{semi_honest::AdditiveShare as Replicated, ReplicatedSecretSharing}, BitDecomposed, SharedValue, TransposeFrom, }, + sharding::NotSharded, telemetry::metrics::BYTES_SENT, test_fixture::{Reconstruct, Runner, TestWorld, TestWorldConfig}, }; @@ -863,7 +864,8 @@ mod test { if std::env::var("EXEC_SLOW_TESTS").is_err() { return; } - let world = TestWorld::default(); + let config = TestWorldConfig::default().with_timeout_secs(60); + let world = TestWorld::::with_config(&config); let result: [Vec>; 3] = world .dzkp_semi_honest((), |ctx, ()| async move { Vec::transposed_from( @@ -898,7 +900,8 @@ mod test { type OutputValue = BA16; const NUM_BREAKDOWNS: u32 = 32; let num_bernoulli: u32 = 2000; - let world = TestWorld::default(); + let config = TestWorldConfig::default().with_timeout_secs(60); + let world = TestWorld::::with_config(&config); let result: [Vec>; 3] = world .dzkp_semi_honest((), |ctx, ()| async move { Vec::transposed_from( @@ -933,7 +936,8 @@ mod test { type OutputValue = BA16; const NUM_BREAKDOWNS: u32 = 256; let num_bernoulli: u32 = 1000; - let world = TestWorld::default(); + let config = TestWorldConfig::default().with_timeout_secs(60); + let world = TestWorld::::with_config(&config); let result: [Vec>; 3] = world .dzkp_semi_honest((), |ctx, ()| async move { Vec::transposed_from( diff --git a/ipa-core/src/protocol/hybrid/breakdown_reveal.rs b/ipa-core/src/protocol/hybrid/breakdown_reveal.rs index be676bea6..2f96e9536 100644 --- a/ipa-core/src/protocol/hybrid/breakdown_reveal.rs +++ b/ipa-core/src/protocol/hybrid/breakdown_reveal.rs @@ -336,10 +336,13 @@ pub mod tests { #[test] #[cfg(not(feature = "shuttle"))] // too slow fn breakdown_reveal_malicious_happy_path() { + use crate::test_fixture::TestWorldConfig; + type HV = BA16; const SHARDS: usize = 2; run(|| async { - let world = TestWorld::>::with_shards(TestWorldConfig::default()); + let config = TestWorldConfig::default().with_timeout_secs(60); + let world = TestWorld::>::with_shards(&config); let (inputs, expectation) = inputs_and_expectation(world.rng()); let result: Vec<_> = world diff --git a/ipa-core/src/protocol/hybrid/oprf.rs b/ipa-core/src/protocol/hybrid/oprf.rs index da2bf903f..63f2c7b17 100644 --- a/ipa-core/src/protocol/hybrid/oprf.rs +++ b/ipa-core/src/protocol/hybrid/oprf.rs @@ -200,7 +200,10 @@ where #[cfg(all(test, unit_test, feature = "in-memory-infra"))] mod test { - use std::collections::{HashMap, HashSet}; + use std::{ + collections::{HashMap, HashSet}, + time::Duration, + }; use ipa_step::StepNarrow; @@ -218,6 +221,7 @@ mod test { const SHARDS: usize = 2; let world: TestWorld> = TestWorld::with_shards(TestWorldConfig { initial_gate: Some(Gate::default().narrow(&ProtocolStep::Hybrid)), + timeout: Some(Duration::from_secs(60)), ..Default::default() }); diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs index 198de6be9..8b799d0ac 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs @@ -392,9 +392,12 @@ pub mod tests { #[test] #[cfg(not(feature = "shuttle"))] // too slow fn malicious_happy_path() { + use crate::{sharding::NotSharded, test_fixture::TestWorldConfig}; + type HV = BA16; run(|| async { - let world = TestWorld::default(); + let config = TestWorldConfig::default().with_timeout_secs(60); + let world = TestWorld::::with_config(&config); let mut rng = world.rng(); let mut expectation = Vec::new(); for _ in 0..32 { diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs b/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs index 6a9adb345..e8d7631b7 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/mod.rs @@ -536,7 +536,7 @@ pub mod tests { proptest! { #[test] - fn aggregate_proptest( + fn aggregate_values_proptest( input_struct in arb_aggregate_values_inputs(PROP_MAX_INPUT_LEN), seed in any::(), ) { diff --git a/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs b/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs index ba10d84bd..74955a9e1 100644 --- a/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs +++ b/ipa-core/src/protocol/ipa_prf/boolean_ops/share_conversion_aby.rs @@ -386,8 +386,9 @@ mod tests { rand::thread_rng, secret_sharing::SharedValue, seq_join::{seq_join, SeqJoin}, + sharding::NotSharded, test_executor::run, - test_fixture::{ReconstructArr, Runner, TestWorld}, + test_fixture::{ReconstructArr, Runner, TestWorld, TestWorldConfig}, }; #[test] @@ -457,7 +458,8 @@ mod tests { const COUNT: usize = CONV_CHUNK * PROOF_CHUNK * 2 + 1; const TOTAL_RECORDS: usize = COUNT.div_ceil(CONV_CHUNK); - let world = TestWorld::default(); + let config = TestWorldConfig::default().with_timeout_secs(60); + let world = TestWorld::::with_config(&config); let mut rng = thread_rng(); diff --git a/ipa-core/src/protocol/ipa_prf/mod.rs b/ipa-core/src/protocol/ipa_prf/mod.rs index 63dc34a6f..2659d31d4 100644 --- a/ipa-core/src/protocol/ipa_prf/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/mod.rs @@ -527,8 +527,9 @@ pub mod tests { dp::NoiseParams, ipa_prf::{oprf_ipa, oprf_padding::PaddingParameters}, }, + sharding::NotSharded, test_executor::run, - test_fixture::{ipa::TestRawDataRecord, Reconstruct, Runner, TestWorld}, + test_fixture::{ipa::TestRawDataRecord, Reconstruct, Runner, TestWorld, TestWorldConfig}, }; fn test_input( @@ -660,7 +661,8 @@ pub mod tests { let dp_params = DpMechanism::Binomial { epsilon }; let per_user_credit_cap = 2_f64.powi(i32::try_from(SS_BITS).unwrap()); let padding_params = PaddingParameters::relaxed(); - let world = TestWorld::default(); + let config = TestWorldConfig::default().with_timeout_secs(60); + let world = TestWorld::::with_config(&config); let records: Vec = vec![ test_input(0, 12345, false, 1, 0), diff --git a/ipa-core/src/test_fixture/world.rs b/ipa-core/src/test_fixture/world.rs index 91cfc87ea..bba3547f5 100644 --- a/ipa-core/src/test_fixture/world.rs +++ b/ipa-core/src/test_fixture/world.rs @@ -4,10 +4,12 @@ use std::{ array::from_fn, borrow::Borrow, fmt::Debug, + future::IntoFuture, io::stdout, iter::{self, zip}, marker::PhantomData, sync::Mutex, + time::Duration, }; use async_trait::async_trait; @@ -108,6 +110,7 @@ pub struct TestWorld { rng: Mutex, gate_vendor: Box, _shard_network: InMemoryShardNetwork, + timeout: Option, } #[derive(Clone)] @@ -155,6 +158,14 @@ pub struct TestWorldConfig { /// [`MaliciousHelper`]: crate::helpers::in_memory_config::MaliciousHelper /// [`passthrough`]: crate::helpers::in_memory_config::passthrough pub stream_interceptor: DynStreamInterceptor, + + /// Timeout for tests run by this `TestWorld`. + /// + /// If `None`, there is no timeout. + /// + /// The timeout is implemented using tokio, so it will only be able to terminate the test if the + /// futures are yielding periodically. + pub timeout: Option, } impl ShardingScheme for NotSharded { @@ -347,6 +358,7 @@ impl TestWorld { rng: Mutex::new(rng), gate_vendor: gate_vendor(config.initial_gate.clone()), _shard_network: shard_network, + timeout: config.timeout, } } @@ -375,6 +387,23 @@ impl TestWorld { // unfortunately take `&self`. StdRng::from_seed(self.rng.lock().unwrap().gen()) } + + async fn with_timeout(&self, fut: F) -> F::Output { + let timeout = if cfg!(feature = "shuttle") { + None + } else { + self.timeout + }; + if let Some(timeout) = timeout { + let Ok(output) = tokio::time::timeout(timeout, fut).await else { + tracing::error!("timed out after {:?}", self.timeout); + panic!("timed out after {:?}", self.timeout); + }; + output + } else { + fut.await + } + } } impl Default for TestWorldConfig { @@ -392,6 +421,7 @@ impl Default for TestWorldConfig { seed: thread_rng().next_u64(), initial_gate: None, stream_interceptor: passthrough(), + timeout: Some(Duration::from_secs(10)), } } } @@ -409,6 +439,18 @@ impl TestWorldConfig { self } + #[must_use] + pub fn with_timeout_secs(mut self, timeout_secs: u64) -> Self { + self.timeout = Some(Duration::from_secs(timeout_secs)); + self + } + + #[must_use] + pub fn with_no_timeout(mut self) -> Self { + self.timeout = None; + self + } + #[must_use] pub fn role_assignment(&self) -> &RoleAssignment { const DEFAULT_ASSIGNMENT: RoleAssignment = RoleAssignment::new([ @@ -538,18 +580,20 @@ impl Runner> // No clippy, you're wrong, it is not redundant, it allows shard_fn to be `Copy` #[allow(clippy::redundant_closure)] let shard_fn = |ctx, input| helper_fn(ctx, input); - zip(shards.into_iter(), zip(zip(h1, h2), h3)) - .map(|(shard, ((h1, h2), h3))| { - ShardWorld::>::run_either( - shard.contexts(&gate), - self.metrics_handle.span(), - [h1, h2, h3], - shard_fn, - ) - }) - .collect::>() - .collect::>() - .await + self.with_timeout( + zip(shards.into_iter(), zip(zip(h1, h2), h3)) + .map(|(shard, ((h1, h2), h3))| { + ShardWorld::>::run_either( + shard.contexts(&gate), + self.metrics_handle.span(), + [h1, h2, h3], + shard_fn, + ) + }) + .collect::>() + .collect::>(), + ) + .await } async fn malicious<'a, I, A, O, H, R>(&'a self, input: I, helper_fn: H) -> Vec<[O; 3]> @@ -573,18 +617,20 @@ impl Runner> // No clippy, you're wrong, it is not redundant, it allows shard_fn to be `Copy` #[allow(clippy::redundant_closure)] let shard_fn = |ctx, input| helper_fn(ctx, input); - zip(shards.into_iter(), zip(zip(h1, h2), h3)) - .map(|(shard, ((h1, h2), h3))| { - ShardWorld::>::run_either( - shard.malicious_contexts(&gate), - self.metrics_handle.span(), - [h1, h2, h3], - shard_fn, - ) - }) - .collect::>() - .collect::>() - .await + self.with_timeout( + zip(shards.into_iter(), zip(zip(h1, h2), h3)) + .map(|(shard, ((h1, h2), h3))| { + ShardWorld::>::run_either( + shard.malicious_contexts(&gate), + self.metrics_handle.span(), + [h1, h2, h3], + shard_fn, + ) + }) + .collect::>() + .collect::>(), + ) + .await } async fn upgraded_malicious<'a, F, I, A, M, O, H, R, P>( @@ -645,12 +691,12 @@ impl Runner for TestWorld { H: Fn(Self::SemiHonestContext<'a>, A) -> R + Send + Sync, R: Future + Send, { - ShardWorld::::run_either( + self.with_timeout(ShardWorld::::run_either( self.contexts(), self.metrics_handle.span(), input.share_with(&mut self.rng()), helper_fn, - ) + )) .await } @@ -662,12 +708,12 @@ impl Runner for TestWorld { H: Fn(Self::MaliciousContext<'a>, A) -> R + Send + Sync, R: Future + Send, { - ShardWorld::::run_either( + self.with_timeout(ShardWorld::::run_either( self.malicious_contexts(), self.metrics_handle.span(), input.share_with(&mut self.rng()), helper_fn, - ) + )) .await } @@ -747,7 +793,7 @@ impl Runner for TestWorld { H: Fn(DZKPUpgradedSemiHonestContext<'a, NotSharded>, A) -> R + Send + Sync, R: Future + Send, { - ShardWorld::::run_either( + self.with_timeout(ShardWorld::::run_either( self.contexts(), self.metrics_handle.span(), input.share(), @@ -758,7 +804,7 @@ impl Runner for TestWorld { v.validate().await.unwrap(); m_result }, - ) + )) .await } @@ -852,6 +898,7 @@ impl ShardWorld { })) .instrument(span) .await; + <[_; 3]>::try_from(output).unwrap() }