Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(feat): Add memfd secret based allocation #16

Merged
merged 16 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ version = "0.6.3"
authors = ["quininer kel <[email protected]>"]
description = "Rust implementation `libsodium/utils`."
repository = "https://github.com/quininer/memsec"
keywords = [ "protection", "memory", "secure" ]
keywords = ["protection", "memory", "secure"]
documentation = "https://docs.rs/memsec/"
license = "MIT"
categories = [ "no-std", "memory-management" ]
categories = ["no-std", "memory-management"]
edition = "2018"

[badges]
Expand All @@ -22,14 +22,15 @@ libc = { version = "0.2", optional = true }

[target.'cfg(windows)'.dependencies]
windows-sys = { version = "0.45", default-features = false, features = [
"Win32_System_SystemInformation",
"Win32_System_Memory",
"Win32_Foundation",
"Win32_System_Diagnostics_Debug"
"Win32_System_SystemInformation",
"Win32_System_Memory",
"Win32_Foundation",
"Win32_System_Diagnostics_Debug",
], optional = true }

[features]
default = [ "use_os", "alloc" ]
default = ["use_os", "alloc"]
nightly = []
use_os = [ "libc", "windows-sys" ]
alloc = [ "getrandom", "use_os" ]
use_os = ["libc", "windows-sys"]
alloc = ["getrandom", "use_os"]
alloc_ext = ["alloc"]
3 changes: 2 additions & 1 deletion memsec-test/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ libsodium-sys = { version = "0.2" }
nix = "0.26"

[features]
default = [ "alloc", "use_os" ]
default = [ "alloc", "use_os", "alloc_ext"]
nightly = [ "memsec/nightly" ]
use_os = [ "memsec/use_os" ]
alloc = [ "memsec/alloc" ]
alloc_ext = [ "memsec/alloc_ext" ]
prabhpreet marked this conversation as resolved.
Show resolved Hide resolved
86 changes: 86 additions & 0 deletions memsec-test/tests/allocext_linux.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#![cfg(feature = "alloc_ext")]
#![cfg(target_os = "linux")]

use std::ptr::NonNull;


#[test]
fn memfd_secret_u64_test() {
unsafe {
let mut p: NonNull<u64> = memsec::memfd_secret().unwrap();
*p.as_mut() = std::u64::MAX;
assert_eq!(*p.as_ref(), std::u64::MAX);
memsec::free_memfd_secret(p);
}
}

#[test]
fn memfd_secret_free_test() {
unsafe {
let memptr: Option<NonNull<u8>> = memsec::memfd_secret();
assert!(memptr.is_some());
if let Some(memptr) = memptr {
memsec::free_memfd_secret(memptr);
}

let memptr: Option<NonNull<()>> = memsec::memfd_secret();
assert!(memptr.is_some());
if let Some(memptr) = memptr {
memsec::free_memfd_secret(memptr);
}

let memptr: Option<NonNull<[u8]>> = memsec::memfd_secret_sized(1024);
assert!(memptr.is_some());
if let Some(memptr) = memptr {
memsec::free_memfd_secret(memptr);
}

// let memptr: Option<NonNull<[u8; std::usize::MAX - 1]>> = memsec::memfd_secret();
// assert!(memptr.is_none());
}
}

#[test]
fn memfd_secret_mprotect_1_test() {
unsafe {
let mut x: NonNull<[u8; 16]> = memsec::memfd_secret().unwrap();

memsec::memset(x.as_mut().as_mut_ptr(), 0x01, 16);
assert!(memsec::mprotect(x, memsec::Prot::ReadOnly));
assert!(memsec::memeq(x.as_ref().as_ptr(), [1; 16].as_ptr(), 16));
assert!(memsec::mprotect(x, memsec::Prot::NoAccess));
assert!(memsec::mprotect(x, memsec::Prot::ReadWrite));
memsec::memzero(x.as_mut().as_mut_ptr(), 16);
memsec::free_memfd_secret(x);
}

unsafe {
let mut x: NonNull<[u8; 4096]> = memsec::memfd_secret().unwrap();
memsec::memset(x.as_mut().as_mut_ptr(), 0x02, 96);
memsec::free_memfd_secret(x);
}

unsafe {
let mut x: NonNull<[u8; 4100]> = memsec::memfd_secret().unwrap();
memsec::memset(x.as_mut().as_mut_ptr().offset(100), 0x03, 3000);
memsec::free_memfd_secret(x);
}

unsafe {
let mut x = memsec::memfd_secret_sized(16).unwrap();

memsec::memset(x.as_mut().as_mut_ptr(), 0x01, 16);
assert!(memsec::mprotect(x, memsec::Prot::ReadOnly));
assert!(memsec::memeq(x.as_ref().as_ptr(), [1; 16].as_ptr(), 16));
assert!(memsec::mprotect(x, memsec::Prot::NoAccess));
assert!(memsec::mprotect(x, memsec::Prot::ReadWrite));
memsec::memzero(x.as_mut().as_mut_ptr(), 16);
memsec::free_memfd_secret(x);
}

unsafe {
let mut x = memsec::memfd_secret_sized(4100).unwrap();
memsec::memset(x.as_mut().as_mut_ptr().offset(100), 0x03, 3000);
memsec::free_memfd_secret(x);
}
}
138 changes: 138 additions & 0 deletions src/alloc/allocext.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
//! allocext
//! OS Specific allocation
//!
//!

#![cfg(feature = "alloc_ext")]
extern crate std;
use self::std::process::abort;
use crate::{alloc::*, Prot };
use core::mem;
use core::ptr::{self, NonNull};
use core::slice;

#[cfg(target_os = "linux")]
use self::memfd_secret_alloc::*;

#[cfg(target_os = "linux")]
mod memfd_secret_alloc {
use core::convert::TryInto;

use super::*;

#[inline]
pub unsafe fn alloc_memfd_secret(size: usize) -> Option<NonNull<u8>> {
let fd: Result<libc::c_int, _> = libc::syscall(libc::SYS_memfd_secret, 0).try_into();

if fd.is_err() || fd.unwrap() < 0 {
return None;
}

let Ok(fd) = fd else {
return None;
};

// File size is set using ftruncate
let _ = libc::ftruncate(fd, size as libc::off_t);

let ptr = libc::mmap(
ptr::null_mut(),
size,
Prot::ReadWrite,
libc::MAP_SHARED,
fd,
0,
);

if ptr == libc::MAP_FAILED {
return None;
}

NonNull::new(ptr as *mut u8)
}
}

#[cfg(target_os = "linux")]
unsafe fn _memfd_secret(size: usize) -> Option<*mut u8> {
ALLOC_INIT.call_once(|| alloc_init());

if size >= ::core::usize::MAX - PAGE_SIZE * 4 {
return None;
}

// aligned alloc ptr
let size_with_canary = CANARY_SIZE + size;
let unprotected_size = page_round(size_with_canary);
let total_size = PAGE_SIZE + PAGE_SIZE + unprotected_size + PAGE_SIZE;
let base_ptr = alloc_memfd_secret(total_size)?.as_ptr();
let unprotected_ptr = base_ptr.add(PAGE_SIZE * 2);

// mprotect can be used to change protection flag after mmap setup
// https://www.gnu.org/software/libc/manual/html_node/Memory-Protection.html#index-mprotect
_mprotect(base_ptr.add(PAGE_SIZE), PAGE_SIZE, Prot::NoAccess);
_mprotect(
unprotected_ptr.add(unprotected_size),
PAGE_SIZE,
Prot::NoAccess,
);

let canary_ptr = unprotected_ptr.add(unprotected_size - size_with_canary);
let user_ptr = canary_ptr.add(CANARY_SIZE);
ptr::copy_nonoverlapping(CANARY.as_ptr(), canary_ptr, CANARY_SIZE);
ptr::write_unaligned(base_ptr as *mut usize, unprotected_size);
_mprotect(base_ptr, PAGE_SIZE, Prot::ReadOnly);

assert_eq!(unprotected_ptr_from_user_ptr(user_ptr), unprotected_ptr);

Some(user_ptr)
}

/// Linux specific `memfd_secret` backed allocation
#[inline]
#[cfg(target_os = "linux")]
pub unsafe fn memfd_secret<T>() -> Option<NonNull<T>> {
_memfd_secret(mem::size_of::<T>()).map(|memptr| {
ptr::write_bytes(memptr, GARBAGE_VALUE, mem::size_of::<T>());
NonNull::new_unchecked(memptr as *mut T)
})
}

/// Linux specific `memfd_secret` backed `sized` allocation
#[inline]
#[cfg(target_os = "linux")]
pub unsafe fn memfd_secret_sized(size: usize) -> Option<NonNull<[u8]>> {
_memfd_secret(size).map(|memptr| {
ptr::write_bytes(memptr, GARBAGE_VALUE, size);
NonNull::new_unchecked(slice::from_raw_parts_mut(memptr, size))
})
}

/// Secure `free` for memfd_secret allocations,
/// i.e. provides read write access back to mprotect guard pages
/// and unmaps mmaped secrets
#[cfg(target_os = "linux")]
prabhpreet marked this conversation as resolved.
Show resolved Hide resolved
pub unsafe fn free_memfd_secret<T: ?Sized>(memptr: NonNull<T>) {
use libc::c_void;

let memptr = memptr.as_ptr() as *mut u8;

// get unprotected ptr
let canary_ptr = memptr.sub(CANARY_SIZE);
let unprotected_ptr = unprotected_ptr_from_user_ptr(memptr);
let base_ptr = unprotected_ptr.sub(PAGE_SIZE * 2);
let unprotected_size = ptr::read(base_ptr as *const usize);

// check
if !crate::memeq(canary_ptr as *const u8, CANARY.as_ptr(), CANARY_SIZE) {
abort();
}

// free
let total_size = PAGE_SIZE + PAGE_SIZE + unprotected_size + PAGE_SIZE;
_mprotect(base_ptr, total_size, Prot::ReadWrite);

let res = libc::munmap(base_ptr as *mut c_void, total_size);
if res < 0 {
abort();
}
}
Loading