Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add safe simplify for expressions that compares random assignments before and after. #918

Merged
merged 1 commit into from
Dec 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 152 additions & 9 deletions crates/prover/src/constraint_framework/expr.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
use std::collections::{HashMap, HashSet};
use std::ops::{Add, AddAssign, Index, Mul, MulAssign, Neg, Sub};

use itertools::sorted;
use num_traits::{One, Zero};
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};

use super::{EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX};
use super::{AssertEvaluator, EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX};
use crate::core::fields::cm31::CM31;
use crate::core::fields::m31::{self, BaseField};
use crate::core::fields::qm31::{SecureField, QM31};
use crate::core::fields::FieldExpOps;
use crate::core::lookups::utils::Fraction;

/// A single base field column at index `idx` of interaction `interaction`, at mask offset `offset`.
#[derive(Clone, Debug, PartialEq)]
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct ColumnExpr {
interaction: usize,
idx: usize,
Expand Down Expand Up @@ -174,11 +178,14 @@ impl BaseExpr {
}
}

pub fn simplify(&self) -> Self {
/// Helper function, use [`simplify`] instead.
///
/// Simplifies an expression by applying basic arithmetic rules.
fn unchecked_simplify(&self) -> Self {
let simple = simplify_arithmetic!(self);
match simple {
Self::Inv(a) => {
let a = a.simplify();
let a = a.unchecked_simplify();
match a {
Self::Inv(inv_a) => *inv_a, // 1 / (1 / a) = a
Self::Const(c) => Self::Const(c.inverse()),
Expand All @@ -189,6 +196,14 @@ impl BaseExpr {
}
}

/// Simplifies an expression by applying basic arithmetic rules and ensures that the result is
/// equivalent to the original expression by assigning random values.
pub fn simplify(&self) -> Self {
let simplified = self.unchecked_simplify();
assert_eq!(self.random_eval(), simplified.random_eval());
simplified
}

pub fn simplify_and_format(&self) -> String {
self.simplify().format_expr()
}
Expand Down Expand Up @@ -220,6 +235,25 @@ impl BaseExpr {
Self::Inv(a) => a.eval_expr::<E, C, V>(columns, vars).inverse(),
}
}

pub fn collect_variables(&self) -> ExprVariables {
match self {
BaseExpr::Col(col) => ExprVariables::col(col.clone()),
BaseExpr::Const(_) => ExprVariables::default(),
BaseExpr::Param(param) => ExprVariables::param(param.to_string()),
BaseExpr::Add(a, b) => a.collect_variables() + b.collect_variables(),
BaseExpr::Sub(a, b) => a.collect_variables() + b.collect_variables(),
BaseExpr::Mul(a, b) => a.collect_variables() + b.collect_variables(),
BaseExpr::Neg(a) => a.collect_variables(),
BaseExpr::Inv(a) => a.collect_variables(),
}
}

pub fn random_eval(&self) -> BaseField {
let assignment = self.collect_variables().random_assignment();
assert!(assignment.2.is_empty());
self.eval_expr::<AssertEvaluator<'_>, _, _>(&assignment.0, &assignment.1)
}
}

impl ExtExpr {
Expand Down Expand Up @@ -256,14 +290,17 @@ impl ExtExpr {
}
}

pub fn simplify(&self) -> Self {
/// Helper function, use [`simplify`] instead.
///
/// Simplifies an expression by applying basic arithmetic rules.
fn unchecked_simplify(&self) -> Self {
let simple = simplify_arithmetic!(self);
match simple {
Self::SecureCol([a, b, c, d]) => {
let a = a.simplify();
let b = b.simplify();
let c = c.simplify();
let d = d.simplify();
let a = a.unchecked_simplify();
let b = b.unchecked_simplify();
let c = c.unchecked_simplify();
let d = d.unchecked_simplify();
match (a.clone(), b.clone(), c.clone(), d.clone()) {
(
BaseExpr::Const(a_val),
Expand All @@ -278,6 +315,14 @@ impl ExtExpr {
}
}

/// Simplifies an expression by applying basic arithmetic rules and ensures that the result is
/// equivalent to the original expression by assigning random values.
pub fn simplify(&self) -> Self {
let simplified = self.unchecked_simplify();
assert_eq!(self.random_eval(), simplified.random_eval());
simplified
}

pub fn simplify_and_format(&self) -> String {
self.simplify().format_expr()
}
Expand Down Expand Up @@ -319,6 +364,104 @@ impl ExtExpr {
Self::Neg(a) => -a.eval_expr::<E, C, V, EV>(columns, vars, ext_vars),
}
}

pub fn collect_variables(&self) -> ExprVariables {
match self {
ExtExpr::SecureCol([a, b, c, d]) => {
a.collect_variables()
+ b.collect_variables()
+ c.collect_variables()
+ d.collect_variables()
}
ExtExpr::Const(_) => ExprVariables::default(),
ExtExpr::Param(param) => ExprVariables::ext_param(param.to_string()),
ExtExpr::Add(a, b) => a.collect_variables() + b.collect_variables(),
ExtExpr::Sub(a, b) => a.collect_variables() + b.collect_variables(),
ExtExpr::Mul(a, b) => a.collect_variables() + b.collect_variables(),
ExtExpr::Neg(a) => a.collect_variables(),
}
}

pub fn random_eval(&self) -> SecureField {
let assignment = self.collect_variables().random_assignment();
self.eval_expr::<AssertEvaluator<'_>, _, _, _>(&assignment.0, &assignment.1, &assignment.2)
}
}

/// An assignment to the variables that may appear in an expression.
pub type ExprVarAssignment = (
HashMap<(usize, usize, isize), BaseField>,
HashMap<String, BaseField>,
HashMap<String, SecureField>,
);

/// Three sets representing all the variables that can appear in an expression:
/// * `cols`: The columns of the AIR.
/// * `params`: The formal parameters to the AIR.
/// * `ext_params`: The extension field parameters to the AIR.
#[derive(Default)]
pub struct ExprVariables {
pub cols: HashSet<ColumnExpr>,
pub params: HashSet<String>,
pub ext_params: HashSet<String>,
}

impl ExprVariables {
pub fn col(col: ColumnExpr) -> Self {
Self {
cols: vec![col].into_iter().collect(),
params: HashSet::new(),
ext_params: HashSet::new(),
}
}

pub fn param(param: String) -> Self {
Self {
cols: HashSet::new(),
params: vec![param].into_iter().collect(),
ext_params: HashSet::new(),
}
}

pub fn ext_param(param: String) -> Self {
Self {
cols: HashSet::new(),
params: HashSet::new(),
ext_params: vec![param].into_iter().collect(),
}
}

/// Generates a random assignment to the variables.
/// Note that the assignment is deterministic in the sets of variables (disregarding their
/// order), and this is required.
pub fn random_assignment(&self) -> ExprVarAssignment {
let mut rng = SmallRng::seed_from_u64(0);

let cols = sorted(self.cols.iter())
.map(|col| ((col.interaction, col.idx, col.offset), rng.gen()))
.collect();

let params = sorted(self.params.iter())
.map(|param| (param.clone(), rng.gen()))
.collect();

let ext_params = sorted(self.ext_params.iter())
.map(|param| (param.clone(), rng.gen()))
.collect();

(cols, params, ext_params)
}
}

impl Add for ExprVariables {
type Output = Self;
fn add(self, rhs: Self) -> Self {
Self {
cols: self.cols.union(&rhs.cols).cloned().collect(),
params: self.params.union(&rhs.params).cloned().collect(),
ext_params: self.ext_params.union(&rhs.ext_params).cloned().collect(),
}
}
}

impl From<BaseField> for BaseExpr {
Expand Down
Loading