From 4f855738151719720e8a3f691a88a94ebe4ecdba Mon Sep 17 00:00:00 2001 From: Vincent Laporte <Vincent.Laporte@inria.fr> Date: Tue, 6 Feb 2024 11:20:38 +0100 Subject: [PATCH] ML-KEM implementations are constant-time --- .github/workflows/ct.yml | 21 +++++++++++++++++++ code/jasmin/mlkem_avx2/Makefile | 5 ++++- .../jasmin/mlkem_avx2/extraction/jkem_avx2.ec | 7 +++++-- code/jasmin/mlkem_avx2/gen_matrix.jinc | 9 ++++++-- code/jasmin/mlkem_ref/Makefile | 5 ++++- code/jasmin/mlkem_ref/extraction/jkem.ec | 7 +++++-- code/jasmin/mlkem_ref/gen_matrix.jinc | 9 ++++++-- proof/correctness/MLKEM_InnerPKE.ec | 10 ++++----- shell.nix | 20 +++++++++++------- 9 files changed, 71 insertions(+), 22 deletions(-) create mode 100644 .github/workflows/ct.yml diff --git a/.github/workflows/ct.yml b/.github/workflows/ct.yml new file mode 100644 index 00000000..f71d25fa --- /dev/null +++ b/.github/workflows/ct.yml @@ -0,0 +1,21 @@ +name: "Check Constant-Time Security" +on: + pull_request: + push: + +jobs: + checkCT: + name: CT + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + dir: [ 'ref', 'avx2' ] + steps: + - uses: actions/checkout@v4 + - uses: cachix/install-nix-action@v25 + with: + nix_path: nixpkgs=channel:nixos-unstable + - uses: DeterminateSystems/magic-nix-cache-action@v3 + - run: nix-shell --arg full false --run "echo Dependencies OK…" + - run: nix-shell --arg full false --run "make -C code/jasmin/mlkem_${{matrix.dir}}/ ct" diff --git a/code/jasmin/mlkem_avx2/Makefile b/code/jasmin/mlkem_avx2/Makefile index b9ece5c9..64e42e58 100644 --- a/code/jasmin/mlkem_avx2/Makefile +++ b/code/jasmin/mlkem_avx2/Makefile @@ -94,7 +94,10 @@ test/test_polyvec_%: test/test_polyvec_%.c $(HEADERS) $(SOURCES) $(INCS) jpolyve %.s: %.jazz $(JASMINC) -o $@ $(JFLAGS) $^ -.PHONY: clean +.PHONY: ct clean + +ct: + $(JASMINC) -checkCT -infer jkem.jazz clean: -rm -f *.o diff --git a/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec b/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec index 581d06e7..84eae86c 100644 --- a/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec +++ b/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec @@ -4354,6 +4354,7 @@ module M(SC:Syscall_t) = { var val1:W16.t; var t:W16.t; var val2:W16.t; + var cond:bool; var cnd0:W64.t; var cnd1:W64.t; @@ -4375,13 +4376,15 @@ module M(SC:Syscall_t) = { t <- (t `<<` (W8.of_int 4)); val2 <- (val2 `|` t); pos <- (pos + (W64.of_int 1)); - if ((val1 \ult (W16.of_int 3329))) { + cond <- (val1 \ult (W16.of_int 3329)); + if (cond) { rp.[(W64.to_uint ctr)] <- val1; ctr <- (ctr + (W64.of_int 1)); } else { } - if ((val2 \ult (W16.of_int 3329))) { + cond <- (val2 \ult (W16.of_int 3329)); + if (cond) { if ((ctr \ult (W64.of_int 256))) { rp.[(W64.to_uint ctr)] <- val2; ctr <- (ctr + (W64.of_int 1)); diff --git a/code/jasmin/mlkem_avx2/gen_matrix.jinc b/code/jasmin/mlkem_avx2/gen_matrix.jinc index 3fbcfa20..0cfba0b3 100644 --- a/code/jasmin/mlkem_avx2/gen_matrix.jinc +++ b/code/jasmin/mlkem_avx2/gen_matrix.jinc @@ -33,13 +33,18 @@ fn __rej_uniform(stack u16[MLKEM_N] rp, reg u64 offset, stack u8[SHAKE128_RATE] val2 |= t; pos += 1; - if(val1 < MLKEM_Q) + reg bool cond; + #[declassify] + cond = val1 < MLKEM_Q; + if cond { rp[(int)ctr] = val1; ctr += 1; } - if(val2 < MLKEM_Q) + #[declassify] + cond = val2 < MLKEM_Q; + if cond { if(ctr < MLKEM_N) { diff --git a/code/jasmin/mlkem_ref/Makefile b/code/jasmin/mlkem_ref/Makefile index 076644d3..6a3c0454 100644 --- a/code/jasmin/mlkem_ref/Makefile +++ b/code/jasmin/mlkem_ref/Makefile @@ -81,7 +81,10 @@ test/test_polyvec_%: test/test_polyvec_%.c $(HEADERS) $(SOURCES) jpolyvec.s $(JASMINC) -o $@ $(JFLAGS) $^ -.PHONY: clean +.PHONY: ct clean + +ct: + $(JASMINC) -checkCT -infer jkem.jazz clean: -rm -f *.s diff --git a/code/jasmin/mlkem_ref/extraction/jkem.ec b/code/jasmin/mlkem_ref/extraction/jkem.ec index b2874dc5..30cbaf5d 100644 --- a/code/jasmin/mlkem_ref/extraction/jkem.ec +++ b/code/jasmin/mlkem_ref/extraction/jkem.ec @@ -1618,6 +1618,7 @@ module M(SC:Syscall_t) = { var val1:W16.t; var t:W16.t; var val2:W16.t; + var cond:bool; var cnd0:W64.t; var cnd1:W64.t; @@ -1639,13 +1640,15 @@ module M(SC:Syscall_t) = { t <- (t `<<` (W8.of_int 4)); val2 <- (val2 `|` t); pos <- (pos + (W64.of_int 1)); - if ((val1 \ult (W16.of_int 3329))) { + cond <- (val1 \ult (W16.of_int 3329)); + if (cond) { rp.[(W64.to_uint ctr)] <- val1; ctr <- (ctr + (W64.of_int 1)); } else { } - if ((val2 \ult (W16.of_int 3329))) { + cond <- (val2 \ult (W16.of_int 3329)); + if (cond) { if ((ctr \ult (W64.of_int 256))) { rp.[(W64.to_uint ctr)] <- val2; ctr <- (ctr + (W64.of_int 1)); diff --git a/code/jasmin/mlkem_ref/gen_matrix.jinc b/code/jasmin/mlkem_ref/gen_matrix.jinc index 5fa706ca..7461a13f 100644 --- a/code/jasmin/mlkem_ref/gen_matrix.jinc +++ b/code/jasmin/mlkem_ref/gen_matrix.jinc @@ -31,13 +31,18 @@ fn __rej_uniform(stack u16[MLKEM_N] rp, reg u64 offset, stack u8[SHAKE128_RATE] val2 |= t; pos += 1; - if(val1 < MLKEM_Q) + reg bool cond; + #[declassify] + cond = val1 < MLKEM_Q; + if cond { rp[(int)ctr] = val1; ctr += 1; } - if(val2 < MLKEM_Q) + #[declassify] + cond = val2 < MLKEM_Q; + if cond { if(ctr < MLKEM_N) { diff --git a/proof/correctness/MLKEM_InnerPKE.ec b/proof/correctness/MLKEM_InnerPKE.ec index 50d6febe..3f598fbc 100644 --- a/proof/correctness/MLKEM_InnerPKE.ec +++ b/proof/correctness/MLKEM_InnerPKE.ec @@ -582,12 +582,12 @@ auto => /> &1 &2 ?????????; do split; 1,2,4:smt(). + by rewrite mergebytes to_uintD_small; smt(). by rewrite mergebytes2 !to_uintD_small; smt(). -seq 2 2 : (to_uint ctr0{1} = j0{2} /\ - to_uint pos{1} = k{2} /\ +seq 4 2 : (to_uint ctr0{1} = j0{2} /\ + to_uint pos{1} = k{2} /\ #{/~exit{1}}post). -+ if; 1: by move => &1 &2; rewrite ultE qE; smt(). - + sp 2 2; if{2}. ++ sp 1 0; if; 1: by move => &1 &2; rewrite ultE qE; smt(). + + sp 3 2; if{2}. + rcondt{1} 1; 1: by move => *; auto => /> *; rewrite ultE; smt(). rcondt{1} 1; 1: by move => *; auto => /> *; rewrite ultE /= to_uintD_small /= /#. auto => /> &1 aar ctrl rpl 8?; rewrite ultE /= => *; do split; 2..3:smt(). @@ -633,7 +633,7 @@ seq 2 2 : (to_uint ctr0{1} = j0{2} /\ rewrite set_eqiE 1,2:/#. by rewrite to_sint_unsigned; rewrite /to_sint /smod /=; smt(W16.to_uint_cmp). - if{2}. + sp 1 0; if{2}. + rcondt{1} 1; 1: by move => *; auto => /> *; rewrite ultE /#. rcondt{1} 1; 1: by move => *; auto => /> *; rewrite ultE /#. auto => /> &1 &2 8?; rewrite ultE /= => *; do split; 2..3:smt(). diff --git a/shell.nix b/shell.nix index 611ca747..3f755eb8 100644 --- a/shell.nix +++ b/shell.nix @@ -1,7 +1,12 @@ -with import (fetchTarball { - url = https://github.com/NixOS/nixpkgs/archive/53fbe41cf76b6a685004194e38e889bc8857e8c2.tar.gz; - sha256 = "sha256:1fyc4kbhv7rrfzya74yprvd70prlcsv56b7n0fv47kn7rznvvr2b"; -}) {}; +{ pkgs ? + import (fetchTarball { + url = https://github.com/NixOS/nixpkgs/archive/53fbe41cf76b6a685004194e38e889bc8857e8c2.tar.gz; + sha256 = "sha256:1fyc4kbhv7rrfzya74yprvd70prlcsv56b7n0fv47kn7rznvvr2b"; + }) {} +, full ? true +}: + +with pkgs; let oc = ocaml-ng.ocamlPackages_4_14; @@ -28,7 +33,9 @@ let altergo = callPackage ./config/alt-ergo.nix { ocamlPackages = oc; } ; in -mkShell { +mkShell ({ + JASMINC = "${jasmin-compiler.bin}/bin/jasminc"; +} // lib.optionalAttrs full { packages = [ ec altergo @@ -36,6 +43,5 @@ mkShell { z3 ]; - JASMINC = "${jasmin-compiler.bin}/bin/jasminc"; EC_RDIRS = "Jasmin:${jasmin-compiler.lib}/lib/jasmin/easycrypt"; -} +})