Skip to content

Commit

Permalink
fix typos, refactor code SG/MG code path
Browse files Browse the repository at this point in the history
  • Loading branch information
Naim committed Jan 9, 2024
1 parent b89dc11 commit 50c852a
Showing 1 changed file with 32 additions and 44 deletions.
76 changes: 32 additions & 44 deletions cpp/src/detail/permute_range.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -57,29 +57,14 @@ rmm::device_uvector<vertex_t> permute_range(raft::handle_t const& handle,
sub_range_sizes[comm_rank] == local_range_start,
"Invalid input arguments: a rage must have contiguous and non-overlapping values");
}
rmm::device_uvector<vertex_t> permuted_intergers(local_range_size, handle.get_stream());
rmm::device_uvector<vertex_t> permuted_integers(local_range_size, handle.get_stream());

// generate as many number as #local_vertices on each GPU
// generate as many integers as #local_range_size on each GPU
detail::sequence_fill(
handle.get_stream(), permuted_intergers.begin(), permuted_intergers.size(), local_range_start);

// shuffle/permute locally
rmm::device_uvector<float> fractional_random_numbers(permuted_intergers.size(),
handle.get_stream());

cugraph::detail::uniform_random_fill(handle.get_stream(),
fractional_random_numbers.data(),
fractional_random_numbers.size(),
float{0.0},
float{1.0},
rng_state);
thrust::sort_by_key(handle.get_thrust_policy(),
fractional_random_numbers.begin(),
fractional_random_numbers.end(),
permuted_intergers.begin());
handle.get_stream(), permuted_integers.begin(), permuted_integers.size(), local_range_start);

if (multi_gpu) {
// distribute shuffled/permuted numbers to other GPUs
// randomly distribute integers to all GPUs
auto& comm = handle.get_comms();
auto const comm_size = comm.get_size();
auto const comm_rank = comm.get_rank();
Expand All @@ -88,7 +73,7 @@ rmm::device_uvector<vertex_t> permute_range(raft::handle_t const& handle,
std::fill(tx_value_counts.begin(), tx_value_counts.end(), 0);

{
rmm::device_uvector<vertex_t> d_target_ranks(permuted_intergers.size(), handle.get_stream());
rmm::device_uvector<vertex_t> d_target_ranks(permuted_integers.size(), handle.get_stream());

cugraph::detail::uniform_random_fill(handle.get_stream(),
d_target_ranks.data(),
Expand All @@ -100,7 +85,7 @@ rmm::device_uvector<vertex_t> permute_range(raft::handle_t const& handle,
thrust::sort_by_key(handle.get_thrust_policy(),
d_target_ranks.begin(),
d_target_ranks.end(),
permuted_intergers.begin());
permuted_integers.begin());

rmm::device_uvector<vertex_t> d_reduced_ranks(comm_size, handle.get_stream());
rmm::device_uvector<vertex_t> d_reduced_counts(comm_size, handle.get_stream());
Expand Down Expand Up @@ -130,49 +115,52 @@ rmm::device_uvector<vertex_t> permute_range(raft::handle_t const& handle,
}
}

std::tie(permuted_intergers, std::ignore) = cugraph::shuffle_values(
handle.get_comms(), permuted_intergers.begin(), tx_value_counts, handle.get_stream());
std::tie(permuted_integers, std::ignore) = cugraph::shuffle_values(
handle.get_comms(), permuted_integers.begin(), tx_value_counts, handle.get_stream());
}

// shuffle/permute locally again
fractional_random_numbers.resize(permuted_intergers.size(), handle.get_stream());
// permute locally
rmm::device_uvector<float> fractional_random_numbers(permuted_integers.size(),
handle.get_stream());

cugraph::detail::uniform_random_fill(handle.get_stream(),
fractional_random_numbers.data(),
fractional_random_numbers.size(),
float{0.0},
float{1.0},
rng_state);
thrust::sort_by_key(handle.get_thrust_policy(),
fractional_random_numbers.begin(),
fractional_random_numbers.end(),
permuted_intergers.begin());
cugraph::detail::uniform_random_fill(handle.get_stream(),
fractional_random_numbers.data(),
fractional_random_numbers.size(),
float{0.0},
float{1.0},
rng_state);
thrust::sort_by_key(handle.get_thrust_policy(),
fractional_random_numbers.begin(),
fractional_random_numbers.end(),
permuted_integers.begin());

if (multi_gpu) {
// take care of deficits and extras numbers

int nr_extras =
static_cast<int>(permuted_intergers.size()) - static_cast<int>(local_range_size);
auto& comm = handle.get_comms();
auto const comm_rank = comm.get_rank();
int nr_extras = static_cast<int>(permuted_integers.size()) - static_cast<int>(local_range_size);
int nr_deficits = nr_extras >= 0 ? 0 : -nr_extras;

auto extra_cluster_ids = cugraph::detail::device_allgatherv(
handle,
comm,
raft::device_span<vertex_t const>(permuted_intergers.data() + local_range_size,
raft::device_span<vertex_t const>(permuted_integers.data() + local_range_size,
nr_extras > 0 ? nr_extras : 0));

permuted_intergers.resize(local_range_size, handle.get_stream());
permuted_integers.resize(local_range_size, handle.get_stream());
auto deficits =
cugraph::host_scalar_allgather(handle.get_comms(), nr_deficits, handle.get_stream());

std::exclusive_scan(deficits.begin(), deficits.end(), deficits.begin(), vertex_t{0});

raft::copy(permuted_intergers.data() + local_range_size - nr_deficits,
raft::copy(permuted_integers.data() + local_range_size - nr_deficits,
extra_cluster_ids.begin() + deficits[comm_rank],
nr_deficits,
handle.get_stream());
}

assert(permuted_intergers.size() == local_range_size);
return permuted_intergers;
assert(permuted_integers.size() == local_range_size);
return permuted_integers;
}

template rmm::device_uvector<int32_t> permute_range(raft::handle_t const& handle,
Expand Down

0 comments on commit 50c852a

Please sign in to comment.