Skip to content

Commit

Permalink
[xla:collectives] NFC: Move Config and DeviceRank from NcclApi to Col…
Browse files Browse the repository at this point in the history
…lectives API

PiperOrigin-RevId: 702919277
  • Loading branch information
ezhulenev authored and tensorflower-gardener committed Dec 5, 2024
1 parent fc97f62 commit 9184c12
Show file tree
Hide file tree
Showing 9 changed files with 132 additions and 41 deletions.
8 changes: 8 additions & 0 deletions third_party/xla/xla/backends/gpu/collectives/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ cc_library(
"//xla:status_macros",
"//xla:types",
"//xla:util",
"//xla/core/collectives",
"//xla/core/collectives:clique_id",
"//xla/core/collectives:communicator",
"//xla/core/collectives:rank_id",
Expand Down Expand Up @@ -115,12 +116,19 @@ cc_library(

cc_library(
name = "gpu_collectives",
srcs = ["gpu_collectives.cc"],
hdrs = ["gpu_collectives.h"],
deps = [
"//xla:shape_util",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/core/collectives",
"//xla/core/collectives:clique_id",
"//xla/core/collectives:clique_key",
"//xla/stream_executor:device_memory",
"//xla/stream_executor:stream_executor_h",
"@com_google_absl//absl/status:statusor",
"@local_tsl//tsl/platform:casts",
],
)

Expand Down
18 changes: 11 additions & 7 deletions third_party/xla/xla/backends/gpu/collectives/gpu_clique_locking.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ limitations under the License.
#include "xla/backends/gpu/collectives/gpu_clique_key.h"
#include "xla/backends/gpu/collectives/gpu_collectives.h"
#include "xla/core/collectives/clique_id.h"
#include "xla/core/collectives/collectives.h"
#include "xla/core/collectives/communicator.h"
#include "xla/core/collectives/rank_id.h"
#include "xla/debug_options_flags.h"
Expand All @@ -60,6 +61,8 @@ limitations under the License.

namespace xla::gpu {

using DeviceRank = Collectives::DeviceRank;

//===----------------------------------------------------------------------===//
// GpuClique Acquire and Initialization Timeouts
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -178,7 +181,7 @@ static void StartGpuCliqueHeartBeatMonitor() {
// defined order and do not deadlock inside underlying collective communication
// library.

static auto DeviceRanksToString(absl::Span<const NcclApi::DeviceRank> ranks) {
static auto DeviceRanksToString(absl::Span<const DeviceRank> ranks) {
return absl::StrJoin(ranks, ",", [](std::string* str, auto& rank) {
str->append(std::to_string(rank.rank.value()));
});
Expand All @@ -192,15 +195,15 @@ InitializeGpuClique(NcclApi* nccl_api, se::StreamExecutor* device, RunId run_id,
const GpuCliqueKey& clique_key,
const GpuCollectives::CliqueIdCallback& clique_id_callback,
int32_t num_local_participants, RankId rank,
NcclApi::Config& config) {
const GpuCollectives::Config& config) {
int nranks = clique_key.devices().size();
VLOG(3) << "Initialize GPU clique " << clique_key.ToString() << " rank #"
<< rank << "; num_local_participants=" << num_local_participants;

// Start GPU clique heart beat monitor when create a first clique.
StartGpuCliqueHeartBeatMonitor();

using RendezvousArg = std::pair<NcclApi::DeviceRank, /*synchronized=*/bool>;
using RendezvousArg = std::pair<DeviceRank, /*synchronized=*/bool>;

// Initializes a GpuClique for given device ranks and returns a lock that
// gives access to clique communicators.
Expand All @@ -219,7 +222,7 @@ InitializeGpuClique(NcclApi* nccl_api, se::StreamExecutor* device, RunId run_id,
}
}

std::vector<NcclApi::DeviceRank> ranks;
std::vector<DeviceRank> ranks;
ranks.reserve(args.size());
for (auto* arg : args) ranks.emplace_back(arg->first);

Expand Down Expand Up @@ -274,7 +277,8 @@ InitializeGpuClique(NcclApi* nccl_api, se::StreamExecutor* device, RunId run_id,
absl::StrFormat("initialize clique for rank %d; clique=%s; run_id=%d",
rank.value(), clique_key.ToString(), run_id.ToInt());

NcclApi::DeviceRank device_rank = {device, rank};
GpuCollectives::Device gpu_device(device);
GpuCollectives::DeviceRank device_rank = {&gpu_device, rank};
bool synchronized = device->SynchronizeAllActivity();

// We choose not to exit early on failed synchronization, because it will lead
Expand Down Expand Up @@ -317,7 +321,7 @@ InitializeGpuClique(NcclApi* nccl_api, se::StreamExecutor* device, RunId run_id,
const GpuCliqueKey& clique_key,
std::shared_ptr<LockableGpuClique::Lock> parent_clique,
int32_t num_local_participants, RankId rank,
NcclApi::Config& config) {
const GpuCollectives::Config& config) {
// Find our rank in the parent clique.
const GpuCliqueKey& parent_clique_key = (*parent_clique)->key();
RankId parent_rank =
Expand Down Expand Up @@ -471,7 +475,7 @@ absl::StatusOr<std::shared_ptr<LockableGpuClique::Lock>> AcquireGpuClique(

// We enable resource sharing between parent and split communicators by
// default because that's the only reason why we use comm splitting.
NcclApi::Config config;
GpuCollectives::Config config;
config.split_share = true;
config.max_nchannels = max_nchannels;

Expand Down
52 changes: 52 additions & 0 deletions third_party/xla/xla/backends/gpu/collectives/gpu_collectives.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/* Copyright 2024 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/backends/gpu/collectives/gpu_collectives.h"

#include <cstddef>

#include "absl/status/statusor.h"
#include "xla/core/collectives/collectives.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/util.h"
#include "tsl/platform/casts.h"

namespace xla::gpu {

GpuCollectives::Device::Device(se::StreamExecutor* stream_executor)
: stream_executor_(stream_executor) {}

se::StreamExecutor* GpuCollectives::Device::stream_executor() const {
return stream_executor_;
}

absl::StatusOr<GpuCollectives::Device*> GpuCollectives::TryCast(
Collectives::Device* device) {
if (auto* gpu_device = tsl::down_cast<Device*>(device)) {
return gpu_device;
}
return InvalidArgument("Collectvies device is not a GPU device");
}

se::DeviceMemoryBase GpuCollectives::Slice(se::DeviceMemoryBase buff,
PrimitiveType dtype, size_t offset,
size_t count) {
size_t multiplier = ShapeUtil::ByteSizeOfPrimitiveType(dtype);
return buff.GetByteSlice(offset * multiplier, count * multiplier);
}

} // namespace xla::gpu
34 changes: 34 additions & 0 deletions third_party/xla/xla/backends/gpu/collectives/gpu_collectives.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@ limitations under the License.
#ifndef XLA_BACKENDS_GPU_COLLECTIVES_GPU_COLLECTIVES_H_
#define XLA_BACKENDS_GPU_COLLECTIVES_GPU_COLLECTIVES_H_

#include <cstddef>
#include <cstdint>
#include <functional>

#include "absl/status/statusor.h"
#include "xla/core/collectives/clique_id.h"
#include "xla/core/collectives/clique_key.h"
#include "xla/core/collectives/collectives.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/xla_data.pb.h"

namespace xla::gpu {

Expand All @@ -32,13 +37,42 @@ class GpuCollectives : public Collectives {
using CliqueIdCallback = // NOLINT
std::function<absl::StatusOr<CliqueId>(const CliqueKey&)>;

// GPU collectives device is just a wrapper around the StreamExecutor.
class Device : public Collectives::Device {
public:
explicit Device(stream_executor::StreamExecutor* stream_executor);
stream_executor::StreamExecutor* stream_executor() const;

private:
stream_executor::StreamExecutor* stream_executor_;
};

// Casts a Collectives::Device to a GPU device and returns an error if it's
// not a GPU device.
static absl::StatusOr<Device*> TryCast(Collectives::Device* device);

// GPU communicator configuration.
//
// For NCCL backend see configuration options documentation at:
// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig
struct Config {
bool split_share = false;
int64_t max_nchannels = 0;
};

// Returns true if collectives backend uses global config.
virtual bool IsGlobalConfig() const = 0;

// Returns a clique id callback passed as an argument if it's not null or a
// default callback to get create a clique id if we are running in local mode.
virtual absl::StatusOr<const CliqueIdCallback*> GetCliqueIdCallback(
const CliqueIdCallback* clique_id_callback, bool is_local) = 0;

// Returns a slice of device memory `buff` containing `count` values of data
// type `dtype` starting from `offset`.
static stream_executor::DeviceMemoryBase Slice(
stream_executor::DeviceMemoryBase buff, PrimitiveType dtype,
size_t offset, size_t count);
};

} // namespace xla::gpu
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/core/collectives/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ cc_library(
deps = [
":clique_id",
":communicator",
":rank_id",
"@com_google_absl//absl/status:statusor",
],
)
Expand Down
16 changes: 16 additions & 0 deletions third_party/xla/xla/core/collectives/collectives.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.

#include "absl/status/statusor.h"
#include "xla/core/collectives/clique_id.h"
#include "xla/core/collectives/rank_id.h"

namespace xla {

Expand All @@ -35,6 +36,21 @@ class Collectives {
public:
virtual ~Collectives() = default;

// A base class for the device that the collectives are running on, i.e. in
// XLA:GPU this is the GPU device (StreamExecutor).
class Device {
public:
virtual ~Device() = default;
};

// A collective device together with its rank in the collective clique.
struct DeviceRank {
DeviceRank(Device* device, RankId rank) : device(device), rank(rank) {}

Device* device;
RankId rank;
};

// Creates a unique CliqueId.
virtual absl::StatusOr<CliqueId> CreateUniqueCliqueId() const = 0;
};
Expand Down
16 changes: 8 additions & 8 deletions third_party/xla/xla/service/gpu/runtime/nccl_all_to_all_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,12 +236,12 @@ absl::Status RunAllToAll(NcclApi* nccl_api, bool has_split_dimension,

for (int peer = 0; peer < num_ranks; ++peer) {
se::DeviceMemoryBase send_slice =
NcclApi::Slice(buffer.source_buffer, buffer.element_type,
peer * chunk_elements, chunk_elements);
nccl_api->Slice(buffer.source_buffer, buffer.element_type,
peer * chunk_elements, chunk_elements);

se::DeviceMemoryBase recv_slice =
NcclApi::Slice(buffer.destination_buffer, buffer.element_type,
peer * chunk_elements, chunk_elements);
nccl_api->Slice(buffer.destination_buffer, buffer.element_type,
peer * chunk_elements, chunk_elements);

TF_RETURN_IF_ERROR(nccl_api->Send(send_slice, buffer.element_type,
chunk_elements, peer, comm, &stream));
Expand Down Expand Up @@ -298,8 +298,8 @@ absl::Status RunMemCpyAllToAll(
TF_RETURN_IF_ERROR(nccl_api->GroupStart());
for (int peer = 0; peer < num_ranks; ++peer) {
se::DeviceMemoryBase recv_slice =
NcclApi::Slice(buffer.destination_buffer, buffer.element_type,
peer * chunk_elements, chunk_elements);
nccl_api->Slice(buffer.destination_buffer, buffer.element_type,
peer * chunk_elements, chunk_elements);
send_pointer_map[peer] = (uint64_t)recv_slice.opaque();

TF_RETURN_IF_ERROR(nccl_api->SendPtrToPeer(&send_pointer_map[peer],
Expand All @@ -312,8 +312,8 @@ absl::Status RunMemCpyAllToAll(

for (int peer = 0; peer < num_ranks; ++peer) {
se::DeviceMemoryBase send_slice =
NcclApi::Slice(buffer.source_buffer, buffer.element_type,
peer * chunk_elements, chunk_elements);
nccl_api->Slice(buffer.source_buffer, buffer.element_type,
peer * chunk_elements, chunk_elements);
se::DeviceMemoryBase dst_addr =
se::DeviceMemoryBase((void*)receive_pointer_map[peer]);
TF_RETURN_IF_ERROR(
Expand Down
3 changes: 2 additions & 1 deletion third_party/xla/xla/service/gpu/runtime/nccl_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,8 @@ DefaultNcclApi::CommInitRanks(int32_t nranks, const CliqueId& clique_id,
VLOG(1) << "Initialize NCCL communicator for rank #" << ranks[i].rank
<< " of " << nranks
<< "; fingerprint(id)=" << clique_id.fingerprint();
auto activate_context = ranks[i].device->Activate();
TF_ASSIGN_OR_RETURN(auto* device, TryCast(ranks[i].device));
auto activate_context = device->stream_executor()->Activate();

TF_ASSIGN_OR_RETURN(auto nccl_unique_id, AsNcclUniqueId(clique_id));
XLA_NCCL_RETURN_IF_ERROR(
Expand Down
25 changes: 0 additions & 25 deletions third_party/xla/xla/service/gpu/runtime/nccl_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,6 @@ class NcclApi : public GpuCollectives {
public:
virtual ~NcclApi() = default;

// Communicator configuration.
//
// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig
struct Config {
bool split_share = false;
int64_t max_nchannels = 0;
};

// Returns a default NcclApi for a current process. Can be a real one based on
// NCCL or a stub if XLA compiled without NCCL or CUDA support.
static NcclApi* Default();
Expand Down Expand Up @@ -119,23 +111,6 @@ class NcclApi : public GpuCollectives {
tsl::RCReference<PersistentPlanAllocator> allocator_;
};

struct DeviceRank {
DeviceRank(se::StreamExecutor* device, RankId rank)
: device(device), rank(rank) {}

se::StreamExecutor* device;
RankId rank;
};

// Returns a slice of device memory `buff` containing `count` values of data
// type `dtype` starting from `offset`.
static se::DeviceMemoryBase Slice(se::DeviceMemoryBase buff,
PrimitiveType dtype, size_t offset,
size_t count) {
size_t multiplier = ShapeUtil::ByteSizeOfPrimitiveType(dtype);
return buff.GetByteSlice(offset * multiplier, count * multiplier);
}

// Creates new communicators for given devices.
//
// This API doesn't have a corresponding API in NCCL and implemented as
Expand Down

0 comments on commit 9184c12

Please sign in to comment.