diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 1884310c5fda..1e6bb020fe50 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -95,7 +95,6 @@ void PrepareToExit() { runtime::GetComputationClientIfInitialized(); if (client != nullptr) { XLAGraphExecutor::Get()->WaitDeviceOps({}); - client->PrepareToExit(); } } diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 1e610be7959b..db4bac21916a 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -344,8 +344,6 @@ class ComputationClient { virtual MemoryInfo GetMemoryInfo(const std::string& device) = 0; - virtual void PrepareToExit() = 0; - // Block until pass in devices' async operation are finished. If empty, all // the local devices will be waited for. virtual void WaitDeviceOps(const std::vector& devices) = 0; diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index d7a11611a034..f4fc73bb79e5 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -85,8 +85,6 @@ class PjRtComputationClient : public ComputationClient { std::shared_ptr> GetReplicationDevices() override; - void PrepareToExit() override { return; }; - void WaitDeviceOps(const std::vector& devices) override; std::map GetMetrics() const override; diff --git a/torch_xla/csrc/runtime/runtime.cc b/torch_xla/csrc/runtime/runtime.cc index 8cfd06951842..69e5bb74319f 100644 --- a/torch_xla/csrc/runtime/runtime.cc +++ b/torch_xla/csrc/runtime/runtime.cc @@ -10,10 +10,11 @@ namespace torch_xla { namespace runtime { namespace { -std::atomic g_computation_client(nullptr); -std::once_flag g_computation_client_once; +std::atomic g_computation_client_initialized(false); ComputationClient* CreateClient() { + bool was_initialized = g_computation_client_initialized.exchange(true); + XLA_CHECK(!was_initialized) << "ComputationClient already initialized"; if (sys_util::GetEnvBool("XLA_DUMP_FATAL_STACK", false)) { tsl::testing::InstallStacktraceHandler(); } @@ -23,6 +24,7 @@ ComputationClient* CreateClient() { if (sys_util::GetEnvString(env::kEnvPjRtDevice, "") != "") { client = new PjRtComputationClient(); } else { + g_computation_client_initialized = false; XLA_ERROR() << "$PJRT_DEVICE is not set." << std::endl; } @@ -34,13 +36,12 @@ ComputationClient* CreateClient() { } // namespace ComputationClient* GetComputationClient() { - std::call_once(g_computation_client_once, - [&]() { g_computation_client = std::move(CreateClient()); }); - return g_computation_client.load(); + static auto client = std::unique_ptr(CreateClient()); + return client.get(); } ComputationClient* GetComputationClientIfInitialized() { - return g_computation_client.load(); + return g_computation_client_initialized ? GetComputationClient() : nullptr; } } // namespace runtime