From ab577c6586c5b40ed84b5a6e72af0dc330a4ab4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alexander=20Mei=C3=9Fner?= Date: Sat, 14 Oct 2023 19:35:39 +0200 Subject: [PATCH] Implements all frontends and backends. --- src/back/dot/mod.rs | 66 ++++++++++- src/back/glsl/features.rs | 23 ++++ src/back/glsl/mod.rs | 102 ++++++++++++++++- src/back/hlsl/conv.rs | 14 +-- src/back/hlsl/writer.rs | 177 +++++++++++++++++++++++++++-- src/back/msl/mod.rs | 7 +- src/back/msl/writer.rs | 109 ++++++++++++++++-- src/back/spv/block.rs | 31 +---- src/back/spv/instructions.rs | 17 ++- src/back/spv/subgroup.rs | 111 ++++++++++++++---- src/back/spv/writer.rs | 26 ++++- src/back/wgsl/writer.rs | 98 +++++++++++++++- src/front/spv/error.rs | 6 +- src/front/spv/mod.rs | 213 ++++++++++++++++++++++++++++++++++- src/front/wgsl/lower/mod.rs | 98 ++++++++-------- src/front/wgsl/parse/conv.rs | 25 +++- 16 files changed, 959 insertions(+), 164 deletions(-) diff --git a/src/back/dot/mod.rs b/src/back/dot/mod.rs index 3174b7c6b6..86f4797b56 100644 --- a/src/back/dot/mod.rs +++ b/src/back/dot/mod.rs @@ -294,16 +294,78 @@ impl StatementGraph { } => { self.dependencies.push((id, argument, "arg")); self.emits.push((id, result)); - "SubgroupCollectiveOperation" // FIXME + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + "SubgroupAll" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + "SubgroupAny" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + "SubgroupAdd" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + "SubgroupMul" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + "SubgroupMax" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + "SubgroupMin" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + "SubgroupAnd" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + "SubgroupOr" + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + "SubgroupXor" + } + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Add, + ) => "SubgroupPrefixExclusiveAdd", + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Mul, + ) => "SubgroupPrefixExclusiveMul", + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Add, + ) => "SubgroupPrefixInclusiveAdd", + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Mul, + ) => "SubgroupPrefixInclusiveMul", + _ => unimplemented!(), + } } S::SubgroupGather { mode, argument, result, } => { + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + self.dependencies.push((id, index, "index")) + } + } self.dependencies.push((id, argument, "arg")); self.emits.push((id, result)); - "SubgroupGather" // FIXME + match mode { + crate::GatherMode::BroadcastFirst => "SubgroupBroadcastFirst", + crate::GatherMode::Broadcast(_) => "SubgroupBroadcast", + crate::GatherMode::Shuffle(_) => "SubgroupShuffle", + crate::GatherMode::ShuffleDown(_) => "SubgroupShuffleDown", + crate::GatherMode::ShuffleUp(_) => "SubgroupShuffleUp", + crate::GatherMode::ShuffleXor(_) => "SubgroupShuffleXor", + } } }; // Set the last node to the merge node diff --git a/src/back/glsl/features.rs b/src/back/glsl/features.rs index b1fff4d4bc..483a3a9348 100644 --- a/src/back/glsl/features.rs +++ b/src/back/glsl/features.rs @@ -43,6 +43,8 @@ bitflags::bitflags! { const IMAGE_SIZE = 1 << 20; /// Dual source blending const DUAL_SOURCE_BLENDING = 1 << 21; + /// Subgroup operations + const SUBGROUP_OPERATIONS = 1 << 22; } } @@ -106,6 +108,7 @@ impl FeaturesManager { check_feature!(SAMPLE_VARIABLES, 400, 300); check_feature!(DYNAMIC_ARRAY_SIZE, 430, 310); check_feature!(DUAL_SOURCE_BLENDING, 330, 300 /* with extension */); + check_feature!(SUBGROUP_OPERATIONS, 430, 310); match version { Version::Embedded { is_webgl: true, .. } => check_feature!(MULTI_VIEW, 140, 300), _ => check_feature!(MULTI_VIEW, 140, 310), @@ -235,6 +238,22 @@ impl FeaturesManager { writeln!(out, "#extension GL_EXT_blend_func_extended : require")?; } + if self.0.contains(Features::SUBGROUP_OPERATIONS) { + // https://registry.khronos.org/OpenGL/extensions/KHR/KHR_shader_subgroup.txt + writeln!(out, "#extension GL_KHR_shader_subgroup_basic : require")?; + writeln!(out, "#extension GL_KHR_shader_subgroup_vote : require")?; + writeln!( + out, + "#extension GL_KHR_shader_subgroup_arithmetic : require" + )?; + writeln!(out, "#extension GL_KHR_shader_subgroup_ballot : require")?; + writeln!(out, "#extension GL_KHR_shader_subgroup_shuffle : require")?; + writeln!( + out, + "#extension GL_KHR_shader_subgroup_shuffle_relative : require" + )?; + } + Ok(()) } } @@ -455,6 +474,10 @@ impl<'a, W> Writer<'a, W> { } } } + Expression::SubgroupBallotResult | + Expression::SubgroupOperationResult { .. } => { + features.request(Features::SUBGROUP_OPERATIONS) + } _ => {} } } diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index e6d3648f1d..03a06890f8 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -2263,7 +2263,7 @@ impl<'a, W: Write> Writer<'a, W> { Some(predicate) => self.write_expr(predicate, ctx)?, None => write!(self.out, "true")?, } - write!(self.out, ");")?; + writeln!(self.out, ");")?; } Statement::SubgroupCollectiveOperation { op, @@ -2271,14 +2271,103 @@ impl<'a, W: Write> Writer<'a, W> { argument, result, } => { - unimplemented!(); // FIXME: + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + let res_ty = ctx.info[result].ty.inner_with(&self.module.types); + self.write_value_type(res_ty)?; + write!(self.out, " {res_name} = ")?; + self.named_expressions.insert(result, res_name); + + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "subgroupAll(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "subgroupAny(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupAdd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupMul(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "subgroupMax(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "subgroupMin(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "subgroupAnd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "subgroupOr(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "subgroupXor(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupExclusiveAdd(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupExclusiveMul(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupInclusiveAdd(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupInclusiveMul(")? + } + _ => unimplemented!(), + } + self.write_expr(argument, ctx)?; + writeln!(self.out, ");")?; } Statement::SubgroupGather { mode, argument, result, } => { - unimplemented!(); // FIXME + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + let res_ty = ctx.info[result].ty.inner_with(&self.module.types); + self.write_value_type(res_ty)?; + write!(self.out, " {res_name} = ")?; + self.named_expressions.insert(result, res_name); + + match mode { + crate::GatherMode::BroadcastFirst => { + write!(self.out, "subgroupBroadcastFirst(")?; + } + crate::GatherMode::Broadcast(_) => { + write!(self.out, "subgroupBroadcast(")?; + } + crate::GatherMode::Shuffle(_) => { + write!(self.out, "subgroupShuffle(")?; + } + crate::GatherMode::ShuffleDown(_) => { + write!(self.out, "subgroupShuffleDown(")?; + } + crate::GatherMode::ShuffleUp(_) => { + write!(self.out, "subgroupShuffleUp(")?; + } + crate::GatherMode::ShuffleXor(_) => { + write!(self.out, "subgroupShuffleXor(")?; + } + } + self.write_expr(argument, ctx)?; + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + write!(self.out, ", ")?; + self.write_expr(index, ctx)?; + } + } + writeln!(self.out, ");")?; } } @@ -4013,7 +4102,7 @@ impl<'a, W: Write> Writer<'a, W> { writeln!(self.out, "{level}memoryBarrierShared();")?; } if flags.contains(crate::Barrier::SUB_GROUP) { - unimplemented!() // FIXME + writeln!(self.out, "{level}subgroupMemoryBarrier();")?; } writeln!(self.out, "{level}barrier();")?; Ok(()) @@ -4192,9 +4281,10 @@ const fn glsl_built_in( Bi::WorkGroupSize => "gl_WorkGroupSize", Bi::NumWorkGroups => "gl_NumWorkGroups", // subgroup - Bi::NumSubgroups | Bi::SubgroupId => todo!(), - Bi::SubgroupInvocationId => "gl_SubgroupInvocationID", + Bi::NumSubgroups => "gl_NumSubgroups", + Bi::SubgroupId => "gl_SubgroupID", Bi::SubgroupSize => "gl_SubgroupSize", + Bi::SubgroupInvocationId => "gl_SubgroupInvocationID", } } diff --git a/src/back/hlsl/conv.rs b/src/back/hlsl/conv.rs index 3f51278cc1..d3fb76e401 100644 --- a/src/back/hlsl/conv.rs +++ b/src/back/hlsl/conv.rs @@ -166,13 +166,13 @@ impl crate::BuiltIn { // to this field will get replaced with references to `SPECIAL_CBUF_VAR` // in `Writer::write_expr`. Self::NumWorkGroups => "SV_GroupID", - - Self::NumSubgroups | Self::SubgroupId => todo!(), - Self::SubgroupInvocationId - | Self::SubgroupSize - | Self::BaseInstance - | Self::BaseVertex - | Self::WorkGroupSize => return Err(Error::Unimplemented(format!("builtin {self:?}"))), + Self::SubgroupSize + | Self::SubgroupInvocationId + | Self::NumSubgroups + | Self::SubgroupId => return Err(Error::Unimplemented(format!("builtin {self:?}"))), + Self::BaseInstance | Self::BaseVertex | Self::WorkGroupSize => { + return Err(Error::Unimplemented(format!("builtin {self:?}"))) + } Self::PointSize | Self::ViewIndex | Self::PointCoord => { return Err(Error::Custom(format!("Unsupported builtin {self:?}"))) } diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index a04fef7a1b..1eab43a4c3 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -1130,7 +1130,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, " {name}(")?; let need_workgroup_variables_initialization = - self.need_workgroup_variables_initialization(func_ctx, module); + self.need_workgroup_variables_initialization(func, func_ctx, module); // Write function arguments for non entry point functions match func_ctx.ty { @@ -1166,7 +1166,21 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { write!(self.out, "{} {}", ep_input.ty_name, ep_input.arg_name,)?; } else { let stage = module.entry_points[ep_index as usize].stage; + let mut arg_num = 0; for (index, arg) in func.arguments.iter().enumerate() { + if matches!( + arg.binding, + Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupSize)) + | Some(crate::Binding::BuiltIn( + crate::BuiltIn::SubgroupInvocationId + )) + | Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) + | Some(crate::Binding::BuiltIn(crate::BuiltIn::SubgroupId)) + ) { + continue; + } + arg_num += 1; + if index != 0 { write!(self.out, ", ")?; } @@ -1186,7 +1200,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } if need_workgroup_variables_initialization { - if !func.arguments.is_empty() { + if arg_num > 0 { write!(self.out, ", ")?; } write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?; @@ -1217,6 +1231,53 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_workgroup_variables_initialization(func_ctx, module)?; } + if let back::FunctionType::EntryPoint(ep_index) = func_ctx.ty { + let ep = &module.entry_points[ep_index as usize]; + for (index, arg) in func.arguments.iter().enumerate() { + if let Some(crate::Binding::BuiltIn(builtin)) = arg.binding { + if matches!( + builtin, + crate::BuiltIn::SubgroupSize + | crate::BuiltIn::SubgroupInvocationId + | crate::BuiltIn::NumSubgroups + | crate::BuiltIn::SubgroupId + ) { + let level = back::Level(1); + write!(self.out, "{level}const ")?; + + self.write_type(module, arg.ty)?; + + let argument_name = + &self.names[&NameKey::EntryPointArgument(ep_index, index as u32)]; + write!(self.out, " {argument_name} = ")?; + + match builtin { + crate::BuiltIn::SubgroupSize => { + writeln!(self.out, "WaveGetLaneCount();")? + } + crate::BuiltIn::SubgroupInvocationId => { + writeln!(self.out, "WaveGetLaneIndex();")? + } + crate::BuiltIn::NumSubgroups => writeln!( + self.out, + "({}u + WaveGetLaneCount() - 1u) / WaveGetLaneCount();", + ep.workgroup_size[0] * ep.workgroup_size[1] * ep.workgroup_size[2] + )?, + crate::BuiltIn::SubgroupId => { + writeln!( + self.out, + "(__local_invocation_id.x * {}u + __local_invocation_id.y * {}u + __local_invocation_id.z) / WaveGetLaneCount();", + ep.workgroup_size[0] * ep.workgroup_size[1], + ep.workgroup_size[1], + )?; + } + _ => unreachable!(), + } + } + } + } + } + if let back::FunctionType::EntryPoint(index) = func_ctx.ty { self.write_ep_arguments_initialization(module, func, index)?; } @@ -1267,14 +1328,20 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { fn need_workgroup_variables_initialization( &mut self, + func: &crate::Function, func_ctx: &back::FunctionCtx, module: &Module, ) -> bool { - self.options.zero_initialize_workgroup_memory + func.arguments.iter().any(|arg| { + matches!( + arg.binding, + Some(crate::Binding::BuiltIn(crate::BuiltIn::NumSubgroups)) + ) + }) || (self.options.zero_initialize_workgroup_memory && func_ctx.ty.is_compute_entry_point(module) && module.global_variables.iter().any(|(handle, var)| { !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup - }) + })) } fn write_workgroup_variables_initialization( @@ -2006,7 +2073,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { Statement::RayQuery { .. } => unreachable!(), Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; - let name = format!("{}{}", back::BAKE_PREFIX, result.index()); write!(self.out, "const uint4 {name} = ")?; self.named_expressions.insert(result, name); @@ -2024,14 +2090,109 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { argument, result, } => { - unimplemented!(); // FIXME + write!(self.out, "{level}")?; + write!(self.out, "const ")?; + let name = format!("{}{}", back::BAKE_PREFIX, result.index()); + match func_ctx.info[result].ty { + proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?, + proc::TypeResolution::Value(ref value) => { + self.write_value_type(module, value)? + } + }; + write!(self.out, " {name} = ")?; + self.named_expressions.insert(result, name); + + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "WaveActiveAllTrue(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "WaveActiveAnyTrue(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "WaveActiveSum(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "WaveActiveProduct(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "WaveActiveMax(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "WaveActiveMin(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "WaveActiveBitAnd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "WaveActiveBitOr(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "WaveActiveBitXor(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "WavePrefixSum(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "WavePrefixProduct(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => { + self.write_expr(module, argument, func_ctx)?; + write!(self.out, " + WavePrefixSum(")?; + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => { + self.write_expr(module, argument, func_ctx)?; + write!(self.out, " * WavePrefixProduct(")?; + } + _ => unimplemented!(), + } + self.write_expr(module, argument, func_ctx)?; + writeln!(self.out, ");")?; } Statement::SubgroupGather { mode, argument, result, } => { - unimplemented!(); // FIXME + write!(self.out, "{level}")?; + write!(self.out, "const ")?; + let name = format!("{}{}", back::BAKE_PREFIX, result.index()); + match func_ctx.info[result].ty { + proc::TypeResolution::Handle(handle) => self.write_type(module, handle)?, + proc::TypeResolution::Value(ref value) => { + self.write_value_type(module, value)? + } + }; + write!(self.out, " {name} = ")?; + self.named_expressions.insert(result, name); + + if matches!(mode, crate::GatherMode::BroadcastFirst) { + write!(self.out, "WaveReadLaneFirst(")?; + self.write_expr(module, argument, func_ctx)?; + } else { + write!(self.out, "WaveReadLaneAt(")?; + self.write_expr(module, argument, func_ctx)?; + write!(self.out, ", ")?; + match mode { + crate::GatherMode::BroadcastFirst => unreachable!(), + crate::GatherMode::Broadcast(index) | crate::GatherMode::Shuffle(index) => { + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::ShuffleDown(index) => { + write!(self.out, "WaveGetLaneIndex() + ")?; + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::ShuffleUp(index) => { + write!(self.out, "WaveGetLaneIndex() - ")?; + self.write_expr(module, index, func_ctx)?; + } + crate::GatherMode::ShuffleXor(index) => { + write!(self.out, "WaveGetLaneIndex() ^ ")?; + self.write_expr(module, index, func_ctx)?; + } + } + } + writeln!(self.out, ");")?; } } @@ -3251,7 +3412,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { writeln!(self.out, "{level}GroupMemoryBarrierWithGroupSync();")?; } if barrier.contains(crate::Barrier::SUB_GROUP) { - unimplemented!() // FIXME + // Does not exist in DirectX } Ok(()) } diff --git a/src/back/msl/mod.rs b/src/back/msl/mod.rs index 4e0d8489e4..eee825a83b 100644 --- a/src/back/msl/mod.rs +++ b/src/back/msl/mod.rs @@ -438,9 +438,10 @@ impl ResolvedBinding { Bi::WorkGroupSize => "dispatch_threads_per_threadgroup", Bi::NumWorkGroups => "threadgroups_per_grid", // subgroup - Bi::NumSubgroups | Bi::SubgroupId => todo!(), - Bi::SubgroupInvocationId => "simdgroup_index_in_threadgroup", - Bi::SubgroupSize => "simdgroups_per_threadgroup", + Bi::NumSubgroups => "simdgroups_per_threadgroup", + Bi::SubgroupId => "simdgroup_index_in_threadgroup", + Bi::SubgroupSize => "threads_per_simdgroup", + Bi::SubgroupInvocationId => "thread_index_in_simdgroup", Bi::CullDistance | Bi::ViewIndex => { return Err(Error::UnsupportedBuiltIn(built_in)) } diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 629b98c96d..ce57588240 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -3017,14 +3017,13 @@ impl Writer { let name = self.namer.call(""); self.start_baking_expression(result, &context.expression, &name)?; self.named_expressions.insert(result, name); - write!(self.out, "{NAMESPACE}::simd_ballot(;")?; - match predicate { - Some(predicate) => { - self.put_expression(predicate, &context.expression, true)? - } - None => write!(self.out, "true")?, + write!(self.out, "uint4((uint64_t){NAMESPACE}::simd_ballot(")?; + if let Some(predicate) = predicate { + self.put_expression(predicate, &context.expression, true)?; + } else { + write!(self.out, "true")?; } - writeln!(self.out, ");")?; + writeln!(self.out, "), 0, 0, 0);")?; } crate::Statement::SubgroupCollectiveOperation { op, @@ -3032,14 +3031,101 @@ impl Writer { argument, result, } => { - unimplemented!(); // FIXME + write!(self.out, "{level}")?; + let name = self.namer.call(""); + self.start_baking_expression(result, &context.expression, &name)?; + self.named_expressions.insert(result, name); + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "{NAMESPACE}::simd_all(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "{NAMESPACE}::simd_any(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "{NAMESPACE}::simd_sum(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "{NAMESPACE}::simd_product(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "{NAMESPACE}::simd_max(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "{NAMESPACE}::simd_min(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "{NAMESPACE}::simd_and(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "{NAMESPACE}::simd_or(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "{NAMESPACE}::simd_xor(")? + } + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Add, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_sum(")?, + ( + crate::CollectiveOperation::ExclusiveScan, + crate::SubgroupOperation::Mul, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_exclusive_product(")?, + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Add, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_sum(")?, + ( + crate::CollectiveOperation::InclusiveScan, + crate::SubgroupOperation::Mul, + ) => write!(self.out, "{NAMESPACE}::simd_prefix_inclusive_product(")?, + _ => unimplemented!(), + } + self.put_expression(argument, &context.expression, true)?; + writeln!(self.out, ");")?; } crate::Statement::SubgroupGather { mode, argument, result, } => { - unimplemented!(); // FIXME + write!(self.out, "{level}")?; + let name = self.namer.call(""); + self.start_baking_expression(result, &context.expression, &name)?; + self.named_expressions.insert(result, name); + match mode { + crate::GatherMode::BroadcastFirst => { + write!(self.out, "{NAMESPACE}::simd_broadcast_first(")?; + } + crate::GatherMode::Broadcast(_) => { + write!(self.out, "{NAMESPACE}::simd_broadcast(")?; + } + crate::GatherMode::Shuffle(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle(")?; + } + crate::GatherMode::ShuffleDown(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle_down(")?; + } + crate::GatherMode::ShuffleUp(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle_up(")?; + } + crate::GatherMode::ShuffleXor(_) => { + write!(self.out, "{NAMESPACE}::simd_shuffle_xor(")?; + } + } + self.put_expression(argument, &context.expression, true)?; + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + write!(self.out, ", ")?; + self.put_expression(index, &context.expression, true)?; + } + } + writeln!(self.out, ");")?; } } } @@ -4378,7 +4464,10 @@ impl Writer { )?; } if flags.contains(crate::Barrier::SUB_GROUP) { - unimplemented!(); // FIXME + writeln!( + self.out, + "{level}{NAMESPACE}::simdgroup_barrier({NAMESPACE}::mem_flags::mem_threadgroup);", + )?; } Ok(()) } diff --git a/src/back/spv/block.rs b/src/back/spv/block.rs index 9fba489f79..50883ce071 100644 --- a/src/back/spv/block.rs +++ b/src/back/spv/block.rs @@ -2340,30 +2340,11 @@ impl<'w> BlockContext<'w> { crate::Statement::RayQuery { query, ref fun } => { self.write_ray_query_function(query, fun, &mut block); } - crate::Statement::SubgroupBallot { result, predicate } => { - self.writer.require_any( - "GroupNonUniformBallot", - &[spirv::Capability::GroupNonUniformBallot], - )?; - let vec4_u32_type_id = self.get_type_id(LookupType::Local(LocalType::Value { - vector_size: Some(crate::VectorSize::Quad), - kind: crate::ScalarKind::Uint, - width: 4, - pointer_space: None, - })); - let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); - let predicate = match predicate { - Some(predicate) => self.cached[predicate], - None => self.writer.get_constant_scalar(crate::Literal::Bool(true)), - }; - let id = self.gen_id(); - block.body.push(Instruction::group_non_uniform_ballot( - vec4_u32_type_id, - id, - exec_scope_id, - predicate, - )); - self.cached[result] = id; + crate::Statement::SubgroupBallot { + result, + ref predicate, + } => { + self.write_subgroup_ballot(predicate, result, &mut block)?; } crate::Statement::SubgroupCollectiveOperation { ref op, @@ -2378,7 +2359,7 @@ impl<'w> BlockContext<'w> { argument, result, } => { - self.write_subgroup_broadcast(mode, argument, result, &mut block)?; + self.write_subgroup_gather(mode, argument, result, &mut block)?; } } } diff --git a/src/back/spv/instructions.rs b/src/back/spv/instructions.rs index 8a528065ef..5f7c6b34fd 100644 --- a/src/back/spv/instructions.rs +++ b/src/back/spv/instructions.rs @@ -1054,33 +1054,34 @@ impl super::Instruction { instruction } - pub(super) fn group_non_uniform_broadcast( + pub(super) fn group_non_uniform_broadcast_first( result_type_id: Word, id: Word, exec_scope_id: Word, value: Word, - index: Word, ) -> Self { - let mut instruction = Self::new(Op::GroupNonUniformBroadcast); + let mut instruction = Self::new(Op::GroupNonUniformBroadcastFirst); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(exec_scope_id); instruction.add_operand(value); - instruction.add_operand(index); instruction } - pub(super) fn group_non_uniform_broadcast_first( + pub(super) fn group_non_uniform_gather( + op: Op, result_type_id: Word, id: Word, exec_scope_id: Word, value: Word, + index: Word, ) -> Self { - let mut instruction = Self::new(Op::GroupNonUniformBroadcastFirst); + let mut instruction = Self::new(op); instruction.set_type(result_type_id); instruction.set_result(id); instruction.add_operand(exec_scope_id); instruction.add_operand(value); + instruction.add_operand(index); instruction } @@ -1092,10 +1093,6 @@ impl super::Instruction { group_op: Option, value: Word, ) -> Self { - println!( - "{:?}", - (op, result_type_id, id, exec_scope_id, group_op, value) - ); let mut instruction = Self::new(op); instruction.set_type(result_type_id); instruction.set_result(id); diff --git a/src/back/spv/subgroup.rs b/src/back/spv/subgroup.rs index 7206ec9312..79db752a6c 100644 --- a/src/back/spv/subgroup.rs +++ b/src/back/spv/subgroup.rs @@ -1,7 +1,43 @@ use super::{Block, BlockContext, Error, Instruction}; -use crate::{arena::Handle, TypeInner}; +use crate::{ + arena::Handle, + back::spv::{LocalType, LookupType}, + TypeInner, +}; impl<'w> BlockContext<'w> { + pub(super) fn write_subgroup_ballot( + &mut self, + predicate: &Option>, + result: Handle, + block: &mut Block, + ) -> Result<(), Error> { + self.writer.require_any( + "GroupNonUniformBallot", + &[spirv::Capability::GroupNonUniformBallot], + )?; + let vec4_u32_type_id = self.get_type_id(LookupType::Local(LocalType::Value { + vector_size: Some(crate::VectorSize::Quad), + kind: crate::ScalarKind::Uint, + width: 4, + pointer_space: None, + })); + let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); + let predicate = if let Some(predicate) = *predicate { + self.cached[predicate] + } else { + self.writer.get_constant_scalar(crate::Literal::Bool(true)) + }; + let id = self.gen_id(); + block.body.push(Instruction::group_non_uniform_ballot( + vec4_u32_type_id, + id, + exec_scope_id, + predicate, + )); + self.cached[result] = id; + Ok(()) + } pub(super) fn write_subgroup_operation( &mut self, op: &crate::SubgroupOperation, @@ -11,7 +47,7 @@ impl<'w> BlockContext<'w> { block: &mut Block, ) -> Result<(), Error> { use crate::SubgroupOperation as sg; - match op { + match *op { sg::All | sg::Any => { self.writer.require_any( "GroupNonUniformVote", @@ -21,11 +57,7 @@ impl<'w> BlockContext<'w> { _ => { self.writer.require_any( "GroupNonUniformArithmetic", - &[ - spirv::Capability::GroupNonUniformArithmetic, - spirv::Capability::GroupNonUniformClustered, - spirv::Capability::GroupNonUniformPartitionedNV, - ], + &[spirv::Capability::GroupNonUniformArithmetic], )?; } } @@ -35,14 +67,14 @@ impl<'w> BlockContext<'w> { let result_type_id = self.get_expression_type_id(result_ty); let result_ty_inner = result_ty.inner_with(&self.ir_module.types); - let (is_scalar, kind) = match result_ty_inner { + let (is_scalar, kind) = match *result_ty_inner { TypeInner::Scalar { kind, .. } => (true, kind), TypeInner::Vector { kind, .. } => (false, kind), _ => unimplemented!(), }; use crate::ScalarKind as sk; - let spirv_op = match (kind, op) { + let spirv_op = match (kind, *op) { (sk::Bool, sg::All) if is_scalar => spirv::Op::GroupNonUniformAll, (sk::Bool, sg::Any) if is_scalar => spirv::Op::GroupNonUniformAny, (_, sg::All | sg::Any) => unimplemented!(), @@ -71,9 +103,9 @@ impl<'w> BlockContext<'w> { let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); use crate::CollectiveOperation as c; - let group_op = match op { + let group_op = match *op { sg::All | sg::Any => None, - _ => Some(match collective_op { + _ => Some(match *collective_op { c::Reduce => spirv::GroupOperation::Reduce, c::InclusiveScan => spirv::GroupOperation::InclusiveScan, c::ExclusiveScan => spirv::GroupOperation::ExclusiveScan, @@ -92,7 +124,7 @@ impl<'w> BlockContext<'w> { self.cached[result] = id; Ok(()) } - pub(super) fn write_subgroup_broadcast( + pub(super) fn write_subgroup_gather( &mut self, mode: &crate::GatherMode, argument: Handle, @@ -103,6 +135,26 @@ impl<'w> BlockContext<'w> { "GroupNonUniformBallot", &[spirv::Capability::GroupNonUniformBallot], )?; + match *mode { + crate::GatherMode::BroadcastFirst | crate::GatherMode::Broadcast(_) => { + self.writer.require_any( + "GroupNonUniformBallot", + &[spirv::Capability::GroupNonUniformBallot], + )?; + } + crate::GatherMode::Shuffle(_) | crate::GatherMode::ShuffleXor(_) => { + self.writer.require_any( + "GroupNonUniformShuffle", + &[spirv::Capability::GroupNonUniformShuffle], + )?; + } + crate::GatherMode::ShuffleDown(_) | crate::GatherMode::ShuffleUp(_) => { + self.writer.require_any( + "GroupNonUniformShuffleRelative", + &[spirv::Capability::GroupNonUniformShuffleRelative], + )?; + } + } let id = self.gen_id(); let result_ty = &self.fun_info[result].ty; @@ -111,17 +163,7 @@ impl<'w> BlockContext<'w> { let exec_scope_id = self.get_index_constant(spirv::Scope::Subgroup as u32); let arg_id = self.cached[argument]; - match mode { - crate::GatherMode::Broadcast(index) => { - let index_id = self.cached[*index]; - block.body.push(Instruction::group_non_uniform_broadcast( - result_type_id, - id, - exec_scope_id, - arg_id, - index_id, - )); - } + match *mode { crate::GatherMode::BroadcastFirst => { block .body @@ -132,10 +174,29 @@ impl<'w> BlockContext<'w> { arg_id, )); } - crate::GatherMode::Shuffle(index) + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) | crate::GatherMode::ShuffleDown(index) | crate::GatherMode::ShuffleUp(index) - | crate::GatherMode::ShuffleXor(index) => todo!(), + | crate::GatherMode::ShuffleXor(index) => { + let index_id = self.cached[index]; + let op = match *mode { + crate::GatherMode::BroadcastFirst => unreachable!(), + crate::GatherMode::Broadcast(_) => spirv::Op::GroupNonUniformBroadcast, + crate::GatherMode::Shuffle(_) => spirv::Op::GroupNonUniformShuffle, + crate::GatherMode::ShuffleDown(_) => spirv::Op::GroupNonUniformShuffleDown, + crate::GatherMode::ShuffleUp(_) => spirv::Op::GroupNonUniformShuffleUp, + crate::GatherMode::ShuffleXor(_) => spirv::Op::GroupNonUniformShuffleXor, + }; + block.body.push(Instruction::group_non_uniform_gather( + op, + result_type_id, + id, + exec_scope_id, + arg_id, + index_id, + )); + } } self.cached[result] = id; Ok(()) diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index 6a55288308..da0fdf766f 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -1594,17 +1594,23 @@ impl Writer { Bi::WorkGroupSize => BuiltIn::WorkgroupSize, Bi::NumWorkGroups => BuiltIn::NumWorkgroups, // Subgroup - Bi::NumSubgroups | Bi::SubgroupId => todo!(), - Bi::SubgroupInvocationId => { + Bi::NumSubgroups => { self.require_any( - "`subgroup_invocation_id` built-in", + "`num_subgroups` built-in", &[spirv::Capability::GroupNonUniform], )?; - BuiltIn::SubgroupLocalInvocationId + BuiltIn::NumSubgroups + } + Bi::SubgroupId => { + self.require_any( + "`subgroup_id` built-in", + &[spirv::Capability::GroupNonUniform], + )?; + BuiltIn::SubgroupId } Bi::SubgroupSize => { self.require_any( - "`subgroup_invocation_id` built-in", + "`subgroup_size` built-in", &[ spirv::Capability::GroupNonUniform, spirv::Capability::SubgroupBallotKHR, @@ -1612,6 +1618,16 @@ impl Writer { )?; BuiltIn::SubgroupSize } + Bi::SubgroupInvocationId => { + self.require_any( + "`subgroup_invocation_id` built-in", + &[ + spirv::Capability::GroupNonUniform, + spirv::Capability::SubgroupBallotKHR, + ], + )?; + BuiltIn::SubgroupLocalInvocationId + } }; self.decorate(id, Decoration::BuiltIn, &[built_in as u32]); diff --git a/src/back/wgsl/writer.rs b/src/back/wgsl/writer.rs index 8fc9881f10..0bc2dfceb0 100644 --- a/src/back/wgsl/writer.rs +++ b/src/back/wgsl/writer.rs @@ -919,6 +919,10 @@ impl Writer { if barrier.contains(crate::Barrier::WORK_GROUP) { writeln!(self.out, "{level}workgroupBarrier();")?; } + + if barrier.contains(crate::Barrier::SUB_GROUP) { + writeln!(self.out, "{level}subgroupBarrier();")?; + } } Statement::RayQuery { .. } => unreachable!(), Statement::SubgroupBallot { result, predicate } => { @@ -939,14 +943,99 @@ impl Writer { argument, result, } => { - unimplemented!() // FIXME + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_named_expr(module, result, func_ctx, &res_name)?; + self.named_expressions.insert(result, res_name); + + match (collective_op, op) { + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => { + write!(self.out, "subgroupAll(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => { + write!(self.out, "subgroupAny(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupAdd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupMul(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => { + write!(self.out, "subgroupMax(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => { + write!(self.out, "subgroupMin(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => { + write!(self.out, "subgroupAnd(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => { + write!(self.out, "subgroupOr(")? + } + (crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => { + write!(self.out, "subgroupXor(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupPrefixExclusiveAdd(")? + } + (crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupPrefixExclusiveMul(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => { + write!(self.out, "subgroupPrefixInclusiveAdd(")? + } + (crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => { + write!(self.out, "subgroupPrefixInclusiveMul(")? + } + _ => unimplemented!(), + } + self.write_expr(module, argument, func_ctx)?; + writeln!(self.out, ");")?; } Statement::SubgroupGather { mode, argument, result, } => { - unimplemented!() // FIXME + write!(self.out, "{level}")?; + let res_name = format!("{}{}", back::BAKE_PREFIX, result.index()); + self.start_named_expr(module, result, func_ctx, &res_name)?; + self.named_expressions.insert(result, res_name); + + match mode { + crate::GatherMode::BroadcastFirst => { + write!(self.out, "subgroupBroadcastFirst(")?; + } + crate::GatherMode::Broadcast(_) => { + write!(self.out, "subgroupBroadcast(")?; + } + crate::GatherMode::Shuffle(_) => { + write!(self.out, "subgroupShuffle(")?; + } + crate::GatherMode::ShuffleDown(_) => { + write!(self.out, "subgroupShuffleDown(")?; + } + crate::GatherMode::ShuffleUp(_) => { + write!(self.out, "subgroupShuffleUp(")?; + } + crate::GatherMode::ShuffleXor(_) => { + write!(self.out, "subgroupShuffleXor(")?; + } + } + self.write_expr(module, argument, func_ctx)?; + match mode { + crate::GatherMode::BroadcastFirst => {} + crate::GatherMode::Broadcast(index) + | crate::GatherMode::Shuffle(index) + | crate::GatherMode::ShuffleDown(index) + | crate::GatherMode::ShuffleUp(index) + | crate::GatherMode::ShuffleXor(index) => { + write!(self.out, ", ")?; + self.write_expr(module, index, func_ctx)?; + } + } + writeln!(self.out, ");")?; } } @@ -1789,9 +1878,10 @@ fn builtin_str(built_in: crate::BuiltIn) -> Result<&'static str, Error> { Bi::SampleMask => "sample_mask", Bi::PrimitiveIndex => "primitive_index", Bi::ViewIndex => "view_index", - Bi::NumSubgroups | Bi::SubgroupId => todo!(), - Bi::SubgroupInvocationId => "subgroup_invocation_id", + Bi::NumSubgroups => "num_subgroups", + Bi::SubgroupId => "subgroup_id", Bi::SubgroupSize => "subgroup_size", + Bi::SubgroupInvocationId => "subgroup_invocation_id", Bi::BaseInstance | Bi::BaseVertex | Bi::ClipDistance diff --git a/src/front/spv/error.rs b/src/front/spv/error.rs index 2f9bf2d1bc..8508ede042 100644 --- a/src/front/spv/error.rs +++ b/src/front/spv/error.rs @@ -54,6 +54,8 @@ pub enum Error { UnknownBinaryOperator(spirv::Op), #[error("unknown relational function {0:?}")] UnknownRelationalFunction(spirv::Op), + #[error("unsupported group opeation %{0}")] + UnsupportedGroupOperation(spirv::Word), #[error("invalid parameter {0:?}")] InvalidParameter(spirv::Op), #[error("invalid operand count {1} for {0:?}")] @@ -116,8 +118,8 @@ pub enum Error { FunctionCallCycle(spirv::Word), #[error("invalid array size {0:?}")] InvalidArraySize(Handle), - #[error("invalid barrier scope %{0}")] - InvalidBarrierScope(spirv::Word), + #[error("invalid execution scope %{0}")] + InvalidExecutionScope(spirv::Word), #[error("invalid barrier memory semantics %{0}")] InvalidBarrierMemorySemantics(spirv::Word), #[error( diff --git a/src/front/spv/mod.rs b/src/front/spv/mod.rs index c78407ef99..e7c082c6a5 100644 --- a/src/front/spv/mod.rs +++ b/src/front/spv/mod.rs @@ -3650,7 +3650,7 @@ impl> Frontend { let semantics_const = self.lookup_constant.lookup(semantics_id)?; let exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle) - .ok_or(Error::InvalidBarrierScope(exec_scope_id))?; + .ok_or(Error::InvalidExecutionScope(exec_scope_id))?; let semantics = resolve_constant(ctx.gctx(), semantics_const.handle) .ok_or(Error::InvalidBarrierMemorySemantics(semantics_id))?; @@ -3691,6 +3691,209 @@ impl> Frontend { }, ); } + Op::GroupNonUniformBallot => { + inst.expect(4)?; + let _result_type_id = self.next()?; + let result_id = self.next()?; + let exec_scope_id = self.next()?; + let predicate_id = self.next()?; + + let result_lookup = self.lookup_expression.lookup(result_id)?; + let result_handle = get_expr_handle!(result_id, result_lookup); + + let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; + let _exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle) + .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) + .ok_or(Error::InvalidExecutionScope(exec_scope_id))?; + + let predicate = if self + .lookup_constant + .lookup(predicate_id) + .ok() + .filter(|predicate_const| { + matches!( + ctx.gctx().const_expressions + [ctx.gctx().constants[predicate_const.handle].init], + crate::Expression::Literal(crate::Literal::Bool(true)) + ) + }) + .is_some() + { + None + } else { + let predicate_lookup = self.lookup_expression.lookup(predicate_id)?; + let predicate_handle = get_expr_handle!(predicate_id, predicate_lookup); + Some(predicate_handle) + }; + + block.push( + crate::Statement::SubgroupBallot { + result: result_handle, + predicate, + }, + span, + ); + } + spirv::Op::GroupNonUniformAll + | spirv::Op::GroupNonUniformAny + | spirv::Op::GroupNonUniformIAdd + | spirv::Op::GroupNonUniformFAdd + | spirv::Op::GroupNonUniformIMul + | spirv::Op::GroupNonUniformFMul + | spirv::Op::GroupNonUniformSMax + | spirv::Op::GroupNonUniformUMax + | spirv::Op::GroupNonUniformFMax + | spirv::Op::GroupNonUniformSMin + | spirv::Op::GroupNonUniformUMin + | spirv::Op::GroupNonUniformFMin + | spirv::Op::GroupNonUniformBitwiseAnd + | spirv::Op::GroupNonUniformBitwiseOr + | spirv::Op::GroupNonUniformBitwiseXor + | spirv::Op::GroupNonUniformLogicalAnd + | spirv::Op::GroupNonUniformLogicalOr + | spirv::Op::GroupNonUniformLogicalXor => { + inst.expect( + if matches!( + inst.op, + spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny + ) { + 4 + } else { + 5 + }, + )?; + let _result_type_id = self.next()?; + let result_id = self.next()?; + let exec_scope_id = self.next()?; + let collective_op_id = match inst.op { + spirv::Op::GroupNonUniformAll | spirv::Op::GroupNonUniformAny => { + crate::CollectiveOperation::Reduce + } + _ => { + let group_op_id = self.next()?; + match spirv::GroupOperation::from_u32(group_op_id) { + Some(spirv::GroupOperation::Reduce) => { + crate::CollectiveOperation::Reduce + } + Some(spirv::GroupOperation::InclusiveScan) => { + crate::CollectiveOperation::InclusiveScan + } + Some(spirv::GroupOperation::ExclusiveScan) => { + crate::CollectiveOperation::ExclusiveScan + } + _ => return Err(Error::UnsupportedGroupOperation(group_op_id)), + } + } + }; + let argument_id = self.next()?; + + let result_lookup = self.lookup_expression.lookup(result_id)?; + let result_handle = get_expr_handle!(result_id, result_lookup); + let argument_lookup = self.lookup_expression.lookup(argument_id)?; + let argument_handle = get_expr_handle!(argument_id, argument_lookup); + + let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; + let _exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle) + .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) + .ok_or(Error::InvalidExecutionScope(exec_scope_id))?; + + let op_id = match inst.op { + spirv::Op::GroupNonUniformAll => crate::SubgroupOperation::All, + spirv::Op::GroupNonUniformAny => crate::SubgroupOperation::Any, + spirv::Op::GroupNonUniformIAdd | spirv::Op::GroupNonUniformFAdd => { + crate::SubgroupOperation::Add + } + spirv::Op::GroupNonUniformIMul | spirv::Op::GroupNonUniformFMul => { + crate::SubgroupOperation::Mul + } + spirv::Op::GroupNonUniformSMax + | spirv::Op::GroupNonUniformUMax + | spirv::Op::GroupNonUniformFMax => crate::SubgroupOperation::Max, + spirv::Op::GroupNonUniformSMin + | spirv::Op::GroupNonUniformUMin + | spirv::Op::GroupNonUniformFMin => crate::SubgroupOperation::Min, + spirv::Op::GroupNonUniformBitwiseAnd + | spirv::Op::GroupNonUniformLogicalAnd => crate::SubgroupOperation::And, + spirv::Op::GroupNonUniformBitwiseOr + | spirv::Op::GroupNonUniformLogicalOr => crate::SubgroupOperation::Or, + spirv::Op::GroupNonUniformBitwiseXor + | spirv::Op::GroupNonUniformLogicalXor => crate::SubgroupOperation::Xor, + _ => unreachable!(), + }; + + block.push( + crate::Statement::SubgroupCollectiveOperation { + result: result_handle, + op: op_id, + collective_op: collective_op_id, + argument: argument_handle, + }, + span, + ); + } + Op::GroupNonUniformBroadcastFirst + | Op::GroupNonUniformBroadcast + | Op::GroupNonUniformShuffle + | Op::GroupNonUniformShuffleDown + | Op::GroupNonUniformShuffleUp + | Op::GroupNonUniformShuffleXor => { + inst.expect( + if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) { + 4 + } else { + 5 + }, + )?; + let _result_type_id = self.next()?; + let result_id = self.next()?; + let exec_scope_id = self.next()?; + let argument_id = self.next()?; + + let result_lookup = self.lookup_expression.lookup(result_id)?; + let result_handle = get_expr_handle!(result_id, result_lookup); + let argument_lookup = self.lookup_expression.lookup(argument_id)?; + let argument_handle = get_expr_handle!(argument_id, argument_lookup); + + let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?; + let _exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle) + .filter(|exec_scope| *exec_scope == spirv::Scope::Subgroup as u32) + .ok_or(Error::InvalidExecutionScope(exec_scope_id))?; + + let mode = if matches!(inst.op, spirv::Op::GroupNonUniformBroadcastFirst) { + crate::GatherMode::BroadcastFirst + } else { + let index_id = self.next()?; + let index_lookup = self.lookup_expression.lookup(index_id)?; + let index_handle = get_expr_handle!(index_id, index_lookup); + match inst.op { + spirv::Op::GroupNonUniformBroadcast => { + crate::GatherMode::Broadcast(index_handle) + } + spirv::Op::GroupNonUniformShuffle => { + crate::GatherMode::Shuffle(index_handle) + } + spirv::Op::GroupNonUniformShuffleDown => { + crate::GatherMode::ShuffleDown(index_handle) + } + spirv::Op::GroupNonUniformShuffleUp => { + crate::GatherMode::ShuffleUp(index_handle) + } + spirv::Op::GroupNonUniformShuffleXor => { + crate::GatherMode::ShuffleXor(index_handle) + } + _ => unreachable!(), + } + }; + + block.push( + crate::Statement::SubgroupGather { + result: result_handle, + mode, + argument: argument_handle, + }, + span, + ); + } _ => return Err(Error::UnsupportedInstruction(self.state, inst.op)), } }; @@ -3811,7 +4014,10 @@ impl> Frontend { | S::Store { .. } | S::ImageStore { .. } | S::Atomic { .. } - | S::RayQuery { .. } => {} + | S::RayQuery { .. } + | S::SubgroupBallot { .. } + | S::SubgroupCollectiveOperation { .. } + | S::SubgroupGather { .. } => {} S::Call { function: ref mut callee, ref arguments, @@ -3842,9 +4048,6 @@ impl> Frontend { } } S::WorkGroupUniformLoad { .. } => unreachable!(), - S::SubgroupBallot { .. } => unreachable!(), // FIXME?? - S::SubgroupCollectiveOperation { .. } => unreachable!(), - S::SubgroupGather { .. } => unreachable!(), } i += 1; } diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index 127008d437..d3e47f6959 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -1918,7 +1918,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } else if let Some(fun) = Texture::map(function.name) { self.texture_sample_helper(fun, arguments, span, ctx.reborrow())? } else if let Some((op, cop)) = conv::map_subgroup_operation(function.name) { - return Ok(Some(self.subgroup_helper(span, op, cop, arguments, ctx)?)); + return Ok(Some( + self.subgroup_operation_helper(span, op, cop, arguments, ctx)?, + )); + } else if let Some(mode) = conv::map_subgroup_gather(function.name) { + return Ok(Some( + self.subgroup_gather_helper(span, mode, arguments, ctx)?, + )); } else { match function.name { "select" => { @@ -2316,53 +2322,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .push(crate::Statement::SubgroupBallot { result, predicate }, span); return Ok(Some(result)); } - "subgroupBroadcast" => { - let mut args = ctx.prepare_args(arguments, 2, span); - - let index = self.expression(args.next()?, ctx.reborrow())?; - let argument = self.expression(args.next()?, ctx.reborrow())?; - args.finish()?; - - let ty = ctx.register_type(argument)?; - - let result = ctx.interrupt_emitter( - crate::Expression::SubgroupOperationResult { ty }, - span, - )?; - let rctx = ctx.runtime_expression_ctx(span)?; - rctx.block.push( - crate::Statement::SubgroupGather { - mode: crate::GatherMode::Broadcast(index), - argument, - result, - }, - span, - ); - return Ok(Some(result)); - } - "subgroupBroadcastFirst" => { - let mut args = ctx.prepare_args(arguments, 1, span); - - let argument = self.expression(args.next()?, ctx.reborrow())?; - args.finish()?; - - let ty = ctx.register_type(argument)?; - - let result = ctx.interrupt_emitter( - crate::Expression::SubgroupOperationResult { ty }, - span, - )?; - let rctx = ctx.runtime_expression_ctx(span)?; - rctx.block.push( - crate::Statement::SubgroupGather { - mode: crate::GatherMode::BroadcastFirst, - argument, - result, - }, - span, - ); - return Ok(Some(result)); - } _ => return Err(Error::UnknownIdent(function.span, function.name)), } }; @@ -2554,7 +2513,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { depth_ref, }) } - fn subgroup_helper( + + fn subgroup_operation_helper( &mut self, span: Span, op: crate::SubgroupOperation, @@ -2584,6 +2544,46 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { Ok(result) } + fn subgroup_gather_helper( + &mut self, + span: Span, + mode: crate::GatherMode, + arguments: &[Handle>], + mut ctx: ExpressionContext<'source, '_, '_>, + ) -> Result, Error<'source>> { + let mut args = ctx.prepare_args(arguments, 2, span); + + let argument = self.expression(args.next()?, ctx.reborrow())?; + let index = if let crate::GatherMode::BroadcastFirst = mode { + Handle::new(NonZeroU32::new(u32::MAX).unwrap()) + } else { + self.expression(args.next()?, ctx.reborrow())? + }; + args.finish()?; + + let ty = ctx.register_type(argument)?; + + let result = + ctx.interrupt_emitter(crate::Expression::SubgroupOperationResult { ty }, span)?; + let rctx = ctx.runtime_expression_ctx(span)?; + rctx.block.push( + crate::Statement::SubgroupGather { + mode: match mode { + crate::GatherMode::BroadcastFirst => crate::GatherMode::BroadcastFirst, + crate::GatherMode::Broadcast(_) => crate::GatherMode::Broadcast(index), + crate::GatherMode::Shuffle(_) => crate::GatherMode::Shuffle(index), + crate::GatherMode::ShuffleDown(_) => crate::GatherMode::ShuffleDown(index), + crate::GatherMode::ShuffleUp(_) => crate::GatherMode::ShuffleUp(index), + crate::GatherMode::ShuffleXor(_) => crate::GatherMode::ShuffleXor(index), + }, + argument, + result, + }, + span, + ); + Ok(result) + } + fn r#struct( &mut self, s: &ast::Struct<'source>, diff --git a/src/front/wgsl/parse/conv.rs b/src/front/wgsl/parse/conv.rs index 160213e4a3..c53f4df753 100644 --- a/src/front/wgsl/parse/conv.rs +++ b/src/front/wgsl/parse/conv.rs @@ -35,8 +35,10 @@ pub fn map_built_in(word: &str, span: Span) -> Result> "workgroup_id" => crate::BuiltIn::WorkGroupId, "num_workgroups" => crate::BuiltIn::NumWorkGroups, // subgroup - "subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId, + "num_subgroups" => crate::BuiltIn::NumSubgroups, + "subgroup_id" => crate::BuiltIn::SubgroupId, "subgroup_size" => crate::BuiltIn::SubgroupSize, + "subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId, _ => return Err(Error::UnknownBuiltin(span)), }) } @@ -254,8 +256,25 @@ pub fn map_subgroup_operation( "subgroupAnd" => (sg::And, co::Reduce), "subgroupOr" => (sg::Or, co::Reduce), "subgroupXor" => (sg::Xor, co::Reduce), - "subgroupPrefixAdd" => (sg::Add, co::InclusiveScan), - "subgroupPrefixMul" => (sg::Mul, co::InclusiveScan), + "subgroupPrefixExclusiveAdd" => (sg::Add, co::ExclusiveScan), + "subgroupPrefixExclusiveMul" => (sg::Mul, co::ExclusiveScan), + "subgroupPrefixInclusiveAdd" => (sg::Add, co::InclusiveScan), + "subgroupPrefixInclusiveMul" => (sg::Mul, co::InclusiveScan), + _ => return None, + }) +} + +pub fn map_subgroup_gather(word: &str) -> Option { + use crate::GatherMode as gm; + use crate::Handle; + use std::num::NonZeroU32; + Some(match word { + "subgroupBroadcastFirst" => gm::BroadcastFirst, + "subgroupBroadcast" => gm::Broadcast(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), + "subgroupShuffle" => gm::Shuffle(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), + "subgroupShuffleDown" => gm::ShuffleDown(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), + "subgroupShuffleUp" => gm::ShuffleUp(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), + "subgroupShuffleXor" => gm::ShuffleXor(Handle::new(NonZeroU32::new(u32::MAX).unwrap())), _ => return None, }) }