From 0a2e9b9e638ef24954e3fe2c4a2174c0fdb51997 Mon Sep 17 00:00:00 2001 From: Gali Michlevich Date: Tue, 10 Dec 2024 13:16:46 +0200 Subject: [PATCH] Add secure powers generation for simd --- crates/prover/src/core/backend/simd/utils.rs | 41 +++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/crates/prover/src/core/backend/simd/utils.rs b/crates/prover/src/core/backend/simd/utils.rs index d5f53a22b..e02846bc6 100644 --- a/crates/prover/src/core/backend/simd/utils.rs +++ b/crates/prover/src/core/backend/simd/utils.rs @@ -1,5 +1,12 @@ use std::simd::Swizzle; +use itertools::Itertools; + +use crate::core::backend::simd::m31::N_LANES; +use crate::core::backend::simd::qm31::PackedSecureField; +use crate::core::fields::qm31::SecureField; +use crate::core::utils::generate_secure_powers; + /// Used with [`Swizzle::concat_swizzle`] to interleave the even values of two vectors. pub struct InterleaveEvens; @@ -51,11 +58,32 @@ impl UnsafeConst { unsafe impl Send for UnsafeConst {} unsafe impl Sync for UnsafeConst {} +// TODO(Gali): Remove #[allow(dead_code)]. +#[allow(dead_code)] +pub fn generate_secure_powers_simd(felt: SecureField, n_powers: usize) -> Vec { + let base_arr = generate_secure_powers(felt, N_LANES).try_into().unwrap(); + let base = PackedSecureField::from_array(base_arr); + let step = PackedSecureField::broadcast(base_arr[N_LANES - 1] * felt); + let size = n_powers.div_ceil(N_LANES); + + (0..size) + .scan(base, |acc, _| { + let res = *acc; + *acc *= step; + Some(res) + }) + .flat_map(|x| x.to_array()) + .take(n_powers) + .collect_vec() +} + #[cfg(test)] mod tests { use std::simd::{u32x4, Swizzle}; - use super::{InterleaveEvens, InterleaveOdds}; + use super::{generate_secure_powers_simd, InterleaveEvens, InterleaveOdds}; + use crate::core::utils::generate_secure_powers; + use crate::qm31; #[test] fn interleave_evens() { @@ -76,4 +104,15 @@ mod tests { assert_eq!(res, u32x4::from_array([1, 5, 3, 7])); } + + #[test] + fn test_generate_secure_powers_simd() { + let felt = qm31!(1, 2, 3, 4); + let n_powers = 10000; + + let cpu_powers = generate_secure_powers(felt, n_powers); + let powers = generate_secure_powers_simd(felt, n_powers); + + assert_eq!(powers, cpu_powers); + } }