Skip to content

Commit

Permalink
unroll edges without using global comms
Browse files Browse the repository at this point in the history
  • Loading branch information
jnke2016 committed Jul 13, 2024
1 parent d956d22 commit dc6136b
Showing 1 changed file with 90 additions and 17 deletions.
107 changes: 90 additions & 17 deletions cpp/src/community/k_truss_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ edge_t remove_overcompensating_edges(raft::handle_t const& handle,
raft::device_span<vertex_t const> global_set_c_weak_edges_dsts,
raft::device_span<vertex_t const> set_c_weak_edges_srcs,
raft::device_span<vertex_t const> set_c_weak_edges_dsts,
vertex_t number_of_local_edge_partitions,
std::vector<vertex_t> vertex_partition_range_lasts)
{

Expand Down Expand Up @@ -86,17 +87,21 @@ edge_t remove_overcompensating_edges(raft::handle_t const& handle,
edges_not_overcomp);
return dist;
} else {

auto& comm = handle.get_comms();
auto const comm_rank = comm.get_rank();

rmm::device_uvector<vertex_t> set_a_query_edges_srcs(buffer_size, handle.get_stream());
rmm::device_uvector<vertex_t> set_a_query_edges_dsts(buffer_size, handle.get_stream());
std::vector<size_t> rx_count{};
std::vector<size_t> rx_counts{};

thrust::copy(handle.get_thrust_policy(),
set_a_query_edges,
set_a_query_edges + buffer_size,
thrust::make_zip_iterator(set_a_query_edges_srcs.begin(), set_a_query_edges_dsts.begin()));

// group_by_count to get the destination of each edges
std::tie(set_a_query_edges_srcs, set_a_query_edges_dsts, std::ignore, std::ignore, std::ignore, rx_count) =
std::tie(set_a_query_edges_srcs, set_a_query_edges_dsts, std::ignore, std::ignore, std::ignore, rx_counts) =
detail::shuffle_int_vertex_pairs_with_values_to_local_gpu_by_edge_partitioning<vertex_t,
edge_t,
float,
Expand All @@ -106,11 +111,11 @@ edge_t remove_overcompensating_edges(raft::handle_t const& handle,

rmm::device_uvector<vertex_t> has_edge(set_a_query_edges_srcs.size(), handle.get_stream()); // type should be size_t

auto set_c_weak_edges_first = thrust::make_zip_iterator(global_set_c_weak_edges_srcs.begin(), global_set_c_weak_edges_dsts.begin()); // setBedges
auto set_c_weak_edges_last = thrust::make_zip_iterator(global_set_c_weak_edges_srcs.end(), global_set_c_weak_edges_dsts.end());
auto set_c_weak_edges_first = thrust::make_zip_iterator(set_c_weak_edges_srcs.begin(), set_c_weak_edges_dsts.begin()); // setBedges
auto set_c_weak_edges_last = thrust::make_zip_iterator(set_c_weak_edges_srcs.end(), set_c_weak_edges_dsts.end());
auto set_a_query_edges_first = thrust::make_zip_iterator(set_a_query_edges_srcs.begin(), set_a_query_edges_dsts.begin());

// FIXME: Was recommended to use thrust::transform instead but how ?
// FIXME: Use thrust::transform instead
thrust::tabulate(
handle.get_thrust_policy(),
has_edge.begin(),
Expand All @@ -124,32 +129,80 @@ edge_t remove_overcompensating_edges(raft::handle_t const& handle,
thrust::seq, set_c_weak_edges_first, set_c_weak_edges_last, set_a_query_edges_first[i]);
});

//auto& comm = handle.get_comms();
auto const comm_size = comm.get_size();
auto& major_comm = handle.get_subcomm(cugraph::partition_manager::major_comm_name());
auto const major_comm_size = major_comm.get_size();
auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name());
auto const minor_comm_size = minor_comm.get_size();

rmm::device_uvector<vertex_t> d_vertex_partition_range_lasts(vertex_partition_range_lasts.size(),
handle.get_stream());

raft::update_device(d_vertex_partition_range_lasts.data(),
vertex_partition_range_lasts.data(),
vertex_partition_range_lasts.size(),
handle.get_stream());

auto func = cugraph::detail::compute_gpu_id_from_int_edge_endpoints_t<vertex_t>{
raft::device_span<vertex_t const>(d_vertex_partition_range_lasts.data(),
d_vertex_partition_range_lasts.size()),
comm_size,
major_comm_size,
minor_comm_size};

auto d_tx_counts = cugraph::groupby_and_count(
thrust::make_zip_iterator(set_a_query_edges_srcs.begin(), set_a_query_edges_dsts.begin()),
thrust::make_zip_iterator(set_a_query_edges_srcs.end(), set_a_query_edges_dsts.end()),
[func, major_comm_size]__device__(auto val) {
return func(val) % major_comm_size;
},
major_comm_size,
std::numeric_limits<size_t>::max(),
handle.get_stream());

std::vector<size_t> h_tx_counts{d_tx_counts.size()};

raft::update_host(h_tx_counts.data(),
d_tx_counts.data(),
d_tx_counts.size(),
handle.get_stream());

std::tie(has_edge, std::ignore) =
shuffle_values(handle.get_comms(), has_edge.begin(), rx_count, handle.get_stream());
shuffle_values(handle.get_comms(), has_edge.begin(), h_tx_counts, handle.get_stream());

auto set_a_and_b_query_edges_first = thrust::make_zip_iterator(set_a_query_edges, set_b_query_edges);
auto set_a_and_b_query_edges_last = thrust::make_zip_iterator(
set_a_query_edges + buffer_size, set_b_query_edges + buffer_size);


thrust::sort_by_key(handle.get_thrust_policy(),
set_a_query_edges,
set_a_query_edges + buffer_size,
thrust::make_zip_iterator(set_b_query_edges, has_edge.begin())
);



auto edges_not_overcomp = thrust::remove_if(
handle.get_thrust_policy(),
set_a_and_b_query_edges_first,
set_a_and_b_query_edges_last,
[
set_a_and_b_query_edges_first,
set_a_and_b_query_edges_last,
set_a_query_edges,
buffer_size,
has_edge = raft::device_span<vertex_t const>(has_edge.data(), has_edge.size())
] __device__(auto pair_set) {
//auto set_a_query_edge = thrust::get<0>(pair_set)
auto itr = thrust::lower_bound(
thrust::seq, set_a_and_b_query_edges_first, set_a_and_b_query_edges_last, pair_set);
thrust::seq, set_a_query_edges, set_a_query_edges + buffer_size, thrust::get<0>(pair_set));

auto idx = thrust::distance(set_a_and_b_query_edges_first, itr);
auto idx = thrust::distance(set_a_query_edges, itr);
return has_edge[idx];

});

auto dist = thrust::distance(thrust::make_zip_iterator(set_a_query_edges,
set_b_query_edges), edges_not_overcomp);
auto dist = thrust::distance(set_a_and_b_query_edges_first, edges_not_overcomp);
return dist;

}
Expand All @@ -167,6 +220,16 @@ struct extract_weak_edges {
}
};

template <typename vertex_t, typename edge_t>
struct extract_edges { // FIXME: ******************************Remove this functor. For testing purposes only*******************
__device__ thrust::optional<thrust::tuple<vertex_t, vertex_t, edge_t>> operator()(

auto src, auto dst, thrust::nullopt_t, thrust::nullopt_t, auto count) const
{
return thrust::make_tuple(src, dst, count);
}
};

template <typename vertex_t, typename edge_t>
struct extract_edges_and_triangle_counts {
__device__ thrust::optional<thrust::tuple<vertex_t, vertex_t, edge_t>> operator()(
Expand Down Expand Up @@ -890,7 +953,7 @@ k_truss(raft::handle_t const& handle,
auto& major_comm = handle.get_subcomm(cugraph::partition_manager::major_comm_name());
// Perform all-to-all in chunks across minor comm
auto major_vertex_q_r_set = cugraph::detail::device_allgatherv(
handle, major_comm, raft::device_span<vertex_t const>(vertex_q_r_set.data(), vertex_q_r_set.size()));
handle, handle.get_comms(), raft::device_span<vertex_t const>(vertex_q_r_set.data(), vertex_q_r_set.size()));

thrust::sort(handle.get_thrust_policy(), major_vertex_q_r_set.begin(), major_vertex_q_r_set.end());

Expand Down Expand Up @@ -1114,6 +1177,7 @@ k_truss(raft::handle_t const& handle,
raft::device_span<vertex_t const>(global_weak_edgelist_dsts.data(), global_weak_edgelist_dsts.size()),
raft::device_span<vertex_t const>(weak_edgelist_srcs.data(), weak_edgelist_srcs.size()),
raft::device_span<vertex_t const>(weak_edgelist_dsts.data(), weak_edgelist_dsts.size()),
cur_graph_view.number_of_local_edge_partitions(),
cur_graph_view.vertex_partition_range_lasts()
);

Expand Down Expand Up @@ -1147,7 +1211,7 @@ k_truss(raft::handle_t const& handle,
decltype(get_dataframe_buffer_begin(vertex_pair_buffer_q_r)),
true,
multi_gpu,
false
true
>(
handle,
size_dataframe_buffer(vertex_pair_buffer_p_q_edge_q_r),
Expand All @@ -1157,6 +1221,7 @@ k_truss(raft::handle_t const& handle,
raft::device_span<vertex_t const>(weak_edgelist_dsts.data(), weak_edgelist_dsts.size()),
raft::device_span<vertex_t const>(weak_edgelist_srcs.data(), weak_edgelist_srcs.size()), // FIXME: Only for MG validation purposes
raft::device_span<vertex_t const>(weak_edgelist_dsts.data(), weak_edgelist_dsts.size()), // FIXME: Only for MG validation purposes
cur_graph_view.number_of_local_edge_partitions(),
cur_graph_view.vertex_partition_range_lasts() // Not needed for SG
);

Expand Down Expand Up @@ -1524,7 +1589,7 @@ k_truss(raft::handle_t const& handle,
thrust::sort(handle.get_thrust_policy(),
chunk_global_weak_edgelist_first,
chunk_global_weak_edgelist_first + global_weak_edgelist_srcs.size());

auto num_edges_not_overcomp_p_q =
remove_overcompensating_edges<vertex_t,
edge_t,
Expand All @@ -1541,8 +1606,13 @@ k_truss(raft::handle_t const& handle,
raft::device_span<vertex_t const>(global_weak_edgelist_dsts.data(), global_weak_edgelist_dsts.size()),
raft::device_span<vertex_t const>(weak_edgelist_srcs.data(), weak_edgelist_srcs.size()),
raft::device_span<vertex_t const>(weak_edgelist_dsts.data(), weak_edgelist_dsts.size()),
cur_graph_view.number_of_local_edge_partitions(),
cur_graph_view.vertex_partition_range_lasts()
);

// FIXME: No need to resize the dataframes buffer now.
resize_dataframe_buffer(vertex_pair_buffer_p_q_edge_p_r, num_edges_not_overcomp_p_q, handle.get_stream());
resize_dataframe_buffer(vertex_pair_buffer_q_r_edge_p_r, num_edges_not_overcomp_p_q, handle.get_stream());

auto num_edges_not_overcomp_q_r =
remove_overcompensating_edges<vertex_t,
Expand All @@ -1560,6 +1630,7 @@ k_truss(raft::handle_t const& handle,
raft::device_span<vertex_t const>(global_weak_edgelist_dsts.data(), global_weak_edgelist_dsts.size()),
raft::device_span<vertex_t const>(weak_edgelist_srcs.data(), weak_edgelist_srcs.size()),
raft::device_span<vertex_t const>(weak_edgelist_dsts.data(), weak_edgelist_dsts.size()),
cur_graph_view.number_of_local_edge_partitions(),
cur_graph_view.vertex_partition_range_lasts());

resize_dataframe_buffer(vertex_pair_buffer_q_r_edge_p_r, num_edges_not_overcomp_q_r, handle.get_stream());
Expand Down Expand Up @@ -1645,7 +1716,7 @@ k_truss(raft::handle_t const& handle,
decltype(get_dataframe_buffer_begin(vertex_pair_buffer_p_q_edge_p_r)),
false,
multi_gpu,
false
true
>(
handle,
q_closing.size(),
Expand All @@ -1655,6 +1726,7 @@ k_truss(raft::handle_t const& handle,
raft::device_span<vertex_t const>(weak_edgelist_dsts.data(), weak_edgelist_dsts.size()),
raft::device_span<vertex_t const>(weak_edgelist_srcs.data(), weak_edgelist_srcs.size()), // FIXME: Only for MG validation purposes
raft::device_span<vertex_t const>(weak_edgelist_dsts.data(), weak_edgelist_dsts.size()), // FIXME: Only for MG validation purposes
cur_graph_view.number_of_local_edge_partitions(),
cur_graph_view.vertex_partition_range_lasts());

resize_dataframe_buffer(vertex_pair_buffer_p_q_edge_p_r, num_edges_not_overcomp_p_q, handle.get_stream());
Expand All @@ -1666,7 +1738,7 @@ k_truss(raft::handle_t const& handle,
decltype(get_dataframe_buffer_begin(vertex_pair_buffer_p_q_edge_p_r)),
false,
multi_gpu,
false
true
>(
handle,
num_edges_not_overcomp_p_q,
Expand All @@ -1676,6 +1748,7 @@ k_truss(raft::handle_t const& handle,
raft::device_span<vertex_t const>(weak_edgelist_dsts.data(), weak_edgelist_dsts.size()),
raft::device_span<vertex_t const>(weak_edgelist_srcs.data(), weak_edgelist_srcs.size()), // FIXME: Only for MG validation purposes
raft::device_span<vertex_t const>(weak_edgelist_dsts.data(), weak_edgelist_dsts.size()), // FIXME: Only for MG validation purposes
cur_graph_view.number_of_local_edge_partitions(),
cur_graph_view.vertex_partition_range_lasts());

resize_dataframe_buffer(vertex_pair_buffer_p_q_edge_p_r, num_edges_not_overcomp_q_r, handle.get_stream());
Expand Down

0 comments on commit dc6136b

Please sign in to comment.