From e4d833dd5e34875466cc6f8ceff46d0e8dde0937 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Tue, 3 Dec 2024 10:55:49 -0800 Subject: [PATCH] Sharded shuffle http e2e test (#1464) * Some plumbing work for sharded shuffle e2e Sharded shuffle e2e test without HTTP handler * temp * temp2 * For Alex. Unit tests pass * added num_arg * Sharded shuffle HTTP end-to-end test This finalizes the plumbing for HTTP stack and verifies that it is working correctly end-to-end. * Fix compact gate integration tests * Fix compact gate tests * Fix merge issues * Improve documentation a bit --------- Co-authored-by: Christian Berkhoff --- ipa-core/src/bin/test_mpc.rs | 53 ++++- ipa-core/src/cli/config_parse.rs | 19 ++ ipa-core/src/cli/keygen.rs | 2 +- ipa-core/src/cli/playbook/mod.rs | 50 ++++- ipa-core/src/cli/playbook/sharded_shuffle.rs | 97 +++++++++ ipa-core/src/cli/test_setup.rs | 172 +++++++++++++++- ipa-core/src/helpers/transport/routing.rs | 7 + ipa-core/src/net/client/mod.rs | 50 ++--- ipa-core/src/net/error.rs | 2 +- ipa-core/src/net/transport.rs | 9 +- .../src/protocol/ipa_prf/shuffle/sharded.rs | 9 +- ipa-core/src/protocol/ipa_prf/shuffle/step.rs | 7 + ipa-core/src/query/processor.rs | 13 +- ipa-core/tests/common/mod.rs | 186 +++++++++++++----- ipa-core/tests/helper_networks.rs | 68 ++++++- 15 files changed, 627 insertions(+), 117 deletions(-) create mode 100644 ipa-core/src/cli/playbook/sharded_shuffle.rs diff --git a/ipa-core/src/bin/test_mpc.rs b/ipa-core/src/bin/test_mpc.rs index 1f8484ac7..dc509485f 100644 --- a/ipa-core/src/bin/test_mpc.rs +++ b/ipa-core/src/bin/test_mpc.rs @@ -5,13 +5,18 @@ use generic_array::ArrayLength; use hyper::http::uri::Scheme; use ipa_core::{ cli::{ - playbook::{make_clients, secure_add, secure_mul, validate, InputSource}, + playbook::{ + make_clients, make_sharded_clients, secure_add, secure_mul, secure_shuffle, validate, + InputSource, + }, Verbosity, }, - ff::{Field, FieldType, Fp31, Fp32BitPrime, Serializable, U128Conversions}, + ff::{ + boolean_array::BA64, Field, FieldType, Fp31, Fp32BitPrime, Serializable, U128Conversions, + }, helpers::query::{ QueryConfig, - QueryType::{TestAddInPrimeField, TestMultiply}, + QueryType::{TestAddInPrimeField, TestMultiply, TestShardedShuffle}, }, net::{Helper, IpaHttpClient}, secret_sharing::{replicated::semi_honest::AdditiveShare, IntoShares}, @@ -103,11 +108,27 @@ async fn main() -> Result<(), Box> { Scheme::HTTPS }; - let (clients, _) = make_clients(args.network.as_deref(), scheme, args.wait).await; match args.action { - TestAction::Multiply => multiply(&args, &clients).await, - TestAction::AddInPrimeField => add(&args, &clients).await, - TestAction::ShardedShuffle => sharded_shuffle(&args, &clients).await, + TestAction::Multiply => { + let (clients, _) = make_clients(args.network.as_deref(), scheme, args.wait).await; + multiply(&args, &clients).await + } + TestAction::AddInPrimeField => { + let (clients, _) = make_clients(args.network.as_deref(), scheme, args.wait).await; + add(&args, &clients).await + } + TestAction::ShardedShuffle => { + // we need clients to talk to each individual shard + let clients = make_sharded_clients( + args.network + .as_deref() + .expect("network config is required for sharded shuffle"), + scheme, + args.wait, + ) + .await; + sharded_shuffle(&args, clients).await + } }; Ok(()) @@ -166,6 +187,20 @@ async fn add(args: &Args, helper_clients: &[IpaHttpClient; 3]) { }; } -async fn sharded_shuffle(_args: &Args, _helper_clients: &[IpaHttpClient; 3]) { - unimplemented!() +async fn sharded_shuffle(args: &Args, helper_clients: Vec<[IpaHttpClient; 3]>) { + let input = InputSource::from(&args.input); + let input_rows = input + .iter::() + .map(BA64::truncate_from) + .collect::>(); + let query_config = + QueryConfig::new(TestShardedShuffle, args.input.field, input_rows.len()).unwrap(); + let query_id = helper_clients[0][0] + .create_query(query_config) + .await + .unwrap(); + let shuffled = secure_shuffle(input_rows.clone(), &helper_clients, query_id).await; + + assert_eq!(shuffled.len(), input_rows.len()); + assert_ne!(shuffled, input_rows); } diff --git a/ipa-core/src/cli/config_parse.rs b/ipa-core/src/cli/config_parse.rs index ec8c17b54..76cb30e96 100644 --- a/ipa-core/src/cli/config_parse.rs +++ b/ipa-core/src/cli/config_parse.rs @@ -255,6 +255,7 @@ fn assert_hpke_config(expected: &Value, actual: Option<&HpkeClientConfig>) { #[allow(dead_code)] pub trait HelperNetworkConfigParseExt { fn from_toml_str(input: &str) -> Result, Error>; + fn from_toml_str_sharded(input: &str) -> Result>, Error>; } /// Reads config from string. Expects config to be toml format. @@ -274,6 +275,24 @@ impl HelperNetworkConfigParseExt for NetworkConfig { all_network.client.clone(), )) } + fn from_toml_str_sharded(input: &str) -> Result>, Error> { + let all_network = parse_sharded_network_toml(input)?; + // peers are grouped by shard, meaning 0,1,2 describe MPC for shard 0. + // 3,4,5 describe shard 1, etc. + Ok(all_network + .peers + .chunks(3) + .map(|mpc_config| { + NetworkConfig::new_mpc( + mpc_config + .iter() + .map(ShardedPeerConfigToml::to_mpc_peer) + .collect(), + all_network.client.clone(), + ) + }) + .collect()) + } } /// Reads a the config for a specific, single, sharded server from string. Expects config to be diff --git a/ipa-core/src/cli/keygen.rs b/ipa-core/src/cli/keygen.rs index d75e1f27c..1bed8d420 100644 --- a/ipa-core/src/cli/keygen.rs +++ b/ipa-core/src/cli/keygen.rs @@ -15,7 +15,7 @@ use time::{Duration, OffsetDateTime}; use crate::{error::BoxError, hpke::KeyPair}; -#[derive(Debug, Args)] +#[derive(Debug, Clone, Args)] #[clap( name = "keygen", about = "Generate keys used by an MPC helper", diff --git a/ipa-core/src/cli/playbook/mod.rs b/ipa-core/src/cli/playbook/mod.rs index f6abcbe38..8900e04fa 100644 --- a/ipa-core/src/cli/playbook/mod.rs +++ b/ipa-core/src/cli/playbook/mod.rs @@ -3,6 +3,7 @@ mod generator; mod input; mod ipa; mod multiply; +mod sharded_shuffle; use core::fmt::Debug; use std::{fs, path::Path, time::Duration}; @@ -12,6 +13,7 @@ use comfy_table::{Cell, Color, Table}; use hyper::http::uri::Scheme; pub use input::InputSource; pub use multiply::secure_mul; +pub use sharded_shuffle::secure_shuffle; use tokio::time::sleep; pub use self::ipa::{playbook_oprf_ipa, run_query_and_validate}; @@ -196,7 +198,6 @@ pub async fn make_clients( scheme: Scheme, wait: usize, ) -> ([IpaHttpClient; 3], NetworkConfig) { - let mut wait = wait; let network = if let Some(path) = network_path { NetworkConfig::from_toml_str(&fs::read_to_string(path).unwrap()).unwrap() } else { @@ -214,16 +215,51 @@ pub async fn make_clients( // Note: This closure is only called when the selected action uses clients. let clients = IpaHttpClient::from_conf(&IpaRuntime::current(), &network, &ClientIdentity::None); - while wait > 0 && !clients_ready(&clients).await { + wait_for_servers(wait, &[clients.clone()]).await; + (clients, network) +} + +/// Creates enough clients to talk to all shards on MPC helpers. This only supports +/// reading configuration from the `network.toml` file +/// ## Panics +/// If configuration file `network_path` cannot be read from or if it does not conform to toml spec. +pub async fn make_sharded_clients( + network_path: &Path, + scheme: Scheme, + wait: usize, +) -> Vec<[IpaHttpClient; 3]> { + let network = + NetworkConfig::from_toml_str_sharded(&fs::read_to_string(network_path).unwrap()).unwrap(); + + let clients = network + .into_iter() + .map(|network| { + let network = network.override_scheme(&scheme); + IpaHttpClient::from_conf(&IpaRuntime::current(), &network, &ClientIdentity::None) + }) + .collect::>(); + + wait_for_servers(wait, &clients).await; + + clients +} + +async fn wait_for_servers(mut wait: usize, clients: &[[IpaHttpClient; 3]]) { + while wait > 0 && !clients_ready(clients).await { tracing::debug!("waiting for servers to come up"); sleep(Duration::from_secs(1)).await; wait -= 1; } - (clients, network) } -async fn clients_ready(clients: &[IpaHttpClient; 3]) -> bool { - clients[0].echo("").await.is_ok() - && clients[1].echo("").await.is_ok() - && clients[2].echo("").await.is_ok() +#[allow(clippy::disallowed_methods)] +async fn clients_ready(clients: &[[IpaHttpClient; 3]]) -> bool { + let r = futures::future::join_all(clients.iter().map(|clients| async move { + clients[0].echo("").await.is_ok() + && clients[1].echo("").await.is_ok() + && clients[2].echo("").await.is_ok() + })) + .await; + + r.iter().all(|&v| v) } diff --git a/ipa-core/src/cli/playbook/sharded_shuffle.rs b/ipa-core/src/cli/playbook/sharded_shuffle.rs new file mode 100644 index 000000000..0139a8171 --- /dev/null +++ b/ipa-core/src/cli/playbook/sharded_shuffle.rs @@ -0,0 +1,97 @@ +use std::{ + cmp::{max, min}, + ops::Add, + time::Duration, +}; + +use futures_util::future::try_join_all; +use generic_array::ArrayLength; + +use crate::{ + ff::{boolean_array::BooleanArray, Serializable}, + helpers::{query::QueryInput, BodyStream}, + net::{Helper, IpaHttpClient}, + protocol::QueryId, + query::QueryStatus, + secret_sharing::{replicated::semi_honest::AdditiveShare, IntoShares}, + test_fixture::Reconstruct, +}; + +/// Secure sharded shuffle protocol +/// +/// ## Panics +/// If the input size is empty or contains only one row. +#[allow(clippy::disallowed_methods)] // allow try_join_all +pub async fn secure_shuffle( + inputs: Vec, + clients: &[[IpaHttpClient; 3]], + query_id: QueryId, +) -> Vec +where + V: IntoShares>, + ::Size: Add<::Size, Output: ArrayLength>, + V: BooleanArray, +{ + assert!( + inputs.len() > 1, + "Shuffle requires at least two rows to be shuffled" + ); + let chunk_size = max(1, inputs.len() / clients.len()); + let _ = try_join_all( + inputs + .chunks(chunk_size) + .zip(clients) + .map(|(chunk, mpc_clients)| { + let shared = chunk.iter().copied().share(); + try_join_all(mpc_clients.each_ref().iter().zip(shared).map( + |(mpc_client, input)| { + mpc_client.query_input(QueryInput { + query_id, + input_stream: BodyStream::from_serializable_iter(input), + }) + }, + )) + }), + ) + .await + .unwrap(); + let leader_clients = &clients[0]; + + let mut delay = Duration::from_millis(125); + loop { + if try_join_all( + leader_clients + .iter() + .map(|client| client.query_status(query_id)), + ) + .await + .unwrap() + .into_iter() + .all(|status| status == QueryStatus::Completed) + { + break; + } + + tokio::time::sleep(delay).await; + delay = min(Duration::from_secs(5), delay * 2); + } + + let results: [_; 3] = try_join_all( + leader_clients + .iter() + .map(|client| client.query_results(query_id)), + ) + .await + .unwrap() + .try_into() + .unwrap(); + let results: Vec = results + .map(|bytes| { + AdditiveShare::::from_byte_slice(&bytes) + .collect::, _>>() + .unwrap() + }) + .reconstruct(); + + results +} diff --git a/ipa-core/src/cli/test_setup.rs b/ipa-core/src/cli/test_setup.rs index 32ac91885..86c663931 100644 --- a/ipa-core/src/cli/test_setup.rs +++ b/ipa-core/src/cli/test_setup.rs @@ -34,13 +34,31 @@ pub struct TestSetupArgs { #[arg(long, default_value_t = false)] use_http1: bool, - #[arg(short, long, num_args = 3, value_name = "PORT", default_values = vec!["3000", "3001", "3002"])] + /// A list of ALL the MPC ports for all servers. If you have a server with shard count 4, you + /// will have to provide 12 ports. + #[arg(short, long, value_name = "PORT", num_args = 1.., default_values = vec!["3000", "3001", "3002"])] ports: Vec, - #[arg(short, long, num_args = 3, value_name = "SHARD_PORT", default_values = vec!["6000", "6001", "6002"])] + /// A list of ALL the sharding ports for all servers. If you have a server with shard count 4, + /// you will have to provide 12 ports. + #[arg(short, long, value_name = "SHARD_PORT", num_args = 1.., default_values = vec!["6000", "6001", "6002"])] shard_ports: Vec, } +impl TestSetupArgs { + /// Returns number of shards requested for setup. + fn shard_count(&self) -> usize { + self.ports.len() / 3 + } + + /// If the number of shards requested is greater than 1 + /// then we configure a sharded environment, otherwise + /// a fixed 3-host MPC configuration is created + fn is_sharded(&self) -> bool { + self.shard_count() > 1 + } +} + /// Prepare a test network of three helpers. /// /// # Errors @@ -49,6 +67,21 @@ pub struct TestSetupArgs { /// # Panics /// If something that shouldn't happen goes wrong. pub fn test_setup(args: TestSetupArgs) -> Result<(), BoxError> { + assert_eq!( + args.ports.len(), + args.shard_ports.len(), + "number of mpc ports and shard ports don't match" + ); + assert_eq!( + args.ports.len() % 3, + 0, + "Number of ports must be a multiple of 3" + ); + assert!( + !args.ports.is_empty() && !args.shard_ports.is_empty(), + "Please provide a list of ports" + ); + if args.output_dir.exists() { if !args.output_dir.is_dir() || args.output_dir.read_dir()?.next().is_some() { return Err("output directory already exists and is not empty".into()); @@ -57,8 +90,76 @@ pub fn test_setup(args: TestSetupArgs) -> Result<(), BoxError> { DirBuilder::new().create(&args.output_dir)?; } + if args.is_sharded() { + sharded_keygen(args) + } else { + non_sharded_keygen(args) + } +} + +fn sharded_keygen(args: TestSetupArgs) -> Result<(), BoxError> { let localhost = String::from("localhost"); + let keygen_args: Vec<_> = [1, 2, 3] + .into_iter() + .cycle() + .take(args.ports.len()) + .enumerate() + .map(|(i, id)| { + let shard_dir = args.output_dir.join(format!("shard{i}")); + DirBuilder::new().create(&shard_dir)?; + Ok::<_, BoxError>(if i < 3 { + // Only leader shards generate MK keys + KeygenArgs { + name: localhost.clone(), + tls_cert: shard_dir.helper_tls_cert(id), + tls_key: shard_dir.helper_tls_key(id), + tls_expire_after: 365, + mk_public_key: Some(shard_dir.helper_mk_public_key(id)), + mk_private_key: Some(shard_dir.helper_mk_private_key(id)), + } + } else { + KeygenArgs { + name: localhost.clone(), + tls_cert: shard_dir.helper_tls_cert(id), + tls_key: shard_dir.helper_tls_key(id), + tls_expire_after: 365, + mk_public_key: None, + mk_private_key: None, + } + }) + }) + .collect::>()?; + for ka in &keygen_args { + keygen(ka)?; + } + + let clients_config: Vec<_> = zip( + keygen_args.iter(), + zip( + keygen_args.clone().into_iter().take(3).cycle(), + zip(args.ports, args.shard_ports), + ), + ) + .map( + |(keygen, (leader_keygen, (port, shard_port)))| HelperClientConf { + host: localhost.to_string(), + port, + shard_port, + tls_cert_file: keygen.tls_cert.clone(), + mk_public_key_file: leader_keygen.mk_public_key.clone().unwrap(), + }, + ) + .collect(); + + let mut conf_file = File::create(args.output_dir.join("network.toml"))?; + gen_client_config(clients_config.into_iter(), args.use_http1, &mut conf_file) +} + +/// This generates directories and files needed to run a non-sharded MPC. +/// The directory structure is flattened and does not include per-shard configuration. +fn non_sharded_keygen(args: TestSetupArgs) -> Result<(), BoxError> { + let localhost = String::from("localhost"); let clients_config: [_; 3] = zip([1, 2, 3], zip(args.ports, args.shard_ports)) .map(|(id, (port, shard_port))| { let keygen_args = KeygenArgs { @@ -87,3 +188,70 @@ pub fn test_setup(args: TestSetupArgs) -> Result<(), BoxError> { let mut conf_file = File::create(args.output_dir.join("network.toml"))?; gen_client_config(clients_config.into_iter(), args.use_http1, &mut conf_file) } + +#[cfg(test)] +mod tests { + use std::fs; + + use tempfile::TempDir; + + use crate::{ + cli::{sharded_server_from_toml_str, test_setup, TestSetupArgs}, + helpers::HelperIdentity, + sharding::ShardIndex, + }; + + #[test] + fn test_happy_case() { + let temp_dir = TempDir::new().unwrap(); + let outdir = temp_dir.path().to_path_buf(); + let args = TestSetupArgs { + output_dir: outdir.clone(), + disable_https: false, + use_http1: false, + ports: vec![3000, 3001, 3002, 3003, 3004, 3005], + shard_ports: vec![6000, 6001, 6002, 6003, 6004, 6005], + }; + test_setup(args).unwrap(); + let network_config_path = outdir.join("network.toml"); + let network_config_string = &fs::read_to_string(network_config_path).unwrap(); + let _r = sharded_server_from_toml_str( + network_config_string, + HelperIdentity::TWO, + ShardIndex::from(1), + ShardIndex::from(2), + None, + ) + .unwrap(); + } + + #[test] + #[should_panic(expected = "Please provide a list of ports")] + fn test_empty_ports() { + let temp_dir = TempDir::new().unwrap(); + let outdir = temp_dir.path().to_path_buf(); + let args = TestSetupArgs { + output_dir: outdir, + disable_https: false, + use_http1: false, + ports: vec![], + shard_ports: vec![], + }; + test_setup(args).unwrap(); + } + + #[test] + #[should_panic(expected = "number of mpc ports and shard ports don't match")] + fn test_mismatched_ports() { + let temp_dir = TempDir::new().unwrap(); + let outdir = temp_dir.path().to_path_buf(); + let args = TestSetupArgs { + output_dir: outdir, + disable_https: false, + use_http1: false, + ports: vec![3000, 3001], + shard_ports: vec![6000, 6001, 6002], + }; + test_setup(args).unwrap(); + } +} diff --git a/ipa-core/src/helpers/transport/routing.rs b/ipa-core/src/helpers/transport/routing.rs index e851865ce..c935704d4 100644 --- a/ipa-core/src/helpers/transport/routing.rs +++ b/ipa-core/src/helpers/transport/routing.rs @@ -14,6 +14,13 @@ pub enum RouteId { ReceiveQuery, PrepareQuery, QueryInput, + /// To accelerate delivery, we made some compromise here and as a result this API + /// has double-meaning depending on the context. + /// In the context of a shard, it is used to check whether other shards have the + /// same status + /// In the context of an MPC client, it is used to fetch the latest status of a given query. + /// + /// We should've used a different `RouteId` to differentiate those QueryStatus, CompleteQuery, KillQuery, diff --git a/ipa-core/src/net/client/mod.rs b/ipa-core/src/net/client/mod.rs index 3c423c33c..d4789b198 100644 --- a/ipa-core/src/net/client/mod.rs +++ b/ipa-core/src/net/client/mod.rs @@ -384,6 +384,31 @@ impl IpaHttpClient { let resp = self.request(req).await?; resp_ok(resp).await } + + /// This API is used by leader shards in MPC to request query status information on peers. + /// If a given peer has status that doesn't match the one provided by the leader, it responds + /// with 412 error and encodes its status inside the response body. Otherwise, 200 is returned. + /// + /// # Errors + /// If the request has illegal arguments, or fails to be delivered + pub async fn status_match(&self, data: CompareStatusRequest) -> Result<(), Error> { + let req = http_serde::query::status_match::try_into_http_request( + &data, + self.scheme.clone(), + self.authority.clone(), + )?; + let resp = self.request(req).await?; + + match resp.status() { + StatusCode::OK => Ok(()), + StatusCode::PRECONDITION_FAILED => { + let bytes = response_to_bytes(resp).await?; + let err = serde_json::from_slice::(&bytes)?; + Err(err.into()) + } + _ => Err(Error::from_failed_resp(resp).await), + } + } } impl IpaHttpClient { @@ -509,31 +534,6 @@ impl IpaHttpClient { }) .collect() } - - /// This API is used by leader shards in MPC to request query status information on peers. - /// If a given peer has status that doesn't match the one provided by the leader, it responds - /// with 412 error and encodes its status inside the response body. Otherwise, 200 is returned. - /// - /// # Errors - /// If the request has illegal arguments, or fails to be delivered - pub async fn status_match(&self, data: CompareStatusRequest) -> Result<(), Error> { - let req = http_serde::query::status_match::try_into_http_request( - &data, - self.scheme.clone(), - self.authority.clone(), - )?; - let resp = self.request(req).await?; - - match resp.status() { - StatusCode::OK => Ok(()), - StatusCode::PRECONDITION_FAILED => { - let bytes = response_to_bytes(resp).await?; - let err = serde_json::from_slice::(&bytes)?; - Err(err.into()) - } - _ => Err(Error::from_failed_resp(resp).await), - } - } } fn make_http_connector() -> HttpConnector { diff --git a/ipa-core/src/net/error.rs b/ipa-core/src/net/error.rs index e5f188158..82426b402 100644 --- a/ipa-core/src/net/error.rs +++ b/ipa-core/src/net/error.rs @@ -182,7 +182,7 @@ impl IntoResponse for Error { StatusCode::PRECONDITION_FAILED, serde_json::to_string(&error).unwrap(), ) - .into_response() + .into_response(); } }; (status_code, self.to_string()).into_response() diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index 173c08831..d8fbfc4b5 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -111,10 +111,11 @@ impl HttpTransport { .expect("query_id is required to call complete query API"); self.clients[client_ix].complete_query(query_id).await } - evt @ (RouteId::QueryInput - | RouteId::ReceiveQuery - | RouteId::QueryStatus - | RouteId::KillQuery) => { + RouteId::QueryStatus => { + let req = serde_json::from_str(route.extra().borrow())?; + self.clients[client_ix].status_match(req).await + } + evt @ (RouteId::QueryInput | RouteId::ReceiveQuery | RouteId::KillQuery) => { unimplemented!( "attempting to send client-specific request {evt:?} to another helper" ) diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs b/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs index 165366826..fbaa2f860 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/sharded.rs @@ -29,7 +29,10 @@ use crate::{ helpers::{Direction, Error, Role, TotalRecords}, protocol::{ context::{reshard_iter, ShardedContext}, - ipa_prf::shuffle::{step::ShardedShuffleStep as ShuffleStep, IntermediateShuffleMessages}, + ipa_prf::shuffle::{ + step::{ShardedShufflePermuteStep as PermuteStep, ShardedShuffleStep as ShuffleStep}, + IntermediateShuffleMessages, + }, prss::{FromRandom, SharedRandomness}, RecordId, }, @@ -105,7 +108,7 @@ pub trait ShuffleContext: ShardedContext { { let data = data.into_iter(); async move { - let masking_ctx = self.narrow(&ShuffleStep::Mask); + let masking_ctx = self.narrow(&PermuteStep::Mask); let mut resharded = assert_send(reshard_iter( self.clone(), data.enumerate().map(|(i, item)| { @@ -118,7 +121,7 @@ pub trait ShuffleContext: ShardedContext { )) .await?; - let ctx = self.narrow(&ShuffleStep::LocalShuffle); + let ctx = self.narrow(&PermuteStep::LocalShuffle); resharded.shuffle(&mut match direction { Direction::Left => ctx.prss_rng().0, Direction::Right => ctx.prss_rng().1, diff --git a/ipa-core/src/protocol/ipa_prf/shuffle/step.rs b/ipa-core/src/protocol/ipa_prf/shuffle/step.rs index 07d9c204a..6a4ff2050 100644 --- a/ipa-core/src/protocol/ipa_prf/shuffle/step.rs +++ b/ipa-core/src/protocol/ipa_prf/shuffle/step.rs @@ -32,10 +32,13 @@ pub(crate) enum ShardedShuffleStep { /// Depending on the helper position inside the MPC ring, generate Ã, B̃ or both. PseudoRandomTable, /// Permute the input according to the PRSS shared between H1 and H2. + #[step(child = ShardedShufflePermuteStep)] Permute12, /// Permute the input according to the PRSS shared between H2 and H3. + #[step(child = ShardedShufflePermuteStep)] Permute23, /// Permute the input according to the PRSS shared between H3 and H1. + #[step(child = ShardedShufflePermuteStep)] Permute31, /// Specific to H1 and H2 interaction - H2 informs H1 about |C|. Cardinality, @@ -43,6 +46,10 @@ pub(crate) enum ShardedShuffleStep { TransferXY, /// H2 and H3 interaction - Exchange `C_1` and `C_2`. TransferC, +} + +#[derive(CompactStep)] +pub(crate) enum ShardedShufflePermuteStep { /// Apply a mask to the given set of shares. Masking values come from PRSS. Mask, /// Local per-shard shuffle, where each shard redistributes shares locally according to samples diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index 856751fdc..38a75baa5 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -8,7 +8,7 @@ use serde::Serialize; use super::min_status; use crate::{ - error::{BoxError, Error as ProtocolError}, + error::Error as ProtocolError, executor::IpaRuntime, helpers::{ query::{CompareStatusRequest, PrepareQuery, QueryConfig, QueryInput}, @@ -368,7 +368,8 @@ impl Processor { /// This helper function is used to transform a [`BoxError`] into a /// [`QueryStatusError::DifferentStatus`] and retrieve it's internal state. Returns [`None`] /// if not possible. - fn downcast_state_error(box_error: BoxError) -> Option { + #[cfg(feature = "in-memory-infra")] + fn downcast_state_error(box_error: crate::error::BoxError) -> Option { use crate::helpers::ApiError; let api_error = box_error.downcast::().ok()?; if let ApiError::QueryStatus(QueryStatusError::DifferentStatus { my_status, .. }) = @@ -399,8 +400,8 @@ impl Processor { /// of relying on errors. #[cfg(feature = "real-world-infra")] fn get_state_from_error(shard_error: crate::net::ShardError) -> Option { - if let crate::net::Error::Application { error, .. } = shard_error.source { - return Self::downcast_state_error(error); + if let crate::net::Error::ShardQueryStatusMismatch { error, .. } = shard_error.source { + return Some(error.actual); } None } @@ -431,6 +432,10 @@ impl Processor { let shard_responses = shard_transport.broadcast(shard_query_status_req).await; if let Err(e) = shard_responses { // The following silently ignores the cases where the query isn't found. + // TODO: this code is a ticking bomb - it ignores all errors, not just when + // query is not found. If there is no handler, handler responded with an error, etc. + // Moreover, any error may result in client mistakenly assuming that the status + // is completed. let states: Vec<_> = e .failures .into_iter() diff --git a/ipa-core/tests/common/mod.rs b/ipa-core/tests/common/mod.rs index dae743987..d461a2018 100644 --- a/ipa-core/tests/common/mod.rs +++ b/ipa-core/tests/common/mod.rs @@ -2,7 +2,6 @@ use std::{ array, error::Error, io::{self, Write}, - iter::zip, net::TcpListener, ops::Deref, os::fd::AsRawFd, @@ -44,7 +43,7 @@ impl UnwrapStatusExt for Result { } } -trait TerminateOnDropExt { +pub trait TerminateOnDropExt { fn terminate_on_drop(self) -> TerminateOnDrop; } @@ -61,7 +60,7 @@ impl TerminateOnDrop { self.0.take().unwrap() } - fn wait(self) -> io::Result { + pub fn wait(self) -> io::Result { self.into_inner().wait() } } @@ -109,11 +108,50 @@ impl CommandExt for Command { } } -fn test_setup(config_path: &Path) -> [TcpListener; 6] { - let sockets: [_; 6] = array::from_fn(|_| TcpListener::bind("127.0.0.1:0").unwrap()); - let ports: [u16; 6] = sockets +fn test_setup(config_path: &Path) -> [ShardTcpListeners; 3] { + test_sharded_setup::<1>(config_path) + .into_iter() + .next() + .unwrap() +} + +pub struct ShardTcpListeners { + pub mpc: TcpListener, + pub shard: TcpListener, +} + +impl ShardTcpListeners { + pub fn bind_random() -> Self { + let mpc = TcpListener::bind("127.0.0.1:0").unwrap(); + let shard = TcpListener::bind("127.0.0.1:0").unwrap(); + + Self { mpc, shard } + } +} + +pub fn test_sharded_setup(config_path: &Path) -> Vec<[ShardTcpListeners; 3]> { + let sockets: [_; SHARDS] = array::from_fn(|_| { + let r: [_; 3] = array::from_fn(|_| ShardTcpListeners::bind_random()); + + r + }); + + let (mpc_ports, shard_ports): (Vec<_>, Vec<_>) = sockets .each_ref() - .map(|sock| sock.local_addr().unwrap().port()); + .iter() + .flat_map(|listeners| { + listeners + .each_ref() + .iter() + .map(|l| { + ( + l.mpc.local_addr().unwrap().port(), + l.shard.local_addr().unwrap().port(), + ) + }) + .collect::>() + }) + .unzip(); let mut command = Command::new(HELPER_BIN); command @@ -121,63 +159,103 @@ fn test_setup(config_path: &Path) -> [TcpListener; 6] { .arg("test-setup") .args(["--output-dir".as_ref(), config_path.as_os_str()]) .arg("--ports") - .args(ports.iter().take(3).map(|p| p.to_string())) + .args(mpc_ports.iter().map(|p| p.to_string())) .arg("--shard-ports") - .args(ports.iter().skip(3).take(3).map(|p| p.to_string())); - + .args(shard_ports.iter().map(|p| p.to_string())); command.status().unwrap_status(); + + sockets.into_iter().collect() +} + +pub fn spawn_shards( + config_path: &Path, + sockets: &[[ShardTcpListeners; 3]], + https: bool, +) -> Vec { + if https { + unimplemented!("We haven't implemented HTTPS path yet") + } + + let shard_count = sockets.len(); sockets + .iter() + .enumerate() + .flat_map(|(shard_index, mpc_sockets)| { + (1..=3) + .zip(mpc_sockets.each_ref().iter()) + .map(|(id, ShardTcpListeners { mpc, shard })| { + let mut command = Command::new(HELPER_BIN); + command + .args(["-i", &id.to_string()]) + .args(["--shard-index", &shard_index.to_string()]) + .args(["--shard-count", &shard_count.to_string()]) + .args(["--network".into(), config_path.join("network.toml")]) + .arg("--disable-https") + .silent(); + + command.preserved_fds(vec![mpc.as_raw_fd(), shard.as_raw_fd()]); + command.args(["--server-socket-fd", &mpc.as_raw_fd().to_string()]); + command.args(["--shard-server-socket-fd", &shard.as_raw_fd().to_string()]); + + // something went wrong if command is terminated at this point. + let mut child = command.spawn().unwrap(); + if let Ok(Some(status)) = child.try_wait() { + panic!("Helper binary terminated early with status = {status}"); + } + + child.terminate_on_drop() + }) + .collect::>() + }) + .collect() } pub fn spawn_helpers( config_path: &Path, - sockets: &[TcpListener; 6], + // (mpc port + shard port) for 3 helpers + sockets: &[ShardTcpListeners; 3], https: bool, ) -> Vec { - zip( - [1, 2, 3], - zip(sockets.iter().take(3), sockets.iter().skip(3).take(3)), - ) - .map(|(id, (socket, shard_socket))| { - let mut command = Command::new(HELPER_BIN); - command - .args(["-i", &id.to_string()]) - .args(["--network".into(), config_path.join("network.toml")]) - .silent(); - - if https { + sockets + .iter() + .enumerate() + .map(|(id, ShardTcpListeners { mpc, shard })| { + let id = id + 1; + let mut command = Command::new(HELPER_BIN); command - .args(["--tls-cert".into(), config_path.join(format!("h{id}.pem"))]) - .args(["--tls-key".into(), config_path.join(format!("h{id}.key"))]) - .args([ - "--mk-public-key".into(), - config_path.join(format!("h{id}_mk.pub")), - ]) - .args([ - "--mk-private-key".into(), - config_path.join(format!("h{id}_mk.key")), - ]); - } else { - command.arg("--disable-https"); - } - - command.preserved_fds(vec![socket.as_raw_fd()]); - command.args(["--server-socket-fd", &socket.as_raw_fd().to_string()]); - command.preserved_fds(vec![shard_socket.as_raw_fd()]); - command.args([ - "--shard-server-socket-fd", - &shard_socket.as_raw_fd().to_string(), - ]); - - // something went wrong if command is terminated at this point. - let mut child = command.spawn().unwrap(); - if let Ok(Some(status)) = child.try_wait() { - panic!("Helper binary terminated early with status = {status}"); - } - - child.terminate_on_drop() - }) - .collect::>() + .args(["-i", &id.to_string()]) + .args(["--network".into(), config_path.join("network.toml")]) + .silent(); + + if https { + command + .args(["--tls-cert".into(), config_path.join(format!("h{id}.pem"))]) + .args(["--tls-key".into(), config_path.join(format!("h{id}.key"))]) + .args([ + "--mk-public-key".into(), + config_path.join(format!("h{id}_mk.pub")), + ]) + .args([ + "--mk-private-key".into(), + config_path.join(format!("h{id}_mk.key")), + ]); + } else { + command.arg("--disable-https"); + } + + command.preserved_fds(vec![mpc.as_raw_fd(), shard.as_raw_fd()]); + command.args(["--server-socket-fd", &mpc.as_raw_fd().to_string()]); + command.args(["--shard-server-socket-fd", &shard.as_raw_fd().to_string()]); + + // something went wrong if command is terminated at this point. + let mut child = command.spawn().unwrap(); + if let Ok(Some(status)) = child.try_wait() { + panic!("Helper binary terminated early with status = {status}"); + } + + child.terminate_on_drop() + }) + .collect::>() } pub fn test_multiply(config_dir: &Path, https: bool) { diff --git a/ipa-core/tests/helper_networks.rs b/ipa-core/tests/helper_networks.rs index 06adb56a7..fbd1e2433 100644 --- a/ipa-core/tests/helper_networks.rs +++ b/ipa-core/tests/helper_networks.rs @@ -1,6 +1,11 @@ mod common; -use std::{array, net::TcpListener, path::Path, process::Command}; +use std::{ + array, + io::Write, + path::Path, + process::{Command, Stdio}, +}; use common::{ spawn_helpers, tempdir::TempDir, test_ipa, test_multiply, test_network, CommandExt, @@ -8,7 +13,10 @@ use common::{ }; use ipa_core::{cli::CliPaths, helpers::HelperIdentity, test_fixture::ipa::IpaSecurityModel}; -use crate::common::{AddInPrimeField, Multiply}; +use crate::common::{ + spawn_shards, test_sharded_setup, AddInPrimeField, Multiply, ShardTcpListeners, + TerminateOnDrop, TerminateOnDropExt, TEST_MPC_BIN, +}; #[test] #[cfg(all(test, web_test))] @@ -61,6 +69,45 @@ fn https_malicious_ipa() { test_ipa(IpaSecurityModel::Malicious, true, true); } +#[test] +#[cfg(all(test, web_test))] +fn http_sharded_shuffle_3_shards() { + let dir = TempDir::new_delete_on_drop(); + let path = dir.path(); + + println!("generating configuration in {}", path.display()); + let sockets = test_sharded_setup::<3>(path); + let _helpers = spawn_shards(path, &sockets, false); + + let mut command = Command::new(TEST_MPC_BIN); + command + .args(["--network".into(), path.join("network.toml")]) + .args(["--wait", "2"]) + .arg("--disable-https"); + + command.arg("sharded-shuffle").stdin(Stdio::piped()); + + let test_mpc = command.spawn().unwrap().terminate_on_drop(); + + // Shuffle numbers from 1 to 10. `test_mpc` binary will check if they were + // permuted correctly. Our job here is to submit input large enough to avoid + // false negatives + test_mpc + .stdin + .as_ref() + .unwrap() + .write_all( + (1..10) + .into_iter() + .map(|i| i.to_string()) + .collect::>() + .join("\n") + .as_bytes(), + ) + .unwrap(); + TerminateOnDrop::wait(test_mpc).unwrap_status(); +} + /// Similar to [`network`] tests, but it uses keygen + confgen CLIs to generate helper client config /// and then just runs test multiply to make sure helpers are up and running /// @@ -71,10 +118,17 @@ fn keygen_confgen() { let dir = TempDir::new_delete_on_drop(); let path = dir.path(); - let sockets: [_; 6] = array::from_fn(|_| TcpListener::bind("127.0.0.1:0").unwrap()); - let ports: [u16; 6] = sockets + let sockets: [_; 3] = array::from_fn(|_| ShardTcpListeners::bind_random()); + let (mpc_ports, shard_ports): (Vec<_>, Vec<_>) = sockets .each_ref() - .map(|sock| sock.local_addr().unwrap().port()); + .iter() + .map(|ShardTcpListeners { mpc, shard }| { + ( + mpc.local_addr().unwrap().port(), + shard.local_addr().unwrap().port(), + ) + }) + .unzip(); // closure that generates the client config file (network.toml) let exec_conf_gen = |overwrite| { @@ -85,9 +139,9 @@ fn keygen_confgen() { .args(["--output-dir".as_ref(), path.as_os_str()]) .args(["--keys-dir".as_ref(), path.as_os_str()]) .arg("--ports") - .args(ports.iter().take(3).map(|p| p.to_string())) + .args(mpc_ports.iter().map(|p| p.to_string())) .arg("--shard-ports") - .args(ports.iter().skip(3).take(3).map(|p| p.to_string())) + .args(shard_ports.iter().map(|p| p.to_string())) .arg("--hosts") .args(["localhost", "localhost", "localhost"]); if overwrite {