Skip to content

Commit

Permalink
fix(d3d12): remove nullability in ComPtr
Browse files Browse the repository at this point in the history
  • Loading branch information
ErichDonGubler committed Sep 25, 2023
1 parent e8ff8ee commit d31c3ab
Show file tree
Hide file tree
Showing 16 changed files with 225 additions and 234 deletions.
90 changes: 23 additions & 67 deletions d3d12/src/com.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,80 +3,44 @@ use std::{
fmt,
hash::{Hash, Hasher},
ops::Deref,
ptr,
};
use winapi::{ctypes::c_void, um::unknwnbase::IUnknown, Interface};
use winapi::{um::unknwnbase::IUnknown, Interface};

#[repr(transparent)]
pub struct ComPtr<T: Interface>(*mut T);

impl<T: Interface> ComPtr<T> {
/// Creates a null ComPtr.
pub fn null() -> Self {
ComPtr(ptr::null_mut())
/// Create a ComPtr from a raw pointer. This will _not_ call AddRef on the pointer, assuming
/// that it has already been called.
///
/// # Safety
///
/// - `raw` must be a valid pointer to a COM object that implements T.
pub unsafe fn from_reffed(raw: *mut T) -> Self {
debug_assert!(!raw.is_null());
ComPtr(raw)
}

/// Create a ComPtr from a raw pointer. This will call AddRef on the pointer.
///
/// # Safety
///
/// - if `raw` is not null, it must be a valid pointer to a COM object that implements T.
/// - `raw` must be a valid pointer to a COM object that implements T.
pub unsafe fn from_raw(raw: *mut T) -> Self {
if !raw.is_null() {
(*(raw as *mut IUnknown)).AddRef();
}
debug_assert!(!raw.is_null());
(*(raw as *mut IUnknown)).AddRef();
ComPtr(raw)
}

/// Returns true if the inner pointer is null.
pub fn is_null(&self) -> bool {
self.0.is_null()
}

/// Returns the raw inner pointer. May be null.
/// Returns the raw inner pointer.
pub fn as_ptr(&self) -> *const T {
self.0
}

/// Returns the raw inner pointer as mutable. May be null.
/// Returns the raw inner pointer as mutable.
pub fn as_mut_ptr(&self) -> *mut T {
self.0
}

/// Returns a mutable reference to the inner pointer casted as a pointer to c_void.
///
/// This is useful when D3D functions initialize objects by filling in a pointer to pointer
/// by taking `void**` as an argument.
///
/// # Safety
///
/// - Any modifications done to this pointer must result in the pointer either:
/// - being set to null
/// - being set to a valid pointer to a COM object that implements T
pub unsafe fn mut_void(&mut self) -> &mut *mut c_void {
// SAFETY: We must first get a reference pointing to our internal pointer
// and only then cast it. As if we cast it, then take a reference, we would
// end up with a reference to a temporary.
let refer: &mut *mut T = &mut self.0;
let void: *mut *mut c_void = refer.cast();

// SAFETY: This reference is valid for the duration of the borrow due our mutable borrow of self.
&mut *void
}

/// Returns a mutable reference to the inner pointer.
///
/// This is useful when D3D functions initialize objects by filling in a pointer to pointer
/// by taking `T**` as an argument.
///
/// # Safety
///
/// - Any modifications done to this pointer must result in the pointer either:
/// - being set to null
/// - being set to a valid pointer to a COM object that implements T
pub fn mut_self(&mut self) -> &mut *mut T {
&mut self.0
}
}

impl<T: Interface> ComPtr<T> {
Expand All @@ -86,7 +50,6 @@ impl<T: Interface> ComPtr<T> {
///
/// - This pointer must not be null.
pub unsafe fn as_unknown(&self) -> &IUnknown {
debug_assert!(!self.is_null());
&*(self.0 as *mut IUnknown)
}

Expand All @@ -95,44 +58,37 @@ impl<T: Interface> ComPtr<T> {
/// # Safety
///
/// - This pointer must not be null.
pub unsafe fn cast<U>(&self) -> D3DResult<ComPtr<U>>
pub unsafe fn cast<U>(&self) -> D3DResult<Option<ComPtr<U>>>
where
U: Interface,
{
debug_assert!(!self.is_null());
let mut obj = ComPtr::<U>::null();
let hr = self
.as_unknown()
.QueryInterface(&U::uuidof(), obj.mut_void());
let mut obj = std::ptr::null_mut();
let hr = self.as_unknown().QueryInterface(&U::uuidof(), &mut obj);
let obj = (!obj.is_null()).then(|| ComPtr::from_reffed(obj.cast()));
(obj, hr)
}
}

impl<T: Interface> Clone for ComPtr<T> {
fn clone(&self) -> Self {
if !self.is_null() {
unsafe {
self.as_unknown().AddRef();
}
unsafe {
self.as_unknown().AddRef();
}
ComPtr(self.0)
}
}

impl<T: Interface> Drop for ComPtr<T> {
fn drop(&mut self) {
if !self.is_null() {
unsafe {
self.as_unknown().Release();
}
unsafe {
self.as_unknown().Release();
}
}
}

impl<T: Interface> Deref for ComPtr<T> {
type Target = T;
fn deref(&self) -> &T {
assert!(!self.is_null());
unsafe { &*self.0 }
}
}
Expand Down
19 changes: 13 additions & 6 deletions d3d12/src/command_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,13 @@ impl GraphicsCommandList {
unsafe { self.Close() }
}

pub fn reset(&self, allocator: &CommandAllocator, initial_pso: PipelineState) -> HRESULT {
unsafe { self.Reset(allocator.as_mut_ptr(), initial_pso.as_mut_ptr()) }
pub fn reset(
&self,
allocator: &CommandAllocator,
initial_pso: Option<&PipelineState>,
) -> HRESULT {
let initial_pso = initial_pso.map_or(ptr::null_mut(), |pso| pso.as_mut_ptr());
unsafe { self.Reset(allocator.as_mut_ptr(), initial_pso) }
}

pub fn discard_resource(&self, resource: Resource, region: DiscardRegion) {
Expand Down Expand Up @@ -284,15 +289,17 @@ impl GraphicsCommandList {
}
}

pub fn set_compute_root_signature(&self, signature: &RootSignature) {
pub fn set_compute_root_signature(&self, signature: Option<&RootSignature>) {
unsafe {
self.SetComputeRootSignature(signature.as_mut_ptr());
self.SetComputeRootSignature(signature.map_or(ptr::null_mut(), |sig| sig.as_mut_ptr()));
}
}

pub fn set_graphics_root_signature(&self, signature: &RootSignature) {
pub fn set_graphics_root_signature(&self, signature: Option<&RootSignature>) {
unsafe {
self.SetGraphicsRootSignature(signature.as_mut_ptr());
self.SetGraphicsRootSignature(
signature.map_or(ptr::null_mut(), |sig| sig.as_mut_ptr()),
);
}
}

Expand Down
10 changes: 6 additions & 4 deletions d3d12/src/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ impl crate::D3D12Lib {
*mut *mut winapi::ctypes::c_void,
) -> crate::HRESULT;

let mut debug = Debug::null();
let mut debug = std::ptr::null_mut();
let hr = unsafe {
let func: libloading::Symbol<Fun> = self.lib.get(b"D3D12GetDebugInterface")?;
func(&d3d12sdklayers::ID3D12Debug::uuidof(), debug.mut_void())
func(&d3d12sdklayers::ID3D12Debug::uuidof(), &mut debug)
};
let debug = unsafe { ComPtr::from_reffed(debug.cast()) };

Ok((debug, hr))
}
Expand All @@ -26,13 +27,14 @@ impl crate::D3D12Lib {
impl Debug {
#[cfg(feature = "implicit-link")]
pub fn get_interface() -> crate::D3DResult<Self> {
let mut debug = Debug::null();
let mut debug = std::ptr::null_mut();
let hr = unsafe {
winapi::um::d3d12::D3D12GetDebugInterface(
&d3d12sdklayers::ID3D12Debug::uuidof(),
debug.mut_void(),
&mut debug,
)
};
let debug = unsafe { ComPtr::from_reffed(debug.cast()) };

(debug, hr)
}
Expand Down
23 changes: 11 additions & 12 deletions d3d12/src/descriptor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ bitflags::bitflags! {
}

pub type RootSignature = ComPtr<d3d12::ID3D12RootSignature>;
pub type BlobResult = D3DResult<(Blob, Error)>;
pub type BlobResult = D3DResult<(Blob, Option<Error>)>;

#[cfg(feature = "libloading")]
impl crate::D3D12Lib {
Expand All @@ -293,12 +293,14 @@ impl crate::D3D12Lib {
Flags: flags.bits(),
};

let mut blob = Blob::null();
let mut error = Error::null();
let mut blob = std::ptr::null_mut();
let mut error = std::ptr::null_mut();
let hr = unsafe {
let func: libloading::Symbol<Fun> = self.lib.get(b"D3D12SerializeRootSignature")?;
func(&desc, version as _, blob.mut_self(), error.mut_self())
func(&desc, version as _, &mut blob, &mut error)
};
let blob = unsafe { ComPtr::from_reffed(blob) };
let error = (!error.is_null()).then(|| unsafe { ComPtr::from_reffed(error) });

Ok(((blob, error), hr))
}
Expand All @@ -312,8 +314,8 @@ impl RootSignature {
static_samplers: &[StaticSampler],
flags: RootSignatureFlags,
) -> BlobResult {
let mut blob = Blob::null();
let mut error = Error::null();
let mut blob = std::ptr::null_mut();
let mut error = std::ptr::null_mut();

let desc = d3d12::D3D12_ROOT_SIGNATURE_DESC {
NumParameters: parameters.len() as _,
Expand All @@ -324,13 +326,10 @@ impl RootSignature {
};

let hr = unsafe {
d3d12::D3D12SerializeRootSignature(
&desc,
version as _,
blob.mut_self(),
error.mut_self(),
)
d3d12::D3D12SerializeRootSignature(&desc, version as _, &mut blob, &mut error)
};
let blob = unsafe { ComPtr::from_reffed(blob) };
let error = unsafe { ComPtr::from_reffed(error) };

((blob, error), hr)
}
Expand Down
Loading

0 comments on commit d31c3ab

Please sign in to comment.