From 4f855738151719720e8a3f691a88a94ebe4ecdba Mon Sep 17 00:00:00 2001
From: Vincent Laporte <Vincent.Laporte@inria.fr>
Date: Tue, 6 Feb 2024 11:20:38 +0100
Subject: [PATCH] ML-KEM implementations are constant-time

---
 .github/workflows/ct.yml                      | 21 +++++++++++++++++++
 code/jasmin/mlkem_avx2/Makefile               |  5 ++++-
 .../jasmin/mlkem_avx2/extraction/jkem_avx2.ec |  7 +++++--
 code/jasmin/mlkem_avx2/gen_matrix.jinc        |  9 ++++++--
 code/jasmin/mlkem_ref/Makefile                |  5 ++++-
 code/jasmin/mlkem_ref/extraction/jkem.ec      |  7 +++++--
 code/jasmin/mlkem_ref/gen_matrix.jinc         |  9 ++++++--
 proof/correctness/MLKEM_InnerPKE.ec           | 10 ++++-----
 shell.nix                                     | 20 +++++++++++-------
 9 files changed, 71 insertions(+), 22 deletions(-)
 create mode 100644 .github/workflows/ct.yml

diff --git a/.github/workflows/ct.yml b/.github/workflows/ct.yml
new file mode 100644
index 00000000..f71d25fa
--- /dev/null
+++ b/.github/workflows/ct.yml
@@ -0,0 +1,21 @@
+name: "Check Constant-Time Security"
+on:
+  pull_request:
+  push:
+
+jobs:
+  checkCT:
+    name: CT
+    runs-on: ubuntu-latest
+    strategy:
+      fail-fast: false
+      matrix:
+        dir: [ 'ref', 'avx2' ]
+    steps:
+    - uses: actions/checkout@v4
+    - uses: cachix/install-nix-action@v25
+      with:
+        nix_path: nixpkgs=channel:nixos-unstable
+    - uses: DeterminateSystems/magic-nix-cache-action@v3
+    - run: nix-shell --arg full false --run "echo Dependencies OK…"
+    - run: nix-shell --arg full false --run "make -C code/jasmin/mlkem_${{matrix.dir}}/ ct"
diff --git a/code/jasmin/mlkem_avx2/Makefile b/code/jasmin/mlkem_avx2/Makefile
index b9ece5c9..64e42e58 100644
--- a/code/jasmin/mlkem_avx2/Makefile
+++ b/code/jasmin/mlkem_avx2/Makefile
@@ -94,7 +94,10 @@ test/test_polyvec_%: test/test_polyvec_%.c $(HEADERS) $(SOURCES) $(INCS) jpolyve
 %.s: %.jazz
 	$(JASMINC) -o $@ $(JFLAGS) $^
 
-.PHONY: clean
+.PHONY: ct clean
+
+ct:
+	$(JASMINC) -checkCT -infer jkem.jazz
 
 clean:
 	-rm -f *.o
diff --git a/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec b/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec
index 581d06e7..84eae86c 100644
--- a/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec
+++ b/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec
@@ -4354,6 +4354,7 @@ module M(SC:Syscall_t) = {
     var val1:W16.t;
     var t:W16.t;
     var val2:W16.t;
+    var cond:bool;
     var cnd0:W64.t;
     var cnd1:W64.t;
     
@@ -4375,13 +4376,15 @@ module M(SC:Syscall_t) = {
       t <- (t `<<` (W8.of_int 4));
       val2 <- (val2 `|` t);
       pos <- (pos + (W64.of_int 1));
-      if ((val1 \ult (W16.of_int 3329))) {
+      cond <- (val1 \ult (W16.of_int 3329));
+      if (cond) {
         rp.[(W64.to_uint ctr)] <- val1;
         ctr <- (ctr + (W64.of_int 1));
       } else {
         
       }
-      if ((val2 \ult (W16.of_int 3329))) {
+      cond <- (val2 \ult (W16.of_int 3329));
+      if (cond) {
         if ((ctr \ult (W64.of_int 256))) {
           rp.[(W64.to_uint ctr)] <- val2;
           ctr <- (ctr + (W64.of_int 1));
diff --git a/code/jasmin/mlkem_avx2/gen_matrix.jinc b/code/jasmin/mlkem_avx2/gen_matrix.jinc
index 3fbcfa20..0cfba0b3 100644
--- a/code/jasmin/mlkem_avx2/gen_matrix.jinc
+++ b/code/jasmin/mlkem_avx2/gen_matrix.jinc
@@ -33,13 +33,18 @@ fn __rej_uniform(stack u16[MLKEM_N] rp, reg u64 offset, stack u8[SHAKE128_RATE]
     val2 |= t;
     pos += 1;
 
-    if(val1 < MLKEM_Q)
+    reg bool cond;
+    #[declassify]
+    cond = val1 < MLKEM_Q;
+    if cond
     {
       rp[(int)ctr] = val1;
       ctr += 1;
     }
 
-    if(val2 < MLKEM_Q)
+    #[declassify]
+    cond = val2 < MLKEM_Q;
+    if cond
     {
       if(ctr < MLKEM_N)
       {
diff --git a/code/jasmin/mlkem_ref/Makefile b/code/jasmin/mlkem_ref/Makefile
index 076644d3..6a3c0454 100644
--- a/code/jasmin/mlkem_ref/Makefile
+++ b/code/jasmin/mlkem_ref/Makefile
@@ -81,7 +81,10 @@ test/test_polyvec_%: test/test_polyvec_%.c $(HEADERS) $(SOURCES) jpolyvec.s
 	$(JASMINC) -o $@ $(JFLAGS) $^ 
 
 
-.PHONY: clean
+.PHONY: ct clean
+
+ct:
+	$(JASMINC) -checkCT -infer jkem.jazz
 
 clean:
 	-rm -f *.s
diff --git a/code/jasmin/mlkem_ref/extraction/jkem.ec b/code/jasmin/mlkem_ref/extraction/jkem.ec
index b2874dc5..30cbaf5d 100644
--- a/code/jasmin/mlkem_ref/extraction/jkem.ec
+++ b/code/jasmin/mlkem_ref/extraction/jkem.ec
@@ -1618,6 +1618,7 @@ module M(SC:Syscall_t) = {
     var val1:W16.t;
     var t:W16.t;
     var val2:W16.t;
+    var cond:bool;
     var cnd0:W64.t;
     var cnd1:W64.t;
     
@@ -1639,13 +1640,15 @@ module M(SC:Syscall_t) = {
       t <- (t `<<` (W8.of_int 4));
       val2 <- (val2 `|` t);
       pos <- (pos + (W64.of_int 1));
-      if ((val1 \ult (W16.of_int 3329))) {
+      cond <- (val1 \ult (W16.of_int 3329));
+      if (cond) {
         rp.[(W64.to_uint ctr)] <- val1;
         ctr <- (ctr + (W64.of_int 1));
       } else {
         
       }
-      if ((val2 \ult (W16.of_int 3329))) {
+      cond <- (val2 \ult (W16.of_int 3329));
+      if (cond) {
         if ((ctr \ult (W64.of_int 256))) {
           rp.[(W64.to_uint ctr)] <- val2;
           ctr <- (ctr + (W64.of_int 1));
diff --git a/code/jasmin/mlkem_ref/gen_matrix.jinc b/code/jasmin/mlkem_ref/gen_matrix.jinc
index 5fa706ca..7461a13f 100644
--- a/code/jasmin/mlkem_ref/gen_matrix.jinc
+++ b/code/jasmin/mlkem_ref/gen_matrix.jinc
@@ -31,13 +31,18 @@ fn __rej_uniform(stack u16[MLKEM_N] rp, reg u64 offset, stack u8[SHAKE128_RATE]
     val2 |= t;
     pos += 1;
 
-    if(val1 < MLKEM_Q)
+    reg bool cond;
+    #[declassify]
+    cond = val1 < MLKEM_Q;
+    if cond
     {
       rp[(int)ctr] = val1;
       ctr += 1;
     }
 
-    if(val2 < MLKEM_Q)
+    #[declassify]
+    cond = val2 < MLKEM_Q;
+    if cond
     {
       if(ctr < MLKEM_N)
       {
diff --git a/proof/correctness/MLKEM_InnerPKE.ec b/proof/correctness/MLKEM_InnerPKE.ec
index 50d6febe..3f598fbc 100644
--- a/proof/correctness/MLKEM_InnerPKE.ec
+++ b/proof/correctness/MLKEM_InnerPKE.ec
@@ -582,12 +582,12 @@ auto => /> &1 &2 ?????????; do split; 1,2,4:smt().
 + by rewrite mergebytes to_uintD_small; smt().
 by rewrite mergebytes2  !to_uintD_small; smt().
 
-seq 2 2 : (to_uint ctr0{1} = j0{2} /\
-           to_uint pos{1} = k{2} /\ 
+seq 4 2 : (to_uint ctr0{1} = j0{2} /\
+           to_uint pos{1} = k{2} /\
            #{/~exit{1}}post).
 
-+ if; 1: by move => &1 &2; rewrite ultE qE; smt().
-  + sp 2 2; if{2}. 
++ sp 1 0; if; 1: by move => &1 &2; rewrite ultE qE; smt().
+  + sp 3 2; if{2}.
     + rcondt{1} 1; 1: by move => *; auto => /> *; rewrite ultE; smt().
       rcondt{1} 1; 1: by move => *; auto => /> *; rewrite ultE /= to_uintD_small /= /#.
       auto => /> &1 aar ctrl rpl 8?; rewrite ultE /= => *; do split; 2..3:smt().
@@ -633,7 +633,7 @@ seq 2 2 : (to_uint ctr0{1} = j0{2} /\
     rewrite set_eqiE 1,2:/#.
     by rewrite to_sint_unsigned; rewrite /to_sint /smod /=; smt(W16.to_uint_cmp).
 
-  if{2}. 
+  sp 1 0; if{2}.
   + rcondt{1} 1; 1: by move => *; auto => /> *; rewrite ultE /#.
     rcondt{1} 1; 1: by move => *; auto => /> *; rewrite ultE /#.
     auto => /> &1 &2 8?; rewrite ultE /= => *; do split; 2..3:smt().
diff --git a/shell.nix b/shell.nix
index 611ca747..3f755eb8 100644
--- a/shell.nix
+++ b/shell.nix
@@ -1,7 +1,12 @@
-with import (fetchTarball {
-    url = https://github.com/NixOS/nixpkgs/archive/53fbe41cf76b6a685004194e38e889bc8857e8c2.tar.gz;
-    sha256 = "sha256:1fyc4kbhv7rrfzya74yprvd70prlcsv56b7n0fv47kn7rznvvr2b";
-}) {};
+{ pkgs ?
+    import (fetchTarball {
+      url = https://github.com/NixOS/nixpkgs/archive/53fbe41cf76b6a685004194e38e889bc8857e8c2.tar.gz;
+      sha256 = "sha256:1fyc4kbhv7rrfzya74yprvd70prlcsv56b7n0fv47kn7rznvvr2b";
+    }) {}
+, full ? true
+}:
+
+with pkgs;
 
 let
   oc = ocaml-ng.ocamlPackages_4_14;
@@ -28,7 +33,9 @@ let
   altergo = callPackage ./config/alt-ergo.nix { ocamlPackages = oc; } ;
 in
 
-mkShell {
+mkShell ({
+  JASMINC = "${jasmin-compiler.bin}/bin/jasminc";
+} // lib.optionalAttrs full {
   packages = [
     ec
     altergo
@@ -36,6 +43,5 @@ mkShell {
     z3
   ];
 
-  JASMINC = "${jasmin-compiler.bin}/bin/jasminc";
   EC_RDIRS = "Jasmin:${jasmin-compiler.lib}/lib/jasmin/easycrypt";
-}
+})