From 74148298c3a92210831d0706fa08aed3bc41a3e8 Mon Sep 17 00:00:00 2001 From: Alon Titelman Date: Wed, 27 Nov 2024 16:46:12 +0200 Subject: [PATCH] Added test for expression simplifier. --- .../prover/src/constraint_framework/expr.rs | 119 ++++++++++++++++-- 1 file changed, 111 insertions(+), 8 deletions(-) diff --git a/crates/prover/src/constraint_framework/expr.rs b/crates/prover/src/constraint_framework/expr.rs index c459cee27..95f9f3342 100644 --- a/crates/prover/src/constraint_framework/expr.rs +++ b/crates/prover/src/constraint_framework/expr.rs @@ -750,16 +750,64 @@ mod tests { use std::collections::HashMap; use num_traits::One; + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; use super::{BaseExpr, ExtExpr}; use crate::constraint_framework::expr::ExprEvaluator; use crate::constraint_framework::{ relation, AssertEvaluator, EvalAtRow, FrameworkEval, RelationEntry, }; - use crate::core::fields::m31::BaseField; + use crate::core::fields::m31::{self, BaseField}; use crate::core::fields::qm31::SecureField; use crate::core::fields::FieldExpOps; + macro_rules! secure_col { + ($a:expr, $b:expr, $c:expr, $d:expr) => { + ExtExpr::SecureCol([ + Box::new($a.into()), + Box::new($b.into()), + Box::new($c.into()), + Box::new($d.into()), + ]) + }; + } + + macro_rules! col { + ($interaction:expr, $idx:expr, $offset:expr) => { + BaseExpr::Col(($interaction, $idx, $offset).into()) + }; + } + + macro_rules! var { + ($var:expr) => { + BaseExpr::Param($var.to_string()) + }; + } + + macro_rules! qvar { + ($var:expr) => { + ExtExpr::Param($var.to_string()) + }; + } + + macro_rules! felt { + ($val:expr) => { + BaseExpr::Const($val.into()) + }; + } + + macro_rules! qfelt { + ($a:expr, $b:expr, $c:expr, $d:expr) => { + ExtExpr::Const(SecureField::from_m31_array([ + $a.into(), + $b.into(), + $c.into(), + $d.into(), + ])) + }; + } + #[test] fn test_eval_expr() { let col_1_0_0 = BaseField::from(12); @@ -778,13 +826,13 @@ mod tests { let vars = HashMap::from([("a".to_string(), var_a), ("b".to_string(), var_b)]); let ext_vars = HashMap::from([("c".to_string(), var_c)]); - let expr = ExtExpr::SecureCol([ - Box::new(BaseExpr::Col((1, 0, 0).into()) - BaseExpr::Col((1, 1, 0).into())), - Box::new(BaseExpr::Col((1, 1, 0).into()) * (-BaseExpr::Param("a".to_string()))), - Box::new(BaseExpr::Param("a".to_string()) + BaseExpr::Param("a".to_string()).inverse()), - Box::new(BaseExpr::Param("b".to_string()) * BaseExpr::Const(BaseField::from(7))), - ]) + ExtExpr::Param("c".to_string()) * ExtExpr::Param("c".to_string()) - - ExtExpr::Const(SecureField::one()); + let expr = secure_col!( + col!(1, 0, 0) - col!(1, 1, 0), + col!(1, 1, 0) * (-var!("a")), + var!("a") + var!("a").inverse(), + var!("b") * felt!(7) + ) + qvar!("c") * qvar!("c") + - qfelt!(1, 0, 0, 0); let expected = SecureField::from_m31_array([ col_1_0_0 - col_1_1_0, @@ -800,6 +848,61 @@ mod tests { ); } + #[test] + fn test_simplify_expr() { + let c0 = col!(1, 0, 0); + let c1 = col!(1, 1, 0); + let a = var!("a"); + let b = qvar!("b"); + let zero = felt!(0); + let qzero = qfelt!(0, 0, 0, 0); + let one = felt!(1); + let qone = qfelt!(1, 0, 0, 0); + let minus_one = felt!(m31::P - 1); + let qminus_one = qfelt!(m31::P - 1, 0, 0, 0); + + let mut rng = SmallRng::seed_from_u64(0); + let columns: HashMap<(usize, usize, isize), BaseField> = + HashMap::from([((1, 0, 0), rng.gen()), ((1, 1, 0), rng.gen())]); + let vars: HashMap = HashMap::from([("a".to_string(), rng.gen())]); + let ext_vars: HashMap = HashMap::from([("b".to_string(), rng.gen())]); + + let base_expr = (((zero.clone() + c0.clone()) + (a.clone() + zero.clone())) + * ((-c1.clone()) + (-c0.clone())) + + (-(-(a.clone() + a.clone() + c0.clone()))) + - zero.clone()) + + (a.clone() - zero.clone()) + + (-c1.clone() - (a.clone() * a.clone())) + + (a.clone() * zero.clone()) + - (zero.clone() * c1.clone()) + + one.clone() + * a.clone() + * one.clone() + * c1.clone() + * (-a.clone()) + * c1.clone() + * (minus_one.clone() * c0.clone()); + + let expr = (qzero.clone() + + secure_col!( + base_expr.clone(), + base_expr.clone(), + zero.clone(), + one.clone() + ) + - qzero.clone()) + * qone.clone() + * b.clone() + * qminus_one.clone(); + + let full_eval = expr.eval_expr::, _, _, _>(&columns, &vars, &ext_vars); + let simplified_eval = expr + .simplify() + .eval_expr::, _, _, _>(&columns, &vars, &ext_vars); + + assert_eq!(full_eval, simplified_eval); + } + #[test] fn test_format_expr() { let test_struct = TestStruct {};