diff --git a/crates/prover/src/constraint_framework/expr.rs b/crates/prover/src/constraint_framework/expr.rs index e01a03818..7ea345564 100644 --- a/crates/prover/src/constraint_framework/expr.rs +++ b/crates/prover/src/constraint_framework/expr.rs @@ -791,34 +791,35 @@ pub struct ExprEvaluator { pub cur_var_index: usize, pub constraints: Vec, pub logup: FormalLogupAtRow, - pub intermediates: Vec<(String, ExtExpr)>, + pub intermediates: Vec<(String, BaseExpr)>, + pub secure_intermediates: Vec<(String, ExtExpr)>, } impl ExprEvaluator { - #[allow(dead_code)] pub fn new(log_size: u32, has_partial_sum: bool) -> Self { Self { cur_var_index: Default::default(), constraints: Default::default(), logup: FormalLogupAtRow::new(INTERACTION_TRACE_IDX, has_partial_sum, log_size), intermediates: vec![], + secure_intermediates: vec![], } } - pub fn add_intermediate(&mut self, expr: ExtExpr) -> ExtExpr { - let name = format!("intermediate{}", self.intermediates.len()); - let intermediate = ExtExpr::Param(name.clone()); - self.intermediates.push((name, expr)); - intermediate - } - pub fn format_constraints(&self) -> String { let lets_string = self .intermediates .iter() .map(|(name, expr)| format!("let {} = {};", name, expr.simplify_and_format())) .collect::>() - .join("\n"); + .join("\n\n"); + + let secure_lets_string = self + .secure_intermediates + .iter() + .map(|(name, expr)| format!("let {} = {};", name, expr.simplify_and_format())) + .collect::>() + .join("\n\n"); let constraints_str = self .constraints @@ -828,7 +829,12 @@ impl ExprEvaluator { .collect::>() .join("\n\n"); - lets_string + "\n\n" + &constraints_str + [lets_string, secure_lets_string, constraints_str] + .iter() + .filter(|x| !x.is_empty()) + .cloned() + .collect::>() + .join("\n\n") } } @@ -877,7 +883,8 @@ impl EvalAtRow for ExprEvaluator { multiplicity, values, }| { - let intermediate = self.add_intermediate(combine_formal(*relation, values)); + let intermediate = + self.add_secure_intermediate(combine_formal(*relation, values)); Fraction::new(multiplicity.clone(), intermediate) }, ) @@ -885,6 +892,26 @@ impl EvalAtRow for ExprEvaluator { self.write_logup_frac(fracs.into_iter().sum()); } + fn add_intermediate(&mut self, expr: Self::F) -> Self::F { + let name = format!( + "intermediate{}", + self.intermediates.len() + self.secure_intermediates.len() + ); + let intermediate = BaseExpr::Param(name.clone()); + self.intermediates.push((name, expr)); + intermediate + } + + fn add_secure_intermediate(&mut self, expr: Self::EF) -> Self::EF { + let name = format!( + "intermediate{}", + self.intermediates.len() + self.secure_intermediates.len() + ); + let intermediate = ExtExpr::Param(name.clone()); + self.secure_intermediates.push((name, expr)); + intermediate + } + super::logup_proxy!(); } @@ -1050,21 +1077,22 @@ mod tests { fn test_format_expr() { let test_struct = TestStruct {}; let eval = test_struct.evaluate(ExprEvaluator::new(16, false)); - let expected = "let intermediate0 = (TestRelation_alpha0) * (col_1_0[0]) \ + let expected = "let intermediate0 = (col_1_1[0]) * (col_1_2[0]); + +\ + let intermediate1 = (TestRelation_alpha0) * (col_1_0[0]) \ + (TestRelation_alpha1) * (col_1_1[0]) \ + (TestRelation_alpha2) * (col_1_2[0]) \ - (TestRelation_z); \ - let constraint_0 = \ - (((col_1_0[0]) * (col_1_1[0])) * (col_1_2[0])) * (1 / (col_1_0[0] + col_1_1[0])); + let constraint_0 = ((col_1_0[0]) * (intermediate0)) * (1 / (col_1_0[0] + col_1_1[0])); \ let constraint_1 = (SecureCol(col_2_4[0], col_2_6[0], col_2_8[0], col_2_10[0]) \ - (SecureCol(col_2_5[-1], col_2_7[-1], col_2_9[-1], col_2_11[-1]) \ - - ((total_sum) * (col_0_3[0])))\ - ) \ - * (intermediate0) \ + - ((total_sum) * (col_0_3[0])))) \ + * (intermediate1) \ - (1);" .to_string(); @@ -1085,9 +1113,8 @@ mod tests { let x0 = eval.next_trace_mask(); let x1 = eval.next_trace_mask(); let x2 = eval.next_trace_mask(); - eval.add_constraint( - x0.clone() * x1.clone() * x2.clone() * (x0.clone() + x1.clone()).inverse(), - ); + let intermediate = eval.add_intermediate(x1.clone() * x2.clone()); + eval.add_constraint(x0.clone() * intermediate * (x0.clone() + x1.clone()).inverse()); eval.add_to_relation(&[RelationEntry::new( &TestRelation::dummy(), E::EF::one(), diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index 341e24511..7d7bd0068 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -111,6 +111,18 @@ pub trait EvalAtRow { where Self::EF: Mul + From; + /// Adds an intermediate value to the component and returns its value. + /// Does nothing by default. + fn add_intermediate(&mut self, val: Self::F) -> Self::F { + val + } + + /// Adds a secure intermediate value to the component and returns its value. + /// Does nothing by default. + fn add_secure_intermediate(&mut self, val: Self::EF) -> Self::EF { + val + } + /// Combines 4 base field values into a single extension field value. fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF; diff --git a/crates/prover/src/examples/state_machine/mod.rs b/crates/prover/src/examples/state_machine/mod.rs index 23973a960..bdb265fff 100644 --- a/crates/prover/src/examples/state_machine/mod.rs +++ b/crates/prover/src/examples/state_machine/mod.rs @@ -357,6 +357,7 @@ mod tests { let expected = "let intermediate0 = (StateMachineElements_alpha0) * (col_1_0[0]) \ + (StateMachineElements_alpha1) * (col_1_1[0]) \ - (StateMachineElements_z); + \ let intermediate1 = (StateMachineElements_alpha0) * (col_1_0[0] + 1) \ + (StateMachineElements_alpha1) * (col_1_1[0]) \