Skip to content

Commit

Permalink
Fix compilation issues
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed May 12, 2024
1 parent 8c4a5e0 commit 15e93af
Showing 1 changed file with 18 additions and 29 deletions.
47 changes: 18 additions & 29 deletions crates/prover/src/core/backend/simd/bit_reverse.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::arch::x86_64::{__m512i, _mm512_permutex2var_epi32};
use std::simd::u32x16;

use super::PackedBaseField;
use super::m31::PackedBaseField;
use crate::core::utils::bit_reverse_index;

const VEC_BITS: u32 = 4;
Expand Down Expand Up @@ -74,18 +74,18 @@ pub fn bit_reverse_m31(data: &mut [PackedBaseField]) {

/// Bit reverses 256 M31 values, packed in 16 words of 16 elements each.
fn bit_reverse16(data: [PackedBaseField; 16]) -> [PackedBaseField; 16] {
let mut data: [__m512i; 16] = unsafe { std::mem::transmute(data) };
let mut data: [u32x16; 16] = unsafe { std::mem::transmute(data) };
// L is an input to _mm512_permutex2var_epi32, and it is used to
// interleave the first half of a with the first half of b.
const L: __m512i = unsafe {
const _L: u32x16 = unsafe {
core::mem::transmute([
0b00000, 0b10000, 0b00001, 0b10001, 0b00010, 0b10010, 0b00011, 0b10011, 0b00100,
0b10100, 0b00101, 0b10101, 0b00110, 0b10110, 0b00111, 0b10111,
])
};
// H is an input to _mm512_permutex2var_epi32, and it is used to interleave the second half
// interleave the second half of a with the second half of b.
const H: __m512i = unsafe {
const _H: u32x16 = unsafe {
core::mem::transmute([
0b01000, 0b11000, 0b01001, 0b11001, 0b01010, 0b11010, 0b01011, 0b11011, 0b01100,
0b11100, 0b01101, 0b11101, 0b01110, 0b11110, 0b01111, 0b11111,
Expand All @@ -111,36 +111,25 @@ fn bit_reverse16(data: [PackedBaseField; 16]) -> [PackedBaseField; 16] {
// 0001:xyz0 (even indices of register 1) <= 0010:0xyz (low half of register2), and
// 0001:xyz1 (odd indices of register 1) <= 0011:0xyz (low half of register 3)
// or 0001:xyzw <= 001w:0xyz.
unsafe {
data = [
_mm512_permutex2var_epi32(data[0], L, data[1]),
_mm512_permutex2var_epi32(data[2], L, data[3]),
_mm512_permutex2var_epi32(data[4], L, data[5]),
_mm512_permutex2var_epi32(data[6], L, data[7]),
_mm512_permutex2var_epi32(data[8], L, data[9]),
_mm512_permutex2var_epi32(data[10], L, data[11]),
_mm512_permutex2var_epi32(data[12], L, data[13]),
_mm512_permutex2var_epi32(data[14], L, data[15]),
_mm512_permutex2var_epi32(data[0], H, data[1]),
_mm512_permutex2var_epi32(data[2], H, data[3]),
_mm512_permutex2var_epi32(data[4], H, data[5]),
_mm512_permutex2var_epi32(data[6], H, data[7]),
_mm512_permutex2var_epi32(data[8], H, data[9]),
_mm512_permutex2var_epi32(data[10], H, data[11]),
_mm512_permutex2var_epi32(data[12], H, data[13]),
_mm512_permutex2var_epi32(data[14], H, data[15]),
];
}
let (d0, d8) = data[0].interleave(data[1]);
let (d1, d9) = data[2].interleave(data[3]);
let (d2, d10) = data[4].interleave(data[5]);
let (d3, d11) = data[6].interleave(data[7]);
let (d4, d12) = data[8].interleave(data[9]);
let (d5, d13) = data[10].interleave(data[11]);
let (d6, d14) = data[12].interleave(data[13]);
let (d7, d15) = data[14].interleave(data[15]);
data = [
d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14, d15,
];
}
unsafe { std::mem::transmute(data) }
}

#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[cfg(test)]
mod tests {
use super::bit_reverse16;
use crate::core::backend::avx512::bit_reverse::bit_reverse_m31;
use crate::core::backend::avx512::BaseFieldVec;
use super::{bit_reverse16, bit_reverse_m31};
use crate::core::backend::simd::column::BaseFieldVec;
use crate::core::backend::Column;
use crate::core::fields::m31::BaseField;
use crate::core::utils::bit_reverse;
Expand Down

0 comments on commit 15e93af

Please sign in to comment.