Skip to content

Commit

Permalink
State machine constraint string.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alon-Ti committed Nov 17, 2024
1 parent 36aa69d commit 282c196
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 16 deletions.
14 changes: 7 additions & 7 deletions crates/prover/src/constraint_framework/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,22 @@ use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub};
use num_traits::{One, Zero};

use super::logup::{LogupAtRow, LogupSums};
use super::{EvalAtRow, RelationEntry, RelationType, INTERACTION_TRACE_IDX};
use super::{EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::FieldExpOps;
use crate::core::lookups::utils::Fraction;

/// A single base field column at index `idx` of interaction `interaction`, at mask offset `offset`.
#[derive(Clone, Debug, PartialEq)]
struct ColumnExpr {
pub struct ColumnExpr {
interaction: usize,
idx: usize,
offset: usize,
}

#[derive(Clone, Debug, PartialEq)]
enum Expr {
pub enum Expr {
Col(ColumnExpr),
/// An atomic secure column constructed from 4 expressions.
/// Expressions on the secure column are not reduced, i.e,
Expand Down Expand Up @@ -178,7 +178,7 @@ impl AddAssign<BaseField> for Expr {
}
}

fn combine_formal<R: RelationType<Expr, Expr>>(relation: &R, values: &[Expr]) -> Expr {
fn combine_formal<R: Relation<Expr, Expr>>(relation: &R, values: &[Expr]) -> Expr {
let z = Expr::Var(relation.get_name().to_owned() + "_z");
let alpha_powers = (0..relation.get_size())
.map(|i| Expr::Var(relation.get_name().to_owned() + "_alpha" + &i.to_string()));
Expand All @@ -192,7 +192,7 @@ fn combine_formal<R: RelationType<Expr, Expr>>(relation: &R, values: &[Expr]) ->
}

/// An Evaluator that saves all constraint expressions.
struct ExprEvaluator {
pub struct ExprEvaluator {
pub cur_var_index: usize,
pub constraints: Vec<Expr>,
pub logup: LogupAtRow<Self>,
Expand Down Expand Up @@ -246,9 +246,9 @@ impl EvalAtRow for ExprEvaluator {
])
}

fn add_to_relation<Relation: RelationType<Self::F, Self::EF>>(
fn add_to_relation<R: Relation<Self::F, Self::EF>>(
&mut self,
entries: &[RelationEntry<'_, Self::F, Self::EF, Relation>],
entries: &[RelationEntry<'_, Self::F, Self::EF, R>],
) {
let fracs: Vec<Fraction<Self::EF, Self::EF>> = entries
.iter()
Expand Down
15 changes: 6 additions & 9 deletions crates/prover/src/examples/state_machine/components.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@ use num_traits::{One, Zero};

use crate::constraint_framework::logup::ClaimedPrefixSum;
use crate::constraint_framework::{
relation, EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator, Relation,
relation, EvalAtRow, FrameworkComponent, FrameworkEval, InfoEvaluator, RelationEntry,
PREPROCESSED_TRACE_IDX,
};
use crate::core::air::{Component, ComponentProver};
use crate::core::backend::simd::SimdBackend;
use crate::core::channel::Channel;
use crate::core::fields::m31::M31;
use crate::core::fields::qm31::{SecureField, QM31};
use crate::core::lookups::utils::Fraction;
use crate::core::pcs::TreeVec;
use crate::core::prover::StarkProof;
use crate::core::vcs::ops::MerkleHasher;
Expand Down Expand Up @@ -44,16 +43,14 @@ impl<const COORDINATE: usize> FrameworkEval for StateTransitionEval<COORDINATE>
}
fn evaluate<E: EvalAtRow>(&self, mut eval: E) -> E {
let input_state: [_; STATE_SIZE] = std::array::from_fn(|_| eval.next_trace_mask());
let input_denom: E::EF = self.lookup_elements.combine(&input_state);

let mut output_state = input_state;
let mut output_state = input_state.clone();
output_state[COORDINATE] += E::F::one();
let output_denom: E::EF = self.lookup_elements.combine(&output_state);

eval.write_logup_frac(
Fraction::new(E::EF::one(), input_denom)
+ Fraction::new(-E::EF::one(), output_denom.clone()),
);
eval.add_to_relation(&[
RelationEntry::new(&self.lookup_elements, E::EF::one(), &input_state),
RelationEntry::new(&self.lookup_elements, -E::EF::one(), &output_state),
]);

eval.finalize_logup();
eval
Expand Down
34 changes: 34 additions & 0 deletions crates/prover/src/examples/state_machine/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ mod tests {
};
use super::gen::{gen_interaction_trace, gen_trace};
use super::{prove_state_machine, verify_state_machine};
use crate::constraint_framework::expr::ExprEvaluator;
use crate::constraint_framework::preprocessed_columns::gen_is_first;
use crate::constraint_framework::{
assert_constraints, FrameworkEval, Relation, TraceLocationAllocator,
Expand Down Expand Up @@ -267,4 +268,37 @@ mod tests {

verify_state_machine(config, verifier_channel, components, proof).unwrap();
}

#[test]
fn test_state_machine_constraint_repr() {
let log_n_rows = 8;
let initial_state = [M31::zero(); STATE_SIZE];

let trace = gen_trace(log_n_rows, initial_state, 0);
let lookup_elements = StateMachineElements::draw(&mut Blake2sChannel::default());

let (_, [total_sum, claimed_sum]) =
gen_interaction_trace(1 << log_n_rows, &trace, 0, &lookup_elements);

assert_eq!(total_sum, claimed_sum);
let component = StateMachineOp0Component::new(
&mut TraceLocationAllocator::default(),
StateTransitionEval {
log_n_rows,
lookup_elements,
total_sum,
claimed_sum: (total_sum, (1 << log_n_rows) - 1),
},
(total_sum, Some((total_sum, (1 << log_n_rows) - 1))),
);

let eval = component.evaluate(ExprEvaluator::new(
log_n_rows,
(total_sum, Some((total_sum, (1 << log_n_rows) - 1))),
));

assert_eq!(eval.constraints.len(), 2);
assert_eq!(eval.constraints[0].format_expr(), "(1) * ((SecureCol(col_2_5[255], col_2_8[255], col_2_11[255], col_2_14[255]) - (SecureCol(223732908, 22408442, 1020999916, 2109866192))) * (col_0_2[0]))");
assert_eq!(eval.constraints[1].format_expr(), "(1) * ((SecureCol(col_2_3[0], col_2_6[0], col_2_9[0], col_2_12[0]) - (SecureCol(col_2_4[18446744073709551615], col_2_7[18446744073709551615], col_2_10[18446744073709551615], col_2_13[18446744073709551615]) - ((col_0_2[0]) * (SecureCol(223732908, 22408442, 1020999916, 2109866192)))) - (0)) * ((0 + (StateMachineElements_alpha0) * (col_1_0[0]) + (StateMachineElements_alpha1) * (col_1_1[0]) - (StateMachineElements_z)) * (0 + (StateMachineElements_alpha0) * (col_1_0[0] + 1) + (StateMachineElements_alpha1) * (col_1_1[0]) - (StateMachineElements_z))) - ((0 + (StateMachineElements_alpha0) * (col_1_0[0] + 1) + (StateMachineElements_alpha1) * (col_1_1[0]) - (StateMachineElements_z)) * (1) + (0 + (StateMachineElements_alpha0) * (col_1_0[0]) + (StateMachineElements_alpha1) * (col_1_1[0]) - (StateMachineElements_z)) * (-(1))))");
}
}

0 comments on commit 282c196

Please sign in to comment.