From afc00eee1410211bba1e44cb1cc2a704f7d17367 Mon Sep 17 00:00:00 2001 From: Charles Hastings Date: Fri, 1 Dec 2023 12:53:41 -0800 Subject: [PATCH] There are mask utilities that perform some of the functions here, use those instead of replicating --- cpp/src/structure/detail/structure_utils.cuh | 33 +++++------------ cpp/src/structure/remove_multi_edges_impl.cuh | 37 +++++++++---------- cpp/src/structure/remove_self_loops_impl.cuh | 36 +++++++++--------- 3 files changed, 45 insertions(+), 61 deletions(-) diff --git a/cpp/src/structure/detail/structure_utils.cuh b/cpp/src/structure/detail/structure_utils.cuh index 7630d5855a0..f0f729bce18 100644 --- a/cpp/src/structure/detail/structure_utils.cuh +++ b/cpp/src/structure/detail/structure_utils.cuh @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -524,35 +525,21 @@ std::tuple> mark_entries(raft::handle_t co return word; }); - // FIXME: use detail::count_set_bits - size_t bit_count = thrust::transform_reduce( - handle.get_thrust_policy(), - marked_entries.begin(), - marked_entries.end(), - [] __device__(auto word) { return __popc(word); }, - size_t{0}, - thrust::plus()); + size_t bit_count = detail::count_set_bits(handle, marked_entries.begin(), num_entries); return std::make_tuple(bit_count, std::move(marked_entries)); } template -rmm::device_uvector remove_flagged_elements(raft::handle_t const& handle, - rmm::device_uvector&& vector, - raft::device_span remove_flags, - size_t remove_count) +rmm::device_uvector keep_flagged_elements(raft::handle_t const& handle, + rmm::device_uvector&& vector, + raft::device_span keep_flags, + size_t keep_count) { - rmm::device_uvector result(vector.size() - remove_count, handle.get_stream()); - - thrust::copy_if( - handle.get_thrust_policy(), - thrust::make_counting_iterator(size_t{0}), - thrust::make_counting_iterator(vector.size()), - thrust::make_transform_output_iterator(result.begin(), - indirection_t{vector.data()}), - [remove_flags] __device__(size_t i) { - return !(remove_flags[cugraph::packed_bool_offset(i)] & cugraph::packed_bool_mask(i)); - }); + rmm::device_uvector result(keep_count, handle.get_stream()); + + detail::copy_if_mask_set( + handle, vector.begin(), vector.end(), keep_flags.begin(), result.begin()); return result; } diff --git a/cpp/src/structure/remove_multi_edges_impl.cuh b/cpp/src/structure/remove_multi_edges_impl.cuh index ab6b1fba8eb..6df82990c4e 100644 --- a/cpp/src/structure/remove_multi_edges_impl.cuh +++ b/cpp/src/structure/remove_multi_edges_impl.cuh @@ -254,7 +254,7 @@ remove_multi_edges(raft::handle_t const& handle, } } - auto [multi_edge_count, multi_edges_to_delete] = + auto [keep_count, keep_flags] = detail::mark_entries(handle, edgelist_srcs.size(), [d_edgelist_srcs = edgelist_srcs.data(), @@ -263,41 +263,38 @@ remove_multi_edges(raft::handle_t const& handle, (d_edgelist_dsts[idx - 1] == d_edgelist_dsts[idx]); }); - if (multi_edge_count > 0) { - edgelist_srcs = detail::remove_flagged_elements( + if (keep_count < edgelist_srcs.size()) { + edgelist_srcs = detail::keep_flagged_elements( handle, std::move(edgelist_srcs), - raft::device_span{multi_edges_to_delete.data(), multi_edges_to_delete.size()}, - multi_edge_count); - edgelist_dsts = detail::remove_flagged_elements( + raft::device_span{keep_flags.data(), keep_flags.size()}, + keep_count); + edgelist_dsts = detail::keep_flagged_elements( handle, std::move(edgelist_dsts), - raft::device_span{multi_edges_to_delete.data(), multi_edges_to_delete.size()}, - multi_edge_count); + raft::device_span{keep_flags.data(), keep_flags.size()}, + keep_count); if (edgelist_weights) - edgelist_weights = detail::remove_flagged_elements( + edgelist_weights = detail::keep_flagged_elements( handle, std::move(*edgelist_weights), - raft::device_span{multi_edges_to_delete.data(), - multi_edges_to_delete.size()}, - multi_edge_count); + raft::device_span{keep_flags.data(), keep_flags.size()}, + keep_count); if (edgelist_edge_ids) - edgelist_edge_ids = detail::remove_flagged_elements( + edgelist_edge_ids = detail::keep_flagged_elements( handle, std::move(*edgelist_edge_ids), - raft::device_span{multi_edges_to_delete.data(), - multi_edges_to_delete.size()}, - multi_edge_count); + raft::device_span{keep_flags.data(), keep_flags.size()}, + keep_count); if (edgelist_edge_types) - edgelist_edge_types = detail::remove_flagged_elements( + edgelist_edge_types = detail::keep_flagged_elements( handle, std::move(*edgelist_edge_types), - raft::device_span{multi_edges_to_delete.data(), - multi_edges_to_delete.size()}, - multi_edge_count); + raft::device_span{keep_flags.data(), keep_flags.size()}, + keep_count); } return std::make_tuple(std::move(edgelist_srcs), diff --git a/cpp/src/structure/remove_self_loops_impl.cuh b/cpp/src/structure/remove_self_loops_impl.cuh index 161ffeae28e..dafe26cd1c5 100644 --- a/cpp/src/structure/remove_self_loops_impl.cuh +++ b/cpp/src/structure/remove_self_loops_impl.cuh @@ -44,44 +44,44 @@ remove_self_loops(raft::handle_t const& handle, std::optional>&& edgelist_edge_ids, std::optional>&& edgelist_edge_types) { - auto [self_loop_count, self_loops_to_delete] = + auto [keep_count, keep_flags] = detail::mark_entries(handle, edgelist_srcs.size(), [d_srcs = edgelist_srcs.data(), d_dsts = edgelist_dsts.data()] __device__( - size_t i) { return d_srcs[i] == d_dsts[i]; }); + size_t i) { return d_srcs[i] != d_dsts[i]; }); - if (self_loop_count > 0) { - edgelist_srcs = detail::remove_flagged_elements( + if (keep_count < edgelist_srcs.size()) { + edgelist_srcs = detail::keep_flagged_elements( handle, std::move(edgelist_srcs), - raft::device_span{self_loops_to_delete.data(), self_loops_to_delete.size()}, - self_loop_count); - edgelist_dsts = detail::remove_flagged_elements( + raft::device_span{keep_flags.data(), keep_flags.size()}, + keep_count); + edgelist_dsts = detail::keep_flagged_elements( handle, std::move(edgelist_dsts), - raft::device_span{self_loops_to_delete.data(), self_loops_to_delete.size()}, - self_loop_count); + raft::device_span{keep_flags.data(), keep_flags.size()}, + keep_count); if (edgelist_weights) - edgelist_weights = detail::remove_flagged_elements( + edgelist_weights = detail::keep_flagged_elements( handle, std::move(*edgelist_weights), - raft::device_span{self_loops_to_delete.data(), self_loops_to_delete.size()}, - self_loop_count); + raft::device_span{keep_flags.data(), keep_flags.size()}, + keep_count); if (edgelist_edge_ids) - edgelist_edge_ids = detail::remove_flagged_elements( + edgelist_edge_ids = detail::keep_flagged_elements( handle, std::move(*edgelist_edge_ids), - raft::device_span{self_loops_to_delete.data(), self_loops_to_delete.size()}, - self_loop_count); + raft::device_span{keep_flags.data(), keep_flags.size()}, + keep_count); if (edgelist_edge_types) - edgelist_edge_types = detail::remove_flagged_elements( + edgelist_edge_types = detail::keep_flagged_elements( handle, std::move(*edgelist_edge_types), - raft::device_span{self_loops_to_delete.data(), self_loops_to_delete.size()}, - self_loop_count); + raft::device_span{keep_flags.data(), keep_flags.size()}, + keep_count); } return std::make_tuple(std::move(edgelist_srcs),