diff --git a/crates/prover/src/core/backend/simd/bit_reverse.rs b/crates/prover/src/core/backend/simd/bit_reverse.rs index f729ca590..5053c172a 100644 --- a/crates/prover/src/core/backend/simd/bit_reverse.rs +++ b/crates/prover/src/core/backend/simd/bit_reverse.rs @@ -1,6 +1,6 @@ -use std::arch::x86_64::{__m512i, _mm512_permutex2var_epi32}; +use std::simd::u32x16; -use super::PackedBaseField; +use super::m31::PackedBaseField; use crate::core::utils::bit_reverse_index; const VEC_BITS: u32 = 4; @@ -74,10 +74,10 @@ pub fn bit_reverse_m31(data: &mut [PackedBaseField]) { /// Bit reverses 256 M31 values, packed in 16 words of 16 elements each. fn bit_reverse16(data: [PackedBaseField; 16]) -> [PackedBaseField; 16] { - let mut data: [__m512i; 16] = unsafe { std::mem::transmute(data) }; + let mut data: [u32x16; 16] = unsafe { std::mem::transmute(data) }; // L is an input to _mm512_permutex2var_epi32, and it is used to // interleave the first half of a with the first half of b. - const L: __m512i = unsafe { + const _L: u32x16 = unsafe { core::mem::transmute([ 0b00000, 0b10000, 0b00001, 0b10001, 0b00010, 0b10010, 0b00011, 0b10011, 0b00100, 0b10100, 0b00101, 0b10101, 0b00110, 0b10110, 0b00111, 0b10111, @@ -85,7 +85,7 @@ fn bit_reverse16(data: [PackedBaseField; 16]) -> [PackedBaseField; 16] { }; // H is an input to _mm512_permutex2var_epi32, and it is used to interleave the second half // interleave the second half of a with the second half of b. - const H: __m512i = unsafe { + const _H: u32x16 = unsafe { core::mem::transmute([ 0b01000, 0b11000, 0b01001, 0b11001, 0b01010, 0b11010, 0b01011, 0b11011, 0b01100, 0b11100, 0b01101, 0b11101, 0b01110, 0b11110, 0b01111, 0b11111, @@ -111,36 +111,25 @@ fn bit_reverse16(data: [PackedBaseField; 16]) -> [PackedBaseField; 16] { // 0001:xyz0 (even indices of register 1) <= 0010:0xyz (low half of register2), and // 0001:xyz1 (odd indices of register 1) <= 0011:0xyz (low half of register 3) // or 0001:xyzw <= 001w:0xyz. - unsafe { - data = [ - _mm512_permutex2var_epi32(data[0], L, data[1]), - _mm512_permutex2var_epi32(data[2], L, data[3]), - _mm512_permutex2var_epi32(data[4], L, data[5]), - _mm512_permutex2var_epi32(data[6], L, data[7]), - _mm512_permutex2var_epi32(data[8], L, data[9]), - _mm512_permutex2var_epi32(data[10], L, data[11]), - _mm512_permutex2var_epi32(data[12], L, data[13]), - _mm512_permutex2var_epi32(data[14], L, data[15]), - _mm512_permutex2var_epi32(data[0], H, data[1]), - _mm512_permutex2var_epi32(data[2], H, data[3]), - _mm512_permutex2var_epi32(data[4], H, data[5]), - _mm512_permutex2var_epi32(data[6], H, data[7]), - _mm512_permutex2var_epi32(data[8], H, data[9]), - _mm512_permutex2var_epi32(data[10], H, data[11]), - _mm512_permutex2var_epi32(data[12], H, data[13]), - _mm512_permutex2var_epi32(data[14], H, data[15]), - ]; - } + let (d0, d8) = data[0].interleave(data[1]); + let (d1, d9) = data[2].interleave(data[3]); + let (d2, d10) = data[4].interleave(data[5]); + let (d3, d11) = data[6].interleave(data[7]); + let (d4, d12) = data[8].interleave(data[9]); + let (d5, d13) = data[10].interleave(data[11]); + let (d6, d14) = data[12].interleave(data[13]); + let (d7, d15) = data[14].interleave(data[15]); + data = [ + d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14, d15, + ]; } unsafe { std::mem::transmute(data) } } -#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))] #[cfg(test)] mod tests { - use super::bit_reverse16; - use crate::core::backend::avx512::bit_reverse::bit_reverse_m31; - use crate::core::backend::avx512::BaseFieldVec; + use super::{bit_reverse16, bit_reverse_m31}; + use crate::core::backend::simd::column::BaseFieldVec; use crate::core::backend::Column; use crate::core::fields::m31::BaseField; use crate::core::utils::bit_reverse;