Skip to content

Commit

Permalink
Add allocator check in tests/test_code_conventions.py
Browse files Browse the repository at this point in the history
Signed-off-by: Songling Han <[email protected]>
  • Loading branch information
songlingatpan committed Sep 24, 2024
1 parent c4b647e commit f47e341
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 52 deletions.
37 changes: 19 additions & 18 deletions src/common/common.c
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,9 @@ OQS_API int OQS_MEM_secure_bcmp(const void *a, const void *b, size_t len) {
}

OQS_API void OQS_MEM_cleanse(void *ptr, size_t len) {
if (ptr == NULL) {
return;
}
#if defined(OQS_USE_OPENSSL)
OSSL_FUNC(OPENSSL_cleanse)(ptr, len);
#elif defined(_WIN32)
Expand All @@ -267,20 +270,19 @@ OQS_API void OQS_MEM_cleanse(void *ptr, size_t len) {
explicit_memset(ptr, 0, len);
#elif defined(__STDC_LIB_EXT1__) || defined(OQS_HAVE_MEMSET_S)
if (0U < len && memset_s(ptr, (rsize_t)len, 0, (rsize_t)len) != 0) {
abort();
return; //abort();
}
#else
typedef void *(*memset_t)(void *, int, size_t);
static volatile memset_t memset_func = memset;
memset_func(ptr, 0, len);
#endif
}

void *OQS_MEM_checked_malloc(size_t len) {
void *ptr = OQS_MEM_malloc(len);
if (ptr == NULL) {
fprintf(stderr, "Memory allocation failed\n");
abort();
return NULL; //abort();
}

return ptr;
Expand All @@ -290,7 +292,7 @@ void *OQS_MEM_checked_aligned_alloc(size_t alignment, size_t size) {
void *ptr = OQS_MEM_aligned_alloc(alignment, size);
if (ptr == NULL) {
fprintf(stderr, "Memory allocation failed\n");
abort();
return NULL; //abort();
}

return ptr;
Expand All @@ -299,7 +301,7 @@ void *OQS_MEM_checked_aligned_alloc(size_t alignment, size_t size) {
OQS_API void OQS_MEM_secure_free(void *ptr, size_t len) {
if (ptr != NULL) {
OQS_MEM_cleanse(ptr, len);
OQS_MEM_insecure_free(ptr); // IGNORE free-check
OQS_MEM_insecure_free(ptr);
}
}

Expand Down Expand Up @@ -372,7 +374,7 @@ void *OQS_MEM_aligned_alloc(size_t alignment, size_t size) {
// |
// diff = ptr - buffer
const size_t offset = alignment - 1 + sizeof(uint8_t);
uint8_t *buffer = malloc(size + offset);
uint8_t *buffer = malloc(size + offset); // IGNORE memory-check
if (!buffer) {
return NULL;
}
Expand All @@ -382,7 +384,7 @@ void *OQS_MEM_aligned_alloc(size_t alignment, size_t size) {
ptrdiff_t diff = ptr - buffer;
if (diff > UINT8_MAX) {
// This should never happen in our code, but just to be safe
free(buffer); // IGNORE free-check
free(buffer); // IGNORE memory-check
errno = EINVAL;
return NULL;
}
Expand All @@ -395,24 +397,23 @@ void *OQS_MEM_aligned_alloc(size_t alignment, size_t size) {
}

void OQS_MEM_aligned_free(void *ptr) {
if (ptr == NULL) {
return;
}
#if defined(OQS_USE_OPENSSL)
// Use OpenSSL's free function
if (ptr) {
uint8_t *u8ptr = ptr;
OPENSSL_free(u8ptr - u8ptr[-1]);
}
uint8_t *u8ptr = ptr;
OPENSSL_free(u8ptr - u8ptr[-1]);
#elif defined(OQS_HAVE_ALIGNED_ALLOC) || defined(OQS_HAVE_POSIX_MEMALIGN) || defined(OQS_HAVE_MEMALIGN)
free(ptr); // IGNORE free-check
free(ptr); // IGNORE memory-check
#elif defined(__MINGW32__) || defined(__MINGW64__)
__mingw_aligned_free(ptr);
#elif defined(_MSC_VER)
_aligned_free(ptr);
#else
if (ptr) {
// Reconstruct the pointer returned from malloc using the difference
// stored one byte ahead of ptr.
uint8_t *u8ptr = ptr;
free(u8ptr - u8ptr[-1]); // IGNORE free-check
}
// Reconstruct the pointer returned from malloc using the difference
// stored one byte ahead of ptr.
uint8_t *u8ptr = ptr;
free(u8ptr - u8ptr[-1]); // IGNORE memory-check
#endif
}
8 changes: 4 additions & 4 deletions src/common/sha2/sha2_c.c
Original file line number Diff line number Diff line change
Expand Up @@ -588,22 +588,22 @@ void oqs_sha2_sha512_inc_ctx_clone_c(sha512ctx *stateout, const sha512ctx *state

/* Destroy the hash state. */
void oqs_sha2_sha224_inc_ctx_release_c(sha224ctx *state) {
OQS_MEM_insecure_free(state->ctx); // IGNORE free-check
OQS_MEM_insecure_free(state->ctx);
}

/* Destroy the hash state. */
void oqs_sha2_sha256_inc_ctx_release_c(sha256ctx *state) {
OQS_MEM_insecure_free(state->ctx); // IGNORE free-check
OQS_MEM_insecure_free(state->ctx);
}

/* Destroy the hash state. */
void oqs_sha2_sha384_inc_ctx_release_c(sha384ctx *state) {
OQS_MEM_insecure_free(state->ctx); // IGNORE free-check
OQS_MEM_insecure_free(state->ctx);
}

/* Destroy the hash state. */
void oqs_sha2_sha512_inc_ctx_release_c(sha512ctx *state) {
OQS_MEM_insecure_free(state->ctx); // IGNORE free-check
OQS_MEM_insecure_free(state->ctx);
}

void oqs_sha2_sha256_inc_blocks_c(sha256ctx *state, const uint8_t *in, size_t inblocks) {
Expand Down
8 changes: 4 additions & 4 deletions src/common/sha3/ossl_sha3x4.c
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ static void SHA3_shake128_x4_inc_squeeze(uint8_t *out0, uint8_t *out1, uint8_t *
OSSL_FUNC(EVP_MD_CTX_copy_ex)(clone, s->mdctx3);
OSSL_FUNC(EVP_DigestFinalXOF)(clone, tmp, s->n_out + outlen);
memcpy(out3, tmp + s->n_out, outlen);
OQS_MEM_insecure_free(tmp); // IGNORE free-check
OQS_MEM_insecure_free(tmp);
}
OSSL_FUNC(EVP_MD_CTX_free)(clone);
s->n_out += outlen;
Expand All @@ -117,7 +117,7 @@ static void SHA3_shake128_x4_inc_ctx_release(OQS_SHA3_shake128_x4_inc_ctx *state
OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx1);
OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx2);
OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx3);
OQS_MEM_insecure_free(s); // IGNORE free-check
OQS_MEM_insecure_free(s);
}

static void SHA3_shake128_x4_inc_ctx_reset(OQS_SHA3_shake128_x4_inc_ctx *state) {
Expand Down Expand Up @@ -215,7 +215,7 @@ static void SHA3_shake256_x4_inc_squeeze(uint8_t *out0, uint8_t *out1, uint8_t *
OSSL_FUNC(EVP_MD_CTX_copy_ex)(clone, s->mdctx3);
OSSL_FUNC(EVP_DigestFinalXOF)(clone, tmp, s->n_out + outlen);
memcpy(out3, tmp + s->n_out, outlen);
OQS_MEM_insecure_free(tmp); // IGNORE free-check
OQS_MEM_insecure_free(tmp);
}
OSSL_FUNC(EVP_MD_CTX_free)(clone);
s->n_out += outlen;
Expand All @@ -238,7 +238,7 @@ static void SHA3_shake256_x4_inc_ctx_release(OQS_SHA3_shake256_x4_inc_ctx *state
OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx1);
OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx2);
OSSL_FUNC(EVP_MD_CTX_free)(s->mdctx3);
OQS_MEM_insecure_free(s); // IGNORE free-check
OQS_MEM_insecure_free(s);
}

static void SHA3_shake256_x4_inc_ctx_reset(OQS_SHA3_shake256_x4_inc_ctx *state) {
Expand Down
10 changes: 5 additions & 5 deletions src/sig_stfl/lms/external/hss_alloc.c
Original file line number Diff line number Diff line change
Expand Up @@ -542,15 +542,15 @@ void hss_free_working_key(struct hss_working_key *w) {
unsigned j, k;
for (j=0; j<MAX_SUBLEVELS; j++)
for (k=0; k<3; k++)
OQS_MEM_insecure_free(tree->subtree[j][k]); // IGNORE free-check
OQS_MEM_insecure_free(tree->subtree[j][k]);
hss_zeroize( tree, sizeof *tree ); /* We have seeds here */
}
OQS_MEM_insecure_free(tree); // IGNORE free-check
OQS_MEM_insecure_free(tree);
}
for (i=0; i<MAX_HSS_LEVELS-1; i++) {
OQS_MEM_insecure_free(w->signed_pk[i]); // IGNORE free-check
OQS_MEM_insecure_free(w->signed_pk[i]);
}
OQS_MEM_insecure_free(w->stack); // IGNORE free-check
OQS_MEM_insecure_free(w->stack);
hss_zeroize( w, sizeof *w ); /* We have secret information here */
OQS_MEM_insecure_free(w); // IGNORE free-check
OQS_MEM_insecure_free(w);
}
4 changes: 2 additions & 2 deletions src/sig_stfl/lms/external/hss_generate.c
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@ bool hss_generate_working_key(
#if DO_FLOATING_POINT
/* Don't leak suborders on an intermediate error */
for (i=0; i<(sequence_t)count_order; i++) {
OQS_MEM_insecure_free( order[i].sub ); // IGNORE free-check
OQS_MEM_insecure_free( order[i].sub );
}
#endif
info->error_code = got_error;
Expand Down Expand Up @@ -831,7 +831,7 @@ bool hss_generate_working_key(
hash_size, tree->h, I);
}

OQS_MEM_insecure_free( sub ); // IGNORE free-check
OQS_MEM_insecure_free( sub );
p_order->sub = 0;
}
#endif
Expand Down
4 changes: 2 additions & 2 deletions src/sig_stfl/lms/external/hss_keygen.c
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ bool hss_generate_private_key(
} else {
hss_zeroize( context, PRIVATE_KEY_LEN );
}
OQS_MEM_insecure_free(temp_buffer); // IGNORE free-check
OQS_MEM_insecure_free(temp_buffer);
return false;
}

Expand Down Expand Up @@ -355,7 +355,7 @@ bool hss_generate_private_key(
/* Hey, what do you know -- it all worked! */
hss_zeroize( private_key, sizeof private_key ); /* Zeroize local copy of */
/* the private key */
OQS_MEM_insecure_free(temp_buffer); // IGNORE free-check
OQS_MEM_insecure_free(temp_buffer);
return true;
}
#endif
Expand Down
10 changes: 5 additions & 5 deletions src/sig_stfl/lms/external/hss_thread_pthread.c
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,13 @@ struct thread_collection *hss_thread_init(int num_thread) {
col->num_thread = num_thread;

if (0 != pthread_mutex_init( &col->lock, 0 )) {
OQS_MEM_insecure_free(col); // IGNORE free-check
OQS_MEM_insecure_free(col);
return 0;
}

if (0 != pthread_mutex_init( &col->write_lock, 0 )) {
pthread_mutex_destroy( &col->lock );
OQS_MEM_insecure_free(col); // IGNORE free-check
OQS_MEM_insecure_free(col);
return 0;
}

Expand Down Expand Up @@ -126,7 +126,7 @@ static void *worker_thread( void *arg ) {
(w->function)(w->x.detail, col);

/* Ok, we did that */
OQS_MEM_insecure_free(w); // IGNORE free-check
OQS_MEM_insecure_free(w);

/* Check if there's anything else to do */
pthread_mutex_lock( &col->lock );
Expand Down Expand Up @@ -219,7 +219,7 @@ void hss_thread_issue_work(struct thread_collection *col,
/* Hmmm, couldn't spawn it; fall back */
default: /* On error condition */
pthread_mutex_unlock( &col->lock );
OQS_MEM_insecure_free(w); // IGNORE free-check
OQS_MEM_insecure_free(w);
function( detail, col );
return;
}
Expand Down Expand Up @@ -277,7 +277,7 @@ void hss_thread_done(struct thread_collection *col) {

pthread_mutex_destroy( &col->lock );
pthread_mutex_destroy( &col->write_lock );
OQS_MEM_insecure_free(col); // IGNORE free-check
OQS_MEM_insecure_free(col);
}

void hss_thread_before_write(struct thread_collection *col) {
Expand Down
36 changes: 24 additions & 12 deletions tests/test_code_conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,26 +48,38 @@ def test_spdx():
print(result)
assert False

# Ensure "free" is not used unprotected in the main OQS code.
@helpers.filtered_test
@pytest.mark.skipif(sys.platform.startswith("win"), reason="Not needed on Windows")
def test_free():
def test_memory_functions():
c_files = []
for path, _, files in os.walk('src'):
c_files += [os.path.join(path,f) for f in files if f[-2:] == '.c']
c_files += [os.path.join(path, f) for f in files if f.endswith('.c')]

memory_functions = ['free', 'malloc', 'calloc', 'realloc', 'strdup']
okay = True

for fn in c_files:
with open(fn) as f:
# Find all lines that contain 'free(' but not '_free('
for no, line in enumerate(f,1):
if not re.match(r'^.*[^_]free\(.*$', line):
content = f.read()
lines = content.splitlines()
for no, line in enumerate(lines, 1):
# Skip comments
if line.strip().startswith('//') or line.strip().startswith('/*'):
continue
if 'IGNORE free-check' in line:
# Check if we're inside a multi-line comment
if '/*' in content[:content.find(line)] and '*/' not in content[:content.find(line)]:
continue
okay = False
print("Suspicious `free` in {}:{}:{}".format(fn,no,line))
assert okay, "'free' is used in some files. These should be changed to 'OQS_MEM_secure_free' or 'OQS_MEM_insecure_free' as appropriate. If you are sure you want to use 'free' in a particular spot, add the comment '// IGNORE free-check' on the line where 'free' occurs."
for func in memory_functions:
if re.search(r'\b{}\('.format(func), line) and not re.search(r'\b_{}\('.format(func), line):
if 'IGNORE memory-check' in line:
continue
okay = False
print(f"Suspicious `{func}` in {fn}:{no}:{line.strip()}")

assert okay, ("Standard memory functions are used in some files. "
"These should be changed to OQS_MEM_* equivalents as appropriate. "
"If you are sure you want to use these functions in a particular spot, "
"add the comment '// IGNORE memory-check' on the line where the function occurs.")

if __name__ == "__main__":
test_memory_functions()
import sys
pytest.main(sys.argv)

0 comments on commit f47e341

Please sign in to comment.