Skip to content

Commit

Permalink
Support PreemptionSyncManager in DistributedRuntime
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 committed Oct 31, 2023
1 parent 83778f0 commit bb13a6a
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 21 deletions.
35 changes: 34 additions & 1 deletion test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import functools
import os
import signal
import sys
import tempfile
import unittest
import test_xla_sharding_base
import threading
import time
import unittest

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dist_cp
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.experimental.xla_sharding as xs
Expand Down Expand Up @@ -486,6 +489,36 @@ def test_master_ip_discovery(self, patched_get_worker_ips):
patched_get_worker_ips.return_value = ['10.0.0.1', '10.0.0.2']
self.assertTrue(xr.get_master_ip(), '10.0.0.1')

def test_preemption_sync_manager(self):
try:
torch_xla._XLAC._ensure_dist_runtime_initialized(
global_rank=0, world_size=1, master_addr="127.1")
torch_xla._XLAC._activate_preemption_sync_manager()
sync_point_reached = torch_xla._XLAC._sync_point_reached

# No sync point for the first several steps
sigterm_step = 10
for step in range(sigterm_step):
self.assertFalse(sync_point_reached(step))

# Send a SIGTERM to the current process to trigger a sync point
os.kill(os.getpid(), signal.SIGTERM)

# Allow the signal to be processed. The PreemptionSyncManager must receive
# the SIGTERM, which happens asynchronously, and the state must be
# propagated through the distributed runtime. Eventually,
# sync_point_reached will return True.
success = False
for attempt in range(10):
success = sync_point_reached(sigterm_step + attempt)
if success:
break
time.sleep(1)
self.assertTrue(success, "Sync point was never reached after SIGTERM")
finally:
# Scope the distributed runtime to the lifespan of the test.
torch_xla._XLAC._ensure_dist_runtime_shutdown()


if __name__ == '__main__':
test = unittest.main()
Expand Down
34 changes: 34 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "torch_xla/csrc/ops/device_data.h"
#include "torch_xla/csrc/ops/xla_ops.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/distributed_runtime.h"
#include "torch_xla/csrc/runtime/metrics.h"
#include "torch_xla/csrc/runtime/metrics_analysis.h"
#include "torch_xla/csrc/runtime/metrics_reader.h"
Expand Down Expand Up @@ -1876,6 +1877,39 @@ void InitXlaModuleBindings(py::module m) {
xla::HloModule::CreateFromProto(module_proto, config).value());
return module->ToString();
});
// Initialize a distributed runtime if one has not already been created.
m.def("_ensure_dist_runtime_initialized",
[](int global_rank, int world_size, std::string master_addr,
std::string master_port) {
if (!runtime::DistributedRuntime::IsInitialized()) {
runtime::DistributedRuntime::Initialize(global_rank, world_size,
master_addr, master_port);
}
},
py::arg("global_rank"), py::arg("world_size"), py::arg("master_addr"),
py::arg("master_port") =
runtime::DistributedRuntime::kDefaultCoordinatorPort);
// Shutdown the distributed runtime if it's active.
m.def("_ensure_dist_runtime_shutdown", []() {
if (runtime::DistributedRuntime::IsInitialized()) {
runtime::DistributedRuntime::Shutdown();
}
});
// Create a PreemptionSyncManager for the DistributedRuntime. The
// PreemptionSyncManager will register a SIGTERM handler as a side effect.
m.def("_activate_preemption_sync_manager", []() {
XLA_CHECK(runtime::DistributedRuntime::IsInitialized())
<< "DistributedRuntime must be initialized to register "
"PreemptionSyncManager";
runtime::DistributedRuntime::Get().ActivatePreemptionSyncManager();
});
// Check whether a sync point has been reached. This method requires that the
// distributed runtime be initialized and a PreemptionSyncManager activated.
m.def("_sync_point_reached", [](int step) {
XLA_CHECK(runtime::DistributedRuntime::IsInitialized())
<< "DistributedRuntime must be initialized";
return runtime::DistributedRuntime::Get().ReachedSyncPoint(step);
});
m.def("_is_placecholder", [](at::Tensor& input) {
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
return xtensor->CurrentDataHandle() &&
Expand Down
1 change: 1 addition & 0 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ cc_library(
":debug_macros",
":sys_util",
"@xla//xla/pjrt/distributed",
"@tsl//tsl/distributed_runtime/preemption:preemption_sync_manager",
],
)

Expand Down
26 changes: 20 additions & 6 deletions torch_xla/csrc/runtime/distributed_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,13 @@
namespace torch_xla {
namespace runtime {

const std::string DistributedRuntime::default_coordinator_port = "8547";

DistributedRuntime::DistributedRuntime(int global_rank, std::string master_addr,
DistributedRuntime::DistributedRuntime(int global_rank, int world_size,
std::string master_addr,
std::string port) {
std::string dist_service_addr = absl::StrJoin({master_addr, port}, ":");
if (global_rank == 0) {
int local_world_size = sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1);
int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", local_world_size);
xla::CoordinationServiceImpl::Options service_options;
service_options.num_nodes = global_world_size;
service_options.num_nodes = world_size;
xla::StatusOr<std::unique_ptr<xla::DistributedRuntimeService>>
dist_runtime_service = xla::GetDistributedRuntimeService(
dist_service_addr, service_options);
Expand Down Expand Up @@ -50,5 +47,22 @@ std::shared_ptr<xla::DistributedRuntimeClient> DistributedRuntime::GetClient() {
return dist_runtime_client_;
}

void DistributedRuntime::ActivatePreemptionSyncManager() {
if (preemption_sync_manager_ == nullptr) {
preemption_sync_manager_ = std::move(tsl::CreatePreemptionSyncManager());
auto client = dist_runtime_client_->GetCoordinationServiceAgent();
XLA_CHECK(client.ok()) << "Failed to retrieve the CoodinationServiceAgent";
auto status = preemption_sync_manager_->Initialize(client.value());
XLA_CHECK(status.ok()) << "Failed to initialize the PreemptionSyncManager";
}
}

bool DistributedRuntime::ReachedSyncPoint(int step) {
XLA_CHECK(preemption_sync_manager_ != nullptr)
<< "A PreemptionSyncManager has not been registered with the "
"DistributedRuntime.";
return preemption_sync_manager_->ReachedSyncPoint(step);
}

} // namespace runtime
} // namespace torch_xla
61 changes: 53 additions & 8 deletions torch_xla/csrc/runtime/distributed_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,78 @@

#include <memory>

#include "torch_xla/csrc/runtime/debug_macros.h"
#include "tsl/distributed_runtime/preemption/preemption_sync_manager.h"
#include "xla/pjrt/distributed/distributed.h"

namespace torch_xla {
namespace runtime {

// DistributedRuntime serves as the point of entry for all operations which
// required the XLA distributed runtime, such as preemption coordination.
class DistributedRuntime {
public:
static const std::string default_coordinator_port;
static DistributedRuntime& getInstance(int global_rank,
std::string master_addr,
std::string port) {
static DistributedRuntime dist_runtime_instance(global_rank, master_addr,
port);
return dist_runtime_instance;
static inline const std::string kDefaultCoordinatorPort = "8547";

// Returns true if the distributed runtime has already been initialized.
static bool IsInitialized() { return dist_runtime_ != nullptr; }

// Initialize the shared DistributedRuntime object. This creates a
// DistributedRuntimeClient on each worker, and on global_rank 0 initializes
// the corresponding DistributedRuntimeService.
static void Initialize(int global_rank, int world_size,
std::string master_addr, std::string port) {
XLA_CHECK(!IsInitialized()) << "DistributedRuntime already initialized";
dist_runtime_ = std::unique_ptr<DistributedRuntime>(
new DistributedRuntime(global_rank, world_size, master_addr, port));
}

// Shutdown the distributed runtime. All associated resources will be
// released, and subsequent calls to IsInitialized will return false.
// The distributed runtime may later be re-initialized.
static void Shutdown() {
XLA_CHECK(IsInitialized())
<< "Must initialize distributed runtime before shutdown";
dist_runtime_ = nullptr;
}

// Retrieve the shared DistributedRuntime object.
static DistributedRuntime& Get() {
XLA_CHECK(IsInitialized())
<< "Must initialize distributed runtime before retrieval";
return *dist_runtime_;
}

~DistributedRuntime();
DistributedRuntime(DistributedRuntime const&) = delete;
void operator=(DistributedRuntime const&) = delete;

// Retrieve the DistributedRuntimeClient.
std::shared_ptr<xla::DistributedRuntimeClient> GetClient();

// Register a PreemptionSyncManager for the distributed runtime if none is
// active. The PreemptionSyncManager will register a SIGTERM handler, and
// when any host has received a preemption notice, all hosts are made aware
// through the ReachedSyncPoint API. See the documentation of
// tsl::PreemptionSyncManager for the full semantics:
// https://github.com/google/tsl/blob/3bbe663/tsl/distributed_runtime/preemption/preemption_sync_manager.h#L34
void ActivatePreemptionSyncManager();

// A pass-throguh API to the PreemptionSyncManager::ReachedSyncPoint.
// The PreemptionSyncManager must be activated within the DistributedRuntime.
// Returns true when the input step has been identified as a sync point, and
// false otherwise.
bool ReachedSyncPoint(int step);

private:
DistributedRuntime(int global_rank, std::string master_addr,
static inline std::unique_ptr<DistributedRuntime> dist_runtime_;

DistributedRuntime(int global_rank, int world_size, std::string master_addr,
std::string port);

std::unique_ptr<xla::DistributedRuntimeService> dist_runtime_service_;
std::shared_ptr<xla::DistributedRuntimeClient> dist_runtime_client_;
std::shared_ptr<tsl::PreemptionSyncManager> preemption_sync_manager_;
};

} // namespace runtime
Expand Down
14 changes: 8 additions & 6 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,17 @@ PjRtComputationClient::PjRtComputationClient() {
bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncGpuClient, true);
int local_process_rank = sys_util::GetEnvInt(env::kEnvPjRtLocalRank, 0);
int global_process_rank = sys_util::GetEnvInt("RANK", local_process_rank);
int local_world_size = sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1);
int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", local_world_size);
std::string master_addr =
runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost");
std::string port = runtime::sys_util::GetEnvString(
"XLA_COORDINATOR_PORT", DistributedRuntime::default_coordinator_port);
std::shared_ptr<xla::DistributedRuntimeClient> distributed_client =
DistributedRuntime::getInstance(global_process_rank, master_addr, port)
.GetClient();
"XLA_COORDINATOR_PORT", DistributedRuntime::kDefaultCoordinatorPort);

// Use the DistributedRuntime as the distributed key-value store.
DistributedRuntime::Initialize(global_process_rank, global_world_size,
master_addr, port);
auto distributed_client = DistributedRuntime::Get().GetClient();
auto allowed_devices =
std::make_optional<std::set<int>>(std::set{local_process_rank});
xla::PjRtClient::KeyValueGetCallback kv_get = nullptr;
Expand All @@ -132,8 +136,6 @@ PjRtComputationClient::PjRtComputationClient() {
return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v);
};
}
int local_world_size = sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1);
int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", local_world_size);
TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id="
<< global_process_rank << ", num_nodes=" << global_world_size;
client_ = std::move(xla::GetStreamExecutorGpuClient(
Expand Down

0 comments on commit bb13a6a

Please sign in to comment.