Skip to content

Commit

Permalink
Remove Mask struct (#569)
Browse files Browse the repository at this point in the history
<!-- Reviewable:start -->
This change is [<img src="https://reviewable.io/review_button.svg" height="34" align="absmiddle" alt="Reviewable"/>](https://reviewable.io/reviews/starkware-libs/stwo/569)
<!-- Reviewable:end -->
  • Loading branch information
shaharsamocha7 authored Apr 15, 2024
2 parents d75d283 + 5f9adce commit f90d64d
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 45 deletions.
91 changes: 91 additions & 0 deletions src/core/air/mask.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
use std::collections::HashSet;
use std::vec;

use itertools::Itertools;

use crate::core::circle::CirclePoint;
use crate::core::fields::qm31::SecureField;
use crate::core::poly::circle::CanonicCoset;
use crate::core::ColumnVec;

/// Mask holds a vector with an entry for each column.
/// Each entry holds a list of mask items, which are the offsets of the mask at that column.
type Mask = ColumnVec<Vec<usize>>;

/// Returns the same point for each mask item.
/// Should be used where all the mask items has no shift from the constraint point.
pub fn fixed_mask_points(
mask: &Mask,
point: CirclePoint<SecureField>,
) -> ColumnVec<Vec<CirclePoint<SecureField>>> {
assert_eq!(
mask.iter()
.flat_map(|mask_entry| mask_entry.iter().collect::<HashSet<_>>())
.collect::<HashSet<&usize>>()
.into_iter()
.collect_vec(),
vec![&0]
);
mask.iter()
.map(|mask_entry| mask_entry.iter().map(|_| point).collect())
.collect()
}

/// For each mask item returns the point shifted by the domain initial point of the column.
/// Should be used where the mask items are shifted from the constraint point.
pub fn shifted_mask_points(
mask: &Mask,
domains: &[CanonicCoset],
point: CirclePoint<SecureField>,
) -> ColumnVec<Vec<CirclePoint<SecureField>>> {
mask.iter()
.zip(domains.iter())
.map(|(mask_entry, domain)| {
mask_entry
.iter()
.map(|mask_item| point + domain.at(*mask_item).into_ef())
.collect()
})
.collect()
}

#[cfg(test)]
mod tests {
use crate::core::air::mask::{fixed_mask_points, shifted_mask_points};
use crate::core::circle::CirclePoint;
use crate::core::poly::circle::CanonicCoset;

#[test]
fn test_mask_fixed_points() {
let mask = vec![vec![0], vec![0]];
let constraint_point = CirclePoint::get_point(1234);

let points = fixed_mask_points(&mask, constraint_point);

assert_eq!(points.len(), 2);
assert_eq!(points[0].len(), 1);
assert_eq!(points[1].len(), 1);
assert_eq!(points[0][0], constraint_point);
assert_eq!(points[1][0], constraint_point);
}

#[test]
fn test_mask_shifted_points() {
let mask = vec![vec![0, 1], vec![0, 1, 2]];
let constraint_point = CirclePoint::get_point(1234);
let domains = (0..mask.len() as u32)
.map(|i| CanonicCoset::new(7 + i))
.collect::<Vec<_>>();

let points = shifted_mask_points(&mask, &domains, constraint_point);

assert_eq!(points.len(), 2);
assert_eq!(points[0].len(), 2);
assert_eq!(points[1].len(), 3);
assert_eq!(points[0][0], constraint_point + domains[0].at(0).into_ef());
assert_eq!(points[0][1], constraint_point + domains[0].at(1).into_ef());
assert_eq!(points[1][0], constraint_point + domains[1].at(0).into_ef());
assert_eq!(points[1][1], constraint_point + domains[1].at(1).into_ef());
assert_eq!(points[1][2], constraint_point + domains[1].at(2).into_ef());
}
}
34 changes: 2 additions & 32 deletions src/core/air/mod.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
use std::iter::zip;
use std::ops::Deref;

use self::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use super::backend::Backend;
use super::circle::CirclePoint;
use super::fields::qm31::SecureField;
use super::poly::circle::{CanonicCoset, CirclePoly};
use super::poly::circle::CirclePoly;
use super::ColumnVec;

pub mod accumulation;
mod air_ext;
pub mod mask;

pub use air_ext::AirExt;

Expand All @@ -23,36 +23,6 @@ pub trait Air<B: Backend> {
fn components(&self) -> Vec<&dyn Component<B>>;
}

/// Holds the mask offsets at each column.
/// Holds a vector with an entry for each column. Each entry holds the offsets
/// of the mask at that column.
pub struct Mask(pub ColumnVec<Vec<usize>>);

impl Mask {
pub fn to_points(
&self,
domains: &[CanonicCoset],
point: CirclePoint<SecureField>,
) -> ColumnVec<Vec<CirclePoint<SecureField>>> {
self.iter()
.zip(domains.iter())
.map(|(col, domain)| {
col.iter()
.map(|i| point + domain.at(*i).into_ef())
.collect()
})
.collect()
}
}

impl Deref for Mask {
type Target = ColumnVec<Vec<usize>>;

fn deref(&self) -> &Self::Target {
&self.0
}
}

/// A component is a set of trace columns of various sizes along with a set of
/// constraints on them.
pub trait Component<B: Backend> {
Expand Down
10 changes: 7 additions & 3 deletions src/examples/fibonacci/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use std::ops::Div;
use num_traits::One;

use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use crate::core::air::{Component, ComponentTrace, Mask};
use crate::core::air::mask::shifted_mask_points;
use crate::core::air::{Component, ComponentTrace};
use crate::core::backend::CPUBackend;
use crate::core::circle::{CirclePoint, Coset};
use crate::core::constraints::{coset_vanishing, pair_vanishing};
Expand Down Expand Up @@ -120,8 +121,11 @@ impl Component<CPUBackend> for FibonacciComponent {
&self,
point: CirclePoint<SecureField>,
) -> ColumnVec<Vec<CirclePoint<SecureField>>> {
let fib_mask = Mask(vec![vec![0, 1, 2]]);
fib_mask.to_points(&[CanonicCoset::new(self.log_size)], point)
shifted_mask_points(
&vec![vec![0, 1, 2]],
&[CanonicCoset::new(self.log_size)],
point,
)
}

fn evaluate_constraint_quotients_at_point(
Expand Down
8 changes: 3 additions & 5 deletions src/examples/wide_fibonacci/avx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use num_traits::{One, Zero};

use super::structs::WideFibComponent;
use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use crate::core::air::{Air, Component, ComponentTrace, Mask};
use crate::core::air::mask::fixed_mask_points;
use crate::core::air::{Air, Component, ComponentTrace};
use crate::core::backend::avx512::qm31::PackedSecureField;
use crate::core::backend::avx512::{AVX512Backend, BaseFieldVec, PackedBaseField, VECS_LOG_SIZE};
use crate::core::backend::{CPUBackend, Col, Column, ColumnOps};
Expand Down Expand Up @@ -132,10 +133,7 @@ impl Component<AVX512Backend> for WideFibComponent {
&self,
point: CirclePoint<SecureField>,
) -> ColumnVec<Vec<CirclePoint<SecureField>>> {
let mask = Mask(vec![vec![0_usize]; 256]);
mask.iter()
.map(|col| col.iter().map(|_| point).collect())
.collect()
fixed_mask_points(&vec![vec![0_usize]; 256], point)
}

fn evaluate_constraint_quotients_at_point(
Expand Down
8 changes: 3 additions & 5 deletions src/examples/wide_fibonacci/constraint_eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use num_traits::{One, Zero};

use super::structs::WideFibComponent;
use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use crate::core::air::{Component, ComponentTrace, Mask};
use crate::core::air::mask::fixed_mask_points;
use crate::core::air::{Component, ComponentTrace};
use crate::core::backend::CPUBackend;
use crate::core::circle::CirclePoint;
use crate::core::constraints::coset_vanishing;
Expand Down Expand Up @@ -30,10 +31,7 @@ impl Component<CPUBackend> for WideFibComponent {
&self,
point: CirclePoint<SecureField>,
) -> ColumnVec<Vec<CirclePoint<SecureField>>> {
let mask = Mask(vec![vec![0_usize]; 256]);
mask.iter()
.map(|col| col.iter().map(|_| point).collect())
.collect()
fixed_mask_points(&vec![vec![0_usize]; 256], point)
}

// TODO(ShaharS), precompute random coeff powers.
Expand Down

0 comments on commit f90d64d

Please sign in to comment.