Skip to content

Commit

Permalink
fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
jnke2016 committed Nov 27, 2024
1 parent 580d3e9 commit 529955a
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 124 deletions.
4 changes: 2 additions & 2 deletions cpp/src/c_api/neighbor_sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1284,8 +1284,8 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor {

edge_label ? (*offsets).size() - 1 : size_t{1},
hop ? (((fan_out_->size_ % num_edge_types_) == 0)
? (fan_out_->size_ / num_edge_types_)
: ((fan_out_->size_ / num_edge_types_) + 1))
? (fan_out_->size_ / num_edge_types_)
: ((fan_out_->size_ / num_edge_types_) + 1))
: size_t{1},
(vertex_type_offsets_ != nullptr) ? vertex_type_offsets_->size_ - 1
: vertex_type_offsets.size() - 1,
Expand Down
258 changes: 136 additions & 122 deletions cpp/src/sampling/neighbor_sampling_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <raft/core/handle.hpp>

#include <rmm/device_uvector.hpp>

#include <thrust/unique.h>

namespace cugraph {
Expand Down Expand Up @@ -105,26 +106,25 @@ neighbor_sample_impl(raft::handle_t const& handle,
edge_masks_vector{};
graph_view_t<vertex_t, edge_t, false, multi_gpu> modified_graph_view = graph_view;
edge_masks_vector.reserve(num_edge_types);

label_t num_unique_labels = 0;

std::optional<rmm::device_uvector<label_t>> cp_starting_vertex_labels{std::nullopt};

if (starting_vertex_labels) {
// Find the number of unique lables
cp_starting_vertex_labels = rmm::device_uvector<label_t>(starting_vertex_labels->size(), handle.get_stream());

thrust::copy(
handle.get_thrust_policy(),
starting_vertex_labels->begin(),
starting_vertex_labels->end(),
cp_starting_vertex_labels->begin());

thrust::sort(
handle.get_thrust_policy(),
cp_starting_vertex_labels->begin(),
cp_starting_vertex_labels->end());

cp_starting_vertex_labels =
rmm::device_uvector<label_t>(starting_vertex_labels->size(), handle.get_stream());

thrust::copy(handle.get_thrust_policy(),
starting_vertex_labels->begin(),
starting_vertex_labels->end(),
cp_starting_vertex_labels->begin());

thrust::sort(handle.get_thrust_policy(),
cp_starting_vertex_labels->begin(),
cp_starting_vertex_labels->end());

num_unique_labels = thrust::unique_count(handle.get_thrust_policy(),
cp_starting_vertex_labels->begin(),
cp_starting_vertex_labels->end());
Expand Down Expand Up @@ -169,7 +169,6 @@ neighbor_sample_impl(raft::handle_t const& handle,
? (fan_out.size() / num_edge_types)
: ((fan_out.size() / num_edge_types) + 1);


auto level_result_weight_vectors =
edge_weight_view ? std::make_optional(std::vector<rmm::device_uvector<weight_t>>{})
: std::nullopt;
Expand All @@ -190,13 +189,14 @@ neighbor_sample_impl(raft::handle_t const& handle,
: std::nullopt;
auto level_result_edge_id =
edge_id_view ? std::make_optional(rmm::device_uvector<edge_t>(0, handle.get_stream()))
: std::nullopt;
: std::nullopt;
auto level_result_edge_type =
edge_type_view ? std::make_optional(rmm::device_uvector<edge_type_t>(0, handle.get_stream()))
: std::nullopt;
: std::nullopt;
auto level_result_label =
starting_vertex_labels ? std::make_optional(rmm::device_uvector<label_t>(0, handle.get_stream()))
: std::nullopt;
starting_vertex_labels
? std::make_optional(rmm::device_uvector<label_t>(0, handle.get_stream()))
: std::nullopt;

if (level_result_weight_vectors) { (*level_result_weight_vectors).reserve(num_hops); }
if (level_result_edge_id_vectors) { (*level_result_edge_id_vectors).reserve(num_hops); }
Expand Down Expand Up @@ -224,7 +224,6 @@ neighbor_sample_impl(raft::handle_t const& handle,
std::vector<size_t> level_sizes{};
std::vector<size_t> level_sizes_edge_types{};


for (auto hop = 0; hop < num_hops; hop++) {
for (auto edge_type_id = 0; edge_type_id < num_edge_types; edge_type_id++) {
auto k_level = fan_out[(hop * num_edge_types) + edge_type_id];
Expand Down Expand Up @@ -265,53 +264,69 @@ neighbor_sample_impl(raft::handle_t const& handle,

level_sizes_edge_types.push_back(srcs.size());

level_result_src.resize(std::reduce(level_sizes_edge_types.begin(), level_sizes_edge_types.end()), handle.get_stream());
level_result_dst.resize(std::reduce(level_sizes_edge_types.begin(), level_sizes_edge_types.end()), handle.get_stream());
level_result_src.resize(
std::reduce(level_sizes_edge_types.begin(), level_sizes_edge_types.end()),
handle.get_stream());
level_result_dst.resize(
std::reduce(level_sizes_edge_types.begin(), level_sizes_edge_types.end()),
handle.get_stream());

raft::copy(level_result_src.begin() + std::reduce(level_sizes_edge_types.begin(), level_sizes_edge_types.end() - 1),
raft::copy(level_result_src.begin() +
std::reduce(level_sizes_edge_types.begin(), level_sizes_edge_types.end() - 1),
srcs.begin(),
srcs.size(),
handle.get_stream());

raft::copy(level_result_dst.begin() + std::reduce(level_sizes_edge_types.begin(), level_sizes_edge_types.end() - 1),

raft::copy(level_result_dst.begin() +
std::reduce(level_sizes_edge_types.begin(), level_sizes_edge_types.end() - 1),
dsts.begin(),
srcs.size(),
handle.get_stream());

if (weights) {
(*level_result_weight).resize(std::reduce(level_sizes_edge_types.begin(), level_sizes_edge_types.end()), handle.get_stream());

raft::copy(level_result_weight->begin() + std::reduce(level_sizes_edge_types.begin(), level_sizes_edge_types.end() - 1),
weights->begin(),
srcs.size(),
handle.get_stream());
(*level_result_weight)
.resize(std::reduce(level_sizes_edge_types.begin(), level_sizes_edge_types.end()),
handle.get_stream());

raft::copy(level_result_weight->begin() +
std::reduce(level_sizes_edge_types.begin(), level_sizes_edge_types.end() - 1),
weights->begin(),
srcs.size(),
handle.get_stream());
}

if (edge_ids) {
(*level_result_edge_id).resize(std::reduce(level_sizes_edge_types.begin(), level_sizes_edge_types.end()), handle.get_stream());
raft::copy(level_result_edge_id->begin() + std::reduce(level_sizes_edge_types.begin(), level_sizes_edge_types.end() - 1),
edge_ids->begin(),
srcs.size(),
handle.get_stream());
(*level_result_edge_id)
.resize(std::reduce(level_sizes_edge_types.begin(), level_sizes_edge_types.end()),
handle.get_stream());
raft::copy(level_result_edge_id->begin() +
std::reduce(level_sizes_edge_types.begin(), level_sizes_edge_types.end() - 1),
edge_ids->begin(),
srcs.size(),
handle.get_stream());
}
if (edge_types) {
(*level_result_edge_type).resize(std::reduce(level_sizes_edge_types.begin(), level_sizes_edge_types.end()), handle.get_stream());


raft::copy(level_result_edge_type->begin() + std::reduce(level_sizes_edge_types.begin(), level_sizes_edge_types.end() - 1),
edge_types->begin(),
srcs.size(),
handle.get_stream());
(*level_result_edge_type)
.resize(std::reduce(level_sizes_edge_types.begin(), level_sizes_edge_types.end()),
handle.get_stream());

raft::copy(level_result_edge_type->begin() +
std::reduce(level_sizes_edge_types.begin(), level_sizes_edge_types.end() - 1),
edge_types->begin(),
srcs.size(),
handle.get_stream());
}

if (labels) {
(*level_result_label).resize(std::reduce(level_sizes_edge_types.begin(), level_sizes_edge_types.end()), handle.get_stream());

raft::copy(level_result_label->begin() + std::reduce(level_sizes_edge_types.begin(), level_sizes_edge_types.end() - 1),
labels->begin(),
srcs.size(),
handle.get_stream());

if (labels) {
(*level_result_label)
.resize(std::reduce(level_sizes_edge_types.begin(), level_sizes_edge_types.end()),
handle.get_stream());

raft::copy(level_result_label->begin() +
std::reduce(level_sizes_edge_types.begin(), level_sizes_edge_types.end() - 1),
labels->begin(),
srcs.size(),
handle.get_stream());
}

if (num_edge_types > 1) { modified_graph_view.clear_edge_mask(); }
Expand All @@ -321,11 +336,18 @@ neighbor_sample_impl(raft::handle_t const& handle,
level_result_src_vectors.push_back(std::move(level_result_src));
level_result_dst_vectors.push_back(std::move(level_result_dst));

if (level_result_weight) { (*level_result_weight_vectors).push_back(std::move(*level_result_weight)); }
if (level_result_edge_id) { (*level_result_edge_id_vectors).push_back(std::move(*level_result_edge_id)); }
if (level_result_edge_type) { (*level_result_edge_type_vectors).push_back(std::move(*level_result_edge_type)); }
if (level_result_label) { (*level_result_label_vectors).push_back(std::move(*level_result_label)); }

if (level_result_weight) {
(*level_result_weight_vectors).push_back(std::move(*level_result_weight));
}
if (level_result_edge_id) {
(*level_result_edge_id_vectors).push_back(std::move(*level_result_edge_id));
}
if (level_result_edge_type) {
(*level_result_edge_type_vectors).push_back(std::move(*level_result_edge_type));
}
if (level_result_label) {
(*level_result_label_vectors).push_back(std::move(*level_result_label));
}

// FIXME: We should modify vertex_partition_range_lasts to return a raft::host_span
// rather than making a copy.
Expand All @@ -335,12 +357,10 @@ neighbor_sample_impl(raft::handle_t const& handle,
handle,
starting_vertices,
starting_vertex_labels,
raft::device_span<vertex_t const>{level_result_dst.data(),
level_result_dst.size()},
frontier_vertex_labels
? std::make_optional(raft::device_span<label_t const>(
level_result_label->data(), level_result_label->size()))
: std::nullopt,
raft::device_span<vertex_t const>{level_result_dst.data(), level_result_dst.size()},
frontier_vertex_labels ? std::make_optional(raft::device_span<label_t const>(
level_result_label->data(), level_result_label->size()))
: std::nullopt,
std::move(vertex_used_as_source),
modified_graph_view.local_vertex_partition_view(),
vertex_partition_range_lasts,
Expand Down Expand Up @@ -437,14 +457,13 @@ neighbor_sample_impl(raft::handle_t const& handle,
if (return_hops) {
result_hops = rmm::device_uvector<int32_t>(result_size, handle.get_stream());
output_offset = 0;
for (size_t i = 0; i < num_hops; ++i) { // FIXME: replace this by the number of hops
for (size_t i = 0; i < num_hops; ++i) { // FIXME: replace this by the number of hops
scalar_fill(
handle, result_hops->data() + output_offset, level_sizes[i], static_cast<int32_t>(i));
output_offset += level_sizes[i];
}
}


auto result_labels =
level_result_label_vectors
? std::make_optional(rmm::device_uvector<label_t>(result_size, handle.get_stream()))
Expand All @@ -466,65 +485,60 @@ neighbor_sample_impl(raft::handle_t const& handle,
if (result_labels) {
cp_result_labels = rmm::device_uvector<label_t>(result_labels->size(), handle.get_stream());

thrust::copy(
handle.get_thrust_policy(),
result_labels->begin(),
result_labels->end(),
cp_result_labels->begin());
thrust::copy(handle.get_thrust_policy(),
result_labels->begin(),
result_labels->end(),
cp_result_labels->begin());
}

std::tie(result_srcs,
result_dsts,
result_weights,
result_edge_ids,
result_edge_types,
result_hops,
result_labels,
result_offsets) = detail::shuffle_and_organize_output(handle,
std::move(result_srcs),
std::move(result_dsts),
std::move(result_weights),
std::move(result_edge_ids),
std::move(result_edge_types),
std::move(result_hops),
std::move(result_labels),
label_to_output_comm_rank);

if (result_labels && (result_offsets->size() != num_unique_labels + 1)) {
result_offsets = rmm::device_uvector<size_t>(num_unique_labels + 1, handle.get_stream());

// Sort labels
thrust::sort(handle.get_thrust_policy(), cp_result_labels->begin(), cp_result_labels->end());

thrust::transform(handle.get_thrust_policy(),
thrust::make_counting_iterator<edge_t>(0),
thrust::make_counting_iterator<edge_t>(result_offsets->size() - 1),
result_offsets->begin() + 1,
[result_labels = raft::device_span<label_t const>(
cp_result_labels->data(), cp_result_labels->size())] __device__(auto idx) {
auto itr_lower = thrust::lower_bound(
thrust::seq, result_labels.begin(), result_labels.end(), idx);

auto itr_upper = thrust::upper_bound(
thrust::seq, result_labels.begin(), result_labels.end(), idx);

auto sampled_label_size = thrust::distance(itr_lower, itr_upper);

// return thrust::distance(itr_lower, itr_upper);
return sampled_label_size;
});

// Run inclusive scan
thrust::inclusive_scan(handle.get_thrust_policy(),
result_offsets->begin() + 1,
result_offsets->end(),
result_offsets->begin() + 1);
}

std::tie(result_srcs, result_dsts, result_weights, result_edge_ids,
result_edge_types, result_hops, result_labels, result_offsets)
= detail::shuffle_and_organize_output(handle,
std::move(result_srcs),
std::move(result_dsts),
std::move(result_weights),
std::move(result_edge_ids),
std::move(result_edge_types),
std::move(result_hops),
std::move(result_labels),
label_to_output_comm_rank);


if (result_labels && (result_offsets->size() != num_unique_labels + 1)) {
result_offsets = rmm::device_uvector<size_t>(num_unique_labels + 1, handle.get_stream());

// Sort labels
thrust::sort(
handle.get_thrust_policy(),
cp_result_labels->begin(),
cp_result_labels->end());


thrust::transform(
handle.get_thrust_policy(),
thrust::make_counting_iterator<edge_t>(0),
thrust::make_counting_iterator<edge_t>(result_offsets->size() - 1),
result_offsets->begin() + 1,
[
result_labels = raft::device_span<label_t const>(
cp_result_labels->data(),
cp_result_labels->size())
] __device__(auto idx) {
auto itr_lower = thrust::lower_bound(
thrust::seq, result_labels.begin(), result_labels.end(), idx);

auto itr_upper = thrust::upper_bound(
thrust::seq, result_labels.begin(), result_labels.end(), idx);

auto sampled_label_size = thrust::distance(itr_lower, itr_upper);

//return thrust::distance(itr_lower, itr_upper);
return sampled_label_size;
});

// Run inclusive scan
thrust::inclusive_scan(handle.get_thrust_policy(),
result_offsets->begin() + 1,
result_offsets->end(),
result_offsets->begin() + 1);
}

return std::make_tuple(std::move(result_srcs),
std::move(result_dsts),
std::move(result_weights),
Expand Down

0 comments on commit 529955a

Please sign in to comment.