Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix proofs w.r.t. EC dev #45

Merged
merged 1 commit into from
Jul 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion proof/correctness/MLKEMFCLib.ec
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ lemma initEq16 (f g: int -> 'a) :
(*-----------------------------------------------------------------------------*)
lemma nosmt set_neqiE (t : coeff Array256.t) x y a :
lemma set_neqiE (t : coeff Array256.t) x y a :
y <> x => t.[x <- a].[y] = t.[y].
proof. by rewrite get_set_if => /neqF ->. qed.
Expand Down
5 changes: 3 additions & 2 deletions proof/correctness/MLKEM_Poly.ec
Original file line number Diff line number Diff line change
Expand Up @@ -1733,7 +1733,8 @@ have /= [#] redbl6 redbh6 redv6 :=
have /= [#] redbl7 redbh7 redv7 :=
(SREDCp_corr (to_sint r6 * to_sint (- jzetas.[64 + to_uint i{hr} %/ 4])) hq _).
+ rewrite /R /=; move : (zeta_bound (64 + to_uint i{hr} %/ 4)); rewrite /minimum_residues /bpos16 => zb.
rewrite to_sintN /=; do split; smt().
case/(_ _): zb => *; ~-1:smt().
by rewrite to_sintN /=; do split; smt().

have /= [#] redbl8 redbh8 redv8 :=
(SREDCp_corr (to_sint ap{hr}.[to_uint i{hr}+2] * to_sint bp{hr}.[to_uint i{hr}+2]) hq _).
Expand Down Expand Up @@ -1801,7 +1802,7 @@ case (k < to_uint i{hr} %/4).
+ move => kbb; move: (vprev k _); 1:smt(); rewrite !mapiE /=; 1..12:smt().
rewrite /doublemul /cmplx_mul_169 /=.
move => /> vprev0 vprev1 vprev2 vprev3.
by rewrite !set_neqiE /#.
by rewrite !set_neqiE //#.

move => *; have kval : (k = to_uint i{hr} %/ 4) by smt().
have -> : 4 * k = to_uint i{hr} by smt().
Expand Down
10 changes: 5 additions & 5 deletions proof/correctness/Montgomery.ec
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ op BREDC(a bits : int) =

require import Barrett_mlkem_general.

lemma nosmt BREDCp_corr a bits:
lemma BREDCp_corr a bits:
0 < 2 * q < R %/2 =>
R < 2^bits =>
2 ^ bits %/ q * q < 2 ^ bits =>
Expand Down Expand Up @@ -222,7 +222,7 @@ op SREDC (a: int) : int =
let t = smod (a - u %/ R * q) (R^2)in
smod (t %/ R %% (R^2)) R.

lemma nosmt SREDCp_corr a:
lemma SREDCp_corr a:
0 < q < R %/2 =>
-R %/ 2 * q <= a < R %/2 * q =>
-q <= SREDC a < q /\
Expand Down Expand Up @@ -353,7 +353,7 @@ op REDC' (T: int) : int =
let m = ((T %% R)*_N') %% R
in (T + m*_N) %/ R.

lemma nosmt aux_divR T:
lemma aux_divR T:
let m = ((T %% R)*_N') %% R
in (T + m*_N) %% R = 0.
proof.
Expand All @@ -364,7 +364,7 @@ smt().
qed.


lemma nosmt REDC'_congr T:
lemma REDC'_congr T:
REDC' T %% _N = T * Rinv %% _N.
proof.
pose m := ((T %% R)*_N') %% R.
Expand All @@ -378,7 +378,7 @@ have t_modN: t %% _N = T*Rinv %% _N.
by rewrite /REDC'.
qed.

lemma nosmt REDC'_bnds T n:
lemma REDC'_bnds T n:
0 <= n =>
0 <= T < _N + _N * R^(n+1) =>
0 <= REDC' T < _N + _N*R^n.
Expand Down
58 changes: 29 additions & 29 deletions proof/correctness/Montgomery16.ec
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ require import MLKEMFCLib.

(* @JBA: MOVE THIS *)

lemma nosmt modzB_eq0 (x y m:int):
lemma modzB_eq0 (x y m:int):
0 < m => (x-y) %% m = 0 =>
x%%m = y%%m.
proof.
Expand Down Expand Up @@ -53,7 +53,7 @@ case: ( 2 ^ (16 - 1) <= x) => ?//.
by rewrite -modzDmr -modzNm modzz.
qed.

lemma nosmt to_sint_mod x:
lemma to_sint_mod x:
W16.to_sint x %% W16.modulus = to_uint x.
proof.
rewrite /to_sint /smod.
Expand All @@ -69,7 +69,7 @@ lemma to_sintK (w : W16.t) :
W16.of_int (W16.to_sint w) = w.
proof. by rewrite -of_int_mod to_sint_mod to_uintK //. qed.

lemma nosmt to_sint_eq (w1 w2: W16.t):
lemma to_sint_eq (w1 w2: W16.t):
to_sint w1 = to_sint w2 <=> w1=w2.
proof.
rewrite !to_sintE /smod /=.
Expand All @@ -82,7 +82,7 @@ case: (32768 <= to_uint w2) => CB /=.
by rewrite to_uint_eq.
qed.

lemma nosmt smod_small (x: int):
lemma smod_small (x: int):
-2^(16-1) <= x < 2^(16-1) =>
W16.smod (x %% W16.modulus) = x.
proof.
Expand All @@ -92,7 +92,7 @@ rewrite /smod; case: (x < 0) => C.
by have ->/#: ! 2 ^ (16 - 1) <= x %% W16.modulus by smt().
qed.

lemma nosmt modzM_sint (x y: W16.t):
lemma modzM_sint (x y: W16.t):
(to_sint x * to_sint y) %% W16.modulus
= (to_uint x * to_uint y) %% W16.modulus.
proof.
Expand All @@ -110,14 +110,14 @@ by rewrite modzMm.
done.
qed.

lemma nosmt to_sintM (x y: W16.t):
lemma to_sintM (x y: W16.t):
to_sint (x*y)
= W16.smod ((to_sint x * to_sint y) %% W16.modulus).
proof.
by rewrite {1}/W16.to_sint to_uintM modzM_sint.
qed.

lemma nosmt modzD_sint (x y: W16.t):
lemma modzD_sint (x y: W16.t):
(to_sint x + to_sint y) %% W16.modulus
= (to_uint x + to_uint y) %% W16.modulus.
proof.
Expand All @@ -126,22 +126,22 @@ case: (2 ^ (16 - 1) <= to_uint x);
case: (2 ^ (16 - 1) <= to_uint y); smt().
qed.

lemma nosmt to_sintD (x y: W16.t):
lemma to_sintD (x y: W16.t):
to_sint (x+y)
= W16.smod ((to_sint x + to_sint y)%%W16.modulus).
proof.
by rewrite {1}/W16.to_sint to_uintD modzD_sint.
qed.

lemma nosmt modzN_sint (x: W16.t):
lemma modzN_sint (x: W16.t):
(- to_sint x) %% W16.modulus
= (- to_uint x) %% W16.modulus.
proof.
rewrite /to_sint /smod.
case: (2 ^ (16 - 1) <= to_uint x); smt().
qed.

lemma nosmt to_sintN (x: W16.t):
lemma to_sintN (x: W16.t):
to_sint (-x)
= W16.smod ((-to_sint x) %% W16.modulus).
proof.
Expand All @@ -164,7 +164,7 @@ have X: 32768 <= (- to_uint x) %% 65536 by smt().
smt().
qed.

lemma nosmt modzB_sint (x y: W16.t):
lemma modzB_sint (x y: W16.t):
(to_sint x - to_sint y) %% W16.modulus
= (to_uint x - to_uint y) %% W16.modulus.
proof.
Expand All @@ -173,14 +173,14 @@ case: (2 ^ (16 - 1) <= to_uint x);
case: (2 ^ (16 - 1) <= to_uint y); smt().
qed.

lemma nosmt to_sintB (x y: W16.t):
lemma to_sintB (x y: W16.t):
to_sint (x-y)
= W16.smod ((to_sint x - to_sint y)%%W16.modulus).
proof.
rewrite {1}/W16.to_sint to_uintD to_uintN modzB_sint /#.
qed.

lemma nosmt wmulsE (x y: W16.t):
lemma wmulsE (x y: W16.t):
to_sint x * to_sint y
= to_sint (wmulhs x y) * W16.modulus
+ to_uint (x * y).
Expand All @@ -193,7 +193,7 @@ have /=?:= to_sint_cmp y.
smt().
qed.

lemma nosmt to_sint_wmulhs x y:
lemma to_sint_wmulhs x y:
to_sint (W16.wmulhs x y) = to_sint x * to_sint y %/ W16.modulus.
proof.
rewrite wmulsE divzDl 1:/# mulzK 1:/#; ring.
Expand Down Expand Up @@ -221,15 +221,15 @@ lemma sint_bndW (x: W16.t) (xL1 xH1 xL2 xH2: int):
sint_bnd xL2 xH2 x
by smt().

lemma nosmt to_sintD_small (x y: W16.t):
lemma to_sintD_small (x y: W16.t):
W16.min_sint <= to_sint x + to_sint y <= W16.max_sint =>
to_sint (x+y) = to_sint x + to_sint y.
proof.
move=> /=?; rewrite to_sintD smod_small /= /#.
qed.

(* a version of [to_sintD_small] with bounds *)
lemma nosmt to_sintD_small' (xL xH yL yH: int) (x y: W16.t):
lemma to_sintD_small' (xL xH yL yH: int) (x y: W16.t):
sint_bnd xL xH x =>
sint_bnd yL yH y =>
W16.min_sint <= xL+yL =>
Expand All @@ -239,7 +239,7 @@ proof.
by move=> /= *; rewrite to_sintD_small /#.
qed.

lemma nosmt to_sintN_small (x: W16.t):
lemma to_sintN_small (x: W16.t):
W16.min_sint < to_sint x =>
to_sint (-x) = - to_sint x.
proof.
Expand Down Expand Up @@ -384,7 +384,7 @@ qed.
abbrev sint32_bnd xL xH x =
xL <= W32.to_sint x <= xH.

lemma nosmt modzD_sint32 (x y: W32.t):
lemma modzD_sint32 (x y: W32.t):
(to_sint x + to_sint y) %% W32.modulus
= (to_uint x + to_uint y) %% W32.modulus.
proof.
Expand All @@ -393,14 +393,14 @@ case: (2 ^ (32 - 1) <= to_uint x);
case: (2 ^ (32 - 1) <= to_uint y); smt().
qed.

lemma nosmt to_sint32D (x y: W32.t):
lemma to_sint32D (x y: W32.t):
to_sint (x+y)
= W32.smod ((to_sint x + to_sint y)%%W32.modulus).
proof.
by rewrite {1}/W32.to_sint to_uintD modzD_sint32.
qed.

lemma nosmt smod32_small (x: int):
lemma smod32_small (x: int):
-2^(32-1) <= x < 2^(32-1) =>
W32.smod (x %% W32.modulus) = x.
proof.
Expand All @@ -410,7 +410,7 @@ rewrite /smod; case: (x < 0) => C.
by have ->/#: ! 2 ^ (32 - 1) <= x %% W32.modulus by smt().
qed.

lemma nosmt to_sint32D_small (x y: W32.t):
lemma to_sint32D_small (x y: W32.t):
W32.min_sint <= to_sint x + to_sint y <= W32.max_sint =>
to_sint (x+y) = to_sint x + to_sint y.
proof.
Expand All @@ -427,22 +427,22 @@ proof.
by move=> /> *; rewrite to_sint32D_small /#.
qed.
(*
lemma nosmt modzN_sint32 (x: W32.t):
lemma modzN_sint32 (x: W32.t):
(- to_sint x) %% W32.modulus
= (- to_uint x) %% W32.modulus.
proof.
rewrite /to_sint /smod.
case: (2 ^ (32 - 1) <= to_uint x); smt().
qed.

lemma nosmt to_sint32N (x: W32.t):
lemma to_sint32N (x: W32.t):
to_sint (-x)
= W32.smod ((-to_sint x) %% W32.modulus).
proof.
by rewrite {1}/W32.to_sint to_uintN modzN_sint32.
qed.
*)
lemma nosmt modzB_sint32 (x y: W32.t):
lemma modzB_sint32 (x y: W32.t):
(to_sint x - to_sint y) %% W32.modulus
= (to_uint x - to_uint y) %% W32.modulus.
proof.
Expand All @@ -451,7 +451,7 @@ case: (2 ^ (32 - 1) <= to_uint x);
case: (2 ^ (32 - 1) <= to_uint y); smt().
qed.

lemma nosmt to_sint32B (x y: W32.t):
lemma to_sint32B (x y: W32.t):
to_sint (x-y)
= W32.smod ((to_sint x - to_sint y)%%W32.modulus).
proof.
Expand Down Expand Up @@ -512,7 +512,7 @@ rewrite -{1}unpack16K /unpack16 /=; congr.
by rewrite init_of_list -JUtils.iotaredE /=.
qed.

lemma nosmt modz_sint32 (x: W32.t):
lemma modz_sint32 (x: W32.t):
(to_sint x) %% W16.modulus
= (to_uint x) %% W16.modulus.
proof.
Expand Down Expand Up @@ -567,7 +567,7 @@ op REDC16 (xyL xyH: W16.t): W16.t =
in xyH - (wmulhs m (W16.of_int q)).

(* general bounds... *)
lemma nosmt REDC16_correct bL bR (xyL xyH: W16.t):
lemma REDC16_correct bL bR (xyL xyH: W16.t):
W16.min_sint + 1664 <= bL <= 0 =>
0 <= bR <= W16.max_sint - 1665 =>
sint_bnd bL bR xyH =>
Expand Down Expand Up @@ -608,7 +608,7 @@ by rewrite eq_sym -modzDmr -Domain.mulNr -modzMm modzz mod0z.
qed.

(* useful specific case *)
lemma nosmt REDC16_correct_q (xyL xyH: W16.t):
lemma REDC16_correct_q (xyL xyH: W16.t):
sint_bnd (-q%/2) (q%/2 - 1) xyH =>
to_sint (REDC16 xyL xyH) %% q
= (to_sint xyH * R + to_uint xyL) * Rinv %% q
Expand All @@ -625,7 +625,7 @@ abbrev REDCmul16 (x y: W16.t): W16.t = REDC16 (x*y) (wmulhs x y).

(* correctness result for multiplication, for
the specific case of a reduced argument *)
lemma nosmt REDCmul16_correct (x y: W16.t):
lemma REDCmul16_correct (x y: W16.t):
sint_bnd 0 (q-1) y =>
to_sint (REDCmul16 x y) %% q
= to_sint x * to_sint y * Rinv %% q
Expand Down
2 changes: 1 addition & 1 deletion proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,7 @@ move: (noise_coef_avx2_aux bytes j) => /=.
by rewrite /noise_coef C' to_sintE /smod to_uint_shr //= => <- /#.
qed.

lemma nosmt to_sint8_mod x:
lemma to_sint8_mod x:
W8.to_sint x %% W8.modulus = to_uint x.
proof.
rewrite /to_sint /smod.
Expand Down
8 changes: 6 additions & 2 deletions proof/correctness/avx2/NTT_AVX_j.ec
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,11 @@ lemma wmuls16P n x y _x _y:
Iu16_sb n x _x =>
Iu16_sb n y _y =>
sint32_bnd (-n*n*q*q) (n*n*q*q) (wmuls16 x y).
proof. by move => [??] [??]; rewrite to_sint_wmuls16 /#. qed.
proof.
move=> [? [??]] [? [??]]; rewrite to_sint_wmuls16.
have ->: (n * n * q * q) = (n * q) * (n * q) by ring.
by rewrite &(ler_norml) normrM ler_pmul // 1,2:normr_ge0 /#.
qed.
phoare wmul_16u16_ph n _x _y:
[Jkem_avx2.M(Jkem_avx2.Syscall).__wmul_16u16:
Expand Down Expand Up @@ -1506,7 +1510,7 @@ qed.
(** Butterfly *)
lemma nosmt REDCmul16coeff (x y: W16.t):
lemma REDCmul16coeff (x y: W16.t):
sint_bnd 0 (q-1) y =>
incoeffW16 (Montgomery16.REDCmul16 x y)
= incoeffW16 x * incoeffW16 y * incoeff Montgomery16.Rinv
Expand Down