From 97cc8633bcfed621e1216653ee28674c9d00958e Mon Sep 17 00:00:00 2001 From: Tiago Oliveira Date: Tue, 9 Apr 2024 17:32:45 +0100 Subject: [PATCH] mlkem768 ref sct: work in progress - recording the state to try a different approach for gen_matrix -- current overhead between 3% and 5% --- code/jasmin/mlkem_ref/Makefile | 3 + code/jasmin/mlkem_ref/fips202.jinc | 95 +++++++++++++-------------- code/jasmin/mlkem_ref/gen_matrix.jinc | 48 +++++++++++--- code/jasmin/mlkem_ref/indcpa.jinc | 35 ++++++---- code/jasmin/mlkem_ref/jkem.jazz | 9 +++ code/jasmin/mlkem_ref/kem.jinc | 24 +++++++ code/jasmin/mlkem_ref/poly.jinc | 10 +-- 7 files changed, 147 insertions(+), 77 deletions(-) diff --git a/code/jasmin/mlkem_ref/Makefile b/code/jasmin/mlkem_ref/Makefile index 86469323..029ea8fb 100644 --- a/code/jasmin/mlkem_ref/Makefile +++ b/code/jasmin/mlkem_ref/Makefile @@ -86,6 +86,9 @@ test/test_polyvec_%: test/test_polyvec_%.c $(HEADERS) $(SOURCES) jpolyvec.s ct: $(JASMINC) -checkCT -infer jkem.jazz +sct: + $(JASMINC) -checkSCT jkem.jazz + clean: -rm -f *.s -rm -f jindcpa.o diff --git a/code/jasmin/mlkem_ref/fips202.jinc b/code/jasmin/mlkem_ref/fips202.jinc index 73e5bc87..da081565 100644 --- a/code/jasmin/mlkem_ref/fips202.jinc +++ b/code/jasmin/mlkem_ref/fips202.jinc @@ -370,7 +370,9 @@ fn ____xtr_bytes( } } +// Note: the following code is not used; Todo: Remove it after double checking why it is here. +/* inline fn ____keccak1600_ref( stack u64 s_out s_outlen, @@ -462,17 +464,16 @@ fn __shake256(reg u64 out outlen in inlen) config[1] = rate; __keccak1600_ref(out, outlen, in, inlen, config); } +*/ -fn _shake256_128_33(reg ptr u8[128] out, reg const ptr u8[33] in) -> stack u8[128] +fn _shake256_128_33(#spill_to_mmx reg ptr u8[128] out, reg const ptr u8[33] in) -> stack u8[128] { stack u64[25] state; reg u8 c; inline int i; - stack ptr u8[128] sout; - - sout = out; + () = #spill(out); state = __st0(state); @@ -485,7 +486,7 @@ fn _shake256_128_33(reg ptr u8[128] out, reg const ptr u8[33] in) -> stack u8[12 state = _keccakf1600_(state); - out = sout; + () = #unspill(out); for i = 0 to 128 { c = state[u8 (int) i]; @@ -494,16 +495,15 @@ fn _shake256_128_33(reg ptr u8[128] out, reg const ptr u8[33] in) -> stack u8[12 return out; } -fn _shake256_1120_32(reg u64 out, reg u64 in0 in1) { +fn _shake256_1120_32(#spill_to_mmx reg u64 out in0 in1) { stack u64[25] state; - stack u64 s_out s_in1; - stack u64 s_in s_ilen s_r8; - reg u64 ilen r8 t64 in; + #spill_to_mmx reg u64 ilen r8; + reg u64 t64; reg u8 t8; inline int i; - s_out = out; - s_in1 = in1; + () = #spill(out); + state = __st0(state); for i = 0 to MLKEM_SYMBYTES/8 { @@ -515,35 +515,34 @@ fn _shake256_1120_32(reg u64 out, reg u64 in0 in1) { t64 = (u64)[in1 + (i-MLKEM_SYMBYTES/8)*8]; state[u64 i] ^= t64; } - + + () = #spill(in1); + state = _keccakf1600_(state); + () = #unspill(in1); + r8 = SHAKE256_RATE; ilen = MLKEM_CT_LEN - (SHAKE256_RATE - MLKEM_SYMBYTES); - in = s_in1; - in += SHAKE256_RATE - MLKEM_SYMBYTES; + in1 += SHAKE256_RATE - MLKEM_SYMBYTES; while(ilen >= r8) { - state, in, ilen = __add_full_block(state, in, ilen, r8); + state, in1, ilen = __add_full_block(state, in1, ilen, r8); - s_in = in; - s_ilen = ilen; - s_r8 = r8; + () = #spill(in1, ilen, r8); state = _keccakf1600_(state); - in = s_in; - ilen = s_ilen; - r8 = s_r8; + () = #unspill(in1, ilen, r8); } t8 = 0x1f; - state = __add_final_block(state, in, ilen, t8, r8); + state = __add_final_block(state, in1, ilen, t8, r8); state = _keccakf1600_(state); - out = s_out; + () = #unspill(out); for i=0 to MLKEM_SYMBYTES/8 { @@ -554,14 +553,13 @@ fn _shake256_1120_32(reg u64 out, reg u64 in0 in1) { } -fn _sha3512_32(reg ptr u8[64] out, reg const ptr u8[32] in) -> stack u8[64] +fn _sha3512_32(#spill_to_mmx reg ptr u8[64] out, reg const ptr u8[32] in) -> stack u8[64] { stack u64[25] state; reg u8 c; inline int i; - stack ptr u8[64] s_out; - s_out = out; + () = #spill(out); state = __st0(state); @@ -574,11 +572,13 @@ fn _sha3512_32(reg ptr u8[64] out, reg const ptr u8[32] in) -> stack u8[64] state = _keccakf1600_(state); - out = s_out; + () = #unspill(out); + for i = 0 to 64 { c = state[u8 (int) i]; out[i] = c; } + return out; } @@ -601,17 +601,17 @@ fn _shake128_absorb34(reg ptr u64[25] state, reg const ptr u8[34] in) -> reg ptr } -fn _shake128_squeezeblock(reg ptr u64[25] state, reg ptr u8[SHAKE128_RATE] out) -> reg ptr u64[25], reg ptr u8[SHAKE128_RATE] +fn _shake128_squeezeblock(reg ptr u64[25] state, #spill_to_mmx reg ptr u8[SHAKE128_RATE] out) -> reg ptr u64[25], reg ptr u8[SHAKE128_RATE] { - stack ptr u8[SHAKE128_RATE] s_out; reg u8 c; inline int i; - s_out = out; + () = #spill(out); state = _keccakf1600_(state); - out = s_out; + () = #unspill(out); + for i = 0 to SHAKE128_RATE { // SHAKE128 rate is 168: or 21 u64: TODO: 'compress' this for loop c = state[u8 (int) i]; out[i] = c; @@ -621,16 +621,15 @@ fn _shake128_squeezeblock(reg ptr u64[25] state, reg ptr u8[SHAKE128_RATE] out) #[returnaddress="stack"] -fn _isha3_256(reg ptr u8[32] out, reg u64 in inlen) -> reg ptr u8[32] +fn _isha3_256(#spill_to_mmx reg ptr u8[32] out, #spill_to_mmx reg u64 in inlen) -> reg ptr u8[32] { stack u64[25] state; - stack ptr u8[32] s_out; - stack u64 s_in s_ilen s_r8; - reg u64 ilen r8 t64; + #spill_to_mmx reg u64 ilen r8; + reg u64 t64; reg u8 t8; inline int i; - s_out = out; + () = #spill(out); state = __st0(state); @@ -641,15 +640,11 @@ fn _isha3_256(reg ptr u8[32] out, reg u64 in inlen) -> reg ptr u8[32] { state, in, ilen = __add_full_block(state, in, ilen, r8); - s_in = in; - s_ilen = ilen; - s_r8 = r8; + () = #spill(in, ilen, r8); state = _keccakf1600_(state); - in = s_in; - ilen = s_ilen; - r8 = s_r8; + () = #unspill(in, ilen, r8); } t8 = 0x06; @@ -657,7 +652,7 @@ fn _isha3_256(reg ptr u8[32] out, reg u64 in inlen) -> reg ptr u8[32] state = _keccakf1600_(state); - out = s_out; + () = #unspill(out); for i=0 to 4 { @@ -669,14 +664,13 @@ fn _isha3_256(reg ptr u8[32] out, reg u64 in inlen) -> reg ptr u8[32] } #[returnaddress="stack"] -fn _isha3_256_32(reg ptr u8[32] out, reg ptr u8[MLKEM_SYMBYTES] in) -> reg ptr u8[32] +fn _isha3_256_32(#spill_to_mmx reg ptr u8[32] out, reg ptr u8[MLKEM_SYMBYTES] in) -> reg ptr u8[32] { stack u64[25] state; - stack ptr u8[32] s_out; reg u64 t64; inline int i; - s_out = out; + () = #spill(out); state = __st0(state); @@ -691,7 +685,7 @@ fn _isha3_256_32(reg ptr u8[32] out, reg ptr u8[MLKEM_SYMBYTES] in) -> reg ptr u state = _keccakf1600_(state); - out = s_out; + () = #unspill(out); for i=0 to 4 { @@ -703,10 +697,9 @@ fn _isha3_256_32(reg ptr u8[32] out, reg ptr u8[MLKEM_SYMBYTES] in) -> reg ptr u } #[returnaddress="stack"] -fn _sha3_512_64(reg ptr u8[64] out, reg const ptr u8[64] in) -> stack u8[64] +fn _sha3_512_64(#spill_to_mmx reg ptr u8[64] out, reg const ptr u8[64] in) -> stack u8[64] { stack u64[25] state; - stack ptr u8[64] out_s; reg u64 t64; inline int i; @@ -721,11 +714,11 @@ fn _sha3_512_64(reg ptr u8[64] out, reg const ptr u8[64] in) -> stack u8[64] state[u8 64] ^= 0x06; state[u8 SHA3_512_RATE - 1] ^= 0x80; - out_s = out; + () = #spill(out); state = _keccakf1600_(state); - out = out_s; + () = #unspill(out); for i = 0 to 8 { diff --git a/code/jasmin/mlkem_ref/gen_matrix.jinc b/code/jasmin/mlkem_ref/gen_matrix.jinc index f261b711..f622e02d 100644 --- a/code/jasmin/mlkem_ref/gen_matrix.jinc +++ b/code/jasmin/mlkem_ref/gen_matrix.jinc @@ -8,14 +8,28 @@ fn __rej_uniform(stack u16[MLKEM_N] rp, reg u64 offset, stack u8[SHAKE128_RATE] reg u16 t; reg u64 pos ctr; + #msf reg u64 ms; + reg bool cond; + + ms = #init_msf(); ctr = offset; pos = 0; - while (pos < SHAKE128_RATE - 2) { - if ctr < MLKEM_N { + while { cond = (pos < SHAKE128_RATE - 2); } (cond) { + + ms = #update_msf(cond, ms); + + cond = ctr < MLKEM_N; + if cond { + ms = #update_msf(cond, ms); + val1 = (16u)buf[pos]; - t = (16u)buf[pos + 1]; + val1 = #protect_16(val1, ms); + + t = (16u)buf[pos + 1]; + t = #protect_16(t, ms); + val2 = t; val2 >>= 4; t &= 0x0F; @@ -23,30 +37,42 @@ fn __rej_uniform(stack u16[MLKEM_N] rp, reg u64 offset, stack u8[SHAKE128_RATE] val1 |= t; t = (16u)buf[pos + 2]; + t = #protect_16(t, ms); + t <<= 4; val2 |= t; pos += 3; - reg bool cond; - #[declassify] cond = val1 < MLKEM_Q; if cond { + ms = #update_msf(cond, ms); rp[ctr] = val1; ctr += 1; + } else { + ms = #update_msf(!cond, ms); } - #[declassify] cond = val2 < MLKEM_Q; if cond { - if(ctr < MLKEM_N) - { + ms = #update_msf(cond, ms); + + cond = ctr < MLKEM_N; + if cond { + ms = #update_msf(cond, ms); rp[ctr] = val2; ctr += 1; + } else { + ms = #update_msf(!cond, ms); } + } else { + ms = #update_msf(!cond, ms); } + } else { + ms = #update_msf(!cond, ms); pos = SHAKE128_RATE; } + } return ctr, rp; @@ -64,8 +90,8 @@ fn __gen_matrix(stack u8[MLKEM_SYMBYTES] seed, reg u64 transposed) -> stack u16[ reg u8 c; reg u16 t; reg u64 ctr k; - stack u64 sctr; - stack u64 stransposed; + #mmx reg u64 sctr; + #mmx reg u64 stransposed; inline int j i; stransposed = transposed; @@ -81,6 +107,7 @@ fn __gen_matrix(stack u8[MLKEM_SYMBYTES] seed, reg u64 transposed) -> stack u16[ for j = 0 to MLKEM_K { transposed = stransposed; + if(transposed == 0) { extseed[MLKEM_SYMBYTES] = j; @@ -100,6 +127,7 @@ fn __gen_matrix(stack u8[MLKEM_SYMBYTES] seed, reg u64 transposed) -> stack u16[ sctr = ctr; state, buf = _shake128_squeezeblock(state, buf); ctr = sctr; + ctr, poly = __rej_uniform(poly, ctr, buf); } diff --git a/code/jasmin/mlkem_ref/indcpa.jinc b/code/jasmin/mlkem_ref/indcpa.jinc index a0f2a463..c52f1682 100644 --- a/code/jasmin/mlkem_ref/indcpa.jinc +++ b/code/jasmin/mlkem_ref/indcpa.jinc @@ -10,6 +10,8 @@ fn __indcpa_keypair(reg u64 pkp, reg u64 skp, reg ptr u8[MLKEM_SYMBYTES] randomn stack u16[MLKEM_VECN] e pkpv skpv; stack u8[64] buf; stack u8[MLKEM_SYMBYTES] publicseed noiseseed; + reg ptr u8[MLKEM_SYMBYTES] r_noiseseed; + #mmx reg ptr u8[MLKEM_SYMBYTES] s_noiseseed; stack u8[32] inbuf; reg u64 t64; reg u64 zero; @@ -32,28 +34,31 @@ fn __indcpa_keypair(reg u64 pkp, reg u64 skp, reg ptr u8[MLKEM_SYMBYTES] randomn for i=0 to MLKEM_SYMBYTES/8 { - t64 = buf[u64 i]; + #declassify t64 = buf[u64 i]; publicseed[u64 i] = t64; t64 = buf[u64 i + MLKEM_SYMBYTES/8]; noiseseed[u64 i] = t64; } + r_noiseseed = noiseseed; // currently, it is not possible to load stack to mmx, so: first to register, and then to mmx + s_noiseseed = r_noiseseed; + zero = 0; a = __gen_matrix(publicseed, zero); nonce = 0; - skpv[0:MLKEM_N] = _poly_getnoise(skpv[0:MLKEM_N], noiseseed, nonce); + skpv[0:MLKEM_N] = _poly_getnoise(skpv[0:MLKEM_N], s_noiseseed, nonce); nonce = 1; - skpv[MLKEM_N:MLKEM_N] = _poly_getnoise(skpv[MLKEM_N:MLKEM_N], noiseseed, nonce); + skpv[MLKEM_N:MLKEM_N] = _poly_getnoise(skpv[MLKEM_N:MLKEM_N], s_noiseseed, nonce); nonce = 2; - skpv[2*MLKEM_N:MLKEM_N] = _poly_getnoise(skpv[2*MLKEM_N:MLKEM_N], noiseseed, nonce); + skpv[2*MLKEM_N:MLKEM_N] = _poly_getnoise(skpv[2*MLKEM_N:MLKEM_N], s_noiseseed, nonce); nonce = 3; - e[0:MLKEM_N] = _poly_getnoise(e[0:MLKEM_N], noiseseed, nonce); + e[0:MLKEM_N] = _poly_getnoise(e[0:MLKEM_N], s_noiseseed, nonce); nonce = 4; - e[MLKEM_N:MLKEM_N] = _poly_getnoise(e[MLKEM_N:MLKEM_N], noiseseed, nonce); + e[MLKEM_N:MLKEM_N] = _poly_getnoise(e[MLKEM_N:MLKEM_N], s_noiseseed, nonce); nonce = 5; - e[2*MLKEM_N:MLKEM_N] = _poly_getnoise(e[2*MLKEM_N:MLKEM_N], noiseseed, nonce); + e[2*MLKEM_N:MLKEM_N] = _poly_getnoise(e[2*MLKEM_N:MLKEM_N], s_noiseseed, nonce); skpv = __polyvec_ntt(skpv); e = __polyvec_ntt(e); @@ -71,6 +76,8 @@ fn __indcpa_keypair(reg u64 pkp, reg u64 skp, reg ptr u8[MLKEM_SYMBYTES] randomn pkp = spkp; skp = sskp; + _ = #init_msf(); // temporary fix + __polyvec_tobytes(skp, skpv); __polyvec_tobytes(pkp, pkpv); @@ -93,7 +100,7 @@ fn __indcpa_enc(stack u64 sctp, reg ptr u8[32] msgp, reg u64 pkp, reg ptr u8[MLK reg u64 i t64; reg u64 ctp; reg u8 nonce; - stack ptr u8[MLKEM_SYMBYTES] s_noiseseed; + #mmx reg ptr u8[MLKEM_SYMBYTES] s_noiseseed; s_noiseseed = noiseseed; @@ -103,7 +110,7 @@ fn __indcpa_enc(stack u64 sctp, reg ptr u8[32] msgp, reg u64 pkp, reg ptr u8[MLK pkp += MLKEM_POLYVECBYTES; while (i < MLKEM_SYMBYTES/8) { - t64 = (u64)[pkp]; + #declassify t64 = (u64)[pkp]; publicseed.[u64 8 * (int)i] = t64; pkp += 8; i += 1; @@ -152,6 +159,9 @@ fn __indcpa_enc(stack u64 sctp, reg ptr u8[32] msgp, reg u64 pkp, reg ptr u8[MLK v = __poly_reduce(v); ctp = sctp; + + _ = #init_msf(); // temporary fix + __polyvec_compress(ctp, bp); ctp += MLKEM_POLYVECCOMPRESSEDBYTES; v = _poly_compress(ctp, v); @@ -167,7 +177,7 @@ fn __iindcpa_enc(reg ptr u8[MLKEM_CT_LEN] ctp, reg ptr u8[32] msgp, reg u64 pkp, reg u64 i t64; reg u8 nonce; stack ptr u8[MLKEM_CT_LEN] sctp; - stack ptr u8[MLKEM_SYMBYTES] s_noiseseed; + #mmx reg ptr u8[MLKEM_SYMBYTES] s_noiseseed; s_noiseseed = noiseseed; sctp = ctp; @@ -178,7 +188,7 @@ fn __iindcpa_enc(reg ptr u8[MLKEM_CT_LEN] ctp, reg ptr u8[32] msgp, reg u64 pkp, pkp += MLKEM_POLYVECBYTES; while (i < MLKEM_SYMBYTES/8) { - t64 = (u64)[pkp]; + #declassify t64 = (u64)[pkp]; publicseed.[u64 8*(int)i] = t64; pkp += 8; i += 1; @@ -227,6 +237,9 @@ fn __iindcpa_enc(reg ptr u8[MLKEM_CT_LEN] ctp, reg ptr u8[32] msgp, reg u64 pkp, v = __poly_reduce(v); ctp = sctp; + + _ = #init_msf(); + ctp[0:MLKEM_POLYVECCOMPRESSEDBYTES] = __i_polyvec_compress(ctp[0:MLKEM_POLYVECCOMPRESSEDBYTES], bp); ctp[MLKEM_POLYVECCOMPRESSEDBYTES:MLKEM_POLYCOMPRESSEDBYTES], v = _i_poly_compress(ctp[MLKEM_POLYVECCOMPRESSEDBYTES:MLKEM_POLYCOMPRESSEDBYTES], v); diff --git a/code/jasmin/mlkem_ref/jkem.jazz b/code/jasmin/mlkem_ref/jkem.jazz index 10f7ad0e..c7dfd900 100644 --- a/code/jasmin/mlkem_ref/jkem.jazz +++ b/code/jasmin/mlkem_ref/jkem.jazz @@ -7,6 +7,8 @@ export fn jade_kem_mlkem_mlkem768_amd64_ref_keypair_derand(reg u64 public_key se reg ptr u8[MLKEM_SYMBYTES*2] randomnessp; inline int i; + _ = #init_msf(); + public_key = public_key; secret_key = secret_key; @@ -28,6 +30,8 @@ export fn jade_kem_mlkem_mlkem768_amd64_ref_enc_derand(reg u64 ciphertext shared reg ptr u8[MLKEM_SYMBYTES] randomnessp; inline int i; + _ = #init_msf(); + ciphertext = ciphertext; shared_secret = shared_secret; public_key = public_key; @@ -54,6 +58,8 @@ export fn jade_kem_mlkem_mlkem768_amd64_ref_keypair(reg u64 public_key secret_ke randomnessp = randomness; randomnessp = #randombytes(randomnessp); + _ = #init_msf(); + __crypto_kem_keypair_jazz(public_key, secret_key, randomnessp); ?{}, r = #set0(); return r; @@ -71,6 +77,8 @@ export fn jade_kem_mlkem_mlkem768_amd64_ref_enc(reg u64 ciphertext shared_secret randomnessp = randomness; randomnessp = #randombytes(randomnessp); + _ = #init_msf(); + __crypto_kem_enc_jazz(ciphertext, shared_secret, public_key, randomnessp); ?{}, r = #set0(); return r; @@ -79,6 +87,7 @@ export fn jade_kem_mlkem_mlkem768_amd64_ref_enc(reg u64 ciphertext shared_secret export fn jade_kem_mlkem_mlkem768_amd64_ref_dec(reg u64 shared_secret ciphertext secret_key) -> reg u64 { reg u64 r; + _ = #init_msf(); __crypto_kem_dec_jazz(shared_secret, ciphertext, secret_key); ?{}, r = #set0(); return r; diff --git a/code/jasmin/mlkem_ref/kem.jinc b/code/jasmin/mlkem_ref/kem.jinc index ee8c60ea..8bee8f67 100644 --- a/code/jasmin/mlkem_ref/kem.jinc +++ b/code/jasmin/mlkem_ref/kem.jinc @@ -23,6 +23,8 @@ fn __crypto_kem_keypair_jazz(reg u64 pkp, reg u64 skp, reg ptr u8[MLKEM_SYMBYTES skp += MLKEM_POLYVECBYTES; pkp = s_pkp; + _ = #init_msf(); // temporary fix + for i=0 to MLKEM_INDCPA_PUBLICKEYBYTES/8 { t64 = (u64)[pkp + 8*i]; @@ -32,10 +34,15 @@ fn __crypto_kem_keypair_jazz(reg u64 pkp, reg u64 skp, reg ptr u8[MLKEM_SYMBYTES s_skp = skp; pkp = s_pkp; + + _ = #init_msf(); // temporary fix + t64 = MLKEM_POLYVECBYTES + MLKEM_SYMBYTES; h_pk = _isha3_256(h_pk, pkp, t64); skp = s_skp; + _ = #init_msf(); // temporary fix + for i=0 to 4 { t64 = h_pk[u64 i]; @@ -45,6 +52,9 @@ fn __crypto_kem_keypair_jazz(reg u64 pkp, reg u64 skp, reg ptr u8[MLKEM_SYMBYTES randomnessp = s_randomnessp; randomnessp2 = randomnessp[MLKEM_SYMBYTES:MLKEM_SYMBYTES]; + + _ = #init_msf(); // temporary fix + for i=0 to MLKEM_SYMBYTES/8 { t64 = randomnessp2[u64 i]; @@ -74,6 +84,8 @@ fn __crypto_kem_enc_jazz(reg u64 ctp, reg u64 shkp, reg u64 pkp, reg ptr u8[MLKE pkp = s_pkp; + _ = #init_msf(); // temporary fix + t64 = MLKEM_PUBLICKEYBYTES; buf[MLKEM_SYMBYTES:MLKEM_SYMBYTES] = _isha3_256(buf[MLKEM_SYMBYTES:MLKEM_SYMBYTES], pkp, t64); @@ -81,10 +93,14 @@ fn __crypto_kem_enc_jazz(reg u64 ctp, reg u64 shkp, reg u64 pkp, reg ptr u8[MLKE pkp = s_pkp; + _ = #init_msf(); // temporary fix + __indcpa_enc(s_ctp, buf[0:MLKEM_SYMBYTES], pkp, kr[MLKEM_SYMBYTES:MLKEM_SYMBYTES]); shkp = s_shkp; + _ = #init_msf(); // temporary fix + for i=0 to MLKEM_SYMBYTES/8 { t64 = kr[u64 i]; @@ -123,9 +139,14 @@ fn __crypto_kem_dec_jazz(reg u64 shkp, reg u64 ctp, reg u64 skp) pkp = s_skp; pkp += 12 * MLKEM_K * MLKEM_N>>3; + _ = #init_msf(); // temporary fix + ctpc = __iindcpa_enc(ctpc, buf[0:MLKEM_SYMBYTES], pkp, kr[MLKEM_SYMBYTES:MLKEM_SYMBYTES]); ctp = s_ctp; + + _ = #init_msf(); // temporary fix + cnd = __verify(ctp, ctpc); s_cnd = cnd; @@ -139,5 +160,8 @@ fn __crypto_kem_dec_jazz(reg u64 shkp, reg u64 ctp, reg u64 skp) shkp = s_shkp; cnd = s_cnd; + + _ = #init_msf(); // temporary fix + __cmov(shkp, kr[0:MLKEM_SYMBYTES], cnd); } diff --git a/code/jasmin/mlkem_ref/poly.jinc b/code/jasmin/mlkem_ref/poly.jinc index 19cc55bb..c51d8af5 100644 --- a/code/jasmin/mlkem_ref/poly.jinc +++ b/code/jasmin/mlkem_ref/poly.jinc @@ -398,7 +398,7 @@ fn _i_poly_frommsg(reg ptr u16[MLKEM_N] rp, reg ptr u8[32] ap) -> stack u16[MLKE return rp; } -fn _poly_getnoise(reg ptr u16[MLKEM_N] rp, reg ptr u8[MLKEM_SYMBYTES] seed, reg u8 nonce) -> reg ptr u16[MLKEM_N] +fn _poly_getnoise(#spill_to_mmx reg ptr u16[MLKEM_N] rp, #mmx reg ptr u8[MLKEM_SYMBYTES] s_seed, reg u8 nonce) -> reg ptr u16[MLKEM_N] { stack u8[33] extseed; /* 33 = MLKEM_SYMBYTES +1 */ stack u8[128] buf; /* 128 = MLKEM_ETA*MLKEM_N/4 */ @@ -406,10 +406,10 @@ fn _poly_getnoise(reg ptr u16[MLKEM_N] rp, reg ptr u8[MLKEM_SYMBYTES] seed, reg reg u16 t; reg u64 i; inline int k; + reg ptr u8[MLKEM_SYMBYTES] seed; - stack ptr u16[MLKEM_N] srp; - - srp = rp; + () = #spill(rp); + seed = s_seed; for k = 0 to MLKEM_SYMBYTES { @@ -420,7 +420,7 @@ fn _poly_getnoise(reg ptr u16[MLKEM_N] rp, reg ptr u8[MLKEM_SYMBYTES] seed, reg buf = _shake256_128_33(buf, extseed); - rp = srp; + () = #unspill(rp); i = 0; while (i < 128) {