From 14f6be2f9e35c4ff5655fc01cbad4606e618c9a8 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Fri, 29 Sep 2023 23:31:39 -0400 Subject: [PATCH 01/46] subgroup: Implement subgroupBallot for wgsl-in, wgsl-out, spv-out, hlsl-out TODO: metal out, figure out what needs to be done in validation --- naga/src/back/dot/mod.rs | 5 +++++ naga/src/back/glsl/mod.rs | 13 ++++++++++++- naga/src/back/hlsl/writer.rs | 12 +++++++++++- naga/src/back/msl/writer.rs | 2 ++ naga/src/back/spv/block.rs | 21 ++++++++++++++++++++- naga/src/back/spv/instructions.rs | 17 +++++++++++++++++ naga/src/back/wgsl/writer.rs | 9 +++++++++ naga/src/compact/expressions.rs | 2 ++ naga/src/compact/statements.rs | 4 ++++ naga/src/front/spv/mod.rs | 1 + naga/src/front/wgsl/lower/mod.rs | 10 ++++++++++ naga/src/lib.rs | 16 ++++++++++++++-- naga/src/proc/constant_evaluator.rs | 3 +++ naga/src/proc/terminator.rs | 1 + naga/src/proc/typifier.rs | 5 +++++ naga/src/valid/analyzer.rs | 5 +++++ naga/src/valid/expression.rs | 1 + naga/src/valid/function.rs | 3 +++ naga/src/valid/handles.rs | 5 +++++ 19 files changed, 130 insertions(+), 5 deletions(-) diff --git a/naga/src/back/dot/mod.rs b/naga/src/back/dot/mod.rs index 1556371df1..b24eeae5ad 100644 --- a/naga/src/back/dot/mod.rs +++ b/naga/src/back/dot/mod.rs @@ -279,6 +279,10 @@ impl StatementGraph { crate::RayQueryFunction::Terminate => "RayQueryTerminate", } } + S::SubgroupBallot { result } => { + self.emits.push((id, result)); + "SubgroupBallot" + } }; // Set the last node to the merge node last_node = merge_id; @@ -586,6 +590,7 @@ fn write_function_expressions( let ty = if committed { "Committed" } else { "Candidate" }; (format!("rayQueryGet{}Intersection", ty).into(), 4) } + E::SubgroupBallotResult => ("SubgroupBallotResult".into(), 4), }; // give uniform expressions an outline diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index 592c72a9a5..370ed49231 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -2294,6 +2294,16 @@ impl<'a, W: Write> Writer<'a, W> { writeln!(self.out, ");")?; } Statement::RayQuery { .. } => unreachable!(), + Statement::SubgroupBallot { result } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + let res_ty = ctx.info[result].ty.inner_with(&self.module.types); + self.write_value_type(res_ty)?; + write!(self.out, " {res_name} = ")?; + self.named_expressions.insert(result, res_name); + + writeln!(self.out, "subgroupBallot(true);")?; + } } Ok(()) @@ -3467,7 +3477,8 @@ impl<'a, W: Write> Writer<'a, W> { Expression::CallResult(_) | Expression::AtomicResult { .. } | Expression::RayQueryProceedResult - | Expression::WorkGroupUniformLoadResult { .. } => unreachable!(), + | Expression::WorkGroupUniformLoadResult { .. } + | Expression::SubgroupBallotResult => unreachable!(), // `ArrayLength` is written as `expr.length()` and we convert it to a uint Expression::ArrayLength(expr) => { write!(self.out, "uint(")?; diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index f26604476a..fc92cbd800 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -2004,6 +2004,15 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { writeln!(self.out, "{level}}}")? } Statement::RayQuery { .. } => unreachable!(), + Statement::SubgroupBallot { result } => { + write!(self.out, "{level}")?; + + let name = format!("{}{}", back::BAKE_PREFIX, result.index()); + write!(self.out, "const uint4 {name} = ")?; + self.named_expressions.insert(result, name); + + writeln!(self.out, "WaveActiveBallot(true);")?; + } } Ok(()) @@ -3152,7 +3161,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { Expression::CallResult(_) | Expression::AtomicResult { .. } | Expression::WorkGroupUniformLoadResult { .. } - | Expression::RayQueryProceedResult => {} + | Expression::RayQueryProceedResult + | Expression::SubgroupBallotResult => {} } if !closing_bracket.is_empty() { diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 09f7b1c73f..d8d6426664 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -1997,6 +1997,7 @@ impl Writer { } write!(self.out, "}}")?; } + crate::Expression::SubgroupBallotResult => todo!(), } Ok(()) } @@ -3010,6 +3011,7 @@ impl Writer { } } } + crate::Statement::SubgroupBallot { .. } => todo!(), } } diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index 0471d957f0..357b7c3459 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -1130,7 +1130,8 @@ impl<'w> BlockContext<'w> { crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } | crate::Expression::WorkGroupUniformLoadResult { .. } - | crate::Expression::RayQueryProceedResult => self.cached[expr_handle], + | crate::Expression::RayQueryProceedResult + | crate::Expression::SubgroupBallotResult => self.cached[expr_handle], crate::Expression::As { expr, kind, @@ -2338,6 +2339,24 @@ impl<'w> BlockContext<'w> { crate::Statement::RayQuery { query, ref fun } => { self.write_ray_query_function(query, fun, &mut block); } + crate::Statement::SubgroupBallot { result } => { + let vec4_u32_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(crate::VectorSize::Quad), + kind: crate::ScalarKind::Uint, + width: 4, + pointer_space: None, + })); + let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); + let predicate = self.writer.get_constant_scalar(crate::Literal::Bool(true)); + let id = self.gen_id(); + block.body.push(Instruction::group_non_uniform_ballot( + vec4_u32_type_id, + id, + exec_scope_id, + predicate, + )); + self.cached[result] = id; + } } } diff --git a/naga/src/back/spv/instructions.rs b/naga/src/back/spv/instructions.rs index b963793ad3..1ca58431d5 100644 --- a/naga/src/back/spv/instructions.rs +++ b/naga/src/back/spv/instructions.rs @@ -1037,6 +1037,23 @@ impl super::Instruction { instruction.add_operand(semantics_id); instruction } + + // Group Instructions + + pub(super) fn group_non_uniform_ballot( + result_type_id: Word, + id: Word, + exec_scope_id: Word, + predicate: Word, + ) -> Self { + let mut instruction = Self::new(Op::GroupNonUniformBallot); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(exec_scope_id); + instruction.add_operand(predicate); + + instruction + } } impl From for spirv::ImageFormat { diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 075d85558c..18887825ea 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -921,6 +921,14 @@ impl Writer { } } Statement::RayQuery { .. } => unreachable!(), + Statement::SubgroupBallot { result } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_named_expr(module, result, func_ctx, &res_name)?; + self.named_expressions.insert(result, res_name); + + writeln!(self.out, "subgroupBallot();")?; + } } Ok(()) @@ -1659,6 +1667,7 @@ impl Writer { Expression::CallResult(_) | Expression::AtomicResult { .. } | Expression::RayQueryProceedResult + | Expression::SubgroupBallotResult | Expression::WorkGroupUniformLoadResult { .. } => {} } diff --git a/naga/src/compact/expressions.rs b/naga/src/compact/expressions.rs index c1326e92be..d62d00a85f 100644 --- a/naga/src/compact/expressions.rs +++ b/naga/src/compact/expressions.rs @@ -55,6 +55,7 @@ impl<'tracer> ExpressionTracer<'tracer> { | Ex::GlobalVariable(_) | Ex::LocalVariable(_) | Ex::CallResult(_) + | Ex::SubgroupBallotResult // FIXME: ??? | Ex::RayQueryProceedResult => {} Ex::Constant(handle) => { @@ -222,6 +223,7 @@ impl ModuleMap { | Ex::GlobalVariable(_) | Ex::LocalVariable(_) | Ex::CallResult(_) + | Ex::SubgroupBallotResult // FIXME: ??? | Ex::RayQueryProceedResult => {} // Expressions that contain handles that need to be adjusted. diff --git a/naga/src/compact/statements.rs b/naga/src/compact/statements.rs index 4c62771023..3b27c8b71a 100644 --- a/naga/src/compact/statements.rs +++ b/naga/src/compact/statements.rs @@ -95,6 +95,9 @@ impl FunctionTracer<'_> { self.trace_expression(query); self.trace_ray_query_function(fun); } + St::SubgroupBallot { result } => { + self.trace_expression(result); + } // Trivial statements. St::Break @@ -244,6 +247,7 @@ impl FunctionMap { adjust(query); self.adjust_ray_query_function(fun); } + St::SubgroupBallot { ref mut result } => adjust(result), // Trivial statements. St::Break diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index 083205a45b..2658a95d30 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -3842,6 +3842,7 @@ impl> Frontend { } } S::WorkGroupUniformLoad { .. } => unreachable!(), + S::SubgroupBallot { .. } => unreachable!(), } i += 1; } diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 236656c45d..cef62521fd 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -2241,6 +2241,16 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { )?; return Ok(Some(handle)); } + "subgroupBallot" => { + ctx.prepare_args(arguments, 0, span).finish()?; + + let result = ctx + .interrupt_emitter(crate::Expression::SubgroupBallotResult, span)?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .push(crate::Statement::SubgroupBallot { result }, span); + return Ok(Some(result)); + } _ => return Err(Error::UnknownIdent(function.span, function.name)), } }; diff --git a/naga/src/lib.rs b/naga/src/lib.rs index 300c6e4820..31cae75d8d 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -1399,7 +1399,9 @@ pub enum Expression { /// /// For [`TypeInner::Atomic`] the result is a corresponding scalar. /// For other types behind the `pointer`, the result is `T`. - Load { pointer: Handle }, + Load { + pointer: Handle, + }, /// Sample a point from a sampled or a depth image. ImageSample { image: Handle, @@ -1539,7 +1541,10 @@ pub enum Expression { /// Result of calling another function. CallResult(Handle), /// Result of an atomic operation. - AtomicResult { ty: Handle, comparison: bool }, + AtomicResult { + ty: Handle, + comparison: bool, + }, /// Result of a [`WorkGroupUniformLoad`] statement. /// /// [`WorkGroupUniformLoad`]: Statement::WorkGroupUniformLoad @@ -1567,6 +1572,7 @@ pub enum Expression { query: Handle, committed: bool, }, + SubgroupBallotResult, } pub use block::Block; @@ -1839,6 +1845,12 @@ pub enum Statement { /// The specific operation we're performing on `query`. fun: RayQueryFunction, }, + SubgroupBallot { + /// The [`SubgroupBallotResult`] expression representing this load's result. + /// + /// [`SubgroupBallotResult`]: Expression::SubgroupBallotResult + result: Handle, + }, } /// A function argument. diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index 2082743975..d45ab4683f 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -133,6 +133,8 @@ pub enum ConstantEvaluatorError { ImageExpression, #[error("Constants don't support ray query expressions")] RayQueryExpression, + #[error("Constants don't support subgroup operations")] + SubgroupOperation, #[error("Cannot access the type")] InvalidAccessBase, #[error("Cannot access at the index")] @@ -439,6 +441,7 @@ impl<'a> ConstantEvaluator<'a> { Expression::RayQueryProceedResult | Expression::RayQueryGetIntersection { .. } => { Err(ConstantEvaluatorError::RayQueryExpression) } + Expression::SubgroupBallotResult => Err(ConstantEvaluatorError::SubgroupOperation), } } diff --git a/naga/src/proc/terminator.rs b/naga/src/proc/terminator.rs index a5239d4eca..d2dde729f1 100644 --- a/naga/src/proc/terminator.rs +++ b/naga/src/proc/terminator.rs @@ -37,6 +37,7 @@ pub fn ensure_block_returns(block: &mut crate::Block) { | S::RayQuery { .. } | S::Atomic { .. } | S::WorkGroupUniformLoad { .. } + | S::SubgroupBallot { .. } | S::Barrier(_)), ) | None => block.push(S::Return { value: None }, Default::default()), diff --git a/naga/src/proc/typifier.rs b/naga/src/proc/typifier.rs index ad9eec94d2..6241c5bad8 100644 --- a/naga/src/proc/typifier.rs +++ b/naga/src/proc/typifier.rs @@ -905,6 +905,11 @@ impl<'a> ResolveContext<'a> { .ok_or(ResolveError::MissingSpecialType)?; TypeResolution::Handle(result) } + crate::Expression::SubgroupBallotResult => TypeResolution::Value(Ti::Vector { + kind: crate::ScalarKind::Uint, + size: crate::VectorSize::Quad, + width: 4, + }), }) } } diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index ff1db071c8..d23caaf473 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -740,6 +740,10 @@ impl FunctionInfo { non_uniform_result: self.add_ref(query), requirements: UniformityRequirements::empty(), }, + E::SubgroupBallotResult => Uniformity { + non_uniform_result: None, + requirements: UniformityRequirements::empty(), + }, }; let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?; @@ -983,6 +987,7 @@ impl FunctionInfo { } FunctionUniformity::new() } + S::SubgroupBallot { result: _ } => FunctionUniformity::new(), }; disruptor = disruptor.or(uniformity.exit_disruptor()); diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index f77844b4b1..890a4b9973 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -1537,6 +1537,7 @@ impl super::Validator { return Err(ExpressionError::InvalidRayQueryType(query)); } }, + E::SubgroupBallotResult => ShaderStages::COMPUTE | ShaderStages::FRAGMENT, }; Ok(stages) } diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index d967f4b1f3..52f51a2810 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -919,6 +919,9 @@ impl super::Validator { crate::RayQueryFunction::Terminate => {} } } + S::SubgroupBallot { result } => { + self.emit_expression(result, context)?; + } } } Ok(BlockInfo { stages, finished }) diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index c68ded074b..547dfac551 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -394,6 +394,7 @@ impl super::Validator { } crate::Expression::AtomicResult { .. } | crate::Expression::RayQueryProceedResult + | crate::Expression::SubgroupBallotResult | crate::Expression::WorkGroupUniformLoadResult { .. } => (), crate::Expression::ArrayLength(array) => { handle.check_dep(array)?; @@ -539,6 +540,10 @@ impl super::Validator { } Ok(()) } + crate::Statement::SubgroupBallot { result } => { + validate_expr(result)?; + Ok(()) + } crate::Statement::Break | crate::Statement::Continue | crate::Statement::Kill From e969f09bc35dcbea0f53107a71e104a8252ded00 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Sat, 30 Sep 2023 15:54:00 -0400 Subject: [PATCH 02/46] subgroup: subgroupBallot metal out --- naga/src/back/msl/writer.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index d8d6426664..145a4f1ff0 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -1935,6 +1935,7 @@ impl Writer { crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } | crate::Expression::WorkGroupUniformLoadResult { .. } + | crate::Expression::SubgroupBallotResult | crate::Expression::RayQueryProceedResult => { unreachable!() } @@ -1997,7 +1998,6 @@ impl Writer { } write!(self.out, "}}")?; } - crate::Expression::SubgroupBallotResult => todo!(), } Ok(()) } @@ -3011,7 +3011,13 @@ impl Writer { } } } - crate::Statement::SubgroupBallot { .. } => todo!(), + crate::Statement::SubgroupBallot { result } => { + write!(self.out, "{level}")?; + let name = self.namer.call(""); + self.start_baking_expression(result, &context.expression, &name)?; + self.named_expressions.insert(result, name); + write!(self.out, "{NAMESPACE}::simd_active_threads_mask();")?; + } } } From 05feb8830ff5ce54a325dc0852a6c1fb85befa7b Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Sat, 30 Sep 2023 17:53:44 -0400 Subject: [PATCH 03/46] subgroup: require GroupNonUnifomBallot capability --- naga/src/back/spv/block.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index 357b7c3459..5c36cd8bcc 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -2340,6 +2340,10 @@ impl<'w> BlockContext<'w> { self.write_ray_query_function(query, fun, &mut block); } crate::Statement::SubgroupBallot { result } => { + self.writer.require_any( + "GroupNonUniformBallot", + &[spirv::Capability::GroupNonUniformBallot], + )?; let vec4_u32_type_id = self.get_type_id(LookupType::Local(LocalType::Value { vector_size: Some(crate::VectorSize::Quad), kind: crate::ScalarKind::Uint, From 2ff88098f7d6be2c5006afe1c3bb3177459f8a55 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Sat, 30 Sep 2023 22:59:39 -0400 Subject: [PATCH 04/46] subgroup: Add subgroup invocation id and subgroup size builtins --- naga/src/back/glsl/mod.rs | 3 +++ naga/src/back/hlsl/conv.rs | 9 ++++++--- naga/src/back/msl/mod.rs | 3 +++ naga/src/back/spv/writer.rs | 18 ++++++++++++++++++ naga/src/back/wgsl/writer.rs | 2 ++ naga/src/front/wgsl/parse/conv.rs | 3 +++ naga/src/lib.rs | 3 +++ naga/src/valid/interface.rs | 11 +++++++++++ 8 files changed, 49 insertions(+), 3 deletions(-) diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index 370ed49231..0869e4d9ee 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -4212,6 +4212,9 @@ const fn glsl_built_in( Bi::WorkGroupId => "gl_WorkGroupID", Bi::WorkGroupSize => "gl_WorkGroupSize", Bi::NumWorkGroups => "gl_NumWorkGroups", + // subgroup + Bi::SubgroupInvocationId => "gl_SubgroupInvocationID", + Bi::SubgroupSize => "gl_SubgroupSize", } } diff --git a/naga/src/back/hlsl/conv.rs b/naga/src/back/hlsl/conv.rs index 19bde6926a..19c4da5e74 100644 --- a/naga/src/back/hlsl/conv.rs +++ b/naga/src/back/hlsl/conv.rs @@ -166,9 +166,12 @@ impl crate::BuiltIn { // to this field will get replaced with references to `SPECIAL_CBUF_VAR` // in `Writer::write_expr`. Self::NumWorkGroups => "SV_GroupID", - Self::BaseInstance | Self::BaseVertex | Self::WorkGroupSize => { - return Err(Error::Unimplemented(format!("builtin {self:?}"))) - } + + Self::SubgroupInvocationId + | Self::SubgroupSize + | Self::BaseInstance + | Self::BaseVertex + | Self::WorkGroupSize => return Err(Error::Unimplemented(format!("builtin {self:?}"))), Self::PointSize | Self::ViewIndex | Self::PointCoord => { return Err(Error::Custom(format!("Unsupported builtin {self:?}"))) } diff --git a/naga/src/back/msl/mod.rs b/naga/src/back/msl/mod.rs index 5ef18730c9..9e23d2a08d 100644 --- a/naga/src/back/msl/mod.rs +++ b/naga/src/back/msl/mod.rs @@ -437,6 +437,9 @@ impl ResolvedBinding { Bi::WorkGroupId => "threadgroup_position_in_grid", Bi::WorkGroupSize => "dispatch_threads_per_threadgroup", Bi::NumWorkGroups => "threadgroups_per_grid", + // subgroup + Bi::SubgroupInvocationId => "simdgroup_index_in_threadgroup", + Bi::SubgroupSize => "simdgroups_per_threadgroup", Bi::CullDistance | Bi::ViewIndex => { return Err(Error::UnsupportedBuiltIn(built_in)) } diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 24cb14a161..a445fb2b10 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -1589,6 +1589,24 @@ impl Writer { Bi::WorkGroupId => BuiltIn::WorkgroupId, Bi::WorkGroupSize => BuiltIn::WorkgroupSize, Bi::NumWorkGroups => BuiltIn::NumWorkgroups, + // Subgroup + Bi::SubgroupInvocationId => { + self.require_any( + "`subgroup_invocation_id` built-in", + &[spirv::Capability::GroupNonUniform], + )?; + BuiltIn::SubgroupLocalInvocationId + } + Bi::SubgroupSize => { + self.require_any( + "`subgroup_invocation_id` built-in", + &[ + spirv::Capability::GroupNonUniform, + spirv::Capability::SubgroupBallotKHR, + ], + )?; + BuiltIn::SubgroupSize + } }; self.decorate(id, Decoration::BuiltIn, &[built_in as u32]); diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 18887825ea..33d45635c4 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -1769,6 +1769,8 @@ fn builtin_str(built_in: crate::BuiltIn) -> Result<&'static str, Error> { Bi::SampleMask => "sample_mask", Bi::PrimitiveIndex => "primitive_index", Bi::ViewIndex => "view_index", + Bi::SubgroupInvocationId => "subgroup_invocation_id", + Bi::SubgroupSize => "subgroup_size", Bi::BaseInstance | Bi::BaseVertex | Bi::ClipDistance diff --git a/naga/src/front/wgsl/parse/conv.rs b/naga/src/front/wgsl/parse/conv.rs index 51977173d6..a27bdb1cbc 100644 --- a/naga/src/front/wgsl/parse/conv.rs +++ b/naga/src/front/wgsl/parse/conv.rs @@ -34,6 +34,9 @@ pub fn map_built_in(word: &str, span: Span) -> Result> "local_invocation_index" => crate::BuiltIn::LocalInvocationIndex, "workgroup_id" => crate::BuiltIn::WorkGroupId, "num_workgroups" => crate::BuiltIn::NumWorkGroups, + // subgroup + "subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId, + "subgroup_size" => crate::BuiltIn::SubgroupSize, _ => return Err(Error::UnknownBuiltin(span)), }) } diff --git a/naga/src/lib.rs b/naga/src/lib.rs index 31cae75d8d..c4d9eef359 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -435,6 +435,9 @@ pub enum BuiltIn { WorkGroupId, WorkGroupSize, NumWorkGroups, + // subgroup + SubgroupInvocationId, + SubgroupSize, } /// Number of bytes per scalar. diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index 6c41ece81f..c1ffc4447a 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -299,6 +299,17 @@ impl VaryingContext<'_> { width, }, ), + Bi::SubgroupInvocationId | Bi::SubgroupSize => ( + match self.stage { + St::Compute | St::Fragment => !self.output, + St::Vertex => false, + }, + *ty_inner + == Ti::Scalar { + kind: Sk::Uint, + width, + }, + ), }; if !visible { From 8cbf423168805eea2067cae50f9cd81a8a9271f7 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Sun, 1 Oct 2023 20:08:30 -0400 Subject: [PATCH 05/46] subgroup: SubgroupInvocationId is only valid in compute stages --- naga/src/valid/interface.rs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index c1ffc4447a..bf4f397224 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -299,7 +299,15 @@ impl VaryingContext<'_> { width, }, ), - Bi::SubgroupInvocationId | Bi::SubgroupSize => ( + Bi::SubgroupInvocationId => ( + self.stage == St::Compute && !self.output, + *ty_inner + == Ti::Scalar { + kind: Sk::Uint, + width, + }, + ), + Bi::SubgroupSize => ( match self.stage { St::Compute | St::Fragment => !self.output, St::Vertex => false, From 943060343fdb48461014c1e0461b3a6762129fee Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Tue, 3 Oct 2023 16:20:25 -0400 Subject: [PATCH 06/46] subgroup: expierment with subgroupBarrier() based on OpControlbarrier SPIR-V OpControlBarrier with execution scope Subgroup has implementation defined behavior when executed nonuniformly. OpenCL SPIR-V execution spec say nonuniform execution is UB. Vulkan SPIR-V execution spec says nothing :). --- naga/src/back/glsl/mod.rs | 3 +++ naga/src/back/hlsl/writer.rs | 3 +++ naga/src/back/msl/writer.rs | 3 +++ naga/src/back/spv/writer.rs | 6 +++++- naga/src/front/wgsl/lower/mod.rs | 8 ++++++++ naga/src/lib.rs | 6 ++++-- 6 files changed, 26 insertions(+), 3 deletions(-) diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index 0869e4d9ee..0b680df256 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -4035,6 +4035,9 @@ impl<'a, W: Write> Writer<'a, W> { if flags.contains(crate::Barrier::WORK_GROUP) { writeln!(self.out, "{level}memoryBarrierShared();")?; } + if flags.contains(crate::Barrier::SUB_GROUP) { + unimplemented!() // FIXME + } writeln!(self.out, "{level}barrier();")?; Ok(()) } diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index fc92cbd800..821dcc743d 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -3229,6 +3229,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { if barrier.contains(crate::Barrier::WORK_GROUP) { writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?; } + if barrier.contains(crate::Barrier::SUB_GROUP) { + unimplemented!() // FIXME + } Ok(()) } } diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 145a4f1ff0..c2a1eaddd0 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -4354,6 +4354,9 @@ impl Writer { "{level}{NAMESPACE}::threadgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);", )?; } + if flags.contains(crate::Barrier::SUB_GROUP) { + unimplemented!(); // FIXME + } Ok(()) } } diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index a445fb2b10..077751d90b 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -1314,7 +1314,11 @@ impl Writer { spirv::MemorySemantics::WORKGROUP_MEMORY, flags.contains(crate::Barrier::WORK_GROUP), ); - let exec_scope_id = self.get_index_constant(spirv::Scope::Workgroup as u32); + let exec_scope_id = if flags.contains(crate::Barrier::SUB_GROUP) { + self.get_index_constant(spirv::Scope::Subgroup as u32) + } else { + self.get_index_constant(spirv::Scope::Workgroup as u32) + }; let mem_scope_id = self.get_index_constant(memory_scope as u32); let semantics_id = self.get_index_constant(semantics.bits()); block.body.push(Instruction::control_barrier( diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index cef62521fd..57a1f860d1 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -2034,6 +2034,14 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .push(crate::Statement::Barrier(crate::Barrier::WORK_GROUP), span); return Ok(None); } + "subgroupBarrier" => { + ctx.prepare_args(arguments, 0, span).finish()?; + + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .push(crate::Statement::Barrier(crate::Barrier::SUB_GROUP), span); + return Ok(None); + } "workgroupUniformLoad" => { let mut args = ctx.prepare_args(arguments, 1, span); let expr = args.next()?; diff --git a/naga/src/lib.rs b/naga/src/lib.rs index c4d9eef359..a6f755fae6 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -1269,9 +1269,11 @@ bitflags::bitflags! { #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] pub struct Barrier: u32 { /// Barrier affects all `AddressSpace::Storage` accesses. - const STORAGE = 0x1; + const STORAGE = 1 << 0; /// Barrier affects all `AddressSpace::WorkGroup` accesses. - const WORK_GROUP = 0x2; + const WORK_GROUP = 1 << 1; + /// Barrier synchronizes execution across all invocations within a subgroup that exectue this instruction. + const SUB_GROUP = 1 << 2; } } From 44f69292871bdbb41847d78f609f425ce2c48631 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Sat, 14 Oct 2023 14:20:43 +0200 Subject: [PATCH 07/46] subgroup: add statement for rest of subgroup ops --- naga/src/back/dot/mod.rs | 20 +++++ naga/src/back/glsl/mod.rs | 16 ++++ naga/src/back/hlsl/writer.rs | 18 +++- naga/src/back/msl/writer.rs | 16 ++++ naga/src/back/spv/block.rs | 18 +++- naga/src/back/spv/instructions.rs | 17 ++++ naga/src/back/spv/mod.rs | 1 + naga/src/back/spv/subgroup.rs | 83 +++++++++++++++++++ naga/src/back/wgsl/writer.rs | 16 ++++ naga/src/compact/expressions.rs | 2 + naga/src/compact/statements.rs | 40 +++++++++ naga/src/front/spv/mod.rs | 4 +- naga/src/front/wgsl/lower/mod.rs | 39 +++++++++ naga/src/lib.rs | 82 +++++++++++++++++++ naga/src/proc/constant_evaluator.rs | 4 +- naga/src/proc/terminator.rs | 2 + naga/src/proc/typifier.rs | 1 + naga/src/valid/analyzer.rs | 26 +++++- naga/src/valid/expression.rs | 1 + naga/src/valid/function.rs | 122 ++++++++++++++++++++++++++++ naga/src/valid/handles.rs | 22 +++++ 21 files changed, 545 insertions(+), 5 deletions(-) create mode 100644 naga/src/back/spv/subgroup.rs diff --git a/naga/src/back/dot/mod.rs b/naga/src/back/dot/mod.rs index b24eeae5ad..5ba3ffe49b 100644 --- a/naga/src/back/dot/mod.rs +++ b/naga/src/back/dot/mod.rs @@ -283,6 +283,25 @@ impl StatementGraph { self.emits.push((id, result)); "SubgroupBallot" } + S::SubgroupCollectiveOperation { + ref op, + ref collective_op, + argument, + result, + } => { + self.dependencies.push((id, argument, "arg")); + self.emits.push((id, result)); + "SubgroupCollectiveOperation" // FIXME + } + S::SubgroupBroadcast { + ref mode, + argument, + result, + } => { + self.dependencies.push((id, argument, "arg")); + self.emits.push((id, result)); + "SubgroupBroadcast" // FIXME + } }; // Set the last node to the merge node last_node = merge_id; @@ -591,6 +610,7 @@ fn write_function_expressions( (format!("rayQueryGet{}Intersection", ty).into(), 4) } E::SubgroupBallotResult => ("SubgroupBallotResult".into(), 4), + E::SubgroupOperationResult { .. } => ("SubgroupOperationResult".into(), 4), }; // give uniform expressions an outline diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index 0b680df256..fc23a5f891 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -2304,6 +2304,21 @@ impl<'a, W: Write> Writer<'a, W> { writeln!(self.out, "subgroupBallot(true);")?; } + Statement::SubgroupCollectiveOperation { + ref op, + ref collective_op, + argument, + result, + } => { + unimplemented!(); // FIXME: + } + Statement::SubgroupBroadcast { + ref mode, + argument, + result, + } => { + unimplemented!(); // FIXME + } } Ok(()) @@ -3478,6 +3493,7 @@ impl<'a, W: Write> Writer<'a, W> { | Expression::AtomicResult { .. } | Expression::RayQueryProceedResult | Expression::WorkGroupUniformLoadResult { .. } + | Expression::SubgroupOperationResult { .. } | Expression::SubgroupBallotResult => unreachable!(), // `ArrayLength` is written as `expr.length()` and we convert it to a uint Expression::ArrayLength(expr) => { diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 821dcc743d..b14d0c4bde 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -2013,6 +2013,21 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { writeln!(self.out, "WaveActiveBallot(true);")?; } + Statement::SubgroupCollectiveOperation { + ref op, + ref collective_op, + argument, + result, + } => { + unimplemented!(); // FIXME + } + Statement::SubgroupBroadcast { + ref mode, + argument, + result, + } => { + unimplemented!(); // FIXME + } } Ok(()) @@ -3162,7 +3177,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { | Expression::AtomicResult { .. } | Expression::WorkGroupUniformLoadResult { .. } | Expression::RayQueryProceedResult - | Expression::SubgroupBallotResult => {} + | Expression::SubgroupBallotResult + | Expression::SubgroupOperationResult { .. } => {} } if !closing_bracket.is_empty() { diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index c2a1eaddd0..cc4baf70ff 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -1936,6 +1936,7 @@ impl Writer { | crate::Expression::AtomicResult { .. } | crate::Expression::WorkGroupUniformLoadResult { .. } | crate::Expression::SubgroupBallotResult + | crate::Expression::SubgroupOperationResult { .. } | crate::Expression::RayQueryProceedResult => { unreachable!() } @@ -3018,6 +3019,21 @@ impl Writer { self.named_expressions.insert(result, name); write!(self.out, "{NAMESPACE}::simd_active_threads_mask();")?; } + crate::Statement::SubgroupCollectiveOperation { + ref op, + ref collective_op, + argument, + result, + } => { + unimplemented!(); // FIXME + } + crate::Statement::SubgroupBroadcast { + ref mode, + argument, + result, + } => { + unimplemented!(); // FIXME + } } } diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index 5c36cd8bcc..fa799b12c4 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -1131,7 +1131,8 @@ impl<'w> BlockContext<'w> { | crate::Expression::AtomicResult { .. } | crate::Expression::WorkGroupUniformLoadResult { .. } | crate::Expression::RayQueryProceedResult - | crate::Expression::SubgroupBallotResult => self.cached[expr_handle], + | crate::Expression::SubgroupBallotResult + | crate::Expression::SubgroupOperationResult { .. } => self.cached[expr_handle], crate::Expression::As { expr, kind, @@ -2361,6 +2362,21 @@ impl<'w> BlockContext<'w> { )); self.cached[result] = id; } + crate::Statement::SubgroupCollectiveOperation { + ref op, + ref collective_op, + argument, + result, + } => { + self.write_subgroup_operation(op, collective_op, argument, result, &mut block); + } + crate::Statement::SubgroupBroadcast { + ref mode, + argument, + result, + } => { + unimplemented!() // FIXME + } } } diff --git a/naga/src/back/spv/instructions.rs b/naga/src/back/spv/instructions.rs index 1ca58431d5..725014dee4 100644 --- a/naga/src/back/spv/instructions.rs +++ b/naga/src/back/spv/instructions.rs @@ -1052,6 +1052,23 @@ impl super::Instruction { instruction.add_operand(exec_scope_id); instruction.add_operand(predicate); + instruction + } + pub(super) fn group_non_uniform_arithmetic( + op: Op, + result_type_id: Word, + id: Word, + exec_scope_id: Word, + group_op: spirv::GroupOperation, + value: Word, + ) -> Self { + let mut instruction = Self::new(op); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(exec_scope_id); + instruction.add_operand(group_op as u32); + instruction.add_operand(value); + instruction } } diff --git a/naga/src/back/spv/mod.rs b/naga/src/back/spv/mod.rs index ac7281fc6b..a7a4bd302c 100644 --- a/naga/src/back/spv/mod.rs +++ b/naga/src/back/spv/mod.rs @@ -13,6 +13,7 @@ mod layout; mod ray; mod recyclable; mod selection; +mod subgroup; mod writer; pub use spirv::Capability; diff --git a/naga/src/back/spv/subgroup.rs b/naga/src/back/spv/subgroup.rs new file mode 100644 index 0000000000..a4f7b018f4 --- /dev/null +++ b/naga/src/back/spv/subgroup.rs @@ -0,0 +1,83 @@ +use super::{Block, BlockContext, Error, Instruction}; +use crate::{arena::Handle, TypeInner}; + +impl<'w> BlockContext<'w> { + pub(super) fn write_subgroup_operation( + &mut self, + op: &crate::SubgroupOperation, + collective_op: &crate::CollectiveOperation, + argument: Handle, + result: Handle, + block: &mut Block, + ) -> Result<(), Error> { + self.writer.require_any( + "GroupNonUniformArithmetic", + &[ + spirv::Capability::GroupNonUniformArithmetic, + spirv::Capability::GroupNonUniformClustered, + spirv::Capability::GroupNonUniformPartitionedNV, + ], + )?; + + let id = self.gen_id(); + let result_ty = &self.fun_info[result].ty; + let result_type_id = self.get_expression_type_id(result_ty); + let result_ty_inner = result_ty.inner_with(&self.ir_module.types); + let kind = result_ty_inner.scalar_kind().unwrap(); + + let (is_scalar, kind) = match result_ty_inner { + TypeInner::Scalar { kind, .. } => (true, kind), + TypeInner::Vector { kind, .. } => (false, kind), + _ => unimplemented!(), + }; + + use crate::ScalarKind as sk; + use crate::SubgroupOperation as sg; + let spirv_op = match (kind, op) { + (sk::Bool, sg::All) if is_scalar => spirv::Op::GroupNonUniformAll, + (sk::Bool, sg::Any) if is_scalar => spirv::Op::GroupNonUniformAny, + (_, sg::All | sg::Any) => unimplemented!(), + + (sk::Sint | sk::Uint, sg::Add) => spirv::Op::GroupNonUniformIAdd, + (sk::Float, sg::Add) => spirv::Op::GroupNonUniformFAdd, + (sk::Sint | sk::Uint, sg::Mul) => spirv::Op::GroupNonUniformIMul, + (sk::Float, sg::Mul) => spirv::Op::GroupNonUniformFMul, + (sk::Sint, sg::Max) => spirv::Op::GroupNonUniformSMax, + (sk::Uint, sg::Max) => spirv::Op::GroupNonUniformUMax, + (sk::Float, sg::Max) => spirv::Op::GroupNonUniformFMax, + (sk::Sint, sg::Min) => spirv::Op::GroupNonUniformSMin, + (sk::Uint, sg::Min) => spirv::Op::GroupNonUniformUMin, + (sk::Float, sg::Min) => spirv::Op::GroupNonUniformFMin, + (sk::Bool, sg::Add | sg::Mul | sg::Min | sg::Max) => unimplemented!(), + + (sk::Sint | sk::Uint, sg::And) => spirv::Op::GroupNonUniformBitwiseAnd, + (sk::Sint | sk::Uint, sg::Or) => spirv::Op::GroupNonUniformBitwiseOr, + (sk::Sint | sk::Uint, sg::Xor) => spirv::Op::GroupNonUniformBitwiseXor, + (sk::Float, sg::And | sg::Or | sg::Xor) => unimplemented!(), + (sk::Bool, sg::And) => spirv::Op::GroupNonUniformLogicalAnd, + (sk::Bool, sg::Or) => spirv::Op::GroupNonUniformLogicalOr, + (sk::Bool, sg::Xor) => spirv::Op::GroupNonUniformLogicalXor, + }; + + let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); + + use crate::CollectiveOperation as c; + let group_op = match collective_op { + c::Reduce => spirv::GroupOperation::Reduce, + c::InclusiveScan => spirv::GroupOperation::InclusiveScan, + c::ExclusiveScan => spirv::GroupOperation::ExclusiveScan, + }; + + let arg_id = self.cached[argument]; + block.body.push(Instruction::group_non_uniform_arithmetic( + spirv_op, + result_type_id, + id, + exec_scope_id, + group_op, + arg_id, + )); + self.cached[result] = id; + Ok(()) + } +} diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 33d45635c4..b66863f1df 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -929,6 +929,21 @@ impl Writer { writeln!(self.out, "subgroupBallot();")?; } + Statement::SubgroupCollectiveOperation { + ref op, + ref collective_op, + argument, + result, + } => { + unimplemented!() // FIXME + } + Statement::SubgroupBroadcast { + ref mode, + argument, + result, + } => { + unimplemented!() // FIXME + } } Ok(()) @@ -1668,6 +1683,7 @@ impl Writer { | Expression::AtomicResult { .. } | Expression::RayQueryProceedResult | Expression::SubgroupBallotResult + | Expression::SubgroupOperationResult { .. } | Expression::WorkGroupUniformLoadResult { .. } => {} } diff --git a/naga/src/compact/expressions.rs b/naga/src/compact/expressions.rs index d62d00a85f..1533362407 100644 --- a/naga/src/compact/expressions.rs +++ b/naga/src/compact/expressions.rs @@ -157,6 +157,7 @@ impl<'tracer> ExpressionTracer<'tracer> { Ex::AtomicResult { ty, comparison: _ } => self.trace_type(ty), Ex::WorkGroupUniformLoadResult { ty } => self.trace_type(ty), Ex::ArrayLength(expr) => work_list.push(expr), + Ex::SubgroupOperationResult { ty } => self.trace_type(ty), Ex::RayQueryGetIntersection { query, committed: _, @@ -351,6 +352,7 @@ impl ModuleMap { comparison: _, } => self.types.adjust(ty), Ex::WorkGroupUniformLoadResult { ref mut ty } => self.types.adjust(ty), + Ex::SubgroupOperationResult { ref mut ty } => self.types.adjust(ty), Ex::ArrayLength(ref mut expr) => adjust(expr), Ex::RayQueryGetIntersection { ref mut query, diff --git a/naga/src/compact/statements.rs b/naga/src/compact/statements.rs index 3b27c8b71a..462553b9d6 100644 --- a/naga/src/compact/statements.rs +++ b/naga/src/compact/statements.rs @@ -98,6 +98,26 @@ impl FunctionTracer<'_> { St::SubgroupBallot { result } => { self.trace_expression(result); } + St::SubgroupCollectiveOperation { + ref op, + ref collective_op, + argument, + result, + } => { + self.trace_expression(argument); + self.trace_expression(result); + } + St::SubgroupBroadcast { + ref mode, + argument, + result, + } => { + if let crate::BroadcastMode::Index(expr) = *mode { + self.trace_expression(expr); + } + self.trace_expression(argument); + self.trace_expression(result); + } // Trivial statements. St::Break @@ -248,6 +268,26 @@ impl FunctionMap { self.adjust_ray_query_function(fun); } St::SubgroupBallot { ref mut result } => adjust(result), + St::SubgroupCollectiveOperation { + ref mut op, + ref mut collective_op, + ref mut argument, + ref mut result, + } => { + adjust(argument); + adjust(result); + } + St::SubgroupBroadcast { + ref mut mode, + ref mut argument, + ref mut result, + } => { + if let crate::BroadcastMode::Index(expr) = mode { + adjust(expr); + } + adjust(argument); + adjust(result); + } // Trivial statements. St::Break diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index 2658a95d30..561e32e22d 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -3842,7 +3842,9 @@ impl> Frontend { } } S::WorkGroupUniformLoad { .. } => unreachable!(), - S::SubgroupBallot { .. } => unreachable!(), + S::SubgroupBallot { .. } => unreachable!(), // FIXME?? + S::SubgroupCollectiveOperation { .. } => unreachable!(), + S::SubgroupBroadcast { .. } => unreachable!(), } i += 1; } diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 57a1f860d1..1112f7fcb6 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -2259,6 +2259,45 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .push(crate::Statement::SubgroupBallot { result }, span); return Ok(Some(result)); } + "subgroupBroadcast" => { + unimplemented!(); // FIXME + } + "subgroupBroadcastFirst" => { + unimplemented!(); // FIXME + } + "subgroupAll" => { + unimplemented!(); // FIXME + } + "subgroupAny" => { + unimplemented!(); // FIXME + } + "subgroupAdd" => { + unimplemented!(); // FIXME + } + "subgroupMul" => { + unimplemented!(); // FIXME + } + "subgroupMin" => { + unimplemented!(); // FIXME + } + "subgroupMax" => { + unimplemented!(); // FIXME + } + "subgroupAnd" => { + unimplemented!(); // FIXME + } + "subgroupOr" => { + unimplemented!(); // FIXME + } + "subgroupXor" => { + unimplemented!(); // FIXME + } + "subgroupPrefixAdd" => { + unimplemented!(); // FIXME + } + "subgroupPrefixMul" => { + unimplemented!(); // FIXME + } _ => return Err(Error::UnknownIdent(function.span, function.name)), } }; diff --git a/naga/src/lib.rs b/naga/src/lib.rs index a6f755fae6..d837f8336b 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -1261,6 +1261,42 @@ pub enum SwizzleComponent { W = 3, } +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum BroadcastMode { + First, + Index(Handle), +} + +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum SubgroupOperation { + All = 0, + Any = 1, + Add = 2, + Mul = 3, + Min = 4, + Max = 5, + And = 6, + Or = 7, + Xor = 8, +} + +#[repr(u8)] +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum CollectiveOperation { + Reduce = 0, + InclusiveScan = 1, + ExclusiveScan = 2, +} + bitflags::bitflags! { /// Memory barrier flags. #[cfg_attr(feature = "serialize", derive(Serialize))] @@ -1578,6 +1614,9 @@ pub enum Expression { committed: bool, }, SubgroupBallotResult, + SubgroupOperationResult { + ty: Handle, + }, } pub use block::Block; @@ -1850,12 +1889,55 @@ pub enum Statement { /// The specific operation we're performing on `query`. fun: RayQueryFunction, }, + // subgroupBallot(bool) -> vec4 SubgroupBallot { /// The [`SubgroupBallotResult`] expression representing this load's result. /// /// [`SubgroupBallotResult`]: Expression::SubgroupBallotResult result: Handle, }, + + // subgroupBroadcast(value, lane) -> value + // subgroupBroadcastFirst(value) -> value + SubgroupBroadcast { + /// Specifies which thread to broadcast from + mode: BroadcastMode, + /// The value to broadcast over + argument: Handle, + /// The [`SubgroupBroadcastResult`] expression representing this load's result. + /// + /// [`SubgroupBroadcastResult`]: Expression::SubgroupBroadcastResult + result: Handle, + }, + + // Reduction on bool + // subgroupAll(bool) -> bool + // subgroupAny(bool) -> bool + // Reduction on float, int + // subgroupMin(value) -> value + // subgroupMax(value) -> value + // subgroupAdd(value) -> value + // subgroupMul(value) -> value + // Reduction on int + // subgroupAnd(value) -> value + // subgroupOr(value) -> value + // subgroupXor(value) -> value + // Scan on float, int + // subgroupPrefixAdd(value) -> value + // subgroupPrefixMul(value) -> value + /// Compute a collective operation across all active threads in th subgroup + SubgroupCollectiveOperation { + /// What operation to compute + op: SubgroupOperation, + /// How to combine the results + collective_op: CollectiveOperation, + /// The value to compute over + argument: Handle, + /// The [`SubgroupOperationResult`] expression representing this load's result. + /// + /// [`SubgroupOperationResult`]: Expression::SubgroupOperationResult + result: Handle, + }, } /// A function argument. diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index d45ab4683f..d13a66ec2f 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -441,7 +441,9 @@ impl<'a> ConstantEvaluator<'a> { Expression::RayQueryProceedResult | Expression::RayQueryGetIntersection { .. } => { Err(ConstantEvaluatorError::RayQueryExpression) } - Expression::SubgroupBallotResult => Err(ConstantEvaluatorError::SubgroupOperation), + Expression::SubgroupBallotResult | Expression::SubgroupOperationResult { .. } => { + Err(ConstantEvaluatorError::SubgroupOperation) + } } } diff --git a/naga/src/proc/terminator.rs b/naga/src/proc/terminator.rs index d2dde729f1..35111a11de 100644 --- a/naga/src/proc/terminator.rs +++ b/naga/src/proc/terminator.rs @@ -38,6 +38,8 @@ pub fn ensure_block_returns(block: &mut crate::Block) { | S::Atomic { .. } | S::WorkGroupUniformLoad { .. } | S::SubgroupBallot { .. } + | S::SubgroupCollectiveOperation { .. } + | S::SubgroupBroadcast { .. } | S::Barrier(_)), ) | None => block.push(S::Return { value: None }, Default::default()), diff --git a/naga/src/proc/typifier.rs b/naga/src/proc/typifier.rs index 6241c5bad8..c2b38ac73b 100644 --- a/naga/src/proc/typifier.rs +++ b/naga/src/proc/typifier.rs @@ -638,6 +638,7 @@ impl<'a> ResolveContext<'a> { | crate::BinaryOperator::ShiftRight => past(left)?.clone(), }, crate::Expression::AtomicResult { ty, .. } => TypeResolution::Handle(ty), + crate::Expression::SubgroupOperationResult { ty } => TypeResolution::Handle(ty), crate::Expression::WorkGroupUniformLoadResult { ty } => TypeResolution::Handle(ty), crate::Expression::Select { accept, .. } => past(accept)?.clone(), crate::Expression::Derivative { expr, .. } => past(expr)?.clone(), diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index d23caaf473..e9ca4ee7c7 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -741,7 +741,11 @@ impl FunctionInfo { requirements: UniformityRequirements::empty(), }, E::SubgroupBallotResult => Uniformity { - non_uniform_result: None, + non_uniform_result: None, // FIXME + requirements: UniformityRequirements::empty(), + }, + E::SubgroupOperationResult { ty } => Uniformity { + non_uniform_result: None, // FIXME requirements: UniformityRequirements::empty(), }, }; @@ -988,6 +992,26 @@ impl FunctionInfo { FunctionUniformity::new() } S::SubgroupBallot { result: _ } => FunctionUniformity::new(), + S::SubgroupCollectiveOperation { + ref op, + ref collective_op, + argument, + result: _, + } => { + let _ = self.add_ref(argument); + FunctionUniformity::new() + } + S::SubgroupBroadcast { + ref mode, + argument, + result, + } => { + let _ = self.add_ref(argument); + if let crate::BroadcastMode::Index(expr) = *mode { + let _ = self.add_ref(expr); + } + FunctionUniformity::new() + } }; disruptor = disruptor.or(uniformity.exit_disruptor()); diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 890a4b9973..f840bba42d 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -1538,6 +1538,7 @@ impl super::Validator { } }, E::SubgroupBallotResult => ShaderStages::COMPUTE | ShaderStages::FRAGMENT, + E::SubgroupOperationResult { ty } => ShaderStages::COMPUTE, // FIXME }; Ok(stages) } diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index 52f51a2810..090fffd4c6 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -51,6 +51,15 @@ pub enum AtomicError { ResultTypeMismatch(Handle), } +#[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] +pub enum SubgroupError { + #[error("Operand {0:?} has invalid type.")] + InvalidOperand(Handle), + #[error("Result type for {0:?} doesn't match the statement")] + ResultTypeMismatch(Handle), +} + #[derive(Clone, Debug, thiserror::Error)] #[cfg_attr(test, derive(PartialEq))] pub enum LocalVariableError { @@ -159,6 +168,8 @@ pub enum FunctionError { WorkgroupUniformLoadExpressionMismatch(Handle), #[error("The expression {0:?} is not valid as a WorkGroupUniformLoad argument. It should be a Pointer in Workgroup address space")] WorkgroupUniformLoadInvalidPointer(Handle), + #[error("Subgroup operation is invalid")] + InvalidSubgroup(#[from] SubgroupError), } bitflags::bitflags! { @@ -413,6 +424,102 @@ impl super::Validator { } Ok(()) } + #[cfg(feature = "validate")] + fn validate_subgroup_operation( + &mut self, + op: &crate::SubgroupOperation, + _collective_op: &crate::CollectiveOperation, + argument: Handle, + result: Handle, + context: &BlockContext, + ) -> Result<(), WithSpan> { + self.emit_expression(argument, context)?; + let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; + + let (is_scalar, kind) = match argument_inner { + crate::TypeInner::Scalar { kind, .. } => (true, kind), + crate::TypeInner::Vector { kind, .. } => (false, kind), + _ => unimplemented!(), + }; + + use crate::ScalarKind as sk; + use crate::SubgroupOperation as sg; + match (kind, op) { + (sk::Bool, sg::All | sg::Any) if is_scalar => {} + (sk::Sint | sk::Uint | sk::Float, sg::Add | sg::Mul | sg::Min | sg::Max) => {} + (sk::Sint | sk::Uint | sk::Bool, sg::And | sg::Or | sg::Xor) => {} + + (_, sg::All | sg::Any) + | (sk::Bool, sg::Add | sg::Mul | sg::Min | sg::Max) + | (sk::Float, sg::And | sg::Or | sg::Xor) => { + log::error!("Subgroup operand type {:?}", argument_inner); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(argument, context.expressions) + .into_other()); + } + }; + + self.emit_expression(result, context)?; + match context.expressions[result] { + crate::Expression::SubgroupOperationResult { ty } + if { &context.types[ty].inner == argument_inner } => {} + _ => { + return Err(SubgroupError::ResultTypeMismatch(result) + .with_span_handle(result, context.expressions) + .into_other()) + } + } + Ok(()) + } + #[cfg(feature = "validate")] + fn validate_subgroup_broadcast( + &mut self, + mode: &crate::BroadcastMode, + argument: Handle, + result: Handle, + context: &BlockContext, + ) -> Result<(), WithSpan> { + if let crate::BroadcastMode::Index(expr) = *mode { + self.emit_expression(expr, context)?; + let index_ty = context.resolve_type(expr, &self.valid_expression_set)?; + match index_ty { + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Uint, + .. + } => {} + _ => { + log::error!("Subgroup broadcast index type {:?}", index_ty); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(argument, context.expressions) + .into_other()); + } + } + } + self.emit_expression(argument, context)?; + let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; + + match argument_inner { + crate::TypeInner::Scalar { .. } | crate::TypeInner::Vector { .. } => {} + _ => { + log::error!("Subgroup operand type {:?}", argument_inner); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(argument, context.expressions) + .into_other()); + } + } + + self.emit_expression(result, context)?; + match context.expressions[result] { + crate::Expression::SubgroupOperationResult { ty } + if { &context.types[ty].inner == argument_inner } => {} + _ => { + return Err(SubgroupError::ResultTypeMismatch(result) + .with_span_handle(result, context.expressions) + .into_other()) + } + } + Ok(()) + } #[cfg(feature = "validate")] fn validate_block_impl( @@ -922,6 +1029,21 @@ impl super::Validator { S::SubgroupBallot { result } => { self.emit_expression(result, context)?; } + S::SubgroupCollectiveOperation { + ref op, + ref collective_op, + argument, + result, + } => { + self.validate_subgroup_operation(op, collective_op, argument, result, context)?; + } + S::SubgroupBroadcast { + ref mode, + argument, + result, + } => { + self.validate_subgroup_broadcast(mode, argument, result, context)?; + } } } Ok(BlockInfo { stages, finished }) diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index 547dfac551..399665342f 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -395,6 +395,7 @@ impl super::Validator { crate::Expression::AtomicResult { .. } | crate::Expression::RayQueryProceedResult | crate::Expression::SubgroupBallotResult + | crate::Expression::SubgroupOperationResult { .. } | crate::Expression::WorkGroupUniformLoadResult { .. } => (), crate::Expression::ArrayLength(array) => { handle.check_dep(array)?; @@ -544,6 +545,27 @@ impl super::Validator { validate_expr(result)?; Ok(()) } + crate::Statement::SubgroupCollectiveOperation { + op, + collective_op, + argument, + result, + } => { + validate_expr(argument)?; + validate_expr(result)?; + Ok(()) + } + crate::Statement::SubgroupBroadcast { + mode, + argument, + result, + } => { + if let crate::BroadcastMode::Index(expr) = mode { + validate_expr(expr)?; + } + validate_expr(result)?; + Ok(()) + } crate::Statement::Break | crate::Statement::Continue | crate::Statement::Kill From b5be66eebc445c5de300d328462cf750ae17e518 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Thu, 5 Oct 2023 09:43:29 -0400 Subject: [PATCH 08/46] subgroup: fix doc error on SubgroupBroadcast --- naga/src/lib.rs | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/naga/src/lib.rs b/naga/src/lib.rs index d837f8336b..171801fa54 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -1897,34 +1897,17 @@ pub enum Statement { result: Handle, }, - // subgroupBroadcast(value, lane) -> value - // subgroupBroadcastFirst(value) -> value SubgroupBroadcast { /// Specifies which thread to broadcast from mode: BroadcastMode, /// The value to broadcast over argument: Handle, - /// The [`SubgroupBroadcastResult`] expression representing this load's result. + /// The [`SubgroupOperationResult`] expression representing this load's result. /// - /// [`SubgroupBroadcastResult`]: Expression::SubgroupBroadcastResult + /// [`SubgroupOperationResult`]: Expression::SubgroupOperationResult result: Handle, }, - // Reduction on bool - // subgroupAll(bool) -> bool - // subgroupAny(bool) -> bool - // Reduction on float, int - // subgroupMin(value) -> value - // subgroupMax(value) -> value - // subgroupAdd(value) -> value - // subgroupMul(value) -> value - // Reduction on int - // subgroupAnd(value) -> value - // subgroupOr(value) -> value - // subgroupXor(value) -> value - // Scan on float, int - // subgroupPrefixAdd(value) -> value - // subgroupPrefixMul(value) -> value /// Compute a collective operation across all active threads in th subgroup SubgroupCollectiveOperation { /// What operation to compute From a277ec55ecbb38458c8cd8a4cbb0c94642eee183 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Thu, 5 Oct 2023 11:46:55 -0400 Subject: [PATCH 09/46] subgroup: wgsl-in and spv-out for subgroup operations --- naga/src/back/spv/block.rs | 4 +- naga/src/back/spv/instructions.rs | 40 ++++++++++- naga/src/back/spv/subgroup.rs | 84 +++++++++++++++++++---- naga/src/front/wgsl/lower/mod.rs | 109 ++++++++++++++++++++---------- naga/src/front/wgsl/parse/conv.rs | 21 ++++++ naga/src/valid/function.rs | 12 ++-- 6 files changed, 211 insertions(+), 59 deletions(-) diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index fa799b12c4..b2366f3b16 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -2368,14 +2368,14 @@ impl<'w> BlockContext<'w> { argument, result, } => { - self.write_subgroup_operation(op, collective_op, argument, result, &mut block); + self.write_subgroup_operation(op, collective_op, argument, result, &mut block)?; } crate::Statement::SubgroupBroadcast { ref mode, argument, result, } => { - unimplemented!() // FIXME + self.write_subgroup_broadcast(mode, argument, result, &mut block)?; } } } diff --git a/naga/src/back/spv/instructions.rs b/naga/src/back/spv/instructions.rs index 725014dee4..8a528065ef 100644 --- a/naga/src/back/spv/instructions.rs +++ b/naga/src/back/spv/instructions.rs @@ -1054,19 +1054,55 @@ impl super::Instruction { instruction } + pub(super) fn group_non_uniform_broadcast( + result_type_id: Word, + id: Word, + exec_scope_id: Word, + value: Word, + index: Word, + ) -> Self { + let mut instruction = Self::new(Op::GroupNonUniformBroadcast); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(exec_scope_id); + instruction.add_operand(value); + instruction.add_operand(index); + + instruction + } + pub(super) fn group_non_uniform_broadcast_first( + result_type_id: Word, + id: Word, + exec_scope_id: Word, + value: Word, + ) -> Self { + let mut instruction = Self::new(Op::GroupNonUniformBroadcastFirst); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(exec_scope_id); + instruction.add_operand(value); + + instruction + } pub(super) fn group_non_uniform_arithmetic( op: Op, result_type_id: Word, id: Word, exec_scope_id: Word, - group_op: spirv::GroupOperation, + group_op: Option, value: Word, ) -> Self { + println!( + "{:?}", + (op, result_type_id, id, exec_scope_id, group_op, value) + ); let mut instruction = Self::new(op); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(exec_scope_id); - instruction.add_operand(group_op as u32); + if let Some(group_op) = group_op { + instruction.add_operand(group_op as u32); + } instruction.add_operand(value); instruction diff --git a/naga/src/back/spv/subgroup.rs b/naga/src/back/spv/subgroup.rs index a4f7b018f4..3c8b7d827a 100644 --- a/naga/src/back/spv/subgroup.rs +++ b/naga/src/back/spv/subgroup.rs @@ -10,20 +10,30 @@ impl<'w> BlockContext<'w> { result: Handle, block: &mut Block, ) -> Result<(), Error> { - self.writer.require_any( - "GroupNonUniformArithmetic", - &[ - spirv::Capability::GroupNonUniformArithmetic, - spirv::Capability::GroupNonUniformClustered, - spirv::Capability::GroupNonUniformPartitionedNV, - ], - )?; + use crate::SubgroupOperation as sg; + match op { + sg::All | sg::Any => { + self.writer.require_any( + "GroupNonUniformVote", + &[spirv::Capability::GroupNonUniformVote], + )?; + } + _ => { + self.writer.require_any( + "GroupNonUniformArithmetic", + &[ + spirv::Capability::GroupNonUniformArithmetic, + spirv::Capability::GroupNonUniformClustered, + spirv::Capability::GroupNonUniformPartitionedNV, + ], + )?; + } + } let id = self.gen_id(); let result_ty = &self.fun_info[result].ty; let result_type_id = self.get_expression_type_id(result_ty); let result_ty_inner = result_ty.inner_with(&self.ir_module.types); - let kind = result_ty_inner.scalar_kind().unwrap(); let (is_scalar, kind) = match result_ty_inner { TypeInner::Scalar { kind, .. } => (true, kind), @@ -32,7 +42,6 @@ impl<'w> BlockContext<'w> { }; use crate::ScalarKind as sk; - use crate::SubgroupOperation as sg; let spirv_op = match (kind, op) { (sk::Bool, sg::All) if is_scalar => spirv::Op::GroupNonUniformAll, (sk::Bool, sg::Any) if is_scalar => spirv::Op::GroupNonUniformAny, @@ -62,10 +71,13 @@ impl<'w> BlockContext<'w> { let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); use crate::CollectiveOperation as c; - let group_op = match collective_op { - c::Reduce => spirv::GroupOperation::Reduce, - c::InclusiveScan => spirv::GroupOperation::InclusiveScan, - c::ExclusiveScan => spirv::GroupOperation::ExclusiveScan, + let group_op = match op { + sg::All | sg::Any => None, + _ => Some(match collective_op { + c::Reduce => spirv::GroupOperation::Reduce, + c::InclusiveScan => spirv::GroupOperation::InclusiveScan, + c::ExclusiveScan => spirv::GroupOperation::ExclusiveScan, + }), }; let arg_id = self.cached[argument]; @@ -80,4 +92,48 @@ impl<'w> BlockContext<'w> { self.cached[result] = id; Ok(()) } + pub(super) fn write_subgroup_broadcast( + &mut self, + mode: &crate::BroadcastMode, + argument: Handle, + result: Handle, + block: &mut Block, + ) -> Result<(), Error> { + self.writer.require_any( + "GroupNonUniformBallot", + &[spirv::Capability::GroupNonUniformBallot], + )?; + + let id = self.gen_id(); + let result_ty = &self.fun_info[result].ty; + let result_type_id = self.get_expression_type_id(result_ty); + + let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); + + let arg_id = self.cached[argument]; + match mode { + crate::BroadcastMode::Index(index) => { + let index_id = self.cached[*index]; + block.body.push(Instruction::group_non_uniform_broadcast( + result_type_id, + id, + exec_scope_id, + arg_id, + index_id, + )); + } + crate::BroadcastMode::First => { + block + .body + .push(Instruction::group_non_uniform_broadcast_first( + result_type_id, + id, + exec_scope_id, + arg_id, + )); + } + } + self.cached[result] = id; + Ok(()) + } } diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 1112f7fcb6..3d2d921a01 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -1866,6 +1866,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } } else if let Some(fun) = Texture::map(function.name) { self.texture_sample_helper(fun, arguments, span, ctx)? + } else if let Some((op, cop)) = conv::map_subgroup_operation(function.name) { + return Ok(Some(self.subgroup_helper(span, op, cop, arguments, ctx)?)); } else { match function.name { "select" => { @@ -2260,43 +2262,51 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { return Ok(Some(result)); } "subgroupBroadcast" => { - unimplemented!(); // FIXME + let mut args = ctx.prepare_args(arguments, 2, span); + + let index = self.expression(args.next()?, ctx)?; + let argument = self.expression(args.next()?, ctx)?; + args.finish()?; + + let ty = ctx.register_type(argument)?; + + let result = ctx.interrupt_emitter( + crate::Expression::SubgroupOperationResult { ty }, + span, + )?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::SubgroupBroadcast { + mode: crate::BroadcastMode::Index(index), + argument, + result, + }, + span, + ); + return Ok(Some(result)); } "subgroupBroadcastFirst" => { - unimplemented!(); // FIXME - } - "subgroupAll" => { - unimplemented!(); // FIXME - } - "subgroupAny" => { - unimplemented!(); // FIXME - } - "subgroupAdd" => { - unimplemented!(); // FIXME - } - "subgroupMul" => { - unimplemented!(); // FIXME - } - "subgroupMin" => { - unimplemented!(); // FIXME - } - "subgroupMax" => { - unimplemented!(); // FIXME - } - "subgroupAnd" => { - unimplemented!(); // FIXME - } - "subgroupOr" => { - unimplemented!(); // FIXME - } - "subgroupXor" => { - unimplemented!(); // FIXME - } - "subgroupPrefixAdd" => { - unimplemented!(); // FIXME - } - "subgroupPrefixMul" => { - unimplemented!(); // FIXME + let mut args = ctx.prepare_args(arguments, 1, span); + + let argument = self.expression(args.next()?, ctx)?; + args.finish()?; + + let ty = ctx.register_type(argument)?; + + let result = ctx.interrupt_emitter( + crate::Expression::SubgroupOperationResult { ty }, + span, + )?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::SubgroupBroadcast { + mode: crate::BroadcastMode::First, + argument, + result, + }, + span, + ); + return Ok(Some(result)); } _ => return Err(Error::UnknownIdent(function.span, function.name)), } @@ -2488,6 +2498,35 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { depth_ref, }) } + fn subgroup_helper( + &mut self, + span: Span, + op: crate::SubgroupOperation, + collective_op: crate::CollectiveOperation, + arguments: &[Handle>], + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result, Error<'source>> { + let mut args = ctx.prepare_args(arguments, 1, span); + + let argument = self.expression(args.next()?, ctx)?; + args.finish()?; + + let ty = ctx.register_type(argument)?; + + let result = + ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::SubgroupCollectiveOperation { + op, + collective_op, + argument, + result, + }, + span, + ); + Ok(result) + } fn r#struct( &mut self, diff --git a/naga/src/front/wgsl/parse/conv.rs b/naga/src/front/wgsl/parse/conv.rs index a27bdb1cbc..160213e4a3 100644 --- a/naga/src/front/wgsl/parse/conv.rs +++ b/naga/src/front/wgsl/parse/conv.rs @@ -238,3 +238,24 @@ pub fn map_conservative_depth( _ => Err(Error::UnknownConservativeDepth(span)), } } + +pub fn map_subgroup_operation( + word: &str, +) -> Option<(crate::SubgroupOperation, crate::CollectiveOperation)> { + use crate::CollectiveOperation as co; + use crate::SubgroupOperation as sg; + Some(match word { + "subgroupAll" => (sg::All, co::Reduce), + "subgroupAny" => (sg::Any, co::Reduce), + "subgroupAdd" => (sg::Add, co::Reduce), + "subgroupMul" => (sg::Mul, co::Reduce), + "subgroupMin" => (sg::Min, co::Reduce), + "subgroupMax" => (sg::Max, co::Reduce), + "subgroupAnd" => (sg::And, co::Reduce), + "subgroupOr" => (sg::Or, co::Reduce), + "subgroupXor" => (sg::Xor, co::Reduce), + "subgroupPrefixAdd" => (sg::Add, co::InclusiveScan), + "subgroupPrefixMul" => (sg::Mul, co::InclusiveScan), + _ => return None, + }) +} diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index 090fffd4c6..171e6a747b 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -433,7 +433,6 @@ impl super::Validator { result: Handle, context: &BlockContext, ) -> Result<(), WithSpan> { - self.emit_expression(argument, context)?; let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; let (is_scalar, kind) = match argument_inner { @@ -480,7 +479,6 @@ impl super::Validator { context: &BlockContext, ) -> Result<(), WithSpan> { if let crate::BroadcastMode::Index(expr) = *mode { - self.emit_expression(expr, context)?; let index_ty = context.resolve_type(expr, &self.valid_expression_set)?; match index_ty { crate::TypeInner::Scalar { @@ -488,20 +486,22 @@ impl super::Validator { .. } => {} _ => { - log::error!("Subgroup broadcast index type {:?}", index_ty); + log::error!( + "Subgroup broadcast index type {:?}, expected unsigned int", + index_ty + ); return Err(SubgroupError::InvalidOperand(argument) - .with_span_handle(argument, context.expressions) + .with_span_handle(expr, context.expressions) .into_other()); } } } - self.emit_expression(argument, context)?; let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; match argument_inner { crate::TypeInner::Scalar { .. } | crate::TypeInner::Vector { .. } => {} _ => { - log::error!("Subgroup operand type {:?}", argument_inner); + log::error!("Subgroup broadcast operand type {:?}", argument_inner); return Err(SubgroupError::InvalidOperand(argument) .with_span_handle(argument, context.expressions) .into_other()); From dd49f99e9d486cf9fda5026c4312ffb3f3a47eb9 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Thu, 5 Oct 2023 13:06:47 -0400 Subject: [PATCH 10/46] subgroup: add optional predicate for subgroupBallot --- naga/src/back/dot/mod.rs | 5 ++++- naga/src/back/glsl/mod.rs | 9 +++++++-- naga/src/back/hlsl/writer.rs | 9 +++++++-- naga/src/back/msl/writer.rs | 11 +++++++++-- naga/src/back/spv/block.rs | 7 +++++-- naga/src/back/wgsl/writer.rs | 8 ++++++-- naga/src/compact/statements.rs | 15 +++++++++++++-- naga/src/front/wgsl/lower/mod.rs | 10 ++++++++-- naga/src/lib.rs | 2 ++ naga/src/valid/analyzer.rs | 12 ++++++++++-- naga/src/valid/function.rs | 21 ++++++++++++++++++++- naga/src/valid/handles.rs | 3 ++- 12 files changed, 93 insertions(+), 19 deletions(-) diff --git a/naga/src/back/dot/mod.rs b/naga/src/back/dot/mod.rs index 5ba3ffe49b..fec2c60d32 100644 --- a/naga/src/back/dot/mod.rs +++ b/naga/src/back/dot/mod.rs @@ -279,7 +279,10 @@ impl StatementGraph { crate::RayQueryFunction::Terminate => "RayQueryTerminate", } } - S::SubgroupBallot { result } => { + S::SubgroupBallot { result, predicate } => { + if let Some(predicate) = predicate { + self.dependencies.push((id, predicate, "predicate")); + } self.emits.push((id, result)); "SubgroupBallot" } diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index fc23a5f891..948fc3645b 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -2294,7 +2294,7 @@ impl<'a, W: Write> Writer<'a, W> { writeln!(self.out, ");")?; } Statement::RayQuery { .. } => unreachable!(), - Statement::SubgroupBallot { result } => { + Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); let res_ty = ctx.info[result].ty.inner_with(&self.module.types); @@ -2302,7 +2302,12 @@ impl<'a, W: Write> Writer<'a, W> { write!(self.out, " {res_name} = ")?; self.named_expressions.insert(result, res_name); - writeln!(self.out, "subgroupBallot(true);")?; + write!(self.out, "subgroupBallot(")?; + match predicate { + Some(predicate) => self.write_expr(predicate, ctx)?, + None => write!(self.out, "true")?, + } + write!(self.out, ");")?; } Statement::SubgroupCollectiveOperation { ref op, diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index b14d0c4bde..98b60f4954 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -2004,14 +2004,19 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { writeln!(self.out, "{level}}}")? } Statement::RayQuery { .. } => unreachable!(), - Statement::SubgroupBallot { result } => { + Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let name = format!("{}{}", back::BAKE_PREFIX, result.index()); write!(self.out, "const uint4 {name} = ")?; self.named_expressions.insert(result, name); - writeln!(self.out, "WaveActiveBallot(true);")?; + write!(self.out, "WaveActiveBallot(")?; + match predicate { + Some(predicate) => self.write_expr(module, predicate, func_ctx)?, + None => write!(self.out, "true")?, + } + writeln!(self.out, ");")?; } Statement::SubgroupCollectiveOperation { ref op, diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index cc4baf70ff..30cf02ddc4 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -3012,12 +3012,19 @@ impl Writer { } } } - crate::Statement::SubgroupBallot { result } => { + crate::Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let name = self.namer.call(""); self.start_baking_expression(result, &context.expression, &name)?; self.named_expressions.insert(result, name); - write!(self.out, "{NAMESPACE}::simd_active_threads_mask();")?; + write!(self.out, "{NAMESPACE}::simd_ballot(;")?; + match predicate { + Some(predicate) => { + self.put_expression(predicate, &context.expression, true)? + } + None => write!(self.out, "true")?, + } + writeln!(self.out, ");")?; } crate::Statement::SubgroupCollectiveOperation { ref op, diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index b2366f3b16..222d0cde39 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -2340,7 +2340,7 @@ impl<'w> BlockContext<'w> { crate::Statement::RayQuery { query, ref fun } => { self.write_ray_query_function(query, fun, &mut block); } - crate::Statement::SubgroupBallot { result } => { + crate::Statement::SubgroupBallot { result, predicate } => { self.writer.require_any( "GroupNonUniformBallot", &[spirv::Capability::GroupNonUniformBallot], @@ -2352,7 +2352,10 @@ impl<'w> BlockContext<'w> { pointer_space: None, })); let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); - let predicate = self.writer.get_constant_scalar(crate::Literal::Bool(true)); + let predicate = match predicate { + Some(predicate) => self.cached[predicate], + None => self.writer.get_constant_scalar(crate::Literal::Bool(true)), + }; let id = self.gen_id(); block.body.push(Instruction::group_non_uniform_ballot( vec4_u32_type_id, diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index b66863f1df..91718d4f4d 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -921,13 +921,17 @@ impl Writer { } } Statement::RayQuery { .. } => unreachable!(), - Statement::SubgroupBallot { result } => { + Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); self.start_named_expr(module, result, func_ctx, &res_name)?; self.named_expressions.insert(result, res_name); - writeln!(self.out, "subgroupBallot();")?; + writeln!(self.out, "subgroupBallot(")?; + if let Some(predicate) = predicate { + self.write_expr(module, predicate, func_ctx)?; + } + writeln!(self.out, ");")?; } Statement::SubgroupCollectiveOperation { ref op, diff --git a/naga/src/compact/statements.rs b/naga/src/compact/statements.rs index 462553b9d6..0e8c0e4e81 100644 --- a/naga/src/compact/statements.rs +++ b/naga/src/compact/statements.rs @@ -95,7 +95,10 @@ impl FunctionTracer<'_> { self.trace_expression(query); self.trace_ray_query_function(fun); } - St::SubgroupBallot { result } => { + St::SubgroupBallot { result, predicate } => { + if let Some(predicate) = predicate { + self.trace_expression(predicate); + } self.trace_expression(result); } St::SubgroupCollectiveOperation { @@ -267,7 +270,15 @@ impl FunctionMap { adjust(query); self.adjust_ray_query_function(fun); } - St::SubgroupBallot { ref mut result } => adjust(result), + St::SubgroupBallot { + ref mut result, + ref mut predicate, + } => { + if let Some(ref mut predicate) = predicate { + adjust(predicate); + } + adjust(result); + } St::SubgroupCollectiveOperation { ref mut op, ref mut collective_op, diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 3d2d921a01..d3a353be0c 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -2252,13 +2252,19 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { return Ok(Some(handle)); } "subgroupBallot" => { - ctx.prepare_args(arguments, 0, span).finish()?; + let mut args = ctx.prepare_args(arguments, 0, span); + let predicate = if arguments.len() == 1 { + Some(self.expression(args.next()?, ctx)?) + } else { + None + }; + args.finish()?; let result = ctx .interrupt_emitter(crate::Expression::SubgroupBallotResult, span)?; let rctx = ctx.runtime_expression_ctx(span)?; rctx.block - .push(crate::Statement::SubgroupBallot { result }, span); + .push(crate::Statement::SubgroupBallot { result, predicate }, span); return Ok(Some(result)); } "subgroupBroadcast" => { diff --git a/naga/src/lib.rs b/naga/src/lib.rs index 171801fa54..b50b1c1f23 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -1895,6 +1895,8 @@ pub enum Statement { /// /// [`SubgroupBallotResult`]: Expression::SubgroupBallotResult result: Handle, + /// The value from this thread to store in the ballot + predicate: Option>, }, SubgroupBroadcast { diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index e9ca4ee7c7..262e3ffec7 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -991,14 +991,22 @@ impl FunctionInfo { } FunctionUniformity::new() } - S::SubgroupBallot { result: _ } => FunctionUniformity::new(), + S::SubgroupBallot { + result: _, + predicate, + } => { + if let Some(predicate) = predicate { + let _ = self.add_ref(predicate); + } + FunctionUniformity::new() + } S::SubgroupCollectiveOperation { ref op, ref collective_op, argument, result: _, } => { - let _ = self.add_ref(argument); + let _ = self.add_ref(argument); // FIXME FunctionUniformity::new() } S::SubgroupBroadcast { diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index 171e6a747b..91127a40ad 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -1026,7 +1026,26 @@ impl super::Validator { crate::RayQueryFunction::Terminate => {} } } - S::SubgroupBallot { result } => { + S::SubgroupBallot { result, predicate } => { + if let Some(predicate) = predicate { + let predicate_inner = + context.resolve_type(predicate, &self.valid_expression_set)?; + match predicate_inner { + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Bool, + .. + } => {} + _ => { + log::error!( + "Subgroup ballot predicate type {:?} expected bool", + predicate_inner + ); + return Err(SubgroupError::InvalidOperand(predicate) + .with_span_handle(predicate, context.expressions) + .into_other()); + } + } + } self.emit_expression(result, context)?; } S::SubgroupCollectiveOperation { diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index 399665342f..91209b460b 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -541,7 +541,8 @@ impl super::Validator { } Ok(()) } - crate::Statement::SubgroupBallot { result } => { + crate::Statement::SubgroupBallot { result, predicate } => { + validate_expr_opt(predicate)?; validate_expr(result)?; Ok(()) } From 18ceb01f968331897a908270d7512af5794e6343 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Mei=C3=9Fner?= Date: Thu, 12 Oct 2023 01:07:12 +0200 Subject: [PATCH 11/46] Renames SubgroupBroadcast => SubgroupGather and BroadcastMode => GatherMode. --- naga/src/back/dot/mod.rs | 4 ++-- naga/src/back/glsl/mod.rs | 2 +- naga/src/back/hlsl/writer.rs | 2 +- naga/src/back/msl/writer.rs | 2 +- naga/src/back/spv/block.rs | 2 +- naga/src/back/spv/subgroup.rs | 6 +++--- naga/src/back/wgsl/writer.rs | 2 +- naga/src/compact/statements.rs | 8 ++++---- naga/src/front/spv/mod.rs | 2 +- naga/src/front/wgsl/lower/mod.rs | 8 ++++---- naga/src/lib.rs | 14 +++++++------- naga/src/proc/terminator.rs | 2 +- naga/src/valid/analyzer.rs | 4 ++-- naga/src/valid/function.rs | 6 +++--- naga/src/valid/handles.rs | 12 +++++++----- 15 files changed, 39 insertions(+), 37 deletions(-) diff --git a/naga/src/back/dot/mod.rs b/naga/src/back/dot/mod.rs index fec2c60d32..d08eee631a 100644 --- a/naga/src/back/dot/mod.rs +++ b/naga/src/back/dot/mod.rs @@ -296,14 +296,14 @@ impl StatementGraph { self.emits.push((id, result)); "SubgroupCollectiveOperation" // FIXME } - S::SubgroupBroadcast { + S::SubgroupGather { ref mode, argument, result, } => { self.dependencies.push((id, argument, "arg")); self.emits.push((id, result)); - "SubgroupBroadcast" // FIXME + "SubgroupGather" // FIXME } }; // Set the last node to the merge node diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index 948fc3645b..ae24109eee 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -2317,7 +2317,7 @@ impl<'a, W: Write> Writer<'a, W> { } => { unimplemented!(); // FIXME: } - Statement::SubgroupBroadcast { + Statement::SubgroupGather { ref mode, argument, result, diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 98b60f4954..222a0bf02c 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -2026,7 +2026,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } => { unimplemented!(); // FIXME } - Statement::SubgroupBroadcast { + Statement::SubgroupGather { ref mode, argument, result, diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 30cf02ddc4..9d1bc84955 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -3034,7 +3034,7 @@ impl Writer { } => { unimplemented!(); // FIXME } - crate::Statement::SubgroupBroadcast { + crate::Statement::SubgroupGather { ref mode, argument, result, diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index 222d0cde39..9fba489f79 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -2373,7 +2373,7 @@ impl<'w> BlockContext<'w> { } => { self.write_subgroup_operation(op, collective_op, argument, result, &mut block)?; } - crate::Statement::SubgroupBroadcast { + crate::Statement::SubgroupGather { ref mode, argument, result, diff --git a/naga/src/back/spv/subgroup.rs b/naga/src/back/spv/subgroup.rs index 3c8b7d827a..ca193562b7 100644 --- a/naga/src/back/spv/subgroup.rs +++ b/naga/src/back/spv/subgroup.rs @@ -94,7 +94,7 @@ impl<'w> BlockContext<'w> { } pub(super) fn write_subgroup_broadcast( &mut self, - mode: &crate::BroadcastMode, + mode: &crate::GatherMode, argument: Handle, result: Handle, block: &mut Block, @@ -112,7 +112,7 @@ impl<'w> BlockContext<'w> { let arg_id = self.cached[argument]; match mode { - crate::BroadcastMode::Index(index) => { + crate::GatherMode::Broadcast(index) => { let index_id = self.cached[*index]; block.body.push(Instruction::group_non_uniform_broadcast( result_type_id, @@ -122,7 +122,7 @@ impl<'w> BlockContext<'w> { index_id, )); } - crate::BroadcastMode::First => { + crate::GatherMode::BroadcastFirst => { block .body .push(Instruction::group_non_uniform_broadcast_first( diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 91718d4f4d..c6217ea2a8 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -941,7 +941,7 @@ impl Writer { } => { unimplemented!() // FIXME } - Statement::SubgroupBroadcast { + Statement::SubgroupGather { ref mode, argument, result, diff --git a/naga/src/compact/statements.rs b/naga/src/compact/statements.rs index 0e8c0e4e81..14526bef61 100644 --- a/naga/src/compact/statements.rs +++ b/naga/src/compact/statements.rs @@ -110,12 +110,12 @@ impl FunctionTracer<'_> { self.trace_expression(argument); self.trace_expression(result); } - St::SubgroupBroadcast { + St::SubgroupGather { ref mode, argument, result, } => { - if let crate::BroadcastMode::Index(expr) = *mode { + if let crate::GatherMode::Broadcast(expr) = *mode { self.trace_expression(expr); } self.trace_expression(argument); @@ -288,12 +288,12 @@ impl FunctionMap { adjust(argument); adjust(result); } - St::SubgroupBroadcast { + St::SubgroupGather { ref mut mode, ref mut argument, ref mut result, } => { - if let crate::BroadcastMode::Index(expr) = mode { + if let crate::GatherMode::Broadcast(expr) = mode { adjust(expr); } adjust(argument); diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index 561e32e22d..c78407ef99 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -3844,7 +3844,7 @@ impl> Frontend { S::WorkGroupUniformLoad { .. } => unreachable!(), S::SubgroupBallot { .. } => unreachable!(), // FIXME?? S::SubgroupCollectiveOperation { .. } => unreachable!(), - S::SubgroupBroadcast { .. } => unreachable!(), + S::SubgroupGather { .. } => unreachable!(), } i += 1; } diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index d3a353be0c..e97da5db7f 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -2282,8 +2282,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { )?; let rctx = ctx.runtime_expression_ctx(span)?; rctx.block.push( - crate::Statement::SubgroupBroadcast { - mode: crate::BroadcastMode::Index(index), + crate::Statement::SubgroupGather { + mode: crate::GatherMode::Broadcast(index), argument, result, }, @@ -2305,8 +2305,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { )?; let rctx = ctx.runtime_expression_ctx(span)?; rctx.block.push( - crate::Statement::SubgroupBroadcast { - mode: crate::BroadcastMode::First, + crate::Statement::SubgroupGather { + mode: crate::GatherMode::BroadcastFirst, argument, result, }, diff --git a/naga/src/lib.rs b/naga/src/lib.rs index b50b1c1f23..cf16e6786c 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -436,8 +436,8 @@ pub enum BuiltIn { WorkGroupSize, NumWorkGroups, // subgroup - SubgroupInvocationId, SubgroupSize, + SubgroupInvocationId, } /// Number of bytes per scalar. @@ -1265,9 +1265,9 @@ pub enum SwizzleComponent { #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] -pub enum BroadcastMode { - First, - Index(Handle), +pub enum GatherMode { + BroadcastFirst, + Broadcast(Handle), } #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] @@ -1899,9 +1899,9 @@ pub enum Statement { predicate: Option>, }, - SubgroupBroadcast { - /// Specifies which thread to broadcast from - mode: BroadcastMode, + SubgroupGather { + /// Specifies which thread to gather from + mode: GatherMode, /// The value to broadcast over argument: Handle, /// The [`SubgroupOperationResult`] expression representing this load's result. diff --git a/naga/src/proc/terminator.rs b/naga/src/proc/terminator.rs index 35111a11de..5edf55cb73 100644 --- a/naga/src/proc/terminator.rs +++ b/naga/src/proc/terminator.rs @@ -39,7 +39,7 @@ pub fn ensure_block_returns(block: &mut crate::Block) { | S::WorkGroupUniformLoad { .. } | S::SubgroupBallot { .. } | S::SubgroupCollectiveOperation { .. } - | S::SubgroupBroadcast { .. } + | S::SubgroupGather { .. } | S::Barrier(_)), ) | None => block.push(S::Return { value: None }, Default::default()), diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index 262e3ffec7..deb775aebf 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -1009,13 +1009,13 @@ impl FunctionInfo { let _ = self.add_ref(argument); // FIXME FunctionUniformity::new() } - S::SubgroupBroadcast { + S::SubgroupGather { ref mode, argument, result, } => { let _ = self.add_ref(argument); - if let crate::BroadcastMode::Index(expr) = *mode { + if let crate::GatherMode::Broadcast(expr) = *mode { let _ = self.add_ref(expr); } FunctionUniformity::new() diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index 91127a40ad..03e67f5f18 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -473,12 +473,12 @@ impl super::Validator { #[cfg(feature = "validate")] fn validate_subgroup_broadcast( &mut self, - mode: &crate::BroadcastMode, + mode: &crate::GatherMode, argument: Handle, result: Handle, context: &BlockContext, ) -> Result<(), WithSpan> { - if let crate::BroadcastMode::Index(expr) = *mode { + if let crate::GatherMode::Broadcast(expr) = *mode { let index_ty = context.resolve_type(expr, &self.valid_expression_set)?; match index_ty { crate::TypeInner::Scalar { @@ -1056,7 +1056,7 @@ impl super::Validator { } => { self.validate_subgroup_operation(op, collective_op, argument, result, context)?; } - S::SubgroupBroadcast { + S::SubgroupGather { ref mode, argument, result, diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index 91209b460b..6d9923a23b 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -547,8 +547,8 @@ impl super::Validator { Ok(()) } crate::Statement::SubgroupCollectiveOperation { - op, - collective_op, + op: _, + collective_op: _, argument, result, } => { @@ -556,13 +556,15 @@ impl super::Validator { validate_expr(result)?; Ok(()) } - crate::Statement::SubgroupBroadcast { + crate::Statement::SubgroupGather { mode, argument, result, } => { - if let crate::BroadcastMode::Index(expr) = mode { - validate_expr(expr)?; + validate_expr(argument)?; + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) => validate_expr(index)?, } validate_expr(result)?; Ok(()) From 7c911ba3bf96f10dc8396a23aa8d31a27093dac4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Mei=C3=9Fner?= Date: Sat, 14 Oct 2023 18:35:27 +0200 Subject: [PATCH 12/46] General fixes. --- naga/src/back/dot/mod.rs | 6 +-- naga/src/back/glsl/mod.rs | 6 +-- naga/src/back/hlsl/writer.rs | 6 +-- naga/src/back/msl/writer.rs | 6 +-- naga/src/back/wgsl/writer.rs | 6 +-- naga/src/compact/statements.rs | 22 ++++---- naga/src/proc/constant_evaluator.rs | 11 ++-- naga/src/valid/analyzer.rs | 14 ++--- naga/src/valid/expression.rs | 2 +- naga/src/valid/function.rs | 81 ++++++++++++++++------------- 10 files changed, 86 insertions(+), 74 deletions(-) diff --git a/naga/src/back/dot/mod.rs b/naga/src/back/dot/mod.rs index d08eee631a..3174b7c6b6 100644 --- a/naga/src/back/dot/mod.rs +++ b/naga/src/back/dot/mod.rs @@ -287,8 +287,8 @@ impl StatementGraph { "SubgroupBallot" } S::SubgroupCollectiveOperation { - ref op, - ref collective_op, + op, + collective_op, argument, result, } => { @@ -297,7 +297,7 @@ impl StatementGraph { "SubgroupCollectiveOperation" // FIXME } S::SubgroupGather { - ref mode, + mode, argument, result, } => { diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index ae24109eee..c8fe4f2fa0 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -2310,15 +2310,15 @@ impl<'a, W: Write> Writer<'a, W> { write!(self.out, ");")?; } Statement::SubgroupCollectiveOperation { - ref op, - ref collective_op, + op, + collective_op, argument, result, } => { unimplemented!(); // FIXME: } Statement::SubgroupGather { - ref mode, + mode, argument, result, } => { diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 222a0bf02c..a04fef7a1b 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -2019,15 +2019,15 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { writeln!(self.out, ");")?; } Statement::SubgroupCollectiveOperation { - ref op, - ref collective_op, + op, + collective_op, argument, result, } => { unimplemented!(); // FIXME } Statement::SubgroupGather { - ref mode, + mode, argument, result, } => { diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 9d1bc84955..629b98c96d 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -3027,15 +3027,15 @@ impl Writer { writeln!(self.out, ");")?; } crate::Statement::SubgroupCollectiveOperation { - ref op, - ref collective_op, + op, + collective_op, argument, result, } => { unimplemented!(); // FIXME } crate::Statement::SubgroupGather { - ref mode, + mode, argument, result, } => { diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index c6217ea2a8..a2b2497c0c 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -934,15 +934,15 @@ impl Writer { writeln!(self.out, ");")?; } Statement::SubgroupCollectiveOperation { - ref op, - ref collective_op, + op, + collective_op, argument, result, } => { unimplemented!() // FIXME } Statement::SubgroupGather { - ref mode, + mode, argument, result, } => { diff --git a/naga/src/compact/statements.rs b/naga/src/compact/statements.rs index 14526bef61..04a184daf8 100644 --- a/naga/src/compact/statements.rs +++ b/naga/src/compact/statements.rs @@ -102,8 +102,8 @@ impl FunctionTracer<'_> { self.trace_expression(result); } St::SubgroupCollectiveOperation { - ref op, - ref collective_op, + op: _, + collective_op: _, argument, result, } => { @@ -111,12 +111,13 @@ impl FunctionTracer<'_> { self.trace_expression(result); } St::SubgroupGather { - ref mode, + mode, argument, result, } => { - if let crate::GatherMode::Broadcast(expr) = *mode { - self.trace_expression(expr); + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) => self.trace_expression(index), } self.trace_expression(argument); self.trace_expression(result); @@ -274,14 +275,14 @@ impl FunctionMap { ref mut result, ref mut predicate, } => { - if let Some(ref mut predicate) = predicate { + if let Some(ref mut predicate) = *predicate { adjust(predicate); } adjust(result); } St::SubgroupCollectiveOperation { - ref mut op, - ref mut collective_op, + op: _, + collective_op: _, ref mut argument, ref mut result, } => { @@ -293,8 +294,9 @@ impl FunctionMap { ref mut argument, ref mut result, } => { - if let crate::GatherMode::Broadcast(expr) = mode { - adjust(expr); + match *mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(ref mut index) => adjust(index), } adjust(argument); adjust(result); diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index d13a66ec2f..c945805813 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -133,8 +133,8 @@ pub enum ConstantEvaluatorError { ImageExpression, #[error("Constants don't support ray query expressions")] RayQueryExpression, - #[error("Constants don't support subgroup operations")] - SubgroupOperation, + #[error("Constants don't support subgroup expressions")] + SubgroupExpression, #[error("Cannot access the type")] InvalidAccessBase, #[error("Cannot access at the index")] @@ -441,8 +441,11 @@ impl<'a> ConstantEvaluator<'a> { Expression::RayQueryProceedResult | Expression::RayQueryGetIntersection { .. } => { Err(ConstantEvaluatorError::RayQueryExpression) } - Expression::SubgroupBallotResult | Expression::SubgroupOperationResult { .. } => { - Err(ConstantEvaluatorError::SubgroupOperation) + Expression::SubgroupBallotResult { .. } => { + Err(ConstantEvaluatorError::SubgroupExpression) + } + Expression::SubgroupOperationResult { .. } => { + Err(ConstantEvaluatorError::SubgroupExpression) } } } diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index deb775aebf..4095f03e41 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -744,7 +744,7 @@ impl FunctionInfo { non_uniform_result: None, // FIXME requirements: UniformityRequirements::empty(), }, - E::SubgroupOperationResult { ty } => Uniformity { + E::SubgroupOperationResult { .. } => Uniformity { non_uniform_result: None, // FIXME requirements: UniformityRequirements::empty(), }, @@ -1001,21 +1001,21 @@ impl FunctionInfo { FunctionUniformity::new() } S::SubgroupCollectiveOperation { - ref op, - ref collective_op, + op: _, + collective_op: _, argument, result: _, } => { - let _ = self.add_ref(argument); // FIXME + let _ = self.add_ref(argument); FunctionUniformity::new() } S::SubgroupGather { - ref mode, + mode, argument, - result, + result: _, } => { let _ = self.add_ref(argument); - if let crate::GatherMode::Broadcast(expr) = *mode { + if let crate::GatherMode::Broadcast(expr) = mode { let _ = self.add_ref(expr); } FunctionUniformity::new() diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index f840bba42d..03ad851dbf 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -1538,7 +1538,7 @@ impl super::Validator { } }, E::SubgroupBallotResult => ShaderStages::COMPUTE | ShaderStages::FRAGMENT, - E::SubgroupOperationResult { ty } => ShaderStages::COMPUTE, // FIXME + E::SubgroupOperationResult { .. } => ShaderStages::COMPUTE, // FIXME }; Ok(stages) } diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index 03e67f5f18..e88b763258 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -435,15 +435,20 @@ impl super::Validator { ) -> Result<(), WithSpan> { let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; - let (is_scalar, kind) = match argument_inner { + let (is_scalar, kind) = match *argument_inner { crate::TypeInner::Scalar { kind, .. } => (true, kind), crate::TypeInner::Vector { kind, .. } => (false, kind), - _ => unimplemented!(), + _ => { + log::error!("Subgroup operand type {:?}", argument_inner); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(argument, context.expressions) + .into_other()); + } }; use crate::ScalarKind as sk; use crate::SubgroupOperation as sg; - match (kind, op) { + match (kind, *op) { (sk::Bool, sg::All | sg::Any) if is_scalar => {} (sk::Sint | sk::Uint | sk::Float, sg::Add | sg::Mul | sg::Min | sg::Max) => {} (sk::Sint | sk::Uint | sk::Bool, sg::And | sg::Or | sg::Xor) => {} @@ -478,34 +483,36 @@ impl super::Validator { result: Handle, context: &BlockContext, ) -> Result<(), WithSpan> { - if let crate::GatherMode::Broadcast(expr) = *mode { - let index_ty = context.resolve_type(expr, &self.valid_expression_set)?; - match index_ty { - crate::TypeInner::Scalar { - kind: crate::ScalarKind::Uint, - .. - } => {} - _ => { - log::error!( - "Subgroup broadcast index type {:?}, expected unsigned int", - index_ty - ); - return Err(SubgroupError::InvalidOperand(argument) - .with_span_handle(expr, context.expressions) - .into_other()); + match *mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) => { + let index_ty = context.resolve_type(index, &self.valid_expression_set)?; + match *index_ty { + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Uint, + .. + } => {} + _ => { + log::error!( + "Subgroup gather index type {:?}, expected unsigned int", + index_ty + ); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(index, context.expressions) + .into_other()); + } } } } let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; - - match argument_inner { - crate::TypeInner::Scalar { .. } | crate::TypeInner::Vector { .. } => {} - _ => { - log::error!("Subgroup broadcast operand type {:?}", argument_inner); - return Err(SubgroupError::InvalidOperand(argument) - .with_span_handle(argument, context.expressions) - .into_other()); - } + if !matches!(*argument_inner, + crate::TypeInner::Scalar { kind, .. } | crate::TypeInner::Vector { kind, .. } + if matches!(kind, crate::ScalarKind::Uint | crate::ScalarKind::Sint | crate::ScalarKind::Float) + ) { + log::error!("Subgroup gather operand type {:?}", argument_inner); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(argument, context.expressions) + .into_other()); } self.emit_expression(result, context)?; @@ -1030,20 +1037,20 @@ impl super::Validator { if let Some(predicate) = predicate { let predicate_inner = context.resolve_type(predicate, &self.valid_expression_set)?; - match predicate_inner { + if !matches!( + *predicate_inner, crate::TypeInner::Scalar { kind: crate::ScalarKind::Bool, .. - } => {} - _ => { - log::error!( - "Subgroup ballot predicate type {:?} expected bool", - predicate_inner - ); - return Err(SubgroupError::InvalidOperand(predicate) - .with_span_handle(predicate, context.expressions) - .into_other()); } + ) { + log::error!( + "Subgroup ballot predicate type {:?} expected bool", + predicate_inner + ); + return Err(SubgroupError::InvalidOperand(predicate) + .with_span_handle(predicate, context.expressions) + .into_other()); } } self.emit_expression(result, context)?; From f5c4ad7ab3cd8fc48e458422ae6632499c67f58c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Mei=C3=9Fner?= Date: Sat, 14 Oct 2023 18:55:14 +0200 Subject: [PATCH 13/46] Adds BuiltIn::NumSubgroups, BuiltIn::SubgroupId. --- naga/src/back/glsl/mod.rs | 1 + naga/src/back/hlsl/conv.rs | 1 + naga/src/back/msl/mod.rs | 1 + naga/src/back/spv/writer.rs | 1 + naga/src/back/wgsl/writer.rs | 1 + naga/src/lib.rs | 2 ++ naga/src/valid/interface.rs | 4 ++-- 7 files changed, 9 insertions(+), 2 deletions(-) diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index c8fe4f2fa0..abef3f9d94 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -4237,6 +4237,7 @@ const fn glsl_built_in( Bi::WorkGroupSize => "gl_WorkGroupSize", Bi::NumWorkGroups => "gl_NumWorkGroups", // subgroup + Bi::NumSubgroups | Bi::SubgroupId => todo!(), Bi::SubgroupInvocationId => "gl_SubgroupInvocationID", Bi::SubgroupSize => "gl_SubgroupSize", } diff --git a/naga/src/back/hlsl/conv.rs b/naga/src/back/hlsl/conv.rs index 19c4da5e74..3f51278cc1 100644 --- a/naga/src/back/hlsl/conv.rs +++ b/naga/src/back/hlsl/conv.rs @@ -167,6 +167,7 @@ impl crate::BuiltIn { // in `Writer::write_expr`. Self::NumWorkGroups => "SV_GroupID", + Self::NumSubgroups | Self::SubgroupId => todo!(), Self::SubgroupInvocationId | Self::SubgroupSize | Self::BaseInstance diff --git a/naga/src/back/msl/mod.rs b/naga/src/back/msl/mod.rs index 9e23d2a08d..4e0d8489e4 100644 --- a/naga/src/back/msl/mod.rs +++ b/naga/src/back/msl/mod.rs @@ -438,6 +438,7 @@ impl ResolvedBinding { Bi::WorkGroupSize => "dispatch_threads_per_threadgroup", Bi::NumWorkGroups => "threadgroups_per_grid", // subgroup + Bi::NumSubgroups | Bi::SubgroupId => todo!(), Bi::SubgroupInvocationId => "simdgroup_index_in_threadgroup", Bi::SubgroupSize => "simdgroups_per_threadgroup", Bi::CullDistance | Bi::ViewIndex => { diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 077751d90b..6a55288308 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -1594,6 +1594,7 @@ impl Writer { Bi::WorkGroupSize => BuiltIn::WorkgroupSize, Bi::NumWorkGroups => BuiltIn::NumWorkgroups, // Subgroup + Bi::NumSubgroups | Bi::SubgroupId => todo!(), Bi::SubgroupInvocationId => { self.require_any( "`subgroup_invocation_id` built-in", diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index a2b2497c0c..8fc9881f10 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -1789,6 +1789,7 @@ fn builtin_str(built_in: crate::BuiltIn) -> Result<&'static str, Error> { Bi::SampleMask => "sample_mask", Bi::PrimitiveIndex => "primitive_index", Bi::ViewIndex => "view_index", + Bi::NumSubgroups | Bi::SubgroupId => todo!(), Bi::SubgroupInvocationId => "subgroup_invocation_id", Bi::SubgroupSize => "subgroup_size", Bi::BaseInstance diff --git a/naga/src/lib.rs b/naga/src/lib.rs index cf16e6786c..23fc871325 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -436,6 +436,8 @@ pub enum BuiltIn { WorkGroupSize, NumWorkGroups, // subgroup + NumSubgroups, + SubgroupId, SubgroupSize, SubgroupInvocationId, } diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index bf4f397224..4b5f66492c 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -299,7 +299,7 @@ impl VaryingContext<'_> { width, }, ), - Bi::SubgroupInvocationId => ( + Bi::NumSubgroups | Bi::SubgroupId => ( self.stage == St::Compute && !self.output, *ty_inner == Ti::Scalar { @@ -307,7 +307,7 @@ impl VaryingContext<'_> { width, }, ), - Bi::SubgroupSize => ( + Bi::SubgroupSize | Bi::SubgroupInvocationId => ( match self.stage { St::Compute | St::Fragment => !self.output, St::Vertex => false, From f931a4716f0b7cf6113c90c21ccbdc3a1d2147e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Mei=C3=9Fner?= Date: Sat, 14 Oct 2023 18:55:14 +0200 Subject: [PATCH 14/46] Adds GatherMode::Shuffle, GatherMode::ShuffleDown, GatherMode::ShuffleUp, GatherMode::ShuffleXor. --- naga/src/back/spv/subgroup.rs | 4 ++++ naga/src/compact/statements.rs | 12 ++++++++++-- naga/src/lib.rs | 4 ++++ naga/src/valid/analyzer.rs | 11 +++++++++-- naga/src/valid/function.rs | 6 +++++- naga/src/valid/handles.rs | 6 +++++- 6 files changed, 37 insertions(+), 6 deletions(-) diff --git a/naga/src/back/spv/subgroup.rs b/naga/src/back/spv/subgroup.rs index ca193562b7..7206ec9312 100644 --- a/naga/src/back/spv/subgroup.rs +++ b/naga/src/back/spv/subgroup.rs @@ -132,6 +132,10 @@ impl<'w> BlockContext<'w> { arg_id, )); } + crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => todo!(), } self.cached[result] = id; Ok(()) diff --git a/naga/src/compact/statements.rs b/naga/src/compact/statements.rs index 04a184daf8..37074c4299 100644 --- a/naga/src/compact/statements.rs +++ b/naga/src/compact/statements.rs @@ -117,7 +117,11 @@ impl FunctionTracer<'_> { } => { match mode { crate::GatherMode::BroadcastFirst => {} - crate::GatherMode::Broadcast(index) => self.trace_expression(index), + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => self.trace_expression(index), } self.trace_expression(argument); self.trace_expression(result); @@ -296,7 +300,11 @@ impl FunctionMap { } => { match *mode { crate::GatherMode::BroadcastFirst => {} - crate::GatherMode::Broadcast(ref mut index) => adjust(index), + crate::GatherMode::Broadcast(ref mut index) + | crate::GatherMode::Shuffle(ref mut index) + | crate::GatherMode::ShuffleDown(ref mut index) + | crate::GatherMode::ShuffleUp(ref mut index) + | crate::GatherMode::ShuffleXor(ref mut index) => adjust(index), } adjust(argument); adjust(result); diff --git a/naga/src/lib.rs b/naga/src/lib.rs index 23fc871325..08f535f565 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -1270,6 +1270,10 @@ pub enum SwizzleComponent { pub enum GatherMode { BroadcastFirst, Broadcast(Handle), + Shuffle(Handle), + ShuffleDown(Handle), + ShuffleUp(Handle), + ShuffleXor(Handle), } #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index 4095f03e41..f4347df1dd 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -1015,8 +1015,15 @@ impl FunctionInfo { result: _, } => { let _ = self.add_ref(argument); - if let crate::GatherMode::Broadcast(expr) = mode { - let _ = self.add_ref(expr); + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + let _ = self.add_ref(index); + } } FunctionUniformity::new() } diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index e88b763258..729a6405c2 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -485,7 +485,11 @@ impl super::Validator { ) -> Result<(), WithSpan> { match *mode { crate::GatherMode::BroadcastFirst => {} - crate::GatherMode::Broadcast(index) => { + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { let index_ty = context.resolve_type(index, &self.valid_expression_set)?; match *index_ty { crate::TypeInner::Scalar { diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index 6d9923a23b..e1674f0804 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -564,7 +564,11 @@ impl super::Validator { validate_expr(argument)?; match mode { crate::GatherMode::BroadcastFirst => {} - crate::GatherMode::Broadcast(index) => validate_expr(index)?, + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => validate_expr(index)?, } validate_expr(result)?; Ok(()) From d8e17e7998893ef4c437be91bb941c011bbf5026 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Mei=C3=9Fner?= Date: Sat, 14 Oct 2023 19:35:39 +0200 Subject: [PATCH 15/46] Implements all frontends and backends. --- naga/src/back/dot/mod.rs | 66 ++++++++- naga/src/back/glsl/features.rs | 23 ++++ naga/src/back/glsl/mod.rs | 102 +++++++++++++- naga/src/back/hlsl/conv.rs | 14 +- naga/src/back/hlsl/writer.rs | 177 +++++++++++++++++++++++-- naga/src/back/msl/mod.rs | 7 +- naga/src/back/msl/writer.rs | 109 +++++++++++++-- naga/src/back/spv/block.rs | 31 +---- naga/src/back/spv/instructions.rs | 17 +-- naga/src/back/spv/subgroup.rs | 111 ++++++++++++---- naga/src/back/spv/writer.rs | 26 +++- naga/src/back/wgsl/writer.rs | 98 +++++++++++++- naga/src/front/spv/error.rs | 6 +- naga/src/front/spv/mod.rs | 213 +++++++++++++++++++++++++++++- naga/src/front/wgsl/lower/mod.rs | 98 +++++++------- naga/src/front/wgsl/parse/conv.rs | 25 +++- 16 files changed, 959 insertions(+), 164 deletions(-) diff --git a/naga/src/back/dot/mod.rs b/naga/src/back/dot/mod.rs index 3174b7c6b6..86f4797b56 100644 --- a/naga/src/back/dot/mod.rs +++ b/naga/src/back/dot/mod.rs @@ -294,16 +294,78 @@ impl StatementGraph { } => { self.dependencies.push((id, argument, "arg")); self.emits.push((id, result)); - "SubgroupCollectiveOperation" // FIXME + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + "SubgroupAll" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + "SubgroupAny" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + "SubgroupAdd" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + "SubgroupMul" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + "SubgroupMax" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + "SubgroupMin" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + "SubgroupAnd" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + "SubgroupOr" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + "SubgroupXor" + } + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Add, + ) => "SubgroupPrefixExclusiveAdd", + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Mul, + ) => "SubgroupPrefixExclusiveMul", + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Add, + ) => "SubgroupPrefixInclusiveAdd", + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Mul, + ) => "SubgroupPrefixInclusiveMul", + _ => unimplemented!(), + } } S::SubgroupGather { mode, argument, result, } => { + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + self.dependencies.push((id, index, "index")) + } + } self.dependencies.push((id, argument, "arg")); self.emits.push((id, result)); - "SubgroupGather" // FIXME + match mode { + crate::GatherMode::BroadcastFirst => "SubgroupBroadcastFirst", + crate::GatherMode::Broadcast(_) => "SubgroupBroadcast", + crate::GatherMode::Shuffle(_) => "SubgroupShuffle", + crate::GatherMode::ShuffleDown(_) => "SubgroupShuffleDown", + crate::GatherMode::ShuffleUp(_) => "SubgroupShuffleUp", + crate::GatherMode::ShuffleXor(_) => "SubgroupShuffleXor", + } } }; // Set the last node to the merge node diff --git a/naga/src/back/glsl/features.rs b/naga/src/back/glsl/features.rs index d6f3aae35d..39beda997e 100644 --- a/naga/src/back/glsl/features.rs +++ b/naga/src/back/glsl/features.rs @@ -43,6 +43,8 @@ bitflags::bitflags! { const IMAGE_SIZE = 1 << 20; /// Dual source blending const DUAL_SOURCE_BLENDING = 1 << 21; + /// Subgroup operations + const SUBGROUP_OPERATIONS = 1 << 22; } } @@ -105,6 +107,7 @@ impl FeaturesManager { check_feature!(SAMPLE_VARIABLES, 400, 300); check_feature!(DYNAMIC_ARRAY_SIZE, 430, 310); check_feature!(DUAL_SOURCE_BLENDING, 330, 300 /* with extension */); + check_feature!(SUBGROUP_OPERATIONS, 430, 310); match version { Version::Embedded { is_webgl: true, .. } => check_feature!(MULTI_VIEW, 140, 300), _ => check_feature!(MULTI_VIEW, 140, 310), @@ -234,6 +237,22 @@ impl FeaturesManager { writeln!(out, "#extension GL_EXT_blend_func_extended : require")?; } + if self.0.contains(Features::SUBGROUP_OPERATIONS) { + // https://registry.khronos.org/OpenGL/extensions/KHR/KHR_shader_subgroup.txt + writeln!(out, "#extension GL_KHR_shader_subgroup_basic : require")?; + writeln!(out, "#extension GL_KHR_shader_subgroup_vote : require")?; + writeln!( + out, + "#extension GL_KHR_shader_subgroup_arithmetic : require" + )?; + writeln!(out, "#extension GL_KHR_shader_subgroup_ballot : require")?; + writeln!(out, "#extension GL_KHR_shader_subgroup_shuffle : require")?; + writeln!( + out, + "#extension GL_KHR_shader_subgroup_shuffle_relative : require" + )?; + } + Ok(()) } } @@ -454,6 +473,10 @@ impl<'a, W> Writer<'a, W> { } } } + Expression::SubgroupBallotResult | + Expression::SubgroupOperationResult { .. } => { + features.request(Features::SUBGROUP_OPERATIONS) + } _ => {} } } diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index abef3f9d94..7ee460896f 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -2307,7 +2307,7 @@ impl<'a, W: Write> Writer<'a, W> { Some(predicate) => self.write_expr(predicate, ctx)?, None => write!(self.out, "true")?, } - write!(self.out, ");")?; + writeln!(self.out, ");")?; } Statement::SubgroupCollectiveOperation { op, @@ -2315,14 +2315,103 @@ impl<'a, W: Write> Writer<'a, W> { argument, result, } => { - unimplemented!(); // FIXME: + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + let res_ty = ctx.info[result].ty.inner_with(&self.module.types); + self.write_value_type(res_ty)?; + write!(self.out, " {res_name} = ")?; + self.named_expressions.insert(result, res_name); + + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "subgroupAll(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "subgroupAny(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupAdd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupMul(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "subgroupMax(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "subgroupMin(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "subgroupAnd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "subgroupOr(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "subgroupXor(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupExclusiveAdd(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupExclusiveMul(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupInclusiveAdd(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupInclusiveMul(")? + } + _ => unimplemented!(), + } + self.write_expr(argument, ctx)?; + writeln!(self.out, ");")?; } Statement::SubgroupGather { mode, argument, result, } => { - unimplemented!(); // FIXME + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + let res_ty = ctx.info[result].ty.inner_with(&self.module.types); + self.write_value_type(res_ty)?; + write!(self.out, " {res_name} = ")?; + self.named_expressions.insert(result, res_name); + + match mode { + crate::GatherMode::BroadcastFirst => { + write!(self.out, "subgroupBroadcastFirst(")?; + } + crate::GatherMode::Broadcast(_) => { + write!(self.out, "subgroupBroadcast(")?; + } + crate::GatherMode::Shuffle(_) => { + write!(self.out, "subgroupShuffle(")?; + } + crate::GatherMode::ShuffleDown(_) => { + write!(self.out, "subgroupShuffleDown(")?; + } + crate::GatherMode::ShuffleUp(_) => { + write!(self.out, "subgroupShuffleUp(")?; + } + crate::GatherMode::ShuffleXor(_) => { + write!(self.out, "subgroupShuffleXor(")?; + } + } + self.write_expr(argument, ctx)?; + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + write!(self.out, ", ")?; + self.write_expr(index, ctx)?; + } + } + writeln!(self.out, ");")?; } } @@ -4057,7 +4146,7 @@ impl<'a, W: Write> Writer<'a, W> { writeln!(self.out, "{level}memoryBarrierShared();")?; } if flags.contains(crate::Barrier::SUB_GROUP) { - unimplemented!() // FIXME + writeln!(self.out, "{level}subgroupMemoryBarrier();")?; } writeln!(self.out, "{level}barrier();")?; Ok(()) @@ -4237,9 +4326,10 @@ const fn glsl_built_in( Bi::WorkGroupSize => "gl_WorkGroupSize", Bi::NumWorkGroups => "gl_NumWorkGroups", // subgroup - Bi::NumSubgroups | Bi::SubgroupId => todo!(), - Bi::SubgroupInvocationId => "gl_SubgroupInvocationID", + Bi::NumSubgroups => "gl_NumSubgroups", + Bi::SubgroupId => "gl_SubgroupID", Bi::SubgroupSize => "gl_SubgroupSize", + Bi::SubgroupInvocationId => "gl_SubgroupInvocationID", } } diff --git a/naga/src/back/hlsl/conv.rs b/naga/src/back/hlsl/conv.rs index 3f51278cc1..d3fb76e401 100644 --- a/naga/src/back/hlsl/conv.rs +++ b/naga/src/back/hlsl/conv.rs @@ -166,13 +166,13 @@ impl crate::BuiltIn { // to this field will get replaced with references to `SPECIAL_CBUF_VAR` // in `Writer::write_expr`. Self::NumWorkGroups => "SV_GroupID", - - Self::NumSubgroups | Self::SubgroupId => todo!(), - Self::SubgroupInvocationId - | Self::SubgroupSize - | Self::BaseInstance - | Self::BaseVertex - | Self::WorkGroupSize => return Err(Error::Unimplemented(format!("builtin {self:?}"))), + Self::SubgroupSize + | Self::SubgroupInvocationId + | Self::NumSubgroups + | Self::SubgroupId => return Err(Error::Unimplemented(format!("builtin {self:?}"))), + Self::BaseInstance | Self::BaseVertex | Self::WorkGroupSize => { + return Err(Error::Unimplemented(format!("builtin {self:?}"))) + } Self::PointSize | Self::ViewIndex | Self::PointCoord => { return Err(Error::Custom(format!("Unsupported builtin {self:?}"))) } diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index a04fef7a1b..1eab43a4c3 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -1130,7 +1130,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, " {name}(")?; let need_workgroup_variables_initialization = - self.need_workgroup_variables_initialization(func_ctx, module); + self.need_workgroup_variables_initialization(func, func_ctx, module); // Write function arguments for non entry point functions match func_ctx.ty { @@ -1166,7 +1166,21 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name,)?; } else { let stage = module.entry_points[ep_index as usize].stage; + let mut arg_num = 0; for (index, arg) in func.arguments.iter().enumerate() { + if matches!( + arg.binding, + Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) + | Some(crate::Binding::BuiltIn( + crate::BuiltIn::SubgroupInvocationId + )) + | Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) + | Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) + ) { + continue; + } + arg_num += 1; + if index != 0 { write!(self.out, ", ")?; } @@ -1186,7 +1200,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } if need_workgroup_variables_initialization { - if !func.arguments.is_empty() { + if arg_num > 0 { write!(self.out, ", ")?; } write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?; @@ -1217,6 +1231,53 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_workgroup_variables_initialization(func_ctx, module)?; } + if let back::FunctionType::EntryPoint(ep_index) = func_ctx.ty { + let ep = &module.entry_points[ep_index as usize]; + for (index, arg) in func.arguments.iter().enumerate() { + if let Some(crate::Binding::BuiltIn(builtin)) = arg.binding { + if matches!( + builtin, + crate::BuiltIn::SubgroupSize + | crate::BuiltIn::SubgroupInvocationId + | crate::BuiltIn::NumSubgroups + | crate::BuiltIn::SubgroupId + ) { + let level = back::Level(1); + write!(self.out, "{level}const ")?; + + self.write_type(module, arg.ty)?; + + let argument_name = + &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)]; + write!(self.out, " {argument_name} = ")?; + + match builtin { + crate::BuiltIn::SubgroupSize => { + writeln!(self.out, "WaveGetLaneCount();")? + } + crate::BuiltIn::SubgroupInvocationId => { + writeln!(self.out, "WaveGetLaneIndex();")? + } + crate::BuiltIn::NumSubgroups => writeln!( + self.out, + "({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount();", + ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2] + )?, + crate::BuiltIn::SubgroupId => { + writeln!( + self.out, + "(__local_invocation_id.x * {}u + __local_invocation_id.y * {}u + __local_invocation_id.z) / WaveGetLaneCount();", + ep.workgroup_size[0] * ep.workgroup_size[1], + ep.workgroup_size[1], + )?; + } + _ => unreachable!(), + } + } + } + } + } + if let back::FunctionType::EntryPoint(index) = func_ctx.ty { self.write_ep_arguments_initialization(module, func, index)?; } @@ -1267,14 +1328,20 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { fn need_workgroup_variables_initialization( &mut self, + func: &crate::Function, func_ctx: &back::FunctionCtx, module: &Module, ) -> bool { - self.options.zero_initialize_workgroup_memory + func.arguments.iter().any(|arg| { + matches!( + arg.binding, + Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) + ) + }) || (self.options.zero_initialize_workgroup_memory && func_ctx.ty.is_compute_entry_point(module) && module.global_variables.iter().any(|(handle, var)| { !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup - }) + })) } fn write_workgroup_variables_initialization( @@ -2006,7 +2073,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { Statement::RayQuery { .. } => unreachable!(), Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; - let name = format!("{}{}", back::BAKE_PREFIX, result.index()); write!(self.out, "const uint4 {name} = ")?; self.named_expressions.insert(result, name); @@ -2024,14 +2090,109 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { argument, result, } => { - unimplemented!(); // FIXME + write!(self.out, "{level}")?; + write!(self.out, "const ")?; + let name = format!("{}{}", back::BAKE_PREFIX, result.index()); + match func_ctx.info[result].ty { + proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?, + proc::TypeResolution::Value(ref value) => { + self.write_value_type(module, value)? + } + }; + write!(self.out, " {name} = ")?; + self.named_expressions.insert(result, name); + + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "WaveActiveAllTrue(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "WaveActiveAnyTrue(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "WaveActiveSum(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "WaveActiveProduct(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "WaveActiveMax(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "WaveActiveMin(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "WaveActiveBitAnd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "WaveActiveBitOr(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "WaveActiveBitXor(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "WavePrefixSum(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "WavePrefixProduct(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => { + self.write_expr(module, argument, func_ctx)?; + write!(self.out, " + WavePrefixSum(")?; + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => { + self.write_expr(module, argument, func_ctx)?; + write!(self.out, " * WavePrefixProduct(")?; + } + _ => unimplemented!(), + } + self.write_expr(module, argument, func_ctx)?; + writeln!(self.out, ");")?; } Statement::SubgroupGather { mode, argument, result, } => { - unimplemented!(); // FIXME + write!(self.out, "{level}")?; + write!(self.out, "const ")?; + let name = format!("{}{}", back::BAKE_PREFIX, result.index()); + match func_ctx.info[result].ty { + proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?, + proc::TypeResolution::Value(ref value) => { + self.write_value_type(module, value)? + } + }; + write!(self.out, " {name} = ")?; + self.named_expressions.insert(result, name); + + if matches!(mode, crate::GatherMode::BroadcastFirst) { + write!(self.out, "WaveReadLaneFirst(")?; + self.write_expr(module, argument, func_ctx)?; + } else { + write!(self.out, "WaveReadLaneAt(")?; + self.write_expr(module, argument, func_ctx)?; + write!(self.out, ", ")?; + match mode { + crate::GatherMode::BroadcastFirst => unreachable!(), + crate::GatherMode::Broadcast(index) | crate::GatherMode::Shuffle(index) => { + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::ShuffleDown(index) => { + write!(self.out, "WaveGetLaneIndex() + ")?; + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::ShuffleUp(index) => { + write!(self.out, "WaveGetLaneIndex() - ")?; + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::ShuffleXor(index) => { + write!(self.out, "WaveGetLaneIndex() ^ ")?; + self.write_expr(module, index, func_ctx)?; + } + } + } + writeln!(self.out, ");")?; } } @@ -3251,7 +3412,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?; } if barrier.contains(crate::Barrier::SUB_GROUP) { - unimplemented!() // FIXME + // Does not exist in DirectX } Ok(()) } diff --git a/naga/src/back/msl/mod.rs b/naga/src/back/msl/mod.rs index 4e0d8489e4..eee825a83b 100644 --- a/naga/src/back/msl/mod.rs +++ b/naga/src/back/msl/mod.rs @@ -438,9 +438,10 @@ impl ResolvedBinding { Bi::WorkGroupSize => "dispatch_threads_per_threadgroup", Bi::NumWorkGroups => "threadgroups_per_grid", // subgroup - Bi::NumSubgroups | Bi::SubgroupId => todo!(), - Bi::SubgroupInvocationId => "simdgroup_index_in_threadgroup", - Bi::SubgroupSize => "simdgroups_per_threadgroup", + Bi::NumSubgroups => "simdgroups_per_threadgroup", + Bi::SubgroupId => "simdgroup_index_in_threadgroup", + Bi::SubgroupSize => "threads_per_simdgroup", + Bi::SubgroupInvocationId => "thread_index_in_simdgroup", Bi::CullDistance | Bi::ViewIndex => { return Err(Error::UnsupportedBuiltIn(built_in)) } diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 629b98c96d..ce57588240 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -3017,14 +3017,13 @@ impl Writer { let name = self.namer.call(""); self.start_baking_expression(result, &context.expression, &name)?; self.named_expressions.insert(result, name); - write!(self.out, "{NAMESPACE}::simd_ballot(;")?; - match predicate { - Some(predicate) => { - self.put_expression(predicate, &context.expression, true)? - } - None => write!(self.out, "true")?, + write!(self.out, "uint4((uint64_t){NAMESPACE}::simd_ballot(")?; + if let Some(predicate) = predicate { + self.put_expression(predicate, &context.expression, true)?; + } else { + write!(self.out, "true")?; } - writeln!(self.out, ");")?; + writeln!(self.out, "), 0, 0, 0);")?; } crate::Statement::SubgroupCollectiveOperation { op, @@ -3032,14 +3031,101 @@ impl Writer { argument, result, } => { - unimplemented!(); // FIXME + write!(self.out, "{level}")?; + let name = self.namer.call(""); + self.start_baking_expression(result, &context.expression, &name)?; + self.named_expressions.insert(result, name); + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "{NAMESPACE}::simd_all(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "{NAMESPACE}::simd_any(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "{NAMESPACE}::simd_sum(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "{NAMESPACE}::simd_product(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "{NAMESPACE}::simd_max(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "{NAMESPACE}::simd_min(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "{NAMESPACE}::simd_and(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "{NAMESPACE}::simd_or(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "{NAMESPACE}::simd_xor(")? + } + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Add, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?, + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Mul, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?, + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Add, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?, + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Mul, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?, + _ => unimplemented!(), + } + self.put_expression(argument, &context.expression, true)?; + writeln!(self.out, ");")?; } crate::Statement::SubgroupGather { mode, argument, result, } => { - unimplemented!(); // FIXME + write!(self.out, "{level}")?; + let name = self.namer.call(""); + self.start_baking_expression(result, &context.expression, &name)?; + self.named_expressions.insert(result, name); + match mode { + crate::GatherMode::BroadcastFirst => { + write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?; + } + crate::GatherMode::Broadcast(_) => { + write!(self.out, "{NAMESPACE}::simd_broadcast(")?; + } + crate::GatherMode::Shuffle(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle(")?; + } + crate::GatherMode::ShuffleDown(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?; + } + crate::GatherMode::ShuffleUp(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?; + } + crate::GatherMode::ShuffleXor(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?; + } + } + self.put_expression(argument, &context.expression, true)?; + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + write!(self.out, ", ")?; + self.put_expression(index, &context.expression, true)?; + } + } + writeln!(self.out, ");")?; } } } @@ -4378,7 +4464,10 @@ impl Writer { )?; } if flags.contains(crate::Barrier::SUB_GROUP) { - unimplemented!(); // FIXME + writeln!( + self.out, + "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);", + )?; } Ok(()) } diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index 9fba489f79..50883ce071 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -2340,30 +2340,11 @@ impl<'w> BlockContext<'w> { crate::Statement::RayQuery { query, ref fun } => { self.write_ray_query_function(query, fun, &mut block); } - crate::Statement::SubgroupBallot { result, predicate } => { - self.writer.require_any( - "GroupNonUniformBallot", - &[spirv::Capability::GroupNonUniformBallot], - )?; - let vec4_u32_type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: Some(crate::VectorSize::Quad), - kind: crate::ScalarKind::Uint, - width: 4, - pointer_space: None, - })); - let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); - let predicate = match predicate { - Some(predicate) => self.cached[predicate], - None => self.writer.get_constant_scalar(crate::Literal::Bool(true)), - }; - let id = self.gen_id(); - block.body.push(Instruction::group_non_uniform_ballot( - vec4_u32_type_id, - id, - exec_scope_id, - predicate, - )); - self.cached[result] = id; + crate::Statement::SubgroupBallot { + result, + ref predicate, + } => { + self.write_subgroup_ballot(predicate, result, &mut block)?; } crate::Statement::SubgroupCollectiveOperation { ref op, @@ -2378,7 +2359,7 @@ impl<'w> BlockContext<'w> { argument, result, } => { - self.write_subgroup_broadcast(mode, argument, result, &mut block)?; + self.write_subgroup_gather(mode, argument, result, &mut block)?; } } } diff --git a/naga/src/back/spv/instructions.rs b/naga/src/back/spv/instructions.rs index 8a528065ef..5f7c6b34fd 100644 --- a/naga/src/back/spv/instructions.rs +++ b/naga/src/back/spv/instructions.rs @@ -1054,33 +1054,34 @@ impl super::Instruction { instruction } - pub(super) fn group_non_uniform_broadcast( + pub(super) fn group_non_uniform_broadcast_first( result_type_id: Word, id: Word, exec_scope_id: Word, value: Word, - index: Word, ) -> Self { - let mut instruction = Self::new(Op::GroupNonUniformBroadcast); + let mut instruction = Self::new(Op::GroupNonUniformBroadcastFirst); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(exec_scope_id); instruction.add_operand(value); - instruction.add_operand(index); instruction } - pub(super) fn group_non_uniform_broadcast_first( + pub(super) fn group_non_uniform_gather( + op: Op, result_type_id: Word, id: Word, exec_scope_id: Word, value: Word, + index: Word, ) -> Self { - let mut instruction = Self::new(Op::GroupNonUniformBroadcastFirst); + let mut instruction = Self::new(op); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(exec_scope_id); instruction.add_operand(value); + instruction.add_operand(index); instruction } @@ -1092,10 +1093,6 @@ impl super::Instruction { group_op: Option, value: Word, ) -> Self { - println!( - "{:?}", - (op, result_type_id, id, exec_scope_id, group_op, value) - ); let mut instruction = Self::new(op); instruction.set_type(result_type_id); instruction.set_result(id); diff --git a/naga/src/back/spv/subgroup.rs b/naga/src/back/spv/subgroup.rs index 7206ec9312..79db752a6c 100644 --- a/naga/src/back/spv/subgroup.rs +++ b/naga/src/back/spv/subgroup.rs @@ -1,7 +1,43 @@ use super::{Block, BlockContext, Error, Instruction}; -use crate::{arena::Handle, TypeInner}; +use crate::{ + arena::Handle, + back::spv::{LocalType, LookupType}, + TypeInner, +}; impl<'w> BlockContext<'w> { + pub(super) fn write_subgroup_ballot( + &mut self, + predicate: &Option>, + result: Handle, + block: &mut Block, + ) -> Result<(), Error> { + self.writer.require_any( + "GroupNonUniformBallot", + &[spirv::Capability::GroupNonUniformBallot], + )?; + let vec4_u32_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(crate::VectorSize::Quad), + kind: crate::ScalarKind::Uint, + width: 4, + pointer_space: None, + })); + let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); + let predicate = if let Some(predicate) = *predicate { + self.cached[predicate] + } else { + self.writer.get_constant_scalar(crate::Literal::Bool(true)) + }; + let id = self.gen_id(); + block.body.push(Instruction::group_non_uniform_ballot( + vec4_u32_type_id, + id, + exec_scope_id, + predicate, + )); + self.cached[result] = id; + Ok(()) + } pub(super) fn write_subgroup_operation( &mut self, op: &crate::SubgroupOperation, @@ -11,7 +47,7 @@ impl<'w> BlockContext<'w> { block: &mut Block, ) -> Result<(), Error> { use crate::SubgroupOperation as sg; - match op { + match *op { sg::All | sg::Any => { self.writer.require_any( "GroupNonUniformVote", @@ -21,11 +57,7 @@ impl<'w> BlockContext<'w> { _ => { self.writer.require_any( "GroupNonUniformArithmetic", - &[ - spirv::Capability::GroupNonUniformArithmetic, - spirv::Capability::GroupNonUniformClustered, - spirv::Capability::GroupNonUniformPartitionedNV, - ], + &[spirv::Capability::GroupNonUniformArithmetic], )?; } } @@ -35,14 +67,14 @@ impl<'w> BlockContext<'w> { let result_type_id = self.get_expression_type_id(result_ty); let result_ty_inner = result_ty.inner_with(&self.ir_module.types); - let (is_scalar, kind) = match result_ty_inner { + let (is_scalar, kind) = match *result_ty_inner { TypeInner::Scalar { kind, .. } => (true, kind), TypeInner::Vector { kind, .. } => (false, kind), _ => unimplemented!(), }; use crate::ScalarKind as sk; - let spirv_op = match (kind, op) { + let spirv_op = match (kind, *op) { (sk::Bool, sg::All) if is_scalar => spirv::Op::GroupNonUniformAll, (sk::Bool, sg::Any) if is_scalar => spirv::Op::GroupNonUniformAny, (_, sg::All | sg::Any) => unimplemented!(), @@ -71,9 +103,9 @@ impl<'w> BlockContext<'w> { let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); use crate::CollectiveOperation as c; - let group_op = match op { + let group_op = match *op { sg::All | sg::Any => None, - _ => Some(match collective_op { + _ => Some(match *collective_op { c::Reduce => spirv::GroupOperation::Reduce, c::InclusiveScan => spirv::GroupOperation::InclusiveScan, c::ExclusiveScan => spirv::GroupOperation::ExclusiveScan, @@ -92,7 +124,7 @@ impl<'w> BlockContext<'w> { self.cached[result] = id; Ok(()) } - pub(super) fn write_subgroup_broadcast( + pub(super) fn write_subgroup_gather( &mut self, mode: &crate::GatherMode, argument: Handle, @@ -103,6 +135,26 @@ impl<'w> BlockContext<'w> { "GroupNonUniformBallot", &[spirv::Capability::GroupNonUniformBallot], )?; + match *mode { + crate::GatherMode::BroadcastFirst | crate::GatherMode::Broadcast(_) => { + self.writer.require_any( + "GroupNonUniformBallot", + &[spirv::Capability::GroupNonUniformBallot], + )?; + } + crate::GatherMode::Shuffle(_) | crate::GatherMode::ShuffleXor(_) => { + self.writer.require_any( + "GroupNonUniformShuffle", + &[spirv::Capability::GroupNonUniformShuffle], + )?; + } + crate::GatherMode::ShuffleDown(_) | crate::GatherMode::ShuffleUp(_) => { + self.writer.require_any( + "GroupNonUniformShuffleRelative", + &[spirv::Capability::GroupNonUniformShuffleRelative], + )?; + } + } let id = self.gen_id(); let result_ty = &self.fun_info[result].ty; @@ -111,17 +163,7 @@ impl<'w> BlockContext<'w> { let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); let arg_id = self.cached[argument]; - match mode { - crate::GatherMode::Broadcast(index) => { - let index_id = self.cached[*index]; - block.body.push(Instruction::group_non_uniform_broadcast( - result_type_id, - id, - exec_scope_id, - arg_id, - index_id, - )); - } + match *mode { crate::GatherMode::BroadcastFirst => { block .body @@ -132,10 +174,29 @@ impl<'w> BlockContext<'w> { arg_id, )); } - crate::GatherMode::Shuffle(index) + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) | crate::GatherMode::ShuffleDown(index) | crate::GatherMode::ShuffleUp(index) - | crate::GatherMode::ShuffleXor(index) => todo!(), + | crate::GatherMode::ShuffleXor(index) => { + let index_id = self.cached[index]; + let op = match *mode { + crate::GatherMode::BroadcastFirst => unreachable!(), + crate::GatherMode::Broadcast(_) => spirv::Op::GroupNonUniformBroadcast, + crate::GatherMode::Shuffle(_) => spirv::Op::GroupNonUniformShuffle, + crate::GatherMode::ShuffleDown(_) => spirv::Op::GroupNonUniformShuffleDown, + crate::GatherMode::ShuffleUp(_) => spirv::Op::GroupNonUniformShuffleUp, + crate::GatherMode::ShuffleXor(_) => spirv::Op::GroupNonUniformShuffleXor, + }; + block.body.push(Instruction::group_non_uniform_gather( + op, + result_type_id, + id, + exec_scope_id, + arg_id, + index_id, + )); + } } self.cached[result] = id; Ok(()) diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 6a55288308..da0fdf766f 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -1594,17 +1594,23 @@ impl Writer { Bi::WorkGroupSize => BuiltIn::WorkgroupSize, Bi::NumWorkGroups => BuiltIn::NumWorkgroups, // Subgroup - Bi::NumSubgroups | Bi::SubgroupId => todo!(), - Bi::SubgroupInvocationId => { + Bi::NumSubgroups => { self.require_any( - "`subgroup_invocation_id` built-in", + "`num_subgroups` built-in", &[spirv::Capability::GroupNonUniform], )?; - BuiltIn::SubgroupLocalInvocationId + BuiltIn::NumSubgroups + } + Bi::SubgroupId => { + self.require_any( + "`subgroup_id` built-in", + &[spirv::Capability::GroupNonUniform], + )?; + BuiltIn::SubgroupId } Bi::SubgroupSize => { self.require_any( - "`subgroup_invocation_id` built-in", + "`subgroup_size` built-in", &[ spirv::Capability::GroupNonUniform, spirv::Capability::SubgroupBallotKHR, @@ -1612,6 +1618,16 @@ impl Writer { )?; BuiltIn::SubgroupSize } + Bi::SubgroupInvocationId => { + self.require_any( + "`subgroup_invocation_id` built-in", + &[ + spirv::Capability::GroupNonUniform, + spirv::Capability::SubgroupBallotKHR, + ], + )?; + BuiltIn::SubgroupLocalInvocationId + } }; self.decorate(id, Decoration::BuiltIn, &[built_in as u32]); diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 8fc9881f10..0bc2dfceb0 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -919,6 +919,10 @@ impl Writer { if barrier.contains(crate::Barrier::WORK_GROUP) { writeln!(self.out, "{level}workgroupBarrier();")?; } + + if barrier.contains(crate::Barrier::SUB_GROUP) { + writeln!(self.out, "{level}subgroupBarrier();")?; + } } Statement::RayQuery { .. } => unreachable!(), Statement::SubgroupBallot { result, predicate } => { @@ -939,14 +943,99 @@ impl Writer { argument, result, } => { - unimplemented!() // FIXME + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_named_expr(module, result, func_ctx, &res_name)?; + self.named_expressions.insert(result, res_name); + + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "subgroupAll(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "subgroupAny(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupAdd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupMul(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "subgroupMax(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "subgroupMin(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "subgroupAnd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "subgroupOr(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "subgroupXor(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupPrefixExclusiveAdd(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupPrefixExclusiveMul(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupPrefixInclusiveAdd(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupPrefixInclusiveMul(")? + } + _ => unimplemented!(), + } + self.write_expr(module, argument, func_ctx)?; + writeln!(self.out, ");")?; } Statement::SubgroupGather { mode, argument, result, } => { - unimplemented!() // FIXME + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_named_expr(module, result, func_ctx, &res_name)?; + self.named_expressions.insert(result, res_name); + + match mode { + crate::GatherMode::BroadcastFirst => { + write!(self.out, "subgroupBroadcastFirst(")?; + } + crate::GatherMode::Broadcast(_) => { + write!(self.out, "subgroupBroadcast(")?; + } + crate::GatherMode::Shuffle(_) => { + write!(self.out, "subgroupShuffle(")?; + } + crate::GatherMode::ShuffleDown(_) => { + write!(self.out, "subgroupShuffleDown(")?; + } + crate::GatherMode::ShuffleUp(_) => { + write!(self.out, "subgroupShuffleUp(")?; + } + crate::GatherMode::ShuffleXor(_) => { + write!(self.out, "subgroupShuffleXor(")?; + } + } + self.write_expr(module, argument, func_ctx)?; + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + write!(self.out, ", ")?; + self.write_expr(module, index, func_ctx)?; + } + } + writeln!(self.out, ");")?; } } @@ -1789,9 +1878,10 @@ fn builtin_str(built_in: crate::BuiltIn) -> Result<&'static str, Error> { Bi::SampleMask => "sample_mask", Bi::PrimitiveIndex => "primitive_index", Bi::ViewIndex => "view_index", - Bi::NumSubgroups | Bi::SubgroupId => todo!(), - Bi::SubgroupInvocationId => "subgroup_invocation_id", + Bi::NumSubgroups => "num_subgroups", + Bi::SubgroupId => "subgroup_id", Bi::SubgroupSize => "subgroup_size", + Bi::SubgroupInvocationId => "subgroup_invocation_id", Bi::BaseInstance | Bi::BaseVertex | Bi::ClipDistance diff --git a/naga/src/front/spv/error.rs b/naga/src/front/spv/error.rs index 2f9bf2d1bc..8508ede042 100644 --- a/naga/src/front/spv/error.rs +++ b/naga/src/front/spv/error.rs @@ -54,6 +54,8 @@ pub enum Error { UnknownBinaryOperator(spirv::Op), #[error("unknown relational function {0:?}")] UnknownRelationalFunction(spirv::Op), + #[error("unsupported group opeation %{0}")] + UnsupportedGroupOperation(spirv::Word), #[error("invalid parameter {0:?}")] InvalidParameter(spirv::Op), #[error("invalid operand count {1} for {0:?}")] @@ -116,8 +118,8 @@ pub enum Error { FunctionCallCycle(spirv::Word), #[error("invalid array size {0:?}")] InvalidArraySize(Handle), - #[error("invalid barrier scope %{0}")] - InvalidBarrierScope(spirv::Word), + #[error("invalid execution scope %{0}")] + InvalidExecutionScope(spirv::Word), #[error("invalid barrier memory semantics %{0}")] InvalidBarrierMemorySemantics(spirv::Word), #[error( diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index c78407ef99..e7c082c6a5 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -3650,7 +3650,7 @@ impl> Frontend { let semantics_const = self.lookup_constant.lookup(semantics_id)?; let exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle) - .ok_or(Error::InvalidBarrierScope(exec_scope_id))?; + .ok_or(Error::InvalidExecutionScope(exec_scope_id))?; let semantics = resolve_constant(ctx.gctx(), semantics_const.handle) .ok_or(Error::InvalidBarrierMemorySemantics(semantics_id))?; @@ -3691,6 +3691,209 @@ impl> Frontend { }, ); } + Op::GroupNonUniformBallot => { + inst.expect(4)?; + let _result_type_id = self.next()?; + let result_id = self.next()?; + let exec_scope_id = self.next()?; + let predicate_id = self.next()?; + + let result_lookup = self.lookup_expression.lookup(result_id)?; + let result_handle = get_expr_handle!(result_id, result_lookup); + + let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; + let _exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle) + .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) + .ok_or(Error::InvalidExecutionScope(exec_scope_id))?; + + let predicate = if self + .lookup_constant + .lookup(predicate_id) + .ok() + .filter(|predicate_const| { + matches!( + ctx.gctx().const_expressions + [ctx.gctx().constants[predicate_const.handle].init], + crate::Expression::Literal(crate::Literal::Bool(true)) + ) + }) + .is_some() + { + None + } else { + let predicate_lookup = self.lookup_expression.lookup(predicate_id)?; + let predicate_handle = get_expr_handle!(predicate_id, predicate_lookup); + Some(predicate_handle) + }; + + block.push( + crate::Statement::SubgroupBallot { + result: result_handle, + predicate, + }, + span, + ); + } + spirv::Op::GroupNonUniformAll + | spirv::Op::GroupNonUniformAny + | spirv::Op::GroupNonUniformIAdd + | spirv::Op::GroupNonUniformFAdd + | spirv::Op::GroupNonUniformIMul + | spirv::Op::GroupNonUniformFMul + | spirv::Op::GroupNonUniformSMax + | spirv::Op::GroupNonUniformUMax + | spirv::Op::GroupNonUniformFMax + | spirv::Op::GroupNonUniformSMin + | spirv::Op::GroupNonUniformUMin + | spirv::Op::GroupNonUniformFMin + | spirv::Op::GroupNonUniformBitwiseAnd + | spirv::Op::GroupNonUniformBitwiseOr + | spirv::Op::GroupNonUniformBitwiseXor + | spirv::Op::GroupNonUniformLogicalAnd + | spirv::Op::GroupNonUniformLogicalOr + | spirv::Op::GroupNonUniformLogicalXor => { + inst.expect( + if matches!( + inst.op, + spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny + ) { + 4 + } else { + 5 + }, + )?; + let _result_type_id = self.next()?; + let result_id = self.next()?; + let exec_scope_id = self.next()?; + let collective_op_id = match inst.op { + spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny => { + crate::CollectiveOperation::Reduce + } + _ => { + let group_op_id = self.next()?; + match spirv::GroupOperation::from_u32(group_op_id) { + Some(spirv::GroupOperation::Reduce) => { + crate::CollectiveOperation::Reduce + } + Some(spirv::GroupOperation::InclusiveScan) => { + crate::CollectiveOperation::InclusiveScan + } + Some(spirv::GroupOperation::ExclusiveScan) => { + crate::CollectiveOperation::ExclusiveScan + } + _ => return Err(Error::UnsupportedGroupOperation(group_op_id)), + } + } + }; + let argument_id = self.next()?; + + let result_lookup = self.lookup_expression.lookup(result_id)?; + let result_handle = get_expr_handle!(result_id, result_lookup); + let argument_lookup = self.lookup_expression.lookup(argument_id)?; + let argument_handle = get_expr_handle!(argument_id, argument_lookup); + + let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; + let _exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle) + .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) + .ok_or(Error::InvalidExecutionScope(exec_scope_id))?; + + let op_id = match inst.op { + spirv::Op::GroupNonUniformAll => crate::SubgroupOperation::All, + spirv::Op::GroupNonUniformAny => crate::SubgroupOperation::Any, + spirv::Op::GroupNonUniformIAdd | spirv::Op::GroupNonUniformFAdd => { + crate::SubgroupOperation::Add + } + spirv::Op::GroupNonUniformIMul | spirv::Op::GroupNonUniformFMul => { + crate::SubgroupOperation::Mul + } + spirv::Op::GroupNonUniformSMax + | spirv::Op::GroupNonUniformUMax + | spirv::Op::GroupNonUniformFMax => crate::SubgroupOperation::Max, + spirv::Op::GroupNonUniformSMin + | spirv::Op::GroupNonUniformUMin + | spirv::Op::GroupNonUniformFMin => crate::SubgroupOperation::Min, + spirv::Op::GroupNonUniformBitwiseAnd + | spirv::Op::GroupNonUniformLogicalAnd => crate::SubgroupOperation::And, + spirv::Op::GroupNonUniformBitwiseOr + | spirv::Op::GroupNonUniformLogicalOr => crate::SubgroupOperation::Or, + spirv::Op::GroupNonUniformBitwiseXor + | spirv::Op::GroupNonUniformLogicalXor => crate::SubgroupOperation::Xor, + _ => unreachable!(), + }; + + block.push( + crate::Statement::SubgroupCollectiveOperation { + result: result_handle, + op: op_id, + collective_op: collective_op_id, + argument: argument_handle, + }, + span, + ); + } + Op::GroupNonUniformBroadcastFirst + | Op::GroupNonUniformBroadcast + | Op::GroupNonUniformShuffle + | Op::GroupNonUniformShuffleDown + | Op::GroupNonUniformShuffleUp + | Op::GroupNonUniformShuffleXor => { + inst.expect( + if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) { + 4 + } else { + 5 + }, + )?; + let _result_type_id = self.next()?; + let result_id = self.next()?; + let exec_scope_id = self.next()?; + let argument_id = self.next()?; + + let result_lookup = self.lookup_expression.lookup(result_id)?; + let result_handle = get_expr_handle!(result_id, result_lookup); + let argument_lookup = self.lookup_expression.lookup(argument_id)?; + let argument_handle = get_expr_handle!(argument_id, argument_lookup); + + let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; + let _exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle) + .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) + .ok_or(Error::InvalidExecutionScope(exec_scope_id))?; + + let mode = if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) { + crate::GatherMode::BroadcastFirst + } else { + let index_id = self.next()?; + let index_lookup = self.lookup_expression.lookup(index_id)?; + let index_handle = get_expr_handle!(index_id, index_lookup); + match inst.op { + spirv::Op::GroupNonUniformBroadcast => { + crate::GatherMode::Broadcast(index_handle) + } + spirv::Op::GroupNonUniformShuffle => { + crate::GatherMode::Shuffle(index_handle) + } + spirv::Op::GroupNonUniformShuffleDown => { + crate::GatherMode::ShuffleDown(index_handle) + } + spirv::Op::GroupNonUniformShuffleUp => { + crate::GatherMode::ShuffleUp(index_handle) + } + spirv::Op::GroupNonUniformShuffleXor => { + crate::GatherMode::ShuffleXor(index_handle) + } + _ => unreachable!(), + } + }; + + block.push( + crate::Statement::SubgroupGather { + result: result_handle, + mode, + argument: argument_handle, + }, + span, + ); + } _ => return Err(Error::UnsupportedInstruction(self.state, inst.op)), } }; @@ -3811,7 +4014,10 @@ impl> Frontend { | S::Store { .. } | S::ImageStore { .. } | S::Atomic { .. } - | S::RayQuery { .. } => {} + | S::RayQuery { .. } + | S::SubgroupBallot { .. } + | S::SubgroupCollectiveOperation { .. } + | S::SubgroupGather { .. } => {} S::Call { function: ref mut callee, ref arguments, @@ -3842,9 +4048,6 @@ impl> Frontend { } } S::WorkGroupUniformLoad { .. } => unreachable!(), - S::SubgroupBallot { .. } => unreachable!(), // FIXME?? - S::SubgroupCollectiveOperation { .. } => unreachable!(), - S::SubgroupGather { .. } => unreachable!(), } i += 1; } diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index e97da5db7f..50875dfad0 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -1867,7 +1867,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } else if let Some(fun) = Texture::map(function.name) { self.texture_sample_helper(fun, arguments, span, ctx)? } else if let Some((op, cop)) = conv::map_subgroup_operation(function.name) { - return Ok(Some(self.subgroup_helper(span, op, cop, arguments, ctx)?)); + return Ok(Some( + self.subgroup_operation_helper(span, op, cop, arguments, ctx)?, + )); + } else if let Some(mode) = conv::map_subgroup_gather(function.name) { + return Ok(Some( + self.subgroup_gather_helper(span, mode, arguments, ctx)?, + )); } else { match function.name { "select" => { @@ -2267,53 +2273,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .push(crate::Statement::SubgroupBallot { result, predicate }, span); return Ok(Some(result)); } - "subgroupBroadcast" => { - let mut args = ctx.prepare_args(arguments, 2, span); - - let index = self.expression(args.next()?, ctx)?; - let argument = self.expression(args.next()?, ctx)?; - args.finish()?; - - let ty = ctx.register_type(argument)?; - - let result = ctx.interrupt_emitter( - crate::Expression::SubgroupOperationResult { ty }, - span, - )?; - let rctx = ctx.runtime_expression_ctx(span)?; - rctx.block.push( - crate::Statement::SubgroupGather { - mode: crate::GatherMode::Broadcast(index), - argument, - result, - }, - span, - ); - return Ok(Some(result)); - } - "subgroupBroadcastFirst" => { - let mut args = ctx.prepare_args(arguments, 1, span); - - let argument = self.expression(args.next()?, ctx)?; - args.finish()?; - - let ty = ctx.register_type(argument)?; - - let result = ctx.interrupt_emitter( - crate::Expression::SubgroupOperationResult { ty }, - span, - )?; - let rctx = ctx.runtime_expression_ctx(span)?; - rctx.block.push( - crate::Statement::SubgroupGather { - mode: crate::GatherMode::BroadcastFirst, - argument, - result, - }, - span, - ); - return Ok(Some(result)); - } _ => return Err(Error::UnknownIdent(function.span, function.name)), } }; @@ -2504,7 +2463,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { depth_ref, }) } - fn subgroup_helper( + + fn subgroup_operation_helper( &mut self, span: Span, op: crate::SubgroupOperation, @@ -2534,6 +2494,46 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { Ok(result) } + fn subgroup_gather_helper( + &mut self, + span: Span, + mode: crate::GatherMode, + arguments: &[Handle>], + ctx: &mut ExpressionContext<'source, '_, '_>, + ) -> Result, Error<'source>> { + let mut args = ctx.prepare_args(arguments, 2, span); + + let argument = self.expression(args.next()?, ctx)?; + let index = if let crate::GatherMode::BroadcastFirst = mode { + Handle::new(NonZeroU32::new(u32::MAX).unwrap()) + } else { + self.expression(args.next()?, ctx)? + }; + args.finish()?; + + let ty = ctx.register_type(argument)?; + + let result = + ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::SubgroupGather { + mode: match mode { + crate::GatherMode::BroadcastFirst => crate::GatherMode::BroadcastFirst, + crate::GatherMode::Broadcast(_) => crate::GatherMode::Broadcast(index), + crate::GatherMode::Shuffle(_) => crate::GatherMode::Shuffle(index), + crate::GatherMode::ShuffleDown(_) => crate::GatherMode::ShuffleDown(index), + crate::GatherMode::ShuffleUp(_) => crate::GatherMode::ShuffleUp(index), + crate::GatherMode::ShuffleXor(_) => crate::GatherMode::ShuffleXor(index), + }, + argument, + result, + }, + span, + ); + Ok(result) + } + fn r#struct( &mut self, s: &ast::Struct<'source>, diff --git a/naga/src/front/wgsl/parse/conv.rs b/naga/src/front/wgsl/parse/conv.rs index 160213e4a3..c53f4df753 100644 --- a/naga/src/front/wgsl/parse/conv.rs +++ b/naga/src/front/wgsl/parse/conv.rs @@ -35,8 +35,10 @@ pub fn map_built_in(word: &str, span: Span) -> Result> "workgroup_id" => crate::BuiltIn::WorkGroupId, "num_workgroups" => crate::BuiltIn::NumWorkGroups, // subgroup - "subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId, + "num_subgroups" => crate::BuiltIn::NumSubgroups, + "subgroup_id" => crate::BuiltIn::SubgroupId, "subgroup_size" => crate::BuiltIn::SubgroupSize, + "subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId, _ => return Err(Error::UnknownBuiltin(span)), }) } @@ -254,8 +256,25 @@ pub fn map_subgroup_operation( "subgroupAnd" => (sg::And, co::Reduce), "subgroupOr" => (sg::Or, co::Reduce), "subgroupXor" => (sg::Xor, co::Reduce), - "subgroupPrefixAdd" => (sg::Add, co::InclusiveScan), - "subgroupPrefixMul" => (sg::Mul, co::InclusiveScan), + "subgroupPrefixExclusiveAdd" => (sg::Add, co::ExclusiveScan), + "subgroupPrefixExclusiveMul" => (sg::Mul, co::ExclusiveScan), + "subgroupPrefixInclusiveAdd" => (sg::Add, co::InclusiveScan), + "subgroupPrefixInclusiveMul" => (sg::Mul, co::InclusiveScan), + _ => return None, + }) +} + +pub fn map_subgroup_gather(word: &str) -> Option { + use crate::GatherMode as gm; + use crate::Handle; + use std::num::NonZeroU32; + Some(match word { + "subgroupBroadcastFirst" => gm::BroadcastFirst, + "subgroupBroadcast" => gm::Broadcast(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), + "subgroupShuffle" => gm::Shuffle(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), + "subgroupShuffleDown" => gm::ShuffleDown(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), + "subgroupShuffleUp" => gm::ShuffleUp(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), + "subgroupShuffleXor" => gm::ShuffleXor(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), _ => return None, }) } From 43133d082f914e4c064e56bc1a181eb57120b33b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Mei=C3=9Fner?= Date: Sat, 14 Oct 2023 20:23:53 +0200 Subject: [PATCH 16/46] Adjusts metal backend test_stack_size(). --- naga/src/back/msl/writer.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index ce57588240..1b0791b573 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -4739,8 +4739,8 @@ fn test_stack_size() { } let stack_size = addresses_end - addresses_start; // check the size (in debug only) - // last observed macOS value: 19152 (CI) - if !(9000..=20000).contains(&stack_size) { + // last observed macOS value: 22256 (CI) + if !(15000..=25000).contains(&stack_size) { panic!("`put_block` stack size {stack_size} has changed!"); } } From 275a5a5fb790ad430830b94ca0feef46269aa26f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Mei=C3=9Fner?= Date: Fri, 20 Oct 2023 23:28:20 +0200 Subject: [PATCH 17/46] Adds test and snapshots. --- naga/tests/in/subgroup-operations.param.ron | 26 +++++++ naga/tests/in/subgroup-operations.wgsl | 32 ++++++++ .../subgroup-operations.main.Compute.glsl | 41 +++++++++++ naga/tests/out/hlsl/subgroup-operations.hlsl | 32 ++++++++ naga/tests/out/hlsl/subgroup-operations.ron | 12 +++ naga/tests/out/msl/subgroup-operations.msl | 38 ++++++++++ naga/tests/out/spv/subgroup-operations.spvasm | 73 +++++++++++++++++++ naga/tests/out/wgsl/subgroup-operations.wgsl | 26 +++++++ naga/tests/snapshots.rs | 4 + 9 files changed, 284 insertions(+) create mode 100644 naga/tests/in/subgroup-operations.param.ron create mode 100644 naga/tests/in/subgroup-operations.wgsl create mode 100644 naga/tests/out/glsl/subgroup-operations.main.Compute.glsl create mode 100644 naga/tests/out/hlsl/subgroup-operations.hlsl create mode 100644 naga/tests/out/hlsl/subgroup-operations.ron create mode 100644 naga/tests/out/msl/subgroup-operations.msl create mode 100644 naga/tests/out/spv/subgroup-operations.spvasm create mode 100644 naga/tests/out/wgsl/subgroup-operations.wgsl diff --git a/naga/tests/in/subgroup-operations.param.ron b/naga/tests/in/subgroup-operations.param.ron new file mode 100644 index 0000000000..fc444a3efe --- /dev/null +++ b/naga/tests/in/subgroup-operations.param.ron @@ -0,0 +1,26 @@ +( + spv: ( + version: (1, 3), + ), + msl: ( + lang_version: (2, 4), + per_entry_point_map: {}, + inline_samplers: [], + spirv_cross_compatibility: false, + fake_missing_bindings: false, + zero_initialize_workgroup_memory: true, + ), + glsl: ( + version: Desktop(430), + writer_flags: (""), + binding_map: { }, + zero_initialize_workgroup_memory: true, + ), + hlsl: ( + shader_model: V6_0, + binding_map: {}, + fake_missing_bindings: true, + special_constants_binding: None, + zero_initialize_workgroup_memory: true, + ), +) diff --git a/naga/tests/in/subgroup-operations.wgsl b/naga/tests/in/subgroup-operations.wgsl new file mode 100644 index 0000000000..f30b60be47 --- /dev/null +++ b/naga/tests/in/subgroup-operations.wgsl @@ -0,0 +1,32 @@ +@compute @workgroup_size(1) +fn main( + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(subgroup_invocation_id) subgroup_invocation_id: u32, +) { + subgroupBarrier(); + + subgroupBallot((subgroup_invocation_id & 1u) == 1u); + + subgroupAll(subgroup_invocation_id != 0u); + subgroupAny(subgroup_invocation_id == 0u); + subgroupAdd(subgroup_invocation_id); + subgroupMul(subgroup_invocation_id); + subgroupMin(subgroup_invocation_id); + subgroupMax(subgroup_invocation_id); + subgroupAnd(subgroup_invocation_id); + subgroupOr(subgroup_invocation_id); + subgroupXor(subgroup_invocation_id); + subgroupPrefixExclusiveAdd(subgroup_invocation_id); + subgroupPrefixExclusiveMul(subgroup_invocation_id); + subgroupPrefixInclusiveAdd(subgroup_invocation_id); + subgroupPrefixInclusiveMul(subgroup_invocation_id); + + subgroupBroadcastFirst(subgroup_invocation_id); + subgroupBroadcast(subgroup_invocation_id, 4u); + subgroupShuffle(subgroup_invocation_id, subgroup_size - 1u - subgroup_invocation_id); + subgroupShuffleDown(subgroup_invocation_id, 1u); + subgroupShuffleUp(subgroup_invocation_id, 1u); + subgroupShuffleXor(subgroup_invocation_id, subgroup_size - 1u); +} diff --git a/naga/tests/out/glsl/subgroup-operations.main.Compute.glsl b/naga/tests/out/glsl/subgroup-operations.main.Compute.glsl new file mode 100644 index 0000000000..a37cf8e247 --- /dev/null +++ b/naga/tests/out/glsl/subgroup-operations.main.Compute.glsl @@ -0,0 +1,41 @@ +#version 430 core +#extension GL_ARB_compute_shader : require +#extension GL_KHR_shader_subgroup_basic : require +#extension GL_KHR_shader_subgroup_vote : require +#extension GL_KHR_shader_subgroup_arithmetic : require +#extension GL_KHR_shader_subgroup_ballot : require +#extension GL_KHR_shader_subgroup_shuffle : require +#extension GL_KHR_shader_subgroup_shuffle_relative : require +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + + +void main() { + uint num_subgroups = gl_NumSubgroups; + uint subgroup_id = gl_SubgroupID; + uint subgroup_size = gl_SubgroupSize; + uint subgroup_invocation_id = gl_SubgroupInvocationID; + subgroupMemoryBarrier(); + barrier(); + uvec4 _e8 = subgroupBallot(((subgroup_invocation_id & 1u) == 1u)); + bool _e11 = subgroupAll((subgroup_invocation_id != 0u)); + bool _e14 = subgroupAny((subgroup_invocation_id == 0u)); + uint _e15 = subgroupAdd(subgroup_invocation_id); + uint _e16 = subgroupMul(subgroup_invocation_id); + uint _e17 = subgroupMin(subgroup_invocation_id); + uint _e18 = subgroupMax(subgroup_invocation_id); + uint _e19 = subgroupAnd(subgroup_invocation_id); + uint _e20 = subgroupOr(subgroup_invocation_id); + uint _e21 = subgroupXor(subgroup_invocation_id); + uint _e22 = subgroupExclusiveAdd(subgroup_invocation_id); + uint _e23 = subgroupExclusiveMul(subgroup_invocation_id); + uint _e24 = subgroupInclusiveAdd(subgroup_invocation_id); + uint _e25 = subgroupInclusiveMul(subgroup_invocation_id); + uint _e26 = subgroupBroadcastFirst(subgroup_invocation_id); + uint _e28 = subgroupBroadcast(subgroup_invocation_id, 4u); + uint _e32 = subgroupShuffle(subgroup_invocation_id, ((subgroup_size - 1u) - subgroup_invocation_id)); + uint _e34 = subgroupShuffleDown(subgroup_invocation_id, 1u); + uint _e36 = subgroupShuffleUp(subgroup_invocation_id, 1u); + uint _e39 = subgroupShuffleXor(subgroup_invocation_id, (subgroup_size - 1u)); + return; +} + diff --git a/naga/tests/out/hlsl/subgroup-operations.hlsl b/naga/tests/out/hlsl/subgroup-operations.hlsl new file mode 100644 index 0000000000..baa37826e0 --- /dev/null +++ b/naga/tests/out/hlsl/subgroup-operations.hlsl @@ -0,0 +1,32 @@ +[numthreads(1, 1, 1)] +void main(uint3 __local_invocation_id : SV_GroupThreadID) +{ + if (all(__local_invocation_id == uint3(0u, 0u, 0u))) { + } + GroupMemoryBarrierWithGroupSync(); + const uint num_subgroups = (1u + WaveGetLaneCount() - 1u) / WaveGetLaneCount(); + const uint subgroup_id = (__local_invocation_id.x * 1u + __local_invocation_id.y * 1u + __local_invocation_id.z) / WaveGetLaneCount(); + const uint subgroup_size = WaveGetLaneCount(); + const uint subgroup_invocation_id = WaveGetLaneIndex(); + const uint4 _e8 = WaveActiveBallot(((subgroup_invocation_id & 1u) == 1u)); + const bool _e11 = WaveActiveAllTrue((subgroup_invocation_id != 0u)); + const bool _e14 = WaveActiveAnyTrue((subgroup_invocation_id == 0u)); + const uint _e15 = WaveActiveSum(subgroup_invocation_id); + const uint _e16 = WaveActiveProduct(subgroup_invocation_id); + const uint _e17 = WaveActiveMin(subgroup_invocation_id); + const uint _e18 = WaveActiveMax(subgroup_invocation_id); + const uint _e19 = WaveActiveBitAnd(subgroup_invocation_id); + const uint _e20 = WaveActiveBitOr(subgroup_invocation_id); + const uint _e21 = WaveActiveBitXor(subgroup_invocation_id); + const uint _e22 = WavePrefixSum(subgroup_invocation_id); + const uint _e23 = WavePrefixProduct(subgroup_invocation_id); + const uint _e24 = subgroup_invocation_id + WavePrefixSum(subgroup_invocation_id); + const uint _e25 = subgroup_invocation_id * WavePrefixProduct(subgroup_invocation_id); + const uint _e26 = WaveReadLaneFirst(subgroup_invocation_id); + const uint _e28 = WaveReadLaneAt(subgroup_invocation_id, 4u); + const uint _e32 = WaveReadLaneAt(subgroup_invocation_id, ((subgroup_size - 1u) - subgroup_invocation_id)); + const uint _e34 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() + 1u); + const uint _e36 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() - 1u); + const uint _e39 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() ^ (subgroup_size - 1u)); + return; +} diff --git a/naga/tests/out/hlsl/subgroup-operations.ron b/naga/tests/out/hlsl/subgroup-operations.ron new file mode 100644 index 0000000000..b973fe3da1 --- /dev/null +++ b/naga/tests/out/hlsl/subgroup-operations.ron @@ -0,0 +1,12 @@ +( + vertex:[ + ], + fragment:[ + ], + compute:[ + ( + entry_point:"main", + target_profile:"cs_6_0", + ), + ], +) diff --git a/naga/tests/out/msl/subgroup-operations.msl b/naga/tests/out/msl/subgroup-operations.msl new file mode 100644 index 0000000000..576fa3b84e --- /dev/null +++ b/naga/tests/out/msl/subgroup-operations.msl @@ -0,0 +1,38 @@ +// language: metal2.4 +#include +#include + +using metal::uint; + + +struct main_Input { +}; +kernel void main_( + uint num_subgroups [[simdgroups_per_threadgroup]] +, uint subgroup_id [[simdgroup_index_in_threadgroup]] +, uint subgroup_size [[threads_per_simdgroup]] +, uint subgroup_invocation_id [[thread_index_in_simdgroup]] +) { + metal::simdgroup_barrier(metal::mem_flags::mem_threadgroup); + metal::uint4 unnamed = uint4((uint64_t)metal::simd_ballot((subgroup_invocation_id & 1u) == 1u), 0, 0, 0); + bool unnamed_1 = metal::simd_all(subgroup_invocation_id != 0u); + bool unnamed_2 = metal::simd_any(subgroup_invocation_id == 0u); + uint unnamed_3 = metal::simd_sum(subgroup_invocation_id); + uint unnamed_4 = metal::simd_product(subgroup_invocation_id); + uint unnamed_5 = metal::simd_min(subgroup_invocation_id); + uint unnamed_6 = metal::simd_max(subgroup_invocation_id); + uint unnamed_7 = metal::simd_and(subgroup_invocation_id); + uint unnamed_8 = metal::simd_or(subgroup_invocation_id); + uint unnamed_9 = metal::simd_xor(subgroup_invocation_id); + uint unnamed_10 = metal::simd_prefix_exclusive_sum(subgroup_invocation_id); + uint unnamed_11 = metal::simd_prefix_exclusive_product(subgroup_invocation_id); + uint unnamed_12 = metal::simd_prefix_inclusive_sum(subgroup_invocation_id); + uint unnamed_13 = metal::simd_prefix_inclusive_product(subgroup_invocation_id); + uint unnamed_14 = metal::simd_broadcast_first(subgroup_invocation_id); + uint unnamed_15 = metal::simd_broadcast(subgroup_invocation_id, 4u); + uint unnamed_16 = metal::simd_shuffle(subgroup_invocation_id, (subgroup_size - 1u) - subgroup_invocation_id); + uint unnamed_17 = metal::simd_shuffle_down(subgroup_invocation_id, 1u); + uint unnamed_18 = metal::simd_shuffle_up(subgroup_invocation_id, 1u); + uint unnamed_19 = metal::simd_shuffle_xor(subgroup_invocation_id, subgroup_size - 1u); + return; +} diff --git a/naga/tests/out/spv/subgroup-operations.spvasm b/naga/tests/out/spv/subgroup-operations.spvasm new file mode 100644 index 0000000000..c2023c5473 --- /dev/null +++ b/naga/tests/out/spv/subgroup-operations.spvasm @@ -0,0 +1,73 @@ +; SPIR-V +; Version: 1.3 +; Generator: rspirv +; Bound: 52 +OpCapability Shader +OpCapability GroupNonUniform +OpCapability GroupNonUniformBallot +OpCapability GroupNonUniformVote +OpCapability GroupNonUniformArithmetic +OpCapability GroupNonUniformShuffle +OpCapability GroupNonUniformShuffleRelative +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %15 "main" %6 %9 %11 %13 +OpExecutionMode %15 LocalSize 1 1 1 +OpDecorate %6 BuiltIn NumSubgroups +OpDecorate %9 BuiltIn SubgroupId +OpDecorate %11 BuiltIn SubgroupSize +OpDecorate %13 BuiltIn SubgroupLocalInvocationId +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypeBool +%7 = OpTypePointer Input %3 +%6 = OpVariable %7 Input +%9 = OpVariable %7 Input +%11 = OpVariable %7 Input +%13 = OpVariable %7 Input +%16 = OpTypeFunction %2 +%17 = OpConstant %3 1 +%18 = OpConstant %3 0 +%19 = OpConstant %3 4 +%21 = OpConstant %3 3 +%22 = OpConstant %3 2 +%23 = OpConstant %3 8 +%26 = OpTypeVector %3 4 +%15 = OpFunction %2 None %16 +%5 = OpLabel +%8 = OpLoad %3 %6 +%10 = OpLoad %3 %9 +%12 = OpLoad %3 %11 +%14 = OpLoad %3 %13 +OpBranch %20 +%20 = OpLabel +OpControlBarrier %21 %22 %23 +%24 = OpBitwiseAnd %3 %14 %17 +%25 = OpIEqual %4 %24 %17 +%27 = OpGroupNonUniformBallot %26 %21 %25 +%28 = OpINotEqual %4 %14 %18 +%29 = OpGroupNonUniformAll %4 %21 %28 +%30 = OpIEqual %4 %14 %18 +%31 = OpGroupNonUniformAny %4 %21 %30 +%32 = OpGroupNonUniformIAdd %3 %21 Reduce %14 +%33 = OpGroupNonUniformIMul %3 %21 Reduce %14 +%34 = OpGroupNonUniformUMin %3 %21 Reduce %14 +%35 = OpGroupNonUniformUMax %3 %21 Reduce %14 +%36 = OpGroupNonUniformBitwiseAnd %3 %21 Reduce %14 +%37 = OpGroupNonUniformBitwiseOr %3 %21 Reduce %14 +%38 = OpGroupNonUniformBitwiseXor %3 %21 Reduce %14 +%39 = OpGroupNonUniformIAdd %3 %21 ExclusiveScan %14 +%40 = OpGroupNonUniformIMul %3 %21 ExclusiveScan %14 +%41 = OpGroupNonUniformIAdd %3 %21 InclusiveScan %14 +%42 = OpGroupNonUniformIMul %3 %21 InclusiveScan %14 +%43 = OpGroupNonUniformBroadcastFirst %3 %21 %14 +%44 = OpGroupNonUniformBroadcast %3 %21 %14 %19 +%45 = OpISub %3 %12 %17 +%46 = OpISub %3 %45 %14 +%47 = OpGroupNonUniformShuffle %3 %21 %14 %46 +%48 = OpGroupNonUniformShuffleDown %3 %21 %14 %17 +%49 = OpGroupNonUniformShuffleUp %3 %21 %14 %17 +%50 = OpISub %3 %12 %17 +%51 = OpGroupNonUniformShuffleXor %3 %21 %14 %50 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/out/wgsl/subgroup-operations.wgsl b/naga/tests/out/wgsl/subgroup-operations.wgsl new file mode 100644 index 0000000000..f12f226387 --- /dev/null +++ b/naga/tests/out/wgsl/subgroup-operations.wgsl @@ -0,0 +1,26 @@ +@compute @workgroup_size(1, 1, 1) +fn main(@builtin(num_subgroups) num_subgroups: u32, @builtin(subgroup_id) subgroup_id: u32, @builtin(subgroup_size) subgroup_size: u32, @builtin(subgroup_invocation_id) subgroup_invocation_id: u32) { + subgroupBarrier(); + let _e8 = subgroupBallot( +((subgroup_invocation_id & 1u) == 1u)); + let _e11 = subgroupAll((subgroup_invocation_id != 0u)); + let _e14 = subgroupAny((subgroup_invocation_id == 0u)); + let _e15 = subgroupAdd(subgroup_invocation_id); + let _e16 = subgroupMul(subgroup_invocation_id); + let _e17 = subgroupMin(subgroup_invocation_id); + let _e18 = subgroupMax(subgroup_invocation_id); + let _e19 = subgroupAnd(subgroup_invocation_id); + let _e20 = subgroupOr(subgroup_invocation_id); + let _e21 = subgroupXor(subgroup_invocation_id); + let _e22 = subgroupPrefixExclusiveAdd(subgroup_invocation_id); + let _e23 = subgroupPrefixExclusiveMul(subgroup_invocation_id); + let _e24 = subgroupPrefixInclusiveAdd(subgroup_invocation_id); + let _e25 = subgroupPrefixInclusiveMul(subgroup_invocation_id); + let _e26 = subgroupBroadcastFirst(subgroup_invocation_id); + let _e28 = subgroupBroadcast(subgroup_invocation_id, 4u); + let _e32 = subgroupShuffle(subgroup_invocation_id, ((subgroup_size - 1u) - subgroup_invocation_id)); + let _e34 = subgroupShuffleDown(subgroup_invocation_id, 1u); + let _e36 = subgroupShuffleUp(subgroup_invocation_id, 1u); + let _e39 = subgroupShuffleXor(subgroup_invocation_id, (subgroup_size - 1u)); + return; +} diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index ed75805ae0..67573bbbaa 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -779,6 +779,10 @@ fn convert_wgsl() { Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, ), ("separate-entry-points", Targets::SPIRV | Targets::GLSL), + ( + "subgroup-operations", + Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, + ), ]; for &(name, targets) in inputs.iter() { From f04b71dfb3ce4ec97c6e8df4ded6460c82c531e5 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Sat, 21 Oct 2023 16:16:01 -0400 Subject: [PATCH 18/46] subgroup: fix spv-in --- naga/src/front/spv/convert.rs | 5 +++ naga/src/front/spv/mod.rs | 75 ++++++++++++++++++++++++++++------- 2 files changed, 65 insertions(+), 15 deletions(-) diff --git a/naga/src/front/spv/convert.rs b/naga/src/front/spv/convert.rs index efd95898b8..a00e7e525b 100644 --- a/naga/src/front/spv/convert.rs +++ b/naga/src/front/spv/convert.rs @@ -154,6 +154,11 @@ pub(super) fn map_builtin(word: spirv::Word, invariant: bool) -> Result crate::BuiltIn::WorkGroupId, Some(Bi::WorkgroupSize) => crate::BuiltIn::WorkGroupSize, Some(Bi::NumWorkgroups) => crate::BuiltIn::NumWorkGroups, + // subgroup + Some(Bi::NumSubgroups) => crate::BuiltIn::NumSubgroups, + Some(Bi::SubgroupId) => crate::BuiltIn::SubgroupId, + Some(Bi::SubgroupSize) => crate::BuiltIn::SubgroupSize, + Some(Bi::SubgroupLocalInvocationId) => crate::BuiltIn::SubgroupInvocationId, _ => return Err(Error::UnsupportedBuiltIn(word)), }) } diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index e7c082c6a5..bdc981c873 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -3692,15 +3692,13 @@ impl> Frontend { ); } Op::GroupNonUniformBallot => { - inst.expect(4)?; - let _result_type_id = self.next()?; + inst.expect(5)?; + block.extend(emitter.finish(ctx.expressions)); + let result_type_id = self.next()?; let result_id = self.next()?; let exec_scope_id = self.next()?; let predicate_id = self.next()?; - let result_lookup = self.lookup_expression.lookup(result_id)?; - let result_handle = get_expr_handle!(result_id, result_lookup); - let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; let _exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle) .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) @@ -3726,6 +3724,18 @@ impl> Frontend { Some(predicate_handle) }; + let result_handle = ctx + .expressions + .append(crate::Expression::SubgroupBallotResult, span); + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: result_handle, + type_id: result_type_id, + block_id, + }, + ); + block.push( crate::Statement::SubgroupBallot { result: result_handle, @@ -3733,6 +3743,7 @@ impl> Frontend { }, span, ); + emitter.start(ctx.expressions); } spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny @@ -3752,17 +3763,18 @@ impl> Frontend { | spirv::Op::GroupNonUniformLogicalAnd | spirv::Op::GroupNonUniformLogicalOr | spirv::Op::GroupNonUniformLogicalXor => { + block.extend(emitter.finish(ctx.expressions)); inst.expect( if matches!( inst.op, spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny ) { - 4 - } else { 5 + } else { + 6 }, )?; - let _result_type_id = self.next()?; + let result_type_id = self.next()?; let result_id = self.next()?; let exec_scope_id = self.next()?; let collective_op_id = match inst.op { @@ -3787,8 +3799,6 @@ impl> Frontend { }; let argument_id = self.next()?; - let result_lookup = self.lookup_expression.lookup(result_id)?; - let result_handle = get_expr_handle!(result_id, result_lookup); let argument_lookup = self.lookup_expression.lookup(argument_id)?; let argument_handle = get_expr_handle!(argument_id, argument_lookup); @@ -3821,6 +3831,23 @@ impl> Frontend { _ => unreachable!(), }; + let result_type = self.lookup_type.lookup(result_type_id)?; + + let result_handle = ctx.expressions.append( + crate::Expression::SubgroupOperationResult { + ty: result_type.handle, + }, + span, + ); + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: result_handle, + type_id: result_type_id, + block_id, + }, + ); + block.push( crate::Statement::SubgroupCollectiveOperation { result: result_handle, @@ -3830,6 +3857,7 @@ impl> Frontend { }, span, ); + emitter.start(ctx.expressions); } Op::GroupNonUniformBroadcastFirst | Op::GroupNonUniformBroadcast @@ -3839,18 +3867,17 @@ impl> Frontend { | Op::GroupNonUniformShuffleXor => { inst.expect( if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) { - 4 - } else { 5 + } else { + 6 }, )?; - let _result_type_id = self.next()?; + block.extend(emitter.finish(ctx.expressions)); + let result_type_id = self.next()?; let result_id = self.next()?; let exec_scope_id = self.next()?; let argument_id = self.next()?; - let result_lookup = self.lookup_expression.lookup(result_id)?; - let result_handle = get_expr_handle!(result_id, result_lookup); let argument_lookup = self.lookup_expression.lookup(argument_id)?; let argument_handle = get_expr_handle!(argument_id, argument_lookup); @@ -3885,6 +3912,23 @@ impl> Frontend { } }; + let result_type = self.lookup_type.lookup(result_type_id)?; + + let result_handle = ctx.expressions.append( + crate::Expression::SubgroupOperationResult { + ty: result_type.handle, + }, + span, + ); + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: result_handle, + type_id: result_type_id, + block_id, + }, + ); + block.push( crate::Statement::SubgroupGather { result: result_handle, @@ -3893,6 +3937,7 @@ impl> Frontend { }, span, ); + emitter.start(ctx.expressions); } _ => return Err(Error::UnsupportedInstruction(self.state, inst.op)), } From 7a5bf0f7fce58aeb081eb7be2acab97822195b14 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Sat, 21 Oct 2023 16:19:38 -0400 Subject: [PATCH 19/46] subgroup: Add 0 arg subgroupBallot to tests and fix wgsl-out whitespace --- naga/src/back/wgsl/writer.rs | 2 +- naga/tests/in/subgroup-operations.wgsl | 1 + .../subgroup-operations.main.Compute.glsl | 39 +++++++------- naga/tests/out/hlsl/subgroup-operations.hlsl | 39 +++++++------- naga/tests/out/msl/subgroup-operations.msl | 39 +++++++------- naga/tests/out/spv/subgroup-operations.spvasm | 52 ++++++++++--------- naga/tests/out/wgsl/subgroup-operations.wgsl | 42 +++++++-------- 7 files changed, 110 insertions(+), 104 deletions(-) diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 0bc2dfceb0..baf2226ac6 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -931,7 +931,7 @@ impl Writer { self.start_named_expr(module, result, func_ctx, &res_name)?; self.named_expressions.insert(result, res_name); - writeln!(self.out, "subgroupBallot(")?; + write!(self.out, "subgroupBallot(")?; if let Some(predicate) = predicate { self.write_expr(module, predicate, func_ctx)?; } diff --git a/naga/tests/in/subgroup-operations.wgsl b/naga/tests/in/subgroup-operations.wgsl index f30b60be47..4239be114f 100644 --- a/naga/tests/in/subgroup-operations.wgsl +++ b/naga/tests/in/subgroup-operations.wgsl @@ -8,6 +8,7 @@ fn main( subgroupBarrier(); subgroupBallot((subgroup_invocation_id & 1u) == 1u); + subgroupBallot(); subgroupAll(subgroup_invocation_id != 0u); subgroupAny(subgroup_invocation_id == 0u); diff --git a/naga/tests/out/glsl/subgroup-operations.main.Compute.glsl b/naga/tests/out/glsl/subgroup-operations.main.Compute.glsl index a37cf8e247..9a92460a89 100644 --- a/naga/tests/out/glsl/subgroup-operations.main.Compute.glsl +++ b/naga/tests/out/glsl/subgroup-operations.main.Compute.glsl @@ -17,25 +17,26 @@ void main() { subgroupMemoryBarrier(); barrier(); uvec4 _e8 = subgroupBallot(((subgroup_invocation_id & 1u) == 1u)); - bool _e11 = subgroupAll((subgroup_invocation_id != 0u)); - bool _e14 = subgroupAny((subgroup_invocation_id == 0u)); - uint _e15 = subgroupAdd(subgroup_invocation_id); - uint _e16 = subgroupMul(subgroup_invocation_id); - uint _e17 = subgroupMin(subgroup_invocation_id); - uint _e18 = subgroupMax(subgroup_invocation_id); - uint _e19 = subgroupAnd(subgroup_invocation_id); - uint _e20 = subgroupOr(subgroup_invocation_id); - uint _e21 = subgroupXor(subgroup_invocation_id); - uint _e22 = subgroupExclusiveAdd(subgroup_invocation_id); - uint _e23 = subgroupExclusiveMul(subgroup_invocation_id); - uint _e24 = subgroupInclusiveAdd(subgroup_invocation_id); - uint _e25 = subgroupInclusiveMul(subgroup_invocation_id); - uint _e26 = subgroupBroadcastFirst(subgroup_invocation_id); - uint _e28 = subgroupBroadcast(subgroup_invocation_id, 4u); - uint _e32 = subgroupShuffle(subgroup_invocation_id, ((subgroup_size - 1u) - subgroup_invocation_id)); - uint _e34 = subgroupShuffleDown(subgroup_invocation_id, 1u); - uint _e36 = subgroupShuffleUp(subgroup_invocation_id, 1u); - uint _e39 = subgroupShuffleXor(subgroup_invocation_id, (subgroup_size - 1u)); + uvec4 _e9 = subgroupBallot(true); + bool _e12 = subgroupAll((subgroup_invocation_id != 0u)); + bool _e15 = subgroupAny((subgroup_invocation_id == 0u)); + uint _e16 = subgroupAdd(subgroup_invocation_id); + uint _e17 = subgroupMul(subgroup_invocation_id); + uint _e18 = subgroupMin(subgroup_invocation_id); + uint _e19 = subgroupMax(subgroup_invocation_id); + uint _e20 = subgroupAnd(subgroup_invocation_id); + uint _e21 = subgroupOr(subgroup_invocation_id); + uint _e22 = subgroupXor(subgroup_invocation_id); + uint _e23 = subgroupExclusiveAdd(subgroup_invocation_id); + uint _e24 = subgroupExclusiveMul(subgroup_invocation_id); + uint _e25 = subgroupInclusiveAdd(subgroup_invocation_id); + uint _e26 = subgroupInclusiveMul(subgroup_invocation_id); + uint _e27 = subgroupBroadcastFirst(subgroup_invocation_id); + uint _e29 = subgroupBroadcast(subgroup_invocation_id, 4u); + uint _e33 = subgroupShuffle(subgroup_invocation_id, ((subgroup_size - 1u) - subgroup_invocation_id)); + uint _e35 = subgroupShuffleDown(subgroup_invocation_id, 1u); + uint _e37 = subgroupShuffleUp(subgroup_invocation_id, 1u); + uint _e40 = subgroupShuffleXor(subgroup_invocation_id, (subgroup_size - 1u)); return; } diff --git a/naga/tests/out/hlsl/subgroup-operations.hlsl b/naga/tests/out/hlsl/subgroup-operations.hlsl index baa37826e0..65d3a51a92 100644 --- a/naga/tests/out/hlsl/subgroup-operations.hlsl +++ b/naga/tests/out/hlsl/subgroup-operations.hlsl @@ -9,24 +9,25 @@ void main(uint3 __local_invocation_id : SV_GroupThreadID) const uint subgroup_size = WaveGetLaneCount(); const uint subgroup_invocation_id = WaveGetLaneIndex(); const uint4 _e8 = WaveActiveBallot(((subgroup_invocation_id & 1u) == 1u)); - const bool _e11 = WaveActiveAllTrue((subgroup_invocation_id != 0u)); - const bool _e14 = WaveActiveAnyTrue((subgroup_invocation_id == 0u)); - const uint _e15 = WaveActiveSum(subgroup_invocation_id); - const uint _e16 = WaveActiveProduct(subgroup_invocation_id); - const uint _e17 = WaveActiveMin(subgroup_invocation_id); - const uint _e18 = WaveActiveMax(subgroup_invocation_id); - const uint _e19 = WaveActiveBitAnd(subgroup_invocation_id); - const uint _e20 = WaveActiveBitOr(subgroup_invocation_id); - const uint _e21 = WaveActiveBitXor(subgroup_invocation_id); - const uint _e22 = WavePrefixSum(subgroup_invocation_id); - const uint _e23 = WavePrefixProduct(subgroup_invocation_id); - const uint _e24 = subgroup_invocation_id + WavePrefixSum(subgroup_invocation_id); - const uint _e25 = subgroup_invocation_id * WavePrefixProduct(subgroup_invocation_id); - const uint _e26 = WaveReadLaneFirst(subgroup_invocation_id); - const uint _e28 = WaveReadLaneAt(subgroup_invocation_id, 4u); - const uint _e32 = WaveReadLaneAt(subgroup_invocation_id, ((subgroup_size - 1u) - subgroup_invocation_id)); - const uint _e34 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() + 1u); - const uint _e36 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() - 1u); - const uint _e39 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() ^ (subgroup_size - 1u)); + const uint4 _e9 = WaveActiveBallot(true); + const bool _e12 = WaveActiveAllTrue((subgroup_invocation_id != 0u)); + const bool _e15 = WaveActiveAnyTrue((subgroup_invocation_id == 0u)); + const uint _e16 = WaveActiveSum(subgroup_invocation_id); + const uint _e17 = WaveActiveProduct(subgroup_invocation_id); + const uint _e18 = WaveActiveMin(subgroup_invocation_id); + const uint _e19 = WaveActiveMax(subgroup_invocation_id); + const uint _e20 = WaveActiveBitAnd(subgroup_invocation_id); + const uint _e21 = WaveActiveBitOr(subgroup_invocation_id); + const uint _e22 = WaveActiveBitXor(subgroup_invocation_id); + const uint _e23 = WavePrefixSum(subgroup_invocation_id); + const uint _e24 = WavePrefixProduct(subgroup_invocation_id); + const uint _e25 = subgroup_invocation_id + WavePrefixSum(subgroup_invocation_id); + const uint _e26 = subgroup_invocation_id * WavePrefixProduct(subgroup_invocation_id); + const uint _e27 = WaveReadLaneFirst(subgroup_invocation_id); + const uint _e29 = WaveReadLaneAt(subgroup_invocation_id, 4u); + const uint _e33 = WaveReadLaneAt(subgroup_invocation_id, ((subgroup_size - 1u) - subgroup_invocation_id)); + const uint _e35 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() + 1u); + const uint _e37 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() - 1u); + const uint _e40 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() ^ (subgroup_size - 1u)); return; } diff --git a/naga/tests/out/msl/subgroup-operations.msl b/naga/tests/out/msl/subgroup-operations.msl index 576fa3b84e..fe41696892 100644 --- a/naga/tests/out/msl/subgroup-operations.msl +++ b/naga/tests/out/msl/subgroup-operations.msl @@ -15,24 +15,25 @@ kernel void main_( ) { metal::simdgroup_barrier(metal::mem_flags::mem_threadgroup); metal::uint4 unnamed = uint4((uint64_t)metal::simd_ballot((subgroup_invocation_id & 1u) == 1u), 0, 0, 0); - bool unnamed_1 = metal::simd_all(subgroup_invocation_id != 0u); - bool unnamed_2 = metal::simd_any(subgroup_invocation_id == 0u); - uint unnamed_3 = metal::simd_sum(subgroup_invocation_id); - uint unnamed_4 = metal::simd_product(subgroup_invocation_id); - uint unnamed_5 = metal::simd_min(subgroup_invocation_id); - uint unnamed_6 = metal::simd_max(subgroup_invocation_id); - uint unnamed_7 = metal::simd_and(subgroup_invocation_id); - uint unnamed_8 = metal::simd_or(subgroup_invocation_id); - uint unnamed_9 = metal::simd_xor(subgroup_invocation_id); - uint unnamed_10 = metal::simd_prefix_exclusive_sum(subgroup_invocation_id); - uint unnamed_11 = metal::simd_prefix_exclusive_product(subgroup_invocation_id); - uint unnamed_12 = metal::simd_prefix_inclusive_sum(subgroup_invocation_id); - uint unnamed_13 = metal::simd_prefix_inclusive_product(subgroup_invocation_id); - uint unnamed_14 = metal::simd_broadcast_first(subgroup_invocation_id); - uint unnamed_15 = metal::simd_broadcast(subgroup_invocation_id, 4u); - uint unnamed_16 = metal::simd_shuffle(subgroup_invocation_id, (subgroup_size - 1u) - subgroup_invocation_id); - uint unnamed_17 = metal::simd_shuffle_down(subgroup_invocation_id, 1u); - uint unnamed_18 = metal::simd_shuffle_up(subgroup_invocation_id, 1u); - uint unnamed_19 = metal::simd_shuffle_xor(subgroup_invocation_id, subgroup_size - 1u); + metal::uint4 unnamed_1 = uint4((uint64_t)metal::simd_ballot(true), 0, 0, 0); + bool unnamed_2 = metal::simd_all(subgroup_invocation_id != 0u); + bool unnamed_3 = metal::simd_any(subgroup_invocation_id == 0u); + uint unnamed_4 = metal::simd_sum(subgroup_invocation_id); + uint unnamed_5 = metal::simd_product(subgroup_invocation_id); + uint unnamed_6 = metal::simd_min(subgroup_invocation_id); + uint unnamed_7 = metal::simd_max(subgroup_invocation_id); + uint unnamed_8 = metal::simd_and(subgroup_invocation_id); + uint unnamed_9 = metal::simd_or(subgroup_invocation_id); + uint unnamed_10 = metal::simd_xor(subgroup_invocation_id); + uint unnamed_11 = metal::simd_prefix_exclusive_sum(subgroup_invocation_id); + uint unnamed_12 = metal::simd_prefix_exclusive_product(subgroup_invocation_id); + uint unnamed_13 = metal::simd_prefix_inclusive_sum(subgroup_invocation_id); + uint unnamed_14 = metal::simd_prefix_inclusive_product(subgroup_invocation_id); + uint unnamed_15 = metal::simd_broadcast_first(subgroup_invocation_id); + uint unnamed_16 = metal::simd_broadcast(subgroup_invocation_id, 4u); + uint unnamed_17 = metal::simd_shuffle(subgroup_invocation_id, (subgroup_size - 1u) - subgroup_invocation_id); + uint unnamed_18 = metal::simd_shuffle_down(subgroup_invocation_id, 1u); + uint unnamed_19 = metal::simd_shuffle_up(subgroup_invocation_id, 1u); + uint unnamed_20 = metal::simd_shuffle_xor(subgroup_invocation_id, subgroup_size - 1u); return; } diff --git a/naga/tests/out/spv/subgroup-operations.spvasm b/naga/tests/out/spv/subgroup-operations.spvasm index c2023c5473..72c68aa46c 100644 --- a/naga/tests/out/spv/subgroup-operations.spvasm +++ b/naga/tests/out/spv/subgroup-operations.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.3 ; Generator: rspirv -; Bound: 52 +; Bound: 54 OpCapability Shader OpCapability GroupNonUniform OpCapability GroupNonUniformBallot @@ -33,6 +33,7 @@ OpDecorate %13 BuiltIn SubgroupLocalInvocationId %22 = OpConstant %3 2 %23 = OpConstant %3 8 %26 = OpTypeVector %3 4 +%28 = OpConstantTrue %4 %15 = OpFunction %2 None %16 %5 = OpLabel %8 = OpLoad %3 %6 @@ -45,29 +46,30 @@ OpControlBarrier %21 %22 %23 %24 = OpBitwiseAnd %3 %14 %17 %25 = OpIEqual %4 %24 %17 %27 = OpGroupNonUniformBallot %26 %21 %25 -%28 = OpINotEqual %4 %14 %18 -%29 = OpGroupNonUniformAll %4 %21 %28 -%30 = OpIEqual %4 %14 %18 -%31 = OpGroupNonUniformAny %4 %21 %30 -%32 = OpGroupNonUniformIAdd %3 %21 Reduce %14 -%33 = OpGroupNonUniformIMul %3 %21 Reduce %14 -%34 = OpGroupNonUniformUMin %3 %21 Reduce %14 -%35 = OpGroupNonUniformUMax %3 %21 Reduce %14 -%36 = OpGroupNonUniformBitwiseAnd %3 %21 Reduce %14 -%37 = OpGroupNonUniformBitwiseOr %3 %21 Reduce %14 -%38 = OpGroupNonUniformBitwiseXor %3 %21 Reduce %14 -%39 = OpGroupNonUniformIAdd %3 %21 ExclusiveScan %14 -%40 = OpGroupNonUniformIMul %3 %21 ExclusiveScan %14 -%41 = OpGroupNonUniformIAdd %3 %21 InclusiveScan %14 -%42 = OpGroupNonUniformIMul %3 %21 InclusiveScan %14 -%43 = OpGroupNonUniformBroadcastFirst %3 %21 %14 -%44 = OpGroupNonUniformBroadcast %3 %21 %14 %19 -%45 = OpISub %3 %12 %17 -%46 = OpISub %3 %45 %14 -%47 = OpGroupNonUniformShuffle %3 %21 %14 %46 -%48 = OpGroupNonUniformShuffleDown %3 %21 %14 %17 -%49 = OpGroupNonUniformShuffleUp %3 %21 %14 %17 -%50 = OpISub %3 %12 %17 -%51 = OpGroupNonUniformShuffleXor %3 %21 %14 %50 +%29 = OpGroupNonUniformBallot %26 %21 %28 +%30 = OpINotEqual %4 %14 %18 +%31 = OpGroupNonUniformAll %4 %21 %30 +%32 = OpIEqual %4 %14 %18 +%33 = OpGroupNonUniformAny %4 %21 %32 +%34 = OpGroupNonUniformIAdd %3 %21 Reduce %14 +%35 = OpGroupNonUniformIMul %3 %21 Reduce %14 +%36 = OpGroupNonUniformUMin %3 %21 Reduce %14 +%37 = OpGroupNonUniformUMax %3 %21 Reduce %14 +%38 = OpGroupNonUniformBitwiseAnd %3 %21 Reduce %14 +%39 = OpGroupNonUniformBitwiseOr %3 %21 Reduce %14 +%40 = OpGroupNonUniformBitwiseXor %3 %21 Reduce %14 +%41 = OpGroupNonUniformIAdd %3 %21 ExclusiveScan %14 +%42 = OpGroupNonUniformIMul %3 %21 ExclusiveScan %14 +%43 = OpGroupNonUniformIAdd %3 %21 InclusiveScan %14 +%44 = OpGroupNonUniformIMul %3 %21 InclusiveScan %14 +%45 = OpGroupNonUniformBroadcastFirst %3 %21 %14 +%46 = OpGroupNonUniformBroadcast %3 %21 %14 %19 +%47 = OpISub %3 %12 %17 +%48 = OpISub %3 %47 %14 +%49 = OpGroupNonUniformShuffle %3 %21 %14 %48 +%50 = OpGroupNonUniformShuffleDown %3 %21 %14 %17 +%51 = OpGroupNonUniformShuffleUp %3 %21 %14 %17 +%52 = OpISub %3 %12 %17 +%53 = OpGroupNonUniformShuffleXor %3 %21 %14 %52 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/out/wgsl/subgroup-operations.wgsl b/naga/tests/out/wgsl/subgroup-operations.wgsl index f12f226387..c53aa3e2cf 100644 --- a/naga/tests/out/wgsl/subgroup-operations.wgsl +++ b/naga/tests/out/wgsl/subgroup-operations.wgsl @@ -1,26 +1,26 @@ @compute @workgroup_size(1, 1, 1) fn main(@builtin(num_subgroups) num_subgroups: u32, @builtin(subgroup_id) subgroup_id: u32, @builtin(subgroup_size) subgroup_size: u32, @builtin(subgroup_invocation_id) subgroup_invocation_id: u32) { subgroupBarrier(); - let _e8 = subgroupBallot( -((subgroup_invocation_id & 1u) == 1u)); - let _e11 = subgroupAll((subgroup_invocation_id != 0u)); - let _e14 = subgroupAny((subgroup_invocation_id == 0u)); - let _e15 = subgroupAdd(subgroup_invocation_id); - let _e16 = subgroupMul(subgroup_invocation_id); - let _e17 = subgroupMin(subgroup_invocation_id); - let _e18 = subgroupMax(subgroup_invocation_id); - let _e19 = subgroupAnd(subgroup_invocation_id); - let _e20 = subgroupOr(subgroup_invocation_id); - let _e21 = subgroupXor(subgroup_invocation_id); - let _e22 = subgroupPrefixExclusiveAdd(subgroup_invocation_id); - let _e23 = subgroupPrefixExclusiveMul(subgroup_invocation_id); - let _e24 = subgroupPrefixInclusiveAdd(subgroup_invocation_id); - let _e25 = subgroupPrefixInclusiveMul(subgroup_invocation_id); - let _e26 = subgroupBroadcastFirst(subgroup_invocation_id); - let _e28 = subgroupBroadcast(subgroup_invocation_id, 4u); - let _e32 = subgroupShuffle(subgroup_invocation_id, ((subgroup_size - 1u) - subgroup_invocation_id)); - let _e34 = subgroupShuffleDown(subgroup_invocation_id, 1u); - let _e36 = subgroupShuffleUp(subgroup_invocation_id, 1u); - let _e39 = subgroupShuffleXor(subgroup_invocation_id, (subgroup_size - 1u)); + let _e8 = subgroupBallot(((subgroup_invocation_id & 1u) == 1u)); + let _e9 = subgroupBallot(); + let _e12 = subgroupAll((subgroup_invocation_id != 0u)); + let _e15 = subgroupAny((subgroup_invocation_id == 0u)); + let _e16 = subgroupAdd(subgroup_invocation_id); + let _e17 = subgroupMul(subgroup_invocation_id); + let _e18 = subgroupMin(subgroup_invocation_id); + let _e19 = subgroupMax(subgroup_invocation_id); + let _e20 = subgroupAnd(subgroup_invocation_id); + let _e21 = subgroupOr(subgroup_invocation_id); + let _e22 = subgroupXor(subgroup_invocation_id); + let _e23 = subgroupPrefixExclusiveAdd(subgroup_invocation_id); + let _e24 = subgroupPrefixExclusiveMul(subgroup_invocation_id); + let _e25 = subgroupPrefixInclusiveAdd(subgroup_invocation_id); + let _e26 = subgroupPrefixInclusiveMul(subgroup_invocation_id); + let _e27 = subgroupBroadcastFirst(subgroup_invocation_id); + let _e29 = subgroupBroadcast(subgroup_invocation_id, 4u); + let _e33 = subgroupShuffle(subgroup_invocation_id, ((subgroup_size - 1u) - subgroup_invocation_id)); + let _e35 = subgroupShuffleDown(subgroup_invocation_id, 1u); + let _e37 = subgroupShuffleUp(subgroup_invocation_id, 1u); + let _e40 = subgroupShuffleXor(subgroup_invocation_id, (subgroup_size - 1u)); return; } From 2aaf0a0a56d64eff6cabe6edc50ca201a2f8b38f Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Sat, 21 Oct 2023 16:27:43 -0400 Subject: [PATCH 20/46] subgroup: Add spv-in test --- .../in/spv/subgroup-operations-s.param.ron | 27 +++++++ naga/tests/in/spv/subgroup-operations-s.spv | Bin 0 -> 1356 bytes .../tests/in/spv/subgroup-operations-s.spvasm | 75 ++++++++++++++++++ .../subgroup-operations-s.main.Compute.glsl | 58 ++++++++++++++ .../tests/out/hlsl/subgroup-operations-s.hlsl | 49 ++++++++++++ naga/tests/out/hlsl/subgroup-operations-s.ron | 12 +++ naga/tests/out/msl/subgroup-operations-s.msl | 55 +++++++++++++ .../tests/out/wgsl/subgroup-operations-s.wgsl | 40 ++++++++++ naga/tests/snapshots.rs | 5 ++ 9 files changed, 321 insertions(+) create mode 100644 naga/tests/in/spv/subgroup-operations-s.param.ron create mode 100644 naga/tests/in/spv/subgroup-operations-s.spv create mode 100644 naga/tests/in/spv/subgroup-operations-s.spvasm create mode 100644 naga/tests/out/glsl/subgroup-operations-s.main.Compute.glsl create mode 100644 naga/tests/out/hlsl/subgroup-operations-s.hlsl create mode 100644 naga/tests/out/hlsl/subgroup-operations-s.ron create mode 100644 naga/tests/out/msl/subgroup-operations-s.msl create mode 100644 naga/tests/out/wgsl/subgroup-operations-s.wgsl diff --git a/naga/tests/in/spv/subgroup-operations-s.param.ron b/naga/tests/in/spv/subgroup-operations-s.param.ron new file mode 100644 index 0000000000..122542d1f6 --- /dev/null +++ b/naga/tests/in/spv/subgroup-operations-s.param.ron @@ -0,0 +1,27 @@ +( + god_mode: true, + spv: ( + version: (1, 3), + ), + msl: ( + lang_version: (2, 4), + per_entry_point_map: {}, + inline_samplers: [], + spirv_cross_compatibility: false, + fake_missing_bindings: false, + zero_initialize_workgroup_memory: true, + ), + glsl: ( + version: Desktop(430), + writer_flags: (""), + binding_map: { }, + zero_initialize_workgroup_memory: true, + ), + hlsl: ( + shader_model: V6_0, + binding_map: {}, + fake_missing_bindings: true, + special_constants_binding: None, + zero_initialize_workgroup_memory: true, + ), +) diff --git a/naga/tests/in/spv/subgroup-operations-s.spv b/naga/tests/in/spv/subgroup-operations-s.spv new file mode 100644 index 0000000000000000000000000000000000000000..d4bf0191db2b92ed7c0d2c19a794e736d4863225 GIT binary patch literal 1356 zcmZ9M$xc*J5Qa}T+-6c{M8%;QL=kO8!I%)7Hf)FlS;>$N2*JCco6xiimp%?Yh_B&e z$il?`J5-V5y`-qB@6>b#=qOKjg-{B2LI|%Ud_$oua;aBLzcc;D^jp*KO@EN?3ze9+ zy0*DiYn>g`7MGq2hyKukifaR*CuFh*B*%Ms174BzNctq#C4(ZY4@sEAB@-vM(LS!X z+WSYR&Gt!4ex@Jtoz?RLkxHz0#aj9I!x-BgdtC9n@vZq~?<{XNNK;M=8OGTSe8q5e$?Iqd&D z-fRfBjKeldS)6}DHuRrKQQZ4&S?p30%lcpa<-L~t1A +#include + +using metal::uint; + + +void main_1( + thread uint& subgroup_size_1, + thread uint& subgroup_invocation_id_1 +) { + uint _e5 = subgroup_size_1; + uint _e6 = subgroup_invocation_id_1; + metal::uint4 unnamed = uint4((uint64_t)metal::simd_ballot((_e6 & 1u) == 1u), 0, 0, 0); + metal::uint4 unnamed_1 = uint4((uint64_t)metal::simd_ballot(true), 0, 0, 0); + bool unnamed_2 = metal::simd_all(_e6 != 0u); + bool unnamed_3 = metal::simd_any(_e6 == 0u); + uint unnamed_4 = metal::simd_sum(_e6); + uint unnamed_5 = metal::simd_product(_e6); + uint unnamed_6 = metal::simd_min(_e6); + uint unnamed_7 = metal::simd_max(_e6); + uint unnamed_8 = metal::simd_and(_e6); + uint unnamed_9 = metal::simd_or(_e6); + uint unnamed_10 = metal::simd_xor(_e6); + uint unnamed_11 = metal::simd_prefix_exclusive_sum(_e6); + uint unnamed_12 = metal::simd_prefix_exclusive_product(_e6); + uint unnamed_13 = metal::simd_prefix_inclusive_sum(_e6); + uint unnamed_14 = metal::simd_prefix_inclusive_product(_e6); + uint unnamed_15 = metal::simd_broadcast_first(_e6); + uint unnamed_16 = metal::simd_broadcast(_e6, 4u); + uint unnamed_17 = metal::simd_shuffle(_e6, (_e5 - 1u) - _e6); + uint unnamed_18 = metal::simd_shuffle_down(_e6, 1u); + uint unnamed_19 = metal::simd_shuffle_up(_e6, 1u); + uint unnamed_20 = metal::simd_shuffle_xor(_e6, _e5 - 1u); + return; +} + +struct main_Input { +}; +kernel void main_( + uint num_subgroups [[simdgroups_per_threadgroup]] +, uint subgroup_id [[simdgroup_index_in_threadgroup]] +, uint subgroup_size [[threads_per_simdgroup]] +, uint subgroup_invocation_id [[thread_index_in_simdgroup]] +) { + uint num_subgroups_1 = {}; + uint subgroup_id_1 = {}; + uint subgroup_size_1 = {}; + uint subgroup_invocation_id_1 = {}; + num_subgroups_1 = num_subgroups; + subgroup_id_1 = subgroup_id; + subgroup_size_1 = subgroup_size; + subgroup_invocation_id_1 = subgroup_invocation_id; + main_1(subgroup_size_1, subgroup_invocation_id_1); +} diff --git a/naga/tests/out/wgsl/subgroup-operations-s.wgsl b/naga/tests/out/wgsl/subgroup-operations-s.wgsl new file mode 100644 index 0000000000..54e3d60b3a --- /dev/null +++ b/naga/tests/out/wgsl/subgroup-operations-s.wgsl @@ -0,0 +1,40 @@ +var num_subgroups_1: u32; +var subgroup_id_1: u32; +var subgroup_size_1: u32; +var subgroup_invocation_id_1: u32; + +fn main_1() { + let _e5 = subgroup_size_1; + let _e6 = subgroup_invocation_id_1; + let _e9 = subgroupBallot(((_e6 & 1u) == 1u)); + let _e10 = subgroupBallot(); + let _e12 = subgroupAll((_e6 != 0u)); + let _e14 = subgroupAny((_e6 == 0u)); + let _e15 = subgroupAdd(_e6); + let _e16 = subgroupMul(_e6); + let _e17 = subgroupMin(_e6); + let _e18 = subgroupMax(_e6); + let _e19 = subgroupAnd(_e6); + let _e20 = subgroupOr(_e6); + let _e21 = subgroupXor(_e6); + let _e22 = subgroupPrefixExclusiveAdd(_e6); + let _e23 = subgroupPrefixExclusiveMul(_e6); + let _e24 = subgroupPrefixInclusiveAdd(_e6); + let _e25 = subgroupPrefixInclusiveMul(_e6); + let _e26 = subgroupBroadcastFirst(_e6); + let _e27 = subgroupBroadcast(_e6, 4u); + let _e30 = subgroupShuffle(_e6, ((_e5 - 1u) - _e6)); + let _e31 = subgroupShuffleDown(_e6, 1u); + let _e32 = subgroupShuffleUp(_e6, 1u); + let _e34 = subgroupShuffleXor(_e6, (_e5 - 1u)); + return; +} + +@compute @workgroup_size(1, 1, 1) +fn main(@builtin(num_subgroups) num_subgroups: u32, @builtin(subgroup_id) subgroup_id: u32, @builtin(subgroup_size) subgroup_size: u32, @builtin(subgroup_invocation_id) subgroup_invocation_id: u32) { + num_subgroups_1 = num_subgroups; + subgroup_id_1 = subgroup_id; + subgroup_size_1 = subgroup_size; + subgroup_invocation_id_1 = subgroup_invocation_id; + main_1(); +} diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index 67573bbbaa..3097f5457d 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -857,6 +857,11 @@ fn convert_spv_all() { true, Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, ); + convert_spv( + "subgroup-operations-s", + false, + Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, + ); } #[cfg(feature = "glsl-in")] From fb1e3f911962bfcf6f52670b08bc8c601bb07f32 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Sat, 21 Oct 2023 16:44:14 -0400 Subject: [PATCH 21/46] subgroup: resolve fixmes & fix typo --- naga/src/compact/expressions.rs | 4 ++-- naga/src/front/spv/error.rs | 2 +- naga/src/valid/analyzer.rs | 4 ++-- naga/src/valid/expression.rs | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/naga/src/compact/expressions.rs b/naga/src/compact/expressions.rs index 1533362407..47b4155679 100644 --- a/naga/src/compact/expressions.rs +++ b/naga/src/compact/expressions.rs @@ -55,7 +55,7 @@ impl<'tracer> ExpressionTracer<'tracer> { | Ex::GlobalVariable(_) | Ex::LocalVariable(_) | Ex::CallResult(_) - | Ex::SubgroupBallotResult // FIXME: ??? + | Ex::SubgroupBallotResult | Ex::RayQueryProceedResult => {} Ex::Constant(handle) => { @@ -224,7 +224,7 @@ impl ModuleMap { | Ex::GlobalVariable(_) | Ex::LocalVariable(_) | Ex::CallResult(_) - | Ex::SubgroupBallotResult // FIXME: ??? + | Ex::SubgroupBallotResult | Ex::RayQueryProceedResult => {} // Expressions that contain handles that need to be adjusted. diff --git a/naga/src/front/spv/error.rs b/naga/src/front/spv/error.rs index 8508ede042..cc6cd98801 100644 --- a/naga/src/front/spv/error.rs +++ b/naga/src/front/spv/error.rs @@ -54,7 +54,7 @@ pub enum Error { UnknownBinaryOperator(spirv::Op), #[error("unknown relational function {0:?}")] UnknownRelationalFunction(spirv::Op), - #[error("unsupported group opeation %{0}")] + #[error("unsupported group operation %{0}")] UnsupportedGroupOperation(spirv::Word), #[error("invalid parameter {0:?}")] InvalidParameter(spirv::Op), diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index f4347df1dd..e10b2b7eb6 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -741,11 +741,11 @@ impl FunctionInfo { requirements: UniformityRequirements::empty(), }, E::SubgroupBallotResult => Uniformity { - non_uniform_result: None, // FIXME + non_uniform_result: None, requirements: UniformityRequirements::empty(), }, E::SubgroupOperationResult { .. } => Uniformity { - non_uniform_result: None, // FIXME + non_uniform_result: None, requirements: UniformityRequirements::empty(), }, }; diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 03ad851dbf..b76091a122 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -1538,7 +1538,7 @@ impl super::Validator { } }, E::SubgroupBallotResult => ShaderStages::COMPUTE | ShaderStages::FRAGMENT, - E::SubgroupOperationResult { .. } => ShaderStages::COMPUTE, // FIXME + E::SubgroupOperationResult { .. } => ShaderStages::COMPUTE | ShaderStages::FRAGMENT, }; Ok(stages) } From 209225ae7a205c15c239bc135d8d1d54280be01c Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Sat, 21 Oct 2023 17:03:13 -0400 Subject: [PATCH 22/46] subgroup: Add subgroup capability --- naga/src/valid/function.rs | 30 ++++++++++++++++++++- naga/src/valid/interface.rs | 4 +++ naga/src/valid/mod.rs | 2 ++ naga/tests/in/subgroup-operations.param.ron | 1 + 4 files changed, 36 insertions(+), 1 deletion(-) diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index 729a6405c2..09e368a7db 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -148,6 +148,8 @@ pub enum FunctionError { InvalidRayDescriptor(Handle), #[error("Ray Query {0:?} does not have a matching type")] InvalidRayQueryType(Handle), + #[error("Shader requires capability {0:?}")] + MissingCapability(super::Capabilities), #[error( "Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}" )] @@ -746,8 +748,16 @@ impl super::Validator { stages &= super::ShaderStages::FRAGMENT; finished = true; } - S::Barrier(_) => { + S::Barrier(barrier) => { stages &= super::ShaderStages::COMPUTE; + if barrier.contains(crate::Barrier::SUB_GROUP) + && !self.capabilities.contains(super::Capabilities::SUBGROUP) + { + return Err(FunctionError::MissingCapability( + super::Capabilities::SUBGROUP, + ) + .with_span_static(span, "subgroup operation")); + } } S::Store { pointer, value } => { let mut current = pointer; @@ -1038,6 +1048,12 @@ impl super::Validator { } } S::SubgroupBallot { result, predicate } => { + if !self.capabilities.contains(super::Capabilities::SUBGROUP) { + return Err(FunctionError::MissingCapability( + super::Capabilities::SUBGROUP, + ) + .with_span_static(span, "subgroup operation")); + } if let Some(predicate) = predicate { let predicate_inner = context.resolve_type(predicate, &self.valid_expression_set)?; @@ -1065,6 +1081,12 @@ impl super::Validator { argument, result, } => { + if !self.capabilities.contains(super::Capabilities::SUBGROUP) { + return Err(FunctionError::MissingCapability( + super::Capabilities::SUBGROUP, + ) + .with_span_static(span, "subgroup operation")); + } self.validate_subgroup_operation(op, collective_op, argument, result, context)?; } S::SubgroupGather { @@ -1072,6 +1094,12 @@ impl super::Validator { argument, result, } => { + if !self.capabilities.contains(super::Capabilities::SUBGROUP) { + return Err(FunctionError::MissingCapability( + super::Capabilities::SUBGROUP, + ) + .with_span_static(span, "subgroup operation")); + } self.validate_subgroup_broadcast(mode, argument, result, context)?; } } diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index 4b5f66492c..29bbf3e989 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -168,6 +168,10 @@ impl VaryingContext<'_> { Bi::PrimitiveIndex => Capabilities::PRIMITIVE_INDEX, Bi::ViewIndex => Capabilities::MULTIVIEW, Bi::SampleIndex => Capabilities::MULTISAMPLED_SHADING, + Bi::NumSubgroups + | Bi::SubgroupId + | Bi::SubgroupSize + | Bi::SubgroupInvocationId => Capabilities::SUBGROUP, _ => Capabilities::empty(), }; if !self.capabilities.contains(required) { diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index 2fb0a72775..59f1f0e630 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -117,6 +117,8 @@ bitflags::bitflags! { const DUAL_SOURCE_BLENDING = 0x2000; /// Support for arrayed cube textures. const CUBE_ARRAY_TEXTURES = 0x4000; + /// Support for subgroup operations + const SUBGROUP = 0x8000; } } diff --git a/naga/tests/in/subgroup-operations.param.ron b/naga/tests/in/subgroup-operations.param.ron index fc444a3efe..122542d1f6 100644 --- a/naga/tests/in/subgroup-operations.param.ron +++ b/naga/tests/in/subgroup-operations.param.ron @@ -1,4 +1,5 @@ ( + god_mode: true, spv: ( version: (1, 3), ), From 77d33b93b90126f24642ea245e3698d154825423 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Wed, 18 Oct 2023 08:24:37 -0400 Subject: [PATCH 23/46] subgroup: Treat subgroup operation results as non-uniform --- naga/src/valid/analyzer.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index e10b2b7eb6..55838e4953 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -741,11 +741,11 @@ impl FunctionInfo { requirements: UniformityRequirements::empty(), }, E::SubgroupBallotResult => Uniformity { - non_uniform_result: None, + non_uniform_result: Some(handle), requirements: UniformityRequirements::empty(), }, E::SubgroupOperationResult { .. } => Uniformity { - non_uniform_result: None, + non_uniform_result: Some(handle), requirements: UniformityRequirements::empty(), }, }; From 1e51650d577e12a969a9bf05aa68500c0ff9ecee Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Sat, 21 Oct 2023 17:23:34 -0400 Subject: [PATCH 24/46] subgroup: refactor wgsl subgroup gather parsing --- naga/src/front/wgsl/lower/mod.rs | 53 +++++++++++++++++++++++-------- naga/src/front/wgsl/parse/conv.rs | 15 --------- 2 files changed, 40 insertions(+), 28 deletions(-) diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 50875dfad0..150d56bc5a 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -829,6 +829,29 @@ impl Texture { } } +enum SubgroupGather { + BroadcastFirst, + Broadcast, + Shuffle, + ShuffleDown, + ShuffleUp, + ShuffleXor, +} + +impl SubgroupGather { + pub fn map(word: &str) -> Option { + Some(match word { + "subgroupBroadcastFirst" => Self::BroadcastFirst, + "subgroupBroadcast" => Self::Broadcast, + "subgroupShuffle" => Self::Shuffle, + "subgroupShuffleDown" => Self::ShuffleDown, + "subgroupShuffleUp" => Self::ShuffleUp, + "subgroupShuffleXor" => Self::ShuffleXor, + _ => return None, + }) + } +} + pub struct Lowerer<'source, 'temp> { index: &'temp Index<'source>, layouter: Layouter, @@ -1870,7 +1893,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { return Ok(Some( self.subgroup_operation_helper(span, op, cop, arguments, ctx)?, )); - } else if let Some(mode) = conv::map_subgroup_gather(function.name) { + } else if let Some(mode) = SubgroupGather::map(function.name) { return Ok(Some( self.subgroup_gather_helper(span, mode, arguments, ctx)?, )); @@ -2497,18 +2520,29 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { fn subgroup_gather_helper( &mut self, span: Span, - mode: crate::GatherMode, + mode: SubgroupGather, arguments: &[Handle>], ctx: &mut ExpressionContext<'source, '_, '_>, ) -> Result, Error<'source>> { let mut args = ctx.prepare_args(arguments, 2, span); let argument = self.expression(args.next()?, ctx)?; - let index = if let crate::GatherMode::BroadcastFirst = mode { - Handle::new(NonZeroU32::new(u32::MAX).unwrap()) + + use SubgroupGather as Sg; + let mode = if let Sg::BroadcastFirst = mode { + crate::GatherMode::BroadcastFirst } else { - self.expression(args.next()?, ctx)? + let index = self.expression(args.next()?, ctx)?; + match mode { + Sg::Broadcast => crate::GatherMode::Broadcast(index), + Sg::Shuffle => crate::GatherMode::Shuffle(index), + Sg::ShuffleDown => crate::GatherMode::ShuffleDown(index), + Sg::ShuffleUp => crate::GatherMode::ShuffleUp(index), + Sg::ShuffleXor => crate::GatherMode::ShuffleXor(index), + Sg::BroadcastFirst => unreachable!(), + } }; + args.finish()?; let ty = ctx.register_type(argument)?; @@ -2518,14 +2552,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let rctx = ctx.runtime_expression_ctx(span)?; rctx.block.push( crate::Statement::SubgroupGather { - mode: match mode { - crate::GatherMode::BroadcastFirst => crate::GatherMode::BroadcastFirst, - crate::GatherMode::Broadcast(_) => crate::GatherMode::Broadcast(index), - crate::GatherMode::Shuffle(_) => crate::GatherMode::Shuffle(index), - crate::GatherMode::ShuffleDown(_) => crate::GatherMode::ShuffleDown(index), - crate::GatherMode::ShuffleUp(_) => crate::GatherMode::ShuffleUp(index), - crate::GatherMode::ShuffleXor(_) => crate::GatherMode::ShuffleXor(index), - }, + mode, argument, result, }, diff --git a/naga/src/front/wgsl/parse/conv.rs b/naga/src/front/wgsl/parse/conv.rs index c53f4df753..61fd1bb37e 100644 --- a/naga/src/front/wgsl/parse/conv.rs +++ b/naga/src/front/wgsl/parse/conv.rs @@ -263,18 +263,3 @@ pub fn map_subgroup_operation( _ => return None, }) } - -pub fn map_subgroup_gather(word: &str) -> Option { - use crate::GatherMode as gm; - use crate::Handle; - use std::num::NonZeroU32; - Some(match word { - "subgroupBroadcastFirst" => gm::BroadcastFirst, - "subgroupBroadcast" => gm::Broadcast(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), - "subgroupShuffle" => gm::Shuffle(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), - "subgroupShuffleDown" => gm::ShuffleDown(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), - "subgroupShuffleUp" => gm::ShuffleUp(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), - "subgroupShuffleXor" => gm::ShuffleXor(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), - _ => return None, - }) -} From 68fc4d5d7ddcc63fd2fe6be9a6bc906c490cd293 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Sat, 21 Oct 2023 18:18:14 -0400 Subject: [PATCH 25/46] subgroup: doc comments for subgroup `Statement`s and `Expression`s --- naga/src/lib.rs | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/naga/src/lib.rs b/naga/src/lib.rs index 08f535f565..c42856a553 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -1446,9 +1446,7 @@ pub enum Expression { /// /// For [`TypeInner::Atomic`] the result is a corresponding scalar. /// For other types behind the `pointer`, the result is `T`. - Load { - pointer: Handle, - }, + Load { pointer: Handle }, /// Sample a point from a sampled or a depth image. ImageSample { image: Handle, @@ -1588,10 +1586,7 @@ pub enum Expression { /// Result of calling another function. CallResult(Handle), /// Result of an atomic operation. - AtomicResult { - ty: Handle, - comparison: bool, - }, + AtomicResult { ty: Handle, comparison: bool }, /// Result of a [`WorkGroupUniformLoad`] statement. /// /// [`WorkGroupUniformLoad`]: Statement::WorkGroupUniformLoad @@ -1619,10 +1614,15 @@ pub enum Expression { query: Handle, committed: bool, }, + /// Result of a [`SubgroupBallot`] statement. + /// + /// [`SubgroupBallot`]: Statement::SubgroupBallot SubgroupBallotResult, - SubgroupOperationResult { - ty: Handle, - }, + /// Result of a [`SubgroupCollectiveOperation`] or [`SubgroupGather`] statement. + /// + /// [`SubgroupCollectiveOperation`]: Statement::SubgroupCollectiveOperation + /// [`SubgroupGather`]: Statement::SubgroupGather + SubgroupOperationResult { ty: Handle }, } pub use block::Block; @@ -1895,7 +1895,7 @@ pub enum Statement { /// The specific operation we're performing on `query`. fun: RayQueryFunction, }, - // subgroupBallot(bool) -> vec4 + /// Calculate a bitmask using a boolean from each active thread in the subgroup SubgroupBallot { /// The [`SubgroupBallotResult`] expression representing this load's result. /// @@ -1904,7 +1904,7 @@ pub enum Statement { /// The value from this thread to store in the ballot predicate: Option>, }, - + /// Gather a value from another active thread in the subgroup SubgroupGather { /// Specifies which thread to gather from mode: GatherMode, @@ -1915,8 +1915,7 @@ pub enum Statement { /// [`SubgroupOperationResult`]: Expression::SubgroupOperationResult result: Handle, }, - - /// Compute a collective operation across all active threads in th subgroup + /// Compute a collective operation across all active threads in the subgroup SubgroupCollectiveOperation { /// What operation to compute op: SubgroupOperation, From 91c569da001134a60e9ccf27e97f990e505f3f64 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Mon, 23 Oct 2023 17:20:35 -0700 Subject: [PATCH 26/46] subgroup: add validation for each subgroup operation type supported operations and stages subgroup operations are supported on can be passed to the validator after creating it operations are grouped to follow vulkan: - basic: elect, barrier - vote: any, all - arithmetic: reductions, scan - ballot: ballot, broadcasts, - shuffle: shuffles, - shuffle relative: shuffle up, down --- naga-cli/src/bin/naga.rs | 2 ++ naga/src/valid/expression.rs | 3 +- naga/src/valid/function.rs | 61 ++++++++++++++++++++++++++------ naga/src/valid/mod.rs | 67 ++++++++++++++++++++++++++++++++++++ naga/tests/snapshots.rs | 18 ++++++++-- 5 files changed, 136 insertions(+), 15 deletions(-) diff --git a/naga-cli/src/bin/naga.rs b/naga-cli/src/bin/naga.rs index 3b0873a376..26264b0b0a 100644 --- a/naga-cli/src/bin/naga.rs +++ b/naga-cli/src/bin/naga.rs @@ -401,6 +401,8 @@ fn run() -> Result<(), Box> { // Validate the IR before compaction. let info = match naga::valid::Validator::new(params.validation_flags, validation_caps) + .subgroup_stages(naga::valid::ShaderStages::all()) + .subgroup_operations(naga::valid::SubgroupOperationSet::all()) .validate(&module) { Ok(info) => Some(info), diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index b76091a122..231fb9f009 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -1537,8 +1537,7 @@ impl super::Validator { return Err(ExpressionError::InvalidRayQueryType(query)); } }, - E::SubgroupBallotResult => ShaderStages::COMPUTE | ShaderStages::FRAGMENT, - E::SubgroupOperationResult { .. } => ShaderStages::COMPUTE | ShaderStages::FRAGMENT, + E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages, }; Ok(stages) } diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index 09e368a7db..90183ba5b2 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -58,6 +58,8 @@ pub enum SubgroupError { InvalidOperand(Handle), #[error("Result type for {0:?} doesn't match the statement")] ResultTypeMismatch(Handle), + #[error("Support for subgroup operation {0:?} is required")] + UnsupportedOperation(super::SubgroupOperationSet), } #[derive(Clone, Debug, thiserror::Error)] @@ -750,13 +752,24 @@ impl super::Validator { } S::Barrier(barrier) => { stages &= super::ShaderStages::COMPUTE; - if barrier.contains(crate::Barrier::SUB_GROUP) - && !self.capabilities.contains(super::Capabilities::SUBGROUP) - { - return Err(FunctionError::MissingCapability( - super::Capabilities::SUBGROUP, - ) - .with_span_static(span, "subgroup operation")); + if barrier.contains(crate::Barrier::SUB_GROUP) { + if !self.capabilities.contains(super::Capabilities::SUBGROUP) { + return Err(FunctionError::MissingCapability( + super::Capabilities::SUBGROUP, + ) + .with_span_static(span, "missing capability for this operation")); + } + if !self + .subgroup_operations + .contains(super::SubgroupOperationSet::BASIC) + { + return Err(FunctionError::InvalidSubgroup( + SubgroupError::UnsupportedOperation( + super::SubgroupOperationSet::BASIC, + ), + ) + .with_span_static(span, "support for this operation is not present")); + } } } S::Store { pointer, value } => { @@ -1048,11 +1061,23 @@ impl super::Validator { } } S::SubgroupBallot { result, predicate } => { + stages &= self.subgroup_stages; if !self.capabilities.contains(super::Capabilities::SUBGROUP) { return Err(FunctionError::MissingCapability( super::Capabilities::SUBGROUP, ) - .with_span_static(span, "subgroup operation")); + .with_span_static(span, "missing capability for this operation")); + } + if !self + .subgroup_operations + .contains(super::SubgroupOperationSet::BALLOT) + { + return Err(FunctionError::InvalidSubgroup( + SubgroupError::UnsupportedOperation( + super::SubgroupOperationSet::BALLOT, + ), + ) + .with_span_static(span, "support for this operation is not present")); } if let Some(predicate) = predicate { let predicate_inner = @@ -1081,11 +1106,19 @@ impl super::Validator { argument, result, } => { + stages &= self.subgroup_stages; if !self.capabilities.contains(super::Capabilities::SUBGROUP) { return Err(FunctionError::MissingCapability( super::Capabilities::SUBGROUP, ) - .with_span_static(span, "subgroup operation")); + .with_span_static(span, "missing capability for this operation")); + } + let operation = op.required_operations(); + if !self.subgroup_operations.contains(operation) { + return Err(FunctionError::InvalidSubgroup( + SubgroupError::UnsupportedOperation(operation), + ) + .with_span_static(span, "support for this operation is not present")); } self.validate_subgroup_operation(op, collective_op, argument, result, context)?; } @@ -1094,11 +1127,19 @@ impl super::Validator { argument, result, } => { + stages &= self.subgroup_stages; if !self.capabilities.contains(super::Capabilities::SUBGROUP) { return Err(FunctionError::MissingCapability( super::Capabilities::SUBGROUP, ) - .with_span_static(span, "subgroup operation")); + .with_span_static(span, "missing capability for this operation")); + } + let operation = mode.required_operations(); + if !self.subgroup_operations.contains(operation) { + return Err(FunctionError::InvalidSubgroup( + SubgroupError::UnsupportedOperation(operation), + ) + .with_span_static(span, "support for this operation is not present")); } self.validate_subgroup_broadcast(mode, argument, result, context)?; } diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index 59f1f0e630..479a739cfb 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -128,6 +128,59 @@ impl Default for Capabilities { } } +bitflags::bitflags! { + /// Supported subgroup operations + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] + #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] + pub struct SubgroupOperationSet: u8 { + /// Elect, Barrier + const BASIC = 1 << 0; + /// Any, All + const VOTE = 1 << 1; + /// reductions, scans + const ARITHMETIC = 1 << 2; + /// ballot, broadcast + const BALLOT = 1 << 3; + /// shuffle, shuffle xor + const SHUFFLE = 1 << 4; + /// shuffle up, down + const SHUFFLE_RELATIVE = 1 << 5; + // We don't support these operations yet + // /// Clustered + // const CLUSTERED = 1 << 6; + // /// Quad supported + // const QUAD_FRAMENT_COMPUTE = 1 << 7; + // /// Quad supported in all stages + // const QUAD_ALL_STAGES = 1 << 8; + } +} + +#[cfg(feature = "validate")] +impl super::SubgroupOperation { + const fn required_operations(&self) -> SubgroupOperationSet { + use SubgroupOperationSet as S; + match *self { + Self::All | Self::Any => S::VOTE, + Self::Add | Self::Mul | Self::Min | Self::Max | Self::And | Self::Or | Self::Xor => { + S::ARITHMETIC + } + } + } +} + +#[cfg(feature = "validate")] +impl super::GatherMode { + const fn required_operations(&self) -> SubgroupOperationSet { + use SubgroupOperationSet as S; + match *self { + Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT, + Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE, + Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE, + } + } +} + bitflags::bitflags! { /// Validation flags. #[cfg_attr(feature = "serialize", derive(serde::Serialize))] @@ -175,6 +228,8 @@ impl ops::Index> for ModuleInfo { pub struct Validator { flags: ValidationFlags, capabilities: Capabilities, + subgroup_stages: ShaderStages, + subgroup_operations: SubgroupOperationSet, types: Vec, layouter: Layouter, location_mask: BitSet, @@ -291,6 +346,8 @@ impl Validator { Validator { flags, capabilities, + subgroup_stages: ShaderStages::empty(), + subgroup_operations: SubgroupOperationSet::empty(), types: Vec::new(), layouter: Layouter::default(), location_mask: BitSet::new(), @@ -301,6 +358,16 @@ impl Validator { } } + pub fn subgroup_stages(&mut self, stages: ShaderStages) -> &mut Self { + self.subgroup_stages = stages; + self + } + + pub fn subgroup_operations(&mut self, operations: SubgroupOperationSet) -> &mut Self { + self.subgroup_operations = operations; + self + } + /// Reset the validator internals pub fn reset(&mut self) { self.types.clear(); diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index 3097f5457d..018d93f6ab 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -258,10 +258,18 @@ fn check_targets( let params = input.read_parameters(); let name = &input.file_name; - let capabilities = if params.god_mode { - naga::valid::Capabilities::all() + let (capabilities, subgroup_stages, subgroup_operations) = if params.god_mode { + ( + naga::valid::Capabilities::all(), + naga::valid::ShaderStages::all(), + naga::valid::SubgroupOperationSet::all(), + ) } else { - naga::valid::Capabilities::default() + ( + naga::valid::Capabilities::default(), + naga::valid::ShaderStages::empty(), + naga::valid::SubgroupOperationSet::empty(), + ) }; #[cfg(feature = "serialize")] @@ -274,6 +282,8 @@ fn check_targets( } let info = naga::valid::Validator::new(naga::valid::ValidationFlags::all(), capabilities) + .subgroup_stages(subgroup_stages) + .subgroup_operations(subgroup_operations) .validate(module) .unwrap_or_else(|_| panic!("Naga module validation failed on test '{}'", name.display())); @@ -291,6 +301,8 @@ fn check_targets( } naga::valid::Validator::new(naga::valid::ValidationFlags::all(), capabilities) + .subgroup_stages(subgroup_stages) + .subgroup_operations(subgroup_operations) .validate(module) .unwrap_or_else(|_| { panic!( From 6f6f789a57545299f19b8d8c93d9d93486479d14 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Sat, 14 Oct 2023 15:23:50 +0200 Subject: [PATCH 27/46] Add feature for subgroup operations in fragment and compute shaders --- CHANGELOG.md | 4 ++++ wgpu-types/src/lib.rs | 12 +++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 725f72b991..413978e6cd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,6 +40,10 @@ Bottom level categories: ## Unreleased +### Added/New Features + +- Add `SUBGROUP_OPERATIONS` feature. By @exrook and @lichtso in [#4240](https://github.com/gfx-rs/wgpu/pull/4240) + For naga changelogs at or before v0.14.0. See [naga's changelog](naga/CHANGELOG.md). ## v0.18.0 (2023-10-25) diff --git a/wgpu-types/src/lib.rs b/wgpu-types/src/lib.rs index 82989598ef..4095d24e7a 100644 --- a/wgpu-types/src/lib.rs +++ b/wgpu-types/src/lib.rs @@ -299,7 +299,17 @@ bitflags::bitflags! { /// This is a web and native feature. const SHADER_F16 = 1 << 8; - // 9..14 available + /// Allows shaders to use the subgroup operation built-ins + /// + /// Supported Platforms: + /// - Vulkan + /// - DX12 + /// - Metal + /// + /// This is a web and native feature. + const SUBGROUP_OPERATIONS = 1 << 9; + + // 10..14 available // Texture Formats: From 7d2e273fba5e61be98c246c7b70fe88a4c6c9694 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Sat, 30 Sep 2023 18:33:19 -0400 Subject: [PATCH 28/46] Adds feature detection for Vulkan. --- wgpu-hal/src/vulkan/adapter.rs | 46 +++++++++++++++++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/wgpu-hal/src/vulkan/adapter.rs b/wgpu-hal/src/vulkan/adapter.rs index fd62473fd7..e0a210a787 100644 --- a/wgpu-hal/src/vulkan/adapter.rs +++ b/wgpu-hal/src/vulkan/adapter.rs @@ -489,6 +489,24 @@ impl PhysicalDeviceFeatures { ); } + if let Some(ref subgroup) = caps.subgroup { + features.set( + F::SUBGROUP_OPERATIONS, + subgroup.supported_operations.contains( + vk::SubgroupFeatureFlags::BASIC + | vk::SubgroupFeatureFlags::VOTE + | vk::SubgroupFeatureFlags::ARITHMETIC + | vk::SubgroupFeatureFlags::BALLOT + | vk::SubgroupFeatureFlags::SHUFFLE + | vk::SubgroupFeatureFlags::SHUFFLE_RELATIVE + | vk::SubgroupFeatureFlags::CLUSTERED + | vk::SubgroupFeatureFlags::QUAD, + ) && subgroup + .supported_stages + .contains(vk::ShaderStageFlags::COMPUTE | vk::ShaderStageFlags::FRAGMENT), + ); + } + let supports_depth_format = |format| { supports_format( instance, @@ -551,6 +569,8 @@ pub struct PhysicalDeviceCapabilities { maintenance_3: Option, descriptor_indexing: Option, driver: Option, + subgroup: Option, + /// The effective driver api version supported by the physical device. /// The device API version. /// /// Which is the version of Vulkan supported for device-level functionality. @@ -816,6 +836,13 @@ impl super::InstanceShared { builder = builder.push_next(next); } + if capabilities.device_api_version >= vk::API_VERSION_1_1 { + let next = capabilities + .subgroup + .insert(vk::PhysicalDeviceSubgroupProperties::default()); + builder = builder.push_next(next); + } + let mut properties2 = builder.build(); unsafe { get_device_properties.get_physical_device_properties2(phd, &mut properties2); @@ -1252,6 +1279,19 @@ impl super::Adapter { capabilities.push(spv::Capability::Geometry); } + if features.contains(wgt::Features::SUBGROUP_OPERATIONS) { + capabilities.push(spv::Capability::GroupNonUniform); + capabilities.push(spv::Capability::GroupNonUniformVote); + capabilities.push(spv::Capability::GroupNonUniformArithmetic); + capabilities.push(spv::Capability::GroupNonUniformBallot); + capabilities.push(spv::Capability::GroupNonUniformShuffle); + capabilities.push(spv::Capability::GroupNonUniformShuffleRelative); + capabilities.push(spv::Capability::GroupNonUniformClustered); + capabilities.push(spv::Capability::GroupNonUniformQuad); + capabilities.push(spv::Capability::SubgroupBallotKHR); + capabilities.push(spv::Capability::SubgroupVoteKHR); + } + if features.intersects( wgt::Features::SAMPLED_TEXTURE_AND_STORAGE_BUFFER_ARRAY_NON_UNIFORM_INDEXING | wgt::Features::UNIFORM_BUFFER_AND_STORAGE_TEXTURE_ARRAY_NON_UNIFORM_INDEXING, @@ -1279,7 +1319,11 @@ impl super::Adapter { true, // could check `super::Workarounds::SEPARATE_ENTRY_POINTS` ); spv::Options { - lang_version: (1, 0), + lang_version: if features.contains(wgt::Features::SUBGROUP_OPERATIONS) { + (1, 3) + } else { + (1, 0) + }, flags, capabilities: Some(capabilities.iter().cloned().collect()), bounds_check_policies: naga::proc::BoundsCheckPolicies { From 55e21e8bb2ac0d7a76e223145cedca371f7cdc62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Mei=C3=9Fner?= Date: Thu, 12 Oct 2023 01:09:24 +0200 Subject: [PATCH 29/46] Adds feature detection for Metal. --- wgpu-hal/src/metal/adapter.rs | 10 ++++++++++ wgpu-hal/src/metal/mod.rs | 1 + 2 files changed, 11 insertions(+) diff --git a/wgpu-hal/src/metal/adapter.rs b/wgpu-hal/src/metal/adapter.rs index 2f48712b9b..028d0baef7 100644 --- a/wgpu-hal/src/metal/adapter.rs +++ b/wgpu-hal/src/metal/adapter.rs @@ -800,6 +800,12 @@ impl super::PrivateCapabilities { None }, timestamp_query_support, + supports_simd_scoped_operations: family_check + && (device.supports_family(MTLGPUFamily::Metal3) + || device.supports_family(MTLGPUFamily::Mac2) + || device.supports_family(MTLGPUFamily::Apple7) + || device.supports_family(MTLGPUFamily::Apple8) + || device.supports_family(MTLGPUFamily::Apple9)), } } @@ -878,6 +884,10 @@ impl super::PrivateCapabilities { features.set(F::RG11B10UFLOAT_RENDERABLE, self.format_rg11b10_all); features.set(F::SHADER_UNUSED_VERTEX_OUTPUT, true); + if self.supports_simd_scoped_operations { + features.insert(F::SUBGROUP_OPERATIONS); + } + features } diff --git a/wgpu-hal/src/metal/mod.rs b/wgpu-hal/src/metal/mod.rs index a75439096a..15c11dd120 100644 --- a/wgpu-hal/src/metal/mod.rs +++ b/wgpu-hal/src/metal/mod.rs @@ -259,6 +259,7 @@ struct PrivateCapabilities { supports_shader_primitive_index: bool, has_unified_memory: Option, timestamp_query_support: TimestampQuerySupport, + supports_simd_scoped_operations: bool, } #[derive(Clone, Debug)] From 2e44fb1f547d07734dceefb4fdecabc15f0f85a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Mei=C3=9Fner?= Date: Sat, 14 Oct 2023 15:24:39 +0200 Subject: [PATCH 30/46] Adds feature detection for DirectX 12. --- wgpu-hal/src/dx12/adapter.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/wgpu-hal/src/dx12/adapter.rs b/wgpu-hal/src/dx12/adapter.rs index 23bd25e6aa..76e48e7654 100644 --- a/wgpu-hal/src/dx12/adapter.rs +++ b/wgpu-hal/src/dx12/adapter.rs @@ -292,6 +292,12 @@ impl super::Adapter { bgra8unorm_storage_supported, ); + features.set( + wgt::Features::SUBGROUP_OPERATIONS, + shader_model_support.HighestShaderModel >= d3d12_ty::D3D_SHADER_MODEL_6_0 + && matches!(dx12_shader_compiler, &wgt::Dx12Compiler::Dxc { .. }), + ); + // TODO: Determine if IPresentationManager is supported let presentation_timer = auxil::dxgi::time::PresentationTimer::new_dxgi(); From e186d0e1bb039998dc9337940df566331fec65b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Mei=C3=9Fner?= Date: Thu, 12 Oct 2023 01:02:08 +0200 Subject: [PATCH 31/46] Adds subgroup_operations tests. --- tests/tests/gpu.rs | 1 + tests/tests/subgroup_operations/mod.rs | 104 +++++++++++++++++++ tests/tests/subgroup_operations/shader.wgsl | 109 ++++++++++++++++++++ 3 files changed, 214 insertions(+) create mode 100644 tests/tests/subgroup_operations/mod.rs create mode 100644 tests/tests/subgroup_operations/shader.wgsl diff --git a/tests/tests/gpu.rs b/tests/tests/gpu.rs index a5fbcde9da..1494d0f128 100644 --- a/tests/tests/gpu.rs +++ b/tests/tests/gpu.rs @@ -27,6 +27,7 @@ mod scissor_tests; mod shader; mod shader_primitive_index; mod shader_view_format; +mod subgroup_operations; mod texture_bounds; mod transfer; mod vertex_indices; diff --git a/tests/tests/subgroup_operations/mod.rs b/tests/tests/subgroup_operations/mod.rs new file mode 100644 index 0000000000..2b6517538b --- /dev/null +++ b/tests/tests/subgroup_operations/mod.rs @@ -0,0 +1,104 @@ +use std::{borrow::Cow, num::NonZeroU64}; + +use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters}; + +const THREAD_COUNT: u64 = 128; + +#[gpu_test] +static SUBGROUP_OPERATIONS: GpuTestConfiguration = GpuTestConfiguration::new() + .parameters( + TestParameters::default() + .features(wgpu::Features::SUBGROUP_OPERATIONS) + .limits(wgpu::Limits::downlevel_defaults()) + .expect_fail(wgpu_test::FailureCase::molten_vk()), + ) + .run_sync(|ctx| { + let device = &ctx.device; + + let storage_buffer = device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size: THREAD_COUNT * std::mem::size_of::() as u64, + usage: wgpu::BufferUsages::STORAGE + | wgpu::BufferUsages::COPY_DST + | wgpu::BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + + let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + label: Some("bind group layout"), + entries: &[wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: NonZeroU64::new( + THREAD_COUNT * std::mem::size_of::() as u64, + ), + }, + count: None, + }], + }); + + let cs_module = device.create_shader_module(wgpu::ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(include_str!("shader.wgsl"))), + }); + + let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: Some("main"), + bind_group_layouts: &[&bind_group_layout], + push_constant_ranges: &[], + }); + + let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: None, + layout: Some(&pipeline_layout), + module: &cs_module, + entry_point: "main", + }); + + let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { + entries: &[wgpu::BindGroupEntry { + binding: 0, + resource: storage_buffer.as_entire_binding(), + }], + layout: &bind_group_layout, + label: Some("bind group"), + }); + + let mut encoder = + device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None }); + { + let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: None, + timestamp_writes: None, + }); + cpass.set_pipeline(&compute_pipeline); + cpass.set_bind_group(0, &bind_group, &[]); + cpass.dispatch_workgroups(THREAD_COUNT as u32, 1, 1); + } + + let mapping_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Mapping buffer"), + size: THREAD_COUNT * std::mem::size_of::() as u64, + usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + encoder.copy_buffer_to_buffer( + &storage_buffer, + 0, + &mapping_buffer, + 0, + THREAD_COUNT * std::mem::size_of::() as u64, + ); + ctx.queue.submit(Some(encoder.finish())); + + mapping_buffer + .slice(..) + .map_async(wgpu::MapMode::Read, |_| ()); + ctx.device.poll(wgpu::Maintain::Wait); + let mapping_buffer_view = mapping_buffer.slice(..).get_mapped_range(); + let result: &[u32; THREAD_COUNT as usize] = bytemuck::from_bytes(&mapping_buffer_view); + assert_eq!(result, &[27; THREAD_COUNT as usize]); + }); diff --git a/tests/tests/subgroup_operations/shader.wgsl b/tests/tests/subgroup_operations/shader.wgsl new file mode 100644 index 0000000000..70d98dacd4 --- /dev/null +++ b/tests/tests/subgroup_operations/shader.wgsl @@ -0,0 +1,109 @@ +@group(0) +@binding(0) +var storage_buffer: array; + +@compute +@workgroup_size(128) +fn main( + @builtin(global_invocation_id) global_id: vec3, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(subgroup_invocation_id) subgroup_invocation_id: u32, +) { + var passed = 0u; + var expected: u32; + + passed += u32(num_subgroups == 128u / subgroup_size); + passed += u32(subgroup_id == global_id.x / subgroup_size); + passed += u32(subgroup_invocation_id == global_id.x % subgroup_size); + + var expected_ballot = vec4(0u); + for(var i = 0u; i < subgroup_size; i += 1u) { + expected_ballot[i / 32u] |= ((global_id.x - subgroup_invocation_id + i) & 1u) << (i % 32u); + } + passed += u32(dot(vec4(1u), vec4(subgroupBallot((subgroup_invocation_id & 1u) == 1u) == expected_ballot)) == 4u); + + passed += u32(subgroupAll(true)); + passed += u32(!subgroupAll(subgroup_invocation_id != 0u)); + + passed += u32(subgroupAny(subgroup_invocation_id == 0u)); + passed += u32(!subgroupAny(false)); + + expected = 0u; + for(var i = 0u; i < subgroup_size; i += 1u) { + expected += global_id.x - subgroup_invocation_id + i + 1u; + } + passed += u32(subgroupAdd(global_id.x + 1u) == expected); + + expected = 1u; + for(var i = 0u; i < subgroup_size; i += 1u) { + expected *= global_id.x - subgroup_invocation_id + i + 1u; + } + passed += u32(subgroupMul(global_id.x + 1u) == expected); + + expected = 0u; + for(var i = 0u; i < subgroup_size; i += 1u) { + expected = max(expected, global_id.x - subgroup_invocation_id + i + 1u); + } + passed += u32(subgroupMax(global_id.x + 1u) == expected); + + expected = 0xFFFFFFFFu; + for(var i = 0u; i < subgroup_size; i += 1u) { + expected = min(expected, global_id.x - subgroup_invocation_id + i + 1u); + } + passed += u32(subgroupMin(global_id.x + 1u) == expected); + + expected = 0xFFFFFFFFu; + for(var i = 0u; i < subgroup_size; i += 1u) { + expected &= global_id.x - subgroup_invocation_id + i + 1u; + } + passed += u32(subgroupAnd(global_id.x + 1u) == expected); + + expected = 0u; + for(var i = 0u; i < subgroup_size; i += 1u) { + expected |= global_id.x - subgroup_invocation_id + i + 1u; + } + passed += u32(subgroupOr(global_id.x + 1u) == expected); + + expected = 0u; + for(var i = 0u; i < subgroup_size; i += 1u) { + expected ^= global_id.x - subgroup_invocation_id + i + 1u; + } + passed += u32(subgroupXor(global_id.x + 1u) == expected); + + expected = 0u; + for(var i = 0u; i < subgroup_invocation_id; i += 1u) { + expected += global_id.x - subgroup_invocation_id + i + 1u; + } + passed += u32(subgroupPrefixExclusiveAdd(global_id.x + 1u) == expected); + + expected = 1u; + for(var i = 0u; i < subgroup_invocation_id; i += 1u) { + expected *= global_id.x - subgroup_invocation_id + i + 1u; + } + passed += u32(subgroupPrefixExclusiveMul(global_id.x + 1u) == expected); + + expected = 0u; + for(var i = 0u; i <= subgroup_invocation_id; i += 1u) { + expected += global_id.x - subgroup_invocation_id + i + 1u; + } + passed += u32(subgroupPrefixInclusiveAdd(global_id.x + 1u) == expected); + + expected = 1u; + for(var i = 0u; i <= subgroup_invocation_id; i += 1u) { + expected *= global_id.x - subgroup_invocation_id + i + 1u; + } + passed += u32(subgroupPrefixInclusiveMul(global_id.x + 1u) == expected); + + passed += u32(subgroupBroadcastFirst(u32(subgroup_invocation_id != 0u)) == 0u); + passed += u32(subgroupBroadcastFirst(u32(subgroup_invocation_id == 0u)) == 1u); + passed += u32(subgroupBroadcast(subgroup_invocation_id, 4u) == 4u); + passed += u32(subgroupShuffle(subgroup_invocation_id, subgroup_invocation_id) == subgroup_invocation_id); + passed += u32(subgroupShuffle(subgroup_invocation_id, subgroup_size - 1u - subgroup_invocation_id) == subgroup_size - 1u - subgroup_invocation_id); + passed += u32(subgroup_invocation_id == subgroup_size - 1u || subgroupShuffleDown(subgroup_invocation_id, 1u) == subgroup_invocation_id + 1u); + passed += u32(subgroup_invocation_id == 0u || subgroupShuffleUp(subgroup_invocation_id, 1u) == subgroup_invocation_id - 1u); + passed += u32(subgroupShuffleXor(subgroup_invocation_id, subgroup_size - 1u) == (subgroup_invocation_id ^ (subgroup_size - 1u))); + + storage_buffer[global_id.x] = passed; +} From 9c12f08843c88ae69c2c6be52bf93a8c369b889f Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Sun, 22 Oct 2023 16:09:23 -0700 Subject: [PATCH 32/46] Pass subgroup capability to naga --- wgpu-core/src/device/resource.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/wgpu-core/src/device/resource.rs b/wgpu-core/src/device/resource.rs index fd85fd6a77..a521f3d030 100644 --- a/wgpu-core/src/device/resource.rs +++ b/wgpu-core/src/device/resource.rs @@ -1316,6 +1316,10 @@ impl Device { .flags .contains(wgt::DownlevelFlags::CUBE_ARRAY_TEXTURES), ); + caps.set( + Caps::SUBGROUP, + self.features.contains(wgt::Features::SUBGROUP_OPERATIONS), + ); let debug_source = if self.instance_flags.contains(wgt::InstanceFlags::DEBUG) { Some(hal::DebugSource { From 358c2a7d3888fc49740f25a9d82bac2edee91dbc Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Tue, 24 Oct 2023 11:19:51 -0700 Subject: [PATCH 33/46] Separate subgroup feature into one flag per shader stage --- CHANGELOG.md | 2 +- tests/tests/subgroup_operations/mod.rs | 2 +- wgpu-core/src/device/resource.rs | 29 ++++++++++++- wgpu-hal/src/dx12/adapter.rs | 4 +- wgpu-hal/src/metal/adapter.rs | 2 +- wgpu-hal/src/vulkan/adapter.rs | 58 ++++++++++++++++---------- wgpu-types/src/lib.rs | 42 +++++++++++++------ 7 files changed, 101 insertions(+), 38 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 413978e6cd..5bf58ff78d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,7 +42,7 @@ Bottom level categories: ### Added/New Features -- Add `SUBGROUP_OPERATIONS` feature. By @exrook and @lichtso in [#4240](https://github.com/gfx-rs/wgpu/pull/4240) +- Add `SUBGROUP_COMPUTE, SUBGROUP_FRAGMENT, SUBGROUP_VERTEX` features. By @exrook and @lichtso in [#4240](https://github.com/gfx-rs/wgpu/pull/4240) For naga changelogs at or before v0.14.0. See [naga's changelog](naga/CHANGELOG.md). diff --git a/tests/tests/subgroup_operations/mod.rs b/tests/tests/subgroup_operations/mod.rs index 2b6517538b..d0abfb96bb 100644 --- a/tests/tests/subgroup_operations/mod.rs +++ b/tests/tests/subgroup_operations/mod.rs @@ -8,7 +8,7 @@ const THREAD_COUNT: u64 = 128; static SUBGROUP_OPERATIONS: GpuTestConfiguration = GpuTestConfiguration::new() .parameters( TestParameters::default() - .features(wgpu::Features::SUBGROUP_OPERATIONS) + .features(wgpu::Features::SUBGROUP_COMPUTE) .limits(wgpu::Limits::downlevel_defaults()) .expect_fail(wgpu_test::FailureCase::molten_vk()), ) diff --git a/wgpu-core/src/device/resource.rs b/wgpu-core/src/device/resource.rs index a521f3d030..76aca5fe9f 100644 --- a/wgpu-core/src/device/resource.rs +++ b/wgpu-core/src/device/resource.rs @@ -1318,7 +1318,11 @@ impl Device { ); caps.set( Caps::SUBGROUP, - self.features.contains(wgt::Features::SUBGROUP_OPERATIONS), + self.features.intersects( + wgt::Features::SUBGROUP_COMPUTE + | wgt::Features::SUBGROUP_FRAGMENT + | wgt::Features::SUBGROUP_VERTEX, + ), ); let debug_source = if self.instance_flags.contains(wgt::InstanceFlags::DEBUG) { @@ -1334,7 +1338,30 @@ impl Device { None }; + let mut subgroup_stages = naga::valid::ShaderStages::empty(); + subgroup_stages.set( + naga::valid::ShaderStages::COMPUTE, + self.features.contains(wgt::Features::SUBGROUP_COMPUTE), + ); + subgroup_stages.set( + naga::valid::ShaderStages::FRAGMENT, + self.features.contains(wgt::Features::SUBGROUP_FRAGMENT), + ); + subgroup_stages.set( + naga::valid::ShaderStages::VERTEX, + self.features.contains(wgt::Features::SUBGROUP_VERTEX), + ); + + let subgroup_operations = if caps.contains(Caps::SUBGROUP) { + use naga::valid::SubgroupOperationSet as S; + S::BASIC | S::VOTE | S::ARITHMETIC | S::BALLOT | S::SHUFFLE | S::SHUFFLE_RELATIVE + } else { + naga::valid::SubgroupOperationSet::empty() + }; + let info = naga::valid::Validator::new(naga::valid::ValidationFlags::all(), caps) + .subgroup_stages(subgroup_stages) + .subgroup_operations(subgroup_operations) .validate(&module) .map_err(|inner| { pipeline::CreateShaderModuleError::Validation(pipeline::ShaderError { diff --git a/wgpu-hal/src/dx12/adapter.rs b/wgpu-hal/src/dx12/adapter.rs index 76e48e7654..24db1fb719 100644 --- a/wgpu-hal/src/dx12/adapter.rs +++ b/wgpu-hal/src/dx12/adapter.rs @@ -293,7 +293,9 @@ impl super::Adapter { ); features.set( - wgt::Features::SUBGROUP_OPERATIONS, + wgt::Features::SUBGROUP_COMPUTE + | wgt::Features::SUBGROUP_FRAGMENT + | wgt::Features::SUBGROUP_VERTEX, shader_model_support.HighestShaderModel >= d3d12_ty::D3D_SHADER_MODEL_6_0 && matches!(dx12_shader_compiler, &wgt::Dx12Compiler::Dxc { .. }), ); diff --git a/wgpu-hal/src/metal/adapter.rs b/wgpu-hal/src/metal/adapter.rs index 028d0baef7..e56e6e2937 100644 --- a/wgpu-hal/src/metal/adapter.rs +++ b/wgpu-hal/src/metal/adapter.rs @@ -885,7 +885,7 @@ impl super::PrivateCapabilities { features.set(F::SHADER_UNUSED_VERTEX_OUTPUT, true); if self.supports_simd_scoped_operations { - features.insert(F::SUBGROUP_OPERATIONS); + features.insert(F::SUBGROUP_COMPUTE | F::SUBGROUP_FRAGMENT | F::SUBGROUP_VERTEX); } features diff --git a/wgpu-hal/src/vulkan/adapter.rs b/wgpu-hal/src/vulkan/adapter.rs index e0a210a787..1cc2ead6ee 100644 --- a/wgpu-hal/src/vulkan/adapter.rs +++ b/wgpu-hal/src/vulkan/adapter.rs @@ -490,21 +490,33 @@ impl PhysicalDeviceFeatures { } if let Some(ref subgroup) = caps.subgroup { - features.set( - F::SUBGROUP_OPERATIONS, - subgroup.supported_operations.contains( - vk::SubgroupFeatureFlags::BASIC - | vk::SubgroupFeatureFlags::VOTE - | vk::SubgroupFeatureFlags::ARITHMETIC - | vk::SubgroupFeatureFlags::BALLOT - | vk::SubgroupFeatureFlags::SHUFFLE - | vk::SubgroupFeatureFlags::SHUFFLE_RELATIVE - | vk::SubgroupFeatureFlags::CLUSTERED - | vk::SubgroupFeatureFlags::QUAD, - ) && subgroup - .supported_stages - .contains(vk::ShaderStageFlags::COMPUTE | vk::ShaderStageFlags::FRAGMENT), - ); + if subgroup.supported_operations.contains( + vk::SubgroupFeatureFlags::BASIC + | vk::SubgroupFeatureFlags::VOTE + | vk::SubgroupFeatureFlags::ARITHMETIC + | vk::SubgroupFeatureFlags::BALLOT + | vk::SubgroupFeatureFlags::SHUFFLE + | vk::SubgroupFeatureFlags::SHUFFLE_RELATIVE, + ) { + features.set( + F::SUBGROUP_COMPUTE, + subgroup + .supported_stages + .contains(vk::ShaderStageFlags::COMPUTE), + ); + features.set( + F::SUBGROUP_FRAGMENT, + subgroup + .supported_stages + .contains(vk::ShaderStageFlags::FRAGMENT), + ); + features.set( + F::SUBGROUP_VERTEX, + subgroup + .supported_stages + .contains(vk::ShaderStageFlags::VERTEX), + ); + } } let supports_depth_format = |format| { @@ -1279,17 +1291,17 @@ impl super::Adapter { capabilities.push(spv::Capability::Geometry); } - if features.contains(wgt::Features::SUBGROUP_OPERATIONS) { + if features.intersects( + wgt::Features::SUBGROUP_COMPUTE + | wgt::Features::SUBGROUP_FRAGMENT + | wgt::Features::SUBGROUP_VERTEX, + ) { capabilities.push(spv::Capability::GroupNonUniform); capabilities.push(spv::Capability::GroupNonUniformVote); capabilities.push(spv::Capability::GroupNonUniformArithmetic); capabilities.push(spv::Capability::GroupNonUniformBallot); capabilities.push(spv::Capability::GroupNonUniformShuffle); capabilities.push(spv::Capability::GroupNonUniformShuffleRelative); - capabilities.push(spv::Capability::GroupNonUniformClustered); - capabilities.push(spv::Capability::GroupNonUniformQuad); - capabilities.push(spv::Capability::SubgroupBallotKHR); - capabilities.push(spv::Capability::SubgroupVoteKHR); } if features.intersects( @@ -1319,7 +1331,11 @@ impl super::Adapter { true, // could check `super::Workarounds::SEPARATE_ENTRY_POINTS` ); spv::Options { - lang_version: if features.contains(wgt::Features::SUBGROUP_OPERATIONS) { + lang_version: if features.intersects( + wgt::Features::SUBGROUP_COMPUTE + | wgt::Features::SUBGROUP_FRAGMENT + | wgt::Features::SUBGROUP_VERTEX, + ) { (1, 3) } else { (1, 0) diff --git a/wgpu-types/src/lib.rs b/wgpu-types/src/lib.rs index 4095d24e7a..c9a9567e35 100644 --- a/wgpu-types/src/lib.rs +++ b/wgpu-types/src/lib.rs @@ -299,17 +299,7 @@ bitflags::bitflags! { /// This is a web and native feature. const SHADER_F16 = 1 << 8; - /// Allows shaders to use the subgroup operation built-ins - /// - /// Supported Platforms: - /// - Vulkan - /// - DX12 - /// - Metal - /// - /// This is a web and native feature. - const SUBGROUP_OPERATIONS = 1 << 9; - - // 10..14 available + // 9..14 available // Texture Formats: @@ -766,10 +756,38 @@ bitflags::bitflags! { /// - OpenGL const SHADER_UNUSED_VERTEX_OUTPUT = 1 << 54; - // 54..59 available + // 55 available // Shader: + /// Allows compute shaders to use the subgroup operation built-ins + /// + /// Supported Platforms: + /// - Vulkan + /// - DX12 + /// - Metal + /// + /// This is a native only feature. + const SUBGROUP_COMPUTE = 1 << 56; + /// Allows fragment shaders to use the subgroup operation built-ins + /// + /// Supported Platforms: + /// - Vulkan + /// - DX12 + /// - Metal + /// + /// This is a native only feature. + const SUBGROUP_FRAGMENT = 1 << 57; + /// Allows vertx shaders to use the subgroup operation built-ins + /// + /// Supported Platforms: + /// - Vulkan + /// - DX12 + /// - Metal + /// + /// This is a native only feature. + const SUBGROUP_VERTEX = 1 << 58; + /// Enables 64-bit floating point types in SPIR-V shaders. /// /// Note: even when supported by GPU hardware, 64-bit floating point operations are From ebaf08d5b3e13555e6bed71f62ce03a4a71be7ca Mon Sep 17 00:00:00 2001 From: Connor Fitzgerald Date: Sat, 28 Oct 2023 01:49:01 -0400 Subject: [PATCH 34/46] Fix compiles --- wgpu-hal/src/dx12/adapter.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/wgpu-hal/src/dx12/adapter.rs b/wgpu-hal/src/dx12/adapter.rs index edc5c75b46..4a0e62e864 100644 --- a/wgpu-hal/src/dx12/adapter.rs +++ b/wgpu-hal/src/dx12/adapter.rs @@ -126,6 +126,11 @@ impl super::Adapter { ) }); + // If we don't have dxc, we reduce the max to 5.1 + if dxc_container.is_none() { + shader_model_support.HighestShaderModel = d3d12_ty::D3D_SHADER_MODEL_5_1; + } + let mut workarounds = super::Workarounds::default(); let info = wgt::AdapterInfo { @@ -296,8 +301,7 @@ impl super::Adapter { wgt::Features::SUBGROUP_COMPUTE | wgt::Features::SUBGROUP_FRAGMENT | wgt::Features::SUBGROUP_VERTEX, - shader_model_support.HighestShaderModel >= d3d12_ty::D3D_SHADER_MODEL_6_0 - && matches!(dx12_shader_compiler, &wgt::Dx12Compiler::Dxc { .. }), + shader_model_support.HighestShaderModel >= d3d12_ty::D3D_SHADER_MODEL_6_0, ); // TODO: Determine if IPresentationManager is supported From afaf44128f4e3a73ae54ca7a8b599dbd7f9580f0 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Wed, 1 Nov 2023 05:27:08 -0700 Subject: [PATCH 35/46] subgroups: DX12 doesn't support subgroup ops in vertex stage --- wgpu-hal/src/dx12/adapter.rs | 4 +--- wgpu-types/src/lib.rs | 3 +-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/wgpu-hal/src/dx12/adapter.rs b/wgpu-hal/src/dx12/adapter.rs index 4a0e62e864..e6c9ed4183 100644 --- a/wgpu-hal/src/dx12/adapter.rs +++ b/wgpu-hal/src/dx12/adapter.rs @@ -298,9 +298,7 @@ impl super::Adapter { ); features.set( - wgt::Features::SUBGROUP_COMPUTE - | wgt::Features::SUBGROUP_FRAGMENT - | wgt::Features::SUBGROUP_VERTEX, + wgt::Features::SUBGROUP_COMPUTE | wgt::Features::SUBGROUP_FRAGMENT, shader_model_support.HighestShaderModel >= d3d12_ty::D3D_SHADER_MODEL_6_0, ); diff --git a/wgpu-types/src/lib.rs b/wgpu-types/src/lib.rs index 6377e8f14c..c066db3b33 100644 --- a/wgpu-types/src/lib.rs +++ b/wgpu-types/src/lib.rs @@ -784,11 +784,10 @@ bitflags::bitflags! { /// /// This is a native only feature. const SUBGROUP_FRAGMENT = 1 << 57; - /// Allows vertx shaders to use the subgroup operation built-ins + /// Allows vertex shaders to use the subgroup operation built-ins /// /// Supported Platforms: /// - Vulkan - /// - DX12 /// - Metal /// /// This is a native only feature. From 7e0060e5d657efd1fc001705609b8ac52d7f3395 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Sat, 4 Nov 2023 22:47:24 -0400 Subject: [PATCH 36/46] subgroup: fix hlsl subgroup_id --- naga/src/back/hlsl/writer.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 1eab43a4c3..b7ccdff39a 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -1266,7 +1266,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { crate::BuiltIn::SubgroupId => { writeln!( self.out, - "(__local_invocation_id.x * {}u + __local_invocation_id.y * {}u + __local_invocation_id.z) / WaveGetLaneCount();", + "(__local_invocation_id.z * {}u + __local_invocation_id.y * {}u + __local_invocation_id.x) / WaveGetLaneCount();", ep.workgroup_size[0] * ep.workgroup_size[1], ep.workgroup_size[1], )?; From d822524e9134bf08f6d11a8116a459a2d7e3994a Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Sat, 4 Nov 2023 22:55:08 -0400 Subject: [PATCH 37/46] subgroups: update naga snapshots --- naga/tests/out/hlsl/subgroup-operations-s.hlsl | 2 +- naga/tests/out/hlsl/subgroup-operations.hlsl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/naga/tests/out/hlsl/subgroup-operations-s.hlsl b/naga/tests/out/hlsl/subgroup-operations-s.hlsl index e86ed8ee99..e1e399e213 100644 --- a/naga/tests/out/hlsl/subgroup-operations-s.hlsl +++ b/naga/tests/out/hlsl/subgroup-operations-s.hlsl @@ -38,7 +38,7 @@ void main(uint3 __local_invocation_id : SV_GroupThreadID) } GroupMemoryBarrierWithGroupSync(); const uint num_subgroups = (1u + WaveGetLaneCount() - 1u) / WaveGetLaneCount(); - const uint subgroup_id = (__local_invocation_id.x * 1u + __local_invocation_id.y * 1u + __local_invocation_id.z) / WaveGetLaneCount(); + const uint subgroup_id = (__local_invocation_id.z * 1u + __local_invocation_id.y * 1u + __local_invocation_id.x) / WaveGetLaneCount(); const uint subgroup_size = WaveGetLaneCount(); const uint subgroup_invocation_id = WaveGetLaneIndex(); num_subgroups_1 = num_subgroups; diff --git a/naga/tests/out/hlsl/subgroup-operations.hlsl b/naga/tests/out/hlsl/subgroup-operations.hlsl index 65d3a51a92..a79fd8a38d 100644 --- a/naga/tests/out/hlsl/subgroup-operations.hlsl +++ b/naga/tests/out/hlsl/subgroup-operations.hlsl @@ -5,7 +5,7 @@ void main(uint3 __local_invocation_id : SV_GroupThreadID) } GroupMemoryBarrierWithGroupSync(); const uint num_subgroups = (1u + WaveGetLaneCount() - 1u) / WaveGetLaneCount(); - const uint subgroup_id = (__local_invocation_id.x * 1u + __local_invocation_id.y * 1u + __local_invocation_id.z) / WaveGetLaneCount(); + const uint subgroup_id = (__local_invocation_id.z * 1u + __local_invocation_id.y * 1u + __local_invocation_id.x) / WaveGetLaneCount(); const uint subgroup_size = WaveGetLaneCount(); const uint subgroup_invocation_id = WaveGetLaneIndex(); const uint4 _e8 = WaveActiveBallot(((subgroup_invocation_id & 1u) == 1u)); From 011016fb599e27c045cd6898187727d6e10c6233 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Sun, 5 Nov 2023 18:21:24 -0800 Subject: [PATCH 38/46] subgroups: fix gpu test on systems with subgroup size of 1 --- tests/tests/subgroup_operations/shader.wgsl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests/subgroup_operations/shader.wgsl b/tests/tests/subgroup_operations/shader.wgsl index 70d98dacd4..05954958d6 100644 --- a/tests/tests/subgroup_operations/shader.wgsl +++ b/tests/tests/subgroup_operations/shader.wgsl @@ -98,7 +98,7 @@ fn main( passed += u32(subgroupBroadcastFirst(u32(subgroup_invocation_id != 0u)) == 0u); passed += u32(subgroupBroadcastFirst(u32(subgroup_invocation_id == 0u)) == 1u); - passed += u32(subgroupBroadcast(subgroup_invocation_id, 4u) == 4u); + passed += u32(subgroupBroadcast(subgroup_invocation_id, 1u) == 1u); passed += u32(subgroupShuffle(subgroup_invocation_id, subgroup_invocation_id) == subgroup_invocation_id); passed += u32(subgroupShuffle(subgroup_invocation_id, subgroup_size - 1u - subgroup_invocation_id) == subgroup_size - 1u - subgroup_invocation_id); passed += u32(subgroup_invocation_id == subgroup_size - 1u || subgroupShuffleDown(subgroup_invocation_id, 1u) == subgroup_invocation_id + 1u); From a68ef323e571e8702b0f7399c073358a8736c2ff Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Mon, 6 Nov 2023 18:11:53 -0800 Subject: [PATCH 39/46] subgroup: emit Shuffle instead of Broadcast on spv-out --- naga/src/back/spv/subgroup.rs | 6 +++++- naga/tests/out/spv/subgroup-operations.spvasm | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/naga/src/back/spv/subgroup.rs b/naga/src/back/spv/subgroup.rs index 79db752a6c..ba46140b50 100644 --- a/naga/src/back/spv/subgroup.rs +++ b/naga/src/back/spv/subgroup.rs @@ -182,7 +182,11 @@ impl<'w> BlockContext<'w> { let index_id = self.cached[index]; let op = match *mode { crate::GatherMode::BroadcastFirst => unreachable!(), - crate::GatherMode::Broadcast(_) => spirv::Op::GroupNonUniformBroadcast, + // Use shuffle to emit broadcast to allow the index to + // be dynamically uniform on Vulkan 1.1. The argument to + // OpGroupNonUniformBroadcast must be a constant pre SPIR-V + // 1.5 (vulkan 1.2) + crate::GatherMode::Broadcast(_) => spirv::Op::GroupNonUniformShuffle, crate::GatherMode::Shuffle(_) => spirv::Op::GroupNonUniformShuffle, crate::GatherMode::ShuffleDown(_) => spirv::Op::GroupNonUniformShuffleDown, crate::GatherMode::ShuffleUp(_) => spirv::Op::GroupNonUniformShuffleUp, diff --git a/naga/tests/out/spv/subgroup-operations.spvasm b/naga/tests/out/spv/subgroup-operations.spvasm index 72c68aa46c..73d3d52c61 100644 --- a/naga/tests/out/spv/subgroup-operations.spvasm +++ b/naga/tests/out/spv/subgroup-operations.spvasm @@ -63,7 +63,7 @@ OpControlBarrier %21 %22 %23 %43 = OpGroupNonUniformIAdd %3 %21 InclusiveScan %14 %44 = OpGroupNonUniformIMul %3 %21 InclusiveScan %14 %45 = OpGroupNonUniformBroadcastFirst %3 %21 %14 -%46 = OpGroupNonUniformBroadcast %3 %21 %14 %19 +%46 = OpGroupNonUniformShuffle %3 %21 %14 %19 %47 = OpISub %3 %12 %17 %48 = OpISub %3 %47 %14 %49 = OpGroupNonUniformShuffle %3 %21 %14 %48 From 3bc4aa92a39ad1fdec4b11897d0274b1376de390 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Wed, 15 Nov 2023 17:44:07 -0800 Subject: [PATCH 40/46] subgroup: Use bitmask to track pass/failed tests --- tests/tests/subgroup_operations/mod.rs | 9 ++- tests/tests/subgroup_operations/shader.wgsl | 83 ++++++++++++++------- 2 files changed, 63 insertions(+), 29 deletions(-) diff --git a/tests/tests/subgroup_operations/mod.rs b/tests/tests/subgroup_operations/mod.rs index d0abfb96bb..31f395cf00 100644 --- a/tests/tests/subgroup_operations/mod.rs +++ b/tests/tests/subgroup_operations/mod.rs @@ -100,5 +100,12 @@ static SUBGROUP_OPERATIONS: GpuTestConfiguration = GpuTestConfiguration::new() ctx.device.poll(wgpu::Maintain::Wait); let mapping_buffer_view = mapping_buffer.slice(..).get_mapped_range(); let result: &[u32; THREAD_COUNT as usize] = bytemuck::from_bytes(&mapping_buffer_view); - assert_eq!(result, &[27; THREAD_COUNT as usize]); + let expected_mask = (1 << (27)) - 1; // generate full mask + let expected_array = [expected_mask as u32; THREAD_COUNT as usize]; + if result != &expected_array { + panic!( + "Got from GPU:\n{:x?}\n expected:\n{:x?}", + result, &expected_array, + ); + } }); diff --git a/tests/tests/subgroup_operations/shader.wgsl b/tests/tests/subgroup_operations/shader.wgsl index 05954958d6..ecaff03b84 100644 --- a/tests/tests/subgroup_operations/shader.wgsl +++ b/tests/tests/subgroup_operations/shader.wgsl @@ -14,96 +14,123 @@ fn main( var passed = 0u; var expected: u32; - passed += u32(num_subgroups == 128u / subgroup_size); - passed += u32(subgroup_id == global_id.x / subgroup_size); - passed += u32(subgroup_invocation_id == global_id.x % subgroup_size); + var mask = 1u << 0u; + passed |= mask * u32(num_subgroups == 128u / subgroup_size); + mask = 1u << 1u; + passed |= mask * u32(subgroup_id == global_id.x / subgroup_size); + mask = 1u << 2u; + passed |= mask * u32(subgroup_invocation_id == global_id.x % subgroup_size); var expected_ballot = vec4(0u); for(var i = 0u; i < subgroup_size; i += 1u) { expected_ballot[i / 32u] |= ((global_id.x - subgroup_invocation_id + i) & 1u) << (i % 32u); } - passed += u32(dot(vec4(1u), vec4(subgroupBallot((subgroup_invocation_id & 1u) == 1u) == expected_ballot)) == 4u); + mask = 1u << 3u; + passed |= mask * u32(dot(vec4(1u), vec4(subgroupBallot((subgroup_invocation_id & 1u) == 1u) == expected_ballot)) == 4u); - passed += u32(subgroupAll(true)); - passed += u32(!subgroupAll(subgroup_invocation_id != 0u)); + mask = 1u << 4u; + passed |= mask * u32(subgroupAll(true)); + mask = 1u << 5u; + passed |= mask * u32(!subgroupAll(subgroup_invocation_id != 0u)); - passed += u32(subgroupAny(subgroup_invocation_id == 0u)); - passed += u32(!subgroupAny(false)); + mask = 1u << 6u; + passed |= mask * u32(subgroupAny(subgroup_invocation_id == 0u)); + mask = 1u << 7u; + passed |= mask * u32(!subgroupAny(false)); expected = 0u; for(var i = 0u; i < subgroup_size; i += 1u) { expected += global_id.x - subgroup_invocation_id + i + 1u; } - passed += u32(subgroupAdd(global_id.x + 1u) == expected); + mask = 1u << 8u; + passed |= mask * u32(subgroupAdd(global_id.x + 1u) == expected); expected = 1u; for(var i = 0u; i < subgroup_size; i += 1u) { expected *= global_id.x - subgroup_invocation_id + i + 1u; } - passed += u32(subgroupMul(global_id.x + 1u) == expected); + mask = 1u << 9u; + passed |= mask * u32(subgroupMul(global_id.x + 1u) == expected); expected = 0u; for(var i = 0u; i < subgroup_size; i += 1u) { expected = max(expected, global_id.x - subgroup_invocation_id + i + 1u); } - passed += u32(subgroupMax(global_id.x + 1u) == expected); + mask = 1u << 10u; + passed |= mask * u32(subgroupMax(global_id.x + 1u) == expected); expected = 0xFFFFFFFFu; for(var i = 0u; i < subgroup_size; i += 1u) { expected = min(expected, global_id.x - subgroup_invocation_id + i + 1u); } - passed += u32(subgroupMin(global_id.x + 1u) == expected); + mask = 1u << 11u; + passed |= mask * u32(subgroupMin(global_id.x + 1u) == expected); expected = 0xFFFFFFFFu; for(var i = 0u; i < subgroup_size; i += 1u) { expected &= global_id.x - subgroup_invocation_id + i + 1u; } - passed += u32(subgroupAnd(global_id.x + 1u) == expected); + mask = 1u << 12u; + passed |= mask * u32(subgroupAnd(global_id.x + 1u) == expected); expected = 0u; for(var i = 0u; i < subgroup_size; i += 1u) { expected |= global_id.x - subgroup_invocation_id + i + 1u; } - passed += u32(subgroupOr(global_id.x + 1u) == expected); + mask = 1u << 13u; + passed |= mask * u32(subgroupOr(global_id.x + 1u) == expected); expected = 0u; for(var i = 0u; i < subgroup_size; i += 1u) { expected ^= global_id.x - subgroup_invocation_id + i + 1u; } - passed += u32(subgroupXor(global_id.x + 1u) == expected); + mask = 1u << 14u; + passed |= mask * u32(subgroupXor(global_id.x + 1u) == expected); expected = 0u; for(var i = 0u; i < subgroup_invocation_id; i += 1u) { expected += global_id.x - subgroup_invocation_id + i + 1u; } - passed += u32(subgroupPrefixExclusiveAdd(global_id.x + 1u) == expected); + mask = 1u << 15u; + passed |= mask * u32(subgroupPrefixExclusiveAdd(global_id.x + 1u) == expected); expected = 1u; for(var i = 0u; i < subgroup_invocation_id; i += 1u) { expected *= global_id.x - subgroup_invocation_id + i + 1u; } - passed += u32(subgroupPrefixExclusiveMul(global_id.x + 1u) == expected); + mask = 1u << 16u; + passed |= mask * u32(subgroupPrefixExclusiveMul(global_id.x + 1u) == expected); expected = 0u; for(var i = 0u; i <= subgroup_invocation_id; i += 1u) { expected += global_id.x - subgroup_invocation_id + i + 1u; } - passed += u32(subgroupPrefixInclusiveAdd(global_id.x + 1u) == expected); + mask = 1u << 17u; + passed |= mask * u32(subgroupPrefixInclusiveAdd(global_id.x + 1u) == expected); expected = 1u; for(var i = 0u; i <= subgroup_invocation_id; i += 1u) { expected *= global_id.x - subgroup_invocation_id + i + 1u; } - passed += u32(subgroupPrefixInclusiveMul(global_id.x + 1u) == expected); - - passed += u32(subgroupBroadcastFirst(u32(subgroup_invocation_id != 0u)) == 0u); - passed += u32(subgroupBroadcastFirst(u32(subgroup_invocation_id == 0u)) == 1u); - passed += u32(subgroupBroadcast(subgroup_invocation_id, 1u) == 1u); - passed += u32(subgroupShuffle(subgroup_invocation_id, subgroup_invocation_id) == subgroup_invocation_id); - passed += u32(subgroupShuffle(subgroup_invocation_id, subgroup_size - 1u - subgroup_invocation_id) == subgroup_size - 1u - subgroup_invocation_id); - passed += u32(subgroup_invocation_id == subgroup_size - 1u || subgroupShuffleDown(subgroup_invocation_id, 1u) == subgroup_invocation_id + 1u); - passed += u32(subgroup_invocation_id == 0u || subgroupShuffleUp(subgroup_invocation_id, 1u) == subgroup_invocation_id - 1u); - passed += u32(subgroupShuffleXor(subgroup_invocation_id, subgroup_size - 1u) == (subgroup_invocation_id ^ (subgroup_size - 1u))); + mask = 1u << 18u; + passed |= mask * u32(subgroupPrefixInclusiveMul(global_id.x + 1u) == expected); + + mask = 1u << 19u; + passed |= mask * u32(subgroupBroadcastFirst(u32(subgroup_invocation_id != 0u)) == 0u); + mask = 1u << 20u; + passed |= mask * u32(subgroupBroadcastFirst(u32(subgroup_invocation_id == 0u)) == 1u); + mask = 1u << 21u; + passed |= mask * u32(subgroupBroadcast(subgroup_invocation_id, 1u) == 1u); + mask = 1u << 22u; + passed |= mask * u32(subgroupShuffle(subgroup_invocation_id, subgroup_invocation_id) == subgroup_invocation_id); + mask = 1u << 23u; + passed |= mask * u32(subgroupShuffle(subgroup_invocation_id, subgroup_size - 1u - subgroup_invocation_id) == subgroup_size - 1u - subgroup_invocation_id); + mask = 1u << 24u; + passed |= mask * u32(subgroup_invocation_id == subgroup_size - 1u || subgroupShuffleDown(subgroup_invocation_id, 1u) == subgroup_invocation_id + 1u); + mask = 1u << 25u; + passed |= mask * u32(subgroup_invocation_id == 0u || subgroupShuffleUp(subgroup_invocation_id, 1u) == subgroup_invocation_id - 1u); + mask = 1u << 26u; + passed |= mask * u32(subgroupShuffleXor(subgroup_invocation_id, subgroup_size - 1u) == (subgroup_invocation_id ^ (subgroup_size - 1u))); storage_buffer[global_id.x] = passed; } From 8420c970bdd3067a95740c5cce2c0f934f3f3ed4 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Tue, 21 Nov 2023 05:33:19 -0800 Subject: [PATCH 41/46] subgroups: Print detailed error message on test failure --- tests/tests/subgroup_operations/mod.rs | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/tests/tests/subgroup_operations/mod.rs b/tests/tests/subgroup_operations/mod.rs index 31f395cf00..aa63a21935 100644 --- a/tests/tests/subgroup_operations/mod.rs +++ b/tests/tests/subgroup_operations/mod.rs @@ -103,9 +103,27 @@ static SUBGROUP_OPERATIONS: GpuTestConfiguration = GpuTestConfiguration::new() let expected_mask = (1 << (27)) - 1; // generate full mask let expected_array = [expected_mask as u32; THREAD_COUNT as usize]; if result != &expected_array { - panic!( + use std::fmt::Write; + let mut msg = String::new(); + writeln!( + &mut msg, "Got from GPU:\n{:x?}\n expected:\n{:x?}", result, &expected_array, - ); + ) + .unwrap(); + for (thread, (result, expected)) in result + .iter() + .zip(expected_array) + .enumerate() + .filter(|(_, (r, e))| *r != e) + { + write!(&mut msg, "thread {} failed tests:", thread).unwrap(); + let difference = result ^ expected; + for i in (0..u32::BITS).filter(|i| (difference & (1 << i)) != 0) { + write!(&mut msg, " {},", i).unwrap(); + } + writeln!(&mut msg).unwrap(); + } + panic!("{}", msg); } }); From 395cd2176edcbe0a7065c49220d20d24833b89ce Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Tue, 21 Nov 2023 05:34:02 -0800 Subject: [PATCH 42/46] subgroups: Add tests in divergent control flow --- tests/tests/subgroup_operations/mod.rs | 3 ++- tests/tests/subgroup_operations/shader.wgsl | 23 +++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/tests/tests/subgroup_operations/mod.rs b/tests/tests/subgroup_operations/mod.rs index aa63a21935..a667dbddac 100644 --- a/tests/tests/subgroup_operations/mod.rs +++ b/tests/tests/subgroup_operations/mod.rs @@ -3,6 +3,7 @@ use std::{borrow::Cow, num::NonZeroU64}; use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters}; const THREAD_COUNT: u64 = 128; +const TEST_COUNT: u32 = 29; #[gpu_test] static SUBGROUP_OPERATIONS: GpuTestConfiguration = GpuTestConfiguration::new() @@ -100,7 +101,7 @@ static SUBGROUP_OPERATIONS: GpuTestConfiguration = GpuTestConfiguration::new() ctx.device.poll(wgpu::Maintain::Wait); let mapping_buffer_view = mapping_buffer.slice(..).get_mapped_range(); let result: &[u32; THREAD_COUNT as usize] = bytemuck::from_bytes(&mapping_buffer_view); - let expected_mask = (1 << (27)) - 1; // generate full mask + let expected_mask = (1 << (TEST_COUNT)) - 1; // generate full mask let expected_array = [expected_mask as u32; THREAD_COUNT as usize]; if result != &expected_array { use std::fmt::Write; diff --git a/tests/tests/subgroup_operations/shader.wgsl b/tests/tests/subgroup_operations/shader.wgsl index ecaff03b84..6e651f5672 100644 --- a/tests/tests/subgroup_operations/shader.wgsl +++ b/tests/tests/subgroup_operations/shader.wgsl @@ -132,5 +132,28 @@ fn main( mask = 1u << 26u; passed |= mask * u32(subgroupShuffleXor(subgroup_invocation_id, subgroup_size - 1u) == (subgroup_invocation_id ^ (subgroup_size - 1u))); + mask = 1u << 27u; + if subgroup_invocation_id % 2u == 0u { + passed |= mask * u32(subgroupAdd(1u) == (subgroup_size / 2u)); + } else { + passed |= mask * u32(subgroupAdd(1u) == (subgroup_size / 2u)); + } + + mask = 1u << 28u; + switch subgroup_invocation_id % 3u { + case 0u: { + passed |= mask * u32(subgroupBroadcastFirst(subgroup_invocation_id) == 0u); + } + case 1u: { + passed |= mask * u32(subgroupBroadcastFirst(subgroup_invocation_id) == 1u); + } + case 2u: { + passed |= mask * u32(subgroupBroadcastFirst(subgroup_invocation_id) == 2u); + } + default { } + } + + // Increment TEST_COUNT in subgroup_operations/mod.rs if adding more tests + storage_buffer[global_id.x] = passed; } From f1a4d75ff0a6c7a3b5fa0fd5b990ef66956286aa Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Tue, 21 Nov 2023 15:54:42 -0800 Subject: [PATCH 43/46] subgroups: correct changelog entry PR link --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c163c05c3e..5adfd1eecd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,7 +42,7 @@ Bottom level categories: ### Added/New Features -- Add `SUBGROUP_COMPUTE, SUBGROUP_FRAGMENT, SUBGROUP_VERTEX` features. By @exrook and @lichtso in [#4240](https://github.com/gfx-rs/wgpu/pull/4240) +- Add `SUBGROUP_COMPUTE, SUBGROUP_FRAGMENT, SUBGROUP_VERTEX` features. By @exrook and @lichtso in [#4190](https://github.com/gfx-rs/wgpu/pull/4190) For naga changelogs at or before v0.14.0. See [naga's changelog](naga/CHANGELOG.md). From 4bf0479df6088e6b336fe917a7d77e10f864aad1 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Tue, 28 Nov 2023 04:18:27 -0800 Subject: [PATCH 44/46] subgroups: Expect test failures on metal --- tests/tests/subgroup_operations/mod.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/tests/subgroup_operations/mod.rs b/tests/tests/subgroup_operations/mod.rs index a667dbddac..0cc235aa2f 100644 --- a/tests/tests/subgroup_operations/mod.rs +++ b/tests/tests/subgroup_operations/mod.rs @@ -11,7 +11,12 @@ static SUBGROUP_OPERATIONS: GpuTestConfiguration = GpuTestConfiguration::new() TestParameters::default() .features(wgpu::Features::SUBGROUP_COMPUTE) .limits(wgpu::Limits::downlevel_defaults()) - .expect_fail(wgpu_test::FailureCase::molten_vk()), + .expect_fail(wgpu_test::FailureCase::molten_vk()) + .expect_fail( + // Expect metal to fail on tests involving operations in divergent control flow + wgpu_test::FailureCase::backend(wgpu::Backends::METAL) + .panic("thread 0 failed tests: 27,\nthread 1 failed tests: 27, 28,\n"), + ), ) .run_sync(|ctx| { let device = &ctx.device; From 156f02624769384096db6b26f02d2f42e3bc0b3d Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Tue, 28 Nov 2023 04:40:19 -0800 Subject: [PATCH 45/46] subgroups: Add test for divergent for loop --- tests/tests/subgroup_operations/mod.rs | 2 +- tests/tests/subgroup_operations/shader.wgsl | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/tests/subgroup_operations/mod.rs b/tests/tests/subgroup_operations/mod.rs index 0cc235aa2f..2bdfb8dd55 100644 --- a/tests/tests/subgroup_operations/mod.rs +++ b/tests/tests/subgroup_operations/mod.rs @@ -3,7 +3,7 @@ use std::{borrow::Cow, num::NonZeroU64}; use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters}; const THREAD_COUNT: u64 = 128; -const TEST_COUNT: u32 = 29; +const TEST_COUNT: u32 = 30; #[gpu_test] static SUBGROUP_OPERATIONS: GpuTestConfiguration = GpuTestConfiguration::new() diff --git a/tests/tests/subgroup_operations/shader.wgsl b/tests/tests/subgroup_operations/shader.wgsl index 6e651f5672..11d6bb2f75 100644 --- a/tests/tests/subgroup_operations/shader.wgsl +++ b/tests/tests/subgroup_operations/shader.wgsl @@ -153,6 +153,16 @@ fn main( default { } } + mask = 1u << 29u; + expected = 0u; + for (var i = subgroup_size; i >= 0u; i -= 1u) { + expected = subgroupAdd(1u); + if i == subgroup_invocation_id { + break; + } + } + passed |= mask * u32(expected == (subgroup_invocation_id + 1u)); + // Increment TEST_COUNT in subgroup_operations/mod.rs if adding more tests storage_buffer[global_id.x] = passed; From f9fc7f0172ff3a38aa363e0e42f4786f854eaec0 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Tue, 28 Nov 2023 04:59:43 -0800 Subject: [PATCH 46/46] subgroups: Verify convergence after finishing other tests --- tests/tests/subgroup_operations/mod.rs | 4 ++-- tests/tests/subgroup_operations/shader.wgsl | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/tests/subgroup_operations/mod.rs b/tests/tests/subgroup_operations/mod.rs index 2bdfb8dd55..3fb386873d 100644 --- a/tests/tests/subgroup_operations/mod.rs +++ b/tests/tests/subgroup_operations/mod.rs @@ -3,7 +3,7 @@ use std::{borrow::Cow, num::NonZeroU64}; use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters}; const THREAD_COUNT: u64 = 128; -const TEST_COUNT: u32 = 30; +const TEST_COUNT: u32 = 31; #[gpu_test] static SUBGROUP_OPERATIONS: GpuTestConfiguration = GpuTestConfiguration::new() @@ -106,7 +106,7 @@ static SUBGROUP_OPERATIONS: GpuTestConfiguration = GpuTestConfiguration::new() ctx.device.poll(wgpu::Maintain::Wait); let mapping_buffer_view = mapping_buffer.slice(..).get_mapped_range(); let result: &[u32; THREAD_COUNT as usize] = bytemuck::from_bytes(&mapping_buffer_view); - let expected_mask = (1 << (TEST_COUNT)) - 1; // generate full mask + let expected_mask = (1u64 << (TEST_COUNT)) - 1; // generate full mask let expected_array = [expected_mask as u32; THREAD_COUNT as usize]; if result != &expected_array { use std::fmt::Write; diff --git a/tests/tests/subgroup_operations/shader.wgsl b/tests/tests/subgroup_operations/shader.wgsl index 11d6bb2f75..0f1dc47cd9 100644 --- a/tests/tests/subgroup_operations/shader.wgsl +++ b/tests/tests/subgroup_operations/shader.wgsl @@ -163,6 +163,10 @@ fn main( } passed |= mask * u32(expected == (subgroup_invocation_id + 1u)); + // Keep this test last, verify we are still convergent after running other tests + mask = 1u << 30u; + passed |= mask * u32(subgroup_size == subgroupAdd(1u)); + // Increment TEST_COUNT in subgroup_operations/mod.rs if adding more tests storage_buffer[global_id.x] = passed;