From dc6136b7188ea24184f006201a62e645ff80dd80 Mon Sep 17 00:00:00 2001 From: jnke2016 Date: Fri, 12 Jul 2024 20:31:15 -0700 Subject: [PATCH] unroll edges without using global comms --- cpp/src/community/k_truss_impl.cuh | 107 ++++++++++++++++++++++++----- 1 file changed, 90 insertions(+), 17 deletions(-) diff --git a/cpp/src/community/k_truss_impl.cuh b/cpp/src/community/k_truss_impl.cuh index 696d79e8dc6..cc6099df2ae 100644 --- a/cpp/src/community/k_truss_impl.cuh +++ b/cpp/src/community/k_truss_impl.cuh @@ -52,6 +52,7 @@ edge_t remove_overcompensating_edges(raft::handle_t const& handle, raft::device_span global_set_c_weak_edges_dsts, raft::device_span set_c_weak_edges_srcs, raft::device_span set_c_weak_edges_dsts, + vertex_t number_of_local_edge_partitions, std::vector vertex_partition_range_lasts) { @@ -86,9 +87,13 @@ edge_t remove_overcompensating_edges(raft::handle_t const& handle, edges_not_overcomp); return dist; } else { + + auto& comm = handle.get_comms(); + auto const comm_rank = comm.get_rank(); + rmm::device_uvector set_a_query_edges_srcs(buffer_size, handle.get_stream()); rmm::device_uvector set_a_query_edges_dsts(buffer_size, handle.get_stream()); - std::vector rx_count{}; + std::vector rx_counts{}; thrust::copy(handle.get_thrust_policy(), set_a_query_edges, @@ -96,7 +101,7 @@ edge_t remove_overcompensating_edges(raft::handle_t const& handle, thrust::make_zip_iterator(set_a_query_edges_srcs.begin(), set_a_query_edges_dsts.begin())); // group_by_count to get the destination of each edges - std::tie(set_a_query_edges_srcs, set_a_query_edges_dsts, std::ignore, std::ignore, std::ignore, rx_count) = + std::tie(set_a_query_edges_srcs, set_a_query_edges_dsts, std::ignore, std::ignore, std::ignore, rx_counts) = detail::shuffle_int_vertex_pairs_with_values_to_local_gpu_by_edge_partitioning has_edge(set_a_query_edges_srcs.size(), handle.get_stream()); // type should be size_t - auto set_c_weak_edges_first = thrust::make_zip_iterator(global_set_c_weak_edges_srcs.begin(), global_set_c_weak_edges_dsts.begin()); // setBedges - auto set_c_weak_edges_last = thrust::make_zip_iterator(global_set_c_weak_edges_srcs.end(), global_set_c_weak_edges_dsts.end()); + auto set_c_weak_edges_first = thrust::make_zip_iterator(set_c_weak_edges_srcs.begin(), set_c_weak_edges_dsts.begin()); // setBedges + auto set_c_weak_edges_last = thrust::make_zip_iterator(set_c_weak_edges_srcs.end(), set_c_weak_edges_dsts.end()); auto set_a_query_edges_first = thrust::make_zip_iterator(set_a_query_edges_srcs.begin(), set_a_query_edges_dsts.begin()); - // FIXME: Was recommended to use thrust::transform instead but how ? + // FIXME: Use thrust::transform instead thrust::tabulate( handle.get_thrust_policy(), has_edge.begin(), @@ -124,32 +129,80 @@ edge_t remove_overcompensating_edges(raft::handle_t const& handle, thrust::seq, set_c_weak_edges_first, set_c_weak_edges_last, set_a_query_edges_first[i]); }); + //auto& comm = handle.get_comms(); + auto const comm_size = comm.get_size(); + auto& major_comm = handle.get_subcomm(cugraph::partition_manager::major_comm_name()); + auto const major_comm_size = major_comm.get_size(); + auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); + auto const minor_comm_size = minor_comm.get_size(); + + rmm::device_uvector d_vertex_partition_range_lasts(vertex_partition_range_lasts.size(), + handle.get_stream()); + + raft::update_device(d_vertex_partition_range_lasts.data(), + vertex_partition_range_lasts.data(), + vertex_partition_range_lasts.size(), + handle.get_stream()); + + auto func = cugraph::detail::compute_gpu_id_from_int_edge_endpoints_t{ + raft::device_span(d_vertex_partition_range_lasts.data(), + d_vertex_partition_range_lasts.size()), + comm_size, + major_comm_size, + minor_comm_size}; + + auto d_tx_counts = cugraph::groupby_and_count( + thrust::make_zip_iterator(set_a_query_edges_srcs.begin(), set_a_query_edges_dsts.begin()), + thrust::make_zip_iterator(set_a_query_edges_srcs.end(), set_a_query_edges_dsts.end()), + [func, major_comm_size]__device__(auto val) { + return func(val) % major_comm_size; + }, + major_comm_size, + std::numeric_limits::max(), + handle.get_stream()); + + std::vector h_tx_counts{d_tx_counts.size()}; + + raft::update_host(h_tx_counts.data(), + d_tx_counts.data(), + d_tx_counts.size(), + handle.get_stream()); + std::tie(has_edge, std::ignore) = - shuffle_values(handle.get_comms(), has_edge.begin(), rx_count, handle.get_stream()); + shuffle_values(handle.get_comms(), has_edge.begin(), h_tx_counts, handle.get_stream()); auto set_a_and_b_query_edges_first = thrust::make_zip_iterator(set_a_query_edges, set_b_query_edges); auto set_a_and_b_query_edges_last = thrust::make_zip_iterator( set_a_query_edges + buffer_size, set_b_query_edges + buffer_size); + + thrust::sort_by_key(handle.get_thrust_policy(), + set_a_query_edges, + set_a_query_edges + buffer_size, + thrust::make_zip_iterator(set_b_query_edges, has_edge.begin()) + ); + + + auto edges_not_overcomp = thrust::remove_if( handle.get_thrust_policy(), set_a_and_b_query_edges_first, set_a_and_b_query_edges_last, [ - set_a_and_b_query_edges_first, - set_a_and_b_query_edges_last, + set_a_query_edges, + buffer_size, has_edge = raft::device_span(has_edge.data(), has_edge.size()) ] __device__(auto pair_set) { + //auto set_a_query_edge = thrust::get<0>(pair_set) auto itr = thrust::lower_bound( - thrust::seq, set_a_and_b_query_edges_first, set_a_and_b_query_edges_last, pair_set); + thrust::seq, set_a_query_edges, set_a_query_edges + buffer_size, thrust::get<0>(pair_set)); - auto idx = thrust::distance(set_a_and_b_query_edges_first, itr); + auto idx = thrust::distance(set_a_query_edges, itr); return has_edge[idx]; }); - auto dist = thrust::distance(thrust::make_zip_iterator(set_a_query_edges, - set_b_query_edges), edges_not_overcomp); + auto dist = thrust::distance(set_a_and_b_query_edges_first, edges_not_overcomp); return dist; } @@ -167,6 +220,16 @@ struct extract_weak_edges { } }; +template +struct extract_edges { // FIXME: ******************************Remove this functor. For testing purposes only******************* + __device__ thrust::optional> operator()( + + auto src, auto dst, thrust::nullopt_t, thrust::nullopt_t, auto count) const + { + return thrust::make_tuple(src, dst, count); + } +}; + template struct extract_edges_and_triangle_counts { __device__ thrust::optional> operator()( @@ -890,7 +953,7 @@ k_truss(raft::handle_t const& handle, auto& major_comm = handle.get_subcomm(cugraph::partition_manager::major_comm_name()); // Perform all-to-all in chunks across minor comm auto major_vertex_q_r_set = cugraph::detail::device_allgatherv( - handle, major_comm, raft::device_span(vertex_q_r_set.data(), vertex_q_r_set.size())); + handle, handle.get_comms(), raft::device_span(vertex_q_r_set.data(), vertex_q_r_set.size())); thrust::sort(handle.get_thrust_policy(), major_vertex_q_r_set.begin(), major_vertex_q_r_set.end()); @@ -1114,6 +1177,7 @@ k_truss(raft::handle_t const& handle, raft::device_span(global_weak_edgelist_dsts.data(), global_weak_edgelist_dsts.size()), raft::device_span(weak_edgelist_srcs.data(), weak_edgelist_srcs.size()), raft::device_span(weak_edgelist_dsts.data(), weak_edgelist_dsts.size()), + cur_graph_view.number_of_local_edge_partitions(), cur_graph_view.vertex_partition_range_lasts() ); @@ -1147,7 +1211,7 @@ k_truss(raft::handle_t const& handle, decltype(get_dataframe_buffer_begin(vertex_pair_buffer_q_r)), true, multi_gpu, - false + true >( handle, size_dataframe_buffer(vertex_pair_buffer_p_q_edge_q_r), @@ -1157,6 +1221,7 @@ k_truss(raft::handle_t const& handle, raft::device_span(weak_edgelist_dsts.data(), weak_edgelist_dsts.size()), raft::device_span(weak_edgelist_srcs.data(), weak_edgelist_srcs.size()), // FIXME: Only for MG validation purposes raft::device_span(weak_edgelist_dsts.data(), weak_edgelist_dsts.size()), // FIXME: Only for MG validation purposes + cur_graph_view.number_of_local_edge_partitions(), cur_graph_view.vertex_partition_range_lasts() // Not needed for SG ); @@ -1524,7 +1589,7 @@ k_truss(raft::handle_t const& handle, thrust::sort(handle.get_thrust_policy(), chunk_global_weak_edgelist_first, chunk_global_weak_edgelist_first + global_weak_edgelist_srcs.size()); - + auto num_edges_not_overcomp_p_q = remove_overcompensating_edges(global_weak_edgelist_dsts.data(), global_weak_edgelist_dsts.size()), raft::device_span(weak_edgelist_srcs.data(), weak_edgelist_srcs.size()), raft::device_span(weak_edgelist_dsts.data(), weak_edgelist_dsts.size()), + cur_graph_view.number_of_local_edge_partitions(), cur_graph_view.vertex_partition_range_lasts() ); + + // FIXME: No need to resize the dataframes buffer now. + resize_dataframe_buffer(vertex_pair_buffer_p_q_edge_p_r, num_edges_not_overcomp_p_q, handle.get_stream()); + resize_dataframe_buffer(vertex_pair_buffer_q_r_edge_p_r, num_edges_not_overcomp_p_q, handle.get_stream()); auto num_edges_not_overcomp_q_r = remove_overcompensating_edges(global_weak_edgelist_dsts.data(), global_weak_edgelist_dsts.size()), raft::device_span(weak_edgelist_srcs.data(), weak_edgelist_srcs.size()), raft::device_span(weak_edgelist_dsts.data(), weak_edgelist_dsts.size()), + cur_graph_view.number_of_local_edge_partitions(), cur_graph_view.vertex_partition_range_lasts()); resize_dataframe_buffer(vertex_pair_buffer_q_r_edge_p_r, num_edges_not_overcomp_q_r, handle.get_stream()); @@ -1645,7 +1716,7 @@ k_truss(raft::handle_t const& handle, decltype(get_dataframe_buffer_begin(vertex_pair_buffer_p_q_edge_p_r)), false, multi_gpu, - false + true >( handle, q_closing.size(), @@ -1655,6 +1726,7 @@ k_truss(raft::handle_t const& handle, raft::device_span(weak_edgelist_dsts.data(), weak_edgelist_dsts.size()), raft::device_span(weak_edgelist_srcs.data(), weak_edgelist_srcs.size()), // FIXME: Only for MG validation purposes raft::device_span(weak_edgelist_dsts.data(), weak_edgelist_dsts.size()), // FIXME: Only for MG validation purposes + cur_graph_view.number_of_local_edge_partitions(), cur_graph_view.vertex_partition_range_lasts()); resize_dataframe_buffer(vertex_pair_buffer_p_q_edge_p_r, num_edges_not_overcomp_p_q, handle.get_stream()); @@ -1666,7 +1738,7 @@ k_truss(raft::handle_t const& handle, decltype(get_dataframe_buffer_begin(vertex_pair_buffer_p_q_edge_p_r)), false, multi_gpu, - false + true >( handle, num_edges_not_overcomp_p_q, @@ -1676,6 +1748,7 @@ k_truss(raft::handle_t const& handle, raft::device_span(weak_edgelist_dsts.data(), weak_edgelist_dsts.size()), raft::device_span(weak_edgelist_srcs.data(), weak_edgelist_srcs.size()), // FIXME: Only for MG validation purposes raft::device_span(weak_edgelist_dsts.data(), weak_edgelist_dsts.size()), // FIXME: Only for MG validation purposes + cur_graph_view.number_of_local_edge_partitions(), cur_graph_view.vertex_partition_range_lasts()); resize_dataframe_buffer(vertex_pair_buffer_p_q_edge_p_r, num_edges_not_overcomp_q_r, handle.get_stream());