diff --git a/proof/correctness/avx2/MLKEM_KEM_avx2_stack.ec b/proof/correctness/avx2/MLKEM_KEM_avx2_stack.ec index 524665b3..ce5da6c1 100644 --- a/proof/correctness/avx2/MLKEM_KEM_avx2_stack.ec +++ b/proof/correctness/avx2/MLKEM_KEM_avx2_stack.ec @@ -110,7 +110,7 @@ rewrite initiE 1:/# /=. rewrite initiE /#. qed. -(* + lemma mlkem_kem_correct_kg : equiv [Jkem_avx2_stack.M.jade_kem_mlkem_mlkem768_amd64_avx2_keypair_derand ~ MLKEM.kg_derand : coins{2}.`1 = Array32.init(fun i => coins{1}.[i]) /\ @@ -392,7 +392,7 @@ rewrite !tP => H H0 rr0 H1 H2;do split. + by smt(). by move => ????r0 r1 ??;do split;smt(Array32.initiE). qed. -*) + lemma verify_correct_h_stack _ctp _ctp1 : hoare [Jkem_avx2_stack.M.__verify : @@ -666,4 +666,79 @@ wp;ecall {1} (cmov_correct_stack shk0{1} (Array32.init (fun (i_0 : int) => kr{1} wp;ecall{1} (shake256_A32_A1120_ph_stack zp_ct{1}). conseq />;1: smt(). -admitted. + +sp 2 0. print J. +seq 3 0 : (#pre /\ + z{2} = Array32.init (fun i => zp_ct{1}.[i]) +). + ++ wp;while {1} (#pre /\ + aux_0{1} = 4 /\ 0<=i{1} <= aux_0{1} /\ + (forall k, 0<= k < 8 * i{1} => + z{2}.[k] = zp_ct{1}.[k])) + (4 - i{1}). + move => &m z;auto => /> &hr. + move => ???????????prev2. + do split;1,2:smt(). + + move => i ibl ibh; rewrite initiE 1:/# /= initiE 1:/# /=. + rewrite get8_set64_directE 1,2:/# /=. + case (8 * i{!hr} <= i && i < 8 * i{!hr} + 8) => *. + + by rewrite WArray2400.get64E pack8bE 1:/# initiE 1:/# /= initiE /#. + by rewrite /get8 initiE; smt(Array32.initiE). + + by smt(). + + auto => /> &1 &2; rewrite !tP. + move => *. + do split; 1: by smt(). + move => il zpct. + rewrite !tP; split; 1: smt(). + by move => *; rewrite initiE 1:/# /= initiE 1:/# /=;smt(Array32.initiE). + ++ wp;while {1} (#pre /\ + aux_0{1} = (3 * 320 + 128) %/ 8 /\ 0<=i{1} <= aux_0{1} /\ + (forall k, 0<= k < min (8 * i{1}) 960 => + cph{2}.`1.[k] = zp_ct{1}.[32+k]) /\ + (forall k, 960 <= k < min (8 * i{1}) (960 + 128) => + cph{2}.`2.[k-960] = zp_ct{1}.[32+k])) + ((3 * 384 + 32) %/ 8 - i{1}). + move => &m z;auto => /> &hr. + + move => &1 zz;auto => /> &2; rewrite !tP. + move => pkv1 pkv2???prev0?? prev1 prev2?. + do split. + + move => k kb; rewrite initiE 1:/# /= initiE 1:/# /= initiE 1:/# /=. + rewrite get8_set64_directE 1,2:/# /=. + case ((i{hr} + 4) * 8 <= k && k < (i{hr} + 4) * 8 + 8) => *. + + by rewrite WArray1088.get64E pack8bE 1:/# initiE 1:/# /= initiE /#. + by rewrite /get8 initiE; smt(Array32.initiE). + + by smt(). + + by smt(). + + move => k kb kbb; rewrite initiE 1:/# /=. + rewrite get8_set64_directE 1,2:/# /=. + case ((i{hr} + 4) * 8 <= 32 + k && 32 + k < (i{hr} + 4) * 8 + 8) => *. + + rewrite WArray1088.get64E pack8bE 1:/# initiE 1:/# /=. + by rewrite /get8 initiE; smt(Array960.initiE). + + by rewrite /get8 initiE; smt(Array1120.initiE). + + move => k kb kbb; rewrite initiE 1:/# /=. + rewrite get8_set64_directE 1,2:/# /=. + case ((i{hr} + 4) * 8 <= 32 + k && 32 + k < (i{hr} + 4) * 8 + 8) => *. + + rewrite WArray1088.get64E pack8bE 1:/# initiE 1:/# /=. + by rewrite /get8 initiE 1:/#; smt(Array128.initiE). + + rewrite /get8 initiE 1:/# /=; smt(Array960.initiE). + + by smt(). + + auto => /> &1 &2; rewrite !tP. + move => ?????????. + do split; 1..2: by smt(). + move => il zpct. + rewrite !tP; split; 1: smt(). + move => ?????? rr Hrr0 Hrr1; do split. + + move => *; rewrite tP => kk kkb;rewrite initiE 1:/# /=. + rewrite Hrr1 1:/# /= /SHAKE_256_1120_32 /= get_of_list 1:/# /=. + congr;congr;congr;congr. + + congr;rewrite tP => jj jjb;smt(Array32.initiE). + + congr;rewrite tP => jj jjb;smt(Array960.initiE). + + rewrite tP => jj jjb;smt(Array128.initiE). + move => Hok;rewrite Hrr0 1:/# tP => jj jjb;smt(Array32.initiE). + +qed.