diff --git a/src/sig_stfl/lms/sig_stfl_lms_functions.c b/src/sig_stfl/lms/sig_stfl_lms_functions.c index 59265f3110..018b04b21e 100644 --- a/src/sig_stfl/lms/sig_stfl_lms_functions.c +++ b/src/sig_stfl/lms/sig_stfl_lms_functions.c @@ -7,8 +7,9 @@ #include "external/hss_verify_inc.h" #include "external/hss_sign_inc.h" #include "external/hss.h" +#include "external/endian.h" +#include "external/hss_internal.h" #include "sig_stfl_lms_wrap.h" -#include #define DEFAULT_AUX_DATA 10916 /* Use 10+k of aux data (which works well */ /* with the above default parameter set) */ @@ -136,17 +137,39 @@ OQS_API OQS_STATUS OQS_SIG_STFL_alg_lms_verify(const uint8_t *message, size_t me } OQS_API OQS_STATUS OQS_SIG_STFL_lms_sigs_left(unsigned long long *remain, const OQS_SIG_STFL_SECRET_KEY *secret_key) { + OQS_STATUS status; + uint8_t *priv_key = NULL; + unsigned long long total_sigs = 0; + sequence_t current_count = 0; + oqs_lms_key_data *oqs_key_data = NULL; if (remain == NULL || secret_key == NULL) { return OQS_ERROR; } + + status = OQS_SIG_STFL_lms_sigs_total(&total_sigs, secret_key); + if (status != OQS_SUCCESS) { + return OQS_ERROR; + } + /* Lock secret key to ensure data integrity use */ if ((secret_key->lock_key) && (secret_key->mutex)) { secret_key->lock_key(secret_key->mutex); } - remain = 0; + oqs_key_data = secret_key->secret_key_data; + if (oqs_key_data == NULL) { + goto err; + } + priv_key = oqs_key_data->sec_key; + if (priv_key == NULL) { + goto err; + } + + current_count = get_bigendian(priv_key + PRIVATE_KEY_INDEX, PRIVATE_KEY_INDEX_LEN /*0, 8 */); + *remain = (total_sigs - (unsigned long long)current_count); +err: /* Unlock secret key */ if ((secret_key->unlock_key) && (secret_key->mutex)) { secret_key->unlock_key(secret_key->mutex); @@ -156,11 +179,38 @@ OQS_API OQS_STATUS OQS_SIG_STFL_lms_sigs_left(unsigned long long *remain, const OQS_API OQS_STATUS OQS_SIG_STFL_lms_sigs_total(unsigned long long *total, const OQS_SIG_STFL_SECRET_KEY *secret_key) { + uint8_t *priv_key = NULL; + oqs_lms_key_data *oqs_key_data = NULL; + struct hss_working_key *working_key = NULL; + + if (total == NULL || secret_key == NULL) { return OQS_ERROR; } - total = 0; + oqs_key_data = secret_key->secret_key_data; + if (!oqs_key_data) { + return OQS_ERROR; + } + + priv_key = oqs_key_data->sec_key; + if (!priv_key) { + return OQS_ERROR; + } + + working_key = hss_load_private_key(NULL, priv_key, + 0, + NULL, + 0, + 0); + if (!working_key) { + return OQS_ERROR; + } + + + + *total = (unsigned long long)working_key->max_count; + OQS_MEM_secure_free(working_key, sizeof(struct hss_working_key)); return OQS_SUCCESS; } @@ -293,7 +343,7 @@ int oqs_sig_stfl_lms_keypair(uint8_t *pk, OQS_SIG_STFL_SECRET_KEY *sk, const uin memcpy(pk, oqs_key_data->public_key, len_public_key); sk->secret_key_data = oqs_key_data; } else { - OQS_MEM_insecure_free(oqs_key_data->sec_key); + OQS_MEM_secure_free(oqs_key_data->sec_key, sk->length_secret_key * sizeof(uint8_t)); OQS_MEM_insecure_free(oqs_key_data->aux_data); OQS_MEM_insecure_free(oqs_key_data); oqs_key_data = NULL; @@ -329,7 +379,6 @@ int oqs_sig_stfl_lms_sign(OQS_SIG_STFL_SECRET_KEY *sk, 0, 0); if (!w) { - printf( "Error loading private key\n" ); hss_free_working_key(w); return 0; } @@ -340,14 +389,12 @@ int oqs_sig_stfl_lms_sign(OQS_SIG_STFL_SECRET_KEY *sk, sig_len = hss_get_signature_len_from_working_key(w); if (sig_len == 0) { - printf( "Error getting signature len\n" ); hss_free_working_key(w); return 0; } sig = malloc(sig_len); if (!sig) { - printf( "Error during malloc\n" ); hss_free_working_key(w); return -1; } diff --git a/tests/test_sig_stfl.c b/tests/test_sig_stfl.c index 8a3cb2252d..3a1ed49293 100644 --- a/tests/test_sig_stfl.c +++ b/tests/test_sig_stfl.c @@ -677,6 +677,24 @@ static OQS_STATUS sig_stfl_test_secret_key(const char *method_name) { goto err; } + /* + * Get max num signature and the amount remaining + */ + unsigned long long num_sig_left = 0, max_num_sigs = 0; + rc = OQS_SIG_STFL_sigs_total((const OQS_SIG_STFL *)sig_obj, &max_num_sigs, (const OQS_SIG_STFL_SECRET_KEY *)sk); + if (rc != OQS_SUCCESS) { + fprintf(stderr, "OQS STFL key: Failed to get max number of sig from %s.\n", method_name); + goto err; + } + printf("%s Maximum num of sign operations = %llu\n", method_name, max_num_sigs); + + rc = OQS_SIG_STFL_sigs_remaining((const OQS_SIG_STFL *)sig_obj, &num_sig_left, (const OQS_SIG_STFL_SECRET_KEY *)sk); + if (rc != OQS_SUCCESS) { + fprintf(stderr, "OQS STFL key: Failed to get the remaining number of sig from %s.\n", method_name); + goto err; + } + printf("%s Remaining number of sign operations = %llu\n", method_name, num_sig_left); + /* write sk key to disk */ rc = OQS_SECRET_KEY_STFL_serialize_key(sk, &to_file_sk_len, &to_file_sk_buf); if (rc != OQS_SUCCESS) { @@ -837,6 +855,25 @@ static OQS_STATUS sig_stfl_test_sig_gen(const char *method_name) { } + /* + * Get max num signature and the amount remaining + */ + unsigned long long num_sig_left = 0, max_num_sigs = 0; + rc = OQS_SIG_STFL_sigs_total((const OQS_SIG_STFL *)lock_test_sig_obj, &max_num_sigs, (const OQS_SIG_STFL_SECRET_KEY *)lock_test_sk); + if (rc != OQS_SUCCESS) { + fprintf(stderr, "OQS STFL key: Failed to get max number of sig from %s.\n", method_name); + goto err; + } + printf("%s Maximum num of sign operations = %llu\n", method_name, max_num_sigs); + + rc = OQS_SIG_STFL_sigs_remaining((const OQS_SIG_STFL *)lock_test_sig_obj, &num_sig_left, (const OQS_SIG_STFL_SECRET_KEY *)lock_test_sk); + if (rc != OQS_SUCCESS) { + fprintf(stderr, "OQS STFL key: Failed to get the remaining number of sig from %s.\n", method_name); + goto err; + } + printf("%s Remaining number of sign operations = %llu\n", method_name, num_sig_left); + + printf("================================================================================\n"); printf("Sig Gen 1 %s\n", method_name); printf("================================================================================\n"); @@ -850,7 +887,23 @@ static OQS_STATUS sig_stfl_test_sig_gen(const char *method_name) { goto err; } - sleep(3); + /* + * Get max num signature and the amount remaining + */ + num_sig_left = 0, max_num_sigs = 0; + rc = OQS_SIG_STFL_sigs_total((const OQS_SIG_STFL *)lock_test_sig_obj, &max_num_sigs, (const OQS_SIG_STFL_SECRET_KEY *)lock_test_sk); + if (rc != OQS_SUCCESS) { + fprintf(stderr, "OQS STFL key: Failed to get max number of sig from %s.\n", method_name); + goto err; + } + printf("%s Maximum num of sign operations = %llu\n", method_name, max_num_sigs); + + rc = OQS_SIG_STFL_sigs_remaining((const OQS_SIG_STFL *)lock_test_sig_obj, &num_sig_left, (const OQS_SIG_STFL_SECRET_KEY *)lock_test_sk); + if (rc != OQS_SUCCESS) { + fprintf(stderr, "OQS STFL key: Failed to get the remaining number of sig from %s.\n", method_name); + goto err; + } + printf("%s Remaining number of sign operations = %llu\n", method_name, num_sig_left); printf("================================================================================\n"); printf("Sig Gen 2 %s\n", method_name); @@ -868,6 +921,25 @@ static OQS_STATUS sig_stfl_test_sig_gen(const char *method_name) { printf("================================================================================\n"); printf("Stateful Key Gen %s Passed.\n", method_name); printf("================================================================================\n"); + + /* + * Get max num signature and the amount remaining + */ + num_sig_left = 0, max_num_sigs = 0; + rc = OQS_SIG_STFL_sigs_total((const OQS_SIG_STFL *)lock_test_sig_obj, &max_num_sigs, (const OQS_SIG_STFL_SECRET_KEY *)lock_test_sk); + if (rc != OQS_SUCCESS) { + fprintf(stderr, "OQS STFL key: Failed to get max number of sig from %s.\n", method_name); + goto err; + } + printf("%s Maximum num of sign operations = %llu\n", method_name, max_num_sigs); + + rc = OQS_SIG_STFL_sigs_remaining((const OQS_SIG_STFL *)lock_test_sig_obj, &num_sig_left, (const OQS_SIG_STFL_SECRET_KEY *)lock_test_sk); + if (rc != OQS_SUCCESS) { + fprintf(stderr, "OQS STFL key: Failed to get the remaining number of sig from %s.\n", method_name); + goto err; + } + printf("%s Remaining number of sign operations = %llu\n", method_name, num_sig_left); + goto end_it; err: rc = OQS_ERROR;