Skip to content

Commit

Permalink
a little code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
ChuckHastings committed Oct 12, 2023
1 parent 3d57693 commit d72bded
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 17 deletions.
31 changes: 15 additions & 16 deletions cpp/include/cugraph/mtmg/resource_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,16 @@ class resource_manager_t {
/**
* @brief add a remote GPU to the resource manager.
*
* @param rank The rank to assign to the local GPU
* @param remode_node_rank The rank assigned to the remote node
* @param rank The rank to assign to the remote GPU
*/
void register_remote_gpu(int rank, int remote_node_rank)
void register_remote_gpu(int rank)
{
std::lock_guard<std::mutex> lock(lock_);

CUGRAPH_EXPECTS(remote_rank_map_.find(rank) == remote_rank_map_.end(),
CUGRAPH_EXPECTS(remote_rank_set_.find(rank) == remote_rank_set_.end(),
"cannot register same rank multiple times");

remote_rank_map_.insert(std::pair(rank, remote_node_rank));
remote_rank_set_.insert(rank);
}

/**
Expand All @@ -151,7 +150,7 @@ class resource_manager_t {
std::for_each(ranks_to_include.begin(),
ranks_to_include.end(),
[&local_ranks = local_rank_map_,
&remote_ranks = remote_rank_map_,
&remote_ranks = remote_rank_set_,
&local_ranks_to_include](int rank) {
if (local_ranks.find(rank) == local_ranks.end()) {
CUGRAPH_EXPECTS(remote_ranks.find(rank) != remote_ranks.end(),
Expand Down Expand Up @@ -231,21 +230,21 @@ class resource_manager_t {
{
std::lock_guard<std::mutex> lock(lock_);

//
// C++20 mechanism:
// return std::vector<int>{ std::views::keys(local_rank_map_).begin(),
// std::views::keys(local_rank_map_).end() };
// Would need a bit more complicated to handle remote_rank_map_ also
//
std::vector<int> registered_ranks(local_rank_map_.size() + remote_rank_map_.size());
std::vector<int> registered_ranks(local_rank_map_.size() + remote_rank_set_.size());
std::transform(
local_rank_map_.begin(), local_rank_map_.end(), registered_ranks.begin(), [](auto pair) {
return pair.first;
});
std::transform(remote_rank_map_.begin(),
remote_rank_map_.end(),
#if 0
std::transform(remote_rank_set_.begin(),
remote_rank_set_.end(),
registered_ranks.begin() + local_rank_map_.size(),
[](auto pair) { return pair.first; });
#else
std::copy(remote_rank_set_.begin(),
remote_rank_set_.end(),
registered_ranks.begin() + local_rank_map_.size());
#endif

std::sort(registered_ranks.begin(), registered_ranks.end());
return registered_ranks;
Expand All @@ -254,7 +253,7 @@ class resource_manager_t {
private:
mutable std::mutex lock_{};
std::map<int, rmm::cuda_device_id> local_rank_map_{};
std::map<int, int> remote_rank_map_{};
std::set<int> remote_rank_set_{};
std::map<int, std::shared_ptr<rmm::mr::device_memory_resource>> per_device_rmm_resources_{};
};

Expand Down
2 changes: 1 addition & 1 deletion cpp/tests/mtmg/multi_node_threaded_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ class Tests_Multithreaded
num_gpus_file.close();

for (int j = 0; j < num_gpus_this_node; ++j) {
resource_manager.register_remote_gpu(node_rank++, i);
resource_manager.register_remote_gpu(node_rank++);
}
} else {
std::for_each(
Expand Down

0 comments on commit d72bded

Please sign in to comment.