Skip to content

Commit

Permalink
Revert "Update Openxla-pin to 04/24" (#6980)
Browse files Browse the repository at this point in the history
  • Loading branch information
ManfeiBai authored Apr 26, 2024
1 parent 2bf59e0 commit 023e2c8
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 11 deletions.
4 changes: 2 additions & 2 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ http_archive(
"//openxla_patches:gpu_race_condition.diff",
"//openxla_patches:f16_abi_clang.diff",
],
strip_prefix = "xla-fe08041b23d8baa0d00967913a1d6e8a0c348df3",
strip_prefix = "xla-54ca388f9ad9e8bbcb0ef823752d6b47a99d0b5f",
urls = [
"https://github.com/openxla/xla/archive/fe08041b23d8baa0d00967913a1d6e8a0c348df3.tar.gz",
"https://github.com/openxla/xla/archive/54ca388f9ad9e8bbcb0ef823752d6b47a99d0b5f.tar.gz",
],
)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@

base_dir = os.path.dirname(os.path.abspath(__file__))

_date = '20240425'
_date = '20240418'
_libtpu_version = f'0.1.dev{_date}'
_libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl'
_jax_version = f'0.4.27.dev{_date}'
Expand Down
13 changes: 6 additions & 7 deletions torch_xla/csrc/runtime/ifrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ std::string IfrtComputationClient::IfrtDeviceToString(
xla::ifrt::Device* const device) const {
std::string platform =
absl::AsciiStrToUpper(device->client()->platform_name());
int ordinal = global_ordinals_.at(device->Id().value());
int ordinal = global_ordinals_.at(device->id());
std::string str = absl::StrFormat("%s:%d", platform, ordinal);
return str;
}
Expand Down Expand Up @@ -124,12 +124,11 @@ IfrtComputationClient::IfrtComputationClient() {
// a device's global ordinal separately from its device ID. Order the
// devices by increasing ID to assign global ordinals.
std::vector<xla::ifrt::Device*> ordered_devices(client_->device_count());
std::partial_sort_copy(
client_->devices().begin(), client_->devices().end(),
ordered_devices.begin(), ordered_devices.end(),
[](auto& a, auto& b) { return a->Id().value() < b->Id().value(); });
std::partial_sort_copy(client_->devices().begin(), client_->devices().end(),
ordered_devices.begin(), ordered_devices.end(),
[](auto& a, auto& b) { return a->id() < b->id(); });
for (auto* device : ordered_devices) {
global_ordinals_[device->Id().value()] = global_ordinals_.size();
global_ordinals_[device->id()] = global_ordinals_.size();
std::string device_str = IfrtDeviceToString(device);
string_to_device_.emplace(device_str, device);
}
Expand Down Expand Up @@ -616,7 +615,7 @@ std::vector<std::string> IfrtComputationClient::GetAllDevices() const {
int IfrtComputationClient::GetNumProcesses() const {
int max_process_index = client_->process_index();
for (auto* device : client_->devices()) {
max_process_index = std::max(max_process_index, device->ProcessIndex());
max_process_index = std::max(max_process_index, device->process_index());
}

return max_process_index + 1;
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/ifrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class IfrtComputationClient : public ComputationClient {
// global_ordinals_ tracks a map from PjRtDeviceId to the device's
// dense global ordinal.
std::unordered_map<int, int> global_ordinals_;
std::unordered_map<std::string, xla::ifrt::Device* const> string_to_device_;
std::unordered_map<std::string, xla::PjRtDevice* const> string_to_device_;
std::shared_ptr<std::vector<std::string>> replication_devices_;
OperationManager operation_manager_;
tsl::thread::ThreadPool pool_ = tsl::thread::ThreadPool(
Expand Down

0 comments on commit 023e2c8

Please sign in to comment.