From 2458149859f6e666c6742d5ba91e51960a8c54da Mon Sep 17 00:00:00 2001 From: jnke2016 Date: Fri, 27 Sep 2024 15:13:01 -0700 Subject: [PATCH] update type combination --- .../sampling/neighbor_sampling_mg_v32_e32.cu | 88 +++++++++++++++++++ .../sampling/neighbor_sampling_mg_v32_e64.cu | 88 +++++++++++++++++++ .../sampling/neighbor_sampling_mg_v64_e64.cu | 88 +++++++++++++++++++ .../sampling/neighbor_sampling_sg_v32_e32.cu | 88 +++++++++++++++++++ .../sampling/neighbor_sampling_sg_v32_e64.cu | 88 +++++++++++++++++++ .../sampling/neighbor_sampling_sg_v64_e64.cu | 88 +++++++++++++++++++ 6 files changed, 528 insertions(+) diff --git a/cpp/src/sampling/neighbor_sampling_mg_v32_e32.cu b/cpp/src/sampling/neighbor_sampling_mg_v32_e32.cu index ea2444c1e3c..b96a73b722d 100644 --- a/cpp/src/sampling/neighbor_sampling_mg_v32_e32.cu +++ b/cpp/src/sampling/neighbor_sampling_mg_v32_e32.cu @@ -149,6 +149,28 @@ heterogeneous_uniform_neighbor_sample( sampling_flags_t sampling_flags, bool do_expensive_check); +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +heterogeneous_uniform_neighbor_sample( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional> label_to_output_comm_rank, + raft::host_span fan_out, + int32_t num_edge_types, + sampling_flags_t sampling_flags, + bool do_expensive_check); + template std::tuple, rmm::device_uvector, std::optional>, @@ -172,6 +194,29 @@ heterogeneous_biased_neighbor_sample( sampling_flags_t sampling_flags, bool do_expensive_check); +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +heterogeneous_biased_neighbor_sample( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + edge_property_view_t edge_bias_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional> label_to_output_comm_rank, + raft::host_span fan_out, + int32_t num_edge_types, + sampling_flags_t sampling_flags, + bool do_expensive_check); + template std::tuple, rmm::device_uvector, std::optional>, @@ -193,6 +238,27 @@ homogeneous_uniform_neighbor_sample( sampling_flags_t sampling_flags, bool do_expensive_check); +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +homogeneous_uniform_neighbor_sample( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional> label_to_output_comm_rank, + raft::host_span fan_out, + sampling_flags_t sampling_flags, + bool do_expensive_check); + template std::tuple, rmm::device_uvector, std::optional>, @@ -215,6 +281,28 @@ homogeneous_biased_neighbor_sample( sampling_flags_t sampling_flags, bool do_expensive_check); +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +homogeneous_biased_neighbor_sample( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + edge_property_view_t edge_bias_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional> label_to_output_comm_rank, + raft::host_span fan_out, + sampling_flags_t sampling_flags, + bool do_expensive_check); + diff --git a/cpp/src/sampling/neighbor_sampling_mg_v32_e64.cu b/cpp/src/sampling/neighbor_sampling_mg_v32_e64.cu index 846fb33eb21..2a387405900 100644 --- a/cpp/src/sampling/neighbor_sampling_mg_v32_e64.cu +++ b/cpp/src/sampling/neighbor_sampling_mg_v32_e64.cu @@ -150,6 +150,28 @@ heterogeneous_uniform_neighbor_sample( sampling_flags_t sampling_flags, bool do_expensive_check); +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +heterogeneous_uniform_neighbor_sample( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional> label_to_output_comm_rank, + raft::host_span fan_out, + int32_t num_edge_types, + sampling_flags_t sampling_flags, + bool do_expensive_check); + template std::tuple, rmm::device_uvector, std::optional>, @@ -173,6 +195,29 @@ heterogeneous_biased_neighbor_sample( sampling_flags_t sampling_flags, bool do_expensive_check); +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +heterogeneous_biased_neighbor_sample( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + edge_property_view_t edge_bias_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional> label_to_output_comm_rank, + raft::host_span fan_out, + int32_t num_edge_types, + sampling_flags_t sampling_flags, + bool do_expensive_check); + template std::tuple, rmm::device_uvector, std::optional>, @@ -194,6 +239,27 @@ homogeneous_uniform_neighbor_sample( sampling_flags_t sampling_flags, bool do_expensive_check); +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +homogeneous_uniform_neighbor_sample( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional> label_to_output_comm_rank, + raft::host_span fan_out, + sampling_flags_t sampling_flags, + bool do_expensive_check); + template std::tuple, rmm::device_uvector, std::optional>, @@ -216,4 +282,26 @@ homogeneous_biased_neighbor_sample( sampling_flags_t sampling_flags, bool do_expensive_check); +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +homogeneous_biased_neighbor_sample( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + edge_property_view_t edge_bias_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional> label_to_output_comm_rank, + raft::host_span fan_out, + sampling_flags_t sampling_flags, + bool do_expensive_check); + } // namespace cugraph diff --git a/cpp/src/sampling/neighbor_sampling_mg_v64_e64.cu b/cpp/src/sampling/neighbor_sampling_mg_v64_e64.cu index 72eec5f6782..505deec51f5 100644 --- a/cpp/src/sampling/neighbor_sampling_mg_v64_e64.cu +++ b/cpp/src/sampling/neighbor_sampling_mg_v64_e64.cu @@ -149,6 +149,28 @@ heterogeneous_uniform_neighbor_sample( sampling_flags_t sampling_flags, bool do_expensive_check); +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +heterogeneous_uniform_neighbor_sample( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional> label_to_output_comm_rank, + raft::host_span fan_out, + int32_t num_edge_types, + sampling_flags_t sampling_flags, + bool do_expensive_check); + template std::tuple, rmm::device_uvector, std::optional>, @@ -172,6 +194,29 @@ heterogeneous_biased_neighbor_sample( sampling_flags_t sampling_flags, bool do_expensive_check); +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +heterogeneous_biased_neighbor_sample( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + edge_property_view_t edge_bias_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional> label_to_output_comm_rank, + raft::host_span fan_out, + int32_t num_edge_types, + sampling_flags_t sampling_flags, + bool do_expensive_check); + template std::tuple, rmm::device_uvector, std::optional>, @@ -193,6 +238,27 @@ homogeneous_uniform_neighbor_sample( sampling_flags_t sampling_flags, bool do_expensive_check); +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +homogeneous_uniform_neighbor_sample( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional> label_to_output_comm_rank, + raft::host_span fan_out, + sampling_flags_t sampling_flags, + bool do_expensive_check); + template std::tuple, rmm::device_uvector, std::optional>, @@ -215,4 +281,26 @@ homogeneous_biased_neighbor_sample( sampling_flags_t sampling_flags, bool do_expensive_check); +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +homogeneous_biased_neighbor_sample( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + edge_property_view_t edge_bias_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional> label_to_output_comm_rank, + raft::host_span fan_out, + sampling_flags_t sampling_flags, + bool do_expensive_check); + } // namespace cugraph diff --git a/cpp/src/sampling/neighbor_sampling_sg_v32_e32.cu b/cpp/src/sampling/neighbor_sampling_sg_v32_e32.cu index 46db410f5ae..5cb3fd75ce6 100644 --- a/cpp/src/sampling/neighbor_sampling_sg_v32_e32.cu +++ b/cpp/src/sampling/neighbor_sampling_sg_v32_e32.cu @@ -149,6 +149,28 @@ heterogeneous_uniform_neighbor_sample( sampling_flags_t sampling_flags, bool do_expensive_check); +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +heterogeneous_uniform_neighbor_sample( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional> label_to_output_comm_rank, + raft::host_span fan_out, + int32_t num_edge_types, + sampling_flags_t sampling_flags, + bool do_expensive_check); + template std::tuple, rmm::device_uvector, std::optional>, @@ -172,6 +194,29 @@ heterogeneous_biased_neighbor_sample( sampling_flags_t sampling_flags, bool do_expensive_check); +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +heterogeneous_biased_neighbor_sample( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + edge_property_view_t edge_bias_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional> label_to_output_comm_rank, + raft::host_span fan_out, + int32_t num_edge_types, + sampling_flags_t sampling_flags, + bool do_expensive_check); + template std::tuple, rmm::device_uvector, std::optional>, @@ -193,6 +238,27 @@ homogeneous_uniform_neighbor_sample( sampling_flags_t sampling_flags, bool do_expensive_check); +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +homogeneous_uniform_neighbor_sample( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional> label_to_output_comm_rank, + raft::host_span fan_out, + sampling_flags_t sampling_flags, + bool do_expensive_check); + template std::tuple, rmm::device_uvector, std::optional>, @@ -215,4 +281,26 @@ homogeneous_biased_neighbor_sample( sampling_flags_t sampling_flags, bool do_expensive_check); +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +homogeneous_biased_neighbor_sample( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + edge_property_view_t edge_bias_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional> label_to_output_comm_rank, + raft::host_span fan_out, + sampling_flags_t sampling_flags, + bool do_expensive_check); + } // namespace cugraph diff --git a/cpp/src/sampling/neighbor_sampling_sg_v32_e64.cu b/cpp/src/sampling/neighbor_sampling_sg_v32_e64.cu index 93c0fe176f4..47cb064cd37 100644 --- a/cpp/src/sampling/neighbor_sampling_sg_v32_e64.cu +++ b/cpp/src/sampling/neighbor_sampling_sg_v32_e64.cu @@ -149,6 +149,28 @@ heterogeneous_uniform_neighbor_sample( sampling_flags_t sampling_flags, bool do_expensive_check); +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +heterogeneous_uniform_neighbor_sample( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional> label_to_output_comm_rank, + raft::host_span fan_out, + int32_t num_edge_types, + sampling_flags_t sampling_flags, + bool do_expensive_check); + template std::tuple, rmm::device_uvector, std::optional>, @@ -172,6 +194,29 @@ heterogeneous_biased_neighbor_sample( sampling_flags_t sampling_flags, bool do_expensive_check); +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +heterogeneous_biased_neighbor_sample( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + edge_property_view_t edge_bias_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional> label_to_output_comm_rank, + raft::host_span fan_out, + int32_t num_edge_types, + sampling_flags_t sampling_flags, + bool do_expensive_check); + template std::tuple, rmm::device_uvector, std::optional>, @@ -193,6 +238,27 @@ homogeneous_uniform_neighbor_sample( sampling_flags_t sampling_flags, bool do_expensive_check); +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +homogeneous_uniform_neighbor_sample( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional> label_to_output_comm_rank, + raft::host_span fan_out, + sampling_flags_t sampling_flags, + bool do_expensive_check); + template std::tuple, rmm::device_uvector, std::optional>, @@ -215,4 +281,26 @@ homogeneous_biased_neighbor_sample( sampling_flags_t sampling_flags, bool do_expensive_check); +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +homogeneous_biased_neighbor_sample( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + edge_property_view_t edge_bias_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional> label_to_output_comm_rank, + raft::host_span fan_out, + sampling_flags_t sampling_flags, + bool do_expensive_check); + } // namespace cugraph diff --git a/cpp/src/sampling/neighbor_sampling_sg_v64_e64.cu b/cpp/src/sampling/neighbor_sampling_sg_v64_e64.cu index c3764aa0852..6aa8c71429a 100644 --- a/cpp/src/sampling/neighbor_sampling_sg_v64_e64.cu +++ b/cpp/src/sampling/neighbor_sampling_sg_v64_e64.cu @@ -149,6 +149,28 @@ heterogeneous_uniform_neighbor_sample( sampling_flags_t sampling_flags, bool do_expensive_check); +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +heterogeneous_uniform_neighbor_sample( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional> label_to_output_comm_rank, + raft::host_span fan_out, + int32_t num_edge_types, + sampling_flags_t sampling_flags, + bool do_expensive_check); + template std::tuple, rmm::device_uvector, std::optional>, @@ -172,6 +194,29 @@ heterogeneous_biased_neighbor_sample( sampling_flags_t sampling_flags, bool do_expensive_check); +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +heterogeneous_biased_neighbor_sample( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + edge_property_view_t edge_bias_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional> label_to_output_comm_rank, + raft::host_span fan_out, + int32_t num_edge_types, + sampling_flags_t sampling_flags, + bool do_expensive_check); + template std::tuple, rmm::device_uvector, std::optional>, @@ -193,6 +238,27 @@ homogeneous_uniform_neighbor_sample( sampling_flags_t sampling_flags, bool do_expensive_check); +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +homogeneous_uniform_neighbor_sample( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional> label_to_output_comm_rank, + raft::host_span fan_out, + sampling_flags_t sampling_flags, + bool do_expensive_check); + template std::tuple, rmm::device_uvector, std::optional>, @@ -215,4 +281,26 @@ homogeneous_biased_neighbor_sample( sampling_flags_t sampling_flags, bool do_expensive_check); +template std::tuple, + rmm::device_uvector, + std::optional>, + std::optional>, + std::optional>, + std::optional>, + std::optional>> +homogeneous_biased_neighbor_sample( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> edge_weight_view, + std::optional> edge_id_view, + std::optional> edge_type_view, + edge_property_view_t edge_bias_view, + raft::device_span starting_vertices, + std::optional> starting_vertex_labels, + std::optional> label_to_output_comm_rank, + raft::host_span fan_out, + sampling_flags_t sampling_flags, + bool do_expensive_check); + } // namespace cugraph