Skip to content

Commit

Permalink
cleanup code and remove tmp buffers
Browse files Browse the repository at this point in the history
  • Loading branch information
jnke2016 committed Feb 20, 2024
1 parent d2694b7 commit 1130354
Showing 1 changed file with 281 additions and 11 deletions.
292 changes: 281 additions & 11 deletions cpp/src/community/k_truss_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,27 @@ struct unroll_edge {
};


template <typename vertex_t, typename edge_t, typename VertexPairIterator>
struct generate_pr {
raft::device_span<size_t const> intersection_offsets{};
raft::device_span<vertex_t const> intersection_indices{};

VertexPairIterator vertex_pairs_begin{};

__device__ thrust::tuple<vertex_t, vertex_t> operator()(edge_t i) const
{
auto itr = thrust::upper_bound(
thrust::seq, intersection_offsets.begin() + 1, intersection_offsets.end(), i);
auto idx = thrust::distance(intersection_offsets.begin() + 1, itr);
auto pair =
thrust::make_tuple(thrust::get<0>(*(vertex_pairs_begin + idx)), intersection_indices[i]);

return pair;
}
};



template <typename vertex_t, typename edge_t, typename VertexPairIterator>
struct generate_qr {
raft::device_span<size_t const> intersection_offsets{};
Expand Down Expand Up @@ -950,6 +971,7 @@ void k_truss(raft::handle_t const& handle,
// Put edges with triangle count == 0 in the second partition
// FIXME: revisit all the 'stable_partition' and only used them
// when necessary otherwise simply call 'thrust::partition'
// Stable_parition is needed because we want to keep src and dst sorted
auto edges_to_num_triangles_last =
thrust::stable_partition(handle.get_thrust_policy(),
edges_to_num_triangles,
Expand Down Expand Up @@ -998,6 +1020,7 @@ void k_truss(raft::handle_t const& handle,
// resize the 'edgelist_srcs' and 'edgelsit_dst'
edgelist_srcs.resize(last_edge_idx, handle.get_stream());
edgelist_dsts.resize(last_edge_idx, handle.get_stream());
num_triangles.resize(last_edge_idx, handle.get_stream());

num_vertex_pairs = edgelist_srcs.size();

Expand Down Expand Up @@ -1053,9 +1076,6 @@ void k_truss(raft::handle_t const& handle,
// FIXME: Among the invalid edges, identify those that were removed to
// avoid extra panalization. One way to achieve it is by calling thrust::set_intersection
// to filter out the removed edges. However this will require another array.

vertex_pairs_begin =
thrust::make_zip_iterator(edgelist_srcs.begin(), edgelist_dsts.begin());

raft::print_device_vector("zip - after removing edges", edgelist_srcs.data(), edgelist_srcs.size(), std::cout);
raft::print_device_vector("zip - after removing edges", edgelist_dsts.data(), edgelist_dsts.size(), std::cout);
Expand Down Expand Up @@ -1115,10 +1135,7 @@ void k_truss(raft::handle_t const& handle,

raft::print_device_vector("prefix_sum", prefix_sum.data(), prefix_sum.size(), std::cout);

num_invalid_edges = 0;

// case 3 unroll (p, r)
/*
vertex_pair_buffer_p_q = allocate_dataframe_buffer<thrust::tuple<vertex_t, vertex_t>>(
prefix_sum.back_element(handle.get_stream()), handle.get_stream());

Expand All @@ -1129,10 +1146,12 @@ void k_truss(raft::handle_t const& handle,
handle.get_thrust_policy(),
indices.begin(),
indices.end(),
[invalid_first_dst = std::get<1>(vertex_pair_buffer).begin(),
invalid_first_src = std::get<0>(vertex_pair_buffer).begin(),
[invalid_first_dst = std::get<1>(invalid_edges_buffer).begin(),
//invalid_first_dst = std::get<1>(vertex_pair_buffer).begin(),
invalid_first_src = std::get<0>(invalid_edges_buffer).begin(),
//invalid_first_src = std::get<0>(vertex_pair_buffer).begin(),
prefix_sum = prefix_sum.data(),
incoming_vertex_pairs = incoming_vertex_pairs,
incoming_vertex_pairs = get_dataframe_buffer_begin(incoming_vertex_pairs),
vertex_pair_buffer_p_q = get_dataframe_buffer_begin(vertex_pair_buffer_p_q),
vertex_pair_buffer_q_r = get_dataframe_buffer_begin(vertex_pair_buffer_q_r),
num_edges = num_vertex_pairs] __device__(auto idx) {
Expand Down Expand Up @@ -1169,6 +1188,15 @@ void k_truss(raft::handle_t const& handle,
thrust::get<1>(*(incoming_vertex_pairs + idx_lower + idx_in_segment)), dst);
});
});


printf("\ngetting all possible incomming edges\n");
raft::print_device_vector("p_q - src", std::get<0>(vertex_pair_buffer_p_q).data(), std::get<0>(vertex_pair_buffer_p_q).size(), std::cout);
raft::print_device_vector("p_q - dst", std::get<1>(vertex_pair_buffer_p_q).data(), std::get<1>(vertex_pair_buffer_p_q).size(), std::cout);


raft::print_device_vector("q_r - src", std::get<0>(vertex_pair_buffer_q_r).data(), std::get<0>(vertex_pair_buffer_q_r).size(), std::cout);
raft::print_device_vector("q_r - dst", std::get<1>(vertex_pair_buffer_q_r).data(), std::get<1>(vertex_pair_buffer_q_r).size(), std::cout);

edge_exists = cur_graph_view.has_edge(
handle,
Expand All @@ -1181,7 +1209,176 @@ void k_truss(raft::handle_t const& handle,
thrust::make_zip_iterator(get_dataframe_buffer_begin(vertex_pair_buffer_p_q),
get_dataframe_buffer_begin(vertex_pair_buffer_q_r)),
edge_exists.begin());

has_edge_last = thrust::stable_partition(handle.get_thrust_policy(),
edge_to_existance,
edge_to_existance + edge_exists.size(),
[] __device__(auto e) {
auto edge_exists = thrust::get<1>(e);
return edge_exists;
});

num_edge_exists = thrust::distance(edge_to_existance, has_edge_last);

// After pushing the non-existant edges to the second partition,
// remove them by resizing both vertex pair buffer
resize_dataframe_buffer(vertex_pair_buffer_p_q, num_edge_exists, handle.get_stream());
resize_dataframe_buffer(vertex_pair_buffer_q_r, num_edge_exists, handle.get_stream());

raft::print_device_vector("***p_q - src", std::get<0>(vertex_pair_buffer_p_q).data(), std::get<0>(vertex_pair_buffer_p_q).size(), std::cout);
raft::print_device_vector("***p_q - dst", std::get<1>(vertex_pair_buffer_p_q).data(), std::get<1>(vertex_pair_buffer_p_q).size(), std::cout);
raft::print_device_vector("***q_r - src", std::get<0>(vertex_pair_buffer_q_r).data(), std::get<0>(vertex_pair_buffer_q_r).size(), std::cout);
raft::print_device_vector("***q_r - dst", std::get<1>(vertex_pair_buffer_q_r).data(), std::get<1>(vertex_pair_buffer_q_r).size(), std::cout);

raft::print_device_vector("before unrolling - invalid_srcs", edgelist_srcs_.data(), edgelist_srcs_.size(), std::cout);
raft::print_device_vector("before unrolling - invalid_dsts", edgelist_dsts_.data(), edgelist_dsts_.size(), std::cout);
vertex_pairs_end = vertex_pairs_begin + num_vertex_pairs;
thrust::for_each(handle.get_thrust_policy(),
thrust::make_counting_iterator<edge_t>(0),
thrust::make_counting_iterator<edge_t>(num_edge_exists),
unroll_edge<vertex_t, edge_t, decltype(vertex_pairs_begin)>{
//num_vertex_pairs, FIXME: Passing the 'num_vertex_pairs' instead of 'vertex_pairs_end_' yield wrong results
raft::device_span<vertex_t>(num_triangles.data(), num_triangles.size()),
get_dataframe_buffer_begin(vertex_pair_buffer_p_q),
vertex_pairs_begin,
vertex_pairs_end,
});

raft::print_device_vector("num_triangles after unrolling p_q edges", num_triangles.data(), num_triangles.size(), std::cout);

thrust::for_each(handle.get_thrust_policy(),
thrust::make_counting_iterator<edge_t>(0),
thrust::make_counting_iterator<edge_t>(num_edge_exists),
unroll_edge<vertex_t, edge_t, decltype(vertex_pairs_begin)>{
//num_vertex_pairs, FIXME: Passing the 'num_vertex_pairs' instead of 'vertex_pairs_end_' yield wrong results
raft::device_span<vertex_t>(num_triangles.data(), num_triangles.size()),
get_dataframe_buffer_begin(vertex_pair_buffer_q_r),
vertex_pairs_begin,
vertex_pairs_end,
});

raft::print_device_vector("num_triangles after unrolling q_r edges", num_triangles.data(), num_triangles.size(), std::cout);



























// Put edges with triangle count == 0 in the second partition
// FIXME: revisit all the 'stable_partition' and only used them
// when necessary otherwise simply call 'thrust::partition'
// Stable_parition is needed because we want to keep src and dst sorted
edges_to_num_triangles_last =
thrust::stable_partition(handle.get_thrust_policy(),
edges_to_num_triangles,
edges_to_num_triangles + num_vertex_pairs,
[] __device__(auto edge_to_num_triangles) {
return thrust::get<1>(edge_to_num_triangles) > 0;
});

last_edge_idx = thrust::distance(edges_to_num_triangles, edges_to_num_triangles_last);
// rename the above it to last_edge_with_triangles
/*
edges_to_num_triangles = thrust::make_zip_iterator(
get_dataframe_buffer_begin(vertex_pair_buffer), num_triangles.begin());
*/

/*
edge_list.insert(std::get<0>(vertex_pair_buffer).begin(),
std::get<0>(vertex_pair_buffer).begin() + last_edge_idx,
std::get<1>(vertex_pair_buffer).begin());
*/
// rename the below to edges_with_triangles
edge_list.clear(); // FIXME: is this needed?

cugraph::edge_property_t<decltype(cur_graph_view), bool> edge_value_output_p_r(
handle, cur_graph_view);
edge_list.insert(edgelist_srcs.begin(),
edgelist_srcs.begin() + last_edge_idx,
edgelist_dsts.begin());

cugraph::transform_e(
handle,
cur_graph_view,
edge_list,
cugraph::edge_src_dummy_property_t{}.view(),
cugraph::edge_dst_dummy_property_t{}.view(),
cugraph::edge_dummy_property_t{}.view(),
[] __device__(auto src, auto dst, thrust::nullopt_t, thrust::nullopt_t, thrust::nullopt_t) {
return true;
},
edge_value_output_p_r.mutable_view(),
false);

cur_graph_view.attach_edge_mask(edge_value_output_p_r.view());

// resize the 'edgelist_srcs' and 'edgelsit_dst'
edgelist_srcs.resize(last_edge_idx, handle.get_stream());
edgelist_dsts.resize(last_edge_idx, handle.get_stream());
num_triangles.resize(last_edge_idx, handle.get_stream());

num_vertex_pairs = edgelist_srcs.size();

raft::print_device_vector("after removing edges", edgelist_srcs.data(), edgelist_srcs.size(), std::cout);
raft::print_device_vector("after removing edges", edgelist_dsts.data(), edgelist_dsts.size(), std::cout);




































/*
auto edges_to_num_triangles_p_r_last =
thrust::stable_partition(handle.get_thrust_policy(),
edge_to_existance,
Expand Down Expand Up @@ -1343,19 +1540,92 @@ void k_truss(raft::handle_t const& handle,
// Get the new pair of incoming edges
incoming_vertex_pairs =
thrust::make_zip_iterator(edgelist_dsts.begin(), edgelist_srcs.begin());
*/

// case 1. For the (p, q), find intersection 'r' to create (p, r, -1) and (q, r, -1)
// FIXME: check if 'invalid_edge_first' is necessery as I operate on 'vertex_pair_buffer'
// which contains the ordering with the number of triangles.

invalid_edge_last =
thrust::partition(handle.get_thrust_policy(),
get_dataframe_buffer_begin(invalid_edges_buffer),
get_dataframe_buffer_end(invalid_edges_buffer),
[edge_first = vertex_pairs_begin, // rename to 'edge'
edge_last = vertex_pairs_begin + num_vertex_pairs,
num_edges = num_vertex_pairs]
__device__(auto invalid_edge) {

auto itr = thrust::find(thrust::seq, edge_first, edge_last, invalid_edge);
auto idx = thrust::distance(edge_first, itr);
printf("\n src = %d, dst = %d, idx_lower = %d", thrust::get<0>(invalid_edge), thrust::get<1>(invalid_edge), idx);
return idx < num_edges;
});

// get_dataframe_buffer_begin(invalid_edges_buffer) + 3
num_invalid_edges = thrust::distance(get_dataframe_buffer_begin(invalid_edges_buffer), invalid_edge_last);


resize_dataframe_buffer(
invalid_edges_buffer, num_vertex_pairs, handle.get_stream());

printf("\n number of invalid edges = %d\n", num_invalid_edges); //L1084
raft::print_device_vector("p->q invalid - src", std::get<0>(invalid_edges_buffer).data(), std::get<0>(invalid_edges_buffer).size(), std::cout);
raft::print_device_vector("p->q invalid - dst", std::get<1>(invalid_edges_buffer).data(), std::get<1>(invalid_edges_buffer).size(), std::cout);



auto [intersection_offsets, intersection_indices] =
detail::nbr_intersection(handle,
cur_graph_view,
cugraph::edge_dummy_property_t{}.view(),
get_dataframe_buffer_begin(vertex_pair_buffer),
get_dataframe_buffer_end(vertex_pair_buffer),
get_dataframe_buffer_begin(invalid_edges_buffer),
get_dataframe_buffer_end(invalid_edges_buffer),
std::array<bool, 2>{true, true},
do_expensive_check);

printf("\nintersection size = %d\n", intersection_indices.size());
if (intersection_indices.size() > 0) {
size_t accumulate_pair_size =
intersection_indices.size(); // rename this var as accumulate_pair_size

auto vertex_pair_buffer_p_r_edge_p_q =
allocate_dataframe_buffer<thrust::tuple<vertex_t, vertex_t>>(
accumulate_pair_size, handle.get_stream());

thrust::tabulate(
handle.get_thrust_policy(),
get_dataframe_buffer_begin(vertex_pair_buffer_p_r_edge_p_q),
get_dataframe_buffer_end(vertex_pair_buffer_p_r_edge_p_q)
generate_pr<vertex_t, edge_t, decltype(get_dataframe_buffer_begin(invalid_edges_buffer))>{
raft::device_span<size_t const>(intersection_offsets.data(), intersection_offsets.size()),
raft::device_span<vertex_t const>(intersection_indices.data(),
intersection_indices.size()),
get_dataframe_buffer_begin(invalid_edges_buffer) // FIXME: verify this is accurate
});

// unroll set of edges one at a time to reduce peak memory

auto vertex_pair_buffer_q_r_edge_p_q =
allocate_dataframe_buffer<thrust::tuple<vertex_t, vertex_t>>(
accumulate_pair_size, handle.get_stream());

thrust::tabulate(
handle.get_thrust_policy(),
get_dataframe_buffer_begin(vertex_pair_buffer_q_r_edge_p_q),
get_dataframe_buffer_begin(vertex_pair_buffer_q_r_edge_p_q) +
accumulate_pair_size,
generate_qr<vertex_t, edge_t, decltype(get_dataframe_buffer_begin(invalid_edges_buffer))>{
raft::device_span<size_t const>(intersection_offsets.data(), intersection_offsets.size()),
raft::device_span<vertex_t const>(intersection_indices.data(),
intersection_indices.size()),
get_dataframe_buffer_begin(invalid_edges_buffer) // FIXME: verify this is accurate
});

}

num_invalid_edges = 0; //****************** debugging purposes

/*
// generating (p, r)
edge_t vertex_pair_buffer_p_r_edge_p_q_size =
intersection_indices.size(); // rename this var as accumulate_pair_size
Expand Down

0 comments on commit 1130354

Please sign in to comment.