From a6f35b0d2791280893dc66a34e34718c74ecce3e Mon Sep 17 00:00:00 2001 From: Manuel Barbosa Date: Fri, 20 Sep 2024 18:32:01 +0100 Subject: [PATCH] Going for equivs --- .../avx2/MLKEM_Poly_avx2_proof_mr.ec | 6 - .../correctness/avx2/MLKEM_avx2_equivs_mr.ec | 485 +++++++++++++++--- 2 files changed, 404 insertions(+), 87 deletions(-) diff --git a/proof/correctness/avx2/MLKEM_Poly_avx2_proof_mr.ec b/proof/correctness/avx2/MLKEM_Poly_avx2_proof_mr.ec index 91a0d453..b449df59 100644 --- a/proof/correctness/avx2/MLKEM_Poly_avx2_proof_mr.ec +++ b/proof/correctness/avx2/MLKEM_Poly_avx2_proof_mr.ec @@ -25,12 +25,6 @@ import BitEncoding.BitChunking. timeout 1. -require import AllCore List. -from Jasmin require import JModel. -import BitEncoding.BitChunking. - -require import MLKEMFCLib. -require import Array8 Array256 WArray512 Array16 WArray32 Array128 WArray128. (*************************) (*************************) (* BEGIN BINDINGS *) diff --git a/proof/correctness/avx2/MLKEM_avx2_equivs_mr.ec b/proof/correctness/avx2/MLKEM_avx2_equivs_mr.ec index 187843f8..a99d20f0 100644 --- a/proof/correctness/avx2/MLKEM_avx2_equivs_mr.ec +++ b/proof/correctness/avx2/MLKEM_avx2_equivs_mr.ec @@ -37,6 +37,410 @@ import MLKEM_PolyvecAVX. import MLKEM_PolyAVXVec. import MLKEM_PolyVecAVXVec. + +import BitEncoding.BitChunking. + +require import Array128 WArray128. + + +timeout 1. + +(*************************) +(*************************) +(* BEGIN BINDINGS *) +(*************************) +(*************************) +(*************************) + +require import QFABV. + +theory W4. + abbrev [-printing] size = 4. + clone include BitWordSH with op size <- size + rename "_XX" as "_4" + proof gt0_size by done, + size_le_256 by done. + +end W4. export W4 W4.ALU W4.SHIFT. + +bind bitstring W4.w2bits W4.bits2w W4.t 4. +realize ge0_size by auto. +realize size_tolist by exact W4.size_w2bits. +realize tolistK by exact W4.w2bitsK. +realize oflistK by exact W4.bits2wK. +bind bitstring W256.w2bits W256.bits2w W256.t 256. +realize ge0_size by auto. +realize size_tolist by exact W256.size_w2bits. +realize tolistK by exact W256.w2bitsK. +realize oflistK by exact W256.bits2wK. +bind bitstring W64.w2bits W64.bits2w W64.t 64. +realize ge0_size by auto. +realize size_tolist by exact W64.size_w2bits. +realize tolistK by exact W64.w2bitsK. +realize oflistK by exact W64.bits2wK. +bind bitstring W32.w2bits W32.bits2w W32.t 32. +realize ge0_size by auto. +realize size_tolist by exact W32.size_w2bits. +realize tolistK by exact W32.w2bitsK. +realize oflistK by exact W32.bits2wK. +bind bitstring W16.w2bits W16.bits2w W16.t 16. +realize ge0_size by auto. +realize size_tolist by exact W16.size_w2bits. +realize tolistK by exact W16.w2bitsK. +realize oflistK by exact W16.bits2wK. +bind bitstring W8.w2bits W8.bits2w W8.t 8. +realize ge0_size by auto. +realize size_tolist by exact W8.size_w2bits. +realize tolistK by exact W8.w2bitsK. +realize oflistK by exact W8.bits2wK. +bind circuit VPSUB_16u16 "VPSUB_16u16". +bind circuit VPSRA_16u16 "VPSRA_16u16". +bind circuit VPADD_16u16 "VPADD_16u16". +bind circuit W256.(`&`) "AND_256". +bind circuit VPBROADCAST_16u16 "VPBROADCAST_16u16". +bind circuit VPMULH_16u16 "VPMULH_16u16". +bind circuit VPMULHRS_16u16 "VPMULHRS_16u16". +bind circuit VPACKUS_16u16 "VPACKUS_16u16". +bind circuit VPMADDUBSW_256 "VPMADDUBSW_256". +bind circuit VPERMD "VPERMD". + + +bind op W16.(+) "bvadd". +realize bvaddP. +move => bv1 bv2. +by rewrite /BV_W16_t.toint -!W16.to_uintE !W16.to_uintD W16.to_uintE. +qed. + +bind op W64.(+) "bvadd". +realize bvaddP. +move => bv1 bv2. +by rewrite /BV_W64_t.toint -!W64.to_uintE !W64.to_uintD W64.to_uintE. +qed. + +op sliceget_256_256_16(a : W16.t Array256.t, i : int) : W256.t = + WArray512.get256 (WArray512.init16 (fun (i_0 : int) => a.[i_0])) i. + +lemma sliceget_256_256_16E (a : W16.t Array256.t) (i : int) : + WArray512.get256 (WArray512.init16 (fun (i_0 : int) => a.[i_0])) i = + sliceget_256_256_16 a i by auto. + +op sliceset_256_256_16(a : W16.t Array256.t,i : int, x : W256.t) : W16.t Array256.t = + Array256.init (fun (i0 : int) => WArray512.get16 (WArray512.set256 ((WArray512.init16 (fun (i_0 : int) => a.[i_0]))) i x) i0). + +lemma sliceset_256_256_16E (a : W16.t Array256.t) (i : int) (x : W256.t) : + Array256.init (fun (i0 : int) => WArray512.get16 (WArray512.set256 ((WArray512.init16 (fun (i_0 : int) => a.[i_0]))) i x) i0) = + sliceset_256_256_16 a i x by auto. + +op sliceget_256_16_16(a : W16.t Array16.t, i : int) : W256.t = + WArray32.get256 (WArray32.init16 (fun (i_0 : int) => a.[i_0])) i. + +lemma sliceget_256_16_16E (a : W16.t Array16.t) (i : int) : + WArray32.get256 (WArray32.init16 (fun (i_0 : int) => a.[i_0])) i = + sliceget_256_16_16 a i by auto. + +op sliceget_256_8_32(a : W32.t Array8.t, i : int) : W256.t = + WArray32.get256 (WArray32.init32 (fun (i_0 : int) => a.[i_0])) i. + +lemma sliceget_256_8_32E (a : W32.t Array8.t) (i : int) : + WArray32.get256 (WArray32.init32 (fun (i_0 : int) => a.[i_0])) i = + sliceget_256_8_32 a i by auto. + +op sliceset_256_128_8(a : W8.t Array128.t,i : int, x : W256.t) : W8.t Array128.t = + Array128.init ((WArray128.get8 ((WArray128.set256 ((WArray128.init8 (fun (i_0 : int) => a.[i_0]))) i x)))). + +lemma sliceset_256_128_8E (a : W8.t Array128.t) (i : int) (x : W256.t) : + Array128.init ((WArray128.get8 ((WArray128.set256 ((WArray128.init8 (fun (i_0 : int) => a.[i_0]))) i x)))) = + sliceset_256_128_8 a i x by auto. + + +bind bitstring circuit Array256."_.[_]" Array256."_.[_<-_]" Array256.to_list (W16.t Array256.t) 256. + +bind bitstring circuit Array16."_.[_]" Array16."_.[_<-_]" Array16.to_list (W16.t Array16.t) 16. + +bind bitstring circuit Array128."_.[_]" Array128."_.[_<-_]" Array128.to_list (W8.t Array128.t) 128. + +bind bitstring circuit Array8."_.[_]" Array8."_.[_<-_]" Array8.to_list (W32.t Array8.t) 8. + +(*************************) +(*************************) +(* END BINDINGS *) +(*************************) +(*************************) +(*************************) + +(*************************) +(*************************) +(* begin aux lemmas *) +(*************************) +(*************************) +(*************************) + + +lemma get_vs_bits_256u16_size(wa : W16.t Array256.t) : + size (chunk 16 (flatten [flatten (map W16.w2bits (to_list wa))])) = 256. + rewrite flatten1 size_chunk // size_flatten // -map_comp /(\o). + rewrite (eq_map _ (fun _ => 16)) => //=. + have -> : map (fun (_ : W16.t) => 16) (to_list wa) = + mkseq (fun _ => 16) 256; last + by rewrite /mkseq -iotaredE /= /sumz /=. + apply (eq_from_nth witness). + + by rewrite size_map Array256.size_to_list size_mkseq /#. + move => i. + rewrite size_map Array256.size_to_list => ib. + rewrite (nth_map witness) /=;1: by smt(Array256.size_to_list). + by rewrite (nth_mkseq) /#. +qed. + +lemma get_vs_bits_256u16(wa : W16.t Array256.t) k : + 0 <= k < 256 => + nth witness (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list wa))]))) k = wa.[k]. +move => kb. + rewrite (nth_map witness);1:smt(get_vs_bits_256u16_size size_map). + rewrite (nth_change_dfl [] witness); + 1: by smt(get_vs_bits_256u16_size). + rewrite JWordList.nth_chunk //; move : kb => [#kb0 kb1] //. + rewrite flatten1 size_flatten /= /sumz /=. + + have -> : (map List.size (map W16.w2bits (to_list wa))) = + mkseq (fun _ => 16) 256. + apply (eq_from_nth witness);1: + by rewrite !size_map size_iota. + move => i. rewrite !size_map size_iota /max /= => ib. + rewrite nth_mkseq 1:/# /=. + rewrite (nth_map witness);1:by rewrite size_map Array256.size_to_list /#. + rewrite (nth_map witness); 1:by rewrite Array256.size_to_list /#. + by rewrite size_w2bits. + by rewrite /mkseq -iotaredE /= /#. + rewrite flatten1 drop_flatten_ctt. + by move => x; rewrite mapP => H;elim H => x0; smt(W16.size_w2bits). + rewrite -map_drop. + rewrite (: 16 = 16*1) 1:/# take_flatten_ctt. + by move => x; rewrite mapP => H;elim H => x0; smt(W16.size_w2bits). + rewrite -map_take (drop_take1_nth witness) /=;1:smt(Array256.size_to_list). + by rewrite flatten1 w2bitsK. +qed. + +lemma get_vs_bits_pre_256u16 (wa : W16.t Array256.t) (f : W16.t -> bool) (g : W16.t -> bool) : (forall x, f x => g x) => + (forall (k : int), 0 <= k && k < 256 => f wa.[k]) => + all g (map W16.bits2w (chunk 16 (flatten [flatten (map W16.w2bits (to_list wa))]))). +proof. + move => H H0. + apply (all_nthP _ _ witness) => k kb /=. + by rewrite get_vs_bits_256u16;smt(get_vs_bits_256u16_size size_map). +qed. + +(*************************) +(*************************) +(* end aux lemmas *) +(*************************) +(*************************) +(*************************) + +print Jkem.M. + +module Aux(SC : Jkem.Syscall_t) = { +proc _poly_csubq(rp : W16.t Array256.t) : W16.t Array256.t = { + var i : int; + var t : W16.t; + var b : W16.t; + + i <- 0; + while (i < 256){ + t <- rp.[i]; + t <- t - (of_int 3329)%W16; + b <- t; + b <- b `|>>` (of_int 15)%W8; + b <- b `&` (of_int 3329)%W16; + t <- t + b; + rp.[i] <- t; + i <- i + 1; + } + + return rp; + } + +proc __polyvec_csubq(r : W16.t Array768.t) : W16.t Array768.t = { + var aux : W16.t Array256.t; + + aux <@ _poly_csubq((init (fun (i : int) => r.[0 + i]))%Array256); + r <- (init (fun (i : int) => if 0 <= i && i < 0 + 256 then aux.[i - 0] else r.[i]))%Array768; + aux <@ _poly_csubq((init (fun (i : int) => r.[256 + i]))%Array256); + r <- (init (fun (i : int) => if 256 <= i && i < 256 + 256 then aux.[i - 256] else r.[i]))%Array768; + aux <@ _poly_csubq((init (fun (i : int) => r.[2 * 256 + i]))%Array256); + r <- (init (fun (i : int) => if 2 * 256 <= i && i < 2 * 256 + 256 then aux.[i - 2 * 256] else r.[i]))%Array768; + + return r; + } + + +proc __i_polyvec_compress(rp : W8.t Array960.t, a : W16.t Array768.t) : W8.t Array960.t = { + var aux : int; + var i : int; + var j : int; + var aa : W16.t Array768.t; + var k : int; + var t : W64.t Array4.t; + var c : W16.t; + var b : W16.t; + + aa <- witness; + t <- witness; + i <- 0; + j <- 0; + aa <@ __polyvec_csubq(a); + while (i < (3 * 256 - 3)){ + k <- 0; + while (k < 4){ + t.[k] <- zeroextu64 aa.[i+k]; + t.[k] <- t.[k] `<<` (of_int 10)%W8; + t.[k] <- t.[k] + (of_int 1665)%W64; + t.[k] <- t.[k] * (of_int 1290167)%W64; + t.[k] <- t.[k] `>>` (of_int 32)%W8; + t.[k] <- t.[k] `&` (of_int 1023)%W64; + k <- k + 1; + } + c <- truncateu16 t.[0]; + c <- c `&` (of_int 255)%W16; + rp.[j] <- truncateu8 c; + j <- j + 1; + b <- truncateu16 t.[0]; + b <- b `>>` (of_int 8)%W8; + c <- truncateu16 t.[1]; + c <- c `<<` (of_int 2)%W8; + c <- c `|` b; + rp.[j] <- truncateu8 c; + j <- j + 1; + b <- truncateu16 t.[1]; + b <- b `>>` (of_int 6)%W8; + c <- truncateu16 t.[2]; + c <- c `<<` (of_int 4)%W8; + c <- c `|` b; + rp.[j] <- truncateu8 c; + j <- j + 1; + b <- truncateu16 t.[2]; + b <- b `>>` (of_int 4)%W8; + c <- truncateu16 t.[3]; + c <- c `<<` (of_int 6)%W8; + c <- c `|` b; + rp.[j] <- truncateu8 c; + j <- j + 1; + t.[3] <- t.[3] `>>` (of_int 2)%W8; + rp.[j] <- truncateu8 t.[3]; + j <- j + 1; + i <- i + 4; + } + + return rp; + } + +}. + + +equiv auxcsubq : Jkem.M(Jkem.Syscall)._poly_csubq ~ Aux(Jkem.Syscall)._poly_csubq : ={arg} ==> ={res}. +proc. +while(={rp} /\ to_uint i{1} = i{2} /\ 0 <= i{2} <= 256). ++ auto => /> &1;rewrite !ultE /= => ????. + rewrite !to_uintD_small /=; smt(). +by auto. +qed. + +equiv auxcsubqv : Jkem.M(Jkem.Syscall).__polyvec_csubq ~ Aux(Jkem.Syscall).__polyvec_csubq : ={arg} ==> ={res}. +proc. +do 3!(wp;call auxcsubq). +by auto => />. +qed. + +equiv auxcompress : Jkem.M(Jkem.Syscall).__i_polyvec_compress ~ Aux(Jkem.Syscall).__i_polyvec_compress : ={arg} ==> ={res}. +proc. +while (0<=i{2}<=768 /\ to_uint i{1} = i{2} /\ to_uint j{1} = j{2} /\ to_uint j{1} * 4 = i{2} * 5 /\ ={rp,aa}). ++ unroll for {1} 2; unroll for {2} 2; auto => /> &1. + by rewrite !ultE /= => &2????;rewrite !to_uintD_small /=; smt(). +by call auxcsubqv;auto => />. +qed. + +equiv compressequivvec_1 mem : + Jkem_avx2.M(Jkem_avx2.Syscall).__polyvec_compress_1 ~ Jkem.M(Jkem.Syscall).__i_polyvec_compress : + pos_bound768_cxq a{1} 0 768 2 /\ + pos_bound768_cxq a{2} 0 768 2 /\ + lift_array768 a{1} = lift_array768 a{2} /\ + ={Glob.mem} /\ Glob.mem{1} = mem + ==> + ={Glob.mem,res} /\ Glob.mem{1} = mem. +proof. +proc. + transitivity MLKEM_PolyVec_avx2_prevec.Mprevec.polyvec_compress_1 + (={rp, a, Glob.mem} ==> ={res, Glob.mem}) + (pos_bound768_cxq a{1} 0 768 2 /\ + pos_bound768_cxq a{2} 0 768 2 /\ + lift_array768 a{1} = lift_array768 a{2} /\ ={Glob.mem} /\ Glob.mem{1} = mem + ==> + ={Glob.mem, res} /\ Glob.mem{2} = mem); 1,2: smt(). + + symmetry. proc * => /=. call prevec_eq_polyvec_compress_1 => //=. + transitivity EncDec_AVX2.encode10_opt_vec + (a{2} = compress_polyvec 10 (lift_polyvec a{1}) /\ + pos_bound768_cxq a{1} 0 768 2 /\ (forall i, 0<=i<768 => 0 <= a{2}.[i] < q) /\ + Glob.mem{1} = mem /\ ={Glob.mem} ==> + Glob.mem{1} = mem /\ + ={res}) + (pos_bound768_cxq a{2} 0 768 2 /\ (forall i, 0<=i<768 => 0 <= a{1}.[i] < q) /\ + a{1} = compress_polyvec 10 (lift_polyvec a{2}) /\ + Glob.mem{2} = mem /\ ={Glob.mem} ==> + Glob.mem{2} = mem /\ + ={res}). + auto => &1 &2 [#] pos_bound_al pos_bound_ar al_eq_ar />. + exists Glob.mem{2}. + exists (compress_polyvec 10 (lift_polyvec a{1})). + rewrite pos_bound_al pos_bound_ar /=. + do split. + + move => i ib; rewrite /compress_polyvec /lift_polyvec !mapiE //=. + pose x := fromarray256 _ _ _. + move : (compress_rng x.[i] 10 _) => //=; smt (qE). + + move => i ib; rewrite /compress_polyvec /lift_polyvec !mapiE //=. + pose x := fromarray256 _ _ _. + move : (compress_rng x.[i] 10 _) => //=; smt (qE). + + congr; rewrite /lift_polyvec KMatrix.Vector.eq_vectorP => i ib /=. + rewrite !KMatrix.Vector.offunvE /kvec //=. + rewrite /lift_array768 /subarray256 /lift_array256 tP => k kb. + rewrite !mapiE //= !initiE //=. + smt(@Array768). + + smt(). + + proc * => /=. + ecall (polyvec_compress_1_corr (lift_polyvec a{1}) mem) => //=. + symmetry. + transitivity EncDec.encode10_vec + (u{2} = compress_polyvec 10 (lift_polyvec a{1}) /\ + pos_bound768_cxq a{1} 0 768 2 /\ (forall i, 0<=i<768 => 0 <= u{2}.[i] < q) /\ + Glob.mem{1} = mem /\ ={Glob.mem} ==> + Glob.mem{1} = mem /\ ={res}) + ((forall i, 0<=i<768 => 0 <= u{1}.[i] < q) /\ + u{1} = a{2} /\ + ={Glob.mem} ==> ={Glob.mem, res}). + auto => &1 &2 [#] pos_bound_a a2_bnd a1_eq_a2 mem_eq />. + exists Glob.mem{1}. + exists (compress_polyvec 10 (lift_polyvec a{1})). + auto => />;do split. + + move => i ibl ibh; rewrite /compress_polyvec /lift_polyvec !mapiE //=. + pose x := fromarray256 _ _ _. + move : (compress_rng x.[i] 10 _) => //=; smt (qE). + + move => i ibl ibh; rewrite /compress_polyvec /lift_polyvec !mapiE //=. + pose x := fromarray256 _ _ _. + move : (compress_rng x.[i] 10 _) => //=; smt (qE). + + smt(). + smt(). + + proc * => /=. + ecall (MLKEM_PolyVec.i_polyvec_compress_corr (lift_array768 a{1})) => //=. + auto => /> &1 H H0. + + rewrite /compress_polyvec; congr. + rewrite /fromarray256 /lift_polyvec /lift_array768 tP => i ib /=. + rewrite !initiE //= !mapiE //= !getvE !KMatrix.Vector.offunvE //=. + rewrite /subarray256 /lift_array256 /=. + smt(@Array256). + symmetry. + proc * => /=. + call encode10_opt_corr. + auto => />. +qed. + + lemma polyvec_decompress_equiv mem _p : equiv [Jkem_avx2.M(Jkem_avx2.Syscall).__polyvec_decompress ~ Jkem.M(Jkem.Syscall).__polyvec_decompress : valid_ptr _p (3*320) /\ @@ -240,87 +644,6 @@ proof. auto => />. qed. -equiv compressequivvec_1 mem : - Jkem_avx2.M(Jkem_avx2.Syscall).__polyvec_compress_1 ~ Jkem.M(Jkem.Syscall).__i_polyvec_compress : - pos_bound768_cxq a{1} 0 768 2 /\ - pos_bound768_cxq a{2} 0 768 2 /\ - lift_array768 a{1} = lift_array768 a{2} /\ - ={Glob.mem} /\ Glob.mem{1} = mem - ==> - ={Glob.mem,res} /\ Glob.mem{1} = mem. -proof. - transitivity MLKEM_PolyVec_avx2_prevec.Mprevec.polyvec_compress_1 - (={rp, a, Glob.mem} ==> ={res, Glob.mem}) - (pos_bound768_cxq a{1} 0 768 2 /\ - pos_bound768_cxq a{2} 0 768 2 /\ - lift_array768 a{1} = lift_array768 a{2} /\ ={Glob.mem} /\ Glob.mem{1} = mem - ==> - ={Glob.mem, res} /\ Glob.mem{2} = mem); 1,2: smt(). - + symmetry. proc * => /=. call prevec_eq_polyvec_compress_1 => //=. - transitivity EncDec_AVX2.encode10_opt_vec - (a{2} = compress_polyvec 10 (lift_polyvec a{1}) /\ - pos_bound768_cxq a{1} 0 768 2 /\ (forall i, 0<=i<768 => 0 <= a{2}.[i] < q) /\ - Glob.mem{1} = mem /\ ={Glob.mem} ==> - Glob.mem{1} = mem /\ - ={res}) - (pos_bound768_cxq a{2} 0 768 2 /\ (forall i, 0<=i<768 => 0 <= a{1}.[i] < q) /\ - a{1} = compress_polyvec 10 (lift_polyvec a{2}) /\ - Glob.mem{2} = mem /\ ={Glob.mem} ==> - Glob.mem{2} = mem /\ - ={res}). - auto => &1 &2 [#] pos_bound_al pos_bound_ar al_eq_ar />. - exists Glob.mem{2}. - exists (compress_polyvec 10 (lift_polyvec a{1})). - rewrite pos_bound_al pos_bound_ar /=. - do split. - + move => i ib; rewrite /compress_polyvec /lift_polyvec !mapiE //=. - pose x := fromarray256 _ _ _. - move : (compress_rng x.[i] 10 _) => //=; smt (qE). - + move => i ib; rewrite /compress_polyvec /lift_polyvec !mapiE //=. - pose x := fromarray256 _ _ _. - move : (compress_rng x.[i] 10 _) => //=; smt (qE). - + congr; rewrite /lift_polyvec KMatrix.Vector.eq_vectorP => i ib /=. - rewrite !KMatrix.Vector.offunvE /kvec //=. - rewrite /lift_array768 /subarray256 /lift_array256 tP => k kb. - rewrite !mapiE //= !initiE //=. - smt(@Array768). - + smt(). - + proc * => /=. - ecall (polyvec_compress_1_corr (lift_polyvec a{1}) mem) => //=. - symmetry. - transitivity EncDec.encode10_vec - (u{2} = compress_polyvec 10 (lift_polyvec a{1}) /\ - pos_bound768_cxq a{1} 0 768 2 /\ (forall i, 0<=i<768 => 0 <= u{2}.[i] < q) /\ - Glob.mem{1} = mem /\ ={Glob.mem} ==> - Glob.mem{1} = mem /\ ={res}) - ((forall i, 0<=i<768 => 0 <= u{1}.[i] < q) /\ - u{1} = a{2} /\ - ={Glob.mem} ==> ={Glob.mem, res}). - auto => &1 &2 [#] pos_bound_a a2_bnd a1_eq_a2 mem_eq />. - exists Glob.mem{1}. - exists (compress_polyvec 10 (lift_polyvec a{1})). - auto => />;do split. - + move => i ibl ibh; rewrite /compress_polyvec /lift_polyvec !mapiE //=. - pose x := fromarray256 _ _ _. - move : (compress_rng x.[i] 10 _) => //=; smt (qE). - + move => i ibl ibh; rewrite /compress_polyvec /lift_polyvec !mapiE //=. - pose x := fromarray256 _ _ _. - move : (compress_rng x.[i] 10 _) => //=; smt (qE). - + smt(). - smt(). - + proc * => /=. - ecall (MLKEM_PolyVec.i_polyvec_compress_corr (lift_array768 a{1})) => //=. - auto => /> &1 H H0. - + rewrite /compress_polyvec; congr. - rewrite /fromarray256 /lift_polyvec /lift_array768 tP => i ib /=. - rewrite !initiE //= !mapiE //= !getvE !KMatrix.Vector.offunvE //=. - rewrite /subarray256 /lift_array256 /=. - smt(@Array256). - symmetry. - proc * => /=. - call encode10_opt_corr. - auto => />. -qed. lemma poly_decompress_equiv mem _p : equiv [Jkem_avx2.M(Jkem_avx2.Syscall)._poly_decompress ~ Jkem.M(Jkem.Syscall)._poly_decompress :