diff --git a/crates/libs/registry/src/data.rs b/crates/libs/registry/src/data.rs index fa1660f391..6dd4cae3c6 100644 --- a/crates/libs/registry/src/data.rs +++ b/crates/libs/registry/src/data.rs @@ -20,19 +20,12 @@ impl Data { } } - // Returns the buffer as a slice of u16 for reading wide characters. The slice trims off any trailing zero bytes. + // Returns the buffer as a slice of u16 for reading wide characters. pub fn as_wide(&self) -> &[u16] { if self.ptr.is_null() { &[] } else { - let mut wide = - unsafe { core::slice::from_raw_parts(self.ptr as *const u16, self.len / 2) }; - - while wide.last() == Some(&0) { - wide = &wide[..wide.len() - 1]; - } - - wide + unsafe { core::slice::from_raw_parts(self.ptr as *const u16, self.len / 2) } } } diff --git a/crates/libs/registry/src/key.rs b/crates/libs/registry/src/key.rs index f6b5191951..9261b4ecca 100644 --- a/crates/libs/registry/src/key.rs +++ b/crates/libs/registry/src/key.rs @@ -101,7 +101,7 @@ impl Key { name: T, value: &windows_strings::HSTRING, ) -> Result<()> { - self.set_bytes(name, Type::String, value.as_bytes()) + self.set_bytes(name, Type::String, as_bytes(value)) } /// Sets the name and value in the registry key. @@ -115,7 +115,7 @@ impl Key { name: T, value: &windows_strings::HSTRING, ) -> Result<()> { - self.set_bytes(name, Type::ExpandString, value.as_bytes()) + self.set_bytes(name, Type::ExpandString, as_bytes(value)) } /// Sets the name and value in the registry key. @@ -193,12 +193,24 @@ impl Key { /// # Safety /// /// The `PCWSTR` pointer needs to be valid for reads up until and including the next `\0`. + #[track_caller] pub unsafe fn raw_set_bytes>( &self, name: N, ty: Type, value: &[u8], ) -> Result<()> { + if cfg!(debug_assertions) { + // RegSetValueExW expects string data to be null terminated. + if matches!(ty, Type::String | Type::ExpandString | Type::MultiString) { + debug_assert!( + value.get(value.len() - 2) == Some(&0), + "`value` isn't null-terminated" + ); + debug_assert!(value.last() == Some(&0), "`value` isn't null-terminated"); + } + } + let result = RegSetValueExW( self.0, name.as_ref().as_ptr(), diff --git a/crates/libs/registry/src/lib.rs b/crates/libs/registry/src/lib.rs index 6d04f58022..44ba2f4982 100644 --- a/crates/libs/registry/src/lib.rs +++ b/crates/libs/registry/src/lib.rs @@ -76,3 +76,8 @@ fn from_le_bytes(ty: Type, from: &[u8]) -> Result { _ => Err(invalid_data()), } } + +// Get the string as 8-bit bytes including the two terminating null bytes. +fn as_bytes(value: &HSTRING) -> &[u8] { + unsafe { core::slice::from_raw_parts(value.as_ptr() as *const _, (value.len() + 1) * 2) } +} diff --git a/crates/libs/registry/src/pcwstr.rs b/crates/libs/registry/src/pcwstr.rs index 10fb656e3f..cfdf75e70a 100644 --- a/crates/libs/registry/src/pcwstr.rs +++ b/crates/libs/registry/src/pcwstr.rs @@ -31,6 +31,7 @@ impl OwnedPcwstr { self.0.as_ptr() } + // Get the string as 8-bit bytes including the two terminating null bytes. pub fn as_bytes(&self) -> &[u8] { unsafe { core::slice::from_raw_parts(self.as_ptr() as *const _, self.0.len() * 2) } } diff --git a/crates/libs/registry/src/value.rs b/crates/libs/registry/src/value.rs index 8d200358c7..753e9d416c 100644 --- a/crates/libs/registry/src/value.rs +++ b/crates/libs/registry/src/value.rs @@ -17,6 +17,11 @@ impl Value { pub fn set_ty(&mut self, ty: Type) { self.ty = ty; } + + /// Gets the value as a slice of u16 for raw wide characters. + pub fn as_wide(&self) -> &[u16] { + self.data.as_wide() + } } impl core::ops::Deref for Value { @@ -71,7 +76,7 @@ impl TryFrom for String { type Error = Error; fn try_from(from: Value) -> Result { match from.ty { - Type::String | Type::ExpandString => Ok(Self::from_utf16(from.data.as_wide())?), + Type::String | Type::ExpandString => Ok(Self::from_utf16(trim(from.data.as_wide()))?), _ => Err(invalid_data()), } } @@ -102,6 +107,26 @@ impl TryFrom for Vec { } } +impl TryFrom for HSTRING { + type Error = Error; + fn try_from(from: Value) -> Result { + match from.ty { + Type::String | Type::ExpandString => Ok(Self::from_wide(trim(from.data.as_wide()))?), + _ => Err(invalid_data()), + } + } +} + +impl TryFrom<&HSTRING> for Value { + type Error = Error; + fn try_from(from: &HSTRING) -> Result { + Ok(Self { + data: Data::from_slice(as_bytes(from))?, + ty: Type::String, + }) + } +} + impl TryFrom<&[u8]> for Value { type Error = Error; fn try_from(from: &[u8]) -> Result { @@ -121,3 +146,11 @@ impl TryFrom<[u8; N]> for Value { }) } } + +fn trim(mut wide: &[u16]) -> &[u16] { + while wide.last() == Some(&0) { + wide = &wide[..wide.len() - 1]; + } + + wide +} diff --git a/crates/libs/strings/src/hstring.rs b/crates/libs/strings/src/hstring.rs index c6511ef27e..6cd0a1ebc4 100644 --- a/crates/libs/strings/src/hstring.rs +++ b/crates/libs/strings/src/hstring.rs @@ -33,11 +33,6 @@ impl HSTRING { unsafe { core::slice::from_raw_parts(self.as_ptr(), self.len()) } } - /// Get the string as 8-bit bytes. - pub fn as_bytes(&self) -> &[u8] { - unsafe { core::slice::from_raw_parts(self.as_ptr() as *const _, self.len() * 2) } - } - /// Returns a raw pointer to the `HSTRING` buffer. pub fn as_ptr(&self) -> *const u16 { if let Some(header) = self.as_header() { diff --git a/crates/tests/registry/tests/bytes.rs b/crates/tests/registry/tests/bytes.rs index dd10c4b3d4..fdbd0a3891 100644 --- a/crates/tests/registry/tests/bytes.rs +++ b/crates/tests/registry/tests/bytes.rs @@ -1,4 +1,5 @@ use windows_registry::*; +use windows_strings::*; #[test] fn bytes() -> Result<()> { @@ -20,5 +21,10 @@ fn bytes() -> Result<()> { assert_eq!(value.ty(), Type::Other(1234)); assert_eq!(*value, [1, 2, 3, 4]); + assert_eq!( + unsafe { key.raw_get_info(w!("other"))? }, + (Type::Other(1234), 4) + ); + Ok(()) } diff --git a/crates/tests/registry/tests/hstring.rs b/crates/tests/registry/tests/hstring.rs index 13816f5652..19d51d50e7 100644 --- a/crates/tests/registry/tests/hstring.rs +++ b/crates/tests/registry/tests/hstring.rs @@ -1,5 +1,5 @@ use windows_registry::*; -use windows_strings::h; +use windows_strings::*; #[test] fn hstring() -> Result<()> { @@ -9,6 +9,10 @@ fn hstring() -> Result<()> { key.set_hstring("hstring", h!("simple"))?; assert_eq!(&key.get_hstring("hstring")?, h!("simple")); + assert_eq!( + unsafe { key.raw_get_info(w!("hstring"))? }, + (Type::String, 14) + ); // You can embed nulls. key.set_hstring("hstring", h!("hstring\0value\0"))?; diff --git a/crates/tests/registry/tests/u32.rs b/crates/tests/registry/tests/u32.rs index b09186389e..207102991d 100644 --- a/crates/tests/registry/tests/u32.rs +++ b/crates/tests/registry/tests/u32.rs @@ -1,4 +1,5 @@ use windows_registry::*; +use windows_strings::*; #[test] fn u32() -> Result<()> { @@ -11,5 +12,7 @@ fn u32() -> Result<()> { assert_eq!(key.get_u32("u32")?, 123u32); assert_eq!(key.get_u64("u32")?, 123u64); + assert_eq!(unsafe { key.raw_get_info(w!("u32"))? }, (Type::U32, 4)); + Ok(()) } diff --git a/crates/tests/registry/tests/u64.rs b/crates/tests/registry/tests/u64.rs index c096b7aa00..6e6074c201 100644 --- a/crates/tests/registry/tests/u64.rs +++ b/crates/tests/registry/tests/u64.rs @@ -1,4 +1,5 @@ use windows_registry::*; +use windows_strings::*; #[test] fn u64() -> Result<()> { @@ -11,5 +12,7 @@ fn u64() -> Result<()> { assert_eq!(key.get_u32("u64")?, 123u32); assert_eq!(key.get_u64("u64")?, 123u64); + assert_eq!(unsafe { key.raw_get_info(w!("u64"))? }, (Type::U64, 8)); + Ok(()) } diff --git a/crates/tests/registry/tests/value.rs b/crates/tests/registry/tests/value.rs index afb7456238..deb0929621 100644 --- a/crates/tests/registry/tests/value.rs +++ b/crates/tests/registry/tests/value.rs @@ -1,4 +1,5 @@ use windows_registry::*; +use windows_strings::*; #[test] fn value() -> Result<()> { @@ -11,17 +12,29 @@ fn value() -> Result<()> { assert_eq!(key.get_value("u32")?, Value::try_from(123u32)?); assert_eq!(key.get_u32("u32")?, 123u32); assert_eq!(key.get_u64("u32")?, 123u64); + assert_eq!(u32::try_from(key.get_value("u32")?)?, 123u32); + + assert_eq!(unsafe { key.raw_get_info(w!("u32"))? }, (Type::U32, 4)); key.set_value("u64", &Value::try_from(123u64)?)?; assert_eq!(key.get_type("u64")?, Type::U64); assert_eq!(key.get_value("u64")?, Value::try_from(123u64)?); assert_eq!(key.get_u32("u64")?, 123u32); assert_eq!(key.get_u64("u64")?, 123u64); + assert_eq!(u64::try_from(key.get_value("u64")?)?, 123u64); + + assert_eq!(unsafe { key.raw_get_info(w!("u64"))? }, (Type::U64, 8)); key.set_value("string", &Value::try_from("string")?)?; assert_eq!(key.get_type("string")?, Type::String); assert_eq!(key.get_value("string")?, Value::try_from("string")?); assert_eq!(key.get_string("string")?, "string"); + assert_eq!(String::try_from(key.get_value("string")?)?, "string"); + + assert_eq!( + unsafe { key.raw_get_info(w!("string"))? }, + (Type::String, 14) + ); let mut value = Value::try_from("expand")?; value.set_ty(Type::ExpandString); @@ -31,15 +44,43 @@ fn value() -> Result<()> { assert_eq!(key.get_value("expand")?, value); assert_eq!(key.get_string("expand")?, "expand"); + assert_eq!( + unsafe { key.raw_get_info(w!("expand"))? }, + (Type::ExpandString, 14) + ); + key.set_value("bytes", &Value::try_from([1u8, 2u8, 3u8])?)?; assert_eq!(key.get_type("bytes")?, Type::Bytes); assert_eq!(key.get_value("bytes")?, Value::try_from([1, 2, 3])?); + assert_eq!(unsafe { key.raw_get_info(w!("bytes"))? }, (Type::Bytes, 3)); + let mut value = Value::try_from([1u8, 2u8, 3u8, 4u8].as_slice())?; value.set_ty(Type::Other(1234)); key.set_value("slice", &value)?; assert_eq!(key.get_type("slice")?, Type::Other(1234)); assert_eq!(key.get_value("slice")?, value); + assert_eq!( + unsafe { key.raw_get_info(w!("slice"))? }, + (Type::Other(1234), 4) + ); + + key.set_value("hstring", &Value::try_from(h!("HSTRING"))?)?; + assert_eq!(key.get_type("hstring")?, Type::String); + assert_eq!(key.get_value("hstring")?, Value::try_from(h!("HSTRING"))?); + assert_eq!(key.get_string("hstring")?, "HSTRING"); + assert_eq!(HSTRING::try_from(key.get_value("hstring")?)?, "HSTRING"); + + assert_eq!( + unsafe { key.raw_get_info(w!("hstring"))? }, + (Type::String, 16) + ); + + let abc = Value::try_from("abc")?; + assert_eq!(abc.as_wide(), &[97, 98, 99, 0]); + let abc = Value::try_from(h!("abcd"))?; + assert_eq!(abc.as_wide(), &[97, 98, 99, 100, 0]); + Ok(()) } diff --git a/crates/tests/strings/tests/hstring.rs b/crates/tests/strings/tests/hstring.rs index bb7a35a586..23777ad9e0 100644 --- a/crates/tests/strings/tests/hstring.rs +++ b/crates/tests/strings/tests/hstring.rs @@ -4,6 +4,7 @@ use windows_strings::*; fn hstring() -> Result<()> { let s = HSTRING::from("hello"); assert_eq!(s.len(), 5); + assert_eq!(s.as_wide().len(), 5); Ok(()) }