diff --git a/proof/correctness/avx2/MLKEM_Poly_avx2_proof_mr.ec b/proof/correctness/avx2/MLKEM_Poly_avx2_proof_mr.ec index 06ef49d2..d0153e3b 100644 --- a/proof/correctness/avx2/MLKEM_Poly_avx2_proof_mr.ec +++ b/proof/correctness/avx2/MLKEM_Poly_avx2_proof_mr.ec @@ -226,7 +226,13 @@ qed. op lane_func_csubq(x : W16.t) = if W16.of_int 3329 \sle x then x - W16.of_int 3329 else x. op pre16_csubq(x : W16.t) : bool = W16.zero \sle x && x \slt W16.of_int (6658). - +lemma post_commutes_csubq x : pre16_csubq x => + incoeff (to_sint x) = incoeff (to_sint (lane_func_csubq x)) /\ bpos16 (lane_func_csubq x) q. +proof. +rewrite /pre16_csubq /lane_func_csubq sltE !sleE /= /smod /= -eq_incoeff. +case (3329 <= to_sint x);last by smt(). +by move => *;rewrite to_sintB_small /= /smod /= /#. +qed. lemma poly_csubq_corr_h ap : hoare[ Jkem_avx2.M(Syscall)._poly_csubq : @@ -235,7 +241,7 @@ lemma poly_csubq_corr_h ap : ==> ap = lift_array256 res /\ pos_bound256_cxq res 0 256 1 ]. -proc. +proc. proc rewrite 1 sliceget_256_16_16E. unroll for 3. proc rewrite 3 sliceget_256_256_16E. @@ -283,31 +289,22 @@ proc. pre16_csubq _ _); [ by move => x; rewrite /pre16_csubq sleE sltE /= /to_sint /smod /= qE //= | by smt() ]. -move => &hr WRONG_POST. +rewrite /lift_array256 /pos_bound256_cxq;move => &hr [#] H H0 rp0new WRONG_POST. have PRE : -(_rp0 = rp{hr} /\ pos_bound256_cxq _rp0 0 256 2) /\ - map lane_func_csubq (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _rp0))]))) = - map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list rp{hr}))])). +map lane_func_csubq (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _rp0))]))) = + map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list rp0new{hr}))])). apply WRONG_POST. clear WRONG_POST. -have : forall k, nth witness (map lane_func_csubq (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _rp0))])))) k = - nth witness (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list rp{hr}))]))) k by smt(). -move => Hk;rewrite /lift_array256 /pos_bound256_cxq tP;split => i ib /=. -rewrite !mapiE //=. -move : (Hk i). -rewrite get_vs_bits_256u16 //=. -rewrite (nth_map witness); 1: by smt(size_map get_vs_bits_256u16_size). -rewrite get_vs_bits_256u16 //=. -rewrite /lane_func_csubq sleE /= /smod /= -eq_incoeff. -case (3329 <= to_sint _rp0.[i]);last by smt(). -by move => ? <-; rewrite to_sintB_small /= /smod /= /#. -move : (Hk i). -rewrite get_vs_bits_256u16 //=. -rewrite (nth_map witness); 1: by smt(size_map get_vs_bits_256u16_size). -rewrite get_vs_bits_256u16 //=. -rewrite /lane_func_csubq sleE /= /smod /= qE. -case (3329 <= to_sint _rp0.[i]);last by smt(). -by move => ? <-; rewrite to_sintB_small /= /smod /= /#. +have HH : forall k, nth witness (map lane_func_csubq (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _rp0))])))) k = nth witness (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list rp0new))]))) k by smt(). +rewrite Array256.tP;split => k kb. ++ rewrite !Array256.mapiE //=;move : (HH k). + rewrite (nth_map witness);1: smt(size_map get_vs_bits_256u16_size). + rewrite !get_vs_bits_256u16 1,2:/#. + by smt(post_commutes_csubq pow2_16 W16.of_uintK W16.to_uintK ). +move : (HH k). +rewrite (nth_map witness) /=;1: smt(size_map get_vs_bits_256u16_size). +have?: 2^(16-1) = 32768 by auto. +by smt(post_commutes_csubq get_vs_bits_256u16 pow2_16 W16.of_uintK W16.to_uintK ). qed. lemma poly_csubq_ll : islossless Jkem_avx2.M(Syscall)._poly_csubq. @@ -350,13 +347,30 @@ op compress_alt (d : int) (c : coeff) : int = op sliceget_4_64 (bw: W64.t) (i: int) : W4.t = W4.bits2w (W64.w2bits bw). -op lane_func_compress1(x : W16.t) : W4.t = sliceget_4_64 ((((W4u16.zeroextu64 x) `<<` W8.of_int 4) + W64.of_int (1665)) * (W64.of_int (80635) `>>` W8.of_int 28)) 0. - op lane_func_compress(x : W16.t) : W4.t = sliceget_4_64 ( (((W4u16.zeroextu64 x) `<<` W8.of_int 4) + W64.of_int 1665) * (W64.of_int 80635) `>>` W8.of_int 28) 0. op pre16_compress(x : W16.t) : bool = W16.zero \sle x && x \slt W16.of_int (3329). + +lemma post_commutes_compress x : pre16_compress x => + to_uint (lane_func_compress x) = compress_alt 4 (incoeff (to_sint x)). +admitted. + +require import EncDecCorrectness. + +op encode (l : int, ints : int list) : bool list = + flatten (map (BitEncoding.BS2Int.int2bs l) ints). + +op BitsToBytes(bits : bool list) : W8.t list = (map W8.bits2w (chunk 8 bits)). + +op sem_encode4 (a : int Array256.t) : W8.t Array128.t = Array128.of_list W8.zero (BitsToBytes (encode 4 (to_list a))). + +lemma sem_encode4_corr: sem_encode4 = encode4. +proof. +rewrite /sem_encode4 /encode4 fun_ext => x /=. +admitted. + lemma poly_compress_1_corr_mr_h _a mem : hoare [ Jkem_avx2.M(Syscall)._poly_compress_1 : pos_bound256_cxq a 0 256 2 /\ @@ -368,7 +382,6 @@ lemma poly_compress_1_corr_mr_h _a mem : pos_bound256_cxq res.`2 0 256 1 /\ res.`1 = encode4 (compress_poly 4 _a)]. proof. -rewrite /compress_poly. have -> : (compress 4) = (compress_alt 4). smt(fun_ext compress_alt_compress). proc. seq 2 : (pos_bound256_cxq a 0 256 1 /\ lift_array256 a = _a /\ Glob.mem = mem); 1:by call (poly_csubq_corr_h _a); auto => /> /#. cfold 7. @@ -396,10 +409,13 @@ proc rewrite 54 sliceset_256_128_8E. proc rewrite 78 sliceset_256_128_8E. proc rewrite 102 sliceset_256_128_8E. cfold 7. -conseq />. +conseq />. exists * a ;elim * =>_aw. conseq (: _aw = a /\ pos_bound256_cxq _aw 0 256 1 ==> - rp = encode4 (map (compress_alt 4) (lift_array256 _aw)));1,2: smt(). + rp = sem_encode4 (map (compress_alt 4) (lift_array256 _aw))); 1:smt(). ++ move => &hr [#] *. + rewrite -sem_encode4_corr /compress_poly. + by have -> /# : (compress 4) = (compress_alt 4); 1: by rewrite fun_ext => *;smt(compress_alt_compress). bdep 16 4 [ "_aw" ] [ "a" ] [ "rp" ] lane_func_compress pre16_compress. @@ -409,33 +425,95 @@ bdep 16 4 [ "_aw" ] [ "a" ] [ "rp" ] lane_func_compress pre16_compress. pre16_compress _ _); [ by move => x; rewrite /pre16_compress sleE sltE /= /to_sint /smod /= qE //= | by smt() ]. -move => &hr WRONG_POST. +move => &hr [#] H H0 newrp0 WRONG_POST. have PRE : -((_aw = a{hr} /\ pos_bound256_cxq _aw 0 256 1) /\ - map lane_func_compress (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _aw))]))) = - map W4.bits2w (chunk 4 (flatten [flatten (map W8.w2bits (to_list rp{hr}))]))). +map lane_func_compress (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _aw))]))) = +map W4.bits2w (chunk 4 (flatten [flatten (map W8.w2bits (to_list newrp0))])). apply WRONG_POST. clear WRONG_POST. -have : forall k, nth witness (map lane_func_compress (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _aw))])))) k = - nth witness (map W4.bits2w (chunk 4 (flatten [flatten (map W8.w2bits (to_list rp{hr}))]))) k by smt(). -move => Hk;rewrite /lift_array256 /pos_bound256_cxq tP => i ib /=. -move : (Hk i). +rewrite /sem_encode4 /=. +have HH : forall k, nth witness (map lane_func_compress (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _aw))])))) k = + nth witness (map W4.bits2w (chunk 4 (flatten [flatten (map W8.w2bits (to_list newrp0))]))) k by smt(). +rewrite Array128.tP => k kb; rewrite get_of_list // /BitsToBytes /encode. +move : (HH (2*k)). +rewrite (nth_map witness); 1: by smt(size_map get_vs_bits_256u16_size). +rewrite get_vs_bits_256u16 1:/#. +rewrite (nth_map witness); 1: by smt(size_map get_vs_bits_256u16_size). +rewrite flatten1 (nth_map witness) /=. ++ rewrite size_iota size_flatten -map_comp /(\o) /=. + have -> : map (fun (_ : W8.t) => 8) (to_list newrp0) = + mkseq (fun _ => 8) 128; last by + rewrite /mkseq -iotaredE /= /sumz /= /#. + apply (eq_from_nth witness). + + by rewrite size_map Array128.size_to_list size_mkseq /#. + move => i. + rewrite size_map Array128.size_to_list => ib. + rewrite (nth_map witness) /=;1: by smt(Array128.size_to_list). + by rewrite (nth_mkseq) /#. +rewrite nth_iota /=. admit. +have -> : 4 * (2 * k) = 8 * k by ring. +rewrite drop_flatten_ctt. admit. +rewrite -map_drop. +have /= <- := take_take ((flatten (map W8.w2bits (drop k (to_list newrp0))))) 4 8. +have -> : (take 8 (flatten (map W8.w2bits (drop k (to_list newrp0))))) = + W8.w2bits newrp0.[k]. admit. +move => HH2k. + +move : (HH (2*k+1)). +rewrite (nth_map witness); 1: by smt(size_map get_vs_bits_256u16_size). +rewrite get_vs_bits_256u16 1:/#. +rewrite (nth_map witness); 1: by smt(size_map get_vs_bits_256u16_size). +rewrite flatten1 (nth_map witness) /=. ++ rewrite size_iota size_flatten -map_comp /(\o) /=. + have -> : map (fun (_ : W8.t) => 8) (to_list newrp0) = + mkseq (fun _ => 8) 128; last by + rewrite /mkseq -iotaredE /= /sumz /= /#. + apply (eq_from_nth witness). + + by rewrite size_map Array128.size_to_list size_mkseq /#. + move => i. + rewrite size_map Array128.size_to_list => ib. + rewrite (nth_map witness) /=;1: by smt(Array128.size_to_list). + by rewrite (nth_mkseq) /#. +rewrite nth_iota /=. admit. +have -> : 4 * (2 * k + 1) = 4 + 8*k by ring. +rewrite -drop_drop 1,2:/#. +rewrite drop_flatten_ctt. admit. +rewrite -map_drop. +have -> : (drop 4 (flatten (map W8.w2bits (drop k (to_list newrp0))))) = + (drop 4 (take 8 (flatten (map W8.w2bits (drop k (to_list newrp0)))))). ++ apply (eq_from_nth witness). admit. admit. +have -> : (take 8 (flatten (map W8.w2bits (drop k (to_list newrp0))))) = + W8.w2bits newrp0.[k]. admit. +rewrite {1}(:4 = size (drop 4 (w2bits newrp0.[k]))). admit. +rewrite take_size. +move => HH2k1. + +rewrite (nth_map witness). ++ rewrite size_chunk // size_flatten // -map_comp /(\o). + rewrite (eq_map _ (fun _ => 4)) => //=;1: smt(BitEncoding.BS2Int.size_int2bs). + have -> : map (fun (_ : int) => 4) (to_list (map (compress_alt 4) (lift_array256 _aw))) = + mkseq (fun _ => 4) 256; last + by rewrite /mkseq -iotaredE /= /sumz /=. + apply (eq_from_nth witness). + + by rewrite size_map Array256.size_to_list size_mkseq /#. + move => i. + rewrite size_map Array256.size_to_list => ib. + rewrite (nth_map witness) /=;1: by smt(Array256.size_to_list). + by rewrite (nth_mkseq) /#. +rewrite (nth_map witness) /=. ++ rewrite size_iota size_flatten -map_comp /(\o) /=. + have -> : (map (fun (x : int) => size ((BitEncoding.BS2Int.int2bs 4 x))) + (to_list (map (compress_alt 4) (lift_array256 _aw)))) = + mkseq (fun _ => 4) 256; last by + rewrite /mkseq -iotaredE /= /sumz /= /#. + apply (eq_from_nth witness). + + by rewrite size_map Array256.size_to_list size_mkseq /#. + move => i. + rewrite size_map Array256.size_to_list => ib. + rewrite (nth_map witness) /=;1: by smt(Array256.size_to_list). + by rewrite (nth_mkseq);smt(BitEncoding.BS2Int.size_int2bs). +rewrite nth_iota /=. admit. admit. -(* -rewrite get_vs_bits_256u16 //=. -rewrite (nth_map witness); 1: by smt(size_map get_vs_bits_256u16_size). -rewrite get_vs_bits_256u16 //=. -rewrite /lane_func_csubq sleE /= /smod /= -eq_incoeff. -case (3329 <= to_sint _rp0.[i]);last by smt(). -by move => ? <-; rewrite to_sintB_small /= /smod /= /#. -move : (Hk i). -rewrite get_vs_bits_256u16 //=. -rewrite (nth_map witness); 1: by smt(size_map get_vs_bits_256u16_size). -rewrite get_vs_bits_256u16 //=. -rewrite /lane_func_csubq sleE /= /smod /= qE. -case (3329 <= to_sint _rp0.[i]);last by smt(). -by move => ? <-; rewrite to_sintB_small /= /smod /= /#. -*) qed.