Skip to content

Commit

Permalink
Convert wide_fibbonacci AVX backend example to SIMD backend
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed May 11, 2024
1 parent 0713b9f commit f55de63
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 19 deletions.
3 changes: 1 addition & 2 deletions crates/prover/src/examples/wide_fibonacci/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#[cfg(target_arch = "x86_64")]
pub mod avx;
pub mod component;
pub mod constraint_eval;
pub mod simd;
pub mod trace_gen;

#[cfg(test)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ use tracing::{span, Level};
use super::component::{WideFibAir, WideFibComponent};
use crate::core::air::accumulation::DomainEvaluationAccumulator;
use crate::core::air::{AirProver, ComponentProver, ComponentTrace};
use crate::core::backend::avx512::qm31::PackedSecureField;
use crate::core::backend::avx512::{AVX512Backend, BaseFieldVec, PackedBaseField, VECS_LOG_SIZE};
use crate::core::backend::simd::column::BaseFieldVec;
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::backend::simd::qm31::PackedSecureField;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Col, Column, ColumnOps};
use crate::core::constraints::coset_vanishing;
use crate::core::fields::m31::BaseField;
Expand All @@ -16,20 +18,20 @@ use crate::core::poly::BitReversedOrder;
use crate::core::ColumnVec;
use crate::examples::wide_fibonacci::component::N_COLUMNS;

impl AirProver<AVX512Backend> for WideFibAir {
fn prover_components(&self) -> Vec<&dyn ComponentProver<AVX512Backend>> {
impl AirProver<SimdBackend> for WideFibAir {
fn prover_components(&self) -> Vec<&dyn ComponentProver<SimdBackend>> {
vec![&self.component]
}
}

pub fn gen_trace(
log_size: usize,
) -> ColumnVec<CircleEvaluation<AVX512Backend, BaseField, BitReversedOrder>> {
assert!(log_size >= VECS_LOG_SIZE);
) -> ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>> {
assert!(log_size >= LOG_N_LANES as usize);
let mut trace = (0..N_COLUMNS)
.map(|_| Col::<AVX512Backend, BaseField>::zeros(1 << log_size))
.map(|_| Col::<SimdBackend, BaseField>::zeros(1 << log_size))
.collect_vec();
for vec_index in 0..(1 << (log_size - VECS_LOG_SIZE)) {
for vec_index in 0..(1 << (log_size - LOG_N_LANES as usize)) {
let mut a = PackedBaseField::one();
let mut b = PackedBaseField::from_array(std::array::from_fn(|i| {
BaseField::from_u32_unchecked((vec_index * 16 + i) as u32)
Expand All @@ -44,15 +46,15 @@ pub fn gen_trace(
let domain = CanonicCoset::new(log_size as u32).circle_domain();
trace
.into_iter()
.map(|eval| CircleEvaluation::<AVX512Backend, _, BitReversedOrder>::new(domain, eval))
.map(|eval| CircleEvaluation::<SimdBackend, _, BitReversedOrder>::new(domain, eval))
.collect_vec()
}

impl ComponentProver<AVX512Backend> for WideFibComponent {
impl ComponentProver<SimdBackend> for WideFibComponent {
fn evaluate_constraint_quotients_on_domain(
&self,
trace: &ComponentTrace<'_, AVX512Backend>,
evaluation_accumulator: &mut DomainEvaluationAccumulator<AVX512Backend>,
trace: &ComponentTrace<'_, SimdBackend>,
evaluation_accumulator: &mut DomainEvaluationAccumulator<SimdBackend>,
) {
assert_eq!(trace.polys.len(), self.n_columns());
// TODO(spapini): Steal evaluation from commitment.
Expand All @@ -65,9 +67,9 @@ impl ComponentProver<AVX512Backend> for WideFibComponent {
let zero_domain = CanonicCoset::new(self.log_column_size()).coset;
let mut denoms =
BaseFieldVec::from_iter(eval_domain.iter().map(|p| coset_vanishing(zero_domain, p)));
<AVX512Backend as ColumnOps<BaseField>>::bit_reverse_column(&mut denoms);
<SimdBackend as ColumnOps<BaseField>>::bit_reverse_column(&mut denoms);
let mut denom_inverses = BaseFieldVec::zeros(denoms.len());
<AVX512Backend as FieldOps<BaseField>>::batch_inverse(&denoms, &mut denom_inverses);
<SimdBackend as FieldOps<BaseField>>::batch_inverse(&denoms, &mut denom_inverses);
span.exit();

let _span = span!(Level::INFO, "Constraint pointwise eval").entered();
Expand All @@ -76,7 +78,7 @@ impl ComponentProver<AVX512Backend> for WideFibComponent {
let [accum] =
evaluation_accumulator.columns([(constraint_log_degree_bound, self.n_columns() - 1)]);

for vec_row in 0..(1 << (eval_domain.log_size() - VECS_LOG_SIZE as u32)) {
for vec_row in 0..(1 << (eval_domain.log_size() - LOG_N_LANES)) {
// Numerator.
let a = trace_eval[0].data[vec_row];
let mut row_res =
Expand Down Expand Up @@ -110,18 +112,17 @@ impl ComponentProver<AVX512Backend> for WideFibComponent {
}
}

#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[cfg(test)]
mod tests {
use tracing::{span, Level};

use super::{gen_trace, WideFibAir};
use crate::core::channel::{Blake2sChannel, Channel};
use crate::core::fields::m31::BaseField;
use crate::core::fields::IntoSlice;
use crate::core::prover::{prove, verify};
use crate::core::vcs::blake2_hash::Blake2sHasher;
use crate::core::vcs::hasher::Hasher;
use crate::examples::wide_fibonacci::avx::{gen_trace, WideFibAir};
use crate::examples::wide_fibonacci::component::{WideFibComponent, LOG_N_COLUMNS};

#[test_log::test]
Expand Down

0 comments on commit f55de63

Please sign in to comment.