Skip to content

Commit

Permalink
Fix compilation issues
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed May 6, 2024
1 parent 1226157 commit 347d446
Show file tree
Hide file tree
Showing 3 changed files with 458 additions and 233 deletions.
221 changes: 112 additions & 109 deletions crates/prover/src/core/backend/simd/fft/ifft.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
//! Inverse fft.
use std::arch::x86_64::{
__m512i, _mm512_broadcast_i32x4, _mm512_mul_epu32, _mm512_permutex2var_epi32,
_mm512_set1_epi32, _mm512_set1_epi64, _mm512_srli_epi64,
};

use super::{compute_first_twiddles, EVENS_INTERLEAVE_EVENS, ODDS_INTERLEAVE_ODDS};
use crate::core::backend::avx512::fft::{transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE};
use crate::core::backend::avx512::{PackedBaseField, VECS_LOG_SIZE};
use std::simd::{simd_swizzle, u32x16, u32x2, u32x4};

use super::{compute_first_twiddles, transpose_vecs, CACHED_FFT_LOG_SIZE, MIN_FFT_LOG_SIZE};
use crate::core::backend::simd::m31::{PackedBaseField, LOG_N_LANES};
use crate::core::circle::Coset;
use crate::core::fields::FieldExpOps;
use crate::core::utils::bit_reverse;

const VECS_LOG_SIZE: usize = LOG_N_LANES as usize;

/// Performs an Inverse Circle Fast Fourier Transform (ICFFT) on the given values.
///
/// # Safety
Expand Down Expand Up @@ -155,8 +153,9 @@ pub unsafe fn ifft_vecwise_loop(
) {
for index_l in 0..(1 << loop_bits) {
let index = (index_h << loop_bits) + index_l;
let mut val0 = PackedBaseField::load(values.add(index * 32).cast_const());
let mut val1 = PackedBaseField::load(values.add(index * 32 + 16).cast_const());
let mut val0 = PackedBaseField::load(values.add(index * 32).cast_const() as *const u32);
let mut val1 =
PackedBaseField::load(values.add(index * 32 + 16).cast_const() as *const u32);
(val0, val1) = vecwise_ibutterflies(
val0,
val1,
Expand All @@ -167,10 +166,10 @@ pub unsafe fn ifft_vecwise_loop(
(val0, val1) = avx_ibutterfly(
val0,
val1,
_mm512_set1_epi32(*twiddle_dbl[3].get_unchecked(index)),
u32x16::splat(*twiddle_dbl[3].get_unchecked(index) as u32),
);
val0.store(values.add(index * 32));
val1.store(values.add(index * 32 + 16));
val0.store(values.add(index * 32) as *mut u32);
val1.store(values.add(index * 32 + 16) as *mut u32);
}
}

Expand Down Expand Up @@ -269,41 +268,28 @@ unsafe fn ifft1_loop(values: *mut i32, twiddle_dbl: &[&[i32]], layer: usize, ind
pub unsafe fn avx_ibutterfly(
val0: PackedBaseField,
val1: PackedBaseField,
twiddle_dbl: __m512i,
twiddle_dbl: u32x16,
) -> (PackedBaseField, PackedBaseField) {
let r0 = val0 + val1;
let r1 = val0 - val1;

// Extract the even and odd parts of r1 and twiddle_dbl, and spread as 8 64bit values.
let r1_e = r1.0;
let r1_o = _mm512_srli_epi64(r1.0, 32);
let twiddle_dbl_e = twiddle_dbl;
let twiddle_dbl_o = _mm512_srli_epi64(twiddle_dbl, 32);

// To compute prod = r1 * twiddle start by multiplying
// r1_e/o by twiddle_dbl_e/o.
let prod_e_dbl = _mm512_mul_epu32(r1_e, twiddle_dbl_e);
let prod_o_dbl = _mm512_mul_epu32(r1_o, twiddle_dbl_o);

// The result of a multiplication holds r1*twiddle_dbl in as 64-bits.
// Each 64b-bit word looks like this:
// 1 31 31 1
// prod_e_dbl - |0|prod_e_h|prod_e_l|0|
// prod_o_dbl - |0|prod_o_h|prod_o_l|0|

// Interleave the even words of prod_e_dbl with the even words of prod_o_dbl:
let prod_ls = _mm512_permutex2var_epi32(prod_e_dbl, EVENS_INTERLEAVE_EVENS, prod_o_dbl);
// prod_ls - |prod_o_l|0|prod_e_l|0|

// Divide by 2:
let prod_ls = _mm512_srli_epi64(prod_ls, 1);
// prod_ls - |0|prod_o_l|0|prod_e_l|

// Interleave the odd words of prod_e_dbl with the odd words of prod_o_dbl:
let prod_hs = _mm512_permutex2var_epi32(prod_e_dbl, ODDS_INTERLEAVE_ODDS, prod_o_dbl);
// prod_hs - |0|prod_o_h|0|prod_e_h|

let prod = PackedBaseField(prod_ls) + PackedBaseField(prod_hs);
let prod = {
// TODO: Come up with a better approach than `cfg`ing on target_feature.
// TODO: Ensure all these branches get tested in the CI.
cfg_if::cfg_if! {
if #[cfg(all(target_feature = "neon", target_arch = "aarch64"))] {
super::_mul_twiddle_neon(r1, twiddle_dbl)
} else if #[cfg(all(target_feature = "simd128", target_arch = "wasm32"))] {
super::_mul_twiddle_wasm(r1, twiddle_dbl)
} else if #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] {
super::_mul_twiddle_avx512(r1, twiddle_dbl)
} else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2f"))] {
super::_mul_twiddle_avx2(r1, twiddle_dbl)
} else {
super::_mul_twiddle_simd(r1, twiddle_dbl)
}
}
};

(r0, prod)
}
Expand Down Expand Up @@ -349,29 +335,35 @@ pub unsafe fn vecwise_ibutterflies(
let (t0, t1) = compute_first_twiddles(twiddle1_dbl);

// Apply the permutation, resulting in indexing d:iabc.
(val0, val1) = val0.deinterleave_with(val1);
(val0, val1) = val0.deinterleave(val1);
(val0, val1) = avx_ibutterfly(val0, val1, t0);

// Apply the permutation, resulting in indexing c:diab.
(val0, val1) = val0.deinterleave_with(val1);
(val0, val1) = val0.deinterleave(val1);
(val0, val1) = avx_ibutterfly(val0, val1, t1);

// The twiddles for layer 2 are replicated in the following pattern:
// 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3
let t = _mm512_broadcast_i32x4(std::mem::transmute(twiddle2_dbl));
let t = simd_swizzle!(
u32x4::from_array(unsafe { std::mem::transmute(twiddle2_dbl) }),
[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]
);
// Apply the permutation, resulting in indexing b:cdia.
(val0, val1) = val0.deinterleave_with(val1);
(val0, val1) = val0.deinterleave(val1);
(val0, val1) = avx_ibutterfly(val0, val1, t);

// The twiddles for layer 3 are replicated in the following pattern:
// 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1
let t = _mm512_set1_epi64(std::mem::transmute(twiddle3_dbl));
let t = simd_swizzle!(
u32x2::from_array(unsafe { std::mem::transmute(twiddle3_dbl) }),
[0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
);
// Apply the permutation, resulting in indexing a:bcid.
(val0, val1) = val0.deinterleave_with(val1);
(val0, val1) = val0.deinterleave(val1);
(val0, val1) = avx_ibutterfly(val0, val1, t);

// Apply the permutation, resulting in indexing i:abcd.
val0.deinterleave_with(val1)
val0.deinterleave(val1)
}

/// Returns the line twiddles (x points) for an ifft on a coset.
Expand Down Expand Up @@ -414,42 +406,50 @@ pub unsafe fn ifft3(
twiddles_dbl2: [i32; 1],
) {
// Load the 8 AVX vectors from the array.
let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const());
let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const());
let mut val2 = PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const());
let mut val3 = PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const());
let mut val4 = PackedBaseField::load(values.add(offset + (4 << log_step)).cast_const());
let mut val5 = PackedBaseField::load(values.add(offset + (5 << log_step)).cast_const());
let mut val6 = PackedBaseField::load(values.add(offset + (6 << log_step)).cast_const());
let mut val7 = PackedBaseField::load(values.add(offset + (7 << log_step)).cast_const());
let mut val0 =
PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const() as *const u32);
let mut val1 =
PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const() as *const u32);
let mut val2 =
PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const() as *const u32);
let mut val3 =
PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const() as *const u32);
let mut val4 =
PackedBaseField::load(values.add(offset + (4 << log_step)).cast_const() as *const u32);
let mut val5 =
PackedBaseField::load(values.add(offset + (5 << log_step)).cast_const() as *const u32);
let mut val6 =
PackedBaseField::load(values.add(offset + (6 << log_step)).cast_const() as *const u32);
let mut val7 =
PackedBaseField::load(values.add(offset + (7 << log_step)).cast_const() as *const u32);

// Apply the first layer of ibutterflies.
(val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0]));
(val2, val3) = avx_ibutterfly(val2, val3, _mm512_set1_epi32(twiddles_dbl0[1]));
(val4, val5) = avx_ibutterfly(val4, val5, _mm512_set1_epi32(twiddles_dbl0[2]));
(val6, val7) = avx_ibutterfly(val6, val7, _mm512_set1_epi32(twiddles_dbl0[3]));
(val0, val1) = avx_ibutterfly(val0, val1, u32x16::splat(twiddles_dbl0[0] as u32));
(val2, val3) = avx_ibutterfly(val2, val3, u32x16::splat(twiddles_dbl0[1] as u32));
(val4, val5) = avx_ibutterfly(val4, val5, u32x16::splat(twiddles_dbl0[2] as u32));
(val6, val7) = avx_ibutterfly(val6, val7, u32x16::splat(twiddles_dbl0[3] as u32));

// Apply the second layer of ibutterflies.
(val0, val2) = avx_ibutterfly(val0, val2, _mm512_set1_epi32(twiddles_dbl1[0]));
(val1, val3) = avx_ibutterfly(val1, val3, _mm512_set1_epi32(twiddles_dbl1[0]));
(val4, val6) = avx_ibutterfly(val4, val6, _mm512_set1_epi32(twiddles_dbl1[1]));
(val5, val7) = avx_ibutterfly(val5, val7, _mm512_set1_epi32(twiddles_dbl1[1]));
(val0, val2) = avx_ibutterfly(val0, val2, u32x16::splat(twiddles_dbl1[0] as u32));
(val1, val3) = avx_ibutterfly(val1, val3, u32x16::splat(twiddles_dbl1[0] as u32));
(val4, val6) = avx_ibutterfly(val4, val6, u32x16::splat(twiddles_dbl1[1] as u32));
(val5, val7) = avx_ibutterfly(val5, val7, u32x16::splat(twiddles_dbl1[1] as u32));

// Apply the third layer of ibutterflies.
(val0, val4) = avx_ibutterfly(val0, val4, _mm512_set1_epi32(twiddles_dbl2[0]));
(val1, val5) = avx_ibutterfly(val1, val5, _mm512_set1_epi32(twiddles_dbl2[0]));
(val2, val6) = avx_ibutterfly(val2, val6, _mm512_set1_epi32(twiddles_dbl2[0]));
(val3, val7) = avx_ibutterfly(val3, val7, _mm512_set1_epi32(twiddles_dbl2[0]));
(val0, val4) = avx_ibutterfly(val0, val4, u32x16::splat(twiddles_dbl2[0] as u32));
(val1, val5) = avx_ibutterfly(val1, val5, u32x16::splat(twiddles_dbl2[0] as u32));
(val2, val6) = avx_ibutterfly(val2, val6, u32x16::splat(twiddles_dbl2[0] as u32));
(val3, val7) = avx_ibutterfly(val3, val7, u32x16::splat(twiddles_dbl2[0] as u32));

// Store the 8 AVX vectors back to the array.
val0.store(values.add(offset + (0 << log_step)));
val1.store(values.add(offset + (1 << log_step)));
val2.store(values.add(offset + (2 << log_step)));
val3.store(values.add(offset + (3 << log_step)));
val4.store(values.add(offset + (4 << log_step)));
val5.store(values.add(offset + (5 << log_step)));
val6.store(values.add(offset + (6 << log_step)));
val7.store(values.add(offset + (7 << log_step)));
val0.store(values.add(offset + (0 << log_step)) as *mut u32);
val1.store(values.add(offset + (1 << log_step)) as *mut u32);
val2.store(values.add(offset + (2 << log_step)) as *mut u32);
val3.store(values.add(offset + (3 << log_step)) as *mut u32);
val4.store(values.add(offset + (4 << log_step)) as *mut u32);
val5.store(values.add(offset + (5 << log_step)) as *mut u32);
val6.store(values.add(offset + (6 << log_step)) as *mut u32);
val7.store(values.add(offset + (7 << log_step)) as *mut u32);
}

/// Applies 2 ibutterfly layers on 4 vectors of 16 M31 elements.
Expand All @@ -473,24 +473,28 @@ pub unsafe fn ifft2(
twiddles_dbl1: [i32; 1],
) {
// Load the 4 AVX vectors from the array.
let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const());
let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const());
let mut val2 = PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const());
let mut val3 = PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const());
let mut val0 =
PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const() as *const u32);
let mut val1 =
PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const() as *const u32);
let mut val2 =
PackedBaseField::load(values.add(offset + (2 << log_step)).cast_const() as *const u32);
let mut val3 =
PackedBaseField::load(values.add(offset + (3 << log_step)).cast_const() as *const u32);

// Apply the first layer of butterflies.
(val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0]));
(val2, val3) = avx_ibutterfly(val2, val3, _mm512_set1_epi32(twiddles_dbl0[1]));
(val0, val1) = avx_ibutterfly(val0, val1, u32x16::splat(twiddles_dbl0[0] as u32));
(val2, val3) = avx_ibutterfly(val2, val3, u32x16::splat(twiddles_dbl0[1] as u32));

// Apply the second layer of butterflies.
(val0, val2) = avx_ibutterfly(val0, val2, _mm512_set1_epi32(twiddles_dbl1[0]));
(val1, val3) = avx_ibutterfly(val1, val3, _mm512_set1_epi32(twiddles_dbl1[0]));
(val0, val2) = avx_ibutterfly(val0, val2, u32x16::splat(twiddles_dbl1[0] as u32));
(val1, val3) = avx_ibutterfly(val1, val3, u32x16::splat(twiddles_dbl1[0] as u32));

// Store the 4 AVX vectors back to the array.
val0.store(values.add(offset + (0 << log_step)));
val1.store(values.add(offset + (1 << log_step)));
val2.store(values.add(offset + (2 << log_step)));
val3.store(values.add(offset + (3 << log_step)));
val0.store(values.add(offset + (0 << log_step)) as *mut u32);
val1.store(values.add(offset + (1 << log_step)) as *mut u32);
val2.store(values.add(offset + (2 << log_step)) as *mut u32);
val3.store(values.add(offset + (3 << log_step)) as *mut u32);
}

/// Applies 1 ibutterfly layers on 2 vectors of 16 M31 elements.
Expand All @@ -504,25 +508,24 @@ pub unsafe fn ifft2(
/// # Safety
pub unsafe fn ifft1(values: *mut i32, offset: usize, log_step: usize, twiddles_dbl0: [i32; 1]) {
// Load the 2 AVX vectors from the array.
let mut val0 = PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const());
let mut val1 = PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const());
let mut val0 =
PackedBaseField::load(values.add(offset + (0 << log_step)).cast_const() as *const u32);
let mut val1 =
PackedBaseField::load(values.add(offset + (1 << log_step)).cast_const() as *const u32);

(val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddles_dbl0[0]));
(val0, val1) = avx_ibutterfly(val0, val1, u32x16::splat(twiddles_dbl0[0] as u32));

// Store the 2 AVX vectors back to the array.
val0.store(values.add(offset + (0 << log_step)));
val1.store(values.add(offset + (1 << log_step)));
val0.store(values.add(offset + (0 << log_step)) as *mut u32);
val1.store(values.add(offset + (1 << log_step)) as *mut u32);
}

#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[cfg(test)]
mod tests {
use std::arch::x86_64::{_mm512_add_epi32, _mm512_setr_epi32};

use super::*;
use crate::core::backend::avx512::m31::PackedBaseField;
use crate::core::backend::avx512::BaseFieldVec;
use crate::core::backend::cpu::CPUCircleEvaluation;
use crate::core::backend::simd::m31::PackedBaseField;
use crate::core::backend::simd::BaseFieldVec;
use crate::core::backend::Column;
use crate::core::fft::ibutterfly;
use crate::core::fields::m31::BaseField;
Expand All @@ -531,16 +534,16 @@ mod tests {
#[test]
fn test_ibutterfly() {
unsafe {
let val0 = PackedBaseField(_mm512_setr_epi32(
let val0 = PackedBaseField::from_simd_unchecked(u32x16::from_array([
2, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
));
let val1 = PackedBaseField(_mm512_setr_epi32(
]));
let val1 = PackedBaseField::from_simd_unchecked(u32x16::from_array([
3, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,
));
let twiddle = _mm512_setr_epi32(
]));
let twiddle = u32x16::from_array([
1177558791, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
);
let twiddle_dbl = _mm512_add_epi32(twiddle, twiddle);
]);
let twiddle_dbl = twiddle + twiddle;
let (r0, r1) = avx_ibutterfly(val0, val1, twiddle_dbl);

let val0: [BaseField; 16] = std::mem::transmute(val0);
Expand Down Expand Up @@ -647,7 +650,7 @@ mod tests {
twiddle_dbls[1].clone().try_into().unwrap(),
twiddle_dbls[2].clone().try_into().unwrap(),
);
let (val0, val1) = avx_ibutterfly(val0, val1, _mm512_set1_epi32(twiddle_dbls[3][0]));
let (val0, val1) = avx_ibutterfly(val0, val1, u32x16::splat(twiddle_dbls[3][0] as u32));
std::mem::transmute([val0, val1])
};

Expand Down
Loading

0 comments on commit 347d446

Please sign in to comment.