diff --git a/ipa-core/benches/oneshot/ipa.rs b/ipa-core/benches/oneshot/ipa.rs index 4e72dea26..b24d7ea7b 100644 --- a/ipa-core/benches/oneshot/ipa.rs +++ b/ipa-core/benches/oneshot/ipa.rs @@ -116,6 +116,7 @@ async fn run(args: Args) -> Result<(), Error> { ..Default::default() }, initial_gate: Some(Gate::default().narrow(&IpaPrf)), + timeout: None, ..TestWorldConfig::default() }; // Construct TestWorld early to initialize logging. diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index 34a8e92e8..993b06dc9 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -1162,9 +1162,9 @@ mod tests { let a: Vec = repeat_with(|| rng.gen()).take(count).collect(); let b: Vec = repeat_with(|| rng.gen()).take(count).collect(); - // Timeout is 10 seconds plus count * (3 ms). + // Timeout is 20 seconds plus count * (5 ms). let config = TestWorldConfig::default() - .with_timeout_secs(10 + 3 * u64::try_from(count).unwrap() / 1000); + .with_timeout_secs(20 + 5 * u64::try_from(count).unwrap() / 1000); let [ab0, ab1, ab2]: [Vec>; 3] = TestWorld::::with_config(&config) diff --git a/ipa-core/src/protocol/hybrid/oprf.rs b/ipa-core/src/protocol/hybrid/oprf.rs index 652c6ac1d..63f2c7b17 100644 --- a/ipa-core/src/protocol/hybrid/oprf.rs +++ b/ipa-core/src/protocol/hybrid/oprf.rs @@ -221,7 +221,7 @@ mod test { const SHARDS: usize = 2; let world: TestWorld> = TestWorld::with_shards(TestWorldConfig { initial_gate: Some(Gate::default().narrow(&ProtocolStep::Hybrid)), - timeout: Duration::from_secs(60), + timeout: Some(Duration::from_secs(60)), ..Default::default() }); diff --git a/ipa-core/src/test_fixture/world.rs b/ipa-core/src/test_fixture/world.rs index d16a1a804..bba3547f5 100644 --- a/ipa-core/src/test_fixture/world.rs +++ b/ipa-core/src/test_fixture/world.rs @@ -110,7 +110,7 @@ pub struct TestWorld { rng: Mutex, gate_vendor: Box, _shard_network: InMemoryShardNetwork, - timeout: Duration, + timeout: Option, } #[derive(Clone)] @@ -161,9 +161,11 @@ pub struct TestWorldConfig { /// Timeout for tests run by this `TestWorld`. /// - /// This timeout is implement using tokio, so it will only be able to terminate the test if the + /// If `None`, there is no timeout. + /// + /// The timeout is implemented using tokio, so it will only be able to terminate the test if the /// futures are yielding periodically. - pub timeout: Duration, + pub timeout: Option, } impl ShardingScheme for NotSharded { @@ -387,14 +389,19 @@ impl TestWorld { } async fn with_timeout(&self, fut: F) -> F::Output { - if cfg!(feature = "shuttle") { - fut.await + let timeout = if cfg!(feature = "shuttle") { + None } else { - let Ok(output) = tokio::time::timeout(self.timeout, fut).await else { + self.timeout + }; + if let Some(timeout) = timeout { + let Ok(output) = tokio::time::timeout(timeout, fut).await else { tracing::error!("timed out after {:?}", self.timeout); panic!("timed out after {:?}", self.timeout); }; output + } else { + fut.await } } } @@ -414,7 +421,7 @@ impl Default for TestWorldConfig { seed: thread_rng().next_u64(), initial_gate: None, stream_interceptor: passthrough(), - timeout: Duration::from_secs(10), + timeout: Some(Duration::from_secs(10)), } } } @@ -434,7 +441,13 @@ impl TestWorldConfig { #[must_use] pub fn with_timeout_secs(mut self, timeout_secs: u64) -> Self { - self.timeout = Duration::from_secs(timeout_secs); + self.timeout = Some(Duration::from_secs(timeout_secs)); + self + } + + #[must_use] + pub fn with_no_timeout(mut self) -> Self { + self.timeout = None; self }