diff --git a/crates/libs/strings/src/hstring.rs b/crates/libs/strings/src/hstring.rs index b6766b89f2..9f34b77442 100644 --- a/crates/libs/strings/src/hstring.rs +++ b/crates/libs/strings/src/hstring.rs @@ -54,7 +54,7 @@ impl HSTRING { } /// Get the contents of this `HSTRING` as a OsString. - #[cfg(all(feature = "std", windows))] + #[cfg(feature = "std")] pub fn to_os_string(&self) -> std::ffi::OsString { std::os::windows::ffi::OsStringExt::from_wide(self.as_wide()) } @@ -66,7 +66,7 @@ impl HSTRING { return Ok(Self::new()); } - let ptr = HStringHeader::alloc(len.try_into()?)?; + let ptr = HStringHeader::alloc(len.try_into()?, false)?; // Place each utf-16 character into the buffer and // increase len as we go along. @@ -154,14 +154,14 @@ impl From<&String> for HSTRING { } } -#[cfg(all(feature = "std", windows))] +#[cfg(feature = "std")] impl From<&std::path::Path> for HSTRING { fn from(value: &std::path::Path) -> Self { value.as_os_str().into() } } -#[cfg(all(feature = "std", windows))] +#[cfg(feature = "std")] impl From<&std::ffi::OsStr> for HSTRING { fn from(value: &std::ffi::OsStr) -> Self { unsafe { @@ -174,14 +174,14 @@ impl From<&std::ffi::OsStr> for HSTRING { } } -#[cfg(all(feature = "std", windows))] +#[cfg(feature = "std")] impl From for HSTRING { fn from(value: std::ffi::OsString) -> Self { value.as_os_str().into() } } -#[cfg(all(feature = "std", windows))] +#[cfg(feature = "std")] impl From<&std::ffi::OsString> for HSTRING { fn from(value: &std::ffi::OsString) -> Self { value.as_os_str().into() @@ -286,28 +286,28 @@ impl PartialEq<&HSTRING> for String { } } -#[cfg(all(feature = "std", windows))] +#[cfg(feature = "std")] impl PartialEq for HSTRING { fn eq(&self, other: &std::ffi::OsString) -> bool { *self == **other } } -#[cfg(all(feature = "std", windows))] +#[cfg(feature = "std")] impl PartialEq for &HSTRING { fn eq(&self, other: &std::ffi::OsString) -> bool { **self == **other } } -#[cfg(all(feature = "std", windows))] +#[cfg(feature = "std")] impl PartialEq<&std::ffi::OsString> for HSTRING { fn eq(&self, other: &&std::ffi::OsString) -> bool { *self == ***other } } -#[cfg(all(feature = "std", windows))] +#[cfg(feature = "std")] impl PartialEq for HSTRING { fn eq(&self, other: &std::ffi::OsStr) -> bool { self.as_wide() @@ -317,56 +317,56 @@ impl PartialEq for HSTRING { } } -#[cfg(all(feature = "std", windows))] +#[cfg(feature = "std")] impl PartialEq for &HSTRING { fn eq(&self, other: &std::ffi::OsStr) -> bool { **self == *other } } -#[cfg(all(feature = "std", windows))] +#[cfg(feature = "std")] impl PartialEq<&std::ffi::OsStr> for HSTRING { fn eq(&self, other: &&std::ffi::OsStr) -> bool { *self == **other } } -#[cfg(all(feature = "std", windows))] +#[cfg(feature = "std")] impl PartialEq for std::ffi::OsStr { fn eq(&self, other: &HSTRING) -> bool { *other == *self } } -#[cfg(all(feature = "std", windows))] +#[cfg(feature = "std")] impl PartialEq for &std::ffi::OsStr { fn eq(&self, other: &HSTRING) -> bool { *other == **self } } -#[cfg(all(feature = "std", windows))] +#[cfg(feature = "std")] impl PartialEq<&HSTRING> for std::ffi::OsStr { fn eq(&self, other: &&HSTRING) -> bool { **other == *self } } -#[cfg(all(feature = "std", windows))] +#[cfg(feature = "std")] impl PartialEq for std::ffi::OsString { fn eq(&self, other: &HSTRING) -> bool { *other == **self } } -#[cfg(all(feature = "std", windows))] +#[cfg(feature = "std")] impl PartialEq for &std::ffi::OsString { fn eq(&self, other: &HSTRING) -> bool { *other == ***self } } -#[cfg(all(feature = "std", windows))] +#[cfg(feature = "std")] impl PartialEq<&HSTRING> for std::ffi::OsString { fn eq(&self, other: &&HSTRING) -> bool { **other == **self @@ -389,14 +389,14 @@ impl TryFrom for String { } } -#[cfg(all(feature = "std", windows))] +#[cfg(feature = "std")] impl<'a> From<&'a HSTRING> for std::ffi::OsString { fn from(hstring: &HSTRING) -> Self { hstring.to_os_string() } } -#[cfg(all(feature = "std", windows))] +#[cfg(feature = "std")] impl From for std::ffi::OsString { fn from(hstring: HSTRING) -> Self { Self::from(&hstring) diff --git a/crates/libs/strings/src/hstring_builder.rs b/crates/libs/strings/src/hstring_builder.rs index 446451f0df..45ae2d0943 100644 --- a/crates/libs/strings/src/hstring_builder.rs +++ b/crates/libs/strings/src/hstring_builder.rs @@ -9,7 +9,7 @@ pub struct HStringBuilder(*mut HStringHeader); impl HStringBuilder { /// Creates a preallocated `HSTRING` value. pub fn new(len: usize) -> Result { - Ok(Self(HStringHeader::alloc(len.try_into()?)?)) + Ok(Self(HStringHeader::alloc(len.try_into()?, true)?)) } /// Shortens the string by removing any trailing 0 characters. diff --git a/crates/libs/strings/src/hstring_header.rs b/crates/libs/strings/src/hstring_header.rs index 84c3851519..dd6d047827 100644 --- a/crates/libs/strings/src/hstring_header.rs +++ b/crates/libs/strings/src/hstring_header.rs @@ -14,27 +14,17 @@ pub struct HStringHeader { } impl HStringHeader { - pub fn alloc(len: u32) -> Result<*mut HStringHeader> { + pub fn alloc(len: u32, zero_memory: bool) -> Result<*mut Self> { if len == 0 { return Ok(core::ptr::null_mut()); } // Allocate enough space for header and two bytes per character. // The space for the terminating null character is already accounted for inside of `HStringHeader`. - let bytes = core::mem::size_of::() + 2 * len as usize; + let bytes = core::mem::size_of::() + 2 * len as usize; - #[cfg(windows)] - let header = unsafe { bindings::HeapAlloc(bindings::GetProcessHeap(), 0, bytes) } - as *mut HStringHeader; - - #[cfg(not(windows))] - let header = unsafe { - extern "C" { - fn malloc(bytes: usize) -> *mut core::ffi::c_void; - } - - malloc(bytes) as *mut HStringHeader - }; + let header = + unsafe { bindings::HeapAlloc(bindings::GetProcessHeap(), 0, bytes) } as *mut Self; if header.is_null() { return Err(Error::from_hresult(HRESULT(bindings::E_OUTOFMEMORY))); @@ -42,45 +32,35 @@ impl HStringHeader { unsafe { // Use `ptr::write` (since `header` is unintialized). `HStringHeader` is safe to be all zeros. - header.write(core::mem::MaybeUninit::::zeroed().assume_init()); + header.write(core::mem::MaybeUninit::::zeroed().assume_init()); (*header).len = len; (*header).count = RefCount::new(1); (*header).data = &mut (*header).buffer_start; + + if zero_memory { + core::ptr::write_bytes((*header).data, 0, len as usize); + } } Ok(header) } - pub unsafe fn free(header: *mut HStringHeader) { + pub unsafe fn free(header: *mut Self) { if header.is_null() { return; } - let header = header as *mut _; - - #[cfg(windows)] - { - bindings::HeapFree(bindings::GetProcessHeap(), 0, header); - } - - #[cfg(not(windows))] - { - extern "C" { - fn free(ptr: *mut core::ffi::c_void); - } - - free(header); - } + bindings::HeapFree(bindings::GetProcessHeap(), 0, header as *mut _); } - pub fn duplicate(&self) -> Result<*mut HStringHeader> { + pub fn duplicate(&self) -> Result<*mut Self> { if self.flags & HSTRING_REFERENCE_FLAG == 0 { // If this is not a "fast pass" string then simply increment the reference count. self.count.add_ref(); - Ok(self as *const HStringHeader as *mut HStringHeader) + Ok(self as *const Self as *mut Self) } else { // Otherwise, allocate a new string and copy the value into the new string. - let copy = HStringHeader::alloc(self.len)?; + let copy = Self::alloc(self.len, false)?; // SAFETY: since we are duplicating the string it is safe to copy all data from self to the initialized `copy`. // We copy `len + 1` characters since `len` does not account for the terminating null character. unsafe { diff --git a/crates/tests/strings/tests/hstring.rs b/crates/tests/strings/tests/hstring.rs index 13597f4b90..bb7a35a586 100644 --- a/crates/tests/strings/tests/hstring.rs +++ b/crates/tests/strings/tests/hstring.rs @@ -47,5 +47,9 @@ fn hstring_builder() -> Result<()> { assert_eq!(h.len(), 5); assert_eq!(h.as_wide(), HELLO); + // HStringBuilder will initialize memory to zero. + let b = HStringBuilder::new(5)?; + assert_eq!(*b, [0, 0, 0, 0, 0]); + Ok(()) }