diff --git a/WORKSPACE b/WORKSPACE index d096cf8dd20f..1dd9260fd914 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -40,6 +40,7 @@ http_archive( patches = [ "//openxla_patches:cache_urls.diff", "//openxla_patches:constexpr_return.diff", +<<<<<<< Updated upstream "//openxla_patches:gpu_build_file.diff", "//openxla_patches:gpu_race_condition.diff", "//openxla_patches:f16_abi_clang.diff", @@ -47,6 +48,15 @@ http_archive( strip_prefix = "xla-7a19856d74569fd1f765cd03bdee84e3b1fdc579", urls = [ "https://github.com/openxla/xla/archive/7a19856d74569fd1f765cd03bdee84e3b1fdc579.tar.gz", +======= + "//openxla_patches:pjrt_api_tsl_logging.diff", + "//openxla_patches:pjrt_c_api_dynamic_dimensions.diff", + "//openxla_patches:gpu_topk_rewriter.diff", + ], + strip_prefix = "xla-51b59cfb1999c6f1b3ec59851675044b2c502aae", + urls = [ + "https://github.com/openxla/xla/archive/51b59cfb1999c6f1b3ec59851675044b2c502aae.tar.gz", +>>>>>>> Stashed changes ], ) diff --git a/openxla_patches/gpu_topk_rewriter.diff b/openxla_patches/gpu_topk_rewriter.diff new file mode 100644 index 000000000000..47ee3fa0f0a8 --- /dev/null +++ b/openxla_patches/gpu_topk_rewriter.diff @@ -0,0 +1,184 @@ +diff --git a/xla/service/topk_rewriter.cc b/xla/service/topk_rewriter.cc +index da872d962..1b7141055 100644 +--- a/xla/service/topk_rewriter.cc ++++ b/xla/service/topk_rewriter.cc +@@ -196,6 +196,8 @@ std::optional TopkRewriter::SortIsInTopK(HloInstruction* inst) { + return std::nullopt; + } + const int64_t sort_dim = sort->sort_dimension(); ++ const int64_t batch_dim = sort_dim == 1 ? 0 : 1; ++ const bool has_batch = data->shape().rank() == 2; + + bool supported = true; + std::optional k; +@@ -220,15 +222,10 @@ std::optional TopkRewriter::SortIsInTopK(HloInstruction* inst) { + supported = false; + break; + } +- for (int64_t i = 0; i < slice->slice_limits().size(); ++i) { +- if (i != sort_dim && +- slice->slice_limits(i) != slice->operand(0)->shape().dimensions(i)) { +- // Slicing along a non-sort dimension isn't supported. +- supported = false; +- break; +- } +- } +- if (!supported) { ++ if (has_batch && slice->slice_limits(batch_dim) != ++ slice->operand(0)->shape().dimensions(batch_dim)) { ++ // Slicing along the batch dimension isn't supported. ++ supported = false; + break; + } + if (k == std::nullopt) { +@@ -260,57 +257,29 @@ StatusOr TopkRewriter::TransformToCustomCall( + HloSortInstruction* sort = DynCast(inst); + HloInstruction* data = sort->mutable_operand(0); + const PrimitiveType element_type = data->shape().element_type(); +- const Shape data_shape = data->shape(); + +- if (element_type != F32 && element_type != BF16) { ++ if ((data->shape().rank() != 1 && data->shape().rank() != 2) || ++ (element_type != F32 && element_type != BF16)) { + continue; + } + +- // Sort dimension must be the first or last dimension. + const int64_t sort_dim = sort->sort_dimension(); +- if (sort_dim != 0 && sort_dim != data_shape.rank() - 1) { +- continue; +- } ++ const int64_t batch_dim = sort_dim == 1 ? 0 : 1; ++ const bool has_batch = data->shape().rank() == 2; + + // Profitability check. + if (!is_profitable_to_convert_(sort, *k)) { + continue; + } + +- HloInstruction* input = data; +- const bool has_batch = data_shape.rank() >= 2; +- const int64_t input_size = data_shape.dimensions(sort_dim); +- int64_t batch_size = 1; +- Shape topk_input_shape; +- +- if (has_batch) { +- // The TopK custom call expects either a 1d tensor or a 2d tensor with +- // the last dimension being the sort dimension. An input with rank > 2 +- // is reshaped into a 2d tensor by combining non-sort dimensions into a +- // single batch dimension. The original non-sort dimensions are +- // restored for the outputs with another reshape after the custom call. +- batch_size = +- ShapeUtil::ElementsIn(data_shape) / data_shape.dimensions(sort_dim); +- topk_input_shape = +- ShapeUtil::MakeShape(element_type, {batch_size, input_size}); +- +- if (data_shape.rank() > 2) { +- // Reshape to 2d. +- input = comp->AddInstruction(HloInstruction::CreateReshape( +- sort_dim == 0 +- ? ShapeUtil::MakeShape(element_type, {input_size, batch_size}) +- : ShapeUtil::MakeShape(element_type, +- {batch_size, input_size}), +- input)); +- } +- +- if (sort_dim == 0) { +- // Transpose for the custom call when sorting the first dimension. +- input = comp->AddInstruction( +- HloInstruction::CreateTranspose(topk_input_shape, input, {1, 0})); +- } +- } else { +- topk_input_shape = data_shape; ++ const int64_t batch_size = ++ has_batch ? sort->operand(0)->shape().dimensions(batch_dim) : 1; ++ const int64_t input_size = sort->operand(0)->shape().dimensions(sort_dim); ++ HloInstruction* input = sort->mutable_operand(0); ++ if (has_batch && sort_dim == 0) { ++ input = comp->AddInstruction(HloInstruction::CreateTranspose( ++ ShapeUtil::MakeShape(element_type, {batch_size, input_size}), input, ++ {1, 0})); + } + + Shape topk_shape = +@@ -331,26 +300,13 @@ StatusOr TopkRewriter::TransformToCustomCall( + comp->AddInstruction(HloInstruction::CreateGetTupleElement( + topk->shape().tuple_shapes(1), topk, 1)); + +- if (has_batch) { +- if (sort_dim == 0) { +- // Transpose back. +- value_gte = comp->AddInstruction(HloInstruction::CreateTranspose( +- ShapeUtil::MakeShape(element_type, {k.value(), batch_size}), +- value_gte, {1, 0})); +- index_gte = comp->AddInstruction(HloInstruction::CreateTranspose( +- ShapeUtil::MakeShape(S32, {k.value(), batch_size}), index_gte, +- {1, 0})); +- } +- if (data_shape.rank() > 2) { +- // Reshape back. +- std::vector shape_dim(data_shape.dimensions().begin(), +- data_shape.dimensions().end()); +- shape_dim[sort_dim] = k.value(); +- value_gte = comp->AddInstruction(HloInstruction::CreateReshape( +- ShapeUtil::MakeShape(element_type, shape_dim), value_gte)); +- index_gte = comp->AddInstruction(HloInstruction::CreateReshape( +- ShapeUtil::MakeShape(S32, shape_dim), index_gte)); +- } ++ if (has_batch && sort_dim == 0) { ++ value_gte = comp->AddInstruction(HloInstruction::CreateTranspose( ++ ShapeUtil::MakeShape(element_type, {k.value(), batch_size}), ++ value_gte, {1, 0})); ++ index_gte = comp->AddInstruction(HloInstruction::CreateTranspose( ++ ShapeUtil::MakeShape(S32, {k.value(), batch_size}), index_gte, ++ {1, 0})); + } + + for (HloInstruction* user : sort->users()) { +diff --git a/xla/service/topk_rewriter_test.cc b/xla/service/topk_rewriter_test.cc +index 36e723737..25ce150e0 100644 +--- a/xla/service/topk_rewriter_test.cc ++++ b/xla/service/topk_rewriter_test.cc +@@ -326,42 +326,6 @@ ENTRY cluster { + EXPECT_THAT(cc->custom_call_target(), "TopK"); + } + +-TEST_F(TopkRewriterTest, RewriteReshape) { +- const std::string hlo_string = R"( +-HloModule module +-)" + getComparator() + R"( +-ENTRY cluster { +- %arg_tuple.1 = f32[3,8,1234567] parameter(0) +- %iota.4 = s32[3,8,1234567] iota(), iota_dimension=2 +- %sort.27 = (f32[3,8,1234567], s32[3,8,1234567]) sort(%arg_tuple.1, %iota.4), +- dimensions={2}, is_stable=true, to_apply=%compare +- %get-tuple-element.28 = f32[3, 8,1234567] get-tuple-element(%sort.27), index=0 +- %slice.29 = f32[3,8,5] slice(%get-tuple-element.28), slice={[0:3], [0:8], [0:5]} +- %get-tuple-element.30 = s32[3,8,1234567] get-tuple-element(%sort.27), index=1 +- %slice.31 = s32[3,8,5] slice(%get-tuple-element.30), slice={[0:3], [0:8], [0:5]} +- ROOT %tuple.32 = (f32[3,8,5], s32[3,8,5]) tuple(%slice.29, %slice.31) +-})"; +- TF_ASSERT_OK_AND_ASSIGN(auto module, +- ParseAndReturnVerifiedModule(hlo_string)); +- TopkRewriter rewriter( +- [](const HloSortInstruction*, int64_t) { return true; }); +- TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get())); +- TF_ASSERT_OK(HloDCE().Run(module.get()).status()); +- EXPECT_TRUE(changed); +- EXPECT_THAT(module->entry_computation()->root_instruction(), +- GmockMatch(m::Tuple( +- m::Reshape(m::GetTupleElement( +- m::CustomCall(m::Reshape(m::Parameter(0))), 0)), +- m::Reshape(m::GetTupleElement( +- m::CustomCall(m::Reshape(m::Parameter(0))), 1))))); +- const HloInstruction* cc = module->entry_computation() +- ->root_instruction() +- ->operand(0) +- ->operand(0) +- ->operand(0); +- EXPECT_THAT(cc->custom_call_target(), "TopK"); +-} +- + TEST_F(TopkRewriterTest, RewriteNoIota) { + const std::string hlo_string = R"( + HloModule module