diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index 76ea6b71672..6549e54cb65 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -1,14 +1,17 @@ import functools import os +import signal import sys import tempfile -import unittest import test_xla_sharding_base import threading +import time +import unittest import torch import torch.distributed as dist import torch.distributed.checkpoint as dist_cp +import torch_xla import torch_xla.core.xla_model as xm import torch_xla.runtime as xr import torch_xla.experimental.xla_sharding as xs @@ -486,6 +489,36 @@ def test_master_ip_discovery(self, patched_get_worker_ips): patched_get_worker_ips.return_value = ['10.0.0.1', '10.0.0.2'] self.assertTrue(xr.get_master_ip(), '10.0.0.1') + def test_preemption_sync_manager(self): + try: + torch_xla._XLAC._ensure_dist_runtime_initialized( + global_rank=0, world_size=1, master_addr="127.1") + torch_xla._XLAC._activate_preemption_sync_manager() + sync_point_reached = torch_xla._XLAC._sync_point_reached + + # No sync point for the first several steps + sigterm_step = 10 + for step in range(sigterm_step): + self.assertFalse(sync_point_reached(step)) + + # Send a SIGTERM to the current process to trigger a sync point + os.kill(os.getpid(), signal.SIGTERM) + + # Allow the signal to be processed. The PreemptionSyncManager must receive + # the SIGTERM, which happens asynchronously, and the state must be + # propagated through the distributed runtime. Eventually, + # sync_point_reached will return True. + success = False + for attempt in range(10): + success = sync_point_reached(sigterm_step + attempt) + if success: + break + time.sleep(1) + self.assertTrue(success, "Sync point was never reached after SIGTERM") + finally: + # Scope the distributed runtime to the lifespan of the test. + torch_xla._XLAC._ensure_dist_runtime_shutdown() + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index 1e6bb020fe5..e9a0d02aa40 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -38,6 +38,7 @@ #include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/ops/xla_ops.h" #include "torch_xla/csrc/runtime/computation_client.h" +#include "torch_xla/csrc/runtime/distributed_runtime.h" #include "torch_xla/csrc/runtime/metrics.h" #include "torch_xla/csrc/runtime/metrics_analysis.h" #include "torch_xla/csrc/runtime/metrics_reader.h" @@ -1876,6 +1877,39 @@ void InitXlaModuleBindings(py::module m) { xla::HloModule::CreateFromProto(module_proto, config).value()); return module->ToString(); }); + // Initialize a distributed runtime if one has not already been created. + m.def("_ensure_dist_runtime_initialized", + [](int global_rank, int world_size, std::string master_addr, + std::string master_port) { + if (!runtime::DistributedRuntime::IsInitialized()) { + runtime::DistributedRuntime::Initialize(global_rank, world_size, + master_addr, master_port); + } + }, + py::arg("global_rank"), py::arg("world_size"), py::arg("master_addr"), + py::arg("master_port") = + runtime::DistributedRuntime::kDefaultCoordinatorPort); + // Shutdown the distributed runtime if it's active. + m.def("_ensure_dist_runtime_shutdown", []() { + if (runtime::DistributedRuntime::IsInitialized()) { + runtime::DistributedRuntime::Shutdown(); + } + }); + // Create a PreemptionSyncManager for the DistributedRuntime. The + // PreemptionSyncManager will register a SIGTERM handler as a side effect. + m.def("_activate_preemption_sync_manager", []() { + XLA_CHECK(runtime::DistributedRuntime::IsInitialized()) + << "DistributedRuntime must be initialized to register " + "PreemptionSyncManager"; + runtime::DistributedRuntime::Get().ActivatePreemptionSyncManager(); + }); + // Check whether a sync point has been reached. This method requires that the + // distributed runtime be initialized and a PreemptionSyncManager activated. + m.def("_sync_point_reached", [](int step) { + XLA_CHECK(runtime::DistributedRuntime::IsInitialized()) + << "DistributedRuntime must be initialized"; + return runtime::DistributedRuntime::Get().ReachedSyncPoint(step); + }); m.def("_is_placecholder", [](at::Tensor& input) { XLATensorPtr xtensor = bridge::GetXlaTensor(input); return xtensor->CurrentDataHandle() && diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index db054e32289..bd8401f0649 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -172,6 +172,7 @@ cc_library( ":debug_macros", ":sys_util", "@xla//xla/pjrt/distributed", + "@tsl//tsl/distributed_runtime/preemption:preemption_sync_manager", ], ) diff --git a/torch_xla/csrc/runtime/distributed_runtime.cc b/torch_xla/csrc/runtime/distributed_runtime.cc index dc3dbaf4eb4..de5313ef7c2 100644 --- a/torch_xla/csrc/runtime/distributed_runtime.cc +++ b/torch_xla/csrc/runtime/distributed_runtime.cc @@ -6,16 +6,13 @@ namespace torch_xla { namespace runtime { -const std::string DistributedRuntime::default_coordinator_port = "8547"; - -DistributedRuntime::DistributedRuntime(int global_rank, std::string master_addr, +DistributedRuntime::DistributedRuntime(int global_rank, int world_size, + std::string master_addr, std::string port) { std::string dist_service_addr = absl::StrJoin({master_addr, port}, ":"); if (global_rank == 0) { - int local_world_size = sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1); - int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", local_world_size); xla::CoordinationServiceImpl::Options service_options; - service_options.num_nodes = global_world_size; + service_options.num_nodes = world_size; xla::StatusOr> dist_runtime_service = xla::GetDistributedRuntimeService( dist_service_addr, service_options); @@ -50,5 +47,22 @@ std::shared_ptr DistributedRuntime::GetClient() { return dist_runtime_client_; } +void DistributedRuntime::ActivatePreemptionSyncManager() { + if (preemption_sync_manager_ == nullptr) { + preemption_sync_manager_ = std::move(tsl::CreatePreemptionSyncManager()); + auto client = dist_runtime_client_->GetCoordinationServiceAgent(); + XLA_CHECK(client.ok()) << "Failed to retrieve the CoodinationServiceAgent"; + auto status = preemption_sync_manager_->Initialize(client.value()); + XLA_CHECK(status.ok()) << "Failed to initialize the PreemptionSyncManager"; + } +} + +bool DistributedRuntime::ReachedSyncPoint(int step) { + XLA_CHECK(preemption_sync_manager_ != nullptr) + << "A PreemptionSyncManager has not been registered with the " + "DistributedRuntime."; + return preemption_sync_manager_->ReachedSyncPoint(step); +} + } // namespace runtime } // namespace torch_xla diff --git a/torch_xla/csrc/runtime/distributed_runtime.h b/torch_xla/csrc/runtime/distributed_runtime.h index f26ef3d008c..8ea22f28921 100644 --- a/torch_xla/csrc/runtime/distributed_runtime.h +++ b/torch_xla/csrc/runtime/distributed_runtime.h @@ -3,33 +3,78 @@ #include +#include "torch_xla/csrc/runtime/debug_macros.h" +#include "tsl/distributed_runtime/preemption/preemption_sync_manager.h" #include "xla/pjrt/distributed/distributed.h" namespace torch_xla { namespace runtime { +// DistributedRuntime serves as the point of entry for all operations which +// required the XLA distributed runtime, such as preemption coordination. class DistributedRuntime { public: - static const std::string default_coordinator_port; - static DistributedRuntime& getInstance(int global_rank, - std::string master_addr, - std::string port) { - static DistributedRuntime dist_runtime_instance(global_rank, master_addr, - port); - return dist_runtime_instance; + static inline const std::string kDefaultCoordinatorPort = "8547"; + + // Returns true if the distributed runtime has already been initialized. + static bool IsInitialized() { return dist_runtime_ != nullptr; } + + // Initialize the shared DistributedRuntime object. This creates a + // DistributedRuntimeClient on each worker, and on global_rank 0 initializes + // the corresponding DistributedRuntimeService. + static void Initialize(int global_rank, int world_size, + std::string master_addr, std::string port) { + XLA_CHECK(!IsInitialized()) << "DistributedRuntime already initialized"; + dist_runtime_ = std::unique_ptr( + new DistributedRuntime(global_rank, world_size, master_addr, port)); + } + + // Shutdown the distributed runtime. All associated resources will be + // released, and subsequent calls to IsInitialized will return false. + // The distributed runtime may later be re-initialized. + static void Shutdown() { + XLA_CHECK(IsInitialized()) + << "Must initialize distributed runtime before shutdown"; + dist_runtime_ = nullptr; + } + + // Retrieve the shared DistributedRuntime object. + static DistributedRuntime& Get() { + XLA_CHECK(IsInitialized()) + << "Must initialize distributed runtime before retrieval"; + return *dist_runtime_; } + ~DistributedRuntime(); DistributedRuntime(DistributedRuntime const&) = delete; void operator=(DistributedRuntime const&) = delete; + // Retrieve the DistributedRuntimeClient. std::shared_ptr GetClient(); + // Register a PreemptionSyncManager for the distributed runtime if none is + // active. The PreemptionSyncManager will register a SIGTERM handler, and + // when any host has received a preemption notice, all hosts are made aware + // through the ReachedSyncPoint API. See the documentation of + // tsl::PreemptionSyncManager for the full semantics: + // https://github.com/google/tsl/blob/3bbe663/tsl/distributed_runtime/preemption/preemption_sync_manager.h#L34 + void ActivatePreemptionSyncManager(); + + // A pass-throguh API to the PreemptionSyncManager::ReachedSyncPoint. + // The PreemptionSyncManager must be activated within the DistributedRuntime. + // Returns true when the input step has been identified as a sync point, and + // false otherwise. + bool ReachedSyncPoint(int step); + private: - DistributedRuntime(int global_rank, std::string master_addr, + static inline std::unique_ptr dist_runtime_; + + DistributedRuntime(int global_rank, int world_size, std::string master_addr, std::string port); std::unique_ptr dist_runtime_service_; std::shared_ptr dist_runtime_client_; + std::shared_ptr preemption_sync_manager_; }; } // namespace runtime diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 2ae0768856b..ce6c747b3cb 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -109,13 +109,17 @@ PjRtComputationClient::PjRtComputationClient() { bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncGpuClient, true); int local_process_rank = sys_util::GetEnvInt(env::kEnvPjRtLocalRank, 0); int global_process_rank = sys_util::GetEnvInt("RANK", local_process_rank); + int local_world_size = sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1); + int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", local_world_size); std::string master_addr = runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost"); std::string port = runtime::sys_util::GetEnvString( - "XLA_COORDINATOR_PORT", DistributedRuntime::default_coordinator_port); - std::shared_ptr distributed_client = - DistributedRuntime::getInstance(global_process_rank, master_addr, port) - .GetClient(); + "XLA_COORDINATOR_PORT", DistributedRuntime::kDefaultCoordinatorPort); + + // Use the DistributedRuntime as the distributed key-value store. + DistributedRuntime::Initialize(global_process_rank, global_world_size, + master_addr, port); + auto distributed_client = DistributedRuntime::Get().GetClient(); auto allowed_devices = std::make_optional>(std::set{local_process_rank}); xla::PjRtClient::KeyValueGetCallback kv_get = nullptr; @@ -132,8 +136,6 @@ PjRtComputationClient::PjRtComputationClient() { return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v); }; } - int local_world_size = sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1); - int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", local_world_size); TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id=" << global_process_rank << ", num_nodes=" << global_world_size; client_ = std::move(xla::GetStreamExecutorGpuClient(