diff --git a/tests/in/subgroup-operations.param.ron b/tests/in/subgroup-operations.param.ron new file mode 100644 index 0000000000..fc444a3efe --- /dev/null +++ b/tests/in/subgroup-operations.param.ron @@ -0,0 +1,26 @@ +( + spv: ( + version: (1, 3), + ), + msl: ( + lang_version: (2, 4), + per_entry_point_map: {}, + inline_samplers: [], + spirv_cross_compatibility: false, + fake_missing_bindings: false, + zero_initialize_workgroup_memory: true, + ), + glsl: ( + version: Desktop(430), + writer_flags: (""), + binding_map: { }, + zero_initialize_workgroup_memory: true, + ), + hlsl: ( + shader_model: V6_0, + binding_map: {}, + fake_missing_bindings: true, + special_constants_binding: None, + zero_initialize_workgroup_memory: true, + ), +) diff --git a/tests/in/subgroup-operations.wgsl b/tests/in/subgroup-operations.wgsl new file mode 100644 index 0000000000..f30b60be47 --- /dev/null +++ b/tests/in/subgroup-operations.wgsl @@ -0,0 +1,32 @@ +@compute @workgroup_size(1) +fn main( + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(subgroup_invocation_id) subgroup_invocation_id: u32, +) { + subgroupBarrier(); + + subgroupBallot((subgroup_invocation_id & 1u) == 1u); + + subgroupAll(subgroup_invocation_id != 0u); + subgroupAny(subgroup_invocation_id == 0u); + subgroupAdd(subgroup_invocation_id); + subgroupMul(subgroup_invocation_id); + subgroupMin(subgroup_invocation_id); + subgroupMax(subgroup_invocation_id); + subgroupAnd(subgroup_invocation_id); + subgroupOr(subgroup_invocation_id); + subgroupXor(subgroup_invocation_id); + subgroupPrefixExclusiveAdd(subgroup_invocation_id); + subgroupPrefixExclusiveMul(subgroup_invocation_id); + subgroupPrefixInclusiveAdd(subgroup_invocation_id); + subgroupPrefixInclusiveMul(subgroup_invocation_id); + + subgroupBroadcastFirst(subgroup_invocation_id); + subgroupBroadcast(subgroup_invocation_id, 4u); + subgroupShuffle(subgroup_invocation_id, subgroup_size - 1u - subgroup_invocation_id); + subgroupShuffleDown(subgroup_invocation_id, 1u); + subgroupShuffleUp(subgroup_invocation_id, 1u); + subgroupShuffleXor(subgroup_invocation_id, subgroup_size - 1u); +} diff --git a/tests/out/glsl/subgroup-operations.main.Compute.glsl b/tests/out/glsl/subgroup-operations.main.Compute.glsl new file mode 100644 index 0000000000..a37cf8e247 --- /dev/null +++ b/tests/out/glsl/subgroup-operations.main.Compute.glsl @@ -0,0 +1,41 @@ +#version 430 core +#extension GL_ARB_compute_shader : require +#extension GL_KHR_shader_subgroup_basic : require +#extension GL_KHR_shader_subgroup_vote : require +#extension GL_KHR_shader_subgroup_arithmetic : require +#extension GL_KHR_shader_subgroup_ballot : require +#extension GL_KHR_shader_subgroup_shuffle : require +#extension GL_KHR_shader_subgroup_shuffle_relative : require +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + + +void main() { + uint num_subgroups = gl_NumSubgroups; + uint subgroup_id = gl_SubgroupID; + uint subgroup_size = gl_SubgroupSize; + uint subgroup_invocation_id = gl_SubgroupInvocationID; + subgroupMemoryBarrier(); + barrier(); + uvec4 _e8 = subgroupBallot(((subgroup_invocation_id & 1u) == 1u)); + bool _e11 = subgroupAll((subgroup_invocation_id != 0u)); + bool _e14 = subgroupAny((subgroup_invocation_id == 0u)); + uint _e15 = subgroupAdd(subgroup_invocation_id); + uint _e16 = subgroupMul(subgroup_invocation_id); + uint _e17 = subgroupMin(subgroup_invocation_id); + uint _e18 = subgroupMax(subgroup_invocation_id); + uint _e19 = subgroupAnd(subgroup_invocation_id); + uint _e20 = subgroupOr(subgroup_invocation_id); + uint _e21 = subgroupXor(subgroup_invocation_id); + uint _e22 = subgroupExclusiveAdd(subgroup_invocation_id); + uint _e23 = subgroupExclusiveMul(subgroup_invocation_id); + uint _e24 = subgroupInclusiveAdd(subgroup_invocation_id); + uint _e25 = subgroupInclusiveMul(subgroup_invocation_id); + uint _e26 = subgroupBroadcastFirst(subgroup_invocation_id); + uint _e28 = subgroupBroadcast(subgroup_invocation_id, 4u); + uint _e32 = subgroupShuffle(subgroup_invocation_id, ((subgroup_size - 1u) - subgroup_invocation_id)); + uint _e34 = subgroupShuffleDown(subgroup_invocation_id, 1u); + uint _e36 = subgroupShuffleUp(subgroup_invocation_id, 1u); + uint _e39 = subgroupShuffleXor(subgroup_invocation_id, (subgroup_size - 1u)); + return; +} + diff --git a/tests/out/hlsl/subgroup-operations.hlsl b/tests/out/hlsl/subgroup-operations.hlsl new file mode 100644 index 0000000000..baa37826e0 --- /dev/null +++ b/tests/out/hlsl/subgroup-operations.hlsl @@ -0,0 +1,32 @@ +[numthreads(1, 1, 1)] +void main(uint3 __local_invocation_id : SV_GroupThreadID) +{ + if (all(__local_invocation_id == uint3(0u, 0u, 0u))) { + } + GroupMemoryBarrierWithGroupSync(); + const uint num_subgroups = (1u + WaveGetLaneCount() - 1u) / WaveGetLaneCount(); + const uint subgroup_id = (__local_invocation_id.x * 1u + __local_invocation_id.y * 1u + __local_invocation_id.z) / WaveGetLaneCount(); + const uint subgroup_size = WaveGetLaneCount(); + const uint subgroup_invocation_id = WaveGetLaneIndex(); + const uint4 _e8 = WaveActiveBallot(((subgroup_invocation_id & 1u) == 1u)); + const bool _e11 = WaveActiveAllTrue((subgroup_invocation_id != 0u)); + const bool _e14 = WaveActiveAnyTrue((subgroup_invocation_id == 0u)); + const uint _e15 = WaveActiveSum(subgroup_invocation_id); + const uint _e16 = WaveActiveProduct(subgroup_invocation_id); + const uint _e17 = WaveActiveMin(subgroup_invocation_id); + const uint _e18 = WaveActiveMax(subgroup_invocation_id); + const uint _e19 = WaveActiveBitAnd(subgroup_invocation_id); + const uint _e20 = WaveActiveBitOr(subgroup_invocation_id); + const uint _e21 = WaveActiveBitXor(subgroup_invocation_id); + const uint _e22 = WavePrefixSum(subgroup_invocation_id); + const uint _e23 = WavePrefixProduct(subgroup_invocation_id); + const uint _e24 = subgroup_invocation_id + WavePrefixSum(subgroup_invocation_id); + const uint _e25 = subgroup_invocation_id * WavePrefixProduct(subgroup_invocation_id); + const uint _e26 = WaveReadLaneFirst(subgroup_invocation_id); + const uint _e28 = WaveReadLaneAt(subgroup_invocation_id, 4u); + const uint _e32 = WaveReadLaneAt(subgroup_invocation_id, ((subgroup_size - 1u) - subgroup_invocation_id)); + const uint _e34 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() + 1u); + const uint _e36 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() - 1u); + const uint _e39 = WaveReadLaneAt(subgroup_invocation_id, WaveGetLaneIndex() ^ (subgroup_size - 1u)); + return; +} diff --git a/tests/out/hlsl/subgroup-operations.ron b/tests/out/hlsl/subgroup-operations.ron new file mode 100644 index 0000000000..b973fe3da1 --- /dev/null +++ b/tests/out/hlsl/subgroup-operations.ron @@ -0,0 +1,12 @@ +( + vertex:[ + ], + fragment:[ + ], + compute:[ + ( + entry_point:"main", + target_profile:"cs_6_0", + ), + ], +) diff --git a/tests/out/msl/subgroup-operations.msl b/tests/out/msl/subgroup-operations.msl new file mode 100644 index 0000000000..576fa3b84e --- /dev/null +++ b/tests/out/msl/subgroup-operations.msl @@ -0,0 +1,38 @@ +// language: metal2.4 +#include +#include + +using metal::uint; + + +struct main_Input { +}; +kernel void main_( + uint num_subgroups [[simdgroups_per_threadgroup]] +, uint subgroup_id [[simdgroup_index_in_threadgroup]] +, uint subgroup_size [[threads_per_simdgroup]] +, uint subgroup_invocation_id [[thread_index_in_simdgroup]] +) { + metal::simdgroup_barrier(metal::mem_flags::mem_threadgroup); + metal::uint4 unnamed = uint4((uint64_t)metal::simd_ballot((subgroup_invocation_id & 1u) == 1u), 0, 0, 0); + bool unnamed_1 = metal::simd_all(subgroup_invocation_id != 0u); + bool unnamed_2 = metal::simd_any(subgroup_invocation_id == 0u); + uint unnamed_3 = metal::simd_sum(subgroup_invocation_id); + uint unnamed_4 = metal::simd_product(subgroup_invocation_id); + uint unnamed_5 = metal::simd_min(subgroup_invocation_id); + uint unnamed_6 = metal::simd_max(subgroup_invocation_id); + uint unnamed_7 = metal::simd_and(subgroup_invocation_id); + uint unnamed_8 = metal::simd_or(subgroup_invocation_id); + uint unnamed_9 = metal::simd_xor(subgroup_invocation_id); + uint unnamed_10 = metal::simd_prefix_exclusive_sum(subgroup_invocation_id); + uint unnamed_11 = metal::simd_prefix_exclusive_product(subgroup_invocation_id); + uint unnamed_12 = metal::simd_prefix_inclusive_sum(subgroup_invocation_id); + uint unnamed_13 = metal::simd_prefix_inclusive_product(subgroup_invocation_id); + uint unnamed_14 = metal::simd_broadcast_first(subgroup_invocation_id); + uint unnamed_15 = metal::simd_broadcast(subgroup_invocation_id, 4u); + uint unnamed_16 = metal::simd_shuffle(subgroup_invocation_id, (subgroup_size - 1u) - subgroup_invocation_id); + uint unnamed_17 = metal::simd_shuffle_down(subgroup_invocation_id, 1u); + uint unnamed_18 = metal::simd_shuffle_up(subgroup_invocation_id, 1u); + uint unnamed_19 = metal::simd_shuffle_xor(subgroup_invocation_id, subgroup_size - 1u); + return; +} diff --git a/tests/out/spv/subgroup-operations.spvasm b/tests/out/spv/subgroup-operations.spvasm new file mode 100644 index 0000000000..c2023c5473 --- /dev/null +++ b/tests/out/spv/subgroup-operations.spvasm @@ -0,0 +1,73 @@ +; SPIR-V +; Version: 1.3 +; Generator: rspirv +; Bound: 52 +OpCapability Shader +OpCapability GroupNonUniform +OpCapability GroupNonUniformBallot +OpCapability GroupNonUniformVote +OpCapability GroupNonUniformArithmetic +OpCapability GroupNonUniformShuffle +OpCapability GroupNonUniformShuffleRelative +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %15 "main" %6 %9 %11 %13 +OpExecutionMode %15 LocalSize 1 1 1 +OpDecorate %6 BuiltIn NumSubgroups +OpDecorate %9 BuiltIn SubgroupId +OpDecorate %11 BuiltIn SubgroupSize +OpDecorate %13 BuiltIn SubgroupLocalInvocationId +%2 = OpTypeVoid +%3 = OpTypeInt 32 0 +%4 = OpTypeBool +%7 = OpTypePointer Input %3 +%6 = OpVariable %7 Input +%9 = OpVariable %7 Input +%11 = OpVariable %7 Input +%13 = OpVariable %7 Input +%16 = OpTypeFunction %2 +%17 = OpConstant %3 1 +%18 = OpConstant %3 0 +%19 = OpConstant %3 4 +%21 = OpConstant %3 3 +%22 = OpConstant %3 2 +%23 = OpConstant %3 8 +%26 = OpTypeVector %3 4 +%15 = OpFunction %2 None %16 +%5 = OpLabel +%8 = OpLoad %3 %6 +%10 = OpLoad %3 %9 +%12 = OpLoad %3 %11 +%14 = OpLoad %3 %13 +OpBranch %20 +%20 = OpLabel +OpControlBarrier %21 %22 %23 +%24 = OpBitwiseAnd %3 %14 %17 +%25 = OpIEqual %4 %24 %17 +%27 = OpGroupNonUniformBallot %26 %21 %25 +%28 = OpINotEqual %4 %14 %18 +%29 = OpGroupNonUniformAll %4 %21 %28 +%30 = OpIEqual %4 %14 %18 +%31 = OpGroupNonUniformAny %4 %21 %30 +%32 = OpGroupNonUniformIAdd %3 %21 Reduce %14 +%33 = OpGroupNonUniformIMul %3 %21 Reduce %14 +%34 = OpGroupNonUniformUMin %3 %21 Reduce %14 +%35 = OpGroupNonUniformUMax %3 %21 Reduce %14 +%36 = OpGroupNonUniformBitwiseAnd %3 %21 Reduce %14 +%37 = OpGroupNonUniformBitwiseOr %3 %21 Reduce %14 +%38 = OpGroupNonUniformBitwiseXor %3 %21 Reduce %14 +%39 = OpGroupNonUniformIAdd %3 %21 ExclusiveScan %14 +%40 = OpGroupNonUniformIMul %3 %21 ExclusiveScan %14 +%41 = OpGroupNonUniformIAdd %3 %21 InclusiveScan %14 +%42 = OpGroupNonUniformIMul %3 %21 InclusiveScan %14 +%43 = OpGroupNonUniformBroadcastFirst %3 %21 %14 +%44 = OpGroupNonUniformBroadcast %3 %21 %14 %19 +%45 = OpISub %3 %12 %17 +%46 = OpISub %3 %45 %14 +%47 = OpGroupNonUniformShuffle %3 %21 %14 %46 +%48 = OpGroupNonUniformShuffleDown %3 %21 %14 %17 +%49 = OpGroupNonUniformShuffleUp %3 %21 %14 %17 +%50 = OpISub %3 %12 %17 +%51 = OpGroupNonUniformShuffleXor %3 %21 %14 %50 +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/subgroup-operations.wgsl b/tests/out/wgsl/subgroup-operations.wgsl new file mode 100644 index 0000000000..f12f226387 --- /dev/null +++ b/tests/out/wgsl/subgroup-operations.wgsl @@ -0,0 +1,26 @@ +@compute @workgroup_size(1, 1, 1) +fn main(@builtin(num_subgroups) num_subgroups: u32, @builtin(subgroup_id) subgroup_id: u32, @builtin(subgroup_size) subgroup_size: u32, @builtin(subgroup_invocation_id) subgroup_invocation_id: u32) { + subgroupBarrier(); + let _e8 = subgroupBallot( +((subgroup_invocation_id & 1u) == 1u)); + let _e11 = subgroupAll((subgroup_invocation_id != 0u)); + let _e14 = subgroupAny((subgroup_invocation_id == 0u)); + let _e15 = subgroupAdd(subgroup_invocation_id); + let _e16 = subgroupMul(subgroup_invocation_id); + let _e17 = subgroupMin(subgroup_invocation_id); + let _e18 = subgroupMax(subgroup_invocation_id); + let _e19 = subgroupAnd(subgroup_invocation_id); + let _e20 = subgroupOr(subgroup_invocation_id); + let _e21 = subgroupXor(subgroup_invocation_id); + let _e22 = subgroupPrefixExclusiveAdd(subgroup_invocation_id); + let _e23 = subgroupPrefixExclusiveMul(subgroup_invocation_id); + let _e24 = subgroupPrefixInclusiveAdd(subgroup_invocation_id); + let _e25 = subgroupPrefixInclusiveMul(subgroup_invocation_id); + let _e26 = subgroupBroadcastFirst(subgroup_invocation_id); + let _e28 = subgroupBroadcast(subgroup_invocation_id, 4u); + let _e32 = subgroupShuffle(subgroup_invocation_id, ((subgroup_size - 1u) - subgroup_invocation_id)); + let _e34 = subgroupShuffleDown(subgroup_invocation_id, 1u); + let _e36 = subgroupShuffleUp(subgroup_invocation_id, 1u); + let _e39 = subgroupShuffleXor(subgroup_invocation_id, (subgroup_size - 1u)); + return; +} diff --git a/tests/snapshots.rs b/tests/snapshots.rs index c3455dd864..c720e2efd1 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -782,6 +782,10 @@ fn convert_wgsl() { Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, ), ("separate-entry-points", Targets::SPIRV | Targets::GLSL), + ( + "subgroup-operations", + Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, + ), ]; for &(name, targets) in inputs.iter() {