diff --git a/cli/src/bin/naga.rs b/cli/src/bin/naga.rs index 3b0873a376..26264b0b0a 100644 --- a/cli/src/bin/naga.rs +++ b/cli/src/bin/naga.rs @@ -401,6 +401,8 @@ fn run() -> Result<(), Box> { // Validate the IR before compaction. let info = match naga::valid::Validator::new(params.validation_flags, validation_caps) + .subgroup_stages(naga::valid::ShaderStages::all()) + .subgroup_operations(naga::valid::SubgroupOperationSet::all()) .validate(&module) { Ok(info) => Some(info), diff --git a/src/valid/expression.rs b/src/valid/expression.rs index b76091a122..231fb9f009 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -1537,8 +1537,7 @@ impl super::Validator { return Err(ExpressionError::InvalidRayQueryType(query)); } }, - E::SubgroupBallotResult => ShaderStages::COMPUTE | ShaderStages::FRAGMENT, - E::SubgroupOperationResult { .. } => ShaderStages::COMPUTE | ShaderStages::FRAGMENT, + E::SubgroupBallotResult | E::SubgroupOperationResult { .. } => self.subgroup_stages, }; Ok(stages) } diff --git a/src/valid/function.rs b/src/valid/function.rs index 09e368a7db..90183ba5b2 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -58,6 +58,8 @@ pub enum SubgroupError { InvalidOperand(Handle), #[error("Result type for {0:?} doesn't match the statement")] ResultTypeMismatch(Handle), + #[error("Support for subgroup operation {0:?} is required")] + UnsupportedOperation(super::SubgroupOperationSet), } #[derive(Clone, Debug, thiserror::Error)] @@ -750,13 +752,24 @@ impl super::Validator { } S::Barrier(barrier) => { stages &= super::ShaderStages::COMPUTE; - if barrier.contains(crate::Barrier::SUB_GROUP) - && !self.capabilities.contains(super::Capabilities::SUBGROUP) - { - return Err(FunctionError::MissingCapability( - super::Capabilities::SUBGROUP, - ) - .with_span_static(span, "subgroup operation")); + if barrier.contains(crate::Barrier::SUB_GROUP) { + if !self.capabilities.contains(super::Capabilities::SUBGROUP) { + return Err(FunctionError::MissingCapability( + super::Capabilities::SUBGROUP, + ) + .with_span_static(span, "missing capability for this operation")); + } + if !self + .subgroup_operations + .contains(super::SubgroupOperationSet::BASIC) + { + return Err(FunctionError::InvalidSubgroup( + SubgroupError::UnsupportedOperation( + super::SubgroupOperationSet::BASIC, + ), + ) + .with_span_static(span, "support for this operation is not present")); + } } } S::Store { pointer, value } => { @@ -1048,11 +1061,23 @@ impl super::Validator { } } S::SubgroupBallot { result, predicate } => { + stages &= self.subgroup_stages; if !self.capabilities.contains(super::Capabilities::SUBGROUP) { return Err(FunctionError::MissingCapability( super::Capabilities::SUBGROUP, ) - .with_span_static(span, "subgroup operation")); + .with_span_static(span, "missing capability for this operation")); + } + if !self + .subgroup_operations + .contains(super::SubgroupOperationSet::BALLOT) + { + return Err(FunctionError::InvalidSubgroup( + SubgroupError::UnsupportedOperation( + super::SubgroupOperationSet::BALLOT, + ), + ) + .with_span_static(span, "support for this operation is not present")); } if let Some(predicate) = predicate { let predicate_inner = @@ -1081,11 +1106,19 @@ impl super::Validator { argument, result, } => { + stages &= self.subgroup_stages; if !self.capabilities.contains(super::Capabilities::SUBGROUP) { return Err(FunctionError::MissingCapability( super::Capabilities::SUBGROUP, ) - .with_span_static(span, "subgroup operation")); + .with_span_static(span, "missing capability for this operation")); + } + let operation = op.required_operations(); + if !self.subgroup_operations.contains(operation) { + return Err(FunctionError::InvalidSubgroup( + SubgroupError::UnsupportedOperation(operation), + ) + .with_span_static(span, "support for this operation is not present")); } self.validate_subgroup_operation(op, collective_op, argument, result, context)?; } @@ -1094,11 +1127,19 @@ impl super::Validator { argument, result, } => { + stages &= self.subgroup_stages; if !self.capabilities.contains(super::Capabilities::SUBGROUP) { return Err(FunctionError::MissingCapability( super::Capabilities::SUBGROUP, ) - .with_span_static(span, "subgroup operation")); + .with_span_static(span, "missing capability for this operation")); + } + let operation = mode.required_operations(); + if !self.subgroup_operations.contains(operation) { + return Err(FunctionError::InvalidSubgroup( + SubgroupError::UnsupportedOperation(operation), + ) + .with_span_static(span, "support for this operation is not present")); } self.validate_subgroup_broadcast(mode, argument, result, context)?; } diff --git a/src/valid/mod.rs b/src/valid/mod.rs index 59f1f0e630..8cbb0fe8c6 100644 --- a/src/valid/mod.rs +++ b/src/valid/mod.rs @@ -128,6 +128,57 @@ impl Default for Capabilities { } } +bitflags::bitflags! { + /// Supported subgroup operations + #[cfg_attr(feature = "serialize", derive(serde::Serialize))] + #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] + #[derive(Clone, Copy, Debug, Default, Eq, PartialEq)] + pub struct SubgroupOperationSet: u8 { + /// Elect, Barrier + const BASIC = 1 << 0; + /// Any, All + const VOTE = 1 << 1; + /// reductions, scans + const ARITHMETIC = 1 << 2; + /// ballot, broadcast + const BALLOT = 1 << 3; + /// shuffle, shuffle xor + const SHUFFLE = 1 << 4; + /// shuffle up, down + const SHUFFLE_RELATIVE = 1 << 5; + // We don't support these operations yet + // /// Clustered + // const CLUSTERED = 1 << 6; + // /// Quad supported + // const QUAD_FRAMENT_COMPUTE = 1 << 7; + // /// Quad supported in all stages + // const QUAD_ALL_STAGES = 1 << 8; + } +} + +impl super::SubgroupOperation { + const fn required_operations(&self) -> SubgroupOperationSet { + use SubgroupOperationSet as S; + match *self { + Self::All | Self::Any => S::VOTE, + Self::Add | Self::Mul | Self::Min | Self::Max | Self::And | Self::Or | Self::Xor => { + S::ARITHMETIC + } + } + } +} + +impl super::GatherMode { + const fn required_operations(&self) -> SubgroupOperationSet { + use SubgroupOperationSet as S; + match *self { + Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT, + Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE, + Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE, + } + } +} + bitflags::bitflags! { /// Validation flags. #[cfg_attr(feature = "serialize", derive(serde::Serialize))] @@ -175,6 +226,8 @@ impl ops::Index> for ModuleInfo { pub struct Validator { flags: ValidationFlags, capabilities: Capabilities, + subgroup_stages: ShaderStages, + subgroup_operations: SubgroupOperationSet, types: Vec, layouter: Layouter, location_mask: BitSet, @@ -291,6 +344,8 @@ impl Validator { Validator { flags, capabilities, + subgroup_stages: ShaderStages::empty(), + subgroup_operations: SubgroupOperationSet::empty(), types: Vec::new(), layouter: Layouter::default(), location_mask: BitSet::new(), @@ -301,6 +356,16 @@ impl Validator { } } + pub fn subgroup_stages(&mut self, stages: ShaderStages) -> &mut Self { + self.subgroup_stages = stages; + self + } + + pub fn subgroup_operations(&mut self, operations: SubgroupOperationSet) -> &mut Self { + self.subgroup_operations = operations; + self + } + /// Reset the validator internals pub fn reset(&mut self) { self.types.clear(); diff --git a/tests/snapshots.rs b/tests/snapshots.rs index f52c0c9f34..ca70ddd61d 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -260,10 +260,18 @@ fn check_targets( let params = input.read_parameters(); let name = &input.file_name; - let capabilities = if params.god_mode { - naga::valid::Capabilities::all() + let (capabilities, subgroup_stages, subgroup_operations) = if params.god_mode { + ( + naga::valid::Capabilities::all(), + naga::valid::ShaderStages::all(), + naga::valid::SubgroupOperationSet::all(), + ) } else { - naga::valid::Capabilities::default() + ( + naga::valid::Capabilities::default(), + naga::valid::ShaderStages::empty(), + naga::valid::SubgroupOperationSet::empty(), + ) }; #[cfg(feature = "serialize")] @@ -276,6 +284,8 @@ fn check_targets( } let info = naga::valid::Validator::new(naga::valid::ValidationFlags::all(), capabilities) + .subgroup_stages(subgroup_stages) + .subgroup_operations(subgroup_operations) .validate(module) .expect(&format!( "Naga module validation failed on test '{}'", @@ -296,6 +306,8 @@ fn check_targets( } naga::valid::Validator::new(naga::valid::ValidationFlags::all(), capabilities) + .subgroup_stages(subgroup_stages) + .subgroup_operations(subgroup_operations) .validate(module) .expect(&format!( "Post-compaction module validation failed on test '{}'",