Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Deref for COM interface hierarchies in windows-bindgen #2968

Merged
merged 5 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
34 changes: 14 additions & 20 deletions crates/libs/bindgen/src/rust/com_methods.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::*;

pub fn writer(writer: &Writer, def: metadata::TypeDef, kind: metadata::InterfaceKind, method: metadata::MethodDef, method_names: &mut MethodNames, virtual_names: &mut MethodNames, base_count: usize) -> TokenStream {
pub fn writer(writer: &Writer, def: metadata::TypeDef, kind: metadata::InterfaceKind, method: metadata::MethodDef, method_names: &mut MethodNames, virtual_names: &mut MethodNames) -> TokenStream {
let signature = metadata::method_def_signature(def.namespace(), method, &[]);

let name = method_names.add(method);
Expand All @@ -15,12 +15,6 @@ pub fn writer(writer: &Writer, def: metadata::TypeDef, kind: metadata::Interface
return quote! {};
}

let mut bases = quote! {};

for _ in 0..base_count {
bases.combine(&quote! { .base__ });
}

let kind = signature.kind();
match kind {
metadata::SignatureKind::Query(_) => {
Expand All @@ -33,7 +27,7 @@ pub fn writer(writer: &Writer, def: metadata::TypeDef, kind: metadata::Interface
#features
pub unsafe fn #name<#generics>(&self, #params) -> windows_core::Result<T> #where_clause {
let mut result__ = std::ptr::null_mut();
(windows_core::Interface::vtable(self)#bases.#vname)(windows_core::Interface::as_raw(self), #args).and_then(||windows_core::Type::from_abi(result__))
(windows_core::Interface::vtable(self).#vname)(windows_core::Interface::as_raw(self), #args).and_then(||windows_core::Type::from_abi(result__))
}
}
}
Expand All @@ -46,7 +40,7 @@ pub fn writer(writer: &Writer, def: metadata::TypeDef, kind: metadata::Interface
quote! {
#features
pub unsafe fn #name<#generics>(&self, #params result__: *mut Option<T>) -> windows_core::Result<()> #where_clause {
(windows_core::Interface::vtable(self)#bases.#vname)(windows_core::Interface::as_raw(self), #args).ok()
(windows_core::Interface::vtable(self).#vname)(windows_core::Interface::as_raw(self), #args).ok()
}
}
}
Expand All @@ -67,7 +61,7 @@ pub fn writer(writer: &Writer, def: metadata::TypeDef, kind: metadata::Interface
#features
pub unsafe fn #name<#generics>(&self, #params) -> windows_core::Result<#return_type> #where_clause {
let mut result__ = std::mem::zeroed();
(windows_core::Interface::vtable(self)#bases.#vname)(windows_core::Interface::as_raw(self), #args).#map
(windows_core::Interface::vtable(self).#vname)(windows_core::Interface::as_raw(self), #args).#map
}
}
}
Expand All @@ -78,7 +72,7 @@ pub fn writer(writer: &Writer, def: metadata::TypeDef, kind: metadata::Interface
quote! {
#features
pub unsafe fn #name<#generics>(&self, #params) -> windows_core::Result<()> #where_clause {
(windows_core::Interface::vtable(self)#bases.#vname)(windows_core::Interface::as_raw(self), #args).ok()
(windows_core::Interface::vtable(self).#vname)(windows_core::Interface::as_raw(self), #args).ok()
}
}
}
Expand All @@ -95,7 +89,7 @@ pub fn writer(writer: &Writer, def: metadata::TypeDef, kind: metadata::Interface
#features
pub unsafe fn #name<#generics>(&self, #params) -> windows_core::Result<#return_type> #where_clause {
let mut result__ = std::mem::zeroed();
(windows_core::Interface::vtable(self)#bases.#vname)(windows_core::Interface::as_raw(self), #args);
(windows_core::Interface::vtable(self).#vname)(windows_core::Interface::as_raw(self), #args);
windows_core::Type::from_abi(result__)
}
}
Expand All @@ -112,7 +106,7 @@ pub fn writer(writer: &Writer, def: metadata::TypeDef, kind: metadata::Interface
#features
pub unsafe fn #name<#generics>(&self, #params) -> #return_type #where_clause {
let mut result__ = std::mem::zeroed();
(windows_core::Interface::vtable(self)#bases.#vname)(windows_core::Interface::as_raw(self), #args);
(windows_core::Interface::vtable(self).#vname)(windows_core::Interface::as_raw(self), #args);
#map
}
}
Expand All @@ -127,7 +121,7 @@ pub fn writer(writer: &Writer, def: metadata::TypeDef, kind: metadata::Interface
#features
pub unsafe fn #name<#generics>(&self, #params) -> #return_type #where_clause {
let mut result__: #return_type = core::mem::zeroed();
(windows_core::Interface::vtable(self)#bases.#vname)(windows_core::Interface::as_raw(self), &mut result__, #args);
(windows_core::Interface::vtable(self).#vname)(windows_core::Interface::as_raw(self), &mut result__, #args);
result__
}
}
Expand All @@ -140,7 +134,7 @@ pub fn writer(writer: &Writer, def: metadata::TypeDef, kind: metadata::Interface
quote! {
#features
pub unsafe fn #name<#generics>(&self, #params) #return_type #where_clause {
(windows_core::Interface::vtable(self)#bases.#vname)(windows_core::Interface::as_raw(self), #args)
(windows_core::Interface::vtable(self).#vname)(windows_core::Interface::as_raw(self), #args)
}
}
}
Expand All @@ -151,7 +145,7 @@ pub fn writer(writer: &Writer, def: metadata::TypeDef, kind: metadata::Interface
quote! {
#features
pub unsafe fn #name<#generics>(&self, #params) #where_clause {
(windows_core::Interface::vtable(self)#bases.#vname)(windows_core::Interface::as_raw(self), #args)
(windows_core::Interface::vtable(self).#vname)(windows_core::Interface::as_raw(self), #args)
}
}
}
Expand All @@ -166,7 +160,7 @@ pub fn gen_upcall(writer: &Writer, sig: &metadata::Signature, inner: TokenStream
let result = writer.param_name(sig.params[sig.params.len() - 1].def);

quote! {
match #inner(#(#invoke_args,)*) {
match #inner(this, #(#invoke_args,)*) {
Ok(ok__) => {
// use `core::ptr::write` since the result could be uninitialized
core::ptr::write(#result, core::mem::transmute(ok__));
Expand All @@ -180,21 +174,21 @@ pub fn gen_upcall(writer: &Writer, sig: &metadata::Signature, inner: TokenStream
let invoke_args = sig.params.iter().map(|param| gen_win32_invoke_arg(writer, param));

quote! {
#inner(#(#invoke_args,)*).into()
#inner(this, #(#invoke_args,)*).into()
}
}
metadata::SignatureKind::ReturnStruct => {
let invoke_args = sig.params.iter().map(|param| gen_win32_invoke_arg(writer, param));

quote! {
*result__ = #inner(#(#invoke_args,)*)
*result__ = #inner(this, #(#invoke_args,)*)
}
}
_ => {
let invoke_args = sig.params.iter().map(|param| gen_win32_invoke_arg(writer, param));

quote! {
#inner(#(#invoke_args,)*)
#inner(this, #(#invoke_args,)*)
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions crates/libs/bindgen/src/rust/delegates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ fn gen_win_delegate(writer: &Writer, def: metadata::TypeDef) -> TokenStream {

let vtbl_signature = writer.vtbl_signature(def, true, &signature);
let invoke = winrt_methods::writer(writer, def, generics, metadata::InterfaceKind::Default, method, &mut MethodNames::new(), &mut MethodNames::new());
let invoke_upcall = winrt_methods::gen_upcall(writer, &signature, quote! { ((*this).invoke) });
let invoke_upcall = winrt_methods::gen_upcall(writer, &signature, quote! { (this.invoke) }, false);

let mut tokens = if generics.is_empty() {
let iid = writer.guid_literal(metadata::type_def_guid(def));
Expand Down Expand Up @@ -148,7 +148,7 @@ fn gen_win_delegate(writer: &Writer, def: metadata::TypeDef) -> TokenStream {
remaining
}
unsafe extern "system" fn Invoke #vtbl_signature {
let this = this as *mut *mut core::ffi::c_void as *mut Self;
let this = &mut *(this as *mut *mut core::ffi::c_void as *mut Self);
#invoke_upcall
}
}
Expand Down
6 changes: 2 additions & 4 deletions crates/libs/bindgen/src/rust/implements.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ pub fn writer(writer: &Writer, def: metadata::TypeDef) -> TokenStream {
let runtime_name = writer.runtime_name_trait(def, generics, &type_ident, &constraints, &features);

let mut method_names = MethodNames::new();
method_names.add_vtable_types(def);

let method_traits = def.methods().map(|method| {
let name = method_names.add(method);
Expand All @@ -68,14 +67,14 @@ pub fn writer(writer: &Writer, def: metadata::TypeDef) -> TokenStream {
});

let mut method_names = MethodNames::new();
method_names.add_vtable_types(def);

let method_impls = def.methods().map(|method| {
let name = method_names.add(method);
let signature = metadata::method_def_signature(def.namespace(), method, generics);
let vtbl_signature = writer.vtbl_signature(def, true, &signature);
let call = quote! { #impl_ident::#name };

let invoke_upcall = if def.flags().contains(metadata::TypeAttributes::WindowsRuntime) { winrt_methods::gen_upcall(writer, &signature, quote! { this.#name }) } else { com_methods::gen_upcall(writer, &signature, quote! { this.#name }) };
let invoke_upcall = if def.flags().contains(metadata::TypeAttributes::WindowsRuntime) { winrt_methods::gen_upcall(writer, &signature, call, true) } else { com_methods::gen_upcall(writer, &signature, call) };

if has_unknown_base {
quote! {
Expand Down Expand Up @@ -114,7 +113,6 @@ pub fn writer(writer: &Writer, def: metadata::TypeDef) -> TokenStream {
}

let mut method_names = MethodNames::new();
method_names.add_vtable_types(def);

for method in def.methods() {
let name = method_names.add(method);
Expand Down
31 changes: 15 additions & 16 deletions crates/libs/bindgen/src/rust/interfaces.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,26 +87,25 @@ fn gen_win_interface(writer: &Writer, def: metadata::TypeDef) -> TokenStream {
}
}
} else {
let mut bases = vtables.len();
for ty in &vtables {
match ty {
metadata::Type::IUnknown | metadata::Type::IInspectable => {}
metadata::Type::TypeDef(def, _) => {
let kind = if def.type_name() == metadata::TypeName::IDispatch { metadata::InterfaceKind::None } else { metadata::InterfaceKind::Default };
for method in def.methods() {
methods.combine(&com_methods::writer(writer, *def, kind, method, method_names, virtual_names, bases));
}
}
rest => unimplemented!("{rest:?}"),
}

bases -= 1;
}
for method in def.methods() {
methods.combine(&com_methods::writer(writer, def, metadata::InterfaceKind::Default, method, method_names, virtual_names, 0));
methods.combine(&com_methods::writer(writer, def, metadata::InterfaceKind::Default, method, method_names, virtual_names));
}
}

if let Some(base) = vtables.last() {
let base = writer.type_name(base);

tokens.combine(&quote! {
#features
impl<#constraints> std::ops::Deref for #ident {
type Target = #base;
fn deref(&self) -> &Self::Target {
unsafe { std::mem::transmute(self) }
}
}
});
}

if !vtables.is_empty() && generics.is_empty() {
let mut hierarchy = format!("windows_core::imp::interface_hierarchy!({ident}");
let mut hierarchy_cfg = cfg.clone();
Expand Down
10 changes: 0 additions & 10 deletions crates/libs/bindgen/src/rust/method_names.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,6 @@ impl MethodNames {
to_ident(&name)
}
}

pub fn add_vtable_types(&mut self, def: metadata::TypeDef) {
for def in metadata::type_def_vtables(def) {
if let metadata::Type::TypeDef(def, _) = def {
for method in def.methods() {
self.add(method);
}
}
}
}
}

fn method_def_special_name(row: metadata::MethodDef) -> String {
Expand Down
13 changes: 9 additions & 4 deletions crates/libs/bindgen/src/rust/winrt_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,16 +173,21 @@ fn gen_winrt_abi_args(writer: &Writer, params: &[metadata::SignatureParam]) -> T
tokens
}

pub fn gen_upcall(writer: &Writer, sig: &metadata::Signature, inner: TokenStream) -> TokenStream {
pub fn gen_upcall(writer: &Writer, sig: &metadata::Signature, inner: TokenStream, this: bool) -> TokenStream {
let invoke_args = sig.params.iter().map(|param| gen_winrt_invoke_arg(writer, param));
let this = if this {
quote! { this, }
} else {
quote! {}
};

match &sig.return_type {
metadata::Type::Void => quote! {
#inner(#(#invoke_args,)*).into()
#inner(#this #(#invoke_args,)*).into()
},
_ if sig.return_type.is_winrt_array() => {
quote! {
match #inner(#(#invoke_args,)*) {
match #inner(#this #(#invoke_args,)*) {
Ok(ok__) => {
let (ok_data__, ok_data_len__) = ok__.into_abi();
// use `core::ptr::write` since `result` could be uninitialized
Expand All @@ -202,7 +207,7 @@ pub fn gen_upcall(writer: &Writer, sig: &metadata::Signature, inner: TokenStream
};

quote! {
match #inner(#(#invoke_args,)*) {
match #inner(#this #(#invoke_args,)*) {
Ok(ok__) => {
// use `core::ptr::write` since `result` could be uninitialized
core::ptr::write(result__, core::mem::transmute_copy(&ok__));
Expand Down
1 change: 0 additions & 1 deletion crates/libs/bindgen/src/rust/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,6 @@ impl Writer {
let vtbl = self.type_def_vtbl_name(def, generics);
let mut methods = quote! {};
let mut method_names = MethodNames::new();
method_names.add_vtable_types(def);
let phantoms = self.generic_named_phantoms(generics);
let crate_name = self.crate_name();

Expand Down
42 changes: 42 additions & 0 deletions crates/libs/core/src/imp/com_bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,25 @@ pub const E_BOUNDS: windows_core::HRESULT = windows_core::HRESULT(0x8000000B_u32
pub const E_NOINTERFACE: windows_core::HRESULT = windows_core::HRESULT(0x80004002_u32 as _);
pub const E_OUTOFMEMORY: windows_core::HRESULT = windows_core::HRESULT(0x8007000E_u32 as _);
windows_core::imp::define_interface!(IAgileObject, IAgileObject_Vtbl, 0x94ea2b94_e9cc_49e0_c0ff_ee64ca8f5b90);
impl std::ops::Deref for IAgileObject {
type Target = windows_core::IUnknown;
fn deref(&self) -> &Self::Target {
unsafe { std::mem::transmute(self) }
}
}
windows_core::imp::interface_hierarchy!(IAgileObject, windows_core::IUnknown);
impl IAgileObject {}
#[repr(C)]
pub struct IAgileObject_Vtbl {
pub base__: windows_core::IUnknown_Vtbl,
}
windows_core::imp::define_interface!(IAgileReference, IAgileReference_Vtbl, 0xc03f6a43_65a4_9818_987e_e0b810d2a6f2);
impl std::ops::Deref for IAgileReference {
type Target = windows_core::IUnknown;
fn deref(&self) -> &Self::Target {
unsafe { std::mem::transmute(self) }
}
}
windows_core::imp::interface_hierarchy!(IAgileReference, windows_core::IUnknown);
impl IAgileReference {
pub unsafe fn Resolve<T>(&self) -> windows_core::Result<T>
Expand All @@ -86,6 +98,12 @@ pub struct IAgileReference_Vtbl {
pub Resolve: unsafe extern "system" fn(*mut core::ffi::c_void, *const windows_core::GUID, *mut *mut core::ffi::c_void) -> windows_core::HRESULT,
}
windows_core::imp::define_interface!(IPropertyValue, IPropertyValue_Vtbl, 0x4bd682dd_7554_40e9_9a9b_82654ede7e62);
impl std::ops::Deref for IPropertyValue {
type Target = windows_core::IInspectable;
fn deref(&self) -> &Self::Target {
unsafe { std::mem::transmute(self) }
}
}
windows_core::imp::interface_hierarchy!(IPropertyValue, windows_core::IUnknown, windows_core::IInspectable);
impl IPropertyValue {
pub fn Type(&self) -> windows_core::Result<PropertyType> {
Expand Down Expand Up @@ -403,6 +421,12 @@ pub struct IPropertyValueStatics_Vtbl {
pub struct IReference<T>(windows_core::IUnknown, core::marker::PhantomData<T>)
where
T: windows_core::RuntimeType + 'static;
impl<T: windows_core::RuntimeType + 'static> std::ops::Deref for IReference<T> {
type Target = windows_core::IInspectable;
fn deref(&self) -> &Self::Target {
unsafe { std::mem::transmute(self) }
}
}
impl<T: windows_core::RuntimeType + 'static> windows_core::CanInto<windows_core::IUnknown> for IReference<T> {}
impl<T: windows_core::RuntimeType + 'static> windows_core::CanInto<windows_core::IInspectable> for IReference<T> {}
impl<T: windows_core::RuntimeType + 'static> windows_core::CanInto<IPropertyValue> for IReference<T> {
Expand Down Expand Up @@ -650,6 +674,12 @@ where
pub T: core::marker::PhantomData<T>,
}
windows_core::imp::define_interface!(IStringable, IStringable_Vtbl, 0x96369f54_8eb6_48f0_abce_c1b211e627c3);
impl std::ops::Deref for IStringable {
type Target = windows_core::IInspectable;
fn deref(&self) -> &Self::Target {
unsafe { std::mem::transmute(self) }
}
}
windows_core::imp::interface_hierarchy!(IStringable, windows_core::IUnknown, windows_core::IInspectable);
impl IStringable {
pub fn ToString(&self) -> windows_core::Result<windows_core::HSTRING> {
Expand All @@ -669,6 +699,12 @@ pub struct IStringable_Vtbl {
pub ToString: unsafe extern "system" fn(*mut core::ffi::c_void, *mut std::mem::MaybeUninit<windows_core::HSTRING>) -> windows_core::HRESULT,
}
windows_core::imp::define_interface!(IWeakReference, IWeakReference_Vtbl, 0x00000037_0000_0000_c000_000000000046);
impl std::ops::Deref for IWeakReference {
type Target = windows_core::IUnknown;
fn deref(&self) -> &Self::Target {
unsafe { std::mem::transmute(self) }
}
}
windows_core::imp::interface_hierarchy!(IWeakReference, windows_core::IUnknown);
impl IWeakReference {
pub unsafe fn Resolve<T>(&self) -> windows_core::Result<T>
Expand All @@ -685,6 +721,12 @@ pub struct IWeakReference_Vtbl {
pub Resolve: unsafe extern "system" fn(*mut core::ffi::c_void, *const windows_core::GUID, *mut *mut core::ffi::c_void) -> windows_core::HRESULT,
}
windows_core::imp::define_interface!(IWeakReferenceSource, IWeakReferenceSource_Vtbl, 0x00000038_0000_0000_c000_000000000046);
impl std::ops::Deref for IWeakReferenceSource {
type Target = windows_core::IUnknown;
fn deref(&self) -> &Self::Target {
unsafe { std::mem::transmute(self) }
}
}
windows_core::imp::interface_hierarchy!(IWeakReferenceSource, windows_core::IUnknown);
impl IWeakReferenceSource {
pub unsafe fn GetWeakReference(&self) -> windows_core::Result<IWeakReference> {
Expand Down
Loading