-
Notifications
You must be signed in to change notification settings - Fork 487
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
194 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
diff --git a/xla/service/topk_rewriter.cc b/xla/service/topk_rewriter.cc | ||
index da872d962..1b7141055 100644 | ||
--- a/xla/service/topk_rewriter.cc | ||
+++ b/xla/service/topk_rewriter.cc | ||
@@ -196,6 +196,8 @@ std::optional<int64_t> TopkRewriter::SortIsInTopK(HloInstruction* inst) { | ||
return std::nullopt; | ||
} | ||
const int64_t sort_dim = sort->sort_dimension(); | ||
+ const int64_t batch_dim = sort_dim == 1 ? 0 : 1; | ||
+ const bool has_batch = data->shape().rank() == 2; | ||
|
||
bool supported = true; | ||
std::optional<int64_t> k; | ||
@@ -220,15 +222,10 @@ std::optional<int64_t> TopkRewriter::SortIsInTopK(HloInstruction* inst) { | ||
supported = false; | ||
break; | ||
} | ||
- for (int64_t i = 0; i < slice->slice_limits().size(); ++i) { | ||
- if (i != sort_dim && | ||
- slice->slice_limits(i) != slice->operand(0)->shape().dimensions(i)) { | ||
- // Slicing along a non-sort dimension isn't supported. | ||
- supported = false; | ||
- break; | ||
- } | ||
- } | ||
- if (!supported) { | ||
+ if (has_batch && slice->slice_limits(batch_dim) != | ||
+ slice->operand(0)->shape().dimensions(batch_dim)) { | ||
+ // Slicing along the batch dimension isn't supported. | ||
+ supported = false; | ||
break; | ||
} | ||
if (k == std::nullopt) { | ||
@@ -260,57 +257,29 @@ StatusOr<bool> TopkRewriter::TransformToCustomCall( | ||
HloSortInstruction* sort = DynCast<HloSortInstruction>(inst); | ||
HloInstruction* data = sort->mutable_operand(0); | ||
const PrimitiveType element_type = data->shape().element_type(); | ||
- const Shape data_shape = data->shape(); | ||
|
||
- if (element_type != F32 && element_type != BF16) { | ||
+ if ((data->shape().rank() != 1 && data->shape().rank() != 2) || | ||
+ (element_type != F32 && element_type != BF16)) { | ||
continue; | ||
} | ||
|
||
- // Sort dimension must be the first or last dimension. | ||
const int64_t sort_dim = sort->sort_dimension(); | ||
- if (sort_dim != 0 && sort_dim != data_shape.rank() - 1) { | ||
- continue; | ||
- } | ||
+ const int64_t batch_dim = sort_dim == 1 ? 0 : 1; | ||
+ const bool has_batch = data->shape().rank() == 2; | ||
|
||
// Profitability check. | ||
if (!is_profitable_to_convert_(sort, *k)) { | ||
continue; | ||
} | ||
|
||
- HloInstruction* input = data; | ||
- const bool has_batch = data_shape.rank() >= 2; | ||
- const int64_t input_size = data_shape.dimensions(sort_dim); | ||
- int64_t batch_size = 1; | ||
- Shape topk_input_shape; | ||
- | ||
- if (has_batch) { | ||
- // The TopK custom call expects either a 1d tensor or a 2d tensor with | ||
- // the last dimension being the sort dimension. An input with rank > 2 | ||
- // is reshaped into a 2d tensor by combining non-sort dimensions into a | ||
- // single batch dimension. The original non-sort dimensions are | ||
- // restored for the outputs with another reshape after the custom call. | ||
- batch_size = | ||
- ShapeUtil::ElementsIn(data_shape) / data_shape.dimensions(sort_dim); | ||
- topk_input_shape = | ||
- ShapeUtil::MakeShape(element_type, {batch_size, input_size}); | ||
- | ||
- if (data_shape.rank() > 2) { | ||
- // Reshape to 2d. | ||
- input = comp->AddInstruction(HloInstruction::CreateReshape( | ||
- sort_dim == 0 | ||
- ? ShapeUtil::MakeShape(element_type, {input_size, batch_size}) | ||
- : ShapeUtil::MakeShape(element_type, | ||
- {batch_size, input_size}), | ||
- input)); | ||
- } | ||
- | ||
- if (sort_dim == 0) { | ||
- // Transpose for the custom call when sorting the first dimension. | ||
- input = comp->AddInstruction( | ||
- HloInstruction::CreateTranspose(topk_input_shape, input, {1, 0})); | ||
- } | ||
- } else { | ||
- topk_input_shape = data_shape; | ||
+ const int64_t batch_size = | ||
+ has_batch ? sort->operand(0)->shape().dimensions(batch_dim) : 1; | ||
+ const int64_t input_size = sort->operand(0)->shape().dimensions(sort_dim); | ||
+ HloInstruction* input = sort->mutable_operand(0); | ||
+ if (has_batch && sort_dim == 0) { | ||
+ input = comp->AddInstruction(HloInstruction::CreateTranspose( | ||
+ ShapeUtil::MakeShape(element_type, {batch_size, input_size}), input, | ||
+ {1, 0})); | ||
} | ||
|
||
Shape topk_shape = | ||
@@ -331,26 +300,13 @@ StatusOr<bool> TopkRewriter::TransformToCustomCall( | ||
comp->AddInstruction(HloInstruction::CreateGetTupleElement( | ||
topk->shape().tuple_shapes(1), topk, 1)); | ||
|
||
- if (has_batch) { | ||
- if (sort_dim == 0) { | ||
- // Transpose back. | ||
- value_gte = comp->AddInstruction(HloInstruction::CreateTranspose( | ||
- ShapeUtil::MakeShape(element_type, {k.value(), batch_size}), | ||
- value_gte, {1, 0})); | ||
- index_gte = comp->AddInstruction(HloInstruction::CreateTranspose( | ||
- ShapeUtil::MakeShape(S32, {k.value(), batch_size}), index_gte, | ||
- {1, 0})); | ||
- } | ||
- if (data_shape.rank() > 2) { | ||
- // Reshape back. | ||
- std::vector<int64_t> shape_dim(data_shape.dimensions().begin(), | ||
- data_shape.dimensions().end()); | ||
- shape_dim[sort_dim] = k.value(); | ||
- value_gte = comp->AddInstruction(HloInstruction::CreateReshape( | ||
- ShapeUtil::MakeShape(element_type, shape_dim), value_gte)); | ||
- index_gte = comp->AddInstruction(HloInstruction::CreateReshape( | ||
- ShapeUtil::MakeShape(S32, shape_dim), index_gte)); | ||
- } | ||
+ if (has_batch && sort_dim == 0) { | ||
+ value_gte = comp->AddInstruction(HloInstruction::CreateTranspose( | ||
+ ShapeUtil::MakeShape(element_type, {k.value(), batch_size}), | ||
+ value_gte, {1, 0})); | ||
+ index_gte = comp->AddInstruction(HloInstruction::CreateTranspose( | ||
+ ShapeUtil::MakeShape(S32, {k.value(), batch_size}), index_gte, | ||
+ {1, 0})); | ||
} | ||
|
||
for (HloInstruction* user : sort->users()) { | ||
diff --git a/xla/service/topk_rewriter_test.cc b/xla/service/topk_rewriter_test.cc | ||
index 36e723737..25ce150e0 100644 | ||
--- a/xla/service/topk_rewriter_test.cc | ||
+++ b/xla/service/topk_rewriter_test.cc | ||
@@ -326,42 +326,6 @@ ENTRY cluster { | ||
EXPECT_THAT(cc->custom_call_target(), "TopK"); | ||
} | ||
|
||
-TEST_F(TopkRewriterTest, RewriteReshape) { | ||
- const std::string hlo_string = R"( | ||
-HloModule module | ||
-)" + getComparator() + R"( | ||
-ENTRY cluster { | ||
- %arg_tuple.1 = f32[3,8,1234567] parameter(0) | ||
- %iota.4 = s32[3,8,1234567] iota(), iota_dimension=2 | ||
- %sort.27 = (f32[3,8,1234567], s32[3,8,1234567]) sort(%arg_tuple.1, %iota.4), | ||
- dimensions={2}, is_stable=true, to_apply=%compare | ||
- %get-tuple-element.28 = f32[3, 8,1234567] get-tuple-element(%sort.27), index=0 | ||
- %slice.29 = f32[3,8,5] slice(%get-tuple-element.28), slice={[0:3], [0:8], [0:5]} | ||
- %get-tuple-element.30 = s32[3,8,1234567] get-tuple-element(%sort.27), index=1 | ||
- %slice.31 = s32[3,8,5] slice(%get-tuple-element.30), slice={[0:3], [0:8], [0:5]} | ||
- ROOT %tuple.32 = (f32[3,8,5], s32[3,8,5]) tuple(%slice.29, %slice.31) | ||
-})"; | ||
- TF_ASSERT_OK_AND_ASSIGN(auto module, | ||
- ParseAndReturnVerifiedModule(hlo_string)); | ||
- TopkRewriter rewriter( | ||
- [](const HloSortInstruction*, int64_t) { return true; }); | ||
- TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get())); | ||
- TF_ASSERT_OK(HloDCE().Run(module.get()).status()); | ||
- EXPECT_TRUE(changed); | ||
- EXPECT_THAT(module->entry_computation()->root_instruction(), | ||
- GmockMatch(m::Tuple( | ||
- m::Reshape(m::GetTupleElement( | ||
- m::CustomCall(m::Reshape(m::Parameter(0))), 0)), | ||
- m::Reshape(m::GetTupleElement( | ||
- m::CustomCall(m::Reshape(m::Parameter(0))), 1))))); | ||
- const HloInstruction* cc = module->entry_computation() | ||
- ->root_instruction() | ||
- ->operand(0) | ||
- ->operand(0) | ||
- ->operand(0); | ||
- EXPECT_THAT(cc->custom_call_target(), "TopK"); | ||
-} | ||
- | ||
TEST_F(TopkRewriterTest, RewriteNoIota) { | ||
const std::string hlo_string = R"( | ||
HloModule module |