From 188ba5cb30b72921cccdd1acb31c39defe1e4c60 Mon Sep 17 00:00:00 2001 From: Alon Titelman Date: Tue, 19 Nov 2024 13:48:23 +0200 Subject: [PATCH] Simplify expressions. --- .../prover/src/constraint_framework/expr.rs | 101 ++++++++++++++++-- .../prover/src/examples/state_machine/mod.rs | 12 +-- 2 files changed, 99 insertions(+), 14 deletions(-) diff --git a/crates/prover/src/constraint_framework/expr.rs b/crates/prover/src/constraint_framework/expr.rs index 5d8013402..30d6f04d4 100644 --- a/crates/prover/src/constraint_framework/expr.rs +++ b/crates/prover/src/constraint_framework/expr.rs @@ -36,7 +36,6 @@ pub enum Expr { } impl Expr { - #[allow(dead_code)] pub fn format_expr(&self) -> String { match self { Expr::Col(ColumnExpr { @@ -67,6 +66,10 @@ impl Expr { Expr::Inv(a) => format!("1 / ({})", a.format_expr()), } } + + pub fn simplify_and_format(&self) -> String { + simplify(self.clone()).format_expr() + } } impl From for Expr { @@ -190,6 +193,91 @@ impl AddAssign for Expr { } } +pub fn simplify(expr: Expr) -> Expr { + match expr { + Expr::Add(a, b) => { + let a = simplify(*a); + let b = simplify(*b); + if let (Expr::Const(a), Expr::Const(b)) = (a.clone(), b.clone()) { + Expr::Const(a + b) + } else if a == Expr::zero() { + b + } else if b == Expr::zero() { + a + } else if let Expr::Neg(a) = a { + if let Expr::Neg(b) = b { + -(*a + *b) + } else { + b - *a + } + } else if let Expr::Neg(b) = b { + a - *b + } else { + a + b + } + } + Expr::Sub(a, b) => { + let a = simplify(*a); + let b = simplify(*b); + if a == Expr::zero() { + -b + } else if b == Expr::zero() { + a + } else if a == b { + Expr::zero() + } else { + a - b + } + } + Expr::Mul(a, b) => { + let a = simplify(*a); + let b = simplify(*b); + let minus_one = Expr::Const(-BaseField::one()); + if let (Expr::Const(a), Expr::Const(b)) = (a.clone(), b.clone()) { + Expr::Const(a * b) + } else if a == Expr::zero() || b == Expr::zero() { + Expr::zero() + } else if a == Expr::one() { + b + } else if b == Expr::one() { + a + } else if a == minus_one { + -b + } else if b == minus_one { + -a + } else { + a * b + } + } + Expr::Col(colexpr) => Expr::Col(colexpr), + Expr::SecureCol([a, b, c, d]) => Expr::SecureCol([ + Box::new(simplify(*a)), + Box::new(simplify(*b)), + Box::new(simplify(*c)), + Box::new(simplify(*d)), + ]), + Expr::Const(c) => Expr::Const(c), + Expr::Param(x) => Expr::Param(x), + Expr::Neg(a) => { + let a = simplify(*a); + match a { + Expr::Neg(b) => *b, + Expr::Const(c) => Expr::Const(-c), + Expr::Sub(a, b) => Expr::Sub(b, a), + _ => -a, + } + } + Expr::Inv(a) => { + let a = simplify(*a); + match a { + Expr::Inv(b) => *b, + Expr::Const(c) => Expr::Const(c.inverse()), + _ => Expr::Inv(Box::new(a)), + } + } + } +} + /// Returns the expression /// `value[0] * _alpha0 + value[1] * _alpha1 + ... - _z.` fn combine_formal>(relation: &R, values: &[Expr]) -> Expr { @@ -273,7 +361,7 @@ impl ExprEvaluator { let lets_string = self .intermediates .iter() - .map(|(name, expr)| format!("let {} = {};", name, expr.format_expr())) + .map(|(name, expr)| format!("let {} = {};", name, expr.simplify_and_format())) .collect::>() .join("\n"); @@ -281,7 +369,7 @@ impl ExprEvaluator { .constraints .iter() .enumerate() - .map(|(i, c)| format!("let constraint_{i} = ") + &c.format_expr() + ";") + .map(|(i, c)| format!("let constraint_{i} = ") + &c.simplify_and_format() + ";") .collect::>() .join("\n\n"); @@ -369,8 +457,7 @@ mod tests { fn test_format_expr() { let test_struct = TestStruct {}; let eval = test_struct.evaluate(ExprEvaluator::new(16, false)); - let expected = "let intermediate0 = 0 \ - + (TestRelation_alpha0) * (col_1_0[0]) \ + let expected = "let intermediate0 = (TestRelation_alpha0) * (col_1_0[0]) \ + (TestRelation_alpha1) * (col_1_1[0]) \ + (TestRelation_alpha2) * (col_1_2[0]) \ - (TestRelation_z); @@ -382,8 +469,8 @@ mod tests { \ let constraint_1 = (SecureCol(col_2_4[0], col_2_6[0], col_2_8[0], col_2_10[0]) \ - (SecureCol(col_2_5[-1], col_2_7[-1], col_2_9[-1], col_2_11[-1]) \ - - ((col_0_3[0]) * (total_sum))) \ - - (0)) \ + - ((col_0_3[0]) * (total_sum)))\ + ) \ * (intermediate0) \ - (1);" .to_string(); diff --git a/crates/prover/src/examples/state_machine/mod.rs b/crates/prover/src/examples/state_machine/mod.rs index 787394dd3..2cf8bc2f2 100644 --- a/crates/prover/src/examples/state_machine/mod.rs +++ b/crates/prover/src/examples/state_machine/mod.rs @@ -300,13 +300,11 @@ mod tests { ); let eval = component.evaluate(ExprEvaluator::new(log_n_rows, true)); - let expected = "let intermediate0 = 0 \ - + (StateMachineElements_alpha0) * (col_1_0[0]) \ + let expected = "let intermediate0 = (StateMachineElements_alpha0) * (col_1_0[0]) \ + (StateMachineElements_alpha1) * (col_1_1[0]) \ - (StateMachineElements_z); \ - let intermediate1 = 0 \ - + (StateMachineElements_alpha0) * (col_1_0[0] + 1) \ + let intermediate1 = (StateMachineElements_alpha0) * (col_1_0[0] + 1) \ + (StateMachineElements_alpha1) * (col_1_1[0]) \ - (StateMachineElements_z); @@ -322,10 +320,10 @@ mod tests { \ let constraint_1 = (SecureCol(col_2_3[0], col_2_6[0], col_2_9[0], col_2_12[0]) \ - (SecureCol(col_2_4[-1], col_2_7[-1], col_2_10[-1], col_2_13[-1]) \ - - ((col_0_2[0]) * (total_sum))) \ - - (0)) \ + - ((col_0_2[0]) * (total_sum)))\ + ) \ * ((intermediate0) * (intermediate1)) \ - - ((intermediate1) * (1) + (intermediate0) * (-(1)));" + - (intermediate1 - (intermediate0));" .to_string(); assert_eq!(eval.format_constraints(), expected);