From 297e7ec123aa397ac55f7330b42b091258d5ce1f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 30 Nov 2023 09:24:58 -0800 Subject: [PATCH] [XLA:CPU] Enforce a major-to-minor layout constraint on the TopK custom call. The emitter depends on this layout, but layout assignment doesn't enforce it. This bug was revealed by a change adding an AllGather op that enforced a different layout constraint on one such TopK operator. PiperOrigin-RevId: 586697397 --- .../xla/service/cpu/cpu_layout_assignment.cc | 22 ++++++++++++++----- third_party/xla/xla/service/cpu/ir_emitter.cc | 11 ++++++---- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc b/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc index 8b124ddaa60397..dbfda4a5362a68 100644 --- a/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc +++ b/third_party/xla/xla/service/cpu/cpu_layout_assignment.cc @@ -15,7 +15,10 @@ limitations under the License. #include "xla/service/cpu/cpu_layout_assignment.h" +#include #include +#include +#include #include "absl/container/flat_hash_map.h" #include "xla/map_util.h" @@ -78,12 +81,17 @@ static optional ShouldMakeOperandColumnMajor( return it->second ? operand_idx : nullopt; } -static Shape RowMajorShape(const Shape& old_shape) { - Shape new_shape(old_shape); - std::vector dimension_order(new_shape.dimensions_size()); - std::iota(dimension_order.rbegin(), dimension_order.rend(), 0); - *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order); - return new_shape; +static Shape RowMajorShape(Shape shape) { + ShapeUtil::ForEachMutableSubshape( + &shape, [](Shape* subshape, const ShapeIndex& index) { + if (!subshape->IsArray()) { + return; + } + std::vector dimension_order(subshape->dimensions_size()); + std::iota(dimension_order.rbegin(), dimension_order.rend(), 0); + *subshape->mutable_layout() = LayoutUtil::MakeLayout(dimension_order); + }); + return shape; } static Shape ColMajorShape(const Shape& old_shape) { @@ -103,6 +111,8 @@ static bool OperandsAndResultMustHaveRowMajorLayout( } else if (instr.opcode() == HloOpcode::kDot) { return DotOperandsAndResultMustHaveRowMajorLayout(instr, target_machine_features); + } else if (instr.opcode() == HloOpcode::kCustomCall) { + return instr.custom_call_target() == "TopK"; } return false; } diff --git a/third_party/xla/xla/service/cpu/ir_emitter.cc b/third_party/xla/xla/service/cpu/ir_emitter.cc index 18523dec844113..15446b21998770 100644 --- a/third_party/xla/xla/service/cpu/ir_emitter.cc +++ b/third_party/xla/xla/service/cpu/ir_emitter.cc @@ -2378,13 +2378,16 @@ Status IrEmitter::HandleTopK(HloInstruction* hlo) { const HloInstruction* input = hlo->operand(0); const int64_t k = hlo->shape().tuple_shapes(0).dimensions().back(); const bool has_batch = hlo->shape().tuple_shapes(0).dimensions_size() == 2; - TF_RET_CHECK(input->shape().element_type() == F32); + TF_RET_CHECK(input->shape().element_type() == F32) << hlo->ToString(); TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major( - hlo->shape().tuple_shapes(0).layout())); + hlo->shape().tuple_shapes(0).layout())) + << hlo->ToString(); TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major( - hlo->shape().tuple_shapes(1).layout())); + hlo->shape().tuple_shapes(1).layout())) + << hlo->ToString(); TF_RET_CHECK( - LayoutUtil::IsMonotonicWithDim0Major(hlo->operand(0)->shape().layout())); + LayoutUtil::IsMonotonicWithDim0Major(hlo->operand(0)->shape().layout())) + << hlo->ToString(); TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice values_slice, assignment_.GetUniqueSlice(hlo->operand(0), {}));