Skip to content

Commit

Permalink
leverage new constructor and remove unnecessary code
Browse files Browse the repository at this point in the history
  • Loading branch information
jnke2016 committed Aug 21, 2024
1 parent 9f455bf commit d114534
Showing 1 changed file with 14 additions and 59 deletions.
73 changes: 14 additions & 59 deletions cpp/src/c_api/neighbor_sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,10 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor {
// FIXME: Consolidate 'fan_out_' and 'heterogeneous_fan_out_' into one
// argument with std::variant

//cugraph_type_erased_host_array_t* x = std::get<0>(heterogeneous_fan_out_);

//cugraph::c_api::cugraph_type_erased_host_array_t* x = std::get<0>(*heterogeneous_fan_out_);

auto&& [src, dst, wgt, edge_id, edge_type, hop, edge_label, offsets] =
cugraph::neighbor_sample(
handle_,
Expand All @@ -245,16 +249,14 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor {
(fan_out_ != nullptr) ? std::make_optional<raft::host_span<const int>>(
fan_out_->as_type<const int>(), fan_out_->size_)
: std::nullopt,
/*

(heterogeneous_fan_out_ != nullptr)
? std::make_optional(std::make_tuple(
raft::host_span<const int>{heterogeneous_fan_out_->edge_type_offsets->as_type<int>(),
heterogeneous_fan_out_->edge_type_offsets->size_},
raft::host_span<const int>{heterogeneous_fan_out_->fanout->as_type<int>(),
heterogeneous_fan_out_->fanout->size_}))
raft::host_span<const int>{std::get<0>(*heterogeneous_fan_out_)->as_type<int>(),
std::get<0>(*heterogeneous_fan_out_)->size_},
raft::host_span<const int>{std::get<1>(*heterogeneous_fan_out_)->as_type<int>(),
std::get<1>(*heterogeneous_fan_out_)->size_}))
: std::nullopt,
*/
std::nullopt,
options_.return_hops_,
options_.with_replacement_,
options_.prior_sources_behavior_,
Expand Down Expand Up @@ -441,56 +443,6 @@ struct neighbor_sampling_functor : public cugraph::c_api::abstract_functor {
}
};


struct create_heterogeneous_fanout_functor : public cugraph::c_api::abstract_functor {
raft::handle_t const& handle_;
cugraph::c_api::cugraph_graph_t* graph_;
cugraph::c_api::cugraph_type_erased_host_array_view_t const* edge_type_offsets_;
cugraph::c_api::cugraph_type_erased_host_array_view_t const* fanout_;
cugraph::c_api::cugraph_sample_heterogeneous_fanout_t* result_{};
create_heterogeneous_fanout_functor(::cugraph_resource_handle_t const* handle,
::cugraph_graph_t* graph,
::cugraph_type_erased_host_array_view_t const* edge_type_offsets,
::cugraph_type_erased_host_array_view_t const* fanout)
: abstract_functor(),
handle_(*reinterpret_cast<cugraph::c_api::cugraph_resource_handle_t const*>(handle)->handle_),
graph_(reinterpret_cast<cugraph::c_api::cugraph_graph_t*>(graph)),
edge_type_offsets_(
reinterpret_cast<cugraph::c_api::cugraph_type_erased_host_array_view_t const*>(
edge_type_offsets)),
fanout_(
reinterpret_cast<cugraph::c_api::cugraph_type_erased_host_array_view_t const*>(fanout))
{
}

template <typename vertex_t,
typename edge_t,
typename weight_t,
typename edge_type_type_t,
bool store_transposed,
bool multi_gpu>
void operator()()
{

std::vector<int32_t> edge_type_offsets_copy{(int32_t)edge_type_offsets_->size_};
std::vector<int32_t> fanout_copy{(int32_t)fanout_->size_};

raft::copy(edge_type_offsets_copy.data(),
edge_type_offsets_->as_type<int32_t>(),
edge_type_offsets_->size_,
handle_.get_stream());

raft::copy(
fanout_copy.data(), fanout_->as_type<int32_t>(), fanout_->size_, handle_.get_stream());

auto result_tuple = std::make_tuple(
new cugraph::c_api::cugraph_type_erased_host_array_t(edge_type_offsets_copy, INT32),
new cugraph::c_api::cugraph_type_erased_host_array_t(fanout_copy, INT32)
);

result_ = &result_tuple;
}
};
} // namespace

extern "C" cugraph_error_code_t cugraph_sampling_options_create(
Expand Down Expand Up @@ -1224,7 +1176,10 @@ extern "C" cugraph_error_code_t cugraph_create_heterogeneous_fanout(
cugraph_sample_heterogeneous_fanout_t** heterogeneous_fanout,
cugraph_error_t** error)
{
create_heterogeneous_fanout_functor functor(handle, graph, edge_type_offsets, fanout);

return cugraph::c_api::run_algorithm(graph, functor, heterogeneous_fanout, error);
*heterogeneous_fanout = reinterpret_cast<cugraph_sample_heterogeneous_fanout_t*> (new std::tuple<cugraph::c_api::cugraph_type_erased_host_array_t *, cugraph::c_api::cugraph_type_erased_host_array_t*> {
new cugraph::c_api::cugraph_type_erased_host_array_t(reinterpret_cast<cugraph::c_api::cugraph_type_erased_host_array_view_t const*>(edge_type_offsets)),
new cugraph::c_api::cugraph_type_erased_host_array_t(reinterpret_cast<cugraph::c_api::cugraph_type_erased_host_array_view_t const*>(fanout))});

return CUGRAPH_SUCCESS;
}

0 comments on commit d114534

Please sign in to comment.