Skip to content

Commit

Permalink
only matrix sampling to prove
Browse files Browse the repository at this point in the history
  • Loading branch information
mbbarbosa-lectures committed Mar 25, 2024
1 parent ac431fc commit 664fd37
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 29 deletions.
60 changes: 33 additions & 27 deletions proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,14 @@ conseq => /=.
inline *; unroll for ^while; auto.
qed.
(*
axiom shake128_equiv_absorb : equiv [ M(Syscall)._shake128_absorb34 ~
Jkem_avx2.M(Jkem_avx2.Syscall)._shake128_absorb34 :
={state, in_0} ==> ={res}].
axiom shake128_equiv_squeezeblock : equiv [ M(Syscall)._shake128_squeezeblock ~
Jkem_avx2.M(Jkem_avx2.Syscall)._shake128_squeezeblock :
={state, out} ==> ={res}].
*)
lemma rng_iotared i n :
(0 <= i < n) = (i \in iotared 0 n).
Expand Down Expand Up @@ -350,11 +351,13 @@ by smt().
qed.
equiv genmatrixequiv b :
Jkem_avx2.M(Jkem_avx2.Syscall).__gen_matrix ~ Jkem.M(Jkem.Syscall).__gen_matrix :
arg{1}.`1 = arg{2}.`1 /\ arg{1}.`2 = (W64.of_int (b2i b)) /\ arg{2}.`2 = (W64.of_int (b2i b)) ==>
Jkem_avx2.M(Jkem_avx2.Syscall)._gen_matrix_avx2 ~ Jkem.M(Jkem.Syscall).__gen_matrix :
arg{1}.`2 = arg{2}.`1 /\ arg{1}.`3= (W64.of_int (b2i b)) /\ arg{2}.`2 = (W64.of_int (b2i b)) ==>
res{1} = nttunpackm res{2} /\
pos_bound2304_cxq res{1} 0 2304 2 /\
pos_bound2304_cxq res{2} 0 2304 2.
admitted.
(*
symmetry.
have H : equiv [
Jkem.M(Jkem.Syscall).__gen_matrix ~ Jkem_avx2.M(Jkem_avx2.Syscall).__gen_matrix :
Expand Down Expand Up @@ -452,7 +455,7 @@ conseq (: _ ==> res{2} = nttunpackm res{1} /\ pos_bound2304_cxq res{1} 0 2304 2)
move : (nttunpackv_pred (subarray768 r1 2) (fun a => W16extra.bpos16 a (2*q))); rewrite !allP /=; smt(Array768.initiE).
conseq H matrix_bound => //=. smt(@W64).
qed.
*)
lemma lift768_nttunpack (v : W16.t Array768.t):
lift_array768 (nttunpackv v) = nttunpackv (lift_array768 v).
rewrite /lift_array768 /nttunpackv tP => k kb.
Expand Down Expand Up @@ -1210,7 +1213,7 @@ transitivity {1} {Jkem.M(Jkem.Syscall).__indcpa_keypair(pkp, skp, randomnessp);}
rho = load_array32 Glob.mem{1} (_pkp + 1152)); 1,2: smt();
last by call(mlkem_correct_kg mem _pkp _skp); auto => />.

inline{1} 1; inline {2} 1; sim 43 60.
inline{1} 1; inline {2} 1; sim 40 60.

call (polyvec_tobytes_equiv _pkp).
call (polyvec_tobytes_equiv _skp).
Expand All @@ -1219,14 +1222,14 @@ ecall (polyvec_reduce_equiv (lift_array768 pkpv{2})).

have H := polyvec_add2_equiv 2 2 _ _ => //.
ecall (H (lift_array768 pkpv{2}) (lift_array768 e{2})); clear H.
unroll for {1} 37.
unroll for {1} 36.

sp 3 3.

seq 17 17 : (#pre /\ ={publicseed, noiseseed,sskp,spkp,e,skpv,pkpv} /\ sskp{2} = skp{1} /\ spkp{2} = pkp{1}); 1: by
seq 15 17 : (#pre /\ ={publicseed, noiseseed,e,skpv,pkpv} /\ sskp{2} = skp{1} /\ spkp{2} = pkp{1}); 1: by
sp; conseq />; sim 2 2; call( sha3equiv); conseq />; sim.

seq 1 2 : (#pre /\ aa{1} = nttunpackm a{2} /\
seq 2 2 : (#pre /\ aa{1} = nttunpackm a{2} /\
pos_bound2304_cxq aa{1} 0 2304 2 /\
pos_bound2304_cxq a{2} 0 2304 2); 1: by
conseq />; call (genmatrixequiv false); auto => />.
Expand All @@ -1252,7 +1255,7 @@ seq 10 18 : (#pre /\
to_uint pkp{1} = _pkp /\
to_uint skp{1} = _skp /\
valid_disj_reg _pkp (384 * 3 + 32) _skp (384 * 3)) /\
={publicseed, noiseseed, sskp, spkp, skpv, pkpv, e}) /\
={publicseed, noiseseed, skpv, pkpv, e}) /\
aa{1} = nttunpackm a{2} /\ pos_bound2304_cxq aa{1} 0 2304 2 /\ pos_bound2304_cxq a{2} 0 2304 2
==>
={skpv, e} /\
Expand Down Expand Up @@ -1648,14 +1651,14 @@ transitivity {1} {Jkem.M(Jkem.Syscall).__indcpa_enc(sctp,msgp,pkp,noiseseed);}

inline{1} 1; inline {2} 1. wp.

seq 49 59 : (={ctp,Glob.mem} /\
seq 50 59 : (={ctp,Glob.mem} /\
pos_bound256_cxq v{1} 0 256 2 /\
pos_bound256_cxq v{2} 0 256 2 /\
lift_array256 v{1} = lift_array256 v{2} /\
valid_ptr (to_uint ctp{1}) 128); last by
exists *Glob.mem{1}, (to_uint ctp{1}); elim* => memm _p; call (compressequiv memm _p); auto.

seq 47 57 : (={ctp,Glob.mem} /\
seq 48 57 : (={ctp,Glob.mem} /\
pos_bound256_cxq v{1} 0 256 2 /\
pos_bound256_cxq v{2} 0 256 2 /\
pos_bound768_cxq bp{1} 0 768 2 /\
Expand All @@ -1676,13 +1679,12 @@ have H := polyvec_add2_equiv_noperm 2 2 _ _ => //.
ecall (H (lift_array768 bp{2}) (lift_array768 ep{2})); clear H.


unroll for {1} 38.
unroll for {1} 39.

swap {1} 3 -2; swap {2} 3 -2; seq 1 1: (#pre /\ ={pkp0} /\ pkp0{2}=pkp{1}); 1: by auto.
sp 3 3.
swap {1} 18 -1. (* avoid dealing with stack noise seed *)

seq 17 17 : (#pre /\ ={publicseed, bp,ep,epp,v,sp_0,k} /\
swap {1} 17 2.
seq 18 17 : (#pre /\ ={publicseed, bp,ep,epp,v,sp_0,k} /\
s_noiseseed{2} = noiseseed0{2} /\
pos_bound256_cxq k{1} 0 256 1 /\
pos_bound256_cxq k{2} 0 256 1 /\
Expand All @@ -1693,7 +1695,7 @@ seq 17 17 : (#pre /\ ={publicseed, bp,ep,epp,v,sp_0,k} /\
pos_bound2304_cxq aat{1} 0 2304 2 /\
pos_bound2304_cxq aat{2} 0 2304 2).
+ call (genmatrixequiv true).
call frommsgequiv_noperm. conseq />. smt().
wp;call frommsgequiv_noperm. conseq />. smt().
conseq (_: _ ==> lift_array768 pkpv{1} = nttunpackv (lift_array768 pkpv{2}) /\
pos_bound768_cxq pkpv{1} 0 768 2 /\ pos_bound768_cxq pkpv{2} 0 768 2 /\ ={publicseed,pkp0,bp,ep,epp,v,sp_0,Glob.mem} /\ pkp0{2} = pkp{1} /\ s_noiseseed{2} = noiseseed0{2}).
auto => /> &2 ????????? rl rr H H0 H1 ????.
Expand All @@ -1717,7 +1719,7 @@ seq 17 17 : (#pre /\ ={publicseed, bp,ep,epp,v,sp_0,k} /\
sp 2 0.
(* swap {1} [11..12] 2. *)

seq 12 20 : (#{/~bp{1}=bp{2}}pre /\
seq 11 20 : (#{/~bp{1}=bp{2}}pre /\
signed_bound768_cxq sp_0{1} 0 768 1 /\
signed_bound768_cxq ep{1} 0 768 1 /\
signed_bound_cxq epp{1} 0 256 1 /\
Expand Down Expand Up @@ -1794,6 +1796,7 @@ seq 12 20 : (#{/~bp{1}=bp{2}}pre /\
move => *; rewrite !initiE //= fun_if.
by smt().

swap {1} 1 2.
seq 1 1 : (#{/~sp_0{1}}{~sp_0{2}}pre /\
lift_array768 sp_0{1} = nttunpackv (lift_array768 sp_0{2}) /\
pos_bound768_cxq sp_0{1} 0 768 2 /\
Expand All @@ -1802,7 +1805,7 @@ seq 1 1 : (#{/~sp_0{1}}{~sp_0{2}}pre /\

(* First ip *)

seq 4 2 : (#pre /\
seq 5 2 : (#pre /\
lift_array256 (subarray256 bp{1} 0) = nttunpack (lift_array256 (subarray256 bp{2} 0)) /\
signed_bound768_cxq bp{1} 0 256 4 /\
signed_bound768_cxq bp{2} 0 256 2 /\ w{1} = 1).
Expand Down Expand Up @@ -2028,13 +2031,13 @@ transitivity {1} { r <@Jkem.M(Jkem.Syscall).__iindcpa_enc(ctp,msgp,pkp,noiseseed

inline{1} 1; inline {2} 1. wp.

seq 51 61 : (={ctp0,Glob.mem} /\ Glob.mem{1} = mem /\
seq 49 61 : (={ctp0,Glob.mem} /\ Glob.mem{1} = mem /\
pos_bound256_cxq v{1} 0 256 2 /\
pos_bound256_cxq v{2} 0 256 2 /\
lift_array256 v{1} = lift_array256 v{2}); last by
exists *Glob.mem{1}; elim* => memm; call (compressequiv_1 memm); auto => />; smt(Array1088.tP Array1088.initiE).

seq 49 59 : (={ctp0,Glob.mem} /\ Glob.mem{1} = mem /\
seq 47 59 : (={ctp0,Glob.mem} /\ Glob.mem{1} = mem /\
pos_bound256_cxq v{1} 0 256 2 /\
pos_bound256_cxq v{2} 0 256 2 /\
pos_bound768_cxq bp{1} 0 768 2 /\
Expand All @@ -2053,13 +2056,14 @@ have H := polyvec_add2_equiv_noperm 2 2 _ _ => //.
ecall (H (lift_array768 bp{2}) (lift_array768 ep{2})); clear H.


unroll for {1} 40.
unroll for {1} 39.

swap {1} 3 -2; swap {2} 3 -2; seq 1 1: (#pre /\ ={pkp0} /\ pkp0{2} = pkp{1}); 1: by auto.
sp 3 3.
swap {1} 19 1. (* avoid dealing with stack noise seed *)

seq 19 19 : (#pre /\ ={publicseed, bp,ep,epp,v,sp_0,k,sctp} /\
swap {1} 17 2. (* avoid dealing with stack noise seed *)

seq 18 19 : (#pre /\ ={publicseed, bp,ep,epp,v,sp_0,k} /\ ctp{2} = sctp{2} /\
s_noiseseed{2} = noiseseed0{2} /\
pos_bound256_cxq k{1} 0 256 1 /\
pos_bound256_cxq k{2} 0 256 1 /\
Expand All @@ -2071,8 +2075,8 @@ seq 19 19 : (#pre /\ ={publicseed, bp,ep,epp,v,sp_0,k,sctp} /\
pos_bound2304_cxq aat{2} 0 2304 2).
+ call (genmatrixequiv true).
wp;call frommsgequiv_noperm. conseq />. smt().
conseq (_: _ ==> lift_array768 pkpv{1} = nttunpackv (lift_array768 pkpv{2}) /\
pos_bound768_cxq pkpv{1} 0 768 2 /\ pos_bound768_cxq pkpv{2} 0 768 2 /\ ={publicseed,pkp0,bp,ep,epp,v,sp_0,sctp,Glob.mem} /\ pkp0{2} = pkp{1} /\ s_noiseseed{2} = noiseseed0{2}).
conseq (_: _ ==> lift_array768 pkpv{1} = nttunpackv (lift_array768 pkpv{2}) /\ ctp{2} = sctp{2} /\
pos_bound768_cxq pkpv{1} 0 768 2 /\ pos_bound768_cxq pkpv{2} 0 768 2 /\ ={publicseed,pkp0,bp,ep,epp,v,sp_0,Glob.mem} /\ pkp0{2} = pkp{1} /\ s_noiseseed{2} = noiseseed0{2}).
auto => /> &2 ??????? rl rr H H0 H1 ????.
+ rewrite tP => k kb.
move : H; rewrite /lift_array256 tP => H.
Expand All @@ -2085,11 +2089,13 @@ seq 19 19 : (#pre /\ ={publicseed, bp,ep,epp,v,sp_0,k,sctp} /\
rewrite ifF. smt(W16.to_uint_cmp).
rewrite ifF. smt(W16.to_uint_cmp).
smt(W16.to_uint_eq).
seq 14 14 : (#{/~publicseed{2}}post /\ ={publicseed}).

seq 12 14 : (#{/~publicseed{2}}post /\ ={publicseed}).
wp;sp; conseq />.
call (polyvec_frombytes_equiv).
auto => />. smt().
conseq />. sim.

sp 2 0.
(* swap {1} [11..12] 2. *)

Expand Down Expand Up @@ -2357,7 +2363,7 @@ conseq />. call(invnttequiv). auto => />. move => &1 &2 ?????????????H1??H0??H?
seq 1 1 : (#{/~v{2}}{~v{1}}pre /\ lift_array256 v{1} = lift_array256 v{2} /\ signed_bound_cxq v{1} 0 256 2 /\ signed_bound_cxq v{2} 0 256 2).
conseq />. call(polyinvnttequiv). auto => />. smt().

auto => /> /#.
auto => /> /#.
qed.


Expand Down
5 changes: 3 additions & 2 deletions proof/correctness/avx2/MLKEM_KEM_avx2.ec
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ lemma mlkem_kem_correct_enc mem _ctp _pkp _kp :
k = load_array32 Glob.mem{1} _kp
].
proc => /=.
seq 15 4 : (#[/1:-2]post
seq 14 4 : (#[/1:-2]post
/\ valid_disj_reg _ctp 1088 _kp 32
/\ to_uint s_shkp{1} = _kp
/\ (forall k, 0<=k<32 => kr{1}.[k]=_K{2}.[k])); last first.
Expand Down Expand Up @@ -368,9 +368,10 @@ seq 15 4 : (#[/1:-2]post
rewrite !WArray64.WArray64.get64E. search pack8_t (\bits8).
by rewrite !pack8bE // !initiE //= /init8 !WArray64.WArray64.initiE /#.
by smt().
auto => /> &1 &2 ?????????;split; 1: by smt().
auto => /> &1 &2 ?????????;split; 1: by smt().
move => mm ii;do split => ???????; 1: smt().
by rewrite /load_array32 tP => kk kkb; smt(Array32.initiE).

wp;call (mlkem_correct_enc_0_avx2 mem _ctp _pkp).
wp;ecall {1} (sha_g_avx2 buf{1} kr{1}).
wp;ecall {1} (pkH_sha_avx2 mem (_pkp) ((Array32.init (fun (i : int) => buf{1}.[32 + i])))).
Expand Down

0 comments on commit 664fd37

Please sign in to comment.