diff --git a/src/alloc/mod.rs b/src/alloc/mod.rs index 77dd6cd..d161f96 100644 --- a/src/alloc/mod.rs +++ b/src/alloc/mod.rs @@ -15,15 +15,15 @@ use getrandom::getrandom; const GARBAGE_VALUE: u8 = 0xd0; const CANARY_SIZE: usize = 16; -static ALLOC_INIT: Once = Once::new(); -static mut PAGE_SIZE: usize = 0; +pub(crate) static ALLOC_INIT: Once = Once::new(); +pub(crate) static mut PAGE_SIZE: usize = 0; static mut PAGE_MASK: usize = 0; static mut CANARY: [u8; CANARY_SIZE] = [0; CANARY_SIZE]; // -- alloc init -- #[inline] -unsafe fn alloc_init() { +pub(crate) unsafe fn alloc_init() { #[cfg(unix)] { PAGE_SIZE = libc::sysconf(libc::_SC_PAGESIZE) as usize; diff --git a/src/mlock.rs b/src/mlock.rs index f85d171..bb7c6db 100644 --- a/src/mlock.rs +++ b/src/mlock.rs @@ -2,6 +2,14 @@ #![cfg(feature = "use_os")] +use crate::alloc::{alloc_init, ALLOC_INIT, PAGE_SIZE}; + +unsafe fn page_size() -> usize { + ALLOC_INIT.call_once(|| alloc_init()); + + PAGE_SIZE +} + /// Cross-platform `mlock`. /// /// * Unix `mlock`. @@ -9,13 +17,27 @@ pub unsafe fn mlock(addr: *mut u8, len: usize) -> bool { #[cfg(unix)] { + let page_size = page_size(); + + let (start_addr, end_addr) = get_page_aligned_addrs(addr, len, page_size); + + let aligned_len = end_addr - start_addr; + #[cfg(target_os = "linux")] - libc::madvise(addr as *mut libc::c_void, len, libc::MADV_DONTDUMP); + libc::madvise( + start_addr as *mut libc::c_void, + aligned_len, + libc::MADV_DONTDUMP, + ); #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))] - libc::madvise(addr as *mut libc::c_void, len, libc::MADV_NOCORE); + libc::madvise( + start_addr as *mut libc::c_void, + aligned_len, + libc::MADV_NOCORE, + ); - libc::mlock(addr as *mut libc::c_void, len) == 0 + libc::mlock(start_addr as *mut libc::c_void, aligned_len) == 0 } #[cfg(windows)] @@ -33,13 +55,27 @@ pub unsafe fn munlock(addr: *mut u8, len: usize) -> bool { #[cfg(unix)] { + let page_size = page_size(); + + let (start_addr, end_addr) = get_page_aligned_addrs(addr, len, page_size); + + let aligned_len = end_addr - start_addr; + #[cfg(target_os = "linux")] - libc::madvise(addr as *mut libc::c_void, len, libc::MADV_DODUMP); + libc::madvise( + start_addr as *mut libc::c_void, + aligned_len, + libc::MADV_DODUMP, + ); #[cfg(any(target_os = "freebsd", target_os = "dragonfly"))] - libc::madvise(addr as *mut libc::c_void, len, libc::MADV_CORE); + libc::madvise( + start_addr as *mut libc::c_void, + aligned_len, + libc::MADV_CORE, + ); - libc::munlock(addr as *mut libc::c_void, len) == 0 + libc::munlock(start_addr as *mut libc::c_void, aligned_len) == 0 } #[cfg(windows)] @@ -47,3 +83,59 @@ pub unsafe fn munlock(addr: *mut u8, len: usize) -> bool { windows_sys::Win32::System::Memory::VirtualUnlock(addr.cast(), len) != 0 } } + +unsafe fn get_page_aligned_addrs(addr: *mut u8, len: usize, ps: usize) -> (usize, usize) { + // start address of the page obtained from masked the value of the page size + // with the memory address + // + let start_addr = (addr as usize) & !(ps - 1); + + // nearest end address of the overlapping page + let end_addr = ((addr as usize) + len + ps - 1) & !(ps - 1); + (start_addr, end_addr) +} + +#[cfg(test)] +mod tests { + use super::*; + + + #[test] + fn test_get_page_aligned_addrs_exact_page_boundary() { + let addr = 0x1000 as *mut u8; + let len = 0x1000; // 4KB + let page_size = 0x1000; // 4KB page size + + unsafe { + let (start_addr, end_addr) = get_page_aligned_addrs(addr, len, page_size); + assert_eq!(start_addr, 0x1000); + assert_eq!(end_addr, 0x2000); + } + } + + #[test] + fn test_get_page_aligned_addrs_with_offset() { + let addr = 0x1234 as *mut u8; + let len = 0x1000; // 4KB + let test_page_size = 0x1000; // 4KB page size + + unsafe { + let (start_addr, end_addr) = get_page_aligned_addrs(addr, len, test_page_size); + assert_eq!(start_addr, 0x1000); + assert_eq!(end_addr, 0x3000); + } + } + + #[test] + fn test_get_page_aligned_addrs_small_length() { + let addr = 0x2000 as *mut u8; + let len = 0x100; // 256 bytes + let page_size = 0x1000; // 4KB page size + + unsafe { + let (start_addr, end_addr) = get_page_aligned_addrs(addr, len, page_size); + assert_eq!(start_addr, 0x2000); + assert_eq!(end_addr, 0x3000); + } + } +}