From 58aa362790fa6dc7ddfa5a5e33657d14d5d679a6 Mon Sep 17 00:00:00 2001 From: JBA Date: Mon, 11 Nov 2024 01:07:15 +0000 Subject: [PATCH] recompiles on current bdep branch --- proof/correctness/MLKEM_InnerPKE.ec | 14 +- proof/correctness/NTTAlgebra.ec | 4 +- proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec | 4407 +++++++++-------- proof/correctness/avx2/MLKEM_avx2_encdec.ec | 52 +- proof/security/FO_MLKEM.ec | 6 +- proof/security/FO_TT.ec | 2 +- proof/security/FO_UU.ec | 4 +- proof/security/MLWE.ec | 16 +- proof/security/MLWE_PKE_Hash.ec | 2 +- proof/spec/MLKEMSecurity.ec | 50 +- 10 files changed, 2388 insertions(+), 2169 deletions(-) diff --git a/proof/correctness/MLKEM_InnerPKE.ec b/proof/correctness/MLKEM_InnerPKE.ec index fe58c7a6..08f37ce1 100644 --- a/proof/correctness/MLKEM_InnerPKE.ec +++ b/proof/correctness/MLKEM_InnerPKE.ec @@ -853,15 +853,15 @@ equiv sample_noise_good2 _key : proc => /=. unroll for {2} 7; unroll for {2} 5. seq 2 2 : (#pre); 1:by auto. -seq 3 4 : (#pre /\ lift_array256 (subarray256 noise1{1} 0) = (noise1{2}.[0])%Vector /\ _N{2}=0 /\ i{2} = 0 /\ +seq 3 3 : (#pre /\ lift_array256 (subarray256 noise1{1} 0) = (noise1{2}.[0])%Vector /\ _N{2}=0 /\ (forall k, 0<=k<256 => -5 < to_sint noise1{1}.[k] < 5)). -wp; call (get_noise_sample_noise); auto => /> &1 &2 *; split. +wp; call (get_noise_sample_noise); auto => /> *; split. + rewrite /lift_array256 /subarray256 /= tP => k kb. rewrite !mapiE //= initiE //= initiE 1:/# /= kb /=. by rewrite /= setvE /= offunvE 1:/# /= !mapiE /#. by move => k kbl kbh;rewrite !initiE 1:/# /= kbl kbh /= /#. -seq 3 4 : (#{/~_N{2}}{~i{2}}{~noise1{1}}pre /\ _N{2}=1 /\ i{2} = 1 /\ +seq 3 3 : (#{/~_N{2}}{~i{2}}{~noise1{1}}pre /\ _N{2}=1 /\ lift_array256 (subarray256 noise1{1} 0) = (noise1{2}.[0])%Vector /\ lift_array256 (subarray256 noise1{1} 1) = (noise1{2}.[1])%Vector /\ (forall k, 0<=k<512 => -5 < to_sint noise1{1}.[k] < 5)). @@ -872,7 +872,7 @@ wp; call (get_noise_sample_noise); auto => /> &1 &2 H *; do split. by move : (H k kb);rewrite !mapiE //= !initiE //= !initiE //= /#. by move => k kbl kbh;rewrite !initiE 1:/# /= /#. -seq 3 4 : (#{/~_N{2}}{~i{2}}{~noise1{1}}pre /\ _N{2}=2 /\ i{2} = 2 /\ +seq 3 3 : (#{/~_N{2}}{~i{2}}{~noise1{1}}pre /\ _N{2}=2 /\ lift_polyvec noise1{1} = noise1{2} /\ (forall k, 0<=k<768 => -5 < to_sint noise1{1}.[k] < 5)). wp; call (get_noise_sample_noise); auto => /> &1 &2 H0 H1 *; split; @@ -888,14 +888,14 @@ case (r = 1). move => *; have -> /= : 2 = r by smt(). by rewrite !mapiE //= !initiE //= !initiE //= /#. -seq 3 5 : (#{/~_N{2}}{~i{2}}pre /\ lift_array256 (subarray256 noise2{1} 0) = (noise2{2}.[0])%Vector /\ _N{2}=3 /\ i{2} = 0 /\ +seq 3 4 : (#{/~_N{2}}{~i{2}}pre /\ lift_array256 (subarray256 noise2{1} 0) = (noise2{2}.[0])%Vector /\ _N{2}=3 /\ (forall k, 0<=k<256 => -5 < to_sint noise2{1}.[k] < 5)). wp; call (get_noise_sample_noise); auto => /> &1 &2 *; split. + rewrite /lift_array256 /subarray256 setvE /= offunvE 1:/# /= tP => k kb. by rewrite !mapiE //= initiE //= initiE 1:/# /= kb /=. by move => k kbl kbh;rewrite !initiE 1:/# /= kbl kbh /= /#. -seq 3 4 : (#{/~_N{2}}{~i{2}}{~noise2{1}}pre /\ _N{2}=4 /\ i{2} = 1 /\ +seq 3 3 : (#{/~_N{2}}{~i{2}}{~noise2{1}}pre /\ _N{2}=4 /\ lift_array256 (subarray256 noise2{1} 0) = (noise2{2}.[0])%Vector /\ lift_array256 (subarray256 noise2{1} 1) = (noise2{2}.[1])%Vector /\ (forall k, 0<=k<512 => -5 < to_sint noise2{1}.[k] < 5)). @@ -906,7 +906,7 @@ wp; call (get_noise_sample_noise); auto => /> &1 &2 ?H *; do split. by move : (H k kb);rewrite !mapiE //= !initiE //= !initiE //= /#. by move => k kbl kbh;rewrite !initiE 1:/# /= /#. -seq 3 4 : (#{/~_N{2}}{~i{2}}{~noise2{1}}pre /\ _N{2}=5 /\ i{2} = 2 /\ +seq 3 3 : (#{/~_N{2}}{~i{2}}{~noise2{1}}pre /\ _N{2}=5 /\ lift_polyvec noise2{1} = noise2{2} /\ (forall k, 0<=k<768 => -5 < to_sint noise2{1}.[k] < 5)). wp; call (get_noise_sample_noise); auto => /> &1 &2 ?H0 H1 *; split; diff --git a/proof/correctness/NTTAlgebra.ec b/proof/correctness/NTTAlgebra.ec index e90db24c..5da9b3f6 100644 --- a/proof/correctness/NTTAlgebra.ec +++ b/proof/correctness/NTTAlgebra.ec @@ -1684,7 +1684,7 @@ theory NTTequiv. rewrite (divz_pow_sub_range _ _ _ mem_kl_range) //=. rewrite (exprSr_range _ _ _ mem_kl_range) //=. rewrite (exprD_nneg_sub_add_range _ _ _ mem_kl_range) //=. - move => [[? _] _] {r1 r2 r} [r1 r2] [_] [r] /= [->> ->>]. + move => [[? _] _] {r1 r2 r} [r1 r2] [r] /= [->> ->>]. rewrite -mulrSl -ltz_NdivNLR; [by rewrite expr_gt0|]. rewrite (NdivzN_pow_sub_range _ _ _ mem_kl_range) //=. split => //=; rewrite (exprSr_sub_range _ _ _ mem_kl_range) //=. @@ -2497,7 +2497,7 @@ theory NTTequiv. rewrite (divz_pow_add_range _ _ _ mem_kl_range) //=. rewrite (exprSr_sub_range _ _ _ mem_kl_range) //=. rewrite (exprD_nneg_add_sub_range _ _ _ mem_kl_range) //=. - move => [[? _] _] {r1 r2 r} [r1 r2] [_] [r] /= [->> ->>]. + move => [[? _] _] {r1 r2 r} [r1 r2] [r] /= [->> ->>]. rewrite -mulrSl -ltz_NdivNLR; [by rewrite expr_gt0|]. rewrite (NdivzN_pow_add_range _ _ _ mem_kl_range) //=. split => //=; rewrite (exprSr_add_range _ _ _ mem_kl_range) //=. diff --git a/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec b/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec index a29cbba5..cd3aacbb 100644 --- a/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec +++ b/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec @@ -1,6 +1,6 @@ require import AllCore IntDiv List. from Jasmin require import JModel. -require import Array16 Array25 Array32 Array33 Array128 Array136 Array256 Array768 Array960 Array1088 Array2304. +require import Array4 Array16 Array25 Array32 Array33 Array128 Array136 Array256 Array768 Array960 Array1088 Array2304. require import List_extra. require import MLKEM_Poly MLKEM_PolyVec MLKEM_InnerPKE. require import MLKEM_Poly_avx2_proof. @@ -25,2341 +25,2523 @@ import MLKEM_PolyVec. import MLKEM_PolyvecAVX. import MLKEM_PolyAVXVec. import NTT_Avx2. -import WArray136 WArray32 WArray128. +import WArray136 WArray32 WArray128 WArray960 WArray1536. import WArray512 WArray256. -(* shake assumptions *) - -(* -op SHAKE256_ABSORB4x_33 : W8.t Array33.t -> W8.t Array33.t -> W8.t Array33.t -> W8.t Array33.t -> W256.t Array25.t. -op SHAKE256_SQUEEZENBLOCKS4x : W256.t Array25.t -> W256.t Array25.t * W8.t Array136.t * W8.t Array136.t * W8.t Array136.t * W8.t Array136.t. - -axiom shake_absorb4x state seed1 seed2 seed3 seed4 : - phoare [ Jkem_avx2.M(Jkem_avx2.Syscall)._shake256_absorb4x_33 : - arg = (state,seed1,seed2,seed3,seed4) ==> - res = SHAKE256_ABSORB4x_33 seed1 seed2 seed3 seed4 ] = 1%r. - -axiom shake_squeezenblocks4x state buf1 buf2 buf3 buf4 : - phoare [ Jkem_avx2.M(Jkem_avx2.Syscall).__shake256_squeezenblocks4x : ` - arg = (state,buf1,buf2,buf3,buf4) ==> - res = SHAKE256_SQUEEZENBLOCKS4x state ] = 1%r. - -axiom shake4x_equiv (sn1 sn2 sn3 sn4: W8.t Array33.t) (s1 s2 s3 s4 : W8.t Array32.t) n1 n2 n3 n4 : - s1 = Array32.init (fun i => sn1.[i]) => - s2 = Array32.init (fun i => sn2.[i]) => - s3 = Array32.init (fun i => sn3.[i]) => - s4 = Array32.init (fun i => sn4.[i]) => - n1 = sn1.[32] => n2 = sn2.[32] => n3 = sn3.[32] => n4 = sn4.[32] => - Array128.init ("_.[_]" (SHAKE256_SQUEEZENBLOCKS4x (SHAKE256_ABSORB4x_33 sn1 sn2 sn3 sn4)).`2) = SHAKE256_33_128 s1 n1 /\ - Array128.init ("_.[_]" (SHAKE256_SQUEEZENBLOCKS4x (SHAKE256_ABSORB4x_33 sn1 sn2 sn3 sn4)).`3) = SHAKE256_33_128 s2 n2 /\ - Array128.init ("_.[_]" (SHAKE256_SQUEEZENBLOCKS4x (SHAKE256_ABSORB4x_33 sn1 sn2 sn3 sn4)).`4) = SHAKE256_33_128 s3 n3 /\ - Array128.init ("_.[_]" (SHAKE256_SQUEEZENBLOCKS4x (SHAKE256_ABSORB4x_33 sn1 sn2 sn3 sn4)).`5) = SHAKE256_33_128 s4 n4. - -axiom sha3equiv : - equiv [ (* is this in the sha3 paper? *) -Jkem_avx2.M(Jkem_avx2.Syscall)._sha3_512_32 ~Jkem.M(Jkem.Syscall)._sha3512_32 : ={arg} ==> ={res}]. - -lemma keccakf1600_set_row_ll : islossless M(Syscall).keccakf1600_set_row. -proc. by unroll for ^while; auto. qed. +(* MOVE THIS ELSEWHERE? *) -lemma keccakf1600_rho_offsets_ll : islossless M(Syscall).keccakf1600_rho_offsets. -proc. by unroll for ^while; islossless. qed. +require import BitEncoding. +require import StdOrder. +import IntOrder. +import BS2Int. +print BS2Int. +lemma to_sint_pos (w: W32.t): + to_uint w < 2^(W32.size - 1) => + to_sint w = to_uint w. +proof. +rewrite to_sintE /smod /= => H. +by rewrite ifF 1:/#. +qed. -lemma keccakf1600_rhotates_ll : islossless M(Syscall).keccakf1600_rhotates. -proc. by call keccakf1600_rho_offsets_ll; islossless. qed. +lemma W32_get_ule (w: W32.t) k: + w.[k] => 2^k <= W32.to_uint w. +proof. +rewrite get_to_uint => [[Hk H]]. +smt(W32.to_uint_cmp gt0_pow2). +qed. -lemma keccakf1600_theta_rol_ll : islossless M(Syscall).keccakf1600_theta_rol. -proc. by unroll for ^while; islossless. qed. +lemma W32_msbE (w: W32.t): + msb w <=> w.[W32.size-1]. +proof. +rewrite /msb; split; last by apply W32_get_ule. +rewrite get_to_uint /= => H. +have L0: 0 < to_uint w %/ 2147483648. + smt(W32.to_uint_cmp). +have: to_uint w \in range (1 * 2147483648) ((1 + 1) * 2147483648). +(* REMARK: seg-faults with rewrite /= *) + rewrite /range /= mem_iota; split => //=. + by move: (W32.to_uint_cmp w) => /= /#. +by rewrite -eq_div_range // /#. +qed. + +lemma W32_min_sintE: + W32.of_int W32.min_sint + = W32.zero.[W32.size-1 <- true]. +proof. +rewrite to_uint_eq of_uintK /= to_uint_bits. +rewrite /bits /mkseq -iotaredE /=. +have:= bs2int_pow2 31 0. +rewrite nseq0 /max -mkseq_nseq /mkseq -iotaredE /= /#. +qed. -lemma keccakf1600_theta_sum_ll : islossless M(Syscall).keccakf1600_theta_sum. -proc. by do 6!(unroll for ^while); islossless. qed. +lemma W32_msb_sar (w: W32.t) k: + 0 <= k => + msb (sar w k) = msb w. +proof. +move=> Hk. +by rewrite !W32_msbE /(`|>>`) /sar /= /#. +qed. -lemma keccakf1600_rol_sum_ll : islossless M(Syscall).keccakf1600_rol_sum. -proc. -while (x <= 5) (5 - x); auto; last smt(). -conseq => /=; call keccakf1600_rhotates_ll; auto => /#. +lemma W32_msb_sign (w: W32.t): + msb w <=> to_sint w < 0. +proof. +rewrite /msb to_sintE /smod. +by move: (W32.to_uint_cmp w); smt(). qed. -lemma keccakf1600_round_ll : islossless Jkem.M(Syscall).keccakf1600_round. -proc; auto. -while (y <= 5) (5 - y); auto. -+ call keccakf1600_set_row_ll. - call keccakf1600_rol_sum_ll. - auto; smt(). -call keccakf1600_theta_rol_ll. -call keccakf1600_theta_sum_ll. -auto; smt(). +lemma W32_to_sint_neg (w: W32.t): + msb w => + to_sint w = to_uint w - W32.modulus. +proof. +rewrite /msb to_sintE /smod. +by move: (W32.to_uint_cmp w); smt(). qed. -lemma keccakf1600_ll : islossless Jkem.M(Syscall)._keccakf1600_. -proc; auto. -call (:true); auto. -call (:true); auto. -while (to_uint c <= 24 /\ to_uint c %% 2 = 0) (24 - to_uint c); auto; last by move => /> *; rewrite ultE to_uint_small //= /#. -call keccakf1600_round_ll; auto. -call keccakf1600_round_ll; auto. -move => /> ??; rewrite ultE to_uintD_small to_uint_small //= /#. +lemma W32_to_sint_pos (w: W32.t): + !msb w => + to_sint w = to_uint w. +proof. +rewrite /msb to_sintE /smod. +by move: (W32.to_uint_cmp w); smt(). qed. -lemma sha3ll : islossless Jkem.M(Jkem.Syscall)._shake256_128_33. +lemma W32_sar_pos (w: W32.t) k: + 0 <= k => + !msb w => + sar w k = w `>>>` k. proof. -proc. -unroll for 9; wp; conseq => /=. -call keccakf1600_ll; auto. -conseq => /=. -unroll for ^while; auto. -conseq => /=. -inline *; unroll for ^while; auto. +rewrite W32_msbE /= => Hk Hpos. +apply W32.ext_eq => i Hi. +rewrite /(`|>>`) /sar initiE //=. +rewrite /(`>>`) /(`>>>`) Hi /=. +case: (31 < (i + k)) => E. + by rewrite lez_minl 1:/# eq_sym get_out /#. +by rewrite lez_minr 1:/#. qed. (* -axiom shake128_equiv_absorb : equiv [ M(Syscall)._shake128_absorb34 ~ - Jkem_avx2.M(Jkem_avx2.Syscall)._shake128_absorb34 : - ={state, in_0} ==> ={res}]. -axiom shake128_equiv_squeezeblock : equiv [ M(Syscall)._shake128_squeezeblock ~ - Jkem_avx2.M(Jkem_avx2.Syscall)._shake128_squeezeblock : - ={state, out} ==> ={res}]. +| (* IntDiv.modz_pow2_div *) +| lemma modz_pow2_div: +| forall (n p m : int), +| 0 <= p && p <= n => +| m %% 2 ^ n %/ 2 ^ p = m %/ 2 ^ p %% 2 ^ (n - p). +| lemma dvdNdiv: forall (x y : int), x <> 0 => x %| y => (-y) %/ x = - y %/ x. +| lemma divNz: +| forall (m d : int), 0 < m => 0 < d => (-m) %/ d = - ((m - 1) %/ d + 1). +| lemma lez_NdivNLR: +| forall (d m n : int), 0 < d => d %| n => m <= n %\ d <=> m * d <= n. +| lemma divzDr: forall (m n d : int), d %| n => (m + n) %/ d = m %/ d + n %/ d. *) +lemma W32_shl_onew k: + 0 < k => + W32.onew `>>>` k + = W32.of_int (2^(max 0 (W32.size - k)) -1). +proof. +move => Hk. +apply W32.ext_eq => i Hi. +by rewrite /(`>>>`) initiE //= /#. +qed. -equiv genmatrixequiv b : - Jkem_avx2.M(Jkem_avx2.Syscall)._gen_matrix_avx2 ~ Jkem.M(Jkem.Syscall).__gen_matrix : - arg{1}.`2 = arg{2}.`1 /\ arg{1}.`3= (W64.of_int (b2i b)) /\ arg{2}.`2 = (W64.of_int (b2i b)) ==> - res{1} = nttunpackm res{2} /\ - pos_bound2304_cxq res{1} 0 2304 2 /\ - pos_bound2304_cxq res{2} 0 2304 2. -proc* => /=. -transitivity {2} { r <@ AuxMLKEM.__gen_matrix(seed,b); } - ( rho{1} = seed{2} /\ transposed{1} = (of_int (b2i b))%W64 /\ transposed{2} = (of_int (b2i b))%W64 ==> r{1} = nttunpackm r{2} /\ pos_bound2304_cxq r{1} 0 2304 2 /\ pos_bound2304_cxq r{2} 0 2304 2 ) - ( seed{1} = seed{2} /\ transposed{1} = (of_int (b2i b))%W64 /\ transposed{2} = (of_int (b2i b))%W64==> ={r});1,2:smt(). - + call (genmatrixequiv_aux b); 1: by auto => />. - by symmetry;call (auxgenmatrix_good); auto => /> /#. +lemma W32_sarE_neg (w: W32.t) k: + 0 <= k => + msb w => + sar w k + = (w `>>>` k) `|` invw (W32.onew `>>>` k). +proof. +rewrite W32_msbE /= => Hk Hmsb. +apply W32.ext_eq => i Hi. +rewrite /(`|>>`) /sar initiE //=. +rewrite /(`>>`) /(`>>>`) !Hi //=. +have ->/=: 0 <= i + k by smt(). +case: (i + k < 32) => C. + by rewrite lez_minr 1:/#. +by rewrite lez_minl 1:/# Hmsb. +qed. + +lemma W32_to_uint_sar_neg (w: W32.t) k: + 0 <= k => + msb w => + to_uint (sar w k) + = W32.modulus + (to_sint w %/ 2^k). +proof. +move=> Hk Hmbs. +have [Hw0 Hw1]:= W32.to_uint_cmp w. +rewrite W32_sarE_neg // W32.to_uint_orw_disjoint. + apply W32.ext_eq => i Hi. + rewrite /(`>>`) /(`>>>`) /= !initiE Hi //=. + have ->/=: 0 <= i + k by smt(). + case: (i + k < 32) => C //=. + by rewrite get_out /#. +case: (k = 0) => Ek0. + rewrite !Ek0 /= to_uint_shr 1:/#. + rewrite to_uint_invw to_uint_shr 1:/#. + by rewrite W32_to_sint_neg // to_uint_onew /#. +case: (32 <= k) => Ek32. + rewrite to_uint_shr 1:/# divz_small /=. + smt(ler_weexpn2l). + have ->: to_sint w %/ 2 ^ k = -1. + rewrite W32_to_sint_neg //. + have ->: to_uint w - W32.modulus + = - (W32.modulus - to_uint w) by ring. + rewrite divNz. + by rewrite subr_gt0; apply Hw1. + smt(expr_gt0). + rewrite divz_small //. + apply bound_abs; split. + rewrite /= in Hw1. + by rewrite /= /#. + move=> ?. + apply (ltr_le_trans W32.modulus); first smt(). + by apply ler_weexpn2l => /#. + rewrite to_uint_invw to_uint_shr 1:/# to_uint_onew divz_small //. + apply bound_abs; split; first done. + move=> ?; apply (ltr_le_trans W32.modulus); first smt(). + by apply ler_weexpn2l. +rewrite to_uint_shr; first smt(to_uint_cmp). +rewrite W32_shl_onew; first smt(to_uint_cmp). +rewrite lez_maxr 1:/# to_uint_invw. +rewrite of_uintK modz_small. + apply bound_abs; split; first smt(@IntDiv). + move => ?. + have /=?: 2 ^ (32 - k) <= W32.modulus. + by apply ler_weexpn2l => // /#. + by rewrite /= /#. +have Hkk: W32.modulus = 2^(32-k) * 2^k. + rewrite -exprD_nneg // 1:/#. + by rewrite -addzA /=. +have ->: W32.max_uint - (2 ^ (32 - k) - 1) + = W32.modulus - W32.modulus %/ 2^k. + rewrite {3}Hkk mulzK /#. +rewrite addzC -addzA; congr. +rewrite addzC W32_to_sint_neg //. +have ->: (to_uint w - W32.modulus) + = to_uint w + (- 2^(32-k)) * 2^k. + by rewrite Hkk; ring. +rewrite divzMDr 1:/# addzC; congr. +have ->: W32.modulus = 2^(32 - k + k). + by congr; smt(). +smt(). qed. -module GetNoiseAVX2 = { - proc _poly_getnoise_eta1_4x(aux3 aux2 aux1 aux0 : W16.t Array256.t, - noiseseed : W8.t Array32.t, - nonce : W8.t) : - W16.t Array256.t * W16.t Array256.t * W16.t Array256.t * W16.t Array256.t = { - var n3, n2, n1, n0 : W8.t; - var aux_3, aux_2, aux_1, aux_0 : W16.t Array256.t; - n0 <- nonce + W8.of_int 3; - n1 <- nonce + W8.of_int 2; - n2 <- nonce + W8.of_int 1; - n3 <- nonce; - aux_3 <@Jkem.M(Jkem.Syscall)._poly_getnoise(aux3,noiseseed,n3); - aux_2 <@Jkem.M(Jkem.Syscall)._poly_getnoise(aux2,noiseseed,n2); - aux_1 <@Jkem.M(Jkem.Syscall)._poly_getnoise(aux1,noiseseed,n1); - aux_0 <@Jkem.M(Jkem.Syscall)._poly_getnoise(aux0,noiseseed,n0); - return (aux_3, aux_2, aux_1, aux_0); - } +lemma W32_sar_div (w1 : W32.t) k: + 0 <= k => + to_sint (sar w1 k) + = to_sint w1 %/ 2 ^ k. +proof. +case: (msb w1) => Hk C. + rewrite W32_to_sint_neg 1:W32_msb_sar //. + by rewrite W32_to_uint_sar_neg //. +rewrite W32_to_sint_pos 1:W32_msb_sar //. +by rewrite W32_sar_pos // to_uint_shr /#. +qed. - proc sample_noise_kg(skpv pkpv e : W16.t Array768.t, noiseseed:W8.t Array32.t) : W16.t Array768.t * W16.t Array768.t ={ - var nonce : W8.t; - var aux_3, aux_2, aux_1, aux_0 : W16.t Array256.t; - nonce <- (W8.of_int 0); - (aux_3, aux_2, aux_1, - aux_0) <@ _poly_getnoise_eta1_4x ((Array256.init (fun i_0 => skpv.[0 + i_0])), - (Array256.init (fun i_0 => skpv.[256 + i_0])), - (Array256.init (fun i_0 => skpv.[(2 * 256) + i_0])), - (Array256.init (fun i_0 => e.[0 + i_0])), noiseseed, nonce); - skpv <- Array768.init - (fun i_0 => if 0 <= i_0 < 0 + 256 then aux_3.[i_0-0] - else skpv.[i_0]); - skpv <- Array768.init - (fun i_0 => if 256 <= i_0 < 256 + 256 - then aux_2.[i_0-256] else skpv.[i_0]); - skpv <- Array768.init - (fun i_0 => if (2 * 256) <= i_0 < (2 * 256) + 256 - then aux_1.[i_0-(2 * 256)] else skpv.[i_0]); - e <- Array768.init - (fun i_0 => if 0 <= i_0 < 0 + 256 then aux_0.[i_0-0] - else e.[i_0]); - nonce <- (W8.of_int 4); - (aux_3, aux_2, aux_1, - aux_0) <@ _poly_getnoise_eta1_4x ((Array256.init (fun i_0 => e.[256 + i_0])), - (Array256.init (fun i_0 => e.[(2 * 256) + i_0])), - (Array256.init (fun i_0 => pkpv.[0 + i_0])), - (Array256.init (fun i_0 => pkpv.[256 + i_0])), noiseseed, - nonce); - e <- Array768.init - (fun i_0 => if 256 <= i_0 < 256 + 256 - then aux_3.[i_0-256] else e.[i_0]); - e <- Array768.init - (fun i_0 => if (2 * 256) <= i_0 < (2 * 256) + 256 - then aux_2.[i_0-(2 * 256)] else e.[i_0]); - return (skpv,e); - } +require import Bindings. - proc samplenoise_enc(sp_0 ep bp : W16.t Array768.t, epp : W16.t Array256.t, noiseseed:W8.t Array32.t) : W16.t Array768.t * W16.t Array768.t * W16.t Array768.t * W16.t Array256.t = { - var nonce : W8.t; - var aux_2, aux_1, aux_0, aux : W16.t Array256.t; - nonce <- (W8.of_int 0); - (aux_2, aux_1, aux_0, - aux) <@ _poly_getnoise_eta1_4x ((Array256.init (fun i_0 => sp_0.[0 + i_0])), - (Array256.init (fun i_0 => sp_0.[256 + i_0])), - (Array256.init (fun i_0 => sp_0.[(2 * 256) + i_0])), - (Array256.init (fun i_0 => ep.[0 + i_0])), noiseseed, - nonce); - sp_0 <- Array768.init - (fun i_0 => if 0 <= i_0 < 0 + 256 then aux_2.[i_0-0] - else sp_0.[i_0]); - sp_0 <- Array768.init - (fun i_0 => if 256 <= i_0 < 256 + 256 - then aux_1.[i_0-256] else sp_0.[i_0]); - sp_0 <- Array768.init - (fun i_0 => if (2 * 256) <= i_0 < (2 * 256) + 256 - then aux_0.[i_0-(2 * 256)] else sp_0.[i_0]); - ep <- Array768.init - (fun i_0 => if 0 <= i_0 < 0 + 256 then aux.[i_0-0] - else ep.[i_0]); - nonce <- (W8.of_int 4); - (aux_2, aux_1, aux_0, - aux) <@ _poly_getnoise_eta1_4x ((Array256.init (fun i_0 => ep.[256 + i_0])), - (Array256.init (fun i_0 => ep.[(2 * 256) + i_0])), epp, - (Array256.init (fun i_0 => bp.[0 + i_0])), noiseseed, - nonce); - ep <- Array768.init - (fun i_0 => if 256 <= i_0 < 256 + 256 - then aux_2.[i_0-256] else ep.[i_0]); - ep <- Array768.init - (fun i_0 => if (2 * 256) <= i_0 < (2 * 256) + 256 - then aux_1.[i_0-(2 * 256)] else ep.[i_0]); - epp <- aux_0; - bp <- Array768.init - (fun i_0 => if 0 <= i_0 < 0 + 256 then aux.[i_0-0] - else bp.[i_0]); - return (sp_0,ep,bp, epp); +op W32_sar_32 (w1 w2 : W32.t) : W32.t = + W32.sar w1 (to_uint w2). - } -}. +bind op [W32.t] W32_sar_32 "ashr". +realize bvashrP. +move=> bv1 bv2; rewrite W32_sar_div //. +smt(W32.to_uint_cmp). +qed. -(* int value of jth noise coeficient *) -op noise_coef (bytes: W8.t Array128.t) (j: int): int = - let b = bytes.[j%/2] in b2i b.[j%%2*4] + b2i b.[j%%2*4+1] - (b2i b.[j%%2*4+2] + b2i b.[j%%2*4+3]). +(* BINDINGS *) -import WArray128. +bind array Array256."_.[_]" Array256."_.[_<-_]" Array256.to_list Array256.of_list Array256.t 256. +realize tolistP by done. +realize get_setP by smt(Array256.get_setE). +realize eqP by smt(Array256.tP). +realize get_out by smt(Array256.get_out). -op B2Ri (bytes: W8.t Array128.t) (j: int): W256.t = - get256 (WArray128.init8 (fun i => bytes.[i])) j. -lemma bytes_getR (bytes: W8.t Array128.t) (k: int): - 0 <= k && k < 128 => - bytes.[k] = B2Ri bytes (k %/ 32) \bits8 (k %% 32). -proof. -move=> Hk; rewrite /B2Ri /get256_direct pack32bE 1:/# initiE 1:/# /=. -by rewrite mulrC -divz_eq /init8 initiE. -qed. +bind array Array768."_.[_]" Array768."_.[_<-_]" Array768.to_list Array768.of_list Array768.t 768. +realize tolistP by done. +realize get_setP by smt(Array768.get_setE). +realize eqP by smt(Array768.tP). +realize get_out by smt(Array768.get_out). -abbrev mask55u256 = VPBROADCAST_8u32 (W32.of_int 1431655765). -abbrev mask33u256 = VPBROADCAST_8u32 (W32.of_int 858993459). -abbrev mask03u256 = VPBROADCAST_8u32 (W32.of_int 50529027). -abbrev mask0Fu256 = VPBROADCAST_8u32 (W32.of_int 252645135). +bind array Array32."_.[_]" Array32."_.[_<-_]" Array32.to_list Array32.of_list Array32.t 32. +realize tolistP by done. +realize get_setP by smt(Array32.get_setE). +realize eqP by smt(Array32.tP). +realize get_out by smt(Array32.get_out). -abbrev mask55u16 = W16.of_int 21845. (* 21845 = 0x5555 *) -abbrev mask33u16 = W16.of_int 13107. (* 13107 = 0x3333 *) -abbrev mask03u16 = W16.of_int 771. (* 771 = 0x0303 *) -abbrev mask0Fu16 = W16.of_int 3855. (* 3855 = 0x0F0F *) +bind array Array960."_.[_]" Array960."_.[_<-_]" Array960.to_list Array960.of_list Array960.t 960. +realize tolistP by done. +realize get_setP by smt(Array960.get_setE). +realize eqP by smt(Array960.tP). +realize get_out by smt(Array960.get_out). -abbrev mask55u8 = W8.of_int 85. (* 85 = 0x55 *) -abbrev mask33u8 = W8.of_int 51. (* 51 = 0x33 *) -abbrev mask03u8 = W8.of_int 3. (* 3 = 0x03 *) -abbrev mask0Fu8 = W8.of_int 15. (* 15 = 0x0F *) +bind array Array1088."_.[_]" Array1088."_.[_<-_]" Array1088.to_list Array1088.of_list Array1088.t 1088. +realize tolistP by done. +realize get_setP by smt(Array1088.get_setE). +realize eqP by smt(Array1088.tP). +realize get_out by smt(Array1088.get_out). -lemma mask55_bits16 k: - 0 <= k < 16 => - mask55u256 \bits16 k = mask55u16. -proof. -move=> Hk. -rewrite /VPBROADCAST_8u32. -rewrite bits16_W8u32 Hk //= get_of_list 1:/# /=. -rewrite (nth_map 0) /=; first smt(size_iota). -have: (k%%2 \in iota_ 0 2) by smt(mem_iota). -by move: (k%%2); rewrite -allP -iotaredE /= W2u16.bits16_div //. -qed. +bind array Array4."_.[_]" Array4."_.[_<-_]" Array4.to_list Array4.of_list Array4.t 4. +realize tolistP by done. +realize get_setP by smt(Array4.get_setE). +realize eqP by smt(Array4.tP). +realize get_out by smt(Array4.get_out). -lemma mask55_bits8 k: - 0 <= k < 32 => - mask55u256 \bits8 k = mask55u8. -proof. -move=> Hk. -rewrite /VPBROADCAST_8u32. -rewrite bits8_W8u32 Hk //= get_of_list 1:/# /=. -rewrite (nth_map 0) /=; first smt(size_iota). -have: (k%%4 \in iota_ 0 4) by smt(mem_iota). -by move: (k%%4); rewrite -allP -iotaredE /= W4u8.bits8_div //. -qed. -lemma mask33_bits16 k: - 0 <= k < 16 => - mask33u256 \bits16 k = mask33u16. -proof. -move=> Hk. -rewrite /VPBROADCAST_8u32. -rewrite bits16_W8u32 Hk //= get_of_list 1:/# /=. -rewrite (nth_map 0) /=; first smt(size_iota). -have: (k%%2 \in iota_ 0 2) by smt(mem_iota). -by move: (k%%2); rewrite -allP -iotaredE /= W2u16.bits16_div //. -qed. +op init_256_16 (f: int -> W16.t) : W16.t Array256.t = Array256.init f. -lemma mask33_bits8 k: - 0 <= k < 32 => - mask33u256 \bits8 k = mask33u8. +bind op [W16.t & Array256.t] init_256_16 "ainit". +realize bvainitP. proof. -move=> Hk. -rewrite /VPBROADCAST_8u32. -rewrite bits8_W8u32 Hk //= get_of_list 1:/# /=. -rewrite (nth_map 0) /=; first smt(size_iota). -have: (k%%4 \in iota_ 0 4) by smt(mem_iota). -by move: (k%%4); rewrite -allP -iotaredE /= W4u8.bits8_div //. +rewrite /init_256_16 => f. +rewrite BVA_Top_Array256_Array256_t.tolistP. +apply eq_in_mkseq => i i_bnd; +smt(Array256.initE). qed. -lemma mask03_bits8 k: - 0 <= k < 32 => - mask03u256 \bits8 k = mask03u8. -proof. -move=> Hk. -rewrite /VPBROADCAST_8u32. -rewrite bits8_W8u32 Hk //= get_of_list 1:/# /=. -rewrite (nth_map 0) /=; first smt(size_iota). -have: (k%%4 \in iota_ 0 4) by smt(mem_iota). -by move: (k%%4); rewrite -allP -iotaredE /= W4u8.bits8_div //. -qed. +op init_768_16 (f: int -> W16.t) : W16.t Array768.t = Array768.init f. -lemma mask0F_bits16 k: - 0 <= k < 16 => - mask0Fu256 \bits16 k = mask0Fu16. -proof. -move=> Hk. -rewrite /VPBROADCAST_8u32. -rewrite bits16_W8u32 Hk //= get_of_list 1:/# /=. -rewrite (nth_map 0) /=; first smt(size_iota). -have: (k%%2 \in iota_ 0 2) by smt(mem_iota). -by move: (k%%2); rewrite -allP -iotaredE /= W2u16.bits16_div //. -qed. +print Array768.initE. -lemma mask0F_bits8 k: - 0 <= k < 32 => - mask0Fu256 \bits8 k = mask0Fu8. -proof. -move=> Hk. -rewrite /VPBROADCAST_8u32. -rewrite bits8_W8u32 Hk //= get_of_list 1:/# /=. -rewrite (nth_map 0) /=; first smt(size_iota). -have: (k%%4 \in iota_ 0 4) by smt(mem_iota). -by move: (k%%4); rewrite -allP -iotaredE /= W4u8.bits8_div //. +bind op [W16.t & Array768.t] init_768_16 "ainit". +realize bvainitP. +rewrite /init_768_16 => f. +rewrite BVA_Top_Array768_Array768_t.tolistP. +apply eq_in_mkseq => i i_bnd; smt(Array768.initE). qed. -lemma VPSRL1_ANDmask55 w k: - 0 <= k < 32 => - mask55u256 `&` (VPSRL_16u16 w (W8.of_int 1)) \bits8 k - = mask55u8 `&` ((w \bits8 k) `>>` (W8.of_int 1)). -proof. -move=> Hk. -rewrite {1}(_:k=2*(k%/2) + (k%%2)); first smt(divz_eq). -rewrite -W256_bits16_bits8 1:/# andb16E /VPSRL_16u16 mapbE 1:/# /=. -rewrite W256_bits16_bits8 1:/# mask55_bits8 1:/#. -apply W8extra.wordP_red. rewrite -allP /=. -have: (k\in iota_ 0 32) by smt(mem_iota). -by move: {Hk} k; rewrite -allP -iotaredE /= !W16.shrwE !W8.shrwE /int_bit /=. -qed. +op init_4_64 (f: int -> W64.t) : W64.t Array4.t = Array4.init f. -lemma VPSRL2_ANDmask33 w k: - 0 <= k < 32 => - mask33u256 `&` (VPSRL_16u16 w (W8.of_int 2)) \bits8 k - = mask33u8 `&` ((w \bits8 k) `>>` (W8.of_int 2)). +bind op [W64.t & Array4.t] init_4_64 "ainit". +realize bvainitP. proof. -move=> Hk. -rewrite {1}(_:k=2*(k%/2) + (k%%2)); first smt(divz_eq). -rewrite -W256_bits16_bits8 1:/# andb16E /VPSRL_16u16 mapbE 1:/# /=. -rewrite W256_bits16_bits8 1:/# mask33_bits8 1:/#. -apply W8extra.wordP_red. rewrite -allP /=. -have: (k\in iota_ 0 32) by smt(mem_iota). -by move: {Hk} k; rewrite -allP -iotaredE /= !W16.shrwE !W8.shrwE /int_bit /=. +rewrite /init_4_64 => f. +rewrite BVA_Top_Array4_Array4_t.tolistP. +apply eq_in_mkseq => i i_bnd; smt(Array4.initE). qed. -lemma VPSRL4_ANDmask0F w k: - 0 <= k < 32 => - VPAND_256 mask0Fu256 (VPSRL_16u16 w (W8.of_int 4)) \bits8 k - = mask0Fu8 `&` ((w \bits8 k) `>>` (W8.of_int 4)). +op init_960_8 (f: int -> W8.t) : W8.t Array960.t = Array960.init f. + +bind op [W8.t & Array960.t] init_960_8 "ainit". +realize bvainitP. proof. -move=> Hk. -rewrite {1}(_:k=2*(k%/2) + (k%%2)); first smt(divz_eq). -rewrite -W256_bits16_bits8 1:/# andb16E /VPSRL_16u16 mapbE 1:/# /=. -rewrite W256_bits16_bits8 1:/# mask0F_bits8 1:/#. -apply W8extra.wordP_red. rewrite -allP /=. -have: (k\in iota_ 0 32) by smt(mem_iota). -by move: {Hk} k; rewrite -allP -iotaredE /= !W16.shrwE !W8.shrwE /int_bit /=. +rewrite /init_960_8 => f. +rewrite BVA_Top_Array960_Array960_t.tolistP. +apply eq_in_mkseq => i i_bnd; smt(Array960.initE). qed. -lemma to_uint_mask33 (w:W8.t): - to_uint (mask33u8 `&` w) - = to_uint w %% 4 + to_uint w %/ 16 %% 4 * 16. +op init_1088_8 (f: int -> W8.t) : W8.t Array1088.t = Array1088.init f. + +bind op [W8.t & Array1088.t] init_1088_8 "ainit". +realize bvainitP. proof. -have ->: mask33u8 = (mask03u8 `<<<` 4) `|` mask03u8. - apply W8.wordP => k; rewrite -mem_range /range /=. - by move: k; apply/List.allP; rewrite -iotaredE /int_bit /=. -rewrite andwC andw_orwDr orw_disjoint. - apply W8.wordP => k; rewrite -mem_range /range /=. - by move: k; apply/List.allP; rewrite -iotaredE /int_bit /=. -have ->: w `&` (mask03u8 `<<<` 4) - = ((w `>>>` 4) `&` W8.masklsb (6-4)) `<<<` 4. -rewrite -shlw_andmask // shrl_andmaskN // -andwA /=. -congr. -rewrite /max /=. - apply W8.wordP => k; rewrite -mem_range /range /=. - by move: k; apply/List.allP; rewrite -iotaredE /int_bit /=. -have E1: to_uint (w `&` mask03u8) = to_uint w %% 4. - by rewrite (W8.to_uint_and_mod 2) //. -have /= E2: to_uint ((w `>>>` 4) `&` (masklsb (6-4))%W8 `<<<` 4) = to_uint w %/ 16 %% 4 * 16. - rewrite /max /= to_uint_shl // (W8.to_uint_and_mod 2) //. - by rewrite to_uint_shr //= modz_small /#. -rewrite to_uintD_small /=. - by rewrite E1 E2 /#. -by rewrite E1 E2 /#. +rewrite /init_1088_8 => f. +rewrite BVA_Top_Array1088_Array1088_t.tolistP. +apply eq_in_mkseq => i i_bnd; smt(Array1088.initE). qed. -lemma aux_coef_pos b: - to_uint (mask33u8 `&` (mask55u8 `&` b + mask55u8 `&` (b `>>` ru_ones_s))) - = b2i b.[0] + b2i b.[1] + 16 * (b2i b.[4] + b2i b.[5]). +op sliceget256_16_256 (arr: W16.t Array256.t) (offset: int) : W256.t = + if 8 %| offset then + get256_direct ((init16 (fun (i_0 : int) => arr.[i_0])))%WArray512 (offset %/ 8) + else W256.bits2w (take 256 (drop offset (flatten (map W16.w2bits (to_list arr))))). + + +(* +lemma flatten_take_drop_16 (l : W16.t list) (csize offset bit : int) : + 0 <= offset => + offset + csize <= 16 * size l => + 0 <= bit < csize => + nth false (take csize (drop offset (flatten (map W16.w2bits l)))) bit = + (nth witness l ((offset + bit) %/ 16)).[(offset + bit) %% 16]. proof. -rewrite addrC -(mask85_sum b 0) // -(mask85_sum b 2) //= !(W8.andwC mask55u8). -by rewrite to_uint_mask33 /(`>>`) to_uint_shr //= to_uint_shr //= /#. +move => *. +rewrite nth_take 1,2:/#. +rewrite nth_drop 1,2:/#. +rewrite (BitEncoding.BitChunking.nth_flatten false 16). ++ rewrite allP => i; rewrite mapP => He;elim He;smt(W16.size_w2bits). +rewrite -get_w2bits;congr. +by rewrite (nth_map witness) 1:/#. qed. -lemma aux_coef_neg b: - to_uint (mask33u8 `&` ((mask55u8 `&` b + mask55u8 `&` (b `>>` ru_ones_s)) `>>` W8.of_int 2)) - = b2i b.[2] + b2i b.[3] + 16 * (b2i b.[6] + b2i b.[7]). -proof. -rewrite to_uint_mask33 to_uint_shr // -divz_mul //= !(W8.andwC mask55u8). -rewrite {1}(_:4=2^2) // (_:64=2^6) // -(mask85_sum b 1) // -(mask85_sum b 3) //=. -by rewrite to_uint_shr //= to_uint_shr //= /#. + +lemma aligned_get256_16_256 arr offset : +0 <= offset <= 16*256 - 256 => +256 %| offset => +sliceget256_16_256 arr offset = +WArray512.get256 (WArray512.init16 (fun (i_0 : int) => arr.[i_0])) (offset %/ 256). +move => Ho1 Ho2; rewrite /sliceget256_16_256. +have sz : size (take 256 (drop offset (flatten (map W16.w2bits (to_list arr))))) = 256 by rewrite size_take 1:/# size_drop 1:/# /max /=;smt(Array256.size_to_list size_flatten_W16_w2bits). +rewrite wordP => i ib; rewrite get_bits2w //. +rewrite flatten_take_drop_16;1..3:smt(Array256.size_to_list). +rewrite nth_mkseq 1:/# /=. +rewrite /get256_direct /pack32_t initiE 1:/# /= initiE 1:/# /= initiE 1:/# /=. +rewrite get_bits8 1:/#. +smt(@IntDiv). qed. -lemma noise_coef_avx2_aux bytes j: - 3 + noise_coef bytes j - = let b = bytes.[j%/2] in - let x = mask55u8 `&` b + mask55u8 `&` (b `>>` W8.one) in - let y = mask33u8 `&` x + mask33u8 - mask33u8 `&` (x `>>` (W8.of_int 2))in - to_uint y %/ 2^(j%%2*4) %% 16. -proof. -have LL1: forall (x y z:int), (x + z*y) %% z = x %% z. - by move=> x1 x2 x3; rewrite -modzDm modzMr /= modz_mod. -have LL2: forall (x y z:int), (x - z*y) %% z = x %% z. - by move=> x1 x2 x3; rewrite -modzDm -modzNm modzMr /= modz_mod. -move=> /=. -pose b:= bytes.[j %/ 2]. -pose x:= mask55u8 `&` b + mask55u8 `&` (b `>>` ru_ones_s). -case: (j %% 2 = 0) => C. - rewrite C /=. - rewrite -addrA to_uintD /= modz_dvd 1:/#. - rewrite aux_coef_pos W8.to_uintB. - by rewrite ule_andw. - rewrite -modzDm LL1 modzDm aux_coef_neg. - rewrite Ring.IntID.opprD !addzA LL2 /=. - by rewrite -modzDml -(modzDmr _ 51) /= modzDml modz_small /#. -have ->/=: j%%2 = 1 by smt(). -rewrite -addrA to_uintD. -rewrite (_:16=2^4) // modz_pow_div //= modz_mod. -rewrite aux_coef_pos W8.to_uintB. - by rewrite ule_andw. -rewrite aux_coef_neg /= (divz_eq 51 16). -pose X:= (b2i _ + _ + _ + _)%W8. -have /=->: X - = b2i b.[0] + b2i b.[1] + (51 %% 16) - (b2i b.[2] + b2i b.[3]) - + 16 * (b2i b.[4] + b2i b.[5] + (51 %/ 16) - (b2i b.[6] + b2i b.[7])). - by rewrite /X /=; ring. -by rewrite mulzC divzMDr // divz_small //= /#. +*) +bind op [W16.t & W256.t & Array256.t] sliceget256_16_256 "asliceget". +realize bvaslicegetP. +move => /= arr offset; rewrite /sliceget256_16_256 /= => H k kb. +case (8%| offset) => /= *; last by smt(W256.get_bits2w). +rewrite /get256_direct pack32E initiE 1:/# /= initiE 1:/# /= initiE 1:/# /= bits8E initiE 1:/# /=. +rewrite nth_take 1,2:/# nth_drop 1,2:/#. +rewrite (BitEncoding.BitChunking.nth_flatten false 16 _). ++ rewrite allP => x /=; rewrite mapP => He; elim He;smt(W16.size_w2bits). +rewrite (nth_map W16.zero []); 1: smt(Array256.size_to_list). +by rewrite nth_mkseq /#. qed. -lemma noise_coef_avx2 bytes j: - noise_coef bytes j - = let b = bytes.[j%/2] in - let x = mask55u8 `&` b + mask55u8 `&` (b `>>` W8.one) in - let y = mask33u8 `&` x + mask33u8 - mask33u8 `&` (x `>>` (W8.of_int 2)) in - if j%%2 = 0 - then to_sint (mask0Fu8 `&` y - mask03u8) - else to_sint (mask0Fu8 `&` (y `>>` (W8.of_int 4)) - mask03u8). -proof. -have L1: forall x, W8.to_uint x < 128 => W8.to_sint x = to_uint x. - by move=> x; rewrite to_sintE /smod /= /#. -rewrite /noise_coef /=. -pose b:= bytes.[j %/ 2]. -pose x:= b `&` mask55u8 + (b `>>` ru_ones_s) `&` mask55u8. -pose y:= x `&` mask33u8 + mask33u8 - (x `>>` (W8.of_int 2)) `&` mask33u8. -case: (j %% 2 = 0) => C. - rewrite C /= andwC W8_to_sintB_small. - by rewrite !to_sintE (W8.to_uint_and_mod 4) /smod //= /#. - rewrite L1 (W8.to_uint_and_mod 4) //= /smod /= 1:/#. - move: (noise_coef_avx2_aux bytes j) => /=. - by rewrite C to_sintE /smod => <- /#. -have C': j %% 2 = 1 by smt(). -rewrite C' /= andwC W8_to_sintB_small. - by rewrite !to_sintE (W8.to_uint_and_mod 4) /smod //= /#. -rewrite L1 (W8.to_uint_and_mod 4) //= /smod /= 1:/#. -move: (noise_coef_avx2_aux bytes j) => /=. -by rewrite /noise_coef C' to_sintE /smod to_uint_shr //= => <- /#. +import BitEncoding BS2Int BitChunking. + +op sliceset256_16_256 (arr: W16.t Array256.t) (offset: int) (bv: W256.t) : W16.t Array256.t = + if 8 %| offset + then (init (fun (i3 : int) => get16 (set256_direct ((init16 (fun (i_0 : int) => arr.[i_0])))%WArray512 (offset %/ 8) bv) i3))%Array256 + else Array256.of_list witness (map W16.bits2w (chunk 16 (take offset (flatten (map W16.w2bits (to_list arr))) ++ w2bits bv ++ + drop (offset + 256) (flatten (map W16.w2bits (to_list arr)))))). + +(* +lemma aligned_set256_16_256 arr offset bv : +0 <= offset <= 16*256 - 256 => +256 %| offset => +sliceset256_16_256 arr offset bv = +Array256.init (fun (i3 : int) => get16 (set256 ((init16 (fun (i_0 : int) => arr.[i_0])))%WArray512 (offset %/ 256) bv) i3). +rewrite /sliceset256_16_256 tP /= => ?? i ib. +rewrite !initiE 1,2:/# /=. +rewrite get16_set256E 1,2:/# /= (nth_map []). ++ rewrite size_chunk // !size_cat !size_take 1:/# !size_drop 1:/# /max /=. + by smt(Array256.size_to_list size_flatten_W16_w2bits). +rewrite JWordList.nth_chunk //= 1:/#. +rewrite !size_cat !size_take 1:/# !size_drop 1:/# /max /=. + by smt(Array256.size_to_list size_flatten_W16_w2bits). +case (32 * (offset %/ 256) <= 2 * i);last first. ++ move => ? /=. have ? : 16*i < offset. smt(). + rewrite get16_init16 1:/# -catA drop_cat ifT;1: by smt(size_take W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + rewrite take_cat_le ifT;1: by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + have -> : offset = 16 * (offset %/ 16) by smt(). + rewrite take_flatten_ctt; 1: by smt(mapP W16.size_w2bits). + rewrite -map_take. + rewrite -(W16.w2bitsK arr.[i]);congr. + apply (eq_from_nth false). + + rewrite size_w2bits size_take // size_drop 1:/# /= /max /=;smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + move => k kb; rewrite flatten_take_drop_16 1:/#. + + rewrite size_take 1:/# size_to_list //= 1:/#. + by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + rewrite nth_take 1:/#. smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + rewrite get_w2bits;congr; rewrite ?get_to_list;smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). +case (2 * i < 32 * (offset %/ 256 + 1));last first. ++ move => ? /=. have ? : offset + 256 <= 16*i . smt(). + rewrite get16_init16 1:/# -catA drop_cat ifF;1: by smt(size_take W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + rewrite drop_cat ifF;1: by smt(size_take W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + rewrite size_take 1:/# size_flatten_W16_w2bits size_to_list /= ifT 1:/#. + have -> : offset + 256 = 16 * ((offset + 256) %/ 16) by smt(). + rewrite drop_flatten_ctt; 1: by smt(mapP W16.size_w2bits). + have -> : 16 * i - offset - 256 = 16 * (i - offset %/ 16 - 16) by smt(). + rewrite drop_flatten_ctt; 1: by smt(mapP W16.size_w2bits mem_drop). + rewrite drop_drop 1,2:/# /= => ?. + rewrite -(W16.w2bitsK arr.[i]);congr. + apply (eq_from_nth false). + + rewrite -map_drop size_take // size_flatten_W16_w2bits size_drop 1:/#; smt(Array256.size_to_list W16.size_w2bits). + move => k kb. + have -> : i - offset %/ 16 - 16 + (offset + 256) %/ 16 = i by smt(). + rewrite -(drop_flatten_ctt 16); 1: smt(mapP W16.size_w2bits). + rewrite flatten_take_drop_16; 1..3: by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + rewrite get_w2bits;congr; rewrite ?get_to_list;smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + ++ move => ?? /=. have ? : offset <= 16*i < offset + 256. smt(). + rewrite -!catA drop_cat ifF;1: by smt(size_take W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + rewrite !drop_cat ifT;1: by smt(size_take W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + rewrite size_take 1:/# size_flatten_W16_w2bits size_to_list /= ifT 1:/#. + rewrite take_cat_le ifT;1: by rewrite size_drop 1:/# size_w2bits /= /max ifT /#. + rewrite -(W16.w2bitsK ((bv \bits16 i - 16 * (offset %/ 256))));congr. + apply (eq_from_nth false). + + rewrite size_take // size_drop 1:/#; smt(Array256.size_to_list W16.size_w2bits). + move => k kb. + rewrite nth_take; 1,2: by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + rewrite nth_drop; 1,2: by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + rewrite !get_w2bits get_bits16;by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). qed. +*) -lemma to_sint8_mod x: - W8.to_sint x %% W8.modulus = to_uint x. +lemma size_flatten_W16_w2bits (a : W16.t list) : + (size (flatten (map W16.w2bits (a)))) = 16 * size a. proof. -rewrite /to_sint /smod. -case: (2 ^ (8 - 1) <= to_uint x) => C. - rewrite -modzDm -modzNm modzz /= modz_mod. - rewrite modz_small //. - by apply JUtils.bound_abs; apply W8.to_uint_cmp. -rewrite modz_small //. -by apply JUtils.bound_abs; apply W8.to_uint_cmp. + rewrite size_flatten -map_comp /(\o) /=. + rewrite StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /=. + rewrite StdBigop.Bigint.big_constz count_predT /#. qed. -lemma to_sint8K (x:W8.t): W8.of_int (to_sint x) = x. -proof. by rewrite -of_int_mod to_sint8_mod to_uintK. qed. +bind op [W16.t & W256.t & Array256.t] sliceset256_16_256 "asliceset". +realize bvaslicesetP. +move => arr offset bv H /= k kb; rewrite /sliceset256_16_256 /=. +case (8 %| offset) => /= *; last first. ++ rewrite of_listK; 1: by rewrite size_map size_chunk // !size_cat size_take; + by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + rewrite -(map_comp W16.w2bits W16.bits2w) /(\o). + have := eq_in_map ((fun (x : bool list) => w2bits ((bits2w x))%W16)) idfun (chunk 16 + (take offset (flatten (map W16.w2bits (to_list arr))) ++ w2bits bv ++ + drop (offset + 256) (flatten (map W16.w2bits (to_list arr))))). + rewrite iffE => [#] -> * /=; 1: by smt(in_chunk_size W16.bits2wK). + rewrite map_id /= chunkK //;1: by rewrite !size_cat size_take; + by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + by rewrite !nth_cat !size_cat /=; + smt(nth_take nth_drop size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). +rewrite (nth_flatten _ 16); 1: by rewrite allP => i;rewrite mapP => He; elim He;smt(W16.size_w2bits). +rewrite (nth_map W16.zero []); 1: smt(Array256.size_to_list). +rewrite nth_mkseq 1:/# /= initiE 1:/# /= get16E pack2E initiE 1:/# /= initiE 1:/# /= /set256_direct. +rewrite initiE 1:/# /=. +case (offset <= k && k < offset + 256) => *; 1: by + rewrite ifT 1:/# get_bits8 /= 1,2:/# initiE // initiE //. +rewrite ifF 1:/# initiE 1:/# /=. +rewrite (nth_flatten _ 16); 1: by rewrite allP => i;rewrite mapP => He; elim He;smt(W16.size_w2bits). +rewrite (nth_map W16.zero []); 1: smt(Array256.size_to_list). +rewrite nth_mkseq 1:/# /= bits8E /= initiE /# /=. +qed. -lemma truncateu128_bits128 (w:W256.t): - truncateu128 w = w \bits128 0. -proof. by rewrite /truncateu128 to_uint_eq of_uintK bits128_div // of_uintK. qed. +op sliceget32_8_256 (arr: W8.t Array32.t) (offset: int) : W256.t = +if 8 %| offset then + get256_direct (WArray32.init8 (fun (i_0 : int) => arr.[i_0])) (offset %/ 8) + else W256.bits2w (take 256 (drop offset (flatten (map W8.w2bits (to_list arr))))). -hoare cbd2_avx2_h _bytes: - Jkem_avx2.M(Jkem_avx2.Syscall).__cbd2: buf=_bytes ==> res = Array256.init (fun k => W16.of_int (noise_coef _bytes k)). +bind op [W8.t & W256.t & Array32.t] sliceget32_8_256 "asliceget". +realize bvaslicegetP. +move => /= arr offset; rewrite /sliceget32_8_256 /= => H k kb. +case (8%| offset) => /= *; last by smt(W256.get_bits2w). +rewrite /get256_direct pack32E initiE 1:/# /= initiE 1:/# /= initiE 1:/# /=. +rewrite nth_take 1,2:/# nth_drop 1,2:/#. +rewrite (BitEncoding.BitChunking.nth_flatten false 8 _). ++ rewrite allP => x /=; rewrite mapP => He; elim He;smt(W8.size_w2bits). +rewrite (nth_map W8.zero []); 1: smt(Array32.size_to_list). +by rewrite nth_mkseq /#. +qed. + +op sliceget768_16_256 (arr: W16.t Array768.t) (offset: int) : W256.t = +if 8 %| offset then + get256_direct (WArray1536.init16 (fun (i_0 : int) => arr.[i_0])) (offset %/ 8) + else W256.bits2w (take 256 (drop offset (flatten (map W16.w2bits (to_list arr))))). + + +bind op [W16.t & W256.t & Array768.t] sliceget768_16_256 "asliceget". +realize bvaslicegetP. +move => /= arr offset; rewrite /sliceget768_16_256 /= => H k kb. +case (8%| offset) => /= *; last by smt(W256.get_bits2w). +rewrite /get256_direct pack32E initiE 1:/# /= initiE 1:/# /= initiE 1:/# /= bits8E initiE 1:/# /=. +rewrite nth_take 1,2:/# nth_drop 1,2:/#. +rewrite (BitEncoding.BitChunking.nth_flatten false 16 _). ++ rewrite allP => x /=; rewrite mapP => He; elim He;smt(W16.size_w2bits). +rewrite (nth_map W16.zero []); 1: smt(Array768.size_to_list). +by rewrite nth_mkseq /#. +qed. + +op sliceset960_8_128 (arr: W8.t Array960.t) (offset: int) (bv: W128.t) : W8.t Array960.t = + if 8 %| offset + then Array960.init (fun (i3 : int) => get8 (set128_direct ((init8 (fun (i_0 : int) => arr.[i_0])))%WArray960 (offset %/ 8) bv) i3) + else Array960.of_list witness (map W8.bits2w (chunk 8 (take offset (flatten (map W8.w2bits (to_list arr))) ++ w2bits bv ++ + drop (offset + 128) (flatten (map W8.w2bits (to_list arr)))))). + +lemma size_flatten_W8_w2bits (a : W8.t list) : + (size (flatten (map W8.w2bits (a)))) = 8 * size a. proof. -proc. -sp; simplify. -while (0 <= i <= 4 /\ #{~i}pre /\ List.all (fun k => rp.[k]=W16.of_int (noise_coef _bytes k)) (iota_ 0 (64*i))). - seq 15: (#pre /\ - all (fun k=> if k%%2 = 0 - then to_sint (f0 \bits8 (k %/ 2)) = noise_coef _bytes (64*i+k) - else to_sint (f1 \bits8 (k %/ 2)) = noise_coef _bytes (64*i+k)) - (iota_ 0 64)). - auto => &m |> ?_ /List.allP H ?; apply/List.allP => k; rewrite mem_iota /= => *. - case: (k%%2=0) => C1. - move: (noise_coef_avx2 buf{m} (64*i{m}+k)). - have ->: (64 * i{m} + k) %% 2 = 0 by smt(). - rewrite /= => ->. - have:= (bytes_getR buf{m} ((64*i{m}+k)%/2) _); first smt(). - rewrite /B2Ri /= -!divz_mul //=. - have ->: (64 * i{m} + k) %/ 64 = i{m}. - by rewrite (mulzC 64) divzMDl // (divz_small _ 64) 1:/# /=. - have ->: (64 * i{m} + k) %/ 2 %% 32 = k %/ 2. - rewrite -(modz_pow_div 2 6 1) //=. - by rewrite (mulzC 64) modzMDl modz_small /#. - move => Eb. - rewrite map2bE 1:/# /= mask0F_bits8 1:/# /=. - rewrite map2bE 1:/#. beta. - rewrite VPSRL2_ANDmask33 1:/#. - rewrite map2bE 1:/#; beta. - rewrite map2bE 1:/#; beta. - rewrite VPSRL1_ANDmask55 1:/#. - rewrite mask33_bits8 1:/# /=. - rewrite map2bE 1:/#; beta. - rewrite VPSRL1_ANDmask55 1:/#. - rewrite mask33_bits8 1:/# /=. - rewrite mask55_bits8 1:/# /=. - rewrite mask03_bits8 1:/# -!Eb. - by congr. - have C2: k %% 2 = 1 by smt(). - move: (noise_coef_avx2 buf{m} (64*i{m}+k)). - have ->: (64 * i{m} + k) %% 2 = 1 by smt(). - rewrite /= => ->. - have:= (bytes_getR buf{m} ((64*i{m}+k)%/2) _); first smt(). - rewrite /B2Ri /= -!divz_mul //=. - have ->: (64 * i{m} + k) %/ 64 = i{m}. - by rewrite (mulzC 64) divzMDl // (divz_small _ 64) 1:/# /=. - have ->: (64 * i{m} + k) %/ 2 %% 32 = k %/ 2. - rewrite -(modz_pow_div 2 6 1) //=. - by rewrite (mulzC 64) modzMDl modz_small /#. - move => Eb. - rewrite map2bE 1:/#; beta. - rewrite VPSRL4_ANDmask0F 1:/#. - rewrite map2bE 1:/#; beta. - rewrite map2bE 1:/#; beta. - rewrite VPSRL2_ANDmask33 1:/#. - rewrite map2bE 1:/#; beta. - rewrite VPSRL1_ANDmask55 1:/#. - rewrite mask33_bits8 1:/# /=. - rewrite map2bE 1:/#; beta. - rewrite VPSRL1_ANDmask55 1:/#. - rewrite mask33_bits8 1:/# /=. - rewrite mask55_bits8 1:/# /=. - rewrite mask03_bits8 1:/# -!Eb. - by congr. - seq 10: (#[/:-2]pre /\ - all (fun (k : int) => - if k %/ 16 = 0 - then f0 \bits16 k%%16 = W16.of_int (noise_coef _bytes (64*i+k)) - else if k %/ 16 = 1 - then f2 \bits16 k%%16 = W16.of_int (noise_coef _bytes (64*i+k)) - else if k %/ 16 = 2 - then f1 \bits16 k%%16 = W16.of_int (noise_coef _bytes (64*i+k)) - else f3 \bits16 k%%16 = W16.of_int (noise_coef _bytes (64*i+k))) - (iota_ 0 64)). - auto => &m |> ?_ /List.allP IH ?. - rewrite -{1}iotaredE /= => |> *. - rewrite -iotaredE /=. - do 32! (split; first by - rewrite /VPMOVSX_16u8_16u16 /VPUNPCKL_32u8 /VPUNPCKL_16u8 /VPUNPCKH_32u8 /VPUNPCKH_16u8 /MOVSX_u16s8 truncateu128_bits128 /interleave_gen /get_lo_2u64 /get_hi_2u64 /= /#). - do 31! (split; first by - rewrite /VPMOVSX_16u8_16u16 /VPUNPCKL_32u8 /VPUNPCKL_16u8 /VPUNPCKH_32u8 /VPUNPCKH_16u8 /MOVSX_u16s8 /VEXTRACTI128 /interleave_gen /get_lo_2u64 /get_hi_2u64 /b2i /= /int_bit /= /#). - by rewrite /VPMOVSX_16u8_16u16 /VPUNPCKL_32u8 /VPUNPCKL_16u8 /VPUNPCKH_32u8 /VPUNPCKH_16u8 /MOVSX_u16s8 /VEXTRACTI128 /interleave_gen /get_lo_2u64 /get_hi_2u64 /b2i /= /int_bit /= /#. - auto => |> &m ? _ /List.allP IH ? /List.allP H. - split; first smt(). - rewrite -!NTT_AVX_Fq.PURE 1..4:/#. - apply/List.allP => k; rewrite mem_iota /= => |> *. - rewrite !NTT_AVX_Fq.PUR_get 1..8:/#. - case: (k %/ 16 = 4 * i{m} + 3) => C1. - move: (H (k %% 64) _) => /=; first smt(mem_iota). - rewrite (modz_pow_div 2 6 4) //= C1 (mulzC 4) modzMDl /=. - rewrite (modz_dvd_pow 4 6 _ 2) //. - have ->: 64 * i{m} + k %% 64 = k by smt(). - by rewrite /R2C /= Array16.initiE /#. - case: (k %/ 16 = 4 * i{m} + 2) => C2. - move: (H (k %% 64) _) => /=; first smt(mem_iota). - rewrite (modz_pow_div 2 6 4) //= C2 (mulzC 4) modzMDl /=. - rewrite (modz_dvd_pow 4 6 _ 2) //. - have ->: 64 * i{m} + k %% 64 = k by smt(). - by rewrite /R2C /= Array16.initiE /#. - case: (k %/ 16 = 4 * i{m} + 1) => C3. - move: (H (k %% 64) _) => /=; first smt(mem_iota). - rewrite (modz_pow_div 2 6 4) //= C3 (mulzC 4) modzMDl /=. - rewrite (modz_dvd_pow 4 6 _ 2) //. - have ->: 64 * i{m} + k %% 64 = k by smt(). - by rewrite /R2C /= Array16.initiE /#. - case: (k %/ 16 = 4 * i{m}) => C4. - move: (H (k %% 64) _) => /=; first smt(mem_iota). - rewrite (modz_pow_div 2 6 4) //= C4 modzMr. - rewrite (modz_dvd_pow 4 6 _ 2) //. - have ->: 64 * i{m} + k %% 64 = k by smt(). - by rewrite /R2C /= Array16.initiE /#. - have ?: k < 64*i{m} by smt(). - by move: (IH k _) => /=; first smt(mem_iota). -auto => &m |> *. -split; first by rewrite iota0. -move => i rp ???; rewrite (_:i=4) 1:/# /=. -move => /List.allP H. -rewrite tP => k Hk; rewrite (H k _); first smt(mem_iota). -by rewrite initiE /#. + rewrite size_flatten -map_comp /(\o) /=. + rewrite StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /=. + rewrite StdBigop.Bigint.big_constz count_predT /#. qed. -lemma cbd2_ll : islossless Jkem_avx2.M(Jkem_avx2.Syscall).__cbd2. -proc. inline *. sp; wp. while (true) (4-i). move => z. -auto => /> &hr H. smt(). -auto => />i. smt(). qed. - -phoare cbd2_avx2_ph _bytes: - [Jkem_avx2.M(Jkem_avx2.Syscall).__cbd2: buf=_bytes ==> res = Array256.init (fun k => W16.of_int (noise_coef _bytes k))] = 1%r. -conseq cbd2_ll (cbd2_avx2_h _bytes) => />. qed. +bind op [W8.t & W128.t & Array960.t] sliceset960_8_128 "asliceset". +realize bvaslicesetP. +move => arr offset bv H /= k kb; rewrite /sliceset960_8_128 /=. +case (8 %| offset) => /= *; last first. ++ rewrite of_listK; 1: by rewrite size_map size_chunk // !size_cat size_take; + by smt(size_take size_drop W8.size_w2bits size_cat Array960.size_to_list size_flatten_W8_w2bits size_ge0). + rewrite -(map_comp W8.w2bits W8.bits2w) /(\o). + have := eq_in_map ((fun (x : bool list) => w2bits ((bits2w x))%W8)) idfun (chunk 8 + (take offset (flatten (map W8.w2bits (to_list arr))) ++ w2bits bv ++ + drop (offset + 128) (flatten (map W8.w2bits (to_list arr))))). + rewrite iffE => [#] -> * /=; 1: by smt(in_chunk_size W8.bits2wK). + rewrite map_id /= chunkK //;1: by rewrite !size_cat size_take; + by smt(size_take size_drop W8.size_w2bits size_cat Array960.size_to_list size_flatten_W8_w2bits size_ge0). + by rewrite !nth_cat !size_cat /=; + smt(nth_take nth_drop size_take size_drop W8.size_w2bits size_cat Array960.size_to_list size_flatten_W8_w2bits size_ge0). +rewrite (nth_flatten _ 8); 1: by rewrite allP => i;rewrite mapP => He; elim He;smt(W8.size_w2bits). +rewrite (nth_map W8.zero []); 1: smt(Array960.size_to_list). +rewrite nth_mkseq 1:/# /= initiE 1:/# /= /get8 /set128_direct. +rewrite initiE 1:/# /=. +case (offset <= k && k < offset + 128) => *; 1: by + rewrite ifT 1:/# get_bits8 /= 1,2:/# initiE // initiE //. +rewrite ifF 1:/# initiE 1:/# /=. +rewrite (nth_flatten _ 8); 1: by rewrite allP => i;rewrite mapP => He; elim He;smt(W8.size_w2bits). +rewrite (nth_map W8.zero []); 1: smt(Array960.size_to_list). +rewrite nth_mkseq /#. +qed. -module AuxMLKEMAvx2 = { - proc cbd2_ref (rp:W16.t Array256.t,buf:W8.t Array128.t) : W16.t Array256.t = { - var k: int; - var a, b, c: W8.t; - var i: W64.t; - var t: W16.t; - i <- (W64.of_int 0); - while ((i \ult (W64.of_int 128))) { - c <- buf.[(W64.to_uint i)]; - a <- c; - a <- (a `&` (W8.of_int 85)); - c <- (c `>>` (W8.of_int 1)); - c <- (c `&` (W8.of_int 85)); - c <- (c + a); - a <- c; - a <- (a `&` (W8.of_int 3)); - b <- c; - b <- (b `>>` (W8.of_int 2)); - b <- (b `&` (W8.of_int 3)); - a <- (a - b); - t <- (sigextu16 a); - rp.[W64.to_uint (W64.of_int 2 * i)] <- t; - a <- c; - a <- (a `>>` (W8.of_int 4)); - a <- (a `&` (W8.of_int 3)); - b <- (c `>>` (W8.of_int 6)); - b <- (b `&` (W8.of_int 3)); - a <- (a - b); - t <- (sigextu16 a); - rp.[W64.to_uint (W64.of_int 2 * i + W64.one)] <- t; - i <- (i + (W64.of_int 1)); - } - return (rp); - } - proc _poly_getnoise (rp:W16.t Array256.t, seed:W8.t Array32.t,nonce:W8.t) : W16.t Array256.t = { - var buf:W8.t Array128.t; - var r; +op sliceset960_8_32 (arr: W8.t Array960.t) (offset: int) (bv: W32.t) : W8.t Array960.t = + if 8 %| offset + then Array960.init + (WArray960.get8 + (set32_direct (WArray960.init8 (fun (i_0 : int) => arr.[i_0])) ( + offset %/ 8) bv)) + else Array960.of_list witness (map W8.bits2w (chunk 8 (take offset (flatten (map W8.w2bits (to_list arr))) ++ w2bits bv ++ + drop (offset + 32) (flatten (map W8.w2bits (to_list arr)))))). - buf <- witness; - buf <@ M(Syscall)._shake256_128_33 (buf,Array33.init (fun i => if i=32 then nonce else seed.[i])); - r <@ cbd2_ref(rp,buf); - return r; - } - proc __poly_getnoise_eta1_4x(aux3 aux2 aux1 aux0 : W16.t Array256.t, - noiseseed : W8.t Array32.t, - nonce : W8.t) : - W16.t Array256.t * W16.t Array256.t * W16.t Array256.t * W16.t Array256.t = { - var n3, n2, n1, n0 : W8.t; - var aux_3, aux_2, aux_1, aux_0 : W16.t Array256.t; - n0 <- nonce + W8.of_int 3; - n1 <- nonce + W8.of_int 2; - n2 <- nonce + W8.of_int 1; - n3 <- nonce; - aux_3 <@ _poly_getnoise(aux3,noiseseed,n3); - aux_2 <@ _poly_getnoise(aux2,noiseseed,n2); - aux_1 <@ _poly_getnoise(aux1,noiseseed,n1); - aux_0 <@ _poly_getnoise(aux0,noiseseed,n0); - return (aux_3, aux_2, aux_1, aux_0); - } -}. -hoare cbd2_ref_h _bytes: - AuxMLKEMAvx2.cbd2_ref: buf=_bytes ==> res = Array256.init (fun k => W16.of_int (noise_coef _bytes k)). -proof. -proc. -while (to_uint i <= 128 /\ #pre /\ List.all (fun k => rp.[k]=W16.of_int (noise_coef _bytes k)) (iota_ 0 (2 * to_uint i))). - auto => &m |>; rewrite /(\ult) => _ /List.allP IH /= Hi. - rewrite to_uintD_small /= 1:/#. - split; first smt(). - apply/List.allP => k; rewrite mem_iota /=; move => [? Hk]. - rewrite to_uintD_small !to_uintM_small /= 1..3:/#. - case: (k = 2 * to_uint i{m}) => C1. - rewrite /noise_coef !get_setE 1..2:/# C1 /= ifF 1:/#. - have ->/=: 2 * to_uint i{m} %/ 2 = to_uint i{m} by smt(). - rewrite -to_sint_eq sigextu16_to_sint (_: 3 = 2^2 -1) // !and_mod //= W8_of_sintK_signed /=; 1: smt(). - have -> /= : 2 * to_uint i{m} %% 2 = 0 by smt(). - by rewrite -parallel_noisesum_low smod_small // /#. - case: (k = 2 * to_uint i{m}+1) => C2. - rewrite /noise_coef !get_setE 1..2:/# C2 /=. - have ->/=: (2 * to_uint i{m} + 1) %/ 2 = to_uint i{m} by smt(). - rewrite -to_sint_eq sigextu16_to_sint (_: 3 = 2^2 -1) // !and_mod //= W8_of_sintK_signed /=; 1: smt(). - have -> /= : (2 * to_uint i{m}+1) %% 2 = 1 by smt(). - by rewrite -parallel_noisesum_high smod_small // /#. - rewrite !get_setE 1..2:/# C1 C2 /=; apply IH. - smt(mem_iota). -auto => &m |> *. -split; first by rewrite iota0. -move=> i rp; rewrite /(\ult) => |> ??. -have ->/=: to_uint i = 128 by smt(). -rewrite tP => /List.allP H k Hk. -rewrite (H k _) /=. - smt(mem_iota). -by rewrite initiE //. +bind op [W8.t & W32.t & Array960.t] sliceset960_8_32 "asliceset". +realize bvaslicesetP. +move => arr offset bv H /= k kb; rewrite /sliceset960_8_32 /=. +case (8 %| offset) => /= *; last first. ++ rewrite of_listK; 1: by rewrite size_map size_chunk // !size_cat size_take; + by smt(size_take size_drop W8.size_w2bits size_cat Array960.size_to_list size_flatten_W8_w2bits size_ge0). + rewrite -(map_comp W8.w2bits W8.bits2w) /(\o). + have := eq_in_map ((fun (x : bool list) => w2bits ((bits2w x))%W8)) idfun (chunk 8 + (take offset (flatten (map W8.w2bits (to_list arr))) ++ w2bits bv ++ + drop (offset + 32) (flatten (map W8.w2bits (to_list arr))))). + rewrite iffE => [#] -> * /=; 1: by smt(in_chunk_size W8.bits2wK). + rewrite map_id /= chunkK //;1: by rewrite !size_cat size_take; + by smt(size_take size_drop W8.size_w2bits size_cat Array960.size_to_list size_flatten_W8_w2bits size_ge0). + by rewrite !nth_cat !size_cat /=; + smt(nth_take nth_drop size_take size_drop W8.size_w2bits size_cat Array960.size_to_list size_flatten_W8_w2bits size_ge0). +rewrite (nth_flatten _ 8); 1: by rewrite allP => i;rewrite mapP => He; elim He;smt(W8.size_w2bits). +rewrite (nth_map W8.zero []); 1: smt(Array960.size_to_list). +rewrite nth_mkseq 1:/# /= initiE 1:/# /= /get8 /set32_direct. +rewrite initiE 1:/# /=. +case (offset <= k && k < offset + 32) => *; 1: by + rewrite ifT 1:/# get_bits8 /= 1,2:/# initiE // initiE //. +rewrite ifF 1:/# initiE 1:/# /=. +rewrite (nth_flatten _ 8); 1: by rewrite allP => i;rewrite mapP => He; elim He;smt(W8.size_w2bits). +rewrite (nth_map W8.zero []); 1: smt(Array960.size_to_list). +rewrite nth_mkseq /#. qed. -lemma cbd2_ref_ll : islossless AuxMLKEMAvx2.cbd2_ref. -proc. inline*. sp; wp. while (true) (128 - W64.to_uint i). move => z. -auto => /> &hr. rewrite ultE of_uintK //= => H. rewrite to_uintD_small //= /#. -auto => />i H. rewrite ultE of_uintK //= 1:/#. qed. - -phoare cbd2_ref_ph _bytes: - [AuxMLKEMAvx2.cbd2_ref: buf=_bytes ==> res = Array256.init (fun k => W16.of_int (noise_coef _bytes k))] = 1%r. -conseq cbd2_ref_ll (cbd2_ref_h _bytes) => />. qed. +theory W10. +abbrev [-printing] size = 10. +clone include BitWordSH with op size <- size +rename "_XX" as "_10" +proof gt0_size by done, +size_le_256 by done. -equiv getnoise_split : - M(Syscall)._poly_getnoise ~ AuxMLKEMAvx2._poly_getnoise : ={arg} ==> ={res}. -proc; wp; sp => />. -seq 2 0 : ( ={buf,rp,seed,nonce} /\ extseed{1}=Array33.init (fun i => if i=32 then nonce{1} else seed{1}.[i]) ). -wp. while{1} (0 <= k{1} <= 32 /\ (forall i, 0 <= i < k{1} => extseed{1}.[i]=seed{1}.[i])) (32-k{1}). -auto => /> &m H1 H2 H3 H4. split. split. smt(). move => i Hi1 Hi2. rewrite get_setE 1:/#. smt(). smt(). -auto => /> &m. split. move => i Hi1 Hi2. rewrite !get_out /#. move => extseed k. split. smt(). move => H1 H2 H3 H4. rewrite tP => i Hi. rewrite !initiE 1:/# => />. rewrite get_setE 1:/#. smt(). -seq 1 1 : (#pre). -call (_:true). auto => />. sim. auto => />. -inline *. sim. qed. +end W10. export W10 W10.ALU W10.SHIFT. -equiv getnoise_1x_equiv_avx : - Jkem_avx2.M(Jkem_avx2.Syscall).__poly_cbd_eta1 ~ AuxMLKEMAvx2.cbd2_ref : ={arg} ==> ={res}. -proc*. inline Jkem_avx2.M(Jkem_avx2.Syscall).__poly_cbd_eta1. rcondt{1} 3. auto => />. sp;wp. -ecall{1} (cbd2_avx2_ph buf{1}) => />. -ecall{2} (cbd2_ref_ph buf{2}) => />. -auto => /> &2. rewrite tP => i Hi. rewrite initiE //=. qed. +bind bitstring W10.w2bits W10.bits2w W10.to_uint W10.to_sint W10.of_int W10.t 10. +realize size_tolist by auto. +realize tolistP by auto. +realize oflistP by smt(W10.bits2wK). +realize ofintP by move => *;rewrite /of_int int2bs_mod. +realize tosintP. move => bv /=;rewrite /to_sint /smod /BVA_Top_W10_t.msb. +have -> /=: nth false (w2bits bv) (10 - 1) = 2 ^ (10 - 1) <= to_uint bv; last by smt(). +rewrite /to_uint. +rewrite -{2}(cat_take_drop 9 (w2bits bv)). +rewrite bs2int_cat size_take // W10.size_w2bits /=. +rewrite -bs2int_div //= get_to_uint //=. +rewrite -bs2int_mod // /= /to_uint. +have ? : 2^10 = 1024 by rewrite /=. +by smt(bs2int_range mem_range W10.size_w2bits). +qed. +realize touintP by smt(). -equiv getnoise_4x_split : - GetNoiseAVX2._poly_getnoise_eta1_4x ~ AuxMLKEMAvx2.__poly_getnoise_eta1_4x : ={arg} ==> ={res}. -proc; wp; sp => />. call getnoise_split => />. call getnoise_split => />. call getnoise_split => />. call getnoise_split => />. auto => />. qed. +op truncate64_10 (bw: W64.t) : W10.t = W10.bits2w (W64.w2bits bw). -equiv getnoiseequiv_avx : - Jkem_avx2.M(Jkem_avx2.Syscall)._poly_getnoise_eta1_4x ~ GetNoiseAVX2._poly_getnoise_eta1_4x : ={arg} ==> ={res}. -proc*. -transitivity{2} { r <@ AuxMLKEMAvx2.__poly_getnoise_eta1_4x(aux3,aux2,aux1,aux0,noiseseed,nonce); } ((r0{1}, r1{1}, r2{1}, r3{1}, seed{1}, nonce{1}) = (aux3{2}, aux2{2}, aux1{2}, aux0{2}, noiseseed{2}, nonce{2}) ==> ={r}) (={aux3,aux2,aux1,aux0,noiseseed,nonce} ==> ={r}); last first. -symmetry. call getnoise_4x_split => />. auto => />. smt(). smt(). -(*main proof*) -inline Jkem_avx2.M(Jkem_avx2.Syscall)._poly_getnoise_eta1_4x AuxMLKEMAvx2.__poly_getnoise_eta1_4x AuxMLKEMAvx2._poly_getnoise. swap{2} [30..31] 5. swap{2} [23..24] 10. swap{2} [16..17] 15. -seq 25 30 : ( - r00{1}=rp{2} /\ Array128.init (fun (i : int) => buf0{1}.[i]) =buf{2} - /\ r10{1}=rp0{2} /\ Array128.init (fun (i : int) => buf1{1}.[i]) =buf0{2} - /\ r20{1}=rp1{2} /\ Array128.init (fun (i : int) => buf2{1}.[i]) =buf1{2} - /\ r30{1}=rp2{2} /\ Array128.init (fun (i : int) => buf3{1}.[i]) =buf2{2} -). -sp => />. -ecall{2} (shake256_33_128 buf2{2} (Array33.init (fun i => if i = 32 then nonce4{2} else seed2{2}.[i]))); wp => />. -ecall{2} (shake256_33_128 buf1{2} (Array33.init (fun i => if i = 32 then nonce3{2} else seed1{2}.[i]))); wp => />. -ecall{2} (shake256_33_128 buf0{2} (Array33.init (fun i => if i = 32 then nonce2{2} else seed0{2}.[i]))); wp => />. -ecall{2} (shake256_33_128 buf{2} (Array33.init (fun i=> if i = 32 then nonce1{2} else seed{2}.[i]))); wp => />. -ecall{1} (shake_squeezenblocks4x state{1} buf0{1} buf1{1} buf2{1} buf3{1}); wp => />. -ecall{1} (shake_absorb4x state{1} (Array33.init (fun i => buf0{1}.[i])) (Array33.init (fun i => buf1{1}.[i])) (Array33.init (fun i => buf2{1}.[i])) (Array33.init (fun i => buf3{1}.[i])) ); wp => />. -auto => /> &2. rewrite shake4x_equiv => />. -rewrite tP => k Hk; rewrite !initiE //= 1..3:/#; rewrite ifF 1:/#; rewrite /get8 /init8 set_neqiE 1:/#; rewrite initiE //= 1:/#; rewrite initiE //= 1:/#; rewrite set256E initiE //= 1:/#; rewrite ifT //; rewrite /get256_direct bits8_W32u8 //=; rewrite ifT //; rewrite initiE //=; rewrite initiE //=. -rewrite tP => k Hk; rewrite !initiE //= 1..3:/#; rewrite ifF 1:/#; rewrite /get8 /init8 set_neqiE 1:/#; rewrite initiE //= 1:/#; rewrite initiE //= 1:/#; rewrite set256E initiE //= 1:/#; rewrite ifT //; rewrite /get256_direct bits8_W32u8 //=; rewrite ifT //; rewrite initiE //=; rewrite initiE //=. -rewrite tP => k Hk; rewrite !initiE //= 1..3:/#; rewrite ifF 1:/#; rewrite /get8 /init8 set_neqiE 1:/#; rewrite initiE //= 1:/#; rewrite initiE //= 1:/#; rewrite set256E initiE //= 1:/#; rewrite ifT //; rewrite /get256_direct bits8_W32u8 //=; rewrite ifT //; rewrite initiE //=; rewrite initiE //=. -rewrite tP => k Hk; rewrite !initiE //= 1..3:/#; rewrite ifF 1:/#; rewrite /get8 /init8 set_neqiE 1:/#; rewrite initiE //= 1:/#; rewrite initiE //= 1:/#; rewrite set256E initiE //= 1:/#; rewrite ifT //; rewrite /get256_direct bits8_W32u8 //=; rewrite ifT //; rewrite initiE //=; rewrite initiE //=. -wp. call getnoise_1x_equiv_avx => />. -wp. call getnoise_1x_equiv_avx => />. -wp. call getnoise_1x_equiv_avx => />. -wp. call getnoise_1x_equiv_avx => />. -auto => />. qed. +bind op [W64.t & W10.t] truncate64_10 "truncate". +realize bvtruncateP. +move => mv. rewrite /truncate64_10 /W64.w2bits take_mkseq //= /w2bits. +apply (eq_from_nth witness);1: by smt(size_mkseq). +move => i; rewrite size_mkseq /= /max /= => ib. +by rewrite !nth_mkseq // /bits2w initiE //= nth_mkseq /#. +qed. -lemma polygetnoise_ll : islossless Jkem.M(Jkem.Syscall)._poly_getnoise. -proc. -while (0 <= to_uint i <= 128) (128 - to_uint i); - 1: by move => z; auto => />;rewrite ultE /= => &hr ???; rewrite !to_uintD_small /=; smt(to_uint_cmp). -wp; call sha3ll; wp; while (0<=k<=32) (32 -k); 1: by move => z; auto=> /> /#. -auto => /> *; do split; 1:smt(). -by move => *; rewrite ultE /=; smt(). +bind op [W64.t & W8.t] W8u8.truncateu8 "truncate". +realize bvtruncateP. (* generalize *) +move => mv; rewrite /truncateu8 /W64.w2bits take_mkseq //= /w2bits. +apply (eq_from_nth witness);1: by smt(size_mkseq). +move => i; rewrite size_mkseq /= /max /= => ib. +rewrite !nth_mkseq // /of_int /to_uint /= get_bits2w // + nth_mkseq //= get_to_uint //= /to_uint /=. +have -> /=: (0 <= i && i < 64) by smt(). +pose a := bs2int (w2bits mv). +rewrite {1}(divz_eq a (2^(8-i)*2^i)) !mulrA divzMDl; + 1: by smt(StdOrder.IntOrder.expr_gt0). +rewrite dvdz_modzDl; 1: by + have -> : 2^(8-i) = 2^((8-i-1)+1); [ by smt() | + rewrite exprS 1:/#; smt(dvdz_mull dvdz_mulr)]. +by have -> : (2 ^ (8 - i) * 2 ^ i) = 256; + [ rewrite -StdBigop.Bigint.Num.Domain.exprD_nneg + 1,2:/# /= -!addrA /= | done ]. qed. -equiv getnoiseequiv : - Jkem.M(Jkem.Syscall)._poly_getnoise ~Jkem.M(Jkem.Syscall)._poly_getnoise : - ={arg} ==> ={res} /\ - signed_bound_cxq res{1} 0 256 1. -have H : forall &m a, - Pr[Jkem.M(Jkem.Syscall)._poly_getnoise(a) @ &m : forall k, 0<=k<256 => -5 < to_sint res.[k] < 5] = 1%r. -+ move => &m a. - have -> : 1%r = Pr [ CBD2.sample(PRF a.`2 a.`3) @ &m : true]. - + byphoare => //. - proc; inline *; while (0<=i<=128) (128-i); 1: by move => z; auto => /> /#. - by auto => /> /#. - by byequiv get_noise_sample_noise => //. -have HH0 : hoare [Jkem.M(Jkem.Syscall)._poly_getnoise : true ==> forall k, 0<=k<256 => -5 < to_sint res.[k] < 5]. -+ hoare; bypr => //= &m; rewrite Pr[mu_not]. - have -> : Pr[Jkem.M(Jkem.Syscall)._poly_getnoise(rp{m}, s_seed{m}, nonce{m}) @ &m : true] = 1%r. - + by byphoare => //; apply polygetnoise_ll. - smt(). -have HHH : equiv [ Jkem.M(Jkem.Syscall)._poly_getnoise ~Jkem.M(Jkem.Syscall)._poly_getnoise : ={arg} ==> ={res} ] by sim. -conseq HHH HH0. -move => *; rewrite /signed_bound_cxq /b16 qE /#. +bind op [W16.t & W8.t] W2u8.truncateu8 "truncate". +realize bvtruncateP. +move => mv; rewrite /truncateu8 /W16.w2bits take_mkseq //= /w2bits. +apply (eq_from_nth witness);1: by smt(size_mkseq). +move => i; rewrite size_mkseq /= /max /= => ib. +rewrite !nth_mkseq // /of_int /to_uint /= get_bits2w // + nth_mkseq //= get_to_uint //= /to_uint /=. +have -> /=: (0 <= i && i < 16) by smt(). +pose a := bs2int (w2bits mv). +rewrite {1}(divz_eq a (2^(8-i)*2^i)) !mulrA divzMDl; + 1: by smt(StdOrder.IntOrder.expr_gt0). +rewrite dvdz_modzDl; 1: by + have -> : 2^(8-i) = 2^((8-i-1)+1); [ by smt() | + rewrite exprS 1:/#; smt(dvdz_mull dvdz_mulr)]. +by have -> : (2 ^ (8 - i) * 2 ^ i) = 256; + [ rewrite -StdBigop.Bigint.Num.Domain.exprD_nneg + 1,2:/# /= -!addrA /= | done ]. qed. -import InnerPKE. -lemma mlkem_correct_kg_avx2 mem _pkp _skp : - equiv [Jkem_avx2.M(Jkem_avx2.Syscall).__indcpa_keypair ~ InnerPKE.kg_derand : - Glob.mem{1} = mem /\ to_uint pkp{1} = _pkp /\ to_uint skp{1} = _skp /\ - randomnessp{1} = coins{2} /\ - valid_disj_reg _pkp (384*3+32) _skp (384*3) - ==> - touches2 Glob.mem{1} mem _pkp (384*3+32) _skp (384*3) /\ - let (pk,sk) = res{2} in let (t,rho) = pk in - sk = load_array1152 Glob.mem{1} _skp /\ - t = load_array1152 Glob.mem{1} _pkp /\ - rho = load_array32 Glob.mem{1} (_pkp+1152)]. -proc*. -transitivity {1} {Jkem.M(Jkem.Syscall).__indcpa_keypair(pkp, skp, randomnessp);} -(={Glob.mem,pkp,skp,randomnessp} /\ - Glob.mem{1} = mem /\ - to_uint pkp{1} = _pkp /\ - to_uint skp{1} = _skp /\ - randomnessp{1} = randomnessp{2} /\ - valid_disj_reg _pkp (384 * 3 + 32) _skp (384 * 3) ==> ={Glob.mem}) -( Glob.mem{1} = mem /\ to_uint pkp{1} = _pkp /\ to_uint skp{1} = _skp /\ - randomnessp{1} = coins{2} /\ - valid_disj_reg _pkp (384*3+32) _skp (384*3) - ==> - touches2 Glob.mem{1} mem _pkp (384*3+32) _skp (384*3) /\ - let (pk, sk) = r{2} in - let (t, rho) = pk in - sk = load_array1152 Glob.mem{1} _skp /\ - t = load_array1152 Glob.mem{1} _pkp /\ - rho = load_array32 Glob.mem{1} (_pkp + 1152)); 1,2: smt(); - last by call(mlkem_correct_kg mem _pkp _skp); auto => />. +bind op [W16.t & W64.t] W4u16.zeroextu64 "zextend". +realize bvzextendP + by move => bv; rewrite /zeroextu64 /= of_uintK /=; smt(W16.to_uint_cmp pow2_16). -inline{1} 1; inline {2} 1. sim 40 62. +bind op [W64.t & W16.t] W4u16.truncateu16 "truncate". +realize bvtruncateP. +move => mv; rewrite /truncateu16 /W64.w2bits take_mkseq //= /w2bits. +apply (eq_from_nth witness);1: by smt(size_mkseq). +move => i; rewrite size_mkseq /= /max /= => ib. +rewrite !nth_mkseq // /of_int /to_uint /= get_bits2w // + nth_mkseq //= get_to_uint //= /to_uint /=. +have -> /=: (0 <= i && i < 64) by smt(). +pose a := bs2int (w2bits mv). +rewrite {1}(divz_eq a (2^(16-i)*2^i)) !mulrA divzMDl; + 1: by smt(StdOrder.IntOrder.expr_gt0). +rewrite dvdz_modzDl; 1: by + have -> : 2^(16-i) = 2^((16-i-1)+1); [ by smt() | + rewrite exprS 1:/#; smt(dvdz_mull dvdz_mulr)]. +by have -> : (2 ^ (16 - i) * 2 ^ i) = 65536; + [ rewrite -StdBigop.Bigint.Num.Domain.exprD_nneg + 1,2:/# /= -!addrA /= | done ]. +qed. -call (polyvec_tobytes_equiv _pkp). -call (polyvec_tobytes_equiv _skp). -wp;conseq />;1:smt(). -ecall (polyvec_reduce_equiv (lift_array768 pkpv{2})). -have H := polyvec_add2_equiv 2 2 _ _ => //. -ecall (H (lift_array768 pkpv{2}) (lift_array768 e{2})); clear H. -unroll for {1} 36. +op sll_64 (w1 w2 : W64.t) : W64.t = + if (64 <= to_uint w2) then W64.zero else w1 `<<` (truncateu8 w2). -sp 3 3. +bind op [W64.t] sll_64 "shl". +realize bvshlP. +proof. +rewrite /sll_64 => bv1 bv2. +case : (64 <= to_uint bv2); last first. ++ rewrite /(`<<`) W64.to_uint_shl; 1: by smt(W8.to_uint_cmp). + rewrite /truncateu8 => bv2bnd />. + do 2! (rewrite (pmod_small (to_uint bv2) _);smt(W64.to_uint_cmp)). +move => *. +have -> : to_uint bv2 = (to_uint bv2 - 64) + 64 by ring. +by rewrite exprD_nneg 1,2:/# /= /#. +qed. -seq 15 17 : (#pre /\ ={publicseed, noiseseed,e,skpv,pkpv} /\ sskp{2} = skp{1} /\ spkp{2} = pkp{1}); 1: by - sp; conseq />; sim 2 2; call( sha3equiv); conseq />; sim. +bind op [W32.t & W16.t] W2u16.truncateu16 "truncate". +realize bvtruncateP. +move => mv; rewrite /truncateu16 /W32.w2bits take_mkseq //= /w2bits. +apply (eq_from_nth witness);1: by smt(size_mkseq). +move => i; rewrite size_mkseq /= /max /= => ib. +rewrite !nth_mkseq // /of_int /to_uint /= get_bits2w // + nth_mkseq //= get_to_uint //= /to_uint /=. +have -> /=: (0 <= i && i < 32) by smt(). +pose a := bs2int (w2bits mv). +rewrite {1}(divz_eq a (2^(16-i)*2^i)) !mulrA divzMDl; + 1: by smt(StdOrder.IntOrder.expr_gt0). +rewrite dvdz_modzDl; 1: by + have -> : 2^(16-i) = 2^((16-i-1)+1); [ by smt() | + rewrite exprS 1:/#; smt(dvdz_mull dvdz_mulr)]. +by have -> : (2 ^ (16 - i) * 2 ^ i) = 65536; + [ rewrite -StdBigop.Bigint.Num.Domain.exprD_nneg + 1,2:/# /= -!addrA /= | done ]. +qed. -sp 0 2. -seq 2 2 : (#pre /\ aa{1} = nttunpackm a{2} /\ - pos_bound2304_cxq aa{1} 0 2304 2 /\ - pos_bound2304_cxq a{2} 0 2304 2); 1: by - conseq />; call (genmatrixequiv false); auto => />. -swap {1} [11..12] 2. +bind op [W16.t & W32.t] sigextu32 "sextend". +realize bvsextendP. +move => bv;rewrite /sigextu32 /to_sint /smod /= !of_uintK /=. +case (32768 <= to_uint bv); 2: smt(W16.to_uint_cmp). +move =>?;rewrite -{2}(oppzK (to_uint bv - 65536)) modNz /=; smt(W16.to_uint_cmp pow2_16). +qed. -seq 10 18 : (#pre /\ - signed_bound768_cxq skpv{1} 0 768 1 /\ - signed_bound768_cxq e{1} 0 768 1 /\ - signed_bound768_cxq skpv{2} 0 768 1 /\ - signed_bound768_cxq e{2} 0 768 1). -+ conseq />. - transitivity {1} { (skpv,e) <@ GetNoiseAVX2.sample_noise_kg(skpv,pkpv,e,noiseseed);} (={noiseseed,skpv,pkpv,e} ==> ={skpv,e}) - ((r_noiseseed{2} = noiseseed{2} /\ - s_noiseseed{2} = r_noiseseed{2} /\ - (spkp{2} = pkp{2} /\ - sskp{2} = skp{2} /\ - randomnessp0{2} = randomnessp{2} /\ - pkp0{1} = pkp{1} /\ - skp0{1} = skp{1} /\ - randomnessp0{1} = randomnessp{1} /\ - ={Glob.mem, pkp, skp, randomnessp} /\ - Glob.mem{1} = mem /\ - to_uint pkp{1} = _pkp /\ - to_uint skp{1} = _skp /\ ={randomnessp} /\ valid_disj_reg _pkp (384 * 3 + 32) _skp (384 * 3)) /\ - ={publicseed, noiseseed, e, skpv, pkpv} /\ sskp{2} = skp{1} /\ spkp{2} = pkp{1}) /\ - aa{1} = nttunpackm a{2} /\ pos_bound2304_cxq aa{1} 0 2304 2 /\ pos_bound2304_cxq a{2} 0 2304 2 +bind op [W32.t & W8.t] W4u8.truncateu8 "truncate". +realize bvtruncateP. +move => mv; rewrite /truncateu8 /W32.w2bits take_mkseq //= /w2bits. +apply (eq_from_nth witness);1: by smt(size_mkseq). +move => i; rewrite size_mkseq /= /max /= => ib. +rewrite !nth_mkseq // /of_int /to_uint /= get_bits2w // + nth_mkseq //= get_to_uint //= /to_uint /=. +have -> /=: (0 <= i && i < 32) by smt(). +pose a := bs2int (w2bits mv). +rewrite {1}(divz_eq a (2^(8-i)*2^i)) !mulrA divzMDl; + 1: by smt(StdOrder.IntOrder.expr_gt0). +rewrite dvdz_modzDl; 1: by + have -> : 2^(8-i) = 2^((8-i-1)+1); [ by smt() | + rewrite exprS 1:/#; smt(dvdz_mull dvdz_mulr)]. +by have -> : (2 ^ (8 - i) * 2 ^ i) = 256; + [ rewrite -StdBigop.Bigint.Num.Domain.exprD_nneg + 1,2:/# /= -!addrA /= | done ]. +qed. - ==> - ={skpv, e} /\ - signed_bound768_cxq skpv{1} 0 768 1 /\ - signed_bound768_cxq e{1} 0 768 1 /\ signed_bound768_cxq skpv{2} 0 768 1 /\ signed_bound768_cxq e{2} 0 768 1 - ); 1,2:smt(). - + by inline {2} 1;do 2!(wp; call getnoiseequiv_avx);auto => />. - inline {1} 1. inline GetNoiseAVX2._poly_getnoise_eta1_4x. - wp; do 2!(call{1} (_: true ==> true); 1: by apply polygetnoise_ll). - do 6!(wp; call getnoiseequiv); auto => />. - move => &1 &2 ??????R?; split. - + by rewrite tP => k kb; rewrite !initiE //= initiE /#. - move => ?R0?; split. - + rewrite tP => k kb; rewrite !initiE //= initiE 1:/# /= initiE 1:/# /= /#. - move => ?R1?????; split. - + rewrite tP => k kb; rewrite !initiE //= initiE 1:/# /= initiE 1:/# /= initiE 1:/# /= /#. - move => ?R2?; do split. - + rewrite /signed_bound768_cxq => x xb /=. - rewrite !initiE //= fun_if. - case (512 <= x && x < 768); 1: by smt(). - move => *; rewrite !initiE //= fun_if. - case (256 <= x && x < 512); 1: by smt(). - move => *; rewrite !initiE //= fun_if. - by smt(). - + rewrite /signed_bound768_cxq => x xb /=. - rewrite !initiE //= fun_if. - case (512 <= x && x < 768); 1: by smt(). - move => *; rewrite !initiE //= fun_if. - case (256 <= x && x < 512); 1: by smt(). - move => *; rewrite !initiE //= fun_if. - by smt(). - + rewrite /signed_bound768_cxq => x xb /=. - rewrite !initiE //= fun_if. - case (512 <= x && x < 768); 1: by smt(). - move => *; rewrite !initiE //= fun_if. - case (256 <= x && x < 512); 1: by smt(). - move => *; rewrite !initiE //= fun_if. - by smt(). - rewrite /signed_bound768_cxq => x xb /=. - rewrite !initiE //= fun_if. - case (512 <= x && x < 768); 1: by smt(). - move => *; rewrite !initiE //= fun_if. - case (256 <= x && x < 512); 1: by smt(). - move => *; rewrite !initiE //= fun_if. - by smt(). -seq 2 2 : (#{/~skpv{1}}{~e{1}}{~skpv{2}}{~e{2}}pre /\ - lift_array768 skpv{1} = nttunpackv (lift_array768 skpv{2}) /\ - lift_array768 e{1} = nttunpackv (lift_array768 e{2}) /\ - pos_bound768_cxq skpv{1} 0 768 2 /\ - pos_bound768_cxq skpv{2} 0 768 2 /\ - pos_bound768_cxq e{1} 0 768 2 /\ - pos_bound768_cxq e{2} 0 768 2); 1: - by conseq />; call (nttequiv); call (nttequiv); auto => /> /#. +bind circuit VPBROADCAST_8u32 "VPBROADCAST_8u32". +bind circuit VPBROADCAST_4u64 "VPBROADCAST_4u64". -(* First ip *) -seq 8 4: (#{/~pkpv{2}}pre /\ - lift_array256 (subarray256 pkpv{1} 0) = nttunpack (lift_array256 (subarray256 pkpv{2} 0)) /\ - signed_bound768_cxq pkpv{1} 0 256 2 /\ - signed_bound768_cxq pkpv{2} 0 256 2 /\ i{1} = 1). -wp; call frommontequiv; wp; call pointwiseequiv; auto => />. -move => &1 &2 H2 H3 H4 H5 H6 H7 H8 H9 H10 H11 H12 H13 H14; do split. -+ rewrite -lift768_nttunpack. congr. - rewrite /nttunpackm /nttunpackv tP /= => k kb. - rewrite !initiE // 1:/# /= kb /= initiE //=. -+ rewrite /signed_bound768_cxq => k kb; rewrite initiE //=. -+ rewrite /unpackm /unpackv /=. - rewrite !initiE // 1:/# /= kb /= initiE //=. - rewrite fun_if. - case (0<=k<256). - + move => kbb;rewrite /subarray256. - move : (nttunpack_pred (Array256.init (fun (k0 : int) => (subarray768 a{2} 0).[256 * 0 + k0])) (fun x => -2*q <= W16.to_sint x < 2*q)). - rewrite !allP; move => /= [h0 h1]; rewrite h1. move => *. rewrite initiE //=. smt(Array768.initiE). smt(). - case (256<=k<512). - + move => kbb;rewrite /subarray256. - move : (nttunpack_pred (Array256.init (fun (k0 : int) => (subarray768 a{2} 0).[256 * 1 + k0])) (fun x => -2*q <= W16.to_sint x < 2*q)). - rewrite !allP; move => /= [h0 h1]; rewrite h1. move => *. rewrite initiE //=. smt(Array768.initiE). smt(). auto. - case (512<=k<768). - + move => kbbb;rewrite /subarray256. - move : (nttunpack_pred (Array256.init (fun (k0 : int) => (subarray768 a{2} 0).[256 * 2 + k0])) (fun x => -2*q <= W16.to_sint x < 2*q)). - rewrite !allP; move => /= [h0 h1]; rewrite h1. move => *. rewrite initiE //=. smt(Array768.initiE). smt(). auto. - by smt(). -+ move : H10; rewrite /pos_bound768_cxq /signed_bound_768_cxq /#. -+ move : H8; rewrite /pos_bound768_cxq /signed_bound_768_cxq; smt(Array768.initiE). -+ move : H10; rewrite /pos_bound768_cxq /signed_bound_768_cxq; smt(Array768.initiE). +bind circuit VPMADDWD_256 "VPMADDWD_16u16". -move => H15 H16 H17 H18 H19 r1 r2 H20 H21 H22;do split. -+ rewrite tP /= => k kb. - rewrite /lift_array256 /nttunpack !initiE //=. - pose a:=nttunpack_idx.[k]. - rewrite !mapiE //=; 1: smt(nttunpack_bnd Array256.allP). - rewrite !initiE //=; 1..3: smt(nttunpack_bnd Array256.allP). - rewrite kb /=. - have -> /= : 0 <= a && a < 256 by smt(nttunpack_bnd Array256.allP). - move : H20; rewrite /lift_array256 tP => H20. - move : (H20 k kb). - rewrite /nttunpack initiE //= -/a !mapiE //=;smt(nttunpack_bnd Array256.allP). +bind circuit VPSLLV_8u32 "VPSLLV_8u32". -move => H23 r3 r4 H24 H25 H26;do split. -+ rewrite tP /= => k kb. - rewrite /lift_array256 /nttunpack !initiE //=. - pose a:=nttunpack_idx.[k]. - rewrite !mapiE //=; 1: smt(nttunpack_bnd Array256.allP). - rewrite /subarray256 /=. - rewrite !initiE //=; 1: smt(nttunpack_bnd Array256.allP). - rewrite !initiE //=; 1..2: smt(nttunpack_bnd Array256.allP). - rewrite kb /=. - have -> /= : 0 <= a && a < 256 by smt(nttunpack_bnd Array256.allP). - move : H24; rewrite /lift_array256 tP => H24. - move : (H24 k kb). - rewrite /nttunpack initiE //= -/a !mapiE //=;smt(nttunpack_bnd Array256.allP). +bind circuit VPSRL_4u64 "VPSRL_4u64". +bind circuit VPSHUFB_256 "VPSHUFB_256". -+ rewrite /signed_bound768_cxq /= => k kb; rewrite !initiE //=. smt(). - rewrite kb /=. smt(). +bind circuit VEXTRACTI128 "VEXTRACTI128". -+ rewrite /signed_bound768_cxq /= => k kb; rewrite !initiE //=. smt(). - rewrite kb /=. smt(). +bind circuit VPBLENDW_128 "VPBLEND_8u16". -(* Second ip *) +bind circuit VPEXTR_32 "VEXTRACTI32_256". -seq 5 4: (#{/~i{1}}pre /\ lift_array256 (subarray256 pkpv{1} 1) = nttunpack (lift_array256 (subarray256 pkpv{2} 1)) /\ - signed_bound768_cxq pkpv{1} 256 512 2 /\ - signed_bound768_cxq pkpv{2} 256 512 2 /\ i{1} = 2). -wp; call frommontequiv; wp; call pointwiseequiv; auto => />. -move => &1 &2 H2 H3 H4 H5 H6 H7 H8 H9 H10 H11 H12 H13 H14 H15 H16 H17; do split. -+ rewrite -lift768_nttunpack. congr. - rewrite /nttunpackm /nttunpackv tP /= => k kb. - rewrite !initiE //= initiE //= 1:/# ifF 1:/# ifT 1:/# initiE //=. -+ by rewrite /signed_bound768_cxq => k kb; rewrite initiE //= /#. -+ by rewrite /pos_bound768_cxq /signed_bound_768_cxq /#. -+ by rewrite /signed_bound768_cxq => k kb; rewrite initiE //= /#. -+ by rewrite /pos_bound768_cxq /signed_bound_768_cxq /#. +bind circuit W4u32.VPEXTR_32 "VEXTRACTI32_128". -move => H18 H19 H20 H21 H22 r1 r2 H23 H24 H25;do split. -+ rewrite tP /= => k kb. - rewrite /lift_array256 /nttunpack !initiE //=. - pose a:=nttunpack_idx.[k]. - rewrite !mapiE //=; 1: smt(nttunpack_bnd Array256.allP). - rewrite !initiE //=; 1: smt(nttunpack_bnd Array256.allP). - rewrite !initiE //=; 1..2: smt(nttunpack_bnd Array256.allP). - rewrite !ifT /=; 1,2: smt(nttunpack_bnd Array256.allP). - move : H23; rewrite /lift_array256 tP => H23. - move : (H23 k kb). - rewrite /nttunpack initiE //= -/a !mapiE //=;smt(nttunpack_bnd Array256.allP). +bind op [W256.t & W128.t] truncateu128 "truncate". +realize bvtruncateP. +move => mv; rewrite /truncateu128 /W256.w2bits take_mkseq //= /w2bits. +apply (eq_from_nth witness);1: by smt(size_mkseq). +move => i; rewrite size_mkseq /= /max /= => ib. +rewrite !nth_mkseq // /of_int /to_uint /= get_bits2w // + nth_mkseq //= get_to_uint //= /to_uint /=. +have -> /=: (0 <= i && i < 256) by smt(). +pose a := bs2int (w2bits mv). +rewrite {1}(divz_eq a (2^(128-i)*2^i)) !mulrA divzMDl; + 1: by smt(StdOrder.IntOrder.expr_gt0). +rewrite dvdz_modzDl; 1: by + have -> : 2^(128-i) = 2^((128-i-1)+1); [ by smt() | + rewrite exprS 1:/#; smt(dvdz_mull dvdz_mulr)]. +by have -> : (2 ^ (128 - i) * 2 ^ i) = 340282366920938463463374607431768211456; + [ rewrite -StdBigop.Bigint.Num.Domain.exprD_nneg + 1,2:/# /= -!addrA /= | done ]. +qed. -move => H26 r3 r4 H27 H28 H29;do split. -+ rewrite tP /= => k kb. - rewrite /lift_array256 /nttunpack !initiE //=. - pose a:=nttunpack_idx.[k]. - rewrite !mapiE //=; 1: smt(nttunpack_bnd Array256.allP). - rewrite /subarray256 /=. - rewrite !initiE //=; 1: smt(nttunpack_bnd Array256.allP). - rewrite !initiE //=; 1..2: smt(nttunpack_bnd Array256.allP). - rewrite !ifF /=; 1..2: smt(nttunpack_bnd Array256.allP). - rewrite !initiE //=; 1..2: smt(nttunpack_bnd Array256.allP). - rewrite !ifF /=; 1..2: smt(nttunpack_bnd Array256.allP). - move : H15; rewrite /lift_array256 /subarray256 tP => H15. - move : (H15 k kb). - rewrite /nttunpack initiE //= -/a !mapiE //=; 1:smt(nttunpack_bnd Array256.allP). - rewrite !initiE //=; smt(nttunpack_bnd Array256.allP). +op srl_16 (w1 w2 : W16.t) : W16.t = + if 16 <= (to_uint w2) then W16.zero else + w1 `>>` (truncateu8 w2). -+ rewrite /signed_bound768_cxq /= => k kb; rewrite !initiE //=. smt(). - rewrite !ifF /=. smt(). - rewrite !initiE //=. smt(). smt(). +bind op [W16.t] srl_16 "shr". +realize bvshrP. +rewrite /srl_16 => bv1 bv2. +case : (16 <= to_uint bv2); last first. ++ rewrite /(`>>`) W16.to_uint_shr; 1: by smt(W8.to_uint_cmp). + rewrite /truncateu8 => bv2bnd />. + do 2! (rewrite (pmod_small (to_uint bv2) _);smt(W16.to_uint_cmp)). +move => *. +have -> : to_uint bv2 = (to_uint bv2 - 16) + 16 by ring. +rewrite exprD_nneg 1,2:/# /=. +smt(StdOrder.IntOrder.expr_gt0 W16.to_uint_cmp pow2_16). +qed. -+ rewrite /signed_bound768_cxq /= => k kb; rewrite !initiE //=. smt(). - rewrite !ifF /=. smt(). - rewrite !initiE //=. smt(). smt(). -+ rewrite tP /= => k kb. - rewrite /lift_array256 /nttunpack !initiE //=. - pose a:=nttunpack_idx.[k]. - rewrite !mapiE //=; 1: smt(nttunpack_bnd Array256.allP). - rewrite /subarray256 /=. - rewrite !initiE //=; 1: smt(nttunpack_bnd Array256.allP). - rewrite !initiE //=; 1..2: smt(nttunpack_bnd Array256.allP). - rewrite !ifT /=; 1..2: smt(nttunpack_bnd Array256.allP). - move : H27; rewrite /lift_array256 /subarray256 tP => H27. - move : (H27 k kb). - rewrite /nttunpack initiE //= -/a !mapiE //=; smt(nttunpack_bnd Array256.allP). +op sll_16 (w1 w2 : W16.t) : W16.t = + if (16 <= to_uint w2) then W16.zero else w1 `<<` (truncateu8 w2). -+ rewrite /signed_bound768_cxq /= => k kb; rewrite !initiE //=. smt(). - rewrite !initiE //=. smt(). smt(). +bind op [W16.t] sll_16 "shl". +realize bvshlP. +rewrite /sll_16 => bv1 bv2. +case : (16 <= to_uint bv2); last first. ++ rewrite /(`<<`) W16.to_uint_shl; 1: by smt(W8.to_uint_cmp). + rewrite /truncateu8 => bv2bnd />. + do 2! (rewrite (pmod_small (to_uint bv2) _);smt(W16.to_uint_cmp pow2_16)). +move => *. +have -> : to_uint bv2 = (to_uint bv2 - 16) + 16 by ring. +by rewrite exprD_nneg 1,2:/# /= /#. +qed. -+ rewrite /signed_bound768_cxq /= => k kb; rewrite !initiE //=. smt(). - rewrite !initiE //=. smt(). smt(). +op srl_64 (w1 w2 : W64.t) : W64.t = + if (64 <= to_uint w2) then W64.zero else w1 `>>` (truncateu8 w2). -(* Third ip *) +bind op [W64.t] srl_64 "shr". +realize bvshrP. +rewrite /srl_64 => bv1 bv2. +case : (64 <= to_uint bv2); last first. ++ rewrite /(`>>`) W64.to_uint_shr; 1: by smt(W8.to_uint_cmp). + rewrite /truncateu8 => bv2bnd />. + do 2! (rewrite (pmod_small (to_uint bv2) _);smt(W64.to_uint_cmp)). +move => *. +have -> : to_uint bv2 = (to_uint bv2 - 64) + 64 by ring. +rewrite exprD_nneg 1,2:/# /=. +smt(StdOrder.IntOrder.expr_gt0 W64.to_uint_cmp pow2_64). +qed. -seq 5 4: (#{/~i{1}}pre /\ lift_array256 (subarray256 pkpv{1} 2) = nttunpack (lift_array256 (subarray256 pkpv{2} 2)) /\ - signed_bound768_cxq pkpv{1} 512 768 2 /\ - signed_bound768_cxq pkpv{2} 512 768 2). -wp; call frommontequiv; wp; call pointwiseequiv; auto => />. -move => &1 &2 H2 H3 H4 H5 H6 H7 H8 H9 H10 H11 H12 H13 H14 H15 H16 H17 H18 H19 H20; do split. -+ rewrite -lift768_nttunpack. congr. - rewrite /nttunpackm /nttunpackv tP /= => k kb. - rewrite !initiE //= initiE //= 1:/# ifF 1:/# ifF 1:/# initiE //=. -+ by rewrite /signed_bound768_cxq => k kb; rewrite initiE //= /#. -+ by rewrite /pos_bound768_cxq /signed_bound_768_cxq /#. -+ by rewrite /signed_bound768_cxq => k kb; rewrite initiE //= /#. -+ by rewrite /pos_bound768_cxq /signed_bound_768_cxq /#. -move => H21 H22 H23 H24 H25 r1 r2 H26 H27 H28;do split. -+ rewrite tP /= => k kb. - rewrite /lift_array256 /nttunpack !initiE //=. - pose a:=nttunpack_idx.[k]. - rewrite !mapiE //=; 1: smt(nttunpack_bnd Array256.allP). - rewrite !initiE //=; 1: smt(nttunpack_bnd Array256.allP). - rewrite !initiE //=; 1..2: smt(nttunpack_bnd Array256.allP). - rewrite !ifT /=; 1,2: smt(nttunpack_bnd Array256.allP). - move : H26; rewrite /lift_array256 tP => H26. - move : (H26 k kb). - rewrite /nttunpack initiE //= -/a !mapiE //=;smt(nttunpack_bnd Array256.allP). +(* shake assumptions *) - move => H29 r3 r4 H30 H31 H32;do split. -+ rewrite tP /= => k kb. - rewrite /lift_array256 /nttunpack !initiE //=. - pose a:=nttunpack_idx.[k]. - rewrite !mapiE //=; 1: smt(nttunpack_bnd Array256.allP). - rewrite /subarray256 /=. - rewrite !initiE //=; 1: smt(nttunpack_bnd Array256.allP). - rewrite !initiE //=; 1..2: smt(nttunpack_bnd Array256.allP). - rewrite !ifF /=; 1..2: smt(nttunpack_bnd Array256.allP). - rewrite !initiE //=; 1..2: smt(nttunpack_bnd Array256.allP). - rewrite !ifF /=; 1..2: smt(nttunpack_bnd Array256.allP). - move : H15; rewrite /lift_array256 /subarray256 tP => H15. - move : (H15 k kb). - rewrite /nttunpack initiE //= -/a !mapiE //=; 1:smt(nttunpack_bnd Array256.allP). - rewrite !initiE //=; smt(nttunpack_bnd Array256.allP). +op SHAKE256_ABSORB4x_33 : W8.t Array33.t -> W8.t Array33.t -> W8.t Array33.t -> W8.t Array33.t -> W256.t Array25.t. +op SHAKE256_SQUEEZENBLOCKS4x : W256.t Array25.t -> W256.t Array25.t * W8.t Array136.t * W8.t Array136.t * W8.t Array136.t * W8.t Array136.t. -+ rewrite /signed_bound768_cxq /= => k kb; rewrite !initiE //=. smt(). - rewrite !ifF /=. smt(). - rewrite !initiE //=. smt(). smt(). +axiom shake_absorb4x state seed1 seed2 seed3 seed4 : + phoare [ Jkem_avx2.M(Jkem_avx2.Syscall)._shake256_absorb4x_33 : + arg = (state,seed1,seed2,seed3,seed4) ==> + res = SHAKE256_ABSORB4x_33 seed1 seed2 seed3 seed4 ] = 1%r. -+ rewrite /signed_bound768_cxq /= => k kb; rewrite !initiE //=. smt(). - rewrite !ifF /=. smt(). - rewrite !initiE //=. smt(). smt(). +axiom shake_squeezenblocks4x state buf1 buf2 buf3 buf4 : + phoare [ Jkem_avx2.M(Jkem_avx2.Syscall).__shake256_squeezenblocks4x : + arg = (state,buf1,buf2,buf3,buf4) ==> + res = SHAKE256_SQUEEZENBLOCKS4x state ] = 1%r. -+ rewrite tP /= => k kb. - rewrite /lift_array256 /nttunpack !initiE //=. - pose a:=nttunpack_idx.[k]. - rewrite !mapiE //=; 1: smt(nttunpack_bnd Array256.allP). - rewrite /subarray256 /=. - rewrite !initiE //=; 1: smt(nttunpack_bnd Array256.allP). - rewrite !initiE //=; 1..2: smt(nttunpack_bnd Array256.allP). - rewrite !ifF /=; 1..2: smt(nttunpack_bnd Array256.allP). - rewrite !initiE //=; 1..2: smt(nttunpack_bnd Array256.allP). - rewrite !ifF /=; 1..2: smt(nttunpack_bnd Array256.allP). - move : H18; rewrite /lift_array256 /subarray256 tP => H18. - move : (H18 k kb). - rewrite /nttunpack initiE //= -/a !mapiE //=; 1:smt(nttunpack_bnd Array256.allP). - rewrite !initiE //=; smt(nttunpack_bnd Array256.allP). +axiom shake4x_equiv (sn1 sn2 sn3 sn4: W8.t Array33.t) (s1 s2 s3 s4 : W8.t Array32.t) n1 n2 n3 n4 : + s1 = Array32.init (fun i => sn1.[i]) => + s2 = Array32.init (fun i => sn2.[i]) => + s3 = Array32.init (fun i => sn3.[i]) => + s4 = Array32.init (fun i => sn4.[i]) => + n1 = sn1.[32] => n2 = sn2.[32] => n3 = sn3.[32] => n4 = sn4.[32] => + Array128.init ("_.[_]" (SHAKE256_SQUEEZENBLOCKS4x (SHAKE256_ABSORB4x_33 sn1 sn2 sn3 sn4)).`2) = SHAKE256_33_128 s1 n1 /\ + Array128.init ("_.[_]" (SHAKE256_SQUEEZENBLOCKS4x (SHAKE256_ABSORB4x_33 sn1 sn2 sn3 sn4)).`3) = SHAKE256_33_128 s2 n2 /\ + Array128.init ("_.[_]" (SHAKE256_SQUEEZENBLOCKS4x (SHAKE256_ABSORB4x_33 sn1 sn2 sn3 sn4)).`4) = SHAKE256_33_128 s3 n3 /\ + Array128.init ("_.[_]" (SHAKE256_SQUEEZENBLOCKS4x (SHAKE256_ABSORB4x_33 sn1 sn2 sn3 sn4)).`5) = SHAKE256_33_128 s4 n4. -+ rewrite /signed_bound768_cxq /= => k kb; rewrite !initiE //=. smt(). - rewrite !ifF /=. smt(). - rewrite !initiE //=. smt(). smt(). +axiom sha3equiv : + equiv [ (* is this in the sha3 paper? *) +Jkem_avx2.M(Jkem_avx2.Syscall)._sha3_512_32 ~Jkem.M(Jkem.Syscall)._sha3512_32 : ={arg} ==> ={res}]. -+ rewrite /signed_bound768_cxq /= => k kb; rewrite !initiE //=. smt(). - rewrite !ifF /=. smt(). - rewrite !initiE //=. smt(). smt(). +lemma keccakf1600_set_row_ll : islossless M(Syscall).keccakf1600_set_row. +proc. by unroll for ^while; auto. qed. -+ rewrite tP /= => k kb. - rewrite /lift_array256 /nttunpack !initiE //=. - pose a:=nttunpack_idx.[k]. - rewrite !mapiE //=; 1: smt(nttunpack_bnd Array256.allP). - rewrite /subarray256 /=. - rewrite !initiE //=; 1: smt(nttunpack_bnd Array256.allP). - rewrite !initiE //=; 1..2: smt(nttunpack_bnd Array256.allP). - rewrite !ifT /=; 1..2: smt(nttunpack_bnd Array256.allP). - move : H30; rewrite /lift_array256 /subarray256 tP => H30. - move : (H30 k kb). - rewrite /nttunpack initiE //= -/a !mapiE //=; smt(nttunpack_bnd Array256.allP). +lemma keccakf1600_rho_offsets_ll : islossless M(Syscall).keccakf1600_rho_offsets. +proc. by unroll for ^while; islossless. qed. -+ rewrite /signed_bound768_cxq /= => k kb; rewrite !initiE //=. smt(). - rewrite !initiE //=. smt(). smt(). +lemma keccakf1600_rhotates_ll : islossless M(Syscall).keccakf1600_rhotates. +proc. by call keccakf1600_rho_offsets_ll; islossless. qed. -+ rewrite /signed_bound768_cxq /= => k kb; rewrite !initiE //=. smt(). - rewrite !initiE //=. smt(). smt(). +lemma keccakf1600_theta_rol_ll : islossless M(Syscall).keccakf1600_theta_rol. +proc. by unroll for ^while; islossless. qed. +lemma keccakf1600_theta_sum_ll : islossless M(Syscall).keccakf1600_theta_sum. +proc. by do 6!(unroll for ^while); islossless. qed. -auto => />. +lemma keccakf1600_rol_sum_ll : islossless M(Syscall).keccakf1600_rol_sum. +proc. +while (x <= 5) (5 - x); auto; last smt(). +conseq => /=; call keccakf1600_rhotates_ll; auto => /#. +qed. -move => &1 &2 ?????????????H1??H2??H3??. -do split. -+ smt(). -+ smt(). -+ rewrite /lift_array256 /subarray256 tP in H1. - rewrite /lift_array256 /subarray256 tP in H2. - rewrite /lift_array256 /subarray256 tP in H3. - rewrite /nttpackv tP => k kb. - rewrite initiE //=. - case (0 <= k && k < 256). - + move => kbb. - move : (H1 (nttpack_idx.[k]) _); 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). - pose a:=nttpack_idx.[k]. - rewrite !mapiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). - rewrite /nttunpack !initiE //=; 1..2: smt(nttpack_bnd nttunpack_bnd Array256.allP). - pose b:=nttunpack_idx.[a]. - rewrite !mapiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). - rewrite /lift_array768 /subarray256 /=. - pose c := nttpack_idx.[k]. - rewrite /nttunpack initiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). - rewrite /nttpack initiE //=. - rewrite initiE //=. smt(nttpack_bnd nttunpack_bnd Array256.allP). - rewrite !mapiE //=; 1: smt(nttpack_bnd Array256.allP). - have -> : b = k. move : nttunpack_idxK; rewrite /b /a allP; smt(mem_iota). - smt(). - case (256 <= k && k < 512). - + move => kbb ?. - move : (H2 (nttpack_idx.[k-256]) _); 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). - pose a:=nttpack_idx.[k-256]. - rewrite !mapiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). - rewrite initiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). - rewrite /nttunpack initiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). - pose b:=nttunpack_idx.[a]. - rewrite !mapiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). - rewrite initiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). - rewrite /lift_array768 /subarray256 /=. - pose c := nttpack_idx.[k-256]. - rewrite /nttpack initiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). - rewrite initiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). - rewrite !mapiE //=; 1: smt(nttpack_bnd Array256.allP). - have -> : b = k-256. move : nttunpack_idxK; rewrite /b /a allP; smt(mem_iota). - smt(). +lemma keccakf1600_round_ll : islossless Jkem.M(Syscall).keccakf1600_round. +proc; auto. +while (y <= 5) (5 - y); auto. ++ call keccakf1600_set_row_ll. + call keccakf1600_rol_sum_ll. + auto; smt(). +call keccakf1600_theta_rol_ll. +call keccakf1600_theta_sum_ll. +auto; smt(). +qed. - + move => *. - move : (H3 (nttpack_idx.[k-512]) _); 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). - pose a:=nttpack_idx.[k-512]. - rewrite !mapiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). - rewrite initiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). - rewrite /nttunpack initiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). - pose b:=nttunpack_idx.[a]. - rewrite !mapiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). - rewrite initiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). - rewrite /lift_array768 /subarray256 /=. - pose c := nttpack_idx.[k-512]. - rewrite /nttpack initiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). - rewrite initiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). - rewrite !mapiE //=; 1: smt(nttpack_bnd Array256.allP). - have -> : b = k-512. move : nttunpack_idxK; rewrite /b /a allP; smt(mem_iota). - smt(). +lemma keccakf1600_ll : islossless Jkem.M(Syscall)._keccakf1600_. +proc; auto. +call (:true); auto. +call (:true); auto. +while (to_uint c <= 24 /\ to_uint c %% 2 = 0) (24 - to_uint c); auto; last by move => /> *; rewrite ultE to_uint_small //= /#. +call keccakf1600_round_ll; auto. +call keccakf1600_round_ll; auto. +move => /> ??; rewrite ultE to_uintD_small to_uint_small //= /#. +qed. -+ smt(unpackvK). -+ smt(). -+ smt(). -by smt(unpackvK). +lemma sha3ll : islossless Jkem.M(Jkem.Syscall)._shake256_128_33. +proof. +proc. +unroll for 9; wp; conseq => /=. +call keccakf1600_ll; auto. +conseq => /=. +unroll for ^while; auto. +conseq => /=. +inline *; unroll for ^while; auto. qed. -(***************************************************) -*) +equiv genmatrixequiv b : + Jkem_avx2.M(Jkem_avx2.Syscall)._gen_matrix_avx2 ~ Jkem.M(Jkem.Syscall).__gen_matrix : + arg{1}.`2 = arg{2}.`1 /\ arg{1}.`3= (W64.of_int (b2i b)) /\ arg{2}.`2 = (W64.of_int (b2i b)) ==> + res{1} = nttunpackm res{2} /\ + pos_bound2304_cxq res{1} 0 2304 2 /\ + pos_bound2304_cxq res{2} 0 2304 2. +proc* => /=. +transitivity {2} { r <@ AuxMLKEM.__gen_matrix(seed,b); } + ( rho{1} = seed{2} /\ transposed{1} = (of_int (b2i b))%W64 /\ transposed{2} = (of_int (b2i b))%W64 ==> r{1} = nttunpackm r{2} /\ pos_bound2304_cxq r{1} 0 2304 2 /\ pos_bound2304_cxq r{2} 0 2304 2 ) + ( seed{1} = seed{2} /\ transposed{1} = (of_int (b2i b))%W64 /\ transposed{2} = (of_int (b2i b))%W64==> ={r});1,2:smt(). + + call (genmatrixequiv_aux b); 1: by auto => />. + by symmetry;call (auxgenmatrix_good); auto => /> /#. +qed. - import WArray960 WArray1536 Array4. +module GetNoiseAVX2 = { + proc _poly_getnoise_eta1_4x(aux3 aux2 aux1 aux0 : W16.t Array256.t, + noiseseed : W8.t Array32.t, + nonce : W8.t) : + W16.t Array256.t * W16.t Array256.t * W16.t Array256.t * W16.t Array256.t = { + var n3, n2, n1, n0 : W8.t; + var aux_3, aux_2, aux_1, aux_0 : W16.t Array256.t; + n0 <- nonce + W8.of_int 3; + n1 <- nonce + W8.of_int 2; + n2 <- nonce + W8.of_int 1; + n3 <- nonce; + aux_3 <@Jkem.M(Jkem.Syscall)._poly_getnoise(aux3,noiseseed,n3); + aux_2 <@Jkem.M(Jkem.Syscall)._poly_getnoise(aux2,noiseseed,n2); + aux_1 <@Jkem.M(Jkem.Syscall)._poly_getnoise(aux1,noiseseed,n1); + aux_0 <@Jkem.M(Jkem.Syscall)._poly_getnoise(aux0,noiseseed,n0); + return (aux_3, aux_2, aux_1, aux_0); + } -module AuxPolyVecCompress10 = { - proc avx2_orig(ctp : W64.t, bp : W16.t Array768.t) : W8.t Array960.t = { - bp <@ Jkem_avx2.M(Jkem_avx2.Syscall).__polyvec_reduce_sig(bp); - Jkem_avx2.M(Jkem_avx2.Syscall).__polyvec_compress(ctp, bp); - return witness; - } - - proc avx2_orig_i(ctp : W8.t Array1088.t, bp : W16.t Array768.t) : W8.t Array960.t = { - var rr : W8.t Array960.t; - bp <@ Jkem_avx2.M(Jkem_avx2.Syscall).__polyvec_reduce_sig(bp); - rr <@ Jkem_avx2.M(Jkem_avx2.Syscall).__polyvec_compress_1(Array960.init (fun (i_0 : int) => ctp.[0 + i_0]),bp); - return rr; - } + proc sample_noise_kg(skpv pkpv e : W16.t Array768.t, noiseseed:W8.t Array32.t) : W16.t Array768.t * W16.t Array768.t ={ + var nonce : W8.t; + var aux_3, aux_2, aux_1, aux_0 : W16.t Array256.t; + nonce <- (W8.of_int 0); + (aux_3, aux_2, aux_1, + aux_0) <@ _poly_getnoise_eta1_4x ((Array256.init (fun i_0 => skpv.[0 + i_0])), + (Array256.init (fun i_0 => skpv.[256 + i_0])), + (Array256.init (fun i_0 => skpv.[(2 * 256) + i_0])), + (Array256.init (fun i_0 => e.[0 + i_0])), noiseseed, nonce); + skpv <- Array768.init + (fun i_0 => if 0 <= i_0 < 0 + 256 then aux_3.[i_0-0] + else skpv.[i_0]); + skpv <- Array768.init + (fun i_0 => if 256 <= i_0 < 256 + 256 + then aux_2.[i_0-256] else skpv.[i_0]); + skpv <- Array768.init + (fun i_0 => if (2 * 256) <= i_0 < (2 * 256) + 256 + then aux_1.[i_0-(2 * 256)] else skpv.[i_0]); + e <- Array768.init + (fun i_0 => if 0 <= i_0 < 0 + 256 then aux_0.[i_0-0] + else e.[i_0]); + nonce <- (W8.of_int 4); + (aux_3, aux_2, aux_1, + aux_0) <@ _poly_getnoise_eta1_4x ((Array256.init (fun i_0 => e.[256 + i_0])), + (Array256.init (fun i_0 => e.[(2 * 256) + i_0])), + (Array256.init (fun i_0 => pkpv.[0 + i_0])), + (Array256.init (fun i_0 => pkpv.[256 + i_0])), noiseseed, + nonce); + e <- Array768.init + (fun i_0 => if 256 <= i_0 < 256 + 256 + then aux_3.[i_0-256] else e.[i_0]); + e <- Array768.init + (fun i_0 => if (2 * 256) <= i_0 < (2 * 256) + 256 + then aux_2.[i_0-(2 * 256)] else e.[i_0]); + return (skpv,e); + } + proc samplenoise_enc(sp_0 ep bp : W16.t Array768.t, epp : W16.t Array256.t, noiseseed:W8.t Array32.t) : W16.t Array768.t * W16.t Array768.t * W16.t Array768.t * W16.t Array256.t = { + var nonce : W8.t; + var aux_2, aux_1, aux_0, aux : W16.t Array256.t; + nonce <- (W8.of_int 0); + (aux_2, aux_1, aux_0, + aux) <@ _poly_getnoise_eta1_4x ((Array256.init (fun i_0 => sp_0.[0 + i_0])), + (Array256.init (fun i_0 => sp_0.[256 + i_0])), + (Array256.init (fun i_0 => sp_0.[(2 * 256) + i_0])), + (Array256.init (fun i_0 => ep.[0 + i_0])), noiseseed, + nonce); + sp_0 <- Array768.init + (fun i_0 => if 0 <= i_0 < 0 + 256 then aux_2.[i_0-0] + else sp_0.[i_0]); + sp_0 <- Array768.init + (fun i_0 => if 256 <= i_0 < 256 + 256 + then aux_1.[i_0-256] else sp_0.[i_0]); + sp_0 <- Array768.init + (fun i_0 => if (2 * 256) <= i_0 < (2 * 256) + 256 + then aux_0.[i_0-(2 * 256)] else sp_0.[i_0]); + ep <- Array768.init + (fun i_0 => if 0 <= i_0 < 0 + 256 then aux.[i_0-0] + else ep.[i_0]); + nonce <- (W8.of_int 4); + (aux_2, aux_1, aux_0, + aux) <@ _poly_getnoise_eta1_4x ((Array256.init (fun i_0 => ep.[256 + i_0])), + (Array256.init (fun i_0 => ep.[(2 * 256) + i_0])), epp, + (Array256.init (fun i_0 => bp.[0 + i_0])), noiseseed, + nonce); + ep <- Array768.init + (fun i_0 => if 256 <= i_0 < 256 + 256 + then aux_2.[i_0-256] else ep.[i_0]); + ep <- Array768.init + (fun i_0 => if (2 * 256) <= i_0 < (2 * 256) + 256 + then aux_1.[i_0-(2 * 256)] else ep.[i_0]); + epp <- aux_0; + bp <- Array768.init + (fun i_0 => if 0 <= i_0 < 0 + 256 then aux.[i_0-0] + else bp.[i_0]); + return (sp_0,ep,bp, epp); -proc __polyvec_compress_avx2(ctp : W8.t Array1088.t, a : W16.t Array768.t) : W8.t Array960.t = { - var aux : int; - var b0 : W256.t; - var b1 : W256.t; - var b2 : W256.t; - var mask10 : W256.t; - var shift : W256.t; - var sllv_indx : W256.t; - var shuffle : W256.t; - var i : int; - var a0 : W256.t; - var lo : W128.t; - var hi : W128.t; - var rp : W8.t Array960.t <- (init (fun (i_0 : int) => ctp.[0 + i_0]))%Array960; - - b0 <- VPBROADCAST_16u16 compress10_b0; - b1 <- VPBROADCAST_16u16 compress10_b1; - b2 <- VPBROADCAST_16u16 pc_shift1_s; - mask10 <- VPBROADCAST_16u16 pvc_mask_s; - shift <- VPBROADCAST_8u32 compress10_shift; - sllv_indx <- VPBROADCAST_4u64 pvc_sllvdidx_s; - shuffle <- get256 ((init8 (fun (i_0 : int) => pvc_shufbidx_s.[i_0])))%WArray32 0; - aux <- 3 * 256 %/ 16; - i <- 0; - while (i < aux){ - a0 <- (get256 ((WArray1536.init16 (fun (i_0 : int) => a.[i_0]))) i); - a0 <@ Jkem_avx2.M(Syscall).compress10_16x16_inline(a0, b0, b1, b2, mask10); - (lo, hi) <@ Jkem_avx2.M(Syscall).pack10_16x16(a0, shift, sllv_indx, shuffle); -rp <- - (init - (get8 - (set128_direct - ((init8 (fun (i_0 : int) => rp.[i_0])))%WArray960 - (20 * i) lo)))%Array960 ; - - rp <- - (init - (get8 - (set32_direct - ((init8 (fun (i_0 : int) => rp.[i_0])))%WArray960 - (20 * i + 16) (VPEXTR_32 hi W8.zero))))%Array960 ; - (* - Glob.mem <- storeW128 Glob.mem (to_uint (r + (of_int (i * 20 + 0))%W64)) lo; - Glob.mem <- storeW32 Glob.mem (to_uint (r + (of_int (i * 20 + 16))%W64)) (VPEXTR_32 hi W8.zero); -*) - i <- i + 1; - } - - return rp; } +}. - proc avx2(bp : W16.t Array768.t) : W8.t Array960.t = { - var rr : W8.t Array960.t; - var ctp : W8.t Array1088.t <- (init (fun (i_0 : int) => W8.zero))%Array1088; - bp <@ Jkem_avx2.M(Jkem_avx2.Syscall).__polyvec_reduce_sig(bp); - rr <@ __polyvec_compress_avx2(ctp,bp); - return rr; - } - - proc ref_orig(ctp : W64.t, bp : W16.t Array768.t) : W8.t Array960.t = { - bp <@ Jkem.M(Syscall).__polyvec_reduce(bp); - Jkem.M(Syscall).__polyvec_compress(ctp, bp); - return witness; - } +(* int value of jth noise coeficient *) +op noise_coef (bytes: W8.t Array128.t) (j: int): int = + let b = bytes.[j%/2] in b2i b.[j%%2*4] + b2i b.[j%%2*4+1] - (b2i b.[j%%2*4+2] + b2i b.[j%%2*4+3]). - proc ref_orig_i(ctp : W8.t Array1088.t, bp : W16.t Array768.t) : W8.t Array960.t = { - var rr : W8.t Array960.t; - bp <@ M(Syscall).__polyvec_reduce(bp) ; - rr <@ M(Syscall).__i_polyvec_compress(Array960.init (fun (i_0 : int) => ctp.[0 + i_0]),bp); - return rr; - } +import WArray128. -proc _poly_csubq_ref(rp : W16.t Array256.t) : W16.t Array256.t = { - var i : int; - var t : W16.t; - var b : W16.t; - - i <- 0; - while (i < 256){ - t <- rp.[i]; - t <- t - qlocal; - b <- t; - b <- b `|>>` (of_int 15)%W8; - b <- b `&` qlocal; - t <- t + b; - rp.[i] <- t; - i <- i + 1; - } - - return rp; - } - proc __polyvec_csubq_ref(r : W16.t Array768.t) : W16.t Array768.t = { - var aux : W16.t Array256.t; - - aux <@ _poly_csubq_ref((init (fun (i : int) => r.[0 + i]))%Array256); - r <- (init (fun (i : int) => if 0 <= i && i < 0 + 256 then aux.[i - 0] else r.[i]))%Array768; - aux <@ _poly_csubq_ref((init (fun (i : int) => r.[256 + i]))%Array256); - r <- (init (fun (i : int) => if 256 <= i && i < 256 + 256 then aux.[i - 256] else r.[i]))%Array768; - aux <@ _poly_csubq_ref((init (fun (i : int) => r.[2 * 256 + i]))%Array256); - r <- (init (fun (i : int) => if 2 * 256 <= i && i < 2 * 256 + 256 then aux.[i - 2 * 256] else r.[i]))%Array768; - - return r; - } +op B2Ri (bytes: W8.t Array128.t) (j: int): W256.t = + get256 (WArray128.init8 (fun i => bytes.[i])) j. -(* proc __polyvec_compress_ref(a : W16.t Array768.t) : W8.t Array960.t = { - var aux : int; - var i : int; - var j : int; - var aa : W16.t Array768.t; - var k : int; - var t : W64.t Array4.t; - var c : W16.t; - var b : W16.t; - var rr : W8.t Array960.t <- Array960.init(fun _ => W8.zero); - - aa <- witness; - t <- Array4.init(fun _ => W64.zero); - i <- 0; - j <- 0; - aa <@ __polyvec_csubq_ref(a); - while (i < (3 * 256 - 3)){ - k <- 0; - while (k < 4){ - t.[k] <- zeroextu64 aa.[i]; - i <- i + 1; - t.[k] <- (t.[k]) `<<` (of_int 10)%W8; - t.[k] <- (t.[k]) + (of_int 1665)%W64; - t.[k] <- (t.[k]) * (of_int 1290167)%W64; - t.[k] <- (t.[k]) `>>` (of_int 32)%W8; - t.[k] <- (t.[k]) `&` (of_int 1023)%W64; - k <- k + 1; - } - c <- truncateu16 (t.[0]); - c <- c `&` (of_int 255)%W16; - rr.[j] <- (truncateu8 c); - (* - Glob.mem <- storeW8 Glob.mem (to_uint (rp + j)) (truncateu8 c); - *) - j <- j + 1; - b <- truncateu16 (t.[0]); - b <- b `>>` (of_int 8)%W8; - c <- truncateu16 (t.[1]); - c <- c `<<` (of_int 2)%W8; - c <- c `|` b; - rr.[j] <- (truncateu8 c); - (* - Glob.mem <- storeW8 Glob.mem (to_uint (rp + j)) (truncateu8 c); - *) - j <- j + 1; - b <- truncateu16 (t.[1]); - b <- b `>>` (of_int 6)%W8; - c <- truncateu16 (t.[2]); - c <- c `<<` (of_int 4)%W8; - c <- c `|` b; - rr.[j] <- (truncateu8 c); - (* - Glob.mem <- storeW8 Glob.mem (to_uint (rp + j)) (truncateu8 c); - *) - j <- j + 1; - b <- truncateu16 (t.[2]); - b <- b `>>` (of_int 4)%W8; - c <- truncateu16 (t.[3]); - c <- c `<<` (of_int 6)%W8; - c <- c `|` b; - rr.[j] <- (truncateu8 c); - (* - Glob.mem <- storeW8 Glob.mem (to_uint (rp + j)) (truncateu8 c); - *) - j <- j + 1; - t.[3] <- (t.[3]) `>>` (of_int 2)%W8; - rr.[j] <- (truncateu8 (t.[3])); - (* - Glob.mem <- storeW8 Glob.mem (to_uint (rp + j)) (truncateu8 (t.[3])); - *) - j <- j + 1; - } - - return rr; - } -*) +lemma bytes_getR (bytes: W8.t Array128.t) (k: int): + 0 <= k && k < 128 => + bytes.[k] = B2Ri bytes (k %/ 32) \bits8 (k %% 32). +proof. +move=> Hk; rewrite /B2Ri /get256_direct pack32bE 1:/# initiE 1:/# /=. +by rewrite mulrC -divz_eq /init8 initiE. +qed. - proc __i_polyvec_compress_ref(a : W16.t Array768.t) : W8.t Array960.t = { - var aux : int; - var i : int; - var j : int; - var aa : W16.t Array768.t; - var t : W64.t t; - var c : W16.t; - var b : W16.t; - - var rp : W8.t Array960.t <- Array960.init(fun _ => W8.zero); - t <- Array4.init(fun _ => W64.zero); - aa <@ __polyvec_csubq_ref(a); - j <- 0; - i <- 0; - while (i < (3 * 256 - 3)){ - (* k = 0 *) - t.[0] <- zeroextu64 aa.[i+0]; - t.[0] <- t.[0] `<<` (of_int 10)%W8; - t.[0] <- t.[0] + (of_int 1665)%W64; - t.[0] <- t.[0] * (of_int 1290167)%W64; - t.[0] <- t.[0] `>>` (of_int 32)%W8; - t.[0] <- t.[0] `&` (of_int 1023)%W64; - (* k = 1 *) - t.[1] <- zeroextu64 aa.[i+1]; - t.[1] <- t.[1] `<<` (of_int 10)%W8; - t.[1] <- t.[1] + (of_int 1665)%W64; - t.[1] <- t.[1] * (of_int 1290167)%W64; - t.[1] <- t.[1] `>>` (of_int 32)%W8; - t.[1] <- t.[1] `&` (of_int 1023)%W64; - (* k = 2 *) - t.[2] <- zeroextu64 aa.[i+2]; - t.[2] <- t.[2] `<<` (of_int 10)%W8; - t.[2] <- t.[2] + (of_int 1665)%W64; - t.[2] <- t.[2] * (of_int 1290167)%W64; - t.[2] <- t.[2] `>>` (of_int 32)%W8; - t.[2] <- t.[2] `&` (of_int 1023)%W64; - (* k = 3 *) - t.[3] <- zeroextu64 aa.[i+3]; - t.[3] <- t.[3] `<<` (of_int 10)%W8; - t.[3] <- t.[3] + (of_int 1665)%W64; - t.[3] <- t.[3] * (of_int 1290167)%W64; - t.[3] <- t.[3] `>>` (of_int 32)%W8; - t.[3] <- t.[3] `&` (of_int 1023)%W64; - c <- truncateu16 t.[0]; - c <- c `&` (of_int 255)%W16; - rp.[j] <- truncateu8 c; - j <- j + 1; - b <- truncateu16 t.[0]; - b <- b `>>` (of_int 8)%W8; - c <- truncateu16 t.[1]; - c <- c `<<` (of_int 2)%W8; - c <- c `|` b; - rp.[j] <- truncateu8 c; - j <- j + 1; - b <- truncateu16 t.[1]; - b <- b `>>` (of_int 6)%W8; - c <- truncateu16 t.[2]; - c <- c `<<` (of_int 4)%W8; - c <- c `|` b; - rp.[j] <- truncateu8 c; - j <- j + 1; - b <- truncateu16 t.[2]; - b <- b `>>` (of_int 4)%W8; - c <- truncateu16 t.[3]; - c <- c `<<` (of_int 6)%W8; - c <- c `|` b; - rp.[j] <- truncateu8 c; - j <- j + 1; - t.[3] <- t.[3] `>>` (of_int 2)%W8; - rp.[j] <- truncateu8 t.[3]; - j <- j + 1; - i <- i + 4; - } - - return rp; - } +abbrev mask55u256 = VPBROADCAST_8u32 (W32.of_int 1431655765). +abbrev mask33u256 = VPBROADCAST_8u32 (W32.of_int 858993459). +abbrev mask03u256 = VPBROADCAST_8u32 (W32.of_int 50529027). +abbrev mask0Fu256 = VPBROADCAST_8u32 (W32.of_int 252645135). -proc __poly_reduce(rp : W16.t Array256.t) : W16.t Array256.t = { - var j : int; - var t : W16.t; - - j <- 0; - while (j < 256){ - t <- rp.[j]; - t <@ M(Syscall).__barrett_reduce(t); - rp.[j] <- t; - j <- j + 1; - } - - return rp; - } +abbrev mask55u16 = W16.of_int 21845. (* 21845 = 0x5555 *) +abbrev mask33u16 = W16.of_int 13107. (* 13107 = 0x3333 *) +abbrev mask03u16 = W16.of_int 771. (* 771 = 0x0303 *) +abbrev mask0Fu16 = W16.of_int 3855. (* 3855 = 0x0F0F *) - proc __polyvec_reduce(r : W16.t Array768.t) : W16.t Array768.t = { - var aux : W16.t Array256.t; - - aux <@ __poly_reduce((init (fun (i : int) => r.[0 + i]))%Array256); - r <- (init (fun (i : int) => if 0 <= i && i < 0 + 256 then aux.[i - 0] else r.[i]))%Array768; - aux <@ __poly_reduce((init (fun (i : int) => r.[256 + i]))%Array256); - r <- (init (fun (i : int) => if 256 <= i && i < 256 + 256 then aux.[i - 256] else r.[i]))%Array768; - aux <@ __poly_reduce((init (fun (i : int) => r.[2 * 256 + i]))%Array256); - r <- (init (fun (i : int) => if 2 * 256 <= i && i < 2 * 256 + 256 then aux.[i - 2 * 256] else r.[i]))%Array768; - - return r; - } +abbrev mask55u8 = W8.of_int 85. (* 85 = 0x55 *) +abbrev mask33u8 = W8.of_int 51. (* 51 = 0x33 *) +abbrev mask03u8 = W8.of_int 3. (* 3 = 0x03 *) +abbrev mask0Fu8 = W8.of_int 15. (* 15 = 0x0F *) - proc ref(bp : W16.t Array768.t) : W8.t Array960.t = { - var rr : W8.t Array960.t; - bp <@ __polyvec_reduce(bp); - rr <@ __i_polyvec_compress_ref(bp); - return rr; - } +lemma mask55_bits16 k: + 0 <= k < 16 => + mask55u256 \bits16 k = mask55u16. +proof. +move=> Hk. +rewrite /VPBROADCAST_8u32. +rewrite bits16_W8u32 Hk //= get_of_list 1:/# /=. +rewrite (nth_map 0) /=; first smt(size_iota). +have: (k%%2 \in iota_ 0 2) by smt(mem_iota). +by move: (k%%2); rewrite -allP -iotaredE /= W2u16.bits16_div //. +qed. +lemma mask55_bits8 k: + 0 <= k < 32 => + mask55u256 \bits8 k = mask55u8. +proof. +move=> Hk. +rewrite /VPBROADCAST_8u32. +rewrite bits8_W8u32 Hk //= get_of_list 1:/# /=. +rewrite (nth_map 0) /=; first smt(size_iota). +have: (k%%4 \in iota_ 0 4) by smt(mem_iota). +by move: (k%%4); rewrite -allP -iotaredE /= W4u8.bits8_div //. +qed. -}. +lemma mask33_bits16 k: + 0 <= k < 16 => + mask33u256 \bits16 k = mask33u16. +proof. +move=> Hk. +rewrite /VPBROADCAST_8u32. +rewrite bits16_W8u32 Hk //= get_of_list 1:/# /=. +rewrite (nth_map 0) /=; first smt(size_iota). +have: (k%%2 \in iota_ 0 2) by smt(mem_iota). +by move: (k%%2); rewrite -allP -iotaredE /= W2u16.bits16_div //. +qed. -lemma compress10_equiv_avx2mem _ctp _mem : - equiv [ AuxPolyVecCompress10.avx2_orig ~ AuxPolyVecCompress10.avx2 : - ={bp} /\ ctp{1} = _ctp /\ Glob.mem{1} = _mem /\ valid_ptr (to_uint ctp{1}) (128 + 3 * 320) ==> - Glob.mem{1} = stores _mem (to_uint _ctp) (to_list res{2}) ]. -proc => /=. -swap {2} 2 -1;seq 1 1 : #pre; 1: by conseq />;inline *;sim. -inline {1} 1; inline {2} 2. -wp. -while (Glob.mem{1} = stores _mem (to_uint _ctp) (take (i{2}*20) (to_list rp{2})) /\ aux{1} = 48 /\ - valid_ptr (to_uint r{1}) (128 + 3 * 320) /\ r{1} = _ctp /\ - ={i,a,aux,sllv_indx, shuffle, shift, mask10, b2, b1, b0} /\ 0 <= i{2} <= 48); last - by auto => />;smt(Array960.size_to_list List.take_size List.take0 storesE iota0). +lemma mask33_bits8 k: + 0 <= k < 32 => + mask33u256 \bits8 k = mask33u8. +proof. +move=> Hk. +rewrite /VPBROADCAST_8u32. +rewrite bits8_W8u32 Hk //= get_of_list 1:/# /=. +rewrite (nth_map 0) /=; first smt(size_iota). +have: (k%%4 \in iota_ 0 4) by smt(mem_iota). +by move: (k%%4); rewrite -allP -iotaredE /= W4u8.bits8_div //. +qed. -seq 3 3 : (#pre /\ ={lo,hi}); - 1: by conseq />; sim. -auto => /> &1 &2 ????;split;last by smt(). -rewrite /storeW32 /storeW128. -apply mem_eq_ext => add. -rewrite !get_storesE !to_uintD_small /= !of_uintK /= 1,2:/# !modz_small 1..2:/#. -rewrite !size_take 1,2:/# /= !size_to_list. -case ((to_uint _ctp <= add && add < to_uint _ctp + MIN ((i{1} + 1) * 20) 960)); last by smt(). -move => *. -case ((to_uint _ctp + MIN (i{1} * 20) 960) <= add && add < to_uint _ctp + MIN (i{1} * 20 + 16) 960). -+ move => *; rewrite ifF 1:/# ifT 1:/# mulrDl /= takeD 1,2:/# nth_cat !size_take 1:/# size_to_list . - have -> /= : add - to_uint _ctp < MIN (i{1} * 20) 960 = false by smt(). - rewrite /to_list drop_mkseq 1:/# take_mkseq 1:/# /= /(\o) /= /mkseq (nth_map witness) /=;1:smt(size_iota). - rewrite nth_iota 1:/# initiE 1:/# get8_set32_directE 1,2:/# /= /get8 initiE 1:/# /= -/WArray960.get8 initiE 1:/# get8_set128_directE /#. -case ((to_uint _ctp + MIN (i{1} * 20+16) 960) <= add && add < to_uint _ctp + MIN (i{1} * 20 + 20) 960). -+ move => *; rewrite ifT 1:/# mulrDl /= takeD 1,2:/# nth_cat !size_take 1:/# size_to_list . - have -> /= : add - to_uint _ctp < MIN (i{1} * 20) 960 = false by smt(). - rewrite /to_list drop_mkseq 1:/# take_mkseq 1:/# /= /(\o) /= /mkseq (nth_map witness) /=;1:smt(size_iota). - rewrite nth_iota 1:/# initiE 1:/# get8_set32_directE 1,2:/# /= /get8 initiE 1:/# /= -/WArray960.get8 initiE 1:/# get8_set128_directE /#. -case (to_uint _ctp <= add && add < to_uint _ctp + MIN (i{1} * 20) 960); last by smt(). -move => *; rewrite ifF 1:/# ifF 1:/# mulrDl /= /to_list !take_mkseq 1,2:/# /= /mkseq !(nth_map witness); 1,2: smt(size_iota). -rewrite !nth_iota 1,2:/# initiE 1:/# get8_set32_directE 1,2:/# /get8 !initiE 1,2:/# /= -/WArray960.get8 get8_set128_directE 1,2:/# /get8 initiE /#. +lemma mask03_bits8 k: + 0 <= k < 32 => + mask03u256 \bits8 k = mask03u8. +proof. +move=> Hk. +rewrite /VPBROADCAST_8u32. +rewrite bits8_W8u32 Hk //= get_of_list 1:/# /=. +rewrite (nth_map 0) /=; first smt(size_iota). +have: (k%%4 \in iota_ 0 4) by smt(mem_iota). +by move: (k%%4); rewrite -allP -iotaredE /= W4u8.bits8_div //. qed. -lemma poly_reduce_noloops : - equiv [ AuxPolyVecCompress10.__poly_reduce ~ M(Syscall).__poly_reduce : - ={arg} ==> ={res} ]. -proc => /=. -while (#pre /\ 0<=j{1} <= 256 /\ j{1} = to_uint j{2}); last by auto. -inline *;auto => /> &2; rewrite !W64.ultE /= => *;do split;1..2:smt();by rewrite to_uintD_small /=; smt(). +lemma mask0F_bits16 k: + 0 <= k < 16 => + mask0Fu256 \bits16 k = mask0Fu16. +proof. +move=> Hk. +rewrite /VPBROADCAST_8u32. +rewrite bits16_W8u32 Hk //= get_of_list 1:/# /=. +rewrite (nth_map 0) /=; first smt(size_iota). +have: (k%%2 \in iota_ 0 2) by smt(mem_iota). +by move: (k%%2); rewrite -allP -iotaredE /= W2u16.bits16_div //. qed. -lemma polyvec_reduce_noloops : - equiv [ AuxPolyVecCompress10.__polyvec_reduce ~ M(Syscall).__polyvec_reduce : - ={arg} ==> ={res} ]. -proc => /=. -by do 3!(wp;call poly_reduce_noloops);auto => />. +lemma mask0F_bits8 k: + 0 <= k < 32 => + mask0Fu256 \bits8 k = mask0Fu8. +proof. +move=> Hk. +rewrite /VPBROADCAST_8u32. +rewrite bits8_W8u32 Hk //= get_of_list 1:/# /=. +rewrite (nth_map 0) /=; first smt(size_iota). +have: (k%%4 \in iota_ 0 4) by smt(mem_iota). +by move: (k%%4); rewrite -allP -iotaredE /= W4u8.bits8_div //. qed. -lemma poly_csubq_noloops : - equiv [ AuxPolyVecCompress10._poly_csubq_ref ~ M(Syscall)._poly_csubq : - ={arg} ==> ={res} ]. -proc => /=. -while (#pre /\ 0<=i{1} <= 256 /\ i{1} = to_uint i{2}); last by auto. -inline *;auto => /> &2; rewrite !W64.ultE /= => *;do split;1..2:smt();by rewrite to_uintD_small /=; smt(). +lemma VPSRL1_ANDmask55 w k: + 0 <= k < 32 => + mask55u256 `&` (VPSRL_16u16 w (W8.of_int 1)) \bits8 k + = mask55u8 `&` ((w \bits8 k) `>>` (W8.of_int 1)). +proof. +move=> Hk. +rewrite {1}(_:k=2*(k%/2) + (k%%2)); first smt(divz_eq). +rewrite -W256_bits16_bits8 1:/# andb16E /VPSRL_16u16 mapbE 1:/# /=. +rewrite W256_bits16_bits8 1:/# mask55_bits8 1:/#. +apply W8extra.wordP_red. rewrite -allP /=. +have: (k\in iota_ 0 32) by smt(mem_iota). +by move: {Hk} k; rewrite -allP -iotaredE /= !W16.shrwE !W8.shrwE /int_bit /=. qed. -lemma polyvec_csubq_noloops : - equiv [ AuxPolyVecCompress10.__polyvec_csubq_ref ~ M(Syscall).__polyvec_csubq : - ={arg} ==> ={res} ]. -proc => /=. -by do 3!(wp;call poly_csubq_noloops);auto => />. +lemma VPSRL2_ANDmask33 w k: + 0 <= k < 32 => + mask33u256 `&` (VPSRL_16u16 w (W8.of_int 2)) \bits8 k + = mask33u8 `&` ((w \bits8 k) `>>` (W8.of_int 2)). +proof. +move=> Hk. +rewrite {1}(_:k=2*(k%/2) + (k%%2)); first smt(divz_eq). +rewrite -W256_bits16_bits8 1:/# andb16E /VPSRL_16u16 mapbE 1:/# /=. +rewrite W256_bits16_bits8 1:/# mask33_bits8 1:/#. +apply W8extra.wordP_red. rewrite -allP /=. +have: (k\in iota_ 0 32) by smt(mem_iota). +by move: {Hk} k; rewrite -allP -iotaredE /= !W16.shrwE !W8.shrwE /int_bit /=. qed. -lemma compress10_equiv_refmem _ctp _mem : - equiv [ AuxPolyVecCompress10.ref ~ AuxPolyVecCompress10.ref_orig : - ={bp} /\ ctp{2} = _ctp /\ Glob.mem{2} = _mem /\ valid_ptr (to_uint ctp{2}) (128 + 3 * 320) ==> - Glob.mem{2} = stores _mem (to_uint _ctp) (to_list res{1}) ]. - proc => /=. -seq 1 1 : #pre; 1: by call polyvec_reduce_noloops => />. -inline {1} 1; inline {2} 1. -swap {1} 3 -1; swap {1} 4 -2; swap {2} [2..3] -1; swap {2} 7 -4. -seq 2 3 : (#pre /\ ={aa});1:by call polyvec_csubq_noloops;auto => />. -wp;while (0 <= i{1} <= 768 /\ i{1} = to_uint i{2} /\ valid_ptr (to_uint rp{2}) (128 + 3 * 320) /\ - 0 <= j{1} <= 960 /\ j{1} = to_uint j{2} /\ rp{2} = _ctp /\ - j{1} *4 = i{1} * 5 /\ ={aa} /\ - Glob.mem{2} = stores _mem (to_uint _ctp) (take j{1} (to_list rp{1}))); last - by auto => />; smt(Array960.size_to_list List.take_size List.take0 storesE iota0). -unroll for {2} 2;auto => /> &1 &2;rewrite !ultE /= => ?????????;do split;1,2,4,5:smt();1..3,5..:by rewrite ?to_uintD_small;smt(). -rewrite /storeW8 /=. -apply mem_eq_ext => adr. - rewrite !to_uintD_small /= 1..16:/# !addrA. -rewrite !get_storesE. -case (to_uint _ctp + to_uint j{2} <= adr < to_uint _ctp + to_uint j{2} + 5); last first. -case (to_uint _ctp <= adr < to_uint _ctp + to_uint j{2}); last first. -+ move => *;do 5!(rewrite get_set_neqE_s 1:/#). - rewrite !size_take 1:/# size_to_list /= ifF 1:/# get_storesE /= size_take 1:/# size_to_list /#. -+ move => *. - move => *;do 5!(rewrite get_set_neqE_s 1:/#). - rewrite !size_take 1:/# size_to_list /= ifT 1:/# nth_take 1,2:/# /to_list nth_mkseq 1:/# /= get_storesE size_take 1:/# size_mkseq /= ifT 1:/#. - by rewrite nth_take 1,2:/# nth_mkseq 1:/# /=; smt(Array960.get_setE). -move => *; rewrite size_take 1:/# size_to_list ifT 1:/# nth_take 1,2:/# /to_list nth_mkseq 1:/# /=. -by smt(Array960.get_setE get_set_neqE_s get_set_eqE_s). +lemma VPSRL4_ANDmask0F w k: + 0 <= k < 32 => + VPAND_256 mask0Fu256 (VPSRL_16u16 w (W8.of_int 4)) \bits8 k + = mask0Fu8 `&` ((w \bits8 k) `>>` (W8.of_int 4)). +proof. +move=> Hk. +rewrite {1}(_:k=2*(k%/2) + (k%%2)); first smt(divz_eq). +rewrite -W256_bits16_bits8 1:/# andb16E /VPSRL_16u16 mapbE 1:/# /=. +rewrite W256_bits16_bits8 1:/# mask0F_bits8 1:/#. +apply W8extra.wordP_red. rewrite -allP /=. +have: (k\in iota_ 0 32) by smt(mem_iota). +by move: {Hk} k; rewrite -allP -iotaredE /= !W16.shrwE !W8.shrwE /int_bit /=. +qed. + +lemma to_uint_mask33 (w:W8.t): + to_uint (mask33u8 `&` w) + = to_uint w %% 4 + to_uint w %/ 16 %% 4 * 16. +proof. +have ->: mask33u8 = (mask03u8 `<<<` 4) `|` mask03u8. + apply W8.wordP => k; rewrite -mem_range /range /=. + by move: k; apply/List.allP; rewrite -iotaredE /int_bit /=. +rewrite andwC andw_orwDr orw_disjoint. + apply W8.wordP => k; rewrite -mem_range /range /=. + by move: k; apply/List.allP; rewrite -iotaredE /int_bit /=. +have ->: w `&` (mask03u8 `<<<` 4) + = ((w `>>>` 4) `&` W8.masklsb (6-4)) `<<<` 4. +rewrite -shlw_andmask // shrl_andmaskN // -andwA /=. +congr. +rewrite /max /=. + apply W8.wordP => k; rewrite -mem_range /range /=. + by move: k; apply/List.allP; rewrite -iotaredE /int_bit /=. +have E1: to_uint (w `&` mask03u8) = to_uint w %% 4. + by rewrite (W8.to_uint_and_mod 2) //. +have /= E2: to_uint ((w `>>>` 4) `&` (masklsb (6-4))%W8 `<<<` 4) = to_uint w %/ 16 %% 4 * 16. + rewrite /max /= to_uint_shl // (W8.to_uint_and_mod 2) //. + by rewrite to_uint_shr //= modz_small /#. +rewrite to_uintD_small /=. + by rewrite E1 E2 /#. +by rewrite E1 E2 /#. qed. -lemma compress10_equiv_avx2i : - equiv [ AuxPolyVecCompress10.avx2_orig_i ~ AuxPolyVecCompress10.avx2 : - ={bp} ==> ={res} ]. -proc => /=. -swap {2} 2 -1. -seq 1 1 : #pre; 1: by sim. -inline *;wp. -while (={i,a,bp,b0,b1,b2,mask10,shift,sllv_indx,shuffle,aux} /\ 0<=i{1} <= 48 /\ aux{1}=48 /\ (forall k, 0<=k rp{1}.[k] = rp{2}.[k])); - last by auto => /> *; split;[ smt() | move => *; rewrite tP => *;smt()]. -auto => /> &1 &2 *;split;1:smt(). -move => k kbl kbh; rewrite !initiE 1,2:/# /=. -rewrite !get8_set32_directE 1..4:/#. -case (0<=k *; rewrite !ifF 1,2:/# /get8 !initiE 1..4:/# /=. - rewrite -/WArray960.get8 !get8_set128_directE 1..4:/# !ifF 1,2:/#. - by rewrite /get8 !initiE /#. -move => *;case (i{2}*20<=k *; rewrite !ifF 1,2:/# /get8 !initiE 1..4:/# /=. - by rewrite -/WArray960.get8 !get8_set128_directE 1..4:/# !ifT /#. -by smt(). +lemma aux_coef_pos b: + to_uint (mask33u8 `&` (mask55u8 `&` b + mask55u8 `&` (b `>>` ru_ones_s))) + = b2i b.[0] + b2i b.[1] + 16 * (b2i b.[4] + b2i b.[5]). +proof. +rewrite addrC -(mask85_sum b 0) // -(mask85_sum b 2) //= !(W8.andwC mask55u8). +by rewrite to_uint_mask33 /(`>>`) to_uint_shr //= to_uint_shr //= /#. qed. -lemma compress10_equiv_refi : - equiv [ AuxPolyVecCompress10.ref ~ AuxPolyVecCompress10.ref_orig_i : - ={bp} ==> ={res} ]. -proc. -seq 1 1 : #pre; 1: by call polyvec_reduce_noloops => />. -inline {1} 1; inline {2} 1. -swap {1} 4 -2;swap {2} [2..3] -1; swap {2} 7 -4. -seq 2 3 : (#pre /\ ={aa});1:by call polyvec_csubq_noloops;auto => />. -wp;while (0 <= i{1} <= 768 /\ i{1} = to_uint i{2} /\ - 0 <= j{1} <= 960 /\ j{1} = to_uint j{2} /\ - j{1} *4 = i{1} * 5 /\ ={aa} /\ - (forall kk, 0 <= kk < j{1} => rp{1}.[kk] = rp{2}.[kk])); last by auto => />*; split;[ smt() | move => *; rewrite tP => *; smt()]. -unroll for {2} 2;auto => /> &1 &2;rewrite !ultE /= => ????????;do split; 1,2,4,5:smt();1..3,5..:by rewrite ?to_uintD_small;smt(). -move => kk kkb ?. -rewrite !to_uintD_small /=;1..7:smt(). -case (kk < to_uint j{2}); by smt(Array960.get_setE). +lemma aux_coef_neg b: + to_uint (mask33u8 `&` ((mask55u8 `&` b + mask55u8 `&` (b `>>` ru_ones_s)) `>>` W8.of_int 2)) + = b2i b.[2] + b2i b.[3] + 16 * (b2i b.[6] + b2i b.[7]). +proof. +rewrite to_uint_mask33 to_uint_shr // -divz_mul //= !(W8.andwC mask55u8). +rewrite {1}(_:4=2^2) // (_:64=2^6) // -(mask85_sum b 1) // -(mask85_sum b 3) //=. +by rewrite to_uint_shr //= to_uint_shr //= /#. qed. -(*****************************************************************) +lemma noise_coef_avx2_aux bytes j: + 3 + noise_coef bytes j + = let b = bytes.[j%/2] in + let x = mask55u8 `&` b + mask55u8 `&` (b `>>` W8.one) in + let y = mask33u8 `&` x + mask33u8 - mask33u8 `&` (x `>>` (W8.of_int 2))in + to_uint y %/ 2^(j%%2*4) %% 16. +proof. +have LL1: forall (x y z:int), (x + z*y) %% z = x %% z. + by move=> x1 x2 x3; rewrite -modzDm modzMr /= modz_mod. +have LL2: forall (x y z:int), (x - z*y) %% z = x %% z. + by move=> x1 x2 x3; rewrite -modzDm -modzNm modzMr /= modz_mod. +move=> /=. +pose b:= bytes.[j %/ 2]. +pose x:= mask55u8 `&` b + mask55u8 `&` (b `>>` ru_ones_s). +case: (j %% 2 = 0) => C. + rewrite C /=. + rewrite -addrA to_uintD /= modz_dvd 1:/#. + rewrite aux_coef_pos W8.to_uintB. + by rewrite ule_andw. + rewrite -modzDm LL1 modzDm aux_coef_neg. + rewrite Ring.IntID.opprD !addzA LL2 /=. + by rewrite -modzDml -(modzDmr _ 51) /= modzDml modz_small /#. +have ->/=: j%%2 = 1 by smt(). +rewrite -addrA to_uintD. +rewrite (_:16=2^4) // modz_pow_div //= modz_mod. +rewrite aux_coef_pos W8.to_uintB. + by rewrite ule_andw. +rewrite aux_coef_neg /= (divz_eq 51 16). +pose X:= (b2i _ + _ + _ + _)%W8. +have /=->: X + = b2i b.[0] + b2i b.[1] + (51 %% 16) - (b2i b.[2] + b2i b.[3]) + + 16 * (b2i b.[4] + b2i b.[5] + (51 %/ 16) - (b2i b.[6] + b2i b.[7])). + by rewrite /X /=; ring. +by rewrite mulzC divzMDr // divz_small //= /#. +qed. +lemma noise_coef_avx2 bytes j: + noise_coef bytes j + = let b = bytes.[j%/2] in + let x = mask55u8 `&` b + mask55u8 `&` (b `>>` W8.one) in + let y = mask33u8 `&` x + mask33u8 - mask33u8 `&` (x `>>` (W8.of_int 2)) in + if j%%2 = 0 + then to_sint (mask0Fu8 `&` y - mask03u8) + else to_sint (mask0Fu8 `&` (y `>>` (W8.of_int 4)) - mask03u8). +proof. +have L1: forall x, W8.to_uint x < 128 => W8.to_sint x = to_uint x. + by move=> x; rewrite to_sintE /smod /= /#. +rewrite /noise_coef /=. +pose b:= bytes.[j %/ 2]. +pose x:= b `&` mask55u8 + (b `>>` ru_ones_s) `&` mask55u8. +pose y:= x `&` mask33u8 + mask33u8 - (x `>>` (W8.of_int 2)) `&` mask33u8. +case: (j %% 2 = 0) => C. + rewrite C /= andwC W8_to_sintB_small. + by rewrite !to_sintE (W8.to_uint_and_mod 4) /smod //= /#. + rewrite L1 (W8.to_uint_and_mod 4) //= /smod /= 1:/#. + move: (noise_coef_avx2_aux bytes j) => /=. + by rewrite C to_sintE /smod => <- /#. +have C': j %% 2 = 1 by smt(). +rewrite C' /= andwC W8_to_sintB_small. + by rewrite !to_sintE (W8.to_uint_and_mod 4) /smod //= /#. +rewrite L1 (W8.to_uint_and_mod 4) //= /smod /= 1:/#. +move: (noise_coef_avx2_aux bytes j) => /=. +by rewrite /noise_coef C' to_sintE /smod to_uint_shr //= => <- /#. +qed. -require import Bindings. -(* BINDINGS *) +lemma to_sint8_mod x: + W8.to_sint x %% W8.modulus = to_uint x. +proof. +rewrite /to_sint /smod. +case: (2 ^ (8 - 1) <= to_uint x) => C. + rewrite -modzDm -modzNm modzz /= modz_mod. + rewrite modz_small //. + by apply JUtils.bound_abs; apply W8.to_uint_cmp. +rewrite modz_small //. +by apply JUtils.bound_abs; apply W8.to_uint_cmp. +qed. + +lemma to_sint8K (x:W8.t): W8.of_int (to_sint x) = x. +proof. by rewrite -of_int_mod to_sint8_mod to_uintK. qed. + +lemma truncateu128_bits128 (w:W256.t): + truncateu128 w = w \bits128 0. +proof. by rewrite /truncateu128 to_uint_eq of_uintK bits128_div // of_uintK. qed. + +hoare cbd2_avx2_h _bytes: + Jkem_avx2.M(Jkem_avx2.Syscall).__cbd2: buf=_bytes ==> res = Array256.init (fun k => W16.of_int (noise_coef _bytes k)). +proof. +proc. +sp; simplify. +while (0 <= i <= 4 /\ #{~i}pre /\ List.all (fun k => rp.[k]=W16.of_int (noise_coef _bytes k)) (iota_ 0 (64*i))). + seq 15: (#pre /\ + all (fun k=> if k%%2 = 0 + then to_sint (f0 \bits8 (k %/ 2)) = noise_coef _bytes (64*i+k) + else to_sint (f1 \bits8 (k %/ 2)) = noise_coef _bytes (64*i+k)) + (iota_ 0 64)). + auto => &m |> ?_ /List.allP H ?; apply/List.allP => k; rewrite mem_iota /= => *. + case: (k%%2=0) => C1. + move: (noise_coef_avx2 buf{m} (64*i{m}+k)). + have ->: (64 * i{m} + k) %% 2 = 0 by smt(). + rewrite /= => ->. + have:= (bytes_getR buf{m} ((64*i{m}+k)%/2) _); first smt(). + rewrite /B2Ri /= -!divz_mul //=. + have ->: (64 * i{m} + k) %/ 64 = i{m}. + by rewrite (mulzC 64) divzMDl // (divz_small _ 64) 1:/# /=. + have ->: (64 * i{m} + k) %/ 2 %% 32 = k %/ 2. + rewrite -(modz_pow_div 2 6 1) //=. + by rewrite (mulzC 64) modzMDl modz_small /#. + move => Eb. + rewrite map2bE 1:/# /= mask0F_bits8 1:/# /=. + rewrite map2bE 1:/#. beta. + rewrite VPSRL2_ANDmask33 1:/#. + rewrite map2bE 1:/#; beta. + rewrite map2bE 1:/#; beta. + rewrite VPSRL1_ANDmask55 1:/#. + rewrite mask33_bits8 1:/# /=. + rewrite map2bE 1:/#; beta. + rewrite VPSRL1_ANDmask55 1:/#. + rewrite mask33_bits8 1:/# /=. + rewrite mask55_bits8 1:/# /=. + rewrite mask03_bits8 1:/# -!Eb. + by congr. + have C2: k %% 2 = 1 by smt(). + move: (noise_coef_avx2 buf{m} (64*i{m}+k)). + have ->: (64 * i{m} + k) %% 2 = 1 by smt(). + rewrite /= => ->. + have:= (bytes_getR buf{m} ((64*i{m}+k)%/2) _); first smt(). + rewrite /B2Ri /= -!divz_mul //=. + have ->: (64 * i{m} + k) %/ 64 = i{m}. + by rewrite (mulzC 64) divzMDl // (divz_small _ 64) 1:/# /=. + have ->: (64 * i{m} + k) %/ 2 %% 32 = k %/ 2. + rewrite -(modz_pow_div 2 6 1) //=. + by rewrite (mulzC 64) modzMDl modz_small /#. + move => Eb. + rewrite map2bE 1:/#; beta. + rewrite VPSRL4_ANDmask0F 1:/#. + rewrite map2bE 1:/#; beta. + rewrite map2bE 1:/#; beta. + rewrite VPSRL2_ANDmask33 1:/#. + rewrite map2bE 1:/#; beta. + rewrite VPSRL1_ANDmask55 1:/#. + rewrite mask33_bits8 1:/# /=. + rewrite map2bE 1:/#; beta. + rewrite VPSRL1_ANDmask55 1:/#. + rewrite mask33_bits8 1:/# /=. + rewrite mask55_bits8 1:/# /=. + rewrite mask03_bits8 1:/# -!Eb. + by congr. + seq 10: (#[/:-2]pre /\ + all (fun (k : int) => + if k %/ 16 = 0 + then f0 \bits16 k%%16 = W16.of_int (noise_coef _bytes (64*i+k)) + else if k %/ 16 = 1 + then f2 \bits16 k%%16 = W16.of_int (noise_coef _bytes (64*i+k)) + else if k %/ 16 = 2 + then f1 \bits16 k%%16 = W16.of_int (noise_coef _bytes (64*i+k)) + else f3 \bits16 k%%16 = W16.of_int (noise_coef _bytes (64*i+k))) + (iota_ 0 64)). + auto => &m |> ?_ /List.allP IH ?. + rewrite -{1}iotaredE /= => |> *. + rewrite -iotaredE /=. + do 32! (split; first by + rewrite /VPMOVSX_16u8_16u16 /VPUNPCKL_32u8 /VPUNPCKL_16u8 /VPUNPCKH_32u8 /VPUNPCKH_16u8 /MOVSX_u16s8 truncateu128_bits128 /interleave_gen /get_lo_2u64 /get_hi_2u64 /= /#). + do 31! (split; first by + rewrite /VPMOVSX_16u8_16u16 /VPUNPCKL_32u8 /VPUNPCKL_16u8 /VPUNPCKH_32u8 /VPUNPCKH_16u8 /MOVSX_u16s8 /VEXTRACTI128 /interleave_gen /get_lo_2u64 /get_hi_2u64 /b2i /= /int_bit /= /#). + by rewrite /VPMOVSX_16u8_16u16 /VPUNPCKL_32u8 /VPUNPCKL_16u8 /VPUNPCKH_32u8 /VPUNPCKH_16u8 /MOVSX_u16s8 /VEXTRACTI128 /interleave_gen /get_lo_2u64 /get_hi_2u64 /b2i /= /int_bit /= /#. + auto => |> &m ? _ /List.allP IH ? /List.allP H. + split; first smt(). + rewrite -!NTT_AVX_Fq.PURE 1..4:/#. + apply/List.allP => k; rewrite mem_iota /= => |> *. + rewrite !NTT_AVX_Fq.PUR_get 1..8:/#. + case: (k %/ 16 = 4 * i{m} + 3) => C1. + move: (H (k %% 64) _) => /=; first smt(mem_iota). + rewrite (modz_pow_div 2 6 4) //= C1 (mulzC 4) modzMDl /=. + rewrite (modz_dvd_pow 4 6 _ 2) //. + have ->: 64 * i{m} + k %% 64 = k by smt(). + by rewrite /R2C /= Array16.initiE /#. + case: (k %/ 16 = 4 * i{m} + 2) => C2. + move: (H (k %% 64) _) => /=; first smt(mem_iota). + rewrite (modz_pow_div 2 6 4) //= C2 (mulzC 4) modzMDl /=. + rewrite (modz_dvd_pow 4 6 _ 2) //. + have ->: 64 * i{m} + k %% 64 = k by smt(). + by rewrite /R2C /= Array16.initiE /#. + case: (k %/ 16 = 4 * i{m} + 1) => C3. + move: (H (k %% 64) _) => /=; first smt(mem_iota). + rewrite (modz_pow_div 2 6 4) //= C3 (mulzC 4) modzMDl /=. + rewrite (modz_dvd_pow 4 6 _ 2) //. + have ->: 64 * i{m} + k %% 64 = k by smt(). + by rewrite /R2C /= Array16.initiE /#. + case: (k %/ 16 = 4 * i{m}) => C4. + move: (H (k %% 64) _) => /=; first smt(mem_iota). + rewrite (modz_pow_div 2 6 4) //= C4 modzMr. + rewrite (modz_dvd_pow 4 6 _ 2) //. + have ->: 64 * i{m} + k %% 64 = k by smt(). + by rewrite /R2C /= Array16.initiE /#. + have ?: k < 64*i{m} by smt(). + by move: (IH k _) => /=; first smt(mem_iota). +auto => &m |> *. +split; first by rewrite iota0. +move => i rp ???; rewrite (_:i=4) 1:/# /=. +move => /List.allP H. +rewrite tP => k Hk; rewrite (H k _); first smt(mem_iota). +by rewrite initiE /#. +qed. -bind array Array256."_.[_]" Array256."_.[_<-_]" Array256.to_list Array256.of_list Array256.t 256. -realize tolistP by done. -realize get_setP by smt(Array256.get_setE). -realize eqP by smt(Array256.tP). -realize get_out by smt(Array256.get_out). +lemma cbd2_ll : islossless Jkem_avx2.M(Jkem_avx2.Syscall).__cbd2. +proc. inline *. sp; wp. while (true) (4-i). move => z. +auto => /> &hr H. smt(). +auto => />i. smt(). qed. +phoare cbd2_avx2_ph _bytes: + [Jkem_avx2.M(Jkem_avx2.Syscall).__cbd2: buf=_bytes ==> res = Array256.init (fun k => W16.of_int (noise_coef _bytes k))] = 1%r. +conseq cbd2_ll (cbd2_avx2_h _bytes) => />. qed. -bind array Array768."_.[_]" Array768."_.[_<-_]" Array768.to_list Array768.of_list Array768.t 768. -realize tolistP by done. -realize get_setP by smt(Array768.get_setE). -realize eqP by smt(Array768.tP). -realize get_out by smt(Array768.get_out). +module AuxMLKEMAvx2 = { + proc cbd2_ref (rp:W16.t Array256.t,buf:W8.t Array128.t) : W16.t Array256.t = { + var k: int; + var a, b, c: W8.t; + var i: W64.t; + var t: W16.t; + i <- (W64.of_int 0); -bind array Array32."_.[_]" Array32."_.[_<-_]" Array32.to_list Array32.of_list Array32.t 32. -realize tolistP by done. -realize get_setP by smt(Array32.get_setE). -realize eqP by smt(Array32.tP). -realize get_out by smt(Array32.get_out). + while ((i \ult (W64.of_int 128))) { + c <- buf.[(W64.to_uint i)]; + a <- c; + a <- (a `&` (W8.of_int 85)); + c <- (c `>>` (W8.of_int 1)); + c <- (c `&` (W8.of_int 85)); + c <- (c + a); + a <- c; + a <- (a `&` (W8.of_int 3)); + b <- c; + b <- (b `>>` (W8.of_int 2)); + b <- (b `&` (W8.of_int 3)); + a <- (a - b); + t <- (sigextu16 a); + rp.[W64.to_uint (W64.of_int 2 * i)] <- t; + a <- c; + a <- (a `>>` (W8.of_int 4)); + a <- (a `&` (W8.of_int 3)); + b <- (c `>>` (W8.of_int 6)); + b <- (b `&` (W8.of_int 3)); + a <- (a - b); + t <- (sigextu16 a); + rp.[W64.to_uint (W64.of_int 2 * i + W64.one)] <- t; + i <- (i + (W64.of_int 1)); + } + return (rp); + } + proc _poly_getnoise (rp:W16.t Array256.t, seed:W8.t Array32.t,nonce:W8.t) : W16.t Array256.t = { + var buf:W8.t Array128.t; + var r; -bind array Array960."_.[_]" Array960."_.[_<-_]" Array960.to_list Array960.of_list Array960.t 960. -realize tolistP by done. -realize get_setP by smt(Array960.get_setE). -realize eqP by smt(Array960.tP). -realize get_out by smt(Array960.get_out). + buf <- witness; + buf <@ M(Syscall)._shake256_128_33 (buf,Array33.init (fun i => if i=32 then nonce else seed.[i])); + r <@ cbd2_ref(rp,buf); + return r; + } + proc __poly_getnoise_eta1_4x(aux3 aux2 aux1 aux0 : W16.t Array256.t, + noiseseed : W8.t Array32.t, + nonce : W8.t) : + W16.t Array256.t * W16.t Array256.t * W16.t Array256.t * W16.t Array256.t = { + var n3, n2, n1, n0 : W8.t; + var aux_3, aux_2, aux_1, aux_0 : W16.t Array256.t; + n0 <- nonce + W8.of_int 3; + n1 <- nonce + W8.of_int 2; + n2 <- nonce + W8.of_int 1; + n3 <- nonce; + aux_3 <@ _poly_getnoise(aux3,noiseseed,n3); + aux_2 <@ _poly_getnoise(aux2,noiseseed,n2); + aux_1 <@ _poly_getnoise(aux1,noiseseed,n1); + aux_0 <@ _poly_getnoise(aux0,noiseseed,n0); + return (aux_3, aux_2, aux_1, aux_0); + } +}. -bind array Array1088."_.[_]" Array1088."_.[_<-_]" Array1088.to_list Array1088.of_list Array1088.t 1088. -realize tolistP by done. -realize get_setP by smt(Array1088.get_setE). -realize eqP by smt(Array1088.tP). -realize get_out by smt(Array1088.get_out). +hoare cbd2_ref_h _bytes: + AuxMLKEMAvx2.cbd2_ref: buf=_bytes ==> res = Array256.init (fun k => W16.of_int (noise_coef _bytes k)). +proof. +proc. +while (to_uint i <= 128 /\ #pre /\ List.all (fun k => rp.[k]=W16.of_int (noise_coef _bytes k)) (iota_ 0 (2 * to_uint i))). + auto => &m |>; rewrite /(\ult) => _ /List.allP IH /= Hi. + rewrite to_uintD_small /= 1:/#. + split; first smt(). + apply/List.allP => k; rewrite mem_iota /=; move => [? Hk]. + rewrite to_uintD_small !to_uintM_small /= 1..3:/#. + case: (k = 2 * to_uint i{m}) => C1. + rewrite /noise_coef !get_setE 1..2:/# C1 /= ifF 1:/#. + have ->/=: 2 * to_uint i{m} %/ 2 = to_uint i{m} by smt(). + rewrite -to_sint_eq sigextu16_to_sint (_: 3 = 2^2 -1) // !and_mod //= W8_of_sintK_signed /=; 1: smt(). + have -> /= : 2 * to_uint i{m} %% 2 = 0 by smt(). + by rewrite -parallel_noisesum_low smod_small // /#. + case: (k = 2 * to_uint i{m}+1) => C2. + rewrite /noise_coef !get_setE 1..2:/# C2 /=. + have ->/=: (2 * to_uint i{m} + 1) %/ 2 = to_uint i{m} by smt(). + rewrite -to_sint_eq sigextu16_to_sint (_: 3 = 2^2 -1) // !and_mod //= W8_of_sintK_signed /=; 1: smt(). + have -> /= : (2 * to_uint i{m}+1) %% 2 = 1 by smt(). + by rewrite -parallel_noisesum_high smod_small // /#. + rewrite !get_setE 1..2:/# C1 C2 /=; apply IH. + smt(mem_iota). +auto => &m |> *. +split; first by rewrite iota0. +move=> i rp; rewrite /(\ult) => |> ??. +have ->/=: to_uint i = 128 by smt(). +rewrite tP => /List.allP H k Hk. +rewrite (H k _) /=. + smt(mem_iota). +by rewrite initiE //. +qed. -bind array Array4."_.[_]" Array4."_.[_<-_]" Array4.to_list Array4.of_list Array4.t 4. -realize tolistP by done. -realize get_setP by smt(Array4.get_setE). -realize eqP by smt(Array4.tP). -realize get_out by smt(Array4.get_out). +lemma cbd2_ref_ll : islossless AuxMLKEMAvx2.cbd2_ref. +proc. inline*. sp; wp. while (true) (128 - W64.to_uint i). move => z. +auto => /> &hr. rewrite ultE of_uintK //= => H. rewrite to_uintD_small //= /#. +auto => />i H. rewrite ultE of_uintK //= 1:/#. qed. +phoare cbd2_ref_ph _bytes: + [AuxMLKEMAvx2.cbd2_ref: buf=_bytes ==> res = Array256.init (fun k => W16.of_int (noise_coef _bytes k))] = 1%r. +conseq cbd2_ref_ll (cbd2_ref_h _bytes) => />. qed. -op init_256_16 (f: int -> W16.t) : W16.t Array256.t = Array256.init f. +equiv getnoise_split : + M(Syscall)._poly_getnoise ~ AuxMLKEMAvx2._poly_getnoise : ={arg} ==> ={res}. +proc; wp; sp => />. +seq 2 0 : ( ={buf,rp,seed,nonce} /\ extseed{1}=Array33.init (fun i => if i=32 then nonce{1} else seed{1}.[i]) ). +wp. while{1} (0 <= k{1} <= 32 /\ (forall i, 0 <= i < k{1} => extseed{1}.[i]=seed{1}.[i])) (32-k{1}). +auto => /> &m H1 H2 H3 H4. split. split. smt(). move => i Hi1 Hi2. rewrite get_setE 1:/#. smt(). smt(). +auto => /> &m. split. move => i Hi1 Hi2. rewrite !get_out /#. move => extseed k. split. smt(). move => H1 H2 H3 H4. rewrite tP => i Hi. rewrite !initiE 1:/# => />. rewrite get_setE 1:/#. smt(). +seq 1 1 : (#pre). +call (_:true). auto => />. sim. auto => />. +inline *. sim. qed. -bind op [W16.t & Array256.t] init_256_16 "ainit". -realize bvainitP. -proof. -rewrite /init_256_16 => f. -rewrite BVA_Top_Array256_Array256_t.tolistP. -apply eq_in_mkseq => i i_bnd; -smt(Array256.initE). -qed. +equiv getnoise_1x_equiv_avx : + Jkem_avx2.M(Jkem_avx2.Syscall).__poly_cbd_eta1 ~ AuxMLKEMAvx2.cbd2_ref : ={arg} ==> ={res}. +proc*. inline Jkem_avx2.M(Jkem_avx2.Syscall).__poly_cbd_eta1. rcondt{1} 3. auto => />. sp;wp. +ecall{1} (cbd2_avx2_ph buf{1}) => />. +ecall{2} (cbd2_ref_ph buf{2}) => />. +auto => /> &2. rewrite tP => i Hi. rewrite initiE //=. qed. -op init_768_16 (f: int -> W16.t) : W16.t Array768.t = Array768.init f. +equiv getnoise_4x_split : + GetNoiseAVX2._poly_getnoise_eta1_4x ~ AuxMLKEMAvx2.__poly_getnoise_eta1_4x : ={arg} ==> ={res}. +proc; wp; sp => />. call getnoise_split => />. call getnoise_split => />. call getnoise_split => />. call getnoise_split => />. auto => />. qed. -print Array768.initE. +equiv getnoiseequiv_avx : + Jkem_avx2.M(Jkem_avx2.Syscall)._poly_getnoise_eta1_4x ~ GetNoiseAVX2._poly_getnoise_eta1_4x : ={arg} ==> ={res}. +proc*. +transitivity{2} { r <@ AuxMLKEMAvx2.__poly_getnoise_eta1_4x(aux3,aux2,aux1,aux0,noiseseed,nonce); } ((r0{1}, r1{1}, r2{1}, r3{1}, seed{1}, nonce{1}) = (aux3{2}, aux2{2}, aux1{2}, aux0{2}, noiseseed{2}, nonce{2}) ==> ={r}) (={aux3,aux2,aux1,aux0,noiseseed,nonce} ==> ={r}); last first. +symmetry. call getnoise_4x_split => />. auto => />. smt(). smt(). +(*main proof*) +inline Jkem_avx2.M(Jkem_avx2.Syscall)._poly_getnoise_eta1_4x AuxMLKEMAvx2.__poly_getnoise_eta1_4x AuxMLKEMAvx2._poly_getnoise. swap{2} [30..31] 5. swap{2} [23..24] 10. swap{2} [16..17] 15. +seq 25 30 : ( + r00{1}=rp{2} /\ Array128.init (fun (i : int) => buf0{1}.[i]) =buf{2} + /\ r10{1}=rp0{2} /\ Array128.init (fun (i : int) => buf1{1}.[i]) =buf0{2} + /\ r20{1}=rp1{2} /\ Array128.init (fun (i : int) => buf2{1}.[i]) =buf1{2} + /\ r30{1}=rp2{2} /\ Array128.init (fun (i : int) => buf3{1}.[i]) =buf2{2} +). +sp => />. +ecall{2} (shake256_33_128 buf2{2} (Array33.init (fun i => if i = 32 then nonce4{2} else seed2{2}.[i]))); wp => />. +ecall{2} (shake256_33_128 buf1{2} (Array33.init (fun i => if i = 32 then nonce3{2} else seed1{2}.[i]))); wp => />. +ecall{2} (shake256_33_128 buf0{2} (Array33.init (fun i => if i = 32 then nonce2{2} else seed0{2}.[i]))); wp => />. +ecall{2} (shake256_33_128 buf{2} (Array33.init (fun i=> if i = 32 then nonce1{2} else seed{2}.[i]))); wp => />. +ecall{1} (shake_squeezenblocks4x state{1} buf0{1} buf1{1} buf2{1} buf3{1}); wp => />. +ecall{1} (shake_absorb4x state{1} (Array33.init (fun i => buf0{1}.[i])) (Array33.init (fun i => buf1{1}.[i])) (Array33.init (fun i => buf2{1}.[i])) (Array33.init (fun i => buf3{1}.[i])) ); wp => />. +auto => /> &2. rewrite shake4x_equiv => />. +rewrite tP => k Hk; rewrite !initiE //= 1..3:/#; rewrite ifF 1:/#; rewrite /get8 /init8 set_neqiE 1:/#; rewrite initiE //= 1:/#; rewrite initiE //= 1:/#; rewrite set256E initiE //= 1:/#; rewrite ifT //; rewrite /get256_direct bits8_W32u8 //=; rewrite ifT //; rewrite initiE //=; rewrite initiE //=. +rewrite tP => k Hk; rewrite !initiE //= 1..3:/#; rewrite ifF 1:/#; rewrite /get8 /init8 set_neqiE 1:/#; rewrite initiE //= 1:/#; rewrite initiE //= 1:/#; rewrite set256E initiE //= 1:/#; rewrite ifT //; rewrite /get256_direct bits8_W32u8 //=; rewrite ifT //; rewrite initiE //=; rewrite initiE //=. +rewrite tP => k Hk; rewrite !initiE //= 1..3:/#; rewrite ifF 1:/#; rewrite /get8 /init8 set_neqiE 1:/#; rewrite initiE //= 1:/#; rewrite initiE //= 1:/#; rewrite set256E initiE //= 1:/#; rewrite ifT //; rewrite /get256_direct bits8_W32u8 //=; rewrite ifT //; rewrite initiE //=; rewrite initiE //=. +rewrite tP => k Hk; rewrite !initiE //= 1..3:/#; rewrite ifF 1:/#; rewrite /get8 /init8 set_neqiE 1:/#; rewrite initiE //= 1:/#; rewrite initiE //= 1:/#; rewrite set256E initiE //= 1:/#; rewrite ifT //; rewrite /get256_direct bits8_W32u8 //=; rewrite ifT //; rewrite initiE //=; rewrite initiE //=. +wp. call getnoise_1x_equiv_avx => />. +wp. call getnoise_1x_equiv_avx => />. +wp. call getnoise_1x_equiv_avx => />. +wp. call getnoise_1x_equiv_avx => />. +auto => />. qed. -bind op [W16.t & Array768.t] init_768_16 "ainit". -realize bvainitP. -rewrite /init_768_16 => f. -rewrite BVA_Top_Array768_Array768_t.tolistP. -apply eq_in_mkseq => i i_bnd; smt(Array768.initE). +lemma polygetnoise_ll : islossless Jkem.M(Jkem.Syscall)._poly_getnoise. +proc. +while (0 <= to_uint i <= 128) (128 - to_uint i); + 1: by move => z; auto => />;rewrite ultE /= => &hr ???; rewrite !to_uintD_small /=; smt(to_uint_cmp). +wp; call sha3ll; wp; while (0<=k<=32) (32 -k); 1: by move => z; auto=> /> /#. +auto => /> *; do split; 1:smt(). +by move => *; rewrite ultE /=; smt(). qed. -op init_4_64 (f: int -> W64.t) : W64.t Array4.t = Array4.init f. - -bind op [W64.t & Array4.t] init_4_64 "ainit". -realize bvainitP. -proof. -rewrite /init_4_64 => f. -rewrite BVA_Top_Array4_Array4_t.tolistP. -apply eq_in_mkseq => i i_bnd; smt(Array4.initE). +equiv getnoiseequiv : + Jkem.M(Jkem.Syscall)._poly_getnoise ~Jkem.M(Jkem.Syscall)._poly_getnoise : + ={arg} ==> ={res} /\ + signed_bound_cxq res{1} 0 256 1. +have H : forall &m a, + Pr[Jkem.M(Jkem.Syscall)._poly_getnoise(a) @ &m : forall k, 0<=k<256 => -5 < to_sint res.[k] < 5] = 1%r. ++ move => &m a. + have -> : 1%r = Pr [ CBD2.sample(PRF a.`2 a.`3) @ &m : true]. + + byphoare => //. + proc; inline *; while (0<=i<=128) (128-i); 1: by move => z; auto => /> /#. + by auto => /> /#. + by byequiv get_noise_sample_noise => //. +have HH0 : hoare [Jkem.M(Jkem.Syscall)._poly_getnoise : true ==> forall k, 0<=k<256 => -5 < to_sint res.[k] < 5]. ++ hoare; bypr => //= &m; rewrite Pr[mu_not]. + have -> : Pr[Jkem.M(Jkem.Syscall)._poly_getnoise(rp{m}, s_seed{m}, nonce{m}) @ &m : true] = 1%r. + + by byphoare => //; apply polygetnoise_ll. + smt(). +have HHH : equiv [ Jkem.M(Jkem.Syscall)._poly_getnoise ~Jkem.M(Jkem.Syscall)._poly_getnoise : ={arg} ==> ={res} ] by sim. +conseq HHH HH0. +move => *; rewrite /signed_bound_cxq /b16 qE /#. qed. -op init_960_8 (f: int -> W8.t) : W8.t Array960.t = Array960.init f. - -bind op [W8.t & Array960.t] init_960_8 "ainit". -realize bvainitP. -proof. -rewrite /init_960_8 => f. -rewrite BVA_Top_Array960_Array960_t.tolistP. -apply eq_in_mkseq => i i_bnd; smt(Array960.initE). -qed. +import InnerPKE. -op init_1088_8 (f: int -> W8.t) : W8.t Array1088.t = Array1088.init f. +lemma mlkem_correct_kg_avx2 mem _pkp _skp : + equiv [Jkem_avx2.M(Jkem_avx2.Syscall).__indcpa_keypair ~ InnerPKE.kg_derand : + Glob.mem{1} = mem /\ to_uint pkp{1} = _pkp /\ to_uint skp{1} = _skp /\ + randomnessp{1} = coins{2} /\ + valid_disj_reg _pkp (384*3+32) _skp (384*3) + ==> + touches2 Glob.mem{1} mem _pkp (384*3+32) _skp (384*3) /\ + let (pk,sk) = res{2} in let (t,rho) = pk in + sk = load_array1152 Glob.mem{1} _skp /\ + t = load_array1152 Glob.mem{1} _pkp /\ + rho = load_array32 Glob.mem{1} (_pkp+1152)]. +proc*. +transitivity {1} {Jkem.M(Jkem.Syscall).__indcpa_keypair(pkp, skp, randomnessp);} +(={Glob.mem,pkp,skp,randomnessp} /\ + Glob.mem{1} = mem /\ + to_uint pkp{1} = _pkp /\ + to_uint skp{1} = _skp /\ + randomnessp{1} = randomnessp{2} /\ + valid_disj_reg _pkp (384 * 3 + 32) _skp (384 * 3) ==> ={Glob.mem}) +( Glob.mem{1} = mem /\ to_uint pkp{1} = _pkp /\ to_uint skp{1} = _skp /\ + randomnessp{1} = coins{2} /\ + valid_disj_reg _pkp (384*3+32) _skp (384*3) + ==> + touches2 Glob.mem{1} mem _pkp (384*3+32) _skp (384*3) /\ + let (pk, sk) = r{2} in + let (t, rho) = pk in + sk = load_array1152 Glob.mem{1} _skp /\ + t = load_array1152 Glob.mem{1} _pkp /\ + rho = load_array32 Glob.mem{1} (_pkp + 1152)); 1,2: smt(); + last by call(mlkem_correct_kg mem _pkp _skp); auto => />. -bind op [W8.t & Array1088.t] init_1088_8 "ainit". -realize bvainitP. -proof. -rewrite /init_1088_8 => f. -rewrite BVA_Top_Array1088_Array1088_t.tolistP. -apply eq_in_mkseq => i i_bnd; smt(Array1088.initE). -qed. +inline{1} 1; inline {2} 1. sim 40 62. -op sliceget256_16_256 (arr: W16.t Array256.t) (offset: int) : W256.t = - if 8 %| offset then - get256_direct ((init16 (fun (i_0 : int) => arr.[i_0])))%WArray512 (offset %/ 8) - else W256.bits2w (take 256 (drop offset (flatten (map W16.w2bits (to_list arr))))). +call (polyvec_tobytes_equiv _pkp). +call (polyvec_tobytes_equiv _skp). +wp;conseq />;1:smt(). +ecall (polyvec_reduce_equiv (lift_array768 pkpv{2})). -(* -lemma flatten_take_drop_16 (l : W16.t list) (csize offset bit : int) : - 0 <= offset => - offset + csize <= 16 * size l => - 0 <= bit < csize => - nth false (take csize (drop offset (flatten (map W16.w2bits l)))) bit = - (nth witness l ((offset + bit) %/ 16)).[(offset + bit) %% 16]. -proof. -move => *. -rewrite nth_take 1,2:/#. -rewrite nth_drop 1,2:/#. -rewrite (BitEncoding.BitChunking.nth_flatten false 16). -+ rewrite allP => i; rewrite mapP => He;elim He;smt(W16.size_w2bits). -rewrite -get_w2bits;congr. -by rewrite (nth_map witness) 1:/#. -qed. +have H := polyvec_add2_equiv 2 2 _ _ => //. +ecall (H (lift_array768 pkpv{2}) (lift_array768 e{2})); clear H. +unroll for* {1} 36. +sp 3 3. -lemma aligned_get256_16_256 arr offset : -0 <= offset <= 16*256 - 256 => -256 %| offset => -sliceget256_16_256 arr offset = -WArray512.get256 (WArray512.init16 (fun (i_0 : int) => arr.[i_0])) (offset %/ 256). -move => Ho1 Ho2; rewrite /sliceget256_16_256. -have sz : size (take 256 (drop offset (flatten (map W16.w2bits (to_list arr))))) = 256 by rewrite size_take 1:/# size_drop 1:/# /max /=;smt(Array256.size_to_list size_flatten_W16_w2bits). -rewrite wordP => i ib; rewrite get_bits2w //. -rewrite flatten_take_drop_16;1..3:smt(Array256.size_to_list). -rewrite nth_mkseq 1:/# /=. -rewrite /get256_direct /pack32_t initiE 1:/# /= initiE 1:/# /= initiE 1:/# /=. -rewrite get_bits8 1:/#. -smt(@IntDiv). -qed. +seq 15 17 : (#pre /\ ={publicseed, noiseseed,e,skpv,pkpv} /\ sskp{2} = skp{1} /\ spkp{2} = pkp{1}); 1: by + sp; conseq />; sim 2 2; call( sha3equiv); conseq />; sim. -*) -bind op [W16.t & W256.t & Array256.t] sliceget256_16_256 "asliceget". -realize bvaslicegetP. -move => /= arr offset; rewrite /sliceget256_16_256 /= => H k kb. -case (8%| offset) => /= *; last by smt(W256.get_bits2w). -rewrite /get256_direct pack32E initiE 1:/# /= initiE 1:/# /= initiE 1:/# /= bits8E initiE 1:/# /=. -rewrite nth_take 1,2:/# nth_drop 1,2:/#. -rewrite (BitEncoding.BitChunking.nth_flatten false 16 _). -+ rewrite allP => x /=; rewrite mapP => He; elim He;smt(W16.size_w2bits). -rewrite (nth_map W16.zero []); 1: smt(Array256.size_to_list). -by rewrite nth_mkseq /#. -qed. +sp 0 2. +seq 2 2 : (#pre /\ aa{1} = nttunpackm a{2} /\ + pos_bound2304_cxq aa{1} 0 2304 2 /\ + pos_bound2304_cxq a{2} 0 2304 2); 1: by + conseq />; call (genmatrixequiv false); auto => />. -import BitEncoding BS2Int BitChunking. +swap {1} [11..12] 2. -op sliceset256_16_256 (arr: W16.t Array256.t) (offset: int) (bv: W256.t) : W16.t Array256.t = - if 8 %| offset - then (init (fun (i3 : int) => get16 (set256_direct ((init16 (fun (i_0 : int) => arr.[i_0])))%WArray512 (offset %/ 8) bv) i3))%Array256 - else Array256.of_list witness (map W16.bits2w (chunk 16 (take offset (flatten (map W16.w2bits (to_list arr))) ++ w2bits bv ++ - drop (offset + 256) (flatten (map W16.w2bits (to_list arr)))))). +seq 10 18 : (#pre /\ + signed_bound768_cxq skpv{1} 0 768 1 /\ + signed_bound768_cxq e{1} 0 768 1 /\ + signed_bound768_cxq skpv{2} 0 768 1 /\ + signed_bound768_cxq e{2} 0 768 1). ++ conseq />. + transitivity {1} { (skpv,e) <@ GetNoiseAVX2.sample_noise_kg(skpv,pkpv,e,noiseseed);} (={noiseseed,skpv,pkpv,e} ==> ={skpv,e}) + ((r_noiseseed{2} = noiseseed{2} /\ + s_noiseseed{2} = r_noiseseed{2} /\ + (spkp{2} = pkp{2} /\ + sskp{2} = skp{2} /\ + randomnessp0{2} = randomnessp{2} /\ + pkp0{1} = pkp{1} /\ + skp0{1} = skp{1} /\ + randomnessp0{1} = randomnessp{1} /\ + ={Glob.mem, pkp, skp, randomnessp} /\ + Glob.mem{1} = mem /\ + to_uint pkp{1} = _pkp /\ + to_uint skp{1} = _skp /\ ={randomnessp} /\ valid_disj_reg _pkp (384 * 3 + 32) _skp (384 * 3)) /\ + ={publicseed, noiseseed, e, skpv, pkpv} /\ sskp{2} = skp{1} /\ spkp{2} = pkp{1}) /\ + aa{1} = nttunpackm a{2} /\ pos_bound2304_cxq aa{1} 0 2304 2 /\ pos_bound2304_cxq a{2} 0 2304 2 -(* -lemma aligned_set256_16_256 arr offset bv : -0 <= offset <= 16*256 - 256 => -256 %| offset => -sliceset256_16_256 arr offset bv = -Array256.init (fun (i3 : int) => get16 (set256 ((init16 (fun (i_0 : int) => arr.[i_0])))%WArray512 (offset %/ 256) bv) i3). -rewrite /sliceset256_16_256 tP /= => ?? i ib. -rewrite !initiE 1,2:/# /=. -rewrite get16_set256E 1,2:/# /= (nth_map []). -+ rewrite size_chunk // !size_cat !size_take 1:/# !size_drop 1:/# /max /=. - by smt(Array256.size_to_list size_flatten_W16_w2bits). -rewrite JWordList.nth_chunk //= 1:/#. -rewrite !size_cat !size_take 1:/# !size_drop 1:/# /max /=. - by smt(Array256.size_to_list size_flatten_W16_w2bits). -case (32 * (offset %/ 256) <= 2 * i);last first. -+ move => ? /=. have ? : 16*i < offset. smt(). - rewrite get16_init16 1:/# -catA drop_cat ifT;1: by smt(size_take W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). - rewrite take_cat_le ifT;1: by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). - have -> : offset = 16 * (offset %/ 16) by smt(). - rewrite take_flatten_ctt; 1: by smt(mapP W16.size_w2bits). - rewrite -map_take. - rewrite -(W16.w2bitsK arr.[i]);congr. - apply (eq_from_nth false). - + rewrite size_w2bits size_take // size_drop 1:/# /= /max /=;smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). - move => k kb; rewrite flatten_take_drop_16 1:/#. - + rewrite size_take 1:/# size_to_list //= 1:/#. - by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). - rewrite nth_take 1:/#. smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). - rewrite get_w2bits;congr; rewrite ?get_to_list;smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). -case (2 * i < 32 * (offset %/ 256 + 1));last first. -+ move => ? /=. have ? : offset + 256 <= 16*i . smt(). - rewrite get16_init16 1:/# -catA drop_cat ifF;1: by smt(size_take W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). - rewrite drop_cat ifF;1: by smt(size_take W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). - rewrite size_take 1:/# size_flatten_W16_w2bits size_to_list /= ifT 1:/#. - have -> : offset + 256 = 16 * ((offset + 256) %/ 16) by smt(). - rewrite drop_flatten_ctt; 1: by smt(mapP W16.size_w2bits). - have -> : 16 * i - offset - 256 = 16 * (i - offset %/ 16 - 16) by smt(). - rewrite drop_flatten_ctt; 1: by smt(mapP W16.size_w2bits mem_drop). - rewrite drop_drop 1,2:/# /= => ?. - rewrite -(W16.w2bitsK arr.[i]);congr. - apply (eq_from_nth false). - + rewrite -map_drop size_take // size_flatten_W16_w2bits size_drop 1:/#; smt(Array256.size_to_list W16.size_w2bits). - move => k kb. - have -> : i - offset %/ 16 - 16 + (offset + 256) %/ 16 = i by smt(). - rewrite -(drop_flatten_ctt 16); 1: smt(mapP W16.size_w2bits). - rewrite flatten_take_drop_16; 1..3: by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). - rewrite get_w2bits;congr; rewrite ?get_to_list;smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). + ==> + ={skpv, e} /\ + signed_bound768_cxq skpv{1} 0 768 1 /\ + signed_bound768_cxq e{1} 0 768 1 /\ signed_bound768_cxq skpv{2} 0 768 1 /\ signed_bound768_cxq e{2} 0 768 1 + ); 1,2:smt(). + + by inline {2} 1;do 2!(wp; call getnoiseequiv_avx);auto => />. + inline {1} 1. inline GetNoiseAVX2._poly_getnoise_eta1_4x. + wp; do 2!(call{1} (_: true ==> true); 1: by apply polygetnoise_ll). + do 6!(wp; call getnoiseequiv); auto => />. + move => &1 &2 ??????R?; split. + + by rewrite tP => k kb; rewrite !initiE //= initiE /#. + move => ?R0?; split. + + rewrite tP => k kb; rewrite !initiE //= initiE 1:/# /= initiE 1:/# /= /#. + move => ?R1?????; split. + + rewrite tP => k kb; rewrite !initiE //= initiE 1:/# /= initiE 1:/# /= initiE 1:/# /= /#. + move => ?R2?; do split. + + rewrite /signed_bound768_cxq => x xb /=. + rewrite !initiE //= fun_if. + case (512 <= x && x < 768); 1: by smt(). + move => *; rewrite !initiE //= fun_if. + case (256 <= x && x < 512); 1: by smt(). + move => *; rewrite !initiE //= fun_if. + by smt(). + + rewrite /signed_bound768_cxq => x xb /=. + rewrite !initiE //= fun_if. + case (512 <= x && x < 768); 1: by smt(). + move => *; rewrite !initiE //= fun_if. + case (256 <= x && x < 512); 1: by smt(). + move => *; rewrite !initiE //= fun_if. + by smt(). + + rewrite /signed_bound768_cxq => x xb /=. + rewrite !initiE //= fun_if. + case (512 <= x && x < 768); 1: by smt(). + move => *; rewrite !initiE //= fun_if. + case (256 <= x && x < 512); 1: by smt(). + move => *; rewrite !initiE //= fun_if. + by smt(). + rewrite /signed_bound768_cxq => x xb /=. + rewrite !initiE //= fun_if. + case (512 <= x && x < 768); 1: by smt(). + move => *; rewrite !initiE //= fun_if. + case (256 <= x && x < 512); 1: by smt(). + move => *; rewrite !initiE //= fun_if. + by smt(). -+ move => ?? /=. have ? : offset <= 16*i < offset + 256. smt(). - rewrite -!catA drop_cat ifF;1: by smt(size_take W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). - rewrite !drop_cat ifT;1: by smt(size_take W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). - rewrite size_take 1:/# size_flatten_W16_w2bits size_to_list /= ifT 1:/#. - rewrite take_cat_le ifT;1: by rewrite size_drop 1:/# size_w2bits /= /max ifT /#. - rewrite -(W16.w2bitsK ((bv \bits16 i - 16 * (offset %/ 256))));congr. - apply (eq_from_nth false). - + rewrite size_take // size_drop 1:/#; smt(Array256.size_to_list W16.size_w2bits). - move => k kb. - rewrite nth_take; 1,2: by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). - rewrite nth_drop; 1,2: by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). - rewrite !get_w2bits get_bits16;by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). -qed. -*) +seq 2 2 : (#{/~skpv{1}}{~e{1}}{~skpv{2}}{~e{2}}pre /\ + lift_array768 skpv{1} = nttunpackv (lift_array768 skpv{2}) /\ + lift_array768 e{1} = nttunpackv (lift_array768 e{2}) /\ + pos_bound768_cxq skpv{1} 0 768 2 /\ + pos_bound768_cxq skpv{2} 0 768 2 /\ + pos_bound768_cxq e{1} 0 768 2 /\ + pos_bound768_cxq e{2} 0 768 2); 1: + by conseq />; call (nttequiv); call (nttequiv); auto => /> /#. -lemma size_flatten_W16_w2bits (a : W16.t list) : - (size (flatten (map W16.w2bits (a)))) = 16 * size a. -proof. - rewrite size_flatten -map_comp /(\o) /=. - rewrite StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /=. - rewrite StdBigop.Bigint.big_constz count_predT /#. -qed. +(* First ip *) +seq 8 4: (#{/~pkpv{2}}pre /\ + lift_array256 (subarray256 pkpv{1} 0) = nttunpack (lift_array256 (subarray256 pkpv{2} 0)) /\ + signed_bound768_cxq pkpv{1} 0 256 2 /\ + signed_bound768_cxq pkpv{2} 0 256 2 /\ i{1} = 1). +wp; call frommontequiv; wp; call pointwiseequiv; auto => />. +move => &1 &2 H2 H3 H4 H5 H6 H7 H8 H9 H10 H11 H12 H13 H14; do split. ++ rewrite -lift768_nttunpack. congr. + rewrite /nttunpackm /nttunpackv tP /= => k kb. + rewrite !initiE // 1:/# /= kb /= initiE //=. ++ rewrite /signed_bound768_cxq => k kb; rewrite initiE //=. ++ rewrite /unpackm /unpackv /=. + rewrite !initiE // 1:/# /= kb /= initiE //=. + rewrite fun_if. + case (0<=k<256). + + move => kbb;rewrite /subarray256. + move : (nttunpack_pred (Array256.init (fun (k0 : int) => (subarray768 a{2} 0).[256 * 0 + k0])) (fun x => -2*q <= W16.to_sint x < 2*q)). + rewrite !allP; move => /= [h0 h1]; rewrite h1. move => *. rewrite initiE //=. smt(Array768.initiE). smt(). + case (256<=k<512). + + move => kbb;rewrite /subarray256. + move : (nttunpack_pred (Array256.init (fun (k0 : int) => (subarray768 a{2} 0).[256 * 1 + k0])) (fun x => -2*q <= W16.to_sint x < 2*q)). + rewrite !allP; move => /= [h0 h1]; rewrite h1. move => *. rewrite initiE //=. smt(Array768.initiE). smt(). auto. + case (512<=k<768). + + move => kbbb;rewrite /subarray256. + move : (nttunpack_pred (Array256.init (fun (k0 : int) => (subarray768 a{2} 0).[256 * 2 + k0])) (fun x => -2*q <= W16.to_sint x < 2*q)). + rewrite !allP; move => /= [h0 h1]; rewrite h1. move => *. rewrite initiE //=. smt(Array768.initiE). smt(). auto. + by smt(). ++ move : H10; rewrite /pos_bound768_cxq /signed_bound_768_cxq /#. ++ move : H8; rewrite /pos_bound768_cxq /signed_bound_768_cxq; smt(Array768.initiE). ++ move : H10; rewrite /pos_bound768_cxq /signed_bound_768_cxq; smt(Array768.initiE). -bind op [W16.t & W256.t & Array256.t] sliceset256_16_256 "asliceset". -realize bvaslicesetP. -move => arr offset bv H /= k kb; rewrite /sliceset256_16_256 /=. -case (8 %| offset) => /= *; last first. -+ rewrite of_listK; 1: by rewrite size_map size_chunk // !size_cat size_take; - by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). - rewrite -(map_comp W16.w2bits W16.bits2w) /(\o). - have := eq_in_map ((fun (x : bool list) => w2bits ((bits2w x))%W16)) idfun (chunk 16 - (take offset (flatten (map W16.w2bits (to_list arr))) ++ w2bits bv ++ - drop (offset + 256) (flatten (map W16.w2bits (to_list arr))))). - rewrite iffE => [#] -> * /=; 1: by smt(in_chunk_size W16.bits2wK). - rewrite map_id /= chunkK //;1: by rewrite !size_cat size_take; - by smt(size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). - by rewrite !nth_cat !size_cat /=; - smt(nth_take nth_drop size_take size_drop W16.size_w2bits size_cat Array256.size_to_list size_flatten_W16_w2bits size_ge0). -rewrite (nth_flatten _ 16); 1: by rewrite allP => i;rewrite mapP => He; elim He;smt(W16.size_w2bits). -rewrite (nth_map W16.zero []); 1: smt(Array256.size_to_list). -rewrite nth_mkseq 1:/# /= initiE 1:/# /= get16E pack2E initiE 1:/# /= initiE 1:/# /= /set256_direct. -rewrite initiE 1:/# /=. -case (offset <= k && k < offset + 256) => *; 1: by - rewrite ifT 1:/# get_bits8 /= 1,2:/# initiE // initiE //. -rewrite ifF 1:/# initiE 1:/# /=. -rewrite (nth_flatten _ 16); 1: by rewrite allP => i;rewrite mapP => He; elim He;smt(W16.size_w2bits). -rewrite (nth_map W16.zero []); 1: smt(Array256.size_to_list). -rewrite nth_mkseq 1:/# /= bits8E /= initiE /# /=. -qed. +move => H15 H16 H17 H18 H19 r1 r2 H20 H21 H22;do split. ++ rewrite tP /= => k kb. + rewrite /lift_array256 /nttunpack !initiE //=. + pose a:=nttunpack_idx.[k]. + rewrite !mapiE //=; 1: smt(nttunpack_bnd Array256.allP). + rewrite !initiE //=; 1..3: smt(nttunpack_bnd Array256.allP). + rewrite kb /=. + have -> /= : 0 <= a && a < 256 by smt(nttunpack_bnd Array256.allP). + move : H20; rewrite /lift_array256 tP => H20. + move : (H20 k kb). + rewrite /nttunpack initiE //= -/a !mapiE //=;smt(nttunpack_bnd Array256.allP). -op sliceget32_8_256 (arr: W8.t Array32.t) (offset: int) : W256.t = -if 8 %| offset then - get256_direct (WArray32.init8 (fun (i_0 : int) => arr.[i_0])) (offset %/ 8) - else W256.bits2w (take 256 (drop offset (flatten (map W8.w2bits (to_list arr))))). +move => H23 r3 r4 H24 H25 H26;do split. ++ rewrite tP /= => k kb. + rewrite /lift_array256 /nttunpack !initiE //=. + pose a:=nttunpack_idx.[k]. + rewrite !mapiE //=; 1: smt(nttunpack_bnd Array256.allP). + rewrite /subarray256 /=. + rewrite !initiE //=; 1: smt(nttunpack_bnd Array256.allP). + rewrite !initiE //=; 1..2: smt(nttunpack_bnd Array256.allP). + rewrite kb /=. + have -> /= : 0 <= a && a < 256 by smt(nttunpack_bnd Array256.allP). + move : H24; rewrite /lift_array256 tP => H24. + move : (H24 k kb). + rewrite /nttunpack initiE //= -/a !mapiE //=;smt(nttunpack_bnd Array256.allP). -bind op [W8.t & W256.t & Array32.t] sliceget32_8_256 "asliceget". -realize bvaslicegetP. -move => /= arr offset; rewrite /sliceget32_8_256 /= => H k kb. -case (8%| offset) => /= *; last by smt(W256.get_bits2w). -rewrite /get256_direct pack32E initiE 1:/# /= initiE 1:/# /= initiE 1:/# /=. -rewrite nth_take 1,2:/# nth_drop 1,2:/#. -rewrite (BitEncoding.BitChunking.nth_flatten false 8 _). -+ rewrite allP => x /=; rewrite mapP => He; elim He;smt(W8.size_w2bits). -rewrite (nth_map W8.zero []); 1: smt(Array32.size_to_list). -by rewrite nth_mkseq /#. -qed. -op sliceget768_16_256 (arr: W16.t Array768.t) (offset: int) : W256.t = -if 8 %| offset then - get256_direct (WArray1536.init16 (fun (i_0 : int) => arr.[i_0])) (offset %/ 8) - else W256.bits2w (take 256 (drop offset (flatten (map W16.w2bits (to_list arr))))). ++ rewrite /signed_bound768_cxq /= => k kb; rewrite !initiE //=. smt(). + rewrite kb /=. smt(). ++ rewrite /signed_bound768_cxq /= => k kb; rewrite !initiE //=. smt(). + rewrite kb /=. smt(). -bind op [W16.t & W256.t & Array768.t] sliceget768_16_256 "asliceget". -realize bvaslicegetP. -move => /= arr offset; rewrite /sliceget768_16_256 /= => H k kb. -case (8%| offset) => /= *; last by smt(W256.get_bits2w). -rewrite /get256_direct pack32E initiE 1:/# /= initiE 1:/# /= initiE 1:/# /= bits8E initiE 1:/# /=. -rewrite nth_take 1,2:/# nth_drop 1,2:/#. -rewrite (BitEncoding.BitChunking.nth_flatten false 16 _). -+ rewrite allP => x /=; rewrite mapP => He; elim He;smt(W16.size_w2bits). -rewrite (nth_map W16.zero []); 1: smt(Array768.size_to_list). -by rewrite nth_mkseq /#. -qed. +(* Second ip *) -op sliceset960_8_128 (arr: W8.t Array960.t) (offset: int) (bv: W128.t) : W8.t Array960.t = - if 8 %| offset - then Array960.init (fun (i3 : int) => get8 (set128_direct ((init8 (fun (i_0 : int) => arr.[i_0])))%WArray960 (offset %/ 8) bv) i3) - else Array960.of_list witness (map W8.bits2w (chunk 8 (take offset (flatten (map W8.w2bits (to_list arr))) ++ w2bits bv ++ - drop (offset + 128) (flatten (map W8.w2bits (to_list arr)))))). +seq 5 4: (#{/~i{1}}pre /\ lift_array256 (subarray256 pkpv{1} 1) = nttunpack (lift_array256 (subarray256 pkpv{2} 1)) /\ + signed_bound768_cxq pkpv{1} 256 512 2 /\ + signed_bound768_cxq pkpv{2} 256 512 2 /\ i{1} = 2). +wp; call frommontequiv; wp; call pointwiseequiv; auto => />. +move => &1 &2 H2 H3 H4 H5 H6 H7 H8 H9 H10 H11 H12 H13 H14 H15 H16 H17; do split. ++ rewrite -lift768_nttunpack. congr. + rewrite /nttunpackm /nttunpackv tP /= => k kb. + rewrite !initiE //= initiE //= 1:/# ifF 1:/# ifT 1:/# initiE //=. ++ by rewrite /signed_bound768_cxq => k kb; rewrite initiE //= /#. ++ by rewrite /pos_bound768_cxq /signed_bound_768_cxq /#. ++ by rewrite /signed_bound768_cxq => k kb; rewrite initiE //= /#. ++ by rewrite /pos_bound768_cxq /signed_bound_768_cxq /#. -lemma size_flatten_W8_w2bits (a : W8.t list) : - (size (flatten (map W8.w2bits (a)))) = 8 * size a. -proof. - rewrite size_flatten -map_comp /(\o) /=. - rewrite StdBigop.Bigint.sumzE /= StdBigop.Bigint.BIA.big_mapT /(\o) /=. - rewrite StdBigop.Bigint.big_constz count_predT /#. -qed. +move => H18 H19 H20 H21 H22 r1 r2 H23 H24 H25;do split. ++ rewrite tP /= => k kb. + rewrite /lift_array256 /nttunpack !initiE //=. + pose a:=nttunpack_idx.[k]. + rewrite !mapiE //=; 1: smt(nttunpack_bnd Array256.allP). + rewrite !initiE //=; 1: smt(nttunpack_bnd Array256.allP). + rewrite !initiE //=; 1..2: smt(nttunpack_bnd Array256.allP). + rewrite !ifT /=; 1,2: smt(nttunpack_bnd Array256.allP). + move : H23; rewrite /lift_array256 tP => H23. + move : (H23 k kb). + rewrite /nttunpack initiE //= -/a !mapiE //=;smt(nttunpack_bnd Array256.allP). -bind op [W8.t & W128.t & Array960.t] sliceset960_8_128 "asliceset". -realize bvaslicesetP. -move => arr offset bv H /= k kb; rewrite /sliceset960_8_128 /=. -case (8 %| offset) => /= *; last first. -+ rewrite of_listK; 1: by rewrite size_map size_chunk // !size_cat size_take; - by smt(size_take size_drop W8.size_w2bits size_cat Array960.size_to_list size_flatten_W8_w2bits size_ge0). - rewrite -(map_comp W8.w2bits W8.bits2w) /(\o). - have := eq_in_map ((fun (x : bool list) => w2bits ((bits2w x))%W8)) idfun (chunk 8 - (take offset (flatten (map W8.w2bits (to_list arr))) ++ w2bits bv ++ - drop (offset + 128) (flatten (map W8.w2bits (to_list arr))))). - rewrite iffE => [#] -> * /=; 1: by smt(in_chunk_size W8.bits2wK). - rewrite map_id /= chunkK //;1: by rewrite !size_cat size_take; - by smt(size_take size_drop W8.size_w2bits size_cat Array960.size_to_list size_flatten_W8_w2bits size_ge0). - by rewrite !nth_cat !size_cat /=; - smt(nth_take nth_drop size_take size_drop W8.size_w2bits size_cat Array960.size_to_list size_flatten_W8_w2bits size_ge0). -rewrite (nth_flatten _ 8); 1: by rewrite allP => i;rewrite mapP => He; elim He;smt(W8.size_w2bits). -rewrite (nth_map W8.zero []); 1: smt(Array960.size_to_list). -rewrite nth_mkseq 1:/# /= initiE 1:/# /= /get8 /set128_direct. -rewrite initiE 1:/# /=. -case (offset <= k && k < offset + 128) => *; 1: by - rewrite ifT 1:/# get_bits8 /= 1,2:/# initiE // initiE //. -rewrite ifF 1:/# initiE 1:/# /=. -rewrite (nth_flatten _ 8); 1: by rewrite allP => i;rewrite mapP => He; elim He;smt(W8.size_w2bits). -rewrite (nth_map W8.zero []); 1: smt(Array960.size_to_list). -rewrite nth_mkseq /#. -qed. +move => H26 r3 r4 H27 H28 H29;do split. ++ rewrite tP /= => k kb. + rewrite /lift_array256 /nttunpack !initiE //=. + pose a:=nttunpack_idx.[k]. + rewrite !mapiE //=; 1: smt(nttunpack_bnd Array256.allP). + rewrite /subarray256 /=. + rewrite !initiE //=; 1: smt(nttunpack_bnd Array256.allP). + rewrite !initiE //=; 1..2: smt(nttunpack_bnd Array256.allP). + rewrite !ifF /=; 1..2: smt(nttunpack_bnd Array256.allP). + rewrite !initiE //=; 1..2: smt(nttunpack_bnd Array256.allP). + rewrite !ifF /=; 1..2: smt(nttunpack_bnd Array256.allP). + move : H15; rewrite /lift_array256 /subarray256 tP => H15. + move : (H15 k kb). + rewrite /nttunpack initiE //= -/a !mapiE //=; 1:smt(nttunpack_bnd Array256.allP). + rewrite !initiE //=; smt(nttunpack_bnd Array256.allP). ++ rewrite /signed_bound768_cxq /= => k kb; rewrite !initiE //=. smt(). + rewrite !ifF /=. smt(). + rewrite !initiE //=. smt(). smt(). -op sliceset960_8_32 (arr: W8.t Array960.t) (offset: int) (bv: W32.t) : W8.t Array960.t = - if 8 %| offset - then Array960.init - (WArray960.get8 - (set32_direct (WArray960.init8 (fun (i_0 : int) => arr.[i_0])) ( - offset %/ 8) bv)) - else Array960.of_list witness (map W8.bits2w (chunk 8 (take offset (flatten (map W8.w2bits (to_list arr))) ++ w2bits bv ++ - drop (offset + 32) (flatten (map W8.w2bits (to_list arr)))))). ++ rewrite /signed_bound768_cxq /= => k kb; rewrite !initiE //=. smt(). + rewrite !ifF /=. smt(). + rewrite !initiE //=. smt(). smt(). ++ rewrite tP /= => k kb. + rewrite /lift_array256 /nttunpack !initiE //=. + pose a:=nttunpack_idx.[k]. + rewrite !mapiE //=; 1: smt(nttunpack_bnd Array256.allP). + rewrite /subarray256 /=. + rewrite !initiE //=; 1: smt(nttunpack_bnd Array256.allP). + rewrite !initiE //=; 1..2: smt(nttunpack_bnd Array256.allP). + rewrite !ifT /=; 1..2: smt(nttunpack_bnd Array256.allP). + move : H27; rewrite /lift_array256 /subarray256 tP => H27. + move : (H27 k kb). + rewrite /nttunpack initiE //= -/a !mapiE //=; smt(nttunpack_bnd Array256.allP). -bind op [W8.t & W32.t & Array960.t] sliceset960_8_32 "asliceset". -realize bvaslicesetP. -move => arr offset bv H /= k kb; rewrite /sliceset960_8_32 /=. -case (8 %| offset) => /= *; last first. -+ rewrite of_listK; 1: by rewrite size_map size_chunk // !size_cat size_take; - by smt(size_take size_drop W8.size_w2bits size_cat Array960.size_to_list size_flatten_W8_w2bits size_ge0). - rewrite -(map_comp W8.w2bits W8.bits2w) /(\o). - have := eq_in_map ((fun (x : bool list) => w2bits ((bits2w x))%W8)) idfun (chunk 8 - (take offset (flatten (map W8.w2bits (to_list arr))) ++ w2bits bv ++ - drop (offset + 32) (flatten (map W8.w2bits (to_list arr))))). - rewrite iffE => [#] -> * /=; 1: by smt(in_chunk_size W8.bits2wK). - rewrite map_id /= chunkK //;1: by rewrite !size_cat size_take; - by smt(size_take size_drop W8.size_w2bits size_cat Array960.size_to_list size_flatten_W8_w2bits size_ge0). - by rewrite !nth_cat !size_cat /=; - smt(nth_take nth_drop size_take size_drop W8.size_w2bits size_cat Array960.size_to_list size_flatten_W8_w2bits size_ge0). -rewrite (nth_flatten _ 8); 1: by rewrite allP => i;rewrite mapP => He; elim He;smt(W8.size_w2bits). -rewrite (nth_map W8.zero []); 1: smt(Array960.size_to_list). -rewrite nth_mkseq 1:/# /= initiE 1:/# /= /get8 /set32_direct. -rewrite initiE 1:/# /=. -case (offset <= k && k < offset + 32) => *; 1: by - rewrite ifT 1:/# get_bits8 /= 1,2:/# initiE // initiE //. -rewrite ifF 1:/# initiE 1:/# /=. -rewrite (nth_flatten _ 8); 1: by rewrite allP => i;rewrite mapP => He; elim He;smt(W8.size_w2bits). -rewrite (nth_map W8.zero []); 1: smt(Array960.size_to_list). -rewrite nth_mkseq /#. -qed. ++ rewrite /signed_bound768_cxq /= => k kb; rewrite !initiE //=. smt(). + rewrite !initiE //=. smt(). smt(). -theory W10. -abbrev [-printing] size = 10. -clone include BitWordSH with op size <- size -rename "_XX" as "_10" -proof gt0_size by done, -size_le_256 by done. ++ rewrite /signed_bound768_cxq /= => k kb; rewrite !initiE //=. smt(). + rewrite !initiE //=. smt(). smt(). -end W10. export W10 W10.ALU W10.SHIFT. +(* Third ip *) -bind bitstring W10.w2bits W10.bits2w W10.to_uint W10.to_sint W10.of_int W10.t 10. -realize size_tolist by auto. -realize tolistP by auto. -realize oflistP by smt(W10.bits2wK). -realize ofintP by move => *;rewrite /of_int int2bs_mod. -realize tosintP. move => bv /=;rewrite /to_sint /smod /BVA_Top_W10_t.msb. -have -> /=: nth false (w2bits bv) (10 - 1) = 2 ^ (10 - 1) <= to_uint bv; last by smt(). -rewrite /to_uint. -rewrite -{2}(cat_take_drop 9 (w2bits bv)). -rewrite bs2int_cat size_take // W10.size_w2bits /=. -rewrite -bs2int_div //= get_to_uint //=. -rewrite -bs2int_mod // /= /to_uint. -have ? : 2^10 = 1024 by rewrite /=. -by smt(bs2int_range mem_range W10.size_w2bits). -qed. -realize touintP by smt(). +seq 5 4: (#{/~i{1}}pre /\ lift_array256 (subarray256 pkpv{1} 2) = nttunpack (lift_array256 (subarray256 pkpv{2} 2)) /\ + signed_bound768_cxq pkpv{1} 512 768 2 /\ + signed_bound768_cxq pkpv{2} 512 768 2). +wp; call frommontequiv; wp; call pointwiseequiv; auto => />. +move => &1 &2 H2 H3 H4 H5 H6 H7 H8 H9 H10 H11 H12 H13 H14 H15 H16 H17 H18 H19 H20; do split. ++ rewrite -lift768_nttunpack. congr. + rewrite /nttunpackm /nttunpackv tP /= => k kb. + rewrite !initiE //= initiE //= 1:/# ifF 1:/# ifF 1:/# initiE //=. ++ by rewrite /signed_bound768_cxq => k kb; rewrite initiE //= /#. ++ by rewrite /pos_bound768_cxq /signed_bound_768_cxq /#. ++ by rewrite /signed_bound768_cxq => k kb; rewrite initiE //= /#. ++ by rewrite /pos_bound768_cxq /signed_bound_768_cxq /#. -op truncate64_10 (bw: W64.t) : W10.t = W10.bits2w (W64.w2bits bw). +move => H21 H22 H23 H24 H25 r1 r2 H26 H27 H28;do split. ++ rewrite tP /= => k kb. + rewrite /lift_array256 /nttunpack !initiE //=. + pose a:=nttunpack_idx.[k]. + rewrite !mapiE //=; 1: smt(nttunpack_bnd Array256.allP). + rewrite !initiE //=; 1: smt(nttunpack_bnd Array256.allP). + rewrite !initiE //=; 1..2: smt(nttunpack_bnd Array256.allP). + rewrite !ifT /=; 1,2: smt(nttunpack_bnd Array256.allP). + move : H26; rewrite /lift_array256 tP => H26. + move : (H26 k kb). + rewrite /nttunpack initiE //= -/a !mapiE //=;smt(nttunpack_bnd Array256.allP). -bind op [W64.t & W10.t] truncate64_10 "truncate". -realize bvtruncateP. -move => mv. rewrite /truncate64_10 /W64.w2bits take_mkseq //= /w2bits. -apply (eq_from_nth witness);1: by smt(size_mkseq). -move => i; rewrite size_mkseq /= /max /= => ib. -by rewrite !nth_mkseq // /bits2w initiE //= nth_mkseq /#. -qed. + move => H29 r3 r4 H30 H31 H32;do split. ++ rewrite tP /= => k kb. + rewrite /lift_array256 /nttunpack !initiE //=. + pose a:=nttunpack_idx.[k]. + rewrite !mapiE //=; 1: smt(nttunpack_bnd Array256.allP). + rewrite /subarray256 /=. + rewrite !initiE //=; 1: smt(nttunpack_bnd Array256.allP). + rewrite !initiE //=; 1..2: smt(nttunpack_bnd Array256.allP). + rewrite !ifF /=; 1..2: smt(nttunpack_bnd Array256.allP). + rewrite !initiE //=; 1..2: smt(nttunpack_bnd Array256.allP). + rewrite !ifF /=; 1..2: smt(nttunpack_bnd Array256.allP). + move : H15; rewrite /lift_array256 /subarray256 tP => H15. + move : (H15 k kb). + rewrite /nttunpack initiE //= -/a !mapiE //=; 1:smt(nttunpack_bnd Array256.allP). + rewrite !initiE //=; smt(nttunpack_bnd Array256.allP). -bind op [W64.t & W8.t] W8u8.truncateu8 "truncate". -realize bvtruncateP. (* generalize *) -move => mv; rewrite /truncateu8 /W64.w2bits take_mkseq //= /w2bits. -apply (eq_from_nth witness);1: by smt(size_mkseq). -move => i; rewrite size_mkseq /= /max /= => ib. -rewrite !nth_mkseq // /of_int /to_uint /= get_bits2w // - nth_mkseq //= get_to_uint //= /to_uint /=. -have -> /=: (0 <= i && i < 64) by smt(). -pose a := bs2int (w2bits mv). -rewrite {1}(divz_eq a (2^(8-i)*2^i)) !mulrA divzMDl; - 1: by smt(StdOrder.IntOrder.expr_gt0). -rewrite dvdz_modzDl; 1: by - have -> : 2^(8-i) = 2^((8-i-1)+1); [ by smt() | - rewrite exprS 1:/#; smt(dvdz_mull dvdz_mulr)]. -by have -> : (2 ^ (8 - i) * 2 ^ i) = 256; - [ rewrite -StdBigop.Bigint.Num.Domain.exprD_nneg - 1,2:/# /= -!addrA /= | done ]. -qed. ++ rewrite /signed_bound768_cxq /= => k kb; rewrite !initiE //=. smt(). + rewrite !ifF /=. smt(). + rewrite !initiE //=. smt(). smt(). -bind op [W16.t & W8.t] W2u8.truncateu8 "truncate". -realize bvtruncateP. -move => mv; rewrite /truncateu8 /W16.w2bits take_mkseq //= /w2bits. -apply (eq_from_nth witness);1: by smt(size_mkseq). -move => i; rewrite size_mkseq /= /max /= => ib. -rewrite !nth_mkseq // /of_int /to_uint /= get_bits2w // - nth_mkseq //= get_to_uint //= /to_uint /=. -have -> /=: (0 <= i && i < 16) by smt(). -pose a := bs2int (w2bits mv). -rewrite {1}(divz_eq a (2^(8-i)*2^i)) !mulrA divzMDl; - 1: by smt(StdOrder.IntOrder.expr_gt0). -rewrite dvdz_modzDl; 1: by - have -> : 2^(8-i) = 2^((8-i-1)+1); [ by smt() | - rewrite exprS 1:/#; smt(dvdz_mull dvdz_mulr)]. -by have -> : (2 ^ (8 - i) * 2 ^ i) = 256; - [ rewrite -StdBigop.Bigint.Num.Domain.exprD_nneg - 1,2:/# /= -!addrA /= | done ]. -qed. ++ rewrite /signed_bound768_cxq /= => k kb; rewrite !initiE //=. smt(). + rewrite !ifF /=. smt(). + rewrite !initiE //=. smt(). smt(). ++ rewrite tP /= => k kb. + rewrite /lift_array256 /nttunpack !initiE //=. + pose a:=nttunpack_idx.[k]. + rewrite !mapiE //=; 1: smt(nttunpack_bnd Array256.allP). + rewrite /subarray256 /=. + rewrite !initiE //=; 1: smt(nttunpack_bnd Array256.allP). + rewrite !initiE //=; 1..2: smt(nttunpack_bnd Array256.allP). + rewrite !ifF /=; 1..2: smt(nttunpack_bnd Array256.allP). + rewrite !initiE //=; 1..2: smt(nttunpack_bnd Array256.allP). + rewrite !ifF /=; 1..2: smt(nttunpack_bnd Array256.allP). + move : H18; rewrite /lift_array256 /subarray256 tP => H18. + move : (H18 k kb). + rewrite /nttunpack initiE //= -/a !mapiE //=; 1:smt(nttunpack_bnd Array256.allP). + rewrite !initiE //=; smt(nttunpack_bnd Array256.allP). -bind op [W16.t & W64.t] W4u16.zeroextu64 "zextend". -realize bvzextendP - by move => bv; rewrite /zeroextu64 /= of_uintK /=; smt(W16.to_uint_cmp pow2_16). ++ rewrite /signed_bound768_cxq /= => k kb; rewrite !initiE //=. smt(). + rewrite !ifF /=. smt(). + rewrite !initiE //=. smt(). smt(). -bind op [W64.t & W16.t] W4u16.truncateu16 "truncate". -realize bvtruncateP. -move => mv; rewrite /truncateu16 /W64.w2bits take_mkseq //= /w2bits. -apply (eq_from_nth witness);1: by smt(size_mkseq). -move => i; rewrite size_mkseq /= /max /= => ib. -rewrite !nth_mkseq // /of_int /to_uint /= get_bits2w // - nth_mkseq //= get_to_uint //= /to_uint /=. -have -> /=: (0 <= i && i < 64) by smt(). -pose a := bs2int (w2bits mv). -rewrite {1}(divz_eq a (2^(16-i)*2^i)) !mulrA divzMDl; - 1: by smt(StdOrder.IntOrder.expr_gt0). -rewrite dvdz_modzDl; 1: by - have -> : 2^(16-i) = 2^((16-i-1)+1); [ by smt() | - rewrite exprS 1:/#; smt(dvdz_mull dvdz_mulr)]. -by have -> : (2 ^ (16 - i) * 2 ^ i) = 65536; - [ rewrite -StdBigop.Bigint.Num.Domain.exprD_nneg - 1,2:/# /= -!addrA /= | done ]. -qed. ++ rewrite /signed_bound768_cxq /= => k kb; rewrite !initiE //=. smt(). + rewrite !ifF /=. smt(). + rewrite !initiE //=. smt(). smt(). ++ rewrite tP /= => k kb. + rewrite /lift_array256 /nttunpack !initiE //=. + pose a:=nttunpack_idx.[k]. + rewrite !mapiE //=; 1: smt(nttunpack_bnd Array256.allP). + rewrite /subarray256 /=. + rewrite !initiE //=; 1: smt(nttunpack_bnd Array256.allP). + rewrite !initiE //=; 1..2: smt(nttunpack_bnd Array256.allP). + rewrite !ifT /=; 1..2: smt(nttunpack_bnd Array256.allP). + move : H30; rewrite /lift_array256 /subarray256 tP => H30. + move : (H30 k kb). + rewrite /nttunpack initiE //= -/a !mapiE //=; smt(nttunpack_bnd Array256.allP). -op sll_64 (w1 w2 : W64.t) : W64.t = - if (64 <= to_uint w2) then W64.zero else w1 `<<` (truncateu8 w2). ++ rewrite /signed_bound768_cxq /= => k kb; rewrite !initiE //=. smt(). + rewrite !initiE //=. smt(). smt(). -bind op [W64.t] sll_64 "shl". -realize bvshlP. -proof. -rewrite /sll_64 => bv1 bv2. -case : (64 <= to_uint bv2); last first. -+ rewrite /(`<<`) W64.to_uint_shl; 1: by smt(W8.to_uint_cmp). - rewrite /truncateu8 => bv2bnd />. - do 2! (rewrite (pmod_small (to_uint bv2) _);smt(W64.to_uint_cmp)). -move => *. -have -> : to_uint bv2 = (to_uint bv2 - 64) + 64 by ring. -by rewrite exprD_nneg 1,2:/# /= /#. -qed. ++ rewrite /signed_bound768_cxq /= => k kb; rewrite !initiE //=. smt(). + rewrite !initiE //=. smt(). smt(). -bind op [W32.t & W16.t] W2u16.truncateu16 "truncate". -realize bvtruncateP. -move => mv; rewrite /truncateu16 /W32.w2bits take_mkseq //= /w2bits. -apply (eq_from_nth witness);1: by smt(size_mkseq). -move => i; rewrite size_mkseq /= /max /= => ib. -rewrite !nth_mkseq // /of_int /to_uint /= get_bits2w // - nth_mkseq //= get_to_uint //= /to_uint /=. -have -> /=: (0 <= i && i < 32) by smt(). -pose a := bs2int (w2bits mv). -rewrite {1}(divz_eq a (2^(16-i)*2^i)) !mulrA divzMDl; - 1: by smt(StdOrder.IntOrder.expr_gt0). -rewrite dvdz_modzDl; 1: by - have -> : 2^(16-i) = 2^((16-i-1)+1); [ by smt() | - rewrite exprS 1:/#; smt(dvdz_mull dvdz_mulr)]. -by have -> : (2 ^ (16 - i) * 2 ^ i) = 65536; - [ rewrite -StdBigop.Bigint.Num.Domain.exprD_nneg - 1,2:/# /= -!addrA /= | done ]. -qed. +auto => />. -bind op [W16.t & W32.t] sigextu32 "sextend". -realize bvsextendP. -move => bv;rewrite /sigextu32 /to_sint /smod /= !of_uintK /=. -case (32768 <= to_uint bv); 2: smt(W16.to_uint_cmp). -move =>?;rewrite -{2}(oppzK (to_uint bv - 65536)) modNz /=; smt(W16.to_uint_cmp pow2_16). -qed. +move => &1 &2 ?????????????H1??H2??H3??. +do split. ++ smt(). ++ smt(). ++ rewrite /lift_array256 /subarray256 tP in H1. + rewrite /lift_array256 /subarray256 tP in H2. + rewrite /lift_array256 /subarray256 tP in H3. + rewrite /nttpackv tP => k kb. + rewrite initiE //=. + case (0 <= k && k < 256). + + move => kbb. + move : (H1 (nttpack_idx.[k]) _); 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). + pose a:=nttpack_idx.[k]. + rewrite !mapiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). + rewrite /nttunpack !initiE //=; 1..2: smt(nttpack_bnd nttunpack_bnd Array256.allP). + pose b:=nttunpack_idx.[a]. + rewrite !mapiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). + rewrite /lift_array768 /subarray256 /=. + pose c := nttpack_idx.[k]. + rewrite /nttunpack initiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). + rewrite /nttpack initiE //=. + rewrite initiE //=. smt(nttpack_bnd nttunpack_bnd Array256.allP). + rewrite !mapiE //=; 1: smt(nttpack_bnd Array256.allP). + have -> : b = k. move : nttunpack_idxK; rewrite /b /a allP; smt(mem_iota). + smt(). + case (256 <= k && k < 512). + + move => kbb ?. + move : (H2 (nttpack_idx.[k-256]) _); 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). + pose a:=nttpack_idx.[k-256]. + rewrite !mapiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). + rewrite initiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). + rewrite /nttunpack initiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). + pose b:=nttunpack_idx.[a]. + rewrite !mapiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). + rewrite initiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). + rewrite /lift_array768 /subarray256 /=. + pose c := nttpack_idx.[k-256]. + rewrite /nttpack initiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). + rewrite initiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). + rewrite !mapiE //=; 1: smt(nttpack_bnd Array256.allP). + have -> : b = k-256. move : nttunpack_idxK; rewrite /b /a allP; smt(mem_iota). + smt(). -bind op [W32.t & W8.t] W4u8.truncateu8 "truncate". -realize bvtruncateP. -move => mv; rewrite /truncateu8 /W32.w2bits take_mkseq //= /w2bits. -apply (eq_from_nth witness);1: by smt(size_mkseq). -move => i; rewrite size_mkseq /= /max /= => ib. -rewrite !nth_mkseq // /of_int /to_uint /= get_bits2w // - nth_mkseq //= get_to_uint //= /to_uint /=. -have -> /=: (0 <= i && i < 32) by smt(). -pose a := bs2int (w2bits mv). -rewrite {1}(divz_eq a (2^(8-i)*2^i)) !mulrA divzMDl; - 1: by smt(StdOrder.IntOrder.expr_gt0). -rewrite dvdz_modzDl; 1: by - have -> : 2^(8-i) = 2^((8-i-1)+1); [ by smt() | - rewrite exprS 1:/#; smt(dvdz_mull dvdz_mulr)]. -by have -> : (2 ^ (8 - i) * 2 ^ i) = 256; - [ rewrite -StdBigop.Bigint.Num.Domain.exprD_nneg - 1,2:/# /= -!addrA /= | done ]. + + move => *. + move : (H3 (nttpack_idx.[k-512]) _); 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). + pose a:=nttpack_idx.[k-512]. + rewrite !mapiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). + rewrite initiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). + rewrite /nttunpack initiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). + pose b:=nttunpack_idx.[a]. + rewrite !mapiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). + rewrite initiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). + rewrite /lift_array768 /subarray256 /=. + pose c := nttpack_idx.[k-512]. + rewrite /nttpack initiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). + rewrite initiE //=; 1: smt(nttpack_bnd nttunpack_bnd Array256.allP). + rewrite !mapiE //=; 1: smt(nttpack_bnd Array256.allP). + have -> : b = k-512. move : nttunpack_idxK; rewrite /b /a allP; smt(mem_iota). + smt(). + ++ smt(unpackvK). ++ smt(). ++ smt(). +by smt(unpackvK). qed. +(***************************************************) -bind circuit VPBROADCAST_8u32 "VPBROADCAST_8u32". -bind circuit VPBROADCAST_4u64 "VPBROADCAST_4u64". + import WArray960 WArray1536 Array4. -bind circuit VPMADDWD_256 "VPMADDWD_16u16". +module AuxPolyVecCompress10 = { + proc avx2_orig(ctp : W64.t, bp : W16.t Array768.t) : W8.t Array960.t = { + bp <@ Jkem_avx2.M(Jkem_avx2.Syscall).__polyvec_reduce_sig(bp); + Jkem_avx2.M(Jkem_avx2.Syscall).__polyvec_compress(ctp, bp); + return witness; + } -bind circuit VPSLLV_8u32 "VPSLLV_8u32". + proc avx2_orig_i(ctp : W8.t Array1088.t, bp : W16.t Array768.t) : W8.t Array960.t = { + var rr : W8.t Array960.t; + bp <@ Jkem_avx2.M(Jkem_avx2.Syscall).__polyvec_reduce_sig(bp); + rr <@ Jkem_avx2.M(Jkem_avx2.Syscall).__polyvec_compress_1(Array960.init (fun (i_0 : int) => ctp.[0 + i_0]),bp); + return rr; + } -bind circuit VPSRL_4u64 "VPSRL_4u64". -bind circuit VPSHUFB_256 "VPSHUFB_256". +proc __polyvec_compress_avx2(ctp : W8.t Array1088.t, a : W16.t Array768.t) : W8.t Array960.t = { + var aux : int; + var b0 : W256.t; + var b1 : W256.t; + var b2 : W256.t; + var mask10 : W256.t; + var shift : W256.t; + var sllv_indx : W256.t; + var shuffle : W256.t; + var i : int; + var a0 : W256.t; + var lo : W128.t; + var hi : W128.t; + var rp : W8.t Array960.t <- (init (fun (i_0 : int) => ctp.[0 + i_0]))%Array960; + + b0 <- VPBROADCAST_16u16 compress10_b0; + b1 <- VPBROADCAST_16u16 compress10_b1; + b2 <- VPBROADCAST_16u16 pc_shift1_s; + mask10 <- VPBROADCAST_16u16 pvc_mask_s; + shift <- VPBROADCAST_8u32 compress10_shift; + sllv_indx <- VPBROADCAST_4u64 pvc_sllvdidx_s; + shuffle <- get256 ((init8 (fun (i_0 : int) => pvc_shufbidx_s.[i_0])))%WArray32 0; + aux <- 3 * 256 %/ 16; + i <- 0; + while (i < aux){ + a0 <- (get256 ((WArray1536.init16 (fun (i_0 : int) => a.[i_0]))) i); + a0 <@ Jkem_avx2.M(Syscall).compress10_16x16_inline(a0, b0, b1, b2, mask10); + (lo, hi) <@ Jkem_avx2.M(Syscall).pack10_16x16(a0, shift, sllv_indx, shuffle); +rp <- + (init + (get8 + (set128_direct + ((init8 (fun (i_0 : int) => rp.[i_0])))%WArray960 + (20 * i) lo)))%Array960 ; + + rp <- + (init + (get8 + (set32_direct + ((init8 (fun (i_0 : int) => rp.[i_0])))%WArray960 + (20 * i + 16) (VPEXTR_32 hi W8.zero))))%Array960 ; + (* + Glob.mem <- storeW128 Glob.mem (to_uint (r + (of_int (i * 20 + 0))%W64)) lo; + Glob.mem <- storeW32 Glob.mem (to_uint (r + (of_int (i * 20 + 16))%W64)) (VPEXTR_32 hi W8.zero); +*) + i <- i + 1; + } + + return rp; + } -bind circuit VEXTRACTI128 "VEXTRACTI128". + proc avx2(bp : W16.t Array768.t) : W8.t Array960.t = { + var rr : W8.t Array960.t; + var ctp : W8.t Array1088.t <- (init (fun (i_0 : int) => W8.zero))%Array1088; + bp <@ Jkem_avx2.M(Jkem_avx2.Syscall).__polyvec_reduce_sig(bp); + rr <@ __polyvec_compress_avx2(ctp,bp); + return rr; + } + + proc ref_orig(ctp : W64.t, bp : W16.t Array768.t) : W8.t Array960.t = { + bp <@ Jkem.M(Syscall).__polyvec_reduce(bp); + Jkem.M(Syscall).__polyvec_compress(ctp, bp); + return witness; + } + + proc ref_orig_i(ctp : W8.t Array1088.t, bp : W16.t Array768.t) : W8.t Array960.t = { + var rr : W8.t Array960.t; + bp <@ M(Syscall).__polyvec_reduce(bp) ; + rr <@ M(Syscall).__i_polyvec_compress(Array960.init (fun (i_0 : int) => ctp.[0 + i_0]),bp); + return rr; + } + +proc _poly_csubq_ref(rp : W16.t Array256.t) : W16.t Array256.t = { + var i : int; + var t : W16.t; + var b : W16.t; + + i <- 0; + while (i < 256){ + t <- rp.[i]; + t <- t - qlocal; + b <- t; + b <- b `|>>` (of_int 15)%W8; + b <- b `&` qlocal; + t <- t + b; + rp.[i] <- t; + i <- i + 1; + } + + return rp; + } + proc __polyvec_csubq_ref(r : W16.t Array768.t) : W16.t Array768.t = { + var aux : W16.t Array256.t; + + aux <@ _poly_csubq_ref((init (fun (i : int) => r.[0 + i]))%Array256); + r <- (init (fun (i : int) => if 0 <= i && i < 0 + 256 then aux.[i - 0] else r.[i]))%Array768; + aux <@ _poly_csubq_ref((init (fun (i : int) => r.[256 + i]))%Array256); + r <- (init (fun (i : int) => if 256 <= i && i < 256 + 256 then aux.[i - 256] else r.[i]))%Array768; + aux <@ _poly_csubq_ref((init (fun (i : int) => r.[2 * 256 + i]))%Array256); + r <- (init (fun (i : int) => if 2 * 256 <= i && i < 2 * 256 + 256 then aux.[i - 2 * 256] else r.[i]))%Array768; + + return r; + } + +(* proc __polyvec_compress_ref(a : W16.t Array768.t) : W8.t Array960.t = { + var aux : int; + var i : int; + var j : int; + var aa : W16.t Array768.t; + var k : int; + var t : W64.t Array4.t; + var c : W16.t; + var b : W16.t; + var rr : W8.t Array960.t <- Array960.init(fun _ => W8.zero); + + aa <- witness; + t <- Array4.init(fun _ => W64.zero); + i <- 0; + j <- 0; + aa <@ __polyvec_csubq_ref(a); + while (i < (3 * 256 - 3)){ + k <- 0; + while (k < 4){ + t.[k] <- zeroextu64 aa.[i]; + i <- i + 1; + t.[k] <- (t.[k]) `<<` (of_int 10)%W8; + t.[k] <- (t.[k]) + (of_int 1665)%W64; + t.[k] <- (t.[k]) * (of_int 1290167)%W64; + t.[k] <- (t.[k]) `>>` (of_int 32)%W8; + t.[k] <- (t.[k]) `&` (of_int 1023)%W64; + k <- k + 1; + } + c <- truncateu16 (t.[0]); + c <- c `&` (of_int 255)%W16; + rr.[j] <- (truncateu8 c); + (* + Glob.mem <- storeW8 Glob.mem (to_uint (rp + j)) (truncateu8 c); + *) + j <- j + 1; + b <- truncateu16 (t.[0]); + b <- b `>>` (of_int 8)%W8; + c <- truncateu16 (t.[1]); + c <- c `<<` (of_int 2)%W8; + c <- c `|` b; + rr.[j] <- (truncateu8 c); + (* + Glob.mem <- storeW8 Glob.mem (to_uint (rp + j)) (truncateu8 c); + *) + j <- j + 1; + b <- truncateu16 (t.[1]); + b <- b `>>` (of_int 6)%W8; + c <- truncateu16 (t.[2]); + c <- c `<<` (of_int 4)%W8; + c <- c `|` b; + rr.[j] <- (truncateu8 c); + (* + Glob.mem <- storeW8 Glob.mem (to_uint (rp + j)) (truncateu8 c); + *) + j <- j + 1; + b <- truncateu16 (t.[2]); + b <- b `>>` (of_int 4)%W8; + c <- truncateu16 (t.[3]); + c <- c `<<` (of_int 6)%W8; + c <- c `|` b; + rr.[j] <- (truncateu8 c); + (* + Glob.mem <- storeW8 Glob.mem (to_uint (rp + j)) (truncateu8 c); + *) + j <- j + 1; + t.[3] <- (t.[3]) `>>` (of_int 2)%W8; + rr.[j] <- (truncateu8 (t.[3])); + (* + Glob.mem <- storeW8 Glob.mem (to_uint (rp + j)) (truncateu8 (t.[3])); + *) + j <- j + 1; + } + + return rr; + } +*) + + proc __i_polyvec_compress_ref(a : W16.t Array768.t) : W8.t Array960.t = { + var aux : int; + var i : int; + var j : int; + var aa : W16.t Array768.t; + var t : W64.t t; + var c : W16.t; + var b : W16.t; + + var rp : W8.t Array960.t <- Array960.init(fun _ => W8.zero); + t <- Array4.init(fun _ => W64.zero); + aa <@ __polyvec_csubq_ref(a); + j <- 0; + i <- 0; + while (i < (3 * 256 - 3)){ + (* k = 0 *) + t.[0] <- zeroextu64 aa.[i+0]; + t.[0] <- t.[0] `<<` (of_int 10)%W8; + t.[0] <- t.[0] + (of_int 1665)%W64; + t.[0] <- t.[0] * (of_int 1290167)%W64; + t.[0] <- t.[0] `>>` (of_int 32)%W8; + t.[0] <- t.[0] `&` (of_int 1023)%W64; + (* k = 1 *) + t.[1] <- zeroextu64 aa.[i+1]; + t.[1] <- t.[1] `<<` (of_int 10)%W8; + t.[1] <- t.[1] + (of_int 1665)%W64; + t.[1] <- t.[1] * (of_int 1290167)%W64; + t.[1] <- t.[1] `>>` (of_int 32)%W8; + t.[1] <- t.[1] `&` (of_int 1023)%W64; + (* k = 2 *) + t.[2] <- zeroextu64 aa.[i+2]; + t.[2] <- t.[2] `<<` (of_int 10)%W8; + t.[2] <- t.[2] + (of_int 1665)%W64; + t.[2] <- t.[2] * (of_int 1290167)%W64; + t.[2] <- t.[2] `>>` (of_int 32)%W8; + t.[2] <- t.[2] `&` (of_int 1023)%W64; + (* k = 3 *) + t.[3] <- zeroextu64 aa.[i+3]; + t.[3] <- t.[3] `<<` (of_int 10)%W8; + t.[3] <- t.[3] + (of_int 1665)%W64; + t.[3] <- t.[3] * (of_int 1290167)%W64; + t.[3] <- t.[3] `>>` (of_int 32)%W8; + t.[3] <- t.[3] `&` (of_int 1023)%W64; + c <- truncateu16 t.[0]; + c <- c `&` (of_int 255)%W16; + rp.[j] <- truncateu8 c; + j <- j + 1; + b <- truncateu16 t.[0]; + b <- b `>>` (of_int 8)%W8; + c <- truncateu16 t.[1]; + c <- c `<<` (of_int 2)%W8; + c <- c `|` b; + rp.[j] <- truncateu8 c; + j <- j + 1; + b <- truncateu16 t.[1]; + b <- b `>>` (of_int 6)%W8; + c <- truncateu16 t.[2]; + c <- c `<<` (of_int 4)%W8; + c <- c `|` b; + rp.[j] <- truncateu8 c; + j <- j + 1; + b <- truncateu16 t.[2]; + b <- b `>>` (of_int 4)%W8; + c <- truncateu16 t.[3]; + c <- c `<<` (of_int 6)%W8; + c <- c `|` b; + rp.[j] <- truncateu8 c; + j <- j + 1; + t.[3] <- t.[3] `>>` (of_int 2)%W8; + rp.[j] <- truncateu8 t.[3]; + j <- j + 1; + i <- i + 4; + } + + return rp; + } -bind circuit VPBLENDW_128 "VPBLEND_8u16". +proc __poly_reduce(rp : W16.t Array256.t) : W16.t Array256.t = { + var j : int; + var t : W16.t; + + j <- 0; + while (j < 256){ + t <- rp.[j]; + t <@ M(Syscall).__barrett_reduce(t); + rp.[j] <- t; + j <- j + 1; + } + + return rp; + } -bind circuit VPEXTR_32 "VEXTRACTI32_256". + proc __polyvec_reduce(r : W16.t Array768.t) : W16.t Array768.t = { + var aux : W16.t Array256.t; + + aux <@ __poly_reduce((init (fun (i : int) => r.[0 + i]))%Array256); + r <- (init (fun (i : int) => if 0 <= i && i < 0 + 256 then aux.[i - 0] else r.[i]))%Array768; + aux <@ __poly_reduce((init (fun (i : int) => r.[256 + i]))%Array256); + r <- (init (fun (i : int) => if 256 <= i && i < 256 + 256 then aux.[i - 256] else r.[i]))%Array768; + aux <@ __poly_reduce((init (fun (i : int) => r.[2 * 256 + i]))%Array256); + r <- (init (fun (i : int) => if 2 * 256 <= i && i < 2 * 256 + 256 then aux.[i - 2 * 256] else r.[i]))%Array768; + + return r; + } -bind circuit W4u32.VPEXTR_32 "VEXTRACTI32_128". + proc ref(bp : W16.t Array768.t) : W8.t Array960.t = { + var rr : W8.t Array960.t; + bp <@ __polyvec_reduce(bp); + rr <@ __i_polyvec_compress_ref(bp); + return rr; + } -bind op [W256.t & W128.t] truncateu128 "truncate". -realize bvtruncateP. -move => mv; rewrite /truncateu128 /W256.w2bits take_mkseq //= /w2bits. -apply (eq_from_nth witness);1: by smt(size_mkseq). -move => i; rewrite size_mkseq /= /max /= => ib. -rewrite !nth_mkseq // /of_int /to_uint /= get_bits2w // - nth_mkseq //= get_to_uint //= /to_uint /=. -have -> /=: (0 <= i && i < 256) by smt(). -pose a := bs2int (w2bits mv). -rewrite {1}(divz_eq a (2^(128-i)*2^i)) !mulrA divzMDl; - 1: by smt(StdOrder.IntOrder.expr_gt0). -rewrite dvdz_modzDl; 1: by - have -> : 2^(128-i) = 2^((128-i-1)+1); [ by smt() | - rewrite exprS 1:/#; smt(dvdz_mull dvdz_mulr)]. -by have -> : (2 ^ (128 - i) * 2 ^ i) = 340282366920938463463374607431768211456; - [ rewrite -StdBigop.Bigint.Num.Domain.exprD_nneg - 1,2:/# /= -!addrA /= | done ]. -qed. +}. -op sra_32 (w1 w2 : W32.t) : W32.t = - if (32 <= to_uint w2) then W32.zero else w1 `|>>` (truncateu8 w2). +lemma compress10_equiv_avx2mem _ctp _mem : + equiv [ AuxPolyVecCompress10.avx2_orig ~ AuxPolyVecCompress10.avx2 : + ={bp} /\ ctp{1} = _ctp /\ Glob.mem{1} = _mem /\ valid_ptr (to_uint ctp{1}) (128 + 3 * 320) ==> + Glob.mem{1} = stores _mem (to_uint _ctp) (to_list res{2}) ]. +proc => /=. +swap {2} 2 -1;seq 1 1 : #pre; 1: by conseq />;inline *;sim. +inline {1} 1; inline {2} 2. +wp. +while (Glob.mem{1} = stores _mem (to_uint _ctp) (take (i{2}*20) (to_list rp{2})) /\ aux{1} = 48 /\ + valid_ptr (to_uint r{1}) (128 + 3 * 320) /\ r{1} = _ctp /\ + ={i,a,aux,sllv_indx, shuffle, shift, mask10, b2, b1, b0} /\ 0 <= i{2} <= 48); last + by auto => />;smt(Array960.size_to_list List.take_size List.take0 storesE iota0). -bind op [W32.t] sra_32 "ashr". -realize bvashrP. -rewrite /sra_32 => bv1 bv2. -case : (32 <= to_uint bv2 < 32); last by admit. +seq 3 3 : (#pre /\ ={lo,hi}); + 1: by conseq />; sim. +auto => /> &1 &2 ????;split;last by smt(). +rewrite /storeW32 /storeW128. +apply mem_eq_ext => add. +rewrite !get_storesE !to_uintD_small /= !of_uintK /= 1,2:/# !modz_small 1..2:/#. +rewrite !size_take 1,2:/# /= !size_to_list. +case ((to_uint _ctp <= add && add < to_uint _ctp + MIN ((i{1} + 1) * 20) 960)); last by smt(). move => *. -have -> : to_uint bv2 = (to_uint bv2 - 32) + 32 by ring. -rewrite exprD_nneg 1,2:/# /= mulrC. by admit. +case ((to_uint _ctp + MIN (i{1} * 20) 960) <= add && add < to_uint _ctp + MIN (i{1} * 20 + 16) 960). ++ move => *; rewrite ifF 1:/# ifT 1:/# mulrDl /= takeD 1,2:/# nth_cat !size_take 1:/# size_to_list . + have -> /= : add - to_uint _ctp < MIN (i{1} * 20) 960 = false by smt(). + rewrite /to_list drop_mkseq 1:/# take_mkseq 1:/# /= /(\o) /= /mkseq (nth_map witness) /=;1:smt(size_iota). + rewrite nth_iota 1:/# initiE 1:/# get8_set32_directE 1,2:/# /= /get8 initiE 1:/# /= -/WArray960.get8 initiE 1:/# get8_set128_directE /#. +case ((to_uint _ctp + MIN (i{1} * 20+16) 960) <= add && add < to_uint _ctp + MIN (i{1} * 20 + 20) 960). ++ move => *; rewrite ifT 1:/# mulrDl /= takeD 1,2:/# nth_cat !size_take 1:/# size_to_list . + have -> /= : add - to_uint _ctp < MIN (i{1} * 20) 960 = false by smt(). + rewrite /to_list drop_mkseq 1:/# take_mkseq 1:/# /= /(\o) /= /mkseq (nth_map witness) /=;1:smt(size_iota). + rewrite nth_iota 1:/# initiE 1:/# get8_set32_directE 1,2:/# /= /get8 initiE 1:/# /= -/WArray960.get8 initiE 1:/# get8_set128_directE /#. +case (to_uint _ctp <= add && add < to_uint _ctp + MIN (i{1} * 20) 960); last by smt(). +move => *; rewrite ifF 1:/# ifF 1:/# mulrDl /= /to_list !take_mkseq 1,2:/# /= /mkseq !(nth_map witness); 1,2: smt(size_iota). +rewrite !nth_iota 1,2:/# initiE 1:/# get8_set32_directE 1,2:/# /get8 !initiE 1,2:/# /= -/WArray960.get8 get8_set128_directE 1,2:/# /get8 initiE /#. qed. -op sra_16 (w1 w2 : W16.t) : W16.t = -if (16 <= to_uint w2) then W16.zero else w1 `|>>` (truncateu8 w2). - -bind op [W16.t] sra_16 "ashr". -realize bvashrP. -rewrite /sra_16 => bv1 bv2. -case : (16 <= to_uint bv2); last by admit. -move => *. -have -> : to_uint bv2 = (to_uint bv2 - 16) + 16 by ring. -rewrite exprD_nneg 1,2:/# /= mulrC. by admit. +lemma poly_reduce_noloops : + equiv [ AuxPolyVecCompress10.__poly_reduce ~ M(Syscall).__poly_reduce : + ={arg} ==> ={res} ]. +proc => /=. +while (#pre /\ 0<=j{1} <= 256 /\ j{1} = to_uint j{2}); last by auto. +inline *;auto => /> &2; rewrite !W64.ultE /= => *;do split;1..2:smt();by rewrite to_uintD_small /=; smt(). qed. +lemma polyvec_reduce_noloops : + equiv [ AuxPolyVecCompress10.__polyvec_reduce ~ M(Syscall).__polyvec_reduce : + ={arg} ==> ={res} ]. +proc => /=. +by do 3!(wp;call poly_reduce_noloops);auto => />. +qed. -op srl_16 (w1 w2 : W16.t) : W16.t = - if 16 <= (to_uint w2) then W16.zero else - w1 `>>` (truncateu8 w2). +lemma poly_csubq_noloops : + equiv [ AuxPolyVecCompress10._poly_csubq_ref ~ M(Syscall)._poly_csubq : + ={arg} ==> ={res} ]. +proc => /=. +while (#pre /\ 0<=i{1} <= 256 /\ i{1} = to_uint i{2}); last by auto. +inline *;auto => /> &2; rewrite !W64.ultE /= => *;do split;1..2:smt();by rewrite to_uintD_small /=; smt(). +qed. -bind op [W16.t] srl_16 "shr". -realize bvshrP. -rewrite /srl_16 => bv1 bv2. -case : (16 <= to_uint bv2); last first. -+ rewrite /(`>>`) W16.to_uint_shr; 1: by smt(W8.to_uint_cmp). - rewrite /truncateu8 => bv2bnd />. - do 2! (rewrite (pmod_small (to_uint bv2) _);smt(W16.to_uint_cmp)). -move => *. -have -> : to_uint bv2 = (to_uint bv2 - 16) + 16 by ring. -rewrite exprD_nneg 1,2:/# /=. -smt(StdOrder.IntOrder.expr_gt0 W16.to_uint_cmp pow2_16). +lemma polyvec_csubq_noloops : + equiv [ AuxPolyVecCompress10.__polyvec_csubq_ref ~ M(Syscall).__polyvec_csubq : + ={arg} ==> ={res} ]. +proc => /=. +by do 3!(wp;call poly_csubq_noloops);auto => />. qed. +lemma compress10_equiv_refmem _ctp _mem : + equiv [ AuxPolyVecCompress10.ref ~ AuxPolyVecCompress10.ref_orig : + ={bp} /\ ctp{2} = _ctp /\ Glob.mem{2} = _mem /\ valid_ptr (to_uint ctp{2}) (128 + 3 * 320) ==> + Glob.mem{2} = stores _mem (to_uint _ctp) (to_list res{1}) ]. + proc => /=. +seq 1 1 : #pre; 1: by call polyvec_reduce_noloops => />. +inline {1} 1; inline {2} 1. +swap {1} 3 -1; swap {1} 4 -2; swap {2} [2..3] -1; swap {2} 7 -4. +seq 2 3 : (#pre /\ ={aa});1:by call polyvec_csubq_noloops;auto => />. +wp;while (0 <= i{1} <= 768 /\ i{1} = to_uint i{2} /\ valid_ptr (to_uint rp{2}) (128 + 3 * 320) /\ + 0 <= j{1} <= 960 /\ j{1} = to_uint j{2} /\ rp{2} = _ctp /\ + j{1} *4 = i{1} * 5 /\ ={aa} /\ + Glob.mem{2} = stores _mem (to_uint _ctp) (take j{1} (to_list rp{1}))); last + by auto => />; smt(Array960.size_to_list List.take_size List.take0 storesE iota0). +unroll for* {2} 2;auto => /> &1 &2;rewrite !ultE /= => ?????????;do split;1,2,4,5:smt();1..3,5..:by rewrite ?to_uintD_small;smt(). +rewrite /storeW8 /=. +apply mem_eq_ext => adr. + rewrite !to_uintD_small /= 1..16:/# !addrA. +rewrite !get_storesE. +case (to_uint _ctp + to_uint j{2} <= adr < to_uint _ctp + to_uint j{2} + 5); last first. +case (to_uint _ctp <= adr < to_uint _ctp + to_uint j{2}); last first. ++ move => *;do 5!(rewrite get_set_neqE_s 1:/#). + rewrite !size_take 1:/# size_to_list /= ifF 1:/# get_storesE /= size_take 1:/# size_to_list /#. ++ move => *. + move => *;do 5!(rewrite get_set_neqE_s 1:/#). + rewrite !size_take 1:/# size_to_list /= ifT 1:/# nth_take 1,2:/# /to_list nth_mkseq 1:/# /= get_storesE size_take 1:/# size_mkseq /= ifT 1:/#. + by rewrite nth_take 1,2:/# nth_mkseq 1:/# /=; smt(Array960.get_setE). +move => *; rewrite size_take 1:/# size_to_list ifT 1:/# nth_take 1,2:/# /to_list nth_mkseq 1:/# /=. +by smt(Array960.get_setE get_set_neqE_s get_set_eqE_s). +qed. -op sll_16 (w1 w2 : W16.t) : W16.t = - if (16 <= to_uint w2) then W16.zero else w1 `<<` (truncateu8 w2). +lemma compress10_equiv_avx2i : + equiv [ AuxPolyVecCompress10.avx2_orig_i ~ AuxPolyVecCompress10.avx2 : + ={bp} ==> ={res} ]. +proc => /=. +swap {2} 2 -1. +seq 1 1 : #pre; 1: by sim. +inline *;wp. +while (={i,a,bp,b0,b1,b2,mask10,shift,sllv_indx,shuffle,aux} /\ 0<=i{1} <= 48 /\ aux{1}=48 /\ (forall k, 0<=k rp{1}.[k] = rp{2}.[k])); + last by auto => /> *; split;[ smt() | move => *; rewrite tP => *;smt()]. +auto => /> &1 &2 *;split;1:smt(). +move => k kbl kbh; rewrite !initiE 1,2:/# /=. +rewrite !get8_set32_directE 1..4:/#. +case (0<=k *; rewrite !ifF 1,2:/# /get8 !initiE 1..4:/# /=. + rewrite -/WArray960.get8 !get8_set128_directE 1..4:/# !ifF 1,2:/#. + by rewrite /get8 !initiE /#. +move => *;case (i{2}*20<=k *; rewrite !ifF 1,2:/# /get8 !initiE 1..4:/# /=. + by rewrite -/WArray960.get8 !get8_set128_directE 1..4:/# !ifT /#. +by smt(). +qed. -bind op [W16.t] sll_16 "shl". -realize bvshlP. -rewrite /sll_16 => bv1 bv2. -case : (16 <= to_uint bv2); last first. -+ rewrite /(`<<`) W16.to_uint_shl; 1: by smt(W8.to_uint_cmp). - rewrite /truncateu8 => bv2bnd />. - do 2! (rewrite (pmod_small (to_uint bv2) _);smt(W16.to_uint_cmp pow2_16)). -move => *. -have -> : to_uint bv2 = (to_uint bv2 - 16) + 16 by ring. -by rewrite exprD_nneg 1,2:/# /= /#. +lemma compress10_equiv_refi : + equiv [ AuxPolyVecCompress10.ref ~ AuxPolyVecCompress10.ref_orig_i : + ={bp} ==> ={res} ]. +proc. +seq 1 1 : #pre; 1: by call polyvec_reduce_noloops => />. +inline {1} 1; inline {2} 1. +swap {1} 4 -2;swap {2} [2..3] -1; swap {2} 7 -4. +seq 2 3 : (#pre /\ ={aa});1:by call polyvec_csubq_noloops;auto => />. +wp;while (0 <= i{1} <= 768 /\ i{1} = to_uint i{2} /\ + 0 <= j{1} <= 960 /\ j{1} = to_uint j{2} /\ + j{1} *4 = i{1} * 5 /\ ={aa} /\ + (forall kk, 0 <= kk < j{1} => rp{1}.[kk] = rp{2}.[kk])); last by auto => />*; split;[ smt() | move => *; rewrite tP => *; smt()]. +unroll for* {2} 2;auto => /> &1 &2;rewrite !ultE /= => ????????;do split; 1,2,4,5:smt();1..3,5..:by rewrite ?to_uintD_small;smt(). +move => kk kkb ?. +rewrite !to_uintD_small /=;1..7:smt(). +case (kk < to_uint j{2}); by smt(Array960.get_setE). qed. -op srl_64 (w1 w2 : W64.t) : W64.t = - if (64 <= to_uint w2) then W64.zero else w1 `>>` (truncateu8 w2). +(*****************************************************************) -bind op [W64.t] srl_64 "shr". -realize bvshrP. -rewrite /srl_64 => bv1 bv2. -case : (64 <= to_uint bv2); last first. -+ rewrite /(`>>`) W64.to_uint_shr; 1: by smt(W8.to_uint_cmp). - rewrite /truncateu8 => bv2bnd />. - do 2! (rewrite (pmod_small (to_uint bv2) _);smt(W64.to_uint_cmp)). -move => *. -have -> : to_uint bv2 = (to_uint bv2 - 64) + 64 by ring. -rewrite exprD_nneg 1,2:/# /=. -smt(StdOrder.IntOrder.expr_gt0 W64.to_uint_cmp pow2_64). -qed. op lane_func_reduce(c : W16.t) : W16.t = let t = (sigextu32 c) * (W32.of_int 20159) in @@ -2545,17 +2727,52 @@ bdep 16 10 [_bp] [bp] [rp] lane_polyvec_redcomp10 pcond_all. + by smt(). qed. -lemma ref_correctness_p (_bp : W16.t Array768.t) : phoare [ AuxPolyVecCompress10.ref : -_bp = bp ==> -map lane_polyvec_redcomp10 (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _bp))]))) = - map W10.bits2w (chunk 10 (flatten [flatten (map W8.w2bits (to_list res))])) +lemma ref_correctness_p (_bp : W16.t Array768.t) : +phoare [ AuxPolyVecCompress10.ref : + _bp = bp ==> + map lane_polyvec_redcomp10 (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _bp))]))) = + map W10.bits2w (chunk 10 (flatten [flatten (map W8.w2bits (to_list res))])) ] = 1%r. -admitted. - -lemma avx_correctness_p (_bp : W16.t Array768.t) : phoare [ AuxPolyVecCompress10.avx2 : _bp = bp ==> -map lane_polyvec_redcomp10 (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _bp))]))) = - map W10.bits2w (chunk 10 (flatten [flatten (map W8.w2bits (to_list res))]))] = 1%r. -admitted. +proof. +have Hll: islossless AuxPolyVecCompress10.ref. + proc; inline*. + wp; while (true) (3*256-i). + by move=> *; auto => /> /#. + wp; while (true) (256-i2). + by move=> *; auto => /> /#. + wp; while (true) (256-i1). + by move=> *; auto => /> /#. + wp; while (true) (256-i0). + by move=> *; auto => /> /#. + wp; while (true) (256-j2). + by move=> *; auto => /> /#. + wp; while (true) (256-j1). + by move=> *; auto => /> /#. + wp; while (true) (256-j0). + by move=> *; auto => /> /#. + by auto => /> /#. +by conseq Hll (ref_correctness _bp). +qed. + +lemma avx_correctness_p (_bp : W16.t Array768.t) : + phoare [ AuxPolyVecCompress10.avx2 + : _bp = bp ==> + map lane_polyvec_redcomp10 (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list _bp))]))) = + map W10.bits2w (chunk 10 (flatten [flatten (map W8.w2bits (to_list res))]))] = 1%r. +proof. +have Hll: islossless AuxPolyVecCompress10.avx2. + proc; inline*. + wp; while (true) (aux0-i). + by move=> *; auto => /> /#. + wp; while (true) (16-i2). + by move=> *; auto => /> /#. + wp; while (true) (16-i1). + by move=> *; auto => /> /#. + wp; while (true) (16-i0). + by move=> *; auto => /> /#. + auto => /> /#. +by conseq Hll (avx_correctness _bp). +qed. lemma unbits16 (bpl : W16.t list) : (map W16.bits2w (chunk 16 (flatten (map W16.w2bits bpl)))) = bpl. @@ -2762,7 +2979,7 @@ have H := polyvec_add2_equiv_noperm 2 2 _ _ => //. ecall (H (lift_array768 bp{2}) (lift_array768 ep{2})); clear H. -unroll for {1} 39. +unroll for* {1} 39. swap {1} 3 -2; swap {2} 3 -2; seq 1 1: (#pre /\ ={pkp0} /\ pkp0{2}=pkp{1}); 1: by auto. sp 3 3. @@ -3147,7 +3364,7 @@ have H := polyvec_add2_equiv_noperm 2 2 _ _ => //. ecall (H (lift_array768 bp{2}) (lift_array768 ep{2})); clear H. -unroll for {1} 39. +unroll for* {1} 39. swap {1} 3 -2; swap {2} 3 -2; seq 1 1: (#pre /\ ={pkp0} /\ pkp0{2} = pkp{1}); 1: by auto. sp 3 3. diff --git a/proof/correctness/avx2/MLKEM_avx2_encdec.ec b/proof/correctness/avx2/MLKEM_avx2_encdec.ec index 0ea7cfd2..074090a0 100644 --- a/proof/correctness/avx2/MLKEM_avx2_encdec.ec +++ b/proof/correctness/avx2/MLKEM_avx2_encdec.ec @@ -346,10 +346,10 @@ equiv eq_encode1 : EncDec_AVX2.encode1 ~ EncDec.encode1 : ={a} /\ all (fun (x : int) => 0 <= x && x < 2) a{1} ==> ={res}. proc. -unroll for {1} ^while. +unroll for* {1} ^while. do 8!(unroll for {1} ^while). -unroll for {2} ^while. -do 32!(unroll for {2} ^while). +unroll for* {2} ^while. +do 32!(unroll for* {2} ^while). auto => /> &m. rewrite (initSet _ (fun i => W8.init (fun (j : int) => W8.int_bit a{m}.[8*i + j] 0) )) //=. rewrite (initSet _ (fun i => W8.of_int (((((((a{m}.[8*i+0] %% 256 + a{m}.[8*i+1] * 2) %% 256 + a{m}.[8*i+2] * 4) %% 256 + a{m}.[8*i+3] * 8) %% 256 + a{m}.[8*i+4] * 16) %% 256 + a{m}.[8*i+5] * 32) %% 256 + a{m}.[8*i+6] * 64) %% 256 + a{m}.[8*i+7] * 128) )) //=. @@ -400,9 +400,9 @@ equiv eq_decode1_opt : EncDec_AVX2.decode1_opt ~ EncDec.decode1 : ={a} ==> ={res}. proc. -unroll for {1} ^while. -do 8!(unroll for {1} ^while). -unroll for {2} ^while. +unroll for* {1} ^while. +do 8!(unroll for* {1} ^while). +unroll for* {2} ^while. auto=> /= &1 &2 ->; rewrite -!fillCE /=. by do 32! ((do 8! rewrite fillCSm 1://); rewrite /= fillC0). qed. @@ -412,10 +412,10 @@ equiv decode10_vec_corr: proof. proc. swap 2 1. - unroll for {1} ^while. - do 3!(unroll for {1} ^while). + unroll for* {1} ^while. + do 3!(unroll for* {1} ^while). swap {2} 2 1. - unroll for {2} ^while. + unroll for* {2} ^while. by auto => />. qed. @@ -425,7 +425,7 @@ proof. proc. while (#pre /\ ={k} /\ 0 <= k{1} <= 3 /\ j{2} = 320*k{2} /\ (forall j, 0 <= j < 256*k{1} => c{1}.[j] = c{2}.[j])). - unroll for {2} 2. + unroll for* {2} 2. wp; skip; auto => />. move => &1 &2 [#] k_lb k_ub c_eq k_tub />. rewrite (mulzDr 320 _ _) (mulzDr 256 _ _) mulz1 /=. @@ -465,10 +465,10 @@ equiv encode10_vec_corr: proof. proc. swap 2 1. - unroll for {1} ^while. - do 48!(unroll for {1} ^while). + unroll for* {1} ^while. + do 48!(unroll for* {1} ^while). swap {2} 2 1. - unroll for {2} ^while. + unroll for* {2} ^while. by auto => />. qed. @@ -478,7 +478,7 @@ proof. proc. while (#pre /\ ={i} /\ 0 <= i{1} <= 48 /\ j{2} = 20*i{2} /\ (forall k, 0 <= k < 20*i{1} => c{1}.[k] = c{2}.[k])). - unroll for {2} 2. + unroll for* {2} 2. wp; skip; auto => />. move => &1 &2 [#] i_lb i_ub c_eq i_tub />. rewrite (mulzDr 20 _ _) mulz1 //=. @@ -517,7 +517,7 @@ equiv encode12_avx2_corr: EncDec_AVX2.encode12 ~ EncDec.encode12: ={a} ==> ={res}. proof. proc. - unroll for {1} ^while. + unroll for* {1} ^while. splitwhile {2} 4: (i < 128). wp. while (0<=k{1}<=64 /\ 128<=i{2}<=256 /\ i{2} = 2*k{1} + 128 /\ j{2} = 192 * i{1} + 3 * k{1} /\ i{1} = 1 /\ ={r,a}). @@ -533,7 +533,7 @@ proof. proc. while (#pre /\ i{1} = i{2} /\ 0 <= i{1} <= 2 /\ (forall k, 0 <= k < 192 * i{1} => r{1}.[k] = r{2}.[k])). - unroll for {2} 2. + unroll for* {2} 2. wp; skip; auto => />. move => &1 &2 [#] i_lb i_ub r1_eq_r2 i_tub />. split. @@ -587,9 +587,9 @@ equiv decode12_avx2_corr: EncDec_AVX2.decode12 ~ EncDec.decode12: ={a} ==> ={res}. proof. proc. - unroll for {1} ^while. - do 2!(unroll for {1} ^while). - unroll for {2} ^while. + unroll for* {1} ^while. + do 2!(unroll for* {1} ^while). + unroll for* {2} ^while. by auto => />. qed. @@ -599,7 +599,7 @@ proof. proc. while (#pre /\ i{1} = i{2} /\ 0 <= i{1} <= 2 /\ (forall k, 0 <= k < 128 * i{1} => r{1}.[k] = r{2}.[k])). - unroll for {2} 2. + unroll for* {2} 2. wp; skip; auto => />. move => &1 &2 [#] i_lb i_ub r1_eq_r2 i_tub />. split. @@ -650,9 +650,9 @@ qed. equiv eq_decode4: EncDec_AVX2.decode4 ~ EncDec.decode4: ={a} ==> ={res}. proc. -unroll for {1} ^while. -do 16!(unroll for {1} ^while). -unroll for {2} ^while. +unroll for* {1} ^while. +do 16!(unroll for* {1} ^while). +unroll for* {2} ^while. by auto => />. qed. @@ -660,9 +660,9 @@ equiv eq_encode4: EncDec_AVX2.encode4 ~ EncDec.encode4: ={p} ==> ={res}. proof. proc. - unroll for {1} ^while. - do 4!(unroll for {1} ^while). - unroll for {2} ^while. + unroll for* {1} ^while. + do 4!(unroll for* {1} ^while). + unroll for* {2} ^while. by auto => />. qed. diff --git a/proof/security/FO_MLKEM.ec b/proof/security/FO_MLKEM.ec index 698c599f..bb0f505f 100644 --- a/proof/security/FO_MLKEM.ec +++ b/proof/security/FO_MLKEM.ec @@ -1,4 +1,4 @@ -require import AllCore Distr List Real SmtMap FSet DInterval FinType KEM_ROM. +require import AllCore Distr List Real FMap FSet DInterval FinType KEM_ROM. require (****) PKE_ROM PlugAndPray Hybrid FelTactic. require FO_UU. @@ -238,7 +238,7 @@ seq 6 5 : (#pre /\ m0{1} = m{2} /\ pk1{1} = pk0{2} /\ dkey `*` randd; last by smt(). rewrite dprod_dlet; congr;apply fun_ext => r. by rewrite dlet_dunit. - by auto => />;smt(get_setE mem_set mem_empty @SmtMap @FSet). + by auto => />;smt(get_setE mem_set mem_empty @FMap @FSet). seq 1 1 : (#pre /\ c1{1} = c0{2}); 1: by auto => /#. @@ -484,7 +484,7 @@ call(: ={glob CCA} /\ B1x2._pk{1} = CCA.sk{2}.`1.`1 /\ by auto => />;smt(get_setE). rcondf{1} 2;1: by move => *;inline *;auto => />. - by inline *;auto => /> *; do split;move => *;do split;move => *;1:do split;smt(@SmtMap). + by inline *;auto => /> *; do split;move => *;do split;move => *;1:do split;smt(@FMap). inline *. swap {1} 14 -13. swap {1} 20 -18. swap {2} 11 -10. diff --git a/proof/security/FO_TT.ec b/proof/security/FO_TT.ec index 48636a7d..5e29186f 100644 --- a/proof/security/FO_TT.ec +++ b/proof/security/FO_TT.ec @@ -1,4 +1,4 @@ -require import AllCore Distr List Real SmtMap FSet DInterval. +require import AllCore Distr List Real FMap FSet DInterval. require (****) FinType PKE_ROM PlugAndPray Hybrid FelTactic. (******************************************************************) diff --git a/proof/security/FO_UU.ec b/proof/security/FO_UU.ec index 781b13c6..b6952d80 100644 --- a/proof/security/FO_UU.ec +++ b/proof/security/FO_UU.ec @@ -1,4 +1,4 @@ -require import AllCore Distr List Real SmtMap FSet DInterval FinType KEM_ROM. +require import AllCore Distr List Real FMap FSet DInterval FinType KEM_ROM. require (****) PKE_ROM PlugAndPray Hybrid FelTactic. (* This will be the underlying scheme resulting @@ -964,7 +964,7 @@ seq 2 2 : (={glob A,k1} /\ k0{1} = k2{2}); 1: by auto. (!H1.bad{2} <=> Some H2.mtgt{2} = dec CCA.sk{2}.`1.`2 (oget CCA.cstar{2}))). auto => /> &2 f Hf kpair Hkpair b Hb m Hm. - smt(mem_empty get_setE fdom_set @SmtMap @FSet @List). + smt(mem_empty get_setE fdom_set @FMap @FSet @List). case (H1.bad{1}). rnd;wp;call(:H1.bad,false,CCA.cstar{2} <> None /\ diff --git a/proof/security/MLWE.ec b/proof/security/MLWE.ec index d1b805a0..1d8a6752 100644 --- a/proof/security/MLWE.ec +++ b/proof/security/MLWE.ec @@ -1,4 +1,5 @@ -require import AllCore Ring SmtMap Distr PROM. +require import AllCore Ring Distr FMap PROM. + require (****) Matrix. clone import Matrix as Matrix_. @@ -258,10 +259,11 @@ swap {2} 11 -8. swap {2} 14 -13. swap {2} [15..16] -10. swap {1} 5 -2. + seq 3 6 : (#pre /\ ={b,_A,sd} /\ (RO.m{1}.[B._sd{2}] = Some B.__A{2}) /\ B.__A{2} = _A{2} /\ B._sd{2} = sd{2} /\ (forall x, x <> B._sd{2} => RO.m{1}.[x] = RO.m{2}.[x])); - first by inline *; auto => />; smt(@SmtMap). + first by inline *; auto => />; smt(@FMap). wp;call(: (RO.m{1}.[B._sd{2}] = Some B.__A{2}) /\ forall x, x <> B._sd{2} => RO.m{1}.[x] = RO.m{2}.[x]). proc;inline *. @@ -290,7 +292,7 @@ seq 3 6 : (#pre /\ ={b,sd} /\ _A{1} = m_transpose _A{2} /\ Bt.__A{2} = m_transpose _A{2} /\ Bt._sd{2} = sd{2} /\ (forall x, x <> Bt._sd{2} => RO.m{1}.[x] = RO.m{2}.[x])). + inline *; wp; rnd (fun m => m_transpose m) (fun m => m_transpose m). - by auto => />; smt(@SmtMap trmxK duni_matrix_funi). + by auto => />; smt(@FMap trmxK duni_matrix_funi). wp;call(: (RO.m{1}.[Bt._sd{2}] = Some Bt.__A{2}) /\ forall x, x <> Bt._sd{2} => RO.m{1}.[x] = RO.m{2}.[x]). proc;inline *. @@ -658,7 +660,7 @@ seq 4 0 : (#pre /\ (_A = oget RO.m.[sd]){1}); 1: seq 1 1 : #pre; last by auto => /> &1 ?; rewrite /dout duni_matrix_ll. exists* _A{1}, sd{1}; elim * => _A1 sd1. call(: ={RO.m} /\ sd1 \in RO.m{1} /\ _A1 = oget RO.m{1}.[sd1]). -+ proc; auto => /> &2 rl ??; 1:smt(@SmtMap). ++ proc; auto => /> &2 rl ??; 1:smt(@FMap). by auto => />. qed. @@ -694,7 +696,7 @@ seq 4 0 : (#pre /\ (_A = oget RO.m.[sd]){1}); 1: seq 1 1 : #pre; last by auto => /> &1 ?; rewrite /dout duni_matrix_ll. exists* _A{1}, sd{1}; elim * => _A1 sd1. call(: ={RO.m} /\ sd1 \in RO.m{1} /\ _A1 = oget RO.m{1}.[sd1]). -+ proc; auto => /> &2 rl ??; 1:smt(@SmtMap). ++ proc; auto => /> &2 rl ??; 1:smt(@FMap). by auto => />. qed. @@ -1060,7 +1062,7 @@ seq 2 2 : #pre; last by auto => /> &1 ?; rewrite /dout duni_matrix_ll. exists* _A{1}, sd{1}; elim * => _A1 sd1. call(: ={RO.m, glob Sim} /\ sd1 \in RO.m{1} /\ _A1 = oget RO.m{1}.[sd1]); last by call(_:true); auto => />. proc*;call(: ={RO.m} /\ sd1 \in RO.m{1} /\ _A1 = oget RO.m{1}.[sd1]); last by auto => />. -by proc; inline *; auto => />; smt(@SmtMap). +by proc; inline *; auto => />; smt(@FMap). qed. lemma MLWE_SMP_equiv_t _b &m (S <: Sampler {-LRO, -RO, -FRO, -RO_SMP.LRO, -D}) @@ -1122,7 +1124,7 @@ seq 2 2 : #pre; last by auto => /> &1 ?; rewrite /dout duni_matrix_ll. exists* _A{1}, sd{1}; elim * => _A1 sd1. call(: ={RO.m, glob Sim} /\ sd1 \in RO.m{1} /\ _A1 = oget RO.m{1}.[sd1]); last by call(_: true);auto => />. proc*;call(: ={RO.m} /\ sd1 \in RO.m{1} /\ _A1 = oget RO.m{1}.[sd1]); last by auto => />. -by proc; inline *; auto => />; smt(@SmtMap). +by proc; inline *; auto => />; smt(@FMap). qed. end SMP_vs_ROM_IND. diff --git a/proof/security/MLWE_PKE_Hash.ec b/proof/security/MLWE_PKE_Hash.ec index c86707cc..53ffa20b 100644 --- a/proof/security/MLWE_PKE_Hash.ec +++ b/proof/security/MLWE_PKE_Hash.ec @@ -1,4 +1,4 @@ -require import AllCore Distr List SmtMap Dexcepted PKE_ROM StdOrder. +require import AllCore Distr List FMap Dexcepted PKE_ROM StdOrder. require (**RndExcept **) MLWE FLPRG. theory MLWE_PKE_Hash. diff --git a/proof/spec/MLKEMSecurity.ec b/proof/spec/MLKEMSecurity.ec index f90c6971..1976b488 100644 --- a/proof/spec/MLKEMSecurity.ec +++ b/proof/spec/MLKEMSecurity.ec @@ -1170,8 +1170,8 @@ transitivity {2} { rho <$ srand; noise1 <@ CBD2rnd.sample_vec_real(); noise2 <@ by symmetry;wp; do 2!call(CBD2rnd_vec_equiv); auto => />. seq 1 5: (_N{2} = 0 /\ ={rho} /\ - forall (x:W8.t), SmtMap.dom RF.m{2} x => W8.to_uint x < _N{2}). - + inline *; swap {2} 1 1; auto;conseq (: _ ==> ={rho}); 1:by smt(SmtMap.mem_empty). + forall (x:W8.t), FMap.dom RF.m{2} x => W8.to_uint x < _N{2}). + + inline *; swap {2} 1 1; auto;conseq (: _ ==> ={rho}); 1:by smt(FMap.mem_empty). rndsem*{2} 0; auto => />. have -> : (dfst dRO) = srand; last by smt(). apply eq_distr => x;rewrite dmap1E. @@ -1181,17 +1181,17 @@ transitivity {2} { rho <$ srand; noise1 <@ CBD2rnd.sample_vec_real(); noise2 <@ by rewrite srand_ll /=. seq 1 2: (={noise1,rho} /\ _N{2} = 3 /\ - forall (x:W8.t), SmtMap.dom RF.m{2} x => W8.to_uint x < _N{2}). + forall (x:W8.t), FMap.dom RF.m{2} x => W8.to_uint x < _N{2}). inline*; wp. while (i0{1} = i{2} /\ 0 <= i{2} <= kvec /\ _N{2}=i{2} /\ (forall k, 0 <= k < i{2} => (v{1}.[k]=noise1{2}.[k])%PolyVec) /\ - forall (x:W8.t), SmtMap.dom RF.m{2} x => W8.to_uint x < _N{2}). + forall (x:W8.t), FMap.dom RF.m{2} x => W8.to_uint x < _N{2}). rcondt {2} 4. + move=> *; wp; skip => &hr /> ??? Hm ?. rewrite -implybF => H. by move: (Hm _ H); rewrite implybF of_uintK /#. wp; while (#[/:4,7:]pre /\ ={bytes} /\ i1{1} = i0{2} /\ 0 <= i0{2} <= 128 /\ j{2} = i0{2}*2 /\ - (forall (x1 : W8.t), SmtMap.dom RF.m{2} x1 => to_uint x1 <= _N{2}) /\ + (forall (x1 : W8.t), FMap.dom RF.m{2} x1 => to_uint x1 <= _N{2}) /\ forall k, 0 <= k < j{2} => p0{1}.[k] = rr{2}.[k]). auto => />. move => &1 &2 *; do split; 1..3:smt(). move=> k ?? /=. @@ -1201,10 +1201,10 @@ transitivity {2} { rho <$ srand; noise1 <@ CBD2rnd.sample_vec_real(); noise2 <@ by rewrite set_eqiE 1..2:/# set_eqiE /#. by rewrite set_neqiE 1..2:/# set_neqiE 1..2:/# set_neqiE 1..2:/# set_neqiE /#. wp; rnd; wp; skip => /> &1 &2; rewrite !setvE !getvE => ?????????; split. - split; 1: by rewrite SmtMap.get_set_sameE. + split; 1: by rewrite FMap.get_set_sameE. move=> x; case: (x=W8.of_int i{2}) => E. by move=> _; rewrite E of_uintK modz_small /#. - rewrite SmtMap.domE SmtMap.get_set_neqE 1:// => H. + rewrite FMap.domE FMap.get_set_neqE 1:// => H. by apply StdOrder.IntOrder.ltrW; smt(). move => p1 i0 p2 ?????? H; split; first smt(). have EE: p1 = p2. @@ -1217,17 +1217,17 @@ transitivity {2} { rho <$ srand; noise1 <@ CBD2rnd.sample_vec_real(); noise2 <@ move=> v1 m i v2 => ??????; split; last smt(). apply eq_vectorP => k kb;smt(setvE getvE). wp; seq 2 2: (={rho,noise1,noise2} /\ _N{2} = 6 /\ - forall (x:W8.t), SmtMap.dom RF.m{2} x => W8.to_uint x < _N{2}). + forall (x:W8.t), FMap.dom RF.m{2} x => W8.to_uint x < _N{2}). inline*; wp. while (i0{1} = i{2} /\ 0 <= i{2} <= kvec /\ _N{2}=3+i{2} /\ noise1{1}=noise1{2} /\ (forall k, 0 <= k < i{2} => (v{1}.[k]=noise2{2}.[k])%PolyVec) /\ - forall (x:W8.t), SmtMap.dom RF.m{2} x => W8.to_uint x < _N{2}). + forall (x:W8.t), FMap.dom RF.m{2} x => W8.to_uint x < _N{2}). rcondt {2} 4. move=> *; wp; skip => &hr /> ??? Hm ?. rewrite -implybF => H. by move: (Hm _ H); rewrite implybF of_uintK /#. wp; while (#[/:5,8:]pre /\ bytes{1}=bytes{2} /\ i1{1}=i0{2} /\ 0 <= i0{2} <= 128 /\ j{2} = i0{2}*2 /\ - (forall (x1 : W8.t), SmtMap.dom RF.m{2} x1 => to_uint x1 <= _N{2}) /\ + (forall (x1 : W8.t), FMap.dom RF.m{2} x1 => to_uint x1 <= _N{2}) /\ forall k, 0 <= k < j{2} => p0{1}.[k] = rr{2}.[k]). wp; skip => /> &1&2 *; split; first smt(). split; first smt(). @@ -1238,10 +1238,10 @@ transitivity {2} { rho <$ srand; noise1 <@ CBD2rnd.sample_vec_real(); noise2 <@ by rewrite set_eqiE 1..2:/# set_eqiE /#. by rewrite set_neqiE 1..2:/# set_neqiE 1..2:/# set_neqiE 1..2:/# set_neqiE /#. wp; rnd; wp; skip => /> &1 &2; rewrite !getvE !setvE => ?????????; split. - split; 1: by rewrite SmtMap.get_set_sameE. + split; 1: by rewrite FMap.get_set_sameE. move=> x; case: (x=W8.of_int (3+i{2})) => E. by move=> _; rewrite E of_uintK modz_small /#. - rewrite SmtMap.domE SmtMap.get_set_neqE 1:// => H. + rewrite FMap.domE FMap.get_set_neqE 1:// => H. by apply StdOrder.IntOrder.ltrW; smt(). move => p1 i1 p2 ?????? H; split; first smt(). have EE: p1 = p2. @@ -1309,21 +1309,21 @@ transitivity {2} { noise1 <@ CBD2rnd.sample_vec_real(); noise2 <@ CBD2rnd.sample by symmetry;wp; call(CBD2rnd_equiv); do 2!call(CBD2rnd_vec_equiv); auto => />. seq 0 4: (_N{2} = 0 /\ - forall (x:W8.t), SmtMap.dom RF.m{2} x => W8.to_uint x < _N{2}). - + by auto => />;smt(SmtMap.mem_empty). + forall (x:W8.t), FMap.dom RF.m{2} x => W8.to_uint x < _N{2}). + + by auto => />;smt(FMap.mem_empty). seq 1 2: (={noise1} /\ _N{2} = 3 /\ - forall (x:W8.t), SmtMap.dom RF.m{2} x => W8.to_uint x < _N{2}). + forall (x:W8.t), FMap.dom RF.m{2} x => W8.to_uint x < _N{2}). inline*; wp. while (i0{1} = i{2} /\ 0 <= i{2} <= kvec /\ _N{2}=i{2} /\ (forall k, 0 <= k < i{2} => (v{1}.[k]=noise1{2}.[k])%PolyVec) /\ - forall (x:W8.t), SmtMap.dom RF.m{2} x => W8.to_uint x < _N{2}). + forall (x:W8.t), FMap.dom RF.m{2} x => W8.to_uint x < _N{2}). rcondt {2} 4. + move=> *; wp; skip => &hr /> ??? Hm ?. rewrite -implybF => H. by move: (Hm _ H); rewrite implybF of_uintK /#. wp; while (#[/:4,7:]pre /\ ={bytes} /\ i1{1} = i0{2} /\ 0 <= i0{2} <= 128 /\ j{2} = i0{2}*2 /\ - (forall (x1 : W8.t), SmtMap.dom RF.m{2} x1 => to_uint x1 <= _N{2}) /\ + (forall (x1 : W8.t), FMap.dom RF.m{2} x1 => to_uint x1 <= _N{2}) /\ forall k, 0 <= k < j{2} => p0{1}.[k] = rr{2}.[k]). auto => />. move => &1 &2 *; do split; 1..3:smt(). move=> k ?? /=. @@ -1333,10 +1333,10 @@ transitivity {2} { noise1 <@ CBD2rnd.sample_vec_real(); noise2 <@ CBD2rnd.sample by rewrite set_eqiE 1..2:/# set_eqiE /#. by rewrite set_neqiE 1..2:/# set_neqiE 1..2:/# set_neqiE 1..2:/# set_neqiE /#. wp; rnd; wp; skip => /> &1 &2; rewrite !setvE !getvE => ?????????; split. - split; 1: by rewrite SmtMap.get_set_sameE. + split; 1: by rewrite FMap.get_set_sameE. move=> x; case: (x=W8.of_int i{2}) => E. by move=> _; rewrite E of_uintK modz_small /#. - rewrite SmtMap.domE SmtMap.get_set_neqE 1:// => H. + rewrite FMap.domE FMap.get_set_neqE 1:// => H. by apply StdOrder.IntOrder.ltrW; smt(). move => p1 i0 p2 ?????? H; split; first smt(). have EE: p1 = p2. @@ -1349,17 +1349,17 @@ transitivity {2} { noise1 <@ CBD2rnd.sample_vec_real(); noise2 <@ CBD2rnd.sample move=> v1 m i v2 => ??????; split; last smt(). apply eq_vectorP => k kb;smt(setvE getvE). wp; seq 1 2: (={noise1,noise2} /\ _N{2} = 6 /\ - forall (x:W8.t), SmtMap.dom RF.m{2} x => W8.to_uint x < _N{2}). + forall (x:W8.t), FMap.dom RF.m{2} x => W8.to_uint x < _N{2}). inline*; wp. while (i0{1} = i{2} /\ 0 <= i{2} <= kvec /\ _N{2}=3+i{2} /\ noise1{1}=noise1{2} /\ (forall k, 0 <= k < i{2} => (v{1}.[k]=noise2{2}.[k])%PolyVec) /\ - forall (x:W8.t), SmtMap.dom RF.m{2} x => W8.to_uint x < _N{2}). + forall (x:W8.t), FMap.dom RF.m{2} x => W8.to_uint x < _N{2}). rcondt {2} 4. move=> *; wp; skip => &hr /> ??? Hm ?. rewrite -implybF => H. by move: (Hm _ H); rewrite implybF of_uintK /#. wp; while (#[/:5,8:]pre /\ bytes{1}=bytes{2} /\ i1{1}=i0{2} /\ 0 <= i0{2} <= 128 /\ j{2} = i0{2}*2 /\ - (forall (x1 : W8.t), SmtMap.dom RF.m{2} x1 => to_uint x1 <= _N{2}) /\ + (forall (x1 : W8.t), FMap.dom RF.m{2} x1 => to_uint x1 <= _N{2}) /\ forall k, 0 <= k < j{2} => p0{1}.[k] = rr{2}.[k]). wp; skip => /> &1&2 *; split; first smt(). split; first smt(). @@ -1370,10 +1370,10 @@ transitivity {2} { noise1 <@ CBD2rnd.sample_vec_real(); noise2 <@ CBD2rnd.sample by rewrite set_eqiE 1..2:/# set_eqiE /#. by rewrite set_neqiE 1..2:/# set_neqiE 1..2:/# set_neqiE 1..2:/# set_neqiE /#. wp; rnd; wp; skip => /> &1 &2; rewrite !getvE !setvE => ?????????; split. - split; 1: by rewrite SmtMap.get_set_sameE. + split; 1: by rewrite FMap.get_set_sameE. move=> x; case: (x=W8.of_int (3+i{2})) => E. by move=> _; rewrite E of_uintK modz_small /#. - rewrite SmtMap.domE SmtMap.get_set_neqE 1:// => H. + rewrite FMap.domE FMap.get_set_neqE 1:// => H. by apply StdOrder.IntOrder.ltrW; smt(). move => p1 i1 p2 ?????? H; split; first smt(). have EE: p1 = p2. @@ -1403,7 +1403,7 @@ transitivity {2} { noise1 <@ CBD2rnd.sample_vec_real(); noise2 <@ CBD2rnd.sample by rewrite set_eqiE 1..2:/# set_eqiE /#. by rewrite set_neqiE 1..2:/# set_neqiE 1..2:/# set_neqiE 1..2:/# set_neqiE /#. wp; rnd; wp; skip => /> &1 &2 ????; split. - by rewrite SmtMap.get_set_eqE //=. + by rewrite FMap.get_set_eqE //=. move => p1 i1 p2 ????? H. by apply Array256.tP => k kb; apply H; smt(). qed.