Skip to content

Commit

Permalink
Fix compilation issues
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed May 15, 2024
1 parent b24961f commit 9f342e8
Showing 1 changed file with 75 additions and 97 deletions.
172 changes: 75 additions & 97 deletions crates/prover/src/core/backend/simd/blake2s.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
//! An AVX512 implementation of the BLAKE2s compression function.
//! Based on <https://github.com/oconnor663/blake2_simd/blob/master/blake2s/src/avx2.rs>.
use std::arch::x86_64::{
__m512i, _mm512_add_epi32, _mm512_loadu_si512, _mm512_or_si512, _mm512_permutex2var_epi32,
_mm512_set1_epi32, _mm512_slli_epi32, _mm512_srli_epi32, _mm512_xor_si512,
};
use std::simd::u32x16;

use itertools::Itertools;

use super::blake2s_avx::{compress16, set1, transpose_msgs, untranspose_states};
use super::tranpose_utils::{
EVENS_CONCAT_EVENS, HHALF_INTERLEAVE_HHALF, LHALF_INTERLEAVE_LHALF, ODDS_CONCAT_ODDS,
};
use super::{AVX512Backend, VECS_LOG_SIZE};
use super::m31::LOG_N_LANES;
use super::SimdBackend;
use crate::core::backend::{Col, Column, ColumnOps};
use crate::core::fields::m31::BaseField;
use crate::core::vcs::blake2_hash::Blake2sHash;
use crate::core::vcs::blake2_merkle::Blake2sMerkleHasher;
use crate::core::vcs::ops::{MerkleHasher, MerkleOps};

const VECS_LOG_SIZE: usize = LOG_N_LANES as usize;

const IV: [u32; 8] = [
0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19,
];
Expand All @@ -36,19 +32,19 @@ const SIGMA: [[u8; 16]; 10] = [
[10, 2, 8, 4, 7, 6, 1, 5, 15, 11, 9, 14, 3, 12, 13, 0],
];

impl ColumnOps<Blake2sHash> for AVX512Backend {
impl ColumnOps<Blake2sHash> for SimdBackend {
type Column = Vec<Blake2sHash>;

fn bit_reverse_column(_column: &mut Self::Column) {
unimplemented!()
}
}

impl MerkleOps<Blake2sMerkleHasher> for AVX512Backend {
impl MerkleOps<Blake2sMerkleHasher> for SimdBackend {
fn commit_on_layer(
log_size: u32,
prev_layer: Option<&Vec<Blake2sHash>>,
columns: &[&Col<AVX512Backend, BaseField>],
columns: &[&Col<SimdBackend, BaseField>],
) -> Vec<Blake2sHash> {
// Pad prev_layer if too small.
if log_size < VECS_LOG_SIZE as u32 {
Expand All @@ -69,13 +65,11 @@ impl MerkleOps<Blake2sMerkleHasher> for AVX512Backend {
// Commit to columns.
let mut res = Vec::with_capacity(1 << log_size);
for i in 0..(1 << (log_size - VECS_LOG_SIZE as u32)) {
let mut state: [__m512i; 8] = unsafe { std::mem::zeroed() };
let mut state: [u32x16; 8] = unsafe { std::mem::zeroed() };
// Hash prev_layer, if exists.
if let Some(prev_layer) = prev_layer {
let ptr = prev_layer[(i << 5)..((i + 1) << 5)].as_ptr() as *const __m512i;
let msgs: [__m512i; 16] = std::array::from_fn(|j| unsafe {
_mm512_loadu_si512(ptr.add(j) as *const i32)
});
let ptr = prev_layer[(i << 5)..((i + 1) << 5)].as_ptr() as *const u32x16;
let msgs: [u32x16; 16] = std::array::from_fn(|j| unsafe { *ptr.add(j) });
state = unsafe {
compress16(
state,
Expand All @@ -91,7 +85,7 @@ impl MerkleOps<Blake2sMerkleHasher> for AVX512Backend {
// Hash columns in chunks of 16.
let mut col_chunk_iter = columns.array_chunks();
for col_chunk in &mut col_chunk_iter {
let msgs = col_chunk.map(|column| column.data[i].0);
let msgs = col_chunk.map(|column| column.data[i].into_simd());
state = unsafe { compress16(state, msgs, set1(0), set1(0), set1(0), set1(0)) };
}

Expand All @@ -100,7 +94,7 @@ impl MerkleOps<Blake2sMerkleHasher> for AVX512Backend {
if !remainder.is_empty() {
let msgs = remainder
.iter()
.map(|column| column.data[i].0)
.map(|column| column.data[i].into_simd())
.chain(std::iter::repeat(unsafe { set1(0) }))
.take(16)
.collect_vec()
Expand All @@ -118,42 +112,42 @@ impl MerkleOps<Blake2sMerkleHasher> for AVX512Backend {

/// # Safety
#[inline(always)]
pub unsafe fn set1(iv: i32) -> __m512i {
_mm512_set1_epi32(iv)
pub unsafe fn set1(iv: i32) -> u32x16 {
u32x16::splat(iv as u32)
}

#[inline(always)]
unsafe fn add(a: __m512i, b: __m512i) -> __m512i {
_mm512_add_epi32(a, b)
unsafe fn add(a: u32x16, b: u32x16) -> u32x16 {
a + b
}

#[inline(always)]
unsafe fn xor(a: __m512i, b: __m512i) -> __m512i {
_mm512_xor_si512(a, b)
unsafe fn xor(a: u32x16, b: u32x16) -> u32x16 {
a ^ b
}

#[inline(always)]
unsafe fn rot16(x: __m512i) -> __m512i {
_mm512_or_si512(_mm512_srli_epi32(x, 16), _mm512_slli_epi32(x, 32 - 16))
unsafe fn rot16(x: u32x16) -> u32x16 {
(x >> 16) | (x << (32 - 16))
}

#[inline(always)]
unsafe fn rot12(x: __m512i) -> __m512i {
_mm512_or_si512(_mm512_srli_epi32(x, 12), _mm512_slli_epi32(x, 32 - 12))
unsafe fn rot12(x: u32x16) -> u32x16 {
(x >> 12) | (x << (32 - 12))
}

#[inline(always)]
unsafe fn rot8(x: __m512i) -> __m512i {
_mm512_or_si512(_mm512_srli_epi32(x, 8), _mm512_slli_epi32(x, 32 - 8))
unsafe fn rot8(x: u32x16) -> u32x16 {
(x >> 8) | (x << (32 - 8))
}

#[inline(always)]
unsafe fn rot7(x: __m512i) -> __m512i {
_mm512_or_si512(_mm512_srli_epi32(x, 7), _mm512_slli_epi32(x, 32 - 7))
unsafe fn rot7(x: u32x16) -> u32x16 {
(x >> 7) | (x << (32 - 7))
}

#[inline(always)]
unsafe fn round(v: &mut [__m512i; 16], m: [__m512i; 16], r: usize) {
unsafe fn round(v: &mut [u32x16; 16], m: [u32x16; 16], r: usize) {
v[0] = add(v[0], m[SIGMA[r][0] as usize]);
v[1] = add(v[1], m[SIGMA[r][2] as usize]);
v[2] = add(v[2], m[SIGMA[r][4] as usize]);
Expand Down Expand Up @@ -269,10 +263,10 @@ unsafe fn round(v: &mut [__m512i; 16], m: [__m512i; 16], r: usize) {
v[4] = rot7(v[4]);
}

/// Transposes input chunks (16 chunks of 16 u32s each), to get 16 __m512i, each
/// Transposes input chunks (16 chunks of 16 u32s each), to get 16 u32x16, each
/// representing 16 packed instances of a message word.
/// # Safety
pub unsafe fn transpose_msgs(mut data: [__m512i; 16]) -> [__m512i; 16] {
pub unsafe fn transpose_msgs(mut data: [u32x16; 16]) -> [u32x16; 16] {
// Each _m512i chunk contains 16 u32 words.
// Index abcd:xyzw, refers to a specific word in data as follows:
// abcd - chunk index (in base 2)
Expand All @@ -281,56 +275,24 @@ pub unsafe fn transpose_msgs(mut data: [__m512i; 16]) -> [__m512i; 16] {
// abcd:xyzw => wabc:dxyz
// In other words, rotate the index to the right by 1.
for _ in 0..4 {
let (d0, d8) = data[0].deinterleave(data[1]);
let (d1, d9) = data[2].deinterleave(data[3]);
let (d2, d10) = data[4].deinterleave(data[5]);
let (d3, d11) = data[6].deinterleave(data[7]);
let (d4, d12) = data[8].deinterleave(data[9]);
let (d5, d13) = data[10].deinterleave(data[11]);
let (d6, d14) = data[12].deinterleave(data[13]);
let (d7, d15) = data[14].deinterleave(data[15]);
data = [
_mm512_permutex2var_epi32(data[0], EVENS_CONCAT_EVENS, data[1]),
_mm512_permutex2var_epi32(data[2], EVENS_CONCAT_EVENS, data[3]),
_mm512_permutex2var_epi32(data[4], EVENS_CONCAT_EVENS, data[5]),
_mm512_permutex2var_epi32(data[6], EVENS_CONCAT_EVENS, data[7]),
_mm512_permutex2var_epi32(data[8], EVENS_CONCAT_EVENS, data[9]),
_mm512_permutex2var_epi32(data[10], EVENS_CONCAT_EVENS, data[11]),
_mm512_permutex2var_epi32(data[12], EVENS_CONCAT_EVENS, data[13]),
_mm512_permutex2var_epi32(data[14], EVENS_CONCAT_EVENS, data[15]),
_mm512_permutex2var_epi32(data[0], ODDS_CONCAT_ODDS, data[1]),
_mm512_permutex2var_epi32(data[2], ODDS_CONCAT_ODDS, data[3]),
_mm512_permutex2var_epi32(data[4], ODDS_CONCAT_ODDS, data[5]),
_mm512_permutex2var_epi32(data[6], ODDS_CONCAT_ODDS, data[7]),
_mm512_permutex2var_epi32(data[8], ODDS_CONCAT_ODDS, data[9]),
_mm512_permutex2var_epi32(data[10], ODDS_CONCAT_ODDS, data[11]),
_mm512_permutex2var_epi32(data[12], ODDS_CONCAT_ODDS, data[13]),
_mm512_permutex2var_epi32(data[14], ODDS_CONCAT_ODDS, data[15]),
d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10, d11, d12, d13, d14, d15,
];
}
data
}

/// Transposes states, from 8 packed words, to get 16 results, each of size 32B.
/// # Safety
pub unsafe fn transpose_states(mut states: [__m512i; 8]) -> [__m512i; 8] {
// Each _m512i chunk contains 16 u32 words.
// Index abc:xyzw, refers to a specific word in data as follows:
// abc - chunk index (in base 2)
// xyzw - word offset (in base 2)
// Transpose by applying 3 times the index permutation:
// abc:xyzw => wab:cxyz
// In other words, rotate the index to the right by 1.
for _ in 0..3 {
states = [
_mm512_permutex2var_epi32(states[0], EVENS_CONCAT_EVENS, states[1]),
_mm512_permutex2var_epi32(states[2], EVENS_CONCAT_EVENS, states[3]),
_mm512_permutex2var_epi32(states[4], EVENS_CONCAT_EVENS, states[5]),
_mm512_permutex2var_epi32(states[6], EVENS_CONCAT_EVENS, states[7]),
_mm512_permutex2var_epi32(states[0], ODDS_CONCAT_ODDS, states[1]),
_mm512_permutex2var_epi32(states[2], ODDS_CONCAT_ODDS, states[3]),
_mm512_permutex2var_epi32(states[4], ODDS_CONCAT_ODDS, states[5]),
_mm512_permutex2var_epi32(states[6], ODDS_CONCAT_ODDS, states[7]),
];
}
states
}

/// Transposes states, from 8 packed words, to get 16 results, each of size 32B.
/// # Safety
pub unsafe fn untranspose_states(mut states: [__m512i; 8]) -> [__m512i; 8] {
pub unsafe fn untranspose_states(mut states: [u32x16; 8]) -> [u32x16; 8] {
// Each _m512i chunk contains 16 u32 words.
// Index abc:xyzw, refers to a specific word in data as follows:
// abc - chunk index (in base 2)
Expand All @@ -339,30 +301,25 @@ pub unsafe fn untranspose_states(mut states: [__m512i; 8]) -> [__m512i; 8] {
// abc:xyzw => bcx:yzwa
// In other words, rotate the index to the left by 1.
for _ in 0..3 {
states = [
_mm512_permutex2var_epi32(states[0], LHALF_INTERLEAVE_LHALF, states[4]),
_mm512_permutex2var_epi32(states[0], HHALF_INTERLEAVE_HHALF, states[4]),
_mm512_permutex2var_epi32(states[1], LHALF_INTERLEAVE_LHALF, states[5]),
_mm512_permutex2var_epi32(states[1], HHALF_INTERLEAVE_HHALF, states[5]),
_mm512_permutex2var_epi32(states[2], LHALF_INTERLEAVE_LHALF, states[6]),
_mm512_permutex2var_epi32(states[2], HHALF_INTERLEAVE_HHALF, states[6]),
_mm512_permutex2var_epi32(states[3], LHALF_INTERLEAVE_LHALF, states[7]),
_mm512_permutex2var_epi32(states[3], HHALF_INTERLEAVE_HHALF, states[7]),
];
let (d0, d1) = states[0].interleave(states[4]);
let (d2, d3) = states[1].interleave(states[5]);
let (d4, d5) = states[2].interleave(states[6]);
let (d6, d7) = states[3].interleave(states[7]);
states = [d0, d1, d2, d3, d4, d5, d6, d7];
}
states
}

/// Compress 16 blake2s instances.
/// # Safety
pub unsafe fn compress16(
h_vecs: [__m512i; 8],
msg_vecs: [__m512i; 16],
count_low: __m512i,
count_high: __m512i,
lastblock: __m512i,
lastnode: __m512i,
) -> [__m512i; 8] {
h_vecs: [u32x16; 8],
msg_vecs: [u32x16; 16],
count_low: u32x16,
count_high: u32x16,
lastblock: u32x16,
lastnode: u32x16,
) -> [u32x16; 8] {
let mut v = [
h_vecs[0],
h_vecs[1],
Expand Down Expand Up @@ -405,10 +362,11 @@ pub unsafe fn compress16(
]
}

#[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
#[cfg(test)]
mod tests {
use super::{compress16, set1, transpose_msgs, transpose_states, untranspose_states};
use std::simd::u32x16;

use super::{compress16, set1, transpose_msgs, untranspose_states};
use crate::core::vcs::blake2s_ref::compress;

#[test]
Expand Down Expand Up @@ -440,4 +398,24 @@ mod tests {

assert_eq!(res_unvectorized, res_vectorized);
}

/// Transposes states, from 8 packed words, to get 16 results, each of size 32B.
/// # Safety
pub unsafe fn transpose_states(mut states: [u32x16; 8]) -> [u32x16; 8] {
// Each _m512i chunk contains 16 u32 words.
// Index abc:xyzw, refers to a specific word in data as follows:
// abc - chunk index (in base 2)
// xyzw - word offset (in base 2)
// Transpose by applying 3 times the index permutation:
// abc:xyzw => wab:cxyz
// In other words, rotate the index to the right by 1.
for _ in 0..3 {
let (s0, s4) = states[0].deinterleave(states[1]);
let (s1, s5) = states[2].deinterleave(states[3]);
let (s2, s6) = states[4].deinterleave(states[5]);
let (s3, s7) = states[6].deinterleave(states[7]);
states = [s0, s1, s2, s3, s4, s5, s6, s7];
}
states
}
}

0 comments on commit 9f342e8

Please sign in to comment.