diff --git a/crates/prover/src/core/backend/simd/fft/ifft.rs b/crates/prover/src/core/backend/simd/fft/ifft.rs index 526d2d9d3..1e68bea86 100644 --- a/crates/prover/src/core/backend/simd/fft/ifft.rs +++ b/crates/prover/src/core/backend/simd/fft/ifft.rs @@ -2,6 +2,8 @@ use std::simd::{simd_swizzle, u32x16, u32x2, u32x4}; +use itertools::Itertools; + use super::{ compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE, }; @@ -10,24 +12,24 @@ use crate::core::circle::Coset; use crate::core::fields::FieldExpOps; use crate::core::utils::bit_reverse; -const VECS_LOG_SIZE: usize = LOG_N_LANES as usize; - /// Performs an Inverse Circle Fast Fourier Transform (ICFFT) on the given values. /// -/// # Safety -/// This function is unsafe because it takes a raw pointer to i32 values. -/// `values` must be aligned to 64 bytes. -/// /// # Arguments -/// * `values`: A mutable pointer to the values on which the ICFFT is to be performed. -/// * `twiddle_dbl`: A reference to the doubles of the twiddle factors. -/// * `log_n_elements`: The log of the number of elements in the `values` array. +/// +/// - `values`: A mutable pointer to the values on which the ICFFT is to be performed. +/// - `twiddle_dbl`: A reference to the doubles of the twiddle factors. +/// - `log_n_elements`: The log of the number of elements in the `values` array. /// /// # Panics +/// /// This function will panic if `log_n_elements` is less than `MIN_FFT_LOG_SIZE`. -pub unsafe fn ifft(values: *mut i32, twiddle_dbl: &[&[i32]], log_n_elements: usize) { +/// +/// # Safety +/// +/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. +pub unsafe fn ifft(values: *mut u32, twiddle_dbl: &[&[u32]], log_n_elements: usize) { assert!(log_n_elements >= MIN_FFT_LOG_SIZE); - let log_n_vecs = log_n_elements - VECS_LOG_SIZE; + let log_n_vecs = log_n_elements - LOG_N_LANES as usize; if log_n_elements <= CACHED_FFT_LOG_SIZE { ifft_lower_with_vecwise(values, twiddle_dbl, log_n_elements, log_n_elements); return; @@ -37,9 +39,9 @@ pub unsafe fn ifft(values: *mut i32, twiddle_dbl: &[&[i32]], log_n_elements: usi let fft_layers_post_transpose = log_n_vecs / 2; ifft_lower_with_vecwise( values, - &twiddle_dbl[..(3 + fft_layers_pre_transpose)], + &twiddle_dbl[..3 + fft_layers_pre_transpose], log_n_elements, - fft_layers_pre_transpose + VECS_LOG_SIZE, + fft_layers_pre_transpose + LOG_N_LANES as usize, ); transpose_vecs(values, log_n_vecs); ifft_lower_without_vecwise( @@ -51,28 +53,35 @@ pub unsafe fn ifft(values: *mut i32, twiddle_dbl: &[&[i32]], log_n_elements: usi } /// Computes partial ifft on `2^log_size` M31 elements. -/// Parameters: -/// values - Pointer to the entire value array, aligned to 64 bytes. -/// twiddle_dbl - The doubles of the twiddle factors for each layer of the the ifft. -/// layer i holds 2^(log_size - 1 - i) twiddles. -/// log_size - The log of the number of number of M31 elements in the array. -/// fft_layers - The number of ifft layers to apply, out of log_size. +/// +/// # Arguments +/// +/// - `values`: Pointer to the entire value array, aligned to 64 bytes. +/// - `twiddle_dbl`: The doubles of the twiddle factors for each layer of the the ifft. Layer i +/// holds `2^(log_size - 1 - i)` twiddles. +/// - `log_size`: The log of the number of number of M31 elements in the array. +/// - `fft_layers`: The number of ifft layers to apply, out of log_size. +/// +/// # Panics +/// +/// Panics if `log_size` is not at least 5. +/// /// # Safety -/// `values` must be aligned to 64 bytes. -/// `log_size` must be at least 5. +/// +/// `values` must have the same alignment as [`PackedBaseField`]. /// `fft_layers` must be at least 5. pub unsafe fn ifft_lower_with_vecwise( - values: *mut i32, - twiddle_dbl: &[&[i32]], + values: *mut u32, + twiddle_dbl: &[&[u32]], log_size: usize, fft_layers: usize, ) { - const VECWISE_FFT_BITS: usize = VECS_LOG_SIZE + 1; + const VECWISE_FFT_BITS: usize = LOG_N_LANES as usize + 1; assert!(log_size >= VECWISE_FFT_BITS); assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2)); - for index_h in 0..(1 << (log_size - fft_layers)) { + for index_h in 0..1 << (log_size - fft_layers) { ifft_vecwise_loop(values, twiddle_dbl, fft_layers - VECWISE_FFT_BITS, index_h); for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3) { match fft_layers - layer { @@ -96,29 +105,35 @@ pub unsafe fn ifft_lower_with_vecwise( } } -/// Computes partial ifft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits -/// of the index). -/// Parameters: +/// Computes partial ifft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits of +/// the index). +/// +/// # Arguments +/// /// values - Pointer to the entire value array, aligned to 64 bytes. /// twiddle_dbl - The doubles of the twiddle factors for each layer of the the ifft. /// log_size - The log of the number of number of M31 elements in the array. /// fft_layers - The number of ifft layers to apply, out of log_size - VEC_LOG_SIZE. /// +/// # Panics +/// +/// Panics if `log_size` is not at least 4. +/// /// # Safety -/// `values` must be aligned to 64 bytes. -/// `log_size` must be at least 4. +/// +/// `values` must have the same alignment as [`PackedBaseField`]. /// `fft_layers` must be at least 4. pub unsafe fn ifft_lower_without_vecwise( - values: *mut i32, - twiddle_dbl: &[&[i32]], + values: *mut u32, + twiddle_dbl: &[&[u32]], log_size: usize, fft_layers: usize, ) { - assert!(log_size >= VECS_LOG_SIZE); + assert!(log_size >= LOG_N_LANES as usize); - for index_h in 0..(1 << (log_size - fft_layers - VECS_LOG_SIZE)) { + for index_h in 0..1 << (log_size - fft_layers - LOG_N_LANES as usize) { for layer in (0..fft_layers).step_by(3) { - let fixed_layer = layer + VECS_LOG_SIZE; + let fixed_layer = layer + LOG_N_LANES as usize; match fft_layers - layer { 1 => { ifft1_loop(values, &twiddle_dbl[layer..], fixed_layer, index_h); @@ -141,23 +156,27 @@ pub unsafe fn ifft_lower_without_vecwise( } /// Runs the first 5 ifft layers across the entire array. -/// Parameters: +/// +/// # Arguments +/// /// values - Pointer to the entire value array, aligned to 64 bytes. /// twiddle_dbl - The doubles of the twiddle factors for each of the 5 ifft layers. /// high_bits - The number of bits this loops needs to run on. /// index_h - The higher part of the index, iterated by the caller. +/// /// # Safety +/// +/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. pub unsafe fn ifft_vecwise_loop( - values: *mut i32, - twiddle_dbl: &[&[i32]], + values: *mut u32, + twiddle_dbl: &[&[u32]], loop_bits: usize, index_h: usize, ) { - for index_l in 0..(1 << loop_bits) { + for index_l in 0..1 << loop_bits { let index = (index_h << loop_bits) + index_l; - let mut val0 = PackedBaseField::load(values.add(index * 32).cast_const() as *const u32); - let mut val1 = - PackedBaseField::load(values.add(index * 32 + 16).cast_const() as *const u32); + let mut val0 = PackedBaseField::load(values.add(index * 32).cast_const()); + let mut val1 = PackedBaseField::load(values.add(index * 32 + 16).cast_const()); (val0, val1) = vecwise_ibutterflies( val0, val1, @@ -165,36 +184,41 @@ pub unsafe fn ifft_vecwise_loop( std::array::from_fn(|i| *twiddle_dbl[1].get_unchecked(index * 4 + i)), std::array::from_fn(|i| *twiddle_dbl[2].get_unchecked(index * 2 + i)), ); - (val0, val1) = avx_ibutterfly( + (val0, val1) = ibutterfly( val0, val1, - u32x16::splat(*twiddle_dbl[3].get_unchecked(index) as u32), + u32x16::splat(*twiddle_dbl[3].get_unchecked(index)), ); - val0.store(values.add(index * 32) as *mut u32); - val1.store(values.add(index * 32 + 16) as *mut u32); + val0.store(values.add(index * 32)); + val1.store(values.add(index * 32 + 16)); } } /// Runs 3 ifft layers across the entire array. -/// Parameters: +/// +/// # Arguments +/// /// values - Pointer to the entire value array, aligned to 64 bytes. /// twiddle_dbl - The doubles of the twiddle factors for each of the 3 ifft layers. /// loop_bits - The number of bits this loops needs to run on. /// layer - The layer number of the first ifft layer to apply. /// The layers `layer`, `layer + 1`, `layer + 2` are applied. /// index_h - The higher part of the index, iterated by the caller. +/// /// # Safety +/// +/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. pub unsafe fn ifft3_loop( - values: *mut i32, - twiddle_dbl: &[&[i32]], + values: *mut u32, + twiddle_dbl: &[&[u32]], loop_bits: usize, layer: usize, index_h: usize, ) { - for index_l in 0..(1 << loop_bits) { + for index_l in 0..1 << loop_bits { let index = (index_h << loop_bits) + index_l; let offset = index << (layer + 3); - for l in (0..(1 << layer)).step_by(1 << VECS_LOG_SIZE) { + for l in (0..1 << layer).step_by(1 << LOG_N_LANES as usize) { ifft3( values, offset + l, @@ -214,17 +238,22 @@ pub unsafe fn ifft3_loop( } /// Runs 2 ifft layers across the entire array. -/// Parameters: +/// +/// # Arguments +/// /// values - Pointer to the entire value array, aligned to 64 bytes. /// twiddle_dbl - The doubles of the twiddle factors for each of the 2 ifft layers. /// loop_bits - The number of bits this loops needs to run on. /// layer - The layer number of the first ifft layer to apply. /// The layers `layer`, `layer + 1` are applied. /// index - The index, iterated by the caller. +/// /// # Safety -unsafe fn ifft2_loop(values: *mut i32, twiddle_dbl: &[&[i32]], layer: usize, index: usize) { +/// +/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. +unsafe fn ifft2_loop(values: *mut u32, twiddle_dbl: &[&[u32]], layer: usize, index: usize) { let offset = index << (layer + 2); - for l in (0..(1 << layer)).step_by(1 << VECS_LOG_SIZE) { + for l in (0..1 << layer).step_by(1 << LOG_N_LANES as usize) { ifft2( values, offset + l, @@ -240,15 +269,20 @@ unsafe fn ifft2_loop(values: *mut i32, twiddle_dbl: &[&[i32]], layer: usize, ind } /// Runs 1 ifft layer across the entire array. -/// Parameters: +/// +/// # Arguments +/// /// values - Pointer to the entire value array, aligned to 64 bytes. /// twiddle_dbl - The doubles of the twiddle factors for the ifft layer. /// layer - The layer number of the ifft layer to apply. /// index_h - The higher part of the index, iterated by the caller. +/// /// # Safety -unsafe fn ifft1_loop(values: *mut i32, twiddle_dbl: &[&[i32]], layer: usize, index: usize) { +/// +/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. +unsafe fn ifft1_loop(values: *mut u32, twiddle_dbl: &[&[u32]], layer: usize, index: usize) { let offset = index << (layer + 1); - for l in (0..(1 << layer)).step_by(1 << VECS_LOG_SIZE) { + for l in (0..1 << layer).step_by(1 << LOG_N_LANES as usize) { ifft1( values, offset + l, @@ -261,13 +295,11 @@ unsafe fn ifft1_loop(values: *mut i32, twiddle_dbl: &[&[i32]], layer: usize, ind } /// Computes the ibutterfly operation for packed M31 elements. -/// val0 + val1, t (val0 - val1). -/// val0, val1 are packed M31 elements. 16 M31 words at each. -/// Each value is assumed to be in unreduced form, [0, P] including P. -/// twiddle_dbl holds 16 values, each is a *double* of a twiddle factor, in unreduced form. -/// # Safety -/// This function is safe. -pub unsafe fn avx_ibutterfly( +/// +/// Returns `val0 + val1, t (val0 - val1)`. `val0, val1` are packed M31 elements. 16 M31 words at +/// each. Each value is assumed to be in unreduced form, [0, P] including P. `twiddle_dbl` holds 16 +/// values, each is a *double* of a twiddle factor, in unreduced form. +pub fn ibutterfly( val0: PackedBaseField, val1: PackedBaseField, twiddle_dbl: u32x16, @@ -279,6 +311,7 @@ pub unsafe fn avx_ibutterfly( } /// Runs ifft on 2 vectors of 16 M31 elements. +/// /// This amounts to 4 butterfly layers, each with 16 butterflies. /// Each of the vectors represents a bit reversed evaluation. /// Each value in a vectors is in unreduced form: [0, P] including P. @@ -288,18 +321,16 @@ pub unsafe fn avx_ibutterfly( /// twiddles. The second layer takes 8 twiddles. /// The third layer takes 4 twiddles. /// The fourth layer takes 2 twiddles. -/// # Safety -/// This function is safe. -pub unsafe fn vecwise_ibutterflies( +pub fn vecwise_ibutterflies( mut val0: PackedBaseField, mut val1: PackedBaseField, - twiddle1_dbl: [i32; 8], - twiddle2_dbl: [i32; 4], - twiddle3_dbl: [i32; 2], + twiddle1_dbl: [u32; 8], + twiddle2_dbl: [u32; 4], + twiddle3_dbl: [u32; 2], ) -> (PackedBaseField, PackedBaseField) { // TODO(spapini): The permute can be fused with the _mm512_srli_epi64 inside the butterfly. - // Each avx_ibutterfly take 2 512-bit registers, and does 16 butterflies element by element. + // Each `ibutterfly` take 2 512-bit registers, and does 16 butterflies element by element. // We need to permute the 512-bit registers to get the right order for the butterflies. // Denote the index of the 16 M31 elements in register i as i:abcd. // At each layer we apply the following permutation to the index: @@ -316,46 +347,46 @@ pub unsafe fn vecwise_ibutterflies( // ifft on a // i:abcd - let (t0, t1) = compute_first_twiddles(twiddle1_dbl); + let (t0, t1) = compute_first_twiddles(twiddle1_dbl.into()); // Apply the permutation, resulting in indexing d:iabc. (val0, val1) = val0.deinterleave(val1); - (val0, val1) = avx_ibutterfly(val0, val1, t0); + (val0, val1) = ibutterfly(val0, val1, t0); // Apply the permutation, resulting in indexing c:diab. (val0, val1) = val0.deinterleave(val1); - (val0, val1) = avx_ibutterfly(val0, val1, t1); + (val0, val1) = ibutterfly(val0, val1, t1); let t = simd_swizzle!( - u32x4::from_array(unsafe { std::mem::transmute(twiddle2_dbl) }), + u32x4::from(twiddle2_dbl), [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3] ); // Apply the permutation, resulting in indexing b:cdia. (val0, val1) = val0.deinterleave(val1); - (val0, val1) = avx_ibutterfly(val0, val1, t); + (val0, val1) = ibutterfly(val0, val1, t); let t = simd_swizzle!( - u32x2::from_array(unsafe { std::mem::transmute(twiddle3_dbl) }), + u32x2::from(twiddle3_dbl), [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] ); // Apply the permutation, resulting in indexing a:bcid. (val0, val1) = val0.deinterleave(val1); - (val0, val1) = avx_ibutterfly(val0, val1, t); + (val0, val1) = ibutterfly(val0, val1, t); // Apply the permutation, resulting in indexing i:abcd. val0.deinterleave(val1) } /// Returns the line twiddles (x points) for an ifft on a coset. -pub fn get_itwiddle_dbls(mut coset: Coset) -> Vec> { +pub fn get_itwiddle_dbls(mut coset: Coset) -> Vec> { let mut res = vec![]; for _ in 0..coset.log_size() { res.push( coset .iter() .take(coset.size() / 2) - .map(|p| (p.x.inverse().0 * 2) as i32) - .collect::>(), + .map(|p| p.x.inverse().0 * 2) + .collect_vec(), ); bit_reverse(res.last_mut().unwrap()); coset = coset.double(); @@ -365,254 +396,242 @@ pub fn get_itwiddle_dbls(mut coset: Coset) -> Vec> { } /// Applies 3 ibutterfly layers on 8 vectors of 16 M31 elements. +/// /// Vectorized over the 16 elements of the vectors. /// Used for radix-8 ifft. -/// Each butterfly layer, has 3 AVX butterflies. -/// Total of 12 AVX butterflies. -/// Parameters: -/// values - Pointer to the entire value array. -/// offset - The offset of the first value in the array. -/// log_step - The log of the distance in the array, in M31 elements, between each pair of -/// values that need to be transformed. For layer i this is i - 4. -/// twiddles_dbl0/1/2 - The double of the twiddles for the 3 layers of ibutterflies. -/// Each layer has 4/2/1 twiddles. +/// Each butterfly layer, has 3 SIMD butterflies. +/// Total of 12 SIMD butterflies. +/// +/// # Arguments +/// +/// - `values`: Pointer to the entire value array. +/// - `offset`: The offset of the first value in the array. +/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values +/// that need to be transformed. For layer i this is i - 4. +/// - `twiddles_dbl0/1/2`: The double of the twiddles for the 3 layers of ibutterflies. Each layer +/// has 4/2/1 twiddles. +/// /// # Safety +/// +/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. pub unsafe fn ifft3( - values: *mut i32, + values: *mut u32, offset: usize, log_step: usize, - twiddles_dbl0: [i32; 4], - twiddles_dbl1: [i32; 2], - twiddles_dbl2: [i32; 1], + twiddles_dbl0: [u32; 4], + twiddles_dbl1: [u32; 2], + twiddles_dbl2: [u32; 1], ) { - // Load the 8 AVX vectors from the array. - let mut val0 = - PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const() as *const u32); - let mut val1 = - PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const() as *const u32); - let mut val2 = - PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const() as *const u32); - let mut val3 = - PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const() as *const u32); - let mut val4 = - PackedBaseField::load(values.add(offset + (4 << log_step)).cast_const() as *const u32); - let mut val5 = - PackedBaseField::load(values.add(offset + (5 << log_step)).cast_const() as *const u32); - let mut val6 = - PackedBaseField::load(values.add(offset + (6 << log_step)).cast_const() as *const u32); - let mut val7 = - PackedBaseField::load(values.add(offset + (7 << log_step)).cast_const() as *const u32); + // Load the 8 SIMD vectors from the array. + let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const()); + let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const()); + let mut val2 = PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const()); + let mut val3 = PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const()); + let mut val4 = PackedBaseField::load(values.add(offset + (4 << log_step)).cast_const()); + let mut val5 = PackedBaseField::load(values.add(offset + (5 << log_step)).cast_const()); + let mut val6 = PackedBaseField::load(values.add(offset + (6 << log_step)).cast_const()); + let mut val7 = PackedBaseField::load(values.add(offset + (7 << log_step)).cast_const()); // Apply the first layer of ibutterflies. - (val0, val1) = avx_ibutterfly(val0, val1, u32x16::splat(twiddles_dbl0[0] as u32)); - (val2, val3) = avx_ibutterfly(val2, val3, u32x16::splat(twiddles_dbl0[1] as u32)); - (val4, val5) = avx_ibutterfly(val4, val5, u32x16::splat(twiddles_dbl0[2] as u32)); - (val6, val7) = avx_ibutterfly(val6, val7, u32x16::splat(twiddles_dbl0[3] as u32)); + (val0, val1) = ibutterfly(val0, val1, u32x16::splat(twiddles_dbl0[0])); + (val2, val3) = ibutterfly(val2, val3, u32x16::splat(twiddles_dbl0[1])); + (val4, val5) = ibutterfly(val4, val5, u32x16::splat(twiddles_dbl0[2])); + (val6, val7) = ibutterfly(val6, val7, u32x16::splat(twiddles_dbl0[3])); // Apply the second layer of ibutterflies. - (val0, val2) = avx_ibutterfly(val0, val2, u32x16::splat(twiddles_dbl1[0] as u32)); - (val1, val3) = avx_ibutterfly(val1, val3, u32x16::splat(twiddles_dbl1[0] as u32)); - (val4, val6) = avx_ibutterfly(val4, val6, u32x16::splat(twiddles_dbl1[1] as u32)); - (val5, val7) = avx_ibutterfly(val5, val7, u32x16::splat(twiddles_dbl1[1] as u32)); + (val0, val2) = ibutterfly(val0, val2, u32x16::splat(twiddles_dbl1[0])); + (val1, val3) = ibutterfly(val1, val3, u32x16::splat(twiddles_dbl1[0])); + (val4, val6) = ibutterfly(val4, val6, u32x16::splat(twiddles_dbl1[1])); + (val5, val7) = ibutterfly(val5, val7, u32x16::splat(twiddles_dbl1[1])); // Apply the third layer of ibutterflies. - (val0, val4) = avx_ibutterfly(val0, val4, u32x16::splat(twiddles_dbl2[0] as u32)); - (val1, val5) = avx_ibutterfly(val1, val5, u32x16::splat(twiddles_dbl2[0] as u32)); - (val2, val6) = avx_ibutterfly(val2, val6, u32x16::splat(twiddles_dbl2[0] as u32)); - (val3, val7) = avx_ibutterfly(val3, val7, u32x16::splat(twiddles_dbl2[0] as u32)); - - // Store the 8 AVX vectors back to the array. - val0.store(values.add(offset + (0 << log_step)) as *mut u32); - val1.store(values.add(offset + (1 << log_step)) as *mut u32); - val2.store(values.add(offset + (2 << log_step)) as *mut u32); - val3.store(values.add(offset + (3 << log_step)) as *mut u32); - val4.store(values.add(offset + (4 << log_step)) as *mut u32); - val5.store(values.add(offset + (5 << log_step)) as *mut u32); - val6.store(values.add(offset + (6 << log_step)) as *mut u32); - val7.store(values.add(offset + (7 << log_step)) as *mut u32); + (val0, val4) = ibutterfly(val0, val4, u32x16::splat(twiddles_dbl2[0])); + (val1, val5) = ibutterfly(val1, val5, u32x16::splat(twiddles_dbl2[0])); + (val2, val6) = ibutterfly(val2, val6, u32x16::splat(twiddles_dbl2[0])); + (val3, val7) = ibutterfly(val3, val7, u32x16::splat(twiddles_dbl2[0])); + + // Store the 8 SIMD vectors back to the array. + val0.store(values.add(offset + (0 << log_step))); + val1.store(values.add(offset + (1 << log_step))); + val2.store(values.add(offset + (2 << log_step))); + val3.store(values.add(offset + (3 << log_step))); + val4.store(values.add(offset + (4 << log_step))); + val5.store(values.add(offset + (5 << log_step))); + val6.store(values.add(offset + (6 << log_step))); + val7.store(values.add(offset + (7 << log_step))); } /// Applies 2 ibutterfly layers on 4 vectors of 16 M31 elements. +/// /// Vectorized over the 16 elements of the vectors. /// Used for radix-4 ifft. -/// Each ibutterfly layer, has 2 AVX butterflies. -/// Total of 4 AVX butterflies. -/// Parameters: +/// Each ibutterfly layer, has 2 SIMD butterflies. +/// Total of 4 SIMD butterflies. +/// +/// # Arguments +/// /// values - Pointer to the entire value array. /// offset - The offset of the first value in the array. /// log_step - The log of the distance in the array, in M31 elements, between each pair of /// values that need to be transformed. For layer i this is i - 4. /// twiddles_dbl0/1 - The double of the twiddles for the 2 layers of ibutterflies. /// Each layer has 2/1 twiddles. +/// /// # Safety +/// +/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. pub unsafe fn ifft2( - values: *mut i32, + values: *mut u32, offset: usize, log_step: usize, - twiddles_dbl0: [i32; 2], - twiddles_dbl1: [i32; 1], + twiddles_dbl0: [u32; 2], + twiddles_dbl1: [u32; 1], ) { - // Load the 4 AVX vectors from the array. - let mut val0 = - PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const() as *const u32); - let mut val1 = - PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const() as *const u32); - let mut val2 = - PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const() as *const u32); - let mut val3 = - PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const() as *const u32); + // Load the 4 SIMD vectors from the array. + let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const()); + let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const()); + let mut val2 = PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const()); + let mut val3 = PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const()); // Apply the first layer of butterflies. - (val0, val1) = avx_ibutterfly(val0, val1, u32x16::splat(twiddles_dbl0[0] as u32)); - (val2, val3) = avx_ibutterfly(val2, val3, u32x16::splat(twiddles_dbl0[1] as u32)); + (val0, val1) = ibutterfly(val0, val1, u32x16::splat(twiddles_dbl0[0])); + (val2, val3) = ibutterfly(val2, val3, u32x16::splat(twiddles_dbl0[1])); // Apply the second layer of butterflies. - (val0, val2) = avx_ibutterfly(val0, val2, u32x16::splat(twiddles_dbl1[0] as u32)); - (val1, val3) = avx_ibutterfly(val1, val3, u32x16::splat(twiddles_dbl1[0] as u32)); - - // Store the 4 AVX vectors back to the array. - val0.store(values.add(offset + (0 << log_step)) as *mut u32); - val1.store(values.add(offset + (1 << log_step)) as *mut u32); - val2.store(values.add(offset + (2 << log_step)) as *mut u32); - val3.store(values.add(offset + (3 << log_step)) as *mut u32); + (val0, val2) = ibutterfly(val0, val2, u32x16::splat(twiddles_dbl1[0])); + (val1, val3) = ibutterfly(val1, val3, u32x16::splat(twiddles_dbl1[0])); + + // Store the 4 SIMD vectors back to the array. + val0.store(values.add(offset + (0 << log_step))); + val1.store(values.add(offset + (1 << log_step))); + val2.store(values.add(offset + (2 << log_step))); + val3.store(values.add(offset + (3 << log_step))); } /// Applies 1 ibutterfly layers on 2 vectors of 16 M31 elements. +/// /// Vectorized over the 16 elements of the vectors. -/// Parameters: +/// +/// # Arguments +/// /// values - Pointer to the entire value array. /// offset - The offset of the first value in the array. /// log_step - The log of the distance in the array, in M31 elements, between each pair of /// values that need to be transformed. For layer i this is i - 4. /// twiddles_dbl0 - The double of the twiddles for the ibutterfly layer. +/// /// # Safety -pub unsafe fn ifft1(values: *mut i32, offset: usize, log_step: usize, twiddles_dbl0: [i32; 1]) { - // Load the 2 AVX vectors from the array. - let mut val0 = - PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const() as *const u32); - let mut val1 = - PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const() as *const u32); - - (val0, val1) = avx_ibutterfly(val0, val1, u32x16::splat(twiddles_dbl0[0] as u32)); - - // Store the 2 AVX vectors back to the array. - val0.store(values.add(offset + (0 << log_step)) as *mut u32); - val1.store(values.add(offset + (1 << log_step)) as *mut u32); +/// +/// Behavior is undefined if `values` does not have the same alignment as [`PackedBaseField`]. +pub unsafe fn ifft1(values: *mut u32, offset: usize, log_step: usize, twiddles_dbl0: [u32; 1]) { + // Load the 2 SIMD vectors from the array. + let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const()); + let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const()); + + (val0, val1) = ibutterfly(val0, val1, u32x16::splat(twiddles_dbl0[0])); + + // Store the 2 SIMD vectors back to the array. + val0.store(values.add(offset + (0 << log_step))); + val1.store(values.add(offset + (1 << log_step))); } #[cfg(test)] mod tests { - use super::*; + use std::mem::transmute; + use std::simd::u32x16; + + use itertools::Itertools; + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + use super::{ + get_itwiddle_dbls, ibutterfly, ifft, ifft3, ifft_lower_with_vecwise, vecwise_ibutterflies, + }; use crate::core::backend::cpu::CPUCircleEvaluation; use crate::core::backend::simd::column::BaseFieldVec; - use crate::core::backend::simd::m31::PackedBaseField; + use crate::core::backend::simd::fft::{transpose_vecs, CACHED_FFT_LOG_SIZE}; + use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES, N_LANES}; use crate::core::backend::Column; - use crate::core::fft::ibutterfly; + use crate::core::fft::ibutterfly as ground_truth_ibutterfly; use crate::core::fields::m31::BaseField; use crate::core::poly::circle::{CanonicCoset, CircleDomain}; #[test] fn test_ibutterfly() { - unsafe { - let val0 = PackedBaseField::from_simd_unchecked(u32x16::from_array([ - 2, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - ])); - let val1 = PackedBaseField::from_simd_unchecked(u32x16::from_array([ - 3, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, - ])); - let twiddle = u32x16::from_array([ - 1177558791, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, - ]); - let twiddle_dbl = twiddle + twiddle; - let (r0, r1) = avx_ibutterfly(val0, val1, twiddle_dbl); - - let val0: [BaseField; 16] = std::mem::transmute(val0); - let val1: [BaseField; 16] = std::mem::transmute(val1); - let twiddle: [BaseField; 16] = std::mem::transmute(twiddle); - let r0: [BaseField; 16] = std::mem::transmute(r0); - let r1: [BaseField; 16] = std::mem::transmute(r1); - - for i in 0..16 { - let mut x = val0[i]; - let mut y = val1[i]; - let twiddle = twiddle[i]; - ibutterfly(&mut x, &mut y, twiddle); - assert_eq!(x, r0[i]); - assert_eq!(y, r1[i]); - } + let mut rng = SmallRng::seed_from_u64(0); + let mut v0: [BaseField; N_LANES] = rng.gen(); + let mut v1: [BaseField; N_LANES] = rng.gen(); + let twiddle: [BaseField; N_LANES] = rng.gen(); + let twiddle_dbl = twiddle.map(|v| v.0 * 2); + + let (r0, r1) = ibutterfly(v0.into(), v1.into(), twiddle_dbl.into()); + + let r0 = r0.to_array(); + let r1 = r1.to_array(); + for i in 0..N_LANES { + ground_truth_ibutterfly(&mut v0[i], &mut v1[i], twiddle[i]); + assert_eq!((v0[i], v1[i]), (r0[i], r1[i]), "mismatch at i={i}"); } } #[test] fn test_ifft3() { + let mut rng = SmallRng::seed_from_u64(0); + let values = rng.gen::<[BaseField; 8]>().map(PackedBaseField::broadcast); + let twiddles0: [BaseField; 4] = rng.gen(); + let twiddles1: [BaseField; 2] = rng.gen(); + let twiddles2: [BaseField; 1] = rng.gen(); + let twiddles0_dbl = twiddles0.map(|v| v.0 * 2); + let twiddles1_dbl = twiddles1.map(|v| v.0 * 2); + let twiddles2_dbl = twiddles2.map(|v| v.0 * 2); + + let mut res = values; unsafe { - let mut values: Vec = (0..8) - .map(|i| { - PackedBaseField::from_array(std::array::from_fn(|_| { - BaseField::from_u32_unchecked(i) - })) - }) - .collect(); - let twiddles0 = [32, 33, 34, 35]; - let twiddles1 = [36, 37]; - let twiddles2 = [38]; - let twiddles0_dbl = std::array::from_fn(|i| twiddles0[i] * 2); - let twiddles1_dbl = std::array::from_fn(|i| twiddles1[i] * 2); - let twiddles2_dbl = std::array::from_fn(|i| twiddles2[i] * 2); ifft3( - std::mem::transmute(values.as_mut_ptr()), + transmute(res.as_mut_ptr()), 0, - VECS_LOG_SIZE, + LOG_N_LANES as usize, twiddles0_dbl, twiddles1_dbl, twiddles2_dbl, - ); + ) + }; - let expected: [u32; 8] = std::array::from_fn(|i| i as u32); - let mut expected: [BaseField; 8] = std::mem::transmute(expected); - let twiddles0: [BaseField; 4] = std::mem::transmute(twiddles0); - let twiddles1: [BaseField; 2] = std::mem::transmute(twiddles1); - let twiddles2: [BaseField; 1] = std::mem::transmute(twiddles2); - for i in 0..8 { - let j = i ^ 1; - if i > j { - continue; - } - let (mut v0, mut v1) = (expected[i], expected[j]); - ibutterfly(&mut v0, &mut v1, twiddles0[i / 2]); - (expected[i], expected[j]) = (v0, v1); + let mut expected = values.map(|v| v.to_array()[0]); + for i in 0..8 { + let j = i ^ 1; + if i > j { + continue; } - for i in 0..8 { - let j = i ^ 2; - if i > j { - continue; - } - let (mut v0, mut v1) = (expected[i], expected[j]); - ibutterfly(&mut v0, &mut v1, twiddles1[i / 4]); - (expected[i], expected[j]) = (v0, v1); - } - for i in 0..8 { - let j = i ^ 4; - if i > j { - continue; - } - let (mut v0, mut v1) = (expected[i], expected[j]); - ibutterfly(&mut v0, &mut v1, twiddles2[0]); - (expected[i], expected[j]) = (v0, v1); + let (mut v0, mut v1) = (expected[i], expected[j]); + ground_truth_ibutterfly(&mut v0, &mut v1, twiddles0[i / 2]); + (expected[i], expected[j]) = (v0, v1); + } + for i in 0..8 { + let j = i ^ 2; + if i > j { + continue; } - for i in 0..8 { - assert_eq!(values[i].to_array()[0], expected[i]); + let (mut v0, mut v1) = (expected[i], expected[j]); + ground_truth_ibutterfly(&mut v0, &mut v1, twiddles1[i / 4]); + (expected[i], expected[j]) = (v0, v1); + } + for i in 0..8 { + let j = i ^ 4; + if i > j { + continue; } + let (mut v0, mut v1) = (expected[i], expected[j]); + ground_truth_ibutterfly(&mut v0, &mut v1, twiddles2[0]); + (expected[i], expected[j]) = (v0, v1); } - } - - fn ref_ifft(domain: CircleDomain, values: Vec) -> Vec { - let eval = CPUCircleEvaluation::new(domain, values); - let mut expected_coeffs = eval.interpolate().coeffs; - for x in expected_coeffs.iter_mut() { - *x *= BaseField::from_u32_unchecked(domain.size() as u32); + for i in 0..8 { + assert_eq!( + res[i].to_array(), + [expected[i]; N_LANES], + "mismatch at i={i}" + ); } - expected_coeffs } #[test] @@ -620,95 +639,73 @@ mod tests { let domain = CanonicCoset::new(5).circle_domain(); let twiddle_dbls = get_itwiddle_dbls(domain.half_coset); assert_eq!(twiddle_dbls.len(), 4); - let values0: [i32; 16] = std::array::from_fn(|i| i as i32); - let values1: [i32; 16] = std::array::from_fn(|i| (i + 16) as i32); - let result: [BaseField; 32] = unsafe { + let mut rng = SmallRng::seed_from_u64(0); + let values: [[BaseField; 16]; 2] = rng.gen(); + + let res = { let (val0, val1) = vecwise_ibutterflies( - std::mem::transmute(values0), - std::mem::transmute(values1), + values[0].into(), + values[1].into(), twiddle_dbls[0].clone().try_into().unwrap(), twiddle_dbls[1].clone().try_into().unwrap(), twiddle_dbls[2].clone().try_into().unwrap(), ); - let (val0, val1) = avx_ibutterfly(val0, val1, u32x16::splat(twiddle_dbls[3][0] as u32)); - std::mem::transmute([val0, val1]) + let (val0, val1) = ibutterfly(val0, val1, u32x16::splat(twiddle_dbls[3][0])); + [val0.to_array(), val1.to_array()].concat() }; - // ref. - let mut values = values0.to_vec(); - values.extend_from_slice(&values1); - let expected = ref_ifft(domain, values.into_iter().map(BaseField::from).collect()); - - // Compare. - for i in 0..32 { - assert_eq!(result[i], expected[i]); - } + assert_eq!(res, ground_truth_ifft(domain, values.flatten())); } #[test] fn test_ifft_lower_with_vecwise() { for log_size in 5..12 { let domain = CanonicCoset::new(log_size).circle_domain(); - let values = (0..domain.size()) - .map(|i| BaseField::from_u32_unchecked(i as u32)) - .collect::>(); - let expected_coeffs = ref_ifft(domain, values.clone()); - - // Compute. - let mut values = BaseFieldVec::from_iter(values); + let mut rng = SmallRng::seed_from_u64(0); + let values = (0..domain.size()).map(|_| rng.gen()).collect_vec(); let twiddle_dbls = get_itwiddle_dbls(domain.half_coset); + let mut res = values.iter().copied().collect::(); unsafe { ifft_lower_with_vecwise( - std::mem::transmute(values.data.as_mut_ptr()), - &twiddle_dbls - .iter() - .map(|x| x.as_slice()) - .collect::>(), + transmute(res.data.as_mut_ptr()), + &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), log_size as usize, log_size as usize, ); - - // Compare. - assert_eq!(values.to_cpu(), expected_coeffs); } + + assert_eq!(res.to_cpu(), ground_truth_ifft(domain, &values)); } } - fn run_ifft_full_test(log_size: u32) { - let domain = CanonicCoset::new(log_size).circle_domain(); - let values = (0..domain.size()) - .map(|i| BaseField::from_u32_unchecked(i as u32)) - .collect::>(); - let expected_coeffs = ref_ifft(domain, values.clone()); - - // Compute. - let mut values = BaseFieldVec::from_iter(values); - let twiddle_dbls = get_itwiddle_dbls(domain.half_coset); + #[test] + fn test_ifft_full() { + for log_size in CACHED_FFT_LOG_SIZE + 1..CACHED_FFT_LOG_SIZE + 3 { + let domain = CanonicCoset::new(log_size as u32).circle_domain(); + let mut rng = SmallRng::seed_from_u64(0); + let values = (0..domain.size()).map(|_| rng.gen()).collect_vec(); + let twiddle_dbls = get_itwiddle_dbls(domain.half_coset); - unsafe { - ifft( - std::mem::transmute(values.data.as_mut_ptr()), - &twiddle_dbls - .iter() - .map(|x| x.as_slice()) - .collect::>(), - log_size as usize, - ); - transpose_vecs( - std::mem::transmute(values.data.as_mut_ptr()), - (log_size - 4) as usize, - ); + let mut res = values.iter().copied().collect::(); + unsafe { + ifft( + transmute(res.data.as_mut_ptr()), + &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), + log_size, + ); + transpose_vecs(transmute(res.data.as_mut_ptr()), log_size - 4); + } - // Compare. - assert_eq!(values.to_cpu(), expected_coeffs); + assert_eq!(res.to_cpu(), ground_truth_ifft(domain, &values)); } } - #[test] - fn test_ifft_full() { - for i in (CACHED_FFT_LOG_SIZE + 1)..(CACHED_FFT_LOG_SIZE + 3) { - run_ifft_full_test(i as u32); - } + fn ground_truth_ifft(domain: CircleDomain, values: &[BaseField]) -> Vec { + let eval = CPUCircleEvaluation::new(domain, values.to_vec()); + let mut res = eval.interpolate().coeffs; + let denorm = BaseField::from(domain.size()); + res.iter_mut().for_each(|v| *v *= denorm); + res } } diff --git a/crates/prover/src/core/backend/simd/fft/mod.rs b/crates/prover/src/core/backend/simd/fft/mod.rs index 7f7cbb833..70bf5bcbb 100644 --- a/crates/prover/src/core/backend/simd/fft/mod.rs +++ b/crates/prover/src/core/backend/simd/fft/mod.rs @@ -1,4 +1,3 @@ -use std::mem::transmute; use std::simd::{simd_swizzle, u32x16, u32x8}; use super::m31::PackedBaseField; @@ -7,76 +6,54 @@ use crate::core::fields::m31::P; pub mod ifft; pub mod rfft; -/// An input to _mm512_permutex2var_epi32, and is used to interleave the even words of a -/// with the even words of b. -const _EVENS_INTERLEAVE_EVENS: u32x16 = unsafe { - core::mem::transmute([ - 0b00000, 0b10000, 0b00010, 0b10010, 0b00100, 0b10100, 0b00110, 0b10110, 0b01000, 0b11000, - 0b01010, 0b11010, 0b01100, 0b11100, 0b01110, 0b11110, - ]) -}; -/// An input to _mm512_permutex2var_epi32, and is used to interleave the odd words of a -/// with the odd words of b. -const _ODDS_INTERLEAVE_ODDS: u32x16 = unsafe { - core::mem::transmute([ - 0b00001, 0b10001, 0b00011, 0b10011, 0b00101, 0b10101, 0b00111, 0b10111, 0b01001, 0b11001, - 0b01011, 0b11011, 0b01101, 0b11101, 0b01111, 0b11111, - ]) -}; - pub const CACHED_FFT_LOG_SIZE: usize = 16; + pub const MIN_FFT_LOG_SIZE: usize = 5; // TODO(spapini): FFTs return a redundant representation, that can get the value P. need to reduce // it somewhere. -/// Transposes the AVX vectors in the given array. +/// Transposes the SIMD vectors in the given array. +/// /// Swaps the bit index abc <-> cba, where |a|=|c| and |b| = 0 or 1, according to the parity of /// `log_n_vecs`. /// When log_n_vecs is odd, transforms the index abc <-> cba, w /// +/// # Arguments +/// +/// - `values`: A mutable pointer to the values that are to be transposed. +/// - `log_n_vecs`: The log of the number of SIMD vectors in the `values` array. +/// /// # Safety -/// This function is unsafe because it takes a raw pointer to i32 values. -/// `values` must be aligned to 64 bytes. /// -/// # Arguments -/// * `values`: A mutable pointer to the values that are to be transposed. -/// * `log_n_vecs`: The log of the number of AVX vectors in the `values` array. -pub unsafe fn transpose_vecs(values: *mut i32, log_n_vecs: usize) { +/// Behavior is undefined if `values` does not have the same alignment as [`u32x16`]. +pub unsafe fn transpose_vecs(values: *mut u32, log_n_vecs: usize) { let half = log_n_vecs / 2; - for b in 0..(1 << (log_n_vecs & 1)) { - for a in 0..(1 << half) { - for c in 0..(1 << half) { + for b in 0..1 << (log_n_vecs & 1) { + for a in 0..1 << half { + for c in 0..1 << half { let i = (a << (log_n_vecs - half)) | (b << half) | c; let j = (c << (log_n_vecs - half)) | (b << half) | a; if i >= j { continue; } - let val0 = _mm512_load_epi32(values.add(i << 4).cast_const()); - let val1 = _mm512_load_epi32(values.add(j << 4).cast_const()); - _mm512_store_epi32(values.add(i << 4), val1); - _mm512_store_epi32(values.add(j << 4), val0); + let val0 = load(values.add(i << 4).cast_const()); + let val1 = load(values.add(j << 4).cast_const()); + store(values.add(i << 4), val1); + store(values.add(j << 4), val0); } } } } -unsafe fn _mm512_load_epi32(mem_addr: *const i32) -> u32x16 { - std::ptr::read(mem_addr as *const u32x16) -} - -unsafe fn _mm512_store_epi32(mem_addr: *mut i32, a: u32x16) { - std::ptr::write(mem_addr as *mut u32x16, a); -} - -/// Computes the twiddles for the first fft layer from the second, and loads both to AVX registers. +/// Computes the twiddles for the first fft layer from the second, and loads both to SIMD registers. +/// /// Returns the twiddles for the first layer and the twiddles for the second layer. -/// # Safety -pub unsafe fn compute_first_twiddles(twiddle1_dbl: [i32; 8]) -> (u32x16, u32x16) { +fn compute_first_twiddles(twiddle1_dbl: u32x8) -> (u32x16, u32x16) { // Start by loading the twiddles for the second layer (layer 1): let t1 = simd_swizzle!( - unsafe { transmute::<_, u32x8>(twiddle1_dbl) }, - unsafe { transmute::<_, u32x8>(twiddle1_dbl) }, + twiddle1_dbl, + twiddle1_dbl, [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7] ); @@ -110,6 +87,14 @@ pub unsafe fn compute_first_twiddles(twiddle1_dbl: [i32; 8]) -> (u32x16, u32x16) (t0, t1) } +unsafe fn load(mem_addr: *const u32) -> u32x16 { + std::ptr::read(mem_addr as *const u32x16) +} + +unsafe fn store(mem_addr: *mut u32, a: u32x16) { + std::ptr::write(mem_addr as *mut u32x16, a); +} + /// Computes `v * twiddle` fn mul_twiddle(v: PackedBaseField, twiddle_dbl: u32x16) -> PackedBaseField { // TODO: Come up with a better approach than `cfg`ing on target_feature. diff --git a/crates/prover/src/core/backend/simd/fft/rfft.rs b/crates/prover/src/core/backend/simd/fft/rfft.rs index aacbffe24..5fbdb868a 100644 --- a/crates/prover/src/core/backend/simd/fft/rfft.rs +++ b/crates/prover/src/core/backend/simd/fft/rfft.rs @@ -1,6 +1,9 @@ //! Regular (forward) fft. -use std::simd::{simd_swizzle, u32x16, u32x2, u32x4}; +use std::array; +use std::simd::{simd_swizzle, u32x16, u32x2, u32x4, u32x8}; + +use itertools::Itertools; use super::{ compute_first_twiddles, mul_twiddle, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE, @@ -9,25 +12,25 @@ use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES}; use crate::core::circle::Coset; use crate::core::utils::bit_reverse; -const VECS_LOG_SIZE: usize = LOG_N_LANES as usize; - -/// Performs a Circle Fast Fourier Transform (ICFFT) on the given values. -/// -/// # Safety -/// This function is unsafe because it takes a raw pointer to i32 values. -/// `values` must be aligned to 64 bytes. +/// Performs a Circle Fast Fourier Transform (CFFT) on the given values. /// /// # Arguments +/// /// * `src`: A pointer to the values to transform. /// * `dst`: A pointer to the destination array. /// * `twiddle_dbl`: A reference to the doubles of the twiddle factors. /// * `log_n_elements`: The log of the number of elements in the `values` array. /// /// # Panics +/// /// This function will panic if `log_n_elements` is less than `MIN_FFT_LOG_SIZE`. -pub unsafe fn fft(src: *const i32, dst: *mut i32, twiddle_dbl: &[&[i32]], log_n_elements: usize) { +/// +/// # Safety +/// +/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. +pub unsafe fn fft(src: *const u32, dst: *mut u32, twiddle_dbl: &[&[u32]], log_n_elements: usize) { assert!(log_n_elements >= MIN_FFT_LOG_SIZE); - let log_n_vecs = log_n_elements - VECS_LOG_SIZE; + let log_n_vecs = log_n_elements - LOG_N_LANES as usize; if log_n_elements <= CACHED_FFT_LOG_SIZE { fft_lower_with_vecwise(src, dst, twiddle_dbl, log_n_elements, log_n_elements); return; @@ -46,37 +49,44 @@ pub unsafe fn fft(src: *const i32, dst: *mut i32, twiddle_dbl: &[&[i32]], log_n_ fft_lower_with_vecwise( dst, dst, - &twiddle_dbl[..(3 + fft_layers_pre_transpose)], + &twiddle_dbl[..3 + fft_layers_pre_transpose], log_n_elements, - fft_layers_pre_transpose + VECS_LOG_SIZE, + fft_layers_pre_transpose + LOG_N_LANES as usize, ); } /// Computes partial fft on `2^log_size` M31 elements. -/// Parameters: -/// src - A pointer to the values to transform, aligned to 64 bytes. -/// dst - A pointer to the destination array, aligned to 64 bytes. -/// twiddle_dbl - The doubles of the twiddle factors for each layer of the the fft. -/// layer i holds 2^(log_size - 1 - i) twiddles. -/// log_size - The log of the number of number of M31 elements in the array. -/// fft_layers - The number of fft layers to apply, out of log_size. +/// +/// # Arguments +/// +/// - `src`: A pointer to the values to transform, aligned to 64 bytes. +/// - `dst`: A pointer to the destination array, aligned to 64 bytes. +/// - `twiddle_dbl`: The doubles of the twiddle factors for each layer of the the fft. Layer `i` +/// holds `2^(log_size - 1 - i)` twiddles. +/// - `log_size`: The log of the number of number of M31 elements in the array. +/// - `fft_layers`: The number of fft layers to apply, out of log_size. +/// +/// # Panics +/// +/// Panics if `log_size` is not at least 5. +/// /// # Safety -/// `values` must be aligned to 64 bytes. -/// `log_size` must be at least 5. +/// +/// `src` and `dst` must have same alignment as [`PackedBaseField`]. /// `fft_layers` must be at least 5. pub unsafe fn fft_lower_with_vecwise( - src: *const i32, - dst: *mut i32, - twiddle_dbl: &[&[i32]], + src: *const u32, + dst: *mut u32, + twiddle_dbl: &[&[u32]], log_size: usize, fft_layers: usize, ) { - const VECWISE_FFT_BITS: usize = VECS_LOG_SIZE + 1; + const VECWISE_FFT_BITS: usize = LOG_N_LANES as usize + 1; assert!(log_size >= VECWISE_FFT_BITS); assert_eq!(twiddle_dbl[0].len(), 1 << (log_size - 2)); - for index_h in 0..(1 << (log_size - fft_layers)) { + for index_h in 0..1 << (log_size - fft_layers) { let mut src = src; for layer in (VECWISE_FFT_BITS..fft_layers).step_by(3).rev() { match fft_layers - layer { @@ -109,32 +119,38 @@ pub unsafe fn fft_lower_with_vecwise( } } -/// Computes partial fft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits -/// of the index). -/// Parameters: -/// src - A pointer to the values to transform, aligned to 64 bytes. -/// dst - A pointer to the destination array, aligned to 64 bytes. -/// twiddle_dbl - The doubles of the twiddle factors for each layer of the the fft. -/// log_size - The log of the number of number of M31 elements in the array. -/// fft_layers - The number of fft layers to apply, out of log_size - VEC_LOG_SIZE. +/// Computes partial fft on `2^log_size` M31 elements, skipping the vecwise layers (lower 4 bits of +/// the index). +/// +/// # Arguments +/// +/// - `src`: A pointer to the values to transform, aligned to 64 bytes. +/// - `dst`: A pointer to the destination array, aligned to 64 bytes. +/// - `twiddle_dbl`: The doubles of the twiddle factors for each layer of the the fft. +/// - `log_size`: The log of the number of number of M31 elements in the array. +/// - `fft_layers`: The number of fft layers to apply, out of log_size - VEC_LOG_SIZE. +/// +/// # Panics +/// +/// Panics if `log_size` is not at least 4. /// /// # Safety -/// `values` must be aligned to 64 bytes. -/// `log_size` must be at least 4. +/// +/// `src` and `dst` must have same alignment as [`PackedBaseField`]. /// `fft_layers` must be at least 4. pub unsafe fn fft_lower_without_vecwise( - src: *const i32, - dst: *mut i32, - twiddle_dbl: &[&[i32]], + src: *const u32, + dst: *mut u32, + twiddle_dbl: &[&[u32]], log_size: usize, fft_layers: usize, ) { - assert!(log_size >= VECS_LOG_SIZE); + assert!(log_size >= LOG_N_LANES as usize); - for index_h in 0..(1 << (log_size - fft_layers - VECS_LOG_SIZE)) { + for index_h in 0..1 << (log_size - fft_layers - LOG_N_LANES as usize) { let mut src = src; for layer in (0..fft_layers).step_by(3).rev() { - let fixed_layer = layer + VECS_LOG_SIZE; + let fixed_layer = layer + LOG_N_LANES as usize; match fft_layers - layer { 1 => { fft1_loop(src, dst, &twiddle_dbl[layer..], fixed_layer, index_h); @@ -159,75 +175,85 @@ pub unsafe fn fft_lower_without_vecwise( } /// Runs the last 5 fft layers across the entire array. -/// Parameters: -/// src - A pointer to the values to transform, aligned to 64 bytes. -/// dst - A pointer to the destination array, aligned to 64 bytes. -/// twiddle_dbl - The doubles of the twiddle factors for each of the 5 fft layers. -/// high_bits - The number of bits this loops needs to run on. -/// index_h - The higher part of the index, iterated by the caller. +/// +/// # Arguments +/// +/// - `src`: A pointer to the values to transform, aligned to 64 bytes. +/// - `dst`: A pointer to the destination array, aligned to 64 bytes. +/// - `twiddle_dbl`: The doubles of the twiddle factors for each of the 5 fft layers. +/// - `high_bits`: The number of bits this loops needs to run on. +/// - `index_h`: The higher part of the index, iterated by the caller. +/// /// # Safety +/// +/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. unsafe fn fft_vecwise_loop( - src: *const i32, - dst: *mut i32, - twiddle_dbl: &[&[i32]], + src: *const u32, + dst: *mut u32, + twiddle_dbl: &[&[u32]], loop_bits: usize, index_h: usize, ) { - for index_l in 0..(1 << loop_bits) { + for index_l in 0..1 << loop_bits { let index = (index_h << loop_bits) + index_l; - let mut val0 = PackedBaseField::load(src.add(index * 32) as *const u32); - let mut val1 = PackedBaseField::load(src.add(index * 32 + 16) as *const u32); - (val0, val1) = avx_butterfly( + let mut val0 = PackedBaseField::load(src.add(index * 32)); + let mut val1 = PackedBaseField::load(src.add(index * 32 + 16)); + (val0, val1) = butterfly( val0, val1, - u32x16::splat(*twiddle_dbl[3].get_unchecked(index) as u32), + u32x16::splat(*twiddle_dbl[3].get_unchecked(index)), ); (val0, val1) = vecwise_butterflies( val0, val1, - std::array::from_fn(|i| *twiddle_dbl[0].get_unchecked(index * 8 + i)), - std::array::from_fn(|i| *twiddle_dbl[1].get_unchecked(index * 4 + i)), - std::array::from_fn(|i| *twiddle_dbl[2].get_unchecked(index * 2 + i)), + array::from_fn(|i| *twiddle_dbl[0].get_unchecked(index * 8 + i)), + array::from_fn(|i| *twiddle_dbl[1].get_unchecked(index * 4 + i)), + array::from_fn(|i| *twiddle_dbl[2].get_unchecked(index * 2 + i)), ); - val0.store(dst.add(index * 32) as *mut u32); - val1.store(dst.add(index * 32 + 16) as *mut u32); + val0.store(dst.add(index * 32)); + val1.store(dst.add(index * 32 + 16)); } } /// Runs 3 fft layers across the entire array. -/// Parameters: -/// src - A pointer to the values to transform, aligned to 64 bytes. -/// dst - A pointer to the destination array, aligned to 64 bytes. -/// twiddle_dbl - The doubles of the twiddle factors for each of the 3 fft layers. -/// loop_bits - The number of bits this loops needs to run on. -/// layer - The layer number of the first fft layer to apply. -/// The layers `layer`, `layer + 1`, `layer + 2` are applied. -/// index_h - The higher part of the index, iterated by the caller. +/// +/// # Arguments +/// +/// - `src`: A pointer to the values to transform, aligned to 64 bytes. +/// - `dst`: A pointer to the destination array, aligned to 64 bytes. +/// - `twiddle_dbl`: The doubles of the twiddle factors for each of the 3 fft layers. +/// - `loop_bits`: The number of bits this loops needs to run on. +/// - `layer`: The layer number of the first fft layer to apply. The layers `layer`, `layer + 1`, +/// `layer + 2` are applied. +/// - `index_h`: The higher part of the index, iterated by the caller. +/// /// # Safety +/// +/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. unsafe fn fft3_loop( - src: *const i32, - dst: *mut i32, - twiddle_dbl: &[&[i32]], + src: *const u32, + dst: *mut u32, + twiddle_dbl: &[&[u32]], loop_bits: usize, layer: usize, index_h: usize, ) { - for index_l in 0..(1 << loop_bits) { + for index_l in 0..1 << loop_bits { let index = (index_h << loop_bits) + index_l; let offset = index << (layer + 3); - for l in (0..(1 << layer)).step_by(1 << VECS_LOG_SIZE) { + for l in (0..1 << layer).step_by(1 << LOG_N_LANES as usize) { fft3( src, dst, offset + l, layer, - std::array::from_fn(|i| { + array::from_fn(|i| { *twiddle_dbl[0].get_unchecked((index * 4 + i) & (twiddle_dbl[0].len() - 1)) }), - std::array::from_fn(|i| { + array::from_fn(|i| { *twiddle_dbl[1].get_unchecked((index * 2 + i) & (twiddle_dbl[1].len() - 1)) }), - std::array::from_fn(|i| { + array::from_fn(|i| { *twiddle_dbl[2].get_unchecked((index + i) & (twiddle_dbl[2].len() - 1)) }), ); @@ -236,33 +262,38 @@ unsafe fn fft3_loop( } /// Runs 2 fft layers across the entire array. -/// Parameters: -/// src - A pointer to the values to transform, aligned to 64 bytes. -/// dst - A pointer to the destination array, aligned to 64 bytes. -/// twiddle_dbl - The doubles of the twiddle factors for each of the 2 fft layers. -/// loop_bits - The number of bits this loops needs to run on. -/// layer - The layer number of the first fft layer to apply. -/// The layers `layer`, `layer + 1` are applied. -/// index - The index, iterated by the caller. +/// +/// # Arguments +/// +/// - `src`: A pointer to the values to transform, aligned to 64 bytes. +/// - `dst`: A pointer to the destination array, aligned to 64 bytes. +/// - `twiddle_dbl`: The doubles of the twiddle factors for each of the 2 fft layers. +/// - `loop_bits`: The number of bits this loops needs to run on. +/// - `layer`: The layer number of the first fft layer to apply. The layers `layer`, `layer + 1` are +/// applied. +/// - `index`: The index, iterated by the caller. +/// /// # Safety +/// +/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. unsafe fn fft2_loop( - src: *const i32, - dst: *mut i32, - twiddle_dbl: &[&[i32]], + src: *const u32, + dst: *mut u32, + twiddle_dbl: &[&[u32]], layer: usize, index: usize, ) { let offset = index << (layer + 2); - for l in (0..(1 << layer)).step_by(1 << VECS_LOG_SIZE) { + for l in (0..1 << layer).step_by(1 << LOG_N_LANES as usize) { fft2( src, dst, offset + l, layer, - std::array::from_fn(|i| { + array::from_fn(|i| { *twiddle_dbl[0].get_unchecked((index * 2 + i) & (twiddle_dbl[0].len() - 1)) }), - std::array::from_fn(|i| { + array::from_fn(|i| { *twiddle_dbl[1].get_unchecked((index + i) & (twiddle_dbl[1].len() - 1)) }), ); @@ -270,28 +301,33 @@ unsafe fn fft2_loop( } /// Runs 1 fft layer across the entire array. -/// Parameters: -/// src - A pointer to the values to transform, aligned to 64 bytes. -/// dst - A pointer to the destination array, aligned to 64 bytes. -/// twiddle_dbl - The doubles of the twiddle factors for the fft layer. -/// layer - The layer number of the fft layer to apply. -/// index_h - The higher part of the index, iterated by the caller. +/// +/// # Arguments +/// +/// - `src`: A pointer to the values to transform, aligned to 64 bytes. +/// - `dst`: A pointer to the destination array, aligned to 64 bytes. +/// - `twiddle_dbl`: The doubles of the twiddle factors for the fft layer. +/// - `layer`: The layer number of the fft layer to apply. +/// - `index_h`: The higher part of the index, iterated by the caller. +/// /// # Safety +/// +/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. unsafe fn fft1_loop( - src: *const i32, - dst: *mut i32, - twiddle_dbl: &[&[i32]], + src: *const u32, + dst: *mut u32, + twiddle_dbl: &[&[u32]], layer: usize, index: usize, ) { let offset = index << (layer + 1); - for l in (0..(1 << layer)).step_by(1 << VECS_LOG_SIZE) { + for l in (0..1 << layer).step_by(1 << LOG_N_LANES as usize) { fft1( src, dst, offset + l, layer, - std::array::from_fn(|i| { + array::from_fn(|i| { *twiddle_dbl[0].get_unchecked((index + i) & (twiddle_dbl[0].len() - 1)) }), ); @@ -299,14 +335,12 @@ unsafe fn fft1_loop( } /// Computes the butterfly operation for packed M31 elements. -/// val0 + t val1, val0 - t val1. -/// val0, val1 are packed M31 elements. 16 M31 words at each. -/// Each value is assumed to be in unreduced form, [0, P] including P. -/// Returned values are in unreduced form, [0, P] including P. -/// twiddle_dbl holds 16 values, each is a *double* of a twiddle factor, in unreduced form. -/// # Safety -/// This function is safe. -pub unsafe fn avx_butterfly( +/// +/// Returns `val0 + t val1, val0 - t val1`. `val0, val1` are packed M31 elements. 16 M31 words at +/// each. Each value is assumed to be in unreduced form, [0, P] including P. Returned values are in +/// unreduced form, [0, P] including P. twiddle_dbl holds 16 values, each is a *double* of a twiddle +/// factor, in unreduced form, [0, 2*P]. +pub fn butterfly( val0: PackedBaseField, val1: PackedBaseField, twiddle_dbl: u32x16, @@ -316,6 +350,7 @@ pub unsafe fn avx_butterfly( } /// Runs fft on 2 vectors of 16 M31 elements. +/// /// This amounts to 4 butterfly layers, each with 16 butterflies. /// Each of the vectors represents natural ordered polynomial coefficeint. /// Each value in a vectors is in unreduced form: [0, P] including P. @@ -323,52 +358,51 @@ pub unsafe fn avx_butterfly( /// The first layer (higher bit of the index) takes 2 twiddles. /// The second layer takes 4 twiddles. /// etc. -/// # Safety -pub unsafe fn vecwise_butterflies( +pub fn vecwise_butterflies( mut val0: PackedBaseField, mut val1: PackedBaseField, - twiddle1_dbl: [i32; 8], - twiddle2_dbl: [i32; 4], - twiddle3_dbl: [i32; 2], + twiddle1_dbl: [u32; 8], + twiddle2_dbl: [u32; 4], + twiddle3_dbl: [u32; 2], ) -> (PackedBaseField, PackedBaseField) { // TODO(spapini): Compute twiddle0 from twiddle1. // TODO(spapini): The permute can be fused with the _mm512_srli_epi64 inside the butterfly. // The implementation is the exact reverse of vecwise_ibutterflies(). // See the comments in its body for more info. let t = simd_swizzle!( - u32x2::from_array(unsafe { std::mem::transmute(twiddle3_dbl) }), + u32x2::from(twiddle3_dbl), [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] ); (val0, val1) = val0.interleave(val1); - (val0, val1) = avx_butterfly(val0, val1, t); + (val0, val1) = butterfly(val0, val1, t); let t = simd_swizzle!( - u32x4::from_array(unsafe { std::mem::transmute(twiddle2_dbl) }), + u32x4::from(twiddle2_dbl), [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3] ); (val0, val1) = val0.interleave(val1); - (val0, val1) = avx_butterfly(val0, val1, t); + (val0, val1) = butterfly(val0, val1, t); - let (t0, t1) = compute_first_twiddles(twiddle1_dbl); + let (t0, t1) = compute_first_twiddles(u32x8::from(twiddle1_dbl)); (val0, val1) = val0.interleave(val1); - (val0, val1) = avx_butterfly(val0, val1, t1); + (val0, val1) = butterfly(val0, val1, t1); (val0, val1) = val0.interleave(val1); - (val0, val1) = avx_butterfly(val0, val1, t0); + (val0, val1) = butterfly(val0, val1, t0); val0.interleave(val1) } /// Returns the line twiddles (x points) for an fft on a coset. -pub fn get_twiddle_dbls(mut coset: Coset) -> Vec> { +pub fn get_twiddle_dbls(mut coset: Coset) -> Vec> { let mut res = vec![]; for _ in 0..coset.log_size() { res.push( coset .iter() .take(coset.size() / 2) - .map(|p| (p.x.0 * 2) as i32) - .collect::>(), + .map(|p| p.x.0 * 2) + .collect_vec(), ); bit_reverse(res.last_mut().unwrap()); coset = coset.double(); @@ -378,247 +412,254 @@ pub fn get_twiddle_dbls(mut coset: Coset) -> Vec> { } /// Applies 3 butterfly layers on 8 vectors of 16 M31 elements. +/// /// Vectorized over the 16 elements of the vectors. /// Used for radix-8 ifft. -/// Each butterfly layer, has 3 AVX butterflies. -/// Total of 12 AVX butterflies. -/// Parameters: -/// src - A pointer to the values to transform, aligned to 64 bytes. -/// dst - A pointer to the destination array, aligned to 64 bytes. -/// offset - The offset of the first value in the array. -/// log_step - The log of the distance in the array, in M31 elements, between each pair of -/// values that need to be transformed. For layer i this is i - 4. -/// twiddles_dbl0/1/2 - The double of the twiddles for the 3 layers of butterflies. -/// Each layer has 4/2/1 twiddles. +/// Each butterfly layer, has 3 SIMD butterflies. +/// Total of 12 SIMD butterflies. +/// +/// # Arguments +/// +/// - `src`: A pointer to the values to transform, aligned to 64 bytes. +/// - `dst`: A pointer to the destination array, aligned to 64 bytes. +/// - `offset`: The offset of the first value in the array. +/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values +/// that need to be transformed. For layer i this is i - 4. +/// - `twiddles_dbl0/1/2`: The double of the twiddles for the 3 layers of butterflies. Each layer +/// has 4/2/1 twiddles. +/// /// # Safety +/// +/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. pub unsafe fn fft3( - src: *const i32, - dst: *mut i32, + src: *const u32, + dst: *mut u32, offset: usize, log_step: usize, - twiddles_dbl0: [i32; 4], - twiddles_dbl1: [i32; 2], - twiddles_dbl2: [i32; 1], + twiddles_dbl0: [u32; 4], + twiddles_dbl1: [u32; 2], + twiddles_dbl2: [u32; 1], ) { - // Load the 8 AVX vectors from the array. - let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step)) as *const u32); - let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step)) as *const u32); - let mut val2 = PackedBaseField::load(src.add(offset + (2 << log_step)) as *const u32); - let mut val3 = PackedBaseField::load(src.add(offset + (3 << log_step)) as *const u32); - let mut val4 = PackedBaseField::load(src.add(offset + (4 << log_step)) as *const u32); - let mut val5 = PackedBaseField::load(src.add(offset + (5 << log_step)) as *const u32); - let mut val6 = PackedBaseField::load(src.add(offset + (6 << log_step)) as *const u32); - let mut val7 = PackedBaseField::load(src.add(offset + (7 << log_step)) as *const u32); + // Load the 8 SIMD vectors from the array. + let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step))); + let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step))); + let mut val2 = PackedBaseField::load(src.add(offset + (2 << log_step))); + let mut val3 = PackedBaseField::load(src.add(offset + (3 << log_step))); + let mut val4 = PackedBaseField::load(src.add(offset + (4 << log_step))); + let mut val5 = PackedBaseField::load(src.add(offset + (5 << log_step))); + let mut val6 = PackedBaseField::load(src.add(offset + (6 << log_step))); + let mut val7 = PackedBaseField::load(src.add(offset + (7 << log_step))); // Apply the third layer of butterflies. - (val0, val4) = avx_butterfly(val0, val4, u32x16::splat(twiddles_dbl2[0] as u32)); - (val1, val5) = avx_butterfly(val1, val5, u32x16::splat(twiddles_dbl2[0] as u32)); - (val2, val6) = avx_butterfly(val2, val6, u32x16::splat(twiddles_dbl2[0] as u32)); - (val3, val7) = avx_butterfly(val3, val7, u32x16::splat(twiddles_dbl2[0] as u32)); + (val0, val4) = butterfly(val0, val4, u32x16::splat(twiddles_dbl2[0])); + (val1, val5) = butterfly(val1, val5, u32x16::splat(twiddles_dbl2[0])); + (val2, val6) = butterfly(val2, val6, u32x16::splat(twiddles_dbl2[0])); + (val3, val7) = butterfly(val3, val7, u32x16::splat(twiddles_dbl2[0])); // Apply the second layer of butterflies. - (val0, val2) = avx_butterfly(val0, val2, u32x16::splat(twiddles_dbl1[0] as u32)); - (val1, val3) = avx_butterfly(val1, val3, u32x16::splat(twiddles_dbl1[0] as u32)); - (val4, val6) = avx_butterfly(val4, val6, u32x16::splat(twiddles_dbl1[1] as u32)); - (val5, val7) = avx_butterfly(val5, val7, u32x16::splat(twiddles_dbl1[1] as u32)); + (val0, val2) = butterfly(val0, val2, u32x16::splat(twiddles_dbl1[0])); + (val1, val3) = butterfly(val1, val3, u32x16::splat(twiddles_dbl1[0])); + (val4, val6) = butterfly(val4, val6, u32x16::splat(twiddles_dbl1[1])); + (val5, val7) = butterfly(val5, val7, u32x16::splat(twiddles_dbl1[1])); // Apply the first layer of butterflies. - (val0, val1) = avx_butterfly(val0, val1, u32x16::splat(twiddles_dbl0[0] as u32)); - (val2, val3) = avx_butterfly(val2, val3, u32x16::splat(twiddles_dbl0[1] as u32)); - (val4, val5) = avx_butterfly(val4, val5, u32x16::splat(twiddles_dbl0[2] as u32)); - (val6, val7) = avx_butterfly(val6, val7, u32x16::splat(twiddles_dbl0[3] as u32)); - - // Store the 8 AVX vectors back to the array. - val0.store(dst.add(offset + (0 << log_step)) as *mut u32); - val1.store(dst.add(offset + (1 << log_step)) as *mut u32); - val2.store(dst.add(offset + (2 << log_step)) as *mut u32); - val3.store(dst.add(offset + (3 << log_step)) as *mut u32); - val4.store(dst.add(offset + (4 << log_step)) as *mut u32); - val5.store(dst.add(offset + (5 << log_step)) as *mut u32); - val6.store(dst.add(offset + (6 << log_step)) as *mut u32); - val7.store(dst.add(offset + (7 << log_step)) as *mut u32); + (val0, val1) = butterfly(val0, val1, u32x16::splat(twiddles_dbl0[0])); + (val2, val3) = butterfly(val2, val3, u32x16::splat(twiddles_dbl0[1])); + (val4, val5) = butterfly(val4, val5, u32x16::splat(twiddles_dbl0[2])); + (val6, val7) = butterfly(val6, val7, u32x16::splat(twiddles_dbl0[3])); + + // Store the 8 SIMD vectors back to the array. + val0.store(dst.add(offset + (0 << log_step))); + val1.store(dst.add(offset + (1 << log_step))); + val2.store(dst.add(offset + (2 << log_step))); + val3.store(dst.add(offset + (3 << log_step))); + val4.store(dst.add(offset + (4 << log_step))); + val5.store(dst.add(offset + (5 << log_step))); + val6.store(dst.add(offset + (6 << log_step))); + val7.store(dst.add(offset + (7 << log_step))); } /// Applies 2 butterfly layers on 4 vectors of 16 M31 elements. +/// /// Vectorized over the 16 elements of the vectors. /// Used for radix-4 fft. -/// Each butterfly layer, has 2 AVX butterflies. -/// Total of 4 AVX butterflies. -/// Parameters: -/// src - A pointer to the values to transform, aligned to 64 bytes. -/// dst - A pointer to the destination array, aligned to 64 bytes. -/// offset - The offset of the first value in the array. -/// log_step - The log of the distance in the array, in M31 elements, between each pair of -/// values that need to be transformed. For layer i this is i - 4. -/// twiddles_dbl0/1 - The double of the twiddles for the 2 layers of butterflies. -/// Each layer has 2/1 twiddles. +/// Each butterfly layer, has 2 SIMD butterflies. +/// Total of 4 SIMD butterflies. +/// +/// # Arguments +/// +/// - `src`: A pointer to the values to transform, aligned to 64 bytes. +/// - `dst`: A pointer to the destination array, aligned to 64 bytes. +/// - `offset`: The offset of the first value in the array. +/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values +/// that need to be transformed. For layer i this is i - 4. +/// - `twiddles_dbl0/1`: The double of the twiddles for the 2 layers of butterflies. Each layer has +/// 2/1 twiddles. +/// /// # Safety +/// +/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. pub unsafe fn fft2( - src: *const i32, - dst: *mut i32, + src: *const u32, + dst: *mut u32, offset: usize, log_step: usize, - twiddles_dbl0: [i32; 2], - twiddles_dbl1: [i32; 1], + twiddles_dbl0: [u32; 2], + twiddles_dbl1: [u32; 1], ) { - // Load the 4 AVX vectors from the array. - let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step)) as *const u32); - let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step)) as *const u32); - let mut val2 = PackedBaseField::load(src.add(offset + (2 << log_step)) as *const u32); - let mut val3 = PackedBaseField::load(src.add(offset + (3 << log_step)) as *const u32); + // Load the 4 SIMD vectors from the array. + let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step))); + let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step))); + let mut val2 = PackedBaseField::load(src.add(offset + (2 << log_step))); + let mut val3 = PackedBaseField::load(src.add(offset + (3 << log_step))); // Apply the second layer of butterflies. - (val0, val2) = avx_butterfly(val0, val2, u32x16::splat(twiddles_dbl1[0] as u32)); - (val1, val3) = avx_butterfly(val1, val3, u32x16::splat(twiddles_dbl1[0] as u32)); + (val0, val2) = butterfly(val0, val2, u32x16::splat(twiddles_dbl1[0])); + (val1, val3) = butterfly(val1, val3, u32x16::splat(twiddles_dbl1[0])); // Apply the first layer of butterflies. - (val0, val1) = avx_butterfly(val0, val1, u32x16::splat(twiddles_dbl0[0] as u32)); - (val2, val3) = avx_butterfly(val2, val3, u32x16::splat(twiddles_dbl0[1] as u32)); - - // Store the 4 AVX vectors back to the array. - val0.store(dst.add(offset + (0 << log_step)) as *mut u32); - val1.store(dst.add(offset + (1 << log_step)) as *mut u32); - val2.store(dst.add(offset + (2 << log_step)) as *mut u32); - val3.store(dst.add(offset + (3 << log_step)) as *mut u32); + (val0, val1) = butterfly(val0, val1, u32x16::splat(twiddles_dbl0[0])); + (val2, val3) = butterfly(val2, val3, u32x16::splat(twiddles_dbl0[1])); + + // Store the 4 SIMD vectors back to the array. + val0.store(dst.add(offset + (0 << log_step))); + val1.store(dst.add(offset + (1 << log_step))); + val2.store(dst.add(offset + (2 << log_step))); + val3.store(dst.add(offset + (3 << log_step))); } /// Applies 1 butterfly layers on 2 vectors of 16 M31 elements. +/// /// Vectorized over the 16 elements of the vectors. -/// Parameters: -/// src - A pointer to the values to transform, aligned to 64 bytes. -/// dst - A pointer to the destination array, aligned to 64 bytes. -/// offset - The offset of the first value in the array. -/// log_step - The log of the distance in the array, in M31 elements, between each pair of -/// values that need to be transformed. For layer i this is i - 4. -/// twiddles_dbl0 - The double of the twiddles for the butterfly layer. +/// +/// # Arguments +/// +/// - `src`: A pointer to the values to transform, aligned to 64 bytes. +/// - `dst`: A pointer to the destination array, aligned to 64 bytes. +/// - `offset`: The offset of the first value in the array. +/// - `log_step`: The log of the distance in the array, in M31 elements, between each pair of values +/// that need to be transformed. For layer i this is i - 4. +/// - `twiddles_dbl0`: The double of the twiddles for the butterfly layer. +/// /// # Safety +/// +/// Behavior is undefined if `src` and `dst` do not have the same alignment as [`PackedBaseField`]. pub unsafe fn fft1( - src: *const i32, - dst: *mut i32, + src: *const u32, + dst: *mut u32, offset: usize, log_step: usize, - twiddles_dbl0: [i32; 1], + twiddles_dbl0: [u32; 1], ) { - // Load the 2 AVX vectors from the array. - let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step)) as *const u32); - let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step)) as *const u32); + // Load the 2 SIMD vectors from the array. + let mut val0 = PackedBaseField::load(src.add(offset + (0 << log_step))); + let mut val1 = PackedBaseField::load(src.add(offset + (1 << log_step))); - (val0, val1) = avx_butterfly(val0, val1, u32x16::splat(twiddles_dbl0[0] as u32)); + (val0, val1) = butterfly(val0, val1, u32x16::splat(twiddles_dbl0[0])); - // Store the 2 AVX vectors back to the array. - val0.store(dst.add(offset + (0 << log_step)) as *mut u32); - val1.store(dst.add(offset + (1 << log_step)) as *mut u32); + // Store the 2 SIMD vectors back to the array. + val0.store(dst.add(offset + (0 << log_step))); + val1.store(dst.add(offset + (1 << log_step))); } #[cfg(test)] mod tests { - use super::*; + use std::mem::transmute; + use std::simd::u32x16; + + use itertools::Itertools; + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + + use super::{ + butterfly, fft, fft3, fft_lower_with_vecwise, get_twiddle_dbls, vecwise_butterflies, + }; use crate::core::backend::cpu::CPUCirclePoly; use crate::core::backend::simd::column::BaseFieldVec; + use crate::core::backend::simd::fft::{transpose_vecs, CACHED_FFT_LOG_SIZE}; + use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES, N_LANES}; use crate::core::backend::Column; - use crate::core::fft::butterfly; + use crate::core::fft::butterfly as ground_truth_butterfly; use crate::core::fields::m31::BaseField; use crate::core::poly::circle::{CanonicCoset, CircleDomain}; #[test] fn test_butterfly() { - unsafe { - let val0 = PackedBaseField::from_simd_unchecked(u32x16::from_array([ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - ])); - let val1 = PackedBaseField::from_simd_unchecked(u32x16::from_array([ - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, - ])); - let twiddle = u32x16::from_array([ - 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, - ]); - let twiddle_dbl = twiddle + twiddle; - let (r0, r1) = avx_butterfly(val0, val1, twiddle_dbl); - - let val0: [BaseField; 16] = std::mem::transmute(val0); - let val1: [BaseField; 16] = std::mem::transmute(val1); - let twiddle: [BaseField; 16] = std::mem::transmute(twiddle); - let r0: [BaseField; 16] = std::mem::transmute(r0); - let r1: [BaseField; 16] = std::mem::transmute(r1); - - for i in 0..16 { - let mut x = val0[i]; - let mut y = val1[i]; - let twiddle = twiddle[i]; - butterfly(&mut x, &mut y, twiddle); - assert_eq!(x, r0[i]); - assert_eq!(y, r1[i]); - } + let mut rng = SmallRng::seed_from_u64(0); + let mut v0: [BaseField; N_LANES] = rng.gen(); + let mut v1: [BaseField; N_LANES] = rng.gen(); + let twiddle: [BaseField; N_LANES] = rng.gen(); + let twiddle_dbl = twiddle.map(|v| v.0 * 2); + + let (r0, r1) = butterfly(v0.into(), v1.into(), twiddle_dbl.into()); + + let r0 = r0.to_array(); + let r1 = r1.to_array(); + for i in 0..N_LANES { + ground_truth_butterfly(&mut v0[i], &mut v1[i], twiddle[i]); + assert_eq!((v0[i], v1[i]), (r0[i], r1[i]), "mismatch at i={i}"); } } #[test] fn test_fft3() { + let mut rng = SmallRng::seed_from_u64(0); + let values = rng.gen::<[BaseField; 8]>().map(PackedBaseField::broadcast); + let twiddles0: [BaseField; 4] = rng.gen(); + let twiddles1: [BaseField; 2] = rng.gen(); + let twiddles2: [BaseField; 1] = rng.gen(); + let twiddles0_dbl = twiddles0.map(|v| v.0 * 2); + let twiddles1_dbl = twiddles1.map(|v| v.0 * 2); + let twiddles2_dbl = twiddles2.map(|v| v.0 * 2); + + let mut res = values; unsafe { - let mut values: Vec = (0..8) - .map(|i| { - PackedBaseField::from_array(std::array::from_fn(|_| { - BaseField::from_u32_unchecked(i) - })) - }) - .collect(); - let twiddles0 = [32, 33, 34, 35]; - let twiddles1 = [36, 37]; - let twiddles2 = [38]; - let twiddles0_dbl = std::array::from_fn(|i| twiddles0[i] * 2); - let twiddles1_dbl = std::array::from_fn(|i| twiddles1[i] * 2); - let twiddles2_dbl = std::array::from_fn(|i| twiddles2[i] * 2); fft3( - std::mem::transmute(values.as_ptr()), - std::mem::transmute(values.as_mut_ptr()), + transmute(res.as_ptr()), + transmute(res.as_mut_ptr()), 0, - VECS_LOG_SIZE, + LOG_N_LANES as usize, twiddles0_dbl, twiddles1_dbl, twiddles2_dbl, - ); + ) + }; - let expected: [u32; 8] = std::array::from_fn(|i| i as u32); - let mut expected: [BaseField; 8] = std::mem::transmute(expected); - let twiddles0: [BaseField; 4] = std::mem::transmute(twiddles0); - let twiddles1: [BaseField; 2] = std::mem::transmute(twiddles1); - let twiddles2: [BaseField; 1] = std::mem::transmute(twiddles2); - for i in 0..8 { - let j = i ^ 4; - if i > j { - continue; - } - let (mut v0, mut v1) = (expected[i], expected[j]); - butterfly(&mut v0, &mut v1, twiddles2[0]); - (expected[i], expected[j]) = (v0, v1); - } - for i in 0..8 { - let j = i ^ 2; - if i > j { - continue; - } - let (mut v0, mut v1) = (expected[i], expected[j]); - butterfly(&mut v0, &mut v1, twiddles1[i / 4]); - (expected[i], expected[j]) = (v0, v1); + let mut expected = values.map(|v| v.to_array()[0]); + for i in 0..8 { + let j = i ^ 4; + if i > j { + continue; } - for i in 0..8 { - let j = i ^ 1; - if i > j { - continue; - } - let (mut v0, mut v1) = (expected[i], expected[j]); - butterfly(&mut v0, &mut v1, twiddles0[i / 2]); - (expected[i], expected[j]) = (v0, v1); + let (mut v0, mut v1) = (expected[i], expected[j]); + ground_truth_butterfly(&mut v0, &mut v1, twiddles2[0]); + (expected[i], expected[j]) = (v0, v1); + } + for i in 0..8 { + let j = i ^ 2; + if i > j { + continue; } - for i in 0..8 { - assert_eq!(values[i].to_array()[0], expected[i]); + let (mut v0, mut v1) = (expected[i], expected[j]); + ground_truth_butterfly(&mut v0, &mut v1, twiddles1[i / 4]); + (expected[i], expected[j]) = (v0, v1); + } + for i in 0..8 { + let j = i ^ 1; + if i > j { + continue; } + let (mut v0, mut v1) = (expected[i], expected[j]); + ground_truth_butterfly(&mut v0, &mut v1, twiddles0[i / 2]); + (expected[i], expected[j]) = (v0, v1); + } + for i in 0..8 { + assert_eq!( + res[i].to_array(), + [expected[i]; N_LANES], + "mismatch at i={i}" + ); } - } - - fn ref_fft(domain: CircleDomain, values: Vec) -> Vec { - let poly = CPUCirclePoly::new(values); - poly.evaluate(domain).values } #[test] @@ -626,13 +667,14 @@ mod tests { let domain = CanonicCoset::new(5).circle_domain(); let twiddle_dbls = get_twiddle_dbls(domain.half_coset); assert_eq!(twiddle_dbls.len(), 4); - let values0: [i32; 16] = std::array::from_fn(|i| i as i32); - let values1: [i32; 16] = std::array::from_fn(|i| (i + 16) as i32); - let result: [BaseField; 32] = unsafe { - let (val0, val1) = avx_butterfly( - std::mem::transmute(values0), - std::mem::transmute(values1), - u32x16::splat(twiddle_dbls[3][0] as u32), + let mut rng = SmallRng::seed_from_u64(0); + let values: [[BaseField; 16]; 2] = rng.gen(); + + let res = { + let (val0, val1) = butterfly( + values[0].into(), + values[1].into(), + u32x16::splat(twiddle_dbls[3][0]), ); let (val0, val1) = vecwise_butterflies( val0, @@ -641,86 +683,60 @@ mod tests { twiddle_dbls[1].clone().try_into().unwrap(), twiddle_dbls[2].clone().try_into().unwrap(), ); - std::mem::transmute([val0, val1]) + [val0.to_array(), val1.to_array()].concat() }; - // ref. - let mut values = values0.to_vec(); - values.extend_from_slice(&values1); - let expected = ref_fft(domain, values.into_iter().map(BaseField::from).collect()); - - // Compare. - for i in 0..32 { - assert_eq!(result[i], expected[i]); - } + assert_eq!(res, ground_truth_fft(domain, values.flatten())); } #[test] fn test_fft_lower() { for log_size in 5..12 { let domain = CanonicCoset::new(log_size).circle_domain(); - let values = (0..domain.size()) - .map(|i| BaseField::from_u32_unchecked(i as u32)) - .collect::>(); - let expected_coeffs = ref_fft(domain, values.clone()); - - // Compute. - let mut values = BaseFieldVec::from_iter(values); + let mut rng = SmallRng::seed_from_u64(0); + let values = (0..domain.size()).map(|_| rng.gen()).collect_vec(); let twiddle_dbls = get_twiddle_dbls(domain.half_coset); + let mut res = values.iter().copied().collect::(); unsafe { fft_lower_with_vecwise( - std::mem::transmute(values.data.as_ptr()), - std::mem::transmute(values.data.as_mut_ptr()), - &twiddle_dbls - .iter() - .map(|x| x.as_slice()) - .collect::>(), + transmute(res.data.as_ptr()), + transmute(res.data.as_mut_ptr()), + &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), log_size as usize, log_size as usize, - ); - - // Compare. - assert_eq!(values.to_cpu(), expected_coeffs); + ) } + + assert_eq!(res.to_cpu(), ground_truth_fft(domain, &values)); } } - fn run_fft_full_test(log_size: u32) { - let domain = CanonicCoset::new(log_size).circle_domain(); - let values = (0..domain.size()) - .map(|i| BaseField::from_u32_unchecked(i as u32)) - .collect::>(); - let expected_coeffs = ref_fft(domain, values.clone()); - - // Compute. - let mut values = BaseFieldVec::from_iter(values); - let twiddle_dbls = get_twiddle_dbls(domain.half_coset); + #[test] + fn test_fft_full() { + for log_size in CACHED_FFT_LOG_SIZE + 1..CACHED_FFT_LOG_SIZE + 3 { + let domain = CanonicCoset::new(log_size as u32).circle_domain(); + let mut rng = SmallRng::seed_from_u64(0); + let values = (0..domain.size()).map(|_| rng.gen()).collect_vec(); + let twiddle_dbls = get_twiddle_dbls(domain.half_coset); - unsafe { - transpose_vecs( - std::mem::transmute(values.data.as_mut_ptr()), - (log_size - 4) as usize, - ); - fft( - std::mem::transmute(values.data.as_ptr()), - std::mem::transmute(values.data.as_mut_ptr()), - &twiddle_dbls - .iter() - .map(|x| x.as_slice()) - .collect::>(), - log_size as usize, - ); + let mut res = values.iter().copied().collect::(); + unsafe { + transpose_vecs(transmute(res.data.as_mut_ptr()), log_size - 4); + fft( + transmute(res.data.as_ptr()), + transmute(res.data.as_mut_ptr()), + &twiddle_dbls.iter().map(|x| x.as_slice()).collect_vec(), + log_size, + ); + } - // Compare. - assert_eq!(values.to_cpu(), expected_coeffs); + assert_eq!(res.to_cpu(), ground_truth_fft(domain, &values)); } } - #[test] - fn test_fft_full() { - for i in (CACHED_FFT_LOG_SIZE + 1)..(CACHED_FFT_LOG_SIZE + 3) { - run_fft_full_test(i as u32); - } + fn ground_truth_fft(domain: CircleDomain, values: &[BaseField]) -> Vec { + let poly = CPUCirclePoly::new(values.to_vec()); + poly.evaluate(domain).values } }