diff --git a/cpp/src/community/k_truss_impl.cuh b/cpp/src/community/k_truss_impl.cuh index b0e23e9497c..5b4601397d6 100644 --- a/cpp/src/community/k_truss_impl.cuh +++ b/cpp/src/community/k_truss_impl.cuh @@ -1115,7 +1115,34 @@ void k_truss(raft::handle_t const& handle, num_edges = edgelist_srcs.size(); - //num_invalid_edges = 0; //****************** debugging purposes + // FIXME: Rename the variable below. Can't reuse 'invalid_edge_last' + // because the call below returns a type different than 'edges_to_num_triangles' + auto invalid_edge_last_1 = + thrust::stable_partition(handle.get_thrust_policy(), + edges_to_num_triangles, + edges_to_num_triangles + num_triangles.size(), + [k] __device__(auto e) { + auto num_triangles = thrust::get<1>(e); + return num_triangles < k - 2; + }); + + num_invalid_edges = static_cast( + thrust::distance(edges_to_num_triangles, invalid_edge_last_1)); + + + // copy invalid edges + resize_dataframe_buffer(invalid_edges_buffer, num_invalid_edges, handle.get_stream()); + + thrust::copy(handle.get_thrust_policy(), + edges, + edges + num_invalid_edges, + get_dataframe_buffer_begin(invalid_edges_buffer)); + + // sort back the edges as those are needed later when running a binary tree + thrust::sort_by_key(handle.get_thrust_policy(), + edges, + edges + num_edges, + num_triangles.begin()); } if (num_invalid_edges == num_edges) {