Skip to content

Commit

Permalink
Use PreProcessedColumn Trait Instead of Enum
Browse files Browse the repository at this point in the history
  • Loading branch information
Gali-StarkWare committed Jan 13, 2025
1 parent 01a4251 commit bf8c714
Show file tree
Hide file tree
Showing 19 changed files with 210 additions and 209 deletions.
60 changes: 29 additions & 31 deletions crates/prover/src/constraint_framework/component.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::borrow::Cow;
use std::collections::HashMap;
use std::fmt::{self, Display, Formatter};
use std::iter::zip;
use std::ops::Deref;
Expand All @@ -11,7 +10,6 @@ use tracing::{span, Level};

use super::cpu_domain::CpuDomainEvaluator;
use super::logup::LogupSums;
use super::preprocessed_columns::PreprocessedColumn;
use super::{
EvalAtRow, InfoEvaluator, PointEvaluator, SimdDomainEvaluator, PREPROCESSED_TRACE_IDX,
};
Expand Down Expand Up @@ -49,7 +47,8 @@ pub struct TraceLocationAllocator {
/// Mapping of tree index to next available column offset.
next_tree_offsets: TreeVec<usize>,
/// Mapping of preprocessed columns to their index.
preprocessed_columns: HashMap<PreprocessedColumn, usize>,
// TODO(Gali): Change Vec type to struct PreProcessedColumnId {pub id: String}.
preprocessed_columns: Vec<String>,
/// Controls whether the preprocessed columns are dynamic or static (default=Dynamic).
preprocessed_columns_allocation_mode: PreprocessedColumnsAllocationMode,
}
Expand Down Expand Up @@ -81,31 +80,27 @@ impl TraceLocationAllocator {
}

/// Create a new `TraceLocationAllocator` with fixed preprocessed columns setup.
pub fn new_with_preproccessed_columns(preprocessed_columns: &[PreprocessedColumn]) -> Self {
pub fn new_with_preproccessed_columns(preprocessed_columns: &[String]) -> Self {
assert!(
preprocessed_columns.iter().all_unique(),
"preprocessed_columns contains duplicates"
);
Self {
next_tree_offsets: Default::default(),
preprocessed_columns: preprocessed_columns
.iter()
.enumerate()
.map(|(i, &col)| (col, i))
.collect(),
preprocessed_columns: preprocessed_columns.to_vec(),
preprocessed_columns_allocation_mode: PreprocessedColumnsAllocationMode::Static,
}
}

pub const fn preprocessed_columns(&self) -> &HashMap<PreprocessedColumn, usize> {
pub const fn preprocessed_columns(&self) -> &Vec<String> {
&self.preprocessed_columns
}

// validates that `self.preprocessed_columns` is consistent with
// `preprocessed_columns`.
// I.e. preprocessed_columns[i] == self.preprocessed_columns[i].
pub fn validate_preprocessed_columns(&self, preprocessed_columns: &[PreprocessedColumn]) {
assert_eq!(preprocessed_columns.len(), self.preprocessed_columns.len());

for (column, idx) in self.preprocessed_columns.iter() {
assert_eq!(Some(column), preprocessed_columns.get(*idx));
}
pub fn validate_preprocessed_columns(&self, preprocessed_columns: &[String]) {
assert_eq!(self.preprocessed_columns, preprocessed_columns);
}
}

Expand Down Expand Up @@ -144,22 +139,25 @@ impl<E: FrameworkEval> FrameworkComponent<E> {
.iter()
.map(|col| {
let next_column = location_allocator.preprocessed_columns.len();
*location_allocator
if let Some(pos) = location_allocator
.preprocessed_columns
.entry(*col)
.or_insert_with(|| {
if matches!(
location_allocator.preprocessed_columns_allocation_mode,
PreprocessedColumnsAllocationMode::Static
) {
panic!(
"Preprocessed column {:?} is missing from static alloction",
col
);
}

next_column
})
.iter()
.position(|x| x == col)
{
pos
} else {
if matches!(
location_allocator.preprocessed_columns_allocation_mode,
PreprocessedColumnsAllocationMode::Static
) {
panic!(
"Preprocessed column {:?} is missing from static allocation",
col
);
}
location_allocator.preprocessed_columns.push(col.clone());
next_column
}
})
.collect();
Self {
Expand Down
7 changes: 3 additions & 4 deletions crates/prover/src/constraint_framework/expr/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use num_traits::Zero;

use super::{BaseExpr, ExtExpr};
use crate::constraint_framework::expr::ColumnExpr;
use crate::constraint_framework::preprocessed_columns::PreprocessedColumn;
use crate::constraint_framework::{EvalAtRow, Relation, RelationEntry, INTERACTION_TRACE_IDX};
use crate::core::fields::m31;
use crate::core::lookups::utils::Fraction;
Expand Down Expand Up @@ -174,8 +173,8 @@ impl EvalAtRow for ExprEvaluator {
intermediate
}

fn get_preprocessed_column(&mut self, column: PreprocessedColumn) -> Self::F {
BaseExpr::Param(column.name().to_string())
fn get_preprocessed_column(&mut self, column: String) -> Self::F {
BaseExpr::Param(column)
}

crate::constraint_framework::logup_proxy!();
Expand Down Expand Up @@ -208,7 +207,7 @@ mod tests {
\
let constraint_1 = (QM31Impl::from_partial_evals([trace_2_column_3_offset_0, trace_2_column_4_offset_0, trace_2_column_5_offset_0, trace_2_column_6_offset_0]) \
- (QM31Impl::from_partial_evals([trace_2_column_3_offset_neg_1, trace_2_column_4_offset_neg_1, trace_2_column_5_offset_neg_1, trace_2_column_6_offset_neg_1]) \
- ((total_sum) * (preprocessed_is_first)))) \
- ((total_sum) * (preprocessed_is_first_16)))) \
* (intermediate1) \
- (qm31(1, 0, 0, 0));"
.to_string();
Expand Down
12 changes: 4 additions & 8 deletions crates/prover/src/constraint_framework/info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use std::rc::Rc;
use num_traits::{One, Zero};

use super::logup::{LogupAtRow, LogupSums};
use super::preprocessed_columns::PreprocessedColumn;
use super::{EvalAtRow, INTERACTION_TRACE_IDX};
use crate::constraint_framework::PREPROCESSED_TRACE_IDX;
use crate::core::fields::m31::BaseField;
Expand All @@ -22,16 +21,13 @@ use crate::core::pcs::TreeVec;
pub struct InfoEvaluator {
pub mask_offsets: TreeVec<Vec<Vec<isize>>>,
pub n_constraints: usize,
pub preprocessed_columns: Vec<PreprocessedColumn>,
// TODO(Gali): Change Vec type to struct PreProcessedColumnId {pub id: String}.
pub preprocessed_columns: Vec<String>,
pub logup: LogupAtRow<Self>,
pub arithmetic_counts: ArithmeticCounts,
}
impl InfoEvaluator {
pub fn new(
log_size: u32,
preprocessed_columns: Vec<PreprocessedColumn>,
logup_sums: LogupSums,
) -> Self {
pub fn new(log_size: u32, preprocessed_columns: Vec<String>, logup_sums: LogupSums) -> Self {
Self {
mask_offsets: Default::default(),
n_constraints: Default::default(),
Expand Down Expand Up @@ -70,7 +66,7 @@ impl EvalAtRow for InfoEvaluator {
array::from_fn(|_| FieldCounter::one())
}

fn get_preprocessed_column(&mut self, column: PreprocessedColumn) -> Self::F {
fn get_preprocessed_column(&mut self, column: String) -> Self::F {
self.preprocessed_columns.push(column);
FieldCounter::one()
}
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ pub struct LogupAtRow<E: EvalAtRow> {
pub fracs: Vec<Fraction<E::EF, E::EF>>,
pub is_finalized: bool,
/// The value of the `is_first` constant column at current row.
/// See [`super::preprocessed_columns::gen_is_first()`].
/// See [`super::preprocessed_columns::IsFirst`].
pub is_first: E::F,
pub log_size: u32,
}
Expand Down
8 changes: 4 additions & 4 deletions crates/prover/src/constraint_framework/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ pub use component::{FrameworkComponent, FrameworkEval, TraceLocationAllocator};
pub use info::InfoEvaluator;
use num_traits::{One, Zero};
pub use point::PointEvaluator;
use preprocessed_columns::PreprocessedColumn;
pub use simd_domain::SimdDomainEvaluator;

use crate::core::fields::m31::BaseField;
Expand Down Expand Up @@ -87,7 +86,7 @@ pub trait EvalAtRow {
mask_item
}

fn get_preprocessed_column(&mut self, _column: PreprocessedColumn) -> Self::F {
fn get_preprocessed_column(&mut self, _column: String) -> Self::F {
let [mask_item] = self.next_interaction_mask(PREPROCESSED_TRACE_IDX, [0]);
mask_item
}
Expand Down Expand Up @@ -173,9 +172,10 @@ macro_rules! logup_proxy {
fn write_logup_frac(&mut self, fraction: Fraction<Self::EF, Self::EF>) {
if self.logup.fracs.is_empty() {
self.logup.is_first = self.get_preprocessed_column(
crate::constraint_framework::preprocessed_columns::PreprocessedColumn::IsFirst(
crate::constraint_framework::preprocessed_columns::IsFirst::new(
self.logup.log_size,
),
)
.id(),
);
self.logup.is_finalized = false;
}
Expand Down
33 changes: 4 additions & 29 deletions crates/prover/src/constraint_framework/preprocessed_columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ impl IsFirst {
}
}

// TODO(ilya): Where should this enum be placed?
// TODO(Gali): Consider making it a trait, add documentation for the rest of the variants.
// TODO(Gali): Remove Enum.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum PreprocessedColumn {
/// A column with `1` at the first position, and `0` elsewhere.
Expand Down Expand Up @@ -173,26 +172,15 @@ pub fn gen_preprocessed_columns<'a, B: Backend>(

#[cfg(test)]
mod tests {
use super::IsFirst;
use crate::core::backend::simd::m31::N_LANES;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::Column;
use crate::core::fields::m31::{BaseField, M31};
const LOG_SIZE: u32 = 8;

#[test]
fn test_gen_seq() {
let seq = super::gen_seq::<SimdBackend>(LOG_SIZE);

for i in 0..(1 << LOG_SIZE) {
assert_eq!(seq.at(i), BaseField::from_u32_unchecked(i as u32));
}
}

// TODO(Gali): Add packed_at tests for xor_table and plonk.
#[test]
fn test_packed_at_is_first() {
let is_first = super::PreprocessedColumn::IsFirst(LOG_SIZE);
let expected_is_first = super::gen_is_first::<SimdBackend>(LOG_SIZE).to_cpu();
let is_first = IsFirst::new(LOG_SIZE);
let expected_is_first = is_first.gen_column_simd().to_cpu();

for i in 0..(1 << LOG_SIZE) / N_LANES {
assert_eq!(
Expand All @@ -201,17 +189,4 @@ mod tests {
);
}
}

#[test]
fn test_packed_at_seq() {
let seq = super::PreprocessedColumn::Seq(LOG_SIZE);
let expected_seq: [_; 1 << LOG_SIZE] = std::array::from_fn(|i| M31::from(i as u32));

let packed_seq = std::array::from_fn::<_, { (1 << LOG_SIZE) / N_LANES }, _>(|i| {
seq.packed_at(i).to_array()
})
.concat();

assert_eq!(packed_seq, expected_seq);
}
}
5 changes: 0 additions & 5 deletions crates/prover/src/constraint_framework/relation_tracker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use itertools::Itertools;
use num_traits::Zero;

use super::logup::LogupSums;
use super::preprocessed_columns::PreprocessedColumn;
use super::{
Batching, EvalAtRow, FrameworkEval, InfoEvaluator, Relation, RelationEntry,
TraceLocationAllocator, INTERACTION_TRACE_IDX,
Expand Down Expand Up @@ -146,10 +145,6 @@ impl EvalAtRow for RelationTrackerEvaluator<'_> {
})
}

fn get_preprocessed_column(&mut self, column: PreprocessedColumn) -> Self::F {
column.packed_at(self.vec_row)
}

fn add_constraint<G>(&mut self, _constraint: G) {}

fn combine_ef(_values: [Self::F; SECURE_EXTENSION_DEGREE]) -> Self::EF {
Expand Down
Loading

0 comments on commit bf8c714

Please sign in to comment.