diff --git a/crates/prover/src/constraint_framework/component.rs b/crates/prover/src/constraint_framework/component.rs new file mode 100644 index 000000000..9700d2747 --- /dev/null +++ b/crates/prover/src/constraint_framework/component.rs @@ -0,0 +1,154 @@ +use std::borrow::Cow; + +use itertools::Itertools; +use tracing::{span, Level}; + +use super::{EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator}; +use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; +use crate::core::air::{Component, ComponentProver, ComponentTrace}; +use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; +use crate::core::backend::simd::SimdBackend; +use crate::core::circle::CirclePoint; +use crate::core::constraints::coset_vanishing; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::fields::FieldExpOps; +use crate::core::pcs::TreeVec; +use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; +use crate::core::poly::BitReversedOrder; +use crate::core::prover::LOG_BLOWUP_FACTOR; +use crate::core::{utils, ColumnVec, InteractionElements, LookupValues}; + +/// A component defined solely in means of the constraints framework. +/// Implementing this trait introduces implementations for [Component] and [ComponentProver] for the +/// SIMD backend. +/// Note that the constraint framework only support components with columns of the same size. +pub trait FrameworkComponent { + fn log_size(&self) -> u32; + fn max_constraint_log_degree_bound(&self) -> u32; + fn evaluate(&self, eval: E) -> E; +} + +impl Component for C { + fn n_constraints(&self) -> usize { + self.evaluate(InfoEvaluator::default()).n_constraints + } + + fn max_constraint_log_degree_bound(&self) -> u32 { + FrameworkComponent::max_constraint_log_degree_bound(self) + } + + fn trace_log_degree_bounds(&self) -> TreeVec> { + TreeVec::new( + self.evaluate(InfoEvaluator::default()) + .mask_offsets + .iter() + .map(|tree_masks| vec![self.log_size(); tree_masks.len()]) + .collect(), + ) + } + + fn mask_points( + &self, + point: CirclePoint, + ) -> TreeVec>>> { + let info = self.evaluate(InfoEvaluator::default()); + let trace_step = CanonicCoset::new(self.log_size()).step(); + info.mask_offsets.map(|tree_mask| { + tree_mask + .iter() + .map(|col_mask| { + col_mask + .iter() + .map(|off| point + trace_step.mul_signed(*off).into_ef()) + .collect() + }) + .collect() + }) + } + + fn evaluate_constraint_quotients_at_point( + &self, + point: CirclePoint, + mask: &TreeVec>>, + evaluation_accumulator: &mut PointEvaluationAccumulator, + _interaction_elements: &InteractionElements, + _lookup_values: &LookupValues, + ) { + self.evaluate(PointEvaluator::new( + mask.as_ref(), + evaluation_accumulator, + coset_vanishing(CanonicCoset::new(self.log_size()).coset, point).inverse(), + )); + } +} + +impl ComponentProver for C { + fn evaluate_constraint_quotients_on_domain( + &self, + trace: &ComponentTrace<'_, SimdBackend>, + evaluation_accumulator: &mut DomainEvaluationAccumulator, + _interaction_elements: &InteractionElements, + _lookup_values: &LookupValues, + ) { + let eval_domain = CanonicCoset::new(self.max_constraint_log_degree_bound()).circle_domain(); + let trace_domain = CanonicCoset::new(self.log_size()); + + // Extend trace if necessary. + // TODO(spapini): Don't extend when eval_size < committed_size. Instead, pick a good + // subdomain. + let trace: TreeVec< + Vec>>, + > = if eval_domain.log_size() != self.log_size() + LOG_BLOWUP_FACTOR { + let _span = span!(Level::INFO, "Extension").entered(); + let twiddles = SimdBackend::precompute_twiddles(eval_domain.half_coset); + trace + .polys + .as_cols_ref() + .map_cols(|col| Cow::Owned(col.evaluate_with_twiddles(eval_domain, &twiddles))) + } else { + trace.evals.as_cols_ref().map_cols(|c| Cow::Borrowed(*c)) + }; + + // Denom inverses. + let log_expand = eval_domain.log_size() - trace_domain.log_size(); + let mut denom_inv = (0..1 << log_expand) + .map(|i| coset_vanishing(trace_domain.coset(), eval_domain.at(i)).inverse()) + .collect_vec(); + utils::bit_reverse(&mut denom_inv); + + // Accumulator. + let [mut accum] = + evaluation_accumulator.columns([(eval_domain.log_size(), self.n_constraints())]); + accum.random_coeff_powers.reverse(); + + let _span = span!(Level::INFO, "Constraint pointwise eval").entered(); + for vec_row in 0..(1 << (eval_domain.log_size() - LOG_N_LANES)) { + let trace_cols = trace.as_cols_ref().map_cols(|c| c.as_ref()); + + // Evaluate constrains at row. + let eval = SimdDomainEvaluator::new( + &trace_cols, + vec_row, + &accum.random_coeff_powers, + trace_domain.log_size(), + eval_domain.log_size(), + ); + let row_res = self.evaluate(eval).row_res; + + // Finalize row. + unsafe { + let denom_inv = PackedBaseField::broadcast( + denom_inv[vec_row >> (trace_domain.log_size() - LOG_N_LANES)], + ); + accum + .col + .set_packed(vec_row, accum.col.packed_at(vec_row) + row_res * denom_inv) + } + } + } + + fn lookup_values(&self, _trace: &ComponentTrace<'_, SimdBackend>) -> LookupValues { + LookupValues::default() + } +} diff --git a/crates/prover/src/constraint_framework/mod.rs b/crates/prover/src/constraint_framework/mod.rs index ef47c8e71..29f173e1b 100644 --- a/crates/prover/src/constraint_framework/mod.rs +++ b/crates/prover/src/constraint_framework/mod.rs @@ -1,5 +1,6 @@ /// ! This module contains helpers to express and use constraints for components. mod assert; +mod component; pub mod constant_columns; mod info; pub mod logup; @@ -11,6 +12,7 @@ use std::fmt::Debug; use std::ops::{Add, AddAssign, Mul, Neg, Sub}; pub use assert::{assert_constraints, AssertEvaluator}; +pub use component::FrameworkComponent; pub use info::InfoEvaluator; use num_traits::{One, Zero}; pub use point::PointEvaluator; diff --git a/crates/prover/src/examples/poseidon/mod.rs b/crates/prover/src/examples/poseidon/mod.rs index 350d45cf6..ac17fdf5a 100644 --- a/crates/prover/src/examples/poseidon/mod.rs +++ b/crates/prover/src/examples/poseidon/mod.rs @@ -1,39 +1,31 @@ //! AIR for Poseidon2 hash function from . -use std::array; use std::ops::{Add, AddAssign, Mul, Sub}; use itertools::Itertools; -use num_traits::{One, Zero}; -#[cfg(feature = "parallel")] -use rayon::prelude::*; +use num_traits::One; use tracing::{span, Level}; use crate::constraint_framework::constant_columns::gen_is_first; use crate::constraint_framework::logup::{LogupAtRow, LogupTraceGenerator, LookupElements}; -use crate::constraint_framework::{EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator}; -use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator}; -use crate::core::air::{Air, AirProver, Component, ComponentProver, ComponentTrace}; +use crate::constraint_framework::{EvalAtRow, FrameworkComponent}; +use crate::core::air::{Air, AirProver, Component, ComponentProver}; use crate::core::backend::simd::column::BaseColumn; -use crate::core::backend::simd::m31::{PackedBaseField, PackedM31, LOG_N_LANES}; +use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; use crate::core::backend::simd::qm31::PackedSecureField; use crate::core::backend::simd::SimdBackend; use crate::core::backend::{Col, Column}; use crate::core::channel::{Blake2sChannel, Channel as _}; -use crate::core::circle::CirclePoint; -use crate::core::constraints::coset_vanishing; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::{FieldExpOps, IntoSlice}; -use crate::core::pcs::{CommitmentSchemeProver, TreeVec}; +use crate::core::pcs::CommitmentSchemeProver; use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps}; use crate::core::poly::BitReversedOrder; use crate::core::prover::{prove, StarkProof, VerificationError, LOG_BLOWUP_FACTOR}; -use crate::core::utils::bit_reverse; use crate::core::vcs::blake2_hash::Blake2sHasher; use crate::core::vcs::hasher::Hasher; use crate::core::{ColumnVec, InteractionElements, LookupValues}; -use crate::trace_generation::{AirTraceGenerator, AirTraceVerifier, ComponentTraceGenerator}; const N_LOG_INSTANCES_PER_ROW: usize = 3; const N_INSTANCES_PER_ROW: usize = 1 << N_LOG_INSTANCES_PER_ROW; @@ -50,23 +42,6 @@ const EXTERNAL_ROUND_CONSTS: [[BaseField; N_STATE]; 2 * N_HALF_FULL_ROUNDS] = const INTERNAL_ROUND_CONSTS: [BaseField; N_PARTIAL_ROUNDS] = [BaseField::from_u32_unchecked(1234); N_PARTIAL_ROUNDS]; -#[derive(Clone)] -pub struct PoseidonComponent { - pub log_n_rows: u32, - pub lookup_elements: LookupElements, - pub claimed_sum: SecureField, -} - -impl PoseidonComponent { - pub fn log_column_size(&self) -> u32 { - self.log_n_rows - } - - pub fn n_columns(&self) -> usize { - N_COLUMNS - } -} - #[derive(Clone)] pub struct PoseidonAir { pub component: PoseidonComponent, @@ -82,103 +57,27 @@ impl Air for PoseidonAir { } } -impl AirTraceVerifier for PoseidonAir { - fn interaction_elements(&self, _channel: &mut Blake2sChannel) -> InteractionElements { - InteractionElements::default() - } -} - -impl AirTraceGenerator for PoseidonAir { - fn interact( - &self, - _trace: &ColumnVec>, - _elements: &InteractionElements, - ) -> Vec> { - vec![] - } - - fn to_air_prover(&self) -> impl AirProver { - self.clone() - } - - fn composition_log_degree_bound(&self) -> u32 { - self.component.max_constraint_log_degree_bound() - } -} - -pub fn poseidon_info() -> InfoEvaluator { - let mut eval = InfoEvaluator::default(); - let [is_first] = eval.next_interaction_mask(2, [0]); - let counter = PoseidonEval { - eval, - lookup_elements: LookupElements { - z: SecureField::one(), - alpha: SecureField::one(), - }, - logup: LogupAtRow::new(1, SecureField::zero(), is_first), - }; - counter.eval() +#[derive(Clone)] +pub struct PoseidonComponent { + pub log_n_rows: u32, + pub lookup_elements: LookupElements, + pub claimed_sum: SecureField, } - -impl Component for PoseidonComponent { - fn n_constraints(&self) -> usize { - poseidon_info().n_constraints +impl FrameworkComponent for PoseidonComponent { + fn log_size(&self) -> u32 { + self.log_n_rows } - fn max_constraint_log_degree_bound(&self) -> u32 { - self.log_column_size() + LOG_EXPAND - } - - fn trace_log_degree_bounds(&self) -> TreeVec> { - TreeVec::new( - poseidon_info() - .mask_offsets - .iter() - .map(|tree_masks| vec![self.log_n_rows; tree_masks.len()]) - .collect(), - ) + self.log_n_rows + LOG_EXPAND } - - fn mask_points( - &self, - point: CirclePoint, - ) -> TreeVec>>> { - let trace_step = CanonicCoset::new(self.log_n_rows).step(); - let counter = poseidon_info(); - counter.mask_offsets.map(|tree_mask| { - tree_mask - .iter() - .map(|col_mask| { - col_mask - .iter() - .map(|off| point + trace_step.mul_signed(*off).into_ef()) - .collect() - }) - .collect() - }) - } - - fn evaluate_constraint_quotients_at_point( - &self, - point: CirclePoint, - mask: &TreeVec>>, - evaluation_accumulator: &mut PointEvaluationAccumulator, - _interaction_elements: &InteractionElements, - _lookup_values: &LookupValues, - ) { - let constraint_zero_domain = CanonicCoset::new(self.log_column_size()).coset; - let denom = coset_vanishing(constraint_zero_domain, point); - let denom_inverse = denom.inverse(); - - let mut eval = PointEvaluator::new(mask.as_ref(), evaluation_accumulator, denom_inverse); + fn evaluate(&self, mut eval: E) -> E { let [is_first] = eval.next_interaction_mask(2, [0]); let poseidon_eval = PoseidonEval { eval, logup: LogupAtRow::new(1, self.claimed_sum, is_first), lookup_elements: self.lookup_elements, }; - let eval = poseidon_eval.eval(); - assert_eq!(eval.col_index[0], N_COLUMNS); + poseidon_eval.eval() } } @@ -448,123 +347,6 @@ pub fn gen_interaction_trace( logup_gen.finalize() } -impl ComponentTraceGenerator for PoseidonComponent { - type Component = Self; - type Inputs = (); - - fn add_inputs(&mut self, _inputs: &Self::Inputs) { - todo!() - } - - fn write_trace( - _component_id: &str, - _registry: &mut crate::trace_generation::registry::ComponentGenerationRegistry, - ) -> ColumnVec> { - todo!() - } - - fn write_interaction_trace( - &self, - _trace: &ColumnVec<&CircleEvaluation>, - _elements: &InteractionElements, - ) -> ColumnVec> { - vec![] - } - - fn component(&self) -> Self::Component { - todo!() - } -} - -impl ComponentProver for PoseidonComponent { - fn evaluate_constraint_quotients_on_domain( - &self, - trace: &ComponentTrace<'_, SimdBackend>, - evaluation_accumulator: &mut DomainEvaluationAccumulator, - _interaction_elements: &InteractionElements, - _lookup_values: &LookupValues, - ) { - assert_eq!(trace.polys[0].len(), self.n_columns()); - let eval_domain = CanonicCoset::new(self.log_column_size() + LOG_EXPAND).circle_domain(); - - // Create a new evaluation. - let span = span!(Level::INFO, "Deg4 eval").entered(); - let twiddles = SimdBackend::precompute_twiddles( - CanonicCoset::new(self.max_constraint_log_degree_bound()) - .circle_domain() - .half_coset, - ); - let trace_eval = trace - .polys - .as_cols_ref() - .map_cols(|col| col.evaluate_with_twiddles(eval_domain, &twiddles)); - let trace_eval_ref = trace_eval.as_ref().map(|t| t.iter().collect_vec()); - span.exit(); - - // Denoms. - let span = span!(Level::INFO, "Constraint eval denominators").entered(); - let zero_domain = CanonicCoset::new(self.log_column_size()).coset; - let denoms_inv: [BaseField; 1 << LOG_EXPAND] = - array::from_fn(|i| coset_vanishing(zero_domain, eval_domain.at(i)).inverse()); - let mut packed_denoms_inv = denoms_inv.map(PackedM31::broadcast); - bit_reverse(&mut packed_denoms_inv); - span.exit(); - - let _span = span!(Level::INFO, "Constraint pointwise eval").entered(); - - let constraint_log_degree_bound = self.max_constraint_log_degree_bound(); - let n_constraints = self.n_constraints(); - let [accum] = - evaluation_accumulator.columns([(constraint_log_degree_bound, n_constraints)]); - let mut pows = accum.random_coeff_powers.clone(); - pows.reverse(); - - const CHUNK_SIZE: usize = 16; - assert_eq!(accum.col.columns[0].length % (CHUNK_SIZE << LOG_N_LANES), 0); - - #[cfg(not(feature = "parallel"))] - let iter = (0..(1 << (eval_domain.log_size() - LOG_N_LANES))) - .step_by(CHUNK_SIZE) - .zip(accum.col.chunks_mut(CHUNK_SIZE)); - - #[cfg(feature = "parallel")] - let iter = (0..(1 << (eval_domain.log_size() - LOG_N_LANES))) - .into_par_iter() - .step_by(CHUNK_SIZE) - .zip(accum.col.chunks_mut(CHUNK_SIZE)); - - iter.for_each(|(chunk_offset, mut col_chunk)| { - for offset in 0..CHUNK_SIZE { - let vec_row = chunk_offset + offset; - let mut eval = SimdDomainEvaluator::new( - &trace_eval_ref, - vec_row, - &pows, - self.log_n_rows, - self.log_n_rows + LOG_EXPAND, - ); - let [is_first] = eval.next_interaction_mask(2, [0]); - let poseidon_eval = PoseidonEval { - eval, - logup: LogupAtRow::new(1, self.claimed_sum, is_first), - lookup_elements: self.lookup_elements, - }; - let eval = poseidon_eval.eval(); - let row_res = eval.row_res; - - let packed_denom_inv = - packed_denoms_inv[vec_row >> (zero_domain.log_size() - LOG_N_LANES)]; - let quotient = row_res * packed_denom_inv; - unsafe { col_chunk.set_packed(offset, col_chunk.packed_at(offset) + quotient) }; - } - }); - } - - fn lookup_values(&self, _trace: &ComponentTrace<'_, SimdBackend>) -> LookupValues { - LookupValues::default() - } -} - pub fn prove_poseidon(log_n_instances: u32) -> (PoseidonAir, StarkProof) { assert!(log_n_instances >= N_LOG_INSTANCES_PER_ROW as u32); let log_n_rows = log_n_instances - N_LOG_INSTANCES_PER_ROW as u32;