Skip to content

Commit

Permalink
gen-matrix: use a slice instead of two counters (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
vbgl authored Feb 6, 2024
1 parent b2358e2 commit cf9bdf4
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 51 deletions.
12 changes: 8 additions & 4 deletions code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec
Original file line number Diff line number Diff line change
Expand Up @@ -4415,12 +4415,13 @@ module M(SC:Syscall_t) = {
var buf:W8.t Array168.t;
var poly:W16.t Array256.t;
var k:W64.t;
var l:W64.t;
var rij:W16.t Array256.t;
var t:W16.t;
buf <- witness;
extseed <- witness;
poly <- witness;
r <- witness;
rij <- witness;
state <- witness;
stransposed <- transposed;
j <- 0;
Expand Down Expand Up @@ -4451,14 +4452,17 @@ module M(SC:Syscall_t) = {
(ctr, poly) <@ __rej_uniform (poly, ctr, buf);
}
k <- (W64.of_int 0);
l <- (W64.of_int ((i * (3 * 256)) + (j * 256)));
rij <-
(Array256.init (fun i_0 => r.[((i * (3 * 256)) + (j * 256)) + i_0]));

while ((k \ult (W64.of_int 256))) {
t <- poly.[(W64.to_uint k)];
r.[(W64.to_uint l)] <- t;
rij.[(W64.to_uint k)] <- t;
k <- (k + (W64.of_int 1));
l <- (l + (W64.of_int 1));
}
r <- Array2304.init
(fun i_0 => if ((i * (3 * 256)) + (j * 256)) <= i_0 < ((i * (3 * 256)) + (j * 256)) + 256
then rij.[i_0-((i * (3 * 256)) + (j * 256))] else r.[i_0]);
j <- j + 1;
}
i <- i + 1;
Expand Down
9 changes: 5 additions & 4 deletions code/jasmin/mlkem_avx2/gen_matrix.jinc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ fn __gen_matrix(stack u8[MLKEM_SYMBYTES] seed, reg u64 transposed) -> stack u16[

reg u8 c;
reg u16 t;
reg u64 ctr k l;
reg u64 ctr k;
stack u64 sctr;
stack u64 stransposed;
inline int j i;
Expand Down Expand Up @@ -114,14 +114,15 @@ fn __gen_matrix(stack u8[MLKEM_SYMBYTES] seed, reg u64 transposed) -> stack u16[
}

k = 0;
l = i * MLKEM_VECN + j * MLKEM_N;
reg ptr u16[MLKEM_N] rij;
rij = r[i * MLKEM_VECN + j * MLKEM_N : MLKEM_N];
while (k < MLKEM_N)
{
t = poly[(int) k];
r[(int) l] = t;
rij[k] = t;
k += 1;
l += 1;
}
r[i * MLKEM_VECN + j * MLKEM_N : MLKEM_N] = rij;
}
}

Expand Down
12 changes: 8 additions & 4 deletions code/jasmin/mlkem_ref/extraction/jkem.ec
Original file line number Diff line number Diff line change
Expand Up @@ -1682,12 +1682,13 @@ module M(SC:Syscall_t) = {
var buf:W8.t Array168.t;
var poly:W16.t Array256.t;
var k:W64.t;
var l:W64.t;
var rij:W16.t Array256.t;
var t:W16.t;
buf <- witness;
extseed <- witness;
poly <- witness;
r <- witness;
rij <- witness;
state <- witness;
stransposed <- transposed;
j <- 0;
Expand Down Expand Up @@ -1718,14 +1719,17 @@ module M(SC:Syscall_t) = {
(ctr, poly) <@ __rej_uniform (poly, ctr, buf);
}
k <- (W64.of_int 0);
l <- (W64.of_int ((i * (3 * 256)) + (j * 256)));
rij <-
(Array256.init (fun i_0 => r.[((i * (3 * 256)) + (j * 256)) + i_0]));

while ((k \ult (W64.of_int 256))) {
t <- poly.[(W64.to_uint k)];
r.[(W64.to_uint l)] <- t;
rij.[(W64.to_uint k)] <- t;
k <- (k + (W64.of_int 1));
l <- (l + (W64.of_int 1));
}
r <- Array2304.init
(fun i_0 => if ((i * (3 * 256)) + (j * 256)) <= i_0 < ((i * (3 * 256)) + (j * 256)) + 256
then rij.[i_0-((i * (3 * 256)) + (j * 256))] else r.[i_0]);
j <- j + 1;
}
i <- i + 1;
Expand Down
9 changes: 5 additions & 4 deletions code/jasmin/mlkem_ref/gen_matrix.jinc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ fn __gen_matrix(stack u8[MLKEM_SYMBYTES] seed, reg u64 transposed) -> stack u16[

reg u8 c;
reg u16 t;
reg u64 ctr k l;
reg u64 ctr k;
stack u64 sctr;
stack u64 stransposed;
inline int j i;
Expand Down Expand Up @@ -112,14 +112,15 @@ fn __gen_matrix(stack u8[MLKEM_SYMBYTES] seed, reg u64 transposed) -> stack u16[
}

k = 0;
l = i * MLKEM_VECN + j * MLKEM_N;
reg ptr u16[MLKEM_N] rij;
rij = r[i * MLKEM_VECN + j * MLKEM_N : MLKEM_N];
while (k < MLKEM_N)
{
t = poly[(int) k];
r[(int) l] = t;
rij[k] = t;
k += 1;
l += 1;
}
r[i * MLKEM_VECN + j * MLKEM_N : MLKEM_N] = rij;
}
}
return r;
Expand Down
56 changes: 22 additions & 34 deletions proof/correctness/MLKEM_InnerPKE.ec
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ proc => /=.
inline Parse(XOF).sample.
inline Jkem.M(Jkem.Syscall).__rej_uniform.

seq 6 1: (={seed} /\ stransposed{1} = (if trans{2} then W64.one else W64.zero)); 1: by auto.
seq 7 1: (={seed} /\ stransposed{1} = (if trans{2} then W64.one else W64.zero)); 1: by auto.

seq 2 0 : (#pre /\ forall k, 0<=k<32 => extseed {1}.[k] = seed{2}.[k]).
+ conseq => />; 1: smt().
Expand Down Expand Up @@ -491,39 +491,27 @@ seq 5 4 : (#pre /\ lift_array256 poly{1} = aa{2} /\
forall k, 0 <= k < 256 =>
bpos16 poly{1}.[k] q); last first.

+ wp;conseq />; 1: smt().
while{1} (
0 <= to_uint k{1} <= 256 /\ to_uint l{1} = i{2} * 768 + j{2} * 256 + to_uint k{1} /\
0 <= j{2} < 3 /\ 0 <= i{2} < 3 /\
(forall (k0 : int), 0 <= k0 < i{2} * 768 + j{2} * 256 =>
(a{2}.[k0 %/ 768, k0 %% 768 %/ 256])%Matrix.[k0 %% 256] = incoeff (to_sint r{1}.[k0]) /\
0 <= to_sint r{1}.[k0] < q) /\
(forall (k0 : int),
i{2} * 768 + j{2} * 256 <= k0 < i{2} * 768 + j{2} * 256 + to_uint k{1} =>
r{1}.[k0] = poly{1}.[k0 %% 256] /\
0 <= to_sint r{1}.[k0] < q) /\
(forall k, 0 <= k < 256 => bpos16 poly{1}.[k] q)
) (256 - to_uint k{1}); last first.

+ auto => /> &1 &2 [#] 8?; rewrite /lift_array256 => ? k; do split; 2:smt().
+ by rewrite of_uintK /= modz_small /#.
move => kl ll rl; rewrite ultE /=; split; 1: smt().
move => *;do split; 1,2: smt().
move => k0 k0bl k0bh;rewrite setmE /= offunmE /= 1:/#.
case (k0 < i{2} * 768 + j{2} * 256); 1: smt().
move => *.
have -> : k0 %/ 768 = i{2} by smt().
have -> /= : k0 %% 768 %/ 256 = j{2} by smt().
by rewrite mapiE /#.

move => &2 z; auto => /> &1; rewrite !ultE /= => ?????.
rewrite to_uintD_small /= 1:/# => *; do split;1,2,6:smt().
+ by rewrite to_uintD_small /#.
+ by move => kk kkbl kkbh; rewrite set_neqiE /#.
+ move => kk kkbl kkbh.
case (kk < i{2} * 768 + j{2} * 256 + to_uint k{1}).
+ by move => *; rewrite set_neqiE /#.
by move => *; rewrite Array2304.set_eqiE /#.
+ wp; conseq />.
* move => &1 &2 [] [] [] -> [] ? [] [] -> -> [] ? [] ? [] ? ? ? [] ? ? ? [] [] ? ? H /=.
do split => //; 1,2: smt().
by move => ??; rewrite -H /#.
while{1}
(0 <= to_uint k{1} <= 256
/\ forall (k0 : int), 0 <= k0 < to_uint k{1} => rij{1}.[k0] = poly{1}.[k0])
(256 - to_uint k{1}).
* auto => /> 4?.
rewrite ultE /= to_uintD_small 1:/# to_uint_small //.
smt(Array256.get_setE).
auto => /> &1 &2 [#] 9? hpoly; split; 1: smt().
move => rij kl; rewrite ultE /=; split; 1: smt().
move => hkl 2? heq; split; 1: smt().
move => k ??.
rewrite initiE 1:/# /=.
rewrite /lift_array256 setmE /= offunmE /= 1:/#.
have -> /= : k < i{2} * 768 + j{2} * 256 + 256 by smt().
have -> : (k %/ 768 = i{2} /\ k %% 768 %/ 256 = j{2}) <=> i{2} * 768 + j{2} * 256 <= k by smt().
case: (i{2} * 768 + j{2} * 256 <= k); 2: smt().
by rewrite mapiE /#.

conseq />; 1: by smt().

Expand Down
2 changes: 1 addition & 1 deletion proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ symmetry.
have H : equiv [
Jkem.M(Jkem.Syscall).__gen_matrix ~ Jkem_avx2.M(Jkem_avx2.Syscall).__gen_matrix :
={arg} ==> res{2} = nttunpackm res{1}].
proc. seq 10 10 : (={r}).
proc. seq 11 11 : (={r}).
sim (M(Syscall)._shake128_absorb34 ~ Jkem_avx2.M(Jkem_avx2.Syscall)._shake128_absorb34 : true)
(M(Syscall)._shake128_squeezeblock ~ Jkem_avx2.M(Jkem_avx2.Syscall)._shake128_squeezeblock : true)
(Jkem_avx2.M(Jkem_avx2.Syscall).__rej_uniform ~
Expand Down

0 comments on commit cf9bdf4

Please sign in to comment.