From 4e792e409fd423f7daf4e59944621b3584cf64f6 Mon Sep 17 00:00:00 2001 From: Alan Baker Date: Tue, 15 Oct 2024 13:42:45 -0400 Subject: [PATCH] Fix switch case construct validation Fixes https://crbug.com/tint/372311599 * Stop using block depth in switch validation and instead use the more robust structured exit logic from the switch construct * This is valid because the function has already handled the additional valid cases for case constructs --- source/val/validate_cfg.cpp | 31 +++++++------ test/val/val_cfg_test.cpp | 86 +++++++++++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 13 deletions(-) diff --git a/source/val/validate_cfg.cpp b/source/val/validate_cfg.cpp index 77e4f4f2fd..63e018ae3c 100644 --- a/source/val/validate_cfg.cpp +++ b/source/val/validate_cfg.cpp @@ -468,13 +468,13 @@ std::string ConstructErrorString(const Construct& construct, // headed by |target_block| branches to multiple case constructs. spv_result_t FindCaseFallThrough( ValidationState_t& _, BasicBlock* target_block, uint32_t* case_fall_through, - const BasicBlock* merge, const std::unordered_set& case_targets, - Function* function) { + const Construct& switch_construct, + const std::unordered_set& case_targets) { + const auto* merge = switch_construct.exit_block(); std::vector stack; stack.push_back(target_block); std::unordered_set visited; bool target_reachable = target_block->structurally_reachable(); - int target_depth = function->GetBlockDepth(target_block); while (!stack.empty()) { auto block = stack.back(); stack.pop_back(); @@ -492,9 +492,14 @@ spv_result_t FindCaseFallThrough( } else { // Exiting the case construct to non-merge block. if (!case_targets.count(block->id())) { - int depth = function->GetBlockDepth(block); - if ((depth < target_depth) || - (depth == target_depth && block->is_type(kBlockTypeContinue))) { + // We have already filtered out the following: + // * The switch's merge + // * Other case targets + // * Blocks in the same case construct + // + // So the only remaining valid branches are the structured exits from the + // overall selection construct of the switch. + if (switch_construct.IsStructuredExit(_, block)) { continue; } @@ -526,9 +531,10 @@ spv_result_t FindCaseFallThrough( } spv_result_t StructuredSwitchChecks(ValidationState_t& _, Function* function, - const Instruction* switch_inst, - const BasicBlock* header, - const BasicBlock* merge) { + const Construct& switch_construct) { + const auto* header = switch_construct.entry_block(); + const auto* merge = switch_construct.exit_block(); + const auto* switch_inst = header->terminator(); std::unordered_set case_targets; for (uint32_t i = 1; i < switch_inst->operands().size(); i += 2) { uint32_t target = switch_inst->GetOperandAs(i); @@ -546,6 +552,7 @@ spv_result_t StructuredSwitchChecks(ValidationState_t& _, Function* function, break; } } + std::unordered_map seen_to_fall_through; for (uint32_t i = 1; i < switch_inst->operands().size(); i += 2) { uint32_t target = switch_inst->GetOperandAs(i); @@ -566,7 +573,7 @@ spv_result_t StructuredSwitchChecks(ValidationState_t& _, Function* function, } if (auto error = FindCaseFallThrough(_, target_block, &case_fall_through, - merge, case_targets, function)) { + switch_construct, case_targets)) { return error; } @@ -866,9 +873,7 @@ spv_result_t StructuredControlFlowChecks( // Checks rules for case constructs. if (construct.type() == ConstructType::kSelection && header->terminator()->opcode() == spv::Op::OpSwitch) { - const auto terminator = header->terminator(); - if (auto error = - StructuredSwitchChecks(_, function, terminator, header, merge)) { + if (auto error = StructuredSwitchChecks(_, function, construct)) { return error; } } diff --git a/test/val/val_cfg_test.cpp b/test/val/val_cfg_test.cpp index d3b5e90209..ae2e45bb62 100644 --- a/test/val/val_cfg_test.cpp +++ b/test/val/val_cfg_test.cpp @@ -5155,6 +5155,92 @@ TEST_F(ValidateCFG, StructurallyUnreachableContinuePredecessor) { EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); } +TEST_F(ValidateCFG, FullyLoopPrecedingSwitchToContinue) { + const std::string text = R"( + OpCapability Shader + OpMemoryModel Logical GLSL450 + OpEntryPoint Fragment %main "main" + OpExecutionMode %main OriginUpperLeft + OpName %main "main" + %void = OpTypeVoid + %3 = OpTypeFunction %void + %bool = OpTypeBool + %true = OpConstantTrue %bool + %int = OpTypeInt 32 1 + %int_0 = OpConstant %int 0 + %int_1 = OpConstant %int 1 + %main = OpFunction %void None %3 + %4 = OpLabel + OpBranch %7 + %7 = OpLabel + OpLoopMerge %8 %6 None + OpBranch %5 + %5 = OpLabel + OpSelectionMerge %9 None + OpBranchConditional %true %10 %9 + %10 = OpLabel + OpSelectionMerge %16 None + OpSwitch %int_0 %13 + %13 = OpLabel + OpBranch %19 + %19 = OpLabel + OpLoopMerge %20 %18 None + OpBranch %17 + %17 = OpLabel + OpReturn + %18 = OpLabel + OpBranch %19 + %20 = OpLabel + OpSelectionMerge %23 None + OpSwitch %int_1 %21 + %21 = OpLabel + OpBranch %6 + %23 = OpLabel + OpBranch %16 + %16 = OpLabel + OpBranch %9 + %9 = OpLabel + OpBranch %6 + %6 = OpLabel + OpBranch %7 + %8 = OpLabel + OpUnreachable + OpFunctionEnd +)"; + + CompileSuccessfully(text); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + +TEST_F(ValidateCFG, CaseBreak) { + const std::string text = R"( +OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint Fragment %main "main" +OpExecutionMode %main OriginUpperLeft +OpName %main "main" +%void = OpTypeVoid +%3 = OpTypeFunction %void +%bool = OpTypeBool +%true = OpConstantTrue %bool +%int = OpTypeInt 32 1 +%int_0 = OpConstant %int 0 +%int_1 = OpConstant %int 1 +%main = OpFunction %void None %3 +%4 = OpLabel +OpSelectionMerge %merge None +OpSwitch %int_1 %case 2 %merge +%case = OpLabel +OpBranch %merge +%merge = OpLabel +OpReturn +OpFunctionEnd +)"; + + CompileSuccessfully(text); + EXPECT_EQ(SPV_SUCCESS, ValidateInstructions()); +} + } // namespace } // namespace val } // namespace spvtools