Skip to content

Commit

Permalink
add arguments and type check
Browse files Browse the repository at this point in the history
  • Loading branch information
jnke2016 committed Aug 23, 2024
1 parent 1e0ef27 commit de79620
Showing 1 changed file with 36 additions and 12 deletions.
48 changes: 36 additions & 12 deletions cpp/src/c_api/neighbor_sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1007,7 +1007,7 @@ cugraph_error_code_t cugraph_neighbor_sample(
const cugraph_type_erased_device_array_view_t* label_to_comm_rank,
const cugraph_type_erased_device_array_view_t* label_offsets,
const cugraph_type_erased_host_array_view_t* fan_out,
const cugraph_sample_heterogeneous_fanout_t* heterogeneous_fanout,
const cugraph_sample_heterogeneous_fanout_t* heterogeneous_fan_out,
const cugraph_sampling_options_t* options,
bool_t is_biased,
bool_t do_expensive_check,
Expand Down Expand Up @@ -1047,6 +1047,36 @@ cugraph_error_code_t cugraph_neighbor_sample(
CUGRAPH_INVALID_INPUT,
"cannot specify label_to_comm_rank unless label_list is also specified",
*error);

CAPI_EXPECTS(!((fan_out != nullptr) && (heterogeneous_fan_out != nullptr)),
CUGRAPH_INVALID_INPUT,
"cannot specify both fan_out and heterogeneous_fan_out",
*error);

if (fan_out != nullptr) {
CAPI_EXPECTS(reinterpret_cast<cugraph::c_api::cugraph_type_erased_host_array_view_t const*>(
fan_out)
->type_ == INT32,
CUGRAPH_INVALID_INPUT,
"fan_out type must be INT32",
*error);

} else {

CAPI_EXPECTS(reinterpret_cast<cugraph::c_api::cugraph_type_erased_host_array_view_t const*>(
std::get<0>(*reinterpret_cast<cugraph::c_api::cugraph_sample_heterogeneous_fanout_t const*>(heterogeneous_fan_out)))
->type_ == INT32,
CUGRAPH_INVALID_INPUT,
"edge type offsets type must be INT32",
*error);

CAPI_EXPECTS(reinterpret_cast<cugraph::c_api::cugraph_type_erased_host_array_view_t const*>(
std::get<0>(*reinterpret_cast<cugraph::c_api::cugraph_sample_heterogeneous_fanout_t const*>(heterogeneous_fan_out)))
->type_ == INT32,
CUGRAPH_INVALID_INPUT,
"fan_out values type must be INT32",
*error);
}

CAPI_EXPECTS(reinterpret_cast<cugraph::c_api::cugraph_graph_t*>(graph)->vertex_type_ ==
reinterpret_cast<cugraph::c_api::cugraph_type_erased_device_array_view_t const*>(
Expand All @@ -1056,12 +1086,6 @@ cugraph_error_code_t cugraph_neighbor_sample(
"vertex type of graph and start_vertices must match",
*error);

CAPI_EXPECTS(
reinterpret_cast<cugraph::c_api::cugraph_type_erased_host_array_view_t const*>(fan_out)
->type_ == INT32,
CUGRAPH_INVALID_INPUT,
"fan_out should be of type int",
*error);

neighbor_sampling_functor functor{handle,
rng_state,
Expand All @@ -1073,7 +1097,7 @@ cugraph_error_code_t cugraph_neighbor_sample(
label_to_comm_rank,
label_offsets,
fan_out,
heterogeneous_fanout,
heterogeneous_fan_out,
std::move(options_cpp),
is_biased,
do_expensive_check};
Expand Down Expand Up @@ -1166,14 +1190,14 @@ extern "C" cugraph_error_code_t cugraph_create_heterogeneous_fanout(
const cugraph_resource_handle_t* handle,
cugraph_graph_t* graph,
const cugraph_type_erased_host_array_view_t* edge_type_offsets,
const cugraph_type_erased_host_array_view_t* fanout,
cugraph_sample_heterogeneous_fanout_t** heterogeneous_fanout,
const cugraph_type_erased_host_array_view_t* fan_out,
cugraph_sample_heterogeneous_fanout_t** heterogeneous_fan_out,
cugraph_error_t** error)
{

*heterogeneous_fanout = reinterpret_cast<cugraph_sample_heterogeneous_fanout_t*> (new std::tuple<cugraph::c_api::cugraph_type_erased_host_array_t *, cugraph::c_api::cugraph_type_erased_host_array_t*> {
*heterogeneous_fan_out = reinterpret_cast<cugraph_sample_heterogeneous_fanout_t*> (new std::tuple<cugraph::c_api::cugraph_type_erased_host_array_t *, cugraph::c_api::cugraph_type_erased_host_array_t*> {
new cugraph::c_api::cugraph_type_erased_host_array_t(reinterpret_cast<cugraph::c_api::cugraph_type_erased_host_array_view_t const*>(edge_type_offsets)),
new cugraph::c_api::cugraph_type_erased_host_array_t(reinterpret_cast<cugraph::c_api::cugraph_type_erased_host_array_view_t const*>(fanout))});
new cugraph::c_api::cugraph_type_erased_host_array_t(reinterpret_cast<cugraph::c_api::cugraph_type_erased_host_array_view_t const*>(fan_out))});

return CUGRAPH_SUCCESS;
}

0 comments on commit de79620

Please sign in to comment.