Skip to content

Commit

Permalink
use inner_serialize to avoid recursive lock
Browse files Browse the repository at this point in the history
  • Loading branch information
ducnguyen-sb committed Nov 1, 2023
1 parent 60e947c commit b2d5670
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 122 deletions.
3 changes: 3 additions & 0 deletions src/sig_stfl/xmss/sig_stfl_xmss.h
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,9 @@ OQS_SIG_STFL_SECRET_KEY *OQS_SECRET_KEY_XMSS_new(size_t length_secret_key);
/* Serialize XMSS secret key data into a byte string, and return an allocated buffer. Users must deallocate the buffer. */
OQS_STATUS OQS_SECRET_KEY_XMSS_serialize_key(uint8_t **sk_buf_ptr, size_t *sk_len, const OQS_SIG_STFL_SECRET_KEY *sk);

/* Only for internal use. Similar to OQS_SECRET_KEY_XMSS_serialize_key, this function does not acquire and release a lock. */
OQS_STATUS OQS_SECRET_KEY_XMSS_inner_serialize_key(uint8_t **sk_buf_ptr, size_t *sk_len, const OQS_SIG_STFL_SECRET_KEY *sk);

/* Deserialize XMSS byte string into an XMSS secret key data */
OQS_STATUS OQS_SECRET_KEY_XMSS_deserialize_key(OQS_SIG_STFL_SECRET_KEY *sk, const size_t sk_len, const uint8_t *sk_buf, void *context);

Expand Down
2 changes: 1 addition & 1 deletion src/sig_stfl/xmss/sig_stfl_xmss_functions.c
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ OQS_API OQS_STATUS OQS_SIG_STFL_alg_xmss_sign(uint8_t *signature, size_t *signat
* regardless, delete signature and the serialized key other wise
*/

status = OQS_SECRET_KEY_XMSS_serialize_key(&sk_key_buf_ptr, &sk_key_buf_len, secret_key);
status = OQS_SECRET_KEY_XMSS_inner_serialize_key(&sk_key_buf_ptr, &sk_key_buf_len, secret_key);
if (status != OQS_SUCCESS) {
goto err;
}
Expand Down
20 changes: 20 additions & 0 deletions src/sig_stfl/xmss/sig_stfl_xmss_secret_key_functions.c
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,26 @@ OQS_STATUS OQS_SECRET_KEY_XMSS_serialize_key(uint8_t **sk_buf_ptr, size_t *sk_le
return OQS_SUCCESS;
}

/* Only for internal use. Similar to OQS_SECRET_KEY_XMSS_serialize_key, but this function does not aquire and release lock. */
OQS_STATUS OQS_SECRET_KEY_XMSS_inner_serialize_key(uint8_t **sk_buf_ptr, size_t *sk_len, const OQS_SIG_STFL_SECRET_KEY *sk) {
if (sk == NULL || sk_len == NULL || sk_buf_ptr == NULL) {
return OQS_ERROR;
}

uint8_t *sk_buf = malloc(sk->length_secret_key * sizeof(uint8_t));
if (sk_buf == NULL) {
return OQS_ERROR;
}

// Simply copy byte string of secret_key_data
memcpy(sk_buf, sk->secret_key_data, sk->length_secret_key);

*sk_buf_ptr = sk_buf;
*sk_len = sk->length_secret_key;

return OQS_SUCCESS;
}

/* Deserialize XMSS byte string into an XMSS secret key data. */
OQS_STATUS OQS_SECRET_KEY_XMSS_deserialize_key(OQS_SIG_STFL_SECRET_KEY *sk, const size_t sk_len, const uint8_t *sk_buf, XMSS_UNUSED_ATT void *context) {
if (sk == NULL || sk_buf == NULL || (sk_len != sk->length_secret_key)) {
Expand Down
2 changes: 1 addition & 1 deletion src/sig_stfl/xmss/sig_stfl_xmssmt_functions.c
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ OQS_API OQS_STATUS OQS_SIG_STFL_alg_xmssmt_sign(uint8_t *signature, size_t *sign
* regardless, delete signature and the serialized key other wise
*/

status = OQS_SECRET_KEY_XMSS_serialize_key(&sk_key_buf_ptr, &sk_key_buf_len, secret_key);
status = OQS_SECRET_KEY_XMSS_inner_serialize_key(&sk_key_buf_ptr, &sk_key_buf_len, secret_key);
if (status != OQS_SUCCESS) {
goto err;
}
Expand Down
212 changes: 92 additions & 120 deletions tests/test_sig_stfl.c
Original file line number Diff line number Diff line change
Expand Up @@ -39,68 +39,6 @@ static pthread_mutex_t *sk_lock = NULL;
*/
#define MAX_MARKER_LEN 50

static OQS_SIG_STFL_SECRET_KEY *lock_test_sk = NULL;
static OQS_SIG_STFL *lock_test_sig_obj = NULL;
static uint8_t *lock_test_public_key = NULL;
static char *lock_test_context = NULL;
static uint8_t *signature_1 = NULL;
static uint8_t *signature_2 = NULL;
static size_t signature_len_1;
static size_t signature_len_2;
static uint8_t message_1[] = "The quick brown fox ...";
static uint8_t message_2[] = "The quick brown fox jumped from the tree.";

/*
* Write stateful secret keys to disk.
*/
static OQS_STATUS save_secret_key(uint8_t *key_buf, size_t buf_len, void *context) {
if (key_buf == NULL || buf_len == 0 || context == NULL) {
return OQS_ERROR;
}
const char *context_char = context;

if (oqs_fstore("sk", context_char, key_buf, buf_len) == OQS_SUCCESS) {
printf("\n================================================================================\n");
printf("Updated STFL SK <%s>.\n", context_char);
printf("================================================================================\n");
return OQS_SUCCESS;
}

return OQS_ERROR;
}

#if OQS_USE_PTHREADS_IN_TESTS
static OQS_STATUS lock_sk_key(void *mutex) {
if (mutex == NULL) {
return OQS_ERROR;
}

if (!(pthread_mutex_lock((pthread_mutex_t *)mutex))) {
return OQS_SUCCESS;
}
return OQS_ERROR;
}

static OQS_STATUS unlock_sk_key(void *mutex) {
if (mutex == NULL) {
return OQS_ERROR;
}

if (!(pthread_mutex_unlock((pthread_mutex_t *)mutex))) {
return OQS_SUCCESS;
}
return OQS_ERROR;
}
#else
static OQS_STATUS lock_sk_key(void *mutex) {
return OQS_SUCCESS;
}

static OQS_STATUS unlock_sk_key(void *mutex) {
return OQS_SUCCESS;
}
#endif

//
// ALLOW TO READ HEXADECIMAL ENTRY (KEYS, DATA, TEXT, etc.)
//
Expand Down Expand Up @@ -193,6 +131,68 @@ int ReadHex(FILE *infile, unsigned char *a, unsigned long Length, char *str) {
return 1;
}

static OQS_SIG_STFL_SECRET_KEY *lock_test_sk = NULL;
static OQS_SIG_STFL *lock_test_sig_obj = NULL;
static uint8_t *lock_test_public_key = NULL;
static char *lock_test_context = NULL;
static uint8_t *signature_1 = NULL;
static uint8_t *signature_2 = NULL;
static size_t signature_len_1;
static size_t signature_len_2;
static uint8_t message_1[] = "The quick brown fox ...";
static uint8_t message_2[] = "The quick brown fox jumped from the tree.";

/*
* Write stateful secret keys to disk.
*/
static OQS_STATUS save_secret_key(uint8_t *key_buf, size_t buf_len, void *context) {
if (key_buf == NULL || buf_len == 0 || context == NULL) {
return OQS_ERROR;
}
const char *context_char = context;

if (oqs_fstore("sk", context_char, key_buf, buf_len) == OQS_SUCCESS) {
printf("\n================================================================================\n");
printf("Updated STFL SK <%s>.\n", context_char);
printf("================================================================================\n");
return OQS_SUCCESS;
}

return OQS_ERROR;
}

#if OQS_USE_PTHREADS_IN_TESTS
static OQS_STATUS lock_sk_key(void *mutex) {
if (mutex == NULL) {
return OQS_ERROR;
}

if (pthread_mutex_lock((pthread_mutex_t *)mutex)) {
return OQS_ERROR;
}
return OQS_SUCCESS;
}

static OQS_STATUS unlock_sk_key(void *mutex) {
if (mutex == NULL) {
return OQS_ERROR;
}

if (pthread_mutex_unlock((pthread_mutex_t *)mutex)) {
return OQS_ERROR;
}
return OQS_SUCCESS;
}
#else
static OQS_STATUS lock_sk_key(void *mutex) {
return OQS_SUCCESS;
}

static OQS_STATUS unlock_sk_key(void *mutex) {
return OQS_SUCCESS;
}
#endif

OQS_STATUS sig_stfl_keypair_from_keygen(OQS_SIG_STFL *sig, uint8_t *public_key, OQS_SIG_STFL_SECRET_KEY *secret_key) {
OQS_STATUS rc;

Expand Down Expand Up @@ -933,18 +933,18 @@ static void TEST_SIG_STFL_randombytes(uint8_t *random_array, size_t bytes_to_rea
#endif

#if OQS_USE_PTHREADS_IN_TESTS
struct thread_data {
typedef struct thread_data {
const char *alg_name;
const char *katfile;
OQS_STATUS rc;
OQS_STATUS rc1;
};
} thread_data_t;

struct lock_test_data {
typedef struct lock_test_data {
const char *alg_name;
const char *katfile;
OQS_STATUS rc;
};
} lock_test_data_t;

void *test_query_key(void *arg) {
struct lock_test_data *td = arg;
Expand Down Expand Up @@ -1014,30 +1014,21 @@ int main(int argc, char **argv) {
OQS_randombytes_switch_algorithm("system");
#endif

OQS_STATUS rc = OQS_SUCCESS, rc1 = OQS_SUCCESS,
rc_lck = OQS_SUCCESS, rc_sig = OQS_SUCCESS, rc_qry = OQS_SUCCESS;
OQS_STATUS rc = OQS_ERROR, rc1 = OQS_ERROR;

#if OQS_USE_PTHREADS_IN_TESTS
#define MAX_LEN_SIG_NAME_ 64
OQS_STATUS rc_create = OQS_ERROR, rc_sign = OQS_ERROR, rc_query = OQS_ERROR;

pthread_t thread;
pthread_t create_key_thread;
pthread_t sign_key_thread;
pthread_t query_key_thread;
struct thread_data td;
td.alg_name = alg_name;
td.katfile = katfile;

struct lock_test_data td_create;
struct lock_test_data td_sign;
struct lock_test_data td_query;
td_create.alg_name = alg_name;
td_sign.alg_name = alg_name;
td_query.alg_name = alg_name;

td_create.katfile = katfile;
td_sign.katfile = katfile;
td_query.katfile = katfile;
pthread_mutexattr_t attr1, attr2;

thread_data_t td = {.alg_name = alg_name, .katfile = katfile, .rc = OQS_ERROR, .rc1 = OQS_ERROR};
lock_test_data_t td_create = {.alg_name = alg_name, .katfile = katfile, .rc = OQS_ERROR};
lock_test_data_t td_sign = {.alg_name = alg_name, .katfile = katfile, .rc = OQS_ERROR};
lock_test_data_t td_query = {.alg_name = alg_name, .katfile = katfile, .rc = OQS_ERROR};

test_sk_lock = (pthread_mutex_t *)malloc(sizeof(pthread_mutex_t));
if (test_sk_lock == NULL) {
Expand All @@ -1048,64 +1039,45 @@ int main(int argc, char **argv) {
goto err;
}

if (pthread_mutexattr_init(&attr1)) {
goto err;
}
if (pthread_mutexattr_init(&attr2)) {
if (pthread_mutex_init(test_sk_lock, NULL) || pthread_mutex_init(sk_lock, NULL)) {
fprintf(stderr, "ERROR: Initializing mutex\n");
goto err;
}

pthread_mutexattr_settype(&attr1, PTHREAD_MUTEX_RECURSIVE);
pthread_mutexattr_settype(&attr2, PTHREAD_MUTEX_RECURSIVE);

if (pthread_mutex_init(test_sk_lock, &attr1)) {
goto err;
}
if (pthread_mutex_init(test_sk_lock, &attr2)) {
if (pthread_create(&thread, NULL, test_wrapper, &td)) {
fprintf(stderr, "ERROR: Creating pthread for test_wrapper\n");
goto err;
}

int trc = pthread_create(&thread, NULL, test_wrapper, &td);
if (trc) {
fprintf(stderr, "ERROR: Creating pthread\n");
OQS_destroy();
return EXIT_FAILURE;
}
pthread_join(thread, NULL);
rc = td.rc;
rc1 = td.rc1;

int trc_2 = pthread_create(&create_key_thread, NULL, test_create_keys, &td_create);
if (trc_2) {
fprintf(stderr, "ERROR: Creating pthread for stateful key gen test\n");
OQS_destroy();
return EXIT_FAILURE;
if (pthread_create(&create_key_thread, NULL, test_create_keys, &td_create)) {
fprintf(stderr, "ERROR: Creating pthread for test_create_keys\n");
goto err;
}
pthread_join(create_key_thread, NULL);
rc_lck = td_create.rc;
rc_create = td_create.rc;

int trc_3 = pthread_create(&sign_key_thread, NULL, test_sig_gen, &td_sign);
if (trc_3) {
fprintf(stderr, "ERROR: Creating pthread for sig gen test\n");
OQS_destroy();
return EXIT_FAILURE;
if (pthread_create(&sign_key_thread, NULL, test_sig_gen, &td_sign)) {
fprintf(stderr, "ERROR: Creating pthread for test_sig_gen\n");
goto err;
}
pthread_join(sign_key_thread, NULL);
rc_sig = td_sign.rc;
rc_sign = td_sign.rc;

int trc_4 = pthread_create(&query_key_thread, NULL, test_query_key, &td_query);
if (trc_4) {
fprintf(stderr, "ERROR: Creating pthread for query key test.\n");
OQS_destroy();
return EXIT_FAILURE;
if (pthread_create(&query_key_thread, NULL, test_query_key, &td_query)) {
fprintf(stderr, "ERROR: Creating pthread for test_query_key\n");
goto err;
}
pthread_join(query_key_thread, NULL);
rc_qry = td_query.rc;
rc_query = td_query.rc;

err:
if (test_sk_lock) {
pthread_mutex_destroy(test_sk_lock);
}

if (sk_lock) {
pthread_mutex_destroy(sk_lock);
}
Expand All @@ -1123,7 +1095,7 @@ int main(int argc, char **argv) {

OQS_destroy();
if (rc != OQS_SUCCESS || rc1 != OQS_SUCCESS ||
rc_lck != OQS_SUCCESS || rc_sig != OQS_SUCCESS || rc_qry != OQS_SUCCESS) {
rc_create != OQS_SUCCESS || rc_sign != OQS_SUCCESS || rc_query != OQS_SUCCESS) {
return EXIT_FAILURE;
}
return EXIT_SUCCESS;
Expand Down

0 comments on commit b2d5670

Please sign in to comment.