diff --git a/Cargo.lock b/Cargo.lock index 092d816e..33e43cd5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1868,6 +1868,7 @@ dependencies = [ "polynomials", "rand", "rayon", + "serial_test", "transcript", ] diff --git a/arith/polynomials/src/ref_mle.rs b/arith/polynomials/src/ref_mle.rs index fab69eab..4bb13eea 100644 --- a/arith/polynomials/src/ref_mle.rs +++ b/arith/polynomials/src/ref_mle.rs @@ -9,6 +9,7 @@ pub trait MultilinearExtension: Index { fn num_vars(&self) -> usize; + #[inline] fn hypercube_size(&self) -> usize { 1 << self.num_vars() } @@ -18,6 +19,11 @@ pub trait MultilinearExtension: Index { fn hypercube_basis_ref(&self) -> &Vec; fn interpolate_over_hypercube(&self) -> Vec; + + #[inline] + fn serialized_size(&self) -> usize { + self.hypercube_size() * F::SIZE + } } #[derive(Debug, Clone)] diff --git a/config/mpi_config/src/mpi_config.rs b/config/mpi_config/src/mpi_config.rs index e4bc7182..41b760d8 100644 --- a/config/mpi_config/src/mpi_config.rs +++ b/config/mpi_config/src/mpi_config.rs @@ -135,10 +135,7 @@ impl MPIConfig { pub fn read_all_field_flat(&self, start: usize, end: usize) -> Vec { let data = self.read_all(start, end); data.iter() - .flat_map(|x| { - x.chunks(F::SIZE) - .map(|y| F::deserialize_from(y).unwrap()) - }) + .flat_map(|x| x.chunks(F::SIZE).map(|y| F::deserialize_from(y).unwrap())) .collect() } @@ -190,4 +187,16 @@ impl MPIConfig { result } } + + #[inline] + /// Finalize function does nothing except for a minimal sanity check + /// that all threads have the same amount of data + pub fn finalize(&self) { + let len = self.threads[0].size(); + self.threads.iter().skip(1).for_each(|t| { + assert_eq!(t.size(), len); + }); + + // do nothing + } } diff --git a/poly_commit/Cargo.toml b/poly_commit/Cargo.toml index 5de553bd..58306dec 100644 --- a/poly_commit/Cargo.toml +++ b/poly_commit/Cargo.toml @@ -13,3 +13,7 @@ transcript = { path = "../transcript" } ethnum.workspace = true rand.workspace = true rayon.workspace = true + + +[dev-dependencies] +serial_test.workspace = true \ No newline at end of file diff --git a/poly_commit/src/raw.rs b/poly_commit/src/raw.rs index 4332aa91..afc77602 100644 --- a/poly_commit/src/raw.rs +++ b/poly_commit/src/raw.rs @@ -197,8 +197,13 @@ impl> PCSForExpanderGKR> PCSForExpanderGKR::VKey, commitment: &Self::Commitment, x: &ExpanderGKRChallenge, @@ -236,7 +241,6 @@ impl> PCSForExpanderGKR bool { - assert!(mpi_config.is_root()); // Only the root will verify let ExpanderGKRChallenge:: { x, x_simd, x_mpi } = x; Self::eval(&commitment.evals, x, x_simd, x_mpi) == v } diff --git a/poly_commit/src/traits.rs b/poly_commit/src/traits.rs index 507ff32b..a698597a 100644 --- a/poly_commit/src/traits.rs +++ b/poly_commit/src/traits.rs @@ -7,8 +7,8 @@ use std::fmt::Debug; use transcript::Transcript; pub trait StructuredReferenceString { - type PKey: Clone + Debug + FieldSerde + Send; - type VKey: Clone + Debug + FieldSerde + Send; + type PKey: Clone + Debug + FieldSerde + Send + Sync; + type VKey: Clone + Debug + FieldSerde + Send + Sync; /// Convert the SRS into proving and verifying keys. /// Comsuming self by default. @@ -72,7 +72,7 @@ pub struct ExpanderGKRChallenge { pub trait PCSForExpanderGKR> { const NAME: &'static str; - type Params: Clone + Debug + Default + Send; + type Params: Clone + Debug + Default + Send + Sync; type ScratchPad: Clone + Debug + Default + Send; type SRS: Clone + Debug + Default + FieldSerde + StructuredReferenceString; diff --git a/poly_commit/tests/common.rs b/poly_commit/tests/common.rs index ada8f1da..9e80aa1b 100644 --- a/poly_commit/tests/common.rs +++ b/poly_commit/tests/common.rs @@ -5,8 +5,9 @@ use poly_commit::raw::RawExpanderGKR; use poly_commit::{ ExpanderGKRChallenge, PCSForExpanderGKR, PolynomialCommitmentScheme, StructuredReferenceString, }; -use polynomials::MultilinearExtension; +use polynomials::{MultilinearExtension, RefMultiLinearPoly}; use rand::thread_rng; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; use transcript::Transcript; pub fn test_pcs>( @@ -39,55 +40,79 @@ pub fn test_gkr_pcs< T: Transcript, P: PCSForExpanderGKR, >( - params: &P::Params, mpi_config: &MPIConfig, - transcript: &mut T, - poly: &impl MultilinearExtension, - xs: &[ExpanderGKRChallenge], + n_local_vars: usize, ) { let mut rng = thread_rng(); - let srs = P::gen_srs_for_testing(params, mpi_config, &mut rng); + let params = P::gen_params(n_local_vars); + + let srs = P::gen_srs_for_testing(¶ms, mpi_config, &mut rng); let (proving_key, verification_key) = srs.into_keys(); - let mut scratch_pad = P::init_scratch_pad(params, mpi_config); - let commitment = P::commit(params, mpi_config, &proving_key, poly, &mut scratch_pad); + let num_threads = rayon::current_num_threads(); - // PCSForExpanderGKR does not require an evaluation value for the opening function - // We use RawExpanderGKR as the golden standard for the evaluation value - // Note this test will almost always pass for RawExpanderGKR, so make sure it is correct - let mut coeffs_gathered = if mpi_config.is_root() { - vec![C::SimdCircuitField::ZERO; poly.hypercube_size() * mpi_config.world_size()] - } else { - vec![] - }; - mpi_config.gather_vec(poly.hypercube_basis_ref(), &mut coeffs_gathered); + (0..num_threads).into_par_iter().for_each(|_| { + let mut rng = thread_rng(); + let hypercube_basis = (0..(1 << n_local_vars)) + .map(|_| C::SimdCircuitField::random_unsafe(&mut rng)) + .collect(); + let poly = RefMultiLinearPoly::from_ref(&hypercube_basis); - for xx in xs { - let ExpanderGKRChallenge { x, x_simd, x_mpi } = xx; - let opening = P::open( - params, - mpi_config, - &proving_key, - poly, - xx, - transcript, - &mut scratch_pad, - ); + let xs = (0..2) + .map(|_| ExpanderGKRChallenge:: { + x: (0..n_local_vars) + .map(|_| C::ChallengeField::random_unsafe(&mut rng)) + .collect::>(), + x_simd: (0..C::get_field_pack_size().trailing_zeros()) + .map(|_| C::ChallengeField::random_unsafe(&mut rng)) + .collect::>(), + x_mpi: (0..mpi_config.world_size().trailing_zeros()) + .map(|_| C::ChallengeField::random_unsafe(&mut rng)) + .collect::>(), + }) + .collect::>>(); + + let mut scratch_pad = P::init_scratch_pad(¶ms, mpi_config); + + let commitment = P::commit(¶ms, mpi_config, &proving_key, &poly, &mut scratch_pad); + let mut transcript = T::new(); + + // PCSForExpanderGKR does not require an evaluation value for the opening function + // We use RawExpanderGKR as the golden standard for the evaluation value + // Note this test will almost always pass for RawExpanderGKR, so make sure it is correct + let start = mpi_config.current_size(); + let end = start + poly.serialized_size(); + + poly.hypercube_basis_ref() + .iter() + .for_each(|f| mpi_config.append_local_field(f)); + let coeffs_gathered = mpi_config.read_all_field_flat(start, end); + + for xx in xs { + let ExpanderGKRChallenge { x, x_simd, x_mpi } = &xx; + let opening = P::open( + ¶ms, + mpi_config, + &proving_key, + &poly, + &xx, + &mut transcript, + &mut scratch_pad, + ); - if mpi_config.is_root() { // this will always pass for RawExpanderGKR, so make sure it is correct - let v = RawExpanderGKR::::eval(&coeffs_gathered, x, x_simd, x_mpi); + let v = RawExpanderGKR::::eval(&coeffs_gathered, &x, &x_simd, &x_mpi); assert!(P::verify( - params, + ¶ms, mpi_config, &verification_key, &commitment, - xx, + &xx, v, - transcript, + &mut transcript, &opening )); } - } + }); } diff --git a/poly_commit/tests/test_raw.rs b/poly_commit/tests/test_raw.rs index 6818c351..ee1b8c0c 100644 --- a/poly_commit/tests/test_raw.rs +++ b/poly_commit/tests/test_raw.rs @@ -1,15 +1,16 @@ mod common; +use std::sync::Arc; + use arith::{BN254Fr, Field}; +use common::test_gkr_pcs; use gkr_field_config::{BN254Config, GF2ExtConfig, GKRFieldConfig, M31ExtConfig}; use mpi_config::MPIConfig; -use poly_commit::{ - raw::{RawExpanderGKR, RawExpanderGKRParams, RawMultiLinearPCS, RawMultiLinearParams}, - ExpanderGKRChallenge, -}; -use polynomials::{MultiLinearPoly, RefMultiLinearPoly}; +use poly_commit::raw::{RawExpanderGKR, RawMultiLinearPCS, RawMultiLinearParams}; +use polynomials::MultiLinearPoly; use rand::thread_rng; -use transcript::{BytesHashTranscript, SHA256hasher, Transcript}; +use serial_test::serial; +use transcript::{BytesHashTranscript, SHA256hasher}; #[test] fn test_raw() { @@ -27,45 +28,24 @@ fn test_raw() { common::test_pcs::(¶ms, &poly, &xs); } -fn test_raw_gkr_helper>( - mpi_config: &MPIConfig, - transcript: &mut T, -) { - let params = RawExpanderGKRParams { n_local_vars: 8 }; - let mut rng = thread_rng(); - let hypercube_basis = (0..(1 << params.n_local_vars)) - .map(|_| C::SimdCircuitField::random_unsafe(&mut rng)) - .collect(); - let poly = RefMultiLinearPoly::from_ref(&hypercube_basis); - let xs = (0..100) - .map(|_| ExpanderGKRChallenge:: { - x: (0..params.n_local_vars) - .map(|_| C::ChallengeField::random_unsafe(&mut rng)) - .collect::>(), - x_simd: (0..C::get_field_pack_size().trailing_zeros()) - .map(|_| C::ChallengeField::random_unsafe(&mut rng)) - .collect::>(), - x_mpi: (0..mpi_config.world_size().trailing_zeros()) - .map(|_| C::ChallengeField::random_unsafe(&mut rng)) - .collect::>(), - }) - .collect::>>(); - common::test_gkr_pcs::>(¶ms, mpi_config, transcript, &poly, &xs); -} - #[test] +#[serial] fn test_raw_gkr() { - let mpi_config = MPIConfig::new(); + let global_data: Arc<[u8]> = Arc::from((0..1024).map(|i| i as u8).collect::>()); + let num_threads = rayon::current_num_threads(); + + // Create configs for all threads + let mpi_config = MPIConfig::new(num_threads as i32, global_data, 1024 * 1024); type TM31 = BytesHashTranscript<::ChallengeField, SHA256hasher>; - test_raw_gkr_helper::(&mpi_config, &mut TM31::new()); + test_gkr_pcs::>(&mpi_config, 8); type TGF2 = BytesHashTranscript<::ChallengeField, SHA256hasher>; - test_raw_gkr_helper::(&mpi_config, &mut TGF2::new()); + test_gkr_pcs::>(&mpi_config, 8); type TBN254 = BytesHashTranscript<::ChallengeField, SHA256hasher>; - test_raw_gkr_helper::(&mpi_config, &mut TBN254::new()); + test_gkr_pcs::>(&mpi_config, 8); - MPIConfig::finalize(); + mpi_config.finalize(); } diff --git a/sumcheck/src/prover_helper/sumcheck_gkr_vanilla.rs b/sumcheck/src/prover_helper/sumcheck_gkr_vanilla.rs index 54fddb0f..eefe2777 100644 --- a/sumcheck/src/prover_helper/sumcheck_gkr_vanilla.rs +++ b/sumcheck/src/prover_helper/sumcheck_gkr_vanilla.rs @@ -329,14 +329,20 @@ impl<'a, C: GKRFieldConfig> SumcheckGkrVanillaHelper<'a, C> { #[inline] pub(crate) fn prepare_mpi_var_vals(&mut self) { let start = self.mpi_config.current_size(); - self.mpi_config.append_local_field(&self.sp.simd_var_v_evals[0]); + self.mpi_config + .append_local_field(&self.sp.simd_var_v_evals[0]); let end = self.mpi_config.current_size(); - self.sp.mpi_var_v_evals = self.mpi_config.read_all_field_flat::(start, end); + self.sp.mpi_var_v_evals = self + .mpi_config + .read_all_field_flat::(start, end); let start = self.mpi_config.current_size(); - self.mpi_config.append_local_field(&(self.sp.simd_var_hg_evals[0] * self.sp.eq_evals_at_r_simd0[0])); + self.mpi_config + .append_local_field(&(self.sp.simd_var_hg_evals[0] * self.sp.eq_evals_at_r_simd0[0])); let end = self.mpi_config.current_size(); - self.sp.mpi_var_hg_evals = self.mpi_config.read_all_field_flat::(start, end); + self.sp.mpi_var_hg_evals = self + .mpi_config + .read_all_field_flat::(start, end); } #[inline]