Skip to content

Commit

Permalink
Add Pipeline Overrides for workgroup_size (gfx-rs#6635)
Browse files Browse the repository at this point in the history
  • Loading branch information
kentslaney authored Dec 6, 2024
1 parent e15e1a1 commit b56960b
Show file tree
Hide file tree
Showing 30 changed files with 257 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ By @ErichDonGubler in [#6456](https://github.com/gfx-rs/wgpu/pull/6456), [#6148]
- Implement `quantizeToF16()` for WGSL frontend, and WGSL, SPIR-V, HLSL, MSL, and GLSL backends. By @jamienicol in [#6519](https://github.com/gfx-rs/wgpu/pull/6519).
- Add support for GLSL `usampler*` and `isampler*`. By @DavidPeicho in [#6513](https://github.com/gfx-rs/wgpu/pull/6513).
- Expose Ray Query flags as constants in WGSL. Implement candidate intersections. By @kvark in [#5429](https://github.com/gfx-rs/wgpu/pull/5429)
- Allow for override-expressions in `workgroup_size`. By @KentSlaney in [#6635](https://github.com/gfx-rs/wgpu/pull/6635).

#### General

Expand Down
38 changes: 38 additions & 0 deletions naga/src/back/pipeline_constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ pub enum PipelineConstantError {
ConstantEvaluatorError(#[from] ConstantEvaluatorError),
#[error(transparent)]
ValidationError(#[from] WithSpan<ValidationError>),
#[error("workgroup_size override isn't strictly positive")]
NegativeWorkgroupSize,
}

/// Replace all overrides in `module` with constants.
Expand Down Expand Up @@ -190,6 +192,7 @@ pub fn process_overrides<'a>(
let mut entry_points = mem::take(&mut module.entry_points);
for ep in entry_points.iter_mut() {
process_function(&mut module, &override_map, &mut ep.function)?;
process_workgroup_size_override(&mut module, &adjusted_global_expressions, ep)?;
}
module.entry_points = entry_points;

Expand All @@ -202,6 +205,41 @@ pub fn process_overrides<'a>(
Ok((Cow::Owned(module), Cow::Owned(module_info)))
}

fn process_workgroup_size_override(
module: &mut Module,
adjusted_global_expressions: &HandleVec<Expression, Handle<Expression>>,
ep: &mut crate::EntryPoint,
) -> Result<(), PipelineConstantError> {
match ep.workgroup_size_overrides {
None => {}
Some(overrides) => {
overrides.iter().enumerate().try_for_each(
|(i, overridden)| -> Result<(), PipelineConstantError> {
match *overridden {
None => Ok(()),
Some(h) => {
ep.workgroup_size[i] = module
.to_ctx()
.eval_expr_to_u32(adjusted_global_expressions[h])
.map(|n| {
if n == 0 {
Err(PipelineConstantError::NegativeWorkgroupSize)
} else {
Ok(n)
}
})
.map_err(|_| PipelineConstantError::NegativeWorkgroupSize)??;
Ok(())
}
}
},
)?;
ep.workgroup_size_overrides = None;
}
}
Ok(())
}

/// Add a [`Constant`] to `module` for the override `old_h`.
///
/// Add the new `Constant` to `override_map` and `adjusted_constant_initializers`.
Expand Down
20 changes: 20 additions & 0 deletions naga/src/compact/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ pub fn compact(module: &mut crate::Module) {
}
}

for e in module.entry_points.iter() {
if let Some(sizes) = e.workgroup_size_overrides {
for size in sizes.iter().filter_map(|x| *x) {
module_tracer.global_expressions_used.insert(size);
}
}
}

// We assume that all functions are used.
//
// Observe which types, constant expressions, constants, and
Expand Down Expand Up @@ -176,6 +184,18 @@ pub fn compact(module: &mut crate::Module) {
}
}

// Adjust workgroup_size_overrides
log::trace!("adjusting workgroup_size_overrides");
for e in module.entry_points.iter_mut() {
if let Some(sizes) = e.workgroup_size_overrides.as_mut() {
for size in sizes.iter_mut() {
if let Some(expr) = size.as_mut() {
module_map.global_expressions.adjust(expr);
}
}
}
}

// Adjust global variables' types and initializers.
log::trace!("adjusting global variables");
for (_, global) in module.global_variables.iter_mut() {
Expand Down
1 change: 1 addition & 0 deletions naga/src/front/glsl/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1366,6 +1366,7 @@ impl Frontend {
early_depth_test: Some(crate::EarlyDepthTest { conservative: None })
.filter(|_| self.meta.early_fragment_tests),
workgroup_size: self.meta.workgroup_size,
workgroup_size_overrides: None,
function: Function {
arguments,
expressions,
Expand Down
1 change: 1 addition & 0 deletions naga/src/front/spv/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,7 @@ impl<I: Iterator<Item = u32>> super::Frontend<I> {
stage: ep.stage,
early_depth_test: ep.early_depth_test,
workgroup_size: ep.workgroup_size,
workgroup_size_overrides: None,
function,
});

Expand Down
50 changes: 46 additions & 4 deletions naga/src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1311,24 +1311,53 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.collect();

if let Some(ref entry) = f.entry_point {
let workgroup_size = if let Some(workgroup_size) = entry.workgroup_size {
let workgroup_size_info = if let Some(workgroup_size) = entry.workgroup_size {
// TODO: replace with try_map once stabilized
let mut workgroup_size_out = [1; 3];
let mut workgroup_size_overrides_out = [None; 3];
for (i, size) in workgroup_size.into_iter().enumerate() {
if let Some(size_expr) = size {
workgroup_size_out[i] = self.const_u32(size_expr, &mut ctx.as_const())?.0;
match self.const_u32(size_expr, &mut ctx.as_const()) {
Ok(value) => {
workgroup_size_out[i] = value.0;
}
err => {
if let Err(Error::ConstantEvaluatorError(ref ty, _)) = err {
match **ty {
crate::proc::ConstantEvaluatorError::OverrideExpr => {
workgroup_size_overrides_out[i] =
Some(self.workgroup_size_override(
size_expr,
&mut ctx.as_override(),
)?);
}
_ => {
err?;
}
}
} else {
err?;
}
}
}
}
}
workgroup_size_out
if workgroup_size_overrides_out.iter().all(|x| x.is_none()) {
(workgroup_size_out, None)
} else {
(workgroup_size_out, Some(workgroup_size_overrides_out))
}
} else {
[0; 3]
([0; 3], None)
};

let (workgroup_size, workgroup_size_overrides) = workgroup_size_info;
ctx.module.entry_points.push(crate::EntryPoint {
name: f.name.name.to_string(),
stage: entry.stage,
early_depth_test: entry.early_depth_test,
workgroup_size,
workgroup_size_overrides,
function,
});
Ok(LoweredGlobalDecl::EntryPoint)
Expand All @@ -1338,6 +1367,19 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
}

fn workgroup_size_override(
&mut self,
size_expr: Handle<ast::Expression<'source>>,
ctx: &mut ExpressionContext<'source, '_, '_>,
) -> Result<Handle<crate::Expression>, Error<'source>> {
let span = ctx.ast_expressions.get_span(size_expr);
let expr = self.expression(size_expr, ctx)?;
match resolve_inner!(ctx, expr).scalar_kind().ok_or(0) {
Ok(crate::ScalarKind::Sint) | Ok(crate::ScalarKind::Uint) => Ok(expr),
_ => Err(Error::ExpectedConstExprConcreteIntegerScalar(span)),
}
}

fn block(
&mut self,
b: &ast::Block<'source>,
Expand Down
2 changes: 2 additions & 0 deletions naga/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2187,6 +2187,8 @@ pub struct EntryPoint {
pub early_depth_test: Option<EarlyDepthTest>,
/// Workgroup size for compute stages
pub workgroup_size: [u32; 3],
/// Override expressions for workgroup size in the global_expressions arena
pub workgroup_size_overrides: Option<[Option<Handle<Expression>>; 3]>,
/// The entrance function.
pub function: Function,
}
Expand Down
5 changes: 5 additions & 0 deletions naga/src/valid/handles.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,11 @@ impl super::Validator {

for entry_point in entry_points.iter() {
validate_function(None, &entry_point.function)?;
if let Some(sizes) = entry_point.workgroup_size_overrides {
for size in sizes.iter().filter_map(|x| *x) {
validate_const_expr(size)?;
}
}
}

for (function_handle, function) in functions.iter() {
Expand Down
4 changes: 4 additions & 0 deletions naga/tests/out/ir/access.compact.ron
Original file line number Diff line number Diff line change
Expand Up @@ -1854,6 +1854,7 @@
stage: Vertex,
early_depth_test: None,
workgroup_size: (0, 0, 0),
workgroup_size_overrides: None,
function: (
name: Some("foo_vert"),
arguments: [
Expand Down Expand Up @@ -2156,6 +2157,7 @@
stage: Fragment,
early_depth_test: None,
workgroup_size: (0, 0, 0),
workgroup_size_overrides: None,
function: (
name: Some("foo_frag"),
arguments: [],
Expand Down Expand Up @@ -2348,6 +2350,7 @@
stage: Compute,
early_depth_test: None,
workgroup_size: (1, 1, 1),
workgroup_size_overrides: None,
function: (
name: Some("assign_through_ptr"),
arguments: [],
Expand Down Expand Up @@ -2430,6 +2433,7 @@
stage: Compute,
early_depth_test: None,
workgroup_size: (1, 1, 1),
workgroup_size_overrides: None,
function: (
name: Some("assign_to_ptr_components"),
arguments: [],
Expand Down
4 changes: 4 additions & 0 deletions naga/tests/out/ir/access.ron
Original file line number Diff line number Diff line change
Expand Up @@ -1854,6 +1854,7 @@
stage: Vertex,
early_depth_test: None,
workgroup_size: (0, 0, 0),
workgroup_size_overrides: None,
function: (
name: Some("foo_vert"),
arguments: [
Expand Down Expand Up @@ -2156,6 +2157,7 @@
stage: Fragment,
early_depth_test: None,
workgroup_size: (0, 0, 0),
workgroup_size_overrides: None,
function: (
name: Some("foo_frag"),
arguments: [],
Expand Down Expand Up @@ -2348,6 +2350,7 @@
stage: Compute,
early_depth_test: None,
workgroup_size: (1, 1, 1),
workgroup_size_overrides: None,
function: (
name: Some("assign_through_ptr"),
arguments: [],
Expand Down Expand Up @@ -2430,6 +2433,7 @@
stage: Compute,
early_depth_test: None,
workgroup_size: (1, 1, 1),
workgroup_size_overrides: None,
function: (
name: Some("assign_to_ptr_components"),
arguments: [],
Expand Down
1 change: 1 addition & 0 deletions naga/tests/out/ir/atomic_i_increment.compact.ron
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@
stage: Compute,
early_depth_test: None,
workgroup_size: (32, 1, 1),
workgroup_size_overrides: None,
function: (
name: Some("stage::test_atomic_i_increment_wrap"),
arguments: [],
Expand Down
1 change: 1 addition & 0 deletions naga/tests/out/ir/atomic_i_increment.ron
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@
stage: Compute,
early_depth_test: None,
workgroup_size: (32, 1, 1),
workgroup_size_overrides: None,
function: (
name: Some("stage::test_atomic_i_increment_wrap"),
arguments: [],
Expand Down
1 change: 1 addition & 0 deletions naga/tests/out/ir/collatz.compact.ron
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@
stage: Compute,
early_depth_test: None,
workgroup_size: (1, 1, 1),
workgroup_size_overrides: None,
function: (
name: Some("main"),
arguments: [
Expand Down
1 change: 1 addition & 0 deletions naga/tests/out/ir/collatz.ron
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@
stage: Compute,
early_depth_test: None,
workgroup_size: (1, 1, 1),
workgroup_size_overrides: None,
function: (
name: Some("main"),
arguments: [
Expand Down
1 change: 1 addition & 0 deletions naga/tests/out/ir/fetch_depth.compact.ron
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@
stage: Compute,
early_depth_test: None,
workgroup_size: (32, 1, 1),
workgroup_size_overrides: None,
function: (
name: Some("cull::fetch_depth_wrap"),
arguments: [],
Expand Down
1 change: 1 addition & 0 deletions naga/tests/out/ir/fetch_depth.ron
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@
stage: Compute,
early_depth_test: None,
workgroup_size: (32, 1, 1),
workgroup_size_overrides: None,
function: (
name: Some("cull::fetch_depth_wrap"),
arguments: [],
Expand Down
1 change: 1 addition & 0 deletions naga/tests/out/ir/index-by-value.compact.ron
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@
stage: Vertex,
early_depth_test: None,
workgroup_size: (0, 0, 0),
workgroup_size_overrides: None,
function: (
name: Some("index_let_array_1d"),
arguments: [
Expand Down
1 change: 1 addition & 0 deletions naga/tests/out/ir/index-by-value.ron
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@
stage: Vertex,
early_depth_test: None,
workgroup_size: (0, 0, 0),
workgroup_size_overrides: None,
function: (
name: Some("index_let_array_1d"),
arguments: [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
stage: Compute,
early_depth_test: None,
workgroup_size: (1, 1, 1),
workgroup_size_overrides: None,
function: (
name: Some("f"),
arguments: [],
Expand Down
1 change: 1 addition & 0 deletions naga/tests/out/ir/overrides-atomicCompareExchangeWeak.ron
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
stage: Compute,
early_depth_test: None,
workgroup_size: (1, 1, 1),
workgroup_size_overrides: None,
function: (
name: Some("f"),
arguments: [],
Expand Down
1 change: 1 addition & 0 deletions naga/tests/out/ir/overrides-ray-query.compact.ron
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
stage: Compute,
early_depth_test: None,
workgroup_size: (1, 1, 1),
workgroup_size_overrides: None,
function: (
name: Some("main"),
arguments: [],
Expand Down
1 change: 1 addition & 0 deletions naga/tests/out/ir/overrides-ray-query.ron
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
stage: Compute,
early_depth_test: None,
workgroup_size: (1, 1, 1),
workgroup_size_overrides: None,
function: (
name: Some("main"),
arguments: [],
Expand Down
1 change: 1 addition & 0 deletions naga/tests/out/ir/overrides.compact.ron
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
stage: Compute,
early_depth_test: None,
workgroup_size: (1, 1, 1),
workgroup_size_overrides: None,
function: (
name: Some("main"),
arguments: [],
Expand Down
1 change: 1 addition & 0 deletions naga/tests/out/ir/overrides.ron
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
stage: Compute,
early_depth_test: None,
workgroup_size: (1, 1, 1),
workgroup_size_overrides: None,
function: (
name: Some("main"),
arguments: [],
Expand Down
1 change: 1 addition & 0 deletions naga/tests/out/ir/shadow.compact.ron
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,7 @@
stage: Fragment,
early_depth_test: None,
workgroup_size: (0, 0, 0),
workgroup_size_overrides: None,
function: (
name: Some("fs_main_wrap"),
arguments: [
Expand Down
Loading

0 comments on commit b56960b

Please sign in to comment.