Skip to content

Commit

Permalink
Update XLA pin, 04/19/2024 (#6944)
Browse files Browse the repository at this point in the history
  • Loading branch information
yeounoh authored Apr 19, 2024
1 parent 0417d4d commit 2ec7706
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 21 deletions.
4 changes: 2 additions & 2 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ http_archive(
"//openxla_patches:gpu_race_condition.diff",
"//openxla_patches:f16_abi_clang.diff",
],
strip_prefix = "xla-1acf05ef0d41181caaf0cd691aa9d453ffc41a73",
strip_prefix = "xla-54ca388f9ad9e8bbcb0ef823752d6b47a99d0b5f",
urls = [
"https://github.com/openxla/xla/archive/1acf05ef0d41181caaf0cd691aa9d453ffc41a73.tar.gz",
"https://github.com/openxla/xla/archive/54ca388f9ad9e8bbcb0ef823752d6b47a99d0b5f.tar.gz",
],
)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@

base_dir = os.path.dirname(os.path.abspath(__file__))

_date = '20240409'
_date = '20240418'
_libtpu_version = f'0.1.dev{_date}'
_libtpu_storage_path = f'https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-{_libtpu_version}-py3-none-any.whl'
_jax_version = f'0.4.27.dev{_date}'
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ cc_library(
deps = [
":debug_macros",
":sys_util",
"@tsl//tsl/distributed_runtime/preemption:preemption_sync_manager",
"@xla//xla/tsl/distributed_runtime/preemption:preemption_sync_manager",
"@xla//xla/pjrt/distributed",
],
)
Expand Down
28 changes: 16 additions & 12 deletions torch_xla/csrc/runtime/ifrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,6 @@ torch::lazy::hash_t hash_comp_env(
xla::ifrt::Client* client,
std::vector<xla::ifrt::Device*>& ordered_devices) {
torch::lazy::hash_t hash = hash::HashXlaEnvVars();
auto topology_desc = client->GetTopologyForDevices(ordered_devices);
if (topology_desc.ok()) {
// Some backends support a topology description which provides a better
// view of the specific compilation environment.
auto serialized = topology_desc.value()->Serialize();
if (serialized.ok()) {
return torch::lazy::HashCombine(
hash,
torch::lazy::DataHash(serialized->data(), serialized->length()));
}
// If serialization fails, fallthrough to the manual approach.
}
std::string platform_name(client->platform_name());
std::string platform_version(client->platform_version());
hash = torch::lazy::HashCombine(
Expand All @@ -78,10 +66,26 @@ torch::lazy::hash_t hash_comp_env(
hash = torch::lazy::HashCombine(
hash, torch::lazy::StringHash(platform_version.c_str()));
// Include global devices in the hash, ensuring order is consistent.
xla::ifrt::DeviceList::Devices ifrt_devices;
for (auto& device : ordered_devices) {
std::string device_str(device->ToString());
hash = torch::lazy::HashCombine(
hash, torch::lazy::StringHash(device_str.c_str()));
ifrt_devices.push_back(device);
}

xla::ifrt::DeviceList device_list(std::move(ifrt_devices));
auto topology_desc = client->GetTopologyForDevices(device_list);
if (topology_desc.ok()) {
// Some backends support a topology description which provides a better
// view of the specific compilation environment.
auto serialized = topology_desc.value()->Serialize();
if (serialized.ok()) {
return torch::lazy::HashCombine(
hash,
torch::lazy::DataHash(serialized->data(), serialized->length()));
}
// If serialization fails, fallthrough to the manual approach.
}
return hash;
}
Expand Down
8 changes: 4 additions & 4 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ std::vector<xla::Literal> PjRtComputationClient::TransferFromDevice(
metrics::TimedSection timed(TransferFromDeviceMetric());
tsl::profiler::TraceMe activity("PjRtComputationClient::TransferFromDevice",
tsl::profiler::TraceMeLevel::kInfo);
std::vector<xla::PjRtFuture<absl::Status>> futures;
std::vector<xla::PjRtFuture<>> futures;
futures.reserve(handles.size());
std::vector<xla::Literal> literals;
literals.reserve(handles.size());
Expand Down Expand Up @@ -679,7 +679,7 @@ PjRtComputationClient::ExecuteComputation(
TF_VLOG(5) << "ExecuteComputation acquiring PJRT device lock for " << device
<< " Done";

std::optional<xla::PjRtFuture<xla::Status>> returned_future;
std::optional<xla::PjRtFuture<>> returned_future;
std::vector<std::unique_ptr<xla::PjRtBuffer>> results =
pjrt_computation.executable
->ExecuteSharded(buffers, pjrt_device, execute_options,
Expand Down Expand Up @@ -779,8 +779,8 @@ PjRtComputationClient::ExecuteReplicated(
TF_VLOG(5) << "ExecuteReplicated acquiring PJRT device lock for "
<< spmd_device_str << " Done";

std::optional<std::vector<xla::PjRtFuture<xla::Status>>> returned_futures =
std::vector<xla::PjRtFuture<xla::Status>>();
std::optional<std::vector<xla::PjRtFuture<>>> returned_futures =
std::vector<xla::PjRtFuture<>>();
std::vector<std::vector<std::unique_ptr<xla::PjRtBuffer>>> results;
{
tsl::profiler::TraceMe activity(
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/xla_coordinator.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

#include <memory>

#include "tsl/distributed_runtime/preemption/preemption_sync_manager.h"
#include "xla/pjrt/distributed/distributed.h"
#include "xla/tsl/distributed_runtime/preemption/preemption_sync_manager.h"

namespace torch_xla {
namespace runtime {
Expand Down

0 comments on commit 2ec7706

Please sign in to comment.