Skip to content

Commit

Permalink
subgroup: add optional predicate for subgroupBallot
Browse files Browse the repository at this point in the history
  • Loading branch information
exrook authored and Lichtso committed Oct 14, 2023
1 parent ec30689 commit a50d65a
Show file tree
Hide file tree
Showing 12 changed files with 93 additions and 19 deletions.
5 changes: 4 additions & 1 deletion src/back/dot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,10 @@ impl StatementGraph {
crate::RayQueryFunction::Terminate => "RayQueryTerminate",
}
}
S::SubgroupBallot { result } => {
S::SubgroupBallot { result, predicate } => {
if let Some(predicate) = predicate {
self.dependencies.push((id, predicate, "predicate"));
}
self.emits.push((id, result));
"SubgroupBallot"
}
Expand Down
9 changes: 7 additions & 2 deletions src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2238,15 +2238,20 @@ impl<'a, W: Write> Writer<'a, W> {
writeln!(self.out, ");")?;
}
Statement::RayQuery { .. } => unreachable!(),
Statement::SubgroupBallot { result } => {
Statement::SubgroupBallot { result, predicate } => {
write!(self.out, "{level}")?;
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
let res_ty = ctx.info[result].ty.inner_with(&self.module.types);
self.write_value_type(res_ty)?;
write!(self.out, " {res_name} = ")?;
self.named_expressions.insert(result, res_name);

writeln!(self.out, "subgroupBallot(true);")?;
write!(self.out, "subgroupBallot(")?;
match predicate {
Some(predicate) => self.write_expr(predicate, ctx)?,
None => write!(self.out, "true")?,
}
write!(self.out, ");")?;
}
Statement::SubgroupCollectiveOperation {
ref op,
Expand Down
9 changes: 7 additions & 2 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2008,14 +2008,19 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
writeln!(self.out, "{level}}}")?
}
Statement::RayQuery { .. } => unreachable!(),
Statement::SubgroupBallot { result } => {
Statement::SubgroupBallot { result, predicate } => {
write!(self.out, "{level}")?;

let name = format!("{}{}", back::BAKE_PREFIX, result.index());
write!(self.out, "const uint4 {name} = ")?;
self.named_expressions.insert(result, name);

writeln!(self.out, "WaveActiveBallot(true);")?;
write!(self.out, "WaveActiveBallot(")?;
match predicate {
Some(predicate) => self.write_expr(module, predicate, func_ctx)?,
None => write!(self.out, "true")?,
}
writeln!(self.out, ");")?;
}
Statement::SubgroupCollectiveOperation {
ref op,
Expand Down
11 changes: 9 additions & 2 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2996,12 +2996,19 @@ impl<W: Write> Writer<W> {
}
}
}
crate::Statement::SubgroupBallot { result } => {
crate::Statement::SubgroupBallot { result, predicate } => {
write!(self.out, "{level}")?;
let name = self.namer.call("");
self.start_baking_expression(result, &context.expression, &name)?;
self.named_expressions.insert(result, name);
write!(self.out, "{NAMESPACE}::simd_active_threads_mask();")?;
write!(self.out, "{NAMESPACE}::simd_ballot(;")?;
match predicate {
Some(predicate) => {
self.put_expression(predicate, &context.expression, true)?
}
None => write!(self.out, "true")?,
}
writeln!(self.out, ");")?;
}
crate::Statement::SubgroupCollectiveOperation {
ref op,
Expand Down
7 changes: 5 additions & 2 deletions src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2350,7 +2350,7 @@ impl<'w> BlockContext<'w> {
crate::Statement::RayQuery { query, ref fun } => {
self.write_ray_query_function(query, fun, &mut block);
}
crate::Statement::SubgroupBallot { result } => {
crate::Statement::SubgroupBallot { result, predicate } => {
self.writer.require_any(
"GroupNonUniformBallot",
&[spirv::Capability::GroupNonUniformBallot],
Expand All @@ -2362,7 +2362,10 @@ impl<'w> BlockContext<'w> {
pointer_space: None,
}));
let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
let predicate = self.writer.get_constant_scalar(crate::Literal::Bool(true));
let predicate = match predicate {
Some(predicate) => self.cached[predicate],
None => self.writer.get_constant_scalar(crate::Literal::Bool(true)),
};
let id = self.gen_id();
block.body.push(Instruction::group_non_uniform_ballot(
vec4_u32_type_id,
Expand Down
8 changes: 6 additions & 2 deletions src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -921,13 +921,17 @@ impl<W: Write> Writer<W> {
}
}
Statement::RayQuery { .. } => unreachable!(),
Statement::SubgroupBallot { result } => {
Statement::SubgroupBallot { result, predicate } => {
write!(self.out, "{level}")?;
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
self.start_named_expr(module, result, func_ctx, &res_name)?;
self.named_expressions.insert(result, res_name);

writeln!(self.out, "subgroupBallot();")?;
writeln!(self.out, "subgroupBallot(")?;
if let Some(predicate) = predicate {
self.write_expr(module, predicate, func_ctx)?;
}
writeln!(self.out, ");")?;
}
Statement::SubgroupCollectiveOperation {
ref op,
Expand Down
15 changes: 13 additions & 2 deletions src/compact/statements.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,10 @@ impl FunctionTracer<'_> {
self.trace_expression(query);
self.trace_ray_query_function(fun);
}
St::SubgroupBallot { result } => {
St::SubgroupBallot { result, predicate } => {
if let Some(predicate) = predicate {
self.trace_expression(predicate);
}
self.trace_expression(result);
}
St::SubgroupCollectiveOperation {
Expand Down Expand Up @@ -267,7 +270,15 @@ impl FunctionMap {
adjust(query);
self.adjust_ray_query_function(fun);
}
St::SubgroupBallot { ref mut result } => adjust(result),
St::SubgroupBallot {
ref mut result,
ref mut predicate,
} => {
if let Some(ref mut predicate) = predicate {
adjust(predicate);
}
adjust(result);
}
St::SubgroupCollectiveOperation {
ref mut op,
ref mut collective_op,
Expand Down
10 changes: 8 additions & 2 deletions src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2267,13 +2267,19 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
return Ok(Some(handle));
}
"subgroupBallot" => {
ctx.prepare_args(arguments, 0, span).finish()?;
let mut args = ctx.prepare_args(arguments, 0, span);
let predicate = if arguments.len() == 1 {
Some(self.expression(args.next()?, ctx.reborrow())?)
} else {
None
};
args.finish()?;

let result = ctx
.interrupt_emitter(crate::Expression::SubgroupBallotResult, span);
let rctx = ctx.runtime_expression_ctx(span)?;
rctx.block
.push(crate::Statement::SubgroupBallot { result }, span);
.push(crate::Statement::SubgroupBallot { result, predicate }, span);
return Ok(Some(result));
}
"subgroupBroadcast" => {
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1896,6 +1896,8 @@ pub enum Statement {
///
/// [`SubgroupBallotResult`]: Expression::SubgroupBallotResult
result: Handle<Expression>,
/// The value from this thread to store in the ballot
predicate: Option<Handle<Expression>>,
},

SubgroupBroadcast {
Expand Down
12 changes: 10 additions & 2 deletions src/valid/analyzer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -991,14 +991,22 @@ impl FunctionInfo {
}
FunctionUniformity::new()
}
S::SubgroupBallot { result: _ } => FunctionUniformity::new(),
S::SubgroupBallot {
result: _,
predicate,
} => {
if let Some(predicate) = predicate {
let _ = self.add_ref(predicate);
}
FunctionUniformity::new()
}
S::SubgroupCollectiveOperation {
ref op,
ref collective_op,
argument,
result: _,
} => {
let _ = self.add_ref(argument);
let _ = self.add_ref(argument); // FIXME
FunctionUniformity::new()
}
S::SubgroupBroadcast {
Expand Down
21 changes: 20 additions & 1 deletion src/valid/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1026,7 +1026,26 @@ impl super::Validator {
crate::RayQueryFunction::Terminate => {}
}
}
S::SubgroupBallot { result } => {
S::SubgroupBallot { result, predicate } => {
if let Some(predicate) = predicate {
let predicate_inner =
context.resolve_type(predicate, &self.valid_expression_set)?;
match predicate_inner {
crate::TypeInner::Scalar {
kind: crate::ScalarKind::Bool,
..
} => {}
_ => {
log::error!(
"Subgroup ballot predicate type {:?} expected bool",
predicate_inner
);
return Err(SubgroupError::InvalidOperand(predicate)
.with_span_handle(predicate, context.expressions)
.into_other());
}
}
}
self.emit_expression(result, context)?;
}
S::SubgroupCollectiveOperation {
Expand Down
3 changes: 2 additions & 1 deletion src/valid/handles.rs
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,8 @@ impl super::Validator {
}
Ok(())
}
crate::Statement::SubgroupBallot { result } => {
crate::Statement::SubgroupBallot { result, predicate } => {
validate_expr_opt(predicate)?;
validate_expr(result)?;
Ok(())
}
Expand Down

0 comments on commit a50d65a

Please sign in to comment.