Skip to content

Commit

Permalink
properly handle missing edge types
Browse files Browse the repository at this point in the history
  • Loading branch information
jnke2016 committed Nov 22, 2024
1 parent 223a73b commit 061c8cc
Showing 1 changed file with 81 additions and 56 deletions.
137 changes: 81 additions & 56 deletions cpp/src/sampling/neighbor_sampling_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <raft/core/handle.hpp>

#include <rmm/device_uvector.hpp>
#include <thrust/unique.h>

namespace cugraph {
namespace detail {
Expand Down Expand Up @@ -104,14 +105,38 @@ neighbor_sample_impl(raft::handle_t const& handle,
edge_masks_vector{};
graph_view_t<vertex_t, edge_t, false, multi_gpu> modified_graph_view = graph_view;
edge_masks_vector.reserve(num_edge_types);

label_t num_unique_labels = 0;

std::optional<rmm::device_uvector<label_t>> cp_starting_vertex_labels{std::nullopt};

label_t num_labels = 0;

if (starting_vertex_labels) {
// Initial number of labels. Will be leveraged if there is no sampling result
num_labels = starting_vertex_labels->size();
// Find the number of unique lables
cp_starting_vertex_labels = rmm::device_uvector<label_t>(starting_vertex_labels->size(), handle.get_stream());

thrust::copy(
handle.get_thrust_policy(),
starting_vertex_labels->begin(),
starting_vertex_labels->end(),
cp_starting_vertex_labels->begin());

thrust::sort(
handle.get_thrust_policy(),
cp_starting_vertex_labels->begin(),
cp_starting_vertex_labels->end());

num_unique_labels = thrust::unique_count(handle.get_thrust_policy(),
cp_starting_vertex_labels->begin(),
cp_starting_vertex_labels->end());


}





if (num_edge_types > 1) {
for (int i = 0; i < num_edge_types; i++) {
cugraph::edge_property_t<graph_view_t<vertex_t, edge_t, store_transposed, multi_gpu>, bool>
Expand Down Expand Up @@ -374,60 +399,60 @@ neighbor_sample_impl(raft::handle_t const& handle,
if (result_labels) {
cp_result_labels = rmm::device_uvector<label_t>(result_labels->size(), handle.get_stream());

thrust::copy(handle.get_thrust_policy(),
result_labels->begin(),
result_labels->end(),
cp_result_labels->begin());
thrust::copy(
handle.get_thrust_policy(),
result_labels->begin(),
result_labels->end(),
cp_result_labels->begin());
}

// FIXME: remove the offsets computation in 'shuffle_and_organize_output' as it doesn't
// account for missing labels that are not sampled.
std::tie(result_srcs,
result_dsts,
result_weights,
result_edge_ids,
result_edge_types,
result_hops,
result_labels,
result_offsets) = detail::shuffle_and_organize_output(handle,
std::move(result_srcs),
std::move(result_dsts),
std::move(result_weights),
std::move(result_edge_ids),
std::move(result_edge_types),
std::move(result_hops),
std::move(result_labels),
label_to_output_comm_rank);

if (result_labels) {
// Re-compute the result_offsets and account for missing labels
result_offsets = rmm::device_uvector<size_t>(num_labels + 1, handle.get_stream());

// Sort labels
thrust::sort(handle.get_thrust_policy(), cp_result_labels->begin(), cp_result_labels->end());

thrust::transform(handle.get_thrust_policy(),
thrust::make_counting_iterator<edge_t>(0),
thrust::make_counting_iterator<edge_t>(result_offsets->size() - 1),
result_offsets->begin() + 1,
[result_labels = raft::device_span<label_t const>(
cp_result_labels->data(), cp_result_labels->size())] __device__(auto idx) {
auto itr_lower = thrust::lower_bound(
thrust::seq, result_labels.begin(), result_labels.end(), idx);

auto itr_upper = thrust::upper_bound(
thrust::seq, result_labels.begin(), result_labels.end(), idx);

return thrust::distance(itr_lower, itr_upper);
});

// Run inclusive scan
thrust::inclusive_scan(handle.get_thrust_policy(),
result_offsets->begin() + 1,
result_offsets->end(),
result_offsets->begin() + 1);
}

std::tie(result_srcs, result_dsts, result_weights, result_edge_ids,
result_edge_types, result_hops, result_labels, result_offsets)
= detail::shuffle_and_organize_output(handle,
std::move(result_srcs),
std::move(result_dsts),
std::move(result_weights),
std::move(result_edge_ids),
std::move(result_edge_types),
std::move(result_hops),
std::move(result_labels),
label_to_output_comm_rank);

if (result_labels && (result_offsets->size() != num_unique_labels + 1)) {
// There are missing labels not sampled.
result_offsets = rmm::device_uvector<size_t>(num_unique_labels + 1, handle.get_stream());

// Sort labels
thrust::sort(
handle.get_thrust_policy(),
cp_result_labels->begin(),
cp_result_labels->end());

thrust::transform(
handle.get_thrust_policy(),
thrust::make_counting_iterator<edge_t>(0),
thrust::make_counting_iterator<edge_t>(result_offsets->size() - 1),
result_offsets->begin() + 1,
[
result_labels = raft::device_span<label_t const>(
cp_result_labels->data(),
cp_result_labels->size())
] __device__(auto idx) {
auto itr_lower = thrust::lower_bound(
thrust::seq, result_labels.begin(), result_labels.end(), idx);

auto itr_upper = thrust::upper_bound(
thrust::seq, result_labels.begin(), result_labels.end(), idx);

auto sampled_label_size = thrust::distance(itr_lower, itr_upper);
return sampled_label_size;
});

// Run inclusive scan
thrust::inclusive_scan(handle.get_thrust_policy(),
result_offsets->begin() + 1,
result_offsets->end(),
result_offsets->begin() + 1);
}
return std::make_tuple(std::move(result_srcs),
std::move(result_dsts),
std::move(result_weights),
Expand Down

0 comments on commit 061c8cc

Please sign in to comment.