From a3b6b6dab265c939b3bdf08962d414cca34e4272 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Mei=C3=9Fner?= Date: Thu, 12 Oct 2023 01:07:12 +0200 Subject: [PATCH] Implements all frontends and backends. --- src/back/dot/mod.rs | 74 +++++++++++- src/back/glsl/mod.rs | 109 +++++++++++++++-- src/back/hlsl/conv.rs | 16 ++- src/back/hlsl/writer.rs | 110 +++++++++++++++-- src/back/msl/mod.rs | 6 +- src/back/msl/writer.rs | 117 +++++++++++++++--- src/back/spv/block.rs | 33 ++--- src/back/spv/instructions.rs | 17 ++- src/back/spv/subgroup.rs | 89 ++++++++++---- src/back/spv/writer.rs | 25 +++- src/back/wgsl/writer.rs | 105 ++++++++++++++-- src/compact/statements.rs | 34 ++++-- src/front/spv/error.rs | 6 +- src/front/spv/mod.rs | 213 ++++++++++++++++++++++++++++++++- src/front/wgsl/lower/mod.rs | 103 ++++++++-------- src/front/wgsl/parse/conv.rs | 25 +++- src/lib.rs | 20 ++-- src/proc/constant_evaluator.rs | 8 ++ src/proc/terminator.rs | 2 +- src/valid/analyzer.rs | 25 ++-- src/valid/expression.rs | 2 +- src/valid/function.rs | 49 ++++---- src/valid/handles.rs | 16 ++- src/valid/interface.rs | 4 +- 24 files changed, 977 insertions(+), 231 deletions(-) diff --git a/src/back/dot/mod.rs b/src/back/dot/mod.rs index fec2c60d32..86f4797b56 100644 --- a/src/back/dot/mod.rs +++ b/src/back/dot/mod.rs @@ -287,23 +287,85 @@ impl StatementGraph { "SubgroupBallot" } S::SubgroupCollectiveOperation { - ref op, - ref collective_op, + op, + collective_op, argument, result, } => { self.dependencies.push((id, argument, "arg")); self.emits.push((id, result)); - "SubgroupCollectiveOperation" // FIXME + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + "SubgroupAll" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + "SubgroupAny" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + "SubgroupAdd" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + "SubgroupMul" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + "SubgroupMax" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + "SubgroupMin" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + "SubgroupAnd" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + "SubgroupOr" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + "SubgroupXor" + } + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Add, + ) => "SubgroupPrefixExclusiveAdd", + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Mul, + ) => "SubgroupPrefixExclusiveMul", + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Add, + ) => "SubgroupPrefixInclusiveAdd", + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Mul, + ) => "SubgroupPrefixInclusiveMul", + _ => unimplemented!(), + } } - S::SubgroupBroadcast { - ref mode, + S::SubgroupGather { + mode, argument, result, } => { + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + self.dependencies.push((id, index, "index")) + } + } self.dependencies.push((id, argument, "arg")); self.emits.push((id, result)); - "SubgroupBroadcast" // FIXME + match mode { + crate::GatherMode::BroadcastFirst => "SubgroupBroadcastFirst", + crate::GatherMode::Broadcast(_) => "SubgroupBroadcast", + crate::GatherMode::Shuffle(_) => "SubgroupShuffle", + crate::GatherMode::ShuffleDown(_) => "SubgroupShuffleDown", + crate::GatherMode::ShuffleUp(_) => "SubgroupShuffleUp", + crate::GatherMode::ShuffleXor(_) => "SubgroupShuffleXor", + } } }; // Set the last node to the merge node diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 4ea41f19a3..e7d9f59a02 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -2251,22 +2251,111 @@ impl<'a, W: Write> Writer<'a, W> { Some(predicate) => self.write_expr(predicate, ctx)?, None => write!(self.out, "true")?, } - write!(self.out, ");")?; + writeln!(self.out, ");")?; } Statement::SubgroupCollectiveOperation { - ref op, - ref collective_op, + op, + collective_op, argument, result, } => { - unimplemented!(); // FIXME: + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + let res_ty = ctx.info[result].ty.inner_with(&self.module.types); + self.write_value_type(res_ty)?; + write!(self.out, " {res_name} = ")?; + self.named_expressions.insert(result, res_name); + + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "subgroupAll(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "subgroupAny(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupAdd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupMul(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "subgroupMax(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "subgroupMin(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "subgroupAnd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "subgroupOr(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "subgroupXor(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupExclusiveAdd(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupExclusiveMul(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupInclusiveAdd(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupInclusiveMul(")? + } + _ => unimplemented!(), + } + self.write_expr(argument, ctx)?; + writeln!(self.out, ");")?; } - Statement::SubgroupBroadcast { - ref mode, + Statement::SubgroupGather { + mode, argument, result, } => { - unimplemented!(); // FIXME + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + let res_ty = ctx.info[result].ty.inner_with(&self.module.types); + self.write_value_type(res_ty)?; + write!(self.out, " {res_name} = ")?; + self.named_expressions.insert(result, res_name); + + match mode { + crate::GatherMode::BroadcastFirst => { + write!(self.out, "subgroupBroadcastFirst(")?; + } + crate::GatherMode::Broadcast(_) => { + write!(self.out, "subgroupBroadcast(")?; + } + crate::GatherMode::Shuffle(_) => { + write!(self.out, "subgroupShuffle(")?; + } + crate::GatherMode::ShuffleDown(_) => { + write!(self.out, "subgroupShuffleDown(")?; + } + crate::GatherMode::ShuffleUp(_) => { + write!(self.out, "subgroupShuffleUp(")?; + } + crate::GatherMode::ShuffleXor(_) => { + write!(self.out, "subgroupShuffleXor(")?; + } + } + self.write_expr(argument, ctx)?; + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + write!(self.out, ", ")?; + self.write_expr(index, ctx)?; + } + } + writeln!(self.out, ");")?; } } @@ -4026,7 +4115,7 @@ impl<'a, W: Write> Writer<'a, W> { writeln!(self.out, "{level}memoryBarrierShared();")?; } if flags.contains(crate::Barrier::SUB_GROUP) { - unimplemented!() // FIXME + writeln!(self.out, "{level}subgroupMemoryBarrier();")?; } writeln!(self.out, "{level}barrier();")?; Ok(()) @@ -4205,8 +4294,10 @@ const fn glsl_built_in( Bi::WorkGroupSize => "gl_WorkGroupSize", Bi::NumWorkGroups => "gl_NumWorkGroups", // subgroup - Bi::SubgroupInvocationId => "gl_SubgroupInvocationID", + Bi::NumSubgroups => "gl_NumSubgroups", + Bi::SubgroupId => "gl_SubgroupID", Bi::SubgroupSize => "gl_SubgroupSize", + Bi::SubgroupInvocationId => "gl_SubgroupInvocationID", } } diff --git a/src/back/hlsl/conv.rs b/src/back/hlsl/conv.rs index 19c4da5e74..434e091bcb 100644 --- a/src/back/hlsl/conv.rs +++ b/src/back/hlsl/conv.rs @@ -166,12 +166,16 @@ impl crate::BuiltIn { // to this field will get replaced with references to `SPECIAL_CBUF_VAR` // in `Writer::write_expr`. Self::NumWorkGroups => "SV_GroupID", - - Self::SubgroupInvocationId - | Self::SubgroupSize - | Self::BaseInstance - | Self::BaseVertex - | Self::WorkGroupSize => return Err(Error::Unimplemented(format!("builtin {self:?}"))), + Self::SubgroupSize => "WaveGetLaneCount()", + Self::SubgroupInvocationId => "WaveGetLaneIndex()", + Self::NumSubgroups => { + // FIXME + "(numthreads[0] * numthreads[1] * numthreads[2] / WaveGetLaneCount())" + } + Self::SubgroupId => "(SV_GroupIndex / WaveGetLaneCount())", // FIXME + Self::BaseInstance | Self::BaseVertex | Self::WorkGroupSize => { + return Err(Error::Unimplemented(format!("builtin {self:?}"))) + } Self::PointSize | Self::ViewIndex | Self::PointCoord => { return Err(Error::Custom(format!("Unsupported builtin {self:?}"))) } diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index 542880703f..bab998ab14 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -2010,7 +2010,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { Statement::RayQuery { .. } => unreachable!(), Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; - let name = format!("{}{}", back::BAKE_PREFIX, result.index()); write!(self.out, "const uint4 {name} = ")?; self.named_expressions.insert(result, name); @@ -2023,19 +2022,114 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { writeln!(self.out, ");")?; } Statement::SubgroupCollectiveOperation { - ref op, - ref collective_op, + op, + collective_op, argument, result, } => { - unimplemented!(); // FIXME + write!(self.out, "{level}")?; + write!(self.out, "const ")?; + let name = format!("{}{}", back::BAKE_PREFIX, result.index()); + match func_ctx.info[result].ty { + proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?, + proc::TypeResolution::Value(ref value) => { + self.write_value_type(module, value)? + } + }; + write!(self.out, " {name} = ")?; + self.named_expressions.insert(result, name); + + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "WaveActiveAllTrue(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "WaveActiveAnyTrue(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "WaveActiveSum(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "WaveActiveProduct(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "WaveActiveMax(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "WaveActiveMin(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "WaveActiveBitAnd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "WaveActiveBitOr(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "WaveActiveBitXor(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "WavePrefixSum(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "WavePrefixProduct(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => { + self.write_expr(module, argument, func_ctx)?; + write!(self.out, " + WavePrefixSum(")?; + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => { + self.write_expr(module, argument, func_ctx)?; + write!(self.out, " * WavePrefixProduct(")?; + } + _ => unimplemented!(), + } + self.write_expr(module, argument, func_ctx)?; + writeln!(self.out, ");")?; } - Statement::SubgroupBroadcast { - ref mode, + Statement::SubgroupGather { + mode, argument, result, } => { - unimplemented!(); // FIXME + write!(self.out, "{level}")?; + write!(self.out, "const ")?; + let name = format!("{}{}", back::BAKE_PREFIX, result.index()); + match func_ctx.info[result].ty { + proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?, + proc::TypeResolution::Value(ref value) => { + self.write_value_type(module, value)? + } + }; + write!(self.out, " {name} = ")?; + self.named_expressions.insert(result, name); + + if matches!(mode, crate::GatherMode::BroadcastFirst) { + write!(self.out, "WaveReadLaneFirst(")?; + self.write_expr(module, argument, func_ctx)?; + } else { + write!(self.out, "WaveReadLaneAt(")?; + self.write_expr(module, argument, func_ctx)?; + write!(self.out, ", ")?; + match mode { + crate::GatherMode::BroadcastFirst => unreachable!(), + crate::GatherMode::Broadcast(index) | crate::GatherMode::Shuffle(index) => { + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::ShuffleDown(index) => { + write!(self.out, "WaveGetLaneIndex() + ")?; + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::ShuffleUp(index) => { + write!(self.out, "WaveGetLaneIndex() - ")?; + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::ShuffleXor(index) => { + write!(self.out, "WaveGetLaneIndex() ^ ")?; + self.write_expr(module, index, func_ctx)?; + } + } + } + writeln!(self.out, ");")?; } } @@ -3289,7 +3383,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?; } if barrier.contains(crate::Barrier::SUB_GROUP) { - unimplemented!() // FIXME + // Does not exist in DirectX } Ok(()) } diff --git a/src/back/msl/mod.rs b/src/back/msl/mod.rs index b7f709269a..2501c45422 100644 --- a/src/back/msl/mod.rs +++ b/src/back/msl/mod.rs @@ -412,8 +412,10 @@ impl ResolvedBinding { Bi::WorkGroupSize => "dispatch_threads_per_threadgroup", Bi::NumWorkGroups => "threadgroups_per_grid", // subgroup - Bi::SubgroupInvocationId => "simdgroup_index_in_threadgroup", - Bi::SubgroupSize => "simdgroups_per_threadgroup", + Bi::NumSubgroups => "simdgroups_per_threadgroup", + Bi::SubgroupId => "simdgroup_index_in_threadgroup", + Bi::SubgroupSize => "threads_per_simdgroup", + Bi::SubgroupInvocationId => "thread_index_in_simdgroup", Bi::CullDistance | Bi::ViewIndex => { return Err(Error::UnsupportedBuiltIn(built_in)) } diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 1beb08dd54..d713cc41f1 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -3001,29 +3001,115 @@ impl Writer { let name = self.namer.call(""); self.start_baking_expression(result, &context.expression, &name)?; self.named_expressions.insert(result, name); - write!(self.out, "{NAMESPACE}::simd_ballot(;")?; - match predicate { - Some(predicate) => { - self.put_expression(predicate, &context.expression, true)? - } - None => write!(self.out, "true")?, + write!(self.out, "uint4((uint64_t){NAMESPACE}::simd_ballot(")?; + if let Some(predicate) = predicate { + self.put_expression(predicate, &context.expression, true)?; + } else { + write!(self.out, "true")?; } - writeln!(self.out, ");")?; + writeln!(self.out, "), 0, 0, 0);")?; } crate::Statement::SubgroupCollectiveOperation { - ref op, - ref collective_op, + op, + collective_op, argument, result, } => { - unimplemented!(); // FIXME + write!(self.out, "{level}")?; + let name = self.namer.call(""); + self.start_baking_expression(result, &context.expression, &name)?; + self.named_expressions.insert(result, name); + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "{NAMESPACE}::simd_all(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "{NAMESPACE}::simd_any(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "{NAMESPACE}::simd_sum(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "{NAMESPACE}::simd_product(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "{NAMESPACE}::simd_max(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "{NAMESPACE}::simd_min(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "{NAMESPACE}::simd_and(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "{NAMESPACE}::simd_or(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "{NAMESPACE}::simd_xor(")? + } + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Add, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?, + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Mul, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?, + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Add, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?, + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Mul, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?, + _ => unimplemented!(), + } + self.put_expression(argument, &context.expression, true)?; + writeln!(self.out, ");")?; } - crate::Statement::SubgroupBroadcast { - ref mode, + crate::Statement::SubgroupGather { + mode, argument, result, } => { - unimplemented!(); // FIXME + write!(self.out, "{level}")?; + let name = self.namer.call(""); + self.start_baking_expression(result, &context.expression, &name)?; + self.named_expressions.insert(result, name); + match mode { + crate::GatherMode::BroadcastFirst => { + write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?; + } + crate::GatherMode::Broadcast(_) => { + write!(self.out, "{NAMESPACE}::simd_broadcast(")?; + } + crate::GatherMode::Shuffle(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle(")?; + } + crate::GatherMode::ShuffleDown(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?; + } + crate::GatherMode::ShuffleUp(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?; + } + crate::GatherMode::ShuffleXor(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?; + } + } + self.put_expression(argument, &context.expression, true)?; + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + write!(self.out, ", ")?; + self.put_expression(index, &context.expression, true)?; + } + } + writeln!(self.out, ");")?; } } } @@ -4220,7 +4306,10 @@ impl Writer { )?; } if flags.contains(crate::Barrier::SUB_GROUP) { - unimplemented!(); // FIXME + writeln!( + self.out, + "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);", + )?; } Ok(()) } diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index db393b4069..63049d5ef9 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -2350,30 +2350,11 @@ impl<'w> BlockContext<'w> { crate::Statement::RayQuery { query, ref fun } => { self.write_ray_query_function(query, fun, &mut block); } - crate::Statement::SubgroupBallot { result, predicate } => { - self.writer.require_any( - "GroupNonUniformBallot", - &[spirv::Capability::GroupNonUniformBallot], - )?; - let vec4_u32_type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: Some(crate::VectorSize::Quad), - kind: crate::ScalarKind::Uint, - width: 4, - pointer_space: None, - })); - let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); - let predicate = match predicate { - Some(predicate) => self.cached[predicate], - None => self.writer.get_constant_scalar(crate::Literal::Bool(true)), - }; - let id = self.gen_id(); - block.body.push(Instruction::group_non_uniform_ballot( - vec4_u32_type_id, - id, - exec_scope_id, - predicate, - )); - self.cached[result] = id; + crate::Statement::SubgroupBallot { + result, + ref predicate, + } => { + self.write_subgroup_ballot(predicate, result, &mut block)?; } crate::Statement::SubgroupCollectiveOperation { ref op, @@ -2383,12 +2364,12 @@ impl<'w> BlockContext<'w> { } => { self.write_subgroup_operation(op, collective_op, argument, result, &mut block)?; } - crate::Statement::SubgroupBroadcast { + crate::Statement::SubgroupGather { ref mode, argument, result, } => { - self.write_subgroup_broadcast(mode, argument, result, &mut block)?; + self.write_subgroup_gather(mode, argument, result, &mut block)?; } } } diff --git a/src/back/spv/instructions.rs b/src/back/spv/instructions.rs index 8a528065ef..5f7c6b34fd 100644 --- a/src/back/spv/instructions.rs +++ b/src/back/spv/instructions.rs @@ -1054,33 +1054,34 @@ impl super::Instruction { instruction } - pub(super) fn group_non_uniform_broadcast( + pub(super) fn group_non_uniform_broadcast_first( result_type_id: Word, id: Word, exec_scope_id: Word, value: Word, - index: Word, ) -> Self { - let mut instruction = Self::new(Op::GroupNonUniformBroadcast); + let mut instruction = Self::new(Op::GroupNonUniformBroadcastFirst); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(exec_scope_id); instruction.add_operand(value); - instruction.add_operand(index); instruction } - pub(super) fn group_non_uniform_broadcast_first( + pub(super) fn group_non_uniform_gather( + op: Op, result_type_id: Word, id: Word, exec_scope_id: Word, value: Word, + index: Word, ) -> Self { - let mut instruction = Self::new(Op::GroupNonUniformBroadcastFirst); + let mut instruction = Self::new(op); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(exec_scope_id); instruction.add_operand(value); + instruction.add_operand(index); instruction } @@ -1092,10 +1093,6 @@ impl super::Instruction { group_op: Option, value: Word, ) -> Self { - println!( - "{:?}", - (op, result_type_id, id, exec_scope_id, group_op, value) - ); let mut instruction = Self::new(op); instruction.set_type(result_type_id); instruction.set_result(id); diff --git a/src/back/spv/subgroup.rs b/src/back/spv/subgroup.rs index 3c8b7d827a..214b8ab377 100644 --- a/src/back/spv/subgroup.rs +++ b/src/back/spv/subgroup.rs @@ -1,7 +1,43 @@ use super::{Block, BlockContext, Error, Instruction}; -use crate::{arena::Handle, TypeInner}; +use crate::{ + arena::Handle, + back::spv::{LocalType, LookupType}, + TypeInner, +}; impl<'w> BlockContext<'w> { + pub(super) fn write_subgroup_ballot( + &mut self, + predicate: &Option>, + result: Handle, + block: &mut Block, + ) -> Result<(), Error> { + self.writer.require_any( + "GroupNonUniformBallot", + &[spirv::Capability::GroupNonUniformBallot], + )?; + let vec4_u32_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(crate::VectorSize::Quad), + kind: crate::ScalarKind::Uint, + width: 4, + pointer_space: None, + })); + let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); + let predicate = if let Some(predicate) = *predicate { + self.cached[predicate] + } else { + self.writer.get_constant_scalar(crate::Literal::Bool(true)) + }; + let id = self.gen_id(); + block.body.push(Instruction::group_non_uniform_ballot( + vec4_u32_type_id, + id, + exec_scope_id, + predicate, + )); + self.cached[result] = id; + Ok(()) + } pub(super) fn write_subgroup_operation( &mut self, op: &crate::SubgroupOperation, @@ -11,7 +47,7 @@ impl<'w> BlockContext<'w> { block: &mut Block, ) -> Result<(), Error> { use crate::SubgroupOperation as sg; - match op { + match *op { sg::All | sg::Any => { self.writer.require_any( "GroupNonUniformVote", @@ -35,14 +71,14 @@ impl<'w> BlockContext<'w> { let result_type_id = self.get_expression_type_id(result_ty); let result_ty_inner = result_ty.inner_with(&self.ir_module.types); - let (is_scalar, kind) = match result_ty_inner { + let (is_scalar, kind) = match *result_ty_inner { TypeInner::Scalar { kind, .. } => (true, kind), TypeInner::Vector { kind, .. } => (false, kind), _ => unimplemented!(), }; use crate::ScalarKind as sk; - let spirv_op = match (kind, op) { + let spirv_op = match (kind, *op) { (sk::Bool, sg::All) if is_scalar => spirv::Op::GroupNonUniformAll, (sk::Bool, sg::Any) if is_scalar => spirv::Op::GroupNonUniformAny, (_, sg::All | sg::Any) => unimplemented!(), @@ -71,9 +107,9 @@ impl<'w> BlockContext<'w> { let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); use crate::CollectiveOperation as c; - let group_op = match op { + let group_op = match *op { sg::All | sg::Any => None, - _ => Some(match collective_op { + _ => Some(match *collective_op { c::Reduce => spirv::GroupOperation::Reduce, c::InclusiveScan => spirv::GroupOperation::InclusiveScan, c::ExclusiveScan => spirv::GroupOperation::ExclusiveScan, @@ -92,9 +128,9 @@ impl<'w> BlockContext<'w> { self.cached[result] = id; Ok(()) } - pub(super) fn write_subgroup_broadcast( + pub(super) fn write_subgroup_gather( &mut self, - mode: &crate::BroadcastMode, + mode: &crate::GatherMode, argument: Handle, result: Handle, block: &mut Block, @@ -111,18 +147,8 @@ impl<'w> BlockContext<'w> { let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); let arg_id = self.cached[argument]; - match mode { - crate::BroadcastMode::Index(index) => { - let index_id = self.cached[*index]; - block.body.push(Instruction::group_non_uniform_broadcast( - result_type_id, - id, - exec_scope_id, - arg_id, - index_id, - )); - } - crate::BroadcastMode::First => { + match *mode { + crate::GatherMode::BroadcastFirst => { block .body .push(Instruction::group_non_uniform_broadcast_first( @@ -132,6 +158,29 @@ impl<'w> BlockContext<'w> { arg_id, )); } + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + let index_id = self.cached[index]; + let op = match *mode { + crate::GatherMode::BroadcastFirst => unreachable!(), + crate::GatherMode::Broadcast(_) => spirv::Op::GroupNonUniformBroadcast, + crate::GatherMode::Shuffle(_) => spirv::Op::GroupNonUniformShuffle, + crate::GatherMode::ShuffleDown(_) => spirv::Op::GroupNonUniformShuffleDown, + crate::GatherMode::ShuffleUp(_) => spirv::Op::GroupNonUniformShuffleUp, + crate::GatherMode::ShuffleXor(_) => spirv::Op::GroupNonUniformShuffleXor, + }; + block.body.push(Instruction::group_non_uniform_gather( + op, + result_type_id, + id, + exec_scope_id, + arg_id, + index_id, + )); + } } self.cached[result] = id; Ok(()) diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index 0b3459f26a..9475554646 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -1594,16 +1594,23 @@ impl Writer { Bi::WorkGroupSize => BuiltIn::WorkgroupSize, Bi::NumWorkGroups => BuiltIn::NumWorkgroups, // Subgroup - Bi::SubgroupInvocationId => { + Bi::NumSubgroups => { self.require_any( - "`subgroup_invocation_id` built-in", + "`num_subgroups` built-in", &[spirv::Capability::GroupNonUniform], )?; - BuiltIn::SubgroupLocalInvocationId + BuiltIn::NumSubgroups + } + Bi::SubgroupId => { + self.require_any( + "`subgroup_id` built-in", + &[spirv::Capability::GroupNonUniform], + )?; + BuiltIn::SubgroupId } Bi::SubgroupSize => { self.require_any( - "`subgroup_invocation_id` built-in", + "`subgroup_size` built-in", &[ spirv::Capability::GroupNonUniform, spirv::Capability::SubgroupBallotKHR, @@ -1611,6 +1618,16 @@ impl Writer { )?; BuiltIn::SubgroupSize } + Bi::SubgroupInvocationId => { + self.require_any( + "`subgroup_invocation_id` built-in", + &[ + spirv::Capability::GroupNonUniform, + spirv::Capability::SubgroupBallotKHR, + ], + )?; + BuiltIn::SubgroupLocalInvocationId + } }; self.decorate(id, Decoration::BuiltIn, &[built_in as u32]); diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index b7cd0b478f..cf74bac61f 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -919,6 +919,10 @@ impl Writer { if barrier.contains(crate::Barrier::WORK_GROUP) { writeln!(self.out, "{level}workgroupBarrier();")?; } + + if barrier.contains(crate::Barrier::SUB_GROUP) { + writeln!(self.out, "{level}subgroupBarrier();")?; + } } Statement::RayQuery { .. } => unreachable!(), Statement::SubgroupBallot { result, predicate } => { @@ -934,19 +938,104 @@ impl Writer { writeln!(self.out, ");")?; } Statement::SubgroupCollectiveOperation { - ref op, - ref collective_op, + op, + collective_op, argument, result, } => { - unimplemented!() // FIXME + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_named_expr(module, result, func_ctx, &res_name)?; + self.named_expressions.insert(result, res_name); + + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "subgroupAll(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "subgroupAny(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupAdd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupMul(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "subgroupMax(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "subgroupMin(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "subgroupAnd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "subgroupOr(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "subgroupXor(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupPrefixExclusiveAdd(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupPrefixExclusiveMul(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupPrefixInclusiveAdd(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupPrefixInclusiveMul(")? + } + _ => unimplemented!(), + } + self.write_expr(module, argument, func_ctx)?; + writeln!(self.out, ");")?; } - Statement::SubgroupBroadcast { - ref mode, + Statement::SubgroupGather { + mode, argument, result, } => { - unimplemented!() // FIXME + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_named_expr(module, result, func_ctx, &res_name)?; + self.named_expressions.insert(result, res_name); + + match mode { + crate::GatherMode::BroadcastFirst => { + write!(self.out, "subgroupBroadcastFirst(")?; + } + crate::GatherMode::Broadcast(_) => { + write!(self.out, "subgroupBroadcast(")?; + } + crate::GatherMode::Shuffle(_) => { + write!(self.out, "subgroupShuffle(")?; + } + crate::GatherMode::ShuffleDown(_) => { + write!(self.out, "subgroupShuffleDown(")?; + } + crate::GatherMode::ShuffleUp(_) => { + write!(self.out, "subgroupShuffleUp(")?; + } + crate::GatherMode::ShuffleXor(_) => { + write!(self.out, "subgroupShuffleXor(")?; + } + } + self.write_expr(module, argument, func_ctx)?; + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + write!(self.out, ", ")?; + self.write_expr(module, index, func_ctx)?; + } + } + writeln!(self.out, ");")?; } } @@ -1794,8 +1883,10 @@ fn builtin_str(built_in: crate::BuiltIn) -> Result<&'static str, Error> { Bi::SampleMask => "sample_mask", Bi::PrimitiveIndex => "primitive_index", Bi::ViewIndex => "view_index", - Bi::SubgroupInvocationId => "subgroup_invocation_id", + Bi::NumSubgroups => "num_subgroups", + Bi::SubgroupId => "subgroup_id", Bi::SubgroupSize => "subgroup_size", + Bi::SubgroupInvocationId => "subgroup_invocation_id", Bi::BaseInstance | Bi::BaseVertex | Bi::ClipDistance diff --git a/src/compact/statements.rs b/src/compact/statements.rs index 0e8c0e4e81..916f11b0aa 100644 --- a/src/compact/statements.rs +++ b/src/compact/statements.rs @@ -102,21 +102,26 @@ impl FunctionTracer<'_> { self.trace_expression(result); } St::SubgroupCollectiveOperation { - ref op, - ref collective_op, + op: _, + collective_op: _, argument, result, } => { self.trace_expression(argument); self.trace_expression(result); } - St::SubgroupBroadcast { - ref mode, + St::SubgroupGather { + mode, argument, result, } => { - if let crate::BroadcastMode::Index(expr) = *mode { - self.trace_expression(expr); + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => self.trace_expression(index), } self.trace_expression(argument); self.trace_expression(result); @@ -274,27 +279,32 @@ impl FunctionMap { ref mut result, ref mut predicate, } => { - if let Some(ref mut predicate) = predicate { + if let Some(ref mut predicate) = *predicate { adjust(predicate); } adjust(result); } St::SubgroupCollectiveOperation { - ref mut op, - ref mut collective_op, + op: _, + collective_op: _, ref mut argument, ref mut result, } => { adjust(argument); adjust(result); } - St::SubgroupBroadcast { + St::SubgroupGather { ref mut mode, ref mut argument, ref mut result, } => { - if let crate::BroadcastMode::Index(expr) = mode { - adjust(expr); + match *mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(ref mut index) => adjust(index), + crate::GatherMode::Shuffle(ref mut index) => adjust(index), + crate::GatherMode::ShuffleDown(ref mut index) => adjust(index), + crate::GatherMode::ShuffleUp(ref mut index) => adjust(index), + crate::GatherMode::ShuffleXor(ref mut index) => adjust(index), } adjust(argument); adjust(result); diff --git a/src/front/spv/error.rs b/src/front/spv/error.rs index 2f9bf2d1bc..8508ede042 100644 --- a/src/front/spv/error.rs +++ b/src/front/spv/error.rs @@ -54,6 +54,8 @@ pub enum Error { UnknownBinaryOperator(spirv::Op), #[error("unknown relational function {0:?}")] UnknownRelationalFunction(spirv::Op), + #[error("unsupported group opeation %{0}")] + UnsupportedGroupOperation(spirv::Word), #[error("invalid parameter {0:?}")] InvalidParameter(spirv::Op), #[error("invalid operand count {1} for {0:?}")] @@ -116,8 +118,8 @@ pub enum Error { FunctionCallCycle(spirv::Word), #[error("invalid array size {0:?}")] InvalidArraySize(Handle), - #[error("invalid barrier scope %{0}")] - InvalidBarrierScope(spirv::Word), + #[error("invalid execution scope %{0}")] + InvalidExecutionScope(spirv::Word), #[error("invalid barrier memory semantics %{0}")] InvalidBarrierMemorySemantics(spirv::Word), #[error( diff --git a/src/front/spv/mod.rs b/src/front/spv/mod.rs index 6e64677871..4a78723823 100644 --- a/src/front/spv/mod.rs +++ b/src/front/spv/mod.rs @@ -3645,7 +3645,7 @@ impl> Frontend { let semantics_const = self.lookup_constant.lookup(semantics_id)?; let exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle) - .ok_or(Error::InvalidBarrierScope(exec_scope_id))?; + .ok_or(Error::InvalidExecutionScope(exec_scope_id))?; let semantics = resolve_constant(ctx.gctx(), semantics_const.handle) .ok_or(Error::InvalidBarrierMemorySemantics(semantics_id))?; @@ -3686,6 +3686,209 @@ impl> Frontend { }, ); } + Op::GroupNonUniformBallot => { + inst.expect(4)?; + let _result_type_id = self.next()?; + let result_id = self.next()?; + let exec_scope_id = self.next()?; + let predicate_id = self.next()?; + + let result_lookup = self.lookup_expression.lookup(result_id)?; + let result_handle = get_expr_handle!(result_id, result_lookup); + + let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; + let _exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle) + .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) + .ok_or(Error::InvalidExecutionScope(exec_scope_id))?; + + let predicate = if self + .lookup_constant + .lookup(predicate_id) + .ok() + .filter(|predicate_const| { + matches!( + ctx.gctx().const_expressions + [ctx.gctx().constants[predicate_const.handle].init], + crate::Expression::Literal(crate::Literal::Bool(true)) + ) + }) + .is_some() + { + None + } else { + let predicate_lookup = self.lookup_expression.lookup(predicate_id)?; + let predicate_handle = get_expr_handle!(predicate_id, predicate_lookup); + Some(predicate_handle) + }; + + block.push( + crate::Statement::SubgroupBallot { + result: result_handle, + predicate, + }, + span, + ); + } + spirv::Op::GroupNonUniformAll + | spirv::Op::GroupNonUniformAny + | spirv::Op::GroupNonUniformIAdd + | spirv::Op::GroupNonUniformFAdd + | spirv::Op::GroupNonUniformIMul + | spirv::Op::GroupNonUniformFMul + | spirv::Op::GroupNonUniformSMax + | spirv::Op::GroupNonUniformUMax + | spirv::Op::GroupNonUniformFMax + | spirv::Op::GroupNonUniformSMin + | spirv::Op::GroupNonUniformUMin + | spirv::Op::GroupNonUniformFMin + | spirv::Op::GroupNonUniformBitwiseAnd + | spirv::Op::GroupNonUniformBitwiseOr + | spirv::Op::GroupNonUniformBitwiseXor + | spirv::Op::GroupNonUniformLogicalAnd + | spirv::Op::GroupNonUniformLogicalOr + | spirv::Op::GroupNonUniformLogicalXor => { + inst.expect( + if matches!( + inst.op, + spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny + ) { + 4 + } else { + 5 + }, + )?; + let _result_type_id = self.next()?; + let result_id = self.next()?; + let exec_scope_id = self.next()?; + let collective_op_id = match inst.op { + spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny => { + crate::CollectiveOperation::Reduce + } + _ => { + let group_op_id = self.next()?; + match spirv::GroupOperation::from_u32(group_op_id) { + Some(spirv::GroupOperation::Reduce) => { + crate::CollectiveOperation::Reduce + } + Some(spirv::GroupOperation::InclusiveScan) => { + crate::CollectiveOperation::InclusiveScan + } + Some(spirv::GroupOperation::ExclusiveScan) => { + crate::CollectiveOperation::ExclusiveScan + } + _ => return Err(Error::UnsupportedGroupOperation(group_op_id)), + } + } + }; + let argument_id = self.next()?; + + let result_lookup = self.lookup_expression.lookup(result_id)?; + let result_handle = get_expr_handle!(result_id, result_lookup); + let argument_lookup = self.lookup_expression.lookup(argument_id)?; + let argument_handle = get_expr_handle!(argument_id, argument_lookup); + + let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; + let _exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle) + .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) + .ok_or(Error::InvalidExecutionScope(exec_scope_id))?; + + let op_id = match inst.op { + spirv::Op::GroupNonUniformAll => crate::SubgroupOperation::All, + spirv::Op::GroupNonUniformAny => crate::SubgroupOperation::Any, + spirv::Op::GroupNonUniformIAdd | spirv::Op::GroupNonUniformFAdd => { + crate::SubgroupOperation::Add + } + spirv::Op::GroupNonUniformIMul | spirv::Op::GroupNonUniformFMul => { + crate::SubgroupOperation::Mul + } + spirv::Op::GroupNonUniformSMax + | spirv::Op::GroupNonUniformUMax + | spirv::Op::GroupNonUniformFMax => crate::SubgroupOperation::Max, + spirv::Op::GroupNonUniformSMin + | spirv::Op::GroupNonUniformUMin + | spirv::Op::GroupNonUniformFMin => crate::SubgroupOperation::Min, + spirv::Op::GroupNonUniformBitwiseAnd + | spirv::Op::GroupNonUniformLogicalAnd => crate::SubgroupOperation::And, + spirv::Op::GroupNonUniformBitwiseOr + | spirv::Op::GroupNonUniformLogicalOr => crate::SubgroupOperation::Or, + spirv::Op::GroupNonUniformBitwiseXor + | spirv::Op::GroupNonUniformLogicalXor => crate::SubgroupOperation::Xor, + _ => unreachable!(), + }; + + block.push( + crate::Statement::SubgroupCollectiveOperation { + result: result_handle, + op: op_id, + collective_op: collective_op_id, + argument: argument_handle, + }, + span, + ); + } + Op::GroupNonUniformBroadcastFirst + | Op::GroupNonUniformBroadcast + | Op::GroupNonUniformShuffle + | Op::GroupNonUniformShuffleDown + | Op::GroupNonUniformShuffleUp + | Op::GroupNonUniformShuffleXor => { + inst.expect( + if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) { + 4 + } else { + 5 + }, + )?; + let _result_type_id = self.next()?; + let result_id = self.next()?; + let exec_scope_id = self.next()?; + let argument_id = self.next()?; + + let result_lookup = self.lookup_expression.lookup(result_id)?; + let result_handle = get_expr_handle!(result_id, result_lookup); + let argument_lookup = self.lookup_expression.lookup(argument_id)?; + let argument_handle = get_expr_handle!(argument_id, argument_lookup); + + let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; + let _exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle) + .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) + .ok_or(Error::InvalidExecutionScope(exec_scope_id))?; + + let mode = if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) { + crate::GatherMode::BroadcastFirst + } else { + let index_id = self.next()?; + let index_lookup = self.lookup_expression.lookup(index_id)?; + let index_handle = get_expr_handle!(index_id, index_lookup); + match inst.op { + spirv::Op::GroupNonUniformBroadcast => { + crate::GatherMode::Broadcast(index_handle) + } + spirv::Op::GroupNonUniformShuffle => { + crate::GatherMode::Shuffle(index_handle) + } + spirv::Op::GroupNonUniformShuffleDown => { + crate::GatherMode::ShuffleDown(index_handle) + } + spirv::Op::GroupNonUniformShuffleUp => { + crate::GatherMode::ShuffleUp(index_handle) + } + spirv::Op::GroupNonUniformShuffleXor => { + crate::GatherMode::ShuffleXor(index_handle) + } + _ => unreachable!(), + } + }; + + block.push( + crate::Statement::SubgroupGather { + result: result_handle, + mode, + argument: argument_handle, + }, + span, + ); + } _ => return Err(Error::UnsupportedInstruction(self.state, inst.op)), } }; @@ -3806,7 +4009,10 @@ impl> Frontend { | S::Store { .. } | S::ImageStore { .. } | S::Atomic { .. } - | S::RayQuery { .. } => {} + | S::RayQuery { .. } + | S::SubgroupBallot { .. } + | S::SubgroupCollectiveOperation { .. } + | S::SubgroupGather { .. } => {} S::Call { function: ref mut callee, ref arguments, @@ -3837,9 +4043,6 @@ impl> Frontend { } } S::WorkGroupUniformLoad { .. } => unreachable!(), - S::SubgroupBallot { .. } => unreachable!(), // FIXME?? - S::SubgroupCollectiveOperation { .. } => unreachable!(), - S::SubgroupBroadcast { .. } => unreachable!(), } i += 1; } diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index cd2e62bd45..bfe9b5b4b1 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -1881,7 +1881,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } else if let Some(fun) = Texture::map(function.name) { self.texture_sample_helper(fun, arguments, span, ctx.reborrow())? } else if let Some((op, cop)) = conv::map_subgroup_operation(function.name) { - return Ok(Some(self.subgroup_helper(span, op, cop, arguments, ctx)?)); + return Ok(Some( + self.subgroup_operation_helper(span, op, cop, arguments, ctx)?, + )); + } else if let Some(mode) = conv::map_subgroup_gather(function.name) { + return Ok(Some( + self.subgroup_gather_helper(span, mode, arguments, ctx)?, + )); } else { match function.name { "select" => { @@ -2276,59 +2282,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { args.finish()?; let result = ctx - .interrupt_emitter(crate::Expression::SubgroupBallotResult, span); + .interrupt_emitter(crate::Expression::SubgroupBallotResult, span)?; let rctx = ctx.runtime_expression_ctx(span)?; rctx.block .push(crate::Statement::SubgroupBallot { result, predicate }, span); return Ok(Some(result)); } - "subgroupBroadcast" => { - let mut args = ctx.prepare_args(arguments, 2, span); - - let index = self.expression(args.next()?, ctx.reborrow())?; - let argument = self.expression(args.next()?, ctx.reborrow())?; - args.finish()?; - - let ty = ctx.register_type(argument)?; - - let result = ctx.interrupt_emitter( - crate::Expression::SubgroupOperationResult { ty }, - span, - ); - let rctx = ctx.runtime_expression_ctx(span)?; - rctx.block.push( - crate::Statement::SubgroupBroadcast { - mode: crate::BroadcastMode::Index(index), - argument, - result, - }, - span, - ); - return Ok(Some(result)); - } - "subgroupBroadcastFirst" => { - let mut args = ctx.prepare_args(arguments, 1, span); - - let argument = self.expression(args.next()?, ctx.reborrow())?; - args.finish()?; - - let ty = ctx.register_type(argument)?; - - let result = ctx.interrupt_emitter( - crate::Expression::SubgroupOperationResult { ty }, - span, - ); - let rctx = ctx.runtime_expression_ctx(span)?; - rctx.block.push( - crate::Statement::SubgroupBroadcast { - mode: crate::BroadcastMode::First, - argument, - result, - }, - span, - ); - return Ok(Some(result)); - } _ => return Err(Error::UnknownIdent(function.span, function.name)), } }; @@ -2522,7 +2481,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { depth_ref, }) } - fn subgroup_helper( + + fn subgroup_operation_helper( &mut self, span: Span, op: crate::SubgroupOperation, @@ -2537,7 +2497,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let ty = ctx.register_type(argument)?; - let result = ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span); + let result = + ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?; let rctx = ctx.runtime_expression_ctx(span)?; rctx.block.push( crate::Statement::SubgroupCollectiveOperation { @@ -2551,6 +2512,46 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { Ok(result) } + fn subgroup_gather_helper( + &mut self, + span: Span, + mode: crate::GatherMode, + arguments: &[Handle>], + mut ctx: ExpressionContext<'source, '_, '_>, + ) -> Result, Error<'source>> { + let mut args = ctx.prepare_args(arguments, 2, span); + + let argument = self.expression(args.next()?, ctx.reborrow())?; + let index = if let crate::GatherMode::BroadcastFirst = mode { + Handle::new(NonZeroU32::new(u32::MAX).unwrap()) + } else { + self.expression(args.next()?, ctx.reborrow())? + }; + args.finish()?; + + let ty = ctx.register_type(argument)?; + + let result = + ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::SubgroupGather { + mode: match mode { + crate::GatherMode::BroadcastFirst => crate::GatherMode::BroadcastFirst, + crate::GatherMode::Broadcast(_) => crate::GatherMode::Broadcast(index), + crate::GatherMode::Shuffle(_) => crate::GatherMode::Shuffle(index), + crate::GatherMode::ShuffleDown(_) => crate::GatherMode::ShuffleDown(index), + crate::GatherMode::ShuffleUp(_) => crate::GatherMode::ShuffleUp(index), + crate::GatherMode::ShuffleXor(_) => crate::GatherMode::ShuffleXor(index), + }, + argument, + result, + }, + span, + ); + Ok(result) + } + fn r#struct( &mut self, s: &ast::Struct<'source>, diff --git a/src/front/wgsl/parse/conv.rs b/src/front/wgsl/parse/conv.rs index 160213e4a3..c53f4df753 100644 --- a/src/front/wgsl/parse/conv.rs +++ b/src/front/wgsl/parse/conv.rs @@ -35,8 +35,10 @@ pub fn map_built_in(word: &str, span: Span) -> Result> "workgroup_id" => crate::BuiltIn::WorkGroupId, "num_workgroups" => crate::BuiltIn::NumWorkGroups, // subgroup - "subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId, + "num_subgroups" => crate::BuiltIn::NumSubgroups, + "subgroup_id" => crate::BuiltIn::SubgroupId, "subgroup_size" => crate::BuiltIn::SubgroupSize, + "subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId, _ => return Err(Error::UnknownBuiltin(span)), }) } @@ -254,8 +256,25 @@ pub fn map_subgroup_operation( "subgroupAnd" => (sg::And, co::Reduce), "subgroupOr" => (sg::Or, co::Reduce), "subgroupXor" => (sg::Xor, co::Reduce), - "subgroupPrefixAdd" => (sg::Add, co::InclusiveScan), - "subgroupPrefixMul" => (sg::Mul, co::InclusiveScan), + "subgroupPrefixExclusiveAdd" => (sg::Add, co::ExclusiveScan), + "subgroupPrefixExclusiveMul" => (sg::Mul, co::ExclusiveScan), + "subgroupPrefixInclusiveAdd" => (sg::Add, co::InclusiveScan), + "subgroupPrefixInclusiveMul" => (sg::Mul, co::InclusiveScan), + _ => return None, + }) +} + +pub fn map_subgroup_gather(word: &str) -> Option { + use crate::GatherMode as gm; + use crate::Handle; + use std::num::NonZeroU32; + Some(match word { + "subgroupBroadcastFirst" => gm::BroadcastFirst, + "subgroupBroadcast" => gm::Broadcast(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), + "subgroupShuffle" => gm::Shuffle(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), + "subgroupShuffleDown" => gm::ShuffleDown(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), + "subgroupShuffleUp" => gm::ShuffleUp(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), + "subgroupShuffleXor" => gm::ShuffleXor(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), _ => return None, }) } diff --git a/src/lib.rs b/src/lib.rs index 9cac711313..c923e14e3c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -436,8 +436,10 @@ pub enum BuiltIn { WorkGroupSize, NumWorkGroups, // subgroup - SubgroupInvocationId, + NumSubgroups, + SubgroupId, SubgroupSize, + SubgroupInvocationId, } /// Number of bytes per scalar. @@ -1266,9 +1268,13 @@ pub enum SwizzleComponent { #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] -pub enum BroadcastMode { - First, - Index(Handle), +pub enum GatherMode { + BroadcastFirst, + Broadcast(Handle), + Shuffle(Handle), + ShuffleDown(Handle), + ShuffleUp(Handle), + ShuffleXor(Handle), } #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] @@ -1900,9 +1906,9 @@ pub enum Statement { predicate: Option>, }, - SubgroupBroadcast { - /// Specifies which thread to broadcast from - mode: BroadcastMode, + SubgroupGather { + /// Specifies which thread to gather from + mode: GatherMode, /// The value to broadcast over argument: Handle, /// The [`SubgroupOperationResult`] expression representing this load's result. diff --git a/src/proc/constant_evaluator.rs b/src/proc/constant_evaluator.rs index 70a46114a9..5ce7bd8400 100644 --- a/src/proc/constant_evaluator.rs +++ b/src/proc/constant_evaluator.rs @@ -133,6 +133,8 @@ pub enum ConstantEvaluatorError { ImageExpression, #[error("Constants don't support ray query expressions")] RayQueryExpression, + #[error("Constants don't support subgroup expressions")] + SubgroupExpression, #[error("Cannot access the type")] InvalidAccessBase, #[error("Cannot access at the index")] @@ -439,6 +441,12 @@ impl<'a> ConstantEvaluator<'a> { Expression::RayQueryProceedResult | Expression::RayQueryGetIntersection { .. } => { Err(ConstantEvaluatorError::RayQueryExpression) } + Expression::SubgroupBallotResult { .. } => { + Err(ConstantEvaluatorError::SubgroupExpression) + } + Expression::SubgroupOperationResult { .. } => { + Err(ConstantEvaluatorError::SubgroupExpression) + } } } diff --git a/src/proc/terminator.rs b/src/proc/terminator.rs index 35111a11de..5edf55cb73 100644 --- a/src/proc/terminator.rs +++ b/src/proc/terminator.rs @@ -39,7 +39,7 @@ pub fn ensure_block_returns(block: &mut crate::Block) { | S::WorkGroupUniformLoad { .. } | S::SubgroupBallot { .. } | S::SubgroupCollectiveOperation { .. } - | S::SubgroupBroadcast { .. } + | S::SubgroupGather { .. } | S::Barrier(_)), ) | None => block.push(S::Return { value: None }, Default::default()), diff --git a/src/valid/analyzer.rs b/src/valid/analyzer.rs index 262e3ffec7..f4347df1dd 100644 --- a/src/valid/analyzer.rs +++ b/src/valid/analyzer.rs @@ -744,7 +744,7 @@ impl FunctionInfo { non_uniform_result: None, // FIXME requirements: UniformityRequirements::empty(), }, - E::SubgroupOperationResult { ty } => Uniformity { + E::SubgroupOperationResult { .. } => Uniformity { non_uniform_result: None, // FIXME requirements: UniformityRequirements::empty(), }, @@ -1001,22 +1001,29 @@ impl FunctionInfo { FunctionUniformity::new() } S::SubgroupCollectiveOperation { - ref op, - ref collective_op, + op: _, + collective_op: _, argument, result: _, } => { - let _ = self.add_ref(argument); // FIXME + let _ = self.add_ref(argument); FunctionUniformity::new() } - S::SubgroupBroadcast { - ref mode, + S::SubgroupGather { + mode, argument, - result, + result: _, } => { let _ = self.add_ref(argument); - if let crate::BroadcastMode::Index(expr) = *mode { - let _ = self.add_ref(expr); + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + let _ = self.add_ref(index); + } } FunctionUniformity::new() } diff --git a/src/valid/expression.rs b/src/valid/expression.rs index e6e483b8be..fda30c3854 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -1535,7 +1535,7 @@ impl super::Validator { } }, E::SubgroupBallotResult => ShaderStages::COMPUTE | ShaderStages::FRAGMENT, - E::SubgroupOperationResult { ty } => ShaderStages::COMPUTE, // FIXME + E::SubgroupOperationResult { .. } => ShaderStages::COMPUTE, // FIXME }; Ok(stages) } diff --git a/src/valid/function.rs b/src/valid/function.rs index 54e83f6c21..89ed1e6fb1 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -435,7 +435,7 @@ impl super::Validator { ) -> Result<(), WithSpan> { let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; - let (is_scalar, kind) = match argument_inner { + let (is_scalar, kind) = match *argument_inner { crate::TypeInner::Scalar { kind, .. } => (true, kind), crate::TypeInner::Vector { kind, .. } => (false, kind), _ => unimplemented!(), @@ -443,7 +443,7 @@ impl super::Validator { use crate::ScalarKind as sk; use crate::SubgroupOperation as sg; - match (kind, op) { + match (kind, *op) { (sk::Bool, sg::All | sg::Any) if is_scalar => {} (sk::Sint | sk::Uint | sk::Float, sg::Add | sg::Mul | sg::Min | sg::Max) => {} (sk::Sint | sk::Uint | sk::Bool, sg::And | sg::Or | sg::Xor) => {} @@ -473,32 +473,39 @@ impl super::Validator { #[cfg(feature = "validate")] fn validate_subgroup_broadcast( &mut self, - mode: &crate::BroadcastMode, + mode: &crate::GatherMode, argument: Handle, result: Handle, context: &BlockContext, ) -> Result<(), WithSpan> { - if let crate::BroadcastMode::Index(expr) = *mode { - let index_ty = context.resolve_type(expr, &self.valid_expression_set)?; - match index_ty { - crate::TypeInner::Scalar { - kind: crate::ScalarKind::Uint, - .. - } => {} - _ => { - log::error!( - "Subgroup broadcast index type {:?}, expected unsigned int", - index_ty - ); - return Err(SubgroupError::InvalidOperand(argument) - .with_span_handle(expr, context.expressions) - .into_other()); + match *mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + let index_ty = context.resolve_type(index, &self.valid_expression_set)?; + match *index_ty { + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Uint, + .. + } => {} + _ => { + log::error!( + "Subgroup broadcast index type {:?}, expected unsigned int", + index_ty + ); + return Err(SubgroupError::InvalidOperand(argument) + .with_span_handle(index, context.expressions) + .into_other()); + } } } } let argument_inner = context.resolve_type(argument, &self.valid_expression_set)?; - match argument_inner { + match *argument_inner { crate::TypeInner::Scalar { .. } | crate::TypeInner::Vector { .. } => {} _ => { log::error!("Subgroup broadcast operand type {:?}", argument_inner); @@ -1030,7 +1037,7 @@ impl super::Validator { if let Some(predicate) = predicate { let predicate_inner = context.resolve_type(predicate, &self.valid_expression_set)?; - match predicate_inner { + match *predicate_inner { crate::TypeInner::Scalar { kind: crate::ScalarKind::Bool, .. @@ -1056,7 +1063,7 @@ impl super::Validator { } => { self.validate_subgroup_operation(op, collective_op, argument, result, context)?; } - S::SubgroupBroadcast { + S::SubgroupGather { ref mode, argument, result, diff --git a/src/valid/handles.rs b/src/valid/handles.rs index 91209b460b..e1674f0804 100644 --- a/src/valid/handles.rs +++ b/src/valid/handles.rs @@ -547,8 +547,8 @@ impl super::Validator { Ok(()) } crate::Statement::SubgroupCollectiveOperation { - op, - collective_op, + op: _, + collective_op: _, argument, result, } => { @@ -556,13 +556,19 @@ impl super::Validator { validate_expr(result)?; Ok(()) } - crate::Statement::SubgroupBroadcast { + crate::Statement::SubgroupGather { mode, argument, result, } => { - if let crate::BroadcastMode::Index(expr) = mode { - validate_expr(expr)?; + validate_expr(argument)?; + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => validate_expr(index)?, } validate_expr(result)?; Ok(()) diff --git a/src/valid/interface.rs b/src/valid/interface.rs index a0850c7aa5..2824e5bdd3 100644 --- a/src/valid/interface.rs +++ b/src/valid/interface.rs @@ -299,7 +299,7 @@ impl VaryingContext<'_> { width, }, ), - Bi::SubgroupInvocationId => ( + Bi::NumSubgroups | Bi::SubgroupId => ( self.stage == St::Compute && !self.output, *ty_inner == Ti::Scalar { @@ -307,7 +307,7 @@ impl VaryingContext<'_> { width, }, ), - Bi::SubgroupSize => ( + Bi::SubgroupSize | Bi::SubgroupInvocationId => ( match self.stage { St::Compute | St::Fragment => !self.output, St::Vertex => false,