Skip to content

Commit

Permalink
[Backport] Fix bad error message when PjRtComputationClient throws …
Browse files Browse the repository at this point in the history
…exception (#6144)
  • Loading branch information
will-cromar authored Dec 14, 2023
1 parent 41d0fe5 commit dec978d
Showing 1 changed file with 15 additions and 21 deletions.
36 changes: 15 additions & 21 deletions torch_xla/csrc/runtime/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,29 @@

namespace torch_xla {
namespace runtime {
namespace {

std::atomic<bool> g_computation_client_initialized(false);

ComputationClient* CreateClient() {
bool was_initialized = g_computation_client_initialized.exchange(true);
XLA_CHECK(!was_initialized) << "ComputationClient already initialized";
if (sys_util::GetEnvBool("XLA_DUMP_FATAL_STACK", false)) {
tsl::testing::InstallStacktraceHandler();
}

ComputationClient* client;
ComputationClient* GetComputationClient() {
static std::unique_ptr<ComputationClient> client = []() {
if (sys_util::GetEnvBool("XLA_DUMP_FATAL_STACK", false)) {
tsl::testing::InstallStacktraceHandler();
}

if (sys_util::GetEnvString(env::kEnvPjRtDevice, "") != "") {
client = new PjRtComputationClient();
} else {
g_computation_client_initialized = false;
XLA_ERROR() << "$PJRT_DEVICE is not set." << std::endl;
}
std::unique_ptr<ComputationClient> client;

XLA_CHECK(client != nullptr);
if (sys_util::GetEnvString(env::kEnvPjRtDevice, "") != "") {
client = std::make_unique<PjRtComputationClient>();
} else {
XLA_ERROR() << "$PJRT_DEVICE is not set." << std::endl;
}

return client;
}
XLA_CHECK(client);

} // namespace
g_computation_client_initialized = true;
return client;
}();

ComputationClient* GetComputationClient() {
static auto client = std::unique_ptr<ComputationClient>(CreateClient());
return client.get();
}

Expand Down

0 comments on commit dec978d

Please sign in to comment.