diff --git a/WORKSPACE b/WORKSPACE index a4e4027a67ef..073e270d76a6 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -39,15 +39,12 @@ http_archive( patch_tool = "patch", patches = [ "//openxla_patches:cache_urls.diff", - "//openxla_patches:f16_abi_clang.diff", - "//openxla_patches:gpu_race_condition.diff", "//openxla_patches:constexpr_return.diff", - "//openxla_patches:pjrt_api_tsl_logging.diff", - "//openxla_patches:pjrt_c_api_dynamic_dimensions.diff", + "//openxla_patches:gpu_build_file.diff", ], - strip_prefix = "xla-97a5f819faf9ff793b7ba68ff1f31f74f9459c18", + strip_prefix = "xla-7a19856d74569fd1f765cd03bdee84e3b1fdc579", urls = [ - "https://github.com/openxla/xla/archive/97a5f819faf9ff793b7ba68ff1f31f74f9459c18.tar.gz", + "https://github.com/openxla/xla/archive/7a19856d74569fd1f765cd03bdee84e3b1fdc579.tar.gz", ], ) diff --git a/openxla_patches/f16_abi_clang.diff b/openxla_patches/f16_abi_clang.diff deleted file mode 100644 index 24cc8e5b74d5..000000000000 --- a/openxla_patches/f16_abi_clang.diff +++ /dev/null @@ -1,19 +0,0 @@ -upstream CI will fail without this -diff --git a/xla/service/cpu/runtime_fp16.h b/xla/service/cpu/runtime_fp16.h -index 3f7af5197..ce4491c5d 100644 ---- a/xla/service/cpu/runtime_fp16.h -+++ b/xla/service/cpu/runtime_fp16.h -@@ -18,12 +18,7 @@ limitations under the License. - - #include - --// _Float16 always gets us the correct ABI type, so use that if available. --// AArch64 GCC defines __FLT16_MANT_DIG__ even when _Float16 is not available. --#if defined(__FLT16_MANT_DIG__) && \ -- (defined(__clang__) || !(defined(__GNUC__) && defined(__aarch64__))) --using XlaF16ABIType = _Float16; --#elif defined(__x86_64__) -+#if defined(__x86_64__) - // Older versions of Clang don't have _Float16. Since both float and _Float16 - // are passed in the same register we can use the wider type and careful casting - // to conform to x86_64 psABI. This only works with the assumption that we're \ No newline at end of file diff --git a/openxla_patches/gpu_build_file.diff b/openxla_patches/gpu_build_file.diff new file mode 100644 index 000000000000..0be682c2974a --- /dev/null +++ b/openxla_patches/gpu_build_file.diff @@ -0,0 +1,25 @@ +diff --git a/xla/pjrt/gpu/BUILD b/xla/pjrt/gpu/BUILD +index 9ad1fca31..8ea07ad0a 100644 +--- a/xla/pjrt/gpu/BUILD ++++ b/xla/pjrt/gpu/BUILD +@@ -237,17 +237,17 @@ cc_library( + "@com_google_absl//absl/status", + "@tsl//tsl/platform:casts", + "@tsl//tsl/platform:errors", +- ] + if_cuda([ ++ ] + if_cuda_or_rocm([ ++ "//xla/service/gpu:gpu_compiler", ++ ]) + if_cuda([ + ":nccl_id_store_cuda", + "@local_config_cuda//cuda:cuda_headers", + "//xla/stream_executor/cuda:cuda_activation_header", + "//xla/stream_executor/gpu:gpu_cudamallocasync_allocator", +- "//xla/service/gpu:gpu_compiler", + "//xla/service/gpu:nvptx_compiler_impl", + ]) + if_rocm([ + ":nccl_id_store_rocm", + "@local_config_rocm//rocm:rocm_headers", +- "//xla/service/gpu:gpu_compiler", + "//xla/service/gpu:amdgpu_compiler_impl", + ]), + alwayslink = True, diff --git a/openxla_patches/gpu_race_condition.diff b/openxla_patches/gpu_race_condition.diff deleted file mode 100644 index dfdc3aa74608..000000000000 --- a/openxla_patches/gpu_race_condition.diff +++ /dev/null @@ -1,14 +0,0 @@ -diff --git a/xla/service/gpu/gpu_executable.cc b/xla/service/gpu/gpu_executable.cc -index 242961dd1..787275868 100644 ---- a/xla/service/gpu/gpu_executable.cc -+++ b/xla/service/gpu/gpu_executable.cc -@@ -563,8 +563,7 @@ StatusOr GpuExecutable::ExecuteAsyncOnStreamImpl( - } - - // Force synchronous execution if the allocator requires it. -- const bool block_host_until_done = -- !memory_allocator->AllowsAsynchronousDeallocation(); -+ const bool block_host_until_done = true; - - - // Lock the GPU with a shared lock so that we don't interfere with autotuning \ No newline at end of file diff --git a/openxla_patches/pjrt_api_tsl_logging.diff b/openxla_patches/pjrt_api_tsl_logging.diff deleted file mode 100644 index 296bed91ad68..000000000000 --- a/openxla_patches/pjrt_api_tsl_logging.diff +++ /dev/null @@ -1,21 +0,0 @@ -# Fixes log spam when loading libtpu. We should fix this upstream. -diff --git a/xla/pjrt/pjrt_api.cc b/xla/pjrt/pjrt_api.cc -index 132cfaff0..887e842e0 100644 ---- a/xla/pjrt/pjrt_api.cc -+++ b/xla/pjrt/pjrt_api.cc -@@ -17,7 +17,6 @@ limitations under the License. - - #include - --#include "absl/log/log.h" - #include "absl/status/status.h" - #include "absl/strings/str_cat.h" - #include "xla/pjrt/c/pjrt_c_api.h" -@@ -33,6 +32,7 @@ limitations under the License. - #include "xla/pjrt/c/pjrt_c_api_helpers.h" - #include "xla/status.h" - #include "xla/statusor.h" -+#include "tsl/platform/logging.h" - #include "tsl/platform/errors.h" - - namespace pjrt { diff --git a/openxla_patches/pjrt_c_api_dynamic_dimensions.diff b/openxla_patches/pjrt_c_api_dynamic_dimensions.diff deleted file mode 100644 index ee1ec00eced5..000000000000 --- a/openxla_patches/pjrt_c_api_dynamic_dimensions.diff +++ /dev/null @@ -1,76 +0,0 @@ -# Partial backport of 6308dba2903e78961ac4122f361bc91b09f36891. Remove in next -# pin update. -diff --git a/xla/pjrt/pjrt_c_api_client.cc b/xla/pjrt/pjrt_c_api_client.cc -index ef0b6686c..c0341e81e 100644 ---- a/xla/pjrt/pjrt_c_api_client.cc -+++ b/xla/pjrt/pjrt_c_api_client.cc -@@ -1584,6 +1584,34 @@ bool PjRtCApiBuffer::has_dynamic_dimensions() const { - return args.num_dynamic_dims > 0; - } - -+absl::Span PjRtCApiBuffer::is_dynamic_dimension() const { -+ { -+ absl::MutexLock lock(&mu_); -+ if (!is_dynamic_dimension_.has_value()) { -+ absl::InlinedVector& is_dynamic_dimension_value = -+ is_dynamic_dimension_.emplace(); -+ is_dynamic_dimension_value.assign(dimensions().size(), false); -+ -+ PJRT_Buffer_DynamicDimensionIndices_Args args; -+ args.struct_size = PJRT_Buffer_DynamicDimensionIndices_Args_STRUCT_SIZE; -+ args.priv = nullptr; -+ args.buffer = buffer_.get(); -+ const PJRT_Api* api = pjrt_c_api(); -+ std::unique_ptr error( -+ api->PJRT_Buffer_DynamicDimensionIndices(&args), -+ pjrt::MakeErrorDeleter(api)); -+ if (error && pjrt::GetErrorCode(error.get(), api) == -+ PJRT_Error_Code_UNIMPLEMENTED) { -+ return *is_dynamic_dimension_; -+ } -+ for (int i = 0; i < args.num_dynamic_dims; ++i) { -+ is_dynamic_dimension_value[args.dynamic_dim_indices[i]] = true; -+ } -+ } -+ } -+ return *is_dynamic_dimension_; -+} -+ - StatusOr> PjRtCApiBuffer::logical_dimensions() { - PJRT_Buffer_UnpaddedDimensions_Args args; - args.struct_size = PJRT_Buffer_UnpaddedDimensions_Args_STRUCT_SIZE; -diff --git a/xla/pjrt/pjrt_c_api_client.h b/xla/pjrt/pjrt_c_api_client.h -index 9c460f246..279608e60 100644 ---- a/xla/pjrt/pjrt_c_api_client.h -+++ b/xla/pjrt/pjrt_c_api_client.h -@@ -27,6 +27,7 @@ limitations under the License. - #include - - #include "absl/container/flat_hash_map.h" -+#include "absl/container/inlined_vector.h" - #include "absl/log/check.h" - #include "absl/log/log.h" - #include "absl/strings/string_view.h" -@@ -369,11 +370,7 @@ class PjRtCApiBuffer : public PjRtBuffer { - - bool has_dynamic_dimensions() const override; - -- absl::Span is_dynamic_dimension() const override { -- LOG(FATAL) << "PjRtCApiBuffer::is_dynamic_dimension() not implemented. " -- << "Considering using has_dynamic_dimensions() or " -- "logical_dimensions() if applicable."; -- } -+ absl::Span is_dynamic_dimension() const override; - - StatusOr> logical_dimensions() override; - -@@ -455,6 +452,9 @@ class PjRtCApiBuffer : public PjRtBuffer { - std::shared_ptr::Promise> readiness_promise_; - // Set and cached the first time layout() is called. - mutable std::optional layout_; -+ // Set and cached the first time is_dynamic_dimension() is called. -+ mutable std::optional> -+ is_dynamic_dimension_; - // Used to synchronize concurrent setting of cached values. - mutable absl::Mutex mu_; - }; diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index a09768c6a9ef..1e610be7959b 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -329,7 +329,7 @@ class ComputationClient { virtual int GetNumProcesses() const = 0; using DeviceAttribute = - std::variant, float>; + std::variant, float, bool>; virtual const absl::flat_hash_map< std::string, torch_xla::runtime::ComputationClient::DeviceAttribute>& diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 4f175be7d717..30ebd247e935 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -109,8 +109,10 @@ PjRtComputationClient::PjRtComputationClient() { client_ = std::move(xla::GetTfrtCpuClient(async, cpu_device_count).value()); } else if (device_type == "TPU" || device_type == "TPU_C_API") { TF_VLOG(1) << "Initializing TFRT TPU client..."; - XLA_CHECK_OK(pjrt::LoadPjrtPlugin( - "tpu", sys_util::GetEnvString(env::kEnvTpuLibraryPath, "libtpu.so"))); + XLA_CHECK_OK( + pjrt::LoadPjrtPlugin( + "tpu", sys_util::GetEnvString(env::kEnvTpuLibraryPath, "libtpu.so")) + .status()); tsl::Status tpu_status = pjrt::InitializePjrtPlugin("tpu"); XLA_CHECK(tpu_status.ok()); client_ = std::move(xla::GetCApiClient("TPU").value()); @@ -154,15 +156,18 @@ PjRtComputationClient::PjRtComputationClient() { .value()); } else if (device_type == "XPU") { TF_VLOG(1) << "Initializing PjRt XPU client..."; - XLA_CHECK_OK(pjrt::LoadPjrtPlugin( - "xpu", sys_util::GetEnvString(env::kEnvXpuLibraryPath, "libxpu.so"))); + XLA_CHECK_OK( + pjrt::LoadPjrtPlugin( + "xpu", sys_util::GetEnvString(env::kEnvXpuLibraryPath, "libxpu.so")) + .status()); client_ = std::move(xla::GetCApiClient("XPU").value()); } else if (device_type == "NEURON") { TF_VLOG(1) << "Initializing PjRt NEURON client..."; - XLA_CHECK_OK(pjrt::LoadPjrtPlugin( - "NEURON", sys_util::GetEnvString(env::kEnvNeuronLibraryPath, - "libneuronpjrt.so"))); + XLA_CHECK_OK(pjrt::LoadPjrtPlugin("NEURON", sys_util::GetEnvString( + env::kEnvNeuronLibraryPath, + "libneuronpjrt.so")) + .status()); client_ = std::move(xla::GetCApiClient("NEURON").value()); } else { XLA_ERROR() << absl::StrFormat("Unknown %s '%s'", env::kEnvPjRtDevice,