diff --git a/ipa-core/src/app.rs b/ipa-core/src/app.rs index 6b1032f72..0d61e9bb6 100644 --- a/ipa-core/src/app.rs +++ b/ipa-core/src/app.rs @@ -220,7 +220,6 @@ impl RequestHandler for Inner { data: BodyStream, ) -> Result { let qp = &self.query_processor; - Ok(match req.route { r @ RouteId::Records => { return Err(ApiError::BadRequest( diff --git a/ipa-core/src/bin/helper.rs b/ipa-core/src/bin/helper.rs index 6a9b4537b..85b0c7110 100644 --- a/ipa-core/src/bin/helper.rs +++ b/ipa-core/src/bin/helper.rs @@ -371,7 +371,7 @@ pub async fn main() { let res = match args.command { None => server(args.server, handle).await, Some(HelperCommand::Keygen(args)) => keygen(&args), - Some(HelperCommand::TestSetup(args)) => test_setup(args), + Some(HelperCommand::TestSetup(args)) => test_setup(&args), Some(HelperCommand::Confgen(args)) => client_config_setup(args), Some(HelperCommand::ShardedConfgen(args)) => sharded_client_config_setup(args), }; diff --git a/ipa-core/src/cli/clientconf.rs b/ipa-core/src/cli/clientconf.rs index f44ba7ea3..0b376dad0 100644 --- a/ipa-core/src/cli/clientconf.rs +++ b/ipa-core/src/cli/clientconf.rs @@ -9,6 +9,7 @@ use crate::{ }, error::BoxError, helpers::HelperIdentity, + sharding::ShardIndex, }; #[derive(Debug, Args)] @@ -157,7 +158,7 @@ fn create_sharded_conf_from_files( let mut shard_dir = base_dir.clone(); let id_nr: u8 = id.into(); shard_dir.push(format!("helper{id_nr}")); - shard_dir.push(format!("shard{ix}")); + shard_dir.push(shard_conf_folder(ix)); let host_name = find_file_with_extension(&shard_dir, "pem").unwrap(); let tls_cert_file = shard_dir.join(format!("{host_name}.pem")); @@ -223,3 +224,7 @@ fn gen_conf_from_args( ); Ok(()) } + +pub fn shard_conf_folder>(shard_id: I) -> PathBuf { + format!("shard{}", shard_id.try_into().ok().unwrap()).into() +} diff --git a/ipa-core/src/cli/config_parse.rs b/ipa-core/src/cli/config_parse.rs index 76cb30e96..f282a3a9f 100644 --- a/ipa-core/src/cli/config_parse.rs +++ b/ipa-core/src/cli/config_parse.rs @@ -123,8 +123,8 @@ fn parse_sharded_network_toml(input: &str) -> Result /// Generates client configuration file at the requested destination. The destination must exist /// before this function is called -pub fn gen_client_config( - clients_conf: impl Iterator, +pub fn gen_client_config>( + clients_conf: I, use_http1: bool, conf_file: &mut File, ) -> Result<(), BoxError> { @@ -352,7 +352,7 @@ pub fn sharded_server_from_toml_str( identities: shard_count.iter().collect(), }; Ok((mpc_network, shard_network)) - } else if missing_urls == [0, 1, 2] && shard_count == ShardIndex(1) { + } else if missing_urls == [0, 1, 2] && shard_count == ShardIndex::from(1) { // This is the special case we're dealing with a non-sharded, single ring MPC. // Since the shard network will be of size 1, it can't really communicate with anyone else. // Hence we just create a config where I'm the only shard. We take the MPC configuration diff --git a/ipa-core/src/cli/test_setup.rs b/ipa-core/src/cli/test_setup.rs index 86c663931..7a5e2d8ae 100644 --- a/ipa-core/src/cli/test_setup.rs +++ b/ipa-core/src/cli/test_setup.rs @@ -1,11 +1,12 @@ use std::{ fs::{DirBuilder, File}, iter::zip, - path::PathBuf, + path::{Path, PathBuf}, }; use clap::Args; +use super::clientconf::shard_conf_folder; use crate::{ cli::{ config_parse::{gen_client_config, HelperClientConf}, @@ -66,7 +67,7 @@ impl TestSetupArgs { /// /// # Panics /// If something that shouldn't happen goes wrong. -pub fn test_setup(args: TestSetupArgs) -> Result<(), BoxError> { +pub fn test_setup(args: &TestSetupArgs) -> Result<(), BoxError> { assert_eq!( args.ports.len(), args.shard_ports.len(), @@ -97,96 +98,72 @@ pub fn test_setup(args: TestSetupArgs) -> Result<(), BoxError> { } } -fn sharded_keygen(args: TestSetupArgs) -> Result<(), BoxError> { - let localhost = String::from("localhost"); - let keygen_args: Vec<_> = [1, 2, 3] - .into_iter() - .cycle() - .take(args.ports.len()) - .enumerate() - .map(|(i, id)| { - let shard_dir = args.output_dir.join(format!("shard{i}")); - DirBuilder::new().create(&shard_dir)?; - Ok::<_, BoxError>(if i < 3 { - // Only leader shards generate MK keys - KeygenArgs { - name: localhost.clone(), - tls_cert: shard_dir.helper_tls_cert(id), - tls_key: shard_dir.helper_tls_key(id), - tls_expire_after: 365, - mk_public_key: Some(shard_dir.helper_mk_public_key(id)), - mk_private_key: Some(shard_dir.helper_mk_private_key(id)), - } - } else { - KeygenArgs { - name: localhost.clone(), - tls_cert: shard_dir.helper_tls_cert(id), - tls_key: shard_dir.helper_tls_key(id), - tls_expire_after: 365, - mk_public_key: None, - mk_private_key: None, - } - }) - }) - .collect::>()?; - - for ka in &keygen_args { - keygen(ka)?; - } +fn sharded_keygen(args: &TestSetupArgs) -> Result<(), BoxError> { + const RING_SIZE: usize = 3; + // we split all ports into chunks of 3 (for each MPC ring) and go over + // all of them, creating configuration let clients_config: Vec<_> = zip( - keygen_args.iter(), - zip( - keygen_args.clone().into_iter().take(3).cycle(), - zip(args.ports, args.shard_ports), - ), - ) - .map( - |(keygen, (leader_keygen, (port, shard_port)))| HelperClientConf { - host: localhost.to_string(), - port, - shard_port, - tls_cert_file: keygen.tls_cert.clone(), - mk_public_key_file: leader_keygen.mk_public_key.clone().unwrap(), - }, + args.ports.chunks(RING_SIZE), + args.shard_ports.chunks(RING_SIZE), ) - .collect(); + .enumerate() + .flat_map(|(shard_id, (mpc_ports, shard_ports))| { + let shard_dir = args.output_dir.join(shard_conf_folder(shard_id)); + DirBuilder::new().create(&shard_dir)?; + make_client_configs(mpc_ports, shard_ports, &shard_dir) + }) + .flatten() + .collect::>(); let mut conf_file = File::create(args.output_dir.join("network.toml"))?; - gen_client_config(clients_config.into_iter(), args.use_http1, &mut conf_file) + gen_client_config(clients_config, args.use_http1, &mut conf_file) } /// This generates directories and files needed to run a non-sharded MPC. /// The directory structure is flattened and does not include per-shard configuration. -fn non_sharded_keygen(args: TestSetupArgs) -> Result<(), BoxError> { +fn non_sharded_keygen(args: &TestSetupArgs) -> Result<(), BoxError> { + let client_configs = make_client_configs(&args.ports, &args.shard_ports, &args.output_dir)?; + + let mut conf_file = File::create(args.output_dir.join("network.toml"))?; + gen_client_config(client_configs, args.use_http1, &mut conf_file) +} + +fn make_client_configs( + mpc_ports: &[u16], + shard_ports: &[u16], + config_dir: &Path, +) -> Result, BoxError> { + assert_eq!(shard_ports.len(), mpc_ports.len()); + assert_eq!(3, shard_ports.len()); + let localhost = String::from("localhost"); - let clients_config: [_; 3] = zip([1, 2, 3], zip(args.ports, args.shard_ports)) - .map(|(id, (port, shard_port))| { + zip(mpc_ports.iter(), shard_ports.iter()) + .enumerate() + .map(|(i, (&mpc_port, &shard_port))| { + let id = u8::try_from(i + 1).unwrap(); + + // TODO: only leader shards should generate MK encryptions. let keygen_args = KeygenArgs { name: localhost.clone(), - tls_cert: args.output_dir.helper_tls_cert(id), - tls_key: args.output_dir.helper_tls_key(id), + tls_cert: config_dir.helper_tls_cert(id), + tls_key: config_dir.helper_tls_key(id), tls_expire_after: 365, - mk_public_key: Some(args.output_dir.helper_mk_public_key(id)), - mk_private_key: Some(args.output_dir.helper_mk_private_key(id)), + mk_public_key: Some(config_dir.helper_mk_public_key(id)), + mk_private_key: Some(config_dir.helper_mk_private_key(id)), }; keygen(&keygen_args)?; Ok(HelperClientConf { host: localhost.to_string(), - port, + port: mpc_port, shard_port, tls_cert_file: keygen_args.tls_cert, mk_public_key_file: keygen_args.mk_public_key.unwrap(), }) }) - .collect::, BoxError>>()? - .try_into() - .unwrap(); - - let mut conf_file = File::create(args.output_dir.join("network.toml"))?; - gen_client_config(clients_config.into_iter(), args.use_http1, &mut conf_file) + .collect::>() } #[cfg(test)] @@ -212,7 +189,7 @@ mod tests { ports: vec![3000, 3001, 3002, 3003, 3004, 3005], shard_ports: vec![6000, 6001, 6002, 6003, 6004, 6005], }; - test_setup(args).unwrap(); + test_setup(&args).unwrap(); let network_config_path = outdir.join("network.toml"); let network_config_string = &fs::read_to_string(network_config_path).unwrap(); let _r = sharded_server_from_toml_str( @@ -237,7 +214,7 @@ mod tests { ports: vec![], shard_ports: vec![], }; - test_setup(args).unwrap(); + test_setup(&args).unwrap(); } #[test] @@ -252,6 +229,6 @@ mod tests { ports: vec![3000, 3001], shard_ports: vec![6000, 6001, 6002], }; - test_setup(args).unwrap(); + test_setup(&args).unwrap(); } } diff --git a/ipa-core/src/config.rs b/ipa-core/src/config.rs index 97fc4e78a..132f80a64 100644 --- a/ipa-core/src/config.rs +++ b/ipa-core/src/config.rs @@ -602,6 +602,6 @@ mod tests { let pc1 = PeerConfig::new(uri1, None); let client = ClientConfig::default(); let conf = NetworkConfig::new_shards(vec![pc1.clone()], client); - assert_eq!(conf.peers[ShardIndex(0)].url, pc1.url); + assert_eq!(conf.peers[ShardIndex::FIRST].url, pc1.url); } } diff --git a/ipa-core/src/net/test.rs b/ipa-core/src/net/test.rs index ecb654de6..00441e6bf 100644 --- a/ipa-core/src/net/test.rs +++ b/ipa-core/src/net/test.rs @@ -138,7 +138,7 @@ impl TestNetwork { /// Creates all the shards for a helper and creates a network. fn new_shards(id: HelperIdentity, ports: Vec>, conf: &TestConfigBuilder) -> Self { let servers: Vec<_> = (0..conf.shard_count) - .map(ShardIndex) + .map(ShardIndex::from) .zip(ports) .map(|(ix, p)| { let sid = ShardedHelperIdentity::new(id, ix); @@ -296,7 +296,7 @@ impl TestConfig { /// Creates a new [`TestConfig`] using the provided configuration. fn new(conf: &TestConfigBuilder) -> Self { let rings = (0..conf.shard_count) - .map(ShardIndex) + .map(ShardIndex::from) .map(|s| { let ports = conf.get_ports_for_shard_index(s); TestNetwork::::new_mpc(s, ports, conf) @@ -1049,7 +1049,7 @@ mod tests { let builder = TestConfigBuilder::with_http_and_default_test_ports(); assert_eq!( vec![Some(3000), Some(3001), Some(3002)], - builder.get_ports_for_shard_index(ShardIndex(0)) + builder.get_ports_for_shard_index(ShardIndex::FIRST) ); assert_eq!( vec![Some(6001)], @@ -1060,6 +1060,9 @@ mod tests { #[test] fn get_os_ports() { let builder = TestConfigBuilder::default(); - assert_eq!(3, builder.get_ports_for_shard_index(ShardIndex(0)).len()); + assert_eq!( + 3, + builder.get_ports_for_shard_index(ShardIndex::FIRST).len() + ); } } diff --git a/ipa-core/src/protocol/hybrid/agg.rs b/ipa-core/src/protocol/hybrid/agg.rs index d756b0748..9b2113991 100644 --- a/ipa-core/src/protocol/hybrid/agg.rs +++ b/ipa-core/src/protocol/hybrid/agg.rs @@ -187,6 +187,7 @@ pub mod test { // the inputs are laid out to work with exactly 2 shards // as if it we're resharded by match_key/prf const SHARDS: usize = 2; + const SECOND_SHARD: ShardIndex = ShardIndex::from_u32(1); // we re-use these as the "prf" of the match_key // to avoid needing to actually do the prf here @@ -374,8 +375,8 @@ pub mod test { let results: Vec<[Vec<[AggregateableHybridReport; 2]>; 3]> = world .malicious(records.clone().into_iter(), |ctx, input| { let match_keys = match ctx.shard_id() { - ShardIndex(0) => SHARD1_MKS, - ShardIndex(1) => SHARD2_MKS, + ShardIndex::FIRST => SHARD1_MKS, + SECOND_SHARD => SHARD2_MKS, _ => panic!("invalid shard_id"), }; async move { @@ -446,8 +447,8 @@ pub mod test { let results: Vec<[Vec>; 3]> = world .malicious(records.clone().into_iter(), |ctx, input| { let match_keys = match ctx.shard_id() { - ShardIndex(0) => SHARD1_MKS, - ShardIndex(1) => SHARD2_MKS, + ShardIndex::FIRST => SHARD1_MKS, + SECOND_SHARD => SHARD2_MKS, _ => panic!("invalid shard_id"), }; async move { @@ -572,8 +573,8 @@ pub mod test { let _results: Vec<[Vec>; 3]> = world .malicious(records.clone().into_iter(), |ctx, input| { let match_keys = match ctx.shard_id() { - ShardIndex(0) => SHARD1_MKS, - ShardIndex(1) => SHARD2_MKS, + ShardIndex::FIRST => SHARD1_MKS, + SECOND_SHARD => SHARD2_MKS, _ => panic!("invalid shard_id"), }; async move { diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index 38a75baa5..4b35ab3af 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -864,7 +864,7 @@ mod tests { async fn shard_prepare_error() { fn shard_handle(si: ShardIndex) -> Arc> { create_handler(move |_| async move { - if si == ShardIndex(2) { + if si == ShardIndex::from(2) { Err(ApiError::QueryPrepare(PrepareQueryError::AlreadyRunning)) } else { Ok(HelperResponse::ok()) @@ -889,7 +889,7 @@ mod tests { assert!(r.is_err()); if let Err(e) = r { if let NewQueryError::ShardBroadcastError(be) = e { - assert_eq!(be.failures[0].0, ShardIndex(2)); + assert_eq!(be.failures[0].0, ShardIndex::from(2)); } else { panic!("Unexpected error type"); } @@ -1138,6 +1138,7 @@ mod tests { } mod query_status { + use super::*; use crate::{helpers::query::CompareStatusRequest, protocol::QueryId}; @@ -1150,16 +1151,18 @@ mod tests { #[tokio::test] async fn combined_status_response() { fn shard_handle(si: ShardIndex) -> Arc> { + const FOURTH_SHARD: ShardIndex = ShardIndex::from_u32(3); + const THIRD_SHARD: ShardIndex = ShardIndex::from_u32(2); create_handler(move |_| async move { match si { - ShardIndex(3) => { + FOURTH_SHARD => { Err(ApiError::QueryStatus(QueryStatusError::DifferentStatus { query_id: QueryId, my_status: QueryStatus::Completed, other_status: QueryStatus::Preparing, })) } - ShardIndex(2) => { + THIRD_SHARD => { Err(ApiError::QueryStatus(QueryStatusError::DifferentStatus { query_id: QueryId, my_status: QueryStatus::Running, @@ -1208,11 +1211,12 @@ mod tests { async fn status_query_doesnt_exist() { fn shard_handle(si: ShardIndex) -> Arc> { create_handler(move |_| async move { - match si { - ShardIndex(3) => Err(ApiError::QueryStatus(QueryStatusError::NoSuchQuery( + if si == ShardIndex::from(3) { + Err(ApiError::QueryStatus(QueryStatusError::NoSuchQuery( QueryId, - ))), - _ => Ok(HelperResponse::ok()), + ))) + } else { + Ok(HelperResponse::ok()) } }) } diff --git a/ipa-core/src/sharding.rs b/ipa-core/src/sharding.rs index 015aa21ed..d1dc7fc48 100644 --- a/ipa-core/src/sharding.rs +++ b/ipa-core/src/sharding.rs @@ -39,8 +39,11 @@ impl ShardedHelperIdentity { } /// A unique zero-based index of the helper shard. +/// Note to editors - if rustc suggests to make the internal field public, +/// don't. It breaks the encapsulation constraint. Use `from` or other methods +/// to convert from or into this struct's instance. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct ShardIndex(pub u32); +pub struct ShardIndex(u32); impl ShardIndex { pub const FIRST: Self = Self(0); @@ -49,11 +52,18 @@ impl ShardIndex { pub fn iter(self) -> impl Iterator { (0..self.0).map(Self) } + + /// Create a valid shard index from its u32 representation. + /// The reason it exists is because traits don't exist in const context + #[must_use] + pub const fn from_u32(value: u32) -> Self { + Self(value) + } } impl From for ShardIndex { fn from(value: u32) -> Self { - Self(value) + Self::from_u32(value) } } diff --git a/ipa-core/tests/common/mod.rs b/ipa-core/tests/common/mod.rs index d461a2018..8e1c014d7 100644 --- a/ipa-core/tests/common/mod.rs +++ b/ipa-core/tests/common/mod.rs @@ -172,10 +172,6 @@ pub fn spawn_shards( sockets: &[[ShardTcpListeners; 3]], https: bool, ) -> Vec { - if https { - unimplemented!("We haven't implemented HTTPS path yet") - } - let shard_count = sockets.len(); sockets .iter() @@ -189,18 +185,39 @@ pub fn spawn_shards( .args(["-i", &id.to_string()]) .args(["--shard-index", &shard_index.to_string()]) .args(["--shard-count", &shard_count.to_string()]) - .args(["--network".into(), config_path.join("network.toml")]) - .arg("--disable-https") - .silent(); + .args(["--network".into(), config_path.join("network.toml")]); + + if https { + let config_path = config_path.join(format!("shard{shard_index}")); + command + .args(["--tls-cert".into(), config_path.join(format!("h{id}.pem"))]) + .args(["--tls-key".into(), config_path.join(format!("h{id}.key"))]) + .args([ + "--mk-public-key".into(), + config_path.join(format!("h{id}_mk.pub")), + ]) + .args([ + "--mk-private-key".into(), + config_path.join(format!("h{id}_mk.key")), + ]); + } else { + command.arg("--disable-https"); + } command.preserved_fds(vec![mpc.as_raw_fd(), shard.as_raw_fd()]); command.args(["--server-socket-fd", &mpc.as_raw_fd().to_string()]); command.args(["--shard-server-socket-fd", &shard.as_raw_fd().to_string()]); // something went wrong if command is terminated at this point. - let mut child = command.spawn().unwrap(); - if let Ok(Some(status)) = child.try_wait() { - panic!("Helper binary terminated early with status = {status}"); + let mut child = command.silent().spawn().unwrap(); + match child.try_wait() { + Ok(Some(status)) => { + panic!("Helper binary terminated early with status = {status}") + } + Ok(None) => {} + Err(e) => { + panic!("Error while waiting for helper binary: {e}"); + } } child.terminate_on_drop() @@ -305,6 +322,20 @@ pub fn test_network(https: bool) { T::execute(path, https); } +pub fn test_sharded_network>(https: bool) { + let dir = TempDir::new_delete_on_drop(); + let path = dir.path(); + + println!( + "generating configuration for {SHARDS} shards in {}", + path.display() + ); + let sockets = test_sharded_setup::(path); + let _helpers = spawn_shards(path, &sockets, https); + + T::execute(path, https); +} + pub fn test_ipa(mode: IpaSecurityModel, https: bool, encrypted_inputs: bool) { test_ipa_with_config( mode, @@ -440,7 +471,7 @@ pub fn test_ipa_with_config( assert_eq!(INPUT_SIZE, usize::from(output.input_size)); } -pub trait NetworkTest { +pub trait NetworkTest { fn execute(config_path: &Path, https: bool); } @@ -459,3 +490,39 @@ impl NetworkTest for AddInPrimeField { test_add_in_prime_field(config_path, https, N) } } + +pub struct ShardedShuffle; + +impl NetworkTest for ShardedShuffle { + fn execute(config_path: &Path, https: bool) { + let mut command = Command::new(TEST_MPC_BIN); + command + .args(["--network".into(), config_path.join("network.toml")]) + .args(["--wait", "2"]); + + if !https { + command.arg("--disable-https"); + } + + command.arg("sharded-shuffle").stdin(Stdio::piped()); + + let test_mpc = command.silent().spawn().unwrap().terminate_on_drop(); + + // Shuffle numbers from 1 to 10. `test_mpc` binary will check if they were + // permuted correctly. Our job here is to submit input large enough to avoid + // false negatives + test_mpc + .stdin + .as_ref() + .unwrap() + .write_all( + (1..10) + .map(|i| i.to_string()) + .collect::>() + .join("\n") + .as_bytes(), + ) + .unwrap(); + TerminateOnDrop::wait(test_mpc).unwrap_status(); + } +} diff --git a/ipa-core/tests/helper_networks.rs b/ipa-core/tests/helper_networks.rs index fbd1e2433..aa28ea81a 100644 --- a/ipa-core/tests/helper_networks.rs +++ b/ipa-core/tests/helper_networks.rs @@ -1,11 +1,6 @@ mod common; -use std::{ - array, - io::Write, - path::Path, - process::{Command, Stdio}, -}; +use std::{array, path::Path, process::Command}; use common::{ spawn_helpers, tempdir::TempDir, test_ipa, test_multiply, test_network, CommandExt, @@ -14,8 +9,7 @@ use common::{ use ipa_core::{cli::CliPaths, helpers::HelperIdentity, test_fixture::ipa::IpaSecurityModel}; use crate::common::{ - spawn_shards, test_sharded_setup, AddInPrimeField, Multiply, ShardTcpListeners, - TerminateOnDrop, TerminateOnDropExt, TEST_MPC_BIN, + test_sharded_network, AddInPrimeField, Multiply, ShardTcpListeners, ShardedShuffle, }; #[test] @@ -72,40 +66,13 @@ fn https_malicious_ipa() { #[test] #[cfg(all(test, web_test))] fn http_sharded_shuffle_3_shards() { - let dir = TempDir::new_delete_on_drop(); - let path = dir.path(); - - println!("generating configuration in {}", path.display()); - let sockets = test_sharded_setup::<3>(path); - let _helpers = spawn_shards(path, &sockets, false); + test_sharded_network::<3, ShardedShuffle>(false); +} - let mut command = Command::new(TEST_MPC_BIN); - command - .args(["--network".into(), path.join("network.toml")]) - .args(["--wait", "2"]) - .arg("--disable-https"); - - command.arg("sharded-shuffle").stdin(Stdio::piped()); - - let test_mpc = command.spawn().unwrap().terminate_on_drop(); - - // Shuffle numbers from 1 to 10. `test_mpc` binary will check if they were - // permuted correctly. Our job here is to submit input large enough to avoid - // false negatives - test_mpc - .stdin - .as_ref() - .unwrap() - .write_all( - (1..10) - .into_iter() - .map(|i| i.to_string()) - .collect::>() - .join("\n") - .as_bytes(), - ) - .unwrap(); - TerminateOnDrop::wait(test_mpc).unwrap_status(); +#[test] +#[cfg(all(test, web_test))] +fn https_sharded_shuffle_3_shards() { + test_sharded_network::<3, ShardedShuffle>(true); } /// Similar to [`network`] tests, but it uses keygen + confgen CLIs to generate helper client config