Skip to content

Commit

Permalink
Destroy the ComputationClient when the program exits (pytorch#5750)
Browse files Browse the repository at this point in the history
* Destroy the ComputationClient when the program exits

* Fix extra error when PJRT_DEVICE is not set
  • Loading branch information
will-cromar authored and mbzomowski committed Nov 16, 2023
1 parent 80e7d87 commit 50fd007
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 11 deletions.
1 change: 0 additions & 1 deletion torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ void PrepareToExit() {
runtime::GetComputationClientIfInitialized();
if (client != nullptr) {
XLAGraphExecutor::Get()->WaitDeviceOps({});
client->PrepareToExit();
}
}

Expand Down
2 changes: 0 additions & 2 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,6 @@ class ComputationClient {

virtual MemoryInfo GetMemoryInfo(const std::string& device) = 0;

virtual void PrepareToExit() = 0;

// Block until pass in devices' async operation are finished. If empty, all
// the local devices will be waited for.
virtual void WaitDeviceOps(const std::vector<std::string>& devices) = 0;
Expand Down
2 changes: 0 additions & 2 deletions torch_xla/csrc/runtime/pjrt_computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,6 @@ class PjRtComputationClient : public ComputationClient {

std::shared_ptr<std::vector<std::string>> GetReplicationDevices() override;

void PrepareToExit() override { return; };

void WaitDeviceOps(const std::vector<std::string>& devices) override;

std::map<std::string, Metric> GetMetrics() const override;
Expand Down
13 changes: 7 additions & 6 deletions torch_xla/csrc/runtime/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ namespace torch_xla {
namespace runtime {
namespace {

std::atomic<ComputationClient*> g_computation_client(nullptr);
std::once_flag g_computation_client_once;
std::atomic<bool> g_computation_client_initialized(false);

ComputationClient* CreateClient() {
bool was_initialized = g_computation_client_initialized.exchange(true);
XLA_CHECK(!was_initialized) << "ComputationClient already initialized";
if (sys_util::GetEnvBool("XLA_DUMP_FATAL_STACK", false)) {
tsl::testing::InstallStacktraceHandler();
}
Expand All @@ -23,6 +24,7 @@ ComputationClient* CreateClient() {
if (sys_util::GetEnvString(env::kEnvPjRtDevice, "") != "") {
client = new PjRtComputationClient();
} else {
g_computation_client_initialized = false;
XLA_ERROR() << "$PJRT_DEVICE is not set." << std::endl;
}

Expand All @@ -34,13 +36,12 @@ ComputationClient* CreateClient() {
} // namespace

ComputationClient* GetComputationClient() {
std::call_once(g_computation_client_once,
[&]() { g_computation_client = std::move(CreateClient()); });
return g_computation_client.load();
static auto client = std::unique_ptr<ComputationClient>(CreateClient());
return client.get();
}

ComputationClient* GetComputationClientIfInitialized() {
return g_computation_client.load();
return g_computation_client_initialized ? GetComputationClient() : nullptr;
}

} // namespace runtime
Expand Down

0 comments on commit 50fd007

Please sign in to comment.