From 023e2c83dcf20b973ad4ec60b27469fcada02af5 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Fri, 26 Apr 2024 10:28:21 -0700 Subject: [PATCH] Revert "Update Openxla-pin to 04/24" (#6980) --- WORKSPACE | 4 ++-- setup.py | 2 +- torch_xla/csrc/runtime/ifrt_computation_client.cc | 13 ++++++------- torch_xla/csrc/runtime/ifrt_computation_client.h | 2 +- 4 files changed, 10 insertions(+), 11 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index 9fe770bedff..9c6963dae65 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -50,9 +50,9 @@ http_archive( "//openxla_patches:gpu_race_condition.diff", "//openxla_patches:f16_abi_clang.diff", ], - strip_prefix = "xla-fe08041b23d8baa0d00967913a1d6e8a0c348df3", + strip_prefix = "xla-54ca388f9ad9e8bbcb0ef823752d6b47a99d0b5f", urls = [ - "https://github.com/openxla/xla/archive/fe08041b23d8baa0d00967913a1d6e8a0c348df3.tar.gz", + "https://github.com/openxla/xla/archive/54ca388f9ad9e8bbcb0ef823752d6b47a99d0b5f.tar.gz", ], ) diff --git a/setup.py b/setup.py index 31f4eaf679c..dbe47007aff 100644 --- a/setup.py +++ b/setup.py @@ -64,7 +64,7 @@ base_dir = os.path.dirname(os.path.abspath(__file__)) -_date = '20240425' +_date = '20240418' _libtpu_version = f'0.1.dev{_date}' _libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl' _jax_version = f'0.4.27.dev{_date}' diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.cc b/torch_xla/csrc/runtime/ifrt_computation_client.cc index c48cf1555ff..20ee9b0bfa6 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.cc +++ b/torch_xla/csrc/runtime/ifrt_computation_client.cc @@ -96,7 +96,7 @@ std::string IfrtComputationClient::IfrtDeviceToString( xla::ifrt::Device* const device) const { std::string platform = absl::AsciiStrToUpper(device->client()->platform_name()); - int ordinal = global_ordinals_.at(device->Id().value()); + int ordinal = global_ordinals_.at(device->id()); std::string str = absl::StrFormat("%s:%d", platform, ordinal); return str; } @@ -124,12 +124,11 @@ IfrtComputationClient::IfrtComputationClient() { // a device's global ordinal separately from its device ID. Order the // devices by increasing ID to assign global ordinals. std::vector ordered_devices(client_->device_count()); - std::partial_sort_copy( - client_->devices().begin(), client_->devices().end(), - ordered_devices.begin(), ordered_devices.end(), - [](auto& a, auto& b) { return a->Id().value() < b->Id().value(); }); + std::partial_sort_copy(client_->devices().begin(), client_->devices().end(), + ordered_devices.begin(), ordered_devices.end(), + [](auto& a, auto& b) { return a->id() < b->id(); }); for (auto* device : ordered_devices) { - global_ordinals_[device->Id().value()] = global_ordinals_.size(); + global_ordinals_[device->id()] = global_ordinals_.size(); std::string device_str = IfrtDeviceToString(device); string_to_device_.emplace(device_str, device); } @@ -616,7 +615,7 @@ std::vector IfrtComputationClient::GetAllDevices() const { int IfrtComputationClient::GetNumProcesses() const { int max_process_index = client_->process_index(); for (auto* device : client_->devices()) { - max_process_index = std::max(max_process_index, device->ProcessIndex()); + max_process_index = std::max(max_process_index, device->process_index()); } return max_process_index + 1; diff --git a/torch_xla/csrc/runtime/ifrt_computation_client.h b/torch_xla/csrc/runtime/ifrt_computation_client.h index 38d0de97204..d6d914ad8da 100644 --- a/torch_xla/csrc/runtime/ifrt_computation_client.h +++ b/torch_xla/csrc/runtime/ifrt_computation_client.h @@ -134,7 +134,7 @@ class IfrtComputationClient : public ComputationClient { // global_ordinals_ tracks a map from PjRtDeviceId to the device's // dense global ordinal. std::unordered_map global_ordinals_; - std::unordered_map string_to_device_; + std::unordered_map string_to_device_; std::shared_ptr> replication_devices_; OperationManager operation_manager_; tsl::thread::ThreadPool pool_ = tsl::thread::ThreadPool(