From a565ef9ee4f476484b0ca22fc302fda7ae57021f Mon Sep 17 00:00:00 2001 From: Will Cromar Date: Tue, 14 Nov 2023 15:44:52 -0800 Subject: [PATCH] Don't set $TPU_LIBRARY_PATH during import This reverts commit 4baef3c661a66f09d6fe8af91c02884e6ef3c314. --- torch_xla/__init__.py | 9 ++++++--- torch_xla/csrc/runtime/env_vars.cc | 1 + torch_xla/csrc/runtime/env_vars.h | 1 + torch_xla/csrc/runtime/pjrt_computation_client.cc | 9 +++++---- 4 files changed, 13 insertions(+), 7 deletions(-) diff --git a/torch_xla/__init__.py b/torch_xla/__init__.py index 2d522687048..7191b5d5bb9 100644 --- a/torch_xla/__init__.py +++ b/torch_xla/__init__.py @@ -87,13 +87,16 @@ def _aws_ec2_inf_trn_init(): def _setup_tpu_vm_library_path() -> bool: - """Returns true if $TPU_LIBRARY is set or can be inferred. + """Returns true if $TPU_LIBRARY_PATH is set or can be inferred. We load libtpu.so in the following order of precedence: 1. User-set $TPU_LIBRARY_PATH 2. libtpu.so included in torch_xla/lib 3. libtpu-nightly pip package + + Sets $PTXLA_TPU_LIBRARY_PATH if path is inferred by us to prevent conflicts + with other frameworks. This env var will be removed in a future version. """ if 'TPU_LIBRARY_PATH' in os.environ: return True @@ -102,12 +105,12 @@ def _setup_tpu_vm_library_path() -> bool: bundled_libtpu_path = os.path.join(module_path, 'lib/libtpu.so') if os.path.isfile(bundled_libtpu_path) and not os.getenv('TPU_LIBRARY_PATH'): logger.info('Using bundled libtpu.so (%s)', bundled_libtpu_path) - os.environ['TPU_LIBRARY_PATH'] = bundled_libtpu_path + os.environ['PTXLA_TPU_LIBRARY_PATH'] = bundled_libtpu_path return True try: import libtpu - libtpu.configure_library_path() + os.environ['PTXLA_TPU_LIBRARY_PATH'] = libtpu.get_library_path() return True except ImportError: return False diff --git a/torch_xla/csrc/runtime/env_vars.cc b/torch_xla/csrc/runtime/env_vars.cc index a7b0fdd74d6..733574a4818 100644 --- a/torch_xla/csrc/runtime/env_vars.cc +++ b/torch_xla/csrc/runtime/env_vars.cc @@ -14,6 +14,7 @@ const char* const kEnvPjRtTpuMaxInflightComputations = const char* const kEnvPjrtAsyncCpuClient = "PJRT_CPU_ASYNC_CLIENT"; const char* const kEnvPjrtAsyncGpuClient = "PJRT_GPU_ASYNC_CLIENT"; const char* const kEnvTpuLibraryPath = "TPU_LIBRARY_PATH"; +const char* const kEnvInferredTpuLibraryPath = "PTXLA_TPU_LIBRARY_PATH"; const char* const kEnvXpuLibraryPath = "XPU_LIBRARY_PATH"; const char* const kEnvNeuronLibraryPath = "NEURON_LIBRARY_PATH"; const char* const kEnvPjrtDistServiceAddr = "PJRT_DIST_SERVICE_ADDR"; diff --git a/torch_xla/csrc/runtime/env_vars.h b/torch_xla/csrc/runtime/env_vars.h index bc8a6fbc667..e7e1ef81964 100644 --- a/torch_xla/csrc/runtime/env_vars.h +++ b/torch_xla/csrc/runtime/env_vars.h @@ -24,6 +24,7 @@ extern const char* const kEnvPjRtTpuMaxInflightComputations; extern const char* const kEnvPjrtAsyncCpuClient; extern const char* const kEnvPjrtAsyncGpuClient; extern const char* const kEnvTpuLibraryPath; +extern const char* const kEnvInferredTpuLibraryPath; extern const char* const kEnvXpuLibraryPath; extern const char* const kEnvNeuronLibraryPath; extern const char* const kEnvPjrtDistServiceAddr; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 0fa3a790092..ab998d3ce0d 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -114,10 +114,11 @@ PjRtComputationClient::PjRtComputationClient() { client_ = std::move(xla::GetTfrtCpuClient(async, cpu_device_count).value()); } else if (device_type == "TPU" || device_type == "TPU_C_API") { TF_VLOG(1) << "Initializing TFRT TPU client..."; - XLA_CHECK_OK( - pjrt::LoadPjrtPlugin( - "tpu", sys_util::GetEnvString(env::kEnvTpuLibraryPath, "libtpu.so")) - .status()); + // Prefer $TPU_LIBRARY_PATH if set + auto tpu_library_path = sys_util::GetEnvString( + env::kEnvTpuLibraryPath, + sys_util::GetEnvString(env::kEnvInferredTpuLibraryPath, "libtpu.so")); + XLA_CHECK_OK(pjrt::LoadPjrtPlugin("tpu", tpu_library_path).status()); tsl::Status tpu_status = pjrt::InitializePjrtPlugin("tpu"); XLA_CHECK_OK(tpu_status); client_ = std::move(xla::GetCApiClient("TPU").value());