diff --git a/crates/prover/benches/prefix_sum.rs b/crates/prover/benches/prefix_sum.rs index 89cc7c68f..9546ee707 100644 --- a/crates/prover/benches/prefix_sum.rs +++ b/crates/prover/benches/prefix_sum.rs @@ -1,6 +1,6 @@ use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; use stwo_prover::core::backend::simd::column::BaseFieldVec; -use stwo_prover::core::backend::simd::prefix_sum::inclusive_prefix_sum_simd; +use stwo_prover::core::backend::simd::prefix_sum::inclusive_prefix_sum; use stwo_prover::core::fields::m31::BaseField; pub fn simd_prefix_sum_bench(c: &mut Criterion) { @@ -9,7 +9,7 @@ pub fn simd_prefix_sum_bench(c: &mut Criterion) { c.bench_function(&format!("simd prefix_sum 2^{LOG_SIZE}"), |b| { b.iter_batched( || evals.clone(), - inclusive_prefix_sum_simd, + inclusive_prefix_sum, BatchSize::LargeInput, ); }); diff --git a/crates/prover/src/constraint_framework/assert.rs b/crates/prover/src/constraint_framework/assert.rs index c6aa88784..9e5530bae 100644 --- a/crates/prover/src/constraint_framework/assert.rs +++ b/crates/prover/src/constraint_framework/assert.rs @@ -7,6 +7,7 @@ use crate::core::fields::qm31::SecureField; use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; use crate::core::pcs::TreeVec; use crate::core::poly::circle::{CanonicCoset, CirclePoly}; +use crate::core::utils::circle_domain_order_to_coset_order; /// Evaluates expressions at a trace domain row, and asserts constraints. Mainly used for testing. pub struct AssertEvaluator<'a> { @@ -66,10 +67,13 @@ pub fn assert_constraints( let traces = trace_polys.as_ref().map(|tree| { tree.iter() .map(|poly| { - poly.evaluate(trace_domain.circle_domain()) - .bit_reverse() - .values - .to_cpu() + circle_domain_order_to_coset_order( + &poly + .evaluate(trace_domain.circle_domain()) + .bit_reverse() + .values + .to_cpu(), + ) }) .collect() }); diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index bb753b646..845d6cae8 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -5,6 +5,7 @@ mod info; mod point; mod simd_domain; +use std::array; use std::fmt::Debug; use std::ops::{Add, AddAssign, Mul, Sub}; @@ -64,6 +65,17 @@ pub trait EvalAtRow { offsets: [isize; N], ) -> [Self::F; N]; + /// Returns the extension mask values of the given offsets for the next extension degree many + /// columns in the interaction. + fn next_extension_interaction_mask( + &mut self, + interaction: usize, + offsets: [isize; N], + ) -> [Self::EF; N] { + let res_col_major = array::from_fn(|_| self.next_interaction_mask(interaction, offsets)); + array::from_fn(|i| Self::combine_ef(res_col_major.map(|c| c[i]))) + } + /// Adds a constraint to the component. fn add_constraint(&mut self, constraint: G) where diff --git a/crates/prover/src/core/backend/simd/prefix_sum.rs b/crates/prover/src/core/backend/simd/prefix_sum.rs index dcb439fca..c3293f97c 100644 --- a/crates/prover/src/core/backend/simd/prefix_sum.rs +++ b/crates/prover/src/core/backend/simd/prefix_sum.rs @@ -17,7 +17,7 @@ use crate::core::utils::{ /// /// Based on parallel Blelloch prefix sum: /// -pub fn inclusive_prefix_sum_simd( +pub fn inclusive_prefix_sum( bit_rev_circle_domain_evals: Col, ) -> Col { if bit_rev_circle_domain_evals.len() < N_LANES * 4 { @@ -145,7 +145,7 @@ mod tests { use rand::{Rng, SeedableRng}; use test_log::test; - use super::inclusive_prefix_sum_simd; + use super::inclusive_prefix_sum; use crate::core::backend::simd::column::BaseFieldVec; use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum_slow; use crate::core::backend::Column; @@ -157,7 +157,7 @@ mod tests { let evals: BaseFieldVec = (0..1 << LOG_N).map(|_| rng.gen()).collect(); let expected = inclusive_prefix_sum_slow(evals.clone()); - let res = inclusive_prefix_sum_simd(evals); + let res = inclusive_prefix_sum(evals); assert_eq!(res.to_cpu(), expected.to_cpu()); } @@ -169,7 +169,7 @@ mod tests { let evals: BaseFieldVec = (0..1 << LOG_N).map(|_| rng.gen()).collect(); let expected = inclusive_prefix_sum_slow(evals.clone()); - let res = inclusive_prefix_sum_simd(evals); + let res = inclusive_prefix_sum(evals); assert_eq!(res.to_cpu(), expected.to_cpu()); } @@ -181,7 +181,7 @@ mod tests { let evals: BaseFieldVec = (0..1 << LOG_N).map(|_| rng.gen()).collect(); let expected = inclusive_prefix_sum_slow(evals.clone()); - let res = inclusive_prefix_sum_simd(evals); + let res = inclusive_prefix_sum(evals); assert_eq!(res.to_cpu(), expected.to_cpu()); } diff --git a/crates/prover/src/core/utils.rs b/crates/prover/src/core/utils.rs index f478ea450..be08c6572 100644 --- a/crates/prover/src/core/utils.rs +++ b/crates/prover/src/core/utils.rs @@ -7,7 +7,7 @@ use super::circle::CirclePoint; use super::constraints::point_vanishing; use super::fields::m31::BaseField; use super::fields::qm31::SecureField; -use super::fields::FieldExpOps; +use super::fields::{Field, FieldExpOps}; use super::poly::circle::CircleDomain; pub trait IteratorMutExt<'a, T: 'a>: Iterator { @@ -108,7 +108,7 @@ pub(crate) fn circle_domain_order_to_coset_order(values: &[BaseField]) -> Vec Vec { +pub(crate) fn coset_order_to_circle_domain_order(values: &[F]) -> Vec { let mut circle_domain_order = Vec::with_capacity(values.len()); let n = values.len(); let half_len = n / 2; diff --git a/crates/prover/src/examples/mod.rs b/crates/prover/src/examples/mod.rs index cec9debd2..45e2f4186 100644 --- a/crates/prover/src/examples/mod.rs +++ b/crates/prover/src/examples/mod.rs @@ -1,3 +1,4 @@ pub mod fibonacci; pub mod poseidon; pub mod wide_fibonacci; +pub mod xor; diff --git a/crates/prover/src/examples/xor/mod.rs b/crates/prover/src/examples/xor/mod.rs new file mode 100644 index 000000000..8cd1686f1 --- /dev/null +++ b/crates/prover/src/examples/xor/mod.rs @@ -0,0 +1 @@ +pub mod prefix_sum_constraints; diff --git a/crates/prover/src/examples/xor/prefix_sum_constraints.rs b/crates/prover/src/examples/xor/prefix_sum_constraints.rs new file mode 100644 index 000000000..c18f7999c --- /dev/null +++ b/crates/prover/src/examples/xor/prefix_sum_constraints.rs @@ -0,0 +1,126 @@ +use crate::constraint_framework::EvalAtRow; +use crate::core::fields::qm31::SecureField; + +/// Inclusive prefix sum constraint. +pub fn inclusive_prefix_sum_check( + eval: &mut E, + row_diff: E::EF, + final_sum: SecureField, + is_first: E::F, + at: &PrefixSumMask, +) { + let prev = at.prev - is_first * final_sum; + eval.add_constraint(at.curr - prev - row_diff); +} + +#[derive(Debug, Clone, Copy)] +pub struct PrefixSumMask { + pub curr: E::EF, + pub prev: E::EF, +} + +impl PrefixSumMask { + pub fn draw(eval: &mut E) -> Self { + let [curr, prev] = eval.next_extension_interaction_mask(TRACE, [0, -1]); + Self { curr, prev } + } +} + +#[cfg(test)] +mod tests { + use itertools::Itertools; + use num_traits::One; + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + use test_log::test; + + use super::inclusive_prefix_sum_check; + use crate::constraint_framework::{assert_constraints, EvalAtRow}; + use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum; + use crate::core::backend::simd::SimdBackend; + use crate::core::backend::{Col, Column}; + use crate::core::fields::m31::BaseField; + use crate::core::fields::qm31::SecureField; + use crate::core::fields::secure_column::SecureColumn; + use crate::core::pcs::TreeVec; + use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; + use crate::core::poly::BitReversedOrder; + use crate::core::utils::{bit_reverse, coset_order_to_circle_domain_order}; + use crate::examples::xor::prefix_sum_constraints::PrefixSumMask; + + const SUM_TRACE: usize = 0; + const CONST_TRACE: usize = 1; + + #[test] + fn inclusive_prefix_sum_constraints_with_log_size_5() { + const LOG_SIZE: u32 = 5; + let mut rng = SmallRng::seed_from_u64(0); + let vals = (0..1 << LOG_SIZE).map(|_| rng.gen()).collect_vec(); + let final_sum = vals.iter().sum(); + let base_trace = gen_base_trace(vals); + let constants_trace = gen_constants_trace(LOG_SIZE); + let traces = TreeVec::new(vec![base_trace, constants_trace]); + let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); + let trace_domain = CanonicCoset::new(LOG_SIZE); + + assert_constraints(&trace_polys, trace_domain, |mut eval| { + let [is_first] = eval.next_interaction_mask(CONST_TRACE, [0]); + let [row_diff] = eval.next_extension_interaction_mask(SUM_TRACE, [0]); + let at_mask = PrefixSumMask::draw::(&mut eval); + inclusive_prefix_sum_check(&mut eval, row_diff, final_sum, is_first, &at_mask); + }); + } + + /// Generates a trace. + /// + /// Trace structure: + /// + /// ```text + /// --------------------------------------------------------- + /// | Values | Values prefix sum | + /// --------------------------------------------------------- + /// | c0 | c1 | c2 | c3 | c0 | c1 | c2 | c3 | + /// --------------------------------------------------------- + /// ``` + fn gen_base_trace( + vals: Vec, + ) -> Vec> { + assert!(vals.len().is_power_of_two()); + + let vals_circle_domain_order = coset_order_to_circle_domain_order(&vals); + let mut vals_bit_rev_circle_domain_order = vals_circle_domain_order; + bit_reverse(&mut vals_bit_rev_circle_domain_order); + let vals_secure_col: SecureColumn = + vals_bit_rev_circle_domain_order.into_iter().collect(); + let [vals_col0, vals_col1, vals_col2, vals_col3] = vals_secure_col.columns; + + let prefix_sum_col0 = inclusive_prefix_sum(vals_col0.clone()); + let prefix_sum_col1 = inclusive_prefix_sum(vals_col1.clone()); + let prefix_sum_col2 = inclusive_prefix_sum(vals_col2.clone()); + let prefix_sum_col3 = inclusive_prefix_sum(vals_col3.clone()); + + let log_size = vals.len().ilog2(); + let trace_domain = CanonicCoset::new(log_size).circle_domain(); + + vec![ + CircleEvaluation::new(trace_domain, vals_col0), + CircleEvaluation::new(trace_domain, vals_col1), + CircleEvaluation::new(trace_domain, vals_col2), + CircleEvaluation::new(trace_domain, vals_col3), + CircleEvaluation::new(trace_domain, prefix_sum_col0), + CircleEvaluation::new(trace_domain, prefix_sum_col1), + CircleEvaluation::new(trace_domain, prefix_sum_col2), + CircleEvaluation::new(trace_domain, prefix_sum_col3), + ] + } + + fn gen_constants_trace( + log_size: u32, + ) -> Vec> { + let trace_domain = CanonicCoset::new(log_size).circle_domain(); + // Column is `1` at the first trace point and `0` on all other trace points. + let mut is_first = Col::::zeros(1 << log_size); + is_first.as_mut_slice()[0] = BaseField::one(); + vec![CircleEvaluation::new(trace_domain, is_first)] + } +}