Skip to content

Commit

Permalink
Change parameter order, add do_expensive_check falg parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
Naim committed Jan 9, 2024
1 parent 50c852a commit 79b06fa
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 14 deletions.
5 changes: 3 additions & 2 deletions cpp/include/cugraph/detail/shuffle_wrappers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,10 @@ shuffle_ext_vertex_value_pairs_to_local_gpu_by_vertex_partitioning(
template <typename vertex_t>
rmm::device_uvector<vertex_t> permute_range(raft::handle_t const& handle,
raft::random::RngState& rng_state,
vertex_t local_range_size,
vertex_t local_start,
bool multi_gpu = false);
vertex_t local_range_size,
bool multi_gpu = false,
bool do_expensive_check = false);

/**
* @brief Shuffle internal (i.e. renumbered) vertices to their local GPUs based on vertex
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/community/louvain_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ std::pair<std::unique_ptr<Dendrogram<vertex_t>>, weight_t> louvain(
auto random_cluster_assignments = cugraph::detail::permute_range<vertex_t>(
handle,
*rng_state,
current_graph_view.local_vertex_partition_range_size(),
current_graph_view.local_vertex_partition_range_first(),
current_graph_view.local_vertex_partition_range_size(),
multi_gpu);

raft::copy(dendrogram->current_level_begin(),
Expand Down
30 changes: 19 additions & 11 deletions cpp/src/detail/permute_range.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,12 @@ namespace detail {
template <typename vertex_t>
rmm::device_uvector<vertex_t> permute_range(raft::handle_t const& handle,
raft::random::RngState& rng_state,
vertex_t local_range_size,
vertex_t local_range_start,
bool multi_gpu)
vertex_t local_range_size,
bool multi_gpu,
bool do_expensive_check)
{
if (multi_gpu) {
if (do_expensive_check && multi_gpu) {
auto& comm = handle.get_comms();
auto const comm_size = comm.get_size();
auto const comm_rank = comm.get_rank();
Expand All @@ -69,8 +70,7 @@ rmm::device_uvector<vertex_t> permute_range(raft::handle_t const& handle,
auto const comm_size = comm.get_size();
auto const comm_rank = comm.get_rank();

std::vector<size_t> tx_value_counts(comm_size);
std::fill(tx_value_counts.begin(), tx_value_counts.end(), 0);
std::vector<size_t> tx_value_counts(comm_size, 0)

{
rmm::device_uvector<vertex_t> d_target_ranks(permuted_integers.size(), handle.get_stream());
Expand Down Expand Up @@ -138,8 +138,14 @@ rmm::device_uvector<vertex_t> permute_range(raft::handle_t const& handle,
// take care of deficits and extras numbers
auto& comm = handle.get_comms();
auto const comm_rank = comm.get_rank();
int nr_extras = static_cast<int>(permuted_integers.size()) - static_cast<int>(local_range_size);
int nr_deficits = nr_extras >= 0 ? 0 : -nr_extras;

size_t nr_extras{0};
size_t nr_deficits{0};
if (permuted_integers.size() > static_cast<size_t>(local_range_size)) {
nr_extras = permuted_integers.size() - static_cast<size_t>(local_range_size);
} else {
nr_deficits = static_cast<size_t>(local_range_size) - permuted_integers.size();
}

auto extra_cluster_ids = cugraph::detail::device_allgatherv(
handle,
Expand All @@ -165,15 +171,17 @@ rmm::device_uvector<vertex_t> permute_range(raft::handle_t const& handle,

template rmm::device_uvector<int32_t> permute_range(raft::handle_t const& handle,
raft::random::RngState& rng_state,
int32_t local_range_size,
int32_t local_range_start,
bool multi_gpu);
int32_t local_range_size,
bool multi_gpu,
bool do_expensive_check);

template rmm::device_uvector<int64_t> permute_range(raft::handle_t const& handle,
raft::random::RngState& rng_state,
int64_t local_range_size,
int64_t local_range_start,
bool multi_gpu);
int64_t local_range_size,
bool multi_gpu,
bool do_expensive_check);

} // namespace detail
} // namespace cugraph

0 comments on commit 79b06fa

Please sign in to comment.