Skip to content

Commit

Permalink
optimize sumcheck algo
Browse files Browse the repository at this point in the history
  • Loading branch information
hero78119 committed Jul 8, 2024
1 parent 220e0d4 commit 0b47beb
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 63 deletions.
1 change: 1 addition & 0 deletions gkr-graph/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ impl<E: ExtensionField> IOPProverState<E> {
.enumerate()
.for_each(|(wire_id, (pred_type, point_and_eval))| match pred_type {
PredType::Source => {
// sanity check for input poly evaluation
if cfg!(debug_assertions) {
let input_layer_poly = witness.witness_in_ref()[wire_id]
.instances
Expand Down
6 changes: 2 additions & 4 deletions gkr/src/prover/phase1_output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@ use ff::Field;
use ff_ext::ExtensionField;
use itertools::{chain, izip, Itertools};
use multilinear_extensions::{
mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension, FieldType},
mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension},
virtual_poly::{build_eq_x_r_vec, VirtualPolynomial},
};
use std::{iter, mem, sync::Arc};
use transcript::Transcript;

use crate::{
circuit::EvaluateConstant,
izip_parallizable,
prover::SumcheckState,
structs::{Circuit, CircuitWitness, IOPProverState, IOPProverStepMessage, PointAndEval},
Expand All @@ -22,8 +21,7 @@ use rayon::iter::{IndexedParallelIterator, ParallelIterator};

// Prove the items copied from the output layer to the output witness for data parallel circuits.
// \sum_j( \alpha^j * subset[i][j](rt_j || ry_j) )
// = \sum_y( \sum_j( \alpha^j (eq or copy_to[j] or assert_subset_eq)(ry_j, y) \sum_t( eq(rt_j,
// t) * layers[i](t || y) ) ) )
// = \sum_y( \sum_j( \alpha^j (eq or copy_to[j] or assert_subset_eq)(ry_j, y) \sum_t( eq(rt_j, t) * layers[i](t || y) ) ) )
impl<E: ExtensionField> IOPProverState<E> {
/// Sumcheck 1: sigma = \sum_y( \sum_j f1^{(j)}(y) * g1^{(j)}(y) )
/// sigma = \sum_j( \alpha^j * wit_out_eval[j](rt_j || ry_j) )
Expand Down
20 changes: 9 additions & 11 deletions gkr/src/prover/phase2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,26 +45,25 @@ macro_rules! prepare_stepx_g_fn {
// The number of terms depends on the gate.
// Here is an example of degree 3:
// layers[i](rt || ry) = \sum_{s1}( \sum_{s2}( \sum_{s3}( \sum_{x1}( \sum_{x2}( \sum_{x3}(
// eq(rt, s1, s2, s3) * mul3(ry, x1, x2, x3) * layers[i + 1](s1 || x1) * layers[i + 1](s2 || x2)
// * layers[i + 1](s3 || x3) ) ) ) ) ) ) + sum_s1( sum_s2( sum_{x1}( sum_{x2}(
// eq(rt, s1, s2, s3) * mul3(ry, x1, x2, x3) * layers[i + 1](s1 || x1) * layers[i + 1](s2 || x2) * layers[i + 1](s3 || x3)
// ) ) ) ) ) ) + sum_s1( sum_s2( sum_{x1}( sum_{x2}(
// eq(rt, s1, s2) * mul2(ry, x1, x2) * layers[i + 1](s1 || x1) * layers[i + 1](s2 || x2)
// ) ) ) ) + \sum_{s1}( \sum_{x1}(
// eq(rt, s1) * add(ry, x1) * layers[i + 1](s1 || x1)
// ) ) + \sum_{s1}( \sum_{x1}(
// \sum_j eq(rt, s1) paste_from[j](ry, x1) * subset[j][i](s1 || x1)
// ) ) + add_const(ry)
impl<E: ExtensionField> IOPProverState<E> {
/// Sumcheck 1: sigma = \sum_{s1 || x1} f1(s1 || x1) * g1(s1 || x1) + \sum_j f1'_j(s1 || x1) *
/// g1'_j(s1 || x1) sigma = layers[i](rt || ry) - add_const(ry),
/// Sumcheck 1: sigma = \sum_{s1 || x1} f1(s1 || x1) * g1(s1 || x1) + \sum_j f1'_j(s1 || x1) * g1'_j(s1 || x1)
/// sigma = layers[i](rt || ry) - add_const(ry),
/// f1(s1 || x1) = layers[i + 1](s1 || x1)
/// g1(s1 || x1) = \sum_{s2}( \sum_{s3}( \sum_{x2}( \sum_{x3}(
/// eq(rt, s1, s2, s3) * mul3(ry, x1, x2, x3) * layers[i + 1](s2 || x2) * layers[i +
/// 1](s3 || x3) ) ) ) ) + \sum_{s2}( \sum_{x2}(
/// eq(rt, s1, s2, s3) * mul3(ry, x1, x2, x3) * layers[i + 1](s2 || x2) * layers[i + 1](s3 || x3)
/// ) ) ) ) + \sum_{s2}( \sum_{x2}(
/// eq(rt, s1, s2) * mul2(ry, x1, x2) * layers[i + 1](s2 || x2)
/// ) ) + eq(rt, s1) * add(ry, x1)
/// f1'^{(j)}(s1 || x1) = subset[j][i](s1 || x1)
/// g1'^{(j)}(s1 || x1) = eq(rt, s1) paste_from[j](ry, x1)
/// s1 || x1 || 0, s1 || x1 || 1
#[tracing::instrument(skip_all, name = "build_phase2_step1_sumcheck_poly")]
pub(super) fn build_phase2_step1_sumcheck_poly(
eq: &[Vec<E>; 1],
Expand Down Expand Up @@ -106,8 +105,8 @@ impl<E: ExtensionField> IOPProverState<E> {
let f1 = phase2_next_layer_polys_v2.clone();

// g1(s1 || x1) = \sum_{s2}( \sum_{s3}( \sum_{x2}( \sum_{x3}(
// eq(rt, s1, s2, s3) * mul3(ry, x1, x2, x3) * layers[i + 1](s2 || x2) * layers[i +
// 1](s3 || x3) ) ) ) ) + \sum_{s2}( \sum_{x2}(
// eq(rt, s1, s2, s3) * mul3(ry, x1, x2, x3) * layers[i + 1](s2 || x2) * layers[i + 1](s3 || x3)
// ) ) ) ) + \sum_{s2}( \sum_{x2}(
// eq(rt, s1, s2) * mul2(ry, x1, x2) * layers[i + 1](s2 || x2)
// ) ) + eq(rt, s1) * add(ry, x1)
let mut g1 = vec![E::ZERO; 1 << f1.num_vars];
Expand Down Expand Up @@ -177,8 +176,7 @@ impl<E: ExtensionField> IOPProverState<E> {
Vec<ArcDenseMultilinearExtension<E>>,
) = ([vec![f1], f1_j].concat(), [vec![g1], g1_j].concat());

// sumcheck: sigma = \sum_{s1 || x1} f1(s1 || x1) * g1(s1 || x1) + \sum_j f1'_j(s1 || x1) *
// g1'_j(s1 || x1)
// sumcheck: sigma = \sum_{s1 || x1} f1(s1 || x1) * g1(s1 || x1) + \sum_j f1'_j(s1 || x1) * g1'_j(s1 || x1)
let mut virtual_poly_1 = VirtualPolynomial::new(f[0].num_vars);
for (f, g) in f.into_iter().zip(g.into_iter()) {
let mut tmp = VirtualPolynomial::new_from_mle(f, E::BaseField::ONE);
Expand Down
4 changes: 1 addition & 3 deletions gkr/src/prover/phase2_input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ use crate::{
izip_parallizable,
prover::SumcheckState,
structs::{Circuit, CircuitWitness, IOPProverState, IOPProverStepMessage, PointAndEval},
utils::MultilinearExtensionFromVectors,
};

// Prove the computation in the current layer for data parallel circuits.
Expand Down Expand Up @@ -133,8 +132,8 @@ impl<E: ExtensionField> IOPProverState<E> {
.partition(|(i, _)| i % 2 == 0);
let eval_values_f = f_vec
.into_iter()
.map(|(_, f)| f)
.take(wits_in.len())
.map(|(_, f)| f)
.collect_vec();

self.to_next_phase_point_and_evals = izip!(paste_from_wit_in.iter(), eval_values_f.iter())
Expand All @@ -151,7 +150,6 @@ impl<E: ExtensionField> IOPProverState<E> {
PointAndEval::new_from_ref(&point, &wit_in_eval)
})
.collect_vec();

self.to_next_step_point = [&eval_point, hi_point].concat();

end_timer!(timer);
Expand Down
9 changes: 3 additions & 6 deletions gkr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ pub fn i64_to_field<F: SmallField>(x: i64) -> F {
/// This is to compute a segment indicator. Specifically, it is an MLE of the
/// following vector:
/// segment_{\mathbf{x}}
/// = \sum_{\mathbf{b}=min_idx + 1}^{2^n - 1} \prod_{i=0}^{n-1} (x_i b_i + (1 - x_i)(1 -
/// b_i))
/// = \sum_{\mathbf{b}=min_idx + 1}^{2^n - 1} \prod_{i=0}^{n-1} (x_i b_i + (1 - x_i)(1 - b_i))
pub(crate) fn segment_eval_greater_than<E: ExtensionField>(min_idx: usize, a: &[E]) -> E {
let running_product2 = {
let mut running_product = vec![E::ZERO; a.len() + 1];
Expand Down Expand Up @@ -51,8 +50,7 @@ pub(crate) fn segment_eval_greater_than<E: ExtensionField>(min_idx: usize, a: &[
/// This is to compute a variant of eq(\mathbf{x}, \mathbf{y}) for indices in
/// (min_idx, 2^n]. Specifically, it is an MLE of the following vector:
/// partial_eq_{\mathbf{x}}(\mathbf{y})
/// = \sum_{\mathbf{b}=min_idx + 1}^{2^n - 1} \prod_{i=0}^{n-1} (x_i y_i b_i + (1 - x_i)(1 -
/// y_i)(1 - b_i))
/// = \sum_{\mathbf{b}=min_idx + 1}^{2^n - 1} \prod_{i=0}^{n-1} (x_i y_i b_i + (1 - x_i)(1 - y_i)(1 - b_i))
#[allow(dead_code)]
pub(crate) fn eq_eval_greater_than<F: SmallField>(min_idx: usize, a: &[F], b: &[F]) -> F {
assert!(a.len() >= b.len());
Expand Down Expand Up @@ -99,8 +97,7 @@ pub(crate) fn eq_eval_greater_than<F: SmallField>(min_idx: usize, a: &[F], b: &[
/// This is to compute a variant of eq(\mathbf{x}, \mathbf{y}) for indices in
/// [0, max_idx]. Specifically, it is an MLE of the following vector:
/// partial_eq_{\mathbf{x}}(\mathbf{y})
/// = \sum_{\mathbf{b}=0}^{max_idx} \prod_{i=0}^{n-1} (x_i y_i b_i + (1 - x_i)(1 - y_i)(1 -
/// b_i))
/// = \sum_{\mathbf{b}=0}^{max_idx} \prod_{i=0}^{n-1} (x_i y_i b_i + (1 - x_i)(1 - y_i)(1 - b_i))
pub(crate) fn eq_eval_less_or_equal_than<E: ExtensionField>(max_idx: usize, a: &[E], b: &[E]) -> E {
assert!(a.len() >= b.len());
// Compute running product of ( x_i y_i + (1 - x_i)(1 - y_i) )_{0 <= i <= n}
Expand Down
6 changes: 2 additions & 4 deletions multilinear_extensions/src/mle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,7 @@ impl<E: ExtensionField> DenseMultilinearExtension<E> {
let nv = self.num_vars;
// evaluate single variable of partial point from left to right
for (i, point) in partial_point.iter().enumerate() {

Check warning on line 149 in multilinear_extensions/src/mle.rs

View workflow job for this annotation

GitHub Actions / Run Tests

unused variable: `i`
// override buf[b1, b2,..bt, 0] = (1-point) * buf[b1, b2,..bt, 0] + point * buf[b1,
// b2,..bt, 1] in parallel
// override buf[b1, b2,..bt, 0] = (1-point) * buf[b1, b2,..bt, 0] + point * buf[b1,b2,..bt, 1] in parallel
match &mut self.evaluations {
FieldType::Base(evaluations) => {
let evaluations_ext = evaluations
Expand Down Expand Up @@ -445,8 +444,7 @@ impl<E: ExtensionField> DenseMultilinearExtension<E> {
// evaluate single variable of partial point from left to right
for (i, point) in partial_point.iter().enumerate() {
let max_log2_size = nv - i;
// override buf[b1, b2,..bt, 0] = (1-point) * buf[b1, b2,..bt, 0] + point * buf[b1,
// b2,..bt, 1] in parallel
// override buf[b1, b2,..bt, 0] = (1-point) * buf[b1, b2,..bt, 0] + point * buf[b1,b2,..bt, 1] in parallel
match &mut self.evaluations {
FieldType::Base(evaluations) => {
let evaluations_ext = evaluations
Expand Down
5 changes: 2 additions & 3 deletions rustfmt.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
edition = "2021"

comment_width = 100
wrap_comments = false
comment_width = 300
imports_granularity = "Crate"
max_width = 100
newline_style = "Unix"
normalize_comments = true
reorder_imports = true
wrap_comments = true

6 changes: 4 additions & 2 deletions singer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ mod utils;

// Process sketch:
// 1. Construct instruction circuits and circuit gadgets => circuit gadgets
// 2. (bytecode + input) => Run revm interpreter, generate all wires in 2.1 phase 0 wire in +
// commitment 2.2 phase 1 wire in + commitment 2.3 phase 2 wire in + commitment
// 2. (bytecode + input) => Run revm interpreter, generate all wires in
// 2.1 phase 0 wire in + commitment
// 2.2 phase 1 wire in + commitment
// 2.3 phase 2 wire in + commitment
// 3. (circuit gadgets + wires in) => gkr graph + gkr witness
// 4. (gkr graph + gkr witness) => (gkr proof + point)
// 5. (commitments + point) => pcs proof
Expand Down
62 changes: 32 additions & 30 deletions sumcheck/src/prover.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{mem, sync::Arc};
use std::{array, mem, sync::Arc};

use ark_std::{end_timer, start_timer};
use crossbeam_channel::bounded;
Expand All @@ -8,8 +8,7 @@ use rayon::{
iter::{IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator},
prelude::{IntoParallelIterator, ParallelIterator},
};
use transcript::Challenge;
use transcript::{Transcript, TranscriptSyncronized};
use transcript::{Challenge, Transcript, TranscriptSyncronized};

#[cfg(feature = "non_pow2_rayon_thread")]
use crate::local_thread_pool::{create_local_pool_once, LOCAL_THREAD_POOL};
Expand All @@ -26,7 +25,8 @@ use crate::{
impl<E: ExtensionField> IOPProverState<E> {
/// Given a virtual polynomial, generate an IOP proof.
/// multi-threads model follow https://arxiv.org/pdf/2210.00264#page=8 "distributed sumcheck"
/// This is experiment features. It's preferable that we move parallel level up more to "bould_poly" so it can be more isolation
/// This is experiment features. It's preferable that we move parallel level up more to
/// "bould_poly" so it can be more isolation
#[tracing::instrument(skip_all, name = "sumcheck::prove_batch_polys")]
pub fn prove_batch_polys(
max_thread_id: usize,
Expand Down Expand Up @@ -72,7 +72,8 @@ impl<E: ExtensionField> IOPProverState<E> {
})
.collect::<Vec<_>>();

// spawn extra #(max_thread_id - 1) work threads, whereas the main-thread be the last work thread
// spawn extra #(max_thread_id - 1) work threads, whereas the main-thread be the last work
// thread
for thread_id in 0..(max_thread_id - 1) {
let mut prover_state = Self::prover_init_with_extrapolation_aux(
mem::take(&mut polys[thread_id]),
Expand Down Expand Up @@ -357,8 +358,9 @@ impl<E: ExtensionField> IOPProverState<E> {
self.poly
.flattened_ml_extensions
.iter_mut()
// benchmark result indicate make_mut achieve better performange than get_mut, which can be +5% overhead
// rust docs doen't explain the reason
// benchmark result indicate make_mut achieve better performange than get_mut,
// which can be +5% overhead rust docs doen't explain the
// reason
.map(Arc::make_mut)
.for_each(|f| {
f.fix_variables_in_place(&[r.elements]);
Expand All @@ -382,16 +384,16 @@ impl<E: ExtensionField> IOPProverState<E> {
1 => {
let f = &self.poly.flattened_ml_extensions[products[0]];
op_mle! {
|f| (0..f.len())
.into_iter()
.step_by(2)
.map(|b| {
AdditiveArray([
f[b],
f[b + 1]
])
})
.sum::<AdditiveArray<_, 2>>(),
|f| {
(0..f.len())
.into_iter()
.step_by(2)
.fold(AdditiveArray::<E, 2>(array::from_fn(|_| 0.into())), |mut acc, b| {
acc.0[0] += f[b];
acc.0[1] += f[b+1];
acc
})
},
|sum| AdditiveArray(sum.0.map(E::from))
}
.to_vec()
Expand All @@ -402,17 +404,16 @@ impl<E: ExtensionField> IOPProverState<E> {
&self.poly.flattened_ml_extensions[products[1]],
);
commutative_op_mle_pair!(
|f, g| (0..f.len())
.into_iter()
.step_by(2)
.map(|b| {
AdditiveArray([
f[b] * g[b],
f[b + 1] * g[b + 1],
(f[b + 1] + f[b + 1] - f[b]) * (g[b + 1] + g[b + 1] - g[b]),
])
})
.sum::<AdditiveArray<_, 3>>(),
|f, g| (0..f.len()).into_iter().step_by(2).fold(
AdditiveArray::<E, 3>(array::from_fn(|_| 0.into())),
|mut acc, b| {
acc.0[0] += f[b] * g[b];
acc.0[1] += f[b + 1] * g[b + 1];
acc.0[2] +=
(f[b + 1] + f[b + 1] - f[b]) * (g[b + 1] + g[b + 1] - g[b]);
acc
}
),
|sum| AdditiveArray(sum.0.map(E::from))
)
.to_vec()
Expand Down Expand Up @@ -623,8 +624,9 @@ impl<E: ExtensionField> IOPProverState<E> {
self.poly
.flattened_ml_extensions
.par_iter_mut()
// benchmark result indicate make_mut achieve better performange than get_mut, which can be +5% overhead
// rust docs doen't explain the reason
// benchmark result indicate make_mut achieve better performange than get_mut,
// which can be +5% overhead rust docs doen't explain the
// reason
.map(Arc::make_mut)
.for_each(|f| {
f.fix_variables_in_place_parallel(&[r.elements]);

Check warning on line 632 in sumcheck/src/prover.rs

View workflow job for this annotation

GitHub Actions / Run Tests

use of deprecated method `multilinear_extensions::mle::DenseMultilinearExtension::<E>::fix_variables_in_place_parallel`: deprecated parallel version due to syncronizaion overhead
Expand Down

0 comments on commit 0b47beb

Please sign in to comment.