Skip to content

Commit

Permalink
[tint] Skip MergeReturn for entry points
Browse files Browse the repository at this point in the history
The 'MergeReturn' transform is to ensure the convergence of the callee
for the caller. Entry points are not called and therefore do not need
this transform.


Bug:341073176
Change-Id: I2de5d2eebfd1c752c0ed370cb97b7d55ac4c6876
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/203174
Reviewed-by: James Price <[email protected]>
Commit-Queue: Peter McNeeley <[email protected]>
  • Loading branch information
Peter McNeeley authored and Dawn LUCI CQ committed Aug 20, 2024
1 parent bc3246b commit 4c92f4b
Show file tree
Hide file tree
Showing 11 changed files with 1,119 additions and 1,209 deletions.
6 changes: 6 additions & 0 deletions src/tint/lang/spirv/writer/raise/merge_return.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ struct State {
/// Process the function.
/// @param fn the function to process
void Process(core::ir::Function* fn) {
if (fn->Stage() != core::ir::Function::PipelineStage::kUndefined) {
// Entry points are not called and do not require this transformation to ensure
// convergence.
return;
}

// Find all of the nested return instructions in the function.
for (const auto& usage : fn->UsagesUnsorted()) {
if (auto* ret = usage->instruction->As<core::ir::Return>()) {
Expand Down
37 changes: 37 additions & 0 deletions src/tint/lang/spirv/writer/raise/merge_return_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,43 @@ TEST_F(SpirvWriter_MergeReturnTest, IfElse_OneSideReturns) {
EXPECT_EQ(expect, str());
}

TEST_F(SpirvWriter_MergeReturnTest, NoModify_EntryPoint_IfElse_OneSideReturns) {
auto* cond = b.FunctionParam(ty.bool_());
auto* func = b.Function("entrypointfunction", ty.void_(),
core::ir::Function::PipelineStage::kCompute, {{2, 3, 4}});
func->SetParams({cond});
b.Append(func->Block(), [&] {
auto* ifelse = b.If(cond);
b.Append(ifelse->True(), [&] { b.Return(func); });
b.Append(ifelse->False(), [&] { b.ExitIf(ifelse); });

b.Return(func);
});

auto* src = R"(
%entrypointfunction = @compute @workgroup_size(2, 3, 4) func(%2:bool):void {
$B1: {
if %2 [t: $B2, f: $B3] { # if_1
$B2: { # true
ret
}
$B3: { # false
exit_if # if_1
}
}
ret
}
}
)";
EXPECT_EQ(src, str());

auto* expect = src;

Run(MergeReturn);

EXPECT_EQ(expect, str());
}

// This is the same as the above tests, but we create the return instructions in a different order
// to make sure that creation order doesn't matter.
TEST_F(SpirvWriter_MergeReturnTest, IfElse_OneSideReturns_ReturnsCreatedInDifferentOrder) {
Expand Down
195 changes: 91 additions & 104 deletions test/tint/bug/chromium/1273230.wgsl.expected.spvasm
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
; SPIR-V
; Version: 1.3
; Generator: Google Tint Compiler; 1
; Bound: 317
; Bound: 310
; Schema: 0
OpCapability Shader
%80 = OpExtInstImport "GLSL.std.450"
Expand Down Expand Up @@ -78,7 +78,6 @@
OpName %g55 "g55"
OpName %main_count_inner "main_count_inner"
OpName %GlobalInvocationID "GlobalInvocationID"
OpName %continue_execution "continue_execution"
OpName %triangleIndex "triangleIndex"
OpName %i0 "i0"
OpName %i1 "i1"
Expand Down Expand Up @@ -205,20 +204,17 @@
%_ptr_Function_int = OpTypePointer Function %int
%215 = OpTypeFunction %void %v3uint
%bool = OpTypeBool
%_ptr_Function_bool = OpTypePointer Function %bool
%true = OpConstantTrue %bool
%false = OpConstantFalse %bool
%float_3 = OpConstant %float 3
%int_1 = OpConstant %int 1
%285 = OpTypeFunction %uint %uint %uint
%299 = OpTypeFunction %v3uint %v3float
%303 = OpConstantNull %v3float
%278 = OpTypeFunction %uint %uint %uint
%292 = OpTypeFunction %v3uint %v3float
%296 = OpConstantNull %v3float
%v3bool = OpTypeVector %bool 3
%306 = OpConstantNull %v3uint
%299 = OpConstantNull %v3uint
%float_4_29496704e_09 = OpConstant %float 4.29496704e+09
%308 = OpConstantComposite %v3float %float_4_29496704e_09 %float_4_29496704e_09 %float_4_29496704e_09
%301 = OpConstantComposite %v3float %float_4_29496704e_09 %float_4_29496704e_09 %float_4_29496704e_09
%uint_4294967295 = OpConstant %uint 4294967295
%311 = OpConstantComposite %v3uint %uint_4294967295 %uint_4294967295 %uint_4294967295
%304 = OpConstantComposite %v3uint %uint_4294967295 %uint_4294967295 %uint_4294967295
%marg8uintin = OpFunction %void None %34
%35 = OpLabel
OpReturn
Expand Down Expand Up @@ -405,7 +401,6 @@
%main_count_inner = OpFunction %void None %215
%GlobalInvocationID = OpFunctionParameter %v3uint
%216 = OpLabel
%continue_execution = OpVariable %_ptr_Function_bool Function
%triangleIndex = OpVariable %_ptr_Function_uint Function
%i0 = OpVariable %_ptr_Function_uint Function
%i1 = OpVariable %_ptr_Function_uint Function
Expand All @@ -417,108 +412,100 @@
%voxelPos_0 = OpVariable %_ptr_Function_v3float Function
%lIndex = OpVariable %_ptr_Function_uint Function
%triangleOffset = OpVariable %_ptr_Function_int Function
OpStore %continue_execution %true
%221 = OpCompositeExtract %uint %GlobalInvocationID 0
OpStore %triangleIndex %221
%223 = OpLoad %uint %triangleIndex
%224 = OpAccessChain %_ptr_Uniform_uint %1 %uint_0 %uint_0
%225 = OpLoad %uint %224
%226 = OpUGreaterThanEqual %bool %223 %225
OpSelectionMerge %227 None
OpBranchConditional %226 %228 %227
%228 = OpLabel
OpStore %continue_execution %false
OpBranch %227
%227 = OpLabel
%230 = OpLoad %bool %continue_execution
OpSelectionMerge %231 None
OpBranchConditional %230 %232 %231
%232 = OpLabel
%233 = OpFunctionCall %void %doIgnore
%234 = OpLoad %uint %triangleIndex
%235 = OpIMul %uint %uint_3 %234
%236 = OpIAdd %uint %235 %uint_0
%237 = OpAccessChain %_ptr_StorageBuffer_uint %indices %uint_0 %236
%238 = OpLoad %uint %237
OpStore %i0 %238
%240 = OpLoad %uint %i0
%241 = OpIMul %uint %uint_3 %240
%242 = OpIAdd %uint %241 %uint_1
%243 = OpAccessChain %_ptr_StorageBuffer_uint %indices %uint_0 %242
%244 = OpLoad %uint %243
OpStore %i1 %244
%246 = OpLoad %uint %i0
%247 = OpIMul %uint %uint_3 %246
%248 = OpIAdd %uint %247 %uint_2
%249 = OpAccessChain %_ptr_StorageBuffer_uint %indices %uint_0 %248
%250 = OpLoad %uint %249
OpStore %i2 %250
%252 = OpLoad %uint %i0
%253 = OpFunctionCall %v3float %loadPosition %252
OpStore %p0 %253
%255 = OpLoad %uint %i0
%256 = OpFunctionCall %v3float %loadPosition %255
OpStore %p1 %256
%258 = OpLoad %uint %i2
%259 = OpFunctionCall %v3float %loadPosition %258
OpStore %p2 %259
%261 = OpLoad %v3float %p0
%262 = OpLoad %v3float %p2
%263 = OpFAdd %v3float %261 %262
%264 = OpLoad %v3float %p1
%265 = OpFAdd %v3float %263 %264
%266 = OpCompositeConstruct %v3float %float_3 %float_3 %float_3
%268 = OpFDiv %v3float %265 %266
OpStore %center %268
%270 = OpLoad %v3float %p1
%271 = OpFunctionCall %v3float %toVoxelPos %270
OpStore %voxelPos_0 %271
%273 = OpAccessChain %_ptr_Uniform_uint %1 %uint_0 %uint_1
%274 = OpLoad %uint %273
%275 = OpLoad %v3float %p0
%276 = OpFunctionCall %uint %toIndex1D %274 %275
OpStore %lIndex %276
%278 = OpLoad %uint %i1
%279 = OpAccessChain %_ptr_StorageBuffer_int %LUT %uint_0 %278
%280 = OpAtomicIAdd %int %279 %uint_1 %uint_0 %int_1
OpStore %triangleOffset %280
OpBranch %231
%231 = OpLabel
%217 = OpCompositeExtract %uint %GlobalInvocationID 0
OpStore %triangleIndex %217
%219 = OpLoad %uint %triangleIndex
%220 = OpAccessChain %_ptr_Uniform_uint %1 %uint_0 %uint_0
%221 = OpLoad %uint %220
%222 = OpUGreaterThanEqual %bool %219 %221
OpSelectionMerge %224 None
OpBranchConditional %222 %225 %224
%225 = OpLabel
OpReturn
%224 = OpLabel
%226 = OpFunctionCall %void %doIgnore
%227 = OpLoad %uint %triangleIndex
%228 = OpIMul %uint %uint_3 %227
%229 = OpIAdd %uint %228 %uint_0
%230 = OpAccessChain %_ptr_StorageBuffer_uint %indices %uint_0 %229
%231 = OpLoad %uint %230
OpStore %i0 %231
%233 = OpLoad %uint %i0
%234 = OpIMul %uint %uint_3 %233
%235 = OpIAdd %uint %234 %uint_1
%236 = OpAccessChain %_ptr_StorageBuffer_uint %indices %uint_0 %235
%237 = OpLoad %uint %236
OpStore %i1 %237
%239 = OpLoad %uint %i0
%240 = OpIMul %uint %uint_3 %239
%241 = OpIAdd %uint %240 %uint_2
%242 = OpAccessChain %_ptr_StorageBuffer_uint %indices %uint_0 %241
%243 = OpLoad %uint %242
OpStore %i2 %243
%245 = OpLoad %uint %i0
%246 = OpFunctionCall %v3float %loadPosition %245
OpStore %p0 %246
%248 = OpLoad %uint %i0
%249 = OpFunctionCall %v3float %loadPosition %248
OpStore %p1 %249
%251 = OpLoad %uint %i2
%252 = OpFunctionCall %v3float %loadPosition %251
OpStore %p2 %252
%254 = OpLoad %v3float %p0
%255 = OpLoad %v3float %p2
%256 = OpFAdd %v3float %254 %255
%257 = OpLoad %v3float %p1
%258 = OpFAdd %v3float %256 %257
%259 = OpCompositeConstruct %v3float %float_3 %float_3 %float_3
%261 = OpFDiv %v3float %258 %259
OpStore %center %261
%263 = OpLoad %v3float %p1
%264 = OpFunctionCall %v3float %toVoxelPos %263
OpStore %voxelPos_0 %264
%266 = OpAccessChain %_ptr_Uniform_uint %1 %uint_0 %uint_1
%267 = OpLoad %uint %266
%268 = OpLoad %v3float %p0
%269 = OpFunctionCall %uint %toIndex1D %267 %268
OpStore %lIndex %269
%271 = OpLoad %uint %i1
%272 = OpAccessChain %_ptr_StorageBuffer_int %LUT %uint_0 %271
%273 = OpAtomicIAdd %int %272 %uint_1 %uint_0 %int_1
OpStore %triangleOffset %273
OpReturn
OpFunctionEnd
%tint_div_u32 = OpFunction %uint None %285
%tint_div_u32 = OpFunction %uint None %278
%lhs = OpFunctionParameter %uint
%rhs = OpFunctionParameter %uint
%286 = OpLabel
%287 = OpIEqual %bool %rhs %uint_0
%288 = OpSelect %uint %287 %uint_1 %rhs
%289 = OpUDiv %uint %lhs %288
OpReturnValue %289
%279 = OpLabel
%280 = OpIEqual %bool %rhs %uint_0
%281 = OpSelect %uint %280 %uint_1 %rhs
%282 = OpUDiv %uint %lhs %281
OpReturnValue %282
OpFunctionEnd
%tint_mod_u32 = OpFunction %uint None %285
%tint_mod_u32 = OpFunction %uint None %278
%lhs_0 = OpFunctionParameter %uint
%rhs_0 = OpFunctionParameter %uint
%292 = OpLabel
%293 = OpIEqual %bool %rhs_0 %uint_0
%294 = OpSelect %uint %293 %uint_1 %rhs_0
%295 = OpUDiv %uint %lhs_0 %294
%296 = OpIMul %uint %295 %294
%297 = OpISub %uint %lhs_0 %296
OpReturnValue %297
%285 = OpLabel
%286 = OpIEqual %bool %rhs_0 %uint_0
%287 = OpSelect %uint %286 %uint_1 %rhs_0
%288 = OpUDiv %uint %lhs_0 %287
%289 = OpIMul %uint %288 %287
%290 = OpISub %uint %lhs_0 %289
OpReturnValue %290
OpFunctionEnd
%tint_v3f32_to_v3u32 = OpFunction %v3uint None %299
%tint_v3f32_to_v3u32 = OpFunction %v3uint None %292
%value = OpFunctionParameter %v3float
%300 = OpLabel
%301 = OpConvertFToU %v3uint %value
%302 = OpFOrdGreaterThanEqual %v3bool %value %303
%305 = OpSelect %v3uint %302 %301 %306
%307 = OpFOrdLessThanEqual %v3bool %value %308
%310 = OpSelect %v3uint %307 %305 %311
OpReturnValue %310
%293 = OpLabel
%294 = OpConvertFToU %v3uint %value
%295 = OpFOrdGreaterThanEqual %v3bool %value %296
%298 = OpSelect %v3uint %295 %294 %299
%300 = OpFOrdLessThanEqual %v3bool %value %301
%303 = OpSelect %v3uint %300 %298 %304
OpReturnValue %303
OpFunctionEnd
%main_count = OpFunction %void None %34
%314 = OpLabel
%315 = OpLoad %v3uint %main_count_global_invocation_id_Input
%316 = OpFunctionCall %void %main_count_inner %315
%307 = OpLabel
%308 = OpLoad %v3uint %main_count_global_invocation_id_Input
%309 = OpFunctionCall %void %main_count_inner %308
OpReturn
OpFunctionEnd
Loading

0 comments on commit 4c92f4b

Please sign in to comment.