Skip to content

Commit

Permalink
MG uniform random walk implementation (#2585)
Browse files Browse the repository at this point in the history
This PR defines a uniform random walk implementation using the neighborhood sampling functions.

This will be refactored once the new sampling primitive (#2580) is implemented, but should provide a stronger starting point than the original code.

Partially addresses #2555

Authors:
  - Chuck Hastings (https://github.com/ChuckHastings)

Approvers:
  - Seunghwa Kang (https://github.com/seunghwak)

URL: #2585
  • Loading branch information
ChuckHastings authored Sep 14, 2022
1 parent 4657c68 commit 5863be2
Show file tree
Hide file tree
Showing 16 changed files with 1,348 additions and 418 deletions.
3 changes: 2 additions & 1 deletion cpp/src/sampling/detail/graph_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ gather_local_edges(
const rmm::device_uvector<typename GraphViewType::vertex_type>& active_majors,
rmm::device_uvector<typename GraphViewType::edge_type>&& minor_map,
typename GraphViewType::edge_type indices_per_major,
const rmm::device_uvector<typename GraphViewType::edge_type>& global_degree_offsets);
const rmm::device_uvector<typename GraphViewType::edge_type>& global_degree_offsets,
bool remove_invalid_vertices = true);

/**
* @brief Gather edge list for specified vertices
Expand Down
82 changes: 44 additions & 38 deletions cpp/src/sampling/detail/sampling_utils_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,8 @@ gather_local_edges(
const rmm::device_uvector<typename GraphViewType::vertex_type>& active_majors,
rmm::device_uvector<typename GraphViewType::edge_type>&& minor_map,
typename GraphViewType::edge_type indices_per_major,
const rmm::device_uvector<typename GraphViewType::edge_type>& global_degree_offsets)
const rmm::device_uvector<typename GraphViewType::edge_type>& global_degree_offsets,
bool remove_invalid_vertices)
{
using vertex_t = typename GraphViewType::vertex_type;
using edge_t = typename GraphViewType::edge_type;
Expand Down Expand Up @@ -417,6 +418,7 @@ gather_local_edges(
}
} else {
minors[index] = invalid_vertex_id;
if (weights != nullptr) { weights[index] = weight_t{0}; }
}
});
} else {
Expand Down Expand Up @@ -485,52 +487,56 @@ gather_local_edges(
edge_index_first[index] = g_dst_index;
} else {
minors[index] = invalid_vertex_id;
if (weights != nullptr) { weights[index] = weight_t{0}; }
}
});
}

if (weights) {
auto input_iter = thrust::make_zip_iterator(
thrust::make_tuple(majors.begin(), minors.begin(), weights->begin()));
if (remove_invalid_vertices) {
if (weights) {
auto input_iter = thrust::make_zip_iterator(
thrust::make_tuple(majors.begin(), minors.begin(), weights->begin()));

CUGRAPH_EXPECTS(minors.size() < static_cast<size_t>(std::numeric_limits<int32_t>::max()),
"remove_if will fail, minors.size() is too large");
CUGRAPH_EXPECTS(minors.size() < std::numeric_limits<int32_t>::max(),
"remove_if will fail, minors.size() is too large");

// FIXME: remove_if has a 32-bit overflow issue (https://github.com/NVIDIA/thrust/issues/1302)
// Seems unlikely here (the goal of sampling is to extract small graphs)
// so not going to work around this for now.
auto compacted_length = thrust::distance(
input_iter,
thrust::remove_if(
handle.get_thrust_policy(),
// FIXME: remove_if has a 32-bit overflow issue
// (https://github.com/NVIDIA/thrust/issues/1302) Seems unlikely here (the goal of sampling
// is to extract small graphs) so not going to work around this for now.
auto compacted_length = thrust::distance(
input_iter,
input_iter + minors.size(),
minors.begin(),
[invalid_vertex_id] __device__(auto dst) { return (dst == invalid_vertex_id); }));
thrust::remove_if(
handle.get_thrust_policy(),
input_iter,
input_iter + minors.size(),
minors.begin(),
[invalid_vertex_id] __device__(auto dst) { return (dst == invalid_vertex_id); }));

majors.resize(compacted_length, handle.get_stream());
minors.resize(compacted_length, handle.get_stream());
weights->resize(compacted_length, handle.get_stream());
} else {
auto input_iter =
thrust::make_zip_iterator(thrust::make_tuple(majors.begin(), minors.begin()));

majors.resize(compacted_length, handle.get_stream());
minors.resize(compacted_length, handle.get_stream());
weights->resize(compacted_length, handle.get_stream());
} else {
auto input_iter = thrust::make_zip_iterator(thrust::make_tuple(majors.begin(), minors.begin()));

CUGRAPH_EXPECTS(minors.size() < static_cast<size_t>(std::numeric_limits<int32_t>::max()),
"remove_if will fail, minors.size() is too large");

auto compacted_length = thrust::distance(
input_iter,
// FIXME: remove_if has a 32-bit overflow issue (https://github.com/NVIDIA/thrust/issues/1302)
// Seems unlikely here (the goal of sampling is to extract small graphs)
// so not going to work around this for now.
thrust::remove_if(
handle.get_thrust_policy(),
input_iter,
input_iter + minors.size(),
minors.begin(),
[invalid_vertex_id] __device__(auto dst) { return (dst == invalid_vertex_id); }));
CUGRAPH_EXPECTS(minors.size() < std::numeric_limits<int32_t>::max(),
"remove_if will fail, minors.size() is too large");

majors.resize(compacted_length, handle.get_stream());
minors.resize(compacted_length, handle.get_stream());
auto compacted_length = thrust::distance(
input_iter,
// FIXME: remove_if has a 32-bit overflow issue
// (https://github.com/NVIDIA/thrust/issues/1302) Seems unlikely here (the goal of
// sampling is to extract small graphs) so not going to work around this for now.
thrust::remove_if(
handle.get_thrust_policy(),
input_iter,
input_iter + minors.size(),
minors.begin(),
[invalid_vertex_id] __device__(auto dst) { return (dst == invalid_vertex_id); }));

majors.resize(compacted_length, handle.get_stream());
minors.resize(compacted_length, handle.get_stream());
}
}

return std::make_tuple(std::move(majors), std::move(minors), std::move(weights));
Expand Down
18 changes: 12 additions & 6 deletions cpp/src/sampling/detail/sampling_utils_mg.cu
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ gather_local_edges(raft::handle_t const& handle,
const rmm::device_uvector<int32_t>& active_majors,
rmm::device_uvector<int32_t>&& minor_map,
int32_t indices_per_major,
const rmm::device_uvector<int32_t>& global_degree_offsets);
const rmm::device_uvector<int32_t>& global_degree_offsets,
bool remove_invalid_vertices);

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
Expand All @@ -160,7 +161,8 @@ gather_local_edges(raft::handle_t const& handle,
const rmm::device_uvector<int32_t>& active_majors,
rmm::device_uvector<int64_t>&& minor_map,
int64_t indices_per_major,
const rmm::device_uvector<int64_t>& global_degree_offsets);
const rmm::device_uvector<int64_t>& global_degree_offsets,
bool remove_invalid_vertices);

template std::tuple<rmm::device_uvector<int64_t>,
rmm::device_uvector<int64_t>,
Expand All @@ -170,7 +172,8 @@ gather_local_edges(raft::handle_t const& handle,
const rmm::device_uvector<int64_t>& active_majors,
rmm::device_uvector<int64_t>&& minor_map,
int64_t indices_per_major,
const rmm::device_uvector<int64_t>& global_degree_offsets);
const rmm::device_uvector<int64_t>& global_degree_offsets,
bool remove_invalid_vertices);

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
Expand All @@ -180,7 +183,8 @@ gather_local_edges(raft::handle_t const& handle,
const rmm::device_uvector<int32_t>& active_majors,
rmm::device_uvector<int32_t>&& minor_map,
int32_t indices_per_major,
const rmm::device_uvector<int32_t>& global_degree_offsets);
const rmm::device_uvector<int32_t>& global_degree_offsets,
bool remove_invalid_vertices);

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
Expand All @@ -190,7 +194,8 @@ gather_local_edges(raft::handle_t const& handle,
const rmm::device_uvector<int32_t>& active_majors,
rmm::device_uvector<int64_t>&& minor_map,
int64_t indices_per_major,
const rmm::device_uvector<int64_t>& global_degree_offsets);
const rmm::device_uvector<int64_t>& global_degree_offsets,
bool remove_invalid_vertices);

template std::tuple<rmm::device_uvector<int64_t>,
rmm::device_uvector<int64_t>,
Expand All @@ -200,7 +205,8 @@ gather_local_edges(raft::handle_t const& handle,
const rmm::device_uvector<int64_t>& active_majors,
rmm::device_uvector<int64_t>&& minor_map,
int64_t indices_per_major,
const rmm::device_uvector<int64_t>& global_degree_offsets);
const rmm::device_uvector<int64_t>& global_degree_offsets,
bool remove_invalid_vertices);

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
Expand Down
18 changes: 12 additions & 6 deletions cpp/src/sampling/detail/sampling_utils_sg.cu
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ gather_local_edges(raft::handle_t const& handle,
const rmm::device_uvector<int32_t>& active_majors,
rmm::device_uvector<int32_t>&& minor_map,
int32_t indices_per_major,
const rmm::device_uvector<int32_t>& global_degree_offsets);
const rmm::device_uvector<int32_t>& global_degree_offsets,
bool remove_invalid_vertices);

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
Expand All @@ -103,7 +104,8 @@ gather_local_edges(raft::handle_t const& handle,
const rmm::device_uvector<int32_t>& active_majors,
rmm::device_uvector<int64_t>&& minor_map,
int64_t indices_per_major,
const rmm::device_uvector<int64_t>& global_degree_offsets);
const rmm::device_uvector<int64_t>& global_degree_offsets,
bool remove_invalid_vertices);

template std::tuple<rmm::device_uvector<int64_t>,
rmm::device_uvector<int64_t>,
Expand All @@ -113,7 +115,8 @@ gather_local_edges(raft::handle_t const& handle,
const rmm::device_uvector<int64_t>& active_majors,
rmm::device_uvector<int64_t>&& minor_map,
int64_t indices_per_major,
const rmm::device_uvector<int64_t>& global_degree_offsets);
const rmm::device_uvector<int64_t>& global_degree_offsets,
bool remove_invalid_vertices);

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
Expand All @@ -123,7 +126,8 @@ gather_local_edges(raft::handle_t const& handle,
const rmm::device_uvector<int32_t>& active_majors,
rmm::device_uvector<int32_t>&& minor_map,
int32_t indices_per_major,
const rmm::device_uvector<int32_t>& global_degree_offsets);
const rmm::device_uvector<int32_t>& global_degree_offsets,
bool remove_invalid_vertices);

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
Expand All @@ -133,7 +137,8 @@ gather_local_edges(raft::handle_t const& handle,
const rmm::device_uvector<int32_t>& active_majors,
rmm::device_uvector<int64_t>&& minor_map,
int64_t indices_per_major,
const rmm::device_uvector<int64_t>& global_degree_offsets);
const rmm::device_uvector<int64_t>& global_degree_offsets,
bool remove_invalid_vertices);

template std::tuple<rmm::device_uvector<int64_t>,
rmm::device_uvector<int64_t>,
Expand All @@ -143,7 +148,8 @@ gather_local_edges(raft::handle_t const& handle,
const rmm::device_uvector<int64_t>& active_majors,
rmm::device_uvector<int64_t>&& minor_map,
int64_t indices_per_major,
const rmm::device_uvector<int64_t>& global_degree_offsets);
const rmm::device_uvector<int64_t>& global_degree_offsets,
bool remove_invalid_vertices);

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
Expand Down
Loading

0 comments on commit 5863be2

Please sign in to comment.