From c8e74ee4736c57d834be0e1268fc9d6b87a1a47a Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Sun, 3 Dec 2023 22:48:57 -0500 Subject: [PATCH 01/10] Update WORKSPACE --- WORKSPACE | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index 2d3e69033fd..4007eefee91 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -46,9 +46,9 @@ http_archive( "//openxla_patches:quant_dequant_converter.diff", "//openxla_patches:stablehlo_quant_seralization.diff", ], - strip_prefix = "xla-4f8381651977dff16b1d86bb4b198eb733c5f478", + strip_prefix = "xla-8744c9a94782cd7804f015e6d29df253437af3cb", urls = [ - "https://github.com/openxla/xla/archive/4f8381651977dff16b1d86bb4b198eb733c5f478.tar.gz", + "https://github.com/openxla/xla/archive/8744c9a94782cd7804f015e6d29df253437af3cb.tar.gz", ], ) From b31387ed0c8d2504cc053e81f38ae2e659f73e04 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Sun, 3 Dec 2023 22:52:48 -0500 Subject: [PATCH 02/10] Update cache_urls.diff --- openxla_patches/cache_urls.diff | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/openxla_patches/cache_urls.diff b/openxla_patches/cache_urls.diff index 10aeadbb2a4..72cd103b92c 100644 --- a/openxla_patches/cache_urls.diff +++ b/openxla_patches/cache_urls.diff @@ -28,4 +28,5 @@ index a4574d75d..f9ce37094 100644 + "https://storage.googleapis.com/tpu-pytorch/llvm-raw/{commit}.tar.gz".format(commit = LLVM_COMMIT), ], build_file = "//third_party/llvm:llvm.BUILD", - patch_file = [ \ No newline at end of file + patch_file = [ + From 067f5f4fd288e548a43091c7554c71109a7c4a76 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Sun, 3 Dec 2023 22:53:26 -0500 Subject: [PATCH 03/10] Update constexpr_return.diff --- openxla_patches/constexpr_return.diff | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/openxla_patches/constexpr_return.diff b/openxla_patches/constexpr_return.diff index 99825c02409..0872b5f6e78 100644 --- a/openxla_patches/constexpr_return.diff +++ b/openxla_patches/constexpr_return.diff @@ -1,12 +1,12 @@ diff --git a/xla/primitive_util.h b/xla/primitive_util.h -index 696147844..dfea15a4d 100644 +index 63fa4e193..ab352626c 100644 --- a/xla/primitive_util.h +++ b/xla/primitive_util.h -@@ -748,6 +748,7 @@ inline bool FitsInIntegralType(int64_t x, PrimitiveType ty) { +@@ -706,6 +706,7 @@ inline bool FitsInIntegralType(int64_t x, PrimitiveType ty) { std::numeric_limits::max() >= x; } LOG(FATAL) << "Invalid primitive type " << PrimitiveType_Name(ty); -+ return false; ++ return false; }, ty); } From 497339b7130ea3e63320942649fdae21c438fb5e Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Sun, 3 Dec 2023 22:53:47 -0500 Subject: [PATCH 04/10] Update gpu_race_condition.diff --- openxla_patches/gpu_race_condition.diff | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/openxla_patches/gpu_race_condition.diff b/openxla_patches/gpu_race_condition.diff index dfdc3aa7460..683b156e7d2 100644 --- a/openxla_patches/gpu_race_condition.diff +++ b/openxla_patches/gpu_race_condition.diff @@ -1,8 +1,8 @@ diff --git a/xla/service/gpu/gpu_executable.cc b/xla/service/gpu/gpu_executable.cc -index 242961dd1..787275868 100644 +index 1f9903cb3..763b7fc23 100644 --- a/xla/service/gpu/gpu_executable.cc +++ b/xla/service/gpu/gpu_executable.cc -@@ -563,8 +563,7 @@ StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( +@@ -589,8 +589,7 @@ StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( } // Force synchronous execution if the allocator requires it. @@ -10,5 +10,5 @@ index 242961dd1..787275868 100644 - !memory_allocator->AllowsAsynchronousDeallocation(); + const bool block_host_until_done = true; - - // Lock the GPU with a shared lock so that we don't interfere with autotuning \ No newline at end of file + // Lock the GPU with a shared lock so that we don't interfere with autotuning + // that may be running during JIT compilation while allowing multiple XLA From 04ccca8819d3f672fe98301028617a9ca2bdfcaf Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Sun, 3 Dec 2023 22:54:10 -0500 Subject: [PATCH 05/10] Delete openxla_patches/gpu_topk_rewriter.diff --- openxla_patches/gpu_topk_rewriter.diff | 184 ------------------------- 1 file changed, 184 deletions(-) delete mode 100644 openxla_patches/gpu_topk_rewriter.diff diff --git a/openxla_patches/gpu_topk_rewriter.diff b/openxla_patches/gpu_topk_rewriter.diff deleted file mode 100644 index 47ee3fa0f0a..00000000000 --- a/openxla_patches/gpu_topk_rewriter.diff +++ /dev/null @@ -1,184 +0,0 @@ -diff --git a/xla/service/topk_rewriter.cc b/xla/service/topk_rewriter.cc -index da872d962..1b7141055 100644 ---- a/xla/service/topk_rewriter.cc -+++ b/xla/service/topk_rewriter.cc -@@ -196,6 +196,8 @@ std::optional TopkRewriter::SortIsInTopK(HloInstruction* inst) { - return std::nullopt; - } - const int64_t sort_dim = sort->sort_dimension(); -+ const int64_t batch_dim = sort_dim == 1 ? 0 : 1; -+ const bool has_batch = data->shape().rank() == 2; - - bool supported = true; - std::optional k; -@@ -220,15 +222,10 @@ std::optional TopkRewriter::SortIsInTopK(HloInstruction* inst) { - supported = false; - break; - } -- for (int64_t i = 0; i < slice->slice_limits().size(); ++i) { -- if (i != sort_dim && -- slice->slice_limits(i) != slice->operand(0)->shape().dimensions(i)) { -- // Slicing along a non-sort dimension isn't supported. -- supported = false; -- break; -- } -- } -- if (!supported) { -+ if (has_batch && slice->slice_limits(batch_dim) != -+ slice->operand(0)->shape().dimensions(batch_dim)) { -+ // Slicing along the batch dimension isn't supported. -+ supported = false; - break; - } - if (k == std::nullopt) { -@@ -260,57 +257,29 @@ StatusOr TopkRewriter::TransformToCustomCall( - HloSortInstruction* sort = DynCast(inst); - HloInstruction* data = sort->mutable_operand(0); - const PrimitiveType element_type = data->shape().element_type(); -- const Shape data_shape = data->shape(); - -- if (element_type != F32 && element_type != BF16) { -+ if ((data->shape().rank() != 1 && data->shape().rank() != 2) || -+ (element_type != F32 && element_type != BF16)) { - continue; - } - -- // Sort dimension must be the first or last dimension. - const int64_t sort_dim = sort->sort_dimension(); -- if (sort_dim != 0 && sort_dim != data_shape.rank() - 1) { -- continue; -- } -+ const int64_t batch_dim = sort_dim == 1 ? 0 : 1; -+ const bool has_batch = data->shape().rank() == 2; - - // Profitability check. - if (!is_profitable_to_convert_(sort, *k)) { - continue; - } - -- HloInstruction* input = data; -- const bool has_batch = data_shape.rank() >= 2; -- const int64_t input_size = data_shape.dimensions(sort_dim); -- int64_t batch_size = 1; -- Shape topk_input_shape; -- -- if (has_batch) { -- // The TopK custom call expects either a 1d tensor or a 2d tensor with -- // the last dimension being the sort dimension. An input with rank > 2 -- // is reshaped into a 2d tensor by combining non-sort dimensions into a -- // single batch dimension. The original non-sort dimensions are -- // restored for the outputs with another reshape after the custom call. -- batch_size = -- ShapeUtil::ElementsIn(data_shape) / data_shape.dimensions(sort_dim); -- topk_input_shape = -- ShapeUtil::MakeShape(element_type, {batch_size, input_size}); -- -- if (data_shape.rank() > 2) { -- // Reshape to 2d. -- input = comp->AddInstruction(HloInstruction::CreateReshape( -- sort_dim == 0 -- ? ShapeUtil::MakeShape(element_type, {input_size, batch_size}) -- : ShapeUtil::MakeShape(element_type, -- {batch_size, input_size}), -- input)); -- } -- -- if (sort_dim == 0) { -- // Transpose for the custom call when sorting the first dimension. -- input = comp->AddInstruction( -- HloInstruction::CreateTranspose(topk_input_shape, input, {1, 0})); -- } -- } else { -- topk_input_shape = data_shape; -+ const int64_t batch_size = -+ has_batch ? sort->operand(0)->shape().dimensions(batch_dim) : 1; -+ const int64_t input_size = sort->operand(0)->shape().dimensions(sort_dim); -+ HloInstruction* input = sort->mutable_operand(0); -+ if (has_batch && sort_dim == 0) { -+ input = comp->AddInstruction(HloInstruction::CreateTranspose( -+ ShapeUtil::MakeShape(element_type, {batch_size, input_size}), input, -+ {1, 0})); - } - - Shape topk_shape = -@@ -331,26 +300,13 @@ StatusOr TopkRewriter::TransformToCustomCall( - comp->AddInstruction(HloInstruction::CreateGetTupleElement( - topk->shape().tuple_shapes(1), topk, 1)); - -- if (has_batch) { -- if (sort_dim == 0) { -- // Transpose back. -- value_gte = comp->AddInstruction(HloInstruction::CreateTranspose( -- ShapeUtil::MakeShape(element_type, {k.value(), batch_size}), -- value_gte, {1, 0})); -- index_gte = comp->AddInstruction(HloInstruction::CreateTranspose( -- ShapeUtil::MakeShape(S32, {k.value(), batch_size}), index_gte, -- {1, 0})); -- } -- if (data_shape.rank() > 2) { -- // Reshape back. -- std::vector shape_dim(data_shape.dimensions().begin(), -- data_shape.dimensions().end()); -- shape_dim[sort_dim] = k.value(); -- value_gte = comp->AddInstruction(HloInstruction::CreateReshape( -- ShapeUtil::MakeShape(element_type, shape_dim), value_gte)); -- index_gte = comp->AddInstruction(HloInstruction::CreateReshape( -- ShapeUtil::MakeShape(S32, shape_dim), index_gte)); -- } -+ if (has_batch && sort_dim == 0) { -+ value_gte = comp->AddInstruction(HloInstruction::CreateTranspose( -+ ShapeUtil::MakeShape(element_type, {k.value(), batch_size}), -+ value_gte, {1, 0})); -+ index_gte = comp->AddInstruction(HloInstruction::CreateTranspose( -+ ShapeUtil::MakeShape(S32, {k.value(), batch_size}), index_gte, -+ {1, 0})); - } - - for (HloInstruction* user : sort->users()) { -diff --git a/xla/service/topk_rewriter_test.cc b/xla/service/topk_rewriter_test.cc -index 36e723737..25ce150e0 100644 ---- a/xla/service/topk_rewriter_test.cc -+++ b/xla/service/topk_rewriter_test.cc -@@ -326,42 +326,6 @@ ENTRY cluster { - EXPECT_THAT(cc->custom_call_target(), "TopK"); - } - --TEST_F(TopkRewriterTest, RewriteReshape) { -- const std::string hlo_string = R"( --HloModule module --)" + getComparator() + R"( --ENTRY cluster { -- %arg_tuple.1 = f32[3,8,1234567] parameter(0) -- %iota.4 = s32[3,8,1234567] iota(), iota_dimension=2 -- %sort.27 = (f32[3,8,1234567], s32[3,8,1234567]) sort(%arg_tuple.1, %iota.4), -- dimensions={2}, is_stable=true, to_apply=%compare -- %get-tuple-element.28 = f32[3, 8,1234567] get-tuple-element(%sort.27), index=0 -- %slice.29 = f32[3,8,5] slice(%get-tuple-element.28), slice={[0:3], [0:8], [0:5]} -- %get-tuple-element.30 = s32[3,8,1234567] get-tuple-element(%sort.27), index=1 -- %slice.31 = s32[3,8,5] slice(%get-tuple-element.30), slice={[0:3], [0:8], [0:5]} -- ROOT %tuple.32 = (f32[3,8,5], s32[3,8,5]) tuple(%slice.29, %slice.31) --})"; -- TF_ASSERT_OK_AND_ASSIGN(auto module, -- ParseAndReturnVerifiedModule(hlo_string)); -- TopkRewriter rewriter( -- [](const HloSortInstruction*, int64_t) { return true; }); -- TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get())); -- TF_ASSERT_OK(HloDCE().Run(module.get()).status()); -- EXPECT_TRUE(changed); -- EXPECT_THAT(module->entry_computation()->root_instruction(), -- GmockMatch(m::Tuple( -- m::Reshape(m::GetTupleElement( -- m::CustomCall(m::Reshape(m::Parameter(0))), 0)), -- m::Reshape(m::GetTupleElement( -- m::CustomCall(m::Reshape(m::Parameter(0))), 1))))); -- const HloInstruction* cc = module->entry_computation() -- ->root_instruction() -- ->operand(0) -- ->operand(0) -- ->operand(0); -- EXPECT_THAT(cc->custom_call_target(), "TopK"); --} -- - TEST_F(TopkRewriterTest, RewriteNoIota) { - const std::string hlo_string = R"( - HloModule module From 384e5bbe90b1b966ae00e44fb0381a11b2cee334 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Sun, 3 Dec 2023 22:55:02 -0500 Subject: [PATCH 06/10] Update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index a8a04c4c286..ae3d9824667 100644 --- a/setup.py +++ b/setup.py @@ -72,7 +72,7 @@ base_dir = os.path.dirname(os.path.abspath(__file__)) -_libtpu_version = '0.1.dev20231022' +_libtpu_version = '0.1.dev20231125' _libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl' From 9d16ae1e2c23cf2366b46b1f068d57745e1599d7 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Sun, 3 Dec 2023 22:55:56 -0500 Subject: [PATCH 07/10] Update computation_client.h --- torch_xla/csrc/runtime/computation_client.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index a1223c5ef7e..9d0b239f212 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -314,7 +314,7 @@ class ComputationClient { virtual int GetNumProcesses() const = 0; using DeviceAttribute = - std::variant, float, bool>; + std::variant, float>; virtual const absl::flat_hash_map< std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>& From 55eaded50f7ce74eafbc5c2bd4f606e6839bf1e3 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Sun, 3 Dec 2023 22:57:10 -0500 Subject: [PATCH 08/10] Update pjrt_computation_client.cc --- .../csrc/runtime/pjrt_computation_client.cc | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 0bd42b3ad6a..520c5f91cb4 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -153,29 +153,28 @@ PjRtComputationClient::PjRtComputationClient() { xla::PjRtClient::KeyValuePutCallback kv_put = nullptr; if (distributed_client != nullptr) { std::string key_prefix = "gpu:"; - kv_get = [distributed_client, key_prefix](const std::string& k, - absl::Duration timeout) { + kv_get = [distributed_client, key_prefix]( + std::string_view k, + absl::Duration timeout) -> xla::StatusOr { return distributed_client->BlockingKeyValueGet( absl::StrCat(key_prefix, k), timeout); }; - kv_put = [distributed_client, key_prefix](const std::string& k, - const std::string& v) { + kv_put = [distributed_client, key_prefix]( + std::string_view k, std::string_view v) -> xla::Status { return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v); }; } TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id=" << global_process_rank << ", num_nodes=" << global_world_size; - client_ = std::move(xla::GetStreamExecutorGpuClient( - /*asynchronous=*/async, - /*allocator_config=*/GetGpuAllocatorConfig(), - /*node_id=*/global_process_rank, - /*num_nodes=*/global_world_size, - /*allowed_devices=*/allowed_devices, - /*platform_name=*/"gpu", - /*should_stage_host_to_device_transfers=*/true, - /*kv_get=*/kv_get, - /*kv_put=*/kv_put) - .value()); + xla::GpuClientOptions options; + options.allocator_config = GetGpuAllocatorConfig(); + options.node_id = global_process_rank; + options.num_nodes = global_world_size; + options.allowed_devices = allowed_devices; + options.platform_name = "gpu"; + options.kv_get = kv_get; + options.kv_put = kv_put; + client_ = std::move(xla::GetStreamExecutorGpuClient(options).value()); } else if (device_type == "XPU") { TF_VLOG(1) << "Initializing PjRt XPU client..."; XLA_CHECK_OK( From 339ad729fa3be86df3e3da49027ccf2a2fb99afa Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Sun, 3 Dec 2023 23:21:13 -0500 Subject: [PATCH 09/10] Update WORKSPACE --- WORKSPACE | 1 - 1 file changed, 1 deletion(-) diff --git a/WORKSPACE b/WORKSPACE index 4007eefee91..b2154725282 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -42,7 +42,6 @@ http_archive( "//openxla_patches:constexpr_return.diff", "//openxla_patches:gpu_race_condition.diff", "//openxla_patches:f16_abi_clang.diff", - "//openxla_patches:gpu_topk_rewriter.diff", "//openxla_patches:quant_dequant_converter.diff", "//openxla_patches:stablehlo_quant_seralization.diff", ], From 4a220e7ba0b210a28d066f80e4c97c32211b286f Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Mon, 4 Dec 2023 13:17:17 -0500 Subject: [PATCH 10/10] Update pjrt_computation_client.cc --- torch_xla/csrc/runtime/pjrt_computation_client.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 520c5f91cb4..bf3bccd210e 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -172,6 +172,7 @@ PjRtComputationClient::PjRtComputationClient() { options.num_nodes = global_world_size; options.allowed_devices = allowed_devices; options.platform_name = "gpu"; + options.should_stage_host_to_device_transfers = true; options.kv_get = kv_get; options.kv_put = kv_put; client_ = std::move(xla::GetStreamExecutorGpuClient(options).value());