Skip to content

Commit

Permalink
Ensure dist runtime is in sync before shutting down. (#5714)
Browse files Browse the repository at this point in the history
* added the new dist runtime class.

* convert the class to singleton

* the tests passed.

* cleaned up

* fix a build error

* Use std::call_once. But it results in early shutting down the dist rt service before the job finishes.

* Revert "Use std::call_once. But it results in early shutting down the dist rt service before the job finishes."

This reverts commit d6e2b98.

* fix comments

* fix comment
  • Loading branch information
vanbasten23 authored Oct 24, 2023
1 parent 8f121c3 commit fa7828f
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 98 deletions.
1 change: 0 additions & 1 deletion test/pjrt/test_runtime_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ def _reduce_scatter(pin_layout):
return out.cpu().numpy()

# 2023-08-02 04:16:36.520884: F external/xla/xla/service/layout_assignment.cc:157] Check failed: ShapeUtil::Compatible(shape_layout.shape(), instruction->operand(operand_no)->shape()) f32[1]{0} is not compatible with f32[2]{0} (for operand 0 of instruction %reduce-scatter.10 = f32[1]{0} reduce-scatter(f32[2]{0} %add.5), replica_groups={}, constrain_layout=true, dimensions={0}, to_apply=%AddComputation.6)
@unittest.skip("Failed with known error.")
@parameterized.named_parameters(('pinned', True), ('unpinned', False))
def test_reduce_scatter(self, pin_layout):
results = pjrt.run_multiprocess(self._reduce_scatter, pin_layout)
Expand Down
29 changes: 0 additions & 29 deletions torch_xla/_internal/gpu.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import os
import atexit
import torch_xla
import torch_xla.core.xla_env_vars as xenv

distributed_service = None


def num_local_processes() -> int:
"""Returns number of processes to create on this host.
Expand All @@ -17,28 +13,3 @@ def num_local_processes() -> int:
"Must set `GPU_NUM_DEVICES` environment variable to use the PjRt GPU client"
os.environ[xenv.LOCAL_WORLD_SIZE] = os.environ[xenv.GPU_NUM_DEVICES]
return int(os.environ[xenv.LOCAL_WORLD_SIZE])


def initialize_distributed_runtime(global_world_size: int) -> None:
"""Configures GPU distributed runtime parameters.
Must be run before using any XLA devices.
Args:
global_world_size: number of devices in the cluster.
"""
if global_world_size > 1:
global distributed_service
if distributed_service is None:
num_nodes = global_world_size
distributed_service = torch_xla._XLAC._xla_get_distributed_runtime_service(
num_nodes)
atexit.register(shutdown_distributed_runtime)


def shutdown_distributed_runtime() -> None:
"""Destroy the distributed runtime after a distributed computation."""
global distributed_service
if distributed_service:
distributed_service.shutdown()
distributed_service = None
13 changes: 0 additions & 13 deletions torch_xla/_internal/pjrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,6 @@ def _run_singleprocess(fn: Callable[..., R], *args, **kwargs) -> Dict[int, R]:
return fn(*args, **kwargs)


def should_initialize_dist_runtime(local_rank: int):
if dist.is_torchelastic_launched():
assert xenv.RANK in os.environ, 'Environment variable is not set.'
return xu.getenv_as(xenv.RANK, int) == 0
return local_rank == 0


@runtime.requires_pjrt
def initialize_multiprocess(local_rank: int, local_world_size: int):
os.environ.setdefault(xenv.PJRT_LOCAL_PROCESS_RANK, str(local_rank))
Expand All @@ -120,12 +113,6 @@ def initialize_multiprocess(local_rank: int, local_world_size: int):
tpu.configure_topology(local_rank, local_world_size)
elif runtime.device_type() == 'NEURON':
neuron.initialize_env(local_rank)
elif runtime.device_type() in ('GPU', 'ROCM', 'CUDA'):
global_world_size = xu.getenv_as(
xenv.WORLD_SIZE, int, xu.getenv_as(xenv.LOCAL_WORLD_SIZE, int, 1))
assert global_world_size >= 0
if should_initialize_dist_runtime(local_rank):
gpu.initialize_distributed_runtime(global_world_size)

devices = xm.get_xla_supported_devices()
xm.set_replication(xm.xla_device(), devices)
Expand Down
24 changes: 0 additions & 24 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1915,30 +1915,6 @@ void InitXlaModuleBindings(py::module m) {
SetAllReduceToken(device, token);
});

/* The distributed runtime service is used by the PjRt GPU client. */
py::class_<xla::DistributedRuntimeService,
std::unique_ptr<xla::DistributedRuntimeService>>
distributed_runtime_service(m, "DistributedRuntimeService");
distributed_runtime_service.def("shutdown",
&xla::DistributedRuntimeService::Shutdown,
py::call_guard<py::gil_scoped_release>());
m.def(
"_xla_get_distributed_runtime_service",
[](int num_nodes) -> std::unique_ptr<xla::DistributedRuntimeService> {
std::string master_addr =
runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost");
std::string port =
runtime::sys_util::GetEnvString("XLA_COORDINATOR_PORT", "8547");
std::string dist_service_addr = absl::StrJoin({master_addr, port}, ":");
XLA_CHECK(num_nodes > 0) << "num_nodes must be positive: " << num_nodes;

xla::CoordinationServiceImpl::Options options;
options.num_nodes = num_nodes;
return std::move(
xla::GetDistributedRuntimeService(dist_service_addr, options)
.value());
});

BuildProfilerSubmodule(&m);
BuildLoweringContextSubmodule(&m);

Expand Down
12 changes: 12 additions & 0 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ cc_library(
deps = [
":computation_client",
":debug_macros",
":distributed_runtime",
":env_vars",
":multi_wait",
":stablehlo_helper",
Expand Down Expand Up @@ -163,6 +164,17 @@ cc_library(
],
)

cc_library(
name = "distributed_runtime",
srcs = ["distributed_runtime.cc"],
hdrs = ["distributed_runtime.h"],
deps = [
":debug_macros",
":sys_util",
"@xla//xla/pjrt/distributed",
],
)

cc_library(
name = "metrics",
srcs = ["metrics.cc"],
Expand Down
54 changes: 54 additions & 0 deletions torch_xla/csrc/runtime/distributed_runtime.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#include "torch_xla/csrc/runtime/distributed_runtime.h"

#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/sys_util.h"

namespace torch_xla {
namespace runtime {

const std::string DistributedRuntime::default_coordinator_port = "8547";

DistributedRuntime::DistributedRuntime(int global_rank, std::string master_addr,
std::string port) {
std::string dist_service_addr = absl::StrJoin({master_addr, port}, ":");
if (global_rank == 0) {
int local_world_size = sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1);
int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", local_world_size);
xla::CoordinationServiceImpl::Options service_options;
service_options.num_nodes = global_world_size;
xla::StatusOr<std::unique_ptr<xla::DistributedRuntimeService>>
dist_runtime_service = xla::GetDistributedRuntimeService(
dist_service_addr, service_options);
XLA_CHECK(dist_runtime_service.ok())
<< "Failed to initialize distributed runtime service.";
dist_runtime_service_ = std::move(dist_runtime_service.value());
}

xla::DistributedRuntimeClient::Options client_options;
client_options.node_id = global_rank;
dist_runtime_client_ =
xla::GetDistributedRuntimeClient(dist_service_addr, client_options);
XLA_CHECK(dist_runtime_client_->Connect().ok())
<< "Failed to initialize distributed runtime client";
}

DistributedRuntime::~DistributedRuntime() {
if (dist_runtime_client_ != nullptr) {
XLA_CHECK(dist_runtime_client_->Shutdown().ok())
<< "Failed to shut down the distributed runtime client.";
dist_runtime_client_ = nullptr;
}
if (dist_runtime_service_ != nullptr) {
dist_runtime_service_->Shutdown();
dist_runtime_service_ = nullptr;
}
}

std::shared_ptr<xla::DistributedRuntimeClient> DistributedRuntime::GetClient() {
XLA_CHECK(dist_runtime_client_ != nullptr)
<< "distributed runtime client is null.";
return dist_runtime_client_;
}

} // namespace runtime
} // namespace torch_xla
38 changes: 38 additions & 0 deletions torch_xla/csrc/runtime/distributed_runtime.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#ifndef XLA_CLIENT_DISTRIBUTED_RUNTIME_H_
#define XLA_CLIENT_DISTRIBUTED_RUNTIME_H_

#include <memory>

#include "xla/pjrt/distributed/distributed.h"

namespace torch_xla {
namespace runtime {

class DistributedRuntime {
public:
static const std::string default_coordinator_port;
static DistributedRuntime& getInstance(int global_rank,
std::string master_addr,
std::string port) {
static DistributedRuntime dist_runtime_instance(global_rank, master_addr,
port);
return dist_runtime_instance;
}
~DistributedRuntime();
DistributedRuntime(DistributedRuntime const&) = delete;
void operator=(DistributedRuntime const&) = delete;

std::shared_ptr<xla::DistributedRuntimeClient> GetClient();

private:
DistributedRuntime(int global_rank, std::string master_addr,
std::string port);

std::unique_ptr<xla::DistributedRuntimeService> dist_runtime_service_;
std::shared_ptr<xla::DistributedRuntimeClient> dist_runtime_client_;
};

} // namespace runtime
} // namespace torch_xla

#endif // XLA_CLIENT_DISTRIBUTED_RUNTIME_H_
44 changes: 13 additions & 31 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "pjrt_computation_client.h"
#include "torch_xla/csrc/runtime/computation_client.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/distributed_runtime.h"
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/runtime/multi_wait.h"
#include "torch_xla/csrc/runtime/stablehlo_helper.h"
Expand Down Expand Up @@ -37,29 +38,6 @@ namespace {

static std::string spmd_device_str = "SPMD:0";

// Initializes a distributed runtime client if dist_service_addr is specified
std::shared_ptr<xla::DistributedRuntimeClient>
MaybeInitializeDistributedRuntimeClient(int local_rank) {
std::shared_ptr<xla::DistributedRuntimeClient> client;
int global_world_size = sys_util::GetEnvInt(
"WORLD_SIZE", sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1));
if (global_world_size < 2) {
return std::move(client);
}
std::string master_addr = sys_util::GetEnvString("MASTER_ADDR", "localhost");
std::string port =
runtime::sys_util::GetEnvString("XLA_COORDINATOR_PORT", "8547");
std::string dist_service_addr = absl::StrJoin({master_addr, port}, ":");
xla::DistributedRuntimeClient::Options options;
options.node_id = local_rank;
TF_VLOG(3) << "Getting distributed runtime client for address="
<< dist_service_addr << ", node_id=" << options.node_id;
client = xla::GetDistributedRuntimeClient(dist_service_addr, options);
XLA_CHECK(client->Connect().ok())
<< "Failed to initialize distributed runtime client";
return std::move(client);
}

// Builds a map from the device's global ordinal to its index in the `devices`
// array.
std::unordered_map<int, int> build_index_map(
Expand Down Expand Up @@ -131,10 +109,14 @@ PjRtComputationClient::PjRtComputationClient() {
TF_VLOG(1) << "Initializing PjRt GPU client...";
bool async = sys_util::GetEnvBool(env::kEnvPjrtAsyncGpuClient, true);
int local_process_rank = sys_util::GetEnvInt(env::kEnvPjRtLocalRank, 0);

int global_process_rank = sys_util::GetEnvInt("RANK", local_process_rank);
auto distributed_client =
MaybeInitializeDistributedRuntimeClient(global_process_rank);
std::string master_addr =
runtime::sys_util::GetEnvString("MASTER_ADDR", "localhost");
std::string port = runtime::sys_util::GetEnvString(
"XLA_COORDINATOR_PORT", DistributedRuntime::default_coordinator_port);
std::shared_ptr<xla::DistributedRuntimeClient> distributed_client =
DistributedRuntime::getInstance(global_process_rank, master_addr, port)
.GetClient();
auto allowed_devices =
std::make_optional<std::set<int>>(std::set{local_process_rank});
xla::PjRtClient::KeyValueGetCallback kv_get = nullptr;
Expand All @@ -151,15 +133,15 @@ PjRtComputationClient::PjRtComputationClient() {
return distributed_client->KeyValueSet(absl::StrCat(key_prefix, k), v);
};
}
int global_world_size = sys_util::GetEnvInt(
"WORLD_SIZE", sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1));
int local_world_size = sys_util::GetEnvInt("LOCAL_WORLD_SIZE", 1);
int global_world_size = sys_util::GetEnvInt("WORLD_SIZE", local_world_size);
TF_VLOG(3) << "Getting StreamExecutorGpuClient for node_id="
<< global_process_rank << ", num_nodes=" << global_world_size;
client_ = std::move(xla::GetStreamExecutorGpuClient(
/*asynchronous=*/async, xla::GpuAllocatorConfig{},
/*asynchronous=*/async,
/*allocator_config=*/xla::GpuAllocatorConfig{},
/*node_id=*/global_process_rank,
/*num_nodes=*/
global_world_size,
/*num_nodes=*/global_world_size,
/*allowed_devices=*/allowed_devices,
/*platform_name=*/"gpu",
/*should_stage_host_to_device_transfers=*/true,
Expand Down

0 comments on commit fa7828f

Please sign in to comment.