diff --git a/crates/prover/src/constraint_framework/expr.rs b/crates/prover/src/constraint_framework/expr.rs index ab32829cf..66c9fb5dc 100644 --- a/crates/prover/src/constraint_framework/expr.rs +++ b/crates/prover/src/constraint_framework/expr.rs @@ -1,4 +1,4 @@ -use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub}; +use std::ops::{Add, AddAssign, Index, Mul, MulAssign, Neg, Sub}; use num_traits::{One, Zero}; @@ -17,6 +17,16 @@ pub struct ColumnExpr { offset: isize, } +impl From<(usize, usize, isize)> for ColumnExpr { + fn from((interaction, idx, offset): (usize, usize, isize)) -> Self { + Self { + interaction, + idx, + offset, + } + } +} + /// An expression representing a base field value. Can be either: /// * A column indexed by a `ColumnExpr`. /// * A base field constant. @@ -182,6 +192,34 @@ impl BaseExpr { pub fn simplify_and_format(&self) -> String { self.simplify().format_expr() } + + /// Evaluates a base field expression. + /// Takes: + /// * `columns`: A mapping from triplets (interaction, idx, offset) to base field values. + /// * `vars`: A mapping from variable names to base field values. + pub fn eval_expr(&self, columns: &C, vars: &V) -> E::F + where + C: for<'a> Index<&'a (usize, usize, isize), Output = E::F>, + V: for<'a> Index<&'a String, Output = E::F>, + E: EvalAtRow, + { + match self { + Self::Col(col) => columns[&(col.interaction, col.idx, col.offset)].clone(), + Self::Const(c) => E::F::from(*c), + Self::Param(var) => vars[&var.to_string()].clone(), + Self::Add(a, b) => { + a.eval_expr::(columns, vars) + b.eval_expr::(columns, vars) + } + Self::Sub(a, b) => { + a.eval_expr::(columns, vars) - b.eval_expr::(columns, vars) + } + Self::Mul(a, b) => { + a.eval_expr::(columns, vars) * b.eval_expr::(columns, vars) + } + Self::Neg(a) => -a.eval_expr::(columns, vars), + Self::Inv(a) => a.eval_expr::(columns, vars).inverse(), + } + } } impl ExtExpr { @@ -243,6 +281,44 @@ impl ExtExpr { pub fn simplify_and_format(&self) -> String { self.simplify().format_expr() } + + /// Evaluates an extension field expression. + /// Takes: + /// * `columns`: A mapping from triplets (interaction, idx, offset) to base field values. + /// * `vars`: A mapping from variable names to base field values. + /// * `ext_vars`: A mapping from variable names to extension field values. + pub fn eval_expr(&self, columns: &C, vars: &V, ext_vars: &W) -> E::EF + where + C: for<'a> Index<&'a (usize, usize, isize), Output = E::F>, + V: for<'a> Index<&'a String, Output = E::F>, + W: for<'a> Index<&'a String, Output = E::EF>, + E: EvalAtRow, + { + match self { + Self::SecureCol([a, b, c, d]) => { + let a = a.eval_expr::(columns, vars); + let b = b.eval_expr::(columns, vars); + let c = c.eval_expr::(columns, vars); + let d = d.eval_expr::(columns, vars); + E::combine_ef([a, b, c, d]) + } + Self::Const(c) => E::EF::from(*c), + Self::Param(var) => ext_vars[&var.to_string()].clone(), + Self::Add(a, b) => { + a.eval_expr::(columns, vars, ext_vars) + + b.eval_expr::(columns, vars, ext_vars) + } + Self::Sub(a, b) => { + a.eval_expr::(columns, vars, ext_vars) + - b.eval_expr::(columns, vars, ext_vars) + } + Self::Mul(a, b) => { + a.eval_expr::(columns, vars, ext_vars) + * b.eval_expr::(columns, vars, ext_vars) + } + Self::Neg(a) => -a.eval_expr::(columns, vars, ext_vars), + } + } } impl From for BaseExpr { @@ -624,11 +700,7 @@ impl EvalAtRow for ExprEvaluator { offsets: [isize; N], ) -> [Self::F; N] { std::array::from_fn(|i| { - let col = ColumnExpr { - interaction, - idx: self.cur_var_index, - offset: offsets[i], - }; + let col = ColumnExpr::from((interaction, self.cur_var_index, offsets[i])); self.cur_var_index += 1; BaseExpr::Col(col) }) @@ -675,12 +747,59 @@ impl EvalAtRow for ExprEvaluator { #[cfg(test)] mod tests { + use std::collections::HashMap; + use num_traits::One; + use super::{BaseExpr, ExtExpr}; use crate::constraint_framework::expr::ExprEvaluator; - use crate::constraint_framework::{relation, EvalAtRow, FrameworkEval, RelationEntry}; + use crate::constraint_framework::{ + relation, AssertEvaluator, EvalAtRow, FrameworkEval, RelationEntry, + }; + use crate::core::fields::m31::BaseField; + use crate::core::fields::qm31::SecureField; use crate::core::fields::FieldExpOps; + #[test] + fn test_eval_expr() { + let col_1_0_0 = BaseField::from(12); + let col_1_1_0 = BaseField::from(5); + let var_a = BaseField::from(3); + let var_b = BaseField::from(4); + let var_c = SecureField::from_m31_array([ + BaseField::from(1), + BaseField::from(2), + BaseField::from(3), + BaseField::from(4), + ]); + + let columns: HashMap<(usize, usize, isize), BaseField> = + HashMap::from([((1, 0, 0), col_1_0_0), ((1, 1, 0), col_1_1_0)]); + let vars = HashMap::from([("a".to_string(), var_a), ("b".to_string(), var_b)]); + let ext_vars = HashMap::from([("c".to_string(), var_c)]); + + let expr = ExtExpr::SecureCol([ + Box::new(BaseExpr::Col((1, 0, 0).into()) - BaseExpr::Col((1, 1, 0).into())), + Box::new(BaseExpr::Col((1, 1, 0).into()) * (-BaseExpr::Param("a".to_string()))), + Box::new(BaseExpr::Param("a".to_string()) + BaseExpr::Param("a".to_string()).inverse()), + Box::new(BaseExpr::Param("b".to_string()) * BaseExpr::Const(BaseField::from(7))), + ]) + ExtExpr::Param("c".to_string()) * ExtExpr::Param("c".to_string()) + - ExtExpr::Const(SecureField::one()); + + let expected = SecureField::from_m31_array([ + col_1_0_0 - col_1_1_0, + col_1_1_0 * (-var_a), + var_a + var_a.inverse(), + var_b * BaseField::from(7), + ]) + var_c * var_c + - SecureField::one(); + + assert_eq!( + expr.eval_expr::, _, _, _>(&columns, &vars, &ext_vars), + expected + ); + } + #[test] fn test_format_expr() { let test_struct = TestStruct {};