diff --git a/src/async_query.rs b/src/async_query.rs index f05d59c..d668484 100644 --- a/src/async_query.rs +++ b/src/async_query.rs @@ -40,7 +40,7 @@ impl WMIConnection { &query_language, &query, WBEM_FLAG_BIDIRECTIONAL, - None, + &self.ctx, &p_sink_handle, )?; } diff --git a/src/connection.rs b/src/connection.rs index 8d907c6..9db5555 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -13,7 +13,8 @@ use windows::Win32::System::Com::{ }; use windows::Win32::System::Rpc::{RPC_C_AUTHN_WINNT, RPC_C_AUTHZ_NONE}; use windows::Win32::System::Wmi::{ - IWbemLocator, IWbemServices, WbemLocator, WBEM_FLAG_CONNECT_USE_MAX_WAIT, + IWbemContext, IWbemLocator, IWbemServices, WbemContext, WbemLocator, + WBEM_FLAG_CONNECT_USE_MAX_WAIT, }; /// A marker to indicate that the current thread was `CoInitialize`d. @@ -126,6 +127,7 @@ fn _test_com_lib_not_send(_s: impl Send) {} pub struct WMIConnection { _com_con: COMLibrary, pub svc: IWbemServices, + pub ctx: IWbemContext, } /// A connection to the local WMI provider, which provides querying capabilities. @@ -151,10 +153,12 @@ impl WMIConnection { pub fn with_namespace_path(namespace_path: &str, com_lib: COMLibrary) -> WMIResult { let loc = create_locator()?; let svc = create_services(&loc, namespace_path)?; + let ctx = create_context()?; let this = Self { _com_con: com_lib, svc, + ctx, }; this.set_proxy()?; @@ -191,6 +195,16 @@ fn create_locator() -> WMIResult { Ok(loc) } +fn create_context() -> WMIResult { + debug!("Calling CoCreateInstance for CLSID_WbemContext"); + + let ctx = unsafe { CoCreateInstance(&WbemContext, None, CLSCTX_INPROC_SERVER)? }; + + debug!("Got context {:?}", ctx); + + Ok(ctx) +} + fn create_services(loc: &IWbemLocator, path: &str) -> WMIResult { debug!("Calling ConnectServer"); diff --git a/src/context.rs b/src/context.rs new file mode 100644 index 0000000..b275fd1 --- /dev/null +++ b/src/context.rs @@ -0,0 +1,103 @@ +use std::collections::HashMap; + +use serde::Serialize; +use windows_core::{BSTR, VARIANT}; + +use crate::{WMIConnection, WMIResult}; + +#[derive(Debug, PartialEq, Serialize, Clone)] +#[serde(untagged)] +pub enum ContextValueType { + String(String), + I4(i32), + R8(f64), + Bool(bool), +} + +impl From for VARIANT { + fn from(value: ContextValueType) -> Self { + match value { + ContextValueType::Bool(b) => Self::from(b), + ContextValueType::I4(i4) => Self::from(i4), + ContextValueType::R8(r8) => Self::from(r8), + ContextValueType::String(str) => Self::from(BSTR::from(str)), + } + } +} + +impl WMIConnection { + /// Sets the specified named context values for use in providing additional context information to queries. + /// + /// Note the context values will persist across subsequent queries until [`WMIConnection::clear_ctx_values`] is called. + pub fn set_ctx_values( + &mut self, + ctx_values: HashMap, + ) -> WMIResult<()> { + for (k, v) in ctx_values { + let key = BSTR::from(k); + let value = v.clone().into(); + unsafe { self.ctx.SetValue(&key, 0, &value)? }; + } + + Ok(()) + } + + /// Clears all named values from the underlying context object. + pub fn clear_ctx_values(&mut self) -> WMIResult<()> { + unsafe { self.ctx.DeleteAll().map_err(Into::into) } + } +} + +macro_rules! impl_from_type { + ($target_type:ty, $variant:ident) => { + impl From<$target_type> for ContextValueType { + fn from(value: $target_type) -> Self { + Self::$variant(value.into()) + } + } + }; +} + +impl_from_type!(&str, String); +impl_from_type!(i32, I4); +impl_from_type!(f64, R8); +impl_from_type!(bool, Bool); + +#[allow(non_snake_case)] +#[allow(non_camel_case_types)] +#[allow(dead_code)] +#[cfg(test)] +mod tests { + use super::*; + use crate::COMLibrary; + use serde::Deserialize; + + #[test] + fn verify_ctx_values_used() { + let com_con = COMLibrary::new().unwrap(); + let mut wmi_con = + WMIConnection::with_namespace_path("ROOT\\StandardCimv2", com_con).unwrap(); + + #[derive(Deserialize, PartialEq, Eq, PartialOrd, Ord, Debug)] + struct MSFT_NetAdapter { + InterfaceName: String, + } + + let mut orig_adapters = wmi_con.query::().unwrap(); + assert!(!orig_adapters.is_empty()); + + let mut ctx_values = HashMap::new(); + ctx_values.insert("IncludeHidden".into(), true.into()); + wmi_con.set_ctx_values(ctx_values).unwrap(); + + // With 'IncludeHidden' set to 'true', expect the response to contain additional adapters + let all_adapters = wmi_con.query::().unwrap(); + assert!(all_adapters.len() > orig_adapters.len()); + + wmi_con.clear_ctx_values().unwrap(); + let mut adapters = wmi_con.query::().unwrap(); + adapters.sort(); + orig_adapters.sort(); + assert_eq!(adapters, orig_adapters); + } +} diff --git a/src/lib.rs b/src/lib.rs index 0beade0..82b13d3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -273,6 +273,7 @@ pub mod datetime; #[cfg(feature = "time")] mod datetime_time; +pub mod context; pub mod de; pub mod duration; pub mod query; diff --git a/src/query.rs b/src/query.rs index a93578e..4b2d747 100644 --- a/src/query.rs +++ b/src/query.rs @@ -279,7 +279,7 @@ impl WMIConnection { &query_language, &query, WBEM_FLAG_FORWARD_ONLY | WBEM_FLAG_RETURN_IMMEDIATELY, - None, + &self.ctx, )? }; @@ -536,7 +536,7 @@ impl WMIConnection { /// Query all the associators of type T of the given object. /// The `object_path` argument can be provided by querying an object wih it's `__Path` property. - /// `AssocClass` must be have the name as the conneting association class between the original object and the results. + /// `AssocClass` must be have the name as the connecting association class between the original object and the results. /// See for example. /// /// ```edition2018