From c257748043128b7031a2057a81fe474f28830e75 Mon Sep 17 00:00:00 2001 From: Jacob Hughes Date: Mon, 23 Oct 2023 17:20:35 -0700 Subject: [PATCH] subgroup: add validation for each subgroup operation type --- cli/src/bin/naga.rs | 2 ++ src/valid/function.rs | 33 ++++++++++++++++++++-- src/valid/mod.rs | 65 +++++++++++++++++++++++++++++++++++++++++++ tests/snapshots.rs | 18 ++++++++++-- 4 files changed, 112 insertions(+), 6 deletions(-) diff --git a/cli/src/bin/naga.rs b/cli/src/bin/naga.rs index 3b0873a376..26264b0b0a 100644 --- a/cli/src/bin/naga.rs +++ b/cli/src/bin/naga.rs @@ -401,6 +401,8 @@ fn run() -> Result<(), Box> { // Validate the IR before compaction. let info = match naga::valid::Validator::new(params.validation_flags, validation_caps) + .subgroup_stages(naga::valid::ShaderStages::all()) + .subgroup_operations(naga::valid::SubgroupOperationSet::all()) .validate(&module) { Ok(info) => Some(info), diff --git a/src/valid/function.rs b/src/valid/function.rs index 09e368a7db..c315b7ee9a 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -58,6 +58,8 @@ pub enum SubgroupError { InvalidOperand(Handle), #[error("Result type for {0:?} doesn't match the statement")] ResultTypeMismatch(Handle), + #[error("Support for subgroup operation {0:?} is required")] + UnsupportedOperation(super::SubgroupOperationSet), } #[derive(Clone, Debug, thiserror::Error)] @@ -756,7 +758,7 @@ impl super::Validator { return Err(FunctionError::MissingCapability( super::Capabilities::SUBGROUP, ) - .with_span_static(span, "subgroup operation")); + .with_span_static(span, "missing capability")); } } S::Store { pointer, value } => { @@ -1052,7 +1054,18 @@ impl super::Validator { return Err(FunctionError::MissingCapability( super::Capabilities::SUBGROUP, ) - .with_span_static(span, "subgroup operation")); + .with_span_static(span, "missing capability")); + } + if !self + .subgroup_operations + .contains(super::SubgroupOperationSet::BALLOT) + { + return Err(FunctionError::InvalidSubgroup( + SubgroupError::UnsupportedOperation( + super::SubgroupOperationSet::BALLOT, + ), + ) + .with_span_static(span, "support for this operation is not present")); } if let Some(predicate) = predicate { let predicate_inner = @@ -1085,7 +1098,14 @@ impl super::Validator { return Err(FunctionError::MissingCapability( super::Capabilities::SUBGROUP, ) - .with_span_static(span, "subgroup operation")); + .with_span_static(span, "missing capability")); + } + let operation = op.required_operations(); + if !self.subgroup_operations.contains(operation) { + return Err(FunctionError::InvalidSubgroup( + SubgroupError::UnsupportedOperation(operation), + ) + .with_span_static(span, "support for this operation is not present")); } self.validate_subgroup_operation(op, collective_op, argument, result, context)?; } @@ -1100,6 +1120,13 @@ impl super::Validator { ) .with_span_static(span, "subgroup operation")); } + let operation = mode.required_operations(); + if !self.subgroup_operations.contains(operation) { + return Err(FunctionError::InvalidSubgroup( + SubgroupError::UnsupportedOperation(operation), + ) + .with_span_static(span, "support for this operation is not present")); + } self.validate_subgroup_broadcast(mode, argument, result, context)?; } } diff --git a/src/valid/mod.rs b/src/valid/mod.rs index 59f1f0e630..64028e8e37 100644 --- a/src/valid/mod.rs +++ b/src/valid/mod.rs @@ -128,6 +128,57 @@ impl Default for Capabilities { } } +bitflags::bitflags! { + /// Supported subgroup operations + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] + #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] + pub struct SubgroupOperationSet: u8 { + /// Elect, Barrier + const BASIC = 1 << 0; + /// Any, All + const VOTE = 1 << 1; + /// reductions, scans + const ARITHMETIC = 1 << 2; + /// ballot, broadcast + const BALLOT = 1 << 3; + /// shuffle + const SHUFFLE = 1 << 4; + /// shuffle up, down, xor + const SHUFFLE_RELATIVE = 1 << 5; + // We don't support these operations yet + // /// Clustered + // const CLUSTERED = 1 << 6; + // /// Quad supported + // const QUAD_FRAMENT_COMPUTE = 1 << 7; + // /// Quad supported in all stages + // const QUAD_ALL_STAGES = 1 << 8; + } +} + +impl super::SubgroupOperation { + fn required_operations(&self) -> SubgroupOperationSet { + use SubgroupOperationSet as S; + match self { + Self::All | Self::Any => S::VOTE, + Self::Add | Self::Mul | Self::Min | Self::Max | Self::And | Self::Or | Self::Xor => { + S::ARITHMETIC + } + } + } +} + +impl super::GatherMode { + fn required_operations(&self) -> SubgroupOperationSet { + use SubgroupOperationSet as S; + match self { + Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT, + Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE, + Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE, + } + } +} + bitflags::bitflags! { /// Validation flags. #[cfg_attr(feature = "serialize", derive(serde::Serialize))] @@ -175,6 +226,8 @@ impl ops::Index> for ModuleInfo { pub struct Validator { flags: ValidationFlags, capabilities: Capabilities, + subgroup_stages: ShaderStages, + subgroup_operations: SubgroupOperationSet, types: Vec, layouter: Layouter, location_mask: BitSet, @@ -291,6 +344,8 @@ impl Validator { Validator { flags, capabilities, + subgroup_stages: ShaderStages::empty(), + subgroup_operations: SubgroupOperationSet::empty(), types: Vec::new(), layouter: Layouter::default(), location_mask: BitSet::new(), @@ -301,6 +356,16 @@ impl Validator { } } + pub fn subgroup_stages(&mut self, stages: ShaderStages) -> &mut Self { + self.subgroup_stages = stages; + self + } + + pub fn subgroup_operations(&mut self, operations: SubgroupOperationSet) -> &mut Self { + self.subgroup_operations = operations; + self + } + /// Reset the validator internals pub fn reset(&mut self) { self.types.clear(); diff --git a/tests/snapshots.rs b/tests/snapshots.rs index f52c0c9f34..ca70ddd61d 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -260,10 +260,18 @@ fn check_targets( let params = input.read_parameters(); let name = &input.file_name; - let capabilities = if params.god_mode { - naga::valid::Capabilities::all() + let (capabilities, subgroup_stages, subgroup_operations) = if params.god_mode { + ( + naga::valid::Capabilities::all(), + naga::valid::ShaderStages::all(), + naga::valid::SubgroupOperationSet::all(), + ) } else { - naga::valid::Capabilities::default() + ( + naga::valid::Capabilities::default(), + naga::valid::ShaderStages::empty(), + naga::valid::SubgroupOperationSet::empty(), + ) }; #[cfg(feature = "serialize")] @@ -276,6 +284,8 @@ fn check_targets( } let info = naga::valid::Validator::new(naga::valid::ValidationFlags::all(), capabilities) + .subgroup_stages(subgroup_stages) + .subgroup_operations(subgroup_operations) .validate(module) .expect(&format!( "Naga module validation failed on test '{}'", @@ -296,6 +306,8 @@ fn check_targets( } naga::valid::Validator::new(naga::valid::ValidationFlags::all(), capabilities) + .subgroup_stages(subgroup_stages) + .subgroup_operations(subgroup_operations) .validate(module) .expect(&format!( "Post-compaction module validation failed on test '{}'",