Skip to content

Commit

Permalink
Move bit_reverse function to backend/cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
Gali-StarkWare committed Dec 10, 2024
1 parent 10a3f69 commit ef3a094
Show file tree
Hide file tree
Showing 14 changed files with 55 additions and 52 deletions.
2 changes: 1 addition & 1 deletion crates/prover/benches/bit_rev.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use itertools::Itertools;
use stwo_prover::core::fields::m31::BaseField;

pub fn cpu_bit_rev(c: &mut Criterion) {
use stwo_prover::core::utils::bit_reverse;
use stwo_prover::core::backend::cpu::bit_reverse;
// TODO(andrew): Consider using same size for all.
const SIZE: usize = 1 << 24;
let data = (0..SIZE).map(BaseField::from).collect_vec();
Expand Down
5 changes: 3 additions & 2 deletions crates/prover/src/constraint_framework/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use super::{
};
use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use crate::core::air::{Component, ComponentProver, Trace};
use crate::core::backend::cpu::bit_reverse;
use crate::core::backend::simd::column::VeryPackedSecureColumnByCoords;
use crate::core::backend::simd::m31::LOG_N_LANES;
use crate::core::backend::simd::very_packed_m31::{VeryPackedBaseField, LOG_N_VERY_PACKED_ELEMS};
Expand All @@ -30,7 +31,7 @@ use crate::core::fields::FieldExpOps;
use crate::core::pcs::{TreeSubspan, TreeVec};
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps};
use crate::core::poly::BitReversedOrder;
use crate::core::{utils, ColumnVec};
use crate::core::ColumnVec;

const CHUNK_SIZE: usize = 1;

Expand Down Expand Up @@ -292,7 +293,7 @@ impl<E: FrameworkEval + Sync> ComponentProver<SimdBackend> for FrameworkComponen
let mut denom_inv = (0..1 << log_expand)
.map(|i| coset_vanishing(trace_domain.coset(), eval_domain.at(i)).inverse())
.collect_vec();
utils::bit_reverse(&mut denom_inv);
bit_reverse(&mut denom_inv);

// Accumulator.
let [mut accum] =
Expand Down
3 changes: 2 additions & 1 deletion crates/prover/src/core/backend/cpu/circle.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use num_traits::Zero;

use super::CpuBackend;
use crate::core::backend::cpu::bit_reverse;
use crate::core::backend::{Col, ColumnOps};
use crate::core::circle::{CirclePoint, Coset};
use crate::core::fft::{butterfly, ibutterfly};
Expand All @@ -13,7 +14,7 @@ use crate::core::poly::circle::{
use crate::core::poly::twiddles::TwiddleTree;
use crate::core::poly::utils::{domain_line_twiddles_from_tree, fold};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::{bit_reverse, coset_order_to_circle_domain_order};
use crate::core::utils::coset_order_to_circle_domain_order;

impl PolyOps for CpuBackend {
type Twiddles = Vec<BaseField>;
Expand Down
34 changes: 33 additions & 1 deletion crates/prover/src/core/backend/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use super::{Backend, BackendForChannel, Column, ColumnOps, FieldOps};
use crate::core::fields::Field;
use crate::core::lookups::mle::Mle;
use crate::core::poly::circle::{CircleEvaluation, CirclePoly};
use crate::core::utils::bit_reverse;
use crate::core::utils::bit_reverse_index;
use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel;
#[cfg(not(target_arch = "wasm32"))]
use crate::core::vcs::poseidon252_merkle::Poseidon252MerkleChannel;
Expand All @@ -29,6 +29,23 @@ impl BackendForChannel<Blake2sMerkleChannel> for CpuBackend {}
#[cfg(not(target_arch = "wasm32"))]
impl BackendForChannel<Poseidon252MerkleChannel> for CpuBackend {}

/// Performs a naive bit-reversal permutation inplace.
///
/// # Panics
///
/// Panics if the length of the slice is not a power of two.
pub fn bit_reverse<T>(v: &mut [T]) {
let n = v.len();
assert!(n.is_power_of_two());
let log_n = n.ilog2();
for i in 0..n {
let j = bit_reverse_index(i, log_n);
if j > i {
v.swap(i, j);
}
}
}

impl<T: Debug + Clone + Default> ColumnOps<T> for CpuBackend {
type Column = Vec<T>;

Expand Down Expand Up @@ -79,10 +96,25 @@ mod tests {
use rand::prelude::*;
use rand::rngs::SmallRng;

use crate::core::backend::cpu::bit_reverse;
use crate::core::backend::{Column, CpuBackend, FieldOps};
use crate::core::fields::qm31::QM31;
use crate::core::fields::FieldExpOps;

#[test]
fn bit_reverse_works() {
let mut data = [0, 1, 2, 3, 4, 5, 6, 7];
bit_reverse(&mut data);
assert_eq!(data, [0, 4, 2, 6, 1, 5, 3, 7]);
}

#[test]
#[should_panic]
fn bit_reverse_non_power_of_two_size_fails() {
let mut data = [0, 1, 2, 3, 4, 5];
bit_reverse(&mut data);
}

#[test]
fn batch_inverse_test() {
let mut rng = SmallRng::seed_from_u64(0);
Expand Down
5 changes: 3 additions & 2 deletions crates/prover/src/core/backend/simd/bit_reverse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ use rayon::prelude::*;
use super::column::{BaseColumn, SecureColumn};
use super::m31::PackedBaseField;
use super::SimdBackend;
use crate::core::backend::cpu::bit_reverse as cpu_bit_reverse;
use crate::core::backend::simd::utils::UnsafeMut;
use crate::core::backend::ColumnOps;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::utils::{bit_reverse as cpu_bit_reverse, bit_reverse_index};
use crate::core::utils::bit_reverse_index;
use crate::parallel_iter;

const VEC_BITS: u32 = 4;
Expand Down Expand Up @@ -150,12 +151,12 @@ mod tests {
use itertools::Itertools;

use super::{bit_reverse16, bit_reverse_m31, MIN_LOG_SIZE};
use crate::core::backend::cpu::bit_reverse as cpu_bit_reverse;
use crate::core::backend::simd::column::BaseColumn;
use crate::core::backend::simd::m31::{PackedM31, N_LANES};
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Column, ColumnOps};
use crate::core::fields::m31::BaseField;
use crate::core::utils::bit_reverse as cpu_bit_reverse;

#[test]
fn test_bit_reverse16() {
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/core/backend/simd/domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ fn test_circle_domain_bit_rev_iterator() {
5,
));
let mut expected = domain.iter().collect::<Vec<_>>();
crate::core::utils::bit_reverse(&mut expected);
crate::core::backend::cpu::bit_reverse(&mut expected);
let actual = CircleDomainBitRevIterator::new(domain)
.flat_map(|c| -> [_; 16] {
std::array::from_fn(|i| CirclePoint {
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/core/backend/simd/fft/ifft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ use rayon::prelude::*;
use super::{
compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE,
};
use crate::core::backend::cpu::bit_reverse;
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::backend::simd::utils::UnsafeMut;
use crate::core::circle::Coset;
use crate::core::fields::FieldExpOps;
use crate::core::utils::bit_reverse;
use crate::parallel_iter;

/// Performs an Inverse Circle Fast Fourier Transform (ICFFT) on the given values.
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/core/backend/simd/fft/rfft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ use rayon::prelude::*;
use super::{
compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE,
};
use crate::core::backend::cpu::bit_reverse;
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::backend::simd::utils::{UnsafeConst, UnsafeMut};
use crate::core::circle::Coset;
use crate::core::utils::bit_reverse;
use crate::parallel_iter;

/// Performs a Circle Fast Fourier Transform (CFFT) on the given values.
Expand Down
5 changes: 2 additions & 3 deletions crates/prover/src/core/backend/simd/prefix_sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@ use std::ops::{AddAssign, Sub};
use itertools::{izip, Itertools};
use num_traits::Zero;

use crate::core::backend::cpu::bit_reverse;
use crate::core::backend::simd::m31::{PackedBaseField, N_LANES};
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Col, Column};
use crate::core::fields::m31::BaseField;
use crate::core::utils::{
bit_reverse, circle_domain_order_to_coset_order, coset_order_to_circle_domain_order,
};
use crate::core::utils::{circle_domain_order_to_coset_order, coset_order_to_circle_domain_order};

/// Performs a inclusive prefix sum on values in `Coset` order when provided
/// with evaluations in bit-reversed `CircleDomain` order.
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/core/backend/simd/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use super::domain::CircleDomainBitRevIterator;
use super::m31::{PackedBaseField, LOG_N_LANES, N_LANES};
use super::qm31::PackedSecureField;
use super::SimdBackend;
use crate::core::backend::cpu::bit_reverse;
use crate::core::backend::cpu::quotients::{batch_random_coeffs, column_line_coeffs};
use crate::core::backend::{Column, CpuBackend};
use crate::core::fields::m31::BaseField;
Expand All @@ -17,7 +18,6 @@ use crate::core::fields::FieldExpOps;
use crate::core::pcs::quotients::{ColumnSampleBatch, QuotientOps};
use crate::core::poly::circle::{CircleDomain, CircleEvaluation, PolyOps, SecureEvaluation};
use crate::core::poly::BitReversedOrder;
use crate::core::utils::bit_reverse;

pub struct QuotientConstants {
pub line_coeffs: Vec<Vec<(SecureField, SecureField, SecureField)>>,
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/core/poly/line.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ use serde::{Deserialize, Serialize};

use super::circle::CircleDomain;
use super::utils::fold;
use crate::core::backend::cpu::bit_reverse;
use crate::core::backend::{ColumnOps, CpuBackend};
use crate::core::circle::{CirclePoint, Coset, CosetIterator};
use crate::core::fft::ibutterfly;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumnByCoords;
use crate::core::fields::{ExtensionOf, FieldExpOps, FieldOps};
use crate::core::utils::bit_reverse;

/// Domain comprising of the x-coordinates of points in a [Coset].
///
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/core/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,10 @@ impl Deref for Queries {

#[cfg(test)]
mod tests {
use crate::core::backend::cpu::bit_reverse;
use crate::core::channel::Blake2sChannel;
use crate::core::poly::circle::CanonicCoset;
use crate::core::queries::Queries;
use crate::core::utils::bit_reverse;

#[test]
fn test_generate_queries() {
Expand Down
33 changes: 0 additions & 33 deletions crates/prover/src/core/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,24 +129,6 @@ pub const fn coset_index_to_circle_domain_index(coset_index: usize, log_domain_s
}
}

/// Performs a naive bit-reversal permutation inplace.
///
/// # Panics
///
/// Panics if the length of the slice is not a power of two.
// TODO(alont): Move this to the cpu backend.
pub fn bit_reverse<T>(v: &mut [T]) {
let n = v.len();
assert!(n.is_power_of_two());
let log_n = n.ilog2();
for i in 0..n {
let j = bit_reverse_index(i, log_n);
if j > i {
v.swap(i, j);
}
}
}

/// Performs a coset-natural-order to circle-domain-bit-reversed-order permutation in-place.
///
/// # Panics
Expand Down Expand Up @@ -187,23 +169,8 @@ mod tests {
use crate::core::fields::FieldExpOps;
use crate::core::poly::circle::CanonicCoset;
use crate::core::poly::NaturalOrder;
use crate::core::utils::bit_reverse;
use crate::{m31, qm31};

#[test]
fn bit_reverse_works() {
let mut data = [0, 1, 2, 3, 4, 5, 6, 7];
bit_reverse(&mut data);
assert_eq!(data, [0, 4, 2, 6, 1, 5, 3, 7]);
}

#[test]
#[should_panic]
fn bit_reverse_non_power_of_two_size_fails() {
let mut data = [0, 1, 2, 3, 4, 5];
bit_reverse(&mut data);
}

#[test]
fn generate_secure_powers_works() {
let felt = qm31!(1, 2, 3, 4);
Expand Down
8 changes: 5 additions & 3 deletions crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use crate::constraint_framework::{
};
use crate::core::air::accumulation::{DomainEvaluationAccumulator, PointEvaluationAccumulator};
use crate::core::air::{Component, ComponentProver, Trace};
use crate::core::backend::cpu::bit_reverse;
use crate::core::backend::simd::column::{SecureColumn, VeryPackedSecureColumnByCoords};
use crate::core::backend::simd::m31::LOG_N_LANES;
use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum;
Expand All @@ -36,7 +37,7 @@ use crate::core::poly::circle::{
};
use crate::core::poly::twiddles::TwiddleTree;
use crate::core::poly::BitReversedOrder;
use crate::core::utils::{self, bit_reverse_index, coset_index_to_circle_domain_index};
use crate::core::utils::{bit_reverse_index, coset_index_to_circle_domain_index};
use crate::core::ColumnVec;

/// Prover component that carries out a univariate IOP for multilinear eval at point.
Expand Down Expand Up @@ -231,7 +232,7 @@ impl<'twiddles, 'oracle, O: MleCoeffColumnOracle> ComponentProver<SimdBackend>
let mut denom_inv = (0..1 << log_expand)
.map(|i| coset_vanishing(trace_domain.coset(), eval_domain.at(i)).inverse())
.collect_vec();
utils::bit_reverse(&mut denom_inv);
bit_reverse(&mut denom_inv);

// Accumulator.
let [mut acc] = accumulator.columns([(eval_domain.log_size(), self.n_constraints())]);
Expand Down Expand Up @@ -752,6 +753,7 @@ mod tests {
};
use crate::constraint_framework::{assert_constraints, EvalAtRow, TraceLocationAllocator};
use crate::core::air::{Component, ComponentProver, Components};
use crate::core::backend::cpu::bit_reverse;
use crate::core::backend::simd::prefix_sum::inclusive_prefix_sum;
use crate::core::backend::simd::qm31::PackedSecureField;
use crate::core::backend::simd::SimdBackend;
Expand All @@ -765,7 +767,7 @@ mod tests {
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation, PolyOps};
use crate::core::poly::BitReversedOrder;
use crate::core::prover::{prove, verify, VerificationError};
use crate::core::utils::{bit_reverse, coset_order_to_circle_domain_order};
use crate::core::utils::coset_order_to_circle_domain_order;
use crate::core::vcs::blake2_merkle::Blake2sMerkleChannel;
use crate::examples::xor::gkr_lookups::accumulation::MIN_LOG_BLOWUP_FACTOR;
use crate::examples::xor::gkr_lookups::mle_eval::eval_step_selector_with_offset;
Expand Down

0 comments on commit ef3a094

Please sign in to comment.