Skip to content

Commit

Permalink
update copy_if_nosync to take a pointer to store the counter
Browse files Browse the repository at this point in the history
  • Loading branch information
seunghwak committed Sep 5, 2024
1 parent ebcbfb7 commit 3652c33
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 20 deletions.
6 changes: 3 additions & 3 deletions cpp/src/prims/detail/multi_stream_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ void copy_if_nosync(InputIterator input_first,
InputIterator input_last,
FlagIterator flag_first,
OutputIterator output_first,
raft::device_span<size_t> count /* size = 1 */,
rmm::cuda_stream_view stream_view)
{
CUGRAPH_EXPECTS(
Expand All @@ -72,14 +73,13 @@ void copy_if_nosync(InputIterator input_first,

size_t tmp_storage_bytes{0};
size_t input_size = static_cast<int>(thrust::distance(input_first, input_last));
rmm::device_scalar<int> num_copied(stream_view);

cub::DeviceSelect::Flagged(static_cast<void*>(nullptr),
tmp_storage_bytes,
input_first,
flag_first,
output_first,
num_copied.data(),
count.data(),
input_size,
stream_view);

Expand All @@ -90,7 +90,7 @@ void copy_if_nosync(InputIterator input_first,
input_first,
flag_first,
output_first,
num_copied.data(),
count.data(),
input_size,
stream_view);
}
Expand Down
46 changes: 29 additions & 17 deletions cpp/src/prims/detail/per_v_transform_reduce_e.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
#include <thrust/iterator/transform_iterator.h>
#include <thrust/optional.h>
#include <thrust/scatter.h>
#include <thrust/set_operations.h>
#include <thrust/transform_reduce.h>
#include <thrust/tuple.h>
#include <thrust/type_traits/integer_sequence.h>
Expand Down Expand Up @@ -1582,30 +1583,35 @@ void per_v_transform_reduce_e(raft::handle_t const& handle,
auto const minor_comm_size = minor_comm.get_size();

rmm::device_uvector<vertex_t> tmps(2, handle.get_stream());
thrust::tabulate(handle.get_thrust_policy(),
tmps.begin(),
tmps.end(),
[sorted_unique_key_first, v_list_size = static_cast<size_t>(thrust::distance(sorted_unique_key_first, sorted_unique_nzd_key_last))]__device__(size_t i) {
return (i == 0) ? *sorted_unique_key_first : (*(sorted_unique_key_first + (v_list_size - 1)) + 1);
});
raft::update_host(v_list_range.data(), tmps.data(), 2, handle.get_stream());
thrust::tabulate(
handle.get_thrust_policy(),
tmps.begin(),
tmps.end(),
[sorted_unique_key_first,
v_list_size = static_cast<size_t>(thrust::distance(
sorted_unique_key_first, sorted_unique_nzd_key_last))] __device__(size_t i) {
return (i == 0) ? *sorted_unique_key_first
: (*(sorted_unique_key_first + (v_list_size - 1)) + 1);
});
raft::update_host(v_list_range.data(), tmps.data(), 2, handle.get_stream());

if (minor_comm_size > 1) {
key_list_bitmap =
compute_vertex_list_bitmap_info(sorted_unique_key_first,
sorted_unique_nzd_key_last,
v_list_range[0],
v_list_range[1],
handle.get_stream());
key_list_bitmap = compute_vertex_list_bitmap_info(sorted_unique_key_first,
sorted_unique_nzd_key_last,
v_list_range[0],
v_list_range[1],
handle.get_stream());
}
}

// 6. collect local_key_list_sizes & use_bitmap_flags & key_segment_offsets

std::conditional_t<use_input_key, std::vector<size_t>, std::byte /* dummy */>
local_key_list_sizes{};
std::conditional_t<try_bitmap, std::vector<vertex_t>, std::byte /* dummy */> local_key_list_range_firsts{};
std::conditional_t<try_bitmap, std::vector<vertex_t>, std::byte /* dummy */> local_key_list_range_lasts{};
std::conditional_t<try_bitmap, std::vector<vertex_t>, std::byte /* dummy */>
local_key_list_range_firsts{};
std::conditional_t<try_bitmap, std::vector<vertex_t>, std::byte /* dummy */>
local_key_list_range_lasts{};
std::conditional_t<try_bitmap, std::vector<bool>, std::byte /* dummy */> use_bitmap_flags{};
std::conditional_t<use_input_key,
std::optional<std::vector<std::vector<size_t>>>,
Expand All @@ -1620,8 +1626,10 @@ void per_v_transform_reduce_e(raft::handle_t const& handle,
static_cast<size_t>(thrust::distance(sorted_unique_key_first, sorted_unique_nzd_key_last)),
handle.get_stream());
if constexpr (try_bitmap) {
local_key_list_range_firsts = host_scalar_allgather(minor_comm, v_list_range[0], handle.get_stream());
local_key_list_range_lasts = host_scalar_allgather(minor_comm, v_list_range[1], handle.get_stream());
local_key_list_range_firsts =
host_scalar_allgather(minor_comm, v_list_range[0], handle.get_stream());
local_key_list_range_lasts =
host_scalar_allgather(minor_comm, v_list_range[1], handle.get_stream());
auto tmp_flags = host_scalar_allgather(
minor_comm, key_list_bitmap ? uint8_t{1} : uint8_t{0}, handle.get_stream());
use_bitmap_flags.resize(tmp_flags.size());
Expand Down Expand Up @@ -2146,6 +2154,7 @@ void per_v_transform_reduce_e(raft::handle_t const& handle,
assert(edge_partition_selected_ranks_or_flags[j].index() == 0);
auto const& selected_ranks = std::get<0>(edge_partition_selected_ranks_or_flags[j]);
resize_dataframe_buffer(values, copy_size, loop_stream);
rmm::device_scalar<size_t> dummy(size_t{0}, loop_stream); // we already know the count
copy_if_nosync(
get_dataframe_buffer_begin(output_buffer),
get_dataframe_buffer_end(output_buffer),
Expand All @@ -2154,17 +2163,20 @@ void per_v_transform_reduce_e(raft::handle_t const& handle,
cuda::proclaim_return_type<bool>(
[minor_comm_rank] __device__(auto rank) { return rank == minor_comm_rank; })),
get_dataframe_buffer_begin(values),
raft::device_span<size_t>(dummy.data(), size_t{1}),
loop_stream);
}
} else {
if (copy_size > 0) {
assert(edge_partition_selected_ranks_or_flags[j].index() == 1);
auto& selected_flags = std::get<1>(edge_partition_selected_ranks_or_flags[j]);
resize_dataframe_buffer(values, copy_size, loop_stream);
rmm::device_scalar<size_t> dummy(size_t{0}, loop_stream); // we already know the count
copy_if_nosync(get_dataframe_buffer_begin(output_buffer),
get_dataframe_buffer_end(output_buffer),
(*selected_flags).begin(),
get_dataframe_buffer_begin(values),
raft::device_span<size_t>(dummy.data(), size_t{1}),
loop_stream);
(*selected_flags).resize(0, loop_stream);
(*selected_flags).shrink_to_fit(loop_stream);
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/prims/vertex_frontier.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ void device_bcast_vertex_list(
assert((comm.get_rank() != root) || (std::get<0>(v_list).size() == tmp_bitmap.size()));
device_bcast(
comm, std::get<0>(v_list).data(), tmp_bitmap.data(), tmp_bitmap.size(), root, stream_view);
rmm::device_scalar<size_t> dummy(size_t{0}, stream_view); // we already know the count
detail::copy_if_nosync(
thrust::make_counting_iterator(vertex_range_first),
thrust::make_counting_iterator(vertex_range_last),
Expand All @@ -200,6 +201,7 @@ void device_bcast_vertex_list(
packed_bool_empty_mask());
})),
output_v_first,
raft::device_span<size_t>(dummy.data(), size_t{1}),
stream_view);
} else {
device_bcast(comm, std::get<1>(v_list), output_v_first, v_list_size, root, stream_view);
Expand Down

0 comments on commit 3652c33

Please sign in to comment.