Skip to content

Commit

Permalink
Gets ugly when non-aligned
Browse files Browse the repository at this point in the history
  • Loading branch information
mbbarbosa-lectures committed Aug 28, 2024
1 parent 8a45ec4 commit 5c518b2
Showing 1 changed file with 130 additions and 52 deletions.
182 changes: 130 additions & 52 deletions proof/correctness/avx2/MLKEM_Poly_avx2_proof_mr.ec
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,13 @@ qed.
op lane_func_csubq(x : W16.t) = if W16.of_int 3329 \sle x then x - W16.of_int 3329 else x.
op pre16_csubq(x : W16.t) : bool = W16.zero \sle x && x \slt W16.of_int (6658).


lemma post_commutes_csubq x : pre16_csubq x =>
incoeff (to_sint x) = incoeff (to_sint (lane_func_csubq x)) /\ bpos16 (lane_func_csubq x) q.
proof.
rewrite /pre16_csubq /lane_func_csubq sltE !sleE /= /smod /= -eq_incoeff.
case (3329 <= to_sint x);last by smt().
by move => *;rewrite to_sintB_small /= /smod /= /#.
qed.

lemma poly_csubq_corr_h ap :
hoare[ Jkem_avx2.M(Syscall)._poly_csubq :
Expand All @@ -235,7 +241,7 @@ lemma poly_csubq_corr_h ap :
==>
ap = lift_array256 res /\
pos_bound256_cxq res 0 256 1 ].
proc.
proc.
proc rewrite 1 sliceget_256_16_16E.
unroll for 3.
proc rewrite 3 sliceget_256_256_16E.
Expand Down Expand Up @@ -283,31 +289,22 @@ proc.
pre16_csubq _ _);
[ by move => x; rewrite /pre16_csubq sleE sltE /= /to_sint /smod /= qE //= | by smt() ].

move => &hr WRONG_POST.
rewrite /lift_array256 /pos_bound256_cxq;move => &hr [#] H H0 rp0new WRONG_POST.
have PRE :
(_rp0 = rp{hr} /\ pos_bound256_cxq _rp0 0 256 2) /\
map lane_func_csubq (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _rp0))]))) =
map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list rp{hr}))])).
map lane_func_csubq (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _rp0))]))) =
map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list rp0new{hr}))])).
apply WRONG_POST.
clear WRONG_POST.
have : forall k, nth witness (map lane_func_csubq (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _rp0))])))) k =
nth witness (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list rp{hr}))]))) k by smt().
move => Hk;rewrite /lift_array256 /pos_bound256_cxq tP;split => i ib /=.
rewrite !mapiE //=.
move : (Hk i).
rewrite get_vs_bits_256u16 //=.
rewrite (nth_map witness); 1: by smt(size_map get_vs_bits_256u16_size).
rewrite get_vs_bits_256u16 //=.
rewrite /lane_func_csubq sleE /= /smod /= -eq_incoeff.
case (3329 <= to_sint _rp0.[i]);last by smt().
by move => ? <-; rewrite to_sintB_small /= /smod /= /#.
move : (Hk i).
rewrite get_vs_bits_256u16 //=.
rewrite (nth_map witness); 1: by smt(size_map get_vs_bits_256u16_size).
rewrite get_vs_bits_256u16 //=.
rewrite /lane_func_csubq sleE /= /smod /= qE.
case (3329 <= to_sint _rp0.[i]);last by smt().
by move => ? <-; rewrite to_sintB_small /= /smod /= /#.
have HH : forall k, nth witness (map lane_func_csubq (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _rp0))])))) k = nth witness (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list rp0new))]))) k by smt().
rewrite Array256.tP;split => k kb.
+ rewrite !Array256.mapiE //=;move : (HH k).
rewrite (nth_map witness);1: smt(size_map get_vs_bits_256u16_size).
rewrite !get_vs_bits_256u16 1,2:/#.
by smt(post_commutes_csubq pow2_16 W16.of_uintK W16.to_uintK ).
move : (HH k).
rewrite (nth_map witness) /=;1: smt(size_map get_vs_bits_256u16_size).
have?: 2^(16-1) = 32768 by auto.
by smt(post_commutes_csubq get_vs_bits_256u16 pow2_16 W16.of_uintK W16.to_uintK ).
qed.

lemma poly_csubq_ll : islossless Jkem_avx2.M(Syscall)._poly_csubq.
Expand Down Expand Up @@ -350,13 +347,30 @@ op compress_alt (d : int) (c : coeff) : int =
op sliceget_4_64 (bw: W64.t) (i: int) : W4.t = W4.bits2w (W64.w2bits bw).


op lane_func_compress1(x : W16.t) : W4.t = sliceget_4_64 ((((W4u16.zeroextu64 x) `<<` W8.of_int 4) + W64.of_int (1665)) * (W64.of_int (80635) `>>` W8.of_int 28)) 0.

op lane_func_compress(x : W16.t) : W4.t = sliceget_4_64 (
(((W4u16.zeroextu64 x) `<<` W8.of_int 4) + W64.of_int 1665) * (W64.of_int 80635) `>>` W8.of_int 28) 0.

op pre16_compress(x : W16.t) : bool = W16.zero \sle x && x \slt W16.of_int (3329).


lemma post_commutes_compress x : pre16_compress x =>
to_uint (lane_func_compress x) = compress_alt 4 (incoeff (to_sint x)).
admitted.

require import EncDecCorrectness.

op encode (l : int, ints : int list) : bool list =
flatten (map (BitEncoding.BS2Int.int2bs l) ints).

op BitsToBytes(bits : bool list) : W8.t list = (map W8.bits2w (chunk 8 bits)).

op sem_encode4 (a : int Array256.t) : W8.t Array128.t = Array128.of_list W8.zero (BitsToBytes (encode 4 (to_list a))).

lemma sem_encode4_corr: sem_encode4 = encode4.
proof.
rewrite /sem_encode4 /encode4 fun_ext => x /=.
admitted.

lemma poly_compress_1_corr_mr_h _a mem :
hoare [ Jkem_avx2.M(Syscall)._poly_compress_1 :
pos_bound256_cxq a 0 256 2 /\
Expand All @@ -368,7 +382,6 @@ lemma poly_compress_1_corr_mr_h _a mem :
pos_bound256_cxq res.`2 0 256 1 /\
res.`1 = encode4 (compress_poly 4 _a)].
proof.
rewrite /compress_poly. have -> : (compress 4) = (compress_alt 4). smt(fun_ext compress_alt_compress).
proc.
seq 2 : (pos_bound256_cxq a 0 256 1 /\ lift_array256 a = _a /\ Glob.mem = mem); 1:by call (poly_csubq_corr_h _a); auto => /> /#.
cfold 7.
Expand Down Expand Up @@ -396,10 +409,13 @@ proc rewrite 54 sliceset_256_128_8E.
proc rewrite 78 sliceset_256_128_8E.
proc rewrite 102 sliceset_256_128_8E.
cfold 7.
conseq />.
conseq />.
exists * a ;elim * =>_aw.
conseq (: _aw = a /\ pos_bound256_cxq _aw 0 256 1 ==>
rp = encode4 (map (compress_alt 4) (lift_array256 _aw)));1,2: smt().
rp = sem_encode4 (map (compress_alt 4) (lift_array256 _aw))); 1:smt().
+ move => &hr [#] *.
rewrite -sem_encode4_corr /compress_poly.
by have -> /# : (compress 4) = (compress_alt 4); 1: by rewrite fun_ext => *;smt(compress_alt_compress).

bdep 16 4 [ "_aw" ] [ "a" ] [ "rp" ] lane_func_compress pre16_compress.

Expand All @@ -409,33 +425,95 @@ bdep 16 4 [ "_aw" ] [ "a" ] [ "rp" ] lane_func_compress pre16_compress.
pre16_compress _ _);
[ by move => x; rewrite /pre16_compress sleE sltE /= /to_sint /smod /= qE //= | by smt() ].

move => &hr WRONG_POST.
move => &hr [#] H H0 newrp0 WRONG_POST.
have PRE :
((_aw = a{hr} /\ pos_bound256_cxq _aw 0 256 1) /\
map lane_func_compress (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _aw))]))) =
map W4.bits2w (chunk 4 (flatten [flatten (map W8.w2bits (to_list rp{hr}))]))).
map lane_func_compress (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _aw))]))) =
map W4.bits2w (chunk 4 (flatten [flatten (map W8.w2bits (to_list newrp0))])).
apply WRONG_POST.
clear WRONG_POST.
have : forall k, nth witness (map lane_func_compress (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _aw))])))) k =
nth witness (map W4.bits2w (chunk 4 (flatten [flatten (map W8.w2bits (to_list rp{hr}))]))) k by smt().
move => Hk;rewrite /lift_array256 /pos_bound256_cxq tP => i ib /=.
move : (Hk i).
rewrite /sem_encode4 /=.
have HH : forall k, nth witness (map lane_func_compress (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _aw))])))) k =
nth witness (map W4.bits2w (chunk 4 (flatten [flatten (map W8.w2bits (to_list newrp0))]))) k by smt().
rewrite Array128.tP => k kb; rewrite get_of_list // /BitsToBytes /encode.
move : (HH (2*k)).
rewrite (nth_map witness); 1: by smt(size_map get_vs_bits_256u16_size).
rewrite get_vs_bits_256u16 1:/#.
rewrite (nth_map witness); 1: by smt(size_map get_vs_bits_256u16_size).
rewrite flatten1 (nth_map witness) /=.
+ rewrite size_iota size_flatten -map_comp /(\o) /=.
have -> : map (fun (_ : W8.t) => 8) (to_list newrp0) =
mkseq (fun _ => 8) 128; last by
rewrite /mkseq -iotaredE /= /sumz /= /#.
apply (eq_from_nth witness).
+ by rewrite size_map Array128.size_to_list size_mkseq /#.
move => i.
rewrite size_map Array128.size_to_list => ib.
rewrite (nth_map witness) /=;1: by smt(Array128.size_to_list).
by rewrite (nth_mkseq) /#.
rewrite nth_iota /=. admit.
have -> : 4 * (2 * k) = 8 * k by ring.
rewrite drop_flatten_ctt. admit.
rewrite -map_drop.
have /= <- := take_take ((flatten (map W8.w2bits (drop k (to_list newrp0))))) 4 8.
have -> : (take 8 (flatten (map W8.w2bits (drop k (to_list newrp0))))) =
W8.w2bits newrp0.[k]. admit.
move => HH2k.

move : (HH (2*k+1)).
rewrite (nth_map witness); 1: by smt(size_map get_vs_bits_256u16_size).
rewrite get_vs_bits_256u16 1:/#.
rewrite (nth_map witness); 1: by smt(size_map get_vs_bits_256u16_size).
rewrite flatten1 (nth_map witness) /=.
+ rewrite size_iota size_flatten -map_comp /(\o) /=.
have -> : map (fun (_ : W8.t) => 8) (to_list newrp0) =
mkseq (fun _ => 8) 128; last by
rewrite /mkseq -iotaredE /= /sumz /= /#.
apply (eq_from_nth witness).
+ by rewrite size_map Array128.size_to_list size_mkseq /#.
move => i.
rewrite size_map Array128.size_to_list => ib.
rewrite (nth_map witness) /=;1: by smt(Array128.size_to_list).
by rewrite (nth_mkseq) /#.
rewrite nth_iota /=. admit.
have -> : 4 * (2 * k + 1) = 4 + 8*k by ring.
rewrite -drop_drop 1,2:/#.
rewrite drop_flatten_ctt. admit.
rewrite -map_drop.
have -> : (drop 4 (flatten (map W8.w2bits (drop k (to_list newrp0))))) =
(drop 4 (take 8 (flatten (map W8.w2bits (drop k (to_list newrp0)))))).
+ apply (eq_from_nth witness). admit. admit.
have -> : (take 8 (flatten (map W8.w2bits (drop k (to_list newrp0))))) =
W8.w2bits newrp0.[k]. admit.
rewrite {1}(:4 = size (drop 4 (w2bits newrp0.[k]))). admit.
rewrite take_size.
move => HH2k1.

rewrite (nth_map witness).
+ rewrite size_chunk // size_flatten // -map_comp /(\o).
rewrite (eq_map _ (fun _ => 4)) => //=;1: smt(BitEncoding.BS2Int.size_int2bs).
have -> : map (fun (_ : int) => 4) (to_list (map (compress_alt 4) (lift_array256 _aw))) =
mkseq (fun _ => 4) 256; last
by rewrite /mkseq -iotaredE /= /sumz /=.
apply (eq_from_nth witness).
+ by rewrite size_map Array256.size_to_list size_mkseq /#.
move => i.
rewrite size_map Array256.size_to_list => ib.
rewrite (nth_map witness) /=;1: by smt(Array256.size_to_list).
by rewrite (nth_mkseq) /#.
rewrite (nth_map witness) /=.
+ rewrite size_iota size_flatten -map_comp /(\o) /=.
have -> : (map (fun (x : int) => size ((BitEncoding.BS2Int.int2bs 4 x)))
(to_list (map (compress_alt 4) (lift_array256 _aw)))) =
mkseq (fun _ => 4) 256; last by
rewrite /mkseq -iotaredE /= /sumz /= /#.
apply (eq_from_nth witness).
+ by rewrite size_map Array256.size_to_list size_mkseq /#.
move => i.
rewrite size_map Array256.size_to_list => ib.
rewrite (nth_map witness) /=;1: by smt(Array256.size_to_list).
by rewrite (nth_mkseq);smt(BitEncoding.BS2Int.size_int2bs).
rewrite nth_iota /=. admit.
admit.
(*
rewrite get_vs_bits_256u16 //=.
rewrite (nth_map witness); 1: by smt(size_map get_vs_bits_256u16_size).
rewrite get_vs_bits_256u16 //=.
rewrite /lane_func_csubq sleE /= /smod /= -eq_incoeff.
case (3329 <= to_sint _rp0.[i]);last by smt().
by move => ? <-; rewrite to_sintB_small /= /smod /= /#.
move : (Hk i).
rewrite get_vs_bits_256u16 //=.
rewrite (nth_map witness); 1: by smt(size_map get_vs_bits_256u16_size).
rewrite get_vs_bits_256u16 //=.
rewrite /lane_func_csubq sleE /= /smod /= qE.
case (3329 <= to_sint _rp0.[i]);last by smt().
by move => ? <-; rewrite to_sintB_small /= /smod /= /#.
*)
qed.


Expand Down

0 comments on commit 5c518b2

Please sign in to comment.