Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert wide_fibbonacci AVX backend example to SIMD backend #615

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions crates/prover/src/examples/wide_fibonacci/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#[cfg(target_arch = "x86_64")]
pub mod avx;
pub mod component;
pub mod constraint_eval;
pub mod simd;
pub mod trace_gen;

#[cfg(test)]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use itertools::Itertools;
use num_traits::One;
use num_traits::{One, Zero};
use tracing::{span, Level};

use super::component::{WideFibAir, WideFibComponent};
use crate::core::air::accumulation::DomainEvaluationAccumulator;
use crate::core::air::{AirProver, Component, ComponentProver, ComponentTrace};
use crate::core::backend::avx512::qm31::PackedSecureField;
use crate::core::backend::avx512::{AVX512Backend, BaseFieldVec, PackedBaseField, VECS_LOG_SIZE};
use crate::core::backend::simd::column::BaseFieldVec;
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::backend::simd::qm31::PackedSecureField;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Col, Column, ColumnOps};
use crate::core::constraints::coset_vanishing;
use crate::core::fields::m31::BaseField;
Expand All @@ -16,20 +18,20 @@ use crate::core::poly::BitReversedOrder;
use crate::core::ColumnVec;
use crate::examples::wide_fibonacci::component::N_COLUMNS;

impl AirProver<AVX512Backend> for WideFibAir {
fn prover_components(&self) -> Vec<&dyn ComponentProver<AVX512Backend>> {
impl AirProver<SimdBackend> for WideFibAir {
fn prover_components(&self) -> Vec<&dyn ComponentProver<SimdBackend>> {
vec![&self.component]
}
}

pub fn gen_trace(
log_size: usize,
) -> ColumnVec<CircleEvaluation<AVX512Backend, BaseField, BitReversedOrder>> {
assert!(log_size >= VECS_LOG_SIZE);
log_size: u32,
) -> ColumnVec<CircleEvaluation<SimdBackend, BaseField, BitReversedOrder>> {
assert!(log_size >= LOG_N_LANES);
let mut trace = (0..N_COLUMNS)
.map(|_| Col::<AVX512Backend, BaseField>::zeros(1 << log_size))
.map(|_| Col::<SimdBackend, BaseField>::zeros(1 << log_size))
.collect_vec();
for vec_index in 0..(1 << (log_size - VECS_LOG_SIZE)) {
for vec_index in 0..(1 << (log_size - LOG_N_LANES)) {
let mut a = PackedBaseField::one();
let mut b = PackedBaseField::from_array(std::array::from_fn(|i| {
BaseField::from_u32_unchecked((vec_index * 16 + i) as u32)
Expand All @@ -41,18 +43,18 @@ pub fn gen_trace(
col.data[vec_index] = b;
});
}
let domain = CanonicCoset::new(log_size as u32).circle_domain();
let domain = CanonicCoset::new(log_size).circle_domain();
trace
.into_iter()
.map(|eval| CircleEvaluation::<AVX512Backend, _, BitReversedOrder>::new(domain, eval))
.map(|eval| CircleEvaluation::<SimdBackend, _, BitReversedOrder>::new(domain, eval))
.collect_vec()
}

impl ComponentProver<AVX512Backend> for WideFibComponent {
impl ComponentProver<SimdBackend> for WideFibComponent {
fn evaluate_constraint_quotients_on_domain(
&self,
trace: &ComponentTrace<'_, AVX512Backend>,
evaluation_accumulator: &mut DomainEvaluationAccumulator<AVX512Backend>,
trace: &ComponentTrace<'_, SimdBackend>,
evaluation_accumulator: &mut DomainEvaluationAccumulator<SimdBackend>,
) {
assert_eq!(trace.polys.len(), self.n_columns());
// TODO(spapini): Steal evaluation from commitment.
Expand All @@ -65,9 +67,9 @@ impl ComponentProver<AVX512Backend> for WideFibComponent {
let zero_domain = CanonicCoset::new(self.log_column_size()).coset;
let mut denoms =
BaseFieldVec::from_iter(eval_domain.iter().map(|p| coset_vanishing(zero_domain, p)));
<AVX512Backend as ColumnOps<BaseField>>::bit_reverse_column(&mut denoms);
<SimdBackend as ColumnOps<BaseField>>::bit_reverse_column(&mut denoms);
let mut denom_inverses = BaseFieldVec::zeros(denoms.len());
<AVX512Backend as FieldOps<BaseField>>::batch_inverse(&denoms, &mut denom_inverses);
<SimdBackend as FieldOps<BaseField>>::batch_inverse(&denoms, &mut denom_inverses);
span.exit();

let _span = span!(Level::INFO, "Constraint pointwise eval").entered();
Expand All @@ -77,7 +79,7 @@ impl ComponentProver<AVX512Backend> for WideFibComponent {
let [accum] =
evaluation_accumulator.columns([(constraint_log_degree_bound, n_constraints)]);

for vec_row in 0..(1 << (eval_domain.log_size() - VECS_LOG_SIZE as u32)) {
for vec_row in 0..(1 << (eval_domain.log_size() - LOG_N_LANES)) {
// Numerator.
let a = trace_eval[0].data[vec_row];
let mut row_res = PackedSecureField::zero();
Expand All @@ -104,27 +106,26 @@ impl ComponentProver<AVX512Backend> for WideFibComponent {
}
}

#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[cfg(test)]
mod tests {
use tracing::{span, Level};

use crate::core::backend::avx512::AVX512Backend;
use super::{gen_trace, WideFibAir};
use crate::core::backend::simd::SimdBackend;
use crate::core::channel::{Blake2sChannel, Channel};
use crate::core::fields::m31::BaseField;
use crate::core::fields::IntoSlice;
use crate::core::prover::{prove, verify};
use crate::core::vcs::blake2_hash::Blake2sHasher;
use crate::core::vcs::hasher::Hasher;
use crate::examples::wide_fibonacci::avx::{gen_trace, WideFibAir};
use crate::examples::wide_fibonacci::component::{WideFibComponent, LOG_N_COLUMNS};

#[test_log::test]
fn test_avx_wide_fib_prove() {
fn test_simd_wide_fib_prove() {
// Note: To see time measurement, run test with
// RUST_LOG_SPAN_EVENTS=enter,close RUST_LOG=info RUST_BACKTRACE=1 RUSTFLAGS="
// -C target-cpu=native -C target-feature=+avx512f -C opt-level=2" cargo test
// test_avx_wide_fib_prove -- --nocapture
// test_simd_wide_fib_prove -- --nocapture

// Note: 17 means 128MB of trace.
const LOG_N_ROWS: u32 = 12;
Expand All @@ -133,11 +134,11 @@ mod tests {
log_n_instances: LOG_N_ROWS,
};
let span = span!(Level::INFO, "Trace generation").entered();
let trace = gen_trace(component.log_column_size() as usize);
let trace = gen_trace(component.log_column_size());
span.exit();
let channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[])));
let air = WideFibAir { component };
let proof = prove::<AVX512Backend>(&air, channel, trace).unwrap();
let proof = prove::<SimdBackend>(&air, channel, trace).unwrap();

let channel = &mut Blake2sChannel::new(Blake2sHasher::hash(BaseField::into_slice(&[])));
verify(proof, &air, channel).unwrap();
Expand Down
Loading