Skip to content

Commit

Permalink
ML-KEM implementations are constant-time
Browse files Browse the repository at this point in the history
  • Loading branch information
vbgl committed Feb 7, 2024
1 parent bc31973 commit 4f85573
Show file tree
Hide file tree
Showing 9 changed files with 71 additions and 22 deletions.
21 changes: 21 additions & 0 deletions .github/workflows/ct.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: "Check Constant-Time Security"
on:
pull_request:
push:

jobs:
checkCT:
name: CT
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
dir: [ 'ref', 'avx2' ]
steps:
- uses: actions/checkout@v4
- uses: cachix/install-nix-action@v25
with:
nix_path: nixpkgs=channel:nixos-unstable
- uses: DeterminateSystems/magic-nix-cache-action@v3
- run: nix-shell --arg full false --run "echo Dependencies OK…"
- run: nix-shell --arg full false --run "make -C code/jasmin/mlkem_${{matrix.dir}}/ ct"
5 changes: 4 additions & 1 deletion code/jasmin/mlkem_avx2/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ test/test_polyvec_%: test/test_polyvec_%.c $(HEADERS) $(SOURCES) $(INCS) jpolyve
%.s: %.jazz
$(JASMINC) -o $@ $(JFLAGS) $^

.PHONY: clean
.PHONY: ct clean

ct:
$(JASMINC) -checkCT -infer jkem.jazz

clean:
-rm -f *.o
Expand Down
7 changes: 5 additions & 2 deletions code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec
Original file line number Diff line number Diff line change
Expand Up @@ -4354,6 +4354,7 @@ module M(SC:Syscall_t) = {
var val1:W16.t;
var t:W16.t;
var val2:W16.t;
var cond:bool;
var cnd0:W64.t;
var cnd1:W64.t;

Expand All @@ -4375,13 +4376,15 @@ module M(SC:Syscall_t) = {
t <- (t `<<` (W8.of_int 4));
val2 <- (val2 `|` t);
pos <- (pos + (W64.of_int 1));
if ((val1 \ult (W16.of_int 3329))) {
cond <- (val1 \ult (W16.of_int 3329));
if (cond) {
rp.[(W64.to_uint ctr)] <- val1;
ctr <- (ctr + (W64.of_int 1));
} else {

}
if ((val2 \ult (W16.of_int 3329))) {
cond <- (val2 \ult (W16.of_int 3329));
if (cond) {
if ((ctr \ult (W64.of_int 256))) {
rp.[(W64.to_uint ctr)] <- val2;
ctr <- (ctr + (W64.of_int 1));
Expand Down
9 changes: 7 additions & 2 deletions code/jasmin/mlkem_avx2/gen_matrix.jinc
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,18 @@ fn __rej_uniform(stack u16[MLKEM_N] rp, reg u64 offset, stack u8[SHAKE128_RATE]
val2 |= t;
pos += 1;

if(val1 < MLKEM_Q)
reg bool cond;
#[declassify]
cond = val1 < MLKEM_Q;
if cond
{
rp[(int)ctr] = val1;
ctr += 1;
}

if(val2 < MLKEM_Q)
#[declassify]
cond = val2 < MLKEM_Q;
if cond
{
if(ctr < MLKEM_N)
{
Expand Down
5 changes: 4 additions & 1 deletion code/jasmin/mlkem_ref/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ test/test_polyvec_%: test/test_polyvec_%.c $(HEADERS) $(SOURCES) jpolyvec.s
$(JASMINC) -o $@ $(JFLAGS) $^


.PHONY: clean
.PHONY: ct clean

ct:
$(JASMINC) -checkCT -infer jkem.jazz

clean:
-rm -f *.s
Expand Down
7 changes: 5 additions & 2 deletions code/jasmin/mlkem_ref/extraction/jkem.ec
Original file line number Diff line number Diff line change
Expand Up @@ -1618,6 +1618,7 @@ module M(SC:Syscall_t) = {
var val1:W16.t;
var t:W16.t;
var val2:W16.t;
var cond:bool;
var cnd0:W64.t;
var cnd1:W64.t;

Expand All @@ -1639,13 +1640,15 @@ module M(SC:Syscall_t) = {
t <- (t `<<` (W8.of_int 4));
val2 <- (val2 `|` t);
pos <- (pos + (W64.of_int 1));
if ((val1 \ult (W16.of_int 3329))) {
cond <- (val1 \ult (W16.of_int 3329));
if (cond) {
rp.[(W64.to_uint ctr)] <- val1;
ctr <- (ctr + (W64.of_int 1));
} else {

}
if ((val2 \ult (W16.of_int 3329))) {
cond <- (val2 \ult (W16.of_int 3329));
if (cond) {
if ((ctr \ult (W64.of_int 256))) {
rp.[(W64.to_uint ctr)] <- val2;
ctr <- (ctr + (W64.of_int 1));
Expand Down
9 changes: 7 additions & 2 deletions code/jasmin/mlkem_ref/gen_matrix.jinc
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,18 @@ fn __rej_uniform(stack u16[MLKEM_N] rp, reg u64 offset, stack u8[SHAKE128_RATE]
val2 |= t;
pos += 1;

if(val1 < MLKEM_Q)
reg bool cond;
#[declassify]
cond = val1 < MLKEM_Q;
if cond
{
rp[(int)ctr] = val1;
ctr += 1;
}

if(val2 < MLKEM_Q)
#[declassify]
cond = val2 < MLKEM_Q;
if cond
{
if(ctr < MLKEM_N)
{
Expand Down
10 changes: 5 additions & 5 deletions proof/correctness/MLKEM_InnerPKE.ec
Original file line number Diff line number Diff line change
Expand Up @@ -582,12 +582,12 @@ auto => /> &1 &2 ?????????; do split; 1,2,4:smt().
+ by rewrite mergebytes to_uintD_small; smt().
by rewrite mergebytes2 !to_uintD_small; smt().

seq 2 2 : (to_uint ctr0{1} = j0{2} /\
to_uint pos{1} = k{2} /\
seq 4 2 : (to_uint ctr0{1} = j0{2} /\
to_uint pos{1} = k{2} /\
#{/~exit{1}}post).

+ if; 1: by move => &1 &2; rewrite ultE qE; smt().
+ sp 2 2; if{2}.
+ sp 1 0; if; 1: by move => &1 &2; rewrite ultE qE; smt().
+ sp 3 2; if{2}.
+ rcondt{1} 1; 1: by move => *; auto => /> *; rewrite ultE; smt().
rcondt{1} 1; 1: by move => *; auto => /> *; rewrite ultE /= to_uintD_small /= /#.
auto => /> &1 aar ctrl rpl 8?; rewrite ultE /= => *; do split; 2..3:smt().
Expand Down Expand Up @@ -633,7 +633,7 @@ seq 2 2 : (to_uint ctr0{1} = j0{2} /\
rewrite set_eqiE 1,2:/#.
by rewrite to_sint_unsigned; rewrite /to_sint /smod /=; smt(W16.to_uint_cmp).

if{2}.
sp 1 0; if{2}.
+ rcondt{1} 1; 1: by move => *; auto => /> *; rewrite ultE /#.
rcondt{1} 1; 1: by move => *; auto => /> *; rewrite ultE /#.
auto => /> &1 &2 8?; rewrite ultE /= => *; do split; 2..3:smt().
Expand Down
20 changes: 13 additions & 7 deletions shell.nix
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
with import (fetchTarball {
url = https://github.com/NixOS/nixpkgs/archive/53fbe41cf76b6a685004194e38e889bc8857e8c2.tar.gz;
sha256 = "sha256:1fyc4kbhv7rrfzya74yprvd70prlcsv56b7n0fv47kn7rznvvr2b";
}) {};
{ pkgs ?
import (fetchTarball {
url = https://github.com/NixOS/nixpkgs/archive/53fbe41cf76b6a685004194e38e889bc8857e8c2.tar.gz;
sha256 = "sha256:1fyc4kbhv7rrfzya74yprvd70prlcsv56b7n0fv47kn7rznvvr2b";
}) {}
, full ? true
}:

with pkgs;

let
oc = ocaml-ng.ocamlPackages_4_14;
Expand All @@ -28,14 +33,15 @@ let
altergo = callPackage ./config/alt-ergo.nix { ocamlPackages = oc; } ;
in

mkShell {
mkShell ({
JASMINC = "${jasmin-compiler.bin}/bin/jasminc";
} // lib.optionalAttrs full {
packages = [
ec
altergo
cvc4
z3
];

JASMINC = "${jasmin-compiler.bin}/bin/jasminc";
EC_RDIRS = "Jasmin:${jasmin-compiler.lib}/lib/jasmin/easycrypt";
}
})

0 comments on commit 4f85573

Please sign in to comment.