Skip to content

Commit

Permalink
fix poly_decompress: read out of bounds; originally proposed here: fo…
Browse files Browse the repository at this point in the history
…rmosa-crypto/libjade@d05492d

Co-authored-by: Vincent Laporte <[email protected]>
  • Loading branch information
tfaoliveira and vbgl committed Feb 7, 2024
1 parent 82da326 commit bc31973
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 20 deletions.
8 changes: 6 additions & 2 deletions code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec
Original file line number Diff line number Diff line change
Expand Up @@ -2830,6 +2830,8 @@ module M(SC:Syscall_t) = {
var shift:W256.t;
var f:W256.t;
var i:int;
var h:W128.t;
var sh:W128.t;
x16p <- witness;
x32p <- witness;
x16p <- jqx16;
Expand All @@ -2842,8 +2844,10 @@ module M(SC:Syscall_t) = {
aux <- (256 %/ 16);
i <- 0;
while (i < aux) {
f <-
VPBROADCAST_2u128 (loadW128 Glob.mem (W64.to_uint (ap + (W64.of_int (8 * i)))));
h <-
(zeroextu128 (loadW64 Glob.mem (W64.to_uint (ap + (W64.of_int (8 * i))))));
sh <- h;
f <- VPBROADCAST_2u128 sh;
f <- VPSHUFB_256 f shufbidx;
f <- VPAND_256 f mask;
f <- VPMULL_16u16 f shift;
Expand Down
7 changes: 6 additions & 1 deletion code/jasmin/mlkem_avx2/poly.jinc
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,10 @@ fn _poly_decompress(reg ptr u16[MLKEM_N] rp, reg u64 ap) -> stack u16[MLKEM_N]
{
inline int i;
reg u256 f q shufbidx mask shift;
reg u128 h;
reg ptr u16[16] x16p;
reg ptr u8[32] x32p;
stack u128 sh;

x16p = jqx16;
q = x16p[u256 0];
Expand All @@ -316,7 +318,10 @@ fn _poly_decompress(reg ptr u16[MLKEM_N] rp, reg u64 ap) -> stack u16[MLKEM_N]

for i=0 to MLKEM_N/16
{
f = #VPBROADCAST_2u128((u128)[ap + 8*i]);
h = (128u)(u64)[ap + 8*i];
sh = h;
f = #VPBROADCAST_2u128(sh);

f = #VPSHUFB_256(f, shufbidx);
f = #VPAND_256(f, mask);
f = #VPMULL_16u16(f, shift);
Expand Down
4 changes: 3 additions & 1 deletion proof/correctness/avx2/MLKEM_Poly_avx2_prevec.ec
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,7 @@ module Mprevec = {
var shift:t16u16;
var f_b:t32u8;
var f:t16u16;
var t64:t8u8;
var t:t16u8;
var i:int;

Expand All @@ -607,7 +608,8 @@ module Mprevec = {
i <- 0;

while (i < aux) {
t <@ Ops.iload16u8(Glob.mem, ap + (W64.of_int (8 * i)));
t64 <@ Ops.iload8u8(Glob.mem, ap + (W64.of_int (8 * i)));
t <@ Ops.zeroextu128_t8u8(t64);
f_b <@ Ops.iVPBROADCAST_2u128_32u8(t);
f_b <@ Ops.iVPSHUFB_256(f_b, shufbidx);

Expand Down
33 changes: 22 additions & 11 deletions proof/correctness/avx2/MLKEM_Poly_avx2_proof.ec
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,15 @@ lemma poly_frommont_corr ap:
map W16.to_sint res = map (fun x => SREDC (x * ((Ring.IntID.(^) SignedReductions.R 2) %% q))) ap] = 1%r
by conseq poly_frommont_ll (poly_frommont_corr_h ap) => />.

lemma array32_of_list16 (m : W8.t list) :
Array32.init (fun (x : int) => (Array16.of_list witness m).[x %% 16]) = Array32.init (fun (x : int) => nth witness m (x %% 16)).
proof.
apply: Array32.ext_eq => x x_range.
rewrite !initiE // /=.
apply: Array16.get_of_list.
smt().
qed.

lemma poly_decompress_corr mem _p (_a : W8.t Array128.t):
equiv [ Mprevec.poly_decompress ~ EncDec_AVX2.decode4 :
valid_ptr _p 128 /\
Expand Down Expand Up @@ -1068,19 +1077,21 @@ proof.
rewrite -decompress_alt_decompress 1,2://= /decompress_alt k_ub /=.
case (16 * i{2} <= k) => k_tlb />.
rewrite /loadW8.
rewrite (_: (Array32.init (fun (i0 : int) =>
(Array16.init (fun (i1 : int) =>
mem.[W64.to_uint (ap{1} + (W64.of_int (8 * i{2}))) + i1])).[i0 %% 16])) =
(Array32.init (fun (i0 : int) =>
mem.[W64.to_uint (ap{1} + (W64.of_int (8 * i{2}))) + i0 %% 16]))).
apply Array32.ext_eq.
move => x x_i.
do (rewrite Array32.initiE 1:x_i /=).
rewrite Array16.initiE 1:/# //=.
rewrite array32_of_list16.
rewrite (_: (Array32.init (fun (i1 : int) =>
if 128 <= to_uint shufbidx{1}.[i1] then W8.zero
else (Array32.init (fun (i0 : int) =>
mem.[to_uint (ap{1} + (of_int (8 * i{2}))%W64) + i0 %% 16])).[16 * (i1 %/ 16) + to_uint shufbidx{1}.[i1] %% 16])) =
else (Array32.init (fun (x3 : int) =>
nth witness
[mem.[to_uint (ap{1} + (of_int (8 * i{2}))%W64)];
mem.[to_uint (ap{1} + (of_int (8 * i{2}))%W64) + 1];
mem.[to_uint (ap{1} + (of_int (8 * i{2}))%W64) + 2];
mem.[to_uint (ap{1} + (of_int (8 * i{2}))%W64) + 3];
mem.[to_uint (ap{1} + (of_int (8 * i{2}))%W64) + 4];
mem.[to_uint (ap{1} + (of_int (8 * i{2}))%W64) + 5];
mem.[to_uint (ap{1} + (of_int (8 * i{2}))%W64) + 6];
mem.[to_uint (ap{1} + (of_int (8 * i{2}))%W64) + 7];
W8.zero; W8.zero; W8.zero; W8.zero; W8.zero; W8.zero; W8.zero; W8.zero]
(x3 %% 16))).[16 * (i1 %/ 16) + to_uint shufbidx{1}.[i1] %% 16])) =
(Array32.init (fun (i0 : int) => mem.[to_uint (ap{1} + (of_int (8 * i{2}))%W64) + i0 %/ 4]))).
apply Array32.ext_eq => x x_i.
do (rewrite (Array32.initiE _ x) 1:x_i /=).
Expand Down
11 changes: 6 additions & 5 deletions proof/correctness/avx2/MLKEM_Poly_avx2_vec.ec
Original file line number Diff line number Diff line change
Expand Up @@ -445,8 +445,8 @@ module Mvec = {
aux <- (256 %/ 16);
i <- 0;
while (i < aux) {
f <@
OpsV.iVPBROADCAST_2u128_32u8(loadW128 Glob.mem (W64.to_uint (ap + (W64.of_int (8 * i)))));
t <@ OpsV.zeroextu128_t8u8(loadW64 Glob.mem (W64.to_uint (ap + (W64.of_int (8 * i)))));
f <@ OpsV.iVPBROADCAST_2u128_32u8(t);
f <@ OpsV.iVPSHUFB_256(f, shufbidx);
f <@ OpsV.iVPAND_16u16(f, mask);
f <@ OpsV.iVPMULL_16u16(f, shift);
Expand Down Expand Up @@ -1100,13 +1100,14 @@ proof.
call eq_iVPMULHRS_256; call eq_iVPMULL_16u16; call eq_iVPAND_16u16.
wp.
call eq_iVPSHUFB_256.
inline Ops.iload16u8.
inline Ops.iload8u8.
sp.
call eq_iVPBROADCAST_2u128_32u8.
call eq_zeroextu128_t8u8.
wp; skip; rewrite /is32u8 /is16u16 /is16u8 => />. move => &1 &2 rp_l i_lb i_ub i_tub.
split.
+ rewrite /loadW128 /loadW8 /=.
apply W16u8.allP => //=.
+ rewrite /loadW64 /loadW8 /=.
apply W8u8.allP => //=.
move => p_eq res_l0.
split.
+ rewrite /f32u8_t16u16 initiE //=.
Expand Down

0 comments on commit bc31973

Please sign in to comment.