diff --git a/src/common/rand/rand_nist.c b/src/common/rand/rand_nist.c index 66475bd28b..e316c1e34c 100644 --- a/src/common/rand/rand_nist.c +++ b/src/common/rand/rand_nist.c @@ -19,6 +19,7 @@ You are solely responsible for determining the appropriateness of using and dist #include #include +#include #ifdef OQS_USE_OPENSSL #include @@ -127,22 +128,22 @@ void OQS_randombytes_nist_kat(unsigned char *x, size_t xlen) { DRBG_ctx.reseed_counter++; } -OQS_API void OQS_randombytes_nist_kat_get_state(void *out) { - AES256_CTR_DRBG_struct *out_state = (AES256_CTR_DRBG_struct *)out; - if (out_state != NULL) { - memcpy(out_state->Key, DRBG_ctx.Key, sizeof(DRBG_ctx.Key)); - memcpy(out_state->V, DRBG_ctx.V, sizeof(DRBG_ctx.V)); - out_state->reseed_counter = DRBG_ctx.reseed_counter; - } +void OQS_randombytes_nist_kat_get_state(void *out) { + AES256_CTR_DRBG_struct *out_state = (AES256_CTR_DRBG_struct *)out; + if (out_state != NULL) { + memcpy(out_state->Key, DRBG_ctx.Key, sizeof(DRBG_ctx.Key)); + memcpy(out_state->V, DRBG_ctx.V, sizeof(DRBG_ctx.V)); + out_state->reseed_counter = DRBG_ctx.reseed_counter; + } } -OQS_API void OQS_randombytes_nist_kat_set_state(const void *in) { - AES256_CTR_DRBG_struct *in_state = (AES256_CTR_DRBG_struct *)in; - if (in_state != NULL) { - memcpy(DRBG_ctx.Key, in_state->Key, sizeof(DRBG_ctx.Key)); - memcpy(DRBG_ctx.V, in_state->V, sizeof(DRBG_ctx.V)); - DRBG_ctx.reseed_counter = in_state->reseed_counter; - } +void OQS_randombytes_nist_kat_set_state(const void *in) { + AES256_CTR_DRBG_struct *in_state = (AES256_CTR_DRBG_struct *)in; + if (in_state != NULL) { + memcpy(DRBG_ctx.Key, in_state->Key, sizeof(DRBG_ctx.Key)); + memcpy(DRBG_ctx.V, in_state->V, sizeof(DRBG_ctx.V)); + DRBG_ctx.reseed_counter = in_state->reseed_counter; + } } static void AES256_CTR_DRBG_Update(unsigned char *provided_data, unsigned char *Key, unsigned char *V) { diff --git a/tests/kat_sig.c b/tests/kat_sig.c index f6654972e2..89e1f27e2b 100644 --- a/tests/kat_sig.c +++ b/tests/kat_sig.c @@ -243,13 +243,13 @@ OQS_STATUS sig_kat(const char *method_name, bool all) { size_t signature_len = 0; size_t signed_msg_len = 0; OQS_STATUS rc, ret = OQS_ERROR; - OQS_KAT_PRNG *prng = NULL; - int max_count; + OQS_KAT_PRNG *prng = NULL; + int max_count; - prng = OQS_KAT_PRNG_new(method_name); - if (prng == NULL) { - goto err; - } + prng = OQS_KAT_PRNG_new(method_name); + if (prng == NULL) { + goto err; + } sig = OQS_SIG_new(method_name); if (sig == NULL) { @@ -347,7 +347,7 @@ OQS_STATUS sig_kat(const char *method_name, bool all) { OQS_MEM_insecure_free(signature); OQS_MEM_insecure_free(msg); OQS_SIG_free(sig); - OQS_KAT_PRNG_free(prng); + OQS_KAT_PRNG_free(prng); return ret; } diff --git a/tests/test_helpers.c b/tests/test_helpers.c index 291fd87e55..03bf9e172b 100644 --- a/tests/test_helpers.c +++ b/tests/test_helpers.c @@ -1,9 +1,13 @@ +// SPDX-License-Identifier: MIT +#include +#include #include -#include "test_helpers.h" #include +#include // Internal NIST DRBG API +#include // Internal SHA3 API -#define HQC_PRNG_DOMAIN 1 +#include "test_helpers.h" /* HQC PRNG implementation */ @@ -12,15 +16,15 @@ static OQS_SHA3_shake256_inc_ctx hqc_prng_state = { NULL }; // Allocate the state. static void hqc_prng_new(void) { - OQS_SHA3_shake256_inc_init(&hqc_prng_state); + OQS_SHA3_shake256_inc_init(&hqc_prng_state); } // entropy_input must have length 48. // If personalization_string is non-null, its length must also be 48. static void hqc_prng_seed(const uint8_t *entropy_input, const uint8_t *personalization_string) { - uint8_t domain = HQC_PRNG_DOMAIN; - // reset state - OQS_SHA3_shake256_inc_ctx_reset(&hqc_prng_state); + uint8_t domain = 1; + // reset state + OQS_SHA3_shake256_inc_ctx_reset(&hqc_prng_state); OQS_SHA3_shake256_inc_absorb(&hqc_prng_state, entropy_input, 48); if (personalization_string != NULL) { OQS_SHA3_shake256_inc_absorb(&hqc_prng_state, personalization_string, 48); @@ -35,26 +39,25 @@ static void hqc_prng_randombytes(uint8_t *random_array, size_t bytes_to_read) { } static void hqc_prng_get_state(void *out) { - OQS_SHA3_shake256_inc_ctx_clone((OQS_SHA3_shake256_inc_ctx *)out, &hqc_prng_state); + OQS_SHA3_shake256_inc_ctx_clone((OQS_SHA3_shake256_inc_ctx *)out, &hqc_prng_state); } static void hqc_prng_set_state(const void *in) { - OQS_SHA3_shake256_inc_ctx_clone(&hqc_prng_state, (OQS_SHA3_shake256_inc_ctx *)in); + OQS_SHA3_shake256_inc_ctx_clone(&hqc_prng_state, (const OQS_SHA3_shake256_inc_ctx *)in); } -static void hqc_prng_free(void *saved_state) { - OQS_SHA3_shake256_inc_ctx *hqc_saved_state = (OQS_SHA3_shake256_inc_ctx *)saved_state; - if (hqc_saved_state != NULL) { - OQS_SHA3_shake256_inc_ctx_release(hqc_saved_state); - } - if (hqc_prng_state.ctx != NULL) { - OQS_SHA3_shake256_inc_ctx_release(&hqc_prng_state); - hqc_prng_state.ctx = NULL; - } +static void hqc_prng_free(OQS_KAT_PRNG_state *saved_state) { + if (saved_state != NULL) { + OQS_SHA3_shake256_inc_ctx_release(&saved_state->hqc_state); + } + if (hqc_prng_state.ctx != NULL) { + OQS_SHA3_shake256_inc_ctx_release(&hqc_prng_state); + hqc_prng_state.ctx = NULL; + } } /* Additional NIST DRBG implementation */ -static void nist_drbg_free(void *saved_state) {} +static void nist_drbg_free(OQS_KAT_PRNG_state *saved_state) {} /* Helpers for identifying algorithms */ static int is_mceliece(const char *method_name) { @@ -79,67 +82,66 @@ static int is_hqc(const char *method_name) { /* OQS_KAT_PRNG interface implementation */ OQS_KAT_PRNG *OQS_KAT_PRNG_new(const char *method_name) { - OQS_KAT_PRNG *prng = malloc(sizeof(OQS_KAT_PRNG)); - if (prng != NULL) { - prng->max_kats = is_mceliece(method_name) ? 10 : 100; - if (is_hqc(method_name)) { - // set randombytes function - OQS_randombytes_custom_algorithm(&hqc_prng_randombytes); - // reset the PRNG - hqc_prng_new(); - // initialize saved state - OQS_SHA3_shake256_inc_init(&prng->saved_state.hqc_state); - // TODO set callbacks - prng->seed = &hqc_prng_seed; - prng->get_state = &hqc_prng_get_state; - prng->set_state = &hqc_prng_set_state; - prng->free = &hqc_prng_free; - } else { - // set randombytes function - if (OQS_randombytes_switch_algorithm(OQS_RAND_alg_nist_kat) == OQS_SUCCESS) { - // TODO set callbacks - prng->seed = &OQS_randombytes_nist_kat_init_256bit; - prng->get_state = &OQS_randombytes_nist_kat_get_state; - prng->set_state = &OQS_randombytes_nist_kat_set_state; - prng->free = &nist_drbg_free; - } else { - OQS_MEM_insecure_free(prng); - prng = NULL; - } - } - } - return prng; + OQS_KAT_PRNG *prng = malloc(sizeof(OQS_KAT_PRNG)); + if (prng != NULL) { + prng->max_kats = is_mceliece(method_name) ? 10 : 100; + if (is_hqc(method_name)) { + // set randombytes function + OQS_randombytes_custom_algorithm(&hqc_prng_randombytes); + // reset the PRNG + hqc_prng_new(); + // initialize saved state + OQS_SHA3_shake256_inc_init(&prng->saved_state.hqc_state); + // TODO set callbacks + prng->seed = &hqc_prng_seed; + prng->get_state = &hqc_prng_get_state; + prng->set_state = &hqc_prng_set_state; + prng->free = &hqc_prng_free; + } else { + // set randombytes function + if (OQS_randombytes_switch_algorithm(OQS_RAND_alg_nist_kat) == OQS_SUCCESS) { + // TODO set callbacks + prng->seed = &OQS_randombytes_nist_kat_init_256bit; + prng->get_state = &OQS_randombytes_nist_kat_get_state; + prng->set_state = &OQS_randombytes_nist_kat_set_state; + prng->free = &nist_drbg_free; + } else { + OQS_MEM_insecure_free(prng); + prng = NULL; + } + } + } + return prng; } // entropy_input must have length 48. // If personalization_string is non-null, its length must also be 48. void OQS_KAT_PRNG_seed(OQS_KAT_PRNG *prng, const uint8_t *entropy_input, const uint8_t *personalization_string) { - if (prng != NULL) { - prng->seed(entropy_input, personalization_string); - } + if (prng != NULL) { + prng->seed(entropy_input, personalization_string); + } } void OQS_KAT_PRNG_save_state(OQS_KAT_PRNG *prng) { - if (prng != NULL) { - prng->get_state(&prng->saved_state); - } + if (prng != NULL) { + prng->get_state(&prng->saved_state); + } } void OQS_KAT_PRNG_restore_state(OQS_KAT_PRNG *prng) { - if (prng != NULL) { - prng->set_state(&prng->saved_state); - } + if (prng != NULL) { + prng->set_state(&prng->saved_state); + } } void OQS_KAT_PRNG_free(OQS_KAT_PRNG *prng) { - if (prng != NULL) { - // saved_state needs to be handled dynamically - prng->free(&prng->saved_state); - } - OQS_MEM_insecure_free(prng); + if (prng != NULL) { + // saved_state needs to be handled dynamically + prng->free(&prng->saved_state); + } + OQS_MEM_insecure_free(prng); } - /* Displays hexadecimal strings */ void OQS_print_hex_string(const char *label, const uint8_t *str, size_t len) { printf("%-20s (%4zu bytes): ", label, len); @@ -160,4 +162,3 @@ void OQS_fprintBstr(FILE *fp, const char *S, const uint8_t *A, size_t L) { } fprintf(fp, "\n"); } - diff --git a/tests/test_helpers.h b/tests/test_helpers.h index cd34f696be..244e9d4b48 100644 --- a/tests/test_helpers.h +++ b/tests/test_helpers.h @@ -1,3 +1,4 @@ +// SPDX-License-Identifier: MIT #ifndef OQS_TEST_HELPERS_H #define OQS_TEST_HELPERS_H @@ -8,23 +9,23 @@ #include typedef union { - OQS_SHA3_shake256_inc_ctx hqc_state; - // struct definition copied from rand_nist.c - struct { - unsigned char Key[32]; - unsigned char V[16]; - int reseed_counter; - } nist_state; + OQS_SHA3_shake256_inc_ctx hqc_state; + // struct definition copied from rand_nist.c + struct { + unsigned char Key[32]; + unsigned char V[16]; + int reseed_counter; + } nist_state; } OQS_KAT_PRNG_state; typedef struct { - int max_kats; - OQS_KAT_PRNG_state saved_state; - // The caller should use the OQS_KAT_PRNG_* functions instead of these callbacks. + int max_kats; + OQS_KAT_PRNG_state saved_state; + // The caller should use the OQS_KAT_PRNG_* functions instead of these callbacks. void (*seed)(const uint8_t *, const uint8_t *); void (*get_state)(void *); void (*set_state)(const void *); - void (*free)(void *); + void (*free)(OQS_KAT_PRNG_state *); } OQS_KAT_PRNG; OQS_KAT_PRNG *OQS_KAT_PRNG_new(const char *method_name);