diff --git a/jasmin b/jasmin index e39717d8..f202ec4d 160000 --- a/jasmin +++ b/jasmin @@ -1 +1 @@ -Subproject commit e39717d8f39397413d6623965ca304ddb695d552 +Subproject commit f202ec4d82df333c036180555a12fe1835bee8da diff --git a/proof/correctness/avx2/MLKEM_genmatrix_avx2.ec b/proof/correctness/avx2/MLKEM_genmatrix_avx2.ec index f3155d3a..2fec58b6 100644 --- a/proof/correctness/avx2/MLKEM_genmatrix_avx2.ec +++ b/proof/correctness/avx2/MLKEM_genmatrix_avx2.ec @@ -1,7 +1,7 @@ -require import AllCore. +require import AllCore IntDiv. from Jasmin require import JModel. -require import Array16 Array25 Array32 Array33 Array128 Array136 Array256 Array768 Array960 Array1088 Array2304. +require import Array16 Array25 Array32 Array33 Array128 Array136 Array256 Array768 Array960 Array1024 Array1088 Array2304. require import MLKEM_InnerPKE NTT_avx2 MLKEMFCLib. @@ -78,6 +78,23 @@ phoare nttunpack_corr a : [ Jkem_avx2.M(Jkem_avx2.Syscall)._nttunpack : arg = a ==> res = nttunpack a] = 1%r. admitted. (* proved in indcpa *) +phoare sample_last _rho : + [ Jkem_avx2.M(Jkem_avx2.Syscall).__gen_matrix_sample_one_polynomial : + rho = _rho /\ rc = W16.of_int (2*256+2) ==> + res.`1 = subarray256 (subarray768 (unlift_matrix (sampleA _rho)) 2) 2 ] = 1%r. +admitted. + +op subarray1024 ['a] (x : 'a Array2304.t) (i : int) : 'a Array1024.t = + Array1024.init (fun (k : int) => x.[1024 * i + k]). + +lemma sample_four _sd _rc b : + (_rc = 0 \/ _rc = 4) => + phoare + [ Jkem_avx2.M(Jkem_avx2.Syscall)._gen_matrix_sample_four_polynomials : + rho = _sd /\ mat_entry = W64.of_int _rc /\ transposed = W64.of_int (b2i b) ==> + res.`1 = subarray1024 (unlift_matrix (if b then trmx (sampleA _sd) else (sampleA _sd))) (_rc %% 3) ] = 1%r. +admitted. + timeout 1. phoare _gen_matrix_avx2_sem _sd b : [ Jkem_avx2.M(Jkem_avx2.Syscall)._gen_matrix_avx2 : arg.`2 = _sd /\ arg.`3 = W64.of_int (b2i b) @@ -85,11 +102,11 @@ phoare _gen_matrix_avx2_sem _sd b : then nttunpackm (unlift_matrix (trmx (sampleA _sd))) else nttunpackm (unlift_matrix (sampleA _sd)) ] = 1%r. proc => /=. -while (0<=i<=3 /\ +while (0<=i<=3 /\ rho = _sd /\ ((forall kk, 0 <= kk < i => subarray768 matrix kk = nttunpackv (subarray768 (unlift_matrix (if b then trmx (sampleA _sd) else (sampleA _sd))) kk))) /\ (forall kk, i <= kk < 3 => subarray768 matrix kk = (subarray768 (unlift_matrix (if b then trmx (sampleA _sd) else (sampleA _sd))) kk))) (kvec-i). + move => *;wp => />;1:smt(). - while (0<=i<3 /\ 0 <= j <= 3 /\ + while (0<=i<3 /\ 0 <= j <= 3 /\ rho = _sd /\ ((forall kk, 0 <= kk < i => subarray768 matrix kk = nttunpackv (subarray768 (unlift_matrix (if b then trmx (sampleA _sd) else (sampleA _sd))) kk))) /\ (forall kk, i+1 <= kk < 3 => subarray768 matrix kk = (subarray768 (unlift_matrix (if b then trmx (sampleA _sd) else (sampleA _sd))) kk)) /\ (forall kk, 0 <= kk < j => subarray256 (subarray768 matrix i) kk = nttunpack (subarray256 (subarray768 (unlift_matrix (if b then trmx (sampleA _sd) else (sampleA _sd))) i) kk)) /\ @@ -169,8 +186,28 @@ case (768 <= k && k < 1536). + by move =>? kbb;rewrite -H0 1:/# /subarray768 initiE 1:/# /=. by move =>? kbb;rewrite -H0 1:/# /subarray768 initiE 1:/# /=. -admitted. +unroll for 9. +wp;call (sample_last _sd). +wp;call (sample_four _sd 4 b _). +wp;call (sample_four _sd 0 b _). +auto => /> &hr a -> a0 -> a1 -> row ??. +congr; rewrite tP => kk ?. +pose xx := (unlift_matrix (if b then trmx (sampleA _sd) else sampleA _sd)).[kk]. +rewrite initiE 1:/# /=. +case (2048 <= kk && kk < 2304). ++ move => lastpos;rewrite /subarray256 /subarray768 initiE 1:/# /= initiE 1:/# /=. + rewrite /xx; case (!(b = true)); 1: by smt(). + move => /= ->;rewrite /unlift_matrix !initiE 1,2:/# /=;congr;congr. + by rewrite /trmx offunmE 1:/# /= /#. + +move => notlast. +rewrite initiE 1:/# /=. +case (1024 <= kk && kk < 2048). ++ by move => middlepos;rewrite /subarray1024 initiE /#. +move => *;rewrite initiE 1:/# /=. +by rewrite /subarray1024 initiE /#. +qed. equiv genmatrixequiv_aux b : Jkem_avx2.M(Jkem_avx2.Syscall)._gen_matrix_avx2 ~ AuxMLKEM.__gen_matrix :