From 30b0b98340fca15ec4b0e91571aa34503c1257a7 Mon Sep 17 00:00:00 2001 From: Try Date: Thu, 31 Oct 2024 00:20:00 +0100 Subject: [PATCH] added task shader support --- ...er-basic-lines.msl3.spv14.vk.nocompat.mesh | 33 +++--- ...basic-triangle.msl3.spv14.vk.nocompat.mesh | 32 ++--- .../task-basic.msl3.spv14.vk.nocompat.task | 22 ++++ .../task-func.msl3.spv14.vk.nocompat.task | 23 ++++ ...er-basic-lines.msl3.spv14.vk.nocompat.mesh | 18 +-- ...basic-triangle.msl3.spv14.vk.nocompat.mesh | 24 ++-- .../task-basic.msl3.spv14.vk.nocompat.task | 22 ++++ .../task-func.msl3.spv14.vk.nocompat.task | 37 ++++++ ...er-basic-lines.msl3.spv14.vk.nocompat.mesh | 3 +- ...basic-triangle.msl3.spv14.vk.nocompat.mesh | 3 +- .../task-basic.msl3.spv14.vk.nocompat.task | 22 ++++ .../task-func.msl3.spv14.vk.nocompat.task | 32 +++++ spirv_common.hpp | 3 +- spirv_glsl.cpp | 4 + spirv_msl.cpp | 109 +++++++++++++----- spirv_msl.hpp | 4 +- 16 files changed, 304 insertions(+), 87 deletions(-) create mode 100644 reference/opt/shaders-msl/task/task-basic.msl3.spv14.vk.nocompat.task create mode 100644 reference/opt/shaders-msl/task/task-func.msl3.spv14.vk.nocompat.task create mode 100644 reference/shaders-msl/task/task-basic.msl3.spv14.vk.nocompat.task create mode 100644 reference/shaders-msl/task/task-func.msl3.spv14.vk.nocompat.task create mode 100644 shaders-msl/task/task-basic.msl3.spv14.vk.nocompat.task create mode 100644 shaders-msl/task/task-func.msl3.spv14.vk.nocompat.task diff --git a/reference/opt/shaders-msl/mesh/mesh-shader-basic-lines.msl3.spv14.vk.nocompat.mesh b/reference/opt/shaders-msl/mesh/mesh-shader-basic-lines.msl3.spv14.vk.nocompat.mesh index 9862f16df..5d0d9287e 100644 --- a/reference/opt/shaders-msl/mesh/mesh-shader-basic-lines.msl3.spv14.vk.nocompat.mesh +++ b/reference/opt/shaders-msl/mesh/mesh-shader-basic-lines.msl3.spv14.vk.nocompat.mesh @@ -89,6 +89,13 @@ struct BlockOutPrim float4 b; }; +struct TaskPayload +{ + float a; + float b; + int c; +}; + constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(2u, 3u, 4u); struct spvPerVertex @@ -116,33 +123,33 @@ struct spvPerPrimitive using spvMesh_t = mesh; static inline __attribute__((always_inline)) -void _4(threadgroup spvUnsafeArray& gl_PrimitiveLineIndicesEXT, thread uint& gl_LocalInvocationIndex, threadgroup spvUnsafeArray& gl_MeshPrimitivesEXT, thread uint3& gl_GlobalInvocationID, threadgroup spvUnsafeArray& gl_MeshVerticesEXT, threadgroup spvUnsafeArray& vOut, threadgroup spvUnsafeArray& outputs, threadgroup spvUnsafeArray& vPrim, thread uint3& gl_WorkGroupID, threadgroup spvUnsafeArray& prim_outputs, threadgroup uint2& spvMeshSizes) +void _4(threadgroup spvUnsafeArray& gl_PrimitiveLineIndicesEXT, thread uint& gl_LocalInvocationIndex, threadgroup spvUnsafeArray& gl_MeshPrimitivesEXT, thread uint3& gl_GlobalInvocationID, threadgroup spvUnsafeArray& gl_MeshVerticesEXT, threadgroup spvUnsafeArray& vOut, threadgroup spvUnsafeArray& outputs, threadgroup spvUnsafeArray& vPrim, thread uint3& gl_WorkGroupID, threadgroup spvUnsafeArray& prim_outputs, const object_data TaskPayload& payload, threadgroup uint2& spvMeshSizes) { spvSetMeshOutputsEXT(gl_LocalInvocationIndex, spvMeshSizes, 24u, 22u); - float3 _158 = float3(gl_GlobalInvocationID); - float _159 = _158.x; - gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_Position = float4(_159, _158.yz, 1.0); + float3 _163 = float3(gl_GlobalInvocationID); + float _164 = _163.x; + gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_Position = float4(_164, _163.yz, 1.0); gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_PointSize = 2.0; gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_ClipDistance[0] = 4.0; - vOut[gl_LocalInvocationIndex] = float4(_159, _158.yz, 2.0); + vOut[gl_LocalInvocationIndex] = float4(_164, _163.yz, 2.0); outputs[gl_LocalInvocationIndex].a = float4(5.0); outputs[gl_LocalInvocationIndex].b = float4(6.0); threadgroup_barrier(mem_flags::mem_threadgroup); if (gl_LocalInvocationIndex < 22u) { vPrim[gl_LocalInvocationIndex] = float4(float3(gl_WorkGroupID), 3.0); - prim_outputs[gl_LocalInvocationIndex].a = float4(0.0); - prim_outputs[gl_LocalInvocationIndex].b = float4(1.0); + prim_outputs[gl_LocalInvocationIndex].a = float4(payload.a); + prim_outputs[gl_LocalInvocationIndex].b = float4(payload.b); gl_PrimitiveLineIndicesEXT[gl_LocalInvocationIndex] = uint2(0u, 1u) + uint2(gl_LocalInvocationIndex); - int _206 = int(gl_GlobalInvocationID.x); - gl_MeshPrimitivesEXT[gl_LocalInvocationIndex].gl_PrimitiveID = _206; - gl_MeshPrimitivesEXT[gl_LocalInvocationIndex].gl_Layer = _206 + 1; - gl_MeshPrimitivesEXT[gl_LocalInvocationIndex].gl_ViewportIndex = _206 + 2; + int _217 = int(gl_GlobalInvocationID.x); + gl_MeshPrimitivesEXT[gl_LocalInvocationIndex].gl_PrimitiveID = _217; + gl_MeshPrimitivesEXT[gl_LocalInvocationIndex].gl_Layer = _217 + 1; + gl_MeshPrimitivesEXT[gl_LocalInvocationIndex].gl_ViewportIndex = _217 + 2; gl_MeshPrimitivesEXT[gl_LocalInvocationIndex].gl_CullPrimitiveEXT = short((gl_GlobalInvocationID.x & 1u) != 0u); } } -[[mesh]] void main0(uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]], spvMesh_t spvMesh) +[[mesh]] void main0(uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]], spvMesh_t spvMesh, const object_data TaskPayload& payload [[payload]]) { threadgroup uint2 spvMeshSizes; threadgroup spvUnsafeArray gl_PrimitiveLineIndicesEXT; @@ -154,7 +161,7 @@ void _4(threadgroup spvUnsafeArray& gl_PrimitiveLineIndicesEXT, threa threadgroup spvUnsafeArray prim_outputs; threadgroup spvUnsafeArray shared_float; if (gl_LocalInvocationIndex == 0) spvMeshSizes.y = 0u; - _4(gl_PrimitiveLineIndicesEXT, gl_LocalInvocationIndex, gl_MeshPrimitivesEXT, gl_GlobalInvocationID, gl_MeshVerticesEXT, vOut, outputs, vPrim, gl_WorkGroupID, prim_outputs, spvMeshSizes); + _4(gl_PrimitiveLineIndicesEXT, gl_LocalInvocationIndex, gl_MeshPrimitivesEXT, gl_GlobalInvocationID, gl_MeshVerticesEXT, vOut, outputs, vPrim, gl_WorkGroupID, prim_outputs, payload, spvMeshSizes); threadgroup_barrier(mem_flags::mem_threadgroup); if (spvMeshSizes.y == 0) { diff --git a/reference/opt/shaders-msl/mesh/mesh-shader-basic-triangle.msl3.spv14.vk.nocompat.mesh b/reference/opt/shaders-msl/mesh/mesh-shader-basic-triangle.msl3.spv14.vk.nocompat.mesh index f0430820e..6e7889e6b 100644 --- a/reference/opt/shaders-msl/mesh/mesh-shader-basic-triangle.msl3.spv14.vk.nocompat.mesh +++ b/reference/opt/shaders-msl/mesh/mesh-shader-basic-triangle.msl3.spv14.vk.nocompat.mesh @@ -81,6 +81,13 @@ struct BlockOutPrim float4 b; }; +struct TaskPayload +{ + float a; + float b; + int c; +}; + struct gl_MeshPerPrimitiveEXT { uint gl_PrimitiveID [[primitive_id]]; @@ -91,13 +98,6 @@ struct gl_MeshPerPrimitiveEXT constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(2u, 3u, 4u); -struct TaskPayload -{ - float a; - float b; - int c; -}; - struct spvPerVertex { float4 gl_Position [[position]]; @@ -123,7 +123,7 @@ struct spvPerPrimitive using spvMesh_t = mesh; static inline __attribute__((always_inline)) -void _4(threadgroup spvUnsafeArray& gl_MeshVerticesEXT, thread uint& gl_LocalInvocationIndex, thread uint3& gl_GlobalInvocationID, threadgroup spvUnsafeArray& vOut, threadgroup spvUnsafeArray& outputs, threadgroup spvUnsafeArray& vPrim, thread uint3& gl_WorkGroupID, threadgroup spvUnsafeArray& prim_outputs, threadgroup spvUnsafeArray& gl_PrimitiveTriangleIndicesEXT, threadgroup spvUnsafeArray& gl_MeshPrimitivesEXT, threadgroup uint2& spvMeshSizes) +void _4(threadgroup spvUnsafeArray& gl_MeshVerticesEXT, thread uint& gl_LocalInvocationIndex, thread uint3& gl_GlobalInvocationID, threadgroup spvUnsafeArray& vOut, threadgroup spvUnsafeArray& outputs, threadgroup spvUnsafeArray& vPrim, thread uint3& gl_WorkGroupID, threadgroup spvUnsafeArray& prim_outputs, const object_data TaskPayload& payload, threadgroup spvUnsafeArray& gl_PrimitiveTriangleIndicesEXT, threadgroup spvUnsafeArray& gl_MeshPrimitivesEXT, threadgroup uint2& spvMeshSizes) { spvSetMeshOutputsEXT(gl_LocalInvocationIndex, spvMeshSizes, 24u, 22u); float3 _27 = float3(gl_GlobalInvocationID); @@ -138,18 +138,18 @@ void _4(threadgroup spvUnsafeArray& gl_MeshVerticesEXT, if (gl_LocalInvocationIndex < 22u) { vPrim[gl_LocalInvocationIndex] = float4(float3(gl_WorkGroupID), 3.0); - prim_outputs[gl_LocalInvocationIndex].a = float4(0.0); - prim_outputs[gl_LocalInvocationIndex].b = float4(1.0); + prim_outputs[gl_LocalInvocationIndex].a = float4(payload.a); + prim_outputs[gl_LocalInvocationIndex].b = float4(payload.b); gl_PrimitiveTriangleIndicesEXT[gl_LocalInvocationIndex] = uint3(0u, 1u, 2u) + uint3(gl_LocalInvocationIndex); - int _116 = int(gl_GlobalInvocationID.x); - gl_MeshPrimitivesEXT[gl_LocalInvocationIndex].gl_PrimitiveID = _116; - gl_MeshPrimitivesEXT[gl_LocalInvocationIndex].gl_Layer = _116 + 1; - gl_MeshPrimitivesEXT[gl_LocalInvocationIndex].gl_ViewportIndex = _116 + 2; + int _123 = int(gl_GlobalInvocationID.x); + gl_MeshPrimitivesEXT[gl_LocalInvocationIndex].gl_PrimitiveID = _123; + gl_MeshPrimitivesEXT[gl_LocalInvocationIndex].gl_Layer = _123 + 1; + gl_MeshPrimitivesEXT[gl_LocalInvocationIndex].gl_ViewportIndex = _123 + 2; gl_MeshPrimitivesEXT[gl_LocalInvocationIndex].gl_CullPrimitiveEXT = short((gl_GlobalInvocationID.x & 1u) != 0u); } } -[[mesh]] void main0(uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]], spvMesh_t spvMesh) +[[mesh]] void main0(uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]], spvMesh_t spvMesh, const object_data TaskPayload& payload [[payload]]) { threadgroup uint2 spvMeshSizes; threadgroup spvUnsafeArray gl_MeshVerticesEXT; @@ -161,7 +161,7 @@ void _4(threadgroup spvUnsafeArray& gl_MeshVerticesEXT, threadgroup spvUnsafeArray gl_MeshPrimitivesEXT; threadgroup spvUnsafeArray shared_float; if (gl_LocalInvocationIndex == 0) spvMeshSizes.y = 0u; - _4(gl_MeshVerticesEXT, gl_LocalInvocationIndex, gl_GlobalInvocationID, vOut, outputs, vPrim, gl_WorkGroupID, prim_outputs, gl_PrimitiveTriangleIndicesEXT, gl_MeshPrimitivesEXT, spvMeshSizes); + _4(gl_MeshVerticesEXT, gl_LocalInvocationIndex, gl_GlobalInvocationID, vOut, outputs, vPrim, gl_WorkGroupID, prim_outputs, payload, gl_PrimitiveTriangleIndicesEXT, gl_MeshPrimitivesEXT, spvMeshSizes); threadgroup_barrier(mem_flags::mem_threadgroup); if (spvMeshSizes.y == 0) { diff --git a/reference/opt/shaders-msl/task/task-basic.msl3.spv14.vk.nocompat.task b/reference/opt/shaders-msl/task/task-basic.msl3.spv14.vk.nocompat.task new file mode 100644 index 000000000..93c3e40b3 --- /dev/null +++ b/reference/opt/shaders-msl/task/task-basic.msl3.spv14.vk.nocompat.task @@ -0,0 +1,22 @@ +#include +#include + +using namespace metal; + +struct TaskPayload +{ + float a; + float b; + int c; +}; + +constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u); + +[[object]] void main0(mesh_grid_properties spvMgp, object_data TaskPayload& payload [[payload]]) +{ + payload.a = 1.2000000476837158203125; + payload.b = 2.2999999523162841796875; + payload.c = 3; + spvMgp.set_threadgroups_per_grid(uint3(1u, 2u, 3u)); +} + diff --git a/reference/opt/shaders-msl/task/task-func.msl3.spv14.vk.nocompat.task b/reference/opt/shaders-msl/task/task-func.msl3.spv14.vk.nocompat.task new file mode 100644 index 000000000..e942b3a92 --- /dev/null +++ b/reference/opt/shaders-msl/task/task-func.msl3.spv14.vk.nocompat.task @@ -0,0 +1,23 @@ +#include +#include + +using namespace metal; + +struct TaskPayload +{ + float a; + float b; + int c; +}; + +constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u); + +[[object]] void main0(mesh_grid_properties spvMgp, object_data TaskPayload& payload [[payload]]) +{ + payload.a = 1.2000000476837158203125; + payload.b = 2.2999999523162841796875; + payload.c = 3; + threadgroup_barrier(mem_flags::mem_threadgroup); + spvMgp.set_threadgroups_per_grid(uint3(1u, 2u, 3u)); +} + diff --git a/reference/shaders-msl/mesh/mesh-shader-basic-lines.msl3.spv14.vk.nocompat.mesh b/reference/shaders-msl/mesh/mesh-shader-basic-lines.msl3.spv14.vk.nocompat.mesh index d07e267e3..2811b4587 100644 --- a/reference/shaders-msl/mesh/mesh-shader-basic-lines.msl3.spv14.vk.nocompat.mesh +++ b/reference/shaders-msl/mesh/mesh-shader-basic-lines.msl3.spv14.vk.nocompat.mesh @@ -89,8 +89,6 @@ struct BlockOutPrim float4 b; }; -constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(2u, 3u, 4u); - struct TaskPayload { float a; @@ -98,6 +96,8 @@ struct TaskPayload int c; }; +constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(2u, 3u, 4u); + struct spvPerVertex { float4 gl_Position [[position]]; @@ -133,7 +133,7 @@ void main3(threadgroup spvUnsafeArray& gl_PrimitiveLineIndicesEXT, th } static inline __attribute__((always_inline)) -void main2(threadgroup spvUnsafeArray& gl_PrimitiveLineIndicesEXT, thread uint& gl_LocalInvocationIndex, threadgroup spvUnsafeArray& gl_MeshPrimitivesEXT, thread uint3& gl_GlobalInvocationID, threadgroup spvUnsafeArray& gl_MeshVerticesEXT, threadgroup spvUnsafeArray& vOut, threadgroup spvUnsafeArray& outputs, threadgroup spvUnsafeArray& vPrim, thread uint3& gl_WorkGroupID, threadgroup spvUnsafeArray& prim_outputs, threadgroup uint2& spvMeshSizes) +void main2(threadgroup spvUnsafeArray& gl_PrimitiveLineIndicesEXT, thread uint& gl_LocalInvocationIndex, threadgroup spvUnsafeArray& gl_MeshPrimitivesEXT, thread uint3& gl_GlobalInvocationID, threadgroup spvUnsafeArray& gl_MeshVerticesEXT, threadgroup spvUnsafeArray& vOut, threadgroup spvUnsafeArray& outputs, threadgroup spvUnsafeArray& vPrim, thread uint3& gl_WorkGroupID, threadgroup spvUnsafeArray& prim_outputs, const object_data TaskPayload& payload, threadgroup uint2& spvMeshSizes) { spvSetMeshOutputsEXT(gl_LocalInvocationIndex, spvMeshSizes, 24u, 22u); gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_Position = float4(float3(gl_GlobalInvocationID), 1.0); @@ -146,19 +146,19 @@ void main2(threadgroup spvUnsafeArray& gl_PrimitiveLineIndicesEXT, th if (gl_LocalInvocationIndex < 22u) { vPrim[gl_LocalInvocationIndex] = float4(float3(gl_WorkGroupID), 3.0); - prim_outputs[gl_LocalInvocationIndex].a = float4(0.0); - prim_outputs[gl_LocalInvocationIndex].b = float4(1.0); + prim_outputs[gl_LocalInvocationIndex].a = float4(payload.a); + prim_outputs[gl_LocalInvocationIndex].b = float4(payload.b); main3(gl_PrimitiveLineIndicesEXT, gl_LocalInvocationIndex, gl_MeshPrimitivesEXT, gl_GlobalInvocationID); } } static inline __attribute__((always_inline)) -void _4(threadgroup spvUnsafeArray& gl_PrimitiveLineIndicesEXT, thread uint& gl_LocalInvocationIndex, threadgroup spvUnsafeArray& gl_MeshPrimitivesEXT, thread uint3& gl_GlobalInvocationID, threadgroup spvUnsafeArray& gl_MeshVerticesEXT, threadgroup spvUnsafeArray& vOut, threadgroup spvUnsafeArray& outputs, threadgroup spvUnsafeArray& vPrim, thread uint3& gl_WorkGroupID, threadgroup spvUnsafeArray& prim_outputs, threadgroup uint2& spvMeshSizes) +void _4(threadgroup spvUnsafeArray& gl_PrimitiveLineIndicesEXT, thread uint& gl_LocalInvocationIndex, threadgroup spvUnsafeArray& gl_MeshPrimitivesEXT, thread uint3& gl_GlobalInvocationID, threadgroup spvUnsafeArray& gl_MeshVerticesEXT, threadgroup spvUnsafeArray& vOut, threadgroup spvUnsafeArray& outputs, threadgroup spvUnsafeArray& vPrim, thread uint3& gl_WorkGroupID, threadgroup spvUnsafeArray& prim_outputs, const object_data TaskPayload& payload, threadgroup uint2& spvMeshSizes) { - main2(gl_PrimitiveLineIndicesEXT, gl_LocalInvocationIndex, gl_MeshPrimitivesEXT, gl_GlobalInvocationID, gl_MeshVerticesEXT, vOut, outputs, vPrim, gl_WorkGroupID, prim_outputs, spvMeshSizes); + main2(gl_PrimitiveLineIndicesEXT, gl_LocalInvocationIndex, gl_MeshPrimitivesEXT, gl_GlobalInvocationID, gl_MeshVerticesEXT, vOut, outputs, vPrim, gl_WorkGroupID, prim_outputs, payload, spvMeshSizes); } -[[mesh]] void main0(uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]], spvMesh_t spvMesh) +[[mesh]] void main0(uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]], spvMesh_t spvMesh, const object_data TaskPayload& payload [[payload]]) { threadgroup uint2 spvMeshSizes; threadgroup spvUnsafeArray gl_PrimitiveLineIndicesEXT; @@ -170,7 +170,7 @@ void _4(threadgroup spvUnsafeArray& gl_PrimitiveLineIndicesEXT, threa threadgroup spvUnsafeArray prim_outputs; threadgroup spvUnsafeArray shared_float; if (gl_LocalInvocationIndex == 0) spvMeshSizes.y = 0u; - _4(gl_PrimitiveLineIndicesEXT, gl_LocalInvocationIndex, gl_MeshPrimitivesEXT, gl_GlobalInvocationID, gl_MeshVerticesEXT, vOut, outputs, vPrim, gl_WorkGroupID, prim_outputs, spvMeshSizes); + _4(gl_PrimitiveLineIndicesEXT, gl_LocalInvocationIndex, gl_MeshPrimitivesEXT, gl_GlobalInvocationID, gl_MeshVerticesEXT, vOut, outputs, vPrim, gl_WorkGroupID, prim_outputs, payload, spvMeshSizes); threadgroup_barrier(mem_flags::mem_threadgroup); if (spvMeshSizes.y == 0) { diff --git a/reference/shaders-msl/mesh/mesh-shader-basic-triangle.msl3.spv14.vk.nocompat.mesh b/reference/shaders-msl/mesh/mesh-shader-basic-triangle.msl3.spv14.vk.nocompat.mesh index b9dfccf3a..b40dd48ec 100644 --- a/reference/shaders-msl/mesh/mesh-shader-basic-triangle.msl3.spv14.vk.nocompat.mesh +++ b/reference/shaders-msl/mesh/mesh-shader-basic-triangle.msl3.spv14.vk.nocompat.mesh @@ -81,6 +81,13 @@ struct BlockOutPrim float4 b; }; +struct TaskPayload +{ + float a; + float b; + int c; +}; + struct gl_MeshPerPrimitiveEXT { uint gl_PrimitiveID [[primitive_id]]; @@ -91,13 +98,6 @@ struct gl_MeshPerPrimitiveEXT constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(2u, 3u, 4u); -struct TaskPayload -{ - float a; - float b; - int c; -}; - struct spvPerVertex { float4 gl_Position [[position]]; @@ -123,7 +123,7 @@ struct spvPerPrimitive using spvMesh_t = mesh; static inline __attribute__((always_inline)) -void _4(threadgroup spvUnsafeArray& gl_MeshVerticesEXT, thread uint& gl_LocalInvocationIndex, thread uint3& gl_GlobalInvocationID, threadgroup spvUnsafeArray& vOut, threadgroup spvUnsafeArray& outputs, threadgroup spvUnsafeArray& vPrim, thread uint3& gl_WorkGroupID, threadgroup spvUnsafeArray& prim_outputs, threadgroup spvUnsafeArray& gl_PrimitiveTriangleIndicesEXT, threadgroup spvUnsafeArray& gl_MeshPrimitivesEXT, threadgroup uint2& spvMeshSizes) +void _4(threadgroup spvUnsafeArray& gl_MeshVerticesEXT, thread uint& gl_LocalInvocationIndex, thread uint3& gl_GlobalInvocationID, threadgroup spvUnsafeArray& vOut, threadgroup spvUnsafeArray& outputs, threadgroup spvUnsafeArray& vPrim, thread uint3& gl_WorkGroupID, threadgroup spvUnsafeArray& prim_outputs, const object_data TaskPayload& payload, threadgroup spvUnsafeArray& gl_PrimitiveTriangleIndicesEXT, threadgroup spvUnsafeArray& gl_MeshPrimitivesEXT, threadgroup uint2& spvMeshSizes) { spvSetMeshOutputsEXT(gl_LocalInvocationIndex, spvMeshSizes, 24u, 22u); gl_MeshVerticesEXT[gl_LocalInvocationIndex].gl_Position = float4(float3(gl_GlobalInvocationID), 1.0); @@ -136,8 +136,8 @@ void _4(threadgroup spvUnsafeArray& gl_MeshVerticesEXT, if (gl_LocalInvocationIndex < 22u) { vPrim[gl_LocalInvocationIndex] = float4(float3(gl_WorkGroupID), 3.0); - prim_outputs[gl_LocalInvocationIndex].a = float4(0.0); - prim_outputs[gl_LocalInvocationIndex].b = float4(1.0); + prim_outputs[gl_LocalInvocationIndex].a = float4(payload.a); + prim_outputs[gl_LocalInvocationIndex].b = float4(payload.b); gl_PrimitiveTriangleIndicesEXT[gl_LocalInvocationIndex] = uint3(0u, 1u, 2u) + uint3(gl_LocalInvocationIndex); gl_MeshPrimitivesEXT[gl_LocalInvocationIndex].gl_PrimitiveID = int(gl_GlobalInvocationID.x); gl_MeshPrimitivesEXT[gl_LocalInvocationIndex].gl_Layer = int(gl_GlobalInvocationID.x) + 1; @@ -146,7 +146,7 @@ void _4(threadgroup spvUnsafeArray& gl_MeshVerticesEXT, } } -[[mesh]] void main0(uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]], spvMesh_t spvMesh) +[[mesh]] void main0(uint gl_LocalInvocationIndex [[thread_index_in_threadgroup]], uint3 gl_GlobalInvocationID [[thread_position_in_grid]], uint3 gl_WorkGroupID [[threadgroup_position_in_grid]], spvMesh_t spvMesh, const object_data TaskPayload& payload [[payload]]) { threadgroup uint2 spvMeshSizes; threadgroup spvUnsafeArray gl_MeshVerticesEXT; @@ -158,7 +158,7 @@ void _4(threadgroup spvUnsafeArray& gl_MeshVerticesEXT, threadgroup spvUnsafeArray gl_MeshPrimitivesEXT; threadgroup spvUnsafeArray shared_float; if (gl_LocalInvocationIndex == 0) spvMeshSizes.y = 0u; - _4(gl_MeshVerticesEXT, gl_LocalInvocationIndex, gl_GlobalInvocationID, vOut, outputs, vPrim, gl_WorkGroupID, prim_outputs, gl_PrimitiveTriangleIndicesEXT, gl_MeshPrimitivesEXT, spvMeshSizes); + _4(gl_MeshVerticesEXT, gl_LocalInvocationIndex, gl_GlobalInvocationID, vOut, outputs, vPrim, gl_WorkGroupID, prim_outputs, payload, gl_PrimitiveTriangleIndicesEXT, gl_MeshPrimitivesEXT, spvMeshSizes); threadgroup_barrier(mem_flags::mem_threadgroup); if (spvMeshSizes.y == 0) { diff --git a/reference/shaders-msl/task/task-basic.msl3.spv14.vk.nocompat.task b/reference/shaders-msl/task/task-basic.msl3.spv14.vk.nocompat.task new file mode 100644 index 000000000..93c3e40b3 --- /dev/null +++ b/reference/shaders-msl/task/task-basic.msl3.spv14.vk.nocompat.task @@ -0,0 +1,22 @@ +#include +#include + +using namespace metal; + +struct TaskPayload +{ + float a; + float b; + int c; +}; + +constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u); + +[[object]] void main0(mesh_grid_properties spvMgp, object_data TaskPayload& payload [[payload]]) +{ + payload.a = 1.2000000476837158203125; + payload.b = 2.2999999523162841796875; + payload.c = 3; + spvMgp.set_threadgroups_per_grid(uint3(1u, 2u, 3u)); +} + diff --git a/reference/shaders-msl/task/task-func.msl3.spv14.vk.nocompat.task b/reference/shaders-msl/task/task-func.msl3.spv14.vk.nocompat.task new file mode 100644 index 000000000..7b7a59641 --- /dev/null +++ b/reference/shaders-msl/task/task-func.msl3.spv14.vk.nocompat.task @@ -0,0 +1,37 @@ +#pragma clang diagnostic ignored "-Wmissing-prototypes" + +#include +#include + +using namespace metal; + +struct TaskPayload +{ + float a; + float b; + int c; +}; + +constant uint3 gl_WorkGroupSize [[maybe_unused]] = uint3(1u); + +static inline __attribute__((always_inline)) +void foo(object_data TaskPayload& payload) +{ + payload.a = 1.2000000476837158203125; + payload.b = 2.2999999523162841796875; + payload.c = 3; +} + +static inline __attribute__((always_inline)) +void boo(thread mesh_grid_properties& spvMgp) +{ + spvMgp.set_threadgroups_per_grid(uint3(1u, 2u, 3u)); +} + +[[object]] void main0(mesh_grid_properties spvMgp, object_data TaskPayload& payload [[payload]]) +{ + foo(payload); + threadgroup_barrier(mem_flags::mem_threadgroup); + boo(spvMgp); +} + diff --git a/shaders-msl/mesh/mesh-shader-basic-lines.msl3.spv14.vk.nocompat.mesh b/shaders-msl/mesh/mesh-shader-basic-lines.msl3.spv14.vk.nocompat.mesh index 003ede481..9a71d7c8e 100644 --- a/shaders-msl/mesh/mesh-shader-basic-lines.msl3.spv14.vk.nocompat.mesh +++ b/shaders-msl/mesh/mesh-shader-basic-lines.msl3.spv14.vk.nocompat.mesh @@ -36,8 +36,7 @@ struct TaskPayload int c; }; -// taskPayloadSharedEXT TaskPayload payload; -const TaskPayload payload = {0.0, 1.0, 2}; +taskPayloadSharedEXT TaskPayload payload; void main3() { diff --git a/shaders-msl/mesh/mesh-shader-basic-triangle.msl3.spv14.vk.nocompat.mesh b/shaders-msl/mesh/mesh-shader-basic-triangle.msl3.spv14.vk.nocompat.mesh index 8b908b366..e86e35515 100644 --- a/shaders-msl/mesh/mesh-shader-basic-triangle.msl3.spv14.vk.nocompat.mesh +++ b/shaders-msl/mesh/mesh-shader-basic-triangle.msl3.spv14.vk.nocompat.mesh @@ -36,8 +36,7 @@ struct TaskPayload int c; }; -// taskPayloadSharedEXT TaskPayload payload; -const TaskPayload payload = {0.0, 1.0, 2}; +taskPayloadSharedEXT TaskPayload payload; void main() { diff --git a/shaders-msl/task/task-basic.msl3.spv14.vk.nocompat.task b/shaders-msl/task/task-basic.msl3.spv14.vk.nocompat.task new file mode 100644 index 000000000..6422977fc --- /dev/null +++ b/shaders-msl/task/task-basic.msl3.spv14.vk.nocompat.task @@ -0,0 +1,22 @@ +#version 450 +#extension GL_EXT_mesh_shader : require + +layout(local_size_x = 1) in; + +struct TaskPayload +{ + float a; + float b; + int c; +}; + +taskPayloadSharedEXT TaskPayload payload; + +void main() +{ + payload.a = 1.2; + payload.b = 2.3; + payload.c = 3; + + EmitMeshTasksEXT(1, 2, 3); +} diff --git a/shaders-msl/task/task-func.msl3.spv14.vk.nocompat.task b/shaders-msl/task/task-func.msl3.spv14.vk.nocompat.task new file mode 100644 index 000000000..e750af2a4 --- /dev/null +++ b/shaders-msl/task/task-func.msl3.spv14.vk.nocompat.task @@ -0,0 +1,32 @@ +#version 450 +#extension GL_EXT_mesh_shader : require + +layout(local_size_x = 1) in; + +struct TaskPayload +{ + float a; + float b; + int c; +}; + +taskPayloadSharedEXT TaskPayload payload; + +void foo() +{ + payload.a = 1.2; + payload.b = 2.3; + payload.c = 3; +} + +void boo() +{ + EmitMeshTasksEXT(1, 2, 3); +} + +void main() +{ + foo(); + barrier(); + boo(); +} diff --git a/spirv_common.hpp b/spirv_common.hpp index 7149b0755..b599e0335 100644 --- a/spirv_common.hpp +++ b/spirv_common.hpp @@ -579,7 +579,8 @@ struct SPIRType : IVariant ControlPointArray, Interpolant, Char, - Meshlet + // MSL specific type, that is used by 'object'(analog of 'task' from glsl) shader. + MeshGridProperties }; // Scalar/vector/matrix support. diff --git a/spirv_glsl.cpp b/spirv_glsl.cpp index 5c86cd402..3df175171 100644 --- a/spirv_glsl.cpp +++ b/spirv_glsl.cpp @@ -16499,6 +16499,10 @@ void CompilerGLSL::emit_function(SPIRFunction &func, const Bitset &return_flags) { auto &var = get(v); var.deferred_declaration = false; + if (var.storage == StorageClassTaskPayloadWorkgroupEXT) + { + continue; + } if (variable_decl_is_remapped_storage(var, StorageClassWorkgroup)) { diff --git a/spirv_msl.cpp b/spirv_msl.cpp index cf5265ac8..43f0caa2a 100644 --- a/spirv_msl.cpp +++ b/spirv_msl.cpp @@ -1086,6 +1086,31 @@ void CompilerMSL::build_implicit_builtins() set_name(var_id, "spvMeshSizes"); builtin_mesh_sizes_id = var_id; } + + if (get_execution_model() == spv::ExecutionModelTaskEXT) + { + uint32_t offset = ir.increase_bound_by(3); + uint32_t type_id = offset; + uint32_t type_ptr_id = offset + 1; + uint32_t var_id = offset + 2; + + SPIRType mesh_grid_type { OpTypeStruct }; + mesh_grid_type.basetype = SPIRType::MeshGridProperties; + set(type_id, mesh_grid_type); + + SPIRType mesh_grid_type_ptr = mesh_grid_type; + mesh_grid_type_ptr.op = spv::OpTypePointer; + mesh_grid_type_ptr.pointer = true; + mesh_grid_type_ptr.pointer_depth++; + mesh_grid_type_ptr.parent_type = type_id; + mesh_grid_type_ptr.storage = StorageClassOutput; + + auto &ptr_in_type = set(type_ptr_id, mesh_grid_type_ptr); + ptr_in_type.self = type_id; + set(var_id, type_ptr_id, StorageClassOutput); + set_name(var_id, "spvMgp"); + builtin_task_grid_id = var_id; + } } // Checks if the specified builtin variable (e.g. gl_InstanceIndex) is marked as active. @@ -1237,35 +1262,6 @@ uint32_t CompilerMSL::get_uint_type_id() return uint_type_id; } -uint32_t CompilerMSL::get_shared_uint_type_id() -{ - if (shared_uint_type_id != 0) - return shared_uint_type_id; - - shared_uint_type_id = ir.increase_bound_by(1); - - SPIRType type { OpTypeInt }; - type.basetype = SPIRType::UInt; - type.width = 32; - type.storage = spv::StorageClassWorkgroup; - set(shared_uint_type_id, type); - return shared_uint_type_id; -} - -uint32_t CompilerMSL::get_meshlet_type_id() -{ - if (meshlet_type_id != 0) - return meshlet_type_id; - - meshlet_type_id = ir.increase_bound_by(1); - - SPIRType type { OpTypeStruct }; - type.basetype = SPIRType::Meshlet; - // type.storage = StorageClassWorkgroup; // threadgroup is not alowed with mesh<> - set(meshlet_type_id, type); - return meshlet_type_id; -} - void CompilerMSL::emit_entry_point_declarations() { // FIXME: Get test coverage here ... @@ -1848,7 +1844,8 @@ void CompilerMSL::localize_global_variables() { uint32_t v_id = *iter; auto &var = get(v_id); - if (var.storage == StorageClassPrivate || var.storage == StorageClassWorkgroup) + if (var.storage == StorageClassPrivate || var.storage == StorageClassWorkgroup || + var.storage == StorageClassTaskPayloadWorkgroupEXT) { if (!variable_is_lut(var)) entry_func.add_local_variable(v_id); @@ -2217,6 +2214,12 @@ void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std:: // We should consider a more unified system here to reduce boiler-plate. // This kind of analysis is done in several places ... } + + if (b.terminator == SPIRBlock::EmitMeshTasks) + { + if (builtin_task_grid_id != 0) + added_arg_ids.insert(builtin_task_grid_id); + } } function_global_vars[func_id] = added_arg_ids; @@ -11210,6 +11213,21 @@ void CompilerMSL::emit_function_prototype(SPIRFunction &func, const Bitset &) if (ir.ids[initializer].get_type() == TypeNone || ir.ids[initializer].get_type() == TypeExpression) set(ed_var.initializer, "{}", ed_var.basetype, true); } + + // add `taskPayloadSharedEXT` variable to entry-point arguments + for (auto &v : func.local_variables) + { + auto &var = get(v); + if (var.storage != StorageClassTaskPayloadWorkgroupEXT) + continue; + + add_local_variable_name(v); + SPIRFunction::Parameter arg = {}; + arg.id = v; + arg.type = var.basetype; + arg.alias_global_variable = true; + decl += join(", ", argument_decl(arg), " [[payload]]"); + } } for (auto &arg : func.arguments) @@ -13336,6 +13354,9 @@ string CompilerMSL::func_type_decl(SPIRType &type) case ExecutionModelMeshEXT: entry_type = "[[mesh]]"; break; + case ExecutionModelTaskEXT: + entry_type = "[[object]]"; + break; default: entry_type = "unknown"; break; @@ -13479,6 +13500,13 @@ string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bo addr_space = "threadgroup"; break; + case StorageClassTaskPayloadWorkgroupEXT: + if (is_mesh_shader()) + addr_space = "const object_data"; + else + addr_space = "object_data"; + break; + default: break; } @@ -13885,6 +13913,13 @@ void CompilerMSL::entry_point_args_builtin(string &ep_args) ep_args += ", "; ep_args += join("spvMesh_t spvMesh"); } + + if (get_execution_model() == ExecutionModelTaskEXT) + { + if (!ep_args.empty()) + ep_args += ", "; + ep_args += join("mesh_grid_properties spvMgp"); + } } string CompilerMSL::entry_point_args_argument_buffer(bool append_comma) @@ -15769,6 +15804,9 @@ string CompilerMSL::to_qualifiers_glsl(uint32_t id) auto *var = maybe_get(id); auto &type = expression_type(id); + if (type.storage == StorageClassTaskPayloadWorkgroupEXT) + quals += "object_data "; + if (type.storage == StorageClassWorkgroup || (var && variable_decl_is_remapped_storage(*var, StorageClassWorkgroup))) quals += "threadgroup "; @@ -15953,6 +15991,8 @@ string CompilerMSL::type_to_glsl(const SPIRType &type, uint32_t id, bool member) break; case SPIRType::RayQuery: return "raytracing::intersection_query"; + case SPIRType::MeshGridProperties: + return "mesh_grid_properties"; default: return "unknown_type"; @@ -19307,6 +19347,15 @@ void CompilerMSL::emit_mesh_outputs() } } +void CompilerMSL::emit_mesh_tasks(SPIRBlock &block) +{ + // GLSL: Once this instruction is called, the workgroup must be terminated immediately, and the mesh shaders are launched. + // TODO: find relieble and clean of terminating shader. + flush_variable_declaration(builtin_task_grid_id); + statement("spvMgp.set_threadgroups_per_grid(uint3(", to_unpacked_expression(block.mesh.groups[0]), ", ", + to_unpacked_expression(block.mesh.groups[1]), ", ", to_unpacked_expression(block.mesh.groups[2]), "));"); +} + string CompilerMSL::additional_fixed_sample_mask_str() const { char print_buffer[32]; diff --git a/spirv_msl.hpp b/spirv_msl.hpp index 17f249e85..4aaad01a8 100644 --- a/spirv_msl.hpp +++ b/spirv_msl.hpp @@ -871,6 +871,7 @@ class CompilerMSL : public CompilerGLSL void emit_block_hints(const SPIRBlock &block) override; void emit_mesh_entry_point(); void emit_mesh_outputs(); + void emit_mesh_tasks(SPIRBlock &block) override; // Allow Metal to use the array template to make arrays a value type std::string type_to_array_glsl(const SPIRType &type, uint32_t variable_id) override; @@ -1077,8 +1078,6 @@ class CompilerMSL : public CompilerGLSL std::string get_tess_factor_struct_name(); SPIRType &get_uint_type(); uint32_t get_uint_type_id(); - uint32_t get_shared_uint_type_id(); - uint32_t get_meshlet_type_id(); void emit_atomic_func_op(uint32_t result_type, uint32_t result_id, const char *op, spv::Op opcode, uint32_t mem_order_1, uint32_t mem_order_2, bool has_mem_order_2, uint32_t op0, uint32_t op1 = 0, bool op1_is_pointer = false, bool op1_is_literal = false, uint32_t op2 = 0); @@ -1113,6 +1112,7 @@ class CompilerMSL : public CompilerGLSL uint32_t builtin_workgroup_size_id = 0; uint32_t builtin_mesh_primitive_indices_id = 0; uint32_t builtin_mesh_sizes_id = 0; + uint32_t builtin_task_grid_id = 0; uint32_t builtin_frag_depth_id = 0; uint32_t swizzle_buffer_id = 0; uint32_t buffer_size_buffer_id = 0;