Skip to content

Commit

Permalink
Assign a radom rank to each vertex id/number
Browse files Browse the repository at this point in the history
  • Loading branch information
Naim committed Jan 6, 2024
1 parent 81e18e5 commit c8ad2d5
Showing 1 changed file with 60 additions and 41 deletions.
101 changes: 60 additions & 41 deletions cpp/src/detail/permute_range.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include <rmm/device_uvector.hpp>
#include <rmm/exec_policy.hpp>

#include <thrust/iterator/constant_iterator.h>
#include <thrust/reduce.h>
#include <thrust/sort.h>

namespace cugraph {
Expand Down Expand Up @@ -42,27 +44,26 @@ rmm::device_uvector<vertex_t> permute_range(raft::handle_t const& handle,
sub_range_sizes[comm_rank] == local_range_start,
"Invalid input arguments: a rage must have contiguous and non-overlapping values");
}
rmm::device_uvector<vertex_t> random_cluster_assignments(local_range_size, handle.get_stream());
rmm::device_uvector<vertex_t> permuted_intergers(local_range_size, handle.get_stream());

// generate as many number as #local_vertices on each GPU
detail::sequence_fill(handle.get_stream(),
random_cluster_assignments.begin(),
random_cluster_assignments.size(),
local_range_start);
detail::sequence_fill(
handle.get_stream(), permuted_intergers.begin(), permuted_intergers.size(), local_range_start);

// shuffle/permute locally
rmm::device_uvector<float> random_numbers(random_cluster_assignments.size(), handle.get_stream());
rmm::device_uvector<float> fractional_random_numbers(permuted_intergers.size(),
handle.get_stream());

cugraph::detail::uniform_random_fill(handle.get_stream(),
random_numbers.data(),
random_numbers.size(),
fractional_random_numbers.data(),
fractional_random_numbers.size(),
float{0.0},
float{1.0},
rng_state);
thrust::sort_by_key(handle.get_thrust_policy(),
random_numbers.begin(),
random_numbers.end(),
random_cluster_assignments.begin());
fractional_random_numbers.begin(),
fractional_random_numbers.end(),
permuted_intergers.begin());

if (multi_gpu) {
// distribute shuffled/permuted numbers to other GPUs
Expand All @@ -71,76 +72,94 @@ rmm::device_uvector<vertex_t> permute_range(raft::handle_t const& handle,
auto const comm_rank = comm.get_rank();

std::vector<size_t> tx_value_counts(comm_size);
std::fill(tx_value_counts.begin(),
tx_value_counts.end(),
random_cluster_assignments.size() / comm_size);
std::fill(tx_value_counts.begin(), tx_value_counts.end(), 0);

std::vector<vertex_t> h_random_gpu_ranks;
{
rmm::device_uvector<vertex_t> d_random_numbers(random_cluster_assignments.size() % comm_size,
handle.get_stream());
rmm::device_uvector<vertex_t> d_target_ranks(permuted_intergers.size(), handle.get_stream());

cugraph::detail::uniform_random_fill(handle.get_stream(),
d_random_numbers.data(),
d_random_numbers.size(),
d_target_ranks.data(),
d_target_ranks.size(),
vertex_t{0},
vertex_t{comm_size},
rng_state);

h_random_gpu_ranks.resize(d_random_numbers.size());
thrust::sort_by_key(handle.get_thrust_policy(),
d_target_ranks.begin(),
d_target_ranks.end(),
permuted_intergers.begin());

raft::update_host(h_random_gpu_ranks.data(),
d_random_numbers.data(),
d_random_numbers.size(),
handle.get_stream());
}
rmm::device_uvector<vertex_t> d_reduced_ranks(comm_size, handle.get_stream());
rmm::device_uvector<vertex_t> d_reduced_counts(comm_size, handle.get_stream());

auto output_end = thrust::reduce_by_key(handle.get_thrust_policy(),
d_target_ranks.begin(),
d_target_ranks.end(),
thrust::make_constant_iterator(1),
d_reduced_ranks.begin(),
d_reduced_counts.begin(),
thrust::equal_to<int>());

auto nr_output_pairs =
static_cast<vertex_t>(thrust::distance(d_reduced_ranks.begin(), output_end.first));

std::vector<vertex_t> h_reduced_ranks(comm_size);
std::vector<vertex_t> h_reduced_counts(comm_size);

raft::update_host(
h_reduced_ranks.data(), d_reduced_ranks.data(), nr_output_pairs, handle.get_stream());

raft::update_host(
h_reduced_counts.data(), d_reduced_counts.data(), nr_output_pairs, handle.get_stream());

for (int i = 0; i < static_cast<int>(random_cluster_assignments.size() % comm_size); i++) {
tx_value_counts[h_random_gpu_ranks[i]]++;
for (int i = 0; i < static_cast<int>(nr_output_pairs); i++) {
tx_value_counts[h_reduced_ranks[i]] = static_cast<size_t>(h_reduced_counts[i]);
}
}

std::tie(random_cluster_assignments, std::ignore) = cugraph::shuffle_values(
handle.get_comms(), random_cluster_assignments.begin(), tx_value_counts, handle.get_stream());
std::tie(permuted_intergers, std::ignore) = cugraph::shuffle_values(
handle.get_comms(), permuted_intergers.begin(), tx_value_counts, handle.get_stream());

// shuffle/permute locally again
random_numbers.resize(random_cluster_assignments.size(), handle.get_stream());
fractional_random_numbers.resize(permuted_intergers.size(), handle.get_stream());

cugraph::detail::uniform_random_fill(handle.get_stream(),
random_numbers.data(),
random_numbers.size(),
fractional_random_numbers.data(),
fractional_random_numbers.size(),
float{0.0},
float{1.0},
rng_state);
thrust::sort_by_key(handle.get_thrust_policy(),
random_numbers.begin(),
random_numbers.end(),
random_cluster_assignments.begin());
fractional_random_numbers.begin(),
fractional_random_numbers.end(),
permuted_intergers.begin());

// take care of deficits and extras numbers

int nr_extras =
static_cast<int>(random_cluster_assignments.size()) - static_cast<int>(local_range_size);
static_cast<int>(permuted_intergers.size()) - static_cast<int>(local_range_size);
int nr_deficits = nr_extras >= 0 ? 0 : -nr_extras;

auto extra_cluster_ids = cugraph::detail::device_allgatherv(
handle,
comm,
raft::device_span<vertex_t const>(random_cluster_assignments.data() + local_range_size,
raft::device_span<vertex_t const>(permuted_intergers.data() + local_range_size,
nr_extras > 0 ? nr_extras : 0));

random_cluster_assignments.resize(local_range_size, handle.get_stream());
permuted_intergers.resize(local_range_size, handle.get_stream());
auto deficits =
cugraph::host_scalar_allgather(handle.get_comms(), nr_deficits, handle.get_stream());

std::exclusive_scan(deficits.begin(), deficits.end(), deficits.begin(), vertex_t{0});

raft::copy(random_cluster_assignments.data() + local_range_size - nr_deficits,
raft::copy(permuted_intergers.data() + local_range_size - nr_deficits,
extra_cluster_ids.begin() + deficits[comm_rank],
nr_deficits,
handle.get_stream());
}

assert(random_cluster_assignments.size() == local_range_size);
return random_cluster_assignments;
assert(permuted_intergers.size() == local_range_size);
return permuted_intergers;
}

template rmm::device_uvector<int32_t> permute_range(raft::handle_t const& handle,
Expand Down

0 comments on commit c8ad2d5

Please sign in to comment.