From b851a52b019c00fa473a6c8fac32602672f7fe42 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Fri, 29 Sep 2023 23:31:39 -0400 Subject: [PATCH] Implement subgroupBallot for wgsl-in, wgsl-out, spv-out, hlsl-out TODO: metal out, figure out what needs to be done in validation --- src/back/dot/mod.rs | 5 +++++ src/back/glsl/mod.rs | 13 ++++++++++++- src/back/hlsl/writer.rs | 12 +++++++++++- src/back/msl/writer.rs | 2 ++ src/back/spv/block.rs | 21 ++++++++++++++++++++- src/back/spv/instructions.rs | 17 +++++++++++++++++ src/back/wgsl/writer.rs | 9 +++++++++ src/compact/expressions.rs | 2 ++ src/compact/statements.rs | 4 ++++ src/front/glsl/constants.rs | 1 + src/front/spv/mod.rs | 1 + src/front/wgsl/lower/mod.rs | 10 ++++++++++ src/lib.rs | 16 ++++++++++++++-- src/proc/terminator.rs | 1 + src/proc/typifier.rs | 5 +++++ src/valid/analyzer.rs | 5 +++++ src/valid/expression.rs | 1 + src/valid/function.rs | 3 +++ src/valid/handles.rs | 5 +++++ 19 files changed, 128 insertions(+), 5 deletions(-) diff --git a/src/back/dot/mod.rs b/src/back/dot/mod.rs index 1556371df1..b24eeae5ad 100644 --- a/src/back/dot/mod.rs +++ b/src/back/dot/mod.rs @@ -279,6 +279,10 @@ impl StatementGraph { crate::RayQueryFunction::Terminate => "RayQueryTerminate", } } + S::SubgroupBallot { result } => { + self.emits.push((id, result)); + "SubgroupBallot" + } }; // Set the last node to the merge node last_node = merge_id; @@ -586,6 +590,7 @@ fn write_function_expressions( let ty = if committed { "Committed" } else { "Candidate" }; (format!("rayQueryGet{}Intersection", ty).into(), 4) } + E::SubgroupBallotResult => ("SubgroupBallotResult".into(), 4), }; // give uniform expressions an outline diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index eb2c2ed025..6e7f1ce9c6 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -2238,6 +2238,16 @@ impl<'a, W: Write> Writer<'a, W> { writeln!(self.out, ");")?; } Statement::RayQuery { .. } => unreachable!(), + Statement::SubgroupBallot { result } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + let res_ty = ctx.info[result].ty.inner_with(&self.module.types); + self.write_value_type(res_ty)?; + write!(self.out, " {res_name} = ")?; + self.named_expressions.insert(result, res_name); + + writeln!(self.out, "subgroupBallot(true);")?; + } } Ok(()) @@ -3426,7 +3436,8 @@ impl<'a, W: Write> Writer<'a, W> { Expression::CallResult(_) | Expression::AtomicResult { .. } | Expression::RayQueryProceedResult - | Expression::WorkGroupUniformLoadResult { .. } => unreachable!(), + | Expression::WorkGroupUniformLoadResult { .. } + | Expression::SubgroupBallotResult => unreachable!(), // `ArrayLength` is written as `expr.length()` and we convert it to a uint Expression::ArrayLength(expr) => { write!(self.out, "uint(")?; diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index e9809bd2d4..4e622a72b0 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -2008,6 +2008,15 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { writeln!(self.out, "{level}}}")? } Statement::RayQuery { .. } => unreachable!(), + Statement::SubgroupBallot { result } => { + write!(self.out, "{level}")?; + + let name = format!("{}{}", back::BAKE_PREFIX, result.index()); + write!(self.out, "const uint4 {name} = ")?; + self.named_expressions.insert(result, name); + + writeln!(self.out, "WaveActiveBallot(true);")?; + } } Ok(()) @@ -3186,7 +3195,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { Expression::CallResult(_) | Expression::AtomicResult { .. } | Expression::WorkGroupUniformLoadResult { .. } - | Expression::RayQueryProceedResult => {} + | Expression::RayQueryProceedResult + | Expression::SubgroupBallotResult => {} } if !closing_bracket.is_empty() { diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 77231a286d..ca85bfd3f8 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -1963,6 +1963,7 @@ impl Writer { } write!(self.out, "}}")?; } + crate::Expression::SubgroupBallotResult => todo!(), } Ok(()) } @@ -2976,6 +2977,7 @@ impl Writer { } } } + crate::Statement::SubgroupBallot { result } => todo!(), } } diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index c5246ad190..90de9ce127 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -1118,7 +1118,8 @@ impl<'w> BlockContext<'w> { crate::Expression::CallResult(_) | crate::Expression::AtomicResult { .. } | crate::Expression::WorkGroupUniformLoadResult { .. } - | crate::Expression::RayQueryProceedResult => self.cached[expr_handle], + | crate::Expression::RayQueryProceedResult + | crate::Expression::SubgroupBallotResult => self.cached[expr_handle], crate::Expression::As { expr, kind, @@ -2327,6 +2328,24 @@ impl<'w> BlockContext<'w> { crate::Statement::RayQuery { query, ref fun } => { self.write_ray_query_function(query, fun, &mut block); } + crate::Statement::SubgroupBallot { result } => { + let vec4_u32_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(crate::VectorSize::Quad), + kind: crate::ScalarKind::Uint, + width: 4, + pointer_space: Some(spirv::StorageClass::Output), + })); + let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); + let predicate = self.writer.get_constant_scalar(crate::Literal::Bool(true)); + let id = self.gen_id(); + block.body.push(Instruction::group_non_uniform_ballot( + vec4_u32_type_id, + id, + exec_scope_id, + predicate, + )); + self.cached[result] = id; + } } } diff --git a/src/back/spv/instructions.rs b/src/back/spv/instructions.rs index 11794fc73b..b562e591b2 100644 --- a/src/back/spv/instructions.rs +++ b/src/back/spv/instructions.rs @@ -1037,6 +1037,23 @@ impl super::Instruction { instruction.add_operand(semantics_id); instruction } + + // Group Instructions + + pub(super) fn group_non_uniform_ballot( + result_type_id: Word, + id: Word, + exec_scope_id: Word, + predicate: Word, + ) -> Self { + let mut instruction = Self::new(Op::GroupNonUniformBallot); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(exec_scope_id); + instruction.add_operand(predicate); + + instruction + } } impl From for spirv::ImageFormat { diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index 980d768782..bbdefe1b7d 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -921,6 +921,14 @@ impl Writer { } } Statement::RayQuery { .. } => unreachable!(), + Statement::SubgroupBallot { result } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_named_expr(module, result, func_ctx, &res_name)?; + self.named_expressions.insert(result, res_name); + + writeln!(self.out, "subgroupBallot();")?; + } } Ok(()) @@ -1668,6 +1676,7 @@ impl Writer { Expression::CallResult(_) | Expression::AtomicResult { .. } | Expression::RayQueryProceedResult + | Expression::SubgroupBallotResult | Expression::WorkGroupUniformLoadResult { .. } => {} } diff --git a/src/compact/expressions.rs b/src/compact/expressions.rs index 4ccf559c4e..a9592a229c 100644 --- a/src/compact/expressions.rs +++ b/src/compact/expressions.rs @@ -55,6 +55,7 @@ impl<'tracer> ExpressionTracer<'tracer> { | Ex::GlobalVariable(_) | Ex::LocalVariable(_) | Ex::CallResult(_) + | Ex::SubgroupBallotResult // FIXME: ??? | Ex::RayQueryProceedResult => {} Ex::Constant(handle) => { @@ -222,6 +223,7 @@ impl ModuleMap { | Ex::GlobalVariable(_) | Ex::LocalVariable(_) | Ex::CallResult(_) + | Ex::SubgroupBallotResult // FIXME: ??? | Ex::RayQueryProceedResult => {} // Expressions that contain handles that need to be adjusted. diff --git a/src/compact/statements.rs b/src/compact/statements.rs index 2c33df0494..3f47d545b4 100644 --- a/src/compact/statements.rs +++ b/src/compact/statements.rs @@ -94,6 +94,9 @@ impl FunctionTracer<'_> { self.trace_expression(query); self.trace_ray_query_function(fun); } + St::SubgroupBallot { result } => { + self.trace_expression(result); + } // Trivial statements. St::Break @@ -252,6 +255,7 @@ impl FunctionMap { adjust(query); self.adjust_ray_query_function(fun); } + St::SubgroupBallot { ref mut result } => adjust(result), // Trivial statements. St::Break diff --git a/src/front/glsl/constants.rs b/src/front/glsl/constants.rs index 81ec6f3a8c..670605858d 100644 --- a/src/front/glsl/constants.rs +++ b/src/front/glsl/constants.rs @@ -349,6 +349,7 @@ impl<'a> ConstantSolver<'a> { Expression::RayQueryProceedResult | Expression::RayQueryGetIntersection { .. } => { Err(ConstantSolvingError::RayQueryExpression) } + Expression::SubgroupBallotResult { .. } => unreachable!(), } } diff --git a/src/front/spv/mod.rs b/src/front/spv/mod.rs index 5fd54b023f..94c9673ad0 100644 --- a/src/front/spv/mod.rs +++ b/src/front/spv/mod.rs @@ -3837,6 +3837,7 @@ impl> Frontend { } } S::WorkGroupUniformLoad { .. } => unreachable!(), + S::SubgroupBallot { .. } => unreachable!(), } i += 1; } diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index 56d8709708..67dd97c1d9 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -2157,6 +2157,16 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { )?; return Ok(Some(handle)); } + "subgroupBallot" => { + ctx.prepare_args(arguments, 0, span).finish()?; + + let result = ctx + .interrupt_emitter(crate::Expression::SubgroupBallotResult, span); + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block + .push(crate::Statement::SubgroupBallot { result }, span); + return Ok(Some(result)); + } _ => return Err(Error::UnknownIdent(function.span, function.name)), } }; diff --git a/src/lib.rs b/src/lib.rs index 9351572892..ff5a48d976 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1392,7 +1392,9 @@ pub enum Expression { /// /// For [`TypeInner::Atomic`] the result is a corresponding scalar. /// For other types behind the `pointer`, the result is `T`. - Load { pointer: Handle }, + Load { + pointer: Handle, + }, /// Sample a point from a sampled or a depth image. ImageSample { image: Handle, @@ -1532,7 +1534,10 @@ pub enum Expression { /// Result of calling another function. CallResult(Handle), /// Result of an atomic operation. - AtomicResult { ty: Handle, comparison: bool }, + AtomicResult { + ty: Handle, + comparison: bool, + }, /// Result of a [`WorkGroupUniformLoad`] statement. /// /// [`WorkGroupUniformLoad`]: Statement::WorkGroupUniformLoad @@ -1560,6 +1565,7 @@ pub enum Expression { query: Handle, committed: bool, }, + SubgroupBallotResult, } pub use block::Block; @@ -1832,6 +1838,12 @@ pub enum Statement { /// The specific operation we're performing on `query`. fun: RayQueryFunction, }, + SubgroupBallot { + /// The [`SubgroupBallotResult`] expression representing this load's result. + /// + /// [`SubgroupBallotResult`]: Expression::SubgroupBallotResult + result: Handle, + }, } /// A function argument. diff --git a/src/proc/terminator.rs b/src/proc/terminator.rs index a5239d4eca..d2dde729f1 100644 --- a/src/proc/terminator.rs +++ b/src/proc/terminator.rs @@ -37,6 +37,7 @@ pub fn ensure_block_returns(block: &mut crate::Block) { | S::RayQuery { .. } | S::Atomic { .. } | S::WorkGroupUniformLoad { .. } + | S::SubgroupBallot { .. } | S::Barrier(_)), ) | None => block.push(S::Return { value: None }, Default::default()), diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index a6130ad796..854ddd9db4 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -906,6 +906,11 @@ impl<'a> ResolveContext<'a> { .ok_or(ResolveError::MissingSpecialType)?; TypeResolution::Handle(result) } + crate::Expression::SubgroupBallotResult => TypeResolution::Value(Ti::Vector { + kind: crate::ScalarKind::Uint, + size: crate::VectorSize::Quad, + width: 4, + }), }) } } diff --git a/src/valid/analyzer.rs b/src/valid/analyzer.rs index 9b21b7f732..8810bcc67d 100644 --- a/src/valid/analyzer.rs +++ b/src/valid/analyzer.rs @@ -723,6 +723,10 @@ impl FunctionInfo { non_uniform_result: self.add_ref(query), requirements: UniformityRequirements::empty(), }, + E::SubgroupBallotResult => Uniformity { + non_uniform_result: None, + requirements: UniformityRequirements::empty(), + }, }; let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?; @@ -966,6 +970,7 @@ impl FunctionInfo { } FunctionUniformity::new() } + S::SubgroupBallot { result: _ } => FunctionUniformity::new(), }; disruptor = disruptor.or(uniformity.exit_disruptor()); diff --git a/src/valid/expression.rs b/src/valid/expression.rs index 3426bf008e..ac464d904e 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -1498,6 +1498,7 @@ impl super::Validator { return Err(ExpressionError::InvalidRayQueryType(query)); } }, + E::SubgroupBallotResult => ShaderStages::COMPUTE | ShaderStages::FRAGMENT, }; Ok(stages) } diff --git a/src/valid/function.rs b/src/valid/function.rs index 06aa27c84b..48f3141d83 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -917,6 +917,9 @@ impl super::Validator { crate::RayQueryFunction::Terminate => {} } } + S::SubgroupBallot { result } => { + self.emit_expression(result, context)?; + } } } Ok(BlockInfo { stages, finished }) diff --git a/src/valid/handles.rs b/src/valid/handles.rs index da95f60842..a1764788b8 100644 --- a/src/valid/handles.rs +++ b/src/valid/handles.rs @@ -394,6 +394,7 @@ impl super::Validator { } crate::Expression::AtomicResult { .. } | crate::Expression::RayQueryProceedResult + | crate::Expression::SubgroupBallotResult | crate::Expression::WorkGroupUniformLoadResult { .. } => (), crate::Expression::ArrayLength(array) => { handle.check_dep(array)?; @@ -539,6 +540,10 @@ impl super::Validator { } Ok(()) } + crate::Statement::SubgroupBallot { result } => { + validate_expr(result)?; + Ok(()) + } crate::Statement::Break | crate::Statement::Continue | crate::Statement::Kill