Skip to content

Commit

Permalink
address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ChuckHastings committed Oct 19, 2023
1 parent d88082e commit 235f2b6
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 60 deletions.
93 changes: 49 additions & 44 deletions cpp/include/cugraph/mtmg/detail/per_device_edgelist.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,7 @@ class per_device_edgelist_t {
std::optional<raft::host_span<edge_t const>> edge_id,
std::optional<raft::host_span<edge_type_t const>> edge_type)
{
std::vector<std::tuple<vertex_t*, vertex_t const*, size_t>> src_copies;
std::vector<std::tuple<vertex_t*, vertex_t const*, size_t>> dst_copies;
std::vector<std::tuple<weight_t*, weight_t const*, size_t>> wgt_copies;
std::vector<std::tuple<edge_t*, edge_t const*, size_t>> edge_id_copies;
std::vector<std::tuple<edge_type_t*, edge_type_t const*, size_t>> edge_type_copies;
std::vector<std::tuple<size_t, size_t, size_t, size_t>> copy_positions;

{
std::lock_guard<std::mutex> lock(lock_);
Expand All @@ -140,52 +136,61 @@ class per_device_edgelist_t {
while (count > 0) {
size_t copy_count = std::min(count, (src_.back().size() - current_pos_));

src_copies.push_back(
std::make_tuple(src_.back().begin() + current_pos_, src.begin() + pos, copy_count));
dst_copies.push_back(
std::make_tuple(dst_.back().begin() + current_pos_, dst.begin() + pos, copy_count));
if (wgt)
wgt_copies.push_back(
std::make_tuple(wgt_->back().begin() + current_pos_, wgt->begin() + pos, copy_count));
if (edge_id)
edge_id_copies.push_back(std::make_tuple(
edge_id_->back().begin() + current_pos_, edge_id->begin() + pos, copy_count));
if (edge_type)
edge_type_copies.push_back(std::make_tuple(
edge_type_->back().begin() + current_pos_, edge_type->begin() + pos, copy_count));
copy_positions.push_back(std::make_tuple(src_.size() - 1, current_pos_, pos, copy_count));

count -= copy_count;
pos += copy_count;
current_pos_ += copy_count;

if (current_pos_ == src_.back().size()) { create_new_buffers(handle); }
}
}

std::for_each(src_copies.begin(), src_copies.end(), [&handle](auto tuple) {
raft::update_device(
std::get<0>(tuple), std::get<1>(tuple), std::get<2>(tuple), handle.get_stream());
});

std::for_each(dst_copies.begin(), dst_copies.end(), [&handle](auto tuple) {
raft::update_device(
std::get<0>(tuple), std::get<1>(tuple), std::get<2>(tuple), handle.get_stream());
});

std::for_each(wgt_copies.begin(), wgt_copies.end(), [&handle](auto tuple) {
raft::update_device(
std::get<0>(tuple), std::get<1>(tuple), std::get<2>(tuple), handle.get_stream());
});

std::for_each(edge_id_copies.begin(), edge_id_copies.end(), [&handle](auto tuple) {
raft::update_device(
std::get<0>(tuple), std::get<1>(tuple), std::get<2>(tuple), handle.get_stream());
});

std::for_each(edge_type_copies.begin(), edge_type_copies.end(), [&handle](auto tuple) {
raft::update_device(
std::get<0>(tuple), std::get<1>(tuple), std::get<2>(tuple), handle.get_stream());
});

handle.raft_handle().sync_stream(handle.get_stream());
std::for_each(copy_positions.begin(),
copy_positions.end(),
[&handle,
&this_src = src_,
&src,
&this_dst = dst_,
&dst,
&this_wgt = wgt_,
&wgt,
&this_edge_id = edge_id_,
&edge_id,
&this_edge_type = edge_type_,
&edge_type](auto tuple) {
auto [buffer_idx, buffer_pos, input_pos, copy_count] = tuple;

raft::update_device(this_src[buffer_idx].begin() + buffer_pos,
src.begin() + input_pos,
copy_count,
handle.get_stream());

raft::update_device(this_dst[buffer_idx].begin() + buffer_pos,
dst.begin() + input_pos,
copy_count,
handle.get_stream());

if (this_wgt)
raft::update_device((*this_wgt)[buffer_idx].begin() + buffer_pos,
wgt->begin() + input_pos,
copy_count,
handle.get_stream());

if (this_edge_id)
raft::update_device((*this_edge_id)[buffer_idx].begin() + buffer_pos,
edge_id->begin() + input_pos,
copy_count,
handle.get_stream());

if (this_edge_type)
raft::update_device((*this_edge_type)[buffer_idx].begin() + buffer_pos,
edge_type->begin() + input_pos,
copy_count,
handle.get_stream());
});

handle.sync_stream();
}

/**
Expand Down
33 changes: 33 additions & 0 deletions cpp/include/cugraph/mtmg/handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

#include <raft/core/handle.hpp>

#include <rmm/exec_policy.hpp>

namespace cugraph {
namespace mtmg {

Expand Down Expand Up @@ -64,6 +66,37 @@ class handle_t {
: raft_handle_.get_stream();
}

/**
* @brief Sync on the cuda stream
*
* @param stream Which stream to synchronize (defaults to the stream for this handle)
*/
void sync_stream(rmm::cuda_stream_view stream) const { raft_handle_.sync_stream(stream); }

/**
* @brief Sync on the cuda stream for this handle
*/
void sync_stream() const { sync_stream(get_stream()); }

/**
* @brief get thrust policy for the stream
*
* @param stream Which stream to use for this thrust call
*
* @return exec policy using the current stream
*/
rmm::exec_policy get_thrust_policy(rmm::cuda_stream_view stream) const
{
return rmm::exec_policy(stream);
}

/**
* @brief get thrust policy for the stream for this handle
*
* @return exec policy using the current stream
*/
rmm::exec_policy get_thrust_policy() const { return get_thrust_policy(get_stream()); }

/**
* @brief Get thread rank
*
Expand Down
7 changes: 3 additions & 4 deletions cpp/include/cugraph/mtmg/resource_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
#include <raft/comms/std_comms.hpp>

#include <rmm/cuda_device.hpp>
#include <rmm/exec_policy.hpp>
#include <rmm/mr/device/cuda_memory_resource.hpp>
#include <rmm/mr/device/owning_wrapper.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>
Expand Down Expand Up @@ -125,8 +124,9 @@ class resource_manager_t {
*
* @return unique pointer to instance manager
*/
std::unique_ptr<instance_manager_t> create_instance_manager(
std::vector<int> ranks_to_include, ncclUniqueId instance_manager_id) const
std::unique_ptr<instance_manager_t> create_instance_manager(std::vector<int> ranks_to_include,
ncclUniqueId instance_manager_id,
size_t n_streams = 16) const
{
std::for_each(
ranks_to_include.begin(), ranks_to_include.end(), [local_ranks = local_rank_map_](int rank) {
Expand Down Expand Up @@ -154,7 +154,6 @@ class resource_manager_t {
auto pos = local_rank_map_.find(rank);
RAFT_CUDA_TRY(cudaSetDevice(pos->second.value()));

size_t n_streams{16};
nccl_comms.push_back(std::make_unique<ncclComm_t>());
handles.push_back(
std::make_unique<raft::handle_t>(rmm::cuda_stream_per_thread,
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/mtmg/vertex_result.cu
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ rmm::device_uvector<result_t> vertex_result_view_t<result_t>::gather(
return vertex_partition.local_vertex_partition_offset_from_vertex_nocheck(v);
});

thrust::gather(rmm::exec_policy(handle.get_stream()),
thrust::gather(handle.get_thrust_policy(),
iter,
iter + local_vertices.size(),
wrapped.begin(),
Expand All @@ -118,7 +118,7 @@ rmm::device_uvector<result_t> vertex_result_view_t<result_t>::gather(
//
// Finally, reorder result
//
thrust::scatter(rmm::exec_policy(handle.get_stream()),
thrust::scatter(handle.get_thrust_policy(),
tmp_result.begin(),
tmp_result.end(),
vertex_pos.begin(),
Expand Down
11 changes: 1 addition & 10 deletions cpp/tests/mtmg/threaded_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class Tests_Multithreaded
ncclGetUniqueId(&instance_manager_id);

auto instance_manager = resource_manager.create_instance_manager(
resource_manager.registered_ranks(), instance_manager_id);
resource_manager.registered_ranks(), instance_manager_id, 4);

cugraph::mtmg::edgelist_t<vertex_t, weight_t, edge_t, edge_type_t> edgelist;
cugraph::mtmg::graph_t<vertex_t, edge_t, true, multi_gpu> graph;
Expand Down Expand Up @@ -172,15 +172,6 @@ class Tests_Multithreaded
per_thread_edgelist(edgelist.get(thread_handle), thread_buffer_size);

for (size_t j = i; j < h_src_v.size(); j += num_threads) {
#if 0
if (h_weights_v) {
thread_edgelist.append(
thread_handle, h_src_v[j], h_dst_v[j], (*h_weights_v)[j], std::nullopt, std::nullopt);
} else {
thread_edgelist.append(
thread_handle, h_src_v[j], h_dst_v[j], std::nullopt, std::nullopt, std::nullopt);
}
#endif
per_thread_edgelist.append(
thread_handle,
h_src_v[j],
Expand Down

0 comments on commit 235f2b6

Please sign in to comment.