Skip to content

Commit

Permalink
Ref: poly (de)compress: use a single loop counter
Browse files Browse the repository at this point in the history
  • Loading branch information
vbgl committed Feb 9, 2024
1 parent 98255e3 commit 9827d66
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 41 deletions.
16 changes: 4 additions & 12 deletions code/jasmin/mlkem_ref/extraction/jkem.ec
Original file line number Diff line number Diff line change
Expand Up @@ -850,25 +850,22 @@ module M(SC:Syscall_t) = {
W16.t Array256.t = {

var i:W64.t;
var j:W64.t;
var t:W16.t;
var d0:W32.t;
var d1:W32.t;

a <@ _poly_csubq (a);
i <- (W64.of_int 0);
j <- (W64.of_int 0);

while ((i \ult (W64.of_int 128))) {
t <- a.[(W64.to_uint j)];
t <- a.[(W64.to_uint ((W64.of_int 2) * i))];
d0 <- (zeroextu32 t);
d0 <- (d0 `<<` (W8.of_int 4));
d0 <- (d0 + (W32.of_int 1665));
d0 <- (d0 * (W32.of_int 80635));
d0 <- (d0 `>>` (W8.of_int 28));
d0 <- (d0 `&` (W32.of_int 15));
j <- (j + (W64.of_int 1));
t <- a.[(W64.to_uint j)];
t <- a.[(W64.to_uint (((W64.of_int 2) * i) + (W64.of_int 1)))];
d1 <- (zeroextu32 t);
d1 <- (d1 `<<` (W8.of_int 4));
d1 <- (d1 + (W32.of_int 1665));
Expand All @@ -879,21 +876,18 @@ module M(SC:Syscall_t) = {
d0 <- (d0 `|` d1);
rp.[(W64.to_uint i)] <- (truncateu8 d0);
i <- (i + (W64.of_int 1));
j <- (j + (W64.of_int 1));
}
return (rp, a);
}

proc _poly_decompress (rp:W16.t Array256.t, ap:W64.t) : W16.t Array256.t = {

var i:W64.t;
var j:W64.t;
var t:W8.t;
var d0:W16.t;
var d1:W16.t;

i <- (W64.of_int 0);
j <- (W64.of_int 0);

while ((i \ult (W64.of_int 128))) {
t <- (loadW8 Glob.mem (W64.to_uint (ap + i)));
Expand All @@ -907,10 +901,8 @@ module M(SC:Syscall_t) = {
d1 <- (d1 + (W16.of_int 8));
d0 <- (d0 `>>` (W8.of_int 4));
d1 <- (d1 `>>` (W8.of_int 4));
rp.[(W64.to_uint j)] <- d0;
j <- (j + (W64.of_int 1));
rp.[(W64.to_uint j)] <- d1;
j <- (j + (W64.of_int 1));
rp.[(W64.to_uint ((W64.of_int 2) * i))] <- d0;
rp.[(W64.to_uint (((W64.of_int 2) * i) + (W64.of_int 1)))] <- d1;
i <- (i + (W64.of_int 1));
}
return (rp);
Expand Down
20 changes: 7 additions & 13 deletions code/jasmin/mlkem_ref/poly.jinc
Original file line number Diff line number Diff line change
Expand Up @@ -176,23 +176,21 @@ fn _i_poly_compress(reg ptr u8[MLKEM_POLYCOMPRESSEDBYTES] rp, reg ptr u16[MLKEM_
{
reg u16 t;
reg u32 d0, d1;
reg u64 i j;
reg u64 i;

a = _poly_csubq(a);

i = 0;
j = 0;
while(i < 128)
{
t = a[(int)j];
t = a[2 * i];
d0 = (32u)t;
d0 <<= 4;
d0 += 1665;
d0 *= 80635;
d0 >>= 28;
d0 &= 0xf;
j += 1;
t = a[(int)j];
t = a[2 * i + 1];
d1 = (32u)t;
d1 <<= 4;
d1 += 1665;
Expand All @@ -201,9 +199,8 @@ fn _i_poly_compress(reg ptr u8[MLKEM_POLYCOMPRESSEDBYTES] rp, reg ptr u16[MLKEM_
d1 &= 0xf;
d1 <<= 4;
d0 |= d1;
rp[(int) i] = d0;
rp[i] = d0;
i += 1;
j += 1;
}
return rp, a;
}
Expand All @@ -213,10 +210,9 @@ fn _poly_decompress(reg ptr u16[MLKEM_N] rp, reg u64 ap) -> stack u16[MLKEM_N]
{
reg u8 t;
reg u16 d0, d1;
reg u64 i j;
reg u64 i;

i = 0;
j = 0;

while (i < 128) {
t = (u8)[ap+i];
Expand All @@ -230,10 +226,8 @@ fn _poly_decompress(reg ptr u16[MLKEM_N] rp, reg u64 ap) -> stack u16[MLKEM_N]
d1 += 8;
d0 >>= 4;
d1 >>= 4;
rp[(int)j] = d0;
j += 1;
rp[(int)j] = d1;
j += 1;
rp[2 * i] = d0;
rp[2 * i + 1] = d1;
i += 1;
}
return rp;
Expand Down
32 changes: 16 additions & 16 deletions proof/correctness/MLKEM_Poly.ec
Original file line number Diff line number Diff line change
Expand Up @@ -861,22 +861,22 @@ lemma i_poly_compress_corr _a :
res{1}.`1 = res{2} /\
pos_bound256_cxq res{1}.`2 0 256 1].
proc => /=.
seq 3 3 : (#{/~a{1}}pre /\ to_uint i{1} = i{2} /\ i{2} = 0 /\
to_uint j{1} = j{2} /\ j{2} = 0 /\
seq 2 3 : (#{/~a{1}}pre /\ to_uint i{1} = i{2} /\ i{2} = 0 /\
j{2} = 0 /\
pos_bound256_cxq a{1} 0 256 1 /\ lift_array256 a{1} = _a).
wp => /=;call{1} (poly_csubq_corr _a); 1: by auto => /#.

while (#{~i{2}=0}{~j{2}=0}pre /\ to_uint i{1} = i{2} /\ 0<=i{2}<=128 /\
to_uint j{1} = j{2} /\ j{2} = 2* i{2} /\
j{2} = 2* i{2} /\
forall k, 0<=k<i{2} => rp{1}.[k] = r{2}.[k]); last first.
+ auto => /> &1 &2; rewrite ultE of_uintK /= =>
bnd ??; split; 1: by smt().
move => i' j' ra'; rewrite ultE of_uintK /= => rr exit _ ibl ibh jv prev.
bnd ?; split; 1: by smt().
move => i' j' ra'; rewrite ultE of_uintK /= => rr exit ibl ibh prev.
by rewrite tP => k kb /#.

auto => /> &1 &2 ??; rewrite /pos_bound256_cxq /=.
rewrite ultE of_uintK /= => ?????.
rewrite !to_uintD_small /=; 1..3: smt().
rewrite ultE of_uintK /= => ????.
rewrite !to_uintD_small ?to_uintM_small /= 1..4:/#.
do split; 1..3: by smt().
+ move => k kbl kbh.
case (k = to_uint i{1}); last first.
Expand All @@ -892,9 +892,9 @@ do split; 1..3: by smt().
case (k = to_uint i{1}); last by smt(Array128.set_neqiE).
move => iv; have -> : 15 = 2^4 - 1 by auto.
rewrite !and_mod //.
pose x := (((zeroextu32 a{1}.[to_uint j{1}] `<<` (of_int 4)%W8) +
pose x := (((zeroextu32 a{1}.[2 * to_uint i{1}] `<<` (of_int 4)%W8) +
(of_int 1665)%W32) * (of_int 80635)%W32 `>>`(of_int 28)%W8).
pose y := (((zeroextu32 a{1}.[to_uint j{1} + 1] `<<` (of_int 4)%W8) +
pose y := (((zeroextu32 a{1}.[2 * to_uint i{1} + 1] `<<` (of_int 4)%W8) +
(of_int 1665)%W32) * (of_int 80635)%W32 `>>` (of_int 28)%W8).
rewrite to_uint_eq to_uint_truncateu8 !of_uintK to_uint_orw_disjoint.
+ apply W32.ext_eq => i ib; rewrite /W32.(`&`) map2E initiE //=.
Expand Down Expand Up @@ -938,24 +938,24 @@ lemma poly_decompress_corr mem _p (_a : W8.t Array128.t):
lift_array256 res{1} = decompress_poly 4 res{2} /\
pos_bound256_cxq res{1} 0 256 1 ].
proc.
while(#pre /\ to_uint i{1} = i{2} /\ to_uint j{1} = 2*i{2} /\ 0 <= i{2} <= 128 /\
while(#pre /\ to_uint i{1} = i{2} /\ 0 <= i{2} <= 128 /\
forall k, 0<=k<i{2}*2 => incoeff (to_sint rp{1}.[k]) = decompress 4 r{2}.[k] /\
0 <= to_sint rp{1}.[k] <q); last first.
+ auto => /> &1; rewrite /lift_array256 /decompress_poly /pos_bound256_cxq /=.
move => vrl vrh; split; 1: by smt().
move => i j rp rr; rewrite ultE of_uintK /= => exit _ ibl ibh jv prior; split; 2: by smt().
move => i rp rr; rewrite ultE of_uintK /= => exit _ ibl ibh prior; split; 2: by smt().
by rewrite tP => x xb; rewrite !mapiE //= /#.

auto => /> &1 &2 vpl vph il ih jv prior; rewrite ultE of_uintK /= => enter _;
auto => /> &1 &2 vpl vph il ih prior; rewrite ultE of_uintK /= => enter _;
split; 2: by rewrite ultE of_uintK /= to_uintD_small /#.
rewrite !to_uintD_small /=; 1..4: by smt().
do split; 1..3: by smt().
rewrite !to_uintD_small ?to_uintM_small /= 1..5:/#.
do split; 1..2: by smt().
move => k kl kh.
rewrite /(`>>`) /= /load_array32 /=.

split;last first.
+ case(k < to_uint i{1} * 2); 1: by smt(Array256.set_neqiE).
case(k = to_uint j{1}).
case(k = 2 * to_uint i{1}).
+ move => ->;rewrite set_neqiE;1,2:smt().
rewrite set_eqiE;1,2:smt().
move => _; have pow : W16.of_int 15 = W16.of_int (2^4-1) by smt().
Expand All @@ -975,7 +975,7 @@ split;last first.
by rewrite /smod /=; smt(W8.to_uint_cmp pow2_8).

case(k < to_uint i{1} * 2); 1: by smt(Array256.set_neqiE).
case(k = to_uint j{1}).
case(k = 2 * to_uint i{1}).
+ move => ->;rewrite set_neqiE;1,2:smt().
rewrite set_eqiE;1,2:smt().
rewrite set_neqiE;1,2:smt().
Expand Down

0 comments on commit 9827d66

Please sign in to comment.