From bfea83d16f328f01fd0a138f45771a1722c6e533 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Sun, 26 Nov 2023 21:45:05 -0800 Subject: [PATCH] Update gpu_topk_rewriter.diff --- openxla_patches/gpu_topk_rewriter.diff | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/openxla_patches/gpu_topk_rewriter.diff b/openxla_patches/gpu_topk_rewriter.diff index 2aa24950d32..258c99c55d7 100644 --- a/openxla_patches/gpu_topk_rewriter.diff +++ b/openxla_patches/gpu_topk_rewriter.diff @@ -35,9 +35,9 @@ index a0fc5758a..da9be54c8 100644 TopKCustomCall CreateTopKCustomCall(HloInstruction* input, const int64_t sort_dim, const int64_t k, HloComputation* comparator, +- HloComputation* comp) { + HloComputation* comp, + int64_t batch_dim, -+ bool has_batch, + HloSortInstruction* sort) { Shape data_shape = input->shape(); PrimitiveType element_type = data_shape.element_type(); @@ -72,12 +72,12 @@ index a0fc5758a..da9be54c8 100644 - HloInstruction::CreateTranspose(topk_input_shape, input, {1, 0})); - } - } else { +- topk_input_shape = data_shape; + const int64_t batch_size = + has_batch ? sort->operand(0)->shape().dimensions(batch_dim) : 1; + const int64_t input_size = sort->operand(0)->shape().dimensions(sort_dim); + // HloInstruction* input = sort->mutable_operand(0); + if (has_batch && sort_dim == 0) { -+ input = comp->AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(element_type, {batch_size, input_size}), input, + {1, 0})); } @@ -109,6 +109,7 @@ index a0fc5758a..da9be54c8 100644 + if (has_batch && sort_dim == 0) { + value_gte = comp->AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(element_type, {k, batch_size}), ++ value_gte, {1, 0})); + index_gte = comp->AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(S32, {k, batch_size}), index_gte, + {1, 0})); @@ -117,7 +118,6 @@ index a0fc5758a..da9be54c8 100644 } @@ -343,15 +308,14 @@ StatusOr TopkRewriter::TransformPatternToCustomCall( HloInstruction* data = sort->mutable_operand(0); - const PrimitiveType element_type = data->shape().element_type(); - if (element_type != F32 && element_type != BF16) { + if ((data->shape().rank() != 1 && data->shape().rank() != 2) || @@ -146,6 +146,7 @@ index a0fc5758a..da9be54c8 100644 if (sort->operand_count() == 2) { @@ -530,3 +494,4 @@ StatusOr TopkDecomposer::Run( } + } // namespace xla + diff --git a/xla/service/topk_rewriter_test.cc b/xla/service/topk_rewriter_test.cc