Skip to content

Commit

Permalink
poly_basemul: compute the zetas counter from i (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
vbgl authored Feb 7, 2024
1 parent cf9bdf4 commit 82da326
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 21 deletions.
12 changes: 6 additions & 6 deletions code/jasmin/mlkem_ref/extraction/jkem.ec
Original file line number Diff line number Diff line change
Expand Up @@ -733,9 +733,9 @@ module M(SC:Syscall_t) = {
bp:W16.t Array256.t) : W16.t Array256.t = {

var srp:W16.t Array256.t;
var zetasctr:W64.t;
var i:W64.t;
var zetasp:W16.t Array128.t;
var zetasp:W16.t Array64.t;
var zetasctr:W64.t;
var zeta_0:W16.t;
var a0:W16.t;
var b0:W16.t;
Expand All @@ -747,13 +747,13 @@ module M(SC:Syscall_t) = {
srp <- witness;
zetasp <- witness;
srp <- rp;
zetasctr <- (W64.of_int 64);
i <- (W64.of_int 0);

while ((i \ult (W64.of_int 256))) {
zetasp <- jzetas;
while ((i \ult (W64.of_int (256 - 3)))) {
zetasp <- (Array64.init (fun i_0 => jzetas.[64 + i_0]));
zetasctr <- i;
zetasctr <- (zetasctr `>>` (W8.of_int 2));
zeta_0 <- zetasp.[(W64.to_uint zetasctr)];
zetasctr <- (zetasctr + (W64.of_int 1));
a0 <- ap.[(W64.to_uint i)];
b0 <- bp.[(W64.to_uint i)];
i <- (i + (W64.of_int 1));
Expand Down
12 changes: 6 additions & 6 deletions code/jasmin/mlkem_ref/poly.jinc
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,22 @@ fn _poly_basemul(reg ptr u16[MLKEM_N] rp, reg const ptr u16[MLKEM_N] ap bp) -> r
reg u16 b0;
reg u16 b1;
reg u16 t;
reg ptr u16[128] zetasp;
reg ptr u16[64] zetasp;
reg u64 zetasctr;
reg u64 i;

stack ptr u16[MLKEM_N] srp;

srp = rp;

zetasctr = 64;
i = 0;

while(i < MLKEM_N)
while(i < MLKEM_N - 3)
{
zetasp = jzetas;
zeta = zetasp[(int)zetasctr];
zetasctr += 1;
zetasp = jzetas[64:64];
zetasctr = i;
zetasctr >>= 2;
zeta = zetasp[zetasctr];

a0 = ap[(int)i];
b0 = bp[(int)i];
Expand Down
18 changes: 9 additions & 9 deletions proof/correctness/MLKEM_Poly.ec
Original file line number Diff line number Diff line change
Expand Up @@ -1675,7 +1675,6 @@ proc.
seq 2 : #pre; first by auto => />.
while (#pre /\ srp = rp /\
0<= to_uint i <= 256 /\ to_uint i %% 4 = 0 /\
to_uint zetasctr = to_uint i %/ 4 + 64 /\
signed_bound_cxq rp 0 (to_uint i) 3 /\
isbasemul _ap _bp zetas (lift_array256 rp) (to_uint i));
last by auto => /> &hr ??; split; [ smt() | move => ??;rewrite ultE /#].
Expand All @@ -1691,18 +1690,19 @@ wp;ecall(fqmul_corr_h (to_sint a0) (to_sint b0)).
wp;ecall(fqmul_corr_h (to_sint r0) (to_sint zeta_0)).
wp;ecall(fqmul_corr_h (to_sint a1) (to_sint b1)).
auto => /> &hr; rewrite /signed_bound_cxq /isbasemul /lift_array256 !ultE /=.
move => ba bb ??? zctl bprev vprev entry.
move => ba bb ??? bprev vprev entry.
have -> /= : i{hr} + W64.one - W64.one = i{hr} by ring.
have -> /= : i{hr} + W64.of_int 3 - W64.one = i{hr} + W64.of_int 2 by ring.
rewrite !to_uintD_small /=;1..5:smt();rewrite !zctl.
rewrite !to_uintD_small /=; 1..4:smt().
rewrite !W64.shr_div /= Array64.Array64.initiE /=; 1: smt().
move => r1 r1val r2 r2val r3 r3val r4 r4val r5 r5val r6 r6val r7 r7val r8 r8val r9 r9val r10 r10val.

have /= [#] redbl1 redbh1 redv1 :=
(SREDCp_corr (to_sint ap{hr}.[to_uint i{hr} + 1] * to_sint bp{hr}.[to_uint i{hr} + 1]) _ _); 1,2:
by rewrite /R /=; smt().

have /= [#] redbl2 redbh2 redv2 :=
(SREDCp_corr (to_sint r1 * to_sint jzetas.[to_uint i{hr} %/ 4 + 64]) _ _); 1,2:
(SREDCp_corr (to_sint r1 * to_sint jzetas.[64 + to_uint i{hr} %/ 4]) _ _); 1,2:
by rewrite /R /=; move : zeta_bound; rewrite /minimum_residues; smt().

have /= [#] redbl3 redbh3 redv3 :=
Expand All @@ -1722,7 +1722,7 @@ have /= [#] redbl6 redbh6 redv6 :=
by rewrite /R /=; smt().

have /= [#] redbl7 redbh7 redv7 :=
(SREDCp_corr (to_sint r6 * to_sint (- jzetas.[to_uint i{hr} %/ 4 + 64])) _ _); 1: by rewrite /R /=; smt().
(SREDCp_corr (to_sint r6 * to_sint (- jzetas.[64 + to_uint i{hr} %/ 4])) _ _); 1: by rewrite /R /=; smt().
+ rewrite /R /=; move : zeta_bound; rewrite /minimum_residues /bpos16 => zb.
rewrite to_sintN /=; do split; smt().

Expand Down Expand Up @@ -1752,14 +1752,14 @@ rewrite -r9val in redbl9;rewrite -r9val in redbh9;rewrite -r9val eq_incoeff in r
rewrite -r10val in redbl10;rewrite -r10val in redbh10;rewrite -r10val eq_incoeff in redv10.

move : zetas_montE; rewrite /array_mont /lift_array128 tP => zv.
move : (zv (to_uint i{hr} %/ 4 + 64) _); 1: smt().
move : (zv (64 + to_uint i{hr} %/ 4) _); 1: smt().
rewrite !mapiE /=; 1,2: smt() => <-.
move => zv1.
move : (zv (to_uint i{hr} %/ 4) _); 1: smt().
rewrite !mapiE /=; 1,2: smt() => <-.
move => zv2.

do split; 1..4: smt().
do split; 1..3: smt().

+ move => k kb; case (k < to_uint i{hr}).
+ by move => *; rewrite !set_neqiE /#.
Expand Down Expand Up @@ -1800,7 +1800,7 @@ do split.
rewrite set_eqiE; 1,2: smt().
rewrite to_sintD_small; 1: smt().
rewrite incoeffD redv2 redv3 !incoeffM redv1 !incoeffM -zv1 /zetas !initiE /=; 1,2: smt().
rewrite kval.
rewrite kval (IntID.addrC 64).
move : (zetavals1 (to_uint i{hr}) _ _); 1,2: smt().
rewrite /zetas /= initiE 1:/# /= => ->.
rewrite ZqField.mulrA -ZqField.mulrA rrinvcoeff.
Expand All @@ -1817,7 +1817,7 @@ rewrite set_neqiE; 1,2: smt().
rewrite set_eqiE; 1,2: smt().
rewrite to_sintD_small; 1: smt().
rewrite incoeffD redv7 redv8 redv6 -zv1 /zetas !initiE /=; 1,2: smt().
rewrite !incoeffM kval.
rewrite !incoeffM kval (IntID.addrC 64).
move : (zetavals1 (to_uint i{hr}) _ _); 1,2: smt().
rewrite /zetas /= initiE 1:/# /= => ->.
congr;congr;rewrite -ZqRing.mulrA -(ZqRing.mulrC _ (incoeff 169));congr.
Expand Down

0 comments on commit 82da326

Please sign in to comment.