Skip to content

Commit

Permalink
Fix confusing error message when PjRtComputationClient throws excep…
Browse files Browse the repository at this point in the history
…tion
  • Loading branch information
will-cromar committed Nov 29, 2023
1 parent b9475d9 commit e379c6f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 13 deletions.
23 changes: 11 additions & 12 deletions torch_xla/csrc/runtime/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,38 +10,37 @@ namespace torch_xla {
namespace runtime {
namespace {

std::atomic<bool> g_computation_client_initialized(false);

ComputationClient* CreateClient() {
bool was_initialized = g_computation_client_initialized.exchange(true);
XLA_CHECK(!was_initialized) << "ComputationClient already initialized";
std::unique_ptr<ComputationClient> CreateClient() {
if (sys_util::GetEnvBool("XLA_DUMP_FATAL_STACK", false)) {
tsl::testing::InstallStacktraceHandler();
}

ComputationClient* client;
std::unique_ptr<ComputationClient> client = nullptr;

if (sys_util::GetEnvString(env::kEnvPjRtDevice, "") != "") {
client = new PjRtComputationClient();
client = std::make_unique<PjRtComputationClient>();
} else {
g_computation_client_initialized = false;
XLA_ERROR() << "$PJRT_DEVICE is not set." << std::endl;
}

XLA_CHECK(client != nullptr);
XLA_CHECK(client);

return client;
}

} // namespace

ComputationClient* GetComputationClient() {
static auto client = std::unique_ptr<ComputationClient>(CreateClient());
ComputationClient* GetComputationClient(bool create = true) {
static std::unique_ptr<ComputationClient> client = nullptr;
if (!client && create) {
static std::once_flag flag;
std::call_once(flag, [](){ client = CreateClient(); });
}
return client.get();
}

ComputationClient* GetComputationClientIfInitialized() {
return g_computation_client_initialized ? GetComputationClient() : nullptr;
return GetComputationClient(/*create=*/false);
}

} // namespace runtime
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace torch_xla {
namespace runtime {

// Returns the ComputationClient singleton.
ComputationClient* GetComputationClient();
ComputationClient* GetComputationClient(bool create = true);

ComputationClient* GetComputationClientIfInitialized();

Expand Down

0 comments on commit e379c6f

Please sign in to comment.