diff --git a/cpp/src/community/k_truss_impl.cuh b/cpp/src/community/k_truss_impl.cuh index ec0f761885b..1fa7a14d6a8 100644 --- a/cpp/src/community/k_truss_impl.cuh +++ b/cpp/src/community/k_truss_impl.cuh @@ -176,13 +176,13 @@ edge_t remove_overcompensating_edges(raft::handle_t const& handle, } template -void find_unroll_p_r_or_q_r_edges( +void unroll_p_r_or_q_r_edges( raft::handle_t const& handle, graph_view_t& graph_view, edge_t num_invalid_edges, edge_t num_valid_edges, - raft::device_span edgelist_srcs, // FIXME: Use device_span instead - raft::device_span edgelist_dsts, + raft::device_span edgelist_srcs, + raft::device_span edgelist_dsts, raft::device_span num_triangles) { auto prefix_sum_valid = prefix_sum_valid_and_invalid_edges( @@ -823,38 +823,24 @@ std::tuple, rmm::device_uvector> k_truss // case 2: unroll (q, r) // For each (q, r) edges to unroll, find the incoming edges to 'r' let's say from 'p' and // create the pair (p, q) - cugraph::find_unroll_p_r_or_q_r_edges( + cugraph::unroll_p_r_or_q_r_edges( handle, cur_graph_view, num_invalid_edges, num_valid_edges, - raft::device_span(edgelist_srcs.data(), edgelist_srcs.size()), - raft::device_span(edgelist_dsts.data(), edgelist_dsts.size()), + raft::device_span(edgelist_srcs.data(), edgelist_srcs.size()), + raft::device_span(edgelist_dsts.data(), edgelist_dsts.size()), raft::device_span(num_triangles.data(), num_triangles.size())); // case 3: unroll (p, r) - cugraph::find_unroll_p_r_or_q_r_edges( + cugraph::unroll_p_r_or_q_r_edges( handle, cur_graph_view, num_invalid_edges, num_valid_edges, - raft::device_span(edgelist_srcs.data(), edgelist_srcs.size()), - raft::device_span(edgelist_dsts.data(), edgelist_dsts.size()), + raft::device_span(edgelist_srcs.data(), edgelist_srcs.size()), + raft::device_span(edgelist_dsts.data(), edgelist_dsts.size()), raft::device_span(num_triangles.data(), num_triangles.size())); - - auto edges_with_triangle_last = - thrust::stable_partition(handle.get_thrust_policy(), - transposed_edge_triangle_count_pair_first, - transposed_edge_triangle_count_pair_first + num_triangles.size(), - [] __device__(auto e) { - auto num_triangles = thrust::get<1>(e); - return num_triangles > 0; - }); - - - - size_t num_edges_with_triangles = static_cast( - thrust::distance(transposed_edge_triangle_count_pair_first, edges_with_triangle_last)); //cugraph::edge_property_t edge_mask(handle, cur_graph_view); // Set edge property to 'True' for all edges then mask out invalid edges which can be @@ -870,8 +856,6 @@ std::tuple, rmm::device_uvector> k_truss edgelist_srcs.end(), edgelist_dsts.begin() + num_edges_with_triangles); - // FIXME: Cannot modify an edgemask that is still attached. - // This can lead to race conditions cur_graph_view.clear_edge_mask(); cugraph::transform_e( handle,