Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(feat): Add memfd secret based allocation #16

Merged
merged 16 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 144 additions & 24 deletions src/alloc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@

extern crate std;

#[cfg(target_os = "linux")]
use self::memfd_secret_alloc::*;

use self::raw_alloc::*;
use self::std::process::abort;
use self::std::sync::Once;
use core::mem;
use core::ptr::{ self, NonNull };
use core::ptr::{self, NonNull};
use core::slice;
use getrandom::getrandom;
use self::std::sync::Once;
use self::std::process::abort;
use self::raw_alloc::*;


const GARBAGE_VALUE: u8 = 0xd0;
const CANARY_SIZE: usize = 16;
Expand All @@ -20,16 +22,17 @@ static mut PAGE_SIZE: usize = 0;
static mut PAGE_MASK: usize = 0;
static mut CANARY: [u8; CANARY_SIZE] = [0; CANARY_SIZE];


// -- alloc init --

#[inline]
unsafe fn alloc_init() {
#[cfg(unix)] {
#[cfg(unix)]
{
PAGE_SIZE = libc::sysconf(libc::_SC_PAGESIZE) as usize;
}

#[cfg(windows)] {
#[cfg(windows)]
{
let mut si = mem::MaybeUninit::uninit();
windows_sys::Win32::System::SystemInformation::GetSystemInfo(si.as_mut_ptr());
PAGE_SIZE = (*si.as_ptr()).dwPageSize as usize;
Expand All @@ -44,11 +47,10 @@ unsafe fn alloc_init() {
getrandom(&mut CANARY).unwrap();
}


// -- aligned alloc / aligned free --

mod raw_alloc {
use super::std::alloc::{ alloc, dealloc, Layout };
use super::std::alloc::{alloc, dealloc, Layout};
use super::*;

#[inline]
Expand All @@ -64,6 +66,41 @@ mod raw_alloc {
}
}

#[cfg(target_os = "linux")]
mod memfd_secret_alloc {
use core::convert::TryInto;

use super::*;

#[inline]
pub unsafe fn alloc_memfd_secret(size: usize) -> Option<NonNull<u8>> {
let fd: Result<libc::c_int, _> = libc::syscall(libc::SYS_memfd_secret, 0).try_into();

if fd.is_err() || fd.unwrap() < 0 {
return None;
}

let fd = fd.unwrap();
prabhpreet marked this conversation as resolved.
Show resolved Hide resolved

// File size is set using ftruncate
let _ = libc::ftruncate(fd, size as libc::off_t);

let ptr = libc::mmap(
ptr::null_mut(),
size,
Prot::ReadWrite,
libc::MAP_SHARED,
fd,
0,
);

if ptr == libc::MAP_FAILED {
return None;
}

NonNull::new(ptr as *mut u8)
}
}

// -- mprotect --

Expand Down Expand Up @@ -105,7 +142,6 @@ pub mod Prot {
pub const TargetsNoUpdate: Ty = windows_sys::Win32::System::Memory::PAGE_TARGETS_NO_UPDATE;
}


/// Unix `mprotect`.
#[cfg(unix)]
#[inline]
Expand All @@ -121,7 +157,6 @@ pub unsafe fn _mprotect(ptr: *mut u8, len: usize, prot: Prot::Ty) -> bool {
windows_sys::Win32::System::Memory::VirtualProtect(ptr.cast(), len, prot, old.as_mut_ptr()) != 0
}


/// Secure `mprotect`.
#[cfg(any(unix, windows))]
pub unsafe fn mprotect<T: ?Sized>(memptr: NonNull<T>, prot: Prot::Ty) -> bool {
Expand All @@ -133,7 +168,6 @@ pub unsafe fn mprotect<T: ?Sized>(memptr: NonNull<T>, prot: Prot::Ty) -> bool {
_mprotect(unprotected_ptr, unprotected_size, prot)
}


// -- malloc / free --

#[inline]
Expand Down Expand Up @@ -167,7 +201,11 @@ unsafe fn _malloc(size: usize) -> Option<*mut u8> {

// mprotect ptr
_mprotect(base_ptr.add(PAGE_SIZE), PAGE_SIZE, Prot::NoAccess);
_mprotect(unprotected_ptr.add(unprotected_size), PAGE_SIZE, Prot::NoAccess);
_mprotect(
unprotected_ptr.add(unprotected_size),
PAGE_SIZE,
Prot::NoAccess,
);
crate::mlock(unprotected_ptr, unprotected_size);

let canary_ptr = unprotected_ptr.add(unprotected_size - size_with_canary);
Expand All @@ -181,24 +219,76 @@ unsafe fn _malloc(size: usize) -> Option<*mut u8> {
Some(user_ptr)
}

#[cfg(target_os = "linux")]
unsafe fn _memfd_secret(size: usize) -> Option<*mut u8> {
ALLOC_INIT.call_once(|| alloc_init());

if size >= ::core::usize::MAX - PAGE_SIZE * 4 {
return None;
}

// aligned alloc ptr
let size_with_canary = CANARY_SIZE + size;
let unprotected_size = page_round(size_with_canary);
let total_size = PAGE_SIZE + PAGE_SIZE + unprotected_size + PAGE_SIZE;
let base_ptr = alloc_memfd_secret(total_size)?.as_ptr();
let unprotected_ptr = base_ptr.add(PAGE_SIZE * 2);

// mprotect can be used to change protection flag after mmap setup
// https://www.gnu.org/software/libc/manual/html_node/Memory-Protection.html#index-mprotect
_mprotect(base_ptr.add(PAGE_SIZE), PAGE_SIZE, Prot::NoAccess);
_mprotect(
unprotected_ptr.add(unprotected_size),
PAGE_SIZE,
Prot::NoAccess,
);

let canary_ptr = unprotected_ptr.add(unprotected_size - size_with_canary);
let user_ptr = canary_ptr.add(CANARY_SIZE);
ptr::copy_nonoverlapping(CANARY.as_ptr(), canary_ptr, CANARY_SIZE);
ptr::write_unaligned(base_ptr as *mut usize, unprotected_size);
_mprotect(base_ptr, PAGE_SIZE, Prot::ReadOnly);

assert_eq!(unprotected_ptr_from_user_ptr(user_ptr), unprotected_ptr);

Some(user_ptr)
}

/// Secure `malloc`.
#[inline]
pub unsafe fn malloc<T>() -> Option<NonNull<T>> {
_malloc(mem::size_of::<T>())
.map(|memptr| {
ptr::write_bytes(memptr, GARBAGE_VALUE, mem::size_of::<T>());
NonNull::new_unchecked(memptr as *mut T)
})
_malloc(mem::size_of::<T>()).map(|memptr| {
ptr::write_bytes(memptr, GARBAGE_VALUE, mem::size_of::<T>());
NonNull::new_unchecked(memptr as *mut T)
})
}

/// Secure `malloc_sized`.
#[inline]
pub unsafe fn malloc_sized(size: usize) -> Option<NonNull<[u8]>> {
_malloc(size)
.map(|memptr| {
ptr::write_bytes(memptr, GARBAGE_VALUE, size);
NonNull::new_unchecked(slice::from_raw_parts_mut(memptr, size))
})
_malloc(size).map(|memptr| {
ptr::write_bytes(memptr, GARBAGE_VALUE, size);
NonNull::new_unchecked(slice::from_raw_parts_mut(memptr, size))
})
}

#[inline]
#[cfg(target_os = "linux")]
pub unsafe fn memfd_secret<T>() -> Option<NonNull<T>> {
_memfd_secret(mem::size_of::<T>()).map(|memptr| {
ptr::write_bytes(memptr, GARBAGE_VALUE, mem::size_of::<T>());
NonNull::new_unchecked(memptr as *mut T)
})
}

/// Secure `malloc_sized`.
#[inline]
#[cfg(target_os = "linux")]
pub unsafe fn memfd_secret_sized(size: usize) -> Option<NonNull<[u8]>> {
_memfd_secret(size).map(|memptr| {
ptr::write_bytes(memptr, GARBAGE_VALUE, size);
NonNull::new_unchecked(slice::from_raw_parts_mut(memptr, size))
})
}

/// Secure `free`.
Expand All @@ -224,3 +314,33 @@ pub unsafe fn free<T: ?Sized>(memptr: NonNull<T>) {

free_aligned(base_ptr, total_size);
}

/// Secure `free` for memfd_secret,
/// i.e. provides read write access back to mprotect guard pages
/// and unmaps mmap
#[cfg(target_os = "linux")]
pub unsafe fn free_memfd_secret<T: ?Sized>(memptr: NonNull<T>) {
use libc::c_void;

let memptr = memptr.as_ptr() as *mut u8;

// get unprotected ptr
let canary_ptr = memptr.sub(CANARY_SIZE);
let unprotected_ptr = unprotected_ptr_from_user_ptr(memptr);
let base_ptr = unprotected_ptr.sub(PAGE_SIZE * 2);
let unprotected_size = ptr::read(base_ptr as *const usize);

// check
if !crate::memeq(canary_ptr as *const u8, CANARY.as_ptr(), CANARY_SIZE) {
abort();
}

// free
let total_size = PAGE_SIZE + PAGE_SIZE + unprotected_size + PAGE_SIZE;
_mprotect(base_ptr, total_size, Prot::ReadWrite);

let res = libc::munmap(base_ptr as *mut c_void, total_size);
if res < 0 {
abort();
}
}
23 changes: 12 additions & 11 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
#![no_std]

#![cfg_attr(feature = "nightly", feature(core_intrinsics))]
#![allow(clippy::missing_safety_doc)]

mod mlock;
mod alloc;
mod mlock;

use core::ptr;

#[cfg(feature = "use_os")]
pub use mlock::{ mlock, munlock };
pub use mlock::{mlock, munlock};

#[cfg(feature = "alloc")]
pub use alloc::{ Prot, mprotect, malloc, malloc_sized, free };

pub use alloc::{free, malloc, malloc_sized, mprotect, Prot};

#[cfg(feature = "alloc")]
#[cfg(target_os = "linux")]
pub use alloc::{free_memfd_secret, memfd_secret, memfd_secret_sized};
// -- memcmp --

/// Secure `memeq`.
Expand All @@ -26,30 +27,30 @@ pub unsafe fn memeq(b1: *const u8, b2: *const u8, len: usize) -> bool {
.eq(&0)
}


/// Secure `memcmp`.
#[inline(never)]
pub unsafe fn memcmp(b1: *const u8, b2: *const u8, len: usize) -> i32 {
let mut res = 0;
for i in (0..len).rev() {
let diff = i32::from(ptr::read_volatile(b1.add(i)))
- i32::from(ptr::read_volatile(b2.add(i)));
let diff =
i32::from(ptr::read_volatile(b1.add(i))) - i32::from(ptr::read_volatile(b2.add(i)));
res = (res & (((diff - 1) & !diff) >> 8)) | diff;
}
((res - 1) >> 8) + (res >> 8) + 1
}


// -- memset / memzero --

/// General `memset`.
#[inline(never)]
pub unsafe fn memset(s: *mut u8, c: u8, n: usize) {
#[cfg(feature = "nightly")] {
#[cfg(feature = "nightly")]
{
core::intrinsics::volatile_set_memory(s, c, n);
}

#[cfg(not(feature = "nightly"))] {
#[cfg(not(feature = "nightly"))]
{
let s = ptr::read_volatile(&s);
let c = ptr::read_volatile(&c);
let n = ptr::read_volatile(&n);
Expand Down
13 changes: 8 additions & 5 deletions src/mlock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

#![cfg(feature = "use_os")]


/// Cross-platform `mlock`.
///
/// * Unix `mlock`.
/// * Windows `VirtualLock`.
pub unsafe fn mlock(addr: *mut u8, len: usize) -> bool {
#[cfg(unix)] {
#[cfg(unix)]
{
#[cfg(target_os = "linux")]
libc::madvise(addr as *mut libc::c_void, len, libc::MADV_DONTDUMP);

Expand All @@ -18,7 +18,8 @@ pub unsafe fn mlock(addr: *mut u8, len: usize) -> bool {
libc::mlock(addr as *mut libc::c_void, len) == 0
}

#[cfg(windows)] {
#[cfg(windows)]
{
windows_sys::Win32::System::Memory::VirtualLock(addr.cast(), len) != 0
}
}
Expand All @@ -30,7 +31,8 @@ pub unsafe fn mlock(addr: *mut u8, len: usize) -> bool {
pub unsafe fn munlock(addr: *mut u8, len: usize) -> bool {
crate::memzero(addr, len);

#[cfg(unix)] {
#[cfg(unix)]
{
#[cfg(target_os = "linux")]
libc::madvise(addr as *mut libc::c_void, len, libc::MADV_DODUMP);

Expand All @@ -40,7 +42,8 @@ pub unsafe fn munlock(addr: *mut u8, len: usize) -> bool {
libc::munlock(addr as *mut libc::c_void, len) == 0
}

#[cfg(windows)] {
#[cfg(windows)]
{
windows_sys::Win32::System::Memory::VirtualUnlock(addr.cast(), len) != 0
}
}