Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor sumcheck benchmark #295

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 74 additions & 50 deletions sumcheck/benches/devirgo_sumcheck.rs
Original file line number Diff line number Diff line change
@@ -1,80 +1,104 @@
#![allow(clippy::manual_memcpy)]
#![allow(clippy::needless_range_loop)]

use std::sync::Arc;
use std::{array, sync::Arc};

use ark_std::test_rng;
use const_env::from_env;
use criterion::*;
use ff_ext::{ff::Field, ExtensionField};
use itertools::Itertools;
use sumcheck::{structs::IOPProverState, util::ceil_log2};
use sumcheck::{structs::IOPProverStateV2 as IOPProverState, util::ceil_log2};

use goldilocks::GoldilocksExt2;
use multilinear_extensions::{
commutative_op_mle_pair,
mle::{ArcDenseMultilinearExtension, DenseMultilinearExtension, MultilinearExtension},
virtual_poly::VirtualPolynomial,
op_mle,
virtual_poly_v2::{ArcMultilinearExtension, VirtualPolynomialV2 as VirtualPolynomial},
};
use transcript::Transcript;

criterion_group!(benches, sumcheck_fn, devirgo_sumcheck_fn,);
criterion_main!(benches);

const NUM_SAMPLES: usize = 10;
const NUM_DEGREE: usize = 3;
const NV: [usize; 2] = [25, 26];

/// transpose 2d vector without clone
pub fn transpose<T>(v: Vec<Vec<T>>) -> Vec<Vec<T>> {
assert!(!v.is_empty());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's empty, you can just return the original vector, no need to panic.

In any case, we already have multiple implementations of transpose in the code. Why don't you just re-use one of them, instead of going for what looks like a copy-and-paste job? (Feel free to re-organise the existing code, if you need to, to make one of the other implementations available to be referenced from here.)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ceno_zkvm/src/utils.rs has literally the exact same code. Did you copy and paste that?

let len = v[0].len();
let mut iters: Vec<_> = v.into_iter().map(|n| n.into_iter()).collect();
(0..len)
.map(|_| {
iters
.iter_mut()
.map(|n| n.next().unwrap())
.collect::<Vec<T>>()
})
.collect()
}

fn prepare_input<E: ExtensionField>(
fn prepare_input<'a, E: ExtensionField>(
max_thread_id: usize,
nv: usize,
) -> (E, VirtualPolynomial<E>, Vec<VirtualPolynomial<E>>) {
) -> (E, VirtualPolynomial<'a, E>, Vec<VirtualPolynomial<'a, E>>) {
let mut rng = test_rng();
let size_log2 = ceil_log2(max_thread_id);
let f1: Arc<DenseMultilinearExtension<E>> =
DenseMultilinearExtension::<E>::random(nv, &mut rng).into();
let g1: Arc<DenseMultilinearExtension<E>> =
DenseMultilinearExtension::<E>::random(nv, &mut rng).into();

let mut virtual_poly_1 = VirtualPolynomial::new_from_mle(f1.clone(), E::BaseField::ONE);
virtual_poly_1.mul_by_mle(g1.clone(), <E as ff_ext::ExtensionField>::BaseField::ONE);

let mut virtual_poly_f1: Vec<VirtualPolynomial<E>> = match &f1.evaluations {
multilinear_extensions::mle::FieldType::Base(evaluations) => evaluations
.chunks((1 << nv) >> size_log2)
.map(|chunk| {
DenseMultilinearExtension::<E>::from_evaluations_vec(nv - size_log2, chunk.to_vec())
.into()
let fs: [ArcMultilinearExtension<'a, E>; NUM_DEGREE] = array::from_fn(|_| {
let mle: ArcMultilinearExtension<'a, E> =
DenseMultilinearExtension::<E>::random(nv, &mut rng).into();
mle
});

let mut virtual_poly_v1 = VirtualPolynomial::new(nv);
virtual_poly_v1.add_mle_list(fs.to_vec(), E::ONE);

// devirgo version
let mut virtual_poly_v2: Vec<Vec<ArcMultilinearExtension<'a, E>>> = transpose(
fs.iter()
.map(|f| match &f.evaluations() {
multilinear_extensions::mle::FieldType::Base(evaluations) => evaluations
.chunks((1 << nv) >> size_log2)
.map(|chunk| {
let mle: ArcMultilinearExtension<'a, E> =
DenseMultilinearExtension::<E>::from_evaluations_vec(
nv - size_log2,
chunk.to_vec(),
)
.into();
mle
})
.collect_vec(),
_ => unreachable!(),
})
.map(|mle| VirtualPolynomial::new_from_mle(mle, E::BaseField::ONE))
.collect_vec(),
_ => unreachable!(),
};

let poly_g1: Vec<ArcDenseMultilinearExtension<E>> = match &g1.evaluations {
multilinear_extensions::mle::FieldType::Base(evaluations) => evaluations
.chunks((1 << nv) >> size_log2)
.map(|chunk| {
DenseMultilinearExtension::<E>::from_evaluations_vec(nv - size_log2, chunk.to_vec())
.into()
.collect(),
);
let mut virtual_poly_v2: Vec<VirtualPolynomial<E>> = virtual_poly_v2
.into_iter()
.map(|fs| {
let mut virtual_polynomial = VirtualPolynomial::new(fs[0].num_vars());
virtual_polynomial.add_mle_list(fs, E::ONE);
virtual_polynomial
})
.collect();

let asserted_sum = fs
.iter()
.fold(vec![E::ONE; 1 << nv], |mut acc, f| {
op_mle!(f, |f| {
(0..f.len()).zip(acc.iter_mut()).for_each(|(i, acc)| {
*acc *= f[i];
});
acc
})
.collect_vec(),
_ => unreachable!(),
};

let asserted_sum = commutative_op_mle_pair!(|f1, g1| {
(0..f1.len())
.map(|i| f1[i] * g1[i])
.fold(E::ZERO, |acc, item| acc + item)
});
})
.iter()
.sum::<E>();

virtual_poly_f1
.iter_mut()
.zip(poly_g1.iter())
.for_each(|(f1, g1)| f1.mul_by_mle(g1.clone(), E::BaseField::ONE));
(
asserted_sum,
virtual_poly_1,
virtual_poly_f1.try_into().unwrap(),
)
(asserted_sum, virtual_poly_v1, virtual_poly_v2)
}

#[from_env]
Expand All @@ -83,7 +107,7 @@ const RAYON_NUM_THREADS: usize = 8;
fn sumcheck_fn(c: &mut Criterion) {
type E = GoldilocksExt2;

for nv in [13, 14, 15, 16].into_iter() {
for nv in NV.into_iter() {
// expand more input size once runtime is acceptable
let mut group = c.benchmark_group(format!("sumcheck_nv_{}", nv));
group.sample_size(NUM_SAMPLES);
Expand Down Expand Up @@ -121,7 +145,7 @@ fn sumcheck_fn(c: &mut Criterion) {
fn devirgo_sumcheck_fn(c: &mut Criterion) {
type E = GoldilocksExt2;

for nv in [13, 14, 15, 16].into_iter() {
for nv in NV.into_iter() {
// expand more input size once runtime is acceptable
let mut group = c.benchmark_group(format!("devirgo_nv_{}", nv));
group.sample_size(NUM_SAMPLES);
Expand Down
Loading