diff --git a/src/alloc/allocext/linux.rs b/src/alloc/allocext/linux.rs index bcc5512..9ff23d3 100644 --- a/src/alloc/allocext/linux.rs +++ b/src/alloc/allocext/linux.rs @@ -15,12 +15,12 @@ mod memfd_secret_alloc { use super::*; use core::convert::TryInto; - /// Allocate memfd_secret with given size and optionally at given address ptr + /// Allocate memfd_secret with given size at given address ptr /// Returns tuple of ptr to memory and file descriptor of memfd_secret #[inline] pub unsafe fn alloc_memfd_secret_at_ptr( size: usize, - ptr: Option<*mut libc::c_void>, + ptr: *mut libc::c_void, ) -> Option<(NonNull, libc::c_int)> { let fd: Result = libc::syscall(libc::SYS_memfd_secret, 0).try_into(); @@ -29,23 +29,25 @@ mod memfd_secret_alloc { // File size is set using ftruncate let _ = libc::ftruncate(fd, size as libc::off_t); - let ptr = libc::mmap( - ptr.unwrap_or_else(ptr::null_mut), + let ptr_out = libc::mmap( + ptr, size, Prot::ReadWrite, - libc::MAP_SHARED - | if ptr.is_some() { libc::MAP_FIXED } else { 0 } - | MAP_LOCKED - | MAP_POPULATE, + libc::MAP_SHARED | libc::MAP_FIXED | MAP_LOCKED | MAP_POPULATE, fd, 0, ); - if ptr == libc::MAP_FAILED { + if ptr_out == libc::MAP_FAILED { return None; } - NonNull::new(ptr as *mut u8).map(|ptr| (ptr, fd)) + if ptr_out != ptr { + libc::munmap(ptr_out, size); + return None; + } + + NonNull::new(ptr_out as *mut u8).map(|ptr| (ptr, fd)) } } @@ -95,7 +97,7 @@ unsafe fn _memfd_secret(size: usize) -> Option<*mut u8> { let unprotected_ptr = base_ptr.add(front_guard_size); let Some((unprotected_ptr_val, fd)) = - alloc_memfd_secret_at_ptr(unprotected_size, Some(unprotected_ptr as *mut libc::c_void)) + alloc_memfd_secret_at_ptr(unprotected_size, unprotected_ptr as *mut libc::c_void) else { libc::munmap(base_ptr_stored as *mut libc::c_void, PAGE_SIZE); libc::munmap(base_ptr as *mut libc::c_void, total_size);