From 501b9ead5276b99855c1d02f9259821c8125d845 Mon Sep 17 00:00:00 2001 From: Andrew Milson Date: Tue, 30 Apr 2024 23:42:59 -0400 Subject: [PATCH] Add bit_reverse for SIMD backend --- .../src/core/backend/simd/bit_reverse.rs | 139 ++++++++++++------ crates/prover/src/core/backend/simd/column.rs | 18 +-- 2 files changed, 92 insertions(+), 65 deletions(-) diff --git a/crates/prover/src/core/backend/simd/bit_reverse.rs b/crates/prover/src/core/backend/simd/bit_reverse.rs index 5053c172a..37d2a14ed 100644 --- a/crates/prover/src/core/backend/simd/bit_reverse.rs +++ b/crates/prover/src/core/backend/simd/bit_reverse.rs @@ -1,13 +1,43 @@ -use std::simd::u32x16; +use std::array; +use super::column::{BaseFieldVec, SecureFieldVec}; use super::m31::PackedBaseField; -use crate::core::utils::bit_reverse_index; +use super::SimdBackend; +use crate::core::backend::ColumnOps; +use crate::core::fields::m31::BaseField; +use crate::core::fields::qm31::SecureField; +use crate::core::utils::{bit_reverse as cpu_bit_reverse, bit_reverse_index}; const VEC_BITS: u32 = 4; + const W_BITS: u32 = 3; + pub const MIN_LOG_SIZE: u32 = 2 * W_BITS + VEC_BITS; -/// Bit reverses packed M31 values. +impl ColumnOps for SimdBackend { + type Column = BaseFieldVec; + + fn bit_reverse_column(column: &mut Self::Column) { + // Fallback to cpu bit_reverse. + if column.data.len().ilog2() < MIN_LOG_SIZE { + cpu_bit_reverse(column.as_mut_slice()); + return; + } + + bit_reverse_m31(&mut column.data); + } +} + +impl ColumnOps for SimdBackend { + type Column = SecureFieldVec; + + fn bit_reverse_column(_column: &mut SecureFieldVec) { + todo!() + } +} + +/// Bit reverses M31 values. +/// /// Given an array `A[0..2^n)`, computes `B[i] = A[bit_reverse(i)]`. pub fn bit_reverse_m31(data: &mut [PackedBaseField]) { assert!(data.len().is_power_of_two()); @@ -17,16 +47,16 @@ pub fn bit_reverse_m31(data: &mut [PackedBaseField]) { // |v_h| = |v_l| = VEC_BITS, |w_h| = |w_l| = W_BITS, |a| = n - 2*W_BITS - VEC_BITS. // The loops go over a, w_l, w_h, and then swaps the 16 by 16 values at: // * w_h a w_l * <-> * rev(w_h a w_l) *. - // These are 1 or 2 chunks of 2^W_BITS contiguous AVX512 vectors. + // These are 1 or 2 chunks of 2^W_BITS contiguous `u32x16` vectors. let log_size = data.len().ilog2(); let a_bits = log_size - 2 * W_BITS - VEC_BITS; // TODO(spapini): when doing multithreading, do it over a. - for a in 0u32..(1 << a_bits) { - for w_l in 0u32..(1 << W_BITS) { - let w_l_rev = w_l.reverse_bits() >> (32 - W_BITS); - for w_h in 0u32..(w_l_rev + 1) { + for a in 0u32..1 << a_bits { + for w_l in 0u32..1 << W_BITS { + let w_l_rev = w_l.reverse_bits() >> (u32::BITS - W_BITS); + for w_h in 0..w_l_rev + 1 { let idx = ((((w_h << a_bits) | a) << W_BITS) | w_l) as usize; let idx_rev = bit_reverse_index(idx, log_size - VEC_BITS); @@ -37,7 +67,7 @@ pub fn bit_reverse_m31(data: &mut [PackedBaseField]) { // Read first chunk. // TODO(spapini): Think about optimizing a_bits. - let chunk0 = std::array::from_fn(|i| unsafe { + let chunk0 = array::from_fn(|i| unsafe { *data.get_unchecked(idx + (i << (2 * W_BITS + a_bits))) }); let values0 = bit_reverse16(chunk0); @@ -55,7 +85,7 @@ pub fn bit_reverse_m31(data: &mut [PackedBaseField]) { } // Read bit reversed chunk. - let chunk1 = std::array::from_fn(|i| unsafe { + let chunk1 = array::from_fn(|i| unsafe { *data.get_unchecked(idx_rev + (i << (2 * W_BITS + a_bits))) }); let values1 = bit_reverse16(chunk1); @@ -73,25 +103,7 @@ pub fn bit_reverse_m31(data: &mut [PackedBaseField]) { } /// Bit reverses 256 M31 values, packed in 16 words of 16 elements each. -fn bit_reverse16(data: [PackedBaseField; 16]) -> [PackedBaseField; 16] { - let mut data: [u32x16; 16] = unsafe { std::mem::transmute(data) }; - // L is an input to _mm512_permutex2var_epi32, and it is used to - // interleave the first half of a with the first half of b. - const _L: u32x16 = unsafe { - core::mem::transmute([ - 0b00000, 0b10000, 0b00001, 0b10001, 0b00010, 0b10010, 0b00011, 0b10011, 0b00100, - 0b10100, 0b00101, 0b10101, 0b00110, 0b10110, 0b00111, 0b10111, - ]) - }; - // H is an input to _mm512_permutex2var_epi32, and it is used to interleave the second half - // interleave the second half of a with the second half of b. - const _H: u32x16 = unsafe { - core::mem::transmute([ - 0b01000, 0b11000, 0b01001, 0b11001, 0b01010, 0b11010, 0b01011, 0b11011, 0b01100, - 0b11100, 0b01101, 0b11101, 0b01110, 0b11110, 0b01111, 0b11111, - ]) - }; - +fn bit_reverse16(mut data: [PackedBaseField; 16]) -> [PackedBaseField; 16] { // Denote the index of each element in the 16 packed M31 words as abcd:0123, // where abcd is the index of the packed word and 0123 is the index of the element in the word. // Bit reversal is achieved by applying the following permutation to the index for 4 times: @@ -104,7 +116,7 @@ fn bit_reverse16(data: [PackedBaseField; 16]) -> [PackedBaseField; 16] { // 3210:dcba for _ in 0..4 { // Apply the abcd:0123 => 0abc:123d permutation. - // _mm512_permutex2var_epi32() with L allows us to interleave the first half of 2 words. + // `LoLoInterleaveHiLo` allows us to interleave the first half of 2 words. // For example, the second call interleaves 0010:0xyz (low half of register 2) with // 0011:0xyz (low half of register 3), and stores the result in register 1 (0001). // This results in @@ -123,38 +135,69 @@ fn bit_reverse16(data: [PackedBaseField; 16]) -> [PackedBaseField; 16] { d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14, d15, ]; } - unsafe { std::mem::transmute(data) } + + data } #[cfg(test)] mod tests { - use super::{bit_reverse16, bit_reverse_m31}; + use itertools::Itertools; + + use super::{bit_reverse16, bit_reverse_m31, MIN_LOG_SIZE}; use crate::core::backend::simd::column::BaseFieldVec; - use crate::core::backend::Column; + use crate::core::backend::simd::m31::{PackedM31, N_LANES}; + use crate::core::backend::simd::SimdBackend; + use crate::core::backend::{Column, ColumnOps}; use crate::core::fields::m31::BaseField; - use crate::core::utils::bit_reverse; + use crate::core::utils::bit_reverse as cpu_bit_reverse; #[test] fn test_bit_reverse16() { - let data: [u32; 256] = std::array::from_fn(|i| i as u32); - let expected: [u32; 256] = std::array::from_fn(|i| (i as u32).reverse_bits() >> 24); - unsafe { - let data = bit_reverse16(std::mem::transmute(data)); - assert_eq!(std::mem::transmute::<_, [u32; 256]>(data), expected); - } + let values: BaseFieldVec = (0..N_LANES * 16).map(BaseField::from).collect(); + let mut expected = values.to_cpu(); + cpu_bit_reverse(&mut expected); + + let res = bit_reverse16(values.data.try_into().unwrap()); + + assert_eq!(res.map(PackedM31::to_array).flatten(), expected); } #[test] - fn test_bit_reverse() { + fn bit_reverse_m31_works() { const SIZE: usize = 1 << 15; - let data: Vec<_> = (0..SIZE as u32) - .map(BaseField::from_u32_unchecked) - .collect(); + let data: Vec<_> = (0..SIZE).map(BaseField::from).collect(); let mut expected = data.clone(); - bit_reverse(&mut expected); - let mut data: BaseFieldVec = data.into_iter().collect(); + cpu_bit_reverse(&mut expected); + + let mut res: BaseFieldVec = data.into_iter().collect(); + bit_reverse_m31(&mut res.data[..]); + + assert_eq!(res.to_cpu(), expected); + } + + #[test] + fn bit_reverse_small_column_works() { + const LOG_SIZE: u32 = MIN_LOG_SIZE - 1; + let column = (0..1 << LOG_SIZE).map(BaseField::from).collect_vec(); + let mut expected = column.clone(); + cpu_bit_reverse(&mut expected); + + let mut res = column.iter().copied().collect::(); + >::bit_reverse_column(&mut res); + + assert_eq!(res.to_cpu(), expected); + } + + #[test] + fn bit_reverse_large_column_works() { + const LOG_SIZE: u32 = MIN_LOG_SIZE; + let column = (0..1 << LOG_SIZE).map(BaseField::from).collect_vec(); + let mut expected = column.clone(); + cpu_bit_reverse(&mut expected); + + let mut res = column.iter().copied().collect::(); + >::bit_reverse_column(&mut res); - bit_reverse_m31(&mut data.data[..]); - assert_eq!(data.to_cpu(), expected); + assert_eq!(res.to_cpu(), expected); } } diff --git a/crates/prover/src/core/backend/simd/column.rs b/crates/prover/src/core/backend/simd/column.rs index ac113ee25..6c6c92c6d 100644 --- a/crates/prover/src/core/backend/simd/column.rs +++ b/crates/prover/src/core/backend/simd/column.rs @@ -5,33 +5,17 @@ use num_traits::Zero; use super::m31::{PackedBaseField, N_LANES}; use super::qm31::PackedSecureField; use super::SimdBackend; -use crate::core::backend::{Column, ColumnOps}; +use crate::core::backend::Column; use crate::core::fields::m31::BaseField; use crate::core::fields::qm31::SecureField; use crate::core::fields::{FieldExpOps, FieldOps}; -impl ColumnOps for SimdBackend { - type Column = BaseFieldVec; - - fn bit_reverse_column(_column: &mut Self::Column) { - todo!() - } -} - impl FieldOps for SimdBackend { fn batch_inverse(column: &BaseFieldVec, dst: &mut BaseFieldVec) { PackedBaseField::batch_inverse(&column.data, &mut dst.data); } } -impl ColumnOps for SimdBackend { - type Column = SecureFieldVec; - - fn bit_reverse_column(_column: &mut SecureFieldVec) { - todo!() - } -} - impl FieldOps for SimdBackend { fn batch_inverse(column: &SecureFieldVec, dst: &mut SecureFieldVec) { PackedSecureField::batch_inverse(&column.data, &mut dst.data);