Skip to content

Commit

Permalink
subgroup: add validation for each subgroup operation type
Browse files Browse the repository at this point in the history
supported operations and stages subgroup operations are supported on
can be passed to the validator after creating it

operations are grouped to follow vulkan:
- basic: elect, barrier
- vote: any, all
- arithmetic: reductions, scan
- ballot: ballot, broadcasts,
- shuffle: shuffles,
- shuffle relative: shuffle up, down
  • Loading branch information
exrook committed Oct 24, 2023
1 parent 36844c8 commit 63e2da9
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 15 deletions.
2 changes: 2 additions & 0 deletions cli/src/bin/naga.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,8 @@ fn run() -> Result<(), Box<dyn std::error::Error>> {

// Validate the IR before compaction.
let info = match naga::valid::Validator::new(params.validation_flags, validation_caps)
.subgroup_stages(naga::valid::ShaderStages::all())
.subgroup_operations(naga::valid::SubgroupOperationSet::all())
.validate(&module)
{
Ok(info) => Some(info),
Expand Down
3 changes: 1 addition & 2 deletions src/valid/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1537,8 +1537,7 @@ impl super::Validator {
return Err(ExpressionError::InvalidRayQueryType(query));
}
},
E::SubgroupBallotResult => ShaderStages::COMPUTE | ShaderStages::FRAGMENT,
E::SubgroupOperationResult { .. } => ShaderStages::COMPUTE | ShaderStages::FRAGMENT,
E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages,
};
Ok(stages)
}
Expand Down
61 changes: 51 additions & 10 deletions src/valid/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ pub enum SubgroupError {
InvalidOperand(Handle<crate::Expression>),
#[error("Result type for {0:?} doesn't match the statement")]
ResultTypeMismatch(Handle<crate::Expression>),
#[error("Support for subgroup operation {0:?} is required")]
UnsupportedOperation(super::SubgroupOperationSet),
}

#[derive(Clone, Debug, thiserror::Error)]
Expand Down Expand Up @@ -750,13 +752,24 @@ impl super::Validator {
}
S::Barrier(barrier) => {
stages &= super::ShaderStages::COMPUTE;
if barrier.contains(crate::Barrier::SUB_GROUP)
&& !self.capabilities.contains(super::Capabilities::SUBGROUP)
{
return Err(FunctionError::MissingCapability(
super::Capabilities::SUBGROUP,
)
.with_span_static(span, "subgroup operation"));
if barrier.contains(crate::Barrier::SUB_GROUP) {
if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
return Err(FunctionError::MissingCapability(
super::Capabilities::SUBGROUP,
)
.with_span_static(span, "missing capability for this operation"));
}
if !self
.subgroup_operations
.contains(super::SubgroupOperationSet::BASIC)
{
return Err(FunctionError::InvalidSubgroup(
SubgroupError::UnsupportedOperation(
super::SubgroupOperationSet::BASIC,
),
)
.with_span_static(span, "support for this operation is not present"));
}
}
}
S::Store { pointer, value } => {
Expand Down Expand Up @@ -1048,11 +1061,23 @@ impl super::Validator {
}
}
S::SubgroupBallot { result, predicate } => {
stages &= self.subgroup_stages;
if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
return Err(FunctionError::MissingCapability(
super::Capabilities::SUBGROUP,
)
.with_span_static(span, "subgroup operation"));
.with_span_static(span, "missing capability for this operation"));
}
if !self
.subgroup_operations
.contains(super::SubgroupOperationSet::BALLOT)
{
return Err(FunctionError::InvalidSubgroup(
SubgroupError::UnsupportedOperation(
super::SubgroupOperationSet::BALLOT,
),
)
.with_span_static(span, "support for this operation is not present"));
}
if let Some(predicate) = predicate {
let predicate_inner =
Expand Down Expand Up @@ -1081,11 +1106,19 @@ impl super::Validator {
argument,
result,
} => {
stages &= self.subgroup_stages;
if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
return Err(FunctionError::MissingCapability(
super::Capabilities::SUBGROUP,
)
.with_span_static(span, "subgroup operation"));
.with_span_static(span, "missing capability for this operation"));
}
let operation = op.required_operations();
if !self.subgroup_operations.contains(operation) {
return Err(FunctionError::InvalidSubgroup(
SubgroupError::UnsupportedOperation(operation),
)
.with_span_static(span, "support for this operation is not present"));
}
self.validate_subgroup_operation(op, collective_op, argument, result, context)?;
}
Expand All @@ -1094,11 +1127,19 @@ impl super::Validator {
argument,
result,
} => {
stages &= self.subgroup_stages;
if !self.capabilities.contains(super::Capabilities::SUBGROUP) {
return Err(FunctionError::MissingCapability(
super::Capabilities::SUBGROUP,
)
.with_span_static(span, "subgroup operation"));
.with_span_static(span, "missing capability for this operation"));
}
let operation = mode.required_operations();
if !self.subgroup_operations.contains(operation) {
return Err(FunctionError::InvalidSubgroup(
SubgroupError::UnsupportedOperation(operation),
)
.with_span_static(span, "support for this operation is not present"));
}
self.validate_subgroup_broadcast(mode, argument, result, context)?;
}
Expand Down
65 changes: 65 additions & 0 deletions src/valid/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,57 @@ impl Default for Capabilities {
}
}

bitflags::bitflags! {
/// Supported subgroup operations
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct SubgroupOperationSet: u8 {
/// Elect, Barrier
const BASIC = 1 << 0;
/// Any, All
const VOTE = 1 << 1;
/// reductions, scans
const ARITHMETIC = 1 << 2;
/// ballot, broadcast
const BALLOT = 1 << 3;
/// shuffle, shuffle xor
const SHUFFLE = 1 << 4;
/// shuffle up, down
const SHUFFLE_RELATIVE = 1 << 5;
// We don't support these operations yet
// /// Clustered
// const CLUSTERED = 1 << 6;
// /// Quad supported
// const QUAD_FRAMENT_COMPUTE = 1 << 7;
// /// Quad supported in all stages
// const QUAD_ALL_STAGES = 1 << 8;
}
}

impl super::SubgroupOperation {
const fn required_operations(&self) -> SubgroupOperationSet {
use SubgroupOperationSet as S;
match *self {
Self::All | Self::Any => S::VOTE,
Self::Add | Self::Mul | Self::Min | Self::Max | Self::And | Self::Or | Self::Xor => {
S::ARITHMETIC
}
}
}
}

impl super::GatherMode {
const fn required_operations(&self) -> SubgroupOperationSet {
use SubgroupOperationSet as S;
match *self {
Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT,
Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE,
Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE,
}
}
}

bitflags::bitflags! {
/// Validation flags.
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
Expand Down Expand Up @@ -175,6 +226,8 @@ impl ops::Index<Handle<crate::Expression>> for ModuleInfo {
pub struct Validator {
flags: ValidationFlags,
capabilities: Capabilities,
subgroup_stages: ShaderStages,
subgroup_operations: SubgroupOperationSet,
types: Vec<r#type::TypeInfo>,
layouter: Layouter,
location_mask: BitSet,
Expand Down Expand Up @@ -291,6 +344,8 @@ impl Validator {
Validator {
flags,
capabilities,
subgroup_stages: ShaderStages::empty(),
subgroup_operations: SubgroupOperationSet::empty(),
types: Vec::new(),
layouter: Layouter::default(),
location_mask: BitSet::new(),
Expand All @@ -301,6 +356,16 @@ impl Validator {
}
}

pub fn subgroup_stages(&mut self, stages: ShaderStages) -> &mut Self {
self.subgroup_stages = stages;
self
}

pub fn subgroup_operations(&mut self, operations: SubgroupOperationSet) -> &mut Self {
self.subgroup_operations = operations;
self
}

/// Reset the validator internals
pub fn reset(&mut self) {
self.types.clear();
Expand Down
18 changes: 15 additions & 3 deletions tests/snapshots.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,18 @@ fn check_targets(
let params = input.read_parameters();
let name = &input.file_name;

let capabilities = if params.god_mode {
naga::valid::Capabilities::all()
let (capabilities, subgroup_stages, subgroup_operations) = if params.god_mode {
(
naga::valid::Capabilities::all(),
naga::valid::ShaderStages::all(),
naga::valid::SubgroupOperationSet::all(),
)
} else {
naga::valid::Capabilities::default()
(
naga::valid::Capabilities::default(),
naga::valid::ShaderStages::empty(),
naga::valid::SubgroupOperationSet::empty(),
)
};

#[cfg(feature = "serialize")]
Expand All @@ -276,6 +284,8 @@ fn check_targets(
}

let info = naga::valid::Validator::new(naga::valid::ValidationFlags::all(), capabilities)
.subgroup_stages(subgroup_stages)
.subgroup_operations(subgroup_operations)
.validate(module)
.expect(&format!(
"Naga module validation failed on test '{}'",
Expand All @@ -296,6 +306,8 @@ fn check_targets(
}

naga::valid::Validator::new(naga::valid::ValidationFlags::all(), capabilities)
.subgroup_stages(subgroup_stages)
.subgroup_operations(subgroup_operations)
.validate(module)
.expect(&format!(
"Post-compaction module validation failed on test '{}'",
Expand Down

0 comments on commit 63e2da9

Please sign in to comment.