diff --git a/ipa-core/src/ff/ec_prime_field.rs b/ipa-core/src/ff/ec_prime_field.rs index 893e56007..a2e565ac0 100644 --- a/ipa-core/src/ff/ec_prime_field.rs +++ b/ipa-core/src/ff/ec_prime_field.rs @@ -14,7 +14,7 @@ use crate::{ }, secret_sharing::{ replicated::{malicious::ExtendableField, semi_honest::AdditiveShare}, - Block, FieldVectorizable, SharedValue, StdArray, Vectorizable, + Block, FieldVectorizable, SharedValue, SharedValueArray, StdArray, Vectorizable, }, }; @@ -267,6 +267,30 @@ macro_rules! impl_share_from_random { impl_share_from_random!(PRF_CHUNK); +/// Calls the `batch_invert` function for Scalars from the `curve25519_dalek` crate. +/// # Panics +/// When N != N (never) +pub fn batch_invert( + inputs: &>::Array, +) -> >::Array +where + Fp25519: Vectorizable, +{ + // TODO: This can be made more memory-efficient if we can manage to pass + // the input as a mutable reference to `Scalar::batch_invert` + // We can also create our own version of `Scalar::batch_invert` that avoids Vecs, + // but this would require forking the crate. + let mut inverted: [Scalar; N] = inputs + .clone() + .into_iter() + .map(|x| x.0) + .collect::>() // Relying on the compiler to optimize this out + .try_into() + .unwrap(); // Safe, the length will always be N + Scalar::batch_invert(&mut inverted); + >::Array::from_fn(|i| Fp25519(inverted[i])) +} + #[cfg(all(test, unit_test))] mod test { use curve25519_dalek::scalar::Scalar; diff --git a/ipa-core/src/protocol/hybrid/breakdown_reveal.rs b/ipa-core/src/protocol/hybrid/breakdown_reveal.rs index cb9e4caf6..be676bea6 100644 --- a/ipa-core/src/protocol/hybrid/breakdown_reveal.rs +++ b/ipa-core/src/protocol/hybrid/breakdown_reveal.rs @@ -106,19 +106,18 @@ where // Any real-world aggregation should be able to complete in two layers (two // iterations of the `while` loop below). Tests with small `TARGET_PROOF_SIZE` // may exceed that. - let mut chunk_counter = 0; let mut depth = 0; let agg_proof_chunk = aggregate_values_proof_chunk(B, usize::try_from(V::BITS).unwrap()); while intermediate_results.len() > 1 { let mut record_ids = [RecordId::FIRST; AGGREGATE_DEPTH]; let mut next_intermediate_results = Vec::new(); - for chunk in intermediate_results.chunks(agg_proof_chunk) { + for (chunk_counter, chunk) in intermediate_results.chunks(agg_proof_chunk).enumerate() { let chunk_len = chunk.len(); let validator = ctx.clone().dzkp_validator( MaliciousProtocolSteps { protocol: &Step::aggregate(depth), - validate: &Step::AggregateValidate, + validate: &Step::aggregate_validate(depth), }, usize::MAX, // See note about batching above. ); @@ -130,17 +129,21 @@ where ) .await?; validator.validate_indexed(chunk_counter).await?; - chunk_counter += 1; next_intermediate_results.push(result); } depth += 1; intermediate_results = next_intermediate_results; } - Ok(intermediate_results + let mut result = intermediate_results .into_iter() .next() - .expect("aggregation input must not be empty")) + .expect("aggregation input must not be empty"); + result.resize( + usize::try_from(HV::BITS).unwrap(), + Replicated::::ZERO, + ); + Ok(result) } /// Transforms the Breakdown key from a secret share into a revealed `usize`. @@ -266,6 +269,27 @@ pub mod tests { } } + fn inputs_and_expectation( + mut rng: R, + ) -> (Vec, Vec) { + let mut expectation = Vec::new(); + for _ in 0..32 { + expectation.push(rng.gen_range(0u128..256)); + } + let mut inputs = Vec::new(); + for (bk, expected_hv) in expectation.iter().enumerate() { + let mut remainder = *expected_hv; + while remainder > 7 { + let tv = rng.gen_range(0u128..8); + remainder -= tv; + inputs.push(input_row(bk, tv)); + } + inputs.push(input_row(bk, remainder)); + } + inputs.shuffle(&mut rng); + (inputs, expectation) + } + #[test] fn breakdown_reveal_semi_honest_happy_path() { // if shuttle executor is enabled, run this test only once. @@ -276,23 +300,7 @@ pub mod tests { const SHARDS: usize = 2; run_with::<_, _, 3>(|| async { let world = TestWorld::>::with_shards(TestWorldConfig::default()); - let mut rng = world.rng(); - let mut expectation = Vec::new(); - for _ in 0..32 { - expectation.push(rng.gen_range(0u128..256)); - } - let expectation = expectation; // no more mutability for safety - let mut inputs = Vec::new(); - for (bk, expected_hv) in expectation.iter().enumerate() { - let mut remainder = *expected_hv; - while remainder > 7 { - let tv = rng.gen_range(0u128..8); - remainder -= tv; - inputs.push(input_row(bk, tv)); - } - inputs.push(input_row(bk, remainder)); - } - inputs.shuffle(&mut rng); + let (inputs, expectation) = inputs_and_expectation(world.rng()); let result: Vec<_> = world .semi_honest(inputs.into_iter(), |ctx, reports| async move { breakdown_reveal_aggregation::<_, BA5, BA3, HV, 32>( @@ -332,30 +340,7 @@ pub mod tests { const SHARDS: usize = 2; run(|| async { let world = TestWorld::>::with_shards(TestWorldConfig::default()); - let mut rng = world.rng(); - let mut expectation = Vec::new(); - for _ in 0..32 { - expectation.push(rng.gen_range(0u128..512)); - } - // The size of input needed here to get complete coverage (more precisely, - // the size of input to the final aggregation using `aggregate_values`) - // depends on `TARGET_PROOF_SIZE`. - let expectation = expectation; // no more mutability for safety - let mut inputs = Vec::new(); - // Builds out inputs with values for each breakdown_key that add up to - // the expectation. Expectation is ranomg (0..512). Each iteration - // generates a value (0..8) and subtracts from the expectation until a final - // remaninder in (0..8) remains to be added to the vec. - for (breakdown_key, expected_value) in expectation.iter().enumerate() { - let mut remainder = *expected_value; - while remainder > 7 { - let value = rng.gen_range(0u128..8); - remainder -= value; - inputs.push(input_row(breakdown_key, value)); - } - inputs.push(input_row(breakdown_key, remainder)); - } - inputs.shuffle(&mut rng); + let (inputs, expectation) = inputs_and_expectation(world.rng()); let result: Vec<_> = world .malicious(inputs.into_iter(), |ctx, reports| async move { @@ -388,4 +373,60 @@ pub mod tests { assert_eq!(result, expectation); }); } + + #[test] + #[cfg(not(feature = "shuttle"))] // too slow + fn breakdown_reveal_malicious_chunk_size_1() { + type HV = BA16; + const SHARDS: usize = 1; + run(|| async { + let world = TestWorld::>::with_shards(TestWorldConfig::default()); + + let mut inputs = vec![ + input_row(1, 1), + input_row(1, 2), + input_row(1, 3), + input_row(1, 4), + ]; + inputs.extend_from_within(..); // 8 + inputs.extend_from_within(..); // 16 + inputs.extend_from_within(..); // 32 + inputs.extend_from_within(..); // 64 + inputs.extend_from_within(..1); // 65 + + let expectation = [ + 0, 161, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, + ]; + + let result: Vec<_> = world + .malicious(inputs.into_iter(), |ctx, reports| async move { + breakdown_reveal_aggregation::<_, BA5, BA3, HV, 32>( + ctx, + reports, + &PaddingParameters::no_padding(), + ) + .map_ok(|d: BitDecomposed>| { + Vec::transposed_from(&d).unwrap() + }) + .await + .unwrap() + }) + .await + .reconstruct(); + + let initial = vec![0_u128; 32]; + let result = result + .iter() + .fold(initial, |mut acc, vec: &Vec| { + acc.iter_mut() + .zip(vec) + .for_each(|(a, &b)| *a += b.as_u128()); + acc + }) + .into_iter() + .collect::>(); + assert_eq!(result, expectation); + }); + } } diff --git a/ipa-core/src/protocol/hybrid/mod.rs b/ipa-core/src/protocol/hybrid/mod.rs index 61b4fe8c7..bfeaabf88 100644 --- a/ipa-core/src/protocol/hybrid/mod.rs +++ b/ipa-core/src/protocol/hybrid/mod.rs @@ -8,7 +8,7 @@ use std::convert::Infallible; use tracing::{info_span, Instrument}; use crate::{ - error::Error, + error::{Error, LengthError}, ff::{ boolean::Boolean, boolean_array::BooleanArray, curve_points::RP25519, ec_prime_field::Fp25519, Serializable, U128Conversions, @@ -17,6 +17,7 @@ use crate::{ protocol::{ basics::{BooleanArrayMul, Reveal}, context::{DZKPUpgraded, MacUpgraded, ShardedContext, UpgradableContext}, + dp::dp_for_histogram, hybrid::{ agg::aggregate_reports, breakdown_reveal::breakdown_reveal_aggregation, @@ -63,10 +64,10 @@ use crate::{ /// Propagates errors from config issues or while running the protocol /// # Panics /// Propagates errors from config issues or while running the protocol -pub async fn hybrid_protocol<'ctx, C, BK, V, HV, const B: usize>( +pub async fn hybrid_protocol<'ctx, C, BK, V, HV, const SS_BITS: usize, const B: usize>( ctx: C, input_rows: Vec>, - _dp_params: DpMechanism, + dp_params: DpMechanism, dp_padding_params: PaddingParameters, ) -> Result>, Error> where @@ -87,6 +88,10 @@ where + Reveal, Output = >::Array>, BitDecomposed>: for<'a> TransposeFrom<&'a [Replicated; B], Error = Infallible>, + BitDecomposed>: + for<'a> TransposeFrom<&'a [Replicated; B], Error = Infallible>, + Vec>: + for<'a> TransposeFrom<&'a BitDecomposed>, Error = LengthError>, { if input_rows.is_empty() { return Ok(vec![Replicated::ZERO; B]); @@ -110,11 +115,23 @@ where let aggregated_reports = aggregate_reports::(ctx.clone(), sharded_reports).await?; - let _historgram = breakdown_reveal_aggregation::( + let histogram = breakdown_reveal_aggregation::( ctx.clone(), aggregated_reports, &dp_padding_params, - ); + ) + .await?; + + let noisy_histogram = if ctx.is_leader() { + dp_for_histogram::<_, B, HV, SS_BITS>(ctx, histogram, dp_params).await? + } else { + // the following ZERO vec should be the result for + // all follow shards, but this won't work until + // #1446 is merged + // vec![Replicated::::ZERO; B] + // temporary hack, just return the histogram + Vec::transposed_from(&histogram)? + }; - unimplemented!("protocol::hybrid::hybrid_protocol is not fully implemented") + Ok(noisy_histogram) } diff --git a/ipa-core/src/protocol/hybrid/step.rs b/ipa-core/src/protocol/hybrid/step.rs index b4022551d..71306035d 100644 --- a/ipa-core/src/protocol/hybrid/step.rs +++ b/ipa-core/src/protocol/hybrid/step.rs @@ -3,7 +3,7 @@ use ipa_step_derive::CompactStep; #[derive(CompactStep)] pub(crate) enum HybridStep { ReshardByTag, - #[step(child = crate::protocol::ipa_prf::oprf_padding::step::PaddingDpStep, name="padding_dp")] + #[step(child = crate::protocol::ipa_prf::oprf_padding::step::PaddingDpStep, name="report_padding_dp")] PaddingDp, #[step(child = crate::protocol::ipa_prf::shuffle::step::OPRFShuffleStep)] InputShuffle, diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs index 00d4c36af..198de6be9 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs @@ -174,19 +174,18 @@ where // Any real-world aggregation should be able to complete in two layers (two // iterations of the `while` loop below). Tests with small `TARGET_PROOF_SIZE` // may exceed that. - let mut chunk_counter = 0; let mut depth = 0; let agg_proof_chunk = aggregate_values_proof_chunk(B, usize::try_from(TV::BITS).unwrap()); while intermediate_results.len() > 1 { let mut record_ids = [RecordId::FIRST; AGGREGATE_DEPTH]; let mut next_intermediate_results = Vec::new(); - for chunk in intermediate_results.chunks(agg_proof_chunk) { + for (chunk_counter, chunk) in intermediate_results.chunks(agg_proof_chunk).enumerate() { let chunk_len = chunk.len(); let validator = ctx.clone().dzkp_validator( MaliciousProtocolSteps { protocol: &Step::aggregate(depth), - validate: &Step::AggregateValidate, + validate: &Step::aggregate_validate(depth), }, usize::MAX, // See note about batching above. ); @@ -198,7 +197,6 @@ where ) .await?; validator.validate_indexed(chunk_counter).await?; - chunk_counter += 1; next_intermediate_results.push(result); } depth += 1; diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/step.rs b/ipa-core/src/protocol/ipa_prf/aggregation/step.rs index 8be4fdcd1..9de6af752 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/step.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/step.rs @@ -14,8 +14,8 @@ pub(crate) enum AggregationStep { RevealValidate, // only partly used -- see code #[step(count = 4, child = AggregateChunkStep, name = "chunks")] Aggregate(usize), - #[step(child = crate::protocol::context::step::DzkpValidationProtocolStep)] - AggregateValidate, + #[step(count = 4, child = crate::protocol::context::step::DzkpValidationProtocolStep)] + AggregateValidate(usize), } // The step count here is duplicated as the AGGREGATE_DEPTH constant in the code. diff --git a/ipa-core/src/protocol/ipa_prf/prf_eval.rs b/ipa-core/src/protocol/ipa_prf/prf_eval.rs index 91a75bd99..29751593f 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_eval.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_eval.rs @@ -147,8 +147,9 @@ where .await?; //compute R^(1/z) to u64 - Ok(zip(gr, z) - .map(|(gr, z)| u64::from(gr * z.invert())) + let inv_z = crate::ff::ec_prime_field::batch_invert::(&z); + Ok(zip(gr, inv_z) + .map(|(gr, inv_z)| u64::from(gr * inv_z)) .collect::>() .try_into() .expect("iteration over arrays")) diff --git a/ipa-core/src/query/runner/hybrid.rs b/ipa-core/src/query/runner/hybrid.rs index 8fcf4f860..c4e40c233 100644 --- a/ipa-core/src/query/runner/hybrid.rs +++ b/ipa-core/src/query/runner/hybrid.rs @@ -1,9 +1,13 @@ -use std::{convert::Into, marker::PhantomData, sync::Arc}; +use std::{ + convert::{Infallible, Into}, + marker::PhantomData, + sync::Arc, +}; use futures::{stream::iter, StreamExt, TryStreamExt}; use crate::{ - error::Error, + error::{Error, LengthError}, ff::{ boolean::Boolean, boolean_array::{BooleanArray, BA3, BA8}, @@ -35,7 +39,10 @@ use crate::{ }, hybrid_info::HybridInfo, }, - secret_sharing::{replicated::semi_honest::AdditiveShare as Replicated, Vectorizable}, + secret_sharing::{ + replicated::semi_honest::AdditiveShare as Replicated, BitDecomposed, TransposeFrom, + Vectorizable, + }, }; #[allow(dead_code)] @@ -75,6 +82,10 @@ where Replicated: BooleanProtocols>, Replicated: BooleanArrayMul> + Reveal, Output = >::Array>, + BitDecomposed>: + for<'b> TransposeFrom<&'b [Replicated; 256], Error = Infallible>, + Vec>: + for<'b> TransposeFrom<&'b BitDecomposed>, Error = LengthError>, { #[tracing::instrument("hybrid_query", skip_all, fields(sz=%query_size))] pub async fn execute( @@ -144,7 +155,7 @@ where #[cfg(not(feature = "relaxed-dp"))] let padding_params = PaddingParameters::default(); - hybrid_protocol::<_, BA8, BA3, HV, 256>( + hybrid_protocol::<_, BA8, BA3, HV, 3, 256>( ctx, indistinguishable_reports, dp_params, @@ -156,7 +167,10 @@ where #[cfg(all(test, unit_test, feature = "in-memory-infra"))] mod tests { - use std::{iter::zip, sync::Arc}; + use std::{ + iter::{repeat, zip}, + sync::Arc, + }; use rand::rngs::StdRng; use rand_core::SeedableRng; @@ -173,45 +187,15 @@ mod tests { hpke::{KeyPair, KeyRegistry}, query::runner::hybrid::Query as HybridQuery, report::{hybrid::HybridReport, hybrid_info::HybridInfo, DEFAULT_KEY_ID}, - secret_sharing::{replicated::semi_honest::AdditiveShare, IntoShares}, + secret_sharing::IntoShares, test_executor::run, test_fixture::{ - flatten3v, hybrid::TestHybridRecord, Reconstruct, RoundRobinInputDistribution, - TestWorld, TestWorldConfig, WithShards, + flatten3v, + hybrid::{build_hybrid_records_and_expectation, TestHybridRecord}, + Reconstruct, RoundRobinInputDistribution, TestWorld, TestWorldConfig, WithShards, }, }; - const EXPECTED: &[u128] = &[0, 8, 5]; - - fn build_records() -> Vec { - vec![ - TestHybridRecord::TestImpression { - match_key: 12345, - breakdown_key: 2, - }, - TestHybridRecord::TestImpression { - match_key: 68362, - breakdown_key: 1, - }, - TestHybridRecord::TestConversion { - match_key: 12345, - value: 5, - }, - TestHybridRecord::TestConversion { - match_key: 68362, - value: 2, - }, - TestHybridRecord::TestImpression { - match_key: 68362, - breakdown_key: 1, - }, - TestHybridRecord::TestConversion { - match_key: 68362, - value: 7, - }, - ] - } - struct BufferAndKeyRegistry { buffers: [Vec>; 3], key_registry: Arc>, @@ -265,17 +249,22 @@ mod tests { } #[test] - // placeholder until the protocol is complete. can be updated to make sure we - // get to the unimplemented() call - #[should_panic( - expected = "not implemented: protocol::hybrid::hybrid_protocol is not fully implemented" - )] fn encrypted_hybrid_reports_happy() { // While this test currently checks for an unimplemented panic it is // designed to test for a correct result for a complete implementation. run(|| async { const SHARDS: usize = 2; - let records = build_records(); + let (test_hybrid_records, mut expected) = build_hybrid_records_and_expectation(); + + match expected.len() { + len if len < 256 => { + expected.extend(repeat(0).take(256 - len)); + } + len if len > 256 => { + panic!("no support for more than 256 breakdown_keys"); + } + _ => {} + } let hybrid_info = HybridInfo::new(0, "HELPER_ORIGIN", "meta.com", 1_729_707_432, 5.0, 1.1).unwrap(); @@ -284,7 +273,7 @@ mod tests { buffers, key_registry, query_sizes, - } = build_buffers_from_records(&records, SHARDS, &hybrid_info); + } = build_buffers_from_records(&test_hybrid_records, SHARDS, &hybrid_info); let world = TestWorld::>::with_shards(TestWorldConfig::default()); let contexts = world.malicious_contexts(); @@ -297,7 +286,10 @@ mod tests { .zip(helper_ctxs) .zip(query_sizes.clone()) .map(|((buffer, ctx), query_size)| { - let query_params = HybridQueryParams::default(); + let query_params = HybridQueryParams { + with_dp: 0, + ..Default::default() + }; let input = BodyStream::from(buffer); HybridQuery::<_, BA16, KeyRegistry>::new( @@ -311,7 +303,9 @@ mod tests { )) .await; - let results: Vec<[Vec>; 3]> = results + // TODO: after landing #1446, refactor this to only take the first 3 + // vectors, then validate that all other vectors reconstruct to 0. + let results: Vec = results .chunks(3) .map(|chunk| { [ @@ -319,15 +313,21 @@ mod tests { chunk[1].as_ref().unwrap().clone(), chunk[2].as_ref().unwrap().clone(), ] + .reconstruct() + .iter() + .map(U128Conversions::as_u128) + .collect::>() }) - .collect(); + .fold(([0_u128; 256]).to_vec(), |acc, v| { + acc.into_iter().zip(v).map(|(a, b)| a + b).collect() + }); assert_eq!( - results.into_iter().next().unwrap().reconstruct()[0..3] + results .iter() - .map(U128Conversions::as_u128) - .collect::>(), - EXPECTED + .map(|&x| u32::try_from(x).expect("test values should fit in u32")) + .collect::>(), + expected ); }); } @@ -337,7 +337,7 @@ mod tests { #[should_panic(expected = "DuplicateBytes")] async fn duplicate_encrypted_hybrid_reports() { const SHARDS: usize = 2; - let records = build_records(); + let (test_hybrid_records, _expected) = build_hybrid_records_and_expectation(); let hybrid_info = HybridInfo::new(0, "HELPER_ORIGIN", "meta.com", 1_729_707_432, 5.0, 1.1).unwrap(); @@ -346,7 +346,7 @@ mod tests { mut buffers, key_registry, query_sizes, - } = build_buffers_from_records(&records, SHARDS, &hybrid_info); + } = build_buffers_from_records(&test_hybrid_records, SHARDS, &hybrid_info); // this is double, since we duplicate the data below let query_sizes = query_sizes @@ -406,7 +406,7 @@ mod tests { )] async fn unsupported_plaintext_match_keys_hybrid_query() { const SHARDS: usize = 2; - let records = build_records(); + let (test_hybrid_records, _expected) = build_hybrid_records_and_expectation(); let hybrid_info = HybridInfo::new(0, "HELPER_ORIGIN", "meta.com", 1_729_707_432, 5.0, 1.1).unwrap(); @@ -415,7 +415,7 @@ mod tests { buffers, key_registry, query_sizes, - } = build_buffers_from_records(&records, SHARDS, &hybrid_info); + } = build_buffers_from_records(&test_hybrid_records, SHARDS, &hybrid_info); let world: TestWorld> = TestWorld::with_shards(TestWorldConfig::default()); diff --git a/ipa-core/src/test_fixture/hybrid.rs b/ipa-core/src/test_fixture/hybrid.rs index d522089c3..680c60c80 100644 --- a/ipa-core/src/test_fixture/hybrid.rs +++ b/ipa-core/src/test_fixture/hybrid.rs @@ -1,7 +1,4 @@ -use std::{ - collections::{HashMap, HashSet}, - iter::zip, -}; +use std::{collections::HashMap, iter::zip}; use crate::{ ff::{ @@ -23,7 +20,7 @@ pub enum TestHybridRecord { TestConversion { match_key: u64, value: u32 }, } -#[derive(PartialEq, Eq, Debug)] +#[derive(Clone, PartialEq, Eq, Debug)] pub struct TestIndistinguishableHybridReport { pub match_key: MK, pub value: u32, @@ -159,16 +156,40 @@ where } } -struct HashmapEntry { - breakdown_key: u32, - total_value: u32, +enum MatchEntry { + Single(TestHybridRecord), + Pair(TestHybridRecord, TestHybridRecord), + MoreThanTwo, } -impl HashmapEntry { - pub fn new(breakdown_key: u32, value: u32) -> Self { - Self { - breakdown_key, - total_value: value, +impl MatchEntry { + pub fn add_record(&mut self, new_record: TestHybridRecord) { + match self { + Self::Single(old_record) => { + *self = Self::Pair(old_record.clone(), new_record); + } + Self::Pair { .. } | Self::MoreThanTwo => *self = Self::MoreThanTwo, + } + } + + pub fn into_breakdown_key_and_value_tuple(self) -> Option<(u32, u32)> { + match self { + Self::Pair(imp, conv) => match (imp, conv) { + ( + TestHybridRecord::TestImpression { breakdown_key, .. }, + TestHybridRecord::TestConversion { value, .. }, + ) + | ( + TestHybridRecord::TestConversion { value, .. }, + TestHybridRecord::TestImpression { breakdown_key, .. }, + ) => Some((breakdown_key, value)), + ( + TestHybridRecord::TestConversion { value: value1, .. }, + TestHybridRecord::TestConversion { value: value2, .. }, + ) => Some((0, value1 + value2)), + _ => None, + }, + _ => None, } } } @@ -177,127 +198,114 @@ impl HashmapEntry { /// It won't, so long as you can convert a u32 to a usize #[must_use] pub fn hybrid_in_the_clear(input_rows: &[TestHybridRecord], max_breakdown: usize) -> Vec { - let mut conversion_match_keys = HashSet::new(); - let mut impression_match_keys = HashSet::new(); - + let mut attributed_conversions = HashMap::::new(); for input in input_rows { match input { - TestHybridRecord::TestImpression { match_key, .. } => { - impression_match_keys.insert(*match_key); - } - TestHybridRecord::TestConversion { match_key, .. } => { - conversion_match_keys.insert(*match_key); + TestHybridRecord::TestConversion { match_key, .. } + | TestHybridRecord::TestImpression { match_key, .. } => { + attributed_conversions + .entry(*match_key) + .and_modify(|e| e.add_record(input.clone())) + .or_insert(MatchEntry::Single(input.clone())); } } } - // The key is the "match key" and the value stores both the breakdown and total attributed value - let mut attributed_conversions = HashMap::new(); - - for input in input_rows { - match input { - TestHybridRecord::TestImpression { - match_key, - breakdown_key, - } => { - if conversion_match_keys.contains(match_key) { - let v = attributed_conversions - .entry(*match_key) - .or_insert(HashmapEntry::new(*breakdown_key, 0)); - v.breakdown_key = *breakdown_key; - } - } - TestHybridRecord::TestConversion { match_key, value } => { - if impression_match_keys.contains(match_key) { - attributed_conversions - .entry(*match_key) - .and_modify(|e| e.total_value += value) - .or_insert(HashmapEntry::new(0, *value)); - } - } - } - } + let pairs = attributed_conversions + .into_values() + .filter_map(MatchEntry::into_breakdown_key_and_value_tuple) + .collect::>(); let mut output = vec![0; max_breakdown]; - for (_, entry) in attributed_conversions { - output[usize::try_from(entry.breakdown_key).unwrap()] += entry.total_value; + for (breakdown_key, value) in pairs { + output[usize::try_from(breakdown_key).unwrap()] += value; } output } +#[must_use] +pub fn build_hybrid_records_and_expectation() -> (Vec, Vec) { + let test_hybrid_records = vec![ + TestHybridRecord::TestConversion { + match_key: 12345, + value: 2, + }, // malicious client attributed to breakdown 0 + TestHybridRecord::TestConversion { + match_key: 12345, + value: 5, + }, // malicious client attributed to breakdown 0 + TestHybridRecord::TestImpression { + match_key: 23456, + breakdown_key: 4, + }, // attributed + TestHybridRecord::TestConversion { + match_key: 23456, + value: 7, + }, // attributed + TestHybridRecord::TestImpression { + match_key: 34567, + breakdown_key: 1, + }, // no conversion + TestHybridRecord::TestImpression { + match_key: 45678, + breakdown_key: 3, + }, // attributed + TestHybridRecord::TestConversion { + match_key: 45678, + value: 5, + }, // attributed + TestHybridRecord::TestImpression { + match_key: 56789, + breakdown_key: 5, + }, // no conversion + TestHybridRecord::TestConversion { + match_key: 67890, + value: 2, + }, // NOT attributed + TestHybridRecord::TestImpression { + match_key: 78901, + breakdown_key: 2, + }, // too many reports + TestHybridRecord::TestConversion { + match_key: 78901, + value: 3, + }, // not attributed, too many reports + TestHybridRecord::TestConversion { + match_key: 78901, + value: 4, + }, // not attributed, too many reports + TestHybridRecord::TestImpression { + match_key: 89012, + breakdown_key: 4, + }, // attributed + TestHybridRecord::TestConversion { + match_key: 89012, + value: 6, + }, // attributed + ]; + + let expected = vec![ + 7, // two conversions goes to bucket 0: 2 + 5 + 0, 0, 5, 13, // 4: 7 + 6 + 0, + ]; + + (test_hybrid_records, expected) +} + #[cfg(all(test, unit_test))] mod tests { use rand::{seq::SliceRandom, thread_rng}; - use super::TestHybridRecord; - use crate::test_fixture::hybrid::hybrid_in_the_clear; + use crate::test_fixture::hybrid::{build_hybrid_records_and_expectation, hybrid_in_the_clear}; #[test] - fn basic() { - let mut test_data = vec![ - TestHybridRecord::TestImpression { - match_key: 12345, - breakdown_key: 2, - }, - TestHybridRecord::TestImpression { - match_key: 23456, - breakdown_key: 4, - }, - TestHybridRecord::TestConversion { - match_key: 23456, - value: 25, - }, // attributed - TestHybridRecord::TestImpression { - match_key: 34567, - breakdown_key: 1, - }, - TestHybridRecord::TestImpression { - match_key: 45678, - breakdown_key: 3, - }, - TestHybridRecord::TestConversion { - match_key: 45678, - value: 13, - }, // attributed - TestHybridRecord::TestImpression { - match_key: 56789, - breakdown_key: 5, - }, - TestHybridRecord::TestConversion { - match_key: 67890, - value: 14, - }, // NOT attributed - TestHybridRecord::TestImpression { - match_key: 78901, - breakdown_key: 2, - }, - TestHybridRecord::TestConversion { - match_key: 78901, - value: 12, - }, // attributed - TestHybridRecord::TestConversion { - match_key: 78901, - value: 31, - }, // attributed - TestHybridRecord::TestImpression { - match_key: 89012, - breakdown_key: 4, - }, - TestHybridRecord::TestConversion { - match_key: 89012, - value: 8, - }, // attributed - ]; - + fn hybrid_basic() { + let (mut test_hybrid_records, expected) = build_hybrid_records_and_expectation(); let mut rng = thread_rng(); - test_data.shuffle(&mut rng); - let expected = vec![ - 0, 0, 43, // 12 + 31 - 13, 33, // 25 + 8 - 0, - ]; - let result = hybrid_in_the_clear(&test_data, 6); + test_hybrid_records.shuffle(&mut rng); + let result = hybrid_in_the_clear(&test_hybrid_records, 6); assert_eq!(result, expected); } }