Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML-DSA] Inline hash_functions.rs #701

Merged
merged 3 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libcrux-ml-dsa/benches/bench_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub(crate) fn print_time(label: &str, d: std::time::Duration) {
println!("{label}:{space}{time}");
}

pub(crate) const ITERATIONS: usize = 100_000;
pub(crate) const ITERATIONS: usize = 10_000;
#[allow(unused)]
pub(crate) const WARMUP_ITERATIONS: usize = 1_000;

Expand Down
39 changes: 39 additions & 0 deletions libcrux-ml-dsa/src/hash_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ pub(crate) mod portable {
state3: KeccakState,
}

#[inline(always)]
fn init_absorb(input0: &[u8], input1: &[u8], input2: &[u8], input3: &[u8]) -> Shake128X4 {
let mut state0 = incremental::shake128_init();
incremental::shake128_absorb_final(&mut state0, &input0);
Expand All @@ -112,6 +113,7 @@ pub(crate) mod portable {
}
}

#[inline(always)]
fn squeeze_first_five_blocks(
state: &mut Shake128X4,
out0: &mut [u8; shake128::FIVE_BLOCKS_SIZE],
Expand All @@ -125,6 +127,7 @@ pub(crate) mod portable {
incremental::shake128_squeeze_first_five_blocks(&mut state.state3, out3);
}

#[inline(always)]
fn squeeze_next_block(
state: &mut Shake128X4,
) -> (
Expand All @@ -146,10 +149,12 @@ pub(crate) mod portable {
}

impl shake128::XofX4 for Shake128X4 {
#[inline(always)]
fn init_absorb(input0: &[u8], input1: &[u8], input2: &[u8], input3: &[u8]) -> Self {
init_absorb(input0, input1, input2, input3)
}

#[inline(always)]
fn squeeze_first_five_blocks(
&mut self,
out0: &mut [u8; shake128::FIVE_BLOCKS_SIZE],
Expand All @@ -159,6 +164,8 @@ pub(crate) mod portable {
) {
squeeze_first_five_blocks(self, out0, out1, out2, out3);
}

#[inline(always)]
fn squeeze_next_block(
&mut self,
) -> (
Expand All @@ -175,11 +182,13 @@ pub(crate) mod portable {
#[cfg_attr(hax, hax_lib::opaque_type)]
pub(crate) struct Shake128 {}

#[inline(always)]
fn shake128<const OUTPUT_LENGTH: usize>(input: &[u8], out: &mut [u8; OUTPUT_LENGTH]) {
libcrux_sha3::portable::shake128(out, input);
}

impl shake128::Xof for Shake128 {
#[inline(always)]
fn shake128<const OUTPUT_LENGTH: usize>(input: &[u8], out: &mut [u8; OUTPUT_LENGTH]) {
shake128(input, out);
}
Expand All @@ -191,41 +200,49 @@ pub(crate) mod portable {
state: KeccakState,
}

#[inline(always)]
fn shake256<const OUTPUT_LENGTH: usize>(input: &[u8], out: &mut [u8; OUTPUT_LENGTH]) {
libcrux_sha3::portable::shake256(out, input);
}

#[inline(always)]
fn init_absorb_shake256(input: &[u8]) -> Shake256 {
let mut state = incremental::shake256_init();
incremental::shake256_absorb_final(&mut state, input);
Shake256 { state }
}

#[inline(always)]
fn squeeze_first_block_shake256(state: &mut Shake256) -> [u8; shake256::BLOCK_SIZE] {
let mut out = [0u8; shake256::BLOCK_SIZE];
incremental::shake256_squeeze_first_block(&mut state.state, &mut out);
out
}

#[inline(always)]
fn squeeze_next_block_shake256(state: &mut Shake256) -> [u8; shake256::BLOCK_SIZE] {
let mut out = [0u8; shake256::BLOCK_SIZE];
incremental::shake256_squeeze_next_block(&mut state.state, &mut out);
out
}

impl shake256::Xof for Shake256 {
#[inline(always)]
fn shake256<const OUTPUT_LENGTH: usize>(input: &[u8], out: &mut [u8; OUTPUT_LENGTH]) {
shake256(input, out);
}

#[inline(always)]
fn init_absorb(input: &[u8]) -> Self {
init_absorb_shake256(input)
}

#[inline(always)]
fn squeeze_first_block(&mut self) -> [u8; shake256::BLOCK_SIZE] {
squeeze_first_block_shake256(self)
}

#[inline(always)]
fn squeeze_next_block(&mut self) -> [u8; shake256::BLOCK_SIZE] {
squeeze_next_block_shake256(self)
}
Expand All @@ -241,6 +258,8 @@ pub(crate) mod portable {
state2: libcrux_sha3::portable::KeccakState,
state3: libcrux_sha3::portable::KeccakState,
}

#[inline(always)]
fn init_absorb_x4(input0: &[u8], input1: &[u8], input2: &[u8], input3: &[u8]) -> Shake256X4 {
let mut state0 = incremental::shake256_init();
incremental::shake256_absorb_final(&mut state0, input0);
Expand All @@ -262,6 +281,7 @@ pub(crate) mod portable {
}
}

#[inline(always)]
fn squeeze_first_block_x4(
state: &mut Shake256X4,
) -> (
Expand All @@ -282,6 +302,7 @@ pub(crate) mod portable {
(out0, out1, out2, out3)
}

#[inline(always)]
fn squeeze_next_block_x4(
state: &mut Shake256X4,
) -> (
Expand All @@ -303,10 +324,12 @@ pub(crate) mod portable {
}

impl shake256::XofX4 for Shake256X4 {
#[inline(always)]
fn init_absorb_x4(input0: &[u8], input1: &[u8], input2: &[u8], input3: &[u8]) -> Self {
init_absorb_x4(input0, input1, input2, input3)
}

#[inline(always)]
fn squeeze_first_block_x4(
&mut self,
) -> (
Expand All @@ -318,6 +341,7 @@ pub(crate) mod portable {
squeeze_first_block_x4(self)
}

#[inline(always)]
fn squeeze_next_block_x4(
&mut self,
) -> (
Expand All @@ -329,6 +353,7 @@ pub(crate) mod portable {
squeeze_next_block_x4(self)
}

#[inline(always)]
fn shake256_x4<const OUT_LEN: usize>(
input0: &[u8],
input1: &[u8],
Expand Down Expand Up @@ -358,19 +383,26 @@ pub(crate) mod portable {

use libcrux_sha3::portable::incremental::{XofAbsorb, XofSqueeze};

#[inline(always)]
pub(crate) fn shake256_init() -> Shake256Absorb {
Shake256Absorb {
state: libcrux_sha3::portable::incremental::Shake256Absorb::new(),
}
}

#[inline(always)]
pub(crate) fn shake256_absorb(st: &mut Shake256Absorb, input: &[u8]) {
st.state.absorb(input)
}

#[inline(always)]
pub(crate) fn shake256_absorb_final(st: Shake256Absorb, input: &[u8]) -> Shake256Squeeze {
Shake256Squeeze {
state: st.state.absorb_final(input),
}
}

#[inline(always)]
pub(crate) fn shake256_squeeze(st: &mut Shake256Squeeze, out: &mut [u8]) {
st.state.squeeze(out)
}
Expand All @@ -393,12 +425,14 @@ pub(crate) mod simd256 {
}

/// Init the state and absorb 4 blocks in parallel.
#[inline(always)]
fn init_absorb(input0: &[u8], input1: &[u8], input2: &[u8], input3: &[u8]) -> Shake128x4 {
let mut state = x4::incremental::init();
x4::incremental::shake128_absorb_final(&mut state, &input0, &input1, &input2, &input3);
Shake128x4 { state }
}

#[inline(always)]
fn squeeze_first_five_blocks(
state: &mut Shake128x4,
out0: &mut [u8; shake128::FIVE_BLOCKS_SIZE],
Expand All @@ -415,6 +449,7 @@ pub(crate) mod simd256 {
);
}

#[inline(always)]
fn squeeze_next_block(
state: &mut Shake128x4,
) -> (
Expand Down Expand Up @@ -536,12 +571,14 @@ pub(crate) mod simd256 {
state: libcrux_sha3::avx2::x4::incremental::KeccakState,
}

#[inline(always)]
fn init_absorb_x4(input0: &[u8], input1: &[u8], input2: &[u8], input3: &[u8]) -> Shake256x4 {
let mut state = x4::incremental::init();
x4::incremental::shake256_absorb_final(&mut state, &input0, &input1, &input2, &input3);
Shake256x4 { state }
}

#[inline(always)]
fn squeeze_first_block_x4(
state: &mut Shake256x4,
) -> (
Expand All @@ -565,6 +602,7 @@ pub(crate) mod simd256 {
(out0, out1, out2, out3)
}

#[inline(always)]
fn squeeze_next_block_x4(
state: &mut Shake256x4,
) -> (
Expand All @@ -588,6 +626,7 @@ pub(crate) mod simd256 {
(out0, out1, out2, out3)
}

#[inline(always)]
fn shake256_x4<const OUT_LEN: usize>(
input0: &[u8],
input1: &[u8],
Expand Down
4 changes: 2 additions & 2 deletions libcrux-ml-dsa/src/matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ pub(crate) fn compute_As1_plus_s2<
s2: &[PolynomialRingElement<SIMDUnit>; ROWS_IN_A],
) -> [PolynomialRingElement<SIMDUnit>; ROWS_IN_A] {
let mut result = [PolynomialRingElement::<SIMDUnit>::ZERO(); ROWS_IN_A];
let s1_ntt = s1.map(|s| ntt::<SIMDUnit>(s));

for (i, row) in A_as_ntt.iter().enumerate() {
for (j, ring_element) in row.iter().enumerate() {
let product =
ntt_multiply_montgomery::<SIMDUnit>(ring_element, &ntt::<SIMDUnit>(s1[j]));
let product = ntt_multiply_montgomery::<SIMDUnit>(ring_element, &s1_ntt[j]);
result[i] = PolynomialRingElement::add(&result[i], &product);
}

Expand Down