diff --git a/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec b/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec index 6896d69b..918d34b5 100644 --- a/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec +++ b/code/jasmin/mlkem_avx2/extraction/jkem_avx2.ec @@ -137,17 +137,17 @@ W64.of_int (-9223372036854775808); W64.of_int (-9223372036854775808); W64.of_int (-9223372036854775808)]. -abbrev KECCAK_RC = Array24.of_list witness [W64.of_int 1; W64.of_int 32898; -W64.of_int (-9223372036854742902); W64.of_int (-9223372034707259392); -W64.of_int 32907; W64.of_int 2147483649; W64.of_int (-9223372034707259263); -W64.of_int (-9223372036854743031); W64.of_int 138; W64.of_int 136; -W64.of_int 2147516425; W64.of_int 2147483658; W64.of_int 2147516555; -W64.of_int (-9223372036854775669); W64.of_int (-9223372036854742903); -W64.of_int (-9223372036854743037); W64.of_int (-9223372036854743038); -W64.of_int (-9223372036854775680); W64.of_int 32778; -W64.of_int (-9223372034707292150); W64.of_int (-9223372034707259263); -W64.of_int (-9223372036854742912); W64.of_int 2147483649; -W64.of_int (-9223372034707259384)]. +abbrev KECCAK1600_RC = Array24.of_list witness [W64.of_int 1; +W64.of_int 32898; W64.of_int (-9223372036854742902); +W64.of_int (-9223372034707259392); W64.of_int 32907; W64.of_int 2147483649; +W64.of_int (-9223372034707259263); W64.of_int (-9223372036854743031); +W64.of_int 138; W64.of_int 136; W64.of_int 2147516425; W64.of_int 2147483658; +W64.of_int 2147516555; W64.of_int (-9223372036854775669); +W64.of_int (-9223372036854742903); W64.of_int (-9223372036854743037); +W64.of_int (-9223372036854743038); W64.of_int (-9223372036854775680); +W64.of_int 32778; W64.of_int (-9223372034707292150); +W64.of_int (-9223372034707259263); W64.of_int (-9223372036854742912); +W64.of_int 2147483649; W64.of_int (-9223372034707259384)]. abbrev jdmontx16 = Array16.of_list witness [W16.of_int 1353; W16.of_int 1353; @@ -737,7 +737,7 @@ module M(SC:Syscall_t) = { return (rd); } - proc __index (x:int, y:int) : int = { + proc keccakf1600_index (x:int, y:int) : int = { var r:int; @@ -745,7 +745,7 @@ module M(SC:Syscall_t) = { return (r); } - proc __keccak_rho_offsets (i:int) : int = { + proc keccakf1600_rho_offsets (i:int) : int = { var aux: int; var r:int; @@ -772,175 +772,191 @@ module M(SC:Syscall_t) = { return (r); } - proc __rhotates (x:int, y:int) : int = { + proc keccakf1600_rhotates (x:int, y:int) : int = { var r:int; var i:int; - i <@ __index (x, y); - r <@ __keccak_rho_offsets (i); + i <@ keccakf1600_index (x, y); + r <@ keccakf1600_rho_offsets (i); return (r); } - proc __theta_sum_scalar (a:W64.t Array25.t) : W64.t Array5.t = { + proc keccakf1600_theta_sum (a:W64.t Array25.t) : W64.t Array5.t = { var aux: int; var c:W64.t Array5.t; - var i:int; - var ti:int; - var j:int; + var x:int; + var y:int; c <- witness; - i <- 0; - while (i < 5) { - ti <@ __index (i, 0); - c.[i] <- a.[ti]; - i <- i + 1; + x <- 0; + while (x < 5) { + c.[x] <- a.[(x + 0)]; + x <- x + 1; } - j <- 1; - while (j < 5) { - i <- 0; - while (i < 5) { - ti <@ __index (i, j); - c.[i] <- (c.[i] `^` a.[ti]); - i <- i + 1; + y <- 1; + while (y < 5) { + x <- 0; + while (x < 5) { + c.[x] <- (c.[x] `^` a.[(x + (y * 5))]); + x <- x + 1; } - j <- j + 1; + y <- y + 1; } return (c); } - proc __theta_rol_scalar (c:W64.t Array5.t) : W64.t Array5.t = { + proc keccakf1600_theta_rol (c:W64.t Array5.t) : W64.t Array5.t = { var aux_1: bool; var aux_0: bool; var aux: int; var aux_2: W64.t; var d:W64.t Array5.t; - var i:int; + var x:int; var _0:bool; var _1:bool; d <- witness; - i <- 0; - while (i < 5) { - d.[i] <- c.[((i + 1) %% 5)]; - (aux_1, aux_0, aux_2) <- ROL_64 d.[i] (W8.of_int 1); + x <- 0; + while (x < 5) { + d.[x] <- c.[((x + 1) %% 5)]; + (aux_1, aux_0, aux_2) <- ROL_64 d.[x] (W8.of_int 1); _0 <- aux_1; _1 <- aux_0; - d.[i] <- aux_2; - d.[i] <- (d.[i] `^` c.[((i + 4) %% 5)]); - i <- i + 1; + d.[x] <- aux_2; + d.[x] <- (d.[x] `^` c.[(((x - 1) + 5) %% 5)]); + x <- x + 1; } return (d); } - proc __rol_sum_scalar (d:W64.t Array5.t, a:W64.t Array25.t, offset:int) : + proc keccakf1600_rol_sum (a:W64.t Array25.t, d:W64.t Array5.t, y:int) : W64.t Array5.t = { var aux_1: bool; var aux_0: bool; var aux: int; var aux_2: W64.t; - var c:W64.t Array5.t; - var j:int; - var j1:int; - var k:int; - var ti:int; + var b:W64.t Array5.t; + var x:int; + var x_:int; + var y_:int; + var r:int; var _0:bool; var _1:bool; - c <- witness; - j <- 0; - while (j < 5) { - j1 <- ((j + offset) %% 5); - k <@ __rhotates (j1, j); - ti <@ __index (j1, j); - c.[j] <- a.[ti]; - c.[j] <- (c.[j] `^` d.[j1]); - (aux_1, aux_0, aux_2) <- ROL_64 c.[j] (W8.of_int k); - _0 <- aux_1; - _1 <- aux_0; - c.[j] <- aux_2; - j <- j + 1; + b <- witness; + x <- 0; + while (x < 5) { + x_ <- ((x + (3 * y)) %% 5); + y_ <- x; + r <@ keccakf1600_rhotates (x_, y_); + b.[x] <- a.[(x_ + (y_ * 5))]; + b.[x] <- (b.[x] `^` d.[x_]); + if ((r <> 0)) { + (aux_1, aux_0, aux_2) <- ROL_64 b.[x] (W8.of_int r); + _0 <- aux_1; + _1 <- aux_0; + b.[x] <- aux_2; + } else { + + } + x <- x + 1; } - return (c); + return (b); } - proc __set_row_scalar (r:W64.t Array25.t, row:int, c:W64.t Array5.t, - iota_0:W64.t) : W64.t Array25.t = { + proc keccakf1600_set_row (e:W64.t Array25.t, b:W64.t Array5.t, y:int, + s_rc:W64.t) : W64.t Array25.t = { var aux: int; - var j:int; - var j1:int; - var j2:int; + var x:int; + var x1:int; + var x2:int; var t:W64.t; - var ti:int; - j <- 0; - while (j < 5) { - j1 <- ((j + 1) %% 5); - j2 <- ((j + 2) %% 5); - t <- ((invw c.[j1]) `&` c.[j2]); - if (((row = 0) /\ (j = 0))) { - t <- (t `^` iota_0); + x <- 0; + while (x < 5) { + x1 <- ((x + 1) %% 5); + x2 <- ((x + 2) %% 5); + t <- ((invw b.[x1]) `&` b.[x2]); + t <- (t `^` b.[x]); + if (((x = 0) /\ (y = 0))) { + t <- (t `^` s_rc); } else { } - t <- (t `^` c.[j]); - ti <@ __index (j, row); - r.[ti] <- t; - j <- j + 1; + e.[(x + (y * 5))] <- t; + x <- x + 1; } - return (r); + return (e); } - proc __round2x_scalar (a:W64.t Array25.t, r:W64.t Array25.t, iota_0:W64.t) : - W64.t Array25.t * W64.t Array25.t = { + proc keccakf1600_round (e:W64.t Array25.t, a:W64.t Array25.t, rc:W64.t) : + W64.t Array25.t = { + var aux: int; + var s_rc:W64.t; var c:W64.t Array5.t; var d:W64.t Array5.t; + var y:int; + var b:W64.t Array5.t; + b <- witness; c <- witness; d <- witness; - c <@ __theta_sum_scalar (a); - d <@ __theta_rol_scalar (c); - c <@ __rol_sum_scalar (d, a, 0); - r <@ __set_row_scalar (r, 0, c, iota_0); - c <@ __rol_sum_scalar (d, a, 3); - r <@ __set_row_scalar (r, 1, c, iota_0); - c <@ __rol_sum_scalar (d, a, 1); - r <@ __set_row_scalar (r, 2, c, iota_0); - c <@ __rol_sum_scalar (d, a, 4); - r <@ __set_row_scalar (r, 3, c, iota_0); - c <@ __rol_sum_scalar (d, a, 2); - r <@ __set_row_scalar (r, 4, c, iota_0); - return (a, r); + s_rc <- rc; + c <@ keccakf1600_theta_sum (a); + d <@ keccakf1600_theta_rol (c); + y <- 0; + while (y < 5) { + b <@ keccakf1600_rol_sum (a, d, y); + e <@ keccakf1600_set_row (e, b, y, s_rc); + y <- y + 1; + } + return (e); } - proc _keccakf1600_scalar (a:W64.t Array25.t) : W64.t Array25.t = { - - var iotas_p:W64.t Array24.t; - var round:W64.t; - var iota_0:W64.t; - var round_s:W64.t; - var r:W64.t Array25.t; - iotas_p <- witness; - r <- witness; - iotas_p <- KECCAK_RC; - round <- (W64.of_int 0); - - while ((round \ult (W64.of_int 24))) { - iota_0 <- iotas_p.[(W64.to_uint round)]; - round_s <- round; - (a, r) <@ __round2x_scalar (a, r, iota_0); - round <- round_s; - round <- (round + (W64.of_int 1)); - iota_0 <- iotas_p.[(W64.to_uint round)]; - round_s <- round; - (r, a) <@ __round2x_scalar (r, a, iotas_p.[(W64.to_uint round)]); - round <- round_s; - round <- (round + (W64.of_int 1)); + proc __keccakf1600 (a:W64.t Array25.t) : W64.t Array25.t = { + + var rC:W64.t Array24.t; + var s_e:W64.t Array25.t; + var e:W64.t Array25.t; + var c:W64.t; + var rc:W64.t; + rC <- witness; + e <- witness; + s_e <- witness; + rC <- KECCAK1600_RC; + e <- s_e; + c <- (W64.of_int 0); + + while ((c \ult (W64.of_int 24))) { + rc <- rC.[(W64.to_uint c)]; + e <@ keccakf1600_round (e, a, rc); + rc <- rC.[((W64.to_uint c) + 1)]; + a <@ keccakf1600_round (a, e, rc); + c <- (c + (W64.of_int 2)); } return (a); } + proc _keccakf1600 (a:W64.t Array25.t) : W64.t Array25.t = { + + + + a <@ __keccakf1600 (a); + return (a); + } + + proc _keccakf1600_ (a:W64.t Array25.t) : W64.t Array25.t = { + + + + a <- a; + a <@ _keccakf1600 (a); + a <- a; + return (a); + } + proc __st0 (state:W64.t Array25.t) : W64.t Array25.t = { var aux: int; @@ -1040,14 +1056,14 @@ module M(SC:Syscall_t) = { s_in <- in_0; s_ilen <- ilen; s_r8 <- r8; - state <@ _keccakf1600_scalar (state); + state <@ _keccakf1600_ (state); in_0 <- s_in; ilen <- s_ilen; r8 <- s_r8; } t8 <- (W8.of_int 6); state <@ __add_final_block (state, in_0, ilen, t8, r8); - state <@ _keccakf1600_scalar (state); + state <@ _keccakf1600_ (state); out <- s_out; i <- 0; while (i < 4) { @@ -1093,7 +1109,7 @@ module M(SC:Syscall_t) = { i <- i + 1; } s_in <- in1; - state <@ _keccakf1600_scalar (state); + state <@ _keccakf1600_ (state); r8 <- (W64.of_int 136); ilen <- (W64.of_int (((3 * 320) + 128) - (136 - 32))); in_0 <- s_in; @@ -1104,14 +1120,14 @@ module M(SC:Syscall_t) = { s_in <- in_0; s_ilen <- ilen; s_r8 <- r8; - state <@ _keccakf1600_scalar (state); + state <@ _keccakf1600_ (state); in_0 <- s_in; ilen <- s_ilen; r8 <- s_r8; } t8 <- (W8.of_int 31); state <@ __add_final_block (state, in_0, ilen, t8, r8); - state <@ _keccakf1600_scalar (state); + state <@ _keccakf1600_ (state); out <- s_out; aux <- (32 %/ 8); i <- 0; @@ -1149,7 +1165,7 @@ module M(SC:Syscall_t) = { (WArray200.get64 (WArray200.set8 (WArray200.init64 (fun i_0 => state.[i_0])) (72 - 1) (( (get8 (WArray200.init64 (fun i_0 => state.[i_0])) (72 - 1)) `^` (W8.of_int 128))))); out_s <- out; - state <@ _keccakf1600_scalar (state); + state <@ _keccakf1600_ (state); out <- out_s; i <- 0; while (i < 8) { @@ -1187,7 +1203,7 @@ module M(SC:Syscall_t) = { (WArray200.get64 (WArray200.set8 (WArray200.init64 (fun i_0 => state.[i_0])) (72 - 1) (( (get8 (WArray200.init64 (fun i_0 => state.[i_0])) (72 - 1)) `^` (W8.of_int 128))))); out_s <- out; - state <@ _keccakf1600_scalar (state); + state <@ _keccakf1600_ (state); out <- out_s; i <- 0; while (i < 8) { @@ -1240,7 +1256,7 @@ module M(SC:Syscall_t) = { var t:W64.t; out_s <- witness; out_s <- out; - state <@ _keccakf1600_scalar (state); + state <@ _keccakf1600_ (state); out <- out_s; aux <- (168 %/ 8); i <- 0; diff --git a/code/jasmin/mlkem_avx2/fips202.jinc b/code/jasmin/mlkem_avx2/fips202.jinc index 8fce754e..178c5c02 100644 --- a/code/jasmin/mlkem_avx2/fips202.jinc +++ b/code/jasmin/mlkem_avx2/fips202.jinc @@ -162,7 +162,7 @@ fn __keccak1600_scalar( s_inlen = inlen; s_rate = rate; - state = _keccakf1600_scalar(state); + state = _keccakf1600_(state); inlen = s_inlen; in = s_in; @@ -180,7 +180,7 @@ fn __keccak1600_scalar( s_outlen = outlen; s_rate = rate; - state = _keccakf1600_scalar(state); + state = _keccakf1600_(state); out = s_out; outlen = s_outlen; @@ -191,7 +191,7 @@ fn __keccak1600_scalar( s_out = out; } - state = _keccakf1600_scalar(state); + state = _keccakf1600_(state); out = s_out; outlen = s_outlen; @@ -267,7 +267,7 @@ fn _isha3_256(reg ptr u8[32] out, reg u64 in inlen) -> reg ptr u8[32] s_ilen = ilen; s_r8 = r8; - state = _keccakf1600_scalar(state); + state = _keccakf1600_(state); in = s_in; ilen = s_ilen; @@ -277,7 +277,7 @@ fn _isha3_256(reg ptr u8[32] out, reg u64 in inlen) -> reg ptr u8[32] t8 = 0x06; state = __add_final_block(state, in, ilen, t8, r8); - state = _keccakf1600_scalar(state); + state = _keccakf1600_(state); out = s_out; @@ -315,7 +315,7 @@ fn __isha3_512(reg ptr u8[64] out, reg u64 in, inline int inlen) -> stack u8[64] s_ilen = ilen; s_r8 = r8; - state = _keccakf1600_scalar(state); + state = _keccakf1600_(state); in = s_in; ilen = s_ilen; @@ -325,7 +325,7 @@ fn __isha3_512(reg ptr u8[64] out, reg u64 in, inline int inlen) -> stack u8[64] t8 = 0x06; state = __add_final_block(state, in, ilen, t8, r8); - state = _keccakf1600_scalar(state); + state = _keccakf1600_(state); out = s_out; @@ -361,7 +361,7 @@ fn _shake256_1120_32(reg u64 out, reg u64 in0 in1) { s_in = in1; - state = _keccakf1600_scalar(state); + state = _keccakf1600_(state); r8 = SHAKE256_RATE; ilen = MLKEM_INDCPA_CIPHERTEXTBYTES - (SHAKE256_RATE - MLKEM_SYMBYTES); @@ -376,7 +376,7 @@ fn _shake256_1120_32(reg u64 out, reg u64 in0 in1) { s_ilen = ilen; s_r8 = r8; - state = _keccakf1600_scalar(state); + state = _keccakf1600_(state); in = s_in; ilen = s_ilen; @@ -386,7 +386,7 @@ fn _shake256_1120_32(reg u64 out, reg u64 in0 in1) { t8 = 0x1f; state = __add_final_block(state, in, ilen, t8, r8); - state = _keccakf1600_scalar(state); + state = _keccakf1600_(state); out = s_out; @@ -422,7 +422,7 @@ fn _shake256_128_33(reg ptr u8[128] out, reg const ptr u8[33] in) -> stack u8[12 state[u8 33] ^= 0x1f; state[u8 SHAKE256_RATE-1] ^= 0x80; - state = _keccakf1600_scalar(state); + state = _keccakf1600_(state); out = sout; @@ -455,7 +455,7 @@ fn _isha3_256_32(reg ptr u8[32] out, reg ptr u8[MLKEM_SYMBYTES] in) -> reg ptr u state[u8 MLKEM_SYMBYTES] ^= 0x06; state[u8 SHA3_256_RATE - 1] = 0x80; - state = _keccakf1600_scalar(state); + state = _keccakf1600_(state); out = s_out; @@ -489,7 +489,7 @@ fn _sha3_512_64(reg ptr u8[64] out, reg const ptr u8[64] in) -> stack u8[64] out_s = out; - state = _keccakf1600_scalar(state); + state = _keccakf1600_(state); out = out_s; @@ -523,7 +523,7 @@ fn _sha3_512_32(reg ptr u8[64] out, reg const ptr u8[32] in) -> stack u8[64] out_s = out; - state = _keccakf1600_scalar(state); + state = _keccakf1600_(state); out = out_s; @@ -567,7 +567,7 @@ fn _shake128_squeezeblock(reg ptr u64[25] state, reg ptr u8[SHAKE128_RATE] out) inline int i; out_s = out; - state = _keccakf1600_scalar(state); + state = _keccakf1600_(state); out = out_s; for i = 0 to SHAKE128_RATE/8 diff --git a/code/jasmin/mlkem_avx2/keccakf1600.jinc b/code/jasmin/mlkem_avx2/keccakf1600.jinc index 02996b6a..c5aa5e62 100644 --- a/code/jasmin/mlkem_avx2/keccakf1600.jinc +++ b/code/jasmin/mlkem_avx2/keccakf1600.jinc @@ -1,4 +1,6 @@ -u64[24] KECCAK_RC = +param int KECCAK_ROUNDS = 24; + +u64[24] KECCAK1600_RC = { 0x0000000000000001 ,0x0000000000008082 ,0x800000000000808a @@ -25,14 +27,14 @@ u64[24] KECCAK_RC = ,0x8000000080008008 }; -inline fn __index(inline int x y) -> inline int +inline fn keccakf1600_index(inline int x y) -> inline int { inline int r; r = (x % 5) + 5 * (y % 5); return r; } -inline fn __keccak_rho_offsets(inline int i) -> inline int +inline fn keccakf1600_rho_offsets(inline int i) -> inline int { inline int r x y z t; @@ -40,10 +42,9 @@ inline fn __keccak_rho_offsets(inline int i) -> inline int x = 1; y = 0; - for t = 0 to 24 { - if (i == x + 5 * y) { - r = ((t + 1) * (t + 2) / 2) % 64; - } + for t = 0 to 24 + { if (i == x + 5 * y) + { r = ((t + 1) * (t + 2) / 2) % 64; } z = (2 * x + 3 * y) % 5; x = y; y = z; @@ -52,143 +53,178 @@ inline fn __keccak_rho_offsets(inline int i) -> inline int return r; } -inline fn __rhotates(inline int x y) -> inline int +inline fn keccakf1600_rhotates(inline int x y) -> inline int { inline int i r; - i = __index(x, y); - r = __keccak_rho_offsets(i); + i = keccakf1600_index(x, y); + r = keccakf1600_rho_offsets(i); return r; } -inline fn __theta_sum_scalar(reg ptr u64[25] a) -> reg u64[5] +// C[x] = A[x,0] ^ A[x,1] ^ A[x,2] ^ A[x,3] ^ A[x,4] +inline fn keccakf1600_theta_sum(reg ptr u64[25] a) -> reg u64[5] { - inline int i j ti; + inline int x y; reg u64[5] c; - for i=0 to 5 - { - ti = __index(i, 0); - c[i] = a[ti]; - } + // C[x] = A[x, 0] + for x=0 to 5 + { c[x] = a[x + 0]; } - for j=1 to 5 - { for i=0 to 5 - { - ti = __index(i, j); - c[i] ^= a[ti]; - } + // C[x] ^= A[x,1] ^ A[x,2] ^ A[x,3] ^ A[x,4] + for y=1 to 5 + { for x=0 to 5 + { c[x] ^= a[x + y*5]; } } return c; } -inline fn __theta_rol_scalar(reg u64[5] c) -> reg u64[5] +// D[x] = C[x-1] ^ ROT(C[x+1], 1) +inline fn keccakf1600_theta_rol(reg u64[5] c) -> reg u64[5] { - inline int i; + inline int x; reg u64[5] d; - for i = 0 to 5 - { d[i] = c[(i+1)%5]; - _, _, d[i] = #ROL_64(d[i], 1); - d[i] ^= c[(i+4)%5]; + for x = 0 to 5 + { // D[x] = C[x + 1] + d[x] = c[(x + 1) % 5]; + + // D[x] = ROT(D[x], 1) + _, _, d[x] = #ROL_64(d[x], 1); + + // D[x] ^= C[x-1] + d[x] ^= c[(x - 1 + 5) % 5]; } return d; } -inline fn __rol_sum_scalar( - reg u64[5] d, +// B[x] = ROT( (A[x',y'] ^ D[x']), r[x',y'] ) with (x',y') = M^-1 (x,y) +// +// M = (0 1) M^-1 = (1 3) x' = 1x + 3y +// (2 3) (1 0) y' = 1x + 0y +// +inline fn keccakf1600_rol_sum( reg ptr u64[25] a, - inline int offset -) -> reg u64[5] + reg u64[5] d, + inline int y) + -> + reg u64[5] { - inline int j j1 k ti; - reg u64[5] c; + inline int r x x_ y_; + reg u64[5] b; - for j = 0 to 5 + for x = 0 to 5 { - j1 = (j+offset) % 5; - k = __rhotates(j1, j); - ti = __index(j1, j); - c[j] = a[ti]; - c[j] ^= d[j1]; - _, _, c[j] = #ROL_64(c[j], k); + x_ = (x + 3*y) % 5; + y_ = x; + r = keccakf1600_rhotates(x_, y_); + + // B[x] = A[x',y'] + b[x] = a[x_ + y_*5]; + + // B[x] ^= D[x']; + b[x] ^= d[x_]; + + // B[x] = ROT( B[x], r[x',y'] ); + if(r != 0) + { _, _, b[x] = #ROL_64(b[x], r); } + } - return c; + return b; } -inline fn __set_row_scalar( - reg ptr u64[25] r, - inline int row, - reg u64[5] c, - reg u64 iota -) -> reg ptr u64[25] +// E[x, y] = B[x] ^ ( (!B[x+1]) & B[x+2] ) +// -- when x and y are 0: E[0,0] ^= RC[i]; +inline fn keccakf1600_set_row( + reg ptr u64[25] e, + reg u64[5] b, + inline int y, + stack u64 s_rc) + -> + reg ptr u64[25] { - inline int j j1 j2 ti; + inline int x x1 x2; reg u64 t; - for j= 0 to 5 + for x=0 to 5 { - j1 = (j+1) % 5; - j2 = (j+2) % 5; - t = !c[j1] & c[j2]; - if row==0 && j==0 { t ^= iota; } - t ^= c[j]; - ti = __index(j, row); - r[ti] = t; + x1 = (x + 1) % 5; + x2 = (x + 2) % 5; + + t = !b[x1] & b[x2]; // bmi1 + //t = b[x1]; t = !t; t &= b[x2]; + + t ^= b[x]; + if( x==0 && y==0 ){ t ^= s_rc; } + e[x + y*5] = t; } - return r; + return e; } -inline fn __round2x_scalar(reg ptr u64[25] a r, reg u64 iota) -> reg ptr u64[25], reg ptr u64[25] +inline fn keccakf1600_round( + reg ptr u64[25] e, + reg ptr u64[25] a, + reg u64 rc) + -> + reg ptr u64[25] { - reg u64[5] c d; - - c = __theta_sum_scalar(a); - d = __theta_rol_scalar(c); - c = __rol_sum_scalar(d, a, 0); - r = __set_row_scalar(r, 0, c, iota); - c = __rol_sum_scalar(d, a, 3); - r = __set_row_scalar(r, 1, c, iota); - c = __rol_sum_scalar(d, a, 1); - r = __set_row_scalar(r, 2, c, iota); - c = __rol_sum_scalar(d, a, 4); - r = __set_row_scalar(r, 3, c, iota); - c = __rol_sum_scalar(d, a, 2); - r = __set_row_scalar(r, 4, c, iota); - - return a, r; + inline int y; + reg u64[5] b c d; + stack u64 s_rc; + + s_rc = rc; + + c = keccakf1600_theta_sum(a); + d = keccakf1600_theta_rol(c); + + for y = 0 to 5 + { b = keccakf1600_rol_sum(a, d, y); + e = keccakf1600_set_row(e, b, y, s_rc); + } + + return e; } -#[returnaddress="stack"] -fn _keccakf1600_scalar(reg ptr u64[25] a) -> reg ptr u64[25] +inline fn __keccakf1600(reg ptr u64[25] a) -> reg ptr u64[25] { - stack u64[25] r; - reg ptr u64[24] iotas_p; - reg u64 iota; - reg u64 round; - stack u64 round_s; + reg ptr u64[24] RC; + stack u64[25] s_e; + reg ptr u64[25] e; - iotas_p = KECCAK_RC; + reg u64 c rc; - round = 0; + RC = KECCAK1600_RC; + e = s_e; - while(round < 24) + c = 0; + while (c < KECCAK_ROUNDS) { - iota = iotas_p[(int) round]; - round_s = round; - a, r = __round2x_scalar(a, r, iota); - round = round_s; - round += 1; - - iota = iotas_p[(int) round]; - round_s = round; - r, a = __round2x_scalar(r, a, iotas_p[(int) round]); - round = round_s; - round += 1; + rc = RC[(int) c]; + e = keccakf1600_round(e, a, rc); + + rc = RC[(int) c + 1]; + a = keccakf1600_round(a, e, rc); + + c += 2; } return a; } + +fn _keccakf1600(reg ptr u64[25] a) -> reg ptr u64[25] +{ + a = __keccakf1600(a); + return a; +} + +inline fn _keccakf1600_(reg ptr u64[25] a) -> reg ptr u64[25] +{ + a = a; + a = _keccakf1600(a); + a = a; + return a; +} diff --git a/code/jasmin/mlkem_ref/extraction/jkem.ec b/code/jasmin/mlkem_ref/extraction/jkem.ec index c3ce1bda..08acfd2a 100644 --- a/code/jasmin/mlkem_ref/extraction/jkem.ec +++ b/code/jasmin/mlkem_ref/extraction/jkem.ec @@ -79,7 +79,7 @@ W16.of_int 996; W16.of_int 991; W16.of_int 958; W16.of_int 1869; W16.of_int 1522; W16.of_int 1628]. -abbrev roundconstants = Array24.of_list witness [W64.of_int 1; +abbrev KECCAK1600_RC = Array24.of_list witness [W64.of_int 1; W64.of_int 32898; W64.of_int (-9223372036854742902); W64.of_int (-9223372034707259392); W64.of_int 32907; W64.of_int 2147483649; W64.of_int (-9223372034707259263); W64.of_int (-9223372036854743031); @@ -148,7 +148,7 @@ module M(SC:Syscall_t) = { return (r); } - proc __index (x:int, y:int) : int = { + proc keccakf1600_index (x:int, y:int) : int = { var r:int; @@ -156,53 +156,7 @@ module M(SC:Syscall_t) = { return (r); } - proc __theta (a:W64.t Array25.t) : W64.t Array25.t = { - var aux_1: bool; - var aux_0: bool; - var aux: int; - var aux_2: W64.t; - - var x:int; - var c:W64.t Array5.t; - var y:int; - var d:W64.t Array5.t; - var _0:bool; - var _1:bool; - c <- witness; - d <- witness; - x <- 0; - while (x < 5) { - c.[x] <- (W64.of_int 0); - y <- 0; - while (y < 5) { - c.[x] <- (c.[x] `^` a.[(x + (5 * y))]); - y <- y + 1; - } - x <- x + 1; - } - x <- 0; - while (x < 5) { - d.[x] <- c.[((x + 1) %% 5)]; - (aux_1, aux_0, aux_2) <- ROL_64 d.[x] (W8.of_int 1); - _0 <- aux_1; - _1 <- aux_0; - d.[x] <- aux_2; - d.[x] <- (d.[x] `^` c.[((x + 4) %% 5)]); - x <- x + 1; - } - x <- 0; - while (x < 5) { - y <- 0; - while (y < 5) { - a.[(x + (5 * y))] <- (a.[(x + (5 * y))] `^` d.[x]); - y <- y + 1; - } - x <- x + 1; - } - return (a); - } - - proc __keccakRhoOffsets (i:int) : int = { + proc keccakf1600_rho_offsets (i:int) : int = { var aux: int; var r:int; @@ -229,124 +183,191 @@ module M(SC:Syscall_t) = { return (r); } - proc __rho (a:W64.t Array25.t) : W64.t Array25.t = { + proc keccakf1600_rhotates (x:int, y:int) : int = { + + var r:int; + var i:int; + + i <@ keccakf1600_index (x, y); + r <@ keccakf1600_rho_offsets (i); + return (r); + } + + proc keccakf1600_theta_sum (a:W64.t Array25.t) : W64.t Array5.t = { + var aux: int; + + var c:W64.t Array5.t; + var x:int; + var y:int; + c <- witness; + x <- 0; + while (x < 5) { + c.[x] <- a.[(x + 0)]; + x <- x + 1; + } + y <- 1; + while (y < 5) { + x <- 0; + while (x < 5) { + c.[x] <- (c.[x] `^` a.[(x + (y * 5))]); + x <- x + 1; + } + y <- y + 1; + } + return (c); + } + + proc keccakf1600_theta_rol (c:W64.t Array5.t) : W64.t Array5.t = { var aux_1: bool; var aux_0: bool; var aux: int; var aux_2: W64.t; + var d:W64.t Array5.t; var x:int; - var y:int; - var i:int; - var z:int; var _0:bool; var _1:bool; + d <- witness; + x <- 0; + while (x < 5) { + d.[x] <- c.[((x + 1) %% 5)]; + (aux_1, aux_0, aux_2) <- ROL_64 d.[x] (W8.of_int 1); + _0 <- aux_1; + _1 <- aux_0; + d.[x] <- aux_2; + d.[x] <- (d.[x] `^` c.[(((x - 1) + 5) %% 5)]); + x <- x + 1; + } + return (d); + } + + proc keccakf1600_rol_sum (a:W64.t Array25.t, d:W64.t Array5.t, y:int) : + W64.t Array5.t = { + var aux_1: bool; + var aux_0: bool; + var aux: int; + var aux_2: W64.t; + var b:W64.t Array5.t; + var x:int; + var x_:int; + var y_:int; + var r:int; + var _0:bool; + var _1:bool; + b <- witness; x <- 0; while (x < 5) { - y <- 0; - while (y < 5) { - i <@ __index (x, y); - z <@ __keccakRhoOffsets (i); - (aux_1, aux_0, aux_2) <- ROL_64 a.[i] (W8.of_int z); + x_ <- ((x + (3 * y)) %% 5); + y_ <- x; + r <@ keccakf1600_rhotates (x_, y_); + b.[x] <- a.[(x_ + (y_ * 5))]; + b.[x] <- (b.[x] `^` d.[x_]); + if ((r <> 0)) { + (aux_1, aux_0, aux_2) <- ROL_64 b.[x] (W8.of_int r); _0 <- aux_1; _1 <- aux_0; - a.[i] <- aux_2; - y <- y + 1; + b.[x] <- aux_2; + } else { + } x <- x + 1; } - return (a); + return (b); } - proc __pi (a:W64.t Array25.t) : W64.t Array25.t = { + proc keccakf1600_set_row (e:W64.t Array25.t, b:W64.t Array5.t, y:int, + s_rc:W64.t) : W64.t Array25.t = { var aux: int; - var i:int; - var t:W64.t; - var b:W64.t Array25.t; - var y:int; var x:int; - b <- witness; - i <- 0; - while (i < 25) { - t <- a.[i]; - b.[i] <- t; - i <- i + 1; - } + var x1:int; + var x2:int; + var t:W64.t; + x <- 0; while (x < 5) { - y <- 0; - while (y < 5) { - t <- b.[(x + (5 * y))]; - i <@ __index (y, ((2 * x) + (3 * y))); - a.[i] <- t; - y <- y + 1; + x1 <- ((x + 1) %% 5); + x2 <- ((x + 2) %% 5); + t <- b.[x1]; + t <- (invw t); + t <- (t `&` b.[x2]); + t <- (t `^` b.[x]); + if (((x = 0) /\ (y = 0))) { + t <- (t `^` s_rc); + } else { + } + e.[(x + (y * 5))] <- t; x <- x + 1; } - return (a); + return (e); } - proc __chi (a:W64.t Array25.t) : W64.t Array25.t = { + proc keccakf1600_round (e:W64.t Array25.t, a:W64.t Array25.t, rc:W64.t) : + W64.t Array25.t = { var aux: int; - var x:int; - var y:int; - var i:int; + var s_rc:W64.t; var c:W64.t Array5.t; + var d:W64.t Array5.t; + var y:int; + var b:W64.t Array5.t; + b <- witness; c <- witness; + d <- witness; + s_rc <- rc; + c <@ keccakf1600_theta_sum (a); + d <@ keccakf1600_theta_rol (c); y <- 0; while (y < 5) { - x <- 0; - while (x < 5) { - i <@ __index ((x + 1), y); - c.[x] <- a.[i]; - c.[x] <- (invw c.[x]); - i <@ __index ((x + 2), y); - c.[x] <- (c.[x] `&` a.[i]); - i <@ __index (x, y); - c.[x] <- (c.[x] `^` a.[i]); - x <- x + 1; - } - x <- 0; - while (x < 5) { - a.[(x + (5 * y))] <- c.[x]; - x <- x + 1; - } + b <@ keccakf1600_rol_sum (a, d, y); + e <@ keccakf1600_set_row (e, b, y, s_rc); y <- y + 1; } + return (e); + } + + proc __keccakf1600 (a:W64.t Array25.t) : W64.t Array25.t = { + + var rC:W64.t Array24.t; + var s_e:W64.t Array25.t; + var e:W64.t Array25.t; + var c:W64.t; + var rc:W64.t; + rC <- witness; + e <- witness; + s_e <- witness; + rC <- KECCAK1600_RC; + e <- s_e; + c <- (W64.of_int 0); + + while ((c \ult (W64.of_int 24))) { + rc <- rC.[(W64.to_uint c)]; + e <@ keccakf1600_round (e, a, rc); + rc <- rC.[((W64.to_uint c) + 1)]; + a <@ keccakf1600_round (a, e, rc); + c <- (c + (W64.of_int 2)); + } return (a); } - proc __iota (a:W64.t Array25.t, c:W64.t) : W64.t Array25.t = { + proc _keccakf1600 (a:W64.t Array25.t) : W64.t Array25.t = { - a.[0] <- (a.[0] `^` c); + a <@ __keccakf1600 (a); return (a); } - proc __keccakf1600_ref (state:W64.t Array25.t) : W64.t Array25.t = { - - var constptr:W64.t Array24.t; - var rctr:W64.t; - constptr <- witness; - constptr <- roundconstants; - rctr <- (W64.of_int 0); - - while ((rctr \ult (W64.of_int 192))) { - state <@ __theta (state); - state <@ __rho (state); - state <@ __pi (state); - state <@ __chi (state); - constptr <- roundconstants; - state <@ __iota (state, - (get64_direct (WArray192.init64 (fun i => constptr.[i])) - (W64.to_uint rctr))); - rctr <- (rctr + (W64.of_int 8)); - } - return (state); + proc _keccakf1600_ (a:W64.t Array25.t) : W64.t Array25.t = { + + + + a <- a; + a <@ _keccakf1600 (a); + a <- a; + return (a); } proc __st0 (state:W64.t Array25.t) : W64.t Array25.t = { @@ -452,7 +473,7 @@ module M(SC:Syscall_t) = { Array25.init (WArray200.get64 (WArray200.set8 (WArray200.init64 (fun i_0 => state.[i_0])) (136 - 1) (( (get8 (WArray200.init64 (fun i_0 => state.[i_0])) (136 - 1)) `^` (W8.of_int 128))))); - state <@ __keccakf1600_ref (state); + state <@ _keccakf1600_ (state); out <- sout; i <- 0; while (i < 128) { @@ -467,6 +488,7 @@ module M(SC:Syscall_t) = { var aux: int; var s_out:W64.t; + var s_in1:W64.t; var state:W64.t Array25.t; var i:int; var t64:W64.t; @@ -479,6 +501,7 @@ module M(SC:Syscall_t) = { var t8:W8.t; state <- witness; s_out <- out; + s_in1 <- in1; state <@ __st0 (state); aux <- (32 %/ 8); i <- 0; @@ -495,10 +518,10 @@ module M(SC:Syscall_t) = { state.[i] <- (state.[i] `^` t64); i <- i + 1; } - state <@ __keccakf1600_ref (state); + state <@ _keccakf1600_ (state); r8 <- (W64.of_int 136); ilen <- (W64.of_int (((3 * 320) + 128) - (136 - 32))); - in_0 <- in1; + in_0 <- s_in1; in_0 <- (in_0 + (W64.of_int (136 - 32))); while ((r8 \ule ilen)) { @@ -506,14 +529,14 @@ module M(SC:Syscall_t) = { s_in <- in_0; s_ilen <- ilen; s_r8 <- r8; - state <@ __keccakf1600_ref (state); + state <@ _keccakf1600_ (state); in_0 <- s_in; ilen <- s_ilen; r8 <- s_r8; } t8 <- (W8.of_int 31); state <@ __add_final_block (state, in_0, ilen, t8, r8); - state <@ __keccakf1600_ref (state); + state <@ _keccakf1600_ (state); out <- s_out; aux <- (32 %/ 8); i <- 0; @@ -529,10 +552,13 @@ module M(SC:Syscall_t) = { proc _sha3512_32 (out:W8.t Array64.t, in_0:W8.t Array32.t) : W8.t Array64.t = { var aux: int; + var s_out:W8.t Array64.t; var state:W64.t Array25.t; var i:int; var c:W8.t; + s_out <- witness; state <- witness; + s_out <- out; state <@ __st0 (state); i <- 0; while (i < 32) { @@ -551,7 +577,8 @@ module M(SC:Syscall_t) = { Array25.init (WArray200.get64 (WArray200.set8 (WArray200.init64 (fun i_0 => state.[i_0])) (72 - 1) (( (get8 (WArray200.init64 (fun i_0 => state.[i_0])) (72 - 1)) `^` (W8.of_int 128))))); - state <@ __keccakf1600_ref (state); + state <@ _keccakf1600_ (state); + out <- s_out; i <- 0; while (i < 64) { c <- (get8 (WArray200.init64 (fun i_0 => state.[i_0])) i); @@ -593,10 +620,13 @@ module M(SC:Syscall_t) = { W64.t Array25.t * W8.t Array168.t = { var aux: int; + var s_out:W8.t Array168.t; var i:int; var c:W8.t; - - state <@ __keccakf1600_ref (state); + s_out <- witness; + s_out <- out; + state <@ _keccakf1600_ (state); + out <- s_out; i <- 0; while (i < 168) { c <- (get8 (WArray200.init64 (fun i_0 => state.[i_0])) i); @@ -631,14 +661,14 @@ module M(SC:Syscall_t) = { s_in <- in_0; s_ilen <- ilen; s_r8 <- r8; - state <@ __keccakf1600_ref (state); + state <@ _keccakf1600_ (state); in_0 <- s_in; ilen <- s_ilen; r8 <- s_r8; } t8 <- (W8.of_int 6); state <@ __add_final_block (state, in_0, ilen, t8, r8); - state <@ __keccakf1600_ref (state); + state <@ _keccakf1600_ (state); out <- s_out; i <- 0; while (i < 4) { @@ -676,7 +706,7 @@ module M(SC:Syscall_t) = { (WArray200.get64 (WArray200.set8 (WArray200.init64 (fun i_0 => state.[i_0])) (72 - 1) (( (get8 (WArray200.init64 (fun i_0 => state.[i_0])) (72 - 1)) `^` (W8.of_int 128))))); out_s <- out; - state <@ __keccakf1600_ref (state); + state <@ _keccakf1600_ (state); out <- out_s; i <- 0; while (i < 8) { @@ -1857,6 +1887,7 @@ module M(SC:Syscall_t) = { noiseseed:W8.t Array32.t) : unit = { var aux: W16.t Array256.t; + var s_noiseseed:W8.t Array32.t; var pkpv:W16.t Array768.t; var i:W64.t; var t64:W64.t; @@ -1877,8 +1908,10 @@ module M(SC:Syscall_t) = { k <- witness; pkpv <- witness; publicseed <- witness; + s_noiseseed <- witness; sp_0 <- witness; v <- witness; + s_noiseseed <- noiseseed; pkpv <@ __polyvec_frombytes (pkp); i <- (W64.of_int 0); pkp <- (pkp + (W64.of_int (3 * 384))); @@ -1895,41 +1928,41 @@ module M(SC:Syscall_t) = { aat <@ __gen_matrix (publicseed, (W64.of_int 1)); nonce <- (W8.of_int 0); aux <@ _poly_getnoise ((Array256.init (fun i_0 => sp_0.[0 + i_0])), - noiseseed, nonce); + s_noiseseed, nonce); sp_0 <- Array768.init (fun i_0 => if 0 <= i_0 < 0 + 256 then aux.[i_0-0] else sp_0.[i_0]); nonce <- (W8.of_int 1); aux <@ _poly_getnoise ((Array256.init (fun i_0 => sp_0.[256 + i_0])), - noiseseed, nonce); + s_noiseseed, nonce); sp_0 <- Array768.init (fun i_0 => if 256 <= i_0 < 256 + 256 then aux.[i_0-256] else sp_0.[i_0]); nonce <- (W8.of_int 2); aux <@ _poly_getnoise ((Array256.init (fun i_0 => sp_0.[(2 * 256) + i_0])), - noiseseed, nonce); + s_noiseseed, nonce); sp_0 <- Array768.init (fun i_0 => if (2 * 256) <= i_0 < (2 * 256) + 256 then aux.[i_0-(2 * 256)] else sp_0.[i_0]); nonce <- (W8.of_int 3); aux <@ _poly_getnoise ((Array256.init (fun i_0 => ep.[0 + i_0])), - noiseseed, nonce); + s_noiseseed, nonce); ep <- Array768.init (fun i_0 => if 0 <= i_0 < 0 + 256 then aux.[i_0-0] else ep.[i_0]); nonce <- (W8.of_int 4); aux <@ _poly_getnoise ((Array256.init (fun i_0 => ep.[256 + i_0])), - noiseseed, nonce); + s_noiseseed, nonce); ep <- Array768.init (fun i_0 => if 256 <= i_0 < 256 + 256 then aux.[i_0-256] else ep.[i_0]); nonce <- (W8.of_int 5); aux <@ _poly_getnoise ((Array256.init (fun i_0 => ep.[(2 * 256) + i_0])), - noiseseed, nonce); + s_noiseseed, nonce); ep <- Array768.init (fun i_0 => if (2 * 256) <= i_0 < (2 * 256) + 256 then aux.[i_0-(2 * 256)] else ep.[i_0]); nonce <- (W8.of_int 6); - epp <@ _poly_getnoise (epp, noiseseed, nonce); + epp <@ _poly_getnoise (epp, s_noiseseed, nonce); sp_0 <@ __polyvec_ntt (sp_0); aux <@ __polyvec_pointwise_acc ((Array768.init (fun i_0 => aat.[0 + i_0])), sp_0); @@ -1966,6 +1999,7 @@ module M(SC:Syscall_t) = { var aux_0: W8.t Array960.t; var aux: W16.t Array256.t; + var s_noiseseed:W8.t Array32.t; var sctp:W8.t Array1088.t; var pkpv:W16.t Array768.t; var i:W64.t; @@ -1986,9 +2020,11 @@ module M(SC:Syscall_t) = { k <- witness; pkpv <- witness; publicseed <- witness; + s_noiseseed <- witness; sctp <- witness; sp_0 <- witness; v <- witness; + s_noiseseed <- noiseseed; sctp <- ctp; pkpv <@ __polyvec_frombytes (pkp); i <- (W64.of_int 0); @@ -2006,41 +2042,41 @@ module M(SC:Syscall_t) = { aat <@ __gen_matrix (publicseed, (W64.of_int 1)); nonce <- (W8.of_int 0); aux <@ _poly_getnoise ((Array256.init (fun i_0 => sp_0.[0 + i_0])), - noiseseed, nonce); + s_noiseseed, nonce); sp_0 <- Array768.init (fun i_0 => if 0 <= i_0 < 0 + 256 then aux.[i_0-0] else sp_0.[i_0]); nonce <- (W8.of_int 1); aux <@ _poly_getnoise ((Array256.init (fun i_0 => sp_0.[256 + i_0])), - noiseseed, nonce); + s_noiseseed, nonce); sp_0 <- Array768.init (fun i_0 => if 256 <= i_0 < 256 + 256 then aux.[i_0-256] else sp_0.[i_0]); nonce <- (W8.of_int 2); aux <@ _poly_getnoise ((Array256.init (fun i_0 => sp_0.[(2 * 256) + i_0])), - noiseseed, nonce); + s_noiseseed, nonce); sp_0 <- Array768.init (fun i_0 => if (2 * 256) <= i_0 < (2 * 256) + 256 then aux.[i_0-(2 * 256)] else sp_0.[i_0]); nonce <- (W8.of_int 3); aux <@ _poly_getnoise ((Array256.init (fun i_0 => ep.[0 + i_0])), - noiseseed, nonce); + s_noiseseed, nonce); ep <- Array768.init (fun i_0 => if 0 <= i_0 < 0 + 256 then aux.[i_0-0] else ep.[i_0]); nonce <- (W8.of_int 4); aux <@ _poly_getnoise ((Array256.init (fun i_0 => ep.[256 + i_0])), - noiseseed, nonce); + s_noiseseed, nonce); ep <- Array768.init (fun i_0 => if 256 <= i_0 < 256 + 256 then aux.[i_0-256] else ep.[i_0]); nonce <- (W8.of_int 5); aux <@ _poly_getnoise ((Array256.init (fun i_0 => ep.[(2 * 256) + i_0])), - noiseseed, nonce); + s_noiseseed, nonce); ep <- Array768.init (fun i_0 => if (2 * 256) <= i_0 < (2 * 256) + 256 then aux.[i_0-(2 * 256)] else ep.[i_0]); nonce <- (W8.of_int 6); - epp <@ _poly_getnoise (epp, noiseseed, nonce); + epp <@ _poly_getnoise (epp, s_noiseseed, nonce); sp_0 <@ __polyvec_ntt (sp_0); aux <@ __polyvec_pointwise_acc ((Array768.init (fun i_0 => aat.[0 + i_0])), sp_0); @@ -2275,6 +2311,7 @@ module M(SC:Syscall_t) = { var pkp:W64.t; var ctpc:W8.t Array1088.t; var cnd:W64.t; + var s_cnd:W64.t; var zp:W64.t; buf <- witness; ctpc <- witness; @@ -2304,11 +2341,14 @@ module M(SC:Syscall_t) = { pkp, (Array32.init (fun i_0 => kr.[32 + i_0]))); ctp <- s_ctp; cnd <@ __verify (ctp, ctpc); + s_cnd <- cnd; zp <- s_skp; zp <- (zp + (W64.of_int 64)); zp <- (zp + (W64.of_int (((24 * 3) * 256) `|>>` 3))); + shkp <- s_shkp; _shake256_1120_32 (shkp, zp, ctp); shkp <- s_shkp; + cnd <- s_cnd; __cmov (shkp, (Array32.init (fun i_0 => kr.[0 + i_0])), cnd); return (); } diff --git a/code/jasmin/mlkem_ref/fips202.jinc b/code/jasmin/mlkem_ref/fips202.jinc index 0ca1e83a..793fe166 100644 --- a/code/jasmin/mlkem_ref/fips202.jinc +++ b/code/jasmin/mlkem_ref/fips202.jinc @@ -3,167 +3,236 @@ param int SHAKE256_RATE = 136; param int SHA3_256_RATE = 136; param int SHA3_512_RATE = 72; -inline -fn __index(inline int x, inline int y) -> inline int { +param int KECCAK_ROUNDS = 24; + +u64[24] KECCAK1600_RC = +{ 0x0000000000000001 + ,0x0000000000008082 + ,0x800000000000808a + ,0x8000000080008000 + ,0x000000000000808b + ,0x0000000080000001 + ,0x8000000080008081 + ,0x8000000000008009 + ,0x000000000000008a + ,0x0000000000000088 + ,0x0000000080008009 + ,0x000000008000000a + ,0x000000008000808b + ,0x800000000000008b + ,0x8000000000008089 + ,0x8000000000008003 + ,0x8000000000008002 + ,0x8000000000000080 + ,0x000000000000800a + ,0x800000008000000a + ,0x8000000080008081 + ,0x8000000000008080 + ,0x0000000080000001 + ,0x8000000080008008 +}; + +inline fn keccakf1600_index(inline int x y) -> inline int +{ inline int r; r = (x % 5) + 5 * (y % 5); return r; } +inline fn keccakf1600_rho_offsets(inline int i) -> inline int +{ + inline int r x y z t; -inline -fn __ROL64(reg u64 x, inline int c) -> reg u64 { - reg u64 y; - _, _, y = #ROL_64(x, c); - return y; + r = 0; + x = 1; + y = 0; + + for t = 0 to 24 + { if (i == x + 5 * y) + { r = ((t + 1) * (t + 2) / 2) % 64; } + z = (2 * x + 3 * y) % 5; + x = y; + y = z; + } + + return r; } -inline -fn __theta(reg ptr u64[25] a) -> reg ptr u64[25] { - inline int x, y; - reg u64[5] c, d; +inline fn keccakf1600_rhotates(inline int x y) -> inline int +{ + inline int i r; + i = keccakf1600_index(x, y); + r = keccakf1600_rho_offsets(i); + return r; +} - for x = 0 to 5 { - c[x] = 0; - for y = 0 to 5 { - c[x] ^= a[x + 5 * y]; - } +// C[x] = A[x,0] ^ A[x,1] ^ A[x,2] ^ A[x,3] ^ A[x,4] +inline fn keccakf1600_theta_sum(reg ptr u64[25] a) -> reg u64[5] +{ + inline int x y; + reg u64[5] c; + + // C[x] = A[x, 0] + for x=0 to 5 + { c[x] = a[x + 0]; } + + // C[x] ^= A[x,1] ^ A[x,2] ^ A[x,3] ^ A[x,4] + for y=1 to 5 + { for x=0 to 5 + { c[x] ^= a[x + y*5]; } } - for x = 0 to 5 { - /* d[x] = __ROL64(c[(x + 1) % 5], 1); */ - /* extraction fails */ + return c; +} - /* _, _, d[x] = #ROL_64(c[(x + 1) % 5], 1);*/ - /* d[x] ^= c[(x + 4) % 5];*/ - /* does not compile */ +// D[x] = C[x-1] ^ ROT(C[x+1], 1) +inline fn keccakf1600_theta_rol(reg u64[5] c) -> reg u64[5] +{ + inline int x; + reg u64[5] d; + for x = 0 to 5 + { // D[x] = C[x + 1] d[x] = c[(x + 1) % 5]; + + // D[x] = ROT(D[x], 1) _, _, d[x] = #ROL_64(d[x], 1); - d[x] ^= c[(x + 4) % 5]; - } - for x = 0 to 5 { - for y = 0 to 5 { - a[x + 5 * y] ^= d[x]; - } + // D[x] ^= C[x-1] + d[x] ^= c[(x - 1 + 5) % 5]; } - return a; + return d; } +// B[x] = ROT( (A[x',y'] ^ D[x']), r[x',y'] ) with (x',y') = M^-1 (x,y) +// +// M = (0 1) M^-1 = (1 3) x' = 1x + 3y +// (2 3) (1 0) y' = 1x + 0y +// +inline fn keccakf1600_rol_sum( + reg ptr u64[25] a, + reg u64[5] d, + inline int y) + -> + reg u64[5] +{ + inline int r x x_ y_; + reg u64[5] b; -inline -fn __keccakRhoOffsets(inline int i) -> inline int { - inline int r, x, y, z, t; - - r = 0; - x = 1; - y = 0; - for t = 0 to 24 { - if (i == x + 5 * y) { - r = ((t + 1) * (t + 2) / 2) % 64; - } - z = (2 * x + 3 * y) % 5; - x = y; - y = z; - } + for x = 0 to 5 + { + x_ = (x + 3*y) % 5; + y_ = x; + r = keccakf1600_rhotates(x_, y_); - return r; -} + // B[x] = A[x',y'] + b[x] = a[x_ + y_*5]; + // B[x] ^= D[x']; + b[x] ^= d[x_]; -inline -fn __rho(reg ptr u64[25] a) -> reg ptr u64[25] { - inline int x, y, i, z; + // B[x] = ROT( B[x], r[x',y'] ); + if(r != 0) + { _, _, b[x] = #ROL_64(b[x], r); } - for x = 0 to 5 { - for y = 0 to 5 { - i = __index(x, y); - z = __keccakRhoOffsets(i); - _, _, a[i] = #ROL_64(a[i], z); - } } - return a; + return b; } - -inline -fn __pi(reg ptr u64[25] a) -> reg ptr u64[25] { - stack u64[25] b; +// E[x, y] = B[x] ^ ( (!B[x+1]) & B[x+2] ) +// -- when x and y are 0: E[0,0] ^= RC[i]; +inline fn keccakf1600_set_row( + reg ptr u64[25] e, + reg u64[5] b, + inline int y, + stack u64 s_rc) + -> + reg ptr u64[25] +{ + inline int x x1 x2; reg u64 t; - inline int x, y, i; - for i = 0 to 25 { t = a[i]; b[i] = t; } - for x = 0 to 5 { - for y = 0 to 5 { - t = b[x + 5 * y]; - i = __index(y, 2 * x + 3 * y); - a[i] = t; - } - } - return a; -} + for x=0 to 5 + { + x1 = (x + 1) % 5; + x2 = (x + 2) % 5; + // t = !b[x1] & b[x2]; // bmi1 + t = b[x1]; t = !t; t &= b[x2]; -inline -fn __chi(reg ptr u64[25] a) -> reg ptr u64[25] { - inline int x, y, i; - reg u64[5] c; - for y = 0 to 5 { - for x = 0 to 5 { - i = __index(x + 1, y); - c[x] = a[i]; - c[x] = !c[x]; - i = __index(x + 2, y); - c[x] &= a[i]; - i = __index(x, y); - c[x] ^= a[i]; - } - for x = 0 to 5 { - a[x + 5 * y] = c[x]; - } + t ^= b[x]; + if( x==0 && y==0 ){ t ^= s_rc; } + e[x + y*5] = t; } - return a; + + return e; } +inline fn keccakf1600_round( + reg ptr u64[25] e, + reg ptr u64[25] a, + reg u64 rc) + -> + reg ptr u64[25] +{ + inline int y; + reg u64[5] b c d; + stack u64 s_rc; + + s_rc = rc; -inline -fn __iota(reg ptr u64[25] a, reg u64 c) -> reg ptr u64[25] { - a[0] ^= c; - return a; + c = keccakf1600_theta_sum(a); + d = keccakf1600_theta_rol(c); + + for y = 0 to 5 + { b = keccakf1600_rol_sum(a, d, y); + e = keccakf1600_set_row(e, b, y, s_rc); + } + + return e; } -u64[24] roundconstants = {0x0000000000000001, 0x0000000000008082, 0x800000000000808a, 0x8000000080008000, - 0x000000000000808b, 0x0000000080000001, 0x8000000080008081, 0x8000000000008009, - 0x000000000000008a, 0x0000000000000088, 0x0000000080008009, 0x000000008000000a, - 0x000000008000808b, 0x800000000000008b, 0x8000000000008089, 0x8000000000008003, - 0x8000000000008002, 0x8000000000000080, 0x000000000000800a, 0x800000008000000a, - 0x8000000080008081, 0x8000000000008080, 0x0000000080000001, 0x8000000080008008}; +inline fn __keccakf1600(reg ptr u64[25] a) -> reg ptr u64[25] +{ + reg ptr u64[24] RC; + stack u64[25] s_e; + reg ptr u64[25] e; + reg u64 c rc; -fn __keccakf1600_ref(reg ptr u64[25] state) -> reg ptr u64[25] { - reg ptr u64[24] constptr; + RC = KECCAK1600_RC; + e = s_e; - reg u64 rctr; - - constptr = roundconstants; - rctr = 0; + c = 0; + while (c < KECCAK_ROUNDS) + { + rc = RC[(int) c]; + e = keccakf1600_round(e, a, rc); - while (rctr < 192) { - state = __theta(state); - state = __rho(state); - state = __pi(state); - state = __chi(state); - constptr = roundconstants; - state = __iota(state, constptr.[(int)rctr]); - rctr += 8; + rc = RC[(int) c + 1]; + a = keccakf1600_round(a, e, rc); + + c += 2; } - return state; + return a; } +fn _keccakf1600(reg ptr u64[25] a) -> reg ptr u64[25] +{ + a = __keccakf1600(a); + return a; +} + +inline fn _keccakf1600_(reg ptr u64[25] a) -> reg ptr u64[25] +{ + a = a; + a = _keccakf1600(a); + a = a; + return a; +} inline fn __st0(reg ptr u64[25] state) -> reg ptr u64[25] @@ -325,7 +394,7 @@ fn ____keccak1600_ref( s_inlen = inlen; s_rate = rate; - state = __keccakf1600_ref(state); + state = _keccakf1600_(state); inlen = s_inlen; in = s_in; @@ -343,7 +412,7 @@ fn ____keccak1600_ref( s_outlen = outlen; s_rate = rate; - state = __keccakf1600_ref(state); + state = _keccakf1600_(state); out = s_out; outlen = s_outlen; @@ -354,7 +423,7 @@ fn ____keccak1600_ref( s_out = out; } - state = __keccakf1600_ref(state); + state = _keccakf1600_(state); out = s_out; outlen = s_outlen; @@ -414,7 +483,7 @@ fn _shake256_128_33(reg ptr u8[128] out, reg const ptr u8[33] in) -> stack u8[12 state[u8 33] ^= 0x1f; state[u8 SHAKE256_RATE-1] ^= 0x80; - state = __keccakf1600_ref(state); + state = _keccakf1600_(state); out = sout; @@ -427,13 +496,14 @@ fn _shake256_128_33(reg ptr u8[128] out, reg const ptr u8[33] in) -> stack u8[12 fn _shake256_1120_32(reg u64 out, reg u64 in0 in1) { stack u64[25] state; - stack u64 s_out; + stack u64 s_out s_in1; stack u64 s_in s_ilen s_r8; reg u64 ilen r8 t64 in; reg u8 t8; inline int i; s_out = out; + s_in1 = in1; state = __st0(state); for i = 0 to MLKEM_SYMBYTES/8 { @@ -446,11 +516,11 @@ fn _shake256_1120_32(reg u64 out, reg u64 in0 in1) { state[u64 i] ^= t64; } - state = __keccakf1600_ref(state); + state = _keccakf1600_(state); r8 = SHAKE256_RATE; ilen = MLKEM_CT_LEN - (SHAKE256_RATE - MLKEM_SYMBYTES); - in = in1; + in = s_in1; in += SHAKE256_RATE - MLKEM_SYMBYTES; while(ilen >= r8) @@ -461,7 +531,7 @@ fn _shake256_1120_32(reg u64 out, reg u64 in0 in1) { s_ilen = ilen; s_r8 = r8; - state = __keccakf1600_ref(state); + state = _keccakf1600_(state); in = s_in; ilen = s_ilen; @@ -471,7 +541,7 @@ fn _shake256_1120_32(reg u64 out, reg u64 in0 in1) { t8 = 0x1f; state = __add_final_block(state, in, ilen, t8, r8); - state = __keccakf1600_ref(state); + state = _keccakf1600_(state); out = s_out; @@ -489,6 +559,9 @@ fn _sha3512_32(reg ptr u8[64] out, reg const ptr u8[32] in) -> stack u8[64] stack u64[25] state; reg u8 c; inline int i; + stack ptr u8[64] s_out; + + s_out = out; state = __st0(state); @@ -499,8 +572,9 @@ fn _sha3512_32(reg ptr u8[64] out, reg const ptr u8[32] in) -> stack u8[64] state[u8 32] ^= 0x06; state[u8 SHA3_512_RATE-1] ^= 0x80; - state = __keccakf1600_ref(state); + state = _keccakf1600_(state); + out = s_out; for i = 0 to 64 { c = state[u8 (int) i]; out[i] = c; @@ -529,12 +603,16 @@ fn _shake128_absorb34(reg ptr u64[25] state, reg const ptr u8[34] in) -> reg ptr fn _shake128_squeezeblock(reg ptr u64[25] state, reg ptr u8[SHAKE128_RATE] out) -> reg ptr u64[25], reg ptr u8[SHAKE128_RATE] { + stack ptr u8[SHAKE128_RATE] s_out; reg u8 c; inline int i; - state = __keccakf1600_ref(state); + s_out = out; + + state = _keccakf1600_(state); - for i = 0 to SHAKE128_RATE { + out = s_out; + for i = 0 to SHAKE128_RATE { // SHAKE128 rate is 168: or 21 u64: TODO: 'compress' this for loop c = state[u8 (int) i]; out[i] = c; } @@ -567,7 +645,7 @@ fn _isha3_256(reg ptr u8[32] out, reg u64 in inlen) -> reg ptr u8[32] s_ilen = ilen; s_r8 = r8; - state = __keccakf1600_ref(state); + state = _keccakf1600_(state); in = s_in; ilen = s_ilen; @@ -577,7 +655,7 @@ fn _isha3_256(reg ptr u8[32] out, reg u64 in inlen) -> reg ptr u8[32] t8 = 0x06; state = __add_final_block(state, in, ilen, t8, r8); - state = __keccakf1600_ref(state); + state = _keccakf1600_(state); out = s_out; @@ -611,7 +689,7 @@ fn _isha3_256_32(reg ptr u8[32] out, reg ptr u8[MLKEM_SYMBYTES] in) -> reg ptr u state[u8 MLKEM_SYMBYTES] ^= 0x06; state[u8 SHA3_256_RATE - 1] = 0x80; - state = __keccakf1600_ref(state); + state = _keccakf1600_(state); out = s_out; @@ -645,7 +723,7 @@ fn _sha3_512_64(reg ptr u8[64] out, reg const ptr u8[64] in) -> stack u8[64] out_s = out; - state = __keccakf1600_ref(state); + state = _keccakf1600_(state); out = out_s; diff --git a/code/jasmin/mlkem_ref/indcpa.jinc b/code/jasmin/mlkem_ref/indcpa.jinc index 5e959a51..a0f2a463 100644 --- a/code/jasmin/mlkem_ref/indcpa.jinc +++ b/code/jasmin/mlkem_ref/indcpa.jinc @@ -93,6 +93,9 @@ fn __indcpa_enc(stack u64 sctp, reg ptr u8[32] msgp, reg u64 pkp, reg ptr u8[MLK reg u64 i t64; reg u64 ctp; reg u8 nonce; + stack ptr u8[MLKEM_SYMBYTES] s_noiseseed; + + s_noiseseed = noiseseed; pkpv = __polyvec_frombytes(pkp); @@ -111,21 +114,25 @@ fn __indcpa_enc(stack u64 sctp, reg ptr u8[32] msgp, reg u64 pkp, reg ptr u8[MLK aat = __gen_matrix(publicseed, 1); nonce = 0; - sp[0:MLKEM_N] = _poly_getnoise(sp[0:MLKEM_N], noiseseed, nonce); + sp[0:MLKEM_N] = _poly_getnoise(sp[0:MLKEM_N], s_noiseseed, nonce); + nonce = 1; - sp[MLKEM_N:MLKEM_N] = _poly_getnoise(sp[MLKEM_N:MLKEM_N], noiseseed, nonce); + sp[MLKEM_N:MLKEM_N] = _poly_getnoise(sp[MLKEM_N:MLKEM_N], s_noiseseed, nonce); + nonce = 2; - sp[2*MLKEM_N:MLKEM_N] = _poly_getnoise(sp[2*MLKEM_N:MLKEM_N], noiseseed, nonce); + sp[2*MLKEM_N:MLKEM_N] = _poly_getnoise(sp[2*MLKEM_N:MLKEM_N], s_noiseseed, nonce); nonce = 3; - ep[0:MLKEM_N] = _poly_getnoise(ep[0:MLKEM_N], noiseseed, nonce); + ep[0:MLKEM_N] = _poly_getnoise(ep[0:MLKEM_N], s_noiseseed, nonce); + nonce = 4; - ep[MLKEM_N:MLKEM_N] = _poly_getnoise(ep[MLKEM_N:MLKEM_N], noiseseed, nonce); + ep[MLKEM_N:MLKEM_N] = _poly_getnoise(ep[MLKEM_N:MLKEM_N], s_noiseseed, nonce); + nonce = 5; - ep[2*MLKEM_N:MLKEM_N] = _poly_getnoise(ep[2*MLKEM_N:MLKEM_N], noiseseed, nonce); + ep[2*MLKEM_N:MLKEM_N] = _poly_getnoise(ep[2*MLKEM_N:MLKEM_N], s_noiseseed, nonce); nonce = 6; - epp = _poly_getnoise(epp, noiseseed, nonce); + epp = _poly_getnoise(epp, s_noiseseed, nonce); sp = __polyvec_ntt(sp); @@ -160,7 +167,9 @@ fn __iindcpa_enc(reg ptr u8[MLKEM_CT_LEN] ctp, reg ptr u8[32] msgp, reg u64 pkp, reg u64 i t64; reg u8 nonce; stack ptr u8[MLKEM_CT_LEN] sctp; + stack ptr u8[MLKEM_SYMBYTES] s_noiseseed; + s_noiseseed = noiseseed; sctp = ctp; pkpv = __polyvec_frombytes(pkp); @@ -180,21 +189,25 @@ fn __iindcpa_enc(reg ptr u8[MLKEM_CT_LEN] ctp, reg ptr u8[32] msgp, reg u64 pkp, aat = __gen_matrix(publicseed, 1); nonce = 0; - sp[0:MLKEM_N] = _poly_getnoise(sp[0:MLKEM_N], noiseseed, nonce); + sp[0:MLKEM_N] = _poly_getnoise(sp[0:MLKEM_N], s_noiseseed, nonce); + nonce = 1; - sp[MLKEM_N:MLKEM_N] = _poly_getnoise(sp[MLKEM_N:MLKEM_N], noiseseed, nonce); + sp[MLKEM_N:MLKEM_N] = _poly_getnoise(sp[MLKEM_N:MLKEM_N], s_noiseseed, nonce); + nonce = 2; - sp[2*MLKEM_N:MLKEM_N] = _poly_getnoise(sp[2*MLKEM_N:MLKEM_N], noiseseed, nonce); + sp[2*MLKEM_N:MLKEM_N] = _poly_getnoise(sp[2*MLKEM_N:MLKEM_N], s_noiseseed, nonce); nonce = 3; - ep[0:MLKEM_N] = _poly_getnoise(ep[0:MLKEM_N], noiseseed, nonce); + ep[0:MLKEM_N] = _poly_getnoise(ep[0:MLKEM_N], s_noiseseed, nonce); + nonce = 4; - ep[MLKEM_N:MLKEM_N] = _poly_getnoise(ep[MLKEM_N:MLKEM_N], noiseseed, nonce); + ep[MLKEM_N:MLKEM_N] = _poly_getnoise(ep[MLKEM_N:MLKEM_N], s_noiseseed, nonce); + nonce = 5; - ep[2*MLKEM_N:MLKEM_N] = _poly_getnoise(ep[2*MLKEM_N:MLKEM_N], noiseseed, nonce); + ep[2*MLKEM_N:MLKEM_N] = _poly_getnoise(ep[2*MLKEM_N:MLKEM_N], s_noiseseed, nonce); nonce = 6; - epp = _poly_getnoise(epp, noiseseed, nonce); + epp = _poly_getnoise(epp, s_noiseseed, nonce); sp = __polyvec_ntt(sp); diff --git a/code/jasmin/mlkem_ref/kem.jinc b/code/jasmin/mlkem_ref/kem.jinc index 4795a352..ee8c60ea 100644 --- a/code/jasmin/mlkem_ref/kem.jinc +++ b/code/jasmin/mlkem_ref/kem.jinc @@ -98,7 +98,7 @@ fn __crypto_kem_dec_jazz(reg u64 shkp, reg u64 ctp, reg u64 skp) { stack u8[MLKEM_CT_LEN] ctpc; stack u8[2*MLKEM_SYMBYTES] kr buf; - stack u64 s_skp s_ctp s_shkp; + stack u64 s_skp s_ctp s_shkp s_cnd; reg u64 pkp hp zp t64 cnd; inline int i; @@ -127,14 +127,17 @@ fn __crypto_kem_dec_jazz(reg u64 shkp, reg u64 ctp, reg u64 skp) ctp = s_ctp; cnd = __verify(ctp, ctpc); + s_cnd = cnd; zp = s_skp; zp += 64; zp += 24 * MLKEM_K * MLKEM_N>>3; /* fixme: should this be done in memory? */ + shkp = s_shkp; _shake256_1120_32(shkp, zp, ctp); shkp = s_shkp; + cnd = s_cnd; __cmov(shkp, kr[0:MLKEM_SYMBYTES], cnd); } diff --git a/proof/correctness/MLKEM_InnerPKE.ec b/proof/correctness/MLKEM_InnerPKE.ec index 98ae995f..e7f13f82 100644 --- a/proof/correctness/MLKEM_InnerPKE.ec +++ b/proof/correctness/MLKEM_InnerPKE.ec @@ -411,12 +411,12 @@ equiv squeezeblock_ignore : Jkem.M(Jkem.Syscall)._shake128_squeezeblock ~Jkem.M(Jkem.Syscall)._shake128_squeezeblock : arg{1}.`1 = arg{2}.`1 ==> ={res}. proof. -proc => /=; seq 1 1: (#pre); 1: by call(_:true) => /=; sim. +proc => /=; seq 3 3: (#pre); 1: by call(_:true) => /=; sim. while (={i} /\ 0<=i{1}<=168 /\ ={state} /\ forall k, 0<=k out{1}.[k] = out{2}.[k]); 1: by auto => />; smt(Array168.set_eqiE Array168.set_neqiE). auto => /> *; split; 1: by smt(). by move => *;rewrite tP => k kb; smt(). -qed. +qed. (* equiv absorb_ignore : @@ -1031,37 +1031,40 @@ equiv auxenc_good : ={Glob.mem,arg} ==> ={Glob.mem,res}. proc. swap {1} 6 -5. -swap {1} 10 -8. +swap {1} 12 -10. seq 2 2 : (#pre /\ ={pkpv}); 1: by sim. swap {1} 6 -5. -swap {1} [9..11] -7. +swap {1} [11..13] -9. seq 4 4 : (#pre /\ ={publicseed}); 1: by sim. -swap {1} 9 -7. +swap {1} 11 -9. seq 2 3 : (#pre /\ aat{1}=at{2}); 1: by call auxgenmatrix_good;auto => />. swap {1} 4 -3. -swap {1} 7 -5. +swap {1} 9 -7. seq 2 2 : (#pre /\ ={k}); 1: by sim. swap {1} [2..3] -1. swap {1} 4 -3. -swap {1} [6..27] -1. -swap {1} 4 21. -swap {1} 3 19. +swap {1} [5..26] -1. +swap {1} 5 29. +swap {1} 25 2. + +seq 25 1 : (#pre /\ ={sp_0,ep,epp}); last by sim. -seq 23 1 : (#pre /\ ={sp_0,ep,epp}). -transitivity {1} {(sp_0,ep,epp) <@ AuxMLKEM.sample_noise3_jasmin(noiseseed);} - (={Glob.mem,msgp,pkp,pkpv,noiseseed,publicseed,k} /\ aat{1} = aat{2} /\ sctp{1} = sctp{2} ==> +swap {1} 5 -3. +swap {1} 4 20. + +transitivity {1} {s_noiseseed <- noiseseed; (sp_0,ep,epp) <@ AuxMLKEM.sample_noise3_jasmin(s_noiseseed);} + (={Glob.mem,msgp,pkp,pkpv,noiseseed,publicseed,k} /\ aat{1} = aat{2} /\ sctp{1} = sctp{2} ==> ={Glob.mem,msgp,pkp,pkpv,noiseseed,publicseed,k,sp_0,ep,epp} /\ aat{1} = aat{2} /\ sctp{1} = sctp{2} ) (={Glob.mem,msgp,pkp,pkpv,noiseseed,publicseed,k} /\ aat{1} = at{2} /\ sctp{1} = ctp{2} ==> - ={Glob.mem,msgp,pkp,pkpv,noiseseed,publicseed,k,sp_0,ep,epp} /\ aat{1} = at{2} /\ sctp{1} = ctp{2} ); 1,2: smt(). -+ by inline AuxMLKEM.sample_noise3_jasmin AuxMLKEM.sample_noise2_jasmin; sim. -+ by conseq />; call sample_noise_good3; auto => />. - -by sim. + ={Glob.mem,msgp,pkp,pkpv,noiseseed,publicseed,k,sp_0,ep,epp} /\ aat{1} = at{2} /\ sctp{1} = ctp{2} ); 1, 2: smt(). ++ seq 2 1 : (={s_noiseseed} /\ #pre); first by auto. + by inline AuxMLKEM.sample_noise3_jasmin AuxMLKEM.sample_noise2_jasmin; sim; auto. +by conseq />; call sample_noise_good3; auto => />. qed. equiv auxienc_good : @@ -1069,39 +1072,41 @@ equiv auxienc_good : ={Glob.mem,arg} ==> ={Glob.mem,res}. proc. swap {1} 6 -5. -swap {1} 12 -10. +swap {1} 14 -12. swap {2} 1 29. seq 2 2 : (#pre /\ ={pkpv}); 1: by sim. swap {1} 6 -5. -swap {1} [11..13] -9. +swap {1} [13..15] -11. seq 4 4 : (#pre /\ ={publicseed}); 1: by sim. -swap {1} 11 -9. +swap {1} 13 -11. seq 2 3 : (#pre /\ aat{1}=at{2}); 1: by call auxgenmatrix_good;auto => />. swap {1} 4 -3. -swap {1} 9 -7. +swap {1} 11 -9. seq 2 2 : (#pre /\ ={k}); 1: by sim. - swap {1} [2..3] -1. swap {1} 5 -4. swap {1} [8..27] -1. -swap {1} 3 22. +swap {1} 3 25. swap {1} 3 24. +swap {1} 27 3. + +swap {1} 6 38. +swap {1} 1 42. -seq 25 1 : (#pre /\ ={sp_0,ep,epp}). -transitivity {1} {(sp_0,ep,epp) <@ AuxMLKEM.sample_noise3_jasmin(noiseseed);} +seq 26 1 : (#pre /\ ={sp_0,ep,epp}). +transitivity {1} {s_noiseseed <- noiseseed; (sp_0,ep,epp) <@ AuxMLKEM.sample_noise3_jasmin(s_noiseseed);} (={Glob.mem,msgp,pkp,pkpv,noiseseed,publicseed,k} /\ aat{1} = aat{2} /\ ctp{1} = ctp{2} ==> ={Glob.mem,msgp,pkp,pkpv,noiseseed,publicseed,k,sp_0,ep,epp} /\ aat{1} = aat{2} /\ ctp{1} = ctp{2} ) (={Glob.mem,msgp,pkp,pkpv,noiseseed,publicseed,k} /\ aat{1} = at{2} /\ ctp{1} = ctp{2} ==> ={Glob.mem,msgp,pkp,pkpv,noiseseed,publicseed,k,sp_0,ep,epp} /\ aat{1} = at{2} /\ ctp{1} = ctp{2} ); 1,2: smt(). -+ by inline AuxMLKEM.sample_noise3_jasmin AuxMLKEM.sample_noise2_jasmin; sim; auto => />. ++ inline AuxMLKEM.sample_noise3_jasmin AuxMLKEM.sample_noise2_jasmin; sim; auto => />. + by conseq />; call sample_noise_good3; auto => />. -swap {1} 2 1. -by sim. +by sim. qed. (******* CORRECTNESS PROOFS TOP LEVEL ************) diff --git a/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec b/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec index 2e4e0015..9eaad408 100644 --- a/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec +++ b/proof/correctness/avx2/MLKEM_InnerPKE_avx2.ec @@ -72,33 +72,59 @@ axiom sha3equiv : equiv [ (* is this in the sha3 paper? *) Jkem_avx2.M(Jkem_avx2.Syscall)._sha3_512_32 ~Jkem.M(Jkem.Syscall)._sha3512_32 : ={arg} ==> ={res}]. +lemma keccakf1600_set_row_ll : islossless M(Syscall).keccakf1600_set_row. +proc. by unroll for ^while; auto. qed. + +lemma keccakf1600_rho_offsets_ll : islossless M(Syscall).keccakf1600_rho_offsets. +proc. by unroll for ^while; islossless. qed. + +lemma keccakf1600_rhotates_ll : islossless M(Syscall).keccakf1600_rhotates. +proc. by call keccakf1600_rho_offsets_ll; islossless. qed. + +lemma keccakf1600_theta_rol_ll : islossless M(Syscall).keccakf1600_theta_rol. +proc. by unroll for ^while; islossless. qed. + +lemma keccakf1600_theta_sum_ll : islossless M(Syscall).keccakf1600_theta_sum. +proc. by do 6!(unroll for ^while); islossless. qed. + +lemma keccakf1600_rol_sum_ll : islossless M(Syscall).keccakf1600_rol_sum. +proc. +while (x <= 5) (5 - x); auto; last smt(). +conseq => /=; call keccakf1600_rhotates_ll; auto => /#. +qed. + +lemma keccakf1600_round_ll : islossless Jkem.M(Syscall).keccakf1600_round. +proc; auto. +while (y <= 5) (5 - y); auto. ++ call keccakf1600_set_row_ll. + call keccakf1600_rol_sum_ll. + auto; smt(). +call keccakf1600_theta_rol_ll. +call keccakf1600_theta_sum_ll. +auto; smt(). +qed. + +lemma keccakf1600_ll : islossless Jkem.M(Syscall)._keccakf1600_. +proc; auto. +call (:true); auto. +call (:true); auto. +while (to_uint c <= 24 /\ to_uint c %% 2 = 0) (24 - to_uint c); auto; last by move => /> *; rewrite ultE to_uint_small //= /#. +call keccakf1600_round_ll; auto. +call keccakf1600_round_ll; auto. +move => /> ??; rewrite ultE to_uintD_small to_uint_small //= /#. +qed. + lemma sha3ll : islossless Jkem.M(Jkem.Syscall)._shake256_128_33. -proc. -while (0<=i<=128) (128 - i);1 : by move => *; auto => /> /#. -wp; call (_: true). -+ while (0<=to_uint rctr <= 192 /\ to_uint rctr %% 8 = 0) (192 - to_uint rctr); last by - auto => /> ?; rewrite ultE; smt(W64.to_uint_cmp W64.to_uintD_small). - move => *; wp; conseq (_: _ ==> true); 1: by - auto => /> &hr; rewrite ultE /= => [#] ???; rewrite !to_uintD_small /=; smt(W64.to_uint_cmp). - call(_: true); 1: by islossless. - wp; call(_: true). - + by do 11!(unroll for ^while); islossless. - call(_: true). - + by do 7!(unroll for ^while); islossless. - call(_: true). - + while (0<=x<=5) (5 - x);last by auto => /> /#. - move => *;wp; while (0<=y<=5) (5 - y);last by auto => /> /#. - move => *;inline 1;wp;sp => /=; conseq (_: _ ==> true); 1: by smt(). - call(_: true). - + by (unroll for ^while); islossless. - by auto => />. - call(_: true). - + by do 13!(unroll for ^while); islossless. - by auto => />. -conseq (_: true); 1: by smt(). -by inline *; do 2!(unroll for ^while); islossless. +proof. +proc. +unroll for 12; wp; conseq => /=. +call keccakf1600_ll; auto. +conseq => /=. +unroll for ^while; auto. +conseq => /=. +inline *; unroll for ^while; auto. qed. - + axiom shake128_equiv_absorb : equiv [ M(Syscall)._shake128_absorb34 ~ Jkem_avx2.M(Jkem_avx2.Syscall)._shake128_absorb34 : ={state, in_0} ==> ={res}]. @@ -317,7 +343,9 @@ qed. hoare matrix_bound : M(Syscall).__gen_matrix : 0 <= to_uint transposed <2 ==> pos_bound2304_cxq res 0 2304 2. conseq auxgenmatrix_good matrix_bound_aux. -move => /> &1 H H0; exists (seed{1},transposed{1} = W64.one); 1: by smt(@W64). +move => /> &1 H H0; exists (seed{1}, to_uint transposed{1} = 1) => /=. +- rewrite W64.to_uint_eq. + by have [] -> : (to_uint transposed{1} = 0 \/ to_uint transposed{1} = 1) by smt(). by smt(). qed. @@ -1620,14 +1648,14 @@ transitivity {1} {Jkem.M(Jkem.Syscall).__indcpa_enc(sctp,msgp,pkp,noiseseed);} inline{1} 1; inline {2} 1. wp. -seq 49 57 : (={ctp,Glob.mem} /\ +seq 49 59 : (={ctp,Glob.mem} /\ pos_bound256_cxq v{1} 0 256 2 /\ pos_bound256_cxq v{2} 0 256 2 /\ lift_array256 v{1} = lift_array256 v{2} /\ valid_ptr (to_uint ctp{1}) 128); last by exists *Glob.mem{1}, (to_uint ctp{1}); elim* => memm _p; call (compressequiv memm _p); auto. -seq 47 55 : (={ctp,Glob.mem} /\ +seq 47 57 : (={ctp,Glob.mem} /\ pos_bound256_cxq v{1} 0 256 2 /\ pos_bound256_cxq v{2} 0 256 2 /\ pos_bound768_cxq bp{1} 0 768 2 /\ @@ -1654,7 +1682,8 @@ swap {1} 3 -2; swap {2} 3 -2; seq 1 1: (#pre /\ ={pkp0} /\ pkp0{2}=pkp{1}); 1: b sp 3 3. swap {1} 18 -1. (* avoid dealing with stack noise seed *) -seq 17 15 : (#pre /\ ={publicseed, bp,ep,epp,v,sp_0,k} /\ +seq 17 17 : (#pre /\ ={publicseed, bp,ep,epp,v,sp_0,k} /\ + s_noiseseed{2} = noiseseed0{2} /\ pos_bound256_cxq k{1} 0 256 1 /\ pos_bound256_cxq k{2} 0 256 1 /\ lift_array768 pkpv{1} = nttunpackv (lift_array768 pkpv{2}) /\ @@ -1666,8 +1695,8 @@ seq 17 15 : (#pre /\ ={publicseed, bp,ep,epp,v,sp_0,k} /\ + call (genmatrixequiv true). call frommsgequiv_noperm. conseq />. smt(). conseq (_: _ ==> lift_array768 pkpv{1} = nttunpackv (lift_array768 pkpv{2}) /\ - pos_bound768_cxq pkpv{1} 0 768 2 /\ pos_bound768_cxq pkpv{2} 0 768 2 /\ ={publicseed,pkp0,bp,ep,epp,v,sp_0,Glob.mem} /\ pkp0{2} = pkp{1}). - auto => /> &2 ????????? rl rr H H0 H1 ????. + pos_bound768_cxq pkpv{1} 0 768 2 /\ pos_bound768_cxq pkpv{2} 0 768 2 /\ ={publicseed,pkp0,bp,ep,epp,v,sp_0,Glob.mem} /\ pkp0{2} = pkp{1} /\ s_noiseseed{2} = noiseseed0{2}). + auto => /> &2 ????????? rl rr H H0 H1 ????. + rewrite tP => k kb. move : H; rewrite /lift_array256 tP => H. move : (H k kb); rewrite !mapiE //=. @@ -1679,7 +1708,7 @@ seq 17 15 : (#pre /\ ={publicseed, bp,ep,epp,v,sp_0,k} /\ rewrite ifF. smt(W16.to_uint_cmp). rewrite ifF. smt(W16.to_uint_cmp). smt(W16.to_uint_eq). - seq 12 10 : (#{/~publicseed{2}}post /\ ={publicseed}). + seq 12 12 : (#{/~publicseed{2}}post /\ ={publicseed}). wp;sp; conseq />. call (polyvec_frombytes_equiv). auto => />. smt(). @@ -1706,11 +1735,12 @@ seq 12 20 : (#{/~bp{1}=bp{2}}pre /\ sctp0{1} = sctp{1} /\ msgp0{1} = msgp{1} /\ noiseseed0{1} = noiseseed{1} /\ - (={Glob.mem, msgp, pkp, noiseseed, sctp} /\ + (={Glob.mem, msgp, noiseseed, pkp, sctp} /\ valid_ptr _pkp (384 * 3 + 32) /\ valid_ptr _ctp (3 * 320 + 128) /\ Glob.mem{1} = mem /\ to_uint sctp{1} = _ctp /\ to_uint pkp{1} = _pkp) /\ ={pkp0}) /\ ={publicseed, bp, ep, epp, v, sp_0, k} /\ + s_noiseseed{2} = noiseseed0{2} /\ pos_bound256_cxq k{1} 0 256 1 /\ pos_bound256_cxq k{2} 0 256 1 /\ lift_array768 pkpv{1} = nttunpackv (lift_array768 pkpv{2}) /\ @@ -1998,13 +2028,13 @@ transitivity {1} { r <@Jkem.M(Jkem.Syscall).__iindcpa_enc(ctp,msgp,pkp,noiseseed inline{1} 1; inline {2} 1. wp. -seq 51 59 : (={ctp0,Glob.mem} /\ Glob.mem{1} = mem /\ +seq 51 61 : (={ctp0,Glob.mem} /\ Glob.mem{1} = mem /\ pos_bound256_cxq v{1} 0 256 2 /\ pos_bound256_cxq v{2} 0 256 2 /\ lift_array256 v{1} = lift_array256 v{2}); last by exists *Glob.mem{1}; elim* => memm; call (compressequiv_1 memm); auto => />; smt(Array1088.tP Array1088.initiE). -seq 49 57 : (={ctp0,Glob.mem} /\ Glob.mem{1} = mem /\ +seq 49 59 : (={ctp0,Glob.mem} /\ Glob.mem{1} = mem /\ pos_bound256_cxq v{1} 0 256 2 /\ pos_bound256_cxq v{2} 0 256 2 /\ pos_bound768_cxq bp{1} 0 768 2 /\ @@ -2029,7 +2059,8 @@ swap {1} 3 -2; swap {2} 3 -2; seq 1 1: (#pre /\ ={pkp0} /\ pkp0{2} = pkp{1}); 1: sp 3 3. swap {1} 19 1. (* avoid dealing with stack noise seed *) -seq 19 17 : (#pre /\ ={publicseed, bp,ep,epp,v,sp_0,k,sctp} /\ +seq 19 19 : (#pre /\ ={publicseed, bp,ep,epp,v,sp_0,k,sctp} /\ + s_noiseseed{2} = noiseseed0{2} /\ pos_bound256_cxq k{1} 0 256 1 /\ pos_bound256_cxq k{2} 0 256 1 /\ lift_array768 pkpv{1} = nttunpackv (lift_array768 pkpv{2}) /\ @@ -2041,7 +2072,7 @@ seq 19 17 : (#pre /\ ={publicseed, bp,ep,epp,v,sp_0,k,sctp} /\ + call (genmatrixequiv true). wp;call frommsgequiv_noperm. conseq />. smt(). conseq (_: _ ==> lift_array768 pkpv{1} = nttunpackv (lift_array768 pkpv{2}) /\ - pos_bound768_cxq pkpv{1} 0 768 2 /\ pos_bound768_cxq pkpv{2} 0 768 2 /\ ={publicseed,pkp0,bp,ep,epp,v,sp_0,sctp,Glob.mem} /\ pkp0{2} = pkp{1}). + pos_bound768_cxq pkpv{1} 0 768 2 /\ pos_bound768_cxq pkpv{2} 0 768 2 /\ ={publicseed,pkp0,bp,ep,epp,v,sp_0,sctp,Glob.mem} /\ pkp0{2} = pkp{1} /\ s_noiseseed{2} = noiseseed0{2}). auto => /> &2 ??????? rl rr H H0 H1 ????. + rewrite tP => k kb. move : H; rewrite /lift_array256 tP => H. @@ -2054,7 +2085,7 @@ seq 19 17 : (#pre /\ ={publicseed, bp,ep,epp,v,sp_0,k,sctp} /\ rewrite ifF. smt(W16.to_uint_cmp). rewrite ifF. smt(W16.to_uint_cmp). smt(W16.to_uint_eq). - seq 14 12 : (#{/~publicseed{2}}post /\ ={publicseed}). + seq 14 14 : (#{/~publicseed{2}}post /\ ={publicseed}). wp;sp; conseq />. call (polyvec_frombytes_equiv). auto => />. smt(). @@ -2083,6 +2114,7 @@ seq 12 20 : (#{/~bp{1}=bp{2}}pre /\ (={Glob.mem, msgp, pkp, noiseseed} /\ valid_ptr _pkp (384 * 3 + 32) /\ Glob.mem{1} = mem /\ to_uint pkp{1} = _pkp) /\ ={pkp0}) /\ ={publicseed, bp, ep, epp, v, sp_0, k} /\ + s_noiseseed{2} = noiseseed0{2} /\ pos_bound256_cxq k{1} 0 256 1 /\ pos_bound256_cxq k{2} 0 256 1 /\ lift_array768 pkpv{1} = nttunpackv (lift_array768 pkpv{2}) /\