From fd3ed3b285b71074b413b2c17ec3278ff7685239 Mon Sep 17 00:00:00 2001 From: valaphee <32491319+valaphee@users.noreply.github.com> Date: Thu, 16 May 2024 15:20:55 +0200 Subject: [PATCH] Impl quad swap for hlsl, msl and wgsl, finish spv front --- naga/src/back/hlsl/writer.rs | 28 +++++++++++++++++++++++++++- naga/src/back/msl/writer.rs | 22 +++++++++++++++++++++- naga/src/back/wgsl/writer.rs | 21 ++++++++++++++++++++- naga/src/front/spv/mod.rs | 14 ++++++++++++-- 4 files changed, 80 insertions(+), 5 deletions(-) diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index b825fd1314d..6eea0392178 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -2210,7 +2210,33 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } writeln!(self.out, ");")?; } - Statement::SubgroupQuadSwap { direction, argument, result } => {} + Statement::SubgroupQuadSwap { direction, argument, result } => { + write!(self.out, "{level}")?; + write!(self.out, "const ")?; + let name = format!("{}{}", back::BAKE_PREFIX, result.index()); + match func_ctx.info[result].ty { + proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?, + proc::TypeResolution::Value(ref value) => { + self.write_value_type(module, value)? + } + }; + write!(self.out, " {name} = ")?; + self.named_expressions.insert(result, name); + + match direction { + crate::Direction::X => { + write!(self.out, "QuadReadAcrossX(")?; + }, + crate::Direction::Y => { + write!(self.out, "QuadReadAcrossY(")?; + }, + crate::Direction::Diagonal => { + write!(self.out, "QuadReadAcrossDiagonal(")?; + }, + } + self.write_expr(module, argument, func_ctx)?; + writeln!(self.out, ");")?; + } } Ok(()) diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index ee5f1c74ab7..d8833084a93 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -3220,7 +3220,27 @@ impl Writer { } writeln!(self.out, ");")?; }, - crate::Statement::SubgroupQuadSwap { direction, argument, result } => {} + crate::Statement::SubgroupQuadSwap { direction, argument, result } => { + write!(self.out, "{level}")?; + let name = self.namer.call(""); + self.start_baking_expression(result, &context.expression, &name)?; + self.named_expressions.insert(result, name); + write!(self.out, "{NAMESPACE}::quad_shuffle_xor(")?; + self.put_expression(argument, &context.expression, true)?; + write!(self.out, ", ")?; + match direction { + crate::Direction::X => { + write!(self.out, "0x01")?; + }, + crate::Direction::Y => { + write!(self.out, "0x10")?; + }, + crate::Direction::Diagonal => { + write!(self.out, "0x11")?; + }, + } + writeln!(self.out, ");")?; + } } } diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index c6a5f9484e0..be34447a3e3 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -1046,7 +1046,26 @@ impl Writer { } writeln!(self.out, ");")?; } - Statement::SubgroupQuadSwap { direction, argument, result } => {} + Statement::SubgroupQuadSwap { direction, argument, result } => { + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_named_expr(module, result, func_ctx, &res_name)?; + self.named_expressions.insert(result, res_name); + + match direction { + crate::Direction::X => { + write!(self.out, "quadSwapX(")?; + }, + crate::Direction::Y => { + write!(self.out, "quadSwapY(")?; + }, + crate::Direction::Diagonal => { + write!(self.out, "quadSwapDiagonal(")?; + }, + } + self.write_expr(module, argument, func_ctx)?; + writeln!(self.out, ");")?; + } } Ok(()) diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index 8b10c1c99e2..68d195fcd16 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -3959,7 +3959,7 @@ impl> Frontend { let result_id = self.next()?; let exec_scope_id = self.next()?; let argument_id = self.next()?; - let direction = self.next()?; + let direction_id = self.next()?; let argument_lookup = self.lookup_expression.lookup(argument_id)?; let argument_handle = get_expr_handle!(argument_id, argument_lookup); @@ -3969,6 +3969,16 @@ impl> Frontend { .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) .ok_or(Error::InvalidBarrierScope(exec_scope_id))?; + let direction_const = self.lookup_constant.lookup(direction_id)?; + let direction_const = resolve_constant(ctx.gctx(), &direction_const.inner) + .ok_or(Error::InvalidOperand)?; + let direction = match direction_const { + 0 => crate::Direction::X, + 1 => crate::Direction::Y, + 2 => crate::Direction::Diagonal, + _ => unreachable!() + }; + let result_type = self.lookup_type.lookup(result_type_id)?; let result_handle = ctx.expressions.append( @@ -3988,7 +3998,7 @@ impl> Frontend { block.push( crate::Statement::SubgroupQuadSwap { - direction: crate::Direction::X, + direction, result: result_handle, argument: argument_handle, },