Skip to content

Commit

Permalink
Document bounds and remove redundant reductions (#208)
Browse files Browse the repository at this point in the history
* Remove redundant `poly_reduce()` calls

The typical flow of arithmetic data is as follows:
- Forward NTT
- Basemul
- One of:
  * Remove Montgomery factor + invNTT
  * Remove Montgomery factor
- Reduction
- Serialization

At present, `poly_reduce()` features in multple places in this flow,
serving multiple purposes:

First, it is used after the fwd NTT and after the base mul. The purpose
of those `poly_reduce()` calls is to merely keep the data small, with
details of canonicity and sign being irrelevant.

Second, `poly_reduce()` is used ahead of serialization:
Here, we do need _unsigned_ canonical representatives. The C and AArch64
versions of `poly_reduce()` produce _signed_ canonical Barret reduction,
forcing the serialization functions to complete the normalization through a
conditional addition. The AVX2 code uses an _unsigned_ 'almost-canonical'
`poly_reduce()`, giving a representative in [0,q] (inclusive upper bound!)
which are subject to conditional _subtraction_ during serialization.

In a nutshell, conceptually we'd like an unsigned canonical reduction
ahead of serialization, and non-strict (faster) reduction everywhere else.

Looking closer, the uses of `poly_reduce()` after the forward NTT and after
the basemul appear unnecessary.

Reduction after basemul:
- Basemul output fed into `poly_tomont()` is reduced through the Montgomery
  multiplication that's done in `poly_tomont()`.
- Basemul output fed into inv NTT is potentially problematic if not reduced:
  The invNTT does currently _not_ reduce its input. _However_, the invNTT
  does need a scaling by `mont^2/128`. This scaling is currently happening
  at the _end_ of the invNTT. Instead moving it to the _start_ of the invNTT
  reduces the input in the same way as `poly_tomont()` does.
  This change affects both the C reference NTT as well as the AArch64 NTT.

Reduction after fwd NTT: A reduction after the forward NTT is not needed
since base multiplication does not overflow provided _one_ factor is bound
by q, which is always the case in MLKEM.

Signed-off-by: Hanno Becker <[email protected]>

* Document and runtime-check bounds for arithmetic data

This commit documents bounds on arithmetic data as it flows
through common operations like NTT, invNTT and basemul.

Moreover, it introduces debug macros like POLY_BOUND or
POLYVEC_BOUND which can be used in debug builds to check
the documented bounds at runtime. Moreover, at a later
point those assertions should be proved by converting them
to CBMC assertions.

Finally, the bounds show that canonicity of the result of
the NTT is not needed, and this commit removes the Barrett
reduction at the end of the AArch64 ASM NTT.

Signed-off-by: Hanno Becker <[email protected]>

* Reduce pk & sk after unpacking

It's not standards-required for the unpacked sk to be be reduced
(its coefficients can be up to 2^12 in size), but our bounds
reasoning assumes it to be.

For pk, it must be checked at the top-level that the byte stream
unpacks to coefficients in bound, which however has not yet been
implemented. Until that's done, reduce the pk after unpacking, so
that the lower level routines are only called with canonical pk.

Signed-off-by: Hanno Becker <[email protected]>

* Add Python script confirming Barrett/Montgomery relation and bound

Signed-off-by: Hanno Becker <[email protected]>

* Remove TODO in AArch64 intt_clean.S

Signed-off-by: Hanno Becker <[email protected]>

* Add bound for fwd NTT and static non-overflow assertion in keygen

Signed-off-by: Hanno Becker <[email protected]>

* Fix formatting

Signed-off-by: Hanno Becker <[email protected]>

* Address further review feedback

Signed-off-by: Hanno Becker <[email protected]>

* Introduce single contractual bound for [inv]NTT output

Reasoning about safety of the C 'frontend' should not depend
on details of the arithmetic backend (ref/native). We thus
introduce single bounds NTT_BOUND and INVNTT_BOUND on the
absolute values of the output of the [inv]NTT that any
implementation needs to satisfy.

For every specific implementation, we still define and check
(in debug builds) for tighter bounds, plus add a static assertion
that the implementation-specific bound is smaller than the
contractual bound.

Signed-off-by: Hanno Becker <[email protected]>

* Rewrite bounds-checking macros to work with poly and poly_mulcache

Signed-off-by: Hanno Becker <[email protected]>

* Fix typo in ntt.c

Signed-off-by: Hanno Becker <[email protected]>

* Document output bounds for poly_mulcache_compute()

Signed-off-by: Hanno Becker <[email protected]>

* Document+Check input bound for polyvec_basemul_acc_montgomery_cached

Signed-off-by: Hanno Becker <[email protected]>

* Add input bounds check for NTT

Signed-off-by: Hanno Becker <[email protected]>

* Make wording of bounds estimates more accessible

Signed-off-by: Hanno Becker <[email protected]>

* Document implementation-defined C behaviour in montgomery_reduce()

See pq-crystals/kyber#77

Signed-off-by: Hanno Becker <[email protected]>

* Run functional CI in debugging mode

Signed-off-by: Hanno Becker <[email protected]>

---------

Signed-off-by: Hanno Becker <[email protected]>
  • Loading branch information
hanno-becker authored Oct 11, 2024
1 parent c6ec86b commit 33773e3
Show file tree
Hide file tree
Showing 22 changed files with 2,883 additions and 2,316 deletions.
10 changes: 9 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,18 @@ jobs:
runs-on: ${{ matrix.target.runner }}
steps:
- uses: actions/checkout@v4
- name: native tests
- name: native build
uses: ./.github/actions/multi-functest
with:
compile_mode: native
func: false
nistkat: false
kat: falst
- name: native tests (+debug)
uses: ./.github/actions/multi-functest
with:
compile_mode: native
cflags: "-DMLKEM_DEBUG"
- name: cross tests (opt only)
if: ${{ matrix.target.runner == 'pqcp-arm64' && (success() || failure()) }}
uses: ./.github/actions/multi-functest
Expand Down
2 changes: 1 addition & 1 deletion mk/schemes.mk
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
SOURCES = $(wildcard mlkem/*.c)
SOURCES = $(wildcard mlkem/*.c) $(wildcard mlkem/debug/*.c)
ifeq ($(OPT),1)
SOURCES += $(wildcard mlkem/native/aarch64/*.[csS]) $(wildcard mlkem/native/x86_64/*.[csS])
CPPFLAGS += -DMLKEM_USE_NATIVE
Expand Down
39 changes: 39 additions & 0 deletions mlkem/debug/debug.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// SPDX-License-Identifier: Apache-2.0
#include "debug.h"

#if defined(MLKEM_DEBUG)

static char debug_buf[256];

void mlkem_debug_check_bounds(const char *file, int line,
const char *description, const int16_t *ptr,
unsigned len, int16_t lower_bound_inclusive,
int16_t upper_bound_inclusive) {
int err = 0;
unsigned i;
for (i = 0; i < len; i++) {
int16_t val = ptr[i];
if (val < lower_bound_inclusive || val > upper_bound_inclusive) {
snprintf(debug_buf, sizeof(debug_buf),
"%s, index %u, value %d out of bounds (%d,%d)", description, i,
(int)val, (int)lower_bound_inclusive,
(int)upper_bound_inclusive);
mlkem_debug_print_error(file, line, debug_buf);
err = 1;
}
}

if (err == 1)
exit(1);
}

void mlkem_debug_print_error(const char *file, int line, const char *msg) {
fprintf(stderr, "[ERROR:%s:%04d] %s\n", file, line, msg);
fflush(stderr);
}

#else /* MLKEM_DEBUG */

int empty_cu_debug;

#endif /* MLKEM_DEBUG */
105 changes: 105 additions & 0 deletions mlkem/debug/debug.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// SPDX-License-Identifier: Apache-2.0
#ifndef MLKEM_DEBUG_H
#define MLKEM_DEBUG_H

#if defined(MLKEM_DEBUG)
#include <stdio.h>
#include <stdlib.h>

/*************************************************
* Name: mlkem_debug_check_bounds
*
* Description: Check whether values in an array of int16_t
* are within specified bounds.
*
* Prints an error message to stderr and calls
* exit(1) if not.
*
* Arguments: - file: filename
* - line: line number
* - description: Textual description of check
* - ptr: Base of array to be checked
* - len: Number of int16_t in ptr
* - lower_bound_inclusive: Inclusive lower bound
* - upper_bound_inclusive: Inclusive upper bound
**************************************************/
void mlkem_debug_check_bounds(const char *file, int line,
const char *description, const int16_t *ptr,
unsigned len, int16_t lower_bound_inclusive,
int16_t upper_bound_inclusive);

/* Print error message to stderr alongside file and line information */
void mlkem_debug_print_error(const char *file, int line, const char *msg);

/* Check absolute bounds in array of int16_t's
* ptr: Base of array, expression of type int16_t*
* len: Number of int16_t in array
* abs_bound: Exclusive upper bound on absolute value to check
* msg: Message to print on failure */
#define BOUND(ptr, len, abs_bound, msg) \
do { \
mlkem_debug_check_bounds(__FILE__, __LINE__, (msg), (int16_t *)(ptr), \
(len), -((abs_bound)-1), ((abs_bound)-1)); \
} while (0)

/* Check absolute bounds on coefficients in polynomial or mulcache
* ptr: poly* or poly_mulcache* pointer to polynomial (cache) to check
* abs_bound: Exclusive upper bound on absolute value to check
* msg: Message to print on failure */
#define POLY_BOUND_MSG(ptr, abs_bound, msg) \
BOUND((ptr)->coeffs, (sizeof((ptr)->coeffs) / sizeof(int16_t)), (abs_bound), \
msg)

/* Check absolute bounds on coefficients in polynomial
* ptr: poly* of poly_mulcache* pointer to polynomial (cache) to check
* abs_bound: Exclusive upper bound on absolute value to check */
#define POLY_BOUND(ptr, abs_bound) \
POLY_BOUND_MSG((ptr), (abs_bound), "poly bound for " #ptr)

/* Check absolute bounds on coefficients in vector of polynomials
* ptr: polyvec* or polyvec_mulcache* pointer to vector of polynomials to check
* abs_bound: Exclusive upper bound on absolute value to check */
#define POLYVEC_BOUND(ptr, abs_bound) \
do { \
for (unsigned _debug_polyvec_bound_idx = 0; \
_debug_polyvec_bound_idx < KYBER_K; _debug_polyvec_bound_idx++) \
POLY_BOUND_MSG(&(ptr)->vec[_debug_polyvec_bound_idx], (abs_bound), \
"polyvec bound for " #ptr ".vec[i]"); \
} while (0)

// Following AWS-LC to define a C99-compliant static assert
#define MLKEM_CONCAT(left, right) left##right
#define MLKEM_STATIC_ASSERT_DEFINE(cond, msg) \
typedef struct { \
unsigned int MLKEM_CONCAT(static_assertion_, msg) : (cond) ? 1 : -1; \
} MLKEM_CONCAT(static_assertion_, msg) __attribute__((unused));

#define MLKEM_STATIC_ASSERT_ADD_LINE0(cond, suffix) \
MLKEM_STATIC_ASSERT_DEFINE(cond, MLKEM_CONCAT(at_line_, suffix))
#define MLKEM_STATIC_ASSERT_ADD_LINE1(cond, line, suffix) \
MLKEM_STATIC_ASSERT_ADD_LINE0(cond, MLKEM_CONCAT(line, suffix))
#define MLKEM_STATIC_ASSERT_ADD_LINE2(cond, suffix) \
MLKEM_STATIC_ASSERT_ADD_LINE1(cond, __LINE__, suffix)
#define MLKEM_STATIC_ASSERT_ADD_ERROR(cond, suffix) \
MLKEM_STATIC_ASSERT_ADD_LINE2(cond, MLKEM_CONCAT(_error_is_, suffix))
#define STATIC_ASSERT(cond, error) MLKEM_STATIC_ASSERT_ADD_ERROR(cond, error)

#else /* MLKEM_DEBUG */

#define BOUND(...) \
do { \
} while (0)
#define POLY_BOUND(...) \
do { \
} while (0)
#define POLYVEC_BOUND(...) \
do { \
} while (0)
#define POLY_BOUND_MSG(...) \
do { \
} while (0)
#define STATIC_ASSERT(...)

#endif /* MLKEM_DEBUG */

#endif /* MLKEM_DEBUG_H */
26 changes: 26 additions & 0 deletions mlkem/indcpa.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "symmetric.h"

#include "arith_native.h"
#include "debug/debug.h"

/*************************************************
* Name: pack_pk
Expand All @@ -29,6 +30,7 @@
**************************************************/
static void pack_pk(uint8_t r[KYBER_INDCPA_PUBLICKEYBYTES], polyvec *pk,
const uint8_t seed[KYBER_SYMBYTES]) {
POLYVEC_BOUND(pk, KYBER_Q);
polyvec_tobytes(r, pk);
memcpy(r + KYBER_POLYVECBYTES, seed, KYBER_SYMBYTES);
}
Expand All @@ -48,6 +50,11 @@ static void unpack_pk(polyvec *pk, uint8_t seed[KYBER_SYMBYTES],
const uint8_t packedpk[KYBER_INDCPA_PUBLICKEYBYTES]) {
polyvec_frombytes(pk, packedpk);
memcpy(seed, packedpk + KYBER_POLYVECBYTES, KYBER_SYMBYTES);

// TODO! pk must be subject to a "modulus check" at the top-level
// crypto_kem_enc_derand(). Once that's done, the reduction is no
// longer necessary here.
polyvec_reduce(pk);
}

/*************************************************
Expand All @@ -60,6 +67,7 @@ static void unpack_pk(polyvec *pk, uint8_t seed[KYBER_SYMBYTES],
*key)
**************************************************/
static void pack_sk(uint8_t r[KYBER_INDCPA_SECRETKEYBYTES], polyvec *sk) {
POLYVEC_BOUND(sk, KYBER_Q);
polyvec_tobytes(r, sk);
}

Expand All @@ -76,6 +84,7 @@ static void pack_sk(uint8_t r[KYBER_INDCPA_SECRETKEYBYTES], polyvec *sk) {
static void unpack_sk(polyvec *sk,
const uint8_t packedsk[KYBER_INDCPA_SECRETKEYBYTES]) {
polyvec_frombytes(sk, packedsk);
polyvec_reduce(sk);
}

/*************************************************
Expand Down Expand Up @@ -245,6 +254,9 @@ void gen_matrix(polyvec *a, const uint8_t seed[KYBER_SYMBYTES],
* - const uint8_t *coins: pointer to input randomness
* (of length KYBER_SYMBYTES bytes)
**************************************************/

STATIC_ASSERT(NTT_BOUND + KYBER_Q < INT16_MAX, indcpa_enc_bound_0)

void indcpa_keypair_derand(uint8_t pk[KYBER_INDCPA_PUBLICKEYBYTES],
uint8_t sk[KYBER_INDCPA_SECRETKEYBYTES],
const uint8_t coins[KYBER_SYMBYTES]) {
Expand Down Expand Up @@ -289,8 +301,10 @@ void indcpa_keypair_derand(uint8_t pk[KYBER_INDCPA_PUBLICKEYBYTES],
poly_tomont(&pkpv.vec[i]);
}

// Arithmetic cannot overflow, see static assertion at the top
polyvec_add(&pkpv, &pkpv, &e);
polyvec_reduce(&pkpv);
polyvec_reduce(&skpv);

pack_sk(sk, &skpv);
pack_pk(pk, &pkpv, publicseed);
Expand All @@ -311,6 +325,12 @@ void indcpa_keypair_derand(uint8_t pk[KYBER_INDCPA_PUBLICKEYBYTES],
* - const uint8_t *coins: pointer to input random coins used as
*seed (of length KYBER_SYMBYTES) to deterministically generate all randomness
**************************************************/

// Check that the arithmetic in indcpa_enc() does not overflow
STATIC_ASSERT(INVNTT_BOUND + KYBER_ETA1 < INT16_MAX, indcpa_enc_bound_0)
STATIC_ASSERT(INVNTT_BOUND + KYBER_ETA2 + KYBER_Q < INT16_MAX,
indcpa_enc_bound_1)

void indcpa_enc(uint8_t c[KYBER_INDCPA_BYTES],
const uint8_t m[KYBER_INDCPA_MSGBYTES],
const uint8_t pk[KYBER_INDCPA_PUBLICKEYBYTES],
Expand Down Expand Up @@ -355,6 +375,7 @@ void indcpa_enc(uint8_t c[KYBER_INDCPA_BYTES],
polyvec_invntt_tomont(&b);
poly_invntt_tomont(&v);

// Arithmetic cannot overflow, see static assertion at the top
polyvec_add(&b, &b, &ep);
poly_add(&v, &v, &epp);
poly_add(&v, &v, &k);
Expand All @@ -377,6 +398,10 @@ void indcpa_enc(uint8_t c[KYBER_INDCPA_BYTES],
* - const uint8_t *sk: pointer to input secret key
* (of length KYBER_INDCPA_SECRETKEYBYTES)
**************************************************/

// Check that the arithmetic in indcpa_dec() does not overflow
STATIC_ASSERT(INVNTT_BOUND + KYBER_Q < INT16_MAX, indcpa_dec_bound_0)

void indcpa_dec(uint8_t m[KYBER_INDCPA_MSGBYTES],
const uint8_t c[KYBER_INDCPA_BYTES],
const uint8_t sk[KYBER_INDCPA_SECRETKEYBYTES]) {
Expand All @@ -390,6 +415,7 @@ void indcpa_dec(uint8_t m[KYBER_INDCPA_MSGBYTES],
polyvec_basemul_acc_montgomery(&mp, &skpv, &b);
poly_invntt_tomont(&mp);

// Arithmetic cannot overflow, see static assertion at the top
poly_sub(&mp, &v, &mp);
poly_reduce(&mp);

Expand Down
5 changes: 2 additions & 3 deletions mlkem/native/aarch64/intt_123_45_67_twiddles.S
Original file line number Diff line number Diff line change
Expand Up @@ -477,9 +477,8 @@ roots_l34:
.short 0
.short 0
roots_l012:
// layer 0 root modified to include ninv
.short 266 // originally: 1600
.short 2618 // originally: 15749
.short 1600
.short 15749
.short 40
.short 394
.short 749
Expand Down
Loading

0 comments on commit 33773e3

Please sign in to comment.