From d706cfca5591befe968303a595ef2258a094d408 Mon Sep 17 00:00:00 2001 From: Tiago Oliveira Date: Mon, 11 Mar 2024 22:22:51 +0000 Subject: [PATCH] sync with https://github.com/formosa-crypto/hakyber/pull/32 --- .../mlkem/mlkem768/amd64/avx2/fips202.jinc | 333 ++----- .../mlkem/mlkem768/amd64/avx2/gen_matrix.jinc | 892 +++++++++++++++--- .../amd64/avx2/gen_matrix_globals.jinc | 287 ++++++ .../mlkem/mlkem768/amd64/avx2/indcpa.jinc | 48 +- .../amd64/avx2/keccak/keccakf1600.jinc | 169 ++++ .../keccak/keccakf1600_4x_avx2_compact.jinc | 331 +++++++ .../amd64/avx2/keccak/keccakf1600_avx2.jinc | 316 +++++++ .../avx2/keccak/keccakf1600_generic.jinc | 64 ++ .../mlkem/mlkem768/amd64/avx2/kem.jinc | 15 +- 9 files changed, 2052 insertions(+), 403 deletions(-) create mode 100644 src/crypto_kem/mlkem/mlkem768/amd64/avx2/gen_matrix_globals.jinc create mode 100644 src/crypto_kem/mlkem/mlkem768/amd64/avx2/keccak/keccakf1600.jinc create mode 100644 src/crypto_kem/mlkem/mlkem768/amd64/avx2/keccak/keccakf1600_4x_avx2_compact.jinc create mode 100644 src/crypto_kem/mlkem/mlkem768/amd64/avx2/keccak/keccakf1600_avx2.jinc create mode 100644 src/crypto_kem/mlkem/mlkem768/amd64/avx2/keccak/keccakf1600_generic.jinc diff --git a/src/crypto_kem/mlkem/mlkem768/amd64/avx2/fips202.jinc b/src/crypto_kem/mlkem/mlkem768/amd64/avx2/fips202.jinc index 178c5c02..741edb42 100644 --- a/src/crypto_kem/mlkem/mlkem768/amd64/avx2/fips202.jinc +++ b/src/crypto_kem/mlkem/mlkem768/amd64/avx2/fips202.jinc @@ -1,5 +1,5 @@ require "params.jinc" -require "keccakf1600.jinc" +require "keccak/keccakf1600.jinc" require "fips202_common.jinc" inline @@ -81,178 +81,18 @@ fn __add_final_block( return state; } - -inline -fn __xtr_full_block( - stack u64[25] state, - reg u64 out, - reg u64 outlen, - reg u64 rate -) -> reg u64, reg u64 -{ - reg u64 i t rate64; - - rate64 = rate; - rate64 >>= 3; - i = 0; - while (i < rate64) - { - t = state[(int) i]; - [out + 8 * i] = t; - i = i + 1; - } - - out += rate; - outlen -= rate; - - return out, outlen; -} - - -inline -fn __xtr_bytes( - stack u64[25] state, - reg u64 out, - reg u64 outlen -) -{ - reg u64 i t outlen8; - reg u8 c; - - outlen8 = outlen; - outlen8 >>= 3; - i = 0; - while (i < outlen8 ) - { - t = state[(int) i]; - [out + 8 * i] = t; - i = i + 1; - } - i <<= 3; - - while (i < outlen) - { - c = state[u8 (int) i]; - (u8)[out + i] = c; - i = i + 1; - } -} - - -inline -fn __keccak1600_scalar( - stack u64 s_out s_outlen, - reg u64 in inlen, - stack u64 s_trail_byte, - reg u64 rate -) +fn _isha3_256( + #spill_to_mmx reg ptr u8[32] out, + #spill_to_mmx reg u64 in inlen) + -> + reg ptr u8[32] { stack u64[25] state; - stack u64 s_in, s_inlen, s_rate; - reg u64 out, outlen, t; - reg u8 trail_byte; - - state = __st0(state); - - while ( inlen >= rate ) - { - state, in, inlen = __add_full_block(state, in, inlen, rate); - - s_in = in; - s_inlen = inlen; - s_rate = rate; - - state = _keccakf1600_(state); - - inlen = s_inlen; - in = s_in; - rate = s_rate; - } - - t = s_trail_byte; - trail_byte = (8u) t; - state = __add_final_block(state, in, inlen, trail_byte, rate); - - outlen = s_outlen; - - while ( outlen > rate ) - { - s_outlen = outlen; - s_rate = rate; - - state = _keccakf1600_(state); - - out = s_out; - outlen = s_outlen; - rate = s_rate; - - out, outlen = __xtr_full_block(state, out, outlen, rate); - s_outlen = outlen; - s_out = out; - } - - state = _keccakf1600_(state); - out = s_out; - outlen = s_outlen; - - __xtr_bytes(state, out, outlen); -} - - -#[returnaddress="stack"] -fn _shake256(reg u64 out outlen in inlen) -{ - stack u64 ds; - stack u64 rate; - - ds = 0x1f; - rate = SHAKE256_RATE; - - __keccak1600_scalar(out, outlen, in, inlen, ds, rate); -} - - -#[returnaddress="stack"] -fn _sha3_512(reg u64 out in inlen) -{ - reg u64 ds; - reg u64 rate; - reg u64 outlen; - - ds = 0x06; - rate = SHA3_512_RATE; - outlen = 64; - - __keccak1600_scalar(out, outlen, in, inlen, ds, rate); -} - - -#[returnaddress="stack"] -fn _sha3_256(reg u64 out in inlen) -{ - reg u64 ds; - reg u64 rate; - reg u64 outlen; - - ds = 0x06; - rate = SHA3_256_RATE; - outlen = 32; - - __keccak1600_scalar(out, outlen, in, inlen, ds, rate); -} - - -#[returnaddress="stack"] -fn _isha3_256(reg ptr u8[32] out, reg u64 in inlen) -> reg ptr u8[32] -{ - stack u64[25] state; - stack ptr u8[32] s_out; - stack u64 s_in s_ilen s_r8; - reg u64 ilen r8 t64; + #spill_to_mmx reg u64 ilen r8 t64; reg u8 t8; inline int i; - s_out = out; + () = #spill(out); state = __st0(state); @@ -263,15 +103,11 @@ fn _isha3_256(reg ptr u8[32] out, reg u64 in inlen) -> reg ptr u8[32] { state, in, ilen = __add_full_block(state, in, ilen, r8); - s_in = in; - s_ilen = ilen; - s_r8 = r8; + () = #spill(in, ilen, r8); state = _keccakf1600_(state); - in = s_in; - ilen = s_ilen; - r8 = s_r8; + () = #unspill(in, ilen, r8); } t8 = 0x06; @@ -279,74 +115,26 @@ fn _isha3_256(reg ptr u8[32] out, reg u64 in inlen) -> reg ptr u8[32] state = _keccakf1600_(state); - out = s_out; + () = #unspill(out); for i=0 to 4 - { - t64 = state[i]; + { t64 = state[i]; out[u64 i] = t64; } return out; } -inline -fn __isha3_512(reg ptr u8[64] out, reg u64 in, inline int inlen) -> stack u8[64] + +fn _shake256_1120_32(#spill_to_mmx reg u64 out in0 in1) { stack u64[25] state; - stack ptr u8[64] s_out; - stack u64 s_in s_ilen s_r8; - reg u64 ilen r8 t64; + #spill_to_mmx reg u64 ilen r8 t64; reg u8 t8; inline int i; - s_out = out; - - state = __st0(state); - - r8 = SHA3_512_RATE; - ilen = inlen; - - while(ilen >= r8) - { - state, in, ilen = __add_full_block(state, in, ilen, r8); - - s_in = in; - s_ilen = ilen; - s_r8 = r8; - - state = _keccakf1600_(state); - - in = s_in; - ilen = s_ilen; - r8 = s_r8; - } - - t8 = 0x06; - state = __add_final_block(state, in, ilen, t8, r8); - - state = _keccakf1600_(state); - - out = s_out; - - for i=0 to 8 - { - t64 = state[i]; - out[u64 i] = t64; - } - - return out; -} - -fn _shake256_1120_32(reg u64 out, reg u64 in0 in1) { - stack u64[25] state; - stack u64 s_out; - stack u64 s_in s_ilen s_r8; - reg u64 ilen r8 t64 in; - reg u8 t8; - inline int i; + () = #spill(out); - s_out = out; state = __st0(state); for i = 0 to MLKEM_SYMBYTES/8 { @@ -359,36 +147,33 @@ fn _shake256_1120_32(reg u64 out, reg u64 in0 in1) { state[u64 i] ^= t64; } - s_in = in1; + () = #spill(in1); state = _keccakf1600_(state); + () = #unspill(in1); + r8 = SHAKE256_RATE; ilen = MLKEM_INDCPA_CIPHERTEXTBYTES - (SHAKE256_RATE - MLKEM_SYMBYTES); - in = s_in; - in += SHAKE256_RATE - MLKEM_SYMBYTES; + in1 += SHAKE256_RATE - MLKEM_SYMBYTES; while(ilen >= r8) { - state, in, ilen = __add_full_block(state, in, ilen, r8); + state, in1, ilen = __add_full_block(state, in1, ilen, r8); - s_in = in; - s_ilen = ilen; - s_r8 = r8; + () = #spill(in1, ilen, r8); state = _keccakf1600_(state); - in = s_in; - ilen = s_ilen; - r8 = s_r8; + () = #unspill(in1, ilen, r8); } t8 = 0x1f; - state = __add_final_block(state, in, ilen, t8, r8); + state = __add_final_block(state, in1, ilen, t8, r8); state = _keccakf1600_(state); - out = s_out; + () = #unspill(out); for i=0 to MLKEM_SYMBYTES/8 { @@ -398,17 +183,18 @@ fn _shake256_1120_32(reg u64 out, reg u64 in0 in1) { } -#[returnaddress="stack"] -fn _shake256_128_33(reg ptr u8[128] out, reg const ptr u8[33] in) -> stack u8[128] +fn _shake256_128_33( + #spill_to_mmx reg ptr u8[128] out, + reg const ptr u8[33] in) + -> + stack u8[128] { stack u64[25] state; reg u64 t64; reg u8 c; inline int i; - stack ptr u8[128] sout; - - sout = out; + () = #spill(out); state = __st0(state); @@ -424,7 +210,7 @@ fn _shake256_128_33(reg ptr u8[128] out, reg const ptr u8[33] in) -> stack u8[12 state = _keccakf1600_(state); - out = sout; + () = #spill(out); for i = 0 to 16 { t64 = state[u64 i]; @@ -434,15 +220,17 @@ fn _shake256_128_33(reg ptr u8[128] out, reg const ptr u8[33] in) -> stack u8[12 return out; } -#[returnaddress="stack"] -fn _isha3_256_32(reg ptr u8[32] out, reg ptr u8[MLKEM_SYMBYTES] in) -> reg ptr u8[32] +fn _isha3_256_32( + #spill_to_mmx reg ptr u8[32] out, + reg ptr u8[MLKEM_SYMBYTES] in) + -> + reg ptr u8[32] { stack u64[25] state; - stack ptr u8[32] s_out; reg u64 t64; inline int i; - s_out = out; + () = #spill(out); state = __st0(state); @@ -457,7 +245,7 @@ fn _isha3_256_32(reg ptr u8[32] out, reg ptr u8[MLKEM_SYMBYTES] in) -> reg ptr u state = _keccakf1600_(state); - out = s_out; + () = #unspill(out); for i=0 to 4 { @@ -468,14 +256,18 @@ fn _isha3_256_32(reg ptr u8[32] out, reg ptr u8[MLKEM_SYMBYTES] in) -> reg ptr u return out; } -#[returnaddress="stack"] -fn _sha3_512_64(reg ptr u8[64] out, reg const ptr u8[64] in) -> stack u8[64] +fn _sha3_512_64( + #spill_to_mmx reg ptr u8[64] out, + reg const ptr u8[64] in) + -> + reg ptr u8[64] { stack u64[25] state; - stack ptr u8[64] out_s; reg u64 t64; inline int i; + () = #spill(out); + state = __st0(state); for i = 0 to 8 @@ -487,11 +279,9 @@ fn _sha3_512_64(reg ptr u8[64] out, reg const ptr u8[64] in) -> stack u8[64] state[u8 64] ^= 0x06; state[u8 SHA3_512_RATE - 1] ^= 0x80; - out_s = out; - state = _keccakf1600_(state); - out = out_s; + () = #unspill(out); for i = 0 to 8 { @@ -502,14 +292,18 @@ fn _sha3_512_64(reg ptr u8[64] out, reg const ptr u8[64] in) -> stack u8[64] return out; } -#[returnaddress="stack"] -fn _sha3_512_32(reg ptr u8[64] out, reg const ptr u8[32] in) -> stack u8[64] +fn _sha3_512_32( + #spill_to_mmx reg ptr u8[64] out, + reg const ptr u8[32] in) + -> + reg ptr u8[64] { stack u64[25] state; - stack ptr u8[64] out_s; reg u64 t64; inline int i; + () = #spill(out); + state = __st0(state); for i = 0 to 4 @@ -520,12 +314,10 @@ fn _sha3_512_32(reg ptr u8[64] out, reg const ptr u8[32] in) -> stack u8[64] state[u8 32] ^= 0x06; state[u8 SHA3_512_RATE-1] ^= 0x80; - - out_s = out; state = _keccakf1600_(state); - out = out_s; + () = #unspill(out); for i = 0 to 8 { t64 = state[i]; @@ -559,16 +351,21 @@ fn _shake128_absorb34(reg ptr u64[25] state, reg const ptr u8[34] in) -> reg ptr return state; } -#[returnaddress="stack"] -fn _shake128_squeezeblock(reg ptr u64[25] state, reg ptr u8[SHAKE128_RATE] out) -> reg ptr u64[25], reg ptr u8[SHAKE128_RATE] +fn _shake128_squeezeblock( + reg ptr u64[25] state, + #spill_to_mmx reg ptr u8[SHAKE128_RATE] out) + -> + reg ptr u64[25], + reg ptr u8[SHAKE128_RATE] { - stack ptr u8[SHAKE128_RATE] out_s; reg u64 t; inline int i; - out_s = out; + () = #spill(out); + state = _keccakf1600_(state); - out = out_s; + + () = #unspill(out); for i = 0 to SHAKE128_RATE/8 { diff --git a/src/crypto_kem/mlkem/mlkem768/amd64/avx2/gen_matrix.jinc b/src/crypto_kem/mlkem/mlkem768/amd64/avx2/gen_matrix.jinc index 59e3b518..54fe2748 100644 --- a/src/crypto_kem/mlkem/mlkem768/amd64/avx2/gen_matrix.jinc +++ b/src/crypto_kem/mlkem/mlkem768/amd64/avx2/gen_matrix.jinc @@ -1,130 +1,802 @@ +require "keccak/keccakf1600_4x_avx2_compact.jinc" +require "keccak/keccakf1600_avx2.jinc" require "params.jinc" -require "shuffle.jinc" -require "fips202.jinc" -require "params.jinc" +require "gen_matrix_globals.jinc" -inline -fn __rej_uniform(stack u16[MLKEM_N] rp, reg u64 offset, stack u8[SHAKE128_RATE] buf) -> reg u64, stack u16[MLKEM_N] +// a < b && c < d +inline fn comp_u64_l_int_and_u64_l_int( + reg u64 a, + inline int b, + reg u64 c, + inline int d) + -> + reg bool { - reg u16 val1 val2; - reg u16 t; - reg u64 pos ctr; - - - ctr = offset; - pos = 0; - - while (pos < SHAKE128_RATE - 2) { - if ctr < MLKEM_N { - val1 = (16u)buf[pos]; - t = (16u)buf[pos + 1]; - val2 = t; - val2 >>= 4; - t &= 0x0F; - t <<= 8; - val1 |= t; - - t = (16u)buf[pos + 2]; - t <<= 4; - val2 |= t; - pos += 3; - - reg bool cond; - #[declassify] - cond = val1 < MLKEM_Q; - if cond { - rp[ctr] = val1; - ctr += 1; - } - - #[declassify] - cond = val2 < MLKEM_Q; - if cond { - if(ctr < MLKEM_N) - { - rp[ctr] = val2; - ctr += 1; - } - } - } else { - pos = SHAKE128_RATE; - } + reg bool c1 c2 c3; + reg u8 bc1 bc2; + + ?{ " if(a if(a bc1 & bc2 == 0 => cond = false + // zf == 0 => bc1 & bc2 == 1 => cond = true + ?{ "!=" = c3 } = #TEST_8(bc1, bc2); + + return c3; +} + +// BUF_size per entry: 21(rate) + 21(rate) + 25(keccak_state) + 1(pad) +param int BUF_size = (3 * 21) + 4 + 1; //5; + +// deinterleave u64-lanes of 4 u256 regs +fn _4u64x4_u256x4(reg u256 y0 y1 y2 y3) -> reg u256, reg u256, reg u256, reg u256 { + reg u256 x0, x1, x2, x3; + x0 = #VPERM2I128(y0, y2, 0x20); + x1 = #VPERM2I128(y1, y3, 0x20); + x2 = #VPERM2I128(y0, y2, 0x31); + x3 = #VPERM2I128(y1, y3, 0x31); + + y0 = #VPUNPCKL_4u64(x0, x1); // y0 = l00 l01 l02 l03 + y1 = #VPUNPCKH_4u64(x0, x1); // y1 = l10 l11 l12 l13 + y2 = #VPUNPCKL_4u64(x2, x3); // y2 = l20 l21 l22 l23 + y3 = #VPUNPCKH_4u64(x2, x3); // y3 = l30 l31 l32 l33 + + return y0, y1, y2, y3; +} + +// extracts 4 keccak states (st25) from a 4-way state (st4x) +inline fn __st4x_unpack_at +( reg mut ptr u64[4*BUF_size] buf +, reg const ptr u256[25] st4x +, reg u64 offset // in bytes +) -> reg ptr u64[4*BUF_size] { + inline int i; + reg u256 x0, x1, x2, x3; + reg u64 t0, t1, t2, t3; + for i = 0 to 6 { + x0 = st4x[u256 4*i+0]; + x1 = st4x[u256 4*i+1]; + x2 = st4x[u256 4*i+2]; + x3 = st4x[u256 4*i+3]; + x0, x1, x2, x3 = _4u64x4_u256x4(x0, x1, x2, x3); + buf.[u256 offset + 4*8*i + 0*8*BUF_size] = x0; + buf.[u256 offset + 4*8*i + 1*8*BUF_size] = x1; + buf.[u256 offset + 4*8*i + 2*8*BUF_size] = x2; + buf.[u256 offset + 4*8*i + 3*8*BUF_size] = x3; + } + t0 = st4x[u64 4*24+0]; + t1 = st4x[u64 4*24+1]; + t2 = st4x[u64 4*24+2]; + t3 = st4x[u64 4*24+3]; + buf.[u64 offset + 8*24 + 0*8*BUF_size] = t0; + buf.[u64 offset + 8*24 + 1*8*BUF_size] = t1; + buf.[u64 offset + 8*24 + 2*8*BUF_size] = t2; + buf.[u64 offset + 8*24 + 3*8*BUF_size] = t3; + + return buf; +} + +inline fn __stavx2_pack_at +( reg const ptr u64[BUF_size] st +, reg u64 offset // in bytes +) -> reg u256[7] { + // 3*r256 (evitáveis...) + reg u256[7] state; + reg u256 t256_0 t256_1 t256_2; + reg u128 t128_0, t128_1; + reg u64 r; + + // [ 0 0 0 0 ] + state[0] = #VPBROADCAST_4u64(st.[u64 8*0 + offset]); + // [ 1 2 3 4 ] + state[1] = st.[u256 1*8 + offset]; + // [ 5 - ] + t128_0 = #VMOV(st.[u64 5*8 + offset]); + // [ 6 7 8 9 ] + state[3] = st.[u256 6*8 + offset]; + // [ 10 - ] + t128_1 = #VMOV(st.[u64 10*8 + offset]); + // [ 11 12 13 14 ] + state[4] = st.[u256 11*8 + offset]; + // [ 5 15 ] + r = st.[u64 15*8 + offset]; + t128_0 = #VPINSR_2u64(t128_0, r, 1); + // [ 16 17 18 19 ] + state[5] = st.[u256 16*8 + offset]; + // [ 10 20 ] + r = st.[u64 20*8 + offset]; + t128_1 = #VPINSR_2u64(t128_1, r, 1); + // alternative not currently supported: VPGATHERDQ for filling state[2] + // [ 10 20 5 15 ] + state[2] = (2u128)[t128_0, t128_1]; + // [ 21 22 23 24 ] + state[6] = st.[u256 21*8 + offset]; + + // [ 16 7 8 19 ] + t256_0 = #VPBLEND_8u32(state[3], state[5], (8u1)[1,1,0,0,0,0,1,1]); + // [ 11 22 23 14 ] + t256_1 = #VPBLEND_8u32(state[6], state[4], (8u1)[1,1,0,0,0,0,1,1]); + // [ 6 12 13 9 ] + t256_2 = #VPBLEND_8u32(state[4], state[3], (8u1)[1,1,0,0,0,0,1,1]); + + // [ 16 7 23 14 ] + state[3] = #VPBLEND_8u32(t256_0, t256_1, (8u1)[1,1,1,1,0,0,0,0]); + // [ 11 22 8 19 ] + state[4] = #VPBLEND_8u32(t256_1, t256_0, (8u1)[1,1,1,1,0,0,0,0]); + + // [ 21 17 18 24 ] + t256_0 = #VPBLEND_8u32(state[5], state[6], (8u1)[1,1,0,0,0,0,1,1]); + + // [ 21 17 13 9 ] + state[5] = #VPBLEND_8u32(t256_0, t256_2, (8u1)[1,1,1,1,0,0,0,0]); + // [ 6 12 18 24 ] + state[6] = #VPBLEND_8u32(t256_2, t256_0, (8u1)[1,1,1,1,0,0,0,0]); + + // [ 0 0 0 0 ] + // [ 1 2 3 4 ] + // [ 10 20 5 15 ] + // [ 16 7 23 14 ] + // [ 11 22 8 19 ] + // [ 21 17 13 9 ] + // [ 6 12 18 24 ] + return state; +} + +fn _stavx2_pack_at +( reg const ptr u64[BUF_size] st +, reg u64 offset // in bytes +) -> reg u256[7] { + reg u256[7] stavx2; + stavx2 = __stavx2_pack_at(st, offset); + return stavx2; +} + +inline fn __stavx2_unpack_at +( reg mut ptr u64[BUF_size] buf +, reg u64 offset // in bytes +, reg u256[7] state +) -> reg ptr u64[BUF_size] { + // 5*r256 + 2*r128(evitáveis) (+7*r256) + reg u256 t256_0 t256_1 t256_2 t256_3 t256_4; + reg u128 t128_0, t128_1; + + // [ 0, 0 ] + t128_0 = (128u) state[0]; + buf.[u64 0*8 + offset] = #VMOVLPD(t128_0); + // [ 1, 2, 3, 4 ] + buf.[u256 1*8 + offset] = state[1]; + + // [ 16, 7, 8, 19 ] + t256_0 = #VPBLEND_8u32(state[3], state[4], (8u1)[1,1,1,1,0,0,0,0]); + // [ 11, 22, 23, 14 ] + t256_1 = #VPBLEND_8u32(state[4], state[3], (8u1)[1,1,1,1,0,0,0,0]); + // [ 21, 17, 18, 24 ] + t256_2 = #VPBLEND_8u32(state[5], state[6], (8u1)[1,1,1,1,0,0,0,0]); + // [ 6, 12, 13, 9 ] + t256_3 = #VPBLEND_8u32(state[6], state[5], (8u1)[1,1,1,1,0,0,0,0]); + + // [ 5, 15 ] + t128_1 = #VEXTRACTI128(state[2], 1); + buf.[u64 5*8 + offset] = #VMOVLPD(t128_1); + + // [ 6, 7, 8, 9 ] + t256_4 = #VPBLEND_8u32(t256_0, t256_3, (8u1)[1,1,0,0,0,0,1,1]); + buf.[u256 6*8 + offset] = t256_4; + + // [ 10, 20 ] + t128_0 = (128u) state[2]; + buf.[u64 10*8 + offset] = #VMOVLPD(t128_0); + + // [ 11, 12, 13, 14 ] + t256_4 = #VPBLEND_8u32(t256_3, t256_1, (8u1)[1,1,0,0,0,0,1,1]); + buf.[u256 11*8 + offset] = t256_4; + + // [ 15 ] + buf.[u64 15*8 + offset] = #VMOVHPD(t128_1); + + // [ 16, 17, 18, 19 ] + t256_4 = #VPBLEND_8u32(t256_2, t256_0, (8u1)[1,1,0,0,0,0,1,1]); + buf.[u256 16*8 + offset] = t256_4; + + // [ 20 ] + buf.[u64 20*8 + offset] = #VMOVHPD(t128_0); + + // [ 21, 22, 23, 24 ] + t256_4 = #VPBLEND_8u32(t256_1, t256_2, (8u1)[1,1,0,0,0,0,1,1]); + buf.[u256 21*8 + offset] = t256_4; + + return buf; +} + +fn _stavx2_unpack_at +( reg mut ptr u64[BUF_size] buf +, reg u64 offset // in bytes +, reg u256[7] state +) -> reg ptr u64[BUF_size] { + buf = __stavx2_unpack_at(buf, offset, state); + return buf; +} + +u256[1] RATE_BIT_x4 = +{ (4u64)[0x8000000000000000, 0x8000000000000000, 0x8000000000000000, 0x8000000000000000] }; + +// xof related code +inline fn xof_init_x4 +( reg const ptr u8[32] rho +, reg u16[4] indexes +) -> stack u256[25] +{ + stack u256[25] state; + stack u256[1] temp; + inline int i; + reg u256 t; + reg u64 r; + + // copy rho to state + for i = 0 to 4 { + t = #VPBROADCAST_4u64(rho[u64 i]); + state[i] = t; + } + + for i=0 to 4 + { r = (64u) indexes[i]; + r |= 0x1F0000; + temp[u64 i] = r; + } + t = temp[0]; + state[4] = t; + + // init to zero + t = #set0_256(); + for i=5 to 25 { state[i] = t; } + + t = RATE_BIT_x4[0]; + t ^= state[20]; + state[20] = t; + + return state; +} + +inline fn xof_init_avx2 +( reg const ptr u8[32] rho +, reg u16 index +) -> reg u256[7] +{ + inline int i; + stack u256[1] temp; + reg u256[7] state; + reg u256 t; + reg u128 t128; + reg u64 r; + + // copy rho to state + state[0] = #VPBROADCAST_4u64(rho[u64 0]); + r = rho[u64 1]; + temp[u64 0] = r; + r = rho[u64 2]; + temp[u64 1] = r; + r = rho[u64 3]; + temp[u64 2] = r; + r = (64u) index; + r |= 0x1F0000; + temp[u64 3] = r; + state[1] = temp[0]; + + t = #set0_256(); + t128 = (128u) t; + r = 0x8000000000000000; + state[2] = (256u) #VPINSR_2u64(t128, r, 1); + + for i=3 to 7 { state[i] = t; } + + return state; +} + +/* +DEFS: +a \lmatch l == l is_prefix_of (to_list a) +bytes2coefs: W8.t list -> int list + == converte lista de bytes em lista de coefs +PARAMS: lpol, offset, lbuf + +@requires: + pol \lmatch lpol + size lpol = to_uint counter + size lpol <= MLKEM_N - 32 + to_uint buf_offset = offset + to_list buf = lbuf + 0 <= offset <= BUF_size - (48 + 8) +@ensures: + let lcoefs = filter ( reg ptr u16[MLKEM_N], reg u64 +{ + reg u256 f0 f1 g0 g1; + reg u256 shuffle_0 shuffle_1 shuffle_t; + reg u128 shuffle_0_1 shuffle_1_1; + reg u64 good t0_0 t0_1 t1_0 t1_1; + + // loads 24 bytes (while touching 32 bytes of memory) into f0 and another + // 24 bytes into f1 while doing some rearrangements: + // - consider that the memory contains the following 32 bytes (in u64s) + // - 0x01aaaaaaaaaaaa08, 0x01bbbbbbbbbbbb08, 0x01cccccccccccc08, 0x01dddddddddddd08 + // - the command given to vpermq is 0x94, or (8u1)[1,0,0,1, 0,1,0,0], or (4u2)[2,1,1,0] + // - so the last 8 bytes will be discarded: + // - 0x01aaaaaaaaaaaa08, 0x01bbbbbbbbbbbb08, 0x01bbbbbbbbbbbb08, 0x01cccccccccccc08 + + f0 = #VPERMQ(buf.[u256 (int) buf_offset + 0 ], (4u2)[2,1,1,0]); + f1 = #VPERMQ(buf.[u256 (int) buf_offset + 24], (4u2)[2,1,1,0]); + + // next, the data is shuffled at byte level. For a given state (in u64s): + // - 0xa8a7a6a5a4a3a2a1, 0xb8b7b6b5b4b3b2b1, 0xc8c7c6c5c4c3c2c1, 0xd8d7d6d5d4d3d2d1 + // f's get rearranged into: + // - 0xa6a5a5a4a3a2a2a1, 0xb4b3b3b2b1a8a8a7, 0xd2d1d1c8c7c6c6c5, 0xd8d7d7d6d5d4d4d3 + + f0 = #VPSHUFB_256(f0, load_shuffle); + f1 = #VPSHUFB_256(f1, load_shuffle); + + // next, a shift right by 4 (u16) is performed, for a given state: + // (consider that c's hold the same values as b's ++ some underscores to help the reading) + // + // - 0xa6a5_a5a4_a3a2_a2a1, 0xb4b3_b3b2_b1a8_a8a7, 0xd2d1_d1c8_c7c6_c6c5, 0xd8d7_d7d6_d5d4_d4d3 + // to: + // - 0x0a6a_0a5a_0a3a_0a2a, 0x0b4b_0b3b_0b1a_0a8a, 0x0d2d_0d1c_0c7c_0c6c, 0x0d8d_0d7d_0d5d_0d4d + + g0 = #VPSRL_16u16(f0, 4); + g1 = #VPSRL_16u16(f1, 4); + + // next, blend. + // from: + // - 0xAA (1010 1010 in binary) + // + // bottom top b t b t b t + // 1 0 1 0 1 0 1 0 (same for next 128-bit lane) + // - 0xa6a5_a5a4_a3a2_a2a1, 0xb4b3_b3b2_b1a8_a8a7, 0xd2d1_d1c8_c7c6_c6c5, 0xd8d7_d7d6_d5d4_d4d3 + // - 0x0a6a_0a5a_0a3a_0a2a, 0x0b4b_0b3b_0b1a_0a8a, 0x0d2d_0d1c_0c7c_0c6c, 0x0d8d_0d7d_0d5d_0d4d + // to: + // - 0x0a6a_a5a4_0a3a_a2a1, 0x0b4b_b3b2_0b1a_a8a7, 0x0d2d_d1c8_0c7c_c6c5, 0x0d8d_d7d6_0d5d_d4d3 + + f0 = #VPBLEND_16u16(f0, g0, 0xAA); + f1 = #VPBLEND_16u16(f1, g1, 0xAA); + + // next, mask at 12 bits (0xFFF) + // from: + // - 0x0a6a_a5a4_0a3a_a2a1, 0x0b4b_b3b2_0b1a_a8a7, 0x0d2d_d1c8_0c7c_c6c5, 0x0d8d_d7d6_0d5d_d4d3 + // to: + // - 0x0a6a_05a4_0a3a_02a1, 0x0b4b_03b2_0b1a_08a7, 0x0d2d_01c8_0c7c_06c5, 0x0d8d_07d6_0d5d_04d3 + + f0 = #VPAND_256(f0, mask); + f1 = #VPAND_256(f1, mask); + + // KYBER_Q is 3329 or 0xd01 + // + // bounds: + // - 0x0d01_0d01_0d01_0d01, ... + // + // some input: + // - 0x0a6a_05a4_0a3a_02a1, 0x0b4b_03b2_0b1a_08a7, 0x0d2d_01c8_0c7c_06c5, 0x0d8d_07d6_0d5d_04d3 + // + // output (the 'good' results are highlighted with Fs; what about when equal to 3329?) + // - 0xffff_ffff_ffff_ffff, 0xffff_ffff_ffff_ffff, 0x0000_ffff_ffff_ffff, 0x0000_ffff_0000_ffff + // + // intuitively, for i=0 to 15: if bounds[i] > input[i] then 0xffff else 0x0 + g0 = #VPCMPGT_16u16(bounds, f0); + g1 = #VPCMPGT_16u16(bounds, f1); + + // from Intel intrinsics: "Convert packed signed 16-bit integers from a and b to packed 8-bit integers using signed saturation" + // intuitively, each u16 ffff -> ff and 0000 -> 00 + // g0 = g0[0..7] || g1[0..7] || g0[8..15] || g1[8..15], where each u16 "goes to" u8 + g0 = #VPACKSS_16u16(g0, g1); + + // from Intel intrinsics: "Create mask from the most significant bit of each 8-bit element in a, and store the result in dst." + good = #VPMOVMSKB_u256u64(g0); + + good = #protect(good, ms); + + // at this point, the bit count of good contains the number of 'good' elements + + // g0 + t0_0 = good; + t0_0 &= 0xFF; // g0[0..7] + + shuffle_0 = (256u) #VMOV(sst[u64 (int)t0_0]); + ?{}, t0_0 = #POPCNT_64(t0_0); + t0_0 += counter; + + t0_1 = good; + t0_1 >>= 16; + t0_1 &= 0xFF; // g0[8..15] + shuffle_0_1 = #VMOV(sst[u64 (int)t0_1]); + ?{}, t0_1 = #POPCNT_64(t0_1); + t0_1 += t0_0; + + // g1 + t1_0 = good; + t1_0 >>= 8; + t1_0 &= 0xFF; // g1[0..7] + shuffle_1 = (256u) #VMOV(sst[u64 (int)t1_0]); + ?{}, t1_0 = #POPCNT_64(t1_0); + t1_0 += t0_1; + + t1_1 = good; + t1_1 >>= 24; + t1_1 &= 0xFF; // g1[8..15] + shuffle_1_1 = #VMOV(sst[u64 (int)t1_1]); + ?{}, t1_1 = #POPCNT_64(t1_1); + t1_1 += t1_0; + + // + + shuffle_0 = #VINSERTI128(shuffle_0, shuffle_0_1, 1); + shuffle_1 = #VINSERTI128(shuffle_1, shuffle_1_1, 1); + + // + + shuffle_t = #VPADD_32u8(shuffle_0, ones); + shuffle_0 = #VPUNPCKL_32u8(shuffle_0, shuffle_t); + + shuffle_t = #VPADD_32u8(shuffle_1, ones); + shuffle_1 = #VPUNPCKL_32u8(shuffle_1, shuffle_t); + + f0 = #VPSHUFB_256(f0, shuffle_0); + f1 = #VPSHUFB_256(f1, shuffle_1); + + // + + pol.[u128 2*counter] = (128u)f0; + pol.[u128 2*t0_0] = #VEXTRACTI128(f0, 1); + pol.[u128 2*t0_1] = (128u)f1; + pol.[u128 2*t1_0] = #VEXTRACTI128(f1, 1); + + counter = t1_1; + + return pol, counter; +} + +// safe-write (ensured to write inside the array...) +inline fn __write_u128_boundchk +( reg mut ptr u16[MLKEM_N] pol +, reg u64 ctr +, reg u128 data +, #msf reg u64 ms +) -> reg ptr u16[MLKEM_N], reg u64, #msf reg u64 +{ + reg u64 data_u64; + reg bool condition_8 condition_4 condition_2 condition_1; + + condition_8 = (ctr <= MLKEM_N-8); + if ( condition_8 ) { + ms = #update_msf(condition_8, ms); + + pol.[u128 2*(int)ctr] = data; + ctr += 8; + } else + { + ms = #update_msf(!condition_8, ms); + + data_u64 = #MOVV(data); + + condition_4 = (ctr <= MLKEM_N-4); + if ( condition_4 ) { + ms = #update_msf(condition_4, ms); + + pol.[u64 2*(int)ctr] = data_u64; + data_u64 = #VPEXTR_64(data, 1); + ctr += 4; + } else + { ms = #update_msf(!condition_4, ms); } + + condition_2 = (ctr <= MLKEM_N-2); + if ( condition_2 ) { + ms = #update_msf(condition_2, ms); + + pol.[u32 2*(int)ctr] = (32u) data_u64; + data_u64 >>= 32; + ctr += 2; + } else + { ms = #update_msf(!condition_2, ms); } + + condition_1 = (ctr <= MLKEM_N-1); + if ( condition_1 ) { + ms = #update_msf(condition_1, ms); + + pol.[u16 2*(int)ctr] = (16u) data_u64; + ctr += 1; + } else + { ms = #update_msf(!condition_1, ms); } } - return ctr, rp; + return pol, ctr, ms; } -inline -fn __gen_matrix(stack u8[MLKEM_SYMBYTES] seed, reg u64 transposed) -> stack u16[MLKEM_K*MLKEM_VECN] +inline fn __gen_matrix_buf_rejection_filter24 +( reg mut ptr u16[MLKEM_N] pol +, reg u64 counter +, reg const ptr u64[BUF_size] buf +, reg u64 buf_offset // in bytes + +, reg u256 load_shuffle mask bounds +, reg ptr u8[2048] sst +, reg u256 ones +, #msf reg u64 ms +) -> reg ptr u16[MLKEM_N], reg u64, #msf reg u64 +{ + reg u256 f0 g0 g1; + reg u256 shuffle_0 shuffle_t; + reg u128 shuffle_0_1 t128; + reg u64 good t0_0 t0_1; + + f0 = #VPERMQ(buf.[u256 (int) buf_offset + 0 ], (4u2)[2,1,1,0]); + f0 = #VPSHUFB_256(f0, load_shuffle); + g0 = #VPSRL_16u16(f0, 4); + f0 = #VPBLEND_16u16(f0, g0, 0xAA); + f0 = #VPAND_256(f0, mask); + g0 = #VPCMPGT_16u16(bounds, f0); + g1 = #set0_256(); + g0 = #VPACKSS_16u16(g0, g1); + good = #VPMOVMSKB_u256u64(g0); + + good = #protect(good, ms); + + // g0 + t0_0 = good; + t0_0 &= 0xFF; // g0[0..7] + shuffle_0 = (256u) #VMOV(sst[u64 (int)t0_0]); + ?{}, t0_0 = #POPCNT_64(t0_0); + t0_0 += counter; + + t0_1 = good; + t0_1 >>= 16; + t0_1 &= 0xFF; // g0[8..15] + shuffle_0_1 = #VMOV(sst[u64 (int)t0_1]); + ?{}, t0_1 = #POPCNT_64(t0_1); + t0_1 += t0_0; + + // + shuffle_0 = #VINSERTI128(shuffle_0, shuffle_0_1, 1); + shuffle_t = #VPADD_32u8(shuffle_0, ones); + shuffle_0 = #VPUNPCKL_32u8(shuffle_0, shuffle_t); + f0 = #VPSHUFB_256(f0, shuffle_0); + // + + t128 = (128u) f0; + pol, counter, ms = __write_u128_boundchk(pol, counter, t128, ms); + + t128 = #VEXTRACTI128(f0, 1); + pol, counter, ms = __write_u128_boundchk(pol, t0_0, t128, ms); + + counter = t0_1; + + return pol, counter, ms; +} + + +fn _gen_matrix_buf_rejection +( reg mut ptr u16[MLKEM_N] pol // polynomial +, reg u64 counter // number of coefs. already sampled +, reg const ptr u64[BUF_size] buf // whole buffer (size=21+21+25 (+1 pad)) +, reg u64 buf_offset // start looking at... (bytes) + +) -> reg ptr u16[MLKEM_N], reg u64 // pol. and counter { - stack u8[34] extseed; - stack u8[SHAKE128_RATE] buf; - stack u64[25] state; - stack u16[MLKEM_N] poly; - stack u16[MLKEM_K*MLKEM_VECN] r; - - reg u8 c; - reg u16 t; - reg u64 ctr k; - stack u64 sctr; - stack u64 stransposed; - inline int j i; - - stransposed = transposed; - - for j = 0 to MLKEM_SYMBYTES + reg bool condition_loop; + reg ptr u8[2048] sst; + reg u256 load_shuffle mask bounds ones; + #msf reg u64 ms; + + ms = #init_msf(); + + load_shuffle = sample_load_shuffle[u256 0]; + mask = sample_mask; + bounds = sample_q; + ones = sample_ones; + sst = sample_shuffle_table; + + buf_offset = buf_offset; + + while { condition_loop = comp_u64_l_int_and_u64_l_int(buf_offset, 3*168-48+1, counter, MLKEM_N-32+1); } + ( condition_loop ) + { + ms = #update_msf(condition_loop, ms); + + pol, counter = __gen_matrix_buf_rejection_filter48(pol, counter, buf, buf_offset, load_shuffle, mask, bounds, sst, ones, ms); + buf_offset += 48; + } + ms = #update_msf(!condition_loop, ms); + + while { condition_loop = comp_u64_l_int_and_u64_l_int(buf_offset, 3*168-24+1, counter, MLKEM_N+1); } + ( condition_loop ) { - c = seed[j]; - extseed[j] = c; + ms = #update_msf(condition_loop, ms); + + pol, counter, ms = __gen_matrix_buf_rejection_filter24(pol, counter, buf, buf_offset, load_shuffle, mask, bounds, sst, ones, ms); + + buf_offset += 24; } - for i=0 to MLKEM_K + return pol, counter; +} + + +u16[2*4*2] gen_matrix_indexes = +{ + 0x0000, 0x0001, 0x0002, 0x0100, // (0,0) (0,1) (0,2) (1,0) + 0x0101, 0x0102, 0x0200, 0x0201, // (1,1) (1,2) (2,0) (2,1) + + 0x0000, 0x0100, 0x0200, 0x0001, // previous indexes: swapped for transposed + 0x0101, 0x0201, 0x0002, 0x0102 +}; + +inline fn gen_matrix_get_indexes( + reg u64 b, + reg u64 _t) + -> + reg u16[4] +{ + reg u64 t; + reg u16[4] idx; + reg ptr u16[2*4*2] gmi; + + gmi = gen_matrix_indexes; + + t = _t; t <<= 3; // t * 8 + b += t; + + idx[0] = gmi[(int) b + 0]; + idx[1] = gmi[(int) b + 1]; + idx[2] = gmi[(int) b + 2]; + idx[3] = gmi[(int) b + 3]; + + return idx; +} + +fn _gen_matrix_sample_four_polynomials +( reg mut ptr u16[4*MLKEM_N] polx4 +, reg mut ptr u64[4*BUF_size] bufx4 +, reg ptr u8[32] rho +, reg u64 mat_entry +, reg u64 transposed +) -> reg ptr u16[4*MLKEM_N], reg ptr u64[4*BUF_size] +{ + inline int i; + reg ptr u64[BUF_size] buf; + reg ptr u16[MLKEM_N] pol; + stack u256[25] state; + reg ptr u256[25] stx4; + reg u256[7] stavx2; + reg u16[4] indexes; + reg u64 counter buf_offset; + + indexes = gen_matrix_get_indexes(mat_entry, transposed); + + state = xof_init_x4(rho, indexes); + stx4 = state; + buf_offset = 0; + while (buf_offset < 3*168) { + stx4 = _keccakf1600_4x(stx4); + bufx4 = __st4x_unpack_at( bufx4, stx4, buf_offset ); + buf_offset += 168; + } + + for i = 0 to 4 { - for j = 0 to MLKEM_K - { - transposed = stransposed; - if(transposed == 0) - { - extseed[MLKEM_SYMBYTES] = j; - extseed[MLKEM_SYMBYTES+1] = i; - } - else - { - extseed[MLKEM_SYMBYTES] = i; - extseed[MLKEM_SYMBYTES+1] = j; - } - - state = _shake128_absorb34(state, extseed); - - ctr = 0; - while (ctr < MLKEM_N) - { - sctr = ctr; - state, buf = _shake128_squeezeblock(state, buf); - ctr = sctr; - ctr, poly = __rej_uniform(poly, ctr, buf); - } - - k = 0; - reg ptr u16[MLKEM_N] rij; - rij = r[i * MLKEM_VECN + j * MLKEM_N : MLKEM_N]; - while (k < MLKEM_N) - { - t = poly[(int) k]; - rij[k] = t; - k += 1; - } - r[i * MLKEM_VECN + j * MLKEM_N : MLKEM_N] = rij; + buf = bufx4[i*BUF_size:BUF_size]; + buf_offset = 0; + pol = polx4[i*MLKEM_N:MLKEM_N]; + counter = 0; + pol, counter = _gen_matrix_buf_rejection(pol, counter, buf, buf_offset); + buf_offset = 2*168; + while (counter < MLKEM_N) { + stavx2 = _stavx2_pack_at(buf, buf_offset); + stavx2 = _keccakf1600_avx2(stavx2); + buf = _stavx2_unpack_at(buf, buf_offset, stavx2); + pol, counter = _gen_matrix_buf_rejection(pol, counter, buf, buf_offset); } + polx4[i*MLKEM_N:MLKEM_N] = pol; + bufx4[i*BUF_size:BUF_size] = buf; } + + return polx4, bufx4; +} - for i = 0 to MLKEM_K +inline fn __gen_matrix_sample_one_polynomial +( reg mut ptr u16[MLKEM_N] pol +, reg mut ptr u64[BUF_size] buf +, reg ptr u8[32] rho +, reg u16 rc +) -> reg ptr u16[MLKEM_N], reg ptr u64[BUF_size] +{ + reg u256[7] stavx2; + reg u64 counter buf_offset; + + stavx2 = xof_init_avx2(rho, rc); + buf_offset = 0; + while (buf_offset < 3*168) { + stavx2 = _keccakf1600_avx2(stavx2); + buf = _stavx2_unpack_at( buf, buf_offset, stavx2 ); + buf_offset += 168; + } + buf_offset = 0; + counter = 0; + pol, counter = _gen_matrix_buf_rejection(pol, counter, buf, buf_offset); + + buf_offset = 2*168; + while (counter < MLKEM_N) { + stavx2 = _stavx2_pack_at(buf, buf_offset); + stavx2 = _keccakf1600_avx2(stavx2); + buf = _stavx2_unpack_at(buf, buf_offset, stavx2); + pol, counter = _gen_matrix_buf_rejection(pol, counter, buf, buf_offset); + } + + return pol, buf; +} + + + +fn _gen_matrix_avx2 +( reg ptr u16[MLKEM_K * MLKEM_K * MLKEM_N] matrix +, reg ptr u8[32] rho +, #spill_to_mmx reg u64 transposed +) -> reg ptr u16[MLKEM_K * MLKEM_K * MLKEM_N] +{ + // local variables + inline int i j; + stack u64[4*BUF_size] bufx4_s; + reg ptr u64[4*BUF_size] bufx4; + reg ptr u64[BUF_size] buf; + reg ptr u16[4*MLKEM_N] polx4; + reg ptr u16[MLKEM_N] pol; + reg u64 mat_entry; + reg u16 rc; + + () = #spill(transposed); + + bufx4 = bufx4_s; + + for i = 0 to 2 { - for j = 0 to MLKEM_K - { - r[i*MLKEM_VECN+j*MLKEM_N:MLKEM_N] = _nttunpack(r[i*MLKEM_VECN+j*MLKEM_N:MLKEM_N]); + mat_entry = 4*i; + polx4 = matrix[4*i*MLKEM_N:4*MLKEM_N]; + () = #unspill(transposed); + polx4, bufx4 = _gen_matrix_sample_four_polynomials(polx4, bufx4, rho, mat_entry, transposed); + matrix[i*4*MLKEM_N:4*MLKEM_N] = polx4; + } + + + // sample the last one, (2,2), using scalar code + buf = bufx4[0:BUF_size]; + pol = matrix[8*MLKEM_N:MLKEM_N]; + rc = 0x0202; + pol, _ = __gen_matrix_sample_one_polynomial(pol, buf, rho, rc); + + matrix[8*MLKEM_N:MLKEM_N] = pol; + bufx4_s = bufx4; + + for i = 0 to MLKEM_K + { for j = 0 to MLKEM_K + { matrix[i*MLKEM_VECN+j*MLKEM_N:MLKEM_N] = _nttunpack(matrix[i*MLKEM_VECN+j*MLKEM_N:MLKEM_N]); } } - return r; + return matrix; } + diff --git a/src/crypto_kem/mlkem/mlkem768/amd64/avx2/gen_matrix_globals.jinc b/src/crypto_kem/mlkem/mlkem768/amd64/avx2/gen_matrix_globals.jinc new file mode 100644 index 00000000..9de20d3e --- /dev/null +++ b/src/crypto_kem/mlkem/mlkem768/amd64/avx2/gen_matrix_globals.jinc @@ -0,0 +1,287 @@ +require "params.jinc" + +u8[32] sample_load_shuffle = { + 0, 1, 1, 2, 3, 4, 4, 5, + 6, 7, 7, 8, 9, 10, 10, 11, + 4, 5, 5, 6, 7, 8, 8, 9, + 10, 11, 11, 12, 13, 14, 14, 15 +}; + +u256 sample_ones = (32u8) +[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]; + +u256 sample_mask = (16u16) +[ 0x0FFF, 0x0FFF, 0x0FFF, 0x0FFF, 0x0FFF, 0x0FFF, 0x0FFF, 0x0FFF, + 0x0FFF, 0x0FFF, 0x0FFF, 0x0FFF, 0x0FFF, 0x0FFF, 0x0FFF, 0x0FFF]; + +u256 sample_q = (16u16) +[ MLKEM_Q, MLKEM_Q, MLKEM_Q, MLKEM_Q, MLKEM_Q, MLKEM_Q, MLKEM_Q, MLKEM_Q, + MLKEM_Q, MLKEM_Q, MLKEM_Q, MLKEM_Q, MLKEM_Q, MLKEM_Q, MLKEM_Q, MLKEM_Q]; + +u8[256*8] sample_shuffle_table = +{ + -1, -1, -1, -1, -1, -1, -1, -1, // 0 - _0000 -- no good, upper bit 1, set to zero + 0, -1, -1, -1, -1, -1, -1, -1, // 1 - _0001 -- only one good at (byte) offset 0 + 2, -1, -1, -1, -1, -1, -1, -1, // 2 - _0010 -- only one good at (byte) offset 2 + 0, 2, -1, -1, -1, -1, -1, -1, // 3 - _0011 -- two good at (byte) offset 0 and 2 + 4, -1, -1, -1, -1, -1, -1, -1, // 4 - _0100 -- only one good at (byte) offset 4 + 0, 4, -1, -1, -1, -1, -1, -1, // 5 - _0101 -- ... + 2, 4, -1, -1, -1, -1, -1, -1, // 6 - _0110 + 0, 2, 4, -1, -1, -1, -1, -1, // ... + 6, -1, -1, -1, -1, -1, -1, -1, + 0, 6, -1, -1, -1, -1, -1, -1, + 2, 6, -1, -1, -1, -1, -1, -1, + 0, 2, 6, -1, -1, -1, -1, -1, + 4, 6, -1, -1, -1, -1, -1, -1, + 0, 4, 6, -1, -1, -1, -1, -1, + 2, 4, 6, -1, -1, -1, -1, -1, + 0, 2, 4, 6, -1, -1, -1, -1, + 8, -1, -1, -1, -1, -1, -1, -1, + 0, 8, -1, -1, -1, -1, -1, -1, + 2, 8, -1, -1, -1, -1, -1, -1, + 0, 2, 8, -1, -1, -1, -1, -1, + 4, 8, -1, -1, -1, -1, -1, -1, + 0, 4, 8, -1, -1, -1, -1, -1, + 2, 4, 8, -1, -1, -1, -1, -1, + 0, 2, 4, 8, -1, -1, -1, -1, + 6, 8, -1, -1, -1, -1, -1, -1, + 0, 6, 8, -1, -1, -1, -1, -1, + 2, 6, 8, -1, -1, -1, -1, -1, + 0, 2, 6, 8, -1, -1, -1, -1, + 4, 6, 8, -1, -1, -1, -1, -1, + 0, 4, 6, 8, -1, -1, -1, -1, + 2, 4, 6, 8, -1, -1, -1, -1, + 0, 2, 4, 6, 8, -1, -1, -1, + +// + + 10, -1, -1, -1, -1, -1, -1, -1, + 0, 10, -1, -1, -1, -1, -1, -1, + 2, 10, -1, -1, -1, -1, -1, -1, + 0, 2, 10, -1, -1, -1, -1, -1, + 4, 10, -1, -1, -1, -1, -1, -1, + 0, 4, 10, -1, -1, -1, -1, -1, + 2, 4, 10, -1, -1, -1, -1, -1, + 0, 2, 4, 10, -1, -1, -1, -1, + 6, 10, -1, -1, -1, -1, -1, -1, + 0, 6, 10, -1, -1, -1, -1, -1, + 2, 6, 10, -1, -1, -1, -1, -1, + 0, 2, 6, 10, -1, -1, -1, -1, + 4, 6, 10, -1, -1, -1, -1, -1, + 0, 4, 6, 10, -1, -1, -1, -1, + 2, 4, 6, 10, -1, -1, -1, -1, + 0, 2, 4, 6, 10, -1, -1, -1, + 8, 10, -1, -1, -1, -1, -1, -1, + 0, 8, 10, -1, -1, -1, -1, -1, + 2, 8, 10, -1, -1, -1, -1, -1, + 0, 2, 8, 10, -1, -1, -1, -1, + 4, 8, 10, -1, -1, -1, -1, -1, + 0, 4, 8, 10, -1, -1, -1, -1, + 2, 4, 8, 10, -1, -1, -1, -1, + 0, 2, 4, 8, 10, -1, -1, -1, + 6, 8, 10, -1, -1, -1, -1, -1, + 0, 6, 8, 10, -1, -1, -1, -1, + 2, 6, 8, 10, -1, -1, -1, -1, + 0, 2, 6, 8, 10, -1, -1, -1, + 4, 6, 8, 10, -1, -1, -1, -1, + 0, 4, 6, 8, 10, -1, -1, -1, + 2, 4, 6, 8, 10, -1, -1, -1, + 0, 2, 4, 6, 8, 10, -1, -1, + +// + + 12,-1, -1, -1, -1, -1, -1, -1, + 0, 12, -1, -1, -1, -1, -1, -1, + 2, 12, -1, -1, -1, -1, -1, -1, + 0, 2, 12, -1, -1, -1, -1, -1, + 4, 12, -1, -1, -1, -1, -1, -1, + 0, 4, 12, -1, -1, -1, -1, -1, + 2, 4, 12, -1, -1, -1, -1, -1, + 0, 2, 4, 12, -1, -1, -1, -1, + 6, 12, -1, -1, -1, -1, -1, -1, + 0, 6, 12, -1, -1, -1, -1, -1, + 2, 6, 12, -1, -1, -1, -1, -1, + 0, 2, 6, 12, -1, -1, -1, -1, + 4, 6, 12, -1, -1, -1, -1, -1, + 0, 4, 6, 12, -1, -1, -1, -1, + 2, 4, 6, 12, -1, -1, -1, -1, + 0, 2, 4, 6, 12, -1, -1, -1, + 8, 12, -1, -1, -1, -1, -1, -1, + 0, 8, 12, -1, -1, -1, -1, -1, + 2, 8, 12, -1, -1, -1, -1, -1, + 0, 2, 8, 12, -1, -1, -1, -1, + 4, 8, 12, -1, -1, -1, -1, -1, + 0, 4, 8, 12, -1, -1, -1, -1, + 2, 4, 8, 12, -1, -1, -1, -1, + 0, 2, 4, 8, 12, -1, -1, -1, + 6, 8, 12, -1, -1, -1, -1, -1, + 0, 6, 8, 12, -1, -1, -1, -1, + 2, 6, 8, 12, -1, -1, -1, -1, + 0, 2, 6, 8, 12, -1, -1, -1, + 4, 6, 8, 12, -1, -1, -1, -1, + 0, 4, 6, 8, 12, -1, -1, -1, + 2, 4, 6, 8, 12, -1, -1, -1, + 0, 2, 4, 6, 8, 12, -1, -1, + 10, 12, -1, -1, -1, -1, -1, -1, + 0, 10, 12, -1, -1, -1, -1, -1, + 2, 10, 12, -1, -1, -1, -1, -1, + 0, 2, 10, 12, -1, -1, -1, -1, + 4, 10, 12, -1, -1, -1, -1, -1, + 0, 4, 10, 12, -1, -1, -1, -1, + 2, 4, 10, 12, -1, -1, -1, -1, + 0, 2, 4, 10, 12, -1, -1, -1, + 6, 10, 12, -1, -1, -1, -1, -1, + 0, 6, 10, 12, -1, -1, -1, -1, + 2, 6, 10, 12, -1, -1, -1, -1, + 0, 2, 6, 10, 12, -1, -1, -1, + 4, 6, 10, 12, -1, -1, -1, -1, + 0, 4, 6, 10, 12, -1, -1, -1, + 2, 4, 6, 10, 12, -1, -1, -1, + 0, 2, 4, 6, 10, 12, -1, -1, + 8, 10, 12, -1, -1, -1, -1, -1, + 0, 8, 10, 12, -1, -1, -1, -1, + 2, 8, 10, 12, -1, -1, -1, -1, + 0, 2, 8, 10, 12, -1, -1, -1, + 4, 8, 10, 12, -1, -1, -1, -1, + 0, 4, 8, 10, 12, -1, -1, -1, + 2, 4, 8, 10, 12, -1, -1, -1, + 0, 2, 4, 8, 10, 12, -1, -1, + 6, 8, 10, 12, -1, -1, -1, -1, + 0, 6, 8, 10, 12, -1, -1, -1, + 2, 6, 8, 10, 12, -1, -1, -1, + 0, 2, 6, 8, 10, 12, -1, -1, + 4, 6, 8, 10, 12, -1, -1, -1, + 0, 4, 6, 8, 10, 12, -1, -1, + 2, 4, 6, 8, 10, 12, -1, -1, + 0, 2, 4, 6, 8, 10, 12, -1, + 14, -1, -1, -1, -1, -1, -1, -1, + 0, 14, -1, -1, -1, -1, -1, -1, + 2, 14, -1, -1, -1, -1, -1, -1, + 0, 2, 14, -1, -1, -1, -1, -1, + 4, 14, -1, -1, -1, -1, -1, -1, + 0, 4, 14, -1, -1, -1, -1, -1, + 2, 4, 14, -1, -1, -1, -1, -1, + 0, 2, 4, 14, -1, -1, -1, -1, + 6, 14, -1, -1, -1, -1, -1, -1, + 0, 6, 14, -1, -1, -1, -1, -1, + 2, 6, 14, -1, -1, -1, -1, -1, + 0, 2, 6, 14, -1, -1, -1, -1, + 4, 6, 14, -1, -1, -1, -1, -1, + 0, 4, 6, 14, -1, -1, -1, -1, + 2, 4, 6, 14, -1, -1, -1, -1, + 0, 2, 4, 6, 14, -1, -1, -1, + 8, 14, -1, -1, -1, -1, -1, -1, + 0, 8, 14, -1, -1, -1, -1, -1, + 2, 8, 14, -1, -1, -1, -1, -1, + 0, 2, 8, 14, -1, -1, -1, -1, + 4, 8, 14, -1, -1, -1, -1, -1, + 0, 4, 8, 14, -1, -1, -1, -1, + 2, 4, 8, 14, -1, -1, -1, -1, + 0, 2, 4, 8, 14, -1, -1, -1, + 6, 8, 14, -1, -1, -1, -1, -1, + 0, 6, 8, 14, -1, -1, -1, -1, + 2, 6, 8, 14, -1, -1, -1, -1, + 0, 2, 6, 8, 14, -1, -1, -1, + 4, 6, 8, 14, -1, -1, -1, -1, + 0, 4, 6, 8, 14, -1, -1, -1, + 2, 4, 6, 8, 14, -1, -1, -1, + 0, 2, 4, 6, 8, 14, -1, -1, + 10, 14, -1, -1, -1, -1, -1, -1, + 0, 10, 14, -1, -1, -1, -1, -1, + 2, 10, 14, -1, -1, -1, -1, -1, + 0, 2, 10, 14, -1, -1, -1, -1, + 4, 10, 14, -1, -1, -1, -1, -1, + 0, 4, 10, 14, -1, -1, -1, -1, + 2, 4, 10, 14, -1, -1, -1, -1, + 0, 2, 4, 10, 14, -1, -1, -1, + 6, 10, 14, -1, -1, -1, -1, -1, + 0, 6, 10, 14, -1, -1, -1, -1, + 2, 6, 10, 14, -1, -1, -1, -1, + 0, 2, 6, 10, 14, -1, -1, -1, + 4, 6, 10, 14, -1, -1, -1, -1, + 0, 4, 6, 10, 14, -1, -1, -1, + 2, 4, 6, 10, 14, -1, -1, -1, + 0, 2, 4, 6, 10, 14, -1, -1, + 8, 10, 14, -1, -1, -1, -1, -1, + 0, 8, 10, 14, -1, -1, -1, -1, + 2, 8, 10, 14, -1, -1, -1, -1, + 0, 2, 8, 10, 14, -1, -1, -1, + 4, 8, 10, 14, -1, -1, -1, -1, + 0, 4, 8, 10, 14, -1, -1, -1, + 2, 4, 8, 10, 14, -1, -1, -1, + 0, 2, 4, 8, 10, 14, -1, -1, + 6, 8, 10, 14, -1, -1, -1, -1, + 0, 6, 8, 10, 14, -1, -1, -1, + 2, 6, 8, 10, 14, -1, -1, -1, + 0, 2, 6, 8, 10, 14, -1, -1, + 4, 6, 8, 10, 14, -1, -1, -1, + 0, 4, 6, 8, 10, 14, -1, -1, + 2, 4, 6, 8, 10, 14, -1, -1, + 0, 2, 4, 6, 8, 10, 14, -1, + 12, 14, -1, -1, -1, -1, -1, -1, + 0, 12, 14, -1, -1, -1, -1, -1, + 2, 12, 14, -1, -1, -1, -1, -1, + 0, 2, 12, 14, -1, -1, -1, -1, + 4, 12, 14, -1, -1, -1, -1, -1, + 0, 4, 12, 14, -1, -1, -1, -1, + 2, 4, 12, 14, -1, -1, -1, -1, + 0, 2, 4, 12, 14, -1, -1, -1, + 6, 12, 14, -1, -1, -1, -1, -1, + 0, 6, 12, 14, -1, -1, -1, -1, + 2, 6, 12, 14, -1, -1, -1, -1, + 0, 2, 6, 12, 14, -1, -1, -1, + 4, 6, 12, 14, -1, -1, -1, -1, + 0, 4, 6, 12, 14, -1, -1, -1, + 2, 4, 6, 12, 14, -1, -1, -1, + 0, 2, 4, 6, 12, 14, -1, -1, + 8, 12, 14, -1, -1, -1, -1, -1, + 0, 8, 12, 14, -1, -1, -1, -1, + 2, 8, 12, 14, -1, -1, -1, -1, + 0, 2, 8, 12, 14, -1, -1, -1, + 4, 8, 12, 14, -1, -1, -1, -1, + 0, 4, 8, 12, 14, -1, -1, -1, + 2, 4, 8, 12, 14, -1, -1, -1, + 0, 2, 4, 8, 12, 14, -1, -1, + 6, 8, 12, 14, -1, -1, -1, -1, + 0, 6, 8, 12, 14, -1, -1, -1, + 2, 6, 8, 12, 14, -1, -1, -1, + 0, 2, 6, 8, 12, 14, -1, -1, + 4, 6, 8, 12, 14, -1, -1, -1, + 0, 4, 6, 8, 12, 14, -1, -1, + 2, 4, 6, 8, 12, 14, -1, -1, + 0, 2, 4, 6, 8, 12, 14, -1, + 10, 12, 14, -1, -1, -1, -1, -1, + 0, 10, 12, 14, -1, -1, -1, -1, + 2, 10, 12, 14, -1, -1, -1, -1, + 0, 2, 10, 12, 14, -1, -1, -1, + 4, 10, 12, 14, -1, -1, -1, -1, + 0, 4, 10, 12, 14, -1, -1, -1, + 2, 4, 10, 12, 14, -1, -1, -1, + 0, 2, 4, 10, 12, 14, -1, -1, + 6, 10, 12, 14, -1, -1, -1, -1, + 0, 6, 10, 12, 14, -1, -1, -1, + 2, 6, 10, 12, 14, -1, -1, -1, + 0, 2, 6, 10, 12, 14, -1, -1, + 4, 6, 10, 12, 14, -1, -1, -1, + 0, 4, 6, 10, 12, 14, -1, -1, + 2, 4, 6, 10, 12, 14, -1, -1, + 0, 2, 4, 6, 10, 12, 14, -1, + 8, 10, 12, 14, -1, -1, -1, -1, + 0, 8, 10, 12, 14, -1, -1, -1, + 2, 8, 10, 12, 14, -1, -1, -1, + 0, 2, 8, 10, 12, 14, -1, -1, + 4, 8, 10, 12, 14, -1, -1, -1, + 0, 4, 8, 10, 12, 14, -1, -1, + 2, 4, 8, 10, 12, 14, -1, -1, + 0, 2, 4, 8, 10, 12, 14, -1, + 6, 8, 10, 12, 14, -1, -1, -1, + 0, 6, 8, 10, 12, 14, -1, -1, + 2, 6, 8, 10, 12, 14, -1, -1, + 0, 2, 6, 8, 10, 12, 14, -1, + 4, 6, 8, 10, 12, 14, -1, -1, + 0, 4, 6, 8, 10, 12, 14, -1, + 2, 4, 6, 8, 10, 12, 14, -1, + 0, 2, 4, 6, 8, 10, 12, 14 +}; + diff --git a/src/crypto_kem/mlkem/mlkem768/amd64/avx2/indcpa.jinc b/src/crypto_kem/mlkem/mlkem768/amd64/avx2/indcpa.jinc index 382c3a6f..23c6b279 100644 --- a/src/crypto_kem/mlkem/mlkem768/amd64/avx2/indcpa.jinc +++ b/src/crypto_kem/mlkem/mlkem768/amd64/avx2/indcpa.jinc @@ -6,7 +6,6 @@ require "gen_matrix.jinc" inline fn __indcpa_keypair(reg u64 pkp, reg u64 skp, reg ptr u8[MLKEM_SYMBYTES] randomnessp) { - stack u64 spkp sskp; stack u16[MLKEM_K*MLKEM_VECN] aa; stack u16[MLKEM_VECN] e pkpv skpv; stack u8[64] buf; @@ -15,9 +14,9 @@ fn __indcpa_keypair(reg u64 pkp, reg u64 skp, reg ptr u8[MLKEM_SYMBYTES] randomn reg u64 t64; reg u8 nonce; inline int i; + reg u64 transposed; - spkp = pkp; - sskp = skp; + () = #spill(pkp, skp); for i=0 to MLKEM_SYMBYTES/8 { @@ -29,13 +28,14 @@ fn __indcpa_keypair(reg u64 pkp, reg u64 skp, reg ptr u8[MLKEM_SYMBYTES] randomn for i=0 to MLKEM_SYMBYTES/8 { - t64 = buf[u64 i]; + #declassify t64 = buf[u64 i]; publicseed[u64 i] = t64; t64 = buf[u64 i + MLKEM_SYMBYTES/8]; noiseseed[u64 i] = t64; } - aa = __gen_matrix(publicseed, 0); + transposed = 0; + aa = _gen_matrix_avx2(aa, publicseed, transposed); nonce = 0; skpv[0:MLKEM_N], skpv[MLKEM_N:MLKEM_N], skpv[2*MLKEM_N:MLKEM_N], e[0:MLKEM_N] = _poly_getnoise_eta1_4x(skpv[0:MLKEM_N], skpv[MLKEM_N:MLKEM_N], skpv[2*MLKEM_N:MLKEM_N], e[0:MLKEM_N], noiseseed, nonce); @@ -56,8 +56,7 @@ fn __indcpa_keypair(reg u64 pkp, reg u64 skp, reg ptr u8[MLKEM_SYMBYTES] randomn pkpv = __polyvec_add2(pkpv, e); pkpv = __polyvec_reduce(pkpv); - pkp = spkp; - skp = sskp; + () = #unspill(pkp, skp); __polyvec_tobytes(skp, skpv); __polyvec_tobytes(pkp, pkpv); @@ -73,17 +72,18 @@ fn __indcpa_keypair(reg u64 pkp, reg u64 skp, reg ptr u8[MLKEM_SYMBYTES] randomn } inline -fn __indcpa_enc_0(stack u64 sctp, reg ptr u8[MLKEM_INDCPA_MSGBYTES] msgp, reg u64 pkp, reg ptr u8[MLKEM_SYMBYTES] noiseseed) +fn __indcpa_enc_0(#mmx reg u64 sctp, reg ptr u8[MLKEM_INDCPA_MSGBYTES] msgp, reg u64 pkp, reg ptr u8[MLKEM_SYMBYTES] noiseseed) { stack u16[MLKEM_VECN] pkpv sp ep bp; stack u16[MLKEM_K*MLKEM_VECN] aat; stack u16[MLKEM_N] k epp v; stack u8[MLKEM_SYMBYTES] publicseed; - stack ptr u8[MLKEM_SYMBYTES] s_noiseseed; + #mmx reg ptr u8[MLKEM_SYMBYTES] s_noiseseed; reg ptr u8[MLKEM_SYMBYTES] lnoiseseed; reg u64 i t64 ctp; reg u8 nonce; inline int w; + reg u64 transposed; pkpv = __polyvec_frombytes(pkp); @@ -91,7 +91,7 @@ fn __indcpa_enc_0(stack u64 sctp, reg ptr u8[MLKEM_INDCPA_MSGBYTES] msgp, reg u6 pkp += MLKEM_POLYVECBYTES; while (i < MLKEM_SYMBYTES/8) { - t64 = (u64)[pkp]; + #declassify t64 = (u64)[pkp]; publicseed.[u64 8 * (int)i] = t64; pkp += 8; i += 1; @@ -100,7 +100,10 @@ fn __indcpa_enc_0(stack u64 sctp, reg ptr u8[MLKEM_INDCPA_MSGBYTES] msgp, reg u6 k = _poly_frommsg_1(k, msgp); s_noiseseed = noiseseed; - aat = __gen_matrix(publicseed, 1); + + transposed = 1; + aat = _gen_matrix_avx2(aat, publicseed, transposed); + lnoiseseed = s_noiseseed; nonce = 0; @@ -128,26 +131,31 @@ fn __indcpa_enc_0(stack u64 sctp, reg ptr u8[MLKEM_INDCPA_MSGBYTES] msgp, reg u6 v = __poly_reduce(v); ctp = sctp; + __polyvec_compress(ctp, bp); ctp += MLKEM_POLYVECCOMPRESSEDBYTES; v = _poly_compress(ctp, v); } inline -fn __indcpa_enc_1(reg ptr u8[MLKEM_INDCPA_CIPHERTEXTBYTES] ctp, reg ptr u8[MLKEM_INDCPA_MSGBYTES] msgp, reg u64 pkp, reg ptr u8[MLKEM_SYMBYTES] noiseseed) -> reg ptr u8[MLKEM_INDCPA_CIPHERTEXTBYTES] +fn __indcpa_enc_1( + stack u8[MLKEM_INDCPA_CIPHERTEXTBYTES] ctp, + reg ptr u8[MLKEM_INDCPA_MSGBYTES] msgp, + reg u64 pkp, + reg ptr u8[MLKEM_SYMBYTES] noiseseed) + -> + stack u8[MLKEM_INDCPA_CIPHERTEXTBYTES] { stack u16[MLKEM_VECN] pkpv sp ep bp; stack u16[MLKEM_K*MLKEM_VECN] aat; stack u16[MLKEM_N] k epp v; stack u8[MLKEM_SYMBYTES] publicseed; - stack ptr u8[MLKEM_SYMBYTES] s_noiseseed; + #mmx reg ptr u8[MLKEM_SYMBYTES] s_noiseseed; reg ptr u8[MLKEM_SYMBYTES] lnoiseseed; - stack ptr u8[MLKEM_INDCPA_CIPHERTEXTBYTES] sctp; reg u64 i t64; reg u8 nonce; inline int w; - - sctp = ctp; + reg u64 transposed; pkpv = __polyvec_frombytes(pkp); @@ -155,7 +163,7 @@ fn __indcpa_enc_1(reg ptr u8[MLKEM_INDCPA_CIPHERTEXTBYTES] ctp, reg ptr u8[MLKEM pkp += MLKEM_POLYVECBYTES; while (i < MLKEM_SYMBYTES/8) { - t64 = (u64)[pkp]; + #declassify t64 = (u64)[pkp]; publicseed.[u64 8*(int)i] = t64; pkp += 8; i += 1; @@ -164,7 +172,10 @@ fn __indcpa_enc_1(reg ptr u8[MLKEM_INDCPA_CIPHERTEXTBYTES] ctp, reg ptr u8[MLKEM k = _poly_frommsg_1(k, msgp); s_noiseseed = noiseseed; - aat = __gen_matrix(publicseed, 1); + + transposed = 1; + aat = _gen_matrix_avx2(aat, publicseed, transposed); + lnoiseseed = s_noiseseed; nonce = 0; @@ -191,7 +202,6 @@ fn __indcpa_enc_1(reg ptr u8[MLKEM_INDCPA_CIPHERTEXTBYTES] ctp, reg ptr u8[MLKEM bp = __polyvec_reduce(bp); v = __poly_reduce(v); - ctp = sctp; ctp[0:MLKEM_POLYVECCOMPRESSEDBYTES] = __polyvec_compress_1(ctp[0:MLKEM_POLYVECCOMPRESSEDBYTES], bp); ctp[MLKEM_POLYVECCOMPRESSEDBYTES:MLKEM_POLYCOMPRESSEDBYTES], v = _poly_compress_1(ctp[MLKEM_POLYVECCOMPRESSEDBYTES:MLKEM_POLYCOMPRESSEDBYTES], v); diff --git a/src/crypto_kem/mlkem/mlkem768/amd64/avx2/keccak/keccakf1600.jinc b/src/crypto_kem/mlkem/mlkem768/amd64/avx2/keccak/keccakf1600.jinc new file mode 100644 index 00000000..7e2b6869 --- /dev/null +++ b/src/crypto_kem/mlkem/mlkem768/amd64/avx2/keccak/keccakf1600.jinc @@ -0,0 +1,169 @@ +require "keccakf1600_generic.jinc" + +// C[x] = A[x,0] ^ A[x,1] ^ A[x,2] ^ A[x,3] ^ A[x,4] +inline fn keccakf1600_theta_sum(reg ptr u64[25] a) -> reg u64[5] +{ + inline int x y; + reg u64[5] c; + + // C[x] = A[x, 0] + for x=0 to 5 + { c[x] = a[x + 0]; } + + // C[x] ^= A[x,1] ^ A[x,2] ^ A[x,3] ^ A[x,4] + for y=1 to 5 + { for x=0 to 5 + { c[x] ^= a[x + y*5]; } + } + + return c; +} + +// D[x] = C[x-1] ^ ROT(C[x+1], 1) +inline fn keccakf1600_theta_rol(reg u64[5] c) -> reg u64[5] +{ + inline int x; + reg u64[5] d; + + for x = 0 to 5 + { // D[x] = C[x + 1] + d[x] = c[(x + 1) % 5]; + + // D[x] = ROT(D[x], 1) + _, _, d[x] = #ROL_64(d[x], 1); + + // D[x] ^= C[x-1] + d[x] ^= c[(x - 1 + 5) % 5]; + } + + return d; +} + +// B[x] = ROT( (A[x',y'] ^ D[x']), r[x',y'] ) with (x',y') = M^-1 (x,y) +// +// M = (0 1) M^-1 = (1 3) x' = 1x + 3y +// (2 3) (1 0) y' = 1x + 0y +// +inline fn keccakf1600_rol_sum( + reg ptr u64[25] a, + reg u64[5] d, + inline int y) + -> + reg u64[5] +{ + inline int r x x_ y_; + reg u64[5] b; + + for x = 0 to 5 + { + x_ = (x + 3*y) % 5; + y_ = x; + r = keccakf1600_rhotates(x_, y_); + + // B[x] = A[x',y'] + b[x] = a[x_ + y_*5]; + + // B[x] ^= D[x']; + b[x] ^= d[x_]; + + // B[x] = ROT( B[x], r[x',y'] ); + if(r != 0) + { _, _, b[x] = #ROL_64(b[x], r); } + + } + + return b; +} + +// E[x, y] = B[x] ^ ( (!B[x+1]) & B[x+2] ) +// -- when x and y are 0: E[0,0] ^= RC[i]; +inline fn keccakf1600_set_row( + reg ptr u64[25] e, + reg u64[5] b, + inline int y, + stack u64 s_rc) + -> + reg ptr u64[25] +{ + inline int x x1 x2; + reg u64 t; + + for x=0 to 5 + { + x1 = (x + 1) % 5; + x2 = (x + 2) % 5; + + t = !b[x1] & b[x2]; // bmi1 + //t = b[x1]; t = !t; t &= b[x2]; + + t ^= b[x]; + if( x==0 && y==0 ){ t ^= s_rc; } + e[x + y*5] = t; + } + + return e; +} + +inline fn keccakf1600_round( + reg ptr u64[25] e, + reg ptr u64[25] a, + reg u64 rc) + -> + reg ptr u64[25] +{ + inline int y; + reg u64[5] b c d; + stack u64 s_rc; + + s_rc = rc; + + c = keccakf1600_theta_sum(a); + d = keccakf1600_theta_rol(c); + + for y = 0 to 5 + { b = keccakf1600_rol_sum(a, d, y); + e = keccakf1600_set_row(e, b, y, s_rc); + } + + return e; +} + +inline fn __keccakf1600(reg ptr u64[25] a) -> reg ptr u64[25] +{ + reg ptr u64[24] RC; + stack u64[25] s_e; + reg ptr u64[25] e; + + reg u64 c rc; + + RC = KECCAK1600_RC; + e = s_e; + + c = 0; + while (c < KECCAK_ROUNDS - 1) + { + rc = RC[(int) c]; + e = keccakf1600_round(e, a, rc); + + rc = RC[(int) c + 1]; + a = keccakf1600_round(a, e, rc); + + c += 2; + } + + return a; +} + +fn _keccakf1600(reg ptr u64[25] a) -> reg ptr u64[25] +{ + a = __keccakf1600(a); + return a; +} + +inline fn _keccakf1600_(reg ptr u64[25] a) -> reg ptr u64[25] +{ + a = a; + a = _keccakf1600(a); + a = a; + return a; +} diff --git a/src/crypto_kem/mlkem/mlkem768/amd64/avx2/keccak/keccakf1600_4x_avx2_compact.jinc b/src/crypto_kem/mlkem/mlkem768/amd64/avx2/keccak/keccakf1600_4x_avx2_compact.jinc new file mode 100644 index 00000000..0f0f4f8a --- /dev/null +++ b/src/crypto_kem/mlkem/mlkem768/amd64/avx2/keccak/keccakf1600_4x_avx2_compact.jinc @@ -0,0 +1,331 @@ + +require "keccakf1600_generic.jinc" + +u256[24] KECCAK1600_RC_AVX2 = +{ (4u64)[0x0000000000000001, 0x0000000000000001, 0x0000000000000001, 0x0000000000000001], + (4u64)[0x0000000000008082, 0x0000000000008082, 0x0000000000008082, 0x0000000000008082], + (4u64)[0x800000000000808a, 0x800000000000808a, 0x800000000000808a, 0x800000000000808a], + (4u64)[0x8000000080008000, 0x8000000080008000, 0x8000000080008000, 0x8000000080008000], + (4u64)[0x000000000000808b, 0x000000000000808b, 0x000000000000808b, 0x000000000000808b], + (4u64)[0x0000000080000001, 0x0000000080000001, 0x0000000080000001, 0x0000000080000001], + (4u64)[0x8000000080008081, 0x8000000080008081, 0x8000000080008081, 0x8000000080008081], + (4u64)[0x8000000000008009, 0x8000000000008009, 0x8000000000008009, 0x8000000000008009], + (4u64)[0x000000000000008a, 0x000000000000008a, 0x000000000000008a, 0x000000000000008a], + (4u64)[0x0000000000000088, 0x0000000000000088, 0x0000000000000088, 0x0000000000000088], + (4u64)[0x0000000080008009, 0x0000000080008009, 0x0000000080008009, 0x0000000080008009], + (4u64)[0x000000008000000a, 0x000000008000000a, 0x000000008000000a, 0x000000008000000a], + (4u64)[0x000000008000808b, 0x000000008000808b, 0x000000008000808b, 0x000000008000808b], + (4u64)[0x800000000000008b, 0x800000000000008b, 0x800000000000008b, 0x800000000000008b], + (4u64)[0x8000000000008089, 0x8000000000008089, 0x8000000000008089, 0x8000000000008089], + (4u64)[0x8000000000008003, 0x8000000000008003, 0x8000000000008003, 0x8000000000008003], + (4u64)[0x8000000000008002, 0x8000000000008002, 0x8000000000008002, 0x8000000000008002], + (4u64)[0x8000000000000080, 0x8000000000000080, 0x8000000000000080, 0x8000000000000080], + (4u64)[0x000000000000800a, 0x000000000000800a, 0x000000000000800a, 0x000000000000800a], + (4u64)[0x800000008000000a, 0x800000008000000a, 0x800000008000000a, 0x800000008000000a], + (4u64)[0x8000000080008081, 0x8000000080008081, 0x8000000080008081, 0x8000000080008081], + (4u64)[0x8000000000008080, 0x8000000000008080, 0x8000000000008080, 0x8000000000008080], + (4u64)[0x0000000080000001, 0x0000000080000001, 0x0000000080000001, 0x0000000080000001], + (4u64)[0x8000000080008008, 0x8000000080008008, 0x8000000080008008, 0x8000000080008008] +}; + +u256 ROL56 = 0x181F1E1D1C1B1A191017161514131211080F0E0D0C0B0A090007060504030201; +u256 ROL8 = 0x1E1D1C1B1A19181F16151413121110170E0D0C0B0A09080F0605040302010007; + +// C[x] = A[x,0] ^ A[x,1] ^ A[x,2] ^ A[x,3] ^ A[x,4] +inline fn keccakf1600_4x_theta_sum(reg ptr u256[25] a) -> reg u256[5] +{ + inline int x y; + reg u256[5] c; + + // C[x] = A[x, 0] + for x=0 to 5 + { c[x] = a[x + 0]; } + + // C[x] ^= A[x,1] ^ A[x,2] ^ A[x,3] ^ A[x,4] + for y=1 to 5 + { for x=0 to 5 + { c[x] ^= a[x + y*5]; } + } + + return c; +} + +inline fn keccakf1600_4x_rol(reg u256[5] a, inline int x r, reg u256 r8 r56) -> reg u256[5] +{ + reg u256 t; + + if(r == 8) + { a[x] = #VPSHUFB_256(a[x], r8); } + else { if(r == 56) + { a[x] = #VPSHUFB_256(a[x], r56); } + else + { t = #VPSLL_4u64(a[x], r); + a[x] = #VPSRL_4u64(a[x], 64 - r); + a[x] |= t; } + } + + return a; +} + +// D[x] = C[x-1] ^ ROT(C[x+1], 1) +inline fn keccakf1600_4x_theta_rol(reg u256[5] c, reg u256 r8 r56) -> reg u256[5] +{ + inline int x; + reg u256[5] d; + + for x = 0 to 5 + { // D[x] = C[x + 1] + d[x] = c[(x + 1) % 5]; + + // D[x] = ROT(D[x], 1) + d = keccakf1600_4x_rol(d, x, 1, r8, r56); + + // D[x] ^= C[x-1] + d[x] ^= c[(x - 1 + 5) % 5]; + } + + return d; +} + + +// B[x] = ROT( (A[x',y'] ^ D[x']), r[x',y'] ) with (x',y') = M^-1 (x,y) +// +// M = (0 1) M^-1 = (1 3) x' = 1x + 3y +// (2 3) (1 0) y' = 1x + 0y +// +inline fn keccakf1600_4x_rol_sum( + reg ptr u256[25] a, + reg u256[5] d, + inline int y, + reg u256 r8 r56 +) -> reg u256[5] +{ + inline int r x x_ y_; + reg u256[5] b; + + for x = 0 to 5 + { + x_ = (x + 3*y) % 5; + y_ = x; + r = keccakf1600_rhotates(x_, y_); + + // B[x] = A[x',y'] + b[x] = a[x_ + y_*5]; + + // B[x] ^= D[x']; + b[x] ^= d[x_]; + + // B[x] = ROT( B[x], r[x',y'] ); + if(r != 0) + { b = keccakf1600_4x_rol(b, x, r, r8, r56); } + } + + return b; +} + + +// E[x, y] = B[x] ^ ( (!B[x+1]) & B[x+2] ) +// -- when x and y are 0: E[0,0] ^= RC[i]; +inline fn keccakf1600_4x_set_row( + reg ptr u256[25] e, + reg u256[5] b, + inline int y, + reg u256 rc +) -> reg ptr u256[25] +{ + inline int x x1 x2; + reg u256 t; + + for x=0 to 5 + { + x1 = (x + 1) % 5; + x2 = (x + 2) % 5; + + t = #VPANDN_256(b[x1], b[x2]); + + t ^= b[x]; + if( x==0 && y==0 ){ t ^= rc; } + e[x + y*5] = t; + } + + return e; +} + + +fn keccakf1600_4x_round(reg ptr u256[25] e a, reg u256 rc r8 r56) -> reg ptr u256[25] +{ + inline int y; + reg u256[5] b c d; + + c = keccakf1600_4x_theta_sum(a); + d = keccakf1600_4x_theta_rol(c, r8, r56); + + for y = 0 to 5 + { b = keccakf1600_4x_rol_sum(a, d, y, r8, r56); + e = keccakf1600_4x_set_row(e, b, y, rc); + } + + return e; +} + +//////////////////////////////////////////////////////////////////////////////// + +inline fn __keccakf1600_4x(reg ptr u256[25] a) -> reg ptr u256[25] +{ + #mmx reg ptr u256[25] a_s; + + reg ptr u256[24] RC; + + stack u256[25] s_e; + reg ptr u256[25] e; + + reg u256 rc r8 r56; + reg u64 c; + + RC = KECCAK1600_RC_AVX2; + e = s_e; + r8 = ROL8; + r56 = ROL56; + + c = 0; + while(c < (KECCAK_ROUNDS*32)) + { + rc = RC.[(int) c]; + e = keccakf1600_4x_round(e, a, rc, r8, r56); + + // just an expensive pointer swap (#todo request feature) + a_s = a; s_e = e; + a = a_s; e = s_e; + + rc = RC.[(int) c + 32]; + a = keccakf1600_4x_round(a, e, rc, r8, r56); + + // just an expensive pointer swap (#todo request feature) + a_s = a; s_e = e; + a = a_s; e = s_e; + + c += 64; + } + + return a; +} + +fn _keccakf1600_4x_(reg ptr u256[25] a) -> reg ptr u256[25] +{ + a = __keccakf1600_4x(a); + return a; +} + +inline fn _keccakf1600_4x(reg ptr u256[25] a) -> reg ptr u256[25] +{ + a = a; + a = _keccakf1600_4x_(a); + a = a; + return a; +} + +// pack 4 keccak states (st25) into a 4-way state (st4x) +inline fn __u256x4_4u64x4 +( reg u256 x0 x1 x2 x3 +) -> reg u256, reg u256, reg u256, reg u256 { + // x0 = l00 l01 l02 l03 + // x1 = l10 l11 l12 l13 + // x2 = l20 l21 l22 l23 + // x3 = l30 l31 l32 l33 + reg u256 y0, y1, y2, y3; + y0 = #VPUNPCKL_4u64(x0, x1); // y0 = l00 l10 l02 l12 + y1 = #VPUNPCKH_4u64(x0, x1); // y1 = l01 l11 l03 l13 + y2 = #VPUNPCKL_4u64(x2, x3); // y2 = l20 l30 l22 l32 + y3 = #VPUNPCKH_4u64(x2, x3); // y3 = l21 l31 l23 l33 + + x0 = #VPERM2I128(y0, y2, 0x20); // x0 = l00 l10 l20 l30 + x1 = #VPERM2I128(y1, y3, 0x20); // x1 = l01 l11 l21 l31 + x2 = #VPERM2I128(y0, y2, 0x31); // x2 = l02 l12 l22 l32 + x3 = #VPERM2I128(y1, y3, 0x31); // x3 = l03 l13 l23 l33 + + return x0, x1, x2, x3; +} + +inline fn __st4x_pack +( reg mut ptr u256[25] st4x +, reg const ptr u64[25] st0 st1 st2 st3 +) -> reg ptr u256[25] { + inline int i; + reg u256 x0, x1, x2, x3; + reg u64 t0, t1, t2, t3; + for i = 0 to 6 { + x0 = st0[u256 i]; + x1 = st1[u256 i]; + x2 = st2[u256 i]; + x3 = st3[u256 i]; + x0, x1, x2, x3 = __u256x4_4u64x4(x0, x1, x2, x3); + st4x[4*i+0] = x0; + st4x[4*i+1] = x1; + st4x[4*i+2] = x2; + st4x[4*i+3] = x3; + } + t0 = st0[24]; + t1 = st1[24]; + t2 = st2[24]; + t3 = st3[24]; + st4x[u64 4*24+0] = t0; + st4x[u64 4*24+1] = t1; + st4x[u64 4*24+2] = t2; + st4x[u64 4*24+3] = t3; + + return st4x; +} + + + +// extracts 4 keccak states (st25) from a 4-way state (st4x) +inline fn __4u64x4_u256x4 +( reg u256 y0 y1 y2 y3 +) -> reg u256, reg u256, reg u256, reg u256 { + // y0 = l00 l10 l20 l30 + // y1 = l01 l11 l21 l31 + // y2 = l02 l12 l22 l32 + // y3 = l03 l13 l23 l33 + reg u256 x0, x1, x2, x3; + x0 = #VPERM2I128(y0, y2, 0x20); // x0 = l00 l10 l02 l12 + x1 = #VPERM2I128(y1, y3, 0x20); // x1 = l01 l11 l03 l13 + x2 = #VPERM2I128(y0, y2, 0x31); // x2 = l20 l30 l22 l32 + x3 = #VPERM2I128(y1, y3, 0x31); // x3 = l21 l31 l23 l33 + + y0 = #VPUNPCKL_4u64(x0, x1); // y0 = l00 l01 l02 l03 + y1 = #VPUNPCKH_4u64(x0, x1); // y1 = l10 l11 l12 l13 + y2 = #VPUNPCKL_4u64(x2, x3); // y2 = l20 l21 l22 l23 + y3 = #VPUNPCKH_4u64(x2, x3); // y3 = l30 l31 l32 l33 + + return y0, y1, y2, y3; +} + +inline fn __st4x_unpack +( reg mut ptr u64[25] st0 st1 st2 st3 +, reg const ptr u256[25] st4x +) -> reg ptr u64[25], reg ptr u64[25], reg ptr u64[25], reg ptr u64[25] { + inline int i; + reg u256 x0, x1, x2, x3; + reg u64 t0, t1, t2, t3; + for i = 0 to 6 { + x0 = st4x[u256 4*i+0]; + x1 = st4x[u256 4*i+1]; + x2 = st4x[u256 4*i+2]; + x3 = st4x[u256 4*i+3]; + x0, x1, x2, x3 = __4u64x4_u256x4(x0, x1, x2, x3); + st0.[u256 4*8*i] = x0; + st1.[u256 4*8*i] = x1; + st2.[u256 4*8*i] = x2; + st3.[u256 4*8*i] = x3; + } + t0 = st4x[u64 4*24+0]; + t1 = st4x[u64 4*24+1]; + t2 = st4x[u64 4*24+2]; + t3 = st4x[u64 4*24+3]; + st0.[u64 8*24] = t0; + st1.[u64 8*24] = t1; + st2.[u64 8*24] = t2; + st3.[u64 8*24] = t3; + + return st0, st1, st2, st3; +} diff --git a/src/crypto_kem/mlkem/mlkem768/amd64/avx2/keccak/keccakf1600_avx2.jinc b/src/crypto_kem/mlkem/mlkem768/amd64/avx2/keccak/keccakf1600_avx2.jinc new file mode 100644 index 00000000..bbc0d321 --- /dev/null +++ b/src/crypto_kem/mlkem/mlkem768/amd64/avx2/keccak/keccakf1600_avx2.jinc @@ -0,0 +1,316 @@ +require "keccakf1600_generic.jinc" + +u256[24] KECCAK_IOTAS = +{ (4u64)[0x0000000000000001, 0x0000000000000001, 0x0000000000000001, 0x0000000000000001] + ,(4u64)[0x0000000000008082, 0x0000000000008082, 0x0000000000008082, 0x0000000000008082] + ,(4u64)[0x800000000000808a, 0x800000000000808a, 0x800000000000808a, 0x800000000000808a] + ,(4u64)[0x8000000080008000, 0x8000000080008000, 0x8000000080008000, 0x8000000080008000] + ,(4u64)[0x000000000000808b, 0x000000000000808b, 0x000000000000808b, 0x000000000000808b] + ,(4u64)[0x0000000080000001, 0x0000000080000001, 0x0000000080000001, 0x0000000080000001] + ,(4u64)[0x8000000080008081, 0x8000000080008081, 0x8000000080008081, 0x8000000080008081] + ,(4u64)[0x8000000000008009, 0x8000000000008009, 0x8000000000008009, 0x8000000000008009] + ,(4u64)[0x000000000000008a, 0x000000000000008a, 0x000000000000008a, 0x000000000000008a] + ,(4u64)[0x0000000000000088, 0x0000000000000088, 0x0000000000000088, 0x0000000000000088] + ,(4u64)[0x0000000080008009, 0x0000000080008009, 0x0000000080008009, 0x0000000080008009] + ,(4u64)[0x000000008000000a, 0x000000008000000a, 0x000000008000000a, 0x000000008000000a] + ,(4u64)[0x000000008000808b, 0x000000008000808b, 0x000000008000808b, 0x000000008000808b] + ,(4u64)[0x800000000000008b, 0x800000000000008b, 0x800000000000008b, 0x800000000000008b] + ,(4u64)[0x8000000000008089, 0x8000000000008089, 0x8000000000008089, 0x8000000000008089] + ,(4u64)[0x8000000000008003, 0x8000000000008003, 0x8000000000008003, 0x8000000000008003] + ,(4u64)[0x8000000000008002, 0x8000000000008002, 0x8000000000008002, 0x8000000000008002] + ,(4u64)[0x8000000000000080, 0x8000000000000080, 0x8000000000000080, 0x8000000000000080] + ,(4u64)[0x000000000000800a, 0x000000000000800a, 0x000000000000800a, 0x000000000000800a] + ,(4u64)[0x800000008000000a, 0x800000008000000a, 0x800000008000000a, 0x800000008000000a] + ,(4u64)[0x8000000080008081, 0x8000000080008081, 0x8000000080008081, 0x8000000080008081] + ,(4u64)[0x8000000000008080, 0x8000000000008080, 0x8000000000008080, 0x8000000000008080] + ,(4u64)[0x0000000080000001, 0x0000000080000001, 0x0000000080000001, 0x0000000080000001] + ,(4u64)[0x8000000080008008, 0x8000000080008008, 0x8000000080008008, 0x8000000080008008] +}; + + +u256[6] KECCAK_RHOTATES_LEFT = +{ + (4u64)[41, 36, 18, 3], + (4u64)[27, 28, 62, 1], + (4u64)[39, 56, 6, 45], + (4u64)[ 8, 55, 61, 10], + (4u64)[20, 25, 15, 2], + (4u64)[14, 21, 43, 44] +}; + + +u256[6] KECCAK_RHOTATES_RIGHT = +{ + (4u64)[64-41, 64-36, 64-18, 64- 3], + (4u64)[64-27, 64-28, 64-62, 64- 1], + (4u64)[64-39, 64-56, 64- 6, 64-45], + (4u64)[64- 8, 64-55, 64-61, 64-10], + (4u64)[64-20, 64-25, 64-15, 64- 2], + (4u64)[64-14, 64-21, 64-43, 64-44] +}; + + +fn _keccakf1600_avx2(reg u256[7] state) -> reg u256[7] +{ + reg u256[9] t; + reg u256 c00 c14 d00 d14; + + reg bool zf; + reg u64 r iotas_o; + + reg ptr u256[24] iotas_p; + reg ptr u256[6] rhotates_left_p; + reg ptr u256[6] rhotates_right_p; + + iotas_p = KECCAK_IOTAS; + iotas_o = 0; + rhotates_left_p = KECCAK_RHOTATES_LEFT; + rhotates_right_p = KECCAK_RHOTATES_RIGHT; + + r = KECCAK_ROUNDS; + while + { + //######################################## Theta + c00 = #VPSHUFD_256(state[2], (4u2)[1,0,3,2]); + c14 = state[5] ^ state[3]; + t[2] = state[4] ^ state[6]; + c14 = c14 ^ state[1]; + c14 = c14 ^ t[2]; + t[4] = #VPERMQ(c14, (4u2)[2,1,0,3]); + c00 = c00 ^ state[2]; + t[0] = #VPERMQ(c00, (4u2)[1,0,3,2]); + t[1] = c14 >>4u64 63; + t[2] = c14 +4u64 c14; + t[1] = t[1] | t[2]; + d14 = #VPERMQ(t[1], (4u2)[0,3,2,1]); + d00 = t[1] ^ t[4]; + d00 = #VPERMQ(d00, (4u2)[0,0,0,0]); + c00 = c00 ^ state[0]; + c00 = c00 ^ t[0]; + t[0] = c00 >>4u64 63; + t[1] = c00 +4u64 c00; + t[1] = t[1] | t[0]; + state[2] = state[2] ^ d00; + state[0] = state[0] ^ d00; + d14 = #VPBLEND_8u32(d14, t[1], (8u1)[1,1,0,0,0,0,0,0]); + t[4] = #VPBLEND_8u32(t[4], c00, (8u1)[0,0,0,0,0,0,1,1]); + d14 = d14 ^ t[4]; + + //######################################## Rho + Pi + pre-Chi shuffle + t[3] = #VPSLLV_4u64(state[2], rhotates_left_p[0] ); + state[2] = #VPSRLV_4u64(state[2], rhotates_right_p[0] ); + state[2] = state[2] | t[3]; + state[3] = state[3] ^ d14; + t[4] = #VPSLLV_4u64(state[3], rhotates_left_p[2] ); + state[3] = #VPSRLV_4u64(state[3], rhotates_right_p[2] ); + state[3] = state[3] | t[4]; + state[4] = state[4] ^ d14; + t[5] = #VPSLLV_4u64(state[4], rhotates_left_p[3] ); + state[4] = #VPSRLV_4u64(state[4], rhotates_right_p[3] ); + state[4] = state[4] | t[5]; + state[5] = state[5] ^ d14; + t[6] = #VPSLLV_4u64(state[5], rhotates_left_p[4] ); + state[5] = #VPSRLV_4u64(state[5], rhotates_right_p[4] ); + state[5] = state[5] | t[6]; + state[6] = state[6] ^ d14; + t[3] = #VPERMQ(state[2], (4u2)[2,0,3,1]); + t[4] = #VPERMQ(state[3], (4u2)[2,0,3,1]); + t[7] = #VPSLLV_4u64(state[6], rhotates_left_p[5] ); + t[1] = #VPSRLV_4u64(state[6], rhotates_right_p[5] ); + t[1] = t[1] | t[7]; + state[1] = state[1] ^ d14; + t[5] = #VPERMQ(state[4], (4u2)[0,1,2,3]); + t[6] = #VPERMQ(state[5], (4u2)[1,3,0,2]); + t[8] = #VPSLLV_4u64(state[1], rhotates_left_p[1] ); + t[2] = #VPSRLV_4u64(state[1], rhotates_right_p[1] ); + t[2] = t[2] | t[8]; + + //######################################## Chi + t[7] = #VPSRLDQ_256(t[1], 8); + t[0] = !t[1] & t[7]; + state[3] = #VPBLEND_8u32(t[2], t[6], (8u1)[0,0,0,0,1,1,0,0]); + t[8] = #VPBLEND_8u32(t[4], t[2], (8u1)[0,0,0,0,1,1,0,0]); + state[5] = #VPBLEND_8u32(t[3], t[4], (8u1)[0,0,0,0,1,1,0,0]); + t[7] = #VPBLEND_8u32(t[2], t[3], (8u1)[0,0,0,0,1,1,0,0]); + state[3] = #VPBLEND_8u32(state[3], t[4], (8u1)[0,0,1,1,0,0,0,0]); + t[8] = #VPBLEND_8u32(t[8], t[5], (8u1)[0,0,1,1,0,0,0,0]); + state[5] = #VPBLEND_8u32(state[5], t[2], (8u1)[0,0,1,1,0,0,0,0]); + t[7] = #VPBLEND_8u32(t[7], t[6], (8u1)[0,0,1,1,0,0,0,0]); + state[3] = #VPBLEND_8u32(state[3], t[5], (8u1)[1,1,0,0,0,0,0,0]); + t[8] = #VPBLEND_8u32(t[8], t[6], (8u1)[1,1,0,0,0,0,0,0]); + state[5] = #VPBLEND_8u32(state[5], t[6], (8u1)[1,1,0,0,0,0,0,0]); + t[7] = #VPBLEND_8u32(t[7], t[4], (8u1)[1,1,0,0,0,0,0,0]); + state[3] = !state[3] & t[8]; + state[5] = !state[5] & t[7]; + state[6] = #VPBLEND_8u32(t[5], t[2], (8u1)[0,0,0,0,1,1,0,0]); + t[8] = #VPBLEND_8u32(t[3], t[5], (8u1)[0,0,0,0,1,1,0,0]); + state[3] = state[3] ^ t[3]; + state[6] = #VPBLEND_8u32(state[6], t[3], (8u1)[0,0,1,1,0,0,0,0]); + t[8] = #VPBLEND_8u32(t[8], t[4], (8u1)[0,0,1,1,0,0,0,0]); + state[5] = state[5] ^ t[5]; + state[6] = #VPBLEND_8u32(state[6], t[4], (8u1)[1,1,0,0,0,0,0,0]); + t[8] = #VPBLEND_8u32(t[8], t[2], (8u1)[1,1,0,0,0,0,0,0]); + state[6] = !state[6] & t[8]; + state[6] = state[6] ^ t[6]; + state[4] = #VPERMQ(t[1], (4u2)[0,1,3,2]); + t[8] = #VPBLEND_8u32(state[4], state[0], (8u1)[0,0,1,1,0,0,0,0]); + state[1] = #VPERMQ(t[1], (4u2)[0,3,2,1]); + state[1] = #VPBLEND_8u32(state[1], state[0], (8u1)[1,1,0,0,0,0,0,0]); + state[1] = !state[1] & t[8]; + state[2] = #VPBLEND_8u32(t[4], t[5], (8u1)[0,0,0,0,1,1,0,0]); + t[7] = #VPBLEND_8u32(t[6], t[4], (8u1)[0,0,0,0,1,1,0,0]); + state[2] = #VPBLEND_8u32(state[2], t[6], (8u1)[0,0,1,1,0,0,0,0]); + t[7] = #VPBLEND_8u32(t[7], t[3], (8u1)[0,0,1,1,0,0,0,0]); + state[2] = #VPBLEND_8u32(state[2], t[3], (8u1)[1,1,0,0,0,0,0,0]); + t[7] = #VPBLEND_8u32(t[7], t[5], (8u1)[1,1,0,0,0,0,0,0]); + state[2] = !state[2] & t[7]; + state[2] = state[2] ^ t[2]; + t[0] = #VPERMQ(t[0], (4u2)[0,0,0,0]); + state[3] = #VPERMQ(state[3], (4u2)[0,1,2,3]); + state[5] = #VPERMQ(state[5], (4u2)[2,0,3,1]); + state[6] = #VPERMQ(state[6], (4u2)[1,3,0,2]); + state[4] = #VPBLEND_8u32(t[6], t[3], (8u1)[0,0,0,0,1,1,0,0]); + t[7] = #VPBLEND_8u32(t[5], t[6], (8u1)[0,0,0,0,1,1,0,0]); + state[4] = #VPBLEND_8u32(state[4], t[5], (8u1)[0,0,1,1,0,0,0,0]); + t[7] = #VPBLEND_8u32(t[7], t[2], (8u1)[0,0,1,1,0,0,0,0]); + state[4] = #VPBLEND_8u32(state[4], t[2], (8u1)[1,1,0,0,0,0,0,0]); + t[7] = #VPBLEND_8u32(t[7], t[3], (8u1)[1,1,0,0,0,0,0,0]); + state[4] = !state[4] & t[7]; + state[0] = state[0] ^ t[0]; + state[1] = state[1] ^ t[1]; + state[4] = state[4] ^ t[4]; + + //######################################## Iota + state[0] = state[0] ^ iotas_p.[(int) iotas_o]; + iotas_o += 32; + + _,_,_,zf,r = #DEC_64(r); + }(!zf) + + return state; +} + +// converts a (plain) keccak state (st25) into the avx2 representation +inline fn __stavx2_pack +( reg const ptr u64[25] st +) -> reg u256[7] { + // 3*r256 (evitáveis...) + reg u256[7] state; + reg u256 t256_0 t256_1 t256_2; + reg u128 t128_0, t128_1; + reg u64 r; + + // [ 0 0 0 0 ] + state[0] = #VPBROADCAST_4u64(st.[u64 8*0]); + // [ 1 2 3 4 ] + state[1] = st.[u256 1*8]; + // [ 5 - ] + t128_0 = #VMOV(st[5]); + // [ 6 7 8 9 ] + state[3] = st.[u256 6*8]; + // [ 10 - ] + t128_1 = #VMOV(st[10]); + // [ 11 12 13 14 ] + state[4] = st.[u256 11*8]; + // [ 5 15 ] + r = st[15]; + t128_0 = #VPINSR_2u64(t128_0, r, 1); + // [ 16 17 18 19 ] + state[5] = st.[u256 16*8]; + // [ 10 20 ] + r = st[20]; + t128_1 = #VPINSR_2u64(t128_1, r, 1); + // alternative not currently supported: VPGATHERDQ for filling state[2] + // [ 10 20 5 15 ] + state[2] = (2u128)[t128_0, t128_1]; + // [ 21 22 23 24 ] + state[6] = st.[u256 21*8]; + + // [ 16 7 8 19 ] + t256_0 = #VPBLEND_8u32(state[3], state[5], (8u1)[1,1,0,0,0,0,1,1]); + // [ 11 22 23 14 ] + t256_1 = #VPBLEND_8u32(state[6], state[4], (8u1)[1,1,0,0,0,0,1,1]); + // [ 6 12 13 9 ] + t256_2 = #VPBLEND_8u32(state[4], state[3], (8u1)[1,1,0,0,0,0,1,1]); + + // [ 16 7 23 14 ] + state[3] = #VPBLEND_8u32(t256_0, t256_1, (8u1)[1,1,1,1,0,0,0,0]); + // [ 11 22 8 19 ] + state[4] = #VPBLEND_8u32(t256_1, t256_0, (8u1)[1,1,1,1,0,0,0,0]); + + // [ 21 17 18 24 ] + t256_0 = #VPBLEND_8u32(state[5], state[6], (8u1)[1,1,0,0,0,0,1,1]); + + // [ 21 17 13 9 ] + state[5] = #VPBLEND_8u32(t256_0, t256_2, (8u1)[1,1,1,1,0,0,0,0]); + // [ 6 12 18 24 ] + state[6] = #VPBLEND_8u32(t256_2, t256_0, (8u1)[1,1,1,1,0,0,0,0]); + + // [ 0 0 0 0 ] + // [ 1 2 3 4 ] + // [ 10 20 5 15 ] + // [ 16 7 23 14 ] + // [ 11 22 8 19 ] + // [ 21 17 13 9 ] + // [ 6 12 18 24 ] + return state; +} + +// recovers a (plain) keccak state (st25) from an avx2-encoded one +inline fn __stavx2_unpack +( reg mut ptr u64[25] st +, reg u256[7] state +) -> reg ptr u64[25] { + // 5*r256 + 2*r128(evitáveis) (+7*r256) + reg u256 t256_0 t256_1 t256_2 t256_3 t256_4; + reg u128 t128_0, t128_1; + + // [ 0, 0 ] + t128_0 = (128u) state[0]; + st[0] = #VMOVLPD(t128_0); + // [ 1, 2, 3, 4 ] + st.[u256 1*8] = state[1]; + + // [ 16, 7, 8, 19 ] + t256_0 = #VPBLEND_8u32(state[3], state[4], (8u1)[1,1,1,1,0,0,0,0]); + // [ 11, 22, 23, 14 ] + t256_1 = #VPBLEND_8u32(state[4], state[3], (8u1)[1,1,1,1,0,0,0,0]); + // [ 21, 17, 18, 24 ] + t256_2 = #VPBLEND_8u32(state[5], state[6], (8u1)[1,1,1,1,0,0,0,0]); + // [ 6, 12, 13, 9 ] + t256_3 = #VPBLEND_8u32(state[6], state[5], (8u1)[1,1,1,1,0,0,0,0]); + + // [ 5, 15 ] +// state[2] = TTT[0]; + t128_1 = #VEXTRACTI128(state[2], 1); + st[5] = #VMOVLPD(t128_1); + + // [ 6, 7, 8, 9 ] + t256_4 = #VPBLEND_8u32(t256_0, t256_3, (8u1)[1,1,0,0,0,0,1,1]); + st.[u256 6*8] = t256_4; + + // [ 10, 20 ] + t128_0 = (128u) state[2]; + st[10] = #VMOVLPD(t128_0); + + // [ 11, 12, 13, 14 ] + t256_4 = #VPBLEND_8u32(t256_3, t256_1, (8u1)[1,1,0,0,0,0,1,1]); + st.[u256 11*8] = t256_4; + + // [ 15 ] + st[15] = #VMOVHPD(t128_1); + + // [ 16, 17, 18, 19 ] + t256_4 = #VPBLEND_8u32(t256_2, t256_0, (8u1)[1,1,0,0,0,0,1,1]); + st.[u256 16*8] = t256_4; + + // [ 20 ] + st[20] = #VMOVHPD(t128_0); + + // [ 21, 22, 23, 24 ] + t256_4 = #VPBLEND_8u32(t256_1, t256_2, (8u1)[1,1,0,0,0,0,1,1]); + st.[u256 21*8] = t256_4; + + return st; +} + diff --git a/src/crypto_kem/mlkem/mlkem768/amd64/avx2/keccak/keccakf1600_generic.jinc b/src/crypto_kem/mlkem/mlkem768/amd64/avx2/keccak/keccakf1600_generic.jinc new file mode 100644 index 00000000..c11a69b8 --- /dev/null +++ b/src/crypto_kem/mlkem/mlkem768/amd64/avx2/keccak/keccakf1600_generic.jinc @@ -0,0 +1,64 @@ +param int KECCAK_ROUNDS = 24; + +u64[24] KECCAK1600_RC = +{ 0x0000000000000001 + ,0x0000000000008082 + ,0x800000000000808a + ,0x8000000080008000 + ,0x000000000000808b + ,0x0000000080000001 + ,0x8000000080008081 + ,0x8000000000008009 + ,0x000000000000008a + ,0x0000000000000088 + ,0x0000000080008009 + ,0x000000008000000a + ,0x000000008000808b + ,0x800000000000008b + ,0x8000000000008089 + ,0x8000000000008003 + ,0x8000000000008002 + ,0x8000000000000080 + ,0x000000000000800a + ,0x800000008000000a + ,0x8000000080008081 + ,0x8000000000008080 + ,0x0000000080000001 + ,0x8000000080008008 +}; + +inline fn keccakf1600_index(inline int x y) -> inline int +{ + inline int r; + r = (x % 5) + 5 * (y % 5); + return r; +} + + +inline fn keccakf1600_rho_offsets(inline int i) -> inline int +{ + inline int r x y z t; + + r = 0; + x = 1; + y = 0; + + for t = 0 to 24 + { if (i == x + 5 * y) + { r = ((t + 1) * (t + 2) / 2) % 64; } + z = (2 * x + 3 * y) % 5; + x = y; + y = z; + } + + return r; +} + +inline fn keccakf1600_rhotates(inline int x y) -> inline int +{ + inline int i r; + i = keccakf1600_index(x, y); + r = keccakf1600_rho_offsets(i); + return r; +} + diff --git a/src/crypto_kem/mlkem/mlkem768/amd64/avx2/kem.jinc b/src/crypto_kem/mlkem/mlkem768/amd64/avx2/kem.jinc index 940411f5..7be45bb2 100644 --- a/src/crypto_kem/mlkem/mlkem768/amd64/avx2/kem.jinc +++ b/src/crypto_kem/mlkem/mlkem768/amd64/avx2/kem.jinc @@ -4,11 +4,11 @@ require "verify.jinc" inline fn __crypto_kem_keypair_jazz(reg u64 pkp, reg u64 skp, reg ptr u8[MLKEM_SYMBYTES*2] randomnessp) { - stack ptr u8[MLKEM_SYMBYTES*2] s_randomnessp; + #mmx reg ptr u8[MLKEM_SYMBYTES*2] s_randomnessp; reg ptr u8[MLKEM_SYMBYTES] randomnessp1 randomnessp2; stack u8[32] h_pk; - stack u64 s_skp s_pkp; + #mmx reg u64 s_skp s_pkp; reg u64 t64; inline int i; @@ -33,6 +33,7 @@ fn __crypto_kem_keypair_jazz(reg u64 pkp, reg u64 skp, reg ptr u8[MLKEM_SYMBYTES s_skp = skp; pkp = s_pkp; t64 = MLKEM_PUBLICKEYBYTES; + h_pk = _isha3_256(h_pk, pkp, t64); skp = s_skp; @@ -45,6 +46,7 @@ fn __crypto_kem_keypair_jazz(reg u64 pkp, reg u64 skp, reg ptr u8[MLKEM_SYMBYTES randomnessp = s_randomnessp; randomnessp2 = randomnessp[MLKEM_SYMBYTES:MLKEM_SYMBYTES]; + for i=0 to MLKEM_SYMBYTES/8 { t64 = randomnessp2[u64 i]; @@ -59,7 +61,7 @@ fn __crypto_kem_enc_jazz(reg u64 ctp, reg u64 shkp, reg u64 pkp, reg ptr u8[MLKE inline int i; stack u8[MLKEM_SYMBYTES * 2] buf kr; - stack u64 s_pkp s_ctp s_shkp; + #mmx reg u64 s_pkp s_ctp s_shkp; reg u64 t64; s_pkp = pkp; @@ -72,8 +74,6 @@ fn __crypto_kem_enc_jazz(reg u64 ctp, reg u64 shkp, reg u64 pkp, reg ptr u8[MLKE buf[u64 i] = t64; } - pkp = s_pkp; - t64 = MLKEM_PUBLICKEYBYTES; buf[MLKEM_SYMBYTES:MLKEM_SYMBYTES] = _isha3_256(buf[MLKEM_SYMBYTES:MLKEM_SYMBYTES], pkp, t64); @@ -97,7 +97,7 @@ fn __crypto_kem_dec_jazz(reg u64 shkp, reg u64 ctp, reg u64 skp) { stack u8[MLKEM_INDCPA_CIPHERTEXTBYTES] ctpc; stack u8[2*MLKEM_SYMBYTES] kr buf; - stack u64 s_skp s_ctp s_shkp s_cnd; + #mmx reg u64 s_skp s_ctp s_shkp s_cnd; reg u64 pkp hp zp t64 cnd; inline int i; @@ -126,6 +126,7 @@ fn __crypto_kem_dec_jazz(reg u64 shkp, reg u64 ctp, reg u64 skp) ctpc = __indcpa_enc_1(ctpc, buf[0:MLKEM_INDCPA_MSGBYTES], pkp, kr[MLKEM_SYMBYTES:MLKEM_SYMBYTES]); ctp = s_ctp; + cnd = __verify(ctp, ctpc); s_cnd = cnd; /* avoidable ? */ @@ -136,9 +137,11 @@ fn __crypto_kem_dec_jazz(reg u64 shkp, reg u64 ctp, reg u64 skp) /* fixme: should this be done in memory? */ shkp = s_shkp; + _shake256_1120_32(shkp, zp, ctp); shkp = s_shkp; cnd = s_cnd; + __cmov(shkp, kr[0:MLKEM_SYMBYTES], cnd); }