From e6d8baa099e412a85ae100b8465978ea22c17ba7 Mon Sep 17 00:00:00 2001 From: Tiago Oliveira Date: Thu, 26 Sep 2024 17:10:58 +0100 Subject: [PATCH] gen_matrix: fix --- code/jasmin/mlkem_avx2/gen_matrix.jinc | 52 +++++++++++++------------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/code/jasmin/mlkem_avx2/gen_matrix.jinc b/code/jasmin/mlkem_avx2/gen_matrix.jinc index 1e4c64d9..4ecf9e13 100644 --- a/code/jasmin/mlkem_avx2/gen_matrix.jinc +++ b/code/jasmin/mlkem_avx2/gen_matrix.jinc @@ -35,7 +35,7 @@ param int BUF_size = 536; // 168*2+200 (was in u64s: 3*21 + 4 + 1; //544 by // deinterleave u64-lanes of 4 u256 regs -fn _4u64x4_u256x4(reg u256 y0 y1 y2 y3) -> reg u256, reg u256, reg u256, reg u256 { +inline fn _4u64x4_u256x4(reg u256 y0 y1 y2 y3) -> reg u256, reg u256, reg u256, reg u256 { reg u256 x0, x1, x2, x3; x0 = #VPERM2I128(y0, y2, 0x20); x1 = #VPERM2I128(y1, y3, 0x20); @@ -182,7 +182,7 @@ inline fn __stavx2_pack_at return state; } -fn _stavx2_pack_at +inline fn _stavx2_pack_at ( reg const ptr u8[BUF_size] st , reg u64 offset // in bytes ) -> reg u256[7] { @@ -248,7 +248,7 @@ inline fn __stavx2_unpack_at return buf; } -fn _stavx2_unpack_at +inline fn _stavx2_unpack_at ( reg mut ptr u8[BUF_size] buf , reg u64 offset // in bytes , reg u256[7] state @@ -704,7 +704,7 @@ inline fn gen_matrix_get_indexes( return idx; } -inline fn __gen_matrix_fill_polynomial +fn __gen_matrix_fill_polynomial ( reg mut ptr u16[MLKEM_N] pol , reg mut ptr u8[BUF_size] buf ) -> reg ptr u16[MLKEM_N], reg ptr u8[BUF_size] @@ -727,11 +727,11 @@ inline fn __gen_matrix_fill_polynomial fn _gen_matrix_sample_four_polynomials ( reg mut ptr u16[4*MLKEM_N] polx4 -, reg mut ptr u8[BUF_size] buf0 buf1 buf2 buf3 +, reg mut ptr u8[BUF_size * 4] buf , reg ptr u8[32] rho , reg u64 mat_entry , reg u64 transposed -) -> reg ptr u16[4*MLKEM_N], reg ptr u8[BUF_size], reg ptr u8[BUF_size], reg ptr u8[BUF_size], reg ptr u8[BUF_size] +) -> reg ptr u16[4*MLKEM_N], reg ptr u8[BUF_size * 4] { reg u64 buf_offset; reg ptr u16[MLKEM_N] pol; @@ -746,27 +746,37 @@ fn _gen_matrix_sample_four_polynomials buf_offset = 0; while (buf_offset < 3*168) { stx4 = _keccakf1600_4x(stx4); - buf0, buf1, buf2, buf3 = __st4x_unpack_at( buf0, buf1, buf2, buf3, stx4, buf_offset ); + + buf[BUF_size * 0 : BUF_size], + buf[BUF_size * 1 : BUF_size], + buf[BUF_size * 2 : BUF_size], + buf[BUF_size * 3 : BUF_size] = + __st4x_unpack_at(buf[BUF_size * 0 : BUF_size], + buf[BUF_size * 1 : BUF_size], + buf[BUF_size * 2 : BUF_size], + buf[BUF_size * 3 : BUF_size], + stx4, buf_offset ); + buf_offset += 168; } pol = polx4[0*MLKEM_N:MLKEM_N]; - pol, buf0 = __gen_matrix_fill_polynomial(pol, buf0); + pol, buf[BUF_size * 0 : BUF_size] = __gen_matrix_fill_polynomial(pol, buf[BUF_size * 0 : BUF_size]); polx4[0*MLKEM_N:MLKEM_N] = pol; pol = polx4[1*MLKEM_N:MLKEM_N]; - pol, buf1 = __gen_matrix_fill_polynomial(pol, buf1); + pol, buf[BUF_size * 1 : BUF_size] = __gen_matrix_fill_polynomial(pol, buf[BUF_size * 1 : BUF_size]); polx4[1*MLKEM_N:MLKEM_N] = pol; pol = polx4[2*MLKEM_N:MLKEM_N]; - pol, buf2 = __gen_matrix_fill_polynomial(pol, buf2); + pol, buf[BUF_size * 2 : BUF_size] = __gen_matrix_fill_polynomial(pol, buf[BUF_size * 2 : BUF_size]); polx4[2*MLKEM_N:MLKEM_N] = pol; pol = polx4[3*MLKEM_N:MLKEM_N]; - pol, buf3 = __gen_matrix_fill_polynomial(pol, buf3); + pol, buf[BUF_size * 3 : BUF_size] = __gen_matrix_fill_polynomial(pol, buf[BUF_size * 3 : BUF_size]); polx4[3*MLKEM_N:MLKEM_N] = pol; - return polx4, buf0, buf1, buf2, buf3; + return polx4, buf; } inline fn __gen_matrix_sample_one_polynomial @@ -800,8 +810,8 @@ fn _gen_matrix_avx2 { // local variables inline int i j; - stack u8[BUF_size] buf0_s, buf1_s, buf2_s, buf3_s; - reg ptr u8[BUF_size] buf0, buf1, buf2, buf3; + stack u8[BUF_size * 4] buf_s; + reg ptr u8[BUF_size * 4] buf; reg ptr u16[4*MLKEM_N] polx4; reg ptr u16[MLKEM_N] pol; reg u64 mat_entry; @@ -809,18 +819,14 @@ fn _gen_matrix_avx2 () = #spill(transposed); - buf0 = buf0_s; - buf1 = buf1_s; - buf2 = buf2_s; - buf3 = buf3_s; + buf = buf_s; for i = 0 to 2 { mat_entry = 4*i; polx4 = matrix[4*i*MLKEM_N:4*MLKEM_N]; () = #unspill(transposed); - polx4, buf0, buf1, buf2, buf3 - = _gen_matrix_sample_four_polynomials(polx4, buf0, buf1, buf2 ,buf3, rho, mat_entry, transposed); + polx4, buf = _gen_matrix_sample_four_polynomials(polx4, buf, rho, mat_entry, transposed); matrix[i*4*MLKEM_N:4*MLKEM_N] = polx4; } @@ -828,13 +834,9 @@ fn _gen_matrix_avx2 // sample the last one, (2,2), using scalar code pol = matrix[8*MLKEM_N:MLKEM_N]; rc = 0x0202; - pol, buf0 = __gen_matrix_sample_one_polynomial(pol, buf0, rho, rc); + pol, buf[BUF_size * 0 : BUF_size] = __gen_matrix_sample_one_polynomial(pol, buf[BUF_size * 0 : BUF_size], rho, rc); matrix[8*MLKEM_N:MLKEM_N] = pol; - buf0_s = buf0; - buf1_s = buf1; - buf2_s = buf2; - buf3_s = buf3; for i = 0 to MLKEM_K { for j = 0 to MLKEM_K