Skip to content

Commit

Permalink
Sharded shuffle http e2e test (#1464)
Browse files Browse the repository at this point in the history
* Some plumbing work for sharded shuffle e2e

Sharded shuffle e2e test without HTTP handler

* temp

* temp2

* For Alex. Unit tests pass

* added num_arg

* Sharded shuffle HTTP end-to-end test

This finalizes the plumbing for HTTP stack and verifies that it is working correctly end-to-end.

* Fix compact gate integration tests

* Fix compact gate tests

* Fix merge issues

* Improve documentation a bit

---------

Co-authored-by: Christian Berkhoff <[email protected]>
  • Loading branch information
akoshelev and cberkhoff authored Dec 3, 2024
1 parent d3c7469 commit e4d833d
Show file tree
Hide file tree
Showing 15 changed files with 627 additions and 117 deletions.
53 changes: 44 additions & 9 deletions ipa-core/src/bin/test_mpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@ use generic_array::ArrayLength;
use hyper::http::uri::Scheme;
use ipa_core::{
cli::{
playbook::{make_clients, secure_add, secure_mul, validate, InputSource},
playbook::{
make_clients, make_sharded_clients, secure_add, secure_mul, secure_shuffle, validate,
InputSource,
},
Verbosity,
},
ff::{Field, FieldType, Fp31, Fp32BitPrime, Serializable, U128Conversions},
ff::{
boolean_array::BA64, Field, FieldType, Fp31, Fp32BitPrime, Serializable, U128Conversions,
},
helpers::query::{
QueryConfig,
QueryType::{TestAddInPrimeField, TestMultiply},
QueryType::{TestAddInPrimeField, TestMultiply, TestShardedShuffle},
},
net::{Helper, IpaHttpClient},
secret_sharing::{replicated::semi_honest::AdditiveShare, IntoShares},
Expand Down Expand Up @@ -103,11 +108,27 @@ async fn main() -> Result<(), Box<dyn Error>> {
Scheme::HTTPS
};

let (clients, _) = make_clients(args.network.as_deref(), scheme, args.wait).await;
match args.action {
TestAction::Multiply => multiply(&args, &clients).await,
TestAction::AddInPrimeField => add(&args, &clients).await,
TestAction::ShardedShuffle => sharded_shuffle(&args, &clients).await,
TestAction::Multiply => {
let (clients, _) = make_clients(args.network.as_deref(), scheme, args.wait).await;
multiply(&args, &clients).await
}
TestAction::AddInPrimeField => {
let (clients, _) = make_clients(args.network.as_deref(), scheme, args.wait).await;
add(&args, &clients).await
}
TestAction::ShardedShuffle => {
// we need clients to talk to each individual shard
let clients = make_sharded_clients(
args.network
.as_deref()
.expect("network config is required for sharded shuffle"),
scheme,
args.wait,
)
.await;
sharded_shuffle(&args, clients).await
}
};

Ok(())
Expand Down Expand Up @@ -166,6 +187,20 @@ async fn add(args: &Args, helper_clients: &[IpaHttpClient<Helper>; 3]) {
};
}

async fn sharded_shuffle(_args: &Args, _helper_clients: &[IpaHttpClient<Helper>; 3]) {
unimplemented!()
async fn sharded_shuffle(args: &Args, helper_clients: Vec<[IpaHttpClient<Helper>; 3]>) {
let input = InputSource::from(&args.input);
let input_rows = input
.iter::<u64>()
.map(BA64::truncate_from)
.collect::<Vec<_>>();
let query_config =
QueryConfig::new(TestShardedShuffle, args.input.field, input_rows.len()).unwrap();
let query_id = helper_clients[0][0]
.create_query(query_config)
.await
.unwrap();
let shuffled = secure_shuffle(input_rows.clone(), &helper_clients, query_id).await;

assert_eq!(shuffled.len(), input_rows.len());
assert_ne!(shuffled, input_rows);
}
19 changes: 19 additions & 0 deletions ipa-core/src/cli/config_parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ fn assert_hpke_config(expected: &Value, actual: Option<&HpkeClientConfig>) {
#[allow(dead_code)]
pub trait HelperNetworkConfigParseExt {
fn from_toml_str(input: &str) -> Result<NetworkConfig<Helper>, Error>;
fn from_toml_str_sharded(input: &str) -> Result<Vec<NetworkConfig<Helper>>, Error>;
}

/// Reads config from string. Expects config to be toml format.
Expand All @@ -274,6 +275,24 @@ impl HelperNetworkConfigParseExt for NetworkConfig<Helper> {
all_network.client.clone(),
))
}
fn from_toml_str_sharded(input: &str) -> Result<Vec<NetworkConfig<Helper>>, Error> {
let all_network = parse_sharded_network_toml(input)?;
// peers are grouped by shard, meaning 0,1,2 describe MPC for shard 0.
// 3,4,5 describe shard 1, etc.
Ok(all_network
.peers
.chunks(3)
.map(|mpc_config| {
NetworkConfig::new_mpc(
mpc_config
.iter()
.map(ShardedPeerConfigToml::to_mpc_peer)
.collect(),
all_network.client.clone(),
)
})
.collect())
}
}

/// Reads a the config for a specific, single, sharded server from string. Expects config to be
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/cli/keygen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use time::{Duration, OffsetDateTime};

use crate::{error::BoxError, hpke::KeyPair};

#[derive(Debug, Args)]
#[derive(Debug, Clone, Args)]
#[clap(
name = "keygen",
about = "Generate keys used by an MPC helper",
Expand Down
50 changes: 43 additions & 7 deletions ipa-core/src/cli/playbook/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod generator;
mod input;
mod ipa;
mod multiply;
mod sharded_shuffle;

use core::fmt::Debug;
use std::{fs, path::Path, time::Duration};
Expand All @@ -12,6 +13,7 @@ use comfy_table::{Cell, Color, Table};
use hyper::http::uri::Scheme;
pub use input::InputSource;
pub use multiply::secure_mul;
pub use sharded_shuffle::secure_shuffle;
use tokio::time::sleep;

pub use self::ipa::{playbook_oprf_ipa, run_query_and_validate};
Expand Down Expand Up @@ -196,7 +198,6 @@ pub async fn make_clients(
scheme: Scheme,
wait: usize,
) -> ([IpaHttpClient<Helper>; 3], NetworkConfig<Helper>) {
let mut wait = wait;
let network = if let Some(path) = network_path {
NetworkConfig::from_toml_str(&fs::read_to_string(path).unwrap()).unwrap()
} else {
Expand All @@ -214,16 +215,51 @@ pub async fn make_clients(
// Note: This closure is only called when the selected action uses clients.

let clients = IpaHttpClient::from_conf(&IpaRuntime::current(), &network, &ClientIdentity::None);
while wait > 0 && !clients_ready(&clients).await {
wait_for_servers(wait, &[clients.clone()]).await;
(clients, network)
}

/// Creates enough clients to talk to all shards on MPC helpers. This only supports
/// reading configuration from the `network.toml` file
/// ## Panics
/// If configuration file `network_path` cannot be read from or if it does not conform to toml spec.
pub async fn make_sharded_clients(
network_path: &Path,
scheme: Scheme,
wait: usize,
) -> Vec<[IpaHttpClient<Helper>; 3]> {
let network =
NetworkConfig::from_toml_str_sharded(&fs::read_to_string(network_path).unwrap()).unwrap();

let clients = network
.into_iter()
.map(|network| {
let network = network.override_scheme(&scheme);
IpaHttpClient::from_conf(&IpaRuntime::current(), &network, &ClientIdentity::None)
})
.collect::<Vec<_>>();

wait_for_servers(wait, &clients).await;

clients
}

async fn wait_for_servers(mut wait: usize, clients: &[[IpaHttpClient<Helper>; 3]]) {
while wait > 0 && !clients_ready(clients).await {
tracing::debug!("waiting for servers to come up");
sleep(Duration::from_secs(1)).await;
wait -= 1;
}
(clients, network)
}

async fn clients_ready(clients: &[IpaHttpClient<Helper>; 3]) -> bool {
clients[0].echo("").await.is_ok()
&& clients[1].echo("").await.is_ok()
&& clients[2].echo("").await.is_ok()
#[allow(clippy::disallowed_methods)]
async fn clients_ready(clients: &[[IpaHttpClient<Helper>; 3]]) -> bool {
let r = futures::future::join_all(clients.iter().map(|clients| async move {
clients[0].echo("").await.is_ok()
&& clients[1].echo("").await.is_ok()
&& clients[2].echo("").await.is_ok()
}))
.await;

r.iter().all(|&v| v)
}
97 changes: 97 additions & 0 deletions ipa-core/src/cli/playbook/sharded_shuffle.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
use std::{
cmp::{max, min},
ops::Add,
time::Duration,
};

use futures_util::future::try_join_all;
use generic_array::ArrayLength;

use crate::{
ff::{boolean_array::BooleanArray, Serializable},
helpers::{query::QueryInput, BodyStream},
net::{Helper, IpaHttpClient},
protocol::QueryId,
query::QueryStatus,
secret_sharing::{replicated::semi_honest::AdditiveShare, IntoShares},
test_fixture::Reconstruct,
};

/// Secure sharded shuffle protocol
///
/// ## Panics
/// If the input size is empty or contains only one row.
#[allow(clippy::disallowed_methods)] // allow try_join_all
pub async fn secure_shuffle<V>(
inputs: Vec<V>,
clients: &[[IpaHttpClient<Helper>; 3]],
query_id: QueryId,
) -> Vec<V>
where
V: IntoShares<AdditiveShare<V>>,
<V as Serializable>::Size: Add<<V as Serializable>::Size, Output: ArrayLength>,
V: BooleanArray,
{
assert!(
inputs.len() > 1,
"Shuffle requires at least two rows to be shuffled"
);
let chunk_size = max(1, inputs.len() / clients.len());
let _ = try_join_all(
inputs
.chunks(chunk_size)
.zip(clients)
.map(|(chunk, mpc_clients)| {
let shared = chunk.iter().copied().share();
try_join_all(mpc_clients.each_ref().iter().zip(shared).map(
|(mpc_client, input)| {
mpc_client.query_input(QueryInput {
query_id,
input_stream: BodyStream::from_serializable_iter(input),
})
},
))
}),
)
.await
.unwrap();
let leader_clients = &clients[0];

let mut delay = Duration::from_millis(125);
loop {
if try_join_all(
leader_clients
.iter()
.map(|client| client.query_status(query_id)),
)
.await
.unwrap()
.into_iter()
.all(|status| status == QueryStatus::Completed)
{
break;
}

tokio::time::sleep(delay).await;
delay = min(Duration::from_secs(5), delay * 2);
}

let results: [_; 3] = try_join_all(
leader_clients
.iter()
.map(|client| client.query_results(query_id)),
)
.await
.unwrap()
.try_into()
.unwrap();
let results: Vec<V> = results
.map(|bytes| {
AdditiveShare::<V>::from_byte_slice(&bytes)
.collect::<Result<Vec<_>, _>>()
.unwrap()
})
.reconstruct();

results
}
Loading

0 comments on commit e4d833d

Please sign in to comment.