Skip to content

Commit

Permalink
add CLI command to run hybrid query with shards (#1472)
Browse files Browse the repository at this point in the history
  • Loading branch information
eriktaubeneck authored Dec 5, 2024
1 parent 341828c commit 26ee498
Show file tree
Hide file tree
Showing 4 changed files with 271 additions and 14 deletions.
137 changes: 127 additions & 10 deletions ipa-core/src/bin/report_collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@ use hyper::http::uri::Scheme;
use ipa_core::{
cli::{
playbook::{
make_clients, playbook_oprf_ipa, run_query_and_validate, validate, validate_dp,
InputSource,
make_clients, make_sharded_clients, playbook_oprf_ipa, run_hybrid_query_and_validate,
run_query_and_validate, validate, validate_dp, HybridQueryResult, InputSource,
},
CsvSerializer, IpaQueryResult, Verbosity,
},
config::{KeyRegistries, NetworkConfig},
ff::{boolean_array::BA32, FieldType},
helpers::query::{DpMechanism, IpaQueryConfig, QueryConfig, QuerySize, QueryType},
helpers::query::{
DpMechanism, HybridQueryParams, IpaQueryConfig, QueryConfig, QuerySize, QueryType,
},
net::{Helper, IpaHttpClient},
report::{EncryptedOprfReportStreams, DEFAULT_KEY_ID},
test_fixture::{
Expand Down Expand Up @@ -58,6 +60,9 @@ struct Args {
#[arg(long, value_name = "OUTPUT_FILE")]
output_file: Option<PathBuf>,

#[arg(long, default_value_t = 1)]
shard_count: usize,

#[command(subcommand)]
action: ReportCollectorCommand,
}
Expand Down Expand Up @@ -132,6 +137,13 @@ enum ReportCollectorCommand {
#[clap(flatten)]
ipa_query_config: IpaQueryConfig,
},
MaliciousHybrid {
#[clap(flatten)]
encrypted_inputs: EncryptedInputs,

#[clap(flatten)]
hybrid_query_config: HybridQueryParams,
},
}

#[derive(Debug, clap::Args)]
Expand Down Expand Up @@ -169,7 +181,20 @@ async fn main() -> Result<(), Box<dyn Error>> {
Scheme::HTTPS
};

let (clients, network) = make_clients(args.network.as_deref(), scheme, args.wait).await;
let (clients, networks) = if args.shard_count == 1 {
let (c, n) = make_clients(args.network.as_deref(), scheme, args.wait).await;
(vec![c], vec![n])
} else {
make_sharded_clients(
args.network
.as_deref()
.expect("Network.toml is required for sharded queries"),
scheme,
args.wait,
)
.await
};

match args.action {
ReportCollectorCommand::GenIpaInputs {
count,
Expand All @@ -184,20 +209,20 @@ async fn main() -> Result<(), Box<dyn Error>> {
ReportCollectorCommand::SemiHonestOprfIpaTest(config) => {
ipa_test(
&args,
&network,
&networks[0],
IpaSecurityModel::SemiHonest,
config,
&clients,
&clients[0],
)
.await?
}
ReportCollectorCommand::MaliciousOprfIpaTest(config) => {
ipa_test(
&args,
&network,
&networks[0],
IpaSecurityModel::Malicious,
config,
&clients,
&clients[0],
)
.await?
}
Expand All @@ -209,7 +234,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
&args,
IpaSecurityModel::Malicious,
ipa_query_config,
&clients,
&clients[0],
encrypted_inputs,
)
.await?
Expand All @@ -222,11 +247,15 @@ async fn main() -> Result<(), Box<dyn Error>> {
&args,
IpaSecurityModel::SemiHonest,
ipa_query_config,
&clients,
&clients[0],
encrypted_inputs,
)
.await?
}
ReportCollectorCommand::MaliciousHybrid {
ref encrypted_inputs,
hybrid_query_config,
} => hybrid(&args, hybrid_query_config, clients, encrypted_inputs).await?,
};

Ok(())
Expand Down Expand Up @@ -329,6 +358,94 @@ fn write_ipa_output_file(
Ok(())
}

fn write_hybrid_output_file(
path: &PathBuf,
query_result: &HybridQueryResult,
) -> Result<(), Box<dyn Error>> {
// it will be sad to lose the results if file already exists.
let path = if Path::is_file(path) {
let mut new_file_name = thread_rng()
.sample_iter(&Alphanumeric)
.take(5)
.map(char::from)
.collect::<String>();
let file_name = path.file_stem().ok_or("not a file")?;

new_file_name.insert(0, '-');
new_file_name.insert_str(0, &file_name.to_string_lossy());
tracing::warn!(
"{} file exists, renaming to {:?}",
path.display(),
new_file_name
);

// it will not be 100% accurate until file_prefix API is stabilized
Cow::Owned(
path.with_file_name(&new_file_name)
.with_extension(path.extension().unwrap_or("".as_ref())),
)
} else {
Cow::Borrowed(path)
};
let mut file = File::options()
.write(true)
.create_new(true)
.open(path.deref())
.map_err(|e| format!("Failed to create output file {}: {e}", path.display()))?;

write!(file, "{}", serde_json::to_string_pretty(query_result)?)?;
Ok(())
}

async fn hybrid(
args: &Args,
hybrid_query_config: HybridQueryParams,
helper_clients: Vec<[IpaHttpClient<Helper>; 3]>,
encrypted_inputs: &EncryptedInputs,
) -> Result<(), Box<dyn Error>> {
let query_type = QueryType::MaliciousHybrid(hybrid_query_config);

let files = [
&encrypted_inputs.enc_input_file1,
&encrypted_inputs.enc_input_file2,
&encrypted_inputs.enc_input_file3,
];

// despite the name, this is generic enough to work with hybrid
let encrypted_report_streams = EncryptedOprfReportStreams::from(files);

let query_config = QueryConfig {
size: QuerySize::try_from(encrypted_report_streams.query_size).unwrap(),
field_type: FieldType::Fp32BitPrime,
query_type,
};

let query_id = helper_clients[0][0]
.create_query(query_config)
.await
.expect("Unable to create query!");

tracing::info!("Starting query for OPRF");
// the value for histogram values (BA32) must be kept in sync with the server-side
// implementation, otherwise a runtime reconstruct error will be generated.
// see ipa-core/src/query/executor.rs
let actual = run_hybrid_query_and_validate::<BA32>(
encrypted_report_streams.streams,
encrypted_report_streams.query_size,
helper_clients,
query_id,
hybrid_query_config,
)
.await;

if let Some(ref path) = args.output_file {
write_hybrid_output_file(path, &actual)?;
} else {
println!("{}", serde_json::to_string_pretty(&actual)?);
}
Ok(())
}

async fn ipa(
args: &Args,
security_model: IpaSecurityModel,
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/bin/test_mpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
}
TestAction::ShardedShuffle => {
// we need clients to talk to each individual shard
let clients = make_sharded_clients(
let (clients, _networks) = make_sharded_clients(
args.network
.as_deref()
.expect("network config is required for sharded shuffle"),
Expand Down
135 changes: 135 additions & 0 deletions ipa-core/src/cli/playbook/hybrid.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
#![cfg(all(feature = "web-app", feature = "cli"))]
use std::{
cmp::min,
time::{Duration, Instant},
};

use futures_util::future::try_join_all;
use serde::{Deserialize, Serialize};
use tokio::time::sleep;

use crate::{
ff::{Serializable, U128Conversions},
helpers::{
query::{HybridQueryParams, QueryInput, QuerySize},
BodyStream,
},
net::{Helper, IpaHttpClient},
protocol::QueryId,
query::QueryStatus,
secret_sharing::{replicated::semi_honest::AdditiveShare, SharedValue},
test_fixture::Reconstruct,
};

/// # Panics
/// if results are invalid
#[allow(clippy::disallowed_methods)] // allow try_join_all
pub async fn run_hybrid_query_and_validate<HV>(
inputs: [BodyStream; 3],
query_size: usize,
clients: Vec<[IpaHttpClient<Helper>; 3]>,
query_id: QueryId,
query_config: HybridQueryParams,
) -> HybridQueryResult
where
HV: SharedValue + U128Conversions,
AdditiveShare<HV>: Serializable,
{
let mpc_time = Instant::now();

// for now, submit everything to the leader. TODO: round robin submission
let leader_clients = &clients[0];
try_join_all(
inputs
.into_iter()
.zip(leader_clients)
.map(|(input_stream, client)| {
client.query_input(QueryInput {
query_id,
input_stream,
})
}),
)
.await
.unwrap();

let mut delay = Duration::from_millis(125);
loop {
if try_join_all(
leader_clients
.iter()
.map(|client| client.query_status(query_id)),
)
.await
.unwrap()
.into_iter()
.all(|status| status == QueryStatus::Completed)
{
break;
}

sleep(delay).await;
delay = min(Duration::from_secs(5), delay * 2);
// TODO: Add a timeout of some sort. Possibly, add some sort of progress indicator to
// the status API so we can check whether the query is making progress.
}

// wait until helpers have processed the query and get the results from them
let results: [_; 3] = try_join_all(
leader_clients
.iter()
.map(|client| client.query_results(query_id)),
)
.await
.unwrap()
.try_into()
.unwrap();

let results: Vec<HV> = results
.map(|bytes| {
AdditiveShare::<HV>::from_byte_slice(&bytes)
.collect::<Result<Vec<_>, _>>()
.unwrap()
})
.reconstruct();

let lat = mpc_time.elapsed();

tracing::info!("Running IPA for {query_size:?} records took {t:?}", t = lat);
let mut breakdowns = vec![0; usize::try_from(query_config.max_breakdown_key).unwrap()];
for (breakdown_key, trigger_value) in results.into_iter().enumerate() {
// TODO: make the data type used consistent with `ipa_in_the_clear`
// I think using u32 is wrong, we should move to u128
if query_config.with_dp == 0 {
// otherwise if DP is added trigger_values will not be zero due to noise
assert!(
breakdown_key < query_config.max_breakdown_key.try_into().unwrap()
|| trigger_value == HV::ZERO,
"trigger values were attributed to buckets more than max breakdown key"
);
}

if breakdown_key < query_config.max_breakdown_key.try_into().unwrap() {
breakdowns[breakdown_key] += u32::try_from(trigger_value.as_u128()).unwrap();
}
}

HybridQueryResult {
input_size: QuerySize::try_from(query_size).unwrap(),
config: query_config,
latency: lat,
breakdowns,
}
}

#[derive(Debug, Serialize, Deserialize)]
pub struct HybridQueryResult {
pub input_size: QuerySize,
pub config: HybridQueryParams,
#[serde(
serialize_with = "crate::serde::duration::to_secs",
deserialize_with = "crate::serde::duration::from_secs"
)]
pub latency: Duration,
pub breakdowns: Vec<u32>,
}
11 changes: 8 additions & 3 deletions ipa-core/src/cli/playbook/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod add;
mod generator;
mod hybrid;
mod input;
mod ipa;
mod multiply;
Expand All @@ -16,7 +17,10 @@ pub use multiply::secure_mul;
pub use sharded_shuffle::secure_shuffle;
use tokio::time::sleep;

pub use self::ipa::{playbook_oprf_ipa, run_query_and_validate};
pub use self::{
hybrid::{run_hybrid_query_and_validate, HybridQueryResult},
ipa::{playbook_oprf_ipa, run_query_and_validate},
};
use crate::{
cli::config_parse::HelperNetworkConfigParseExt,
config::{ClientConfig, NetworkConfig, PeerConfig},
Expand Down Expand Up @@ -227,11 +231,12 @@ pub async fn make_sharded_clients(
network_path: &Path,
scheme: Scheme,
wait: usize,
) -> Vec<[IpaHttpClient<Helper>; 3]> {
) -> (Vec<[IpaHttpClient<Helper>; 3]>, Vec<NetworkConfig<Helper>>) {
let network =
NetworkConfig::from_toml_str_sharded(&fs::read_to_string(network_path).unwrap()).unwrap();

let clients = network
.clone()
.into_iter()
.map(|network| {
let network = network.override_scheme(&scheme);
Expand All @@ -241,7 +246,7 @@ pub async fn make_sharded_clients(

wait_for_servers(wait, &clients).await;

clients
(clients, network)
}

async fn wait_for_servers(mut wait: usize, clients: &[[IpaHttpClient<Helper>; 3]]) {
Expand Down

0 comments on commit 26ee498

Please sign in to comment.