Skip to content

Commit

Permalink
move CommandEncoderStatus on the CommandBuffer and change its var…
Browse files Browse the repository at this point in the history
…iants to hold `CommandBufferMutable`

This makes the code more straightforward, we were previously holding invalidity state in 2 places: `CommandBuffer::data` could hold `None` and in `CommandEncoderStatus::Error`.

This commit also implements `Drop` for `CommandEncoder` which makes the destruction/reclamation code automatic. We were previously not reclaiming all command encoders (`CommandBufferMutable::destroy` didn't call `release_encoder`) even though all encoders are coming from a pool.
  • Loading branch information
teoxoy authored and jimblandy committed Dec 2, 2024
1 parent 68d336e commit 5e1fbd7
Show file tree
Hide file tree
Showing 13 changed files with 240 additions and 267 deletions.
8 changes: 4 additions & 4 deletions wgpu-core/src/command/clear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ impl Global {
let cmd_buf = hub
.command_buffers
.get(command_encoder_id.into_command_buffer_id());
let mut cmd_buf_data = cmd_buf.try_get()?;
cmd_buf_data.check_recording()?;
let mut cmd_buf_data = cmd_buf.data.lock();
let cmd_buf_data = cmd_buf_data.record()?;

#[cfg(feature = "trace")]
if let Some(ref mut list) = cmd_buf_data.commands {
Expand Down Expand Up @@ -174,8 +174,8 @@ impl Global {
let cmd_buf = hub
.command_buffers
.get(command_encoder_id.into_command_buffer_id());
let mut cmd_buf_data = cmd_buf.try_get()?;
cmd_buf_data.check_recording()?;
let mut cmd_buf_data = cmd_buf.data.lock();
let cmd_buf_data = cmd_buf_data.record()?;

#[cfg(feature = "trace")]
if let Some(ref mut list) = cmd_buf_data.commands {
Expand Down
34 changes: 13 additions & 21 deletions wgpu-core/src/command/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use crate::{
end_pipeline_statistics_query,
memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
validate_and_begin_pipeline_statistics_query, ArcPassTimestampWrites, BasePass,
BindGroupStateChange, CommandBuffer, CommandEncoderError, CommandEncoderStatus, MapPassErr,
PassErrorScope, PassTimestampWrites, QueryUseError, StateChange,
BindGroupStateChange, CommandBuffer, CommandEncoderError, MapPassErr, PassErrorScope,
PassTimestampWrites, QueryUseError, StateChange,
},
device::{Device, DeviceError, MissingDownlevelFlags, MissingFeatures},
global::Global,
Expand All @@ -30,8 +30,7 @@ use wgt::{BufferAddress, DynamicOffset};

use super::{bind::BinderError, memory_init::CommandBufferTextureMemoryActions};
use crate::ray_tracing::TlasAction;
use std::sync::Arc;
use std::{fmt, mem::size_of, str};
use std::{fmt, mem::size_of, str, sync::Arc};

pub struct ComputePass {
/// All pass data & records is stored here.
Expand Down Expand Up @@ -282,7 +281,9 @@ impl Global {
/// If creation fails, an invalid pass is returned.
/// Any operation on an invalid pass will return an error.
///
/// If successful, puts the encoder into the [`CommandEncoderStatus::Locked`] state.
/// If successful, puts the encoder into the [`Locked`] state.
///
/// [`Locked`]: crate::command::CommandEncoderStatus::Locked
pub fn command_encoder_create_compute_pass(
&self,
encoder_id: id::CommandEncoderId,
Expand All @@ -299,11 +300,7 @@ impl Global {

let cmd_buf = hub.command_buffers.get(encoder_id.into_command_buffer_id());

match cmd_buf
.try_get()
.map_err(|e| e.into())
.and_then(|mut cmd_buf_data| cmd_buf_data.lock_encoder())
{
match cmd_buf.data.lock().lock_encoder() {
Ok(_) => {}
Err(e) => return make_err(e, arc_desc),
};
Expand Down Expand Up @@ -340,7 +337,8 @@ impl Global {
.hub
.command_buffers
.get(encoder_id.into_command_buffer_id());
let mut cmd_buf_data = cmd_buf.try_get().map_pass_err(pass_scope)?;
let mut cmd_buf_data = cmd_buf.data.lock();
let cmd_buf_data = cmd_buf_data.get_inner().map_pass_err(pass_scope)?;

if let Some(ref mut list) = cmd_buf_data.commands {
list.push(crate::device::trace::Command::RunComputePass {
Expand Down Expand Up @@ -408,19 +406,16 @@ impl Global {
let device = &cmd_buf.device;
device.check_is_valid().map_pass_err(pass_scope)?;

let mut cmd_buf_data = cmd_buf.try_get().map_pass_err(pass_scope)?;
cmd_buf_data.unlock_encoder().map_pass_err(pass_scope)?;
let cmd_buf_data = &mut *cmd_buf_data;
let mut cmd_buf_data = cmd_buf.data.lock();
let mut cmd_buf_data_guard = cmd_buf_data.unlock_encoder().map_pass_err(pass_scope)?;
let cmd_buf_data = &mut *cmd_buf_data_guard;

let encoder = &mut cmd_buf_data.encoder;
let status = &mut cmd_buf_data.status;

// We automatically keep extending command buffers over time, and because
// we want to insert a command buffer _before_ what we're about to record,
// we need to make sure to close the previous one.
encoder.close(&cmd_buf.device).map_pass_err(pass_scope)?;
// will be reset to true if recording is done without errors
*status = CommandEncoderStatus::Error;
let raw_encoder = encoder.open(&cmd_buf.device).map_pass_err(pass_scope)?;

let mut state = State {
Expand Down Expand Up @@ -590,10 +585,6 @@ impl Global {
state.raw_encoder.end_compute_pass();
}

// We've successfully recorded the compute pass, bring the
// command buffer out of the error state.
*status = CommandEncoderStatus::Recording;

let State {
snatch_guard,
tracker,
Expand Down Expand Up @@ -626,6 +617,7 @@ impl Global {
encoder
.close_and_swap(&cmd_buf.device)
.map_pass_err(pass_scope)?;
cmd_buf_data_guard.mark_successful();

Ok(())
}
Expand Down
Loading

0 comments on commit 5e1fbd7

Please sign in to comment.