From 75ac8b8029db29462bbb859877b7b0734fac1d72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jose=CC=81=20Bacelar=20Almeida?= Date: Fri, 18 Oct 2024 18:26:33 +0100 Subject: [PATCH] transition to new keccak --- code/jasmin/mlkem_avx2/Makefile | 13 +- code/jasmin/mlkem_avx2/gen_matrix.jinc | 55 ++- code/jasmin/mlkem_avx2/indcpa.jinc | 7 +- .../keccak/keccak1600_array_avx2_ASIZE.jinc | 384 ++++++++++++++++ .../mlkem_avx2/keccak/keccak1600_avx2.jinc | 318 +++++++++++++ .../mlkem_avx2/keccak/keccak1600_globals.jinc | 11 + .../keccak/keccak1600_imem_avx2.jinc | 349 +++++++++++++++ .../keccak/keccak1600_orig_avx2.jinc | 281 ++++++++++++ .../keccak/keccak1600_orig_avx2_ASIZE.jinc | 247 ++++++++++ .../keccak/keccak1600x4_array_avx2_ASIZE.jinc | 416 +++++++++++++++++ .../mlkem_avx2/keccak/keccak1600x4_avx2.jinc | 92 ++++ .../keccak/keccak1600x4_imem_avx2.jinc | 421 ++++++++++++++++++ .../mlkem_avx2/keccak/keccakf1600_avx2.jinc | 58 ++- .../keccak/keccakf1600_globals.jinc | 36 ++ .../mlkem_avx2/keccak/keccakf1600x4_avx2.jinc | 333 ++++++++++++++ .../keccak/subreadwrite_array_ASIZE.jinc | 261 +++++++++++ .../mlkem_avx2/keccak/subreadwrite_imem.jinc | 244 ++++++++++ .../mlkem_avx2/{ => keccak_OLD}/fips202.jinc | 4 +- .../{ => keccak_OLD}/fips202_4x.jinc | 22 + .../{ => keccak_OLD}/fips202_common.jinc | 0 .../{ => keccak_OLD}/gen_matrix_old.jinc | 0 .../{keccak => keccak_OLD}/keccakf1600.jinc | 0 .../keccakf1600_4x_avx2_compact.jinc | 0 .../keccak_OLD/keccakf1600_avx2.jinc | 316 +++++++++++++ .../keccakf1600_generic.jinc | 0 code/jasmin/mlkem_avx2/kem.h | 10 +- code/jasmin/mlkem_avx2/kem.jinc | 19 +- code/jasmin/mlkem_avx2/mlkem_keccak_avx2.jinc | 251 +++++++++++ .../mlkem_keccak_avx2_TRANSITION.jinc | 175 ++++++++ code/jasmin/mlkem_avx2/poly.jinc | 81 ++++ code/jasmin/mlkem_avx2/test/test_kem.c | 10 +- 31 files changed, 4356 insertions(+), 58 deletions(-) create mode 100644 code/jasmin/mlkem_avx2/keccak/keccak1600_array_avx2_ASIZE.jinc create mode 100644 code/jasmin/mlkem_avx2/keccak/keccak1600_avx2.jinc create mode 100644 code/jasmin/mlkem_avx2/keccak/keccak1600_globals.jinc create mode 100644 code/jasmin/mlkem_avx2/keccak/keccak1600_imem_avx2.jinc create mode 100644 code/jasmin/mlkem_avx2/keccak/keccak1600_orig_avx2.jinc create mode 100644 code/jasmin/mlkem_avx2/keccak/keccak1600_orig_avx2_ASIZE.jinc create mode 100644 code/jasmin/mlkem_avx2/keccak/keccak1600x4_array_avx2_ASIZE.jinc create mode 100644 code/jasmin/mlkem_avx2/keccak/keccak1600x4_avx2.jinc create mode 100644 code/jasmin/mlkem_avx2/keccak/keccak1600x4_imem_avx2.jinc create mode 100644 code/jasmin/mlkem_avx2/keccak/keccakf1600_globals.jinc create mode 100644 code/jasmin/mlkem_avx2/keccak/keccakf1600x4_avx2.jinc create mode 100644 code/jasmin/mlkem_avx2/keccak/subreadwrite_array_ASIZE.jinc create mode 100644 code/jasmin/mlkem_avx2/keccak/subreadwrite_imem.jinc rename code/jasmin/mlkem_avx2/{ => keccak_OLD}/fips202.jinc (98%) rename code/jasmin/mlkem_avx2/{ => keccak_OLD}/fips202_4x.jinc (98%) rename code/jasmin/mlkem_avx2/{ => keccak_OLD}/fips202_common.jinc (100%) rename code/jasmin/mlkem_avx2/{ => keccak_OLD}/gen_matrix_old.jinc (100%) rename code/jasmin/mlkem_avx2/{keccak => keccak_OLD}/keccakf1600.jinc (100%) rename code/jasmin/mlkem_avx2/{keccak => keccak_OLD}/keccakf1600_4x_avx2_compact.jinc (100%) create mode 100644 code/jasmin/mlkem_avx2/keccak_OLD/keccakf1600_avx2.jinc rename code/jasmin/mlkem_avx2/{keccak => keccak_OLD}/keccakf1600_generic.jinc (100%) create mode 100644 code/jasmin/mlkem_avx2/mlkem_keccak_avx2.jinc create mode 100644 code/jasmin/mlkem_avx2/mlkem_keccak_avx2_TRANSITION.jinc diff --git a/code/jasmin/mlkem_avx2/Makefile b/code/jasmin/mlkem_avx2/Makefile index 5bb618f0..b9bfd118 100644 --- a/code/jasmin/mlkem_avx2/Makefile +++ b/code/jasmin/mlkem_avx2/Makefile @@ -3,6 +3,8 @@ -include ../../Makefile.conf +JADDFLAGS ?= -lazy-regalloc + CC ?= /usr/bin/gcc GFLAGS ?= CFLAGS := -Wall -Wextra -g -Ofast -fomit-frame-pointer @@ -11,9 +13,12 @@ OS := $(shell uname -s) .SECONDARY: jpoly.s jpolyvec.s jfips202.s jindcpa.s jindcpa.o jkem.s -default: test speed +default: test +#default: test speed + +test: test/test_kem -test: test/test_poly_compress \ +testX: test/test_poly_compress \ test/test_poly_decompress \ test/test_poly_tobytes \ test/test_poly_frombytes \ @@ -71,7 +76,7 @@ test/test_indcpa: test/test_indcpa.c $(HEADERS) $(SOURCES) $(INCS) jindcpa.o $(CC) $(CFLAGS) -o $@ $(SOURCES) jindcpa.o $< test/test_kem: test/test_kem.c $(HEADERS) $(SOURCES) $(INCS) jkem.o - $(CC) $(CFLAGS) -o $@ $(SOURCES) jkem.o ~/Desktop/Repos/jasmin/compiler/syscall/jasmin_syscall.o $< + $(CC) $(CFLAGS) -o $@ $(SOURCES) jkem.o $(current_dir)../jasmin/compiler/syscall/jasmin_syscall.o $< test/speed_indcpa: test/speed_indcpa.c $(HEADERS) $(SOURCES) $(INCS) jindcpa.o $(CC) $(CFLAGS) -o $@ $(SOURCES) jindcpa.o $< @@ -92,7 +97,7 @@ test/test_polyvec_%: test/test_polyvec_%.c $(HEADERS) $(SOURCES) $(INCS) jpolyve $(CC) $(CFLAGS) -o $@ $(SOURCES) jpolyvec.s $< %.s: %.jazz - $(JASMINC) -o $@ $(JFLAGS) $^ + $(JASMINC) $(JFLAGS) -o $@ $(JFLAGS) $^ .PHONY: ct sct clean diff --git a/code/jasmin/mlkem_avx2/gen_matrix.jinc b/code/jasmin/mlkem_avx2/gen_matrix.jinc index 4ecf9e13..637f1784 100644 --- a/code/jasmin/mlkem_avx2/gen_matrix.jinc +++ b/code/jasmin/mlkem_avx2/gen_matrix.jinc @@ -1,5 +1,11 @@ -require "keccak/keccakf1600_4x_avx2_compact.jinc" +/* // OLD INTERFACE +require "keccak/keccakf1600x4_avx2.jinc" require "keccak/keccakf1600_avx2.jinc" +*/ +// NEW INTERFACE +require "mlkem_keccak_avx2_TRANSITION.jinc" + + require "params.jinc" require "gen_matrix_globals.jinc" @@ -33,7 +39,7 @@ inline fn comp_u64_l_int_and_u64_l_int( // BUF_size per entry: 21(rate) + 21(rate) + 25(keccak_state) + 1(pad) param int BUF_size = 536; // 168*2+200 (was in u64s: 3*21 + 4 + 1; //544 bytes; - +/* // deinterleave u64-lanes of 4 u256 regs inline fn _4u64x4_u256x4(reg u256 y0 y1 y2 y3) -> reg u256, reg u256, reg u256, reg u256 { reg u256 x0, x1, x2, x3; @@ -331,6 +337,8 @@ inline fn xof_init_avx2 return state; } +*/ + /* DEFS: @@ -685,10 +693,12 @@ inline fn gen_matrix_get_indexes( reg u64 b, reg u64 _t) -> - reg u16[4] + stack u8[4*2] +// reg u16[4] { reg u64 t; - reg u16[4] idx; +// reg u16[4] idx; + stack u8[4*2] idx; reg ptr u16[2*4*2] gmi; gmi = gen_matrix_indexes; @@ -696,10 +706,10 @@ inline fn gen_matrix_get_indexes( t = _t; t <<= 3; // t * 8 b += t; - idx[0] = gmi[(int) b + 0]; - idx[1] = gmi[(int) b + 1]; - idx[2] = gmi[(int) b + 2]; - idx[3] = gmi[(int) b + 3]; + idx[u16 0] = gmi[(int) b + 0]; + idx[u16 1] = gmi[(int) b + 1]; + idx[u16 2] = gmi[(int) b + 2]; + idx[u16 3] = gmi[(int) b + 3]; return idx; } @@ -710,15 +720,18 @@ fn __gen_matrix_fill_polynomial ) -> reg ptr u16[MLKEM_N], reg ptr u8[BUF_size] { reg u64 counter, buf_offset; - reg u256[7] stavx2; +// reg u256[7] stavx2; buf_offset = 0; counter = 0; pol, counter = _gen_matrix_buf_rejection(pol, counter, buf, buf_offset); buf_offset = 2*168; while (counter < MLKEM_N) { +/* stavx2 = _stavx2_pack_at(buf, buf_offset); stavx2 = _keccakf1600_avx2(stavx2); buf = _stavx2_unpack_at(buf, buf_offset, stavx2); +*/ + buf = _shake128_next_state(buf); pol, counter = _gen_matrix_buf_rejection(pol, counter, buf, buf_offset); } @@ -733,19 +746,24 @@ fn _gen_matrix_sample_four_polynomials , reg u64 transposed ) -> reg ptr u16[4*MLKEM_N], reg ptr u8[BUF_size * 4] { - reg u64 buf_offset; +// reg u64 buf_offset; reg ptr u16[MLKEM_N] pol; stack u256[25] state; reg ptr u256[25] stx4; - reg u16[4] indexes; +// reg u16[4] indexes; + stack u8[4*2] indexes; + indexes = gen_matrix_get_indexes(mat_entry, transposed); stx4 = state; + stx4 = _shake128x4_absorb_A32_A2(stx4, rho, indexes); + _, buf = _shake128x4_squeeze3blocks(stx4, buf); +/* stx4 = xof_init_x4(rho, indexes); buf_offset = 0; while (buf_offset < 3*168) { - stx4 = _keccakf1600_4x(stx4); + stx4 = _keccakf1600_avx2x4(stx4); buf[BUF_size * 0 : BUF_size], buf[BUF_size * 1 : BUF_size], @@ -759,7 +777,7 @@ fn _gen_matrix_sample_four_polynomials buf_offset += 168; } - +*/ pol = polx4[0*MLKEM_N:MLKEM_N]; pol, buf[BUF_size * 0 : BUF_size] = __gen_matrix_fill_polynomial(pol, buf[BUF_size * 0 : BUF_size]); polx4[0*MLKEM_N:MLKEM_N] = pol; @@ -787,8 +805,10 @@ inline fn __gen_matrix_sample_one_polynomial ) -> reg ptr u16[MLKEM_N], reg ptr u8[BUF_size] { reg u256[7] stavx2; - reg u64 buf_offset; - +// reg u64 buf_offset; + stack u8[2] pos; + +/* stavx2 = xof_init_avx2(rho, rc); buf_offset = 0; while (buf_offset < 3*168) { @@ -796,6 +816,11 @@ inline fn __gen_matrix_sample_one_polynomial buf = _stavx2_unpack_at( buf, buf_offset, stavx2 ); buf_offset += 168; } +*/ + + pos[u16 0] = rc; + stavx2 = _shake128_absorb_A32_A2(rho, pos); + buf = _shake128_squeeze3blocks(buf, stavx2); pol, buf = __gen_matrix_fill_polynomial(pol, buf); diff --git a/code/jasmin/mlkem_avx2/indcpa.jinc b/code/jasmin/mlkem_avx2/indcpa.jinc index f2323f1b..a9e827c9 100644 --- a/code/jasmin/mlkem_avx2/indcpa.jinc +++ b/code/jasmin/mlkem_avx2/indcpa.jinc @@ -24,7 +24,8 @@ fn __indcpa_keypair(#spill_to_mmx reg u64 pkp skp, reg ptr u8[MLKEM_SYMBYTES] ra inbuf[u64 i] = t64; } - buf = _sha3_512_32(buf, inbuf); + //buf = _sha3_512_32(buf, inbuf); + buf = _sha3_512A_A32(buf, inbuf); for i=0 to MLKEM_SYMBYTES/8 { @@ -105,10 +106,10 @@ fn __indcpa_enc_0(#mmx reg u64 sctp, reg ptr u8[MLKEM_INDCPA_MSGBYTES] msgp, reg aat = _gen_matrix_avx2(aat, publicseed, transposed); lnoiseseed = s_noiseseed; - nonce = 0; sp[0:MLKEM_N], sp[MLKEM_N:MLKEM_N], sp[2*MLKEM_N:MLKEM_N], ep[0:MLKEM_N] = _poly_getnoise_eta1_4x(sp[0:MLKEM_N], sp[MLKEM_N:MLKEM_N], sp[2*MLKEM_N:MLKEM_N], ep[0:MLKEM_N], lnoiseseed, nonce); + lnoiseseed = s_noiseseed; nonce = 4; ep[MLKEM_N:MLKEM_N], ep[2*MLKEM_N:MLKEM_N], epp, bp[0:MLKEM_N] = _poly_getnoise_eta1_4x(ep[MLKEM_N:MLKEM_N], ep[2*MLKEM_N:MLKEM_N], epp, bp[0:MLKEM_N], lnoiseseed, nonce); @@ -177,10 +178,10 @@ fn __indcpa_enc_1( aat = _gen_matrix_avx2(aat, publicseed, transposed); lnoiseseed = s_noiseseed; - nonce = 0; sp[0:MLKEM_N], sp[MLKEM_N:MLKEM_N], sp[2*MLKEM_N:MLKEM_N], ep[0:MLKEM_N] = _poly_getnoise_eta1_4x(sp[0:MLKEM_N], sp[MLKEM_N:MLKEM_N], sp[2*MLKEM_N:MLKEM_N], ep[0:MLKEM_N], lnoiseseed, nonce); + lnoiseseed = s_noiseseed; nonce = 4; ep[MLKEM_N:MLKEM_N], ep[2*MLKEM_N:MLKEM_N], epp, bp[0:MLKEM_N] = _poly_getnoise_eta1_4x(ep[MLKEM_N:MLKEM_N], ep[2*MLKEM_N:MLKEM_N], epp, bp[0:MLKEM_N], lnoiseseed, nonce); diff --git a/code/jasmin/mlkem_avx2/keccak/keccak1600_array_avx2_ASIZE.jinc b/code/jasmin/mlkem_avx2/keccak/keccak1600_array_avx2_ASIZE.jinc new file mode 100644 index 00000000..9e3cfc5a --- /dev/null +++ b/code/jasmin/mlkem_avx2/keccak/keccak1600_array_avx2_ASIZE.jinc @@ -0,0 +1,384 @@ +/* DEPENDENCIES +require "keccak1600_avx2.jinc" +param int ASIZE = 1002; +*/ + +require "subreadwrite_array_ASIZE.jinc" + +/* + ONE-SHOT (FIXED-SIZE) ARRAY ABSORB + ================================== +*/ + +inline fn __addstate_array_avx2 +( reg u256[7] st +, reg const ptr u8[ASIZE] buf +, reg u64 offset +, inline int LEN +, inline int TRAILB +) -> reg ptr u256[7] /* st */ + , reg u64 /* offset */ +{ + reg u64 t64; + reg u256 r0, r1, r2, r3, r4, r5, r6; + reg u128 t128_0, t128_1; + inline int DELTA; + DELTA = 0; + + DELTA, LEN, TRAILB, t64 = __aread_subu64(buf, offset, DELTA, LEN, TRAILB); + t128_0 = (128u) t64; + r0 = #VPBROADCAST_4u64(t128_0); + st[0] ^= r0; + + DELTA, LEN, TRAILB, r1 = __aread_subu256(buf, offset, DELTA, LEN, TRAILB); + st[1] ^= r1; + + if (0 < LEN ) { + DELTA, LEN, TRAILB, t64 = __aread_subu64(buf,offset, DELTA, LEN, TRAILB); + t128_1 = (128u) t64; + + DELTA, LEN, TRAILB, r3 = __aread_subu256(buf, offset, DELTA, LEN, TRAILB); + + DELTA, LEN, TRAILB, t64 = __aread_subu64(buf, offset, DELTA, LEN, TRAILB); + t128_0 = (128u) t64; + + DELTA, LEN, TRAILB, r4 = __aread_subu256(buf, offset, DELTA, LEN, TRAILB); + + DELTA, LEN, TRAILB, t64 = __aread_subu64(buf, offset, DELTA, LEN, TRAILB); + t128_1 = #VPINSR_2u64(t128_1, t64, 1); + + DELTA, LEN, TRAILB, r5 = __aread_subu256(buf, offset, DELTA, LEN, TRAILB); + + DELTA, LEN, TRAILB, t64 = __aread_subu64(buf, offset, DELTA, LEN, TRAILB); + t128_0 = #VPINSR_2u64(t128_0, t64, 1); + r2 = (2u128)[t128_1, t128_0]; + st[2] ^= r2; + + DELTA, LEN, TRAILB, r6 = __aread_subu256(buf, offset, DELTA, LEN, TRAILB); + + st = __addstate_r3456( st, r3, r4, r5, r6); + } + offset += DELTA; + return st, offset; +} + +inline fn __absorb_array_avx2 +( reg u256[7] st +, reg const ptr u8[ASIZE] buf +, reg u64 offset +, inline int LEN +, inline int RATE8 +, inline int TRAILB /* closes state if !=0 (i.e. adds trailbyte and padding) */ +) -> reg u256[7] /* st */ + , reg u64 /* offset */ +{ + reg u64 i; + inline int ALL, ITERS; + + ALL = LEN + (TRAILB!=0 ? 1 : 0); + + // continue by processing full blocks + ITERS = LEN / RATE8; // number of full blocks + if (0 < ITERS) { + i = 0; + while ( i < ITERS ) { + st, offset = __addstate_array_avx2(st, buf, offset, RATE8, 0); + st = _keccakf1600_avx2(st); + i += 1; + } + } + + // last incomplete block + LEN = LEN % RATE8; + st, offset = __addstate_array_avx2(st, buf, offset, LEN, TRAILB); + if (TRAILB!=0) { st = __addratebit_avx2(st, RATE8); } + + return st, offset; +} + +/* + INCREMENTAL (FIXED-SIZE) MEMORY ABSORB + ====================================== +*/ + +inline fn __pstate_array_avx2 +( reg mut ptr u64[25] pst +, inline int AT /* bytes (0 <= AT < 200) */ +, reg const ptr u8[ASIZE] buf +, reg u64 offset +, inline int LEN +, inline int TRAILB +) -> reg ptr u64[25] /* pst */ + , inline int /* AT */ + , reg u64 /* offset */ +{ + inline int DELTA, LO, ALL; + reg u64 at, t64; + reg u256 t256; + reg u128 t128; + + DELTA = 0; + ALL = AT+LEN; // total bytes to process (excluding trail byte, if !=0) + LO = AT % 8; // leftover bytes + at = AT / 8; // current pstate position + + if ( 0 < LO ) { // process first word... + if ( LO+LEN < 8) { // ...not enough to fill a word (just update it) + if ( TRAILB != 0 ) { ALL += 1; } + DELTA, _, TRAILB, t64 = __aread_subu64(buf, offset, DELTA, LEN, TRAILB); + t64 <<= 8*LO; + pst[(int) at] ^= t64; + LO = 0; + AT = 0; + LEN = 0; + } else { // process first word + if ( 8 <= LEN ) { + t64 = buf.[u64 offset + DELTA]; + DELTA += (8-LO); + } else { + DELTA, _, _, t64 = __aread_subu64(buf, offset, DELTA, 8-LO, 0); + } + LEN -= 8-LO; + AT += 8-LO; + t64 <<= 8*LO; + pst[(int) at] ^= t64; + at += 1; + } + } + + // continue processing remaining bytes + if (32 <= LEN) { + offset += DELTA; + DELTA = 0; + while ( at < AT/8+4*(LEN/32)) { + t256 = buf.[u256 offset]; + offset += 32; + pst.[u256 8*at] = t256; + at += 4; + } + LEN = LEN % 32; + } + if (16 <= LEN) { + t128 = buf.[u128 offset + DELTA]; + DELTA += 16; + pst.[u128 8*at] = t128; + at += 2; + LEN -= 16; + } + if (8 <= LEN) { + t64 = buf.[u64 offset + DELTA]; + DELTA += 8; + pst.[u64 8*at] = t64; + at += 1; + LEN -= 8; + } + + // process last word (possibly closing the state) + LO = (AT+LEN) % 8; + if ( 0 < LO || TRAILB != 0 ) { + if ( TRAILB != 0 ) { ALL += 1; } + DELTA, _, TRAILB, t64 = __aread_subu64(buf, offset, DELTA, LO, TRAILB); + pst[u64 (ALL/8)] = t64; + } + offset += DELTA; + return pst, ALL, offset; +} + +inline fn __pabsorb_array_avx2 +( reg mut ptr u64[25] pst +, inline int AT +, reg u256[7] st +, reg const ptr u8[ASIZE] buf +, reg u64 offset +, inline int LEN +, inline int RATE8 +, inline int TRAILB /* closes state if !=0 (i.e. adds trailbyte and padding) */ +) -> reg ptr u64[25] /* pst */ + , inline int /* AT */ + , reg u256[7] /* st */ + , reg u64 /* offset */ +{ + reg u64 i; + inline int ALL, ITERS; + + ALL = AT + LEN; + if ( (AT+LEN) < RATE8 ) { // not enough to fill a block! + pst, AT, offset = __pstate_array_avx2(pst, AT, buf, offset, LEN, TRAILB); + if (TRAILB != 0) { // add pstate and closes the state + i = AT/8 + 1; + if (AT <= 5*8) { // only st[0..1] is affected + while (i < 5) { pst[i] = 0; i += 1; } + st = __addpst01(st, pst); + st = __addratebit_avx2(st, RATE8); + } else { // all state is affected + while (i < RATE8/8) { pst[i] = 0; i += 1; } + pst[u8 RATE8-1] ^= 0x80; + st = _addpstate_avx2(st, pst); + } + } + } else { // at least a block is filled + if ( AT != 0 ) { // start by filling the first block + pst, _, offset = __pstate_array_avx2(pst, AT, buf, offset, RATE8-AT, 0); + LEN = LEN - (RATE8-AT); + st = _addpstate_avx2(st, pst); + st = _keccakf1600_avx2(st); + AT = 0; + } + + // continue by processing full blocks + ITERS = LEN / RATE8; // number of full blocks + i = 0; + while ( i < ITERS ) { + st, offset = __addstate_array_avx2(st, buf, offset, RATE8, 0); + st = _keccakf1600_avx2(st); + i += 1; + } + + // last incomplete block + LEN = ALL % RATE8; + if (TRAILB!=0) { + st, offset = __addstate_array_avx2(st, buf, offset, LEN, TRAILB); + st = __addratebit_avx2(st, RATE8); + } else if ( LEN != 0) { + pst, AT, offset = __pstate_array_avx2(pst, 0, buf, offset, LEN, TRAILB); +/* + if (TRAILB != 0) { // add pstate and closes the state + i = AT/8 + 1; + if (AT <= 5*8) { // only st[0..1] is affected + while (i < 5) { pst[i] = 0; i += 1; } + st = __addpst01(st, pst); + st = __addratebit_avx2(st, RATE8); + } else { // all state is affected + while (i < RATE8/8) { pst[i] = 0; i += 1; } + pst[u8 RATE8-1] ^= 0x80; + st = _addpstate_avx2(st, pst); + } + } +*/ + } + } + return pst, AT, st, offset; +} + +/* + ONE-SHOT (FIXED-SIZE) MEMORY SQUEEZE + ==================================== +*/ + +inline fn __dumpstate_array_avx2 +( reg mut ptr u8[ASIZE] buf +, reg u64 offset +, inline int LEN +, reg u256[7] st +) -> reg ptr u8[ASIZE] /* buf */ + , reg u64 /* offset */ +{ + reg u64 t; + reg u128 t128_0, t128_1; + reg u256 t256_0, t256_1, t256_2, t256_3, t256_4; + inline int DELTA; + + DELTA = 0; + + // reg0 + if (8 <= LEN) { + buf, DELTA, _ = __awrite_subu256(buf, offset, DELTA, 8, st[0]); + LEN -= 8; + } else { + buf, DELTA, LEN = __awrite_subu256(buf, offset, DELTA, LEN, st[0]); + } + + // reg1 + buf, DELTA, LEN = __awrite_subu256(buf, offset, DELTA, LEN, st[1]); + + // reg2 (5) + if (0 reg ptr u8[ASIZE] /* buf */ + , reg u256[7] /* st */ +{ + reg u64 i; + inline int ITERS, LO; + ITERS = LEN/RATE8; + LO = LEN%RATE8; + if (0 reg u256[7] +{ + inline int i; + reg u256[7] st; + + for i=0 to 7 { st[i] = #set0_256(); } + + return st; +} + +/* + PSTATE - UNPERMUTED KECCAK STATE + ================================ +*/ +inline fn __pstate_init_avx2 +( reg mut ptr u64[25] pst +) -> reg ptr u64[25] +{ + inline int i; + reg u64 z64; + reg u256 z256; + + z256 = #set0_256(); + for i=0 to 25/4 { pst[u256 i] = z256; } + z64 = 0; + pst[24] = z64; + + return pst; +} + +inline fn __perm_reg3456_avx2 +( reg u256 r3 r4 r5 r6 +) -> reg u256 /* st[3] */ + , reg u256 /* st[4] */ + , reg u256 /* st[5] */ + , reg u256 /* st[6] */ +{ + reg u256 t256_0, t256_1, t256_2; + reg u256 st3, st4, st5, st6; + // [ 16 7 8 19 ] + t256_0 = #VPBLEND_8u32(r3, r5, (8u1)[1,1,0,0,0,0,1,1]); + // [ 11 22 23 14 ] + t256_1 = #VPBLEND_8u32(r6, r4, (8u1)[1,1,0,0,0,0,1,1]); + // [ 6 12 13 9 ] + t256_2 = #VPBLEND_8u32(r4, r3, (8u1)[1,1,0,0,0,0,1,1]); + // [ 16 7 23 14 ] + st3 = #VPBLEND_8u32(t256_0, t256_1, (8u1)[1,1,1,1,0,0,0,0]); + // [ 11 22 8 19 ] + st4 = #VPBLEND_8u32(t256_1, t256_0, (8u1)[1,1,1,1,0,0,0,0]); + // [ 21 17 18 24 ] + t256_0 = #VPBLEND_8u32(r5, r6, (8u1)[1,1,0,0,0,0,1,1]); + // [ 21 17 13 9 ] + st5 = #VPBLEND_8u32(t256_0, t256_2, (8u1)[1,1,1,1,0,0,0,0]); + // [ 6 12 18 24 ] + st6 = #VPBLEND_8u32(t256_2, t256_0, (8u1)[1,1,1,1,0,0,0,0]); + + return st3, st4, st5, st6; +} + +inline fn __unperm_reg3456_avx2 +( reg u256 st3 st4 st5 st6 +) -> reg u256 /* r3 */ + , reg u256 /* r4 */ + , reg u256 /* r5 */ + , reg u256 /* r6 */ +{ + reg u256 t256_0, t256_1, t256_2, t256_3; + reg u256 r3, r4, r5, r6; + // [ 16, 7, 8, 19 ] + t256_0 = #VPBLEND_8u32(st3, st4, (8u1)[1,1,1,1,0,0,0,0]); + // [ 11, 22, 23, 14 ] + t256_1 = #VPBLEND_8u32(st4, st3, (8u1)[1,1,1,1,0,0,0,0]); + // [ 21, 17, 18, 24 ] + t256_2 = #VPBLEND_8u32(st5, st6, (8u1)[1,1,1,1,0,0,0,0]); + // [ 6, 12, 13, 9 ] + t256_3 = #VPBLEND_8u32(st6, st5, (8u1)[1,1,1,1,0,0,0,0]); + // [ 6, 7, 8, 9 ] + r3 = #VPBLEND_8u32(t256_0, t256_3, (8u1)[1,1,0,0,0,0,1,1]); + // [ 11, 12, 13, 14 ] + r4 = #VPBLEND_8u32(t256_3, t256_1, (8u1)[1,1,0,0,0,0,1,1]); + // [ 16, 17, 18, 19 ] + r5 = #VPBLEND_8u32(t256_2, t256_0, (8u1)[1,1,0,0,0,0,1,1]); + // [ 21, 22, 23, 24 ] + r6 = #VPBLEND_8u32(t256_1, t256_2, (8u1)[1,1,0,0,0,0,1,1]); + + return r3, r4, r5, r6; +} + +/* + STATE READ + ========== +*/ +inline fn __state_from_pstate_avx2 +( reg const ptr u64[25] pst +) -> reg u256[7] +{ + reg u256[7] st; + reg u128 t128_0, t128_1; + reg u64 t; + + st[0] = #VPBROADCAST_4u64(pst.[u64 0]); + st[1] = pst.[u256 8]; + + // [ 5 - ] + t128_0 = #VMOV(pst.[u64 5*8]); + // [ 6 7 8 9 ] + st[3] = pst.[u256 6*8]; + // [ 10 - ] + t128_1 = #VMOV(pst.[u64 10*8]); + // [ 11 12 13 14 ] + st[4] = pst.[u256 11*8]; + // [ 5 15 ] + t = pst.[u64 15*8]; + t128_0 = #VPINSR_2u64(t128_0, t, 1); + // [ 16 17 18 19 ] + st[5] = pst.[u256 16*8]; + // [ 10 20 ] + t = pst.[u64 20*8]; + t128_1 = #VPINSR_2u64(t128_1, t, 1); + // [ 10 20 5 15 ] + st[2] = (2u128)[t128_0, t128_1]; + // [ 21 22 23 24 ] + st[6] = pst.[u256 21*8]; + st[3], st[4], st[5], st[6] = __perm_reg3456_avx2(st[3], st[4], st[5], st[6]); + + return st; +} + + +inline fn __addstate_r3456 +( reg u256[7] st +, reg u256 r3 r4 r5 r6 +) -> reg u256[7] +{ + r3, r4, r5, r6 = __perm_reg3456_avx2(r3, r4, r5, r6); + st[3] ^= r3; + st[4] ^= r4; + st[5] ^= r5; + st[6] ^= r6; +/* + reg u256 t256_0, t256_1, t256_2; + // [ 16 7 8 19 ] + t256_0 = #VPBLEND_8u32(r3, r5, (8u1)[1,1,0,0,0,0,1,1]); + // [ 11 22 23 14 ] + t256_1 = #VPBLEND_8u32(r6, r4, (8u1)[1,1,0,0,0,0,1,1]); + // [ 6 12 13 9 ] + t256_2 = #VPBLEND_8u32(r4, r3, (8u1)[1,1,0,0,0,0,1,1]); + // [ 16 7 23 14 ] + r3 = #VPBLEND_8u32(t256_0, t256_1, (8u1)[1,1,1,1,0,0,0,0]); + st[3] ^= r3; + // [ 11 22 8 19 ] + r4 = #VPBLEND_8u32(t256_1, t256_0, (8u1)[1,1,1,1,0,0,0,0]); + st[4] ^= r4; + // [ 21 17 18 24 ] + t256_0 = #VPBLEND_8u32(r5, r6, (8u1)[1,1,0,0,0,0,1,1]); + // [ 21 17 13 9 ] + r5 = #VPBLEND_8u32(t256_0, t256_2, (8u1)[1,1,1,1,0,0,0,0]); + st[5] ^= r5; + // [ 6 12 18 24 ] + r6 = #VPBLEND_8u32(t256_2, t256_0, (8u1)[1,1,1,1,0,0,0,0]); + st[6] ^= r6; +*/ + return st; +} + +inline fn __addpst01 +( reg u256[7] st +, reg const ptr u64[25] pst +) -> reg u256[7] +{ + reg u256 t256; + t256 = #VPBROADCAST_4u64(pst.[u64 0]); + st[0] ^= t256; + t256 = pst.[u256 8*1]; + st[1] ^= t256; + return st; +} + +inline fn __addpst23456 // remaining entries +( reg u256[7] st +, reg const ptr u64[25] pst +) -> reg u256[7] +{ + reg u256 r2, r3, r4, r5, r6; + reg u128 t128_0, t128_1; + reg u64 t; + + // [ 5 - ] + t128_0 = #VMOV(pst.[u64 5*8]); + // [ 6 7 8 9 ] + r3 = pst.[u256 6*8]; + // [ 10 - ] + t128_1 = #VMOV(pst.[u64 10*8]); + // [ 11 12 13 14 ] + r4 = pst.[u256 11*8]; + // [ 5 15 ] + t = pst.[u64 15*8]; + t128_0 = #VPINSR_2u64(t128_0, t, 1); + // [ 16 17 18 19 ] + r5 = pst.[u256 16*8]; + // [ 10 20 ] + t = pst.[u64 20*8]; + t128_1 = #VPINSR_2u64(t128_1, t, 1); + // [ 10 20 5 15 ] + r2 = (2u128)[t128_0, t128_1]; + st[2] ^= r2; + // [ 21 22 23 24 ] + r6 = pst.[u256 21*8]; + + st = __addstate_r3456(st, r3, r4, r5, r6); + + return st; +} + +fn _addpstate_avx2 +( reg u256[7] st +, reg const ptr u64[25] pst +) -> reg u256[7] +{ + st = __addpst01(st, pst); + st = __addpst23456(st, pst); + return st; +} + +/* + ADD RATE BIT + ============ +*/ + +inline fn __stavx2_pos(inline int POS) -> inline int, inline int { + inline int R, L; + //0: [ 0 0 0 0 ] + R = 0; L = 0; + if (0 < POS) { + //1: [ 1 2 3 4 ] + if (POS <= 4) { R = 1; L = POS-1; } + //2: [ 10 20 5 15 ] + else if (POS == 10) { R = 2; L = 0; } + else if (POS == 20) { R = 2; L = 1; } + else if (POS == 5 ) { R = 2; L = 2; } + else if (POS == 15) { R = 2; L = 3; } + //3: [ 16 7 23 14 ] + else if (POS == 16) { R = 3; L = 0; } + else if (POS == 7 ) { R = 3; L = 1; } + else if (POS == 23) { R = 3; L = 2; } + else if (POS == 14) { R = 3; L = 3; } + //4: [ 11 22 8 19 ] + else if (POS == 11) { R = 4; L = 0; } + else if (POS == 22) { R = 4; L = 1; } + else if (POS == 8 ) { R = 4; L = 2; } + else if (POS == 19) { R = 4; L = 3; } + //5: [ 21 17 13 9 ] + else if (POS == 21) { R = 5; L = 0; } + else if (POS == 17) { R = 5; L = 1; } + else if (POS == 13) { R = 5; L = 2; } + else if (POS == 9 ) { R = 5; L = 3; } + //6: [ 6 12 18 24 ] + else if (POS == 6 ) { R = 6; L = 0; } + else if (POS == 12) { R = 6; L = 1; } + else if (POS == 18) { R = 6; L = 2; } + else if (POS == 24) { R = 6; L = 3; } + } + return R,L; +} + +inline fn __u64_to_u256 +( reg u64 x +, inline int L +) -> reg u256 +{ + reg u256 t256; + reg u128 t128; + + if (L % 2 == 0) { + t128 = (128u) x; + } else { + t128 = #set0_128(); + t128 = #VPINSR_2u64(t128, x, 1); + } + t256 = #set0_256(); + if (L / 2 == 0) { + t256 = #VINSERTI128(t256, t128, 0); + } else { + t256 = #VINSERTI128(t256, t128, 1); + } + + return t256; +} + +inline fn __addratebit_avx2 +( reg u256[7] st +, inline int RATE8 +) -> reg ptr u256[7] +{ + inline int R, L; + reg u256 t256; + + reg u64 t64; + t64 = 1; + t64 <<= (8*RATE8-1) % 64; // obs: should be 63 for all admissible rates! + R,L = __stavx2_pos((RATE8-1)/8); + if (R==0) { + t256 = #VPBROADCAST_4u64(t64); + } else { + t256 = __u64_to_u256(t64, L); + } + st[R] ^= t256; + return st; +} + diff --git a/code/jasmin/mlkem_avx2/keccak/keccak1600_globals.jinc b/code/jasmin/mlkem_avx2/keccak/keccak1600_globals.jinc new file mode 100644 index 00000000..290d1693 --- /dev/null +++ b/code/jasmin/mlkem_avx2/keccak/keccak1600_globals.jinc @@ -0,0 +1,11 @@ +param int R72 = 72; +param int R104 = 104; +param int R136 = 136; +param int R144 = 144; +param int R168 = 168; + +param int UNFINISHED = 0; +param int SHA3 = 0x06; +param int RAWSHAKE = 0x07; +param int SHAKE = 0x1F; + diff --git a/code/jasmin/mlkem_avx2/keccak/keccak1600_imem_avx2.jinc b/code/jasmin/mlkem_avx2/keccak/keccak1600_imem_avx2.jinc new file mode 100644 index 00000000..e3e6ff99 --- /dev/null +++ b/code/jasmin/mlkem_avx2/keccak/keccak1600_imem_avx2.jinc @@ -0,0 +1,349 @@ +require "keccak1600_avx2.jinc" + +/* + ONE-SHOT (FIXED-SIZE) MEMORY ABSORB + =================================== +*/ + +inline fn __addstate_imem_avx2 +( reg u256[7] st +, reg u64 buf +, inline int LEN +, inline int TRAILB +) -> reg ptr u256[7] /* st */ + , reg u64 /* buf */ +{ + reg u64 t64; + reg u256 r0, r1, r2, r3, r4, r5, r6; + reg u128 t128_0, t128_1; + + buf, LEN, TRAILB, r0 = __mread_bcast_4subu64(buf, LEN, TRAILB); + st[0] ^= r0; + + buf, LEN, TRAILB, r1 = __mread_subu256(buf, LEN, TRAILB); + st[1] ^= r1; + + if (0 < LEN ) { + buf, LEN, TRAILB, t64 = __mread_subu64(buf, LEN, TRAILB); + t128_1 = (128u) t64; + + buf, LEN, TRAILB, r3 = __mread_subu256(buf, LEN, TRAILB); + + buf, LEN, TRAILB, t64 = __mread_subu64(buf, LEN, TRAILB); + t128_0 = (128u) t64; + + buf, LEN, TRAILB, r4 = __mread_subu256(buf, LEN, TRAILB); + + buf, LEN, TRAILB, t64 = __mread_subu64(buf, LEN, TRAILB); + t128_1 = #VPINSR_2u64(t128_1, t64, 1); + + buf, LEN, TRAILB, r5 = __mread_subu256(buf, LEN, TRAILB); + + buf, LEN, TRAILB, t64 = __mread_subu64(buf, LEN, TRAILB); + t128_0 = #VPINSR_2u64(t128_0, t64, 1); + r2 = (2u128)[t128_1, t128_0]; + st[2] ^= r2; + + buf, LEN, TRAILB, r6 = __mread_subu256(buf, LEN, TRAILB); + + st = __addstate_r3456( st, r3, r4, r5, r6); + } + return st, buf; +} + +inline fn __absorb_imem_avx2 +( reg u256[7] st +, reg u64 buf +, inline int LEN +, inline int RATE8 +, inline int TRAILB /* closes state if !=0 (i.e. adds trailbyte and padding) */ +) -> reg u256[7] /* st */ + , reg u64 /* buf */ +{ + reg u64 i; + inline int ALL, ITERS; + + ALL = LEN + (TRAILB!=0 ? 1 : 0); + + // continue by processing full blocks + ITERS = LEN / RATE8; // number of full blocks + if (0 < ITERS) { + i = 0; + while ( i < ITERS ) { + st, buf = __addstate_imem_avx2(st, buf, RATE8, 0); + st = _keccakf1600_avx2(st); + i += 1; + } + } + + // last incomplete block + LEN = LEN % RATE8; + st, buf = __addstate_imem_avx2(st, buf, LEN, TRAILB); + if (TRAILB!=0) { st = __addratebit_avx2(st, RATE8); } + + return st, buf; +} + +/* + INCREMENTAL (FIXED-SIZE) MEMORY ABSORB + ====================================== +*/ + +inline fn __pstate_imem_avx2 +( reg mut ptr u64[25] pst +, inline int AT /* bytes (0 <= AT < 200) */ +, reg u64 buf +, inline int LEN +, inline int TRAILB +) -> reg ptr u64[25] /* pst */ + , inline int /* AT */ + , reg u64 /* buf */ +{ + inline int LO, ALL; + reg u64 at, t64; + reg u256 t256; + reg u128 t128; + + ALL = AT+LEN; // total bytes to process (excluding trail byte, if !=0) + LO = AT % 8; // leftover bytes + at = AT / 8; // current pstate position + + if ( 0 < LO ) { // process first word... + if ( LO+LEN < 8) { // ...not enough to fill a word (just update it) + if ( TRAILB != 0 ) { ALL += 1; } + buf, _, TRAILB, t64 = __mread_subu64(buf, LEN, TRAILB); + t64 <<= 8*LO; + pst[(int) at] ^= t64; + LO = 0; + AT = 0; + LEN = 0; + } else { // process first word + if ( 8 <= LEN ) { + t64 = (u64)[buf]; + buf += (8-LO); + } else { + buf, _, _, t64 = __mread_subu64(buf, 8-LO, 0); + } + LEN -= 8-LO; + AT += 8-LO; + t64 <<= 8*LO; + pst[(int) at] ^= t64; + at += 1; + } + } + + // continue processing remaining bytes + if (32 <= LEN) { + while ( at < AT/8+4*(LEN/32)) { + t256 = (u256)[buf]; + buf += 32; + pst.[u256 8*at] = t256; + at += 4; + } + LEN = LEN % 32; + } + if (16 <= LEN) { + t128 = (u128)[buf]; + buf += 16; + pst.[u128 8*at] = t128; + at += 2; + LEN -= 16; + } + if (8 <= LEN) { + t64 = (u64)[buf]; + buf += 8; + pst.[u64 8*at] = t64; + at += 1; + LEN -= 8; + } + + // process last word (possibly closing the state) + LO = (AT+LEN) % 8; + if ( 0 < LEN || TRAILB != 0 ) { + if ( TRAILB != 0 ) { ALL += 1; } + buf, _, TRAILB, t64 = __mread_subu64(buf, LO, TRAILB); + pst[u64 (ALL/8)] = t64; + } + + return pst, ALL, buf; +} + +inline fn __pabsorb_imem_avx2 +( reg mut ptr u64[25] pst +, inline int AT +, reg u256[7] st +, reg u64 buf +, inline int LEN +, inline int RATE8 +, inline int TRAILB /* closes state if !=0 (i.e. adds trailbyte and padding) */ +) -> reg ptr u64[25] /* pst */ + , inline int /* AT */ + , reg u256[7] /* st */ + , reg u64 /* buf */ +{ + reg u64 i; + inline int ALL, ITERS; + + ALL = AT + LEN; + if ( (AT+LEN) < RATE8 ) { // not enough to fill a block! + pst, AT, buf = __pstate_imem_avx2(pst, AT, buf, LEN, TRAILB); + if (TRAILB != 0) { // add pstate and closes the state + i = AT/8 + 1; + if (AT <= 5*8) { // only st[0..1] is affected + while (i < 5) { pst[i] = 0; i += 1; } + st = __addpst01(st, pst); + st = __addratebit_avx2(st, RATE8); + } else { // all state is affected + while (i < RATE8/8) { pst[i] = 0; i += 1; } + pst[u8 RATE8-1] ^= 0x80; + st = _addpstate_avx2(st, pst); + } + } + } else { // at least a block is filled + if ( AT != 0 ) { // start by filling the first block + pst, _, buf = __pstate_imem_avx2(pst, AT, buf, RATE8-AT, 0); + LEN = LEN - (RATE8-AT); + st = _addpstate_avx2(st, pst); + st = _keccakf1600_avx2(st); + AT = 0; + } + + // continue by processing full blocks + ITERS = LEN / RATE8; // number of full blocks + i = 0; + while ( i < ITERS ) { + st, buf = __addstate_imem_avx2(st, buf, RATE8, 0); + st = _keccakf1600_avx2(st); + i += 1; + } + + // last incomplete block + LEN = ALL % RATE8; + if (TRAILB!=0) { + st, buf = __addstate_imem_avx2(st, buf, LEN, TRAILB); + st = __addratebit_avx2(st, RATE8); + AT = 0; + } else if ( LEN != 0) { + pst, AT, buf = __pstate_imem_avx2(pst, 0, buf, LEN, TRAILB); + } + } + return pst, AT, st, buf; +} + +/* + ONE-SHOT (FIXED-SIZE) MEMORY SQUEEZE + ==================================== +*/ + +inline fn __dumpstate_imem_avx2 +( reg u64 buf +, inline int LEN +, reg u256[7] st +) -> reg u64 +{ + reg u64 t; + reg u128 t128_0, t128_1; + reg u256 t256_0, t256_1, t256_2, t256_3, t256_4; + + // reg0 + if (8 <= LEN) { + buf, _ = __mwrite_subu256(buf, 8, st[0]); + LEN -= 8; + } else { + buf, LEN = __mwrite_subu256(buf, LEN, st[0]); + } + + // reg1 + buf, LEN = __mwrite_subu256(buf, LEN, st[1]); + + // reg2 (5) + if (0 reg u64 /* buf */ + , reg u256[7] /* st */ +{ + reg u64 i; + inline int ITERS, LO; + ITERS = LEN/RATE8; + LO = LEN%RATE8; + if (0 reg u256[7] +{ + inline int i; + reg u256[7] state; + + for i=0 to 7 + { state[i] = #set0_256(); } + + return state; +} + + +inline fn __init_s_state_avx2() -> stack u64[28] +{ + inline int i; + stack u64[28] s_state; + reg u256 zero; + + zero = #set0_256(); + for i=0 to 7 + { s_state[u256 i] = zero; } + + return s_state; +} + + +inline fn __add_full_block_avx2( + reg u256[7] state, + stack u64[28] s_state, + reg ptr u64[25] a_jagged_p, + reg u64 in inlen, + reg u64 rate +) -> reg u256[7], stack u64[28], reg u64, reg u64 +{ + + inline int i; + reg u64 j l t rate8; + + rate8 = rate; + rate8 >>= 3; + j = 0; + while ( j < rate8 ) + { + t = [in + 8*j]; + l = a_jagged_p[(int) j]; + s_state[(int) l] = t; + j += 1; + } + + //TODO: check & change to #VPBROADCAST_4u64 + t = s_state[0]; + s_state[1] = t; + s_state[2] = t; + s_state[3] = t; + + for i = 0 to 7 + { state[i] ^= s_state[u256 i]; } + + in += rate; + inlen -= rate; + + return state, s_state, in, inlen; +} + + +// TODO: refactor when this feature is available: https://github.com/haslab/libjbn/wiki/Feature-request-%231#procedural-parameters +inline fn __add_final_block_avx2( + reg u256[7] state, + stack u64[28] s_state, + reg ptr u64[25] a_jagged_p, + reg u64 in inlen, + reg u8 trail_byte, + reg u64 rate +) -> reg u256[7] +{ + inline int i; + reg u64 j l t inlen8; + reg u8 c; + + s_state = __init_s_state_avx2(); + + inlen8 = inlen; + inlen8 >>= 3; + j = 0; + while ( j < inlen8 ) + { + t = [in + 8*j]; + l = a_jagged_p[(int) j]; + s_state[(int) l] = t; + j += 1; + } + l = a_jagged_p[(int) j]; + l <<= 3; + j <<= 3; + + while ( j < inlen ) + { + c = (u8)[in + j]; + s_state[u8 (int) l] = c; + j += 1; + l += 1; + } + + s_state[u8 (int) l] = trail_byte; + + // j = (rate-1) >> 3; + j = rate; j -= 1; j >>= 3; + l = a_jagged_p[(int) j]; + l <<= 3; + // l += ((rate-1) & 0x7) + j = rate; j -= 1; j &= 0x7; + l += j; + + s_state[u8 (int) l] ^= 0x80; + + t = s_state[0]; + s_state[1] = t; + s_state[2] = t; + s_state[3] = t; + + for i = 0 to 7 + { state[i] ^= s_state[u256 i]; } + + return state; +} + + +// obs: @pre: len <= rate_in_bytes +inline fn __xtr_full_block_avx2( + reg u256[7] state, + reg ptr u64[25] a_jagged_p, + reg u64 out, + reg u64 len +) -> reg u64 +{ + inline int i; + stack u64[28] s_state; + reg u64 j l t len8; + + for i = 0 to 7 + { s_state[u256 i] = state[i]; } + + len8 = len; + len8 >>= 3; + j = 0; + while ( j < len8 ) + { + l = a_jagged_p[(int) j]; + t = s_state[(int) l]; + [out + 8*j] = t; + j += 1; + } + + out += len; + + return out; +} + + +// obs: @pre: len <= rate_in_bytes +inline fn __xtr_bytes_avx2( + reg u256[7] state, + reg ptr u64[25] a_jagged_p, + reg u64 out, + reg u64 len +) -> reg u64 +{ + inline int i; + stack u64[28] s_state; + reg u64 j l t len8; + reg u8 c; + + for i = 0 to 7 + { s_state[u256 i] = state[i]; } + + len8 = len; + len8 >>= 3; + j = 0; + while ( j < len8 ) + { l = a_jagged_p[(int) j]; + t = s_state[(int) l]; + [out + 8*j] = t; + j += 1; + } + l = a_jagged_p[(int)j]; + j <<= 3; + l <<= 3; + + while ( j < len ) + { + c = s_state[u8 (int) l]; + (u8)[out + j] = c; + j += 1; + l += 1; + } + + out += len; + + return out; +} + + +inline fn __absorb_avx2( + reg u256[7] state, + reg u64 in inlen, + reg u8 trail_byte, + reg u64 rate +) -> reg u256[7] +{ + stack u64[28] s_state; + reg ptr u64[25] a_jagged_p; + + a_jagged_p = KECCAK_A_JAGGED; + s_state = __init_s_state_avx2(); + + // intermediate blocks + while ( inlen >= rate ) + { + state, s_state, in, inlen = __add_full_block_avx2(state, s_state, a_jagged_p, in, inlen, rate); + state = _keccakf1600_avx2_(state); + } + + // final block + state = __add_final_block_avx2(state, s_state, a_jagged_p, in, inlen, trail_byte, rate); + + return state; +} + + +inline fn __squeeze_avx2(reg u256[7] state, reg u64 out outlen rate) -> reg u256[7] +{ + reg ptr u64[25] a_jagged_p; + + a_jagged_p = KECCAK_A_JAGGED; + + // intermediate blocks + while ( outlen > rate ) + { + state = _keccakf1600_avx2_(state); + out = __xtr_full_block_avx2(state, a_jagged_p, out, rate); + outlen -= rate; + } + + state = _keccakf1600_avx2_(state); + out = __xtr_bytes_avx2(state, a_jagged_p, out, outlen); + return state; +} + + +inline fn __keccak1600_avx2(reg u64 out outlen in inlen, reg u8 trail_byte, reg u64 rate) +{ + reg u256[7] state; + + state = __keccak_init_avx2(); + + // absorb + state = __absorb_avx2(state, in, inlen, trail_byte, rate); + + // squeeze + _ = __squeeze_avx2(state, out, outlen, rate); +} + + +fn _keccak1600_avx2(reg u64 out outlen in inlen, reg u8 trail_byte, reg u64 rate) +{ + __keccak1600_avx2(out, outlen, in, inlen, trail_byte, rate); +} + diff --git a/code/jasmin/mlkem_avx2/keccak/keccak1600_orig_avx2_ASIZE.jinc b/code/jasmin/mlkem_avx2/keccak/keccak1600_orig_avx2_ASIZE.jinc new file mode 100644 index 00000000..7d4c3e25 --- /dev/null +++ b/code/jasmin/mlkem_avx2/keccak/keccak1600_orig_avx2_ASIZE.jinc @@ -0,0 +1,247 @@ +/* DEPENDENCIES: +require "keccak1600_orig_avx2.jinc" +param int ASIZE=101; +*/ + +inline fn __addarray_full_block_avx2 +( reg u256[7] state +, stack u64[28] s_state +, reg ptr u64[25] a_jagged_p +, reg const ptr u8[ASIZE] in +, reg u64 offset +, reg u64 inlen +, reg u64 rate +) -> reg u256[7] /* st */ + , stack u64[28] /* pst */ + , reg u64 /* offset */ + , reg u64 /* inlen */ +{ + + inline int i; + reg u64 j l t rate8; + + rate8 = rate; + rate8 >>= 3; + j = 0; + while ( j < rate8 ) + { + t = in.[u64 offset + 8*j]; + l = a_jagged_p[(int) j]; + s_state[(int) l] = t; + j += 1; + } + + //TODO: check & change to #VPBROADCAST_4u64 + t = s_state[0]; + s_state[1] = t; + s_state[2] = t; + s_state[3] = t; + + for i = 0 to 7 + { state[i] ^= s_state[u256 i]; } + + offset += rate; + inlen -= rate; + + return state, s_state, offset, inlen; +} + + +// TODO: refactor when this feature is available: https://github.com/haslab/libjbn/wiki/Feature-request-%231#procedural-parameters +inline fn __addarray_final_block_avx2 +( reg u256[7] state +, stack u64[28] s_state +, reg ptr u64[25] a_jagged_p +, reg const ptr u8[ASIZE] in +, reg u64 offset +, reg u64 inlen +, reg u8 trail_byte +, reg u64 rate +) -> reg u256[7] +{ + inline int i; + reg u64 j l t inlen8; + reg u8 c; + + s_state = __init_s_state_avx2(); + + inlen8 = inlen; + inlen8 >>= 3; + j = 0; + while ( j < inlen8 ) + { + t = in.[u64 offset + 8*j]; + l = a_jagged_p[(int) j]; + s_state[(int) l] = t; + j += 1; + } + l = a_jagged_p[(int) j]; + l <<= 3; + j <<= 3; + + while ( j < inlen ) + { + c = in.[u8 offset + j]; + s_state[u8 (int) l] = c; + j += 1; + l += 1; + } + + s_state[u8 (int) l] = trail_byte; + + // j = (rate-1) >> 3; + j = rate; j -= 1; j >>= 3; + l = a_jagged_p[(int) j]; + l <<= 3; + // l += ((rate-1) & 0x7) + j = rate; j -= 1; j &= 0x7; + l += j; + + s_state[u8 (int) l] ^= 0x80; + + t = s_state[0]; + s_state[1] = t; + s_state[2] = t; + s_state[3] = t; + + for i = 0 to 7 + { state[i] ^= s_state[u256 i]; } + + return state; +} + + +// obs: @pre: len <= rate_in_bytes +inline fn __xtrarray_full_block_avx2 +( reg u256[7] state +, reg ptr u64[25] a_jagged_p +, reg mut ptr u8[ASIZE] out +, reg u64 offset +, reg u64 len +) -> reg ptr u8[ASIZE] /* out */ + , reg u64 /* offset */ +{ + inline int i; + stack u64[28] s_state; + reg u64 j l t len8; + + for i = 0 to 7 + { s_state[u256 i] = state[i]; } + + len8 = len; + len8 >>= 3; + j = 0; + while ( j < len8 ) + { + l = a_jagged_p[(int) j]; + t = s_state[(int) l]; + out.[u64 offset + 8*j] = t; + j += 1; + } + + offset += len; + + return out, offset; +} + + +// obs: @pre: len <= rate_in_bytes +inline fn __xtrarray_bytes_avx2 +( reg u256[7] state +, reg ptr u64[25] a_jagged_p +, reg mut ptr u8[ASIZE] out +, reg u64 offset +, reg u64 len +) -> reg ptr u8[ASIZE] /* out */ + , reg u64 /* offset */ +{ + inline int i; + stack u64[28] s_state; + reg u64 j l t len8; + reg u8 c; + + for i = 0 to 7 + { s_state[u256 i] = state[i]; } + + len8 = len; + len8 >>= 3; + j = 0; + while ( j < len8 ) + { l = a_jagged_p[(int) j]; + t = s_state[(int) l]; + out.[u64 offset + 8*j] = t; + j += 1; + } + l = a_jagged_p[(int)j]; + j <<= 3; + l <<= 3; + + while ( j < len ) + { + c = s_state[u8 (int) l]; + out.[u8 offset + j] = c; + j += 1; + l += 1; + } + + offset += len; + + return out, offset; +} + + +inline fn __absorbarray_avx2 +( reg const ptr u8[ASIZE] in +, reg u64 offset +, reg u64 inlen +, reg u8 trail_byte +, reg u64 rate +) -> reg u256[7] +{ reg u256[7] state; + stack u64[28] s_state; + reg ptr u64[25] a_jagged_p; + + a_jagged_p = KECCAK_A_JAGGED; + s_state = __init_s_state_avx2(); + + // intermediate blocks + while ( inlen >= rate ) + { + state, s_state, offset, inlen = __addarray_full_block_avx2(state, s_state, a_jagged_p, in, offset, inlen, rate); + state = _keccakf1600_avx2_(state); + } + + // final block + state = __addarray_final_block_avx2(state, s_state, a_jagged_p, in, offset, inlen, trail_byte, rate); + + return state; +} + + +inline fn __squeezearray_avx2 +( reg u256[7] state +, reg mut ptr u8[ASIZE] out +, reg u64 offset +, reg u64 outlen +, reg u64 rate +) -> reg ptr u8[ASIZE] /* out */ + , reg u64 /* offset */ + , reg u256[7] /* state */ +{ + reg ptr u64[25] a_jagged_p; + + a_jagged_p = KECCAK_A_JAGGED; + + // intermediate blocks + while ( outlen > rate ) + { + state = _keccakf1600_avx2_(state); + out, offset = __xtrarray_full_block_avx2(state, a_jagged_p, out, offset, rate); + outlen -= rate; + } + + state = _keccakf1600_avx2_(state); + out, offset = __xtrarray_bytes_avx2(state, a_jagged_p, out, offset, outlen); + return out, offset, state; +} + diff --git a/code/jasmin/mlkem_avx2/keccak/keccak1600x4_array_avx2_ASIZE.jinc b/code/jasmin/mlkem_avx2/keccak/keccak1600x4_array_avx2_ASIZE.jinc new file mode 100644 index 00000000..dab79aef --- /dev/null +++ b/code/jasmin/mlkem_avx2/keccak/keccak1600x4_array_avx2_ASIZE.jinc @@ -0,0 +1,416 @@ +/* DEPENDENCIES: +require "keccak1600x4_avx2.jinc" +param int ASIZE = 1001; +*/ + +require "subreadwrite_array_ASIZE.jinc" + + +/* + INCREMENTAL ARRAY BROADCAST ABSORB + ================================== +*/ + +inline fn __addstate_bcast_array_avx2x4 +( reg mut ptr u256[25] st +, inline int AT /* bytes (0 <= AT < 200) */ +, reg const ptr u8[ASIZE] buf +, reg u64 offset +, inline int LEN +, inline int TRAILB +) -> reg ptr u256[25] /* st */ + , inline int /* AT */ + , reg u64 /* offset */ +{ + inline int DELTA, LO, ALL; + reg u64 at; + reg u256 t256; + + ALL = AT+LEN; // total bytes to process (excluding trail byte, if !=0) + LO = AT % 8; // leftover bytes + at = 32 * (AT / 8); // current pstate position + DELTA = 0; + + if ( 0 < LO ) { // process first word... + if ( LO+LEN < 8) { // ...not enough to fill a word (just update it) + if ( TRAILB != 0 ) { ALL += 1; } + DELTA, _, TRAILB, t256 = __aread_bcast_4subu64(buf, offset, DELTA, LEN, TRAILB); + t256 = #VPSLL_4u64(t256, 8*LO); + t256 ^= st.[u256 (int) at]; + st.[u256 (int) at] = t256; + LO = 0; + AT = 0; + LEN = 0; + } else { // process first word + if ( 8 <= LEN ) { + t256 = #VPBROADCAST_4u64(buf.[u64 offset + DELTA]); + DELTA += (8-LO); + } else { + DELTA, _, _, t256 = __aread_bcast_4subu64(buf, offset, DELTA, 8-LO, 0); + } + LEN -= 8-LO; + AT += 8-LO; + t256 = #VPSLL_4u64(t256, 8*LO); + t256 ^= st.[u256 (int) at]; + st.[u256 (int) at] = t256; + at += 32; + } + } + + offset += DELTA; + DELTA = 0; + // continue processing remaining bytes + if (8 <= LEN) { + while ( at < 32*(AT/8)+32*(LEN/8)) { + t256 = #VPBROADCAST_4u64(buf.[u64 offset]); + offset += 8; + t256 ^= st.[u256 at]; + st.[u256 at] = t256; + at += 32; + } + LEN = (AT+LEN) % 8; + } + + // process last word (possibly closing the state) + LO = (AT+LEN) % 8; + if ( 0 < LO || TRAILB != 0 ) { + if ( TRAILB != 0 ) { ALL += 1; } + DELTA, _, TRAILB, t256 = __aread_bcast_4subu64(buf, offset, DELTA, LO, TRAILB); + offset += DELTA; + t256 ^= st.[u256 at]; + st.[u256 at] = t256; + } + return st, ALL, offset; +} + +inline fn __absorb_bcast_array_avx2x4 +( reg mut ptr u256[25] st +, inline int AT +, reg const ptr u8[ASIZE] buf +, reg u64 offset +, inline int LEN +, inline int RATE8 +, inline int TRAILB /* closes state if !=0 (i.e. adds trailbyte and padding) */ +) -> reg ptr u256[25] /* st */ + , inline int /* AT */ + , reg u64 /* offset */ +{ + reg u64 i; + inline int ALL, ITERS; + + ALL = AT + LEN; + if ( (AT+LEN) < RATE8 ) { // not enough to fill a block! + st, AT, offset = __addstate_bcast_array_avx2x4(st, AT, buf, offset, LEN, TRAILB); + if (TRAILB != 0) { // add pstate and closes the state + st = __addratebit_avx2x4(st, RATE8); + } + } else { // at least a block is filled + if ( AT != 0 ) { // start by filling the first block + st, _, offset = __addstate_bcast_array_avx2x4(st, AT, buf, offset, RATE8-AT, 0); + LEN = LEN - (RATE8-AT); + st = _keccakf1600_avx2x4(st); + AT = 0; + } + + // continue by processing full blocks + ITERS = LEN / RATE8; // number of full blocks + i = 0; + while ( i < ITERS ) { + st, _, offset = __addstate_bcast_array_avx2x4(st, 0, buf, offset, RATE8, 0); + st = _keccakf1600_avx2x4(st); + i += 1; + } + + // last incomplete block + LEN = ALL % RATE8; + st, AT, offset = __addstate_bcast_array_avx2x4(st, 0, buf, offset, LEN, TRAILB); + if (TRAILB!=0) { st = __addratebit_avx2x4(st, RATE8); } + } + return st, AT, offset; +} + +/* + INCREMENTAL (FIXED-SIZE) MEMORY 4-way ABSORB + ============================================ +*/ + +inline fn __addstate_array_avx2x4 +( reg mut ptr u256[25] st +, inline int AT /* bytes (0 <= AT < 200) */ +, reg const ptr u8[ASIZE] buf0 buf1 buf2 buf3 +, reg u64 offset +, inline int LEN +, inline int TRAILB +) -> reg ptr u256[25] /* st */ + , inline int /* AT */ + , reg u64 /* offset */ +{ + inline int DELTA, LO, ALL; + reg u64 at, t0, t1, t2, t3; + reg u256 t256_0, t256_1, t256_2, t256_3; + + ALL = AT+LEN; // total bytes to process (excluding trail byte, if !=0) + LO = AT % 8; // leftover bytes + at = 4 * (AT / 8); // current pstate position (referencing u64 words) +//at = 0, 4, 8, ... + DELTA = 0; + + if ( 0 < LO ) { // process first word... + if ( LO+LEN < 8) { // ...not enough to fill a word (just update it) + if ( TRAILB != 0 ) { ALL += 1; } + _, _, _, t0 = __aread_subu64(buf0, offset, DELTA, LEN, TRAILB); + _, _, _, t1 = __aread_subu64(buf1, offset, DELTA, LEN, TRAILB); + _, _, _, t2 = __aread_subu64(buf2, offset, DELTA, LEN, TRAILB); + DELTA, _, _, t3 = __aread_subu64(buf3, offset, DELTA, LEN, TRAILB); + t0 <<= 8*LO; + t0 ^= st[u64 at + 0]; + st[u64 at + 0] = t0; + t1 <<= 8*LO; + t1 ^= st[u64 at + 1]; + st[u64 at + 1] = t1; + t2 <<= 8*LO; + t2 ^= st[u64 at + 2]; + st[u64 at + 2] = t2; + t3 <<= 8*LO; + t3 ^= st[u64 at + 3]; + st[u64 at + 3] = t3; + LO = 0; + AT = 0; + LEN = 0; + TRAILB = 0; + } else { // process first word + if ( 8 <= LEN ) { + t0 = buf0.[u64 offset + DELTA]; + t1 = buf1.[u64 offset + DELTA]; + t2 = buf2.[u64 offset + DELTA]; + t3 = buf3.[u64 offset + DELTA]; + offset += 8-LO; + } else { + _, _, _, t0 = __aread_subu64(buf0, offset, DELTA, 8-LO, TRAILB); + _, _, _, t1 = __aread_subu64(buf1, offset, DELTA, 8-LO, TRAILB); + _, _, _, t2 = __aread_subu64(buf2, offset, DELTA, 8-LO, TRAILB); + DELTA, _, _, t3 = __aread_subu64(buf3, offset, DELTA, 8-LO, TRAILB); + } + LEN -= 8-LO; + AT += 8-LO; + t0 <<= 8*LO; + t0 ^= st[u64 at + 0]; + st[u64 at + 0] = t0; + t1 <<= 8*LO; + t1 ^= st[u64 at + 1]; + st[u64 at + 1] = t1; + t2 <<= 8*LO; + t2 ^= st[u64 at + 2]; + st[u64 at + 2] = t2; + t3 <<= 8*LO; + t3 ^= st[u64 at + 3]; + st[u64 at + 3] = t3; + at += 4; + } + } + offset += DELTA; + DELTA = 0; + // continue processing remaining bytes + if (8 <= LEN) { + while ( at < 4*(AT/8)+32*(LEN/32) ) { + t256_0 = buf0.[u256 offset]; + t256_1 = buf1.[u256 offset]; + t256_2 = buf0.[u256 offset]; + t256_3 = buf0.[u256 offset]; + offset += 32; + t256_0, t256_1, t256_2, t256_3 = __4u64x4_u256x4(t256_0, t256_1, t256_2, t256_3); + st.[u256 8*at] = t256_0; + st.[u256 8*at+32] = t256_1; + st.[u256 8*at+64] = t256_2; + st.[u256 8*at+96] = t256_3; + at += 32; + } + while ( at < 4*(AT/8)+4*(LEN/8)) { + t0 = buf0.[u64 offset]; + t0 ^= st[u64 at + 0]; + st[u64 at + 0] = t0; + t1 = buf1.[u64 offset]; + t1 ^= st[u64 at + 1]; + st[u64 at + 1] = t1; + t2 = buf2.[u64 offset]; + t2 ^= st[u64 at + 2]; + st[u64 at + 2] = t2; + t3 = buf3.[u64 offset]; + offset += 8; + t3 ^= st[u64 at + 3]; + st[u64 at + 3] = t3; + at += 4; + } + LEN = (AT+LEN) % 8; + } + + // process last word (possibly closing the state) + LO = (AT+LEN) % 8; + if ( 0 < LO || TRAILB != 0 ) { + _, _, _, t0 = __aread_subu64(buf0, offset, DELTA, LO, TRAILB); + _, _, _, t1 = __aread_subu64(buf1, offset, DELTA, LO, TRAILB); + _, _, _, t2 = __aread_subu64(buf2, offset, DELTA, LO, TRAILB); + DELTA, _, _, t3 = __aread_subu64(buf3, offset, DELTA, LO, TRAILB); + offset += DELTA; + if ( TRAILB != 0 ) { ALL += 1; TRAILB = 0; } + t0 ^= st[u64 at + 0]; + st[u64 at + 0] = t0; + t1 ^= st[u64 at + 1]; + st[u64 at + 1] = t1; + t2 ^= st[u64 at + 2]; + st[u64 at + 2] = t2; + t3 ^= st[u64 at + 3]; + st[u64 at + 3] = t3; + } + + return st, ALL, offset; +} + + +inline fn __absorb_array_avx2x4 +( reg mut ptr u256[25] st +, inline int AT +, reg const ptr u8[ASIZE] buf0 buf1 buf2 buf3 +, reg u64 offset +, inline int LEN +, inline int RATE8 +, inline int TRAILB /* closes state if !=0 (i.e. adds trailbyte and padding) */ +) -> reg ptr u256[25] /* st */ + , inline int /* AT */ + , reg u64 /* offset */ +{ + reg u64 i; + inline int ALL, ITERS; + + ALL = AT + LEN; + if ( (AT+LEN) < RATE8 ) { // not enough to fill a block! + st, AT, offset + = __addstate_array_avx2x4(st, AT, buf0, buf1, buf2, buf3, offset, LEN, TRAILB); + if (TRAILB != 0) { // add pstate and closes the state + st = __addratebit_avx2x4(st, RATE8); + } + } else { // at least a block is filled + if ( AT != 0 ) { // start by filling the first block + st, _, offset + = __addstate_array_avx2x4(st, AT, buf0, buf1, buf2, buf3, offset, RATE8-AT, 0); + LEN = LEN - (RATE8-AT); + st = _keccakf1600_avx2x4(st); + AT = 0; + } + + // continue by processing full blocks + ITERS = LEN / RATE8; // number of full blocks + i = 0; + while ( i < ITERS ) { + st, _, offset + = __addstate_array_avx2x4(st, 0, buf0, buf1, buf2, buf3, offset, RATE8, 0); + st = _keccakf1600_avx2x4(st); + i += 1; + } + + // last incomplete block + LEN = ALL % RATE8; + st, AT, offset + = __addstate_array_avx2x4(st, 0, buf0, buf1, buf2, buf3, offset, LEN, TRAILB); + if (TRAILB!=0) { st = __addratebit_avx2x4(st, RATE8); } + } + return st, AT, offset; +} + + +/* + ONE-SHOT (FIXED-SIZE) MEMORY SQUEEZE + ==================================== +*/ +inline fn __dumpstate_array_avx2x4 +( reg mut ptr u8[ASIZE] buf0 buf1 buf2 buf3 +, reg u64 offset +, inline int LEN +, reg const ptr u256[25] st +) -> reg ptr u8[ASIZE] /* buf0 */ + , reg ptr u8[ASIZE] /* buf1 */ + , reg ptr u8[ASIZE] /* buf2 */ + , reg ptr u8[ASIZE] /* buf3 */ + , reg u64 /* offset */ +{ + reg u256 x0, x1, x2, x3; + reg u64 i, t0, t1, t2, t3; + i = 0; + while (i reg ptr u8[ASIZE] /* buf0 */ + , reg ptr u8[ASIZE] /* buf1 */ + , reg ptr u8[ASIZE] /* buf2 */ + , reg ptr u8[ASIZE] /* buf3 */ + , reg u64 /* offset */ + , reg ptr u256[25] /* st */ +{ + reg u64 i; + inline int ITERS, LO; + ITERS = LEN/RATE8; + LO = LEN%RATE8; + + if (0 reg ptr u256[25] +{ + reg u64 i; + reg u256 z256; + z256 = #set0_256(); + i = 0; + while (i < 32*25) { + st.[u256 (int) i] = z256; + i += 32; + } + return st; +} + +/* + ADD RATE BIT + ============ +*/ + +inline fn __addratebit_avx2x4 +( reg mut ptr u256[25] st +, inline int RATE8 +) -> reg ptr u256[25] +{ + reg u256 t256; + reg u128 t128; + reg u64 t64; + t64 = 1; + t64 <<= (8*RATE8-1) % 64; // obs: should be 63 for all admissible rates! + t128 = (128u) t64; + t256 = #VPBROADCAST_4u64(t128); + t256 ^= st[(RATE8-1)/8]; + st[(RATE8-1)/8] = t256; + return st; +} + +/* + State25 to/from State4x25 + ========================= +*/ +// pack 4 keccak states (st25) into a 4-way state (st4x) +inline fn __u256x4_4u64x4 +( reg u256 x0 x1 x2 x3 +) -> reg u256, reg u256, reg u256, reg u256 { + // x0 = l00 l01 l02 l03 + // x1 = l10 l11 l12 l13 + // x2 = l20 l21 l22 l23 + // x3 = l30 l31 l32 l33 + reg u256 y0, y1, y2, y3; + y0 = #VPUNPCKL_4u64(x0, x1); // y0 = l00 l10 l02 l12 + y1 = #VPUNPCKH_4u64(x0, x1); // y1 = l01 l11 l03 l13 + y2 = #VPUNPCKL_4u64(x2, x3); // y2 = l20 l30 l22 l32 + y3 = #VPUNPCKH_4u64(x2, x3); // y3 = l21 l31 l23 l33 + + x0 = #VPERM2I128(y0, y2, 0x20); // x0 = l00 l10 l20 l30 + x1 = #VPERM2I128(y1, y3, 0x20); // x1 = l01 l11 l21 l31 + x2 = #VPERM2I128(y0, y2, 0x31); // x2 = l02 l12 l22 l32 + x3 = #VPERM2I128(y1, y3, 0x31); // x3 = l03 l13 l23 l33 + + return x0, x1, x2, x3; +} + +// extracts 4 keccak states (st25) from a 4-way state (st4x) +inline fn __4u64x4_u256x4 +( reg u256 y0 y1 y2 y3 +) -> reg u256, reg u256, reg u256, reg u256 { + // y0 = l00 l10 l20 l30 + // y1 = l01 l11 l21 l31 + // y2 = l02 l12 l22 l32 + // y3 = l03 l13 l23 l33 + reg u256 x0, x1, x2, x3; + x0 = #VPERM2I128(y0, y2, 0x20); // x0 = l00 l10 l02 l12 + x1 = #VPERM2I128(y1, y3, 0x20); // x1 = l01 l11 l03 l13 + x2 = #VPERM2I128(y0, y2, 0x31); // x2 = l20 l30 l22 l32 + x3 = #VPERM2I128(y1, y3, 0x31); // x3 = l21 l31 l23 l33 + + y0 = #VPUNPCKL_4u64(x0, x1); // y0 = l00 l01 l02 l03 + y1 = #VPUNPCKH_4u64(x0, x1); // y1 = l10 l11 l12 l13 + y2 = #VPUNPCKL_4u64(x2, x3); // y2 = l20 l21 l22 l23 + y3 = #VPUNPCKH_4u64(x2, x3); // y3 = l30 l31 l32 l33 + + return y0, y1, y2, y3; +} diff --git a/code/jasmin/mlkem_avx2/keccak/keccak1600x4_imem_avx2.jinc b/code/jasmin/mlkem_avx2/keccak/keccak1600x4_imem_avx2.jinc new file mode 100644 index 00000000..1733146c --- /dev/null +++ b/code/jasmin/mlkem_avx2/keccak/keccak1600x4_imem_avx2.jinc @@ -0,0 +1,421 @@ +require "keccak1600x4_avx2.jinc" + + + + +/* + INCREMENTAL (FIXED-SIZE) MEMORY BROADCAST ABSORB + ================================================ +*/ + +inline fn __addstate_bcast_imem_avx2x4 +( reg mut ptr u256[25] st +, inline int AT /* bytes (0 <= AT < 200) */ +, reg u64 buf +, inline int LEN +, inline int TRAILB +) -> reg ptr u256[25] /* st */ + , inline int /* AT */ + , reg u64 /* buf */ +{ + inline int LO, ALL; + reg u64 at; + reg u256 t256; + + ALL = AT+LEN; // total bytes to process (excluding trail byte, if !=0) + LO = AT % 8; // leftover bytes + at = 32 * (AT / 8); // current pstate position + + if ( 0 < LO ) { // process first word... + if ( LO+LEN < 8) { // ...not enough to fill a word (just update it) + if ( TRAILB != 0 ) { ALL += 1; } + buf, _, TRAILB, t256 = __mread_bcast_4subu64(buf, LEN, TRAILB); + t256 = #VPSLL_4u64(t256, 8*LO); + t256 ^= st.[u256 (int) at]; + st.[u256 (int) at] = t256; + LO = 0; + AT = 0; + LEN = 0; + } else { // process first word + if ( 8 <= LEN ) { + t256 = #VPBROADCAST_4u64((u64)[buf]); + buf += (8-LO); + } else { + buf, _, _, t256 = __mread_bcast_4subu64(buf, 8-LO, 0); + } + LEN -= 8-LO; + AT += 8-LO; + t256 = #VPSLL_4u64(t256, 8*LO); + t256 ^= st.[u256 (int) at]; + st.[u256 (int) at] = t256; + at += 32; + } + } + + // continue processing remaining bytes + if (8 <= LEN) { + while ( at < 32*(AT/8)+32*(LEN/8)) { + t256 = #VPBROADCAST_4u64((u64)[buf]); + buf += 8; + t256 ^= st.[u256 at]; + st.[u256 at] = t256; + at += 32; + } + LEN = (AT+LEN) % 8; + } + + // process last word (possibly closing the state) + LO = (AT+LEN) % 8; + if ( 0 < LO || TRAILB != 0 ) { + if ( TRAILB != 0 ) { ALL += 1; } + buf, _, TRAILB, t256 = __mread_bcast_4subu64(buf, LO, TRAILB); + t256 ^= st.[u256 at]; + st.[u256 at] = t256; + } + + return st, ALL, buf; +} + +inline fn __absorb_bcast_imem_avx2x4 +( reg mut ptr u256[25] st +, inline int AT +, reg u64 buf +, inline int LEN +, inline int RATE8 +, inline int TRAILB /* closes state if !=0 (i.e. adds trailbyte and padding) */ +) -> reg ptr u256[25] /* st */ + , inline int /* AT */ + , reg u64 /* buf */ +{ + reg u64 i; + inline int ALL, ITERS; + + ALL = AT + LEN; + if ( (AT+LEN) < RATE8 ) { // not enough to fill a block! + st, AT, buf = __addstate_bcast_imem_avx2x4(st, AT, buf, LEN, TRAILB); + if (TRAILB != 0) { // add pstate and closes the state + st = __addratebit_avx2x4(st, RATE8); + } + } else { // at least a block is filled + if ( AT != 0 ) { // start by filling the first block + st, _, buf = __addstate_bcast_imem_avx2x4(st, AT, buf, RATE8-AT, 0); + LEN = LEN - (RATE8-AT); + st = _keccakf1600_avx2x4(st); + AT = 0; + } + + // continue by processing full blocks + ITERS = LEN / RATE8; // number of full blocks + i = 0; + while ( i < ITERS ) { + st, _, buf = __addstate_bcast_imem_avx2x4(st, 0, buf, RATE8, 0); + st = _keccakf1600_avx2x4(st); + i += 1; + } + + // last incomplete block + LEN = ALL % RATE8; + st, AT, buf = __addstate_bcast_imem_avx2x4(st, 0, buf, LEN, TRAILB); + if (TRAILB!=0) { st = __addratebit_avx2x4(st, RATE8); } + } + return st, AT, buf; +} + +/* + INCREMENTAL (FIXED-SIZE) MEMORY 4-way ABSORB + ============================================ +*/ + +inline fn __addstate_imem_avx2x4 +( reg mut ptr u256[25] st +, inline int AT /* bytes (0 <= AT < 200) */ +, reg u64 buf0 buf1 buf2 buf3 +, inline int LEN +, inline int TRAILB +) -> reg ptr u256[25] /* st */ + , inline int /* AT */ + , reg u64 /* buf0 */ + , reg u64 /* buf1 */ + , reg u64 /* buf2 */ + , reg u64 /* buf3 */ +{ + inline int LO, ALL; + reg u64 at, t0, t1, t2, t3; + reg u256 t256_0, t256_1, t256_2, t256_3; + + ALL = AT+LEN; // total bytes to process (excluding trail byte, if !=0) + LO = AT % 8; // leftover bytes + at = 4 * (AT / 8); // current pstate position (referencing u64 words) +//at = 0, 4, 8, ... + + if ( 0 < LO ) { // process first word... + if ( LO+LEN < 8) { // ...not enough to fill a word (just update it) + if ( TRAILB != 0 ) { ALL += 1; } + buf0, _, _, t0 = __mread_subu64(buf0, LEN, TRAILB); + buf1, _, _, t1 = __mread_subu64(buf1, LEN, TRAILB); + buf2, _, _, t2 = __mread_subu64(buf2, LEN, TRAILB); + buf3, _, _, t3 = __mread_subu64(buf3, LEN, TRAILB); + t0 <<= 8*LO; + t0 ^= st[u64 at + 0]; + st[u64 at + 0] = t0; + t1 <<= 8*LO; + t1 ^= st[u64 at + 1]; + st[u64 at + 1] = t1; + t2 <<= 8*LO; + t2 ^= st[u64 at + 2]; + st[u64 at + 2] = t2; + t3 <<= 8*LO; + t3 ^= st[u64 at + 3]; + st[u64 at + 3] = t3; + LO = 0; + AT = 0; + LEN = 0; + TRAILB = 0; + } else { // process first word + if ( 8 <= LEN ) { + t0 = (u64)[buf0]; + buf0 += 8-LO; + t1 = (u64)[buf1]; + buf1 += 8-LO; + t2 = (u64)[buf2]; + buf2 += 8-LO; + t3 = (u64)[buf3]; + buf3 += 8-LO; + } else { + buf0, _, _, t0 = __mread_subu64(buf0, 8-LO, TRAILB); + buf1, _, _, t1 = __mread_subu64(buf1, 8-LO, TRAILB); + buf2, _, _, t2 = __mread_subu64(buf2, 8-LO, TRAILB); + buf3, _, _, t3 = __mread_subu64(buf3, 8-LO, TRAILB); + } + LEN -= 8-LO; + AT += 8-LO; + t0 <<= 8*LO; + t0 ^= st[u64 at + 0]; + st[u64 at + 0] = t0; + t1 <<= 8*LO; + t1 ^= st[u64 at + 1]; + st[u64 at + 1] = t1; + t2 <<= 8*LO; + t2 ^= st[u64 at + 2]; + st[u64 at + 2] = t2; + t3 <<= 8*LO; + t3 ^= st[u64 at + 3]; + st[u64 at + 3] = t3; + at += 4; + } + } + + // continue processing remaining bytes + if (8 <= LEN) { + while ( at < 4*(AT/8)+32*(LEN/32) ) { + t256_0 = (u256)[buf0]; + buf0 += 32; + t256_1 = (u256)[buf1]; + buf1 += 32; + t256_2 = (u256)[buf2]; + buf2 += 32; + t256_3 = (u256)[buf3]; + buf3 += 32; + t256_0, t256_1, t256_2, t256_3 = __4u64x4_u256x4(t256_0, t256_1, t256_2, t256_3); + st.[u256 8*at] = t256_0; + st.[u256 8*at+32] = t256_1; + st.[u256 8*at+64] = t256_2; + st.[u256 8*at+96] = t256_3; + at += 32; + } + while ( at < 4*(AT/8)+4*(LEN/8)) { + t0 = (u64)[buf0]; + buf0 += 8; + t0 ^= st[u64 at + 0]; + st[u64 at + 0] = t0; + t1 = (u64)[buf1]; + buf1 += 8; + t1 ^= st[u64 at + 1]; + st[u64 at + 1] = t1; + t2 = (u64)[buf2]; + buf2 += 8; + t2 ^= st[u64 at + 2]; + st[u64 at + 2] = t2; + t3 = (u64)[buf3]; + buf3 += 8; + t3 ^= st[u64 at + 3]; + st[u64 at + 3] = t3; + at += 4; + } + LEN = (AT+LEN) % 8; + } + + // process last word (possibly closing the state) + LO = (AT+LEN) % 8; + if ( 0 < LO || TRAILB != 0 ) { + buf0, _, _, t0 = __mread_subu64(buf0, LO, TRAILB); + buf1, _, _, t1 = __mread_subu64(buf1, LO, TRAILB); + buf2, _, _, t2 = __mread_subu64(buf2, LO, TRAILB); + buf3, _, _, t3 = __mread_subu64(buf3, LO, TRAILB); + if ( TRAILB != 0 ) { ALL += 1; TRAILB = 0; } + t0 ^= st[u64 at + 0]; + st[u64 at + 0] = t0; + t0 ^= st[u64 at + 1]; + st[u64 at + 1] = t1; + t0 ^= st[u64 at + 2]; + st[u64 at + 2] = t2; + t0 ^= st[u64 at + 3]; + st[u64 at + 3] = t3; + } + + return st, ALL, buf0, buf1, buf2, buf3; +} + + +inline fn __absorb_imem_avx2x4 +( reg mut ptr u256[25] st +, inline int AT +, reg u64 buf0 +, reg u64 buf1 +, reg u64 buf2 +, reg u64 buf3 +, inline int LEN +, inline int RATE8 +, inline int TRAILB /* closes state if !=0 (i.e. adds trailbyte and padding) */ +) -> reg ptr u256[25] /* st */ + , inline int /* AT */ + , reg u64 /* buf0 */ + , reg u64 /* buf1 */ + , reg u64 /* buf2 */ + , reg u64 /* buf3 */ +{ + reg u64 i; + inline int ALL, ITERS; + + ALL = AT + LEN; + if ( (AT+LEN) < RATE8 ) { // not enough to fill a block! + st, AT, buf0, buf1, buf2, buf3 + = __addstate_imem_avx2x4(st, AT, buf0, buf1, buf2, buf3, LEN, TRAILB); + if (TRAILB != 0) { // add pstate and closes the state + st = __addratebit_avx2x4(st, RATE8); + } + } else { // at least a block is filled + if ( AT != 0 ) { // start by filling the first block + st, _, buf0, buf1, buf2, buf3 + = __addstate_imem_avx2x4(st, AT, buf0, buf1, buf2, buf3, RATE8-AT, 0); + LEN = LEN - (RATE8-AT); + st = _keccakf1600_avx2x4(st); + AT = 0; + } + + // continue by processing full blocks + ITERS = LEN / RATE8; // number of full blocks + i = 0; + while ( i < ITERS ) { + st, _, buf0, buf1, buf2, buf3 + = __addstate_imem_avx2x4(st, 0, buf0, buf1, buf2, buf3, RATE8, 0); + st = _keccakf1600_avx2x4(st); + i += 1; + } + + // last incomplete block + LEN = ALL % RATE8; + st, AT, buf0, buf1, buf2, buf3 + = __addstate_imem_avx2x4(st, 0, buf0, buf1, buf2, buf3, LEN, TRAILB); + if (TRAILB!=0) { st = __addratebit_avx2x4(st, RATE8); } + } + return st, AT, buf0, buf1, buf2, buf3; +} + + +/* + ONE-SHOT (FIXED-SIZE) MEMORY SQUEEZE + ==================================== +*/ +inline fn __dumpstate_imem_avx2x4 +( reg u64 buf0 buf1 buf2 buf3 +, inline int LEN +, reg const ptr u256[25] st +) -> reg u64 /* buf0 */ + , reg u64 /* buf1 */ + , reg u64 /* buf2 */ + , reg u64 /* buf3 */ +{ + reg u256 x0, x1, x2, x3; + reg u64 i, t0, t1, t2, t3; + i = 0; + while (i reg u64 /* buf0 */ + , reg u64 /* buf1 */ + , reg u64 /* buf2 */ + , reg u64 /* buf3 */ + , reg ptr u256[25] /* st */ +{ + reg u64 i; + inline int ITERS, LO; + ITERS = LEN/RATE8; + LO = LEN%RATE8; + if (0 reg u256[7] reg u64 r iotas_o; reg ptr u256[24] iotas_p; - reg ptr u256[6] rhotates_left_p; - reg ptr u256[6] rhotates_right_p; iotas_p = KECCAK_IOTAS; iotas_o = 0; - rhotates_left_p = KECCAK_RHOTATES_LEFT; - rhotates_right_p = KECCAK_RHOTATES_RIGHT; r = KECCAK_ROUNDS; while @@ -97,32 +93,32 @@ fn _keccakf1600_avx2(reg u256[7] state) -> reg u256[7] d14 = d14 ^ t[4]; //######################################## Rho + Pi + pre-Chi shuffle - t[3] = #VPSLLV_4u64(state[2], rhotates_left_p[0] ); - state[2] = #VPSRLV_4u64(state[2], rhotates_right_p[0] ); + t[3] = #VPSLLV_4u64(state[2], KECCAK_RHOTATES_LEFT[0] ); + state[2] = #VPSRLV_4u64(state[2], KECCAK_RHOTATES_RIGHT[0] ); state[2] = state[2] | t[3]; state[3] = state[3] ^ d14; - t[4] = #VPSLLV_4u64(state[3], rhotates_left_p[2] ); - state[3] = #VPSRLV_4u64(state[3], rhotates_right_p[2] ); + t[4] = #VPSLLV_4u64(state[3], KECCAK_RHOTATES_LEFT[2] ); + state[3] = #VPSRLV_4u64(state[3], KECCAK_RHOTATES_RIGHT[2] ); state[3] = state[3] | t[4]; state[4] = state[4] ^ d14; - t[5] = #VPSLLV_4u64(state[4], rhotates_left_p[3] ); - state[4] = #VPSRLV_4u64(state[4], rhotates_right_p[3] ); + t[5] = #VPSLLV_4u64(state[4], KECCAK_RHOTATES_LEFT[3] ); + state[4] = #VPSRLV_4u64(state[4], KECCAK_RHOTATES_RIGHT[3] ); state[4] = state[4] | t[5]; state[5] = state[5] ^ d14; - t[6] = #VPSLLV_4u64(state[5], rhotates_left_p[4] ); - state[5] = #VPSRLV_4u64(state[5], rhotates_right_p[4] ); + t[6] = #VPSLLV_4u64(state[5], KECCAK_RHOTATES_LEFT[4] ); + state[5] = #VPSRLV_4u64(state[5], KECCAK_RHOTATES_RIGHT[4] ); state[5] = state[5] | t[6]; state[6] = state[6] ^ d14; t[3] = #VPERMQ(state[2], (4u2)[2,0,3,1]); t[4] = #VPERMQ(state[3], (4u2)[2,0,3,1]); - t[7] = #VPSLLV_4u64(state[6], rhotates_left_p[5] ); - t[1] = #VPSRLV_4u64(state[6], rhotates_right_p[5] ); + t[7] = #VPSLLV_4u64(state[6], KECCAK_RHOTATES_LEFT[5] ); + t[1] = #VPSRLV_4u64(state[6], KECCAK_RHOTATES_RIGHT[5] ); t[1] = t[1] | t[7]; state[1] = state[1] ^ d14; t[5] = #VPERMQ(state[4], (4u2)[0,1,2,3]); t[6] = #VPERMQ(state[5], (4u2)[1,3,0,2]); - t[8] = #VPSLLV_4u64(state[1], rhotates_left_p[1] ); - t[2] = #VPSRLV_4u64(state[1], rhotates_right_p[1] ); + t[8] = #VPSLLV_4u64(state[1], KECCAK_RHOTATES_LEFT[1] ); + t[2] = #VPSRLV_4u64(state[1], KECCAK_RHOTATES_RIGHT[1] ); t[2] = t[2] | t[8]; //######################################## Chi @@ -190,6 +186,34 @@ fn _keccakf1600_avx2(reg u256[7] state) -> reg u256[7] return state; } +/* +export fn testF(reg mut ptr u256[7] stm) -> reg ptr u256[7] +{ + reg u256[7] st; + inline int i; + for i = 0 to 7 { st[i] = stm[i]; } + st = _keccakf1600_avx2(st); + for i = 0 to 7 { stm[i] = st[i]; } + return stm; +} +*/ + +inline +fn _keccakf1600_avx2_(reg u256[7] state) -> reg u256[7] +{ + inline int i; + reg u256[7] st; + + for i = 0 to 7 { st[i] = state[i]; } + + st = _keccakf1600_avx2(st); + + for i = 0 to 7 { state[i] = st[i]; } + + return state; +} + + // converts a (plain) keccak state (st25) into the avx2 representation inline fn __stavx2_pack ( reg const ptr u64[25] st diff --git a/code/jasmin/mlkem_avx2/keccak/keccakf1600_globals.jinc b/code/jasmin/mlkem_avx2/keccak/keccakf1600_globals.jinc new file mode 100644 index 00000000..ccdcbb26 --- /dev/null +++ b/code/jasmin/mlkem_avx2/keccak/keccakf1600_globals.jinc @@ -0,0 +1,36 @@ +param int KECCAK_ROUNDS = 24; + +inline fn keccakf1600_index(inline int x y) -> inline int +{ + inline int r; + r = (x % 5) + 5 * (y % 5); + return r; +} + + +inline fn keccakf1600_rho_offsets(inline int i) -> inline int +{ + inline int r x y z t; + + r = 0; + x = 1; + y = 0; + + for t = 0 to 24 + { if (i == x + 5 * y) + { r = ((t + 1) * (t + 2) / 2) % 64; } + z = (2 * x + 3 * y) % 5; + x = y; + y = z; + } + + return r; +} + +inline fn keccakf1600_rhotates(inline int x y) -> inline int +{ + inline int i r; + i = keccakf1600_index(x, y); + r = keccakf1600_rho_offsets(i); + return r; +} diff --git a/code/jasmin/mlkem_avx2/keccak/keccakf1600x4_avx2.jinc b/code/jasmin/mlkem_avx2/keccak/keccakf1600x4_avx2.jinc new file mode 100644 index 00000000..cde4d741 --- /dev/null +++ b/code/jasmin/mlkem_avx2/keccak/keccakf1600x4_avx2.jinc @@ -0,0 +1,333 @@ + +require "keccakf1600_globals.jinc" + +u256[24] KECCAK1600_RC_AVX2 = +{ (4u64)[0x0000000000000001, 0x0000000000000001, 0x0000000000000001, 0x0000000000000001], + (4u64)[0x0000000000008082, 0x0000000000008082, 0x0000000000008082, 0x0000000000008082], + (4u64)[0x800000000000808a, 0x800000000000808a, 0x800000000000808a, 0x800000000000808a], + (4u64)[0x8000000080008000, 0x8000000080008000, 0x8000000080008000, 0x8000000080008000], + (4u64)[0x000000000000808b, 0x000000000000808b, 0x000000000000808b, 0x000000000000808b], + (4u64)[0x0000000080000001, 0x0000000080000001, 0x0000000080000001, 0x0000000080000001], + (4u64)[0x8000000080008081, 0x8000000080008081, 0x8000000080008081, 0x8000000080008081], + (4u64)[0x8000000000008009, 0x8000000000008009, 0x8000000000008009, 0x8000000000008009], + (4u64)[0x000000000000008a, 0x000000000000008a, 0x000000000000008a, 0x000000000000008a], + (4u64)[0x0000000000000088, 0x0000000000000088, 0x0000000000000088, 0x0000000000000088], + (4u64)[0x0000000080008009, 0x0000000080008009, 0x0000000080008009, 0x0000000080008009], + (4u64)[0x000000008000000a, 0x000000008000000a, 0x000000008000000a, 0x000000008000000a], + (4u64)[0x000000008000808b, 0x000000008000808b, 0x000000008000808b, 0x000000008000808b], + (4u64)[0x800000000000008b, 0x800000000000008b, 0x800000000000008b, 0x800000000000008b], + (4u64)[0x8000000000008089, 0x8000000000008089, 0x8000000000008089, 0x8000000000008089], + (4u64)[0x8000000000008003, 0x8000000000008003, 0x8000000000008003, 0x8000000000008003], + (4u64)[0x8000000000008002, 0x8000000000008002, 0x8000000000008002, 0x8000000000008002], + (4u64)[0x8000000000000080, 0x8000000000000080, 0x8000000000000080, 0x8000000000000080], + (4u64)[0x000000000000800a, 0x000000000000800a, 0x000000000000800a, 0x000000000000800a], + (4u64)[0x800000008000000a, 0x800000008000000a, 0x800000008000000a, 0x800000008000000a], + (4u64)[0x8000000080008081, 0x8000000080008081, 0x8000000080008081, 0x8000000080008081], + (4u64)[0x8000000000008080, 0x8000000000008080, 0x8000000000008080, 0x8000000000008080], + (4u64)[0x0000000080000001, 0x0000000080000001, 0x0000000080000001, 0x0000000080000001], + (4u64)[0x8000000080008008, 0x8000000080008008, 0x8000000080008008, 0x8000000080008008] +}; + +u256 ROL56 = 0x181F1E1D1C1B1A191017161514131211080F0E0D0C0B0A090007060504030201; +u256 ROL8 = 0x1E1D1C1B1A19181F16151413121110170E0D0C0B0A09080F0605040302010007; + +// C[x] = A[x,0] ^ A[x,1] ^ A[x,2] ^ A[x,3] ^ A[x,4] +inline fn keccakf1600_4x_theta_sum(reg ptr u256[25] a) -> reg u256[5] +{ + inline int x y; + reg u256[5] c; + + // C[x] = A[x, 0] + for x=0 to 5 + { c[x] = a[x + 0]; } + + // C[x] ^= A[x,1] ^ A[x,2] ^ A[x,3] ^ A[x,4] + for y=1 to 5 + { for x=0 to 5 + { c[x] ^= a[x + y*5]; } + } + + return c; +} + +inline fn keccakf1600_4x_rol(reg u256[5] a, inline int x r, reg u256 r8 r56) -> reg u256[5] +{ + reg u256 t; + + if(r == 8) + { a[x] = #VPSHUFB_256(a[x], r8); } + else { if(r == 56) + { a[x] = #VPSHUFB_256(a[x], r56); } + else + { t = #VPSLL_4u64(a[x], r); + a[x] = #VPSRL_4u64(a[x], 64 - r); + a[x] |= t; } + } + + return a; +} + +// D[x] = C[x-1] ^ ROT(C[x+1], 1) +inline fn keccakf1600_4x_theta_rol(reg u256[5] c, reg u256 r8 r56) -> reg u256[5] +{ + inline int x; + reg u256[5] d; + + for x = 0 to 5 + { // D[x] = C[x + 1] + d[x] = c[(x + 1) % 5]; + + // D[x] = ROT(D[x], 1) + d = keccakf1600_4x_rol(d, x, 1, r8, r56); + + // D[x] ^= C[x-1] + d[x] ^= c[(x - 1 + 5) % 5]; + } + + return d; +} + + +// B[x] = ROT( (A[x',y'] ^ D[x']), r[x',y'] ) with (x',y') = M^-1 (x,y) +// +// M = (0 1) M^-1 = (1 3) x' = 1x + 3y +// (2 3) (1 0) y' = 1x + 0y +// +inline fn keccakf1600_4x_rol_sum( + reg ptr u256[25] a, + reg u256[5] d, + inline int y, + reg u256 r8 r56 +) -> reg u256[5] +{ + inline int r x x_ y_; + reg u256[5] b; + + for x = 0 to 5 + { + x_ = (x + 3*y) % 5; + y_ = x; + r = keccakf1600_rhotates(x_, y_); + + // B[x] = A[x',y'] + b[x] = a[x_ + y_*5]; + + // B[x] ^= D[x']; + b[x] ^= d[x_]; + + // B[x] = ROT( B[x], r[x',y'] ); + if(r != 0) + { b = keccakf1600_4x_rol(b, x, r, r8, r56); } + } + + return b; +} + + +// E[x, y] = B[x] ^ ( (!B[x+1]) & B[x+2] ) +// -- when x and y are 0: E[0,0] ^= RC[i]; +inline fn keccakf1600_4x_set_row( + reg ptr u256[25] e, + reg u256[5] b, + inline int y, + reg u256 rc +) -> reg ptr u256[25] +{ + inline int x x1 x2; + reg u256 t; + + for x=0 to 5 + { + x1 = (x + 1) % 5; + x2 = (x + 2) % 5; + + t = #VPANDN_256(b[x1], b[x2]); + + t ^= b[x]; + if( x==0 && y==0 ){ t ^= rc; } + e[x + y*5] = t; + } + + return e; +} + + +fn keccakf1600_4x_round(reg ptr u256[25] e a, reg u256 rc r8 r56) -> reg ptr u256[25] +{ + inline int y; + reg u256[5] b c d; + + c = keccakf1600_4x_theta_sum(a); + d = keccakf1600_4x_theta_rol(c, r8, r56); + + for y = 0 to 5 + { b = keccakf1600_4x_rol_sum(a, d, y, r8, r56); + e = keccakf1600_4x_set_row(e, b, y, rc); + } + + return e; +} + +//////////////////////////////////////////////////////////////////////////////// + +inline fn __keccakf1600_avx2x4(reg ptr u256[25] a) -> reg ptr u256[25] +{ + #mmx reg ptr u256[25] a_s; + + reg ptr u256[24] RC; + + stack u256[25] s_e; + reg ptr u256[25] e; + + reg u256 rc r8 r56; + reg u64 c; + + RC = KECCAK1600_RC_AVX2; + e = s_e; + r8 = ROL8; + r56 = ROL56; + + c = 0; + while(c < (KECCAK_ROUNDS*32)) + { + rc = RC.[(int) c]; + e = keccakf1600_4x_round(e, a, rc, r8, r56); + + // just an expensive pointer swap (#todo request feature) + a_s = a; s_e = e; + a = a_s; e = s_e; + + rc = RC.[(int) c + 32]; + a = keccakf1600_4x_round(a, e, rc, r8, r56); + + // just an expensive pointer swap (#todo request feature) + a_s = a; s_e = e; + a = a_s; e = s_e; + + c += 64; + } + + return a; +} + +fn _keccakf1600_avx2x4(reg ptr u256[25] a) -> reg ptr u256[25] +{ + a = __keccakf1600_avx2x4(a); + return a; +} + +inline fn _keccakf1600_avx2x4_(reg ptr u256[25] a) -> reg ptr u256[25] +{ + a = a; + a = _keccakf1600_avx2x4(a); + a = a; + return a; +} + +/* +// pack 4 keccak states (st25) into a 4-way state (st4x) +inline fn __u256x4_4u64x4 +( reg u256 x0 x1 x2 x3 +) -> reg u256, reg u256, reg u256, reg u256 { + // x0 = l00 l01 l02 l03 + // x1 = l10 l11 l12 l13 + // x2 = l20 l21 l22 l23 + // x3 = l30 l31 l32 l33 + reg u256 y0, y1, y2, y3; + y0 = #VPUNPCKL_4u64(x0, x1); // y0 = l00 l10 l02 l12 + y1 = #VPUNPCKH_4u64(x0, x1); // y1 = l01 l11 l03 l13 + y2 = #VPUNPCKL_4u64(x2, x3); // y2 = l20 l30 l22 l32 + y3 = #VPUNPCKH_4u64(x2, x3); // y3 = l21 l31 l23 l33 + + x0 = #VPERM2I128(y0, y2, 0x20); // x0 = l00 l10 l20 l30 + x1 = #VPERM2I128(y1, y3, 0x20); // x1 = l01 l11 l21 l31 + x2 = #VPERM2I128(y0, y2, 0x31); // x2 = l02 l12 l22 l32 + x3 = #VPERM2I128(y1, y3, 0x31); // x3 = l03 l13 l23 l33 + + return x0, x1, x2, x3; +} + +inline fn __st4x_pack +( reg mut ptr u256[25] st4x +, reg const ptr u64[25] st0 st1 st2 st3 +) -> reg ptr u256[25] { + inline int i; + reg u256 x0, x1, x2, x3; + reg u64 t0, t1, t2, t3; + for i = 0 to 6 { + x0 = st0[u256 i]; + x1 = st1[u256 i]; + x2 = st2[u256 i]; + x3 = st3[u256 i]; + x0, x1, x2, x3 = __u256x4_4u64x4(x0, x1, x2, x3); + st4x[4*i+0] = x0; + st4x[4*i+1] = x1; + st4x[4*i+2] = x2; + st4x[4*i+3] = x3; + } + t0 = st0[24]; + t1 = st1[24]; + t2 = st2[24]; + t3 = st3[24]; + st4x[u64 4*24+0] = t0; + st4x[u64 4*24+1] = t1; + st4x[u64 4*24+2] = t2; + st4x[u64 4*24+3] = t3; + + return st4x; +} + + + +// extracts 4 keccak states (st25) from a 4-way state (st4x) +inline fn __4u64x4_u256x4 +( reg u256 y0 y1 y2 y3 +) -> reg u256, reg u256, reg u256, reg u256 { + // y0 = l00 l10 l20 l30 + // y1 = l01 l11 l21 l31 + // y2 = l02 l12 l22 l32 + // y3 = l03 l13 l23 l33 + reg u256 x0, x1, x2, x3; + x0 = #VPERM2I128(y0, y2, 0x20); // x0 = l00 l10 l02 l12 + x1 = #VPERM2I128(y1, y3, 0x20); // x1 = l01 l11 l03 l13 + x2 = #VPERM2I128(y0, y2, 0x31); // x2 = l20 l30 l22 l32 + x3 = #VPERM2I128(y1, y3, 0x31); // x3 = l21 l31 l23 l33 + + y0 = #VPUNPCKL_4u64(x0, x1); // y0 = l00 l01 l02 l03 + y1 = #VPUNPCKH_4u64(x0, x1); // y1 = l10 l11 l12 l13 + y2 = #VPUNPCKL_4u64(x2, x3); // y2 = l20 l21 l22 l23 + y3 = #VPUNPCKH_4u64(x2, x3); // y3 = l30 l31 l32 l33 + + return y0, y1, y2, y3; +} + +inline fn __st4x_unpack +( reg mut ptr u64[25] st0 st1 st2 st3 +, reg const ptr u256[25] st4x +) -> reg ptr u64[25], reg ptr u64[25], reg ptr u64[25], reg ptr u64[25] { + inline int i; + reg u256 x0, x1, x2, x3; + reg u64 t0, t1, t2, t3; + for i = 0 to 6 { + x0 = st4x[u256 4*i+0]; + x1 = st4x[u256 4*i+1]; + x2 = st4x[u256 4*i+2]; + x3 = st4x[u256 4*i+3]; + x0, x1, x2, x3 = __4u64x4_u256x4(x0, x1, x2, x3); + st0.[u256 4*8*i] = x0; + st1.[u256 4*8*i] = x1; + st2.[u256 4*8*i] = x2; + st3.[u256 4*8*i] = x3; + } + t0 = st4x[u64 4*24+0]; + t1 = st4x[u64 4*24+1]; + t2 = st4x[u64 4*24+2]; + t3 = st4x[u64 4*24+3]; + st0.[u64 8*24] = t0; + st1.[u64 8*24] = t1; + st2.[u64 8*24] = t2; + st3.[u64 8*24] = t3; + + return st0, st1, st2, st3; +} +*/ diff --git a/code/jasmin/mlkem_avx2/keccak/subreadwrite_array_ASIZE.jinc b/code/jasmin/mlkem_avx2/keccak/subreadwrite_array_ASIZE.jinc new file mode 100644 index 00000000..d33937e2 --- /dev/null +++ b/code/jasmin/mlkem_avx2/keccak/subreadwrite_array_ASIZE.jinc @@ -0,0 +1,261 @@ +/** + READ A FIXED NUMBER OF BYTES INTO A WORD +**/ + +inline fn __aread_subu64 +( reg const ptr u8[ASIZE] buf +, reg u64 offset +, inline int DELTA +, inline int LEN +, inline int TRAIL +) -> inline int /* DELTA */ + , inline int /* LEN */ + , inline int /* TRAIL */ + , reg u64 /* w */ +{ + reg u64 w, t16, t8; + inline int ILEN; + ILEN = LEN; + if (LEN <=s 0) { + w = TRAIL; + TRAIL = 0; + } else if (8 <=s LEN) { + w = buf.[u64 offset + DELTA]; + DELTA += 8; + LEN -= 8; + } else { + if (4 <=s LEN) { + w = (64u) buf.[u32 offset + DELTA]; + DELTA += 4; + LEN -= 4; + } else { + w = 0; + } + if (2 <=s LEN) { + t16 = (64u) buf.[u16 offset + DELTA]; + DELTA += 2; + LEN -= 2; + } else { + t16 = 0; + } + if (1 <=s LEN || TRAIL != 0) { + if (1 <=s LEN) { + t8 = (64u) buf.[u8 offset + DELTA]; + if (TRAIL != 0) { t8 |= 256*TRAIL; } + DELTA += 1; + LEN -= 1; + } else { + t8 = TRAIL; + } + TRAIL = 0; + t8 <<= 8*(2*((ILEN/2) % 2)); + t16 |= t8; + } + t16 <<= 8*(4*((ILEN/4) % 2)); + w |= t16; + } + return DELTA, LEN, TRAIL, w; +} + +inline fn __aread_bcast_4subu64 +( reg const ptr u8[ASIZE] buf +, reg u64 offset +, inline int DELTA +, inline int LEN +, inline int TRAIL +) -> inline int /* DELTA */ + , inline int /* LEN */ + , inline int /* TRAIL */ + , reg u256 /* w */ +{ + reg u64 t64; + reg u128 t128; + reg u256 w; + if (LEN <=s 0 && TRAIL==0) { + w = #set0_256(); + } else { + if (8 <= LEN) { + w = #VPBROADCAST_4u64(buf.[u64 offset + DELTA]); + DELTA += 8; + LEN -= 8; + } else { + DELTA, LEN, TRAIL, t64 = __aread_subu64(buf, offset, DELTA, LEN, TRAIL); + t128 = (128u) t64; + w = #VPBROADCAST_4u64(t128); + } + } + return DELTA, LEN, TRAIL, w; +} + +inline fn __aread_subu128 +( reg const ptr u8[ASIZE] buf +, reg u64 offset +, inline int DELTA +, inline int LEN +, inline int TRAIL +) -> inline int /* DELTA */ + , inline int /* LEN */ + , inline int /* TRAIL */ + , reg u128 /* w */ +{ + reg u128 w; + reg u64 t64; + if (LEN <=s 0 && TRAIL==0) { + w = #set0_128(); + } else if (16 <=s LEN) { + w = buf.[u128 offset + DELTA]; + DELTA += 16; + LEN -= 16; + } else { + if (8 <=s LEN) { + w = #VMOV(buf.[u64 offset + DELTA]); + DELTA += 8; + LEN -= 8; + DELTA, LEN, TRAIL, t64 = __aread_subu64(buf, offset, DELTA, LEN, TRAIL); + w = #VPINSR_2u64(w, t64, 1); + } else { + DELTA, LEN, TRAIL, t64 = __aread_subu64(buf, offset, DELTA, LEN, TRAIL); + w = (128u) t64; + } + } + return DELTA, LEN, TRAIL, w; +} + +inline fn __aread_subu256 +( reg const ptr u8[ASIZE] buf +, reg u64 offset +, inline int DELTA +, inline int LEN +, inline int TRAIL +) -> inline int /* DELTA */ + , inline int /* LEN */ + , inline int /* TRAIL */ + , reg u256 /* w */ +{ + reg u256 w; + reg u128 t128_0, t128_1; + if (LEN <=s 0 && TRAIL==0) { + w = #set0_256(); + } else if (32 <=s LEN) { + w = buf.[u256 offset + DELTA]; + DELTA += 32; + LEN -= 32; + } else { + if (16 <=s LEN) { + t128_0 = buf.[u128 offset + DELTA]; + DELTA += 16; + LEN -= 16; + DELTA, LEN, TRAIL, t128_1 = __aread_subu128(buf, offset, DELTA, LEN, TRAIL); + w = (2u128)[t128_1, t128_0]; + } else { + t128_1 = #set0_128(); + DELTA, LEN, TRAIL, t128_0 = __aread_subu128(buf, offset, DELTA, LEN, TRAIL); + w = (2u128)[t128_1, t128_0]; + } + } + return DELTA, LEN, TRAIL, w; +} + + +/** + WRITE A FIXED NUMBER OF BYTES FROM A WORD +**/ +inline fn __awrite_subu64 +( reg mut ptr u8[ASIZE] buf +, reg u64 offset +, inline int DELTA +, inline int LEN +, reg u64 w +) -> reg ptr u8[ASIZE] /* buf */ + , inline int /* DELTA */ + , inline int /* LEN */ +{ + if (0 >= 32; + DELTA += 4; + LEN -= 4; + } + if (2 <=s LEN) { + buf.[u16 offset + DELTA] = (16u) w; + w >>= 16; + DELTA += 2; + LEN -= 2; + } + if (1 <=s LEN) { + buf.[u8 offset + DELTA] = (8u) w; + DELTA += 1; + LEN -= 1; + } + } + } + return buf, DELTA, LEN; +} + +inline fn __awrite_subu128 +( reg mut ptr u8[ASIZE] buf +, reg u64 offset +, inline int DELTA +, inline int LEN +, reg u128 w +) -> reg ptr u8[ASIZE] /* buf */ + , inline int /* DELTA */ + , inline int /* LEN */ +{ + reg u64 t64; + if (0 reg ptr u8[ASIZE] /* buf */ + , inline int /* DELTA */ + , inline int /* LEN */ +{ + reg u128 t128; + if (0 reg u64 /* buf */ + , inline int /* LEN */ + , inline int /* TRAIL */ + , reg u64 /* w */ +{ + reg u64 w, t16, t8; + inline int ILEN; + ILEN = LEN; + if (LEN <=s 0) { + w = TRAIL; + TRAIL = 0; + } else if (8 <=s LEN) { + w = (u64)[buf]; + buf += 8; + LEN -= 8; + } else { + if (4 <=s LEN) { + w = (64u) (u32)[buf]; + buf += 4; + LEN -= 4; + } else { + w = 0; + } + if (2 <=s LEN) { + t16 = (64u) (u16)[buf]; + buf += 2; + LEN -= 2; + } else { + t16 = 0; + } + if (1 <=s LEN || TRAIL != 0) { + if (1 <=s LEN) { + t8 = (64u) (u8)[buf]; + if (TRAIL != 0) { t8 |= 256*TRAIL; } + buf += 1; + LEN -= 1; + } else { + t8 = TRAIL; + } + TRAIL = 0; + t8 <<= 8*(2*((ILEN/2) % 2)); + t16 |= t8; + } + t16 <<= 8*(4*((ILEN/4) % 2)); + w |= t16; + } + return buf, LEN, TRAIL, w; +} + +inline fn __mread_bcast_4subu64 +( reg u64 buf +, inline int LEN +, inline int TRAIL +) -> reg u64 /* buf */ + , inline int /* LEN */ + , inline int /* TRAIL */ + , reg u256 /* w */ +{ + reg u64 t64; + reg u128 t128; + reg u256 w; + if (LEN <=s 0 && TRAIL==0) { + w = #set0_256(); + } else { + if (8 <= LEN) { + w = #VPBROADCAST_4u64((u64)[buf]); + buf += 8; + LEN -= 8; + } else { + buf, LEN, TRAIL, t64 = __mread_subu64(buf, LEN, TRAIL); + t128 = (128u) t64; + w = #VPBROADCAST_4u64(t128); + } + } + return buf, LEN, TRAIL, w; +} + +inline fn __mread_subu128 +( reg u64 buf +, inline int LEN +, inline int TRAIL +) -> reg u64 /* buf */ + , inline int /* LEN */ + , inline int /* TRAIL */ + , reg u128 /* w */ +{ + reg u128 w; + reg u64 t64; + if (LEN <=s 0 && TRAIL==0) { + w = #set0_128(); + } else if (16 <=s LEN) { + w = (u128) [buf]; + buf += 16; + LEN -= 16; + } else { + if (8 <=s LEN) { + w = #VMOV((u64)[buf]); + buf += 8; + LEN -= 8; + buf, LEN, TRAIL, t64 = __mread_subu64(buf, LEN, TRAIL); + w = #VPINSR_2u64(w, t64, 1); + } else { + buf, LEN, TRAIL, t64 = __mread_subu64(buf, LEN, TRAIL); + w = (128u) t64; + } + } + return buf, LEN, TRAIL, w; +} + +inline fn __mread_subu256 +( reg u64 buf +, inline int LEN +, inline int TRAIL +) -> reg u64 /* buf */ + , inline int /* LEN */ + , inline int /* TRAIL */ + , reg u256 /* w */ +{ + reg u256 w; + reg u128 t128_0, t128_1; + if (LEN <=s 0 && TRAIL==0) { + w = #set0_256(); + } else if (32 <=s LEN) { + w = (u256)[buf]; + buf += 32; + LEN -= 32; + } else { + if (16 <=s LEN) { + t128_0 = (u128) [buf]; + buf += 16; + LEN -= 16; + buf, LEN, TRAIL, t128_1 = __mread_subu128(buf, LEN, TRAIL); + w = (2u128)[t128_1, t128_0]; + } else { + t128_1 = #set0_128(); + buf, LEN, TRAIL, t128_0 = __mread_subu128(buf, LEN, TRAIL); + w = (2u128)[t128_1, t128_0]; + } + } + return buf, LEN, TRAIL, w; +} + + +/** + WRITE A FIXED NUMBER OF BYTES FROM A WORD +**/ +inline fn __mwrite_subu64 +( reg u64 buf +, inline int LEN +, reg u64 w +) -> reg u64 /* buf */ + , inline int /* LEN */ +{ + if (0 >= 32; + buf += 4; + LEN -= 4; + } + if (2 <=s LEN) { + (u16)[buf] = (16u) w; + w >>= 16; + buf += 2; + LEN -= 2; + } + if (1 <=s LEN) { + (u8)[buf] = (8u) w; + buf += 1; + LEN -= 1; + } + } + } + return buf, LEN; +} + +inline fn __mwrite_subu128 +( reg u64 buf +, inline int LEN +, reg u128 w +) -> reg u64 /* buf */ + , inline int /* LEN */ +{ + reg u64 t64; + if (0 reg u64 /* buf */ + , inline int /* LEN */ +{ + reg u128 t128; + if (0 reg ptr u256[25], reg ptr u8[128], reg ptr u8[128], reg ptr u8[128], reg ptr u8[128] +{ + reg u256 t256; + reg u128 t128; + inline int i; + + state = _KeccakF1600_StatePermute4x(state); + + for i = 0 to (128 / 8) { + t256 = state[i]; + t128 = (128u)t256; + h0[u64 i] = #VMOVLPD(t128); + h1[u64 i] = #VMOVHPD(t128); + t128 = #VEXTRACTI128(t256, 1); + h2[u64 i] = #VMOVLPD(t128); + h3[u64 i] = #VMOVHPD(t128); + } + + return state, h0, h1, h2, h3; +} diff --git a/code/jasmin/mlkem_avx2/fips202_common.jinc b/code/jasmin/mlkem_avx2/keccak_OLD/fips202_common.jinc similarity index 100% rename from code/jasmin/mlkem_avx2/fips202_common.jinc rename to code/jasmin/mlkem_avx2/keccak_OLD/fips202_common.jinc diff --git a/code/jasmin/mlkem_avx2/gen_matrix_old.jinc b/code/jasmin/mlkem_avx2/keccak_OLD/gen_matrix_old.jinc similarity index 100% rename from code/jasmin/mlkem_avx2/gen_matrix_old.jinc rename to code/jasmin/mlkem_avx2/keccak_OLD/gen_matrix_old.jinc diff --git a/code/jasmin/mlkem_avx2/keccak/keccakf1600.jinc b/code/jasmin/mlkem_avx2/keccak_OLD/keccakf1600.jinc similarity index 100% rename from code/jasmin/mlkem_avx2/keccak/keccakf1600.jinc rename to code/jasmin/mlkem_avx2/keccak_OLD/keccakf1600.jinc diff --git a/code/jasmin/mlkem_avx2/keccak/keccakf1600_4x_avx2_compact.jinc b/code/jasmin/mlkem_avx2/keccak_OLD/keccakf1600_4x_avx2_compact.jinc similarity index 100% rename from code/jasmin/mlkem_avx2/keccak/keccakf1600_4x_avx2_compact.jinc rename to code/jasmin/mlkem_avx2/keccak_OLD/keccakf1600_4x_avx2_compact.jinc diff --git a/code/jasmin/mlkem_avx2/keccak_OLD/keccakf1600_avx2.jinc b/code/jasmin/mlkem_avx2/keccak_OLD/keccakf1600_avx2.jinc new file mode 100644 index 00000000..bbc0d321 --- /dev/null +++ b/code/jasmin/mlkem_avx2/keccak_OLD/keccakf1600_avx2.jinc @@ -0,0 +1,316 @@ +require "keccakf1600_generic.jinc" + +u256[24] KECCAK_IOTAS = +{ (4u64)[0x0000000000000001, 0x0000000000000001, 0x0000000000000001, 0x0000000000000001] + ,(4u64)[0x0000000000008082, 0x0000000000008082, 0x0000000000008082, 0x0000000000008082] + ,(4u64)[0x800000000000808a, 0x800000000000808a, 0x800000000000808a, 0x800000000000808a] + ,(4u64)[0x8000000080008000, 0x8000000080008000, 0x8000000080008000, 0x8000000080008000] + ,(4u64)[0x000000000000808b, 0x000000000000808b, 0x000000000000808b, 0x000000000000808b] + ,(4u64)[0x0000000080000001, 0x0000000080000001, 0x0000000080000001, 0x0000000080000001] + ,(4u64)[0x8000000080008081, 0x8000000080008081, 0x8000000080008081, 0x8000000080008081] + ,(4u64)[0x8000000000008009, 0x8000000000008009, 0x8000000000008009, 0x8000000000008009] + ,(4u64)[0x000000000000008a, 0x000000000000008a, 0x000000000000008a, 0x000000000000008a] + ,(4u64)[0x0000000000000088, 0x0000000000000088, 0x0000000000000088, 0x0000000000000088] + ,(4u64)[0x0000000080008009, 0x0000000080008009, 0x0000000080008009, 0x0000000080008009] + ,(4u64)[0x000000008000000a, 0x000000008000000a, 0x000000008000000a, 0x000000008000000a] + ,(4u64)[0x000000008000808b, 0x000000008000808b, 0x000000008000808b, 0x000000008000808b] + ,(4u64)[0x800000000000008b, 0x800000000000008b, 0x800000000000008b, 0x800000000000008b] + ,(4u64)[0x8000000000008089, 0x8000000000008089, 0x8000000000008089, 0x8000000000008089] + ,(4u64)[0x8000000000008003, 0x8000000000008003, 0x8000000000008003, 0x8000000000008003] + ,(4u64)[0x8000000000008002, 0x8000000000008002, 0x8000000000008002, 0x8000000000008002] + ,(4u64)[0x8000000000000080, 0x8000000000000080, 0x8000000000000080, 0x8000000000000080] + ,(4u64)[0x000000000000800a, 0x000000000000800a, 0x000000000000800a, 0x000000000000800a] + ,(4u64)[0x800000008000000a, 0x800000008000000a, 0x800000008000000a, 0x800000008000000a] + ,(4u64)[0x8000000080008081, 0x8000000080008081, 0x8000000080008081, 0x8000000080008081] + ,(4u64)[0x8000000000008080, 0x8000000000008080, 0x8000000000008080, 0x8000000000008080] + ,(4u64)[0x0000000080000001, 0x0000000080000001, 0x0000000080000001, 0x0000000080000001] + ,(4u64)[0x8000000080008008, 0x8000000080008008, 0x8000000080008008, 0x8000000080008008] +}; + + +u256[6] KECCAK_RHOTATES_LEFT = +{ + (4u64)[41, 36, 18, 3], + (4u64)[27, 28, 62, 1], + (4u64)[39, 56, 6, 45], + (4u64)[ 8, 55, 61, 10], + (4u64)[20, 25, 15, 2], + (4u64)[14, 21, 43, 44] +}; + + +u256[6] KECCAK_RHOTATES_RIGHT = +{ + (4u64)[64-41, 64-36, 64-18, 64- 3], + (4u64)[64-27, 64-28, 64-62, 64- 1], + (4u64)[64-39, 64-56, 64- 6, 64-45], + (4u64)[64- 8, 64-55, 64-61, 64-10], + (4u64)[64-20, 64-25, 64-15, 64- 2], + (4u64)[64-14, 64-21, 64-43, 64-44] +}; + + +fn _keccakf1600_avx2(reg u256[7] state) -> reg u256[7] +{ + reg u256[9] t; + reg u256 c00 c14 d00 d14; + + reg bool zf; + reg u64 r iotas_o; + + reg ptr u256[24] iotas_p; + reg ptr u256[6] rhotates_left_p; + reg ptr u256[6] rhotates_right_p; + + iotas_p = KECCAK_IOTAS; + iotas_o = 0; + rhotates_left_p = KECCAK_RHOTATES_LEFT; + rhotates_right_p = KECCAK_RHOTATES_RIGHT; + + r = KECCAK_ROUNDS; + while + { + //######################################## Theta + c00 = #VPSHUFD_256(state[2], (4u2)[1,0,3,2]); + c14 = state[5] ^ state[3]; + t[2] = state[4] ^ state[6]; + c14 = c14 ^ state[1]; + c14 = c14 ^ t[2]; + t[4] = #VPERMQ(c14, (4u2)[2,1,0,3]); + c00 = c00 ^ state[2]; + t[0] = #VPERMQ(c00, (4u2)[1,0,3,2]); + t[1] = c14 >>4u64 63; + t[2] = c14 +4u64 c14; + t[1] = t[1] | t[2]; + d14 = #VPERMQ(t[1], (4u2)[0,3,2,1]); + d00 = t[1] ^ t[4]; + d00 = #VPERMQ(d00, (4u2)[0,0,0,0]); + c00 = c00 ^ state[0]; + c00 = c00 ^ t[0]; + t[0] = c00 >>4u64 63; + t[1] = c00 +4u64 c00; + t[1] = t[1] | t[0]; + state[2] = state[2] ^ d00; + state[0] = state[0] ^ d00; + d14 = #VPBLEND_8u32(d14, t[1], (8u1)[1,1,0,0,0,0,0,0]); + t[4] = #VPBLEND_8u32(t[4], c00, (8u1)[0,0,0,0,0,0,1,1]); + d14 = d14 ^ t[4]; + + //######################################## Rho + Pi + pre-Chi shuffle + t[3] = #VPSLLV_4u64(state[2], rhotates_left_p[0] ); + state[2] = #VPSRLV_4u64(state[2], rhotates_right_p[0] ); + state[2] = state[2] | t[3]; + state[3] = state[3] ^ d14; + t[4] = #VPSLLV_4u64(state[3], rhotates_left_p[2] ); + state[3] = #VPSRLV_4u64(state[3], rhotates_right_p[2] ); + state[3] = state[3] | t[4]; + state[4] = state[4] ^ d14; + t[5] = #VPSLLV_4u64(state[4], rhotates_left_p[3] ); + state[4] = #VPSRLV_4u64(state[4], rhotates_right_p[3] ); + state[4] = state[4] | t[5]; + state[5] = state[5] ^ d14; + t[6] = #VPSLLV_4u64(state[5], rhotates_left_p[4] ); + state[5] = #VPSRLV_4u64(state[5], rhotates_right_p[4] ); + state[5] = state[5] | t[6]; + state[6] = state[6] ^ d14; + t[3] = #VPERMQ(state[2], (4u2)[2,0,3,1]); + t[4] = #VPERMQ(state[3], (4u2)[2,0,3,1]); + t[7] = #VPSLLV_4u64(state[6], rhotates_left_p[5] ); + t[1] = #VPSRLV_4u64(state[6], rhotates_right_p[5] ); + t[1] = t[1] | t[7]; + state[1] = state[1] ^ d14; + t[5] = #VPERMQ(state[4], (4u2)[0,1,2,3]); + t[6] = #VPERMQ(state[5], (4u2)[1,3,0,2]); + t[8] = #VPSLLV_4u64(state[1], rhotates_left_p[1] ); + t[2] = #VPSRLV_4u64(state[1], rhotates_right_p[1] ); + t[2] = t[2] | t[8]; + + //######################################## Chi + t[7] = #VPSRLDQ_256(t[1], 8); + t[0] = !t[1] & t[7]; + state[3] = #VPBLEND_8u32(t[2], t[6], (8u1)[0,0,0,0,1,1,0,0]); + t[8] = #VPBLEND_8u32(t[4], t[2], (8u1)[0,0,0,0,1,1,0,0]); + state[5] = #VPBLEND_8u32(t[3], t[4], (8u1)[0,0,0,0,1,1,0,0]); + t[7] = #VPBLEND_8u32(t[2], t[3], (8u1)[0,0,0,0,1,1,0,0]); + state[3] = #VPBLEND_8u32(state[3], t[4], (8u1)[0,0,1,1,0,0,0,0]); + t[8] = #VPBLEND_8u32(t[8], t[5], (8u1)[0,0,1,1,0,0,0,0]); + state[5] = #VPBLEND_8u32(state[5], t[2], (8u1)[0,0,1,1,0,0,0,0]); + t[7] = #VPBLEND_8u32(t[7], t[6], (8u1)[0,0,1,1,0,0,0,0]); + state[3] = #VPBLEND_8u32(state[3], t[5], (8u1)[1,1,0,0,0,0,0,0]); + t[8] = #VPBLEND_8u32(t[8], t[6], (8u1)[1,1,0,0,0,0,0,0]); + state[5] = #VPBLEND_8u32(state[5], t[6], (8u1)[1,1,0,0,0,0,0,0]); + t[7] = #VPBLEND_8u32(t[7], t[4], (8u1)[1,1,0,0,0,0,0,0]); + state[3] = !state[3] & t[8]; + state[5] = !state[5] & t[7]; + state[6] = #VPBLEND_8u32(t[5], t[2], (8u1)[0,0,0,0,1,1,0,0]); + t[8] = #VPBLEND_8u32(t[3], t[5], (8u1)[0,0,0,0,1,1,0,0]); + state[3] = state[3] ^ t[3]; + state[6] = #VPBLEND_8u32(state[6], t[3], (8u1)[0,0,1,1,0,0,0,0]); + t[8] = #VPBLEND_8u32(t[8], t[4], (8u1)[0,0,1,1,0,0,0,0]); + state[5] = state[5] ^ t[5]; + state[6] = #VPBLEND_8u32(state[6], t[4], (8u1)[1,1,0,0,0,0,0,0]); + t[8] = #VPBLEND_8u32(t[8], t[2], (8u1)[1,1,0,0,0,0,0,0]); + state[6] = !state[6] & t[8]; + state[6] = state[6] ^ t[6]; + state[4] = #VPERMQ(t[1], (4u2)[0,1,3,2]); + t[8] = #VPBLEND_8u32(state[4], state[0], (8u1)[0,0,1,1,0,0,0,0]); + state[1] = #VPERMQ(t[1], (4u2)[0,3,2,1]); + state[1] = #VPBLEND_8u32(state[1], state[0], (8u1)[1,1,0,0,0,0,0,0]); + state[1] = !state[1] & t[8]; + state[2] = #VPBLEND_8u32(t[4], t[5], (8u1)[0,0,0,0,1,1,0,0]); + t[7] = #VPBLEND_8u32(t[6], t[4], (8u1)[0,0,0,0,1,1,0,0]); + state[2] = #VPBLEND_8u32(state[2], t[6], (8u1)[0,0,1,1,0,0,0,0]); + t[7] = #VPBLEND_8u32(t[7], t[3], (8u1)[0,0,1,1,0,0,0,0]); + state[2] = #VPBLEND_8u32(state[2], t[3], (8u1)[1,1,0,0,0,0,0,0]); + t[7] = #VPBLEND_8u32(t[7], t[5], (8u1)[1,1,0,0,0,0,0,0]); + state[2] = !state[2] & t[7]; + state[2] = state[2] ^ t[2]; + t[0] = #VPERMQ(t[0], (4u2)[0,0,0,0]); + state[3] = #VPERMQ(state[3], (4u2)[0,1,2,3]); + state[5] = #VPERMQ(state[5], (4u2)[2,0,3,1]); + state[6] = #VPERMQ(state[6], (4u2)[1,3,0,2]); + state[4] = #VPBLEND_8u32(t[6], t[3], (8u1)[0,0,0,0,1,1,0,0]); + t[7] = #VPBLEND_8u32(t[5], t[6], (8u1)[0,0,0,0,1,1,0,0]); + state[4] = #VPBLEND_8u32(state[4], t[5], (8u1)[0,0,1,1,0,0,0,0]); + t[7] = #VPBLEND_8u32(t[7], t[2], (8u1)[0,0,1,1,0,0,0,0]); + state[4] = #VPBLEND_8u32(state[4], t[2], (8u1)[1,1,0,0,0,0,0,0]); + t[7] = #VPBLEND_8u32(t[7], t[3], (8u1)[1,1,0,0,0,0,0,0]); + state[4] = !state[4] & t[7]; + state[0] = state[0] ^ t[0]; + state[1] = state[1] ^ t[1]; + state[4] = state[4] ^ t[4]; + + //######################################## Iota + state[0] = state[0] ^ iotas_p.[(int) iotas_o]; + iotas_o += 32; + + _,_,_,zf,r = #DEC_64(r); + }(!zf) + + return state; +} + +// converts a (plain) keccak state (st25) into the avx2 representation +inline fn __stavx2_pack +( reg const ptr u64[25] st +) -> reg u256[7] { + // 3*r256 (evitáveis...) + reg u256[7] state; + reg u256 t256_0 t256_1 t256_2; + reg u128 t128_0, t128_1; + reg u64 r; + + // [ 0 0 0 0 ] + state[0] = #VPBROADCAST_4u64(st.[u64 8*0]); + // [ 1 2 3 4 ] + state[1] = st.[u256 1*8]; + // [ 5 - ] + t128_0 = #VMOV(st[5]); + // [ 6 7 8 9 ] + state[3] = st.[u256 6*8]; + // [ 10 - ] + t128_1 = #VMOV(st[10]); + // [ 11 12 13 14 ] + state[4] = st.[u256 11*8]; + // [ 5 15 ] + r = st[15]; + t128_0 = #VPINSR_2u64(t128_0, r, 1); + // [ 16 17 18 19 ] + state[5] = st.[u256 16*8]; + // [ 10 20 ] + r = st[20]; + t128_1 = #VPINSR_2u64(t128_1, r, 1); + // alternative not currently supported: VPGATHERDQ for filling state[2] + // [ 10 20 5 15 ] + state[2] = (2u128)[t128_0, t128_1]; + // [ 21 22 23 24 ] + state[6] = st.[u256 21*8]; + + // [ 16 7 8 19 ] + t256_0 = #VPBLEND_8u32(state[3], state[5], (8u1)[1,1,0,0,0,0,1,1]); + // [ 11 22 23 14 ] + t256_1 = #VPBLEND_8u32(state[6], state[4], (8u1)[1,1,0,0,0,0,1,1]); + // [ 6 12 13 9 ] + t256_2 = #VPBLEND_8u32(state[4], state[3], (8u1)[1,1,0,0,0,0,1,1]); + + // [ 16 7 23 14 ] + state[3] = #VPBLEND_8u32(t256_0, t256_1, (8u1)[1,1,1,1,0,0,0,0]); + // [ 11 22 8 19 ] + state[4] = #VPBLEND_8u32(t256_1, t256_0, (8u1)[1,1,1,1,0,0,0,0]); + + // [ 21 17 18 24 ] + t256_0 = #VPBLEND_8u32(state[5], state[6], (8u1)[1,1,0,0,0,0,1,1]); + + // [ 21 17 13 9 ] + state[5] = #VPBLEND_8u32(t256_0, t256_2, (8u1)[1,1,1,1,0,0,0,0]); + // [ 6 12 18 24 ] + state[6] = #VPBLEND_8u32(t256_2, t256_0, (8u1)[1,1,1,1,0,0,0,0]); + + // [ 0 0 0 0 ] + // [ 1 2 3 4 ] + // [ 10 20 5 15 ] + // [ 16 7 23 14 ] + // [ 11 22 8 19 ] + // [ 21 17 13 9 ] + // [ 6 12 18 24 ] + return state; +} + +// recovers a (plain) keccak state (st25) from an avx2-encoded one +inline fn __stavx2_unpack +( reg mut ptr u64[25] st +, reg u256[7] state +) -> reg ptr u64[25] { + // 5*r256 + 2*r128(evitáveis) (+7*r256) + reg u256 t256_0 t256_1 t256_2 t256_3 t256_4; + reg u128 t128_0, t128_1; + + // [ 0, 0 ] + t128_0 = (128u) state[0]; + st[0] = #VMOVLPD(t128_0); + // [ 1, 2, 3, 4 ] + st.[u256 1*8] = state[1]; + + // [ 16, 7, 8, 19 ] + t256_0 = #VPBLEND_8u32(state[3], state[4], (8u1)[1,1,1,1,0,0,0,0]); + // [ 11, 22, 23, 14 ] + t256_1 = #VPBLEND_8u32(state[4], state[3], (8u1)[1,1,1,1,0,0,0,0]); + // [ 21, 17, 18, 24 ] + t256_2 = #VPBLEND_8u32(state[5], state[6], (8u1)[1,1,1,1,0,0,0,0]); + // [ 6, 12, 13, 9 ] + t256_3 = #VPBLEND_8u32(state[6], state[5], (8u1)[1,1,1,1,0,0,0,0]); + + // [ 5, 15 ] +// state[2] = TTT[0]; + t128_1 = #VEXTRACTI128(state[2], 1); + st[5] = #VMOVLPD(t128_1); + + // [ 6, 7, 8, 9 ] + t256_4 = #VPBLEND_8u32(t256_0, t256_3, (8u1)[1,1,0,0,0,0,1,1]); + st.[u256 6*8] = t256_4; + + // [ 10, 20 ] + t128_0 = (128u) state[2]; + st[10] = #VMOVLPD(t128_0); + + // [ 11, 12, 13, 14 ] + t256_4 = #VPBLEND_8u32(t256_3, t256_1, (8u1)[1,1,0,0,0,0,1,1]); + st.[u256 11*8] = t256_4; + + // [ 15 ] + st[15] = #VMOVHPD(t128_1); + + // [ 16, 17, 18, 19 ] + t256_4 = #VPBLEND_8u32(t256_2, t256_0, (8u1)[1,1,0,0,0,0,1,1]); + st.[u256 16*8] = t256_4; + + // [ 20 ] + st[20] = #VMOVHPD(t128_0); + + // [ 21, 22, 23, 24 ] + t256_4 = #VPBLEND_8u32(t256_1, t256_2, (8u1)[1,1,0,0,0,0,1,1]); + st.[u256 21*8] = t256_4; + + return st; +} + diff --git a/code/jasmin/mlkem_avx2/keccak/keccakf1600_generic.jinc b/code/jasmin/mlkem_avx2/keccak_OLD/keccakf1600_generic.jinc similarity index 100% rename from code/jasmin/mlkem_avx2/keccak/keccakf1600_generic.jinc rename to code/jasmin/mlkem_avx2/keccak_OLD/keccakf1600_generic.jinc diff --git a/code/jasmin/mlkem_avx2/kem.h b/code/jasmin/mlkem_avx2/kem.h index c86b2540..130c142d 100644 --- a/code/jasmin/mlkem_avx2/kem.h +++ b/code/jasmin/mlkem_avx2/kem.h @@ -16,24 +16,24 @@ void crypto_kem_dec(unsigned char *m, const unsigned char *c, const unsigned char *sk); -void jade_kem_mlkem_mlkem768_amd64_avx2v_keypair_derand(unsigned char *pk, +void jade_kem_mlkem_mlkem768_amd64_avx2_keypair_derand(unsigned char *pk, unsigned char *sk, const unsigned char *randomness); -void jade_kem_mlkem_mlkem768_amd64_avx2v_enc_derand(unsigned char *c, +void jade_kem_mlkem_mlkem768_amd64_avx2_enc_derand(unsigned char *c, const unsigned char *m, const unsigned char *pk, const unsigned char *coins); -void jade_kem_mlkem_mlkem768_amd64_avx2v_keypair(unsigned char *pk, +void jade_kem_mlkem_mlkem768_amd64_avx2_keypair(unsigned char *pk, unsigned char *sk); -void jade_kem_mlkem_mlkem768_amd64_avx2v_enc(unsigned char *c, +void jade_kem_mlkem_mlkem768_amd64_avx2_enc(unsigned char *c, const unsigned char *m, const unsigned char *pk); -void jade_kem_mlkem_mlkem768_amd64_avx2v_dec(unsigned char *m, +void jade_kem_mlkem_mlkem768_amd64_avx2_dec(unsigned char *m, const unsigned char *c, const unsigned char *sk); diff --git a/code/jasmin/mlkem_avx2/kem.jinc b/code/jasmin/mlkem_avx2/kem.jinc index 7be45bb2..1dc19d54 100644 --- a/code/jasmin/mlkem_avx2/kem.jinc +++ b/code/jasmin/mlkem_avx2/kem.jinc @@ -32,9 +32,10 @@ fn __crypto_kem_keypair_jazz(reg u64 pkp, reg u64 skp, reg ptr u8[MLKEM_SYMBYTES s_skp = skp; pkp = s_pkp; - t64 = MLKEM_PUBLICKEYBYTES; - h_pk = _isha3_256(h_pk, pkp, t64); + //t64 = MLKEM_PUBLICKEYBYTES; + //h_pk = _isha3_256(h_pk, pkp, t64); + h_pk = _sha3_256A_M1184(h_pk, pkp); skp = s_skp; for i=0 to 4 @@ -74,10 +75,12 @@ fn __crypto_kem_enc_jazz(reg u64 ctp, reg u64 shkp, reg u64 pkp, reg ptr u8[MLKE buf[u64 i] = t64; } - t64 = MLKEM_PUBLICKEYBYTES; - buf[MLKEM_SYMBYTES:MLKEM_SYMBYTES] = _isha3_256(buf[MLKEM_SYMBYTES:MLKEM_SYMBYTES], pkp, t64); + //t64 = MLKEM_PUBLICKEYBYTES; + //buf[MLKEM_SYMBYTES:MLKEM_SYMBYTES] = _isha3_256(buf[MLKEM_SYMBYTES:MLKEM_SYMBYTES], pkp, t64); + buf[MLKEM_SYMBYTES:MLKEM_SYMBYTES] = _sha3_256A_M1184(buf[MLKEM_SYMBYTES:MLKEM_SYMBYTES], pkp); - kr = _sha3_512_64(kr, buf); + //kr = _sha3_512_64(kr, buf); + kr = _sha3_512A_A64(kr, buf); pkp = s_pkp; @@ -118,7 +121,8 @@ fn __crypto_kem_dec_jazz(reg u64 shkp, reg u64 ctp, reg u64 skp) s_skp = skp; - kr = _sha3_512_64(kr, buf); + //kr = _sha3_512_64(kr, buf); + kr = _sha3_512A_A64(kr, buf); pkp = s_skp; pkp += 12 * MLKEM_K * MLKEM_N>>3; @@ -138,7 +142,8 @@ fn __crypto_kem_dec_jazz(reg u64 shkp, reg u64 ctp, reg u64 skp) /* fixme: should this be done in memory? */ shkp = s_shkp; - _shake256_1120_32(shkp, zp, ctp); + //_shake256_1120_32(shkp, zp, ctp); + _shake256_M32__M32_M1088(shkp, zp, ctp); shkp = s_shkp; cnd = s_cnd; diff --git a/code/jasmin/mlkem_avx2/mlkem_keccak_avx2.jinc b/code/jasmin/mlkem_avx2/mlkem_keccak_avx2.jinc new file mode 100644 index 00000000..16cdfdb2 --- /dev/null +++ b/code/jasmin/mlkem_avx2/mlkem_keccak_avx2.jinc @@ -0,0 +1,251 @@ +require "keccak/keccak1600_imem_avx2.jinc" +require "keccak/keccak1600x4_imem_avx2.jinc" + +namespace A1 { + param int ASIZE = 1; + require "keccak/keccak1600_array_avx2_ASIZE.jinc" + require "keccak/keccak1600x4_array_avx2_ASIZE.jinc" +} + +namespace A2 { + param int ASIZE = 2; + require "keccak/keccak1600_array_avx2_ASIZE.jinc" + require "keccak/keccak1600x4_array_avx2_ASIZE.jinc" +} + +namespace A32 { + param int ASIZE = 32; + require "keccak/keccak1600_array_avx2_ASIZE.jinc" + require "keccak/keccak1600x4_array_avx2_ASIZE.jinc" +} + +namespace A64 { + param int ASIZE = 64; + require "keccak/keccak1600_array_avx2_ASIZE.jinc" +} + +namespace A128 { + param int ASIZE = 128; + require "keccak/keccak1600_array_avx2_ASIZE.jinc" + require "keccak/keccak1600x4_array_avx2_ASIZE.jinc" +} + +namespace ABUFLEN { + param int ASIZE = 536; + require "keccak/keccak1600_array_avx2_ASIZE.jinc" + require "keccak/keccak1600x4_array_avx2_ASIZE.jinc" +} + + +fn _sha3_256A_A32 +( #spill_to_mmx reg mut ptr u8[32] out +, reg const ptr u8[32] in +) -> reg ptr u8[32] +{ reg u256[7] st; + reg u64 offset; + st = __state_init_avx2(); + offset = 0; + st, _ = A32::__absorb_array_avx2(st, in, offset, 32, R136, SHA3); + offset = 0; + out, _ = A32::__squeeze_array_avx2(out, offset, 32, st, R136); + return out; +} + +fn _sha3_256A_M1184 +( #spill_to_mmx reg mut ptr u8[32] out +, #spill_to_mmx reg u64 in +) -> reg ptr u8[32] +{ reg u256[7] st; + reg u64 offset; + st = __state_init_avx2(); + st, _ = __absorb_imem_avx2(st, in, 1184, R136, SHA3); + offset = 0; + out, _ = A32::__squeeze_array_avx2(out, offset, 32, st, R136); + return out; +} + +fn _sha3_512A_A32 +( #spill_to_mmx reg mut ptr u8[64] out +, reg const ptr u8[32] in +) -> reg ptr u8[64] +{ reg u256[7] st; + reg u64 offset; + st = __state_init_avx2(); + offset = 0; + st, _ = A32::__absorb_array_avx2(st, in, offset, 32, R72, SHA3); + offset = 0; + out, _ = A64::__squeeze_array_avx2(out, offset, 64, st, R72); + return out; +} + +fn _sha3_512A_A64 +( reg mut ptr u8[64] out +, reg const ptr u8[64] in +) -> reg ptr u8[64] +{ reg u256[7] st; + reg u64 offset; + st = __state_init_avx2(); + offset = 0; + st, _ = A64::__absorb_array_avx2(st, in, offset, 64, R72, SHA3); + offset = 0; + out, _ = A64::__squeeze_array_avx2(out, offset, 64, st, R72); + return out; +} + +fn _shake256_M32__M32_M1088 +( reg u64 out +, reg u64 in0 in1 // 32+MLKEM_INDCPA_CIPHERTEXTBYTES +) +{ reg u256[7] st; + stack u64[25] pst_s; + reg ptr u64[25] pst; + st = __state_init_avx2(); + pst = pst_s; + pst = __pstate_init_avx2(pst); + pst, _, st, _ = __pabsorb_imem_avx2(pst, 0, st, in0, 32, R136, UNFINISHED); + pst, _, st, _ = __pabsorb_imem_avx2(pst, 32, st, in1, 1088, R136, SHAKE); + _, _ = __squeeze_imem_avx2(out, 32, st, R136); +} + +fn _shake256_A128__A32_A1 +( reg mut ptr u8[128] out +, reg const ptr u8[32] seed +, reg const ptr u8[1] nonce +) -> reg ptr u8[128] +{ reg u256[7] st; + stack u64[25] pst_s; + reg ptr u64[25] pst; + reg u64 offset; + st = __state_init_avx2(); + pst_s = pst; + pst = __pstate_init_avx2(pst); + offset = 0; + pst, _, st, _ = A32::__pabsorb_array_avx2(pst, 0, st, seed, offset, 32, R136, UNFINISHED); + offset = 0; + pst, _, st, _ = A1::__pabsorb_array_avx2(pst, 32, st, nonce, offset, 1, R136, SHAKE); + offset = 0; + out, _ = A128::__squeeze_array_avx2(out, offset, 128, st, R136); + + return out; +} + +fn _shake256x4_A128__A32_A1 +( reg mut ptr u8[128] out0 out1 out2 out3 +, reg const ptr u8[32] seed +, reg const ptr u8[4] nonces +) -> reg ptr u8[128] /* out0 */ + , reg ptr u8[128] /* out1 */ + , reg ptr u8[128] /* out2 */ + , reg ptr u8[128] /* out3 */ +{ stack u256[25] st_s; + reg ptr u256[25] st; + reg u64 offset; + st = st_s; + st = __state_init_avx2x4(st); + offset = 0; + st, _, _ = A32::__absorb_bcast_array_avx2x4(st, 0, seed, offset, 32, R136, UNFINISHED); + offset = 0; + st, _, _ = A1::__absorb_array_avx2x4(st, 32, nonces[0:1], nonces[1:1], nonces[2:1], nonces[3:1], offset, 1, R136, SHAKE); + offset = 0; + out0, out1, out2, out3, _, st + = A128::__squeeze_array_avx2x4(out0, out1, out2, out3, offset, 128, st, R136); + st_s = st; + + return out0, out1, out2, out3; +} + +fn _shake128_absorb_A32_A2 +( reg const ptr u8[32] seed +, reg const ptr u8[2] pos +) -> reg u256[7] +{ reg u256[7] st; + stack u64[25] pst_s; + reg ptr u64[25] pst; + reg u64 offset; + st = __state_init_avx2(); + pst = pst_s; + pst = __pstate_init_avx2(pst); + offset = 0; + pst, _, st, _ = A32::__pabsorb_array_avx2(pst, 0, st, seed, offset, 32, R168, UNFINISHED); + offset = 0; + pst, _, st, _ = A2::__pabsorb_array_avx2(pst, 32, st, pos, offset, 2, R168, SHAKE); + + return st; +} + +fn _shake128x4_absorb_A32_A2 +( reg mut ptr u256[25] st +, reg const ptr u8[32] seed +, reg const ptr u8[4*2] pos +) -> reg ptr u256[25] +{ inline int AT; + reg u64 offset; + st = __state_init_avx2x4(st); + offset = 0; + st, AT, _ = A32::__absorb_bcast_array_avx2x4(st, 0, seed, offset, 32, R168, UNFINISHED); + offset = 0; + st, _, _ = A2::__absorb_array_avx2x4(st, AT, pos[0:2], pos[2:2], pos[4:2], pos[6:2], offset, 2, R168, SHAKE); + + return st; +} + +fn _shake128_squeeze3blocks +( reg mut ptr u8[ABUFLEN::ASIZE] buf +, reg u256[7] st +) -> reg ptr u8[ABUFLEN::ASIZE] +{ + reg u64 offset; + st = _keccakf1600_avx2(st); + offset = 0; + buf, offset = ABUFLEN::__dumpstate_array_avx2(buf, offset, R168, st); + st = _keccakf1600_avx2(st); + buf, offset = ABUFLEN::__dumpstate_array_avx2(buf, offset, R168, st); + st = _keccakf1600_avx2(st); + buf, offset = ABUFLEN::__dumpstate_array_avx2(buf, offset, 200, st); + return buf; +} + +fn _shake128_next_state +( reg mut ptr u8[ABUFLEN::ASIZE] buf +) -> reg ptr u8[ABUFLEN::ASIZE] /* buf */ +{ + reg u256[7] st; + reg ptr u64[25] pst; + reg u64 offset; + pst = buf[u64 2*(168/8) : 25]; + st = __state_from_pstate_avx2(pst); + st = _keccakf1600_avx2(st); + offset = 2*168; + buf, _ = ABUFLEN::__dumpstate_array_avx2(buf, offset, 200, st); + return buf; +} + +fn _shake128x4_squeeze3blocks +( reg mut ptr u256[25] st +, reg mut ptr u8[4*ABUFLEN::ASIZE] buf +) -> reg ptr u256[25] /* st */ + , reg ptr u8[4*ABUFLEN::ASIZE] /* buf */ +{ + reg ptr u8[ABUFLEN::ASIZE] buf0 buf1 buf2 buf3; + reg u64 offset; + buf0 = buf[0*ABUFLEN::ASIZE : ABUFLEN::ASIZE]; + buf1 = buf[1*ABUFLEN::ASIZE : ABUFLEN::ASIZE]; + buf2 = buf[2*ABUFLEN::ASIZE : ABUFLEN::ASIZE]; + buf3 = buf[3*ABUFLEN::ASIZE : ABUFLEN::ASIZE]; + offset = 0; + st = _keccakf1600_avx2x4(st); + buf0, buf1, buf2, buf3, offset + = ABUFLEN::__dumpstate_array_avx2x4(buf0, buf1, buf2, buf3, offset, R168, st); + st = _keccakf1600_avx2x4(st); + buf0, buf1, buf2, buf3, offset + = ABUFLEN::__dumpstate_array_avx2x4(buf0, buf1, buf2, buf3, offset, R168, st); + st = _keccakf1600_avx2x4(st); + buf0, buf1, buf2, buf3, offset + = ABUFLEN::__dumpstate_array_avx2x4(buf0, buf1, buf2, buf3, offset, 200, st); + buf[0*ABUFLEN::ASIZE : ABUFLEN::ASIZE] = buf0; + buf[1*ABUFLEN::ASIZE : ABUFLEN::ASIZE] = buf1; + buf[2*ABUFLEN::ASIZE : ABUFLEN::ASIZE] = buf2; + buf[3*ABUFLEN::ASIZE : ABUFLEN::ASIZE] = buf3; + + return st, buf; +} diff --git a/code/jasmin/mlkem_avx2/mlkem_keccak_avx2_TRANSITION.jinc b/code/jasmin/mlkem_avx2/mlkem_keccak_avx2_TRANSITION.jinc new file mode 100644 index 00000000..66d7d7f5 --- /dev/null +++ b/code/jasmin/mlkem_avx2/mlkem_keccak_avx2_TRANSITION.jinc @@ -0,0 +1,175 @@ + +namespace OLD_KECCAK { +require "keccak_OLD/fips202.jinc" +require "keccak_OLD/fips202_4x.jinc" + +inline fn _sha3_256A_M1184 +( #spill_to_mmx reg mut ptr u8[32] out +, #spill_to_mmx reg u64 in +) -> reg ptr u8[32] +{ reg u64 inlen; + inlen = 1184; + out = _isha3_256(out, in, inlen); + return out; +} + +inline fn _shake256_M32__M32_M1088 +( reg u64 out +, reg u64 in0 in1 // 32+MLKEM_INDCPA_CIPHERTEXTBYTES +) +{ _shake256_1120_32(out, in0, in1); } + +inline fn _shake256_A128__A32_A1 +( reg mut ptr u8[128] out +, reg const ptr u8[32] seed +, reg const ptr u8[1] nonce +) -> reg ptr u8[128] +{ reg u256 t256; + reg u8 t8; + stack u8[33] in_s; + reg ptr u8[33] in; + in = in_s; + t256 = seed[u256 0]; + in[u256 0] = t256; + t8 = nonce[0]; + in[32] = t8; + out = _shake256_128_33(out, in); + return out; +} + +inline fn _shake256x4_A128__A32_A1 +( reg mut ptr u8[128] out0 out1 out2 out3 +, reg const ptr u8[32] seed +, reg const ptr u8[4] nonce +) -> reg ptr u8[128] /* out0 */ + , reg ptr u8[128] /* out1 */ + , reg ptr u8[128] /* out2 */ + , reg ptr u8[128] /* out3 */ +{ reg u256 t256; + reg u8 t8; + stack u8[33] in0 in1 in2 in3; + stack u256[25] st_s; + reg ptr u256[25] st; + st = st_s; + t256 = seed[u256 0]; + in0[u256 0] = t256; + t8 = nonce[0]; + in0[32] = t8; + t8 = nonce[1]; + in1[u256 0] = t256; + in1[32] = t8; + t8 = nonce[2]; + in2[u256 0] = t256; + in2[32] = t8; + t8 = nonce[3]; + in3[u256 0] = t256; + in3[32] = t8; + st = _shake256_absorb4x_33(st, in0, in1, in2, in3); + st, out0, out1, out2, out3 = __shake256_squeezeblock4xTRANSITION(st, out0, out1, out2, out3); + return out0, out1, out2, out3; +} + +inline fn _sha3_256A_A32 +( #spill_to_mmx reg mut ptr u8[32] out +, reg const ptr u8[32] in +) -> reg ptr u8[32] +{ out = _isha3_256_32(out, in); return out; } + +inline fn _sha3_512A_A64 +( reg mut ptr u8[64] out +, reg const ptr u8[64] in +) -> reg ptr u8[64] +{ out = _sha3_512_64(out, in); return out; } + +inline fn _sha3_512A_A32 +( reg mut ptr u8[64] out +, reg const ptr u8[32] in +) -> reg ptr u8[64] +{ out = _sha3_512_32(out, in); return out; } + +} // OLD_KECCAK + +namespace NEW_KECCAK { +require "mlkem_keccak_avx2.jinc" +} + +inline fn _sha3_256A_M1184 +( #spill_to_mmx reg mut ptr u8[32] out +, #spill_to_mmx reg u64 in +) -> reg ptr u8[32] +{ out = NEW_KECCAK::_sha3_256A_M1184(out, in); return out; } + + +inline fn _shake256_M32__M32_M1088 +( reg u64 out +, reg u64 in0 in1 // 32+MLKEM_INDCPA_CIPHERTEXTBYTES +) +{ NEW_KECCAK::_shake256_M32__M32_M1088(out, in0, in1); } + +inline fn _shake256_A128__A32_A1 +( reg mut ptr u8[128] out +, reg const ptr u8[32] seed +, reg const ptr u8[1] nonce +) -> reg ptr u8[128] +{ out = NEW_KECCAK::_shake256_A128__A32_A1(out, seed, nonce); return out; } + +inline fn _shake256x4_A128__A32_A1 +( reg mut ptr u8[128] out0 out1 out2 out3 +, reg const ptr u8[32] seed +, reg const ptr u8[4] nonces +) -> reg ptr u8[128] /* out0 */ + , reg ptr u8[128] /* out1 */ + , reg ptr u8[128] /* out2 */ + , reg ptr u8[128] /* out3 */ +{ out0, out1, out2, out3 = NEW_KECCAK::_shake256x4_A128__A32_A1(out0, out1, out2, out3, seed, nonces); return out0, out1, out2, out3; } + +inline fn _sha3_256A_A32 +( #spill_to_mmx reg mut ptr u8[32] out +, reg const ptr u8[32] in +) -> reg ptr u8[32] +{ out = NEW_KECCAK::_sha3_256A_A32(out, in); return out; } + +inline fn _sha3_512A_A64 +( reg mut ptr u8[64] out +, reg const ptr u8[64] in +) -> reg ptr u8[64] +{ out = NEW_KECCAK::_sha3_512A_A64(out, in); return out; } + +inline fn _sha3_512A_A32 +( reg mut ptr u8[64] out +, reg const ptr u8[32] in +) -> reg ptr u8[64] +{ out = NEW_KECCAK::_sha3_512A_A32(out, in); return out; } + + +// Only available on the new version!!! +inline fn _shake128_absorb_A32_A2 +( reg const ptr u8[32] seed +, reg const ptr u8[2] pos +) -> reg u256[7] +{ reg u256[7] st; st = NEW_KECCAK::_shake128_absorb_A32_A2(seed, pos); return st; } + +inline fn _shake128x4_absorb_A32_A2 +( reg mut ptr u256[25] st +, reg const ptr u8[32] seed +, reg const ptr u8[8] pos +) -> reg ptr u256[25] +{ st = NEW_KECCAK::_shake128x4_absorb_A32_A2(st, seed, pos); return st; } + +inline fn _shake128_squeeze3blocks +( reg mut ptr u8[536] buf +, reg u256[7] st +) -> reg ptr u8[536] +{ buf = NEW_KECCAK::_shake128_squeeze3blocks( buf, st); return buf; } + +inline fn _shake128_next_state +( reg mut ptr u8[536] buf +) -> reg ptr u8[536] /* buf */ +{ buf = NEW_KECCAK::_shake128_next_state(buf); return buf; } + +inline fn _shake128x4_squeeze3blocks +( reg mut ptr u256[25] st +, reg mut ptr u8[4*536] buf +) -> reg ptr u256[25] /* st */ + , reg ptr u8[4*536] /* buf */ +{ st, buf = NEW_KECCAK::_shake128x4_squeeze3blocks(st, buf); return st, buf; } diff --git a/code/jasmin/mlkem_avx2/poly.jinc b/code/jasmin/mlkem_avx2/poly.jinc index 6e902b8c..6a528888 100644 --- a/code/jasmin/mlkem_avx2/poly.jinc +++ b/code/jasmin/mlkem_avx2/poly.jinc @@ -2,8 +2,13 @@ require "params.jinc" require "shuffle.jinc" require "consts.jinc" require "reduce.jinc" + +require "mlkem_keccak_avx2_TRANSITION.jinc" +require "keccak_OLD/fips202_common.jinc" +/* replaced by "mlkem_keccak_avx2"... require "fips202.jinc" require "fips202_4x.jinc" +*/ fn _poly_add2(reg ptr u16[MLKEM_N] rp bp) -> stack u16[MLKEM_N] { @@ -764,6 +769,7 @@ fn _poly_getnoise(reg ptr u16[MLKEM_N] rp, reg ptr u8[MLKEM_SYMBYTES] seed, reg } */ +/* OLD_KECCAK inline fn __shake256_squeezenblocks4x(reg ptr u256[25] state, reg ptr u8[NOISE_NBLOCKS * SHAKE256_RATE] buf0 buf1 buf2 buf3) -> reg ptr u256[25], reg ptr u8[NOISE_NBLOCKS*SHAKE256_RATE], reg ptr u8[NOISE_NBLOCKS*SHAKE256_RATE], reg ptr u8[NOISE_NBLOCKS*SHAKE256_RATE], reg ptr u8[NOISE_NBLOCKS*SHAKE256_RATE] { @@ -808,7 +814,47 @@ fn _poly_getnoise_eta1_4x(reg ptr u16[MLKEM_N] r0 r1 r2 r3, reg ptr u8[MLKEM_SYM return r0, r1, r2, r3; } +*/ + +#[returnaddress="stack"] +fn _poly_getnoise_eta1_4x +( reg ptr u16[MLKEM_N] r0 r1 r2 r3 +, reg ptr u8[MLKEM_SYMBYTES] seed +, reg u8 nonce +) -> reg ptr u16[MLKEM_N] + , reg ptr u16[MLKEM_N] + , reg ptr u16[MLKEM_N] + , reg ptr u16[MLKEM_N] +{ + stack u8[128] buf0_s buf1_s buf2_s buf3_s; + stack u8[4] nonces; + reg ptr u8[128] buf0, buf1, buf2, buf3; + + buf0 = buf0_s; buf1 = buf1_s; buf2 = buf2_s; buf3 = buf3_s; + +() = #spill(r0,r1,r2,r3); + + nonces[0] = nonce; + nonce += 1; + nonces[1] = nonce; + nonce += 1; + nonces[2] = nonce; + nonce += 1; + nonces[3] = nonce; + + + buf0, buf1, buf2, buf3 = _shake256x4_A128__A32_A1(buf0, buf1, buf2, buf3, seed, nonces); + +() = #unspill(r0,r1,r2,r3); + r0 = __poly_cbd_eta1(r0, buf0); + r1 = __poly_cbd_eta1(r1, buf1); + r2 = __poly_cbd_eta1(r2, buf2); + r3 = __poly_cbd_eta1(r3, buf3); + + return r0, r1, r2, r3; +} +/* OLD_KECCAK #[returnaddress="stack"] fn _poly_getnoise_eta1122_4x(reg ptr u16[MLKEM_N] r0 r1 r2 r3, reg ptr u8[MLKEM_SYMBYTES] seed, reg u8 nonce) -> reg ptr u16[MLKEM_N], reg ptr u16[MLKEM_N], reg ptr u16[MLKEM_N], reg ptr u16[MLKEM_N] { @@ -840,6 +886,41 @@ fn _poly_getnoise_eta1122_4x(reg ptr u16[MLKEM_N] r0 r1 r2 r3, reg ptr u8[MLKEM_ return r0, r1, r2, r3; } +*/ + +#[returnaddress="stack"] +fn _poly_getnoise_eta1122_4x +( reg ptr u16[MLKEM_N] r0 r1 r2 r3 +, reg ptr u8[MLKEM_SYMBYTES] seed +, reg u8 nonce +) -> reg ptr u16[MLKEM_N] + , reg ptr u16[MLKEM_N] + , reg ptr u16[MLKEM_N] + , reg ptr u16[MLKEM_N] +{ + stack u8[128] buf0_s buf1_s buf2_s buf3_s; + stack u8[4] nonces; + reg ptr u8[128] buf0, buf1, buf2, buf3; + + buf0 = buf0_s; buf1 = buf1_s; buf2 = buf2_s; buf3 = buf3_s; + + nonces[0] = nonce; + nonce += 1; + nonces[1] = nonce; + nonce += 1; + nonces[2] = nonce; + nonce += 1; + nonces[3] = nonce; + + buf0, buf1, buf2, buf3 = _shake256x4_A128__A32_A1(buf0, buf1, buf2, buf3, seed, nonces); + + r0 = __poly_cbd_eta1(r0, buf0); + r1 = __poly_cbd_eta1(r1, buf1); + r2 = __poly_cbd_eta2(r2, buf2); + r3 = __poly_cbd_eta2(r3, buf3); + + return r0, r1, r2, r3; +} inline diff --git a/code/jasmin/mlkem_avx2/test/test_kem.c b/code/jasmin/mlkem_avx2/test/test_kem.c index 7fbca06b..524fd67d 100644 --- a/code/jasmin/mlkem_avx2/test/test_kem.c +++ b/code/jasmin/mlkem_avx2/test/test_kem.c @@ -25,7 +25,7 @@ int main(void) fclose(urandom); /* TEST KEYPAIR */ - jade_kem_mlkem_mlkem768_amd64_avx2v_keypair_derand(pk1, sk1, randomness0); + jade_kem_mlkem_mlkem768_amd64_avx2_keypair_derand(pk1, sk1, randomness0); crypto_kem_keypair(pk0, sk0, randomness0); for(int i=0;i