From 54e7069fca1dbfcbd746c645b6ab371c37cb9299 Mon Sep 17 00:00:00 2001 From: Manuel Barbosa Date: Wed, 11 Dec 2024 18:59:00 +0100 Subject: [PATCH] pending --- .../extraction/jkem_avx2_stack.ec | 63 +- code/jasmin/mlkem_avx2_stack/indcpa.jinc | 4 +- code/jasmin/mlkem_avx2_stack/jkem.s | 577 ++++++++-------- code/jasmin/mlkem_avx2_stack/kem.jinc | 2 +- code/jasmin/mlkem_avx2_stack/verify.jinc | 44 +- .../avx2/MLKEM_InnerPKE_avx2_stack.ec | 34 + .../correctness/avx2/MLKEM_KEM_avx2_stack.ec | 642 +++++++++++++++++- 7 files changed, 969 insertions(+), 397 deletions(-) create mode 100644 proof/correctness/avx2/MLKEM_InnerPKE_avx2_stack.ec diff --git a/code/jasmin/mlkem_avx2_stack/extraction/jkem_avx2_stack.ec b/code/jasmin/mlkem_avx2_stack/extraction/jkem_avx2_stack.ec index 8347222a..69b0d0c2 100644 --- a/code/jasmin/mlkem_avx2_stack/extraction/jkem_avx2_stack.ec +++ b/code/jasmin/mlkem_avx2_stack/extraction/jkem_avx2_stack.ec @@ -9643,7 +9643,7 @@ module M = { return ct; } proc __indcpa_dec (msgp:W8.t Array32.t, ct:W8.t Array1088.t, - sk:W8.t Array2400.t) : W8.t Array32.t = { + sk:W8.t Array1152.t) : W8.t Array32.t = { var bp:W16.t Array768.t; var v:W16.t Array256.t; var skpv:W16.t Array768.t; @@ -9657,7 +9657,7 @@ module M = { bp <@ __i_polyvec_decompress (ct); v <@ _i_poly_decompress (v, (Array128.init (fun i => ct.[((3 * 320) + i)]))); - skpv <@ __i_polyvec_frombytes ((Array1152.init (fun i => sk.[(0 + i)]))); + skpv <@ __i_polyvec_frombytes (sk); bp <@ __polyvec_ntt (bp); t <@ __polyvec_pointwise_acc (t, skpv, bp); t <@ _poly_invntt (t); @@ -9675,9 +9675,6 @@ module M = { var f:W256.t; var g:W256.t; var zf:bool; - var off:int; - var t1:W8.t; - var t2:W8.t; var _0:bool; var _1:bool; var _2:bool; @@ -9697,63 +9694,23 @@ module M = { } ( _0, _1, _2, _3, zf) <- (VPTEST_256 h h); cnd <- ((! zf) ? t64 : cnd); - off <- ((((3 * 320) + 128) %/ 32) * 32); - aux <- ((3 * 320) + 128); - i <- off; - while ((i < aux)) { - t1 <- (get8_direct (WArray1088.init8 (fun i_0 => ct.[i_0])) i); - t2 <- (get8_direct (WArray1088.init8 (fun i_0 => ctpc.[i_0])) i); - t1 <- (t1 `^` t2); - t64 <- (zeroextu64 t1); - cnd <- (cnd `|` t64); - i <- (i + 1); - } - cnd <- (- cnd); - cnd <- (cnd `>>` (W8.of_int 63)); return cnd; } proc __cmov (dst:W8.t Array32.t, src:W8.t Array32.t, cnd:W64.t) : W8.t Array32.t = { - var aux:int; var scnd:W64.t; var m:W256.t; - var i:int; var f:W256.t; var g:W256.t; - var off:int; - var bcond:W8.t; - var t2:W8.t; - var t1:W8.t; cnd <- (- cnd); scnd <- cnd; m <- (VPBROADCAST_4u64 scnd); - aux <- (32 %/ 32); - i <- 0; - while ((i < aux)) { - f <- (get256_direct (WArray32.init8 (fun i_0 => src.[i_0])) (32 * i)); - g <- (get256_direct (WArray32.init8 (fun i_0 => dst.[i_0])) (32 * i)); - f <- (VPBLENDVB_256 f g m); - dst <- - (Array32.init - (WArray32.get8 - (WArray32.set256_direct (WArray32.init8 (fun i_0 => dst.[i_0])) - (32 * i) f))); - i <- (i + 1); - } - off <- ((32 %/ 32) * 32); - bcond <- (truncateu8 cnd); - i <- off; - while ((i < 32)) { - t2 <- (get8_direct (WArray32.init8 (fun i_0 => dst.[i_0])) i); - t1 <- src.[i]; - t2 <- (t2 `^` t1); - t2 <- (t2 `&` (truncateu8 cnd)); - t1 <- (t1 `^` t2); - dst <- - (Array32.init - (WArray32.get8 - (WArray32.set8_direct (WArray32.init8 (fun i_0 => dst.[i_0])) i t1))); - i <- (i + 1); - } + f <- (get256_direct (WArray32.init8 (fun i => src.[i])) 0); + g <- (get256_direct (WArray32.init8 (fun i => dst.[i])) 0); + f <- (VPBLENDVB_256 f g m); + dst <- + (Array32.init + (WArray32.get8 + (WArray32.set256_direct (WArray32.init8 (fun i => dst.[i])) 0 f))); return dst; } proc __crypto_kem_keypair_jazz (pk:W8.t Array1184.t, sk:W8.t Array2400.t, @@ -9909,7 +9866,7 @@ module M = { ) i_0)) ); aux <@ __indcpa_dec ((Array32.init (fun i_0 => buf.[(0 + i_0)])), - ct, s_sk); + ct, (Array1152.init (fun i_0 => s_sk.[(0 + i_0)]))); buf <- (Array64.init (fun i_0 => (if (0 <= i_0 < (0 + 32)) then aux.[(i_0 - 0)] else buf.[i_0])) diff --git a/code/jasmin/mlkem_avx2_stack/indcpa.jinc b/code/jasmin/mlkem_avx2_stack/indcpa.jinc index 39745c3e..86d9fed2 100644 --- a/code/jasmin/mlkem_avx2_stack/indcpa.jinc +++ b/code/jasmin/mlkem_avx2_stack/indcpa.jinc @@ -143,7 +143,7 @@ fn __indcpa_enc( inline -fn __indcpa_dec(reg ptr u8[MLKEM_INDCPA_MSGBYTES] msgp, reg ptr u8[MLKEM_CIPHERTEXTBYTES] ct, reg ptr u8[MLKEM_SECRETKEYBYTES] sk) -> reg ptr u8[MLKEM_INDCPA_MSGBYTES] +fn __indcpa_dec(reg ptr u8[MLKEM_INDCPA_MSGBYTES] msgp, reg ptr u8[MLKEM_CIPHERTEXTBYTES] ct, reg ptr u8[MLKEM_POLYVECBYTES] sk) -> reg ptr u8[MLKEM_INDCPA_MSGBYTES] { stack u16[MLKEM_N] t v mp; stack u16[MLKEM_VECN] bp skpv; @@ -151,7 +151,7 @@ fn __indcpa_dec(reg ptr u8[MLKEM_INDCPA_MSGBYTES] msgp, reg ptr u8[MLKEM_CIPHERT bp = __i_polyvec_decompress(ct); v = _i_poly_decompress(v, ct[MLKEM_POLYVECCOMPRESSEDBYTES:MLKEM_POLYCOMPRESSEDBYTES]); - skpv = __i_polyvec_frombytes(sk[0:MLKEM_POLYVECBYTES]); + skpv = __i_polyvec_frombytes(sk); bp = __polyvec_ntt(bp); t = __polyvec_pointwise_acc(t, skpv, bp); diff --git a/code/jasmin/mlkem_avx2_stack/jkem.s b/code/jasmin/mlkem_avx2_stack/jkem.s index abc94c96..f04eb526 100644 --- a/code/jasmin/mlkem_avx2_stack/jkem.s +++ b/code/jasmin/mlkem_avx2_stack/jkem.s @@ -1895,294 +1895,293 @@ jade_kem_mlkem_mlkem768_amd64_avx2_dec: call L_i_poly_decompress$1 Ljade_kem_mlkem_mlkem768_amd64_avx2_dec$60: leaq 24(%rsp), %rsp - movq %rdx, %rax - movq (%rax), %rdx - movq %rdx, 5536(%rsp) - movq 8(%rax), %rdx - movq %rdx, 5544(%rsp) - movq 16(%rax), %rdx - movq %rdx, 5552(%rsp) - movq 24(%rax), %rdx - movq %rdx, 5560(%rsp) - movq 32(%rax), %rdx - movq %rdx, 5568(%rsp) - movq 40(%rax), %rdx - movq %rdx, 5576(%rsp) - movq 48(%rax), %rdx - movq %rdx, 5584(%rsp) - movq 56(%rax), %rdx - movq %rdx, 5592(%rsp) - movq 64(%rax), %rdx - movq %rdx, 5600(%rsp) - movq 72(%rax), %rdx - movq %rdx, 5608(%rsp) - movq 80(%rax), %rdx - movq %rdx, 5616(%rsp) - movq 88(%rax), %rdx - movq %rdx, 5624(%rsp) - movq 96(%rax), %rdx - movq %rdx, 5632(%rsp) - movq 104(%rax), %rdx - movq %rdx, 5640(%rsp) - movq 112(%rax), %rdx - movq %rdx, 5648(%rsp) - movq 120(%rax), %rdx - movq %rdx, 5656(%rsp) - movq 128(%rax), %rdx - movq %rdx, 5664(%rsp) - movq 136(%rax), %rdx - movq %rdx, 5672(%rsp) - movq 144(%rax), %rdx - movq %rdx, 5680(%rsp) - movq 152(%rax), %rdx - movq %rdx, 5688(%rsp) - movq 160(%rax), %rdx - movq %rdx, 5696(%rsp) - movq 168(%rax), %rdx - movq %rdx, 5704(%rsp) - movq 176(%rax), %rdx - movq %rdx, 5712(%rsp) - movq 184(%rax), %rdx - movq %rdx, 5720(%rsp) - movq 192(%rax), %rdx - movq %rdx, 5728(%rsp) - movq 200(%rax), %rdx - movq %rdx, 5736(%rsp) - movq 208(%rax), %rdx - movq %rdx, 5744(%rsp) - movq 216(%rax), %rdx - movq %rdx, 5752(%rsp) - movq 224(%rax), %rdx - movq %rdx, 5760(%rsp) - movq 232(%rax), %rdx - movq %rdx, 5768(%rsp) - movq 240(%rax), %rdx - movq %rdx, 5776(%rsp) - movq 248(%rax), %rdx - movq %rdx, 5784(%rsp) - movq 256(%rax), %rdx - movq %rdx, 5792(%rsp) - movq 264(%rax), %rdx - movq %rdx, 5800(%rsp) - movq 272(%rax), %rdx - movq %rdx, 5808(%rsp) - movq 280(%rax), %rdx - movq %rdx, 5816(%rsp) - movq 288(%rax), %rdx - movq %rdx, 5824(%rsp) - movq 296(%rax), %rdx - movq %rdx, 5832(%rsp) - movq 304(%rax), %rdx - movq %rdx, 5840(%rsp) - movq 312(%rax), %rdx - movq %rdx, 5848(%rsp) - movq 320(%rax), %rdx - movq %rdx, 5856(%rsp) - movq 328(%rax), %rdx - movq %rdx, 5864(%rsp) - movq 336(%rax), %rdx - movq %rdx, 5872(%rsp) - movq 344(%rax), %rdx - movq %rdx, 5880(%rsp) - movq 352(%rax), %rdx - movq %rdx, 5888(%rsp) - movq 360(%rax), %rdx - movq %rdx, 5896(%rsp) - movq 368(%rax), %rdx - movq %rdx, 5904(%rsp) - movq 376(%rax), %rdx - movq %rdx, 5912(%rsp) - movq 384(%rax), %rdx - movq %rdx, 5920(%rsp) - movq 392(%rax), %rdx - movq %rdx, 5928(%rsp) - movq 400(%rax), %rdx - movq %rdx, 5936(%rsp) - movq 408(%rax), %rdx - movq %rdx, 5944(%rsp) - movq 416(%rax), %rdx - movq %rdx, 5952(%rsp) - movq 424(%rax), %rdx - movq %rdx, 5960(%rsp) - movq 432(%rax), %rdx - movq %rdx, 5968(%rsp) - movq 440(%rax), %rdx - movq %rdx, 5976(%rsp) - movq 448(%rax), %rdx - movq %rdx, 5984(%rsp) - movq 456(%rax), %rdx - movq %rdx, 5992(%rsp) - movq 464(%rax), %rdx - movq %rdx, 6000(%rsp) - movq 472(%rax), %rdx - movq %rdx, 6008(%rsp) - movq 480(%rax), %rdx - movq %rdx, 6016(%rsp) - movq 488(%rax), %rdx - movq %rdx, 6024(%rsp) - movq 496(%rax), %rdx - movq %rdx, 6032(%rsp) - movq 504(%rax), %rdx - movq %rdx, 6040(%rsp) - movq 512(%rax), %rdx - movq %rdx, 6048(%rsp) - movq 520(%rax), %rdx - movq %rdx, 6056(%rsp) - movq 528(%rax), %rdx - movq %rdx, 6064(%rsp) - movq 536(%rax), %rdx - movq %rdx, 6072(%rsp) - movq 544(%rax), %rdx - movq %rdx, 6080(%rsp) - movq 552(%rax), %rdx - movq %rdx, 6088(%rsp) - movq 560(%rax), %rdx - movq %rdx, 6096(%rsp) - movq 568(%rax), %rdx - movq %rdx, 6104(%rsp) - movq 576(%rax), %rdx - movq %rdx, 6112(%rsp) - movq 584(%rax), %rdx - movq %rdx, 6120(%rsp) - movq 592(%rax), %rdx - movq %rdx, 6128(%rsp) - movq 600(%rax), %rdx - movq %rdx, 6136(%rsp) - movq 608(%rax), %rdx - movq %rdx, 6144(%rsp) - movq 616(%rax), %rdx - movq %rdx, 6152(%rsp) - movq 624(%rax), %rdx - movq %rdx, 6160(%rsp) - movq 632(%rax), %rdx - movq %rdx, 6168(%rsp) - movq 640(%rax), %rdx - movq %rdx, 6176(%rsp) - movq 648(%rax), %rdx - movq %rdx, 6184(%rsp) - movq 656(%rax), %rdx - movq %rdx, 6192(%rsp) - movq 664(%rax), %rdx - movq %rdx, 6200(%rsp) - movq 672(%rax), %rdx - movq %rdx, 6208(%rsp) - movq 680(%rax), %rdx - movq %rdx, 6216(%rsp) - movq 688(%rax), %rdx - movq %rdx, 6224(%rsp) - movq 696(%rax), %rdx - movq %rdx, 6232(%rsp) - movq 704(%rax), %rdx - movq %rdx, 6240(%rsp) - movq 712(%rax), %rdx - movq %rdx, 6248(%rsp) - movq 720(%rax), %rdx - movq %rdx, 6256(%rsp) - movq 728(%rax), %rdx - movq %rdx, 6264(%rsp) - movq 736(%rax), %rdx - movq %rdx, 6272(%rsp) - movq 744(%rax), %rdx - movq %rdx, 6280(%rsp) - movq 752(%rax), %rdx - movq %rdx, 6288(%rsp) - movq 760(%rax), %rdx - movq %rdx, 6296(%rsp) - movq 768(%rax), %rdx - movq %rdx, 6304(%rsp) - movq 776(%rax), %rdx - movq %rdx, 6312(%rsp) - movq 784(%rax), %rdx - movq %rdx, 6320(%rsp) - movq 792(%rax), %rdx - movq %rdx, 6328(%rsp) - movq 800(%rax), %rdx - movq %rdx, 6336(%rsp) - movq 808(%rax), %rdx - movq %rdx, 6344(%rsp) - movq 816(%rax), %rdx - movq %rdx, 6352(%rsp) - movq 824(%rax), %rdx - movq %rdx, 6360(%rsp) - movq 832(%rax), %rdx - movq %rdx, 6368(%rsp) - movq 840(%rax), %rdx - movq %rdx, 6376(%rsp) - movq 848(%rax), %rdx - movq %rdx, 6384(%rsp) - movq 856(%rax), %rdx - movq %rdx, 6392(%rsp) - movq 864(%rax), %rdx - movq %rdx, 6400(%rsp) - movq 872(%rax), %rdx - movq %rdx, 6408(%rsp) - movq 880(%rax), %rdx - movq %rdx, 6416(%rsp) - movq 888(%rax), %rdx - movq %rdx, 6424(%rsp) - movq 896(%rax), %rdx - movq %rdx, 6432(%rsp) - movq 904(%rax), %rdx - movq %rdx, 6440(%rsp) - movq 912(%rax), %rdx - movq %rdx, 6448(%rsp) - movq 920(%rax), %rdx - movq %rdx, 6456(%rsp) - movq 928(%rax), %rdx - movq %rdx, 6464(%rsp) - movq 936(%rax), %rdx - movq %rdx, 6472(%rsp) - movq 944(%rax), %rdx - movq %rdx, 6480(%rsp) - movq 952(%rax), %rdx - movq %rdx, 6488(%rsp) - movq 960(%rax), %rdx - movq %rdx, 6496(%rsp) - movq 968(%rax), %rdx - movq %rdx, 6504(%rsp) - movq 976(%rax), %rdx - movq %rdx, 6512(%rsp) - movq 984(%rax), %rdx - movq %rdx, 6520(%rsp) - movq 992(%rax), %rdx - movq %rdx, 6528(%rsp) - movq 1000(%rax), %rdx - movq %rdx, 6536(%rsp) - movq 1008(%rax), %rdx - movq %rdx, 6544(%rsp) - movq 1016(%rax), %rdx - movq %rdx, 6552(%rsp) - movq 1024(%rax), %rdx - movq %rdx, 6560(%rsp) - movq 1032(%rax), %rdx - movq %rdx, 6568(%rsp) - movq 1040(%rax), %rdx - movq %rdx, 6576(%rsp) - movq 1048(%rax), %rdx - movq %rdx, 6584(%rsp) - movq 1056(%rax), %rdx - movq %rdx, 6592(%rsp) - movq 1064(%rax), %rdx - movq %rdx, 6600(%rsp) - movq 1072(%rax), %rdx - movq %rdx, 6608(%rsp) - movq 1080(%rax), %rdx - movq %rdx, 6616(%rsp) - movq 1088(%rax), %rdx - movq %rdx, 6624(%rsp) - movq 1096(%rax), %rdx - movq %rdx, 6632(%rsp) - movq 1104(%rax), %rdx - movq %rdx, 6640(%rsp) - movq 1112(%rax), %rdx - movq %rdx, 6648(%rsp) - movq 1120(%rax), %rdx - movq %rdx, 6656(%rsp) - movq 1128(%rax), %rdx - movq %rdx, 6664(%rsp) - movq 1136(%rax), %rdx - movq %rdx, 6672(%rsp) - movq 1144(%rax), %rax + movq (%rdx), %rax + movq %rax, 5536(%rsp) + movq 8(%rdx), %rax + movq %rax, 5544(%rsp) + movq 16(%rdx), %rax + movq %rax, 5552(%rsp) + movq 24(%rdx), %rax + movq %rax, 5560(%rsp) + movq 32(%rdx), %rax + movq %rax, 5568(%rsp) + movq 40(%rdx), %rax + movq %rax, 5576(%rsp) + movq 48(%rdx), %rax + movq %rax, 5584(%rsp) + movq 56(%rdx), %rax + movq %rax, 5592(%rsp) + movq 64(%rdx), %rax + movq %rax, 5600(%rsp) + movq 72(%rdx), %rax + movq %rax, 5608(%rsp) + movq 80(%rdx), %rax + movq %rax, 5616(%rsp) + movq 88(%rdx), %rax + movq %rax, 5624(%rsp) + movq 96(%rdx), %rax + movq %rax, 5632(%rsp) + movq 104(%rdx), %rax + movq %rax, 5640(%rsp) + movq 112(%rdx), %rax + movq %rax, 5648(%rsp) + movq 120(%rdx), %rax + movq %rax, 5656(%rsp) + movq 128(%rdx), %rax + movq %rax, 5664(%rsp) + movq 136(%rdx), %rax + movq %rax, 5672(%rsp) + movq 144(%rdx), %rax + movq %rax, 5680(%rsp) + movq 152(%rdx), %rax + movq %rax, 5688(%rsp) + movq 160(%rdx), %rax + movq %rax, 5696(%rsp) + movq 168(%rdx), %rax + movq %rax, 5704(%rsp) + movq 176(%rdx), %rax + movq %rax, 5712(%rsp) + movq 184(%rdx), %rax + movq %rax, 5720(%rsp) + movq 192(%rdx), %rax + movq %rax, 5728(%rsp) + movq 200(%rdx), %rax + movq %rax, 5736(%rsp) + movq 208(%rdx), %rax + movq %rax, 5744(%rsp) + movq 216(%rdx), %rax + movq %rax, 5752(%rsp) + movq 224(%rdx), %rax + movq %rax, 5760(%rsp) + movq 232(%rdx), %rax + movq %rax, 5768(%rsp) + movq 240(%rdx), %rax + movq %rax, 5776(%rsp) + movq 248(%rdx), %rax + movq %rax, 5784(%rsp) + movq 256(%rdx), %rax + movq %rax, 5792(%rsp) + movq 264(%rdx), %rax + movq %rax, 5800(%rsp) + movq 272(%rdx), %rax + movq %rax, 5808(%rsp) + movq 280(%rdx), %rax + movq %rax, 5816(%rsp) + movq 288(%rdx), %rax + movq %rax, 5824(%rsp) + movq 296(%rdx), %rax + movq %rax, 5832(%rsp) + movq 304(%rdx), %rax + movq %rax, 5840(%rsp) + movq 312(%rdx), %rax + movq %rax, 5848(%rsp) + movq 320(%rdx), %rax + movq %rax, 5856(%rsp) + movq 328(%rdx), %rax + movq %rax, 5864(%rsp) + movq 336(%rdx), %rax + movq %rax, 5872(%rsp) + movq 344(%rdx), %rax + movq %rax, 5880(%rsp) + movq 352(%rdx), %rax + movq %rax, 5888(%rsp) + movq 360(%rdx), %rax + movq %rax, 5896(%rsp) + movq 368(%rdx), %rax + movq %rax, 5904(%rsp) + movq 376(%rdx), %rax + movq %rax, 5912(%rsp) + movq 384(%rdx), %rax + movq %rax, 5920(%rsp) + movq 392(%rdx), %rax + movq %rax, 5928(%rsp) + movq 400(%rdx), %rax + movq %rax, 5936(%rsp) + movq 408(%rdx), %rax + movq %rax, 5944(%rsp) + movq 416(%rdx), %rax + movq %rax, 5952(%rsp) + movq 424(%rdx), %rax + movq %rax, 5960(%rsp) + movq 432(%rdx), %rax + movq %rax, 5968(%rsp) + movq 440(%rdx), %rax + movq %rax, 5976(%rsp) + movq 448(%rdx), %rax + movq %rax, 5984(%rsp) + movq 456(%rdx), %rax + movq %rax, 5992(%rsp) + movq 464(%rdx), %rax + movq %rax, 6000(%rsp) + movq 472(%rdx), %rax + movq %rax, 6008(%rsp) + movq 480(%rdx), %rax + movq %rax, 6016(%rsp) + movq 488(%rdx), %rax + movq %rax, 6024(%rsp) + movq 496(%rdx), %rax + movq %rax, 6032(%rsp) + movq 504(%rdx), %rax + movq %rax, 6040(%rsp) + movq 512(%rdx), %rax + movq %rax, 6048(%rsp) + movq 520(%rdx), %rax + movq %rax, 6056(%rsp) + movq 528(%rdx), %rax + movq %rax, 6064(%rsp) + movq 536(%rdx), %rax + movq %rax, 6072(%rsp) + movq 544(%rdx), %rax + movq %rax, 6080(%rsp) + movq 552(%rdx), %rax + movq %rax, 6088(%rsp) + movq 560(%rdx), %rax + movq %rax, 6096(%rsp) + movq 568(%rdx), %rax + movq %rax, 6104(%rsp) + movq 576(%rdx), %rax + movq %rax, 6112(%rsp) + movq 584(%rdx), %rax + movq %rax, 6120(%rsp) + movq 592(%rdx), %rax + movq %rax, 6128(%rsp) + movq 600(%rdx), %rax + movq %rax, 6136(%rsp) + movq 608(%rdx), %rax + movq %rax, 6144(%rsp) + movq 616(%rdx), %rax + movq %rax, 6152(%rsp) + movq 624(%rdx), %rax + movq %rax, 6160(%rsp) + movq 632(%rdx), %rax + movq %rax, 6168(%rsp) + movq 640(%rdx), %rax + movq %rax, 6176(%rsp) + movq 648(%rdx), %rax + movq %rax, 6184(%rsp) + movq 656(%rdx), %rax + movq %rax, 6192(%rsp) + movq 664(%rdx), %rax + movq %rax, 6200(%rsp) + movq 672(%rdx), %rax + movq %rax, 6208(%rsp) + movq 680(%rdx), %rax + movq %rax, 6216(%rsp) + movq 688(%rdx), %rax + movq %rax, 6224(%rsp) + movq 696(%rdx), %rax + movq %rax, 6232(%rsp) + movq 704(%rdx), %rax + movq %rax, 6240(%rsp) + movq 712(%rdx), %rax + movq %rax, 6248(%rsp) + movq 720(%rdx), %rax + movq %rax, 6256(%rsp) + movq 728(%rdx), %rax + movq %rax, 6264(%rsp) + movq 736(%rdx), %rax + movq %rax, 6272(%rsp) + movq 744(%rdx), %rax + movq %rax, 6280(%rsp) + movq 752(%rdx), %rax + movq %rax, 6288(%rsp) + movq 760(%rdx), %rax + movq %rax, 6296(%rsp) + movq 768(%rdx), %rax + movq %rax, 6304(%rsp) + movq 776(%rdx), %rax + movq %rax, 6312(%rsp) + movq 784(%rdx), %rax + movq %rax, 6320(%rsp) + movq 792(%rdx), %rax + movq %rax, 6328(%rsp) + movq 800(%rdx), %rax + movq %rax, 6336(%rsp) + movq 808(%rdx), %rax + movq %rax, 6344(%rsp) + movq 816(%rdx), %rax + movq %rax, 6352(%rsp) + movq 824(%rdx), %rax + movq %rax, 6360(%rsp) + movq 832(%rdx), %rax + movq %rax, 6368(%rsp) + movq 840(%rdx), %rax + movq %rax, 6376(%rsp) + movq 848(%rdx), %rax + movq %rax, 6384(%rsp) + movq 856(%rdx), %rax + movq %rax, 6392(%rsp) + movq 864(%rdx), %rax + movq %rax, 6400(%rsp) + movq 872(%rdx), %rax + movq %rax, 6408(%rsp) + movq 880(%rdx), %rax + movq %rax, 6416(%rsp) + movq 888(%rdx), %rax + movq %rax, 6424(%rsp) + movq 896(%rdx), %rax + movq %rax, 6432(%rsp) + movq 904(%rdx), %rax + movq %rax, 6440(%rsp) + movq 912(%rdx), %rax + movq %rax, 6448(%rsp) + movq 920(%rdx), %rax + movq %rax, 6456(%rsp) + movq 928(%rdx), %rax + movq %rax, 6464(%rsp) + movq 936(%rdx), %rax + movq %rax, 6472(%rsp) + movq 944(%rdx), %rax + movq %rax, 6480(%rsp) + movq 952(%rdx), %rax + movq %rax, 6488(%rsp) + movq 960(%rdx), %rax + movq %rax, 6496(%rsp) + movq 968(%rdx), %rax + movq %rax, 6504(%rsp) + movq 976(%rdx), %rax + movq %rax, 6512(%rsp) + movq 984(%rdx), %rax + movq %rax, 6520(%rsp) + movq 992(%rdx), %rax + movq %rax, 6528(%rsp) + movq 1000(%rdx), %rax + movq %rax, 6536(%rsp) + movq 1008(%rdx), %rax + movq %rax, 6544(%rsp) + movq 1016(%rdx), %rax + movq %rax, 6552(%rsp) + movq 1024(%rdx), %rax + movq %rax, 6560(%rsp) + movq 1032(%rdx), %rax + movq %rax, 6568(%rsp) + movq 1040(%rdx), %rax + movq %rax, 6576(%rsp) + movq 1048(%rdx), %rax + movq %rax, 6584(%rsp) + movq 1056(%rdx), %rax + movq %rax, 6592(%rsp) + movq 1064(%rdx), %rax + movq %rax, 6600(%rsp) + movq 1072(%rdx), %rax + movq %rax, 6608(%rsp) + movq 1080(%rdx), %rax + movq %rax, 6616(%rsp) + movq 1088(%rdx), %rax + movq %rax, 6624(%rsp) + movq 1096(%rdx), %rax + movq %rax, 6632(%rsp) + movq 1104(%rdx), %rax + movq %rax, 6640(%rsp) + movq 1112(%rdx), %rax + movq %rax, 6648(%rsp) + movq 1120(%rdx), %rax + movq %rax, 6656(%rsp) + movq 1128(%rdx), %rax + movq %rax, 6664(%rsp) + movq 1136(%rdx), %rax + movq %rax, 6672(%rsp) + movq 1144(%rdx), %rax movq %rax, 6680(%rsp) leaq 8224(%rsp), %rdx leaq 5536(%rsp), %rdi @@ -4348,8 +4347,6 @@ Ljade_kem_mlkem_mlkem768_amd64_avx2_dec$2: vpor %ymm1, %ymm0, %ymm0 vptest %ymm0, %ymm0 cmovne %rsi, %rdx - negq %rdx - shrq $63, %rdx movq %rdx, %mm0 movq 17448(%rsp), %rax movq 22224(%rsp), %rcx diff --git a/code/jasmin/mlkem_avx2_stack/kem.jinc b/code/jasmin/mlkem_avx2_stack/kem.jinc index 6d0501fc..dd4d517e 100644 --- a/code/jasmin/mlkem_avx2_stack/kem.jinc +++ b/code/jasmin/mlkem_avx2_stack/kem.jinc @@ -112,7 +112,7 @@ fn __crypto_kem_dec_jazz(reg ptr u8[MLKEM_SYMBYTES] shk, reg ptr u8[MLKEM_CIPHER s_ct = ct; s_sk = #copy(sk); - buf[0:MLKEM_INDCPA_MSGBYTES] = __indcpa_dec(buf[0:MLKEM_INDCPA_MSGBYTES], ct, s_sk); + buf[0:MLKEM_INDCPA_MSGBYTES] = __indcpa_dec(buf[0:MLKEM_INDCPA_MSGBYTES], ct, s_sk[0:MLKEM_POLYVECBYTES]); // we do not need to do pointer arithmetic here diff --git a/code/jasmin/mlkem_avx2_stack/verify.jinc b/code/jasmin/mlkem_avx2_stack/verify.jinc index ead3e682..c4e98eab 100644 --- a/code/jasmin/mlkem_avx2_stack/verify.jinc +++ b/code/jasmin/mlkem_avx2_stack/verify.jinc @@ -5,9 +5,8 @@ fn __verify(reg ptr u8[MLKEM_INDCPA_CIPHERTEXTBYTES] ct, reg ptr u8[MLKEM_INDCPA { reg u256 f g h; reg u64 cnd t64; - reg u8 t1 t2; reg bool zf; - inline int i off; + inline int i; cnd = 0; t64 = 1; @@ -25,20 +24,6 @@ fn __verify(reg ptr u8[MLKEM_INDCPA_CIPHERTEXTBYTES] ct, reg ptr u8[MLKEM_INDCPA cnd = t64 if !zf; - off = MLKEM_INDCPA_CIPHERTEXTBYTES/32 * 32; - - for i=off to MLKEM_INDCPA_CIPHERTEXTBYTES - { - t1 = ct.[i]; - t2 = ctpc.[i]; - t1 ^= t2; - t64 = (64u)t1; - cnd |= t64; - } - - cnd = -cnd; - cnd >>= 63; - return cnd; } @@ -47,35 +32,16 @@ fn __cmov(reg ptr u8[MLKEM_SYMBYTES] dst, reg ptr u8[MLKEM_SYMBYTES] src, reg u6 { reg u256 f g m; stack u64 scnd; - reg u8 t1 t2 bcond; - inline int i off; cnd = -cnd; scnd = cnd; m = #VPBROADCAST_4u64(scnd); - for i=0 to MLKEM_SYMBYTES/32 - { - f = src.[u256 32*i]; - g = dst.[u256 32*i]; - f = #VPBLENDVB_256(f, g, m); - dst.[u256 32*i] = f; - } - - off = MLKEM_SYMBYTES/32 * 32; - - /* fixme: unused in 768, hence untested */ - bcond = (8u)cnd; - for i=off to MLKEM_SYMBYTES - { - t2 = dst.[i]; - t1 = src[i]; - t2 = t2 ^ t1; - t2 = t2 & cnd; - t1 ^= t2; - dst.[i] = t1; - } + f = src.[u256 0]; + g = dst.[u256 0]; + f = #VPBLENDVB_256(f, g, m); + dst.[u256 0] = f; return dst; } diff --git a/proof/correctness/avx2/MLKEM_InnerPKE_avx2_stack.ec b/proof/correctness/avx2/MLKEM_InnerPKE_avx2_stack.ec new file mode 100644 index 00000000..a8702489 --- /dev/null +++ b/proof/correctness/avx2/MLKEM_InnerPKE_avx2_stack.ec @@ -0,0 +1,34 @@ +require import AllCore IntDiv List. +require import Jkem_avx2_stack MLKEM. +from Jasmin require import JModel_x86. + +require import Array32 Array64 Array148 Array960 Array128 Array1088 Array1152 Array1184 Array2400. +require import Array8 Array300 WArray64 WArray1184 WArray2400. +import MLKEM InnerPKE. + +require import MLKEM_KEM_avx2. + +equiv mlkem_correct_kg_avx2_stack : +M.__indcpa_keypair ~ InnerPKE.InnerPKE.kg_derand : +arg{1}.`3 = arg{2} ==> + res{2}.`2 = Array1152.init ("_.[_]" res{1}.`2) /\ + res{2}.`1.`1 = (init (fun (i : int) => res{1}.`1.[i]))%Array1152 /\ + res{2}.`1.`2 = (init (fun (i : int) => res{1}.`1.[i + 1152]))%Array32. +admitted. + +equiv mlkem_correct_enc_avx2_stack : +M.__indcpa_enc ~ InnerPKE.InnerPKE.enc_derand : +arg{1}.`2 = arg{2}.`2 /\ arg{1}.`4 = arg{2}.`3 /\ + Array1152.init(fun i => arg{1}.`3.[i]) = arg{2}.`1.`1 /\ + Array32.init(fun i => arg{1}.`3.[i+1152]) = arg{2}.`1.`2 ==> + Array960.init(fun i => res{1}.[i]) = res{2}.`1 /\ + Array128.init(fun i => res{1}.[i+960]) = res{2}.`2. +admitted. + +equiv mlkem_correct_dec_avx2_stack : +M.__indcpa_dec ~ InnerPKE.InnerPKE.dec : +arg{1}.`3 = arg{2}.`1 /\ +Array960.init(fun i => arg{1}.`2.[i]) = arg{2}.`2.`1 /\ +Array128.init(fun i => arg{1}.`2.[i+960]) = arg{2}.`2.`2 ==> + ={res}. +admitted. diff --git a/proof/correctness/avx2/MLKEM_KEM_avx2_stack.ec b/proof/correctness/avx2/MLKEM_KEM_avx2_stack.ec index 18059445..524665b3 100644 --- a/proof/correctness/avx2/MLKEM_KEM_avx2_stack.ec +++ b/proof/correctness/avx2/MLKEM_KEM_avx2_stack.ec @@ -2,25 +2,322 @@ require import AllCore IntDiv List. require import Jkem_avx2_stack MLKEM. from Jasmin require import JModel_x86. -require import Array32 Array64 Array960 Array128 Array1088 Array1152 Array1184 Array2400. -import MLKEM. +require import Array32 Array64 Array136 Array148 Array960 Array128 Array1120 Array1088 Array1152 Array1184 Array2400. +require import Array4 Array8 Array300 WArray32 WArray64 WArray1088 WArray1120 WArray1184 WArray2400. +import MLKEM Symmetric. -print MLKEM. +require import MLKEM_KEM_avx2. +require import MLKEM_InnerPKE_avx2_stack. + +(* sha3 assumptions *) +axiom sha3_256A_M1184_ph_stack (inp : W8.t Array1184.t): + phoare [ Jkem_avx2_stack.M._sha3_256A_A1184 + : arg.`2 = inp + ==> + res = SHA3_256_1184_32 + (Array1152.init (fun i => inp.[i]), + Array32.init (fun i => inp.[1152+i]))] = 1%r. + +axiom sha3_512A_512A_A64_stack m hpk: + phoare [ Jkem_avx2_stack.M._sha3_512A_A64 + : + m = Array32.init (fun i => arg.`2.[i]) /\ + hpk = Array32.init (fun i => arg.`2.[i+32]) + ==> + let (_K,r) = SHA3_512_64_64 m hpk in + _K = Array32.init (fun i => res.[i]) /\ + r = Array32.init (fun i => res.[i+32]) ] = 1%r. + +axiom shake256_A32_A1120_ph_stack (_in : W8.t Array1120.t): + phoare[ Jkem_avx2_stack.M._shake256_A32__A1120 : + arg.`2=_in ==> + res = + SHAKE_256_1120_32 + (init (fun (k : int) => _in.[k]))%Array32 + ((init (fun (k : int) => _in.[32 + k]))%Array960, + (init (fun (k : int) => _in.[960 + 32 + k]))%Array128)] = 1%r. + +lemma copy32 (a : W8.t Array32.t) : +Array32.init (fun (i : int) => WArray32.get8 (WArray32.init64 + (fun (i0 : int) => copy_64 (Array4.init + (fun (i1 : int) => WArray32.get64 (WArray32.init8 + (fun (i2 : int) => a.[i2])) i1)).[i0])) i) = a. +proof. +rewrite tP => k kb. +rewrite initiE 1:/# /= /get8. +rewrite initiE 1:/# /= /copy_64. +rewrite initiE 1:/# /= /get64_direct. +rewrite W8u8.pack8bE 1:/# /=. +rewrite initiE 1:/# /=. +rewrite initiE /#. +qed. + +lemma copy64 (a : W8.t Array64.t) : +Array64.init (fun (i : int) => WArray64.get8 (WArray64.init64 + (fun (i0 : int) => copy_64 (Array8.init + (fun (i1 : int) => WArray64.get64 (WArray64.init8 + (fun (i2 : int) => a.[i2])) i1)).[i0])) i) = a. +proof. +rewrite tP => k kb. +rewrite initiE 1:/# /= /get8. +rewrite initiE 1:/# /= /copy_64. +rewrite initiE 1:/# /= /get64_direct. +rewrite W8u8.pack8bE 1:/# /=. +rewrite initiE 1:/# /=. +rewrite initiE /#. +qed. + +lemma copy1088 (a : W8.t Array1088.t) : +Array1088.init (fun (i : int) => WArray1088.get8 (WArray1088.init64 + (fun (i0 : int) => copy_64 (Array136.init + (fun (i1 : int) => WArray1088.get64 (WArray1088.init8 + (fun (i2 : int) => a.[i2])) i1)).[i0])) i) = a. +rewrite tP => k kb. +rewrite initiE 1:/# /= /get8. +rewrite initiE 1:/# /= /copy_64. +rewrite initiE 1:/# /= /get64_direct. +rewrite W8u8.pack8bE 1:/# /=. +rewrite initiE 1:/# /=. +rewrite initiE /#. +qed. + +lemma copy1184 (a : W8.t Array1184.t) : +Array1184.init (fun (i : int) => WArray1184.get8 (WArray1184.init64 + (fun (i0 : int) => copy_64 (Array148.init + (fun (i1 : int) => WArray1184.get64 (WArray1184.init8 + (fun (i2 : int) => a.[i2])) i1)).[i0])) i) = a. +rewrite tP => k kb. +rewrite initiE 1:/# /= /get8. +rewrite initiE 1:/# /= /copy_64. +rewrite initiE 1:/# /= /get64_direct. +rewrite W8u8.pack8bE 1:/# /=. +rewrite initiE 1:/# /=. +rewrite initiE /#. +qed. + +lemma copy2400 (a : W8.t Array2400.t) : +Array2400.init (fun (i : int) => WArray2400.get8 (WArray2400.init64 + (fun (i0 : int) => copy_64 (Array300.init + (fun (i1 : int) => WArray2400.get64 (WArray2400.init8 + (fun (i2 : int) => a.[i2])) i1)).[i0])) i) = a. +rewrite tP => k kb. +rewrite initiE 1:/# /= /get8. +rewrite initiE 1:/# /= /copy_64. +rewrite initiE 1:/# /= /get64_direct. +rewrite W8u8.pack8bE 1:/# /=. +rewrite initiE 1:/# /=. +rewrite initiE /#. +qed. + +(* lemma mlkem_kem_correct_kg : equiv [Jkem_avx2_stack.M.jade_kem_mlkem_mlkem768_amd64_avx2_keypair_derand ~ MLKEM.kg_derand : coins{2}.`1 = Array32.init(fun i => coins{1}.[i]) /\ coins{2}.`2 = Array32.init(fun i => coins{1}.[32 + i]) ==> let (pk,sk) = res{2} in let (t,rho) = pk in - sk.`1 = Array1152.init(fun i => res{1}.`1.[i]) /\ - sk.`2.`1 = Array1152.init(fun i => res{1}.`1.[i+1152]) /\ - sk.`2.`2 = Array32.init(fun i => res{1}.`1.[i+1152+1152]) /\ - sk.`3 = Array32.init(fun i => res{1}.`1.[i+1152+1152 + 32]) /\ - sk.`4 = Array32.init(fun i => res{1}.`1.[i+1152+1152 + 32 + 32]) /\ - t = Array1152.init(fun i => res{1}.`2.[i]) /\ - rho = Array32.init(fun i => res{1}.`2.[i+1152])]. -admitted. + sk.`1 = Array1152.init(fun i => res{1}.`2.[i]) /\ + sk.`2.`1 = Array1152.init(fun i => res{1}.`2.[i+1152]) /\ + sk.`2.`2 = Array32.init(fun i => res{1}.`2.[i+1152+1152]) /\ + sk.`3 = Array32.init(fun i => res{1}.`2.[i+1152+1152 + 32]) /\ + sk.`4 = Array32.init(fun i => res{1}.`2.[i+1152+1152 + 32 + 32]) /\ + t = Array1152.init(fun i => res{1}.`1.[i]) /\ + rho = Array32.init(fun i => res{1}.`1.[i+1152])]. +proof. +proc => /=. +proc rewrite {1} 8 (copy64). +proc rewrite {1} 16 (copy1184). +proc rewrite {1} 17 (copy2400). +sp;wp;inline {1} 1. + +sp 3 0; seq 6 0 : #pre;1: by auto. +wp;sp 1 0. + +seq 11 1 : ( + z{2} = coins{2}.`2 /\ + coins{2}.`2 = (init (fun (i0 : int) => s_randomnessp{1}.[32 + i0]))%Array32 /\ + sk{2} = Array1152.init (fun i => sk0{1}.[i]) /\ + pk{2}.`1 = Array1152.init (fun i => sk0{1}.[i+1152]) /\ + pk{2}.`2 = Array32.init (fun i => sk0{1}.[i + 1152 + 1152]) /\ + H_pk pk{2} = Array32.init (fun i => sk0{1}.[i + 1152 + 1152 + 32]) /\ + pk{2}.`1 = Array1152.init (fun i => pk0{1}.[i]) /\ + pk{2}.`2 = Array32.init (fun i => pk0{1}.[i + 1152]) +); last first. ++ while {1} (aux{1} = 4 /\ + coins{2}.`2 = randomnessp2{1} /\ + sk{2} = Array1152.init (fun i => sk0{1}.[i]) /\ + pk{2}.`1 = Array1152.init (fun i => sk0{1}.[i+1152]) /\ + pk{2}.`2 = Array32.init (fun i => sk0{1}.[i + 1152 + 1152]) /\ + H_pk pk{2} = Array32.init (fun i => sk0{1}.[i + 1152 + 1152 + 32]) /\ + pk{2}.`1 = Array1152.init (fun i => pk0{1}.[i]) /\ + pk{2}.`2 = Array32.init (fun i => pk0{1}.[i + 1152]) /\ + 0 <= i{1} <= 4 /\ + forall k, 0<=k + WArray2400.get64_direct (WArray2400.init + (fun ii => sk0{1}.[ii])) (2368 + k*8) = + pack8_t (W8u8.Pack.init (fun ii => coins{2}.`2.[k*8+ii]))) + (4 - i{1}). + + move => &m z0; auto => /> &hr; rewrite !tP => + ???????prev?. + do split;5..6,8..:smt(). + + move => i ib;rewrite initiE 1:/# initiE 1:/# /= initiE 1:/# /=. + rewrite get8_set64_directE 1,2:/# /=. + case ((296 + i{!hr}) * 8 <= i && i < (296 + i{!hr}) * 8 + 8 ) => *. + + by rewrite WArray32.get64E pack8bE 1:/# initiE 1:/# /= initiE /#. + by rewrite /get8 initiE /#. + + move => i ib;rewrite initiE 1:/# /= initiE 1:/# /=. + rewrite get8_set64_directE 1,2:/# /=. + case ((296 + i{!hr}) * 8 <= i + 1152 && i + 1152 < (296 + i{!hr}) * 8 + 8) => *. + + by rewrite WArray32.get64E pack8bE 1:/# initiE 1:/# /= initiE /#. + by rewrite /get8 initiE 1:/# /=;smt(Array1152.initiE). + + move => i ib;rewrite initiE 1:/# /= initiE 1:/# /=. + rewrite get8_set64_directE 1,2:/# /=. + case ((296 + i{!hr}) * 8 <= i + 2304 && i + 2304 < (296 + i{!hr}) * 8 + 8) => *. + + by rewrite WArray32.get64E pack8bE 1:/# initiE 1:/# /= initiE /#. + by rewrite /get8 initiE 1:/# /=; smt(Array32.initiE). + + move => i ib; pose xx := H_pk pk{m}. + rewrite initiE 1:/# /= initiE 1:/# /= initiE 1:/# /=. + rewrite get8_set64_directE 1,2:/# /=. + case ((296 + i{!hr}) * 8 <= i + 2336 && i + 2336 < (296 + i{!hr}) * 8 + 8) => *. + + by rewrite WArray32.get64E pack8bE 1:/# initiE 1:/# /= initiE /#. + by rewrite /get8 initiE; smt(Array32.initiE). + + move => ii iibl iibh. + rewrite /get64_direct (inj_eq W8u8.pack8_t pack_inj) packP => k kh. + rewrite initiE 1:/# /= initiE 1:/# /= initiE 1:/# /=. + rewrite get8_set64_directE 1,2:/# /=. + case ( (296 + i{hr}) * 8 <= 2368 + ii * 8 + k && 2368 + ii * 8 + k < (296 + i{hr}) * 8 + 8) => *. + + rewrite pack8bE 1:/# initiE 1:/# /= initiE 1: /# /= initiE 1:/# /=. + by smt(Array32.initiE). + move : (prev ii _);1:by smt(). + rewrite /get64_direct (inj_eq W8u8.pack8_t pack_inj) packP => Hii. + rewrite -Hii 1:/#. + by rewrite /get8 initiE 1:/# /= initiE 1:/# /= initiE /#. + +auto =>/> &1 &2 *;do split. + + move => ii iibl iibh. + rewrite /get64_direct (inj_eq W8u8.pack8_t pack_inj) packP => k kh. + by rewrite initiE 1:/# /= initiE 1:/# /= initiE /#. + move => i1 sk0; do split =>???????H;1:smt(). + + move => rho t?; rewrite tP => k kb. + move : (H (k %/ 8) _);1:smt(). + rewrite /get64_direct (inj_eq W8u8.pack8_t pack_inj) packP => Hkk. + move : (Hkk (k %% 8) _); 1: smt(). + rewrite initiE 1:/# /= initiE 1:/# /= initiE 1:/# /=. + + by rewrite initiE /#. + +wp;conseq (_: _ ==> + sk{2} = Array1152.init (fun i => sk0{1}.[i]) /\ + pk{2}.`1 = Array1152.init (fun i => sk0{1}.[i+1152]) /\ + pk{2}.`2 = Array32.init (fun i => sk0{1}.[i + 1152 + 1152]) /\ + H_pk pk{2} = Array32.init (fun i => sk0{1}.[i + 1152 + 1152 + 32]) /\ + pk{2}.`1 = Array1152.init (fun i => pk0{1}.[i]) /\ + pk{2}.`2 = Array32.init (fun i => pk0{1}.[i + 1152]) +); 1: by auto. + +seq 2 1 : ( + sk{2} = Array1152.init (fun i => sk0{1}.[i]) /\ + pk{2}.`1 = Array1152.init (fun i => pk0{1}.[i]) /\ + pk{2}.`2 = Array32.init (fun i => pk0{1}.[i + 1152])). +call (mlkem_correct_kg_avx2_stack). +auto => /> /#. + +swap {1} 1 8. swap {1} 4 1. +seq 3 0 : (#pre /\ + sk{2} = Array1152.init (fun i => sk0{1}.[i]) /\ + pk{2}.`1 = Array1152.init (fun i => sk0{1}.[i+1152]) /\ + pk{2}.`2 = Array32.init (fun i => sk0{1}.[i + 1152 + 1152]) /\ + pk{2}.`1 = Array1152.init (fun i => pk0{1}.[i]) /\ + pk{2}.`2 = Array32.init (fun i => pk0{1}.[i + 1152]) +). + ++ wp;while {1} (#pre /\ + aux{1} = (3 * 384 + 32) %/ 8 /\ 0<=i{1} <= aux{1} /\ + (forall k, 0<= k < min (8 * i{1}) 1152 => + pk{2}.`1.[k] = sk0{1}.[1152+k]) /\ + (forall k, 1152 <= k < min (8 * i{1}) (1152 + 32) => + pk{2}.`2.[k-1152] = sk0{1}.[1152 + k])) + ((3 * 384 + 32) %/ 8 - i{1}). + move => &m z;auto => /> &hr; rewrite !tP. + move => pkv1 pkv2???prev1 prev2. + do split. + + move => i ib; rewrite initiE 1:/# /= initiE 1:/# /= initiE 1:/# /=. + rewrite get8_set64_directE 1,2:/# /=. + case ((144 + i{!hr}) * 8 <= i && i < (144 + i{!hr}) * 8 + 8) => *. + + by rewrite WArray1184.get64E pack8bE 1:/# initiE 1:/# /= initiE /#. + by rewrite /get8 initiE; smt(Array32.initiE). + + by smt(). + + by smt(). + + move => i ib ibb; rewrite initiE 1:/# /=. + rewrite get8_set64_directE 1,2:/# /=. + case ((144 + i{!hr}) * 8 <= 1152 + i && 1152 + i < (144 + i{!hr}) * 8 + 8) => *. + + rewrite WArray1184.get64E pack8bE 1:/# initiE 1:/# /=. + by rewrite /get8 initiE; smt(Array1152.initiE). + + by rewrite /get8 initiE; smt(Array1152.initiE). + + move => i ib ibb; rewrite initiE 1:/# /=. + rewrite get8_set64_directE 1,2:/# /=. + case ( (144 + i{!hr}) * 8 <= 1152 + i && 1152 + i < (144 + i{!hr}) * 8 + 8) => *. + + rewrite WArray1184.get64E pack8bE 1:/# initiE 1:/# /=. + by rewrite /get8 initiE 1:/#; smt(Array32.initiE). + + by rewrite /get8 initiE; smt(Array1152.initiE). + + by smt(). + + auto => /> &1 &2; rewrite !tP. + move => ??. + do split; 1..2: by smt(). + move => il skl. + rewrite !tP; split; 1: smt(). + move => ??????; do split. + + by move => *; rewrite initiE //= /#. + by move => *; rewrite initiE //= /#. + +seq 1 0 : (#pre /\ h_pk{1} = H_pk pk{2}). + +ecall {1} (sha3_256A_M1184_ph_stack pk0{1}); 1: by inline *; auto => />* /#. + +wp;while {1} (#pre /\ + 0 <= i{1} <= 4 /\ + forall k, 0 <= k < i{1} * 8 => sk0{1}.[2336 + k] = ((H_pk pk{2})).[k]) (4 - i{1}). +move => &m z; auto => /> &1 &2; rewrite !tP. +move => pkv1s pkv2s pkv1 pkv2 ? prev ?;do split. + + move => i ib; rewrite initiE 1:/# /= initiE 1:/# /= initiE 1:/# /=. + rewrite get8_set64_directE 1,2:/# /=. + case ((292 + i{!1}) * 8 <= i && i < (292 + i{!1}) * 8 + 8) => *. + + by rewrite WArray32.get64E pack8bE 1:/# initiE 1:/# /= initiE /#. + by rewrite /get8 initiE; smt(Array2400.initiE). + + move => i ib; rewrite initiE 1:/# /= initiE 1:/# /= initiE 1:/# /=. + rewrite get8_set64_directE 1,2:/# /=. + case ((292 + i{!1}) * 8 <= i && i < (292 + i{!1}) * 8 + 8) => *. + + by rewrite WArray32.get64E pack8bE 1:/# initiE 1:/# /= initiE /#. + by rewrite /get8 initiE; smt(Array2400.initiE). + + move => i ib; rewrite initiE 1:/# /= initiE 1:/# /=. + rewrite get8_set64_directE 1,2:/# /=. + case ((292 + i{!1}) * 8 <= i + 1152 && i + 1152 < (292 + i{!1}) * 8 + 8) => *. + + by rewrite WArray32.get64E pack8bE 1:/# initiE 1:/# /= initiE /#. + by rewrite /get8 initiE; smt(Array1152.initiE). + + move => i ib; rewrite initiE 1:/# /= initiE 1:/# /=. + rewrite get8_set64_directE 1,2:/# /=. + case ((292 + i{!1}) * 8 <= i + 2304 && i + 2304 < (292 + i{!1}) * 8 + 8) => *. + + by rewrite WArray32.get64E pack8bE 1:/# initiE 1:/# /= initiE /#. + by rewrite /get8 initiE; smt(Array32.initiE). + + by smt(). + + by smt(). + + move => i ib ibb; + rewrite initiE 1:/# /=. + rewrite get8_set64_directE 1,2:/# /=. + case ( (292 + i{!1}) * 8 <= 2336 + i && 2336 + i < (292 + i{!1}) * 8 + 8) => *. + + rewrite WArray32.get64E pack8bE 1:/# initiE 1:/#. + rewrite /get8 initiE 1:/# /= initiE 1:/# /=. + by rewrite /H_pk /SHA3_256_1184_32 /= get_of_list /#. + + by rewrite /get8 initiE; smt(Array1152.initiE). + + by smt(). + + +auto => /> &1 &2; rewrite !tP => *; split; 1: smt(). +move => i ib *;do split;1:smt(). +by move => *;rewrite tP => k kb; smt(Array32.initiE). +qed. + lemma mlkem_kem_correct_enc : equiv [Jkem_avx2_stack.M.jade_kem_mlkem_mlkem768_amd64_avx2_enc_derand ~ MLKEM.enc_derand: @@ -33,7 +330,212 @@ lemma mlkem_kem_correct_enc : c.`2 = Array128.init(fun i => res{1}.`1.[i+960]) /\ k = res{1}.`2 ]. -admitted. +proc => //=. +proc rewrite {1} 10 (copy1184). +proc rewrite {1} 11 (copy32). +proc rewrite {1} 21 (copy1088). +proc rewrite {1} 22 (copy32). +inline {1} 16. +swap {1} [9..19] 5. +seq 13 0 : #pre; 1: by auto. +sp 14 0;wp. +seq 9 4 : (#[/1:-2]post /\ + (forall k, 0<=k<32 => kr{1}.[k]=_K{2}.[k])); last first. ++ wp;conseq />. while {1} ( + aux{1} = 4 + /\ 0<=i{1}<=4 + /\ (forall k, 0<=k<32 => kr{1}.[k]=_K{2}.[k]) + /\ (forall k, 0<=k<8*i{1} => _K{2}.[k] = shk0{1}.[k])) (aux{1} - i{1}). + + auto => /> &hr H H0 H1 H2 H3;do split;1,2,4:smt(). + + move => k kb kbh; rewrite initiE 1:/# /=. + rewrite get8_set64_directE 1,2:/# /=. + case (8 * i{hr} <= k && k < 8 * i{hr} + 8 ) => *. + + rewrite WArray64.get64E pack8bE 1:/# initiE 1:/# /= initiE /= 1:/#. by smt(). + by rewrite /get8 initiE; smt(Array32.initiE). + + auto => /> &1 &2;rewrite !tP => *;do split. + + move => *; smt(Array960.initiE). + + move => *; do split;1:smt(). + by move => *; smt(Array32.tP). + +wp; call (mlkem_correct_enc_avx2_stack). +wp; ecall {1} (sha3_512A_512A_A64_stack coins{2} (H_pk pk{2})). +wp; ecall {1} (sha3_256A_M1184_ph_stack pk0{1}). +seq 3 0 : (#pre /\ randomnessp{1} = Array32.init (fun i => buf{1}.[i])). ++ sp ; conseq />. + while {1} (0<=i{1}<=aux{1} /\ aux{1} = 4 /\ randomnessp{1} = coins{2} /\ (forall k, 0<=k randomnessp{1}.[k] = buf{1}.[k])) (aux{1} - i{1}); last first. + + auto => /> &1 &2 *; split; 1: by smt(). + move => buf i1; split; 1: smt(). + by move => *; rewrite tP => k kn; rewrite initiE //= /#. + move => &2 ?. + auto => /> &1 il ih premv ihh; do split; 1,2,4:smt(). + move => k kl kh; rewrite initiE 1:/# /=. + rewrite WArray64.get8_set64_directE 1..2:/#. + case (8 * i{1} <= k && k < 8 * i{1} + 8). + + move => *. + rewrite WArray32.get64E pack8bE 1:/# !initiE 1:/# /= /init8. + by rewrite !WArray32.initiE /#. + by move => *; rewrite /get8; rewrite WArray64.initiE /#. +auto => /> &1 &2. +move => [#] pkv1 pkv2; do split. ++ by rewrite tP => i ib; rewrite !initiE /= /#. ++ rewrite tP => i ib; rewrite initiE /= 1:/# initiE /= 1:/# /=. + rewrite initiE 1:/# /= initiE 1:/# /= ifT 1:/#. + by congr;congr;congr;congr;smt(). + +rewrite !tP => H H0 rr0 H1 H2;do split. ++ by move => i ib;rewrite -pkv1 initiE 1:/# /= initiE 1:/# /= ifF 1:/# initiE /#. ++ move : H2;rewrite !tP => H2 i ib. + move : (H2 i ib);rewrite !initiE 1..3:/# /= addrC => <-. + rewrite /G_mhpk /SHA3_512_64_64 /= get_of_list /#. ++ by smt(). ++ by smt(). +by move => ????r0 r1 ??;do split;smt(Array32.initiE). +qed. +*) + +lemma verify_correct_h_stack _ctp _ctp1 : + hoare [Jkem_avx2_stack.M.__verify : + arg = (_ctp,_ctp1) ==> + (_ctp = _ctp1 => + res = W64.of_int 0) /\ + (_ctp <> _ctp1 => + res = W64.of_int 1)]. +proc => /=. + +wp; while (#pre /\ 0 <= i{hr} <= 34 /\ aux{hr} = 34 /\ + (to_uint h{hr} = 0 <=> + forall k, 0 <= k < i{hr}*32 => _ctp.[k] = _ctp1.[k])); last first. ++ auto => />; split; 1: by smt(). + move => h i ??? [HL HR]; split; last first. + + rewrite tP => H1. + rewrite /VPTEST_256 /= /ZF_of /= ifT //. + move : H1; rewrite negb_forall /= => [[k H]]. + move : H; rewrite negb_imply /= => [# H1 H2 H3]. + rewrite implybE in HL; elim HL; smt(to_uint0). + + + rewrite tP => H1. + rewrite /VPTEST_256 /= /ZF_of /= ifF //. + rewrite implybE in HR; elim HR. + + rewrite negb_forall /= => [[k H]]. + move : H; rewrite negb_imply /= => [# H2 H3 H4]. + smt(Array1088.initiE). + smt(W256.to_uint_eq W256.to_uint0). + +auto => /> &hr ?? [HL HR] ?. +pose x := + ((WArray1088.get256 ((WArray1088.init8 ("_.[_]" _ctp1))) i{hr})) `^` + (WArray1088.get256 ((WArray1088.init8 ("_.[_]" _ctp))) i{hr}). +do split; 1..2: by smt(). ++ move => H0 k kbl kbh. + + have H1 : to_uint h{hr} = 0. + + have : h{hr} = W256.zero; last by rewrite to_uint_eq /=. + have : (h{hr} `|` x) = W256.zero by rewrite to_uint_eq H0 /=. + by rewrite !wordP; smt(orwE zerowE). + + have H2 : x = W256.zero. + + have : (h{hr} `|` x) = W256.zero by rewrite to_uint_eq H0 /=. + by rewrite !wordP; smt(orwE zerowE). + + case (k < i{hr}*32); 1: by move => *; apply (HL _ _) => // /#. + + move => kv. + move : H2; rewrite /x W256.WRing.addr_eq0 /oppw /=. + rewrite /init8 /loadW256 /loadW8 /get256_direct wordP => H2. + apply W8.wordP => j jb. + move : (H2 ((k - i{hr} * 32) * 8 + j) _). smt(). + rewrite !pack32wE /=; 1,2: smt(). + rewrite !initiE /=; 1,2: smt(). + rewrite !initiE /=; 1,2: smt(). + by smt(). + +move => H0. +have -> : h{hr} = W256.zero by rewrite to_uint_eq /= /#. +rewrite or0w;have -> : x = W256.zero; last by auto. +rewrite wordP => k kb; rewrite zerowE /x. +move : H0;rewrite /init8 /loadW256 /loadW8 /get256_direct => H0. +pose xx := pack32_t _. +pose yy := pack32_t _. +have -> : xx = yy; last by rewrite xorwK_s //. +rewrite /xx /yy wordP => j jb. + rewrite !pack32wE /=; 1,2: smt(). + rewrite !initiE /=; 1,2: smt(). + rewrite !initiE /=; 1,2: smt(). + by smt(). +qed. + +lemma verify_ll : islossless Jkem_avx2_stack.M.__verify. +proc. +wp. +while (0 <= i{hr} <= 34 /\ aux{hr} = 34) (34 - i{hr}). +auto => /> /#. +auto => /> /#. +qed. + +lemma verify_correct_stack _ctp _ctp1 : + phoare [Jkem_avx2_stack.M.__verify : + arg = (_ctp,_ctp1) ==> + (_ctp = _ctp1 => + res = W64.of_int 0) /\ + (_ctp <> _ctp1 => + res = W64.of_int 1)] = 1%r + by conseq verify_ll (verify_correct_h_stack _ctp _ctp1). + +lemma cmov_correct_h_stack _dst _src _cnd : + hoare [Jkem_avx2_stack.M.__cmov : + dst = _dst /\ src = _src /\ cnd = _cnd ==> + (_cnd = W64.of_int 0 => + res = _src) /\ + (_cnd = W64.of_int 1 => + res = _dst)]. +proc => /=. +seq 1 : (#{/~cnd}pre /\ (_cnd = W64.zero => cnd = W64.zero) /\ + (_cnd = W64.one => cnd = W64.onew)); + 1: by auto => /> /=; split; [ by ring | by rewrite W64.minus_one /=]. + + +auto => /> &hr H H0. + +have msb1 : forall i, 0 <= i < 8 => msb (W64.onew \bits8 i) = true. ++ move => i ib; rewrite /msb /=. + have -> : W64.onew \bits8 i = W8.onew. + + rewrite wordP => k kb. + rewrite onewE /= kb /= bits8iE //= /#. + by rewrite to_uint_onew /=. + +have msb0 : forall i, 0 <= i <8 => msb (W64.zero \bits8 i) = false. ++ move => i ib; rewrite /msb /=. + have -> : W64.zero \bits8 i = W8.zero. + + rewrite wordP => k kb. + rewrite zerowE /= bits8iE //= /#. + by rewrite to_uint0 /=. + +do split. ++ rewrite tP => H1 k kb;rewrite (H H1). + rewrite initiE 1:/# /= kb /=. +rewrite /get256_direct /init8 /loadW256 /loadW8 /= wordP => i ib. +rewrite /VPBLENDVB_256 /VPBROADCAST_4u64 /(\bits8) -iotaredE /= /VPBLENDVB_128 /= initiE 1:/# /=. +by rewrite !msb0 1..8:/# pack32E initiE /= 1:/# /of_list initiE /= /#. + ++ move =>cone; move : (H0 cone) => -> /=;rewrite tP => k kb. +rewrite initiE 1:/# /= kb /=. +rewrite /get256_direct /init8 /loadW256 /loadW8 /= wordP => i ib. +rewrite /VPBLENDVB_256 /VPBROADCAST_4u64 /(\bits8) -iotaredE /= /VPBLENDVB_128 /= !msb1 //= initiE //=. +by rewrite pack32E initiE /= 1:/# /of_list !initiE /= /#. +qed. + +lemma cmov_ll_stack : islossless Jkem_avx2_stack.M.__cmov by islossless. + + +lemma cmov_correct_stack _dst _src _cnd: + phoare [Jkem_avx2_stack.M.__cmov : + dst=_dst /\ src = _src /\ cnd = _cnd ==> + (_cnd = W64.of_int 0 => + res = _src) /\ + (_cnd = W64.of_int 1 => + res = _dst)] = 1%r + by conseq cmov_ll_stack (cmov_correct_h_stack _dst _src _cnd). lemma mlkem_kem_correct_dec : equiv [Jkem_avx2_stack.M.jade_kem_mlkem_mlkem768_amd64_avx2_dec ~ MLKEM.dec: @@ -48,4 +550,120 @@ lemma mlkem_kem_correct_dec : ==> res{2} = res{1}.`1 ]. +proc. +proc rewrite {1} 8 (copy1088). +proc rewrite {1} 9 (copy2400). +proc rewrite {1} 17 (copy32). +inline {1} 13. +proc rewrite {1} 25 (copy2400). + +swap {1} [16..22] -15. +seq 13 0 : #pre; 1: by auto. +sp;seq 1 1 : (#pre /\ aux{1} = m{2}); + 1: by call (mlkem_correct_dec_avx2_stack); 1: by auto => /> /#. + +seq 5 1 : (#pre /\ + (forall k, 0<=k<32 => buf{1}.[k] = m{2}.[k]) /\ + (forall k, 0<=k<32 => kr{1}.[k] = _K{2}.[k]) /\ + (forall k, 0<=k<32 => kr{1}.[k+32] = r{2}.[k])). ++ ecall {1} (sha3_512A_512A_A64_stack m{2} hpk{2}). + wp; conseq (_: _ ==> + (forall k, 0<=k<32 => buf{1}.[k] = m{2}.[k]) /\ + (forall k, 32<=k<64 => buf{1}.[k] = s_sk{1}.[2336 + k - 32]) /\ + (forall k, 0<=k<32 => buf{1}.[k] = aux{1}.[k])). ++ auto => /> &1 &2; rewrite !tP => ???buf bvl bvh ; split. + + split; rewrite tP => *; rewrite initiE /= 1:/#;1: by smt(). + by rewrite initiE /#. + move => H H0 H1 H2 H3; do split. + + move => k kbl kbh. + move : H2;rewrite tP => H2. + move : (H2 k _);1:smt(). + rewrite initiE 1:/# /= initiE 1:/# /= => <-. + by rewrite /G_mhpk /SHA3_512_64_64 /= initiE 1:/# /=. + + move => k kbl kbh. + move : H3;rewrite tP => H3. + move : (H3 k _);1:smt(). + rewrite initiE 1:/# /= initiE 1:/# /= => <-. + by rewrite /G_mhpk /SHA3_512_64_64 /= initiE 1:/# /=. + + while {1} (0<=i{1}<=4 /\ aux_0{1} = 4 /\ + (forall (k : int), 32 <= k && k < 32 + 8*i{1} => buf{1}.[k] = s_sk{1}.[2336 + k - 32]) /\ + forall (k : int), 0 <= k && k < 32 => buf{1}.[k] = aux{1}.[k]) (4 - i{1}); last first. + + auto => /> &1 &2 ??? /=;do split. + + by move => *;rewrite initiE 1:/# /= ifT 1:/#;smt(Array32.initiE). + + by move => *;rewrite initiE 1:/# /= ifT 1:/#;smt(Array1152.initiE). + + by smt(). + move => *; auto => /> &hr ?????; do split; 1..2,5: by smt(). + + move => k kbl kbh; rewrite initiE 1:/# /=. + rewrite get8_set64_directE 1,2:/# /=. + case ((i{hr} + 4) * 8 <= k && k < (i{hr} + 4) * 8 + 8 ) => *. + + by rewrite WArray2400.get64E pack8bE 1:/# initiE 1:/# /= initiE /#. + by rewrite /get8 initiE; smt(Array1152.initiE). + + move => k kbl kbh; rewrite initiE 1:/# /=. + rewrite get8_set64_directE 1,2:/# /=. + case ((i{hr} + 4) * 8 <= k && k < (i{hr} + 4) * 8 + 8) => *. + + by rewrite WArray2400.get64E pack8bE 1:/# initiE 1:/# /= initiE /#. + by rewrite /get8 initiE; smt(Array1152.initiE). + +swap {2} 1 1. +seq 1 1 : (#pre /\ + ctc{1} = Array1088.init (fun i => if i < 960 then c{2}.`1.[i] else c{2}.`2.[i-960])). ++ call (mlkem_correct_enc_avx2_stack). + + auto => /> &1 &2 ?; rewrite !tP => ?????; do split. + + by smt(Array32.initiE). + + by smt(Array32.initiE). + + move => *;rewrite initiE 1:/# /= initiE 1:/# /=. + by smt(Array1152.initiE). + + move => *;rewrite initiE 1:/# /= initiE 1:/# /=. + by smt(Array32.initiE). + + move => /= ? ? bufv ? krv rl rr H. + rewrite tP => i ib. + rewrite !initiE //=. + by smt(Array128.initiE Array960.initiE). + +seq 1 0 : #pre; 1: by auto. +seq 1 0 : (#pre /\ + (c{2} = cph{2} => cnd{1} = W64.of_int 0) /\ + (c{2} <> cph{2} => cnd{1} = W64.of_int 1)). ++ conseq (_: _ ==> (c{2} = cph{2} => cnd{1} = W64.of_int 0) /\ + (c{2} <> cph{2} => cnd{1} = W64.of_int 1)); 1: smt(). + ecall {1} (verify_correct_stack ct0{1} ctc{1}). + auto => /> &1 &2; rewrite /= !tP. + move => ??cphv???rst Heq Hdiff. + rewrite (_: cph{2} = (cph{2}.`1, cph{2}.`2)) /= in cphv; 1: by smt(). + move : cphv; move => [cphv1 cphv2]. + split. + + move => ceq; rewrite (Heq _); last by done. + move => i0 ib; rewrite !initiE //=. + case (i0 < 960). + + by move => ibb; rewrite ceq cphv1 1: /# initiE /= /#. + by move => ibb; rewrite ceq cphv2 1: /# initiE /= /#. + move => neq;rewrite Hdiff. + have : exists i0, 0<= i0 < 1088 /\ + (ciphertext{1}.[i0] <> + (Array1088.init (fun (i1 : int) => if i1 < 960 then c{2}.`1.[i1] else c{2}.`2.[i1 - 960])).[i0]); last by smt(). + case (c{2}.`1 <> cph{2}.`1). + + move => neq1. rewrite tP in neq1. + have [k kb] : exists k, 0<=k<960 /\ c{2}.`1.[k] <> cph{2}.`1.[k] by smt(). + exists k; split; 1: by smt(). + rewrite !initiE /= 1:/# /=. + move : (cphv1 k _); 1: smt(). + by rewrite initiE /= /#. + + move => eq1. + have neq2 : c{2}.`2 <> cph{2}.`2 by move : neq eq1; smt(). + rewrite tP in neq2. + have [k kb] : exists k, 0<=k<128 /\ c{2}.`2.[k] <> cph{2}.`2.[k] by smt(). + exists (k + 960); split; 1: by smt(). + rewrite !initiE /= 1:/# /=. + move : (cphv2 k _); 1: smt(). + by rewrite initiE /= /#. + done. + + +wp;ecall {1} (cmov_correct_stack shk0{1} (Array32.init (fun (i_0 : int) => kr{1}.[0 + i_0])) cnd{1}). + +wp;ecall{1} (shake256_A32_A1120_ph_stack zp_ct{1}). +conseq />;1: smt(). admitted.