Skip to content

Commit

Permalink
Revert cl/561479066
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Oct 10, 2023
1 parent 67d87bd commit 28cadf5
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 0 deletions.
10 changes: 10 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,23 @@ http_archive(
patches = [
"//openxla_patches:cache_urls.diff",
"//openxla_patches:constexpr_return.diff",
<<<<<<< Updated upstream
"//openxla_patches:gpu_build_file.diff",
"//openxla_patches:gpu_race_condition.diff",
"//openxla_patches:f16_abi_clang.diff",
],
strip_prefix = "xla-7a19856d74569fd1f765cd03bdee84e3b1fdc579",
urls = [
"https://github.com/openxla/xla/archive/7a19856d74569fd1f765cd03bdee84e3b1fdc579.tar.gz",
=======
"//openxla_patches:pjrt_api_tsl_logging.diff",
"//openxla_patches:pjrt_c_api_dynamic_dimensions.diff",
"//openxla_patches:gpu_topk_rewriter.diff",
],
strip_prefix = "xla-51b59cfb1999c6f1b3ec59851675044b2c502aae",
urls = [
"https://github.com/openxla/xla/archive/51b59cfb1999c6f1b3ec59851675044b2c502aae.tar.gz",
>>>>>>> Stashed changes
],
)

Expand Down
184 changes: 184 additions & 0 deletions openxla_patches/gpu_topk_rewriter.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
diff --git a/xla/service/topk_rewriter.cc b/xla/service/topk_rewriter.cc
index da872d962..1b7141055 100644
--- a/xla/service/topk_rewriter.cc
+++ b/xla/service/topk_rewriter.cc
@@ -196,6 +196,8 @@ std::optional<int64_t> TopkRewriter::SortIsInTopK(HloInstruction* inst) {
return std::nullopt;
}
const int64_t sort_dim = sort->sort_dimension();
+ const int64_t batch_dim = sort_dim == 1 ? 0 : 1;
+ const bool has_batch = data->shape().rank() == 2;

bool supported = true;
std::optional<int64_t> k;
@@ -220,15 +222,10 @@ std::optional<int64_t> TopkRewriter::SortIsInTopK(HloInstruction* inst) {
supported = false;
break;
}
- for (int64_t i = 0; i < slice->slice_limits().size(); ++i) {
- if (i != sort_dim &&
- slice->slice_limits(i) != slice->operand(0)->shape().dimensions(i)) {
- // Slicing along a non-sort dimension isn't supported.
- supported = false;
- break;
- }
- }
- if (!supported) {
+ if (has_batch && slice->slice_limits(batch_dim) !=
+ slice->operand(0)->shape().dimensions(batch_dim)) {
+ // Slicing along the batch dimension isn't supported.
+ supported = false;
break;
}
if (k == std::nullopt) {
@@ -260,57 +257,29 @@ StatusOr<bool> TopkRewriter::TransformToCustomCall(
HloSortInstruction* sort = DynCast<HloSortInstruction>(inst);
HloInstruction* data = sort->mutable_operand(0);
const PrimitiveType element_type = data->shape().element_type();
- const Shape data_shape = data->shape();

- if (element_type != F32 && element_type != BF16) {
+ if ((data->shape().rank() != 1 && data->shape().rank() != 2) ||
+ (element_type != F32 && element_type != BF16)) {
continue;
}

- // Sort dimension must be the first or last dimension.
const int64_t sort_dim = sort->sort_dimension();
- if (sort_dim != 0 && sort_dim != data_shape.rank() - 1) {
- continue;
- }
+ const int64_t batch_dim = sort_dim == 1 ? 0 : 1;
+ const bool has_batch = data->shape().rank() == 2;

// Profitability check.
if (!is_profitable_to_convert_(sort, *k)) {
continue;
}

- HloInstruction* input = data;
- const bool has_batch = data_shape.rank() >= 2;
- const int64_t input_size = data_shape.dimensions(sort_dim);
- int64_t batch_size = 1;
- Shape topk_input_shape;
-
- if (has_batch) {
- // The TopK custom call expects either a 1d tensor or a 2d tensor with
- // the last dimension being the sort dimension. An input with rank > 2
- // is reshaped into a 2d tensor by combining non-sort dimensions into a
- // single batch dimension. The original non-sort dimensions are
- // restored for the outputs with another reshape after the custom call.
- batch_size =
- ShapeUtil::ElementsIn(data_shape) / data_shape.dimensions(sort_dim);
- topk_input_shape =
- ShapeUtil::MakeShape(element_type, {batch_size, input_size});
-
- if (data_shape.rank() > 2) {
- // Reshape to 2d.
- input = comp->AddInstruction(HloInstruction::CreateReshape(
- sort_dim == 0
- ? ShapeUtil::MakeShape(element_type, {input_size, batch_size})
- : ShapeUtil::MakeShape(element_type,
- {batch_size, input_size}),
- input));
- }
-
- if (sort_dim == 0) {
- // Transpose for the custom call when sorting the first dimension.
- input = comp->AddInstruction(
- HloInstruction::CreateTranspose(topk_input_shape, input, {1, 0}));
- }
- } else {
- topk_input_shape = data_shape;
+ const int64_t batch_size =
+ has_batch ? sort->operand(0)->shape().dimensions(batch_dim) : 1;
+ const int64_t input_size = sort->operand(0)->shape().dimensions(sort_dim);
+ HloInstruction* input = sort->mutable_operand(0);
+ if (has_batch && sort_dim == 0) {
+ input = comp->AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(element_type, {batch_size, input_size}), input,
+ {1, 0}));
}

Shape topk_shape =
@@ -331,26 +300,13 @@ StatusOr<bool> TopkRewriter::TransformToCustomCall(
comp->AddInstruction(HloInstruction::CreateGetTupleElement(
topk->shape().tuple_shapes(1), topk, 1));

- if (has_batch) {
- if (sort_dim == 0) {
- // Transpose back.
- value_gte = comp->AddInstruction(HloInstruction::CreateTranspose(
- ShapeUtil::MakeShape(element_type, {k.value(), batch_size}),
- value_gte, {1, 0}));
- index_gte = comp->AddInstruction(HloInstruction::CreateTranspose(
- ShapeUtil::MakeShape(S32, {k.value(), batch_size}), index_gte,
- {1, 0}));
- }
- if (data_shape.rank() > 2) {
- // Reshape back.
- std::vector<int64_t> shape_dim(data_shape.dimensions().begin(),
- data_shape.dimensions().end());
- shape_dim[sort_dim] = k.value();
- value_gte = comp->AddInstruction(HloInstruction::CreateReshape(
- ShapeUtil::MakeShape(element_type, shape_dim), value_gte));
- index_gte = comp->AddInstruction(HloInstruction::CreateReshape(
- ShapeUtil::MakeShape(S32, shape_dim), index_gte));
- }
+ if (has_batch && sort_dim == 0) {
+ value_gte = comp->AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(element_type, {k.value(), batch_size}),
+ value_gte, {1, 0}));
+ index_gte = comp->AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(S32, {k.value(), batch_size}), index_gte,
+ {1, 0}));
}

for (HloInstruction* user : sort->users()) {
diff --git a/xla/service/topk_rewriter_test.cc b/xla/service/topk_rewriter_test.cc
index 36e723737..25ce150e0 100644
--- a/xla/service/topk_rewriter_test.cc
+++ b/xla/service/topk_rewriter_test.cc
@@ -326,42 +326,6 @@ ENTRY cluster {
EXPECT_THAT(cc->custom_call_target(), "TopK");
}

-TEST_F(TopkRewriterTest, RewriteReshape) {
- const std::string hlo_string = R"(
-HloModule module
-)" + getComparator() + R"(
-ENTRY cluster {
- %arg_tuple.1 = f32[3,8,1234567] parameter(0)
- %iota.4 = s32[3,8,1234567] iota(), iota_dimension=2
- %sort.27 = (f32[3,8,1234567], s32[3,8,1234567]) sort(%arg_tuple.1, %iota.4),
- dimensions={2}, is_stable=true, to_apply=%compare
- %get-tuple-element.28 = f32[3, 8,1234567] get-tuple-element(%sort.27), index=0
- %slice.29 = f32[3,8,5] slice(%get-tuple-element.28), slice={[0:3], [0:8], [0:5]}
- %get-tuple-element.30 = s32[3,8,1234567] get-tuple-element(%sort.27), index=1
- %slice.31 = s32[3,8,5] slice(%get-tuple-element.30), slice={[0:3], [0:8], [0:5]}
- ROOT %tuple.32 = (f32[3,8,5], s32[3,8,5]) tuple(%slice.29, %slice.31)
-})";
- TF_ASSERT_OK_AND_ASSIGN(auto module,
- ParseAndReturnVerifiedModule(hlo_string));
- TopkRewriter rewriter(
- [](const HloSortInstruction*, int64_t) { return true; });
- TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get()));
- TF_ASSERT_OK(HloDCE().Run(module.get()).status());
- EXPECT_TRUE(changed);
- EXPECT_THAT(module->entry_computation()->root_instruction(),
- GmockMatch(m::Tuple(
- m::Reshape(m::GetTupleElement(
- m::CustomCall(m::Reshape(m::Parameter(0))), 0)),
- m::Reshape(m::GetTupleElement(
- m::CustomCall(m::Reshape(m::Parameter(0))), 1)))));
- const HloInstruction* cc = module->entry_computation()
- ->root_instruction()
- ->operand(0)
- ->operand(0)
- ->operand(0);
- EXPECT_THAT(cc->custom_call_target(), "TopK");
-}
-
TEST_F(TopkRewriterTest, RewriteNoIota) {
const std::string hlo_string = R"(
HloModule module

0 comments on commit 28cadf5

Please sign in to comment.