diff --git a/crates/prover/src/constraint_framework/expr.rs b/crates/prover/src/constraint_framework/expr.rs index fff593ac2..ab48c456a 100644 --- a/crates/prover/src/constraint_framework/expr.rs +++ b/crates/prover/src/constraint_framework/expr.rs @@ -183,6 +183,90 @@ impl AddAssign for Expr { } } +pub fn simplify(expr: Expr) -> Expr { + match expr { + Expr::Add(a, b) => { + let a = simplify(*a); + let b = simplify(*b); + if let (Expr::Const(a), Expr::Const(b)) = (a.clone(), b.clone()) { + Expr::Const(a + b) + } else if a == Expr::zero() { + b + } else if b == Expr::zero() { + a + } else if let Expr::Neg(a) = a { + if let Expr::Neg(b) = b { + -(*a + *b) + } else { + b - *a + } + } else if let Expr::Neg(b) = b { + a - *b + } else { + a + b + } + } + Expr::Sub(a, b) => { + let a = simplify(*a); + let b = simplify(*b); + if a == Expr::zero() { + -b + } else if b == Expr::zero() { + a + } else if a == b { + Expr::zero() + } else { + a - b + } + } + Expr::Mul(a, b) => { + let a = simplify(*a); + let b = simplify(*b); + if let (Expr::Const(a), Expr::Const(b)) = (a.clone(), b.clone()) { + Expr::Const(a - b) + } else if a == Expr::zero() || b == Expr::zero() { + Expr::zero() + } else if a == Expr::one() { + b + } else if b == Expr::one() { + a + } else if a == -Expr::one() { + -b + } else if b == -Expr::one() { + -a + } else { + a * b + } + } + Expr::Col(colexpr) => Expr::Col(colexpr), + Expr::SecureCol([a, b, c, d]) => Expr::SecureCol([ + Box::new(simplify(*a)), + Box::new(simplify(*b)), + Box::new(simplify(*c)), + Box::new(simplify(*d)), + ]), + Expr::Const(c) => Expr::Const(c), + Expr::Param(x) => Expr::Param(x), + Expr::Neg(a) => { + let a = simplify(*a); + match a { + Expr::Neg(b) => *b, + Expr::Const(c) => Expr::Const(-c), + Expr::Sub(a, b) => Expr::Sub(b, a), + _ => -a, + } + } + Expr::Inv(a) => { + let a = simplify(*a); + match a { + Expr::Inv(b) => *b, + Expr::Const(c) => Expr::Const(c.inverse()), + _ => Expr::Inv(Box::new(a)), + } + } + } +} + /// Returns the expression /// `value[0] * _alpha0 + value[1] * _alpha1 + ... - _z.` fn combine_formal>(relation: &R, values: &[Expr]) -> Expr { @@ -269,7 +353,7 @@ impl ExprEvaluator { let lets_string = self .temp_vars .iter() - .map(|(name, expr)| format!("let {} = {};", name, expr.format_expr())) + .map(|(name, expr)| format!("let {} = {};", name, simplify(expr.clone()).format_expr())) .collect::>() .join("\n"); @@ -277,7 +361,9 @@ impl ExprEvaluator { .constraints .iter() .enumerate() - .map(|(i, c)| format!("let constraint_{i} = ") + &c.format_expr() + ";") + .map(|(i, c)| { + format!("let constraint_{i} = ") + &simplify(c.clone()).format_expr() + ";" + }) .collect::>() .join("\n\n"); @@ -362,8 +448,7 @@ mod tests { fn test_format_expr() { let test_struct = TestStruct {}; let eval = test_struct.evaluate(ExprEvaluator::new(16, false)); - let expected = "let temp_0 = 0 \ - + (TestRelation_alpha0) * (col_1_0[0]) \ + let expected = "let temp_0 = (TestRelation_alpha0) * (col_1_0[0]) \ + (TestRelation_alpha1) * (col_1_1[0]) \ + (TestRelation_alpha2) * (col_1_2[0]) \ - (TestRelation_z); @@ -375,8 +460,8 @@ mod tests { \ let constraint_1 = (SecureCol(col_2_4[0], col_2_6[0], col_2_8[0], col_2_10[0]) \ - (SecureCol(col_2_5[-1], col_2_7[-1], col_2_9[-1], col_2_11[-1]) \ - - ((col_0_3[0]) * (total_sum))) \ - - (0)) \ + - ((col_0_3[0]) * (total_sum)))\ + ) \ * (temp_0) \ - (1);" .to_string(); diff --git a/crates/prover/src/examples/state_machine/mod.rs b/crates/prover/src/examples/state_machine/mod.rs index aab5ebb1b..1529ed107 100644 --- a/crates/prover/src/examples/state_machine/mod.rs +++ b/crates/prover/src/examples/state_machine/mod.rs @@ -293,13 +293,11 @@ mod tests { ); let eval = component.evaluate(ExprEvaluator::new(log_n_rows, true)); - let expected = "let temp_0 = 0 \ - + (StateMachineElements_alpha0) * (col_1_0[0]) \ + let expected = "let temp_0 = (StateMachineElements_alpha0) * (col_1_0[0]) \ + (StateMachineElements_alpha1) * (col_1_1[0]) \ - (StateMachineElements_z); \ - let temp_1 = 0 \ - + (StateMachineElements_alpha0) * (col_1_0[0] + 1) \ + let temp_1 = (StateMachineElements_alpha0) * (col_1_0[0] + 1) \ + (StateMachineElements_alpha1) * (col_1_1[0]) \ - (StateMachineElements_z); @@ -315,10 +313,10 @@ mod tests { \ let constraint_1 = (SecureCol(col_2_3[0], col_2_6[0], col_2_9[0], col_2_12[0]) \ - (SecureCol(col_2_4[-1], col_2_7[-1], col_2_10[-1], col_2_13[-1]) \ - - ((col_0_2[0]) * (total_sum))) \ - - (0)) \ + - ((col_0_2[0]) * (total_sum)))\ + ) \ * ((temp_0) * (temp_1)) \ - - ((temp_1) * (1) + (temp_0) * (-(1)));" + - (temp_1 + (temp_0) * (2147483646));" .to_string(); assert_eq!(eval.format_constraints(), expected);