diff --git a/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec b/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec index 9eaad408..c5a6c16c 100644 --- a/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec +++ b/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec @@ -125,13 +125,14 @@ conseq => /=. inline *; unroll for ^while; auto. qed. +(* axiom shake128_equiv_absorb : equiv [ M(Syscall)._shake128_absorb34 ~ Jkem_avx2.M(Jkem_avx2.Syscall)._shake128_absorb34 : ={state, in_0} ==> ={res}]. - axiom shake128_equiv_squeezeblock : equiv [ M(Syscall)._shake128_squeezeblock ~ Jkem_avx2.M(Jkem_avx2.Syscall)._shake128_squeezeblock : ={state, out} ==> ={res}]. +*) lemma rng_iotared i n : (0 <= i < n) = (i \in iotared 0 n). @@ -350,11 +351,13 @@ by smt(). qed. equiv genmatrixequiv b : - Jkem_avx2.M(Jkem_avx2.Syscall).__gen_matrix ~ Jkem.M(Jkem.Syscall).__gen_matrix : - arg{1}.`1 = arg{2}.`1 /\ arg{1}.`2 = (W64.of_int (b2i b)) /\ arg{2}.`2 = (W64.of_int (b2i b)) ==> + Jkem_avx2.M(Jkem_avx2.Syscall)._gen_matrix_avx2 ~ Jkem.M(Jkem.Syscall).__gen_matrix : + arg{1}.`2 = arg{2}.`1 /\ arg{1}.`3= (W64.of_int (b2i b)) /\ arg{2}.`2 = (W64.of_int (b2i b)) ==> res{1} = nttunpackm res{2} /\ pos_bound2304_cxq res{1} 0 2304 2 /\ pos_bound2304_cxq res{2} 0 2304 2. +admitted. +(* symmetry. have H : equiv [ Jkem.M(Jkem.Syscall).__gen_matrix ~ Jkem_avx2.M(Jkem_avx2.Syscall).__gen_matrix : @@ -452,7 +455,7 @@ conseq (: _ ==> res{2} = nttunpackm res{1} /\ pos_bound2304_cxq res{1} 0 2304 2) move : (nttunpackv_pred (subarray768 r1 2) (fun a => W16extra.bpos16 a (2*q))); rewrite !allP /=; smt(Array768.initiE). conseq H matrix_bound => //=. smt(@W64). qed. - +*) lemma lift768_nttunpack (v : W16.t Array768.t): lift_array768 (nttunpackv v) = nttunpackv (lift_array768 v). rewrite /lift_array768 /nttunpackv tP => k kb. @@ -1210,7 +1213,7 @@ transitivity {1} {Jkem.M(Jkem.Syscall).__indcpa_keypair(pkp, skp, randomnessp);} rho = load_array32 Glob.mem{1} (_pkp + 1152)); 1,2: smt(); last by call(mlkem_correct_kg mem _pkp _skp); auto => />. -inline{1} 1; inline {2} 1; sim 43 60. +inline{1} 1; inline {2} 1; sim 40 60. call (polyvec_tobytes_equiv _pkp). call (polyvec_tobytes_equiv _skp). @@ -1219,14 +1222,14 @@ ecall (polyvec_reduce_equiv (lift_array768 pkpv{2})). have H := polyvec_add2_equiv 2 2 _ _ => //. ecall (H (lift_array768 pkpv{2}) (lift_array768 e{2})); clear H. -unroll for {1} 37. +unroll for {1} 36. sp 3 3. -seq 17 17 : (#pre /\ ={publicseed, noiseseed,sskp,spkp,e,skpv,pkpv} /\ sskp{2} = skp{1} /\ spkp{2} = pkp{1}); 1: by +seq 15 17 : (#pre /\ ={publicseed, noiseseed,e,skpv,pkpv} /\ sskp{2} = skp{1} /\ spkp{2} = pkp{1}); 1: by sp; conseq />; sim 2 2; call( sha3equiv); conseq />; sim. -seq 1 2 : (#pre /\ aa{1} = nttunpackm a{2} /\ +seq 2 2 : (#pre /\ aa{1} = nttunpackm a{2} /\ pos_bound2304_cxq aa{1} 0 2304 2 /\ pos_bound2304_cxq a{2} 0 2304 2); 1: by conseq />; call (genmatrixequiv false); auto => />. @@ -1252,7 +1255,7 @@ seq 10 18 : (#pre /\ to_uint pkp{1} = _pkp /\ to_uint skp{1} = _skp /\ valid_disj_reg _pkp (384 * 3 + 32) _skp (384 * 3)) /\ - ={publicseed, noiseseed, sskp, spkp, skpv, pkpv, e}) /\ + ={publicseed, noiseseed, skpv, pkpv, e}) /\ aa{1} = nttunpackm a{2} /\ pos_bound2304_cxq aa{1} 0 2304 2 /\ pos_bound2304_cxq a{2} 0 2304 2 ==> ={skpv, e} /\ @@ -1648,14 +1651,14 @@ transitivity {1} {Jkem.M(Jkem.Syscall).__indcpa_enc(sctp,msgp,pkp,noiseseed);} inline{1} 1; inline {2} 1. wp. -seq 49 59 : (={ctp,Glob.mem} /\ +seq 50 59 : (={ctp,Glob.mem} /\ pos_bound256_cxq v{1} 0 256 2 /\ pos_bound256_cxq v{2} 0 256 2 /\ lift_array256 v{1} = lift_array256 v{2} /\ valid_ptr (to_uint ctp{1}) 128); last by exists *Glob.mem{1}, (to_uint ctp{1}); elim* => memm _p; call (compressequiv memm _p); auto. -seq 47 57 : (={ctp,Glob.mem} /\ +seq 48 57 : (={ctp,Glob.mem} /\ pos_bound256_cxq v{1} 0 256 2 /\ pos_bound256_cxq v{2} 0 256 2 /\ pos_bound768_cxq bp{1} 0 768 2 /\ @@ -1676,13 +1679,12 @@ have H := polyvec_add2_equiv_noperm 2 2 _ _ => //. ecall (H (lift_array768 bp{2}) (lift_array768 ep{2})); clear H. -unroll for {1} 38. +unroll for {1} 39. swap {1} 3 -2; swap {2} 3 -2; seq 1 1: (#pre /\ ={pkp0} /\ pkp0{2}=pkp{1}); 1: by auto. sp 3 3. -swap {1} 18 -1. (* avoid dealing with stack noise seed *) - -seq 17 17 : (#pre /\ ={publicseed, bp,ep,epp,v,sp_0,k} /\ +swap {1} 17 2. +seq 18 17 : (#pre /\ ={publicseed, bp,ep,epp,v,sp_0,k} /\ s_noiseseed{2} = noiseseed0{2} /\ pos_bound256_cxq k{1} 0 256 1 /\ pos_bound256_cxq k{2} 0 256 1 /\ @@ -1693,7 +1695,7 @@ seq 17 17 : (#pre /\ ={publicseed, bp,ep,epp,v,sp_0,k} /\ pos_bound2304_cxq aat{1} 0 2304 2 /\ pos_bound2304_cxq aat{2} 0 2304 2). + call (genmatrixequiv true). - call frommsgequiv_noperm. conseq />. smt(). + wp;call frommsgequiv_noperm. conseq />. smt(). conseq (_: _ ==> lift_array768 pkpv{1} = nttunpackv (lift_array768 pkpv{2}) /\ pos_bound768_cxq pkpv{1} 0 768 2 /\ pos_bound768_cxq pkpv{2} 0 768 2 /\ ={publicseed,pkp0,bp,ep,epp,v,sp_0,Glob.mem} /\ pkp0{2} = pkp{1} /\ s_noiseseed{2} = noiseseed0{2}). auto => /> &2 ????????? rl rr H H0 H1 ????. @@ -1717,7 +1719,7 @@ seq 17 17 : (#pre /\ ={publicseed, bp,ep,epp,v,sp_0,k} /\ sp 2 0. (* swap {1} [11..12] 2. *) -seq 12 20 : (#{/~bp{1}=bp{2}}pre /\ +seq 11 20 : (#{/~bp{1}=bp{2}}pre /\ signed_bound768_cxq sp_0{1} 0 768 1 /\ signed_bound768_cxq ep{1} 0 768 1 /\ signed_bound_cxq epp{1} 0 256 1 /\ @@ -1794,6 +1796,7 @@ seq 12 20 : (#{/~bp{1}=bp{2}}pre /\ move => *; rewrite !initiE //= fun_if. by smt(). +swap {1} 1 2. seq 1 1 : (#{/~sp_0{1}}{~sp_0{2}}pre /\ lift_array768 sp_0{1} = nttunpackv (lift_array768 sp_0{2}) /\ pos_bound768_cxq sp_0{1} 0 768 2 /\ @@ -1802,7 +1805,7 @@ seq 1 1 : (#{/~sp_0{1}}{~sp_0{2}}pre /\ (* First ip *) -seq 4 2 : (#pre /\ +seq 5 2 : (#pre /\ lift_array256 (subarray256 bp{1} 0) = nttunpack (lift_array256 (subarray256 bp{2} 0)) /\ signed_bound768_cxq bp{1} 0 256 4 /\ signed_bound768_cxq bp{2} 0 256 2 /\ w{1} = 1). @@ -2028,13 +2031,13 @@ transitivity {1} { r <@Jkem.M(Jkem.Syscall).__iindcpa_enc(ctp,msgp,pkp,noiseseed inline{1} 1; inline {2} 1. wp. -seq 51 61 : (={ctp0,Glob.mem} /\ Glob.mem{1} = mem /\ +seq 49 61 : (={ctp0,Glob.mem} /\ Glob.mem{1} = mem /\ pos_bound256_cxq v{1} 0 256 2 /\ pos_bound256_cxq v{2} 0 256 2 /\ lift_array256 v{1} = lift_array256 v{2}); last by exists *Glob.mem{1}; elim* => memm; call (compressequiv_1 memm); auto => />; smt(Array1088.tP Array1088.initiE). -seq 49 59 : (={ctp0,Glob.mem} /\ Glob.mem{1} = mem /\ +seq 47 59 : (={ctp0,Glob.mem} /\ Glob.mem{1} = mem /\ pos_bound256_cxq v{1} 0 256 2 /\ pos_bound256_cxq v{2} 0 256 2 /\ pos_bound768_cxq bp{1} 0 768 2 /\ @@ -2053,13 +2056,14 @@ have H := polyvec_add2_equiv_noperm 2 2 _ _ => //. ecall (H (lift_array768 bp{2}) (lift_array768 ep{2})); clear H. -unroll for {1} 40. +unroll for {1} 39. swap {1} 3 -2; swap {2} 3 -2; seq 1 1: (#pre /\ ={pkp0} /\ pkp0{2} = pkp{1}); 1: by auto. sp 3 3. -swap {1} 19 1. (* avoid dealing with stack noise seed *) -seq 19 19 : (#pre /\ ={publicseed, bp,ep,epp,v,sp_0,k,sctp} /\ +swap {1} 17 2. (* avoid dealing with stack noise seed *) + +seq 18 19 : (#pre /\ ={publicseed, bp,ep,epp,v,sp_0,k} /\ ctp{2} = sctp{2} /\ s_noiseseed{2} = noiseseed0{2} /\ pos_bound256_cxq k{1} 0 256 1 /\ pos_bound256_cxq k{2} 0 256 1 /\ @@ -2071,8 +2075,8 @@ seq 19 19 : (#pre /\ ={publicseed, bp,ep,epp,v,sp_0,k,sctp} /\ pos_bound2304_cxq aat{2} 0 2304 2). + call (genmatrixequiv true). wp;call frommsgequiv_noperm. conseq />. smt(). - conseq (_: _ ==> lift_array768 pkpv{1} = nttunpackv (lift_array768 pkpv{2}) /\ - pos_bound768_cxq pkpv{1} 0 768 2 /\ pos_bound768_cxq pkpv{2} 0 768 2 /\ ={publicseed,pkp0,bp,ep,epp,v,sp_0,sctp,Glob.mem} /\ pkp0{2} = pkp{1} /\ s_noiseseed{2} = noiseseed0{2}). + conseq (_: _ ==> lift_array768 pkpv{1} = nttunpackv (lift_array768 pkpv{2}) /\ ctp{2} = sctp{2} /\ + pos_bound768_cxq pkpv{1} 0 768 2 /\ pos_bound768_cxq pkpv{2} 0 768 2 /\ ={publicseed,pkp0,bp,ep,epp,v,sp_0,Glob.mem} /\ pkp0{2} = pkp{1} /\ s_noiseseed{2} = noiseseed0{2}). auto => /> &2 ??????? rl rr H H0 H1 ????. + rewrite tP => k kb. move : H; rewrite /lift_array256 tP => H. @@ -2085,11 +2089,13 @@ seq 19 19 : (#pre /\ ={publicseed, bp,ep,epp,v,sp_0,k,sctp} /\ rewrite ifF. smt(W16.to_uint_cmp). rewrite ifF. smt(W16.to_uint_cmp). smt(W16.to_uint_eq). - seq 14 14 : (#{/~publicseed{2}}post /\ ={publicseed}). + + seq 12 14 : (#{/~publicseed{2}}post /\ ={publicseed}). wp;sp; conseq />. call (polyvec_frombytes_equiv). auto => />. smt(). conseq />. sim. + sp 2 0. (* swap {1} [11..12] 2. *) @@ -2357,7 +2363,7 @@ conseq />. call(invnttequiv). auto => />. move => &1 &2 ?????????????H1??H0??H? seq 1 1 : (#{/~v{2}}{~v{1}}pre /\ lift_array256 v{1} = lift_array256 v{2} /\ signed_bound_cxq v{1} 0 256 2 /\ signed_bound_cxq v{2} 0 256 2). conseq />. call(polyinvnttequiv). auto => />. smt(). -auto => /> /#. +auto => /> /#. qed. diff --git a/proof/correctness/avx2/MLKEM_KEM_avx2.ec b/proof/correctness/avx2/MLKEM_KEM_avx2.ec index ded8b743..82783948 100644 --- a/proof/correctness/avx2/MLKEM_KEM_avx2.ec +++ b/proof/correctness/avx2/MLKEM_KEM_avx2.ec @@ -331,7 +331,7 @@ lemma mlkem_kem_correct_enc mem _ctp _pkp _kp : k = load_array32 Glob.mem{1} _kp ]. proc => /=. -seq 15 4 : (#[/1:-2]post +seq 14 4 : (#[/1:-2]post /\ valid_disj_reg _ctp 1088 _kp 32 /\ to_uint s_shkp{1} = _kp /\ (forall k, 0<=k<32 => kr{1}.[k]=_K{2}.[k])); last first. @@ -368,9 +368,10 @@ seq 15 4 : (#[/1:-2]post rewrite !WArray64.WArray64.get64E. search pack8_t (\bits8). by rewrite !pack8bE // !initiE //= /init8 !WArray64.WArray64.initiE /#. by smt(). - auto => /> &1 &2 ?????????;split; 1: by smt(). + auto => /> &1 &2 ?????????;split; 1: by smt(). move => mm ii;do split => ???????; 1: smt(). by rewrite /load_array32 tP => kk kkb; smt(Array32.initiE). + wp;call (mlkem_correct_enc_0_avx2 mem _ctp _pkp). wp;ecall {1} (sha_g_avx2 buf{1} kr{1}). wp;ecall {1} (pkH_sha_avx2 mem (_pkp) ((Array32.init (fun (i : int) => buf{1}.[32 + i])))).