diff --git a/crates/prover/benches/bit_rev.rs b/crates/prover/benches/bit_rev.rs index dac6cf4e1c..219a6d6882 100644 --- a/crates/prover/benches/bit_rev.rs +++ b/crates/prover/benches/bit_rev.rs @@ -1,54 +1,63 @@ #![feature(iter_array_chunks)] -use criterion::Criterion; - -#[cfg(target_arch = "x86_64")] -pub fn cpu_bit_rev(c: &mut criterion::Criterion) { - use stwo_prover::core::fields::m31::BaseField; +use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; +use itertools::Itertools; +use stwo_prover::core::fields::m31::BaseField; +pub fn cpu_bit_rev(c: &mut Criterion) { + use stwo_prover::core::utils::bit_reverse; + // TODO(andrew): Consider using same size for all. const SIZE: usize = 1 << 24; - let mut data: Vec<_> = (0..SIZE as u32) - .map(BaseField::from_u32_unchecked) - .collect(); - + let data = (0..SIZE).map(BaseField::from).collect_vec(); c.bench_function("cpu bit_rev 24bit", |b| { - b.iter(|| { - stwo_prover::core::utils::bit_reverse(&mut data); - }) + b.iter_batched( + || data.clone(), + |mut data| bit_reverse(&mut data), + BatchSize::LargeInput, + ); + }); +} + +pub fn simd_bit_rev(c: &mut Criterion) { + use stwo_prover::core::backend::simd::bit_reverse::bit_reverse_m31; + use stwo_prover::core::backend::simd::column::BaseFieldVec; + const SIZE: usize = 1 << 26; + let data = (0..SIZE).map(BaseField::from).collect::(); + c.bench_function("simd bit_rev 26bit", |b| { + b.iter_batched( + || data.data.clone(), + |mut data| bit_reverse_m31(&mut data), + BatchSize::LargeInput, + ); }); } #[cfg(target_arch = "x86_64")] -pub fn avx512_bit_rev(c: &mut criterion::Criterion) { - use bytemuck::cast_slice_mut; +pub fn avx512_bit_rev(c: &mut Criterion) { use stwo_prover::core::backend::avx512::bit_reverse::bit_reverse_m31; - use stwo_prover::core::backend::avx512::m31::PackedBaseField; - use stwo_prover::core::fields::m31::BaseField; - use stwo_prover::platform; - if !platform::avx512_detected() { + use stwo_prover::core::backend::avx512::BaseFieldVec; + const SIZE: usize = 1 << 26; + if !stwo_prover::platform::avx512_detected() { return; } - - const SIZE: usize = 1 << 26; - let data: Vec<_> = (0..SIZE as u32) - .map(BaseField::from_u32_unchecked) - .collect(); - let mut data: Vec<_> = data - .into_iter() - .array_chunks::<16>() - .map(PackedBaseField::from_array) - .collect(); - + let data = (0..SIZE).map(BaseField::from).collect::(); c.bench_function("avx bit_rev 26bit", |b| { - b.iter(|| { - bit_reverse_m31(cast_slice_mut(&mut data[..])); - }) + b.iter_batched( + || data.data.clone(), + |mut data| bit_reverse_m31(&mut data), + BatchSize::LargeInput, + ); }); } #[cfg(target_arch = "x86_64")] -criterion::criterion_group!( - name=avx_bit_rev; +criterion_group!( + name = bit_rev; + config = Criterion::default().sample_size(10); + targets = avx512_bit_rev, simd_bit_rev, cpu_bit_rev); +#[cfg(not(target_arch = "x86_64"))] +criterion_group!( + name = bit_rev; config = Criterion::default().sample_size(10); - targets=avx512_bit_rev, cpu_bit_rev); -criterion::criterion_main!(avx_bit_rev); + targets = simd_bit_rev, cpu_bit_rev); +criterion_main!(bit_rev);