Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subgroup Operations #4190

Closed
wants to merge 53 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
14f6be2
subgroup: Implement subgroupBallot for wgsl-in, wgsl-out, spv-out, hl…
exrook Sep 30, 2023
e969f09
subgroup: subgroupBallot metal out
exrook Sep 30, 2023
05feb88
subgroup: require GroupNonUnifomBallot capability
exrook Sep 30, 2023
2ff8809
subgroup: Add subgroup invocation id and subgroup size builtins
exrook Oct 1, 2023
8cbf423
subgroup: SubgroupInvocationId is only valid in compute stages
exrook Oct 2, 2023
9430603
subgroup: expierment with subgroupBarrier() based on OpControlbarrier
exrook Oct 3, 2023
44f6929
subgroup: add statement for rest of subgroup ops
exrook Oct 14, 2023
b5be66e
subgroup: fix doc error on SubgroupBroadcast
exrook Oct 5, 2023
a277ec5
subgroup: wgsl-in and spv-out for subgroup operations
exrook Oct 5, 2023
dd49f99
subgroup: add optional predicate for subgroupBallot
exrook Oct 5, 2023
18ceb01
Renames SubgroupBroadcast => SubgroupGather and BroadcastMode => Gath…
Lichtso Oct 11, 2023
7c911ba
General fixes.
Lichtso Oct 14, 2023
f5c4ad7
Adds BuiltIn::NumSubgroups, BuiltIn::SubgroupId.
Lichtso Oct 14, 2023
f931a47
Adds GatherMode::Shuffle, GatherMode::ShuffleDown, GatherMode::Shuffl…
Lichtso Oct 14, 2023
d8e17e7
Implements all frontends and backends.
Lichtso Oct 14, 2023
43133d0
Adjusts metal backend test_stack_size().
Lichtso Oct 14, 2023
275a5a5
Adds test and snapshots.
Lichtso Oct 20, 2023
f04b71d
subgroup: fix spv-in
exrook Oct 21, 2023
7a5bf0f
subgroup: Add 0 arg subgroupBallot to tests and fix wgsl-out whitespace
exrook Oct 21, 2023
2aaf0a0
subgroup: Add spv-in test
exrook Oct 21, 2023
fb1e3f9
subgroup: resolve fixmes & fix typo
exrook Oct 21, 2023
209225a
subgroup: Add subgroup capability
exrook Oct 21, 2023
77d33b9
subgroup: Treat subgroup operation results as non-uniform
exrook Oct 18, 2023
1e51650
subgroup: refactor wgsl subgroup gather parsing
exrook Oct 21, 2023
68fc4d5
subgroup: doc comments for subgroup `Statement`s and `Expression`s
exrook Oct 21, 2023
91c569d
subgroup: add validation for each subgroup operation type
exrook Oct 24, 2023
6f6f789
Add feature for subgroup operations in fragment and compute shaders
exrook Oct 14, 2023
7d2e273
Adds feature detection for Vulkan.
exrook Sep 30, 2023
55e21e8
Adds feature detection for Metal.
Lichtso Oct 11, 2023
2e44fb1
Adds feature detection for DirectX 12.
Lichtso Oct 14, 2023
e186d0e
Adds subgroup_operations tests.
Lichtso Oct 11, 2023
9c12f08
Pass subgroup capability to naga
exrook Oct 22, 2023
358c2a7
Separate subgroup feature into one flag per shader stage
exrook Oct 24, 2023
9afec48
Merge branch 'trunk' into subgroup_feature
cwfitzgerald Oct 26, 2023
edd11d1
Merge branch 'trunk' into subgroup_feature
cwfitzgerald Oct 26, 2023
c821cc9
Merge branch 'trunk' into subgroup_feature
cwfitzgerald Oct 28, 2023
ebaf08d
Fix compiles
cwfitzgerald Oct 28, 2023
afaf441
subgroups: DX12 doesn't support subgroup ops in vertex stage
exrook Nov 1, 2023
fd14a43
Merge branch 'trunk' into subgroup_feature
exrook Nov 1, 2023
7e0060e
subgroup: fix hlsl subgroup_id
exrook Nov 5, 2023
d822524
subgroups: update naga snapshots
exrook Nov 5, 2023
a86a130
Merge branch 'trunk' into subgroup_feature
exrook Nov 5, 2023
011016f
subgroups: fix gpu test on systems with subgroup size of 1
exrook Nov 6, 2023
a68ef32
subgroup: emit Shuffle instead of Broadcast on spv-out
exrook Nov 7, 2023
3bc4aa9
subgroup: Use bitmask to track pass/failed tests
exrook Nov 16, 2023
8420c97
subgroups: Print detailed error message on test failure
exrook Nov 21, 2023
395cd21
subgroups: Add tests in divergent control flow
exrook Nov 21, 2023
f1a4d75
subgroups: correct changelog entry PR link
exrook Nov 21, 2023
4dd8e06
Merge branch 'trunk' into subgroup_feature
exrook Nov 22, 2023
50355e0
Merge branch 'trunk' into HEAD
exrook Nov 28, 2023
4bf0479
subgroups: Expect test failures on metal
exrook Nov 28, 2023
156f026
subgroups: Add test for divergent for loop
exrook Nov 28, 2023
f9fc7f0
subgroups: Verify convergence after finishing other tests
exrook Nov 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Previously, `DeviceExt::create_texture_with_data` only allowed data to be provid

#### General
- Added `DownlevelFlags::VERTEX_AND_INSTANCE_INDEX_RESPECTS_RESPECTIVE_FIRST_VALUE_IN_INDIRECT_DRAW` to know if `@builtin(vertex_index)` and `@builtin(instance_index)` will respect the `first_vertex` / `first_instance` in indirect calls. If this is not present, both will always start counting from 0. Currently enabled on all backends except DX12. By @cwfitzgerald in [#4722](https://github.com/gfx-rs/wgpu/pull/4722)
- Add `SUBGROUP_COMPUTE, SUBGROUP_FRAGMENT, SUBGROUP_VERTEX` features. By @exrook and @lichtso in [#4190](https://github.com/gfx-rs/wgpu/pull/4190)

#### OpenGL
- `@builtin(instance_index)` now properly reflects the range provided in the draw call instead of always counting from 0. By @cwfitzgerald in [#4722](https://github.com/gfx-rs/wgpu/pull/4722).
Expand Down
2 changes: 2 additions & 0 deletions naga-cli/src/bin/naga.rs
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,8 @@ fn run() -> Result<(), Box<dyn std::error::Error>> {

// Validate the IR before compaction.
let info = match naga::valid::Validator::new(params.validation_flags, validation_caps)
.subgroup_stages(naga::valid::ShaderStages::all())
.subgroup_operations(naga::valid::SubgroupOperationSet::all())
.validate(&module)
{
Ok(info) => Some(info),
Expand Down
90 changes: 90 additions & 0 deletions naga/src/back/dot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,94 @@ impl StatementGraph {
crate::RayQueryFunction::Terminate => "RayQueryTerminate",
}
}
S::SubgroupBallot { result, predicate } => {
if let Some(predicate) = predicate {
self.dependencies.push((id, predicate, "predicate"));
}
self.emits.push((id, result));
"SubgroupBallot"
}
S::SubgroupCollectiveOperation {
op,
collective_op,
argument,
result,
} => {
self.dependencies.push((id, argument, "arg"));
self.emits.push((id, result));
match (collective_op, op) {
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
"SubgroupAll"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
"SubgroupAny"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
"SubgroupAdd"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
"SubgroupMul"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
"SubgroupMax"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
"SubgroupMin"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
"SubgroupAnd"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
"SubgroupOr"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
"SubgroupXor"
}
(
crate::CollectiveOperation::ExclusiveScan,
crate::SubgroupOperation::Add,
) => "SubgroupPrefixExclusiveAdd",
(
crate::CollectiveOperation::ExclusiveScan,
crate::SubgroupOperation::Mul,
) => "SubgroupPrefixExclusiveMul",
(
crate::CollectiveOperation::InclusiveScan,
crate::SubgroupOperation::Add,
) => "SubgroupPrefixInclusiveAdd",
(
crate::CollectiveOperation::InclusiveScan,
crate::SubgroupOperation::Mul,
) => "SubgroupPrefixInclusiveMul",
_ => unimplemented!(),
}
}
S::SubgroupGather {
mode,
argument,
result,
} => {
match mode {
crate::GatherMode::BroadcastFirst => {}
crate::GatherMode::Broadcast(index)
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => {
self.dependencies.push((id, index, "index"))
}
}
self.dependencies.push((id, argument, "arg"));
self.emits.push((id, result));
match mode {
crate::GatherMode::BroadcastFirst => "SubgroupBroadcastFirst",
crate::GatherMode::Broadcast(_) => "SubgroupBroadcast",
crate::GatherMode::Shuffle(_) => "SubgroupShuffle",
crate::GatherMode::ShuffleDown(_) => "SubgroupShuffleDown",
crate::GatherMode::ShuffleUp(_) => "SubgroupShuffleUp",
crate::GatherMode::ShuffleXor(_) => "SubgroupShuffleXor",
}
}
};
// Set the last node to the merge node
last_node = merge_id;
Expand Down Expand Up @@ -586,6 +674,8 @@ fn write_function_expressions(
let ty = if committed { "Committed" } else { "Candidate" };
(format!("rayQueryGet{}Intersection", ty).into(), 4)
}
E::SubgroupBallotResult => ("SubgroupBallotResult".into(), 4),
E::SubgroupOperationResult { .. } => ("SubgroupOperationResult".into(), 4),
};

// give uniform expressions an outline
Expand Down
23 changes: 23 additions & 0 deletions naga/src/back/glsl/features.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ bitflags::bitflags! {
///
/// We can always support this, either through the language or a polyfill
const INSTANCE_INDEX = 1 << 22;
/// Subgroup operations
const SUBGROUP_OPERATIONS = 1 << 23;
}
}

Expand Down Expand Up @@ -115,6 +117,7 @@ impl FeaturesManager {
check_feature!(SAMPLE_VARIABLES, 400, 300);
check_feature!(DYNAMIC_ARRAY_SIZE, 430, 310);
check_feature!(DUAL_SOURCE_BLENDING, 330, 300 /* with extension */);
check_feature!(SUBGROUP_OPERATIONS, 430, 310);
match version {
Version::Embedded { is_webgl: true, .. } => check_feature!(MULTI_VIEW, 140, 300),
_ => check_feature!(MULTI_VIEW, 140, 310),
Expand Down Expand Up @@ -251,6 +254,22 @@ impl FeaturesManager {
}
}

if self.0.contains(Features::SUBGROUP_OPERATIONS) {
// https://registry.khronos.org/OpenGL/extensions/KHR/KHR_shader_subgroup.txt
writeln!(out, "#extension GL_KHR_shader_subgroup_basic : require")?;
writeln!(out, "#extension GL_KHR_shader_subgroup_vote : require")?;
writeln!(
out,
"#extension GL_KHR_shader_subgroup_arithmetic : require"
)?;
writeln!(out, "#extension GL_KHR_shader_subgroup_ballot : require")?;
writeln!(out, "#extension GL_KHR_shader_subgroup_shuffle : require")?;
writeln!(
out,
"#extension GL_KHR_shader_subgroup_shuffle_relative : require"
)?;
}

Ok(())
}
}
Expand Down Expand Up @@ -469,6 +488,10 @@ impl<'a, W> Writer<'a, W> {
}
}
}
Expression::SubgroupBallotResult |
Expression::SubgroupOperationResult { .. } => {
features.request(Features::SUBGROUP_OPERATIONS)
}
_ => {}
}
}
Expand Down
131 changes: 130 additions & 1 deletion naga/src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2379,6 +2379,125 @@ impl<'a, W: Write> Writer<'a, W> {
writeln!(self.out, ");")?;
}
Statement::RayQuery { .. } => unreachable!(),
Statement::SubgroupBallot { result, predicate } => {
write!(self.out, "{level}")?;
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
let res_ty = ctx.info[result].ty.inner_with(&self.module.types);
self.write_value_type(res_ty)?;
write!(self.out, " {res_name} = ")?;
self.named_expressions.insert(result, res_name);

write!(self.out, "subgroupBallot(")?;
match predicate {
Some(predicate) => self.write_expr(predicate, ctx)?,
None => write!(self.out, "true")?,
}
writeln!(self.out, ");")?;
}
Statement::SubgroupCollectiveOperation {
op,
collective_op,
argument,
result,
} => {
write!(self.out, "{level}")?;
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
let res_ty = ctx.info[result].ty.inner_with(&self.module.types);
self.write_value_type(res_ty)?;
write!(self.out, " {res_name} = ")?;
self.named_expressions.insert(result, res_name);

match (collective_op, op) {
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
write!(self.out, "subgroupAll(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
write!(self.out, "subgroupAny(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
write!(self.out, "subgroupAdd(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
write!(self.out, "subgroupMul(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
write!(self.out, "subgroupMax(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
write!(self.out, "subgroupMin(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
write!(self.out, "subgroupAnd(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
write!(self.out, "subgroupOr(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
write!(self.out, "subgroupXor(")?
}
(crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
write!(self.out, "subgroupExclusiveAdd(")?
}
(crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
write!(self.out, "subgroupExclusiveMul(")?
}
(crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
write!(self.out, "subgroupInclusiveAdd(")?
}
(crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
write!(self.out, "subgroupInclusiveMul(")?
}
_ => unimplemented!(),
}
self.write_expr(argument, ctx)?;
writeln!(self.out, ");")?;
}
Statement::SubgroupGather {
mode,
argument,
result,
} => {
write!(self.out, "{level}")?;
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
let res_ty = ctx.info[result].ty.inner_with(&self.module.types);
self.write_value_type(res_ty)?;
write!(self.out, " {res_name} = ")?;
self.named_expressions.insert(result, res_name);

match mode {
crate::GatherMode::BroadcastFirst => {
write!(self.out, "subgroupBroadcastFirst(")?;
}
crate::GatherMode::Broadcast(_) => {
write!(self.out, "subgroupBroadcast(")?;
}
crate::GatherMode::Shuffle(_) => {
write!(self.out, "subgroupShuffle(")?;
}
crate::GatherMode::ShuffleDown(_) => {
write!(self.out, "subgroupShuffleDown(")?;
}
crate::GatherMode::ShuffleUp(_) => {
write!(self.out, "subgroupShuffleUp(")?;
}
crate::GatherMode::ShuffleXor(_) => {
write!(self.out, "subgroupShuffleXor(")?;
}
}
self.write_expr(argument, ctx)?;
match mode {
crate::GatherMode::BroadcastFirst => {}
crate::GatherMode::Broadcast(index)
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => {
write!(self.out, ", ")?;
self.write_expr(index, ctx)?;
}
}
writeln!(self.out, ");")?;
}
}

Ok(())
Expand Down Expand Up @@ -3567,7 +3686,9 @@ impl<'a, W: Write> Writer<'a, W> {
Expression::CallResult(_)
| Expression::AtomicResult { .. }
| Expression::RayQueryProceedResult
| Expression::WorkGroupUniformLoadResult { .. } => unreachable!(),
| Expression::WorkGroupUniformLoadResult { .. }
| Expression::SubgroupOperationResult { .. }
| Expression::SubgroupBallotResult => unreachable!(),
// `ArrayLength` is written as `expr.length()` and we convert it to a uint
Expression::ArrayLength(expr) => {
write!(self.out, "uint(")?;
Expand Down Expand Up @@ -4131,6 +4252,9 @@ impl<'a, W: Write> Writer<'a, W> {
if flags.contains(crate::Barrier::WORK_GROUP) {
writeln!(self.out, "{level}memoryBarrierShared();")?;
}
if flags.contains(crate::Barrier::SUB_GROUP) {
writeln!(self.out, "{level}subgroupMemoryBarrier();")?;
}
writeln!(self.out, "{level}barrier();")?;
Ok(())
}
Expand Down Expand Up @@ -4397,6 +4521,11 @@ const fn glsl_built_in(built_in: crate::BuiltIn, options: VaryingOptions) -> &'s
Bi::WorkGroupId => "gl_WorkGroupID",
Bi::WorkGroupSize => "gl_WorkGroupSize",
Bi::NumWorkGroups => "gl_NumWorkGroups",
// subgroup
Bi::NumSubgroups => "gl_NumSubgroups",
Bi::SubgroupId => "gl_SubgroupID",
Bi::SubgroupSize => "gl_SubgroupSize",
Bi::SubgroupInvocationId => "gl_SubgroupInvocationID",
}
}

Expand Down
4 changes: 4 additions & 0 deletions naga/src/back/hlsl/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ impl crate::BuiltIn {
// to this field will get replaced with references to `SPECIAL_CBUF_VAR`
// in `Writer::write_expr`.
Self::NumWorkGroups => "SV_GroupID",
Self::SubgroupSize
| Self::SubgroupInvocationId
| Self::NumSubgroups
| Self::SubgroupId => return Err(Error::Unimplemented(format!("builtin {self:?}"))),
Self::BaseInstance | Self::BaseVertex | Self::WorkGroupSize => {
return Err(Error::Unimplemented(format!("builtin {self:?}")))
}
Expand Down
Loading
Loading