From 29cbe978f7624b1e1d516e1fcfb2d9eef7c0d681 Mon Sep 17 00:00:00 2001 From: Charles Hastings Date: Fri, 5 Jul 2024 13:09:23 -0700 Subject: [PATCH 01/18] Define C++ API for negative sampling --- cpp/include/cugraph/sampling_functions.hpp | 55 ++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/cpp/include/cugraph/sampling_functions.hpp b/cpp/include/cugraph/sampling_functions.hpp index fec1a07604e..24ea1423eba 100644 --- a/cpp/include/cugraph/sampling_functions.hpp +++ b/cpp/include/cugraph/sampling_functions.hpp @@ -743,4 +743,59 @@ lookup_endpoints_from_edge_ids_and_types( raft::device_span edge_ids_to_lookup, raft::device_span edge_types_to_lookup); +/** + * @brief Negative Sampling + * + * This function generates negative samples for graph. + * + * Negative sampling is done by generating a random graph according to the specified + * parameters and optionally removing the false negatives. + * + * Sampling occurs by creating a list of source vertex ids from biased samping + * of the source vertex space, and destination vertex ids from biased sampling of the + * destination vertex space, and using this as the putative list of edges. We + * then can optionally remove duplicates and remove false negatives to generate + * the final list. If necessary we will repeat the process to end with a resulting + * edge list of the appropriate size. + * + * @tparam vertex_t Type of vertex identifiers. Needs to be an integral type. + * @tparam edge_t Type of edge identifiers. Needs to be an integral type. + * @tparam store_transposed Flag indicating whether sources (if false) or destinations (if + * true) are major indices + * @tparam multi_gpu Flag indicating whether template instantiation should target single-GPU (false) + * + * @param handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator, and + * handles to various CUDA libraries) to run graph algorithms. + * @param graph_view optional Graph View object to generate NBR Sampling on, only required if + * remove_false_negatives is true + * @param rng_state RNG state + * @param num_samples Number of negative samples to generate + * @param src_bias Optional bias for randomly selecting source vertices. If std::nullopt vertices + * will be selected uniformly + * @param dst_bias Optional bias for randomly selecting destination vertices. If std::nullopt + * vertices will be selected uniformly + * @param remove_duplicates If true, remove duplicate samples + * @param remove_false_negatives If true, remove false negatives (samples that are actually edges in + * the graph + * @param exact_number_of_samples If true, repeat generation until we get the exact number of + * negative samples + * + * @return tuple containing source vertex ids and destination vertex ids for the negative samples + */ +template +std::tuple, rmm::device_uvector> negative_sampling( + raft::handle_t const& handle, + std::optional const&> graph_view, + raft::random::rng_state& rng_state, + size_t num_samples, + std::optional> src_bias, + std::optional> dst_bias, + bool remove_duplicates, + bool remove_false_negatives, + bool exact_number_of_samples); + } // namespace cugraph From 983a8819902cb397d40c364a2b02ca087bffc62b Mon Sep 17 00:00:00 2001 From: Charles Hastings Date: Wed, 10 Jul 2024 11:11:41 -0700 Subject: [PATCH 02/18] first cut at negative sampling implementation (untested)... fixed API --- .../cugraph/detail/utility_wrappers.hpp | 23 +++ cpp/include/cugraph/sampling_functions.hpp | 5 +- cpp/src/detail/utility_wrappers.cuh | 10 + cpp/src/detail/utility_wrappers_32.cu | 10 + cpp/src/detail/utility_wrappers_64.cu | 10 + cpp/src/sampling/negative_sampling_impl.cuh | 181 ++++++++++++++++++ 6 files changed, 236 insertions(+), 3 deletions(-) create mode 100644 cpp/src/sampling/negative_sampling_impl.cuh diff --git a/cpp/include/cugraph/detail/utility_wrappers.hpp b/cpp/include/cugraph/detail/utility_wrappers.hpp index 61ac1bd2804..fc75f06b373 100644 --- a/cpp/include/cugraph/detail/utility_wrappers.hpp +++ b/cpp/include/cugraph/detail/utility_wrappers.hpp @@ -50,6 +50,29 @@ void uniform_random_fill(rmm::cuda_stream_view const& stream_view, value_t max_value, raft::random::RngState& rng_state); +/** + * @brief Fill a buffer with biased random values + * + * Fills a buffer with values based on the specified biases. + * The probability of selecting the value `i` is determined by + * `biases[i] / sum(biases)`. + * + * @tparam value_t type of the value to operate on + * @tparam bias_t type of the bias + * + * @param[in] handle RAFT handle object to encapsulate resources (e.g. CUDA stream, + * communicator, and handles to various CUDA libraries) to run graph algorithms. + * @param[in] rng_state The RngState instance holding pseudo-random number generator state. + * @param[out] output The random values + * @param[in] biases The biased values + * + */ +template +void biased_random_fill(raft::handle_t const& handle, + raft::random::RngState& rng_state, + raft::device_span output, + raft::device_span biases); + /** * @brief Fill a buffer with a constant value * diff --git a/cpp/include/cugraph/sampling_functions.hpp b/cpp/include/cugraph/sampling_functions.hpp index 24ea1423eba..2c303e8204f 100644 --- a/cpp/include/cugraph/sampling_functions.hpp +++ b/cpp/include/cugraph/sampling_functions.hpp @@ -766,8 +766,7 @@ lookup_endpoints_from_edge_ids_and_types( * * @param handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator, and * handles to various CUDA libraries) to run graph algorithms. - * @param graph_view optional Graph View object to generate NBR Sampling on, only required if - * remove_false_negatives is true + * @param graph_view Graph View object to generate NBR Sampling for * @param rng_state RNG state * @param num_samples Number of negative samples to generate * @param src_bias Optional bias for randomly selecting source vertices. If std::nullopt vertices @@ -789,7 +788,7 @@ template std::tuple, rmm::device_uvector> negative_sampling( raft::handle_t const& handle, - std::optional const&> graph_view, + graph_view_t const& graph_view, raft::random::rng_state& rng_state, size_t num_samples, std::optional> src_bias, diff --git a/cpp/src/detail/utility_wrappers.cuh b/cpp/src/detail/utility_wrappers.cuh index ce8549db9f8..98862358b6e 100644 --- a/cpp/src/detail/utility_wrappers.cuh +++ b/cpp/src/detail/utility_wrappers.cuh @@ -57,6 +57,16 @@ void uniform_random_fill(rmm::cuda_stream_view const& stream_view, } } +template +void biased_random_fill(raft::handle_t const& handle, + raft::random::RngState& rng_state, + raft::device_span output, + raft::device_span biases) +{ + CUGRAPH_EXPECTS(std::is_integral::value); + raft::random::discrete(handle, rng_state, output, biases); +} + template void scalar_fill(raft::handle_t const& handle, value_t* d_value, size_t size, value_t value) { diff --git a/cpp/src/detail/utility_wrappers_32.cu b/cpp/src/detail/utility_wrappers_32.cu index 6ab5ae375ca..7a3b0ceb458 100644 --- a/cpp/src/detail/utility_wrappers_32.cu +++ b/cpp/src/detail/utility_wrappers_32.cu @@ -54,6 +54,16 @@ template void uniform_random_fill(rmm::cuda_stream_view const& stream_view, float max_value, raft::random::RngState& rng_state); +template void biased_random_fill(raft::handle_t const& handle, + raft::random::RngState& rng_state, + raft::device_span output, + raft::device_span biases); + +template void biased_random_fill(raft::handle_t const& handle, + raft::random::RngState& rng_state, + raft::device_span output, + raft::device_span biases); + template void scalar_fill(raft::handle_t const& handle, int32_t* d_value, size_t size, diff --git a/cpp/src/detail/utility_wrappers_64.cu b/cpp/src/detail/utility_wrappers_64.cu index a12bc3e952d..e186ce35473 100644 --- a/cpp/src/detail/utility_wrappers_64.cu +++ b/cpp/src/detail/utility_wrappers_64.cu @@ -54,6 +54,16 @@ template void uniform_random_fill(rmm::cuda_stream_view const& stream_view, double max_value, raft::random::RngState& rng_state); +template void biased_random_fill(raft::handle_t const& handle, + raft::random::RngState& rng_state, + raft::device_span output, + raft::device_span biases); + +template void biased_random_fill(raft::handle_t const& handle, + raft::random::RngState& rng_state, + raft::device_span output, + raft::device_span biases); + template void scalar_fill(raft::handle_t const& handle, int64_t* d_value, size_t size, diff --git a/cpp/src/sampling/negative_sampling_impl.cuh b/cpp/src/sampling/negative_sampling_impl.cuh new file mode 100644 index 00000000000..8ec98b2d7d3 --- /dev/null +++ b/cpp/src/sampling/negative_sampling_impl.cuh @@ -0,0 +1,181 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace cugraph { + +template +std::tuple, rmm::device_uvector> negative_sampling( + raft::handle_t const& handle, + graph_view_t const& graph_view, + raft::random::rng_state& rng_state, + size_t num_samples, + std::optional> src_bias, + std::optional> dst_bias, + bool remove_duplicates, + bool remove_false_negatives, + bool exact_number_of_samples, + bool do_expensive_check) +{ + rmm::device_uvector src(0, handle.get_stream()); + rmm::device_uvector dst(0, handle.get_stream()); + + // Optimistically assume we can do this in one pass + size_t samples_in_this_batch = num_samples; + + while (samples_in_this_batch > 0) { + if constexpr (multi_gpu) { + size_t num_gpus = handle.get_comms().get_size(); + size_t rank = handle.get_comms().get_rank(); + + samples_in_this_batch = + (samples_in_this_batch / num_gpus) + (rank < (samples_in_this_batch % num_gpus) ? 1 : 0); + } + + rmm::device_uvector batch_src(samples_in_this_batch, handle.get_stream()); + rmm::device_uvector batch_dst(samples_in_this_batch, handle.get_stream()); + + if (src_bias) { + biased_random_fill(handle, + rng_state, + raft::device_span{batch_src.data(), batch_src.size()}, + *src_bias); + } else { + uniform_random_fill(handle.get_stream(), + batch_src.data(), + batch_src.size(), + vertex_t{0}, + graph_view.number_of_vertices(), + rng_state); + } + + if (dst_bias) { + biased_random_fill(handle, + rng_state, + raft::device_span{batch_dst.data(), batch_dst.size()}, + *dst_bias); + } else { + uniform_random_fill(handle.get_stream(), + batch_dst.data(), + batch_dst.size(), + vertex_t{0}, + graph_view.number_of_vertices(), + rng_state); + } + + if constexpr (multi_gpu) { + auto vertex_partition_range_lasts = graph_view.vertex_partition_range_lasts(); + + std::tie(batch_src, batch_dst, std::ignore, std::ignore, std::ignore) = + detail::shuffle_int_vertex_pairs_with_values_to_local_gpu_by_edge_partitioning( + handle, + std::move(batch_src), + std::move(batch_dst), + std::nullopt, + std::nullopt, + std::nullopt, + vertex_partition_range_lasts); + } + + if (remove_false_negatives) { + auto has_edge_flags = + graph_view.has_edge(handle, + raft::device_span{batch_src.data(), batch_src.size()}, + raft::device_span{batch_dst.data(), batch_dst.size()}, + do_expensive_check); + + auto begin_iter = + thrust::make_zip_iterator(batch_src.begin(), batch_dst.begin(), has_edge_flags.begin()); + auto new_end = thrust::remove_if(handle.get_thrust_policy(), + begin_iter, + begin_iter + batch_src.size(), + [] __device__(auto tuple) { return thrust::get<2>(tuple); }); + batch_src.resize(thrust::distance(begin_iter, new_end), handle.get_stream()); + batch_dst.resize(thrust::distance(begin_iter, new_end), handle.get_stream()); + } + + if (remove_duplicates) { + auto begin_iter = thrust::make_zip_iterator(batch_src.begin(), batch_dst.begin()); + thrust::sort(handle.get_thrust_policy(), begin_iter, begin_iter + batch_src.size()); + + auto new_end = thrust::unique(handle.get_thrust_policy(), begin_iter, end_iter); + + size_t unique_size = thrust::distance(begin_iter, new_end); + + if (src.size() > 0) { + new_end = thrust::remove_if( + handle.get_thrust_policy(), + begin_iter, + begin_iter + unique_size, + [local_src = raft::device_span{src.data(), src.size()}, + local_dst = raft::device_span{dst.data(), dst.size()}] __device__(auto tuple) { + return thrust::binary_search( + thrust::seq, + thrust::make_zip_iterator(local_src.begin(), local_dst.begin()), + thrust::make_zip_iterator(local_src.end(), local_dst.end()), + tuple); + }); + + unique_size = thrust::distance(begin_iter, new_end); + } + + batch_src.resize(unique_size, handle.get_stream()); + batch_dst.resize(unique_size, handle.get_stream()); + } + + if (src.size() > 0) { + size_t current_end = src.size(); + + src.resize(src.size() + batch_src.size(), handle.get_stream()); + dst.resize(dst.size() + batch_dst.size(), handle.get_stream()); + + thrust::copy(handle.get_thrust_policy(), + thrust::make_zip_iterator(batch_src.begin(), batch_dst.begin()), + thrust::make_zip_iterator(batch_src.end(), batch_dst.end()), + thrust::make_zip_iterator(src.begin(), dst.begin()) + current_end); + } else { + src = std::move(batch_src); + dst = std::move(batch_dst); + } + + if (exact_number_of_samples) { + size_t num_batch_samples = src.size(); + if constexpr (multi_gpu) { + num_batch_samples = cugraph::host_scalar_allreduce( + handle.get_comms(), num_batch_samples, raft::comms::op_t::SUM, handle.get_stream()); + } + + // FIXME: We could oversample and discard the unnecessary samples + // to reduce the number of iterations in the outer loop, but it seems like + // exact_number_of_samples is an edge case not worth optimizing for at this time. + samples_in_this_batch = num_samples - num_batch_samples; + } else { + samples_in_this_batch = 0; + } + } + + return std::make_tuple(std::move(src), std::move(dst)); +} + +} // namespace cugraph From 5504c748f0d61317ea394b87cf01f2c75249aa39 Mon Sep 17 00:00:00 2001 From: Charles Hastings Date: Wed, 10 Jul 2024 11:13:08 -0700 Subject: [PATCH 03/18] rename utility_wrapper.cuh --- cpp/src/detail/utility_wrappers_32.cu | 2 +- cpp/src/detail/utility_wrappers_64.cu | 2 +- .../detail/{utility_wrappers.cuh => utility_wrappers_impl.cuh} | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename cpp/src/detail/{utility_wrappers.cuh => utility_wrappers_impl.cuh} (100%) diff --git a/cpp/src/detail/utility_wrappers_32.cu b/cpp/src/detail/utility_wrappers_32.cu index 7a3b0ceb458..35dc15079b2 100644 --- a/cpp/src/detail/utility_wrappers_32.cu +++ b/cpp/src/detail/utility_wrappers_32.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "detail/utility_wrappers.cuh" +#include "detail/utility_wrappers_impl.cuh" #include #include diff --git a/cpp/src/detail/utility_wrappers_64.cu b/cpp/src/detail/utility_wrappers_64.cu index e186ce35473..a6dfb5d768c 100644 --- a/cpp/src/detail/utility_wrappers_64.cu +++ b/cpp/src/detail/utility_wrappers_64.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "detail/utility_wrappers.cuh" +#include "detail/utility_wrappers_impl.cuh" #include #include diff --git a/cpp/src/detail/utility_wrappers.cuh b/cpp/src/detail/utility_wrappers_impl.cuh similarity index 100% rename from cpp/src/detail/utility_wrappers.cuh rename to cpp/src/detail/utility_wrappers_impl.cuh From 912ae6fdfdcaff2a227612fe576204e380d764f1 Mon Sep 17 00:00:00 2001 From: Charles Hastings Date: Mon, 15 Jul 2024 11:56:19 -0700 Subject: [PATCH 04/18] Working SG negative sampling tests --- cpp/CMakeLists.txt | 3 + cpp/include/cugraph/graph_view.hpp | 4 +- cpp/include/cugraph/sampling_functions.hpp | 10 +- cpp/src/detail/utility_wrappers_impl.cuh | 8 +- cpp/src/generators/erdos_renyi_generator.cuh | 10 + cpp/src/sampling/negative_sampling_impl.cuh | 93 ++-- .../sampling/negative_sampling_sg_v32_e32.cu | 48 +++ .../sampling/negative_sampling_sg_v32_e64.cu | 48 +++ .../sampling/negative_sampling_sg_v64_e64.cu | 48 +++ cpp/src/structure/graph_view_impl.cuh | 4 +- cpp/tests/CMakeLists.txt | 4 + cpp/tests/sampling/negative_sampling.cu | 400 ++++++++++++++++++ 12 files changed, 632 insertions(+), 48 deletions(-) create mode 100644 cpp/src/sampling/negative_sampling_sg_v32_e32.cu create mode 100644 cpp/src/sampling/negative_sampling_sg_v32_e64.cu create mode 100644 cpp/src/sampling/negative_sampling_sg_v64_e64.cu create mode 100644 cpp/tests/sampling/negative_sampling.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 9a9c445ed54..9503e1cf1ad 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -331,6 +331,9 @@ set(CUGRAPH_SOURCES src/sampling/neighbor_sampling_sg_v32_e64.cpp src/sampling/neighbor_sampling_sg_v32_e32.cpp src/sampling/neighbor_sampling_sg_v64_e64.cpp + src/sampling/negative_sampling_sg_v32_e64.cu + src/sampling/negative_sampling_sg_v32_e32.cu + src/sampling/negative_sampling_sg_v64_e64.cu src/sampling/renumber_sampled_edgelist_sg_v64_e64.cu src/sampling/renumber_sampled_edgelist_sg_v32_e32.cu src/sampling/sampling_post_processing_sg_v64_e64.cu diff --git a/cpp/include/cugraph/graph_view.hpp b/cpp/include/cugraph/graph_view.hpp index cbb52ef3b1e..a2ff3166fa4 100644 --- a/cpp/include/cugraph/graph_view.hpp +++ b/cpp/include/cugraph/graph_view.hpp @@ -636,7 +636,7 @@ class graph_view_t edge_srcs, raft::device_span edge_dsts, - bool do_expensive_check = false); + bool do_expensive_check = false) const; rmm::device_uvector compute_multiplicity( raft::handle_t const& handle, @@ -945,7 +945,7 @@ class graph_view_t has_edge(raft::handle_t const& handle, raft::device_span edge_srcs, raft::device_span edge_dsts, - bool do_expensive_check = false); + bool do_expensive_check = false) const; rmm::device_uvector compute_multiplicity(raft::handle_t const& handle, raft::device_span edge_srcs, diff --git a/cpp/include/cugraph/sampling_functions.hpp b/cpp/include/cugraph/sampling_functions.hpp index 2c303e8204f..88854ecc0ea 100644 --- a/cpp/include/cugraph/sampling_functions.hpp +++ b/cpp/include/cugraph/sampling_functions.hpp @@ -778,6 +778,7 @@ lookup_endpoints_from_edge_ids_and_types( * the graph * @param exact_number_of_samples If true, repeat generation until we get the exact number of * negative samples + * @param do_expensive_check A flag to run expensive checks for input arguments (if set to `true`). * * @return tuple containing source vertex ids and destination vertex ids for the negative samples */ @@ -788,13 +789,14 @@ template std::tuple, rmm::device_uvector> negative_sampling( raft::handle_t const& handle, + raft::random::RngState& rng_state, graph_view_t const& graph_view, - raft::random::rng_state& rng_state, size_t num_samples, - std::optional> src_bias, - std::optional> dst_bias, + std::optional> src_bias, + std::optional> dst_bias, bool remove_duplicates, bool remove_false_negatives, - bool exact_number_of_samples); + bool exact_number_of_samples, + bool do_expensive_check); } // namespace cugraph diff --git a/cpp/src/detail/utility_wrappers_impl.cuh b/cpp/src/detail/utility_wrappers_impl.cuh index 98862358b6e..f6023c650b8 100644 --- a/cpp/src/detail/utility_wrappers_impl.cuh +++ b/cpp/src/detail/utility_wrappers_impl.cuh @@ -63,8 +63,12 @@ void biased_random_fill(raft::handle_t const& handle, raft::device_span output, raft::device_span biases) { - CUGRAPH_EXPECTS(std::is_integral::value); - raft::random::discrete(handle, rng_state, output, biases); + CUGRAPH_EXPECTS(std::is_integral::value, + "biased_random_fill can only output integral values"); + raft::random::discrete(handle, + rng_state, + raft::make_device_vector_view(output.data(), output.size()), + raft::make_device_vector_view(biases.data(), biases.size())); } template diff --git a/cpp/src/generators/erdos_renyi_generator.cuh b/cpp/src/generators/erdos_renyi_generator.cuh index cd461ee1aa2..10573ddb0d0 100644 --- a/cpp/src/generators/erdos_renyi_generator.cuh +++ b/cpp/src/generators/erdos_renyi_generator.cuh @@ -40,6 +40,11 @@ generate_erdos_renyi_graph_edgelist_gnp(raft::handle_t const& handle, vertex_t base_vertex_id, uint64_t seed) { + // NOTE: + // https://networkx.org/documentation/stable/_modules/networkx/generators/random_graphs.html#fast_gnp_random_graph + // identifies a faster algorithm that I think would be very efficient on the GPU. I believe we + // could just compute lr/lp in that code for a batch of values, use prefix sums to generate edge + // ids and then convert the generated values to a batch of edges. CUGRAPH_EXPECTS(num_vertices < std::numeric_limits::max(), "Implementation cannot support specified value"); @@ -88,6 +93,11 @@ generate_erdos_renyi_graph_edgelist_gnm(raft::handle_t const& handle, uint64_t seed) { CUGRAPH_FAIL("Not implemented"); + + // To implement: + // Use sampling function to select `m` unique edge ids from the + // (num_vertices ^ 2) possible edges. Convert these to vertex + // ids. } } // namespace cugraph diff --git a/cpp/src/sampling/negative_sampling_impl.cuh b/cpp/src/sampling/negative_sampling_impl.cuh index 8ec98b2d7d3..dc174098b59 100644 --- a/cpp/src/sampling/negative_sampling_impl.cuh +++ b/cpp/src/sampling/negative_sampling_impl.cuh @@ -16,6 +16,13 @@ #pragma once +#include +#include +#include + +#include +#include + namespace cugraph { template std::tuple, rmm::device_uvector> negative_sampling( raft::handle_t const& handle, + raft::random::RngState& rng_state, graph_view_t const& graph_view, - raft::random::rng_state& rng_state, size_t num_samples, - std::optional> src_bias, - std::optional> dst_bias, + std::optional> src_bias, + std::optional> dst_bias, bool remove_duplicates, bool remove_false_negatives, bool exact_number_of_samples, @@ -54,31 +61,31 @@ std::tuple, rmm::device_uvector> negativ rmm::device_uvector batch_dst(samples_in_this_batch, handle.get_stream()); if (src_bias) { - biased_random_fill(handle, - rng_state, - raft::device_span{batch_src.data(), batch_src.size()}, - *src_bias); + detail::biased_random_fill(handle, + rng_state, + raft::device_span{batch_src.data(), batch_src.size()}, + *src_bias); } else { - uniform_random_fill(handle.get_stream(), - batch_src.data(), - batch_src.size(), - vertex_t{0}, - graph_view.number_of_vertices(), - rng_state); + detail::uniform_random_fill(handle.get_stream(), + batch_src.data(), + batch_src.size(), + vertex_t{0}, + graph_view.number_of_vertices(), + rng_state); } if (dst_bias) { - biased_random_fill(handle, - rng_state, - raft::device_span{batch_dst.data(), batch_dst.size()}, - *dst_bias); + detail::biased_random_fill(handle, + rng_state, + raft::device_span{batch_dst.data(), batch_dst.size()}, + *dst_bias); } else { - uniform_random_fill(handle.get_stream(), - batch_dst.data(), - batch_dst.size(), - vertex_t{0}, - graph_view.number_of_vertices(), - rng_state); + detail::uniform_random_fill(handle.get_stream(), + batch_dst.data(), + batch_dst.size(), + vertex_t{0}, + graph_view.number_of_vertices(), + rng_state); } if constexpr (multi_gpu) { @@ -101,8 +108,8 @@ std::tuple, rmm::device_uvector> negativ if (remove_false_negatives) { auto has_edge_flags = graph_view.has_edge(handle, - raft::device_span{batch_src.data(), batch_src.size()}, - raft::device_span{batch_dst.data(), batch_dst.size()}, + raft::device_span{batch_src.data(), batch_src.size()}, + raft::device_span{batch_dst.data(), batch_dst.size()}, do_expensive_check); auto begin_iter = @@ -119,23 +126,25 @@ std::tuple, rmm::device_uvector> negativ auto begin_iter = thrust::make_zip_iterator(batch_src.begin(), batch_dst.begin()); thrust::sort(handle.get_thrust_policy(), begin_iter, begin_iter + batch_src.size()); - auto new_end = thrust::unique(handle.get_thrust_policy(), begin_iter, end_iter); + auto new_end = + thrust::unique(handle.get_thrust_policy(), begin_iter, begin_iter + batch_src.size()); size_t unique_size = thrust::distance(begin_iter, new_end); if (src.size() > 0) { - new_end = thrust::remove_if( - handle.get_thrust_policy(), - begin_iter, - begin_iter + unique_size, - [local_src = raft::device_span{src.data(), src.size()}, - local_dst = raft::device_span{dst.data(), dst.size()}] __device__(auto tuple) { - return thrust::binary_search( - thrust::seq, - thrust::make_zip_iterator(local_src.begin(), local_dst.begin()), - thrust::make_zip_iterator(local_src.end(), local_dst.end()), - tuple); - }); + new_end = + thrust::remove_if(handle.get_thrust_policy(), + begin_iter, + begin_iter + unique_size, + [local_src = raft::device_span{src.data(), src.size()}, + local_dst = raft::device_span{ + dst.data(), dst.size()}] __device__(auto tuple) { + return thrust::binary_search( + thrust::seq, + thrust::make_zip_iterator(local_src.begin(), local_dst.begin()), + thrust::make_zip_iterator(local_src.end(), local_dst.end()), + tuple); + }); unique_size = thrust::distance(begin_iter, new_end); } @@ -154,9 +163,17 @@ std::tuple, rmm::device_uvector> negativ thrust::make_zip_iterator(batch_src.begin(), batch_dst.begin()), thrust::make_zip_iterator(batch_src.end(), batch_dst.end()), thrust::make_zip_iterator(src.begin(), dst.begin()) + current_end); + + auto begin_iter = thrust::make_zip_iterator(src.begin(), dst.begin()); + thrust::sort(handle.get_thrust_policy(), begin_iter, begin_iter + src.size()); } else { src = std::move(batch_src); dst = std::move(batch_dst); + + if (!remove_duplicates) { + auto begin_iter = thrust::make_zip_iterator(src.begin(), dst.begin()); + thrust::sort(handle.get_thrust_policy(), begin_iter, begin_iter + src.size()); + } } if (exact_number_of_samples) { diff --git a/cpp/src/sampling/negative_sampling_sg_v32_e32.cu b/cpp/src/sampling/negative_sampling_sg_v32_e32.cu new file mode 100644 index 00000000000..b9fbaba76be --- /dev/null +++ b/cpp/src/sampling/negative_sampling_sg_v32_e32.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "negative_sampling_impl.cuh" + +#include +#include + +namespace cugraph { + +template std::tuple, rmm::device_uvector> negative_sampling( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + size_t num_samples, + std::optional> src_bias, + std::optional> dst_bias, + bool remove_duplicates, + bool remove_false_negatives, + bool exact_number_of_samples, + bool do_expensive_check); + +template std::tuple, rmm::device_uvector> negative_sampling( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + size_t num_samples, + std::optional> src_bias, + std::optional> dst_bias, + bool remove_duplicates, + bool remove_false_negatives, + bool exact_number_of_samples, + bool do_expensive_check); + +} // namespace cugraph diff --git a/cpp/src/sampling/negative_sampling_sg_v32_e64.cu b/cpp/src/sampling/negative_sampling_sg_v32_e64.cu new file mode 100644 index 00000000000..6db40b327af --- /dev/null +++ b/cpp/src/sampling/negative_sampling_sg_v32_e64.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "negative_sampling_impl.cuh" + +#include +#include + +namespace cugraph { + +template std::tuple, rmm::device_uvector> negative_sampling( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + size_t num_samples, + std::optional> src_bias, + std::optional> dst_bias, + bool remove_duplicates, + bool remove_false_negatives, + bool exact_number_of_samples, + bool do_expensive_check); + +template std::tuple, rmm::device_uvector> negative_sampling( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + size_t num_samples, + std::optional> src_bias, + std::optional> dst_bias, + bool remove_duplicates, + bool remove_false_negatives, + bool exact_number_of_samples, + bool do_expensive_check); + +} // namespace cugraph diff --git a/cpp/src/sampling/negative_sampling_sg_v64_e64.cu b/cpp/src/sampling/negative_sampling_sg_v64_e64.cu new file mode 100644 index 00000000000..0c5152b21c5 --- /dev/null +++ b/cpp/src/sampling/negative_sampling_sg_v64_e64.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "negative_sampling_impl.cuh" + +#include +#include + +namespace cugraph { + +template std::tuple, rmm::device_uvector> negative_sampling( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + size_t num_samples, + std::optional> src_bias, + std::optional> dst_bias, + bool remove_duplicates, + bool remove_false_negatives, + bool exact_number_of_samples, + bool do_expensive_check); + +template std::tuple, rmm::device_uvector> negative_sampling( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + size_t num_samples, + std::optional> src_bias, + std::optional> dst_bias, + bool remove_duplicates, + bool remove_false_negatives, + bool exact_number_of_samples, + bool do_expensive_check); + +} // namespace cugraph diff --git a/cpp/src/structure/graph_view_impl.cuh b/cpp/src/structure/graph_view_impl.cuh index 7097349dce5..755210a586f 100644 --- a/cpp/src/structure/graph_view_impl.cuh +++ b/cpp/src/structure/graph_view_impl.cuh @@ -793,7 +793,7 @@ graph_view_t edge_srcs, raft::device_span edge_dsts, - bool do_expensive_check) + bool do_expensive_check) const { CUGRAPH_EXPECTS( edge_srcs.size() == edge_dsts.size(), @@ -873,7 +873,7 @@ graph_view_t edge_srcs, raft::device_span edge_dsts, - bool do_expensive_check) + bool do_expensive_check) const { CUGRAPH_EXPECTS( edge_srcs.size() == edge_dsts.size(), diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 3ad27b503a4..4bcb2769b65 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -491,6 +491,10 @@ ConfigureTest(BIASED_NEIGHBOR_SAMPLING_TEST sampling/biased_neighbor_sampling.cp # - SAMPLING_POST_PROCESSING tests ---------------------------------------------------------------- ConfigureTest(SAMPLING_POST_PROCESSING_TEST sampling/sampling_post_processing_test.cu) +################################################################################################### +# - NEGATVIE SAMPLING tests -------------------------------------------------------------------- +ConfigureTest(NEGATIVE_SAMPLING_TEST sampling/negative_sampling.cu) + ################################################################################################### # - Renumber tests -------------------------------------------------------------------------------- ConfigureTest(RENUMBERING_TEST structure/renumbering_test.cpp) diff --git a/cpp/tests/sampling/negative_sampling.cu b/cpp/tests/sampling/negative_sampling.cu new file mode 100644 index 00000000000..1d714b85271 --- /dev/null +++ b/cpp/tests/sampling/negative_sampling.cu @@ -0,0 +1,400 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utilities/base_fixture.hpp" +#include "utilities/conversion_utilities.hpp" +#include "utilities/property_generator_utilities.hpp" + +#include +#include +#include + +#include + +struct Negative_Sampling_Usecase { + float sample_multiplier{2}; + bool use_src_bias{false}; + bool use_dst_bias{false}; + bool remove_duplicates{false}; + bool remove_false_negatives{false}; + bool exact_number_of_samples{false}; + bool check_correctness{true}; +}; + +template +class Tests_Negative_Sampling : public ::testing::TestWithParam { + public: + using graph_t = cugraph::graph_t; + using graph_view_t = cugraph::graph_view_t; + + Tests_Negative_Sampling() : graph(raft::handle_t{}) {} + + static void SetUpTestCase() {} + static void TearDownTestCase() {} + + template + void load_graph(input_t const& param) + { + raft::handle_t handle{}; + HighResTimer hr_timer{}; + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + hr_timer.start("Construct graph"); + } + + std::tie(graph, edge_weights, renumber_map_labels) = + cugraph::test::construct_graph( + handle, param, true, true); + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + hr_timer.stop(); + hr_timer.display_and_clear(std::cout); + } + } + + virtual void SetUp() {} + virtual void TearDown() {} + + void run_current_test(raft::random::RngState& rng_state, + Negative_Sampling_Usecase const& negative_sampling_usecase) + { + constexpr bool do_expensive_check{false}; + + raft::handle_t handle{}; + HighResTimer hr_timer{}; + + auto graph_view = graph.view(); + + size_t num_samples = graph_view.number_of_edges() * negative_sampling_usecase.sample_multiplier; + + rmm::device_uvector src_bias_v(0, handle.get_stream()); + rmm::device_uvector dst_bias_v(0, handle.get_stream()); + + std::optional> src_bias{std::nullopt}; + std::optional> dst_bias{std::nullopt}; + + if (negative_sampling_usecase.use_src_bias) { + src_bias_v.resize(graph_view.number_of_vertices(), handle.get_stream()); + + cugraph::detail::uniform_random_fill(handle.get_stream(), + src_bias_v.data(), + src_bias_v.size(), + weight_t{1}, + weight_t{10}, + rng_state); + + src_bias = raft::device_span{src_bias_v.data(), src_bias_v.size()}; + } + + if (negative_sampling_usecase.use_dst_bias) { + dst_bias_v.resize(graph_view.number_of_vertices(), handle.get_stream()); + + cugraph::detail::uniform_random_fill(handle.get_stream(), + dst_bias_v.data(), + dst_bias_v.size(), + weight_t{1}, + weight_t{10}, + rng_state); + + dst_bias = raft::device_span{dst_bias_v.data(), dst_bias_v.size()}; + } + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + hr_timer.start("Negative sampling"); + } + + auto&& [src_out, dst_out] = + cugraph::negative_sampling(handle, + rng_state, + graph_view, + num_samples, + src_bias, + dst_bias, + negative_sampling_usecase.remove_duplicates, + negative_sampling_usecase.remove_false_negatives, + negative_sampling_usecase.exact_number_of_samples, + do_expensive_check); + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + hr_timer.stop(); + hr_timer.display_and_clear(std::cout); + } + + if (negative_sampling_usecase.check_correctness) { + ASSERT_EQ(src_out.size(), dst_out.size()) << "Result size (src, dst) mismatch"; + + auto vertex_partition = cugraph::vertex_partition_device_view_t( + graph_view.local_vertex_partition_view()); + + size_t count = + thrust::count_if(handle.get_thrust_policy(), + src_out.begin(), + src_out.end(), + [vertex_partition] __device__(auto val) { + return !(vertex_partition.is_valid_vertex(val) && + vertex_partition.in_local_vertex_partition_range_nocheck(val)); + }); + + ASSERT_EQ(count, 0) << "Source vertices out of range > 0"; + + count = + thrust::count_if(handle.get_thrust_policy(), + dst_out.begin(), + dst_out.end(), + [vertex_partition] __device__(auto val) { + return !(vertex_partition.is_valid_vertex(val) && + vertex_partition.in_local_vertex_partition_range_nocheck(val)); + }); + ASSERT_EQ(count, 0) << "Dest vertices out of range > 0"; + + if (negative_sampling_usecase.remove_duplicates) { + count = thrust::count_if( + handle.get_thrust_policy(), + thrust::make_counting_iterator(1), + thrust::make_counting_iterator(src_out.size()), + [src = src_out.data(), dst = dst_out.data()] __device__(size_t index) { + return (src[index - 1] == src[index]) && (dst[index - 1] == dst[index]); + }); + ASSERT_EQ(count, 0) << "Remove duplicates specified, found duplicate entries"; + } + + if (negative_sampling_usecase.remove_false_negatives) { + rmm::device_uvector graph_src(0, handle.get_stream()); + rmm::device_uvector graph_dst(0, handle.get_stream()); + + std::tie(graph_src, graph_dst, std::ignore, std::ignore, std::ignore) = + cugraph::decompress_to_edgelist( + handle, graph_view, std::nullopt, std::nullopt, std::nullopt, std::nullopt); + + count = thrust::count_if( + handle.get_thrust_policy(), + thrust::make_zip_iterator(src_out.begin(), dst_out.begin()), + thrust::make_zip_iterator(src_out.end(), dst_out.end()), + [src = graph_src.data(), dst = graph_dst.data(), size = graph_dst.size()] __device__( + auto tuple) { + return thrust::binary_search(thrust::seq, + thrust::make_zip_iterator(src, dst), + thrust::make_zip_iterator(src, dst) + size, + tuple); + }); + + ASSERT_EQ(count, 0) << "Remove false negatives specified, found false negatives"; + } + + if (negative_sampling_usecase.exact_number_of_samples) { + ASSERT_EQ(src_out.size(), num_samples) << "Expected exact number of samples"; + } + + // TBD: How do we determine if we have properly reflected the biases? + } + } + + private: + graph_t graph; + std::optional> edge_weights{std::nullopt}; + std::optional> renumber_map_labels{std::nullopt}; +}; + +using Tests_Negative_Sampling_File_i32_i32_float = + Tests_Negative_Sampling; + +using Tests_Negative_Sampling_File_i32_i64_float = + Tests_Negative_Sampling; + +using Tests_Negative_Sampling_File_i64_i64_float = + Tests_Negative_Sampling; + +using Tests_Negative_Sampling_Rmat_i32_i32_float = + Tests_Negative_Sampling; + +using Tests_Negative_Sampling_Rmat_i32_i64_float = + Tests_Negative_Sampling; + +using Tests_Negative_Sampling_Rmat_i64_i64_float = + Tests_Negative_Sampling; + +template +void run_all_tests(CurrentTest* current_test) +{ + raft::random::RngState rng_state{0}; + + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, false, false, false, false, false, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, true, false, false, false, false, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, false, true, false, false, false, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, true, true, false, false, false, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, false, false, true, false, false, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, true, false, true, false, false, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, false, true, true, false, false, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, true, true, true, false, false, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, false, false, false, true, false, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, true, false, false, true, false, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, false, true, false, true, false, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, true, true, false, true, false, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, false, false, true, true, false, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, true, false, true, true, false, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, false, true, true, true, false, true}); + current_test->run_current_test(rng_state, + Negative_Sampling_Usecase{2, true, true, true, true, false, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, false, false, false, false, true, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, true, false, false, false, true, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, false, true, false, false, true, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, true, true, false, false, true, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, false, false, true, false, true, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, true, false, true, false, true, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, false, true, true, false, true, true}); + current_test->run_current_test(rng_state, + Negative_Sampling_Usecase{2, true, true, true, false, true, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, false, false, false, true, true, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, true, false, false, true, true, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, false, true, false, true, true, true}); + current_test->run_current_test(rng_state, + Negative_Sampling_Usecase{2, true, true, false, true, true, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, false, false, true, true, true, true}); + current_test->run_current_test(rng_state, + Negative_Sampling_Usecase{2, true, false, true, true, true, true}); + current_test->run_current_test(rng_state, + Negative_Sampling_Usecase{2, false, true, true, true, true, true}); + current_test->run_current_test(rng_state, + Negative_Sampling_Usecase{2, true, true, true, true, true, true}); +} + +TEST_P(Tests_Negative_Sampling_File_i32_i32_float, CheckInt32Int32Float) +{ + load_graph(override_File_Usecase_with_cmd_line_arguments(GetParam())); + run_all_tests(this); +} + +TEST_P(Tests_Negative_Sampling_File_i32_i64_float, CheckInt32Int64Float) +{ + load_graph(override_File_Usecase_with_cmd_line_arguments(GetParam())); + run_all_tests(this); +} + +TEST_P(Tests_Negative_Sampling_File_i64_i64_float, CheckInt64Int64Float) +{ + load_graph(override_File_Usecase_with_cmd_line_arguments(GetParam())); + run_all_tests(this); +} + +TEST_P(Tests_Negative_Sampling_Rmat_i32_i32_float, CheckInt32Int32Float) +{ + load_graph(override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); + run_all_tests(this); +} + +TEST_P(Tests_Negative_Sampling_Rmat_i32_i64_float, CheckInt32Int64Float) +{ + load_graph(override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); + run_all_tests(this); +} + +TEST_P(Tests_Negative_Sampling_Rmat_i64_i64_float, CheckInt64Int64Float) +{ + load_graph(override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); + run_all_tests(this); +} + +INSTANTIATE_TEST_SUITE_P( + file_test, + Tests_Negative_Sampling_File_i32_i32_float, + ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"))); + +INSTANTIATE_TEST_SUITE_P( + file_large_test, + Tests_Negative_Sampling_File_i32_i32_float, + ::testing::Values(cugraph::test::File_Usecase("test/datasets/web-Google.mtx"), + cugraph::test::File_Usecase("test/datasets/ljournal-2008.mtx"), + cugraph::test::File_Usecase("test/datasets/webbase-1M.mtx"))); + +INSTANTIATE_TEST_SUITE_P( + file_test, + Tests_Negative_Sampling_File_i32_i64_float, + ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"))); + +INSTANTIATE_TEST_SUITE_P( + file_large_test, + Tests_Negative_Sampling_File_i32_i64_float, + ::testing::Values(cugraph::test::File_Usecase("test/datasets/web-Google.mtx"), + cugraph::test::File_Usecase("test/datasets/ljournal-2008.mtx"), + cugraph::test::File_Usecase("test/datasets/webbase-1M.mtx"))); + +INSTANTIATE_TEST_SUITE_P( + file_test, + Tests_Negative_Sampling_File_i64_i64_float, + ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"))); + +INSTANTIATE_TEST_SUITE_P( + file_large_test, + Tests_Negative_Sampling_File_i64_i64_float, + ::testing::Values(cugraph::test::File_Usecase("test/datasets/web-Google.mtx"), + cugraph::test::File_Usecase("test/datasets/ljournal-2008.mtx"), + cugraph::test::File_Usecase("test/datasets/webbase-1M.mtx"))); + +INSTANTIATE_TEST_SUITE_P( + rmat_small_test, + Tests_Negative_Sampling_Rmat_i32_i32_float, + ::testing::Values(cugraph::test::Rmat_Usecase(10, 16, 0.57, 0.19, 0.19, 0, false, false, 0))); + +INSTANTIATE_TEST_SUITE_P( + rmat_small_test, + Tests_Negative_Sampling_Rmat_i32_i64_float, + ::testing::Values(cugraph::test::Rmat_Usecase(10, 16, 0.57, 0.19, 0.19, 0, false, false, 0))); + +INSTANTIATE_TEST_SUITE_P( + rmat_small_test, + Tests_Negative_Sampling_Rmat_i64_i64_float, + ::testing::Values(cugraph::test::Rmat_Usecase(10, 16, 0.57, 0.19, 0.19, 0, false, false, 0))); + +INSTANTIATE_TEST_SUITE_P( + rmat_benchmark_test, /* note that scale & edge factor can be overridden in benchmarking (with + --gtest_filter to select only the rmat_benchmark_test with a specific + vertex & edge type combination) by command line arguments and do not + include more than one Rmat_Usecase that differ only in scale or edge + factor (to avoid running same benchmarks more than once) */ + Tests_Negative_Sampling_Rmat_i64_i64_float, + ::testing::Values(cugraph::test::Rmat_Usecase(20, 32, 0.57, 0.19, 0.19, 0, false, false, 0))); + +CUGRAPH_TEST_PROGRAM_MAIN() From 0ce07129f4dab6ba9e7d605c892964bfd10528df Mon Sep 17 00:00:00 2001 From: Charles Hastings Date: Wed, 17 Jul 2024 13:27:13 -0700 Subject: [PATCH 05/18] add MG tests --- cpp/CMakeLists.txt | 3 + .../sampling/negative_sampling_mg_v32_e32.cu | 48 ++ .../sampling/negative_sampling_mg_v32_e64.cu | 48 ++ .../sampling/negative_sampling_mg_v64_e64.cu | 48 ++ cpp/tests/CMakeLists.txt | 7 +- cpp/tests/sampling/mg_negative_sampling.cu | 411 ++++++++++++++++++ 6 files changed, 564 insertions(+), 1 deletion(-) create mode 100644 cpp/src/sampling/negative_sampling_mg_v32_e32.cu create mode 100644 cpp/src/sampling/negative_sampling_mg_v32_e64.cu create mode 100644 cpp/src/sampling/negative_sampling_mg_v64_e64.cu create mode 100644 cpp/tests/sampling/mg_negative_sampling.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 9503e1cf1ad..bbbc1d69093 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -334,6 +334,9 @@ set(CUGRAPH_SOURCES src/sampling/negative_sampling_sg_v32_e64.cu src/sampling/negative_sampling_sg_v32_e32.cu src/sampling/negative_sampling_sg_v64_e64.cu + src/sampling/negative_sampling_mg_v32_e64.cu + src/sampling/negative_sampling_mg_v32_e32.cu + src/sampling/negative_sampling_mg_v64_e64.cu src/sampling/renumber_sampled_edgelist_sg_v64_e64.cu src/sampling/renumber_sampled_edgelist_sg_v32_e32.cu src/sampling/sampling_post_processing_sg_v64_e64.cu diff --git a/cpp/src/sampling/negative_sampling_mg_v32_e32.cu b/cpp/src/sampling/negative_sampling_mg_v32_e32.cu new file mode 100644 index 00000000000..fe00bb16747 --- /dev/null +++ b/cpp/src/sampling/negative_sampling_mg_v32_e32.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "negative_sampling_impl.cuh" + +#include +#include + +namespace cugraph { + +template std::tuple, rmm::device_uvector> negative_sampling( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + size_t num_samples, + std::optional> src_bias, + std::optional> dst_bias, + bool remove_duplicates, + bool remove_false_negatives, + bool exact_number_of_samples, + bool do_expensive_check); + +template std::tuple, rmm::device_uvector> negative_sampling( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + size_t num_samples, + std::optional> src_bias, + std::optional> dst_bias, + bool remove_duplicates, + bool remove_false_negatives, + bool exact_number_of_samples, + bool do_expensive_check); + +} // namespace cugraph diff --git a/cpp/src/sampling/negative_sampling_mg_v32_e64.cu b/cpp/src/sampling/negative_sampling_mg_v32_e64.cu new file mode 100644 index 00000000000..403257103f8 --- /dev/null +++ b/cpp/src/sampling/negative_sampling_mg_v32_e64.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "negative_sampling_impl.cuh" + +#include +#include + +namespace cugraph { + +template std::tuple, rmm::device_uvector> negative_sampling( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + size_t num_samples, + std::optional> src_bias, + std::optional> dst_bias, + bool remove_duplicates, + bool remove_false_negatives, + bool exact_number_of_samples, + bool do_expensive_check); + +template std::tuple, rmm::device_uvector> negative_sampling( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + size_t num_samples, + std::optional> src_bias, + std::optional> dst_bias, + bool remove_duplicates, + bool remove_false_negatives, + bool exact_number_of_samples, + bool do_expensive_check); + +} // namespace cugraph diff --git a/cpp/src/sampling/negative_sampling_mg_v64_e64.cu b/cpp/src/sampling/negative_sampling_mg_v64_e64.cu new file mode 100644 index 00000000000..b3941b9db13 --- /dev/null +++ b/cpp/src/sampling/negative_sampling_mg_v64_e64.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "negative_sampling_impl.cuh" + +#include +#include + +namespace cugraph { + +template std::tuple, rmm::device_uvector> negative_sampling( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + size_t num_samples, + std::optional> src_bias, + std::optional> dst_bias, + bool remove_duplicates, + bool remove_false_negatives, + bool exact_number_of_samples, + bool do_expensive_check); + +template std::tuple, rmm::device_uvector> negative_sampling( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + size_t num_samples, + std::optional> src_bias, + std::optional> dst_bias, + bool remove_duplicates, + bool remove_false_negatives, + bool exact_number_of_samples, + bool do_expensive_check); + +} // namespace cugraph diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 4bcb2769b65..c62c794e7e3 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -492,7 +492,7 @@ ConfigureTest(BIASED_NEIGHBOR_SAMPLING_TEST sampling/biased_neighbor_sampling.cp ConfigureTest(SAMPLING_POST_PROCESSING_TEST sampling/sampling_post_processing_test.cu) ################################################################################################### -# - NEGATVIE SAMPLING tests -------------------------------------------------------------------- +# - NEGATIVE SAMPLING tests -------------------------------------------------------------------- ConfigureTest(NEGATIVE_SAMPLING_TEST sampling/negative_sampling.cu) ################################################################################################### @@ -745,6 +745,11 @@ if(BUILD_CUGRAPH_MG_TESTS) # - MG BIASED NBR SAMPLING tests -------------------------------------------------------------- ConfigureTestMG(MG_BIASED_NEIGHBOR_SAMPLING_TEST sampling/mg_biased_neighbor_sampling.cpp) + ################################################################################################### + # - NEGATIVE SAMPLING tests -------------------------------------------------------------------- + ConfigureTestMG(MG_NEGATIVE_SAMPLING_TEST sampling/mg_negative_sampling.cu) + + ############################################################################################### # - MG RANDOM_WALKS tests --------------------------------------------------------------------- ConfigureTestMG(MG_RANDOM_WALKS_TEST sampling/mg_random_walks_test.cpp) diff --git a/cpp/tests/sampling/mg_negative_sampling.cu b/cpp/tests/sampling/mg_negative_sampling.cu new file mode 100644 index 00000000000..e180594f87b --- /dev/null +++ b/cpp/tests/sampling/mg_negative_sampling.cu @@ -0,0 +1,411 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "detail/graph_partition_utils.cuh" +#include "utilities/base_fixture.hpp" +#include "utilities/conversion_utilities.hpp" +#include "utilities/property_generator_utilities.hpp" + +#include +#include +#include + +#include + +struct Negative_Sampling_Usecase { + float sample_multiplier{2}; + bool use_src_bias{false}; + bool use_dst_bias{false}; + bool remove_duplicates{false}; + bool remove_false_negatives{false}; + bool exact_number_of_samples{false}; + bool check_correctness{true}; +}; + +template +class Tests_MGNegative_Sampling : public ::testing::TestWithParam { + public: + using graph_t = cugraph::graph_t; + using graph_view_t = cugraph::graph_view_t; + + Tests_MGNegative_Sampling() : graph_(*handle_) {} + + static void SetUpTestCase() { handle_ = cugraph::test::initialize_mg_handle(); } + + static void TearDownTestCase() { handle_.reset(); } + + template + void load_graph(input_t const& param) + { + HighResTimer hr_timer{}; + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + hr_timer.start("Construct graph"); + } + + std::tie(graph_, edge_weights_, renumber_map_labels_) = + cugraph::test::construct_graph( + *handle_, param, true, true); + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + hr_timer.stop(); + hr_timer.display_and_clear(std::cout); + } + } + + virtual void SetUp() {} + virtual void TearDown() {} + + void run_current_test(raft::random::RngState& rng_state, + Negative_Sampling_Usecase const& negative_sampling_usecase) + { + constexpr bool do_expensive_check{false}; + + HighResTimer hr_timer{}; + + auto graph_view = graph_.view(); + + size_t num_samples = graph_view.number_of_edges() * negative_sampling_usecase.sample_multiplier; + + rmm::device_uvector src_bias_v(0, handle_->get_stream()); + rmm::device_uvector dst_bias_v(0, handle_->get_stream()); + + std::optional> src_bias{std::nullopt}; + std::optional> dst_bias{std::nullopt}; + + if (negative_sampling_usecase.use_src_bias) { + src_bias_v.resize(graph_view.number_of_vertices(), handle_->get_stream()); + + cugraph::detail::uniform_random_fill(handle_->get_stream(), + src_bias_v.data(), + src_bias_v.size(), + weight_t{1}, + weight_t{10}, + rng_state); + + src_bias = raft::device_span{src_bias_v.data(), src_bias_v.size()}; + } + + if (negative_sampling_usecase.use_dst_bias) { + dst_bias_v.resize(graph_view.number_of_vertices(), handle_->get_stream()); + + cugraph::detail::uniform_random_fill(handle_->get_stream(), + dst_bias_v.data(), + dst_bias_v.size(), + weight_t{1}, + weight_t{10}, + rng_state); + + dst_bias = raft::device_span{dst_bias_v.data(), dst_bias_v.size()}; + } + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + hr_timer.start("Negative sampling"); + } + + auto&& [src_out, dst_out] = + cugraph::negative_sampling(*handle_, + rng_state, + graph_view, + num_samples, + src_bias, + dst_bias, + negative_sampling_usecase.remove_duplicates, + negative_sampling_usecase.remove_false_negatives, + negative_sampling_usecase.exact_number_of_samples, + do_expensive_check); + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + hr_timer.stop(); + hr_timer.display_and_clear(std::cout); + } + + if (negative_sampling_usecase.check_correctness) { + ASSERT_EQ(src_out.size(), dst_out.size()) << "Result size (src, dst) mismatch"; + + auto h_vertex_partition_range_lasts = graph_view.vertex_partition_range_lasts(); + rmm::device_uvector d_vertex_partition_range_lasts( + h_vertex_partition_range_lasts.size(), handle_->get_stream()); + raft::update_device(d_vertex_partition_range_lasts.data(), + h_vertex_partition_range_lasts.data(), + h_vertex_partition_range_lasts.size(), + handle_->get_stream()); + + size_t error_count = thrust::count_if( + handle_->get_thrust_policy(), + thrust::make_zip_iterator(src_out.begin(), dst_out.begin()), + thrust::make_zip_iterator(src_out.end(), dst_out.end()), + [comm_rank = handle_->get_comms().get_rank(), + gpu_id_key_func = cugraph::detail::compute_gpu_id_from_int_edge_endpoints_t{ + raft::device_span{d_vertex_partition_range_lasts.data(), + d_vertex_partition_range_lasts.size()}, + handle_->get_comms().get_size(), + handle_->get_subcomm(cugraph::partition_manager::major_comm_name()).get_size(), + handle_->get_subcomm(cugraph::partition_manager::minor_comm_name()) + .get_size()}] __device__(auto e) { + return (gpu_id_key_func(thrust::get<0>(e), thrust::get<1>(e)) != comm_rank); + }); + + ASSERT_EQ(error_count, 0) << "generate edges out of range > 0"; + + if (negative_sampling_usecase.remove_duplicates) { + error_count = thrust::count_if( + handle_->get_thrust_policy(), + thrust::make_counting_iterator(1), + thrust::make_counting_iterator(src_out.size()), + [src = src_out.data(), dst = dst_out.data()] __device__(size_t index) { + return (src[index - 1] == src[index]) && (dst[index - 1] == dst[index]); + }); + ASSERT_EQ(error_count, 0) << "Remove duplicates specified, found duplicate entries"; + } + + if (negative_sampling_usecase.remove_false_negatives) { + rmm::device_uvector graph_src(0, handle_->get_stream()); + rmm::device_uvector graph_dst(0, handle_->get_stream()); + + std::tie(graph_src, graph_dst, std::ignore, std::ignore, std::ignore) = + cugraph::decompress_to_edgelist( + *handle_, graph_view, std::nullopt, std::nullopt, std::nullopt, std::nullopt); + + error_count = thrust::count_if( + handle_->get_thrust_policy(), + thrust::make_zip_iterator(src_out.begin(), dst_out.begin()), + thrust::make_zip_iterator(src_out.end(), dst_out.end()), + [src = graph_src.data(), dst = graph_dst.data(), size = graph_dst.size()] __device__( + auto tuple) { + return thrust::binary_search(thrust::seq, + thrust::make_zip_iterator(src, dst), + thrust::make_zip_iterator(src, dst) + size, + tuple); + }); + + ASSERT_EQ(error_count, 0) << "Remove false negatives specified, found false negatives"; + } + + if (negative_sampling_usecase.exact_number_of_samples) { + size_t sz = cugraph::host_scalar_allreduce( + handle_->get_comms(), src_out.size(), raft::comms::op_t::SUM, handle_->get_stream()); + ASSERT_EQ(sz, num_samples) << "Expected exact number of samples"; + } + + // TBD: How do we determine if we have properly reflected the biases? + } + } + + public: + static std::unique_ptr handle_; + + private: + graph_t graph_; + std::optional> edge_weights_{std::nullopt}; + std::optional> renumber_map_labels_{std::nullopt}; +}; + +template +std::unique_ptr + Tests_MGNegative_Sampling::handle_ = nullptr; + +using Tests_MGNegative_Sampling_File_i32_i32_float = + Tests_MGNegative_Sampling; + +using Tests_MGNegative_Sampling_File_i32_i64_float = + Tests_MGNegative_Sampling; + +using Tests_MGNegative_Sampling_File_i64_i64_float = + Tests_MGNegative_Sampling; + +using Tests_MGNegative_Sampling_Rmat_i32_i32_float = + Tests_MGNegative_Sampling; + +using Tests_MGNegative_Sampling_Rmat_i32_i64_float = + Tests_MGNegative_Sampling; + +using Tests_MGNegative_Sampling_Rmat_i64_i64_float = + Tests_MGNegative_Sampling; + +template +void run_all_tests(CurrentTest* current_test) +{ + raft::random::RngState rng_state{ + static_cast(current_test->handle_->get_comms().get_rank())}; + + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, false, false, false, false, false, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, true, false, false, false, false, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, false, true, false, false, false, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, true, true, false, false, false, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, false, false, true, false, false, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, true, false, true, false, false, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, false, true, true, false, false, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, true, true, true, false, false, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, false, false, false, true, false, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, true, false, false, true, false, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, false, true, false, true, false, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, true, true, false, true, false, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, false, false, true, true, false, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, true, false, true, true, false, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, false, true, true, true, false, true}); + current_test->run_current_test(rng_state, + Negative_Sampling_Usecase{2, true, true, true, true, false, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, false, false, false, false, true, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, true, false, false, false, true, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, false, true, false, false, true, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, true, true, false, false, true, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, false, false, true, false, true, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, true, false, true, false, true, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, false, true, true, false, true, true}); + current_test->run_current_test(rng_state, + Negative_Sampling_Usecase{2, true, true, true, false, true, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, false, false, false, true, true, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, true, false, false, true, true, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, false, true, false, true, true, true}); + current_test->run_current_test(rng_state, + Negative_Sampling_Usecase{2, true, true, false, true, true, true}); + current_test->run_current_test( + rng_state, Negative_Sampling_Usecase{2, false, false, true, true, true, true}); + current_test->run_current_test(rng_state, + Negative_Sampling_Usecase{2, true, false, true, true, true, true}); + current_test->run_current_test(rng_state, + Negative_Sampling_Usecase{2, false, true, true, true, true, true}); + current_test->run_current_test(rng_state, + Negative_Sampling_Usecase{2, true, true, true, true, true, true}); +} + +TEST_P(Tests_MGNegative_Sampling_File_i32_i32_float, CheckInt32Int32Float) +{ + load_graph(override_File_Usecase_with_cmd_line_arguments(GetParam())); + run_all_tests(this); +} + +TEST_P(Tests_MGNegative_Sampling_File_i32_i64_float, CheckInt32Int64Float) +{ + load_graph(override_File_Usecase_with_cmd_line_arguments(GetParam())); + run_all_tests(this); +} + +TEST_P(Tests_MGNegative_Sampling_File_i64_i64_float, CheckInt64Int64Float) +{ + load_graph(override_File_Usecase_with_cmd_line_arguments(GetParam())); + run_all_tests(this); +} + +TEST_P(Tests_MGNegative_Sampling_Rmat_i32_i32_float, CheckInt32Int32Float) +{ + load_graph(override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); + run_all_tests(this); +} + +TEST_P(Tests_MGNegative_Sampling_Rmat_i32_i64_float, CheckInt32Int64Float) +{ + load_graph(override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); + run_all_tests(this); +} + +TEST_P(Tests_MGNegative_Sampling_Rmat_i64_i64_float, CheckInt64Int64Float) +{ + load_graph(override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); + run_all_tests(this); +} + +INSTANTIATE_TEST_SUITE_P( + file_test, + Tests_MGNegative_Sampling_File_i32_i32_float, + ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"))); + +INSTANTIATE_TEST_SUITE_P( + file_test, + Tests_MGNegative_Sampling_File_i32_i64_float, + ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"))); + +INSTANTIATE_TEST_SUITE_P( + file_test, + Tests_MGNegative_Sampling_File_i64_i64_float, + ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"))); + +INSTANTIATE_TEST_SUITE_P( + file_large_test, + Tests_MGNegative_Sampling_File_i32_i32_float, + ::testing::Values(cugraph::test::File_Usecase("test/datasets/web-Google.mtx"), + cugraph::test::File_Usecase("test/datasets/ljournal-2008.mtx"), + cugraph::test::File_Usecase("test/datasets/webbase-1M.mtx"))); + +INSTANTIATE_TEST_SUITE_P( + file_large_test, + Tests_MGNegative_Sampling_File_i32_i64_float, + ::testing::Values(cugraph::test::File_Usecase("test/datasets/web-Google.mtx"), + cugraph::test::File_Usecase("test/datasets/ljournal-2008.mtx"), + cugraph::test::File_Usecase("test/datasets/webbase-1M.mtx"))); + +INSTANTIATE_TEST_SUITE_P( + file_large_test, + Tests_MGNegative_Sampling_File_i64_i64_float, + ::testing::Values(cugraph::test::File_Usecase("test/datasets/web-Google.mtx"), + cugraph::test::File_Usecase("test/datasets/ljournal-2008.mtx"), + cugraph::test::File_Usecase("test/datasets/webbase-1M.mtx"))); + +INSTANTIATE_TEST_SUITE_P( + rmat_small_test, + Tests_MGNegative_Sampling_Rmat_i32_i32_float, + ::testing::Values(cugraph::test::Rmat_Usecase(10, 16, 0.57, 0.19, 0.19, 0, false, false, 0))); + +INSTANTIATE_TEST_SUITE_P( + rmat_small_test, + Tests_MGNegative_Sampling_Rmat_i32_i64_float, + ::testing::Values(cugraph::test::Rmat_Usecase(10, 16, 0.57, 0.19, 0.19, 0, false, false, 0))); + +INSTANTIATE_TEST_SUITE_P( + rmat_small_test, + Tests_MGNegative_Sampling_Rmat_i64_i64_float, + ::testing::Values(cugraph::test::Rmat_Usecase(10, 16, 0.57, 0.19, 0.19, 0, false, false, 0))); + +INSTANTIATE_TEST_SUITE_P( + rmat_benchmark_test, /* note that scale & edge factor can be overridden in benchmarking (with + --gtest_filter to select only the rmat_benchmark_test with a specific + vertex & edge type combination) by command line arguments and do not + include more than one Rmat_Usecase that differ only in scale or edge + factor (to avoid running same benchmarks more than once) */ + Tests_MGNegative_Sampling_Rmat_i64_i64_float, + ::testing::Values(cugraph::test::Rmat_Usecase(20, 32, 0.57, 0.19, 0.19, 0, false, false, 0))); + +CUGRAPH_MG_TEST_PROGRAM_MAIN() From a31f5a9aa1bf4996e7f6fdadda602a918aed8961 Mon Sep 17 00:00:00 2001 From: Charles Hastings Date: Tue, 23 Jul 2024 13:38:02 -0700 Subject: [PATCH 06/18] Add C API and PLC for negative sampling --- cpp/CMakeLists.txt | 1 + cpp/include/cugraph_c/coo.h | 115 +++++++ cpp/include/cugraph_c/graph_generators.h | 86 +---- cpp/include/cugraph_c/sampling_algorithms.h | 51 +++ cpp/src/c_api/coo.hpp | 37 +++ cpp/src/c_api/graph_generators.cpp | 19 +- cpp/src/c_api/negative_sampling.cpp | 226 ++++++++++++++ cpp/tests/CMakeLists.txt | 2 + cpp/tests/c_api/mg_negative_sampling_test.c | 295 ++++++++++++++++++ cpp/tests/c_api/negative_sampling_test.c | 284 +++++++++++++++++ .../pylibcugraph/_cugraph_c/coo.pxd | 71 +++++ .../_cugraph_c/graph_generators.pxd | 58 +--- .../_cugraph_c/sampling_algorithms.pxd | 22 ++ .../pylibcugraph/generate_rmat_edgelist.pyx | 12 +- .../pylibcugraph/generate_rmat_edgelists.pyx | 14 +- 15 files changed, 1127 insertions(+), 166 deletions(-) create mode 100644 cpp/include/cugraph_c/coo.h create mode 100644 cpp/src/c_api/coo.hpp create mode 100644 cpp/src/c_api/negative_sampling.cpp create mode 100644 cpp/tests/c_api/mg_negative_sampling_test.c create mode 100644 cpp/tests/c_api/negative_sampling_test.c create mode 100644 python/pylibcugraph/pylibcugraph/_cugraph_c/coo.pxd diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 65fa451e6d0..68537900ca2 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -659,6 +659,7 @@ add_library(cugraph_c src/c_api/louvain.cpp src/c_api/triangle_count.cpp src/c_api/uniform_neighbor_sampling.cpp + src/c_api/negative_sampling.cpp src/c_api/labeling_result.cpp src/c_api/weakly_connected_components.cpp src/c_api/strongly_connected_components.cpp diff --git a/cpp/include/cugraph_c/coo.h b/cpp/include/cugraph_c/coo.h new file mode 100644 index 00000000000..ef746c6ed6a --- /dev/null +++ b/cpp/include/cugraph_c/coo.h @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief Opaque COO definition + */ +typedef struct { + int32_t align_; +} cugraph_coo_t; + +/** + * @brief Opaque COO list definition + */ +typedef struct { + int32_t align_; +} cugraph_coo_list_t; + +/** + * @brief Get the source vertex ids + * + * @param [in] coo Opaque pointer to COO + * @return type erased array view of source vertex ids + */ +cugraph_type_erased_device_array_view_t* cugraph_coo_get_sources(cugraph_coo_t* coo); + +/** + * @brief Get the destination vertex ids + * + * @param [in] coo Opaque pointer to COO + * @return type erased array view of destination vertex ids + */ +cugraph_type_erased_device_array_view_t* cugraph_coo_get_destinations(cugraph_coo_t* coo); + +/** + * @brief Get the edge weights + * + * @param [in] coo Opaque pointer to COO + * @return type erased array view of edge weights, NULL if no edge weights in COO + */ +cugraph_type_erased_device_array_view_t* cugraph_coo_get_edge_weights(cugraph_coo_t* coo); + +/** + * @brief Get the edge id + * + * @param [in] coo Opaque pointer to COO + * @return type erased array view of edge id, NULL if no edge ids in COO + */ +cugraph_type_erased_device_array_view_t* cugraph_coo_get_edge_id(cugraph_coo_t* coo); + +/** + * @brief Get the edge type + * + * @param [in] coo Opaque pointer to COO + * @return type erased array view of edge type, NULL if no edge types in COO + */ +cugraph_type_erased_device_array_view_t* cugraph_coo_get_edge_type(cugraph_coo_t* coo); + +/** + * @brief Get the number of coo object in the list + * + * @param [in] coo_list Opaque pointer to COO list + * @return number of elements + */ +size_t cugraph_coo_list_size(const cugraph_coo_list_t* coo_list); + +/** + * @brief Get a COO from the list + * + * @param [in] coo_list Opaque pointer to COO list + * @param [in] index Index of desired COO from list + * @return a cugraph_coo_t* object from the list + */ +cugraph_coo_t* cugraph_coo_list_element(cugraph_coo_list_t* coo_list, size_t index); + +/** + * @brief Free coo object + * + * @param [in] coo Opaque pointer to COO + */ +void cugraph_coo_free(cugraph_coo_t* coo); + +/** + * @brief Free coo list + * + * @param [in] coo_list Opaque pointer to list of COO objects + */ +void cugraph_coo_list_free(cugraph_coo_list_t* coo_list); + +#ifdef __cplusplus +} +#endif diff --git a/cpp/include/cugraph_c/graph_generators.h b/cpp/include/cugraph_c/graph_generators.h index 272131d2aab..553be530e95 100644 --- a/cpp/include/cugraph_c/graph_generators.h +++ b/cpp/include/cugraph_c/graph_generators.h @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include #include @@ -27,91 +28,6 @@ extern "C" { typedef enum { POWER_LAW = 0, UNIFORM } cugraph_generator_distribution_t; -/** - * @brief Opaque COO definition - */ -typedef struct { - int32_t align_; -} cugraph_coo_t; - -/** - * @brief Opaque COO list definition - */ -typedef struct { - int32_t align_; -} cugraph_coo_list_t; - -/** - * @brief Get the source vertex ids - * - * @param [in] coo Opaque pointer to COO - * @return type erased array view of source vertex ids - */ -cugraph_type_erased_device_array_view_t* cugraph_coo_get_sources(cugraph_coo_t* coo); - -/** - * @brief Get the destination vertex ids - * - * @param [in] coo Opaque pointer to COO - * @return type erased array view of destination vertex ids - */ -cugraph_type_erased_device_array_view_t* cugraph_coo_get_destinations(cugraph_coo_t* coo); - -/** - * @brief Get the edge weights - * - * @param [in] coo Opaque pointer to COO - * @return type erased array view of edge weights, NULL if no edge weights in COO - */ -cugraph_type_erased_device_array_view_t* cugraph_coo_get_edge_weights(cugraph_coo_t* coo); - -/** - * @brief Get the edge id - * - * @param [in] coo Opaque pointer to COO - * @return type erased array view of edge id, NULL if no edge ids in COO - */ -cugraph_type_erased_device_array_view_t* cugraph_coo_get_edge_id(cugraph_coo_t* coo); - -/** - * @brief Get the edge type - * - * @param [in] coo Opaque pointer to COO - * @return type erased array view of edge type, NULL if no edge types in COO - */ -cugraph_type_erased_device_array_view_t* cugraph_coo_get_edge_type(cugraph_coo_t* coo); - -/** - * @brief Get the number of coo object in the list - * - * @param [in] coo_list Opaque pointer to COO list - * @return number of elements - */ -size_t cugraph_coo_list_size(const cugraph_coo_list_t* coo_list); - -/** - * @brief Get a COO from the list - * - * @param [in] coo_list Opaque pointer to COO list - * @param [in] index Index of desired COO from list - * @return a cugraph_coo_t* object from the list - */ -cugraph_coo_t* cugraph_coo_list_element(cugraph_coo_list_t* coo_list, size_t index); - -/** - * @brief Free coo object - * - * @param [in] coo Opaque pointer to COO - */ -void cugraph_coo_free(cugraph_coo_t* coo); - -/** - * @brief Free coo list - * - * @param [in] coo_list Opaque pointer to list of COO objects - */ -void cugraph_coo_list_free(cugraph_coo_list_t* coo_list); - /** * @brief Generate RMAT edge list * diff --git a/cpp/include/cugraph_c/sampling_algorithms.h b/cpp/include/cugraph_c/sampling_algorithms.h index a7490ad2c63..4dfc146ee72 100644 --- a/cpp/include/cugraph_c/sampling_algorithms.h +++ b/cpp/include/cugraph_c/sampling_algorithms.h @@ -16,6 +16,7 @@ #pragma once +#include #include #include #include @@ -614,6 +615,56 @@ cugraph_error_code_t cugraph_select_random_vertices(const cugraph_resource_handl cugraph_type_erased_device_array_t** vertices, cugraph_error_t** error); +/** + * @ingroup samplingC + * @brief Perform negative sampling + * + * Negative sampling generates a COO structure defining edges according to the specified parameters + * + * @param [in] handle Handle for accessing resources + * @param [in,out] rng_state State of the random number generator, updated with each + * call + * @param [in] graph Pointer to graph + * @param [in] num_samples Number of negative samples to generate + * @param [in] vertices Vertex ids for the source biases. If @p src_bias and + * @p dst_bias are not specified this is ignored. If + * @p vertices is specified then vertices[i] is the vertex + * id of src_bias[i] and dst_bias[i]. If @p vertices is not specified then i is the vertex id if + * src_bias[i] and dst_bias[i] + * @param [in] src_bias Bias for selecting source vertices. If NULL, do uniform + * sampling, if provided probability of vertex i will be + * src_bias[i] / (sum of all source biases) + * @param [in] dst_bias Bias for selecting destination vertices. If NULL, do + * uniform sampling, if provided probability of vertex i + * will be dst_bias[i] / (sum of all destination biases) + * @param [in] remove_duplicates If true, remove duplicates from sampled edges + * @param [in] remove_false_negatives If true, remove sampled edges that actually exist in the + * graph + * @param [in] exact_number_of_samples If true, result should contain exactly @p num_samples. If + * false the code will generate @p num_samples and then do + * any filtering as specified + * @param [in] do_expensive_check A flag to run expensive checks for input arguments (if + * set to true) + * @param [out] result Opaque pointer to generated coo list + * @param [out] error Pointer to an error object storing details of any error. + * Will be populated if error code is not CUGRAPH_SUCCESS + * @return error code + */ +cugraph_error_code_t cugraph_negative_sampling( + const cugraph_resource_handle_t* handle, + cugraph_rng_state_t* rng_state, + cugraph_graph_t* graph, + size_t num_samples, + const cugraph_type_erased_device_array_view_t* vertices, + const cugraph_type_erased_device_array_view_t* src_bias, + const cugraph_type_erased_device_array_view_t* dst_bias, + bool_t remove_duplicates, + bool_t remove_false_negatives, + bool_t exact_number_of_samples, + bool_t do_expensive_check, + cugraph_coo_t** result, + cugraph_error_t** error); + #ifdef __cplusplus } #endif diff --git a/cpp/src/c_api/coo.hpp b/cpp/src/c_api/coo.hpp new file mode 100644 index 00000000000..a83a3af375a --- /dev/null +++ b/cpp/src/c_api/coo.hpp @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "c_api/array.hpp" + +#include + +namespace cugraph { +namespace c_api { + +struct cugraph_coo_t { + std::unique_ptr src_{}; + std::unique_ptr dst_{}; + std::unique_ptr wgt_{}; + std::unique_ptr id_{}; + std::unique_ptr type_{}; +}; + +struct cugraph_coo_list_t { + std::vector> list_; +}; + +} // namespace c_api +} // namespace cugraph diff --git a/cpp/src/c_api/graph_generators.cpp b/cpp/src/c_api/graph_generators.cpp index ef478e57098..7601f1508f9 100644 --- a/cpp/src/c_api/graph_generators.cpp +++ b/cpp/src/c_api/graph_generators.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ #include "c_api/array.hpp" +#include "c_api/coo.hpp" #include "c_api/error.hpp" #include "c_api/random.hpp" #include "c_api/resource_handle.hpp" @@ -26,24 +27,6 @@ #include -namespace cugraph { -namespace c_api { - -struct cugraph_coo_t { - std::unique_ptr src_{}; - std::unique_ptr dst_{}; - std::unique_ptr wgt_{}; - std::unique_ptr id_{}; - std::unique_ptr type_{}; -}; - -struct cugraph_coo_list_t { - std::vector> list_; -}; - -} // namespace c_api -} // namespace cugraph - namespace { template diff --git a/cpp/src/c_api/negative_sampling.cpp b/cpp/src/c_api/negative_sampling.cpp new file mode 100644 index 00000000000..1996755e536 --- /dev/null +++ b/cpp/src/c_api/negative_sampling.cpp @@ -0,0 +1,226 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "c_api/abstract_functor.hpp" +#include "c_api/coo.hpp" +#include "c_api/graph.hpp" +#include "c_api/random.hpp" +#include "c_api/resource_handle.hpp" +#include "c_api/utils.hpp" + +#include + +#include +#include +#include +#include + +#include + +namespace { + +struct negative_sampling_functor : public cugraph::c_api::abstract_functor { + raft::handle_t const& handle_; + cugraph::c_api::cugraph_rng_state_t* rng_state_{nullptr}; + cugraph::c_api::cugraph_graph_t* graph_{nullptr}; + size_t num_samples_; + cugraph::c_api::cugraph_type_erased_device_array_view_t const* vertices_{nullptr}; + cugraph::c_api::cugraph_type_erased_device_array_view_t const* src_bias_{nullptr}; + cugraph::c_api::cugraph_type_erased_device_array_view_t const* dst_bias_{nullptr}; + bool remove_duplicates_{false}; + bool remove_false_negatives_{false}; + bool exact_number_of_samples_{false}; + bool do_expensive_check_{false}; + cugraph::c_api::cugraph_coo_t* result_{nullptr}; + + negative_sampling_functor(const cugraph_resource_handle_t* handle, + cugraph_rng_state_t* rng_state, + cugraph_graph_t* graph, + size_t num_samples, + const cugraph_type_erased_device_array_view_t* vertices, + const cugraph_type_erased_device_array_view_t* src_bias, + const cugraph_type_erased_device_array_view_t* dst_bias, + bool_t remove_duplicates, + bool_t remove_false_negatives, + bool_t exact_number_of_samples, + bool_t do_expensive_check) + : abstract_functor(), + handle_(*reinterpret_cast(handle)->handle_), + rng_state_(reinterpret_cast(rng_state)), + graph_(reinterpret_cast(graph)), + num_samples_(num_samples), + vertices_( + reinterpret_cast(vertices)), + src_bias_( + reinterpret_cast(src_bias)), + dst_bias_( + reinterpret_cast(dst_bias)), + remove_duplicates_(remove_duplicates), + remove_false_negatives_(remove_false_negatives), + exact_number_of_samples_(exact_number_of_samples), + do_expensive_check_(do_expensive_check) + { + } + + template + void operator()() + { + // FIXME: Think about how to handle SG vice MG + if constexpr (!cugraph::is_candidate::value) { + unsupported(); + } else { + // uniform_nbr_sample expects store_transposed == false + if constexpr (store_transposed) { + error_code_ = cugraph::c_api:: + transpose_storage( + handle_, graph_, error_.get()); + if (error_code_ != CUGRAPH_SUCCESS) return; + } + + auto graph = + reinterpret_cast*>(graph_->graph_); + + auto graph_view = graph->view(); + + auto number_map = reinterpret_cast*>(graph_->number_map_); + + rmm::device_uvector vertices(0, handle_.get_stream()); + rmm::device_uvector src_bias(0, handle_.get_stream()); + rmm::device_uvector dst_bias(0, handle_.get_stream()); + + // TODO: What is required here? + + if (src_bias_ != nullptr) { + vertices.resize(vertices_->size_, handle_.get_stream()); + src_bias.resize(src_bias_->size_, handle_.get_stream()); + + raft::copy( + vertices.data(), vertices_->as_type(), vertices.size(), handle_.get_stream()); + raft::copy( + src_bias.data(), src_bias_->as_type(), src_bias.size(), handle_.get_stream()); + + src_bias = cugraph::detail:: + collect_local_vertex_values_from_ext_vertex_value_pairs( + handle_, + std::move(vertices), + std::move(src_bias), + *number_map, + graph_view.local_vertex_partition_range_first(), + graph_view.local_vertex_partition_range_last(), + weight_t{0}, + do_expensive_check_); + } + + if (dst_bias_ != nullptr) { + vertices.resize(vertices_->size_, handle_.get_stream()); + dst_bias.resize(dst_bias_->size_, handle_.get_stream()); + + raft::copy( + vertices.data(), vertices_->as_type(), vertices.size(), handle_.get_stream()); + raft::copy( + dst_bias.data(), dst_bias_->as_type(), dst_bias.size(), handle_.get_stream()); + + dst_bias = cugraph::detail:: + collect_local_vertex_values_from_ext_vertex_value_pairs( + handle_, + std::move(vertices), + std::move(dst_bias), + *number_map, + graph_view.local_vertex_partition_range_first(), + graph_view.local_vertex_partition_range_last(), + weight_t{0}, + do_expensive_check_); + } + + auto&& [src, dst] = cugraph::negative_sampling( + handle_, + rng_state_->rng_state_, + graph_view, + num_samples_, + (src_bias_ != nullptr) + ? std::make_optional(raft::device_span{src_bias.data(), src_bias.size()}) + : std::nullopt, + (dst_bias_ != nullptr) + ? std::make_optional(raft::device_span{dst_bias.data(), dst_bias.size()}) + : std::nullopt, + remove_duplicates_, + remove_false_negatives_, + exact_number_of_samples_, + do_expensive_check_); + + std::vector vertex_partition_lasts = graph_view.vertex_partition_range_lasts(); + + cugraph::unrenumber_int_vertices(handle_, + src.data(), + src.size(), + number_map->data(), + vertex_partition_lasts, + do_expensive_check_); + + cugraph::unrenumber_int_vertices(handle_, + dst.data(), + dst.size(), + number_map->data(), + vertex_partition_lasts, + do_expensive_check_); + + result_ = new cugraph::c_api::cugraph_coo_t{ + std::make_unique(src, + graph_->vertex_type_), + std::make_unique(dst, + graph_->vertex_type_), + nullptr, + nullptr, + nullptr}; + } + } +}; + +} // namespace + +cugraph_error_code_t cugraph_negative_sampling( + const cugraph_resource_handle_t* handle, + cugraph_rng_state_t* rng_state, + cugraph_graph_t* graph, + size_t num_samples, + const cugraph_type_erased_device_array_view_t* vertices, + const cugraph_type_erased_device_array_view_t* src_bias, + const cugraph_type_erased_device_array_view_t* dst_bias, + bool_t remove_duplicates, + bool_t remove_false_negatives, + bool_t exact_number_of_samples, + bool_t do_expensive_check, + cugraph_coo_t** result, + cugraph_error_t** error) +{ + negative_sampling_functor functor{handle, + rng_state, + graph, + num_samples, + vertices, + src_bias, + dst_bias, + remove_duplicates, + remove_false_negatives, + exact_number_of_samples, + do_expensive_check}; + return cugraph::c_api::run_algorithm(graph, functor, result, error); +} diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index ebd2457c471..cf42a6143c6 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -777,6 +777,7 @@ if(BUILD_CUGRAPH_MG_TESTS) ConfigureCTestMG(MG_CAPI_EDGE_BETWEENNESS_CENTRALITY_TEST c_api/mg_edge_betweenness_centrality_test.c) ConfigureCTestMG(MG_CAPI_HITS_TEST c_api/mg_hits_test.c) ConfigureCTestMG(MG_CAPI_UNIFORM_NEIGHBOR_SAMPLE_TEST c_api/mg_uniform_neighbor_sample_test.c) + ConfigureCTestMG(MG_CAPI_NEGATIVE_SAMPLING_TEST c_api/mg_negative_sampling_test.c) ConfigureCTestMG(MG_CAPI_LOOKUP_SRC_DST_TEST c_api/mg_lookup_src_dst_test.c) ConfigureCTestMG(MG_CAPI_RANDOM_WALKS_TEST c_api/mg_random_walks_test.c) ConfigureCTestMG(MG_CAPI_TRIANGLE_COUNT_TEST c_api/mg_triangle_count_test.c) @@ -814,6 +815,7 @@ ConfigureCTest(CAPI_NODE2VEC_TEST c_api/node2vec_test.c) ConfigureCTest(CAPI_WEAKLY_CONNECTED_COMPONENTS_TEST c_api/weakly_connected_components_test.c) ConfigureCTest(CAPI_STRONGLY_CONNECTED_COMPONENTS_TEST c_api/strongly_connected_components_test.c) ConfigureCTest(CAPI_UNIFORM_NEIGHBOR_SAMPLE_TEST c_api/uniform_neighbor_sample_test.c) +ConfigureCTest(CAPI_NEGATIVE_SAMPLING_TEST c_api/negative_sampling_test.c) ConfigureCTest(CAPI_RANDOM_WALKS_TEST c_api/sg_random_walks_test.c) ConfigureCTest(CAPI_TRIANGLE_COUNT_TEST c_api/triangle_count_test.c) ConfigureCTest(CAPI_LOUVAIN_TEST c_api/louvain_test.c) diff --git a/cpp/tests/c_api/mg_negative_sampling_test.c b/cpp/tests/c_api/mg_negative_sampling_test.c new file mode 100644 index 00000000000..566524251ed --- /dev/null +++ b/cpp/tests/c_api/mg_negative_sampling_test.c @@ -0,0 +1,295 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mg_test_utils.h" /* RUN_MG_TEST */ + +#include +#include + +#include +#include +#include + +typedef int32_t vertex_t; +typedef int32_t edge_t; +typedef float weight_t; + +data_type_id_t vertex_tid = INT32; +data_type_id_t edge_tid = INT32; +data_type_id_t weight_tid = FLOAT32; +data_type_id_t edge_id_tid = INT32; +data_type_id_t edge_type_tid = INT32; + +int generic_negative_sampling_test(const cugraph_resource_handle_t* handle, + vertex_t* h_src, + vertex_t* h_dst, + size_t num_vertices, + size_t num_edges, + size_t num_samples, + vertex_t* h_vertices, + weight_t* h_src_bias, + weight_t* h_dst_bias, + size_t num_biases, + bool_t remove_duplicates, + bool_t remove_false_negatives, + bool_t exact_number_of_samples) +{ + // Create graph + int test_ret_value = 0; + cugraph_error_code_t ret_code = CUGRAPH_SUCCESS; + cugraph_error_t* ret_error = NULL; + cugraph_graph_t* graph = NULL; + cugraph_coo_t* result = NULL; + + ret_code = create_mg_test_graph_new(handle, + vertex_tid, + edge_tid, + h_src, + h_dst, + weight_tid, + NULL, + edge_type_tid, + NULL, + edge_id_tid, + NULL, + num_edges, + FALSE, + TRUE, + FALSE, + FALSE, + &graph, + &ret_error); + + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "graph creation failed."); + + cugraph_type_erased_device_array_t* d_vertices = NULL; + cugraph_type_erased_device_array_view_t* d_vertices_view = NULL; + cugraph_type_erased_device_array_t* d_src_bias = NULL; + cugraph_type_erased_device_array_view_t* d_src_bias_view = NULL; + cugraph_type_erased_device_array_t* d_dst_bias = NULL; + cugraph_type_erased_device_array_view_t* d_dst_bias_view = NULL; + + int rank = cugraph_resource_handle_get_rank(handle); + + if (num_biases > 0) { + if (rank == 0) { + ret_code = cugraph_type_erased_device_array_create( + handle, num_biases, vertex_tid, &d_vertices, &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "d_vertices create failed."); + + d_vertices_view = cugraph_type_erased_device_array_view(d_vertices); + + ret_code = cugraph_type_erased_device_array_view_copy_from_host( + handle, d_vertices_view, (byte_t*)h_vertices, &ret_error); + + ret_code = cugraph_type_erased_device_array_create( + handle, num_biases, weight_tid, &d_src_bias, &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "d_src_bias create failed."); + + d_src_bias_view = cugraph_type_erased_device_array_view(d_src_bias); + + ret_code = cugraph_type_erased_device_array_view_copy_from_host( + handle, d_src_bias_view, (byte_t*)h_src_bias, &ret_error); + + ret_code = cugraph_type_erased_device_array_create( + handle, num_biases, weight_tid, &d_dst_bias, &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "d_dst_bias create failed."); + + d_dst_bias_view = cugraph_type_erased_device_array_view(d_dst_bias); + + ret_code = cugraph_type_erased_device_array_view_copy_from_host( + handle, d_dst_bias_view, (byte_t*)h_dst_bias, &ret_error); + } else { + d_vertices_view = cugraph_type_erased_device_array_view_create(NULL, 0, vertex_tid); + d_src_bias_view = cugraph_type_erased_device_array_view_create(NULL, 0, weight_tid); + d_dst_bias_view = cugraph_type_erased_device_array_view_create(NULL, 0, weight_tid); + } + } + + cugraph_rng_state_t* rng_state; + ret_code = cugraph_rng_state_create(handle, rank, &rng_state, &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "rng_state create failed."); + + ret_code = cugraph_negative_sampling(handle, + rng_state, + graph, + num_samples, + d_vertices_view, + d_src_bias_view, + d_dst_bias_view, + remove_duplicates, + remove_false_negatives, + exact_number_of_samples, + FALSE, + &result, + &ret_error); + + cugraph_type_erased_device_array_view_t* result_srcs = NULL; + cugraph_type_erased_device_array_view_t* result_dsts = NULL; + + result_srcs = cugraph_coo_get_sources(result); + result_dsts = cugraph_coo_get_destinations(result); + + size_t result_size = cugraph_type_erased_device_array_view_size(result_srcs); + + vertex_t h_result_srcs[result_size]; + vertex_t h_result_dsts[result_size]; + + ret_code = cugraph_type_erased_device_array_view_copy_to_host( + handle, (byte_t*)h_result_srcs, result_srcs, &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "copy_to_host failed."); + + ret_code = cugraph_type_erased_device_array_view_copy_to_host( + handle, (byte_t*)h_result_dsts, result_dsts, &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "copy_to_host failed."); + + // First, check that all edges are actually part of the graph + int32_t M_exists[num_vertices][num_vertices]; + int32_t M_duplicates[num_vertices][num_vertices]; + + for (int i = 0; i < num_vertices; ++i) + for (int j = 0; j < num_vertices; ++j) { + M_exists[i][j] = 0; + M_duplicates[i][j] = 0; + } + + for (int i = 0; i < num_edges; ++i) { + M_exists[h_src[i]][h_dst[i]] = 1; + } + + for (int i = 0; (i < result_size) && (test_ret_value == 0); ++i) { + TEST_ASSERT(test_ret_value, + (h_result_srcs[i] >= 0) && (h_result_srcs[i] < num_vertices), + "negative_sampling generated an edge that with an invalid vertex"); + TEST_ASSERT(test_ret_value, + (h_result_dsts[i] >= 0) && (h_result_dsts[i] < num_vertices), + "negative_sampling generated an edge that with an invalid vertex"); + if (remove_false_negatives == TRUE) { + TEST_ASSERT(test_ret_value, + M_exists[h_result_srcs[i]][h_result_dsts[i]] == 0, + "negative_sampling generated a false negative edge that should be suppressed"); + } + + if (remove_duplicates == TRUE) { + TEST_ASSERT(test_ret_value, + M_duplicates[h_result_srcs[i]][h_result_dsts[i]] == 0, + "negative_sampling generated a duplicate edge that should be suppressed"); + M_duplicates[h_result_srcs[i]][h_result_dsts[i]] = 1; + } + } + + if (exact_number_of_samples == TRUE) + TEST_ASSERT(test_ret_value, + result_size == num_samples, + "negative_sampling generated a result with an incorrect number of samples"); + + cugraph_type_erased_device_array_view_free(d_vertices_view); + cugraph_type_erased_device_array_view_free(d_src_bias_view); + cugraph_type_erased_device_array_view_free(d_dst_bias_view); + cugraph_type_erased_device_array_free(d_vertices); + cugraph_type_erased_device_array_free(d_src_bias); + cugraph_type_erased_device_array_free(d_dst_bias); + cugraph_coo_free(result); + cugraph_mg_graph_free(graph); + cugraph_error_free(ret_error); + return test_ret_value; +} + +int test_negative_sampling_uniform(const cugraph_resource_handle_t* handle) +{ + data_type_id_t vertex_tid = INT32; + data_type_id_t edge_tid = INT32; + data_type_id_t weight_tid = FLOAT32; + + size_t num_edges = 9; + size_t num_vertices = 6; + size_t num_biases = 0; + size_t num_samples = 10; + + vertex_t src[] = {0, 0, 1, 1, 2, 2, 2, 3, 4}; + vertex_t dst[] = {1, 2, 3, 4, 0, 1, 3, 5, 5}; + + bool_t remove_duplicates = FALSE; + bool_t remove_false_negatives = TRUE; + bool_t exact_number_of_samples = FALSE; + + return generic_negative_sampling_test(handle, + src, + dst, + num_vertices, + num_edges, + num_samples, + NULL, + NULL, + NULL, + num_biases, + remove_duplicates, + remove_false_negatives, + exact_number_of_samples); +} + +int test_negative_sampling_biased(const cugraph_resource_handle_t* handle) +{ + data_type_id_t vertex_tid = INT32; + data_type_id_t edge_tid = INT32; + data_type_id_t weight_tid = FLOAT32; + + size_t num_edges = 9; + size_t num_vertices = 6; + size_t num_biases = 6; + size_t num_samples = 10; + + vertex_t src[] = {0, 0, 1, 1, 2, 2, 2, 3, 4}; + vertex_t dst[] = {1, 2, 3, 4, 0, 1, 3, 5, 5}; + weight_t src_bias[] = {1, 1, 2, 2, 1, 1}; + weight_t dst_bias[] = {2, 2, 1, 1, 1, 1}; + vertex_t vertices[] = {0, 1, 2, 3, 4, 5}; + + bool_t remove_duplicates = FALSE; + bool_t remove_false_negatives = TRUE; + bool_t exact_number_of_samples = FALSE; + + return generic_negative_sampling_test(handle, + src, + dst, + num_vertices, + num_edges, + num_samples, + vertices, + src_bias, + dst_bias, + num_biases, + remove_duplicates, + remove_false_negatives, + exact_number_of_samples); +} + +/******************************************************************************/ + +int main(int argc, char** argv) +{ + void* raft_handle = create_mg_raft_handle(argc, argv); + cugraph_resource_handle_t* handle = cugraph_create_resource_handle(raft_handle); + + int result = 0; + result |= RUN_MG_TEST(test_negative_sampling_uniform, handle); + result |= RUN_MG_TEST(test_negative_sampling_biased, handle); + + cugraph_free_resource_handle(handle); + free_mg_raft_handle(raft_handle); + + return result; +} diff --git a/cpp/tests/c_api/negative_sampling_test.c b/cpp/tests/c_api/negative_sampling_test.c new file mode 100644 index 00000000000..abea4028061 --- /dev/null +++ b/cpp/tests/c_api/negative_sampling_test.c @@ -0,0 +1,284 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "c_test_utils.h" /* RUN_TEST */ + +#include +#include + +#include +#include +#include + +typedef int32_t vertex_t; +typedef int32_t edge_t; +typedef float weight_t; + +data_type_id_t vertex_tid = INT32; +data_type_id_t edge_tid = INT32; +data_type_id_t weight_tid = FLOAT32; +data_type_id_t edge_id_tid = INT32; +data_type_id_t edge_type_tid = INT32; + +int generic_negative_sampling_test(const cugraph_resource_handle_t* handle, + vertex_t* h_src, + vertex_t* h_dst, + size_t num_vertices, + size_t num_edges, + size_t num_samples, + vertex_t* h_vertices, + weight_t* h_src_bias, + weight_t* h_dst_bias, + size_t num_biases, + bool_t remove_duplicates, + bool_t remove_false_negatives, + bool_t exact_number_of_samples) +{ + // Create graph + int test_ret_value = 0; + cugraph_error_code_t ret_code = CUGRAPH_SUCCESS; + cugraph_error_t* ret_error = NULL; + cugraph_graph_t* graph = NULL; + cugraph_coo_t* result = NULL; + + ret_code = create_sg_test_graph(handle, + vertex_tid, + edge_tid, + h_src, + h_dst, + weight_tid, + NULL, + edge_type_tid, + NULL, + edge_id_tid, + NULL, + num_edges, + FALSE, + TRUE, + FALSE, + FALSE, + &graph, + &ret_error); + + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "graph creation failed."); + + cugraph_type_erased_device_array_t* d_vertices = NULL; + cugraph_type_erased_device_array_view_t* d_vertices_view = NULL; + cugraph_type_erased_device_array_t* d_src_bias = NULL; + cugraph_type_erased_device_array_view_t* d_src_bias_view = NULL; + cugraph_type_erased_device_array_t* d_dst_bias = NULL; + cugraph_type_erased_device_array_view_t* d_dst_bias_view = NULL; + + if (num_biases > 0) { + ret_code = cugraph_type_erased_device_array_create( + handle, num_biases, vertex_tid, &d_vertices, &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "d_vertices create failed."); + + d_vertices_view = cugraph_type_erased_device_array_view(d_vertices); + + ret_code = cugraph_type_erased_device_array_view_copy_from_host( + handle, d_vertices_view, (byte_t*)h_vertices, &ret_error); + + ret_code = cugraph_type_erased_device_array_create( + handle, num_biases, weight_tid, &d_src_bias, &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "d_src_bias create failed."); + + d_src_bias_view = cugraph_type_erased_device_array_view(d_src_bias); + + ret_code = cugraph_type_erased_device_array_view_copy_from_host( + handle, d_src_bias_view, (byte_t*)h_src_bias, &ret_error); + + ret_code = cugraph_type_erased_device_array_create( + handle, num_biases, weight_tid, &d_dst_bias, &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "d_dst_bias create failed."); + + d_dst_bias_view = cugraph_type_erased_device_array_view(d_dst_bias); + + ret_code = cugraph_type_erased_device_array_view_copy_from_host( + handle, d_dst_bias_view, (byte_t*)h_dst_bias, &ret_error); + } + + cugraph_rng_state_t* rng_state; + ret_code = cugraph_rng_state_create(handle, 0, &rng_state, &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "rng_state create failed."); + + ret_code = cugraph_negative_sampling(handle, + rng_state, + graph, + num_samples, + d_vertices_view, + d_src_bias_view, + d_dst_bias_view, + remove_duplicates, + remove_false_negatives, + exact_number_of_samples, + FALSE, + &result, + &ret_error); + + cugraph_type_erased_device_array_view_t* result_srcs = NULL; + cugraph_type_erased_device_array_view_t* result_dsts = NULL; + + result_srcs = cugraph_coo_get_sources(result); + result_dsts = cugraph_coo_get_destinations(result); + + size_t result_size = cugraph_type_erased_device_array_view_size(result_srcs); + + vertex_t h_result_srcs[result_size]; + vertex_t h_result_dsts[result_size]; + + ret_code = cugraph_type_erased_device_array_view_copy_to_host( + handle, (byte_t*)h_result_srcs, result_srcs, &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "copy_to_host failed."); + + ret_code = cugraph_type_erased_device_array_view_copy_to_host( + handle, (byte_t*)h_result_dsts, result_dsts, &ret_error); + TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "copy_to_host failed."); + + // First, check that all edges are actually part of the graph + int32_t M_exists[num_vertices][num_vertices]; + int32_t M_duplicates[num_vertices][num_vertices]; + + for (int i = 0; i < num_vertices; ++i) + for (int j = 0; j < num_vertices; ++j) { + M_exists[i][j] = 0; + M_duplicates[i][j] = 0; + } + + for (int i = 0; i < num_edges; ++i) { + M_exists[h_src[i]][h_dst[i]] = 1; + } + + for (int i = 0; (i < result_size) && (test_ret_value == 0); ++i) { + TEST_ASSERT(test_ret_value, + (h_result_srcs[i] >= 0) && (h_result_srcs[i] < num_vertices), + "negative_sampling generated an edge that with an invalid vertex"); + TEST_ASSERT(test_ret_value, + (h_result_dsts[i] >= 0) && (h_result_dsts[i] < num_vertices), + "negative_sampling generated an edge that with an invalid vertex"); + if (remove_false_negatives == TRUE) { + TEST_ASSERT(test_ret_value, + M_exists[h_result_srcs[i]][h_result_dsts[i]] == 0, + "negative_sampling generated a false negative edge that should be suppressed"); + } + + if (remove_duplicates == TRUE) { + TEST_ASSERT(test_ret_value, + M_duplicates[h_result_srcs[i]][h_result_dsts[i]] == 0, + "negative_sampling generated a duplicate edge that should be suppressed"); + M_duplicates[h_result_srcs[i]][h_result_dsts[i]] = 1; + } + } + + if (exact_number_of_samples == TRUE) + TEST_ASSERT(test_ret_value, + result_size == num_samples, + "negative_sampling generated a result with an incorrect number of samples"); + + cugraph_type_erased_device_array_view_free(d_vertices_view); + cugraph_type_erased_device_array_view_free(d_src_bias_view); + cugraph_type_erased_device_array_view_free(d_dst_bias_view); + cugraph_type_erased_device_array_free(d_vertices); + cugraph_type_erased_device_array_free(d_src_bias); + cugraph_coo_free(result); + cugraph_sg_graph_free(graph); + cugraph_error_free(ret_error); + return test_ret_value; +} + +int test_negative_sampling_uniform(const cugraph_resource_handle_t* handle) +{ + data_type_id_t vertex_tid = INT32; + data_type_id_t edge_tid = INT32; + data_type_id_t weight_tid = FLOAT32; + + size_t num_edges = 9; + size_t num_vertices = 6; + size_t num_biases = 0; + size_t num_samples = 10; + + vertex_t src[] = {0, 0, 1, 1, 2, 2, 2, 3, 4}; + vertex_t dst[] = {1, 2, 3, 4, 0, 1, 3, 5, 5}; + + bool_t remove_duplicates = FALSE; + bool_t remove_false_negatives = TRUE; + bool_t exact_number_of_samples = FALSE; + + return generic_negative_sampling_test(handle, + src, + dst, + num_vertices, + num_edges, + num_samples, + NULL, + NULL, + NULL, + num_biases, + remove_duplicates, + remove_false_negatives, + exact_number_of_samples); +} + +int test_negative_sampling_biased(const cugraph_resource_handle_t* handle) +{ + data_type_id_t vertex_tid = INT32; + data_type_id_t edge_tid = INT32; + data_type_id_t weight_tid = FLOAT32; + + size_t num_edges = 9; + size_t num_vertices = 6; + size_t num_biases = 6; + size_t num_samples = 10; + + vertex_t src[] = {0, 0, 1, 1, 2, 2, 2, 3, 4}; + vertex_t dst[] = {1, 2, 3, 4, 0, 1, 3, 5, 5}; + weight_t src_bias[] = {1, 1, 2, 2, 1, 1}; + weight_t dst_bias[] = {2, 2, 1, 1, 1, 1}; + vertex_t vertices[] = {0, 1, 2, 3, 4, 5}; + + bool_t remove_duplicates = FALSE; + bool_t remove_false_negatives = TRUE; + bool_t exact_number_of_samples = FALSE; + + return generic_negative_sampling_test(handle, + src, + dst, + num_vertices, + num_edges, + num_samples, + vertices, + src_bias, + dst_bias, + num_biases, + remove_duplicates, + remove_false_negatives, + exact_number_of_samples); +} + +int main(int argc, char** argv) +{ + cugraph_resource_handle_t* handle = NULL; + + handle = cugraph_create_resource_handle(NULL); + + int result = 0; + result |= RUN_TEST_NEW(test_negative_sampling_uniform, handle); + result |= RUN_TEST_NEW(test_negative_sampling_biased, handle); + + cugraph_free_resource_handle(handle); + + return result; +} diff --git a/python/pylibcugraph/pylibcugraph/_cugraph_c/coo.pxd b/python/pylibcugraph/pylibcugraph/_cugraph_c/coo.pxd new file mode 100644 index 00000000000..e466e6ee5a0 --- /dev/null +++ b/python/pylibcugraph/pylibcugraph/_cugraph_c/coo.pxd @@ -0,0 +1,71 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Have cython use python 3 syntax +# cython: language_level = 3 + +from pylibcugraph._cugraph_c.array cimport ( + cugraph_type_erased_device_array_view_t, +) + +cdef extern from "cugraph_c/coo.h": + ctypedef struct cugraph_coo_t: + pass + + ctypedef struct cugraph_coo_list_t: + pass + + cdef cugraph_type_erased_device_array_view_t* \ + cugraph_coo_get_sources( + cugraph_coo_t* coo + ) + + cdef cugraph_type_erased_device_array_view_t* \ + cugraph_coo_get_destinations( + cugraph_coo_t* coo + ) + + cdef cugraph_type_erased_device_array_view_t* \ + cugraph_coo_get_edge_weights( + cugraph_coo_t* coo + ) + + cdef cugraph_type_erased_device_array_view_t* \ + cugraph_coo_get_edge_id( + cugraph_coo_t* coo + ) + + cdef cugraph_type_erased_device_array_view_t* \ + cugraph_coo_get_edge_type( + cugraph_coo_t* coo + ) + + cdef size_t \ + cugraph_coo_list_size( + const cugraph_coo_list_t* coo_list + ) + + cdef cugraph_coo_t* \ + cugraph_coo_list_element( + cugraph_coo_list_t* coo_list, + size_t index) + + cdef void \ + cugraph_coo_free( + cugraph_coo_t* coo + ) + + cdef void \ + cugraph_coo_list_free( + cugraph_coo_list_t* coo_list + ) diff --git a/python/pylibcugraph/pylibcugraph/_cugraph_c/graph_generators.pxd b/python/pylibcugraph/pylibcugraph/_cugraph_c/graph_generators.pxd index f6d62377443..cda47e55f77 100644 --- a/python/pylibcugraph/pylibcugraph/_cugraph_c/graph_generators.pxd +++ b/python/pylibcugraph/pylibcugraph/_cugraph_c/graph_generators.pxd @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -31,62 +31,16 @@ from pylibcugraph._cugraph_c.random cimport ( cugraph_rng_state_t, ) +from pylibcugraph._cugraph_c.coo cimport ( + cugraph_coo_t, + cugraph_coo_list_t, +) + cdef extern from "cugraph_c/graph_generators.h": ctypedef enum cugraph_generator_distribution_t: POWER_LAW UNIFORM - ctypedef struct cugraph_coo_t: - pass - - ctypedef struct cugraph_coo_list_t: - pass - - cdef cugraph_type_erased_device_array_view_t* \ - cugraph_coo_get_sources( - cugraph_coo_t* coo - ) - - cdef cugraph_type_erased_device_array_view_t* \ - cugraph_coo_get_destinations( - cugraph_coo_t* coo - ) - - cdef cugraph_type_erased_device_array_view_t* \ - cugraph_coo_get_edge_weights( - cugraph_coo_t* coo - ) - - cdef cugraph_type_erased_device_array_view_t* \ - cugraph_coo_get_edge_id( - cugraph_coo_t* coo - ) - - cdef cugraph_type_erased_device_array_view_t* \ - cugraph_coo_get_edge_type( - cugraph_coo_t* coo - ) - - cdef size_t \ - cugraph_coo_list_size( - const cugraph_coo_list_t* coo_list - ) - - cdef cugraph_coo_t* \ - cugraph_coo_list_element( - cugraph_coo_list_t* coo_list, - size_t index) - - cdef void \ - cugraph_coo_free( - cugraph_coo_t* coo - ) - - cdef void \ - cugraph_coo_list_free( - cugraph_coo_list_t* coo_list - ) - cdef cugraph_error_code_t \ cugraph_generate_rmat_edgelist( const cugraph_resource_handle_t* handle, diff --git a/python/pylibcugraph/pylibcugraph/_cugraph_c/sampling_algorithms.pxd b/python/pylibcugraph/pylibcugraph/_cugraph_c/sampling_algorithms.pxd index dbd3ef4b7e1..fc28a2d86b8 100644 --- a/python/pylibcugraph/pylibcugraph/_cugraph_c/sampling_algorithms.pxd +++ b/python/pylibcugraph/pylibcugraph/_cugraph_c/sampling_algorithms.pxd @@ -40,6 +40,10 @@ from pylibcugraph._cugraph_c.array cimport ( cugraph_type_erased_device_array_t, ) +from pylibcugraph._cugraph_c.coo cimport ( + cugraph_coo_t, +) + cdef extern from "cugraph_c/sampling_algorithms.h": ########################################################################### @@ -82,3 +86,21 @@ cdef extern from "cugraph_c/sampling_algorithms.h": cugraph_type_erased_device_array_t** vertices, cugraph_error_t** error ) + + # negative sampling + cdef cugraph_error_code_t \ + cugraph_negative_sampling( + const cugraph_resource_handle_t* handle, + cugraph_rng_state_t* rng_state, + cugraph_graph_t* graph, + size_t num_samples, + const cugraph_type_erased_device_array_view_t* vertices, + const cugraph_type_erased_device_array_view_t* src_bias, + const cugraph_type_erased_device_array_view_t* dst_bias, + bool_t remove_duplicates, + bool_t remove_false_negatives, + bool_t exact_number_of_samples, + bool_t do_expensive_check, + cugraph_coo_t **result, + cugraph_error_t **error + ) diff --git a/python/pylibcugraph/pylibcugraph/generate_rmat_edgelist.pyx b/python/pylibcugraph/pylibcugraph/generate_rmat_edgelist.pyx index f38ad21d3b0..4ea96920e61 100644 --- a/python/pylibcugraph/pylibcugraph/generate_rmat_edgelist.pyx +++ b/python/pylibcugraph/pylibcugraph/generate_rmat_edgelist.pyx @@ -26,11 +26,7 @@ from pylibcugraph._cugraph_c.error cimport ( from pylibcugraph._cugraph_c.array cimport ( cugraph_type_erased_device_array_view_t, ) -from pylibcugraph._cugraph_c.graph_generators cimport ( - cugraph_generate_rmat_edgelist, - cugraph_generate_edge_weights, - cugraph_generate_edge_ids, - cugraph_generate_edge_types, +from pylibcugraph._cugraph_c.coo cimport ( cugraph_coo_t, cugraph_coo_get_sources, cugraph_coo_get_destinations, @@ -39,6 +35,12 @@ from pylibcugraph._cugraph_c.graph_generators cimport ( cugraph_coo_get_edge_type, cugraph_coo_free, ) +from pylibcugraph._cugraph_c.graph_generators cimport ( + cugraph_generate_rmat_edgelist, + cugraph_generate_edge_weights, + cugraph_generate_edge_ids, + cugraph_generate_edge_types, +) from pylibcugraph.resource_handle cimport ( ResourceHandle, ) diff --git a/python/pylibcugraph/pylibcugraph/generate_rmat_edgelists.pyx b/python/pylibcugraph/pylibcugraph/generate_rmat_edgelists.pyx index 32af0c13fc0..7de48708f80 100644 --- a/python/pylibcugraph/pylibcugraph/generate_rmat_edgelists.pyx +++ b/python/pylibcugraph/pylibcugraph/generate_rmat_edgelists.pyx @@ -26,14 +26,9 @@ from pylibcugraph._cugraph_c.error cimport ( from pylibcugraph._cugraph_c.array cimport ( cugraph_type_erased_device_array_view_t, ) -from pylibcugraph._cugraph_c.graph_generators cimport ( - cugraph_generate_rmat_edgelists, - cugraph_generate_edge_weights, - cugraph_generate_edge_ids, - cugraph_generate_edge_types, +from pylibcugraph._cugraph_c.coo cimport ( cugraph_coo_t, cugraph_coo_list_t, - cugraph_generator_distribution_t, cugraph_coo_get_sources, cugraph_coo_get_destinations, cugraph_coo_get_edge_weights, @@ -44,6 +39,13 @@ from pylibcugraph._cugraph_c.graph_generators cimport ( cugraph_coo_free, cugraph_coo_list_free, ) +from pylibcugraph._cugraph_c.graph_generators cimport ( + cugraph_generate_rmat_edgelists, + cugraph_generate_edge_weights, + cugraph_generate_edge_ids, + cugraph_generate_edge_types, + cugraph_generator_distribution_t, +) from pylibcugraph.resource_handle cimport ( ResourceHandle, ) From 6a908446c1a60548b5c795cba892ab3a90fe735c Mon Sep 17 00:00:00 2001 From: Charles Hastings Date: Tue, 23 Jul 2024 14:28:50 -0700 Subject: [PATCH 07/18] Fix filename change lost in merge --- cpp/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 68537900ca2..555bc44eb26 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -658,7 +658,7 @@ add_library(cugraph_c src/c_api/lookup_src_dst.cpp src/c_api/louvain.cpp src/c_api/triangle_count.cpp - src/c_api/uniform_neighbor_sampling.cpp + src/c_api/neighbor_sampling.cpp src/c_api/negative_sampling.cpp src/c_api/labeling_result.cpp src/c_api/weakly_connected_components.cpp From 2f23ac19139fb8c0723fddcff04be8c73a656bf3 Mon Sep 17 00:00:00 2001 From: Charles Hastings Date: Mon, 5 Aug 2024 13:50:57 -0700 Subject: [PATCH 08/18] Negative sampling now working for SG, MG with 1/2/4 GPUs --- cpp/src/sampling/negative_sampling_impl.cuh | 424 +++++++++++++++++--- cpp/tests/sampling/mg_negative_sampling.cu | 83 +--- 2 files changed, 382 insertions(+), 125 deletions(-) diff --git a/cpp/src/sampling/negative_sampling_impl.cuh b/cpp/src/sampling/negative_sampling_impl.cuh index dc174098b59..a0dc500d9e6 100644 --- a/cpp/src/sampling/negative_sampling_impl.cuh +++ b/cpp/src/sampling/negative_sampling_impl.cuh @@ -16,15 +16,376 @@ #pragma once +#include "prims/reduce_v.cuh" +#include "prims/update_edge_src_dst_property.cuh" +#include "utilities/collect_comm.cuh" + #include #include #include +#include +#include + +#include +#include +#include +#include +#include #include +#include +#include #include namespace cugraph { +namespace detail { + +template +class negative_sampling_impl_t { + private: + static const bool store_transposed = false; + + public: + negative_sampling_impl_t( + raft::handle_t const& handle, + graph_view_t const& graph_view, + std::optional> src_bias, + std::optional> dst_bias) + : gpu_bias_v_(0, handle.get_stream()), + src_bias_v_(0, handle.get_stream()), + dst_bias_v_(0, handle.get_stream()), + src_bias_cache_(std::nullopt), + dst_bias_cache_(std::nullopt) + { + // Need to normalize the src_bias + if (src_bias) { + // Normalize the src bias. + rmm::device_uvector normalized_bias(graph_view.local_vertex_partition_range_size(), + handle.get_stream()); + + weight_t sum = reduce_v(handle, graph_view, src_bias->begin()); + + if constexpr (multi_gpu) { + sum = host_scalar_allreduce( + handle.get_comms(), sum, raft::comms::op_t::SUM, handle.get_stream()); + } + + thrust::transform(handle.get_thrust_policy(), + src_bias->begin(), + src_bias->end(), + normalized_bias.begin(), + divider_t{sum}); + + // Distribute the src bias around the edge partitions + src_bias_cache_ = std::make_optional< + edge_src_property_t, weight_t>>( + handle, graph_view); + update_edge_src_property( + handle, graph_view, normalized_bias.begin(), src_bias_cache_->mutable_view()); + } + + if (dst_bias) { + // Normalize the dst bias. + rmm::device_uvector normalized_bias(graph_view.local_vertex_partition_range_size(), + handle.get_stream()); + + weight_t sum = reduce_v(handle, graph_view, dst_bias->begin()); + + if constexpr (multi_gpu) { + sum = host_scalar_allreduce( + handle.get_comms(), sum, raft::comms::op_t::SUM, handle.get_stream()); + } + + thrust::transform(handle.get_thrust_policy(), + dst_bias->begin(), + dst_bias->end(), + normalized_bias.begin(), + divider_t{sum}); + + dst_bias_cache_ = std::make_optional< + edge_dst_property_t, weight_t>>( + handle, graph_view); + update_edge_dst_property( + handle, graph_view, normalized_bias.begin(), dst_bias_cache_->mutable_view()); + } + + if constexpr (multi_gpu) { + weight_t dst_bias_sum{0}; + + if (dst_bias) { + // Compute the dst_bias sum for this partition and normalize cached values + dst_bias_sum = thrust::reduce( + handle.get_thrust_policy(), + dst_bias_cache_->view().value_first(), + dst_bias_cache_->view().value_first() + graph_view.local_edge_partition_dst_range_size(), + weight_t{0}); + + thrust::transform(handle.get_thrust_policy(), + dst_bias_cache_->mutable_view().value_first(), + dst_bias_cache_->mutable_view().value_first() + + graph_view.local_edge_partition_dst_range_size(), + dst_bias_cache_->mutable_view().value_first(), + divider_t{dst_bias_sum}); + + thrust::inclusive_scan(handle.get_thrust_policy(), + dst_bias_cache_->mutable_view().value_first(), + dst_bias_cache_->mutable_view().value_first() + + graph_view.local_edge_partition_dst_range_size(), + dst_bias_cache_->mutable_view().value_first()); + } else { + dst_bias_sum = static_cast(graph_view.local_edge_partition_dst_range_size()) / + static_cast(graph_view.number_of_vertices()); + } + + std::vector h_gpu_bias; + h_gpu_bias.reserve(graph_view.number_of_local_edge_partitions()); + + for (size_t partition_idx = 0; partition_idx < graph_view.number_of_local_edge_partitions(); + ++partition_idx) { + weight_t src_bias_sum{ + static_cast(graph_view.local_edge_partition_src_range_size(partition_idx)) / + static_cast(graph_view.number_of_vertices())}; + + if (src_bias) { + // Normalize each batch of biases and compute the inclusive prefix sum + src_bias_sum = + thrust::reduce(handle.get_thrust_policy(), + src_bias_cache_->view().value_firsts()[partition_idx], + src_bias_cache_->view().value_firsts()[partition_idx] + + graph_view.local_edge_partition_src_range_size(partition_idx), + weight_t{0}); + + thrust::transform(handle.get_thrust_policy(), + src_bias_cache_->mutable_view().value_firsts()[partition_idx], + src_bias_cache_->mutable_view().value_firsts()[partition_idx] + + graph_view.local_edge_partition_src_range_size(partition_idx), + src_bias_cache_->mutable_view().value_firsts()[partition_idx], + divider_t{src_bias_sum}); + + thrust::inclusive_scan(handle.get_thrust_policy(), + src_bias_cache_->mutable_view().value_firsts()[partition_idx], + src_bias_cache_->mutable_view().value_firsts()[partition_idx] + + graph_view.local_edge_partition_src_range_size(partition_idx), + src_bias_cache_->mutable_view().value_firsts()[partition_idx]); + } + + // Because src_bias and dst_bias are normalized, the probability of a random edge appearing + // on this partition is (src_bias_sum * dst_bias_sum) + h_gpu_bias.push_back(src_bias_sum * dst_bias_sum); + } + + rmm::device_uvector d_gpu_bias(h_gpu_bias.size(), handle.get_stream()); + raft::update_device( + d_gpu_bias.data(), h_gpu_bias.data(), h_gpu_bias.size(), handle.get_stream()); + + gpu_bias_v_ = cugraph::device_allgatherv( + handle, + handle.get_comms(), + raft::device_span{d_gpu_bias.data(), d_gpu_bias.size()}); + + thrust::inclusive_scan( + handle.get_thrust_policy(), gpu_bias_v_.begin(), gpu_bias_v_.end(), gpu_bias_v_.begin()); + } else { + if (dst_bias_cache_) + thrust::inclusive_scan(handle.get_thrust_policy(), + dst_bias_cache_->mutable_view().value_first(), + dst_bias_cache_->mutable_view().value_first() + + graph_view.local_edge_partition_dst_range_size(), + dst_bias_cache_->mutable_view().value_first()); + + if (src_bias_cache_) + thrust::inclusive_scan(handle.get_thrust_policy(), + src_bias_cache_->mutable_view().value_firsts()[0], + src_bias_cache_->mutable_view().value_firsts()[0] + + graph_view.local_edge_partition_src_range_size(0), + src_bias_cache_->mutable_view().value_firsts()[0]); + } + } + + std::tuple, rmm::device_uvector> create_local_samples( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + size_t num_samples) + { + rmm::device_uvector src(0, handle.get_stream()); + rmm::device_uvector dst(0, handle.get_stream()); + + std::vector sample_counts; + + // Determine sample counts per GPU edge partition + if constexpr (multi_gpu) { + auto const comm_size = handle.get_comms().get_size(); + auto const rank = handle.get_comms().get_rank(); + auto& major_comm = handle.get_subcomm(cugraph::partition_manager::major_comm_name()); + auto const major_comm_size = major_comm.get_size(); + auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); + auto const minor_comm_size = minor_comm.get_size(); + + // First step is to count how many go on each edge_partition + rmm::device_uvector gpu_counts(gpu_bias_v_.size(), handle.get_stream()); + thrust::fill(handle.get_thrust_policy(), gpu_counts.begin(), gpu_counts.end(), int{0}); + + rmm::device_uvector random_values(num_samples, handle.get_stream()); + detail::uniform_random_fill(handle.get_stream(), + random_values.data(), + random_values.size(), + weight_t{0}, + weight_t{1}, + rng_state); + + thrust::sort(handle.get_thrust_policy(), random_values.begin(), random_values.end()); + + thrust::upper_bound(handle.get_thrust_policy(), + random_values.begin(), + random_values.end(), + gpu_bias_v_.begin(), + gpu_bias_v_.end(), + gpu_counts.begin()); + + thrust::adjacent_difference( + handle.get_thrust_policy(), gpu_counts.begin(), gpu_counts.end(), gpu_counts.begin()); + + device_allreduce(handle.get_comms(), + gpu_counts.begin(), + gpu_counts.begin(), + gpu_counts.size(), + raft::comms::op_t::SUM, + handle.get_stream()); + + num_samples = thrust::reduce(handle.get_thrust_policy(), + gpu_counts.begin() + rank * minor_comm_size, + gpu_counts.begin() + rank * minor_comm_size + minor_comm_size, + size_t{0}); + + sample_counts.resize(minor_comm_size); + raft::update_host(sample_counts.data(), + gpu_counts.data() + rank * minor_comm_size, + minor_comm_size, + handle.get_stream()); + + } else { + // SG is only one partition + sample_counts.push_back(num_samples); + } + + src.resize(num_samples, handle.get_stream()); + dst.resize(num_samples, handle.get_stream()); + + size_t current_pos{0}; + + for (size_t partition_idx = 0; partition_idx < graph_view.number_of_local_edge_partitions(); + ++partition_idx) { + if (sample_counts[partition_idx] > 0) { + if (src_bias_cache_) { + rmm::device_uvector random_values(sample_counts[partition_idx], + handle.get_stream()); + detail::uniform_random_fill(handle.get_stream(), + random_values.data(), + random_values.size(), + weight_t{0}, + weight_t{1}, + rng_state); + + thrust::transform( + handle.get_thrust_policy(), + random_values.begin(), + random_values.end(), + src.begin() + current_pos, + [biases = + raft::device_span{ + src_bias_cache_->view().value_firsts()[partition_idx], + static_cast( + graph_view.local_edge_partition_src_range_size(partition_idx))}, + offset = graph_view.local_edge_partition_src_range_first( + partition_idx)] __device__(weight_t r) { + size_t result = + offset + static_cast(thrust::distance( + biases.begin(), + thrust::lower_bound(thrust::seq, biases.begin(), biases.end(), r))); + + // FIXME: https://github.com/rapidsai/raft/issues/2400 + // results in the possibility that 1 can appear as a + // random floating point value, which results in the sampling + // algorithm below generating a value that's OOB. + if (result == (offset + biases.size())) --result; + + return result; + }); + } else { + detail::uniform_random_fill( + handle.get_stream(), + src.data() + current_pos, + sample_counts[partition_idx], + graph_view.local_edge_partition_src_range_first(partition_idx), + graph_view.local_edge_partition_src_range_last(partition_idx), + rng_state); + } + + if (dst_bias_cache_) { + rmm::device_uvector random_values(sample_counts[partition_idx], + handle.get_stream()); + detail::uniform_random_fill(handle.get_stream(), + random_values.data(), + random_values.size(), + weight_t{0}, + weight_t{1}, + rng_state); + + thrust::transform( + handle.get_thrust_policy(), + random_values.begin(), + random_values.end(), + dst.begin() + current_pos, + [biases = + raft::device_span{ + dst_bias_cache_->view().value_first(), + static_cast(graph_view.local_edge_partition_dst_range_size())}, + offset = graph_view.local_edge_partition_dst_range_first()] __device__(weight_t r) { + size_t result = + offset + static_cast(thrust::distance( + biases.begin(), + thrust::lower_bound(thrust::seq, biases.begin(), biases.end(), r))); + + // FIXME: https://github.com/rapidsai/raft/issues/2400 + // results in the possibility that 1 can appear as a + // random floating point value, which results in the sampling + // algorithm below generating a value that's OOB. + if (result == (offset + biases.size())) --result; + + return result; + }); + } else { + detail::uniform_random_fill(handle.get_stream(), + dst.data() + current_pos, + sample_counts[partition_idx], + graph_view.local_edge_partition_dst_range_first(), + graph_view.local_edge_partition_dst_range_last(), + rng_state); + } + + current_pos += sample_counts[partition_idx]; + } + } + + return std::make_tuple(std::move(src), std::move(dst)); + } + + private: + rmm::device_uvector gpu_bias_v_; + rmm::device_uvector src_bias_v_; + rmm::device_uvector dst_bias_v_; + std::optional< + edge_src_property_t, weight_t>> + src_bias_cache_; + std::optional< + edge_dst_property_t, weight_t>> + dst_bias_cache_; +}; + +} // namespace detail + template , rmm::device_uvector> negativ bool exact_number_of_samples, bool do_expensive_check) { + detail::negative_sampling_impl_t impl( + handle, graph_view, src_bias, dst_bias); + rmm::device_uvector src(0, handle.get_stream()); rmm::device_uvector dst(0, handle.get_stream()); @@ -57,60 +421,16 @@ std::tuple, rmm::device_uvector> negativ (samples_in_this_batch / num_gpus) + (rank < (samples_in_this_batch % num_gpus) ? 1 : 0); } - rmm::device_uvector batch_src(samples_in_this_batch, handle.get_stream()); - rmm::device_uvector batch_dst(samples_in_this_batch, handle.get_stream()); - - if (src_bias) { - detail::biased_random_fill(handle, - rng_state, - raft::device_span{batch_src.data(), batch_src.size()}, - *src_bias); - } else { - detail::uniform_random_fill(handle.get_stream(), - batch_src.data(), - batch_src.size(), - vertex_t{0}, - graph_view.number_of_vertices(), - rng_state); - } - - if (dst_bias) { - detail::biased_random_fill(handle, - rng_state, - raft::device_span{batch_dst.data(), batch_dst.size()}, - *dst_bias); - } else { - detail::uniform_random_fill(handle.get_stream(), - batch_dst.data(), - batch_dst.size(), - vertex_t{0}, - graph_view.number_of_vertices(), - rng_state); - } - - if constexpr (multi_gpu) { - auto vertex_partition_range_lasts = graph_view.vertex_partition_range_lasts(); - - std::tie(batch_src, batch_dst, std::ignore, std::ignore, std::ignore) = - detail::shuffle_int_vertex_pairs_with_values_to_local_gpu_by_edge_partitioning( - handle, - std::move(batch_src), - std::move(batch_dst), - std::nullopt, - std::nullopt, - std::nullopt, - vertex_partition_range_lasts); - } + auto [batch_src, batch_dst] = + impl.create_local_samples(handle, rng_state, graph_view, samples_in_this_batch); if (remove_false_negatives) { auto has_edge_flags = graph_view.has_edge(handle, raft::device_span{batch_src.data(), batch_src.size()}, raft::device_span{batch_dst.data(), batch_dst.size()}, - do_expensive_check); + // do_expensive_check); + true); auto begin_iter = thrust::make_zip_iterator(batch_src.begin(), batch_dst.begin(), has_edge_flags.begin()); @@ -177,16 +497,16 @@ std::tuple, rmm::device_uvector> negativ } if (exact_number_of_samples) { - size_t num_batch_samples = src.size(); + size_t current_sample_size = src.size(); if constexpr (multi_gpu) { - num_batch_samples = cugraph::host_scalar_allreduce( - handle.get_comms(), num_batch_samples, raft::comms::op_t::SUM, handle.get_stream()); + current_sample_size = cugraph::host_scalar_allreduce( + handle.get_comms(), current_sample_size, raft::comms::op_t::SUM, handle.get_stream()); } // FIXME: We could oversample and discard the unnecessary samples // to reduce the number of iterations in the outer loop, but it seems like // exact_number_of_samples is an edge case not worth optimizing for at this time. - samples_in_this_batch = num_samples - num_batch_samples; + samples_in_this_batch = num_samples - current_sample_size; } else { samples_in_this_batch = 0; } diff --git a/cpp/tests/sampling/mg_negative_sampling.cu b/cpp/tests/sampling/mg_negative_sampling.cu index e180594f87b..0bc6bc2e737 100644 --- a/cpp/tests/sampling/mg_negative_sampling.cu +++ b/cpp/tests/sampling/mg_negative_sampling.cu @@ -89,7 +89,7 @@ class Tests_MGNegative_Sampling : public ::testing::TestWithParam> dst_bias{std::nullopt}; if (negative_sampling_usecase.use_src_bias) { - src_bias_v.resize(graph_view.number_of_vertices(), handle_->get_stream()); + src_bias_v.resize(graph_view.local_vertex_partition_range_size(), handle_->get_stream()); cugraph::detail::uniform_random_fill(handle_->get_stream(), src_bias_v.data(), @@ -102,7 +102,7 @@ class Tests_MGNegative_Sampling : public ::testing::TestWithParamget_stream()); + dst_bias_v.resize(graph_view.local_vertex_partition_range_size(), handle_->get_stream()); cugraph::detail::uniform_random_fill(handle_->get_stream(), dst_bias_v.data(), @@ -160,12 +160,19 @@ class Tests_MGNegative_Sampling : public ::testing::TestWithParamget_subcomm(cugraph::partition_manager::major_comm_name()).get_size(), handle_->get_subcomm(cugraph::partition_manager::minor_comm_name()) .get_size()}] __device__(auto e) { + if (gpu_id_key_func(thrust::get<0>(e), thrust::get<1>(e)) != comm_rank) + printf(" gpu_id(%d,%d) = %d, expected %d\n", + (int)thrust::get<0>(e), + (int)thrust::get<1>(e), + gpu_id_key_func(thrust::get<0>(e), thrust::get<1>(e)), + comm_rank); + return (gpu_id_key_func(thrust::get<0>(e), thrust::get<1>(e)) != comm_rank); }); ASSERT_EQ(error_count, 0) << "generate edges out of range > 0"; - if (negative_sampling_usecase.remove_duplicates) { + if ((negative_sampling_usecase.remove_duplicates) && (src_out.size() > 0)) { error_count = thrust::count_if( handle_->get_thrust_policy(), thrust::make_counting_iterator(1), @@ -222,21 +229,9 @@ template Tests_MGNegative_Sampling::handle_ = nullptr; -using Tests_MGNegative_Sampling_File_i32_i32_float = - Tests_MGNegative_Sampling; - -using Tests_MGNegative_Sampling_File_i32_i64_float = - Tests_MGNegative_Sampling; - using Tests_MGNegative_Sampling_File_i64_i64_float = Tests_MGNegative_Sampling; -using Tests_MGNegative_Sampling_Rmat_i32_i32_float = - Tests_MGNegative_Sampling; - -using Tests_MGNegative_Sampling_Rmat_i32_i64_float = - Tests_MGNegative_Sampling; - using Tests_MGNegative_Sampling_Rmat_i64_i64_float = Tests_MGNegative_Sampling; @@ -312,71 +307,23 @@ void run_all_tests(CurrentTest* current_test) Negative_Sampling_Usecase{2, true, true, true, true, true, true}); } -TEST_P(Tests_MGNegative_Sampling_File_i32_i32_float, CheckInt32Int32Float) -{ - load_graph(override_File_Usecase_with_cmd_line_arguments(GetParam())); - run_all_tests(this); -} - -TEST_P(Tests_MGNegative_Sampling_File_i32_i64_float, CheckInt32Int64Float) -{ - load_graph(override_File_Usecase_with_cmd_line_arguments(GetParam())); - run_all_tests(this); -} - TEST_P(Tests_MGNegative_Sampling_File_i64_i64_float, CheckInt64Int64Float) { load_graph(override_File_Usecase_with_cmd_line_arguments(GetParam())); run_all_tests(this); } -TEST_P(Tests_MGNegative_Sampling_Rmat_i32_i32_float, CheckInt32Int32Float) -{ - load_graph(override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); - run_all_tests(this); -} - -TEST_P(Tests_MGNegative_Sampling_Rmat_i32_i64_float, CheckInt32Int64Float) -{ - load_graph(override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); - run_all_tests(this); -} - TEST_P(Tests_MGNegative_Sampling_Rmat_i64_i64_float, CheckInt64Int64Float) { load_graph(override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); run_all_tests(this); } -INSTANTIATE_TEST_SUITE_P( - file_test, - Tests_MGNegative_Sampling_File_i32_i32_float, - ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"))); - -INSTANTIATE_TEST_SUITE_P( - file_test, - Tests_MGNegative_Sampling_File_i32_i64_float, - ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"))); - INSTANTIATE_TEST_SUITE_P( file_test, Tests_MGNegative_Sampling_File_i64_i64_float, ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"))); -INSTANTIATE_TEST_SUITE_P( - file_large_test, - Tests_MGNegative_Sampling_File_i32_i32_float, - ::testing::Values(cugraph::test::File_Usecase("test/datasets/web-Google.mtx"), - cugraph::test::File_Usecase("test/datasets/ljournal-2008.mtx"), - cugraph::test::File_Usecase("test/datasets/webbase-1M.mtx"))); - -INSTANTIATE_TEST_SUITE_P( - file_large_test, - Tests_MGNegative_Sampling_File_i32_i64_float, - ::testing::Values(cugraph::test::File_Usecase("test/datasets/web-Google.mtx"), - cugraph::test::File_Usecase("test/datasets/ljournal-2008.mtx"), - cugraph::test::File_Usecase("test/datasets/webbase-1M.mtx"))); - INSTANTIATE_TEST_SUITE_P( file_large_test, Tests_MGNegative_Sampling_File_i64_i64_float, @@ -384,16 +331,6 @@ INSTANTIATE_TEST_SUITE_P( cugraph::test::File_Usecase("test/datasets/ljournal-2008.mtx"), cugraph::test::File_Usecase("test/datasets/webbase-1M.mtx"))); -INSTANTIATE_TEST_SUITE_P( - rmat_small_test, - Tests_MGNegative_Sampling_Rmat_i32_i32_float, - ::testing::Values(cugraph::test::Rmat_Usecase(10, 16, 0.57, 0.19, 0.19, 0, false, false, 0))); - -INSTANTIATE_TEST_SUITE_P( - rmat_small_test, - Tests_MGNegative_Sampling_Rmat_i32_i64_float, - ::testing::Values(cugraph::test::Rmat_Usecase(10, 16, 0.57, 0.19, 0.19, 0, false, false, 0))); - INSTANTIATE_TEST_SUITE_P( rmat_small_test, Tests_MGNegative_Sampling_Rmat_i64_i64_float, From 06c3d5d6a7edcee10c8c84b96d198582f69da7c8 Mon Sep 17 00:00:00 2001 From: Charles Hastings Date: Fri, 9 Aug 2024 15:09:28 -0700 Subject: [PATCH 09/18] Refactor to do biased sampling by vertex partitions instead of exposing inner details of edge partitioning --- cpp/src/detail/utility_wrappers_32.cu | 10 - cpp/src/detail/utility_wrappers_64.cu | 10 - cpp/src/detail/utility_wrappers_impl.cuh | 14 - cpp/src/sampling/negative_sampling_impl.cuh | 494 +++++++----------- .../sampling/detail/nbr_sampling_validate.cu | 2 + 5 files changed, 190 insertions(+), 340 deletions(-) diff --git a/cpp/src/detail/utility_wrappers_32.cu b/cpp/src/detail/utility_wrappers_32.cu index 35dc15079b2..72dee4a19a5 100644 --- a/cpp/src/detail/utility_wrappers_32.cu +++ b/cpp/src/detail/utility_wrappers_32.cu @@ -54,16 +54,6 @@ template void uniform_random_fill(rmm::cuda_stream_view const& stream_view, float max_value, raft::random::RngState& rng_state); -template void biased_random_fill(raft::handle_t const& handle, - raft::random::RngState& rng_state, - raft::device_span output, - raft::device_span biases); - -template void biased_random_fill(raft::handle_t const& handle, - raft::random::RngState& rng_state, - raft::device_span output, - raft::device_span biases); - template void scalar_fill(raft::handle_t const& handle, int32_t* d_value, size_t size, diff --git a/cpp/src/detail/utility_wrappers_64.cu b/cpp/src/detail/utility_wrappers_64.cu index a6dfb5d768c..e7254d97c4d 100644 --- a/cpp/src/detail/utility_wrappers_64.cu +++ b/cpp/src/detail/utility_wrappers_64.cu @@ -54,16 +54,6 @@ template void uniform_random_fill(rmm::cuda_stream_view const& stream_view, double max_value, raft::random::RngState& rng_state); -template void biased_random_fill(raft::handle_t const& handle, - raft::random::RngState& rng_state, - raft::device_span output, - raft::device_span biases); - -template void biased_random_fill(raft::handle_t const& handle, - raft::random::RngState& rng_state, - raft::device_span output, - raft::device_span biases); - template void scalar_fill(raft::handle_t const& handle, int64_t* d_value, size_t size, diff --git a/cpp/src/detail/utility_wrappers_impl.cuh b/cpp/src/detail/utility_wrappers_impl.cuh index f6023c650b8..ce8549db9f8 100644 --- a/cpp/src/detail/utility_wrappers_impl.cuh +++ b/cpp/src/detail/utility_wrappers_impl.cuh @@ -57,20 +57,6 @@ void uniform_random_fill(rmm::cuda_stream_view const& stream_view, } } -template -void biased_random_fill(raft::handle_t const& handle, - raft::random::RngState& rng_state, - raft::device_span output, - raft::device_span biases) -{ - CUGRAPH_EXPECTS(std::is_integral::value, - "biased_random_fill can only output integral values"); - raft::random::discrete(handle, - rng_state, - raft::make_device_vector_view(output.data(), output.size()), - raft::make_device_vector_view(biases.data(), biases.size())); -} - template void scalar_fill(raft::handle_t const& handle, value_t* d_value, size_t size, value_t value) { diff --git a/cpp/src/sampling/negative_sampling_impl.cuh b/cpp/src/sampling/negative_sampling_impl.cuh index a0dc500d9e6..fd14b7f5fef 100644 --- a/cpp/src/sampling/negative_sampling_impl.cuh +++ b/cpp/src/sampling/negative_sampling_impl.cuh @@ -41,181 +41,86 @@ namespace cugraph { namespace detail { -template -class negative_sampling_impl_t { - private: - static const bool store_transposed = false; - - public: - negative_sampling_impl_t( - raft::handle_t const& handle, - graph_view_t const& graph_view, - std::optional> src_bias, - std::optional> dst_bias) - : gpu_bias_v_(0, handle.get_stream()), - src_bias_v_(0, handle.get_stream()), - dst_bias_v_(0, handle.get_stream()), - src_bias_cache_(std::nullopt), - dst_bias_cache_(std::nullopt) - { - // Need to normalize the src_bias - if (src_bias) { - // Normalize the src bias. - rmm::device_uvector normalized_bias(graph_view.local_vertex_partition_range_size(), - handle.get_stream()); - - weight_t sum = reduce_v(handle, graph_view, src_bias->begin()); +template +std::tuple>, + std::optional>> +normalize_biases(raft::handle_t const& handle, + graph_view_t const& graph_view, + std::optional> biases) +{ + std::optional> normalized_biases{std::nullopt}; + std::optional> gpu_biases{std::nullopt}; - if constexpr (multi_gpu) { - sum = host_scalar_allreduce( - handle.get_comms(), sum, raft::comms::op_t::SUM, handle.get_stream()); - } + if (biases) { + // Need to normalize the biases + normalized_biases = + std::make_optional>(biases->size(), handle.get_stream()); - thrust::transform(handle.get_thrust_policy(), - src_bias->begin(), - src_bias->end(), - normalized_bias.begin(), - divider_t{sum}); - - // Distribute the src bias around the edge partitions - src_bias_cache_ = std::make_optional< - edge_src_property_t, weight_t>>( - handle, graph_view); - update_edge_src_property( - handle, graph_view, normalized_bias.begin(), src_bias_cache_->mutable_view()); - } + weight_t sum = + thrust::reduce(handle.get_thrust_policy(), biases->begin(), biases->end(), weight_t{0}); - if (dst_bias) { - // Normalize the dst bias. - rmm::device_uvector normalized_bias(graph_view.local_vertex_partition_range_size(), - handle.get_stream()); + weight_t aggregate_sum{sum}; - weight_t sum = reduce_v(handle, graph_view, dst_bias->begin()); + if constexpr (multi_gpu) { + aggregate_sum = + host_scalar_allreduce(handle.get_comms(), sum, raft::comms::op_t::SUM, handle.get_stream()); + } - if constexpr (multi_gpu) { - sum = host_scalar_allreduce( - handle.get_comms(), sum, raft::comms::op_t::SUM, handle.get_stream()); - } + thrust::transform(handle.get_thrust_policy(), + biases->begin(), + biases->end(), + normalized_biases->begin(), + divider_t{sum}); - thrust::transform(handle.get_thrust_policy(), - dst_bias->begin(), - dst_bias->end(), - normalized_bias.begin(), - divider_t{sum}); - - dst_bias_cache_ = std::make_optional< - edge_dst_property_t, weight_t>>( - handle, graph_view); - update_edge_dst_property( - handle, graph_view, normalized_bias.begin(), dst_bias_cache_->mutable_view()); - } + thrust::inclusive_scan(handle.get_thrust_policy(), + normalized_biases->begin(), + normalized_biases->end(), + normalized_biases->begin()); if constexpr (multi_gpu) { - weight_t dst_bias_sum{0}; - - if (dst_bias) { - // Compute the dst_bias sum for this partition and normalize cached values - dst_bias_sum = thrust::reduce( - handle.get_thrust_policy(), - dst_bias_cache_->view().value_first(), - dst_bias_cache_->view().value_first() + graph_view.local_edge_partition_dst_range_size(), - weight_t{0}); - - thrust::transform(handle.get_thrust_policy(), - dst_bias_cache_->mutable_view().value_first(), - dst_bias_cache_->mutable_view().value_first() + - graph_view.local_edge_partition_dst_range_size(), - dst_bias_cache_->mutable_view().value_first(), - divider_t{dst_bias_sum}); - - thrust::inclusive_scan(handle.get_thrust_policy(), - dst_bias_cache_->mutable_view().value_first(), - dst_bias_cache_->mutable_view().value_first() + - graph_view.local_edge_partition_dst_range_size(), - dst_bias_cache_->mutable_view().value_first()); - } else { - dst_bias_sum = static_cast(graph_view.local_edge_partition_dst_range_size()) / - static_cast(graph_view.number_of_vertices()); - } + rmm::device_scalar d_sum((sum / aggregate_sum), handle.get_stream()); + gpu_biases = cugraph::device_allgatherv( + handle, handle.get_comms(), raft::device_span{d_sum.data(), d_sum.size()}); - std::vector h_gpu_bias; - h_gpu_bias.reserve(graph_view.number_of_local_edge_partitions()); - - for (size_t partition_idx = 0; partition_idx < graph_view.number_of_local_edge_partitions(); - ++partition_idx) { - weight_t src_bias_sum{ - static_cast(graph_view.local_edge_partition_src_range_size(partition_idx)) / - static_cast(graph_view.number_of_vertices())}; - - if (src_bias) { - // Normalize each batch of biases and compute the inclusive prefix sum - src_bias_sum = - thrust::reduce(handle.get_thrust_policy(), - src_bias_cache_->view().value_firsts()[partition_idx], - src_bias_cache_->view().value_firsts()[partition_idx] + - graph_view.local_edge_partition_src_range_size(partition_idx), - weight_t{0}); - - thrust::transform(handle.get_thrust_policy(), - src_bias_cache_->mutable_view().value_firsts()[partition_idx], - src_bias_cache_->mutable_view().value_firsts()[partition_idx] + - graph_view.local_edge_partition_src_range_size(partition_idx), - src_bias_cache_->mutable_view().value_firsts()[partition_idx], - divider_t{src_bias_sum}); - - thrust::inclusive_scan(handle.get_thrust_policy(), - src_bias_cache_->mutable_view().value_firsts()[partition_idx], - src_bias_cache_->mutable_view().value_firsts()[partition_idx] + - graph_view.local_edge_partition_src_range_size(partition_idx), - src_bias_cache_->mutable_view().value_firsts()[partition_idx]); - } - - // Because src_bias and dst_bias are normalized, the probability of a random edge appearing - // on this partition is (src_bias_sum * dst_bias_sum) - h_gpu_bias.push_back(src_bias_sum * dst_bias_sum); - } + thrust::inclusive_scan( + handle.get_thrust_policy(), gpu_biases->begin(), gpu_biases->end(), gpu_biases->begin()); - rmm::device_uvector d_gpu_bias(h_gpu_bias.size(), handle.get_stream()); + weight_t force_to_one{1.1}; raft::update_device( - d_gpu_bias.data(), h_gpu_bias.data(), h_gpu_bias.size(), handle.get_stream()); - - gpu_bias_v_ = cugraph::device_allgatherv( - handle, - handle.get_comms(), - raft::device_span{d_gpu_bias.data(), d_gpu_bias.size()}); - - thrust::inclusive_scan( - handle.get_thrust_policy(), gpu_bias_v_.begin(), gpu_bias_v_.end(), gpu_bias_v_.begin()); - } else { - if (dst_bias_cache_) - thrust::inclusive_scan(handle.get_thrust_policy(), - dst_bias_cache_->mutable_view().value_first(), - dst_bias_cache_->mutable_view().value_first() + - graph_view.local_edge_partition_dst_range_size(), - dst_bias_cache_->mutable_view().value_first()); - - if (src_bias_cache_) - thrust::inclusive_scan(handle.get_thrust_policy(), - src_bias_cache_->mutable_view().value_firsts()[0], - src_bias_cache_->mutable_view().value_firsts()[0] + - graph_view.local_edge_partition_src_range_size(0), - src_bias_cache_->mutable_view().value_firsts()[0]); + gpu_biases->data() + gpu_biases->size() - 1, &force_to_one, 1, handle.get_stream()); } } - std::tuple, rmm::device_uvector> create_local_samples( - raft::handle_t const& handle, - raft::random::RngState& rng_state, - graph_view_t const& graph_view, - size_t num_samples) - { - rmm::device_uvector src(0, handle.get_stream()); - rmm::device_uvector dst(0, handle.get_stream()); + return std::make_tuple(std::move(normalized_biases), std::move(gpu_biases)); +} + +template +rmm::device_uvector create_local_samples( + raft::handle_t const& handle, + raft::random::RngState& rng_state, + graph_view_t const& graph_view, + std::optional> const& normalized_biases, + std::optional> const& gpu_biases, + size_t samples_in_this_batch) +{ + rmm::device_uvector samples(0, handle.get_stream()); + + if (normalized_biases) { + size_t samples_to_generate{samples_in_this_batch}; + std::vector sample_count_from_each_gpu; - std::vector sample_counts; + rmm::device_uvector position(0, handle.get_stream()); - // Determine sample counts per GPU edge partition if constexpr (multi_gpu) { + // Determine how many vertices are generated on each GPU auto const comm_size = handle.get_comms().get_size(); auto const rank = handle.get_comms().get_rank(); auto& major_comm = handle.get_subcomm(cugraph::partition_manager::major_comm_name()); @@ -223,11 +128,15 @@ class negative_sampling_impl_t { auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); auto const minor_comm_size = minor_comm.get_size(); - // First step is to count how many go on each edge_partition - rmm::device_uvector gpu_counts(gpu_bias_v_.size(), handle.get_stream()); - thrust::fill(handle.get_thrust_policy(), gpu_counts.begin(), gpu_counts.end(), int{0}); + sample_count_from_each_gpu.resize(comm_size); - rmm::device_uvector random_values(num_samples, handle.get_stream()); + rmm::device_uvector gpu_counts(comm_size, handle.get_stream()); + position.resize(samples_in_this_batch, handle.get_stream()); + + thrust::fill(handle.get_thrust_policy(), gpu_counts.begin(), gpu_counts.end(), size_t{0}); + thrust::sequence(handle.get_thrust_policy(), position.begin(), position.end()); + + rmm::device_uvector random_values(samples_in_this_batch, handle.get_stream()); detail::uniform_random_fill(handle.get_stream(), random_values.data(), random_values.size(), @@ -235,154 +144,107 @@ class negative_sampling_impl_t { weight_t{1}, rng_state); - thrust::sort(handle.get_thrust_policy(), random_values.begin(), random_values.end()); + thrust::sort(handle.get_thrust_policy(), + thrust::make_zip_iterator(random_values.begin(), position.begin()), + thrust::make_zip_iterator(random_values.end(), position.end())); thrust::upper_bound(handle.get_thrust_policy(), random_values.begin(), random_values.end(), - gpu_bias_v_.begin(), - gpu_bias_v_.end(), + gpu_biases->begin(), + gpu_biases->end(), gpu_counts.begin()); thrust::adjacent_difference( handle.get_thrust_policy(), gpu_counts.begin(), gpu_counts.end(), gpu_counts.begin()); - device_allreduce(handle.get_comms(), - gpu_counts.begin(), - gpu_counts.begin(), - gpu_counts.size(), - raft::comms::op_t::SUM, - handle.get_stream()); - - num_samples = thrust::reduce(handle.get_thrust_policy(), - gpu_counts.begin() + rank * minor_comm_size, - gpu_counts.begin() + rank * minor_comm_size + minor_comm_size, - size_t{0}); - - sample_counts.resize(minor_comm_size); - raft::update_host(sample_counts.data(), - gpu_counts.data() + rank * minor_comm_size, - minor_comm_size, - handle.get_stream()); + // all_gpu_counts[i][j] will be how many vertices need to be generated on GPU j to be sent to + // GPU i + auto all_gpu_counts = cugraph::device_allgatherv( + handle, + handle.get_comms(), + raft::device_span{gpu_counts.data(), gpu_counts.size()}); - } else { - // SG is only one partition - sample_counts.push_back(num_samples); - } + auto begin_iter = thrust::make_transform_iterator( + thrust::make_counting_iterator(0), + cuda::proclaim_return_type( + [rank, stride = comm_size, counts = all_gpu_counts.data()] __device__(size_t idx) { + return counts[idx * stride + rank]; + })); - src.resize(num_samples, handle.get_stream()); - dst.resize(num_samples, handle.get_stream()); - - size_t current_pos{0}; - - for (size_t partition_idx = 0; partition_idx < graph_view.number_of_local_edge_partitions(); - ++partition_idx) { - if (sample_counts[partition_idx] > 0) { - if (src_bias_cache_) { - rmm::device_uvector random_values(sample_counts[partition_idx], - handle.get_stream()); - detail::uniform_random_fill(handle.get_stream(), - random_values.data(), - random_values.size(), - weight_t{0}, - weight_t{1}, - rng_state); - - thrust::transform( - handle.get_thrust_policy(), - random_values.begin(), - random_values.end(), - src.begin() + current_pos, - [biases = - raft::device_span{ - src_bias_cache_->view().value_firsts()[partition_idx], - static_cast( - graph_view.local_edge_partition_src_range_size(partition_idx))}, - offset = graph_view.local_edge_partition_src_range_first( - partition_idx)] __device__(weight_t r) { - size_t result = - offset + static_cast(thrust::distance( - biases.begin(), - thrust::lower_bound(thrust::seq, biases.begin(), biases.end(), r))); - - // FIXME: https://github.com/rapidsai/raft/issues/2400 - // results in the possibility that 1 can appear as a - // random floating point value, which results in the sampling - // algorithm below generating a value that's OOB. - if (result == (offset + biases.size())) --result; - - return result; - }); - } else { - detail::uniform_random_fill( - handle.get_stream(), - src.data() + current_pos, - sample_counts[partition_idx], - graph_view.local_edge_partition_src_range_first(partition_idx), - graph_view.local_edge_partition_src_range_last(partition_idx), - rng_state); - } - - if (dst_bias_cache_) { - rmm::device_uvector random_values(sample_counts[partition_idx], - handle.get_stream()); - detail::uniform_random_fill(handle.get_stream(), - random_values.data(), - random_values.size(), - weight_t{0}, - weight_t{1}, - rng_state); - - thrust::transform( - handle.get_thrust_policy(), - random_values.begin(), - random_values.end(), - dst.begin() + current_pos, - [biases = - raft::device_span{ - dst_bias_cache_->view().value_first(), - static_cast(graph_view.local_edge_partition_dst_range_size())}, - offset = graph_view.local_edge_partition_dst_range_first()] __device__(weight_t r) { - size_t result = - offset + static_cast(thrust::distance( - biases.begin(), - thrust::lower_bound(thrust::seq, biases.begin(), biases.end(), r))); - - // FIXME: https://github.com/rapidsai/raft/issues/2400 - // results in the possibility that 1 can appear as a - // random floating point value, which results in the sampling - // algorithm below generating a value that's OOB. - if (result == (offset + biases.size())) --result; - - return result; - }); - } else { - detail::uniform_random_fill(handle.get_stream(), - dst.data() + current_pos, - sample_counts[partition_idx], - graph_view.local_edge_partition_dst_range_first(), - graph_view.local_edge_partition_dst_range_last(), - rng_state); - } - - current_pos += sample_counts[partition_idx]; - } + samples_to_generate = + thrust::reduce(handle.get_thrust_policy(), begin_iter, begin_iter + comm_size, size_t{0}); + + rmm::device_uvector d_sample_count_from_each_gpu(comm_size, handle.get_stream()); + + thrust::copy(handle.get_thrust_policy(), + begin_iter, + begin_iter + comm_size, + d_sample_count_from_each_gpu.begin()); + + raft::update_host(sample_count_from_each_gpu.data(), + d_sample_count_from_each_gpu.data(), + d_sample_count_from_each_gpu.size(), + handle.get_stream()); } - return std::make_tuple(std::move(src), std::move(dst)); + // Generate samples + // FIXME: We could save this memory if we had an iterator that + // generated random values. + rmm::device_uvector random_values(samples_to_generate, handle.get_stream()); + samples.resize(samples_to_generate, handle.get_stream()); + detail::uniform_random_fill(handle.get_stream(), + random_values.data(), + random_values.size(), + weight_t{0}, + weight_t{1}, + rng_state); + + thrust::transform( + handle.get_thrust_policy(), + random_values.begin(), + random_values.end(), + samples.begin(), + [biases = + raft::device_span{normalized_biases->data(), normalized_biases->size()}, + offset = graph_view.local_vertex_partition_range_first()] __device__(weight_t r) { + size_t result = + offset + + static_cast(thrust::distance( + biases.begin(), thrust::lower_bound(thrust::seq, biases.begin(), biases.end(), r))); + + // FIXME: https://github.com/rapidsai/raft/issues/2400 + // results in the possibility that 1 can appear as a + // random floating point value, which results in the sampling + // algorithm below generating a value that's OOB. + if (result == (offset + biases.size())) --result; + + return result; + }); + + // Shuffle them back + if constexpr (multi_gpu) { + std::tie(samples, std::ignore) = shuffle_values( + handle.get_comms(), samples.begin(), sample_count_from_each_gpu, handle.get_stream()); + + thrust::sort(handle.get_thrust_policy(), + thrust::make_zip_iterator(position.begin(), samples.begin()), + thrust::make_zip_iterator(position.end(), samples.begin())); + } + } else { + samples.resize(samples_in_this_batch, handle.get_stream()); + + // Uniformly select a vertex from any GPU + detail::uniform_random_fill(handle.get_stream(), + samples.data(), + samples.size(), + vertex_t{0}, + graph_view.number_of_vertices(), + rng_state); } - private: - rmm::device_uvector gpu_bias_v_; - rmm::device_uvector src_bias_v_; - rmm::device_uvector dst_bias_v_; - std::optional< - edge_src_property_t, weight_t>> - src_bias_cache_; - std::optional< - edge_dst_property_t, weight_t>> - dst_bias_cache_; -}; + return samples; +} } // namespace detail @@ -396,22 +258,26 @@ std::tuple, rmm::device_uvector> negativ raft::random::RngState& rng_state, graph_view_t const& graph_view, size_t num_samples, - std::optional> src_bias, - std::optional> dst_bias, + std::optional> src_biases, + std::optional> dst_biases, bool remove_duplicates, - bool remove_false_negatives, + bool remove_existing_edges, bool exact_number_of_samples, bool do_expensive_check) { - detail::negative_sampling_impl_t impl( - handle, graph_view, src_bias, dst_bias); - rmm::device_uvector src(0, handle.get_stream()); rmm::device_uvector dst(0, handle.get_stream()); // Optimistically assume we can do this in one pass size_t samples_in_this_batch = num_samples; + // Normalize the biases and (for MG) determine how the biases are + // distributed across the GPUs. + auto [normalized_src_biases, gpu_src_biases] = + detail::normalize_biases(handle, graph_view, src_biases); + auto [normalized_dst_biases, gpu_dst_biases] = + detail::normalize_biases(handle, graph_view, dst_biases); + while (samples_in_this_batch > 0) { if constexpr (multi_gpu) { size_t num_gpus = handle.get_comms().get_size(); @@ -421,16 +287,32 @@ std::tuple, rmm::device_uvector> negativ (samples_in_this_batch / num_gpus) + (rank < (samples_in_this_batch % num_gpus) ? 1 : 0); } - auto [batch_src, batch_dst] = - impl.create_local_samples(handle, rng_state, graph_view, samples_in_this_batch); + auto batch_src = create_local_samples( + handle, rng_state, graph_view, normalized_src_biases, gpu_src_biases, samples_in_this_batch); + auto batch_dst = create_local_samples( + handle, rng_state, graph_view, normalized_dst_biases, gpu_dst_biases, samples_in_this_batch); + + auto vertex_partition_range_lasts = graph_view.vertex_partition_range_lasts(); - if (remove_false_negatives) { + std::tie(batch_src, batch_dst, std::ignore, std::ignore, std::ignore, std::ignore) = + detail::shuffle_int_vertex_pairs_with_values_to_local_gpu_by_edge_partitioning( + handle, + std::move(batch_src), + std::move(batch_dst), + std::nullopt, + std::nullopt, + std::nullopt, + vertex_partition_range_lasts); + + if (remove_existing_edges) { auto has_edge_flags = graph_view.has_edge(handle, raft::device_span{batch_src.data(), batch_src.size()}, raft::device_span{batch_dst.data(), batch_dst.size()}, - // do_expensive_check); - true); + do_expensive_check); auto begin_iter = thrust::make_zip_iterator(batch_src.begin(), batch_dst.begin(), has_edge_flags.begin()); diff --git a/cpp/tests/sampling/detail/nbr_sampling_validate.cu b/cpp/tests/sampling/detail/nbr_sampling_validate.cu index 61731e2e15c..70828e559f1 100644 --- a/cpp/tests/sampling/detail/nbr_sampling_validate.cu +++ b/cpp/tests/sampling/detail/nbr_sampling_validate.cu @@ -75,6 +75,8 @@ struct ArithmeticZipLess { } else { return thrust::get<1>(left) < thrust::get<1>(right); } + } else { + return false; } } }; From 4de3bda15857cd988fbd594d102d90f7cdeca59b Mon Sep 17 00:00:00 2001 From: Charles Hastings Date: Fri, 9 Aug 2024 21:06:46 -0700 Subject: [PATCH 10/18] address other PR comments --- .../cugraph/detail/utility_wrappers.hpp | 23 ----- cpp/include/cugraph/sampling_functions.hpp | 19 +++-- cpp/include/cugraph_c/sampling_algorithms.h | 29 +++---- cpp/src/c_api/negative_sampling.cpp | 82 +++++++++--------- cpp/src/sampling/negative_sampling_impl.cuh | 83 ++++++++++--------- cpp/tests/sampling/mg_negative_sampling.cu | 19 ++++- 6 files changed, 127 insertions(+), 128 deletions(-) diff --git a/cpp/include/cugraph/detail/utility_wrappers.hpp b/cpp/include/cugraph/detail/utility_wrappers.hpp index fc75f06b373..61ac1bd2804 100644 --- a/cpp/include/cugraph/detail/utility_wrappers.hpp +++ b/cpp/include/cugraph/detail/utility_wrappers.hpp @@ -50,29 +50,6 @@ void uniform_random_fill(rmm::cuda_stream_view const& stream_view, value_t max_value, raft::random::RngState& rng_state); -/** - * @brief Fill a buffer with biased random values - * - * Fills a buffer with values based on the specified biases. - * The probability of selecting the value `i` is determined by - * `biases[i] / sum(biases)`. - * - * @tparam value_t type of the value to operate on - * @tparam bias_t type of the bias - * - * @param[in] handle RAFT handle object to encapsulate resources (e.g. CUDA stream, - * communicator, and handles to various CUDA libraries) to run graph algorithms. - * @param[in] rng_state The RngState instance holding pseudo-random number generator state. - * @param[out] output The random values - * @param[in] biases The biased values - * - */ -template -void biased_random_fill(raft::handle_t const& handle, - raft::random::RngState& rng_state, - raft::device_span output, - raft::device_span biases); - /** * @brief Fill a buffer with a constant value * diff --git a/cpp/include/cugraph/sampling_functions.hpp b/cpp/include/cugraph/sampling_functions.hpp index 88854ecc0ea..9c747c54745 100644 --- a/cpp/include/cugraph/sampling_functions.hpp +++ b/cpp/include/cugraph/sampling_functions.hpp @@ -769,13 +769,14 @@ lookup_endpoints_from_edge_ids_and_types( * @param graph_view Graph View object to generate NBR Sampling for * @param rng_state RNG state * @param num_samples Number of negative samples to generate - * @param src_bias Optional bias for randomly selecting source vertices. If std::nullopt vertices - * will be selected uniformly - * @param dst_bias Optional bias for randomly selecting destination vertices. If std::nullopt - * vertices will be selected uniformly + * @param src_biases Optional bias for randomly selecting source vertices. If std::nullopt vertices + * will be selected uniformly. In multi-GPU environment the biases should be partitioned based + * on the vertex partitions. + * @param dst_biases Optional bias for randomly selecting destination vertices. If std::nullopt + * vertices will be selected uniformly. In multi-GPU environment the biases should be partitioned + * based on the vertex partitions. * @param remove_duplicates If true, remove duplicate samples - * @param remove_false_negatives If true, remove false negatives (samples that are actually edges in - * the graph + * @param remove_existing_edges If true, remove samples that are actually edges in the graph * @param exact_number_of_samples If true, repeat generation until we get the exact number of * negative samples * @param do_expensive_check A flag to run expensive checks for input arguments (if set to `true`). @@ -792,10 +793,10 @@ std::tuple, rmm::device_uvector> negativ raft::random::RngState& rng_state, graph_view_t const& graph_view, size_t num_samples, - std::optional> src_bias, - std::optional> dst_bias, + std::optional> src_biases, + std::optional> dst_biases, bool remove_duplicates, - bool remove_false_negatives, + bool remove_existing_edges, bool exact_number_of_samples, bool do_expensive_check); diff --git a/cpp/include/cugraph_c/sampling_algorithms.h b/cpp/include/cugraph_c/sampling_algorithms.h index ae2bd277787..ad3f8b9ea2c 100644 --- a/cpp/include/cugraph_c/sampling_algorithms.h +++ b/cpp/include/cugraph_c/sampling_algorithms.h @@ -683,31 +683,32 @@ cugraph_error_code_t cugraph_select_random_vertices(const cugraph_resource_handl * * @param [in] handle Handle for accessing resources * @param [in,out] rng_state State of the random number generator, updated with each - * call + * call * @param [in] graph Pointer to graph * @param [in] num_samples Number of negative samples to generate * @param [in] vertices Vertex ids for the source biases. If @p src_bias and * @p dst_bias are not specified this is ignored. If * @p vertices is specified then vertices[i] is the vertex - * id of src_bias[i] and dst_bias[i]. If @p vertices is not specified then i is the vertex id if - * src_bias[i] and dst_bias[i] - * @param [in] src_bias Bias for selecting source vertices. If NULL, do uniform + * id of src_biases[i] and dst_biases[i]. If @p vertices + * is not specified then i is the vertex id if src_biases[i] + * and dst_biases[i] + * @param [in] src_biases Bias for selecting source vertices. If NULL, do uniform * sampling, if provided probability of vertex i will be * src_bias[i] / (sum of all source biases) - * @param [in] dst_bias Bias for selecting destination vertices. If NULL, do + * @param [in] dst_biases Bias for selecting destination vertices. If NULL, do * uniform sampling, if provided probability of vertex i - * will be dst_bias[i] / (sum of all destination biases) + * will be dst_bias[i] / (sum of all destination biases) * @param [in] remove_duplicates If true, remove duplicates from sampled edges - * @param [in] remove_false_negatives If true, remove sampled edges that actually exist in the - * graph + * @param [in] remove_existing_edges If true, remove sampled edges that actually exist in + * the graph * @param [in] exact_number_of_samples If true, result should contain exactly @p num_samples. If * false the code will generate @p num_samples and then do - * any filtering as specified + * any filtering as specified * @param [in] do_expensive_check A flag to run expensive checks for input arguments (if - * set to true) + * set to true) * @param [out] result Opaque pointer to generated coo list * @param [out] error Pointer to an error object storing details of any error. - * Will be populated if error code is not CUGRAPH_SUCCESS + * Will be populated if error code is not CUGRAPH_SUCCESS * @return error code */ cugraph_error_code_t cugraph_negative_sampling( @@ -716,10 +717,10 @@ cugraph_error_code_t cugraph_negative_sampling( cugraph_graph_t* graph, size_t num_samples, const cugraph_type_erased_device_array_view_t* vertices, - const cugraph_type_erased_device_array_view_t* src_bias, - const cugraph_type_erased_device_array_view_t* dst_bias, + const cugraph_type_erased_device_array_view_t* src_biases, + const cugraph_type_erased_device_array_view_t* dst_biases, bool_t remove_duplicates, - bool_t remove_false_negatives, + bool_t remove_existing_edges, bool_t exact_number_of_samples, bool_t do_expensive_check, cugraph_coo_t** result, diff --git a/cpp/src/c_api/negative_sampling.cpp b/cpp/src/c_api/negative_sampling.cpp index 1996755e536..4db5a8b8535 100644 --- a/cpp/src/c_api/negative_sampling.cpp +++ b/cpp/src/c_api/negative_sampling.cpp @@ -38,10 +38,10 @@ struct negative_sampling_functor : public cugraph::c_api::abstract_functor { cugraph::c_api::cugraph_graph_t* graph_{nullptr}; size_t num_samples_; cugraph::c_api::cugraph_type_erased_device_array_view_t const* vertices_{nullptr}; - cugraph::c_api::cugraph_type_erased_device_array_view_t const* src_bias_{nullptr}; - cugraph::c_api::cugraph_type_erased_device_array_view_t const* dst_bias_{nullptr}; + cugraph::c_api::cugraph_type_erased_device_array_view_t const* src_biases_{nullptr}; + cugraph::c_api::cugraph_type_erased_device_array_view_t const* dst_biases_{nullptr}; bool remove_duplicates_{false}; - bool remove_false_negatives_{false}; + bool remove_existing_edges_{false}; bool exact_number_of_samples_{false}; bool do_expensive_check_{false}; cugraph::c_api::cugraph_coo_t* result_{nullptr}; @@ -51,10 +51,10 @@ struct negative_sampling_functor : public cugraph::c_api::abstract_functor { cugraph_graph_t* graph, size_t num_samples, const cugraph_type_erased_device_array_view_t* vertices, - const cugraph_type_erased_device_array_view_t* src_bias, - const cugraph_type_erased_device_array_view_t* dst_bias, + const cugraph_type_erased_device_array_view_t* src_biases, + const cugraph_type_erased_device_array_view_t* dst_biases, bool_t remove_duplicates, - bool_t remove_false_negatives, + bool_t remove_existing_edges, bool_t exact_number_of_samples, bool_t do_expensive_check) : abstract_functor(), @@ -64,12 +64,12 @@ struct negative_sampling_functor : public cugraph::c_api::abstract_functor { num_samples_(num_samples), vertices_( reinterpret_cast(vertices)), - src_bias_( - reinterpret_cast(src_bias)), - dst_bias_( - reinterpret_cast(dst_bias)), + src_biases_(reinterpret_cast( + src_biases)), + dst_biases_(reinterpret_cast( + dst_biases)), remove_duplicates_(remove_duplicates), - remove_false_negatives_(remove_false_negatives), + remove_existing_edges_(remove_existing_edges), exact_number_of_samples_(exact_number_of_samples), do_expensive_check_(do_expensive_check) { @@ -103,25 +103,25 @@ struct negative_sampling_functor : public cugraph::c_api::abstract_functor { auto number_map = reinterpret_cast*>(graph_->number_map_); rmm::device_uvector vertices(0, handle_.get_stream()); - rmm::device_uvector src_bias(0, handle_.get_stream()); - rmm::device_uvector dst_bias(0, handle_.get_stream()); + rmm::device_uvector src_biases(0, handle_.get_stream()); + rmm::device_uvector dst_biases(0, handle_.get_stream()); - // TODO: What is required here? - - if (src_bias_ != nullptr) { + if (src_biases_ != nullptr) { vertices.resize(vertices_->size_, handle_.get_stream()); - src_bias.resize(src_bias_->size_, handle_.get_stream()); + src_biases.resize(src_biases_->size_, handle_.get_stream()); raft::copy( vertices.data(), vertices_->as_type(), vertices.size(), handle_.get_stream()); - raft::copy( - src_bias.data(), src_bias_->as_type(), src_bias.size(), handle_.get_stream()); + raft::copy(src_biases.data(), + src_biases_->as_type(), + src_biases.size(), + handle_.get_stream()); - src_bias = cugraph::detail:: + src_biases = cugraph::detail:: collect_local_vertex_values_from_ext_vertex_value_pairs( handle_, std::move(vertices), - std::move(src_bias), + std::move(src_biases), *number_map, graph_view.local_vertex_partition_range_first(), graph_view.local_vertex_partition_range_last(), @@ -129,20 +129,22 @@ struct negative_sampling_functor : public cugraph::c_api::abstract_functor { do_expensive_check_); } - if (dst_bias_ != nullptr) { + if (dst_biases_ != nullptr) { vertices.resize(vertices_->size_, handle_.get_stream()); - dst_bias.resize(dst_bias_->size_, handle_.get_stream()); + dst_biases.resize(dst_biases_->size_, handle_.get_stream()); raft::copy( vertices.data(), vertices_->as_type(), vertices.size(), handle_.get_stream()); - raft::copy( - dst_bias.data(), dst_bias_->as_type(), dst_bias.size(), handle_.get_stream()); + raft::copy(dst_biases.data(), + dst_biases_->as_type(), + dst_biases.size(), + handle_.get_stream()); - dst_bias = cugraph::detail:: + dst_biases = cugraph::detail:: collect_local_vertex_values_from_ext_vertex_value_pairs( handle_, std::move(vertices), - std::move(dst_bias), + std::move(dst_biases), *number_map, graph_view.local_vertex_partition_range_first(), graph_view.local_vertex_partition_range_last(), @@ -155,14 +157,14 @@ struct negative_sampling_functor : public cugraph::c_api::abstract_functor { rng_state_->rng_state_, graph_view, num_samples_, - (src_bias_ != nullptr) - ? std::make_optional(raft::device_span{src_bias.data(), src_bias.size()}) - : std::nullopt, - (dst_bias_ != nullptr) - ? std::make_optional(raft::device_span{dst_bias.data(), dst_bias.size()}) - : std::nullopt, + (src_biases_ != nullptr) ? std::make_optional(raft::device_span{ + src_biases.data(), src_biases.size()}) + : std::nullopt, + (dst_biases_ != nullptr) ? std::make_optional(raft::device_span{ + dst_biases.data(), dst_biases.size()}) + : std::nullopt, remove_duplicates_, - remove_false_negatives_, + remove_existing_edges_, exact_number_of_samples_, do_expensive_check_); @@ -202,10 +204,10 @@ cugraph_error_code_t cugraph_negative_sampling( cugraph_graph_t* graph, size_t num_samples, const cugraph_type_erased_device_array_view_t* vertices, - const cugraph_type_erased_device_array_view_t* src_bias, - const cugraph_type_erased_device_array_view_t* dst_bias, + const cugraph_type_erased_device_array_view_t* src_biases, + const cugraph_type_erased_device_array_view_t* dst_biases, bool_t remove_duplicates, - bool_t remove_false_negatives, + bool_t remove_existing_edges, bool_t exact_number_of_samples, bool_t do_expensive_check, cugraph_coo_t** result, @@ -216,10 +218,10 @@ cugraph_error_code_t cugraph_negative_sampling( graph, num_samples, vertices, - src_bias, - dst_bias, + src_biases, + dst_biases, remove_duplicates, - remove_false_negatives, + remove_existing_edges, exact_number_of_samples, do_expensive_check}; return cugraph::c_api::run_algorithm(graph, functor, result, error); diff --git a/cpp/src/sampling/negative_sampling_impl.cuh b/cpp/src/sampling/negative_sampling_impl.cuh index fd14b7f5fef..881e2691ca5 100644 --- a/cpp/src/sampling/negative_sampling_impl.cuh +++ b/cpp/src/sampling/negative_sampling_impl.cuh @@ -280,11 +280,11 @@ std::tuple, rmm::device_uvector> negativ while (samples_in_this_batch > 0) { if constexpr (multi_gpu) { - size_t num_gpus = handle.get_comms().get_size(); - size_t rank = handle.get_comms().get_rank(); + size_t comm_size = handle.get_comms().get_size(); + size_t comm_rank = handle.get_comms().get_rank(); - samples_in_this_batch = - (samples_in_this_batch / num_gpus) + (rank < (samples_in_this_batch % num_gpus) ? 1 : 0); + samples_in_this_batch = (samples_in_this_batch / comm_size) + + (comm_rank < (samples_in_this_batch % comm_size) ? 1 : 0); } auto batch_src = create_local_samples( @@ -292,20 +292,22 @@ std::tuple, rmm::device_uvector> negativ auto batch_dst = create_local_samples( handle, rng_state, graph_view, normalized_dst_biases, gpu_dst_biases, samples_in_this_batch); - auto vertex_partition_range_lasts = graph_view.vertex_partition_range_lasts(); - - std::tie(batch_src, batch_dst, std::ignore, std::ignore, std::ignore, std::ignore) = - detail::shuffle_int_vertex_pairs_with_values_to_local_gpu_by_edge_partitioning( - handle, - std::move(batch_src), - std::move(batch_dst), - std::nullopt, - std::nullopt, - std::nullopt, - vertex_partition_range_lasts); + if constexpr (multi_gpu) { + auto vertex_partition_range_lasts = graph_view.vertex_partition_range_lasts(); + + std::tie(batch_src, batch_dst, std::ignore, std::ignore, std::ignore, std::ignore) = + detail::shuffle_int_vertex_pairs_with_values_to_local_gpu_by_edge_partitioning( + handle, + std::move(batch_src), + std::move(batch_dst), + std::nullopt, + std::nullopt, + std::nullopt, + vertex_partition_range_lasts); + } if (remove_existing_edges) { auto has_edge_flags = @@ -314,12 +316,13 @@ std::tuple, rmm::device_uvector> negativ raft::device_span{batch_dst.data(), batch_dst.size()}, do_expensive_check); - auto begin_iter = - thrust::make_zip_iterator(batch_src.begin(), batch_dst.begin(), has_edge_flags.begin()); - auto new_end = thrust::remove_if(handle.get_thrust_policy(), + auto begin_iter = thrust::make_zip_iterator(batch_src.begin(), batch_dst.begin()); + auto new_end = thrust::remove_if(handle.get_thrust_policy(), begin_iter, begin_iter + batch_src.size(), - [] __device__(auto tuple) { return thrust::get<2>(tuple); }); + has_edge_flags.begin(), + thrust::identity()); + batch_src.resize(thrust::distance(begin_iter, new_end), handle.get_stream()); batch_dst.resize(thrust::distance(begin_iter, new_end), handle.get_stream()); } @@ -331,13 +334,14 @@ std::tuple, rmm::device_uvector> negativ auto new_end = thrust::unique(handle.get_thrust_policy(), begin_iter, begin_iter + batch_src.size()); - size_t unique_size = thrust::distance(begin_iter, new_end); + batch_src.resize(thrust::distance(begin_iter, new_end), handle.get_stream()); + batch_dst.resize(thrust::distance(begin_iter, new_end), handle.get_stream()); if (src.size() > 0) { new_end = thrust::remove_if(handle.get_thrust_policy(), begin_iter, - begin_iter + unique_size, + begin_iter + batch_src.size(), [local_src = raft::device_span{src.data(), src.size()}, local_dst = raft::device_span{ dst.data(), dst.size()}] __device__(auto tuple) { @@ -348,14 +352,25 @@ std::tuple, rmm::device_uvector> negativ tuple); }); - unique_size = thrust::distance(begin_iter, new_end); - } + size_t unique_size = thrust::distance(begin_iter, new_end); - batch_src.resize(unique_size, handle.get_stream()); - batch_dst.resize(unique_size, handle.get_stream()); - } + rmm::device_uvector new_src(src.size() + unique_size, handle.get_stream()); + rmm::device_uvector new_dst(dst.size() + unique_size, handle.get_stream()); - if (src.size() > 0) { + thrust::merge(handle.get_thrust_policy(), + begin_iter, + begin_iter + unique_size, + thrust::make_zip_iterator(src.begin(), dst.begin()), + thrust::make_zip_iterator(src.end(), dst.end()), + thrust::make_zip_iterator(new_src.begin(), new_dst.begin())); + + src = std::move(new_src); + dst = std::move(new_dst); + } else { + src = std::move(batch_src); + dst = std::move(batch_dst); + } + } else if (src.size() > 0) { size_t current_end = src.size(); src.resize(src.size() + batch_src.size(), handle.get_stream()); @@ -365,17 +380,9 @@ std::tuple, rmm::device_uvector> negativ thrust::make_zip_iterator(batch_src.begin(), batch_dst.begin()), thrust::make_zip_iterator(batch_src.end(), batch_dst.end()), thrust::make_zip_iterator(src.begin(), dst.begin()) + current_end); - - auto begin_iter = thrust::make_zip_iterator(src.begin(), dst.begin()); - thrust::sort(handle.get_thrust_policy(), begin_iter, begin_iter + src.size()); } else { src = std::move(batch_src); dst = std::move(batch_dst); - - if (!remove_duplicates) { - auto begin_iter = thrust::make_zip_iterator(src.begin(), dst.begin()); - thrust::sort(handle.get_thrust_policy(), begin_iter, begin_iter + src.size()); - } } if (exact_number_of_samples) { diff --git a/cpp/tests/sampling/mg_negative_sampling.cu b/cpp/tests/sampling/mg_negative_sampling.cu index 0bc6bc2e737..0d556f11810 100644 --- a/cpp/tests/sampling/mg_negative_sampling.cu +++ b/cpp/tests/sampling/mg_negative_sampling.cu @@ -30,7 +30,7 @@ struct Negative_Sampling_Usecase { bool use_src_bias{false}; bool use_dst_bias{false}; bool remove_duplicates{false}; - bool remove_false_negatives{false}; + bool remove_existing_edges{false}; bool exact_number_of_samples{false}; bool check_correctness{true}; }; @@ -127,7 +127,7 @@ class Tests_MGNegative_Sampling : public ::testing::TestWithParam 0"; if ((negative_sampling_usecase.remove_duplicates) && (src_out.size() > 0)) { +#if 0 + raft::print_device_vector("SRC", src_out.data(), src_out.size(), std::cout); + raft::print_device_vector("DST", dst_out.data(), dst_out.size(), std::cout); +#endif + error_count = thrust::count_if( handle_->get_thrust_policy(), thrust::make_counting_iterator(1), thrust::make_counting_iterator(src_out.size()), [src = src_out.data(), dst = dst_out.data()] __device__(size_t index) { + if ((src[index - 1] == src[index]) && (dst[index - 1] == dst[index])) + printf(" (%d,%d) : (%d, %d) are duplicates\n", + (int)src[index - 1], + (int)dst[index - 1], + (int)src[index], + (int)dst[index]); return (src[index - 1] == src[index]) && (dst[index - 1] == dst[index]); }); ASSERT_EQ(error_count, 0) << "Remove duplicates specified, found duplicate entries"; } - if (negative_sampling_usecase.remove_false_negatives) { + if (negative_sampling_usecase.remove_existing_edges) { rmm::device_uvector graph_src(0, handle_->get_stream()); rmm::device_uvector graph_dst(0, handle_->get_stream()); @@ -203,7 +214,7 @@ class Tests_MGNegative_Sampling : public ::testing::TestWithParam Date: Tue, 13 Aug 2024 11:28:53 -0700 Subject: [PATCH 11/18] Fix a few straggling references to remove_false_negatives, refactor a bit to address CUDA 11.8 failures --- cpp/include/cugraph/sampling_functions.hpp | 4 +- cpp/src/sampling/negative_sampling_impl.cuh | 51 +++++++++---------- .../sampling/negative_sampling_mg_v32_e32.cu | 4 +- .../sampling/negative_sampling_mg_v32_e64.cu | 4 +- .../sampling/negative_sampling_mg_v64_e64.cu | 4 +- .../sampling/negative_sampling_sg_v32_e32.cu | 4 +- .../sampling/negative_sampling_sg_v32_e64.cu | 4 +- .../sampling/negative_sampling_sg_v64_e64.cu | 4 +- cpp/tests/sampling/negative_sampling.cu | 45 +++++++++++----- 9 files changed, 72 insertions(+), 52 deletions(-) diff --git a/cpp/include/cugraph/sampling_functions.hpp b/cpp/include/cugraph/sampling_functions.hpp index 9c747c54745..208964bdf03 100644 --- a/cpp/include/cugraph/sampling_functions.hpp +++ b/cpp/include/cugraph/sampling_functions.hpp @@ -749,12 +749,12 @@ lookup_endpoints_from_edge_ids_and_types( * This function generates negative samples for graph. * * Negative sampling is done by generating a random graph according to the specified - * parameters and optionally removing the false negatives. + * parameters and optionally removing samples that represent actual edges in the graph * * Sampling occurs by creating a list of source vertex ids from biased samping * of the source vertex space, and destination vertex ids from biased sampling of the * destination vertex space, and using this as the putative list of edges. We - * then can optionally remove duplicates and remove false negatives to generate + * then can optionally remove duplicates and remove actual edges in the graph to generate * the final list. If necessary we will repeat the process to end with a resulting * edge list of the appropriate size. * diff --git a/cpp/src/sampling/negative_sampling_impl.cuh b/cpp/src/sampling/negative_sampling_impl.cuh index 881e2691ca5..bdb31bcd364 100644 --- a/cpp/src/sampling/negative_sampling_impl.cuh +++ b/cpp/src/sampling/negative_sampling_impl.cuh @@ -328,48 +328,44 @@ std::tuple, rmm::device_uvector> negativ } if (remove_duplicates) { - auto begin_iter = thrust::make_zip_iterator(batch_src.begin(), batch_dst.begin()); - thrust::sort(handle.get_thrust_policy(), begin_iter, begin_iter + batch_src.size()); + thrust::sort(handle.get_thrust_policy(), + thrust::make_zip_iterator(batch_src.begin(), batch_dst.begin()), + thrust::make_zip_iterator(batch_src.end(), batch_dst.end())); - auto new_end = - thrust::unique(handle.get_thrust_policy(), begin_iter, begin_iter + batch_src.size()); + auto new_end = thrust::unique(handle.get_thrust_policy(), + thrust::make_zip_iterator(batch_src.begin(), batch_dst.begin()), + thrust::make_zip_iterator(batch_src.end(), batch_dst.end())); - batch_src.resize(thrust::distance(begin_iter, new_end), handle.get_stream()); - batch_dst.resize(thrust::distance(begin_iter, new_end), handle.get_stream()); + size_t new_size = + thrust::distance(thrust::make_zip_iterator(batch_src.begin(), batch_dst.begin()), new_end); if (src.size() > 0) { - new_end = - thrust::remove_if(handle.get_thrust_policy(), - begin_iter, - begin_iter + batch_src.size(), - [local_src = raft::device_span{src.data(), src.size()}, - local_dst = raft::device_span{ - dst.data(), dst.size()}] __device__(auto tuple) { - return thrust::binary_search( - thrust::seq, - thrust::make_zip_iterator(local_src.begin(), local_dst.begin()), - thrust::make_zip_iterator(local_src.end(), local_dst.end()), - tuple); - }); - - size_t unique_size = thrust::distance(begin_iter, new_end); - - rmm::device_uvector new_src(src.size() + unique_size, handle.get_stream()); - rmm::device_uvector new_dst(dst.size() + unique_size, handle.get_stream()); + rmm::device_uvector new_src(src.size() + new_size, handle.get_stream()); + rmm::device_uvector new_dst(dst.size() + new_size, handle.get_stream()); thrust::merge(handle.get_thrust_policy(), - begin_iter, - begin_iter + unique_size, + thrust::make_zip_iterator(batch_src.begin(), batch_dst.begin()), + new_end, thrust::make_zip_iterator(src.begin(), dst.begin()), thrust::make_zip_iterator(src.end(), dst.end()), thrust::make_zip_iterator(new_src.begin(), new_dst.begin())); + new_end = thrust::unique(handle.get_thrust_policy(), + thrust::make_zip_iterator(new_src.begin(), new_dst.begin()), + thrust::make_zip_iterator(new_src.end(), new_dst.end())); + + new_size = + thrust::distance(thrust::make_zip_iterator(new_src.begin(), new_dst.begin()), new_end); + src = std::move(new_src); dst = std::move(new_dst); } else { src = std::move(batch_src); dst = std::move(batch_dst); } + + src.resize(new_size, handle.get_stream()); + dst.resize(new_size, handle.get_stream()); } else if (src.size() > 0) { size_t current_end = src.size(); @@ -401,6 +397,9 @@ std::tuple, rmm::device_uvector> negativ } } + src.shrink_to_fit(handle.get_stream()); + dst.shrink_to_fit(handle.get_stream()); + return std::make_tuple(std::move(src), std::move(dst)); } diff --git a/cpp/src/sampling/negative_sampling_mg_v32_e32.cu b/cpp/src/sampling/negative_sampling_mg_v32_e32.cu index fe00bb16747..92ccd5deeb5 100644 --- a/cpp/src/sampling/negative_sampling_mg_v32_e32.cu +++ b/cpp/src/sampling/negative_sampling_mg_v32_e32.cu @@ -29,7 +29,7 @@ template std::tuple, rmm::device_uvector> std::optional> src_bias, std::optional> dst_bias, bool remove_duplicates, - bool remove_false_negatives, + bool remove_existing_edges, bool exact_number_of_samples, bool do_expensive_check); @@ -41,7 +41,7 @@ template std::tuple, rmm::device_uvector> std::optional> src_bias, std::optional> dst_bias, bool remove_duplicates, - bool remove_false_negatives, + bool remove_existing_edges, bool exact_number_of_samples, bool do_expensive_check); diff --git a/cpp/src/sampling/negative_sampling_mg_v32_e64.cu b/cpp/src/sampling/negative_sampling_mg_v32_e64.cu index 403257103f8..83158a07527 100644 --- a/cpp/src/sampling/negative_sampling_mg_v32_e64.cu +++ b/cpp/src/sampling/negative_sampling_mg_v32_e64.cu @@ -29,7 +29,7 @@ template std::tuple, rmm::device_uvector> std::optional> src_bias, std::optional> dst_bias, bool remove_duplicates, - bool remove_false_negatives, + bool remove_existing_edges, bool exact_number_of_samples, bool do_expensive_check); @@ -41,7 +41,7 @@ template std::tuple, rmm::device_uvector> std::optional> src_bias, std::optional> dst_bias, bool remove_duplicates, - bool remove_false_negatives, + bool remove_existing_edges, bool exact_number_of_samples, bool do_expensive_check); diff --git a/cpp/src/sampling/negative_sampling_mg_v64_e64.cu b/cpp/src/sampling/negative_sampling_mg_v64_e64.cu index b3941b9db13..cdca1f078d0 100644 --- a/cpp/src/sampling/negative_sampling_mg_v64_e64.cu +++ b/cpp/src/sampling/negative_sampling_mg_v64_e64.cu @@ -29,7 +29,7 @@ template std::tuple, rmm::device_uvector> std::optional> src_bias, std::optional> dst_bias, bool remove_duplicates, - bool remove_false_negatives, + bool remove_existing_edges, bool exact_number_of_samples, bool do_expensive_check); @@ -41,7 +41,7 @@ template std::tuple, rmm::device_uvector> std::optional> src_bias, std::optional> dst_bias, bool remove_duplicates, - bool remove_false_negatives, + bool remove_existing_edges, bool exact_number_of_samples, bool do_expensive_check); diff --git a/cpp/src/sampling/negative_sampling_sg_v32_e32.cu b/cpp/src/sampling/negative_sampling_sg_v32_e32.cu index b9fbaba76be..1d784b5e408 100644 --- a/cpp/src/sampling/negative_sampling_sg_v32_e32.cu +++ b/cpp/src/sampling/negative_sampling_sg_v32_e32.cu @@ -29,7 +29,7 @@ template std::tuple, rmm::device_uvector> std::optional> src_bias, std::optional> dst_bias, bool remove_duplicates, - bool remove_false_negatives, + bool remove_existing_edges, bool exact_number_of_samples, bool do_expensive_check); @@ -41,7 +41,7 @@ template std::tuple, rmm::device_uvector> std::optional> src_bias, std::optional> dst_bias, bool remove_duplicates, - bool remove_false_negatives, + bool remove_existing_edges, bool exact_number_of_samples, bool do_expensive_check); diff --git a/cpp/src/sampling/negative_sampling_sg_v32_e64.cu b/cpp/src/sampling/negative_sampling_sg_v32_e64.cu index 6db40b327af..d42a5351eb3 100644 --- a/cpp/src/sampling/negative_sampling_sg_v32_e64.cu +++ b/cpp/src/sampling/negative_sampling_sg_v32_e64.cu @@ -29,7 +29,7 @@ template std::tuple, rmm::device_uvector> std::optional> src_bias, std::optional> dst_bias, bool remove_duplicates, - bool remove_false_negatives, + bool remove_existing_edges, bool exact_number_of_samples, bool do_expensive_check); @@ -41,7 +41,7 @@ template std::tuple, rmm::device_uvector> std::optional> src_bias, std::optional> dst_bias, bool remove_duplicates, - bool remove_false_negatives, + bool remove_existing_edges, bool exact_number_of_samples, bool do_expensive_check); diff --git a/cpp/src/sampling/negative_sampling_sg_v64_e64.cu b/cpp/src/sampling/negative_sampling_sg_v64_e64.cu index 0c5152b21c5..d9365a1519d 100644 --- a/cpp/src/sampling/negative_sampling_sg_v64_e64.cu +++ b/cpp/src/sampling/negative_sampling_sg_v64_e64.cu @@ -29,7 +29,7 @@ template std::tuple, rmm::device_uvector> std::optional> src_bias, std::optional> dst_bias, bool remove_duplicates, - bool remove_false_negatives, + bool remove_existing_edges, bool exact_number_of_samples, bool do_expensive_check); @@ -41,7 +41,7 @@ template std::tuple, rmm::device_uvector> std::optional> src_bias, std::optional> dst_bias, bool remove_duplicates, - bool remove_false_negatives, + bool remove_existing_edges, bool exact_number_of_samples, bool do_expensive_check); diff --git a/cpp/tests/sampling/negative_sampling.cu b/cpp/tests/sampling/negative_sampling.cu index 1d714b85271..febab5f389d 100644 --- a/cpp/tests/sampling/negative_sampling.cu +++ b/cpp/tests/sampling/negative_sampling.cu @@ -29,7 +29,7 @@ struct Negative_Sampling_Usecase { bool use_src_bias{false}; bool use_dst_bias{false}; bool remove_duplicates{false}; - bool remove_false_negatives{false}; + bool remove_existing_edges{false}; bool exact_number_of_samples{false}; bool check_correctness{true}; }; @@ -127,7 +127,7 @@ class Tests_Negative_Sampling : public ::testing::TestWithParam src_bias, dst_bias, negative_sampling_usecase.remove_duplicates, - negative_sampling_usecase.remove_false_negatives, + negative_sampling_usecase.remove_existing_edges, negative_sampling_usecase.exact_number_of_samples, do_expensive_check); @@ -175,7 +175,7 @@ class Tests_Negative_Sampling : public ::testing::TestWithParam ASSERT_EQ(count, 0) << "Remove duplicates specified, found duplicate entries"; } - if (negative_sampling_usecase.remove_false_negatives) { + if (negative_sampling_usecase.remove_existing_edges) { rmm::device_uvector graph_src(0, handle.get_stream()); rmm::device_uvector graph_dst(0, handle.get_stream()); @@ -187,15 +187,36 @@ class Tests_Negative_Sampling : public ::testing::TestWithParam handle.get_thrust_policy(), thrust::make_zip_iterator(src_out.begin(), dst_out.begin()), thrust::make_zip_iterator(src_out.end(), dst_out.end()), - [src = graph_src.data(), dst = graph_dst.data(), size = graph_dst.size()] __device__( - auto tuple) { - return thrust::binary_search(thrust::seq, - thrust::make_zip_iterator(src, dst), - thrust::make_zip_iterator(src, dst) + size, - tuple); - }); - - ASSERT_EQ(count, 0) << "Remove false negatives specified, found false negatives"; + cuda::proclaim_return_type( + [src = raft::device_span{graph_src.data(), graph_src.size()}, + dst = raft::device_span{graph_dst.data(), + graph_dst.size()}] __device__(auto tuple) { +#if 0 + // FIXME: This fails on rocky linux CUDA 11.8, works on CUDA 12 + return thrust::binary_search(thrust::seq, + thrust::make_zip_iterator(src.begin(), dst.begin()), + thrust::make_zip_iterator(src.end(), dst.end()), + tuple) ? size_t{1} : size_t{0}; +#else + auto lb = thrust::distance( + src.begin(), + thrust::lower_bound(thrust::seq, src.begin(), src.end(), thrust::get<0>(tuple))); + auto ub = thrust::distance( + src.begin(), + thrust::upper_bound(thrust::seq, src.begin(), src.end(), thrust::get<0>(tuple))); + + if (src.data()[lb] == thrust::get<0>(tuple)) { + return thrust::binary_search( + thrust::seq, dst.begin() + lb, dst.begin() + ub, thrust::get<1>(tuple)) + ? size_t{1} + : size_t{0}; + } else { + return size_t{0}; + } +#endif + })); + + ASSERT_EQ(count, 0) << "Remove existing edges specified, found existing edges"; } if (negative_sampling_usecase.exact_number_of_samples) { From 905d1b6e6e5bd955a3bc851862d2a4e502d36d03 Mon Sep 17 00:00:00 2001 From: Charles Hastings Date: Thu, 15 Aug 2024 11:12:43 -0700 Subject: [PATCH 12/18] refactor negative sampling based on PR comments --- cpp/src/sampling/negative_sampling_impl.cuh | 130 ++++++++++---------- 1 file changed, 65 insertions(+), 65 deletions(-) diff --git a/cpp/src/sampling/negative_sampling_impl.cuh b/cpp/src/sampling/negative_sampling_impl.cuh index bdb31bcd364..32b56ce6fee 100644 --- a/cpp/src/sampling/negative_sampling_impl.cuh +++ b/cpp/src/sampling/negative_sampling_impl.cuh @@ -50,49 +50,53 @@ std::tuple>, std::optional>> normalize_biases(raft::handle_t const& handle, graph_view_t const& graph_view, - std::optional> biases) + raft::device_span biases) { std::optional> normalized_biases{std::nullopt}; std::optional> gpu_biases{std::nullopt}; - if (biases) { - // Need to normalize the biases - normalized_biases = - std::make_optional>(biases->size(), handle.get_stream()); + // Need to normalize the biases + normalized_biases = + std::make_optional>(biases.size(), handle.get_stream()); - weight_t sum = - thrust::reduce(handle.get_thrust_policy(), biases->begin(), biases->end(), weight_t{0}); + weight_t sum = + thrust::reduce(handle.get_thrust_policy(), biases.begin(), biases.end(), weight_t{0}); - weight_t aggregate_sum{sum}; + thrust::transform(handle.get_thrust_policy(), + biases.begin(), + biases.end(), + normalized_biases->begin(), + divider_t{sum}); - if constexpr (multi_gpu) { - aggregate_sum = - host_scalar_allreduce(handle.get_comms(), sum, raft::comms::op_t::SUM, handle.get_stream()); - } - - thrust::transform(handle.get_thrust_policy(), - biases->begin(), - biases->end(), - normalized_biases->begin(), - divider_t{sum}); + thrust::inclusive_scan(handle.get_thrust_policy(), + normalized_biases->begin(), + normalized_biases->end(), + normalized_biases->begin()); - thrust::inclusive_scan(handle.get_thrust_policy(), - normalized_biases->begin(), - normalized_biases->end(), - normalized_biases->begin()); + if constexpr (multi_gpu) { + // rmm::device_scalar d_sum((sum / aggregate_sum), handle.get_stream()); + rmm::device_scalar d_sum(sum, handle.get_stream()); - if constexpr (multi_gpu) { - rmm::device_scalar d_sum((sum / aggregate_sum), handle.get_stream()); - gpu_biases = cugraph::device_allgatherv( - handle, handle.get_comms(), raft::device_span{d_sum.data(), d_sum.size()}); + gpu_biases = cugraph::device_allgatherv( + handle, handle.get_comms(), raft::device_span{d_sum.data(), d_sum.size()}); - thrust::inclusive_scan( - handle.get_thrust_policy(), gpu_biases->begin(), gpu_biases->end(), gpu_biases->begin()); + weight_t aggregate_sum = thrust::reduce( + handle.get_thrust_policy(), gpu_biases->begin(), gpu_biases->end(), weight_t{0}); - weight_t force_to_one{1.1}; - raft::update_device( - gpu_biases->data() + gpu_biases->size() - 1, &force_to_one, 1, handle.get_stream()); - } + thrust::transform(handle.get_thrust_policy(), + gpu_biases->begin(), + gpu_biases->end(), + gpu_biases->begin(), + divider_t{aggregate_sum}); + + thrust::inclusive_scan( + handle.get_thrust_policy(), gpu_biases->begin(), gpu_biases->end(), gpu_biases->begin()); + +#if 0 + weight_t force_to_one{1.1}; + raft::update_device( + gpu_biases->data() + gpu_biases->size() - 1, &force_to_one, 1, handle.get_stream()); +#endif } return std::make_tuple(std::move(normalized_biases), std::move(gpu_biases)); @@ -122,11 +126,7 @@ rmm::device_uvector create_local_samples( if constexpr (multi_gpu) { // Determine how many vertices are generated on each GPU auto const comm_size = handle.get_comms().get_size(); - auto const rank = handle.get_comms().get_rank(); - auto& major_comm = handle.get_subcomm(cugraph::partition_manager::major_comm_name()); - auto const major_comm_size = major_comm.get_size(); - auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); - auto const minor_comm_size = minor_comm.get_size(); + auto const comm_rank = handle.get_comms().get_rank(); sample_count_from_each_gpu.resize(comm_size); @@ -158,29 +158,18 @@ rmm::device_uvector create_local_samples( thrust::adjacent_difference( handle.get_thrust_policy(), gpu_counts.begin(), gpu_counts.end(), gpu_counts.begin()); - // all_gpu_counts[i][j] will be how many vertices need to be generated on GPU j to be sent to - // GPU i - auto all_gpu_counts = cugraph::device_allgatherv( - handle, - handle.get_comms(), - raft::device_span{gpu_counts.data(), gpu_counts.size()}); - - auto begin_iter = thrust::make_transform_iterator( - thrust::make_counting_iterator(0), - cuda::proclaim_return_type( - [rank, stride = comm_size, counts = all_gpu_counts.data()] __device__(size_t idx) { - return counts[idx * stride + rank]; - })); + std::vector tx_counts(gpu_counts.size()); + std::fill(tx_counts.begin(), tx_counts.end(), size_t{1}); - samples_to_generate = - thrust::reduce(handle.get_thrust_policy(), begin_iter, begin_iter + comm_size, size_t{0}); + rmm::device_uvector d_sample_count_from_each_gpu(0, handle.get_stream()); - rmm::device_uvector d_sample_count_from_each_gpu(comm_size, handle.get_stream()); + std::tie(d_sample_count_from_each_gpu, std::ignore) = + shuffle_values(handle.get_comms(), gpu_counts.begin(), tx_counts, handle.get_stream()); - thrust::copy(handle.get_thrust_policy(), - begin_iter, - begin_iter + comm_size, - d_sample_count_from_each_gpu.begin()); + samples_to_generate = thrust::reduce(handle.get_thrust_policy(), + d_sample_count_from_each_gpu.begin(), + d_sample_count_from_each_gpu.end(), + size_t{0}); raft::update_host(sample_count_from_each_gpu.data(), d_sample_count_from_each_gpu.data(), @@ -273,18 +262,29 @@ std::tuple, rmm::device_uvector> negativ // Normalize the biases and (for MG) determine how the biases are // distributed across the GPUs. - auto [normalized_src_biases, gpu_src_biases] = - detail::normalize_biases(handle, graph_view, src_biases); - auto [normalized_dst_biases, gpu_dst_biases] = - detail::normalize_biases(handle, graph_view, dst_biases); + std::optional> normalized_src_biases{std::nullopt}; + std::optional> gpu_src_biases{std::nullopt}; + std::optional> normalized_dst_biases{std::nullopt}; + std::optional> gpu_dst_biases{std::nullopt}; + + if (src_biases) + std::tie(normalized_src_biases, gpu_src_biases) = + detail::normalize_biases(handle, graph_view, *src_biases); + + if (dst_biases) + std::tie(normalized_dst_biases, gpu_dst_biases) = + detail::normalize_biases(handle, graph_view, *dst_biases); while (samples_in_this_batch > 0) { if constexpr (multi_gpu) { - size_t comm_size = handle.get_comms().get_size(); - size_t comm_rank = handle.get_comms().get_rank(); + auto const comm_size = handle.get_comms().get_size(); + auto const comm_rank = handle.get_comms().get_rank(); - samples_in_this_batch = (samples_in_this_batch / comm_size) + - (comm_rank < (samples_in_this_batch % comm_size) ? 1 : 0); + samples_in_this_batch = + (samples_in_this_batch / static_cast(comm_size)) + + (static_cast(comm_rank) < (samples_in_this_batch % static_cast(comm_size)) + ? 1 + : 0); } auto batch_src = create_local_samples( From dbc0b3827086e69ec672307bdeb4b58a6e5d127a Mon Sep 17 00:00:00 2001 From: Charles Hastings Date: Thu, 15 Aug 2024 11:13:28 -0700 Subject: [PATCH 13/18] start refactoring to make tests .cpp files --- cpp/tests/CMakeLists.txt | 1 + cpp/tests/sampling/negative_sampling.cu | 85 +++---- cpp/tests/utilities/validation_utilities.cu | 239 +++++++++++++++++++ cpp/tests/utilities/validation_utilities.hpp | 58 +++++ 4 files changed, 327 insertions(+), 56 deletions(-) create mode 100644 cpp/tests/utilities/validation_utilities.cu create mode 100644 cpp/tests/utilities/validation_utilities.hpp diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 2289841ff19..13d66b64078 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -39,6 +39,7 @@ add_library(cugraphtestutil STATIC utilities/misc_utilities.cpp utilities/conversion_utilities_sg.cu utilities/debug_utilities_sg.cpp + utilities/validation_utilities.cu link_prediction/similarity_compare.cpp centrality/betweenness_centrality_validate.cu community/egonet_validate.cu diff --git a/cpp/tests/sampling/negative_sampling.cu b/cpp/tests/sampling/negative_sampling.cu index febab5f389d..348cb497b03 100644 --- a/cpp/tests/sampling/negative_sampling.cu +++ b/cpp/tests/sampling/negative_sampling.cu @@ -17,6 +17,7 @@ #include "utilities/base_fixture.hpp" #include "utilities/conversion_utilities.hpp" #include "utilities/property_generator_utilities.hpp" +#include "utilities/validation_utilities.hpp" #include #include @@ -140,38 +141,30 @@ class Tests_Negative_Sampling : public ::testing::TestWithParam if (negative_sampling_usecase.check_correctness) { ASSERT_EQ(src_out.size(), dst_out.size()) << "Result size (src, dst) mismatch"; + cugraph::test::sort(handle, + raft::device_span{src_out.data(), src_out.size()}, + raft::device_span{dst_out.data(), dst_out.size()}); + auto vertex_partition = cugraph::vertex_partition_device_view_t( graph_view.local_vertex_partition_view()); - size_t count = - thrust::count_if(handle.get_thrust_policy(), - src_out.begin(), - src_out.end(), - [vertex_partition] __device__(auto val) { - return !(vertex_partition.is_valid_vertex(val) && - vertex_partition.in_local_vertex_partition_range_nocheck(val)); - }); - + size_t count = cugraph::test::count_invalid_vertices( + handle, + raft::device_span{src_out.data(), src_out.size()}, + vertex_partition); ASSERT_EQ(count, 0) << "Source vertices out of range > 0"; - count = - thrust::count_if(handle.get_thrust_policy(), - dst_out.begin(), - dst_out.end(), - [vertex_partition] __device__(auto val) { - return !(vertex_partition.is_valid_vertex(val) && - vertex_partition.in_local_vertex_partition_range_nocheck(val)); - }); + count = cugraph::test::count_invalid_vertices( + handle, + raft::device_span{dst_out.data(), dst_out.size()}, + vertex_partition); ASSERT_EQ(count, 0) << "Dest vertices out of range > 0"; if (negative_sampling_usecase.remove_duplicates) { - count = thrust::count_if( - handle.get_thrust_policy(), - thrust::make_counting_iterator(1), - thrust::make_counting_iterator(src_out.size()), - [src = src_out.data(), dst = dst_out.data()] __device__(size_t index) { - return (src[index - 1] == src[index]) && (dst[index - 1] == dst[index]); - }); + count = cugraph::test::count_duplicate_vertex_pairs_sorted( + handle, + raft::device_span{src_out.data(), src_out.size()}, + raft::device_span{dst_out.data(), dst_out.size()}); ASSERT_EQ(count, 0) << "Remove duplicates specified, found duplicate entries"; } @@ -183,38 +176,18 @@ class Tests_Negative_Sampling : public ::testing::TestWithParam cugraph::decompress_to_edgelist( handle, graph_view, std::nullopt, std::nullopt, std::nullopt, std::nullopt); - count = thrust::count_if( - handle.get_thrust_policy(), - thrust::make_zip_iterator(src_out.begin(), dst_out.begin()), - thrust::make_zip_iterator(src_out.end(), dst_out.end()), - cuda::proclaim_return_type( - [src = raft::device_span{graph_src.data(), graph_src.size()}, - dst = raft::device_span{graph_dst.data(), - graph_dst.size()}] __device__(auto tuple) { -#if 0 - // FIXME: This fails on rocky linux CUDA 11.8, works on CUDA 12 - return thrust::binary_search(thrust::seq, - thrust::make_zip_iterator(src.begin(), dst.begin()), - thrust::make_zip_iterator(src.end(), dst.end()), - tuple) ? size_t{1} : size_t{0}; -#else - auto lb = thrust::distance( - src.begin(), - thrust::lower_bound(thrust::seq, src.begin(), src.end(), thrust::get<0>(tuple))); - auto ub = thrust::distance( - src.begin(), - thrust::upper_bound(thrust::seq, src.begin(), src.end(), thrust::get<0>(tuple))); - - if (src.data()[lb] == thrust::get<0>(tuple)) { - return thrust::binary_search( - thrust::seq, dst.begin() + lb, dst.begin() + ub, thrust::get<1>(tuple)) - ? size_t{1} - : size_t{0}; - } else { - return size_t{0}; - } -#endif - })); + count = cugraph::test::count_intersection( + handle, + raft::device_span{graph_src.data(), graph_src.size()}, + raft::device_span{graph_dst.data(), graph_dst.size()}, + std::nullopt, + std::nullopt, + std::nullopt, + raft::device_span{src_out.data(), src_out.size()}, + raft::device_span{dst_out.data(), dst_out.size()}, + std::nullopt, + std::nullopt, + std::nullopt); ASSERT_EQ(count, 0) << "Remove existing edges specified, found existing edges"; } diff --git a/cpp/tests/utilities/validation_utilities.cu b/cpp/tests/utilities/validation_utilities.cu new file mode 100644 index 00000000000..973e6a1d2ce --- /dev/null +++ b/cpp/tests/utilities/validation_utilities.cu @@ -0,0 +1,239 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utilities/validation_utilities.hpp" + +#include +#include +#include +#include + +namespace cugraph::test { + +template +size_t count_invalid_vertices( + raft::handle_t const& handle, + raft::device_span vertices, + cugraph::vertex_partition_device_view_t const& vertex_partition) +{ + return thrust::count_if(handle.get_thrust_policy(), + vertices.begin(), + vertices.end(), + [vertex_partition] __device__(auto val) { + return !(vertex_partition.is_valid_vertex(val) && + vertex_partition.in_local_vertex_partition_range_nocheck(val)); + }); +} + +template +size_t count_duplicate_vertex_pairs_sorted(raft::handle_t const& handle, + raft::device_span src, + raft::device_span dst) +{ + return thrust::count_if(handle.get_thrust_policy(), + thrust::make_counting_iterator(1), + thrust::make_counting_iterator(src.size()), + [src, dst] __device__(size_t index) { + return (src[index - 1] == src[index]) && (dst[index - 1] == dst[index]); + }); +} + +// FIXME: Resolve this with dataframe_buffer variations in thrust_wrappers.cu +template +void sort(raft::handle_t const& handle, + raft::device_span srcs, + raft::device_span dsts) +{ + thrust::sort(handle.get_thrust_policy(), + thrust::make_zip_iterator(srcs.begin(), dsts.begin()), + thrust::make_zip_iterator(srcs.end(), dsts.end())); +} + +template +size_t count_intersection(raft::handle_t const& handle, + raft::device_span srcs1, + raft::device_span dsts1, + std::optional> wgts1, + std::optional> edge_ids1, + std::optional> edge_types1, + raft::device_span srcs2, + raft::device_span dsts2, + std::optional> wgts2, + std::optional> edge_ids2, + std::optional> edge_types2) +{ + // FIXME: Add support for wgts, edgeids and edge_types... + // Added to the API for future support. + + auto iter1 = thrust::make_zip_iterator(srcs1.begin(), dsts1.begin()); + auto iter2 = thrust::make_zip_iterator(srcs2.begin(), dsts2.begin()); + auto output_iter = thrust::make_discard_iterator(); + + return thrust::distance(output_iter, + thrust::set_intersection(handle.get_thrust_policy(), + iter1, + iter1 + srcs1.size(), + iter2, + iter2 + srcs2.size(), + output_iter)); +#if 0 + // OLD Approach + return thrust::count_if( + handle.get_thrust_policy(), + thrust::make_zip_iterator(src_out.begin(), dst_out.begin()), + thrust::make_zip_iterator(src_out.end(), dst_out.end()), + cuda::proclaim_return_type( + [src = raft::device_span{graph_src.data(), graph_src.size()}, + dst = raft::device_span{graph_dst.data(), + graph_dst.size()}] __device__(auto tuple) { +#if 0 + // FIXME: This fails on rocky linux CUDA 11.8, works on CUDA 12 + return thrust::binary_search(thrust::seq, + thrust::make_zip_iterator(src.begin(), dst.begin()), + thrust::make_zip_iterator(src.end(), dst.end()), + tuple) ? size_t{1} : size_t{0}; +#else + auto lb = thrust::distance( + src.begin(), + thrust::lower_bound(thrust::seq, src.begin(), src.end(), thrust::get<0>(tuple))); + auto ub = thrust::distance( + src.begin(), + thrust::upper_bound(thrust::seq, src.begin(), src.end(), thrust::get<0>(tuple))); + + if (src.data()[lb] == thrust::get<0>(tuple)) { + return thrust::binary_search( + thrust::seq, dst.begin() + lb, dst.begin() + ub, thrust::get<1>(tuple)) + ? size_t{1} + : size_t{0}; + } else { + return size_t{0}; + } +#endif + })); +#endif +} + +// TODO: Split SG from MG? +template size_t count_invalid_vertices( + raft::handle_t const& handle, + raft::device_span vertices, + cugraph::vertex_partition_device_view_t const& vertex_partition); + +template size_t count_invalid_vertices( + raft::handle_t const& handle, + raft::device_span vertices, + cugraph::vertex_partition_device_view_t const& vertex_partition); + +template size_t count_duplicate_vertex_pairs_sorted( + raft::handle_t const& handle, + raft::device_span src, + raft::device_span dst); + +template size_t count_duplicate_vertex_pairs_sorted( + raft::handle_t const& handle, + raft::device_span src, + raft::device_span dst); + +template void sort(raft::handle_t const& handle, + raft::device_span srcs, + raft::device_span dsts); +template void sort(raft::handle_t const& handle, + raft::device_span srcs, + raft::device_span dsts); + +template size_t count_intersection( + raft::handle_t const& handle, + raft::device_span srcs1, + raft::device_span dsts1, + std::optional> wgts1, + std::optional> edge_ids1, + std::optional> edge_types1, + raft::device_span srcs2, + raft::device_span dsts2, + std::optional> wgts2, + std::optional> edge_ids2, + std::optional> edge_types2); + +template size_t count_intersection( + raft::handle_t const& handle, + raft::device_span srcs1, + raft::device_span dsts1, + std::optional> wgts1, + std::optional> edge_ids1, + std::optional> edge_types1, + raft::device_span srcs2, + raft::device_span dsts2, + std::optional> wgts2, + std::optional> edge_ids2, + std::optional> edge_types2); + +template size_t count_intersection( + raft::handle_t const& handle, + raft::device_span srcs1, + raft::device_span dsts1, + std::optional> wgts1, + std::optional> edge_ids1, + std::optional> edge_types1, + raft::device_span srcs2, + raft::device_span dsts2, + std::optional> wgts2, + std::optional> edge_ids2, + std::optional> edge_types2); + +template size_t count_intersection( + raft::handle_t const& handle, + raft::device_span srcs1, + raft::device_span dsts1, + std::optional> wgts1, + std::optional> edge_ids1, + std::optional> edge_types1, + raft::device_span srcs2, + raft::device_span dsts2, + std::optional> wgts2, + std::optional> edge_ids2, + std::optional> edge_types2); + +template size_t count_intersection( + raft::handle_t const& handle, + raft::device_span srcs1, + raft::device_span dsts1, + std::optional> wgts1, + std::optional> edge_ids1, + std::optional> edge_types1, + raft::device_span srcs2, + raft::device_span dsts2, + std::optional> wgts2, + std::optional> edge_ids2, + std::optional> edge_types2); + +template size_t count_intersection( + raft::handle_t const& handle, + raft::device_span srcs1, + raft::device_span dsts1, + std::optional> wgts1, + std::optional> edge_ids1, + std::optional> edge_types1, + raft::device_span srcs2, + raft::device_span dsts2, + std::optional> wgts2, + std::optional> edge_ids2, + std::optional> edge_types2); + +} // namespace cugraph::test diff --git a/cpp/tests/utilities/validation_utilities.hpp b/cpp/tests/utilities/validation_utilities.hpp new file mode 100644 index 00000000000..10ba264b85f --- /dev/null +++ b/cpp/tests/utilities/validation_utilities.hpp @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include + +namespace cugraph::test { +template +size_t count_invalid_vertices( + raft::handle_t const& handle, + raft::device_span vertices, + cugraph::vertex_partition_device_view_t const& vertex_partition); + +template +size_t count_duplicate_vertex_pairs_sorted(raft::handle_t const& handle, + raft::device_span src, + raft::device_span dst); + +template +void sort(raft::handle_t const& handle, + raft::device_span srcs, + raft::device_span dsts); + +template +size_t count_intersection(raft::handle_t const& handle, + raft::device_span srcs1, + raft::device_span dsts1, + std::optional> wgts1, + std::optional> edge_ids1, + std::optional> edge_types1, + raft::device_span srcs2, + raft::device_span dsts2, + std::optional> wgts2, + std::optional> edge_ids2, + std::optional> edge_types2); + +} // namespace cugraph::test From 5f359870b1438f52ad41c04221d8806e10d6d059 Mon Sep 17 00:00:00 2001 From: Charles Hastings Date: Sat, 17 Aug 2024 14:06:40 -0700 Subject: [PATCH 14/18] move MG validation code into validation_utilitices.cu --- cpp/tests/CMakeLists.txt | 2 +- cpp/tests/sampling/mg_negative_sampling.cu | 81 +++---- ...tive_sampling.cu => negative_sampling.cpp} | 16 +- cpp/tests/utilities/validation_utilities.cu | 210 ++++++++++-------- cpp/tests/utilities/validation_utilities.hpp | 14 +- 5 files changed, 166 insertions(+), 157 deletions(-) rename cpp/tests/sampling/{negative_sampling.cu => negative_sampling.cpp} (96%) diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 13d66b64078..c7661849d8d 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -490,7 +490,7 @@ ConfigureTest(SAMPLING_POST_PROCESSING_TEST sampling/sampling_post_processing_te ################################################################################################### # - NEGATIVE SAMPLING tests -------------------------------------------------------------------- -ConfigureTest(NEGATIVE_SAMPLING_TEST sampling/negative_sampling.cu) +ConfigureTest(NEGATIVE_SAMPLING_TEST sampling/negative_sampling.cpp) ################################################################################################### # - Renumber tests -------------------------------------------------------------------------------- diff --git a/cpp/tests/sampling/mg_negative_sampling.cu b/cpp/tests/sampling/mg_negative_sampling.cu index 0d556f11810..702b83ded21 100644 --- a/cpp/tests/sampling/mg_negative_sampling.cu +++ b/cpp/tests/sampling/mg_negative_sampling.cu @@ -14,14 +14,13 @@ * limitations under the License. */ -#include "detail/graph_partition_utils.cuh" #include "utilities/base_fixture.hpp" #include "utilities/conversion_utilities.hpp" #include "utilities/property_generator_utilities.hpp" +#include "utilities/validation_utilities.hpp" #include #include -#include #include @@ -140,6 +139,11 @@ class Tests_MGNegative_Sampling : public ::testing::TestWithParam{src_out.data(), src_out.size()}, + raft::device_span{dst_out.data(), dst_out.size()}); + + // TODO: Move this to validation_utilities... auto h_vertex_partition_range_lasts = graph_view.vertex_partition_range_lasts(); rmm::device_uvector d_vertex_partition_range_lasts( h_vertex_partition_range_lasts.size(), handle_->get_stream()); @@ -148,49 +152,20 @@ class Tests_MGNegative_Sampling : public ::testing::TestWithParamget_stream()); - size_t error_count = thrust::count_if( - handle_->get_thrust_policy(), - thrust::make_zip_iterator(src_out.begin(), dst_out.begin()), - thrust::make_zip_iterator(src_out.end(), dst_out.end()), - [comm_rank = handle_->get_comms().get_rank(), - gpu_id_key_func = cugraph::detail::compute_gpu_id_from_int_edge_endpoints_t{ - raft::device_span{d_vertex_partition_range_lasts.data(), - d_vertex_partition_range_lasts.size()}, - handle_->get_comms().get_size(), - handle_->get_subcomm(cugraph::partition_manager::major_comm_name()).get_size(), - handle_->get_subcomm(cugraph::partition_manager::minor_comm_name()) - .get_size()}] __device__(auto e) { - if (gpu_id_key_func(thrust::get<0>(e), thrust::get<1>(e)) != comm_rank) - printf(" gpu_id(%d,%d) = %d, expected %d\n", - (int)thrust::get<0>(e), - (int)thrust::get<1>(e), - gpu_id_key_func(thrust::get<0>(e), thrust::get<1>(e)), - comm_rank); - - return (gpu_id_key_func(thrust::get<0>(e), thrust::get<1>(e)) != comm_rank); - }); + size_t error_count = cugraph::test::count_edges_on_wrong_int_gpu( + *handle_, + raft::device_span{src_out.data(), src_out.size()}, + raft::device_span{dst_out.data(), dst_out.size()}, + raft::device_span{d_vertex_partition_range_lasts.data(), + d_vertex_partition_range_lasts.size()}); ASSERT_EQ(error_count, 0) << "generate edges out of range > 0"; if ((negative_sampling_usecase.remove_duplicates) && (src_out.size() > 0)) { -#if 0 - raft::print_device_vector("SRC", src_out.data(), src_out.size(), std::cout); - raft::print_device_vector("DST", dst_out.data(), dst_out.size(), std::cout); -#endif - - error_count = thrust::count_if( - handle_->get_thrust_policy(), - thrust::make_counting_iterator(1), - thrust::make_counting_iterator(src_out.size()), - [src = src_out.data(), dst = dst_out.data()] __device__(size_t index) { - if ((src[index - 1] == src[index]) && (dst[index - 1] == dst[index])) - printf(" (%d,%d) : (%d, %d) are duplicates\n", - (int)src[index - 1], - (int)dst[index - 1], - (int)src[index], - (int)dst[index]); - return (src[index - 1] == src[index]) && (dst[index - 1] == dst[index]); - }); + error_count = cugraph::test::count_duplicate_vertex_pairs_sorted( + *handle_, + raft::device_span{src_out.data(), src_out.size()}, + raft::device_span{dst_out.data(), dst_out.size()}); ASSERT_EQ(error_count, 0) << "Remove duplicates specified, found duplicate entries"; } @@ -202,18 +177,18 @@ class Tests_MGNegative_Sampling : public ::testing::TestWithParam( *handle_, graph_view, std::nullopt, std::nullopt, std::nullopt, std::nullopt); - error_count = thrust::count_if( - handle_->get_thrust_policy(), - thrust::make_zip_iterator(src_out.begin(), dst_out.begin()), - thrust::make_zip_iterator(src_out.end(), dst_out.end()), - [src = graph_src.data(), dst = graph_dst.data(), size = graph_dst.size()] __device__( - auto tuple) { - return thrust::binary_search(thrust::seq, - thrust::make_zip_iterator(src, dst), - thrust::make_zip_iterator(src, dst) + size, - tuple); - }); - + error_count = cugraph::test::count_intersection( + *handle_, + raft::device_span{graph_src.data(), graph_src.size()}, + raft::device_span{graph_dst.data(), graph_dst.size()}, + std::nullopt, + std::nullopt, + std::nullopt, + raft::device_span{src_out.data(), src_out.size()}, + raft::device_span{dst_out.data(), dst_out.size()}, + std::nullopt, + std::nullopt, + std::nullopt); ASSERT_EQ(error_count, 0) << "Remove existing edges specified, found existing edges"; } diff --git a/cpp/tests/sampling/negative_sampling.cu b/cpp/tests/sampling/negative_sampling.cpp similarity index 96% rename from cpp/tests/sampling/negative_sampling.cu rename to cpp/tests/sampling/negative_sampling.cpp index 348cb497b03..63b5dc3442a 100644 --- a/cpp/tests/sampling/negative_sampling.cu +++ b/cpp/tests/sampling/negative_sampling.cpp @@ -148,24 +148,24 @@ class Tests_Negative_Sampling : public ::testing::TestWithParam auto vertex_partition = cugraph::vertex_partition_device_view_t( graph_view.local_vertex_partition_view()); - size_t count = cugraph::test::count_invalid_vertices( + size_t error_count = cugraph::test::count_invalid_vertices( handle, raft::device_span{src_out.data(), src_out.size()}, vertex_partition); - ASSERT_EQ(count, 0) << "Source vertices out of range > 0"; + ASSERT_EQ(error_count, 0) << "Source vertices out of range > 0"; - count = cugraph::test::count_invalid_vertices( + error_count = cugraph::test::count_invalid_vertices( handle, raft::device_span{dst_out.data(), dst_out.size()}, vertex_partition); - ASSERT_EQ(count, 0) << "Dest vertices out of range > 0"; + ASSERT_EQ(error_count, 0) << "Dest vertices out of range > 0"; if (negative_sampling_usecase.remove_duplicates) { - count = cugraph::test::count_duplicate_vertex_pairs_sorted( + error_count = cugraph::test::count_duplicate_vertex_pairs_sorted( handle, raft::device_span{src_out.data(), src_out.size()}, raft::device_span{dst_out.data(), dst_out.size()}); - ASSERT_EQ(count, 0) << "Remove duplicates specified, found duplicate entries"; + ASSERT_EQ(error_count, 0) << "Remove duplicates specified, found duplicate entries"; } if (negative_sampling_usecase.remove_existing_edges) { @@ -176,7 +176,7 @@ class Tests_Negative_Sampling : public ::testing::TestWithParam cugraph::decompress_to_edgelist( handle, graph_view, std::nullopt, std::nullopt, std::nullopt, std::nullopt); - count = cugraph::test::count_intersection( + error_count = cugraph::test::count_intersection( handle, raft::device_span{graph_src.data(), graph_src.size()}, raft::device_span{graph_dst.data(), graph_dst.size()}, @@ -189,7 +189,7 @@ class Tests_Negative_Sampling : public ::testing::TestWithParam std::nullopt, std::nullopt); - ASSERT_EQ(count, 0) << "Remove existing edges specified, found existing edges"; + ASSERT_EQ(error_count, 0) << "Remove existing edges specified, found existing edges"; } if (negative_sampling_usecase.exact_number_of_samples) { diff --git a/cpp/tests/utilities/validation_utilities.cu b/cpp/tests/utilities/validation_utilities.cu index 973e6a1d2ce..b61bd7fef8b 100644 --- a/cpp/tests/utilities/validation_utilities.cu +++ b/cpp/tests/utilities/validation_utilities.cu @@ -14,8 +14,12 @@ * limitations under the License. */ +#include "detail/graph_partition_utils.cuh" #include "utilities/validation_utilities.hpp" +// TODO: Shouldn't use this in the interface... +#include + #include #include #include @@ -38,7 +42,7 @@ size_t count_invalid_vertices( }); } -template +template size_t count_duplicate_vertex_pairs_sorted(raft::handle_t const& handle, raft::device_span src, raft::device_span dst) @@ -62,11 +66,7 @@ void sort(raft::handle_t const& handle, thrust::make_zip_iterator(srcs.end(), dsts.end())); } -template +template size_t count_intersection(raft::handle_t const& handle, raft::device_span srcs1, raft::device_span dsts1, @@ -130,6 +130,34 @@ size_t count_intersection(raft::handle_t const& handle, #endif } +template +size_t count_edges_on_wrong_int_gpu(raft::handle_t const& handle, + raft::device_span srcs, + raft::device_span dsts, + raft::device_span vertex_partition_range_lasts) +{ + return thrust::count_if( + handle.get_thrust_policy(), + thrust::make_zip_iterator(srcs.begin(), dsts.begin()), + thrust::make_zip_iterator(srcs.end(), dsts.end()), + [comm_rank = handle.get_comms().get_rank(), + gpu_id_key_func = cugraph::detail::compute_gpu_id_from_int_edge_endpoints_t{ + vertex_partition_range_lasts, + handle.get_comms().get_size(), + handle.get_subcomm(cugraph::partition_manager::major_comm_name()).get_size(), + handle.get_subcomm(cugraph::partition_manager::minor_comm_name()) + .get_size()}] __device__(auto e) { + if (gpu_id_key_func(thrust::get<0>(e), thrust::get<1>(e)) != comm_rank) + printf(" gpu_id(%d,%d) = %d, expected %d\n", + (int)thrust::get<0>(e), + (int)thrust::get<1>(e), + gpu_id_key_func(thrust::get<0>(e), thrust::get<1>(e)), + comm_rank); + + return (gpu_id_key_func(thrust::get<0>(e), thrust::get<1>(e)) != comm_rank); + }); +} + // TODO: Split SG from MG? template size_t count_invalid_vertices( raft::handle_t const& handle, @@ -141,15 +169,13 @@ template size_t count_invalid_vertices( raft::device_span vertices, cugraph::vertex_partition_device_view_t const& vertex_partition); -template size_t count_duplicate_vertex_pairs_sorted( - raft::handle_t const& handle, - raft::device_span src, - raft::device_span dst); +template size_t count_duplicate_vertex_pairs_sorted(raft::handle_t const& handle, + raft::device_span src, + raft::device_span dst); -template size_t count_duplicate_vertex_pairs_sorted( - raft::handle_t const& handle, - raft::device_span src, - raft::device_span dst); +template size_t count_duplicate_vertex_pairs_sorted(raft::handle_t const& handle, + raft::device_span src, + raft::device_span dst); template void sort(raft::handle_t const& handle, raft::device_span srcs, @@ -158,82 +184,88 @@ template void sort(raft::handle_t const& handle, raft::device_span srcs, raft::device_span dsts); -template size_t count_intersection( - raft::handle_t const& handle, - raft::device_span srcs1, - raft::device_span dsts1, - std::optional> wgts1, - std::optional> edge_ids1, - std::optional> edge_types1, - raft::device_span srcs2, - raft::device_span dsts2, - std::optional> wgts2, - std::optional> edge_ids2, - std::optional> edge_types2); - -template size_t count_intersection( - raft::handle_t const& handle, - raft::device_span srcs1, - raft::device_span dsts1, - std::optional> wgts1, - std::optional> edge_ids1, - std::optional> edge_types1, - raft::device_span srcs2, - raft::device_span dsts2, - std::optional> wgts2, - std::optional> edge_ids2, - std::optional> edge_types2); - -template size_t count_intersection( - raft::handle_t const& handle, - raft::device_span srcs1, - raft::device_span dsts1, - std::optional> wgts1, - std::optional> edge_ids1, - std::optional> edge_types1, - raft::device_span srcs2, - raft::device_span dsts2, - std::optional> wgts2, - std::optional> edge_ids2, - std::optional> edge_types2); - -template size_t count_intersection( - raft::handle_t const& handle, - raft::device_span srcs1, - raft::device_span dsts1, - std::optional> wgts1, - std::optional> edge_ids1, - std::optional> edge_types1, - raft::device_span srcs2, - raft::device_span dsts2, - std::optional> wgts2, - std::optional> edge_ids2, - std::optional> edge_types2); - -template size_t count_intersection( +template size_t count_intersection(raft::handle_t const& handle, + raft::device_span srcs1, + raft::device_span dsts1, + std::optional> wgts1, + std::optional> edge_ids1, + std::optional> edge_types1, + raft::device_span srcs2, + raft::device_span dsts2, + std::optional> wgts2, + std::optional> edge_ids2, + std::optional> edge_types2); + +template size_t count_intersection(raft::handle_t const& handle, + raft::device_span srcs1, + raft::device_span dsts1, + std::optional> wgts1, + std::optional> edge_ids1, + std::optional> edge_types1, + raft::device_span srcs2, + raft::device_span dsts2, + std::optional> wgts2, + std::optional> edge_ids2, + std::optional> edge_types2); + +template size_t count_intersection(raft::handle_t const& handle, + raft::device_span srcs1, + raft::device_span dsts1, + std::optional> wgts1, + std::optional> edge_ids1, + std::optional> edge_types1, + raft::device_span srcs2, + raft::device_span dsts2, + std::optional> wgts2, + std::optional> edge_ids2, + std::optional> edge_types2); + +template size_t count_intersection(raft::handle_t const& handle, + raft::device_span srcs1, + raft::device_span dsts1, + std::optional> wgts1, + std::optional> edge_ids1, + std::optional> edge_types1, + raft::device_span srcs2, + raft::device_span dsts2, + std::optional> wgts2, + std::optional> edge_ids2, + std::optional> edge_types2); + +template size_t count_intersection(raft::handle_t const& handle, + raft::device_span srcs1, + raft::device_span dsts1, + std::optional> wgts1, + std::optional> edge_ids1, + std::optional> edge_types1, + raft::device_span srcs2, + raft::device_span dsts2, + std::optional> wgts2, + std::optional> edge_ids2, + std::optional> edge_types2); + +template size_t count_intersection(raft::handle_t const& handle, + raft::device_span srcs1, + raft::device_span dsts1, + std::optional> wgts1, + std::optional> edge_ids1, + std::optional> edge_types1, + raft::device_span srcs2, + raft::device_span dsts2, + std::optional> wgts2, + std::optional> edge_ids2, + std::optional> edge_types2); + +template size_t count_edges_on_wrong_int_gpu( raft::handle_t const& handle, - raft::device_span srcs1, - raft::device_span dsts1, - std::optional> wgts1, - std::optional> edge_ids1, - std::optional> edge_types1, - raft::device_span srcs2, - raft::device_span dsts2, - std::optional> wgts2, - std::optional> edge_ids2, - std::optional> edge_types2); - -template size_t count_intersection( + raft::device_span srcs, + raft::device_span dsts, + raft::device_span vertex_partition_range_lasts); + +template size_t count_edges_on_wrong_int_gpu( raft::handle_t const& handle, - raft::device_span srcs1, - raft::device_span dsts1, - std::optional> wgts1, - std::optional> edge_ids1, - std::optional> edge_types1, - raft::device_span srcs2, - raft::device_span dsts2, - std::optional> wgts2, - std::optional> edge_ids2, - std::optional> edge_types2); + raft::device_span srcs, + raft::device_span dsts, + raft::device_span vertex_partition_range_lasts); } // namespace cugraph::test diff --git a/cpp/tests/utilities/validation_utilities.hpp b/cpp/tests/utilities/validation_utilities.hpp index 10ba264b85f..29b24ae6d91 100644 --- a/cpp/tests/utilities/validation_utilities.hpp +++ b/cpp/tests/utilities/validation_utilities.hpp @@ -28,7 +28,7 @@ size_t count_invalid_vertices( raft::device_span vertices, cugraph::vertex_partition_device_view_t const& vertex_partition); -template +template size_t count_duplicate_vertex_pairs_sorted(raft::handle_t const& handle, raft::device_span src, raft::device_span dst); @@ -38,11 +38,7 @@ void sort(raft::handle_t const& handle, raft::device_span srcs, raft::device_span dsts); -template +template size_t count_intersection(raft::handle_t const& handle, raft::device_span srcs1, raft::device_span dsts1, @@ -55,4 +51,10 @@ size_t count_intersection(raft::handle_t const& handle, std::optional> edge_ids2, std::optional> edge_types2); +template +size_t count_edges_on_wrong_int_gpu(raft::handle_t const& handle, + raft::device_span srcs, + raft::device_span dsts, + raft::device_span vertex_partition_range_lasts); + } // namespace cugraph::test From 06e71c48b65a5a81e5747cb2c7361c2ea8d36df7 Mon Sep 17 00:00:00 2001 From: Charles Hastings Date: Sat, 17 Aug 2024 14:07:35 -0700 Subject: [PATCH 15/18] rename sampling file --- cpp/tests/CMakeLists.txt | 2 +- .../{mg_negative_sampling.cu => mg_negative_sampling.cpp} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename cpp/tests/sampling/{mg_negative_sampling.cu => mg_negative_sampling.cpp} (100%) diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index c7661849d8d..da31f498de1 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -748,7 +748,7 @@ if(BUILD_CUGRAPH_MG_TESTS) ################################################################################################### # - NEGATIVE SAMPLING tests -------------------------------------------------------------------- - ConfigureTestMG(MG_NEGATIVE_SAMPLING_TEST sampling/mg_negative_sampling.cu) + ConfigureTestMG(MG_NEGATIVE_SAMPLING_TEST sampling/mg_negative_sampling.cpp) ############################################################################################### diff --git a/cpp/tests/sampling/mg_negative_sampling.cu b/cpp/tests/sampling/mg_negative_sampling.cpp similarity index 100% rename from cpp/tests/sampling/mg_negative_sampling.cu rename to cpp/tests/sampling/mg_negative_sampling.cpp From 94990e154c8a119092b9e1a565a9e194622b5b04 Mon Sep 17 00:00:00 2001 From: Charles Hastings Date: Sat, 17 Aug 2024 14:17:37 -0700 Subject: [PATCH 16/18] remove reference of device structure from host API --- cpp/tests/sampling/negative_sampling.cpp | 8 ++---- cpp/tests/utilities/validation_utilities.cu | 30 ++++++++------------ cpp/tests/utilities/validation_utilities.hpp | 4 +-- 3 files changed, 16 insertions(+), 26 deletions(-) diff --git a/cpp/tests/sampling/negative_sampling.cpp b/cpp/tests/sampling/negative_sampling.cpp index 63b5dc3442a..a1762a2f3fc 100644 --- a/cpp/tests/sampling/negative_sampling.cpp +++ b/cpp/tests/sampling/negative_sampling.cpp @@ -21,7 +21,6 @@ #include #include -#include #include @@ -145,19 +144,16 @@ class Tests_Negative_Sampling : public ::testing::TestWithParam raft::device_span{src_out.data(), src_out.size()}, raft::device_span{dst_out.data(), dst_out.size()}); - auto vertex_partition = cugraph::vertex_partition_device_view_t( - graph_view.local_vertex_partition_view()); - size_t error_count = cugraph::test::count_invalid_vertices( handle, raft::device_span{src_out.data(), src_out.size()}, - vertex_partition); + graph_view.local_vertex_partition_view()); ASSERT_EQ(error_count, 0) << "Source vertices out of range > 0"; error_count = cugraph::test::count_invalid_vertices( handle, raft::device_span{dst_out.data(), dst_out.size()}, - vertex_partition); + graph_view.local_vertex_partition_view()); ASSERT_EQ(error_count, 0) << "Dest vertices out of range > 0"; if (negative_sampling_usecase.remove_duplicates) { diff --git a/cpp/tests/utilities/validation_utilities.cu b/cpp/tests/utilities/validation_utilities.cu index b61bd7fef8b..3da998ad626 100644 --- a/cpp/tests/utilities/validation_utilities.cu +++ b/cpp/tests/utilities/validation_utilities.cu @@ -17,7 +17,6 @@ #include "detail/graph_partition_utils.cuh" #include "utilities/validation_utilities.hpp" -// TODO: Shouldn't use this in the interface... #include #include @@ -31,15 +30,17 @@ template size_t count_invalid_vertices( raft::handle_t const& handle, raft::device_span vertices, - cugraph::vertex_partition_device_view_t const& vertex_partition) + cugraph::vertex_partition_view_t const& vertex_partition_view) { - return thrust::count_if(handle.get_thrust_policy(), - vertices.begin(), - vertices.end(), - [vertex_partition] __device__(auto val) { - return !(vertex_partition.is_valid_vertex(val) && - vertex_partition.in_local_vertex_partition_range_nocheck(val)); - }); + return thrust::count_if( + handle.get_thrust_policy(), + vertices.begin(), + vertices.end(), + [vertex_partition = cugraph::vertex_partition_device_view_t{ + vertex_partition_view}] __device__(auto val) { + return !(vertex_partition.is_valid_vertex(val) && + vertex_partition.in_local_vertex_partition_range_nocheck(val)); + }); } template @@ -147,13 +148,6 @@ size_t count_edges_on_wrong_int_gpu(raft::handle_t const& handle, handle.get_subcomm(cugraph::partition_manager::major_comm_name()).get_size(), handle.get_subcomm(cugraph::partition_manager::minor_comm_name()) .get_size()}] __device__(auto e) { - if (gpu_id_key_func(thrust::get<0>(e), thrust::get<1>(e)) != comm_rank) - printf(" gpu_id(%d,%d) = %d, expected %d\n", - (int)thrust::get<0>(e), - (int)thrust::get<1>(e), - gpu_id_key_func(thrust::get<0>(e), thrust::get<1>(e)), - comm_rank); - return (gpu_id_key_func(thrust::get<0>(e), thrust::get<1>(e)) != comm_rank); }); } @@ -162,12 +156,12 @@ size_t count_edges_on_wrong_int_gpu(raft::handle_t const& handle, template size_t count_invalid_vertices( raft::handle_t const& handle, raft::device_span vertices, - cugraph::vertex_partition_device_view_t const& vertex_partition); + cugraph::vertex_partition_view_t const& vertex_partition_view); template size_t count_invalid_vertices( raft::handle_t const& handle, raft::device_span vertices, - cugraph::vertex_partition_device_view_t const& vertex_partition); + cugraph::vertex_partition_view_t const& vertex_partition_view); template size_t count_duplicate_vertex_pairs_sorted(raft::handle_t const& handle, raft::device_span src, diff --git a/cpp/tests/utilities/validation_utilities.hpp b/cpp/tests/utilities/validation_utilities.hpp index 29b24ae6d91..b94ceaf68be 100644 --- a/cpp/tests/utilities/validation_utilities.hpp +++ b/cpp/tests/utilities/validation_utilities.hpp @@ -16,7 +16,7 @@ #pragma once -#include +#include #include #include @@ -26,7 +26,7 @@ template size_t count_invalid_vertices( raft::handle_t const& handle, raft::device_span vertices, - cugraph::vertex_partition_device_view_t const& vertex_partition); + cugraph::vertex_partition_view_t const& vertex_partition); template size_t count_duplicate_vertex_pairs_sorted(raft::handle_t const& handle, From 996b9acca236b19e01d476d84387f47b30c7bbf6 Mon Sep 17 00:00:00 2001 From: Charles Hastings Date: Mon, 19 Aug 2024 13:27:53 -0700 Subject: [PATCH 17/18] update to accomodate GPUs with no bias --- cpp/src/sampling/negative_sampling_impl.cuh | 22 ++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/cpp/src/sampling/negative_sampling_impl.cuh b/cpp/src/sampling/negative_sampling_impl.cuh index 32b56ce6fee..b110d4d09c6 100644 --- a/cpp/src/sampling/negative_sampling_impl.cuh +++ b/cpp/src/sampling/negative_sampling_impl.cuh @@ -83,6 +83,18 @@ normalize_biases(raft::handle_t const& handle, weight_t aggregate_sum = thrust::reduce( handle.get_thrust_policy(), gpu_biases->begin(), gpu_biases->end(), weight_t{0}); + // FIXME: https://github.com/rapidsai/raft/issues/2400 results in the possibility + // that 1 can appear as a random floating point value. We're going to use + // thrust::upper_bound to assign random values to GPUs, we need the value 1.0 to + // be part of the upper-most range. We'll compute the last non-zero value in the + // gpu_biases array here and below we will fill it with a value larger than 1.0 + size_t trailing_zeros = thrust::distance( + thrust::make_reverse_iterator(gpu_biases->end()), + thrust::find_if(handle.get_thrust_policy(), + thrust::make_reverse_iterator(gpu_biases->end()), + thrust::make_reverse_iterator(gpu_biases->begin()), + [] __device__(weight_t bias) { return bias > weight_t{0}; })); + thrust::transform(handle.get_thrust_policy(), gpu_biases->begin(), gpu_biases->end(), @@ -92,11 +104,11 @@ normalize_biases(raft::handle_t const& handle, thrust::inclusive_scan( handle.get_thrust_policy(), gpu_biases->begin(), gpu_biases->end(), gpu_biases->begin()); -#if 0 - weight_t force_to_one{1.1}; - raft::update_device( - gpu_biases->data() + gpu_biases->size() - 1, &force_to_one, 1, handle.get_stream()); -#endif + // FIXME: conclusion of above. Using 1.1 since it is > 1.0 and easy to type + thrust::copy_n(handle.get_thrust_policy(), + thrust::make_constant_iterator(1.1), + trailing_zeros + 1, + gpu_biases->begin() + gpu_biases->size() - trailing_zeros - 1); } return std::make_tuple(std::move(normalized_biases), std::move(gpu_biases)); From 8a28b0b251ef12ae687999f5e8c3f6a7d51bc773 Mon Sep 17 00:00:00 2001 From: Charles Hastings Date: Tue, 20 Aug 2024 11:47:40 -0700 Subject: [PATCH 18/18] move num_samples parameter, add tests for edge masking, some cosmetic cleanup --- cpp/include/cugraph/sampling_functions.hpp | 4 +- cpp/include/cugraph_c/sampling_algorithms.h | 4 +- cpp/src/c_api/negative_sampling.cpp | 14 +-- cpp/src/sampling/negative_sampling_impl.cuh | 3 +- .../sampling/negative_sampling_mg_v32_e32.cu | 4 +- .../sampling/negative_sampling_mg_v32_e64.cu | 4 +- .../sampling/negative_sampling_mg_v64_e64.cu | 4 +- .../sampling/negative_sampling_sg_v32_e32.cu | 4 +- .../sampling/negative_sampling_sg_v32_e64.cu | 4 +- .../sampling/negative_sampling_sg_v64_e64.cu | 4 +- cpp/tests/c_api/mg_negative_sampling_test.c | 2 +- cpp/tests/c_api/negative_sampling_test.c | 2 +- cpp/tests/sampling/mg_negative_sampling.cpp | 90 +++++----------- cpp/tests/sampling/negative_sampling.cpp | 102 ++++++------------ 14 files changed, 80 insertions(+), 165 deletions(-) diff --git a/cpp/include/cugraph/sampling_functions.hpp b/cpp/include/cugraph/sampling_functions.hpp index 208964bdf03..4e5596d06e0 100644 --- a/cpp/include/cugraph/sampling_functions.hpp +++ b/cpp/include/cugraph/sampling_functions.hpp @@ -768,13 +768,13 @@ lookup_endpoints_from_edge_ids_and_types( * handles to various CUDA libraries) to run graph algorithms. * @param graph_view Graph View object to generate NBR Sampling for * @param rng_state RNG state - * @param num_samples Number of negative samples to generate * @param src_biases Optional bias for randomly selecting source vertices. If std::nullopt vertices * will be selected uniformly. In multi-GPU environment the biases should be partitioned based * on the vertex partitions. * @param dst_biases Optional bias for randomly selecting destination vertices. If std::nullopt * vertices will be selected uniformly. In multi-GPU environment the biases should be partitioned * based on the vertex partitions. + * @param num_samples Number of negative samples to generate * @param remove_duplicates If true, remove duplicate samples * @param remove_existing_edges If true, remove samples that are actually edges in the graph * @param exact_number_of_samples If true, repeat generation until we get the exact number of @@ -792,9 +792,9 @@ std::tuple, rmm::device_uvector> negativ raft::handle_t const& handle, raft::random::RngState& rng_state, graph_view_t const& graph_view, - size_t num_samples, std::optional> src_biases, std::optional> dst_biases, + size_t num_samples, bool remove_duplicates, bool remove_existing_edges, bool exact_number_of_samples, diff --git a/cpp/include/cugraph_c/sampling_algorithms.h b/cpp/include/cugraph_c/sampling_algorithms.h index ad3f8b9ea2c..bb26e577915 100644 --- a/cpp/include/cugraph_c/sampling_algorithms.h +++ b/cpp/include/cugraph_c/sampling_algorithms.h @@ -685,7 +685,6 @@ cugraph_error_code_t cugraph_select_random_vertices(const cugraph_resource_handl * @param [in,out] rng_state State of the random number generator, updated with each * call * @param [in] graph Pointer to graph - * @param [in] num_samples Number of negative samples to generate * @param [in] vertices Vertex ids for the source biases. If @p src_bias and * @p dst_bias are not specified this is ignored. If * @p vertices is specified then vertices[i] is the vertex @@ -698,6 +697,7 @@ cugraph_error_code_t cugraph_select_random_vertices(const cugraph_resource_handl * @param [in] dst_biases Bias for selecting destination vertices. If NULL, do * uniform sampling, if provided probability of vertex i * will be dst_bias[i] / (sum of all destination biases) + * @param [in] num_samples Number of negative samples to generate * @param [in] remove_duplicates If true, remove duplicates from sampled edges * @param [in] remove_existing_edges If true, remove sampled edges that actually exist in * the graph @@ -715,10 +715,10 @@ cugraph_error_code_t cugraph_negative_sampling( const cugraph_resource_handle_t* handle, cugraph_rng_state_t* rng_state, cugraph_graph_t* graph, - size_t num_samples, const cugraph_type_erased_device_array_view_t* vertices, const cugraph_type_erased_device_array_view_t* src_biases, const cugraph_type_erased_device_array_view_t* dst_biases, + size_t num_samples, bool_t remove_duplicates, bool_t remove_existing_edges, bool_t exact_number_of_samples, diff --git a/cpp/src/c_api/negative_sampling.cpp b/cpp/src/c_api/negative_sampling.cpp index 4db5a8b8535..54f465d67b4 100644 --- a/cpp/src/c_api/negative_sampling.cpp +++ b/cpp/src/c_api/negative_sampling.cpp @@ -36,10 +36,10 @@ struct negative_sampling_functor : public cugraph::c_api::abstract_functor { raft::handle_t const& handle_; cugraph::c_api::cugraph_rng_state_t* rng_state_{nullptr}; cugraph::c_api::cugraph_graph_t* graph_{nullptr}; - size_t num_samples_; cugraph::c_api::cugraph_type_erased_device_array_view_t const* vertices_{nullptr}; cugraph::c_api::cugraph_type_erased_device_array_view_t const* src_biases_{nullptr}; cugraph::c_api::cugraph_type_erased_device_array_view_t const* dst_biases_{nullptr}; + size_t num_samples_; bool remove_duplicates_{false}; bool remove_existing_edges_{false}; bool exact_number_of_samples_{false}; @@ -49,10 +49,10 @@ struct negative_sampling_functor : public cugraph::c_api::abstract_functor { negative_sampling_functor(const cugraph_resource_handle_t* handle, cugraph_rng_state_t* rng_state, cugraph_graph_t* graph, - size_t num_samples, const cugraph_type_erased_device_array_view_t* vertices, const cugraph_type_erased_device_array_view_t* src_biases, const cugraph_type_erased_device_array_view_t* dst_biases, + size_t num_samples, bool_t remove_duplicates, bool_t remove_existing_edges, bool_t exact_number_of_samples, @@ -61,13 +61,13 @@ struct negative_sampling_functor : public cugraph::c_api::abstract_functor { handle_(*reinterpret_cast(handle)->handle_), rng_state_(reinterpret_cast(rng_state)), graph_(reinterpret_cast(graph)), - num_samples_(num_samples), vertices_( reinterpret_cast(vertices)), src_biases_(reinterpret_cast( src_biases)), dst_biases_(reinterpret_cast( dst_biases)), + num_samples_(num_samples), remove_duplicates_(remove_duplicates), remove_existing_edges_(remove_existing_edges), exact_number_of_samples_(exact_number_of_samples), @@ -87,7 +87,7 @@ struct negative_sampling_functor : public cugraph::c_api::abstract_functor { if constexpr (!cugraph::is_candidate::value) { unsupported(); } else { - // uniform_nbr_sample expects store_transposed == false + // negative_sampling expects store_transposed == false if constexpr (store_transposed) { error_code_ = cugraph::c_api:: transpose_storage( @@ -156,13 +156,13 @@ struct negative_sampling_functor : public cugraph::c_api::abstract_functor { handle_, rng_state_->rng_state_, graph_view, - num_samples_, (src_biases_ != nullptr) ? std::make_optional(raft::device_span{ src_biases.data(), src_biases.size()}) : std::nullopt, (dst_biases_ != nullptr) ? std::make_optional(raft::device_span{ dst_biases.data(), dst_biases.size()}) : std::nullopt, + num_samples_, remove_duplicates_, remove_existing_edges_, exact_number_of_samples_, @@ -202,10 +202,10 @@ cugraph_error_code_t cugraph_negative_sampling( const cugraph_resource_handle_t* handle, cugraph_rng_state_t* rng_state, cugraph_graph_t* graph, - size_t num_samples, const cugraph_type_erased_device_array_view_t* vertices, const cugraph_type_erased_device_array_view_t* src_biases, const cugraph_type_erased_device_array_view_t* dst_biases, + size_t num_samples, bool_t remove_duplicates, bool_t remove_existing_edges, bool_t exact_number_of_samples, @@ -216,10 +216,10 @@ cugraph_error_code_t cugraph_negative_sampling( negative_sampling_functor functor{handle, rng_state, graph, - num_samples, vertices, src_biases, dst_biases, + num_samples, remove_duplicates, remove_existing_edges, exact_number_of_samples, diff --git a/cpp/src/sampling/negative_sampling_impl.cuh b/cpp/src/sampling/negative_sampling_impl.cuh index b110d4d09c6..93bb03077bc 100644 --- a/cpp/src/sampling/negative_sampling_impl.cuh +++ b/cpp/src/sampling/negative_sampling_impl.cuh @@ -74,7 +74,6 @@ normalize_biases(raft::handle_t const& handle, normalized_biases->begin()); if constexpr (multi_gpu) { - // rmm::device_scalar d_sum((sum / aggregate_sum), handle.get_stream()); rmm::device_scalar d_sum(sum, handle.get_stream()); gpu_biases = cugraph::device_allgatherv( @@ -258,9 +257,9 @@ std::tuple, rmm::device_uvector> negativ raft::handle_t const& handle, raft::random::RngState& rng_state, graph_view_t const& graph_view, - size_t num_samples, std::optional> src_biases, std::optional> dst_biases, + size_t num_samples, bool remove_duplicates, bool remove_existing_edges, bool exact_number_of_samples, diff --git a/cpp/src/sampling/negative_sampling_mg_v32_e32.cu b/cpp/src/sampling/negative_sampling_mg_v32_e32.cu index 92ccd5deeb5..ce54d54d319 100644 --- a/cpp/src/sampling/negative_sampling_mg_v32_e32.cu +++ b/cpp/src/sampling/negative_sampling_mg_v32_e32.cu @@ -25,9 +25,9 @@ template std::tuple, rmm::device_uvector> raft::handle_t const& handle, raft::random::RngState& rng_state, graph_view_t const& graph_view, - size_t num_samples, std::optional> src_bias, std::optional> dst_bias, + size_t num_samples, bool remove_duplicates, bool remove_existing_edges, bool exact_number_of_samples, @@ -37,9 +37,9 @@ template std::tuple, rmm::device_uvector> raft::handle_t const& handle, raft::random::RngState& rng_state, graph_view_t const& graph_view, - size_t num_samples, std::optional> src_bias, std::optional> dst_bias, + size_t num_samples, bool remove_duplicates, bool remove_existing_edges, bool exact_number_of_samples, diff --git a/cpp/src/sampling/negative_sampling_mg_v32_e64.cu b/cpp/src/sampling/negative_sampling_mg_v32_e64.cu index 83158a07527..af4c28c0f1a 100644 --- a/cpp/src/sampling/negative_sampling_mg_v32_e64.cu +++ b/cpp/src/sampling/negative_sampling_mg_v32_e64.cu @@ -25,9 +25,9 @@ template std::tuple, rmm::device_uvector> raft::handle_t const& handle, raft::random::RngState& rng_state, graph_view_t const& graph_view, - size_t num_samples, std::optional> src_bias, std::optional> dst_bias, + size_t num_samples, bool remove_duplicates, bool remove_existing_edges, bool exact_number_of_samples, @@ -37,9 +37,9 @@ template std::tuple, rmm::device_uvector> raft::handle_t const& handle, raft::random::RngState& rng_state, graph_view_t const& graph_view, - size_t num_samples, std::optional> src_bias, std::optional> dst_bias, + size_t num_samples, bool remove_duplicates, bool remove_existing_edges, bool exact_number_of_samples, diff --git a/cpp/src/sampling/negative_sampling_mg_v64_e64.cu b/cpp/src/sampling/negative_sampling_mg_v64_e64.cu index cdca1f078d0..c5691fb4644 100644 --- a/cpp/src/sampling/negative_sampling_mg_v64_e64.cu +++ b/cpp/src/sampling/negative_sampling_mg_v64_e64.cu @@ -25,9 +25,9 @@ template std::tuple, rmm::device_uvector> raft::handle_t const& handle, raft::random::RngState& rng_state, graph_view_t const& graph_view, - size_t num_samples, std::optional> src_bias, std::optional> dst_bias, + size_t num_samples, bool remove_duplicates, bool remove_existing_edges, bool exact_number_of_samples, @@ -37,9 +37,9 @@ template std::tuple, rmm::device_uvector> raft::handle_t const& handle, raft::random::RngState& rng_state, graph_view_t const& graph_view, - size_t num_samples, std::optional> src_bias, std::optional> dst_bias, + size_t num_samples, bool remove_duplicates, bool remove_existing_edges, bool exact_number_of_samples, diff --git a/cpp/src/sampling/negative_sampling_sg_v32_e32.cu b/cpp/src/sampling/negative_sampling_sg_v32_e32.cu index 1d784b5e408..3712414e4ec 100644 --- a/cpp/src/sampling/negative_sampling_sg_v32_e32.cu +++ b/cpp/src/sampling/negative_sampling_sg_v32_e32.cu @@ -25,9 +25,9 @@ template std::tuple, rmm::device_uvector> raft::handle_t const& handle, raft::random::RngState& rng_state, graph_view_t const& graph_view, - size_t num_samples, std::optional> src_bias, std::optional> dst_bias, + size_t num_samples, bool remove_duplicates, bool remove_existing_edges, bool exact_number_of_samples, @@ -37,9 +37,9 @@ template std::tuple, rmm::device_uvector> raft::handle_t const& handle, raft::random::RngState& rng_state, graph_view_t const& graph_view, - size_t num_samples, std::optional> src_bias, std::optional> dst_bias, + size_t num_samples, bool remove_duplicates, bool remove_existing_edges, bool exact_number_of_samples, diff --git a/cpp/src/sampling/negative_sampling_sg_v32_e64.cu b/cpp/src/sampling/negative_sampling_sg_v32_e64.cu index d42a5351eb3..c66c31a4258 100644 --- a/cpp/src/sampling/negative_sampling_sg_v32_e64.cu +++ b/cpp/src/sampling/negative_sampling_sg_v32_e64.cu @@ -25,9 +25,9 @@ template std::tuple, rmm::device_uvector> raft::handle_t const& handle, raft::random::RngState& rng_state, graph_view_t const& graph_view, - size_t num_samples, std::optional> src_bias, std::optional> dst_bias, + size_t num_samples, bool remove_duplicates, bool remove_existing_edges, bool exact_number_of_samples, @@ -37,9 +37,9 @@ template std::tuple, rmm::device_uvector> raft::handle_t const& handle, raft::random::RngState& rng_state, graph_view_t const& graph_view, - size_t num_samples, std::optional> src_bias, std::optional> dst_bias, + size_t num_samples, bool remove_duplicates, bool remove_existing_edges, bool exact_number_of_samples, diff --git a/cpp/src/sampling/negative_sampling_sg_v64_e64.cu b/cpp/src/sampling/negative_sampling_sg_v64_e64.cu index d9365a1519d..e4fc50890e4 100644 --- a/cpp/src/sampling/negative_sampling_sg_v64_e64.cu +++ b/cpp/src/sampling/negative_sampling_sg_v64_e64.cu @@ -25,9 +25,9 @@ template std::tuple, rmm::device_uvector> raft::handle_t const& handle, raft::random::RngState& rng_state, graph_view_t const& graph_view, - size_t num_samples, std::optional> src_bias, std::optional> dst_bias, + size_t num_samples, bool remove_duplicates, bool remove_existing_edges, bool exact_number_of_samples, @@ -37,9 +37,9 @@ template std::tuple, rmm::device_uvector> raft::handle_t const& handle, raft::random::RngState& rng_state, graph_view_t const& graph_view, - size_t num_samples, std::optional> src_bias, std::optional> dst_bias, + size_t num_samples, bool remove_duplicates, bool remove_existing_edges, bool exact_number_of_samples, diff --git a/cpp/tests/c_api/mg_negative_sampling_test.c b/cpp/tests/c_api/mg_negative_sampling_test.c index 566524251ed..3289206d8db 100644 --- a/cpp/tests/c_api/mg_negative_sampling_test.c +++ b/cpp/tests/c_api/mg_negative_sampling_test.c @@ -126,10 +126,10 @@ int generic_negative_sampling_test(const cugraph_resource_handle_t* handle, ret_code = cugraph_negative_sampling(handle, rng_state, graph, - num_samples, d_vertices_view, d_src_bias_view, d_dst_bias_view, + num_samples, remove_duplicates, remove_false_negatives, exact_number_of_samples, diff --git a/cpp/tests/c_api/negative_sampling_test.c b/cpp/tests/c_api/negative_sampling_test.c index abea4028061..5e8d3f7e765 100644 --- a/cpp/tests/c_api/negative_sampling_test.c +++ b/cpp/tests/c_api/negative_sampling_test.c @@ -118,10 +118,10 @@ int generic_negative_sampling_test(const cugraph_resource_handle_t* handle, ret_code = cugraph_negative_sampling(handle, rng_state, graph, - num_samples, d_vertices_view, d_src_bias_view, d_dst_bias_view, + num_samples, remove_duplicates, remove_false_negatives, exact_number_of_samples, diff --git a/cpp/tests/sampling/mg_negative_sampling.cpp b/cpp/tests/sampling/mg_negative_sampling.cpp index 702b83ded21..7c64bb7fbbb 100644 --- a/cpp/tests/sampling/mg_negative_sampling.cpp +++ b/cpp/tests/sampling/mg_negative_sampling.cpp @@ -31,6 +31,7 @@ struct Negative_Sampling_Usecase { bool remove_duplicates{false}; bool remove_existing_edges{false}; bool exact_number_of_samples{false}; + bool edge_masking{false}; bool check_correctness{true}; }; @@ -65,6 +66,9 @@ class Tests_MGNegative_Sampling : public ::testing::TestWithParam::edge_property(*handle_, graph_.view(), 2); } virtual void SetUp() {} @@ -79,7 +83,10 @@ class Tests_MGNegative_Sampling : public ::testing::TestWithParamview()); } + + size_t num_samples = + graph_view.compute_number_of_edges(*handle_) * negative_sampling_usecase.sample_multiplier; rmm::device_uvector src_bias_v(0, handle_->get_stream()); rmm::device_uvector dst_bias_v(0, handle_->get_stream()); @@ -122,9 +129,9 @@ class Tests_MGNegative_Sampling : public ::testing::TestWithParam> edge_weights_{std::nullopt}; + std::optional> edge_mask_{std::nullopt}; std::optional> renumber_map_labels_{std::nullopt}; }; @@ -227,70 +235,20 @@ void run_all_tests(CurrentTest* current_test) raft::random::RngState rng_state{ static_cast(current_test->handle_->get_comms().get_rank())}; - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, false, false, false, false, false, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, true, false, false, false, false, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, false, true, false, false, false, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, true, true, false, false, false, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, false, false, true, false, false, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, true, false, true, false, false, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, false, true, true, false, false, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, true, true, true, false, false, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, false, false, false, true, false, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, true, false, false, true, false, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, false, true, false, true, false, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, true, true, false, true, false, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, false, false, true, true, false, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, true, false, true, true, false, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, false, true, true, true, false, true}); - current_test->run_current_test(rng_state, - Negative_Sampling_Usecase{2, true, true, true, true, false, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, false, false, false, false, true, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, true, false, false, false, true, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, false, true, false, false, true, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, true, true, false, false, true, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, false, false, true, false, true, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, true, false, true, false, true, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, false, true, true, false, true, true}); - current_test->run_current_test(rng_state, - Negative_Sampling_Usecase{2, true, true, true, false, true, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, false, false, false, true, true, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, true, false, false, true, true, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, false, true, false, true, true, true}); - current_test->run_current_test(rng_state, - Negative_Sampling_Usecase{2, true, true, false, true, true, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, false, false, true, true, true, true}); - current_test->run_current_test(rng_state, - Negative_Sampling_Usecase{2, true, false, true, true, true, true}); - current_test->run_current_test(rng_state, - Negative_Sampling_Usecase{2, false, true, true, true, true, true}); - current_test->run_current_test(rng_state, - Negative_Sampling_Usecase{2, true, true, true, true, true, true}); + for (bool use_src_bias : {false, true}) + for (bool use_dst_bias : {false, true}) + for (bool remove_duplicates : {false, true}) + for (bool remove_existing_edges : {false, true}) + for (bool exact_number_of_samples : {false, true}) + for (bool edge_masking : {false, true}) + current_test->run_current_test(rng_state, + Negative_Sampling_Usecase{2, + use_src_bias, + use_dst_bias, + remove_duplicates, + remove_existing_edges, + exact_number_of_samples, + edge_masking}); } TEST_P(Tests_MGNegative_Sampling_File_i64_i64_float, CheckInt64Int64Float) diff --git a/cpp/tests/sampling/negative_sampling.cpp b/cpp/tests/sampling/negative_sampling.cpp index a1762a2f3fc..ba929c63e9b 100644 --- a/cpp/tests/sampling/negative_sampling.cpp +++ b/cpp/tests/sampling/negative_sampling.cpp @@ -31,6 +31,7 @@ struct Negative_Sampling_Usecase { bool remove_duplicates{false}; bool remove_existing_edges{false}; bool exact_number_of_samples{false}; + bool edge_masking{false}; bool check_correctness{true}; }; @@ -40,7 +41,7 @@ class Tests_Negative_Sampling : public ::testing::TestWithParam using graph_t = cugraph::graph_t; using graph_view_t = cugraph::graph_view_t; - Tests_Negative_Sampling() : graph(raft::handle_t{}) {} + Tests_Negative_Sampling() : graph_(raft::handle_t{}) {} static void SetUpTestCase() {} static void TearDownTestCase() {} @@ -56,7 +57,7 @@ class Tests_Negative_Sampling : public ::testing::TestWithParam hr_timer.start("Construct graph"); } - std::tie(graph, edge_weights, renumber_map_labels) = + std::tie(graph_, edge_weights_, renumber_map_labels_) = cugraph::test::construct_graph( handle, param, true, true); @@ -65,6 +66,9 @@ class Tests_Negative_Sampling : public ::testing::TestWithParam hr_timer.stop(); hr_timer.display_and_clear(std::cout); } + + edge_mask_ = + cugraph::test::generate::edge_property(handle, graph_.view(), 2); } virtual void SetUp() {} @@ -78,9 +82,12 @@ class Tests_Negative_Sampling : public ::testing::TestWithParam raft::handle_t handle{}; HighResTimer hr_timer{}; - auto graph_view = graph.view(); + auto graph_view = graph_.view(); + + if (negative_sampling_usecase.edge_masking) { graph_view.attach_edge_mask(edge_mask_->view()); } - size_t num_samples = graph_view.number_of_edges() * negative_sampling_usecase.sample_multiplier; + size_t num_samples = + graph_view.compute_number_of_edges(handle) * negative_sampling_usecase.sample_multiplier; rmm::device_uvector src_bias_v(0, handle.get_stream()); rmm::device_uvector dst_bias_v(0, handle.get_stream()); @@ -123,9 +130,9 @@ class Tests_Negative_Sampling : public ::testing::TestWithParam cugraph::negative_sampling(handle, rng_state, graph_view, - num_samples, src_bias, dst_bias, + num_samples, negative_sampling_usecase.remove_duplicates, negative_sampling_usecase.remove_existing_edges, negative_sampling_usecase.exact_number_of_samples, @@ -197,9 +204,10 @@ class Tests_Negative_Sampling : public ::testing::TestWithParam } private: - graph_t graph; - std::optional> edge_weights{std::nullopt}; - std::optional> renumber_map_labels{std::nullopt}; + graph_t graph_; + std::optional> edge_weights_{std::nullopt}; + std::optional> edge_mask_{std::nullopt}; + std::optional> renumber_map_labels_{std::nullopt}; }; using Tests_Negative_Sampling_File_i32_i32_float = @@ -225,70 +233,20 @@ void run_all_tests(CurrentTest* current_test) { raft::random::RngState rng_state{0}; - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, false, false, false, false, false, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, true, false, false, false, false, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, false, true, false, false, false, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, true, true, false, false, false, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, false, false, true, false, false, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, true, false, true, false, false, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, false, true, true, false, false, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, true, true, true, false, false, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, false, false, false, true, false, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, true, false, false, true, false, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, false, true, false, true, false, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, true, true, false, true, false, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, false, false, true, true, false, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, true, false, true, true, false, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, false, true, true, true, false, true}); - current_test->run_current_test(rng_state, - Negative_Sampling_Usecase{2, true, true, true, true, false, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, false, false, false, false, true, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, true, false, false, false, true, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, false, true, false, false, true, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, true, true, false, false, true, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, false, false, true, false, true, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, true, false, true, false, true, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, false, true, true, false, true, true}); - current_test->run_current_test(rng_state, - Negative_Sampling_Usecase{2, true, true, true, false, true, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, false, false, false, true, true, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, true, false, false, true, true, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, false, true, false, true, true, true}); - current_test->run_current_test(rng_state, - Negative_Sampling_Usecase{2, true, true, false, true, true, true}); - current_test->run_current_test( - rng_state, Negative_Sampling_Usecase{2, false, false, true, true, true, true}); - current_test->run_current_test(rng_state, - Negative_Sampling_Usecase{2, true, false, true, true, true, true}); - current_test->run_current_test(rng_state, - Negative_Sampling_Usecase{2, false, true, true, true, true, true}); - current_test->run_current_test(rng_state, - Negative_Sampling_Usecase{2, true, true, true, true, true, true}); + for (bool use_src_bias : {false, true}) + for (bool use_dst_bias : {false, true}) + for (bool remove_duplicates : {false, true}) + for (bool remove_existing_edges : {false, true}) + for (bool exact_number_of_samples : {false, true}) + for (bool edge_masking : {false, true}) + current_test->run_current_test(rng_state, + Negative_Sampling_Usecase{2, + use_src_bias, + use_dst_bias, + remove_duplicates, + remove_existing_edges, + exact_number_of_samples, + edge_masking}); } TEST_P(Tests_Negative_Sampling_File_i32_i32_float, CheckInt32Int32Float)