From d31c3ab0a52be3d35cfd9442dfb3fc2bc12d1a76 Mon Sep 17 00:00:00 2001 From: Erich Gubler Date: Wed, 20 Sep 2023 17:14:43 -0400 Subject: [PATCH] fix(d3d12): remove nullability in `ComPtr` --- d3d12/src/com.rs | 90 +++++++------------------ d3d12/src/command_list.rs | 19 ++++-- d3d12/src/debug.rs | 10 +-- d3d12/src/descriptor.rs | 23 +++---- d3d12/src/device.rs | 86 +++++++++++------------ d3d12/src/dxgi.rs | 56 ++++++++------- d3d12/src/lib.rs | 1 - d3d12/src/pso.rs | 10 +-- wgpu-hal/src/auxil/dxgi/factory.rs | 24 ++++--- wgpu-hal/src/dx11/library.rs | 13 ++-- wgpu-hal/src/dx12/adapter.rs | 2 +- wgpu-hal/src/dx12/command.rs | 15 ++--- wgpu-hal/src/dx12/device.rs | 38 ++++++----- wgpu-hal/src/dx12/mod.rs | 14 ++-- wgpu-hal/src/dx12/shader_compilation.rs | 10 +-- wgpu-hal/src/dx12/suballocation.rs | 48 +++++++------ 16 files changed, 225 insertions(+), 234 deletions(-) diff --git a/d3d12/src/com.rs b/d3d12/src/com.rs index f0eca926fe..db790e8a67 100644 --- a/d3d12/src/com.rs +++ b/d3d12/src/com.rs @@ -3,80 +3,44 @@ use std::{ fmt, hash::{Hash, Hasher}, ops::Deref, - ptr, }; -use winapi::{ctypes::c_void, um::unknwnbase::IUnknown, Interface}; +use winapi::{um::unknwnbase::IUnknown, Interface}; #[repr(transparent)] pub struct ComPtr(*mut T); impl ComPtr { - /// Creates a null ComPtr. - pub fn null() -> Self { - ComPtr(ptr::null_mut()) + /// Create a ComPtr from a raw pointer. This will _not_ call AddRef on the pointer, assuming + /// that it has already been called. + /// + /// # Safety + /// + /// - `raw` must be a valid pointer to a COM object that implements T. + pub unsafe fn from_reffed(raw: *mut T) -> Self { + debug_assert!(!raw.is_null()); + ComPtr(raw) } /// Create a ComPtr from a raw pointer. This will call AddRef on the pointer. /// /// # Safety /// - /// - if `raw` is not null, it must be a valid pointer to a COM object that implements T. + /// - `raw` must be a valid pointer to a COM object that implements T. pub unsafe fn from_raw(raw: *mut T) -> Self { - if !raw.is_null() { - (*(raw as *mut IUnknown)).AddRef(); - } + debug_assert!(!raw.is_null()); + (*(raw as *mut IUnknown)).AddRef(); ComPtr(raw) } - /// Returns true if the inner pointer is null. - pub fn is_null(&self) -> bool { - self.0.is_null() - } - - /// Returns the raw inner pointer. May be null. + /// Returns the raw inner pointer. pub fn as_ptr(&self) -> *const T { self.0 } - /// Returns the raw inner pointer as mutable. May be null. + /// Returns the raw inner pointer as mutable. pub fn as_mut_ptr(&self) -> *mut T { self.0 } - - /// Returns a mutable reference to the inner pointer casted as a pointer to c_void. - /// - /// This is useful when D3D functions initialize objects by filling in a pointer to pointer - /// by taking `void**` as an argument. - /// - /// # Safety - /// - /// - Any modifications done to this pointer must result in the pointer either: - /// - being set to null - /// - being set to a valid pointer to a COM object that implements T - pub unsafe fn mut_void(&mut self) -> &mut *mut c_void { - // SAFETY: We must first get a reference pointing to our internal pointer - // and only then cast it. As if we cast it, then take a reference, we would - // end up with a reference to a temporary. - let refer: &mut *mut T = &mut self.0; - let void: *mut *mut c_void = refer.cast(); - - // SAFETY: This reference is valid for the duration of the borrow due our mutable borrow of self. - &mut *void - } - - /// Returns a mutable reference to the inner pointer. - /// - /// This is useful when D3D functions initialize objects by filling in a pointer to pointer - /// by taking `T**` as an argument. - /// - /// # Safety - /// - /// - Any modifications done to this pointer must result in the pointer either: - /// - being set to null - /// - being set to a valid pointer to a COM object that implements T - pub fn mut_self(&mut self) -> &mut *mut T { - &mut self.0 - } } impl ComPtr { @@ -86,7 +50,6 @@ impl ComPtr { /// /// - This pointer must not be null. pub unsafe fn as_unknown(&self) -> &IUnknown { - debug_assert!(!self.is_null()); &*(self.0 as *mut IUnknown) } @@ -95,25 +58,21 @@ impl ComPtr { /// # Safety /// /// - This pointer must not be null. - pub unsafe fn cast(&self) -> D3DResult> + pub unsafe fn cast(&self) -> D3DResult>> where U: Interface, { - debug_assert!(!self.is_null()); - let mut obj = ComPtr::::null(); - let hr = self - .as_unknown() - .QueryInterface(&U::uuidof(), obj.mut_void()); + let mut obj = std::ptr::null_mut(); + let hr = self.as_unknown().QueryInterface(&U::uuidof(), &mut obj); + let obj = (!obj.is_null()).then(|| ComPtr::from_reffed(obj.cast())); (obj, hr) } } impl Clone for ComPtr { fn clone(&self) -> Self { - if !self.is_null() { - unsafe { - self.as_unknown().AddRef(); - } + unsafe { + self.as_unknown().AddRef(); } ComPtr(self.0) } @@ -121,10 +80,8 @@ impl Clone for ComPtr { impl Drop for ComPtr { fn drop(&mut self) { - if !self.is_null() { - unsafe { - self.as_unknown().Release(); - } + unsafe { + self.as_unknown().Release(); } } } @@ -132,7 +89,6 @@ impl Drop for ComPtr { impl Deref for ComPtr { type Target = T; fn deref(&self) -> &T { - assert!(!self.is_null()); unsafe { &*self.0 } } } diff --git a/d3d12/src/command_list.rs b/d3d12/src/command_list.rs index 168d935e30..2068d7f38f 100644 --- a/d3d12/src/command_list.rs +++ b/d3d12/src/command_list.rs @@ -153,8 +153,13 @@ impl GraphicsCommandList { unsafe { self.Close() } } - pub fn reset(&self, allocator: &CommandAllocator, initial_pso: PipelineState) -> HRESULT { - unsafe { self.Reset(allocator.as_mut_ptr(), initial_pso.as_mut_ptr()) } + pub fn reset( + &self, + allocator: &CommandAllocator, + initial_pso: Option<&PipelineState>, + ) -> HRESULT { + let initial_pso = initial_pso.map_or(ptr::null_mut(), |pso| pso.as_mut_ptr()); + unsafe { self.Reset(allocator.as_mut_ptr(), initial_pso) } } pub fn discard_resource(&self, resource: Resource, region: DiscardRegion) { @@ -284,15 +289,17 @@ impl GraphicsCommandList { } } - pub fn set_compute_root_signature(&self, signature: &RootSignature) { + pub fn set_compute_root_signature(&self, signature: Option<&RootSignature>) { unsafe { - self.SetComputeRootSignature(signature.as_mut_ptr()); + self.SetComputeRootSignature(signature.map_or(ptr::null_mut(), |sig| sig.as_mut_ptr())); } } - pub fn set_graphics_root_signature(&self, signature: &RootSignature) { + pub fn set_graphics_root_signature(&self, signature: Option<&RootSignature>) { unsafe { - self.SetGraphicsRootSignature(signature.as_mut_ptr()); + self.SetGraphicsRootSignature( + signature.map_or(ptr::null_mut(), |sig| sig.as_mut_ptr()), + ); } } diff --git a/d3d12/src/debug.rs b/d3d12/src/debug.rs index 3a6abc46b7..e0e2729137 100644 --- a/d3d12/src/debug.rs +++ b/d3d12/src/debug.rs @@ -13,11 +13,12 @@ impl crate::D3D12Lib { *mut *mut winapi::ctypes::c_void, ) -> crate::HRESULT; - let mut debug = Debug::null(); + let mut debug = std::ptr::null_mut(); let hr = unsafe { let func: libloading::Symbol = self.lib.get(b"D3D12GetDebugInterface")?; - func(&d3d12sdklayers::ID3D12Debug::uuidof(), debug.mut_void()) + func(&d3d12sdklayers::ID3D12Debug::uuidof(), &mut debug) }; + let debug = unsafe { ComPtr::from_reffed(debug.cast()) }; Ok((debug, hr)) } @@ -26,13 +27,14 @@ impl crate::D3D12Lib { impl Debug { #[cfg(feature = "implicit-link")] pub fn get_interface() -> crate::D3DResult { - let mut debug = Debug::null(); + let mut debug = std::ptr::null_mut(); let hr = unsafe { winapi::um::d3d12::D3D12GetDebugInterface( &d3d12sdklayers::ID3D12Debug::uuidof(), - debug.mut_void(), + &mut debug, ) }; + let debug = unsafe { ComPtr::from_reffed(debug.cast()) }; (debug, hr) } diff --git a/d3d12/src/descriptor.rs b/d3d12/src/descriptor.rs index e0b4e6a665..01c83e01de 100644 --- a/d3d12/src/descriptor.rs +++ b/d3d12/src/descriptor.rs @@ -266,7 +266,7 @@ bitflags::bitflags! { } pub type RootSignature = ComPtr; -pub type BlobResult = D3DResult<(Blob, Error)>; +pub type BlobResult = D3DResult<(Blob, Option)>; #[cfg(feature = "libloading")] impl crate::D3D12Lib { @@ -293,12 +293,14 @@ impl crate::D3D12Lib { Flags: flags.bits(), }; - let mut blob = Blob::null(); - let mut error = Error::null(); + let mut blob = std::ptr::null_mut(); + let mut error = std::ptr::null_mut(); let hr = unsafe { let func: libloading::Symbol = self.lib.get(b"D3D12SerializeRootSignature")?; - func(&desc, version as _, blob.mut_self(), error.mut_self()) + func(&desc, version as _, &mut blob, &mut error) }; + let blob = unsafe { ComPtr::from_reffed(blob) }; + let error = (!error.is_null()).then(|| unsafe { ComPtr::from_reffed(error) }); Ok(((blob, error), hr)) } @@ -312,8 +314,8 @@ impl RootSignature { static_samplers: &[StaticSampler], flags: RootSignatureFlags, ) -> BlobResult { - let mut blob = Blob::null(); - let mut error = Error::null(); + let mut blob = std::ptr::null_mut(); + let mut error = std::ptr::null_mut(); let desc = d3d12::D3D12_ROOT_SIGNATURE_DESC { NumParameters: parameters.len() as _, @@ -324,13 +326,10 @@ impl RootSignature { }; let hr = unsafe { - d3d12::D3D12SerializeRootSignature( - &desc, - version as _, - blob.mut_self(), - error.mut_self(), - ) + d3d12::D3D12SerializeRootSignature(&desc, version as _, &mut blob, &mut error) }; + let blob = unsafe { ComPtr::from_reffed(blob) }; + let error = unsafe { ComPtr::from_reffed(error) }; ((blob, error), hr) } diff --git a/d3d12/src/device.rs b/d3d12/src/device.rs index 96d6866ad8..0c54024aa6 100644 --- a/d3d12/src/device.rs +++ b/d3d12/src/device.rs @@ -6,8 +6,8 @@ use crate::{ descriptor::{CpuDescriptor, DescriptorHeapFlags, DescriptorHeapType, RenderTargetViewDesc}, heap::{Heap, HeapFlags, HeapProperties}, pso, query, queue, Blob, CachedPSO, CommandAllocator, CommandQueue, D3DResult, DescriptorHeap, - Fence, GraphicsCommandList, NodeMask, PipelineState, QueryHeap, Resource, RootSignature, - Shader, TextureAddressMode, + Fence, GraphicsCommandList, NodeMask, PipelineState, QueryHeap, RootSignature, Shader, + TextureAddressMode, }; use std::ops::Range; use winapi::{um::d3d12, Interface}; @@ -28,16 +28,17 @@ impl crate::D3D12Lib { *mut *mut winapi::ctypes::c_void, ) -> crate::HRESULT; - let mut device = Device::null(); + let mut device = std::ptr::null_mut(); let hr = unsafe { let func: libloading::Symbol = self.lib.get(b"D3D12CreateDevice")?; func( adapter.as_unknown() as *const _ as *mut _, feature_level as _, &d3d12::ID3D12Device::uuidof(), - device.mut_void(), + &mut device, ) }; + let device = unsafe { ComPtr::from_reffed(device.cast()) }; Ok((device, hr)) } @@ -49,15 +50,16 @@ impl Device { adapter: ComPtr, feature_level: crate::FeatureLevel, ) -> D3DResult { - let mut device = Device::null(); + let mut device = std::ptr::null_mut(); let hr = unsafe { d3d12::D3D12CreateDevice( adapter.as_unknown() as *const _ as *mut _, feature_level as _, &d3d12::ID3D12Device::uuidof(), - device.mut_void(), + &mut device, ) }; + let device = unsafe { ComPtr::from_reffed(device.cast()) }; (device, hr) } @@ -69,7 +71,7 @@ impl Device { alignment: u64, flags: HeapFlags, ) -> D3DResult { - let mut heap = Heap::null(); + let mut heap = std::ptr::null_mut(); let desc = d3d12::D3D12_HEAP_DESC { SizeInBytes: size_in_bytes, @@ -78,20 +80,22 @@ impl Device { Flags: flags.bits(), }; - let hr = unsafe { self.CreateHeap(&desc, &d3d12::ID3D12Heap::uuidof(), heap.mut_void()) }; + let hr = unsafe { self.CreateHeap(&desc, &d3d12::ID3D12Heap::uuidof(), &mut heap) }; + let heap = unsafe { ComPtr::from_reffed(heap.cast()) }; (heap, hr) } pub fn create_command_allocator(&self, list_type: CmdListType) -> D3DResult { - let mut allocator = CommandAllocator::null(); + let mut allocator = std::ptr::null_mut(); let hr = unsafe { self.CreateCommandAllocator( list_type as _, &d3d12::ID3D12CommandAllocator::uuidof(), - allocator.mut_void(), + &mut allocator, ) }; + let allocator = unsafe { ComPtr::from_reffed(allocator.cast()) }; (allocator, hr) } @@ -110,14 +114,11 @@ impl Device { NodeMask: node_mask, }; - let mut queue = CommandQueue::null(); + let mut queue = std::ptr::null_mut(); let hr = unsafe { - self.CreateCommandQueue( - &desc, - &d3d12::ID3D12CommandQueue::uuidof(), - queue.mut_void(), - ) + self.CreateCommandQueue(&desc, &d3d12::ID3D12CommandQueue::uuidof(), &mut queue) }; + let queue = unsafe { ComPtr::from_reffed(queue.cast()) }; (queue, hr) } @@ -136,14 +137,11 @@ impl Device { NodeMask: node_mask, }; - let mut heap = DescriptorHeap::null(); + let mut heap = std::ptr::null_mut(); let hr = unsafe { - self.CreateDescriptorHeap( - &desc, - &d3d12::ID3D12DescriptorHeap::uuidof(), - heap.mut_void(), - ) + self.CreateDescriptorHeap(&desc, &d3d12::ID3D12DescriptorHeap::uuidof(), &mut heap) }; + let heap = unsafe { ComPtr::from_reffed(heap.cast()) }; (heap, hr) } @@ -156,20 +154,22 @@ impl Device { &self, list_type: CmdListType, allocator: &CommandAllocator, - initial: PipelineState, + initial: Option<&PipelineState>, node_mask: NodeMask, ) -> D3DResult { - let mut command_list = GraphicsCommandList::null(); + let mut command_list = std::ptr::null_mut(); + let initial = initial.map_or(std::ptr::null_mut(), |i| i.as_mut_ptr()); let hr = unsafe { self.CreateCommandList( node_mask, list_type as _, allocator.as_mut_ptr(), - initial.as_mut_ptr(), + initial, &d3d12::ID3D12GraphicsCommandList::uuidof(), - command_list.mut_void(), + &mut command_list, ) }; + let command_list = unsafe { ComPtr::from_reffed(command_list.cast()) }; (command_list, hr) } @@ -186,14 +186,11 @@ impl Device { NodeMask: node_mask, }; - let mut query_heap = QueryHeap::null(); + let mut query_heap = std::ptr::null_mut(); let hr = unsafe { - self.CreateQueryHeap( - &desc, - &d3d12::ID3D12QueryHeap::uuidof(), - query_heap.mut_void(), - ) + self.CreateQueryHeap(&desc, &d3d12::ID3D12QueryHeap::uuidof(), &mut query_heap) }; + let query_heap = unsafe { ComPtr::from_reffed(query_heap.cast()) }; (query_heap, hr) } @@ -215,15 +212,16 @@ impl Device { pub fn create_compute_pipeline_state( &self, - root_signature: &RootSignature, + root_signature: Option<&RootSignature>, cs: Shader, node_mask: NodeMask, cached_pso: CachedPSO, flags: pso::PipelineStateFlags, ) -> D3DResult { - let mut pipeline = PipelineState::null(); + let mut pipeline = std::ptr::null_mut(); + let root_signature = root_signature.map_or(std::ptr::null_mut(), |sig| sig.as_mut_ptr()); let desc = d3d12::D3D12_COMPUTE_PIPELINE_STATE_DESC { - pRootSignature: root_signature.as_mut_ptr(), + pRootSignature: root_signature, CS: *cs, NodeMask: node_mask, CachedPSO: *cached_pso, @@ -234,9 +232,10 @@ impl Device { self.CreateComputePipelineState( &desc, &d3d12::ID3D12PipelineState::uuidof(), - pipeline.mut_void(), + &mut pipeline, ) }; + let pipeline = unsafe { ComPtr::from_reffed(pipeline.cast()) }; (pipeline, hr) } @@ -275,16 +274,17 @@ impl Device { blob: Blob, node_mask: NodeMask, ) -> D3DResult { - let mut signature = RootSignature::null(); + let mut signature = std::ptr::null_mut(); let hr = unsafe { self.CreateRootSignature( node_mask, blob.GetBufferPointer(), blob.GetBufferSize(), &d3d12::ID3D12RootSignature::uuidof(), - signature.mut_void(), + &mut signature, ) }; + let signature = unsafe { ComPtr::from_reffed(signature.cast()) }; (signature, hr) } @@ -295,7 +295,7 @@ impl Device { stride: u32, node_mask: NodeMask, ) -> D3DResult { - let mut signature = CommandSignature::null(); + let mut signature = std::ptr::null_mut(); let desc = d3d12::D3D12_COMMAND_SIGNATURE_DESC { ByteStride: stride, NumArgumentDescs: arguments.len() as _, @@ -308,9 +308,10 @@ impl Device { &desc, std::ptr::null_mut(), &d3d12::ID3D12CommandSignature::uuidof(), - signature.mut_void(), + &mut signature, ) }; + let signature = unsafe { ComPtr::from_reffed(signature.cast()) }; (signature, hr) } @@ -329,15 +330,16 @@ impl Device { // TODO: interface not complete pub fn create_fence(&self, initial: u64) -> D3DResult { - let mut fence = Fence::null(); + let mut fence = std::ptr::null_mut(); let hr = unsafe { self.CreateFence( initial, d3d12::D3D12_FENCE_FLAG_NONE, &d3d12::ID3D12Fence::uuidof(), - fence.mut_void(), + &mut fence, ) }; + let fence = unsafe { ComPtr::from_reffed(fence.cast()) }; (fence, hr) } diff --git a/d3d12/src/dxgi.rs b/d3d12/src/dxgi.rs index ca5db4e6da..c8b306b272 100644 --- a/d3d12/src/dxgi.rs +++ b/d3d12/src/dxgi.rs @@ -115,15 +115,16 @@ impl DxgiLib { *mut *mut winapi::ctypes::c_void, ) -> HRESULT; - let mut factory = Factory4::null(); + let mut factory = std::ptr::null_mut(); let hr = unsafe { let func: libloading::Symbol = self.lib.get(b"CreateDXGIFactory2")?; func( flags.bits(), &dxgi1_4::IDXGIFactory4::uuidof(), - factory.mut_void(), + &mut factory, ) }; + let factory = unsafe { ComPtr::from_reffed(factory.cast()) }; Ok((factory, hr)) } @@ -134,11 +135,12 @@ impl DxgiLib { *mut *mut winapi::ctypes::c_void, ) -> HRESULT; - let mut factory = Factory1::null(); + let mut factory = std::ptr::null_mut(); let hr = unsafe { let func: libloading::Symbol = self.lib.get(b"CreateDXGIFactory1")?; - func(&dxgi::IDXGIFactory1::uuidof(), factory.mut_void()) + func(&dxgi::IDXGIFactory1::uuidof(), &mut factory) }; + let factory = unsafe { ComPtr::from_reffed(factory.cast()) }; Ok((factory, hr)) } @@ -149,12 +151,13 @@ impl DxgiLib { *mut *mut winapi::ctypes::c_void, ) -> HRESULT; - let mut factory = FactoryMedia::null(); + let mut factory = std::ptr::null_mut(); let hr = unsafe { // https://learn.microsoft.com/en-us/windows/win32/api/dxgi1_3/nn-dxgi1_3-idxgifactorymedia let func: libloading::Symbol = self.lib.get(b"CreateDXGIFactory1")?; - func(&dxgi1_3::IDXGIFactoryMedia::uuidof(), factory.mut_void()) + func(&dxgi1_3::IDXGIFactoryMedia::uuidof(), &mut factory) }; + let factory = unsafe { ComPtr::from_reffed(factory.cast()) }; Ok((factory, hr)) } @@ -166,11 +169,12 @@ impl DxgiLib { *mut *mut winapi::ctypes::c_void, ) -> HRESULT; - let mut queue = InfoQueue::null(); + let mut queue = std::ptr::null_mut(); let hr = unsafe { let func: libloading::Symbol = self.lib.get(b"DXGIGetDebugInterface1")?; - func(0, &dxgidebug::IDXGIInfoQueue::uuidof(), queue.mut_void()) + func(0, &dxgidebug::IDXGIInfoQueue::uuidof(), &mut queue) }; + let queue = unsafe { ComPtr::from_reffed(queue.cast()) }; Ok((queue, hr)) } } @@ -244,8 +248,9 @@ impl Factory1 { Flags: desc.flags, }; - let mut swapchain = SwapChain::null(); - let hr = unsafe { self.CreateSwapChain(queue, &mut desc, swapchain.mut_self()) }; + let mut swapchain = std::ptr::null_mut(); + let hr = unsafe { self.CreateSwapChain(queue, &mut desc, &mut swapchain) }; + let swapchain = unsafe { ComPtr::from_reffed(swapchain) }; (swapchain, hr) } @@ -262,7 +267,7 @@ impl Factory2 { hwnd: HWND, desc: &SwapchainDesc, ) -> D3DResult { - let mut swap_chain = SwapChain1::null(); + let mut swap_chain = std::ptr::null_mut(); let hr = unsafe { self.CreateSwapChainForHwnd( queue, @@ -270,9 +275,10 @@ impl Factory2 { &desc.to_desc1(), ptr::null(), ptr::null_mut(), - swap_chain.mut_self(), + &mut swap_chain, ) }; + let swap_chain = unsafe { ComPtr::from_reffed(swap_chain) }; (swap_chain, hr) } @@ -285,15 +291,16 @@ impl Factory2 { queue: *mut IUnknown, desc: &SwapchainDesc, ) -> D3DResult { - let mut swap_chain = SwapChain1::null(); + let mut swap_chain = std::ptr::null_mut(); let hr = unsafe { self.CreateSwapChainForComposition( queue, &desc.to_desc1(), ptr::null_mut(), - swap_chain.mut_self(), + &mut swap_chain, ) }; + let swap_chain = unsafe { ComPtr::from_reffed(swap_chain) }; (swap_chain, hr) } @@ -302,21 +309,23 @@ impl Factory2 { impl Factory4 { #[cfg(feature = "implicit-link")] pub fn create(flags: FactoryCreationFlags) -> D3DResult { - let mut factory = Factory4::null(); + let mut factory = std::ptr::null_mut(); let hr = unsafe { dxgi1_3::CreateDXGIFactory2( flags.bits(), &dxgi1_4::IDXGIFactory4::uuidof(), - factory.mut_void(), + &mut factory, ) }; + let factory = unsafe { ComPtr::from_reffed(factory.cast()) }; (factory, hr) } pub fn enumerate_adapters(&self, id: u32) -> D3DResult { - let mut adapter = Adapter1::null(); - let hr = unsafe { self.EnumAdapters1(id, adapter.mut_self()) }; + let mut adapter = std::ptr::null_mut(); + let hr = unsafe { self.EnumAdapters1(id, &mut adapter) }; + let adapter = unsafe { ComPtr::from_reffed(adapter) }; (adapter, hr) } @@ -332,16 +341,17 @@ impl FactoryMedia { surface_handle: HANDLE, desc: &SwapchainDesc, ) -> D3DResult { - let mut swap_chain = SwapChain1::null(); + let mut swap_chain = std::ptr::null_mut(); let hr = unsafe { self.CreateSwapChainForCompositionSurfaceHandle( queue, surface_handle, &desc.to_desc1(), ptr::null_mut(), - swap_chain.mut_self(), + &mut swap_chain, ) }; + let swap_chain = unsafe { ComPtr::from_reffed(swap_chain) }; (swap_chain, hr) } @@ -364,9 +374,9 @@ bitflags::bitflags! { impl SwapChain { pub fn get_buffer(&self, id: u32) -> D3DResult { - let mut resource = Resource::null(); - let hr = - unsafe { self.GetBuffer(id, &d3d12::ID3D12Resource::uuidof(), resource.mut_void()) }; + let mut resource = std::ptr::null_mut(); + let hr = unsafe { self.GetBuffer(id, &d3d12::ID3D12Resource::uuidof(), &mut resource) }; + let resource = unsafe { ComPtr::from_reffed(resource.cast()) }; (resource, hr) } diff --git a/d3d12/src/lib.rs b/d3d12/src/lib.rs index db0aecdd93..e990a110db 100644 --- a/d3d12/src/lib.rs +++ b/d3d12/src/lib.rs @@ -100,7 +100,6 @@ pub type Blob = ComPtr; pub type Error = ComPtr; impl Error { pub fn as_c_str(&self) -> &CStr { - assert!(!self.is_null()); unsafe { let data = self.GetBufferPointer(); CStr::from_ptr(data as *const _ as *const _) diff --git a/d3d12/src/pso.rs b/d3d12/src/pso.rs index 72cc8bf133..542222c193 100644 --- a/d3d12/src/pso.rs +++ b/d3d12/src/pso.rs @@ -72,8 +72,8 @@ impl<'a> Shader<'a> { entry: &ffi::CStr, flags: ShaderCompileFlags, ) -> D3DResult<(Blob, Error)> { - let mut shader = Blob::null(); - let mut error = Error::null(); + let mut shader = std::ptr::null_mut(); + let mut error = std::ptr::null_mut(); let hr = unsafe { d3dcompiler::D3DCompile( @@ -86,10 +86,12 @@ impl<'a> Shader<'a> { target.as_ptr() as *const _, flags.bits(), 0, - shader.mut_self(), - error.mut_self(), + &mut shader, + &mut error, ) }; + let shader = unsafe { ComPtr::from_reffed(shader) }; + let error = unsafe { ComPtr::from_reffed(error) }; ((shader, error), hr) } diff --git a/wgpu-hal/src/auxil/dxgi/factory.rs b/wgpu-hal/src/auxil/dxgi/factory.rs index b22da9c4dc..7cc86dcf54 100644 --- a/wgpu-hal/src/auxil/dxgi/factory.rs +++ b/wgpu-hal/src/auxil/dxgi/factory.rs @@ -1,5 +1,5 @@ use winapi::{ - shared::{dxgi, dxgi1_2, dxgi1_4, dxgi1_6, winerror}, + shared::{dxgi1_2, dxgi1_4, dxgi1_6, winerror}, Interface, }; @@ -21,13 +21,13 @@ pub fn enumerate_adapters(factory: d3d12::DxgiFactory) -> Vec::null(); + let mut adapter4 = std::ptr::null_mut(); let hr = unsafe { factory6.EnumAdapterByGpuPreference( cur_index, dxgi1_6::DXGI_GPU_PREFERENCE_HIGH_PERFORMANCE, &dxgi1_6::IDXGIAdapter4::uuidof(), - adapter4.mut_void(), + &mut adapter4, ) }; @@ -39,13 +39,15 @@ pub fn enumerate_adapters(factory: d3d12::DxgiFactory) -> Vec::null(); - let hr = unsafe { factory.EnumAdapters1(cur_index, adapter1.mut_self()) }; + let mut adapter1 = std::ptr::null_mut(); + let hr = unsafe { factory.EnumAdapters1(cur_index, &mut adapter1) }; + let adapter1 = unsafe { d3d12::ComPtr::from_reffed(adapter1) }; if hr == winerror::DXGI_ERROR_NOT_FOUND { break; @@ -60,23 +62,25 @@ pub fn enumerate_adapters(factory: d3d12::DxgiFactory) -> Vec Adapter3 unsafe { match adapter1.cast::().into_result() { - Ok(adapter3) => { + Ok(Some(adapter3)) => { adapters.push(d3d12::DxgiAdapter::Adapter3(adapter3)); continue; } Err(err) => { log::info!("Failed casting Adapter1 to Adapter3: {}", err); } + Ok(None) => unreachable!(), } } // Adapter1 -> Adapter2 unsafe { match adapter1.cast::().into_result() { - Ok(adapter2) => { + Ok(Some(adapter2)) => { adapters.push(d3d12::DxgiAdapter::Adapter2(adapter2)); continue; } + Ok(None) => unreachable!(), Err(err) => { log::info!("Failed casting Adapter1 to Adapter2: {}", err); } @@ -154,9 +158,10 @@ pub fn create_factory( // Try to cast the IDXGIFactory4 into IDXGIFactory6 let factory6 = unsafe { factory4.cast::().into_result() }; match factory6 { - Ok(factory6) => { + Ok(Some(factory6)) => { return Ok((lib_dxgi, d3d12::DxgiFactory::Factory6(factory6))); } + Ok(None) => unreachable!(), // If we require factory6, hard error. Err(err) if required_factory_type == DxgiFactoryType::Factory6 => { // err is a Cow, not an Error implementor @@ -195,9 +200,10 @@ pub fn create_factory( // Try to cast the IDXGIFactory1 into IDXGIFactory2 let factory2 = unsafe { factory1.cast::().into_result() }; match factory2 { - Ok(factory2) => { + Ok(Some(factory2)) => { return Ok((lib_dxgi, d3d12::DxgiFactory::Factory2(factory2))); } + Ok(None) => unreachable!(), // If we require factory2, hard error. Err(err) if required_factory_type == DxgiFactoryType::Factory2 => { // err is a Cow, not an Error implementor diff --git a/wgpu-hal/src/dx11/library.rs b/wgpu-hal/src/dx11/library.rs index c2b5315ba1..21e8cb8ee2 100644 --- a/wgpu-hal/src/dx11/library.rs +++ b/wgpu-hal/src/dx11/library.rs @@ -63,7 +63,7 @@ impl D3D11Lib { d3dcommon::D3D_FEATURE_LEVEL_9_1, ]; - let mut device = d3d12::ComPtr::::null(); + let mut device = std::ptr::null_mut(); let mut feature_level: d3dcommon::D3D_FEATURE_LEVEL = 0; // We need to try this twice. If the first time fails due to E_INVALIDARG @@ -81,7 +81,7 @@ impl D3D11Lib { feature_levels.as_ptr(), feature_levels.len() as u32, d3d11::D3D11_SDK_VERSION, - device.mut_self(), + &mut device, &mut feature_level, ptr::null_mut(), // device context ) @@ -98,12 +98,13 @@ impl D3D11Lib { feature_levels[1..].as_ptr(), feature_levels[1..].len() as u32, d3d11::D3D11_SDK_VERSION, - device.mut_self(), + &mut device, &mut feature_level, ptr::null_mut(), // device context ) }; } + let device = unsafe { d3d12::ComPtr::from_reffed(device) }; // Any errors here are real and we should complain about if let Err(err) = hr.into_result() { @@ -116,24 +117,26 @@ impl D3D11Lib { // Device -> Device2 unsafe { match device.cast::().into_result() { - Ok(device2) => { + Ok(Some(device2)) => { return Some((super::D3D11Device::Device2(device2), feature_level)); } Err(hr) => { log::info!("Failed to cast device to ID3D11Device2: {}", hr) } + Ok(None) => unreachable!(), } } // Device -> Device1 unsafe { match device.cast::().into_result() { - Ok(device1) => { + Ok(Some(device1)) => { return Some((super::D3D11Device::Device1(device1), feature_level)); } Err(hr) => { log::info!("Failed to cast device to ID3D11Device1: {}", hr) } + Ok(None) => unreachable!(), } } diff --git a/wgpu-hal/src/dx12/adapter.rs b/wgpu-hal/src/dx12/adapter.rs index 3959deeccd..0032a6999c 100644 --- a/wgpu-hal/src/dx12/adapter.rs +++ b/wgpu-hal/src/dx12/adapter.rs @@ -26,7 +26,7 @@ impl Drop for super::Adapter { impl super::Adapter { pub unsafe fn report_live_objects(&self) { - if let Ok(debug_device) = unsafe { + if let Ok(Some(debug_device)) = unsafe { self.raw .cast::() .into_result() diff --git a/wgpu-hal/src/dx12/command.rs b/wgpu-hal/src/dx12/command.rs index 719e63a36f..de9e9cafe6 100644 --- a/wgpu-hal/src/dx12/command.rs +++ b/wgpu-hal/src/dx12/command.rs @@ -249,9 +249,7 @@ impl crate::CommandEncoder for super::CommandEncoder { unsafe fn begin_encoding(&mut self, label: crate::Label) -> Result<(), crate::DeviceError> { let list = loop { if let Some(list) = self.free_lists.pop() { - let reset_result = list - .reset(&self.allocator, d3d12::PipelineState::null()) - .into_result(); + let reset_result = list.reset(&self.allocator, None).into_result(); if reset_result.is_ok() { break Some(list); } @@ -264,12 +262,7 @@ impl crate::CommandEncoder for super::CommandEncoder { list } else { self.device - .create_graphics_command_list( - d3d12::CmdListType::Direct, - &self.allocator, - d3d12::PipelineState::null(), - 0, - ) + .create_graphics_command_list(d3d12::CmdListType::Direct, &self.allocator, None, 0) .into_device_result("Create command list")? }; @@ -956,7 +949,7 @@ impl crate::CommandEncoder for super::CommandEncoder { if self.pass.layout.signature != pipeline.layout.signature { // D3D12 requires full reset on signature change - list.set_graphics_root_signature(&pipeline.layout.signature); + list.set_graphics_root_signature(pipeline.layout.signature.as_ref()); self.reset_signature(&pipeline.layout); }; @@ -1166,7 +1159,7 @@ impl crate::CommandEncoder for super::CommandEncoder { if self.pass.layout.signature != pipeline.layout.signature { // D3D12 requires full reset on signature change - list.set_compute_root_signature(&pipeline.layout.signature); + list.set_compute_root_signature(pipeline.layout.signature.as_ref()); self.reset_signature(&pipeline.layout); }; diff --git a/wgpu-hal/src/dx12/device.rs b/wgpu-hal/src/dx12/device.rs index 86c0fc0cec..19d56854c6 100644 --- a/wgpu-hal/src/dx12/device.rs +++ b/wgpu-hal/src/dx12/device.rs @@ -35,20 +35,20 @@ impl super::Device { wgt::Dx12Compiler::Fxc => None, }; - let mut idle_fence = d3d12::Fence::null(); + let mut idle_fence = std::ptr::null_mut(); let hr = unsafe { profiling::scope!("ID3D12Device::CreateFence"); raw.CreateFence( 0, d3d12_ty::D3D12_FENCE_FLAG_NONE, &d3d12_ty::ID3D12Fence::uuidof(), - idle_fence.mut_void(), + &mut idle_fence, ) }; + let idle_fence = unsafe { d3d12::ComPtr::from_reffed(idle_fence.cast()) }; hr.into_device_result("Idle fence creation")?; - let mut zero_buffer = d3d12::Resource::null(); - unsafe { + let zero_buffer = unsafe { let raw_desc = d3d12_ty::D3D12_RESOURCE_DESC { Dimension: d3d12_ty::D3D12_RESOURCE_DIMENSION_BUFFER, Alignment: 0, @@ -77,6 +77,7 @@ impl super::Device { }; profiling::scope!("Zero Buffer Allocation"); + let mut zero_buffer = std::ptr::null_mut(); raw.CreateCommittedResource( &heap_properties, d3d12_ty::D3D12_HEAP_FLAG_NONE, @@ -84,9 +85,10 @@ impl super::Device { d3d12_ty::D3D12_RESOURCE_STATE_COMMON, ptr::null(), &d3d12_ty::ID3D12Resource::uuidof(), - zero_buffer.mut_void(), + &mut zero_buffer, ) .into_device_result("Zero buffer creation")?; + d3d12::Resource::from_reffed(zero_buffer.cast()) // Note: without `D3D12_HEAP_FLAG_CREATE_NOT_ZEROED` // this resource is zeroed by default. @@ -342,9 +344,8 @@ impl crate::Device for super::Device { Flags: conv::map_buffer_usage_to_resource_flags(desc.usage), }; - let mut resource = d3d12::Resource::null(); - let (hr, allocation) = - super::suballocation::create_buffer_resource(self, desc, raw_desc, &mut resource)?; + let (hr, allocation, resource) = + super::suballocation::create_buffer_resource(self, desc, raw_desc)?; hr.into_device_result("Buffer creation")?; if let Some(label) = desc.label { @@ -423,8 +424,7 @@ impl crate::Device for super::Device { Flags: conv::map_texture_usage_to_resource_flags(desc.usage), }; - let mut resource = d3d12::Resource::null(); - let (hr, allocation) = create_texture_resource(self, desc, raw_desc, &mut resource)?; + let (hr, allocation, resource) = create_texture_resource(self, desc, raw_desc)?; hr.into_device_result("Texture creation")?; if let Some(label) = desc.label { @@ -1015,7 +1015,7 @@ impl crate::Device for super::Device { })? .into_device_result("Root signature serialization")?; - if !error.is_null() { + if let Some(error) = error { log::error!( "Root signature serialization error: {:?}", error.as_c_str().to_str().unwrap() @@ -1037,7 +1037,7 @@ impl crate::Device for super::Device { Ok(super::PipelineLayout { shared: super::PipelineLayoutShared { - signature: raw, + signature: Some(raw), total_root_elements: parameters.len() as super::RootIndex, special_constants_root_index, root_constant_info, @@ -1338,7 +1338,7 @@ impl crate::Device for super::Device { }; let raw_desc = d3d12_ty::D3D12_GRAPHICS_PIPELINE_STATE_DESC { - pRootSignature: desc.layout.shared.signature.as_mut_ptr(), + pRootSignature: desc.layout.shared.signature.as_ref().unwrap().as_mut_ptr(), VS: *blob_vs.create_native_shader(), PS: match blob_fs { Some(ref shader) => *shader.create_native_shader(), @@ -1403,17 +1403,18 @@ impl crate::Device for super::Device { Flags: d3d12_ty::D3D12_PIPELINE_STATE_FLAG_NONE, }; - let mut raw = d3d12::PipelineState::null(); + let mut raw = std::ptr::null_mut(); let hr = { profiling::scope!("ID3D12Device::CreateGraphicsPipelineState"); unsafe { self.raw.CreateGraphicsPipelineState( &raw_desc, &d3d12_ty::ID3D12PipelineState::uuidof(), - raw.mut_void(), + &mut raw, ) } }; + let raw = unsafe { d3d12::PipelineState::from_reffed(raw.cast()) }; unsafe { blob_vs.destroy() }; if let Some(blob_fs) = blob_fs { @@ -1446,7 +1447,7 @@ impl crate::Device for super::Device { let pair = { profiling::scope!("ID3D12Device::CreateComputePipelineState"); self.raw.create_compute_pipeline_state( - &desc.layout.shared.signature, + desc.layout.shared.signature.as_ref(), blob_cs.create_native_shader(), 0, d3d12::CachedPSO::null(), @@ -1506,15 +1507,16 @@ impl crate::Device for super::Device { unsafe fn destroy_query_set(&self, _set: super::QuerySet) {} unsafe fn create_fence(&self) -> Result { - let mut raw = d3d12::Fence::null(); + let mut raw = std::ptr::null_mut(); let hr = unsafe { self.raw.CreateFence( 0, d3d12_ty::D3D12_FENCE_FLAG_NONE, &d3d12_ty::ID3D12Fence::uuidof(), - raw.mut_void(), + &mut raw, ) }; + let raw = unsafe { d3d12::ComPtr::from_reffed(raw.cast()) }; hr.into_device_result("Fence creation")?; Ok(super::Fence { raw }) } diff --git a/wgpu-hal/src/dx12/mod.rs b/wgpu-hal/src/dx12/mod.rs index 2279c0dcfc..4806ebe16e 100644 --- a/wgpu-hal/src/dx12/mod.rs +++ b/wgpu-hal/src/dx12/mod.rs @@ -318,7 +318,7 @@ impl PassState { has_label: false, resolves: ArrayVec::new(), layout: PipelineLayoutShared { - signature: d3d12::RootSignature::null(), + signature: None, total_root_elements: 0, special_constants_root_index: None, root_constant_info: None, @@ -521,7 +521,7 @@ struct RootConstantInfo { #[derive(Clone)] struct PipelineLayoutShared { - signature: d3d12::RootSignature, + signature: Option, total_root_elements: RootIndex, special_constants_root_index: Option, root_constant_info: Option, @@ -731,11 +731,12 @@ impl crate::Surface for Surface { } match unsafe { swap_chain1.cast::() }.into_result() { - Ok(swap_chain3) => swap_chain3, + Ok(Some(swap_chain3)) => swap_chain3, Err(err) => { log::error!("Unable to cast swap chain: {}", err); return Err(crate::SurfaceError::Other("swap chain cast to 3")); } + Ok(None) => unreachable!(), } } }; @@ -760,10 +761,9 @@ impl crate::Surface for Surface { let mut resources = Vec::with_capacity(config.swap_chain_size as usize); for i in 0..config.swap_chain_size { - let mut resource = d3d12::Resource::null(); - unsafe { - swap_chain.GetBuffer(i, &d3d12_ty::ID3D12Resource::uuidof(), resource.mut_void()) - }; + let mut resource = std::ptr::null_mut(); + unsafe { swap_chain.GetBuffer(i, &d3d12_ty::ID3D12Resource::uuidof(), &mut resource) }; + let resource = unsafe { d3d12::ComPtr::from_reffed(resource.cast()) }; resources.push(resource); } diff --git a/wgpu-hal/src/dx12/shader_compilation.rs b/wgpu-hal/src/dx12/shader_compilation.rs index 55a8f595d1..74c86dd870 100644 --- a/wgpu-hal/src/dx12/shader_compilation.rs +++ b/wgpu-hal/src/dx12/shader_compilation.rs @@ -23,7 +23,7 @@ pub(super) fn compile_fxc( log::Level, ) { profiling::scope!("compile_fxc"); - let mut shader_data = d3d12::Blob::null(); + let mut shader_data = std::ptr::null_mut(); let mut compile_flags = d3dcompiler::D3DCOMPILE_ENABLE_STRICTNESS; if device .private_caps @@ -32,7 +32,7 @@ pub(super) fn compile_fxc( { compile_flags |= d3dcompiler::D3DCOMPILE_DEBUG | d3dcompiler::D3DCOMPILE_SKIP_OPTIMIZATION; } - let mut error = d3d12::Blob::null(); + let mut error = std::ptr::null_mut(); let hr = unsafe { profiling::scope!("d3dcompiler::D3DCompile"); d3dcompiler::D3DCompile( @@ -45,10 +45,11 @@ pub(super) fn compile_fxc( full_stage.as_ptr().cast(), compile_flags, 0, - shader_data.mut_void().cast(), - error.mut_void().cast(), + &mut shader_data, + &mut error, ) }; + let shader_data = unsafe { d3d12::ComPtr::from_reffed(shader_data.cast()) }; match hr.into_result() { Ok(()) => ( @@ -58,6 +59,7 @@ pub(super) fn compile_fxc( Err(e) => { let mut full_msg = format!("FXC D3DCompile error ({e})"); if !error.is_null() { + let error = unsafe { d3d12::Blob::from_reffed(error.cast()) }; use std::fmt::Write as _; let message = unsafe { std::slice::from_raw_parts( diff --git a/wgpu-hal/src/dx12/suballocation.rs b/wgpu-hal/src/dx12/suballocation.rs index 9625b2ae3a..43bf601744 100644 --- a/wgpu-hal/src/dx12/suballocation.rs +++ b/wgpu-hal/src/dx12/suballocation.rs @@ -64,15 +64,15 @@ mod placed { device: &crate::dx12::Device, desc: &crate::BufferDescriptor, raw_desc: d3d12_ty::D3D12_RESOURCE_DESC, - resource: &mut ComPtr, - ) -> Result<(HRESULT, Option), crate::DeviceError> { + ) -> Result<(HRESULT, Option, ComPtr), crate::DeviceError> + { let is_cpu_read = desc.usage.contains(crate::BufferUses::MAP_READ); let is_cpu_write = desc.usage.contains(crate::BufferUses::MAP_WRITE); // It's a workaround for Intel Xe drivers. if !device.private_caps.suballocation_supported { - return super::committed::create_buffer_resource(device, desc, raw_desc, resource) - .map(|(hr, _)| (hr, None)); + return super::committed::create_buffer_resource(device, desc, raw_desc) + .map(|(hr, _, res)| (hr, None, res)); } let location = match (is_cpu_read, is_cpu_write) { @@ -102,6 +102,7 @@ mod placed { ); let allocation = allocator.allocator.allocate(&allocation_desc)?; + let mut resource = std::ptr::null_mut(); let hr = unsafe { device.raw.CreatePlacedResource( allocation.heap().as_winapi() as *mut _, @@ -110,23 +111,24 @@ mod placed { d3d12_ty::D3D12_RESOURCE_STATE_COMMON, ptr::null(), &d3d12_ty::ID3D12Resource::uuidof(), - resource.mut_void(), + &mut resource, ) }; + let resource = unsafe { d3d12::ComPtr::from_reffed(resource.cast()) }; - Ok((hr, Some(AllocationWrapper { allocation }))) + Ok((hr, Some(AllocationWrapper { allocation }), resource)) } pub(crate) fn create_texture_resource( device: &crate::dx12::Device, desc: &crate::TextureDescriptor, raw_desc: d3d12_ty::D3D12_RESOURCE_DESC, - resource: &mut ComPtr, - ) -> Result<(HRESULT, Option), crate::DeviceError> { + ) -> Result<(HRESULT, Option, ComPtr), crate::DeviceError> + { // It's a workaround for Intel Xe drivers. if !device.private_caps.suballocation_supported { - return super::committed::create_texture_resource(device, desc, raw_desc, resource) - .map(|(hr, _)| (hr, None)); + return super::committed::create_texture_resource(device, desc, raw_desc) + .map(|(hr, _, res)| (hr, None, res)); } let location = MemoryLocation::GpuOnly; @@ -149,6 +151,7 @@ mod placed { ); let allocation = allocator.allocator.allocate(&allocation_desc)?; + let mut resource = std::ptr::null_mut(); let hr = unsafe { device.raw.CreatePlacedResource( allocation.heap().as_winapi() as *mut _, @@ -157,11 +160,12 @@ mod placed { d3d12_ty::D3D12_RESOURCE_STATE_COMMON, ptr::null(), // clear value &d3d12_ty::ID3D12Resource::uuidof(), - resource.mut_void(), + &mut resource, ) }; + let resource = unsafe { d3d12::ComPtr::from_reffed(resource.cast()) }; - Ok((hr, Some(AllocationWrapper { allocation }))) + Ok((hr, Some(AllocationWrapper { allocation }), resource)) } pub(crate) fn free_buffer_allocation( @@ -254,8 +258,8 @@ mod committed { device: &crate::dx12::Device, desc: &crate::BufferDescriptor, raw_desc: d3d12_ty::D3D12_RESOURCE_DESC, - resource: &mut ComPtr, - ) -> Result<(HRESULT, Option), crate::DeviceError> { + ) -> Result<(HRESULT, Option, ComPtr), crate::DeviceError> + { let is_cpu_read = desc.usage.contains(crate::BufferUses::MAP_READ); let is_cpu_write = desc.usage.contains(crate::BufferUses::MAP_WRITE); @@ -278,6 +282,7 @@ mod committed { VisibleNodeMask: 0, }; + let mut resource = std::ptr::null_mut(); let hr = unsafe { device.raw.CreateCommittedResource( &heap_properties, @@ -290,19 +295,20 @@ mod committed { d3d12_ty::D3D12_RESOURCE_STATE_COMMON, ptr::null(), &d3d12_ty::ID3D12Resource::uuidof(), - resource.mut_void(), + &mut resource, ) }; + let resource = unsafe { d3d12::ComPtr::from_reffed(resource.cast()) }; - Ok((hr, None)) + Ok((hr, None, resource)) } pub(crate) fn create_texture_resource( device: &crate::dx12::Device, _desc: &crate::TextureDescriptor, raw_desc: d3d12_ty::D3D12_RESOURCE_DESC, - resource: &mut ComPtr, - ) -> Result<(HRESULT, Option), crate::DeviceError> { + ) -> Result<(HRESULT, Option, ComPtr), crate::DeviceError> + { let heap_properties = d3d12_ty::D3D12_HEAP_PROPERTIES { Type: d3d12_ty::D3D12_HEAP_TYPE_CUSTOM, CPUPageProperty: d3d12_ty::D3D12_CPU_PAGE_PROPERTY_NOT_AVAILABLE, @@ -314,6 +320,7 @@ mod committed { VisibleNodeMask: 0, }; + let mut resource = std::ptr::null_mut(); let hr = unsafe { device.raw.CreateCommittedResource( &heap_properties, @@ -326,11 +333,12 @@ mod committed { d3d12_ty::D3D12_RESOURCE_STATE_COMMON, ptr::null(), // clear value &d3d12_ty::ID3D12Resource::uuidof(), - resource.mut_void(), + &mut resource, ) }; + let resource = unsafe { d3d12::ComPtr::from_reffed(resource.cast()) }; - Ok((hr, None)) + Ok((hr, None, resource)) } #[allow(unused)]