Skip to content

Commit

Permalink
Add a timeout to TestWorld (#1467)
Browse files Browse the repository at this point in the history
  • Loading branch information
andyleiserson authored Dec 3, 2024
1 parent e4d833d commit dca5be7
Show file tree
Hide file tree
Showing 10 changed files with 145 additions and 67 deletions.
1 change: 1 addition & 0 deletions ipa-core/benches/oneshot/ipa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ async fn run(args: Args) -> Result<(), Error> {
..Default::default()
},
initial_gate: Some(Gate::default().narrow(&IpaPrf)),
timeout: None,
..TestWorldConfig::default()
};
// Construct TestWorld early to initialize logging.
Expand Down
64 changes: 38 additions & 26 deletions ipa-core/src/protocol/context/dzkp_validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,7 @@ mod tests {
seq_join::{seq_join, SeqJoin},
sharding::NotSharded,
test_executor::run_random,
test_fixture::{join3v, Reconstruct, Runner, TestWorld},
test_fixture::{join3v, Reconstruct, Runner, TestWorld, TestWorldConfig},
};

async fn test_select_semi_honest<V>()
Expand Down Expand Up @@ -1162,30 +1162,35 @@ mod tests {
let a: Vec<V> = repeat_with(|| rng.gen()).take(count).collect();
let b: Vec<V> = repeat_with(|| rng.gen()).take(count).collect();

let [ab0, ab1, ab2]: [Vec<Replicated<V>>; 3] = TestWorld::default()
.malicious(
zip(bit.clone(), zip(a.clone(), b.clone())),
|ctx, inputs| async move {
let v = ctx
.set_total_records(count)
.dzkp_validator(TEST_DZKP_STEPS, max_multiplications_per_gate);
let m_ctx = v.context();

v.validated_seq_join(stream::iter(inputs).enumerate().map(
|(i, (bit_share, (a_share, b_share)))| {
let m_ctx = m_ctx.clone();
async move {
select(m_ctx, RecordId::from(i), &bit_share, &a_share, &b_share)
.await
}
},
))
.try_collect()
.await
},
)
.await
.map(Result::unwrap);
// Timeout is 10 seconds plus count * (3 ms).
let config = TestWorldConfig::default()
.with_timeout_secs(10 + 3 * u64::try_from(count).unwrap() / 1000);

let [ab0, ab1, ab2]: [Vec<Replicated<V>>; 3] =
TestWorld::<NotSharded>::with_config(&config)
.malicious(
zip(bit.clone(), zip(a.clone(), b.clone())),
|ctx, inputs| async move {
let v = ctx
.set_total_records(count)
.dzkp_validator(TEST_DZKP_STEPS, max_multiplications_per_gate);
let m_ctx = v.context();

v.validated_seq_join(stream::iter(inputs).enumerate().map(
|(i, (bit_share, (a_share, b_share)))| {
let m_ctx = m_ctx.clone();
async move {
select(m_ctx, RecordId::from(i), &bit_share, &a_share, &b_share)
.await
}
},
))
.try_collect()
.await
},
)
.await
.map(Result::unwrap);

let ab: Vec<V> = [ab0, ab1, ab2].reconstruct();

Expand Down Expand Up @@ -1355,7 +1360,11 @@ mod tests {
}
}

// This test is much slower in the multi-threading config, perhaps because the
// amount of work it does for each record is very small compared to the overhead of
// spawning tasks.
#[tokio::test]
#[cfg(not(feature = "multi-threading"))]
async fn large_batch() {
multi_select_malicious::<BA8>(2 * TARGET_PROOF_SIZE, 2 * TARGET_PROOF_SIZE).await;
}
Expand All @@ -1371,7 +1380,10 @@ mod tests {
let a: Vec<BA8> = repeat_with(|| rng.gen()).take(count).collect();
let b: Vec<BA8> = repeat_with(|| rng.gen()).take(count).collect();

let [ab0, ab1, ab2]: [Vec<Replicated<BA8>>; 3] = TestWorld::default()
let config = TestWorldConfig::default().with_timeout_secs(60);
let world = TestWorld::<NotSharded>::with_config(&config);

let [ab0, ab1, ab2]: [Vec<Replicated<BA8>>; 3] = world
.malicious(
zip(bit.clone(), zip(a.clone(), b.clone())),
|ctx, inputs| async move {
Expand Down
10 changes: 7 additions & 3 deletions ipa-core/src/protocol/dp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ mod test {
replicated::{semi_honest::AdditiveShare as Replicated, ReplicatedSecretSharing},
BitDecomposed, SharedValue, TransposeFrom,
},
sharding::NotSharded,
telemetry::metrics::BYTES_SENT,
test_fixture::{Reconstruct, Runner, TestWorld, TestWorldConfig},
};
Expand Down Expand Up @@ -863,7 +864,8 @@ mod test {
if std::env::var("EXEC_SLOW_TESTS").is_err() {
return;
}
let world = TestWorld::default();
let config = TestWorldConfig::default().with_timeout_secs(60);
let world = TestWorld::<NotSharded>::with_config(&config);
let result: [Vec<Replicated<OutputValue>>; 3] = world
.dzkp_semi_honest((), |ctx, ()| async move {
Vec::transposed_from(
Expand Down Expand Up @@ -898,7 +900,8 @@ mod test {
type OutputValue = BA16;
const NUM_BREAKDOWNS: u32 = 32;
let num_bernoulli: u32 = 2000;
let world = TestWorld::default();
let config = TestWorldConfig::default().with_timeout_secs(60);
let world = TestWorld::<NotSharded>::with_config(&config);
let result: [Vec<Replicated<OutputValue>>; 3] = world
.dzkp_semi_honest((), |ctx, ()| async move {
Vec::transposed_from(
Expand Down Expand Up @@ -933,7 +936,8 @@ mod test {
type OutputValue = BA16;
const NUM_BREAKDOWNS: u32 = 256;
let num_bernoulli: u32 = 1000;
let world = TestWorld::default();
let config = TestWorldConfig::default().with_timeout_secs(60);
let world = TestWorld::<NotSharded>::with_config(&config);
let result: [Vec<Replicated<OutputValue>>; 3] = world
.dzkp_semi_honest((), |ctx, ()| async move {
Vec::transposed_from(
Expand Down
5 changes: 4 additions & 1 deletion ipa-core/src/protocol/hybrid/breakdown_reveal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,10 +336,13 @@ pub mod tests {
#[test]
#[cfg(not(feature = "shuttle"))] // too slow
fn breakdown_reveal_malicious_happy_path() {
use crate::test_fixture::TestWorldConfig;

type HV = BA16;
const SHARDS: usize = 2;
run(|| async {
let world = TestWorld::<WithShards<SHARDS>>::with_shards(TestWorldConfig::default());
let config = TestWorldConfig::default().with_timeout_secs(60);
let world = TestWorld::<WithShards<SHARDS>>::with_shards(&config);
let (inputs, expectation) = inputs_and_expectation(world.rng());

let result: Vec<_> = world
Expand Down
6 changes: 5 additions & 1 deletion ipa-core/src/protocol/hybrid/oprf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,10 @@ where

#[cfg(all(test, unit_test, feature = "in-memory-infra"))]
mod test {
use std::collections::{HashMap, HashSet};
use std::{
collections::{HashMap, HashSet},
time::Duration,
};

use ipa_step::StepNarrow;

Expand All @@ -218,6 +221,7 @@ mod test {
const SHARDS: usize = 2;
let world: TestWorld<WithShards<SHARDS>> = TestWorld::with_shards(TestWorldConfig {
initial_gate: Some(Gate::default().narrow(&ProtocolStep::Hybrid)),
timeout: Some(Duration::from_secs(60)),
..Default::default()
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,9 +392,12 @@ pub mod tests {
#[test]
#[cfg(not(feature = "shuttle"))] // too slow
fn malicious_happy_path() {
use crate::{sharding::NotSharded, test_fixture::TestWorldConfig};

type HV = BA16;
run(|| async {
let world = TestWorld::default();
let config = TestWorldConfig::default().with_timeout_secs(60);
let world = TestWorld::<NotSharded>::with_config(&config);
let mut rng = world.rng();
let mut expectation = Vec::new();
for _ in 0..32 {
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/protocol/ipa_prf/aggregation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ pub mod tests {

proptest! {
#[test]
fn aggregate_proptest(
fn aggregate_values_proptest(
input_struct in arb_aggregate_values_inputs(PROP_MAX_INPUT_LEN),
seed in any::<u64>(),
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,9 @@ mod tests {
rand::thread_rng,
secret_sharing::SharedValue,
seq_join::{seq_join, SeqJoin},
sharding::NotSharded,
test_executor::run,
test_fixture::{ReconstructArr, Runner, TestWorld},
test_fixture::{ReconstructArr, Runner, TestWorld, TestWorldConfig},
};

#[test]
Expand Down Expand Up @@ -457,7 +458,8 @@ mod tests {
const COUNT: usize = CONV_CHUNK * PROOF_CHUNK * 2 + 1;
const TOTAL_RECORDS: usize = COUNT.div_ceil(CONV_CHUNK);

let world = TestWorld::default();
let config = TestWorldConfig::default().with_timeout_secs(60);
let world = TestWorld::<NotSharded>::with_config(&config);

let mut rng = thread_rng();

Expand Down
6 changes: 4 additions & 2 deletions ipa-core/src/protocol/ipa_prf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -527,8 +527,9 @@ pub mod tests {
dp::NoiseParams,
ipa_prf::{oprf_ipa, oprf_padding::PaddingParameters},
},
sharding::NotSharded,
test_executor::run,
test_fixture::{ipa::TestRawDataRecord, Reconstruct, Runner, TestWorld},
test_fixture::{ipa::TestRawDataRecord, Reconstruct, Runner, TestWorld, TestWorldConfig},
};

fn test_input(
Expand Down Expand Up @@ -660,7 +661,8 @@ pub mod tests {
let dp_params = DpMechanism::Binomial { epsilon };
let per_user_credit_cap = 2_f64.powi(i32::try_from(SS_BITS).unwrap());
let padding_params = PaddingParameters::relaxed();
let world = TestWorld::default();
let config = TestWorldConfig::default().with_timeout_secs(60);
let world = TestWorld::<NotSharded>::with_config(&config);

let records: Vec<TestRawDataRecord> = vec![
test_input(0, 12345, false, 1, 0),
Expand Down
Loading

0 comments on commit dca5be7

Please sign in to comment.