diff --git a/oqsprov/oqs_kem.c b/oqsprov/oqs_kem.c index 76780a16..a56150d4 100644 --- a/oqsprov/oqs_kem.c +++ b/oqsprov/oqs_kem.c @@ -116,13 +116,42 @@ static int oqs_qs_kem_encaps_keyslot(void *vpkemctx, unsigned char *out, OQS_KEM_PRINTF("OQS Warning: OQS_KEM not initialized\n"); return -1; } - *outlen = kem_ctx->length_ciphertext; - *secretlen = kem_ctx->length_shared_secret; + if (pkemctx->kem->comp_pubkey == NULL + || pkemctx->kem->comp_pubkey[keyslot] == NULL) { + OQS_KEM_PRINTF("OQS Warning: public key is NULL\n"); + return -1; + } if (out == NULL || secret == NULL) { - OQS_KEM_PRINTF3("KEM returning lengths %ld and %ld\n", *outlen, - *secretlen); + if (outlen != NULL) { + *outlen = kem_ctx->length_ciphertext; + } + if (secretlen != NULL) { + *secretlen = kem_ctx->length_shared_secret; + } + OQS_KEM_PRINTF3("KEM returning lengths %ld and %ld\n", + kem_ctx->length_ciphertext, + kem_ctx->length_shared_secret); return 1; } + if (outlen == NULL) { + OQS_KEM_PRINTF("OQS Warning: outlen is NULL\n"); + return -1; + } + if (secretlen == NULL) { + OQS_KEM_PRINTF("OQS Warning: secretlen is NULL\n"); + return -1; + } + if (*outlen < kem_ctx->length_ciphertext) { + OQS_KEM_PRINTF("OQS Warning: out buffer too small\n"); + return -1; + } + if (*secretlen < kem_ctx->length_shared_secret) { + OQS_KEM_PRINTF("OQS Warning: secret buffer too small\n"); + return -1; + } + *outlen = kem_ctx->length_ciphertext; + *secretlen = kem_ctx->length_shared_secret; + return OQS_SUCCESS == OQS_KEM_encaps(kem_ctx, out, secret, pkemctx->kem->comp_pubkey[keyslot]); @@ -140,9 +169,36 @@ static int oqs_qs_kem_decaps_keyslot(void *vpkemctx, unsigned char *out, OQS_KEM_PRINTF("OQS Warning: OQS_KEM not initialized\n"); return -1; } - *outlen = kem_ctx->length_shared_secret; - if (out == NULL) + if (pkemctx->kem->comp_privkey == NULL + || pkemctx->kem->comp_privkey[keyslot] == NULL) { + OQS_KEM_PRINTF("OQS Warning: private key is NULL\n"); + return -1; + } + if (out == NULL) { + if (outlen != NULL) { + *outlen = kem_ctx->length_shared_secret; + } + OQS_KEM_PRINTF2("KEM returning length %ld\n", + kem_ctx->length_shared_secret); return 1; + } + if (inlen != kem_ctx->length_ciphertext) { + OQS_KEM_PRINTF("OQS Warning: wrong input length\n"); + return 0; + } + if (in == NULL) { + OQS_KEM_PRINTF("OQS Warning: in is NULL\n"); + return -1; + } + if (outlen == NULL) { + OQS_KEM_PRINTF("OQS Warning: outlen is NULL\n"); + return -1; + } + if (*outlen < kem_ctx->length_shared_secret) { + OQS_KEM_PRINTF("OQS Warning: out buffer too small\n"); + return -1; + } + *outlen = kem_ctx->length_shared_secret; return OQS_SUCCESS == OQS_KEM_decaps(kem_ctx, out, in,