Skip to content

Commit

Permalink
Don't set $TPU_LIBRARY_PATH during import
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Oct 11, 2023
1 parent 418c751 commit 84a783d
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 8 deletions.
11 changes: 7 additions & 4 deletions torch_xla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,27 +72,30 @@ def _aws_ec2_inf_trn_init():


def _setup_tpu_vm_library_path() -> bool:
"""Returns true if $TPU_LIBRARY is set or can be inferred.
"""Returns true if $TPU_LIBRARY_PATH is set or can be inferred.
We load libtpu.so in the following order of precedence:
1. User-set $TPU_LIBRARY_PATH
2. libtpu.so included in torch_xla/lib
3. libtpu-nightly pip package
Sets $PTXLA_TPU_LIBRARY_PATH if path is inferred by us to prevent conflicts
with other frameworks. This env var will be removed in a future version.
"""
if 'TPU_LIBRARY_PATH' in os.environ:
if 'TPU_LIBRARY_PATH' in os.environ or 'PTXLA_TPU_LIBRARY_PATH' in os.environ:
return True

module_path = os.path.dirname(__file__)
bundled_libtpu_path = os.path.join(module_path, 'lib/libtpu.so')
if os.path.isfile(bundled_libtpu_path) and not os.getenv('TPU_LIBRARY_PATH'):
logger.info('Using bundled libtpu.so (%s)', bundled_libtpu_path)
os.environ['TPU_LIBRARY_PATH'] = bundled_libtpu_path
os.environ['PTXLA_TPU_LIBRARY_PATH'] = bundled_libtpu_path
return True

try:
import libtpu
libtpu.configure_library_path()
os.environ['PTXLA_TPU_LIBRARY_PATH'] = libtpu.get_library_path()
return True
except ImportError:
return False
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/runtime/env_vars.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ const char* const kEnvPjRtTpuMaxInflightComputations =
const char* const kEnvPjrtAsyncCpuClient = "PJRT_CPU_ASYNC_CLIENT";
const char* const kEnvPjrtAsyncGpuClient = "PJRT_GPU_ASYNC_CLIENT";
const char* const kEnvTpuLibraryPath = "TPU_LIBRARY_PATH";
const char* const kEnvInferredTpuLibraryPath = "PTXLA_TPU_LIBRARY_PATH";
const char* const kEnvXpuLibraryPath = "XPU_LIBRARY_PATH";
const char* const kEnvNeuronLibraryPath = "NEURON_LIBRARY_PATH";
const char* const kEnvPjrtDistServiceAddr = "PJRT_DIST_SERVICE_ADDR";
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/runtime/env_vars.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ extern const char* const kEnvPjRtTpuMaxInflightComputations;
extern const char* const kEnvPjrtAsyncCpuClient;
extern const char* const kEnvPjrtAsyncGpuClient;
extern const char* const kEnvTpuLibraryPath;
extern const char* const kEnvInferredTpuLibraryPath;
extern const char* const kEnvXpuLibraryPath;
extern const char* const kEnvNeuronLibraryPath;
extern const char* const kEnvPjrtDistServiceAddr;
Expand Down
9 changes: 5 additions & 4 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,11 @@ PjRtComputationClient::PjRtComputationClient() {
client_ = std::move(xla::GetTfrtCpuClient(async, cpu_device_count).value());
} else if (device_type == "TPU" || device_type == "TPU_C_API") {
TF_VLOG(1) << "Initializing TFRT TPU client...";
XLA_CHECK_OK(
pjrt::LoadPjrtPlugin(
"tpu", sys_util::GetEnvString(env::kEnvTpuLibraryPath, "libtpu.so"))
.status());
// Prefer $TPU_LIBRARY_PATH if set
auto tpu_library_path = sys_util::GetEnvString(
env::kEnvTpuLibraryPath,
sys_util::GetEnvString(env::kEnvInferredTpuLibraryPath, "libtpu.so"));
XLA_CHECK_OK(pjrt::LoadPjrtPlugin("tpu", tpu_library_path).status());
tsl::Status tpu_status = pjrt::InitializePjrtPlugin("tpu");
XLA_CHECK(tpu_status.ok());
client_ = std::move(xla::GetCApiClient("TPU").value());
Expand Down

0 comments on commit 84a783d

Please sign in to comment.