From 39a8c6ed58464c5634928f6a692cf9a7b5a543f5 Mon Sep 17 00:00:00 2001 From: Pierre-Yves Strub Date: Mon, 22 Jul 2024 09:46:13 +0200 Subject: [PATCH] MLKEM_avx2_encdec.ec: make some proofs faster --- proof/correctness/avx2/MLKEM_avx2_encdec.ec | 60 ++++++++++++++++----- 1 file changed, 48 insertions(+), 12 deletions(-) diff --git a/proof/correctness/avx2/MLKEM_avx2_encdec.ec b/proof/correctness/avx2/MLKEM_avx2_encdec.ec index 4827ff4e..0ea7cfd2 100644 --- a/proof/correctness/avx2/MLKEM_avx2_encdec.ec +++ b/proof/correctness/avx2/MLKEM_avx2_encdec.ec @@ -358,6 +358,44 @@ have Hjm: (0 <= j < 0 + 8). rewrite Hj => />. move :Hjm. rewrite -mema_iota -iot rewrite !modzDml -addrA !modzDml -!addrA !modzDml -!addrA !modzDml !modz_dvd //=. do 8!(try (move => Hjm; case Hjm => />); first by smt()). qed. +lemma fill0 f k (t : ipoly): fill f k 0 t = t. +proof. +rewrite /fill &(Array256.tP) /= => i *. +by rewrite !initE /#. +qed. + +lemma fillSm f k n (t : ipoly): 0 < n => + fill f k n t = fill f (k+1) (n-1) t.[k <- f k]. +proof. +move=> ge0_n @/fill; apply/Array256.tP => i rg_i. +by rewrite !initE /= get_set_if /#. +qed. + +lemma fill_swap (i j ki kj : int) f g (t : ipoly) : + i + ki <= j => fill f i ki (fill g j kj t) = fill g j kj (fill f i ki t). +proof. +move=> rg @/fill; apply/Array256.tP => k *. +do rewrite initE /=. smt(). +qed. + +op [opaque] fillC ['a] = Array256.fill<:'a>. + +lemma fillCE ['a] : fillC<:'a> = Array256.fill. +proof. by rewrite /fillC. qed. + +lemma fillC_swap (i j ki kj : int) f g (t : ipoly) : + i + ki <= j => fillC f i ki (fillC g j kj t) = fillC g j kj (fillC f i ki t). +proof. by rewrite !fillCE &(fill_swap). qed. + +lemma fillC0 f k (t : ipoly): fillC f k 0 t = t. +proof. by rewrite fillCE &(fill0). qed. + +lemma fillCSm f k n (t : ipoly): 0 < n => + fillC f k n t = fillC f (k+1) (n - 1) t.[k <- f k]. +proof. by rewrite fillCE &(fillSm). qed. + +hint simplify fillC_swap. + equiv eq_decode1_opt : EncDec_AVX2.decode1_opt ~ EncDec.decode1 : ={a} ==> ={res}. @@ -365,10 +403,8 @@ proc. unroll for {1} ^while. do 8!(unroll for {1} ^while). unroll for {2} ^while. -auto => /> &m. -apply tP_red256. -move => i. -(* FIXME: TOO LONG => *) do 255!(move => Hi; case Hi => |>). +auto=> /= &1 &2 ->; rewrite -!fillCE /=. +by do 32! ((do 8! rewrite fillCSm 1://); rewrite /= fillC0). qed. equiv decode10_vec_corr: @@ -392,23 +428,23 @@ proof. unroll for {2} 2. wp; skip; auto => />. move => &1 &2 [#] k_lb k_ub c_eq k_tub />. - rewrite (mulzDr 320 _ _) (mulzDr 256 _ _) mulz1 //=. + rewrite (mulzDr 320 _ _) (mulzDr 256 _ _) mulz1 /=. split. + move : k_lb k_tub => /#. + move => j j_lb j_ub. - rewrite filliE 1:/# //=. - rewrite j_ub //=. + rewrite filliE 1:/# ~-1://=. + rewrite j_ub ~-1://=. case (j < 256 * k{2}) => j_tub. + have -> /=: !(256 * k{2} <= j). by rewrite -ltzNge j_tub. rewrite c_eq; first by rewrite j_lb j_tub. - do (rewrite Array768.set_neqiE 1:/#; first by move : j_tub j_lb => /#). - done. + rewrite !Array768.get_set_if ~-1://. + by have /#: k{2} = 1 \/ k{2} = 2 by smt(). + move : j_tub => /lezNgt j_tlb. rewrite j_tlb /=. have j_iota: j \in iota_ (256*k{2}) 256; first by rewrite mem_iota j_ub j_tlb. - move : j_iota. - do (rewrite Array768.get_setE 1:/#). - smt(mem_iota). + move : j_iota; rewrite !Array768.get_set_if ~-1://. + have: k{2} = 0 \/ k{2} = 1 \/ k{2} = 2 by smt(). + by rewrite -iotaredE /=; do 2! (move=> [#|] -> /=). auto => />. move => cL cR k k_tlb _ k_lb k_ub. have -> /=: k = 3. move : k_tlb k_ub => /#.