Skip to content

Commit

Permalink
Added eval on exprs.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alon-Ti committed Dec 1, 2024
1 parent d84b24a commit de080dc
Showing 1 changed file with 126 additions and 7 deletions.
133 changes: 126 additions & 7 deletions crates/prover/src/constraint_framework/expr.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub};
use std::ops::{Add, AddAssign, Index, Mul, MulAssign, Neg, Sub};

use num_traits::{One, Zero};

Expand All @@ -17,6 +17,16 @@ pub struct ColumnExpr {
offset: isize,
}

impl From<(usize, usize, isize)> for ColumnExpr {
fn from((interaction, idx, offset): (usize, usize, isize)) -> Self {
Self {
interaction,
idx,
offset,
}
}
}

/// An expression representing a base field value. Can be either:
/// * A column indexed by a `ColumnExpr`.
/// * A base field constant.
Expand Down Expand Up @@ -182,6 +192,34 @@ impl BaseExpr {
pub fn simplify_and_format(&self) -> String {
self.simplify().format_expr()
}

/// Evaluates a base field expression.
/// Takes:
/// * `columns`: A mapping from triplets (interaction, idx, offset) to base field values.
/// * `vars`: A mapping from variable names to base field values.
pub fn eval_expr<E, C, V>(&self, columns: &C, vars: &V) -> E::F
where
C: for<'a> Index<&'a (usize, usize, isize), Output = E::F>,
V: for<'a> Index<&'a String, Output = E::F>,
E: EvalAtRow,
{
match self {
Self::Col(col) => columns[&(col.interaction, col.idx, col.offset)].clone(),
Self::Const(c) => E::F::from(*c),
Self::Param(var) => vars[&var.to_string()].clone(),
Self::Add(a, b) => {
a.eval_expr::<E, C, V>(columns, vars) + b.eval_expr::<E, C, V>(columns, vars)
}
Self::Sub(a, b) => {
a.eval_expr::<E, C, V>(columns, vars) - b.eval_expr::<E, C, V>(columns, vars)
}
Self::Mul(a, b) => {
a.eval_expr::<E, C, V>(columns, vars) * b.eval_expr::<E, C, V>(columns, vars)
}
Self::Neg(a) => -a.eval_expr::<E, C, V>(columns, vars),
Self::Inv(a) => a.eval_expr::<E, C, V>(columns, vars).inverse(),
}
}
}

impl ExtExpr {
Expand Down Expand Up @@ -243,6 +281,44 @@ impl ExtExpr {
pub fn simplify_and_format(&self) -> String {
self.simplify().format_expr()
}

/// Evaluates an extension field expression.
/// Takes:
/// * `columns`: A mapping from triplets (interaction, idx, offset) to base field values.
/// * `vars`: A mapping from variable names to base field values.
/// * `ext_vars`: A mapping from variable names to extension field values.
pub fn eval_expr<E, C, V, EV>(&self, columns: &C, vars: &V, ext_vars: &EV) -> E::EF
where
C: for<'a> Index<&'a (usize, usize, isize), Output = E::F>,
V: for<'a> Index<&'a String, Output = E::F>,
EV: for<'a> Index<&'a String, Output = E::EF>,
E: EvalAtRow,
{
match self {
Self::SecureCol([a, b, c, d]) => {
let a = a.eval_expr::<E, C, V>(columns, vars);
let b = b.eval_expr::<E, C, V>(columns, vars);
let c = c.eval_expr::<E, C, V>(columns, vars);
let d = d.eval_expr::<E, C, V>(columns, vars);
E::combine_ef([a, b, c, d])
}
Self::Const(c) => E::EF::from(*c),
Self::Param(var) => ext_vars[&var.to_string()].clone(),
Self::Add(a, b) => {
a.eval_expr::<E, C, V, EV>(columns, vars, ext_vars)
+ b.eval_expr::<E, C, V, EV>(columns, vars, ext_vars)
}
Self::Sub(a, b) => {
a.eval_expr::<E, C, V, EV>(columns, vars, ext_vars)
- b.eval_expr::<E, C, V, EV>(columns, vars, ext_vars)
}
Self::Mul(a, b) => {
a.eval_expr::<E, C, V, EV>(columns, vars, ext_vars)
* b.eval_expr::<E, C, V, EV>(columns, vars, ext_vars)
}
Self::Neg(a) => -a.eval_expr::<E, C, V, EV>(columns, vars, ext_vars),
}
}
}

impl From<BaseField> for BaseExpr {
Expand Down Expand Up @@ -624,11 +700,7 @@ impl EvalAtRow for ExprEvaluator {
offsets: [isize; N],
) -> [Self::F; N] {
std::array::from_fn(|i| {
let col = ColumnExpr {
interaction,
idx: self.cur_var_index,
offset: offsets[i],
};
let col = ColumnExpr::from((interaction, self.cur_var_index, offsets[i]));
self.cur_var_index += 1;
BaseExpr::Col(col)
})
Expand Down Expand Up @@ -675,12 +747,59 @@ impl EvalAtRow for ExprEvaluator {

#[cfg(test)]
mod tests {
use std::collections::HashMap;

use num_traits::One;

use super::{BaseExpr, ExtExpr};
use crate::constraint_framework::expr::ExprEvaluator;
use crate::constraint_framework::{relation, EvalAtRow, FrameworkEval, RelationEntry};
use crate::constraint_framework::{
relation, AssertEvaluator, EvalAtRow, FrameworkEval, RelationEntry,
};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;

#[test]
fn test_eval_expr() {
let col_1_0_0 = BaseField::from(12);
let col_1_1_0 = BaseField::from(5);
let var_a = BaseField::from(3);
let var_b = BaseField::from(4);
let var_c = SecureField::from_m31_array([
BaseField::from(1),
BaseField::from(2),
BaseField::from(3),
BaseField::from(4),
]);

let columns: HashMap<(usize, usize, isize), BaseField> =
HashMap::from([((1, 0, 0), col_1_0_0), ((1, 1, 0), col_1_1_0)]);
let vars = HashMap::from([("a".to_string(), var_a), ("b".to_string(), var_b)]);
let ext_vars = HashMap::from([("c".to_string(), var_c)]);

let expr = ExtExpr::SecureCol([
Box::new(BaseExpr::Col((1, 0, 0).into()) - BaseExpr::Col((1, 1, 0).into())),
Box::new(BaseExpr::Col((1, 1, 0).into()) * (-BaseExpr::Param("a".to_string()))),
Box::new(BaseExpr::Param("a".to_string()) + BaseExpr::Param("a".to_string()).inverse()),
Box::new(BaseExpr::Param("b".to_string()) * BaseExpr::Const(BaseField::from(7))),
]) + ExtExpr::Param("c".to_string()) * ExtExpr::Param("c".to_string())
- ExtExpr::Const(SecureField::one());

let expected = SecureField::from_m31_array([
col_1_0_0 - col_1_1_0,
col_1_1_0 * (-var_a),
var_a + var_a.inverse(),
var_b * BaseField::from(7),
]) + var_c * var_c
- SecureField::one();

assert_eq!(
expr.eval_expr::<AssertEvaluator<'_>, _, _, _>(&columns, &vars, &ext_vars),
expected
);
}

#[test]
fn test_format_expr() {
let test_struct = TestStruct {};
Expand Down

0 comments on commit de080dc

Please sign in to comment.