diff --git a/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec b/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec index 8370b710..581d06e7 100644 --- a/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec +++ b/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec @@ -2830,6 +2830,8 @@ module M(SC:Syscall_t) = { var shift:W256.t; var f:W256.t; var i:int; + var h:W128.t; + var sh:W128.t; x16p <- witness; x32p <- witness; x16p <- jqx16; @@ -2842,8 +2844,10 @@ module M(SC:Syscall_t) = { aux <- (256 %/ 16); i <- 0; while (i < aux) { - f <- - VPBROADCAST_2u128 (loadW128 Glob.mem (W64.to_uint (ap + (W64.of_int (8 * i))))); + h <- + (zeroextu128 (loadW64 Glob.mem (W64.to_uint (ap + (W64.of_int (8 * i)))))); + sh <- h; + f <- VPBROADCAST_2u128 sh; f <- VPSHUFB_256 f shufbidx; f <- VPAND_256 f mask; f <- VPMULL_16u16 f shift; diff --git a/code/jasmin/mlkem_avx2/poly.jinc b/code/jasmin/mlkem_avx2/poly.jinc index 2fa2acd0..f970d0b1 100644 --- a/code/jasmin/mlkem_avx2/poly.jinc +++ b/code/jasmin/mlkem_avx2/poly.jinc @@ -302,8 +302,10 @@ fn _poly_decompress(reg ptr u16[MLKEM_N] rp, reg u64 ap) -> stack u16[MLKEM_N] { inline int i; reg u256 f q shufbidx mask shift; + reg u128 h; reg ptr u16[16] x16p; reg ptr u8[32] x32p; + stack u128 sh; x16p = jqx16; q = x16p[u256 0]; @@ -316,7 +318,10 @@ fn _poly_decompress(reg ptr u16[MLKEM_N] rp, reg u64 ap) -> stack u16[MLKEM_N] for i=0 to MLKEM_N/16 { - f = #VPBROADCAST_2u128((u128)[ap + 8*i]); + h = (128u)(u64)[ap + 8*i]; + sh = h; + f = #VPBROADCAST_2u128(sh); + f = #VPSHUFB_256(f, shufbidx); f = #VPAND_256(f, mask); f = #VPMULL_16u16(f, shift); diff --git a/proof/correctness/avx2/MLKEM_Poly_avx2_prevec.ec b/proof/correctness/avx2/MLKEM_Poly_avx2_prevec.ec index 3152ae0d..70f2085c 100644 --- a/proof/correctness/avx2/MLKEM_Poly_avx2_prevec.ec +++ b/proof/correctness/avx2/MLKEM_Poly_avx2_prevec.ec @@ -590,6 +590,7 @@ module Mprevec = { var shift:t16u16; var f_b:t32u8; var f:t16u16; + var t64:t8u8; var t:t16u8; var i:int; @@ -607,7 +608,8 @@ module Mprevec = { i <- 0; while (i < aux) { - t <@ Ops.iload16u8(Glob.mem, ap + (W64.of_int (8 * i))); + t64 <@ Ops.iload8u8(Glob.mem, ap + (W64.of_int (8 * i))); + t <@ Ops.zeroextu128_t8u8(t64); f_b <@ Ops.iVPBROADCAST_2u128_32u8(t); f_b <@ Ops.iVPSHUFB_256(f_b, shufbidx); diff --git a/proof/correctness/avx2/MLKEM_Poly_avx2_proof.ec b/proof/correctness/avx2/MLKEM_Poly_avx2_proof.ec index cff5e401..f2674488 100644 --- a/proof/correctness/avx2/MLKEM_Poly_avx2_proof.ec +++ b/proof/correctness/avx2/MLKEM_Poly_avx2_proof.ec @@ -993,6 +993,15 @@ lemma poly_frommont_corr ap: map W16.to_sint res = map (fun x => SREDC (x * ((Ring.IntID.(^) SignedReductions.R 2) %% q))) ap] = 1%r by conseq poly_frommont_ll (poly_frommont_corr_h ap) => />. +lemma array32_of_list16 (m : W8.t list) : + Array32.init (fun (x : int) => (Array16.of_list witness m).[x %% 16]) = Array32.init (fun (x : int) => nth witness m (x %% 16)). +proof. +apply: Array32.ext_eq => x x_range. +rewrite !initiE // /=. +apply: Array16.get_of_list. +smt(). +qed. + lemma poly_decompress_corr mem _p (_a : W8.t Array128.t): equiv [ Mprevec.poly_decompress ~ EncDec_AVX2.decode4 : valid_ptr _p 128 /\ @@ -1068,19 +1077,21 @@ proof. rewrite -decompress_alt_decompress 1,2://= /decompress_alt k_ub /=. case (16 * i{2} <= k) => k_tlb />. rewrite /loadW8. - rewrite (_: (Array32.init (fun (i0 : int) => - (Array16.init (fun (i1 : int) => - mem.[W64.to_uint (ap{1} + (W64.of_int (8 * i{2}))) + i1])).[i0 %% 16])) = - (Array32.init (fun (i0 : int) => - mem.[W64.to_uint (ap{1} + (W64.of_int (8 * i{2}))) + i0 %% 16]))). - apply Array32.ext_eq. - move => x x_i. - do (rewrite Array32.initiE 1:x_i /=). - rewrite Array16.initiE 1:/# //=. + rewrite array32_of_list16. rewrite (_: (Array32.init (fun (i1 : int) => if 128 <= to_uint shufbidx{1}.[i1] then W8.zero - else (Array32.init (fun (i0 : int) => - mem.[to_uint (ap{1} + (of_int (8 * i{2}))%W64) + i0 %% 16])).[16 * (i1 %/ 16) + to_uint shufbidx{1}.[i1] %% 16])) = + else (Array32.init (fun (x3 : int) => + nth witness + [mem.[to_uint (ap{1} + (of_int (8 * i{2}))%W64)]; + mem.[to_uint (ap{1} + (of_int (8 * i{2}))%W64) + 1]; + mem.[to_uint (ap{1} + (of_int (8 * i{2}))%W64) + 2]; + mem.[to_uint (ap{1} + (of_int (8 * i{2}))%W64) + 3]; + mem.[to_uint (ap{1} + (of_int (8 * i{2}))%W64) + 4]; + mem.[to_uint (ap{1} + (of_int (8 * i{2}))%W64) + 5]; + mem.[to_uint (ap{1} + (of_int (8 * i{2}))%W64) + 6]; + mem.[to_uint (ap{1} + (of_int (8 * i{2}))%W64) + 7]; + W8.zero; W8.zero; W8.zero; W8.zero; W8.zero; W8.zero; W8.zero; W8.zero] + (x3 %% 16))).[16 * (i1 %/ 16) + to_uint shufbidx{1}.[i1] %% 16])) = (Array32.init (fun (i0 : int) => mem.[to_uint (ap{1} + (of_int (8 * i{2}))%W64) + i0 %/ 4]))). apply Array32.ext_eq => x x_i. do (rewrite (Array32.initiE _ x) 1:x_i /=). diff --git a/proof/correctness/avx2/MLKEM_Poly_avx2_vec.ec b/proof/correctness/avx2/MLKEM_Poly_avx2_vec.ec index 79b8fde6..fbcc9c3a 100644 --- a/proof/correctness/avx2/MLKEM_Poly_avx2_vec.ec +++ b/proof/correctness/avx2/MLKEM_Poly_avx2_vec.ec @@ -445,8 +445,8 @@ module Mvec = { aux <- (256 %/ 16); i <- 0; while (i < aux) { - f <@ - OpsV.iVPBROADCAST_2u128_32u8(loadW128 Glob.mem (W64.to_uint (ap + (W64.of_int (8 * i))))); + t <@ OpsV.zeroextu128_t8u8(loadW64 Glob.mem (W64.to_uint (ap + (W64.of_int (8 * i))))); + f <@ OpsV.iVPBROADCAST_2u128_32u8(t); f <@ OpsV.iVPSHUFB_256(f, shufbidx); f <@ OpsV.iVPAND_16u16(f, mask); f <@ OpsV.iVPMULL_16u16(f, shift); @@ -1100,13 +1100,14 @@ proof. call eq_iVPMULHRS_256; call eq_iVPMULL_16u16; call eq_iVPAND_16u16. wp. call eq_iVPSHUFB_256. - inline Ops.iload16u8. + inline Ops.iload8u8. sp. call eq_iVPBROADCAST_2u128_32u8. + call eq_zeroextu128_t8u8. wp; skip; rewrite /is32u8 /is16u16 /is16u8 => />. move => &1 &2 rp_l i_lb i_ub i_tub. split. - + rewrite /loadW128 /loadW8 /=. - apply W16u8.allP => //=. + + rewrite /loadW64 /loadW8 /=. + apply W8u8.allP => //=. move => p_eq res_l0. split. + rewrite /f32u8_t16u16 initiE //=.