Skip to content

Commit

Permalink
[XLA:CPU] Enforce a major-to-minor layout constraint on the TopK cust…
Browse files Browse the repository at this point in the history
…om call.

The emitter depends on this layout, but layout assignment doesn't enforce it. This bug was revealed by a change adding an AllGather op that enforced a different layout constraint on one such TopK operator.

PiperOrigin-RevId: 586697397
  • Loading branch information
hawkinsp authored and tensorflower-gardener committed Nov 30, 2023
1 parent fb55ba3 commit 297e7ec
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 10 deletions.
22 changes: 16 additions & 6 deletions third_party/xla/xla/service/cpu/cpu_layout_assignment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ limitations under the License.

#include "xla/service/cpu/cpu_layout_assignment.h"

#include <cstdint>
#include <numeric>
#include <optional>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "xla/map_util.h"
Expand Down Expand Up @@ -78,12 +81,17 @@ static optional<int64_t> ShouldMakeOperandColumnMajor(
return it->second ? operand_idx : nullopt;
}

static Shape RowMajorShape(const Shape& old_shape) {
Shape new_shape(old_shape);
std::vector<int64_t> dimension_order(new_shape.dimensions_size());
std::iota(dimension_order.rbegin(), dimension_order.rend(), 0);
*new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order);
return new_shape;
static Shape RowMajorShape(Shape shape) {
ShapeUtil::ForEachMutableSubshape(
&shape, [](Shape* subshape, const ShapeIndex& index) {
if (!subshape->IsArray()) {
return;
}
std::vector<int64_t> dimension_order(subshape->dimensions_size());
std::iota(dimension_order.rbegin(), dimension_order.rend(), 0);
*subshape->mutable_layout() = LayoutUtil::MakeLayout(dimension_order);
});
return shape;
}

static Shape ColMajorShape(const Shape& old_shape) {
Expand All @@ -103,6 +111,8 @@ static bool OperandsAndResultMustHaveRowMajorLayout(
} else if (instr.opcode() == HloOpcode::kDot) {
return DotOperandsAndResultMustHaveRowMajorLayout(instr,
target_machine_features);
} else if (instr.opcode() == HloOpcode::kCustomCall) {
return instr.custom_call_target() == "TopK";
}
return false;
}
Expand Down
11 changes: 7 additions & 4 deletions third_party/xla/xla/service/cpu/ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2378,13 +2378,16 @@ Status IrEmitter::HandleTopK(HloInstruction* hlo) {
const HloInstruction* input = hlo->operand(0);
const int64_t k = hlo->shape().tuple_shapes(0).dimensions().back();
const bool has_batch = hlo->shape().tuple_shapes(0).dimensions_size() == 2;
TF_RET_CHECK(input->shape().element_type() == F32);
TF_RET_CHECK(input->shape().element_type() == F32) << hlo->ToString();
TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(
hlo->shape().tuple_shapes(0).layout()));
hlo->shape().tuple_shapes(0).layout()))
<< hlo->ToString();
TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(
hlo->shape().tuple_shapes(1).layout()));
hlo->shape().tuple_shapes(1).layout()))
<< hlo->ToString();
TF_RET_CHECK(
LayoutUtil::IsMonotonicWithDim0Major(hlo->operand(0)->shape().layout()));
LayoutUtil::IsMonotonicWithDim0Major(hlo->operand(0)->shape().layout()))
<< hlo->ToString();

TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice values_slice,
assignment_.GetUniqueSlice(hlo->operand(0), {}));
Expand Down

0 comments on commit 297e7ec

Please sign in to comment.