From 5a66cae01566750bb03672938407034ce11a01d2 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Mon, 4 Dec 2023 20:11:01 -0500 Subject: [PATCH] Update OpenXLA-pin to Nov24 (#6012) --- WORKSPACE | 5 +- openxla_patches/cache_urls.diff | 3 +- openxla_patches/constexpr_return.diff | 6 +- openxla_patches/gpu_race_condition.diff | 8 +- openxla_patches/gpu_topk_rewriter.diff | 184 ------------------ setup.py | 2 +- torch_xla/csrc/runtime/computation_client.h | 2 +- .../csrc/runtime/pjrt_computation_client.cc | 30 +-- 8 files changed, 28 insertions(+), 212 deletions(-) delete mode 100644 openxla_patches/gpu_topk_rewriter.diff diff --git a/WORKSPACE b/WORKSPACE index 2d3e69033fda..b21547252824 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -42,13 +42,12 @@ http_archive( "//openxla_patches:constexpr_return.diff", "//openxla_patches:gpu_race_condition.diff", "//openxla_patches:f16_abi_clang.diff", - "//openxla_patches:gpu_topk_rewriter.diff", "//openxla_patches:quant_dequant_converter.diff", "//openxla_patches:stablehlo_quant_seralization.diff", ], - strip_prefix = "xla-4f8381651977dff16b1d86bb4b198eb733c5f478", + strip_prefix = "xla-8744c9a94782cd7804f015e6d29df253437af3cb", urls = [ - "https://github.com/openxla/xla/archive/4f8381651977dff16b1d86bb4b198eb733c5f478.tar.gz", + "https://github.com/openxla/xla/archive/8744c9a94782cd7804f015e6d29df253437af3cb.tar.gz", ], ) diff --git a/openxla_patches/cache_urls.diff b/openxla_patches/cache_urls.diff index 10aeadbb2a45..72cd103b92c7 100644 --- a/openxla_patches/cache_urls.diff +++ b/openxla_patches/cache_urls.diff @@ -28,4 +28,5 @@ index a4574d75d..f9ce37094 100644 + "https://storage.googleapis.com/tpu-pytorch/llvm-raw/{commit}.tar.gz".format(commit = LLVM_COMMIT), ], build_file = "//third_party/llvm:llvm.BUILD", - patch_file = [ \ No newline at end of file + patch_file = [ + diff --git a/openxla_patches/constexpr_return.diff b/openxla_patches/constexpr_return.diff index 99825c024093..0872b5f6e781 100644 --- a/openxla_patches/constexpr_return.diff +++ b/openxla_patches/constexpr_return.diff @@ -1,12 +1,12 @@ diff --git a/xla/primitive_util.h b/xla/primitive_util.h -index 696147844..dfea15a4d 100644 +index 63fa4e193..ab352626c 100644 --- a/xla/primitive_util.h +++ b/xla/primitive_util.h -@@ -748,6 +748,7 @@ inline bool FitsInIntegralType(int64_t x, PrimitiveType ty) { +@@ -706,6 +706,7 @@ inline bool FitsInIntegralType(int64_t x, PrimitiveType ty) { std::numeric_limits::max() >= x; } LOG(FATAL) << "Invalid primitive type " << PrimitiveType_Name(ty); -+ return false; ++ return false; }, ty); } diff --git a/openxla_patches/gpu_race_condition.diff b/openxla_patches/gpu_race_condition.diff index dfdc3aa74608..683b156e7d2a 100644 --- a/openxla_patches/gpu_race_condition.diff +++ b/openxla_patches/gpu_race_condition.diff @@ -1,8 +1,8 @@ diff --git a/xla/service/gpu/gpu_executable.cc b/xla/service/gpu/gpu_executable.cc -index 242961dd1..787275868 100644 +index 1f9903cb3..763b7fc23 100644 --- a/xla/service/gpu/gpu_executable.cc +++ b/xla/service/gpu/gpu_executable.cc -@@ -563,8 +563,7 @@ StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( +@@ -589,8 +589,7 @@ StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( } // Force synchronous execution if the allocator requires it. @@ -10,5 +10,5 @@ index 242961dd1..787275868 100644 - !memory_allocator->AllowsAsynchronousDeallocation(); + const bool block_host_until_done = true; - - // Lock the GPU with a shared lock so that we don't interfere with autotuning \ No newline at end of file + // Lock the GPU with a shared lock so that we don't interfere with autotuning + // that may be running during JIT compilation while allowing multiple XLA diff --git a/openxla_patches/gpu_topk_rewriter.diff b/openxla_patches/gpu_topk_rewriter.diff deleted file mode 100644 index 47ee3fa0f0a8..000000000000 --- a/openxla_patches/gpu_topk_rewriter.diff +++ /dev/null @@ -1,184 +0,0 @@ -diff --git a/xla/service/topk_rewriter.cc b/xla/service/topk_rewriter.cc -index da872d962..1b7141055 100644 ---- a/xla/service/topk_rewriter.cc -+++ b/xla/service/topk_rewriter.cc -@@ -196,6 +196,8 @@ std::optional TopkRewriter::SortIsInTopK(HloInstruction* inst) { - return std::nullopt; - } - const int64_t sort_dim = sort->sort_dimension(); -+ const int64_t batch_dim = sort_dim == 1 ? 0 : 1; -+ const bool has_batch = data->shape().rank() == 2; - - bool supported = true; - std::optional k; -@@ -220,15 +222,10 @@ std::optional TopkRewriter::SortIsInTopK(HloInstruction* inst) { - supported = false; - break; - } -- for (int64_t i = 0; i < slice->slice_limits().size(); ++i) { -- if (i != sort_dim && -- slice->slice_limits(i) != slice->operand(0)->shape().dimensions(i)) { -- // Slicing along a non-sort dimension isn't supported. -- supported = false; -- break; -- } -- } -- if (!supported) { -+ if (has_batch && slice->slice_limits(batch_dim) != -+ slice->operand(0)->shape().dimensions(batch_dim)) { -+ // Slicing along the batch dimension isn't supported. -+ supported = false; - break; - } - if (k == std::nullopt) { -@@ -260,57 +257,29 @@ StatusOr TopkRewriter::TransformToCustomCall( - HloSortInstruction* sort = DynCast(inst); - HloInstruction* data = sort->mutable_operand(0); - const PrimitiveType element_type = data->shape().element_type(); -- const Shape data_shape = data->shape(); - -- if (element_type != F32 && element_type != BF16) { -+ if ((data->shape().rank() != 1 && data->shape().rank() != 2) || -+ (element_type != F32 && element_type != BF16)) { - continue; - } - -- // Sort dimension must be the first or last dimension. - const int64_t sort_dim = sort->sort_dimension(); -- if (sort_dim != 0 && sort_dim != data_shape.rank() - 1) { -- continue; -- } -+ const int64_t batch_dim = sort_dim == 1 ? 0 : 1; -+ const bool has_batch = data->shape().rank() == 2; - - // Profitability check. - if (!is_profitable_to_convert_(sort, *k)) { - continue; - } - -- HloInstruction* input = data; -- const bool has_batch = data_shape.rank() >= 2; -- const int64_t input_size = data_shape.dimensions(sort_dim); -- int64_t batch_size = 1; -- Shape topk_input_shape; -- -- if (has_batch) { -- // The TopK custom call expects either a 1d tensor or a 2d tensor with -- // the last dimension being the sort dimension. An input with rank > 2 -- // is reshaped into a 2d tensor by combining non-sort dimensions into a -- // single batch dimension. The original non-sort dimensions are -- // restored for the outputs with another reshape after the custom call. -- batch_size = -- ShapeUtil::ElementsIn(data_shape) / data_shape.dimensions(sort_dim); -- topk_input_shape = -- ShapeUtil::MakeShape(element_type, {batch_size, input_size}); -- -- if (data_shape.rank() > 2) { -- // Reshape to 2d. -- input = comp->AddInstruction(HloInstruction::CreateReshape( -- sort_dim == 0 -- ? ShapeUtil::MakeShape(element_type, {input_size, batch_size}) -- : ShapeUtil::MakeShape(element_type, -- {batch_size, input_size}), -- input)); -- } -- -- if (sort_dim == 0) { -- // Transpose for the custom call when sorting the first dimension. -- input = comp->AddInstruction( -- HloInstruction::CreateTranspose(topk_input_shape, input, {1, 0})); -- } -- } else { -- topk_input_shape = data_shape; -+ const int64_t batch_size = -+ has_batch ? sort->operand(0)->shape().dimensions(batch_dim) : 1; -+ const int64_t input_size = sort->operand(0)->shape().dimensions(sort_dim); -+ HloInstruction* input = sort->mutable_operand(0); -+ if (has_batch && sort_dim == 0) { -+ input = comp->AddInstruction(HloInstruction::CreateTranspose( -+ ShapeUtil::MakeShape(element_type, {batch_size, input_size}), input, -+ {1, 0})); - } - - Shape topk_shape = -@@ -331,26 +300,13 @@ StatusOr TopkRewriter::TransformToCustomCall( - comp->AddInstruction(HloInstruction::CreateGetTupleElement( - topk->shape().tuple_shapes(1), topk, 1)); - -- if (has_batch) { -- if (sort_dim == 0) { -- // Transpose back. -- value_gte = comp->AddInstruction(HloInstruction::CreateTranspose( -- ShapeUtil::MakeShape(element_type, {k.value(), batch_size}), -- value_gte, {1, 0})); -- index_gte = comp->AddInstruction(HloInstruction::CreateTranspose( -- ShapeUtil::MakeShape(S32, {k.value(), batch_size}), index_gte, -- {1, 0})); -- } -- if (data_shape.rank() > 2) { -- // Reshape back. -- std::vector shape_dim(data_shape.dimensions().begin(), -- data_shape.dimensions().end()); -- shape_dim[sort_dim] = k.value(); -- value_gte = comp->AddInstruction(HloInstruction::CreateReshape( -- ShapeUtil::MakeShape(element_type, shape_dim), value_gte)); -- index_gte = comp->AddInstruction(HloInstruction::CreateReshape( -- ShapeUtil::MakeShape(S32, shape_dim), index_gte)); -- } -+ if (has_batch && sort_dim == 0) { -+ value_gte = comp->AddInstruction(HloInstruction::CreateTranspose( -+ ShapeUtil::MakeShape(element_type, {k.value(), batch_size}), -+ value_gte, {1, 0})); -+ index_gte = comp->AddInstruction(HloInstruction::CreateTranspose( -+ ShapeUtil::MakeShape(S32, {k.value(), batch_size}), index_gte, -+ {1, 0})); - } - - for (HloInstruction* user : sort->users()) { -diff --git a/xla/service/topk_rewriter_test.cc b/xla/service/topk_rewriter_test.cc -index 36e723737..25ce150e0 100644 ---- a/xla/service/topk_rewriter_test.cc -+++ b/xla/service/topk_rewriter_test.cc -@@ -326,42 +326,6 @@ ENTRY cluster { - EXPECT_THAT(cc->custom_call_target(), "TopK"); - } - --TEST_F(TopkRewriterTest, RewriteReshape) { -- const std::string hlo_string = R"( --HloModule module --)" + getComparator() + R"( --ENTRY cluster { -- %arg_tuple.1 = f32[3,8,1234567] parameter(0) -- %iota.4 = s32[3,8,1234567] iota(), iota_dimension=2 -- %sort.27 = (f32[3,8,1234567], s32[3,8,1234567]) sort(%arg_tuple.1, %iota.4), -- dimensions={2}, is_stable=true, to_apply=%compare -- %get-tuple-element.28 = f32[3, 8,1234567] get-tuple-element(%sort.27), index=0 -- %slice.29 = f32[3,8,5] slice(%get-tuple-element.28), slice={[0:3], [0:8], [0:5]} -- %get-tuple-element.30 = s32[3,8,1234567] get-tuple-element(%sort.27), index=1 -- %slice.31 = s32[3,8,5] slice(%get-tuple-element.30), slice={[0:3], [0:8], [0:5]} -- ROOT %tuple.32 = (f32[3,8,5], s32[3,8,5]) tuple(%slice.29, %slice.31) --})"; -- TF_ASSERT_OK_AND_ASSIGN(auto module, -- ParseAndReturnVerifiedModule(hlo_string)); -- TopkRewriter rewriter( -- [](const HloSortInstruction*, int64_t) { return true; }); -- TF_ASSERT_OK_AND_ASSIGN(bool changed, rewriter.Run(module.get())); -- TF_ASSERT_OK(HloDCE().Run(module.get()).status()); -- EXPECT_TRUE(changed); -- EXPECT_THAT(module->entry_computation()->root_instruction(), -- GmockMatch(m::Tuple( -- m::Reshape(m::GetTupleElement( -- m::CustomCall(m::Reshape(m::Parameter(0))), 0)), -- m::Reshape(m::GetTupleElement( -- m::CustomCall(m::Reshape(m::Parameter(0))), 1))))); -- const HloInstruction* cc = module->entry_computation() -- ->root_instruction() -- ->operand(0) -- ->operand(0) -- ->operand(0); -- EXPECT_THAT(cc->custom_call_target(), "TopK"); --} -- - TEST_F(TopkRewriterTest, RewriteNoIota) { - const std::string hlo_string = R"( - HloModule module diff --git a/setup.py b/setup.py index a8a04c4c286a..ae3d98246672 100644 --- a/setup.py +++ b/setup.py @@ -72,7 +72,7 @@ base_dir = os.path.dirname(os.path.abspath(__file__)) -_libtpu_version = '0.1.dev20231022' +_libtpu_version = '0.1.dev20231125' _libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl' diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index a1223c5ef7e1..9d0b239f2129 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -314,7 +314,7 @@ class ComputationClient { virtual int GetNumProcesses() const = 0; using DeviceAttribute = - std::variant, float, bool>; + std::variant, float>; virtual const absl::flat_hash_map< std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>& diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 0bd42b3ad6a3..bf3bccd210e1 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -153,29 +153,29 @@ PjRtComputationClient::PjRtComputationClient() { xla::PjRtClient::KeyValuePutCallback kv_put = nullptr; if (distributed_client != nullptr) { std::string key_prefix = "gpu:"; - kv_get = [distributed_client, key_prefix](const std::string& k, - absl::Duration timeout) { + kv_get = [distributed_client, key_prefix]( + std::string_view k, + absl::Duration timeout) -> xla::StatusOr { return distributed_client->BlockingKeyValueGet( absl::StrCat(key_prefix, k), timeout); }; - kv_put = [distributed_client, key_prefix](const std::string& k, - const std::string& v) { + kv_put = [distributed_client, key_prefix]( + std::string_view k, std::string_view v) -> xla::Status { return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v); }; } TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id=" << global_process_rank << ", num_nodes=" << global_world_size; - client_ = std::move(xla::GetStreamExecutorGpuClient( - /*asynchronous=*/async, - /*allocator_config=*/GetGpuAllocatorConfig(), - /*node_id=*/global_process_rank, - /*num_nodes=*/global_world_size, - /*allowed_devices=*/allowed_devices, - /*platform_name=*/"gpu", - /*should_stage_host_to_device_transfers=*/true, - /*kv_get=*/kv_get, - /*kv_put=*/kv_put) - .value()); + xla::GpuClientOptions options; + options.allocator_config = GetGpuAllocatorConfig(); + options.node_id = global_process_rank; + options.num_nodes = global_world_size; + options.allowed_devices = allowed_devices; + options.platform_name = "gpu"; + options.should_stage_host_to_device_transfers = true; + options.kv_get = kv_get; + options.kv_put = kv_put; + client_ = std::move(xla::GetStreamExecutorGpuClient(options).value()); } else if (device_type == "XPU") { TF_VLOG(1) << "Initializing PjRt XPU client..."; XLA_CHECK_OK(