From 0b47beb7d49e553c83e0d1cdbb0efd8cdd60c440 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Fri, 5 Jul 2024 11:09:21 +0800 Subject: [PATCH] optimize sumcheck algo --- gkr-graph/src/prover.rs | 1 + gkr/src/prover/phase1_output.rs | 6 +-- gkr/src/prover/phase2.rs | 20 +++++----- gkr/src/prover/phase2_input.rs | 4 +- gkr/src/utils.rs | 9 ++--- multilinear_extensions/src/mle.rs | 6 +-- rustfmt.toml | 5 +-- singer/src/lib.rs | 6 ++- sumcheck/src/prover.rs | 62 ++++++++++++++++--------------- 9 files changed, 56 insertions(+), 63 deletions(-) diff --git a/gkr-graph/src/prover.rs b/gkr-graph/src/prover.rs index 316abd05a..74cbcf0eb 100644 --- a/gkr-graph/src/prover.rs +++ b/gkr-graph/src/prover.rs @@ -94,6 +94,7 @@ impl IOPProverState { .enumerate() .for_each(|(wire_id, (pred_type, point_and_eval))| match pred_type { PredType::Source => { + // sanity check for input poly evaluation if cfg!(debug_assertions) { let input_layer_poly = witness.witness_in_ref()[wire_id] .instances diff --git a/gkr/src/prover/phase1_output.rs b/gkr/src/prover/phase1_output.rs index 385e49abf..3dbce07d2 100644 --- a/gkr/src/prover/phase1_output.rs +++ b/gkr/src/prover/phase1_output.rs @@ -3,14 +3,13 @@ use ff::Field; use ff_ext::ExtensionField; use itertools::{chain, izip, Itertools}; use multilinear_extensions::{ - mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension, FieldType}, + mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension}, virtual_poly::{build_eq_x_r_vec, VirtualPolynomial}, }; use std::{iter, mem, sync::Arc}; use transcript::Transcript; use crate::{ - circuit::EvaluateConstant, izip_parallizable, prover::SumcheckState, structs::{Circuit, CircuitWitness, IOPProverState, IOPProverStepMessage, PointAndEval}, @@ -22,8 +21,7 @@ use rayon::iter::{IndexedParallelIterator, ParallelIterator}; // Prove the items copied from the output layer to the output witness for data parallel circuits. // \sum_j( \alpha^j * subset[i][j](rt_j || ry_j) ) -// = \sum_y( \sum_j( \alpha^j (eq or copy_to[j] or assert_subset_eq)(ry_j, y) \sum_t( eq(rt_j, -// t) * layers[i](t || y) ) ) ) +// = \sum_y( \sum_j( \alpha^j (eq or copy_to[j] or assert_subset_eq)(ry_j, y) \sum_t( eq(rt_j, t) * layers[i](t || y) ) ) ) impl IOPProverState { /// Sumcheck 1: sigma = \sum_y( \sum_j f1^{(j)}(y) * g1^{(j)}(y) ) /// sigma = \sum_j( \alpha^j * wit_out_eval[j](rt_j || ry_j) ) diff --git a/gkr/src/prover/phase2.rs b/gkr/src/prover/phase2.rs index be8b60be0..a8f786039 100644 --- a/gkr/src/prover/phase2.rs +++ b/gkr/src/prover/phase2.rs @@ -45,8 +45,8 @@ macro_rules! prepare_stepx_g_fn { // The number of terms depends on the gate. // Here is an example of degree 3: // layers[i](rt || ry) = \sum_{s1}( \sum_{s2}( \sum_{s3}( \sum_{x1}( \sum_{x2}( \sum_{x3}( -// eq(rt, s1, s2, s3) * mul3(ry, x1, x2, x3) * layers[i + 1](s1 || x1) * layers[i + 1](s2 || x2) -// * layers[i + 1](s3 || x3) ) ) ) ) ) ) + sum_s1( sum_s2( sum_{x1}( sum_{x2}( +// eq(rt, s1, s2, s3) * mul3(ry, x1, x2, x3) * layers[i + 1](s1 || x1) * layers[i + 1](s2 || x2) * layers[i + 1](s3 || x3) +// ) ) ) ) ) ) + sum_s1( sum_s2( sum_{x1}( sum_{x2}( // eq(rt, s1, s2) * mul2(ry, x1, x2) * layers[i + 1](s1 || x1) * layers[i + 1](s2 || x2) // ) ) ) ) + \sum_{s1}( \sum_{x1}( // eq(rt, s1) * add(ry, x1) * layers[i + 1](s1 || x1) @@ -54,17 +54,16 @@ macro_rules! prepare_stepx_g_fn { // \sum_j eq(rt, s1) paste_from[j](ry, x1) * subset[j][i](s1 || x1) // ) ) + add_const(ry) impl IOPProverState { - /// Sumcheck 1: sigma = \sum_{s1 || x1} f1(s1 || x1) * g1(s1 || x1) + \sum_j f1'_j(s1 || x1) * - /// g1'_j(s1 || x1) sigma = layers[i](rt || ry) - add_const(ry), + /// Sumcheck 1: sigma = \sum_{s1 || x1} f1(s1 || x1) * g1(s1 || x1) + \sum_j f1'_j(s1 || x1) * g1'_j(s1 || x1) + /// sigma = layers[i](rt || ry) - add_const(ry), /// f1(s1 || x1) = layers[i + 1](s1 || x1) /// g1(s1 || x1) = \sum_{s2}( \sum_{s3}( \sum_{x2}( \sum_{x3}( - /// eq(rt, s1, s2, s3) * mul3(ry, x1, x2, x3) * layers[i + 1](s2 || x2) * layers[i + - /// 1](s3 || x3) ) ) ) ) + \sum_{s2}( \sum_{x2}( + /// eq(rt, s1, s2, s3) * mul3(ry, x1, x2, x3) * layers[i + 1](s2 || x2) * layers[i + 1](s3 || x3) + /// ) ) ) ) + \sum_{s2}( \sum_{x2}( /// eq(rt, s1, s2) * mul2(ry, x1, x2) * layers[i + 1](s2 || x2) /// ) ) + eq(rt, s1) * add(ry, x1) /// f1'^{(j)}(s1 || x1) = subset[j][i](s1 || x1) /// g1'^{(j)}(s1 || x1) = eq(rt, s1) paste_from[j](ry, x1) - /// s1 || x1 || 0, s1 || x1 || 1 #[tracing::instrument(skip_all, name = "build_phase2_step1_sumcheck_poly")] pub(super) fn build_phase2_step1_sumcheck_poly( eq: &[Vec; 1], @@ -106,8 +105,8 @@ impl IOPProverState { let f1 = phase2_next_layer_polys_v2.clone(); // g1(s1 || x1) = \sum_{s2}( \sum_{s3}( \sum_{x2}( \sum_{x3}( - // eq(rt, s1, s2, s3) * mul3(ry, x1, x2, x3) * layers[i + 1](s2 || x2) * layers[i + - // 1](s3 || x3) ) ) ) ) + \sum_{s2}( \sum_{x2}( + // eq(rt, s1, s2, s3) * mul3(ry, x1, x2, x3) * layers[i + 1](s2 || x2) * layers[i + 1](s3 || x3) + // ) ) ) ) + \sum_{s2}( \sum_{x2}( // eq(rt, s1, s2) * mul2(ry, x1, x2) * layers[i + 1](s2 || x2) // ) ) + eq(rt, s1) * add(ry, x1) let mut g1 = vec![E::ZERO; 1 << f1.num_vars]; @@ -177,8 +176,7 @@ impl IOPProverState { Vec>, ) = ([vec![f1], f1_j].concat(), [vec![g1], g1_j].concat()); - // sumcheck: sigma = \sum_{s1 || x1} f1(s1 || x1) * g1(s1 || x1) + \sum_j f1'_j(s1 || x1) * - // g1'_j(s1 || x1) + // sumcheck: sigma = \sum_{s1 || x1} f1(s1 || x1) * g1(s1 || x1) + \sum_j f1'_j(s1 || x1) * g1'_j(s1 || x1) let mut virtual_poly_1 = VirtualPolynomial::new(f[0].num_vars); for (f, g) in f.into_iter().zip(g.into_iter()) { let mut tmp = VirtualPolynomial::new_from_mle(f, E::BaseField::ONE); diff --git a/gkr/src/prover/phase2_input.rs b/gkr/src/prover/phase2_input.rs index 8a5961fac..350e0c644 100644 --- a/gkr/src/prover/phase2_input.rs +++ b/gkr/src/prover/phase2_input.rs @@ -16,7 +16,6 @@ use crate::{ izip_parallizable, prover::SumcheckState, structs::{Circuit, CircuitWitness, IOPProverState, IOPProverStepMessage, PointAndEval}, - utils::MultilinearExtensionFromVectors, }; // Prove the computation in the current layer for data parallel circuits. @@ -133,8 +132,8 @@ impl IOPProverState { .partition(|(i, _)| i % 2 == 0); let eval_values_f = f_vec .into_iter() - .map(|(_, f)| f) .take(wits_in.len()) + .map(|(_, f)| f) .collect_vec(); self.to_next_phase_point_and_evals = izip!(paste_from_wit_in.iter(), eval_values_f.iter()) @@ -151,7 +150,6 @@ impl IOPProverState { PointAndEval::new_from_ref(&point, &wit_in_eval) }) .collect_vec(); - self.to_next_step_point = [&eval_point, hi_point].concat(); end_timer!(timer); diff --git a/gkr/src/utils.rs b/gkr/src/utils.rs index f9436e6e7..2670c68f6 100644 --- a/gkr/src/utils.rs +++ b/gkr/src/utils.rs @@ -20,8 +20,7 @@ pub fn i64_to_field(x: i64) -> F { /// This is to compute a segment indicator. Specifically, it is an MLE of the /// following vector: /// segment_{\mathbf{x}} -/// = \sum_{\mathbf{b}=min_idx + 1}^{2^n - 1} \prod_{i=0}^{n-1} (x_i b_i + (1 - x_i)(1 - -/// b_i)) +/// = \sum_{\mathbf{b}=min_idx + 1}^{2^n - 1} \prod_{i=0}^{n-1} (x_i b_i + (1 - x_i)(1 - b_i)) pub(crate) fn segment_eval_greater_than(min_idx: usize, a: &[E]) -> E { let running_product2 = { let mut running_product = vec![E::ZERO; a.len() + 1]; @@ -51,8 +50,7 @@ pub(crate) fn segment_eval_greater_than(min_idx: usize, a: &[ /// This is to compute a variant of eq(\mathbf{x}, \mathbf{y}) for indices in /// (min_idx, 2^n]. Specifically, it is an MLE of the following vector: /// partial_eq_{\mathbf{x}}(\mathbf{y}) -/// = \sum_{\mathbf{b}=min_idx + 1}^{2^n - 1} \prod_{i=0}^{n-1} (x_i y_i b_i + (1 - x_i)(1 - -/// y_i)(1 - b_i)) +/// = \sum_{\mathbf{b}=min_idx + 1}^{2^n - 1} \prod_{i=0}^{n-1} (x_i y_i b_i + (1 - x_i)(1 - y_i)(1 - b_i)) #[allow(dead_code)] pub(crate) fn eq_eval_greater_than(min_idx: usize, a: &[F], b: &[F]) -> F { assert!(a.len() >= b.len()); @@ -99,8 +97,7 @@ pub(crate) fn eq_eval_greater_than(min_idx: usize, a: &[F], b: &[ /// This is to compute a variant of eq(\mathbf{x}, \mathbf{y}) for indices in /// [0, max_idx]. Specifically, it is an MLE of the following vector: /// partial_eq_{\mathbf{x}}(\mathbf{y}) -/// = \sum_{\mathbf{b}=0}^{max_idx} \prod_{i=0}^{n-1} (x_i y_i b_i + (1 - x_i)(1 - y_i)(1 - -/// b_i)) +/// = \sum_{\mathbf{b}=0}^{max_idx} \prod_{i=0}^{n-1} (x_i y_i b_i + (1 - x_i)(1 - y_i)(1 - b_i)) pub(crate) fn eq_eval_less_or_equal_than(max_idx: usize, a: &[E], b: &[E]) -> E { assert!(a.len() >= b.len()); // Compute running product of ( x_i y_i + (1 - x_i)(1 - y_i) )_{0 <= i <= n} diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index 018b782d2..1cefa8269 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -147,8 +147,7 @@ impl DenseMultilinearExtension { let nv = self.num_vars; // evaluate single variable of partial point from left to right for (i, point) in partial_point.iter().enumerate() { - // override buf[b1, b2,..bt, 0] = (1-point) * buf[b1, b2,..bt, 0] + point * buf[b1, - // b2,..bt, 1] in parallel + // override buf[b1, b2,..bt, 0] = (1-point) * buf[b1, b2,..bt, 0] + point * buf[b1,b2,..bt, 1] in parallel match &mut self.evaluations { FieldType::Base(evaluations) => { let evaluations_ext = evaluations @@ -445,8 +444,7 @@ impl DenseMultilinearExtension { // evaluate single variable of partial point from left to right for (i, point) in partial_point.iter().enumerate() { let max_log2_size = nv - i; - // override buf[b1, b2,..bt, 0] = (1-point) * buf[b1, b2,..bt, 0] + point * buf[b1, - // b2,..bt, 1] in parallel + // override buf[b1, b2,..bt, 0] = (1-point) * buf[b1, b2,..bt, 0] + point * buf[b1,b2,..bt, 1] in parallel match &mut self.evaluations { FieldType::Base(evaluations) => { let evaluations_ext = evaluations diff --git a/rustfmt.toml b/rustfmt.toml index 1b31b7d5f..835c6b277 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,10 +1,9 @@ edition = "2021" -comment_width = 100 +wrap_comments = false +comment_width = 300 imports_granularity = "Crate" max_width = 100 newline_style = "Unix" normalize_comments = true reorder_imports = true -wrap_comments = true - diff --git a/singer/src/lib.rs b/singer/src/lib.rs index 69d07cd8b..aa829c07f 100644 --- a/singer/src/lib.rs +++ b/singer/src/lib.rs @@ -23,8 +23,10 @@ mod utils; // Process sketch: // 1. Construct instruction circuits and circuit gadgets => circuit gadgets -// 2. (bytecode + input) => Run revm interpreter, generate all wires in 2.1 phase 0 wire in + -// commitment 2.2 phase 1 wire in + commitment 2.3 phase 2 wire in + commitment +// 2. (bytecode + input) => Run revm interpreter, generate all wires in +// 2.1 phase 0 wire in + commitment +// 2.2 phase 1 wire in + commitment +// 2.3 phase 2 wire in + commitment // 3. (circuit gadgets + wires in) => gkr graph + gkr witness // 4. (gkr graph + gkr witness) => (gkr proof + point) // 5. (commitments + point) => pcs proof diff --git a/sumcheck/src/prover.rs b/sumcheck/src/prover.rs index 5af55e749..d32423ee7 100644 --- a/sumcheck/src/prover.rs +++ b/sumcheck/src/prover.rs @@ -1,4 +1,4 @@ -use std::{mem, sync::Arc}; +use std::{array, mem, sync::Arc}; use ark_std::{end_timer, start_timer}; use crossbeam_channel::bounded; @@ -8,8 +8,7 @@ use rayon::{ iter::{IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator}, prelude::{IntoParallelIterator, ParallelIterator}, }; -use transcript::Challenge; -use transcript::{Transcript, TranscriptSyncronized}; +use transcript::{Challenge, Transcript, TranscriptSyncronized}; #[cfg(feature = "non_pow2_rayon_thread")] use crate::local_thread_pool::{create_local_pool_once, LOCAL_THREAD_POOL}; @@ -26,7 +25,8 @@ use crate::{ impl IOPProverState { /// Given a virtual polynomial, generate an IOP proof. /// multi-threads model follow https://arxiv.org/pdf/2210.00264#page=8 "distributed sumcheck" - /// This is experiment features. It's preferable that we move parallel level up more to "bould_poly" so it can be more isolation + /// This is experiment features. It's preferable that we move parallel level up more to + /// "bould_poly" so it can be more isolation #[tracing::instrument(skip_all, name = "sumcheck::prove_batch_polys")] pub fn prove_batch_polys( max_thread_id: usize, @@ -72,7 +72,8 @@ impl IOPProverState { }) .collect::>(); - // spawn extra #(max_thread_id - 1) work threads, whereas the main-thread be the last work thread + // spawn extra #(max_thread_id - 1) work threads, whereas the main-thread be the last work + // thread for thread_id in 0..(max_thread_id - 1) { let mut prover_state = Self::prover_init_with_extrapolation_aux( mem::take(&mut polys[thread_id]), @@ -357,8 +358,9 @@ impl IOPProverState { self.poly .flattened_ml_extensions .iter_mut() - // benchmark result indicate make_mut achieve better performange than get_mut, which can be +5% overhead - // rust docs doen't explain the reason + // benchmark result indicate make_mut achieve better performange than get_mut, + // which can be +5% overhead rust docs doen't explain the + // reason .map(Arc::make_mut) .for_each(|f| { f.fix_variables_in_place(&[r.elements]); @@ -382,16 +384,16 @@ impl IOPProverState { 1 => { let f = &self.poly.flattened_ml_extensions[products[0]]; op_mle! { - |f| (0..f.len()) - .into_iter() - .step_by(2) - .map(|b| { - AdditiveArray([ - f[b], - f[b + 1] - ]) - }) - .sum::>(), + |f| { + (0..f.len()) + .into_iter() + .step_by(2) + .fold(AdditiveArray::(array::from_fn(|_| 0.into())), |mut acc, b| { + acc.0[0] += f[b]; + acc.0[1] += f[b+1]; + acc + }) + }, |sum| AdditiveArray(sum.0.map(E::from)) } .to_vec() @@ -402,17 +404,16 @@ impl IOPProverState { &self.poly.flattened_ml_extensions[products[1]], ); commutative_op_mle_pair!( - |f, g| (0..f.len()) - .into_iter() - .step_by(2) - .map(|b| { - AdditiveArray([ - f[b] * g[b], - f[b + 1] * g[b + 1], - (f[b + 1] + f[b + 1] - f[b]) * (g[b + 1] + g[b + 1] - g[b]), - ]) - }) - .sum::>(), + |f, g| (0..f.len()).into_iter().step_by(2).fold( + AdditiveArray::(array::from_fn(|_| 0.into())), + |mut acc, b| { + acc.0[0] += f[b] * g[b]; + acc.0[1] += f[b + 1] * g[b + 1]; + acc.0[2] += + (f[b + 1] + f[b + 1] - f[b]) * (g[b + 1] + g[b + 1] - g[b]); + acc + } + ), |sum| AdditiveArray(sum.0.map(E::from)) ) .to_vec() @@ -623,8 +624,9 @@ impl IOPProverState { self.poly .flattened_ml_extensions .par_iter_mut() - // benchmark result indicate make_mut achieve better performange than get_mut, which can be +5% overhead - // rust docs doen't explain the reason + // benchmark result indicate make_mut achieve better performange than get_mut, + // which can be +5% overhead rust docs doen't explain the + // reason .map(Arc::make_mut) .for_each(|f| { f.fix_variables_in_place_parallel(&[r.elements]);