diff --git a/cpp/include/raft/comms/detail/mpi_comms.hpp b/cpp/include/raft/comms/detail/mpi_comms.hpp index a07146052e..1825da0233 100644 --- a/cpp/include/raft/comms/detail/mpi_comms.hpp +++ b/cpp/include/raft/comms/detail/mpi_comms.hpp @@ -126,7 +126,28 @@ class mpi_comms : public comms_iface { RAFT_MPI_TRY(MPI_Bcast((void*)&id, sizeof(id), MPI_BYTE, 0, mpi_comm_)); // initializing NCCL - RAFT_NCCL_TRY(ncclCommInitRank(&nccl_comm_, size_, id, rank_)); + ncclConfig_t nccl_config = NCCL_CONFIG_INITIALIZER; + nccl_config.splitShare = 1; + RAFT_NCCL_TRY(ncclCommInitRankConfig(&nccl_comm_, size_, id, rank_, &nccl_config)); + + initialize(); + } + + mpi_comms(MPI_Comm mpi_comm, bool owns_mpi_comm, ncclComm_t nccl_comm, rmm::cuda_stream_view stream) + : owns_mpi_comm_(owns_mpi_comm), + mpi_comm_(mpi_comm), + nccl_comm_(nccl_comm), + size_(0), + rank_(1), + status_(stream), + next_request_id_(0), + stream_(stream) + { + int mpi_is_initialized = 0; + RAFT_MPI_TRY(MPI_Initialized(&mpi_is_initialized)); + RAFT_EXPECTS(mpi_is_initialized, "ERROR: MPI is not initialized!"); + RAFT_MPI_TRY(MPI_Comm_size(mpi_comm_, &size_)); + RAFT_MPI_TRY(MPI_Comm_rank(mpi_comm_, &rank_)); initialize(); } @@ -150,9 +171,11 @@ class mpi_comms : public comms_iface { std::unique_ptr comm_split(int color, int key) const { - MPI_Comm new_comm; - RAFT_MPI_TRY(MPI_Comm_split(mpi_comm_, color, key, &new_comm)); - return std::unique_ptr(new mpi_comms(new_comm, true, stream_)); + MPI_Comm new_mpi_comm; + RAFT_MPI_TRY(MPI_Comm_split(mpi_comm_, color, key, &new_mpi_comm)); + ncclComm_t new_nccl_comm{}; + RAFT_NCCL_TRY(ncclCommSplit(nccl_comm_, color, key, &new_nccl_comm, nullptr)); + return std::unique_ptr(new mpi_comms(new_mpi_comm, true, new_nccl_comm, stream_)); } void barrier() const diff --git a/cpp/include/raft/comms/detail/std_comms.hpp b/cpp/include/raft/comms/detail/std_comms.hpp index ed869e6cae..a5214fcf8b 100644 --- a/cpp/include/raft/comms/detail/std_comms.hpp +++ b/cpp/include/raft/comms/detail/std_comms.hpp @@ -140,55 +140,11 @@ class std_comms : public comms_iface { std::unique_ptr comm_split(int color, int key) const { - rmm::device_uvector d_colors(get_size(), stream_); - rmm::device_uvector d_keys(get_size(), stream_); - - update_device(d_colors.data() + get_rank(), &color, 1, stream_); - update_device(d_keys.data() + get_rank(), &key, 1, stream_); - - allgather(d_colors.data() + get_rank(), d_colors.data(), 1, datatype_t::INT32, stream_); - allgather(d_keys.data() + get_rank(), d_keys.data(), 1, datatype_t::INT32, stream_); - this->sync_stream(stream_); - - std::vector h_colors(get_size()); - std::vector h_keys(get_size()); - - update_host(h_colors.data(), d_colors.data(), get_size(), stream_); - update_host(h_keys.data(), d_keys.data(), get_size(), stream_); - - this->sync_stream(stream_); - - ncclComm_t nccl_comm; - - // Create a structure to allgather... - ncclUniqueId id{}; - rmm::device_uvector d_nccl_ids(get_size(), stream_); - - if (key == 0) { RAFT_NCCL_TRY(ncclGetUniqueId(&id)); } - - update_device(d_nccl_ids.data() + get_rank(), &id, 1, stream_); - - allgather(d_nccl_ids.data() + get_rank(), - d_nccl_ids.data(), - sizeof(ncclUniqueId), - datatype_t::UINT8, - stream_); - - auto offset = - std::distance(thrust::make_zip_iterator(h_colors.begin(), h_keys.begin()), - std::find_if(thrust::make_zip_iterator(h_colors.begin(), h_keys.begin()), - thrust::make_zip_iterator(h_colors.end(), h_keys.end()), - [color](auto tuple) { return thrust::get<0>(tuple) == color; })); - - auto subcomm_size = std::count(h_colors.begin(), h_colors.end(), color); - - update_host(&id, d_nccl_ids.data() + offset, 1, stream_); - - this->sync_stream(stream_); - - RAFT_NCCL_TRY(ncclCommInitRank(&nccl_comm, subcomm_size, id, key)); - - return std::unique_ptr(new std_comms(nccl_comm, subcomm_size, key, stream_, true)); + ncclComm_t new_nccl_comm{}; + RAFT_NCCL_TRY(ncclCommSplit(nccl_comm_, color, key, &new_nccl_comm, nullptr)); + int new_nccl_comm_size{}; + RAFT_NCCL_TRY(ncclCommCount(new_nccl_comm, &new_nccl_comm_size)); + return std::unique_ptr(new std_comms(new_nccl_comm, new_nccl_comm_size, key, stream_, true)); } void barrier() const