diff --git a/src/async_query.rs b/src/async_query.rs index f05d59c..99ed02d 100644 --- a/src/async_query.rs +++ b/src/async_query.rs @@ -40,7 +40,7 @@ impl WMIConnection { &query_language, &query, WBEM_FLAG_BIDIRECTIONAL, - None, + &self.ctx.0, &p_sink_handle, )?; } diff --git a/src/connection.rs b/src/connection.rs index 8d907c6..ff89c01 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -1,3 +1,4 @@ +use crate::context::WMIContext; use crate::utils::WMIResult; use crate::WMIError; use log::debug; @@ -126,6 +127,7 @@ fn _test_com_lib_not_send(_s: impl Send) {} pub struct WMIConnection { _com_con: COMLibrary, pub svc: IWbemServices, + pub(crate) ctx: WMIContext, } /// A connection to the local WMI provider, which provides querying capabilities. @@ -151,10 +153,12 @@ impl WMIConnection { pub fn with_namespace_path(namespace_path: &str, com_lib: COMLibrary) -> WMIResult { let loc = create_locator()?; let svc = create_services(&loc, namespace_path)?; + let ctx = WMIContext::new()?; let this = Self { _com_con: com_lib, svc, + ctx, }; this.set_proxy()?; diff --git a/src/context.rs b/src/context.rs new file mode 100644 index 0000000..f89ce0c --- /dev/null +++ b/src/context.rs @@ -0,0 +1,144 @@ +use crate::{WMIConnection, WMIResult}; +use log::debug; +use windows::Win32::System::{ + Com::{CoCreateInstance, CLSCTX_INPROC_SERVER}, + Wmi::{IWbemContext, WbemContext}, +}; +use windows_core::{BSTR, VARIANT}; + +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum ContextValueType { + String(String), + I4(i32), + R8(f64), + Bool(bool), +} + +impl From for VARIANT { + fn from(value: ContextValueType) -> Self { + match value { + ContextValueType::Bool(b) => Self::from(b), + ContextValueType::I4(i4) => Self::from(i4), + ContextValueType::R8(r8) => Self::from(r8), + ContextValueType::String(str) => Self::from(BSTR::from(str)), + } + } +} + +#[derive(Clone, Debug)] +pub struct WMIContext(pub(crate) IWbemContext); + +impl WMIContext { + /// Creates a new instances of [`WMIContext`] + pub(crate) fn new() -> WMIResult { + debug!("Calling CoCreateInstance for CLSID_WbemContext"); + + let ctx = unsafe { CoCreateInstance(&WbemContext, None, CLSCTX_INPROC_SERVER)? }; + + debug!("Got context {:?}", ctx); + + Ok(WMIContext(ctx)) + } + + /// Sets the specified named context value for use in providing additional context information to queries. + /// + /// Note the context values will persist across subsequent queries until [`WMIConnection::delete_all`] is called. + pub fn set_value(&mut self, key: &str, value: impl Into) -> WMIResult<()> { + let value = value.into(); + unsafe { self.0.SetValue(&BSTR::from(key), 0, &value.into())? }; + + Ok(()) + } + + /// Clears all named values from the underlying context object. + pub fn delete_all(&mut self) -> WMIResult<()> { + unsafe { self.0.DeleteAll()? }; + + Ok(()) + } +} + +impl WMIConnection { + /// Returns a mutable reference to the [`WMIContext`] object + pub fn ctx(&mut self) -> &mut WMIContext { + &mut self.ctx + } +} + +macro_rules! impl_from_type { + ($target_type:ty, $variant:ident) => { + impl From<$target_type> for ContextValueType { + fn from(value: $target_type) -> Self { + Self::$variant(value.into()) + } + } + }; +} + +impl_from_type!(&str, String); +impl_from_type!(i32, I4); +impl_from_type!(f64, R8); +impl_from_type!(bool, Bool); + +#[allow(non_snake_case)] +#[allow(non_camel_case_types)] +#[allow(dead_code)] +#[cfg(test)] +mod tests { + use super::*; + use crate::COMLibrary; + use serde::Deserialize; + + #[test] + fn verify_ctx_values_used() { + let com_con = COMLibrary::new().unwrap(); + let mut wmi_con = + WMIConnection::with_namespace_path("ROOT\\StandardCimv2", com_con).unwrap(); + + #[derive(Deserialize, PartialEq, Eq, PartialOrd, Ord, Debug)] + struct MSFT_NetAdapter { + InterfaceName: String, + } + + let mut orig_adapters = wmi_con.query::().unwrap(); + assert!(!orig_adapters.is_empty()); + + // With 'IncludeHidden' set to 'true', expect the response to contain additional adapters + wmi_con.ctx().set_value("IncludeHidden", true).unwrap(); + let all_adapters = wmi_con.query::().unwrap(); + assert!(all_adapters.len() > orig_adapters.len()); + + wmi_con.ctx().delete_all().unwrap(); + let mut adapters = wmi_con.query::().unwrap(); + adapters.sort(); + orig_adapters.sort(); + assert_eq!(adapters, orig_adapters); + } + + #[tokio::test] + async fn async_verify_ctx_values_used() { + let com_con = COMLibrary::new().unwrap(); + let mut wmi_con = + WMIConnection::with_namespace_path("ROOT\\StandardCimv2", com_con).unwrap(); + + #[derive(Deserialize, PartialEq, Eq, PartialOrd, Ord, Debug)] + struct MSFT_NetAdapter { + InterfaceName: String, + } + + let mut orig_adapters = wmi_con.async_query::().await.unwrap(); + assert!(!orig_adapters.is_empty()); + + // With 'IncludeHidden' set to 'true', expect the response to contain additional adapters + wmi_con.ctx().set_value("IncludeHidden", true).unwrap(); + let all_adapters = wmi_con.async_query::().await.unwrap(); + assert!(all_adapters.len() > orig_adapters.len()); + + wmi_con.ctx().delete_all().unwrap(); + let mut adapters = wmi_con.async_query::().await.unwrap(); + adapters.sort(); + orig_adapters.sort(); + assert_eq!(adapters, orig_adapters); + } +} diff --git a/src/lib.rs b/src/lib.rs index 0beade0..82b13d3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -273,6 +273,7 @@ pub mod datetime; #[cfg(feature = "time")] mod datetime_time; +pub mod context; pub mod de; pub mod duration; pub mod query; diff --git a/src/query.rs b/src/query.rs index a93578e..cebf31d 100644 --- a/src/query.rs +++ b/src/query.rs @@ -279,7 +279,7 @@ impl WMIConnection { &query_language, &query, WBEM_FLAG_FORWARD_ONLY | WBEM_FLAG_RETURN_IMMEDIATELY, - None, + &self.ctx.0, )? }; @@ -431,7 +431,7 @@ impl WMIConnection { self.svc.GetObject( &object_path, WBEM_FLAG_RETURN_WBEM_COMPLETE, - None, + &self.ctx.0, Some(&mut pcls_obj), None, )?; @@ -536,7 +536,7 @@ impl WMIConnection { /// Query all the associators of type T of the given object. /// The `object_path` argument can be provided by querying an object wih it's `__Path` property. - /// `AssocClass` must be have the name as the conneting association class between the original object and the results. + /// `AssocClass` must be have the name as the connecting association class between the original object and the results. /// See for example. /// /// ```edition2018