diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index 6549e54cb65..fbc75bc5e4d 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -343,6 +343,8 @@ def setUp(self): # Initialize the a minimal process group dist.init_process_group( backend='gloo', init_method='tcp://127.1:8932', world_size=1, rank=0) + torch_xla._XLAC._ensure_xla_coordinator_initialized( + global_rank=0, world_size=1, master_addr="127.1") def tearDown(self): super().tearDown() @@ -491,8 +493,6 @@ def test_master_ip_discovery(self, patched_get_worker_ips): def test_preemption_sync_manager(self): try: - torch_xla._XLAC._ensure_dist_runtime_initialized( - global_rank=0, world_size=1, master_addr="127.1") torch_xla._XLAC._activate_preemption_sync_manager() sync_point_reached = torch_xla._XLAC._sync_point_reached @@ -516,8 +516,8 @@ def test_preemption_sync_manager(self): time.sleep(1) self.assertTrue(success, "Sync point was never reached after SIGTERM") finally: - # Scope the distributed runtime to the lifespan of the test. - torch_xla._XLAC._ensure_dist_runtime_shutdown() + # Scope the PreemptionSyncManager to the lifespan of the test. + torch_xla._XLAC._deactivate_preemption_sync_manager() if __name__ == '__main__': diff --git a/torch_xla/csrc/BUILD b/torch_xla/csrc/BUILD index 352db2d34fb..d4eed1ec83f 100644 --- a/torch_xla/csrc/BUILD +++ b/torch_xla/csrc/BUILD @@ -259,6 +259,7 @@ ptxla_cc_library( "//torch_xla/csrc/runtime:sys_util", "//torch_xla/csrc/runtime:thread_pool", "//torch_xla/csrc/runtime:util", + "//torch_xla/csrc/runtime:xla_coordinator", "//torch_xla/csrc/runtime:xla_util", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index e9a0d02aa40..3e2b32cc2d8 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -38,7 +38,6 @@ #include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/ops/xla_ops.h" #include "torch_xla/csrc/runtime/computation_client.h" -#include "torch_xla/csrc/runtime/distributed_runtime.h" #include "torch_xla/csrc/runtime/metrics.h" #include "torch_xla/csrc/runtime/metrics_analysis.h" #include "torch_xla/csrc/runtime/metrics_reader.h" @@ -48,6 +47,7 @@ #include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/thread_pool.h" #include "torch_xla/csrc/runtime/util.h" +#include "torch_xla/csrc/runtime/xla_coordinator.h" #include "torch_xla/csrc/runtime/xla_util.h" #include "torch_xla/csrc/shape_helper.h" #include "torch_xla/csrc/tensor_impl.h" @@ -1877,38 +1877,44 @@ void InitXlaModuleBindings(py::module m) { xla::HloModule::CreateFromProto(module_proto, config).value()); return module->ToString(); }); - // Initialize a distributed runtime if one has not already been created. - m.def("_ensure_dist_runtime_initialized", + // Initialize the XlaCoordinator in the runtime if not already initialized. + m.def("_ensure_xla_coordinator_initialized", [](int global_rank, int world_size, std::string master_addr, std::string master_port) { - if (!runtime::DistributedRuntime::IsInitialized()) { - runtime::DistributedRuntime::Initialize(global_rank, world_size, - master_addr, master_port); + auto comp_client = runtime::GetComputationClient(); + if (!comp_client->CoordinatorInitialized()) { + runtime::GetComputationClient()->InitializeCoordinator( + global_rank, world_size, master_addr, master_port); } }, py::arg("global_rank"), py::arg("world_size"), py::arg("master_addr"), py::arg("master_port") = - runtime::DistributedRuntime::kDefaultCoordinatorPort); - // Shutdown the distributed runtime if it's active. - m.def("_ensure_dist_runtime_shutdown", []() { - if (runtime::DistributedRuntime::IsInitialized()) { - runtime::DistributedRuntime::Shutdown(); - } - }); - // Create a PreemptionSyncManager for the DistributedRuntime. The + runtime::XlaCoordinator::kDefaultCoordinatorPort); + // Create a PreemptionSyncManager for the XlaCoordinator. The // PreemptionSyncManager will register a SIGTERM handler as a side effect. m.def("_activate_preemption_sync_manager", []() { - XLA_CHECK(runtime::DistributedRuntime::IsInitialized()) - << "DistributedRuntime must be initialized to register " - "PreemptionSyncManager"; - runtime::DistributedRuntime::Get().ActivatePreemptionSyncManager(); + auto comp_client = runtime::GetComputationClient(); + XLA_CHECK(comp_client->CoordinatorInitialized()) + << "Coordinator must be initialized"; + auto& coordinator = comp_client->GetCoordinator(); + coordinator.ActivatePreemptionSyncManager(); + }); + // Deactivate the PreemptionSyncManager in the XlaCoordinator if one is active + m.def("_deactivate_preemption_sync_manager", []() { + auto comp_client = runtime::GetComputationClient(); + XLA_CHECK(comp_client->CoordinatorInitialized()) + << "Coordinator must be initialized"; + auto& coordinator = comp_client->GetCoordinator(); + coordinator.DeactivatePreemptionSyncManager(); }); // Check whether a sync point has been reached. This method requires that the // distributed runtime be initialized and a PreemptionSyncManager activated. m.def("_sync_point_reached", [](int step) { - XLA_CHECK(runtime::DistributedRuntime::IsInitialized()) - << "DistributedRuntime must be initialized"; - return runtime::DistributedRuntime::Get().ReachedSyncPoint(step); + auto comp_client = runtime::GetComputationClient(); + XLA_CHECK(comp_client->CoordinatorInitialized()) + << "Coordinator must be initialized"; + auto& coordinator = comp_client->GetCoordinator(); + return coordinator.ReachedSyncPoint(step); }); m.def("_is_placecholder", [](at::Tensor& input) { XLATensorPtr xtensor = bridge::GetXlaTensor(input); diff --git a/torch_xla/csrc/runtime/BUILD b/torch_xla/csrc/runtime/BUILD index bd8401f0649..da12492d2cd 100644 --- a/torch_xla/csrc/runtime/BUILD +++ b/torch_xla/csrc/runtime/BUILD @@ -63,6 +63,7 @@ cc_library( ":sys_util", ":types", ":util", + ":xla_coordinator", "//torch_xla/csrc:device", "@tsl//tsl/platform:stacktrace_handler", "@xla//xla:literal_util", @@ -88,12 +89,12 @@ cc_library( deps = [ ":computation_client", ":debug_macros", - ":distributed_runtime", ":env_vars", ":multi_wait", ":stablehlo_helper", ":tf_logging", ":thread_pool", + ":xla_coordinator", "@xla//xla:literal", "@xla//xla:shape_util", "@xla//xla/client:xla_computation", @@ -165,9 +166,9 @@ cc_library( ) cc_library( - name = "distributed_runtime", - srcs = ["distributed_runtime.cc"], - hdrs = ["distributed_runtime.h"], + name = "xla_coordinator", + srcs = ["xla_coordinator.cc"], + hdrs = ["xla_coordinator.h"], deps = [ ":debug_macros", ":sys_util", diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index db4bac21916..4d3df1d23ff 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -30,6 +30,9 @@ namespace torch_xla { namespace runtime { +// Forward declaration +class XlaCoordinator; + // Somehow the compiler doesn't allow type that has default member being // used as a default parameter in a method defined in the same scope. // Therefore, ClientExecuteOptions is defined here instead of within @@ -348,6 +351,17 @@ class ComputationClient { // the local devices will be waited for. virtual void WaitDeviceOps(const std::vector& devices) = 0; + // Check whether the XlaCoordinator has been initialized. + virtual bool CoordinatorInitialized() const = 0; + + // Initialize the XlaCoordinator for the runtime. + virtual void InitializeCoordinator(int global_rank, int world_size, + std::string master_addr, + std::string port) = 0; + + // Return the XlaCoordinator for the runtime. + virtual XlaCoordinator& GetCoordinator() = 0; + // Utility API around the vector based Compile() API to compile a single // computation. ComputationPtr Compile(xla::XlaComputation computation, diff --git a/torch_xla/csrc/runtime/distributed_runtime.h b/torch_xla/csrc/runtime/distributed_runtime.h deleted file mode 100644 index 8ea22f28921..00000000000 --- a/torch_xla/csrc/runtime/distributed_runtime.h +++ /dev/null @@ -1,83 +0,0 @@ -#ifndef XLA_CLIENT_DISTRIBUTED_RUNTIME_H_ -#define XLA_CLIENT_DISTRIBUTED_RUNTIME_H_ - -#include - -#include "torch_xla/csrc/runtime/debug_macros.h" -#include "tsl/distributed_runtime/preemption/preemption_sync_manager.h" -#include "xla/pjrt/distributed/distributed.h" - -namespace torch_xla { -namespace runtime { - -// DistributedRuntime serves as the point of entry for all operations which -// required the XLA distributed runtime, such as preemption coordination. -class DistributedRuntime { - public: - static inline const std::string kDefaultCoordinatorPort = "8547"; - - // Returns true if the distributed runtime has already been initialized. - static bool IsInitialized() { return dist_runtime_ != nullptr; } - - // Initialize the shared DistributedRuntime object. This creates a - // DistributedRuntimeClient on each worker, and on global_rank 0 initializes - // the corresponding DistributedRuntimeService. - static void Initialize(int global_rank, int world_size, - std::string master_addr, std::string port) { - XLA_CHECK(!IsInitialized()) << "DistributedRuntime already initialized"; - dist_runtime_ = std::unique_ptr( - new DistributedRuntime(global_rank, world_size, master_addr, port)); - } - - // Shutdown the distributed runtime. All associated resources will be - // released, and subsequent calls to IsInitialized will return false. - // The distributed runtime may later be re-initialized. - static void Shutdown() { - XLA_CHECK(IsInitialized()) - << "Must initialize distributed runtime before shutdown"; - dist_runtime_ = nullptr; - } - - // Retrieve the shared DistributedRuntime object. - static DistributedRuntime& Get() { - XLA_CHECK(IsInitialized()) - << "Must initialize distributed runtime before retrieval"; - return *dist_runtime_; - } - - ~DistributedRuntime(); - DistributedRuntime(DistributedRuntime const&) = delete; - void operator=(DistributedRuntime const&) = delete; - - // Retrieve the DistributedRuntimeClient. - std::shared_ptr GetClient(); - - // Register a PreemptionSyncManager for the distributed runtime if none is - // active. The PreemptionSyncManager will register a SIGTERM handler, and - // when any host has received a preemption notice, all hosts are made aware - // through the ReachedSyncPoint API. See the documentation of - // tsl::PreemptionSyncManager for the full semantics: - // https://github.com/google/tsl/blob/3bbe663/tsl/distributed_runtime/preemption/preemption_sync_manager.h#L34 - void ActivatePreemptionSyncManager(); - - // A pass-throguh API to the PreemptionSyncManager::ReachedSyncPoint. - // The PreemptionSyncManager must be activated within the DistributedRuntime. - // Returns true when the input step has been identified as a sync point, and - // false otherwise. - bool ReachedSyncPoint(int step); - - private: - static inline std::unique_ptr dist_runtime_; - - DistributedRuntime(int global_rank, int world_size, std::string master_addr, - std::string port); - - std::unique_ptr dist_runtime_service_; - std::shared_ptr dist_runtime_client_; - std::shared_ptr preemption_sync_manager_; -}; - -} // namespace runtime -} // namespace torch_xla - -#endif // XLA_CLIENT_DISTRIBUTED_RUNTIME_H_ diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index ce6c747b3cb..5f49fd2e357 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -9,7 +9,7 @@ #include "pjrt_computation_client.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" -#include "torch_xla/csrc/runtime/distributed_runtime.h" +//#include "torch_xla/csrc/runtime/xla_coordinator.h" #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/multi_wait.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" @@ -114,12 +114,12 @@ PjRtComputationClient::PjRtComputationClient() { std::string master_addr = runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost"); std::string port = runtime::sys_util::GetEnvString( - "XLA_COORDINATOR_PORT", DistributedRuntime::kDefaultCoordinatorPort); + "XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort); - // Use the DistributedRuntime as the distributed key-value store. - DistributedRuntime::Initialize(global_process_rank, global_world_size, - master_addr, port); - auto distributed_client = DistributedRuntime::Get().GetClient(); + // Use the XlaCoordinator as the distributed key-value store. + coordinator_ = std::make_unique( + global_process_rank, global_world_size, master_addr, port); + auto distributed_client = coordinator_->GetClient(); auto allowed_devices = std::make_optional>(std::set{local_process_rank}); xla::PjRtClient::KeyValueGetCallback kv_get = nullptr; @@ -187,6 +187,26 @@ PjRtComputationClient::PjRtComputationClient() { device_locks_.emplace(spmd_device_str, std::make_unique()); } +bool PjRtComputationClient::CoordinatorInitialized() const { + return coordinator_ != nullptr; +} + +void PjRtComputationClient::InitializeCoordinator(int global_rank, + int world_size, + std::string master_addr, + std::string port) { + XLA_CHECK(coordinator_ == nullptr) + << "Can only initialize the XlaCoordinator once."; + coordinator_ = std::make_unique(global_rank, world_size, + master_addr, port); +} + +XlaCoordinator& PjRtComputationClient::GetCoordinator() { + XLA_CHECK(coordinator_ != nullptr) + << "XlaCoordinator has not been initialized"; + return *coordinator_; +} + void PjRtComputationClient::PjRtData::Assign( const torch::lazy::BackendData& data) { const PjRtData& pjrt_data = dynamic_cast(data); diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index f4fc73bb79e..0d2715e2653 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -11,6 +11,7 @@ #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/util.h" +#include "torch_xla/csrc/runtime/xla_coordinator.h" #include "xla/client/xla_computation.h" #include "xla/literal.h" #include "xla/pjrt/pjrt_client.h" @@ -89,6 +90,14 @@ class PjRtComputationClient : public ComputationClient { std::map GetMetrics() const override; + void InitializeCoordinator( + int global_rank, int world_size, std::string master_addr, + std::string port = XlaCoordinator::kDefaultCoordinatorPort) override; + + XlaCoordinator& GetCoordinator() override; + + bool CoordinatorInitialized() const override; + // NOT IMPLEMENTED MemoryInfo GetMemoryInfo(const std::string& device) override { @@ -97,6 +106,7 @@ class PjRtComputationClient : public ComputationClient { private: std::shared_ptr client_; + std::unique_ptr coordinator_; // global_ordinals_ tracks a map from PjRtDeviceId to the device's // dense global ordinal. std::unordered_map global_ordinals_; diff --git a/torch_xla/csrc/runtime/distributed_runtime.cc b/torch_xla/csrc/runtime/xla_coordinator.cc similarity index 78% rename from torch_xla/csrc/runtime/distributed_runtime.cc rename to torch_xla/csrc/runtime/xla_coordinator.cc index de5313ef7c2..606fe5cb470 100644 --- a/torch_xla/csrc/runtime/distributed_runtime.cc +++ b/torch_xla/csrc/runtime/xla_coordinator.cc @@ -1,14 +1,14 @@ -#include "torch_xla/csrc/runtime/distributed_runtime.h" +#include "torch_xla/csrc/runtime/xla_coordinator.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/sys_util.h" +#include "xla/pjrt/distributed/distributed.h" namespace torch_xla { namespace runtime { -DistributedRuntime::DistributedRuntime(int global_rank, int world_size, - std::string master_addr, - std::string port) { +XlaCoordinator::XlaCoordinator(int global_rank, int world_size, + std::string master_addr, std::string port) { std::string dist_service_addr = absl::StrJoin({master_addr, port}, ":"); if (global_rank == 0) { xla::CoordinationServiceImpl::Options service_options; @@ -29,7 +29,7 @@ DistributedRuntime::DistributedRuntime(int global_rank, int world_size, << "Failed to initialize distributed runtime client"; } -DistributedRuntime::~DistributedRuntime() { +XlaCoordinator::~XlaCoordinator() { if (dist_runtime_client_ != nullptr) { XLA_CHECK(dist_runtime_client_->Shutdown().ok()) << "Failed to shut down the distributed runtime client."; @@ -41,13 +41,13 @@ DistributedRuntime::~DistributedRuntime() { } } -std::shared_ptr DistributedRuntime::GetClient() { +std::shared_ptr XlaCoordinator::GetClient() { XLA_CHECK(dist_runtime_client_ != nullptr) << "distributed runtime client is null."; return dist_runtime_client_; } -void DistributedRuntime::ActivatePreemptionSyncManager() { +void XlaCoordinator::ActivatePreemptionSyncManager() { if (preemption_sync_manager_ == nullptr) { preemption_sync_manager_ = std::move(tsl::CreatePreemptionSyncManager()); auto client = dist_runtime_client_->GetCoordinationServiceAgent(); @@ -57,10 +57,14 @@ void DistributedRuntime::ActivatePreemptionSyncManager() { } } -bool DistributedRuntime::ReachedSyncPoint(int step) { +void XlaCoordinator::DeactivatePreemptionSyncManager() { + preemption_sync_manager_ = nullptr; +} + +bool XlaCoordinator::ReachedSyncPoint(int step) { XLA_CHECK(preemption_sync_manager_ != nullptr) << "A PreemptionSyncManager has not been registered with the " - "DistributedRuntime."; + "XlaCoordinator."; return preemption_sync_manager_->ReachedSyncPoint(step); } diff --git a/torch_xla/csrc/runtime/xla_coordinator.h b/torch_xla/csrc/runtime/xla_coordinator.h new file mode 100644 index 00000000000..88cc3e752dd --- /dev/null +++ b/torch_xla/csrc/runtime/xla_coordinator.h @@ -0,0 +1,58 @@ +#ifndef PTXLA_RUNTIME_COORDINATOR_H_ +#define PTXLA_RUNTIME_COORDINATOR_H_ + +#include + +#include "tsl/distributed_runtime/preemption/preemption_sync_manager.h" + +// Forward declaration +namespace xla { +class DistributedRuntimeClient; +class DistributedRuntimeService; +} // namespace xla + +namespace torch_xla { +namespace runtime { + +// XlaCoordinator serves as the point of entry for all operations which +// required the XLA distributed runtime, such as preemption coordination. +class XlaCoordinator { + public: + static inline const std::string kDefaultCoordinatorPort = "8547"; + + XlaCoordinator(int global_rank, int world_size, std::string master_addr, + std::string port); + + ~XlaCoordinator(); + + // Retrieve the DistributedRuntimeClient. + std::shared_ptr GetClient(); + + // Register a PreemptionSyncManager for the distributed runtime if none is + // active. The PreemptionSyncManager will register a SIGTERM handler, and + // when any host has received a preemption notice, all hosts are made aware + // through the ReachedSyncPoint API. See the documentation of + // tsl::PreemptionSyncManager for the full semantics: + // https://github.com/google/tsl/blob/3bbe663/tsl/distributed_runtime/preemption/preemption_sync_manager.h#L34 + void ActivatePreemptionSyncManager(); + + // If the PreemptionSyncManager is active, this will deactivate it and + // destroy the current instance. + void DeactivatePreemptionSyncManager(); + + // A pass-through API to PreemptionSyncManager::ReachedSyncPoint. + // The PreemptionSyncManager must be activated within the XlaCoordinator. + // Returns true when the input step has been identified as a sync point, and + // false otherwise. + bool ReachedSyncPoint(int step); + + private: + std::unique_ptr dist_runtime_service_; + std::shared_ptr dist_runtime_client_; + std::unique_ptr preemption_sync_manager_; +}; + +} // namespace runtime +} // namespace torch_xla + +#endif // PTXLA_RUNTIME_COORDINATOR_H_