diff --git a/torch_xla/csrc/runtime/pjrt_registry.cc b/torch_xla/csrc/runtime/pjrt_registry.cc index 6123b8fd889..01abbed0f6a 100644 --- a/torch_xla/csrc/runtime/pjrt_registry.cc +++ b/torch_xla/csrc/runtime/pjrt_registry.cc @@ -41,6 +41,40 @@ std::optional GetPjRtPluginPath(const std::string& device_type) { : std::nullopt; } +std::unique_ptr SetKeyValueCallback( + int global_process_rank, int global_world_size, + std::unique_ptr coordinator, + xla::PjRtClient::KeyValueGetCallback& kv_get, + xla::PjRtClient::KeyValuePutCallback& kv_put) { + << "function=" << __FUNCTION__ << ": " << std::endl; + std::string master_addr = + runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost"); + std::string port = runtime::sys_util::GetEnvString( + "XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort); + + // Use the XlaCoordinator as the distributed key-value store. + TF_VLOG(3) << "Creating a XlaCoordinator for global_process_rank=" + << global_process_rank + << ", global_world_size=" << global_world_size + << ", master_addr=" << master_addr << ", port=" << port; + coordinator = std::make_unique( + global_process_rank, global_world_size, master_addr, port); + std::shared_ptr distributed_client = + coordinator->GetClient(); + std::string key_prefix = "gpu:"; + kv_get = [distributed_client, key_prefix]( + std::string_view k, + absl::Duration timeout) -> xla::StatusOr { + return distributed_client->BlockingKeyValueGet(absl::StrCat(key_prefix, k), + timeout); + }; + kv_put = [distributed_client, key_prefix](std::string_view k, + std::string_view v) -> xla::Status { + return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v); + }; + return coordinator; +} + } // namespace void RegisterPjRtPlugin(std::string name, std::string library_path) { @@ -90,39 +124,29 @@ InitializePjRt(const std::string& device_type) { int global_process_rank = sys_util::GetEnvInt("RANK", local_process_rank); int local_world_size = sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1); int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", local_world_size); - std::string master_addr = - runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost"); - std::string port = runtime::sys_util::GetEnvString( - "XLA_COORDINATOR_PORT", XlaCoordinator::kDefaultCoordinatorPort); xla::PjRtClient::KeyValueGetCallback kv_get = nullptr; xla::PjRtClient::KeyValuePutCallback kv_put = nullptr; std::optional> allowed_devices; bool spmd = sys_util::GetEnvBool("XLA_USE_SPMD", false); - if (!spmd) { - allowed_devices = std::set{local_process_rank}; - } else if (global_world_size > 1) { - allowed_devices = - std::make_optional>(std::set{local_process_rank}); - // Use the XlaCoordinator as the distributed key-value store. - coordinator = std::make_unique( - global_process_rank, global_world_size, master_addr, port); - std::shared_ptr distributed_client = - coordinator->GetClient(); - std::string key_prefix = "gpu:"; - kv_get = [distributed_client, key_prefix]( - std::string_view k, - absl::Duration timeout) -> xla::StatusOr { - return distributed_client->BlockingKeyValueGet( - absl::StrCat(key_prefix, k), timeout); - }; - kv_put = [distributed_client, key_prefix]( - std::string_view k, std::string_view v) -> xla::Status { - return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v); - }; + if (spmd) { + if (global_world_size > 1) { + coordinator = + SetKeyValueCallback(global_process_rank, global_world_size, + std::move(coordinator), kv_get, kv_put); + } + } else { + if (global_world_size > 1) { + allowed_devices = std::set{local_process_rank}; + coordinator = + SetKeyValueCallback(global_process_rank, global_world_size, + std::move(coordinator), kv_get, kv_put); + } } TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id=" - << global_process_rank << ", num_nodes=" << global_world_size; + << global_process_rank << ", num_nodes=" << global_world_size + << ", spmd case=" << spmd; + xla::GpuClientOptions options; options.allocator_config = GetGpuAllocatorConfig(); options.node_id = global_process_rank;