Skip to content

Commit

Permalink
[XLA:GPU] Fix for "Transform dimension propagation to the functional …
Browse files Browse the repository at this point in the history
…paradigm

"

I think that now it's perhaps possible that FusionContext::GetPropagatedDimOrdersAndRequirementsIfProfitablyFusible succeeds, but context.CombineDimOrdersAndReqs fails because of a "splittable_dimension_major_part_size" requirement.

Also continue is the same as break in this specific context, so I changed it to break.

PiperOrigin-RevId: 585669068
  • Loading branch information
tdanyluk authored and tensorflower-gardener committed Nov 27, 2023
1 parent a2074a7 commit 417a92f
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions third_party/xla/xla/service/gpu/gemm_rewriter_triton.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1508,11 +1508,10 @@ StatusOr<FusionDecision> FuseDot(HloInstruction& dot,
user->operand_index(fusion_output),
context.dim_orders().at(fusion_output), gpu_version,
context.hero_properties());
if (!std::holds_alternative<DimOrdersAndReqs>(result)) {
continue;
if (!std::holds_alternative<DimOrdersAndReqs>(result) ||
!context.CombineDimOrdersAndReqs(std::get<DimOrdersAndReqs>(result))) {
break;
}
TF_RET_CHECK(
context.CombineDimOrdersAndReqs(std::get<DimOrdersAndReqs>(result)));
for (HloInstruction* operand : user->operands()) {
if (!output_old_to_new_map.contains(operand)) {
context.TryToFuseWithInputsRecursively(*operand, gpu_version,
Expand Down

0 comments on commit 417a92f

Please sign in to comment.