Skip to content

Commit

Permalink
finished mr compress proof
Browse files Browse the repository at this point in the history
  • Loading branch information
mbbarbosa-lectures committed Sep 11, 2024
1 parent 5c518b2 commit 26dccbf
Showing 1 changed file with 148 additions and 51 deletions.
199 changes: 148 additions & 51 deletions proof/correctness/avx2/MLKEM_Poly_avx2_proof_mr.ec
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,26 @@ op pre16_compress(x : W16.t) : bool = W16.zero \sle x && x \slt W16.of_int (332

lemma post_commutes_compress x : pre16_compress x =>
to_uint (lane_func_compress x) = compress_alt 4 (incoeff (to_sint x)).
admitted.

require import EncDecCorrectness.
rewrite /pre16_compress /lane_func_compress /compress_alt sleE sltE /(`<<`) /(`>>`) /= /smod /= /to_sint /= /smod /= /sliceget_4_64 /= W4.to_uintE incoeffK qE /= => [# Hl Hh].
have Hx : 0 <= to_uint x < 3329 by smt(W16.to_uint_cmp pow2_16).
rewrite ifF 1:/#.
pose intv := (to_uint x %% 3329 * 16 + 1665) * 80635 %/ 268435456 %% 16.
pose wdv := (((zeroextu64 x `<<<` 4) + (of_int 1665)%W64) * (of_int 80635)%W64 `>>>` 28).
rewrite -{1}(cat_take_drop 4 (W64.w2bits wdv)).
+ have -> : W4.bits2w (take 4 (w2bits wdv) ++ drop 4 (w2bits wdv)) =
W4.bits2w (take 4 (w2bits wdv)).
+ rewrite W4.wordP => i ib.
by rewrite !get_bits2w // nth_cat ifT; by smt(size_take W64.size_w2bits).

have -> : take 4 (w2bits wdv) = BitEncoding.BS2Int.int2bs 4 intv; last first.
rewrite bits2wK; 1: by rewrite BitEncoding.BS2Int.size_int2bs;smt(W64.size_w2bits).
by rewrite BitEncoding.BS2Int.int2bsK 1,2:/# //.

apply BitEncoding.BS2Int.inj_bs2int_eqsize;1:by smt(size_take W64.size_w2bits BitEncoding.BS2Int.size_int2bs).
rewrite BitEncoding.BS2Int.int2bsK //= 1:/# JWordList.bs2int_take //=.
rewrite -W64.to_uintE /wdv.
by rewrite to_uint_shr //= to_uintM_small /= to_uintD_small /= to_uint_shl //= to_uint_zeroextu64 /= /#.
qed.

op encode (l : int, ints : int list) : bool list =
flatten (map (BitEncoding.BS2Int.int2bs l) ints).
Expand All @@ -369,7 +386,7 @@ op sem_encode4 (a : int Array256.t) : W8.t Array128.t = Array128.of_list W8.zero
lemma sem_encode4_corr: sem_encode4 = encode4.
proof.
rewrite /sem_encode4 /encode4 fun_ext => x /=.
admitted.
admitted. (* This belong in EncDec but was not proved before *)

lemma poly_compress_1_corr_mr_h _a mem :
hoare [ Jkem_avx2.M(Syscall)._poly_compress_1 :
Expand Down Expand Up @@ -435,6 +452,7 @@ rewrite /sem_encode4 /=.
have HH : forall k, nth witness (map lane_func_compress (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _aw))])))) k =
nth witness (map W4.bits2w (chunk 4 (flatten [flatten (map W8.w2bits (to_list newrp0))]))) k by smt().
rewrite Array128.tP => k kb; rewrite get_of_list // /BitsToBytes /encode.

move : (HH (2*k)).
rewrite (nth_map witness); 1: by smt(size_map get_vs_bits_256u16_size).
rewrite get_vs_bits_256u16 1:/#.
Expand All @@ -450,16 +468,50 @@ rewrite flatten1 (nth_map witness) /=.
rewrite size_map Array128.size_to_list => ib.
rewrite (nth_map witness) /=;1: by smt(Array128.size_to_list).
by rewrite (nth_mkseq) /#.
rewrite nth_iota /=. admit.
rewrite nth_iota /=.
+ split;1: by smt().
move => *;rewrite size_flatten StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /= StdBigop.Bigint.BIA.big_mapT /(\o) /=.
by rewrite StdBigop.Bigint.big_constz /= count_predT size_to_list /= /#.

have -> : 4 * (2 * k) = 8 * k by ring.
rewrite drop_flatten_ctt. admit.
rewrite drop_flatten_ctt.
+ move => x. rewrite mapP => Hx;elim Hx => x0.
by smt(W8.size_w2bits).

rewrite -map_drop.

have /= <- := take_take ((flatten (map W8.w2bits (drop k (to_list newrp0))))) 4 8.

have -> : (take 8 (flatten (map W8.w2bits (drop k (to_list newrp0))))) =
W8.w2bits newrp0.[k]. admit.
W8.w2bits newrp0.[k].
+ apply (eq_from_nth witness).
rewrite size_w2bits size_take //.
move => *;rewrite size_flatten StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /= StdBigop.Bigint.BIA.big_mapT /(\o) /=.
rewrite StdBigop.Bigint.big_constz /= count_predT.
by smt(size_drop Array128.size_to_list).

+ move => ii HHa.
have ? : size (take 8 (flatten (map W8.w2bits (drop k (to_list newrp0))))) = 8.
rewrite size_take.
have : 8 <= size (flatten (map W8.w2bits (drop k (to_list newrp0)))); last by smt().
+ move => *;rewrite size_flatten StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /= StdBigop.Bigint.BIA.big_mapT /(\o) /=.
by rewrite StdBigop.Bigint.big_constz /= count_predT size_drop /= 1:/# size_to_list /= /#.
have : 8 <= size (flatten (map W8.w2bits (drop k (to_list newrp0)))); last by smt().
+ move => *;rewrite size_flatten StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /= StdBigop.Bigint.BIA.big_mapT /(\o) /=.
by rewrite StdBigop.Bigint.big_constz /= count_predT size_drop /= 1:/# size_to_list /= /#.
rewrite -map_drop nth_take // 1:/# -map_comp /(\o) /= drop_iota // 1:/# /=.
rewrite (nth_flatten witness 8).
+ rewrite allP => x /=.
rewrite mapP => HHx;elim HHx => x0.
rewrite mem_iota /=;smt(W8.size_w2bits).
rewrite (nth_map witness) /=;1:smt(size_iota).
by rewrite nth_iota /#.

move => HH2k.

move : (HH (2*k+1)).

move : (HH (2*k+1)).

rewrite (nth_map witness); 1: by smt(size_map get_vs_bits_256u16_size).
rewrite get_vs_bits_256u16 1:/#.
rewrite (nth_map witness); 1: by smt(size_map get_vs_bits_256u16_size).
Expand All @@ -474,56 +526,101 @@ rewrite flatten1 (nth_map witness) /=.
rewrite size_map Array128.size_to_list => ib.
rewrite (nth_map witness) /=;1: by smt(Array128.size_to_list).
by rewrite (nth_mkseq) /#.
rewrite nth_iota /=. admit.
have -> : 4 * (2 * k + 1) = 4 + 8*k by ring.
rewrite -drop_drop 1,2:/#.
rewrite drop_flatten_ctt. admit.
rewrite -map_drop.
have -> : (drop 4 (flatten (map W8.w2bits (drop k (to_list newrp0))))) =
(drop 4 (take 8 (flatten (map W8.w2bits (drop k (to_list newrp0)))))).
+ apply (eq_from_nth witness). admit. admit.
have -> : (take 8 (flatten (map W8.w2bits (drop k (to_list newrp0))))) =
W8.w2bits newrp0.[k]. admit.
rewrite {1}(:4 = size (drop 4 (w2bits newrp0.[k]))). admit.
rewrite take_size.
move => HH2k1.

rewrite (nth_map witness).
+ rewrite size_chunk // size_flatten // -map_comp /(\o).
rewrite (eq_map _ (fun _ => 4)) => //=;1: smt(BitEncoding.BS2Int.size_int2bs).
have -> : map (fun (_ : int) => 4) (to_list (map (compress_alt 4) (lift_array256 _aw))) =
mkseq (fun _ => 4) 256; last
by rewrite /mkseq -iotaredE /= /sumz /=.
apply (eq_from_nth witness).
+ by rewrite size_map Array256.size_to_list size_mkseq /#.
move => i.
rewrite size_map Array256.size_to_list => ib.
rewrite (nth_map witness) /=;1: by smt(Array256.size_to_list).
by rewrite (nth_mkseq) /#.
rewrite (nth_map witness) /=.
+ rewrite size_iota size_flatten -map_comp /(\o) /=.
have -> : (map (fun (x : int) => size ((BitEncoding.BS2Int.int2bs 4 x)))
(to_list (map (compress_alt 4) (lift_array256 _aw)))) =
mkseq (fun _ => 4) 256; last by
rewrite /mkseq -iotaredE /= /sumz /= /#.
apply (eq_from_nth witness).
+ by rewrite size_map Array256.size_to_list size_mkseq /#.
move => i.
rewrite size_map Array256.size_to_list => ib.
rewrite (nth_map witness) /=;1: by smt(Array256.size_to_list).
by rewrite (nth_mkseq);smt(BitEncoding.BS2Int.size_int2bs).
rewrite nth_iota /=. admit.
admit.
qed.


rewrite nth_iota /=.
+ split;1: by smt().
move => *;rewrite size_flatten StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /= StdBigop.Bigint.BIA.big_mapT /(\o) /=.
by rewrite StdBigop.Bigint.big_constz /= count_predT size_to_list /= /#.

have -> : 4 * (2 * k + 1) = 4 + 8 * k by ring.
rewrite -drop_drop 1,2:/#.
rewrite drop_flatten_ctt.
+ move => x. rewrite mapP => Hx;elim Hx => x0.
by smt(W8.size_w2bits).

rewrite -map_drop.

have -> : (take 4 (drop 4 (flatten (map W8.w2bits (drop k (to_list newrp0))))))%W4 = (take 4 (drop 4 (take 8 (flatten (map W8.w2bits (drop k (to_list newrp0)))))))%W4.
+ apply (eq_from_nth witness).
+ rewrite !size_take // !size_drop // size_take 1:/#.
have : 8 <= size (flatten (map W8.w2bits (drop k (to_list newrp0)))); last by smt().
+ move => *;rewrite size_flatten StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /= StdBigop.Bigint.BIA.big_mapT /(\o) /=.
by rewrite StdBigop.Bigint.big_constz /= count_predT size_drop /= 1:/# size_to_list /= /#.
move => ii ib.
have ? : size (take 4 (drop 4 (flatten (map W8.w2bits (drop k (to_list newrp0)))))) = 4.
+ rewrite !size_take // !size_drop //.
rewrite size_flatten StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /= StdBigop.Bigint.BIA.big_mapT /(\o) /=.
by rewrite StdBigop.Bigint.big_constz /= count_predT size_drop 1:/# size_to_list /= /#.


by rewrite !nth_take 1..4:/# !nth_drop 1..4:/# nth_take 1,2:/#.

have -> : (take 8 (flatten (map W8.w2bits (drop k (to_list newrp0))))) =
W8.w2bits newrp0.[k].
+ apply (eq_from_nth witness).
rewrite size_w2bits size_take //.
move => *;rewrite size_flatten StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /= StdBigop.Bigint.BIA.big_mapT /(\o) /=.
rewrite StdBigop.Bigint.big_constz /= count_predT.
by smt(size_drop Array128.size_to_list).

+ move => ii HHa.
have ? : size (take 8 (flatten (map W8.w2bits (drop k (to_list newrp0))))) = 8.
rewrite size_take.
have : 8 <= size (flatten (map W8.w2bits (drop k (to_list newrp0)))); last by smt().
+ move => *;rewrite size_flatten StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /= StdBigop.Bigint.BIA.big_mapT /(\o) /=.
by rewrite StdBigop.Bigint.big_constz /= count_predT size_drop /= 1:/# size_to_list /= /#.
have : 8 <= size (flatten (map W8.w2bits (drop k (to_list newrp0)))); last by smt().
+ move => *;rewrite size_flatten StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /= StdBigop.Bigint.BIA.big_mapT /(\o) /=.
by rewrite StdBigop.Bigint.big_constz /= count_predT size_drop /= 1:/# size_to_list /= /#.
rewrite -map_drop nth_take // 1:/# -map_comp /(\o) /= drop_iota // 1:/# /=.
rewrite (nth_flatten witness 8).
+ rewrite allP => x /=.
rewrite mapP => HHx;elim HHx => x0.
rewrite mem_iota /=;smt(W8.size_w2bits).
rewrite (nth_map witness) /=;1:smt(size_iota).
by rewrite nth_iota /#.

move => HH2k1.

rewrite (nth_map []).
+ split; 1:smt().
move => *;rewrite size_chunk 1:/# .
rewrite size_flatten StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /= StdBigop.Bigint.BIA.big_mapT /(\o) /=.
rewrite (StdBigop.Bigint.BIA.eq_bigr _ _ (fun _ => 4)); 1: by smt(BitEncoding.BS2Int.size_int2bs).
by rewrite StdBigop.Bigint.big_constz /= count_predT /= size_to_list /= /#.
rewrite JWordList.nth_chunk // 1:/#.
+ rewrite size_flatten StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /= StdBigop.Bigint.BIA.big_mapT /(\o) /=.
rewrite (StdBigop.Bigint.BIA.eq_bigr _ _ (fun _ => 4)); 1: by smt(BitEncoding.BS2Int.size_int2bs).
by rewrite StdBigop.Bigint.big_constz /= count_predT /= size_to_list /= /#.

apply W8.wordP => jj jb.
rewrite -W8.get_w2bits.
rewrite W8.get_bits2w 1:/#.
rewrite nth_take // 1:/#.
rewrite nth_drop 1,2:/#.
have -> := nth_flatten false 4(map ((BitEncoding.BS2Int.int2bs 4)) (to_list (map (compress_alt 4) (lift_array256 _aw)))) (8 * k + jj) _.
+ rewrite allP => x.
by rewrite mapP => HHx;elim HHx;smt(BitEncoding.BS2Int.size_int2bs).

rewrite (nth_map witness);1: by smt(Array256.size_to_list).
rewrite (Array256.get_to_list) mapiE // 1:/#.
rewrite /lisft_array256 mapiE // 1:/# /=.
rewrite -post_commutes_compress.
+ by rewrite /pre16_compress sleE sltE /= /smod /= /#.

case (0<=jj<4).
+ move => ?.
have -> : (8 * k + jj) %/ 4 = 2*k by smt().
rewrite HH2k to_uintE bits2wK;1: smt(W8.size_w2bits size_take).
rewrite {1}(:4 = size((take 4 (w2bits newrp0.[k]))));1: by smt(W8.size_w2bits size_take).
by rewrite BitEncoding.BS2Int.bs2intK nth_take // 1:/# get_w2bits /#.

move => ?.
have -> : (8 * k + jj) %/ 4 = 2*k + 1 by smt().
rewrite HH2k1 to_uintE bits2wK;1: smt(W8.size_w2bits size_take size_drop).
rewrite {1}(:4 = size((take 4 (drop 4 (w2bits newrp0.[k])))));1: by smt(size_drop W8.size_w2bits size_take).
by rewrite BitEncoding.BS2Int.bs2intK nth_take // 1:/# nth_drop // 1:/# get_w2bits /#.

qed.



Expand Down

0 comments on commit 26dccbf

Please sign in to comment.