diff --git a/crates/prover/src/constraint_framework/expr.rs b/crates/prover/src/constraint_framework/expr.rs index 95f9f3342..a1b68f294 100644 --- a/crates/prover/src/constraint_framework/expr.rs +++ b/crates/prover/src/constraint_framework/expr.rs @@ -1,8 +1,12 @@ +use std::collections::{HashMap, HashSet}; use std::ops::{Add, AddAssign, Index, Mul, MulAssign, Neg, Sub}; +use itertools::sorted; use num_traits::{One, Zero}; +use rand::rngs::SmallRng; +use rand::{Rng, SeedableRng}; -use super::{EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX}; +use super::{AssertEvaluator, EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX}; use crate::core::fields::cm31::CM31; use crate::core::fields::m31::{self, BaseField}; use crate::core::fields::qm31::{SecureField, QM31}; @@ -10,7 +14,7 @@ use crate::core::fields::FieldExpOps; use crate::core::lookups::utils::Fraction; /// A single base field column at index `idx` of interaction `interaction`, at mask offset `offset`. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct ColumnExpr { interaction: usize, idx: usize, @@ -189,8 +193,14 @@ impl BaseExpr { } } + pub fn safe_simplify(&self) -> Self { + let simplified = self.simplify(); + assert_eq!(self.random_eval(), simplified.random_eval()); + simplified + } + pub fn simplify_and_format(&self) -> String { - self.simplify().format_expr() + self.safe_simplify().format_expr() } /// Evaluates a base field expression. @@ -220,6 +230,24 @@ impl BaseExpr { Self::Inv(a) => a.eval_expr::(columns, vars).inverse(), } } + + pub fn collect_variables(&self) -> ExprVariables { + match self { + BaseExpr::Col(col) => ExprVariables::col(col.clone()), + BaseExpr::Const(_) => ExprVariables::default(), + BaseExpr::Param(param) => ExprVariables::param(param.to_string()), + BaseExpr::Add(a, b) => a.collect_variables() + b.collect_variables(), + BaseExpr::Sub(a, b) => a.collect_variables() + b.collect_variables(), + BaseExpr::Mul(a, b) => a.collect_variables() + b.collect_variables(), + BaseExpr::Neg(a) => a.collect_variables(), + BaseExpr::Inv(a) => a.collect_variables(), + } + } + + pub fn random_eval(&self) -> BaseField { + let assignment = self.collect_variables().randomize(); + self.eval_expr::, _, _>(&assignment.0, &assignment.1) + } } impl ExtExpr { @@ -278,8 +306,14 @@ impl ExtExpr { } } + pub fn safe_simplify(&self) -> Self { + let simplified = self.simplify(); + assert_eq!(self.random_eval(), simplified.random_eval()); + simplified + } + pub fn simplify_and_format(&self) -> String { - self.simplify().format_expr() + self.safe_simplify().format_expr() } /// Evaluates an extension field expression. @@ -319,6 +353,96 @@ impl ExtExpr { Self::Neg(a) => -a.eval_expr::(columns, vars, ext_vars), } } + + pub fn collect_variables(&self) -> ExprVariables { + match self { + ExtExpr::SecureCol([a, b, c, d]) => { + a.collect_variables() + + b.collect_variables() + + c.collect_variables() + + d.collect_variables() + } + ExtExpr::Const(_) => ExprVariables::default(), + ExtExpr::Param(param) => ExprVariables::ext_param(param.to_string()), + ExtExpr::Add(a, b) => a.collect_variables() + b.collect_variables(), + ExtExpr::Sub(a, b) => a.collect_variables() + b.collect_variables(), + ExtExpr::Mul(a, b) => a.collect_variables() + b.collect_variables(), + ExtExpr::Neg(a) => a.collect_variables(), + } + } + + pub fn random_eval(&self) -> SecureField { + let assignment = self.collect_variables().randomize(); + self.eval_expr::, _, _, _>(&assignment.0, &assignment.1, &assignment.2) + } +} + +#[derive(Default)] +pub struct ExprVariables { + pub cols: HashSet, + pub params: HashSet, + pub ext_params: HashSet, +} + +pub type ExprVarAssignment = ( + HashMap<(usize, usize, isize), BaseField>, + HashMap, + HashMap, +); + +impl ExprVariables { + pub fn col(col: ColumnExpr) -> Self { + Self { + cols: vec![col].into_iter().collect(), + params: HashSet::new(), + ext_params: HashSet::new(), + } + } + + pub fn param(param: String) -> Self { + Self { + cols: HashSet::new(), + params: vec![param].into_iter().collect(), + ext_params: HashSet::new(), + } + } + + pub fn ext_param(param: String) -> Self { + Self { + cols: HashSet::new(), + params: HashSet::new(), + ext_params: vec![param].into_iter().collect(), + } + } + + pub fn randomize(&self) -> ExprVarAssignment { + let mut rng = SmallRng::seed_from_u64(0); + + let cols = sorted(self.cols.iter()) + .map(|col| ((col.interaction, col.idx, col.offset), rng.gen())) + .collect(); + + let params = sorted(self.params.iter()) + .map(|param| (param.clone(), rng.gen())) + .collect(); + + let ext_params = sorted(self.ext_params.iter()) + .map(|param| (param.clone(), rng.gen())) + .collect(); + + (cols, params, ext_params) + } +} + +impl Add for ExprVariables { + type Output = Self; + fn add(self, rhs: Self) -> Self { + Self { + cols: self.cols.union(&rhs.cols).cloned().collect(), + params: self.params.union(&rhs.params).cloned().collect(), + ext_params: self.ext_params.union(&rhs.ext_params).cloned().collect(), + } + } } impl From for BaseExpr {