diff --git a/cpp/src/detail/permute_range.cu b/cpp/src/detail/permute_range.cu index e497e002f31..009f3578189 100644 --- a/cpp/src/detail/permute_range.cu +++ b/cpp/src/detail/permute_range.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -57,29 +57,14 @@ rmm::device_uvector permute_range(raft::handle_t const& handle, sub_range_sizes[comm_rank] == local_range_start, "Invalid input arguments: a rage must have contiguous and non-overlapping values"); } - rmm::device_uvector permuted_intergers(local_range_size, handle.get_stream()); + rmm::device_uvector permuted_integers(local_range_size, handle.get_stream()); - // generate as many number as #local_vertices on each GPU + // generate as many integers as #local_range_size on each GPU detail::sequence_fill( - handle.get_stream(), permuted_intergers.begin(), permuted_intergers.size(), local_range_start); - - // shuffle/permute locally - rmm::device_uvector fractional_random_numbers(permuted_intergers.size(), - handle.get_stream()); - - cugraph::detail::uniform_random_fill(handle.get_stream(), - fractional_random_numbers.data(), - fractional_random_numbers.size(), - float{0.0}, - float{1.0}, - rng_state); - thrust::sort_by_key(handle.get_thrust_policy(), - fractional_random_numbers.begin(), - fractional_random_numbers.end(), - permuted_intergers.begin()); + handle.get_stream(), permuted_integers.begin(), permuted_integers.size(), local_range_start); if (multi_gpu) { - // distribute shuffled/permuted numbers to other GPUs + // randomly distribute integers to all GPUs auto& comm = handle.get_comms(); auto const comm_size = comm.get_size(); auto const comm_rank = comm.get_rank(); @@ -88,7 +73,7 @@ rmm::device_uvector permute_range(raft::handle_t const& handle, std::fill(tx_value_counts.begin(), tx_value_counts.end(), 0); { - rmm::device_uvector d_target_ranks(permuted_intergers.size(), handle.get_stream()); + rmm::device_uvector d_target_ranks(permuted_integers.size(), handle.get_stream()); cugraph::detail::uniform_random_fill(handle.get_stream(), d_target_ranks.data(), @@ -100,7 +85,7 @@ rmm::device_uvector permute_range(raft::handle_t const& handle, thrust::sort_by_key(handle.get_thrust_policy(), d_target_ranks.begin(), d_target_ranks.end(), - permuted_intergers.begin()); + permuted_integers.begin()); rmm::device_uvector d_reduced_ranks(comm_size, handle.get_stream()); rmm::device_uvector d_reduced_counts(comm_size, handle.get_stream()); @@ -130,49 +115,52 @@ rmm::device_uvector permute_range(raft::handle_t const& handle, } } - std::tie(permuted_intergers, std::ignore) = cugraph::shuffle_values( - handle.get_comms(), permuted_intergers.begin(), tx_value_counts, handle.get_stream()); + std::tie(permuted_integers, std::ignore) = cugraph::shuffle_values( + handle.get_comms(), permuted_integers.begin(), tx_value_counts, handle.get_stream()); + } - // shuffle/permute locally again - fractional_random_numbers.resize(permuted_intergers.size(), handle.get_stream()); + // permute locally + rmm::device_uvector fractional_random_numbers(permuted_integers.size(), + handle.get_stream()); - cugraph::detail::uniform_random_fill(handle.get_stream(), - fractional_random_numbers.data(), - fractional_random_numbers.size(), - float{0.0}, - float{1.0}, - rng_state); - thrust::sort_by_key(handle.get_thrust_policy(), - fractional_random_numbers.begin(), - fractional_random_numbers.end(), - permuted_intergers.begin()); + cugraph::detail::uniform_random_fill(handle.get_stream(), + fractional_random_numbers.data(), + fractional_random_numbers.size(), + float{0.0}, + float{1.0}, + rng_state); + thrust::sort_by_key(handle.get_thrust_policy(), + fractional_random_numbers.begin(), + fractional_random_numbers.end(), + permuted_integers.begin()); + if (multi_gpu) { // take care of deficits and extras numbers - - int nr_extras = - static_cast(permuted_intergers.size()) - static_cast(local_range_size); + auto& comm = handle.get_comms(); + auto const comm_rank = comm.get_rank(); + int nr_extras = static_cast(permuted_integers.size()) - static_cast(local_range_size); int nr_deficits = nr_extras >= 0 ? 0 : -nr_extras; auto extra_cluster_ids = cugraph::detail::device_allgatherv( handle, comm, - raft::device_span(permuted_intergers.data() + local_range_size, + raft::device_span(permuted_integers.data() + local_range_size, nr_extras > 0 ? nr_extras : 0)); - permuted_intergers.resize(local_range_size, handle.get_stream()); + permuted_integers.resize(local_range_size, handle.get_stream()); auto deficits = cugraph::host_scalar_allgather(handle.get_comms(), nr_deficits, handle.get_stream()); std::exclusive_scan(deficits.begin(), deficits.end(), deficits.begin(), vertex_t{0}); - raft::copy(permuted_intergers.data() + local_range_size - nr_deficits, + raft::copy(permuted_integers.data() + local_range_size - nr_deficits, extra_cluster_ids.begin() + deficits[comm_rank], nr_deficits, handle.get_stream()); } - assert(permuted_intergers.size() == local_range_size); - return permuted_intergers; + assert(permuted_integers.size() == local_range_size); + return permuted_integers; } template rmm::device_uvector permute_range(raft::handle_t const& handle,