diff --git a/crates/prover/src/constraint_framework/expr.rs b/crates/prover/src/constraint_framework/expr.rs index d83a5250e..92526a6b6 100644 --- a/crates/prover/src/constraint_framework/expr.rs +++ b/crates/prover/src/constraint_framework/expr.rs @@ -2,9 +2,8 @@ use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub}; use num_traits::{One, Zero}; -use super::logup::{LogupAtRow, LogupSums}; use super::{EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX}; -use crate::core::fields::m31::BaseField; +use crate::core::fields::m31::{self, BaseField}; use crate::core::fields::qm31::SecureField; use crate::core::fields::FieldExpOps; use crate::core::lookups::utils::Fraction; @@ -45,7 +44,12 @@ impl Expr { idx, offset, }) => { - format!("col_{interaction}_{idx}[{offset}]") + let offset_str = if *offset == CLAIMED_SUM_DUMMY_OFFSET.try_into().unwrap() { + "claimed_sum_offset".to_string() + } else { + offset.to_string() + }; + format!("col_{interaction}_{idx}[{offset_str}]") } Expr::SecureCol([a, b, c, d]) => format!( "SecureCol({}, {}, {}, {})", @@ -204,20 +208,55 @@ fn combine_formal>(relation: &R, values: &[Expr]) -> Exp - z } +pub struct FormalLogupAtRow { + pub interaction: usize, + pub total_sum: Expr, + pub claimed_sum: Option<(Expr, usize)>, + pub prev_col_cumsum: Expr, + pub cur_frac: Option>, + pub is_finalized: bool, + pub is_first: Expr, + pub log_size: u32, +} + +// P is an offset no column can reach, it signifies the variable +// offset, which is an input to the verifier. +const CLAIMED_SUM_DUMMY_OFFSET: usize = m31::P as usize; + +impl FormalLogupAtRow { + pub fn new(interaction: usize, has_partial_sum: bool, log_size: u32) -> Self { + let total_sum_name = "total_sum".to_string(); + let claimed_sum_name = "claimed_sum".to_string(); + + Self { + interaction, + // TODO(alont): Should these be Expr::SecureField? + total_sum: Expr::Param(total_sum_name), + claimed_sum: has_partial_sum + .then_some((Expr::Param(claimed_sum_name), CLAIMED_SUM_DUMMY_OFFSET)), + prev_col_cumsum: Expr::zero(), + cur_frac: None, + is_finalized: true, + is_first: Expr::zero(), + log_size, + } + } +} + /// An Evaluator that saves all constraint expressions. pub struct ExprEvaluator { pub cur_var_index: usize, pub constraints: Vec, - pub logup: LogupAtRow, + pub logup: FormalLogupAtRow, } impl ExprEvaluator { #[allow(dead_code)] - pub fn new(log_size: u32, logup_sums: LogupSums) -> Self { + pub fn new(log_size: u32, has_partial_sum: bool) -> Self { Self { cur_var_index: Default::default(), constraints: Default::default(), - logup: LogupAtRow::new(INTERACTION_TRACE_IDX, logup_sums.0, logup_sums.1, log_size), + logup: FormalLogupAtRow::new(INTERACTION_TRACE_IDX, has_partial_sum, log_size), } } } @@ -283,20 +322,19 @@ impl EvalAtRow for ExprEvaluator { #[cfg(test)] mod tests { - use num_traits::{One, Zero}; + use num_traits::One; use crate::constraint_framework::expr::{ColumnExpr, Expr, ExprEvaluator}; use crate::constraint_framework::{ relation, EvalAtRow, FrameworkEval, RelationEntry, ORIGINAL_TRACE_IDX, }; use crate::core::fields::m31::M31; - use crate::core::fields::qm31::SecureField; use crate::core::fields::FieldExpOps; #[test] fn test_expr_eval() { let test_struct = TestStruct {}; - let eval = test_struct.evaluate(ExprEvaluator::new(16, (SecureField::zero(), None))); + let eval = test_struct.evaluate(ExprEvaluator::new(16, false)); assert_eq!(eval.constraints.len(), 2); assert_eq!( eval.constraints[0], @@ -397,12 +435,7 @@ mod tests { idx: 3, offset: 0 })), - Box::new(Expr::SecureCol([ - Box::new(Expr::Const(M31(0))), - Box::new(Expr::Const(M31(0))), - Box::new(Expr::Const(M31(0))), - Box::new(Expr::Const(M31(0))) - ])) + Box::new(Expr::Param("total_sum".into())) )) )) )), @@ -454,7 +487,7 @@ mod tests { #[test] fn test_format_expr() { let test_struct = TestStruct {}; - let eval = test_struct.evaluate(ExprEvaluator::new(16, (SecureField::zero(), None))); + let eval = test_struct.evaluate(ExprEvaluator::new(16, false)); let constraint0_str = "(1) * ((((col_1_0[0]) * (col_1_1[0])) * (col_1_2[0])) * (1/(col_1_0[0] + col_1_1[0])))"; assert_eq!(eval.constraints[0].format_expr(), constraint0_str); let constraint1_str = "(1) \ @@ -464,7 +497,7 @@ mod tests { col_2_7[-1], \ col_2_9[-1], \ col_2_11[-1]\ - ) - ((col_0_3[0]) * (SecureCol(0, 0, 0, 0)))) \ + ) - ((col_0_3[0]) * (total_sum))) \ - (0)) \ * (0 + (TestRelation_alpha0) * (col_1_0[0]) \ + (TestRelation_alpha1) * (col_1_1[0]) \ diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index aa871b5bc..fc7bb91d3 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -169,7 +169,7 @@ macro_rules! logup_proxy { // TODO(ShaharS): remove `claimed_row_index` interaction value and get the shifted // offset from the is_first column when constant columns are supported. - let (cur_cumsum, prev_row_cumsum) = match self.logup.claimed_sum { + let (cur_cumsum, prev_row_cumsum) = match self.logup.claimed_sum.clone() { Some((claimed_sum, claimed_row_index)) => { let [cur_cumsum, prev_row_cumsum, claimed_cumsum] = self .next_extension_interaction_mask( @@ -191,7 +191,7 @@ macro_rules! logup_proxy { }; // Fix `prev_row_cumsum` by subtracting `total_sum` if this is the first row. let fixed_prev_row_cumsum = - prev_row_cumsum - self.logup.is_first.clone() * self.logup.total_sum; + prev_row_cumsum - self.logup.is_first.clone() * self.logup.total_sum.clone(); let diff = cur_cumsum - fixed_prev_row_cumsum - self.logup.prev_col_cumsum.clone(); self.add_constraint(diff * frac.denominator - frac.numerator); diff --git a/crates/prover/src/examples/state_machine/mod.rs b/crates/prover/src/examples/state_machine/mod.rs index f6e72032d..b17258d20 100644 --- a/crates/prover/src/examples/state_machine/mod.rs +++ b/crates/prover/src/examples/state_machine/mod.rs @@ -299,21 +299,22 @@ mod tests { (total_sum, Some((total_sum, (1 << log_n_rows) - 1))), ); - let eval = component.evaluate(ExprEvaluator::new( - log_n_rows, - (total_sum, Some((total_sum, (1 << log_n_rows) - 1))), - )); + let eval = component.evaluate(ExprEvaluator::new(log_n_rows, true)); assert_eq!(eval.constraints.len(), 2); let constraint0_str = "(1) \ - * ((SecureCol(col_2_5[255], col_2_8[255], col_2_11[255], col_2_14[255]) \ - - (SecureCol(223732908, 22408442, 1020999916, 2109866192))) \ + * ((SecureCol(\ + col_2_5[claimed_sum_offset], \ + col_2_8[claimed_sum_offset], \ + col_2_11[claimed_sum_offset], \ + col_2_14[claimed_sum_offset]\ + ) - (claimed_sum)) \ * (col_0_2[0]))"; assert_eq!(eval.constraints[0].format_expr(), constraint0_str); let constraint1_str = "(1) \ * ((SecureCol(col_2_3[0], col_2_6[0], col_2_9[0], col_2_12[0]) \ - (SecureCol(col_2_4[-1], col_2_7[-1], col_2_10[-1], col_2_13[-1]) \ - - ((col_0_2[0]) * (SecureCol(223732908, 22408442, 1020999916, 2109866192)))) \ + - ((col_0_2[0]) * (total_sum))) \ - (0)) \ * ((0 \ + (StateMachineElements_alpha0) * (col_1_0[0]) \