diff --git a/ipa-core/src/protocol/basics/mul/dzkp_malicious.rs b/ipa-core/src/protocol/basics/mul/dzkp_malicious.rs index b496c14f6..af04aa582 100644 --- a/ipa-core/src/protocol/basics/mul/dzkp_malicious.rs +++ b/ipa-core/src/protocol/basics/mul/dzkp_malicious.rs @@ -80,20 +80,143 @@ impl<'a, F: Field + DZKPCompatibleField, const N: usize> #[cfg(all(test, unit_test))] mod test { + use ipa_step_derive::CompactStep; + use crate::{ error::Error, ff::boolean::Boolean, + helpers::{Role, Role::H1}, protocol::{ - basics::SecureMul, - context::{dzkp_validator::DZKPValidator, Context, DZKPContext, UpgradableContext}, + basics::{ + mul::{dzkp_malicious::Field, semi_honest::multiplication_protocol, Replicated}, + SecureMul, + }, + context::{ + dzkp_field::DZKPCompatibleField, + dzkp_validator::{DZKPValidator, Segment}, + Context, DZKPContext, DZKPUpgradedMaliciousContext, UpgradableContext, + }, RecordId, }, rand::{thread_rng, Rng}, + secret_sharing::{replicated::semi_honest::AdditiveShare, SharedValueArray, Vectorizable}, test_fixture::{Reconstruct, Runner, TestWorld}, }; + /// This function mirrors `zkp_multiply` except that on party cheats. + /// + /// The cheating party flips `prss_left` + /// which causes a flip in `z_left` computed by the cheating party. + /// This manipulated `z_left` is then sent to a different helper + /// and included in the DZKP batch. + pub async fn multiply_with_cheater<'a, F, const N: usize>( + ctx: DZKPUpgradedMaliciousContext<'a>, + record_id: RecordId, + a: &Replicated, + b: &Replicated, + prss: &Replicated, + cheater: Role, + ) -> Result, Error> + where + F: Field + DZKPCompatibleField, + { + let mut prss_left = prss.left_arr().clone(); + if ctx.role() == cheater { + prss_left += <>::Array>::from_fn(|_| F::ONE); + }; + + let z = + multiplication_protocol(&ctx, record_id, a, b, &prss_left, prss.right_arr()).await?; + // create segment + let segment = Segment::from_entries( + F::as_segment_entry(a.left_arr()), + F::as_segment_entry(a.right_arr()), + F::as_segment_entry(b.left_arr()), + F::as_segment_entry(b.right_arr()), + F::as_segment_entry(prss.left_arr()), + F::as_segment_entry(prss.right_arr()), + F::as_segment_entry(z.right_arr()), + ); + + // add segment to the batch that needs to be verified by the dzkp prover and verifiers + ctx.push(record_id, segment); + + Ok(z) + } + fn generate_share_from_three_bits(role: Role, i: usize) -> AdditiveShare { + let (first_bit, second_bit) = match role { + Role::H1 => (i % 2 == 0, (i >> 1) % 2 == 0), + Role::H2 => ((i >> 1) % 2 == 0, (i >> 2) % 2 == 0), + Role::H3 => ((i >> 2) % 2 == 0, i % 2 == 0), + }; + >::from((first_bit.into(), second_bit.into())) + } + + fn all_combination_of_inputs(role: Role, i: usize) -> [AdditiveShare; 3] { + // first three bits + let a = generate_share_from_three_bits(role, i); + // middle three bits + let b = generate_share_from_three_bits(role, i >> 3); + // last three bits + let prss = generate_share_from_three_bits(role, i >> 6); + + [a, b, prss] + } + + #[derive(CompactStep)] + enum TestStep { + #[step(count = 512)] + Counter(usize), + } + + #[tokio::test] + async fn detect_cheating() { + let world = TestWorld::default(); + + for i in 0..512 { + let [(_, s_1), (_, s_2), (v_3, s_3)] = world + .malicious((), |ctx, ()| async move { + let [a, b, prss] = all_combination_of_inputs(ctx.role(), i); + let validator = ctx.narrow(&TestStep::Counter(i)).dzkp_validator(10); + let mctx = validator.context(); + let product = multiply_with_cheater( + mctx.set_total_records(1), + RecordId::FIRST, + &a, + &b, + &prss, + H1, + ) + .await + .unwrap(); + + ( + validator.validate().await, + [ + bool::from(*a.left_arr().first()), + bool::from(*a.right_arr().first()), + bool::from(*b.left_arr().first()), + bool::from(*b.right_arr().first()), + bool::from(*prss.left_arr().first()), + bool::from(*prss.right_arr().first()), + bool::from(*product.left_arr().first()), + bool::from(*product.right_arr().first()), + ], + ) + }) + .await; + + // H1 cheats means H3 fails + // since always verifier on the left of the cheating prover fails + match v_3 { + Ok(()) => panic!("Got a result H1: {s_1:?}, H2: {s_2:?}, H3: {s_3:?}"), + Err(ref err) => assert!(matches!(err, Error::DZKPValidationFailed)), + } + } + } + #[tokio::test] - pub async fn simple() { + async fn simple() { let world = TestWorld::default(); let mut rng = thread_rng();