Skip to content

Commit

Permalink
saving some progress
Browse files Browse the repository at this point in the history
  • Loading branch information
mbbarbosa-lectures committed Dec 6, 2024
1 parent 30e50e3 commit bb6d264
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 152 deletions.
140 changes: 71 additions & 69 deletions proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,9 @@ import MLKEM_PolyVec.
import MLKEM_PolyvecAVX.
import MLKEM_PolyAVXVec.
import NTT_Avx2.
(*import WArray136 WArray32 WArray128.*)
import WArray32 WArray33 WArray128.
import WArray512 WArray256.



(* shake assumptions *)

(*
Expand Down Expand Up @@ -105,17 +102,10 @@ call keccakf1600_round_ll; auto.
move => /> ??; rewrite ultE to_uintD_small to_uint_small //= /#.
qed.

lemma sha3ll : islossless Jkem.M(Jkem.Syscall)._shake256_128_33.
proof.
proc.
unroll for 9; wp; conseq => /=.
call keccakf1600_ll; auto.
conseq => /=.
unroll for ^while; auto.
conseq => /=.
inline *; unroll for ^while; auto.
qed.
*)
lemma sha3ll : islossless M(Syscall)._shake256_128_33.
admitted.


(*
axiom shake128_equiv_absorb : equiv [ M(Syscall)._shake128_absorb34 ~
Expand Down Expand Up @@ -277,10 +267,10 @@ lemma mask55_bits16 k:
proof.
move=> Hk.
rewrite /VPBROADCAST_8u32.
rewrite bits16_W8u32 Hk //= get_of_list 1:/# /=.
rewrite bits16_W8u32 Hk /= get_of_list 1:/# /=.
rewrite (nth_map 0) /=; first smt(size_iota).
have: (k%%2 \in iota_ 0 2) by smt(mem_iota).
by move: (k%%2); rewrite -allP -iotaredE /= W2u16.bits16_div //.
by move: (k%%2); rewrite -allP -iotaredE /= W2u16.bits16_div .
qed.

lemma mask55_bits8 k:
Expand All @@ -289,10 +279,10 @@ lemma mask55_bits8 k:
proof.
move=> Hk.
rewrite /VPBROADCAST_8u32.
rewrite bits8_W8u32 Hk //= get_of_list 1:/# /=.
rewrite bits8_W8u32 Hk get_of_list 1:/# /=.
rewrite (nth_map 0) /=; first smt(size_iota).
have: (k%%4 \in iota_ 0 4) by smt(mem_iota).
by move: (k%%4); rewrite -allP -iotaredE /= W4u8.bits8_div //.
by move: (k%%4); rewrite -allP -iotaredE /= W4u8.bits8_div.
qed.

lemma mask33_bits16 k:
Expand All @@ -301,10 +291,10 @@ lemma mask33_bits16 k:
proof.
move=> Hk.
rewrite /VPBROADCAST_8u32.
rewrite bits16_W8u32 Hk //= get_of_list 1:/# /=.
rewrite bits16_W8u32 Hk get_of_list 1:/# /=.
rewrite (nth_map 0) /=; first smt(size_iota).
have: (k%%2 \in iota_ 0 2) by smt(mem_iota).
by move: (k%%2); rewrite -allP -iotaredE /= W2u16.bits16_div //.
by move: (k%%2); rewrite -allP -iotaredE /= W2u16.bits16_div .
qed.

lemma mask33_bits8 k:
Expand All @@ -313,10 +303,10 @@ lemma mask33_bits8 k:
proof.
move=> Hk.
rewrite /VPBROADCAST_8u32.
rewrite bits8_W8u32 Hk //= get_of_list 1:/# /=.
rewrite bits8_W8u32 Hk get_of_list 1:/# /=.
rewrite (nth_map 0) /=; first smt(size_iota).
have: (k%%4 \in iota_ 0 4) by smt(mem_iota).
by move: (k%%4); rewrite -allP -iotaredE /= W4u8.bits8_div //.
by move: (k%%4); rewrite -allP -iotaredE /= W4u8.bits8_div .
qed.

lemma mask03_bits8 k:
Expand All @@ -325,10 +315,10 @@ lemma mask03_bits8 k:
proof.
move=> Hk.
rewrite /VPBROADCAST_8u32.
rewrite bits8_W8u32 Hk //= get_of_list 1:/# /=.
rewrite bits8_W8u32 Hk get_of_list 1:/# /=.
rewrite (nth_map 0) /=; first smt(size_iota).
have: (k%%4 \in iota_ 0 4) by smt(mem_iota).
by move: (k%%4); rewrite -allP -iotaredE /= W4u8.bits8_div //.
by move: (k%%4); rewrite -allP -iotaredE /= W4u8.bits8_div .
qed.

lemma mask0F_bits16 k:
Expand All @@ -337,10 +327,10 @@ lemma mask0F_bits16 k:
proof.
move=> Hk.
rewrite /VPBROADCAST_8u32.
rewrite bits16_W8u32 Hk //= get_of_list 1:/# /=.
rewrite bits16_W8u32 Hk get_of_list 1:/# /=.
rewrite (nth_map 0) /=; first smt(size_iota).
have: (k%%2 \in iota_ 0 2) by smt(mem_iota).
by move: (k%%2); rewrite -allP -iotaredE /= W2u16.bits16_div //.
by move: (k%%2); rewrite -allP -iotaredE /= W2u16.bits16_div .
qed.

lemma mask0F_bits8 k:
Expand All @@ -349,10 +339,10 @@ lemma mask0F_bits8 k:
proof.
move=> Hk.
rewrite /VPBROADCAST_8u32.
rewrite bits8_W8u32 Hk //= get_of_list 1:/# /=.
rewrite bits8_W8u32 Hk get_of_list 1:/# /=.
rewrite (nth_map 0) /=; first smt(size_iota).
have: (k%%4 \in iota_ 0 4) by smt(mem_iota).
by move: (k%%4); rewrite -allP -iotaredE /= W4u8.bits8_div //.
by move: (k%%4); rewrite -allP -iotaredE /= W4u8.bits8_div .
qed.

lemma VPSRL1_ANDmask55 w k:
Expand Down Expand Up @@ -409,16 +399,16 @@ rewrite andwC andw_orwDr orw_disjoint.
by move: k; apply/List.allP; rewrite -iotaredE /int_bit /=.
have ->: w `&` (mask03u8 `<<<` 4)
= ((w `>>>` 4) `&` W8.masklsb (6-4)) `<<<` 4.
rewrite -shlw_andmask // shrl_andmaskN // -andwA /=.
rewrite -shlw_andmask 1:/# shrl_andmaskN 1:/# -andwA /=.
congr.
rewrite /max /=.
apply W8.wordP => k; rewrite -mem_range /range /=.
by move: k; apply/List.allP; rewrite -iotaredE /int_bit /=.
have E1: to_uint (w `&` mask03u8) = to_uint w %% 4.
by rewrite (W8.to_uint_and_mod 2) //.
by rewrite (W8.to_uint_and_mod 2).
have /= E2: to_uint ((w `>>>` 4) `&` (masklsb (6-4))%W8 `<<<` 4) = to_uint w %/ 16 %% 4 * 16.
rewrite /max /= to_uint_shl // (W8.to_uint_and_mod 2) //.
by rewrite to_uint_shr //= modz_small /#.
rewrite /max /= to_uint_shl 1:/# (W8.to_uint_and_mod 2) 1:/#.
by rewrite to_uint_shr 1:/# /= modz_small /#.
rewrite to_uintD_small /=.
by rewrite E1 E2 /#.
by rewrite E1 E2 /#.
Expand Down Expand Up @@ -619,20 +609,20 @@ while (0 <= i <= 4 /\ #{~i}pre /\ List.all (fun k => rp.[k]=W16.of_int (noise_co
by rewrite /R2C /= Array16.initiE /#.
case: (k %/ 16 = 4 * i{m} + 2) => C2.
move: (H (k %% 64) _) => /=; first smt(mem_iota).
rewrite (modz_pow_div 2 6 4) //= C2 (mulzC 4) modzMDl /=.
rewrite (modz_dvd_pow 4 6 _ 2) //.
rewrite (modz_pow_div 2 6 4) 1,2:/# /= C2 (mulzC 4) modzMDl /=.
rewrite (modz_dvd_pow 4 6 _ 2) 1:/#.
have ->: 64 * i{m} + k %% 64 = k by smt().
by rewrite /R2C /= Array16.initiE /#.
case: (k %/ 16 = 4 * i{m} + 1) => C3.
move: (H (k %% 64) _) => /=; first smt(mem_iota).
rewrite (modz_pow_div 2 6 4) //= C3 (mulzC 4) modzMDl /=.
rewrite (modz_dvd_pow 4 6 _ 2) //.
rewrite (modz_pow_div 2 6 4) 1,2:/# /= C3 (mulzC 4) modzMDl /=.
rewrite (modz_dvd_pow 4 6 _ 2) 1:/#.
have ->: 64 * i{m} + k %% 64 = k by smt().
by rewrite /R2C /= Array16.initiE /#.
case: (k %/ 16 = 4 * i{m}) => C4.
move: (H (k %% 64) _) => /=; first smt(mem_iota).
rewrite (modz_pow_div 2 6 4) //= C4 modzMr.
rewrite (modz_dvd_pow 4 6 _ 2) //.
rewrite (modz_pow_div 2 6 4) 1,2:/# /= C4 modzMr.
rewrite (modz_dvd_pow 4 6 _ 2) 1:/#.
have ->: 64 * i{m} + k %% 64 = k by smt().
by rewrite /R2C /= Array16.initiE /#.
have ?: k < 64*i{m} by smt().
Expand Down Expand Up @@ -729,15 +719,15 @@ while (to_uint i <= 128 /\ #pre /\ List.all (fun k => rp.[k]=W16.of_int (noise_c
case: (k = 2 * to_uint i{m}) => C1.
rewrite /noise_coef !get_setE 1..2:/# C1 /= ifF 1:/#.
have ->/=: 2 * to_uint i{m} %/ 2 = to_uint i{m} by smt().
rewrite -to_sint_eq sigextu16_to_sint (_: 3 = 2^2 -1) // !and_mod //= W8_of_sintK_signed /=; 1: smt().
rewrite -to_sint_eq sigextu16_to_sint (_: 3 = 2^2 -1) 1:/# !and_mod 1,2:/# /= W8_of_sintK_signed /=; 1: smt().
have -> /= : 2 * to_uint i{m} %% 2 = 0 by smt().
by rewrite -parallel_noisesum_low smod_small // /#.
by rewrite -parallel_noisesum_low smod_small /#.
case: (k = 2 * to_uint i{m}+1) => C2.
rewrite /noise_coef !get_setE 1..2:/# C2 /=.
have ->/=: (2 * to_uint i{m} + 1) %/ 2 = to_uint i{m} by smt().
rewrite -to_sint_eq sigextu16_to_sint (_: 3 = 2^2 -1) // !and_mod //= W8_of_sintK_signed /=; 1: smt().
rewrite -to_sint_eq sigextu16_to_sint (_: 3 = 2^2 -1) 1:/# !and_mod 1,2:/# /= W8_of_sintK_signed /=; 1: smt().
have -> /= : (2 * to_uint i{m}+1) %% 2 = 1 by smt().
by rewrite -parallel_noisesum_high smod_small // /#.
by rewrite -parallel_noisesum_high smod_small /#.
rewrite !get_setE 1..2:/# C1 C2 /=; apply IH.
smt(mem_iota).
auto => &m |> *.
Expand All @@ -747,13 +737,13 @@ have ->/=: to_uint i = 128 by smt().
rewrite tP => /List.allP H k Hk.
rewrite (H k _) /=.
smt(mem_iota).
by rewrite initiE //.
by rewrite initiE /#.
qed.
lemma cbd2_ref_ll : islossless AuxMLKEMAvx2.cbd2_ref.
proc. inline*. sp; wp. while (true) (128 - W64.to_uint i). move => z.
auto => /> &hr. rewrite ultE of_uintK //= => H. rewrite to_uintD_small //= /#.
auto => />i H. rewrite ultE of_uintK //= 1:/#. qed.
auto => /> &hr. rewrite ultE of_uintK /= => H. rewrite to_uintD_small /= /#.
auto => />i H. rewrite ultE of_uintK /= 1:/#. qed.
phoare cbd2_ref_ph _bytes:
[AuxMLKEMAvx2.cbd2_ref: buf=_bytes ==> res = Array256.init (fun k => W16.of_int (noise_coef _bytes k))] = 1%r.
Expand All @@ -775,55 +765,63 @@ equiv getnoise_1x_equiv_avx :
proc*. inline Jkem_avx2.M(Jkem_avx2.Syscall).__poly_cbd_eta1. rcondt{1} 3. auto => />. sp;wp.
ecall{1} (cbd2_avx2_ph buf{1}) => />.
ecall{2} (cbd2_ref_ph buf{2}) => />.
auto => /> &2. rewrite tP => i Hi. rewrite initiE //=. qed.
auto => /> &2. rewrite tP => i Hi. rewrite initiE /#. qed.
equiv getnoise_4x_split :
GetNoiseAVX2._poly_getnoise_eta1_4x ~ AuxMLKEMAvx2.__poly_getnoise_eta1_4x : ={arg} ==> ={res}.
proc; wp; sp => />. call getnoise_split => />. call getnoise_split => />. call getnoise_split => />. call getnoise_split => />. auto => />. qed.
require import Array4 Array33 Array128.
axiom shake256_4x_128_32 _seed _nonces :
phoare [
Jkem_avx2.M(Jkem_avx2.Syscall)._shake256x4_A128__A32_A1 : arg.`5 = _seed /\ arg.`6 = _nonces ==>
res.`1 =
SHAKE256_33_128 _seed _nonces.[0] /\
res.`2 =
SHAKE256_33_128 _seed _nonces.[1] /\
res.`3 =
SHAKE256_33_128 _seed _nonces.[2] /\
res.`4 =
SHAKE256_33_128 _seed _nonces.[3]
] = 1%r.
equiv getnoiseequiv_avx :
Jkem_avx2.M(Jkem_avx2.Syscall)._poly_getnoise_eta1_4x ~ GetNoiseAVX2._poly_getnoise_eta1_4x : ={arg} ==> ={res}.
proc*.
transitivity{2} { r <@ AuxMLKEMAvx2.__poly_getnoise_eta1_4x(aux3,aux2,aux1,aux0,noiseseed,nonce); } ((r0{1}, r1{1}, r2{1}, r3{1}, seed{1}, nonce{1}) = (aux3{2}, aux2{2}, aux1{2}, aux0{2}, noiseseed{2}, nonce{2}) ==> ={r}) (={aux3,aux2,aux1,aux0,noiseseed,nonce} ==> ={r}); last first.
symmetry. call getnoise_4x_split => />. auto => />. smt(). smt().
(*main proof*)
admit(*
inline Jkem_avx2.M(Jkem_avx2.Syscall)._poly_getnoise_eta1_4x AuxMLKEMAvx2.__poly_getnoise_eta1_4x AuxMLKEMAvx2._poly_getnoise. swap{2} [30..31] 5. swap{2} [23..24] 10. swap{2} [16..17] 15.
seq 25 30 : (
r00{1}=rp{2} /\ Array128.init (fun (i : int) => buf0{1}.[i]) =buf{2}
/\ r10{1}=rp0{2} /\ Array128.init (fun (i : int) => buf1{1}.[i]) =buf0{2}
/\ r20{1}=rp1{2} /\ Array128.init (fun (i : int) => buf2{1}.[i]) =buf1{2}
/\ r30{1}=rp2{2} /\ Array128.init (fun (i : int) => buf3{1}.[i]) =buf2{2}
seq 27 30 : (
r00{1}=rp{2} /\ buf0{1} =buf{2}
/\ r10{1}=rp0{2} /\ buf1{1} =buf0{2}
/\ r20{1}=rp1{2} /\ buf2{1} =buf1{2}
/\ r30{1}=rp2{2} /\ buf3{1} =buf2{2}
).
sp => />.
ecall{2} (shake256_33_128 buf2{2} (Array33.init (fun i => if i = 32 then nonce4{2} else seed2{2}.[i]))); wp => />.
ecall{2} (shake256_33_128 buf1{2} (Array33.init (fun i => if i = 32 then nonce3{2} else seed1{2}.[i]))); wp => />.
ecall{2} (shake256_33_128 buf0{2} (Array33.init (fun i => if i = 32 then nonce2{2} else seed0{2}.[i]))); wp => />.
ecall{2} (shake256_33_128 buf{2} (Array33.init (fun i=> if i = 32 then nonce1{2} else seed{2}.[i]))); wp => />.
ecall{1} (shake_squeezenblocks4x state{1} buf0{1} buf1{1} buf2{1} buf3{1}); wp => />.
ecall{1} (shake_absorb4x state{1} (Array33.init (fun i => buf0{1}.[i])) (Array33.init (fun i => buf1{1}.[i])) (Array33.init (fun i => buf2{1}.[i])) (Array33.init (fun i => buf3{1}.[i])) ); wp => />.
auto => /> &2. rewrite shake4x_equiv => />.
rewrite tP => k Hk; rewrite !initiE //= 1..3:/#; rewrite ifF 1:/#; rewrite /get8 /init8 set_neqiE 1:/#; rewrite initiE //= 1:/#; rewrite initiE //= 1:/#; rewrite set256E initiE //= 1:/#; rewrite ifT //; rewrite /get256_direct bits8_W32u8 //=; rewrite ifT //; rewrite initiE //=; rewrite initiE //=.
rewrite tP => k Hk; rewrite !initiE //= 1..3:/#; rewrite ifF 1:/#; rewrite /get8 /init8 set_neqiE 1:/#; rewrite initiE //= 1:/#; rewrite initiE //= 1:/#; rewrite set256E initiE //= 1:/#; rewrite ifT //; rewrite /get256_direct bits8_W32u8 //=; rewrite ifT //; rewrite initiE //=; rewrite initiE //=.
rewrite tP => k Hk; rewrite !initiE //= 1..3:/#; rewrite ifF 1:/#; rewrite /get8 /init8 set_neqiE 1:/#; rewrite initiE //= 1:/#; rewrite initiE //= 1:/#; rewrite set256E initiE //= 1:/#; rewrite ifT //; rewrite /get256_direct bits8_W32u8 //=; rewrite ifT //; rewrite initiE //=; rewrite initiE //=.
rewrite tP => k Hk; rewrite !initiE //= 1..3:/#; rewrite ifF 1:/#; rewrite /get8 /init8 set_neqiE 1:/#; rewrite initiE //= 1:/#; rewrite initiE //= 1:/#; rewrite set256E initiE //= 1:/#; rewrite ifT //; rewrite /get256_direct bits8_W32u8 //=; rewrite ifT //; rewrite initiE //=; rewrite initiE //=.
ecall{1} (shake256_4x_128_32 seed0{1} nonces{1}).
auto => /> &1 rr ->->->->;do split;congr;
rewrite tP => k kb;rewrite !initiE /# /=.
wp. call getnoise_1x_equiv_avx => />.
wp. call getnoise_1x_equiv_avx => />.
wp. call getnoise_1x_equiv_avx => />.
wp. call getnoise_1x_equiv_avx => />.
auto => />.
*).
by auto.
qed.
lemma polygetnoise_ll : islossless Jkem.M(Jkem.Syscall)._poly_getnoise.
proc.
admit(*
while (0 <= to_uint i <= 128) (128 - to_uint i);
1: by move => z; auto => />;rewrite ultE /= => &hr ???; rewrite !to_uintD_small /=; smt(to_uint_cmp).
wp; call sha3ll; wp; while (0<=k<=32) (32 -k); 1: by move => z; auto=> /> /#.
auto => /> *; do split; 1:smt().
by move => *; rewrite ultE /=; smt().
*).
qed.
equiv getnoiseequiv :
Expand All @@ -834,20 +832,25 @@ have H : forall &m a,
Pr[Jkem.M(Jkem.Syscall)._poly_getnoise(a) @ &m : forall k, 0<=k<256 => -5 < to_sint res.[k] < 5] = 1%r.
+ move => &m a.
have -> : 1%r = Pr [ CBD2.sample(PRF a.`2 a.`3) @ &m : true].
+ byphoare => //.
+ byphoare;2..:smt().
proc; inline *; while (0<=i<=128) (128-i); 1: by move => z; auto => /> /#.
by auto => /> /#.
by byequiv get_noise_sample_noise => //.
have HH0 : hoare [Jkem.M(Jkem.Syscall)._poly_getnoise : true ==> forall k, 0<=k<256 => -5 < to_sint res.[k] < 5].
+ hoare; bypr => //= &m; rewrite Pr[mu_not].
+ hoare; bypr => &m; rewrite Pr[mu_not].
have -> : Pr[Jkem.M(Jkem.Syscall)._poly_getnoise(rp{m}, s_seed{m}, nonce{m}) @ &m : true] = 1%r.
+ by byphoare => //; apply polygetnoise_ll.
+ by byphoare;2..:smt(); apply polygetnoise_ll.
smt().
have HHH : equiv [ Jkem.M(Jkem.Syscall)._poly_getnoise ~Jkem.M(Jkem.Syscall)._poly_getnoise : ={arg} ==> ={res} ] by sim.
conseq HHH HH0.
move => *; rewrite /signed_bound_cxq /b16 qE /#.
qed.
axiom sha3equiv :
equiv [ Jkem_avx2.M(Jkem_avx2.Syscall)._sha3_512A_A33
~ M(Syscall)._sha3512_33 :
={arg} ==> ={res} ].
import InnerPKE.
lemma mlkem_correct_kg_avx2 mem _pkp _skp :
equiv [Jkem_avx2.M(Jkem_avx2.Syscall).__indcpa_keypair ~ InnerPKE.kg_derand :
Expand Down Expand Up @@ -887,15 +890,14 @@ call (polyvec_tobytes_equiv _skp).
wp;conseq />. smt().
ecall (polyvec_reduce_equiv (lift_array768 pkpv{2})).
have H := polyvec_add2_equiv 2 2 _ _ => //.
have H := polyvec_add2_equiv 2 2 _ _;1,2:smt().
ecall (H (lift_array768 pkpv{2}) (lift_array768 e{2})); clear H.
unroll for* {1} 37.
sp 3 3.
seq 16 18 : (#pre /\ ={publicseed, noiseseed,e,skpv,pkpv} /\ sskp{2} = skp{1} /\ spkp{2} = pkp{1}).
admit(* 1: by
sp; conseq />; sim 2 2; call( sha3equiv); conseq />; sim. *).
seq 16 18 : (#pre /\ ={publicseed, noiseseed,e,skpv,pkpv} /\ sskp{2} = skp{1} /\ spkp{2} = pkp{1});
1: by sp; conseq />; sim 3 3;call( sha3equiv); conseq />; sim.
sp 0 2.
seq 2 2 : (#pre /\ aa{1} = nttunpackm a{2} /\
Expand Down
Loading

0 comments on commit bb6d264

Please sign in to comment.