Skip to content

Commit

Permalink
update type combination
Browse files Browse the repository at this point in the history
  • Loading branch information
jnke2016 committed Sep 27, 2024
1 parent 4e2c8cf commit 2458149
Show file tree
Hide file tree
Showing 6 changed files with 528 additions and 0 deletions.
88 changes: 88 additions & 0 deletions cpp/src/sampling/neighbor_sampling_mg_v32_e32.cu
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,28 @@ heterogeneous_uniform_neighbor_sample(
sampling_flags_t sampling_flags,
bool do_expensive_check);

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
std::optional<rmm::device_uvector<double>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<size_t>>>
heterogeneous_uniform_neighbor_sample(
raft::handle_t const& handle,
raft::random::RngState& rng_state,
graph_view_t<int32_t, int32_t, false, true> const& graph_view,
std::optional<edge_property_view_t<int32_t, double const*>> edge_weight_view,
std::optional<edge_property_view_t<int32_t, int32_t const*>> edge_id_view,
std::optional<edge_property_view_t<int32_t, int32_t const*>> edge_type_view,
raft::device_span<int32_t const> starting_vertices,
std::optional<raft::device_span<int32_t const>> starting_vertex_labels,
std::optional<raft::device_span<int32_t const>> label_to_output_comm_rank,
raft::host_span<int32_t const> fan_out,
int32_t num_edge_types,
sampling_flags_t sampling_flags,
bool do_expensive_check);

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
std::optional<rmm::device_uvector<double>>,
Expand All @@ -172,6 +194,29 @@ heterogeneous_biased_neighbor_sample(
sampling_flags_t sampling_flags,
bool do_expensive_check);

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
std::optional<rmm::device_uvector<float>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<size_t>>>
heterogeneous_biased_neighbor_sample(
raft::handle_t const& handle,
raft::random::RngState& rng_state,
graph_view_t<int32_t, int32_t, false, true> const& graph_view,
std::optional<edge_property_view_t<int32_t, float const*>> edge_weight_view,
std::optional<edge_property_view_t<int32_t, int32_t const*>> edge_id_view,
std::optional<edge_property_view_t<int32_t, int32_t const*>> edge_type_view,
edge_property_view_t<int32_t, float const*> edge_bias_view,
raft::device_span<int32_t const> starting_vertices,
std::optional<raft::device_span<int32_t const>> starting_vertex_labels,
std::optional<raft::device_span<int32_t const>> label_to_output_comm_rank,
raft::host_span<int32_t const> fan_out,
int32_t num_edge_types,
sampling_flags_t sampling_flags,
bool do_expensive_check);

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
std::optional<rmm::device_uvector<float>>,
Expand All @@ -193,6 +238,27 @@ homogeneous_uniform_neighbor_sample(
sampling_flags_t sampling_flags,
bool do_expensive_check);

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
std::optional<rmm::device_uvector<double>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<size_t>>>
homogeneous_uniform_neighbor_sample(
raft::handle_t const& handle,
raft::random::RngState& rng_state,
graph_view_t<int32_t, int32_t, false, true> const& graph_view,
std::optional<edge_property_view_t<int32_t, double const*>> edge_weight_view,
std::optional<edge_property_view_t<int32_t, int32_t const*>> edge_id_view,
std::optional<edge_property_view_t<int32_t, int32_t const*>> edge_type_view,
raft::device_span<int32_t const> starting_vertices,
std::optional<raft::device_span<int32_t const>> starting_vertex_labels,
std::optional<raft::device_span<int32_t const>> label_to_output_comm_rank,
raft::host_span<int32_t const> fan_out,
sampling_flags_t sampling_flags,
bool do_expensive_check);

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
std::optional<rmm::device_uvector<double>>,
Expand All @@ -215,6 +281,28 @@ homogeneous_biased_neighbor_sample(
sampling_flags_t sampling_flags,
bool do_expensive_check);

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
std::optional<rmm::device_uvector<float>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<size_t>>>
homogeneous_biased_neighbor_sample(
raft::handle_t const& handle,
raft::random::RngState& rng_state,
graph_view_t<int32_t, int32_t, false, true> const& graph_view,
std::optional<edge_property_view_t<int32_t, float const*>> edge_weight_view,
std::optional<edge_property_view_t<int32_t, int32_t const*>> edge_id_view,
std::optional<edge_property_view_t<int32_t, int32_t const*>> edge_type_view,
edge_property_view_t<int32_t, float const*> edge_bias_view,
raft::device_span<int32_t const> starting_vertices,
std::optional<raft::device_span<int32_t const>> starting_vertex_labels,
std::optional<raft::device_span<int32_t const>> label_to_output_comm_rank,
raft::host_span<int32_t const> fan_out,
sampling_flags_t sampling_flags,
bool do_expensive_check);




Expand Down
88 changes: 88 additions & 0 deletions cpp/src/sampling/neighbor_sampling_mg_v32_e64.cu
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,28 @@ heterogeneous_uniform_neighbor_sample(
sampling_flags_t sampling_flags,
bool do_expensive_check);

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
std::optional<rmm::device_uvector<double>>,
std::optional<rmm::device_uvector<int64_t>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<size_t>>>
heterogeneous_uniform_neighbor_sample(
raft::handle_t const& handle,
raft::random::RngState& rng_state,
graph_view_t<int32_t, int64_t, false, true> const& graph_view,
std::optional<edge_property_view_t<int64_t, double const*>> edge_weight_view,
std::optional<edge_property_view_t<int64_t, int64_t const*>> edge_id_view,
std::optional<edge_property_view_t<int64_t, int32_t const*>> edge_type_view,
raft::device_span<int32_t const> starting_vertices,
std::optional<raft::device_span<int32_t const>> starting_vertex_labels,
std::optional<raft::device_span<int32_t const>> label_to_output_comm_rank,
raft::host_span<int32_t const> fan_out,
int32_t num_edge_types,
sampling_flags_t sampling_flags,
bool do_expensive_check);

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
std::optional<rmm::device_uvector<double>>,
Expand All @@ -173,6 +195,29 @@ heterogeneous_biased_neighbor_sample(
sampling_flags_t sampling_flags,
bool do_expensive_check);

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
std::optional<rmm::device_uvector<float>>,
std::optional<rmm::device_uvector<int64_t>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<size_t>>>
heterogeneous_biased_neighbor_sample(
raft::handle_t const& handle,
raft::random::RngState& rng_state,
graph_view_t<int32_t, int64_t, false, true> const& graph_view,
std::optional<edge_property_view_t<int64_t, float const*>> edge_weight_view,
std::optional<edge_property_view_t<int64_t, int64_t const*>> edge_id_view,
std::optional<edge_property_view_t<int64_t, int32_t const*>> edge_type_view,
edge_property_view_t<int64_t, float const*> edge_bias_view,
raft::device_span<int32_t const> starting_vertices,
std::optional<raft::device_span<int32_t const>> starting_vertex_labels,
std::optional<raft::device_span<int32_t const>> label_to_output_comm_rank,
raft::host_span<int32_t const> fan_out,
int32_t num_edge_types,
sampling_flags_t sampling_flags,
bool do_expensive_check);

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
std::optional<rmm::device_uvector<float>>,
Expand All @@ -194,6 +239,27 @@ homogeneous_uniform_neighbor_sample(
sampling_flags_t sampling_flags,
bool do_expensive_check);

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
std::optional<rmm::device_uvector<double>>,
std::optional<rmm::device_uvector<int64_t>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<size_t>>>
homogeneous_uniform_neighbor_sample(
raft::handle_t const& handle,
raft::random::RngState& rng_state,
graph_view_t<int32_t, int64_t, false, true> const& graph_view,
std::optional<edge_property_view_t<int64_t, double const*>> edge_weight_view,
std::optional<edge_property_view_t<int64_t, int64_t const*>> edge_id_view,
std::optional<edge_property_view_t<int64_t, int32_t const*>> edge_type_view,
raft::device_span<int32_t const> starting_vertices,
std::optional<raft::device_span<int32_t const>> starting_vertex_labels,
std::optional<raft::device_span<int32_t const>> label_to_output_comm_rank,
raft::host_span<int32_t const> fan_out,
sampling_flags_t sampling_flags,
bool do_expensive_check);

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
std::optional<rmm::device_uvector<double>>,
Expand All @@ -216,4 +282,26 @@ homogeneous_biased_neighbor_sample(
sampling_flags_t sampling_flags,
bool do_expensive_check);

template std::tuple<rmm::device_uvector<int32_t>,
rmm::device_uvector<int32_t>,
std::optional<rmm::device_uvector<float>>,
std::optional<rmm::device_uvector<int64_t>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<size_t>>>
homogeneous_biased_neighbor_sample(
raft::handle_t const& handle,
raft::random::RngState& rng_state,
graph_view_t<int32_t, int64_t, false, true> const& graph_view,
std::optional<edge_property_view_t<int64_t, float const*>> edge_weight_view,
std::optional<edge_property_view_t<int64_t, int64_t const*>> edge_id_view,
std::optional<edge_property_view_t<int64_t, int32_t const*>> edge_type_view,
edge_property_view_t<int64_t, float const*> edge_bias_view,
raft::device_span<int32_t const> starting_vertices,
std::optional<raft::device_span<int32_t const>> starting_vertex_labels,
std::optional<raft::device_span<int32_t const>> label_to_output_comm_rank,
raft::host_span<int32_t const> fan_out,
sampling_flags_t sampling_flags,
bool do_expensive_check);

} // namespace cugraph
88 changes: 88 additions & 0 deletions cpp/src/sampling/neighbor_sampling_mg_v64_e64.cu
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,28 @@ heterogeneous_uniform_neighbor_sample(
sampling_flags_t sampling_flags,
bool do_expensive_check);

template std::tuple<rmm::device_uvector<int64_t>,
rmm::device_uvector<int64_t>,
std::optional<rmm::device_uvector<double>>,
std::optional<rmm::device_uvector<int64_t>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<size_t>>>
heterogeneous_uniform_neighbor_sample(
raft::handle_t const& handle,
raft::random::RngState& rng_state,
graph_view_t<int64_t, int64_t, false, true> const& graph_view,
std::optional<edge_property_view_t<int64_t, double const*>> edge_weight_view,
std::optional<edge_property_view_t<int64_t, int64_t const*>> edge_id_view,
std::optional<edge_property_view_t<int64_t, int32_t const*>> edge_type_view,
raft::device_span<int64_t const> starting_vertices,
std::optional<raft::device_span<int32_t const>> starting_vertex_labels,
std::optional<raft::device_span<int32_t const>> label_to_output_comm_rank,
raft::host_span<int32_t const> fan_out,
int32_t num_edge_types,
sampling_flags_t sampling_flags,
bool do_expensive_check);

template std::tuple<rmm::device_uvector<int64_t>,
rmm::device_uvector<int64_t>,
std::optional<rmm::device_uvector<double>>,
Expand All @@ -172,6 +194,29 @@ heterogeneous_biased_neighbor_sample(
sampling_flags_t sampling_flags,
bool do_expensive_check);

template std::tuple<rmm::device_uvector<int64_t>,
rmm::device_uvector<int64_t>,
std::optional<rmm::device_uvector<float>>,
std::optional<rmm::device_uvector<int64_t>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<size_t>>>
heterogeneous_biased_neighbor_sample(
raft::handle_t const& handle,
raft::random::RngState& rng_state,
graph_view_t<int64_t, int64_t, false, true> const& graph_view,
std::optional<edge_property_view_t<int64_t, float const*>> edge_weight_view,
std::optional<edge_property_view_t<int64_t, int64_t const*>> edge_id_view,
std::optional<edge_property_view_t<int64_t, int32_t const*>> edge_type_view,
edge_property_view_t<int64_t, float const*> edge_bias_view,
raft::device_span<int64_t const> starting_vertices,
std::optional<raft::device_span<int32_t const>> starting_vertex_labels,
std::optional<raft::device_span<int32_t const>> label_to_output_comm_rank,
raft::host_span<int32_t const> fan_out,
int32_t num_edge_types,
sampling_flags_t sampling_flags,
bool do_expensive_check);

template std::tuple<rmm::device_uvector<int64_t>,
rmm::device_uvector<int64_t>,
std::optional<rmm::device_uvector<float>>,
Expand All @@ -193,6 +238,27 @@ homogeneous_uniform_neighbor_sample(
sampling_flags_t sampling_flags,
bool do_expensive_check);

template std::tuple<rmm::device_uvector<int64_t>,
rmm::device_uvector<int64_t>,
std::optional<rmm::device_uvector<double>>,
std::optional<rmm::device_uvector<int64_t>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<size_t>>>
homogeneous_uniform_neighbor_sample(
raft::handle_t const& handle,
raft::random::RngState& rng_state,
graph_view_t<int64_t, int64_t, false, true> const& graph_view,
std::optional<edge_property_view_t<int64_t, double const*>> edge_weight_view,
std::optional<edge_property_view_t<int64_t, int64_t const*>> edge_id_view,
std::optional<edge_property_view_t<int64_t, int32_t const*>> edge_type_view,
raft::device_span<int64_t const> starting_vertices,
std::optional<raft::device_span<int32_t const>> starting_vertex_labels,
std::optional<raft::device_span<int32_t const>> label_to_output_comm_rank,
raft::host_span<int32_t const> fan_out,
sampling_flags_t sampling_flags,
bool do_expensive_check);

template std::tuple<rmm::device_uvector<int64_t>,
rmm::device_uvector<int64_t>,
std::optional<rmm::device_uvector<double>>,
Expand All @@ -215,4 +281,26 @@ homogeneous_biased_neighbor_sample(
sampling_flags_t sampling_flags,
bool do_expensive_check);

template std::tuple<rmm::device_uvector<int64_t>,
rmm::device_uvector<int64_t>,
std::optional<rmm::device_uvector<float>>,
std::optional<rmm::device_uvector<int64_t>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<int32_t>>,
std::optional<rmm::device_uvector<size_t>>>
homogeneous_biased_neighbor_sample(
raft::handle_t const& handle,
raft::random::RngState& rng_state,
graph_view_t<int64_t, int64_t, false, true> const& graph_view,
std::optional<edge_property_view_t<int64_t, float const*>> edge_weight_view,
std::optional<edge_property_view_t<int64_t, int64_t const*>> edge_id_view,
std::optional<edge_property_view_t<int64_t, int32_t const*>> edge_type_view,
edge_property_view_t<int64_t, float const*> edge_bias_view,
raft::device_span<int64_t const> starting_vertices,
std::optional<raft::device_span<int32_t const>> starting_vertex_labels,
std::optional<raft::device_span<int32_t const>> label_to_output_comm_rank,
raft::host_span<int32_t const> fan_out,
sampling_flags_t sampling_flags,
bool do_expensive_check);

} // namespace cugraph
Loading

0 comments on commit 2458149

Please sign in to comment.