diff --git a/crates/prover/src/core/backend/simd/fft/ifft.rs b/crates/prover/src/core/backend/simd/fft/ifft.rs index 35a458d02..526d2d9d3 100644 --- a/crates/prover/src/core/backend/simd/fft/ifft.rs +++ b/crates/prover/src/core/backend/simd/fft/ifft.rs @@ -1,17 +1,17 @@ //! Inverse fft. -use std::arch::x86_64::{ - __m512i, _mm512_broadcast_i32x4, _mm512_mul_epu32, _mm512_permutex2var_epi32, - _mm512_set1_epi32, _mm512_set1_epi64, _mm512_srli_epi64, -}; +use std::simd::{simd_swizzle, u32x16, u32x2, u32x4}; -use super::{compute_first_twiddles, EVENS_INTERLEAVE_EVENS, ODDS_INTERLEAVE_ODDS}; -use crate::core::backend::avx512::fft::{transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE}; -use crate::core::backend::avx512::{PackedBaseField, VECS_LOG_SIZE}; +use super::{ + compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE, +}; +use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; use crate::core::circle::Coset; use crate::core::fields::FieldExpOps; use crate::core::utils::bit_reverse; +const VECS_LOG_SIZE: usize = LOG_N_LANES as usize; + /// Performs an Inverse Circle Fast Fourier Transform (ICFFT) on the given values. /// /// # Safety @@ -155,8 +155,9 @@ pub unsafe fn ifft_vecwise_loop( ) { for index_l in 0..(1 << loop_bits) { let index = (index_h << loop_bits) + index_l; - let mut val0 = PackedBaseField::load(values.add(index * 32).cast_const()); - let mut val1 = PackedBaseField::load(values.add(index * 32 + 16).cast_const()); + let mut val0 = PackedBaseField::load(values.add(index * 32).cast_const() as *const u32); + let mut val1 = + PackedBaseField::load(values.add(index * 32 + 16).cast_const() as *const u32); (val0, val1) = vecwise_ibutterflies( val0, val1, @@ -167,10 +168,10 @@ pub unsafe fn ifft_vecwise_loop( (val0, val1) = avx_ibutterfly( val0, val1, - _mm512_set1_epi32(*twiddle_dbl[3].get_unchecked(index)), + u32x16::splat(*twiddle_dbl[3].get_unchecked(index) as u32), ); - val0.store(values.add(index * 32)); - val1.store(values.add(index * 32 + 16)); + val0.store(values.add(index * 32) as *mut u32); + val1.store(values.add(index * 32 + 16) as *mut u32); } } @@ -269,42 +270,11 @@ unsafe fn ifft1_loop(values: *mut i32, twiddle_dbl: &[&[i32]], layer: usize, ind pub unsafe fn avx_ibutterfly( val0: PackedBaseField, val1: PackedBaseField, - twiddle_dbl: __m512i, + twiddle_dbl: u32x16, ) -> (PackedBaseField, PackedBaseField) { let r0 = val0 + val1; let r1 = val0 - val1; - - // Extract the even and odd parts of r1 and twiddle_dbl, and spread as 8 64bit values. - let r1_e = r1.0; - let r1_o = _mm512_srli_epi64(r1.0, 32); - let twiddle_dbl_e = twiddle_dbl; - let twiddle_dbl_o = _mm512_srli_epi64(twiddle_dbl, 32); - - // To compute prod = r1 * twiddle start by multiplying - // r1_e/o by twiddle_dbl_e/o. - let prod_e_dbl = _mm512_mul_epu32(r1_e, twiddle_dbl_e); - let prod_o_dbl = _mm512_mul_epu32(r1_o, twiddle_dbl_o); - - // The result of a multiplication holds r1*twiddle_dbl in as 64-bits. - // Each 64b-bit word looks like this: - // 1 31 31 1 - // prod_e_dbl - |0|prod_e_h|prod_e_l|0| - // prod_o_dbl - |0|prod_o_h|prod_o_l|0| - - // Interleave the even words of prod_e_dbl with the even words of prod_o_dbl: - let prod_ls = _mm512_permutex2var_epi32(prod_e_dbl, EVENS_INTERLEAVE_EVENS, prod_o_dbl); - // prod_ls - |prod_o_l|0|prod_e_l|0| - - // Divide by 2: - let prod_ls = _mm512_srli_epi64(prod_ls, 1); - // prod_ls - |0|prod_o_l|0|prod_e_l| - - // Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl: - let prod_hs = _mm512_permutex2var_epi32(prod_e_dbl, ODDS_INTERLEAVE_ODDS, prod_o_dbl); - // prod_hs - |0|prod_o_h|0|prod_e_h| - - let prod = PackedBaseField(prod_ls) + PackedBaseField(prod_hs); - + let prod = mul_twiddle(r1, twiddle_dbl); (r0, prod) } @@ -349,29 +319,31 @@ pub unsafe fn vecwise_ibutterflies( let (t0, t1) = compute_first_twiddles(twiddle1_dbl); // Apply the permutation, resulting in indexing d:iabc. - (val0, val1) = val0.deinterleave_with(val1); + (val0, val1) = val0.deinterleave(val1); (val0, val1) = avx_ibutterfly(val0, val1, t0); // Apply the permutation, resulting in indexing c:diab. - (val0, val1) = val0.deinterleave_with(val1); + (val0, val1) = val0.deinterleave(val1); (val0, val1) = avx_ibutterfly(val0, val1, t1); - // The twiddles for layer 2 are replicated in the following pattern: - // 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 - let t = _mm512_broadcast_i32x4(std::mem::transmute(twiddle2_dbl)); + let t = simd_swizzle!( + u32x4::from_array(unsafe { std::mem::transmute(twiddle2_dbl) }), + [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3] + ); // Apply the permutation, resulting in indexing b:cdia. - (val0, val1) = val0.deinterleave_with(val1); + (val0, val1) = val0.deinterleave(val1); (val0, val1) = avx_ibutterfly(val0, val1, t); - // The twiddles for layer 3 are replicated in the following pattern: - // 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 - let t = _mm512_set1_epi64(std::mem::transmute(twiddle3_dbl)); + let t = simd_swizzle!( + u32x2::from_array(unsafe { std::mem::transmute(twiddle3_dbl) }), + [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + ); // Apply the permutation, resulting in indexing a:bcid. - (val0, val1) = val0.deinterleave_with(val1); + (val0, val1) = val0.deinterleave(val1); (val0, val1) = avx_ibutterfly(val0, val1, t); // Apply the permutation, resulting in indexing i:abcd. - val0.deinterleave_with(val1) + val0.deinterleave(val1) } /// Returns the line twiddles (x points) for an ifft on a coset. @@ -414,42 +386,50 @@ pub unsafe fn ifft3( twiddles_dbl2: [i32; 1], ) { // Load the 8 AVX vectors from the array. - let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const()); - let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const()); - let mut val2 = PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const()); - let mut val3 = PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const()); - let mut val4 = PackedBaseField::load(values.add(offset + (4 << log_step)).cast_const()); - let mut val5 = PackedBaseField::load(values.add(offset + (5 << log_step)).cast_const()); - let mut val6 = PackedBaseField::load(values.add(offset + (6 << log_step)).cast_const()); - let mut val7 = PackedBaseField::load(values.add(offset + (7 << log_step)).cast_const()); + let mut val0 = + PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const() as *const u32); + let mut val1 = + PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const() as *const u32); + let mut val2 = + PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const() as *const u32); + let mut val3 = + PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const() as *const u32); + let mut val4 = + PackedBaseField::load(values.add(offset + (4 << log_step)).cast_const() as *const u32); + let mut val5 = + PackedBaseField::load(values.add(offset + (5 << log_step)).cast_const() as *const u32); + let mut val6 = + PackedBaseField::load(values.add(offset + (6 << log_step)).cast_const() as *const u32); + let mut val7 = + PackedBaseField::load(values.add(offset + (7 << log_step)).cast_const() as *const u32); // Apply the first layer of ibutterflies. - (val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); - (val2, val3) = avx_ibutterfly(val2, val3, _mm512_set1_epi32(twiddles_dbl0[1])); - (val4, val5) = avx_ibutterfly(val4, val5, _mm512_set1_epi32(twiddles_dbl0[2])); - (val6, val7) = avx_ibutterfly(val6, val7, _mm512_set1_epi32(twiddles_dbl0[3])); + (val0, val1) = avx_ibutterfly(val0, val1, u32x16::splat(twiddles_dbl0[0] as u32)); + (val2, val3) = avx_ibutterfly(val2, val3, u32x16::splat(twiddles_dbl0[1] as u32)); + (val4, val5) = avx_ibutterfly(val4, val5, u32x16::splat(twiddles_dbl0[2] as u32)); + (val6, val7) = avx_ibutterfly(val6, val7, u32x16::splat(twiddles_dbl0[3] as u32)); // Apply the second layer of ibutterflies. - (val0, val2) = avx_ibutterfly(val0, val2, _mm512_set1_epi32(twiddles_dbl1[0])); - (val1, val3) = avx_ibutterfly(val1, val3, _mm512_set1_epi32(twiddles_dbl1[0])); - (val4, val6) = avx_ibutterfly(val4, val6, _mm512_set1_epi32(twiddles_dbl1[1])); - (val5, val7) = avx_ibutterfly(val5, val7, _mm512_set1_epi32(twiddles_dbl1[1])); + (val0, val2) = avx_ibutterfly(val0, val2, u32x16::splat(twiddles_dbl1[0] as u32)); + (val1, val3) = avx_ibutterfly(val1, val3, u32x16::splat(twiddles_dbl1[0] as u32)); + (val4, val6) = avx_ibutterfly(val4, val6, u32x16::splat(twiddles_dbl1[1] as u32)); + (val5, val7) = avx_ibutterfly(val5, val7, u32x16::splat(twiddles_dbl1[1] as u32)); // Apply the third layer of ibutterflies. - (val0, val4) = avx_ibutterfly(val0, val4, _mm512_set1_epi32(twiddles_dbl2[0])); - (val1, val5) = avx_ibutterfly(val1, val5, _mm512_set1_epi32(twiddles_dbl2[0])); - (val2, val6) = avx_ibutterfly(val2, val6, _mm512_set1_epi32(twiddles_dbl2[0])); - (val3, val7) = avx_ibutterfly(val3, val7, _mm512_set1_epi32(twiddles_dbl2[0])); + (val0, val4) = avx_ibutterfly(val0, val4, u32x16::splat(twiddles_dbl2[0] as u32)); + (val1, val5) = avx_ibutterfly(val1, val5, u32x16::splat(twiddles_dbl2[0] as u32)); + (val2, val6) = avx_ibutterfly(val2, val6, u32x16::splat(twiddles_dbl2[0] as u32)); + (val3, val7) = avx_ibutterfly(val3, val7, u32x16::splat(twiddles_dbl2[0] as u32)); // Store the 8 AVX vectors back to the array. - val0.store(values.add(offset + (0 << log_step))); - val1.store(values.add(offset + (1 << log_step))); - val2.store(values.add(offset + (2 << log_step))); - val3.store(values.add(offset + (3 << log_step))); - val4.store(values.add(offset + (4 << log_step))); - val5.store(values.add(offset + (5 << log_step))); - val6.store(values.add(offset + (6 << log_step))); - val7.store(values.add(offset + (7 << log_step))); + val0.store(values.add(offset + (0 << log_step)) as *mut u32); + val1.store(values.add(offset + (1 << log_step)) as *mut u32); + val2.store(values.add(offset + (2 << log_step)) as *mut u32); + val3.store(values.add(offset + (3 << log_step)) as *mut u32); + val4.store(values.add(offset + (4 << log_step)) as *mut u32); + val5.store(values.add(offset + (5 << log_step)) as *mut u32); + val6.store(values.add(offset + (6 << log_step)) as *mut u32); + val7.store(values.add(offset + (7 << log_step)) as *mut u32); } /// Applies 2 ibutterfly layers on 4 vectors of 16 M31 elements. @@ -473,24 +453,28 @@ pub unsafe fn ifft2( twiddles_dbl1: [i32; 1], ) { // Load the 4 AVX vectors from the array. - let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const()); - let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const()); - let mut val2 = PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const()); - let mut val3 = PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const()); + let mut val0 = + PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const() as *const u32); + let mut val1 = + PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const() as *const u32); + let mut val2 = + PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const() as *const u32); + let mut val3 = + PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const() as *const u32); // Apply the first layer of butterflies. - (val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); - (val2, val3) = avx_ibutterfly(val2, val3, _mm512_set1_epi32(twiddles_dbl0[1])); + (val0, val1) = avx_ibutterfly(val0, val1, u32x16::splat(twiddles_dbl0[0] as u32)); + (val2, val3) = avx_ibutterfly(val2, val3, u32x16::splat(twiddles_dbl0[1] as u32)); // Apply the second layer of butterflies. - (val0, val2) = avx_ibutterfly(val0, val2, _mm512_set1_epi32(twiddles_dbl1[0])); - (val1, val3) = avx_ibutterfly(val1, val3, _mm512_set1_epi32(twiddles_dbl1[0])); + (val0, val2) = avx_ibutterfly(val0, val2, u32x16::splat(twiddles_dbl1[0] as u32)); + (val1, val3) = avx_ibutterfly(val1, val3, u32x16::splat(twiddles_dbl1[0] as u32)); // Store the 4 AVX vectors back to the array. - val0.store(values.add(offset + (0 << log_step))); - val1.store(values.add(offset + (1 << log_step))); - val2.store(values.add(offset + (2 << log_step))); - val3.store(values.add(offset + (3 << log_step))); + val0.store(values.add(offset + (0 << log_step)) as *mut u32); + val1.store(values.add(offset + (1 << log_step)) as *mut u32); + val2.store(values.add(offset + (2 << log_step)) as *mut u32); + val3.store(values.add(offset + (3 << log_step)) as *mut u32); } /// Applies 1 ibutterfly layers on 2 vectors of 16 M31 elements. @@ -504,25 +488,24 @@ pub unsafe fn ifft2( /// # Safety pub unsafe fn ifft1(values: *mut i32, offset: usize, log_step: usize, twiddles_dbl0: [i32; 1]) { // Load the 2 AVX vectors from the array. - let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const()); - let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const()); + let mut val0 = + PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const() as *const u32); + let mut val1 = + PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const() as *const u32); - (val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); + (val0, val1) = avx_ibutterfly(val0, val1, u32x16::splat(twiddles_dbl0[0] as u32)); // Store the 2 AVX vectors back to the array. - val0.store(values.add(offset + (0 << log_step))); - val1.store(values.add(offset + (1 << log_step))); + val0.store(values.add(offset + (0 << log_step)) as *mut u32); + val1.store(values.add(offset + (1 << log_step)) as *mut u32); } -#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] #[cfg(test)] mod tests { - use std::arch::x86_64::{_mm512_add_epi32, _mm512_setr_epi32}; - use super::*; - use crate::core::backend::avx512::m31::PackedBaseField; - use crate::core::backend::avx512::BaseFieldVec; use crate::core::backend::cpu::CPUCircleEvaluation; + use crate::core::backend::simd::column::BaseFieldVec; + use crate::core::backend::simd::m31::PackedBaseField; use crate::core::backend::Column; use crate::core::fft::ibutterfly; use crate::core::fields::m31::BaseField; @@ -531,16 +514,16 @@ mod tests { #[test] fn test_ibutterfly() { unsafe { - let val0 = PackedBaseField(_mm512_setr_epi32( + let val0 = PackedBaseField::from_simd_unchecked(u32x16::from_array([ 2, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - )); - let val1 = PackedBaseField(_mm512_setr_epi32( + ])); + let val1 = PackedBaseField::from_simd_unchecked(u32x16::from_array([ 3, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, - )); - let twiddle = _mm512_setr_epi32( + ])); + let twiddle = u32x16::from_array([ 1177558791, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, - ); - let twiddle_dbl = _mm512_add_epi32(twiddle, twiddle); + ]); + let twiddle_dbl = twiddle + twiddle; let (r0, r1) = avx_ibutterfly(val0, val1, twiddle_dbl); let val0: [BaseField; 16] = std::mem::transmute(val0); @@ -647,7 +630,7 @@ mod tests { twiddle_dbls[1].clone().try_into().unwrap(), twiddle_dbls[2].clone().try_into().unwrap(), ); - let (val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddle_dbls[3][0])); + let (val0, val1) = avx_ibutterfly(val0, val1, u32x16::splat(twiddle_dbls[3][0] as u32)); std::mem::transmute([val0, val1]) }; diff --git a/crates/prover/src/core/backend/simd/fft/mod.rs b/crates/prover/src/core/backend/simd/fft/mod.rs index 37dbbd686..f99db2db2 100644 --- a/crates/prover/src/core/backend/simd/fft/mod.rs +++ b/crates/prover/src/core/backend/simd/fft/mod.rs @@ -1,14 +1,15 @@ -use std::arch::x86_64::{ - __m512i, _mm512_broadcast_i64x4, _mm512_load_epi32, _mm512_permutexvar_epi32, - _mm512_store_epi32, _mm512_xor_epi32, -}; +use std::mem::transmute; +use std::simd::{simd_swizzle, u32x16, u32x8}; + +use super::m31::PackedBaseField; +use crate::core::fields::m31::P; pub mod ifft; pub mod rfft; /// An input to _mm512_permutex2var_epi32, and is used to interleave the even words of a /// with the even words of b. -const EVENS_INTERLEAVE_EVENS: __m512i = unsafe { +const _EVENS_INTERLEAVE_EVENS: u32x16 = unsafe { core::mem::transmute([ 0b00000, 0b10000, 0b00010, 0b10010, 0b00100, 0b10100, 0b00110, 0b10110, 0b01000, 0b11000, 0b01010, 0b11010, 0b01100, 0b11100, 0b01110, 0b11110, @@ -16,7 +17,7 @@ const EVENS_INTERLEAVE_EVENS: __m512i = unsafe { }; /// An input to _mm512_permutex2var_epi32, and is used to interleave the odd words of a /// with the odd words of b. -const ODDS_INTERLEAVE_ODDS: __m512i = unsafe { +const _ODDS_INTERLEAVE_ODDS: u32x16 = unsafe { core::mem::transmute([ 0b00001, 0b10001, 0b00011, 0b10011, 0b00101, 0b10101, 0b00111, 0b10111, 0b01001, 0b11001, 0b01011, 0b11011, 0b01101, 0b11101, 0b01111, 0b11111, @@ -60,14 +61,24 @@ pub unsafe fn transpose_vecs(values: *mut i32, log_n_vecs: usize) { } } +unsafe fn _mm512_load_epi32(mem_addr: *const i32) -> u32x16 { + std::ptr::read(mem_addr as *const u32x16) +} + +unsafe fn _mm512_store_epi32(mem_addr: *mut i32, a: u32x16) { + std::ptr::write(mem_addr as *mut u32x16, a); +} + /// Computes the twiddles for the first fft layer from the second, and loads both to AVX registers. /// Returns the twiddles for the first layer and the twiddles for the second layer. /// # Safety -pub unsafe fn compute_first_twiddles(twiddle1_dbl: [i32; 8]) -> (__m512i, __m512i) { +pub unsafe fn compute_first_twiddles(twiddle1_dbl: [i32; 8]) -> (u32x16, u32x16) { // Start by loading the twiddles for the second layer (layer 1): - // The twiddles for layer 1 are replicated in the following pattern: - // 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 - let t1 = _mm512_broadcast_i64x4(std::mem::transmute(twiddle1_dbl)); + let t1 = simd_swizzle!( + unsafe { transmute::<_, u32x8>(twiddle1_dbl) }, + unsafe { transmute::<_, u32x8>(twiddle1_dbl) }, + [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7] + ); // The twiddles for layer 0 can be computed from the twiddles for layer 1. // Since the twiddles are bit reversed, we consider the circle domain in bit reversed order. @@ -84,17 +95,39 @@ pub unsafe fn compute_first_twiddles(twiddle1_dbl: [i32; 8]) -> (__m512i, __m512 // The twiddles for layer 0 are computed like this: // t0[4i:4i+3] = [t1[2i+1], -t1[2i+1], -t1[2i], t1[2i]] - const INDICES_FROM_T1: __m512i = unsafe { - core::mem::transmute([ + // Xoring a double twiddle with P*2 transforms it to the double of it negation. + // Note that this keeps the values as a double of a value in the range [0, P]. + const P2: u32 = P * 2; + const NEGATION_MASK: u32x16 = + u32x16::from_array([0, P2, P2, 0, 0, P2, P2, 0, 0, P2, P2, 0, 0, P2, P2, 0]); + let t0 = simd_swizzle!( + t1, + [ 0b0001, 0b0001, 0b0000, 0b0000, 0b0011, 0b0011, 0b0010, 0b0010, 0b0101, 0b0101, 0b0100, 0b0100, 0b0111, 0b0111, 0b0110, 0b0110, - ]) - }; - // Xoring a double twiddle with 2^32-2 transforms it to the double of it negation. - // Note that this keeps the values as a double of a value in the range [0, P]. - const NEGATION_MASK: __m512i = unsafe { - core::mem::transmute([0i32, -2, -2, 0, 0, -2, -2, 0, 0, -2, -2, 0, 0, -2, -2, 0]) - }; - let t0 = _mm512_xor_epi32(_mm512_permutexvar_epi32(INDICES_FROM_T1, t1), NEGATION_MASK); + ] + ) ^ NEGATION_MASK; (t0, t1) } + +/// Computes `v * twiddle` +fn mul_twiddle(v: PackedBaseField, twiddle_dbl: u32x16) -> PackedBaseField { + // TODO: Come up with a better approach than `cfg`ing on target_feature. + // TODO: Ensure all these branches get tested in the CI. + cfg_if::cfg_if! { + if #[cfg(all(target_feature = "neon", target_arch = "aarch64"))] { + // TODO: For architectures that when multiplying require doubling then the twiddles + // should be precomputed as double. For other architectures, the twiddle should be + // precomputed without doubling. + crate::core::backend::simd::m31::_mul_doubled_neon(v, twiddle_dbl) + } else if #[cfg(all(target_feature = "simd128", target_arch = "wasm32"))] { + crate::core::backend::simd::m31::_mul_doubled_wasm(v, twiddle_dbl) + } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] { + crate::core::backend::simd::m31::_mul_doubled_avx512(v, twiddle_dbl) + } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2f"))] { + crate::core::backend::simd::m31::_mul_doubled_avx2(v, twiddle_dbl) + } else { + crate::core::backend::simd::m31::_mul_doubled_simd(v, twiddle_dbl) + } + } +} diff --git a/crates/prover/src/core/backend/simd/fft/rfft.rs b/crates/prover/src/core/backend/simd/fft/rfft.rs index 43fbf7f32..aacbffe24 100644 --- a/crates/prover/src/core/backend/simd/fft/rfft.rs +++ b/crates/prover/src/core/backend/simd/fft/rfft.rs @@ -1,16 +1,16 @@ //! Regular (forward) fft. -use std::arch::x86_64::{ - __m512i, _mm512_broadcast_i32x4, _mm512_mul_epu32, _mm512_permutex2var_epi32, - _mm512_set1_epi32, _mm512_set1_epi64, _mm512_srli_epi64, -}; +use std::simd::{simd_swizzle, u32x16, u32x2, u32x4}; -use super::{compute_first_twiddles, EVENS_INTERLEAVE_EVENS, ODDS_INTERLEAVE_ODDS}; -use crate::core::backend::avx512::fft::{transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE}; -use crate::core::backend::avx512::{PackedBaseField, VECS_LOG_SIZE}; +use super::{ + compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE, +}; +use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; use crate::core::circle::Coset; use crate::core::utils::bit_reverse; +const VECS_LOG_SIZE: usize = LOG_N_LANES as usize; + /// Performs a Circle Fast Fourier Transform (ICFFT) on the given values. /// /// # Safety @@ -175,12 +175,12 @@ unsafe fn fft_vecwise_loop( ) { for index_l in 0..(1 << loop_bits) { let index = (index_h << loop_bits) + index_l; - let mut val0 = PackedBaseField::load(src.add(index * 32)); - let mut val1 = PackedBaseField::load(src.add(index * 32 + 16)); + let mut val0 = PackedBaseField::load(src.add(index * 32) as *const u32); + let mut val1 = PackedBaseField::load(src.add(index * 32 + 16) as *const u32); (val0, val1) = avx_butterfly( val0, val1, - _mm512_set1_epi32(*twiddle_dbl[3].get_unchecked(index)), + u32x16::splat(*twiddle_dbl[3].get_unchecked(index) as u32), ); (val0, val1) = vecwise_butterflies( val0, @@ -189,8 +189,8 @@ unsafe fn fft_vecwise_loop( std::array::from_fn(|i| *twiddle_dbl[1].get_unchecked(index * 4 + i)), std::array::from_fn(|i| *twiddle_dbl[2].get_unchecked(index * 2 + i)), ); - val0.store(dst.add(index * 32)); - val1.store(dst.add(index * 32 + 16)); + val0.store(dst.add(index * 32) as *mut u32); + val1.store(dst.add(index * 32 + 16) as *mut u32); } } @@ -309,44 +309,10 @@ unsafe fn fft1_loop( pub unsafe fn avx_butterfly( val0: PackedBaseField, val1: PackedBaseField, - twiddle_dbl: __m512i, + twiddle_dbl: u32x16, ) -> (PackedBaseField, PackedBaseField) { - // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of val0. - let val1_e = val1.0; - // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of val0. - let val1_o = _mm512_srli_epi64(val1.0, 32); - let twiddle_dbl_e = twiddle_dbl; - let twiddle_dbl_o = _mm512_srli_epi64(twiddle_dbl, 32); - - // To compute prod = val1 * twiddle start by multiplying - // val1_e/o by twiddle_dbl_e/o. - let prod_e_dbl = _mm512_mul_epu32(val1_e, twiddle_dbl_e); - let prod_o_dbl = _mm512_mul_epu32(val1_o, twiddle_dbl_o); - - // The result of a multiplication holds val1*twiddle_dbl in as 64-bits. - // Each 64b-bit word looks like this: - // 1 31 31 1 - // prod_e_dbl - |0|prod_e_h|prod_e_l|0| - // prod_o_dbl - |0|prod_o_h|prod_o_l|0| - - // Interleave the even words of prod_e_dbl with the even words of prod_o_dbl: - let prod_ls = _mm512_permutex2var_epi32(prod_e_dbl, EVENS_INTERLEAVE_EVENS, prod_o_dbl); - // prod_ls - |prod_o_l|0|prod_e_l|0| - - // Divide by 2: - let prod_ls = _mm512_srli_epi64(prod_ls, 1); - // prod_ls - |0|prod_o_l|0|prod_e_l| - - // Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl: - let prod_hs = _mm512_permutex2var_epi32(prod_e_dbl, ODDS_INTERLEAVE_ODDS, prod_o_dbl); - // prod_hs - |0|prod_o_h|0|prod_e_h| - - let prod = PackedBaseField(prod_ls) + PackedBaseField(prod_hs); - - let r0 = val0 + prod; - let r1 = val0 - prod; - - (r0, r1) + let prod = mul_twiddle(val1, twiddle_dbl); + (val0 + prod, val0 - prod) } /// Runs fft on 2 vectors of 16 M31 elements. @@ -369,22 +335,28 @@ pub unsafe fn vecwise_butterflies( // TODO(spapini): The permute can be fused with the _mm512_srli_epi64 inside the butterfly. // The implementation is the exact reverse of vecwise_ibutterflies(). // See the comments in its body for more info. - let t = _mm512_set1_epi64(std::mem::transmute(twiddle3_dbl)); - (val0, val1) = val0.interleave_with(val1); + let t = simd_swizzle!( + u32x2::from_array(unsafe { std::mem::transmute(twiddle3_dbl) }), + [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + ); + (val0, val1) = val0.interleave(val1); (val0, val1) = avx_butterfly(val0, val1, t); - let t = _mm512_broadcast_i32x4(std::mem::transmute(twiddle2_dbl)); - (val0, val1) = val0.interleave_with(val1); + let t = simd_swizzle!( + u32x4::from_array(unsafe { std::mem::transmute(twiddle2_dbl) }), + [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3] + ); + (val0, val1) = val0.interleave(val1); (val0, val1) = avx_butterfly(val0, val1, t); let (t0, t1) = compute_first_twiddles(twiddle1_dbl); - (val0, val1) = val0.interleave_with(val1); + (val0, val1) = val0.interleave(val1); (val0, val1) = avx_butterfly(val0, val1, t1); - (val0, val1) = val0.interleave_with(val1); + (val0, val1) = val0.interleave(val1); (val0, val1) = avx_butterfly(val0, val1, t0); - val0.interleave_with(val1) + val0.interleave(val1) } /// Returns the line twiddles (x points) for an fft on a coset. @@ -429,42 +401,42 @@ pub unsafe fn fft3( twiddles_dbl2: [i32; 1], ) { // Load the 8 AVX vectors from the array. - let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step))); - let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step))); - let mut val2 = PackedBaseField::load(src.add(offset + (2 << log_step))); - let mut val3 = PackedBaseField::load(src.add(offset + (3 << log_step))); - let mut val4 = PackedBaseField::load(src.add(offset + (4 << log_step))); - let mut val5 = PackedBaseField::load(src.add(offset + (5 << log_step))); - let mut val6 = PackedBaseField::load(src.add(offset + (6 << log_step))); - let mut val7 = PackedBaseField::load(src.add(offset + (7 << log_step))); + let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step)) as *const u32); + let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step)) as *const u32); + let mut val2 = PackedBaseField::load(src.add(offset + (2 << log_step)) as *const u32); + let mut val3 = PackedBaseField::load(src.add(offset + (3 << log_step)) as *const u32); + let mut val4 = PackedBaseField::load(src.add(offset + (4 << log_step)) as *const u32); + let mut val5 = PackedBaseField::load(src.add(offset + (5 << log_step)) as *const u32); + let mut val6 = PackedBaseField::load(src.add(offset + (6 << log_step)) as *const u32); + let mut val7 = PackedBaseField::load(src.add(offset + (7 << log_step)) as *const u32); // Apply the third layer of butterflies. - (val0, val4) = avx_butterfly(val0, val4, _mm512_set1_epi32(twiddles_dbl2[0])); - (val1, val5) = avx_butterfly(val1, val5, _mm512_set1_epi32(twiddles_dbl2[0])); - (val2, val6) = avx_butterfly(val2, val6, _mm512_set1_epi32(twiddles_dbl2[0])); - (val3, val7) = avx_butterfly(val3, val7, _mm512_set1_epi32(twiddles_dbl2[0])); + (val0, val4) = avx_butterfly(val0, val4, u32x16::splat(twiddles_dbl2[0] as u32)); + (val1, val5) = avx_butterfly(val1, val5, u32x16::splat(twiddles_dbl2[0] as u32)); + (val2, val6) = avx_butterfly(val2, val6, u32x16::splat(twiddles_dbl2[0] as u32)); + (val3, val7) = avx_butterfly(val3, val7, u32x16::splat(twiddles_dbl2[0] as u32)); // Apply the second layer of butterflies. - (val0, val2) = avx_butterfly(val0, val2, _mm512_set1_epi32(twiddles_dbl1[0])); - (val1, val3) = avx_butterfly(val1, val3, _mm512_set1_epi32(twiddles_dbl1[0])); - (val4, val6) = avx_butterfly(val4, val6, _mm512_set1_epi32(twiddles_dbl1[1])); - (val5, val7) = avx_butterfly(val5, val7, _mm512_set1_epi32(twiddles_dbl1[1])); + (val0, val2) = avx_butterfly(val0, val2, u32x16::splat(twiddles_dbl1[0] as u32)); + (val1, val3) = avx_butterfly(val1, val3, u32x16::splat(twiddles_dbl1[0] as u32)); + (val4, val6) = avx_butterfly(val4, val6, u32x16::splat(twiddles_dbl1[1] as u32)); + (val5, val7) = avx_butterfly(val5, val7, u32x16::splat(twiddles_dbl1[1] as u32)); // Apply the first layer of butterflies. - (val0, val1) = avx_butterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); - (val2, val3) = avx_butterfly(val2, val3, _mm512_set1_epi32(twiddles_dbl0[1])); - (val4, val5) = avx_butterfly(val4, val5, _mm512_set1_epi32(twiddles_dbl0[2])); - (val6, val7) = avx_butterfly(val6, val7, _mm512_set1_epi32(twiddles_dbl0[3])); + (val0, val1) = avx_butterfly(val0, val1, u32x16::splat(twiddles_dbl0[0] as u32)); + (val2, val3) = avx_butterfly(val2, val3, u32x16::splat(twiddles_dbl0[1] as u32)); + (val4, val5) = avx_butterfly(val4, val5, u32x16::splat(twiddles_dbl0[2] as u32)); + (val6, val7) = avx_butterfly(val6, val7, u32x16::splat(twiddles_dbl0[3] as u32)); // Store the 8 AVX vectors back to the array. - val0.store(dst.add(offset + (0 << log_step))); - val1.store(dst.add(offset + (1 << log_step))); - val2.store(dst.add(offset + (2 << log_step))); - val3.store(dst.add(offset + (3 << log_step))); - val4.store(dst.add(offset + (4 << log_step))); - val5.store(dst.add(offset + (5 << log_step))); - val6.store(dst.add(offset + (6 << log_step))); - val7.store(dst.add(offset + (7 << log_step))); + val0.store(dst.add(offset + (0 << log_step)) as *mut u32); + val1.store(dst.add(offset + (1 << log_step)) as *mut u32); + val2.store(dst.add(offset + (2 << log_step)) as *mut u32); + val3.store(dst.add(offset + (3 << log_step)) as *mut u32); + val4.store(dst.add(offset + (4 << log_step)) as *mut u32); + val5.store(dst.add(offset + (5 << log_step)) as *mut u32); + val6.store(dst.add(offset + (6 << log_step)) as *mut u32); + val7.store(dst.add(offset + (7 << log_step)) as *mut u32); } /// Applies 2 butterfly layers on 4 vectors of 16 M31 elements. @@ -490,24 +462,24 @@ pub unsafe fn fft2( twiddles_dbl1: [i32; 1], ) { // Load the 4 AVX vectors from the array. - let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step))); - let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step))); - let mut val2 = PackedBaseField::load(src.add(offset + (2 << log_step))); - let mut val3 = PackedBaseField::load(src.add(offset + (3 << log_step))); + let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step)) as *const u32); + let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step)) as *const u32); + let mut val2 = PackedBaseField::load(src.add(offset + (2 << log_step)) as *const u32); + let mut val3 = PackedBaseField::load(src.add(offset + (3 << log_step)) as *const u32); // Apply the second layer of butterflies. - (val0, val2) = avx_butterfly(val0, val2, _mm512_set1_epi32(twiddles_dbl1[0])); - (val1, val3) = avx_butterfly(val1, val3, _mm512_set1_epi32(twiddles_dbl1[0])); + (val0, val2) = avx_butterfly(val0, val2, u32x16::splat(twiddles_dbl1[0] as u32)); + (val1, val3) = avx_butterfly(val1, val3, u32x16::splat(twiddles_dbl1[0] as u32)); // Apply the first layer of butterflies. - (val0, val1) = avx_butterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); - (val2, val3) = avx_butterfly(val2, val3, _mm512_set1_epi32(twiddles_dbl0[1])); + (val0, val1) = avx_butterfly(val0, val1, u32x16::splat(twiddles_dbl0[0] as u32)); + (val2, val3) = avx_butterfly(val2, val3, u32x16::splat(twiddles_dbl0[1] as u32)); // Store the 4 AVX vectors back to the array. - val0.store(dst.add(offset + (0 << log_step))); - val1.store(dst.add(offset + (1 << log_step))); - val2.store(dst.add(offset + (2 << log_step))); - val3.store(dst.add(offset + (3 << log_step))); + val0.store(dst.add(offset + (0 << log_step)) as *mut u32); + val1.store(dst.add(offset + (1 << log_step)) as *mut u32); + val2.store(dst.add(offset + (2 << log_step)) as *mut u32); + val3.store(dst.add(offset + (3 << log_step)) as *mut u32); } /// Applies 1 butterfly layers on 2 vectors of 16 M31 elements. @@ -528,24 +500,21 @@ pub unsafe fn fft1( twiddles_dbl0: [i32; 1], ) { // Load the 2 AVX vectors from the array. - let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step))); - let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step))); + let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step)) as *const u32); + let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step)) as *const u32); - (val0, val1) = avx_butterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); + (val0, val1) = avx_butterfly(val0, val1, u32x16::splat(twiddles_dbl0[0] as u32)); // Store the 2 AVX vectors back to the array. - val0.store(dst.add(offset + (0 << log_step))); - val1.store(dst.add(offset + (1 << log_step))); + val0.store(dst.add(offset + (0 << log_step)) as *mut u32); + val1.store(dst.add(offset + (1 << log_step)) as *mut u32); } -#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] #[cfg(test)] mod tests { - use std::arch::x86_64::{_mm512_add_epi32, _mm512_set1_epi32, _mm512_setr_epi32}; - use super::*; - use crate::core::backend::avx512::{BaseFieldVec, PackedBaseField}; use crate::core::backend::cpu::CPUCirclePoly; + use crate::core::backend::simd::column::BaseFieldVec; use crate::core::backend::Column; use crate::core::fft::butterfly; use crate::core::fields::m31::BaseField; @@ -554,16 +523,16 @@ mod tests { #[test] fn test_butterfly() { unsafe { - let val0 = PackedBaseField(_mm512_setr_epi32( + let val0 = PackedBaseField::from_simd_unchecked(u32x16::from_array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - )); - let val1 = PackedBaseField(_mm512_setr_epi32( + ])); + let val1 = PackedBaseField::from_simd_unchecked(u32x16::from_array([ 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, - )); - let twiddle = _mm512_setr_epi32( + ])); + let twiddle = u32x16::from_array([ 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, - ); - let twiddle_dbl = _mm512_add_epi32(twiddle, twiddle); + ]); + let twiddle_dbl = twiddle + twiddle; let (r0, r1) = avx_butterfly(val0, val1, twiddle_dbl); let val0: [BaseField; 16] = std::mem::transmute(val0); @@ -663,7 +632,7 @@ mod tests { let (val0, val1) = avx_butterfly( std::mem::transmute(values0), std::mem::transmute(values1), - _mm512_set1_epi32(twiddle_dbls[3][0]), + u32x16::splat(twiddle_dbls[3][0] as u32), ); let (val0, val1) = vecwise_butterflies( val0, diff --git a/crates/prover/src/core/backend/simd/m31.rs b/crates/prover/src/core/backend/simd/m31.rs index 7cd9a6a20..d8d28ee0a 100644 --- a/crates/prover/src/core/backend/simd/m31.rs +++ b/crates/prover/src/core/backend/simd/m31.rs @@ -225,7 +225,7 @@ impl Distribution for Standard { /// Returns `a * b`. #[cfg(target_arch = "aarch64")] -fn _mul_neon(a: PackedM31, b: PackedM31) -> PackedM31 { +pub(crate) fn _mul_neon(a: PackedM31, b: PackedM31) -> PackedM31 { use core::arch::aarch64::{int32x2_t, vqdmull_s32}; use std::simd::u32x4; @@ -261,9 +261,49 @@ fn _mul_neon(a: PackedM31, b: PackedM31) -> PackedM31 { lo + hi } +/// Returns `a * b`. +/// +/// `b_double` should be in the range `[0, 2P]`. +#[cfg(target_arch = "aarch64")] +pub(crate) fn _mul_doubled_neon(a: PackedM31, b_double: u32x16) -> PackedM31 { + use core::arch::aarch64::{uint32x2_t, vmull_u32}; + use std::simd::u32x4; + + let [a0, a1, a2, a3, a4, a5, a6, a7]: [uint32x2_t; 8] = unsafe { transmute(a) }; + let [b0, b1, b2, b3, b4, b5, b6, b7]: [uint32x2_t; 8] = unsafe { transmute(b_double) }; + + // Each c_i contains |0|prod_lo|prod_hi|0|0|prod_lo|prod_hi|0| + let c0: u32x4 = unsafe { transmute(vmull_u32(a0, b0)) }; + let c1: u32x4 = unsafe { transmute(vmull_u32(a1, b1)) }; + let c2: u32x4 = unsafe { transmute(vmull_u32(a2, b2)) }; + let c3: u32x4 = unsafe { transmute(vmull_u32(a3, b3)) }; + let c4: u32x4 = unsafe { transmute(vmull_u32(a4, b4)) }; + let c5: u32x4 = unsafe { transmute(vmull_u32(a5, b5)) }; + let c6: u32x4 = unsafe { transmute(vmull_u32(a6, b6)) }; + let c7: u32x4 = unsafe { transmute(vmull_u32(a7, b7)) }; + + // *_lo contain `|prod_lo|0|prod_lo|0|prod_lo0|0|prod_lo|0|`. + // *_hi contain `|0|prod_hi|0|prod_hi|0|prod_hi|0|prod_hi|`. + let (mut c0_c1_lo, c0_c1_hi) = c0.deinterleave(c1); + let (mut c2_c3_lo, c2_c3_hi) = c2.deinterleave(c3); + let (mut c4_c5_lo, c4_c5_hi) = c4.deinterleave(c5); + let (mut c6_c7_lo, c6_c7_hi) = c6.deinterleave(c7); + + // *_lo contain `|0|prod_lo|0|prod_lo|0|prod_lo|0|prod_lo|`. + c0_c1_lo >>= 1; + c2_c3_lo >>= 1; + c4_c5_lo >>= 1; + c6_c7_lo >>= 1; + + let lo: PackedM31 = unsafe { transmute([c0_c1_lo, c2_c3_lo, c4_c5_lo, c6_c7_lo]) }; + let hi: PackedM31 = unsafe { transmute([c0_c1_hi, c2_c3_hi, c4_c5_hi, c6_c7_hi]) }; + + lo + hi +} + /// Returns `a * b`. #[cfg(target_arch = "wasm32")] -fn _mul_wasm(a: PackedM31, b: PackedM31) -> PackedM31 { +pub(crate) fn _mul_wasm(a: PackedM31, b: PackedM31) -> PackedM31 { _mul_doubled_wasm(a, b.0 + b.0) } @@ -271,7 +311,7 @@ fn _mul_wasm(a: PackedM31, b: PackedM31) -> PackedM31 { /// /// `b_double` should be in the range `[0, 2P]`. #[cfg(target_arch = "wasm32")] -fn _mul_doubled_wasm(a: PackedM31, b_double: u32x16) -> PackedM31 { +pub(crate) fn _mul_doubled_wasm(a: PackedM31, b_double: u32x16) -> PackedM31 { use core::arch::wasm32::{i64x2_extmul_high_u32x4, i64x2_extmul_low_u32x4, v128}; use std::simd::u32x4; @@ -304,7 +344,7 @@ fn _mul_doubled_wasm(a: PackedM31, b_double: u32x16) -> PackedM31 { /// Returns `a * b`. #[cfg(target_arch = "x86_64")] -fn _mul_avx512(a: PackedM31, b: PackedM31) -> PackedM31 { +pub(crate) fn _mul_avx512(a: PackedM31, b: PackedM31) -> PackedM31 { _mul_doubled_avx512(a, b.0 + b.0) } @@ -312,7 +352,7 @@ fn _mul_avx512(a: PackedM31, b: PackedM31) -> PackedM31 { /// /// `b_double` should be in the range `[0, 2P]`. #[cfg(target_arch = "x86_64")] -fn _mul_doubled_avx512(a: PackedM31, b_double: u32x16) -> PackedM31 { +pub(crate) fn _mul_doubled_avx512(a: PackedM31, b_double: u32x16) -> PackedM31 { use std::arch::x86_64::{__m512i, _mm512_mul_epu32, _mm512_srli_epi64}; let a: __m512i = unsafe { transmute(a) }; @@ -354,7 +394,7 @@ fn _mul_doubled_avx512(a: PackedM31, b_double: u32x16) -> PackedM31 { /// Returns `a * b`. #[cfg(target_arch = "x86_64")] -fn _mul_avx2(a: PackedM31, b: PackedM31) -> PackedM31 { +pub(crate) fn _mul_avx2(a: PackedM31, b: PackedM31) -> PackedM31 { _mul_doubled_avx2(a, b.0 + b.0) } @@ -362,7 +402,7 @@ fn _mul_avx2(a: PackedM31, b: PackedM31) -> PackedM31 { /// /// `b_double` should be in the range `[0, 2P]`. #[cfg(target_arch = "x86_64")] -fn _mul_doubled_avx2(a: PackedM31, b_double: u32x16) -> PackedM31 { +pub(crate) fn _mul_doubled_avx2(a: PackedM31, b_double: u32x16) -> PackedM31 { use std::arch::x86_64::{__m256i, _mm256_mul_epu32, _mm256_srli_epi64}; let [a0, a1]: [__m256i; 2] = unsafe { transmute(a) }; @@ -414,7 +454,7 @@ fn _mul_doubled_avx2(a: PackedM31, b_double: u32x16) -> PackedM31 { /// Returns `a * b`. /// /// Should only be used in the absence of a platform specific implementation. -fn _mul_simd(a: PackedM31, b: PackedM31) -> PackedM31 { +pub(crate) fn _mul_simd(a: PackedM31, b: PackedM31) -> PackedM31 { _mul_doubled_simd(a, b.0 + b.0) } @@ -423,7 +463,7 @@ fn _mul_simd(a: PackedM31, b: PackedM31) -> PackedM31 { /// Should only be used in the absence of a platform specific implementation. /// /// `b_double` should be in the range `[0, 2P]`. -fn _mul_doubled_simd(a: PackedM31, b_double: u32x16) -> PackedM31 { +pub(crate) fn _mul_doubled_simd(a: PackedM31, b_double: u32x16) -> PackedM31 { const MASK_EVENS: Simd = Simd::from_array([0xFFFFFFFF; { N_LANES / 2 }]); // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of