Skip to content

Commit

Permalink
sumcheck protocol support mixed num_vars monomial form (#235)
Browse files Browse the repository at this point in the history
### Goal
To make sumcheck protocol support different num_vars, aiming for
- [x] minimal change to sumcheck protocol, make verifier remain the
same. no extra meta data passed from prover, and what prover have just
mle and it's eval size

Besides this PR also remove some parallism in verifier since it's
unnecessary for relative low cost.

### design rationale
(Also comments in codebase for reference)
To deal with different num_vars, we exploit a fact that for each product
which num_vars < max_num_vars,
for it evaluation value we need to times 2^(max_num_vars - num_vars)
E.g. Giving multivariate poly $f(X) = f_1(X1) + f_2(X), X1 \in {F}^{n'},
X \in {F}^{n}, |X1| := n', |X| = n, n' <= n$
For i round univariate poly, $f^i(x)$
$f^i[0] = \sum_b f(r, 0, b), b \in [0, 1]^{n-i-1}, r \in {F}^{n-i-1}$
chanllenge get from prev rounds
= $\sum_b f_1(r, 0, b1) + f_2(r, 0, b), |b| >= |b1|, |b| - |b1| = n -
n'$
= $2^{(|b| - |b1|)} * \sum_{b1} f_1(r, 0, b1) + \sum_b f_2(r, 0, b)$
same applied on f^i[1]
It imply that, for every evals in f_1, to compute univariate poly, we
just need to times a factor 2^(|b| - |b1|) for it evaluation value


### benchmark
benchmark with ceno_zkvm `riscv_add`, and gkr `keccak` both remain the
same and no impact.
You might see some redundancy coding style, but this is for retain the
best performance. I tried other variants and it impact benchmark results

### scope
Related to #109 #210 ....
To address #126 #127 
This enhance protocol features potiential can be used for `range
table-circuit`, `init/final-memory`, `cpu-init/cpu-final halt` to make
selector sumcheck support batching different num_instance witin.
  • Loading branch information
hero78119 authored Sep 19, 2024
1 parent 20cb152 commit 43af034
Show file tree
Hide file tree
Showing 10 changed files with 336 additions and 108 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

105 changes: 103 additions & 2 deletions ceno_zkvm/src/virtual_polys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,19 +168,28 @@ impl<'a, E: ExtensionField> VirtualPolynomials<'a, E> {
#[cfg(test)]
mod tests {

use ark_std::test_rng;
use ff_ext::ExtensionField;
use goldilocks::{Goldilocks, GoldilocksExt2};
use itertools::Itertools;
use multilinear_extensions::{mle::IntoMLE, virtual_poly_v2::ArcMultilinearExtension};
use multilinear_extensions::{
mle::IntoMLE,
virtual_poly::VPAuxInfo,
virtual_poly_v2::{ArcMultilinearExtension, VirtualPolynomialV2},
};
use sumcheck::structs::{IOPProverStateV2, IOPVerifierState};
use transcript::Transcript;

use crate::{
circuit_builder::{CircuitBuilder, ConstraintSystem},
expression::{Expression, ToExpr},
virtual_polys::VirtualPolynomials,
};
use ff::Field;
type E = GoldilocksExt2;

#[test]
fn test_add_mle_list_by_expr() {
type E = GoldilocksExt2;
let mut cs = ConstraintSystem::new(|| "test_root");
let mut cb = CircuitBuilder::<E>::new(&mut cs);
let x = cb.create_witin(|| "x").unwrap();
Expand Down Expand Up @@ -218,4 +227,96 @@ mod tests {
assert!(distrinct_zerocheck_terms_set.len() == 1);
assert!(virtual_polys.degree() == 3);
}

#[test]
fn test_sumcheck_different_degree() {
let max_num_vars = 3;
let fn_eval = |fs: &[ArcMultilinearExtension<E>]| -> E {
let base_2 = <E as ExtensionField>::BaseField::from(2);

let evals = fs.iter().fold(
vec![<E as ExtensionField>::BaseField::ONE; 1 << fs[0].num_vars()],
|mut evals, f| {
evals
.iter_mut()
.zip(f.get_base_field_vec())
.for_each(|(e, v)| {
*e *= v;
});
evals
},
);

<<E as ExtensionField>::BaseField as std::convert::Into<E>>::into(
evals.iter().sum::<<E as ExtensionField>::BaseField>()
* base_2.pow([(max_num_vars - fs[0].num_vars()) as u64]),
)
};
let num_threads = 1;
let mut transcript = Transcript::new(b"test");

let mut rng = test_rng();

let f1: [ArcMultilinearExtension<E>; 2] = std::array::from_fn(|_| {
(0..1 << (max_num_vars - 2))
.map(|_| <E as ExtensionField>::BaseField::random(&mut rng))
.collect_vec()
.into_mle()
.into()
});
let f2: [ArcMultilinearExtension<E>; 1] = std::array::from_fn(|_| {
(0..1 << (max_num_vars))
.map(|_| <E as ExtensionField>::BaseField::random(&mut rng))
.collect_vec()
.into_mle()
.into()
});
let f3: [ArcMultilinearExtension<E>; 3] = std::array::from_fn(|_| {
(0..1 << (max_num_vars - 1))
.map(|_| <E as ExtensionField>::BaseField::random(&mut rng))
.collect_vec()
.into_mle()
.into()
});

let mut virtual_polys = VirtualPolynomials::<E>::new(num_threads, max_num_vars);

virtual_polys.add_mle_list(f1.iter().collect(), E::ONE);
virtual_polys.add_mle_list(f2.iter().collect(), E::ONE);
virtual_polys.add_mle_list(f3.iter().collect(), E::ONE);

let (sumcheck_proofs, _) = IOPProverStateV2::prove_batch_polys(
num_threads,
virtual_polys.get_batched_polys(),
&mut transcript,
);

let mut transcript = Transcript::new(b"test");
let subclaim = IOPVerifierState::<E>::verify(
fn_eval(&f1) + fn_eval(&f2) + fn_eval(&f3),
&sumcheck_proofs,
&VPAuxInfo {
max_degree: 3,
num_variables: max_num_vars,
phantom: std::marker::PhantomData,
},
&mut transcript,
);

let mut verifier_poly = VirtualPolynomialV2::new(max_num_vars);
verifier_poly.add_mle_list(f1.to_vec(), E::ONE);
verifier_poly.add_mle_list(f2.to_vec(), E::ONE);
verifier_poly.add_mle_list(f3.to_vec(), E::ONE);
assert!(
verifier_poly.evaluate(
subclaim
.point
.iter()
.map(|c| c.elements)
.collect::<Vec<_>>()
.as_ref()
) == subclaim.expected_evaluation,
"wrong subclaim"
);
}
}
1 change: 1 addition & 0 deletions multilinear_extensions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ tracing = "0.1.40"
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
tracing-flame = "0.2.0"
ff_ext = { path = "../ff_ext" }
itertools = "0.12.1"
ark-std.workspace = true
ff.workspace = true
goldilocks.workspace = true
Expand Down
10 changes: 3 additions & 7 deletions multilinear_extensions/src/mle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ pub trait IntoMLEs<T>: Sized {
fn into_mles(self) -> Vec<T>;
}

impl<F: Field, E: ExtensionField<BaseField = F>> IntoMLEs<DenseMultilinearExtension<E>> for Vec<Vec<F>> {
impl<F: Field, E: ExtensionField<BaseField = F>> IntoMLEs<DenseMultilinearExtension<E>>
for Vec<Vec<F>>
{
fn into_mles(self) -> Vec<DenseMultilinearExtension<E>> {
self.into_iter().map(|v| v.into_mle()).collect()
}
Expand Down Expand Up @@ -1000,12 +1002,6 @@ macro_rules! op_mle {
match &$a.evaluations() {
$crate::mle::FieldType::Base(a) => {
let $tmp_a = if let Some((start, offset)) = $a.evaluations_range() {
println!(
"op_mle start {}, offset {}, a.len {}",
start,
offset,
a.len()
);
&a[start..][..offset]
} else {
&a[..]
Expand Down
5 changes: 5 additions & 0 deletions multilinear_extensions/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,8 @@ pub fn create_uninit_vec<T: Sized>(len: usize) -> Vec<MaybeUninit<T>> {
unsafe { vec.set_len(len) };
vec
}

#[inline(always)]
pub fn largest_even_below(n: usize) -> usize {
if n % 2 == 0 { n } else { n.saturating_sub(1) }
}
53 changes: 29 additions & 24 deletions multilinear_extensions/src/virtual_poly_v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::{
};
use ark_std::{end_timer, start_timer};
use ff_ext::ExtensionField;
use itertools::Itertools;
use serde::{Deserialize, Serialize};

pub type ArcMultilinearExtension<'a, E> =
Expand Down Expand Up @@ -55,8 +56,8 @@ pub struct VirtualPolynomialV2<'a, E: ExtensionField> {
pub struct VPAuxInfo<E> {
/// max number of multiplicands in each product
pub max_degree: usize,
/// number of variables of the polynomial
pub num_variables: usize,
/// max number of variables of the polynomial
pub max_num_variables: usize,
/// Associated field
#[doc(hidden)]
pub phantom: PhantomData<E>,
Expand All @@ -69,12 +70,12 @@ impl<E: ExtensionField> AsRef<[u8]> for VPAuxInfo<E> {
}

impl<'a, E: ExtensionField> VirtualPolynomialV2<'a, E> {
/// Creates an empty virtual polynomial with `num_variables`.
pub fn new(num_variables: usize) -> Self {
/// Creates an empty virtual polynomial with `max_num_variables`.
pub fn new(max_num_variables: usize) -> Self {
VirtualPolynomialV2 {
aux_info: VPAuxInfo {
max_degree: 0,
num_variables,
max_num_variables,
phantom: PhantomData,
},
products: Vec::new(),
Expand All @@ -93,7 +94,7 @@ impl<'a, E: ExtensionField> VirtualPolynomialV2<'a, E> {
aux_info: VPAuxInfo {
// The max degree is the max degree of any individual variable
max_degree: 1,
num_variables: mle.num_vars(),
max_num_variables: mle.num_vars(),
phantom: PhantomData,
},
// here `0` points to the first polynomial of `flattened_ml_extensions`
Expand All @@ -104,8 +105,10 @@ impl<'a, E: ExtensionField> VirtualPolynomialV2<'a, E> {
}

/// Add a product of list of multilinear extensions to self
/// Returns an error if the list is empty, or the MLE has a different
/// `num_vars()` from self.
/// Returns an error if the list is empty.
///
/// mle in mle_list must be in same num_vars() in same product,
/// while different product can have different num_vars()
///
/// The MLEs will be multiplied together, and then multiplied by the scalar
/// `coefficient`.
Expand All @@ -114,18 +117,20 @@ impl<'a, E: ExtensionField> VirtualPolynomialV2<'a, E> {
let mut indexed_product = Vec::with_capacity(mle_list.len());

assert!(!mle_list.is_empty(), "input mle_list is empty");
// sanity check: all mle in mle_list must have same num_vars()
assert!(
mle_list
.iter()
.map(|m| {
assert!(m.num_vars() <= self.aux_info.max_num_variables);
m.num_vars()
})
.all_equal()
);

self.aux_info.max_degree = max(self.aux_info.max_degree, mle_list.len());

for mle in mle_list {
assert_eq!(
mle.num_vars(),
self.aux_info.num_variables,
"product has a multiplicand with wrong number of variables {} vs {}",
mle.num_vars(),
self.aux_info.num_variables
);

let mle_ptr: usize = Arc::as_ptr(&mle) as *const () as usize;
if let Some(index) = self.raw_pointers_lookup_table.get(&mle_ptr) {
indexed_product.push(*index)
Expand Down Expand Up @@ -163,10 +168,10 @@ impl<'a, E: ExtensionField> VirtualPolynomialV2<'a, E> {

assert_eq!(
mle.num_vars(),
self.aux_info.num_variables,
self.aux_info.max_num_variables,
"product has a multiplicand with wrong number of variables {} vs {}",
mle.num_vars(),
self.aux_info.num_variables
self.aux_info.max_num_variables
);

let mle_ptr = Arc::as_ptr(&mle) as *const () as usize;
Expand Down Expand Up @@ -200,17 +205,17 @@ impl<'a, E: ExtensionField> VirtualPolynomialV2<'a, E> {
let start = start_timer!(|| "evaluation");

assert_eq!(
self.aux_info.num_variables,
self.aux_info.max_num_variables,
point.len(),
"wrong number of variables {} vs {}",
self.aux_info.num_variables,
self.aux_info.max_num_variables,
point.len()
);

let evals: Vec<E> = self
.flattened_ml_extensions
.iter()
.map(|x| x.evaluate(point))
.map(|x| x.evaluate(&point[0..x.num_vars()]))
.collect();

let res = self
Expand All @@ -225,11 +230,11 @@ impl<'a, E: ExtensionField> VirtualPolynomialV2<'a, E> {

/// Print out the evaluation map for testing. Panic if the num_vars() > 5.
pub fn print_evals(&self) {
if self.aux_info.num_variables > 5 {
if self.aux_info.max_num_variables > 5 {
panic!("this function is used for testing only. cannot print more than 5 num_vars()")
}
for i in 0..1 << self.aux_info.num_variables {
let point = bit_decompose(i, self.aux_info.num_variables);
for i in 0..1 << self.aux_info.max_num_variables {
let point = bit_decompose(i, self.aux_info.max_num_variables);
let point_fr: Vec<E> = point.iter().map(|&x| E::from(x as u64)).collect();
println!("{} {:?}", i, self.evaluate(point_fr.as_ref()))
}
Expand Down
Loading

0 comments on commit 43af034

Please sign in to comment.