Skip to content

Commit

Permalink
zero
Browse files Browse the repository at this point in the history
  • Loading branch information
kennykerr committed Jul 3, 2024
1 parent 9f96662 commit 44ebe7e
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 55 deletions.
40 changes: 20 additions & 20 deletions crates/libs/strings/src/hstring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ impl HSTRING {
}

/// Get the contents of this `HSTRING` as a OsString.
#[cfg(all(feature = "std", windows))]
#[cfg(feature = "std")]
pub fn to_os_string(&self) -> std::ffi::OsString {
std::os::windows::ffi::OsStringExt::from_wide(self.as_wide())
}
Expand All @@ -66,7 +66,7 @@ impl HSTRING {
return Ok(Self::new());
}

let ptr = HStringHeader::alloc(len.try_into()?)?;
let ptr = HStringHeader::alloc(len.try_into()?, false)?;

// Place each utf-16 character into the buffer and
// increase len as we go along.
Expand Down Expand Up @@ -154,14 +154,14 @@ impl From<&String> for HSTRING {
}
}

#[cfg(all(feature = "std", windows))]
#[cfg(feature = "std")]
impl From<&std::path::Path> for HSTRING {
fn from(value: &std::path::Path) -> Self {
value.as_os_str().into()
}
}

#[cfg(all(feature = "std", windows))]
#[cfg(feature = "std")]
impl From<&std::ffi::OsStr> for HSTRING {
fn from(value: &std::ffi::OsStr) -> Self {
unsafe {
Expand All @@ -174,14 +174,14 @@ impl From<&std::ffi::OsStr> for HSTRING {
}
}

#[cfg(all(feature = "std", windows))]
#[cfg(feature = "std")]
impl From<std::ffi::OsString> for HSTRING {
fn from(value: std::ffi::OsString) -> Self {
value.as_os_str().into()
}
}

#[cfg(all(feature = "std", windows))]
#[cfg(feature = "std")]
impl From<&std::ffi::OsString> for HSTRING {
fn from(value: &std::ffi::OsString) -> Self {
value.as_os_str().into()
Expand Down Expand Up @@ -286,28 +286,28 @@ impl PartialEq<&HSTRING> for String {
}
}

#[cfg(all(feature = "std", windows))]
#[cfg(feature = "std")]
impl PartialEq<std::ffi::OsString> for HSTRING {
fn eq(&self, other: &std::ffi::OsString) -> bool {
*self == **other
}
}

#[cfg(all(feature = "std", windows))]
#[cfg(feature = "std")]
impl PartialEq<std::ffi::OsString> for &HSTRING {
fn eq(&self, other: &std::ffi::OsString) -> bool {
**self == **other
}
}

#[cfg(all(feature = "std", windows))]
#[cfg(feature = "std")]
impl PartialEq<&std::ffi::OsString> for HSTRING {
fn eq(&self, other: &&std::ffi::OsString) -> bool {
*self == ***other
}
}

#[cfg(all(feature = "std", windows))]
#[cfg(feature = "std")]
impl PartialEq<std::ffi::OsStr> for HSTRING {
fn eq(&self, other: &std::ffi::OsStr) -> bool {
self.as_wide()
Expand All @@ -317,56 +317,56 @@ impl PartialEq<std::ffi::OsStr> for HSTRING {
}
}

#[cfg(all(feature = "std", windows))]
#[cfg(feature = "std")]
impl PartialEq<std::ffi::OsStr> for &HSTRING {
fn eq(&self, other: &std::ffi::OsStr) -> bool {
**self == *other
}
}

#[cfg(all(feature = "std", windows))]
#[cfg(feature = "std")]
impl PartialEq<&std::ffi::OsStr> for HSTRING {
fn eq(&self, other: &&std::ffi::OsStr) -> bool {
*self == **other
}
}

#[cfg(all(feature = "std", windows))]
#[cfg(feature = "std")]
impl PartialEq<HSTRING> for std::ffi::OsStr {
fn eq(&self, other: &HSTRING) -> bool {
*other == *self
}
}

#[cfg(all(feature = "std", windows))]
#[cfg(feature = "std")]
impl PartialEq<HSTRING> for &std::ffi::OsStr {
fn eq(&self, other: &HSTRING) -> bool {
*other == **self
}
}

#[cfg(all(feature = "std", windows))]
#[cfg(feature = "std")]
impl PartialEq<&HSTRING> for std::ffi::OsStr {
fn eq(&self, other: &&HSTRING) -> bool {
**other == *self
}
}

#[cfg(all(feature = "std", windows))]
#[cfg(feature = "std")]
impl PartialEq<HSTRING> for std::ffi::OsString {
fn eq(&self, other: &HSTRING) -> bool {
*other == **self
}
}

#[cfg(all(feature = "std", windows))]
#[cfg(feature = "std")]
impl PartialEq<HSTRING> for &std::ffi::OsString {
fn eq(&self, other: &HSTRING) -> bool {
*other == ***self
}
}

#[cfg(all(feature = "std", windows))]
#[cfg(feature = "std")]
impl PartialEq<&HSTRING> for std::ffi::OsString {
fn eq(&self, other: &&HSTRING) -> bool {
**other == **self
Expand All @@ -389,14 +389,14 @@ impl TryFrom<HSTRING> for String {
}
}

#[cfg(all(feature = "std", windows))]
#[cfg(feature = "std")]
impl<'a> From<&'a HSTRING> for std::ffi::OsString {
fn from(hstring: &HSTRING) -> Self {
hstring.to_os_string()
}
}

#[cfg(all(feature = "std", windows))]
#[cfg(feature = "std")]
impl From<HSTRING> for std::ffi::OsString {
fn from(hstring: HSTRING) -> Self {
Self::from(&hstring)
Expand Down
2 changes: 1 addition & 1 deletion crates/libs/strings/src/hstring_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub struct HStringBuilder(*mut HStringHeader);
impl HStringBuilder {
/// Creates a preallocated `HSTRING` value.
pub fn new(len: usize) -> Result<Self> {
Ok(Self(HStringHeader::alloc(len.try_into()?)?))
Ok(Self(HStringHeader::alloc(len.try_into()?, true)?))
}

/// Shortens the string by removing any trailing 0 characters.
Expand Down
48 changes: 14 additions & 34 deletions crates/libs/strings/src/hstring_header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,73 +14,53 @@ pub struct HStringHeader {
}

impl HStringHeader {
pub fn alloc(len: u32) -> Result<*mut HStringHeader> {
pub fn alloc(len: u32, zero_memory: bool) -> Result<*mut Self> {
if len == 0 {
return Ok(core::ptr::null_mut());
}

// Allocate enough space for header and two bytes per character.
// The space for the terminating null character is already accounted for inside of `HStringHeader`.
let bytes = core::mem::size_of::<HStringHeader>() + 2 * len as usize;
let bytes = core::mem::size_of::<Self>() + 2 * len as usize;

#[cfg(windows)]
let header = unsafe { bindings::HeapAlloc(bindings::GetProcessHeap(), 0, bytes) }
as *mut HStringHeader;

#[cfg(not(windows))]
let header = unsafe {
extern "C" {
fn malloc(bytes: usize) -> *mut core::ffi::c_void;
}

malloc(bytes) as *mut HStringHeader
};
let header =
unsafe { bindings::HeapAlloc(bindings::GetProcessHeap(), 0, bytes) } as *mut Self;

if header.is_null() {
return Err(Error::from_hresult(HRESULT(bindings::E_OUTOFMEMORY)));
}

unsafe {
// Use `ptr::write` (since `header` is unintialized). `HStringHeader` is safe to be all zeros.
header.write(core::mem::MaybeUninit::<HStringHeader>::zeroed().assume_init());
header.write(core::mem::MaybeUninit::<Self>::zeroed().assume_init());
(*header).len = len;
(*header).count = RefCount::new(1);
(*header).data = &mut (*header).buffer_start;

if zero_memory {
core::ptr::write_bytes((*header).data, 0, len as usize);
}
}

Ok(header)
}

pub unsafe fn free(header: *mut HStringHeader) {
pub unsafe fn free(header: *mut Self) {
if header.is_null() {
return;
}

let header = header as *mut _;

#[cfg(windows)]
{
bindings::HeapFree(bindings::GetProcessHeap(), 0, header);
}

#[cfg(not(windows))]
{
extern "C" {
fn free(ptr: *mut core::ffi::c_void);
}

free(header);
}
bindings::HeapFree(bindings::GetProcessHeap(), 0, header as *mut _);
}

pub fn duplicate(&self) -> Result<*mut HStringHeader> {
pub fn duplicate(&self) -> Result<*mut Self> {
if self.flags & HSTRING_REFERENCE_FLAG == 0 {
// If this is not a "fast pass" string then simply increment the reference count.
self.count.add_ref();
Ok(self as *const HStringHeader as *mut HStringHeader)
Ok(self as *const Self as *mut Self)
} else {
// Otherwise, allocate a new string and copy the value into the new string.
let copy = HStringHeader::alloc(self.len)?;
let copy = Self::alloc(self.len, false)?;
// SAFETY: since we are duplicating the string it is safe to copy all data from self to the initialized `copy`.
// We copy `len + 1` characters since `len` does not account for the terminating null character.
unsafe {
Expand Down
4 changes: 4 additions & 0 deletions crates/tests/strings/tests/hstring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,9 @@ fn hstring_builder() -> Result<()> {
assert_eq!(h.len(), 5);
assert_eq!(h.as_wide(), HELLO);

// HStringBuilder will initialize memory to zero.
let b = HStringBuilder::new(5)?;
assert_eq!(*b, [0, 0, 0, 0, 0]);

Ok(())
}

0 comments on commit 44ebe7e

Please sign in to comment.