diff --git a/ipa-core/src/bin/report_collector.rs b/ipa-core/src/bin/report_collector.rs index 126004fcc..aa6da0bb4 100644 --- a/ipa-core/src/bin/report_collector.rs +++ b/ipa-core/src/bin/report_collector.rs @@ -14,14 +14,16 @@ use hyper::http::uri::Scheme; use ipa_core::{ cli::{ playbook::{ - make_clients, playbook_oprf_ipa, run_query_and_validate, validate, validate_dp, - InputSource, + make_clients, make_sharded_clients, playbook_oprf_ipa, run_hybrid_query_and_validate, + run_query_and_validate, validate, validate_dp, HybridQueryResult, InputSource, }, CsvSerializer, IpaQueryResult, Verbosity, }, config::{KeyRegistries, NetworkConfig}, ff::{boolean_array::BA32, FieldType}, - helpers::query::{DpMechanism, IpaQueryConfig, QueryConfig, QuerySize, QueryType}, + helpers::query::{ + DpMechanism, HybridQueryParams, IpaQueryConfig, QueryConfig, QuerySize, QueryType, + }, net::{Helper, IpaHttpClient}, report::{EncryptedOprfReportStreams, DEFAULT_KEY_ID}, test_fixture::{ @@ -58,6 +60,9 @@ struct Args { #[arg(long, value_name = "OUTPUT_FILE")] output_file: Option, + #[arg(long, default_value_t = 1)] + shard_count: usize, + #[command(subcommand)] action: ReportCollectorCommand, } @@ -132,6 +137,13 @@ enum ReportCollectorCommand { #[clap(flatten)] ipa_query_config: IpaQueryConfig, }, + MaliciousHybrid { + #[clap(flatten)] + encrypted_inputs: EncryptedInputs, + + #[clap(flatten)] + hybrid_query_config: HybridQueryParams, + }, } #[derive(Debug, clap::Args)] @@ -169,7 +181,20 @@ async fn main() -> Result<(), Box> { Scheme::HTTPS }; - let (clients, network) = make_clients(args.network.as_deref(), scheme, args.wait).await; + let (clients, networks) = if args.shard_count == 1 { + let (c, n) = make_clients(args.network.as_deref(), scheme, args.wait).await; + (vec![c], vec![n]) + } else { + make_sharded_clients( + args.network + .as_deref() + .expect("Network.toml is required for sharded queries"), + scheme, + args.wait, + ) + .await + }; + match args.action { ReportCollectorCommand::GenIpaInputs { count, @@ -184,20 +209,20 @@ async fn main() -> Result<(), Box> { ReportCollectorCommand::SemiHonestOprfIpaTest(config) => { ipa_test( &args, - &network, + &networks[0], IpaSecurityModel::SemiHonest, config, - &clients, + &clients[0], ) .await? } ReportCollectorCommand::MaliciousOprfIpaTest(config) => { ipa_test( &args, - &network, + &networks[0], IpaSecurityModel::Malicious, config, - &clients, + &clients[0], ) .await? } @@ -209,7 +234,7 @@ async fn main() -> Result<(), Box> { &args, IpaSecurityModel::Malicious, ipa_query_config, - &clients, + &clients[0], encrypted_inputs, ) .await? @@ -222,11 +247,15 @@ async fn main() -> Result<(), Box> { &args, IpaSecurityModel::SemiHonest, ipa_query_config, - &clients, + &clients[0], encrypted_inputs, ) .await? } + ReportCollectorCommand::MaliciousHybrid { + ref encrypted_inputs, + hybrid_query_config, + } => hybrid(&args, hybrid_query_config, clients, encrypted_inputs).await?, }; Ok(()) @@ -329,6 +358,94 @@ fn write_ipa_output_file( Ok(()) } +fn write_hybrid_output_file( + path: &PathBuf, + query_result: &HybridQueryResult, +) -> Result<(), Box> { + // it will be sad to lose the results if file already exists. + let path = if Path::is_file(path) { + let mut new_file_name = thread_rng() + .sample_iter(&Alphanumeric) + .take(5) + .map(char::from) + .collect::(); + let file_name = path.file_stem().ok_or("not a file")?; + + new_file_name.insert(0, '-'); + new_file_name.insert_str(0, &file_name.to_string_lossy()); + tracing::warn!( + "{} file exists, renaming to {:?}", + path.display(), + new_file_name + ); + + // it will not be 100% accurate until file_prefix API is stabilized + Cow::Owned( + path.with_file_name(&new_file_name) + .with_extension(path.extension().unwrap_or("".as_ref())), + ) + } else { + Cow::Borrowed(path) + }; + let mut file = File::options() + .write(true) + .create_new(true) + .open(path.deref()) + .map_err(|e| format!("Failed to create output file {}: {e}", path.display()))?; + + write!(file, "{}", serde_json::to_string_pretty(query_result)?)?; + Ok(()) +} + +async fn hybrid( + args: &Args, + hybrid_query_config: HybridQueryParams, + helper_clients: Vec<[IpaHttpClient; 3]>, + encrypted_inputs: &EncryptedInputs, +) -> Result<(), Box> { + let query_type = QueryType::MaliciousHybrid(hybrid_query_config); + + let files = [ + &encrypted_inputs.enc_input_file1, + &encrypted_inputs.enc_input_file2, + &encrypted_inputs.enc_input_file3, + ]; + + // despite the name, this is generic enough to work with hybrid + let encrypted_report_streams = EncryptedOprfReportStreams::from(files); + + let query_config = QueryConfig { + size: QuerySize::try_from(encrypted_report_streams.query_size).unwrap(), + field_type: FieldType::Fp32BitPrime, + query_type, + }; + + let query_id = helper_clients[0][0] + .create_query(query_config) + .await + .expect("Unable to create query!"); + + tracing::info!("Starting query for OPRF"); + // the value for histogram values (BA32) must be kept in sync with the server-side + // implementation, otherwise a runtime reconstruct error will be generated. + // see ipa-core/src/query/executor.rs + let actual = run_hybrid_query_and_validate::( + encrypted_report_streams.streams, + encrypted_report_streams.query_size, + helper_clients, + query_id, + hybrid_query_config, + ) + .await; + + if let Some(ref path) = args.output_file { + write_hybrid_output_file(path, &actual)?; + } else { + println!("{}", serde_json::to_string_pretty(&actual)?); + } + Ok(()) +} + async fn ipa( args: &Args, security_model: IpaSecurityModel, diff --git a/ipa-core/src/bin/test_mpc.rs b/ipa-core/src/bin/test_mpc.rs index dc509485f..baf99a2ca 100644 --- a/ipa-core/src/bin/test_mpc.rs +++ b/ipa-core/src/bin/test_mpc.rs @@ -119,7 +119,7 @@ async fn main() -> Result<(), Box> { } TestAction::ShardedShuffle => { // we need clients to talk to each individual shard - let clients = make_sharded_clients( + let (clients, _networks) = make_sharded_clients( args.network .as_deref() .expect("network config is required for sharded shuffle"), diff --git a/ipa-core/src/cli/playbook/hybrid.rs b/ipa-core/src/cli/playbook/hybrid.rs new file mode 100644 index 000000000..a8c80cffd --- /dev/null +++ b/ipa-core/src/cli/playbook/hybrid.rs @@ -0,0 +1,135 @@ +#![cfg(all(feature = "web-app", feature = "cli"))] +use std::{ + cmp::min, + time::{Duration, Instant}, +}; + +use futures_util::future::try_join_all; +use serde::{Deserialize, Serialize}; +use tokio::time::sleep; + +use crate::{ + ff::{Serializable, U128Conversions}, + helpers::{ + query::{HybridQueryParams, QueryInput, QuerySize}, + BodyStream, + }, + net::{Helper, IpaHttpClient}, + protocol::QueryId, + query::QueryStatus, + secret_sharing::{replicated::semi_honest::AdditiveShare, SharedValue}, + test_fixture::Reconstruct, +}; + +/// # Panics +/// if results are invalid +#[allow(clippy::disallowed_methods)] // allow try_join_all +pub async fn run_hybrid_query_and_validate( + inputs: [BodyStream; 3], + query_size: usize, + clients: Vec<[IpaHttpClient; 3]>, + query_id: QueryId, + query_config: HybridQueryParams, +) -> HybridQueryResult +where + HV: SharedValue + U128Conversions, + AdditiveShare: Serializable, +{ + let mpc_time = Instant::now(); + + // for now, submit everything to the leader. TODO: round robin submission + let leader_clients = &clients[0]; + try_join_all( + inputs + .into_iter() + .zip(leader_clients) + .map(|(input_stream, client)| { + client.query_input(QueryInput { + query_id, + input_stream, + }) + }), + ) + .await + .unwrap(); + + let mut delay = Duration::from_millis(125); + loop { + if try_join_all( + leader_clients + .iter() + .map(|client| client.query_status(query_id)), + ) + .await + .unwrap() + .into_iter() + .all(|status| status == QueryStatus::Completed) + { + break; + } + + sleep(delay).await; + delay = min(Duration::from_secs(5), delay * 2); + // TODO: Add a timeout of some sort. Possibly, add some sort of progress indicator to + // the status API so we can check whether the query is making progress. + } + + // wait until helpers have processed the query and get the results from them + let results: [_; 3] = try_join_all( + leader_clients + .iter() + .map(|client| client.query_results(query_id)), + ) + .await + .unwrap() + .try_into() + .unwrap(); + + let results: Vec = results + .map(|bytes| { + AdditiveShare::::from_byte_slice(&bytes) + .collect::, _>>() + .unwrap() + }) + .reconstruct(); + + let lat = mpc_time.elapsed(); + + tracing::info!("Running IPA for {query_size:?} records took {t:?}", t = lat); + let mut breakdowns = vec![0; usize::try_from(query_config.max_breakdown_key).unwrap()]; + for (breakdown_key, trigger_value) in results.into_iter().enumerate() { + // TODO: make the data type used consistent with `ipa_in_the_clear` + // I think using u32 is wrong, we should move to u128 + if query_config.with_dp == 0 { + // otherwise if DP is added trigger_values will not be zero due to noise + assert!( + breakdown_key < query_config.max_breakdown_key.try_into().unwrap() + || trigger_value == HV::ZERO, + "trigger values were attributed to buckets more than max breakdown key" + ); + } + + if breakdown_key < query_config.max_breakdown_key.try_into().unwrap() { + breakdowns[breakdown_key] += u32::try_from(trigger_value.as_u128()).unwrap(); + } + } + + HybridQueryResult { + input_size: QuerySize::try_from(query_size).unwrap(), + config: query_config, + latency: lat, + breakdowns, + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct HybridQueryResult { + pub input_size: QuerySize, + pub config: HybridQueryParams, + #[serde( + serialize_with = "crate::serde::duration::to_secs", + deserialize_with = "crate::serde::duration::from_secs" + )] + pub latency: Duration, + pub breakdowns: Vec, +} diff --git a/ipa-core/src/cli/playbook/mod.rs b/ipa-core/src/cli/playbook/mod.rs index 8900e04fa..47b228040 100644 --- a/ipa-core/src/cli/playbook/mod.rs +++ b/ipa-core/src/cli/playbook/mod.rs @@ -1,5 +1,6 @@ mod add; mod generator; +mod hybrid; mod input; mod ipa; mod multiply; @@ -16,7 +17,10 @@ pub use multiply::secure_mul; pub use sharded_shuffle::secure_shuffle; use tokio::time::sleep; -pub use self::ipa::{playbook_oprf_ipa, run_query_and_validate}; +pub use self::{ + hybrid::{run_hybrid_query_and_validate, HybridQueryResult}, + ipa::{playbook_oprf_ipa, run_query_and_validate}, +}; use crate::{ cli::config_parse::HelperNetworkConfigParseExt, config::{ClientConfig, NetworkConfig, PeerConfig}, @@ -227,11 +231,12 @@ pub async fn make_sharded_clients( network_path: &Path, scheme: Scheme, wait: usize, -) -> Vec<[IpaHttpClient; 3]> { +) -> (Vec<[IpaHttpClient; 3]>, Vec>) { let network = NetworkConfig::from_toml_str_sharded(&fs::read_to_string(network_path).unwrap()).unwrap(); let clients = network + .clone() .into_iter() .map(|network| { let network = network.override_scheme(&scheme); @@ -241,7 +246,7 @@ pub async fn make_sharded_clients( wait_for_servers(wait, &clients).await; - clients + (clients, network) } async fn wait_for_servers(mut wait: usize, clients: &[[IpaHttpClient; 3]]) {