From 5a77691501de66a22fbccd6c09845e7e55b5538c Mon Sep 17 00:00:00 2001 From: valaphee <32491319+valaphee@users.noreply.github.com> Date: Thu, 9 May 2024 12:19:32 +0200 Subject: [PATCH] Rudimentary impl of quad ops, impl quad ops for spirv --- naga/src/back/dot/mod.rs | 17 ++++++- naga/src/back/glsl/mod.rs | 28 ++++++++++- naga/src/back/hlsl/writer.rs | 57 ++++++++++++--------- naga/src/back/msl/writer.rs | 9 +++- naga/src/back/pipeline_constants.rs | 11 +++- naga/src/back/spv/block.rs | 34 +++++++++++++ naga/src/back/spv/instructions.rs | 32 ++++++++++++ naga/src/back/spv/subgroup.rs | 14 ++++-- naga/src/back/wgsl/writer.rs | 7 ++- naga/src/compact/statements.rs | 14 +++++- naga/src/front/spv/mod.rs | 53 ++++++++++++++++++- naga/src/front/wgsl/lower/mod.rs | 70 +++++++++++++++++++++++++- naga/src/lib.rs | 22 ++++++++ naga/src/proc/terminator.rs | 1 + naga/src/valid/analyzer.rs | 11 +++- naga/src/valid/function.rs | 4 +- naga/src/valid/handles.rs | 12 ++++- naga/src/valid/mod.rs | 5 +- naga/tests/in/subgroup-operations.wgsl | 5 ++ 19 files changed, 362 insertions(+), 44 deletions(-) diff --git a/naga/src/back/dot/mod.rs b/naga/src/back/dot/mod.rs index 9a7702b3f6a..dd849c5dadf 100644 --- a/naga/src/back/dot/mod.rs +++ b/naga/src/back/dot/mod.rs @@ -352,7 +352,8 @@ impl StatementGraph { | crate::GatherMode::Shuffle(index) | crate::GatherMode::ShuffleDown(index) | crate::GatherMode::ShuffleUp(index) - | crate::GatherMode::ShuffleXor(index) => { + | crate::GatherMode::ShuffleXor(index) + | crate::GatherMode::QuadBroadcast(index) => { self.dependencies.push((id, index, "index")) } } @@ -365,6 +366,20 @@ impl StatementGraph { crate::GatherMode::ShuffleDown(_) => "SubgroupShuffleDown", crate::GatherMode::ShuffleUp(_) => "SubgroupShuffleUp", crate::GatherMode::ShuffleXor(_) => "SubgroupShuffleXor", + crate::GatherMode::QuadBroadcast(_) => "SubgroupQuadBroadcast" + } + } + S::SubgroupQuadSwap { + direction, + argument, + result + } => { + self.dependencies.push((id, argument, "arg")); + self.emits.push((id, result)); + match direction { + crate::Direction::X => "SubgroupQuadSwapX", + crate::Direction::Y => "SubgroupQuadSwapY", + crate::Direction::Diagonal => "SubgroupQuadSwapDiagonal", } } }; diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index c8c7ea557d8..5e1a03e566d 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -2494,6 +2494,9 @@ impl<'a, W: Write> Writer<'a, W> { crate::GatherMode::ShuffleXor(_) => { write!(self.out, "subgroupShuffleXor(")?; } + crate::GatherMode::QuadBroadcast(_) => { + write!(self.out, "subgroupQuadBroadcast(")?; + } } self.write_expr(argument, ctx)?; match mode { @@ -2502,13 +2505,36 @@ impl<'a, W: Write> Writer<'a, W> { | crate::GatherMode::Shuffle(index) | crate::GatherMode::ShuffleDown(index) | crate::GatherMode::ShuffleUp(index) - | crate::GatherMode::ShuffleXor(index) => { + | crate::GatherMode::ShuffleXor(index) + | crate::GatherMode::QuadBroadcast(index) => { write!(self.out, ", ")?; self.write_expr(index, ctx)?; } } writeln!(self.out, ");")?; } + Statement::SubgroupQuadSwap { direction, argument, result } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + let res_ty = ctx.info[result].ty.inner_with(&self.module.types); + self.write_value_type(res_ty)?; + write!(self.out, " {res_name} = ")?; + self.named_expressions.insert(result, res_name); + + match direction { + crate::Direction::X => { + write!(self.out, "subgroupQuadSwapHorizontal(")?; + } + crate::Direction::Y => { + write!(self.out, "subgroupQuadSwapVertical(")?; + } + crate::Direction::Diagonal => { + write!(self.out, "subgroupQuadSwapDiagonal(")?; + } + } + self.write_expr(argument, ctx)?; + writeln!(self.out, ");")?; + } } Ok(()) diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 86d8f890357..b825fd1314d 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -2172,34 +2172,45 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, " {name} = ")?; self.named_expressions.insert(result, name); - if matches!(mode, crate::GatherMode::BroadcastFirst) { - write!(self.out, "WaveReadLaneFirst(")?; - self.write_expr(module, argument, func_ctx)?; - } else { - write!(self.out, "WaveReadLaneAt(")?; - self.write_expr(module, argument, func_ctx)?; - write!(self.out, ", ")?; - match mode { - crate::GatherMode::BroadcastFirst => unreachable!(), - crate::GatherMode::Broadcast(index) | crate::GatherMode::Shuffle(index) => { - self.write_expr(module, index, func_ctx)?; - } - crate::GatherMode::ShuffleDown(index) => { - write!(self.out, "WaveGetLaneIndex() + ")?; - self.write_expr(module, index, func_ctx)?; - } - crate::GatherMode::ShuffleUp(index) => { - write!(self.out, "WaveGetLaneIndex() - ")?; - self.write_expr(module, index, func_ctx)?; - } - crate::GatherMode::ShuffleXor(index) => { - write!(self.out, "WaveGetLaneIndex() ^ ")?; - self.write_expr(module, index, func_ctx)?; + match mode { + crate::GatherMode::BroadcastFirst => { + write!(self.out, "WaveReadLaneFirst(")?; + self.write_expr(module, argument, func_ctx)?; + } + crate::GatherMode::QuadBroadcast(index) => { + write!(self.out, "QuadReadLaneAt(")?; + self.write_expr(module, argument, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, index, func_ctx)?; + } + _ => { + write!(self.out, "WaveReadLaneAt(")?; + self.write_expr(module, argument, func_ctx)?; + write!(self.out, ", ")?; + match mode { + crate::GatherMode::BroadcastFirst => unreachable!(), + crate::GatherMode::Broadcast(index) | crate::GatherMode::Shuffle(index) => { + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::ShuffleDown(index) => { + write!(self.out, "WaveGetLaneIndex() + ")?; + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::ShuffleUp(index) => { + write!(self.out, "WaveGetLaneIndex() - ")?; + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::ShuffleXor(index) => { + write!(self.out, "WaveGetLaneIndex() ^ ")?; + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::QuadBroadcast(_) => unreachable!() } } } writeln!(self.out, ");")?; } + Statement::SubgroupQuadSwap { direction, argument, result } => {} } Ok(()) diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index e250d0b72ca..ee5f1c74ab7 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -3201,6 +3201,9 @@ impl Writer { crate::GatherMode::ShuffleXor(_) => { write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?; } + crate::GatherMode::QuadBroadcast(_) => { + write!(self.out, "{NAMESPACE}::quad_broadcast(")?; + } } self.put_expression(argument, &context.expression, true)?; match mode { @@ -3209,13 +3212,15 @@ impl Writer { | crate::GatherMode::Shuffle(index) | crate::GatherMode::ShuffleDown(index) | crate::GatherMode::ShuffleUp(index) - | crate::GatherMode::ShuffleXor(index) => { + | crate::GatherMode::ShuffleXor(index) + | crate::GatherMode::QuadBroadcast(index) => { write!(self.out, ", ")?; self.put_expression(index, &context.expression, true)?; } } writeln!(self.out, ");")?; - } + }, + crate::Statement::SubgroupQuadSwap { direction, argument, result } => {} } } diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 0dbe9cf4e86..a7e9c700734 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -669,13 +669,22 @@ fn adjust_stmt(new_pos: &[Handle], stmt: &mut Statement) { | crate::GatherMode::Shuffle(ref mut index) | crate::GatherMode::ShuffleDown(ref mut index) | crate::GatherMode::ShuffleUp(ref mut index) - | crate::GatherMode::ShuffleXor(ref mut index) => { + | crate::GatherMode::ShuffleXor(ref mut index) + | crate::GatherMode::QuadBroadcast(ref mut index) => { adjust(index); } } adjust(argument); adjust(result) } + Statement::SubgroupQuadSwap { + ref mut argument, + ref mut result, + .. + } => { + adjust(argument); + adjust(result); + } Statement::Call { ref mut arguments, ref mut result, diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index 120d60fc403..98070691649 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -2513,6 +2513,40 @@ impl<'w> BlockContext<'w> { } => { self.write_subgroup_gather(mode, argument, result, &mut block)?; } + crate::Statement::SubgroupQuadSwap { + ref direction, + argument, + result + } => { + self.writer.require_any( + "GroupNonUniformQuad", + &[spirv::Capability::GroupNonUniformQuad], + )?; + + let id = self.gen_id(); + let result_ty = &self.fun_info[result].ty; + let result_type_id = self.get_expression_type_id(result_ty); + + let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); + + let arg_id = self.cached[argument]; + + let direction = self.get_index_constant(match direction { + crate::Direction::X => 0, + crate::Direction::Y => 1, + crate::Direction::Diagonal => 2, + }); + + block + .body + .push(Instruction::group_non_uniform_quad_swap( + result_type_id, + id, + exec_scope_id, + arg_id, + direction + )); + } } } diff --git a/naga/src/back/spv/instructions.rs b/naga/src/back/spv/instructions.rs index df2774ab9c2..073bb1f3ed1 100644 --- a/naga/src/back/spv/instructions.rs +++ b/naga/src/back/spv/instructions.rs @@ -1138,6 +1138,38 @@ impl super::Instruction { } instruction.add_operand(value); + instruction + } + pub(super) fn group_non_uniform_quad_broadcast( + result_type_id: Word, + id: Word, + exec_scope_id: Word, + value: Word, + index: Word, + ) -> Self { + let mut instruction = Self::new(Op::GroupNonUniformQuadBroadcast); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(exec_scope_id); + instruction.add_operand(value); + instruction.add_operand(index); + + instruction + } + pub(super) fn group_non_uniform_quad_swap( + result_type_id: Word, + id: Word, + exec_scope_id: Word, + value: Word, + direction: Word, + ) -> Self { + let mut instruction = Self::new(Op::GroupNonUniformQuadSwap); + instruction.set_type(result_type_id); + instruction.set_result(id); + instruction.add_operand(exec_scope_id); + instruction.add_operand(value); + instruction.add_operand(direction); + instruction } } diff --git a/naga/src/back/spv/subgroup.rs b/naga/src/back/spv/subgroup.rs index c952cb11a7b..d63f6697635 100644 --- a/naga/src/back/spv/subgroup.rs +++ b/naga/src/back/spv/subgroup.rs @@ -130,10 +130,6 @@ impl<'w> BlockContext<'w> { result: Handle, block: &mut Block, ) -> Result<(), Error> { - self.writer.require_any( - "GroupNonUniformBallot", - &[spirv::Capability::GroupNonUniformBallot], - )?; match *mode { crate::GatherMode::BroadcastFirst | crate::GatherMode::Broadcast(_) => { self.writer.require_any( @@ -153,6 +149,12 @@ impl<'w> BlockContext<'w> { &[spirv::Capability::GroupNonUniformShuffleRelative], )?; } + crate::GatherMode::QuadBroadcast(_) => { + self.writer.require_any( + "GroupNonUniformQuad", + &[spirv::Capability::GroupNonUniformQuad], + )?; + } } let id = self.gen_id(); @@ -177,7 +179,8 @@ impl<'w> BlockContext<'w> { | crate::GatherMode::Shuffle(index) | crate::GatherMode::ShuffleDown(index) | crate::GatherMode::ShuffleUp(index) - | crate::GatherMode::ShuffleXor(index) => { + | crate::GatherMode::ShuffleXor(index) + | crate::GatherMode::QuadBroadcast(index) => { let index_id = self.cached[index]; let op = match *mode { crate::GatherMode::BroadcastFirst => unreachable!(), @@ -190,6 +193,7 @@ impl<'w> BlockContext<'w> { crate::GatherMode::ShuffleDown(_) => spirv::Op::GroupNonUniformShuffleDown, crate::GatherMode::ShuffleUp(_) => spirv::Op::GroupNonUniformShuffleUp, crate::GatherMode::ShuffleXor(_) => spirv::Op::GroupNonUniformShuffleXor, + crate::GatherMode::QuadBroadcast(_) => spirv::Op::GroupNonUniformQuadBroadcast, }; block.body.push(Instruction::group_non_uniform_gather( op, diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 789f6f62bfc..c6a5f9484e0 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -1027,6 +1027,9 @@ impl Writer { crate::GatherMode::ShuffleXor(_) => { write!(self.out, "subgroupShuffleXor(")?; } + crate::GatherMode::QuadBroadcast(_) => { + write!(self.out, "quadBroadcast(")?; + } } self.write_expr(module, argument, func_ctx)?; match mode { @@ -1035,13 +1038,15 @@ impl Writer { | crate::GatherMode::Shuffle(index) | crate::GatherMode::ShuffleDown(index) | crate::GatherMode::ShuffleUp(index) - | crate::GatherMode::ShuffleXor(index) => { + | crate::GatherMode::ShuffleXor(index) + | crate::GatherMode::QuadBroadcast(index) => { write!(self.out, ", ")?; self.write_expr(module, index, func_ctx)?; } } writeln!(self.out, ");")?; } + Statement::SubgroupQuadSwap { direction, argument, result } => {} } Ok(()) diff --git a/naga/src/compact/statements.rs b/naga/src/compact/statements.rs index a124281bc12..d2366f6eb54 100644 --- a/naga/src/compact/statements.rs +++ b/naga/src/compact/statements.rs @@ -123,13 +123,18 @@ impl FunctionTracer<'_> { | crate::GatherMode::Shuffle(index) | crate::GatherMode::ShuffleDown(index) | crate::GatherMode::ShuffleUp(index) - | crate::GatherMode::ShuffleXor(index) => { + | crate::GatherMode::ShuffleXor(index) + | crate::GatherMode::QuadBroadcast(index) => { self.expressions_used.insert(index) } } self.expressions_used.insert(argument); self.expressions_used.insert(result) } + St::SubgroupQuadSwap { direction, argument, result } => { + self.expressions_used.insert(argument); + self.expressions_used.insert(result) + }, // Trivial statements. St::Break @@ -312,11 +317,16 @@ impl FunctionMap { | crate::GatherMode::Shuffle(ref mut index) | crate::GatherMode::ShuffleDown(ref mut index) | crate::GatherMode::ShuffleUp(ref mut index) - | crate::GatherMode::ShuffleXor(ref mut index) => adjust(index), + | crate::GatherMode::ShuffleXor(ref mut index) + | crate::GatherMode::QuadBroadcast(ref mut index) => adjust(index), } adjust(argument); adjust(result); } + St::SubgroupQuadSwap { direction: _, ref mut argument, ref mut result } => { + adjust(argument); + adjust(result); + } // Trivial statements. St::Break diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index 7ac5a18cd68..8b10c1c99e2 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -3873,7 +3873,8 @@ impl> Frontend { | Op::GroupNonUniformShuffle | Op::GroupNonUniformShuffleDown | Op::GroupNonUniformShuffleUp - | Op::GroupNonUniformShuffleXor => { + | Op::GroupNonUniformShuffleXor + | Op::GroupNonUniformQuadBroadcast => { inst.expect( if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) { 5 @@ -3917,6 +3918,9 @@ impl> Frontend { spirv::Op::GroupNonUniformShuffleXor => { crate::GatherMode::ShuffleXor(index_handle) } + spirv::Op::GroupNonUniformQuadBroadcast => { + crate::GatherMode::QuadBroadcast(index_handle) + } _ => unreachable!(), } }; @@ -3948,6 +3952,50 @@ impl> Frontend { ); emitter.start(ctx.expressions); } + Op::GroupNonUniformQuadSwap => { + inst.expect(6)?; + block.extend(emitter.finish(ctx.expressions)); + let result_type_id = self.next()?; + let result_id = self.next()?; + let exec_scope_id = self.next()?; + let argument_id = self.next()?; + let direction = self.next()?; + + let argument_lookup = self.lookup_expression.lookup(argument_id)?; + let argument_handle = get_expr_handle!(argument_id, argument_lookup); + + let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; + let _exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner) + .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) + .ok_or(Error::InvalidBarrierScope(exec_scope_id))?; + + let result_type = self.lookup_type.lookup(result_type_id)?; + + let result_handle = ctx.expressions.append( + crate::Expression::SubgroupOperationResult { + ty: result_type.handle, + }, + span, + ); + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: result_handle, + type_id: result_type_id, + block_id, + }, + ); + + block.push( + crate::Statement::SubgroupQuadSwap { + direction: crate::Direction::X, + result: result_handle, + argument: argument_handle, + }, + span, + ); + emitter.start(ctx.expressions); + } _ => return Err(Error::UnsupportedInstruction(self.state, inst.op)), } }; @@ -4075,7 +4123,8 @@ impl> Frontend { | S::RayQuery { .. } | S::SubgroupBallot { .. } | S::SubgroupCollectiveOperation { .. } - | S::SubgroupGather { .. } => {} + | S::SubgroupGather { .. } + | S::SubgroupQuadSwap { .. } => {} S::Call { function: ref mut callee, ref arguments, diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index e7cce177230..324a5524d4c 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -881,6 +881,7 @@ enum SubgroupGather { ShuffleDown, ShuffleUp, ShuffleXor, + QuadBroadcast, } impl SubgroupGather { @@ -892,6 +893,7 @@ impl SubgroupGather { "subgroupShuffleDown" => Self::ShuffleDown, "subgroupShuffleUp" => Self::ShuffleUp, "subgroupShuffleXor" => Self::ShuffleXor, + "quadBroadcast" => Self::QuadBroadcast, _ => return None, }) } @@ -2421,6 +2423,71 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .push(crate::Statement::SubgroupBallot { result, predicate }, span); return Ok(Some(result)); } + "quadSwapX" => { + let mut args = ctx.prepare_args(arguments, 1, span); + + let argument = self.expression(args.next()?, ctx)?; + args.finish()?; + + let ty = ctx.register_type(argument)?; + + let result = + ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::SubgroupQuadSwap { + direction: crate::Direction::X, + argument, + result, + }, + span, + ); + return Ok(Some(result)) + } + + "quadSwapY" => { + let mut args = ctx.prepare_args(arguments, 1, span); + + let argument = self.expression(args.next()?, ctx)?; + args.finish()?; + + let ty = ctx.register_type(argument)?; + + let result = + ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::SubgroupQuadSwap { + direction: crate::Direction::Y, + argument, + result, + }, + span, + ); + return Ok(Some(result)) + } + + "quadSwapDiagonal" => { + let mut args = ctx.prepare_args(arguments, 1, span); + + let argument = self.expression(args.next()?, ctx)?; + args.finish()?; + + let ty = ctx.register_type(argument)?; + + let result = + ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::SubgroupQuadSwap { + direction: crate::Direction::Diagonal, + argument, + result, + }, + span, + ); + return Ok(Some(result)) + } _ => return Err(Error::UnknownIdent(function.span, function.name)), } }; @@ -2659,12 +2726,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } else { let index = self.expression(args.next()?, ctx)?; match mode { + Sg::BroadcastFirst => unreachable!(), Sg::Broadcast => crate::GatherMode::Broadcast(index), Sg::Shuffle => crate::GatherMode::Shuffle(index), Sg::ShuffleDown => crate::GatherMode::ShuffleDown(index), Sg::ShuffleUp => crate::GatherMode::ShuffleUp(index), Sg::ShuffleXor => crate::GatherMode::ShuffleXor(index), - Sg::BroadcastFirst => unreachable!(), + Sg::QuadBroadcast => crate::GatherMode::QuadBroadcast(index), } }; diff --git a/naga/src/lib.rs b/naga/src/lib.rs index 24e1b02c769..b011b8eb2b2 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -1300,6 +1300,8 @@ pub enum GatherMode { ShuffleUp(Handle), /// Each gathers from their lane xored with the given by the expression ShuffleXor(Handle), + /// All gather from the same lane at the index given by the expression + QuadBroadcast(Handle), } #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] @@ -1328,6 +1330,16 @@ pub enum CollectiveOperation { ExclusiveScan = 2, } +#[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum Direction { + X = 0, + Y = 1, + Diagonal = 2, +} + bitflags::bitflags! { /// Memory barrier flags. #[cfg_attr(feature = "serialize", derive(Serialize))] @@ -1967,6 +1979,16 @@ pub enum Statement { /// [`SubgroupOperationResult`]: Expression::SubgroupOperationResult result: Handle, }, + SubgroupQuadSwap { + /// In which direction to swap + direction: Direction, + /// The value to swap over + argument: Handle, + /// The [`SubgroupOperationResult`] expression representing this load's result. + /// + /// [`SubgroupOperationResult`]: Expression::SubgroupOperationResult + result: Handle, + } } /// A function argument. diff --git a/naga/src/proc/terminator.rs b/naga/src/proc/terminator.rs index 5edf55cb73a..f79dd399b00 100644 --- a/naga/src/proc/terminator.rs +++ b/naga/src/proc/terminator.rs @@ -40,6 +40,7 @@ pub fn ensure_block_returns(block: &mut crate::Block) { | S::SubgroupBallot { .. } | S::SubgroupCollectiveOperation { .. } | S::SubgroupGather { .. } + | S::SubgroupQuadSwap { .. } | S::Barrier(_)), ) | None => block.push(S::Return { value: None }, Default::default()), diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index 6799e5db278..40aaf1b40f3 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -1067,12 +1067,21 @@ impl FunctionInfo { | crate::GatherMode::Shuffle(index) | crate::GatherMode::ShuffleDown(index) | crate::GatherMode::ShuffleUp(index) - | crate::GatherMode::ShuffleXor(index) => { + | crate::GatherMode::ShuffleXor(index) + | crate::GatherMode::QuadBroadcast(index) => { let _ = self.add_ref(index); } } FunctionUniformity::new() } + S::SubgroupQuadSwap { + direction: _, + argument, + result: _, + } => { + let _ = self.add_ref(argument); + FunctionUniformity::new() + } }; disruptor = disruptor.or(uniformity.exit_disruptor()); diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index 71128fc86da..c9508c6674a 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -498,7 +498,8 @@ impl super::Validator { | crate::GatherMode::Shuffle(index) | crate::GatherMode::ShuffleDown(index) | crate::GatherMode::ShuffleUp(index) - | crate::GatherMode::ShuffleXor(index) => { + | crate::GatherMode::ShuffleXor(index) + | crate::GatherMode::QuadBroadcast(index) => { let index_ty = context.resolve_type(index, &self.valid_expression_set)?; match *index_ty { crate::TypeInner::Scalar(crate::Scalar::U32) => {} @@ -1144,6 +1145,7 @@ impl super::Validator { } self.validate_subgroup_gather(mode, argument, result, context)?; } + S::SubgroupQuadSwap { direction, argument, result } => {} } } Ok(BlockInfo { stages, finished }) diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index 8f782040552..6970a577817 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -594,11 +594,21 @@ impl super::Validator { | crate::GatherMode::Shuffle(index) | crate::GatherMode::ShuffleDown(index) | crate::GatherMode::ShuffleUp(index) - | crate::GatherMode::ShuffleXor(index) => validate_expr(index)?, + | crate::GatherMode::ShuffleXor(index) + | crate::GatherMode::QuadBroadcast(index) => validate_expr(index)?, } validate_expr(result)?; Ok(()) } + crate::Statement::SubgroupQuadSwap { + direction: _, + argument, + result, + } => { + validate_expr(argument)?; + validate_expr(result)?; + Ok(()) + } crate::Statement::Break | crate::Statement::Continue | crate::Statement::Kill diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index a0057f39acf..c7d4be9f556 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -144,8 +144,8 @@ bitflags::bitflags! { // We don't support these operations yet // /// Clustered // const CLUSTERED = 1 << 6; - // /// Quad supported - // const QUAD_FRAGMENT_COMPUTE = 1 << 7; + /// Quad supported + const QUAD_FRAGMENT_COMPUTE = 1 << 7; // /// Quad supported in all stages // const QUAD_ALL_STAGES = 1 << 8; } @@ -170,6 +170,7 @@ impl super::GatherMode { Self::BroadcastFirst | Self::Broadcast(_) => S::BALLOT, Self::Shuffle(_) | Self::ShuffleXor(_) => S::SHUFFLE, Self::ShuffleUp(_) | Self::ShuffleDown(_) => S::SHUFFLE_RELATIVE, + Self::QuadBroadcast(_) => S::QUAD_FRAGMENT_COMPUTE, } } } diff --git a/naga/tests/in/subgroup-operations.wgsl b/naga/tests/in/subgroup-operations.wgsl index bb6eb47fb51..26b3d98e84a 100644 --- a/naga/tests/in/subgroup-operations.wgsl +++ b/naga/tests/in/subgroup-operations.wgsl @@ -34,4 +34,9 @@ fn main( subgroupShuffleDown(subgroup_invocation_id, 1u); subgroupShuffleUp(subgroup_invocation_id, 1u); subgroupShuffleXor(subgroup_invocation_id, sizes.subgroup_size - 1u); + + quadBroadcast(subgroup_invocation_id, 4u); + quadSwapX(subgroup_invocation_id); + quadSwapY(subgroup_invocation_id); + quadSwapDiagonal(subgroup_invocation_id); }