Skip to content

Commit

Permalink
To monomial form (#272)
Browse files Browse the repository at this point in the history
_Issue #238_

A method to turn an Expression into a monomial form.

The requirement of monomial forms makes really hard not only to write
constraints, but also to organize code into functions of expressions.

The circuit builder now accepts any expression and turns it into a
monomial form automatically.

---------

Co-authored-by: Aurélien Nicolas <[email protected]>
  • Loading branch information
naure and Aurélien Nicolas authored Sep 24, 2024
1 parent b871d91 commit 66c08ca
Show file tree
Hide file tree
Showing 3 changed files with 257 additions and 5 deletions.
11 changes: 7 additions & 4 deletions ceno_zkvm/src/circuit_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,10 +284,13 @@ impl<E: ExtensionField> ConstraintSystem<E> {
let path = self.ns.compute_path(name_fn().into());
self.assert_zero_expressions_namespace_map.push(path);
} else {
assert!(
assert_zero_expr.is_monomial_form(),
"only support sumcheck in monomial form"
);
let assert_zero_expr = if assert_zero_expr.is_monomial_form() {
assert_zero_expr
} else {
let e = assert_zero_expr.to_monomial_form();
assert!(e.is_monomial_form(), "failed to put into monomial form");
e
};
self.max_non_lc_degree = self.max_non_lc_degree.max(assert_zero_expr.degree());
self.assert_zero_sumcheck_expressions.push(assert_zero_expr);
let path = self.ns.compute_path(name_fn().into());
Expand Down
8 changes: 7 additions & 1 deletion ceno_zkvm/src/expression.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
mod monomial;

use std::{
cmp::max,
mem::MaybeUninit,
Expand All @@ -14,7 +16,7 @@ use crate::{
structs::{ChallengeId, WitnessId},
};

#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Expression<E: ExtensionField> {
/// WitIn(Id)
WitIn(WitnessId),
Expand Down Expand Up @@ -95,6 +97,10 @@ impl<E: ExtensionField> Expression<E> {
Self::is_monomial_form_inner(MonomialState::SumTerm, self)
}

pub fn to_monomial_form(&self) -> Self {
self.to_monomial_form_inner()
}

pub fn unpack_sum(&self) -> Option<(Expression<E>, Expression<E>)> {
match self {
Expression::Sum(a, b) => Some((a.deref().clone(), b.deref().clone())),
Expand Down
243 changes: 243 additions & 0 deletions ceno_zkvm/src/expression/monomial.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
use ff_ext::ExtensionField;
use goldilocks::SmallField;
use std::cmp::Ordering;

use super::Expression;
use Expression::*;

impl<E: ExtensionField> Expression<E> {
pub(super) fn to_monomial_form_inner(&self) -> Self {
Self::sum_terms(Self::combine(self.distribute()))
}

fn distribute(&self) -> Vec<Term<E>> {
match self {
Constant(_) => {
vec![Term {
coeff: self.clone(),
vars: vec![],
}]
}

Fixed(_) | WitIn(_) | Challenge(..) => {
vec![Term {
coeff: Expression::ONE,
vars: vec![self.clone()],
}]
}

Sum(a, b) => {
let mut res = a.distribute();
res.extend(b.distribute());
res
}

Product(a, b) => {
let a = a.distribute();
let b = b.distribute();
let mut res = vec![];
for a in a {
for b in &b {
res.push(Term {
coeff: a.coeff.clone() * b.coeff.clone(),
vars: a.vars.iter().chain(b.vars.iter()).cloned().collect(),
});
}
}
res
}

ScaledSum(x, a, b) => {
let x = x.distribute();
let a = a.distribute();
let mut res = b.distribute();
for x in x {
for a in &a {
res.push(Term {
coeff: x.coeff.clone() * a.coeff.clone(),
vars: x.vars.iter().chain(a.vars.iter()).cloned().collect(),
});
}
}
res
}
}
}

fn combine(terms: Vec<Term<E>>) -> Vec<Term<E>> {
let mut res: Vec<Term<E>> = vec![];
for mut term in terms {
term.vars.sort();

if let Some(res_term) = res.iter_mut().find(|res_term| res_term.vars == term.vars) {
res_term.coeff = res_term.coeff.clone() + term.coeff.clone();
} else {
res.push(term);
}
}
res
}

fn sum_terms(terms: Vec<Term<E>>) -> Self {
terms
.into_iter()
.map(|term| term.vars.into_iter().fold(term.coeff, Self::product))
.reduce(Self::sum)
.unwrap_or(Expression::ZERO)
}

fn product(a: Self, b: Self) -> Self {
Product(Box::new(a), Box::new(b))
}

fn sum(a: Self, b: Self) -> Self {
Sum(Box::new(a), Box::new(b))
}
}

#[derive(Clone, Debug)]
struct Term<E: ExtensionField> {
coeff: Expression<E>,
vars: Vec<Expression<E>>,
}

// Define a lexicographic order for expressions. It compares the types first, then the arguments left-to-right.
impl<E: ExtensionField> Ord for Expression<E> {
fn cmp(&self, other: &Self) -> Ordering {
use Ordering::*;

match (self, other) {
(Fixed(a), Fixed(b)) => a.cmp(b),
(WitIn(a), WitIn(b)) => a.cmp(b),
(Constant(a), Constant(b)) => cmp_field(a, b),
(Challenge(a, b, c, d), Challenge(e, f, g, h)) => {
let cmp = a.cmp(e);
if cmp == Equal {
let cmp = b.cmp(f);
if cmp == Equal {
let cmp = cmp_ext(c, g);
if cmp == Equal { cmp_ext(d, h) } else { cmp }
} else {
cmp
}
} else {
cmp
}
}
(Sum(a, b), Sum(c, d)) => {
let cmp = a.cmp(c);
if cmp == Equal { b.cmp(d) } else { cmp }
}
(Product(a, b), Product(c, d)) => {
let cmp = a.cmp(c);
if cmp == Equal { b.cmp(d) } else { cmp }
}
(ScaledSum(x, a, b), ScaledSum(y, c, d)) => {
let cmp = x.cmp(y);
if cmp == Equal {
let cmp = a.cmp(c);
if cmp == Equal { b.cmp(d) } else { cmp }
} else {
cmp
}
}
(Fixed(_), _) => Less,
(WitIn(_), _) => Less,
(Constant(_), _) => Less,
(Challenge(..), _) => Less,
(Sum(..), _) => Less,
(Product(..), _) => Less,
(ScaledSum(..), _) => Less,
}
}
}

impl<E: ExtensionField> PartialOrd for Expression<E> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}

fn cmp_field<F: SmallField>(a: &F, b: &F) -> Ordering {
a.to_canonical_u64().cmp(&b.to_canonical_u64())
}

fn cmp_ext<E: ExtensionField>(a: &E, b: &E) -> Ordering {
let a = a.as_bases().iter().map(|f| f.to_canonical_u64());
let b = b.as_bases().iter().map(|f| f.to_canonical_u64());
a.cmp(b)
}

#[cfg(test)]
mod tests {
use crate::{expression::Fixed as FixedS, scheme::utils::eval_by_expr_with_fixed};

use super::*;
use ff::Field;
use goldilocks::{Goldilocks as F, GoldilocksExt2 as E};
use rand_chacha::{rand_core::SeedableRng, ChaChaRng};

#[test]
fn test_to_monomial_form() {
use Expression::*;

let eval = make_eval();

let a = || Fixed(FixedS(0));
let b = || Fixed(FixedS(1));
let c = || Fixed(FixedS(2));
let x = || WitIn(0);
let y = || WitIn(1);
let z = || WitIn(2);
let n = || Constant(104.into());
let m = || Constant(-F::from(599));
let r = || Challenge(0, 1, E::from(1), E::from(0));

let test_exprs: &[Expression<E>] = &[
a() * x() * x(),
a(),
x(),
n(),
r(),
a() + b() + x() + y() + n() + m() + r(),
a() * x() * n() * r(),
x() * y() * z(),
(x() + y() + a()) * b() * (y() + z()) + c(),
(r() * x() + n() + z()) * m() * y(),
(b() + y() + m() * z()) * (x() + y() + c()),
a() * r() * x(),
];

for factored in test_exprs {
let monomials = factored.to_monomial_form_inner();
assert!(monomials.is_monomial_form());

// Check that the two forms are equivalent (Schwartz-Zippel test).
let factored = eval(&factored);
let monomials = eval(&monomials);
assert_eq!(monomials, factored);
}
}

/// Create an evaluator of expressions. Fixed, witness, and challenge values are pseudo-random.
fn make_eval() -> impl Fn(&Expression<E>) -> E {
// Create a deterministic RNG from a seed.
let mut rng = ChaChaRng::from_seed([12u8; 32]);
let fixed = vec![
E::random(&mut rng),
E::random(&mut rng),
E::random(&mut rng),
];
let witnesses = vec![
E::random(&mut rng),
E::random(&mut rng),
E::random(&mut rng),
];
let challenges = vec![
E::random(&mut rng),
E::random(&mut rng),
E::random(&mut rng),
];
move |expr: &Expression<E>| eval_by_expr_with_fixed(&fixed, &witnesses, &challenges, expr)
}
}

0 comments on commit 66c08ca

Please sign in to comment.