Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mtmg updates for rmm #4031

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions cpp/include/cugraph/mtmg/detail/device_shared_wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ class device_shared_wrapper_t {
{
std::lock_guard<std::mutex> lock(lock_);

auto pos = objects_.find(handle.get_local_rank());
auto pos = objects_.find(handle.get_rank());
CUGRAPH_EXPECTS(pos == objects_.end(), "Cannot overwrite wrapped object");

objects_.insert(std::make_pair(handle.get_local_rank(), std::move(obj)));
objects_.insert(std::make_pair(handle.get_rank(), std::move(obj)));
}

/**
Expand Down Expand Up @@ -90,7 +90,7 @@ class device_shared_wrapper_t {
{
std::lock_guard<std::mutex> lock(lock_);

auto pos = objects_.find(handle.get_local_rank());
auto pos = objects_.find(handle.get_rank());
CUGRAPH_EXPECTS(pos != objects_.end(), "Uninitialized wrapped object");

return pos->second;
Expand All @@ -106,7 +106,7 @@ class device_shared_wrapper_t {
{
std::lock_guard<std::mutex> lock(lock_);

auto pos = objects_.find(handle.get_local_rank());
auto pos = objects_.find(handle.get_rank());

CUGRAPH_EXPECTS(pos != objects_.end(), "Uninitialized wrapped object");

Expand Down
21 changes: 7 additions & 14 deletions cpp/include/cugraph/mtmg/handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,19 @@ namespace mtmg {
*
*/
class handle_t {
handle_t(handle_t const&) = delete;
handle_t operator=(handle_t const&) = delete;

public:
/**
* @brief Constructor
*
* @param raft_handle Raft handle for the resources
* @param thread_rank Rank for this thread
* @param device_id Device id for the device this handle operates on
*/
handle_t(raft::handle_t const& raft_handle, int thread_rank, size_t device_id)
: raft_handle_(raft_handle),
thread_rank_(thread_rank),
local_rank_(raft_handle.get_comms().get_rank()), // FIXME: update for multi-node
device_id_(device_id)
handle_t(raft::handle_t const& raft_handle, int thread_rank, rmm::cuda_device_id device_id)
: raft_handle_(raft_handle), thread_rank_(thread_rank), device_id_raii_(device_id)
{
}

Expand Down Expand Up @@ -118,18 +119,10 @@ class handle_t {
*/
int get_rank() const { return raft_handle_.get_comms().get_rank(); }

/**
* @brief Get local gpu rank
*
* @return local gpu rank
*/
int get_local_rank() const { return local_rank_; }

private:
raft::handle_t const& raft_handle_;
int thread_rank_;
int local_rank_;
size_t device_id_;
rmm::cuda_set_device_raii device_id_raii_;
};

} // namespace mtmg
Expand Down
10 changes: 2 additions & 8 deletions cpp/include/cugraph/mtmg/instance_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,10 @@ class instance_manager_t {

~instance_manager_t()
{
int current_device{};
RAFT_CUDA_TRY(cudaGetDevice(&current_device));

for (size_t i = 0; i < nccl_comms_.size(); ++i) {
RAFT_CUDA_TRY(cudaSetDevice(device_ids_[i].value()));
rmm::cuda_set_device_raii local_set_device(device_ids_[i]);
RAFT_NCCL_TRY(ncclCommDestroy(*nccl_comms_[i]));
}

RAFT_CUDA_TRY(cudaSetDevice(current_device));
}

/**
Expand All @@ -75,8 +70,7 @@ class instance_manager_t {
int gpu_id = local_id % raft_handle_.size();
int thread_id = local_id / raft_handle_.size();

RAFT_CUDA_TRY(cudaSetDevice(device_ids_[gpu_id].value()));
return handle_t(*raft_handle_[gpu_id], thread_id, static_cast<size_t>(gpu_id));
return handle_t(*raft_handle_[gpu_id], thread_id, device_ids_[gpu_id]);
}

/**
Expand Down
11 changes: 3 additions & 8 deletions cpp/include/cugraph/mtmg/resource_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class resource_manager_t {

local_rank_map_.insert(std::pair(global_rank, local_device_id));

RAFT_CUDA_TRY(cudaSetDevice(local_device_id.value()));
rmm::cuda_set_device_raii local_set_device(local_device_id);

// FIXME: There is a bug in the cuda_memory_resource that results in a Hang.
// using the pool resource as a work-around.
Expand Down Expand Up @@ -182,14 +182,12 @@ class resource_manager_t {
--gpu_row_comm_size;
}

int current_device{};
RAFT_CUDA_TRY(cudaGetDevice(&current_device));
RAFT_NCCL_TRY(ncclGroupStart());

for (size_t i = 0; i < local_ranks_to_include.size(); ++i) {
int rank = local_ranks_to_include[i];
auto pos = local_rank_map_.find(rank);
RAFT_CUDA_TRY(cudaSetDevice(pos->second.value()));
rmm::cuda_set_device_raii local_set_device(pos->second);

nccl_comms.push_back(std::make_unique<ncclComm_t>());
handles.push_back(
Expand All @@ -204,7 +202,6 @@ class resource_manager_t {
handles[i].get(), *nccl_comms[i], ranks_to_include.size(), rank);
}
RAFT_NCCL_TRY(ncclGroupEnd());
RAFT_CUDA_TRY(cudaSetDevice(current_device));

std::vector<std::thread> running_threads;

Expand All @@ -217,9 +214,7 @@ class resource_manager_t {
&device_ids,
&nccl_comms,
&handles]() {
int rank = local_ranks_to_include[idx];
RAFT_CUDA_TRY(cudaSetDevice(device_ids[idx].value()));

rmm::cuda_set_device_raii local_set_device(device_ids[idx]);
cugraph::partition_manager::init_subcomm(*handles[idx], gpu_row_comm_size);
});
}
Expand Down
1 change: 1 addition & 0 deletions cpp/src/structure/detail/structure_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ std::tuple<size_t, rmm::device_uvector<uint32_t>> mark_entries(raft::handle_t co
return word;
});

// FIXME: use detail::count_set_bits
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have another PR ready to stage right after this where I've made this change.

size_t bit_count = thrust::transform_reduce(
handle.get_thrust_policy(),
marked_entries.begin(),
Expand Down
21 changes: 18 additions & 3 deletions cpp/tests/mtmg/threaded_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,25 @@ class Tests_Multithreaded
input_usecase.template construct_edgelist<vertex_t, weight_t>(
handle, multithreaded_usecase.test_weighted, false, false);

rmm::device_uvector<vertex_t> d_unique_vertices(2 * d_src_v.size(), handle.get_stream());
thrust::copy(
handle.get_thrust_policy(), d_src_v.begin(), d_src_v.end(), d_unique_vertices.begin());
thrust::copy(handle.get_thrust_policy(),
d_dst_v.begin(),
d_dst_v.end(),
d_unique_vertices.begin() + d_src_v.size());
thrust::sort(handle.get_thrust_policy(), d_unique_vertices.begin(), d_unique_vertices.end());

d_unique_vertices.resize(thrust::distance(d_unique_vertices.begin(),
thrust::unique(handle.get_thrust_policy(),
d_unique_vertices.begin(),
d_unique_vertices.end())),
handle.get_stream());

auto h_src_v = cugraph::test::to_host(handle, d_src_v);
auto h_dst_v = cugraph::test::to_host(handle, d_dst_v);
auto h_weights_v = cugraph::test::to_host(handle, d_weights_v);
auto unique_vertices = cugraph::test::to_host(handle, d_vertices_v);
auto unique_vertices = cugraph::test::to_host(handle, d_unique_vertices);

// Load edgelist from different threads. We'll use more threads than GPUs here
for (int i = 0; i < num_threads; ++i) {
Expand Down Expand Up @@ -293,13 +308,13 @@ class Tests_Multithreaded
num_threads]() {
auto thread_handle = instance_manager->get_handle();

auto number_of_vertices = unique_vertices->size();
auto number_of_vertices = unique_vertices.size();

std::vector<vertex_t> my_vertex_list;
my_vertex_list.reserve((number_of_vertices + num_threads - 1) / num_threads);

for (size_t j = i; j < number_of_vertices; j += num_threads) {
my_vertex_list.push_back((*unique_vertices)[j]);
my_vertex_list.push_back(unique_vertices[j]);
}

rmm::device_uvector<vertex_t> d_my_vertex_list(my_vertex_list.size(),
Expand Down