diff --git a/CHANGELOG.md b/CHANGELOG.md index 69d1f06239..641c780c2e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -52,6 +52,7 @@ Previously, `DeviceExt::create_texture_with_data` only allowed data to be provid #### General - Added `DownlevelFlags::VERTEX_AND_INSTANCE_INDEX_RESPECTS_RESPECTIVE_FIRST_VALUE_IN_INDIRECT_DRAW` to know if `@builtin(vertex_index)` and `@builtin(instance_index)` will respect the `first_vertex` / `first_instance` in indirect calls. If this is not present, both will always start counting from 0. Currently enabled on all backends except DX12. By @cwfitzgerald in [#4722](https://github.com/gfx-rs/wgpu/pull/4722) +- Add `SUBGROUP_COMPUTE, SUBGROUP_FRAGMENT, SUBGROUP_VERTEX` features. By @exrook and @lichtso in [#4190](https://github.com/gfx-rs/wgpu/pull/4190) #### OpenGL - `@builtin(instance_index)` now properly reflects the range provided in the draw call instead of always counting from 0. By @cwfitzgerald in [#4722](https://github.com/gfx-rs/wgpu/pull/4722). diff --git a/naga-cli/src/bin/naga.rs b/naga-cli/src/bin/naga.rs index 7543e02539..5e38442c85 100644 --- a/naga-cli/src/bin/naga.rs +++ b/naga-cli/src/bin/naga.rs @@ -400,6 +400,8 @@ fn run() -> Result<(), Box> { // Validate the IR before compaction. let info = match naga::valid::Validator::new(params.validation_flags, validation_caps) + .subgroup_stages(naga::valid::ShaderStages::all()) + .subgroup_operations(naga::valid::SubgroupOperationSet::all()) .validate(&module) { Ok(info) => Some(info), diff --git a/naga/src/back/dot/mod.rs b/naga/src/back/dot/mod.rs index 1556371df1..86f4797b56 100644 --- a/naga/src/back/dot/mod.rs +++ b/naga/src/back/dot/mod.rs @@ -279,6 +279,94 @@ impl StatementGraph { crate::RayQueryFunction::Terminate => "RayQueryTerminate", } } + S::SubgroupBallot { result, predicate } => { + if let Some(predicate) = predicate { + self.dependencies.push((id, predicate, "predicate")); + } + self.emits.push((id, result)); + "SubgroupBallot" + } + S::SubgroupCollectiveOperation { + op, + collective_op, + argument, + result, + } => { + self.dependencies.push((id, argument, "arg")); + self.emits.push((id, result)); + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + "SubgroupAll" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + "SubgroupAny" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + "SubgroupAdd" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + "SubgroupMul" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + "SubgroupMax" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + "SubgroupMin" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + "SubgroupAnd" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + "SubgroupOr" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + "SubgroupXor" + } + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Add, + ) => "SubgroupPrefixExclusiveAdd", + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Mul, + ) => "SubgroupPrefixExclusiveMul", + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Add, + ) => "SubgroupPrefixInclusiveAdd", + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Mul, + ) => "SubgroupPrefixInclusiveMul", + _ => unimplemented!(), + } + } + S::SubgroupGather { + mode, + argument, + result, + } => { + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + self.dependencies.push((id, index, "index")) + } + } + self.dependencies.push((id, argument, "arg")); + self.emits.push((id, result)); + match mode { + crate::GatherMode::BroadcastFirst => "SubgroupBroadcastFirst", + crate::GatherMode::Broadcast(_) => "SubgroupBroadcast", + crate::GatherMode::Shuffle(_) => "SubgroupShuffle", + crate::GatherMode::ShuffleDown(_) => "SubgroupShuffleDown", + crate::GatherMode::ShuffleUp(_) => "SubgroupShuffleUp", + crate::GatherMode::ShuffleXor(_) => "SubgroupShuffleXor", + } + } }; // Set the last node to the merge node last_node = merge_id; @@ -586,6 +674,8 @@ fn write_function_expressions( let ty = if committed { "Committed" } else { "Candidate" }; (format!("rayQueryGet{}Intersection", ty).into(), 4) } + E::SubgroupBallotResult => ("SubgroupBallotResult".into(), 4), + E::SubgroupOperationResult { .. } => ("SubgroupOperationResult".into(), 4), }; // give uniform expressions an outline diff --git a/naga/src/back/glsl/features.rs b/naga/src/back/glsl/features.rs index aaebfde9cb..acafa50e4c 100644 --- a/naga/src/back/glsl/features.rs +++ b/naga/src/back/glsl/features.rs @@ -48,6 +48,8 @@ bitflags::bitflags! { /// /// We can always support this, either through the language or a polyfill const INSTANCE_INDEX = 1 << 22; + /// Subgroup operations + const SUBGROUP_OPERATIONS = 1 << 23; } } @@ -115,6 +117,7 @@ impl FeaturesManager { check_feature!(SAMPLE_VARIABLES, 400, 300); check_feature!(DYNAMIC_ARRAY_SIZE, 430, 310); check_feature!(DUAL_SOURCE_BLENDING, 330, 300 /* with extension */); + check_feature!(SUBGROUP_OPERATIONS, 430, 310); match version { Version::Embedded { is_webgl: true, .. } => check_feature!(MULTI_VIEW, 140, 300), _ => check_feature!(MULTI_VIEW, 140, 310), @@ -251,6 +254,22 @@ impl FeaturesManager { } } + if self.0.contains(Features::SUBGROUP_OPERATIONS) { + // https://registry.khronos.org/OpenGL/extensions/KHR/KHR_shader_subgroup.txt + writeln!(out, "#extension GL_KHR_shader_subgroup_basic : require")?; + writeln!(out, "#extension GL_KHR_shader_subgroup_vote : require")?; + writeln!( + out, + "#extension GL_KHR_shader_subgroup_arithmetic : require" + )?; + writeln!(out, "#extension GL_KHR_shader_subgroup_ballot : require")?; + writeln!(out, "#extension GL_KHR_shader_subgroup_shuffle : require")?; + writeln!( + out, + "#extension GL_KHR_shader_subgroup_shuffle_relative : require" + )?; + } + Ok(()) } } @@ -469,6 +488,10 @@ impl<'a, W> Writer<'a, W> { } } } + Expression::SubgroupBallotResult | + Expression::SubgroupOperationResult { .. } => { + features.request(Features::SUBGROUP_OPERATIONS) + } _ => {} } } diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index cd3075f70a..0141436d6b 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -2379,6 +2379,125 @@ impl<'a, W: Write> Writer<'a, W> { writeln!(self.out, ");")?; } Statement::RayQuery { .. } => unreachable!(), + Statement::SubgroupBallot { result, predicate } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + let res_ty = ctx.info[result].ty.inner_with(&self.module.types); + self.write_value_type(res_ty)?; + write!(self.out, " {res_name} = ")?; + self.named_expressions.insert(result, res_name); + + write!(self.out, "subgroupBallot(")?; + match predicate { + Some(predicate) => self.write_expr(predicate, ctx)?, + None => write!(self.out, "true")?, + } + writeln!(self.out, ");")?; + } + Statement::SubgroupCollectiveOperation { + op, + collective_op, + argument, + result, + } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + let res_ty = ctx.info[result].ty.inner_with(&self.module.types); + self.write_value_type(res_ty)?; + write!(self.out, " {res_name} = ")?; + self.named_expressions.insert(result, res_name); + + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "subgroupAll(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "subgroupAny(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupAdd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupMul(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "subgroupMax(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "subgroupMin(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "subgroupAnd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "subgroupOr(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "subgroupXor(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupExclusiveAdd(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupExclusiveMul(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupInclusiveAdd(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupInclusiveMul(")? + } + _ => unimplemented!(), + } + self.write_expr(argument, ctx)?; + writeln!(self.out, ");")?; + } + Statement::SubgroupGather { + mode, + argument, + result, + } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + let res_ty = ctx.info[result].ty.inner_with(&self.module.types); + self.write_value_type(res_ty)?; + write!(self.out, " {res_name} = ")?; + self.named_expressions.insert(result, res_name); + + match mode { + crate::GatherMode::BroadcastFirst => { + write!(self.out, "subgroupBroadcastFirst(")?; + } + crate::GatherMode::Broadcast(_) => { + write!(self.out, "subgroupBroadcast(")?; + } + crate::GatherMode::Shuffle(_) => { + write!(self.out, "subgroupShuffle(")?; + } + crate::GatherMode::ShuffleDown(_) => { + write!(self.out, "subgroupShuffleDown(")?; + } + crate::GatherMode::ShuffleUp(_) => { + write!(self.out, "subgroupShuffleUp(")?; + } + crate::GatherMode::ShuffleXor(_) => { + write!(self.out, "subgroupShuffleXor(")?; + } + } + self.write_expr(argument, ctx)?; + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + write!(self.out, ", ")?; + self.write_expr(index, ctx)?; + } + } + writeln!(self.out, ");")?; + } } Ok(()) @@ -3567,7 +3686,9 @@ impl<'a, W: Write> Writer<'a, W> { Expression::CallResult(_) | Expression::AtomicResult { .. } | Expression::RayQueryProceedResult - | Expression::WorkGroupUniformLoadResult { .. } => unreachable!(), + | Expression::WorkGroupUniformLoadResult { .. } + | Expression::SubgroupOperationResult { .. } + | Expression::SubgroupBallotResult => unreachable!(), // `ArrayLength` is written as `expr.length()` and we convert it to a uint Expression::ArrayLength(expr) => { write!(self.out, "uint(")?; @@ -4131,6 +4252,9 @@ impl<'a, W: Write> Writer<'a, W> { if flags.contains(crate::Barrier::WORK_GROUP) { writeln!(self.out, "{level}memoryBarrierShared();")?; } + if flags.contains(crate::Barrier::SUB_GROUP) { + writeln!(self.out, "{level}subgroupMemoryBarrier();")?; + } writeln!(self.out, "{level}barrier();")?; Ok(()) } @@ -4397,6 +4521,11 @@ const fn glsl_built_in(built_in: crate::BuiltIn, options: VaryingOptions) -> &'s Bi::WorkGroupId => "gl_WorkGroupID", Bi::WorkGroupSize => "gl_WorkGroupSize", Bi::NumWorkGroups => "gl_NumWorkGroups", + // subgroup + Bi::NumSubgroups => "gl_NumSubgroups", + Bi::SubgroupId => "gl_SubgroupID", + Bi::SubgroupSize => "gl_SubgroupSize", + Bi::SubgroupInvocationId => "gl_SubgroupInvocationID", } } diff --git a/naga/src/back/hlsl/conv.rs b/naga/src/back/hlsl/conv.rs index da17c35704..a6200e1f89 100644 --- a/naga/src/back/hlsl/conv.rs +++ b/naga/src/back/hlsl/conv.rs @@ -168,6 +168,10 @@ impl crate::BuiltIn { // to this field will get replaced with references to `SPECIAL_CBUF_VAR` // in `Writer::write_expr`. Self::NumWorkGroups => "SV_GroupID", + Self::SubgroupSize + | Self::SubgroupInvocationId + | Self::NumSubgroups + | Self::SubgroupId => return Err(Error::Unimplemented(format!("builtin {self:?}"))), Self::BaseInstance | Self::BaseVertex | Self::WorkGroupSize => { return Err(Error::Unimplemented(format!("builtin {self:?}"))) } diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 24d54fc0e5..38638ffe2a 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -1126,7 +1126,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, " {name}(")?; let need_workgroup_variables_initialization = - self.need_workgroup_variables_initialization(func_ctx, module); + self.need_workgroup_variables_initialization(func, func_ctx, module); // Write function arguments for non entry point functions match func_ctx.ty { @@ -1162,7 +1162,21 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name,)?; } else { let stage = module.entry_points[ep_index as usize].stage; + let mut arg_num = 0; for (index, arg) in func.arguments.iter().enumerate() { + if matches!( + arg.binding, + Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) + | Some(crate::Binding::BuiltIn( + crate::BuiltIn::SubgroupInvocationId + )) + | Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) + | Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) + ) { + continue; + } + arg_num += 1; + if index != 0 { write!(self.out, ", ")?; } @@ -1182,7 +1196,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } if need_workgroup_variables_initialization { - if !func.arguments.is_empty() { + if arg_num > 0 { write!(self.out, ", ")?; } write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?; @@ -1213,6 +1227,53 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_workgroup_variables_initialization(func_ctx, module)?; } + if let back::FunctionType::EntryPoint(ep_index) = func_ctx.ty { + let ep = &module.entry_points[ep_index as usize]; + for (index, arg) in func.arguments.iter().enumerate() { + if let Some(crate::Binding::BuiltIn(builtin)) = arg.binding { + if matches!( + builtin, + crate::BuiltIn::SubgroupSize + | crate::BuiltIn::SubgroupInvocationId + | crate::BuiltIn::NumSubgroups + | crate::BuiltIn::SubgroupId + ) { + let level = back::Level(1); + write!(self.out, "{level}const ")?; + + self.write_type(module, arg.ty)?; + + let argument_name = + &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)]; + write!(self.out, " {argument_name} = ")?; + + match builtin { + crate::BuiltIn::SubgroupSize => { + writeln!(self.out, "WaveGetLaneCount();")? + } + crate::BuiltIn::SubgroupInvocationId => { + writeln!(self.out, "WaveGetLaneIndex();")? + } + crate::BuiltIn::NumSubgroups => writeln!( + self.out, + "({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount();", + ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2] + )?, + crate::BuiltIn::SubgroupId => { + writeln!( + self.out, + "(__local_invocation_id.z * {}u + __local_invocation_id.y * {}u + __local_invocation_id.x) / WaveGetLaneCount();", + ep.workgroup_size[0] * ep.workgroup_size[1], + ep.workgroup_size[1], + )?; + } + _ => unreachable!(), + } + } + } + } + } + if let back::FunctionType::EntryPoint(index) = func_ctx.ty { self.write_ep_arguments_initialization(module, func, index)?; } @@ -1263,14 +1324,20 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { fn need_workgroup_variables_initialization( &mut self, + func: &crate::Function, func_ctx: &back::FunctionCtx, module: &Module, ) -> bool { - self.options.zero_initialize_workgroup_memory + func.arguments.iter().any(|arg| { + matches!( + arg.binding, + Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) + ) + }) || (self.options.zero_initialize_workgroup_memory && func_ctx.ty.is_compute_entry_point(module) && module.global_variables.iter().any(|(handle, var)| { !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup - }) + })) } fn write_workgroup_variables_initialization( @@ -2000,6 +2067,129 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { writeln!(self.out, "{level}}}")? } Statement::RayQuery { .. } => unreachable!(), + Statement::SubgroupBallot { result, predicate } => { + write!(self.out, "{level}")?; + let name = format!("{}{}", back::BAKE_PREFIX, result.index()); + write!(self.out, "const uint4 {name} = ")?; + self.named_expressions.insert(result, name); + + write!(self.out, "WaveActiveBallot(")?; + match predicate { + Some(predicate) => self.write_expr(module, predicate, func_ctx)?, + None => write!(self.out, "true")?, + } + writeln!(self.out, ");")?; + } + Statement::SubgroupCollectiveOperation { + op, + collective_op, + argument, + result, + } => { + write!(self.out, "{level}")?; + write!(self.out, "const ")?; + let name = format!("{}{}", back::BAKE_PREFIX, result.index()); + match func_ctx.info[result].ty { + proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?, + proc::TypeResolution::Value(ref value) => { + self.write_value_type(module, value)? + } + }; + write!(self.out, " {name} = ")?; + self.named_expressions.insert(result, name); + + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "WaveActiveAllTrue(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "WaveActiveAnyTrue(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "WaveActiveSum(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "WaveActiveProduct(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "WaveActiveMax(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "WaveActiveMin(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "WaveActiveBitAnd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "WaveActiveBitOr(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "WaveActiveBitXor(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "WavePrefixSum(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "WavePrefixProduct(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => { + self.write_expr(module, argument, func_ctx)?; + write!(self.out, " + WavePrefixSum(")?; + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => { + self.write_expr(module, argument, func_ctx)?; + write!(self.out, " * WavePrefixProduct(")?; + } + _ => unimplemented!(), + } + self.write_expr(module, argument, func_ctx)?; + writeln!(self.out, ");")?; + } + Statement::SubgroupGather { + mode, + argument, + result, + } => { + write!(self.out, "{level}")?; + write!(self.out, "const ")?; + let name = format!("{}{}", back::BAKE_PREFIX, result.index()); + match func_ctx.info[result].ty { + proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?, + proc::TypeResolution::Value(ref value) => { + self.write_value_type(module, value)? + } + }; + write!(self.out, " {name} = ")?; + self.named_expressions.insert(result, name); + + if matches!(mode, crate::GatherMode::BroadcastFirst) { + write!(self.out, "WaveReadLaneFirst(")?; + self.write_expr(module, argument, func_ctx)?; + } else { + write!(self.out, "WaveReadLaneAt(")?; + self.write_expr(module, argument, func_ctx)?; + write!(self.out, ", ")?; + match mode { + crate::GatherMode::BroadcastFirst => unreachable!(), + crate::GatherMode::Broadcast(index) | crate::GatherMode::Shuffle(index) => { + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::ShuffleDown(index) => { + write!(self.out, "WaveGetLaneIndex() + ")?; + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::ShuffleUp(index) => { + write!(self.out, "WaveGetLaneIndex() - ")?; + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::ShuffleXor(index) => { + write!(self.out, "WaveGetLaneIndex() ^ ")?; + self.write_expr(module, index, func_ctx)?; + } + } + } + writeln!(self.out, ");")?; + } } Ok(()) @@ -3153,7 +3343,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { Expression::CallResult(_) | Expression::AtomicResult { .. } | Expression::WorkGroupUniformLoadResult { .. } - | Expression::RayQueryProceedResult => {} + | Expression::RayQueryProceedResult + | Expression::SubgroupBallotResult + | Expression::SubgroupOperationResult { .. } => {} } if !closing_bracket.is_empty() { @@ -3220,6 +3412,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { if barrier.contains(crate::Barrier::WORK_GROUP) { writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?; } + if barrier.contains(crate::Barrier::SUB_GROUP) { + // Does not exist in DirectX + } Ok(()) } } diff --git a/naga/src/back/msl/mod.rs b/naga/src/back/msl/mod.rs index 5ef18730c9..eee825a83b 100644 --- a/naga/src/back/msl/mod.rs +++ b/naga/src/back/msl/mod.rs @@ -437,6 +437,11 @@ impl ResolvedBinding { Bi::WorkGroupId => "threadgroup_position_in_grid", Bi::WorkGroupSize => "dispatch_threads_per_threadgroup", Bi::NumWorkGroups => "threadgroups_per_grid", + // subgroup + Bi::NumSubgroups => "simdgroups_per_threadgroup", + Bi::SubgroupId => "simdgroup_index_in_threadgroup", + Bi::SubgroupSize => "threads_per_simdgroup", + Bi::SubgroupInvocationId => "thread_index_in_simdgroup", Bi::CullDistance | Bi::ViewIndex => { return Err(Error::UnsupportedBuiltIn(built_in)) } diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 17154c3cd5..33102d5673 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -1954,6 +1954,8 @@ impl Writer { crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } | crate::Expression::WorkGroupUniformLoadResult { .. } + | crate::Expression::SubgroupBallotResult + | crate::Expression::SubgroupOperationResult { .. } | crate::Expression::RayQueryProceedResult => { unreachable!() } @@ -3033,6 +3035,121 @@ impl Writer { } } } + crate::Statement::SubgroupBallot { result, predicate } => { + write!(self.out, "{level}")?; + let name = self.namer.call(""); + self.start_baking_expression(result, &context.expression, &name)?; + self.named_expressions.insert(result, name); + write!(self.out, "uint4((uint64_t){NAMESPACE}::simd_ballot(")?; + if let Some(predicate) = predicate { + self.put_expression(predicate, &context.expression, true)?; + } else { + write!(self.out, "true")?; + } + writeln!(self.out, "), 0, 0, 0);")?; + } + crate::Statement::SubgroupCollectiveOperation { + op, + collective_op, + argument, + result, + } => { + write!(self.out, "{level}")?; + let name = self.namer.call(""); + self.start_baking_expression(result, &context.expression, &name)?; + self.named_expressions.insert(result, name); + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "{NAMESPACE}::simd_all(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "{NAMESPACE}::simd_any(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "{NAMESPACE}::simd_sum(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "{NAMESPACE}::simd_product(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "{NAMESPACE}::simd_max(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "{NAMESPACE}::simd_min(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "{NAMESPACE}::simd_and(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "{NAMESPACE}::simd_or(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "{NAMESPACE}::simd_xor(")? + } + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Add, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?, + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Mul, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?, + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Add, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?, + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Mul, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?, + _ => unimplemented!(), + } + self.put_expression(argument, &context.expression, true)?; + writeln!(self.out, ");")?; + } + crate::Statement::SubgroupGather { + mode, + argument, + result, + } => { + write!(self.out, "{level}")?; + let name = self.namer.call(""); + self.start_baking_expression(result, &context.expression, &name)?; + self.named_expressions.insert(result, name); + match mode { + crate::GatherMode::BroadcastFirst => { + write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?; + } + crate::GatherMode::Broadcast(_) => { + write!(self.out, "{NAMESPACE}::simd_broadcast(")?; + } + crate::GatherMode::Shuffle(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle(")?; + } + crate::GatherMode::ShuffleDown(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?; + } + crate::GatherMode::ShuffleUp(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?; + } + crate::GatherMode::ShuffleXor(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?; + } + } + self.put_expression(argument, &context.expression, true)?; + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + write!(self.out, ", ")?; + self.put_expression(index, &context.expression, true)?; + } + } + writeln!(self.out, ");")?; + } } } @@ -4374,6 +4491,12 @@ impl Writer { "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);", )?; } + if flags.contains(crate::Barrier::SUB_GROUP) { + writeln!( + self.out, + "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);", + )?; + } Ok(()) } } @@ -4644,8 +4767,8 @@ fn test_stack_size() { } let stack_size = addresses_end - addresses_start; // check the size (in debug only) - // last observed macOS value: 19152 (CI) - if !(9000..=20000).contains(&stack_size) { + // last observed macOS value: 22256 (CI) + if !(15000..=25000).contains(&stack_size) { panic!("`put_block` stack size {stack_size} has changed!"); } } diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index df6ecd00ff..d20617d2ed 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -1128,7 +1128,9 @@ impl<'w> BlockContext<'w> { crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } | crate::Expression::WorkGroupUniformLoadResult { .. } - | crate::Expression::RayQueryProceedResult => self.cached[expr_handle], + | crate::Expression::RayQueryProceedResult + | crate::Expression::SubgroupBallotResult + | crate::Expression::SubgroupOperationResult { .. } => self.cached[expr_handle], crate::Expression::As { expr, kind, @@ -2339,6 +2341,27 @@ impl<'w> BlockContext<'w> { crate::Statement::RayQuery { query, ref fun } => { self.write_ray_query_function(query, fun, &mut block); } + crate::Statement::SubgroupBallot { + result, + ref predicate, + } => { + self.write_subgroup_ballot(predicate, result, &mut block)?; + } + crate::Statement::SubgroupCollectiveOperation { + ref op, + ref collective_op, + argument, + result, + } => { + self.write_subgroup_operation(op, collective_op, argument, result, &mut block)?; + } + crate::Statement::SubgroupGather { + ref mode, + argument, + result, + } => { + self.write_subgroup_gather(mode, argument, result, &mut block)?; + } } } diff --git a/naga/src/back/spv/instructions.rs b/naga/src/back/spv/instructions.rs index b963793ad3..5f7c6b34fd 100644 --- a/naga/src/back/spv/instructions.rs +++ b/naga/src/back/spv/instructions.rs @@ -1037,6 +1037,73 @@ impl super::Instruction { instruction.add_operand(semantics_id); instruction } + + // Group Instructions + + pub(super) fn group_non_uniform_ballot( + result_type_id: Word, + id: Word, + exec_scope_id: Word, + predicate: Word, + ) -> Self { + let mut instruction = Self::new(Op::GroupNonUniformBallot); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(exec_scope_id); + instruction.add_operand(predicate); + + instruction + } + pub(super) fn group_non_uniform_broadcast_first( + result_type_id: Word, + id: Word, + exec_scope_id: Word, + value: Word, + ) -> Self { + let mut instruction = Self::new(Op::GroupNonUniformBroadcastFirst); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(exec_scope_id); + instruction.add_operand(value); + + instruction + } + pub(super) fn group_non_uniform_gather( + op: Op, + result_type_id: Word, + id: Word, + exec_scope_id: Word, + value: Word, + index: Word, + ) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(exec_scope_id); + instruction.add_operand(value); + instruction.add_operand(index); + + instruction + } + pub(super) fn group_non_uniform_arithmetic( + op: Op, + result_type_id: Word, + id: Word, + exec_scope_id: Word, + group_op: Option, + value: Word, + ) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(exec_scope_id); + if let Some(group_op) = group_op { + instruction.add_operand(group_op as u32); + } + instruction.add_operand(value); + + instruction + } } impl From for spirv::ImageFormat { diff --git a/naga/src/back/spv/mod.rs b/naga/src/back/spv/mod.rs index b7d57be0d4..8335501aa5 100644 --- a/naga/src/back/spv/mod.rs +++ b/naga/src/back/spv/mod.rs @@ -13,6 +13,7 @@ mod layout; mod ray; mod recyclable; mod selection; +mod subgroup; mod writer; pub use spirv::Capability; diff --git a/naga/src/back/spv/subgroup.rs b/naga/src/back/spv/subgroup.rs new file mode 100644 index 0000000000..917c723fd6 --- /dev/null +++ b/naga/src/back/spv/subgroup.rs @@ -0,0 +1,207 @@ +use super::{Block, BlockContext, Error, Instruction}; +use crate::{ + arena::Handle, + back::spv::{LocalType, LookupType}, + TypeInner, +}; + +impl<'w> BlockContext<'w> { + pub(super) fn write_subgroup_ballot( + &mut self, + predicate: &Option>, + result: Handle, + block: &mut Block, + ) -> Result<(), Error> { + self.writer.require_any( + "GroupNonUniformBallot", + &[spirv::Capability::GroupNonUniformBallot], + )?; + let vec4_u32_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(crate::VectorSize::Quad), + scalar: crate::Scalar::U32, + pointer_space: None, + })); + let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); + let predicate = if let Some(predicate) = *predicate { + self.cached[predicate] + } else { + self.writer.get_constant_scalar(crate::Literal::Bool(true)) + }; + let id = self.gen_id(); + block.body.push(Instruction::group_non_uniform_ballot( + vec4_u32_type_id, + id, + exec_scope_id, + predicate, + )); + self.cached[result] = id; + Ok(()) + } + pub(super) fn write_subgroup_operation( + &mut self, + op: &crate::SubgroupOperation, + collective_op: &crate::CollectiveOperation, + argument: Handle, + result: Handle, + block: &mut Block, + ) -> Result<(), Error> { + use crate::SubgroupOperation as sg; + match *op { + sg::All | sg::Any => { + self.writer.require_any( + "GroupNonUniformVote", + &[spirv::Capability::GroupNonUniformVote], + )?; + } + _ => { + self.writer.require_any( + "GroupNonUniformArithmetic", + &[spirv::Capability::GroupNonUniformArithmetic], + )?; + } + } + + let id = self.gen_id(); + let result_ty = &self.fun_info[result].ty; + let result_type_id = self.get_expression_type_id(result_ty); + let result_ty_inner = result_ty.inner_with(&self.ir_module.types); + + let (is_scalar, scalar) = match *result_ty_inner { + TypeInner::Scalar(kind) => (true, kind), + TypeInner::Vector { scalar: kind, .. } => (false, kind), + _ => unimplemented!(), + }; + + use crate::ScalarKind as sk; + let spirv_op = match (scalar.kind, *op) { + (sk::Bool, sg::All) if is_scalar => spirv::Op::GroupNonUniformAll, + (sk::Bool, sg::Any) if is_scalar => spirv::Op::GroupNonUniformAny, + (_, sg::All | sg::Any) => unimplemented!(), + + (sk::Sint | sk::Uint, sg::Add) => spirv::Op::GroupNonUniformIAdd, + (sk::Float, sg::Add) => spirv::Op::GroupNonUniformFAdd, + (sk::Sint | sk::Uint, sg::Mul) => spirv::Op::GroupNonUniformIMul, + (sk::Float, sg::Mul) => spirv::Op::GroupNonUniformFMul, + (sk::Sint, sg::Max) => spirv::Op::GroupNonUniformSMax, + (sk::Uint, sg::Max) => spirv::Op::GroupNonUniformUMax, + (sk::Float, sg::Max) => spirv::Op::GroupNonUniformFMax, + (sk::Sint, sg::Min) => spirv::Op::GroupNonUniformSMin, + (sk::Uint, sg::Min) => spirv::Op::GroupNonUniformUMin, + (sk::Float, sg::Min) => spirv::Op::GroupNonUniformFMin, + (sk::Bool, sg::Add | sg::Mul | sg::Min | sg::Max) => unimplemented!(), + + (sk::Sint | sk::Uint, sg::And) => spirv::Op::GroupNonUniformBitwiseAnd, + (sk::Sint | sk::Uint, sg::Or) => spirv::Op::GroupNonUniformBitwiseOr, + (sk::Sint | sk::Uint, sg::Xor) => spirv::Op::GroupNonUniformBitwiseXor, + (sk::Float, sg::And | sg::Or | sg::Xor) => unimplemented!(), + (sk::Bool, sg::And) => spirv::Op::GroupNonUniformLogicalAnd, + (sk::Bool, sg::Or) => spirv::Op::GroupNonUniformLogicalOr, + (sk::Bool, sg::Xor) => spirv::Op::GroupNonUniformLogicalXor, + }; + + let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); + + use crate::CollectiveOperation as c; + let group_op = match *op { + sg::All | sg::Any => None, + _ => Some(match *collective_op { + c::Reduce => spirv::GroupOperation::Reduce, + c::InclusiveScan => spirv::GroupOperation::InclusiveScan, + c::ExclusiveScan => spirv::GroupOperation::ExclusiveScan, + }), + }; + + let arg_id = self.cached[argument]; + block.body.push(Instruction::group_non_uniform_arithmetic( + spirv_op, + result_type_id, + id, + exec_scope_id, + group_op, + arg_id, + )); + self.cached[result] = id; + Ok(()) + } + pub(super) fn write_subgroup_gather( + &mut self, + mode: &crate::GatherMode, + argument: Handle, + result: Handle, + block: &mut Block, + ) -> Result<(), Error> { + self.writer.require_any( + "GroupNonUniformBallot", + &[spirv::Capability::GroupNonUniformBallot], + )?; + match *mode { + crate::GatherMode::BroadcastFirst | crate::GatherMode::Broadcast(_) => { + self.writer.require_any( + "GroupNonUniformBallot", + &[spirv::Capability::GroupNonUniformBallot], + )?; + } + crate::GatherMode::Shuffle(_) | crate::GatherMode::ShuffleXor(_) => { + self.writer.require_any( + "GroupNonUniformShuffle", + &[spirv::Capability::GroupNonUniformShuffle], + )?; + } + crate::GatherMode::ShuffleDown(_) | crate::GatherMode::ShuffleUp(_) => { + self.writer.require_any( + "GroupNonUniformShuffleRelative", + &[spirv::Capability::GroupNonUniformShuffleRelative], + )?; + } + } + + let id = self.gen_id(); + let result_ty = &self.fun_info[result].ty; + let result_type_id = self.get_expression_type_id(result_ty); + + let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); + + let arg_id = self.cached[argument]; + match *mode { + crate::GatherMode::BroadcastFirst => { + block + .body + .push(Instruction::group_non_uniform_broadcast_first( + result_type_id, + id, + exec_scope_id, + arg_id, + )); + } + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + let index_id = self.cached[index]; + let op = match *mode { + crate::GatherMode::BroadcastFirst => unreachable!(), + // Use shuffle to emit broadcast to allow the index to + // be dynamically uniform on Vulkan 1.1. The argument to + // OpGroupNonUniformBroadcast must be a constant pre SPIR-V + // 1.5 (vulkan 1.2) + crate::GatherMode::Broadcast(_) => spirv::Op::GroupNonUniformShuffle, + crate::GatherMode::Shuffle(_) => spirv::Op::GroupNonUniformShuffle, + crate::GatherMode::ShuffleDown(_) => spirv::Op::GroupNonUniformShuffleDown, + crate::GatherMode::ShuffleUp(_) => spirv::Op::GroupNonUniformShuffleUp, + crate::GatherMode::ShuffleXor(_) => spirv::Op::GroupNonUniformShuffleXor, + }; + block.body.push(Instruction::group_non_uniform_gather( + op, + result_type_id, + id, + exec_scope_id, + arg_id, + index_id, + )); + } + } + self.cached[result] = id; + Ok(()) + } +} diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index ef0532b2ea..7d51b63b28 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -1299,7 +1299,11 @@ impl Writer { spirv::MemorySemantics::WORKGROUP_MEMORY, flags.contains(crate::Barrier::WORK_GROUP), ); - let exec_scope_id = self.get_index_constant(spirv::Scope::Workgroup as u32); + let exec_scope_id = if flags.contains(crate::Barrier::SUB_GROUP) { + self.get_index_constant(spirv::Scope::Subgroup as u32) + } else { + self.get_index_constant(spirv::Scope::Workgroup as u32) + }; let mem_scope_id = self.get_index_constant(memory_scope as u32); let semantics_id = self.get_index_constant(semantics.bits()); block.body.push(Instruction::control_barrier( @@ -1574,6 +1578,41 @@ impl Writer { Bi::WorkGroupId => BuiltIn::WorkgroupId, Bi::WorkGroupSize => BuiltIn::WorkgroupSize, Bi::NumWorkGroups => BuiltIn::NumWorkgroups, + // Subgroup + Bi::NumSubgroups => { + self.require_any( + "`num_subgroups` built-in", + &[spirv::Capability::GroupNonUniform], + )?; + BuiltIn::NumSubgroups + } + Bi::SubgroupId => { + self.require_any( + "`subgroup_id` built-in", + &[spirv::Capability::GroupNonUniform], + )?; + BuiltIn::SubgroupId + } + Bi::SubgroupSize => { + self.require_any( + "`subgroup_size` built-in", + &[ + spirv::Capability::GroupNonUniform, + spirv::Capability::SubgroupBallotKHR, + ], + )?; + BuiltIn::SubgroupSize + } + Bi::SubgroupInvocationId => { + self.require_any( + "`subgroup_invocation_id` built-in", + &[ + spirv::Capability::GroupNonUniform, + spirv::Capability::SubgroupBallotKHR, + ], + )?; + BuiltIn::SubgroupLocalInvocationId + } }; self.decorate(id, Decoration::BuiltIn, &[built_in as u32]); diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 10da339968..9493ce977d 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -917,8 +917,124 @@ impl Writer { if barrier.contains(crate::Barrier::WORK_GROUP) { writeln!(self.out, "{level}workgroupBarrier();")?; } + + if barrier.contains(crate::Barrier::SUB_GROUP) { + writeln!(self.out, "{level}subgroupBarrier();")?; + } } Statement::RayQuery { .. } => unreachable!(), + Statement::SubgroupBallot { result, predicate } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_named_expr(module, result, func_ctx, &res_name)?; + self.named_expressions.insert(result, res_name); + + write!(self.out, "subgroupBallot(")?; + if let Some(predicate) = predicate { + self.write_expr(module, predicate, func_ctx)?; + } + writeln!(self.out, ");")?; + } + Statement::SubgroupCollectiveOperation { + op, + collective_op, + argument, + result, + } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_named_expr(module, result, func_ctx, &res_name)?; + self.named_expressions.insert(result, res_name); + + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "subgroupAll(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "subgroupAny(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupAdd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupMul(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "subgroupMax(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "subgroupMin(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "subgroupAnd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "subgroupOr(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "subgroupXor(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupPrefixExclusiveAdd(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupPrefixExclusiveMul(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupPrefixInclusiveAdd(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupPrefixInclusiveMul(")? + } + _ => unimplemented!(), + } + self.write_expr(module, argument, func_ctx)?; + writeln!(self.out, ");")?; + } + Statement::SubgroupGather { + mode, + argument, + result, + } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_named_expr(module, result, func_ctx, &res_name)?; + self.named_expressions.insert(result, res_name); + + match mode { + crate::GatherMode::BroadcastFirst => { + write!(self.out, "subgroupBroadcastFirst(")?; + } + crate::GatherMode::Broadcast(_) => { + write!(self.out, "subgroupBroadcast(")?; + } + crate::GatherMode::Shuffle(_) => { + write!(self.out, "subgroupShuffle(")?; + } + crate::GatherMode::ShuffleDown(_) => { + write!(self.out, "subgroupShuffleDown(")?; + } + crate::GatherMode::ShuffleUp(_) => { + write!(self.out, "subgroupShuffleUp(")?; + } + crate::GatherMode::ShuffleXor(_) => { + write!(self.out, "subgroupShuffleXor(")?; + } + } + self.write_expr(module, argument, func_ctx)?; + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + write!(self.out, ", ")?; + self.write_expr(module, index, func_ctx)?; + } + } + writeln!(self.out, ");")?; + } } Ok(()) @@ -1672,6 +1788,8 @@ impl Writer { Expression::CallResult(_) | Expression::AtomicResult { .. } | Expression::RayQueryProceedResult + | Expression::SubgroupBallotResult + | Expression::SubgroupOperationResult { .. } | Expression::WorkGroupUniformLoadResult { .. } => {} } @@ -1773,6 +1891,10 @@ fn builtin_str(built_in: crate::BuiltIn) -> Result<&'static str, Error> { Bi::SampleMask => "sample_mask", Bi::PrimitiveIndex => "primitive_index", Bi::ViewIndex => "view_index", + Bi::NumSubgroups => "num_subgroups", + Bi::SubgroupId => "subgroup_id", + Bi::SubgroupSize => "subgroup_size", + Bi::SubgroupInvocationId => "subgroup_invocation_id", Bi::BaseInstance | Bi::BaseVertex | Bi::ClipDistance diff --git a/naga/src/compact/expressions.rs b/naga/src/compact/expressions.rs index 301bbe3240..9493c286ca 100644 --- a/naga/src/compact/expressions.rs +++ b/naga/src/compact/expressions.rs @@ -71,6 +71,7 @@ impl<'tracer> ExpressionTracer<'tracer> { | Ex::GlobalVariable(_) | Ex::LocalVariable(_) | Ex::CallResult(_) + | Ex::SubgroupBallotResult | Ex::RayQueryProceedResult => {} Ex::Constant(handle) => { @@ -186,6 +187,7 @@ impl<'tracer> ExpressionTracer<'tracer> { Ex::AtomicResult { ty, comparison: _ } => self.types_used.insert(ty), Ex::WorkGroupUniformLoadResult { ty } => self.types_used.insert(ty), Ex::ArrayLength(expr) => self.expressions_used.insert(expr), + Ex::SubgroupOperationResult { ty } => self.types_used.insert(ty), Ex::RayQueryGetIntersection { query, committed: _, @@ -217,6 +219,7 @@ impl ModuleMap { | Ex::GlobalVariable(_) | Ex::LocalVariable(_) | Ex::CallResult(_) + | Ex::SubgroupBallotResult | Ex::RayQueryProceedResult => {} // Expressions that contain handles that need to be adjusted. @@ -344,6 +347,7 @@ impl ModuleMap { comparison: _, } => self.types.adjust(ty), Ex::WorkGroupUniformLoadResult { ref mut ty } => self.types.adjust(ty), + Ex::SubgroupOperationResult { ref mut ty } => self.types.adjust(ty), Ex::ArrayLength(ref mut expr) => adjust(expr), Ex::RayQueryGetIntersection { ref mut query, diff --git a/naga/src/compact/statements.rs b/naga/src/compact/statements.rs index 0698b57258..a124281bc1 100644 --- a/naga/src/compact/statements.rs +++ b/naga/src/compact/statements.rs @@ -97,6 +97,39 @@ impl FunctionTracer<'_> { self.expressions_used.insert(query); self.trace_ray_query_function(fun); } + St::SubgroupBallot { result, predicate } => { + if let Some(predicate) = predicate { + self.expressions_used.insert(predicate) + } + self.expressions_used.insert(result) + } + St::SubgroupCollectiveOperation { + op: _, + collective_op: _, + argument, + result, + } => { + self.expressions_used.insert(argument); + self.expressions_used.insert(result) + } + St::SubgroupGather { + mode, + argument, + result, + } => { + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + self.expressions_used.insert(index) + } + } + self.expressions_used.insert(argument); + self.expressions_used.insert(result) + } // Trivial statements. St::Break @@ -250,6 +283,40 @@ impl FunctionMap { adjust(query); self.adjust_ray_query_function(fun); } + St::SubgroupBallot { + ref mut result, + ref mut predicate, + } => { + if let Some(ref mut predicate) = *predicate { + adjust(predicate); + } + adjust(result); + } + St::SubgroupCollectiveOperation { + op: _, + collective_op: _, + ref mut argument, + ref mut result, + } => { + adjust(argument); + adjust(result); + } + St::SubgroupGather { + ref mut mode, + ref mut argument, + ref mut result, + } => { + match *mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(ref mut index) + | crate::GatherMode::Shuffle(ref mut index) + | crate::GatherMode::ShuffleDown(ref mut index) + | crate::GatherMode::ShuffleUp(ref mut index) + | crate::GatherMode::ShuffleXor(ref mut index) => adjust(index), + } + adjust(argument); + adjust(result); + } // Trivial statements. St::Break diff --git a/naga/src/front/spv/convert.rs b/naga/src/front/spv/convert.rs index efd95898b8..a00e7e525b 100644 --- a/naga/src/front/spv/convert.rs +++ b/naga/src/front/spv/convert.rs @@ -154,6 +154,11 @@ pub(super) fn map_builtin(word: spirv::Word, invariant: bool) -> Result crate::BuiltIn::WorkGroupId, Some(Bi::WorkgroupSize) => crate::BuiltIn::WorkGroupSize, Some(Bi::NumWorkgroups) => crate::BuiltIn::NumWorkGroups, + // subgroup + Some(Bi::NumSubgroups) => crate::BuiltIn::NumSubgroups, + Some(Bi::SubgroupId) => crate::BuiltIn::SubgroupId, + Some(Bi::SubgroupSize) => crate::BuiltIn::SubgroupSize, + Some(Bi::SubgroupLocalInvocationId) => crate::BuiltIn::SubgroupInvocationId, _ => return Err(Error::UnsupportedBuiltIn(word)), }) } diff --git a/naga/src/front/spv/error.rs b/naga/src/front/spv/error.rs index 2f9bf2d1bc..cc6cd98801 100644 --- a/naga/src/front/spv/error.rs +++ b/naga/src/front/spv/error.rs @@ -54,6 +54,8 @@ pub enum Error { UnknownBinaryOperator(spirv::Op), #[error("unknown relational function {0:?}")] UnknownRelationalFunction(spirv::Op), + #[error("unsupported group operation %{0}")] + UnsupportedGroupOperation(spirv::Word), #[error("invalid parameter {0:?}")] InvalidParameter(spirv::Op), #[error("invalid operand count {1} for {0:?}")] @@ -116,8 +118,8 @@ pub enum Error { FunctionCallCycle(spirv::Word), #[error("invalid array size {0:?}")] InvalidArraySize(Handle), - #[error("invalid barrier scope %{0}")] - InvalidBarrierScope(spirv::Word), + #[error("invalid execution scope %{0}")] + InvalidExecutionScope(spirv::Word), #[error("invalid barrier memory semantics %{0}")] InvalidBarrierMemorySemantics(spirv::Word), #[error( diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index e7f07ebc58..c4fa937992 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -3650,7 +3650,7 @@ impl> Frontend { let semantics_const = self.lookup_constant.lookup(semantics_id)?; let exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle) - .ok_or(Error::InvalidBarrierScope(exec_scope_id))?; + .ok_or(Error::InvalidExecutionScope(exec_scope_id))?; let semantics = resolve_constant(ctx.gctx(), semantics_const.handle) .ok_or(Error::InvalidBarrierMemorySemantics(semantics_id))?; @@ -3691,6 +3691,254 @@ impl> Frontend { }, ); } + Op::GroupNonUniformBallot => { + inst.expect(5)?; + block.extend(emitter.finish(ctx.expressions)); + let result_type_id = self.next()?; + let result_id = self.next()?; + let exec_scope_id = self.next()?; + let predicate_id = self.next()?; + + let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; + let _exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle) + .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) + .ok_or(Error::InvalidExecutionScope(exec_scope_id))?; + + let predicate = if self + .lookup_constant + .lookup(predicate_id) + .ok() + .filter(|predicate_const| { + matches!( + ctx.gctx().const_expressions + [ctx.gctx().constants[predicate_const.handle].init], + crate::Expression::Literal(crate::Literal::Bool(true)) + ) + }) + .is_some() + { + None + } else { + let predicate_lookup = self.lookup_expression.lookup(predicate_id)?; + let predicate_handle = get_expr_handle!(predicate_id, predicate_lookup); + Some(predicate_handle) + }; + + let result_handle = ctx + .expressions + .append(crate::Expression::SubgroupBallotResult, span); + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: result_handle, + type_id: result_type_id, + block_id, + }, + ); + + block.push( + crate::Statement::SubgroupBallot { + result: result_handle, + predicate, + }, + span, + ); + emitter.start(ctx.expressions); + } + spirv::Op::GroupNonUniformAll + | spirv::Op::GroupNonUniformAny + | spirv::Op::GroupNonUniformIAdd + | spirv::Op::GroupNonUniformFAdd + | spirv::Op::GroupNonUniformIMul + | spirv::Op::GroupNonUniformFMul + | spirv::Op::GroupNonUniformSMax + | spirv::Op::GroupNonUniformUMax + | spirv::Op::GroupNonUniformFMax + | spirv::Op::GroupNonUniformSMin + | spirv::Op::GroupNonUniformUMin + | spirv::Op::GroupNonUniformFMin + | spirv::Op::GroupNonUniformBitwiseAnd + | spirv::Op::GroupNonUniformBitwiseOr + | spirv::Op::GroupNonUniformBitwiseXor + | spirv::Op::GroupNonUniformLogicalAnd + | spirv::Op::GroupNonUniformLogicalOr + | spirv::Op::GroupNonUniformLogicalXor => { + block.extend(emitter.finish(ctx.expressions)); + inst.expect( + if matches!( + inst.op, + spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny + ) { + 5 + } else { + 6 + }, + )?; + let result_type_id = self.next()?; + let result_id = self.next()?; + let exec_scope_id = self.next()?; + let collective_op_id = match inst.op { + spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny => { + crate::CollectiveOperation::Reduce + } + _ => { + let group_op_id = self.next()?; + match spirv::GroupOperation::from_u32(group_op_id) { + Some(spirv::GroupOperation::Reduce) => { + crate::CollectiveOperation::Reduce + } + Some(spirv::GroupOperation::InclusiveScan) => { + crate::CollectiveOperation::InclusiveScan + } + Some(spirv::GroupOperation::ExclusiveScan) => { + crate::CollectiveOperation::ExclusiveScan + } + _ => return Err(Error::UnsupportedGroupOperation(group_op_id)), + } + } + }; + let argument_id = self.next()?; + + let argument_lookup = self.lookup_expression.lookup(argument_id)?; + let argument_handle = get_expr_handle!(argument_id, argument_lookup); + + let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; + let _exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle) + .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) + .ok_or(Error::InvalidExecutionScope(exec_scope_id))?; + + let op_id = match inst.op { + spirv::Op::GroupNonUniformAll => crate::SubgroupOperation::All, + spirv::Op::GroupNonUniformAny => crate::SubgroupOperation::Any, + spirv::Op::GroupNonUniformIAdd | spirv::Op::GroupNonUniformFAdd => { + crate::SubgroupOperation::Add + } + spirv::Op::GroupNonUniformIMul | spirv::Op::GroupNonUniformFMul => { + crate::SubgroupOperation::Mul + } + spirv::Op::GroupNonUniformSMax + | spirv::Op::GroupNonUniformUMax + | spirv::Op::GroupNonUniformFMax => crate::SubgroupOperation::Max, + spirv::Op::GroupNonUniformSMin + | spirv::Op::GroupNonUniformUMin + | spirv::Op::GroupNonUniformFMin => crate::SubgroupOperation::Min, + spirv::Op::GroupNonUniformBitwiseAnd + | spirv::Op::GroupNonUniformLogicalAnd => crate::SubgroupOperation::And, + spirv::Op::GroupNonUniformBitwiseOr + | spirv::Op::GroupNonUniformLogicalOr => crate::SubgroupOperation::Or, + spirv::Op::GroupNonUniformBitwiseXor + | spirv::Op::GroupNonUniformLogicalXor => crate::SubgroupOperation::Xor, + _ => unreachable!(), + }; + + let result_type = self.lookup_type.lookup(result_type_id)?; + + let result_handle = ctx.expressions.append( + crate::Expression::SubgroupOperationResult { + ty: result_type.handle, + }, + span, + ); + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: result_handle, + type_id: result_type_id, + block_id, + }, + ); + + block.push( + crate::Statement::SubgroupCollectiveOperation { + result: result_handle, + op: op_id, + collective_op: collective_op_id, + argument: argument_handle, + }, + span, + ); + emitter.start(ctx.expressions); + } + Op::GroupNonUniformBroadcastFirst + | Op::GroupNonUniformBroadcast + | Op::GroupNonUniformShuffle + | Op::GroupNonUniformShuffleDown + | Op::GroupNonUniformShuffleUp + | Op::GroupNonUniformShuffleXor => { + inst.expect( + if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) { + 5 + } else { + 6 + }, + )?; + block.extend(emitter.finish(ctx.expressions)); + let result_type_id = self.next()?; + let result_id = self.next()?; + let exec_scope_id = self.next()?; + let argument_id = self.next()?; + + let argument_lookup = self.lookup_expression.lookup(argument_id)?; + let argument_handle = get_expr_handle!(argument_id, argument_lookup); + + let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; + let _exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle) + .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) + .ok_or(Error::InvalidExecutionScope(exec_scope_id))?; + + let mode = if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) { + crate::GatherMode::BroadcastFirst + } else { + let index_id = self.next()?; + let index_lookup = self.lookup_expression.lookup(index_id)?; + let index_handle = get_expr_handle!(index_id, index_lookup); + match inst.op { + spirv::Op::GroupNonUniformBroadcast => { + crate::GatherMode::Broadcast(index_handle) + } + spirv::Op::GroupNonUniformShuffle => { + crate::GatherMode::Shuffle(index_handle) + } + spirv::Op::GroupNonUniformShuffleDown => { + crate::GatherMode::ShuffleDown(index_handle) + } + spirv::Op::GroupNonUniformShuffleUp => { + crate::GatherMode::ShuffleUp(index_handle) + } + spirv::Op::GroupNonUniformShuffleXor => { + crate::GatherMode::ShuffleXor(index_handle) + } + _ => unreachable!(), + } + }; + + let result_type = self.lookup_type.lookup(result_type_id)?; + + let result_handle = ctx.expressions.append( + crate::Expression::SubgroupOperationResult { + ty: result_type.handle, + }, + span, + ); + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: result_handle, + type_id: result_type_id, + block_id, + }, + ); + + block.push( + crate::Statement::SubgroupGather { + result: result_handle, + mode, + argument: argument_handle, + }, + span, + ); + emitter.start(ctx.expressions); + } _ => return Err(Error::UnsupportedInstruction(self.state, inst.op)), } }; @@ -3811,7 +4059,10 @@ impl> Frontend { | S::Store { .. } | S::ImageStore { .. } | S::Atomic { .. } - | S::RayQuery { .. } => {} + | S::RayQuery { .. } + | S::SubgroupBallot { .. } + | S::SubgroupCollectiveOperation { .. } + | S::SubgroupGather { .. } => {} S::Call { function: ref mut callee, ref arguments, diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index a727d6379b..6965b7a2a2 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -834,6 +834,29 @@ impl Texture { } } +enum SubgroupGather { + BroadcastFirst, + Broadcast, + Shuffle, + ShuffleDown, + ShuffleUp, + ShuffleXor, +} + +impl SubgroupGather { + pub fn map(word: &str) -> Option { + Some(match word { + "subgroupBroadcastFirst" => Self::BroadcastFirst, + "subgroupBroadcast" => Self::Broadcast, + "subgroupShuffle" => Self::Shuffle, + "subgroupShuffleDown" => Self::ShuffleDown, + "subgroupShuffleUp" => Self::ShuffleUp, + "subgroupShuffleXor" => Self::ShuffleXor, + _ => return None, + }) + } +} + pub struct Lowerer<'source, 'temp> { index: &'temp Index<'source>, layouter: Layouter, @@ -1834,6 +1857,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } } else if let Some(fun) = Texture::map(function.name) { self.texture_sample_helper(fun, arguments, span, ctx)? + } else if let Some((op, cop)) = conv::map_subgroup_operation(function.name) { + return Ok(Some( + self.subgroup_operation_helper(span, op, cop, arguments, ctx)?, + )); + } else if let Some(mode) = SubgroupGather::map(function.name) { + return Ok(Some( + self.subgroup_gather_helper(span, mode, arguments, ctx)?, + )); } else { match function.name { "select" => { @@ -2001,6 +2032,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .push(crate::Statement::Barrier(crate::Barrier::WORK_GROUP), span); return Ok(None); } + "subgroupBarrier" => { + ctx.prepare_args(arguments, 0, span).finish()?; + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .push(crate::Statement::Barrier(crate::Barrier::SUB_GROUP), span); + return Ok(None); + } "workgroupUniformLoad" => { let mut args = ctx.prepare_args(arguments, 1, span); let expr = args.next()?; @@ -2208,6 +2247,22 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { )?; return Ok(Some(handle)); } + "subgroupBallot" => { + let mut args = ctx.prepare_args(arguments, 0, span); + let predicate = if arguments.len() == 1 { + Some(self.expression(args.next()?, ctx)?) + } else { + None + }; + args.finish()?; + + let result = ctx + .interrupt_emitter(crate::Expression::SubgroupBallotResult, span)?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .push(crate::Statement::SubgroupBallot { result, predicate }, span); + return Ok(Some(result)); + } _ => return Err(Error::UnknownIdent(function.span, function.name)), } }; @@ -2399,6 +2454,80 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { }) } + fn subgroup_operation_helper( + &mut self, + span: Span, + op: crate::SubgroupOperation, + collective_op: crate::CollectiveOperation, + arguments: &[Handle>], + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result, Error<'source>> { + let mut args = ctx.prepare_args(arguments, 1, span); + + let argument = self.expression(args.next()?, ctx)?; + args.finish()?; + + let ty = ctx.register_type(argument)?; + + let result = + ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::SubgroupCollectiveOperation { + op, + collective_op, + argument, + result, + }, + span, + ); + Ok(result) + } + + fn subgroup_gather_helper( + &mut self, + span: Span, + mode: SubgroupGather, + arguments: &[Handle>], + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result, Error<'source>> { + let mut args = ctx.prepare_args(arguments, 2, span); + + let argument = self.expression(args.next()?, ctx)?; + + use SubgroupGather as Sg; + let mode = if let Sg::BroadcastFirst = mode { + crate::GatherMode::BroadcastFirst + } else { + let index = self.expression(args.next()?, ctx)?; + match mode { + Sg::Broadcast => crate::GatherMode::Broadcast(index), + Sg::Shuffle => crate::GatherMode::Shuffle(index), + Sg::ShuffleDown => crate::GatherMode::ShuffleDown(index), + Sg::ShuffleUp => crate::GatherMode::ShuffleUp(index), + Sg::ShuffleXor => crate::GatherMode::ShuffleXor(index), + Sg::BroadcastFirst => unreachable!(), + } + }; + + args.finish()?; + + let ty = ctx.register_type(argument)?; + + let result = + ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::SubgroupGather { + mode, + argument, + result, + }, + span, + ); + Ok(result) + } + fn r#struct( &mut self, s: &ast::Struct<'source>, diff --git a/naga/src/front/wgsl/parse/conv.rs b/naga/src/front/wgsl/parse/conv.rs index 08f1e39285..b2c563b7ce 100644 --- a/naga/src/front/wgsl/parse/conv.rs +++ b/naga/src/front/wgsl/parse/conv.rs @@ -35,6 +35,11 @@ pub fn map_built_in(word: &str, span: Span) -> Result> "local_invocation_index" => crate::BuiltIn::LocalInvocationIndex, "workgroup_id" => crate::BuiltIn::WorkGroupId, "num_workgroups" => crate::BuiltIn::NumWorkGroups, + // subgroup + "num_subgroups" => crate::BuiltIn::NumSubgroups, + "subgroup_id" => crate::BuiltIn::SubgroupId, + "subgroup_size" => crate::BuiltIn::SubgroupSize, + "subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId, _ => return Err(Error::UnknownBuiltin(span)), }) } @@ -252,3 +257,26 @@ pub fn map_conservative_depth( _ => Err(Error::UnknownConservativeDepth(span)), } } + +pub fn map_subgroup_operation( + word: &str, +) -> Option<(crate::SubgroupOperation, crate::CollectiveOperation)> { + use crate::CollectiveOperation as co; + use crate::SubgroupOperation as sg; + Some(match word { + "subgroupAll" => (sg::All, co::Reduce), + "subgroupAny" => (sg::Any, co::Reduce), + "subgroupAdd" => (sg::Add, co::Reduce), + "subgroupMul" => (sg::Mul, co::Reduce), + "subgroupMin" => (sg::Min, co::Reduce), + "subgroupMax" => (sg::Max, co::Reduce), + "subgroupAnd" => (sg::And, co::Reduce), + "subgroupOr" => (sg::Or, co::Reduce), + "subgroupXor" => (sg::Xor, co::Reduce), + "subgroupPrefixExclusiveAdd" => (sg::Add, co::ExclusiveScan), + "subgroupPrefixExclusiveMul" => (sg::Mul, co::ExclusiveScan), + "subgroupPrefixInclusiveAdd" => (sg::Add, co::InclusiveScan), + "subgroupPrefixInclusiveMul" => (sg::Mul, co::InclusiveScan), + _ => return None, + }) +} diff --git a/naga/src/lib.rs b/naga/src/lib.rs index e140ad6aef..5f57fa5e91 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -435,6 +435,11 @@ pub enum BuiltIn { WorkGroupId, WorkGroupSize, NumWorkGroups, + // subgroup + NumSubgroups, + SubgroupId, + SubgroupSize, + SubgroupInvocationId, } /// Number of bytes per scalar. @@ -1267,6 +1272,46 @@ pub enum SwizzleComponent { W = 3, } +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum GatherMode { + BroadcastFirst, + Broadcast(Handle), + Shuffle(Handle), + ShuffleDown(Handle), + ShuffleUp(Handle), + ShuffleXor(Handle), +} + +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum SubgroupOperation { + All = 0, + Any = 1, + Add = 2, + Mul = 3, + Min = 4, + Max = 5, + And = 6, + Or = 7, + Xor = 8, +} + +#[repr(u8)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum CollectiveOperation { + Reduce = 0, + InclusiveScan = 1, + ExclusiveScan = 2, +} + bitflags::bitflags! { /// Memory barrier flags. #[cfg_attr(feature = "serialize", derive(Serialize))] @@ -1275,9 +1320,11 @@ bitflags::bitflags! { #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] pub struct Barrier: u32 { /// Barrier affects all `AddressSpace::Storage` accesses. - const STORAGE = 0x1; + const STORAGE = 1 << 0; /// Barrier affects all `AddressSpace::WorkGroup` accesses. - const WORK_GROUP = 0x2; + const WORK_GROUP = 1 << 1; + /// Barrier synchronizes execution across all invocations within a subgroup that exectue this instruction. + const SUB_GROUP = 1 << 2; } } @@ -1576,6 +1623,15 @@ pub enum Expression { query: Handle, committed: bool, }, + /// Result of a [`SubgroupBallot`] statement. + /// + /// [`SubgroupBallot`]: Statement::SubgroupBallot + SubgroupBallotResult, + /// Result of a [`SubgroupCollectiveOperation`] or [`SubgroupGather`] statement. + /// + /// [`SubgroupCollectiveOperation`]: Statement::SubgroupCollectiveOperation + /// [`SubgroupGather`]: Statement::SubgroupGather + SubgroupOperationResult { ty: Handle }, } pub use block::Block; @@ -1848,6 +1904,39 @@ pub enum Statement { /// The specific operation we're performing on `query`. fun: RayQueryFunction, }, + /// Calculate a bitmask using a boolean from each active thread in the subgroup + SubgroupBallot { + /// The [`SubgroupBallotResult`] expression representing this load's result. + /// + /// [`SubgroupBallotResult`]: Expression::SubgroupBallotResult + result: Handle, + /// The value from this thread to store in the ballot + predicate: Option>, + }, + /// Gather a value from another active thread in the subgroup + SubgroupGather { + /// Specifies which thread to gather from + mode: GatherMode, + /// The value to broadcast over + argument: Handle, + /// The [`SubgroupOperationResult`] expression representing this load's result. + /// + /// [`SubgroupOperationResult`]: Expression::SubgroupOperationResult + result: Handle, + }, + /// Compute a collective operation across all active threads in the subgroup + SubgroupCollectiveOperation { + /// What operation to compute + op: SubgroupOperation, + /// How to combine the results + collective_op: CollectiveOperation, + /// The value to compute over + argument: Handle, + /// The [`SubgroupOperationResult`] expression representing this load's result. + /// + /// [`SubgroupOperationResult`]: Expression::SubgroupOperationResult + result: Handle, + }, } /// A function argument. diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index e3c07f9e16..cf1404b07a 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -133,6 +133,8 @@ pub enum ConstantEvaluatorError { ImageExpression, #[error("Constants don't support ray query expressions")] RayQueryExpression, + #[error("Constants don't support subgroup expressions")] + SubgroupExpression, #[error("Cannot access the type")] InvalidAccessBase, #[error("Cannot access at the index")] @@ -439,6 +441,12 @@ impl<'a> ConstantEvaluator<'a> { Expression::RayQueryProceedResult | Expression::RayQueryGetIntersection { .. } => { Err(ConstantEvaluatorError::RayQueryExpression) } + Expression::SubgroupBallotResult { .. } => { + Err(ConstantEvaluatorError::SubgroupExpression) + } + Expression::SubgroupOperationResult { .. } => { + Err(ConstantEvaluatorError::SubgroupExpression) + } } } diff --git a/naga/src/proc/terminator.rs b/naga/src/proc/terminator.rs index a5239d4eca..5edf55cb73 100644 --- a/naga/src/proc/terminator.rs +++ b/naga/src/proc/terminator.rs @@ -37,6 +37,9 @@ pub fn ensure_block_returns(block: &mut crate::Block) { | S::RayQuery { .. } | S::Atomic { .. } | S::WorkGroupUniformLoad { .. } + | S::SubgroupBallot { .. } + | S::SubgroupCollectiveOperation { .. } + | S::SubgroupGather { .. } | S::Barrier(_)), ) | None => block.push(S::Return { value: None }, Default::default()), diff --git a/naga/src/proc/typifier.rs b/naga/src/proc/typifier.rs index 9c4403445c..9df9fb218a 100644 --- a/naga/src/proc/typifier.rs +++ b/naga/src/proc/typifier.rs @@ -595,6 +595,7 @@ impl<'a> ResolveContext<'a> { | crate::BinaryOperator::ShiftRight => past(left)?.clone(), }, crate::Expression::AtomicResult { ty, .. } => TypeResolution::Handle(ty), + crate::Expression::SubgroupOperationResult { ty } => TypeResolution::Handle(ty), crate::Expression::WorkGroupUniformLoadResult { ty } => TypeResolution::Handle(ty), crate::Expression::Select { accept, .. } => past(accept)?.clone(), crate::Expression::Derivative { expr, .. } => past(expr)?.clone(), @@ -882,6 +883,10 @@ impl<'a> ResolveContext<'a> { .ok_or(ResolveError::MissingSpecialType)?; TypeResolution::Handle(result) } + crate::Expression::SubgroupBallotResult => TypeResolution::Value(Ti::Vector { + scalar: crate::Scalar::U32, + size: crate::VectorSize::Quad, + }), }) } } diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index df6fc5e9b0..98259b4168 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -740,6 +740,14 @@ impl FunctionInfo { non_uniform_result: self.add_ref(query), requirements: UniformityRequirements::empty(), }, + E::SubgroupBallotResult => Uniformity { + non_uniform_result: Some(handle), + requirements: UniformityRequirements::empty(), + }, + E::SubgroupOperationResult { .. } => Uniformity { + non_uniform_result: Some(handle), + requirements: UniformityRequirements::empty(), + }, }; let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?; @@ -982,6 +990,42 @@ impl FunctionInfo { } FunctionUniformity::new() } + S::SubgroupBallot { + result: _, + predicate, + } => { + if let Some(predicate) = predicate { + let _ = self.add_ref(predicate); + } + FunctionUniformity::new() + } + S::SubgroupCollectiveOperation { + op: _, + collective_op: _, + argument, + result: _, + } => { + let _ = self.add_ref(argument); + FunctionUniformity::new() + } + S::SubgroupGather { + mode, + argument, + result: _, + } => { + let _ = self.add_ref(argument); + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + let _ = self.add_ref(index); + } + } + FunctionUniformity::new() + } }; disruptor = disruptor.or(uniformity.exit_disruptor()); diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 1f57c55441..d33e7e3c0f 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -1590,6 +1590,7 @@ impl super::Validator { return Err(ExpressionError::InvalidRayQueryType(query)); } }, + E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages, }; Ok(stages) } diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index f5da7d0764..2c3e8c8369 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -47,6 +47,17 @@ pub enum AtomicError { ResultTypeMismatch(Handle), } +#[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] +pub enum SubgroupError { + #[error("Operand {0:?} has invalid type.")] + InvalidOperand(Handle), + #[error("Result type for {0:?} doesn't match the statement")] + ResultTypeMismatch(Handle), + #[error("Support for subgroup operation {0:?} is required")] + UnsupportedOperation(super::SubgroupOperationSet), +} + #[derive(Clone, Debug, thiserror::Error)] #[cfg_attr(test, derive(PartialEq))] pub enum LocalVariableError { @@ -135,6 +146,8 @@ pub enum FunctionError { InvalidRayDescriptor(Handle), #[error("Ray Query {0:?} does not have a matching type")] InvalidRayQueryType(Handle), + #[error("Shader requires capability {0:?}")] + MissingCapability(super::Capabilities), #[error( "Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}" )] @@ -155,6 +168,8 @@ pub enum FunctionError { WorkgroupUniformLoadExpressionMismatch(Handle), #[error("The expression {0:?} is not valid as a WorkGroupUniformLoad argument. It should be a Pointer in Workgroup address space")] WorkgroupUniformLoadInvalidPointer(Handle), + #[error("Subgroup operation is invalid")] + InvalidSubgroup(#[from] SubgroupError), } bitflags::bitflags! { @@ -399,6 +414,108 @@ impl super::Validator { } Ok(()) } + fn validate_subgroup_operation( + &mut self, + op: &crate::SubgroupOperation, + _collective_op: &crate::CollectiveOperation, + argument: Handle, + result: Handle, + context: &BlockContext, + ) -> Result<(), WithSpan> { + let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; + + let (is_scalar, scalar) = match *argument_inner { + crate::TypeInner::Scalar(scalar) => (true, scalar), + crate::TypeInner::Vector { scalar, .. } => (false, scalar), + _ => { + log::error!("Subgroup operand type {:?}", argument_inner); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(argument, context.expressions) + .into_other()); + } + }; + + use crate::ScalarKind as sk; + use crate::SubgroupOperation as sg; + match (scalar.kind, *op) { + (sk::Bool, sg::All | sg::Any) if is_scalar => {} + (sk::Sint | sk::Uint | sk::Float, sg::Add | sg::Mul | sg::Min | sg::Max) => {} + (sk::Sint | sk::Uint | sk::Bool, sg::And | sg::Or | sg::Xor) => {} + + (_, sg::All | sg::Any) + | (sk::Bool, sg::Add | sg::Mul | sg::Min | sg::Max) + | (sk::Float, sg::And | sg::Or | sg::Xor) => { + log::error!("Subgroup operand type {:?}", argument_inner); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(argument, context.expressions) + .into_other()); + } + }; + + self.emit_expression(result, context)?; + match context.expressions[result] { + crate::Expression::SubgroupOperationResult { ty } + if { &context.types[ty].inner == argument_inner } => {} + _ => { + return Err(SubgroupError::ResultTypeMismatch(result) + .with_span_handle(result, context.expressions) + .into_other()) + } + } + Ok(()) + } + fn validate_subgroup_broadcast( + &mut self, + mode: &crate::GatherMode, + argument: Handle, + result: Handle, + context: &BlockContext, + ) -> Result<(), WithSpan> { + match *mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + let index_ty = context.resolve_type(index, &self.valid_expression_set)?; + match *index_ty { + crate::TypeInner::Scalar(crate::Scalar::U32) => {} + _ => { + log::error!( + "Subgroup gather index type {:?}, expected unsigned int", + index_ty + ); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(index, context.expressions) + .into_other()); + } + } + } + } + let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; + if !matches!(*argument_inner, + crate::TypeInner::Scalar ( scalar, .. ) | crate::TypeInner::Vector { scalar, .. } + if matches!(scalar.kind, crate::ScalarKind::Uint | crate::ScalarKind::Sint | crate::ScalarKind::Float) + ) { + log::error!("Subgroup gather operand type {:?}", argument_inner); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(argument, context.expressions) + .into_other()); + } + + self.emit_expression(result, context)?; + match context.expressions[result] { + crate::Expression::SubgroupOperationResult { ty } + if { &context.types[ty].inner == argument_inner } => {} + _ => { + return Err(SubgroupError::ResultTypeMismatch(result) + .with_span_handle(result, context.expressions) + .into_other()) + } + } + Ok(()) + } fn validate_block_impl( &mut self, @@ -613,8 +730,27 @@ impl super::Validator { stages &= super::ShaderStages::FRAGMENT; finished = true; } - S::Barrier(_) => { + S::Barrier(barrier) => { stages &= super::ShaderStages::COMPUTE; + if barrier.contains(crate::Barrier::SUB_GROUP) { + if !self.capabilities.contains(super::Capabilities::SUBGROUP) { + return Err(FunctionError::MissingCapability( + super::Capabilities::SUBGROUP, + ) + .with_span_static(span, "missing capability for this operation")); + } + if !self + .subgroup_operations + .contains(super::SubgroupOperationSet::BASIC) + { + return Err(FunctionError::InvalidSubgroup( + SubgroupError::UnsupportedOperation( + super::SubgroupOperationSet::BASIC, + ), + ) + .with_span_static(span, "support for this operation is not present")); + } + } } S::Store { pointer, value } => { let mut current = pointer; @@ -904,6 +1040,86 @@ impl super::Validator { crate::RayQueryFunction::Terminate => {} } } + S::SubgroupBallot { result, predicate } => { + stages &= self.subgroup_stages; + if !self.capabilities.contains(super::Capabilities::SUBGROUP) { + return Err(FunctionError::MissingCapability( + super::Capabilities::SUBGROUP, + ) + .with_span_static(span, "missing capability for this operation")); + } + if !self + .subgroup_operations + .contains(super::SubgroupOperationSet::BALLOT) + { + return Err(FunctionError::InvalidSubgroup( + SubgroupError::UnsupportedOperation( + super::SubgroupOperationSet::BALLOT, + ), + ) + .with_span_static(span, "support for this operation is not present")); + } + if let Some(predicate) = predicate { + let predicate_inner = + context.resolve_type(predicate, &self.valid_expression_set)?; + if !matches!( + *predicate_inner, + crate::TypeInner::Scalar(crate::Scalar::BOOL,) + ) { + log::error!( + "Subgroup ballot predicate type {:?} expected bool", + predicate_inner + ); + return Err(SubgroupError::InvalidOperand(predicate) + .with_span_handle(predicate, context.expressions) + .into_other()); + } + } + self.emit_expression(result, context)?; + } + S::SubgroupCollectiveOperation { + ref op, + ref collective_op, + argument, + result, + } => { + stages &= self.subgroup_stages; + if !self.capabilities.contains(super::Capabilities::SUBGROUP) { + return Err(FunctionError::MissingCapability( + super::Capabilities::SUBGROUP, + ) + .with_span_static(span, "missing capability for this operation")); + } + let operation = op.required_operations(); + if !self.subgroup_operations.contains(operation) { + return Err(FunctionError::InvalidSubgroup( + SubgroupError::UnsupportedOperation(operation), + ) + .with_span_static(span, "support for this operation is not present")); + } + self.validate_subgroup_operation(op, collective_op, argument, result, context)?; + } + S::SubgroupGather { + ref mode, + argument, + result, + } => { + stages &= self.subgroup_stages; + if !self.capabilities.contains(super::Capabilities::SUBGROUP) { + return Err(FunctionError::MissingCapability( + super::Capabilities::SUBGROUP, + ) + .with_span_static(span, "missing capability for this operation")); + } + let operation = mode.required_operations(); + if !self.subgroup_operations.contains(operation) { + return Err(FunctionError::InvalidSubgroup( + SubgroupError::UnsupportedOperation(operation), + ) + .with_span_static(span, "support for this operation is not present")); + } + self.validate_subgroup_broadcast(mode, argument, result, context)?; + } } } Ok(BlockInfo { stages, finished }) diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index e482f293bb..904d4d6154 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -390,6 +390,8 @@ impl super::Validator { } crate::Expression::AtomicResult { .. } | crate::Expression::RayQueryProceedResult + | crate::Expression::SubgroupBallotResult + | crate::Expression::SubgroupOperationResult { .. } | crate::Expression::WorkGroupUniformLoadResult { .. } => (), crate::Expression::ArrayLength(array) => { handle.check_dep(array)?; @@ -535,6 +537,38 @@ impl super::Validator { } Ok(()) } + crate::Statement::SubgroupBallot { result, predicate } => { + validate_expr_opt(predicate)?; + validate_expr(result)?; + Ok(()) + } + crate::Statement::SubgroupCollectiveOperation { + op: _, + collective_op: _, + argument, + result, + } => { + validate_expr(argument)?; + validate_expr(result)?; + Ok(()) + } + crate::Statement::SubgroupGather { + mode, + argument, + result, + } => { + validate_expr(argument)?; + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => validate_expr(index)?, + } + validate_expr(result)?; + Ok(()) + } crate::Statement::Break | crate::Statement::Continue | crate::Statement::Kill diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index 57863048e5..2699589111 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -162,6 +162,10 @@ impl VaryingContext<'_> { Bi::PrimitiveIndex => Capabilities::PRIMITIVE_INDEX, Bi::ViewIndex => Capabilities::MULTIVIEW, Bi::SampleIndex => Capabilities::MULTISAMPLED_SHADING, + Bi::NumSubgroups + | Bi::SubgroupId + | Bi::SubgroupSize + | Bi::SubgroupInvocationId => Capabilities::SUBGROUP, _ => Capabilities::empty(), }; if !self.capabilities.contains(required) { @@ -249,6 +253,17 @@ impl VaryingContext<'_> { scalar: crate::Scalar::U32, }, ), + Bi::NumSubgroups | Bi::SubgroupId => ( + self.stage == St::Compute && !self.output, + *ty_inner == Ti::Scalar(crate::Scalar::U32), + ), + Bi::SubgroupSize | Bi::SubgroupInvocationId => ( + match self.stage { + St::Compute | St::Fragment => !self.output, + St::Vertex => false, + }, + *ty_inner == Ti::Scalar(crate::Scalar::U32), + ), }; if !visible { diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index 70a4d39d2a..7a4623b3b4 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -108,6 +108,8 @@ bitflags::bitflags! { const DUAL_SOURCE_BLENDING = 0x2000; /// Support for arrayed cube textures. const CUBE_ARRAY_TEXTURES = 0x4000; + /// Support for subgroup operations + const SUBGROUP = 0x8000; } } @@ -117,6 +119,57 @@ impl Default for Capabilities { } } +bitflags::bitflags! { + /// Supported subgroup operations + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] + #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] + pub struct SubgroupOperationSet: u8 { + /// Elect, Barrier + const BASIC = 1 << 0; + /// Any, All + const VOTE = 1 << 1; + /// reductions, scans + const ARITHMETIC = 1 << 2; + /// ballot, broadcast + const BALLOT = 1 << 3; + /// shuffle, shuffle xor + const SHUFFLE = 1 << 4; + /// shuffle up, down + const SHUFFLE_RELATIVE = 1 << 5; + // We don't support these operations yet + // /// Clustered + // const CLUSTERED = 1 << 6; + // /// Quad supported + // const QUAD_FRAMENT_COMPUTE = 1 << 7; + // /// Quad supported in all stages + // const QUAD_ALL_STAGES = 1 << 8; + } +} + +impl super::SubgroupOperation { + const fn required_operations(&self) -> SubgroupOperationSet { + use SubgroupOperationSet as S; + match *self { + Self::All | Self::Any => S::VOTE, + Self::Add | Self::Mul | Self::Min | Self::Max | Self::And | Self::Or | Self::Xor => { + S::ARITHMETIC + } + } + } +} + +impl super::GatherMode { + const fn required_operations(&self) -> SubgroupOperationSet { + use SubgroupOperationSet as S; + match *self { + Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT, + Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE, + Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE, + } + } +} + bitflags::bitflags! { /// Validation flags. #[cfg_attr(feature = "serialize", derive(serde::Serialize))] @@ -164,6 +217,8 @@ impl ops::Index> for ModuleInfo { pub struct Validator { flags: ValidationFlags, capabilities: Capabilities, + subgroup_stages: ShaderStages, + subgroup_operations: SubgroupOperationSet, types: Vec, layouter: Layouter, location_mask: BitSet, @@ -284,6 +339,8 @@ impl Validator { Validator { flags, capabilities, + subgroup_stages: ShaderStages::empty(), + subgroup_operations: SubgroupOperationSet::empty(), types: Vec::new(), layouter: Layouter::default(), location_mask: BitSet::new(), @@ -294,6 +351,16 @@ impl Validator { } } + pub fn subgroup_stages(&mut self, stages: ShaderStages) -> &mut Self { + self.subgroup_stages = stages; + self + } + + pub fn subgroup_operations(&mut self, operations: SubgroupOperationSet) -> &mut Self { + self.subgroup_operations = operations; + self + } + /// Reset the validator internals pub fn reset(&mut self) { self.types.clear(); diff --git a/naga/tests/in/spv/subgroup-operations-s.param.ron b/naga/tests/in/spv/subgroup-operations-s.param.ron new file mode 100644 index 0000000000..122542d1f6 --- /dev/null +++ b/naga/tests/in/spv/subgroup-operations-s.param.ron @@ -0,0 +1,27 @@ +( + god_mode: true, + spv: ( + version: (1, 3), + ), + msl: ( + lang_version: (2, 4), + per_entry_point_map: {}, + inline_samplers: [], + spirv_cross_compatibility: false, + fake_missing_bindings: false, + zero_initialize_workgroup_memory: true, + ), + glsl: ( + version: Desktop(430), + writer_flags: (""), + binding_map: { }, + zero_initialize_workgroup_memory: true, + ), + hlsl: ( + shader_model: V6_0, + binding_map: {}, + fake_missing_bindings: true, + special_constants_binding: None, + zero_initialize_workgroup_memory: true, + ), +) diff --git a/naga/tests/in/spv/subgroup-operations-s.spv b/naga/tests/in/spv/subgroup-operations-s.spv new file mode 100644 index 0000000000..d4bf0191db Binary files /dev/null and b/naga/tests/in/spv/subgroup-operations-s.spv differ diff --git a/naga/tests/in/spv/subgroup-operations-s.spvasm b/naga/tests/in/spv/subgroup-operations-s.spvasm new file mode 100644 index 0000000000..72c68aa46c --- /dev/null +++ b/naga/tests/in/spv/subgroup-operations-s.spvasm @@ -0,0 +1,75 @@ +; SPIR-V +; Version: 1.3 +; Generator: rspirv +; Bound: 54 +OpCapability Shader +OpCapability GroupNonUniform +OpCapability GroupNonUniformBallot +OpCapability GroupNonUniformVote +OpCapability GroupNonUniformArithmetic +OpCapability GroupNonUniformShuffle +OpCapability GroupNonUniformShuffleRelative +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %15 "main" %6 %9 %11 %13 +OpExecutionMode %15 LocalSize 1 1 1 +OpDecorate %6 BuiltIn NumSubgroups +OpDecorate %9 BuiltIn SubgroupId +OpDecorate %11 BuiltIn SubgroupSize +OpDecorate %13 BuiltIn SubgroupLocalInvocationId +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypeBool +%7 = OpTypePointer Input %3 +%6 = OpVariable %7 Input +%9 = OpVariable %7 Input +%11 = OpVariable %7 Input +%13 = OpVariable %7 Input +%16 = OpTypeFunction %2 +%17 = OpConstant %3 1 +%18 = OpConstant %3 0 +%19 = OpConstant %3 4 +%21 = OpConstant %3 3 +%22 = OpConstant %3 2 +%23 = OpConstant %3 8 +%26 = OpTypeVector %3 4 +%28 = OpConstantTrue %4 +%15 = OpFunction %2 None %16 +%5 = OpLabel +%8 = OpLoad %3 %6 +%10 = OpLoad %3 %9 +%12 = OpLoad %3 %11 +%14 = OpLoad %3 %13 +OpBranch %20 +%20 = OpLabel +OpControlBarrier %21 %22 %23 +%24 = OpBitwiseAnd %3 %14 %17 +%25 = OpIEqual %4 %24 %17 +%27 = OpGroupNonUniformBallot %26 %21 %25 +%29 = OpGroupNonUniformBallot %26 %21 %28 +%30 = OpINotEqual %4 %14 %18 +%31 = OpGroupNonUniformAll %4 %21 %30 +%32 = OpIEqual %4 %14 %18 +%33 = OpGroupNonUniformAny %4 %21 %32 +%34 = OpGroupNonUniformIAdd %3 %21 Reduce %14 +%35 = OpGroupNonUniformIMul %3 %21 Reduce %14 +%36 = OpGroupNonUniformUMin %3 %21 Reduce %14 +%37 = OpGroupNonUniformUMax %3 %21 Reduce %14 +%38 = OpGroupNonUniformBitwiseAnd %3 %21 Reduce %14 +%39 = OpGroupNonUniformBitwiseOr %3 %21 Reduce %14 +%40 = OpGroupNonUniformBitwiseXor %3 %21 Reduce %14 +%41 = OpGroupNonUniformIAdd %3 %21 ExclusiveScan %14 +%42 = OpGroupNonUniformIMul %3 %21 ExclusiveScan %14 +%43 = OpGroupNonUniformIAdd %3 %21 InclusiveScan %14 +%44 = OpGroupNonUniformIMul %3 %21 InclusiveScan %14 +%45 = OpGroupNonUniformBroadcastFirst %3 %21 %14 +%46 = OpGroupNonUniformBroadcast %3 %21 %14 %19 +%47 = OpISub %3 %12 %17 +%48 = OpISub %3 %47 %14 +%49 = OpGroupNonUniformShuffle %3 %21 %14 %48 +%50 = OpGroupNonUniformShuffleDown %3 %21 %14 %17 +%51 = OpGroupNonUniformShuffleUp %3 %21 %14 %17 +%52 = OpISub %3 %12 %17 +%53 = OpGroupNonUniformShuffleXor %3 %21 %14 %52 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/in/subgroup-operations.param.ron b/naga/tests/in/subgroup-operations.param.ron new file mode 100644 index 0000000000..122542d1f6 --- /dev/null +++ b/naga/tests/in/subgroup-operations.param.ron @@ -0,0 +1,27 @@ +( + god_mode: true, + spv: ( + version: (1, 3), + ), + msl: ( + lang_version: (2, 4), + per_entry_point_map: {}, + inline_samplers: [], + spirv_cross_compatibility: false, + fake_missing_bindings: false, + zero_initialize_workgroup_memory: true, + ), + glsl: ( + version: Desktop(430), + writer_flags: (""), + binding_map: { }, + zero_initialize_workgroup_memory: true, + ), + hlsl: ( + shader_model: V6_0, + binding_map: {}, + fake_missing_bindings: true, + special_constants_binding: None, + zero_initialize_workgroup_memory: true, + ), +) diff --git a/naga/tests/in/subgroup-operations.wgsl b/naga/tests/in/subgroup-operations.wgsl new file mode 100644 index 0000000000..4239be114f --- /dev/null +++ b/naga/tests/in/subgroup-operations.wgsl @@ -0,0 +1,33 @@ +@compute @workgroup_size(1) +fn main( + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(subgroup_invocation_id) subgroup_invocation_id: u32, +) { + subgroupBarrier(); + + subgroupBallot((subgroup_invocation_id & 1u) == 1u); + subgroupBallot(); + + subgroupAll(subgroup_invocation_id != 0u); + subgroupAny(subgroup_invocation_id == 0u); + subgroupAdd(subgroup_invocation_id); + subgroupMul(subgroup_invocation_id); + subgroupMin(subgroup_invocation_id); + subgroupMax(subgroup_invocation_id); + subgroupAnd(subgroup_invocation_id); + subgroupOr(subgroup_invocation_id); + subgroupXor(subgroup_invocation_id); + subgroupPrefixExclusiveAdd(subgroup_invocation_id); + subgroupPrefixExclusiveMul(subgroup_invocation_id); + subgroupPrefixInclusiveAdd(subgroup_invocation_id); + subgroupPrefixInclusiveMul(subgroup_invocation_id); + + subgroupBroadcastFirst(subgroup_invocation_id); + subgroupBroadcast(subgroup_invocation_id, 4u); + subgroupShuffle(subgroup_invocation_id, subgroup_size - 1u - subgroup_invocation_id); + subgroupShuffleDown(subgroup_invocation_id, 1u); + subgroupShuffleUp(subgroup_invocation_id, 1u); + subgroupShuffleXor(subgroup_invocation_id, subgroup_size - 1u); +} diff --git a/naga/tests/out/glsl/subgroup-operations-s.main.Compute.glsl b/naga/tests/out/glsl/subgroup-operations-s.main.Compute.glsl new file mode 100644 index 0000000000..cc1aac5417 --- /dev/null +++ b/naga/tests/out/glsl/subgroup-operations-s.main.Compute.glsl @@ -0,0 +1,58 @@ +#version 430 core +#extension GL_ARB_compute_shader : require +#extension GL_KHR_shader_subgroup_basic : require +#extension GL_KHR_shader_subgroup_vote : require +#extension GL_KHR_shader_subgroup_arithmetic : require +#extension GL_KHR_shader_subgroup_ballot : require +#extension GL_KHR_shader_subgroup_shuffle : require +#extension GL_KHR_shader_subgroup_shuffle_relative : require +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + +uint num_subgroups_1 = 0u; + +uint subgroup_id_1 = 0u; + +uint subgroup_size_1 = 0u; + +uint subgroup_invocation_id_1 = 0u; + + +void main_1() { + uint _e5 = subgroup_size_1; + uint _e6 = subgroup_invocation_id_1; + uvec4 _e9 = subgroupBallot(((_e6 & 1u) == 1u)); + uvec4 _e10 = subgroupBallot(true); + bool _e12 = subgroupAll((_e6 != 0u)); + bool _e14 = subgroupAny((_e6 == 0u)); + uint _e15 = subgroupAdd(_e6); + uint _e16 = subgroupMul(_e6); + uint _e17 = subgroupMin(_e6); + uint _e18 = subgroupMax(_e6); + uint _e19 = subgroupAnd(_e6); + uint _e20 = subgroupOr(_e6); + uint _e21 = subgroupXor(_e6); + uint _e22 = subgroupExclusiveAdd(_e6); + uint _e23 = subgroupExclusiveMul(_e6); + uint _e24 = subgroupInclusiveAdd(_e6); + uint _e25 = subgroupInclusiveMul(_e6); + uint _e26 = subgroupBroadcastFirst(_e6); + uint _e27 = subgroupBroadcast(_e6, 4u); + uint _e30 = subgroupShuffle(_e6, ((_e5 - 1u) - _e6)); + uint _e31 = subgroupShuffleDown(_e6, 1u); + uint _e32 = subgroupShuffleUp(_e6, 1u); + uint _e34 = subgroupShuffleXor(_e6, (_e5 - 1u)); + return; +} + +void main() { + uint num_subgroups = gl_NumSubgroups; + uint subgroup_id = gl_SubgroupID; + uint subgroup_size = gl_SubgroupSize; + uint subgroup_invocation_id = gl_SubgroupInvocationID; + num_subgroups_1 = num_subgroups; + subgroup_id_1 = subgroup_id; + subgroup_size_1 = subgroup_size; + subgroup_invocation_id_1 = subgroup_invocation_id; + main_1(); +} + diff --git a/naga/tests/out/glsl/subgroup-operations.main.Compute.glsl b/naga/tests/out/glsl/subgroup-operations.main.Compute.glsl new file mode 100644 index 0000000000..9a92460a89 --- /dev/null +++ b/naga/tests/out/glsl/subgroup-operations.main.Compute.glsl @@ -0,0 +1,42 @@ +#version 430 core +#extension GL_ARB_compute_shader : require +#extension GL_KHR_shader_subgroup_basic : require +#extension GL_KHR_shader_subgroup_vote : require +#extension GL_KHR_shader_subgroup_arithmetic : require +#extension GL_KHR_shader_subgroup_ballot : require +#extension GL_KHR_shader_subgroup_shuffle : require +#extension GL_KHR_shader_subgroup_shuffle_relative : require +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + + +void main() { + uint num_subgroups = gl_NumSubgroups; + uint subgroup_id = gl_SubgroupID; + uint subgroup_size = gl_SubgroupSize; + uint subgroup_invocation_id = gl_SubgroupInvocationID; + subgroupMemoryBarrier(); + barrier(); + uvec4 _e8 = subgroupBallot(((subgroup_invocation_id & 1u) == 1u)); + uvec4 _e9 = subgroupBallot(true); + bool _e12 = subgroupAll((subgroup_invocation_id != 0u)); + bool _e15 = subgroupAny((subgroup_invocation_id == 0u)); + uint _e16 = subgroupAdd(subgroup_invocation_id); + uint _e17 = subgroupMul(subgroup_invocation_id); + uint _e18 = subgroupMin(subgroup_invocation_id); + uint _e19 = subgroupMax(subgroup_invocation_id); + uint _e20 = subgroupAnd(subgroup_invocation_id); + uint _e21 = subgroupOr(subgroup_invocation_id); + uint _e22 = subgroupXor(subgroup_invocation_id); + uint _e23 = subgroupExclusiveAdd(subgroup_invocation_id); + uint _e24 = subgroupExclusiveMul(subgroup_invocation_id); + uint _e25 = subgroupInclusiveAdd(subgroup_invocation_id); + uint _e26 = subgroupInclusiveMul(subgroup_invocation_id); + uint _e27 = subgroupBroadcastFirst(subgroup_invocation_id); + uint _e29 = subgroupBroadcast(subgroup_invocation_id, 4u); + uint _e33 = subgroupShuffle(subgroup_invocation_id, ((subgroup_size - 1u) - subgroup_invocation_id)); + uint _e35 = subgroupShuffleDown(subgroup_invocation_id, 1u); + uint _e37 = subgroupShuffleUp(subgroup_invocation_id, 1u); + uint _e40 = subgroupShuffleXor(subgroup_invocation_id, (subgroup_size - 1u)); + return; +} + diff --git a/naga/tests/out/hlsl/subgroup-operations-s.hlsl b/naga/tests/out/hlsl/subgroup-operations-s.hlsl new file mode 100644 index 0000000000..e1e399e213 --- /dev/null +++ b/naga/tests/out/hlsl/subgroup-operations-s.hlsl @@ -0,0 +1,49 @@ +static uint num_subgroups_1 = (uint)0; +static uint subgroup_id_1 = (uint)0; +static uint subgroup_size_1 = (uint)0; +static uint subgroup_invocation_id_1 = (uint)0; + +void main_1() +{ + uint _expr5 = subgroup_size_1; + uint _expr6 = subgroup_invocation_id_1; + const uint4 _e9 = WaveActiveBallot(((_expr6 & 1u) == 1u)); + const uint4 _e10 = WaveActiveBallot(true); + const bool _e12 = WaveActiveAllTrue((_expr6 != 0u)); + const bool _e14 = WaveActiveAnyTrue((_expr6 == 0u)); + const uint _e15 = WaveActiveSum(_expr6); + const uint _e16 = WaveActiveProduct(_expr6); + const uint _e17 = WaveActiveMin(_expr6); + const uint _e18 = WaveActiveMax(_expr6); + const uint _e19 = WaveActiveBitAnd(_expr6); + const uint _e20 = WaveActiveBitOr(_expr6); + const uint _e21 = WaveActiveBitXor(_expr6); + const uint _e22 = WavePrefixSum(_expr6); + const uint _e23 = WavePrefixProduct(_expr6); + const uint _e24 = _expr6 + WavePrefixSum(_expr6); + const uint _e25 = _expr6 * WavePrefixProduct(_expr6); + const uint _e26 = WaveReadLaneFirst(_expr6); + const uint _e27 = WaveReadLaneAt(_expr6, 4u); + const uint _e30 = WaveReadLaneAt(_expr6, ((_expr5 - 1u) - _expr6)); + const uint _e31 = WaveReadLaneAt(_expr6, WaveGetLaneIndex() + 1u); + const uint _e32 = WaveReadLaneAt(_expr6, WaveGetLaneIndex() - 1u); + const uint _e34 = WaveReadLaneAt(_expr6, WaveGetLaneIndex() ^ (_expr5 - 1u)); + return; +} + +[numthreads(1, 1, 1)] +void main(uint3 __local_invocation_id : SV_GroupThreadID) +{ + if (all(__local_invocation_id == uint3(0u, 0u, 0u))) { + } + GroupMemoryBarrierWithGroupSync(); + const uint num_subgroups = (1u + WaveGetLaneCount() - 1u) / WaveGetLaneCount(); + const uint subgroup_id = (__local_invocation_id.z * 1u + __local_invocation_id.y * 1u + __local_invocation_id.x) / WaveGetLaneCount(); + const uint subgroup_size = WaveGetLaneCount(); + const uint subgroup_invocation_id = WaveGetLaneIndex(); + num_subgroups_1 = num_subgroups; + subgroup_id_1 = subgroup_id; + subgroup_size_1 = subgroup_size; + subgroup_invocation_id_1 = subgroup_invocation_id; + main_1(); +} diff --git a/naga/tests/out/hlsl/subgroup-operations-s.ron b/naga/tests/out/hlsl/subgroup-operations-s.ron new file mode 100644 index 0000000000..b973fe3da1 --- /dev/null +++ b/naga/tests/out/hlsl/subgroup-operations-s.ron @@ -0,0 +1,12 @@ +( + vertex:[ + ], + fragment:[ + ], + compute:[ + ( + entry_point:"main", + target_profile:"cs_6_0", + ), + ], +) diff --git a/naga/tests/out/hlsl/subgroup-operations.hlsl b/naga/tests/out/hlsl/subgroup-operations.hlsl new file mode 100644 index 0000000000..a79fd8a38d --- /dev/null +++ b/naga/tests/out/hlsl/subgroup-operations.hlsl @@ -0,0 +1,33 @@ +[numthreads(1, 1, 1)] +void main(uint3 __local_invocation_id : SV_GroupThreadID) +{ + if (all(__local_invocation_id == uint3(0u, 0u, 0u))) { + } + GroupMemoryBarrierWithGroupSync(); + const uint num_subgroups = (1u + WaveGetLaneCount() - 1u) / WaveGetLaneCount(); + const uint subgroup_id = (__local_invocation_id.z * 1u + __local_invocation_id.y * 1u + __local_invocation_id.x) / WaveGetLaneCount(); + const uint subgroup_size = WaveGetLaneCount(); + const uint subgroup_invocation_id = WaveGetLaneIndex(); + const uint4 _e8 = WaveActiveBallot(((subgroup_invocation_id & 1u) == 1u)); + const uint4 _e9 = WaveActiveBallot(true); + const bool _e12 = WaveActiveAllTrue((subgroup_invocation_id != 0u)); + const bool _e15 = WaveActiveAnyTrue((subgroup_invocation_id == 0u)); + const uint _e16 = WaveActiveSum(subgroup_invocation_id); + const uint _e17 = WaveActiveProduct(subgroup_invocation_id); + const uint _e18 = WaveActiveMin(subgroup_invocation_id); + const uint _e19 = WaveActiveMax(subgroup_invocation_id); + const uint _e20 = WaveActiveBitAnd(subgroup_invocation_id); + const uint _e21 = WaveActiveBitOr(subgroup_invocation_id); + const uint _e22 = WaveActiveBitXor(subgroup_invocation_id); + const uint _e23 = WavePrefixSum(subgroup_invocation_id); + const uint _e24 = WavePrefixProduct(subgroup_invocation_id); + const uint _e25 = subgroup_invocation_id + WavePrefixSum(subgroup_invocation_id); + const uint _e26 = subgroup_invocation_id * WavePrefixProduct(subgroup_invocation_id); + const uint _e27 = WaveReadLaneFirst(subgroup_invocation_id); + const uint _e29 = WaveReadLaneAt(subgroup_invocation_id, 4u); + const uint _e33 = WaveReadLaneAt(subgroup_invocation_id, ((subgroup_size - 1u) - subgroup_invocation_id)); + const uint _e35 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() + 1u); + const uint _e37 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() - 1u); + const uint _e40 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() ^ (subgroup_size - 1u)); + return; +} diff --git a/naga/tests/out/hlsl/subgroup-operations.ron b/naga/tests/out/hlsl/subgroup-operations.ron new file mode 100644 index 0000000000..b973fe3da1 --- /dev/null +++ b/naga/tests/out/hlsl/subgroup-operations.ron @@ -0,0 +1,12 @@ +( + vertex:[ + ], + fragment:[ + ], + compute:[ + ( + entry_point:"main", + target_profile:"cs_6_0", + ), + ], +) diff --git a/naga/tests/out/msl/subgroup-operations-s.msl b/naga/tests/out/msl/subgroup-operations-s.msl new file mode 100644 index 0000000000..3a6f30231c --- /dev/null +++ b/naga/tests/out/msl/subgroup-operations-s.msl @@ -0,0 +1,55 @@ +// language: metal2.4 +#include +#include + +using metal::uint; + + +void main_1( + thread uint& subgroup_size_1, + thread uint& subgroup_invocation_id_1 +) { + uint _e5 = subgroup_size_1; + uint _e6 = subgroup_invocation_id_1; + metal::uint4 unnamed = uint4((uint64_t)metal::simd_ballot((_e6 & 1u) == 1u), 0, 0, 0); + metal::uint4 unnamed_1 = uint4((uint64_t)metal::simd_ballot(true), 0, 0, 0); + bool unnamed_2 = metal::simd_all(_e6 != 0u); + bool unnamed_3 = metal::simd_any(_e6 == 0u); + uint unnamed_4 = metal::simd_sum(_e6); + uint unnamed_5 = metal::simd_product(_e6); + uint unnamed_6 = metal::simd_min(_e6); + uint unnamed_7 = metal::simd_max(_e6); + uint unnamed_8 = metal::simd_and(_e6); + uint unnamed_9 = metal::simd_or(_e6); + uint unnamed_10 = metal::simd_xor(_e6); + uint unnamed_11 = metal::simd_prefix_exclusive_sum(_e6); + uint unnamed_12 = metal::simd_prefix_exclusive_product(_e6); + uint unnamed_13 = metal::simd_prefix_inclusive_sum(_e6); + uint unnamed_14 = metal::simd_prefix_inclusive_product(_e6); + uint unnamed_15 = metal::simd_broadcast_first(_e6); + uint unnamed_16 = metal::simd_broadcast(_e6, 4u); + uint unnamed_17 = metal::simd_shuffle(_e6, (_e5 - 1u) - _e6); + uint unnamed_18 = metal::simd_shuffle_down(_e6, 1u); + uint unnamed_19 = metal::simd_shuffle_up(_e6, 1u); + uint unnamed_20 = metal::simd_shuffle_xor(_e6, _e5 - 1u); + return; +} + +struct main_Input { +}; +kernel void main_( + uint num_subgroups [[simdgroups_per_threadgroup]] +, uint subgroup_id [[simdgroup_index_in_threadgroup]] +, uint subgroup_size [[threads_per_simdgroup]] +, uint subgroup_invocation_id [[thread_index_in_simdgroup]] +) { + uint num_subgroups_1 = {}; + uint subgroup_id_1 = {}; + uint subgroup_size_1 = {}; + uint subgroup_invocation_id_1 = {}; + num_subgroups_1 = num_subgroups; + subgroup_id_1 = subgroup_id; + subgroup_size_1 = subgroup_size; + subgroup_invocation_id_1 = subgroup_invocation_id; + main_1(subgroup_size_1, subgroup_invocation_id_1); +} diff --git a/naga/tests/out/msl/subgroup-operations.msl b/naga/tests/out/msl/subgroup-operations.msl new file mode 100644 index 0000000000..fe41696892 --- /dev/null +++ b/naga/tests/out/msl/subgroup-operations.msl @@ -0,0 +1,39 @@ +// language: metal2.4 +#include +#include + +using metal::uint; + + +struct main_Input { +}; +kernel void main_( + uint num_subgroups [[simdgroups_per_threadgroup]] +, uint subgroup_id [[simdgroup_index_in_threadgroup]] +, uint subgroup_size [[threads_per_simdgroup]] +, uint subgroup_invocation_id [[thread_index_in_simdgroup]] +) { + metal::simdgroup_barrier(metal::mem_flags::mem_threadgroup); + metal::uint4 unnamed = uint4((uint64_t)metal::simd_ballot((subgroup_invocation_id & 1u) == 1u), 0, 0, 0); + metal::uint4 unnamed_1 = uint4((uint64_t)metal::simd_ballot(true), 0, 0, 0); + bool unnamed_2 = metal::simd_all(subgroup_invocation_id != 0u); + bool unnamed_3 = metal::simd_any(subgroup_invocation_id == 0u); + uint unnamed_4 = metal::simd_sum(subgroup_invocation_id); + uint unnamed_5 = metal::simd_product(subgroup_invocation_id); + uint unnamed_6 = metal::simd_min(subgroup_invocation_id); + uint unnamed_7 = metal::simd_max(subgroup_invocation_id); + uint unnamed_8 = metal::simd_and(subgroup_invocation_id); + uint unnamed_9 = metal::simd_or(subgroup_invocation_id); + uint unnamed_10 = metal::simd_xor(subgroup_invocation_id); + uint unnamed_11 = metal::simd_prefix_exclusive_sum(subgroup_invocation_id); + uint unnamed_12 = metal::simd_prefix_exclusive_product(subgroup_invocation_id); + uint unnamed_13 = metal::simd_prefix_inclusive_sum(subgroup_invocation_id); + uint unnamed_14 = metal::simd_prefix_inclusive_product(subgroup_invocation_id); + uint unnamed_15 = metal::simd_broadcast_first(subgroup_invocation_id); + uint unnamed_16 = metal::simd_broadcast(subgroup_invocation_id, 4u); + uint unnamed_17 = metal::simd_shuffle(subgroup_invocation_id, (subgroup_size - 1u) - subgroup_invocation_id); + uint unnamed_18 = metal::simd_shuffle_down(subgroup_invocation_id, 1u); + uint unnamed_19 = metal::simd_shuffle_up(subgroup_invocation_id, 1u); + uint unnamed_20 = metal::simd_shuffle_xor(subgroup_invocation_id, subgroup_size - 1u); + return; +} diff --git a/naga/tests/out/spv/subgroup-operations.spvasm b/naga/tests/out/spv/subgroup-operations.spvasm new file mode 100644 index 0000000000..73d3d52c61 --- /dev/null +++ b/naga/tests/out/spv/subgroup-operations.spvasm @@ -0,0 +1,75 @@ +; SPIR-V +; Version: 1.3 +; Generator: rspirv +; Bound: 54 +OpCapability Shader +OpCapability GroupNonUniform +OpCapability GroupNonUniformBallot +OpCapability GroupNonUniformVote +OpCapability GroupNonUniformArithmetic +OpCapability GroupNonUniformShuffle +OpCapability GroupNonUniformShuffleRelative +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %15 "main" %6 %9 %11 %13 +OpExecutionMode %15 LocalSize 1 1 1 +OpDecorate %6 BuiltIn NumSubgroups +OpDecorate %9 BuiltIn SubgroupId +OpDecorate %11 BuiltIn SubgroupSize +OpDecorate %13 BuiltIn SubgroupLocalInvocationId +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypeBool +%7 = OpTypePointer Input %3 +%6 = OpVariable %7 Input +%9 = OpVariable %7 Input +%11 = OpVariable %7 Input +%13 = OpVariable %7 Input +%16 = OpTypeFunction %2 +%17 = OpConstant %3 1 +%18 = OpConstant %3 0 +%19 = OpConstant %3 4 +%21 = OpConstant %3 3 +%22 = OpConstant %3 2 +%23 = OpConstant %3 8 +%26 = OpTypeVector %3 4 +%28 = OpConstantTrue %4 +%15 = OpFunction %2 None %16 +%5 = OpLabel +%8 = OpLoad %3 %6 +%10 = OpLoad %3 %9 +%12 = OpLoad %3 %11 +%14 = OpLoad %3 %13 +OpBranch %20 +%20 = OpLabel +OpControlBarrier %21 %22 %23 +%24 = OpBitwiseAnd %3 %14 %17 +%25 = OpIEqual %4 %24 %17 +%27 = OpGroupNonUniformBallot %26 %21 %25 +%29 = OpGroupNonUniformBallot %26 %21 %28 +%30 = OpINotEqual %4 %14 %18 +%31 = OpGroupNonUniformAll %4 %21 %30 +%32 = OpIEqual %4 %14 %18 +%33 = OpGroupNonUniformAny %4 %21 %32 +%34 = OpGroupNonUniformIAdd %3 %21 Reduce %14 +%35 = OpGroupNonUniformIMul %3 %21 Reduce %14 +%36 = OpGroupNonUniformUMin %3 %21 Reduce %14 +%37 = OpGroupNonUniformUMax %3 %21 Reduce %14 +%38 = OpGroupNonUniformBitwiseAnd %3 %21 Reduce %14 +%39 = OpGroupNonUniformBitwiseOr %3 %21 Reduce %14 +%40 = OpGroupNonUniformBitwiseXor %3 %21 Reduce %14 +%41 = OpGroupNonUniformIAdd %3 %21 ExclusiveScan %14 +%42 = OpGroupNonUniformIMul %3 %21 ExclusiveScan %14 +%43 = OpGroupNonUniformIAdd %3 %21 InclusiveScan %14 +%44 = OpGroupNonUniformIMul %3 %21 InclusiveScan %14 +%45 = OpGroupNonUniformBroadcastFirst %3 %21 %14 +%46 = OpGroupNonUniformShuffle %3 %21 %14 %19 +%47 = OpISub %3 %12 %17 +%48 = OpISub %3 %47 %14 +%49 = OpGroupNonUniformShuffle %3 %21 %14 %48 +%50 = OpGroupNonUniformShuffleDown %3 %21 %14 %17 +%51 = OpGroupNonUniformShuffleUp %3 %21 %14 %17 +%52 = OpISub %3 %12 %17 +%53 = OpGroupNonUniformShuffleXor %3 %21 %14 %52 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/out/wgsl/subgroup-operations-s.wgsl b/naga/tests/out/wgsl/subgroup-operations-s.wgsl new file mode 100644 index 0000000000..54e3d60b3a --- /dev/null +++ b/naga/tests/out/wgsl/subgroup-operations-s.wgsl @@ -0,0 +1,40 @@ +var num_subgroups_1: u32; +var subgroup_id_1: u32; +var subgroup_size_1: u32; +var subgroup_invocation_id_1: u32; + +fn main_1() { + let _e5 = subgroup_size_1; + let _e6 = subgroup_invocation_id_1; + let _e9 = subgroupBallot(((_e6 & 1u) == 1u)); + let _e10 = subgroupBallot(); + let _e12 = subgroupAll((_e6 != 0u)); + let _e14 = subgroupAny((_e6 == 0u)); + let _e15 = subgroupAdd(_e6); + let _e16 = subgroupMul(_e6); + let _e17 = subgroupMin(_e6); + let _e18 = subgroupMax(_e6); + let _e19 = subgroupAnd(_e6); + let _e20 = subgroupOr(_e6); + let _e21 = subgroupXor(_e6); + let _e22 = subgroupPrefixExclusiveAdd(_e6); + let _e23 = subgroupPrefixExclusiveMul(_e6); + let _e24 = subgroupPrefixInclusiveAdd(_e6); + let _e25 = subgroupPrefixInclusiveMul(_e6); + let _e26 = subgroupBroadcastFirst(_e6); + let _e27 = subgroupBroadcast(_e6, 4u); + let _e30 = subgroupShuffle(_e6, ((_e5 - 1u) - _e6)); + let _e31 = subgroupShuffleDown(_e6, 1u); + let _e32 = subgroupShuffleUp(_e6, 1u); + let _e34 = subgroupShuffleXor(_e6, (_e5 - 1u)); + return; +} + +@compute @workgroup_size(1, 1, 1) +fn main(@builtin(num_subgroups) num_subgroups: u32, @builtin(subgroup_id) subgroup_id: u32, @builtin(subgroup_size) subgroup_size: u32, @builtin(subgroup_invocation_id) subgroup_invocation_id: u32) { + num_subgroups_1 = num_subgroups; + subgroup_id_1 = subgroup_id; + subgroup_size_1 = subgroup_size; + subgroup_invocation_id_1 = subgroup_invocation_id; + main_1(); +} diff --git a/naga/tests/out/wgsl/subgroup-operations.wgsl b/naga/tests/out/wgsl/subgroup-operations.wgsl new file mode 100644 index 0000000000..c53aa3e2cf --- /dev/null +++ b/naga/tests/out/wgsl/subgroup-operations.wgsl @@ -0,0 +1,26 @@ +@compute @workgroup_size(1, 1, 1) +fn main(@builtin(num_subgroups) num_subgroups: u32, @builtin(subgroup_id) subgroup_id: u32, @builtin(subgroup_size) subgroup_size: u32, @builtin(subgroup_invocation_id) subgroup_invocation_id: u32) { + subgroupBarrier(); + let _e8 = subgroupBallot(((subgroup_invocation_id & 1u) == 1u)); + let _e9 = subgroupBallot(); + let _e12 = subgroupAll((subgroup_invocation_id != 0u)); + let _e15 = subgroupAny((subgroup_invocation_id == 0u)); + let _e16 = subgroupAdd(subgroup_invocation_id); + let _e17 = subgroupMul(subgroup_invocation_id); + let _e18 = subgroupMin(subgroup_invocation_id); + let _e19 = subgroupMax(subgroup_invocation_id); + let _e20 = subgroupAnd(subgroup_invocation_id); + let _e21 = subgroupOr(subgroup_invocation_id); + let _e22 = subgroupXor(subgroup_invocation_id); + let _e23 = subgroupPrefixExclusiveAdd(subgroup_invocation_id); + let _e24 = subgroupPrefixExclusiveMul(subgroup_invocation_id); + let _e25 = subgroupPrefixInclusiveAdd(subgroup_invocation_id); + let _e26 = subgroupPrefixInclusiveMul(subgroup_invocation_id); + let _e27 = subgroupBroadcastFirst(subgroup_invocation_id); + let _e29 = subgroupBroadcast(subgroup_invocation_id, 4u); + let _e33 = subgroupShuffle(subgroup_invocation_id, ((subgroup_size - 1u) - subgroup_invocation_id)); + let _e35 = subgroupShuffleDown(subgroup_invocation_id, 1u); + let _e37 = subgroupShuffleUp(subgroup_invocation_id, 1u); + let _e40 = subgroupShuffleXor(subgroup_invocation_id, (subgroup_size - 1u)); + return; +} diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index 4ad17f1a2a..420fbc1316 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -258,10 +258,18 @@ fn check_targets( let params = input.read_parameters(); let name = &input.file_name; - let capabilities = if params.god_mode { - naga::valid::Capabilities::all() + let (capabilities, subgroup_stages, subgroup_operations) = if params.god_mode { + ( + naga::valid::Capabilities::all(), + naga::valid::ShaderStages::all(), + naga::valid::SubgroupOperationSet::all(), + ) } else { - naga::valid::Capabilities::default() + ( + naga::valid::Capabilities::default(), + naga::valid::ShaderStages::empty(), + naga::valid::SubgroupOperationSet::empty(), + ) }; #[cfg(feature = "serialize")] @@ -274,6 +282,8 @@ fn check_targets( } let info = naga::valid::Validator::new(naga::valid::ValidationFlags::all(), capabilities) + .subgroup_stages(subgroup_stages) + .subgroup_operations(subgroup_operations) .validate(module) .unwrap_or_else(|_| panic!("Naga module validation failed on test '{}'", name.display())); @@ -291,6 +301,8 @@ fn check_targets( } naga::valid::Validator::new(naga::valid::ValidationFlags::all(), capabilities) + .subgroup_stages(subgroup_stages) + .subgroup_operations(subgroup_operations) .validate(module) .unwrap_or_else(|_| { panic!( @@ -783,6 +795,10 @@ fn convert_wgsl() { "f64", Targets::SPIRV | Targets::GLSL | Targets::HLSL | Targets::WGSL, ), + ( + "subgroup-operations", + Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, + ), ]; for &(name, targets) in inputs.iter() { @@ -862,6 +878,11 @@ fn convert_spv_all() { true, Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, ); + convert_spv( + "subgroup-operations-s", + false, + Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, + ); } #[cfg(feature = "glsl-in")] diff --git a/tests/tests/root.rs b/tests/tests/root.rs index 5fd119b2c9..7dedd73d6f 100644 --- a/tests/tests/root.rs +++ b/tests/tests/root.rs @@ -32,6 +32,7 @@ mod scissor_tests; mod shader; mod shader_primitive_index; mod shader_view_format; +mod subgroup_operations; mod texture_bounds; mod transfer; mod vertex_indices; diff --git a/tests/tests/subgroup_operations/mod.rs b/tests/tests/subgroup_operations/mod.rs new file mode 100644 index 0000000000..3fb386873d --- /dev/null +++ b/tests/tests/subgroup_operations/mod.rs @@ -0,0 +1,135 @@ +use std::{borrow::Cow, num::NonZeroU64}; + +use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters}; + +const THREAD_COUNT: u64 = 128; +const TEST_COUNT: u32 = 31; + +#[gpu_test] +static SUBGROUP_OPERATIONS: GpuTestConfiguration = GpuTestConfiguration::new() + .parameters( + TestParameters::default() + .features(wgpu::Features::SUBGROUP_COMPUTE) + .limits(wgpu::Limits::downlevel_defaults()) + .expect_fail(wgpu_test::FailureCase::molten_vk()) + .expect_fail( + // Expect metal to fail on tests involving operations in divergent control flow + wgpu_test::FailureCase::backend(wgpu::Backends::METAL) + .panic("thread 0 failed tests: 27,\nthread 1 failed tests: 27, 28,\n"), + ), + ) + .run_sync(|ctx| { + let device = &ctx.device; + + let storage_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size: THREAD_COUNT * std::mem::size_of::() as u64, + usage: wgpu::BufferUsages::STORAGE + | wgpu::BufferUsages::COPY_DST + | wgpu::BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + + let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + label: Some("bind group layout"), + entries: &[wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: NonZeroU64::new( + THREAD_COUNT * std::mem::size_of::() as u64, + ), + }, + count: None, + }], + }); + + let cs_module = device.create_shader_module(wgpu::ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(include_str!("shader.wgsl"))), + }); + + let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: Some("main"), + bind_group_layouts: &[&bind_group_layout], + push_constant_ranges: &[], + }); + + let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: None, + layout: Some(&pipeline_layout), + module: &cs_module, + entry_point: "main", + }); + + let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { + entries: &[wgpu::BindGroupEntry { + binding: 0, + resource: storage_buffer.as_entire_binding(), + }], + layout: &bind_group_layout, + label: Some("bind group"), + }); + + let mut encoder = + device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + { + let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: None, + timestamp_writes: None, + }); + cpass.set_pipeline(&compute_pipeline); + cpass.set_bind_group(0, &bind_group, &[]); + cpass.dispatch_workgroups(THREAD_COUNT as u32, 1, 1); + } + + let mapping_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Mapping buffer"), + size: THREAD_COUNT * std::mem::size_of::() as u64, + usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + encoder.copy_buffer_to_buffer( + &storage_buffer, + 0, + &mapping_buffer, + 0, + THREAD_COUNT * std::mem::size_of::() as u64, + ); + ctx.queue.submit(Some(encoder.finish())); + + mapping_buffer + .slice(..) + .map_async(wgpu::MapMode::Read, |_| ()); + ctx.device.poll(wgpu::Maintain::Wait); + let mapping_buffer_view = mapping_buffer.slice(..).get_mapped_range(); + let result: &[u32; THREAD_COUNT as usize] = bytemuck::from_bytes(&mapping_buffer_view); + let expected_mask = (1u64 << (TEST_COUNT)) - 1; // generate full mask + let expected_array = [expected_mask as u32; THREAD_COUNT as usize]; + if result != &expected_array { + use std::fmt::Write; + let mut msg = String::new(); + writeln!( + &mut msg, + "Got from GPU:\n{:x?}\n expected:\n{:x?}", + result, &expected_array, + ) + .unwrap(); + for (thread, (result, expected)) in result + .iter() + .zip(expected_array) + .enumerate() + .filter(|(_, (r, e))| *r != e) + { + write!(&mut msg, "thread {} failed tests:", thread).unwrap(); + let difference = result ^ expected; + for i in (0..u32::BITS).filter(|i| (difference & (1 << i)) != 0) { + write!(&mut msg, " {},", i).unwrap(); + } + writeln!(&mut msg).unwrap(); + } + panic!("{}", msg); + } + }); diff --git a/tests/tests/subgroup_operations/shader.wgsl b/tests/tests/subgroup_operations/shader.wgsl new file mode 100644 index 0000000000..0f1dc47cd9 --- /dev/null +++ b/tests/tests/subgroup_operations/shader.wgsl @@ -0,0 +1,173 @@ +@group(0) +@binding(0) +var storage_buffer: array; + +@compute +@workgroup_size(128) +fn main( + @builtin(global_invocation_id) global_id: vec3, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(subgroup_invocation_id) subgroup_invocation_id: u32, +) { + var passed = 0u; + var expected: u32; + + var mask = 1u << 0u; + passed |= mask * u32(num_subgroups == 128u / subgroup_size); + mask = 1u << 1u; + passed |= mask * u32(subgroup_id == global_id.x / subgroup_size); + mask = 1u << 2u; + passed |= mask * u32(subgroup_invocation_id == global_id.x % subgroup_size); + + var expected_ballot = vec4(0u); + for(var i = 0u; i < subgroup_size; i += 1u) { + expected_ballot[i / 32u] |= ((global_id.x - subgroup_invocation_id + i) & 1u) << (i % 32u); + } + mask = 1u << 3u; + passed |= mask * u32(dot(vec4(1u), vec4(subgroupBallot((subgroup_invocation_id & 1u) == 1u) == expected_ballot)) == 4u); + + mask = 1u << 4u; + passed |= mask * u32(subgroupAll(true)); + mask = 1u << 5u; + passed |= mask * u32(!subgroupAll(subgroup_invocation_id != 0u)); + + mask = 1u << 6u; + passed |= mask * u32(subgroupAny(subgroup_invocation_id == 0u)); + mask = 1u << 7u; + passed |= mask * u32(!subgroupAny(false)); + + expected = 0u; + for(var i = 0u; i < subgroup_size; i += 1u) { + expected += global_id.x - subgroup_invocation_id + i + 1u; + } + mask = 1u << 8u; + passed |= mask * u32(subgroupAdd(global_id.x + 1u) == expected); + + expected = 1u; + for(var i = 0u; i < subgroup_size; i += 1u) { + expected *= global_id.x - subgroup_invocation_id + i + 1u; + } + mask = 1u << 9u; + passed |= mask * u32(subgroupMul(global_id.x + 1u) == expected); + + expected = 0u; + for(var i = 0u; i < subgroup_size; i += 1u) { + expected = max(expected, global_id.x - subgroup_invocation_id + i + 1u); + } + mask = 1u << 10u; + passed |= mask * u32(subgroupMax(global_id.x + 1u) == expected); + + expected = 0xFFFFFFFFu; + for(var i = 0u; i < subgroup_size; i += 1u) { + expected = min(expected, global_id.x - subgroup_invocation_id + i + 1u); + } + mask = 1u << 11u; + passed |= mask * u32(subgroupMin(global_id.x + 1u) == expected); + + expected = 0xFFFFFFFFu; + for(var i = 0u; i < subgroup_size; i += 1u) { + expected &= global_id.x - subgroup_invocation_id + i + 1u; + } + mask = 1u << 12u; + passed |= mask * u32(subgroupAnd(global_id.x + 1u) == expected); + + expected = 0u; + for(var i = 0u; i < subgroup_size; i += 1u) { + expected |= global_id.x - subgroup_invocation_id + i + 1u; + } + mask = 1u << 13u; + passed |= mask * u32(subgroupOr(global_id.x + 1u) == expected); + + expected = 0u; + for(var i = 0u; i < subgroup_size; i += 1u) { + expected ^= global_id.x - subgroup_invocation_id + i + 1u; + } + mask = 1u << 14u; + passed |= mask * u32(subgroupXor(global_id.x + 1u) == expected); + + expected = 0u; + for(var i = 0u; i < subgroup_invocation_id; i += 1u) { + expected += global_id.x - subgroup_invocation_id + i + 1u; + } + mask = 1u << 15u; + passed |= mask * u32(subgroupPrefixExclusiveAdd(global_id.x + 1u) == expected); + + expected = 1u; + for(var i = 0u; i < subgroup_invocation_id; i += 1u) { + expected *= global_id.x - subgroup_invocation_id + i + 1u; + } + mask = 1u << 16u; + passed |= mask * u32(subgroupPrefixExclusiveMul(global_id.x + 1u) == expected); + + expected = 0u; + for(var i = 0u; i <= subgroup_invocation_id; i += 1u) { + expected += global_id.x - subgroup_invocation_id + i + 1u; + } + mask = 1u << 17u; + passed |= mask * u32(subgroupPrefixInclusiveAdd(global_id.x + 1u) == expected); + + expected = 1u; + for(var i = 0u; i <= subgroup_invocation_id; i += 1u) { + expected *= global_id.x - subgroup_invocation_id + i + 1u; + } + mask = 1u << 18u; + passed |= mask * u32(subgroupPrefixInclusiveMul(global_id.x + 1u) == expected); + + mask = 1u << 19u; + passed |= mask * u32(subgroupBroadcastFirst(u32(subgroup_invocation_id != 0u)) == 0u); + mask = 1u << 20u; + passed |= mask * u32(subgroupBroadcastFirst(u32(subgroup_invocation_id == 0u)) == 1u); + mask = 1u << 21u; + passed |= mask * u32(subgroupBroadcast(subgroup_invocation_id, 1u) == 1u); + mask = 1u << 22u; + passed |= mask * u32(subgroupShuffle(subgroup_invocation_id, subgroup_invocation_id) == subgroup_invocation_id); + mask = 1u << 23u; + passed |= mask * u32(subgroupShuffle(subgroup_invocation_id, subgroup_size - 1u - subgroup_invocation_id) == subgroup_size - 1u - subgroup_invocation_id); + mask = 1u << 24u; + passed |= mask * u32(subgroup_invocation_id == subgroup_size - 1u || subgroupShuffleDown(subgroup_invocation_id, 1u) == subgroup_invocation_id + 1u); + mask = 1u << 25u; + passed |= mask * u32(subgroup_invocation_id == 0u || subgroupShuffleUp(subgroup_invocation_id, 1u) == subgroup_invocation_id - 1u); + mask = 1u << 26u; + passed |= mask * u32(subgroupShuffleXor(subgroup_invocation_id, subgroup_size - 1u) == (subgroup_invocation_id ^ (subgroup_size - 1u))); + + mask = 1u << 27u; + if subgroup_invocation_id % 2u == 0u { + passed |= mask * u32(subgroupAdd(1u) == (subgroup_size / 2u)); + } else { + passed |= mask * u32(subgroupAdd(1u) == (subgroup_size / 2u)); + } + + mask = 1u << 28u; + switch subgroup_invocation_id % 3u { + case 0u: { + passed |= mask * u32(subgroupBroadcastFirst(subgroup_invocation_id) == 0u); + } + case 1u: { + passed |= mask * u32(subgroupBroadcastFirst(subgroup_invocation_id) == 1u); + } + case 2u: { + passed |= mask * u32(subgroupBroadcastFirst(subgroup_invocation_id) == 2u); + } + default { } + } + + mask = 1u << 29u; + expected = 0u; + for (var i = subgroup_size; i >= 0u; i -= 1u) { + expected = subgroupAdd(1u); + if i == subgroup_invocation_id { + break; + } + } + passed |= mask * u32(expected == (subgroup_invocation_id + 1u)); + + // Keep this test last, verify we are still convergent after running other tests + mask = 1u << 30u; + passed |= mask * u32(subgroup_size == subgroupAdd(1u)); + + // Increment TEST_COUNT in subgroup_operations/mod.rs if adding more tests + + storage_buffer[global_id.x] = passed; +} diff --git a/wgpu-core/src/device/resource.rs b/wgpu-core/src/device/resource.rs index 6d2e140a8b..a9a3814426 100644 --- a/wgpu-core/src/device/resource.rs +++ b/wgpu-core/src/device/resource.rs @@ -1360,6 +1360,14 @@ impl Device { .flags .contains(wgt::DownlevelFlags::CUBE_ARRAY_TEXTURES), ); + caps.set( + Caps::SUBGROUP, + self.features.intersects( + wgt::Features::SUBGROUP_COMPUTE + | wgt::Features::SUBGROUP_FRAGMENT + | wgt::Features::SUBGROUP_VERTEX, + ), + ); let debug_source = if self.instance_flags.contains(wgt::InstanceFlags::DEBUG) && !source.is_empty() { @@ -1375,7 +1383,30 @@ impl Device { None }; + let mut subgroup_stages = naga::valid::ShaderStages::empty(); + subgroup_stages.set( + naga::valid::ShaderStages::COMPUTE, + self.features.contains(wgt::Features::SUBGROUP_COMPUTE), + ); + subgroup_stages.set( + naga::valid::ShaderStages::FRAGMENT, + self.features.contains(wgt::Features::SUBGROUP_FRAGMENT), + ); + subgroup_stages.set( + naga::valid::ShaderStages::VERTEX, + self.features.contains(wgt::Features::SUBGROUP_VERTEX), + ); + + let subgroup_operations = if caps.contains(Caps::SUBGROUP) { + use naga::valid::SubgroupOperationSet as S; + S::BASIC | S::VOTE | S::ARITHMETIC | S::BALLOT | S::SHUFFLE | S::SHUFFLE_RELATIVE + } else { + naga::valid::SubgroupOperationSet::empty() + }; + let info = naga::valid::Validator::new(naga::valid::ValidationFlags::all(), caps) + .subgroup_stages(subgroup_stages) + .subgroup_operations(subgroup_operations) .validate(&module) .map_err(|inner| { pipeline::CreateShaderModuleError::Validation(pipeline::ShaderError { diff --git a/wgpu-hal/src/dx12/adapter.rs b/wgpu-hal/src/dx12/adapter.rs index c2ffda8429..72695b212c 100644 --- a/wgpu-hal/src/dx12/adapter.rs +++ b/wgpu-hal/src/dx12/adapter.rs @@ -127,6 +127,11 @@ impl super::Adapter { ) }); + // If we don't have dxc, we reduce the max to 5.1 + if dxc_container.is_none() { + shader_model_support.HighestShaderModel = d3d12_ty::D3D_SHADER_MODEL_5_1; + } + let mut workarounds = super::Workarounds::default(); let info = wgt::AdapterInfo { @@ -294,6 +299,11 @@ impl super::Adapter { bgra8unorm_storage_supported, ); + features.set( + wgt::Features::SUBGROUP_COMPUTE | wgt::Features::SUBGROUP_FRAGMENT, + shader_model_support.HighestShaderModel >= d3d12_ty::D3D_SHADER_MODEL_6_0, + ); + // TODO: Determine if IPresentationManager is supported let presentation_timer = auxil::dxgi::time::PresentationTimer::new_dxgi(); diff --git a/wgpu-hal/src/metal/adapter.rs b/wgpu-hal/src/metal/adapter.rs index e05091053a..88af3887b5 100644 --- a/wgpu-hal/src/metal/adapter.rs +++ b/wgpu-hal/src/metal/adapter.rs @@ -804,6 +804,12 @@ impl super::PrivateCapabilities { None }, timestamp_query_support, + supports_simd_scoped_operations: family_check + && (device.supports_family(MTLGPUFamily::Metal3) + || device.supports_family(MTLGPUFamily::Mac2) + || device.supports_family(MTLGPUFamily::Apple7) + || device.supports_family(MTLGPUFamily::Apple8) + || device.supports_family(MTLGPUFamily::Apple9)), } } @@ -884,6 +890,10 @@ impl super::PrivateCapabilities { features.set(F::RG11B10UFLOAT_RENDERABLE, self.format_rg11b10_all); features.set(F::SHADER_UNUSED_VERTEX_OUTPUT, true); + if self.supports_simd_scoped_operations { + features.insert(F::SUBGROUP_COMPUTE | F::SUBGROUP_FRAGMENT | F::SUBGROUP_VERTEX); + } + features } diff --git a/wgpu-hal/src/metal/mod.rs b/wgpu-hal/src/metal/mod.rs index 0ddf96ed4a..46cf41c927 100644 --- a/wgpu-hal/src/metal/mod.rs +++ b/wgpu-hal/src/metal/mod.rs @@ -263,6 +263,7 @@ struct PrivateCapabilities { supports_shader_primitive_index: bool, has_unified_memory: Option, timestamp_query_support: TimestampQuerySupport, + supports_simd_scoped_operations: bool, } #[derive(Clone, Debug)] diff --git a/wgpu-hal/src/vulkan/adapter.rs b/wgpu-hal/src/vulkan/adapter.rs index 0af87eb072..fc4b47c95a 100644 --- a/wgpu-hal/src/vulkan/adapter.rs +++ b/wgpu-hal/src/vulkan/adapter.rs @@ -495,6 +495,36 @@ impl PhysicalDeviceFeatures { ); } + if let Some(ref subgroup) = caps.subgroup { + if subgroup.supported_operations.contains( + vk::SubgroupFeatureFlags::BASIC + | vk::SubgroupFeatureFlags::VOTE + | vk::SubgroupFeatureFlags::ARITHMETIC + | vk::SubgroupFeatureFlags::BALLOT + | vk::SubgroupFeatureFlags::SHUFFLE + | vk::SubgroupFeatureFlags::SHUFFLE_RELATIVE, + ) { + features.set( + F::SUBGROUP_COMPUTE, + subgroup + .supported_stages + .contains(vk::ShaderStageFlags::COMPUTE), + ); + features.set( + F::SUBGROUP_FRAGMENT, + subgroup + .supported_stages + .contains(vk::ShaderStageFlags::FRAGMENT), + ); + features.set( + F::SUBGROUP_VERTEX, + subgroup + .supported_stages + .contains(vk::ShaderStageFlags::VERTEX), + ); + } + } + let supports_depth_format = |format| { supports_format( instance, @@ -573,6 +603,8 @@ pub struct PhysicalDeviceCapabilities { maintenance_3: Option, descriptor_indexing: Option, driver: Option, + subgroup: Option, + /// The effective driver api version supported by the physical device. /// The device API version. /// /// Which is the version of Vulkan supported for device-level functionality. @@ -838,6 +870,13 @@ impl super::InstanceShared { builder = builder.push_next(next); } + if capabilities.device_api_version >= vk::API_VERSION_1_1 { + let next = capabilities + .subgroup + .insert(vk::PhysicalDeviceSubgroupProperties::default()); + builder = builder.push_next(next); + } + let mut properties2 = builder.build(); unsafe { get_device_properties.get_physical_device_properties2(phd, &mut properties2); @@ -1293,6 +1332,19 @@ impl super::Adapter { capabilities.push(spv::Capability::Geometry); } + if features.intersects( + wgt::Features::SUBGROUP_COMPUTE + | wgt::Features::SUBGROUP_FRAGMENT + | wgt::Features::SUBGROUP_VERTEX, + ) { + capabilities.push(spv::Capability::GroupNonUniform); + capabilities.push(spv::Capability::GroupNonUniformVote); + capabilities.push(spv::Capability::GroupNonUniformArithmetic); + capabilities.push(spv::Capability::GroupNonUniformBallot); + capabilities.push(spv::Capability::GroupNonUniformShuffle); + capabilities.push(spv::Capability::GroupNonUniformShuffleRelative); + } + if features.intersects( wgt::Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING | wgt::Features::UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING, @@ -1320,7 +1372,15 @@ impl super::Adapter { true, // could check `super::Workarounds::SEPARATE_ENTRY_POINTS` ); spv::Options { - lang_version: (1, 0), + lang_version: if features.intersects( + wgt::Features::SUBGROUP_COMPUTE + | wgt::Features::SUBGROUP_FRAGMENT + | wgt::Features::SUBGROUP_VERTEX, + ) { + (1, 3) + } else { + (1, 0) + }, flags, capabilities: Some(capabilities.iter().cloned().collect()), bounds_check_policies: naga::proc::BoundsCheckPolicies { diff --git a/wgpu-types/src/lib.rs b/wgpu-types/src/lib.rs index 3e9ab37a24..bc81376d45 100644 --- a/wgpu-types/src/lib.rs +++ b/wgpu-types/src/lib.rs @@ -780,10 +780,35 @@ bitflags::bitflags! { /// This is a native only feature. const TEXTURE_FORMAT_NV12 = 1 << 55; - // 55..59 available - // Shader: + /// Allows compute shaders to use the subgroup operation built-ins + /// + /// Supported Platforms: + /// - Vulkan + /// - DX12 + /// - Metal + /// + /// This is a native only feature. + const SUBGROUP_COMPUTE = 1 << 56; + /// Allows fragment shaders to use the subgroup operation built-ins + /// + /// Supported Platforms: + /// - Vulkan + /// - DX12 + /// - Metal + /// + /// This is a native only feature. + const SUBGROUP_FRAGMENT = 1 << 57; + /// Allows vertex shaders to use the subgroup operation built-ins + /// + /// Supported Platforms: + /// - Vulkan + /// - Metal + /// + /// This is a native only feature. + const SUBGROUP_VERTEX = 1 << 58; + /// Enables 64-bit floating point types in SPIR-V shaders. /// /// Note: even when supported by GPU hardware, 64-bit floating point operations are