diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index fbc75bc5e4d3..29ed825d015b 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -342,9 +342,12 @@ def setUp(self): super().setUp() # Initialize the a minimal process group dist.init_process_group( - backend='gloo', init_method='tcp://127.1:8932', world_size=1, rank=0) + backend='gloo', + init_method='tcp://localhost:8932', + world_size=1, + rank=0) torch_xla._XLAC._ensure_xla_coordinator_initialized( - global_rank=0, world_size=1, master_addr="127.1") + global_rank=0, world_size=1, master_addr="localhost") def tearDown(self): super().tearDown() diff --git a/torch_xla/csrc/runtime/computation_client.h b/torch_xla/csrc/runtime/computation_client.h index 4d3df1d23ff6..145a6d0aa091 100644 --- a/torch_xla/csrc/runtime/computation_client.h +++ b/torch_xla/csrc/runtime/computation_client.h @@ -30,7 +30,10 @@ namespace torch_xla { namespace runtime { -// Forward declaration +// Forward declare XlaCoordinator to avoid logging macro redefinition from the +// transitively included PJRT header. +// TODO(jonbolin): We need a way to ensure the right macros are included +// regardless of the import order. class XlaCoordinator; // Somehow the compiler doesn't allow type that has default member being diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.cc b/torch_xla/csrc/runtime/pjrt_computation_client.cc index 5f49fd2e3570..dc6ad1df6701 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.cc +++ b/torch_xla/csrc/runtime/pjrt_computation_client.cc @@ -9,12 +9,12 @@ #include "pjrt_computation_client.h" #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" -//#include "torch_xla/csrc/runtime/xla_coordinator.h" #include "torch_xla/csrc/runtime/env_vars.h" #include "torch_xla/csrc/runtime/multi_wait.h" #include "torch_xla/csrc/runtime/stablehlo_helper.h" #include "torch_xla/csrc/runtime/tf_logging.h" #include "torch_xla/csrc/runtime/thread_pool.h" +#include "torch_xla/csrc/runtime/xla_coordinator.h" #include "tsl/profiler/lib/traceme.h" #include "xla/client/xla_builder.h" #include "xla/client/xla_computation.h" @@ -119,7 +119,8 @@ PjRtComputationClient::PjRtComputationClient() { // Use the XlaCoordinator as the distributed key-value store. coordinator_ = std::make_unique( global_process_rank, global_world_size, master_addr, port); - auto distributed_client = coordinator_->GetClient(); + std::shared_ptr distributed_client = + coordinator_->GetClient(); auto allowed_devices = std::make_optional>(std::set{local_process_rank}); xla::PjRtClient::KeyValueGetCallback kv_get = nullptr; @@ -187,6 +188,11 @@ PjRtComputationClient::PjRtComputationClient() { device_locks_.emplace(spmd_device_str, std::make_unique()); } +PjRtComputationClient::~PjRtComputationClient() { + coordinator_ = nullptr; + client_ = nullptr; +} + bool PjRtComputationClient::CoordinatorInitialized() const { return coordinator_ != nullptr; } diff --git a/torch_xla/csrc/runtime/pjrt_computation_client.h b/torch_xla/csrc/runtime/pjrt_computation_client.h index 0d2715e26537..faebd4892b87 100644 --- a/torch_xla/csrc/runtime/pjrt_computation_client.h +++ b/torch_xla/csrc/runtime/pjrt_computation_client.h @@ -11,7 +11,6 @@ #include "torch_xla/csrc/runtime/computation_client.h" #include "torch_xla/csrc/runtime/debug_macros.h" #include "torch_xla/csrc/runtime/util.h" -#include "torch_xla/csrc/runtime/xla_coordinator.h" #include "xla/client/xla_computation.h" #include "xla/literal.h" #include "xla/pjrt/pjrt_client.h" @@ -24,6 +23,7 @@ namespace runtime { class PjRtComputationClient : public ComputationClient { public: PjRtComputationClient(); + ~PjRtComputationClient(); DataPtr CreateDataPlaceholder(std::string device, xla::Shape shape) override; @@ -90,9 +90,9 @@ class PjRtComputationClient : public ComputationClient { std::map GetMetrics() const override; - void InitializeCoordinator( - int global_rank, int world_size, std::string master_addr, - std::string port = XlaCoordinator::kDefaultCoordinatorPort) override; + void InitializeCoordinator(int global_rank, int world_size, + std::string master_addr, + std::string port) override; XlaCoordinator& GetCoordinator() override; diff --git a/torch_xla/csrc/runtime/xla_coordinator.cc b/torch_xla/csrc/runtime/xla_coordinator.cc index 606fe5cb470a..72855d8681ea 100644 --- a/torch_xla/csrc/runtime/xla_coordinator.cc +++ b/torch_xla/csrc/runtime/xla_coordinator.cc @@ -30,6 +30,7 @@ XlaCoordinator::XlaCoordinator(int global_rank, int world_size, } XlaCoordinator::~XlaCoordinator() { + preemption_sync_manager_ = nullptr; if (dist_runtime_client_ != nullptr) { XLA_CHECK(dist_runtime_client_->Shutdown().ok()) << "Failed to shut down the distributed runtime client."; diff --git a/torch_xla/csrc/runtime/xla_coordinator.h b/torch_xla/csrc/runtime/xla_coordinator.h index 88cc3e752dd9..ae85c79a9416 100644 --- a/torch_xla/csrc/runtime/xla_coordinator.h +++ b/torch_xla/csrc/runtime/xla_coordinator.h @@ -4,12 +4,7 @@ #include #include "tsl/distributed_runtime/preemption/preemption_sync_manager.h" - -// Forward declaration -namespace xla { -class DistributedRuntimeClient; -class DistributedRuntimeService; -} // namespace xla +#include "xla/pjrt/distributed/distributed.h" namespace torch_xla { namespace runtime {