Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbarghi-nv committed Sep 22, 2023
1 parent ae94c35 commit 7beba4b
Show file tree
Hide file tree
Showing 12 changed files with 387 additions and 275 deletions.
7 changes: 4 additions & 3 deletions cpp/include/cugraph_c/sampling_algorithms.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ void cugraph_sampling_set_renumber_results(cugraph_sampling_options_t* options,

/**
* @brief Set whether to compress per-hop (True) or globally (False)
*
*
* @param options - opaque pointer to the sampling options
* @param value - Boolean value to assign to the option
*/
Expand All @@ -262,11 +262,12 @@ void cugraph_sampling_set_return_hops(cugraph_sampling_options_t* options, bool_

/**
* @brief Set compression type
*
*
* @param options - opaque pointer to the sampling options
* @param value - Enum defining the compresion type
*/
void cugraph_sampling_set_compression_type(cugraph_sampling_options_t* options, cugraph_compression_type_t value);
void cugraph_sampling_set_compression_type(cugraph_sampling_options_t* options,
cugraph_compression_type_t value);

/**
* @brief Set prior sources behavior
Expand Down
64 changes: 29 additions & 35 deletions cpp/src/c_api/uniform_neighbor_sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,12 @@ struct uniform_neighbor_sampling_functor : public cugraph::c_api::abstract_funct
std::optional<rmm::device_uvector<size_t>> renumber_map_offsets{std::nullopt};

bool src_is_major = (options_.compression_type_ == cugraph::compression_type_t::CSR) ||
(options_.compression_type_ == cugraph::compression_type_t::DCSR);
(options_.compression_type_ == cugraph::compression_type_t::DCSR);

if (options_.renumber_results_) {
if (options_.compression_type_ == cugraph::compression_type_t::COO) {
// COO

rmm::device_uvector<vertex_t> output_majors(0, handle_.get_stream());
rmm::device_uvector<vertex_t> output_renumber_map(0, handle_.get_stream());
std::tie(output_majors,
Expand All @@ -275,15 +275,15 @@ struct uniform_neighbor_sampling_functor : public cugraph::c_api::abstract_funct
: std::nullopt,
src_is_major,
do_expensive_check_);

majors.emplace(std::move(output_majors));
renumber_map.emplace(std::move(output_renumber_map));
} else {
// (D)CSC, (D)CSR

bool doubly_compress =
(options_.compression_type_ == cugraph::compression_type_t::DCSR) ||
(options_.compression_type_ == cugraph::compression_type_t::DCSC);
(options_.compression_type_ == cugraph::compression_type_t::DCSR) ||
(options_.compression_type_ == cugraph::compression_type_t::DCSC);

rmm::device_uvector<size_t> output_major_offsets(0, handle_.get_stream());
rmm::device_uvector<vertex_t> output_renumber_map(0, handle_.get_stream());
Expand Down Expand Up @@ -335,18 +335,17 @@ struct uniform_neighbor_sampling_functor : public cugraph::c_api::abstract_funct
edge_id ? std::move(edge_id) : std::nullopt,
edge_type ? std::move(edge_type) : std::nullopt,
hop ? std::make_optional(std::make_tuple(std::move(*hop), fan_out_->size_))
: std::nullopt,
: std::nullopt,
offsets ? std::make_optional(std::make_tuple(
raft::device_span<size_t const>{offsets->data(), offsets->size()},
edge_label->size()))
: std::nullopt,
raft::device_span<size_t const>{offsets->data(), offsets->size()},
edge_label->size()))
: std::nullopt,
src_is_major,
do_expensive_check_
);
do_expensive_check_);

majors.emplace(std::move(src));
minors = std::move(dst);

hop.reset();
offsets.reset();
}
Expand All @@ -367,9 +366,11 @@ struct uniform_neighbor_sampling_functor : public cugraph::c_api::abstract_funct
: nullptr,
(wgt) ? new cugraph::c_api::cugraph_type_erased_device_array_t(*wgt, graph_->weight_type_)
: nullptr,
(hop) ? new cugraph::c_api::cugraph_type_erased_device_array_t(*hop, INT32) : nullptr, // FIXME get rid of this
(label_hop_offsets) ? new cugraph::c_api::cugraph_type_erased_device_array_t(*label_hop_offsets, SIZE_T)
: nullptr,
(hop) ? new cugraph::c_api::cugraph_type_erased_device_array_t(*hop, INT32)
: nullptr, // FIXME get rid of this
(label_hop_offsets)
? new cugraph::c_api::cugraph_type_erased_device_array_t(*label_hop_offsets, SIZE_T)
: nullptr,
(edge_label)
? new cugraph::c_api::cugraph_type_erased_device_array_t(edge_label.value(), INT32)
: nullptr,
Expand Down Expand Up @@ -406,7 +407,9 @@ extern "C" void cugraph_sampling_set_renumber_results(cugraph_sampling_options_t
internal_pointer->renumber_results_ = value;
}

extern "C" void cugraph_sampling_set_compress_per_hop(cugraph_sampling_options_t* options, bool_t value) {
extern "C" void cugraph_sampling_set_compress_per_hop(cugraph_sampling_options_t* options,
bool_t value)
{
auto internal_pointer = reinterpret_cast<cugraph::c_api::cugraph_sampling_options_t*>(options);
internal_pointer->compress_per_hop_ = value;
}
Expand All @@ -424,26 +427,17 @@ extern "C" void cugraph_sampling_set_return_hops(cugraph_sampling_options_t* opt
internal_pointer->return_hops_ = value;
}

extern "C" void cugraph_sampling_set_compression_type(cugraph_sampling_options_t* options, cugraph_compression_type_t value) {
extern "C" void cugraph_sampling_set_compression_type(cugraph_sampling_options_t* options,
cugraph_compression_type_t value)
{
auto internal_pointer = reinterpret_cast<cugraph::c_api::cugraph_sampling_options_t*>(options);
switch(value) {
case COO:
internal_pointer->compression_type_ = cugraph::compression_type_t::COO;
break;
case CSR:
internal_pointer->compression_type_ = cugraph::compression_type_t::CSR;
break;
case CSC:
internal_pointer->compression_type_ = cugraph::compression_type_t::CSC;
break;
case DCSR:
internal_pointer->compression_type_ = cugraph::compression_type_t::DCSR;
break;
case DCSC:
internal_pointer->compression_type_ = cugraph::compression_type_t::DCSC;
break;
default:
CUGRAPH_FAIL("Invalid compression type");
switch (value) {
case COO: internal_pointer->compression_type_ = cugraph::compression_type_t::COO; break;
case CSR: internal_pointer->compression_type_ = cugraph::compression_type_t::CSR; break;
case CSC: internal_pointer->compression_type_ = cugraph::compression_type_t::CSC; break;
case DCSR: internal_pointer->compression_type_ = cugraph::compression_type_t::DCSR; break;
case DCSC: internal_pointer->compression_type_ = cugraph::compression_type_t::DCSC; break;
default: CUGRAPH_FAIL("Invalid compression type");
}
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/sampling/sampling_post_processing_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ void check_input_edges(
"Invlaid input arguments: there should be 1 or more labels if "
"edgelist_label_offsets.has_value() is true.");
*/

CUGRAPH_EXPECTS(
!edgelist_label_offsets.has_value() ||
(std::get<0>(*edgelist_label_offsets).size() == std::get<1>(*edgelist_label_offsets) + 1),
Expand Down
Loading

0 comments on commit 7beba4b

Please sign in to comment.