Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add CLI command to run hybrid query #1472

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 127 additions & 10 deletions ipa-core/src/bin/report_collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@
use ipa_core::{
cli::{
playbook::{
make_clients, playbook_oprf_ipa, run_query_and_validate, validate, validate_dp,
InputSource,
make_clients, make_sharded_clients, playbook_oprf_ipa, run_hybrid_query_and_validate,
run_query_and_validate, validate, validate_dp, HybridQueryResult, InputSource,
},
CsvSerializer, IpaQueryResult, Verbosity,
},
config::{KeyRegistries, NetworkConfig},
ff::{boolean_array::BA32, FieldType},
helpers::query::{DpMechanism, IpaQueryConfig, QueryConfig, QuerySize, QueryType},
helpers::query::{
DpMechanism, HybridQueryParams, IpaQueryConfig, QueryConfig, QuerySize, QueryType,
},
net::{Helper, IpaHttpClient},
report::{EncryptedOprfReportStreams, DEFAULT_KEY_ID},
test_fixture::{
Expand Down Expand Up @@ -58,6 +60,9 @@
#[arg(long, value_name = "OUTPUT_FILE")]
output_file: Option<PathBuf>,

#[arg(long, default_value_t = 1)]
shard_count: usize,

Check warning on line 64 in ipa-core/src/bin/report_collector.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/bin/report_collector.rs#L64

Added line #L64 was not covered by tests

#[command(subcommand)]
action: ReportCollectorCommand,
}
Expand Down Expand Up @@ -132,6 +137,13 @@
#[clap(flatten)]
ipa_query_config: IpaQueryConfig,
},
MaliciousHybrid {
#[clap(flatten)]
encrypted_inputs: EncryptedInputs,

#[clap(flatten)]
hybrid_query_config: HybridQueryParams,
},
}

#[derive(Debug, clap::Args)]
Expand Down Expand Up @@ -169,7 +181,20 @@
Scheme::HTTPS
};

let (clients, network) = make_clients(args.network.as_deref(), scheme, args.wait).await;
let (clients, networks) = if args.shard_count == 1 {
let (c, n) = make_clients(args.network.as_deref(), scheme, args.wait).await;
(vec![c], vec![n])
} else {
make_sharded_clients(
args.network
.as_deref()
.expect("Network.toml is required for sharded queries"),
scheme,
args.wait,
)
.await

Check warning on line 195 in ipa-core/src/bin/report_collector.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/bin/report_collector.rs#L189-L195

Added lines #L189 - L195 were not covered by tests
};

match args.action {
ReportCollectorCommand::GenIpaInputs {
count,
Expand All @@ -184,20 +209,20 @@
ReportCollectorCommand::SemiHonestOprfIpaTest(config) => {
ipa_test(
&args,
&network,
&networks[0],
IpaSecurityModel::SemiHonest,
config,
&clients,
&clients[0],
)
.await?
}
ReportCollectorCommand::MaliciousOprfIpaTest(config) => {
ipa_test(
&args,
&network,
&networks[0],

Check warning on line 222 in ipa-core/src/bin/report_collector.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/bin/report_collector.rs#L222

Added line #L222 was not covered by tests
IpaSecurityModel::Malicious,
config,
&clients,
&clients[0],

Check warning on line 225 in ipa-core/src/bin/report_collector.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/bin/report_collector.rs#L225

Added line #L225 was not covered by tests
)
.await?
}
Expand All @@ -209,7 +234,7 @@
&args,
IpaSecurityModel::Malicious,
ipa_query_config,
&clients,
&clients[0],

Check warning on line 237 in ipa-core/src/bin/report_collector.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/bin/report_collector.rs#L237

Added line #L237 was not covered by tests
encrypted_inputs,
)
.await?
Expand All @@ -222,11 +247,15 @@
&args,
IpaSecurityModel::SemiHonest,
ipa_query_config,
&clients,
&clients[0],
encrypted_inputs,
)
.await?
}
ReportCollectorCommand::MaliciousHybrid {
ref encrypted_inputs,
hybrid_query_config,
} => hybrid(&args, hybrid_query_config, clients, encrypted_inputs).await?,

Check warning on line 258 in ipa-core/src/bin/report_collector.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/bin/report_collector.rs#L257-L258

Added lines #L257 - L258 were not covered by tests
};

Ok(())
Expand Down Expand Up @@ -329,6 +358,94 @@
Ok(())
}

fn write_hybrid_output_file(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks oddly familiar, is it possible to re-use the same code and avoid copy paste here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same old problem with Oprf and Hybrid. we can either spend time supporting both or just delete the other when we remove Oprf.

path: &PathBuf,
query_result: &HybridQueryResult,
) -> Result<(), Box<dyn Error>> {

Check warning on line 364 in ipa-core/src/bin/report_collector.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/bin/report_collector.rs#L361-L364

Added lines #L361 - L364 were not covered by tests
// it will be sad to lose the results if file already exists.
let path = if Path::is_file(path) {
let mut new_file_name = thread_rng()
.sample_iter(&Alphanumeric)
.take(5)
.map(char::from)
.collect::<String>();
let file_name = path.file_stem().ok_or("not a file")?;

Check warning on line 372 in ipa-core/src/bin/report_collector.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/bin/report_collector.rs#L366-L372

Added lines #L366 - L372 were not covered by tests

new_file_name.insert(0, '-');
new_file_name.insert_str(0, &file_name.to_string_lossy());
tracing::warn!(
"{} file exists, renaming to {:?}",
path.display(),

Check warning on line 378 in ipa-core/src/bin/report_collector.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/bin/report_collector.rs#L374-L378

Added lines #L374 - L378 were not covered by tests
new_file_name
);

// it will not be 100% accurate until file_prefix API is stabilized
Cow::Owned(
path.with_file_name(&new_file_name)
.with_extension(path.extension().unwrap_or("".as_ref())),
)

Check warning on line 386 in ipa-core/src/bin/report_collector.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/bin/report_collector.rs#L383-L386

Added lines #L383 - L386 were not covered by tests
} else {
Cow::Borrowed(path)

Check warning on line 388 in ipa-core/src/bin/report_collector.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/bin/report_collector.rs#L388

Added line #L388 was not covered by tests
};
let mut file = File::options()
.write(true)
.create_new(true)
.open(path.deref())
.map_err(|e| format!("Failed to create output file {}: {e}", path.display()))?;

Check warning on line 394 in ipa-core/src/bin/report_collector.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/bin/report_collector.rs#L390-L394

Added lines #L390 - L394 were not covered by tests

write!(file, "{}", serde_json::to_string_pretty(query_result)?)?;
Ok(())
}

Check warning on line 398 in ipa-core/src/bin/report_collector.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/bin/report_collector.rs#L396-L398

Added lines #L396 - L398 were not covered by tests

async fn hybrid(
args: &Args,
hybrid_query_config: HybridQueryParams,
helper_clients: Vec<[IpaHttpClient<Helper>; 3]>,
encrypted_inputs: &EncryptedInputs,
) -> Result<(), Box<dyn Error>> {
let query_type = QueryType::MaliciousHybrid(hybrid_query_config);

let files = [
&encrypted_inputs.enc_input_file1,
&encrypted_inputs.enc_input_file2,
&encrypted_inputs.enc_input_file3,
];

// despite the name, this is generic enough to work with hybrid
let encrypted_report_streams = EncryptedOprfReportStreams::from(files);

let query_config = QueryConfig {
size: QuerySize::try_from(encrypted_report_streams.query_size).unwrap(),
field_type: FieldType::Fp32BitPrime,
query_type,
};

Check warning on line 421 in ipa-core/src/bin/report_collector.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/bin/report_collector.rs#L400-L421

Added lines #L400 - L421 were not covered by tests

let query_id = helper_clients[0][0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we first do a query status to make sure something else isn't running?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it panics with "query already running" already, so that's effectively the same, no?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. The system does't have something like distributed transactions. Think of the following scenario; helper 1 gets the create_query and sends prepare_query to 2 and 3. Say helper 2 return ok but helper 3 says its running something return error. We don't "rollback" helper2 so it's left awaiting inputs.

This problem should be worse now with more states spread around.

Copy link
Member Author

@eriktaubeneck eriktaubeneck Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fair, and we have the "kill" endpoint to help clean up that state that sometimes happens. I'm inclined to try and get this merged as is (which is the same as Oprf Ipa, but with shards), and punt on making this better once we can actually run a query. For the immediate testing needs, we'll be restarting the binaries regularly anyways.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think we have a sharded kill, but we can easily restart the k8s cluster.

I you think adding a query_status is going to take much time then I agree.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

certainly something we should add, but we likely want to do that across all shards, so let's punt for now.

.create_query(query_config)
.await
.expect("Unable to create query!");

tracing::info!("Starting query for OPRF");

Check warning on line 428 in ipa-core/src/bin/report_collector.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/bin/report_collector.rs#L423-L428

Added lines #L423 - L428 were not covered by tests
// the value for histogram values (BA32) must be kept in sync with the server-side
// implementation, otherwise a runtime reconstruct error will be generated.
// see ipa-core/src/query/executor.rs
let actual = run_hybrid_query_and_validate::<BA32>(
encrypted_report_streams.streams,
encrypted_report_streams.query_size,
helper_clients,
query_id,
hybrid_query_config,
)
.await;

Check warning on line 439 in ipa-core/src/bin/report_collector.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/bin/report_collector.rs#L432-L439

Added lines #L432 - L439 were not covered by tests

if let Some(ref path) = args.output_file {
write_hybrid_output_file(path, &actual)?;

Check warning on line 442 in ipa-core/src/bin/report_collector.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/bin/report_collector.rs#L441-L442

Added lines #L441 - L442 were not covered by tests
} else {
println!("{}", serde_json::to_string_pretty(&actual)?);

Check warning on line 444 in ipa-core/src/bin/report_collector.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/bin/report_collector.rs#L444

Added line #L444 was not covered by tests
}
Ok(())
}

Check warning on line 447 in ipa-core/src/bin/report_collector.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/bin/report_collector.rs#L446-L447

Added lines #L446 - L447 were not covered by tests

async fn ipa(
args: &Args,
security_model: IpaSecurityModel,
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/bin/test_mpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
}
TestAction::ShardedShuffle => {
// we need clients to talk to each individual shard
let clients = make_sharded_clients(
let (clients, _networks) = make_sharded_clients(
args.network
.as_deref()
.expect("network config is required for sharded shuffle"),
Expand Down
135 changes: 135 additions & 0 deletions ipa-core/src/cli/playbook/hybrid.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
#![cfg(all(feature = "web-app", feature = "cli"))]
use std::{
cmp::min,
time::{Duration, Instant},
};

use futures_util::future::try_join_all;
use serde::{Deserialize, Serialize};
use tokio::time::sleep;

use crate::{
ff::{Serializable, U128Conversions},
helpers::{
query::{HybridQueryParams, QueryInput, QuerySize},
BodyStream,
},
net::{Helper, IpaHttpClient},
protocol::QueryId,
query::QueryStatus,
secret_sharing::{replicated::semi_honest::AdditiveShare, SharedValue},
test_fixture::Reconstruct,
};

/// # Panics
/// if results are invalid
#[allow(clippy::disallowed_methods)] // allow try_join_all
pub async fn run_hybrid_query_and_validate<HV>(
inputs: [BodyStream; 3],
query_size: usize,
clients: Vec<[IpaHttpClient<Helper>; 3]>,
query_id: QueryId,
query_config: HybridQueryParams,
) -> HybridQueryResult
where
HV: SharedValue + U128Conversions,
AdditiveShare<HV>: Serializable,
{
let mpc_time = Instant::now();

// for now, submit everything to the leader. TODO: round robin submission
let leader_clients = &clients[0];
try_join_all(
inputs
.into_iter()
.zip(leader_clients)
.map(|(input_stream, client)| {
client.query_input(QueryInput {
query_id,
input_stream,
})
}),
)
.await
.unwrap();
Comment on lines +41 to +54
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@akoshelev I ended up needing to create a new function here that could take Vec<[IpaHttpClient<Helper>; 3]>. right now it submits everything to the leader, but this seems like the natural place to do the round robin submission.


let mut delay = Duration::from_millis(125);

Check warning on line 56 in ipa-core/src/cli/playbook/hybrid.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/cli/playbook/hybrid.rs#L27-L56

Added lines #L27 - L56 were not covered by tests
loop {
if try_join_all(
leader_clients
.iter()
.map(|client| client.query_status(query_id)),
)
.await
.unwrap()
.into_iter()
.all(|status| status == QueryStatus::Completed)

Check warning on line 66 in ipa-core/src/cli/playbook/hybrid.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/cli/playbook/hybrid.rs#L58-L66

Added lines #L58 - L66 were not covered by tests
{
break;
}

sleep(delay).await;
delay = min(Duration::from_secs(5), delay * 2);

Check warning on line 72 in ipa-core/src/cli/playbook/hybrid.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/cli/playbook/hybrid.rs#L68-L72

Added lines #L68 - L72 were not covered by tests
// TODO: Add a timeout of some sort. Possibly, add some sort of progress indicator to
// the status API so we can check whether the query is making progress.
}

// wait until helpers have processed the query and get the results from them
let results: [_; 3] = try_join_all(
leader_clients
.iter()
.map(|client| client.query_results(query_id)),
)
.await
.unwrap()
.try_into()
.unwrap();

let results: Vec<HV> = results
.map(|bytes| {
AdditiveShare::<HV>::from_byte_slice(&bytes)
.collect::<Result<Vec<_>, _>>()
.unwrap()
})
.reconstruct();

let lat = mpc_time.elapsed();

tracing::info!("Running IPA for {query_size:?} records took {t:?}", t = lat);
let mut breakdowns = vec![0; usize::try_from(query_config.max_breakdown_key).unwrap()];
for (breakdown_key, trigger_value) in results.into_iter().enumerate() {

Check warning on line 100 in ipa-core/src/cli/playbook/hybrid.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/cli/playbook/hybrid.rs#L78-L100

Added lines #L78 - L100 were not covered by tests
// TODO: make the data type used consistent with `ipa_in_the_clear`
// I think using u32 is wrong, we should move to u128
if query_config.with_dp == 0 {

Check warning on line 103 in ipa-core/src/cli/playbook/hybrid.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/cli/playbook/hybrid.rs#L103

Added line #L103 was not covered by tests
// otherwise if DP is added trigger_values will not be zero due to noise
assert!(
breakdown_key < query_config.max_breakdown_key.try_into().unwrap()
|| trigger_value == HV::ZERO,
"trigger values were attributed to buckets more than max breakdown key"

Check warning on line 108 in ipa-core/src/cli/playbook/hybrid.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/cli/playbook/hybrid.rs#L105-L108

Added lines #L105 - L108 were not covered by tests
);
}

Check warning on line 110 in ipa-core/src/cli/playbook/hybrid.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/cli/playbook/hybrid.rs#L110

Added line #L110 was not covered by tests

if breakdown_key < query_config.max_breakdown_key.try_into().unwrap() {
breakdowns[breakdown_key] += u32::try_from(trigger_value.as_u128()).unwrap();
}

Check warning on line 114 in ipa-core/src/cli/playbook/hybrid.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/cli/playbook/hybrid.rs#L112-L114

Added lines #L112 - L114 were not covered by tests
}

HybridQueryResult {
input_size: QuerySize::try_from(query_size).unwrap(),
config: query_config,
latency: lat,
breakdowns,
}
}

Check warning on line 123 in ipa-core/src/cli/playbook/hybrid.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/cli/playbook/hybrid.rs#L117-L123

Added lines #L117 - L123 were not covered by tests

#[derive(Debug, Serialize, Deserialize)]

Check warning on line 125 in ipa-core/src/cli/playbook/hybrid.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/cli/playbook/hybrid.rs#L125

Added line #L125 was not covered by tests
pub struct HybridQueryResult {
pub input_size: QuerySize,
pub config: HybridQueryParams,
#[serde(
serialize_with = "crate::serde::duration::to_secs",
deserialize_with = "crate::serde::duration::from_secs"
)]
pub latency: Duration,
pub breakdowns: Vec<u32>,
}
11 changes: 8 additions & 3 deletions ipa-core/src/cli/playbook/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod add;
mod generator;
mod hybrid;
mod input;
mod ipa;
mod multiply;
Expand All @@ -16,7 +17,10 @@ pub use multiply::secure_mul;
pub use sharded_shuffle::secure_shuffle;
use tokio::time::sleep;

pub use self::ipa::{playbook_oprf_ipa, run_query_and_validate};
pub use self::{
hybrid::{run_hybrid_query_and_validate, HybridQueryResult},
ipa::{playbook_oprf_ipa, run_query_and_validate},
};
use crate::{
cli::config_parse::HelperNetworkConfigParseExt,
config::{ClientConfig, NetworkConfig, PeerConfig},
Expand Down Expand Up @@ -227,11 +231,12 @@ pub async fn make_sharded_clients(
network_path: &Path,
scheme: Scheme,
wait: usize,
) -> Vec<[IpaHttpClient<Helper>; 3]> {
) -> (Vec<[IpaHttpClient<Helper>; 3]>, Vec<NetworkConfig<Helper>>) {
let network =
NetworkConfig::from_toml_str_sharded(&fs::read_to_string(network_path).unwrap()).unwrap();

let clients = network
.clone()
.into_iter()
.map(|network| {
let network = network.override_scheme(&scheme);
Expand All @@ -241,7 +246,7 @@ pub async fn make_sharded_clients(

wait_for_servers(wait, &clients).await;

clients
(clients, network)
}

async fn wait_for_servers(mut wait: usize, clients: &[[IpaHttpClient<Helper>; 3]]) {
Expand Down
Loading