Skip to content

Commit

Permalink
Declassify the random seed
Browse files Browse the repository at this point in the history
This changes what is declassified and when.

Before, declassification only occurs during rejection sampling and what
is declassified is whether rejection occurs or not.

After, the full sampling of the matrix expects its seed to be public:
what is declassified is the public key when it is read from memory in
the IND-CPA encryption, when it is derived from the output of SHA3 in
keygen.
  • Loading branch information
vbgl committed Jun 14, 2024
1 parent 5358a27 commit 7083770
Show file tree
Hide file tree
Showing 9 changed files with 20 additions and 27 deletions.
7 changes: 2 additions & 5 deletions code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec
Original file line number Diff line number Diff line change
Expand Up @@ -4369,7 +4369,6 @@ module M(SC:Syscall_t) = {
var val1:W16.t;
var t:W16.t;
var val2:W16.t;
var cond:bool;

ctr <- offset;
pos <- (W64.of_int 0);
Expand All @@ -4387,15 +4386,13 @@ module M(SC:Syscall_t) = {
t <- (t `<<` (W8.of_int 4));
val2 <- (val2 `|` t);
pos <- (pos + (W64.of_int 3));
cond <- (val1 \ult (W16.of_int 3329));
if (cond) {
if ((val1 \ult (W16.of_int 3329))) {
rp.[(W64.to_uint ctr)] <- val1;
ctr <- (ctr + (W64.of_int 1));
} else {

}
cond <- (val2 \ult (W16.of_int 3329));
if (cond) {
if ((val2 \ult (W16.of_int 3329))) {
if ((ctr \ult (W64.of_int 256))) {
rp.[(W64.to_uint ctr)] <- val2;
ctr <- (ctr + (W64.of_int 1));
Expand Down
1 change: 1 addition & 0 deletions code/jasmin/mlkem_avx2/gen_matrix.jazz
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ export fn gen_matrix_jazz(reg u64 ap, reg u64 seedp)

for i = 0 to MLKEM_SYMBYTES
{
#[declassify]
c = (u8)[seedp + i];
seed[i] = c;
}
Expand Down
9 changes: 2 additions & 7 deletions code/jasmin/mlkem_avx2/gen_matrix.jinc
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,12 @@ fn __rej_uniform(stack u16[MLKEM_N] rp, reg u64 offset, stack u8[SHAKE128_RATE]
val2 |= t;
pos += 3;

reg bool cond;
#[declassify]
cond = val1 < MLKEM_Q;
if cond {
if val1 < MLKEM_Q {
rp[ctr] = val1;
ctr += 1;
}

#[declassify]
cond = val2 < MLKEM_Q;
if cond {
if val2 < MLKEM_Q {
if(ctr < MLKEM_N)
{
rp[ctr] = val2;
Expand Down
3 changes: 3 additions & 0 deletions code/jasmin/mlkem_avx2/indcpa.jinc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ fn __indcpa_keypair(reg u64 pkp, reg u64 skp, reg ptr u8[MLKEM_SYMBYTES] randomn

for i=0 to MLKEM_SYMBYTES/8
{
#[declassify]
t64 = buf[u64 i];
publicseed[u64 i] = t64;
t64 = buf[u64 i + MLKEM_SYMBYTES/8];
Expand Down Expand Up @@ -91,6 +92,7 @@ fn __indcpa_enc_0(stack u64 sctp, reg ptr u8[MLKEM_INDCPA_MSGBYTES] msgp, reg u6
pkp += MLKEM_POLYVECBYTES;
while (i < MLKEM_SYMBYTES/8)
{
#[declassify]
t64 = (u64)[pkp];
publicseed.[u64 8 * (int)i] = t64;
pkp += 8;
Expand Down Expand Up @@ -155,6 +157,7 @@ fn __indcpa_enc_1(reg ptr u8[MLKEM_INDCPA_CIPHERTEXTBYTES] ctp, reg ptr u8[MLKEM
pkp += MLKEM_POLYVECBYTES;
while (i < MLKEM_SYMBYTES/8)
{
#[declassify]
t64 = (u64)[pkp];
publicseed.[u64 8*(int)i] = t64;
pkp += 8;
Expand Down
7 changes: 2 additions & 5 deletions code/jasmin/mlkem_ref/extraction/jkem.ec
Original file line number Diff line number Diff line change
Expand Up @@ -1634,7 +1634,6 @@ module M(SC:Syscall_t) = {
var val1:W16.t;
var t:W16.t;
var val2:W16.t;
var cond:bool;

ctr <- offset;
pos <- (W64.of_int 0);
Expand All @@ -1652,15 +1651,13 @@ module M(SC:Syscall_t) = {
t <- (t `<<` (W8.of_int 4));
val2 <- (val2 `|` t);
pos <- (pos + (W64.of_int 3));
cond <- (val1 \ult (W16.of_int 3329));
if (cond) {
if ((val1 \ult (W16.of_int 3329))) {
rp.[(W64.to_uint ctr)] <- val1;
ctr <- (ctr + (W64.of_int 1));
} else {

}
cond <- (val2 \ult (W16.of_int 3329));
if (cond) {
if ((val2 \ult (W16.of_int 3329))) {
if ((ctr \ult (W64.of_int 256))) {
rp.[(W64.to_uint ctr)] <- val2;
ctr <- (ctr + (W64.of_int 1));
Expand Down
2 changes: 2 additions & 0 deletions code/jasmin/mlkem_ref/gen_matrix.jazz
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
require "params.jinc"
require "gen_matrix.jinc"

export fn gen_matrix_jazz(reg u64 ap, reg u64 seedp, reg u64 transposed)
Expand All @@ -11,6 +12,7 @@ export fn gen_matrix_jazz(reg u64 ap, reg u64 seedp, reg u64 transposed)

for i = 0 to MLKEM_SYMBYTES
{
#[declassify]
c = (u8)[seedp + i];
seed[i] = c;
}
Expand Down
9 changes: 2 additions & 7 deletions code/jasmin/mlkem_ref/gen_matrix.jinc
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,12 @@ fn __rej_uniform(stack u16[MLKEM_N] rp, reg u64 offset, stack u8[SHAKE128_RATE]
val2 |= t;
pos += 3;

reg bool cond;
#[declassify]
cond = val1 < MLKEM_Q;
if cond {
if val1 < MLKEM_Q {
rp[ctr] = val1;
ctr += 1;
}

#[declassify]
cond = val2 < MLKEM_Q;
if cond {
if val2 < MLKEM_Q {
if(ctr < MLKEM_N)
{
rp[ctr] = val2;
Expand Down
3 changes: 3 additions & 0 deletions code/jasmin/mlkem_ref/indcpa.jinc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ fn __indcpa_keypair(reg u64 pkp, reg u64 skp, reg ptr u8[MLKEM_SYMBYTES] randomn

for i=0 to MLKEM_SYMBYTES/8
{
#[declassify]
t64 = buf[u64 i];
publicseed[u64 i] = t64;
t64 = buf[u64 i + MLKEM_SYMBYTES/8];
Expand Down Expand Up @@ -103,6 +104,7 @@ fn __indcpa_enc(stack u64 sctp, reg ptr u8[32] msgp, reg u64 pkp, reg ptr u8[MLK
pkp += MLKEM_POLYVECBYTES;
while (i < MLKEM_SYMBYTES/8)
{
#[declassify]
t64 = (u64)[pkp];
publicseed.[u64 8 * (int)i] = t64;
pkp += 8;
Expand Down Expand Up @@ -178,6 +180,7 @@ fn __iindcpa_enc(reg ptr u8[MLKEM_CT_LEN] ctp, reg ptr u8[32] msgp, reg u64 pkp,
pkp += MLKEM_POLYVECBYTES;
while (i < MLKEM_SYMBYTES/8)
{
#[declassify]
t64 = (u64)[pkp];
publicseed.[u64 8*(int)i] = t64;
pkp += 8;
Expand Down
6 changes: 3 additions & 3 deletions proof/correctness/MLKEM_InnerPKE.ec
Original file line number Diff line number Diff line change
Expand Up @@ -591,8 +591,8 @@ seq 4 2 : (to_uint ctr0{1} = j0{2} /\
to_uint pos{1} = k{2} /\
#{/~pos{1} \ult (of_int (168 - 2))%W64}post).

+ sp 1 0; if; 1: by move => &1 &2; rewrite ultE qE; smt().
+ sp 3 2; if{2}.
+ if; 1: by move => &1 &2; rewrite ultE qE; smt().
+ sp 2 2; if{2}.
+ rcondt{1} 1; 1: by move => *; auto => /> *; rewrite ultE; smt().
rcondt{1} 1; 1: by move => *; auto => /> *; rewrite ultE /= to_uintD_small /= /#.
auto => /> &1 aar ctrl rpl 8?; rewrite ultE /= => *; do split; 2..3:smt().
Expand Down Expand Up @@ -638,7 +638,7 @@ seq 4 2 : (to_uint ctr0{1} = j0{2} /\
rewrite set_eqiE 1,2:/#.
by rewrite to_sint_unsigned; rewrite /to_sint /smod /=; smt(W16.to_uint_cmp).

sp 1 0; if{2}.
if{2}.
+ rcondt{1} 1; 1: by move => *; auto => /> *; rewrite ultE /#.
rcondt{1} 1; 1: by move => *; auto => /> *; rewrite ultE /#.
auto => /> &1 &2 8?; rewrite ultE /= => *; do split; 2..3:smt().
Expand Down

0 comments on commit 7083770

Please sign in to comment.