Skip to content

Commit

Permalink
fix poly commit
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenfeizhang committed Jan 7, 2025
1 parent 72279a5 commit 31d32a5
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 86 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions arith/polynomials/src/ref_mle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub trait MultilinearExtension<F: Field>: Index<usize, Output = F> {

fn num_vars(&self) -> usize;

#[inline]
fn hypercube_size(&self) -> usize {
1 << self.num_vars()
}
Expand All @@ -18,6 +19,11 @@ pub trait MultilinearExtension<F: Field>: Index<usize, Output = F> {
fn hypercube_basis_ref(&self) -> &Vec<F>;

fn interpolate_over_hypercube(&self) -> Vec<F>;

#[inline]
fn serialized_size(&self) -> usize {
self.hypercube_size() * F::SIZE
}
}

#[derive(Debug, Clone)]
Expand Down
17 changes: 13 additions & 4 deletions config/mpi_config/src/mpi_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,7 @@ impl MPIConfig {
pub fn read_all_field_flat<F: Field>(&self, start: usize, end: usize) -> Vec<F> {
let data = self.read_all(start, end);
data.iter()
.flat_map(|x| {
x.chunks(F::SIZE)
.map(|y| F::deserialize_from(y).unwrap())
})
.flat_map(|x| x.chunks(F::SIZE).map(|y| F::deserialize_from(y).unwrap()))
.collect()
}

Expand Down Expand Up @@ -190,4 +187,16 @@ impl MPIConfig {
result
}
}

#[inline]
/// Finalize function does nothing except for a minimal sanity check
/// that all threads have the same amount of data
pub fn finalize(&self) {
let len = self.threads[0].size();
self.threads.iter().skip(1).for_each(|t| {
assert_eq!(t.size(), len);
});

// do nothing
}
}
4 changes: 4 additions & 0 deletions poly_commit/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@ transcript = { path = "../transcript" }
ethnum.workspace = true
rand.workspace = true
rayon.workspace = true


[dev-dependencies]
serial_test.workspace = true
12 changes: 8 additions & 4 deletions poly_commit/src/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,13 @@ impl<C: GKRFieldConfig, T: Transcript<C::ChallengeField>> PCSForExpanderGKR<C, T
poly.hypercube_basis()
} else {
// read the last poly.hypercube_size() from each thread
let end = mpi_config.threads[0].size();
let start = end - poly.hypercube_size();
let start = mpi_config.current_thread().size();
let end = start + poly.serialized_size();

poly.hypercube_basis()
.iter()
.for_each(|f| mpi_config.append_local_field(f));

let payloads = mpi_config.read_all(start, end); // read #thread payloads

payloads
Expand Down Expand Up @@ -228,15 +233,14 @@ impl<C: GKRFieldConfig, T: Transcript<C::ChallengeField>> PCSForExpanderGKR<C, T

fn verify(
_params: &Self::Params,
mpi_config: &MPIConfig,
_mpi_config: &MPIConfig,
_verifying_key: &<Self::SRS as StructuredReferenceString>::VKey,
commitment: &Self::Commitment,
x: &ExpanderGKRChallenge<C>,
v: C::ChallengeField,
_transcript: &mut T,
_opening: &Self::Opening,
) -> bool {
assert!(mpi_config.is_root()); // Only the root will verify
let ExpanderGKRChallenge::<C> { x, x_simd, x_mpi } = x;
Self::eval(&commitment.evals, x, x_simd, x_mpi) == v
}
Expand Down
6 changes: 3 additions & 3 deletions poly_commit/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use std::fmt::Debug;
use transcript::Transcript;

pub trait StructuredReferenceString {
type PKey: Clone + Debug + FieldSerde + Send;
type VKey: Clone + Debug + FieldSerde + Send;
type PKey: Clone + Debug + FieldSerde + Send + Sync;
type VKey: Clone + Debug + FieldSerde + Send + Sync;

/// Convert the SRS into proving and verifying keys.
/// Comsuming self by default.
Expand Down Expand Up @@ -72,7 +72,7 @@ pub struct ExpanderGKRChallenge<C: GKRFieldConfig> {
pub trait PCSForExpanderGKR<C: GKRFieldConfig, T: Transcript<C::ChallengeField>> {
const NAME: &'static str;

type Params: Clone + Debug + Default + Send;
type Params: Clone + Debug + Default + Send + Sync;
type ScratchPad: Clone + Debug + Default + Send;

type SRS: Clone + Debug + Default + FieldSerde + StructuredReferenceString;
Expand Down
93 changes: 59 additions & 34 deletions poly_commit/tests/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ use poly_commit::raw::RawExpanderGKR;
use poly_commit::{
ExpanderGKRChallenge, PCSForExpanderGKR, PolynomialCommitmentScheme, StructuredReferenceString,
};
use polynomials::MultilinearExtension;
use polynomials::{MultilinearExtension, RefMultiLinearPoly};
use rand::thread_rng;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use transcript::Transcript;

pub fn test_pcs<F: Field, P: PolynomialCommitmentScheme<F>>(
Expand Down Expand Up @@ -39,55 +40,79 @@ pub fn test_gkr_pcs<
T: Transcript<C::ChallengeField>,
P: PCSForExpanderGKR<C, T>,
>(
params: &P::Params,
mpi_config: &MPIConfig,
transcript: &mut T,
poly: &impl MultilinearExtension<C::SimdCircuitField>,
xs: &[ExpanderGKRChallenge<C>],
n_local_vars: usize,
) {
let mut rng = thread_rng();
let srs = P::gen_srs_for_testing(params, mpi_config, &mut rng);
let params = P::gen_params(n_local_vars);

let srs = P::gen_srs_for_testing(&params, mpi_config, &mut rng);
let (proving_key, verification_key) = srs.into_keys();
let mut scratch_pad = P::init_scratch_pad(params, mpi_config);

let commitment = P::commit(params, mpi_config, &proving_key, poly, &mut scratch_pad);
let num_threads = rayon::current_num_threads();

// PCSForExpanderGKR does not require an evaluation value for the opening function
// We use RawExpanderGKR as the golden standard for the evaluation value
// Note this test will almost always pass for RawExpanderGKR, so make sure it is correct
let mut coeffs_gathered = if mpi_config.is_root() {
vec![C::SimdCircuitField::ZERO; poly.hypercube_size() * mpi_config.world_size()]
} else {
vec![]
};
mpi_config.gather_vec(poly.hypercube_basis_ref(), &mut coeffs_gathered);
(0..num_threads).into_par_iter().for_each(|_| {
let mut rng = thread_rng();
let hypercube_basis = (0..(1 << n_local_vars))
.map(|_| C::SimdCircuitField::random_unsafe(&mut rng))
.collect();
let poly = RefMultiLinearPoly::from_ref(&hypercube_basis);

for xx in xs {
let ExpanderGKRChallenge { x, x_simd, x_mpi } = xx;
let opening = P::open(
params,
mpi_config,
&proving_key,
poly,
xx,
transcript,
&mut scratch_pad,
);
let xs = (0..2)
.map(|_| ExpanderGKRChallenge::<C> {
x: (0..n_local_vars)
.map(|_| C::ChallengeField::random_unsafe(&mut rng))
.collect::<Vec<C::ChallengeField>>(),
x_simd: (0..C::get_field_pack_size().trailing_zeros())
.map(|_| C::ChallengeField::random_unsafe(&mut rng))
.collect::<Vec<C::ChallengeField>>(),
x_mpi: (0..mpi_config.world_size().trailing_zeros())
.map(|_| C::ChallengeField::random_unsafe(&mut rng))
.collect::<Vec<C::ChallengeField>>(),
})
.collect::<Vec<ExpanderGKRChallenge<C>>>();

let mut scratch_pad = P::init_scratch_pad(&params, mpi_config);

let commitment = P::commit(&params, mpi_config, &proving_key, &poly, &mut scratch_pad);
let mut transcript = T::new();

// PCSForExpanderGKR does not require an evaluation value for the opening function
// We use RawExpanderGKR as the golden standard for the evaluation value
// Note this test will almost always pass for RawExpanderGKR, so make sure it is correct
let start = mpi_config.current_size();
let end = start + poly.serialized_size();

poly.hypercube_basis_ref()
.iter()
.for_each(|f| mpi_config.append_local_field(f));
let coeffs_gathered = mpi_config.read_all_field_flat(start, end);

for xx in xs {
let ExpanderGKRChallenge { x, x_simd, x_mpi } = &xx;
let opening = P::open(
&params,
mpi_config,
&proving_key,
&poly,
&xx,
&mut transcript,
&mut scratch_pad,
);

if mpi_config.is_root() {
// this will always pass for RawExpanderGKR, so make sure it is correct
let v = RawExpanderGKR::<C, T>::eval(&coeffs_gathered, x, x_simd, x_mpi);
let v = RawExpanderGKR::<C, T>::eval(&coeffs_gathered, &x, &x_simd, &x_mpi);

assert!(P::verify(
params,
&params,
mpi_config,
&verification_key,
&commitment,
xx,
&xx,
v,
transcript,
&mut transcript,
&opening
));
}
}
});
}
54 changes: 17 additions & 37 deletions poly_commit/tests/test_raw.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
mod common;

use std::sync::Arc;

use arith::{BN254Fr, Field};
use common::test_gkr_pcs;
use gkr_field_config::{BN254Config, GF2ExtConfig, GKRFieldConfig, M31ExtConfig};
use mpi_config::MPIConfig;
use poly_commit::{
raw::{RawExpanderGKR, RawExpanderGKRParams, RawMultiLinearPCS, RawMultiLinearParams},
ExpanderGKRChallenge,
};
use polynomials::{MultiLinearPoly, RefMultiLinearPoly};
use poly_commit::raw::{RawExpanderGKR, RawMultiLinearPCS, RawMultiLinearParams};
use polynomials::MultiLinearPoly;
use rand::thread_rng;
use transcript::{BytesHashTranscript, SHA256hasher, Transcript};
use serial_test::serial;
use transcript::{BytesHashTranscript, SHA256hasher};

#[test]
fn test_raw() {
Expand All @@ -27,45 +28,24 @@ fn test_raw() {
common::test_pcs::<BN254Fr, RawMultiLinearPCS>(&params, &poly, &xs);
}

fn test_raw_gkr_helper<C: GKRFieldConfig, T: Transcript<C::ChallengeField>>(
mpi_config: &MPIConfig,
transcript: &mut T,
) {
let params = RawExpanderGKRParams { n_local_vars: 8 };
let mut rng = thread_rng();
let hypercube_basis = (0..(1 << params.n_local_vars))
.map(|_| C::SimdCircuitField::random_unsafe(&mut rng))
.collect();
let poly = RefMultiLinearPoly::from_ref(&hypercube_basis);
let xs = (0..100)
.map(|_| ExpanderGKRChallenge::<C> {
x: (0..params.n_local_vars)
.map(|_| C::ChallengeField::random_unsafe(&mut rng))
.collect::<Vec<C::ChallengeField>>(),
x_simd: (0..C::get_field_pack_size().trailing_zeros())
.map(|_| C::ChallengeField::random_unsafe(&mut rng))
.collect::<Vec<C::ChallengeField>>(),
x_mpi: (0..mpi_config.world_size().trailing_zeros())
.map(|_| C::ChallengeField::random_unsafe(&mut rng))
.collect::<Vec<C::ChallengeField>>(),
})
.collect::<Vec<ExpanderGKRChallenge<C>>>();
common::test_gkr_pcs::<C, T, RawExpanderGKR<C, T>>(&params, mpi_config, transcript, &poly, &xs);
}

#[test]
#[serial]
fn test_raw_gkr() {
let mpi_config = MPIConfig::new();
let global_data: Arc<[u8]> = Arc::from((0..1024).map(|i| i as u8).collect::<Vec<u8>>());
let num_threads = rayon::current_num_threads();

// Create configs for all threads
let mpi_config = MPIConfig::new(num_threads as i32, global_data, 1024 * 1024);

type TM31 = BytesHashTranscript<<M31ExtConfig as GKRFieldConfig>::ChallengeField, SHA256hasher>;
test_raw_gkr_helper::<M31ExtConfig, TM31>(&mpi_config, &mut TM31::new());
test_gkr_pcs::<M31ExtConfig, TM31, RawExpanderGKR<_, _>>(&mpi_config, 8);

type TGF2 = BytesHashTranscript<<GF2ExtConfig as GKRFieldConfig>::ChallengeField, SHA256hasher>;
test_raw_gkr_helper::<GF2ExtConfig, TGF2>(&mpi_config, &mut TGF2::new());
test_gkr_pcs::<GF2ExtConfig, TGF2, RawExpanderGKR<_, _>>(&mpi_config, 8);

type TBN254 =
BytesHashTranscript<<BN254Config as GKRFieldConfig>::ChallengeField, SHA256hasher>;
test_raw_gkr_helper::<BN254Config, TBN254>(&mpi_config, &mut TBN254::new());
test_gkr_pcs::<BN254Config, TBN254, RawExpanderGKR<_, _>>(&mpi_config, 8);

MPIConfig::finalize();
mpi_config.finalize();
}
14 changes: 10 additions & 4 deletions sumcheck/src/prover_helper/sumcheck_gkr_vanilla.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,14 +329,20 @@ impl<'a, C: GKRFieldConfig> SumcheckGkrVanillaHelper<'a, C> {
#[inline]
pub(crate) fn prepare_mpi_var_vals(&mut self) {
let start = self.mpi_config.current_size();
self.mpi_config.append_local_field(&self.sp.simd_var_v_evals[0]);
self.mpi_config
.append_local_field(&self.sp.simd_var_v_evals[0]);
let end = self.mpi_config.current_size();
self.sp.mpi_var_v_evals = self.mpi_config.read_all_field_flat::<C::ChallengeField>(start, end);
self.sp.mpi_var_v_evals = self
.mpi_config
.read_all_field_flat::<C::ChallengeField>(start, end);

let start = self.mpi_config.current_size();
self.mpi_config.append_local_field(&(self.sp.simd_var_hg_evals[0] * self.sp.eq_evals_at_r_simd0[0]));
self.mpi_config
.append_local_field(&(self.sp.simd_var_hg_evals[0] * self.sp.eq_evals_at_r_simd0[0]));
let end = self.mpi_config.current_size();
self.sp.mpi_var_hg_evals = self.mpi_config.read_all_field_flat::<C::ChallengeField>(start, end);
self.sp.mpi_var_hg_evals = self
.mpi_config
.read_all_field_flat::<C::ChallengeField>(start, end);
}

#[inline]
Expand Down

0 comments on commit 31d32a5

Please sign in to comment.