Skip to content

Commit

Permalink
Rudimentary impl of quad ops, impl quad ops for spirv
Browse files Browse the repository at this point in the history
  • Loading branch information
valaphee committed May 9, 2024
1 parent 3b6112d commit 5a77691
Show file tree
Hide file tree
Showing 19 changed files with 362 additions and 44 deletions.
17 changes: 16 additions & 1 deletion naga/src/back/dot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,8 @@ impl StatementGraph {
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => {
| crate::GatherMode::ShuffleXor(index)
| crate::GatherMode::QuadBroadcast(index) => {
self.dependencies.push((id, index, "index"))
}
}
Expand All @@ -365,6 +366,20 @@ impl StatementGraph {
crate::GatherMode::ShuffleDown(_) => "SubgroupShuffleDown",
crate::GatherMode::ShuffleUp(_) => "SubgroupShuffleUp",
crate::GatherMode::ShuffleXor(_) => "SubgroupShuffleXor",
crate::GatherMode::QuadBroadcast(_) => "SubgroupQuadBroadcast"
}
}
S::SubgroupQuadSwap {
direction,
argument,
result
} => {
self.dependencies.push((id, argument, "arg"));
self.emits.push((id, result));
match direction {
crate::Direction::X => "SubgroupQuadSwapX",
crate::Direction::Y => "SubgroupQuadSwapY",
crate::Direction::Diagonal => "SubgroupQuadSwapDiagonal",
}
}
};
Expand Down
28 changes: 27 additions & 1 deletion naga/src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2494,6 +2494,9 @@ impl<'a, W: Write> Writer<'a, W> {
crate::GatherMode::ShuffleXor(_) => {
write!(self.out, "subgroupShuffleXor(")?;
}
crate::GatherMode::QuadBroadcast(_) => {
write!(self.out, "subgroupQuadBroadcast(")?;
}
}
self.write_expr(argument, ctx)?;
match mode {
Expand All @@ -2502,13 +2505,36 @@ impl<'a, W: Write> Writer<'a, W> {
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => {
| crate::GatherMode::ShuffleXor(index)
| crate::GatherMode::QuadBroadcast(index) => {
write!(self.out, ", ")?;
self.write_expr(index, ctx)?;
}
}
writeln!(self.out, ");")?;
}
Statement::SubgroupQuadSwap { direction, argument, result } => {
write!(self.out, "{level}")?;
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
let res_ty = ctx.info[result].ty.inner_with(&self.module.types);
self.write_value_type(res_ty)?;
write!(self.out, " {res_name} = ")?;
self.named_expressions.insert(result, res_name);

match direction {
crate::Direction::X => {
write!(self.out, "subgroupQuadSwapHorizontal(")?;
}
crate::Direction::Y => {
write!(self.out, "subgroupQuadSwapVertical(")?;
}
crate::Direction::Diagonal => {
write!(self.out, "subgroupQuadSwapDiagonal(")?;
}
}
self.write_expr(argument, ctx)?;
writeln!(self.out, ");")?;
}
}

Ok(())
Expand Down
57 changes: 34 additions & 23 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2172,34 +2172,45 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
write!(self.out, " {name} = ")?;
self.named_expressions.insert(result, name);

if matches!(mode, crate::GatherMode::BroadcastFirst) {
write!(self.out, "WaveReadLaneFirst(")?;
self.write_expr(module, argument, func_ctx)?;
} else {
write!(self.out, "WaveReadLaneAt(")?;
self.write_expr(module, argument, func_ctx)?;
write!(self.out, ", ")?;
match mode {
crate::GatherMode::BroadcastFirst => unreachable!(),
crate::GatherMode::Broadcast(index) | crate::GatherMode::Shuffle(index) => {
self.write_expr(module, index, func_ctx)?;
}
crate::GatherMode::ShuffleDown(index) => {
write!(self.out, "WaveGetLaneIndex() + ")?;
self.write_expr(module, index, func_ctx)?;
}
crate::GatherMode::ShuffleUp(index) => {
write!(self.out, "WaveGetLaneIndex() - ")?;
self.write_expr(module, index, func_ctx)?;
}
crate::GatherMode::ShuffleXor(index) => {
write!(self.out, "WaveGetLaneIndex() ^ ")?;
self.write_expr(module, index, func_ctx)?;
match mode {
crate::GatherMode::BroadcastFirst => {
write!(self.out, "WaveReadLaneFirst(")?;
self.write_expr(module, argument, func_ctx)?;
}
crate::GatherMode::QuadBroadcast(index) => {
write!(self.out, "QuadReadLaneAt(")?;
self.write_expr(module, argument, func_ctx)?;
write!(self.out, ", ")?;
self.write_expr(module, index, func_ctx)?;
}
_ => {
write!(self.out, "WaveReadLaneAt(")?;
self.write_expr(module, argument, func_ctx)?;
write!(self.out, ", ")?;
match mode {
crate::GatherMode::BroadcastFirst => unreachable!(),
crate::GatherMode::Broadcast(index) | crate::GatherMode::Shuffle(index) => {
self.write_expr(module, index, func_ctx)?;
}
crate::GatherMode::ShuffleDown(index) => {
write!(self.out, "WaveGetLaneIndex() + ")?;
self.write_expr(module, index, func_ctx)?;
}
crate::GatherMode::ShuffleUp(index) => {
write!(self.out, "WaveGetLaneIndex() - ")?;
self.write_expr(module, index, func_ctx)?;
}
crate::GatherMode::ShuffleXor(index) => {
write!(self.out, "WaveGetLaneIndex() ^ ")?;
self.write_expr(module, index, func_ctx)?;
}
crate::GatherMode::QuadBroadcast(_) => unreachable!()
}
}
}
writeln!(self.out, ");")?;
}
Statement::SubgroupQuadSwap { direction, argument, result } => {}
}

Ok(())
Expand Down
9 changes: 7 additions & 2 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3201,6 +3201,9 @@ impl<W: Write> Writer<W> {
crate::GatherMode::ShuffleXor(_) => {
write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?;
}
crate::GatherMode::QuadBroadcast(_) => {
write!(self.out, "{NAMESPACE}::quad_broadcast(")?;
}
}
self.put_expression(argument, &context.expression, true)?;
match mode {
Expand All @@ -3209,13 +3212,15 @@ impl<W: Write> Writer<W> {
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => {
| crate::GatherMode::ShuffleXor(index)
| crate::GatherMode::QuadBroadcast(index) => {
write!(self.out, ", ")?;
self.put_expression(index, &context.expression, true)?;
}
}
writeln!(self.out, ");")?;
}
},
crate::Statement::SubgroupQuadSwap { direction, argument, result } => {}
}
}

Expand Down
11 changes: 10 additions & 1 deletion naga/src/back/pipeline_constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -669,13 +669,22 @@ fn adjust_stmt(new_pos: &[Handle<Expression>], stmt: &mut Statement) {
| crate::GatherMode::Shuffle(ref mut index)
| crate::GatherMode::ShuffleDown(ref mut index)
| crate::GatherMode::ShuffleUp(ref mut index)
| crate::GatherMode::ShuffleXor(ref mut index) => {
| crate::GatherMode::ShuffleXor(ref mut index)
| crate::GatherMode::QuadBroadcast(ref mut index) => {
adjust(index);
}
}
adjust(argument);
adjust(result)
}
Statement::SubgroupQuadSwap {
ref mut argument,
ref mut result,
..
} => {
adjust(argument);
adjust(result);
}
Statement::Call {
ref mut arguments,
ref mut result,
Expand Down
34 changes: 34 additions & 0 deletions naga/src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2513,6 +2513,40 @@ impl<'w> BlockContext<'w> {
} => {
self.write_subgroup_gather(mode, argument, result, &mut block)?;
}
crate::Statement::SubgroupQuadSwap {
ref direction,
argument,
result
} => {
self.writer.require_any(
"GroupNonUniformQuad",
&[spirv::Capability::GroupNonUniformQuad],
)?;

let id = self.gen_id();
let result_ty = &self.fun_info[result].ty;
let result_type_id = self.get_expression_type_id(result_ty);

let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32);

let arg_id = self.cached[argument];

let direction = self.get_index_constant(match direction {
crate::Direction::X => 0,
crate::Direction::Y => 1,
crate::Direction::Diagonal => 2,
});

block
.body
.push(Instruction::group_non_uniform_quad_swap(
result_type_id,
id,
exec_scope_id,
arg_id,
direction
));
}
}
}

Expand Down
32 changes: 32 additions & 0 deletions naga/src/back/spv/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,38 @@ impl super::Instruction {
}
instruction.add_operand(value);

instruction
}
pub(super) fn group_non_uniform_quad_broadcast(
result_type_id: Word,
id: Word,
exec_scope_id: Word,
value: Word,
index: Word,
) -> Self {
let mut instruction = Self::new(Op::GroupNonUniformQuadBroadcast);
instruction.set_type(result_type_id);
instruction.set_result(id);
instruction.add_operand(exec_scope_id);
instruction.add_operand(value);
instruction.add_operand(index);

instruction
}
pub(super) fn group_non_uniform_quad_swap(
result_type_id: Word,
id: Word,
exec_scope_id: Word,
value: Word,
direction: Word,
) -> Self {
let mut instruction = Self::new(Op::GroupNonUniformQuadSwap);
instruction.set_type(result_type_id);
instruction.set_result(id);
instruction.add_operand(exec_scope_id);
instruction.add_operand(value);
instruction.add_operand(direction);

instruction
}
}
Expand Down
14 changes: 9 additions & 5 deletions naga/src/back/spv/subgroup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,6 @@ impl<'w> BlockContext<'w> {
result: Handle<crate::Expression>,
block: &mut Block,
) -> Result<(), Error> {
self.writer.require_any(
"GroupNonUniformBallot",
&[spirv::Capability::GroupNonUniformBallot],
)?;
match *mode {
crate::GatherMode::BroadcastFirst | crate::GatherMode::Broadcast(_) => {
self.writer.require_any(
Expand All @@ -153,6 +149,12 @@ impl<'w> BlockContext<'w> {
&[spirv::Capability::GroupNonUniformShuffleRelative],
)?;
}
crate::GatherMode::QuadBroadcast(_) => {
self.writer.require_any(
"GroupNonUniformQuad",
&[spirv::Capability::GroupNonUniformQuad],
)?;
}
}

let id = self.gen_id();
Expand All @@ -177,7 +179,8 @@ impl<'w> BlockContext<'w> {
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => {
| crate::GatherMode::ShuffleXor(index)
| crate::GatherMode::QuadBroadcast(index) => {
let index_id = self.cached[index];
let op = match *mode {
crate::GatherMode::BroadcastFirst => unreachable!(),
Expand All @@ -190,6 +193,7 @@ impl<'w> BlockContext<'w> {
crate::GatherMode::ShuffleDown(_) => spirv::Op::GroupNonUniformShuffleDown,
crate::GatherMode::ShuffleUp(_) => spirv::Op::GroupNonUniformShuffleUp,
crate::GatherMode::ShuffleXor(_) => spirv::Op::GroupNonUniformShuffleXor,
crate::GatherMode::QuadBroadcast(_) => spirv::Op::GroupNonUniformQuadBroadcast,
};
block.body.push(Instruction::group_non_uniform_gather(
op,
Expand Down
7 changes: 6 additions & 1 deletion naga/src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,9 @@ impl<W: Write> Writer<W> {
crate::GatherMode::ShuffleXor(_) => {
write!(self.out, "subgroupShuffleXor(")?;
}
crate::GatherMode::QuadBroadcast(_) => {
write!(self.out, "quadBroadcast(")?;
}
}
self.write_expr(module, argument, func_ctx)?;
match mode {
Expand All @@ -1035,13 +1038,15 @@ impl<W: Write> Writer<W> {
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => {
| crate::GatherMode::ShuffleXor(index)
| crate::GatherMode::QuadBroadcast(index) => {
write!(self.out, ", ")?;
self.write_expr(module, index, func_ctx)?;
}
}
writeln!(self.out, ");")?;
}
Statement::SubgroupQuadSwap { direction, argument, result } => {}
}

Ok(())
Expand Down
14 changes: 12 additions & 2 deletions naga/src/compact/statements.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,18 @@ impl FunctionTracer<'_> {
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => {
| crate::GatherMode::ShuffleXor(index)
| crate::GatherMode::QuadBroadcast(index) => {
self.expressions_used.insert(index)
}
}
self.expressions_used.insert(argument);
self.expressions_used.insert(result)
}
St::SubgroupQuadSwap { direction, argument, result } => {
self.expressions_used.insert(argument);
self.expressions_used.insert(result)
},

// Trivial statements.
St::Break
Expand Down Expand Up @@ -312,11 +317,16 @@ impl FunctionMap {
| crate::GatherMode::Shuffle(ref mut index)
| crate::GatherMode::ShuffleDown(ref mut index)
| crate::GatherMode::ShuffleUp(ref mut index)
| crate::GatherMode::ShuffleXor(ref mut index) => adjust(index),
| crate::GatherMode::ShuffleXor(ref mut index)
| crate::GatherMode::QuadBroadcast(ref mut index) => adjust(index),
}
adjust(argument);
adjust(result);
}
St::SubgroupQuadSwap { direction: _, ref mut argument, ref mut result } => {
adjust(argument);
adjust(result);
}

// Trivial statements.
St::Break
Expand Down
Loading

0 comments on commit 5a77691

Please sign in to comment.