From 26dccbf34c19c8c4a9e363479fe9d521e5384aa1 Mon Sep 17 00:00:00 2001 From: Manuel Barbosa Date: Thu, 12 Sep 2024 00:09:05 +0200 Subject: [PATCH] finished mr compress proof --- .../avx2/MLKEM_Poly_avx2_proof_mr.ec | 199 +++++++++++++----- 1 file changed, 148 insertions(+), 51 deletions(-) diff --git a/proof/correctness/avx2/MLKEM_Poly_avx2_proof_mr.ec b/proof/correctness/avx2/MLKEM_Poly_avx2_proof_mr.ec index d0153e3b..05fea3c8 100644 --- a/proof/correctness/avx2/MLKEM_Poly_avx2_proof_mr.ec +++ b/proof/correctness/avx2/MLKEM_Poly_avx2_proof_mr.ec @@ -355,9 +355,26 @@ op pre16_compress(x : W16.t) : bool = W16.zero \sle x && x \slt W16.of_int (332 lemma post_commutes_compress x : pre16_compress x => to_uint (lane_func_compress x) = compress_alt 4 (incoeff (to_sint x)). -admitted. - -require import EncDecCorrectness. +rewrite /pre16_compress /lane_func_compress /compress_alt sleE sltE /(`<<`) /(`>>`) /= /smod /= /to_sint /= /smod /= /sliceget_4_64 /= W4.to_uintE incoeffK qE /= => [# Hl Hh]. +have Hx : 0 <= to_uint x < 3329 by smt(W16.to_uint_cmp pow2_16). +rewrite ifF 1:/#. +pose intv := (to_uint x %% 3329 * 16 + 1665) * 80635 %/ 268435456 %% 16. +pose wdv := (((zeroextu64 x `<<<` 4) + (of_int 1665)%W64) * (of_int 80635)%W64 `>>>` 28). +rewrite -{1}(cat_take_drop 4 (W64.w2bits wdv)). ++ have -> : W4.bits2w (take 4 (w2bits wdv) ++ drop 4 (w2bits wdv)) = + W4.bits2w (take 4 (w2bits wdv)). + + rewrite W4.wordP => i ib. + by rewrite !get_bits2w // nth_cat ifT; by smt(size_take W64.size_w2bits). + +have -> : take 4 (w2bits wdv) = BitEncoding.BS2Int.int2bs 4 intv; last first. + rewrite bits2wK; 1: by rewrite BitEncoding.BS2Int.size_int2bs;smt(W64.size_w2bits). + by rewrite BitEncoding.BS2Int.int2bsK 1,2:/# //. + +apply BitEncoding.BS2Int.inj_bs2int_eqsize;1:by smt(size_take W64.size_w2bits BitEncoding.BS2Int.size_int2bs). +rewrite BitEncoding.BS2Int.int2bsK //= 1:/# JWordList.bs2int_take //=. +rewrite -W64.to_uintE /wdv. +by rewrite to_uint_shr //= to_uintM_small /= to_uintD_small /= to_uint_shl //= to_uint_zeroextu64 /= /#. +qed. op encode (l : int, ints : int list) : bool list = flatten (map (BitEncoding.BS2Int.int2bs l) ints). @@ -369,7 +386,7 @@ op sem_encode4 (a : int Array256.t) : W8.t Array128.t = Array128.of_list W8.zero lemma sem_encode4_corr: sem_encode4 = encode4. proof. rewrite /sem_encode4 /encode4 fun_ext => x /=. -admitted. +admitted. (* This belong in EncDec but was not proved before *) lemma poly_compress_1_corr_mr_h _a mem : hoare [ Jkem_avx2.M(Syscall)._poly_compress_1 : @@ -435,6 +452,7 @@ rewrite /sem_encode4 /=. have HH : forall k, nth witness (map lane_func_compress (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _aw))])))) k = nth witness (map W4.bits2w (chunk 4 (flatten [flatten (map W8.w2bits (to_list newrp0))]))) k by smt(). rewrite Array128.tP => k kb; rewrite get_of_list // /BitsToBytes /encode. + move : (HH (2*k)). rewrite (nth_map witness); 1: by smt(size_map get_vs_bits_256u16_size). rewrite get_vs_bits_256u16 1:/#. @@ -450,16 +468,50 @@ rewrite flatten1 (nth_map witness) /=. rewrite size_map Array128.size_to_list => ib. rewrite (nth_map witness) /=;1: by smt(Array128.size_to_list). by rewrite (nth_mkseq) /#. -rewrite nth_iota /=. admit. +rewrite nth_iota /=. ++ split;1: by smt(). + move => *;rewrite size_flatten StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /= StdBigop.Bigint.BIA.big_mapT /(\o) /=. + by rewrite StdBigop.Bigint.big_constz /= count_predT size_to_list /= /#. + have -> : 4 * (2 * k) = 8 * k by ring. -rewrite drop_flatten_ctt. admit. +rewrite drop_flatten_ctt. ++ move => x. rewrite mapP => Hx;elim Hx => x0. + by smt(W8.size_w2bits). + rewrite -map_drop. + have /= <- := take_take ((flatten (map W8.w2bits (drop k (to_list newrp0))))) 4 8. + have -> : (take 8 (flatten (map W8.w2bits (drop k (to_list newrp0))))) = - W8.w2bits newrp0.[k]. admit. + W8.w2bits newrp0.[k]. ++ apply (eq_from_nth witness). + rewrite size_w2bits size_take //. + move => *;rewrite size_flatten StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /= StdBigop.Bigint.BIA.big_mapT /(\o) /=. + rewrite StdBigop.Bigint.big_constz /= count_predT. + by smt(size_drop Array128.size_to_list). + ++ move => ii HHa. + have ? : size (take 8 (flatten (map W8.w2bits (drop k (to_list newrp0))))) = 8. + rewrite size_take. + have : 8 <= size (flatten (map W8.w2bits (drop k (to_list newrp0)))); last by smt(). + + move => *;rewrite size_flatten StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /= StdBigop.Bigint.BIA.big_mapT /(\o) /=. + by rewrite StdBigop.Bigint.big_constz /= count_predT size_drop /= 1:/# size_to_list /= /#. + have : 8 <= size (flatten (map W8.w2bits (drop k (to_list newrp0)))); last by smt(). + + move => *;rewrite size_flatten StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /= StdBigop.Bigint.BIA.big_mapT /(\o) /=. + by rewrite StdBigop.Bigint.big_constz /= count_predT size_drop /= 1:/# size_to_list /= /#. + rewrite -map_drop nth_take // 1:/# -map_comp /(\o) /= drop_iota // 1:/# /=. + rewrite (nth_flatten witness 8). + + rewrite allP => x /=. + rewrite mapP => HHx;elim HHx => x0. + rewrite mem_iota /=;smt(W8.size_w2bits). + rewrite (nth_map witness) /=;1:smt(size_iota). + by rewrite nth_iota /#. + move => HH2k. -move : (HH (2*k+1)). + + move : (HH (2*k+1)). + rewrite (nth_map witness); 1: by smt(size_map get_vs_bits_256u16_size). rewrite get_vs_bits_256u16 1:/#. rewrite (nth_map witness); 1: by smt(size_map get_vs_bits_256u16_size). @@ -474,56 +526,101 @@ rewrite flatten1 (nth_map witness) /=. rewrite size_map Array128.size_to_list => ib. rewrite (nth_map witness) /=;1: by smt(Array128.size_to_list). by rewrite (nth_mkseq) /#. -rewrite nth_iota /=. admit. -have -> : 4 * (2 * k + 1) = 4 + 8*k by ring. -rewrite -drop_drop 1,2:/#. -rewrite drop_flatten_ctt. admit. -rewrite -map_drop. -have -> : (drop 4 (flatten (map W8.w2bits (drop k (to_list newrp0))))) = - (drop 4 (take 8 (flatten (map W8.w2bits (drop k (to_list newrp0)))))). -+ apply (eq_from_nth witness). admit. admit. -have -> : (take 8 (flatten (map W8.w2bits (drop k (to_list newrp0))))) = - W8.w2bits newrp0.[k]. admit. -rewrite {1}(:4 = size (drop 4 (w2bits newrp0.[k]))). admit. -rewrite take_size. -move => HH2k1. - -rewrite (nth_map witness). -+ rewrite size_chunk // size_flatten // -map_comp /(\o). - rewrite (eq_map _ (fun _ => 4)) => //=;1: smt(BitEncoding.BS2Int.size_int2bs). - have -> : map (fun (_ : int) => 4) (to_list (map (compress_alt 4) (lift_array256 _aw))) = - mkseq (fun _ => 4) 256; last - by rewrite /mkseq -iotaredE /= /sumz /=. - apply (eq_from_nth witness). - + by rewrite size_map Array256.size_to_list size_mkseq /#. - move => i. - rewrite size_map Array256.size_to_list => ib. - rewrite (nth_map witness) /=;1: by smt(Array256.size_to_list). - by rewrite (nth_mkseq) /#. -rewrite (nth_map witness) /=. -+ rewrite size_iota size_flatten -map_comp /(\o) /=. - have -> : (map (fun (x : int) => size ((BitEncoding.BS2Int.int2bs 4 x))) - (to_list (map (compress_alt 4) (lift_array256 _aw)))) = - mkseq (fun _ => 4) 256; last by - rewrite /mkseq -iotaredE /= /sumz /= /#. - apply (eq_from_nth witness). - + by rewrite size_map Array256.size_to_list size_mkseq /#. - move => i. - rewrite size_map Array256.size_to_list => ib. - rewrite (nth_map witness) /=;1: by smt(Array256.size_to_list). - by rewrite (nth_mkseq);smt(BitEncoding.BS2Int.size_int2bs). -rewrite nth_iota /=. admit. -admit. -qed. - - +rewrite nth_iota /=. ++ split;1: by smt(). + move => *;rewrite size_flatten StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /= StdBigop.Bigint.BIA.big_mapT /(\o) /=. + by rewrite StdBigop.Bigint.big_constz /= count_predT size_to_list /= /#. +have -> : 4 * (2 * k + 1) = 4 + 8 * k by ring. +rewrite -drop_drop 1,2:/#. +rewrite drop_flatten_ctt. ++ move => x. rewrite mapP => Hx;elim Hx => x0. + by smt(W8.size_w2bits). +rewrite -map_drop. +have -> : (take 4 (drop 4 (flatten (map W8.w2bits (drop k (to_list newrp0))))))%W4 = (take 4 (drop 4 (take 8 (flatten (map W8.w2bits (drop k (to_list newrp0)))))))%W4. + + apply (eq_from_nth witness). + + rewrite !size_take // !size_drop // size_take 1:/#. + have : 8 <= size (flatten (map W8.w2bits (drop k (to_list newrp0)))); last by smt(). + + move => *;rewrite size_flatten StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /= StdBigop.Bigint.BIA.big_mapT /(\o) /=. + by rewrite StdBigop.Bigint.big_constz /= count_predT size_drop /= 1:/# size_to_list /= /#. + move => ii ib. + have ? : size (take 4 (drop 4 (flatten (map W8.w2bits (drop k (to_list newrp0)))))) = 4. + + rewrite !size_take // !size_drop //. + rewrite size_flatten StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /= StdBigop.Bigint.BIA.big_mapT /(\o) /=. + by rewrite StdBigop.Bigint.big_constz /= count_predT size_drop 1:/# size_to_list /= /#. + + + by rewrite !nth_take 1..4:/# !nth_drop 1..4:/# nth_take 1,2:/#. +have -> : (take 8 (flatten (map W8.w2bits (drop k (to_list newrp0))))) = + W8.w2bits newrp0.[k]. ++ apply (eq_from_nth witness). + rewrite size_w2bits size_take //. + move => *;rewrite size_flatten StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /= StdBigop.Bigint.BIA.big_mapT /(\o) /=. + rewrite StdBigop.Bigint.big_constz /= count_predT. + by smt(size_drop Array128.size_to_list). + ++ move => ii HHa. + have ? : size (take 8 (flatten (map W8.w2bits (drop k (to_list newrp0))))) = 8. + rewrite size_take. + have : 8 <= size (flatten (map W8.w2bits (drop k (to_list newrp0)))); last by smt(). + + move => *;rewrite size_flatten StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /= StdBigop.Bigint.BIA.big_mapT /(\o) /=. + by rewrite StdBigop.Bigint.big_constz /= count_predT size_drop /= 1:/# size_to_list /= /#. + have : 8 <= size (flatten (map W8.w2bits (drop k (to_list newrp0)))); last by smt(). + + move => *;rewrite size_flatten StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /= StdBigop.Bigint.BIA.big_mapT /(\o) /=. + by rewrite StdBigop.Bigint.big_constz /= count_predT size_drop /= 1:/# size_to_list /= /#. + rewrite -map_drop nth_take // 1:/# -map_comp /(\o) /= drop_iota // 1:/# /=. + rewrite (nth_flatten witness 8). + + rewrite allP => x /=. + rewrite mapP => HHx;elim HHx => x0. + rewrite mem_iota /=;smt(W8.size_w2bits). + rewrite (nth_map witness) /=;1:smt(size_iota). + by rewrite nth_iota /#. +move => HH2k1. +rewrite (nth_map []). ++ split; 1:smt(). + move => *;rewrite size_chunk 1:/# . + rewrite size_flatten StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /= StdBigop.Bigint.BIA.big_mapT /(\o) /=. + rewrite (StdBigop.Bigint.BIA.eq_bigr _ _ (fun _ => 4)); 1: by smt(BitEncoding.BS2Int.size_int2bs). + by rewrite StdBigop.Bigint.big_constz /= count_predT /= size_to_list /= /#. +rewrite JWordList.nth_chunk // 1:/#. ++ rewrite size_flatten StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /= StdBigop.Bigint.BIA.big_mapT /(\o) /=. + rewrite (StdBigop.Bigint.BIA.eq_bigr _ _ (fun _ => 4)); 1: by smt(BitEncoding.BS2Int.size_int2bs). + by rewrite StdBigop.Bigint.big_constz /= count_predT /= size_to_list /= /#. + +apply W8.wordP => jj jb. +rewrite -W8.get_w2bits. +rewrite W8.get_bits2w 1:/#. +rewrite nth_take // 1:/#. +rewrite nth_drop 1,2:/#. +have -> := nth_flatten false 4(map ((BitEncoding.BS2Int.int2bs 4)) (to_list (map (compress_alt 4) (lift_array256 _aw)))) (8 * k + jj) _. ++ rewrite allP => x. + by rewrite mapP => HHx;elim HHx;smt(BitEncoding.BS2Int.size_int2bs). + +rewrite (nth_map witness);1: by smt(Array256.size_to_list). +rewrite (Array256.get_to_list) mapiE // 1:/#. +rewrite /lisft_array256 mapiE // 1:/# /=. +rewrite -post_commutes_compress. ++ by rewrite /pre16_compress sleE sltE /= /smod /= /#. + +case (0<=jj<4). ++ move => ?. + have -> : (8 * k + jj) %/ 4 = 2*k by smt(). + rewrite HH2k to_uintE bits2wK;1: smt(W8.size_w2bits size_take). + rewrite {1}(:4 = size((take 4 (w2bits newrp0.[k]))));1: by smt(W8.size_w2bits size_take). + by rewrite BitEncoding.BS2Int.bs2intK nth_take // 1:/# get_w2bits /#. + +move => ?. +have -> : (8 * k + jj) %/ 4 = 2*k + 1 by smt(). +rewrite HH2k1 to_uintE bits2wK;1: smt(W8.size_w2bits size_take size_drop). + rewrite {1}(:4 = size((take 4 (drop 4 (w2bits newrp0.[k])))));1: by smt(size_drop W8.size_w2bits size_take). + by rewrite BitEncoding.BS2Int.bs2intK nth_take // 1:/# nth_drop // 1:/# get_w2bits /#. +qed.