From 6ca81175f7960fd2da5e9d6def5521d7c8f0581b Mon Sep 17 00:00:00 2001 From: Basil Hess Date: Thu, 25 Jan 2024 10:15:43 +0100 Subject: [PATCH] some improvements in test code --- tests/test_kem_vectors.sh | 1 + tests/vectors_kem.c | 40 +++++++++++++++++++------------- tests/vectors_sig.c | 48 +++++++++++++++++++++++++++------------ 3 files changed, 59 insertions(+), 30 deletions(-) diff --git a/tests/test_kem_vectors.sh b/tests/test_kem_vectors.sh index a57d883e2c..0e64ade01f 100644 --- a/tests/test_kem_vectors.sh +++ b/tests/test_kem_vectors.sh @@ -25,6 +25,7 @@ encaps_K=$(grep "encaps_K: " "$file") output=$($build_dir/tests/vectors_kem $1 "$keygen_z$keygen_d$encaps_m" "$encaps_ek" "$encaps_k" "$decaps_dk" "$decaps_c" "$decaps_kprime") if [ $? != 0 ]; then + echo "$output" exit 1 fi diff --git a/tests/vectors_kem.c b/tests/vectors_kem.c index 04296f17d5..c21bbb5e0e 100644 --- a/tests/vectors_kem.c +++ b/tests/vectors_kem.c @@ -14,11 +14,11 @@ #include "system_info.c" -typedef struct { +struct { const uint8_t *pos; -} fixed_prng_state; - -fixed_prng_state prng_state = { .pos = 0 }; +} prng_state = { + .pos = 0 +}; /* Displays hexadecimal strings */ static void OQS_print_hex_string(const char *label, const uint8_t *str, size_t len) { @@ -228,22 +228,28 @@ int main(int argc, char **argv) { char *decaps_ciphertext = argv[6]; char *decaps_kprime = argv[7]; + OQS_KEM *kem = OQS_KEM_new(alg_name); + if (kem == NULL) { + printf("[vectors_kem] %s was not enabled at compile-time.\n", alg_name); + goto err; + } if (strlen(prng_output_stream) % 2 != 0 || - strlen(encaps_pk) % 2 != 0 || - strlen(encaps_K) % 2 != 0 || - strlen(decaps_sk) % 2 != 0 || - strlen(decaps_ciphertext) % 2 != 0 || - strlen(decaps_kprime) % 2 != 0) { - return EXIT_FAILURE; + strlen(encaps_pk) != 2 * kem->length_public_key || + strlen(encaps_K) != 2 * kem->length_shared_secret || + strlen(decaps_sk) != 2 * kem->length_secret_key || + strlen(decaps_ciphertext) != 2 * kem->length_ciphertext || + strlen(decaps_kprime) != 2 * kem->length_shared_secret ) { + rc = OQS_ERROR; + goto err; } - uint8_t *prng_output_stream_bytes = malloc(strlen(prng_output_stream) / 2); // TODO: allocate real sizes and check before to real sizes! - uint8_t *encaps_pk_bytes = malloc(strlen(encaps_pk) / 2); - uint8_t *encaps_K_bytes = malloc(strlen(encaps_K) / 2); - uint8_t *decaps_sk_bytes = malloc(strlen(decaps_sk) / 2); - uint8_t *decaps_ciphertext_bytes = malloc(strlen(decaps_ciphertext) / 2); - uint8_t *decaps_kprime_bytes = malloc(strlen(decaps_kprime) / 2); + uint8_t *prng_output_stream_bytes = malloc(strlen(prng_output_stream) / 2); + uint8_t *encaps_pk_bytes = malloc(kem->length_public_key); + uint8_t *encaps_K_bytes = malloc(kem->length_shared_secret); + uint8_t *decaps_sk_bytes = malloc(kem->length_secret_key); + uint8_t *decaps_ciphertext_bytes = malloc(kem->length_ciphertext); + uint8_t *decaps_kprime_bytes = malloc(kem->length_shared_secret); if ((prng_output_stream_bytes == NULL) || (encaps_pk_bytes == NULL) || (encaps_K_bytes == NULL) || (decaps_sk_bytes == NULL) || (decaps_ciphertext_bytes == NULL) || (decaps_kprime_bytes == NULL)) { fprintf(stderr, "[vectors_kem] ERROR: malloc failed!\n"); @@ -268,6 +274,8 @@ int main(int argc, char **argv) { OQS_MEM_insecure_free(decaps_ciphertext_bytes); OQS_MEM_insecure_free(decaps_kprime_bytes); + OQS_KEM_free(kem); + OQS_destroy(); if (rc != OQS_SUCCESS) { diff --git a/tests/vectors_sig.c b/tests/vectors_sig.c index 1cc2c08d12..8204855f2c 100644 --- a/tests/vectors_sig.c +++ b/tests/vectors_sig.c @@ -14,11 +14,11 @@ #include "system_info.c" -typedef struct { +struct { const uint8_t *pos; -} fixed_prng_state; - -fixed_prng_state prng_state = { .pos = 0 }; +} prng_state = { + .pos = 0 +}; static void fprintBstr(FILE *fp, const char *S, const uint8_t *A, size_t L) { size_t i; @@ -175,6 +175,8 @@ OQS_STATUS sig_vector(const char *method_name, } int main(int argc, char **argv) { + OQS_STATUS rc; + OQS_init(); if (argc != 8) { @@ -203,22 +205,36 @@ int main(int argc, char **argv) { char *verif_msg = argv[7]; size_t verif_msg_len = strlen(verif_msg) / 2; + OQS_SIG *sig = OQS_SIG_new(alg_name); + if (sig == NULL) { + printf("[vectors_sig] %s was not enabled at compile-time.\n", alg_name); + goto err; + } + if (strlen(prng_output_stream) % 2 != 0 || - strlen(sig_msg) % 2 != 0 || - strlen(sig_sk) % 2 != 0 || - strlen(verif_sig) % 2 != 0 || - strlen(verif_pk) % 2 != 0 || - strlen(verif_msg) % 2 != 0) { - return EXIT_FAILURE; + strlen(sig_msg) % 2 != 0 || // variable length + strlen(sig_sk) != 2 * sig->length_secret_key || + strlen(verif_sig) != 2 * sig->length_signature || + strlen(verif_pk) != 2 * sig->length_public_key || + strlen(verif_msg) % 2 != 0) { // variable length + rc = OQS_ERROR; + goto err; } uint8_t *prng_output_stream_bytes = malloc(strlen(prng_output_stream) / 2); uint8_t *sig_msg_bytes = malloc(strlen(sig_msg) / 2); - uint8_t *sig_sk_bytes = malloc(strlen(sig_sk) / 2); - uint8_t *verif_sig_bytes = malloc(strlen(verif_sig) / 2); - uint8_t *verif_pk_bytes = malloc(strlen(verif_pk) / 2); + uint8_t *sig_sk_bytes = malloc(sig->length_secret_key); + uint8_t *verif_sig_bytes = malloc(sig->length_signature); + uint8_t *verif_pk_bytes = malloc(sig->length_public_key); uint8_t *verif_msg_bytes = malloc(strlen(verif_msg) / 2); + if ((prng_output_stream_bytes == NULL) || (sig_msg_bytes == NULL) || (sig_sk_bytes == NULL) || (verif_sig_bytes == NULL) || (verif_pk_bytes == NULL) || (verif_msg_bytes == NULL)) { + fprintf(stderr, "[vectors_sig] ERROR: malloc failed!\n"); + rc = OQS_ERROR; + goto err; + } + + hexStringToByteArray(prng_output_stream, prng_output_stream_bytes); hexStringToByteArray(sig_msg, sig_msg_bytes); hexStringToByteArray(sig_sk, sig_sk_bytes); @@ -226,7 +242,9 @@ int main(int argc, char **argv) { hexStringToByteArray(verif_pk, verif_pk_bytes); hexStringToByteArray(verif_msg, verif_msg_bytes); - OQS_STATUS rc = sig_vector(alg_name, prng_output_stream_bytes, sig_msg_bytes, sig_msg_len, sig_sk_bytes, verif_sig_bytes, verif_pk_bytes, verif_msg_bytes, verif_msg_len); + rc = sig_vector(alg_name, prng_output_stream_bytes, sig_msg_bytes, sig_msg_len, sig_sk_bytes, verif_sig_bytes, verif_pk_bytes, verif_msg_bytes, verif_msg_len); + +err: OQS_MEM_insecure_free(prng_output_stream_bytes); OQS_MEM_insecure_free(sig_msg_bytes); OQS_MEM_insecure_free(sig_sk_bytes); @@ -234,6 +252,8 @@ int main(int argc, char **argv) { OQS_MEM_insecure_free(verif_pk_bytes); OQS_MEM_insecure_free(verif_msg_bytes); + OQS_SIG_free(sig); + OQS_destroy(); if (rc != OQS_SUCCESS) {