diff --git a/torch_xla/csrc/runtime/runtime.cc b/torch_xla/csrc/runtime/runtime.cc index 69e5bb74319..e2f69c44e47 100644 --- a/torch_xla/csrc/runtime/runtime.cc +++ b/torch_xla/csrc/runtime/runtime.cc @@ -8,35 +8,29 @@ namespace torch_xla { namespace runtime { -namespace { std::atomic g_computation_client_initialized(false); -ComputationClient* CreateClient() { - bool was_initialized = g_computation_client_initialized.exchange(true); - XLA_CHECK(!was_initialized) << "ComputationClient already initialized"; - if (sys_util::GetEnvBool("XLA_DUMP_FATAL_STACK", false)) { - tsl::testing::InstallStacktraceHandler(); - } - - ComputationClient* client; +ComputationClient* GetComputationClient() { + static std::unique_ptr client = []() { + if (sys_util::GetEnvBool("XLA_DUMP_FATAL_STACK", false)) { + tsl::testing::InstallStacktraceHandler(); + } - if (sys_util::GetEnvString(env::kEnvPjRtDevice, "") != "") { - client = new PjRtComputationClient(); - } else { - g_computation_client_initialized = false; - XLA_ERROR() << "$PJRT_DEVICE is not set." << std::endl; - } + std::unique_ptr client; - XLA_CHECK(client != nullptr); + if (sys_util::GetEnvString(env::kEnvPjRtDevice, "") != "") { + client = std::make_unique(); + } else { + XLA_ERROR() << "$PJRT_DEVICE is not set." << std::endl; + } - return client; -} + XLA_CHECK(client); -} // namespace + g_computation_client_initialized = true; + return client; + }(); -ComputationClient* GetComputationClient() { - static auto client = std::unique_ptr(CreateClient()); return client.get(); }