Skip to content

Commit

Permalink
Add safe simplify for expressions that compares random assignments
Browse files Browse the repository at this point in the history
before and after.
  • Loading branch information
Alon-Ti committed Dec 2, 2024
1 parent df4a642 commit f84268d
Showing 1 changed file with 134 additions and 10 deletions.
144 changes: 134 additions & 10 deletions crates/prover/src/constraint_framework/expr.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
use std::collections::{HashMap, HashSet};
use std::ops::{Add, AddAssign, Index, Mul, MulAssign, Neg, Sub};

use itertools::sorted;
use num_traits::{One, Zero};
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};

use super::{EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX};
use super::{AssertEvaluator, EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX};
use crate::core::fields::cm31::CM31;
use crate::core::fields::m31::{self, BaseField};
use crate::core::fields::qm31::{SecureField, QM31};
use crate::core::fields::FieldExpOps;
use crate::core::lookups::utils::Fraction;

/// A single base field column at index `idx` of interaction `interaction`, at mask offset `offset`.
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct ColumnExpr {
interaction: usize,
idx: usize,
Expand Down Expand Up @@ -174,11 +178,11 @@ impl BaseExpr {
}
}

pub fn simplify(&self) -> Self {
fn _simplify(&self) -> Self {
let simple = simplify_arithmetic!(self);
match simple {
Self::Inv(a) => {
let a = a.simplify();
let a = a._simplify();
match a {
Self::Inv(inv_a) => *inv_a, // 1 / (1 / a) = a
Self::Const(c) => Self::Const(c.inverse()),
Expand All @@ -189,6 +193,12 @@ impl BaseExpr {
}
}

pub fn simplify(&self) -> Self {
let simplified = self._simplify();
assert_eq!(self.random_eval(), simplified.random_eval());
simplified
}

pub fn simplify_and_format(&self) -> String {
self.simplify().format_expr()
}
Expand Down Expand Up @@ -220,6 +230,24 @@ impl BaseExpr {
Self::Inv(a) => a.eval_expr::<E, C, V>(columns, vars).inverse(),
}
}

pub fn collect_variables(&self) -> ExprVariables {
match self {
BaseExpr::Col(col) => ExprVariables::col(col.clone()),
BaseExpr::Const(_) => ExprVariables::default(),
BaseExpr::Param(param) => ExprVariables::param(param.to_string()),
BaseExpr::Add(a, b) => a.collect_variables() + b.collect_variables(),
BaseExpr::Sub(a, b) => a.collect_variables() + b.collect_variables(),
BaseExpr::Mul(a, b) => a.collect_variables() + b.collect_variables(),
BaseExpr::Neg(a) => a.collect_variables(),
BaseExpr::Inv(a) => a.collect_variables(),
}
}

pub fn random_eval(&self) -> BaseField {
let assignment = self.collect_variables().randomize();
self.eval_expr::<AssertEvaluator<'_>, _, _>(&assignment.0, &assignment.1)
}
}

impl ExtExpr {
Expand Down Expand Up @@ -256,14 +284,14 @@ impl ExtExpr {
}
}

pub fn simplify(&self) -> Self {
fn _simplify(&self) -> Self {
let simple = simplify_arithmetic!(self);
match simple {
Self::SecureCol([a, b, c, d]) => {
let a = a.simplify();
let b = b.simplify();
let c = c.simplify();
let d = d.simplify();
let a = a._simplify();
let b = b._simplify();
let c = c._simplify();
let d = d._simplify();
match (a.clone(), b.clone(), c.clone(), d.clone()) {
(
BaseExpr::Const(a_val),
Expand All @@ -278,6 +306,12 @@ impl ExtExpr {
}
}

pub fn simplify(&self) -> Self {
let simplified = self._simplify();
assert_eq!(self.random_eval(), simplified.random_eval());
simplified
}

pub fn simplify_and_format(&self) -> String {
self.simplify().format_expr()
}
Expand Down Expand Up @@ -319,6 +353,96 @@ impl ExtExpr {
Self::Neg(a) => -a.eval_expr::<E, C, V, EV>(columns, vars, ext_vars),
}
}

pub fn collect_variables(&self) -> ExprVariables {
match self {
ExtExpr::SecureCol([a, b, c, d]) => {
a.collect_variables()
+ b.collect_variables()
+ c.collect_variables()
+ d.collect_variables()
}
ExtExpr::Const(_) => ExprVariables::default(),
ExtExpr::Param(param) => ExprVariables::ext_param(param.to_string()),
ExtExpr::Add(a, b) => a.collect_variables() + b.collect_variables(),
ExtExpr::Sub(a, b) => a.collect_variables() + b.collect_variables(),
ExtExpr::Mul(a, b) => a.collect_variables() + b.collect_variables(),
ExtExpr::Neg(a) => a.collect_variables(),
}
}

pub fn random_eval(&self) -> SecureField {
let assignment = self.collect_variables().randomize();
self.eval_expr::<AssertEvaluator<'_>, _, _, _>(&assignment.0, &assignment.1, &assignment.2)
}
}

#[derive(Default)]
pub struct ExprVariables {
pub cols: HashSet<ColumnExpr>,
pub params: HashSet<String>,
pub ext_params: HashSet<String>,
}

pub type ExprVarAssignment = (
HashMap<(usize, usize, isize), BaseField>,
HashMap<String, BaseField>,
HashMap<String, SecureField>,
);

impl ExprVariables {
pub fn col(col: ColumnExpr) -> Self {
Self {
cols: vec![col].into_iter().collect(),
params: HashSet::new(),
ext_params: HashSet::new(),
}
}

pub fn param(param: String) -> Self {
Self {
cols: HashSet::new(),
params: vec![param].into_iter().collect(),
ext_params: HashSet::new(),
}
}

pub fn ext_param(param: String) -> Self {
Self {
cols: HashSet::new(),
params: HashSet::new(),
ext_params: vec![param].into_iter().collect(),
}
}

pub fn randomize(&self) -> ExprVarAssignment {
let mut rng = SmallRng::seed_from_u64(0);

let cols = sorted(self.cols.iter())
.map(|col| ((col.interaction, col.idx, col.offset), rng.gen()))
.collect();

let params = sorted(self.params.iter())
.map(|param| (param.clone(), rng.gen()))
.collect();

let ext_params = sorted(self.ext_params.iter())
.map(|param| (param.clone(), rng.gen()))
.collect();

(cols, params, ext_params)
}
}

impl Add for ExprVariables {
type Output = Self;
fn add(self, rhs: Self) -> Self {
Self {
cols: self.cols.union(&rhs.cols).cloned().collect(),
params: self.params.union(&rhs.params).cloned().collect(),
ext_params: self.ext_params.union(&rhs.ext_params).cloned().collect(),
}
}
}

impl From<BaseField> for BaseExpr {
Expand Down Expand Up @@ -897,7 +1021,7 @@ mod tests {

let full_eval = expr.eval_expr::<AssertEvaluator<'_>, _, _, _>(&columns, &vars, &ext_vars);
let simplified_eval = expr
.simplify()
._simplify()
.eval_expr::<AssertEvaluator<'_>, _, _, _>(&columns, &vars, &ext_vars);

assert_eq!(full_eval, simplified_eval);
Expand Down

0 comments on commit f84268d

Please sign in to comment.