From 88a549554f1686e91525bdb2a302998f920b6cd8 Mon Sep 17 00:00:00 2001 From: Jiewen Tan Date: Wed, 25 Oct 2023 14:23:05 -0700 Subject: [PATCH] Revert "Don't set $TPU_LIBRARY_PATH during import (#5698)" (#5731) This reverts commit 146f2a0140c5217a31d55cfdebb1efb6bc05b41a. --- torch_xla/__init__.py | 9 +++------ torch_xla/csrc/runtime/env_vars.cc | 1 - torch_xla/csrc/runtime/env_vars.h | 1 - torch_xla/csrc/runtime/pjrt_computation_client.cc | 9 ++++----- 4 files changed, 7 insertions(+), 13 deletions(-) diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index ce1787d0fe3a..eeaa2aaba0c9 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -72,16 +72,13 @@ def _aws_ec2_inf_trn_init(): def _setup_tpu_vm_library_path() -> bool: - """Returns true if $TPU_LIBRARY_PATH is set or can be inferred. + """Returns true if $TPU_LIBRARY is set or can be inferred. We load libtpu.so in the following order of precedence: 1. User-set $TPU_LIBRARY_PATH 2. libtpu.so included in torch_xla/lib 3. libtpu-nightly pip package - - Sets $PTXLA_TPU_LIBRARY_PATH if path is inferred by us to prevent conflicts - with other frameworks. This env var will be removed in a future version. """ if 'TPU_LIBRARY_PATH' in os.environ: return True @@ -90,12 +87,12 @@ def _setup_tpu_vm_library_path() -> bool: bundled_libtpu_path = os.path.join(module_path, 'lib/libtpu.so') if os.path.isfile(bundled_libtpu_path) and not os.getenv('TPU_LIBRARY_PATH'): logger.info('Using bundled libtpu.so (%s)', bundled_libtpu_path) - os.environ['PTXLA_TPU_LIBRARY_PATH'] = bundled_libtpu_path + os.environ['TPU_LIBRARY_PATH'] = bundled_libtpu_path return True try: import libtpu - os.environ['PTXLA_TPU_LIBRARY_PATH'] = libtpu.get_library_path() + libtpu.configure_library_path() return True except ImportError: return False diff --git a/torch_xla/csrc/runtime/env_vars.cc b/torch_xla/csrc/runtime/env_vars.cc index 00ffb1f2a25f..42040a9cca5f 100644 --- a/torch_xla/csrc/runtime/env_vars.cc +++ b/torch_xla/csrc/runtime/env_vars.cc @@ -14,7 +14,6 @@ const char* const kEnvPjRtTpuMaxInflightComputations = const char* const kEnvPjrtAsyncCpuClient = "PJRT_CPU_ASYNC_CLIENT"; const char* const kEnvPjrtAsyncGpuClient = "PJRT_GPU_ASYNC_CLIENT"; const char* const kEnvTpuLibraryPath = "TPU_LIBRARY_PATH"; -const char* const kEnvInferredTpuLibraryPath = "PTXLA_TPU_LIBRARY_PATH"; const char* const kEnvXpuLibraryPath = "XPU_LIBRARY_PATH"; const char* const kEnvNeuronLibraryPath = "NEURON_LIBRARY_PATH"; const char* const kEnvPjrtDistServiceAddr = "PJRT_DIST_SERVICE_ADDR"; diff --git a/torch_xla/csrc/runtime/env_vars.h b/torch_xla/csrc/runtime/env_vars.h index 72849003765b..e54ba8f72cd8 100644 --- a/torch_xla/csrc/runtime/env_vars.h +++ b/torch_xla/csrc/runtime/env_vars.h @@ -24,7 +24,6 @@ extern const char* const kEnvPjRtTpuMaxInflightComputations; extern const char* const kEnvPjrtAsyncCpuClient; extern const char* const kEnvPjrtAsyncGpuClient; extern const char* const kEnvTpuLibraryPath; -extern const char* const kEnvInferredTpuLibraryPath; extern const char* const kEnvXpuLibraryPath; extern const char* const kEnvNeuronLibraryPath; extern const char* const kEnvPjrtDistServiceAddr; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index f5c7b724aace..2ae0768856b7 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -94,11 +94,10 @@ PjRtComputationClient::PjRtComputationClient() { client_ = std::move(xla::GetTfrtCpuClient(async, cpu_device_count).value()); } else if (device_type == "TPU" || device_type == "TPU_C_API") { TF_VLOG(1) << "Initializing TFRT TPU client..."; - // Prefer $TPU_LIBRARY_PATH if set - auto tpu_library_path = sys_util::GetEnvString( - env::kEnvTpuLibraryPath, - sys_util::GetEnvString(env::kEnvInferredTpuLibraryPath, "libtpu.so")); - XLA_CHECK_OK(pjrt::LoadPjrtPlugin("tpu", tpu_library_path).status()); + XLA_CHECK_OK( + pjrt::LoadPjrtPlugin( + "tpu", sys_util::GetEnvString(env::kEnvTpuLibraryPath, "libtpu.so")) + .status()); tsl::Status tpu_status = pjrt::InitializePjrtPlugin("tpu"); XLA_CHECK(tpu_status.ok()); client_ = std::move(xla::GetCApiClient("TPU").value());