Skip to content

Commit

Permalink
Revert "Don't set $TPU_LIBRARY_PATH during import (#5698)" (#5731)
Browse files Browse the repository at this point in the history
This reverts commit 146f2a0.
  • Loading branch information
alanwaketan authored and jonb377 committed Oct 31, 2023
1 parent 10d3cdb commit 967996b
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 13 deletions.
9 changes: 3 additions & 6 deletions torch_xla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,13 @@ def _aws_ec2_inf_trn_init():


def _setup_tpu_vm_library_path() -> bool:
"""Returns true if $TPU_LIBRARY_PATH is set or can be inferred.
"""Returns true if $TPU_LIBRARY is set or can be inferred.
We load libtpu.so in the following order of precedence:
1. User-set $TPU_LIBRARY_PATH
2. libtpu.so included in torch_xla/lib
3. libtpu-nightly pip package
Sets $PTXLA_TPU_LIBRARY_PATH if path is inferred by us to prevent conflicts
with other frameworks. This env var will be removed in a future version.
"""
if 'TPU_LIBRARY_PATH' in os.environ:
return True
Expand All @@ -90,12 +87,12 @@ def _setup_tpu_vm_library_path() -> bool:
bundled_libtpu_path = os.path.join(module_path, 'lib/libtpu.so')
if os.path.isfile(bundled_libtpu_path) and not os.getenv('TPU_LIBRARY_PATH'):
logger.info('Using bundled libtpu.so (%s)', bundled_libtpu_path)
os.environ['PTXLA_TPU_LIBRARY_PATH'] = bundled_libtpu_path
os.environ['TPU_LIBRARY_PATH'] = bundled_libtpu_path
return True

try:
import libtpu
os.environ['PTXLA_TPU_LIBRARY_PATH'] = libtpu.get_library_path()
libtpu.configure_library_path()
return True
except ImportError:
return False
Expand Down
1 change: 0 additions & 1 deletion torch_xla/csrc/runtime/env_vars.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ const char* const kEnvPjRtTpuMaxInflightComputations =
const char* const kEnvPjrtAsyncCpuClient = "PJRT_CPU_ASYNC_CLIENT";
const char* const kEnvPjrtAsyncGpuClient = "PJRT_GPU_ASYNC_CLIENT";
const char* const kEnvTpuLibraryPath = "TPU_LIBRARY_PATH";
const char* const kEnvInferredTpuLibraryPath = "PTXLA_TPU_LIBRARY_PATH";
const char* const kEnvXpuLibraryPath = "XPU_LIBRARY_PATH";
const char* const kEnvNeuronLibraryPath = "NEURON_LIBRARY_PATH";
const char* const kEnvPjrtDistServiceAddr = "PJRT_DIST_SERVICE_ADDR";
Expand Down
1 change: 0 additions & 1 deletion torch_xla/csrc/runtime/env_vars.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ extern const char* const kEnvPjRtTpuMaxInflightComputations;
extern const char* const kEnvPjrtAsyncCpuClient;
extern const char* const kEnvPjrtAsyncGpuClient;
extern const char* const kEnvTpuLibraryPath;
extern const char* const kEnvInferredTpuLibraryPath;
extern const char* const kEnvXpuLibraryPath;
extern const char* const kEnvNeuronLibraryPath;
extern const char* const kEnvPjrtDistServiceAddr;
Expand Down
9 changes: 4 additions & 5 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,10 @@ PjRtComputationClient::PjRtComputationClient() {
client_ = std::move(xla::GetTfrtCpuClient(async, cpu_device_count).value());
} else if (device_type == "TPU" || device_type == "TPU_C_API") {
TF_VLOG(1) << "Initializing TFRT TPU client...";
// Prefer $TPU_LIBRARY_PATH if set
auto tpu_library_path = sys_util::GetEnvString(
env::kEnvTpuLibraryPath,
sys_util::GetEnvString(env::kEnvInferredTpuLibraryPath, "libtpu.so"));
XLA_CHECK_OK(pjrt::LoadPjrtPlugin("tpu", tpu_library_path).status());
XLA_CHECK_OK(
pjrt::LoadPjrtPlugin(
"tpu", sys_util::GetEnvString(env::kEnvTpuLibraryPath, "libtpu.so"))
.status());
tsl::Status tpu_status = pjrt::InitializePjrtPlugin("tpu");
XLA_CHECK(tpu_status.ok());
client_ = std::move(xla::GetCApiClient("TPU").value());
Expand Down

0 comments on commit 967996b

Please sign in to comment.