Skip to content

Commit

Permalink
fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
jnke2016 committed Aug 13, 2024
1 parent bb5a3e2 commit 10fa86d
Show file tree
Hide file tree
Showing 14 changed files with 149 additions and 122 deletions.
5 changes: 3 additions & 2 deletions cpp/include/cugraph/sampling_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ uniform_neighbor_sample(
std::optional<std::tuple<raft::device_span<label_t const>, raft::device_span<int32_t const>>>
label_to_output_comm_rank,
std::optional<raft::host_span<int32_t const>> fan_out,
std::optional<std::tuple<raft::host_span<int32_t const>, raft::host_span<int32_t const>>> heterogeneous_fan_out,
std::optional<std::tuple<raft::host_span<int32_t const>, raft::host_span<int32_t const>>>
heterogeneous_fan_out,
raft::random::RngState& rng_state,
bool return_hops,
bool with_replacement = true,
Expand Down Expand Up @@ -211,7 +212,7 @@ uniform_neighbor_sample(
* optional int32_t hop, optional label_t label, optional size_t offsets)
*/

// tuple with 3 elements. 1 - edge_type, 2- host span of size_t, 3 - fanout vector
// tuple with 3 elements. 1 - edge_type, 2- host span of size_t, 3 - fanout vector
template <typename vertex_t,
typename edge_t,
typename weight_t,
Expand Down
9 changes: 3 additions & 6 deletions cpp/include/cugraph_c/sampling_algorithms.h
Original file line number Diff line number Diff line change
Expand Up @@ -324,10 +324,10 @@ void cugraph_sampling_options_free(cugraph_sampling_options_t* options);
*/
// FIXME: internal representation should be tuple instead of pairs - Make it more generic (tuple)
// cugraph_device_tuple_t, host_device_tuple_t,
//dictionary, key and array
// dictionary, key and array
// translate dictionary to a tuple. Add to the draft PR the PLC layer.
// Concatenate to build the 3 arrays from the PLC layer
/// mimic
/// mimic
typedef struct {
int32_t align_;
} cugraph_sample_heterogeneous_fanout_t;
Expand Down Expand Up @@ -359,8 +359,7 @@ cugraph_error_code_t cugraph_create_heterogeneous_fanout(
*
* @param [in] heterogeneous_fanout The edge type size and fanout values
*/
void cugraph_heterogeneous_fanout_free(
cugraph_sample_heterogeneous_fanout_t* heterogeneous_fanout);
void cugraph_heterogeneous_fanout_free(cugraph_sample_heterogeneous_fanout_t* heterogeneous_fanout);

/**
* @brief Uniform Neighborhood Sampling
Expand Down Expand Up @@ -712,15 +711,13 @@ cugraph_error_code_t cugraph_test_uniform_neighborhood_sample_result_create(
* @return error code
*/


cugraph_error_code_t cugraph_select_random_vertices(const cugraph_resource_handle_t* handle,
const cugraph_graph_t* graph,
cugraph_rng_state_t* rng_state,
size_t num_vertices,
cugraph_type_erased_device_array_t** vertices,
cugraph_error_t** error);


#ifdef __cplusplus
}
#endif
131 changes: 67 additions & 64 deletions cpp/src/c_api/neighbor_sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
* limitations under the License.
*/

#include "c_api/neighbor_sampling.hpp" // FIXME: Remove this and instead use std::tuple

#include "c_api/abstract_functor.hpp"
#include "c_api/graph.hpp"
#include "c_api/properties.hpp"
#include "c_api/random.hpp"
#include "c_api/resource_handle.hpp"
#include "c_api/utils.hpp"
#include "c_api/neighbor_sampling.hpp" // FIXME: Remove this and instead use std::tuple

#include <cugraph_c/algorithms.h>
#include <cugraph_c/sampling_algorithms.h>
Expand Down Expand Up @@ -113,7 +114,8 @@ struct uniform_neighbor_sampling_functor : public cugraph::c_api::abstract_funct
fan_out_(
reinterpret_cast<cugraph::c_api::cugraph_type_erased_host_array_view_t const*>(fan_out)),
heterogeneous_fan_out_(
reinterpret_cast<cugraph::c_api::cugraph_sample_heterogeneous_fanout_t const*>(heterogeneous_fan_out)),
reinterpret_cast<cugraph::c_api::cugraph_sample_heterogeneous_fanout_t const*>(
heterogeneous_fan_out)),
rng_state_(reinterpret_cast<cugraph::c_api::cugraph_rng_state_t*>(rng_state)),
options_(options),
do_expensive_check_(do_expensive_check)
Expand Down Expand Up @@ -221,16 +223,16 @@ struct uniform_neighbor_sampling_functor : public cugraph::c_api::abstract_funct
raft::device_span<label_t const>{label_to_comm_rank_->as_type<label_t>(),
label_to_comm_rank_->size_}))
: std::nullopt,
(fan_out_ != nullptr)
? std::make_optional<raft::host_span<const int>>(fan_out_->as_type<const int>(), fan_out_->size_)
: std::nullopt,
(fan_out_ != nullptr) ? std::make_optional<raft::host_span<const int>>(
fan_out_->as_type<const int>(), fan_out_->size_)
: std::nullopt,

(heterogeneous_fan_out_ != nullptr)
? std::make_optional(std::make_tuple(
raft::host_span<const int>{heterogeneous_fan_out_->edge_type_id->as_type<int>(),
heterogeneous_fan_out_->edge_type_id->size_},
raft::host_span<const int>{heterogeneous_fan_out_->fanout->as_type<int>(),
heterogeneous_fan_out_->fanout->size_}))
raft::host_span<const int>{heterogeneous_fan_out_->edge_type_id->as_type<int>(),
heterogeneous_fan_out_->edge_type_id->size_},
raft::host_span<const int>{heterogeneous_fan_out_->fanout->as_type<int>(),
heterogeneous_fan_out_->fanout->size_}))
: std::nullopt,
rng_state_->rng_state_,
options_.return_hops_,
Expand Down Expand Up @@ -769,58 +771,61 @@ struct biased_neighbor_sampling_functor : public cugraph::c_api::abstract_functo
}
};

struct create_heterogeneous_fanout_functor : public cugraph::c_api::abstract_functor {
raft::handle_t const& handle_;
cugraph::c_api::cugraph_graph_t* graph_;
cugraph::c_api::cugraph_type_erased_host_array_view_t const* edge_type_size_;
cugraph::c_api::cugraph_type_erased_host_array_view_t const* fanout_;
// FIXME: This type doesn't exist: instead create an 'std::tuple<cugraph_type_erased_host_array_t*>'
cugraph::c_api::cugraph_sample_heterogeneous_fanout_t* result_{};

create_heterogeneous_fanout_functor(::cugraph_resource_handle_t const* handle,
::cugraph_graph_t* graph,
::cugraph_type_erased_host_array_view_t const* edge_type_size,
::cugraph_type_erased_host_array_view_t const* fanout)
: abstract_functor(),
handle_(*reinterpret_cast<cugraph::c_api::cugraph_resource_handle_t const*>(handle)->handle_),
graph_(reinterpret_cast<cugraph::c_api::cugraph_graph_t*>(graph)),
edge_type_size_(
reinterpret_cast<cugraph::c_api::cugraph_type_erased_host_array_view_t const*>(edge_type_size)),
fanout_(
reinterpret_cast<cugraph::c_api::cugraph_type_erased_host_array_view_t const*>(fanout))
{
}
struct create_heterogeneous_fanout_functor : public cugraph::c_api::abstract_functor {
raft::handle_t const& handle_;
cugraph::c_api::cugraph_graph_t* graph_;
cugraph::c_api::cugraph_type_erased_host_array_view_t const* edge_type_size_;
cugraph::c_api::cugraph_type_erased_host_array_view_t const* fanout_;
// FIXME: This type doesn't exist: instead create an
// 'std::tuple<cugraph_type_erased_host_array_t*>'
cugraph::c_api::cugraph_sample_heterogeneous_fanout_t* result_{};

create_heterogeneous_fanout_functor(::cugraph_resource_handle_t const* handle,
::cugraph_graph_t* graph,
::cugraph_type_erased_host_array_view_t const* edge_type_size,
::cugraph_type_erased_host_array_view_t const* fanout)
: abstract_functor(),
handle_(*reinterpret_cast<cugraph::c_api::cugraph_resource_handle_t const*>(handle)->handle_),
graph_(reinterpret_cast<cugraph::c_api::cugraph_graph_t*>(graph)),
edge_type_size_(
reinterpret_cast<cugraph::c_api::cugraph_type_erased_host_array_view_t const*>(
edge_type_size)),
fanout_(
reinterpret_cast<cugraph::c_api::cugraph_type_erased_host_array_view_t const*>(fanout))
{
}

template <typename vertex_t,
typename edge_t,
typename weight_t,
typename edge_type_type_t,
bool store_transposed,
bool multi_gpu>
void operator()()
{
// FIXME: Remove this check as it is not necessary
if constexpr (!cugraph::is_candidate<vertex_t, edge_t, weight_t>::value) {
unsupported();
} else {
std::vector<int32_t> edge_type_size_copy{(int32_t) edge_type_size_->size_};
std::vector<int32_t> fanout_copy{(int32_t) fanout_->size_};
raft::copy(
edge_type_size_copy.data(), edge_type_size_->as_type<int32_t>(), edge_type_size_->size_, handle_.get_stream());

raft::copy(
fanout_copy.data(), fanout_->as_type<int32_t>(), fanout_->size_, handle_.get_stream());
// std::tuple (template)
// result_ = new std::tuple <template type of 2 cugraph_type_erased_host_array_t>
result_ = new cugraph::c_api::cugraph_sample_heterogeneous_fanout_t{
new cugraph::c_api::cugraph_type_erased_host_array_t(edge_type_size_copy, INT32),
new cugraph::c_api::cugraph_type_erased_host_array_t(fanout_copy, INT32)};

}
template <typename vertex_t,
typename edge_t,
typename weight_t,
typename edge_type_type_t,
bool store_transposed,
bool multi_gpu>
void operator()()
{
// FIXME: Remove this check as it is not necessary
if constexpr (!cugraph::is_candidate<vertex_t, edge_t, weight_t>::value) {
unsupported();
} else {
std::vector<int32_t> edge_type_size_copy{(int32_t)edge_type_size_->size_};
std::vector<int32_t> fanout_copy{(int32_t)fanout_->size_};

raft::copy(edge_type_size_copy.data(),
edge_type_size_->as_type<int32_t>(),
edge_type_size_->size_,
handle_.get_stream());

raft::copy(
fanout_copy.data(), fanout_->as_type<int32_t>(), fanout_->size_, handle_.get_stream());

// std::tuple (template)
// result_ = new std::tuple <template type of 2 cugraph_type_erased_host_array_t>
result_ = new cugraph::c_api::cugraph_sample_heterogeneous_fanout_t{
new cugraph::c_api::cugraph_type_erased_host_array_t(edge_type_size_copy, INT32),
new cugraph::c_api::cugraph_type_erased_host_array_t(fanout_copy, INT32)};
}
};
}
};
} // namespace

extern "C" cugraph_error_code_t cugraph_sampling_options_create(
Expand Down Expand Up @@ -1454,8 +1459,6 @@ cugraph_error_code_t cugraph_biased_neighbor_sample(
return cugraph::c_api::run_algorithm(graph, functor, result, error);
}



extern "C" cugraph_error_code_t cugraph_create_heterogeneous_fanout(
const cugraph_resource_handle_t* handle,
cugraph_graph_t* graph,
Expand All @@ -1464,7 +1467,7 @@ extern "C" cugraph_error_code_t cugraph_create_heterogeneous_fanout(
cugraph_sample_heterogeneous_fanout_t** heterogeneous_fanout,
cugraph_error_t** error)
{
create_heterogeneous_fanout_functor functor(handle, graph, edge_type_size, fanout);
create_heterogeneous_fanout_functor functor(handle, graph, edge_type_size, fanout);

return cugraph::c_api::run_algorithm(graph, functor, heterogeneous_fanout, error);
}
}
43 changes: 26 additions & 17 deletions cpp/src/sampling/neighbor_sampling_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ neighbor_sample_impl(
std::optional<std::tuple<raft::device_span<label_t const>, raft::device_span<int32_t const>>>
label_to_output_comm_rank,
std::optional<raft::host_span<int32_t const>> fan_out,
std::optional<std::tuple<raft::host_span<int32_t const>, raft::host_span<int32_t const>>> heterogeneous_fan_out,
//raft::host_span<int32_t const> fan_out,
std::optional<std::tuple<raft::host_span<int32_t const>, raft::host_span<int32_t const>>>
heterogeneous_fan_out,
// raft::host_span<int32_t const> fan_out,
bool return_hops,
bool with_replacement,
prior_sources_behavior_t prior_sources_behavior,
Expand All @@ -78,20 +79,25 @@ neighbor_sample_impl(
static_assert(std::is_floating_point_v<bias_t>);

if (fan_out) {
CUGRAPH_EXPECTS((*fan_out).size() > 0, "Invalid input argument: number of levels must be non-zero.");
CUGRAPH_EXPECTS((*fan_out).size() > 0,
"Invalid input argument: number of levels must be non-zero.");
CUGRAPH_EXPECTS(
(*fan_out).size() <= static_cast<size_t>(std::numeric_limits<int32_t>::max()),
"Invalid input argument: number of levels should not overflow int32_t"); // as we use int32_t
// to store hops
} else {
CUGRAPH_EXPECTS(std::accumulate(
std::get<0>(*heterogeneous_fan_out).begin(),
std::get<0>(*heterogeneous_fan_out).end(), 0) == std::get<1>(*heterogeneous_fan_out).size() && std::get<1>(*heterogeneous_fan_out).size() != 0,
"Invalid input argument: number of levels and size must match and should be non zero.");

(*fan_out).size() <= static_cast<size_t>(std::numeric_limits<int32_t>::max()),
"Invalid input argument: number of levels should not overflow int32_t"); // as we use int32_t
// to store hops
} else {
CUGRAPH_EXPECTS(
std::get<0>(*heterogeneous_fan_out).size() <= static_cast<size_t>(std::numeric_limits<int32_t>::max())
&& std::get<1>(*heterogeneous_fan_out).size() <= static_cast<size_t>(std::numeric_limits<int32_t>::max()),
std::accumulate(std::get<0>(*heterogeneous_fan_out).begin(),
std::get<0>(*heterogeneous_fan_out).end(),
0) == std::get<1>(*heterogeneous_fan_out).size() &&
std::get<1>(*heterogeneous_fan_out).size() != 0,
"Invalid input argument: number of levels and size must match and should be non zero.");

CUGRAPH_EXPECTS(
std::get<0>(*heterogeneous_fan_out).size() <=
static_cast<size_t>(std::numeric_limits<int32_t>::max()) &&
std::get<1>(*heterogeneous_fan_out).size() <=
static_cast<size_t>(std::numeric_limits<int32_t>::max()),
"Invalid input argument: number of levels should not overflow int32_t"); // as we use int32_t
// to store hops
}
Expand Down Expand Up @@ -142,7 +148,9 @@ neighbor_sample_impl(
level_result_dst_vectors.reserve((*fan_out).size());
if (level_result_weight_vectors) { (*level_result_weight_vectors).reserve((*fan_out).size()); }
if (level_result_edge_id_vectors) { (*level_result_edge_id_vectors).reserve((*fan_out).size()); }
if (level_result_edge_type_vectors) { (*level_result_edge_type_vectors).reserve((*fan_out).size()); }
if (level_result_edge_type_vectors) {
(*level_result_edge_type_vectors).reserve((*fan_out).size());
}
if (level_result_label_vectors) { (*level_result_label_vectors).reserve((*fan_out).size()); }

rmm::device_uvector<vertex_t> frontier_vertices(0, handle.get_stream());
Expand Down Expand Up @@ -382,8 +390,9 @@ uniform_neighbor_sample(
std::optional<std::tuple<raft::device_span<label_t const>, raft::device_span<int32_t const>>>
label_to_output_comm_rank,
std::optional<raft::host_span<int32_t const>> fan_out,
std::optional<std::tuple<raft::host_span<int32_t const>, raft::host_span<int32_t const>>> heterogeneous_fan_out,
//raft::host_span<int32_t const> fan_out,
std::optional<std::tuple<raft::host_span<int32_t const>, raft::host_span<int32_t const>>>
heterogeneous_fan_out,
// raft::host_span<int32_t const> fan_out,
raft::random::RngState& rng_state,
bool return_hops,
bool with_replacement,
Expand Down
10 changes: 6 additions & 4 deletions cpp/src/sampling/neighbor_sampling_mg_v32_e32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ uniform_neighbor_sample(
std::optional<std::tuple<raft::device_span<int32_t const>, raft::device_span<int32_t const>>>
label_to_output_comm_rank,
std::optional<raft::host_span<int32_t const>> fan_out,
std::optional<std::tuple<raft::host_span<int32_t const>, raft::host_span<int32_t const>>> heterogeneous_fan_out,
std::optional<std::tuple<raft::host_span<int32_t const>, raft::host_span<int32_t const>>>
heterogeneous_fan_out,
raft::random::RngState& rng_state,
bool return_hops,
bool with_replacement,
Expand All @@ -67,7 +68,8 @@ uniform_neighbor_sample(
std::optional<std::tuple<raft::device_span<int32_t const>, raft::device_span<int32_t const>>>
label_to_output_comm_rank,
std::optional<raft::host_span<int32_t const>> fan_out,
std::optional<std::tuple<raft::host_span<int32_t const>, raft::host_span<int32_t const>>> heterogeneous_fan_out,
std::optional<std::tuple<raft::host_span<int32_t const>, raft::host_span<int32_t const>>>
heterogeneous_fan_out,
raft::random::RngState& rng_state,
bool return_hops,
bool with_replacement,
Expand All @@ -94,7 +96,7 @@ biased_neighbor_sample(
std::optional<raft::device_span<int32_t const>> starting_vertex_labels,
std::optional<std::tuple<raft::device_span<int32_t const>, raft::device_span<int32_t const>>>
label_to_output_comm_rank,
//std::optional<raft::host_span<int32_t const>> fan_out,
// std::optional<raft::host_span<int32_t const>> fan_out,
raft::host_span<int32_t const> fan_out,
raft::random::RngState& rng_state,
bool return_hops,
Expand Down Expand Up @@ -122,7 +124,7 @@ biased_neighbor_sample(
std::optional<raft::device_span<int32_t const>> starting_vertex_labels,
std::optional<std::tuple<raft::device_span<int32_t const>, raft::device_span<int32_t const>>>
label_to_output_comm_rank,
//std::optional<raft::host_span<int32_t const>> fan_out,
// std::optional<raft::host_span<int32_t const>> fan_out,
raft::host_span<int32_t const> fan_out,
raft::random::RngState& rng_state,
bool return_hops,
Expand Down
10 changes: 6 additions & 4 deletions cpp/src/sampling/neighbor_sampling_mg_v32_e64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ uniform_neighbor_sample(
std::optional<std::tuple<raft::device_span<int32_t const>, raft::device_span<int32_t const>>>
label_to_output_comm_rank,
std::optional<raft::host_span<int32_t const>> fan_out,
std::optional<std::tuple<raft::host_span<int32_t const>, raft::host_span<int32_t const>>> heterogeneous_fan_out,
std::optional<std::tuple<raft::host_span<int32_t const>, raft::host_span<int32_t const>>>
heterogeneous_fan_out,
raft::random::RngState& rng_state,
bool return_hops,
bool with_replacement,
Expand All @@ -67,7 +68,8 @@ uniform_neighbor_sample(
std::optional<std::tuple<raft::device_span<int32_t const>, raft::device_span<int32_t const>>>
label_to_output_comm_rank,
std::optional<raft::host_span<int32_t const>> fan_out,
std::optional<std::tuple<raft::host_span<int32_t const>, raft::host_span<int32_t const>>> heterogeneous_fan_out,
std::optional<std::tuple<raft::host_span<int32_t const>, raft::host_span<int32_t const>>>
heterogeneous_fan_out,
raft::random::RngState& rng_state,
bool return_hops,
bool with_replacement,
Expand All @@ -94,7 +96,7 @@ biased_neighbor_sample(
std::optional<raft::device_span<int32_t const>> starting_vertex_labels,
std::optional<std::tuple<raft::device_span<int32_t const>, raft::device_span<int32_t const>>>
label_to_output_comm_rank,
//std::optional<raft::host_span<int32_t const>> fan_out,
// std::optional<raft::host_span<int32_t const>> fan_out,
raft::host_span<int32_t const> fan_out,
raft::random::RngState& rng_state,
bool return_hops,
Expand Down Expand Up @@ -122,7 +124,7 @@ biased_neighbor_sample(
std::optional<raft::device_span<int32_t const>> starting_vertex_labels,
std::optional<std::tuple<raft::device_span<int32_t const>, raft::device_span<int32_t const>>>
label_to_output_comm_rank,
//std::optional<raft::host_span<int32_t const>> fan_out,
// std::optional<raft::host_span<int32_t const>> fan_out,
raft::host_span<int32_t const> fan_out,
raft::random::RngState& rng_state,
bool return_hops,
Expand Down
Loading

0 comments on commit 10fa86d

Please sign in to comment.