From 66c08cafcfe7d2737d481222f06687e8425a805f Mon Sep 17 00:00:00 2001 From: naure Date: Tue, 24 Sep 2024 08:49:20 +0200 Subject: [PATCH] To monomial form (#272) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit _Issue #238_ A method to turn an Expression into a monomial form. The requirement of monomial forms makes really hard not only to write constraints, but also to organize code into functions of expressions. The circuit builder now accepts any expression and turns it into a monomial form automatically. --------- Co-authored-by: Aurélien Nicolas --- ceno_zkvm/src/circuit_builder.rs | 11 +- ceno_zkvm/src/expression.rs | 8 +- ceno_zkvm/src/expression/monomial.rs | 243 +++++++++++++++++++++++++++ 3 files changed, 257 insertions(+), 5 deletions(-) create mode 100644 ceno_zkvm/src/expression/monomial.rs diff --git a/ceno_zkvm/src/circuit_builder.rs b/ceno_zkvm/src/circuit_builder.rs index 6b616b1fa..3495b3c9b 100644 --- a/ceno_zkvm/src/circuit_builder.rs +++ b/ceno_zkvm/src/circuit_builder.rs @@ -284,10 +284,13 @@ impl ConstraintSystem { let path = self.ns.compute_path(name_fn().into()); self.assert_zero_expressions_namespace_map.push(path); } else { - assert!( - assert_zero_expr.is_monomial_form(), - "only support sumcheck in monomial form" - ); + let assert_zero_expr = if assert_zero_expr.is_monomial_form() { + assert_zero_expr + } else { + let e = assert_zero_expr.to_monomial_form(); + assert!(e.is_monomial_form(), "failed to put into monomial form"); + e + }; self.max_non_lc_degree = self.max_non_lc_degree.max(assert_zero_expr.degree()); self.assert_zero_sumcheck_expressions.push(assert_zero_expr); let path = self.ns.compute_path(name_fn().into()); diff --git a/ceno_zkvm/src/expression.rs b/ceno_zkvm/src/expression.rs index 628e4cbce..7d1069767 100644 --- a/ceno_zkvm/src/expression.rs +++ b/ceno_zkvm/src/expression.rs @@ -1,3 +1,5 @@ +mod monomial; + use std::{ cmp::max, mem::MaybeUninit, @@ -14,7 +16,7 @@ use crate::{ structs::{ChallengeId, WitnessId}, }; -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Eq)] pub enum Expression { /// WitIn(Id) WitIn(WitnessId), @@ -95,6 +97,10 @@ impl Expression { Self::is_monomial_form_inner(MonomialState::SumTerm, self) } + pub fn to_monomial_form(&self) -> Self { + self.to_monomial_form_inner() + } + pub fn unpack_sum(&self) -> Option<(Expression, Expression)> { match self { Expression::Sum(a, b) => Some((a.deref().clone(), b.deref().clone())), diff --git a/ceno_zkvm/src/expression/monomial.rs b/ceno_zkvm/src/expression/monomial.rs new file mode 100644 index 000000000..b5a3d13ca --- /dev/null +++ b/ceno_zkvm/src/expression/monomial.rs @@ -0,0 +1,243 @@ +use ff_ext::ExtensionField; +use goldilocks::SmallField; +use std::cmp::Ordering; + +use super::Expression; +use Expression::*; + +impl Expression { + pub(super) fn to_monomial_form_inner(&self) -> Self { + Self::sum_terms(Self::combine(self.distribute())) + } + + fn distribute(&self) -> Vec> { + match self { + Constant(_) => { + vec![Term { + coeff: self.clone(), + vars: vec![], + }] + } + + Fixed(_) | WitIn(_) | Challenge(..) => { + vec![Term { + coeff: Expression::ONE, + vars: vec![self.clone()], + }] + } + + Sum(a, b) => { + let mut res = a.distribute(); + res.extend(b.distribute()); + res + } + + Product(a, b) => { + let a = a.distribute(); + let b = b.distribute(); + let mut res = vec![]; + for a in a { + for b in &b { + res.push(Term { + coeff: a.coeff.clone() * b.coeff.clone(), + vars: a.vars.iter().chain(b.vars.iter()).cloned().collect(), + }); + } + } + res + } + + ScaledSum(x, a, b) => { + let x = x.distribute(); + let a = a.distribute(); + let mut res = b.distribute(); + for x in x { + for a in &a { + res.push(Term { + coeff: x.coeff.clone() * a.coeff.clone(), + vars: x.vars.iter().chain(a.vars.iter()).cloned().collect(), + }); + } + } + res + } + } + } + + fn combine(terms: Vec>) -> Vec> { + let mut res: Vec> = vec![]; + for mut term in terms { + term.vars.sort(); + + if let Some(res_term) = res.iter_mut().find(|res_term| res_term.vars == term.vars) { + res_term.coeff = res_term.coeff.clone() + term.coeff.clone(); + } else { + res.push(term); + } + } + res + } + + fn sum_terms(terms: Vec>) -> Self { + terms + .into_iter() + .map(|term| term.vars.into_iter().fold(term.coeff, Self::product)) + .reduce(Self::sum) + .unwrap_or(Expression::ZERO) + } + + fn product(a: Self, b: Self) -> Self { + Product(Box::new(a), Box::new(b)) + } + + fn sum(a: Self, b: Self) -> Self { + Sum(Box::new(a), Box::new(b)) + } +} + +#[derive(Clone, Debug)] +struct Term { + coeff: Expression, + vars: Vec>, +} + +// Define a lexicographic order for expressions. It compares the types first, then the arguments left-to-right. +impl Ord for Expression { + fn cmp(&self, other: &Self) -> Ordering { + use Ordering::*; + + match (self, other) { + (Fixed(a), Fixed(b)) => a.cmp(b), + (WitIn(a), WitIn(b)) => a.cmp(b), + (Constant(a), Constant(b)) => cmp_field(a, b), + (Challenge(a, b, c, d), Challenge(e, f, g, h)) => { + let cmp = a.cmp(e); + if cmp == Equal { + let cmp = b.cmp(f); + if cmp == Equal { + let cmp = cmp_ext(c, g); + if cmp == Equal { cmp_ext(d, h) } else { cmp } + } else { + cmp + } + } else { + cmp + } + } + (Sum(a, b), Sum(c, d)) => { + let cmp = a.cmp(c); + if cmp == Equal { b.cmp(d) } else { cmp } + } + (Product(a, b), Product(c, d)) => { + let cmp = a.cmp(c); + if cmp == Equal { b.cmp(d) } else { cmp } + } + (ScaledSum(x, a, b), ScaledSum(y, c, d)) => { + let cmp = x.cmp(y); + if cmp == Equal { + let cmp = a.cmp(c); + if cmp == Equal { b.cmp(d) } else { cmp } + } else { + cmp + } + } + (Fixed(_), _) => Less, + (WitIn(_), _) => Less, + (Constant(_), _) => Less, + (Challenge(..), _) => Less, + (Sum(..), _) => Less, + (Product(..), _) => Less, + (ScaledSum(..), _) => Less, + } + } +} + +impl PartialOrd for Expression { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +fn cmp_field(a: &F, b: &F) -> Ordering { + a.to_canonical_u64().cmp(&b.to_canonical_u64()) +} + +fn cmp_ext(a: &E, b: &E) -> Ordering { + let a = a.as_bases().iter().map(|f| f.to_canonical_u64()); + let b = b.as_bases().iter().map(|f| f.to_canonical_u64()); + a.cmp(b) +} + +#[cfg(test)] +mod tests { + use crate::{expression::Fixed as FixedS, scheme::utils::eval_by_expr_with_fixed}; + + use super::*; + use ff::Field; + use goldilocks::{Goldilocks as F, GoldilocksExt2 as E}; + use rand_chacha::{rand_core::SeedableRng, ChaChaRng}; + + #[test] + fn test_to_monomial_form() { + use Expression::*; + + let eval = make_eval(); + + let a = || Fixed(FixedS(0)); + let b = || Fixed(FixedS(1)); + let c = || Fixed(FixedS(2)); + let x = || WitIn(0); + let y = || WitIn(1); + let z = || WitIn(2); + let n = || Constant(104.into()); + let m = || Constant(-F::from(599)); + let r = || Challenge(0, 1, E::from(1), E::from(0)); + + let test_exprs: &[Expression] = &[ + a() * x() * x(), + a(), + x(), + n(), + r(), + a() + b() + x() + y() + n() + m() + r(), + a() * x() * n() * r(), + x() * y() * z(), + (x() + y() + a()) * b() * (y() + z()) + c(), + (r() * x() + n() + z()) * m() * y(), + (b() + y() + m() * z()) * (x() + y() + c()), + a() * r() * x(), + ]; + + for factored in test_exprs { + let monomials = factored.to_monomial_form_inner(); + assert!(monomials.is_monomial_form()); + + // Check that the two forms are equivalent (Schwartz-Zippel test). + let factored = eval(&factored); + let monomials = eval(&monomials); + assert_eq!(monomials, factored); + } + } + + /// Create an evaluator of expressions. Fixed, witness, and challenge values are pseudo-random. + fn make_eval() -> impl Fn(&Expression) -> E { + // Create a deterministic RNG from a seed. + let mut rng = ChaChaRng::from_seed([12u8; 32]); + let fixed = vec![ + E::random(&mut rng), + E::random(&mut rng), + E::random(&mut rng), + ]; + let witnesses = vec![ + E::random(&mut rng), + E::random(&mut rng), + E::random(&mut rng), + ]; + let challenges = vec![ + E::random(&mut rng), + E::random(&mut rng), + E::random(&mut rng), + ]; + move |expr: &Expression| eval_by_expr_with_fixed(&fixed, &witnesses, &challenges, expr) + } +}