Skip to content

Commit

Permalink
create a NCCL sub-communicator using ncclCommSplit
Browse files Browse the repository at this point in the history
  • Loading branch information
seunghwak committed Nov 13, 2024
1 parent 615447f commit 999c46e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 53 deletions.
31 changes: 27 additions & 4 deletions cpp/include/raft/comms/detail/mpi_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,28 @@ class mpi_comms : public comms_iface {
RAFT_MPI_TRY(MPI_Bcast((void*)&id, sizeof(id), MPI_BYTE, 0, mpi_comm_));

// initializing NCCL
RAFT_NCCL_TRY(ncclCommInitRank(&nccl_comm_, size_, id, rank_));
ncclConfig_t nccl_config = NCCL_CONFIG_INITIALIZER;
nccl_config.splitShare = 1;
RAFT_NCCL_TRY(ncclCommInitRankConfig(&nccl_comm_, size_, id, rank_, &nccl_config));

initialize();
}

mpi_comms(MPI_Comm mpi_comm, bool owns_mpi_comm, ncclComm_t nccl_comm, rmm::cuda_stream_view stream)
: owns_mpi_comm_(owns_mpi_comm),
mpi_comm_(mpi_comm),
nccl_comm_(nccl_comm),
size_(0),
rank_(1),
status_(stream),
next_request_id_(0),
stream_(stream)
{
int mpi_is_initialized = 0;
RAFT_MPI_TRY(MPI_Initialized(&mpi_is_initialized));
RAFT_EXPECTS(mpi_is_initialized, "ERROR: MPI is not initialized!");
RAFT_MPI_TRY(MPI_Comm_size(mpi_comm_, &size_));
RAFT_MPI_TRY(MPI_Comm_rank(mpi_comm_, &rank_));

initialize();
}
Expand All @@ -150,9 +171,11 @@ class mpi_comms : public comms_iface {

std::unique_ptr<comms_iface> comm_split(int color, int key) const
{
MPI_Comm new_comm;
RAFT_MPI_TRY(MPI_Comm_split(mpi_comm_, color, key, &new_comm));
return std::unique_ptr<comms_iface>(new mpi_comms(new_comm, true, stream_));
MPI_Comm new_mpi_comm;
RAFT_MPI_TRY(MPI_Comm_split(mpi_comm_, color, key, &new_mpi_comm));
ncclComm_t new_nccl_comm{};
RAFT_NCCL_TRY(ncclCommSplit(nccl_comm_, color, key, &new_nccl_comm, nullptr));
return std::unique_ptr<comms_iface>(new mpi_comms(new_mpi_comm, true, new_nccl_comm, stream_));
}

void barrier() const
Expand Down
54 changes: 5 additions & 49 deletions cpp/include/raft/comms/detail/std_comms.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,55 +140,11 @@ class std_comms : public comms_iface {

std::unique_ptr<comms_iface> comm_split(int color, int key) const
{
rmm::device_uvector<int> d_colors(get_size(), stream_);
rmm::device_uvector<int> d_keys(get_size(), stream_);

update_device(d_colors.data() + get_rank(), &color, 1, stream_);
update_device(d_keys.data() + get_rank(), &key, 1, stream_);

allgather(d_colors.data() + get_rank(), d_colors.data(), 1, datatype_t::INT32, stream_);
allgather(d_keys.data() + get_rank(), d_keys.data(), 1, datatype_t::INT32, stream_);
this->sync_stream(stream_);

std::vector<int> h_colors(get_size());
std::vector<int> h_keys(get_size());

update_host(h_colors.data(), d_colors.data(), get_size(), stream_);
update_host(h_keys.data(), d_keys.data(), get_size(), stream_);

this->sync_stream(stream_);

ncclComm_t nccl_comm;

// Create a structure to allgather...
ncclUniqueId id{};
rmm::device_uvector<ncclUniqueId> d_nccl_ids(get_size(), stream_);

if (key == 0) { RAFT_NCCL_TRY(ncclGetUniqueId(&id)); }

update_device(d_nccl_ids.data() + get_rank(), &id, 1, stream_);

allgather(d_nccl_ids.data() + get_rank(),
d_nccl_ids.data(),
sizeof(ncclUniqueId),
datatype_t::UINT8,
stream_);

auto offset =
std::distance(thrust::make_zip_iterator(h_colors.begin(), h_keys.begin()),
std::find_if(thrust::make_zip_iterator(h_colors.begin(), h_keys.begin()),
thrust::make_zip_iterator(h_colors.end(), h_keys.end()),
[color](auto tuple) { return thrust::get<0>(tuple) == color; }));

auto subcomm_size = std::count(h_colors.begin(), h_colors.end(), color);

update_host(&id, d_nccl_ids.data() + offset, 1, stream_);

this->sync_stream(stream_);

RAFT_NCCL_TRY(ncclCommInitRank(&nccl_comm, subcomm_size, id, key));

return std::unique_ptr<comms_iface>(new std_comms(nccl_comm, subcomm_size, key, stream_, true));
ncclComm_t new_nccl_comm{};
RAFT_NCCL_TRY(ncclCommSplit(nccl_comm_, color, key, &new_nccl_comm, nullptr));
int new_nccl_comm_size{};
RAFT_NCCL_TRY(ncclCommCount(new_nccl_comm, &new_nccl_comm_size));
return std::unique_ptr<comms_iface>(new std_comms(new_nccl_comm, new_nccl_comm_size, key, stream_, true));
}

void barrier() const
Expand Down

0 comments on commit 999c46e

Please sign in to comment.