From 2ed13c827530de6d8fe43cba2777ac2a88dc66cc Mon Sep 17 00:00:00 2001 From: Pierre-Yves Strub Date: Sat, 20 Jul 2024 21:54:04 +0200 Subject: [PATCH] remove [nosmt] tags + fix proofs --- crypto-specs | 2 +- proof/correctness/MLKEMFCLib.ec | 2 +- proof/correctness/MLKEM_Poly.ec | 5 +- proof/correctness/Montgomery.ec | 10 ++-- proof/correctness/Montgomery16.ec | 58 +++++++++---------- proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec | 2 +- proof/correctness/avx2/NTT_AVX_j.ec | 8 ++- 7 files changed, 46 insertions(+), 41 deletions(-) diff --git a/crypto-specs b/crypto-specs index a82b8d51..3710c5ff 160000 --- a/crypto-specs +++ b/crypto-specs @@ -1 +1 @@ -Subproject commit a82b8d510b34d0939543ab340b9b0e3907c14f51 +Subproject commit 3710c5ff92a84a3a150c70a89d8d2fd3486bd735 diff --git a/proof/correctness/MLKEMFCLib.ec b/proof/correctness/MLKEMFCLib.ec index 772028e6..7d5a1b9a 100644 --- a/proof/correctness/MLKEMFCLib.ec +++ b/proof/correctness/MLKEMFCLib.ec @@ -221,7 +221,7 @@ lemma initEq16 (f g: int -> 'a) : (*-----------------------------------------------------------------------------*) -lemma nosmt set_neqiE (t : coeff Array256.t) x y a : +lemma set_neqiE (t : coeff Array256.t) x y a : y <> x => t.[x <- a].[y] = t.[y]. proof. by rewrite get_set_if => /neqF ->. qed. diff --git a/proof/correctness/MLKEM_Poly.ec b/proof/correctness/MLKEM_Poly.ec index 084bf15f..b6942a76 100644 --- a/proof/correctness/MLKEM_Poly.ec +++ b/proof/correctness/MLKEM_Poly.ec @@ -1733,7 +1733,8 @@ have /= [#] redbl6 redbh6 redv6 := have /= [#] redbl7 redbh7 redv7 := (SREDCp_corr (to_sint r6 * to_sint (- jzetas.[64 + to_uint i{hr} %/ 4])) hq _). + rewrite /R /=; move : (zeta_bound (64 + to_uint i{hr} %/ 4)); rewrite /minimum_residues /bpos16 => zb. - rewrite to_sintN /=; do split; smt(). + case/(_ _): zb => *; ~-1:smt(). + by rewrite to_sintN /=; do split; smt(). have /= [#] redbl8 redbh8 redv8 := (SREDCp_corr (to_sint ap{hr}.[to_uint i{hr}+2] * to_sint bp{hr}.[to_uint i{hr}+2]) hq _). @@ -1801,7 +1802,7 @@ case (k < to_uint i{hr} %/4). + move => kbb; move: (vprev k _); 1:smt(); rewrite !mapiE /=; 1..12:smt(). rewrite /doublemul /cmplx_mul_169 /=. move => /> vprev0 vprev1 vprev2 vprev3. - by rewrite !set_neqiE /#. + by rewrite !set_neqiE //#. move => *; have kval : (k = to_uint i{hr} %/ 4) by smt(). have -> : 4 * k = to_uint i{hr} by smt(). diff --git a/proof/correctness/Montgomery.ec b/proof/correctness/Montgomery.ec index a8e40220..ffdf369b 100644 --- a/proof/correctness/Montgomery.ec +++ b/proof/correctness/Montgomery.ec @@ -164,7 +164,7 @@ op BREDC(a bits : int) = require import Barrett_mlkem_general. -lemma nosmt BREDCp_corr a bits: +lemma BREDCp_corr a bits: 0 < 2 * q < R %/2 => R < 2^bits => 2 ^ bits %/ q * q < 2 ^ bits => @@ -222,7 +222,7 @@ op SREDC (a: int) : int = let t = smod (a - u %/ R * q) (R^2)in smod (t %/ R %% (R^2)) R. -lemma nosmt SREDCp_corr a: +lemma SREDCp_corr a: 0 < q < R %/2 => -R %/ 2 * q <= a < R %/2 * q => -q <= SREDC a < q /\ @@ -353,7 +353,7 @@ op REDC' (T: int) : int = let m = ((T %% R)*_N') %% R in (T + m*_N) %/ R. -lemma nosmt aux_divR T: +lemma aux_divR T: let m = ((T %% R)*_N') %% R in (T + m*_N) %% R = 0. proof. @@ -364,7 +364,7 @@ smt(). qed. -lemma nosmt REDC'_congr T: +lemma REDC'_congr T: REDC' T %% _N = T * Rinv %% _N. proof. pose m := ((T %% R)*_N') %% R. @@ -378,7 +378,7 @@ have t_modN: t %% _N = T*Rinv %% _N. by rewrite /REDC'. qed. -lemma nosmt REDC'_bnds T n: +lemma REDC'_bnds T n: 0 <= n => 0 <= T < _N + _N * R^(n+1) => 0 <= REDC' T < _N + _N*R^n. diff --git a/proof/correctness/Montgomery16.ec b/proof/correctness/Montgomery16.ec index 55c58f4d..f1ca7ec9 100644 --- a/proof/correctness/Montgomery16.ec +++ b/proof/correctness/Montgomery16.ec @@ -6,7 +6,7 @@ require import MLKEMFCLib. (* @JBA: MOVE THIS *) -lemma nosmt modzB_eq0 (x y m:int): +lemma modzB_eq0 (x y m:int): 0 < m => (x-y) %% m = 0 => x%%m = y%%m. proof. @@ -53,7 +53,7 @@ case: ( 2 ^ (16 - 1) <= x) => ?//. by rewrite -modzDmr -modzNm modzz. qed. -lemma nosmt to_sint_mod x: +lemma to_sint_mod x: W16.to_sint x %% W16.modulus = to_uint x. proof. rewrite /to_sint /smod. @@ -69,7 +69,7 @@ lemma to_sintK (w : W16.t) : W16.of_int (W16.to_sint w) = w. proof. by rewrite -of_int_mod to_sint_mod to_uintK //. qed. -lemma nosmt to_sint_eq (w1 w2: W16.t): +lemma to_sint_eq (w1 w2: W16.t): to_sint w1 = to_sint w2 <=> w1=w2. proof. rewrite !to_sintE /smod /=. @@ -82,7 +82,7 @@ case: (32768 <= to_uint w2) => CB /=. by rewrite to_uint_eq. qed. -lemma nosmt smod_small (x: int): +lemma smod_small (x: int): -2^(16-1) <= x < 2^(16-1) => W16.smod (x %% W16.modulus) = x. proof. @@ -92,7 +92,7 @@ rewrite /smod; case: (x < 0) => C. by have ->/#: ! 2 ^ (16 - 1) <= x %% W16.modulus by smt(). qed. -lemma nosmt modzM_sint (x y: W16.t): +lemma modzM_sint (x y: W16.t): (to_sint x * to_sint y) %% W16.modulus = (to_uint x * to_uint y) %% W16.modulus. proof. @@ -110,14 +110,14 @@ by rewrite modzMm. done. qed. -lemma nosmt to_sintM (x y: W16.t): +lemma to_sintM (x y: W16.t): to_sint (x*y) = W16.smod ((to_sint x * to_sint y) %% W16.modulus). proof. by rewrite {1}/W16.to_sint to_uintM modzM_sint. qed. -lemma nosmt modzD_sint (x y: W16.t): +lemma modzD_sint (x y: W16.t): (to_sint x + to_sint y) %% W16.modulus = (to_uint x + to_uint y) %% W16.modulus. proof. @@ -126,14 +126,14 @@ case: (2 ^ (16 - 1) <= to_uint x); case: (2 ^ (16 - 1) <= to_uint y); smt(). qed. -lemma nosmt to_sintD (x y: W16.t): +lemma to_sintD (x y: W16.t): to_sint (x+y) = W16.smod ((to_sint x + to_sint y)%%W16.modulus). proof. by rewrite {1}/W16.to_sint to_uintD modzD_sint. qed. -lemma nosmt modzN_sint (x: W16.t): +lemma modzN_sint (x: W16.t): (- to_sint x) %% W16.modulus = (- to_uint x) %% W16.modulus. proof. @@ -141,7 +141,7 @@ rewrite /to_sint /smod. case: (2 ^ (16 - 1) <= to_uint x); smt(). qed. -lemma nosmt to_sintN (x: W16.t): +lemma to_sintN (x: W16.t): to_sint (-x) = W16.smod ((-to_sint x) %% W16.modulus). proof. @@ -164,7 +164,7 @@ have X: 32768 <= (- to_uint x) %% 65536 by smt(). smt(). qed. -lemma nosmt modzB_sint (x y: W16.t): +lemma modzB_sint (x y: W16.t): (to_sint x - to_sint y) %% W16.modulus = (to_uint x - to_uint y) %% W16.modulus. proof. @@ -173,14 +173,14 @@ case: (2 ^ (16 - 1) <= to_uint x); case: (2 ^ (16 - 1) <= to_uint y); smt(). qed. -lemma nosmt to_sintB (x y: W16.t): +lemma to_sintB (x y: W16.t): to_sint (x-y) = W16.smod ((to_sint x - to_sint y)%%W16.modulus). proof. rewrite {1}/W16.to_sint to_uintD to_uintN modzB_sint /#. qed. -lemma nosmt wmulsE (x y: W16.t): +lemma wmulsE (x y: W16.t): to_sint x * to_sint y = to_sint (wmulhs x y) * W16.modulus + to_uint (x * y). @@ -193,7 +193,7 @@ have /=?:= to_sint_cmp y. smt(). qed. -lemma nosmt to_sint_wmulhs x y: +lemma to_sint_wmulhs x y: to_sint (W16.wmulhs x y) = to_sint x * to_sint y %/ W16.modulus. proof. rewrite wmulsE divzDl 1:/# mulzK 1:/#; ring. @@ -221,7 +221,7 @@ lemma sint_bndW (x: W16.t) (xL1 xH1 xL2 xH2: int): sint_bnd xL2 xH2 x by smt(). -lemma nosmt to_sintD_small (x y: W16.t): +lemma to_sintD_small (x y: W16.t): W16.min_sint <= to_sint x + to_sint y <= W16.max_sint => to_sint (x+y) = to_sint x + to_sint y. proof. @@ -229,7 +229,7 @@ move=> /=?; rewrite to_sintD smod_small /= /#. qed. (* a version of [to_sintD_small] with bounds *) -lemma nosmt to_sintD_small' (xL xH yL yH: int) (x y: W16.t): +lemma to_sintD_small' (xL xH yL yH: int) (x y: W16.t): sint_bnd xL xH x => sint_bnd yL yH y => W16.min_sint <= xL+yL => @@ -239,7 +239,7 @@ proof. by move=> /= *; rewrite to_sintD_small /#. qed. -lemma nosmt to_sintN_small (x: W16.t): +lemma to_sintN_small (x: W16.t): W16.min_sint < to_sint x => to_sint (-x) = - to_sint x. proof. @@ -384,7 +384,7 @@ qed. abbrev sint32_bnd xL xH x = xL <= W32.to_sint x <= xH. -lemma nosmt modzD_sint32 (x y: W32.t): +lemma modzD_sint32 (x y: W32.t): (to_sint x + to_sint y) %% W32.modulus = (to_uint x + to_uint y) %% W32.modulus. proof. @@ -393,14 +393,14 @@ case: (2 ^ (32 - 1) <= to_uint x); case: (2 ^ (32 - 1) <= to_uint y); smt(). qed. -lemma nosmt to_sint32D (x y: W32.t): +lemma to_sint32D (x y: W32.t): to_sint (x+y) = W32.smod ((to_sint x + to_sint y)%%W32.modulus). proof. by rewrite {1}/W32.to_sint to_uintD modzD_sint32. qed. -lemma nosmt smod32_small (x: int): +lemma smod32_small (x: int): -2^(32-1) <= x < 2^(32-1) => W32.smod (x %% W32.modulus) = x. proof. @@ -410,7 +410,7 @@ rewrite /smod; case: (x < 0) => C. by have ->/#: ! 2 ^ (32 - 1) <= x %% W32.modulus by smt(). qed. -lemma nosmt to_sint32D_small (x y: W32.t): +lemma to_sint32D_small (x y: W32.t): W32.min_sint <= to_sint x + to_sint y <= W32.max_sint => to_sint (x+y) = to_sint x + to_sint y. proof. @@ -427,7 +427,7 @@ proof. by move=> /> *; rewrite to_sint32D_small /#. qed. (* -lemma nosmt modzN_sint32 (x: W32.t): +lemma modzN_sint32 (x: W32.t): (- to_sint x) %% W32.modulus = (- to_uint x) %% W32.modulus. proof. @@ -435,14 +435,14 @@ rewrite /to_sint /smod. case: (2 ^ (32 - 1) <= to_uint x); smt(). qed. -lemma nosmt to_sint32N (x: W32.t): +lemma to_sint32N (x: W32.t): to_sint (-x) = W32.smod ((-to_sint x) %% W32.modulus). proof. by rewrite {1}/W32.to_sint to_uintN modzN_sint32. qed. *) -lemma nosmt modzB_sint32 (x y: W32.t): +lemma modzB_sint32 (x y: W32.t): (to_sint x - to_sint y) %% W32.modulus = (to_uint x - to_uint y) %% W32.modulus. proof. @@ -451,7 +451,7 @@ case: (2 ^ (32 - 1) <= to_uint x); case: (2 ^ (32 - 1) <= to_uint y); smt(). qed. -lemma nosmt to_sint32B (x y: W32.t): +lemma to_sint32B (x y: W32.t): to_sint (x-y) = W32.smod ((to_sint x - to_sint y)%%W32.modulus). proof. @@ -512,7 +512,7 @@ rewrite -{1}unpack16K /unpack16 /=; congr. by rewrite init_of_list -JUtils.iotaredE /=. qed. -lemma nosmt modz_sint32 (x: W32.t): +lemma modz_sint32 (x: W32.t): (to_sint x) %% W16.modulus = (to_uint x) %% W16.modulus. proof. @@ -567,7 +567,7 @@ op REDC16 (xyL xyH: W16.t): W16.t = in xyH - (wmulhs m (W16.of_int q)). (* general bounds... *) -lemma nosmt REDC16_correct bL bR (xyL xyH: W16.t): +lemma REDC16_correct bL bR (xyL xyH: W16.t): W16.min_sint + 1664 <= bL <= 0 => 0 <= bR <= W16.max_sint - 1665 => sint_bnd bL bR xyH => @@ -608,7 +608,7 @@ by rewrite eq_sym -modzDmr -Domain.mulNr -modzMm modzz mod0z. qed. (* useful specific case *) -lemma nosmt REDC16_correct_q (xyL xyH: W16.t): +lemma REDC16_correct_q (xyL xyH: W16.t): sint_bnd (-q%/2) (q%/2 - 1) xyH => to_sint (REDC16 xyL xyH) %% q = (to_sint xyH * R + to_uint xyL) * Rinv %% q @@ -625,7 +625,7 @@ abbrev REDCmul16 (x y: W16.t): W16.t = REDC16 (x*y) (wmulhs x y). (* correctness result for multiplication, for the specific case of a reduced argument *) -lemma nosmt REDCmul16_correct (x y: W16.t): +lemma REDCmul16_correct (x y: W16.t): sint_bnd 0 (q-1) y => to_sint (REDCmul16 x y) %% q = to_sint x * to_sint y * Rinv %% q diff --git a/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec b/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec index 9eaad408..91e3775a 100644 --- a/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec +++ b/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec @@ -841,7 +841,7 @@ move: (noise_coef_avx2_aux bytes j) => /=. by rewrite /noise_coef C' to_sintE /smod to_uint_shr //= => <- /#. qed. -lemma nosmt to_sint8_mod x: +lemma to_sint8_mod x: W8.to_sint x %% W8.modulus = to_uint x. proof. rewrite /to_sint /smod. diff --git a/proof/correctness/avx2/NTT_AVX_j.ec b/proof/correctness/avx2/NTT_AVX_j.ec index d7528c9f..6ff926f5 100644 --- a/proof/correctness/avx2/NTT_AVX_j.ec +++ b/proof/correctness/avx2/NTT_AVX_j.ec @@ -1068,7 +1068,11 @@ lemma wmuls16P n x y _x _y: Iu16_sb n x _x => Iu16_sb n y _y => sint32_bnd (-n*n*q*q) (n*n*q*q) (wmuls16 x y). -proof. by move => [??] [??]; rewrite to_sint_wmuls16 /#. qed. +proof. +move=> [? [??]] [? [??]]; rewrite to_sint_wmuls16. +have ->: (n * n * q * q) = (n * q) * (n * q) by ring. +by rewrite &(ler_norml) normrM ler_pmul // 1,2:normr_ge0 /#. +qed. phoare wmul_16u16_ph n _x _y: [Jkem_avx2.M(Jkem_avx2.Syscall).__wmul_16u16: @@ -1506,7 +1510,7 @@ qed. (** Butterfly *) -lemma nosmt REDCmul16coeff (x y: W16.t): +lemma REDCmul16coeff (x y: W16.t): sint_bnd 0 (q-1) y => incoeffW16 (Montgomery16.REDCmul16 x y) = incoeffW16 x * incoeffW16 y * incoeff Montgomery16.Rinv