Skip to content

Commit

Permalink
Merge from main
Browse files Browse the repository at this point in the history
  • Loading branch information
akoshelev committed Dec 2, 2024
2 parents 0762bb1 + 55d4f78 commit c969aeb
Show file tree
Hide file tree
Showing 9 changed files with 323 additions and 234 deletions.
26 changes: 25 additions & 1 deletion ipa-core/src/ff/ec_prime_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::{
},
secret_sharing::{
replicated::{malicious::ExtendableField, semi_honest::AdditiveShare},
Block, FieldVectorizable, SharedValue, StdArray, Vectorizable,
Block, FieldVectorizable, SharedValue, SharedValueArray, StdArray, Vectorizable,
},
};

Expand Down Expand Up @@ -267,6 +267,30 @@ macro_rules! impl_share_from_random {

impl_share_from_random!(PRF_CHUNK);

/// Calls the `batch_invert` function for Scalars from the `curve25519_dalek` crate.
/// # Panics
/// When N != N (never)
pub fn batch_invert<const N: usize>(
inputs: &<Fp25519 as Vectorizable<N>>::Array,
) -> <Fp25519 as Vectorizable<N>>::Array
where
Fp25519: Vectorizable<N>,
{
// TODO: This can be made more memory-efficient if we can manage to pass
// the input as a mutable reference to `Scalar::batch_invert`
// We can also create our own version of `Scalar::batch_invert` that avoids Vecs,
// but this would require forking the crate.
let mut inverted: [Scalar; N] = inputs
.clone()
.into_iter()
.map(|x| x.0)
.collect::<Vec<Scalar>>() // Relying on the compiler to optimize this out
.try_into()
.unwrap(); // Safe, the length will always be N
Scalar::batch_invert(&mut inverted);
<Fp25519 as Vectorizable<N>>::Array::from_fn(|i| Fp25519(inverted[i]))
}

#[cfg(all(test, unit_test))]
mod test {
use curve25519_dalek::scalar::Scalar;
Expand Down
135 changes: 88 additions & 47 deletions ipa-core/src/protocol/hybrid/breakdown_reveal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,19 +106,18 @@ where
// Any real-world aggregation should be able to complete in two layers (two
// iterations of the `while` loop below). Tests with small `TARGET_PROOF_SIZE`
// may exceed that.
let mut chunk_counter = 0;
let mut depth = 0;
let agg_proof_chunk = aggregate_values_proof_chunk(B, usize::try_from(V::BITS).unwrap());

while intermediate_results.len() > 1 {
let mut record_ids = [RecordId::FIRST; AGGREGATE_DEPTH];
let mut next_intermediate_results = Vec::new();
for chunk in intermediate_results.chunks(agg_proof_chunk) {
for (chunk_counter, chunk) in intermediate_results.chunks(agg_proof_chunk).enumerate() {
let chunk_len = chunk.len();
let validator = ctx.clone().dzkp_validator(
MaliciousProtocolSteps {
protocol: &Step::aggregate(depth),
validate: &Step::AggregateValidate,
validate: &Step::aggregate_validate(depth),
},
usize::MAX, // See note about batching above.
);
Expand All @@ -130,17 +129,21 @@ where
)
.await?;
validator.validate_indexed(chunk_counter).await?;
chunk_counter += 1;
next_intermediate_results.push(result);
}
depth += 1;
intermediate_results = next_intermediate_results;
}

Ok(intermediate_results
let mut result = intermediate_results
.into_iter()
.next()
.expect("aggregation input must not be empty"))
.expect("aggregation input must not be empty");
result.resize(
usize::try_from(HV::BITS).unwrap(),
Replicated::<Boolean, B>::ZERO,
);
Ok(result)
}

/// Transforms the Breakdown key from a secret share into a revealed `usize`.
Expand Down Expand Up @@ -266,6 +269,27 @@ pub mod tests {
}
}

fn inputs_and_expectation<R: Rng>(
mut rng: R,
) -> (Vec<TestAggregateableHybridReport>, Vec<u128>) {
let mut expectation = Vec::new();
for _ in 0..32 {
expectation.push(rng.gen_range(0u128..256));
}
let mut inputs = Vec::new();
for (bk, expected_hv) in expectation.iter().enumerate() {
let mut remainder = *expected_hv;
while remainder > 7 {
let tv = rng.gen_range(0u128..8);
remainder -= tv;
inputs.push(input_row(bk, tv));
}
inputs.push(input_row(bk, remainder));
}
inputs.shuffle(&mut rng);
(inputs, expectation)
}

#[test]
fn breakdown_reveal_semi_honest_happy_path() {
// if shuttle executor is enabled, run this test only once.
Expand All @@ -276,23 +300,7 @@ pub mod tests {
const SHARDS: usize = 2;
run_with::<_, _, 3>(|| async {
let world = TestWorld::<WithShards<SHARDS>>::with_shards(TestWorldConfig::default());
let mut rng = world.rng();
let mut expectation = Vec::new();
for _ in 0..32 {
expectation.push(rng.gen_range(0u128..256));
}
let expectation = expectation; // no more mutability for safety
let mut inputs = Vec::new();
for (bk, expected_hv) in expectation.iter().enumerate() {
let mut remainder = *expected_hv;
while remainder > 7 {
let tv = rng.gen_range(0u128..8);
remainder -= tv;
inputs.push(input_row(bk, tv));
}
inputs.push(input_row(bk, remainder));
}
inputs.shuffle(&mut rng);
let (inputs, expectation) = inputs_and_expectation(world.rng());
let result: Vec<_> = world
.semi_honest(inputs.into_iter(), |ctx, reports| async move {
breakdown_reveal_aggregation::<_, BA5, BA3, HV, 32>(
Expand Down Expand Up @@ -332,30 +340,7 @@ pub mod tests {
const SHARDS: usize = 2;
run(|| async {
let world = TestWorld::<WithShards<SHARDS>>::with_shards(TestWorldConfig::default());
let mut rng = world.rng();
let mut expectation = Vec::new();
for _ in 0..32 {
expectation.push(rng.gen_range(0u128..512));
}
// The size of input needed here to get complete coverage (more precisely,
// the size of input to the final aggregation using `aggregate_values`)
// depends on `TARGET_PROOF_SIZE`.
let expectation = expectation; // no more mutability for safety
let mut inputs = Vec::new();
// Builds out inputs with values for each breakdown_key that add up to
// the expectation. Expectation is ranomg (0..512). Each iteration
// generates a value (0..8) and subtracts from the expectation until a final
// remaninder in (0..8) remains to be added to the vec.
for (breakdown_key, expected_value) in expectation.iter().enumerate() {
let mut remainder = *expected_value;
while remainder > 7 {
let value = rng.gen_range(0u128..8);
remainder -= value;
inputs.push(input_row(breakdown_key, value));
}
inputs.push(input_row(breakdown_key, remainder));
}
inputs.shuffle(&mut rng);
let (inputs, expectation) = inputs_and_expectation(world.rng());

let result: Vec<_> = world
.malicious(inputs.into_iter(), |ctx, reports| async move {
Expand Down Expand Up @@ -388,4 +373,60 @@ pub mod tests {
assert_eq!(result, expectation);
});
}

#[test]
#[cfg(not(feature = "shuttle"))] // too slow
fn breakdown_reveal_malicious_chunk_size_1() {
type HV = BA16;
const SHARDS: usize = 1;
run(|| async {
let world = TestWorld::<WithShards<SHARDS>>::with_shards(TestWorldConfig::default());

let mut inputs = vec![
input_row(1, 1),
input_row(1, 2),
input_row(1, 3),
input_row(1, 4),
];
inputs.extend_from_within(..); // 8
inputs.extend_from_within(..); // 16
inputs.extend_from_within(..); // 32
inputs.extend_from_within(..); // 64
inputs.extend_from_within(..1); // 65

let expectation = [
0, 161, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0,
];

let result: Vec<_> = world
.malicious(inputs.into_iter(), |ctx, reports| async move {
breakdown_reveal_aggregation::<_, BA5, BA3, HV, 32>(
ctx,
reports,
&PaddingParameters::no_padding(),
)
.map_ok(|d: BitDecomposed<Replicated<Boolean, 32>>| {
Vec::transposed_from(&d).unwrap()
})
.await
.unwrap()
})
.await
.reconstruct();

let initial = vec![0_u128; 32];
let result = result
.iter()
.fold(initial, |mut acc, vec: &Vec<HV>| {
acc.iter_mut()
.zip(vec)
.for_each(|(a, &b)| *a += b.as_u128());
acc
})
.into_iter()
.collect::<Vec<_>>();
assert_eq!(result, expectation);
});
}
}
29 changes: 23 additions & 6 deletions ipa-core/src/protocol/hybrid/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::convert::Infallible;
use tracing::{info_span, Instrument};

use crate::{
error::Error,
error::{Error, LengthError},
ff::{
boolean::Boolean, boolean_array::BooleanArray, curve_points::RP25519,
ec_prime_field::Fp25519, Serializable, U128Conversions,
Expand All @@ -17,6 +17,7 @@ use crate::{
protocol::{
basics::{BooleanArrayMul, Reveal},
context::{DZKPUpgraded, MacUpgraded, ShardedContext, UpgradableContext},
dp::dp_for_histogram,
hybrid::{
agg::aggregate_reports,
breakdown_reveal::breakdown_reveal_aggregation,
Expand Down Expand Up @@ -63,10 +64,10 @@ use crate::{
/// Propagates errors from config issues or while running the protocol
/// # Panics
/// Propagates errors from config issues or while running the protocol
pub async fn hybrid_protocol<'ctx, C, BK, V, HV, const B: usize>(
pub async fn hybrid_protocol<'ctx, C, BK, V, HV, const SS_BITS: usize, const B: usize>(
ctx: C,
input_rows: Vec<IndistinguishableHybridReport<BK, V>>,
_dp_params: DpMechanism,
dp_params: DpMechanism,
dp_padding_params: PaddingParameters,
) -> Result<Vec<Replicated<HV>>, Error>
where
Expand All @@ -87,6 +88,10 @@ where
+ Reveal<DZKPUpgraded<C>, Output = <BK as Vectorizable<1>>::Array>,
BitDecomposed<Replicated<Boolean, B>>:
for<'a> TransposeFrom<&'a [Replicated<V>; B], Error = Infallible>,
BitDecomposed<Replicated<Boolean, B>>:
for<'a> TransposeFrom<&'a [Replicated<HV>; B], Error = Infallible>,
Vec<Replicated<HV>>:
for<'a> TransposeFrom<&'a BitDecomposed<Replicated<Boolean, B>>, Error = LengthError>,
{
if input_rows.is_empty() {
return Ok(vec![Replicated::ZERO; B]);
Expand All @@ -110,11 +115,23 @@ where

let aggregated_reports = aggregate_reports::<BK, V, C>(ctx.clone(), sharded_reports).await?;

let _historgram = breakdown_reveal_aggregation::<C, BK, V, HV, B>(
let histogram = breakdown_reveal_aggregation::<C, BK, V, HV, B>(
ctx.clone(),
aggregated_reports,
&dp_padding_params,
);
)
.await?;

let noisy_histogram = if ctx.is_leader() {
dp_for_histogram::<_, B, HV, SS_BITS>(ctx, histogram, dp_params).await?
} else {
// the following ZERO vec should be the result for
// all follow shards, but this won't work until
// #1446 is merged
// vec![Replicated::<HV>::ZERO; B]
// temporary hack, just return the histogram
Vec::transposed_from(&histogram)?
};

unimplemented!("protocol::hybrid::hybrid_protocol is not fully implemented")
Ok(noisy_histogram)
}
2 changes: 1 addition & 1 deletion ipa-core/src/protocol/hybrid/step.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use ipa_step_derive::CompactStep;
#[derive(CompactStep)]
pub(crate) enum HybridStep {
ReshardByTag,
#[step(child = crate::protocol::ipa_prf::oprf_padding::step::PaddingDpStep, name="padding_dp")]
#[step(child = crate::protocol::ipa_prf::oprf_padding::step::PaddingDpStep, name="report_padding_dp")]
PaddingDp,
#[step(child = crate::protocol::ipa_prf::shuffle::step::OPRFShuffleStep)]
InputShuffle,
Expand Down
6 changes: 2 additions & 4 deletions ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,19 +174,18 @@ where
// Any real-world aggregation should be able to complete in two layers (two
// iterations of the `while` loop below). Tests with small `TARGET_PROOF_SIZE`
// may exceed that.
let mut chunk_counter = 0;
let mut depth = 0;
let agg_proof_chunk = aggregate_values_proof_chunk(B, usize::try_from(TV::BITS).unwrap());

while intermediate_results.len() > 1 {
let mut record_ids = [RecordId::FIRST; AGGREGATE_DEPTH];
let mut next_intermediate_results = Vec::new();
for chunk in intermediate_results.chunks(agg_proof_chunk) {
for (chunk_counter, chunk) in intermediate_results.chunks(agg_proof_chunk).enumerate() {
let chunk_len = chunk.len();
let validator = ctx.clone().dzkp_validator(
MaliciousProtocolSteps {
protocol: &Step::aggregate(depth),
validate: &Step::AggregateValidate,
validate: &Step::aggregate_validate(depth),
},
usize::MAX, // See note about batching above.
);
Expand All @@ -198,7 +197,6 @@ where
)
.await?;
validator.validate_indexed(chunk_counter).await?;
chunk_counter += 1;
next_intermediate_results.push(result);
}
depth += 1;
Expand Down
4 changes: 2 additions & 2 deletions ipa-core/src/protocol/ipa_prf/aggregation/step.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ pub(crate) enum AggregationStep {
RevealValidate, // only partly used -- see code
#[step(count = 4, child = AggregateChunkStep, name = "chunks")]
Aggregate(usize),
#[step(child = crate::protocol::context::step::DzkpValidationProtocolStep)]
AggregateValidate,
#[step(count = 4, child = crate::protocol::context::step::DzkpValidationProtocolStep)]
AggregateValidate(usize),
}

// The step count here is duplicated as the AGGREGATE_DEPTH constant in the code.
Expand Down
5 changes: 3 additions & 2 deletions ipa-core/src/protocol/ipa_prf/prf_eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,9 @@ where
.await?;

//compute R^(1/z) to u64
Ok(zip(gr, z)
.map(|(gr, z)| u64::from(gr * z.invert()))
let inv_z = crate::ff::ec_prime_field::batch_invert::<N>(&z);
Ok(zip(gr, inv_z)
.map(|(gr, inv_z)| u64::from(gr * inv_z))
.collect::<Vec<_>>()
.try_into()
.expect("iteration over arrays"))
Expand Down
Loading

0 comments on commit c969aeb

Please sign in to comment.