From 1e4201bc280c454a5f6501eef682554356c3c21f Mon Sep 17 00:00:00 2001
From: Songling Han <shan@paloaltonetworks.com>
Date: Mon, 14 Oct 2024 21:35:36 +0000
Subject: [PATCH] revert back to abort() for checked cases

Signed-off-by: Songling Han <shan@paloaltonetworks.com>
---
 src/common/sha2/sha2_armv8.c  |  4 ++--
 src/common/sha2/sha2_c.c      | 20 ++++++++++----------
 src/common/sha3/ossl_sha3.c   |  4 ++--
 src/common/sha3/ossl_sha3x4.c |  4 ++--
 src/common/sha3/xkcp_sha3.c   | 23 ++++++++++++++---------
 src/common/sha3/xkcp_sha3x4.c |  4 ++--
 6 files changed, 32 insertions(+), 27 deletions(-)

diff --git a/src/common/sha2/sha2_armv8.c b/src/common/sha2/sha2_armv8.c
index 9bebdb8c2..5e8a6c6c2 100644
--- a/src/common/sha2/sha2_armv8.c
+++ b/src/common/sha2/sha2_armv8.c
@@ -182,7 +182,7 @@ void oqs_sha2_sha256_inc_finalize_armv8(uint8_t *out, sha256ctx *state, const ui
 		// Combine incremental data with final input
 		tmp_in = OQS_MEM_malloc(tmp_len);
 		if (!tmp_in) {
-			return;
+			abort();
 		}
 		memcpy(tmp_in, state->data, state->data_len);
 		if (in && inlen) {
@@ -258,7 +258,7 @@ void oqs_sha2_sha256_inc_blocks_armv8(sha256ctx *state, const uint8_t *in, size_
 	if (state->data_len) {
 		tmp_in = OQS_MEM_malloc(buf_len);
 		if (!tmp_in) {
-			return;
+			abort();
 		}
 		memcpy(tmp_in, state->data, state->data_len);
 		memcpy(tmp_in + state->data_len, in, buf_len - state->data_len);
diff --git a/src/common/sha2/sha2_c.c b/src/common/sha2/sha2_c.c
index 20cf23bad..09277a6c4 100644
--- a/src/common/sha2/sha2_c.c
+++ b/src/common/sha2/sha2_c.c
@@ -504,7 +504,7 @@ static const uint8_t iv_512[64] = {
 void oqs_sha2_sha224_inc_init_c(sha224ctx *state) {
 	state->ctx = OQS_MEM_malloc(PQC_SHA256CTX_BYTES);
 	if (!state->ctx) {
-		return;
+		abort();
 	}
 	for (size_t i = 0; i < 32; ++i) {
 		state->ctx[i] = iv_224[i];
@@ -520,7 +520,7 @@ void oqs_sha2_sha256_inc_init_c(sha256ctx *state) {
 	state->data_len = 0;
 	state->ctx = OQS_MEM_malloc(PQC_SHA256CTX_BYTES);
 	if (!state->ctx) {
-		return;
+		abort();
 	}
 	for (size_t i = 0; i < 32; ++i) {
 		state->ctx[i] = iv_256[i];
@@ -535,7 +535,7 @@ void oqs_sha2_sha256_inc_init_c(sha256ctx *state) {
 void oqs_sha2_sha384_inc_init_c(sha384ctx *state) {
 	state->ctx = OQS_MEM_malloc(PQC_SHA512CTX_BYTES);
 	if (!state->ctx) {
-		return;
+		abort();
 	}
 	for (size_t i = 0; i < 64; ++i) {
 		state->ctx[i] = iv_384[i];
@@ -550,7 +550,7 @@ void oqs_sha2_sha384_inc_init_c(sha384ctx *state) {
 void oqs_sha2_sha512_inc_init_c(sha512ctx *state) {
 	state->ctx = OQS_MEM_malloc(PQC_SHA512CTX_BYTES);
 	if (!state->ctx) {
-		return;
+		abort();
 	}
 	for (size_t i = 0; i < 64; ++i) {
 		state->ctx[i] = iv_512[i];
@@ -565,7 +565,7 @@ void oqs_sha2_sha512_inc_init_c(sha512ctx *state) {
 void oqs_sha2_sha224_inc_ctx_clone_c(sha224ctx *stateout, const sha224ctx *statein) {
 	stateout->ctx = OQS_MEM_malloc(PQC_SHA256CTX_BYTES);
 	if (!stateout->ctx) {
-		return;
+		abort();
 	}
 	stateout->data_len = statein->data_len;
 	memcpy(stateout->data, statein->data, 128);
@@ -575,7 +575,7 @@ void oqs_sha2_sha224_inc_ctx_clone_c(sha224ctx *stateout, const sha224ctx *state
 void oqs_sha2_sha256_inc_ctx_clone_c(sha256ctx *stateout, const sha256ctx *statein) {
 	stateout->ctx = OQS_MEM_malloc(PQC_SHA256CTX_BYTES);
 	if (!stateout->ctx) {
-		return;
+		abort();
 	}
 	stateout->data_len = statein->data_len;
 	memcpy(stateout->data, statein->data, 128);
@@ -585,7 +585,7 @@ void oqs_sha2_sha256_inc_ctx_clone_c(sha256ctx *stateout, const sha256ctx *state
 void oqs_sha2_sha384_inc_ctx_clone_c(sha384ctx *stateout, const sha384ctx *statein) {
 	stateout->ctx = OQS_MEM_malloc(PQC_SHA512CTX_BYTES);
 	if (!stateout->ctx) {
-		return;
+		abort();
 	}
 	stateout->data_len = statein->data_len;
 	memcpy(stateout->data, statein->data, 128);
@@ -595,7 +595,7 @@ void oqs_sha2_sha384_inc_ctx_clone_c(sha384ctx *stateout, const sha384ctx *state
 void oqs_sha2_sha512_inc_ctx_clone_c(sha512ctx *stateout, const sha512ctx *statein) {
 	stateout->ctx = OQS_MEM_malloc(PQC_SHA512CTX_BYTES);
 	if (!stateout->ctx) {
-		return;
+		abort();
 	}
 	stateout->data_len = statein->data_len;
 	memcpy(stateout->data, statein->data, 128);
@@ -632,7 +632,7 @@ void oqs_sha2_sha256_inc_blocks_c(sha256ctx *state, const uint8_t *in, size_t in
 	if (state->data_len) {
 		tmp_in = OQS_MEM_malloc(tmp_buflen);
 		if (!tmp_in) {
-			return;
+			abort();
 		}
 		memcpy(tmp_in, state->data, state->data_len);
 		memcpy(tmp_in + state->data_len, in, tmp_buflen - state->data_len);
@@ -711,7 +711,7 @@ void oqs_sha2_sha256_inc_finalize_c(uint8_t *out, sha256ctx *state, const uint8_
 	} else { //Combine incremental data with final input
 		tmp_in = OQS_MEM_malloc(tmp_len);
 		if (!tmp_in) {
-			return;
+			abort();
 		}
 		memcpy(tmp_in, state->data, state->data_len);
 		if (in && inlen) {
diff --git a/src/common/sha3/ossl_sha3.c b/src/common/sha3/ossl_sha3.c
index 82b00431d..274219c43 100644
--- a/src/common/sha3/ossl_sha3.c
+++ b/src/common/sha3/ossl_sha3.c
@@ -200,7 +200,7 @@ static void SHA3_shake128_inc_squeeze(uint8_t *output, size_t outlen, OQS_SHA3_s
 	} else {
 		uint8_t *tmp = OQS_MEM_malloc(s->n_out + outlen);
 		if (!tmp) {
-			return;
+			abort();
 		}
 		OSSL_FUNC(EVP_DigestFinalXOF)(clone, tmp, s->n_out + outlen);
 		memcpy(output, tmp + s->n_out, outlen);
@@ -277,7 +277,7 @@ static void SHA3_shake256_inc_squeeze(uint8_t *output, size_t outlen, OQS_SHA3_s
 	} else {
 		uint8_t *tmp = OQS_MEM_malloc(s->n_out + outlen);
 		if (!tmp) {
-			return;
+			abort();
 		}
 		OSSL_FUNC(EVP_DigestFinalXOF)(clone, tmp, s->n_out + outlen);
 		memcpy(output, tmp + s->n_out, outlen);
diff --git a/src/common/sha3/ossl_sha3x4.c b/src/common/sha3/ossl_sha3x4.c
index 5d8e45ff7..e8dbd4939 100644
--- a/src/common/sha3/ossl_sha3x4.c
+++ b/src/common/sha3/ossl_sha3x4.c
@@ -83,7 +83,7 @@ static void SHA3_shake128_x4_inc_squeeze(uint8_t *out0, uint8_t *out1, uint8_t *
 	} else {
 		uint8_t *tmp = OQS_MEM_malloc(s->n_out + outlen);
 		if (!tmp) {
-			return;
+			abort();
 		}
 		OSSL_FUNC(EVP_MD_CTX_copy_ex)(clone, s->mdctx0);
 		OSSL_FUNC(EVP_DigestFinalXOF)(clone, tmp, s->n_out + outlen);
@@ -207,7 +207,7 @@ static void SHA3_shake256_x4_inc_squeeze(uint8_t *out0, uint8_t *out1, uint8_t *
 	} else {
 		uint8_t *tmp = OQS_MEM_malloc(s->n_out + outlen);
 		if (!tmp) {
-			return;
+			abort();
 		}
 		OSSL_FUNC(EVP_MD_CTX_copy_ex)(clone, s->mdctx0);
 		OSSL_FUNC(EVP_DigestFinalXOF)(clone, tmp, s->n_out + outlen);
diff --git a/src/common/sha3/xkcp_sha3.c b/src/common/sha3/xkcp_sha3.c
index 400f4191a..4fe0de267 100644
--- a/src/common/sha3/xkcp_sha3.c
+++ b/src/common/sha3/xkcp_sha3.c
@@ -200,9 +200,11 @@ static void SHA3_sha3_256(uint8_t *output, const uint8_t *input, size_t inlen) {
 
 static void SHA3_sha3_256_inc_init(OQS_SHA3_sha3_256_inc_ctx *state) {
 	state->ctx = OQS_MEM_aligned_alloc(KECCAK_CTX_ALIGNMENT, KECCAK_CTX_BYTES);
-	if (state->ctx != NULL) {
-		keccak_inc_reset((uint64_t *)state->ctx);
+
+	if (state->ctx == NULL) {
+		abort();
 	}
+	keccak_inc_reset((uint64_t *)state->ctx);
 }
 
 static void SHA3_sha3_256_inc_absorb(OQS_SHA3_sha3_256_inc_ctx *state, const uint8_t *input, size_t inlen) {
@@ -238,9 +240,10 @@ static void SHA3_sha3_384(uint8_t *output, const uint8_t *input, size_t inlen) {
 
 static void SHA3_sha3_384_inc_init(OQS_SHA3_sha3_384_inc_ctx *state) {
 	state->ctx = OQS_MEM_aligned_alloc(KECCAK_CTX_ALIGNMENT, KECCAK_CTX_BYTES);
-	if (state->ctx != NULL) {
-		keccak_inc_reset((uint64_t *)state->ctx);
+	if (state->ctx == NULL) {
+		abort();
 	}
+	keccak_inc_reset((uint64_t *)state->ctx);
 }
 static void SHA3_sha3_384_inc_absorb(OQS_SHA3_sha3_384_inc_ctx *state, const uint8_t *input, size_t inlen) {
 	keccak_inc_absorb((uint64_t *)state->ctx, OQS_SHA3_SHA3_384_RATE, input, inlen);
@@ -275,9 +278,10 @@ static void SHA3_sha3_512(uint8_t *output, const uint8_t *input, size_t inlen) {
 
 static void SHA3_sha3_512_inc_init(OQS_SHA3_sha3_512_inc_ctx *state) {
 	state->ctx = OQS_MEM_aligned_alloc(KECCAK_CTX_ALIGNMENT, KECCAK_CTX_BYTES);
-	if (state->ctx != NULL) {
-		keccak_inc_reset((uint64_t *)state->ctx);
+	if (state->ctx == NULL) {
+		abort();
 	}
+	keccak_inc_reset((uint64_t *)state->ctx);
 }
 
 static void SHA3_sha3_512_inc_absorb(OQS_SHA3_sha3_512_inc_ctx *state, const uint8_t *input, size_t inlen) {
@@ -320,7 +324,7 @@ static void SHA3_shake128_inc_init(OQS_SHA3_shake128_inc_ctx *state) {
 	}
 	state->ctx = OQS_MEM_aligned_alloc(KECCAK_CTX_ALIGNMENT, KECCAK_CTX_BYTES);
 	if (state->ctx == NULL) {
-		return;
+		abort();
 	}
 	keccak_inc_reset((uint64_t *)state->ctx);
 }
@@ -364,9 +368,10 @@ static void SHA3_shake256(uint8_t *output, size_t outlen, const uint8_t *input,
 
 static void SHA3_shake256_inc_init(OQS_SHA3_shake256_inc_ctx *state) {
 	state->ctx = OQS_MEM_aligned_alloc(KECCAK_CTX_ALIGNMENT, KECCAK_CTX_BYTES);
-	if (state->ctx != NULL) {
-		keccak_inc_reset((uint64_t *)state->ctx);
+	if (state->ctx == NULL) {
+		abort();
 	}
+	keccak_inc_reset((uint64_t *)state->ctx);
 }
 
 static void SHA3_shake256_inc_absorb(OQS_SHA3_shake256_inc_ctx *state, const uint8_t *input, size_t inlen) {
diff --git a/src/common/sha3/xkcp_sha3x4.c b/src/common/sha3/xkcp_sha3x4.c
index e49324806..6b03f0baa 100644
--- a/src/common/sha3/xkcp_sha3x4.c
+++ b/src/common/sha3/xkcp_sha3x4.c
@@ -169,7 +169,7 @@ static void SHA3_shake128_x4(uint8_t *out0, uint8_t *out1, uint8_t *out2, uint8_
 static void SHA3_shake128_x4_inc_init(OQS_SHA3_shake128_x4_inc_ctx *state) {
 	state->ctx = OQS_MEM_aligned_alloc(KECCAK_X4_CTX_ALIGNMENT, KECCAK_X4_CTX_BYTES);
 	if (state->ctx == NULL) {
-		return;
+		abort();
 	}
 	keccak_x4_inc_reset((uint64_t *)state->ctx);
 }
@@ -213,7 +213,7 @@ static void SHA3_shake256_x4(uint8_t *out0, uint8_t *out1, uint8_t *out2, uint8_
 static void SHA3_shake256_x4_inc_init(OQS_SHA3_shake256_x4_inc_ctx *state) {
 	state->ctx = OQS_MEM_aligned_alloc(KECCAK_X4_CTX_ALIGNMENT, KECCAK_X4_CTX_BYTES);
 	if (state->ctx == NULL) {
-		return;
+		abort();
 	}
 	keccak_x4_inc_reset((uint64_t *)state->ctx);
 }