Skip to content

Commit

Permalink
Open XLA pin update
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Oct 4, 2023
1 parent fba326e commit 63c5a13
Show file tree
Hide file tree
Showing 8 changed files with 41 additions and 144 deletions.
9 changes: 3 additions & 6 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,12 @@ http_archive(
patch_tool = "patch",
patches = [
"//openxla_patches:cache_urls.diff",
"//openxla_patches:f16_abi_clang.diff",
"//openxla_patches:gpu_race_condition.diff",
"//openxla_patches:constexpr_return.diff",
"//openxla_patches:pjrt_api_tsl_logging.diff",
"//openxla_patches:pjrt_c_api_dynamic_dimensions.diff",
"//openxla_patches:gpu_build_file.diff",
],
strip_prefix = "xla-97a5f819faf9ff793b7ba68ff1f31f74f9459c18",
strip_prefix = "xla-7a19856d74569fd1f765cd03bdee84e3b1fdc579",
urls = [
"https://github.com/openxla/xla/archive/97a5f819faf9ff793b7ba68ff1f31f74f9459c18.tar.gz",
"https://github.com/openxla/xla/archive/7a19856d74569fd1f765cd03bdee84e3b1fdc579.tar.gz",
],
)

Expand Down
19 changes: 0 additions & 19 deletions openxla_patches/f16_abi_clang.diff

This file was deleted.

25 changes: 25 additions & 0 deletions openxla_patches/gpu_build_file.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
diff --git a/xla/pjrt/gpu/BUILD b/xla/pjrt/gpu/BUILD
index 9ad1fca31..8ea07ad0a 100644
--- a/xla/pjrt/gpu/BUILD
+++ b/xla/pjrt/gpu/BUILD
@@ -237,17 +237,17 @@ cc_library(
"@com_google_absl//absl/status",
"@tsl//tsl/platform:casts",
"@tsl//tsl/platform:errors",
- ] + if_cuda([
+ ] + if_cuda_or_rocm([
+ "//xla/service/gpu:gpu_compiler",
+ ]) + if_cuda([
":nccl_id_store_cuda",
"@local_config_cuda//cuda:cuda_headers",
"//xla/stream_executor/cuda:cuda_activation_header",
"//xla/stream_executor/gpu:gpu_cudamallocasync_allocator",
- "//xla/service/gpu:gpu_compiler",
"//xla/service/gpu:nvptx_compiler_impl",
]) + if_rocm([
":nccl_id_store_rocm",
"@local_config_rocm//rocm:rocm_headers",
- "//xla/service/gpu:gpu_compiler",
"//xla/service/gpu:amdgpu_compiler_impl",
]),
alwayslink = True,
14 changes: 0 additions & 14 deletions openxla_patches/gpu_race_condition.diff

This file was deleted.

21 changes: 0 additions & 21 deletions openxla_patches/pjrt_api_tsl_logging.diff

This file was deleted.

76 changes: 0 additions & 76 deletions openxla_patches/pjrt_c_api_dynamic_dimensions.diff

This file was deleted.

2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ class ComputationClient {
virtual int GetNumProcesses() const = 0;

using DeviceAttribute =
std::variant<std::string, int64_t, std::vector<int64_t>, float>;
std::variant<std::string, int64_t, std::vector<int64_t>, float, bool>;

virtual const absl::flat_hash_map<
std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>&
Expand Down
19 changes: 12 additions & 7 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,10 @@ PjRtComputationClient::PjRtComputationClient() {
client_ = std::move(xla::GetTfrtCpuClient(async, cpu_device_count).value());
} else if (device_type == "TPU" || device_type == "TPU_C_API") {
TF_VLOG(1) << "Initializing TFRT TPU client...";
XLA_CHECK_OK(pjrt::LoadPjrtPlugin(
"tpu", sys_util::GetEnvString(env::kEnvTpuLibraryPath, "libtpu.so")));
XLA_CHECK_OK(
pjrt::LoadPjrtPlugin(
"tpu", sys_util::GetEnvString(env::kEnvTpuLibraryPath, "libtpu.so"))
.status());
tsl::Status tpu_status = pjrt::InitializePjrtPlugin("tpu");
XLA_CHECK(tpu_status.ok());
client_ = std::move(xla::GetCApiClient("TPU").value());
Expand Down Expand Up @@ -154,15 +156,18 @@ PjRtComputationClient::PjRtComputationClient() {
.value());
} else if (device_type == "XPU") {
TF_VLOG(1) << "Initializing PjRt XPU client...";
XLA_CHECK_OK(pjrt::LoadPjrtPlugin(
"xpu", sys_util::GetEnvString(env::kEnvXpuLibraryPath, "libxpu.so")));
XLA_CHECK_OK(
pjrt::LoadPjrtPlugin(
"xpu", sys_util::GetEnvString(env::kEnvXpuLibraryPath, "libxpu.so"))
.status());
client_ = std::move(xla::GetCApiClient("XPU").value());

} else if (device_type == "NEURON") {
TF_VLOG(1) << "Initializing PjRt NEURON client...";
XLA_CHECK_OK(pjrt::LoadPjrtPlugin(
"NEURON", sys_util::GetEnvString(env::kEnvNeuronLibraryPath,
"libneuronpjrt.so")));
XLA_CHECK_OK(pjrt::LoadPjrtPlugin("NEURON", sys_util::GetEnvString(
env::kEnvNeuronLibraryPath,
"libneuronpjrt.so"))
.status());
client_ = std::move(xla::GetCApiClient("NEURON").value());
} else {
XLA_ERROR() << absl::StrFormat("Unknown %s '%s'", env::kEnvPjRtDevice,
Expand Down

0 comments on commit 63c5a13

Please sign in to comment.