diff --git a/crates/prover/src/core/backend/simd/fft/ifft.rs b/crates/prover/src/core/backend/simd/fft/ifft.rs new file mode 100644 index 000000000..aa5ec1283 --- /dev/null +++ b/crates/prover/src/core/backend/simd/fft/ifft.rs @@ -0,0 +1,733 @@ +//! Inverse fft. + +use std::simd::{simd_swizzle, u32x16, u32x2, u32x4}; + +use itertools::Itertools; + +use super::compute_first_twiddles; +use crate::core::backend::simd::fft::{transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE}; +use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; +use crate::core::circle::Coset; +use crate::core::fields::FieldExpOps; +use crate::core::utils::bit_reverse; + +/// Performs an Inverse Circle Fast Fourier Transform (ICFFT) on the given values. +/// +/// # Arguments +/// +/// - `values`: A mutable pointer to the values on which the ICFFT is to be performed. +/// - `twiddle_dbl`: A reference to the doubles of the twiddle factors. +/// - `log_n_elements`: The log of the number of elements in the `values` array. +/// +/// # Panics +/// +/// This function will panic if `log_n_elements` is less than `MIN_FFT_LOG_SIZE`. +/// +/// # Safety +/// +/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. +pub unsafe fn ifft(values: *mut u32, twiddle_dbl: &[&[u32]], log_n_elements: usize) { + assert!(log_n_elements >= MIN_FFT_LOG_SIZE); + let log_n_vecs = log_n_elements - LOG_N_LANES as usize; + if log_n_elements <= CACHED_FFT_LOG_SIZE { + ifft_lower_with_vecwise(values, twiddle_dbl, log_n_elements, log_n_elements); + return; + } + + let fft_layers_pre_transpose = log_n_vecs.div_ceil(2); + let fft_layers_post_transpose = log_n_vecs / 2; + ifft_lower_with_vecwise( + values, + &twiddle_dbl[..(3 + fft_layers_pre_transpose)], + log_n_elements, + fft_layers_pre_transpose + LOG_N_LANES as usize, + ); + transpose_vecs(values, log_n_vecs); + ifft_lower_without_vecwise( + values, + &twiddle_dbl[(3 + fft_layers_pre_transpose)..], + log_n_elements, + fft_layers_post_transpose, + ); +} + +/// Computes partial ifft on `2^log_size` M31 elements. +/// +/// # Arguments +/// +/// - `values`: Pointer to the entire value array, aligned to 64 bytes. +/// - `twiddle_dbl`: The doubles of the twiddle factors for each layer of the the ifft. Layer i +/// holds `2^(log_size - 1 - i)` twiddles. +/// - `log_size`: The log of the number of number of M31 elements in the array. +/// - `fft_layers`: The number of ifft layers to apply, out of log_size. +/// +/// # Panics +/// +/// Panics if `log_size` is not at least 5. +/// +/// # Safety +/// +/// `values` must have the same alignment as [`PackedBaseField`]. +/// `fft_layers` must be at least 5. +pub unsafe fn ifft_lower_with_vecwise( + values: *mut u32, + twiddle_dbl: &[&[u32]], + log_size: usize, + fft_layers: usize, +) { + const VECWISE_FFT_BITS: usize = LOG_N_LANES as usize + 1; + assert!(log_size >= VECWISE_FFT_BITS); + + assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2)); + + for index_h in 0..(1 << (log_size - fft_layers)) { + ifft_vecwise_loop(values, twiddle_dbl, fft_layers - VECWISE_FFT_BITS, index_h); + for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3) { + match fft_layers - layer { + 1 => { + ifft1_loop(values, &twiddle_dbl[(layer - 1)..], layer, index_h); + } + 2 => { + ifft2_loop(values, &twiddle_dbl[(layer - 1)..], layer, index_h); + } + _ => { + ifft3_loop( + values, + &twiddle_dbl[(layer - 1)..], + fft_layers - layer - 3, + layer, + index_h, + ); + } + } + } + } +} + +/// Computes partial ifft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits of +/// the index). +/// +/// # Arguments +/// +/// values - Pointer to the entire value array, aligned to 64 bytes. +/// twiddle_dbl - The doubles of the twiddle factors for each layer of the the ifft. +/// log_size - The log of the number of number of M31 elements in the array. +/// fft_layers - The number of ifft layers to apply, out of log_size - VEC_LOG_SIZE. +/// +/// # Panics +/// +/// Panics if `log_size` is not at least 4. +/// +/// # Safety +/// +/// `values` must have the same alignment as [`PackedBaseField`]. +/// `fft_layers` must be at least 4. +pub unsafe fn ifft_lower_without_vecwise( + values: *mut u32, + twiddle_dbl: &[&[u32]], + log_size: usize, + fft_layers: usize, +) { + assert!(log_size >= LOG_N_LANES as usize); + + for index_h in 0..(1 << (log_size - fft_layers - LOG_N_LANES as usize)) { + for layer in (0..fft_layers).step_by(3) { + let fixed_layer = layer + LOG_N_LANES as usize; + match fft_layers - layer { + 1 => { + ifft1_loop(values, &twiddle_dbl[layer..], fixed_layer, index_h); + } + 2 => { + ifft2_loop(values, &twiddle_dbl[layer..], fixed_layer, index_h); + } + _ => { + ifft3_loop( + values, + &twiddle_dbl[layer..], + fft_layers - layer - 3, + fixed_layer, + index_h, + ); + } + } + } + } +} + +/// Runs the first 5 ifft layers across the entire array. +/// +/// # Arguments +/// +/// values - Pointer to the entire value array, aligned to 64 bytes. +/// twiddle_dbl - The doubles of the twiddle factors for each of the 5 ifft layers. +/// high_bits - The number of bits this loops needs to run on. +/// index_h - The higher part of the index, iterated by the caller. +/// +/// # Safety +/// +/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. +pub unsafe fn ifft_vecwise_loop( + values: *mut u32, + twiddle_dbl: &[&[u32]], + loop_bits: usize, + index_h: usize, +) { + for index_l in 0..(1 << loop_bits) { + let index = (index_h << loop_bits) + index_l; + let mut val0 = PackedBaseField::load(values.add(index * 32).cast_const()); + let mut val1 = PackedBaseField::load(values.add(index * 32 + 16).cast_const()); + (val0, val1) = vecwise_ibutterflies( + val0, + val1, + std::array::from_fn(|i| *twiddle_dbl[0].get_unchecked(index * 8 + i)), + std::array::from_fn(|i| *twiddle_dbl[1].get_unchecked(index * 4 + i)), + std::array::from_fn(|i| *twiddle_dbl[2].get_unchecked(index * 2 + i)), + ); + (val0, val1) = ibutterfly( + val0, + val1, + u32x16::splat(*twiddle_dbl[3].get_unchecked(index)), + ); + val0.store(values.add(index * 32)); + val1.store(values.add(index * 32 + 16)); + } +} + +/// Runs 3 ifft layers across the entire array. +/// +/// # Arguments +/// +/// values - Pointer to the entire value array, aligned to 64 bytes. +/// twiddle_dbl - The doubles of the twiddle factors for each of the 3 ifft layers. +/// loop_bits - The number of bits this loops needs to run on. +/// layer - The layer number of the first ifft layer to apply. +/// The layers `layer`, `layer + 1`, `layer + 2` are applied. +/// index_h - The higher part of the index, iterated by the caller. +/// +/// # Safety +/// +/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. +pub unsafe fn ifft3_loop( + values: *mut u32, + twiddle_dbl: &[&[u32]], + loop_bits: usize, + layer: usize, + index_h: usize, +) { + for index_l in 0..(1 << loop_bits) { + let index = (index_h << loop_bits) + index_l; + let offset = index << (layer + 3); + for l in (0..(1 << layer)).step_by(1 << LOG_N_LANES as usize) { + ifft3( + values, + offset + l, + layer, + std::array::from_fn(|i| { + *twiddle_dbl[0].get_unchecked((index * 4 + i) & (twiddle_dbl[0].len() - 1)) + }), + std::array::from_fn(|i| { + *twiddle_dbl[1].get_unchecked((index * 2 + i) & (twiddle_dbl[1].len() - 1)) + }), + std::array::from_fn(|i| { + *twiddle_dbl[2].get_unchecked((index + i) & (twiddle_dbl[2].len() - 1)) + }), + ); + } + } +} + +/// Runs 2 ifft layers across the entire array. +/// +/// # Arguments +/// +/// values - Pointer to the entire value array, aligned to 64 bytes. +/// twiddle_dbl - The doubles of the twiddle factors for each of the 2 ifft layers. +/// loop_bits - The number of bits this loops needs to run on. +/// layer - The layer number of the first ifft layer to apply. +/// The layers `layer`, `layer + 1` are applied. +/// index - The index, iterated by the caller. +/// +/// # Safety +/// +/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. +unsafe fn ifft2_loop(values: *mut u32, twiddle_dbl: &[&[u32]], layer: usize, index: usize) { + let offset = index << (layer + 2); + for l in (0..(1 << layer)).step_by(1 << LOG_N_LANES as usize) { + ifft2( + values, + offset + l, + layer, + std::array::from_fn(|i| { + *twiddle_dbl[0].get_unchecked((index * 2 + i) & (twiddle_dbl[0].len() - 1)) + }), + std::array::from_fn(|i| { + *twiddle_dbl[1].get_unchecked((index + i) & (twiddle_dbl[1].len() - 1)) + }), + ); + } +} + +/// Runs 1 ifft layer across the entire array. +/// +/// # Arguments +/// +/// values - Pointer to the entire value array, aligned to 64 bytes. +/// twiddle_dbl - The doubles of the twiddle factors for the ifft layer. +/// layer - The layer number of the ifft layer to apply. +/// index_h - The higher part of the index, iterated by the caller. +/// +/// # Safety +/// +/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. +unsafe fn ifft1_loop(values: *mut u32, twiddle_dbl: &[&[u32]], layer: usize, index: usize) { + let offset = index << (layer + 1); + for l in (0..(1 << layer)).step_by(1 << LOG_N_LANES as usize) { + ifft1( + values, + offset + l, + layer, + std::array::from_fn(|i| { + *twiddle_dbl[0].get_unchecked((index + i) & (twiddle_dbl[0].len() - 1)) + }), + ); + } +} + +/// Computes the ibutterfly operation for packed M31 elements. +/// +/// Returns `val0 + val1, t (val0 - val1)`. `val0, val1` are packed M31 elements. 16 M31 words at +/// each. Each value is assumed to be in unreduced form, [0, P] including P. `twiddle_dbl` holds 16 +/// values, each is a *double* of a twiddle factor, in unreduced form. +pub fn ibutterfly( + val0: PackedBaseField, + val1: PackedBaseField, + twiddle_dbl: u32x16, +) -> (PackedBaseField, PackedBaseField) { + let r0 = val0 + val1; + let r1 = val0 - val1; + + let prod = { + // TODO: Come up with a better approach than `cfg`ing on target_feature. + cfg_if::cfg_if! { + if #[cfg(all(target_feature = "neon", target_arch = "aarch64"))] { + super::_mul_twiddle_neon(r1, twiddle_dbl) + } else if #[cfg(all(target_feature = "simd128", target_arch = "wasm32"))] { + super::_mul_twiddle_wasm(r1, twiddle_dbl) + } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] { + super::_mul_twiddle_avx512(r1, twiddle_dbl) + } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2f"))] { + super::_mul_twiddle_avx2(r1, twiddle_dbl) + } else { + super::_mul_twiddle_simd(r1, twiddle_dbl) + } + } + }; + + (r0, prod) +} + +/// Runs ifft on 2 vectors of 16 M31 elements. +/// +/// This amounts to 4 butterfly layers, each with 16 butterflies. +/// Each of the vectors represents a bit reversed evaluation. +/// Each value in a vectors is in unreduced form: [0, P] including P. +/// Takes 3 twiddle arrays, one for each layer after the first, holding the double of the +/// corresponding twiddle. +/// The first layer's twiddles (lower bit of the index) are computed from the second layer's +/// twiddles. The second layer takes 8 twiddles. +/// The third layer takes 4 twiddles. +/// The fourth layer takes 2 twiddles. +pub fn vecwise_ibutterflies( + mut val0: PackedBaseField, + mut val1: PackedBaseField, + twiddle1_dbl: [u32; 8], + twiddle2_dbl: [u32; 4], + twiddle3_dbl: [u32; 2], +) -> (PackedBaseField, PackedBaseField) { + // TODO(spapini): The permute can be fused with the _mm512_srli_epi64 inside the butterfly. + + // Each `ibutterfly` take 2 512-bit registers, and does 16 butterflies element by element. + // We need to permute the 512-bit registers to get the right order for the butterflies. + // Denote the index of the 16 M31 elements in register i as i:abcd. + // At each layer we apply the following permutation to the index: + // i:abcd => d:iabc + // This is how it looks like at each iteration. + // i:abcd + // d:iabc + // ifft on d + // c:diab + // ifft on c + // b:cdia + // ifft on b + // a:bcid + // ifft on a + // i:abcd + + let (t0, t1) = compute_first_twiddles(twiddle1_dbl.into()); + + // Apply the permutation, resulting in indexing d:iabc. + (val0, val1) = val0.deinterleave(val1); + (val0, val1) = ibutterfly(val0, val1, t0); + + // Apply the permutation, resulting in indexing c:diab. + (val0, val1) = val0.deinterleave(val1); + (val0, val1) = ibutterfly(val0, val1, t1); + + // The twiddles for layer 2 are replicated in the following pattern: + // 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 + // let t = _mm512_broadcast_i32x4(transmute(twiddle2_dbl)); + let t = simd_swizzle!( + u32x4::from(twiddle2_dbl), + [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3] + ); + // Apply the permutation, resulting in indexing b:cdia. + (val0, val1) = val0.deinterleave(val1); + (val0, val1) = ibutterfly(val0, val1, t); + + // The twiddles for layer 3 are replicated in the following pattern: + // 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 + // let t = _mm512_set1_epi64(transmute(twiddle3_dbl)); + let t = simd_swizzle!( + u32x2::from(twiddle3_dbl), + [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + ); + // Apply the permutation, resulting in indexing a:bcid. + (val0, val1) = val0.deinterleave(val1); + (val0, val1) = ibutterfly(val0, val1, t); + + // Apply the permutation, resulting in indexing i:abcd. + val0.deinterleave(val1) +} + +/// Returns the line twiddles (x points) for an ifft on a coset. +pub fn get_itwiddle_dbls(mut coset: Coset) -> Vec> { + let mut res = vec![]; + for _ in 0..coset.log_size() { + res.push( + coset + .iter() + .take(coset.size() / 2) + .map(|p| p.x.inverse().0 * 2) + .collect_vec(), + ); + bit_reverse(res.last_mut().unwrap()); + coset = coset.double(); + } + + res +} + +/// Applies 3 ibutterfly layers on 8 vectors of 16 M31 elements. +/// +/// Vectorized over the 16 elements of the vectors. +/// Used for radix-8 ifft. +/// Each butterfly layer, has 3 AVX butterflies. +/// Total of 12 AVX butterflies. +/// +/// # Arguments +/// +/// - `values`: Pointer to the entire value array. +/// - `offset`: The offset of the first value in the array. +/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values +/// that need to be transformed. For layer i this is i - 4. +/// - `twiddles_dbl0/1/2`: The double of the twiddles for the 3 layers of ibutterflies. Each layer +/// has 4/2/1 twiddles. +/// +/// # Safety +/// +/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. +pub unsafe fn ifft3( + values: *mut u32, + offset: usize, + log_step: usize, + twiddles_dbl0: [u32; 4], + twiddles_dbl1: [u32; 2], + twiddles_dbl2: [u32; 1], +) { + // Load the 8 AVX vectors from the array. + let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const()); + let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const()); + let mut val2 = PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const()); + let mut val3 = PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const()); + let mut val4 = PackedBaseField::load(values.add(offset + (4 << log_step)).cast_const()); + let mut val5 = PackedBaseField::load(values.add(offset + (5 << log_step)).cast_const()); + let mut val6 = PackedBaseField::load(values.add(offset + (6 << log_step)).cast_const()); + let mut val7 = PackedBaseField::load(values.add(offset + (7 << log_step)).cast_const()); + + // Apply the first layer of ibutterflies. + (val0, val1) = ibutterfly(val0, val1, u32x16::splat(twiddles_dbl0[0])); + (val2, val3) = ibutterfly(val2, val3, u32x16::splat(twiddles_dbl0[1])); + (val4, val5) = ibutterfly(val4, val5, u32x16::splat(twiddles_dbl0[2])); + (val6, val7) = ibutterfly(val6, val7, u32x16::splat(twiddles_dbl0[3])); + + // Apply the second layer of ibutterflies. + (val0, val2) = ibutterfly(val0, val2, u32x16::splat(twiddles_dbl1[0])); + (val1, val3) = ibutterfly(val1, val3, u32x16::splat(twiddles_dbl1[0])); + (val4, val6) = ibutterfly(val4, val6, u32x16::splat(twiddles_dbl1[1])); + (val5, val7) = ibutterfly(val5, val7, u32x16::splat(twiddles_dbl1[1])); + + // Apply the third layer of ibutterflies. + (val0, val4) = ibutterfly(val0, val4, u32x16::splat(twiddles_dbl2[0])); + (val1, val5) = ibutterfly(val1, val5, u32x16::splat(twiddles_dbl2[0])); + (val2, val6) = ibutterfly(val2, val6, u32x16::splat(twiddles_dbl2[0])); + (val3, val7) = ibutterfly(val3, val7, u32x16::splat(twiddles_dbl2[0])); + + // Store the 8 AVX vectors back to the array. + val0.store(values.add(offset + (0 << log_step))); + val1.store(values.add(offset + (1 << log_step))); + val2.store(values.add(offset + (2 << log_step))); + val3.store(values.add(offset + (3 << log_step))); + val4.store(values.add(offset + (4 << log_step))); + val5.store(values.add(offset + (5 << log_step))); + val6.store(values.add(offset + (6 << log_step))); + val7.store(values.add(offset + (7 << log_step))); +} + +/// Applies 2 ibutterfly layers on 4 vectors of 16 M31 elements. +/// +/// Vectorized over the 16 elements of the vectors. +/// Used for radix-4 ifft. +/// Each ibutterfly layer, has 2 AVX butterflies. +/// Total of 4 AVX butterflies. +/// +/// # Arguments +/// +/// values - Pointer to the entire value array. +/// offset - The offset of the first value in the array. +/// log_step - The log of the distance in the array, in M31 elements, between each pair of +/// values that need to be transformed. For layer i this is i - 4. +/// twiddles_dbl0/1 - The double of the twiddles for the 2 layers of ibutterflies. +/// Each layer has 2/1 twiddles. +/// +/// # Safety +/// +/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. +pub unsafe fn ifft2( + values: *mut u32, + offset: usize, + log_step: usize, + twiddles_dbl0: [u32; 2], + twiddles_dbl1: [u32; 1], +) { + // Load the 4 AVX vectors from the array. + let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const()); + let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const()); + let mut val2 = PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const()); + let mut val3 = PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const()); + + // Apply the first layer of butterflies. + (val0, val1) = ibutterfly(val0, val1, u32x16::splat(twiddles_dbl0[0])); + (val2, val3) = ibutterfly(val2, val3, u32x16::splat(twiddles_dbl0[1])); + + // Apply the second layer of butterflies. + (val0, val2) = ibutterfly(val0, val2, u32x16::splat(twiddles_dbl1[0])); + (val1, val3) = ibutterfly(val1, val3, u32x16::splat(twiddles_dbl1[0])); + + // Store the 4 AVX vectors back to the array. + val0.store(values.add(offset + (0 << log_step))); + val1.store(values.add(offset + (1 << log_step))); + val2.store(values.add(offset + (2 << log_step))); + val3.store(values.add(offset + (3 << log_step))); +} + +/// Applies 1 ibutterfly layers on 2 vectors of 16 M31 elements. +/// +/// Vectorized over the 16 elements of the vectors. +/// +/// # Arguments +/// +/// values - Pointer to the entire value array. +/// offset - The offset of the first value in the array. +/// log_step - The log of the distance in the array, in M31 elements, between each pair of +/// values that need to be transformed. For layer i this is i - 4. +/// twiddles_dbl0 - The double of the twiddles for the ibutterfly layer. +/// +/// # Safety +/// +/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. +pub unsafe fn ifft1(values: *mut u32, offset: usize, log_step: usize, twiddles_dbl0: [u32; 1]) { + // Load the 2 AVX vectors from the array. + let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const()); + let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const()); + + (val0, val1) = ibutterfly(val0, val1, u32x16::splat(twiddles_dbl0[0])); + + // Store the 2 AVX vectors back to the array. + val0.store(values.add(offset + (0 << log_step))); + val1.store(values.add(offset + (1 << log_step))); +} + +#[cfg(test)] +mod tests { + use std::mem::transmute; + use std::simd::u32x16; + + use itertools::Itertools; + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + use super::{ + get_itwiddle_dbls, ibutterfly, ifft, ifft3, ifft_lower_with_vecwise, vecwise_ibutterflies, + }; + use crate::core::backend::cpu::CPUCircleEvaluation; + use crate::core::backend::simd::column::BaseFieldVec; + use crate::core::backend::simd::fft::{transpose_vecs, CACHED_FFT_LOG_SIZE}; + use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES, N_LANES}; + use crate::core::backend::Column; + use crate::core::fft::ibutterfly as ground_truth_ibutterfly; + use crate::core::fields::m31::BaseField; + use crate::core::poly::circle::{CanonicCoset, CircleDomain}; + + #[test] + fn test_ibutterfly() { + let mut rng = SmallRng::seed_from_u64(0); + let mut v0: [BaseField; N_LANES] = rng.gen(); + let mut v1: [BaseField; N_LANES] = rng.gen(); + let twiddle: [BaseField; N_LANES] = rng.gen(); + let twiddle_dbl = twiddle.map(|v| v.0 * 2); + + let (r0, r1) = ibutterfly(v0.into(), v1.into(), twiddle_dbl.into()); + + let r0 = r0.to_array(); + let r1 = r1.to_array(); + for i in 0..N_LANES { + ground_truth_ibutterfly(&mut v0[i], &mut v1[i], twiddle[i]); + assert_eq!((v0[i], v1[i]), (r0[i], r1[i]), "mismatch at i={i}"); + } + } + + #[test] + fn test_ifft3() { + let mut rng = SmallRng::seed_from_u64(0); + let values = rng.gen::<[BaseField; 8]>().map(PackedBaseField::broadcast); + let twiddles0: [BaseField; 4] = rng.gen(); + let twiddles1: [BaseField; 2] = rng.gen(); + let twiddles2: [BaseField; 1] = rng.gen(); + let twiddles0_dbl = twiddles0.map(|v| v.0 * 2); + let twiddles1_dbl = twiddles1.map(|v| v.0 * 2); + let twiddles2_dbl = twiddles2.map(|v| v.0 * 2); + + let mut res = values; + unsafe { + ifft3( + transmute(res.as_mut_ptr()), + 0, + LOG_N_LANES as usize, + twiddles0_dbl, + twiddles1_dbl, + twiddles2_dbl, + ) + }; + + let mut expected = values.map(|v| v.to_array()[0]); + for i in 0..8 { + let j = i ^ 1; + if i > j { + continue; + } + let (mut v0, mut v1) = (expected[i], expected[j]); + ground_truth_ibutterfly(&mut v0, &mut v1, twiddles0[i / 2]); + (expected[i], expected[j]) = (v0, v1); + } + for i in 0..8 { + let j = i ^ 2; + if i > j { + continue; + } + let (mut v0, mut v1) = (expected[i], expected[j]); + ground_truth_ibutterfly(&mut v0, &mut v1, twiddles1[i / 4]); + (expected[i], expected[j]) = (v0, v1); + } + for i in 0..8 { + let j = i ^ 4; + if i > j { + continue; + } + let (mut v0, mut v1) = (expected[i], expected[j]); + ground_truth_ibutterfly(&mut v0, &mut v1, twiddles2[0]); + (expected[i], expected[j]) = (v0, v1); + } + for i in 0..8 { + assert_eq!( + res[i].to_array(), + [expected[i]; N_LANES], + "mismatch at i={i}" + ); + } + } + + #[test] + fn test_vecwise_ibutterflies() { + let domain = CanonicCoset::new(5).circle_domain(); + let twiddle_dbls = get_itwiddle_dbls(domain.half_coset); + assert_eq!(twiddle_dbls.len(), 4); + let mut rng = SmallRng::seed_from_u64(0); + let values: [[BaseField; 16]; 2] = rng.gen(); + + let res = { + let (val0, val1) = vecwise_ibutterflies( + values[0].into(), + values[1].into(), + twiddle_dbls[0].clone().try_into().unwrap(), + twiddle_dbls[1].clone().try_into().unwrap(), + twiddle_dbls[2].clone().try_into().unwrap(), + ); + let (val0, val1) = ibutterfly(val0, val1, u32x16::splat(twiddle_dbls[3][0])); + [val0.to_array(), val1.to_array()].concat() + }; + + assert_eq!(res, ground_truth_ifft(domain, values.flatten())); + } + + #[test] + fn test_ifft_lower_with_vecwise() { + for log_size in 5..12 { + let domain = CanonicCoset::new(log_size).circle_domain(); + let mut rng = SmallRng::seed_from_u64(0); + let values = (0..domain.size()).map(|_| rng.gen()).collect_vec(); + let twiddle_dbls = get_itwiddle_dbls(domain.half_coset); + + let mut res = values.iter().copied().collect::(); + unsafe { + ifft_lower_with_vecwise( + transmute(res.data.as_mut_ptr()), + &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), + log_size as usize, + log_size as usize, + ); + } + + assert_eq!(res.to_cpu(), ground_truth_ifft(domain, &values)); + } + } + + #[test] + fn test_ifft_full() { + for log_size in CACHED_FFT_LOG_SIZE + 1..CACHED_FFT_LOG_SIZE + 3 { + let domain = CanonicCoset::new(log_size as u32).circle_domain(); + let mut rng = SmallRng::seed_from_u64(0); + let values = (0..domain.size()).map(|_| rng.gen()).collect_vec(); + let twiddle_dbls = get_itwiddle_dbls(domain.half_coset); + + let mut res = values.iter().copied().collect::(); + unsafe { + ifft( + transmute(res.data.as_mut_ptr()), + &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), + log_size, + ); + transpose_vecs(transmute(res.data.as_mut_ptr()), log_size - 4); + } + + assert_eq!(res.to_cpu(), ground_truth_ifft(domain, &values)); + } + } + + fn ground_truth_ifft(domain: CircleDomain, values: &[BaseField]) -> Vec { + let eval = CPUCircleEvaluation::new(domain, values.to_vec()); + let mut res = eval.interpolate().coeffs; + let denorm = BaseField::from(domain.size()); + res.iter_mut().for_each(|v| *v *= denorm); + res + } +} diff --git a/crates/prover/src/core/backend/simd/fft/mod.rs b/crates/prover/src/core/backend/simd/fft/mod.rs new file mode 100644 index 000000000..e18dcdbda --- /dev/null +++ b/crates/prover/src/core/backend/simd/fft/mod.rs @@ -0,0 +1,316 @@ +use std::mem::transmute; +use std::ptr; +use std::simd::{i32x16, simd_swizzle, u32x16, u32x8, Simd, Swizzle}; + +use crate::core::backend::simd::m31::N_LANES; +use crate::core::backend::simd::utils::{LoEvensInterleaveHiEvens, LoOddsInterleaveHiOdds}; +use crate::core::backend::simd::PackedBaseField; + +pub mod ifft; +pub mod rfft; + +pub const CACHED_FFT_LOG_SIZE: usize = 16; + +pub const MIN_FFT_LOG_SIZE: usize = 5; + +// TODO(spapini): FFTs return a redundant representation, that can get the value P. need to reduce +// it somewhere. + +/// Transposes the AVX vectors in the given array. +/// +/// Swaps the bit index abc <-> cba, where |a|=|c| and |b| = 0 or 1, according to the parity of +/// `log_n_vecs`. +/// When log_n_vecs is odd, transforms the index abc <-> cba, w +/// +/// # Arguments +/// +/// - `values`: A mutable pointer to the values that are to be transposed. +/// - `log_n_vecs`: The log of the number of AVX vectors in the `values` array. +/// +/// # Safety +/// +/// Behavior is undefined if `values` does not have the same alignment as [`u32x16`]. +pub unsafe fn transpose_vecs(values: *mut u32, log_n_vecs: usize) { + let half = log_n_vecs / 2; + for b in 0..(1 << (log_n_vecs & 1)) { + for a in 0..(1 << half) { + for c in 0..(1 << half) { + let i = (a << (log_n_vecs - half)) | (b << half) | c; + let j = (c << (log_n_vecs - half)) | (b << half) | a; + if i >= j { + continue; + } + let val0 = load(values.add(i << 4).cast_const()); + let val1 = load(values.add(j << 4).cast_const()); + store(values.add(i << 4), val1); + store(values.add(j << 4), val0); + } + } + } +} + +/// Computes the twiddles for the first fft layer from the second, and loads both to AVX registers. +/// +/// Returns the twiddles for the first layer and the twiddles for the second layer. +fn compute_first_twiddles(twiddle1_dbl: u32x8) -> (u32x16, u32x16) { + // Start by loading the twiddles for the second layer (layer 1): + // The twiddles for layer 1 are replicated in the following pattern: + // 0 1 2 3 4 5 6 7 0 1 2 3 4 5 6 7 + // let t1 = _mm512_broadcast_i64x4(std::mem::transmute(twiddle1_dbl)); + let t1 = simd_swizzle!( + twiddle1_dbl, + twiddle1_dbl, + [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7] + ); + + // The twiddles for layer 0 can be computed from the twiddles for layer 1. + // Since the twiddles are bit reversed, we consider the circle domain in bit reversed order. + // Each consecutive 4 points in the bit reversed order of a coset form a circle coset of size 4. + // A circle coset of size 4 in bit reversed order looks like this: + // [(x, y), (-x, -y), (y, -x), (-y, x)] + // Note: This is related to the choice of M31_CIRCLE_GEN, and the fact the a quarter rotation + // is (0,-1) and not (0,1). (0,1) would yield another relation. + // The twiddles for layer 0 are the y coordinates: + // [y, -y, -x, x] + // The twiddles for layer 1 in bit reversed order are the x coordinates: + // [x, y] + // Works also for inverse of the twiddles. + + // The twiddles for layer 0 are computed like this: + // t0[4i:4i+3] = [t1[2i+1], -t1[2i+1], -t1[2i], t1[2i]] + struct IndicesFromT1; + + impl Swizzle<16> for IndicesFromT1 { + const INDEX: [usize; 16] = [ + 0b0001, 0b0001, 0b0000, 0b0000, 0b0011, 0b0011, 0b0010, 0b0010, 0b0101, 0b0101, 0b0100, + 0b0100, 0b0111, 0b0111, 0b0110, 0b0110, + ]; + } + + // Xoring a double twiddle with 2^32-2 transforms it to the double of it negation. + // Note that this keeps the values as a double of a value in the range [0, P]. + const NEGATION_MASK: u32x16 = unsafe { + transmute(i32x16::from_array([ + 0, -2, -2, 0, 0, -2, -2, 0, 0, -2, -2, 0, 0, -2, -2, 0, + ])) + }; + + let t0 = IndicesFromT1::swizzle(t1) ^ NEGATION_MASK; + + (t0, t1) +} + +unsafe fn load(mem_addr: *const u32) -> u32x16 { + ptr::read(mem_addr as *const u32x16) +} + +unsafe fn store(mem_addr: *mut u32, a: u32x16) { + ptr::write(mem_addr as *mut u32x16, a); +} + +#[cfg(target_arch = "aarch64")] +fn _mul_twiddle_neon(a: PackedBaseField, twiddle_dbl: u32x16) -> PackedBaseField { + use core::arch::aarch64::{uint32x2_t, vmull_u32}; + use std::simd::u32x4; + + let [a0, a1, a2, a3, a4, a5, a6, a7]: [uint32x2_t; 8] = unsafe { transmute(a) }; + let [b0, b1, b2, b3, b4, b5, b6, b7]: [uint32x2_t; 8] = unsafe { transmute(twiddle_dbl) }; + + // Each c_i contains |0|prod_lo|prod_hi|0|0|prod_lo|prod_hi|0| + let c0: u32x4 = unsafe { transmute(vmull_u32(a0, b0)) }; + let c1: u32x4 = unsafe { transmute(vmull_u32(a1, b1)) }; + let c2: u32x4 = unsafe { transmute(vmull_u32(a2, b2)) }; + let c3: u32x4 = unsafe { transmute(vmull_u32(a3, b3)) }; + let c4: u32x4 = unsafe { transmute(vmull_u32(a4, b4)) }; + let c5: u32x4 = unsafe { transmute(vmull_u32(a5, b5)) }; + let c6: u32x4 = unsafe { transmute(vmull_u32(a6, b6)) }; + let c7: u32x4 = unsafe { transmute(vmull_u32(a7, b7)) }; + + // *_lo contain `|prod_lo|0|prod_lo|0|prod_lo0|0|prod_lo|0|`. + // *_hi contain `|0|prod_hi|0|prod_hi|0|prod_hi|0|prod_hi|`. + let (mut c0_c1_lo, c0_c1_hi) = c0.deinterleave(c1); + let (mut c2_c3_lo, c2_c3_hi) = c2.deinterleave(c3); + let (mut c4_c5_lo, c4_c5_hi) = c4.deinterleave(c5); + let (mut c6_c7_lo, c6_c7_hi) = c6.deinterleave(c7); + + // *_lo contain `|0|prod_lo|0|prod_lo|0|prod_lo|0|prod_lo|`. + c0_c1_lo >>= 1; + c2_c3_lo >>= 1; + c4_c5_lo >>= 1; + c6_c7_lo >>= 1; + + let lo: PackedBaseField = unsafe { transmute([c0_c1_lo, c2_c3_lo, c4_c5_lo, c6_c7_lo]) }; + let hi: PackedBaseField = unsafe { transmute([c0_c1_hi, c2_c3_hi, c4_c5_hi, c6_c7_hi]) }; + + lo + hi +} + +#[cfg(target_arch = "wasm32")] +fn _mul_twiddle_wasm(a: PackedBaseField, twiddle_dbl: u32x16) -> PackedBaseField { + use core::arch::wasm32::{i64x2_extmul_high_u32x4, i64x2_extmul_low_u32x4, v128}; + use std::simd::u32x4; + + let [a0, a1, a2, a3]: [v128; 4] = unsafe { transmute(a) }; + let [b_dbl0, b_dbl1, b_dbl2, b_dbl3]: [v128; 4] = unsafe { transmute(twiddle_dbl) }; + + let c0_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a0, b_dbl0)) }; + let c0_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a0, b_dbl0)) }; + let c1_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a1, b_dbl1)) }; + let c1_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a1, b_dbl1)) }; + let c2_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a2, b_dbl2)) }; + let c2_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a2, b_dbl2)) }; + let c3_lo: u32x4 = unsafe { transmute(i64x2_extmul_low_u32x4(a3, b_dbl3)) }; + let c3_hi: u32x4 = unsafe { transmute(i64x2_extmul_high_u32x4(a3, b_dbl3)) }; + + let (mut c0_even, c0_odd) = c0_lo.deinterleave(c0_hi); + let (mut c1_even, c1_odd) = c1_lo.deinterleave(c1_hi); + let (mut c2_even, c2_odd) = c2_lo.deinterleave(c2_hi); + let (mut c3_even, c3_odd) = c3_lo.deinterleave(c3_hi); + + c0_even >>= 1; + c1_even >>= 1; + c2_even >>= 1; + c3_even >>= 1; + + let even: PackedBaseField = unsafe { transmute([c0_even, c1_even, c2_even, c3_even]) }; + let odd: PackedBaseField = unsafe { transmute([c0_odd, c1_odd, c2_odd, c3_odd]) }; + + even + odd +} + +#[cfg(target_arch = "x86_64")] +fn _mul_twiddle_avx512(a: PackedBaseField, twiddle_dbl: u32x16) -> PackedBaseField { + use std::arch::x86_64::{__m512i, _mm512_mul_epu32, _mm512_srli_epi64}; + + let a: __m512i = unsafe { transmute(a) }; + // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of + // the first operand. + let a_e = a; + // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of + // the first operand. + let a_o = unsafe { _mm512_srli_epi64(a, 32) }; + + let b_dbl = unsafe { transmute(twiddle_dbl) }; + let b_dbl_e = b_dbl; + let b_dbl_o = unsafe { _mm512_srli_epi64(b_dbl, 32) }; + + // To compute prod = a * b start by multiplying a_e/odd by b_dbl_e/odd. + let prod_dbl_e: u32x16 = unsafe { transmute(_mm512_mul_epu32(a_e, b_dbl_e)) }; + let prod_dbl_o: u32x16 = unsafe { transmute(_mm512_mul_epu32(a_o, b_dbl_o)) }; + + // The result of a multiplication holds a*b in as 64-bits. + // Each 64b-bit word looks like this: + // 1 31 31 1 + // prod_dbl_e - |0|prod_e_h|prod_e_l|0| + // prod_dbl_o - |0|prod_o_h|prod_o_l|0| + + // Interleave the even words of prod_dbl_e with the even words of prod_dbl_o: + let mut prod_lo = LoEvensInterleaveHiEvens::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_lo - |prod_dbl_o_l|0|prod_dbl_e_l|0| + // Divide by 2: + prod_lo >>= 1; + // prod_lo - |0|prod_o_l|0|prod_e_l| + + // Interleave the odd words of prod_dbl_e with the odd words of prod_dbl_o: + let prod_hi = LoOddsInterleaveHiOdds::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_hi - |0|prod_o_h|0|prod_e_h| + + unsafe { PackedBaseField::from_simd(prod_lo) + PackedBaseField::from_simd(prod_hi) } +} + +#[cfg(target_arch = "x86_64")] +fn _mul_twiddle_avx2(a: PackedBaseField, twiddle_dbl: u32x16) -> PackedBaseField { + use std::arch::x86_64::{__m256i, _mm256_mul_epu32, _mm256_srli_epi64}; + + let [a0, a1]: [__m256i; 2] = unsafe { transmute(a) }; + let [b0_dbl, b1_dbl]: [__m256i; 2] = unsafe { transmute(twiddle_dbl) }; + + // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of + // the first operand. + let a0_e = a0; + let a1_e = a1; + // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of + // the first operand. + let a0_o = unsafe { _mm256_srli_epi64(a0, 32) }; + let a1_o = unsafe { _mm256_srli_epi64(a1, 32) }; + + let b0_dbl_e = b0_dbl; + let b1_dbl_e = b1_dbl; + let b0_dbl_o = unsafe { _mm256_srli_epi64(b0_dbl, 32) }; + let b1_dbl_o = unsafe { _mm256_srli_epi64(b1_dbl, 32) }; + + // To compute prod = a * b start by multiplying a0/1_e/odd by b0/1_e/odd. + let prod0_dbl_e = unsafe { _mm256_mul_epu32(a0_e, b0_dbl_e) }; + let prod0_dbl_o = unsafe { _mm256_mul_epu32(a0_o, b0_dbl_o) }; + let prod1_dbl_e = unsafe { _mm256_mul_epu32(a1_e, b1_dbl_e) }; + let prod1_dbl_o = unsafe { _mm256_mul_epu32(a1_o, b1_dbl_o) }; + + let prod_dbl_e: u32x16 = unsafe { transmute([prod0_dbl_e, prod1_dbl_e]) }; + let prod_dbl_o: u32x16 = unsafe { transmute([prod0_dbl_o, prod1_dbl_o]) }; + + // The result of a multiplication holds a*b in as 64-bits. + // Each 64b-bit word looks like this: + // 1 31 31 1 + // prod_dbl_e - |0|prod_e_h|prod_e_l|0| + // prod_dbl_o - |0|prod_o_h|prod_o_l|0| + + // Interleave the even words of prod_dbl_e with the even words of prod_dbl_o: + let mut prod_lo = LoEvensInterleaveHiEvens::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_lo - |prod_dbl_o_l|0|prod_dbl_e_l|0| + // Divide by 2: + prod_lo >>= 1; + // prod_lo - |0|prod_o_l|0|prod_e_l| + + // Interleave the odd words of prod_dbl_e with the odd words of prod_dbl_o: + let prod_hi = LoOddsInterleaveHiOdds::concat_swizzle(prod_dbl_e, prod_dbl_o); + // prod_hi - |0|prod_o_h|0|prod_e_h| + + unsafe { PackedBaseField::from_simd(prod_lo) + PackedBaseField::from_simd(prod_hi) } +} + +// Should only be used in the absence of a platform specific implementation. +fn _mul_twiddle_simd(a: PackedBaseField, twiddle_dbl: u32x16) -> PackedBaseField { + const MASK_EVENS: Simd = Simd::from_array([0xFFFFFFFF; { N_LANES / 2 }]); + + // Set up a word s.t. the lower half of each 64-bit word has the even 32-bit words of + // the first operand. + let a_e = unsafe { transmute::<_, Simd>(a.into_simd()) & MASK_EVENS }; + // Set up a word s.t. the lower half of each 64-bit word has the odd 32-bit words of + // the first operand. + let a_o = unsafe { transmute::<_, Simd>(a) >> 32 }; + + let b_dbl_e = unsafe { transmute::<_, Simd>(twiddle_dbl) & MASK_EVENS }; + let b_dbl_o = unsafe { transmute::<_, Simd>(twiddle_dbl) >> 32 }; + + // To compute prod = a * b start by multiplying + // a_e/o by b_dbl_e/o. + let prod_e_dbl = a_e * b_dbl_e; + let prod_o_dbl = a_o * b_dbl_o; + + // The result of a multiplication holds a*b in as 64-bits. + // Each 64b-bit word looks like this: + // 1 31 31 1 + // prod_e_dbl - |0|prod_e_h|prod_e_l|0| + // prod_o_dbl - |0|prod_o_h|prod_o_l|0| + + // Interleave the even words of prod_e_dbl with the even words of prod_o_dbl: + // let prod_lows = _mm512_permutex2var_epi32(prod_e_dbl, EVENS_INTERLEAVE_EVENS, + // prod_o_dbl); + // prod_ls - |prod_o_l|0|prod_e_l|0| + let mut prod_lows = LoEvensInterleaveHiEvens::concat_swizzle( + unsafe { transmute::<_, Simd>(prod_e_dbl) }, + unsafe { transmute::<_, Simd>(prod_o_dbl) }, + ); + // Divide by 2: + prod_lows >>= 1; + // prod_ls - |0|prod_o_l|0|prod_e_l| + + // Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl: + let prod_highs = LoOddsInterleaveHiOdds::concat_swizzle( + unsafe { transmute::<_, Simd>(prod_e_dbl) }, + unsafe { transmute::<_, Simd>(prod_o_dbl) }, + ); + + // prod_hs - |0|prod_o_h|0|prod_e_h| + unsafe { PackedBaseField::from_simd(prod_lows) + PackedBaseField::from_simd(prod_highs) } +} diff --git a/crates/prover/src/core/backend/simd/fft/rfft.rs b/crates/prover/src/core/backend/simd/fft/rfft.rs new file mode 100644 index 000000000..1934a1b2d --- /dev/null +++ b/crates/prover/src/core/backend/simd/fft/rfft.rs @@ -0,0 +1,762 @@ +//! Regular (forward) fft. + +use std::array; +use std::simd::{simd_swizzle, u32x16, u32x2, u32x4, u32x8}; + +use itertools::Itertools; + +use super::compute_first_twiddles; +use crate::core::backend::simd::fft::{transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE}; +use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; +use crate::core::circle::Coset; +use crate::core::utils::bit_reverse; + +/// Performs a Circle Fast Fourier Transform (CFFT) on the given values. +/// +/// # Arguments +/// +/// * `src`: A pointer to the values to transform. +/// * `dst`: A pointer to the destination array. +/// * `twiddle_dbl`: A reference to the doubles of the twiddle factors. +/// * `log_n_elements`: The log of the number of elements in the `values` array. +/// +/// # Panics +/// +/// This function will panic if `log_n_elements` is less than `MIN_FFT_LOG_SIZE`. +/// +/// # Safety +/// +/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. +pub unsafe fn fft(src: *const u32, dst: *mut u32, twiddle_dbl: &[&[u32]], log_n_elements: usize) { + assert!(log_n_elements >= MIN_FFT_LOG_SIZE); + let log_n_vecs = log_n_elements - LOG_N_LANES as usize; + if log_n_elements <= CACHED_FFT_LOG_SIZE { + fft_lower_with_vecwise(src, dst, twiddle_dbl, log_n_elements, log_n_elements); + return; + } + + let fft_layers_pre_transpose = log_n_vecs.div_ceil(2); + let fft_layers_post_transpose = log_n_vecs / 2; + fft_lower_without_vecwise( + src, + dst, + &twiddle_dbl[(3 + fft_layers_pre_transpose)..], + log_n_elements, + fft_layers_post_transpose, + ); + transpose_vecs(dst, log_n_vecs); + fft_lower_with_vecwise( + dst, + dst, + &twiddle_dbl[..(3 + fft_layers_pre_transpose)], + log_n_elements, + fft_layers_pre_transpose + LOG_N_LANES as usize, + ); +} + +/// Computes partial fft on `2^log_size` M31 elements. +/// +/// # Arguments +/// +/// - `src`: A pointer to the values to transform, aligned to 64 bytes. +/// - `dst`: A pointer to the destination array, aligned to 64 bytes. +/// - `twiddle_dbl`: The doubles of the twiddle factors for each layer of the the fft. Layer `i` +/// holds `2^(log_size - 1 - i)` twiddles. +/// - `log_size`: The log of the number of number of M31 elements in the array. +/// - `fft_layers`: The number of fft layers to apply, out of log_size. +/// +/// # Panics +/// +/// Panics if `log_size` is not at least 5. +/// +/// # Safety +/// +/// `src` and `dst` must have same alignment as [`PackedBaseField`]. +/// `fft_layers` must be at least 5. +pub unsafe fn fft_lower_with_vecwise( + src: *const u32, + dst: *mut u32, + twiddle_dbl: &[&[u32]], + log_size: usize, + fft_layers: usize, +) { + const VECWISE_FFT_BITS: usize = LOG_N_LANES as usize + 1; + assert!(log_size >= VECWISE_FFT_BITS); + + assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2)); + + for index_h in 0..(1 << (log_size - fft_layers)) { + let mut src = src; + for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3).rev() { + match fft_layers - layer { + 1 => { + fft1_loop(src, dst, &twiddle_dbl[(layer - 1)..], layer, index_h); + } + 2 => { + fft2_loop(src, dst, &twiddle_dbl[(layer - 1)..], layer, index_h); + } + _ => { + fft3_loop( + src, + dst, + &twiddle_dbl[(layer - 1)..], + fft_layers - layer - 3, + layer, + index_h, + ); + } + } + src = dst; + } + fft_vecwise_loop( + src, + dst, + twiddle_dbl, + fft_layers - VECWISE_FFT_BITS, + index_h, + ); + } +} + +/// Computes partial fft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits of +/// the index). +/// +/// # Arguments +/// +/// - `src`: A pointer to the values to transform, aligned to 64 bytes. +/// - `dst`: A pointer to the destination array, aligned to 64 bytes. +/// - `twiddle_dbl`: The doubles of the twiddle factors for each layer of the the fft. +/// - `log_size`: The log of the number of number of M31 elements in the array. +/// - `fft_layers`: The number of fft layers to apply, out of log_size - VEC_LOG_SIZE. +/// +/// # Panics +/// +/// Panics if `log_size` is not at least 4. +/// +/// # Safety +/// +/// `src` and `dst` must have same alignment as [`PackedBaseField`]. +/// `fft_layers` must be at least 4. +pub unsafe fn fft_lower_without_vecwise( + src: *const u32, + dst: *mut u32, + twiddle_dbl: &[&[u32]], + log_size: usize, + fft_layers: usize, +) { + assert!(log_size >= LOG_N_LANES as usize); + + for index_h in 0..(1 << (log_size - fft_layers - LOG_N_LANES as usize)) { + let mut src = src; + for layer in (0..fft_layers).step_by(3).rev() { + let fixed_layer = layer + LOG_N_LANES as usize; + match fft_layers - layer { + 1 => { + fft1_loop(src, dst, &twiddle_dbl[layer..], fixed_layer, index_h); + } + 2 => { + fft2_loop(src, dst, &twiddle_dbl[layer..], fixed_layer, index_h); + } + _ => { + fft3_loop( + src, + dst, + &twiddle_dbl[layer..], + fft_layers - layer - 3, + fixed_layer, + index_h, + ); + } + } + src = dst; + } + } +} + +/// Runs the last 5 fft layers across the entire array. +/// +/// # Arguments +/// +/// - `src`: A pointer to the values to transform, aligned to 64 bytes. +/// - `dst`: A pointer to the destination array, aligned to 64 bytes. +/// - `twiddle_dbl`: The doubles of the twiddle factors for each of the 5 fft layers. +/// - `high_bits`: The number of bits this loops needs to run on. +/// - `index_h`: The higher part of the index, iterated by the caller. +/// +/// # Safety +/// +/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. +unsafe fn fft_vecwise_loop( + src: *const u32, + dst: *mut u32, + twiddle_dbl: &[&[u32]], + loop_bits: usize, + index_h: usize, +) { + for index_l in 0..(1 << loop_bits) { + let index = (index_h << loop_bits) + index_l; + let mut val0 = PackedBaseField::load(src.add(index * 32)); + let mut val1 = PackedBaseField::load(src.add(index * 32 + 16)); + (val0, val1) = butterfly( + val0, + val1, + u32x16::splat(*twiddle_dbl[3].get_unchecked(index)), + ); + (val0, val1) = vecwise_butterflies( + val0, + val1, + array::from_fn(|i| *twiddle_dbl[0].get_unchecked(index * 8 + i)), + array::from_fn(|i| *twiddle_dbl[1].get_unchecked(index * 4 + i)), + array::from_fn(|i| *twiddle_dbl[2].get_unchecked(index * 2 + i)), + ); + val0.store(dst.add(index * 32)); + val1.store(dst.add(index * 32 + 16)); + } +} + +/// Runs 3 fft layers across the entire array. +/// +/// # Arguments +/// +/// - `src`: A pointer to the values to transform, aligned to 64 bytes. +/// - `dst`: A pointer to the destination array, aligned to 64 bytes. +/// - `twiddle_dbl`: The doubles of the twiddle factors for each of the 3 fft layers. +/// - `loop_bits`: The number of bits this loops needs to run on. +/// - `layer`: The layer number of the first fft layer to apply. The layers `layer`, `layer + 1`, +/// `layer + 2` are applied. +/// - `index_h`: The higher part of the index, iterated by the caller. +/// +/// # Safety +/// +/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. +unsafe fn fft3_loop( + src: *const u32, + dst: *mut u32, + twiddle_dbl: &[&[u32]], + loop_bits: usize, + layer: usize, + index_h: usize, +) { + for index_l in 0..(1 << loop_bits) { + let index = (index_h << loop_bits) + index_l; + let offset = index << (layer + 3); + for l in (0..(1 << layer)).step_by(1 << LOG_N_LANES as usize) { + fft3( + src, + dst, + offset + l, + layer, + array::from_fn(|i| { + *twiddle_dbl[0].get_unchecked((index * 4 + i) & (twiddle_dbl[0].len() - 1)) + }), + array::from_fn(|i| { + *twiddle_dbl[1].get_unchecked((index * 2 + i) & (twiddle_dbl[1].len() - 1)) + }), + array::from_fn(|i| { + *twiddle_dbl[2].get_unchecked((index + i) & (twiddle_dbl[2].len() - 1)) + }), + ); + } + } +} + +/// Runs 2 fft layers across the entire array. +/// +/// # Arguments +/// +/// - `src`: A pointer to the values to transform, aligned to 64 bytes. +/// - `dst`: A pointer to the destination array, aligned to 64 bytes. +/// - `twiddle_dbl`: The doubles of the twiddle factors for each of the 2 fft layers. +/// - `loop_bits`: The number of bits this loops needs to run on. +/// - `layer`: The layer number of the first fft layer to apply. The layers `layer`, `layer + 1` are +/// applied. +/// - `index`: The index, iterated by the caller. +/// +/// # Safety +/// +/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. +unsafe fn fft2_loop( + src: *const u32, + dst: *mut u32, + twiddle_dbl: &[&[u32]], + layer: usize, + index: usize, +) { + let offset = index << (layer + 2); + for l in (0..(1 << layer)).step_by(1 << LOG_N_LANES as usize) { + fft2( + src, + dst, + offset + l, + layer, + array::from_fn(|i| { + *twiddle_dbl[0].get_unchecked((index * 2 + i) & (twiddle_dbl[0].len() - 1)) + }), + array::from_fn(|i| { + *twiddle_dbl[1].get_unchecked((index + i) & (twiddle_dbl[1].len() - 1)) + }), + ); + } +} + +/// Runs 1 fft layer across the entire array. +/// +/// # Arguments +/// +/// - `src`: A pointer to the values to transform, aligned to 64 bytes. +/// - `dst`: A pointer to the destination array, aligned to 64 bytes. +/// - `twiddle_dbl`: The doubles of the twiddle factors for the fft layer. +/// - `layer`: The layer number of the fft layer to apply. +/// - `index_h`: The higher part of the index, iterated by the caller. +/// +/// # Safety +/// +/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. +unsafe fn fft1_loop( + src: *const u32, + dst: *mut u32, + twiddle_dbl: &[&[u32]], + layer: usize, + index: usize, +) { + let offset = index << (layer + 1); + for l in (0..(1 << layer)).step_by(1 << LOG_N_LANES as usize) { + fft1( + src, + dst, + offset + l, + layer, + array::from_fn(|i| { + *twiddle_dbl[0].get_unchecked((index + i) & (twiddle_dbl[0].len() - 1)) + }), + ); + } +} + +/// Computes the butterfly operation for packed M31 elements. +/// +/// Returns `val0 + t val1, val0 - t val1`. `val0, val1` are packed M31 elements. 16 M31 words at +/// each. Each value is assumed to be in unreduced form, [0, P] including P. Returned values are in +/// unreduced form, [0, P] including P. twiddle_dbl holds 16 values, each is a *double* of a twiddle +/// factor, in unreduced form, [0, 2*P]. +pub fn butterfly( + val0: PackedBaseField, + val1: PackedBaseField, + twiddle_dbl: u32x16, +) -> (PackedBaseField, PackedBaseField) { + let prod = { + // TODO: Come up with a better approach than `cfg`ing on target_feature. + cfg_if::cfg_if! { + if #[cfg(all(target_feature = "neon", target_arch = "aarch64"))] { + super::_mul_twiddle_neon(val1, twiddle_dbl) + } else if #[cfg(all(target_feature = "simd128", target_arch = "wasm32"))] { + super::_mul_twiddle_wasm(val1, twiddle_dbl) + } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] { + super::_mul_twiddle_avx512(val1, twiddle_dbl) + } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2f"))] { + super::_mul_twiddle_avx2(val1, twiddle_dbl) + } else { + super::_mul_twiddle_simd(val1, twiddle_dbl) + } + } + }; + + let r0 = val0 + prod; + let r1 = val0 - prod; + + (r0, r1) +} + +/// Runs fft on 2 vectors of 16 M31 elements. +/// +/// This amounts to 4 butterfly layers, each with 16 butterflies. +/// Each of the vectors represents natural ordered polynomial coefficeint. +/// Each value in a vectors is in unreduced form: [0, P] including P. +/// Takes 4 twiddle arrays, one for each layer, holding the double of the corresponding twiddle. +/// The first layer (higher bit of the index) takes 2 twiddles. +/// The second layer takes 4 twiddles. +/// etc. +pub fn vecwise_butterflies( + mut val0: PackedBaseField, + mut val1: PackedBaseField, + twiddle1_dbl: [u32; 8], + twiddle2_dbl: [u32; 4], + twiddle3_dbl: [u32; 2], +) -> (PackedBaseField, PackedBaseField) { + // TODO(spapini): Compute twiddle0 from twiddle1. + // TODO(spapini): The permute can be fused with the _mm512_srli_epi64 inside the butterfly. + // The implementation is the exact reverse of vecwise_ibutterflies(). + // See the comments in its body for more info. + // let t = _mm512_set1_epi64(transmute(twiddle3_dbl)); + let t = simd_swizzle!( + u32x2::from(twiddle3_dbl), + [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + ); + (val0, val1) = val0.interleave(val1); + (val0, val1) = butterfly(val0, val1, t); + + // let t = _mm512_broadcast_i32x4(transmute(twiddle2_dbl)); + let t = simd_swizzle!( + u32x4::from(twiddle2_dbl), + [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3] + ); + (val0, val1) = val0.interleave(val1); + (val0, val1) = butterfly(val0, val1, t); + + let (t0, t1) = compute_first_twiddles(u32x8::from(twiddle1_dbl)); + (val0, val1) = val0.interleave(val1); + (val0, val1) = butterfly(val0, val1, t1); + + (val0, val1) = val0.interleave(val1); + (val0, val1) = butterfly(val0, val1, t0); + + val0.interleave(val1) +} + +/// Returns the line twiddles (x points) for an fft on a coset. +pub fn get_twiddle_dbls(mut coset: Coset) -> Vec> { + let mut res = vec![]; + for _ in 0..coset.log_size() { + res.push( + coset + .iter() + .take(coset.size() / 2) + .map(|p| p.x.0 * 2) + .collect_vec(), + ); + bit_reverse(res.last_mut().unwrap()); + coset = coset.double(); + } + + res +} + +/// Applies 3 butterfly layers on 8 vectors of 16 M31 elements. +/// +/// Vectorized over the 16 elements of the vectors. +/// Used for radix-8 ifft. +/// Each butterfly layer, has 3 AVX butterflies. +/// Total of 12 AVX butterflies. +/// +/// # Arguments +/// +/// - `src`: A pointer to the values to transform, aligned to 64 bytes. +/// - `dst`: A pointer to the destination array, aligned to 64 bytes. +/// - `offset`: The offset of the first value in the array. +/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values +/// that need to be transformed. For layer i this is i - 4. +/// - `twiddles_dbl0/1/2`: The double of the twiddles for the 3 layers of butterflies. Each layer +/// has 4/2/1 twiddles. +/// +/// # Safety +/// +/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. +pub unsafe fn fft3( + src: *const u32, + dst: *mut u32, + offset: usize, + log_step: usize, + twiddles_dbl0: [u32; 4], + twiddles_dbl1: [u32; 2], + twiddles_dbl2: [u32; 1], +) { + // Load the 8 AVX vectors from the array. + let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step))); + let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step))); + let mut val2 = PackedBaseField::load(src.add(offset + (2 << log_step))); + let mut val3 = PackedBaseField::load(src.add(offset + (3 << log_step))); + let mut val4 = PackedBaseField::load(src.add(offset + (4 << log_step))); + let mut val5 = PackedBaseField::load(src.add(offset + (5 << log_step))); + let mut val6 = PackedBaseField::load(src.add(offset + (6 << log_step))); + let mut val7 = PackedBaseField::load(src.add(offset + (7 << log_step))); + + // Apply the third layer of butterflies. + (val0, val4) = butterfly(val0, val4, u32x16::splat(twiddles_dbl2[0])); + (val1, val5) = butterfly(val1, val5, u32x16::splat(twiddles_dbl2[0])); + (val2, val6) = butterfly(val2, val6, u32x16::splat(twiddles_dbl2[0])); + (val3, val7) = butterfly(val3, val7, u32x16::splat(twiddles_dbl2[0])); + + // Apply the second layer of butterflies. + (val0, val2) = butterfly(val0, val2, u32x16::splat(twiddles_dbl1[0])); + (val1, val3) = butterfly(val1, val3, u32x16::splat(twiddles_dbl1[0])); + (val4, val6) = butterfly(val4, val6, u32x16::splat(twiddles_dbl1[1])); + (val5, val7) = butterfly(val5, val7, u32x16::splat(twiddles_dbl1[1])); + + // Apply the first layer of butterflies. + (val0, val1) = butterfly(val0, val1, u32x16::splat(twiddles_dbl0[0])); + (val2, val3) = butterfly(val2, val3, u32x16::splat(twiddles_dbl0[1])); + (val4, val5) = butterfly(val4, val5, u32x16::splat(twiddles_dbl0[2])); + (val6, val7) = butterfly(val6, val7, u32x16::splat(twiddles_dbl0[3])); + + // Store the 8 AVX vectors back to the array. + val0.store(dst.add(offset + (0 << log_step))); + val1.store(dst.add(offset + (1 << log_step))); + val2.store(dst.add(offset + (2 << log_step))); + val3.store(dst.add(offset + (3 << log_step))); + val4.store(dst.add(offset + (4 << log_step))); + val5.store(dst.add(offset + (5 << log_step))); + val6.store(dst.add(offset + (6 << log_step))); + val7.store(dst.add(offset + (7 << log_step))); +} + +/// Applies 2 butterfly layers on 4 vectors of 16 M31 elements. +/// +/// Vectorized over the 16 elements of the vectors. +/// Used for radix-4 fft. +/// Each butterfly layer, has 2 AVX butterflies. +/// Total of 4 AVX butterflies. +/// +/// # Arguments +/// +/// - `src`: A pointer to the values to transform, aligned to 64 bytes. +/// - `dst`: A pointer to the destination array, aligned to 64 bytes. +/// - `offset`: The offset of the first value in the array. +/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values +/// that need to be transformed. For layer i this is i - 4. +/// - `twiddles_dbl0/1`: The double of the twiddles for the 2 layers of butterflies. Each layer has +/// 2/1 twiddles. +/// +/// # Safety +/// +/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. +pub unsafe fn fft2( + src: *const u32, + dst: *mut u32, + offset: usize, + log_step: usize, + twiddles_dbl0: [u32; 2], + twiddles_dbl1: [u32; 1], +) { + // Load the 4 AVX vectors from the array. + let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step))); + let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step))); + let mut val2 = PackedBaseField::load(src.add(offset + (2 << log_step))); + let mut val3 = PackedBaseField::load(src.add(offset + (3 << log_step))); + + // Apply the second layer of butterflies. + (val0, val2) = butterfly(val0, val2, u32x16::splat(twiddles_dbl1[0])); + (val1, val3) = butterfly(val1, val3, u32x16::splat(twiddles_dbl1[0])); + + // Apply the first layer of butterflies. + (val0, val1) = butterfly(val0, val1, u32x16::splat(twiddles_dbl0[0])); + (val2, val3) = butterfly(val2, val3, u32x16::splat(twiddles_dbl0[1])); + + // Store the 4 AVX vectors back to the array. + val0.store(dst.add(offset + (0 << log_step))); + val1.store(dst.add(offset + (1 << log_step))); + val2.store(dst.add(offset + (2 << log_step))); + val3.store(dst.add(offset + (3 << log_step))); +} + +/// Applies 1 butterfly layers on 2 vectors of 16 M31 elements. +/// +/// Vectorized over the 16 elements of the vectors. +/// +/// # Arguments +/// +/// - `src`: A pointer to the values to transform, aligned to 64 bytes. +/// - `dst`: A pointer to the destination array, aligned to 64 bytes. +/// - `offset`: The offset of the first value in the array. +/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values +/// that need to be transformed. For layer i this is i - 4. +/// - `twiddles_dbl0`: The double of the twiddles for the butterfly layer. +/// +/// # Safety +/// +/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. +pub unsafe fn fft1( + src: *const u32, + dst: *mut u32, + offset: usize, + log_step: usize, + twiddles_dbl0: [u32; 1], +) { + // Load the 2 AVX vectors from the array. + let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step))); + let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step))); + + (val0, val1) = butterfly(val0, val1, u32x16::splat(twiddles_dbl0[0])); + + // Store the 2 AVX vectors back to the array. + val0.store(dst.add(offset + (0 << log_step))); + val1.store(dst.add(offset + (1 << log_step))); +} + +#[cfg(test)] +mod tests { + use std::mem::transmute; + use std::simd::u32x16; + + use itertools::Itertools; + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + use super::{ + butterfly, fft, fft3, fft_lower_with_vecwise, get_twiddle_dbls, vecwise_butterflies, + }; + use crate::core::backend::cpu::CPUCirclePoly; + use crate::core::backend::simd::column::BaseFieldVec; + use crate::core::backend::simd::fft::{transpose_vecs, CACHED_FFT_LOG_SIZE}; + use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES, N_LANES}; + use crate::core::backend::Column; + use crate::core::fft::butterfly as ground_truth_butterfly; + use crate::core::fields::m31::BaseField; + use crate::core::poly::circle::{CanonicCoset, CircleDomain}; + + #[test] + fn test_butterfly() { + let mut rng = SmallRng::seed_from_u64(0); + let mut v0: [BaseField; N_LANES] = rng.gen(); + let mut v1: [BaseField; N_LANES] = rng.gen(); + let twiddle: [BaseField; N_LANES] = rng.gen(); + let twiddle_dbl = twiddle.map(|v| v.0 * 2); + + let (r0, r1) = butterfly(v0.into(), v1.into(), twiddle_dbl.into()); + + let r0 = r0.to_array(); + let r1 = r1.to_array(); + for i in 0..N_LANES { + ground_truth_butterfly(&mut v0[i], &mut v1[i], twiddle[i]); + assert_eq!((v0[i], v1[i]), (r0[i], r1[i]), "mismatch at i={i}"); + } + } + + #[test] + fn test_fft3() { + let mut rng = SmallRng::seed_from_u64(0); + let values = rng.gen::<[BaseField; 8]>().map(PackedBaseField::broadcast); + let twiddles0: [BaseField; 4] = rng.gen(); + let twiddles1: [BaseField; 2] = rng.gen(); + let twiddles2: [BaseField; 1] = rng.gen(); + let twiddles0_dbl = twiddles0.map(|v| v.0 * 2); + let twiddles1_dbl = twiddles1.map(|v| v.0 * 2); + let twiddles2_dbl = twiddles2.map(|v| v.0 * 2); + + let mut res = values; + unsafe { + fft3( + transmute(res.as_ptr()), + transmute(res.as_mut_ptr()), + 0, + LOG_N_LANES as usize, + twiddles0_dbl, + twiddles1_dbl, + twiddles2_dbl, + ) + }; + + let mut expected = values.map(|v| v.to_array()[0]); + for i in 0..8 { + let j = i ^ 4; + if i > j { + continue; + } + let (mut v0, mut v1) = (expected[i], expected[j]); + ground_truth_butterfly(&mut v0, &mut v1, twiddles2[0]); + (expected[i], expected[j]) = (v0, v1); + } + for i in 0..8 { + let j = i ^ 2; + if i > j { + continue; + } + let (mut v0, mut v1) = (expected[i], expected[j]); + ground_truth_butterfly(&mut v0, &mut v1, twiddles1[i / 4]); + (expected[i], expected[j]) = (v0, v1); + } + for i in 0..8 { + let j = i ^ 1; + if i > j { + continue; + } + let (mut v0, mut v1) = (expected[i], expected[j]); + ground_truth_butterfly(&mut v0, &mut v1, twiddles0[i / 2]); + (expected[i], expected[j]) = (v0, v1); + } + for i in 0..8 { + assert_eq!( + res[i].to_array(), + [expected[i]; N_LANES], + "mismatch at i={i}" + ); + } + } + + #[test] + fn test_vecwise_butterflies() { + let domain = CanonicCoset::new(5).circle_domain(); + let twiddle_dbls = get_twiddle_dbls(domain.half_coset); + assert_eq!(twiddle_dbls.len(), 4); + let mut rng = SmallRng::seed_from_u64(0); + let values: [[BaseField; 16]; 2] = rng.gen(); + + let res = { + let (val0, val1) = butterfly( + values[0].into(), + values[1].into(), + u32x16::splat(twiddle_dbls[3][0]), + ); + let (val0, val1) = vecwise_butterflies( + val0, + val1, + twiddle_dbls[0].clone().try_into().unwrap(), + twiddle_dbls[1].clone().try_into().unwrap(), + twiddle_dbls[2].clone().try_into().unwrap(), + ); + [val0.to_array(), val1.to_array()].concat() + }; + + assert_eq!(res, ground_truth_fft(domain, values.flatten())); + } + + #[test] + fn test_fft_lower() { + for log_size in 5..12 { + let domain = CanonicCoset::new(log_size).circle_domain(); + let mut rng = SmallRng::seed_from_u64(0); + let values = (0..domain.size()).map(|_| rng.gen()).collect_vec(); + let twiddle_dbls = get_twiddle_dbls(domain.half_coset); + + let mut res = values.iter().copied().collect::(); + unsafe { + fft_lower_with_vecwise( + transmute(res.data.as_ptr()), + transmute(res.data.as_mut_ptr()), + &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), + log_size as usize, + log_size as usize, + ) + } + + assert_eq!(res.to_cpu(), ground_truth_fft(domain, &values)); + } + } + + #[test] + fn test_fft_full() { + for log_size in CACHED_FFT_LOG_SIZE + 1..CACHED_FFT_LOG_SIZE + 3 { + let domain = CanonicCoset::new(log_size as u32).circle_domain(); + let mut rng = SmallRng::seed_from_u64(0); + let values = (0..domain.size()).map(|_| rng.gen()).collect_vec(); + let twiddle_dbls = get_twiddle_dbls(domain.half_coset); + + let mut res = values.iter().copied().collect::(); + unsafe { + transpose_vecs(transmute(res.data.as_mut_ptr()), log_size - 4); + fft( + transmute(res.data.as_ptr()), + transmute(res.data.as_mut_ptr()), + &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), + log_size, + ); + } + + assert_eq!(res.to_cpu(), ground_truth_fft(domain, &values)); + } + } + + fn ground_truth_fft(domain: CircleDomain, values: &[BaseField]) -> Vec { + let poly = CPUCirclePoly::new(values.to_vec()); + poly.evaluate(domain).values + } +} diff --git a/crates/prover/src/core/backend/simd/mod.rs b/crates/prover/src/core/backend/simd/mod.rs index 0ec1810af..e89989200 100644 --- a/crates/prover/src/core/backend/simd/mod.rs +++ b/crates/prover/src/core/backend/simd/mod.rs @@ -12,6 +12,7 @@ pub mod bit_reverse; pub mod blake2s; pub mod cm31; pub mod column; +pub mod fft; pub mod m31; pub mod qm31; mod utils;