Skip to content

Commit

Permalink
more tests and better termination handling
Browse files Browse the repository at this point in the history
  • Loading branch information
kennykerr committed Aug 5, 2024
1 parent f5ed95c commit cad6f75
Show file tree
Hide file tree
Showing 12 changed files with 80 additions and 21 deletions.
11 changes: 2 additions & 9 deletions crates/libs/registry/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,12 @@ impl Data {
}
}

// Returns the buffer as a slice of u16 for reading wide characters. The slice trims off any trailing zero bytes.
// Returns the buffer as a slice of u16 for reading wide characters.
pub fn as_wide(&self) -> &[u16] {
if self.ptr.is_null() {
&[]
} else {
let mut wide =
unsafe { core::slice::from_raw_parts(self.ptr as *const u16, self.len / 2) };

while wide.last() == Some(&0) {
wide = &wide[..wide.len() - 1];
}

wide
unsafe { core::slice::from_raw_parts(self.ptr as *const u16, self.len / 2) }
}
}

Expand Down
16 changes: 14 additions & 2 deletions crates/libs/registry/src/key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ impl Key {
name: T,
value: &windows_strings::HSTRING,
) -> Result<()> {
self.set_bytes(name, Type::String, value.as_bytes())
self.set_bytes(name, Type::String, as_bytes(value))
}

/// Sets the name and value in the registry key.
Expand All @@ -115,7 +115,7 @@ impl Key {
name: T,
value: &windows_strings::HSTRING,
) -> Result<()> {
self.set_bytes(name, Type::ExpandString, value.as_bytes())
self.set_bytes(name, Type::ExpandString, as_bytes(value))
}

/// Sets the name and value in the registry key.
Expand Down Expand Up @@ -193,12 +193,24 @@ impl Key {
/// # Safety
///
/// The `PCWSTR` pointer needs to be valid for reads up until and including the next `\0`.
#[track_caller]
pub unsafe fn raw_set_bytes<N: AsRef<PCWSTR>>(
&self,
name: N,
ty: Type,
value: &[u8],
) -> Result<()> {
if cfg!(debug_assertions) {
// RegSetValueExW expects string data to be null terminated.
if matches!(ty, Type::String | Type::ExpandString | Type::MultiString) {
debug_assert!(
value.get(value.len() - 2) == Some(&0),
"`value` isn't null-terminated"
);
debug_assert!(value.last() == Some(&0), "`value` isn't null-terminated");
}
}

let result = RegSetValueExW(
self.0,
name.as_ref().as_ptr(),
Expand Down
5 changes: 5 additions & 0 deletions crates/libs/registry/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,8 @@ fn from_le_bytes(ty: Type, from: &[u8]) -> Result<u64> {
_ => Err(invalid_data()),
}
}

// Get the string as 8-bit bytes including the two terminating null bytes.
fn as_bytes(value: &HSTRING) -> &[u8] {
unsafe { core::slice::from_raw_parts(value.as_ptr() as *const _, (value.len() + 1) * 2) }
}
1 change: 1 addition & 0 deletions crates/libs/registry/src/pcwstr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ impl OwnedPcwstr {
self.0.as_ptr()
}

// Get the string as 8-bit bytes including the two terminating null bytes.
pub fn as_bytes(&self) -> &[u8] {
unsafe { core::slice::from_raw_parts(self.as_ptr() as *const _, self.0.len() * 2) }
}
Expand Down
16 changes: 13 additions & 3 deletions crates/libs/registry/src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ impl TryFrom<Value> for String {
type Error = Error;
fn try_from(from: Value) -> Result<Self> {
match from.ty {
Type::String | Type::ExpandString => Ok(Self::from_utf16(from.data.as_wide())?),
Type::String | Type::ExpandString => {
Ok(Self::from_utf16(trim(trim(from.data.as_wide())))?)
}
_ => Err(invalid_data()),
}
}
Expand Down Expand Up @@ -106,7 +108,7 @@ impl TryFrom<Value> for HSTRING {
type Error = Error;
fn try_from(from: Value) -> Result<Self> {
match from.ty {
Type::String | Type::ExpandString => Ok(Self::from_wide(from.data.as_wide())?),
Type::String | Type::ExpandString => Ok(Self::from_wide(trim(from.data.as_wide()))?),
_ => Err(invalid_data()),
}
}
Expand All @@ -116,7 +118,7 @@ impl TryFrom<&HSTRING> for Value {
type Error = Error;
fn try_from(from: &HSTRING) -> Result<Self> {
Ok(Self {
data: Data::from_slice(from.as_bytes())?,
data: Data::from_slice(as_bytes(from))?,
ty: Type::String,
})
}
Expand All @@ -141,3 +143,11 @@ impl<const N: usize> TryFrom<[u8; N]> for Value {
})
}
}

fn trim(mut wide: &[u16]) -> &[u16] {
while wide.last() == Some(&0) {
wide = &wide[..wide.len() - 1];
}

wide
}
5 changes: 0 additions & 5 deletions crates/libs/strings/src/hstring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,6 @@ impl HSTRING {
unsafe { core::slice::from_raw_parts(self.as_ptr(), self.len()) }
}

/// Get the string as 8-bit bytes.
pub fn as_bytes(&self) -> &[u8] {
unsafe { core::slice::from_raw_parts(self.as_ptr() as *const _, self.len() * 2) }
}

/// Returns a raw pointer to the `HSTRING` buffer.
pub fn as_ptr(&self) -> *const u16 {
if let Some(header) = self.as_header() {
Expand Down
6 changes: 6 additions & 0 deletions crates/tests/registry/tests/bytes.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use windows_registry::*;
use windows_strings::*;

#[test]
fn bytes() -> Result<()> {
Expand All @@ -20,5 +21,10 @@ fn bytes() -> Result<()> {
assert_eq!(value.ty(), Type::Other(1234));
assert_eq!(*value, [1, 2, 3, 4]);

assert_eq!(
unsafe { key.raw_get_info(w!("other"))? },
(Type::Other(1234), 4)
);

Ok(())
}
6 changes: 5 additions & 1 deletion crates/tests/registry/tests/hstring.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use windows_registry::*;
use windows_strings::h;
use windows_strings::*;

#[test]
fn hstring() -> Result<()> {
Expand All @@ -9,6 +9,10 @@ fn hstring() -> Result<()> {

key.set_hstring("hstring", h!("simple"))?;
assert_eq!(&key.get_hstring("hstring")?, h!("simple"));
assert_eq!(
unsafe { key.raw_get_info(w!("hstring"))? },
(Type::String, 14)
);

// You can embed nulls.
key.set_hstring("hstring", h!("hstring\0value\0"))?;
Expand Down
3 changes: 3 additions & 0 deletions crates/tests/registry/tests/u32.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use windows_registry::*;
use windows_strings::*;

#[test]
fn u32() -> Result<()> {
Expand All @@ -11,5 +12,7 @@ fn u32() -> Result<()> {
assert_eq!(key.get_u32("u32")?, 123u32);
assert_eq!(key.get_u64("u32")?, 123u64);

assert_eq!(unsafe { key.raw_get_info(w!("u32"))? }, (Type::U32, 4));

Ok(())
}
3 changes: 3 additions & 0 deletions crates/tests/registry/tests/u64.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use windows_registry::*;
use windows_strings::*;

#[test]
fn u64() -> Result<()> {
Expand All @@ -11,5 +12,7 @@ fn u64() -> Result<()> {
assert_eq!(key.get_u32("u64")?, 123u32);
assert_eq!(key.get_u64("u64")?, 123u64);

assert_eq!(unsafe { key.raw_get_info(w!("u64"))? }, (Type::U64, 8));

Ok(())
}
28 changes: 27 additions & 1 deletion crates/tests/registry/tests/value.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use windows_registry::*;
use windows_strings::h;
use windows_strings::*;

#[test]
fn value() -> Result<()> {
Expand All @@ -14,19 +14,28 @@ fn value() -> Result<()> {
assert_eq!(key.get_u64("u32")?, 123u64);
assert_eq!(u32::try_from(key.get_value("u32")?)?, 123u32);

assert_eq!(unsafe { key.raw_get_info(w!("u32"))? }, (Type::U32, 4));

key.set_value("u64", &Value::try_from(123u64)?)?;
assert_eq!(key.get_type("u64")?, Type::U64);
assert_eq!(key.get_value("u64")?, Value::try_from(123u64)?);
assert_eq!(key.get_u32("u64")?, 123u32);
assert_eq!(key.get_u64("u64")?, 123u64);
assert_eq!(u64::try_from(key.get_value("u64")?)?, 123u64);

assert_eq!(unsafe { key.raw_get_info(w!("u64"))? }, (Type::U64, 8));

key.set_value("string", &Value::try_from("string")?)?;
assert_eq!(key.get_type("string")?, Type::String);
assert_eq!(key.get_value("string")?, Value::try_from("string")?);
assert_eq!(key.get_string("string")?, "string");
assert_eq!(String::try_from(key.get_value("string")?)?, "string");

assert_eq!(
unsafe { key.raw_get_info(w!("string"))? },
(Type::String, 14)
);

let mut value = Value::try_from("expand")?;
value.set_ty(Type::ExpandString);
assert_eq!(value.ty(), Type::ExpandString);
Expand All @@ -35,21 +44,38 @@ fn value() -> Result<()> {
assert_eq!(key.get_value("expand")?, value);
assert_eq!(key.get_string("expand")?, "expand");

assert_eq!(
unsafe { key.raw_get_info(w!("expand"))? },
(Type::ExpandString, 14)
);

key.set_value("bytes", &Value::try_from([1u8, 2u8, 3u8])?)?;
assert_eq!(key.get_type("bytes")?, Type::Bytes);
assert_eq!(key.get_value("bytes")?, Value::try_from([1, 2, 3])?);

assert_eq!(unsafe { key.raw_get_info(w!("bytes"))? }, (Type::Bytes, 3));

let mut value = Value::try_from([1u8, 2u8, 3u8, 4u8].as_slice())?;
value.set_ty(Type::Other(1234));
key.set_value("slice", &value)?;
assert_eq!(key.get_type("slice")?, Type::Other(1234));
assert_eq!(key.get_value("slice")?, value);

assert_eq!(
unsafe { key.raw_get_info(w!("slice"))? },
(Type::Other(1234), 4)
);

key.set_value("hstring", &Value::try_from(h!("HSTRING"))?)?;
assert_eq!(key.get_type("hstring")?, Type::String);
assert_eq!(key.get_value("hstring")?, Value::try_from(h!("HSTRING"))?);
assert_eq!(key.get_string("hstring")?, "HSTRING");
assert_eq!(HSTRING::try_from(key.get_value("hstring")?)?, "HSTRING");

assert_eq!(
unsafe { key.raw_get_info(w!("hstring"))? },
(Type::String, 16)
);

Ok(())
}
1 change: 1 addition & 0 deletions crates/tests/strings/tests/hstring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use windows_strings::*;
fn hstring() -> Result<()> {
let s = HSTRING::from("hello");
assert_eq!(s.len(), 5);
assert_eq!(s.as_wide().len(), 5);

Ok(())
}
Expand Down

0 comments on commit cad6f75

Please sign in to comment.