diff --git a/ipa-core/src/bin/report_collector.rs b/ipa-core/src/bin/report_collector.rs index 18779b545..aa6da0bb4 100644 --- a/ipa-core/src/bin/report_collector.rs +++ b/ipa-core/src/bin/report_collector.rs @@ -14,8 +14,8 @@ use hyper::http::uri::Scheme; use ipa_core::{ cli::{ playbook::{ - make_clients, playbook_oprf_ipa, run_query_and_validate, validate, validate_dp, - InputSource, + make_clients, make_sharded_clients, playbook_oprf_ipa, run_hybrid_query_and_validate, + run_query_and_validate, validate, validate_dp, HybridQueryResult, InputSource, }, CsvSerializer, IpaQueryResult, Verbosity, }, @@ -60,6 +60,9 @@ struct Args { #[arg(long, value_name = "OUTPUT_FILE")] output_file: Option, + #[arg(long, default_value_t = 1)] + shard_count: usize, + #[command(subcommand)] action: ReportCollectorCommand, } @@ -178,7 +181,20 @@ async fn main() -> Result<(), Box> { Scheme::HTTPS }; - let (clients, network) = make_clients(args.network.as_deref(), scheme, args.wait).await; + let (clients, networks) = if args.shard_count == 1 { + let (c, n) = make_clients(args.network.as_deref(), scheme, args.wait).await; + (vec![c], vec![n]) + } else { + make_sharded_clients( + args.network + .as_deref() + .expect("Network.toml is required for sharded queries"), + scheme, + args.wait, + ) + .await + }; + match args.action { ReportCollectorCommand::GenIpaInputs { count, @@ -193,20 +209,20 @@ async fn main() -> Result<(), Box> { ReportCollectorCommand::SemiHonestOprfIpaTest(config) => { ipa_test( &args, - &network, + &networks[0], IpaSecurityModel::SemiHonest, config, - &clients, + &clients[0], ) .await? } ReportCollectorCommand::MaliciousOprfIpaTest(config) => { ipa_test( &args, - &network, + &networks[0], IpaSecurityModel::Malicious, config, - &clients, + &clients[0], ) .await? } @@ -218,7 +234,7 @@ async fn main() -> Result<(), Box> { &args, IpaSecurityModel::Malicious, ipa_query_config, - &clients, + &clients[0], encrypted_inputs, ) .await? @@ -231,7 +247,7 @@ async fn main() -> Result<(), Box> { &args, IpaSecurityModel::SemiHonest, ipa_query_config, - &clients, + &clients[0], encrypted_inputs, ) .await? @@ -239,7 +255,7 @@ async fn main() -> Result<(), Box> { ReportCollectorCommand::MaliciousHybrid { ref encrypted_inputs, hybrid_query_config, - } => hybrid(&args, hybrid_query_config, &clients, encrypted_inputs).await?, + } => hybrid(&args, hybrid_query_config, clients, encrypted_inputs).await?, }; Ok(()) @@ -342,10 +358,49 @@ fn write_ipa_output_file( Ok(()) } +fn write_hybrid_output_file( + path: &PathBuf, + query_result: &HybridQueryResult, +) -> Result<(), Box> { + // it will be sad to lose the results if file already exists. + let path = if Path::is_file(path) { + let mut new_file_name = thread_rng() + .sample_iter(&Alphanumeric) + .take(5) + .map(char::from) + .collect::(); + let file_name = path.file_stem().ok_or("not a file")?; + + new_file_name.insert(0, '-'); + new_file_name.insert_str(0, &file_name.to_string_lossy()); + tracing::warn!( + "{} file exists, renaming to {:?}", + path.display(), + new_file_name + ); + + // it will not be 100% accurate until file_prefix API is stabilized + Cow::Owned( + path.with_file_name(&new_file_name) + .with_extension(path.extension().unwrap_or("".as_ref())), + ) + } else { + Cow::Borrowed(path) + }; + let mut file = File::options() + .write(true) + .create_new(true) + .open(path.deref()) + .map_err(|e| format!("Failed to create output file {}: {e}", path.display()))?; + + write!(file, "{}", serde_json::to_string_pretty(query_result)?)?; + Ok(()) +} + async fn hybrid( args: &Args, hybrid_query_config: HybridQueryParams, - helper_clients: &[IpaHttpClient; 3], + helper_clients: Vec<[IpaHttpClient; 3]>, encrypted_inputs: &EncryptedInputs, ) -> Result<(), Box> { let query_type = QueryType::MaliciousHybrid(hybrid_query_config); @@ -365,7 +420,7 @@ async fn hybrid( query_type, }; - let query_id = helper_clients[0] + let query_id = helper_clients[0][0] .create_query(query_config) .await .expect("Unable to create query!"); @@ -374,17 +429,17 @@ async fn hybrid( // the value for histogram values (BA32) must be kept in sync with the server-side // implementation, otherwise a runtime reconstruct error will be generated. // see ipa-core/src/query/executor.rs - let actual = run_query_and_validate::( + let actual = run_hybrid_query_and_validate::( encrypted_report_streams.streams, encrypted_report_streams.query_size, helper_clients, query_id, - hybrid_query_config.into(), + hybrid_query_config, ) .await; if let Some(ref path) = args.output_file { - write_ipa_output_file(path, &actual)?; + write_hybrid_output_file(path, &actual)?; } else { println!("{}", serde_json::to_string_pretty(&actual)?); } diff --git a/ipa-core/src/bin/test_mpc.rs b/ipa-core/src/bin/test_mpc.rs index dc509485f..baf99a2ca 100644 --- a/ipa-core/src/bin/test_mpc.rs +++ b/ipa-core/src/bin/test_mpc.rs @@ -119,7 +119,7 @@ async fn main() -> Result<(), Box> { } TestAction::ShardedShuffle => { // we need clients to talk to each individual shard - let clients = make_sharded_clients( + let (clients, _networks) = make_sharded_clients( args.network .as_deref() .expect("network config is required for sharded shuffle"), diff --git a/ipa-core/src/cli/playbook/hybrid.rs b/ipa-core/src/cli/playbook/hybrid.rs new file mode 100644 index 000000000..a8c80cffd --- /dev/null +++ b/ipa-core/src/cli/playbook/hybrid.rs @@ -0,0 +1,135 @@ +#![cfg(all(feature = "web-app", feature = "cli"))] +use std::{ + cmp::min, + time::{Duration, Instant}, +}; + +use futures_util::future::try_join_all; +use serde::{Deserialize, Serialize}; +use tokio::time::sleep; + +use crate::{ + ff::{Serializable, U128Conversions}, + helpers::{ + query::{HybridQueryParams, QueryInput, QuerySize}, + BodyStream, + }, + net::{Helper, IpaHttpClient}, + protocol::QueryId, + query::QueryStatus, + secret_sharing::{replicated::semi_honest::AdditiveShare, SharedValue}, + test_fixture::Reconstruct, +}; + +/// # Panics +/// if results are invalid +#[allow(clippy::disallowed_methods)] // allow try_join_all +pub async fn run_hybrid_query_and_validate( + inputs: [BodyStream; 3], + query_size: usize, + clients: Vec<[IpaHttpClient; 3]>, + query_id: QueryId, + query_config: HybridQueryParams, +) -> HybridQueryResult +where + HV: SharedValue + U128Conversions, + AdditiveShare: Serializable, +{ + let mpc_time = Instant::now(); + + // for now, submit everything to the leader. TODO: round robin submission + let leader_clients = &clients[0]; + try_join_all( + inputs + .into_iter() + .zip(leader_clients) + .map(|(input_stream, client)| { + client.query_input(QueryInput { + query_id, + input_stream, + }) + }), + ) + .await + .unwrap(); + + let mut delay = Duration::from_millis(125); + loop { + if try_join_all( + leader_clients + .iter() + .map(|client| client.query_status(query_id)), + ) + .await + .unwrap() + .into_iter() + .all(|status| status == QueryStatus::Completed) + { + break; + } + + sleep(delay).await; + delay = min(Duration::from_secs(5), delay * 2); + // TODO: Add a timeout of some sort. Possibly, add some sort of progress indicator to + // the status API so we can check whether the query is making progress. + } + + // wait until helpers have processed the query and get the results from them + let results: [_; 3] = try_join_all( + leader_clients + .iter() + .map(|client| client.query_results(query_id)), + ) + .await + .unwrap() + .try_into() + .unwrap(); + + let results: Vec = results + .map(|bytes| { + AdditiveShare::::from_byte_slice(&bytes) + .collect::, _>>() + .unwrap() + }) + .reconstruct(); + + let lat = mpc_time.elapsed(); + + tracing::info!("Running IPA for {query_size:?} records took {t:?}", t = lat); + let mut breakdowns = vec![0; usize::try_from(query_config.max_breakdown_key).unwrap()]; + for (breakdown_key, trigger_value) in results.into_iter().enumerate() { + // TODO: make the data type used consistent with `ipa_in_the_clear` + // I think using u32 is wrong, we should move to u128 + if query_config.with_dp == 0 { + // otherwise if DP is added trigger_values will not be zero due to noise + assert!( + breakdown_key < query_config.max_breakdown_key.try_into().unwrap() + || trigger_value == HV::ZERO, + "trigger values were attributed to buckets more than max breakdown key" + ); + } + + if breakdown_key < query_config.max_breakdown_key.try_into().unwrap() { + breakdowns[breakdown_key] += u32::try_from(trigger_value.as_u128()).unwrap(); + } + } + + HybridQueryResult { + input_size: QuerySize::try_from(query_size).unwrap(), + config: query_config, + latency: lat, + breakdowns, + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct HybridQueryResult { + pub input_size: QuerySize, + pub config: HybridQueryParams, + #[serde( + serialize_with = "crate::serde::duration::to_secs", + deserialize_with = "crate::serde::duration::from_secs" + )] + pub latency: Duration, + pub breakdowns: Vec, +} diff --git a/ipa-core/src/cli/playbook/mod.rs b/ipa-core/src/cli/playbook/mod.rs index 8900e04fa..47b228040 100644 --- a/ipa-core/src/cli/playbook/mod.rs +++ b/ipa-core/src/cli/playbook/mod.rs @@ -1,5 +1,6 @@ mod add; mod generator; +mod hybrid; mod input; mod ipa; mod multiply; @@ -16,7 +17,10 @@ pub use multiply::secure_mul; pub use sharded_shuffle::secure_shuffle; use tokio::time::sleep; -pub use self::ipa::{playbook_oprf_ipa, run_query_and_validate}; +pub use self::{ + hybrid::{run_hybrid_query_and_validate, HybridQueryResult}, + ipa::{playbook_oprf_ipa, run_query_and_validate}, +}; use crate::{ cli::config_parse::HelperNetworkConfigParseExt, config::{ClientConfig, NetworkConfig, PeerConfig}, @@ -227,11 +231,12 @@ pub async fn make_sharded_clients( network_path: &Path, scheme: Scheme, wait: usize, -) -> Vec<[IpaHttpClient; 3]> { +) -> (Vec<[IpaHttpClient; 3]>, Vec>) { let network = NetworkConfig::from_toml_str_sharded(&fs::read_to_string(network_path).unwrap()).unwrap(); let clients = network + .clone() .into_iter() .map(|network| { let network = network.override_scheme(&scheme); @@ -241,7 +246,7 @@ pub async fn make_sharded_clients( wait_for_servers(wait, &clients).await; - clients + (clients, network) } async fn wait_for_servers(mut wait: usize, clients: &[[IpaHttpClient; 3]]) { diff --git a/ipa-core/src/helpers/transport/query/hybrid.rs b/ipa-core/src/helpers/transport/query/hybrid.rs index 2e4b71740..63c80160e 100644 --- a/ipa-core/src/helpers/transport/query/hybrid.rs +++ b/ipa-core/src/helpers/transport/query/hybrid.rs @@ -1,8 +1,5 @@ use serde::{Deserialize, Serialize}; -#[cfg(all(feature = "web-app", feature = "cli"))] -use super::IpaQueryConfig; - #[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq)] #[cfg_attr(feature = "clap", derive(clap::Args))] pub struct HybridQueryParams { @@ -30,20 +27,3 @@ impl Default for HybridQueryParams { } } } - -// `IpaQueryConfig` is a super set of this, and to avoid an almost entire duplication of -// `run_query_and_validate`, we just convert this over. the unused fields (`per_user_credit_cap` -// and `attribution_window_seconds`) aren't used in that function. -// Once we deprecate OprfIpa, we can simply swap out IpaQueryConfig with HybridQueryParams -#[cfg(all(feature = "web-app", feature = "cli"))] -impl From for IpaQueryConfig { - fn from(params: HybridQueryParams) -> Self { - Self { - max_breakdown_key: params.max_breakdown_key, - with_dp: params.with_dp, - epsilon: params.epsilon, - plaintext_match_keys: params.plaintext_match_keys, - ..Default::default() - } - } -}