-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add CLI command to run hybrid query #1472
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,14 +14,16 @@ | |
use ipa_core::{ | ||
cli::{ | ||
playbook::{ | ||
make_clients, playbook_oprf_ipa, run_query_and_validate, validate, validate_dp, | ||
InputSource, | ||
make_clients, make_sharded_clients, playbook_oprf_ipa, run_hybrid_query_and_validate, | ||
run_query_and_validate, validate, validate_dp, HybridQueryResult, InputSource, | ||
}, | ||
CsvSerializer, IpaQueryResult, Verbosity, | ||
}, | ||
config::{KeyRegistries, NetworkConfig}, | ||
ff::{boolean_array::BA32, FieldType}, | ||
helpers::query::{DpMechanism, IpaQueryConfig, QueryConfig, QuerySize, QueryType}, | ||
helpers::query::{ | ||
DpMechanism, HybridQueryParams, IpaQueryConfig, QueryConfig, QuerySize, QueryType, | ||
}, | ||
net::{Helper, IpaHttpClient}, | ||
report::{EncryptedOprfReportStreams, DEFAULT_KEY_ID}, | ||
test_fixture::{ | ||
|
@@ -58,6 +60,9 @@ | |
#[arg(long, value_name = "OUTPUT_FILE")] | ||
output_file: Option<PathBuf>, | ||
|
||
#[arg(long, default_value_t = 1)] | ||
shard_count: usize, | ||
|
||
#[command(subcommand)] | ||
action: ReportCollectorCommand, | ||
} | ||
|
@@ -132,6 +137,13 @@ | |
#[clap(flatten)] | ||
ipa_query_config: IpaQueryConfig, | ||
}, | ||
MaliciousHybrid { | ||
#[clap(flatten)] | ||
encrypted_inputs: EncryptedInputs, | ||
|
||
#[clap(flatten)] | ||
hybrid_query_config: HybridQueryParams, | ||
}, | ||
} | ||
|
||
#[derive(Debug, clap::Args)] | ||
|
@@ -169,7 +181,20 @@ | |
Scheme::HTTPS | ||
}; | ||
|
||
let (clients, network) = make_clients(args.network.as_deref(), scheme, args.wait).await; | ||
let (clients, networks) = if args.shard_count == 1 { | ||
let (c, n) = make_clients(args.network.as_deref(), scheme, args.wait).await; | ||
(vec![c], vec![n]) | ||
} else { | ||
make_sharded_clients( | ||
args.network | ||
.as_deref() | ||
.expect("Network.toml is required for sharded queries"), | ||
scheme, | ||
args.wait, | ||
) | ||
.await | ||
}; | ||
|
||
match args.action { | ||
ReportCollectorCommand::GenIpaInputs { | ||
count, | ||
|
@@ -184,20 +209,20 @@ | |
ReportCollectorCommand::SemiHonestOprfIpaTest(config) => { | ||
ipa_test( | ||
&args, | ||
&network, | ||
&networks[0], | ||
IpaSecurityModel::SemiHonest, | ||
config, | ||
&clients, | ||
&clients[0], | ||
) | ||
.await? | ||
} | ||
ReportCollectorCommand::MaliciousOprfIpaTest(config) => { | ||
ipa_test( | ||
&args, | ||
&network, | ||
&networks[0], | ||
IpaSecurityModel::Malicious, | ||
config, | ||
&clients, | ||
&clients[0], | ||
) | ||
.await? | ||
} | ||
|
@@ -209,7 +234,7 @@ | |
&args, | ||
IpaSecurityModel::Malicious, | ||
ipa_query_config, | ||
&clients, | ||
&clients[0], | ||
encrypted_inputs, | ||
) | ||
.await? | ||
|
@@ -222,11 +247,15 @@ | |
&args, | ||
IpaSecurityModel::SemiHonest, | ||
ipa_query_config, | ||
&clients, | ||
&clients[0], | ||
encrypted_inputs, | ||
) | ||
.await? | ||
} | ||
ReportCollectorCommand::MaliciousHybrid { | ||
ref encrypted_inputs, | ||
hybrid_query_config, | ||
} => hybrid(&args, hybrid_query_config, clients, encrypted_inputs).await?, | ||
}; | ||
|
||
Ok(()) | ||
|
@@ -329,6 +358,94 @@ | |
Ok(()) | ||
} | ||
|
||
fn write_hybrid_output_file( | ||
path: &PathBuf, | ||
query_result: &HybridQueryResult, | ||
) -> Result<(), Box<dyn Error>> { | ||
// it will be sad to lose the results if file already exists. | ||
let path = if Path::is_file(path) { | ||
let mut new_file_name = thread_rng() | ||
.sample_iter(&Alphanumeric) | ||
.take(5) | ||
.map(char::from) | ||
.collect::<String>(); | ||
let file_name = path.file_stem().ok_or("not a file")?; | ||
|
||
new_file_name.insert(0, '-'); | ||
new_file_name.insert_str(0, &file_name.to_string_lossy()); | ||
tracing::warn!( | ||
"{} file exists, renaming to {:?}", | ||
path.display(), | ||
new_file_name | ||
); | ||
|
||
// it will not be 100% accurate until file_prefix API is stabilized | ||
Cow::Owned( | ||
path.with_file_name(&new_file_name) | ||
.with_extension(path.extension().unwrap_or("".as_ref())), | ||
) | ||
} else { | ||
Cow::Borrowed(path) | ||
}; | ||
let mut file = File::options() | ||
.write(true) | ||
.create_new(true) | ||
.open(path.deref()) | ||
.map_err(|e| format!("Failed to create output file {}: {e}", path.display()))?; | ||
|
||
write!(file, "{}", serde_json::to_string_pretty(query_result)?)?; | ||
Ok(()) | ||
} | ||
|
||
async fn hybrid( | ||
args: &Args, | ||
hybrid_query_config: HybridQueryParams, | ||
helper_clients: Vec<[IpaHttpClient<Helper>; 3]>, | ||
encrypted_inputs: &EncryptedInputs, | ||
) -> Result<(), Box<dyn Error>> { | ||
let query_type = QueryType::MaliciousHybrid(hybrid_query_config); | ||
|
||
let files = [ | ||
&encrypted_inputs.enc_input_file1, | ||
&encrypted_inputs.enc_input_file2, | ||
&encrypted_inputs.enc_input_file3, | ||
]; | ||
|
||
// despite the name, this is generic enough to work with hybrid | ||
let encrypted_report_streams = EncryptedOprfReportStreams::from(files); | ||
|
||
let query_config = QueryConfig { | ||
size: QuerySize::try_from(encrypted_report_streams.query_size).unwrap(), | ||
field_type: FieldType::Fp32BitPrime, | ||
query_type, | ||
}; | ||
|
||
let query_id = helper_clients[0][0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we first do a query status to make sure something else isn't running? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe it panics with "query already running" already, so that's effectively the same, no? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not really. The system does't have something like distributed transactions. Think of the following scenario; helper 1 gets the This problem should be worse now with more states spread around. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's fair, and we have the "kill" endpoint to help clean up that state that sometimes happens. I'm inclined to try and get this merged as is (which is the same as Oprf Ipa, but with shards), and punt on making this better once we can actually run a query. For the immediate testing needs, we'll be restarting the binaries regularly anyways. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't think we have a sharded kill, but we can easily restart the k8s cluster. I you think adding a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. certainly something we should add, but we likely want to do that across all shards, so let's punt for now. |
||
.create_query(query_config) | ||
.await | ||
.expect("Unable to create query!"); | ||
|
||
tracing::info!("Starting query for OPRF"); | ||
// the value for histogram values (BA32) must be kept in sync with the server-side | ||
// implementation, otherwise a runtime reconstruct error will be generated. | ||
// see ipa-core/src/query/executor.rs | ||
let actual = run_hybrid_query_and_validate::<BA32>( | ||
encrypted_report_streams.streams, | ||
encrypted_report_streams.query_size, | ||
helper_clients, | ||
query_id, | ||
hybrid_query_config, | ||
) | ||
.await; | ||
|
||
if let Some(ref path) = args.output_file { | ||
write_hybrid_output_file(path, &actual)?; | ||
} else { | ||
println!("{}", serde_json::to_string_pretty(&actual)?); | ||
} | ||
Ok(()) | ||
} | ||
|
||
async fn ipa( | ||
args: &Args, | ||
security_model: IpaSecurityModel, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
#![cfg(all(feature = "web-app", feature = "cli"))] | ||
use std::{ | ||
cmp::min, | ||
time::{Duration, Instant}, | ||
}; | ||
|
||
use futures_util::future::try_join_all; | ||
use serde::{Deserialize, Serialize}; | ||
use tokio::time::sleep; | ||
|
||
use crate::{ | ||
ff::{Serializable, U128Conversions}, | ||
helpers::{ | ||
query::{HybridQueryParams, QueryInput, QuerySize}, | ||
BodyStream, | ||
}, | ||
net::{Helper, IpaHttpClient}, | ||
protocol::QueryId, | ||
query::QueryStatus, | ||
secret_sharing::{replicated::semi_honest::AdditiveShare, SharedValue}, | ||
test_fixture::Reconstruct, | ||
}; | ||
|
||
/// # Panics | ||
/// if results are invalid | ||
#[allow(clippy::disallowed_methods)] // allow try_join_all | ||
pub async fn run_hybrid_query_and_validate<HV>( | ||
inputs: [BodyStream; 3], | ||
query_size: usize, | ||
clients: Vec<[IpaHttpClient<Helper>; 3]>, | ||
query_id: QueryId, | ||
query_config: HybridQueryParams, | ||
) -> HybridQueryResult | ||
where | ||
HV: SharedValue + U128Conversions, | ||
AdditiveShare<HV>: Serializable, | ||
{ | ||
let mpc_time = Instant::now(); | ||
|
||
// for now, submit everything to the leader. TODO: round robin submission | ||
let leader_clients = &clients[0]; | ||
try_join_all( | ||
inputs | ||
.into_iter() | ||
.zip(leader_clients) | ||
.map(|(input_stream, client)| { | ||
client.query_input(QueryInput { | ||
query_id, | ||
input_stream, | ||
}) | ||
}), | ||
) | ||
.await | ||
.unwrap(); | ||
Comment on lines
+41
to
+54
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @akoshelev I ended up needing to create a new function here that could take |
||
|
||
let mut delay = Duration::from_millis(125); | ||
loop { | ||
if try_join_all( | ||
leader_clients | ||
.iter() | ||
.map(|client| client.query_status(query_id)), | ||
) | ||
.await | ||
.unwrap() | ||
.into_iter() | ||
.all(|status| status == QueryStatus::Completed) | ||
{ | ||
break; | ||
} | ||
|
||
sleep(delay).await; | ||
delay = min(Duration::from_secs(5), delay * 2); | ||
// TODO: Add a timeout of some sort. Possibly, add some sort of progress indicator to | ||
// the status API so we can check whether the query is making progress. | ||
} | ||
|
||
// wait until helpers have processed the query and get the results from them | ||
let results: [_; 3] = try_join_all( | ||
leader_clients | ||
.iter() | ||
.map(|client| client.query_results(query_id)), | ||
) | ||
.await | ||
.unwrap() | ||
.try_into() | ||
.unwrap(); | ||
|
||
let results: Vec<HV> = results | ||
.map(|bytes| { | ||
AdditiveShare::<HV>::from_byte_slice(&bytes) | ||
.collect::<Result<Vec<_>, _>>() | ||
.unwrap() | ||
}) | ||
.reconstruct(); | ||
|
||
let lat = mpc_time.elapsed(); | ||
|
||
tracing::info!("Running IPA for {query_size:?} records took {t:?}", t = lat); | ||
let mut breakdowns = vec![0; usize::try_from(query_config.max_breakdown_key).unwrap()]; | ||
for (breakdown_key, trigger_value) in results.into_iter().enumerate() { | ||
// TODO: make the data type used consistent with `ipa_in_the_clear` | ||
// I think using u32 is wrong, we should move to u128 | ||
if query_config.with_dp == 0 { | ||
// otherwise if DP is added trigger_values will not be zero due to noise | ||
assert!( | ||
breakdown_key < query_config.max_breakdown_key.try_into().unwrap() | ||
|| trigger_value == HV::ZERO, | ||
"trigger values were attributed to buckets more than max breakdown key" | ||
); | ||
} | ||
|
||
if breakdown_key < query_config.max_breakdown_key.try_into().unwrap() { | ||
breakdowns[breakdown_key] += u32::try_from(trigger_value.as_u128()).unwrap(); | ||
} | ||
} | ||
|
||
HybridQueryResult { | ||
input_size: QuerySize::try_from(query_size).unwrap(), | ||
config: query_config, | ||
latency: lat, | ||
breakdowns, | ||
} | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
pub struct HybridQueryResult { | ||
pub input_size: QuerySize, | ||
pub config: HybridQueryParams, | ||
#[serde( | ||
serialize_with = "crate::serde::duration::to_secs", | ||
deserialize_with = "crate::serde::duration::from_secs" | ||
)] | ||
pub latency: Duration, | ||
pub breakdowns: Vec<u32>, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks oddly familiar, is it possible to re-use the same code and avoid copy paste here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same old problem with Oprf and Hybrid. we can either spend time supporting both or just delete the other when we remove Oprf.