Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allocating batch inverse #977

Merged
merged 1 commit into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ use crate::core::backend::simd::qm31::PackedSecureField;
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::Column;
use crate::core::channel::Channel;
use crate::core::fields::batch_inverse_in_place;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumnByCoords;
use crate::core::fields::FieldExpOps;
use crate::core::lookups::utils::Fraction;
use crate::core::poly::circle::{CanonicCoset, CircleEvaluation};
use crate::core::poly::BitReversedOrder;
Expand Down Expand Up @@ -259,7 +259,7 @@ impl LogupColGenerator<'_> {

/// Finalizes generating the column.
pub fn finalize_col(mut self) {
FieldExpOps::batch_inverse(&self.gen.denom.data, &mut self.gen.denom_inv.data);
batch_inverse_in_place(&self.gen.denom.data, &mut self.gen.denom_inv.data);

for vec_row in 0..(1 << (self.gen.log_size - LOG_N_LANES)) {
unsafe {
Expand Down
4 changes: 2 additions & 2 deletions crates/prover/src/core/backend/cpu/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::core::circle::{CirclePoint, Coset};
use crate::core::fft::{butterfly, ibutterfly};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::{ExtensionOf, FieldExpOps};
use crate::core::fields::{batch_inverse_in_place, ExtensionOf};
use crate::core::poly::circle::{
CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps,
};
Expand Down Expand Up @@ -172,7 +172,7 @@ impl PolyOps for CpuBackend {
.array_chunks::<CHUNK_SIZE>()
.zip(itwiddles.array_chunks_mut::<CHUNK_SIZE>())
.for_each(|(src, dst)| {
BaseField::batch_inverse(src, dst);
batch_inverse_in_place(src, dst);
});

TwiddleTree {
Expand Down
7 changes: 4 additions & 3 deletions crates/prover/src/core/backend/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ mod tests {
use crate::core::backend::cpu::bit_reverse;
use crate::core::backend::Column;
use crate::core::fields::qm31::QM31;
use crate::core::fields::FieldExpOps;
use crate::core::fields::{batch_inverse_in_place, FieldExpOps};

#[test]
fn bit_reverse_works() {
Expand All @@ -106,14 +106,15 @@ mod tests {
bit_reverse(&mut data);
}

// TODO(Ohad): remove.
#[test]
fn batch_inverse_test() {
fn batch_inverse_in_place_test() {
let mut rng = SmallRng::seed_from_u64(0);
let column = rng.gen::<[QM31; 16]>().to_vec();
let expected = column.iter().map(|e| e.inverse()).collect_vec();
let mut dst = Vec::zeros(column.len());

FieldExpOps::batch_inverse(&column, &mut dst);
batch_inverse_in_place(&column, &mut dst);

assert_eq!(expected, dst);
}
Expand Down
5 changes: 1 addition & 4 deletions crates/prover/src/core/backend/cpu/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,7 @@ fn denominator_inverses(
denominators.push((prx - domain_point.x) * piy - (pry - domain_point.y) * pix);
}

let mut denominator_inverses = vec![CM31::zero(); denominators.len()];
CM31::batch_inverse(&denominators, &mut denominator_inverses);

denominator_inverses
CM31::batch_inverse(&denominators)
}

pub fn quotient_constants(
Expand Down
6 changes: 2 additions & 4 deletions crates/prover/src/core/backend/simd/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ impl SimdBackend {
denominators.push(denominators[i - 1] * mappings[i]);
}

let mut denom_inverses = vec![F::zero(); denominators.len()];
F::batch_inverse(&denominators, &mut denom_inverses);
let denom_inverses = F::batch_inverse(&denominators);

let mut steps = vec![mappings[0]];

Expand Down Expand Up @@ -311,8 +310,7 @@ impl PolyOps for SimdBackend {
remaining_twiddles.try_into().unwrap(),
));

let mut itwiddles = unsafe { BaseColumn::uninitialized(root_coset.size()) }.data;
PackedBaseField::batch_inverse(&twiddles, &mut itwiddles);
let itwiddles = PackedBaseField::batch_inverse(&twiddles);

let dbl_twiddles = twiddles
.into_iter()
Expand Down
10 changes: 2 additions & 8 deletions crates/prover/src/core/backend/simd/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use super::qm31::PackedSecureField;
use super::SimdBackend;
use crate::core::backend::cpu::bit_reverse;
use crate::core::backend::cpu::quotients::{batch_random_coeffs, column_line_coeffs};
use crate::core::backend::{Column, CpuBackend};
use crate::core::backend::CpuBackend;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::{SecureColumnByCoords, SECURE_EXTENSION_DEGREE};
Expand Down Expand Up @@ -251,15 +251,9 @@ fn denominator_inverses(
})
.collect();

let mut flat_denominator_inverses =
unsafe { CM31Column::uninitialized(flat_denominators.len()) };
FieldExpOps::batch_inverse(
&flat_denominators.data,
&mut flat_denominator_inverses.data[..],
);
let flat_denominator_inverses = PackedCM31::batch_inverse(&flat_denominators.data);

flat_denominator_inverses
.data
.chunks(domain.size() / N_LANES)
.map(|denominator_inverses| denominator_inverses.iter().copied().collect())
.collect()
Expand Down
4 changes: 2 additions & 2 deletions crates/prover/src/core/backend/simd/very_packed_m31.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use super::qm31::PackedQM31;
use crate::core::fields::cm31::CM31;
use crate::core::fields::m31::M31;
use crate::core::fields::qm31::QM31;
use crate::core::fields::FieldExpOps;
use crate::core::fields::{batch_inverse_in_place, FieldExpOps};

pub const LOG_N_VERY_PACKED_ELEMS: u32 = 1;
pub const N_VERY_PACKED_ELEMS: usize = 1 << LOG_N_VERY_PACKED_ELEMS;
Expand Down Expand Up @@ -247,7 +247,7 @@ impl<A: One + Copy, const N: usize> One for Vectorized<A, N> {
impl<A: FieldExpOps + Zero + Copy, const N: usize> FieldExpOps for Vectorized<A, N> {
fn inverse(&self) -> Self {
let mut dst = [A::zero(); N];
A::batch_inverse(&self.0, &mut dst);
batch_inverse_in_place(&self.0, &mut dst);
dst.into()
}
}
85 changes: 48 additions & 37 deletions crates/prover/src/core/fields/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,37 +30,8 @@ pub trait FieldExpOps: Mul<Output = Self> + MulAssign + Sized + One + Clone {

fn inverse(&self) -> Self;

/// Inverts a batch of elements using Montgomery's trick.
fn batch_inverse(column: &[Self], dst: &mut [Self]) {
const WIDTH: usize = 4;
let n = column.len();
debug_assert!(dst.len() >= n);

if n <= WIDTH || n % WIDTH != 0 {
batch_inverse_classic(column, dst);
return;
}

// First pass. Compute 'WIDTH' cumulative products in an interleaving fashion, reducing
// instruction dependency and allowing better pipelining.
let mut cum_prod: [Self; WIDTH] = std::array::from_fn(|_| Self::one());
dst[..WIDTH].clone_from_slice(&cum_prod);
for i in 0..n {
cum_prod[i % WIDTH] *= column[i].clone();
dst[i] = cum_prod[i % WIDTH].clone();
}

// Inverse cumulative products.
// Use classic batch inversion.
let mut tail_inverses: [Self; WIDTH] = std::array::from_fn(|_| Self::one());
batch_inverse_classic(&dst[n - WIDTH..], &mut tail_inverses);

// Second pass.
for i in (WIDTH..n).rev() {
dst[i] = dst[i - WIDTH].clone() * tail_inverses[i % WIDTH].clone();
tail_inverses[i % WIDTH] *= column[i].clone();
}
dst[0..WIDTH].clone_from_slice(&tail_inverses);
fn batch_inverse(column: &[Self]) -> Vec<Self> {
batch_inverse(column)
}
}

Expand Down Expand Up @@ -91,6 +62,46 @@ fn batch_inverse_classic<T: FieldExpOps>(column: &[T], dst: &mut [T]) {
dst[0] = curr_inverse;
}

/// Inverts a batch of elements using Montgomery's trick.
pub fn batch_inverse_in_place<F: FieldExpOps>(column: &[F], dst: &mut [F]) {
const WIDTH: usize = 4;
let n = column.len();
debug_assert!(dst.len() >= n);

if n <= WIDTH || n % WIDTH != 0 {
batch_inverse_classic(column, dst);
return;
}

// First pass. Compute 'WIDTH' cumulative products in an interleaving fashion, reducing
// instruction dependency and allowing better pipelining.
let mut cum_prod: [F; WIDTH] = std::array::from_fn(|_| F::one());
dst[..WIDTH].clone_from_slice(&cum_prod);
for i in 0..n {
cum_prod[i % WIDTH] *= column[i].clone();
dst[i] = cum_prod[i % WIDTH].clone();
}

// Inverse cumulative products.
// Use classic batch inversion.
let mut tail_inverses: [F; WIDTH] = std::array::from_fn(|_| F::one());
batch_inverse_classic(&dst[n - WIDTH..], &mut tail_inverses);

// Second pass.
for i in (WIDTH..n).rev() {
dst[i] = dst[i - WIDTH].clone() * tail_inverses[i % WIDTH].clone();
tail_inverses[i % WIDTH] *= column[i].clone();
}
dst[0..WIDTH].clone_from_slice(&tail_inverses);
}

// TODO(Ohad): chunks, parallelize.
pub fn batch_inverse<F: FieldExpOps>(column: &[F]) -> Vec<F> {
let mut dst = vec![unsafe { std::mem::zeroed() }; column.len()];
batch_inverse_in_place(column, &mut dst);
dst
}

pub trait Field:
NumAssign
+ Neg<Output = Self>
Expand Down Expand Up @@ -460,19 +471,19 @@ mod tests {
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};

use super::batch_inverse_in_place;
use crate::core::fields::batch_inverse;
use crate::core::fields::m31::M31;
use crate::core::fields::FieldExpOps;

#[test]
fn test_slice_batch_inverse() {
fn test_batch_inverse() {
let mut rng = SmallRng::seed_from_u64(0);
let elements: [M31; 16] = rng.gen();
let expected = elements.iter().map(|e| e.inverse()).collect::<Vec<_>>();
let mut dst = [M31::zero(); 16];

M31::batch_inverse(&elements, &mut dst);
let actual = batch_inverse(&elements);

assert_eq!(expected, dst);
assert_eq!(expected, actual);
}

#[test]
Expand All @@ -482,6 +493,6 @@ mod tests {
let elements: [M31; 16] = rng.gen();
let mut dst = [M31::zero(); 15];

M31::batch_inverse(&elements, &mut dst);
batch_inverse_in_place(&elements, &mut dst);
}
}
3 changes: 1 addition & 2 deletions crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -688,8 +688,7 @@ fn eval_step_selector(coset: Coset, log_step: u32, p: CirclePoint<SecureField>)
vanish_at_log_step.reverse();
// We only need the first `log_step` many values.
vanish_at_log_step.truncate(log_step as usize);
let mut vanish_at_log_step_inv = vec![SecureField::zero(); vanish_at_log_step.len()];
SecureField::batch_inverse(&vanish_at_log_step, &mut vanish_at_log_step_inv);
let vanish_at_log_step_inv = SecureField::batch_inverse(&vanish_at_log_step);

let half_coset_selector_dbl = (vanish_at_log_step[0] * vanish_at_log_step_inv[1]).square();
let vanish_substep_inv_sum = vanish_at_log_step_inv[1..].iter().sum::<SecureField>();
Expand Down
Loading