Skip to content

Commit

Permalink
WideFib test with AVX Backend (#492)
Browse files Browse the repository at this point in the history
<!-- Reviewable:start -->
This change is [<img src="https://reviewable.io/review_button.svg" height="34" align="absmiddle" alt="Reviewable"/>](https://reviewable.io/reviews/starkware-libs/stwo/492)
<!-- Reviewable:end -->
  • Loading branch information
spapinistarkware authored Apr 4, 2024
1 parent 90e675a commit 58a3452
Show file tree
Hide file tree
Showing 6 changed files with 236 additions and 44 deletions.
1 change: 1 addition & 0 deletions src/commitment_scheme/blake2_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::fmt;
use blake2::{Blake2s256, Digest};

// Wrapper for the blake2s hash type.
#[repr(align(32))]
#[derive(Clone, Copy, PartialEq, Default, Eq)]
pub struct Blake2sHash([u8; 32]);

Expand Down
15 changes: 8 additions & 7 deletions src/core/air/accumulation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ pub struct DomainEvaluationAccumulator<B: Backend> {
/// Each `sub_accumulation` holds the sum over all columns i of that log_size, of
/// `evaluation_i * alpha^(N - 1 - i)`
/// where `N` is the total number of evaluations.
sub_accumulations: Vec<SecureColumn<B>>,
sub_accumulations: Vec<Option<SecureColumn<B>>>,
}

impl<B: Backend> DomainEvaluationAccumulator<B> {
Expand All @@ -62,9 +62,7 @@ impl<B: Backend> DomainEvaluationAccumulator<B> {
let max_log_size = max_log_size as usize;
Self {
random_coeff_powers: generate_secure_powers(random_coeff, total_columns),
sub_accumulations: (0..(max_log_size + 1))
.map(|n| SecureColumn::zeros(1 << n))
.collect(),
sub_accumulations: (0..(max_log_size + 1)).map(|_| None).collect(),
}
}

Expand All @@ -82,13 +80,13 @@ impl<B: Backend> DomainEvaluationAccumulator<B> {
.unwrap_or_else(|e| panic!("invalid log_sizes: {}", e))
.into_iter()
.zip(n_cols_per_size)
.map(|(col, (_, n_cols))| {
.map(|(col, (log_size, n_cols))| {
let random_coeffs = self
.random_coeff_powers
.split_off(self.random_coeff_powers.len() - n_cols);
ColumnAccumulator {
random_coeff_powers: random_coeffs,
col,
col: col.get_or_insert_with(|| SecureColumn::zeros(1 << log_size)),
}
})
.collect_vec()
Expand Down Expand Up @@ -120,6 +118,9 @@ impl<B: Backend> DomainEvaluationAccumulator<B> {
let res_log_size = self.log_size();

for (log_size, values) in self.sub_accumulations.into_iter().enumerate().skip(1) {
let Some(values) = values else {
continue;
};
let coeffs = SecureColumn::<B> {
columns: values.columns.map(|c| {
CircleEvaluation::<B, BaseField, BitReversedOrder>::new(
Expand All @@ -142,7 +143,7 @@ impl<B: Backend> DomainEvaluationAccumulator<B> {
/// A domain accumulator for polynomials of a single size.
pub struct ColumnAccumulator<'a, B: Backend> {
pub random_coeff_powers: Vec<SecureField>,
col: &'a mut SecureColumn<B>,
pub col: &'a mut SecureColumn<B>,
}
impl<'a> ColumnAccumulator<'a, CPUBackend> {
pub fn accumulate(&mut self, index: usize, evaluation: SecureField) {
Expand Down
37 changes: 20 additions & 17 deletions src/core/backend/avx512/blake2s.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use std::arch::x86_64::__m512i;
use std::arch::x86_64::{__m512i, _mm512_loadu_si512};

use itertools::Itertools;

use super::blake2s_avx::{compress16, set1, transpose_msgs, untranspose_states};
use super::{AVX512Backend, VECS_LOG_SIZE};
use crate::commitment_scheme::blake2_hash::Blake2sHash;
use crate::commitment_scheme::blake2_merkle::Blake2sMerkleHasher;
use crate::commitment_scheme::ops::MerkleOps;
use crate::core::backend::{Col, ColumnOps};
use crate::commitment_scheme::ops::{MerkleHasher, MerkleOps};
use crate::core::backend::{Col, Column, ColumnOps};
use crate::core::fields::m31::BaseField;

impl ColumnOps<Blake2sHash> for AVX512Backend {
Expand All @@ -25,19 +25,20 @@ impl MerkleOps<Blake2sMerkleHasher> for AVX512Backend {
columns: &[&Col<AVX512Backend, BaseField>],
) -> Vec<Blake2sHash> {
// Pad prev_layer if too small.
let mut padded_buffer = vec![];
let prev_layer = if log_size < VECS_LOG_SIZE as u32 {
prev_layer.map(|prev_layer| {
padded_buffer = prev_layer
.iter()
.copied()
.chain(std::iter::repeat(Blake2sHash::default()))
.collect_vec();
&padded_buffer
})
} else {
prev_layer
};
if log_size < VECS_LOG_SIZE as u32 {
return (0..(1 << log_size))
.map(|i| {
Blake2sMerkleHasher::hash_node(
prev_layer.map(|prev_layer| (prev_layer[2 * i], prev_layer[2 * i + 1])),
&columns.iter().map(|column| column.at(i)).collect_vec(),
)
})
.collect();
}

if let Some(prev_layer) = prev_layer {
assert_eq!(prev_layer.len(), 1 << (log_size + 1));
}

// Commit to columns.
let mut res = Vec::with_capacity(1 << log_size);
Expand All @@ -46,7 +47,9 @@ impl MerkleOps<Blake2sMerkleHasher> for AVX512Backend {
// Hash prev_layer, if exists.
if let Some(prev_layer) = prev_layer {
let ptr = prev_layer[(i << 5)..((i + 1) << 5)].as_ptr() as *const __m512i;
let msgs: [__m512i; 16] = std::array::from_fn(|j| unsafe { *ptr.add(j) });
let msgs: [__m512i; 16] = std::array::from_fn(|j| unsafe {
_mm512_loadu_si512(ptr.add(j) as *const i32)
});
state = unsafe {
compress16(
state,
Expand Down
182 changes: 182 additions & 0 deletions src/examples/wide_fibonacci/avx.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
use itertools::Itertools;
use num_traits::{One, Zero};

use super::structs::WideFibComponent;
use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use crate::core::air::{Air, Component, ComponentTrace, Mask};
use crate::core::backend::avx512::qm31::PackedSecureField;
use crate::core::backend::avx512::{AVX512Backend, BaseFieldVec, PackedBaseField, VECS_LOG_SIZE};
use crate::core::backend::{CPUBackend, Col, Column, ColumnOps};
use crate::core::circle::CirclePoint;
use crate::core::constraints::coset_vanishing;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::{FieldExpOps, FieldOps};
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::ColumnVec;

const N_COLS: usize = 1 << 8;

pub struct WideFibAir {
component: WideFibComponent,
}
impl Air<AVX512Backend> for WideFibAir {
fn components(&self) -> Vec<&dyn Component<AVX512Backend>> {
vec![&self.component]
}
}
impl Air<CPUBackend> for WideFibAir {
fn components(&self) -> Vec<&dyn Component<CPUBackend>> {
vec![&self.component]
}
}

pub fn gen_trace(
log_size: usize,
) -> ColumnVec<CircleEvaluation<AVX512Backend, BaseField, BitReversedOrder>> {
assert!(log_size >= VECS_LOG_SIZE);
let mut trace = (0..N_COLS)
.map(|_| Col::<AVX512Backend, BaseField>::zeros(1 << log_size))
.collect_vec();
for vec_index in 0..(1 << (log_size - VECS_LOG_SIZE)) {
let mut a = PackedBaseField::one();
let mut b = PackedBaseField::from_array(std::array::from_fn(|i| {
BaseField::from_u32_unchecked((vec_index * 16 + i) as u32)
}));
trace[0].data[vec_index] = a;
trace[1].data[vec_index] = b;
trace.iter_mut().skip(2).for_each(|col| {
(a, b) = (b, a.square() + b.square());
col.data[vec_index] = b;
});
}
let domain = CanonicCoset::new(log_size as u32).circle_domain();
trace
.into_iter()
.map(|eval| CircleEvaluation::<AVX512Backend, _, BitReversedOrder>::new(domain, eval))
.collect_vec()
}

impl Component<AVX512Backend> for WideFibComponent {
fn n_constraints(&self) -> usize {
N_COLS - 1
}

fn max_constraint_log_degree_bound(&self) -> u32 {
self.log_size + 1
}

fn trace_log_degree_bounds(&self) -> Vec<u32> {
vec![self.log_size; N_COLS]
}

fn evaluate_constraint_quotients_on_domain(
&self,
trace: &ComponentTrace<'_, AVX512Backend>,
evaluation_accumulator: &mut DomainEvaluationAccumulator<AVX512Backend>,
) {
assert_eq!(trace.columns.len(), N_COLS);
// TODO(spapini): Steal evaluation from commitment.
let eval_domain = CanonicCoset::new(self.log_size + 1).circle_domain();
let trace_eval = trace
.columns
.iter()
.map(|poly| poly.evaluate(eval_domain))
.collect_vec();

// Denoms.
// TODO(spapini): Make this prettier.
let zero_domain = CanonicCoset::new(self.log_size).coset;
let mut denoms =
BaseFieldVec::from_iter(eval_domain.iter().map(|p| coset_vanishing(zero_domain, p)));
<AVX512Backend as ColumnOps<BaseField>>::bit_reverse_column(&mut denoms);
let mut denom_inverses = BaseFieldVec::zeros(denoms.len());
<AVX512Backend as FieldOps<BaseField>>::batch_inverse(&denoms, &mut denom_inverses);

let constraint_log_degree_bound = self.log_size + 1;
let [accum] = evaluation_accumulator.columns([(constraint_log_degree_bound, N_COLS - 1)]);

for vec_row in 0..(1 << (eval_domain.log_size() - VECS_LOG_SIZE as u32)) {
// Numerator.
let a = trace_eval[0].data[vec_row];
let mut row_res =
PackedSecureField::from_packed_m31s([
a - PackedBaseField::one(),
PackedBaseField::zero(),
PackedBaseField::zero(),
PackedBaseField::zero(),
]) * PackedSecureField::broadcast(accum.random_coeff_powers[N_COLS - 2]);

let mut a_sq = a.square();
let mut b_sq = trace_eval[1].data[vec_row].square();
#[allow(clippy::needless_range_loop)]
for i in 0..(N_COLS - 2) {
unsafe {
let c = *trace_eval.get_unchecked(i + 2).data.get_unchecked(vec_row);
row_res +=
PackedSecureField::broadcast(accum.random_coeff_powers[N_COLS - 3 - i])
* (a_sq + b_sq - c);
(a_sq, b_sq) = (b_sq, c.square());
}
}

accum.col.set_packed(
vec_row,
accum.col.packed_at(vec_row) + row_res * denom_inverses.data[vec_row],
)
}
}

fn mask_points(
&self,
point: CirclePoint<SecureField>,
) -> ColumnVec<Vec<CirclePoint<SecureField>>> {
let mask = Mask(vec![vec![0_usize]; 256]);
mask.iter()
.map(|col| col.iter().map(|_| point).collect())
.collect()
}

fn evaluate_constraint_quotients_at_point(
&self,
point: CirclePoint<SecureField>,
mask: &ColumnVec<Vec<SecureField>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
) {
let zero_domain = CanonicCoset::new(self.log_size).coset;
let denominator = coset_vanishing(zero_domain, point);
evaluation_accumulator.accumulate((mask[0][0] - SecureField::one()) / denominator);
for i in 0..(N_COLS - 2) {
let numerator = mask[i][0].square() + mask[i + 1][0].square() - mask[i + 2][0];
evaluation_accumulator.accumulate(numerator / denominator);
}
}
}

#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[cfg(test)]
mod tests {
use crate::commitment_scheme::blake2_hash::Blake2sHasher;
use crate::commitment_scheme::hasher::Hasher;
use crate::core::channel::{Blake2sChannel, Channel};
use crate::core::fields::m31::BaseField;
use crate::core::fields::IntoSlice;
use crate::core::prover::{prove, verify};
use crate::examples::wide_fibonacci::avx::{gen_trace, WideFibAir};
use crate::examples::wide_fibonacci::structs::WideFibComponent;

#[test]
fn test_avx_wide_fib_prove() {
// Note: For benchmarks, increase to 17, to get 128MB of trace.
const LOG_SIZE: u32 = 12;
let component = WideFibComponent { log_size: LOG_SIZE };
let air = WideFibAir { component };
let trace = gen_trace(LOG_SIZE as usize);
let channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[])));
let proof = prove(&air, channel, trace).unwrap();

let channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[])));
verify(proof, &air, channel).unwrap();
}
}
35 changes: 18 additions & 17 deletions src/examples/wide_fibonacci/constraint_eval.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use num_traits::Zero;
use num_traits::{One, Zero};

use super::structs::WideFibComponent;
use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
Expand Down Expand Up @@ -43,12 +43,13 @@ impl Component<CPUBackend> for WideFibComponent {
trace: &ComponentTrace<'_, CPUBackend>,
evaluation_accumulator: &mut DomainEvaluationAccumulator<CPUBackend>,
) {
let constraint_log_degree = Component::<CPUBackend>::max_constraint_log_degree_bound(self);
let n_constraints = Component::<CPUBackend>::n_constraints(self);
let mut trace_evals = vec![];
// TODO(ShaharS), Share this LDE with the commitment LDE.
for poly_index in 0..64 {
let poly = &trace.columns[poly_index];
let trace_eval_domain =
CanonicCoset::new(self.max_constraint_log_degree_bound()).circle_domain();
let trace_eval_domain = CanonicCoset::new(constraint_log_degree).circle_domain();
trace_evals.push(poly.evaluate(trace_eval_domain).bit_reverse());
}
let zero_domain = CanonicCoset::new(self.log_size).coset;
Expand All @@ -57,13 +58,10 @@ impl Component<CPUBackend> for WideFibComponent {
for point in eval_domain.iter() {
denoms.push(coset_vanishing(zero_domain, point));
}
let mut denom_inverses =
vec![BaseField::zero(); 1 << (self.max_constraint_log_degree_bound())];
let mut denom_inverses = vec![BaseField::zero(); 1 << (constraint_log_degree)];
BaseField::batch_inverse(&denoms, &mut denom_inverses);
let mut numerators =
vec![SecureField::zero(); 1 << (self.max_constraint_log_degree_bound())];
let [mut accum] = evaluation_accumulator
.columns([(self.max_constraint_log_degree_bound(), self.n_constraints())]);
let mut numerators = vec![SecureField::zero(); 1 << constraint_log_degree];
let [mut accum] = evaluation_accumulator.columns([(constraint_log_degree, n_constraints)]);
// TODO (ShaharS) Change to get the correct power of random coeff inside the loop.
let random_coeff = accum.random_coeff_powers[1];
for (i, point_index) in eval_domain.iter_indices().enumerate() {
Expand Down Expand Up @@ -431,19 +429,22 @@ impl Component<CPUBackend> for WideFibComponent {
* trace_evals[62].get_at(point_index))));
}
for (i, (num, denom)) in numerators.iter().zip(denom_inverses.iter()).enumerate() {
accum.accumulate(
bit_reverse_index(i, self.max_constraint_log_degree_bound()),
*num * *denom,
);
accum.accumulate(bit_reverse_index(i, constraint_log_degree), *num * *denom);
}
}

fn evaluate_constraint_quotients_at_point(
&self,
_point: CirclePoint<SecureField>,
_mask: &ColumnVec<Vec<SecureField>>,
_evaluation_accumulator: &mut PointEvaluationAccumulator,
point: CirclePoint<SecureField>,
mask: &ColumnVec<Vec<SecureField>>,
evaluation_accumulator: &mut PointEvaluationAccumulator,
) {
unimplemented!("not implemented")
let zero_domain = CanonicCoset::new(self.log_size).coset;
let denominator = coset_vanishing(zero_domain, point);
evaluation_accumulator.accumulate((mask[0][0] - SecureField::one()) / denominator);
for i in 0..(256 - 2) {
let numerator = mask[i][0].square() + mask[i + 1][0].square() - mask[i + 2][0];
evaluation_accumulator.accumulate(numerator / denominator);
}
}
}
Loading

0 comments on commit 58a3452

Please sign in to comment.