Skip to content

Commit

Permalink
Update gpu_topk_rewriter.diff
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai authored Nov 27, 2023
1 parent 5952bab commit bfea83d
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions openxla_patches/gpu_topk_rewriter.diff
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ index a0fc5758a..da9be54c8 100644
TopKCustomCall CreateTopKCustomCall(HloInstruction* input,
const int64_t sort_dim, const int64_t k,
HloComputation* comparator,
- HloComputation* comp) {
+ HloComputation* comp,
+ int64_t batch_dim,
+ bool has_batch,
+ HloSortInstruction* sort) {
Shape data_shape = input->shape();
PrimitiveType element_type = data_shape.element_type();
Expand Down Expand Up @@ -72,12 +72,12 @@ index a0fc5758a..da9be54c8 100644
- HloInstruction::CreateTranspose(topk_input_shape, input, {1, 0}));
- }
- } else {
- topk_input_shape = data_shape;
+ const int64_t batch_size =
+ has_batch ? sort->operand(0)->shape().dimensions(batch_dim) : 1;
+ const int64_t input_size = sort->operand(0)->shape().dimensions(sort_dim);
+ // HloInstruction* input = sort->mutable_operand(0);
+ if (has_batch && sort_dim == 0) {
+ input = comp->AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(element_type, {batch_size, input_size}), input,
+ {1, 0}));
}
Expand Down Expand Up @@ -109,6 +109,7 @@ index a0fc5758a..da9be54c8 100644
+ if (has_batch && sort_dim == 0) {
+ value_gte = comp->AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(element_type, {k, batch_size}),
+ value_gte, {1, 0}));
+ index_gte = comp->AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(S32, {k, batch_size}), index_gte,
+ {1, 0}));
Expand All @@ -117,7 +118,6 @@ index a0fc5758a..da9be54c8 100644
}
@@ -343,15 +308,14 @@ StatusOr<HloInstruction*> TopkRewriter::TransformPatternToCustomCall(
HloInstruction* data = sort->mutable_operand(0);
const PrimitiveType element_type = data->shape().element_type();

- if (element_type != F32 && element_type != BF16) {
+ if ((data->shape().rank() != 1 && data->shape().rank() != 2) ||
Expand Down Expand Up @@ -146,6 +146,7 @@ index a0fc5758a..da9be54c8 100644
if (sort->operand_count() == 2) {
@@ -530,3 +494,4 @@ StatusOr<bool> TopkDecomposer::Run(
}

} // namespace xla
+
diff --git a/xla/service/topk_rewriter_test.cc b/xla/service/topk_rewriter_test.cc
Expand Down

0 comments on commit bfea83d

Please sign in to comment.