Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ref: poly (de)compress: use a single loop counter #23

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 8 additions & 20 deletions code/jasmin/mlkem_ref/extraction/jkem.ec
Original file line number Diff line number Diff line change
Expand Up @@ -816,23 +816,21 @@ module M(SC:Syscall_t) = {
proc _poly_compress (rp:W64.t, a:W16.t Array256.t) : W16.t Array256.t = {

var i:W64.t;
var t:W16.t;
var d0:W32.t;
var d1:W32.t;

a <@ _poly_csubq (a);
i <- (W64.of_int 0);

while ((i \ult (W64.of_int 128))) {
t <- a.[(W64.to_uint ((W64.of_int 2) * i))];
d0 <- (zeroextu32 t);
d0 <- (zeroextu32 a.[(W64.to_uint ((W64.of_int 2) * i))]);
d1 <-
(zeroextu32 a.[(W64.to_uint (((W64.of_int 2) * i) + (W64.of_int 1)))]);
d0 <- (d0 `<<` (W8.of_int 4));
d0 <- (d0 + (W32.of_int 1665));
d0 <- (d0 * (W32.of_int 80635));
d0 <- (d0 `>>` (W8.of_int 28));
d0 <- (d0 `&` (W32.of_int 15));
t <- a.[(W64.to_uint (((W64.of_int 2) * i) + (W64.of_int 1)))];
d1 <- (zeroextu32 t);
d1 <- (d1 `<<` (W8.of_int 4));
d1 <- (d1 + (W32.of_int 1665));
d1 <- (d1 * (W32.of_int 80635));
Expand All @@ -850,26 +848,21 @@ module M(SC:Syscall_t) = {
W16.t Array256.t = {

var i:W64.t;
var j:W64.t;
var t:W16.t;
var d0:W32.t;
var d1:W32.t;

a <@ _poly_csubq (a);
i <- (W64.of_int 0);
j <- (W64.of_int 0);

while ((i \ult (W64.of_int 128))) {
t <- a.[(W64.to_uint j)];
d0 <- (zeroextu32 t);
d0 <- (zeroextu32 a.[(W64.to_uint ((W64.of_int 2) * i))]);
d1 <-
(zeroextu32 a.[(W64.to_uint (((W64.of_int 2) * i) + (W64.of_int 1)))]);
d0 <- (d0 `<<` (W8.of_int 4));
d0 <- (d0 + (W32.of_int 1665));
d0 <- (d0 * (W32.of_int 80635));
d0 <- (d0 `>>` (W8.of_int 28));
d0 <- (d0 `&` (W32.of_int 15));
j <- (j + (W64.of_int 1));
t <- a.[(W64.to_uint j)];
d1 <- (zeroextu32 t);
d1 <- (d1 `<<` (W8.of_int 4));
d1 <- (d1 + (W32.of_int 1665));
d1 <- (d1 * (W32.of_int 80635));
Expand All @@ -879,21 +872,18 @@ module M(SC:Syscall_t) = {
d0 <- (d0 `|` d1);
rp.[(W64.to_uint i)] <- (truncateu8 d0);
i <- (i + (W64.of_int 1));
j <- (j + (W64.of_int 1));
}
return (rp, a);
}

proc _poly_decompress (rp:W16.t Array256.t, ap:W64.t) : W16.t Array256.t = {

var i:W64.t;
var j:W64.t;
var t:W8.t;
var d0:W16.t;
var d1:W16.t;

i <- (W64.of_int 0);
j <- (W64.of_int 0);

while ((i \ult (W64.of_int 128))) {
t <- (loadW8 Glob.mem (W64.to_uint (ap + i)));
Expand All @@ -907,10 +897,8 @@ module M(SC:Syscall_t) = {
d1 <- (d1 + (W16.of_int 8));
d0 <- (d0 `>>` (W8.of_int 4));
d1 <- (d1 `>>` (W8.of_int 4));
rp.[(W64.to_uint j)] <- d0;
j <- (j + (W64.of_int 1));
rp.[(W64.to_uint j)] <- d1;
j <- (j + (W64.of_int 1));
rp.[(W64.to_uint ((W64.of_int 2) * i))] <- d0;
rp.[(W64.to_uint (((W64.of_int 2) * i) + (W64.of_int 1)))] <- d1;
i <- (i + (W64.of_int 1));
}
return (rp);
Expand Down
30 changes: 9 additions & 21 deletions code/jasmin/mlkem_ref/poly.jinc
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ fn __poly_reduce(reg ptr u16[MLKEM_N] rp) -> reg ptr u16[MLKEM_N]

fn _poly_compress(reg u64 rp, reg ptr u16[MLKEM_N] a) -> reg ptr u16[MLKEM_N]
{
reg u16 t;
reg u32 d0, d1;
reg u64 i;

Expand All @@ -150,15 +149,13 @@ fn _poly_compress(reg u64 rp, reg ptr u16[MLKEM_N] a) -> reg ptr u16[MLKEM_N]
i = 0;
while(i < 128)
{
t = a[2 * i];
d0 = (32u)t;
d0 = (32u) a[2 * i];
d1 = (32u) a[2 * i + 1];
d0 <<= 4;
d0 += 1665;
d0 *= 80635;
d0 >>= 28;
d0 &= 0xf;
t = a[2 * i + 1];
d1 = (32u)t;
d1 <<= 4;
d1 += 1665;
d1 *= 80635;
Expand All @@ -174,36 +171,30 @@ fn _poly_compress(reg u64 rp, reg ptr u16[MLKEM_N] a) -> reg ptr u16[MLKEM_N]

fn _i_poly_compress(reg ptr u8[MLKEM_POLYCOMPRESSEDBYTES] rp, reg ptr u16[MLKEM_N] a) -> reg ptr u8[MLKEM_POLYCOMPRESSEDBYTES], reg ptr u16[MLKEM_N]
{
reg u16 t;
reg u32 d0, d1;
reg u64 i j;
reg u64 i;

a = _poly_csubq(a);

i = 0;
j = 0;
while(i < 128)
{
t = a[(int)j];
d0 = (32u)t;
d0 = (32u) a[2 * i];
d1 = (32u) a[2 * i + 1];
d0 <<= 4;
d0 += 1665;
d0 *= 80635;
d0 >>= 28;
d0 &= 0xf;
j += 1;
t = a[(int)j];
d1 = (32u)t;
d1 <<= 4;
d1 += 1665;
d1 *= 80635;
d1 >>= 28;
d1 &= 0xf;
d1 <<= 4;
d0 |= d1;
rp[(int) i] = d0;
rp[i] = d0;
i += 1;
j += 1;
}
return rp, a;
}
Expand All @@ -213,10 +204,9 @@ fn _poly_decompress(reg ptr u16[MLKEM_N] rp, reg u64 ap) -> stack u16[MLKEM_N]
{
reg u8 t;
reg u16 d0, d1;
reg u64 i j;
reg u64 i;

i = 0;
j = 0;

while (i < 128) {
t = (u8)[ap+i];
Expand All @@ -230,10 +220,8 @@ fn _poly_decompress(reg ptr u16[MLKEM_N] rp, reg u64 ap) -> stack u16[MLKEM_N]
d1 += 8;
d0 >>= 4;
d1 >>= 4;
rp[(int)j] = d0;
j += 1;
rp[(int)j] = d1;
j += 1;
rp[2 * i] = d0;
rp[2 * i + 1] = d1;
i += 1;
}
return rp;
Expand Down
32 changes: 16 additions & 16 deletions proof/correctness/MLKEM_Poly.ec
Original file line number Diff line number Diff line change
Expand Up @@ -861,22 +861,22 @@ lemma i_poly_compress_corr _a :
res{1}.`1 = res{2} /\
pos_bound256_cxq res{1}.`2 0 256 1].
proc => /=.
seq 3 3 : (#{/~a{1}}pre /\ to_uint i{1} = i{2} /\ i{2} = 0 /\
to_uint j{1} = j{2} /\ j{2} = 0 /\
seq 2 3 : (#{/~a{1}}pre /\ to_uint i{1} = i{2} /\ i{2} = 0 /\
j{2} = 0 /\
pos_bound256_cxq a{1} 0 256 1 /\ lift_array256 a{1} = _a).
wp => /=;call{1} (poly_csubq_corr _a); 1: by auto => /#.

while (#{~i{2}=0}{~j{2}=0}pre /\ to_uint i{1} = i{2} /\ 0<=i{2}<=128 /\
to_uint j{1} = j{2} /\ j{2} = 2* i{2} /\
j{2} = 2* i{2} /\
forall k, 0<=k<i{2} => rp{1}.[k] = r{2}.[k]); last first.
+ auto => /> &1 &2; rewrite ultE of_uintK /= =>
bnd ??; split; 1: by smt().
move => i' j' ra'; rewrite ultE of_uintK /= => rr exit _ ibl ibh jv prev.
bnd ?; split; 1: by smt().
move => i' j' ra'; rewrite ultE of_uintK /= => rr exit ibl ibh prev.
by rewrite tP => k kb /#.

auto => /> &1 &2 ??; rewrite /pos_bound256_cxq /=.
rewrite ultE of_uintK /= => ?????.
rewrite !to_uintD_small /=; 1..3: smt().
rewrite ultE of_uintK /= => ????.
rewrite !to_uintD_small ?to_uintM_small /= 1..4:/#.
do split; 1..3: by smt().
+ move => k kbl kbh.
case (k = to_uint i{1}); last first.
Expand All @@ -892,9 +892,9 @@ do split; 1..3: by smt().
case (k = to_uint i{1}); last by smt(Array128.set_neqiE).
move => iv; have -> : 15 = 2^4 - 1 by auto.
rewrite !and_mod //.
pose x := (((zeroextu32 a{1}.[to_uint j{1}] `<<` (of_int 4)%W8) +
pose x := (((zeroextu32 a{1}.[2 * to_uint i{1}] `<<` (of_int 4)%W8) +
(of_int 1665)%W32) * (of_int 80635)%W32 `>>`(of_int 28)%W8).
pose y := (((zeroextu32 a{1}.[to_uint j{1} + 1] `<<` (of_int 4)%W8) +
pose y := (((zeroextu32 a{1}.[2 * to_uint i{1} + 1] `<<` (of_int 4)%W8) +
(of_int 1665)%W32) * (of_int 80635)%W32 `>>` (of_int 28)%W8).
rewrite to_uint_eq to_uint_truncateu8 !of_uintK to_uint_orw_disjoint.
+ apply W32.ext_eq => i ib; rewrite /W32.(`&`) map2E initiE //=.
Expand Down Expand Up @@ -938,24 +938,24 @@ lemma poly_decompress_corr mem _p (_a : W8.t Array128.t):
lift_array256 res{1} = decompress_poly 4 res{2} /\
pos_bound256_cxq res{1} 0 256 1 ].
proc.
while(#pre /\ to_uint i{1} = i{2} /\ to_uint j{1} = 2*i{2} /\ 0 <= i{2} <= 128 /\
while(#pre /\ to_uint i{1} = i{2} /\ 0 <= i{2} <= 128 /\
forall k, 0<=k<i{2}*2 => incoeff (to_sint rp{1}.[k]) = decompress 4 r{2}.[k] /\
0 <= to_sint rp{1}.[k] <q); last first.
+ auto => /> &1; rewrite /lift_array256 /decompress_poly /pos_bound256_cxq /=.
move => vrl vrh; split; 1: by smt().
move => i j rp rr; rewrite ultE of_uintK /= => exit _ ibl ibh jv prior; split; 2: by smt().
move => i rp rr; rewrite ultE of_uintK /= => exit _ ibl ibh prior; split; 2: by smt().
by rewrite tP => x xb; rewrite !mapiE //= /#.

auto => /> &1 &2 vpl vph il ih jv prior; rewrite ultE of_uintK /= => enter _;
auto => /> &1 &2 vpl vph il ih prior; rewrite ultE of_uintK /= => enter _;
split; 2: by rewrite ultE of_uintK /= to_uintD_small /#.
rewrite !to_uintD_small /=; 1..4: by smt().
do split; 1..3: by smt().
rewrite !to_uintD_small ?to_uintM_small /= 1..5:/#.
do split; 1..2: by smt().
move => k kl kh.
rewrite /(`>>`) /= /load_array32 /=.

split;last first.
+ case(k < to_uint i{1} * 2); 1: by smt(Array256.set_neqiE).
case(k = to_uint j{1}).
case(k = 2 * to_uint i{1}).
+ move => ->;rewrite set_neqiE;1,2:smt().
rewrite set_eqiE;1,2:smt().
move => _; have pow : W16.of_int 15 = W16.of_int (2^4-1) by smt().
Expand All @@ -975,7 +975,7 @@ split;last first.
by rewrite /smod /=; smt(W8.to_uint_cmp pow2_8).

case(k < to_uint i{1} * 2); 1: by smt(Array256.set_neqiE).
case(k = to_uint j{1}).
case(k = 2 * to_uint i{1}).
+ move => ->;rewrite set_neqiE;1,2:smt().
rewrite set_eqiE;1,2:smt().
rewrite set_neqiE;1,2:smt().
Expand Down