Skip to content

Commit

Permalink
Implement subgroupBallot for wgsl-in, wgsl-out, spv-out, hlsl-out
Browse files Browse the repository at this point in the history
TODO: metal out, figure out what needs to be done in validation
  • Loading branch information
exrook committed Sep 30, 2023
1 parent 9f3cdb6 commit b851a52
Show file tree
Hide file tree
Showing 19 changed files with 128 additions and 5 deletions.
5 changes: 5 additions & 0 deletions src/back/dot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,10 @@ impl StatementGraph {
crate::RayQueryFunction::Terminate => "RayQueryTerminate",
}
}
S::SubgroupBallot { result } => {
self.emits.push((id, result));
"SubgroupBallot"
}
};
// Set the last node to the merge node
last_node = merge_id;
Expand Down Expand Up @@ -586,6 +590,7 @@ fn write_function_expressions(
let ty = if committed { "Committed" } else { "Candidate" };
(format!("rayQueryGet{}Intersection", ty).into(), 4)
}
E::SubgroupBallotResult => ("SubgroupBallotResult".into(), 4),
};

// give uniform expressions an outline
Expand Down
13 changes: 12 additions & 1 deletion src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2238,6 +2238,16 @@ impl<'a, W: Write> Writer<'a, W> {
writeln!(self.out, ");")?;
}
Statement::RayQuery { .. } => unreachable!(),
Statement::SubgroupBallot { result } => {
write!(self.out, "{level}")?;
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
let res_ty = ctx.info[result].ty.inner_with(&self.module.types);
self.write_value_type(res_ty)?;
write!(self.out, " {res_name} = ")?;
self.named_expressions.insert(result, res_name);

writeln!(self.out, "subgroupBallot(true);")?;
}
}

Ok(())
Expand Down Expand Up @@ -3426,7 +3436,8 @@ impl<'a, W: Write> Writer<'a, W> {
Expression::CallResult(_)
| Expression::AtomicResult { .. }
| Expression::RayQueryProceedResult
| Expression::WorkGroupUniformLoadResult { .. } => unreachable!(),
| Expression::WorkGroupUniformLoadResult { .. }
| Expression::SubgroupBallotResult => unreachable!(),
// `ArrayLength` is written as `expr.length()` and we convert it to a uint
Expression::ArrayLength(expr) => {
write!(self.out, "uint(")?;
Expand Down
12 changes: 11 additions & 1 deletion src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2008,6 +2008,15 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
writeln!(self.out, "{level}}}")?
}
Statement::RayQuery { .. } => unreachable!(),
Statement::SubgroupBallot { result } => {
write!(self.out, "{level}")?;

let name = format!("{}{}", back::BAKE_PREFIX, result.index());
write!(self.out, "const uint4 {name} = ")?;
self.named_expressions.insert(result, name);

writeln!(self.out, "WaveActiveBallot(true);")?;
}
}

Ok(())
Expand Down Expand Up @@ -3186,7 +3195,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Expression::CallResult(_)
| Expression::AtomicResult { .. }
| Expression::WorkGroupUniformLoadResult { .. }
| Expression::RayQueryProceedResult => {}
| Expression::RayQueryProceedResult
| Expression::SubgroupBallotResult => {}
}

if !closing_bracket.is_empty() {
Expand Down
2 changes: 2 additions & 0 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1963,6 +1963,7 @@ impl<W: Write> Writer<W> {
}
write!(self.out, "}}")?;
}
crate::Expression::SubgroupBallotResult => todo!(),
}
Ok(())
}
Expand Down Expand Up @@ -2976,6 +2977,7 @@ impl<W: Write> Writer<W> {
}
}
}
crate::Statement::SubgroupBallot { result } => todo!(),
}
}

Expand Down
21 changes: 20 additions & 1 deletion src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1118,7 +1118,8 @@ impl<'w> BlockContext<'w> {
crate::Expression::CallResult(_)
| crate::Expression::AtomicResult { .. }
| crate::Expression::WorkGroupUniformLoadResult { .. }
| crate::Expression::RayQueryProceedResult => self.cached[expr_handle],
| crate::Expression::RayQueryProceedResult
| crate::Expression::SubgroupBallotResult => self.cached[expr_handle],
crate::Expression::As {
expr,
kind,
Expand Down Expand Up @@ -2327,6 +2328,24 @@ impl<'w> BlockContext<'w> {
crate::Statement::RayQuery { query, ref fun } => {
self.write_ray_query_function(query, fun, &mut block);
}
crate::Statement::SubgroupBallot { result } => {
let vec4_u32_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: Some(crate::VectorSize::Quad),
kind: crate::ScalarKind::Uint,
width: 4,
pointer_space: Some(spirv::StorageClass::Output),
}));
let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);
let predicate = self.writer.get_constant_scalar(crate::Literal::Bool(true));
let id = self.gen_id();
block.body.push(Instruction::group_non_uniform_ballot(
vec4_u32_type_id,
id,
exec_scope_id,
predicate,
));
self.cached[result] = id;
}
}
}

Expand Down
17 changes: 17 additions & 0 deletions src/back/spv/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,23 @@ impl super::Instruction {
instruction.add_operand(semantics_id);
instruction
}

// Group Instructions

pub(super) fn group_non_uniform_ballot(
result_type_id: Word,
id: Word,
exec_scope_id: Word,
predicate: Word,
) -> Self {
let mut instruction = Self::new(Op::GroupNonUniformBallot);
instruction.set_type(result_type_id);
instruction.set_result(id);
instruction.add_operand(exec_scope_id);
instruction.add_operand(predicate);

instruction
}
}

impl From<crate::StorageFormat> for spirv::ImageFormat {
Expand Down
9 changes: 9 additions & 0 deletions src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,14 @@ impl<W: Write> Writer<W> {
}
}
Statement::RayQuery { .. } => unreachable!(),
Statement::SubgroupBallot { result } => {
write!(self.out, "{level}")?;
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
self.start_named_expr(module, result, func_ctx, &res_name)?;
self.named_expressions.insert(result, res_name);

writeln!(self.out, "subgroupBallot();")?;
}
}

Ok(())
Expand Down Expand Up @@ -1668,6 +1676,7 @@ impl<W: Write> Writer<W> {
Expression::CallResult(_)
| Expression::AtomicResult { .. }
| Expression::RayQueryProceedResult
| Expression::SubgroupBallotResult
| Expression::WorkGroupUniformLoadResult { .. } => {}
}

Expand Down
2 changes: 2 additions & 0 deletions src/compact/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ impl<'tracer> ExpressionTracer<'tracer> {
| Ex::GlobalVariable(_)
| Ex::LocalVariable(_)
| Ex::CallResult(_)
| Ex::SubgroupBallotResult // FIXME: ???
| Ex::RayQueryProceedResult => {}

Ex::Constant(handle) => {
Expand Down Expand Up @@ -222,6 +223,7 @@ impl ModuleMap {
| Ex::GlobalVariable(_)
| Ex::LocalVariable(_)
| Ex::CallResult(_)
| Ex::SubgroupBallotResult // FIXME: ???
| Ex::RayQueryProceedResult => {}

// Expressions that contain handles that need to be adjusted.
Expand Down
4 changes: 4 additions & 0 deletions src/compact/statements.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ impl FunctionTracer<'_> {
self.trace_expression(query);
self.trace_ray_query_function(fun);
}
St::SubgroupBallot { result } => {
self.trace_expression(result);
}

// Trivial statements.
St::Break
Expand Down Expand Up @@ -252,6 +255,7 @@ impl FunctionMap {
adjust(query);
self.adjust_ray_query_function(fun);
}
St::SubgroupBallot { ref mut result } => adjust(result),

// Trivial statements.
St::Break
Expand Down
1 change: 1 addition & 0 deletions src/front/glsl/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ impl<'a> ConstantSolver<'a> {
Expression::RayQueryProceedResult | Expression::RayQueryGetIntersection { .. } => {
Err(ConstantSolvingError::RayQueryExpression)
}
Expression::SubgroupBallotResult { .. } => unreachable!(),
}
}

Expand Down
1 change: 1 addition & 0 deletions src/front/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3837,6 +3837,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
}
}
S::WorkGroupUniformLoad { .. } => unreachable!(),
S::SubgroupBallot { .. } => unreachable!(),
}
i += 1;
}
Expand Down
10 changes: 10 additions & 0 deletions src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2157,6 +2157,16 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
)?;
return Ok(Some(handle));
}
"subgroupBallot" => {
ctx.prepare_args(arguments, 0, span).finish()?;

let result = ctx
.interrupt_emitter(crate::Expression::SubgroupBallotResult, span);
let rctx = ctx.runtime_expression_ctx(span)?;
rctx.block
.push(crate::Statement::SubgroupBallot { result }, span);
return Ok(Some(result));
}
_ => return Err(Error::UnknownIdent(function.span, function.name)),
}
};
Expand Down
16 changes: 14 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1392,7 +1392,9 @@ pub enum Expression {
///
/// For [`TypeInner::Atomic`] the result is a corresponding scalar.
/// For other types behind the `pointer<T>`, the result is `T`.
Load { pointer: Handle<Expression> },
Load {
pointer: Handle<Expression>,
},
/// Sample a point from a sampled or a depth image.
ImageSample {
image: Handle<Expression>,
Expand Down Expand Up @@ -1532,7 +1534,10 @@ pub enum Expression {
/// Result of calling another function.
CallResult(Handle<Function>),
/// Result of an atomic operation.
AtomicResult { ty: Handle<Type>, comparison: bool },
AtomicResult {
ty: Handle<Type>,
comparison: bool,
},
/// Result of a [`WorkGroupUniformLoad`] statement.
///
/// [`WorkGroupUniformLoad`]: Statement::WorkGroupUniformLoad
Expand Down Expand Up @@ -1560,6 +1565,7 @@ pub enum Expression {
query: Handle<Expression>,
committed: bool,
},
SubgroupBallotResult,
}

pub use block::Block;
Expand Down Expand Up @@ -1832,6 +1838,12 @@ pub enum Statement {
/// The specific operation we're performing on `query`.
fun: RayQueryFunction,
},
SubgroupBallot {
/// The [`SubgroupBallotResult`] expression representing this load's result.
///
/// [`SubgroupBallotResult`]: Expression::SubgroupBallotResult
result: Handle<Expression>,
},
}

/// A function argument.
Expand Down
1 change: 1 addition & 0 deletions src/proc/terminator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pub fn ensure_block_returns(block: &mut crate::Block) {
| S::RayQuery { .. }
| S::Atomic { .. }
| S::WorkGroupUniformLoad { .. }
| S::SubgroupBallot { .. }
| S::Barrier(_)),
)
| None => block.push(S::Return { value: None }, Default::default()),
Expand Down
5 changes: 5 additions & 0 deletions src/proc/typifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,11 @@ impl<'a> ResolveContext<'a> {
.ok_or(ResolveError::MissingSpecialType)?;
TypeResolution::Handle(result)
}
crate::Expression::SubgroupBallotResult => TypeResolution::Value(Ti::Vector {
kind: crate::ScalarKind::Uint,
size: crate::VectorSize::Quad,
width: 4,
}),
})
}
}
Expand Down
5 changes: 5 additions & 0 deletions src/valid/analyzer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,10 @@ impl FunctionInfo {
non_uniform_result: self.add_ref(query),
requirements: UniformityRequirements::empty(),
},
E::SubgroupBallotResult => Uniformity {
non_uniform_result: None,
requirements: UniformityRequirements::empty(),
},
};

let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?;
Expand Down Expand Up @@ -966,6 +970,7 @@ impl FunctionInfo {
}
FunctionUniformity::new()
}
S::SubgroupBallot { result: _ } => FunctionUniformity::new(),
};

disruptor = disruptor.or(uniformity.exit_disruptor());
Expand Down
1 change: 1 addition & 0 deletions src/valid/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1498,6 +1498,7 @@ impl super::Validator {
return Err(ExpressionError::InvalidRayQueryType(query));
}
},
E::SubgroupBallotResult => ShaderStages::COMPUTE | ShaderStages::FRAGMENT,
};
Ok(stages)
}
Expand Down
3 changes: 3 additions & 0 deletions src/valid/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,9 @@ impl super::Validator {
crate::RayQueryFunction::Terminate => {}
}
}
S::SubgroupBallot { result } => {
self.emit_expression(result, context)?;
}
}
}
Ok(BlockInfo { stages, finished })
Expand Down
5 changes: 5 additions & 0 deletions src/valid/handles.rs
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ impl super::Validator {
}
crate::Expression::AtomicResult { .. }
| crate::Expression::RayQueryProceedResult
| crate::Expression::SubgroupBallotResult
| crate::Expression::WorkGroupUniformLoadResult { .. } => (),
crate::Expression::ArrayLength(array) => {
handle.check_dep(array)?;
Expand Down Expand Up @@ -539,6 +540,10 @@ impl super::Validator {
}
Ok(())
}
crate::Statement::SubgroupBallot { result } => {
validate_expr(result)?;
Ok(())
}
crate::Statement::Break
| crate::Statement::Continue
| crate::Statement::Kill
Expand Down

0 comments on commit b851a52

Please sign in to comment.