From 3652c33a19644d993150436f432800fbb1beaec8 Mon Sep 17 00:00:00 2001 From: Seunghwa Kang Date: Wed, 4 Sep 2024 18:57:07 -0700 Subject: [PATCH] update copy_if_nosync to take a pointer to store the counter --- cpp/src/prims/detail/multi_stream_utils.cuh | 6 +-- .../prims/detail/per_v_transform_reduce_e.cuh | 46 ++++++++++++------- cpp/src/prims/vertex_frontier.cuh | 2 + 3 files changed, 34 insertions(+), 20 deletions(-) diff --git a/cpp/src/prims/detail/multi_stream_utils.cuh b/cpp/src/prims/detail/multi_stream_utils.cuh index 2f03a22bff5..7a370d5e49f 100644 --- a/cpp/src/prims/detail/multi_stream_utils.cuh +++ b/cpp/src/prims/detail/multi_stream_utils.cuh @@ -62,6 +62,7 @@ void copy_if_nosync(InputIterator input_first, InputIterator input_last, FlagIterator flag_first, OutputIterator output_first, + raft::device_span count /* size = 1 */, rmm::cuda_stream_view stream_view) { CUGRAPH_EXPECTS( @@ -72,14 +73,13 @@ void copy_if_nosync(InputIterator input_first, size_t tmp_storage_bytes{0}; size_t input_size = static_cast(thrust::distance(input_first, input_last)); - rmm::device_scalar num_copied(stream_view); cub::DeviceSelect::Flagged(static_cast(nullptr), tmp_storage_bytes, input_first, flag_first, output_first, - num_copied.data(), + count.data(), input_size, stream_view); @@ -90,7 +90,7 @@ void copy_if_nosync(InputIterator input_first, input_first, flag_first, output_first, - num_copied.data(), + count.data(), input_size, stream_view); } diff --git a/cpp/src/prims/detail/per_v_transform_reduce_e.cuh b/cpp/src/prims/detail/per_v_transform_reduce_e.cuh index 0719d9a8df3..bc15c39a4dc 100644 --- a/cpp/src/prims/detail/per_v_transform_reduce_e.cuh +++ b/cpp/src/prims/detail/per_v_transform_reduce_e.cuh @@ -54,6 +54,7 @@ #include #include #include +#include #include #include #include @@ -1582,21 +1583,24 @@ void per_v_transform_reduce_e(raft::handle_t const& handle, auto const minor_comm_size = minor_comm.get_size(); rmm::device_uvector tmps(2, handle.get_stream()); - thrust::tabulate(handle.get_thrust_policy(), - tmps.begin(), - tmps.end(), - [sorted_unique_key_first, v_list_size = static_cast(thrust::distance(sorted_unique_key_first, sorted_unique_nzd_key_last))]__device__(size_t i) { - return (i == 0) ? *sorted_unique_key_first : (*(sorted_unique_key_first + (v_list_size - 1)) + 1); - }); - raft::update_host(v_list_range.data(), tmps.data(), 2, handle.get_stream()); + thrust::tabulate( + handle.get_thrust_policy(), + tmps.begin(), + tmps.end(), + [sorted_unique_key_first, + v_list_size = static_cast(thrust::distance( + sorted_unique_key_first, sorted_unique_nzd_key_last))] __device__(size_t i) { + return (i == 0) ? *sorted_unique_key_first + : (*(sorted_unique_key_first + (v_list_size - 1)) + 1); + }); + raft::update_host(v_list_range.data(), tmps.data(), 2, handle.get_stream()); if (minor_comm_size > 1) { - key_list_bitmap = - compute_vertex_list_bitmap_info(sorted_unique_key_first, - sorted_unique_nzd_key_last, - v_list_range[0], - v_list_range[1], - handle.get_stream()); + key_list_bitmap = compute_vertex_list_bitmap_info(sorted_unique_key_first, + sorted_unique_nzd_key_last, + v_list_range[0], + v_list_range[1], + handle.get_stream()); } } @@ -1604,8 +1608,10 @@ void per_v_transform_reduce_e(raft::handle_t const& handle, std::conditional_t, std::byte /* dummy */> local_key_list_sizes{}; - std::conditional_t, std::byte /* dummy */> local_key_list_range_firsts{}; - std::conditional_t, std::byte /* dummy */> local_key_list_range_lasts{}; + std::conditional_t, std::byte /* dummy */> + local_key_list_range_firsts{}; + std::conditional_t, std::byte /* dummy */> + local_key_list_range_lasts{}; std::conditional_t, std::byte /* dummy */> use_bitmap_flags{}; std::conditional_t>>, @@ -1620,8 +1626,10 @@ void per_v_transform_reduce_e(raft::handle_t const& handle, static_cast(thrust::distance(sorted_unique_key_first, sorted_unique_nzd_key_last)), handle.get_stream()); if constexpr (try_bitmap) { - local_key_list_range_firsts = host_scalar_allgather(minor_comm, v_list_range[0], handle.get_stream()); - local_key_list_range_lasts = host_scalar_allgather(minor_comm, v_list_range[1], handle.get_stream()); + local_key_list_range_firsts = + host_scalar_allgather(minor_comm, v_list_range[0], handle.get_stream()); + local_key_list_range_lasts = + host_scalar_allgather(minor_comm, v_list_range[1], handle.get_stream()); auto tmp_flags = host_scalar_allgather( minor_comm, key_list_bitmap ? uint8_t{1} : uint8_t{0}, handle.get_stream()); use_bitmap_flags.resize(tmp_flags.size()); @@ -2146,6 +2154,7 @@ void per_v_transform_reduce_e(raft::handle_t const& handle, assert(edge_partition_selected_ranks_or_flags[j].index() == 0); auto const& selected_ranks = std::get<0>(edge_partition_selected_ranks_or_flags[j]); resize_dataframe_buffer(values, copy_size, loop_stream); + rmm::device_scalar dummy(size_t{0}, loop_stream); // we already know the count copy_if_nosync( get_dataframe_buffer_begin(output_buffer), get_dataframe_buffer_end(output_buffer), @@ -2154,6 +2163,7 @@ void per_v_transform_reduce_e(raft::handle_t const& handle, cuda::proclaim_return_type( [minor_comm_rank] __device__(auto rank) { return rank == minor_comm_rank; })), get_dataframe_buffer_begin(values), + raft::device_span(dummy.data(), size_t{1}), loop_stream); } } else { @@ -2161,10 +2171,12 @@ void per_v_transform_reduce_e(raft::handle_t const& handle, assert(edge_partition_selected_ranks_or_flags[j].index() == 1); auto& selected_flags = std::get<1>(edge_partition_selected_ranks_or_flags[j]); resize_dataframe_buffer(values, copy_size, loop_stream); + rmm::device_scalar dummy(size_t{0}, loop_stream); // we already know the count copy_if_nosync(get_dataframe_buffer_begin(output_buffer), get_dataframe_buffer_end(output_buffer), (*selected_flags).begin(), get_dataframe_buffer_begin(values), + raft::device_span(dummy.data(), size_t{1}), loop_stream); (*selected_flags).resize(0, loop_stream); (*selected_flags).shrink_to_fit(loop_stream); diff --git a/cpp/src/prims/vertex_frontier.cuh b/cpp/src/prims/vertex_frontier.cuh index 9c7c84e9719..f92aec680a9 100644 --- a/cpp/src/prims/vertex_frontier.cuh +++ b/cpp/src/prims/vertex_frontier.cuh @@ -188,6 +188,7 @@ void device_bcast_vertex_list( assert((comm.get_rank() != root) || (std::get<0>(v_list).size() == tmp_bitmap.size())); device_bcast( comm, std::get<0>(v_list).data(), tmp_bitmap.data(), tmp_bitmap.size(), root, stream_view); + rmm::device_scalar dummy(size_t{0}, stream_view); // we already know the count detail::copy_if_nosync( thrust::make_counting_iterator(vertex_range_first), thrust::make_counting_iterator(vertex_range_last), @@ -200,6 +201,7 @@ void device_bcast_vertex_list( packed_bool_empty_mask()); })), output_v_first, + raft::device_span(dummy.data(), size_t{1}), stream_view); } else { device_bcast(comm, std::get<1>(v_list), output_v_first, v_list_size, root, stream_view);