Skip to content

Commit

Permalink
Factored out MR goal
Browse files Browse the repository at this point in the history
  • Loading branch information
mbbarbosa-lectures committed Oct 24, 2024
1 parent 07e47fa commit 5ed7bbe
Showing 1 changed file with 208 additions and 4 deletions.
212 changes: 208 additions & 4 deletions proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,7 @@ move => *; rewrite /signed_bound_cxq /b16 qE /#.
qed.

import InnerPKE.

lemma mlkem_correct_kg_avx2 mem _pkp _skp :
equiv [Jkem_avx2.M(Jkem_avx2.Syscall).__indcpa_keypair ~ InnerPKE.kg_derand :
Glob.mem{1} = mem /\ to_uint pkp{1} = _pkp /\ to_uint skp{1} = _skp /\
Expand Down Expand Up @@ -1266,6 +1267,189 @@ qed.

(***************************************************)

require import WArray1088 WArray1536 Array4.

module AuxPolyVecCompress10 = {
proc avx2_orig(ctp : W64.t, bp : W16.t Array768.t) : WArray1088.t = {
bp <@ Jkem_avx2.M(Jkem_avx2.Syscall).__polyvec_reduce_sig(bp);
Jkem_avx2.M(Jkem_avx2.Syscall).__polyvec_compress(ctp, bp);
return witness;
}

proc __polyvec_compress_avx2(a : W16.t Array768.t) : WArray1088.t = {
var aux : int;
var b0 : W256.t;
var b1 : W256.t;
var b2 : W256.t;
var mask10 : W256.t;
var shift : W256.t;
var sllv_indx : W256.t;
var shuffle : W256.t;
var i : int;
var a0 : W256.t;
var lo : W128.t;
var hi : W128.t;
var rr : WArray1088.t <- witness;

b0 <- VPBROADCAST_16u16 compress10_b0;
b1 <- VPBROADCAST_16u16 compress10_b1;
b2 <- VPBROADCAST_16u16 pc_shift1_s;
mask10 <- VPBROADCAST_16u16 pvc_mask_s;
shift <- VPBROADCAST_8u32 compress10_shift;
sllv_indx <- VPBROADCAST_4u64 pvc_sllvdidx_s;
shuffle <- get256 ((init8 (fun (i_0 : int) => pvc_shufbidx_s.[i_0])))%WArray32 0;
aux <- 3 * 256 %/ 16;
i <- 0;
while (i < aux){
a0 <- (get256 ((WArray1536.init16 (fun (i_0 : int) => a.[i_0]))) i);
a0 <@ Jkem_avx2.M(Syscall).compress10_16x16_inline(a0, b0, b1, b2, mask10);
(lo, hi) <@ Jkem_avx2.M(Syscall).pack10_16x16(a0, shift, sllv_indx, shuffle);
rr <- WArray1088.set128_direct rr (i*20+0) lo;
rr <- WArray1088.set32_direct rr (i*20+16) (VPEXTR_32 hi W8.zero);
(*
Glob.mem <- storeW128 Glob.mem (to_uint (r + (of_int (i * 20 + 0))%W64)) lo;
Glob.mem <- storeW32 Glob.mem (to_uint (r + (of_int (i * 20 + 16))%W64)) (VPEXTR_32 hi W8.zero);
*)
i <- i + 1;
}

return rr;
}

proc avx2(bp : W16.t Array768.t) : WArray1088.t = {
var rr : WArray1088.t;
bp <@ Jkem_avx2.M(Jkem_avx2.Syscall).__polyvec_reduce_sig(bp);
rr <@ __polyvec_compress_avx2(bp);
return rr;
}

proc ref_orig(ctp : W64.t, bp : W16.t Array768.t) : WArray1088.t = {
bp <@ Jkem.M(Syscall).__polyvec_reduce(bp);
Jkem.M(Syscall).__polyvec_compress(ctp, bp);
return witness;
}

proc __polyvec_compress_ref(a : W16.t Array768.t) : WArray1088.t = {
var aux : int;
var i : int;
var j : int;
var aa : W16.t Array768.t;
var k : int;
var t : W64.t Array4.t;
var c : W16.t;
var b : W16.t;
var rr : WArray1088.t <- witness;

aa <- witness;
t <- witness;
i <- 0;
j <- 0;
aa <@ M(Syscall).__polyvec_csubq(a);
while (i < (3 * 256 - 3)){
k <- 0;
while (k < 4){
t.[k] <- zeroextu64 aa.[i];
i <- i + 1;
t.[k] <- (t.[k]) `<<` (of_int 10)%W8;
t.[k] <- (t.[k]) + (of_int 1665)%W64;
t.[k] <- (t.[k]) * (of_int 1290167)%W64;
t.[k] <- (t.[k]) `>>` (of_int 32)%W8;
t.[k] <- (t.[k]) `&` (of_int 1023)%W64;
k <- k + 1;
}
c <- truncateu16 (t.[0]);
c <- c `&` (of_int 255)%W16;
rr.[j] <- (truncateu8 c);
(*
Glob.mem <- storeW8 Glob.mem (to_uint (rp + j)) (truncateu8 c);
*)
j <- j + 1;
b <- truncateu16 (t.[0]);
b <- b `>>` (of_int 8)%W8;
c <- truncateu16 (t.[1]);
c <- c `<<` (of_int 2)%W8;
c <- c `|` b;
rr.[j] <- (truncateu8 c);
(*
Glob.mem <- storeW8 Glob.mem (to_uint (rp + j)) (truncateu8 c);
*)
j <- j + 1;
b <- truncateu16 (t.[1]);
b <- b `>>` (of_int 6)%W8;
c <- truncateu16 (t.[2]);
c <- c `<<` (of_int 4)%W8;
c <- c `|` b;
rr.[j] <- (truncateu8 c);
(*
Glob.mem <- storeW8 Glob.mem (to_uint (rp + j)) (truncateu8 c);
*)
j <- j + 1;
b <- truncateu16 (t.[2]);
b <- b `>>` (of_int 4)%W8;
c <- truncateu16 (t.[3]);
c <- c `<<` (of_int 6)%W8;
c <- c `|` b;
rr.[j] <- (truncateu8 c);
(*
Glob.mem <- storeW8 Glob.mem (to_uint (rp + j)) (truncateu8 c);
*)
j <- j + 1;
t.[3] <- (t.[3]) `>>` (of_int 2)%W8;
rr.[j] <- (truncateu8 (t.[3]));
(*
Glob.mem <- storeW8 Glob.mem (to_uint (rp + j)) (truncateu8 (t.[3]));
*)
j <- j + 1;
}

return rr;
}

proc ref(bp : W16.t Array768.t) : WArray1088.t = {
var rr : WArray1088.t;
bp <@ Jkem.M(Jkem_avx2.Syscall).__polyvec_reduce(bp);
rr <@ __polyvec_compress_ref(bp);
return rr;
}

}.

lemma compress10_equiv_avx2mem _ctp _mem :
equiv [ AuxPolyVecCompress10.avx2_orig ~ AuxPolyVecCompress10.avx2 :
={bp} /\ ctp{1} = _ctp /\ Glob.mem{1} = _mem /\ valid_ptr (to_uint ctp{1}) (128 + 3 * 320) ==>
Glob.mem{1} = stores _mem (to_uint _ctp) (to_list res{2}) ].
admitted.


lemma compress10_equiv_refmem _ctp _mem :
equiv [ AuxPolyVecCompress10.ref ~ AuxPolyVecCompress10.ref_orig :
={bp} /\ ctp{2} = _ctp /\ Glob.mem{2} = _mem /\ valid_ptr (to_uint ctp{2}) (128 + 3 * 320) ==>
Glob.mem{2} = stores _mem (to_uint _ctp) (to_list res{1}) ].
admitted.

(* MAP REDUCE GOAL *)
lemma compress10_mr :
equiv [AuxPolyVecCompress10.avx2 ~ AuxPolyVecCompress10.ref : lift_array768 bp{1} = lift_array768 bp{2}==> ={res}].
admitted.

lemma compress10_equiv :
equiv [ AuxPolyVecCompress10.avx2_orig ~ AuxPolyVecCompress10.ref_orig :
lift_array768 bp{1} = lift_array768 bp{2} /\ valid_ptr (to_uint ctp{1}) (128 + 3 * 320) /\ ={ctp,Glob.mem} ==> ={Glob.mem}].
proof.
proc* => /=.
exlim Glob.mem{1}, ctp{1} => _mem _ctp.
transitivity {1} { r <@ AuxPolyVecCompress10.avx2(bp); }
(={bp} /\ ctp{1} = _ctp /\ Glob.mem{1} = _mem /\ valid_ptr (to_uint ctp{1}) (128 + 3 * 320) ==>
Glob.mem{1} = stores _mem (to_uint _ctp) (to_list r{2}))
(lift_array768 bp{1} = lift_array768 bp{2} /\ ctp{2} = _ctp /\ Glob.mem{2} = _mem /\ valid_ptr (to_uint ctp{2}) (128 + 3 * 320) ==> Glob.mem{2} = stores _mem (to_uint _ctp) (to_list r{1}));
[ by smt() | by smt() | by call (compress10_equiv_avx2mem _ctp _mem); auto => /> |].
transitivity {2} { r <@ AuxPolyVecCompress10.ref(bp); }
(lift_array768 bp{1} = lift_array768 bp{2} ==> ={r})
(={bp} /\ ctp{2} = _ctp /\ Glob.mem{2} = _mem /\ valid_ptr (to_uint ctp{2}) (128 + 3 * 320) ==> Glob.mem{2} = stores _mem (to_uint _ctp) (to_list r{1}));
[ by smt() | by smt() | | by call (compress10_equiv_refmem _ctp _mem); auto => />].
by call compress10_mr; auto => />.
qed.

lemma mlkem_correct_enc_0_avx2 mem _ctp _pkp :
equiv [Jkem_avx2.M(Jkem_avx2.Syscall).__indcpa_enc_0 ~ InnerPKE.enc_derand:
valid_ptr _pkp (384*3 + 32) /\
Expand Down Expand Up @@ -1328,8 +1512,16 @@ seq 47 56 : (={ctp,Glob.mem} /\
valid_ptr (to_uint ctp{1}) (128+3*320)); last first.
wp; conseq (: _ ==> ={Glob.mem}).
+ auto => /> &1 &2 *; do split;1,2:
smt (W64.to_uintD_small W64.of_uintK W64.to_uint_cmp pow2_64).
admit. (* MAP REDUCE PROOF GOAL *)
smt (W64.to_uintD_small W64.of_uintK W64.to_uint_cmp pow2_64).
transitivity {1} { AuxPolyVecCompress10.avx2_orig(ctp,bp); }
(={bp,ctp,Glob.mem} ==> ={Glob.mem})
( lift_array768 bp{1} = lift_array768 bp{2} /\ valid_ptr (to_uint ctp{1}) (128 + 3 * 320) /\ ={ctp,Glob.mem} ==> ={Glob.mem});
[ smt() | smt() | by inline*;sim |].
transitivity {2} { AuxPolyVecCompress10.ref_orig(ctp,bp); }
( lift_array768 bp{1} = lift_array768 bp{2} /\ valid_ptr (to_uint ctp{1}) (128 + 3 * 320) /\ ={ctp,Glob.mem} ==> ={Glob.mem})
(={bp,ctp,Glob.mem} ==> ={Glob.mem});
[ smt() | smt() | | by inline *;sim].
by call compress10_equiv;auto => />.

wp;conseq />.
call (reduceequiv_noperm).
Expand Down Expand Up @@ -1707,8 +1899,20 @@ seq 46 58 : (={ctp0,Glob.mem} /\ Glob.mem{1} = mem /\
lift_array256 v{1} = lift_array256 v{2} /\
lift_array768 bp{1} = lift_array768 bp{2}); last first.
+ wp; conseq />; conseq (: _ ==> aux_4{1} = aux_0{2}); 1: by smt().
admit. (* MAP REDUCE GOAL *)

admit.
(*
+ auto => /> &1 &2 *; do split;1,2:
smt (W64.to_uintD_small W64.of_uintK W64.to_uint_cmp pow2_64).
transitivity {1} { AuxPolyVecCompress10.avx2_orig(ctp,bp); }
(={bp,ctp,Glob.mem} ==> ={Glob.mem})
( lift_array768 bp{1} = lift_array768 bp{2} /\ valid_ptr (to_uint ctp{1}) (128 + 3 * 320) /\ ={ctp,Glob.mem} ==> ={Glob.mem});
[ smt() | smt() | by inline*;sim |].
transitivity {2} { AuxPolyVecCompress10.ref_orig(ctp,bp); }
( lift_array768 bp{1} = lift_array768 bp{2} /\ valid_ptr (to_uint ctp{1}) (128 + 3 * 320) /\ ={ctp,Glob.mem} ==> ={Glob.mem})
(={bp,ctp,Glob.mem} ==> ={Glob.mem});
[ smt() | smt() | | by inline *;sim].
by call compress10_equiv;auto => />.
*)
wp;conseq />.
call (reduceequiv_noperm).
call (addequiv_noperm 4 2 _ _) => //.
Expand Down

0 comments on commit 5ed7bbe

Please sign in to comment.