Skip to content

Commit

Permalink
update clients to support sharding in report_collector bin
Browse files Browse the repository at this point in the history
  • Loading branch information
eriktaubeneck committed Dec 4, 2024
1 parent c662266 commit 7651bd0
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 39 deletions.
85 changes: 70 additions & 15 deletions ipa-core/src/bin/report_collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ use hyper::http::uri::Scheme;
use ipa_core::{
cli::{
playbook::{
make_clients, playbook_oprf_ipa, run_query_and_validate, validate, validate_dp,
InputSource,
make_clients, make_sharded_clients, playbook_oprf_ipa, run_hybrid_query_and_validate,
run_query_and_validate, validate, validate_dp, HybridQueryResult, InputSource,
},
CsvSerializer, IpaQueryResult, Verbosity,
},
Expand Down Expand Up @@ -60,6 +60,9 @@ struct Args {
#[arg(long, value_name = "OUTPUT_FILE")]
output_file: Option<PathBuf>,

#[arg(long, default_value_t = 1)]
shard_count: usize,

#[command(subcommand)]
action: ReportCollectorCommand,
}
Expand Down Expand Up @@ -178,7 +181,20 @@ async fn main() -> Result<(), Box<dyn Error>> {
Scheme::HTTPS
};

let (clients, network) = make_clients(args.network.as_deref(), scheme, args.wait).await;
let (clients, networks) = if args.shard_count == 1 {
let (c, n) = make_clients(args.network.as_deref(), scheme, args.wait).await;
(vec![c], vec![n])
} else {
make_sharded_clients(
args.network
.as_deref()
.expect("Network.toml is required for sharded queries"),
scheme,
args.wait,
)
.await
};

match args.action {
ReportCollectorCommand::GenIpaInputs {
count,
Expand All @@ -193,20 +209,20 @@ async fn main() -> Result<(), Box<dyn Error>> {
ReportCollectorCommand::SemiHonestOprfIpaTest(config) => {
ipa_test(
&args,
&network,
&networks[0],
IpaSecurityModel::SemiHonest,
config,
&clients,
&clients[0],
)
.await?
}
ReportCollectorCommand::MaliciousOprfIpaTest(config) => {
ipa_test(
&args,
&network,
&networks[0],
IpaSecurityModel::Malicious,
config,
&clients,
&clients[0],
)
.await?
}
Expand All @@ -218,7 +234,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
&args,
IpaSecurityModel::Malicious,
ipa_query_config,
&clients,
&clients[0],
encrypted_inputs,
)
.await?
Expand All @@ -231,15 +247,15 @@ async fn main() -> Result<(), Box<dyn Error>> {
&args,
IpaSecurityModel::SemiHonest,
ipa_query_config,
&clients,
&clients[0],
encrypted_inputs,
)
.await?
}
ReportCollectorCommand::MaliciousHybrid {
ref encrypted_inputs,
hybrid_query_config,
} => hybrid(&args, hybrid_query_config, &clients, encrypted_inputs).await?,
} => hybrid(&args, hybrid_query_config, clients, encrypted_inputs).await?,
};

Ok(())
Expand Down Expand Up @@ -342,10 +358,49 @@ fn write_ipa_output_file(
Ok(())
}

fn write_hybrid_output_file(
path: &PathBuf,
query_result: &HybridQueryResult,
) -> Result<(), Box<dyn Error>> {
// it will be sad to lose the results if file already exists.
let path = if Path::is_file(path) {
let mut new_file_name = thread_rng()
.sample_iter(&Alphanumeric)
.take(5)
.map(char::from)
.collect::<String>();
let file_name = path.file_stem().ok_or("not a file")?;

new_file_name.insert(0, '-');
new_file_name.insert_str(0, &file_name.to_string_lossy());
tracing::warn!(
"{} file exists, renaming to {:?}",
path.display(),
new_file_name
);

// it will not be 100% accurate until file_prefix API is stabilized
Cow::Owned(
path.with_file_name(&new_file_name)
.with_extension(path.extension().unwrap_or("".as_ref())),
)
} else {
Cow::Borrowed(path)
};
let mut file = File::options()
.write(true)
.create_new(true)
.open(path.deref())
.map_err(|e| format!("Failed to create output file {}: {e}", path.display()))?;

write!(file, "{}", serde_json::to_string_pretty(query_result)?)?;
Ok(())
}

async fn hybrid(
args: &Args,
hybrid_query_config: HybridQueryParams,
helper_clients: &[IpaHttpClient<Helper>; 3],
helper_clients: Vec<[IpaHttpClient<Helper>; 3]>,
encrypted_inputs: &EncryptedInputs,
) -> Result<(), Box<dyn Error>> {
let query_type = QueryType::MaliciousHybrid(hybrid_query_config);
Expand All @@ -365,7 +420,7 @@ async fn hybrid(
query_type,
};

let query_id = helper_clients[0]
let query_id = helper_clients[0][0]
.create_query(query_config)
.await
.expect("Unable to create query!");
Expand All @@ -374,17 +429,17 @@ async fn hybrid(
// the value for histogram values (BA32) must be kept in sync with the server-side
// implementation, otherwise a runtime reconstruct error will be generated.
// see ipa-core/src/query/executor.rs
let actual = run_query_and_validate::<BA32>(
let actual = run_hybrid_query_and_validate::<BA32>(
encrypted_report_streams.streams,
encrypted_report_streams.query_size,
helper_clients,
query_id,
hybrid_query_config.into(),
hybrid_query_config,
)
.await;

if let Some(ref path) = args.output_file {
write_ipa_output_file(path, &actual)?;
write_hybrid_output_file(path, &actual)?;
} else {
println!("{}", serde_json::to_string_pretty(&actual)?);
}
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/bin/test_mpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
}
TestAction::ShardedShuffle => {
// we need clients to talk to each individual shard
let clients = make_sharded_clients(
let (clients, _networks) = make_sharded_clients(
args.network
.as_deref()
.expect("network config is required for sharded shuffle"),
Expand Down
135 changes: 135 additions & 0 deletions ipa-core/src/cli/playbook/hybrid.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
#![cfg(all(feature = "web-app", feature = "cli"))]
use std::{
cmp::min,
time::{Duration, Instant},
};

use futures_util::future::try_join_all;
use serde::{Deserialize, Serialize};
use tokio::time::sleep;

use crate::{
ff::{Serializable, U128Conversions},
helpers::{
query::{HybridQueryParams, QueryInput, QuerySize},
BodyStream,
},
net::{Helper, IpaHttpClient},
protocol::QueryId,
query::QueryStatus,
secret_sharing::{replicated::semi_honest::AdditiveShare, SharedValue},
test_fixture::Reconstruct,
};

/// # Panics
/// if results are invalid
#[allow(clippy::disallowed_methods)] // allow try_join_all
pub async fn run_hybrid_query_and_validate<HV>(
inputs: [BodyStream; 3],
query_size: usize,
clients: Vec<[IpaHttpClient<Helper>; 3]>,
query_id: QueryId,
query_config: HybridQueryParams,
) -> HybridQueryResult
where
HV: SharedValue + U128Conversions,
AdditiveShare<HV>: Serializable,
{
let mpc_time = Instant::now();

// for now, submit everything to the leader. TODO: round robin submission
let leader_clients = &clients[0];
try_join_all(
inputs
.into_iter()
.zip(leader_clients)
.map(|(input_stream, client)| {
client.query_input(QueryInput {
query_id,
input_stream,
})
}),
)
.await
.unwrap();

let mut delay = Duration::from_millis(125);
loop {
if try_join_all(
leader_clients
.iter()
.map(|client| client.query_status(query_id)),
)
.await
.unwrap()
.into_iter()
.all(|status| status == QueryStatus::Completed)
{
break;
}

sleep(delay).await;
delay = min(Duration::from_secs(5), delay * 2);
// TODO: Add a timeout of some sort. Possibly, add some sort of progress indicator to
// the status API so we can check whether the query is making progress.
}

// wait until helpers have processed the query and get the results from them
let results: [_; 3] = try_join_all(
leader_clients
.iter()
.map(|client| client.query_results(query_id)),
)
.await
.unwrap()
.try_into()
.unwrap();

let results: Vec<HV> = results
.map(|bytes| {
AdditiveShare::<HV>::from_byte_slice(&bytes)
.collect::<Result<Vec<_>, _>>()
.unwrap()
})
.reconstruct();

let lat = mpc_time.elapsed();

tracing::info!("Running IPA for {query_size:?} records took {t:?}", t = lat);
let mut breakdowns = vec![0; usize::try_from(query_config.max_breakdown_key).unwrap()];
for (breakdown_key, trigger_value) in results.into_iter().enumerate() {
// TODO: make the data type used consistent with `ipa_in_the_clear`
// I think using u32 is wrong, we should move to u128
if query_config.with_dp == 0 {
// otherwise if DP is added trigger_values will not be zero due to noise
assert!(
breakdown_key < query_config.max_breakdown_key.try_into().unwrap()
|| trigger_value == HV::ZERO,
"trigger values were attributed to buckets more than max breakdown key"
);
}

if breakdown_key < query_config.max_breakdown_key.try_into().unwrap() {
breakdowns[breakdown_key] += u32::try_from(trigger_value.as_u128()).unwrap();
}
}

HybridQueryResult {
input_size: QuerySize::try_from(query_size).unwrap(),
config: query_config,
latency: lat,
breakdowns,
}
}

#[derive(Debug, Serialize, Deserialize)]
pub struct HybridQueryResult {
pub input_size: QuerySize,
pub config: HybridQueryParams,
#[serde(
serialize_with = "crate::serde::duration::to_secs",
deserialize_with = "crate::serde::duration::from_secs"
)]
pub latency: Duration,
pub breakdowns: Vec<u32>,
}
11 changes: 8 additions & 3 deletions ipa-core/src/cli/playbook/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod add;
mod generator;
mod hybrid;
mod input;
mod ipa;
mod multiply;
Expand All @@ -16,7 +17,10 @@ pub use multiply::secure_mul;
pub use sharded_shuffle::secure_shuffle;
use tokio::time::sleep;

pub use self::ipa::{playbook_oprf_ipa, run_query_and_validate};
pub use self::{
hybrid::{run_hybrid_query_and_validate, HybridQueryResult},
ipa::{playbook_oprf_ipa, run_query_and_validate},
};
use crate::{
cli::config_parse::HelperNetworkConfigParseExt,
config::{ClientConfig, NetworkConfig, PeerConfig},
Expand Down Expand Up @@ -227,11 +231,12 @@ pub async fn make_sharded_clients(
network_path: &Path,
scheme: Scheme,
wait: usize,
) -> Vec<[IpaHttpClient<Helper>; 3]> {
) -> (Vec<[IpaHttpClient<Helper>; 3]>, Vec<NetworkConfig<Helper>>) {
let network =
NetworkConfig::from_toml_str_sharded(&fs::read_to_string(network_path).unwrap()).unwrap();

let clients = network
.clone()
.into_iter()
.map(|network| {
let network = network.override_scheme(&scheme);
Expand All @@ -241,7 +246,7 @@ pub async fn make_sharded_clients(

wait_for_servers(wait, &clients).await;

clients
(clients, network)
}

async fn wait_for_servers(mut wait: usize, clients: &[[IpaHttpClient<Helper>; 3]]) {
Expand Down
20 changes: 0 additions & 20 deletions ipa-core/src/helpers/transport/query/hybrid.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
use serde::{Deserialize, Serialize};

#[cfg(all(feature = "web-app", feature = "cli"))]
use super::IpaQueryConfig;

#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq)]
#[cfg_attr(feature = "clap", derive(clap::Args))]
pub struct HybridQueryParams {
Expand Down Expand Up @@ -30,20 +27,3 @@ impl Default for HybridQueryParams {
}
}
}

// `IpaQueryConfig` is a super set of this, and to avoid an almost entire duplication of
// `run_query_and_validate`, we just convert this over. the unused fields (`per_user_credit_cap`
// and `attribution_window_seconds`) aren't used in that function.
// Once we deprecate OprfIpa, we can simply swap out IpaQueryConfig with HybridQueryParams
#[cfg(all(feature = "web-app", feature = "cli"))]
impl From<HybridQueryParams> for IpaQueryConfig {
fn from(params: HybridQueryParams) -> Self {
Self {
max_breakdown_key: params.max_breakdown_key,
with_dp: params.with_dp,
epsilon: params.epsilon,
plaintext_match_keys: params.plaintext_match_keys,
..Default::default()
}
}
}

0 comments on commit 7651bd0

Please sign in to comment.