Skip to content

Commit

Permalink
allocating batch inverse
Browse files Browse the repository at this point in the history
  • Loading branch information
ohad-starkware committed Jan 12, 2025
1 parent d0a0209 commit 47ddd44
Show file tree
Hide file tree
Showing 9 changed files with 25 additions and 30 deletions.
2 changes: 1 addition & 1 deletion crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ impl LogupColGenerator<'_> {

/// Finalizes generating the column.
pub fn finalize_col(mut self) {
FieldExpOps::batch_inverse(&self.gen.denom.data, &mut self.gen.denom_inv.data);
FieldExpOps::batch_inverse_in_place(&self.gen.denom.data, &mut self.gen.denom_inv.data);

for vec_row in 0..(1 << (self.gen.log_size - LOG_N_LANES)) {
unsafe {
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/core/backend/cpu/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ impl PolyOps for CpuBackend {
.array_chunks::<CHUNK_SIZE>()
.zip(itwiddles.array_chunks_mut::<CHUNK_SIZE>())
.for_each(|(src, dst)| {
BaseField::batch_inverse(src, dst);
BaseField::batch_inverse_in_place(src, dst);
});

TwiddleTree {
Expand Down
4 changes: 2 additions & 2 deletions crates/prover/src/core/backend/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,13 @@ mod tests {
}

#[test]
fn batch_inverse_test() {
fn batch_inverse_in_place_test() {
let mut rng = SmallRng::seed_from_u64(0);
let column = rng.gen::<[QM31; 16]>().to_vec();
let expected = column.iter().map(|e| e.inverse()).collect_vec();
let mut dst = Vec::zeros(column.len());

FieldExpOps::batch_inverse(&column, &mut dst);
FieldExpOps::batch_inverse_in_place(&column, &mut dst);

assert_eq!(expected, dst);
}
Expand Down
7 changes: 2 additions & 5 deletions crates/prover/src/core/backend/cpu/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::core::fields::cm31::CM31;
use crate::core::fields::m31::{BaseField, M31};
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumnByCoords;
use crate::core::fields::FieldExpOps;
use crate::core::fields::{batch_inverse, FieldExpOps};
use crate::core::pcs::quotients::{ColumnSampleBatch, PointSample, QuotientOps};
use crate::core::poly::circle::{CircleDomain, CircleEvaluation, SecureEvaluation};
use crate::core::poly::BitReversedOrder;
Expand Down Expand Up @@ -132,10 +132,7 @@ fn denominator_inverses(
denominators.push((prx - domain_point.x) * piy - (pry - domain_point.y) * pix);
}

let mut denominator_inverses = vec![CM31::zero(); denominators.len()];
CM31::batch_inverse(&denominators, &mut denominator_inverses);

denominator_inverses
batch_inverse(&denominators)
}

pub fn quotient_constants(
Expand Down
8 changes: 3 additions & 5 deletions crates/prover/src/core/backend/simd/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::core::backend::{Col, Column, CpuBackend};
use crate::core::circle::{CirclePoint, Coset, M31_CIRCLE_LOG_ORDER};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::{Field, FieldExpOps};
use crate::core::fields::{batch_inverse, Field, FieldExpOps};
use crate::core::poly::circle::{
CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps,
};
Expand Down Expand Up @@ -96,8 +96,7 @@ impl SimdBackend {
denominators.push(denominators[i - 1] * mappings[i]);
}

let mut denom_inverses = vec![F::zero(); denominators.len()];
F::batch_inverse(&denominators, &mut denom_inverses);
let denom_inverses = batch_inverse(&denominators);

let mut steps = vec![mappings[0]];

Expand Down Expand Up @@ -311,8 +310,7 @@ impl PolyOps for SimdBackend {
remaining_twiddles.try_into().unwrap(),
));

let mut itwiddles = unsafe { BaseColumn::uninitialized(root_coset.size()) }.data;
PackedBaseField::batch_inverse(&twiddles, &mut itwiddles);
let itwiddles = batch_inverse(&twiddles);

let dbl_twiddles = twiddles
.into_iter()
Expand Down
12 changes: 3 additions & 9 deletions crates/prover/src/core/backend/simd/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ use super::qm31::PackedSecureField;
use super::SimdBackend;
use crate::core::backend::cpu::bit_reverse;
use crate::core::backend::cpu::quotients::{batch_random_coeffs, column_line_coeffs};
use crate::core::backend::{Column, CpuBackend};
use crate::core::backend::CpuBackend;
use crate::core::fields::batch_inverse;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::{SecureColumnByCoords, SECURE_EXTENSION_DEGREE};
use crate::core::fields::FieldExpOps;
use crate::core::pcs::quotients::{ColumnSampleBatch, QuotientOps};
use crate::core::poly::circle::{CircleDomain, CircleEvaluation, PolyOps, SecureEvaluation};
use crate::core::poly::BitReversedOrder;
Expand Down Expand Up @@ -243,15 +243,9 @@ fn denominator_inverses(
})
.collect();

let mut flat_denominator_inverses =
unsafe { CM31Column::uninitialized(flat_denominators.len()) };
FieldExpOps::batch_inverse(
&flat_denominators.data,
&mut flat_denominator_inverses.data[..],
);
let flat_denominator_inverses = batch_inverse(&flat_denominators.data);

flat_denominator_inverses
.data
.chunks(domain.size() / N_LANES)
.map(|denominator_inverses| denominator_inverses.iter().copied().collect())
.collect()
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/core/backend/simd/very_packed_m31.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ impl<A: One + Copy, const N: usize> One for Vectorized<A, N> {
impl<A: FieldExpOps + Zero + Copy, const N: usize> FieldExpOps for Vectorized<A, N> {
fn inverse(&self) -> Self {
let mut dst = [A::zero(); N];
A::batch_inverse(&self.0, &mut dst);
A::batch_inverse_in_place(&self.0, &mut dst);
dst.into()
}
}
13 changes: 10 additions & 3 deletions crates/prover/src/core/fields/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub trait FieldExpOps: Mul<Output = Self> + MulAssign + Sized + One + Clone {
fn inverse(&self) -> Self;

/// Inverts a batch of elements using Montgomery's trick.
fn batch_inverse(column: &[Self], dst: &mut [Self]) {
fn batch_inverse_in_place(column: &[Self], dst: &mut [Self]) {
const WIDTH: usize = 4;
let n = column.len();
debug_assert!(dst.len() >= n);
Expand Down Expand Up @@ -91,6 +91,13 @@ fn batch_inverse_classic<T: FieldExpOps>(column: &[T], dst: &mut [T]) {
dst[0] = curr_inverse;
}

// TODO(Ohad): chunks, parallelize.
pub fn batch_inverse<T: FieldExpOps>(column: &[T]) -> Vec<T> {
let mut dst = vec![unsafe { std::mem::zeroed() }; column.len()];
T::batch_inverse_in_place(column, &mut dst);
dst
}

pub trait Field:
NumAssign
+ Neg<Output = Self>
Expand Down Expand Up @@ -470,7 +477,7 @@ mod tests {
let expected = elements.iter().map(|e| e.inverse()).collect::<Vec<_>>();
let mut dst = [M31::zero(); 16];

M31::batch_inverse(&elements, &mut dst);
M31::batch_inverse_in_place(&elements, &mut dst);

assert_eq!(expected, dst);
}
Expand All @@ -482,6 +489,6 @@ mod tests {
let elements: [M31; 16] = rng.gen();
let mut dst = [M31::zero(); 15];

M31::batch_inverse(&elements, &mut dst);
M31::batch_inverse_in_place(&elements, &mut dst);
}
}
5 changes: 2 additions & 3 deletions crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use crate::core::constraints::{coset_vanishing, point_vanishing};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumnByCoords;
use crate::core::fields::{Field, FieldExpOps};
use crate::core::fields::{batch_inverse, Field, FieldExpOps};
use crate::core::lookups::gkr_prover::GkrOps;
use crate::core::lookups::mle::Mle;
use crate::core::lookups::utils::eq;
Expand Down Expand Up @@ -687,8 +687,7 @@ fn eval_step_selector(coset: Coset, log_step: u32, p: CirclePoint<SecureField>)
vanish_at_log_step.reverse();
// We only need the first `log_step` many values.
vanish_at_log_step.truncate(log_step as usize);
let mut vanish_at_log_step_inv = vec![SecureField::zero(); vanish_at_log_step.len()];
SecureField::batch_inverse(&vanish_at_log_step, &mut vanish_at_log_step_inv);
let vanish_at_log_step_inv = batch_inverse(&vanish_at_log_step);

let half_coset_selector_dbl = (vanish_at_log_step[0] * vanish_at_log_step_inv[1]).square();
let vanish_substep_inv_sum = vanish_at_log_step_inv[1..].iter().sum::<SecureField>();
Expand Down

0 comments on commit 47ddd44

Please sign in to comment.