diff --git a/tests/vectors_kem.c b/tests/vectors_kem.c index a3dfe99cf..eef566358 100644 --- a/tests/vectors_kem.c +++ b/tests/vectors_kem.c @@ -78,6 +78,20 @@ static void hexStringToByteArray(const char *hexString, uint8_t *byteArray) { } #ifdef OQS_ENABLE_KEM_ML_KEM +/* fetch value of 'K' from MlL-KEM version */ +uint8_t get_ml_kem_k(const char *method) { + if (0 == strcmp(method, OQS_KEM_alg_ml_kem_512)) { + return 2; + } else if (0 == strcmp(method, OQS_KEM_alg_ml_kem_768)) { + return 3; + } else if (0 == strcmp(method, OQS_KEM_alg_ml_kem_1024)) { + return 4; + } else { + return 0; // Default/error case + } +} + +/* sanity check for private/decaps key */ static inline bool sanityCheckSK(const uint8_t *sk, const char *method_name) { /* sanity checks */ if ((NULL == sk) || (NULL == method_name) || (false == is_ml_kem(method_name))) { @@ -88,7 +102,11 @@ static inline bool sanityCheckSK(const uint8_t *sk, const char *method_name) { uint8_t pkdig[SHA256_OP_LEN] = {0}; /* fetch the value of k according to the ML-KEM algorithm as per FIPS-203 K = 2 for ML-KEM-512, K = 3 for ML-KEM-768 & K = 4 for ML-KEM-1024 */ - uint8_t K = (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_512)) ? 2 : (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_768)) ? 3 : 4; + uint8_t K = get_ml_kem_k(method_name); + if (0 == K) { + fprintf(stderr, "K value can be fetched only for ML-KEM !\n"); + return false; + } /* calcualte hash of the public key(len = 384k+32) stored in private key at offset of 384k */ OQS_SHA3_sha3_256(pkdig, sk + (ML_KEM_BLOCKSIZE * K), (ML_KEM_BLOCKSIZE * K) + 32); /* compare it with public key hash stored at 768k+32 offset */ @@ -98,6 +116,7 @@ static inline bool sanityCheckSK(const uint8_t *sk, const char *method_name) { return true; } +/* sanity check for public/encaps key */ static inline bool sanityCheckPK(const uint8_t *pk, size_t pkLen, const char *method_name) { /* sanity checks */ if ((NULL == pk) || (0 == pkLen) || (NULL == method_name) || (false == is_ml_kem(method_name))) { @@ -107,7 +126,11 @@ static inline bool sanityCheckPK(const uint8_t *pk, size_t pkLen, const char *me unsigned int i, j; /* fetch the value of k according to the ML-KEM algorithm as per FIPS-203 K = 2 for ML-KEM-512, K = 3 for ML-KEM-768 & K = 4 for ML-KEM-1024 */ - uint8_t K = (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_512)) ? 2 : (0 == strcmp(method_name, OQS_KEM_alg_ml_kem_768)) ? 3 : 4; + uint8_t K = get_ml_kem_k(method_name); + if (0 == K) { + fprintf(stderr, "K value can be fetched only for ML-KEM !\n"); + return false; + } /* buffer to hold decoded value. max value used, so same buffer could be used for ML-KEM versions encaps key is of length 384K bytes(384K*8 bits). Grouped into 12-bit values, the buffer requires (384*K*8)/12 = 256*K entries of 12 bits */ uint16_t buffd[ML_KEM_N * ML_KEM_K_MAX] = {0}; @@ -119,8 +142,8 @@ static inline bool sanityCheckPK(const uint8_t *pk, size_t pkLen, const char *me buff_dec = &buffd[i * ML_KEM_N]; const uint8_t *curr_pk = &pk[i * ML_KEM_BLOCKSIZE]; for (j = 0; j < ML_KEM_N / 2; j++) { - buff_dec[2 * j] = ((curr_pk[3 * j + 0] >> 0) | ((uint16_t)curr_pk[3 * j + 1] << 8)) & 0xFFF; - buff_dec[2 * j] = MOD_Q(buff_dec[2 * j]); + buff_dec[2 * j + 0] = ((curr_pk[3 * j + 0] >> 0) | ((uint16_t)curr_pk[3 * j + 1] << 8)) & 0xFFF; + buff_dec[2 * j + 0] = MOD_Q(buff_dec[2 * j]); buff_dec[2 * j + 1] = ((curr_pk[3 * j + 1] >> 4) | ((uint16_t)curr_pk[3 * j + 2] << 4)) & 0xFFF; buff_dec[2 * j + 1] = MOD_Q(buff_dec[2 * j + 1]); }