diff --git a/proof/correctness/avx2/MLKEM_genmatrix_avx2.ec b/proof/correctness/avx2/MLKEM_genmatrix_avx2.ec index 42c6e5b5..0368a8e8 100644 --- a/proof/correctness/avx2/MLKEM_genmatrix_avx2.ec +++ b/proof/correctness/avx2/MLKEM_genmatrix_avx2.ec @@ -16,7 +16,6 @@ import MLKEM_PolyAVXVec. import WArray136 WArray32 WArray128. import WArray512 WArray256. - (********* MOVED HERE TO AVOID CIRCULAR DEPS ************) (* move somewhere else *) @@ -293,18 +292,18 @@ case (trans{1});last first. inline XOF.init; sp;wp. exlim rho{2}, i0{2}, j0{2} => _rho _i _j. call {2} (parse_sem (SHAKE128_ABSORB_34 _rho (W8.of_int _i) (W8.of_int _j)) _rho (W8.of_int _i) (W8.of_int _j)). - auto => /> &1 a1 ?????H H0?;do split;1,2: smt(). + auto => /> &1 &2 a1 ?????H H0?;do split;1,2: smt(). + move => kk jj ????. move : H H0;rewrite !setmE !getmE /= => H H0. rewrite !offunmE /=; 1,2:smt(). - case (kk = i{1} /\ jj = j{1}); 1: by smt(). - move => ?;case (kk < i{1}); + case (kk = i{2} /\ jj = j{2}); 1: by smt(). + move => ?;case (kk < i{2}); 1: by move => ?;move : (H kk jj _ _); smt(). move => ?;move : (H0 jj _); smt(). + move => kk ??. move : H H0;rewrite !setmE !getmE /= => H H0. rewrite !offunmE /=; 1,2:smt(). - case (kk = j{1}); 1: by smt(). + case (kk = j{2}); 1: by smt(). by move => ?;move : (H0 kk _); smt(). rcondt{1} 1; 1: by auto. inline {1} 1;wp. @@ -320,18 +319,18 @@ while (0<=i{1} _rho _i _j. call {2} (parse_sem (SHAKE128_ABSORB_34 _rho (W8.of_int _i) (W8.of_int _j)) _rho (W8.of_int _i) (W8.of_int _j)). - auto => /> &1 a1 ?????H H0?;do split;1,2: smt(). + auto => /> &1 &2 a1 ?????H H0?;do split;1,2: smt(). + move => kk jj ????. move : H H0;rewrite !setmE !getmE /= => H H0. rewrite !offunmE /=; 1,2:smt(). - case (kk = i{1} /\ jj = j{1}); 1: by smt(). - move => ?;case (kk < i{1}); + case (kk = i{2} /\ jj = j{2}); 1: by smt(). + move => ?;case (kk < i{2}); 1: by move => ?;move : (H kk jj _ _); smt(). move => ?;move : (H0 jj _); smt(). + move => kk ??. move : H H0;rewrite !setmE !getmE /= => H H0. rewrite !offunmE /=; 1,2:smt(). - case (kk = j{1}); 1: by smt(). + case (kk = j{2}); 1: by smt(). by move => ?;move : (H0 kk _); smt(). qed. @@ -606,6 +605,72 @@ proof. by conseq buf_rejection_filter24_ll (buf_rejection_filter24_h _pol _ctr _buf _buf_offset). qed. +lemma buf_subl0 buf s e: + (e <= s)%Int => + buf_subl buf s e = []. +proof. by rewrite /buf_subl /#. qed. + +lemma size_plist ['a] (pol : 'a Array256.t) n: + 0 <= n <= 256 => + size (plist pol n) = n. +proof. by move=> Hn; rewrite size_mkseq /#. qed. + +lemma size_buf_subl buf bstart bend: + 0 <= bstart <= bend <= 8*68 => + size (buf_subl buf bstart bend) = bend - bstart. +proof. +move=> H; rewrite size_take 1:/# size_drop 1:/#. +by rewrite size_w64L_to_bytes size_to_list /#. +qed. + +lemma take_rejection16_done n buf buf_o bo: + 0 <= buf_o <= bo <= 504 => + 3 %| buf_o => + 3 %| bo => + size (take n (rejection16 (buf_subl buf buf_o bo))) = n => + take n (rejection16 (buf_subl buf buf_o bo)) + = take n (rejection16 (buf_subl buf buf_o 504)). +proof. +move=> H H1 H2 /size_takel' [Hsz1]. +rewrite size_map => Hsz2. +rewrite -(buf_subl_cat _ _ bo 504) 1:/# /rejection16 rejection_cat. + by rewrite size_buf_subl /#. +rewrite -map_take eq_sym -map_take; congr. +by rewrite take_cat' ifT. +qed. + +lemma pack32_sample_load_shuffle: + pack32 (to_list sample_load_shuffle) + = get256_direct ((init8 ("_.[_]" sample_load_shuffle)))%WArray32 0. +proof. +rewrite get256E; congr. +apply W32u8.Pack.all_eqP. +by rewrite of_listK 1:/# /all_eq /= !initiE /#. +qed. + +lemma size_rejection16_le buf bstart bend1 bend2: + 3 %| bstart => + 3 %| bend1 => + 3 %| bend2 => + 0 <= bstart <= bend1 <= bend2 <= 8*68 => + size (rejection16 (buf_subl buf bstart bend1)) + <= size (rejection16 (buf_subl buf bstart bend2)). +proof. +move=> H1 H2 H3 H. +rewrite /rejection16 !size_map. +rewrite -(buf_subl_cat _ _ bend1 bend2) 1:/#. +rewrite bytes2coeffs_cat. + by rewrite size_buf_subl /#. +by rewrite filter_cat size_cat; smt(size_ge0). +qed. + +lemma rejection16_cat l1 l2: + 3 %| size l1 => + rejection16 (l1++l2) = rejection16 l1 ++ rejection16 l2. +proof. +by move=> H; rewrite /rejection16 rejection_cat // map_cat. +qed. + hoare gen_matrix_buf_rejection_h _pol _ctr _buf _buf_offset: Jkem_avx2.M(Jkem_avx2.Syscall)._gen_matrix_buf_rejection : counter = _ctr @@ -621,9 +686,9 @@ hoare gen_matrix_buf_rejection_h _pol _ctr _buf _buf_offset: /\ res.`2 = W64.of_int (to_uint _ctr + size l). proof. proc; simplify. -while ( buf=_buf /\ - 0 <= to_uint buf_offset <= 504 /\ - 0 <= to_uint counter <= 256 /\ +while ( buf=_buf /\ 24 %| to_uint buf_offset /\ 3 %| to_uint _buf_offset /\ + 0 <= to_uint _buf_offset <= to_uint buf_offset <= 504 /\ + 0 <= to_uint _ctr <= to_uint counter <= 256 /\ auxdata_ok load_shuffle mask bounds ones sst /\ (plist pol (to_uint counter) = plist _pol (to_uint _ctr) @@ -634,56 +699,96 @@ while ( buf=_buf /\ /\ to_uint buf_offset <= 504-24))). ecall (conditionloop_h buf_offset (3 * 168 - 24) counter 256); simplify. wp; ecall (buf_rejection_filter24_h pol counter buf buf_offset). - auto => &m /> Ho1 Ho2 Hctr1 Hctr2 H Hcond1 Hcond2 [p c' o'] /= />. - rewrite !of_uintK => H1. + auto => &m /> Hdvd1 Hdvd2 Ho1 Ho2 Ho3 Hctr1 Hctr2 Hctr3 H Hcond1 Hcond2. + have Hsz: to_uint counter{m} = to_uint _ctr + min (256-to_uint _ctr) (size (rejection16 (buf_subl buf{m} (to_uint _buf_offset) (to_uint buf_offset{m})))). + rewrite -(size_plist pol{m} (to_uint counter{m})) 1:/# H size_cat size_plist 1:/#; congr. + by rewrite size_take_min /#. + move: H; rewrite take_oversize 1:/# => H [p c' ms'] /= />. + rewrite !of_uintK => Hstep. + rewrite !to_uintD_small 1:/# !of_uintK; split; first smt(). split. - rewrite to_uintD_small 1:/#. - by rewrite !of_uintK !modz_small //= /#. + split; first by rewrite !modz_small //= /#. + by move=> *; rewrite !modz_small //= /#. split. - split => *; first smt(size_ge0). - admit (* size filter... *). - admit (* H *). + by rewrite size_take_min 1:/# modz_small; smt(size_ge0). + rewrite modz_small; first smt(size_ge0 size_take_min). + rewrite modz_small 1:/#. + rewrite Hstep H -catA; congr. + rewrite -(buf_subl_cat _ (to_uint _buf_offset) (to_uint buf_offset{m}) (to_uint buf_offset{m} + 24)) 1:/#. + rewrite rejection16_cat. + by rewrite size_buf_subl /#. + by rewrite take_catr 1:/#; congr; congr; smt(). ecall (conditionloop_h buf_offset (3 * 168 - 24) counter 256). wp. -while ( buf=_buf /\ - 0 <= to_uint buf_offset <= 504 /\ - 0 <= to_uint counter <= 256 /\ +while ( buf=_buf /\ 24 %| to_uint buf_offset /\ 3 %| to_uint _buf_offset /\ + 0 <= to_uint _buf_offset <= to_uint buf_offset <= 504 /\ + 0 <= to_uint _ctr <= to_uint counter <= 256 /\ auxdata_ok load_shuffle mask bounds ones sst /\ - (plist pol (to_uint counter) - = plist _pol (to_uint _ctr) - ++ rejection16 (buf_subl _buf (to_uint _buf_offset) (to_uint buf_offset)) - ) /\ + plist pol (to_uint counter) + = plist _pol (to_uint _ctr) + ++ rejection16 (buf_subl _buf (to_uint _buf_offset) (to_uint buf_offset)) /\ + to_uint _ctr + size (rejection16 (buf_subl _buf (to_uint _buf_offset) (to_uint buf_offset))) <= 256 /\ (condition_loop <=> (to_uint counter <= 256-32 /\ to_uint buf_offset <= 504-48))). ecall (conditionloop_h buf_offset (3 * 168 - 48) counter (256-32+1)); simplify. wp; ecall (buf_rejection_filter48_h pol counter buf buf_offset). - auto => &m /> Ho1 Ho2 Hctr1 Hctr2 H Hcond1 Hcond2 [p c' o'] /= />. - rewrite !of_uintK. - split. - rewrite to_uintD_small 1:/#. - by rewrite !of_uintK !modz_small //= /#. + auto => &m /> Hdvd1 Hdvd2 Ho1 Ho2 Ho3 Hctr1 Hctr2 Hctr3 H Hbo Hcond1 Hcond2 [p c'] /= /> Hstep. + rewrite !to_uintD_small 1:/# !of_uintK; split; first smt(). split. - split => *; first smt(size_ge0). - admit (* size filter... *). - split; last smt(). - admit (* H *). + by rewrite !modz_small //= /#. + pose R:= rejection16 _. + have ?: 0 <= size R <= 32. + rewrite /rejection16 size_map; split; first smt(size_ge0). + move=> _; apply (size_rejection_le' 48) => //=. + by rewrite /buf_subl !size_take 1:/# !size_drop /#. + rewrite !modz_small 1..2:/#. + split; first smt(). + rewrite -andaE; split. + rewrite -(buf_subl_cat _ (to_uint _buf_offset) (to_uint buf_offset{m}) (to_uint buf_offset{m} + 48)) 1:/#. + rewrite Hstep H -catA; congr. + rewrite rejection16_cat 2://. + by rewrite size_buf_subl /#. + move => HH. + have : size (plist p (to_uint counter{m} + size R)) <= 256. + by rewrite size_plist /#. + by rewrite HH size_cat size_plist /#. ecall (conditionloop_h buf_offset (3 * 168 - 48) counter (256-32+1)); simplify. -auto => &m /> Hctr1 Hctr2 Hbo; split. +wp; skip => &m /> Hctr1 Hctr2 Hbo; split. + split; first smt(). + split; first smt(). split; first smt(). split; first smt(). split. - admit (* get256... *). + by rewrite pack32_sample_load_shuffle. split. - admit (* buf_subl0 *). - smt(). -move => buf_o cond ctr pol Hcond Hbo1 Hbo2 Hctr3 Hctr4 Hok H Hterm. -split. - admit (* *). -move=> buf_o2 cond2 ctr2 pol2 HC2 Hbo3 Hbo4 Hctr5 Hctr6 HH HHterm. -split. - admit. -admit. + by rewrite buf_subl0 1:/# /rejection16 rejection0 cats0. + split; last smt(). + by rewrite buf_subl0 1:/# /rejection16 rejection0 /#. +move => buf_o cond ctr pol Hcond Hdvd1 Hdvd2 Hbo1 Hbo2 Hbo3 Hctr3 Hctr4 _ H Hsz Hterm; split. +by rewrite take_oversize /#. +move=> buf_o2 cond2 ctr2 pol2 HC2 Hdvd3 Hbo4 Hbo5 Hctr5 Hctr6 HH. +case: (to_uint ctr2 < 256) => /= C1. + move=> C2; move: HH; have ->: to_uint buf_o2 = 504 by smt(). + move => HH; rewrite andbC -andaE to_uint_eq of_uintK modz_small. + split=> *; first smt(size_ge0). + by rewrite size_take /#. + split. + rewrite -(size_plist pol2 (to_uint ctr2)) 1:/#. + by rewrite HH size_cat size_plist 1:/#. + by move => E; move: HH; rewrite E => ->. +have E: to_uint ctr2 = 256 by smt(). +rewrite to_uint_eq andbC -andaE. +have HHsz: 256 = to_uint counter{m} + min (256 - to_uint counter{m}) + (size (rejection16 (buf_subl buf{m} (to_uint buf_offset{m}) (to_uint buf_o2)))). + rewrite -(size_plist pol2 256) 1:/#. + by rewrite -E HH size_cat size_plist 1:/# size_take_min /#. +rewrite size_take_min 1:/#. +rewrite of_uintK modz_small; first smt(size_ge0). +split; first smt(size_rejection16_le). +move => <-; rewrite HH; congr. +apply take_rejection16_done; first 3 smt(). +by rewrite size_take_min /#. qed. lemma gen_matrix_buf_rejection_ll: @@ -746,6 +851,9 @@ proof. by conseq gen_matrix_buf_rejection_ll (gen_matrix_buf_rejection_h _pol _ctr _buf _buf_offset). qed. +abbrev coeff2u16 x = W16.of_int (Zq.asint x). +abbrev coeffL2u16L = List.map coeff2u16. + equiv parse_one_polynomial_eq: Jkem_avx2.M(Jkem_avx2.Syscall).__gen_matrix_sample_one_polynomial ~ ParseFilter.sample @@ -764,9 +872,29 @@ transitivity ParseFilter.sample3buf smt(). + by move=> />. proc; simplify. +seq 4 7: ( buf_subl buf{1} 0 (2*168+200) = buf{2} + /\ stavx2bytes stavx2{1} = buf_subl buf{1} (2*168) (2*168+200) + ). + (* squeeze 3 blocks *) + admit. +while ( to_uint counter{1}=c{2} + /\ stavx2bytes stavx2{1} = buf_subl buf{1} (2*168) (2*168+200) + /\ plist pol{1} c{2} = coeffL2u16L p{2} + ). + admit. +wp. +ecall {1} (gen_matrix_buf_rejection_ph pol{1} counter{1} buf{1} buf_offset{1}); simplify. +auto => /> &m Hst [pol ctr] /= Hpol Hctr; split. + split. + split. + admit. + admit. + admit. +move=> buf c p st p'. admit. qed. + phoare sample_last _rho : [ Jkem_avx2.M(Jkem_avx2.Syscall).__gen_matrix_sample_one_polynomial : rho = _rho /\ rc = W16.of_int (2*256+2) ==>