Skip to content

Commit

Permalink
Add bit_reverse for SIMD backend (#594)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson authored May 15, 2024
1 parent d13da83 commit 24935cf
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 65 deletions.
139 changes: 91 additions & 48 deletions crates/prover/src/core/backend/simd/bit_reverse.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,43 @@
use std::simd::u32x16;
use std::array;

use super::column::{BaseFieldVec, SecureFieldVec};
use super::m31::PackedBaseField;
use crate::core::utils::bit_reverse_index;
use super::SimdBackend;
use crate::core::backend::ColumnOps;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::utils::{bit_reverse as cpu_bit_reverse, bit_reverse_index};

const VEC_BITS: u32 = 4;

const W_BITS: u32 = 3;

pub const MIN_LOG_SIZE: u32 = 2 * W_BITS + VEC_BITS;

/// Bit reverses packed M31 values.
impl ColumnOps<BaseField> for SimdBackend {
type Column = BaseFieldVec;

fn bit_reverse_column(column: &mut Self::Column) {
// Fallback to cpu bit_reverse.
if column.data.len().ilog2() < MIN_LOG_SIZE {
cpu_bit_reverse(column.as_mut_slice());
return;
}

bit_reverse_m31(&mut column.data);
}
}

impl ColumnOps<SecureField> for SimdBackend {
type Column = SecureFieldVec;

fn bit_reverse_column(_column: &mut SecureFieldVec) {
todo!()
}
}

/// Bit reverses M31 values.
///
/// Given an array `A[0..2^n)`, computes `B[i] = A[bit_reverse(i)]`.
pub fn bit_reverse_m31(data: &mut [PackedBaseField]) {
assert!(data.len().is_power_of_two());
Expand All @@ -17,16 +47,16 @@ pub fn bit_reverse_m31(data: &mut [PackedBaseField]) {
// |v_h| = |v_l| = VEC_BITS, |w_h| = |w_l| = W_BITS, |a| = n - 2*W_BITS - VEC_BITS.
// The loops go over a, w_l, w_h, and then swaps the 16 by 16 values at:
// * w_h a w_l * <-> * rev(w_h a w_l) *.
// These are 1 or 2 chunks of 2^W_BITS contiguous AVX512 vectors.
// These are 1 or 2 chunks of 2^W_BITS contiguous `u32x16` vectors.

let log_size = data.len().ilog2();
let a_bits = log_size - 2 * W_BITS - VEC_BITS;

// TODO(spapini): when doing multithreading, do it over a.
for a in 0u32..(1 << a_bits) {
for w_l in 0u32..(1 << W_BITS) {
let w_l_rev = w_l.reverse_bits() >> (32 - W_BITS);
for w_h in 0u32..(w_l_rev + 1) {
for a in 0u32..1 << a_bits {
for w_l in 0u32..1 << W_BITS {
let w_l_rev = w_l.reverse_bits() >> (u32::BITS - W_BITS);
for w_h in 0..w_l_rev + 1 {
let idx = ((((w_h << a_bits) | a) << W_BITS) | w_l) as usize;
let idx_rev = bit_reverse_index(idx, log_size - VEC_BITS);

Expand All @@ -37,7 +67,7 @@ pub fn bit_reverse_m31(data: &mut [PackedBaseField]) {

// Read first chunk.
// TODO(spapini): Think about optimizing a_bits.
let chunk0 = std::array::from_fn(|i| unsafe {
let chunk0 = array::from_fn(|i| unsafe {
*data.get_unchecked(idx + (i << (2 * W_BITS + a_bits)))
});
let values0 = bit_reverse16(chunk0);
Expand All @@ -55,7 +85,7 @@ pub fn bit_reverse_m31(data: &mut [PackedBaseField]) {
}

// Read bit reversed chunk.
let chunk1 = std::array::from_fn(|i| unsafe {
let chunk1 = array::from_fn(|i| unsafe {
*data.get_unchecked(idx_rev + (i << (2 * W_BITS + a_bits)))
});
let values1 = bit_reverse16(chunk1);
Expand All @@ -73,25 +103,7 @@ pub fn bit_reverse_m31(data: &mut [PackedBaseField]) {
}

/// Bit reverses 256 M31 values, packed in 16 words of 16 elements each.
fn bit_reverse16(data: [PackedBaseField; 16]) -> [PackedBaseField; 16] {
let mut data: [u32x16; 16] = unsafe { std::mem::transmute(data) };
// L is an input to _mm512_permutex2var_epi32, and it is used to
// interleave the first half of a with the first half of b.
const _L: u32x16 = unsafe {
core::mem::transmute([
0b00000, 0b10000, 0b00001, 0b10001, 0b00010, 0b10010, 0b00011, 0b10011, 0b00100,
0b10100, 0b00101, 0b10101, 0b00110, 0b10110, 0b00111, 0b10111,
])
};
// H is an input to _mm512_permutex2var_epi32, and it is used to interleave the second half
// interleave the second half of a with the second half of b.
const _H: u32x16 = unsafe {
core::mem::transmute([
0b01000, 0b11000, 0b01001, 0b11001, 0b01010, 0b11010, 0b01011, 0b11011, 0b01100,
0b11100, 0b01101, 0b11101, 0b01110, 0b11110, 0b01111, 0b11111,
])
};

fn bit_reverse16(mut data: [PackedBaseField; 16]) -> [PackedBaseField; 16] {
// Denote the index of each element in the 16 packed M31 words as abcd:0123,
// where abcd is the index of the packed word and 0123 is the index of the element in the word.
// Bit reversal is achieved by applying the following permutation to the index for 4 times:
Expand All @@ -104,7 +116,7 @@ fn bit_reverse16(data: [PackedBaseField; 16]) -> [PackedBaseField; 16] {
// 3210:dcba
for _ in 0..4 {
// Apply the abcd:0123 => 0abc:123d permutation.
// _mm512_permutex2var_epi32() with L allows us to interleave the first half of 2 words.
// `interleave` allows us to interleave the first half of 2 words.
// For example, the second call interleaves 0010:0xyz (low half of register 2) with
// 0011:0xyz (low half of register 3), and stores the result in register 1 (0001).
// This results in
Expand All @@ -123,38 +135,69 @@ fn bit_reverse16(data: [PackedBaseField; 16]) -> [PackedBaseField; 16] {
d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14, d15,
];
}
unsafe { std::mem::transmute(data) }

data
}

#[cfg(test)]
mod tests {
use super::{bit_reverse16, bit_reverse_m31};
use itertools::Itertools;

use super::{bit_reverse16, bit_reverse_m31, MIN_LOG_SIZE};
use crate::core::backend::simd::column::BaseFieldVec;
use crate::core::backend::Column;
use crate::core::backend::simd::m31::{PackedM31, N_LANES};
use crate::core::backend::simd::SimdBackend;
use crate::core::backend::{Column, ColumnOps};
use crate::core::fields::m31::BaseField;
use crate::core::utils::bit_reverse;
use crate::core::utils::bit_reverse as cpu_bit_reverse;

#[test]
fn test_bit_reverse16() {
let data: [u32; 256] = std::array::from_fn(|i| i as u32);
let expected: [u32; 256] = std::array::from_fn(|i| (i as u32).reverse_bits() >> 24);
unsafe {
let data = bit_reverse16(std::mem::transmute(data));
assert_eq!(std::mem::transmute::<_, [u32; 256]>(data), expected);
}
let values: BaseFieldVec = (0..N_LANES * 16).map(BaseField::from).collect();
let mut expected = values.to_cpu();
cpu_bit_reverse(&mut expected);

let res = bit_reverse16(values.data.try_into().unwrap());

assert_eq!(res.map(PackedM31::to_array).flatten(), expected);
}

#[test]
fn test_bit_reverse() {
fn bit_reverse_m31_works() {
const SIZE: usize = 1 << 15;
let data: Vec<_> = (0..SIZE as u32)
.map(BaseField::from_u32_unchecked)
.collect();
let data: Vec<_> = (0..SIZE).map(BaseField::from).collect();
let mut expected = data.clone();
bit_reverse(&mut expected);
let mut data: BaseFieldVec = data.into_iter().collect();
cpu_bit_reverse(&mut expected);

let mut res: BaseFieldVec = data.into_iter().collect();
bit_reverse_m31(&mut res.data[..]);

assert_eq!(res.to_cpu(), expected);
}

#[test]
fn bit_reverse_small_column_works() {
const LOG_SIZE: u32 = MIN_LOG_SIZE - 1;
let column = (0..1 << LOG_SIZE).map(BaseField::from).collect_vec();
let mut expected = column.clone();
cpu_bit_reverse(&mut expected);

let mut res = column.iter().copied().collect::<BaseFieldVec>();
<SimdBackend as ColumnOps<BaseField>>::bit_reverse_column(&mut res);

assert_eq!(res.to_cpu(), expected);
}

#[test]
fn bit_reverse_large_column_works() {
const LOG_SIZE: u32 = MIN_LOG_SIZE;
let column = (0..1 << LOG_SIZE).map(BaseField::from).collect_vec();
let mut expected = column.clone();
cpu_bit_reverse(&mut expected);

let mut res = column.iter().copied().collect::<BaseFieldVec>();
<SimdBackend as ColumnOps<BaseField>>::bit_reverse_column(&mut res);

bit_reverse_m31(&mut data.data[..]);
assert_eq!(data.to_cpu(), expected);
assert_eq!(res.to_cpu(), expected);
}
}
18 changes: 1 addition & 17 deletions crates/prover/src/core/backend/simd/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,17 @@ use num_traits::Zero;
use super::m31::{PackedBaseField, N_LANES};
use super::qm31::PackedSecureField;
use super::SimdBackend;
use crate::core::backend::{Column, ColumnOps};
use crate::core::backend::Column;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::{FieldExpOps, FieldOps};

impl ColumnOps<BaseField> for SimdBackend {
type Column = BaseFieldVec;

fn bit_reverse_column(_column: &mut Self::Column) {
todo!()
}
}

impl FieldOps<BaseField> for SimdBackend {
fn batch_inverse(column: &BaseFieldVec, dst: &mut BaseFieldVec) {
PackedBaseField::batch_inverse(&column.data, &mut dst.data);
}
}

impl ColumnOps<SecureField> for SimdBackend {
type Column = SecureFieldVec;

fn bit_reverse_column(_column: &mut SecureFieldVec) {
todo!()
}
}

impl FieldOps<SecureField> for SimdBackend {
fn batch_inverse(column: &SecureFieldVec, dst: &mut SecureFieldVec) {
PackedSecureField::batch_inverse(&column.data, &mut dst.data);
Expand Down

0 comments on commit 24935cf

Please sign in to comment.