From 417a92f664750c16ef49931b9e3e06527b0f65bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Danyluk?= Date: Mon, 27 Nov 2023 09:03:01 -0800 Subject: [PATCH] [XLA:GPU] Fix for "Transform dimension propagation to the functional paradigm " I think that now it's perhaps possible that FusionContext::GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible succeeds, but context.CombineDimOrdersAndReqs fails because of a "splittable_dimension_major_part_size" requirement. Also continue is the same as break in this specific context, so I changed it to break. PiperOrigin-RevId: 585669068 --- third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc index 5f4a28605c6b89..573ea2cdc4d999 100644 --- a/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc +++ b/third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc @@ -1508,11 +1508,10 @@ StatusOr FuseDot(HloInstruction& dot, user->operand_index(fusion_output), context.dim_orders().at(fusion_output), gpu_version, context.hero_properties()); - if (!std::holds_alternative(result)) { - continue; + if (!std::holds_alternative(result) || + !context.CombineDimOrdersAndReqs(std::get(result))) { + break; } - TF_RET_CHECK( - context.CombineDimOrdersAndReqs(std::get(result))); for (HloInstruction* operand : user->operands()) { if (!output_old_to_new_map.contains(operand)) { context.TryToFuseWithInputsRecursively(*operand, gpu_version,