From 32371d5ae3d368349c951dce6a421d6db13c4700 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Mon, 30 Sep 2024 11:32:12 +0800 Subject: [PATCH] refactor sumcheck benchmark --- sumcheck/benches/devirgo_sumcheck.rs | 124 ++++++++++++++++----------- 1 file changed, 74 insertions(+), 50 deletions(-) diff --git a/sumcheck/benches/devirgo_sumcheck.rs b/sumcheck/benches/devirgo_sumcheck.rs index 4ca8875ab..c60a6d44b 100644 --- a/sumcheck/benches/devirgo_sumcheck.rs +++ b/sumcheck/benches/devirgo_sumcheck.rs @@ -1,20 +1,21 @@ #![allow(clippy::manual_memcpy)] #![allow(clippy::needless_range_loop)] -use std::sync::Arc; +use std::{array, sync::Arc}; use ark_std::test_rng; use const_env::from_env; use criterion::*; use ff_ext::{ff::Field, ExtensionField}; use itertools::Itertools; -use sumcheck::{structs::IOPProverState, util::ceil_log2}; +use sumcheck::{structs::IOPProverStateV2 as IOPProverState, util::ceil_log2}; use goldilocks::GoldilocksExt2; use multilinear_extensions::{ commutative_op_mle_pair, mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension, MultilinearExtension}, - virtual_poly::VirtualPolynomial, + op_mle, + virtual_poly_v2::{ArcMultilinearExtension, VirtualPolynomialV2 as VirtualPolynomial}, }; use transcript::Transcript; @@ -22,59 +23,82 @@ criterion_group!(benches, sumcheck_fn, devirgo_sumcheck_fn,); criterion_main!(benches); const NUM_SAMPLES: usize = 10; +const NUM_DEGREE: usize = 3; +const NV: [usize; 2] = [25, 26]; + +/// transpose 2d vector without clone +pub fn transpose(v: Vec>) -> Vec> { + assert!(!v.is_empty()); + let len = v[0].len(); + let mut iters: Vec<_> = v.into_iter().map(|n| n.into_iter()).collect(); + (0..len) + .map(|_| { + iters + .iter_mut() + .map(|n| n.next().unwrap()) + .collect::>() + }) + .collect() +} -fn prepare_input( +fn prepare_input<'a, E: ExtensionField>( max_thread_id: usize, nv: usize, -) -> (E, VirtualPolynomial, Vec>) { +) -> (E, VirtualPolynomial<'a, E>, Vec>) { let mut rng = test_rng(); let size_log2 = ceil_log2(max_thread_id); - let f1: Arc> = - DenseMultilinearExtension::::random(nv, &mut rng).into(); - let g1: Arc> = - DenseMultilinearExtension::::random(nv, &mut rng).into(); - - let mut virtual_poly_1 = VirtualPolynomial::new_from_mle(f1.clone(), E::BaseField::ONE); - virtual_poly_1.mul_by_mle(g1.clone(), ::BaseField::ONE); - - let mut virtual_poly_f1: Vec> = match &f1.evaluations { - multilinear_extensions::mle::FieldType::Base(evaluations) => evaluations - .chunks((1 << nv) >> size_log2) - .map(|chunk| { - DenseMultilinearExtension::::from_evaluations_vec(nv - size_log2, chunk.to_vec()) - .into() + let fs: [ArcMultilinearExtension<'a, E>; NUM_DEGREE] = array::from_fn(|_| { + let mle: ArcMultilinearExtension<'a, E> = + DenseMultilinearExtension::::random(nv, &mut rng).into(); + mle + }); + + let mut virtual_poly_v1 = VirtualPolynomial::new(nv); + virtual_poly_v1.add_mle_list(fs.to_vec(), E::ONE); + + // devirgo version + let mut virtual_poly_v2: Vec>> = transpose( + fs.iter() + .map(|f| match &f.evaluations() { + multilinear_extensions::mle::FieldType::Base(evaluations) => evaluations + .chunks((1 << nv) >> size_log2) + .map(|chunk| { + let mle: ArcMultilinearExtension<'a, E> = + DenseMultilinearExtension::::from_evaluations_vec( + nv - size_log2, + chunk.to_vec(), + ) + .into(); + mle + }) + .collect_vec(), + _ => unreachable!(), }) - .map(|mle| VirtualPolynomial::new_from_mle(mle, E::BaseField::ONE)) - .collect_vec(), - _ => unreachable!(), - }; - - let poly_g1: Vec> = match &g1.evaluations { - multilinear_extensions::mle::FieldType::Base(evaluations) => evaluations - .chunks((1 << nv) >> size_log2) - .map(|chunk| { - DenseMultilinearExtension::::from_evaluations_vec(nv - size_log2, chunk.to_vec()) - .into() + .collect(), + ); + let mut virtual_poly_v2: Vec> = virtual_poly_v2 + .into_iter() + .map(|fs| { + let mut virtual_polynomial = VirtualPolynomial::new(fs[0].num_vars()); + virtual_polynomial.add_mle_list(fs, E::ONE); + virtual_polynomial + }) + .collect(); + + let asserted_sum = fs + .iter() + .fold(vec![E::ONE; 1 << nv], |mut acc, f| { + op_mle!(f, |f| { + (0..f.len()).zip(acc.iter_mut()).for_each(|(i, acc)| { + *acc *= f[i]; + }); + acc }) - .collect_vec(), - _ => unreachable!(), - }; - - let asserted_sum = commutative_op_mle_pair!(|f1, g1| { - (0..f1.len()) - .map(|i| f1[i] * g1[i]) - .fold(E::ZERO, |acc, item| acc + item) - }); + }) + .iter() + .sum::(); - virtual_poly_f1 - .iter_mut() - .zip(poly_g1.iter()) - .for_each(|(f1, g1)| f1.mul_by_mle(g1.clone(), E::BaseField::ONE)); - ( - asserted_sum, - virtual_poly_1, - virtual_poly_f1.try_into().unwrap(), - ) + (asserted_sum, virtual_poly_v1, virtual_poly_v2) } #[from_env] @@ -83,7 +107,7 @@ const RAYON_NUM_THREADS: usize = 8; fn sumcheck_fn(c: &mut Criterion) { type E = GoldilocksExt2; - for nv in [13, 14, 15, 16].into_iter() { + for nv in NV.into_iter() { // expand more input size once runtime is acceptable let mut group = c.benchmark_group(format!("sumcheck_nv_{}", nv)); group.sample_size(NUM_SAMPLES); @@ -121,7 +145,7 @@ fn sumcheck_fn(c: &mut Criterion) { fn devirgo_sumcheck_fn(c: &mut Criterion) { type E = GoldilocksExt2; - for nv in [13, 14, 15, 16].into_iter() { + for nv in NV.into_iter() { // expand more input size once runtime is acceptable let mut group = c.benchmark_group(format!("devirgo_nv_{}", nv)); group.sample_size(NUM_SAMPLES);