diff --git a/crates/prover/src/constraint_framework/constant_cols.rs b/crates/prover/src/constraint_framework/constant_cols.rs index 29e2f9556..e942d81c7 100644 --- a/crates/prover/src/constraint_framework/constant_cols.rs +++ b/crates/prover/src/constraint_framework/constant_cols.rs @@ -4,9 +4,34 @@ use crate::core::backend::{Backend, Col, Column}; use crate::core::fields::m31::BaseField; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; use crate::core::poly::BitReversedOrder; +use crate::core::utils::bit_reverse_index; pub fn gen_is_first(log_size: u32) -> CircleEvaluation { let mut col = Col::::zeros(1 << log_size); col.set(0, BaseField::one()); CircleEvaluation::new(CanonicCoset::new(log_size).circle_domain(), col) } + +pub fn gen_is_step_multiple( + log_size: u32, + log_step: u32, +) -> CircleEvaluation { + let mut col = Col::::zeros(1 << log_size); + + for i in (0..1 << log_size).step_by(1 << log_step) { + let circle_domain_index = coset_index_to_circle_domain_index(i, log_size); + let circle_domain_index_bit_rev = bit_reverse_index(circle_domain_index, log_size); + col.set(circle_domain_index_bit_rev, BaseField::one()); + } + + CircleEvaluation::new(CanonicCoset::new(log_size).circle_domain(), col) +} + +/// Converts an index within a [`Coset`] to the corresponding index in a [`CircleDomain`]. +pub fn coset_index_to_circle_domain_index(coset_index: usize, log_domain_size: u32) -> usize { + if coset_index % 2 == 0 { + coset_index / 2 + } else { + ((2 << log_domain_size) - coset_index) / 2 + } +} diff --git a/crates/prover/src/constraint_framework/info.rs b/crates/prover/src/constraint_framework/info.rs index 499b0cc0f..b0107a8be 100644 --- a/crates/prover/src/constraint_framework/info.rs +++ b/crates/prover/src/constraint_framework/info.rs @@ -28,7 +28,7 @@ impl EvalAtRow for InfoEvaluator { if self.mask_offsets.len() <= interaction { self.mask_offsets.resize(interaction + 1, vec![]); } - self.mask_offsets[interaction].push(offsets.into_iter().collect()); + self.mask_offsets[interaction].push(offsets.to_vec()); [BaseField::one(); N] } fn add_constraint(&mut self, _constraint: G) diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index 4a5201605..bbdb97705 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -17,6 +17,7 @@ pub use point::PointEvaluator; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; +use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; use crate::core::fields::FieldExpOps; /// A trait for evaluating expressions at some point or row. @@ -67,5 +68,5 @@ pub trait EvalAtRow { Self::EF: Mul; /// Combines 4 base field values into a single extension field value. - fn combine_ef(values: [Self::F; 4]) -> Self::EF; + fn combine_ef(values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF; } diff --git a/crates/prover/src/examples/eq_col/mod.rs b/crates/prover/src/examples/eq_col/mod.rs new file mode 100644 index 000000000..92a4c7475 --- /dev/null +++ b/crates/prover/src/examples/eq_col/mod.rs @@ -0,0 +1,254 @@ +#![allow(dead_code)] + +use std::array; + +use num_traits::{One, Zero}; +use tracing::instrument; + +use crate::constraint_framework::constant_cols::{gen_is_first, gen_is_step_multiple}; +use crate::constraint_framework::EvalAtRow; +use crate::core::backend::simd::column::BaseFieldVec; +use crate::core::backend::simd::m31::PackedM31; +use crate::core::backend::simd::SimdBackend; +use crate::core::backend::Backend; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::secure_column::SECURE_EXTENSION_DEGREE; +use crate::core::lookups::gkr_prover::GkrOps; +use crate::core::lookups::utils::eq; +use crate::core::poly::circle::{CanonicCoset, CircleEvaluation}; +use crate::core::poly::BitReversedOrder; +use crate::core::ColumnVec; + +const BASE_TRACE: usize = 0; +const CONSTANTS_TRACE: usize = 2; + +#[derive(Debug, Clone, Copy)] +struct PointMeta { + // Point `p`. + p: [SecureField; N_VARIABLES], + // Stores all `eq(0, p_i) / eq(1, p_i)`. + eq_0pi_div_eq_1pi: [SecureField; N_VARIABLES], + // Equals `eq({0}^|p|, p)`. + eq_0_p: SecureField, +} + +impl PointMeta { + /// Creates new metadata from point `p`. + pub fn new(p: [SecureField; N_VARIABLES]) -> Self { + let zero = SecureField::zero(); + let one = SecureField::one(); + + Self { + p, + eq_0pi_div_eq_1pi: array::from_fn(|i| eq(&[zero], &[p[i]]) / eq(&[one], &[p[i]])), + eq_0_p: eq(&[zero; N_VARIABLES], &p), + } + } +} + +/// Constraints to enforce correct [`eq`] evals. +/// +/// See (Section 5.1). +/// +/// The [`eq`] trace evals should appear ordered on a `CircleDomain` rather than a `Coset`. This +/// gives context for why there are separate sets of constraints for each `CircleDomain` coset half. +struct EqEvalsCheck { + eval: E, + point_meta: PointMeta, +} + +impl EqEvalsCheck { + fn eval(self) -> (E, EqEvalsCheckMask) + where + // Need this const generic to get all required mask items. + [(); N_VARIABLES + 1]: Exists, + { + let Self { + mut eval, + point_meta, + } = self; + + let eq_evals_mask = EqEvalsCheckMask::::new(&mut eval); + let EqEvalsCheckMask { at_curr, at_steps } = eq_evals_mask; + + let [is_first, is_last] = eval.next_interaction_mask(CONSTANTS_TRACE, [0, 1]); + eval.add_constraint((at_curr - point_meta.eq_0_p) * is_first); + + let mut at_steps = at_steps.into_iter(); + + // Check last variable first due to ordering difference between `Coset` and `CircleDomain`. + if let Some(at_step) = at_steps.next() { + // Check eval on first point in half_coset0 with last point in half_coset1. + let eq_0pi_div_eq_1pi = point_meta.eq_0pi_div_eq_1pi[N_VARIABLES - 1]; + // TODO: Can avoid taking `is_last` mask item by using is_first and setting first base + // trace mask step to -1 (instead of 1). Constraint only changes slightly. + eval.add_constraint((at_curr * eq_0pi_div_eq_1pi - at_step) * is_last); + } + + // Check all other variables (all except last - see above). + for (variable, at_step) in at_steps.enumerate() { + // Consider adding `is_steps` to `EqEvalsCheckMask`. + let [is_step_half_coset0, is_step_half_coset1] = + eval.next_interaction_mask(CONSTANTS_TRACE, [0, -1]); + let eq_0pi_div_eq_1pi = point_meta.eq_0pi_div_eq_1pi[variable]; + // TODO: Check if it's safe to combine these constraints + eval.add_constraint((at_curr - at_step * eq_0pi_div_eq_1pi) * is_step_half_coset0); + eval.add_constraint((at_curr * eq_0pi_div_eq_1pi - at_step) * is_step_half_coset1); + } + + (eval, eq_evals_mask) + } +} + +#[derive(Debug, Clone, Copy)] +struct EqEvalsCheckMask { + at_curr: E::EF, + at_steps: [E::EF; N_VARIABLES], +} + +impl EqEvalsCheckMask { + pub fn new(eval: &mut E) -> Self + where + // Need this const generic to get all required mask items. + [(); N_VARIABLES + 1]: Exists, + { + let mut mask_offsets = [0; N_VARIABLES + 1]; + // Current. + mask_offsets[0] = 0; + // Variable step offsets. + mask_offsets[1..] + .iter_mut() + .enumerate() + .for_each(|(variable, mask_offset)| { + let variable_step = 1 << variable; + *mask_offset = variable_step; + }); + + let mask_coord_cols = array::from_fn(|_| eval.next_interaction_mask(0, mask_offsets)); + + let mask_items: [E::EF; N_VARIABLES + 1] = + array::from_fn(|i| E::combine_ef(mask_coord_cols.map(|c| c[i]))); + + Self { + at_curr: mask_items[0], + at_steps: mask_items[1..].try_into().unwrap(), + } + } +} + +trait Exists {} + +impl Exists for T {} + +#[instrument(skip_all)] +fn gen_base_trace( + eval_point: [SecureField; N_VARIABLES], +) -> ColumnVec> { + let eq_evals = SimdBackend::gen_eq_evals(&eval_point, SecureField::one()).into_evals(); + + // Currently have SecureField eq_evals. + // Separate into SECURE_EXTENSION_DEGREE many BaseField columns. + let mut eq_evals_cols: [Vec; SECURE_EXTENSION_DEGREE] = + array::from_fn(|_| Vec::new()); + + for secure_vec in &eq_evals.data { + let [v0, v1, v2, v3] = secure_vec.into_packed_m31s(); + eq_evals_cols[0].push(v0); + eq_evals_cols[1].push(v1); + eq_evals_cols[2].push(v2); + eq_evals_cols[3].push(v3); + } + + let domain = CanonicCoset::new(eval_point.len() as u32).circle_domain(); + let length = domain.size(); + eq_evals_cols + .map(|col| BaseFieldVec { data: col, length }) + .map(|col| CircleEvaluation::new(domain, col)) + .into() +} + +#[instrument] +fn gen_constants_trace( +) -> Vec> { + let mut constants_trace = Vec::new(); + + let log_size = N_VARIABLES as u32; + constants_trace.push(gen_is_first(log_size)); + + // TODO: Last constant column actually equal to gen_is_first but makes the prototype easier. + for log_step in 1..N_VARIABLES as u32 { + constants_trace.push(gen_is_step_multiple(log_size, log_step + 1)) + } + + constants_trace +} + +#[cfg(test)] +mod tests { + use std::array; + + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + use test_log::test; + + use super::{EqEvalsCheck, PointMeta}; + use crate::constraint_framework::assert_constraints; + use crate::core::backend::simd::SimdBackend; + use crate::core::fields::qm31::SecureField; + use crate::core::pcs::TreeVec; + use crate::core::poly::circle::CanonicCoset; + use crate::examples::eq_col::{gen_base_trace, gen_constants_trace}; + + #[test] + #[ignore = "SimdBackend `MIN_FFT_LOG_SIZE` is 5"] + fn test_eq_constraints_with_4_variables() { + const N_VARIABLES: usize = 4; + let mut rng = SmallRng::seed_from_u64(0); + let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); + let base_trace = gen_base_trace(eval_point); + let constants_trace = gen_constants_trace::(); + let traces = TreeVec::new(vec![base_trace, vec![], constants_trace]); + let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); + let trace_domain = CanonicCoset::new(eval_point.len() as u32); + let point_meta = PointMeta::new(eval_point); + + assert_constraints(&trace_polys, trace_domain, |eval| { + EqEvalsCheck { eval, point_meta }.eval(); + }); + } + + #[test] + fn test_eq_constraints_with_5_variables() { + const N_VARIABLES: usize = 5; + let mut rng = SmallRng::seed_from_u64(0); + let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); + let base_trace = gen_base_trace(eval_point); + let constants_trace = gen_constants_trace::(); + let traces = TreeVec::new(vec![base_trace, vec![], constants_trace]); + let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); + let trace_domain = CanonicCoset::new(eval_point.len() as u32); + let point_meta = PointMeta::new(eval_point); + + assert_constraints(&trace_polys, trace_domain, |eval| { + EqEvalsCheck { eval, point_meta }.eval(); + }); + } + + #[test] + fn test_eq_constraints_with_8_variables() { + const N_VARIABLES: usize = 8; + let mut rng = SmallRng::seed_from_u64(0); + let eval_point: [SecureField; N_VARIABLES] = array::from_fn(|_| rng.gen()); + let base_trace = gen_base_trace(eval_point); + let constants_trace = gen_constants_trace::(); + let traces = TreeVec::new(vec![base_trace, vec![], constants_trace]); + let trace_polys = traces.map(|trace| trace.into_iter().map(|c| c.interpolate()).collect()); + let trace_domain = CanonicCoset::new(eval_point.len() as u32); + let point_meta = PointMeta::new(eval_point); + + assert_constraints(&trace_polys, trace_domain, |eval| { + EqEvalsCheck { eval, point_meta }.eval(); + }); + } +} diff --git a/crates/prover/src/examples/mod.rs b/crates/prover/src/examples/mod.rs index cec9debd2..8698c6de2 100644 --- a/crates/prover/src/examples/mod.rs +++ b/crates/prover/src/examples/mod.rs @@ -1,3 +1,5 @@ +pub mod eq_col; pub mod fibonacci; pub mod poseidon; pub mod wide_fibonacci; +// pub mod xor; diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index f89d9c011..45649de96 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -1,5 +1,5 @@ //! AIR for Poseidon2 hash function from . - +use std::hint::black_box; use std::ops::{Add, AddAssign, Mul, Sub}; use itertools::Itertools; @@ -43,9 +43,9 @@ const N_COLUMNS: usize = N_INSTANCES_PER_ROW * N_COLUMNS_PER_REP; const LOG_EXPAND: u32 = 2; // TODO(spapini): Pick better constants. const EXTERNAL_ROUND_CONSTS: [[BaseField; N_STATE]; 2 * N_HALF_FULL_ROUNDS] = - [[BaseField::from_u32_unchecked(1234); N_STATE]; 2 * N_HALF_FULL_ROUNDS]; + black_box([[BaseField::from_u32_unchecked(1234); N_STATE]; 2 * N_HALF_FULL_ROUNDS]); const INTERNAL_ROUND_CONSTS: [BaseField; N_PARTIAL_ROUNDS] = - [BaseField::from_u32_unchecked(1234); N_PARTIAL_ROUNDS]; + black_box([BaseField::from_u32_unchecked(1234); N_PARTIAL_ROUNDS]); #[derive(Clone)] pub struct PoseidonComponent { diff --git a/crates/prover/src/lib.rs b/crates/prover/src/lib.rs index d820075ce..cee2ff3b8 100644 --- a/crates/prover/src/lib.rs +++ b/crates/prover/src/lib.rs @@ -1,18 +1,21 @@ +#![allow(incomplete_features)] #![feature( array_methods, array_chunks, - iter_array_chunks, + assert_matches, + const_black_box, exact_size_is_empty, - is_sorted, - new_uninit, - slice_group_by, - stdsimd, + generic_const_exprs, get_many_mut, int_roundings, + is_sorted, + iter_array_chunks, + new_uninit, + portable_simd, slice_first_last_chunk, slice_flatten, - assert_matches, - portable_simd + slice_group_by, + stdsimd )] pub mod constraint_framework; pub mod core;