From 347d4464e4ae791a4b801c4670773b0ab57bfdc6 Mon Sep 17 00:00:00 2001 From: Andrew Milson Date: Mon, 6 May 2024 15:48:32 -0400 Subject: [PATCH] Fix compilation issues --- .../prover/src/core/backend/simd/fft/ifft.rs | 221 ++++++++------- .../prover/src/core/backend/simd/fft/mod.rs | 265 +++++++++++++++++- .../prover/src/core/backend/simd/fft/rfft.rs | 205 +++++++------- 3 files changed, 458 insertions(+), 233 deletions(-) diff --git a/crates/prover/src/core/backend/simd/fft/ifft.rs b/crates/prover/src/core/backend/simd/fft/ifft.rs index 35a458d02..6c11bb17c 100644 --- a/crates/prover/src/core/backend/simd/fft/ifft.rs +++ b/crates/prover/src/core/backend/simd/fft/ifft.rs @@ -1,17 +1,15 @@ //! Inverse fft. -use std::arch::x86_64::{ - __m512i, _mm512_broadcast_i32x4, _mm512_mul_epu32, _mm512_permutex2var_epi32, - _mm512_set1_epi32, _mm512_set1_epi64, _mm512_srli_epi64, -}; - -use super::{compute_first_twiddles, EVENS_INTERLEAVE_EVENS, ODDS_INTERLEAVE_ODDS}; -use crate::core::backend::avx512::fft::{transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE}; -use crate::core::backend::avx512::{PackedBaseField, VECS_LOG_SIZE}; +use std::simd::{simd_swizzle, u32x16, u32x2, u32x4}; + +use super::{compute_first_twiddles, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE}; +use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; use crate::core::circle::Coset; use crate::core::fields::FieldExpOps; use crate::core::utils::bit_reverse; +const VECS_LOG_SIZE: usize = LOG_N_LANES as usize; + /// Performs an Inverse Circle Fast Fourier Transform (ICFFT) on the given values. /// /// # Safety @@ -155,8 +153,9 @@ pub unsafe fn ifft_vecwise_loop( ) { for index_l in 0..(1 << loop_bits) { let index = (index_h << loop_bits) + index_l; - let mut val0 = PackedBaseField::load(values.add(index * 32).cast_const()); - let mut val1 = PackedBaseField::load(values.add(index * 32 + 16).cast_const()); + let mut val0 = PackedBaseField::load(values.add(index * 32).cast_const() as *const u32); + let mut val1 = + PackedBaseField::load(values.add(index * 32 + 16).cast_const() as *const u32); (val0, val1) = vecwise_ibutterflies( val0, val1, @@ -167,10 +166,10 @@ pub unsafe fn ifft_vecwise_loop( (val0, val1) = avx_ibutterfly( val0, val1, - _mm512_set1_epi32(*twiddle_dbl[3].get_unchecked(index)), + u32x16::splat(*twiddle_dbl[3].get_unchecked(index) as u32), ); - val0.store(values.add(index * 32)); - val1.store(values.add(index * 32 + 16)); + val0.store(values.add(index * 32) as *mut u32); + val1.store(values.add(index * 32 + 16) as *mut u32); } } @@ -269,41 +268,28 @@ unsafe fn ifft1_loop(values: *mut i32, twiddle_dbl: &[&[i32]], layer: usize, ind pub unsafe fn avx_ibutterfly( val0: PackedBaseField, val1: PackedBaseField, - twiddle_dbl: __m512i, + twiddle_dbl: u32x16, ) -> (PackedBaseField, PackedBaseField) { let r0 = val0 + val1; let r1 = val0 - val1; - // Extract the even and odd parts of r1 and twiddle_dbl, and spread as 8 64bit values. - let r1_e = r1.0; - let r1_o = _mm512_srli_epi64(r1.0, 32); - let twiddle_dbl_e = twiddle_dbl; - let twiddle_dbl_o = _mm512_srli_epi64(twiddle_dbl, 32); - - // To compute prod = r1 * twiddle start by multiplying - // r1_e/o by twiddle_dbl_e/o. - let prod_e_dbl = _mm512_mul_epu32(r1_e, twiddle_dbl_e); - let prod_o_dbl = _mm512_mul_epu32(r1_o, twiddle_dbl_o); - - // The result of a multiplication holds r1*twiddle_dbl in as 64-bits. - // Each 64b-bit word looks like this: - // 1 31 31 1 - // prod_e_dbl - |0|prod_e_h|prod_e_l|0| - // prod_o_dbl - |0|prod_o_h|prod_o_l|0| - - // Interleave the even words of prod_e_dbl with the even words of prod_o_dbl: - let prod_ls = _mm512_permutex2var_epi32(prod_e_dbl, EVENS_INTERLEAVE_EVENS, prod_o_dbl); - // prod_ls - |prod_o_l|0|prod_e_l|0| - - // Divide by 2: - let prod_ls = _mm512_srli_epi64(prod_ls, 1); - // prod_ls - |0|prod_o_l|0|prod_e_l| - - // Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl: - let prod_hs = _mm512_permutex2var_epi32(prod_e_dbl, ODDS_INTERLEAVE_ODDS, prod_o_dbl); - // prod_hs - |0|prod_o_h|0|prod_e_h| - - let prod = PackedBaseField(prod_ls) + PackedBaseField(prod_hs); + let prod = { + // TODO: Come up with a better approach than `cfg`ing on target_feature. + // TODO: Ensure all these branches get tested in the CI. + cfg_if::cfg_if! { + if #[cfg(all(target_feature = "neon", target_arch = "aarch64"))] { + super::_mul_twiddle_neon(r1, twiddle_dbl) + } else if #[cfg(all(target_feature = "simd128", target_arch = "wasm32"))] { + super::_mul_twiddle_wasm(r1, twiddle_dbl) + } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] { + super::_mul_twiddle_avx512(r1, twiddle_dbl) + } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2f"))] { + super::_mul_twiddle_avx2(r1, twiddle_dbl) + } else { + super::_mul_twiddle_simd(r1, twiddle_dbl) + } + } + }; (r0, prod) } @@ -349,29 +335,35 @@ pub unsafe fn vecwise_ibutterflies( let (t0, t1) = compute_first_twiddles(twiddle1_dbl); // Apply the permutation, resulting in indexing d:iabc. - (val0, val1) = val0.deinterleave_with(val1); + (val0, val1) = val0.deinterleave(val1); (val0, val1) = avx_ibutterfly(val0, val1, t0); // Apply the permutation, resulting in indexing c:diab. - (val0, val1) = val0.deinterleave_with(val1); + (val0, val1) = val0.deinterleave(val1); (val0, val1) = avx_ibutterfly(val0, val1, t1); // The twiddles for layer 2 are replicated in the following pattern: // 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 - let t = _mm512_broadcast_i32x4(std::mem::transmute(twiddle2_dbl)); + let t = simd_swizzle!( + u32x4::from_array(unsafe { std::mem::transmute(twiddle2_dbl) }), + [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3] + ); // Apply the permutation, resulting in indexing b:cdia. - (val0, val1) = val0.deinterleave_with(val1); + (val0, val1) = val0.deinterleave(val1); (val0, val1) = avx_ibutterfly(val0, val1, t); // The twiddles for layer 3 are replicated in the following pattern: // 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 - let t = _mm512_set1_epi64(std::mem::transmute(twiddle3_dbl)); + let t = simd_swizzle!( + u32x2::from_array(unsafe { std::mem::transmute(twiddle3_dbl) }), + [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + ); // Apply the permutation, resulting in indexing a:bcid. - (val0, val1) = val0.deinterleave_with(val1); + (val0, val1) = val0.deinterleave(val1); (val0, val1) = avx_ibutterfly(val0, val1, t); // Apply the permutation, resulting in indexing i:abcd. - val0.deinterleave_with(val1) + val0.deinterleave(val1) } /// Returns the line twiddles (x points) for an ifft on a coset. @@ -414,42 +406,50 @@ pub unsafe fn ifft3( twiddles_dbl2: [i32; 1], ) { // Load the 8 AVX vectors from the array. - let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const()); - let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const()); - let mut val2 = PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const()); - let mut val3 = PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const()); - let mut val4 = PackedBaseField::load(values.add(offset + (4 << log_step)).cast_const()); - let mut val5 = PackedBaseField::load(values.add(offset + (5 << log_step)).cast_const()); - let mut val6 = PackedBaseField::load(values.add(offset + (6 << log_step)).cast_const()); - let mut val7 = PackedBaseField::load(values.add(offset + (7 << log_step)).cast_const()); + let mut val0 = + PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const() as *const u32); + let mut val1 = + PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const() as *const u32); + let mut val2 = + PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const() as *const u32); + let mut val3 = + PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const() as *const u32); + let mut val4 = + PackedBaseField::load(values.add(offset + (4 << log_step)).cast_const() as *const u32); + let mut val5 = + PackedBaseField::load(values.add(offset + (5 << log_step)).cast_const() as *const u32); + let mut val6 = + PackedBaseField::load(values.add(offset + (6 << log_step)).cast_const() as *const u32); + let mut val7 = + PackedBaseField::load(values.add(offset + (7 << log_step)).cast_const() as *const u32); // Apply the first layer of ibutterflies. - (val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); - (val2, val3) = avx_ibutterfly(val2, val3, _mm512_set1_epi32(twiddles_dbl0[1])); - (val4, val5) = avx_ibutterfly(val4, val5, _mm512_set1_epi32(twiddles_dbl0[2])); - (val6, val7) = avx_ibutterfly(val6, val7, _mm512_set1_epi32(twiddles_dbl0[3])); + (val0, val1) = avx_ibutterfly(val0, val1, u32x16::splat(twiddles_dbl0[0] as u32)); + (val2, val3) = avx_ibutterfly(val2, val3, u32x16::splat(twiddles_dbl0[1] as u32)); + (val4, val5) = avx_ibutterfly(val4, val5, u32x16::splat(twiddles_dbl0[2] as u32)); + (val6, val7) = avx_ibutterfly(val6, val7, u32x16::splat(twiddles_dbl0[3] as u32)); // Apply the second layer of ibutterflies. - (val0, val2) = avx_ibutterfly(val0, val2, _mm512_set1_epi32(twiddles_dbl1[0])); - (val1, val3) = avx_ibutterfly(val1, val3, _mm512_set1_epi32(twiddles_dbl1[0])); - (val4, val6) = avx_ibutterfly(val4, val6, _mm512_set1_epi32(twiddles_dbl1[1])); - (val5, val7) = avx_ibutterfly(val5, val7, _mm512_set1_epi32(twiddles_dbl1[1])); + (val0, val2) = avx_ibutterfly(val0, val2, u32x16::splat(twiddles_dbl1[0] as u32)); + (val1, val3) = avx_ibutterfly(val1, val3, u32x16::splat(twiddles_dbl1[0] as u32)); + (val4, val6) = avx_ibutterfly(val4, val6, u32x16::splat(twiddles_dbl1[1] as u32)); + (val5, val7) = avx_ibutterfly(val5, val7, u32x16::splat(twiddles_dbl1[1] as u32)); // Apply the third layer of ibutterflies. - (val0, val4) = avx_ibutterfly(val0, val4, _mm512_set1_epi32(twiddles_dbl2[0])); - (val1, val5) = avx_ibutterfly(val1, val5, _mm512_set1_epi32(twiddles_dbl2[0])); - (val2, val6) = avx_ibutterfly(val2, val6, _mm512_set1_epi32(twiddles_dbl2[0])); - (val3, val7) = avx_ibutterfly(val3, val7, _mm512_set1_epi32(twiddles_dbl2[0])); + (val0, val4) = avx_ibutterfly(val0, val4, u32x16::splat(twiddles_dbl2[0] as u32)); + (val1, val5) = avx_ibutterfly(val1, val5, u32x16::splat(twiddles_dbl2[0] as u32)); + (val2, val6) = avx_ibutterfly(val2, val6, u32x16::splat(twiddles_dbl2[0] as u32)); + (val3, val7) = avx_ibutterfly(val3, val7, u32x16::splat(twiddles_dbl2[0] as u32)); // Store the 8 AVX vectors back to the array. - val0.store(values.add(offset + (0 << log_step))); - val1.store(values.add(offset + (1 << log_step))); - val2.store(values.add(offset + (2 << log_step))); - val3.store(values.add(offset + (3 << log_step))); - val4.store(values.add(offset + (4 << log_step))); - val5.store(values.add(offset + (5 << log_step))); - val6.store(values.add(offset + (6 << log_step))); - val7.store(values.add(offset + (7 << log_step))); + val0.store(values.add(offset + (0 << log_step)) as *mut u32); + val1.store(values.add(offset + (1 << log_step)) as *mut u32); + val2.store(values.add(offset + (2 << log_step)) as *mut u32); + val3.store(values.add(offset + (3 << log_step)) as *mut u32); + val4.store(values.add(offset + (4 << log_step)) as *mut u32); + val5.store(values.add(offset + (5 << log_step)) as *mut u32); + val6.store(values.add(offset + (6 << log_step)) as *mut u32); + val7.store(values.add(offset + (7 << log_step)) as *mut u32); } /// Applies 2 ibutterfly layers on 4 vectors of 16 M31 elements. @@ -473,24 +473,28 @@ pub unsafe fn ifft2( twiddles_dbl1: [i32; 1], ) { // Load the 4 AVX vectors from the array. - let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const()); - let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const()); - let mut val2 = PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const()); - let mut val3 = PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const()); + let mut val0 = + PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const() as *const u32); + let mut val1 = + PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const() as *const u32); + let mut val2 = + PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const() as *const u32); + let mut val3 = + PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const() as *const u32); // Apply the first layer of butterflies. - (val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); - (val2, val3) = avx_ibutterfly(val2, val3, _mm512_set1_epi32(twiddles_dbl0[1])); + (val0, val1) = avx_ibutterfly(val0, val1, u32x16::splat(twiddles_dbl0[0] as u32)); + (val2, val3) = avx_ibutterfly(val2, val3, u32x16::splat(twiddles_dbl0[1] as u32)); // Apply the second layer of butterflies. - (val0, val2) = avx_ibutterfly(val0, val2, _mm512_set1_epi32(twiddles_dbl1[0])); - (val1, val3) = avx_ibutterfly(val1, val3, _mm512_set1_epi32(twiddles_dbl1[0])); + (val0, val2) = avx_ibutterfly(val0, val2, u32x16::splat(twiddles_dbl1[0] as u32)); + (val1, val3) = avx_ibutterfly(val1, val3, u32x16::splat(twiddles_dbl1[0] as u32)); // Store the 4 AVX vectors back to the array. - val0.store(values.add(offset + (0 << log_step))); - val1.store(values.add(offset + (1 << log_step))); - val2.store(values.add(offset + (2 << log_step))); - val3.store(values.add(offset + (3 << log_step))); + val0.store(values.add(offset + (0 << log_step)) as *mut u32); + val1.store(values.add(offset + (1 << log_step)) as *mut u32); + val2.store(values.add(offset + (2 << log_step)) as *mut u32); + val3.store(values.add(offset + (3 << log_step)) as *mut u32); } /// Applies 1 ibutterfly layers on 2 vectors of 16 M31 elements. @@ -504,25 +508,24 @@ pub unsafe fn ifft2( /// # Safety pub unsafe fn ifft1(values: *mut i32, offset: usize, log_step: usize, twiddles_dbl0: [i32; 1]) { // Load the 2 AVX vectors from the array. - let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const()); - let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const()); + let mut val0 = + PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const() as *const u32); + let mut val1 = + PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const() as *const u32); - (val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); + (val0, val1) = avx_ibutterfly(val0, val1, u32x16::splat(twiddles_dbl0[0] as u32)); // Store the 2 AVX vectors back to the array. - val0.store(values.add(offset + (0 << log_step))); - val1.store(values.add(offset + (1 << log_step))); + val0.store(values.add(offset + (0 << log_step)) as *mut u32); + val1.store(values.add(offset + (1 << log_step)) as *mut u32); } -#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] #[cfg(test)] mod tests { - use std::arch::x86_64::{_mm512_add_epi32, _mm512_setr_epi32}; - use super::*; - use crate::core::backend::avx512::m31::PackedBaseField; - use crate::core::backend::avx512::BaseFieldVec; use crate::core::backend::cpu::CPUCircleEvaluation; + use crate::core::backend::simd::m31::PackedBaseField; + use crate::core::backend::simd::BaseFieldVec; use crate::core::backend::Column; use crate::core::fft::ibutterfly; use crate::core::fields::m31::BaseField; @@ -531,16 +534,16 @@ mod tests { #[test] fn test_ibutterfly() { unsafe { - let val0 = PackedBaseField(_mm512_setr_epi32( + let val0 = PackedBaseField::from_simd_unchecked(u32x16::from_array([ 2, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - )); - let val1 = PackedBaseField(_mm512_setr_epi32( + ])); + let val1 = PackedBaseField::from_simd_unchecked(u32x16::from_array([ 3, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, - )); - let twiddle = _mm512_setr_epi32( + ])); + let twiddle = u32x16::from_array([ 1177558791, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, - ); - let twiddle_dbl = _mm512_add_epi32(twiddle, twiddle); + ]); + let twiddle_dbl = twiddle + twiddle; let (r0, r1) = avx_ibutterfly(val0, val1, twiddle_dbl); let val0: [BaseField; 16] = std::mem::transmute(val0); @@ -647,7 +650,7 @@ mod tests { twiddle_dbls[1].clone().try_into().unwrap(), twiddle_dbls[2].clone().try_into().unwrap(), ); - let (val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddle_dbls[3][0])); + let (val0, val1) = avx_ibutterfly(val0, val1, u32x16::splat(twiddle_dbls[3][0] as u32)); std::mem::transmute([val0, val1]) }; diff --git a/crates/prover/src/core/backend/simd/fft/mod.rs b/crates/prover/src/core/backend/simd/fft/mod.rs index 37dbbd686..958896c19 100644 --- a/crates/prover/src/core/backend/simd/fft/mod.rs +++ b/crates/prover/src/core/backend/simd/fft/mod.rs @@ -1,14 +1,16 @@ -use std::arch::x86_64::{ - __m512i, _mm512_broadcast_i64x4, _mm512_load_epi32, _mm512_permutexvar_epi32, - _mm512_store_epi32, _mm512_xor_epi32, -}; +use std::mem::transmute; +use std::simd::{i32x16, simd_swizzle, u32x16, u32x8, Simd, Swizzle}; + +use super::m31::PackedBaseField; +use crate::core::backend::simd::m31::N_LANES; +use crate::core::backend::simd::utils::{LoEvensInterleaveHiEvens, LoOddsInterleaveHiOdds}; pub mod ifft; pub mod rfft; /// An input to _mm512_permutex2var_epi32, and is used to interleave the even words of a /// with the even words of b. -const EVENS_INTERLEAVE_EVENS: __m512i = unsafe { +const _EVENS_INTERLEAVE_EVENS: u32x16 = unsafe { core::mem::transmute([ 0b00000, 0b10000, 0b00010, 0b10010, 0b00100, 0b10100, 0b00110, 0b10110, 0b01000, 0b11000, 0b01010, 0b11010, 0b01100, 0b11100, 0b01110, 0b11110, @@ -16,7 +18,7 @@ const EVENS_INTERLEAVE_EVENS: __m512i = unsafe { }; /// An input to _mm512_permutex2var_epi32, and is used to interleave the odd words of a /// with the odd words of b. -const ODDS_INTERLEAVE_ODDS: __m512i = unsafe { +const _ODDS_INTERLEAVE_ODDS: u32x16 = unsafe { core::mem::transmute([ 0b00001, 0b10001, 0b00011, 0b10011, 0b00101, 0b10101, 0b00111, 0b10111, 0b01001, 0b11001, 0b01011, 0b11011, 0b01101, 0b11101, 0b01111, 0b11111, @@ -60,14 +62,27 @@ pub unsafe fn transpose_vecs(values: *mut i32, log_n_vecs: usize) { } } +unsafe fn _mm512_load_epi32(mem_addr: *const i32) -> u32x16 { + std::ptr::read(mem_addr as *const u32x16) +} + +unsafe fn _mm512_store_epi32(mem_addr: *mut i32, a: u32x16) { + std::ptr::write(mem_addr as *mut u32x16, a); +} + /// Computes the twiddles for the first fft layer from the second, and loads both to AVX registers. /// Returns the twiddles for the first layer and the twiddles for the second layer. /// # Safety -pub unsafe fn compute_first_twiddles(twiddle1_dbl: [i32; 8]) -> (__m512i, __m512i) { +pub unsafe fn compute_first_twiddles(twiddle1_dbl: [i32; 8]) -> (u32x16, u32x16) { // Start by loading the twiddles for the second layer (layer 1): // The twiddles for layer 1 are replicated in the following pattern: // 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 - let t1 = _mm512_broadcast_i64x4(std::mem::transmute(twiddle1_dbl)); + // let t1 = _mm512_broadcast_i64x4(std::mem::transmute(twiddle1_dbl)); + let t1 = simd_swizzle!( + unsafe { transmute::<_, u32x8>(twiddle1_dbl) }, + unsafe { transmute::<_, u32x8>(twiddle1_dbl) }, + [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7] + ); // The twiddles for layer 0 can be computed from the twiddles for layer 1. // Since the twiddles are bit reversed, we consider the circle domain in bit reversed order. @@ -84,17 +99,237 @@ pub unsafe fn compute_first_twiddles(twiddle1_dbl: [i32; 8]) -> (__m512i, __m512 // The twiddles for layer 0 are computed like this: // t0[4i:4i+3] = [t1[2i+1], -t1[2i+1], -t1[2i], t1[2i]] - const INDICES_FROM_T1: __m512i = unsafe { - core::mem::transmute([ + struct IndicesFromT1; + + impl Swizzle<16> for IndicesFromT1 { + const INDEX: [usize; 16] = [ 0b0001, 0b0001, 0b0000, 0b0000, 0b0011, 0b0011, 0b0010, 0b0010, 0b0101, 0b0101, 0b0100, 0b0100, 0b0111, 0b0111, 0b0110, 0b0110, - ]) - }; + ]; + } // Xoring a double twiddle with 2^32-2 transforms it to the double of it negation. // Note that this keeps the values as a double of a value in the range [0, P]. - const NEGATION_MASK: __m512i = unsafe { - core::mem::transmute([0i32, -2, -2, 0, 0, -2, -2, 0, 0, -2, -2, 0, 0, -2, -2, 0]) + const NEGATION_MASK: u32x16 = unsafe { + transmute(i32x16::from_array([ + 0, -2, -2, 0, 0, -2, -2, 0, 0, -2, -2, 0, 0, -2, -2, 0, + ])) }; - let t0 = _mm512_xor_epi32(_mm512_permutexvar_epi32(INDICES_FROM_T1, t1), NEGATION_MASK); + let t0 = IndicesFromT1::swizzle(t1) ^ NEGATION_MASK; (t0, t1) } + +#[cfg(target_arch = "aarch64")] +fn _mul_twiddle_neon(a: PackedBaseField, twiddle_dbl: u32x16) -> PackedBaseField { + use core::arch::aarch64::{uint32x2_t, vmull_u32}; + use std::simd::u32x4; + + let [a0, a1, a2, a3, a4, a5, a6, a7]: [uint32x2_t; 8] = unsafe { transmute(a) }; + let [b0, b1, b2, b3, b4, b5, b6, b7]: [uint32x2_t; 8] = unsafe { transmute(twiddle_dbl) }; + + // Each c_i contains |0|prod_lo|prod_hi|0|0|prod_lo|prod_hi|0| + let c0: u32x4 = unsafe { transmute(vmull_u32(a0, b0)) }; + let c1: u32x4 = unsafe { transmute(vmull_u32(a1, b1)) }; + let c2: u32x4 = unsafe { transmute(vmull_u32(a2, b2)) }; + let c3: u32x4 = unsafe { transmute(vmull_u32(a3, b3)) }; + let c4: u32x4 = unsafe { transmute(vmull_u32(a4, b4)) }; + let c5: u32x4 = unsafe { transmute(vmull_u32(a5, b5)) }; + let c6: u32x4 = unsafe { transmute(vmull_u32(a6, b6)) }; + let c7: u32x4 = unsafe { transmute(vmull_u32(a7, b7)) }; + + // *_lo contain `|prod_lo|0|prod_lo|0|prod_lo0|0|prod_lo|0|`. + // *_hi contain `|0|prod_hi|0|prod_hi|0|prod_hi|0|prod_hi|`. + let (mut c0_c1_lo, c0_c1_hi) = c0.deinterleave(c1); + let (mut c2_c3_lo, c2_c3_hi) = c2.deinterleave(c3); + let (mut c4_c5_lo, c4_c5_hi) = c4.deinterleave(c5); + let (mut c6_c7_lo, c6_c7_hi) = c6.deinterleave(c7); + + // *_lo contain `|0|prod_lo|0|prod_lo|0|prod_lo|0|prod_lo|`. + c0_c1_lo >>= 1; + c2_c3_lo >>= 1; + c4_c5_lo >>= 1; + c6_c7_lo >>= 1; + + let lo: PackedBaseField = unsafe { transmute([c0_c1_lo, c2_c3_lo, c4_c5_lo, c6_c7_lo]) }; + let hi: PackedBaseField = unsafe { transmute([c0_c1_hi, c2_c3_hi, c4_c5_hi, c6_c7_hi]) }; + + lo + hi +} + +#[cfg(target_arch = "wasm32")] +fn _mul_twiddle_wasm(a: PackedBaseField, twiddle_dbl: u32x16) -> PackedBaseField { + use core::arch::wasm32::{i64x2_extmul_high_u32x4, i64x2_extmul_low_u32x4, v128}; + use std::simd::u32x4; + + let [a0, a1, a2, a3]: [v128; 4] = unsafe { transmute(a) }; + let [b_dbl0, b_dbl1, b_dbl2, b_dbl3]: [v128; 4] = unsafe { transmute(twiddle_dbl) }; + + let c0_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a0, b_dbl0)) }; + let c0_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a0, b_dbl0)) }; + let c1_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a1, b_dbl1)) }; + let c1_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a1, b_dbl1)) }; + let c2_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a2, b_dbl2)) }; + let c2_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a2, b_dbl2)) }; + let c3_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a3, b_dbl3)) }; + let c3_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a3, b_dbl3)) }; + + let (mut c0_even, c0_odd) = c0_lo.deinterleave(c0_hi); + let (mut c1_even, c1_odd) = c1_lo.deinterleave(c1_hi); + let (mut c2_even, c2_odd) = c2_lo.deinterleave(c2_hi); + let (mut c3_even, c3_odd) = c3_lo.deinterleave(c3_hi); + + c0_even >>= 1; + c1_even >>= 1; + c2_even >>= 1; + c3_even >>= 1; + + let even: PackedBaseField = unsafe { transmute([c0_even, c1_even, c2_even, c3_even]) }; + let odd: PackedBaseField = unsafe { transmute([c0_odd, c1_odd, c2_odd, c3_odd]) }; + + even + odd +} + +#[cfg(target_arch = "x86_64")] +fn _mul_twiddle_avx512(a: PackedBaseField, twiddle_dbl: u32x16) -> PackedBaseField { + use std::arch::x86_64::{__m512i, _mm512_mul_epu32, _mm512_srli_epi64}; + + let a: __m512i = unsafe { transmute(a) }; + // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of + // the first operand. + let a_e = a; + // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of + // the first operand. + let a_o = unsafe { _mm512_srli_epi64(a, 32) }; + + let b_dbl = unsafe { transmute(twiddle_dbl) }; + let b_dbl_e = b_dbl; + let b_dbl_o = unsafe { _mm512_srli_epi64(b_dbl, 32) }; + + // To compute prod = a * b start by multiplying a_e/odd by b_dbl_e/odd. + let prod_dbl_e: u32x16 = unsafe { transmute(_mm512_mul_epu32(a_e, b_dbl_e)) }; + let prod_dbl_o: u32x16 = unsafe { transmute(_mm512_mul_epu32(a_o, b_dbl_o)) }; + + // The result of a multiplication holds a*b in as 64-bits. + // Each 64b-bit word looks like this: + // 1 31 31 1 + // prod_dbl_e - |0|prod_e_h|prod_e_l|0| + // prod_dbl_o - |0|prod_o_h|prod_o_l|0| + + // Interleave the even words of prod_dbl_e with the even words of prod_dbl_o: + let mut prod_lo = LoEvensInterleaveHiEvens::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_lo - |prod_dbl_o_l|0|prod_dbl_e_l|0| + // Divide by 2: + prod_lo >>= 1; + // prod_lo - |0|prod_o_l|0|prod_e_l| + + // Interleave the odd words of prod_dbl_e with the odd words of prod_dbl_o: + let prod_hi = LoOddsInterleaveHiOdds::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_hi - |0|prod_o_h|0|prod_e_h| + + unsafe { + PackedBaseField::from_simd_unchecked(prod_lo) + + PackedBaseField::from_simd_unchecked(prod_hi) + } +} + +#[cfg(target_arch = "x86_64")] +fn _mul_twiddle_avx2(a: PackedBaseField, twiddle_dbl: u32x16) -> PackedBaseField { + use std::arch::x86_64::{__m256i, _mm256_mul_epu32, _mm256_srli_epi64}; + + let [a0, a1]: [__m256i; 2] = unsafe { transmute(a) }; + let [b0_dbl, b1_dbl]: [__m256i; 2] = unsafe { transmute(twiddle_dbl) }; + + // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of + // the first operand. + let a0_e = a0; + let a1_e = a1; + // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of + // the first operand. + let a0_o = unsafe { _mm256_srli_epi64(a0, 32) }; + let a1_o = unsafe { _mm256_srli_epi64(a1, 32) }; + + let b0_dbl_e = b0_dbl; + let b1_dbl_e = b1_dbl; + let b0_dbl_o = unsafe { _mm256_srli_epi64(b0_dbl, 32) }; + let b1_dbl_o = unsafe { _mm256_srli_epi64(b1_dbl, 32) }; + + // To compute prod = a * b start by multiplying a0/1_e/odd by b0/1_e/odd. + let prod0_dbl_e = unsafe { _mm256_mul_epu32(a0_e, b0_dbl_e) }; + let prod0_dbl_o = unsafe { _mm256_mul_epu32(a0_o, b0_dbl_o) }; + let prod1_dbl_e = unsafe { _mm256_mul_epu32(a1_e, b1_dbl_e) }; + let prod1_dbl_o = unsafe { _mm256_mul_epu32(a1_o, b1_dbl_o) }; + + let prod_dbl_e: u32x16 = unsafe { transmute([prod0_dbl_e, prod1_dbl_e]) }; + let prod_dbl_o: u32x16 = unsafe { transmute([prod0_dbl_o, prod1_dbl_o]) }; + + // The result of a multiplication holds a*b in as 64-bits. + // Each 64b-bit word looks like this: + // 1 31 31 1 + // prod_dbl_e - |0|prod_e_h|prod_e_l|0| + // prod_dbl_o - |0|prod_o_h|prod_o_l|0| + + // Interleave the even words of prod_dbl_e with the even words of prod_dbl_o: + let mut prod_lo = LoEvensInterleaveHiEvens::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_lo - |prod_dbl_o_l|0|prod_dbl_e_l|0| + // Divide by 2: + prod_lo >>= 1; + // prod_lo - |0|prod_o_l|0|prod_e_l| + + // Interleave the odd words of prod_dbl_e with the odd words of prod_dbl_o: + let prod_hi = LoOddsInterleaveHiOdds::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_hi - |0|prod_o_h|0|prod_e_h| + + unsafe { + PackedBaseField::from_simd_unchecked(prod_lo) + + PackedBaseField::from_simd_unchecked(prod_hi) + } +} + +// Should only be used in the absence of a platform specific implementation. +fn _mul_twiddle_simd(a: PackedBaseField, twiddle_dbl: u32x16) -> PackedBaseField { + const MASK_EVENS: Simd = Simd::from_array([0xFFFFFFFF; { N_LANES / 2 }]); + + // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of + // the first operand. + let a_e = unsafe { transmute::<_, Simd>(a.into_simd()) & MASK_EVENS }; + // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of + // the first operand. + let a_o = unsafe { transmute::<_, Simd>(a) >> 32 }; + + let b_dbl_e = unsafe { transmute::<_, Simd>(twiddle_dbl) & MASK_EVENS }; + let b_dbl_o = unsafe { transmute::<_, Simd>(twiddle_dbl) >> 32 }; + + // To compute prod = a * b start by multiplying + // a_e/o by b_dbl_e/o. + let prod_e_dbl = a_e * b_dbl_e; + let prod_o_dbl = a_o * b_dbl_o; + + // The result of a multiplication holds a*b in as 64-bits. + // Each 64b-bit word looks like this: + // 1 31 31 1 + // prod_e_dbl - |0|prod_e_h|prod_e_l|0| + // prod_o_dbl - |0|prod_o_h|prod_o_l|0| + + // Interleave the even words of prod_e_dbl with the even words of prod_o_dbl: + // let prod_lows = _mm512_permutex2var_epi32(prod_e_dbl, EVENS_INTERLEAVE_EVENS, + // prod_o_dbl); + // prod_ls - |prod_o_l|0|prod_e_l|0| + let mut prod_lows = LoEvensInterleaveHiEvens::concat_swizzle( + unsafe { transmute::<_, Simd>(prod_e_dbl) }, + unsafe { transmute::<_, Simd>(prod_o_dbl) }, + ); + // Divide by 2: + prod_lows >>= 1; + // prod_ls - |0|prod_o_l|0|prod_e_l| + + // Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl: + let prod_highs = LoOddsInterleaveHiOdds::concat_swizzle( + unsafe { transmute::<_, Simd>(prod_e_dbl) }, + unsafe { transmute::<_, Simd>(prod_o_dbl) }, + ); + + // prod_hs - |0|prod_o_h|0|prod_e_h| + unsafe { + PackedBaseField::from_simd_unchecked(prod_lows) + + PackedBaseField::from_simd_unchecked(prod_highs) + } +} diff --git a/crates/prover/src/core/backend/simd/fft/rfft.rs b/crates/prover/src/core/backend/simd/fft/rfft.rs index 43fbf7f32..345005ad3 100644 --- a/crates/prover/src/core/backend/simd/fft/rfft.rs +++ b/crates/prover/src/core/backend/simd/fft/rfft.rs @@ -1,16 +1,14 @@ //! Regular (forward) fft. -use std::arch::x86_64::{ - __m512i, _mm512_broadcast_i32x4, _mm512_mul_epu32, _mm512_permutex2var_epi32, - _mm512_set1_epi32, _mm512_set1_epi64, _mm512_srli_epi64, -}; - -use super::{compute_first_twiddles, EVENS_INTERLEAVE_EVENS, ODDS_INTERLEAVE_ODDS}; -use crate::core::backend::avx512::fft::{transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE}; -use crate::core::backend::avx512::{PackedBaseField, VECS_LOG_SIZE}; +use std::simd::{simd_swizzle, u32x16, u32x2, u32x4}; + +use super::{compute_first_twiddles, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE}; +use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; use crate::core::circle::Coset; use crate::core::utils::bit_reverse; +const VECS_LOG_SIZE: usize = LOG_N_LANES as usize; + /// Performs a Circle Fast Fourier Transform (ICFFT) on the given values. /// /// # Safety @@ -175,12 +173,12 @@ unsafe fn fft_vecwise_loop( ) { for index_l in 0..(1 << loop_bits) { let index = (index_h << loop_bits) + index_l; - let mut val0 = PackedBaseField::load(src.add(index * 32)); - let mut val1 = PackedBaseField::load(src.add(index * 32 + 16)); + let mut val0 = PackedBaseField::load(src.add(index * 32) as *const u32); + let mut val1 = PackedBaseField::load(src.add(index * 32 + 16) as *const u32); (val0, val1) = avx_butterfly( val0, val1, - _mm512_set1_epi32(*twiddle_dbl[3].get_unchecked(index)), + u32x16::splat(*twiddle_dbl[3].get_unchecked(index) as u32), ); (val0, val1) = vecwise_butterflies( val0, @@ -189,8 +187,8 @@ unsafe fn fft_vecwise_loop( std::array::from_fn(|i| *twiddle_dbl[1].get_unchecked(index * 4 + i)), std::array::from_fn(|i| *twiddle_dbl[2].get_unchecked(index * 2 + i)), ); - val0.store(dst.add(index * 32)); - val1.store(dst.add(index * 32 + 16)); + val0.store(dst.add(index * 32) as *mut u32); + val1.store(dst.add(index * 32 + 16) as *mut u32); } } @@ -309,39 +307,25 @@ unsafe fn fft1_loop( pub unsafe fn avx_butterfly( val0: PackedBaseField, val1: PackedBaseField, - twiddle_dbl: __m512i, + twiddle_dbl: u32x16, ) -> (PackedBaseField, PackedBaseField) { - // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of val0. - let val1_e = val1.0; - // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of val0. - let val1_o = _mm512_srli_epi64(val1.0, 32); - let twiddle_dbl_e = twiddle_dbl; - let twiddle_dbl_o = _mm512_srli_epi64(twiddle_dbl, 32); - - // To compute prod = val1 * twiddle start by multiplying - // val1_e/o by twiddle_dbl_e/o. - let prod_e_dbl = _mm512_mul_epu32(val1_e, twiddle_dbl_e); - let prod_o_dbl = _mm512_mul_epu32(val1_o, twiddle_dbl_o); - - // The result of a multiplication holds val1*twiddle_dbl in as 64-bits. - // Each 64b-bit word looks like this: - // 1 31 31 1 - // prod_e_dbl - |0|prod_e_h|prod_e_l|0| - // prod_o_dbl - |0|prod_o_h|prod_o_l|0| - - // Interleave the even words of prod_e_dbl with the even words of prod_o_dbl: - let prod_ls = _mm512_permutex2var_epi32(prod_e_dbl, EVENS_INTERLEAVE_EVENS, prod_o_dbl); - // prod_ls - |prod_o_l|0|prod_e_l|0| - - // Divide by 2: - let prod_ls = _mm512_srli_epi64(prod_ls, 1); - // prod_ls - |0|prod_o_l|0|prod_e_l| - - // Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl: - let prod_hs = _mm512_permutex2var_epi32(prod_e_dbl, ODDS_INTERLEAVE_ODDS, prod_o_dbl); - // prod_hs - |0|prod_o_h|0|prod_e_h| - - let prod = PackedBaseField(prod_ls) + PackedBaseField(prod_hs); + let prod = { + // TODO: Come up with a better approach than `cfg`ing on target_feature. + // TODO: Ensure all these branches get tested in the CI. + cfg_if::cfg_if! { + if #[cfg(all(target_feature = "neon", target_arch = "aarch64"))] { + super::_mul_twiddle_neon(val1, twiddle_dbl) + } else if #[cfg(all(target_feature = "simd128", target_arch = "wasm32"))] { + super::_mul_twiddle_wasm(val1, twiddle_dbl) + } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] { + super::_mul_twiddle_avx512(val1, twiddle_dbl) + } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2f"))] { + super::_mul_twiddle_avx2(val1, twiddle_dbl) + } else { + super::_mul_twiddle_simd(val1, twiddle_dbl) + } + } + }; let r0 = val0 + prod; let r1 = val0 - prod; @@ -369,22 +353,28 @@ pub unsafe fn vecwise_butterflies( // TODO(spapini): The permute can be fused with the _mm512_srli_epi64 inside the butterfly. // The implementation is the exact reverse of vecwise_ibutterflies(). // See the comments in its body for more info. - let t = _mm512_set1_epi64(std::mem::transmute(twiddle3_dbl)); - (val0, val1) = val0.interleave_with(val1); + let t = simd_swizzle!( + u32x2::from_array(unsafe { std::mem::transmute(twiddle3_dbl) }), + [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + ); + (val0, val1) = val0.interleave(val1); (val0, val1) = avx_butterfly(val0, val1, t); - let t = _mm512_broadcast_i32x4(std::mem::transmute(twiddle2_dbl)); - (val0, val1) = val0.interleave_with(val1); + let t = simd_swizzle!( + u32x4::from_array(unsafe { std::mem::transmute(twiddle2_dbl) }), + [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3] + ); + (val0, val1) = val0.interleave(val1); (val0, val1) = avx_butterfly(val0, val1, t); let (t0, t1) = compute_first_twiddles(twiddle1_dbl); - (val0, val1) = val0.interleave_with(val1); + (val0, val1) = val0.interleave(val1); (val0, val1) = avx_butterfly(val0, val1, t1); - (val0, val1) = val0.interleave_with(val1); + (val0, val1) = val0.interleave(val1); (val0, val1) = avx_butterfly(val0, val1, t0); - val0.interleave_with(val1) + val0.interleave(val1) } /// Returns the line twiddles (x points) for an fft on a coset. @@ -429,42 +419,42 @@ pub unsafe fn fft3( twiddles_dbl2: [i32; 1], ) { // Load the 8 AVX vectors from the array. - let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step))); - let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step))); - let mut val2 = PackedBaseField::load(src.add(offset + (2 << log_step))); - let mut val3 = PackedBaseField::load(src.add(offset + (3 << log_step))); - let mut val4 = PackedBaseField::load(src.add(offset + (4 << log_step))); - let mut val5 = PackedBaseField::load(src.add(offset + (5 << log_step))); - let mut val6 = PackedBaseField::load(src.add(offset + (6 << log_step))); - let mut val7 = PackedBaseField::load(src.add(offset + (7 << log_step))); + let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step)) as *const u32); + let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step)) as *const u32); + let mut val2 = PackedBaseField::load(src.add(offset + (2 << log_step)) as *const u32); + let mut val3 = PackedBaseField::load(src.add(offset + (3 << log_step)) as *const u32); + let mut val4 = PackedBaseField::load(src.add(offset + (4 << log_step)) as *const u32); + let mut val5 = PackedBaseField::load(src.add(offset + (5 << log_step)) as *const u32); + let mut val6 = PackedBaseField::load(src.add(offset + (6 << log_step)) as *const u32); + let mut val7 = PackedBaseField::load(src.add(offset + (7 << log_step)) as *const u32); // Apply the third layer of butterflies. - (val0, val4) = avx_butterfly(val0, val4, _mm512_set1_epi32(twiddles_dbl2[0])); - (val1, val5) = avx_butterfly(val1, val5, _mm512_set1_epi32(twiddles_dbl2[0])); - (val2, val6) = avx_butterfly(val2, val6, _mm512_set1_epi32(twiddles_dbl2[0])); - (val3, val7) = avx_butterfly(val3, val7, _mm512_set1_epi32(twiddles_dbl2[0])); + (val0, val4) = avx_butterfly(val0, val4, u32x16::splat(twiddles_dbl2[0] as u32)); + (val1, val5) = avx_butterfly(val1, val5, u32x16::splat(twiddles_dbl2[0] as u32)); + (val2, val6) = avx_butterfly(val2, val6, u32x16::splat(twiddles_dbl2[0] as u32)); + (val3, val7) = avx_butterfly(val3, val7, u32x16::splat(twiddles_dbl2[0] as u32)); // Apply the second layer of butterflies. - (val0, val2) = avx_butterfly(val0, val2, _mm512_set1_epi32(twiddles_dbl1[0])); - (val1, val3) = avx_butterfly(val1, val3, _mm512_set1_epi32(twiddles_dbl1[0])); - (val4, val6) = avx_butterfly(val4, val6, _mm512_set1_epi32(twiddles_dbl1[1])); - (val5, val7) = avx_butterfly(val5, val7, _mm512_set1_epi32(twiddles_dbl1[1])); + (val0, val2) = avx_butterfly(val0, val2, u32x16::splat(twiddles_dbl1[0] as u32)); + (val1, val3) = avx_butterfly(val1, val3, u32x16::splat(twiddles_dbl1[0] as u32)); + (val4, val6) = avx_butterfly(val4, val6, u32x16::splat(twiddles_dbl1[1] as u32)); + (val5, val7) = avx_butterfly(val5, val7, u32x16::splat(twiddles_dbl1[1] as u32)); // Apply the first layer of butterflies. - (val0, val1) = avx_butterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); - (val2, val3) = avx_butterfly(val2, val3, _mm512_set1_epi32(twiddles_dbl0[1])); - (val4, val5) = avx_butterfly(val4, val5, _mm512_set1_epi32(twiddles_dbl0[2])); - (val6, val7) = avx_butterfly(val6, val7, _mm512_set1_epi32(twiddles_dbl0[3])); + (val0, val1) = avx_butterfly(val0, val1, u32x16::splat(twiddles_dbl0[0] as u32)); + (val2, val3) = avx_butterfly(val2, val3, u32x16::splat(twiddles_dbl0[1] as u32)); + (val4, val5) = avx_butterfly(val4, val5, u32x16::splat(twiddles_dbl0[2] as u32)); + (val6, val7) = avx_butterfly(val6, val7, u32x16::splat(twiddles_dbl0[3] as u32)); // Store the 8 AVX vectors back to the array. - val0.store(dst.add(offset + (0 << log_step))); - val1.store(dst.add(offset + (1 << log_step))); - val2.store(dst.add(offset + (2 << log_step))); - val3.store(dst.add(offset + (3 << log_step))); - val4.store(dst.add(offset + (4 << log_step))); - val5.store(dst.add(offset + (5 << log_step))); - val6.store(dst.add(offset + (6 << log_step))); - val7.store(dst.add(offset + (7 << log_step))); + val0.store(dst.add(offset + (0 << log_step)) as *mut u32); + val1.store(dst.add(offset + (1 << log_step)) as *mut u32); + val2.store(dst.add(offset + (2 << log_step)) as *mut u32); + val3.store(dst.add(offset + (3 << log_step)) as *mut u32); + val4.store(dst.add(offset + (4 << log_step)) as *mut u32); + val5.store(dst.add(offset + (5 << log_step)) as *mut u32); + val6.store(dst.add(offset + (6 << log_step)) as *mut u32); + val7.store(dst.add(offset + (7 << log_step)) as *mut u32); } /// Applies 2 butterfly layers on 4 vectors of 16 M31 elements. @@ -490,24 +480,24 @@ pub unsafe fn fft2( twiddles_dbl1: [i32; 1], ) { // Load the 4 AVX vectors from the array. - let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step))); - let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step))); - let mut val2 = PackedBaseField::load(src.add(offset + (2 << log_step))); - let mut val3 = PackedBaseField::load(src.add(offset + (3 << log_step))); + let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step)) as *const u32); + let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step)) as *const u32); + let mut val2 = PackedBaseField::load(src.add(offset + (2 << log_step)) as *const u32); + let mut val3 = PackedBaseField::load(src.add(offset + (3 << log_step)) as *const u32); // Apply the second layer of butterflies. - (val0, val2) = avx_butterfly(val0, val2, _mm512_set1_epi32(twiddles_dbl1[0])); - (val1, val3) = avx_butterfly(val1, val3, _mm512_set1_epi32(twiddles_dbl1[0])); + (val0, val2) = avx_butterfly(val0, val2, u32x16::splat(twiddles_dbl1[0] as u32)); + (val1, val3) = avx_butterfly(val1, val3, u32x16::splat(twiddles_dbl1[0] as u32)); // Apply the first layer of butterflies. - (val0, val1) = avx_butterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); - (val2, val3) = avx_butterfly(val2, val3, _mm512_set1_epi32(twiddles_dbl0[1])); + (val0, val1) = avx_butterfly(val0, val1, u32x16::splat(twiddles_dbl0[0] as u32)); + (val2, val3) = avx_butterfly(val2, val3, u32x16::splat(twiddles_dbl0[1] as u32)); // Store the 4 AVX vectors back to the array. - val0.store(dst.add(offset + (0 << log_step))); - val1.store(dst.add(offset + (1 << log_step))); - val2.store(dst.add(offset + (2 << log_step))); - val3.store(dst.add(offset + (3 << log_step))); + val0.store(dst.add(offset + (0 << log_step)) as *mut u32); + val1.store(dst.add(offset + (1 << log_step)) as *mut u32); + val2.store(dst.add(offset + (2 << log_step)) as *mut u32); + val3.store(dst.add(offset + (3 << log_step)) as *mut u32); } /// Applies 1 butterfly layers on 2 vectors of 16 M31 elements. @@ -528,24 +518,21 @@ pub unsafe fn fft1( twiddles_dbl0: [i32; 1], ) { // Load the 2 AVX vectors from the array. - let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step))); - let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step))); + let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step)) as *const u32); + let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step)) as *const u32); - (val0, val1) = avx_butterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0])); + (val0, val1) = avx_butterfly(val0, val1, u32x16::splat(twiddles_dbl0[0] as u32)); // Store the 2 AVX vectors back to the array. - val0.store(dst.add(offset + (0 << log_step))); - val1.store(dst.add(offset + (1 << log_step))); + val0.store(dst.add(offset + (0 << log_step)) as *mut u32); + val1.store(dst.add(offset + (1 << log_step)) as *mut u32); } -#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] #[cfg(test)] mod tests { - use std::arch::x86_64::{_mm512_add_epi32, _mm512_set1_epi32, _mm512_setr_epi32}; - use super::*; - use crate::core::backend::avx512::{BaseFieldVec, PackedBaseField}; use crate::core::backend::cpu::CPUCirclePoly; + use crate::core::backend::simd::{BaseFieldVec, PackedBaseField}; use crate::core::backend::Column; use crate::core::fft::butterfly; use crate::core::fields::m31::BaseField; @@ -554,16 +541,16 @@ mod tests { #[test] fn test_butterfly() { unsafe { - let val0 = PackedBaseField(_mm512_setr_epi32( + let val0 = PackedBaseField::from_simd_unchecked(u32x16::from_array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - )); - let val1 = PackedBaseField(_mm512_setr_epi32( + ])); + let val1 = PackedBaseField::from_simd_unchecked(u32x16::from_array([ 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, - )); - let twiddle = _mm512_setr_epi32( + ])); + let twiddle = u32x16::from_array([ 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, - ); - let twiddle_dbl = _mm512_add_epi32(twiddle, twiddle); + ]); + let twiddle_dbl = twiddle + twiddle; let (r0, r1) = avx_butterfly(val0, val1, twiddle_dbl); let val0: [BaseField; 16] = std::mem::transmute(val0); @@ -663,7 +650,7 @@ mod tests { let (val0, val1) = avx_butterfly( std::mem::transmute(values0), std::mem::transmute(values1), - _mm512_set1_epi32(twiddle_dbls[3][0]), + u32x16::splat(twiddle_dbls[3][0] as u32), ); let (val0, val1) = vecwise_butterflies( val0,