diff --git a/core/src/file.rs b/core/src/file.rs index 3c2af45..8b1f8ec 100644 --- a/core/src/file.rs +++ b/core/src/file.rs @@ -1,8 +1,7 @@ //! Implementation of in-memory files -use std::io::{Read, Seek, SeekFrom, Write}; use std::marker::PhantomData; use std::mem::size_of; -use std::os::unix::prelude::{AsRawFd, IntoRawFd, RawFd}; +use std::os::unix::prelude::{AsRawFd, FileExt, IntoRawFd, RawFd}; use anyhow::{anyhow, Result}; use memfd::{FileSeal, Memfd, MemfdOptions}; @@ -79,21 +78,21 @@ impl TempFile { // TODO: Use an approach without unsafe let bytes = unsafe { std::slice::from_raw_parts(value as *const T as *const u8, size_of::()) }; - let mut file = self.get_memfd()?.into_file(); - file.seek(SeekFrom::Start(0)).typ(SystemError::Panic)?; - file.write_all(bytes) + let file = self.get_memfd()?.into_file(); + file.write_all_at(bytes, 0) .map_err(anyhow::Error::from) .typ(SystemError::Panic) } /// Returns all of the TempFile's data pub fn read(&self) -> TypedResult { - let mut buf = Vec::with_capacity(size_of::()); - let mut file = self.get_memfd()?.into_file(); - file.seek(SeekFrom::Start(0)).typ(SystemError::Panic)?; - file.read_to_end(buf.as_mut()).typ(SystemError::Panic)?; + let mut buf = vec![0u8; size_of::()]; + let buf = buf.as_mut_slice(); + let file = self.get_memfd()?.into_file(); + file.read_at(buf, 0).typ(SystemError::Panic)?; // TODO: Use an approach without unsafe - Ok(unsafe { buf.as_slice().align_to::().1[0].clone() }) + let aligned = unsafe { buf.align_to::() }; + Ok(aligned.1[0].clone()) } /// Returns a mutable memory map from a TempFile diff --git a/partition/src/apex.rs b/partition/src/apex.rs index 53cb5b3..679b31a 100644 --- a/partition/src/apex.rs +++ b/partition/src/apex.rs @@ -152,11 +152,12 @@ impl ApexSamplingPortP4 for ApexLinuxPartition { sampling_port_id: SamplingPortId, message: &mut [ApexByte], ) -> Result<(Validity, MessageSize), ErrorReturnCode> { - if let Some((port, val)) = SAMPLING_PORTS - .read() - .unwrap() - .get(sampling_port_id as usize - 1) - { + let read = if let Ok(read) = SAMPLING_PORTS.read() { + read + } else { + return Err(ErrorReturnCode::NotAvailable); + }; + if let Some((port, val)) = read.get(sampling_port_id as usize - 1) { if let Some(port) = CONSTANTS.sampling.get(*port) { if message.is_empty() { return Err(ErrorReturnCode::InvalidParam);