From 6501f2ad121e15ed1f697f395dde22e848b7e58f Mon Sep 17 00:00:00 2001 From: Alon Titelman Date: Mon, 18 Nov 2024 18:39:09 +0200 Subject: [PATCH] Exported shared code to temp variables in ExprEvaluator. --- .../prover/src/constraint_framework/expr.rs | 240 +++++------------- .../prover/src/examples/state_machine/mod.rs | 62 +++-- 2 files changed, 91 insertions(+), 211 deletions(-) diff --git a/crates/prover/src/constraint_framework/expr.rs b/crates/prover/src/constraint_framework/expr.rs index 92526a6b6..d77d6d335 100644 --- a/crates/prover/src/constraint_framework/expr.rs +++ b/crates/prover/src/constraint_framework/expr.rs @@ -64,7 +64,7 @@ impl Expr { Expr::Sub(a, b) => format!("{} - ({})", a.format_expr(), b.format_expr()), Expr::Mul(a, b) => format!("({}) * ({})", a.format_expr(), b.format_expr()), Expr::Neg(a) => format!("-({})", a.format_expr()), - Expr::Inv(a) => format!("1/({})", a.format_expr()), + Expr::Inv(a) => format!("1 / ({})", a.format_expr()), } } } @@ -248,6 +248,7 @@ pub struct ExprEvaluator { pub cur_var_index: usize, pub constraints: Vec, pub logup: FormalLogupAtRow, + pub intermediates: Vec<(String, Expr)>, } impl ExprEvaluator { @@ -257,8 +258,35 @@ impl ExprEvaluator { cur_var_index: Default::default(), constraints: Default::default(), logup: FormalLogupAtRow::new(INTERACTION_TRACE_IDX, has_partial_sum, log_size), + intermediates: vec![], } } + + pub fn add_temp_var(&mut self, expr: Expr) -> Expr { + let name = format!("temp_{}", self.intermediates.len()); + let temp_var = Expr::Param(name.clone()); + self.intermediates.push((name, expr)); + temp_var + } + + pub fn format_constraints(&self) -> String { + let lets_string = self + .intermediates + .iter() + .map(|(name, expr)| format!("let {} = {};", name, expr.format_expr())) + .collect::>() + .join("\n"); + + let constraints_str = self + .constraints + .iter() + .enumerate() + .map(|(i, c)| format!("let constraint_{i} = ") + &c.format_expr() + ";") + .collect::>() + .join("\n\n"); + + lets_string + "\n\n" + &constraints_str + } } impl EvalAtRow for ExprEvaluator { @@ -286,7 +314,15 @@ impl EvalAtRow for ExprEvaluator { where Self::EF: std::ops::Mul, { - self.constraints.push(Expr::one() * constraint); + match Expr::one() * constraint { + Expr::Mul(one, constraint) => { + assert_eq!(*one, Expr::one()); + self.constraints.push(*constraint); + } + _ => { + unreachable!(); + } + } } fn combine_ef(values: [Self::F; 4]) -> Self::EF { @@ -310,7 +346,8 @@ impl EvalAtRow for ExprEvaluator { multiplicity, values, }| { - Fraction::new(multiplicity.clone(), combine_formal(*relation, values)) + let temp_var = self.add_temp_var(combine_formal(*relation, values)); + Fraction::new(multiplicity.clone(), temp_var) }, ) .collect(); @@ -324,187 +361,34 @@ impl EvalAtRow for ExprEvaluator { mod tests { use num_traits::One; - use crate::constraint_framework::expr::{ColumnExpr, Expr, ExprEvaluator}; - use crate::constraint_framework::{ - relation, EvalAtRow, FrameworkEval, RelationEntry, ORIGINAL_TRACE_IDX, - }; - use crate::core::fields::m31::M31; + use crate::constraint_framework::expr::ExprEvaluator; + use crate::constraint_framework::{relation, EvalAtRow, FrameworkEval, RelationEntry}; use crate::core::fields::FieldExpOps; - #[test] - fn test_expr_eval() { - let test_struct = TestStruct {}; - let eval = test_struct.evaluate(ExprEvaluator::new(16, false)); - assert_eq!(eval.constraints.len(), 2); - assert_eq!( - eval.constraints[0], - Expr::Mul( - Box::new(Expr::one()), - Box::new(Expr::Mul( - Box::new(Expr::Mul( - Box::new(Expr::Mul( - Box::new(Expr::Col(ColumnExpr { - interaction: ORIGINAL_TRACE_IDX, - idx: 0, - offset: 0 - })), - Box::new(Expr::Col(ColumnExpr { - interaction: ORIGINAL_TRACE_IDX, - idx: 1, - offset: 0 - })) - )), - Box::new(Expr::Col(ColumnExpr { - interaction: ORIGINAL_TRACE_IDX, - idx: 2, - offset: 0 - })) - )), - Box::new(Expr::Inv(Box::new(Expr::Add( - Box::new(Expr::Col(ColumnExpr { - interaction: ORIGINAL_TRACE_IDX, - idx: 0, - offset: 0 - })), - Box::new(Expr::Col(ColumnExpr { - interaction: ORIGINAL_TRACE_IDX, - idx: 1, - offset: 0 - })) - )))) - )) - ) - ); - - assert_eq!( - eval.constraints[1], - Expr::Mul( - Box::new(Expr::Const(M31(1))), - Box::new(Expr::Sub( - Box::new(Expr::Mul( - Box::new(Expr::Sub( - Box::new(Expr::Sub( - Box::new(Expr::SecureCol([ - Box::new(Expr::Col(ColumnExpr { - interaction: 2, - idx: 4, - offset: 0 - })), - Box::new(Expr::Col(ColumnExpr { - interaction: 2, - idx: 6, - offset: 0 - })), - Box::new(Expr::Col(ColumnExpr { - interaction: 2, - idx: 8, - offset: 0 - })), - Box::new(Expr::Col(ColumnExpr { - interaction: 2, - idx: 10, - offset: 0 - })) - ])), - Box::new(Expr::Sub( - Box::new(Expr::SecureCol([ - Box::new(Expr::Col(ColumnExpr { - interaction: 2, - idx: 5, - offset: -1 - })), - Box::new(Expr::Col(ColumnExpr { - interaction: 2, - idx: 7, - offset: -1 - })), - Box::new(Expr::Col(ColumnExpr { - interaction: 2, - idx: 9, - offset: -1 - })), - Box::new(Expr::Col(ColumnExpr { - interaction: 2, - idx: 11, - offset: -1 - })) - ])), - Box::new(Expr::Mul( - Box::new(Expr::Col(ColumnExpr { - interaction: 0, - idx: 3, - offset: 0 - })), - Box::new(Expr::Param("total_sum".into())) - )) - )) - )), - Box::new(Expr::Const(M31(0))) - )), - Box::new(Expr::Sub( - Box::new(Expr::Add( - Box::new(Expr::Add( - Box::new(Expr::Add( - Box::new(Expr::Const(M31(0))), - Box::new(Expr::Mul( - Box::new(Expr::Param( - "TestRelation_alpha0".to_string() - )), - Box::new(Expr::Col(ColumnExpr { - interaction: 1, - idx: 0, - offset: 0 - })) - )) - )), - Box::new(Expr::Mul( - Box::new(Expr::Param("TestRelation_alpha1".to_string())), - Box::new(Expr::Col(ColumnExpr { - interaction: 1, - idx: 1, - offset: 0 - })) - )) - )), - Box::new(Expr::Mul( - Box::new(Expr::Param("TestRelation_alpha2".to_string())), - Box::new(Expr::Col(ColumnExpr { - interaction: 1, - idx: 2, - offset: 0 - })) - )) - )), - Box::new(Expr::Param("TestRelation_z".to_string())) - )) - )), - Box::new(Expr::Const(M31(1))) - )) - ) - ); - } - #[test] fn test_format_expr() { let test_struct = TestStruct {}; let eval = test_struct.evaluate(ExprEvaluator::new(16, false)); - let constraint0_str = "(1) * ((((col_1_0[0]) * (col_1_1[0])) * (col_1_2[0])) * (1/(col_1_0[0] + col_1_1[0])))"; - assert_eq!(eval.constraints[0].format_expr(), constraint0_str); - let constraint1_str = "(1) \ - * ((SecureCol(col_2_4[0], col_2_6[0], col_2_8[0], col_2_10[0]) \ - - (SecureCol(\ - col_2_5[-1], \ - col_2_7[-1], \ - col_2_9[-1], \ - col_2_11[-1]\ - ) - ((col_0_3[0]) * (total_sum))) \ - - (0)) \ - * (0 + (TestRelation_alpha0) * (col_1_0[0]) \ - + (TestRelation_alpha1) * (col_1_1[0]) \ - + (TestRelation_alpha2) * (col_1_2[0]) \ - - (TestRelation_z)) \ - - (1))"; - assert_eq!(eval.constraints[1].format_expr(), constraint1_str); + let expected = "let temp_0 = 0 \ + + (TestRelation_alpha0) * (col_1_0[0]) \ + + (TestRelation_alpha1) * (col_1_1[0]) \ + + (TestRelation_alpha2) * (col_1_2[0]) \ + - (TestRelation_z); + +\ + let constraint_0 = \ + (((col_1_0[0]) * (col_1_1[0])) * (col_1_2[0])) * (1 / (col_1_0[0] + col_1_1[0])); + +\ + let constraint_1 = (SecureCol(col_2_4[0], col_2_6[0], col_2_8[0], col_2_10[0]) \ + - (SecureCol(col_2_5[-1], col_2_7[-1], col_2_9[-1], col_2_11[-1]) \ + - ((col_0_3[0]) * (total_sum))) \ + - (0)) \ + * (temp_0) \ + - (1);" + .to_string(); + + assert_eq!(eval.format_constraints(), expected); } relation!(TestRelation, 3); diff --git a/crates/prover/src/examples/state_machine/mod.rs b/crates/prover/src/examples/state_machine/mod.rs index b17258d20..62f57dd20 100644 --- a/crates/prover/src/examples/state_machine/mod.rs +++ b/crates/prover/src/examples/state_machine/mod.rs @@ -300,38 +300,34 @@ mod tests { ); let eval = component.evaluate(ExprEvaluator::new(log_n_rows, true)); - - assert_eq!(eval.constraints.len(), 2); - let constraint0_str = "(1) \ - * ((SecureCol(\ - col_2_5[claimed_sum_offset], \ - col_2_8[claimed_sum_offset], \ - col_2_11[claimed_sum_offset], \ - col_2_14[claimed_sum_offset]\ - ) - (claimed_sum)) \ - * (col_0_2[0]))"; - assert_eq!(eval.constraints[0].format_expr(), constraint0_str); - let constraint1_str = "(1) \ - * ((SecureCol(col_2_3[0], col_2_6[0], col_2_9[0], col_2_12[0]) \ - - (SecureCol(col_2_4[-1], col_2_7[-1], col_2_10[-1], col_2_13[-1]) \ - - ((col_0_2[0]) * (total_sum))) \ - - (0)) \ - * ((0 \ - + (StateMachineElements_alpha0) * (col_1_0[0]) \ - + (StateMachineElements_alpha1) * (col_1_1[0]) \ - - (StateMachineElements_z)) \ - * (0 + (StateMachineElements_alpha0) * (col_1_0[0] + 1) \ - + (StateMachineElements_alpha1) * (col_1_1[0]) \ - - (StateMachineElements_z))) \ - - ((0 \ - + (StateMachineElements_alpha0) * (col_1_0[0] + 1) \ - + (StateMachineElements_alpha1) * (col_1_1[0]) \ - - (StateMachineElements_z)) \ - * (1) \ - + (0 + (StateMachineElements_alpha0) * (col_1_0[0]) \ - + (StateMachineElements_alpha1) * (col_1_1[0]) \ - - (StateMachineElements_z)) \ - * (-(1))))"; - assert_eq!(eval.constraints[1].format_expr(), constraint1_str); + let expected = "let temp_0 = 0 \ + + (StateMachineElements_alpha0) * (col_1_0[0]) \ + + (StateMachineElements_alpha1) * (col_1_1[0]) \ + - (StateMachineElements_z); +\ + let temp_1 = 0 \ + + (StateMachineElements_alpha0) * (col_1_0[0] + 1) \ + + (StateMachineElements_alpha1) * (col_1_1[0]) \ + - (StateMachineElements_z); + +\ + let constraint_0 = (SecureCol(\ + col_2_5[claimed_sum_offset], \ + col_2_8[claimed_sum_offset], \ + col_2_11[claimed_sum_offset], \ + col_2_14[claimed_sum_offset]\ + ) - (claimed_sum)) \ + * (col_0_2[0]); + +\ + let constraint_1 = (SecureCol(col_2_3[0], col_2_6[0], col_2_9[0], col_2_12[0]) \ + - (SecureCol(col_2_4[-1], col_2_7[-1], col_2_10[-1], col_2_13[-1]) \ + - ((col_0_2[0]) * (total_sum))) \ + - (0)) \ + * ((temp_0) * (temp_1)) \ + - ((temp_1) * (1) + (temp_0) * (-(1)));" + .to_string(); + + assert_eq!(eval.format_constraints(), expected); } }