diff --git a/cpp/include/cugraph/sampling_functions.hpp b/cpp/include/cugraph/sampling_functions.hpp index 5b884fa611e..7acda3e8d4c 100644 --- a/cpp/include/cugraph/sampling_functions.hpp +++ b/cpp/include/cugraph/sampling_functions.hpp @@ -130,7 +130,8 @@ uniform_neighbor_sample( std::optional, raft::device_span>> label_to_output_comm_rank, std::optional> fan_out, - std::optional, raft::host_span>> heterogeneous_fan_out, + std::optional, raft::host_span>> + heterogeneous_fan_out, raft::random::RngState& rng_state, bool return_hops, bool with_replacement = true, @@ -211,7 +212,7 @@ uniform_neighbor_sample( * optional int32_t hop, optional label_t label, optional size_t offsets) */ -// tuple with 3 elements. 1 - edge_type, 2- host span of size_t, 3 - fanout vector +// tuple with 3 elements. 1 - edge_type, 2- host span of size_t, 3 - fanout vector template #include @@ -113,7 +114,8 @@ struct uniform_neighbor_sampling_functor : public cugraph::c_api::abstract_funct fan_out_( reinterpret_cast(fan_out)), heterogeneous_fan_out_( - reinterpret_cast(heterogeneous_fan_out)), + reinterpret_cast( + heterogeneous_fan_out)), rng_state_(reinterpret_cast(rng_state)), options_(options), do_expensive_check_(do_expensive_check) @@ -221,16 +223,16 @@ struct uniform_neighbor_sampling_functor : public cugraph::c_api::abstract_funct raft::device_span{label_to_comm_rank_->as_type(), label_to_comm_rank_->size_})) : std::nullopt, - (fan_out_ != nullptr) - ? std::make_optional>(fan_out_->as_type(), fan_out_->size_) - : std::nullopt, - + (fan_out_ != nullptr) ? std::make_optional>( + fan_out_->as_type(), fan_out_->size_) + : std::nullopt, + (heterogeneous_fan_out_ != nullptr) ? std::make_optional(std::make_tuple( - raft::host_span{heterogeneous_fan_out_->edge_type_id->as_type(), - heterogeneous_fan_out_->edge_type_id->size_}, - raft::host_span{heterogeneous_fan_out_->fanout->as_type(), - heterogeneous_fan_out_->fanout->size_})) + raft::host_span{heterogeneous_fan_out_->edge_type_id->as_type(), + heterogeneous_fan_out_->edge_type_id->size_}, + raft::host_span{heterogeneous_fan_out_->fanout->as_type(), + heterogeneous_fan_out_->fanout->size_})) : std::nullopt, rng_state_->rng_state_, options_.return_hops_, @@ -769,58 +771,61 @@ struct biased_neighbor_sampling_functor : public cugraph::c_api::abstract_functo } }; - struct create_heterogeneous_fanout_functor : public cugraph::c_api::abstract_functor { - raft::handle_t const& handle_; - cugraph::c_api::cugraph_graph_t* graph_; - cugraph::c_api::cugraph_type_erased_host_array_view_t const* edge_type_size_; - cugraph::c_api::cugraph_type_erased_host_array_view_t const* fanout_; - // FIXME: This type doesn't exist: instead create an 'std::tuple' - cugraph::c_api::cugraph_sample_heterogeneous_fanout_t* result_{}; - - create_heterogeneous_fanout_functor(::cugraph_resource_handle_t const* handle, - ::cugraph_graph_t* graph, - ::cugraph_type_erased_host_array_view_t const* edge_type_size, - ::cugraph_type_erased_host_array_view_t const* fanout) - : abstract_functor(), - handle_(*reinterpret_cast(handle)->handle_), - graph_(reinterpret_cast(graph)), - edge_type_size_( - reinterpret_cast(edge_type_size)), - fanout_( - reinterpret_cast(fanout)) - { - } +struct create_heterogeneous_fanout_functor : public cugraph::c_api::abstract_functor { + raft::handle_t const& handle_; + cugraph::c_api::cugraph_graph_t* graph_; + cugraph::c_api::cugraph_type_erased_host_array_view_t const* edge_type_size_; + cugraph::c_api::cugraph_type_erased_host_array_view_t const* fanout_; + // FIXME: This type doesn't exist: instead create an + // 'std::tuple' + cugraph::c_api::cugraph_sample_heterogeneous_fanout_t* result_{}; + + create_heterogeneous_fanout_functor(::cugraph_resource_handle_t const* handle, + ::cugraph_graph_t* graph, + ::cugraph_type_erased_host_array_view_t const* edge_type_size, + ::cugraph_type_erased_host_array_view_t const* fanout) + : abstract_functor(), + handle_(*reinterpret_cast(handle)->handle_), + graph_(reinterpret_cast(graph)), + edge_type_size_( + reinterpret_cast( + edge_type_size)), + fanout_( + reinterpret_cast(fanout)) + { + } - template - void operator()() - { - // FIXME: Remove this check as it is not necessary - if constexpr (!cugraph::is_candidate::value) { - unsupported(); - } else { - std::vector edge_type_size_copy{(int32_t) edge_type_size_->size_}; - std::vector fanout_copy{(int32_t) fanout_->size_}; - - raft::copy( - edge_type_size_copy.data(), edge_type_size_->as_type(), edge_type_size_->size_, handle_.get_stream()); - - raft::copy( - fanout_copy.data(), fanout_->as_type(), fanout_->size_, handle_.get_stream()); - - // std::tuple (template) - // result_ = new std::tuple