Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bad error message when PjRtComputationClient throws exception #5946

Merged
merged 5 commits into from
Dec 12, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 15 additions & 21 deletions torch_xla/csrc/runtime/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,29 @@

namespace torch_xla {
namespace runtime {
namespace {

std::atomic<bool> g_computation_client_initialized(false);

ComputationClient* CreateClient() {
bool was_initialized = g_computation_client_initialized.exchange(true);
XLA_CHECK(!was_initialized) << "ComputationClient already initialized";
if (sys_util::GetEnvBool("XLA_DUMP_FATAL_STACK", false)) {
tsl::testing::InstallStacktraceHandler();
}

ComputationClient* client;
ComputationClient* GetComputationClient() {
static std::unique_ptr<ComputationClient> client = []() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between this and the original approach, since both rely on the static initializer? It seems like the logic has just been moved out of CreateClient and into the lambda.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's right. This lets me skip checking the case where this function gets called twice, the handling of which causes the bad error message. C++11 statics will ensure that it only completes once, and the function is anonymous now, which will prevent future code from erroneously calling CreateClient somewhere else in the future.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, the lambda is guaranteed to only ever be called once. Do we know how CreateClient was called twice in the first place?

Copy link
Collaborator Author

@will-cromar will-cromar Dec 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added that check to defend against future errors (likely by future me). There was never a case in the original code where CreateClient actually completed twice or was called concurrently.

The bad error messaging occurred when the constructor of PjRtComputationClient threw an exception. g_computation_client_initialized was never reset to false, GetComputationClientIfInitialized calls GetComputationClient during teardown, which re-runs CreateClient, which complains because g_computation_client_initialized never got reset.

In this PR, I only set g_computation_client_initialized after the actual runtime init completed without breaking, since there are guaranteed to be no other concurrent callers of CreateClient.

if (sys_util::GetEnvBool("XLA_DUMP_FATAL_STACK", false)) {
tsl::testing::InstallStacktraceHandler();
}

if (sys_util::GetEnvString(env::kEnvPjRtDevice, "") != "") {
client = new PjRtComputationClient();
} else {
g_computation_client_initialized = false;
will-cromar marked this conversation as resolved.
Show resolved Hide resolved
XLA_ERROR() << "$PJRT_DEVICE is not set." << std::endl;
}
std::unique_ptr<ComputationClient> client;

XLA_CHECK(client != nullptr);
if (sys_util::GetEnvString(env::kEnvPjRtDevice, "") != "") {
client = std::make_unique<PjRtComputationClient>();
} else {
XLA_ERROR() << "$PJRT_DEVICE is not set." << std::endl;
}

return client;
}
XLA_CHECK(client);

} // namespace
g_computation_client_initialized = true;
return client;
}();

ComputationClient* GetComputationClient() {
static auto client = std::unique_ptr<ComputationClient>(CreateClient());
return client.get();
}

Expand Down