Skip to content

Commit

Permalink
Add prefix sum constraints (#740)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson authored Jul 18, 2024
1 parent b72ed52 commit 86ddb92
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 13 deletions.
4 changes: 2 additions & 2 deletions crates/prover/benches/prefix_sum.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use criterion::{criterion_group, criterion_main, BatchSize, Criterion};
use stwo_prover::core::backend::simd::column::BaseFieldVec;
use stwo_prover::core::backend::simd::prefix_sum::inclusive_prefix_sum_simd;
use stwo_prover::core::backend::simd::prefix_sum::inclusive_prefix_sum;
use stwo_prover::core::fields::m31::BaseField;

pub fn simd_prefix_sum_bench(c: &mut Criterion) {
Expand All @@ -9,7 +9,7 @@ pub fn simd_prefix_sum_bench(c: &mut Criterion) {
c.bench_function(&format!("simd prefix_sum 2^{LOG_SIZE}"), |b| {
b.iter_batched(
|| evals.clone(),
inclusive_prefix_sum_simd,
inclusive_prefix_sum,
BatchSize::LargeInput,
);
});
Expand Down
12 changes: 8 additions & 4 deletions crates/prover/src/constraint_framework/assert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE;
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::{CanonicCoset, CirclePoly};
use crate::core::utils::circle_domain_order_to_coset_order;

/// Evaluates expressions at a trace domain row, and asserts constraints. Mainly used for testing.
pub struct AssertEvaluator<'a> {
Expand Down Expand Up @@ -66,10 +67,13 @@ pub fn assert_constraints<B: Backend>(
let traces = trace_polys.as_ref().map(|tree| {
tree.iter()
.map(|poly| {
poly.evaluate(trace_domain.circle_domain())
.bit_reverse()
.values
.to_cpu()
circle_domain_order_to_coset_order(
&poly
.evaluate(trace_domain.circle_domain())
.bit_reverse()
.values
.to_cpu(),
)
})
.collect()
});
Expand Down
12 changes: 12 additions & 0 deletions crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod info;
mod point;
mod simd_domain;

use std::array;
use std::fmt::Debug;
use std::ops::{Add, AddAssign, Mul, Sub};

Expand Down Expand Up @@ -64,6 +65,17 @@ pub trait EvalAtRow {
offsets: [isize; N],
) -> [Self::F; N];

/// Returns the extension mask values of the given offsets for the next extension degree many
/// columns in the interaction.
fn next_extension_interaction_mask<const N: usize>(
&mut self,
interaction: usize,
offsets: [isize; N],
) -> [Self::EF; N] {
let res_col_major = array::from_fn(|_| self.next_interaction_mask(interaction, offsets));
array::from_fn(|i| Self::combine_ef(res_col_major.map(|c| c[i])))
}

/// Adds a constraint to the component.
fn add_constraint<G>(&mut self, constraint: G)
where
Expand Down
10 changes: 5 additions & 5 deletions crates/prover/src/core/backend/simd/prefix_sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use crate::core::utils::{
///
/// Based on parallel Blelloch prefix sum:
/// <https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda>
pub fn inclusive_prefix_sum_simd(
pub fn inclusive_prefix_sum(
bit_rev_circle_domain_evals: Col<SimdBackend, BaseField>,
) -> Col<SimdBackend, BaseField> {
if bit_rev_circle_domain_evals.len() < N_LANES * 4 {
Expand Down Expand Up @@ -145,7 +145,7 @@ mod tests {
use rand::{Rng, SeedableRng};
use test_log::test;

use super::inclusive_prefix_sum_simd;
use super::inclusive_prefix_sum;
use crate::core::backend::simd::column::BaseFieldVec;
use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum_slow;
use crate::core::backend::Column;
Expand All @@ -157,7 +157,7 @@ mod tests {
let evals: BaseFieldVec = (0..1 << LOG_N).map(|_| rng.gen()).collect();
let expected = inclusive_prefix_sum_slow(evals.clone());

let res = inclusive_prefix_sum_simd(evals);
let res = inclusive_prefix_sum(evals);

assert_eq!(res.to_cpu(), expected.to_cpu());
}
Expand All @@ -169,7 +169,7 @@ mod tests {
let evals: BaseFieldVec = (0..1 << LOG_N).map(|_| rng.gen()).collect();
let expected = inclusive_prefix_sum_slow(evals.clone());

let res = inclusive_prefix_sum_simd(evals);
let res = inclusive_prefix_sum(evals);

assert_eq!(res.to_cpu(), expected.to_cpu());
}
Expand All @@ -181,7 +181,7 @@ mod tests {
let evals: BaseFieldVec = (0..1 << LOG_N).map(|_| rng.gen()).collect();
let expected = inclusive_prefix_sum_slow(evals.clone());

let res = inclusive_prefix_sum_simd(evals);
let res = inclusive_prefix_sum(evals);

assert_eq!(res.to_cpu(), expected.to_cpu());
}
Expand Down
4 changes: 2 additions & 2 deletions crates/prover/src/core/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use super::circle::CirclePoint;
use super::constraints::point_vanishing;
use super::fields::m31::BaseField;
use super::fields::qm31::SecureField;
use super::fields::FieldExpOps;
use super::fields::{Field, FieldExpOps};
use super::poly::circle::CircleDomain;

pub trait IteratorMutExt<'a, T: 'a>: Iterator<Item = &'a mut T> {
Expand Down Expand Up @@ -108,7 +108,7 @@ pub(crate) fn circle_domain_order_to_coset_order(values: &[BaseField]) -> Vec<Ba
coset_order
}

pub(crate) fn coset_order_to_circle_domain_order(values: &[BaseField]) -> Vec<BaseField> {
pub(crate) fn coset_order_to_circle_domain_order<F: Field>(values: &[F]) -> Vec<F> {
let mut circle_domain_order = Vec::with_capacity(values.len());
let n = values.len();
let half_len = n / 2;
Expand Down
1 change: 1 addition & 0 deletions crates/prover/src/examples/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod fibonacci;
pub mod poseidon;
pub mod wide_fibonacci;
pub mod xor;
1 change: 1 addition & 0 deletions crates/prover/src/examples/xor/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod prefix_sum_constraints;
126 changes: 126 additions & 0 deletions crates/prover/src/examples/xor/prefix_sum_constraints.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
use crate::constraint_framework::EvalAtRow;
use crate::core::fields::qm31::SecureField;

/// Inclusive prefix sum constraint.
pub fn inclusive_prefix_sum_check<E: EvalAtRow>(
eval: &mut E,
row_diff: E::EF,
final_sum: SecureField,
is_first: E::F,
at: &PrefixSumMask<E>,
) {
let prev = at.prev - is_first * final_sum;
eval.add_constraint(at.curr - prev - row_diff);
}

#[derive(Debug, Clone, Copy)]
pub struct PrefixSumMask<E: EvalAtRow> {
pub curr: E::EF,
pub prev: E::EF,
}

impl<E: EvalAtRow> PrefixSumMask<E> {
pub fn draw<const TRACE: usize>(eval: &mut E) -> Self {
let [curr, prev] = eval.next_extension_interaction_mask(TRACE, [0, -1]);
Self { curr, prev }
}
}

#[cfg(test)]
mod tests {
use itertools::Itertools;
use num_traits::One;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use test_log::test;

use super::inclusive_prefix_sum_check;
use crate::constraint_framework::{assert_constraints, EvalAtRow};
use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Col, Column};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumn;
use crate::core::pcs::TreeVec;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::{bit_reverse, coset_order_to_circle_domain_order};
use crate::examples::xor::prefix_sum_constraints::PrefixSumMask;

const SUM_TRACE: usize = 0;
const CONST_TRACE: usize = 1;

#[test]
fn inclusive_prefix_sum_constraints_with_log_size_5() {
const LOG_SIZE: u32 = 5;
let mut rng = SmallRng::seed_from_u64(0);
let vals = (0..1 << LOG_SIZE).map(|_| rng.gen()).collect_vec();
let final_sum = vals.iter().sum();
let base_trace = gen_base_trace(vals);
let constants_trace = gen_constants_trace(LOG_SIZE);
let traces = TreeVec::new(vec![base_trace, constants_trace]);
let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect());
let trace_domain = CanonicCoset::new(LOG_SIZE);

assert_constraints(&trace_polys, trace_domain, |mut eval| {
let [is_first] = eval.next_interaction_mask(CONST_TRACE, [0]);
let [row_diff] = eval.next_extension_interaction_mask(SUM_TRACE, [0]);
let at_mask = PrefixSumMask::draw::<SUM_TRACE>(&mut eval);
inclusive_prefix_sum_check(&mut eval, row_diff, final_sum, is_first, &at_mask);
});
}

/// Generates a trace.
///
/// Trace structure:
///
/// ```text
/// ---------------------------------------------------------
/// | Values | Values prefix sum |
/// ---------------------------------------------------------
/// | c0 | c1 | c2 | c3 | c0 | c1 | c2 | c3 |
/// ---------------------------------------------------------
/// ```
fn gen_base_trace(
vals: Vec<SecureField>,
) -> Vec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>> {
assert!(vals.len().is_power_of_two());

let vals_circle_domain_order = coset_order_to_circle_domain_order(&vals);
let mut vals_bit_rev_circle_domain_order = vals_circle_domain_order;
bit_reverse(&mut vals_bit_rev_circle_domain_order);
let vals_secure_col: SecureColumn<SimdBackend> =
vals_bit_rev_circle_domain_order.into_iter().collect();
let [vals_col0, vals_col1, vals_col2, vals_col3] = vals_secure_col.columns;

let prefix_sum_col0 = inclusive_prefix_sum(vals_col0.clone());
let prefix_sum_col1 = inclusive_prefix_sum(vals_col1.clone());
let prefix_sum_col2 = inclusive_prefix_sum(vals_col2.clone());
let prefix_sum_col3 = inclusive_prefix_sum(vals_col3.clone());

let log_size = vals.len().ilog2();
let trace_domain = CanonicCoset::new(log_size).circle_domain();

vec![
CircleEvaluation::new(trace_domain, vals_col0),
CircleEvaluation::new(trace_domain, vals_col1),
CircleEvaluation::new(trace_domain, vals_col2),
CircleEvaluation::new(trace_domain, vals_col3),
CircleEvaluation::new(trace_domain, prefix_sum_col0),
CircleEvaluation::new(trace_domain, prefix_sum_col1),
CircleEvaluation::new(trace_domain, prefix_sum_col2),
CircleEvaluation::new(trace_domain, prefix_sum_col3),
]
}

fn gen_constants_trace(
log_size: u32,
) -> Vec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>> {
let trace_domain = CanonicCoset::new(log_size).circle_domain();
// Column is `1` at the first trace point and `0` on all other trace points.
let mut is_first = Col::<SimdBackend, BaseField>::zeros(1 << log_size);
is_first.as_mut_slice()[0] = BaseField::one();
vec![CircleEvaluation::new(trace_domain, is_first)]
}
}

0 comments on commit 86ddb92

Please sign in to comment.