Skip to content

Commit

Permalink
Add safe simplify for expressions that compares random assignments
Browse files Browse the repository at this point in the history
before and after.
  • Loading branch information
Alon-Ti committed Dec 2, 2024
1 parent a459c48 commit d13bc18
Showing 1 changed file with 128 additions and 4 deletions.
132 changes: 128 additions & 4 deletions crates/prover/src/constraint_framework/expr.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
use std::collections::{HashMap, HashSet};
use std::ops::{Add, AddAssign, Index, Mul, MulAssign, Neg, Sub};

use itertools::sorted;
use num_traits::{One, Zero};
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};

use super::{EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX};
use super::{AssertEvaluator, EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX};
use crate::core::fields::cm31::CM31;
use crate::core::fields::m31::{self, BaseField};
use crate::core::fields::qm31::{SecureField, QM31};
use crate::core::fields::FieldExpOps;
use crate::core::lookups::utils::Fraction;

/// A single base field column at index `idx` of interaction `interaction`, at mask offset `offset`.
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct ColumnExpr {
interaction: usize,
idx: usize,
Expand Down Expand Up @@ -189,8 +193,14 @@ impl BaseExpr {
}
}

pub fn safe_simplify(&self) -> Self {
let simplified = self.simplify();
assert_eq!(self.random_eval(), simplified.random_eval());
simplified
}

pub fn simplify_and_format(&self) -> String {
self.simplify().format_expr()
self.safe_simplify().format_expr()
}

/// Evaluates a base field expression.
Expand Down Expand Up @@ -220,6 +230,24 @@ impl BaseExpr {
Self::Inv(a) => a.eval_expr::<E, C, V>(columns, vars).inverse(),
}
}

pub fn collect_variables(&self) -> ExprVariables {
match self {
BaseExpr::Col(col) => ExprVariables::col(col.clone()),
BaseExpr::Const(_) => ExprVariables::default(),
BaseExpr::Param(param) => ExprVariables::param(param.to_string()),
BaseExpr::Add(a, b) => a.collect_variables() + b.collect_variables(),
BaseExpr::Sub(a, b) => a.collect_variables() + b.collect_variables(),
BaseExpr::Mul(a, b) => a.collect_variables() + b.collect_variables(),
BaseExpr::Neg(a) => a.collect_variables(),
BaseExpr::Inv(a) => a.collect_variables(),
}
}

pub fn random_eval(&self) -> BaseField {
let assignment = self.collect_variables().randomize();
self.eval_expr::<AssertEvaluator<'_>, _, _>(&assignment.0, &assignment.1)
}
}

impl ExtExpr {
Expand Down Expand Up @@ -278,8 +306,14 @@ impl ExtExpr {
}
}

pub fn safe_simplify(&self) -> Self {
let simplified = self.simplify();
assert_eq!(self.random_eval(), simplified.random_eval());
simplified
}

pub fn simplify_and_format(&self) -> String {
self.simplify().format_expr()
self.safe_simplify().format_expr()
}

/// Evaluates an extension field expression.
Expand Down Expand Up @@ -319,6 +353,96 @@ impl ExtExpr {
Self::Neg(a) => -a.eval_expr::<E, C, V, EV>(columns, vars, ext_vars),
}
}

pub fn collect_variables(&self) -> ExprVariables {
match self {
ExtExpr::SecureCol([a, b, c, d]) => {
a.collect_variables()
+ b.collect_variables()
+ c.collect_variables()
+ d.collect_variables()
}
ExtExpr::Const(_) => ExprVariables::default(),
ExtExpr::Param(param) => ExprVariables::ext_param(param.to_string()),
ExtExpr::Add(a, b) => a.collect_variables() + b.collect_variables(),
ExtExpr::Sub(a, b) => a.collect_variables() + b.collect_variables(),
ExtExpr::Mul(a, b) => a.collect_variables() + b.collect_variables(),
ExtExpr::Neg(a) => a.collect_variables(),
}
}

pub fn random_eval(&self) -> SecureField {
let assignment = self.collect_variables().randomize();
self.eval_expr::<AssertEvaluator<'_>, _, _, _>(&assignment.0, &assignment.1, &assignment.2)
}
}

#[derive(Default)]
pub struct ExprVariables {
pub cols: HashSet<ColumnExpr>,
pub params: HashSet<String>,
pub ext_params: HashSet<String>,
}

pub type ExprVarAssignment = (
HashMap<(usize, usize, isize), BaseField>,
HashMap<String, BaseField>,
HashMap<String, SecureField>,
);

impl ExprVariables {
pub fn col(col: ColumnExpr) -> Self {
Self {
cols: vec![col].into_iter().collect(),
params: HashSet::new(),
ext_params: HashSet::new(),
}
}

pub fn param(param: String) -> Self {
Self {
cols: HashSet::new(),
params: vec![param].into_iter().collect(),
ext_params: HashSet::new(),
}
}

pub fn ext_param(param: String) -> Self {
Self {
cols: HashSet::new(),
params: HashSet::new(),
ext_params: vec![param].into_iter().collect(),
}
}

pub fn randomize(&self) -> ExprVarAssignment {
let mut rng = SmallRng::seed_from_u64(0);

let cols = sorted(self.cols.iter())
.map(|col| ((col.interaction, col.idx, col.offset), rng.gen()))
.collect();

let params = sorted(self.params.iter())
.map(|param| (param.clone(), rng.gen()))
.collect();

let ext_params = sorted(self.ext_params.iter())
.map(|param| (param.clone(), rng.gen()))
.collect();

(cols, params, ext_params)
}
}

impl Add for ExprVariables {
type Output = Self;
fn add(self, rhs: Self) -> Self {
Self {
cols: self.cols.union(&rhs.cols).cloned().collect(),
params: self.params.union(&rhs.params).cloned().collect(),
ext_params: self.ext_params.union(&rhs.ext_params).cloned().collect(),
}
}
}

impl From<BaseField> for BaseExpr {
Expand Down

0 comments on commit d13bc18

Please sign in to comment.