From 47ddd44c8f828fad817f8979c87f27ab1dfcc527 Mon Sep 17 00:00:00 2001 From: ohad-starkware Date: Sun, 12 Jan 2025 15:40:32 +0200 Subject: [PATCH] allocating batch inverse --- crates/prover/src/constraint_framework/logup.rs | 2 +- crates/prover/src/core/backend/cpu/circle.rs | 2 +- crates/prover/src/core/backend/cpu/mod.rs | 4 ++-- crates/prover/src/core/backend/cpu/quotients.rs | 7 ++----- crates/prover/src/core/backend/simd/circle.rs | 8 +++----- crates/prover/src/core/backend/simd/quotients.rs | 12 +++--------- .../prover/src/core/backend/simd/very_packed_m31.rs | 2 +- crates/prover/src/core/fields/mod.rs | 13 ++++++++++--- .../prover/src/examples/xor/gkr_lookups/mle_eval.rs | 5 ++--- 9 files changed, 25 insertions(+), 30 deletions(-) diff --git a/crates/prover/src/constraint_framework/logup.rs b/crates/prover/src/constraint_framework/logup.rs index 370987e4c..601085124 100644 --- a/crates/prover/src/constraint_framework/logup.rs +++ b/crates/prover/src/constraint_framework/logup.rs @@ -259,7 +259,7 @@ impl LogupColGenerator<'_> { /// Finalizes generating the column. pub fn finalize_col(mut self) { - FieldExpOps::batch_inverse(&self.gen.denom.data, &mut self.gen.denom_inv.data); + FieldExpOps::batch_inverse_in_place(&self.gen.denom.data, &mut self.gen.denom_inv.data); for vec_row in 0..(1 << (self.gen.log_size - LOG_N_LANES)) { unsafe { diff --git a/crates/prover/src/core/backend/cpu/circle.rs b/crates/prover/src/core/backend/cpu/circle.rs index 21351d164..ce3d3a186 100644 --- a/crates/prover/src/core/backend/cpu/circle.rs +++ b/crates/prover/src/core/backend/cpu/circle.rs @@ -172,7 +172,7 @@ impl PolyOps for CpuBackend { .array_chunks::() .zip(itwiddles.array_chunks_mut::()) .for_each(|(src, dst)| { - BaseField::batch_inverse(src, dst); + BaseField::batch_inverse_in_place(src, dst); }); TwiddleTree { diff --git a/crates/prover/src/core/backend/cpu/mod.rs b/crates/prover/src/core/backend/cpu/mod.rs index 4ae7c3e5c..5baf7e175 100644 --- a/crates/prover/src/core/backend/cpu/mod.rs +++ b/crates/prover/src/core/backend/cpu/mod.rs @@ -107,13 +107,13 @@ mod tests { } #[test] - fn batch_inverse_test() { + fn batch_inverse_in_place_test() { let mut rng = SmallRng::seed_from_u64(0); let column = rng.gen::<[QM31; 16]>().to_vec(); let expected = column.iter().map(|e| e.inverse()).collect_vec(); let mut dst = Vec::zeros(column.len()); - FieldExpOps::batch_inverse(&column, &mut dst); + FieldExpOps::batch_inverse_in_place(&column, &mut dst); assert_eq!(expected, dst); } diff --git a/crates/prover/src/core/backend/cpu/quotients.rs b/crates/prover/src/core/backend/cpu/quotients.rs index 16f0647b6..0aedd765a 100644 --- a/crates/prover/src/core/backend/cpu/quotients.rs +++ b/crates/prover/src/core/backend/cpu/quotients.rs @@ -8,7 +8,7 @@ use crate::core::fields::cm31::CM31; use crate::core::fields::m31::{BaseField, M31}; use crate::core::fields::qm31::SecureField; use crate::core::fields::secure_column::SecureColumnByCoords; -use crate::core::fields::FieldExpOps; +use crate::core::fields::{batch_inverse, FieldExpOps}; use crate::core::pcs::quotients::{ColumnSampleBatch, PointSample, QuotientOps}; use crate::core::poly::circle::{CircleDomain, CircleEvaluation, SecureEvaluation}; use crate::core::poly::BitReversedOrder; @@ -132,10 +132,7 @@ fn denominator_inverses( denominators.push((prx - domain_point.x) * piy - (pry - domain_point.y) * pix); } - let mut denominator_inverses = vec![CM31::zero(); denominators.len()]; - CM31::batch_inverse(&denominators, &mut denominator_inverses); - - denominator_inverses + batch_inverse(&denominators) } pub fn quotient_constants( diff --git a/crates/prover/src/core/backend/simd/circle.rs b/crates/prover/src/core/backend/simd/circle.rs index 61588ffe3..962d13f90 100644 --- a/crates/prover/src/core/backend/simd/circle.rs +++ b/crates/prover/src/core/backend/simd/circle.rs @@ -16,7 +16,7 @@ use crate::core::backend::{Col, Column, CpuBackend}; use crate::core::circle::{CirclePoint, Coset, M31_CIRCLE_LOG_ORDER}; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; -use crate::core::fields::{Field, FieldExpOps}; +use crate::core::fields::{batch_inverse, Field, FieldExpOps}; use crate::core::poly::circle::{ CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps, }; @@ -96,8 +96,7 @@ impl SimdBackend { denominators.push(denominators[i - 1] * mappings[i]); } - let mut denom_inverses = vec![F::zero(); denominators.len()]; - F::batch_inverse(&denominators, &mut denom_inverses); + let denom_inverses = batch_inverse(&denominators); let mut steps = vec![mappings[0]]; @@ -311,8 +310,7 @@ impl PolyOps for SimdBackend { remaining_twiddles.try_into().unwrap(), )); - let mut itwiddles = unsafe { BaseColumn::uninitialized(root_coset.size()) }.data; - PackedBaseField::batch_inverse(&twiddles, &mut itwiddles); + let itwiddles = batch_inverse(&twiddles); let dbl_twiddles = twiddles .into_iter() diff --git a/crates/prover/src/core/backend/simd/quotients.rs b/crates/prover/src/core/backend/simd/quotients.rs index f0155ebb1..04d02cdf4 100644 --- a/crates/prover/src/core/backend/simd/quotients.rs +++ b/crates/prover/src/core/backend/simd/quotients.rs @@ -12,11 +12,11 @@ use super::qm31::PackedSecureField; use super::SimdBackend; use crate::core::backend::cpu::bit_reverse; use crate::core::backend::cpu::quotients::{batch_random_coeffs, column_line_coeffs}; -use crate::core::backend::{Column, CpuBackend}; +use crate::core::backend::CpuBackend; +use crate::core::fields::batch_inverse; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::secure_column::{SecureColumnByCoords, SECURE_EXTENSION_DEGREE}; -use crate::core::fields::FieldExpOps; use crate::core::pcs::quotients::{ColumnSampleBatch, QuotientOps}; use crate::core::poly::circle::{CircleDomain, CircleEvaluation, PolyOps, SecureEvaluation}; use crate::core::poly::BitReversedOrder; @@ -243,15 +243,9 @@ fn denominator_inverses( }) .collect(); - let mut flat_denominator_inverses = - unsafe { CM31Column::uninitialized(flat_denominators.len()) }; - FieldExpOps::batch_inverse( - &flat_denominators.data, - &mut flat_denominator_inverses.data[..], - ); + let flat_denominator_inverses = batch_inverse(&flat_denominators.data); flat_denominator_inverses - .data .chunks(domain.size() / N_LANES) .map(|denominator_inverses| denominator_inverses.iter().copied().collect()) .collect() diff --git a/crates/prover/src/core/backend/simd/very_packed_m31.rs b/crates/prover/src/core/backend/simd/very_packed_m31.rs index 781212d6f..8456deb4c 100644 --- a/crates/prover/src/core/backend/simd/very_packed_m31.rs +++ b/crates/prover/src/core/backend/simd/very_packed_m31.rs @@ -247,7 +247,7 @@ impl One for Vectorized { impl FieldExpOps for Vectorized { fn inverse(&self) -> Self { let mut dst = [A::zero(); N]; - A::batch_inverse(&self.0, &mut dst); + A::batch_inverse_in_place(&self.0, &mut dst); dst.into() } } diff --git a/crates/prover/src/core/fields/mod.rs b/crates/prover/src/core/fields/mod.rs index b19ea9bc9..f84645a23 100644 --- a/crates/prover/src/core/fields/mod.rs +++ b/crates/prover/src/core/fields/mod.rs @@ -31,7 +31,7 @@ pub trait FieldExpOps: Mul + MulAssign + Sized + One + Clone { fn inverse(&self) -> Self; /// Inverts a batch of elements using Montgomery's trick. - fn batch_inverse(column: &[Self], dst: &mut [Self]) { + fn batch_inverse_in_place(column: &[Self], dst: &mut [Self]) { const WIDTH: usize = 4; let n = column.len(); debug_assert!(dst.len() >= n); @@ -91,6 +91,13 @@ fn batch_inverse_classic(column: &[T], dst: &mut [T]) { dst[0] = curr_inverse; } +// TODO(Ohad): chunks, parallelize. +pub fn batch_inverse(column: &[T]) -> Vec { + let mut dst = vec![unsafe { std::mem::zeroed() }; column.len()]; + T::batch_inverse_in_place(column, &mut dst); + dst +} + pub trait Field: NumAssign + Neg @@ -470,7 +477,7 @@ mod tests { let expected = elements.iter().map(|e| e.inverse()).collect::>(); let mut dst = [M31::zero(); 16]; - M31::batch_inverse(&elements, &mut dst); + M31::batch_inverse_in_place(&elements, &mut dst); assert_eq!(expected, dst); } @@ -482,6 +489,6 @@ mod tests { let elements: [M31; 16] = rng.gen(); let mut dst = [M31::zero(); 15]; - M31::batch_inverse(&elements, &mut dst); + M31::batch_inverse_in_place(&elements, &mut dst); } } diff --git a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs index 3c0c17d0c..9159ba438 100644 --- a/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs +++ b/crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs @@ -27,7 +27,7 @@ use crate::core::constraints::{coset_vanishing, point_vanishing}; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::secure_column::SecureColumnByCoords; -use crate::core::fields::{Field, FieldExpOps}; +use crate::core::fields::{batch_inverse, Field, FieldExpOps}; use crate::core::lookups::gkr_prover::GkrOps; use crate::core::lookups::mle::Mle; use crate::core::lookups::utils::eq; @@ -687,8 +687,7 @@ fn eval_step_selector(coset: Coset, log_step: u32, p: CirclePoint) vanish_at_log_step.reverse(); // We only need the first `log_step` many values. vanish_at_log_step.truncate(log_step as usize); - let mut vanish_at_log_step_inv = vec![SecureField::zero(); vanish_at_log_step.len()]; - SecureField::batch_inverse(&vanish_at_log_step, &mut vanish_at_log_step_inv); + let vanish_at_log_step_inv = batch_inverse(&vanish_at_log_step); let half_coset_selector_dbl = (vanish_at_log_step[0] * vanish_at_log_step_inv[1]).square(); let vanish_substep_inv_sum = vanish_at_log_step_inv[1..].iter().sum::();