Skip to content

Commit

Permalink
Use RMM features for managing devices in RAII manner
Browse files Browse the repository at this point in the history
  • Loading branch information
ChuckHastings committed Nov 29, 2023
1 parent 96c39cc commit 6b0c710
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 37 deletions.
8 changes: 4 additions & 4 deletions cpp/include/cugraph/mtmg/detail/device_shared_wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ class device_shared_wrapper_t {
{
std::lock_guard<std::mutex> lock(lock_);

auto pos = objects_.find(handle.get_local_rank());
auto pos = objects_.find(handle.get_rank());
CUGRAPH_EXPECTS(pos == objects_.end(), "Cannot overwrite wrapped object");

objects_.insert(std::make_pair(handle.get_local_rank(), std::move(obj)));
objects_.insert(std::make_pair(handle.get_rank(), std::move(obj)));
}

/**
Expand Down Expand Up @@ -90,7 +90,7 @@ class device_shared_wrapper_t {
{
std::lock_guard<std::mutex> lock(lock_);

auto pos = objects_.find(handle.get_local_rank());
auto pos = objects_.find(handle.get_rank());
CUGRAPH_EXPECTS(pos != objects_.end(), "Uninitialized wrapped object");

return pos->second;
Expand All @@ -106,7 +106,7 @@ class device_shared_wrapper_t {
{
std::lock_guard<std::mutex> lock(lock_);

auto pos = objects_.find(handle.get_local_rank());
auto pos = objects_.find(handle.get_rank());

CUGRAPH_EXPECTS(pos != objects_.end(), "Uninitialized wrapped object");

Expand Down
21 changes: 7 additions & 14 deletions cpp/include/cugraph/mtmg/handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,19 @@ namespace mtmg {
*
*/
class handle_t {
handle_t(handle_t const&) = delete;
handle_t operator=(handle_t const&) = delete;

public:
/**
* @brief Constructor
*
* @param raft_handle Raft handle for the resources
* @param thread_rank Rank for this thread
* @param device_id Device id for the device this handle operates on
*/
handle_t(raft::handle_t const& raft_handle, int thread_rank, size_t device_id)
: raft_handle_(raft_handle),
thread_rank_(thread_rank),
local_rank_(raft_handle.get_comms().get_rank()), // FIXME: update for multi-node
device_id_(device_id)
handle_t(raft::handle_t const& raft_handle, int thread_rank, rmm::cuda_device_id device_id)
: raft_handle_(raft_handle), thread_rank_(thread_rank), device_id_raii_(device_id)
{
}

Expand Down Expand Up @@ -118,18 +119,10 @@ class handle_t {
*/
int get_rank() const { return raft_handle_.get_comms().get_rank(); }

/**
* @brief Get local gpu rank
*
* @return local gpu rank
*/
int get_local_rank() const { return local_rank_; }

private:
raft::handle_t const& raft_handle_;
int thread_rank_;
int local_rank_;
size_t device_id_;
rmm::cuda_set_device_raii device_id_raii_;
};

} // namespace mtmg
Expand Down
10 changes: 2 additions & 8 deletions cpp/include/cugraph/mtmg/instance_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,10 @@ class instance_manager_t {

~instance_manager_t()
{
int current_device{};
RAFT_CUDA_TRY(cudaGetDevice(&current_device));

for (size_t i = 0; i < nccl_comms_.size(); ++i) {
RAFT_CUDA_TRY(cudaSetDevice(device_ids_[i].value()));
rmm::cuda_set_device_raii local_set_device(device_ids_[i]);
RAFT_NCCL_TRY(ncclCommDestroy(*nccl_comms_[i]));
}

RAFT_CUDA_TRY(cudaSetDevice(current_device));
}

/**
Expand All @@ -75,8 +70,7 @@ class instance_manager_t {
int gpu_id = local_id % raft_handle_.size();
int thread_id = local_id / raft_handle_.size();

RAFT_CUDA_TRY(cudaSetDevice(device_ids_[gpu_id].value()));
return handle_t(*raft_handle_[gpu_id], thread_id, static_cast<size_t>(gpu_id));
return handle_t(*raft_handle_[gpu_id], thread_id, device_ids_[gpu_id]);
}

/**
Expand Down
11 changes: 3 additions & 8 deletions cpp/include/cugraph/mtmg/resource_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class resource_manager_t {

local_rank_map_.insert(std::pair(global_rank, local_device_id));

RAFT_CUDA_TRY(cudaSetDevice(local_device_id.value()));
rmm::cuda_set_device_raii local_set_device(local_device_id);

// FIXME: There is a bug in the cuda_memory_resource that results in a Hang.
// using the pool resource as a work-around.
Expand Down Expand Up @@ -182,14 +182,12 @@ class resource_manager_t {
--gpu_row_comm_size;
}

int current_device{};
RAFT_CUDA_TRY(cudaGetDevice(&current_device));
RAFT_NCCL_TRY(ncclGroupStart());

for (size_t i = 0; i < local_ranks_to_include.size(); ++i) {
int rank = local_ranks_to_include[i];
auto pos = local_rank_map_.find(rank);
RAFT_CUDA_TRY(cudaSetDevice(pos->second.value()));
rmm::cuda_set_device_raii local_set_device(pos->second);

nccl_comms.push_back(std::make_unique<ncclComm_t>());
handles.push_back(
Expand All @@ -204,7 +202,6 @@ class resource_manager_t {
handles[i].get(), *nccl_comms[i], ranks_to_include.size(), rank);
}
RAFT_NCCL_TRY(ncclGroupEnd());
RAFT_CUDA_TRY(cudaSetDevice(current_device));

std::vector<std::thread> running_threads;

Expand All @@ -217,9 +214,7 @@ class resource_manager_t {
&device_ids,
&nccl_comms,
&handles]() {
int rank = local_ranks_to_include[idx];
RAFT_CUDA_TRY(cudaSetDevice(device_ids[idx].value()));

rmm::cuda_set_device_raii local_set_device(device_ids[idx]);
cugraph::partition_manager::init_subcomm(*handles[idx], gpu_row_comm_size);
});
}
Expand Down
21 changes: 18 additions & 3 deletions cpp/tests/mtmg/threaded_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,25 @@ class Tests_Multithreaded
input_usecase.template construct_edgelist<vertex_t, weight_t>(
handle, multithreaded_usecase.test_weighted, false, false);

rmm::device_uvector<vertex_t> d_unique_vertices(2 * d_src_v.size(), handle.get_stream());
thrust::copy(
handle.get_thrust_policy(), d_src_v.begin(), d_src_v.end(), d_unique_vertices.begin());
thrust::copy(handle.get_thrust_policy(),
d_dst_v.begin(),
d_dst_v.end(),
d_unique_vertices.begin() + d_src_v.size());
thrust::sort(handle.get_thrust_policy(), d_unique_vertices.begin(), d_unique_vertices.end());

d_unique_vertices.resize(thrust::distance(d_unique_vertices.begin(),
thrust::unique(handle.get_thrust_policy(),
d_unique_vertices.begin(),
d_unique_vertices.end())),
handle.get_stream());

auto h_src_v = cugraph::test::to_host(handle, d_src_v);
auto h_dst_v = cugraph::test::to_host(handle, d_dst_v);
auto h_weights_v = cugraph::test::to_host(handle, d_weights_v);
auto unique_vertices = cugraph::test::to_host(handle, d_vertices_v);
auto unique_vertices = cugraph::test::to_host(handle, d_unique_vertices);

// Load edgelist from different threads. We'll use more threads than GPUs here
for (int i = 0; i < num_threads; ++i) {
Expand Down Expand Up @@ -293,13 +308,13 @@ class Tests_Multithreaded
num_threads]() {
auto thread_handle = instance_manager->get_handle();

auto number_of_vertices = unique_vertices->size();
auto number_of_vertices = unique_vertices.size();

std::vector<vertex_t> my_vertex_list;
my_vertex_list.reserve((number_of_vertices + num_threads - 1) / num_threads);

for (size_t j = i; j < number_of_vertices; j += num_threads) {
my_vertex_list.push_back((*unique_vertices)[j]);
my_vertex_list.push_back(unique_vertices[j]);
}

rmm::device_uvector<vertex_t> d_my_vertex_list(my_vertex_list.size(),
Expand Down

0 comments on commit 6b0c710

Please sign in to comment.