Skip to content

Commit

Permalink
gen-matrix: clearer loop condition
Browse files Browse the repository at this point in the history
  • Loading branch information
vbgl committed Feb 8, 2024
1 parent 53cc72e commit adc7bf1
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 198 deletions.
68 changes: 29 additions & 39 deletions code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec
Original file line number Diff line number Diff line change
Expand Up @@ -4350,58 +4350,48 @@ module M(SC:Syscall_t) = {

var ctr:W64.t;
var pos:W64.t;
var exit:W64.t;
var val1:W16.t;
var t:W16.t;
var val2:W16.t;
var cond:bool;
var cnd0:W64.t;
var cnd1:W64.t;

ctr <- offset;
pos <- (W64.of_int 0);
exit <- (W64.of_int 0);

while ((exit = (W64.of_int 0))) {
val1 <- (zeroextu16 buf.[(W64.to_uint pos)]);
pos <- (pos + (W64.of_int 1));
t <- (zeroextu16 buf.[(W64.to_uint pos)]);
val2 <- t;
val2 <- (val2 `>>` (W8.of_int 4));
t <- (t `&` (W16.of_int 15));
t <- (t `<<` (W8.of_int 8));
val1 <- (val1 `|` t);
pos <- (pos + (W64.of_int 1));
t <- (zeroextu16 buf.[(W64.to_uint pos)]);
t <- (t `<<` (W8.of_int 4));
val2 <- (val2 `|` t);
pos <- (pos + (W64.of_int 1));
cond <- (val1 \ult (W16.of_int 3329));
if (cond) {
rp.[(W64.to_uint ctr)] <- val1;
ctr <- (ctr + (W64.of_int 1));
} else {

}
cond <- (val2 \ult (W16.of_int 3329));
if (cond) {
if ((ctr \ult (W64.of_int 256))) {
rp.[(W64.to_uint ctr)] <- val2;

while ((pos \ult (W64.of_int (168 - 2)))) {
if ((ctr \ult (W64.of_int 256))) {
val1 <- (zeroextu16 buf.[(W64.to_uint pos)]);
t <- (zeroextu16 buf.[(W64.to_uint (pos + (W64.of_int 1)))]);
val2 <- t;
val2 <- (val2 `>>` (W8.of_int 4));
t <- (t `&` (W16.of_int 15));
t <- (t `<<` (W8.of_int 8));
val1 <- (val1 `|` t);
t <- (zeroextu16 buf.[(W64.to_uint (pos + (W64.of_int 2)))]);
t <- (t `<<` (W8.of_int 4));
val2 <- (val2 `|` t);
pos <- (pos + (W64.of_int 3));
cond <- (val1 \ult (W16.of_int 3329));
if (cond) {
rp.[(W64.to_uint ctr)] <- val1;
ctr <- (ctr + (W64.of_int 1));
} else {

}
cond <- (val2 \ult (W16.of_int 3329));
if (cond) {
if ((ctr \ult (W64.of_int 256))) {
rp.[(W64.to_uint ctr)] <- val2;
ctr <- (ctr + (W64.of_int 1));
} else {

}
} else {

}
} else {

pos <- (W64.of_int 168);
}
cnd0 <- (W64.of_int 256);
cnd0 <- (cnd0 - ctr);
cnd0 <- (cnd0 - (W64.of_int 1));
cnd1 <- (W64.of_int 168);
cnd1 <- (cnd1 - pos);
cnd1 <- (cnd1 - (W64.of_int 3));
exit <- (cnd0 `|` cnd1);
exit <- (exit `>>` (W8.of_int 63));
}
return (ctr, rp);
}
Expand Down
77 changes: 32 additions & 45 deletions code/jasmin/mlkem_avx2/gen_matrix.jinc
Original file line number Diff line number Diff line change
Expand Up @@ -9,59 +9,46 @@ fn __rej_uniform(stack u16[MLKEM_N] rp, reg u64 offset, stack u8[SHAKE128_RATE]
reg u16 val1 val2;
reg u16 t;
reg u64 pos ctr;
reg u64 cnd0 cnd1 exit;


ctr = offset;
pos = 0;
exit = 0;

while(exit == 0)
{
val1 = (16u)buf[(int)pos];
pos += 1;
t = (16u)buf[(int)pos];
val2 = t;
val2 >>= 4;
t &= 0x0F;
t <<= 8;
val1 |= t;
pos += 1;

t = (16u)buf[(int)pos];
t <<= 4;
val2 |= t;
pos += 1;

reg bool cond;
#[declassify]
cond = val1 < MLKEM_Q;
if cond
{
rp[(int)ctr] = val1;
ctr += 1;
}

#[declassify]
cond = val2 < MLKEM_Q;
if cond
{
if(ctr < MLKEM_N)
{
rp[(int)ctr] = val2;
while (pos < SHAKE128_RATE - 2) {
if ctr < MLKEM_N {
val1 = (16u)buf[pos];
t = (16u)buf[pos + 1];
val2 = t;
val2 >>= 4;
t &= 0x0F;
t <<= 8;
val1 |= t;

t = (16u)buf[pos + 2];
t <<= 4;
val2 |= t;
pos += 3;

reg bool cond;
#[declassify]
cond = val1 < MLKEM_Q;
if cond {
rp[ctr] = val1;
ctr += 1;
}
}

// Check if we should exit the loop
cnd0 = MLKEM_N;
cnd0 -= ctr;
cnd0 -= 1;
cnd1 = SHAKE128_RATE;
cnd1 -= pos;
cnd1 -= 3; //TODO: (potentially) wasting 2 'good' bytes
exit = cnd0 | cnd1;
exit >>= 63;
#[declassify]
cond = val2 < MLKEM_Q;
if cond {
if(ctr < MLKEM_N)
{
rp[ctr] = val2;
ctr += 1;
}
}
} else {
pos = SHAKE128_RATE;
}
}

return ctr, rp;
Expand Down
68 changes: 29 additions & 39 deletions code/jasmin/mlkem_ref/extraction/jkem.ec
Original file line number Diff line number Diff line change
Expand Up @@ -1621,58 +1621,48 @@ module M(SC:Syscall_t) = {

var ctr:W64.t;
var pos:W64.t;
var exit:W64.t;
var val1:W16.t;
var t:W16.t;
var val2:W16.t;
var cond:bool;
var cnd0:W64.t;
var cnd1:W64.t;

ctr <- offset;
pos <- (W64.of_int 0);
exit <- (W64.of_int 0);

while ((exit = (W64.of_int 0))) {
val1 <- (zeroextu16 buf.[(W64.to_uint pos)]);
pos <- (pos + (W64.of_int 1));
t <- (zeroextu16 buf.[(W64.to_uint pos)]);
val2 <- t;
val2 <- (val2 `>>` (W8.of_int 4));
t <- (t `&` (W16.of_int 15));
t <- (t `<<` (W8.of_int 8));
val1 <- (val1 `|` t);
pos <- (pos + (W64.of_int 1));
t <- (zeroextu16 buf.[(W64.to_uint pos)]);
t <- (t `<<` (W8.of_int 4));
val2 <- (val2 `|` t);
pos <- (pos + (W64.of_int 1));
cond <- (val1 \ult (W16.of_int 3329));
if (cond) {
rp.[(W64.to_uint ctr)] <- val1;
ctr <- (ctr + (W64.of_int 1));
} else {

}
cond <- (val2 \ult (W16.of_int 3329));
if (cond) {
if ((ctr \ult (W64.of_int 256))) {
rp.[(W64.to_uint ctr)] <- val2;

while ((pos \ult (W64.of_int (168 - 2)))) {
if ((ctr \ult (W64.of_int 256))) {
val1 <- (zeroextu16 buf.[(W64.to_uint pos)]);
t <- (zeroextu16 buf.[(W64.to_uint (pos + (W64.of_int 1)))]);
val2 <- t;
val2 <- (val2 `>>` (W8.of_int 4));
t <- (t `&` (W16.of_int 15));
t <- (t `<<` (W8.of_int 8));
val1 <- (val1 `|` t);
t <- (zeroextu16 buf.[(W64.to_uint (pos + (W64.of_int 2)))]);
t <- (t `<<` (W8.of_int 4));
val2 <- (val2 `|` t);
pos <- (pos + (W64.of_int 3));
cond <- (val1 \ult (W16.of_int 3329));
if (cond) {
rp.[(W64.to_uint ctr)] <- val1;
ctr <- (ctr + (W64.of_int 1));
} else {

}
cond <- (val2 \ult (W16.of_int 3329));
if (cond) {
if ((ctr \ult (W64.of_int 256))) {
rp.[(W64.to_uint ctr)] <- val2;
ctr <- (ctr + (W64.of_int 1));
} else {

}
} else {

}
} else {

pos <- (W64.of_int 168);
}
cnd0 <- (W64.of_int 256);
cnd0 <- (cnd0 - ctr);
cnd0 <- (cnd0 - (W64.of_int 1));
cnd1 <- (W64.of_int 168);
cnd1 <- (cnd1 - pos);
cnd1 <- (cnd1 - (W64.of_int 3));
exit <- (cnd0 `|` cnd1);
exit <- (exit `>>` (W8.of_int 63));
}
return (ctr, rp);
}
Expand Down
77 changes: 32 additions & 45 deletions code/jasmin/mlkem_ref/gen_matrix.jinc
Original file line number Diff line number Diff line change
Expand Up @@ -7,59 +7,46 @@ fn __rej_uniform(stack u16[MLKEM_N] rp, reg u64 offset, stack u8[SHAKE128_RATE]
reg u16 val1 val2;
reg u16 t;
reg u64 pos ctr;
reg u64 cnd0 cnd1 exit;


ctr = offset;
pos = 0;
exit = 0;

while(exit == 0)
{
val1 = (16u)buf[(int)pos];
pos += 1;
t = (16u)buf[(int)pos];
val2 = t;
val2 >>= 4;
t &= 0x0F;
t <<= 8;
val1 |= t;
pos += 1;

t = (16u)buf[(int)pos];
t <<= 4;
val2 |= t;
pos += 1;

reg bool cond;
#[declassify]
cond = val1 < MLKEM_Q;
if cond
{
rp[(int)ctr] = val1;
ctr += 1;
}

#[declassify]
cond = val2 < MLKEM_Q;
if cond
{
if(ctr < MLKEM_N)
{
rp[(int)ctr] = val2;
while (pos < SHAKE128_RATE - 2) {
if ctr < MLKEM_N {
val1 = (16u)buf[pos];
t = (16u)buf[pos + 1];
val2 = t;
val2 >>= 4;
t &= 0x0F;
t <<= 8;
val1 |= t;

t = (16u)buf[pos + 2];
t <<= 4;
val2 |= t;
pos += 3;

reg bool cond;
#[declassify]
cond = val1 < MLKEM_Q;
if cond {
rp[ctr] = val1;
ctr += 1;
}
}

// Check if we should exit the loop
cnd0 = MLKEM_N;
cnd0 -= ctr;
cnd0 -= 1;
cnd1 = SHAKE128_RATE;
cnd1 -= pos;
cnd1 -= 3; //TODO: (potentially) wasting 2 'good' bytes
exit = cnd0 | cnd1;
exit >>= 63;
#[declassify]
cond = val2 < MLKEM_Q;
if cond {
if(ctr < MLKEM_N)
{
rp[ctr] = val2;
ctr += 1;
}
}
} else {
pos = SHAKE128_RATE;
}
}

return ctr, rp;
Expand Down
Loading

0 comments on commit adc7bf1

Please sign in to comment.