Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add checks for ML-KEM keys #2009

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,15 @@ target_link_libraries(vectors_sig PRIVATE ${TEST_DEPS})
add_executable(vectors_kem vectors_kem.c)
target_link_libraries(vectors_kem PRIVATE ${TEST_DEPS})

if(CMAKE_SYSTEM_NAME STREQUAL "Windows" AND BUILD_SHARED_LIBS)
# workaround for Windows .dll
if(MINGW OR MSYS OR CYGWIN OR CMAKE_CROSSCOMPILING)
target_link_options(vectors_kem PRIVATE -Wl,--allow-multiple-definition)
else()
target_link_options(vectors_kem PRIVATE "/FORCE:MULTIPLE")
endif()
endif()

# Enable Valgrind-based timing side-channel analysis for test_kem and test_sig
if(OQS_ENABLE_TEST_CONSTANT_TIME AND NOT OQS_DEBUG_BUILD)
message(WARNING "OQS_ENABLE_TEST_CONSTANT_TIME is incompatible with CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}.")
Expand Down
114 changes: 108 additions & 6 deletions tests/vectors_kem.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,34 @@
#include <sys/stat.h>

#include <oqs/oqs.h>

#include <oqs/sha3.h>
#include "system_info.c"

#ifdef OQS_ENABLE_KEM_ML_KEM
/* macros for sanity checks for encaps and decaps key */
#define ML_KEM_BLOCKSIZE 384
#define ML_KEM_K_MAX 4
#define ML_KEM_N 256
#define ML_KEM_1024_PK_SIZE 1568
SWilson4 marked this conversation as resolved.
Show resolved Hide resolved
#define ML_KEM_Q 3329
#define SHA256_OP_LEN 32
/* since x is 12 bits, max value could be 4095. the below macro uses this to implement a simple time constant mod 3329 */
#define MOD_Q(x) ((x) - ((x >= ML_KEM_Q) * ML_KEM_Q))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't believe we can count on comparison operators being constant time. I suggest doing Barrett reduction here instead, similar to the reference implementation. (That code computes a centred representation; we'd just need an additional addition.)

I realize that this is overkill for testing code, but there's a possibility we use this file as a guide for patching the algorithm source later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct, added a simple mod function as this was a test file. will change it to proper mod function after few tests at my end.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @abhinav-thales, as per discussion in today's OQS meeting, we're OK with this being non–constant time, as long as the comment stating that it is constant time is removed. Feel free to go ahead with that approach if it's simpler for you. (In that case my preference would be to simply use the % operator so that the operation does not appear to be constant time.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @SWilson4 , thanks for the update.
But any particular reason, why non-constant time is ok ? having it time constant would be the ideal scenario IMO.
in the meantime, I have tested modQ using "Barrett reduction" and another approach using shift operators as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @SWilson4 , thanks for the update. But any particular reason, why non-constant time is ok ? having it time constant would be the ideal scenario IMO. in the meantime, I have tested modQ using "Barrett reduction" and another approach using shift operators as well.

I brought it up with the OQS team, and we decided that we're OK with a non–constant time function here because the code is limited to a test file, in the interest of not holding up the PR. If you prefer to submit a constant-time implementation, that's fine too :)

Copy link
Contributor Author

@abhinav-thales abhinav-thales Jan 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @SWilson4 Got it. i will try to submit the code this week

#endif //OQS_ENABLE_KEM_ML_KEM

struct {
const uint8_t *pos;
} prng_state = {
.pos = 0
};

/* MLKEM-specific functions */
static inline bool is_ml_kem(const char *method_name) {
return (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_512))
|| (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_768))
|| (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_1024));
}

static void fprintBstr(FILE *fp, const char *S, const uint8_t *A, size_t L) {
size_t i;
fprintf(fp, "%s", S);
Expand Down Expand Up @@ -58,13 +77,75 @@ static void hexStringToByteArray(const char *hexString, uint8_t *byteArray) {
}
}

/* ML_KEM-specific functions */
static inline bool is_ml_kem(const char *method_name) {
return (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_512))
|| (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_768))
|| (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_1024));
#ifdef OQS_ENABLE_KEM_ML_KEM
static inline bool sanityCheckSK(const uint8_t *sk, const char *method_name) {
/* sanity checks */
if ((NULL == sk) || (NULL == method_name) || (false == is_ml_kem(method_name))) {
fprintf(stderr, "[vectors_kem] %s ERROR: inputs NULL or invalid method !\n", method_name);
return false;
}
/* buffer to hold public key hash */
uint8_t pkdig[SHA256_OP_LEN] = {0};
/* fetch the value of k according to the ML-KEM algorithm as per FIPS-203
K = 2 for ML-KEM-512, K = 3 for ML-KEM-768 & K = 4 for ML-KEM-1024 */
uint8_t K = (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_512)) ? 2 : (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_768)) ? 3 : 4;
SWilson4 marked this conversation as resolved.
Show resolved Hide resolved
/* calcualte hash of the public key(len = 384k+32) stored in private key at offset of 384k */
OQS_SHA3_sha3_256(pkdig, sk + (ML_KEM_BLOCKSIZE * K), (ML_KEM_BLOCKSIZE * K) + 32);
/* compare it with public key hash stored at 768k+32 offset */
if (0 != memcmp(pkdig, sk + (ML_KEM_BLOCKSIZE * K * 2) + 32, SHA256_OP_LEN)) {
return false;
}
return true;
}

static inline bool sanityCheckPK(const uint8_t *pk, size_t pkLen, const char *method_name) {
/* sanity checks */
if ((NULL == pk) || (0 == pkLen) || (NULL == method_name) || (false == is_ml_kem(method_name))) {
fprintf(stderr, "[vectors_kem] %s ERROR: inputs NULL or zero or invalid method !\n", method_name);
return false;
}
unsigned int i, j;
/* fetch the value of k according to the ML-KEM algorithm as per FIPS-203
K = 2 for ML-KEM-512, K = 3 for ML-KEM-768 & K = 4 for ML-KEM-1024 */
uint8_t K = (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_512)) ? 2 : (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_768)) ? 3 : 4;
SWilson4 marked this conversation as resolved.
Show resolved Hide resolved
/* buffer to hold decoded value. max value used, so same buffer could be used for ML-KEM versions
encaps key is of length 384K bytes(384K*8 bits). Grouped into 12-bit values, the buffer requires (384*K*8)/12 = 256*K entries of 12 bits */
uint16_t buffd[ML_KEM_N * ML_KEM_K_MAX] = {0};
/* buffer to hold encoded value */
uint8_t buffe[ML_KEM_1024_PK_SIZE] = {0};
uint16_t *buff_dec;
/* perform byte decoding as per Algo 6 of FIPS 203 */
for (i = 0; i < K; i++) {
buff_dec = &buffd[i * ML_KEM_N];
const uint8_t *curr_pk = &pk[i * ML_KEM_BLOCKSIZE];
for (j = 0; j < ML_KEM_N / 2; j++) {
buff_dec[2 * j] = ((curr_pk[3 * j + 0] >> 0) | ((uint16_t)curr_pk[3 * j + 1] << 8)) & 0xFFF;
buff_dec[2 * j] = MOD_Q(buff_dec[2 * j]);
SWilson4 marked this conversation as resolved.
Show resolved Hide resolved
buff_dec[2 * j + 1] = ((curr_pk[3 * j + 1] >> 4) | ((uint16_t)curr_pk[3 * j + 2] << 4)) & 0xFFF;
buff_dec[2 * j + 1] = MOD_Q(buff_dec[2 * j + 1]);
}
}
/* perform byte encoding as per Algo 5 of FIPS 203 */
for (i = 0; i < K; i++) {
uint16_t t0, t1;
buff_dec = &buffd[i * ML_KEM_N];
uint8_t *buff_enc = &buffe[i * ML_KEM_BLOCKSIZE];
for (j = 0; j < ML_KEM_N / 2; j++) {
t0 = buff_dec[2 * j];
t1 = buff_dec[2 * j + 1];
buff_enc[3 * j + 0] = (uint8_t)(t0 >> 0);
buff_enc[3 * j + 1] = (uint8_t)((t0 >> 8) | (t1 << 4));
buff_enc[3 * j + 2] = (uint8_t)(t1 >> 4);
}
}
/* compare the encoded value with original public key. discard value of `rho(32 bytes)` during comparision as its not encoded */
if (0 != memcmp(buffe, pk, pkLen - 32)) {
return false;
}
return true;
}
#endif //OQS_ENABLE_KEM_ML_KEM

static void MLKEM_randombytes_init(const uint8_t *entropy_input, const uint8_t *personalization_string) {
(void) personalization_string;
prng_state.pos = entropy_input;
Expand Down Expand Up @@ -134,6 +215,13 @@ static OQS_STATUS kem_kg_vector(const char *method_name,
fprintBstr(fh, "ek: ", public_key, kem->length_public_key);
fprintBstr(fh, "dk: ", secret_key, kem->length_secret_key);

#ifdef OQS_ENABLE_KEM_ML_KEM
if ((false == sanityCheckPK(public_key, kem->length_public_key, method_name)) || (false == sanityCheckSK(secret_key, method_name))) {
fprintf(stderr, "[vectors_kem] %s ERROR: generated public key or private key are corrupted !\n", method_name);
goto err;
}
#endif //OQS_ENABLE_KEM_ML_KEM

if (!memcmp(public_key, kg_pk, kem->length_public_key) && !memcmp(secret_key, kg_sk, kem->length_secret_key)) {
ret = OQS_SUCCESS;
} else {
Expand Down Expand Up @@ -208,6 +296,13 @@ static OQS_STATUS kem_vector_encdec_aft(const char *method_name,
goto err;
}

#ifdef OQS_ENABLE_KEM_ML_KEM
if (false == sanityCheckPK(encdec_pk, kem->length_public_key, method_name)) {
fprintf(stderr, "[vectors_kem] %s ERROR: passed encapsulation key is corrupted !\n", method_name);
goto err;
}
#endif //OQS_ENABLE_KEM_ML_KEM

rc = OQS_KEM_encaps(kem, ct_encaps, ss_encaps, encdec_pk);
if (rc != OQS_SUCCESS) {
fprintf(stderr, "[vectors_kem] %s ERROR: OQS_KEM_encaps failed!\n", method_name);
Expand Down Expand Up @@ -273,6 +368,13 @@ static OQS_STATUS kem_vector_encdec_val(const char *method_name,
goto err;
}

#ifdef OQS_ENABLE_KEM_ML_KEM
if (false == sanityCheckSK(encdec_sk, method_name)) {
fprintf(stderr, "[vectors_kem] %s ERROR: passed decapsulation key is corrupted !\n", method_name);
goto err;
}
#endif //OQS_ENABLE_KEM_ML_KEM

rc = OQS_KEM_decaps(kem, ss_decaps, encdec_c, encdec_sk);
if (rc != OQS_SUCCESS) {
fprintf(stderr, "[vectors_kem] %s ERROR: OQS_KEM_encaps failed!\n", method_name);
Expand Down
Loading