Skip to content

Commit

Permalink
recovering commented proof
Browse files Browse the repository at this point in the history
  • Loading branch information
mbbarbosa-lectures committed Dec 7, 2024
1 parent bb6d264 commit f07c57b
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 174 deletions.
140 changes: 29 additions & 111 deletions proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
require import AllCore IntDiv List.
from Jasmin require import JModel.
require import Array16 Array25 Array32 Array33 Array128 Array136 Array256 Array768 Array960 Array1088 Array2304.
require import Array4 Array33 Array128 Array16 Array25 Array32 Array33 Array128 Array136 Array256 Array768 Array960 Array1088 Array2304.
require import List_extra.
require import MLKEM_Poly MLKEM_PolyVec MLKEM_InnerPKE.
require import MLKEM_Poly_avx2_proof.
Expand Down Expand Up @@ -30,91 +30,27 @@ import WArray512 WArray256.

(* shake assumptions *)

(*

op SHAKE256_ABSORB4x_33 : W8.t Array33.t -> W8.t Array33.t -> W8.t Array33.t -> W8.t Array33.t -> W256.t Array25.t.
op SHAKE256_SQUEEZENBLOCKS4x : W256.t Array25.t -> W256.t Array25.t * W8.t Array136.t * W8.t Array136.t * W8.t Array136.t * W8.t Array136.t.

axiom shake_absorb4x state seed1 seed2 seed3 seed4 :
phoare [ Jkem_avx2.M(Jkem_avx2.Syscall)._shake256_absorb4x_33 :
arg = (state,seed1,seed2,seed3,seed4) ==>
res = SHAKE256_ABSORB4x_33 seed1 seed2 seed3 seed4 ] = 1%r.

axiom shake_squeezenblocks4x state buf1 buf2 buf3 buf4 :
phoare [ Jkem_avx2.M(Jkem_avx2.Syscall).__shake256_squeezenblocks4x :
arg = (state,buf1,buf2,buf3,buf4) ==>
res = SHAKE256_SQUEEZENBLOCKS4x state ] = 1%r.

axiom shake4x_equiv (sn1 sn2 sn3 sn4: W8.t Array33.t) (s1 s2 s3 s4 : W8.t Array32.t) n1 n2 n3 n4 :
s1 = Array32.init (fun i => sn1.[i]) =>
s2 = Array32.init (fun i => sn2.[i]) =>
s3 = Array32.init (fun i => sn3.[i]) =>
s4 = Array32.init (fun i => sn4.[i]) =>
n1 = sn1.[32] => n2 = sn2.[32] => n3 = sn3.[32] => n4 = sn4.[32] =>
Array128.init ("_.[_]" (SHAKE256_SQUEEZENBLOCKS4x (SHAKE256_ABSORB4x_33 sn1 sn2 sn3 sn4)).`2) = SHAKE256_33_128 s1 n1 /\
Array128.init ("_.[_]" (SHAKE256_SQUEEZENBLOCKS4x (SHAKE256_ABSORB4x_33 sn1 sn2 sn3 sn4)).`3) = SHAKE256_33_128 s2 n2 /\
Array128.init ("_.[_]" (SHAKE256_SQUEEZENBLOCKS4x (SHAKE256_ABSORB4x_33 sn1 sn2 sn3 sn4)).`4) = SHAKE256_33_128 s3 n3 /\
Array128.init ("_.[_]" (SHAKE256_SQUEEZENBLOCKS4x (SHAKE256_ABSORB4x_33 sn1 sn2 sn3 sn4)).`5) = SHAKE256_33_128 s4 n4.

axiom sha3equiv :
equiv [ (* is this in the sha3 paper? *)
Jkem_avx2.M(Jkem_avx2.Syscall)._sha3_512_32 ~Jkem.M(Jkem.Syscall)._sha3512_32 : ={arg} ==> ={res}].

lemma keccakf1600_set_row_ll : islossless M(Syscall).keccakf1600_set_row.
proc. by unroll for ^while; auto. qed.

lemma keccakf1600_rho_offsets_ll : islossless M(Syscall).keccakf1600_rho_offsets.
proc. by unroll for ^while; islossless. qed.

lemma keccakf1600_rhotates_ll : islossless M(Syscall).keccakf1600_rhotates.
proc. by call keccakf1600_rho_offsets_ll; islossless. qed.

lemma keccakf1600_theta_rol_ll : islossless M(Syscall).keccakf1600_theta_rol.
proc. by unroll for ^while; islossless. qed.

lemma keccakf1600_theta_sum_ll : islossless M(Syscall).keccakf1600_theta_sum.
proc. by do 6!(unroll for ^while); islossless. qed.

lemma keccakf1600_rol_sum_ll : islossless M(Syscall).keccakf1600_rol_sum.
proc.
while (x <= 5) (5 - x); auto; last smt().
conseq => /=; call keccakf1600_rhotates_ll; auto => /#.
qed.

lemma keccakf1600_round_ll : islossless Jkem.M(Syscall).keccakf1600_round.
proc; auto.
while (y <= 5) (5 - y); auto.
+ call keccakf1600_set_row_ll.
call keccakf1600_rol_sum_ll.
auto; smt().
call keccakf1600_theta_rol_ll.
call keccakf1600_theta_sum_ll.
auto; smt().
qed.

lemma keccakf1600_ll : islossless Jkem.M(Syscall)._keccakf1600_.
proc; auto.
call (:true); auto.
call (:true); auto.
while (to_uint c <= 24 /\ to_uint c %% 2 = 0) (24 - to_uint c); auto; last by move => /> *; rewrite ultE to_uint_small //= /#.
call keccakf1600_round_ll; auto.
call keccakf1600_round_ll; auto.
move => /> ??; rewrite ultE to_uintD_small to_uint_small //= /#.
qed.

*)
lemma sha3ll : islossless M(Syscall)._shake256_128_33.
admitted.

lemma shake256_4x_128_32 _seed _nonces :
phoare [
Jkem_avx2.M(Jkem_avx2.Syscall)._shake256x4_A128__A32_A1 : arg.`5 = _seed /\ arg.`6 = _nonces ==>
res.`1 =
SHAKE256_33_128 _seed _nonces.[0] /\
res.`2 =
SHAKE256_33_128 _seed _nonces.[1] /\
res.`3 =
SHAKE256_33_128 _seed _nonces.[2] /\
res.`4 =
SHAKE256_33_128 _seed _nonces.[3]
] = 1%r.
admitted.

(*
axiom shake128_equiv_absorb : equiv [ M(Syscall)._shake128_absorb34 ~
Jkem_avx2.M(Jkem_avx2.Syscall)._shake128_absorb34 :
={state, in_0} ==> ={res}].
axiom shake128_equiv_squeezeblock : equiv [ M(Syscall)._shake128_squeezeblock ~
Jkem_avx2.M(Jkem_avx2.Syscall)._shake128_squeezeblock :
={state, out} ==> ={res}].
*)
lemma sha3equiv :
equiv [ Jkem_avx2.M(Jkem_avx2.Syscall)._sha3_512A_A33 ~ M(Syscall)._sha3512_33 :
={arg} ==> ={res} ].
admitted.


equiv genmatrixequiv b :
Expand Down Expand Up @@ -771,20 +707,6 @@ equiv getnoise_4x_split :
GetNoiseAVX2._poly_getnoise_eta1_4x ~ AuxMLKEMAvx2.__poly_getnoise_eta1_4x : ={arg} ==> ={res}.
proc; wp; sp => />. call getnoise_split => />. call getnoise_split => />. call getnoise_split => />. call getnoise_split => />. auto => />. qed.
require import Array4 Array33 Array128.
axiom shake256_4x_128_32 _seed _nonces :
phoare [
Jkem_avx2.M(Jkem_avx2.Syscall)._shake256x4_A128__A32_A1 : arg.`5 = _seed /\ arg.`6 = _nonces ==>
res.`1 =
SHAKE256_33_128 _seed _nonces.[0] /\
res.`2 =
SHAKE256_33_128 _seed _nonces.[1] /\
res.`3 =
SHAKE256_33_128 _seed _nonces.[2] /\
res.`4 =
SHAKE256_33_128 _seed _nonces.[3]
] = 1%r.
equiv getnoiseequiv_avx :
Jkem_avx2.M(Jkem_avx2.Syscall)._poly_getnoise_eta1_4x ~ GetNoiseAVX2._poly_getnoise_eta1_4x : ={arg} ==> ={res}.
Expand Down Expand Up @@ -846,10 +768,6 @@ conseq HHH HH0.
move => *; rewrite /signed_bound_cxq /b16 qE /#.
qed.
axiom sha3equiv :
equiv [ Jkem_avx2.M(Jkem_avx2.Syscall)._sha3_512A_A33
~ M(Syscall)._sha3512_33 :
={arg} ==> ={res} ].
import InnerPKE.
lemma mlkem_correct_kg_avx2 mem _pkp _skp :
Expand Down Expand Up @@ -1385,22 +1303,22 @@ seq 18 17 : (#pre /\ ={publicseed, bp,ep,epp,v,sp_0,k} /\
seq 12 12 : (#{/~publicseed{2}}post /\ ={publicseed}).
wp;sp; conseq />.
call (polyvec_frombytes_equiv).
auto => />. smt().
conseq />. sim.
auto => />;1: smt().
by conseq />;sim.
sp 2 0.
(* swap {1} [11..12] 2. *)
seq 12 20 : (#{/~bp{1}=bp{2}}pre /\
seq 12 20 : (#{/~bp{1}=bp{2}}{~lnoiseseed{1}}pre /\
signed_bound768_cxq sp_0{1} 0 768 1 /\
signed_bound768_cxq ep{1} 0 768 1 /\
signed_bound_cxq epp{1} 0 256 1 /\
signed_bound768_cxq sp_0{2} 0 768 1 /\
signed_bound768_cxq ep{2} 0 768 1 /\
signed_bound_cxq epp{1} 0 256 1).
admit(*
+ conseq />.
transitivity {1} { (sp_0,ep,bp,epp) <@ GetNoiseAVX2.samplenoise_enc(sp_0,ep,bp, epp,noiseseed);} (lnoiseseed{1} = noiseseed{2} /\ ={sp_0,ep,bp,epp} ==> ={sp_0,ep,epp})
transitivity {1} { (sp_0,ep,bp,epp) <@ GetNoiseAVX2.samplenoise_enc(sp_0,ep,bp, epp,noiseseed);}
(s_noiseseed{1} = noiseseed{2} /\ lnoiseseed{1} = noiseseed{2} /\ ={sp_0,ep,bp,epp} ==> ={sp_0,ep,epp})
(
s_noiseseed{1} = noiseseed0{1} /\
lnoiseseed{1} = s_noiseseed{1} /\
Expand Down Expand Up @@ -1429,7 +1347,7 @@ admit(*
signed_bound_cxq epp{1} 0 256 1 /\
signed_bound768_cxq sp_0{2} 0 768 1 /\ signed_bound768_cxq ep{2} 0 768 1 /\ signed_bound_cxq epp{1} 0 256 1
); 1,2:smt().
+ by inline {2} 1;do 2!(wp; call getnoiseequiv_avx);auto => />.
+ inline {2} 1;do 2!(wp; call getnoiseequiv_avx);auto => />.
inline {1} 1. inline GetNoiseAVX2._poly_getnoise_eta1_4x.
wp; call{1} (_: true ==> true); 1: by apply polygetnoise_ll.
do 7!(wp; call getnoiseequiv); auto => />.
Expand Down Expand Up @@ -1468,7 +1386,7 @@ admit(*
case (256 <= x && x < 512); 1: by smt().
move => *; rewrite !initiE //= fun_if.
by smt().
*).
swap {1} 1 2.
seq 1 1 : (#{/~sp_0{1}}{~sp_0{2}}pre /\
lift_array768 sp_0{1} = nttunpackv (lift_array768 sp_0{2}) /\
Expand Down Expand Up @@ -1772,16 +1690,16 @@ seq 18 19 : (#pre /\ ={publicseed, bp,ep,epp,v,sp_0,k} /\ ctp{2} = sctp{2} /\
sp 2 0.
(* swap {1} [11..12] 2. *)
seq 13 20 : (#{/~bp{1}=bp{2}}pre /\
seq 13 20 : (#{/~bp{1}=bp{2}}{~lnoiseseed{1}}pre /\
signed_bound768_cxq sp_0{1} 0 768 1 /\
signed_bound768_cxq ep{1} 0 768 1 /\
signed_bound_cxq epp{1} 0 256 1 /\
signed_bound768_cxq sp_0{2} 0 768 1 /\
signed_bound768_cxq ep{2} 0 768 1 /\
signed_bound_cxq epp{1} 0 256 1).
admit(*
+ conseq />.
transitivity {1} { (sp_0,ep,bp,epp) <@ GetNoiseAVX2.samplenoise_enc(sp_0,ep,bp, epp,noiseseed);} (lnoiseseed{1} = noiseseed{2} /\ ={sp_0,ep,bp,epp} ==> ={sp_0,ep,epp})
transitivity {1} { (sp_0,ep,bp,epp) <@ GetNoiseAVX2.samplenoise_enc(sp_0,ep,bp, epp,noiseseed);}
(s_noiseseed{1} = noiseseed{2} /\ lnoiseseed{1} = noiseseed{2} /\ ={sp_0,ep,bp,epp} ==> ={sp_0,ep,epp})
(
s_noiseseed{1} = noiseseed0{1} /\
lnoiseseed{1} = s_noiseseed{1} /\
Expand Down Expand Up @@ -1847,7 +1765,7 @@ admit(*
case (256 <= x && x < 512); 1: by smt().
move => *; rewrite !initiE //= fun_if.
by smt().
*).
seq 1 1 : (#{/~sp_0{1}}{~sp_0{2}}pre /\
lift_array768 sp_0{1} = nttunpackv (lift_array768 sp_0{2}) /\
pos_bound768_cxq sp_0{1} 0 768 2 /\
Expand Down
60 changes: 0 additions & 60 deletions proof/correctness/avx2/MLKEM_genmatrix_avx2.ec
Original file line number Diff line number Diff line change
Expand Up @@ -384,67 +384,7 @@ import Array536.
require import JWordList EclibExtra MLKEM_keccak_avx2.
require import Mlkem_filters_bridge.
(*
abbrev bufl (buf: W8.t Array536.t) = to_list buf.
op buf_subl (buf: W8.t Array536.t) (first last: int): W8.t list =
take (last-first) (drop first (bufl buf)).
lemma buf_subl0 buf s e:
(e <= s)%Int =>
buf_subl buf s e = [].
proof. by rewrite /buf_subl /#. qed.
lemma size_buf_subl buf bstart bend:
0 <= bstart <= bend <= 536 =>
size (buf_subl buf bstart bend) = bend - bstart.
proof.
by move=> H; rewrite size_take 1:/# size_drop 1:/# size_to_list /#.
qed.
lemma buf_sublE buf (i j: int):
0 <= i <= j <= 536 =>
buf_subl buf i j = sub buf i (j-i).
proof.
move=> H; rewrite /buf_subl /sub.
apply (eq_from_nth witness).
by rewrite size_take 1:/# size_drop 1:/# !size_mkseq /#.
move=> k.
rewrite size_take 1:/# size_drop 1:/# size_to_list => Hk.
by rewrite nth_take 1..2:/# nth_drop 1..2:/# !nth_mkseq /#.
qed.
lemma sub2buf_subl (buf: W8.t Array536.t) (k len: int):
0 <= k <= k+len <= 536 =>
sub buf k len = buf_subl buf k (k+len).
proof.
move=> H; rewrite /buf_subl /sub.
apply (eq_from_nth witness).
rewrite size_take 1:/# size_drop 1:/# !size_mkseq /#.
move=> i; rewrite size_mkseq => Hi.
by rewrite nth_take 1..2:/# nth_drop 1..2:/# !nth_mkseq /#.
qed.
lemma buf_subl_cat buf (o k n:int):
0 <= o <= k <= n =>
buf_subl buf o k ++ buf_subl buf k n = buf_subl buf o n.
proof.
move=> H; rewrite /buf_subl /=.
rewrite -(cat_take_drop (k-o) (take (n-o) _)).
rewrite take_take ifT 1:/#; congr.
rewrite drop_take 1:/#; congr; first smt().
by rewrite drop_drop /#.
qed.
lemma buf_subl_sub buf o k n l:
0 <= o <= k <= n =>
buf_subl buf o n = l =>
buf_subl buf o k = take (k-o) l.
proof.
move=> H; rewrite /buf_subl => <-.
by rewrite take_take ifT 1:/#.
qed.
*)
hoare comp_u64_l_int_and_u64_l_int_h _a _i1 _b _i2:
Jkem_avx2.M(Jkem_avx2.Syscall).comp_u64_l_int_and_u64_l_int
: arg = (_a,_i1,_b,_i2)
Expand Down
12 changes: 9 additions & 3 deletions proof/correctness/avx2/mlkem_filters_bridge.ec
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ from Jasmin require import JModel_x86.
require import Array32 Array536 Array2048.

require import Correctness.
require import Jkem_avx2 Mlkem_filter48 Mlkem_filter48.
require import Jkem_avx2 (* Mlkem_filter48 *).

abbrev bufl (buf: W8.t Array536.t) = to_list buf.

Expand Down Expand Up @@ -73,9 +73,11 @@ op auxdata_ok (load_shuffle mask bounds ones: W256.t)
/\ ones = Jkem_avx2.sample_ones
/\ sst = Jkem_avx2.sample_shuffle_table.

import Mlkem_filter48_bindings.
require import WArray512 Array40 Array256 Array56 WArray536 WArray2048 IntDiv.

(*
import Mlkem_filter48_bindings.

lemma vmov64_ext_256 b :
zeroextu256 (VMOV_64 b) = zextend_64_256 b.
proof.
Expand Down Expand Up @@ -562,6 +564,8 @@ seq 1 1 : (#{/~plist pol{1} (_ctr + to_uint t0_1{2}) = plist _p _ctr ++ mkseq ("
qed.

require import Bindings. import W12.
*)

lemma buf_rejection_filter48_h _pol _ctr _buf _buf_offset:
hoare [
Jkem_avx2.M(Jkem_avx2.Syscall).__gen_matrix_buf_rejection_filter48
Expand All @@ -577,6 +581,8 @@ hoare [
in plist res.`1 (to_uint _ctr + size l)
= plist _pol (to_uint _ctr) ++ l
/\ res.`2 = W64.of_int (to_uint _ctr + size l)].
admitted.
(*
proof.
conseq (bridge48 (to_uint _ctr) (to_uint _buf_offset) _pol)(filter48P (Array56.init (fun i => _buf.[to_uint _buf_offset+i]))).
+ move => &1 [#] ??????;rewrite /auxdata_ok => [#] ->->->->->.
Expand Down Expand Up @@ -701,7 +707,7 @@ have -> := BitEncoding.BitChunking.nth_flatten witness 8 (map W8.w2bits (take 48
rewrite (nth_map witness);1: by smt(size_take size_drop Array536.size_to_list).
rewrite /w2bits nth_mkseq 1:/# /= nth_take 1,2:/# nth_drop /#.
qed.

*)

lemma buf_rejection_filter48_ll:
islossless Jkem_avx2.M(Jkem_avx2.Syscall).__gen_matrix_buf_rejection_filter48
Expand Down

0 comments on commit f07c57b

Please sign in to comment.