diff --git a/src/sys/ptrace/linux.rs b/src/sys/ptrace/linux.rs index de3124de95..032e69ef6d 100644 --- a/src/sys/ptrace/linux.rs +++ b/src/sys/ptrace/linux.rs @@ -17,8 +17,10 @@ pub type AddressType = *mut ::libc::c_void; target_arch = "x86_64", any(target_env = "gnu", target_env = "musl") ), - all(target_arch = "x86", target_env = "gnu") - ) + all(target_arch = "x86", target_env = "gnu"), + all(target_arch = "aarch64", target_env = "gnu"), + all(target_arch = "riscv64", target_env = "gnu"), + ), ))] use libc::user_regs_struct; @@ -170,6 +172,29 @@ libc_enum! { } } +libc_enum! { + #[cfg(all( + target_os = "linux", + target_env = "gnu", + any( + target_arch = "x86_64", + target_arch = "x86", + target_arch = "aarch64", + target_arch = "riscv64", + ) + ))] + #[repr(i32)] + /// Defining a specific register set, as used in [`getregset`] and [`setregset`]. + #[non_exhaustive] + pub enum RegisterSet { + NT_PRSTATUS, + NT_PRFPREG, + NT_PRPSINFO, + NT_TASKSTRUCT, + NT_AUXV, + } +} + libc_bitflags! { /// Ptrace options used in conjunction with the PTRACE_SETOPTIONS request. /// See `man ptrace` for more details. @@ -231,6 +256,45 @@ pub fn getregs(pid: Pid) -> Result { ptrace_get_data::(Request::PTRACE_GETREGS, pid) } +/// Get user registers, as with `ptrace(PTRACE_GETREGSET, pid, NT_PRSTATUS, ...)` +#[cfg(all( + target_os = "linux", + target_env = "gnu", + any(target_arch = "aarch64", target_arch = "riscv64",) +))] +pub fn getregs(pid: Pid) -> Result { + getregset(pid, RegisterSet::NT_PRSTATUS) +} + +/// Get a particular set of user registers, as with `ptrace(PTRACE_GETREGSET, ...)` +#[cfg(all( + target_os = "linux", + target_env = "gnu", + any( + target_arch = "x86_64", + target_arch = "x86", + target_arch = "aarch64", + target_arch = "riscv64", + ) +))] +pub fn getregset(pid: Pid, set: RegisterSet) -> Result { + let request = Request::PTRACE_GETREGSET; + let mut data = mem::MaybeUninit::::uninit(); + let mut iov = libc::iovec { + iov_base: data.as_mut_ptr().cast(), + iov_len: mem::size_of::(), + }; + unsafe { + ptrace_other( + request, + pid, + set as i32 as AddressType, + (&mut iov as *mut libc::iovec).cast(), + )?; + }; + Ok(unsafe { data.assume_init() }) +} + /// Set user registers, as with `ptrace(PTRACE_SETREGS, ...)` #[cfg(all( target_os = "linux", @@ -248,12 +312,53 @@ pub fn setregs(pid: Pid, regs: user_regs_struct) -> Result<()> { Request::PTRACE_SETREGS as RequestType, libc::pid_t::from(pid), ptr::null_mut::(), - ®s as *const _ as *const c_void, + ®s as *const user_regs_struct as *const c_void, ) }; Errno::result(res).map(drop) } +/// Set user registers, as with `ptrace(PTRACE_SETREGSET, pid, NT_PRSTATUS, ...)` +#[cfg(all( + target_os = "linux", + target_env = "gnu", + any(target_arch = "aarch64", target_arch = "riscv64",) +))] +pub fn setregs(pid: Pid, regs: user_regs_struct) -> Result<()> { + setregset(pid, RegisterSet::NT_PRSTATUS, regs) +} + +/// Set a particular set of user registers, as with `ptrace(PTRACE_SETREGSET, ...)` +#[cfg(all( + target_os = "linux", + target_env = "gnu", + any( + target_arch = "x86_64", + target_arch = "x86", + target_arch = "aarch64", + target_arch = "riscv64", + ) +))] +pub fn setregset( + pid: Pid, + set: RegisterSet, + mut regs: user_regs_struct, +) -> Result<()> { + let mut iov = libc::iovec { + iov_base: (&mut regs as *mut user_regs_struct).cast(), + iov_len: mem::size_of::(), + }; + unsafe { + ptrace_other( + Request::PTRACE_SETREGSET, + pid, + set as i32 as AddressType, + (&mut iov as *mut libc::iovec).cast(), + )?; + } + Ok(()) +} + /// Function for ptrace requests that return values from the data field. /// Some ptrace get requests populate structs or larger elements than `c_long` /// and therefore use the data field to return values. This function handles these diff --git a/test/sys/test_ptrace.rs b/test/sys/test_ptrace.rs index 246b35445d..ff2b0a9501 100644 --- a/test/sys/test_ptrace.rs +++ b/test/sys/test_ptrace.rs @@ -1,7 +1,12 @@ #[cfg(all( target_os = "linux", - any(target_arch = "x86_64", target_arch = "x86"), - target_env = "gnu" + target_env = "gnu", + any( + target_arch = "x86_64", + target_arch = "x86", + target_arch = "aarch64", + target_arch = "riscv64", + ) ))] use memoffset::offset_of; use nix::errno::Errno; @@ -179,8 +184,13 @@ fn test_ptrace_interrupt() { // ptrace::{setoptions, getregs} are only available in these platforms #[cfg(all( target_os = "linux", - any(target_arch = "x86_64", target_arch = "x86"), - target_env = "gnu" + target_env = "gnu", + any( + target_arch = "x86_64", + target_arch = "x86", + target_arch = "aarch64", + target_arch = "riscv64", + ) ))] #[test] fn test_ptrace_syscall() { @@ -226,14 +236,28 @@ fn test_ptrace_syscall() { let get_syscall_id = || ptrace::getregs(child).unwrap().orig_eax as libc::c_long; + #[cfg(target_arch = "aarch64")] + let get_syscall_id = + || ptrace::getregs(child).unwrap().regs[8] as libc::c_long; + + #[cfg(target_arch = "riscv64")] + let get_syscall_id = + || ptrace::getregs(child).unwrap().a7 as libc::c_long; + // this duplicates `get_syscall_id` for the purpose of testing `ptrace::read_user`. #[cfg(target_arch = "x86_64")] let rax_offset = offset_of!(libc::user_regs_struct, orig_rax); #[cfg(target_arch = "x86")] let rax_offset = offset_of!(libc::user_regs_struct, orig_eax); + #[cfg(target_arch = "aarch64")] + let rax_offset = offset_of!(libc::user_regs_struct, regs) + + 8 * mem::size_of::(); + #[cfg(target_arch = "riscv64")] + let rax_offset = offset_of!(libc::user_regs_struct, a7); let get_syscall_from_user_area = || { // Find the offset of `user.regs.rax` (or `user.regs.eax` for x86) + #[cfg(any(target_arch = "x86_64", target_arch = "x86"))] let rax_offset = offset_of!(libc::user, regs) + rax_offset; ptrace::read_user(child, rax_offset as _).unwrap() as libc::c_long @@ -273,3 +297,76 @@ fn test_ptrace_syscall() { } } } + +#[cfg(all( + target_os = "linux", + target_env = "gnu", + any( + target_arch = "x86_64", + target_arch = "x86", + target_arch = "aarch64", + target_arch = "riscv64", + ) +))] +#[test] +fn test_ptrace_regsets() { + use nix::sys::ptrace::{self, getregset, setregset, RegisterSet}; + use nix::sys::signal::*; + use nix::sys::wait::{waitpid, WaitStatus}; + use nix::unistd::fork; + use nix::unistd::ForkResult::*; + + require_capability!("test_ptrace_regsets", CAP_SYS_PTRACE); + + let _m = crate::FORK_MTX.lock(); + + match unsafe { fork() }.expect("Error: Fork Failed") { + Child => { + ptrace::traceme().unwrap(); + // As recommended by ptrace(2), raise SIGTRAP to pause the child + // until the parent is ready to continue + loop { + raise(Signal::SIGTRAP).unwrap(); + } + } + + Parent { child } => { + assert_eq!( + waitpid(child, None), + Ok(WaitStatus::Stopped(child, Signal::SIGTRAP)) + ); + let mut regstruct = + getregset(child, RegisterSet::NT_PRSTATUS).unwrap(); + + #[cfg(target_arch = "x86_64")] + let reg = &mut regstruct.r15; + #[cfg(target_arch = "x86")] + let reg = &mut regstruct.edx; + #[cfg(target_arch = "aarch64")] + let reg = &mut regstruct[16]; + #[cfg(target_arch = "riscv64")] + let reg = &mut regstruct[16]; + + *reg = 0xdeadbeef; + let _ = setregset(child, RegisterSet::NT_PRSTATUS, regstruct); + regstruct = getregset(child, RegisterSet::NT_PRSTATUS).unwrap(); + + #[cfg(target_arch = "x86_64")] + let reg = regstruct.r15; + #[cfg(target_arch = "x86")] + let reg = regstruct.edx; + #[cfg(target_arch = "aarch64")] + let reg = regstruct[16]; + #[cfg(target_arch = "riscv64")] + let reg = regstruct[16]; + assert_eq!(0xdeadbeef, reg); + + ptrace::cont(child, Some(Signal::SIGKILL)).unwrap(); + match waitpid(child, None) { + Ok(WaitStatus::Signaled(pid, Signal::SIGKILL, _)) + if pid == child => {} + _ => panic!("The process should have been killed"), + } + } + } +}