diff --git a/crates/prover/src/constraint_framework/expr.rs b/crates/prover/src/constraint_framework/expr.rs index ad6e41c41..1d2bd189d 100644 --- a/crates/prover/src/constraint_framework/expr.rs +++ b/crates/prover/src/constraint_framework/expr.rs @@ -3,7 +3,7 @@ use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub}; use num_traits::{One, Zero}; use super::logup::{LogupAtRow, LogupSums}; -use super::{EvalAtRow, RelationEntry, RelationType, INTERACTION_TRACE_IDX}; +use super::{EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX}; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::FieldExpOps; @@ -11,14 +11,14 @@ use crate::core::lookups::utils::Fraction; /// A single base field column at index `idx` of interaction `interaction`, at mask offset `offset`. #[derive(Clone, Debug, PartialEq)] -struct ColumnExpr { +pub struct ColumnExpr { interaction: usize, idx: usize, offset: usize, } #[derive(Clone, Debug, PartialEq)] -enum Expr { +pub enum Expr { Col(ColumnExpr), /// An atomic secure column constructed from 4 expressions. /// Expressions on the secure column are not reduced, i.e, @@ -178,7 +178,7 @@ impl AddAssign for Expr { } } -fn combine_formal>(relation: &R, values: &[Expr]) -> Expr { +fn combine_formal>(relation: &R, values: &[Expr]) -> Expr { let z = Expr::Var(relation.get_name().to_owned() + "_z"); let alpha_powers = (0..relation.get_size()) .map(|i| Expr::Var(relation.get_name().to_owned() + "_alpha" + &i.to_string())); @@ -192,7 +192,7 @@ fn combine_formal>(relation: &R, values: &[Expr]) -> } /// An Evaluator that saves all constraint expressions. -struct ExprEvaluator { +pub struct ExprEvaluator { pub cur_var_index: usize, pub constraints: Vec, pub logup: LogupAtRow, @@ -246,9 +246,9 @@ impl EvalAtRow for ExprEvaluator { ]) } - fn add_to_relation>( + fn add_to_relation>( &mut self, - entries: &[RelationEntry<'_, Self::F, Self::EF, Relation>], + entries: &[RelationEntry<'_, Self::F, Self::EF, R>], ) { let fracs: Vec> = entries .iter() diff --git a/crates/prover/src/examples/state_machine/components.rs b/crates/prover/src/examples/state_machine/components.rs index c1376285d..2451eef23 100644 --- a/crates/prover/src/examples/state_machine/components.rs +++ b/crates/prover/src/examples/state_machine/components.rs @@ -2,7 +2,7 @@ use num_traits::{One, Zero}; use crate::constraint_framework::logup::ClaimedPrefixSum; use crate::constraint_framework::{ - relation, EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator, Relation, + relation, EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator, RelationEntry, PREPROCESSED_TRACE_IDX, }; use crate::core::air::{Component, ComponentProver}; @@ -10,7 +10,6 @@ use crate::core::backend::simd::SimdBackend; use crate::core::channel::Channel; use crate::core::fields::m31::M31; use crate::core::fields::qm31::{SecureField, QM31}; -use crate::core::lookups::utils::Fraction; use crate::core::pcs::TreeVec; use crate::core::prover::StarkProof; use crate::core::vcs::ops::MerkleHasher; @@ -44,16 +43,14 @@ impl FrameworkEval for StateTransitionEval } fn evaluate(&self, mut eval: E) -> E { let input_state: [_; STATE_SIZE] = std::array::from_fn(|_| eval.next_trace_mask()); - let input_denom: E::EF = self.lookup_elements.combine(&input_state); - let mut output_state = input_state; + let mut output_state = input_state.clone(); output_state[COORDINATE] += E::F::one(); - let output_denom: E::EF = self.lookup_elements.combine(&output_state); - eval.write_logup_frac( - Fraction::new(E::EF::one(), input_denom) - + Fraction::new(-E::EF::one(), output_denom.clone()), - ); + eval.add_to_relation(&[ + RelationEntry::new(&self.lookup_elements, E::EF::one(), &input_state), + RelationEntry::new(&self.lookup_elements, -E::EF::one(), &output_state), + ]); eval.finalize_logup(); eval diff --git a/crates/prover/src/examples/state_machine/mod.rs b/crates/prover/src/examples/state_machine/mod.rs index 5d756331c..59bd540a8 100644 --- a/crates/prover/src/examples/state_machine/mod.rs +++ b/crates/prover/src/examples/state_machine/mod.rs @@ -180,6 +180,7 @@ mod tests { }; use super::gen::{gen_interaction_trace, gen_trace}; use super::{prove_state_machine, verify_state_machine}; + use crate::constraint_framework::expr::ExprEvaluator; use crate::constraint_framework::preprocessed_columns::gen_is_first; use crate::constraint_framework::{ assert_constraints, FrameworkEval, Relation, TraceLocationAllocator, @@ -267,4 +268,37 @@ mod tests { verify_state_machine(config, verifier_channel, components, proof).unwrap(); } + + #[test] + fn test_state_machine_constraint_repr() { + let log_n_rows = 8; + let initial_state = [M31::zero(); STATE_SIZE]; + + let trace = gen_trace(log_n_rows, initial_state, 0); + let lookup_elements = StateMachineElements::draw(&mut Blake2sChannel::default()); + + let (_, [total_sum, claimed_sum]) = + gen_interaction_trace(1 << log_n_rows, &trace, 0, &lookup_elements); + + assert_eq!(total_sum, claimed_sum); + let component = StateMachineOp0Component::new( + &mut TraceLocationAllocator::default(), + StateTransitionEval { + log_n_rows, + lookup_elements, + total_sum, + claimed_sum: (total_sum, (1 << log_n_rows) - 1), + }, + (total_sum, Some((total_sum, (1 << log_n_rows) - 1))), + ); + + let eval = component.evaluate(ExprEvaluator::new( + log_n_rows, + (total_sum, Some((total_sum, (1 << log_n_rows) - 1))), + )); + + assert_eq!(eval.constraints.len(), 2); + assert_eq!(eval.constraints[0].format_expr(), "(1) * ((SecureCol(col_2_5[255], col_2_8[255], col_2_11[255], col_2_14[255]) - (SecureCol(223732908, 22408442, 1020999916, 2109866192))) * (col_0_2[0]))"); + assert_eq!(eval.constraints[1].format_expr(), "(1) * ((SecureCol(col_2_3[0], col_2_6[0], col_2_9[0], col_2_12[0]) - (SecureCol(col_2_4[18446744073709551615], col_2_7[18446744073709551615], col_2_10[18446744073709551615], col_2_13[18446744073709551615]) - ((col_0_2[0]) * (SecureCol(223732908, 22408442, 1020999916, 2109866192)))) - (0)) * ((0 + (StateMachineElements_alpha0) * (col_1_0[0]) + (StateMachineElements_alpha1) * (col_1_1[0]) - (StateMachineElements_z)) * (0 + (StateMachineElements_alpha0) * (col_1_0[0] + 1) + (StateMachineElements_alpha1) * (col_1_1[0]) - (StateMachineElements_z))) - ((0 + (StateMachineElements_alpha0) * (col_1_0[0] + 1) + (StateMachineElements_alpha1) * (col_1_1[0]) - (StateMachineElements_z)) * (1) + (0 + (StateMachineElements_alpha0) * (col_1_0[0]) + (StateMachineElements_alpha1) * (col_1_1[0]) - (StateMachineElements_z)) * (-(1))))"); + } }