diff --git a/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec b/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec index 73c38480..8370b710 100644 --- a/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec +++ b/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec @@ -4415,12 +4415,13 @@ module M(SC:Syscall_t) = { var buf:W8.t Array168.t; var poly:W16.t Array256.t; var k:W64.t; - var l:W64.t; + var rij:W16.t Array256.t; var t:W16.t; buf <- witness; extseed <- witness; poly <- witness; r <- witness; + rij <- witness; state <- witness; stransposed <- transposed; j <- 0; @@ -4451,14 +4452,17 @@ module M(SC:Syscall_t) = { (ctr, poly) <@ __rej_uniform (poly, ctr, buf); } k <- (W64.of_int 0); - l <- (W64.of_int ((i * (3 * 256)) + (j * 256))); + rij <- + (Array256.init (fun i_0 => r.[((i * (3 * 256)) + (j * 256)) + i_0])); while ((k \ult (W64.of_int 256))) { t <- poly.[(W64.to_uint k)]; - r.[(W64.to_uint l)] <- t; + rij.[(W64.to_uint k)] <- t; k <- (k + (W64.of_int 1)); - l <- (l + (W64.of_int 1)); } + r <- Array2304.init + (fun i_0 => if ((i * (3 * 256)) + (j * 256)) <= i_0 < ((i * (3 * 256)) + (j * 256)) + 256 + then rij.[i_0-((i * (3 * 256)) + (j * 256))] else r.[i_0]); j <- j + 1; } i <- i + 1; diff --git a/code/jasmin/mlkem_avx2/gen_matrix.jinc b/code/jasmin/mlkem_avx2/gen_matrix.jinc index 01e6bf4f..3fbcfa20 100644 --- a/code/jasmin/mlkem_avx2/gen_matrix.jinc +++ b/code/jasmin/mlkem_avx2/gen_matrix.jinc @@ -73,7 +73,7 @@ fn __gen_matrix(stack u8[MLKEM_SYMBYTES] seed, reg u64 transposed) -> stack u16[ reg u8 c; reg u16 t; - reg u64 ctr k l; + reg u64 ctr k; stack u64 sctr; stack u64 stransposed; inline int j i; @@ -114,14 +114,15 @@ fn __gen_matrix(stack u8[MLKEM_SYMBYTES] seed, reg u64 transposed) -> stack u16[ } k = 0; - l = i * MLKEM_VECN + j * MLKEM_N; + reg ptr u16[MLKEM_N] rij; + rij = r[i * MLKEM_VECN + j * MLKEM_N : MLKEM_N]; while (k < MLKEM_N) { t = poly[(int) k]; - r[(int) l] = t; + rij[k] = t; k += 1; - l += 1; } + r[i * MLKEM_VECN + j * MLKEM_N : MLKEM_N] = rij; } } diff --git a/code/jasmin/mlkem_ref/extraction/jkem.ec b/code/jasmin/mlkem_ref/extraction/jkem.ec index 4d8087e4..d1d15639 100644 --- a/code/jasmin/mlkem_ref/extraction/jkem.ec +++ b/code/jasmin/mlkem_ref/extraction/jkem.ec @@ -1682,12 +1682,13 @@ module M(SC:Syscall_t) = { var buf:W8.t Array168.t; var poly:W16.t Array256.t; var k:W64.t; - var l:W64.t; + var rij:W16.t Array256.t; var t:W16.t; buf <- witness; extseed <- witness; poly <- witness; r <- witness; + rij <- witness; state <- witness; stransposed <- transposed; j <- 0; @@ -1718,14 +1719,17 @@ module M(SC:Syscall_t) = { (ctr, poly) <@ __rej_uniform (poly, ctr, buf); } k <- (W64.of_int 0); - l <- (W64.of_int ((i * (3 * 256)) + (j * 256))); + rij <- + (Array256.init (fun i_0 => r.[((i * (3 * 256)) + (j * 256)) + i_0])); while ((k \ult (W64.of_int 256))) { t <- poly.[(W64.to_uint k)]; - r.[(W64.to_uint l)] <- t; + rij.[(W64.to_uint k)] <- t; k <- (k + (W64.of_int 1)); - l <- (l + (W64.of_int 1)); } + r <- Array2304.init + (fun i_0 => if ((i * (3 * 256)) + (j * 256)) <= i_0 < ((i * (3 * 256)) + (j * 256)) + 256 + then rij.[i_0-((i * (3 * 256)) + (j * 256))] else r.[i_0]); j <- j + 1; } i <- i + 1; diff --git a/code/jasmin/mlkem_ref/gen_matrix.jinc b/code/jasmin/mlkem_ref/gen_matrix.jinc index 9241437c..5fa706ca 100644 --- a/code/jasmin/mlkem_ref/gen_matrix.jinc +++ b/code/jasmin/mlkem_ref/gen_matrix.jinc @@ -71,7 +71,7 @@ fn __gen_matrix(stack u8[MLKEM_SYMBYTES] seed, reg u64 transposed) -> stack u16[ reg u8 c; reg u16 t; - reg u64 ctr k l; + reg u64 ctr k; stack u64 sctr; stack u64 stransposed; inline int j i; @@ -112,14 +112,15 @@ fn __gen_matrix(stack u8[MLKEM_SYMBYTES] seed, reg u64 transposed) -> stack u16[ } k = 0; - l = i * MLKEM_VECN + j * MLKEM_N; + reg ptr u16[MLKEM_N] rij; + rij = r[i * MLKEM_VECN + j * MLKEM_N : MLKEM_N]; while (k < MLKEM_N) { t = poly[(int) k]; - r[(int) l] = t; + rij[k] = t; k += 1; - l += 1; } + r[i * MLKEM_VECN + j * MLKEM_N : MLKEM_N] = rij; } } return r; diff --git a/proof/correctness/MLKEM_InnerPKE.ec b/proof/correctness/MLKEM_InnerPKE.ec index 7e4be3d6..50d6febe 100644 --- a/proof/correctness/MLKEM_InnerPKE.ec +++ b/proof/correctness/MLKEM_InnerPKE.ec @@ -460,7 +460,7 @@ proc => /=. inline Parse(XOF).sample. inline Jkem.M(Jkem.Syscall).__rej_uniform. -seq 6 1: (={seed} /\ stransposed{1} = (if trans{2} then W64.one else W64.zero)); 1: by auto. +seq 7 1: (={seed} /\ stransposed{1} = (if trans{2} then W64.one else W64.zero)); 1: by auto. seq 2 0 : (#pre /\ forall k, 0<=k<32 => extseed {1}.[k] = seed{2}.[k]). + conseq => />; 1: smt(). @@ -491,39 +491,27 @@ seq 5 4 : (#pre /\ lift_array256 poly{1} = aa{2} /\ forall k, 0 <= k < 256 => bpos16 poly{1}.[k] q); last first. -+ wp;conseq />; 1: smt(). - while{1} ( - 0 <= to_uint k{1} <= 256 /\ to_uint l{1} = i{2} * 768 + j{2} * 256 + to_uint k{1} /\ - 0 <= j{2} < 3 /\ 0 <= i{2} < 3 /\ - (forall (k0 : int), 0 <= k0 < i{2} * 768 + j{2} * 256 => - (a{2}.[k0 %/ 768, k0 %% 768 %/ 256])%Matrix.[k0 %% 256] = incoeff (to_sint r{1}.[k0]) /\ - 0 <= to_sint r{1}.[k0] < q) /\ - (forall (k0 : int), - i{2} * 768 + j{2} * 256 <= k0 < i{2} * 768 + j{2} * 256 + to_uint k{1} => - r{1}.[k0] = poly{1}.[k0 %% 256] /\ - 0 <= to_sint r{1}.[k0] < q) /\ - (forall k, 0 <= k < 256 => bpos16 poly{1}.[k] q) - ) (256 - to_uint k{1}); last first. - - + auto => /> &1 &2 [#] 8?; rewrite /lift_array256 => ? k; do split; 2:smt(). - + by rewrite of_uintK /= modz_small /#. - move => kl ll rl; rewrite ultE /=; split; 1: smt(). - move => *;do split; 1,2: smt(). - move => k0 k0bl k0bh;rewrite setmE /= offunmE /= 1:/#. - case (k0 < i{2} * 768 + j{2} * 256); 1: smt(). - move => *. - have -> : k0 %/ 768 = i{2} by smt(). - have -> /= : k0 %% 768 %/ 256 = j{2} by smt(). - by rewrite mapiE /#. - - move => &2 z; auto => /> &1; rewrite !ultE /= => ?????. - rewrite to_uintD_small /= 1:/# => *; do split;1,2,6:smt(). - + by rewrite to_uintD_small /#. - + by move => kk kkbl kkbh; rewrite set_neqiE /#. - + move => kk kkbl kkbh. - case (kk < i{2} * 768 + j{2} * 256 + to_uint k{1}). - + by move => *; rewrite set_neqiE /#. - by move => *; rewrite Array2304.set_eqiE /#. ++ wp; conseq />. + * move => &1 &2 [] [] [] -> [] ? [] [] -> -> [] ? [] ? [] ? ? ? [] ? ? ? [] [] ? ? H /=. + do split => //; 1,2: smt(). + by move => ??; rewrite -H /#. + while{1} + (0 <= to_uint k{1} <= 256 + /\ forall (k0 : int), 0 <= k0 < to_uint k{1} => rij{1}.[k0] = poly{1}.[k0]) + (256 - to_uint k{1}). + * auto => /> 4?. + rewrite ultE /= to_uintD_small 1:/# to_uint_small //. + smt(Array256.get_setE). + auto => /> &1 &2 [#] 9? hpoly; split; 1: smt(). + move => rij kl; rewrite ultE /=; split; 1: smt(). + move => hkl 2? heq; split; 1: smt(). + move => k ??. + rewrite initiE 1:/# /=. + rewrite /lift_array256 setmE /= offunmE /= 1:/#. + have -> /= : k < i{2} * 768 + j{2} * 256 + 256 by smt(). + have -> : (k %/ 768 = i{2} /\ k %% 768 %/ 256 = j{2}) <=> i{2} * 768 + j{2} * 256 <= k by smt(). + case: (i{2} * 768 + j{2} * 256 <= k); 2: smt(). + by rewrite mapiE /#. conseq />; 1: by smt(). diff --git a/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec b/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec index 8c90ff5a..740b53ae 100644 --- a/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec +++ b/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec @@ -331,7 +331,7 @@ symmetry. have H : equiv [ Jkem.M(Jkem.Syscall).__gen_matrix ~ Jkem_avx2.M(Jkem_avx2.Syscall).__gen_matrix : ={arg} ==> res{2} = nttunpackm res{1}]. -proc. seq 10 10 : (={r}). +proc. seq 11 11 : (={r}). sim (M(Syscall)._shake128_absorb34 ~ Jkem_avx2.M(Jkem_avx2.Syscall)._shake128_absorb34 : true) (M(Syscall)._shake128_squeezeblock ~ Jkem_avx2.M(Jkem_avx2.Syscall)._shake128_squeezeblock : true) (Jkem_avx2.M(Jkem_avx2.Syscall).__rej_uniform ~