From 30465c2a6d053d57a8a75951656d54a416e402be Mon Sep 17 00:00:00 2001 From: Naim <110031745+naimnv@users.noreply.github.com> Date: Sat, 25 May 2024 02:48:50 +0200 Subject: [PATCH 1/7] Fix bug in kv_store_t's insertion methods (#4444) Update size_ field of kv_cuco_store_t with correct values. Authors: - Naim (https://github.com/naimnv) Approvers: - Chuck Hastings (https://github.com/ChuckHastings) - Seunghwa Kang (https://github.com/seunghwak) URL: https://github.com/rapidsai/cugraph/pull/4444 --- cpp/src/prims/kv_store.cuh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/src/prims/kv_store.cuh b/cpp/src/prims/kv_store.cuh index 76b64b5692b..5001a20bb83 100644 --- a/cpp/src/prims/kv_store.cuh +++ b/cpp/src/prims/kv_store.cuh @@ -584,7 +584,7 @@ class kv_cuco_store_t { store_value_offsets.end(), kv_cuco_insert_and_increment_t{ mutable_device_ref, key_first, counter.data(), std::numeric_limits::max()}); - size_ += counter.value(stream); + size_ = counter.value(stream); resize_optional_dataframe_buffer(store_values_, size_, stream); thrust::scatter_if(rmm::exec_policy(stream), value_first, @@ -636,7 +636,7 @@ class kv_cuco_store_t { pred_op, counter.data(), std::numeric_limits::max()}); - size_ += counter.value(stream); + size_ = counter.value(stream); resize_optional_dataframe_buffer(store_values_, size_, stream); thrust::scatter_if(rmm::exec_policy(stream), value_first, @@ -688,7 +688,7 @@ class kv_cuco_store_t { store_value_offsets.end(), kv_cuco_insert_and_increment_t{ mutable_device_ref, key_first, counter.data(), std::numeric_limits::max()}); - size_ += counter.value(stream); + size_ = counter.value(stream); resize_optional_dataframe_buffer(store_values_, size_, stream); thrust::scatter_if(rmm::exec_policy(stream), value_first, From 1c3f3a8ffb0e22ab0674aff79e675706bbba5f2c Mon Sep 17 00:00:00 2001 From: Joseph Nke <76006812+jnke2016@users.noreply.github.com> Date: Tue, 28 May 2024 14:40:27 +0100 Subject: [PATCH 2/7] Move edge triangle count to the stable API (#4382) This PR 1. Performs edge triangle count in chunk 2. Enables k - 1 core optimization 3. Add C++ tests for edge triangle count 4. Move edge triangle count to the stable API 5. Implement MG edge triangle count and add tests 6. Update 'mg_graph_to_sg_graph' to support 'edge_ids' along with tests closes #4370 closes #4371 Authors: - Joseph Nke (https://github.com/jnke2016) - Rick Ratzel (https://github.com/rlratzel) Approvers: - Chuck Hastings (https://github.com/ChuckHastings) - Seunghwa Kang (https://github.com/seunghwak) URL: https://github.com/rapidsai/cugraph/pull/4382 --- cpp/CMakeLists.txt | 1 + cpp/include/cugraph/algorithms.hpp | 18 + .../community/edge_triangle_count_impl.cuh | 361 ++++++++++++++---- cpp/src/community/edge_triangle_count_mg.cu | 33 ++ cpp/src/community/edge_triangle_count_sg.cu | 18 +- cpp/src/community/k_truss_impl.cuh | 292 +++++++------- cpp/tests/CMakeLists.txt | 9 + .../mg_betweenness_centrality_test.cpp | 16 +- .../mg_edge_betweenness_centrality_test.cpp | 14 +- .../mg_eigenvector_centrality_test.cpp | 16 +- .../centrality/mg_katz_centrality_test.cpp | 16 +- .../community/edge_triangle_count_test.cpp | 260 +++++++++++++ cpp/tests/community/mg_ecg_test.cpp | 14 +- .../community/mg_edge_triangle_count_test.cpp | 253 ++++++++++++ cpp/tests/community/mg_egonet_test.cu | 16 +- cpp/tests/community/mg_leiden_test.cpp | 14 +- cpp/tests/community/mg_louvain_test.cpp | 14 +- .../community/mg_triangle_count_test.cpp | 16 +- .../community/mg_weighted_matching_test.cpp | 14 +- .../mg_weakly_connected_components_test.cpp | 16 +- cpp/tests/cores/mg_core_number_test.cpp | 16 +- cpp/tests/cores/mg_k_core_test.cpp | 16 +- cpp/tests/link_analysis/mg_hits_test.cpp | 16 +- cpp/tests/link_analysis/mg_pagerank_test.cpp | 16 +- cpp/tests/mtmg/threaded_test_louvain.cu | 4 +- cpp/tests/prims/mg_count_if_e.cu | 16 +- cpp/tests/prims/mg_count_if_v.cu | 16 +- cpp/tests/prims/mg_extract_transform_e.cu | 16 +- ...extract_transform_v_frontier_outgoing_e.cu | 16 +- ...r_v_pair_transform_dst_nbr_intersection.cu | 16 +- ...transform_dst_nbr_weighted_intersection.cu | 20 +- ...er_v_random_select_transform_outgoing_e.cu | 16 +- ...rm_reduce_dst_key_aggregated_outgoing_e.cu | 16 +- ..._v_transform_reduce_incoming_outgoing_e.cu | 16 +- cpp/tests/prims/mg_reduce_v.cu | 16 +- ...st_nbr_intersection_of_e_endpoints_by_v.cu | 16 +- cpp/tests/prims/mg_transform_reduce_e.cu | 16 +- .../mg_transform_reduce_e_by_src_dst_key.cu | 16 +- cpp/tests/prims/mg_transform_reduce_v.cu | 16 +- ...orm_reduce_v_frontier_outgoing_e_by_dst.cu | 16 +- cpp/tests/structure/mg_coarsen_graph_test.cpp | 17 +- ..._count_self_loops_and_multi_edges_test.cpp | 16 +- ...has_edge_and_compute_multiplicity_test.cpp | 16 +- .../structure/mg_induced_subgraph_test.cu | 14 +- cpp/tests/structure/mg_symmetrize_test.cpp | 16 +- .../structure/mg_transpose_storage_test.cpp | 16 +- cpp/tests/structure/mg_transpose_test.cpp | 16 +- cpp/tests/traversal/mg_bfs_test.cpp | 16 +- .../traversal/mg_extract_bfs_paths_test.cu | 16 +- cpp/tests/traversal/mg_k_hop_nbrs_test.cpp | 16 +- cpp/tests/traversal/mg_sssp_test.cpp | 16 +- cpp/tests/utilities/conversion_utilities.hpp | 15 +- .../utilities/conversion_utilities_impl.cuh | 25 +- .../utilities/conversion_utilities_mg.cu | 24 ++ 54 files changed, 1448 insertions(+), 514 deletions(-) create mode 100644 cpp/src/community/edge_triangle_count_mg.cu create mode 100644 cpp/tests/community/edge_triangle_count_test.cpp create mode 100644 cpp/tests/community/mg_edge_triangle_count_test.cpp diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 57e0aa2d078..2527599fece 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -180,6 +180,7 @@ set(CUGRAPH_SOURCES src/community/detail/refine_sg.cu src/community/detail/refine_mg.cu src/community/edge_triangle_count_sg.cu + src/community/edge_triangle_count_mg.cu src/community/detail/maximal_independent_moves_sg.cu src/community/detail/maximal_independent_moves_mg.cu src/detail/utility_wrappers.cu diff --git a/cpp/include/cugraph/algorithms.hpp b/cpp/include/cugraph/algorithms.hpp index 7c4a978c4b4..cc42399f091 100644 --- a/cpp/include/cugraph/algorithms.hpp +++ b/cpp/include/cugraph/algorithms.hpp @@ -2007,6 +2007,24 @@ void triangle_count(raft::handle_t const& handle, raft::device_span counts, bool do_expensive_check = false); +/* + * @brief Compute edge triangle counts. + * + * Compute edge triangle counts for the entire set of edges. + * + * @tparam vertex_t Type of vertex identifiers. Needs to be an integral type. + * @tparam edge_t Type of edge identifiers. Needs to be an integral type. + * @tparam multi_gpu Flag indicating whether template instantiation should target single-GPU (false) + * @param handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator, and + * handles to various CUDA libraries) to run graph algorithms. + * @param graph_view Graph view object. + * + * @return edge_property_t containing the edge triangle count + */ +template +edge_property_t, edge_t> edge_triangle_count( + raft::handle_t const& handle, graph_view_t const& graph_view); + /* * @brief Compute K-Truss. * diff --git a/cpp/src/community/edge_triangle_count_impl.cuh b/cpp/src/community/edge_triangle_count_impl.cuh index 1370c1a17f2..c4277e240be 100644 --- a/cpp/src/community/edge_triangle_count_impl.cuh +++ b/cpp/src/community/edge_triangle_count_impl.cuh @@ -17,12 +17,17 @@ #pragma once #include "detail/graph_partition_utils.cuh" +#include "prims/edge_bucket.cuh" +#include "prims/transform_e.cuh" #include "prims/transform_reduce_dst_nbr_intersection_of_e_endpoints_by_v.cuh" +#include #include #include #include +#include + #include #include #include @@ -34,8 +39,9 @@ namespace detail { template struct update_edges_p_r_q_r_num_triangles { - size_t num_edges{}; // rename to num_edges + size_t num_edges{}; const edge_t edge_first_or_second{}; + size_t chunk_start{}; raft::device_span intersection_offsets{}; raft::device_span intersection_indices{}; raft::device_span num_triangles{}; @@ -48,28 +54,22 @@ struct update_edges_p_r_q_r_num_triangles { thrust::seq, intersection_offsets.begin() + 1, intersection_offsets.end(), i); auto idx = thrust::distance(intersection_offsets.begin() + 1, itr); if (edge_first_or_second == 0) { - auto p_r_pair = - thrust::make_tuple(thrust::get<0>(*(edge_first + idx)), intersection_indices[i]); + auto p_r_pair = thrust::make_tuple(thrust::get<0>(*(edge_first + chunk_start + idx)), + intersection_indices[i]); // Find its position in 'edges' auto itr_p_r_p_q = - thrust::lower_bound(thrust::seq, - edge_first, - edge_first + num_edges, // pass the number of vertex pairs - p_r_pair); + thrust::lower_bound(thrust::seq, edge_first, edge_first + num_edges, p_r_pair); assert(*itr_p_r_p_q == p_r_pair); idx = thrust::distance(edge_first, itr_p_r_p_q); } else { - auto p_r_pair = - thrust::make_tuple(thrust::get<1>(*(edge_first + idx)), intersection_indices[i]); + auto p_r_pair = thrust::make_tuple(thrust::get<1>(*(edge_first + chunk_start + idx)), + intersection_indices[i]); // Find its position in 'edges' auto itr_p_r_p_q = - thrust::lower_bound(thrust::seq, - edge_first, - edge_first + num_edges, // pass the number of vertex pairs - p_r_pair); + thrust::lower_bound(thrust::seq, edge_first, edge_first + num_edges, p_r_pair); assert(*itr_p_r_p_q == p_r_pair); idx = thrust::distance(edge_first, itr_p_r_p_q); } @@ -78,77 +78,296 @@ struct update_edges_p_r_q_r_num_triangles { } }; +template +struct extract_p_r_q_r { + size_t chunk_start{}; + size_t p_r_or_q_r{}; + raft::device_span intersection_offsets{}; + raft::device_span intersection_indices{}; + EdgeIterator edge_first; + + __device__ thrust::tuple operator()(edge_t i) const + { + auto itr = thrust::upper_bound( + thrust::seq, intersection_offsets.begin() + 1, intersection_offsets.end(), i); + auto idx = thrust::distance(intersection_offsets.begin() + 1, itr); + + if (p_r_or_q_r == 0) { + return thrust::make_tuple(thrust::get<0>(*(edge_first + chunk_start + idx)), + intersection_indices[i]); + } else { + return thrust::make_tuple(thrust::get<1>(*(edge_first + chunk_start + idx)), + intersection_indices[i]); + } + } +}; + +template +struct extract_q_r { + size_t chunk_start{}; + raft::device_span intersection_offsets{}; + raft::device_span intersection_indices{}; + EdgeIterator edge_first; + + __device__ thrust::tuple operator()(edge_t i) const + { + auto itr = thrust::upper_bound( + thrust::seq, intersection_offsets.begin() + 1, intersection_offsets.end(), i); + auto idx = thrust::distance(intersection_offsets.begin() + 1, itr); + auto pair = thrust::make_tuple(thrust::get<1>(*(edge_first + chunk_start + idx)), + intersection_indices[i]); + + return pair; + } +}; + template -std::enable_if_t> edge_triangle_count_impl( +edge_property_t, edge_t> edge_triangle_count_impl( raft::handle_t const& handle, - graph_view_t const& graph_view, - raft::device_span edgelist_srcs, - raft::device_span edgelist_dsts) + graph_view_t const& graph_view) { - auto edge_first = thrust::make_zip_iterator(edgelist_srcs.begin(), edgelist_dsts.begin()); + using weight_t = float; + rmm::device_uvector edgelist_srcs(0, handle.get_stream()); + rmm::device_uvector edgelist_dsts(0, handle.get_stream()); + std::tie(edgelist_srcs, edgelist_dsts, std::ignore, std::ignore, std::ignore) = + decompress_to_edgelist( + handle, graph_view, std::nullopt, std::nullopt, std::nullopt, std::nullopt); - thrust::sort(handle.get_thrust_policy(), edge_first, edge_first + edgelist_srcs.size()); + auto edge_first = thrust::make_zip_iterator(edgelist_srcs.begin(), edgelist_dsts.begin()); - // FIXME: Perform 'nbr_intersection' in chunks to reduce peak memory. - auto [intersection_offsets, intersection_indices] = - detail::nbr_intersection(handle, - graph_view, - cugraph::edge_dummy_property_t{}.view(), - edge_first, - edge_first + edgelist_srcs.size(), - std::array{true, true}, - false /*FIXME: pass 'do_expensive_check' as argument*/); + size_t edges_to_intersect_per_iteration = + static_cast(handle.get_device_properties().multiProcessorCount) * (1 << 17); + auto num_chunks = + raft::div_rounding_up_safe(edgelist_srcs.size(), edges_to_intersect_per_iteration); + size_t prev_chunk_size = 0; + auto num_remaining_edges = edgelist_srcs.size(); rmm::device_uvector num_triangles(edgelist_srcs.size(), handle.get_stream()); - // Update the number of triangles of each (p, q) edges by looking at their intersection - // size - thrust::adjacent_difference(handle.get_thrust_policy(), - intersection_offsets.begin() + 1, - intersection_offsets.end(), - num_triangles.begin()); - - // Given intersection offsets and indices that are used to update the number of - // triangles of (p, q) edges where `r`s are the intersection indices, update - // the number of triangles of the pairs (p, r) and (q, r). - - thrust::for_each( - handle.get_thrust_policy(), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(intersection_indices.size()), - update_edges_p_r_q_r_num_triangles{ - edgelist_srcs.size(), - 0, - raft::device_span(intersection_offsets.data(), intersection_offsets.size()), - raft::device_span(intersection_indices.data(), intersection_indices.size()), - raft::device_span(num_triangles.data(), num_triangles.size()), - edge_first}); - - thrust::for_each( - handle.get_thrust_policy(), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(intersection_indices.size()), - update_edges_p_r_q_r_num_triangles{ - edgelist_srcs.size(), - 1, - raft::device_span(intersection_offsets.data(), intersection_offsets.size()), - raft::device_span(intersection_indices.data(), intersection_indices.size()), - raft::device_span(num_triangles.data(), num_triangles.size()), - edge_first}); - - return num_triangles; + // auto my_rank = handle.get_comms().get_rank(); + if constexpr (multi_gpu) { + num_chunks = host_scalar_allreduce( + handle.get_comms(), num_chunks, raft::comms::op_t::MAX, handle.get_stream()); + } + + // Need to ensure that the vector has its values initialized to 0 before incrementing + thrust::fill(handle.get_thrust_policy(), num_triangles.begin(), num_triangles.end(), 0); + + for (size_t i = 0; i < num_chunks; ++i) { + auto chunk_size = std::min(edges_to_intersect_per_iteration, num_remaining_edges); + num_remaining_edges -= chunk_size; + // Perform 'nbr_intersection' in chunks to reduce peak memory. + auto [intersection_offsets, intersection_indices] = + detail::nbr_intersection(handle, + graph_view, + cugraph::edge_dummy_property_t{}.view(), + edge_first + prev_chunk_size, + edge_first + prev_chunk_size + chunk_size, + std::array{true, true}, + false /*FIXME: pass 'do_expensive_check' as argument*/); + + // Update the number of triangles of each (p, q) edges by looking at their intersection + // size + thrust::for_each( + handle.get_thrust_policy(), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(chunk_size), + [chunk_start = prev_chunk_size, + num_triangles = raft::device_span(num_triangles.data(), num_triangles.size()), + intersection_offsets = raft::device_span( + intersection_offsets.data(), intersection_offsets.size())] __device__(auto i) { + num_triangles[chunk_start + i] += (intersection_offsets[i + 1] - intersection_offsets[i]); + }); + + if constexpr (multi_gpu) { + // stores all the pairs (p, r) and (q, r) + auto vertex_pair_buffer_tmp = allocate_dataframe_buffer>( + intersection_indices.size() * 2, handle.get_stream()); + + // tabulate with the size of intersection_indices, and call binary search on + // intersection_offsets to get (p, r). + thrust::tabulate( + handle.get_thrust_policy(), + get_dataframe_buffer_begin(vertex_pair_buffer_tmp), + get_dataframe_buffer_begin(vertex_pair_buffer_tmp) + intersection_indices.size(), + extract_p_r_q_r{ + prev_chunk_size, + 0, + raft::device_span(intersection_offsets.data(), intersection_offsets.size()), + raft::device_span(intersection_indices.data(), + intersection_indices.size()), + edge_first}); + // FIXME: Consolidate both functions + thrust::tabulate( + handle.get_thrust_policy(), + get_dataframe_buffer_begin(vertex_pair_buffer_tmp) + intersection_indices.size(), + get_dataframe_buffer_begin(vertex_pair_buffer_tmp) + (2 * intersection_indices.size()), + extract_p_r_q_r{ + prev_chunk_size, + 1, + raft::device_span(intersection_offsets.data(), intersection_offsets.size()), + raft::device_span(intersection_indices.data(), + intersection_indices.size()), + edge_first}); + + thrust::sort(handle.get_thrust_policy(), + get_dataframe_buffer_begin(vertex_pair_buffer_tmp), + get_dataframe_buffer_end(vertex_pair_buffer_tmp)); + + rmm::device_uvector increase_count_tmp(2 * intersection_indices.size(), + handle.get_stream()); + thrust::fill(handle.get_thrust_policy(), + increase_count_tmp.begin(), + increase_count_tmp.end(), + size_t{1}); + + auto count_p_r_q_r = thrust::unique_count(handle.get_thrust_policy(), + get_dataframe_buffer_begin(vertex_pair_buffer_tmp), + get_dataframe_buffer_end(vertex_pair_buffer_tmp)); + + rmm::device_uvector increase_count(count_p_r_q_r, handle.get_stream()); + + auto vertex_pair_buffer = allocate_dataframe_buffer>( + count_p_r_q_r, handle.get_stream()); + thrust::reduce_by_key(handle.get_thrust_policy(), + get_dataframe_buffer_begin(vertex_pair_buffer_tmp), + get_dataframe_buffer_end(vertex_pair_buffer_tmp), + increase_count_tmp.begin(), + get_dataframe_buffer_begin(vertex_pair_buffer), + increase_count.begin(), + thrust::equal_to>{}); + + rmm::device_uvector pair_srcs(0, handle.get_stream()); + rmm::device_uvector pair_dsts(0, handle.get_stream()); + std::optional> pair_count{std::nullopt}; + + std::optional> opt_increase_count = + std::make_optional(rmm::device_uvector(increase_count.size(), handle.get_stream())); + + raft::copy((*opt_increase_count).begin(), + increase_count.begin(), + increase_count.size(), + handle.get_stream()); + + // There are still multiple copies here but is it worth sorting and reducing again? + std::tie(pair_srcs, pair_dsts, std::ignore, pair_count, std::ignore) = + shuffle_int_vertex_pairs_with_values_to_local_gpu_by_edge_partitioning( + handle, + std::move(std::get<0>(vertex_pair_buffer)), + std::move(std::get<1>(vertex_pair_buffer)), + std::nullopt, + // FIXME: Add general purpose function for shuffling vertex pairs and arbitrary attributes + std::move(opt_increase_count), + std::nullopt, + graph_view.vertex_partition_range_lasts()); + + thrust::for_each( + handle.get_thrust_policy(), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(pair_srcs.size()), + [num_edges = edgelist_srcs.size(), + num_triangles = num_triangles.data(), + pair_srcs = pair_srcs.data(), + pair_dsts = pair_dsts.data(), + pair_count = (*pair_count).data(), + edge_first] __device__(auto idx) { + auto src = pair_srcs[idx]; + auto dst = pair_dsts[idx]; + auto p_r_q_r_pair = thrust::make_tuple(src, dst); + + // Find its position in 'edges' + auto itr_p_r_q_r = + thrust::lower_bound(thrust::seq, edge_first, edge_first + num_edges, p_r_q_r_pair); + + assert(*itr_p_r_q_r == p_r_q_r_pair); + auto idx_p_r_q_r = thrust::distance(edge_first, itr_p_r_q_r); + + cuda::atomic_ref atomic_counter( + num_triangles[idx_p_r_q_r]); + auto r = atomic_counter.fetch_add(pair_count[idx], cuda::std::memory_order_relaxed); + }); + + } else { + // Given intersection offsets and indices that are used to update the number of + // triangles of (p, q) edges where `r`s are the intersection indices, update + // the number of triangles of the pairs (p, r) and (q, r). + thrust::for_each( + handle.get_thrust_policy(), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(intersection_indices.size()), + update_edges_p_r_q_r_num_triangles{ + edgelist_srcs.size(), + 0, + prev_chunk_size, + raft::device_span(intersection_offsets.data(), intersection_offsets.size()), + raft::device_span(intersection_indices.data(), + intersection_indices.size()), + raft::device_span(num_triangles.data(), num_triangles.size()), + edge_first}); + + thrust::for_each( + handle.get_thrust_policy(), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(intersection_indices.size()), + update_edges_p_r_q_r_num_triangles{ + edgelist_srcs.size(), + 1, + prev_chunk_size, + raft::device_span(intersection_offsets.data(), intersection_offsets.size()), + raft::device_span(intersection_indices.data(), + intersection_indices.size()), + raft::device_span(num_triangles.data(), num_triangles.size()), + edge_first}); + } + prev_chunk_size += chunk_size; + } + + cugraph::edge_property_t, edge_t> counts( + handle, graph_view); + + cugraph::edge_bucket_t valid_edges(handle); + valid_edges.insert(edgelist_srcs.begin(), edgelist_srcs.end(), edgelist_dsts.begin()); + + auto cur_graph_view = graph_view; + + cugraph::transform_e( + handle, + graph_view, + valid_edges, + cugraph::edge_src_dummy_property_t{}.view(), + cugraph::edge_dst_dummy_property_t{}.view(), + cugraph::edge_dummy_property_t{}.view(), + [edge_first, + edge_last = edge_first + edgelist_srcs.size(), + num_edges = edgelist_srcs.size(), + num_triangles = num_triangles.data()] __device__(auto src, + auto dst, + thrust::nullopt_t, + thrust::nullopt_t, + thrust::nullopt_t) { + auto pair = thrust::make_tuple(src, dst); + + // Find its position in 'edges' + auto itr_pair = thrust::lower_bound(thrust::seq, edge_first, edge_last, pair); + auto idx_pair = thrust::distance(edge_first, itr_pair); + return num_triangles[idx_pair]; + }, + counts.mutable_view(), + false); + + return counts; } } // namespace detail -template -rmm::device_uvector edge_triangle_count( - raft::handle_t const& handle, - graph_view_t const& graph_view, - raft::device_span edgelist_srcs, - raft::device_span edgelist_dsts) +template +edge_property_t, edge_t> edge_triangle_count( + raft::handle_t const& handle, graph_view_t const& graph_view) { - return detail::edge_triangle_count_impl(handle, graph_view, edgelist_srcs, edgelist_dsts); + return detail::edge_triangle_count_impl(handle, graph_view); } } // namespace cugraph diff --git a/cpp/src/community/edge_triangle_count_mg.cu b/cpp/src/community/edge_triangle_count_mg.cu new file mode 100644 index 00000000000..254a0807e56 --- /dev/null +++ b/cpp/src/community/edge_triangle_count_mg.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "community/edge_triangle_count_impl.cuh" + +namespace cugraph { + +// SG instantiation +template edge_property_t, int32_t> edge_triangle_count( + raft::handle_t const& handle, + cugraph::graph_view_t const& graph_view); + +template edge_property_t, int64_t> edge_triangle_count( + raft::handle_t const& handle, + cugraph::graph_view_t const& graph_view); + +template edge_property_t, int64_t> edge_triangle_count( + raft::handle_t const& handle, + cugraph::graph_view_t const& graph_view); + +} // namespace cugraph diff --git a/cpp/src/community/edge_triangle_count_sg.cu b/cpp/src/community/edge_triangle_count_sg.cu index c4b7e71b967..4ccb968458d 100644 --- a/cpp/src/community/edge_triangle_count_sg.cu +++ b/cpp/src/community/edge_triangle_count_sg.cu @@ -18,22 +18,16 @@ namespace cugraph { // SG instantiation -template rmm::device_uvector edge_triangle_count( +template edge_property_t, int32_t> edge_triangle_count( raft::handle_t const& handle, - cugraph::graph_view_t const& graph_view, - raft::device_span edgelist_srcs, - raft::device_span edgelist_dsts); + cugraph::graph_view_t const& graph_view); -template rmm::device_uvector edge_triangle_count( +template edge_property_t, int64_t> edge_triangle_count( raft::handle_t const& handle, - cugraph::graph_view_t const& graph_view, - raft::device_span edgelist_srcs, - raft::device_span edgelist_dsts); + cugraph::graph_view_t const& graph_view); -template rmm::device_uvector edge_triangle_count( +template edge_property_t, int64_t> edge_triangle_count( raft::handle_t const& handle, - cugraph::graph_view_t const& graph_view, - raft::device_span edgelist_srcs, - raft::device_span edgelist_dsts); + cugraph::graph_view_t const& graph_view); } // namespace cugraph diff --git a/cpp/src/community/k_truss_impl.cuh b/cpp/src/community/k_truss_impl.cuh index 7f96312703d..f830e6a7700 100644 --- a/cpp/src/community/k_truss_impl.cuh +++ b/cpp/src/community/k_truss_impl.cuh @@ -27,6 +27,8 @@ #include #include +#include + #include #include #include @@ -39,14 +41,6 @@ namespace cugraph { -// FIXME : This will be deleted once edge_triangle_count becomes public -template -rmm::device_uvector edge_triangle_count( - raft::handle_t const& handle, - graph_view_t const& graph_view, - raft::device_span edgelist_srcs, - raft::device_span edgelist_dsts); - template struct unroll_edge { size_t num_valid_edges{}; @@ -442,6 +436,7 @@ struct extract_low_to_high_degree_edges_t { template struct generate_p_r_or_q_r_from_p_q { + size_t chunk_start{}; raft::device_span intersection_offsets{}; raft::device_span intersection_indices{}; raft::device_span invalid_srcs{}; @@ -454,10 +449,10 @@ struct generate_p_r_or_q_r_from_p_q { auto idx = thrust::distance(intersection_offsets.begin() + 1, itr); if constexpr (generate_p_r) { - return thrust::make_tuple(invalid_srcs[idx], intersection_indices[i]); + return thrust::make_tuple(invalid_srcs[chunk_start + idx], intersection_indices[i]); } else { - return thrust::make_tuple(invalid_dsts[idx], intersection_indices[i]); + return thrust::make_tuple(invalid_dsts[chunk_start + idx], intersection_indices[i]); } } }; @@ -491,6 +486,7 @@ k_truss(raft::handle_t const& handle, std::optional> renumber_map{std::nullopt}; std::optional, weight_t>> edge_weight{std::nullopt}; + std::optional> wgts{std::nullopt}; if (graph_view.count_self_loops(handle) > edge_t{0}) { auto [srcs, dsts] = extract_transform_e(handle, @@ -524,31 +520,30 @@ k_truss(raft::handle_t const& handle, modified_graph_view = (*modified_graph).view(); } - // FIXME: Investigate k-1 core failure to yield correct results. // 3. Find (k-1)-core and exclude edges that do not belong to (k-1)-core - /* { auto cur_graph_view = modified_graph_view ? *modified_graph_view : graph_view; + auto vertex_partition_range_lasts = renumber_map ? std::make_optional>(cur_graph_view.vertex_partition_range_lasts()) : std::nullopt; - rmm::device_uvector d_core_numbers(cur_graph_view.local_vertex_partition_range_size(), - handle.get_stream()); - raft::device_span core_number_span{d_core_numbers.data(), d_core_numbers.size()}; + rmm::device_uvector core_numbers(cur_graph_view.number_of_vertices(), + handle.get_stream()); + core_number( + handle, cur_graph_view, core_numbers.data(), k_core_degree_type_t::OUT, size_t{2}, size_t{2}); + + raft::device_span core_number_span{core_numbers.data(), core_numbers.size()}; rmm::device_uvector srcs{0, handle.get_stream()}; rmm::device_uvector dsts{0, handle.get_stream()}; - std::tie(srcs, dsts, std::ignore) = - k_core(handle, - cur_graph_view, - std::optional>{std::nullopt}, - size_t{k - 1}, - std::make_optional(k_core_degree_type_t::OUT), - // Seems like the below argument is required. passing a std::nullopt - // create a compiler error - std::make_optional(core_number_span)); + std::tie(srcs, dsts, wgts) = k_core(handle, + cur_graph_view, + edge_weight_view, + k - 1, + std::make_optional(k_core_degree_type_t::OUT), + std::make_optional(core_number_span)); if constexpr (multi_gpu) { std::tie(srcs, dsts, std::ignore, std::ignore, std::ignore) = @@ -561,17 +556,17 @@ k_truss(raft::handle_t const& handle, std::optional> tmp_renumber_map{std::nullopt}; - std::tie(*modified_graph, std::ignore, std::ignore, std::ignore, tmp_renumber_map) = + std::tie(*modified_graph, edge_weight, std::ignore, std::ignore, tmp_renumber_map) = create_graph_from_edgelist( handle, std::nullopt, std::move(srcs), std::move(dsts), - std::nullopt, + std::move(wgts), std::nullopt, std::nullopt, cugraph::graph_properties_t{true, graph_view.is_multigraph()}, - true); + false); modified_graph_view = (*modified_graph).view(); @@ -584,7 +579,6 @@ k_truss(raft::handle_t const& handle, } renumber_map = std::move(tmp_renumber_map); } - */ // 4. Keep only the edges from a low-degree vertex to a high-degree vertex. @@ -606,7 +600,10 @@ k_truss(raft::handle_t const& handle, rmm::device_uvector srcs(0, handle.get_stream()); rmm::device_uvector dsts(0, handle.get_stream()); - std::optional> wgts{std::nullopt}; + + edge_weight_view = + edge_weight ? std::make_optional((*edge_weight).view()) + : std::optional>{std::nullopt}; if (edge_weight_view) { std::tie(srcs, dsts, wgts) = extract_transform_e( handle, @@ -666,38 +663,36 @@ k_truss(raft::handle_t const& handle, auto cur_graph_view = modified_graph_view ? *modified_graph_view : graph_view; rmm::device_uvector edgelist_srcs(0, handle.get_stream()); rmm::device_uvector edgelist_dsts(0, handle.get_stream()); + std::optional> num_triangles{std::nullopt}; std::optional> edgelist_wgts{std::nullopt}; edge_weight_view = edge_weight ? std::make_optional((*edge_weight).view()) : std::optional>{std::nullopt}; - std::tie(edgelist_srcs, edgelist_dsts, edgelist_wgts, std::ignore, std::ignore) = + + auto prop_num_triangles = edge_triangle_count(handle, cur_graph_view); + + std::tie(edgelist_srcs, edgelist_dsts, edgelist_wgts, num_triangles, std::ignore) = decompress_to_edgelist( handle, cur_graph_view, edge_weight_view, - std::optional>{std::nullopt}, + // FIXME: Update 'decompress_edgelist' to support int32_t and int64_t values + std::make_optional(prop_num_triangles.view()), std::optional>{std::nullopt}, std::optional>(std::nullopt)); - - auto num_triangles = edge_triangle_count( - handle, - cur_graph_view, - raft::device_span(edgelist_srcs.data(), edgelist_srcs.size()), - raft::device_span(edgelist_dsts.data(), edgelist_dsts.size())); - auto transposed_edge_first = thrust::make_zip_iterator(edgelist_dsts.begin(), edgelist_srcs.begin()); auto edge_first = thrust::make_zip_iterator(edgelist_srcs.begin(), edgelist_dsts.begin()); auto transposed_edge_triangle_count_pair_first = - thrust::make_zip_iterator(transposed_edge_first, num_triangles.begin()); + thrust::make_zip_iterator(transposed_edge_first, (*num_triangles).begin()); thrust::sort_by_key(handle.get_thrust_policy(), transposed_edge_first, transposed_edge_first + edgelist_srcs.size(), - num_triangles.begin()); + (*num_triangles).begin()); cugraph::edge_property_t edge_mask(handle, cur_graph_view); cugraph::fill_edge_property(handle, cur_graph_view, true, edge_mask); @@ -728,92 +723,115 @@ k_truss(raft::handle_t const& handle, // nbr_intersection requires the edges to be sort by 'src' // sort the invalid edges by src for nbr intersection - thrust::sort_by_key(handle.get_thrust_policy(), - edge_first + num_valid_edges, - edge_first + edgelist_srcs.size(), - num_triangles.begin() + num_valid_edges); - - auto [intersection_offsets, intersection_indices] = - detail::nbr_intersection(handle, - cur_graph_view, - cugraph::edge_dummy_property_t{}.view(), - edge_first + num_valid_edges, - edge_first + edgelist_srcs.size(), - std::array{true, true}, - do_expensive_check); - - // Update the number of triangles of each (p, q) edges by looking at their intersection - // size. - thrust::for_each( - handle.get_thrust_policy(), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(num_invalid_edges), - [num_triangles = - raft::device_span(num_triangles.data() + num_valid_edges, num_invalid_edges), - intersection_offsets = raft::device_span( - intersection_offsets.data(), intersection_offsets.size())] __device__(auto i) { - num_triangles[i] -= intersection_offsets[i + 1] - intersection_offsets[i]; - }); - - // FIXME: Find a way to not have to maintain a dataframe_buffer - auto vertex_pair_buffer_p_r_edge_p_q = - allocate_dataframe_buffer>(intersection_indices.size(), - handle.get_stream()); - - thrust::tabulate( - handle.get_thrust_policy(), - get_dataframe_buffer_begin(vertex_pair_buffer_p_r_edge_p_q), - get_dataframe_buffer_end(vertex_pair_buffer_p_r_edge_p_q), - generate_p_r_or_q_r_from_p_q{ - raft::device_span(intersection_offsets.data(), intersection_offsets.size()), - raft::device_span(intersection_indices.data(), - intersection_indices.size()), - raft::device_span(edgelist_srcs.data() + num_valid_edges, num_invalid_edges), - raft::device_span(edgelist_dsts.data() + num_valid_edges, num_invalid_edges)}); - - auto vertex_pair_buffer_q_r_edge_p_q = - allocate_dataframe_buffer>(intersection_indices.size(), - handle.get_stream()); - thrust::tabulate( - handle.get_thrust_policy(), - get_dataframe_buffer_begin(vertex_pair_buffer_q_r_edge_p_q), - get_dataframe_buffer_end(vertex_pair_buffer_q_r_edge_p_q), - generate_p_r_or_q_r_from_p_q{ - raft::device_span(intersection_offsets.data(), intersection_offsets.size()), - raft::device_span(intersection_indices.data(), - intersection_indices.size()), - raft::device_span(edgelist_srcs.data() + num_valid_edges, num_invalid_edges), - raft::device_span(edgelist_dsts.data() + num_valid_edges, num_invalid_edges)}); - - // Unrolling the edges require the edges to be sorted by destination - // re-sort the invalid edges by 'dst' - thrust::sort_by_key(handle.get_thrust_policy(), - transposed_edge_first + num_valid_edges, - transposed_edge_first + edgelist_srcs.size(), - num_triangles.begin() + num_valid_edges); - - thrust::for_each(handle.get_thrust_policy(), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(intersection_indices.size()), - unroll_edge{ - num_valid_edges, - raft::device_span(num_triangles.data(), num_triangles.size()), - get_dataframe_buffer_begin(vertex_pair_buffer_p_r_edge_p_q), - transposed_edge_first, - transposed_edge_first + num_valid_edges, - transposed_edge_first + edgelist_srcs.size()}); - - thrust::for_each(handle.get_thrust_policy(), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(intersection_indices.size()), - unroll_edge{ - num_valid_edges, - raft::device_span(num_triangles.data(), num_triangles.size()), - get_dataframe_buffer_begin(vertex_pair_buffer_q_r_edge_p_q), - transposed_edge_first, - transposed_edge_first + num_valid_edges, - transposed_edge_first + edgelist_srcs.size()}); - + size_t edges_to_intersect_per_iteration = + static_cast(handle.get_device_properties().multiProcessorCount) * (1 << 17); + + size_t prev_chunk_size = 0; + size_t chunk_num_invalid_edges = num_invalid_edges; + + auto num_chunks = + raft::div_rounding_up_safe(edgelist_srcs.size(), edges_to_intersect_per_iteration); + + for (size_t i = 0; i < num_chunks; ++i) { + auto chunk_size = std::min(edges_to_intersect_per_iteration, chunk_num_invalid_edges); + thrust::sort_by_key(handle.get_thrust_policy(), + edge_first + num_valid_edges, + edge_first + edgelist_srcs.size(), + (*num_triangles).begin() + num_valid_edges); + + auto [intersection_offsets, intersection_indices] = + detail::nbr_intersection(handle, + cur_graph_view, + cugraph::edge_dummy_property_t{}.view(), + edge_first + num_valid_edges + prev_chunk_size, + edge_first + num_valid_edges + prev_chunk_size + chunk_size, + std::array{true, true}, + do_expensive_check); + + // Update the number of triangles of each (p, q) edges by looking at their intersection + // size. + thrust::for_each( + handle.get_thrust_policy(), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(chunk_size), + [chunk_start = prev_chunk_size, + num_triangles = raft::device_span((*num_triangles).data() + num_valid_edges, + num_invalid_edges), + intersection_offsets = raft::device_span( + intersection_offsets.data(), intersection_offsets.size())] __device__(auto i) { + num_triangles[chunk_start + i] -= + (intersection_offsets[i + 1] - intersection_offsets[i]); + }); + + // FIXME: Find a way to not have to maintain a dataframe_buffer + auto vertex_pair_buffer_p_r_edge_p_q = + allocate_dataframe_buffer>(intersection_indices.size(), + handle.get_stream()); + thrust::tabulate( + handle.get_thrust_policy(), + get_dataframe_buffer_begin(vertex_pair_buffer_p_r_edge_p_q), + get_dataframe_buffer_end(vertex_pair_buffer_p_r_edge_p_q), + generate_p_r_or_q_r_from_p_q{ + prev_chunk_size, + raft::device_span(intersection_offsets.data(), + intersection_offsets.size()), + raft::device_span(intersection_indices.data(), + intersection_indices.size()), + raft::device_span(edgelist_srcs.data() + num_valid_edges, num_invalid_edges), + raft::device_span(edgelist_dsts.data() + num_valid_edges, + num_invalid_edges)}); + + auto vertex_pair_buffer_q_r_edge_p_q = + allocate_dataframe_buffer>(intersection_indices.size(), + handle.get_stream()); + thrust::tabulate( + handle.get_thrust_policy(), + get_dataframe_buffer_begin(vertex_pair_buffer_q_r_edge_p_q), + get_dataframe_buffer_end(vertex_pair_buffer_q_r_edge_p_q), + generate_p_r_or_q_r_from_p_q{ + prev_chunk_size, + raft::device_span(intersection_offsets.data(), + intersection_offsets.size()), + raft::device_span(intersection_indices.data(), + intersection_indices.size()), + raft::device_span(edgelist_srcs.data() + num_valid_edges, num_invalid_edges), + raft::device_span(edgelist_dsts.data() + num_valid_edges, + num_invalid_edges)}); + + // Unrolling the edges require the edges to be sorted by destination + // re-sort the invalid edges by 'dst' + thrust::sort_by_key(handle.get_thrust_policy(), + transposed_edge_first + num_valid_edges, + transposed_edge_first + edgelist_srcs.size(), + (*num_triangles).begin() + num_valid_edges); + + thrust::for_each( + handle.get_thrust_policy(), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(intersection_indices.size()), + unroll_edge{ + num_valid_edges, + raft::device_span((*num_triangles).data(), (*num_triangles).size()), + get_dataframe_buffer_begin(vertex_pair_buffer_p_r_edge_p_q), + transposed_edge_first, + transposed_edge_first + num_valid_edges, + transposed_edge_first + edgelist_srcs.size()}); + + thrust::for_each( + handle.get_thrust_policy(), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(intersection_indices.size()), + unroll_edge{ + num_valid_edges, + raft::device_span((*num_triangles).data(), (*num_triangles).size()), + get_dataframe_buffer_begin(vertex_pair_buffer_q_r_edge_p_q), + transposed_edge_first, + transposed_edge_first + num_valid_edges, + transposed_edge_first + edgelist_srcs.size()}); + + prev_chunk_size += chunk_size; + chunk_num_invalid_edges -= chunk_size; + } // case 2: unroll (q, r) // For each (q, r) edges to unroll, find the incoming edges to 'r' let's say from 'p' and // create the pair (p, q) @@ -824,7 +842,7 @@ k_truss(raft::handle_t const& handle, num_valid_edges, raft::device_span(edgelist_srcs.data(), edgelist_srcs.size()), raft::device_span(edgelist_dsts.data(), edgelist_dsts.size()), - raft::device_span(num_triangles.data(), num_triangles.size())); + raft::device_span((*num_triangles).data(), (*num_triangles).size())); // case 3: unroll (p, r) cugraph::unroll_p_r_or_q_r_edges( @@ -834,18 +852,18 @@ k_truss(raft::handle_t const& handle, num_valid_edges, raft::device_span(edgelist_srcs.data(), edgelist_srcs.size()), raft::device_span(edgelist_dsts.data(), edgelist_dsts.size()), - raft::device_span(num_triangles.data(), num_triangles.size())); + raft::device_span((*num_triangles).data(), (*num_triangles).size())); // Remove edges that have a triangle count of zero. Those should not be accounted // for during the unroling phase. - auto edges_with_triangle_last = - thrust::stable_partition(handle.get_thrust_policy(), - transposed_edge_triangle_count_pair_first, - transposed_edge_triangle_count_pair_first + num_triangles.size(), - [] __device__(auto e) { - auto num_triangles = thrust::get<1>(e); - return num_triangles > 0; - }); + auto edges_with_triangle_last = thrust::stable_partition( + handle.get_thrust_policy(), + transposed_edge_triangle_count_pair_first, + transposed_edge_triangle_count_pair_first + (*num_triangles).size(), + [] __device__(auto e) { + auto num_triangles = thrust::get<1>(e); + return num_triangles > 0; + }); auto num_edges_with_triangles = static_cast( thrust::distance(transposed_edge_triangle_count_pair_first, edges_with_triangle_last)); @@ -893,7 +911,7 @@ k_truss(raft::handle_t const& handle, edgelist_srcs.resize(num_edges_with_triangles, handle.get_stream()); edgelist_dsts.resize(num_edges_with_triangles, handle.get_stream()); - num_triangles.resize(num_edges_with_triangles, handle.get_stream()); + (*num_triangles).resize(num_edges_with_triangles, handle.get_stream()); } std::tie(edgelist_srcs, edgelist_dsts, edgelist_wgts, std::ignore, std::ignore) = diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index ced3b7bedb1..d1dd2dec069 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -490,6 +490,11 @@ ConfigureTest(K_TRUSS_TEST community/k_truss_test.cpp) # - Triangle Count tests -------------------------------------------------------------------------- ConfigureTest(TRIANGLE_COUNT_TEST community/triangle_count_test.cpp) +################################################################################################### +# - Edge Triangle Count tests -------------------------------------------------------------------------- +ConfigureTest(EDGE_TRIANGLE_COUNT_TEST community/edge_triangle_count_test.cpp) + + ################################################################################################### # - K-hop Neighbors tests ------------------------------------------------------------------------- ConfigureTest(K_HOP_NBRS_TEST traversal/k_hop_nbrs_test.cpp) @@ -590,6 +595,10 @@ if(BUILD_CUGRAPH_MG_TESTS) # - MG LOUVAIN tests -------------------------------------------------------------------------- ConfigureTestMG(MG_EGONET_TEST community/mg_egonet_test.cu) + ############################################################################################### + # - MG EDGE TRIANGLE COUNT tests -------------------------------------------------------------------------- + ConfigureTestMG(MG_EDGE_TRIANGLE_COUNT_TEST community/mg_edge_triangle_count_test.cpp) + ############################################################################################### # - MG WEAKLY CONNECTED COMPONENTS tests ------------------------------------------------------ ConfigureTestMG(MG_WEAKLY_CONNECTED_COMPONENTS_TEST diff --git a/cpp/tests/centrality/mg_betweenness_centrality_test.cpp b/cpp/tests/centrality/mg_betweenness_centrality_test.cpp index 7924d449897..798e767085e 100644 --- a/cpp/tests/centrality/mg_betweenness_centrality_test.cpp +++ b/cpp/tests/centrality/mg_betweenness_centrality_test.cpp @@ -152,13 +152,15 @@ class Tests_MGBetweennessCentrality std::optional< cugraph::edge_property_t, weight_t>> sg_edge_weights{std::nullopt}; - std::tie(sg_graph, sg_edge_weights, std::ignore) = - cugraph::test::mg_graph_to_sg_graph(*handle_, - mg_graph_view, - mg_edge_weight_view, - std::make_optional>( - (*mg_renumber_map).data(), (*mg_renumber_map).size()), - false); + std::tie(sg_graph, sg_edge_weights, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + mg_edge_weight_view, + std::optional>{std::nullopt}, + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); if (handle_->get_comms().get_rank() == 0) { auto sg_graph_view = sg_graph.view(); diff --git a/cpp/tests/centrality/mg_edge_betweenness_centrality_test.cpp b/cpp/tests/centrality/mg_edge_betweenness_centrality_test.cpp index c3417e96c03..1703f198a4c 100644 --- a/cpp/tests/centrality/mg_edge_betweenness_centrality_test.cpp +++ b/cpp/tests/centrality/mg_edge_betweenness_centrality_test.cpp @@ -142,12 +142,14 @@ class Tests_MGEdgeBetweennessCentrality std::optional< cugraph::edge_property_t, weight_t>> sg_edge_weights{std::nullopt}; - std::tie(sg_graph, sg_edge_weights, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - mg_edge_weight_view, - std::optional>{std::nullopt}, - false); + std::tie(sg_graph, sg_edge_weights, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + mg_edge_weight_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + false); if (handle_->get_comms().get_rank() == 0) { auto sg_edge_weights_view = diff --git a/cpp/tests/centrality/mg_eigenvector_centrality_test.cpp b/cpp/tests/centrality/mg_eigenvector_centrality_test.cpp index ed24bee0923..76c52d52bfd 100644 --- a/cpp/tests/centrality/mg_eigenvector_centrality_test.cpp +++ b/cpp/tests/centrality/mg_eigenvector_centrality_test.cpp @@ -144,13 +144,15 @@ class Tests_MGEigenvectorCentrality std::optional< cugraph::edge_property_t, weight_t>> sg_edge_weights{std::nullopt}; - std::tie(sg_graph, sg_edge_weights, std::ignore) = - cugraph::test::mg_graph_to_sg_graph(*handle_, - mg_graph_view, - mg_edge_weight_view, - std::make_optional>( - (*mg_renumber_map).data(), (*mg_renumber_map).size()), - false); + std::tie(sg_graph, sg_edge_weights, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + mg_edge_weight_view, + std::optional>{std::nullopt}, + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); if (handle_->get_comms().get_rank() == int{0}) { // 3-2. run SG Eigenvector Centrality diff --git a/cpp/tests/centrality/mg_katz_centrality_test.cpp b/cpp/tests/centrality/mg_katz_centrality_test.cpp index abe02b2287b..e38f87749b8 100644 --- a/cpp/tests/centrality/mg_katz_centrality_test.cpp +++ b/cpp/tests/centrality/mg_katz_centrality_test.cpp @@ -151,13 +151,15 @@ class Tests_MGKatzCentrality std::optional< cugraph::edge_property_t, weight_t>> sg_edge_weights{std::nullopt}; - std::tie(sg_graph, sg_edge_weights, std::ignore) = - cugraph::test::mg_graph_to_sg_graph(*handle_, - mg_graph_view, - mg_edge_weight_view, - std::make_optional>( - (*mg_renumber_map).data(), (*mg_renumber_map).size()), - false); + std::tie(sg_graph, sg_edge_weights, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + mg_edge_weight_view, + std::optional>{std::nullopt}, + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); if (handle_->get_comms().get_rank() == int{0}) { // 4-2. run SG Katz Centrality diff --git a/cpp/tests/community/edge_triangle_count_test.cpp b/cpp/tests/community/edge_triangle_count_test.cpp new file mode 100644 index 00000000000..8cefc2c31f4 --- /dev/null +++ b/cpp/tests/community/edge_triangle_count_test.cpp @@ -0,0 +1,260 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governin_from_mtxg permissions and + * limitations under the License. + */ + +#include "utilities/base_fixture.hpp" +#include "utilities/check_utilities.hpp" +#include "utilities/conversion_utilities.hpp" +#include "utilities/property_generator_utilities.hpp" +#include "utilities/test_graphs.hpp" +#include "utilities/thrust_wrapper.hpp" + +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +struct EdgeTriangleCount_Usecase { + bool edge_masking_{false}; + bool check_correctness_{true}; +}; + +template +class Tests_EdgeTriangleCount + : public ::testing::TestWithParam> { + public: + Tests_EdgeTriangleCount() {} + + static void SetUpTestCase() {} + static void TearDownTestCase() {} + + virtual void SetUp() {} + virtual void TearDown() {} + + // FIXME: There is an utility equivalent functor not + // supporting host vectors. + template + struct host_nearly_equal { + const type_t threshold_ratio; + const type_t threshold_magnitude; + + bool operator()(type_t lhs, type_t rhs) const + { + return std::abs(lhs - rhs) < + std::max(std::max(lhs, rhs) * threshold_ratio, threshold_magnitude); + } + }; + + template + std::vector edge_triangle_count_reference(std::vector h_srcs, + std::vector h_dsts) + { + std::vector edge_triangle_counts(h_srcs.size()); + std::uninitialized_fill(edge_triangle_counts.begin(), edge_triangle_counts.end(), 0); + + for (int i = 0; i < h_srcs.size(); ++i) { // edge centric implementation + // for each edge, find the intersection + auto src = h_srcs[i]; + auto dst = h_dsts[i]; + auto it_src_start = std::lower_bound(h_srcs.begin(), h_srcs.end(), src); + auto src_start = std::distance(h_srcs.begin(), it_src_start); + + auto src_end = + src_start + std::distance(it_src_start, std::upper_bound(it_src_start, h_srcs.end(), src)); + + auto it_dst_start = std::lower_bound(h_srcs.begin(), h_srcs.end(), dst); + auto dst_start = std::distance(h_srcs.begin(), it_dst_start); + auto dst_end = + dst_start + std::distance(it_dst_start, std::upper_bound(it_dst_start, h_srcs.end(), dst)); + + std::set nbr_intersection; + std::set_intersection(h_dsts.begin() + src_start, + h_dsts.begin() + src_end, + h_dsts.begin() + dst_start, + h_dsts.begin() + dst_end, + std::inserter(nbr_intersection, nbr_intersection.end())); + // Find the supporting edges + for (auto v : nbr_intersection) { + auto it_edge = std::lower_bound(h_dsts.begin() + src_start, h_dsts.begin() + src_end, v); + auto idx_edge = std::distance(h_dsts.begin(), it_edge); + edge_triangle_counts[idx_edge] += 1; + + it_edge = std::lower_bound(h_dsts.begin() + dst_start, h_dsts.begin() + dst_end, v); + idx_edge = std::distance(h_dsts.begin(), it_edge); + } + } + + std::transform(edge_triangle_counts.begin(), + edge_triangle_counts.end(), + edge_triangle_counts.begin(), + [](auto count) { return count * 3; }); + return std::move(edge_triangle_counts); + } + + template + void run_current_test( + std::tuple const& param) + { + constexpr bool renumber = false; + auto [edge_triangle_count_usecase, input_usecase] = param; + raft::handle_t handle{}; + HighResTimer hr_timer{}; + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + hr_timer.start("SG Construct graph"); + } + + auto [graph, edge_weight, d_renumber_map_labels] = + cugraph::test::construct_graph( + handle, input_usecase, false, renumber, true, true); + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + hr_timer.stop(); + hr_timer.display_and_clear(std::cout); + } + + auto graph_view = graph.view(); + + std::optional> edge_mask{std::nullopt}; + if (edge_triangle_count_usecase.edge_masking_) { + edge_mask = + cugraph::test::generate::edge_property(handle, graph_view, 2); + graph_view.attach_edge_mask((*edge_mask).view()); + } + + rmm::device_uvector edgelist_srcs(0, handle.get_stream()); + rmm::device_uvector edgelist_dsts(0, handle.get_stream()); + std::optional> d_edge_triangle_counts{std::nullopt}; + + auto d_cugraph_results = + cugraph::edge_triangle_count(handle, graph_view); + + std::tie(edgelist_srcs, edgelist_dsts, std::ignore, d_edge_triangle_counts, std::ignore) = + cugraph::decompress_to_edgelist( + handle, + graph_view, + std::optional>{std::nullopt}, + std::make_optional(d_cugraph_results.view()), + std::optional>{std::nullopt}, + std::optional>{std::nullopt}); // FIXME: No longer needed + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + hr_timer.start("EdgeTriangleCount"); + } + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + hr_timer.stop(); + hr_timer.display_and_clear(std::cout); + } + + if (edge_triangle_count_usecase.check_correctness_) { + std::optional> modified_graph{std::nullopt}; + std::vector h_srcs(edgelist_srcs.size()); + std::vector h_dsts(edgelist_dsts.size()); + std::tie(h_srcs, h_dsts, std::ignore) = cugraph::test::graph_to_host_coo( + handle, + graph_view, + edge_weight ? std::make_optional((*edge_weight).view()) : std::nullopt, + std::optional>(std::nullopt)); + + auto h_cugraph_edge_triangle_counts = cugraph::test::to_host(handle, *d_edge_triangle_counts); + + auto h_reference_edge_triangle_counts = + edge_triangle_count_reference(h_srcs, h_dsts); + + for (size_t i = 0; i < h_srcs.size(); ++i) { + ASSERT_EQ(h_cugraph_edge_triangle_counts[i], h_reference_edge_triangle_counts[i]) + << "Edge triangle count values do not match with the reference values."; + } + } + } +}; + +using Tests_EdgeTriangleCount_File = Tests_EdgeTriangleCount; +using Tests_EdgeTriangleCount_Rmat = Tests_EdgeTriangleCount; + +TEST_P(Tests_EdgeTriangleCount_File, CheckInt32Int32Float) +{ + run_current_test( + override_File_Usecase_with_cmd_line_arguments(GetParam())); +} +TEST_P(Tests_EdgeTriangleCount_Rmat, CheckInt32Int32Float) +{ + run_current_test( + override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); +} +TEST_P(Tests_EdgeTriangleCount_File, CheckInt64Int64Float) +{ + run_current_test( + override_File_Usecase_with_cmd_line_arguments(GetParam())); +} +TEST_P(Tests_EdgeTriangleCount_Rmat, CheckInt64Int64Float) +{ + run_current_test( + override_Rmat_Usecase_with_cmd_line_arguments(GetParam())); +} + +INSTANTIATE_TEST_SUITE_P( + simple_test, + Tests_EdgeTriangleCount_File, + ::testing::Combine( + // enable correctness checks + ::testing::Values(EdgeTriangleCount_Usecase{false, true}, + EdgeTriangleCount_Usecase{true, true}), + ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"), + cugraph::test::File_Usecase("test/datasets/dolphins.mtx")))); + +INSTANTIATE_TEST_SUITE_P( + rmat_small_test, + Tests_EdgeTriangleCount_Rmat, + // enable correctness checks + ::testing::Combine( + ::testing::Values(EdgeTriangleCount_Usecase{false, true}, + EdgeTriangleCount_Usecase{true, true}), + ::testing::Values(cugraph::test::Rmat_Usecase(10, 16, 0.57, 0.19, 0.19, 0, true, false)))); + +INSTANTIATE_TEST_SUITE_P( + rmat_benchmark_test, /* note that scale & edge factor can be overridden in benchmarking (with + --gtest_filter to select only the rmat_benchmark_test with a specific + vertex & edge type combination) by command line arguments and do not + include more than one Rmat_Usecase that differ only in scale or edge + factor (to avoid running same benchmarks more than once) */ + Tests_EdgeTriangleCount_Rmat, + // disable correctness checks for large graphs + // FIXME: High memory footprint. Perform nbr_intersection in chunks. + ::testing::Combine( + ::testing::Values(EdgeTriangleCount_Usecase{false, false}, + EdgeTriangleCount_Usecase{true, false}), + ::testing::Values(cugraph::test::Rmat_Usecase(16, 16, 0.57, 0.19, 0.19, 0, true, false)))); + +CUGRAPH_TEST_PROGRAM_MAIN() diff --git a/cpp/tests/community/mg_ecg_test.cpp b/cpp/tests/community/mg_ecg_test.cpp index a5e02c4f532..c99f83fa2e8 100644 --- a/cpp/tests/community/mg_ecg_test.cpp +++ b/cpp/tests/community/mg_ecg_test.cpp @@ -127,12 +127,14 @@ class Tests_MGEcg : public ::testing::TestWithParam, weight_t>> sg_edge_weights{std::nullopt}; - std::tie(sg_graph, sg_edge_weights, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - mg_edge_weight_view, - std::optional>{std::nullopt}, - false); // crate a SG graph with MG graph vertex IDs + std::tie(sg_graph, sg_edge_weights, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + mg_edge_weight_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + false); // crate a SG graph with MG graph vertex IDs auto const comm_rank = handle_->get_comms().get_rank(); if (comm_rank == 0) { diff --git a/cpp/tests/community/mg_edge_triangle_count_test.cpp b/cpp/tests/community/mg_edge_triangle_count_test.cpp new file mode 100644 index 00000000000..89bdf870ccd --- /dev/null +++ b/cpp/tests/community/mg_edge_triangle_count_test.cpp @@ -0,0 +1,253 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utilities/base_fixture.hpp" +#include "utilities/conversion_utilities.hpp" +#include "utilities/device_comm_wrapper.hpp" +#include "utilities/mg_utilities.hpp" +#include "utilities/property_generator_utilities.hpp" +#include "utilities/test_graphs.hpp" +#include "utilities/thrust_wrapper.hpp" + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include + +#include + +#include + +struct EdgeTriangleCount_Usecase { + bool edge_masking_{false}; + bool check_correctness_{true}; +}; + +template +class Tests_MGEdgeTriangleCount + : public ::testing::TestWithParam> { + public: + Tests_MGEdgeTriangleCount() {} + + static void SetUpTestCase() { handle_ = cugraph::test::initialize_mg_handle(); } + + static void TearDownTestCase() { handle_.reset(); } + + virtual void SetUp() {} + virtual void TearDown() {} + + // Compare the results of running EdgeTriangleCount on multiple GPUs to that of a single-GPU run + template + void run_current_test(EdgeTriangleCount_Usecase const& edge_triangle_count_usecase, + input_usecase_t const& input_usecase) + { + using weight_t = float; + + HighResTimer hr_timer{}; + + // 1. create MG graph + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + handle_->get_comms().barrier(); + hr_timer.start("MG Construct graph"); + } + + cugraph::graph_t mg_graph(*handle_); + std::optional> mg_renumber_map{std::nullopt}; + std::tie(mg_graph, std::ignore, mg_renumber_map) = + cugraph::test::construct_graph( + *handle_, input_usecase, false, true, false, true); + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + handle_->get_comms().barrier(); + hr_timer.stop(); + hr_timer.display_and_clear(std::cout); + } + + auto mg_graph_view = mg_graph.view(); + + std::optional> edge_mask{std::nullopt}; + if (edge_triangle_count_usecase.edge_masking_) { + edge_mask = cugraph::test::generate::edge_property( + *handle_, mg_graph_view, 2); + mg_graph_view.attach_edge_mask((*edge_mask).view()); + } + + // 2. run MG EdgeTriangleCount + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + handle_->get_comms().barrier(); + hr_timer.start("MG EdgeTriangleCount"); + } + + auto d_mg_cugraph_results = + cugraph::edge_triangle_count(*handle_, mg_graph_view); + + if (cugraph::test::g_perf) { + RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement + handle_->get_comms().barrier(); + hr_timer.stop(); + hr_timer.display_and_clear(std::cout); + } + + // 3. Compare SG & MG results + + if (edge_triangle_count_usecase.check_correctness_) { + // 3-1. Convert to SG graph + + cugraph::graph_t sg_graph(*handle_); + std::optional< + cugraph::edge_property_t, edge_t>> + d_sg_cugraph_results{std::nullopt}; + std::tie(sg_graph, std::ignore, d_sg_cugraph_results, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + std::optional>{std::nullopt}, + // FIXME: Update 'create_graph_from_edgelist' to support int32_t and int64_t values + std::make_optional(d_mg_cugraph_results.view()), + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); + + if (handle_->get_comms().get_rank() == int{0}) { + // 3-2. Convert the MG triangle counts stored as 'edge_property_t' to device vector + + auto [edgelist_srcs, + edgelist_dsts, + d_edgelist_weights, + d_edge_triangle_counts, + d_edgelist_type] = + cugraph::decompress_to_edgelist( + *handle_, + sg_graph.view(), + std::optional>{std::nullopt}, + // FIXME: Update 'decompress_edgelist' to support int32_t and int64_t values + std::make_optional((*d_sg_cugraph_results).view()), + std::optional>{std::nullopt}, + std::optional>{ + std::nullopt}); // FIXME: No longer needed + + // 3-3. Run SG EdgeTriangleCount + + auto ref_d_sg_cugraph_results = + cugraph::edge_triangle_count(*handle_, sg_graph.view()); + auto [ref_edgelist_srcs, + ref_edgelist_dsts, + ref_d_edgelist_weights, + ref_d_edge_triangle_counts, + ref_d_edgelist_type] = + cugraph::decompress_to_edgelist( + *handle_, + sg_graph.view(), + std::optional>{std::nullopt}, + std::make_optional(ref_d_sg_cugraph_results.view()), + std::optional>{std::nullopt}, + std::optional>{ + std::nullopt}); // FIXME: No longer needed + + // 3-4. Compare + + auto h_mg_edge_triangle_counts = cugraph::test::to_host(*handle_, *d_edge_triangle_counts); + auto h_sg_edge_triangle_counts = + cugraph::test::to_host(*handle_, *ref_d_edge_triangle_counts); + + ASSERT_TRUE(std::equal(h_mg_edge_triangle_counts.begin(), + h_mg_edge_triangle_counts.end(), + h_sg_edge_triangle_counts.begin())); + } + } + } + + private: + static std::unique_ptr handle_; +}; + +template +std::unique_ptr Tests_MGEdgeTriangleCount::handle_ = nullptr; + +using Tests_MGEdgeTriangleCount_File = Tests_MGEdgeTriangleCount; +using Tests_MGEdgeTriangleCount_Rmat = Tests_MGEdgeTriangleCount; + +TEST_P(Tests_MGEdgeTriangleCount_File, CheckInt32Int32) +{ + auto param = GetParam(); + run_current_test(std::get<0>(param), std::get<1>(param)); +} + +TEST_P(Tests_MGEdgeTriangleCount_Rmat, CheckInt32Int32) +{ + auto param = GetParam(); + run_current_test( + std::get<0>(param), override_Rmat_Usecase_with_cmd_line_arguments(std::get<1>(param))); +} + +TEST_P(Tests_MGEdgeTriangleCount_Rmat, CheckInt32Int64) +{ + auto param = GetParam(); + run_current_test( + std::get<0>(param), override_Rmat_Usecase_with_cmd_line_arguments(std::get<1>(param))); +} + +TEST_P(Tests_MGEdgeTriangleCount_Rmat, CheckInt64Int64) +{ + auto param = GetParam(); + run_current_test( + std::get<0>(param), override_Rmat_Usecase_with_cmd_line_arguments(std::get<1>(param))); +} + +INSTANTIATE_TEST_SUITE_P( + file_tests, + Tests_MGEdgeTriangleCount_File, + ::testing::Combine( + // enable correctness checks + ::testing::Values(EdgeTriangleCount_Usecase{false, true}, + EdgeTriangleCount_Usecase{true, true}), + ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx"), + cugraph::test::File_Usecase("test/datasets/dolphins.mtx")))); + +INSTANTIATE_TEST_SUITE_P( + rmat_small_tests, + Tests_MGEdgeTriangleCount_Rmat, + ::testing::Combine( + ::testing::Values(EdgeTriangleCount_Usecase{false, true}, + EdgeTriangleCount_Usecase{true, true}), + ::testing::Values(cugraph::test::Rmat_Usecase(10, 16, 0.57, 0.19, 0.19, 0, true, false)))); + +INSTANTIATE_TEST_SUITE_P( + rmat_benchmark_test, /* note that scale & edge factor can be overridden in benchmarking (with + --gtest_filter to select only the rmat_benchmark_test with a specific + vertex & edge type combination) by command line arguments and do not + include more than one Rmat_Usecase that differ only in scale or edge + factor (to avoid running same benchmarks more than once) */ + Tests_MGEdgeTriangleCount_Rmat, + ::testing::Combine( + ::testing::Values(EdgeTriangleCount_Usecase{false, false}, + EdgeTriangleCount_Usecase{true, false}), + ::testing::Values(cugraph::test::Rmat_Usecase(20, 32, 0.57, 0.19, 0.19, 0, true, false)))); + +CUGRAPH_MG_TEST_PROGRAM_MAIN() diff --git a/cpp/tests/community/mg_egonet_test.cu b/cpp/tests/community/mg_egonet_test.cu index 66ab1f47312..ac363df3ec5 100644 --- a/cpp/tests/community/mg_egonet_test.cu +++ b/cpp/tests/community/mg_egonet_test.cu @@ -199,13 +199,15 @@ class Tests_MGEgonet triplet_first + d_mg_aggregate_edgelist_src.size()); } - auto [sg_graph, sg_edge_weights, sg_number_map] = - cugraph::test::mg_graph_to_sg_graph(*handle_, - mg_graph_view, - mg_edge_weight_view, - std::make_optional>( - (*mg_renumber_map).data(), (*mg_renumber_map).size()), - false); + auto [sg_graph, sg_edge_weights, sg_edge_ids, sg_number_map] = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + mg_edge_weight_view, + std::optional>{std::nullopt}, + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); if (handle_->get_comms().get_rank() == 0) { auto d_mg_aggregate_edgelist_offsets = diff --git a/cpp/tests/community/mg_leiden_test.cpp b/cpp/tests/community/mg_leiden_test.cpp index f1a2fc83192..65f4827ba06 100644 --- a/cpp/tests/community/mg_leiden_test.cpp +++ b/cpp/tests/community/mg_leiden_test.cpp @@ -87,12 +87,14 @@ class Tests_MGLeiden std::optional< cugraph::edge_property_t, weight_t>> sg_edge_weights{std::nullopt}; - std::tie(sg_graph, sg_edge_weights, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - mg_edge_weight_view, - std::optional>{std::nullopt}, - false); // crate an SG graph with MG graph vertex IDs + std::tie(sg_graph, sg_edge_weights, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + mg_edge_weight_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + false); // crate an SG graph with MG graph vertex IDs // FIXME: We need to figure out how to test each iteration of // SG vs MG Leiden, possibly by passing results of refinement phase diff --git a/cpp/tests/community/mg_louvain_test.cpp b/cpp/tests/community/mg_louvain_test.cpp index 733ee9368ac..106ad2562f7 100644 --- a/cpp/tests/community/mg_louvain_test.cpp +++ b/cpp/tests/community/mg_louvain_test.cpp @@ -85,12 +85,14 @@ class Tests_MGLouvain std::optional< cugraph::edge_property_t, weight_t>> sg_edge_weights{std::nullopt}; - std::tie(sg_graph, sg_edge_weights, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - mg_edge_weight_view, - std::optional>{std::nullopt}, - false); // crate an SG graph with MG graph vertex IDs + std::tie(sg_graph, sg_edge_weights, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + mg_edge_weight_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + false); // crate an SG graph with MG graph vertex IDs weight_t sg_modularity{-1.0}; diff --git a/cpp/tests/community/mg_triangle_count_test.cpp b/cpp/tests/community/mg_triangle_count_test.cpp index ca3e0b2ac8f..932ff5050f1 100644 --- a/cpp/tests/community/mg_triangle_count_test.cpp +++ b/cpp/tests/community/mg_triangle_count_test.cpp @@ -178,13 +178,15 @@ class Tests_MGTriangleCount d_mg_triangle_counts.size())); cugraph::graph_t sg_graph(*handle_); - std::tie(sg_graph, std::ignore, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - std::optional>{std::nullopt}, - std::make_optional>((*mg_renumber_map).data(), - (*mg_renumber_map).size()), - false); + std::tie(sg_graph, std::ignore, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); if (handle_->get_comms().get_rank() == int{0}) { // 4-2. run SG TriangleCount diff --git a/cpp/tests/community/mg_weighted_matching_test.cpp b/cpp/tests/community/mg_weighted_matching_test.cpp index 21963922ab1..4f36ee36902 100644 --- a/cpp/tests/community/mg_weighted_matching_test.cpp +++ b/cpp/tests/community/mg_weighted_matching_test.cpp @@ -130,12 +130,14 @@ class Tests_MGWeightedMatching std::optional< cugraph::edge_property_t, weight_t>> sg_edge_weights{std::nullopt}; - std::tie(sg_graph, sg_edge_weights, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - mg_edge_weight_view, - std::optional>(std::nullopt), - false); + std::tie(sg_graph, sg_edge_weights, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + mg_edge_weight_view, + std::optional>{std::nullopt}, + std::optional>(std::nullopt), + false); if (handle_->get_comms().get_rank() == 0) { auto sg_graph_view = sg_graph.view(); diff --git a/cpp/tests/components/mg_weakly_connected_components_test.cpp b/cpp/tests/components/mg_weakly_connected_components_test.cpp index c510e3139fb..368fea68877 100644 --- a/cpp/tests/components/mg_weakly_connected_components_test.cpp +++ b/cpp/tests/components/mg_weakly_connected_components_test.cpp @@ -125,13 +125,15 @@ class Tests_MGWeaklyConnectedComponents raft::device_span(d_mg_components.data(), d_mg_components.size())); cugraph::graph_t sg_graph(*handle_); - std::tie(sg_graph, std::ignore, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - std::optional>{std::nullopt}, - std::make_optional>((*mg_renumber_map).data(), - (*mg_renumber_map).size()), - false); + std::tie(sg_graph, std::ignore, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); if (handle_->get_comms().get_rank() == int{0}) { // 3-2. run SG weakly connected components diff --git a/cpp/tests/cores/mg_core_number_test.cpp b/cpp/tests/cores/mg_core_number_test.cpp index ac99d7d4a93..f8294d81fdf 100644 --- a/cpp/tests/cores/mg_core_number_test.cpp +++ b/cpp/tests/cores/mg_core_number_test.cpp @@ -143,13 +143,15 @@ class Tests_MGCoreNumber raft::device_span(d_mg_core_numbers.data(), d_mg_core_numbers.size())); cugraph::graph_t sg_graph(*handle_); - std::tie(sg_graph, std::ignore, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - std::optional>{std::nullopt}, - std::make_optional>((*mg_renumber_map).data(), - (*mg_renumber_map).size()), - false); + std::tie(sg_graph, std::ignore, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); if (handle_->get_comms().get_rank() == int{0}) { // 3-2. run SG CoreNumber diff --git a/cpp/tests/cores/mg_k_core_test.cpp b/cpp/tests/cores/mg_k_core_test.cpp index 100c7fa3bcf..28bc445bda8 100644 --- a/cpp/tests/cores/mg_k_core_test.cpp +++ b/cpp/tests/cores/mg_k_core_test.cpp @@ -160,13 +160,15 @@ class Tests_MGKCore : public ::testing::TestWithParam>{std::nullopt}, raft::device_span(d_mg_core_numbers.data(), d_mg_core_numbers.size())); - auto [sg_graph, sg_edge_weights, sg_number_map] = - cugraph::test::mg_graph_to_sg_graph(*handle_, - mg_graph_view, - mg_edge_weight_view, - std::make_optional>( - (*mg_renumber_map).data(), (*mg_renumber_map).size()), - false); + auto [sg_graph, sg_edge_weights, sg_edge_ids, sg_number_map] = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + mg_edge_weight_view, + std::optional>{std::nullopt}, + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); if (handle_->get_comms().get_rank() == 0) { auto sg_graph_view = sg_graph.view(); diff --git a/cpp/tests/link_analysis/mg_hits_test.cpp b/cpp/tests/link_analysis/mg_hits_test.cpp index 101a4fe1557..40a439ffc4c 100644 --- a/cpp/tests/link_analysis/mg_hits_test.cpp +++ b/cpp/tests/link_analysis/mg_hits_test.cpp @@ -186,13 +186,15 @@ class Tests_MGHits : public ::testing::TestWithParam, weight_t>> sg_edge_weights{std::nullopt}; - std::tie(sg_graph, sg_edge_weights, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - std::optional>{std::nullopt}, - std::make_optional>((*mg_renumber_map).data(), - (*mg_renumber_map).size()), - false); + std::tie(sg_graph, sg_edge_weights, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); if (handle_->get_comms().get_rank() == int{0}) { // 3-3. run SG Hits diff --git a/cpp/tests/link_analysis/mg_pagerank_test.cpp b/cpp/tests/link_analysis/mg_pagerank_test.cpp index 6be451ac5fd..26136c8c9d2 100644 --- a/cpp/tests/link_analysis/mg_pagerank_test.cpp +++ b/cpp/tests/link_analysis/mg_pagerank_test.cpp @@ -202,13 +202,15 @@ class Tests_MGPageRank std::optional< cugraph::edge_property_t, weight_t>> sg_edge_weights{std::nullopt}; - std::tie(sg_graph, sg_edge_weights, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - mg_edge_weight_view, - std::make_optional>((*d_mg_renumber_map).data(), - (*d_mg_renumber_map).size()), - false); + std::tie(sg_graph, sg_edge_weights, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + mg_edge_weight_view, + std::optional>{std::nullopt}, + std::make_optional>((*d_mg_renumber_map).data(), + (*d_mg_renumber_map).size()), + false); if (handle_->get_comms().get_rank() == int{0}) { // 4-2. run SG PageRank diff --git a/cpp/tests/mtmg/threaded_test_louvain.cu b/cpp/tests/mtmg/threaded_test_louvain.cu index ab51d701b57..b9c8f621ab8 100644 --- a/cpp/tests/mtmg/threaded_test_louvain.cu +++ b/cpp/tests/mtmg/threaded_test_louvain.cu @@ -384,12 +384,13 @@ class Tests_Multithreaded auto thread_handle = instance_manager->get_handle(); if (thread_handle.get_rank() == 0) { - std::tie(sg_graph, sg_edge_weights, std::ignore) = + std::tie(sg_graph, sg_edge_weights, std::ignore, std::ignore) = cugraph::test::mg_graph_to_sg_graph( thread_handle.raft_handle(), graph_view.get(thread_handle), edge_weights ? std::make_optional(edge_weights->get(thread_handle).view()) : std::nullopt, + std::optional>{std::nullopt}, std::optional>{std::nullopt}, false); // create an SG graph with MG graph vertex IDs } else { @@ -398,6 +399,7 @@ class Tests_Multithreaded graph_view.get(thread_handle), edge_weights ? std::make_optional(edge_weights->get(thread_handle).view()) : std::nullopt, + std::optional>{std::nullopt}, std::optional>{std::nullopt}, false); // create an SG graph with MG graph vertex IDs } diff --git a/cpp/tests/prims/mg_count_if_e.cu b/cpp/tests/prims/mg_count_if_e.cu index 8ad1a20e585..137f7db8625 100644 --- a/cpp/tests/prims/mg_count_if_e.cu +++ b/cpp/tests/prims/mg_count_if_e.cu @@ -149,13 +149,15 @@ class Tests_MGCountIfE if (prims_usecase.check_correctness) { cugraph::graph_t sg_graph(*handle_); - std::tie(sg_graph, std::ignore, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - std::optional>{std::nullopt}, - std::make_optional>((*mg_renumber_map).data(), - (*mg_renumber_map).size()), - false); + std::tie(sg_graph, std::ignore, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); if (handle_->get_comms().get_rank() == 0) { auto sg_graph_view = sg_graph.view(); diff --git a/cpp/tests/prims/mg_count_if_v.cu b/cpp/tests/prims/mg_count_if_v.cu index eb0e8cf9835..e3f30e37729 100644 --- a/cpp/tests/prims/mg_count_if_v.cu +++ b/cpp/tests/prims/mg_count_if_v.cu @@ -123,13 +123,15 @@ class Tests_MGCountIfV if (prims_usecase.check_correctness) { cugraph::graph_t sg_graph(*handle_); - std::tie(sg_graph, std::ignore, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - std::optional>{std::nullopt}, - std::make_optional>((*mg_renumber_map).data(), - (*mg_renumber_map).size()), - false); + std::tie(sg_graph, std::ignore, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); if (handle_->get_comms().get_rank() == 0) { auto sg_graph_view = sg_graph.view(); diff --git a/cpp/tests/prims/mg_extract_transform_e.cu b/cpp/tests/prims/mg_extract_transform_e.cu index 48b893f6fea..20e87070fa5 100644 --- a/cpp/tests/prims/mg_extract_transform_e.cu +++ b/cpp/tests/prims/mg_extract_transform_e.cu @@ -253,13 +253,15 @@ class Tests_MGExtractTransformE } cugraph::graph_t sg_graph(*handle_); - std::tie(sg_graph, std::ignore, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - std::optional>{std::nullopt}, - std::make_optional>((*d_mg_renumber_map_labels).data(), - (*d_mg_renumber_map_labels).size()), - false); + std::tie(sg_graph, std::ignore, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + std::make_optional>((*d_mg_renumber_map_labels).data(), + (*d_mg_renumber_map_labels).size()), + false); rmm::device_uvector sg_vertex_prop(0, handle_->get_stream()); std::tie(std::ignore, sg_vertex_prop) = cugraph::test::mg_vertex_property_values_to_sg_vertex_property_values( diff --git a/cpp/tests/prims/mg_extract_transform_v_frontier_outgoing_e.cu b/cpp/tests/prims/mg_extract_transform_v_frontier_outgoing_e.cu index 3611a250afd..9e7611190ae 100644 --- a/cpp/tests/prims/mg_extract_transform_v_frontier_outgoing_e.cu +++ b/cpp/tests/prims/mg_extract_transform_v_frontier_outgoing_e.cu @@ -283,13 +283,15 @@ class Tests_MGExtractTransformVFrontierOutgoingE } cugraph::graph_t sg_graph(*handle_); - std::tie(sg_graph, std::ignore, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - std::optional>{std::nullopt}, - std::make_optional>((*d_mg_renumber_map_labels).data(), - (*d_mg_renumber_map_labels).size()), - false); + std::tie(sg_graph, std::ignore, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + std::make_optional>((*d_mg_renumber_map_labels).data(), + (*d_mg_renumber_map_labels).size()), + false); rmm::device_uvector sg_vertex_prop(0, handle_->get_stream()); std::tie(std::ignore, sg_vertex_prop) = cugraph::test::mg_vertex_property_values_to_sg_vertex_property_values( diff --git a/cpp/tests/prims/mg_per_v_pair_transform_dst_nbr_intersection.cu b/cpp/tests/prims/mg_per_v_pair_transform_dst_nbr_intersection.cu index 762da62eeb8..75b711fbd9c 100644 --- a/cpp/tests/prims/mg_per_v_pair_transform_dst_nbr_intersection.cu +++ b/cpp/tests/prims/mg_per_v_pair_transform_dst_nbr_intersection.cu @@ -226,13 +226,15 @@ class Tests_MGPerVPairTransformDstNbrIntersection *handle_, std::get<1>(mg_result_buffer).data(), std::get<1>(mg_result_buffer).size()); cugraph::graph_t sg_graph(*handle_); - std::tie(sg_graph, std::ignore, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - std::optional>{std::nullopt}, - std::make_optional>((*mg_renumber_map).data(), - (*mg_renumber_map).size()), - false); + std::tie(sg_graph, std::ignore, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); if (handle_->get_comms().get_rank() == 0) { auto sg_graph_view = sg_graph.view(); diff --git a/cpp/tests/prims/mg_per_v_pair_transform_dst_nbr_weighted_intersection.cu b/cpp/tests/prims/mg_per_v_pair_transform_dst_nbr_weighted_intersection.cu index de78b42603d..48bbc6176d8 100644 --- a/cpp/tests/prims/mg_per_v_pair_transform_dst_nbr_weighted_intersection.cu +++ b/cpp/tests/prims/mg_per_v_pair_transform_dst_nbr_weighted_intersection.cu @@ -258,15 +258,17 @@ class Tests_MGPerVPairTransformDstNbrIntersection weight_t>> sg_edge_weight{std::nullopt}; - std::tie(sg_graph, sg_edge_weight, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - mg_edge_weight - ? std::make_optional(mg_edge_weight_view) - : std::optional>{std::nullopt}, - std::make_optional>((*mg_renumber_map).data(), - (*mg_renumber_map).size()), - false); + std::tie(sg_graph, sg_edge_weight, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + mg_edge_weight + ? std::make_optional(mg_edge_weight_view) + : std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); if (handle_->get_comms().get_rank() == 0) { auto sg_graph_view = sg_graph.view(); diff --git a/cpp/tests/prims/mg_per_v_random_select_transform_outgoing_e.cu b/cpp/tests/prims/mg_per_v_random_select_transform_outgoing_e.cu index 97c7333cd2e..b99dbf16107 100644 --- a/cpp/tests/prims/mg_per_v_random_select_transform_outgoing_e.cu +++ b/cpp/tests/prims/mg_per_v_random_select_transform_outgoing_e.cu @@ -320,13 +320,15 @@ class Tests_MGPerVRandomSelectTransformOutgoingE } cugraph::graph_t sg_graph(*handle_); - std::tie(sg_graph, std::ignore, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - std::optional>{std::nullopt}, - std::make_optional>((*mg_renumber_map).data(), - (*mg_renumber_map).size()), - false); + std::tie(sg_graph, std::ignore, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); if (handle_->get_comms().get_rank() == 0) { std::optional> mg_aggregate_sample_offsets{std::nullopt}; diff --git a/cpp/tests/prims/mg_per_v_transform_reduce_dst_key_aggregated_outgoing_e.cu b/cpp/tests/prims/mg_per_v_transform_reduce_dst_key_aggregated_outgoing_e.cu index efcfee9fc66..fd9192dcce5 100644 --- a/cpp/tests/prims/mg_per_v_transform_reduce_dst_key_aggregated_outgoing_e.cu +++ b/cpp/tests/prims/mg_per_v_transform_reduce_dst_key_aggregated_outgoing_e.cu @@ -297,13 +297,15 @@ class Tests_MGPerVTransformReduceDstKeyAggregatedOutgoingE std::optional< cugraph::edge_property_t, weight_t>> sg_edge_weights{std::nullopt}; - std::tie(sg_graph, sg_edge_weights, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - std::optional>{std::nullopt}, - std::make_optional>((*mg_renumber_map).data(), - (*mg_renumber_map).size()), - false); + std::tie(sg_graph, sg_edge_weights, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); for (size_t i = 0; i < reduction_types.size(); ++i) { auto mg_aggregate_results = diff --git a/cpp/tests/prims/mg_per_v_transform_reduce_incoming_outgoing_e.cu b/cpp/tests/prims/mg_per_v_transform_reduce_incoming_outgoing_e.cu index e3eb56d5a6e..be29c793ad5 100644 --- a/cpp/tests/prims/mg_per_v_transform_reduce_incoming_outgoing_e.cu +++ b/cpp/tests/prims/mg_per_v_transform_reduce_incoming_outgoing_e.cu @@ -271,13 +271,15 @@ class Tests_MGPerVTransformReduceIncomingOutgoingE if (prims_usecase.check_correctness) { cugraph::graph_t sg_graph(*handle_); - std::tie(sg_graph, std::ignore, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - std::optional>{std::nullopt}, - std::make_optional>((*mg_renumber_map).data(), - (*mg_renumber_map).size()), - false); + std::tie(sg_graph, std::ignore, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); for (size_t i = 0; i < reduction_types.size(); ++i) { auto mg_aggregate_in_results = diff --git a/cpp/tests/prims/mg_reduce_v.cu b/cpp/tests/prims/mg_reduce_v.cu index 1449e8f9910..e91db5fa6ad 100644 --- a/cpp/tests/prims/mg_reduce_v.cu +++ b/cpp/tests/prims/mg_reduce_v.cu @@ -163,13 +163,15 @@ class Tests_MGReduceV if (prims_usecase.check_correctness) { cugraph::graph_t sg_graph(*handle_); - std::tie(sg_graph, std::ignore, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - std::optional>{std::nullopt}, - std::make_optional>((*mg_renumber_map).data(), - (*mg_renumber_map).size()), - false); + std::tie(sg_graph, std::ignore, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); if (handle_->get_comms().get_rank() == 0) { auto sg_graph_view = sg_graph.view(); diff --git a/cpp/tests/prims/mg_transform_reduce_dst_nbr_intersection_of_e_endpoints_by_v.cu b/cpp/tests/prims/mg_transform_reduce_dst_nbr_intersection_of_e_endpoints_by_v.cu index 71cdf27fda1..4fac6ef3be7 100644 --- a/cpp/tests/prims/mg_transform_reduce_dst_nbr_intersection_of_e_endpoints_by_v.cu +++ b/cpp/tests/prims/mg_transform_reduce_dst_nbr_intersection_of_e_endpoints_by_v.cu @@ -174,13 +174,15 @@ class Tests_MGTransformReduceDstNbrIntersectionOfEEndpointsByV raft::device_span(mg_result_buffer.data(), mg_result_buffer.size())); cugraph::graph_t sg_graph(*handle_); - std::tie(sg_graph, std::ignore, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - std::optional>{std::nullopt}, - std::make_optional>((*mg_renumber_map).data(), - (*mg_renumber_map).size()), - false); + std::tie(sg_graph, std::ignore, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); if (handle_->get_comms().get_rank() == 0) { auto sg_graph_view = sg_graph.view(); diff --git a/cpp/tests/prims/mg_transform_reduce_e.cu b/cpp/tests/prims/mg_transform_reduce_e.cu index a086571d6e0..4785a8bb01b 100644 --- a/cpp/tests/prims/mg_transform_reduce_e.cu +++ b/cpp/tests/prims/mg_transform_reduce_e.cu @@ -159,13 +159,15 @@ class Tests_MGTransformReduceE if (prims_usecase.check_correctness) { cugraph::graph_t sg_graph(*handle_); - std::tie(sg_graph, std::ignore, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - std::optional>{std::nullopt}, - std::make_optional>((*mg_renumber_map).data(), - (*mg_renumber_map).size()), - false); + std::tie(sg_graph, std::ignore, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); if (handle_->get_comms().get_rank() == 0) { auto sg_graph_view = sg_graph.view(); diff --git a/cpp/tests/prims/mg_transform_reduce_e_by_src_dst_key.cu b/cpp/tests/prims/mg_transform_reduce_e_by_src_dst_key.cu index a66c70ff586..9950b5bdbf4 100644 --- a/cpp/tests/prims/mg_transform_reduce_e_by_src_dst_key.cu +++ b/cpp/tests/prims/mg_transform_reduce_e_by_src_dst_key.cu @@ -237,13 +237,15 @@ class Tests_MGTransformReduceEBySrcDstKey cugraph::get_dataframe_buffer_begin(mg_aggregate_by_dst_values)); cugraph::graph_t sg_graph(*handle_); - std::tie(sg_graph, std::ignore, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - std::optional>{std::nullopt}, - std::make_optional>((*mg_renumber_map).data(), - (*mg_renumber_map).size()), - false); + std::tie(sg_graph, std::ignore, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); if (handle_->get_comms().get_rank() == 0) { auto sg_graph_view = sg_graph.view(); diff --git a/cpp/tests/prims/mg_transform_reduce_v.cu b/cpp/tests/prims/mg_transform_reduce_v.cu index c26085a55c4..f6f07bc03ab 100644 --- a/cpp/tests/prims/mg_transform_reduce_v.cu +++ b/cpp/tests/prims/mg_transform_reduce_v.cu @@ -169,13 +169,15 @@ class Tests_MGTransformReduceV if (prims_usecase.check_correctness) { cugraph::graph_t sg_graph(*handle_); - std::tie(sg_graph, std::ignore, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - std::optional>{std::nullopt}, - std::make_optional>((*mg_renumber_map).data(), - (*mg_renumber_map).size()), - false); + std::tie(sg_graph, std::ignore, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); if (handle_->get_comms().get_rank() == 0) { auto sg_graph_view = sg_graph.view(); diff --git a/cpp/tests/prims/mg_transform_reduce_v_frontier_outgoing_e_by_dst.cu b/cpp/tests/prims/mg_transform_reduce_v_frontier_outgoing_e_by_dst.cu index 07a0f7e7aab..335a7ec879c 100644 --- a/cpp/tests/prims/mg_transform_reduce_v_frontier_outgoing_e_by_dst.cu +++ b/cpp/tests/prims/mg_transform_reduce_v_frontier_outgoing_e_by_dst.cu @@ -292,13 +292,15 @@ class Tests_MGTransformReduceVFrontierOutgoingEByDst } cugraph::graph_t sg_graph(*handle_); - std::tie(sg_graph, std::ignore, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - std::optional>{std::nullopt}, - std::make_optional>((*mg_renumber_map).data(), - (*mg_renumber_map).size()), - false); + std::tie(sg_graph, std::ignore, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); if (handle_->get_comms().get_rank() == int{0}) { if constexpr (std::is_same_v) { diff --git a/cpp/tests/structure/mg_coarsen_graph_test.cpp b/cpp/tests/structure/mg_coarsen_graph_test.cpp index 1da30869545..471773d71bd 100644 --- a/cpp/tests/structure/mg_coarsen_graph_test.cpp +++ b/cpp/tests/structure/mg_coarsen_graph_test.cpp @@ -330,23 +330,26 @@ class Tests_MGCoarsenGraph cugraph::edge_property_t, weight_t>> sg_edge_weights{std::nullopt}; - std::tie(sg_graph, sg_edge_weights, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - mg_edge_weight_view, - std::optional>{std::nullopt}, - false); + std::tie(sg_graph, sg_edge_weights, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + mg_edge_weight_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + false); cugraph::graph_t sg_coarse_graph(*handle_); std::optional< cugraph::edge_property_t, weight_t>> sg_coarse_edge_weights{std::nullopt}; - std::tie(sg_coarse_graph, sg_coarse_edge_weights, std::ignore) = + std::tie(sg_coarse_graph, sg_coarse_edge_weights, std::ignore, std::ignore) = cugraph::test::mg_graph_to_sg_graph( *handle_, mg_coarse_graph_view, mg_coarse_edge_weight_view, + std::optional>{std::nullopt}, std::optional>{std::nullopt}, false); diff --git a/cpp/tests/structure/mg_count_self_loops_and_multi_edges_test.cpp b/cpp/tests/structure/mg_count_self_loops_and_multi_edges_test.cpp index 45fac884f49..61f40049e31 100644 --- a/cpp/tests/structure/mg_count_self_loops_and_multi_edges_test.cpp +++ b/cpp/tests/structure/mg_count_self_loops_and_multi_edges_test.cpp @@ -126,13 +126,15 @@ class Tests_MGCountSelfLoopsAndMultiEdges // 3-1. aggregate MG results cugraph::graph_t sg_graph(*handle_); - std::tie(sg_graph, std::ignore, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - std::optional>{std::nullopt}, - std::make_optional>((*mg_renumber_map).data(), - (*mg_renumber_map).size()), - false); + std::tie(sg_graph, std::ignore, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); if (handle_->get_comms().get_rank() == 0) { auto sg_graph_view = sg_graph.view(); diff --git a/cpp/tests/structure/mg_has_edge_and_compute_multiplicity_test.cpp b/cpp/tests/structure/mg_has_edge_and_compute_multiplicity_test.cpp index 0ee72726294..3d3d881fb23 100644 --- a/cpp/tests/structure/mg_has_edge_and_compute_multiplicity_test.cpp +++ b/cpp/tests/structure/mg_has_edge_and_compute_multiplicity_test.cpp @@ -204,13 +204,15 @@ class Tests_MGHasEdgeAndComputeMultiplicity d_mg_edge_multiplicities.size())); cugraph::graph_t sg_graph(*handle_); - std::tie(sg_graph, std::ignore, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - std::optional>{std::nullopt}, - std::make_optional>((*mg_renumber_map).data(), - (*mg_renumber_map).size()), - false); + std::tie(sg_graph, std::ignore, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); if (handle_->get_comms().get_rank() == 0) { auto sg_graph_view = sg_graph.view(); diff --git a/cpp/tests/structure/mg_induced_subgraph_test.cu b/cpp/tests/structure/mg_induced_subgraph_test.cu index 3b32c15bf9f..2ed909b9955 100644 --- a/cpp/tests/structure/mg_induced_subgraph_test.cu +++ b/cpp/tests/structure/mg_induced_subgraph_test.cu @@ -214,12 +214,14 @@ class Tests_MGInducedSubgraph true, handle_->get_stream()); - auto [sg_graph, sg_edge_weights, sg_number_map] = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - mg_edge_weight_view, - std::optional>{std::nullopt}, - false); + auto [sg_graph, sg_edge_weights, sg_edge_ids, sg_number_map] = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + mg_edge_weight_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + false); if (my_rank == 0) { auto d_sg_subgraph_offsets = cugraph::test::to_device(*handle_, h_sg_subgraph_offsets); diff --git a/cpp/tests/structure/mg_symmetrize_test.cpp b/cpp/tests/structure/mg_symmetrize_test.cpp index e607370f62a..7f1e4f04dc7 100644 --- a/cpp/tests/structure/mg_symmetrize_test.cpp +++ b/cpp/tests/structure/mg_symmetrize_test.cpp @@ -88,13 +88,15 @@ class Tests_MGSymmetrize weight_t>> sg_edge_weights{std::nullopt}; if (symmetrize_usecase.check_correctness) { - std::tie(sg_graph, sg_edge_weights, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph.view(), - mg_edge_weights ? std::make_optional((*mg_edge_weights).view()) : std::nullopt, - std::make_optional>((*mg_renumber_map).data(), - (*mg_renumber_map).size()), - false); + std::tie(sg_graph, sg_edge_weights, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph.view(), + mg_edge_weights ? std::make_optional((*mg_edge_weights).view()) : std::nullopt, + std::optional>{std::nullopt}, + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); } // 3. run MG symmetrize diff --git a/cpp/tests/structure/mg_transpose_storage_test.cpp b/cpp/tests/structure/mg_transpose_storage_test.cpp index c8b4f70f1e2..e870f648039 100644 --- a/cpp/tests/structure/mg_transpose_storage_test.cpp +++ b/cpp/tests/structure/mg_transpose_storage_test.cpp @@ -87,13 +87,15 @@ class Tests_MGTransposeStorage weight_t>> sg_edge_weights{std::nullopt}; if (transpose_storage_usecase.check_correctness) { - std::tie(sg_graph, sg_edge_weights, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph.view(), - mg_edge_weights ? std::make_optional((*mg_edge_weights).view()) : std::nullopt, - std::make_optional>((*mg_renumber_map).data(), - (*mg_renumber_map).size()), - false); + std::tie(sg_graph, sg_edge_weights, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph.view(), + mg_edge_weights ? std::make_optional((*mg_edge_weights).view()) : std::nullopt, + std::optional>{std::nullopt}, + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); } // 2. run MG transpose storage diff --git a/cpp/tests/structure/mg_transpose_test.cpp b/cpp/tests/structure/mg_transpose_test.cpp index 4428f8430d5..921cef42595 100644 --- a/cpp/tests/structure/mg_transpose_test.cpp +++ b/cpp/tests/structure/mg_transpose_test.cpp @@ -87,13 +87,15 @@ class Tests_MGTranspose weight_t>> sg_edge_weights{std::nullopt}; if (transpose_usecase.check_correctness) { - std::tie(sg_graph, sg_edge_weights, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph.view(), - mg_edge_weights ? std::make_optional((*mg_edge_weights).view()) : std::nullopt, - std::make_optional>((*mg_renumber_map).data(), - (*mg_renumber_map).size()), - false); + std::tie(sg_graph, sg_edge_weights, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph.view(), + mg_edge_weights ? std::make_optional((*mg_edge_weights).view()) : std::nullopt, + std::optional>{std::nullopt}, + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); } // 3. run MG transpose diff --git a/cpp/tests/traversal/mg_bfs_test.cpp b/cpp/tests/traversal/mg_bfs_test.cpp index 431ed75c82d..1b63ad3b085 100644 --- a/cpp/tests/traversal/mg_bfs_test.cpp +++ b/cpp/tests/traversal/mg_bfs_test.cpp @@ -183,13 +183,15 @@ class Tests_MGBFS : public ::testing::TestWithParam sg_graph(*handle_); - std::tie(sg_graph, std::ignore, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - std::optional>{std::nullopt}, - std::make_optional>((*mg_renumber_map).data(), - (*mg_renumber_map).size()), - false); + std::tie(sg_graph, std::ignore, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); if (handle_->get_comms().get_rank() == int{0}) { // 3-3. run SG BFS diff --git a/cpp/tests/traversal/mg_extract_bfs_paths_test.cu b/cpp/tests/traversal/mg_extract_bfs_paths_test.cu index 8484066c6a0..476a6ffab8f 100644 --- a/cpp/tests/traversal/mg_extract_bfs_paths_test.cu +++ b/cpp/tests/traversal/mg_extract_bfs_paths_test.cu @@ -237,13 +237,15 @@ class Tests_MGExtractBFSPaths cugraph::test::device_gatherv(*handle_, d_mg_paths.data(), d_mg_paths.size()); cugraph::graph_t sg_graph(*handle_); - std::tie(sg_graph, std::ignore, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - std::optional>{std::nullopt}, - std::make_optional>((*mg_renumber_map).data(), - (*mg_renumber_map).size()), - false); + std::tie(sg_graph, std::ignore, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); if (handle_->get_comms().get_rank() == int{0}) { // run SG extract_bfs_paths diff --git a/cpp/tests/traversal/mg_k_hop_nbrs_test.cpp b/cpp/tests/traversal/mg_k_hop_nbrs_test.cpp index 07ea107a2ed..64674fb3799 100644 --- a/cpp/tests/traversal/mg_k_hop_nbrs_test.cpp +++ b/cpp/tests/traversal/mg_k_hop_nbrs_test.cpp @@ -178,13 +178,15 @@ class Tests_MGKHopNbrs *handle_, raft::device_span(d_mg_nbrs.data(), d_mg_nbrs.size())); cugraph::graph_t sg_graph(*handle_); - std::tie(sg_graph, std::ignore, std::ignore) = cugraph::test::mg_graph_to_sg_graph( - *handle_, - mg_graph_view, - std::optional>{std::nullopt}, - std::make_optional>((*mg_renumber_map).data(), - (*mg_renumber_map).size()), - false); + std::tie(sg_graph, std::ignore, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + std::optional>{std::nullopt}, + std::optional>{std::nullopt}, + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); if (handle_->get_comms().get_rank() == int{0}) { // 3-3. run SG K-hop neighbors diff --git a/cpp/tests/traversal/mg_sssp_test.cpp b/cpp/tests/traversal/mg_sssp_test.cpp index 188d0eca115..9ad16d1c947 100644 --- a/cpp/tests/traversal/mg_sssp_test.cpp +++ b/cpp/tests/traversal/mg_sssp_test.cpp @@ -176,13 +176,15 @@ class Tests_MGSSSP : public ::testing::TestWithParam, weight_t>> sg_edge_weights{std::nullopt}; - std::tie(sg_graph, sg_edge_weights, std::ignore) = - cugraph::test::mg_graph_to_sg_graph(*handle_, - mg_graph_view, - mg_edge_weight_view, - std::make_optional>( - (*mg_renumber_map).data(), (*mg_renumber_map).size()), - false); + std::tie(sg_graph, sg_edge_weights, std::ignore, std::ignore) = + cugraph::test::mg_graph_to_sg_graph( + *handle_, + mg_graph_view, + mg_edge_weight_view, + std::optional>{std::nullopt}, + std::make_optional>((*mg_renumber_map).data(), + (*mg_renumber_map).size()), + false); if (handle_->get_comms().get_rank() == int{0}) { // 3-3. run SG SSSP diff --git a/cpp/tests/utilities/conversion_utilities.hpp b/cpp/tests/utilities/conversion_utilities.hpp index 9b55f45d5bd..24a8ecbe4fd 100644 --- a/cpp/tests/utilities/conversion_utilities.hpp +++ b/cpp/tests/utilities/conversion_utilities.hpp @@ -216,15 +216,20 @@ graph_to_host_csc( // Only the rank 0 GPU holds the valid data template -std::tuple, - std::optional, - weight_t>>, - std::optional>> +std::tuple< + cugraph::graph_t, + std::optional< + cugraph::edge_property_t, + weight_t>>, + std::optional< + cugraph::edge_property_t, + edge_t>>, + std::optional>> mg_graph_to_sg_graph( raft::handle_t const& handle, cugraph::graph_view_t const& graph_view, std::optional> edge_weight_view, + std::optional> edge_id_view, std::optional> renumber_map, bool renumber); diff --git a/cpp/tests/utilities/conversion_utilities_impl.cuh b/cpp/tests/utilities/conversion_utilities_impl.cuh index 6eb7357eedd..748a5731b89 100644 --- a/cpp/tests/utilities/conversion_utilities_impl.cuh +++ b/cpp/tests/utilities/conversion_utilities_impl.cuh @@ -283,23 +283,26 @@ template , std::optional, weight_t>>, + std::optional, edge_t>>, std::optional>> mg_graph_to_sg_graph( raft::handle_t const& handle, cugraph::graph_view_t const& graph_view, std::optional> edge_weight_view, + std::optional> edge_id_view, std::optional> renumber_map, bool renumber) { rmm::device_uvector d_src(0, handle.get_stream()); rmm::device_uvector d_dst(0, handle.get_stream()); std::optional> d_wgt{std::nullopt}; + std::optional> d_edge_id{std::nullopt}; - std::tie(d_src, d_dst, d_wgt, std::ignore, std::ignore) = cugraph::decompress_to_edgelist( + std::tie(d_src, d_dst, d_wgt, d_edge_id, std::ignore) = cugraph::decompress_to_edgelist( handle, graph_view, edge_weight_view, - std::optional>{std::nullopt}, + edge_id_view, std::optional>{std::nullopt}, renumber_map); @@ -310,6 +313,9 @@ mg_graph_to_sg_graph( if (d_wgt) *d_wgt = cugraph::test::device_gatherv( handle, raft::device_span{d_wgt->data(), d_wgt->size()}); + if (d_edge_id) + *d_edge_id = cugraph::test::device_gatherv( + handle, raft::device_span{d_edge_id->data(), d_edge_id->size()}); rmm::device_uvector vertices(0, handle.get_stream()); if (renumber_map) { vertices = cugraph::test::device_gatherv(handle, *renumber_map); } @@ -317,6 +323,8 @@ mg_graph_to_sg_graph( graph_t sg_graph(handle); std::optional, weight_t>> sg_edge_weights{std::nullopt}; + std::optional, edge_t>> + sg_edge_ids{std::nullopt}; std::optional> sg_number_map; if (handle.get_comms().get_rank() == 0) { if (!renumber_map) { @@ -325,7 +333,7 @@ mg_graph_to_sg_graph( handle.get_stream(), vertices.data(), vertices.size(), vertex_t{0}); } - std::tie(sg_graph, sg_edge_weights, std::ignore, std::ignore, sg_number_map) = + std::tie(sg_graph, sg_edge_weights, sg_edge_ids, std::ignore, sg_number_map) = cugraph::create_graph_from_edgelist diff --git a/cpp/tests/utilities/conversion_utilities_mg.cu b/cpp/tests/utilities/conversion_utilities_mg.cu index d657f868497..cb4703ec89b 100644 --- a/cpp/tests/utilities/conversion_utilities_mg.cu +++ b/cpp/tests/utilities/conversion_utilities_mg.cu @@ -381,132 +381,156 @@ graph_to_host_csc( template std::tuple< cugraph::graph_t, std::optional, float>>, + std::optional, int32_t>>, std::optional>> mg_graph_to_sg_graph( raft::handle_t const& handle, cugraph::graph_view_t const& graph_view, std::optional> edge_weight_view, + std::optional> edge_id_view, std::optional> renumber_map, bool renumber); template std::tuple< cugraph::graph_t, std::optional, float>>, + std::optional, int64_t>>, std::optional>> mg_graph_to_sg_graph( raft::handle_t const& handle, cugraph::graph_view_t const& graph_view, std::optional> edge_weight_view, + std::optional> edge_id_view, std::optional> renumber_map, bool renumber); template std::tuple< cugraph::graph_t, std::optional, float>>, + std::optional, int64_t>>, std::optional>> mg_graph_to_sg_graph( raft::handle_t const& handle, cugraph::graph_view_t const& graph_view, std::optional> edge_weight_view, + std::optional> edge_id_view, std::optional> renumber_map, bool renumber); template std::tuple< cugraph::graph_t, std::optional, double>>, + std::optional, int32_t>>, std::optional>> mg_graph_to_sg_graph( raft::handle_t const& handle, cugraph::graph_view_t const& graph_view, std::optional> edge_weight_view, + std::optional> edge_id_view, std::optional> renumber_map, bool renumber); template std::tuple< cugraph::graph_t, std::optional, double>>, + std::optional, int64_t>>, std::optional>> mg_graph_to_sg_graph( raft::handle_t const& handle, cugraph::graph_view_t const& graph_view, std::optional> edge_weight_view, + std::optional> edge_id_view, std::optional> renumber_map, bool renumber); template std::tuple< cugraph::graph_t, std::optional, double>>, + std::optional, int64_t>>, std::optional>> mg_graph_to_sg_graph( raft::handle_t const& handle, cugraph::graph_view_t const& graph_view, std::optional> edge_weight_view, + std::optional> edge_id_view, std::optional> renumber_map, bool renumber); template std::tuple< cugraph::graph_t, std::optional, float>>, + std::optional, int32_t>>, std::optional>> mg_graph_to_sg_graph( raft::handle_t const& handle, cugraph::graph_view_t const& graph_view, std::optional> edge_weight_view, + std::optional> edge_id_view, std::optional> renumber_map, bool renumber); template std::tuple< cugraph::graph_t, std::optional, float>>, + std::optional, int64_t>>, std::optional>> mg_graph_to_sg_graph( raft::handle_t const& handle, cugraph::graph_view_t const& graph_view, std::optional> edge_weight_view, + std::optional> edge_id_view, std::optional> renumber_map, bool renumber); template std::tuple< cugraph::graph_t, std::optional, float>>, + std::optional, int64_t>>, std::optional>> mg_graph_to_sg_graph( raft::handle_t const& handle, cugraph::graph_view_t const& graph_view, std::optional> edge_weight_view, + std::optional> edge_id_view, std::optional> renumber_map, bool renumber); template std::tuple< cugraph::graph_t, std::optional, double>>, + std::optional, int32_t>>, std::optional>> mg_graph_to_sg_graph( raft::handle_t const& handle, cugraph::graph_view_t const& graph_view, std::optional> edge_weight_view, + std::optional> edge_id_view, std::optional> renumber_map, bool renumber); template std::tuple< cugraph::graph_t, std::optional, double>>, + std::optional, int64_t>>, std::optional>> mg_graph_to_sg_graph( raft::handle_t const& handle, cugraph::graph_view_t const& graph_view, std::optional> edge_weight_view, + std::optional> edge_id_view, std::optional> renumber_map, bool renumber); template std::tuple< cugraph::graph_t, std::optional, double>>, + std::optional, int64_t>>, std::optional>> mg_graph_to_sg_graph( raft::handle_t const& handle, cugraph::graph_view_t const& graph_view, std::optional> edge_weight_view, + std::optional> edge_id_view, std::optional> renumber_map, bool renumber); From 31565696d420ae661aa75db84113ad65104c8da9 Mon Sep 17 00:00:00 2001 From: "Richard (Rick) Zamora" Date: Tue, 28 May 2024 08:58:51 -0500 Subject: [PATCH 3/7] Enable expression-based Dask Dataframe support (#4325) **[WIP]** I'm using this PR to debug/add support for `DASK_DATAFRAME__QUERY_PLANNING=True`. **NOTES**: - Depends on https://github.com/dask/dask-expr/pull/1041 [Merged] - Depends on https://github.com/dask/dask-expr/pull/1044 Authors: - Richard (Rick) Zamora (https://github.com/rjzamora) Approvers: - Rick Ratzel (https://github.com/rlratzel) - Ray Douglass (https://github.com/raydouglass) URL: https://github.com/rapidsai/cugraph/pull/4325 --- .../bulk_sampling/cugraph_bulk_sampling.py | 2 +- ci/test_python.sh | 4 ---- ci/test_wheel.sh | 4 ---- python/cugraph/cugraph/dask/__init__.py | 7 ++++++- .../cugraph/dask/common/input_utils.py | 6 +++--- .../cugraph/cugraph/dask/common/part_utils.py | 4 ++-- .../cugraph/structure/convert_matrix.py | 4 ++-- .../simpleDistributedGraph.py | 19 ++++++++++--------- .../data_store/test_property_graph_mg.py | 4 ++-- .../tests/internals/test_symmetrize_mg.py | 13 ++++++++----- .../cugraph/tests/structure/test_graph_mg.py | 6 +++--- .../cugraph/tests/utils/test_dataset.py | 2 +- 12 files changed, 38 insertions(+), 37 deletions(-) diff --git a/benchmarks/cugraph/standalone/bulk_sampling/cugraph_bulk_sampling.py b/benchmarks/cugraph/standalone/bulk_sampling/cugraph_bulk_sampling.py index 95e1afcb28b..578e2520765 100644 --- a/benchmarks/cugraph/standalone/bulk_sampling/cugraph_bulk_sampling.py +++ b/benchmarks/cugraph/standalone/bulk_sampling/cugraph_bulk_sampling.py @@ -344,7 +344,7 @@ def generate_rmat_dataset( del label_df gc.collect() - dask_label_df = dask_cudf.from_dask_dataframe(dask_label_df) + dask_label_df = dask_label_df.to_backend("cudf") node_offsets = {"paper": 0} edge_offsets = {("paper", "cites", "paper"): 0} diff --git a/ci/test_python.sh b/ci/test_python.sh index 5ea893eca60..9537f66e825 100755 --- a/ci/test_python.sh +++ b/ci/test_python.sh @@ -3,10 +3,6 @@ set -euo pipefail -# TODO: Enable dask query planning (by default) once some bugs are fixed. -# xref: https://github.com/rapidsai/cudf/issues/15027 -export DASK_DATAFRAME__QUERY_PLANNING=False - # Support invoking test_python.sh outside the script directory cd "$(dirname "$(realpath "${BASH_SOURCE[0]}")")"/../ diff --git a/ci/test_wheel.sh b/ci/test_wheel.sh index cda40d92c74..158704e08d1 100755 --- a/ci/test_wheel.sh +++ b/ci/test_wheel.sh @@ -3,10 +3,6 @@ set -eoxu pipefail -# TODO: Enable dask query planning (by default) once some bugs are fixed. -# xref: https://github.com/rapidsai/cudf/issues/15027 -export DASK_DATAFRAME__QUERY_PLANNING=False - package_name=$1 package_dir=$2 diff --git a/python/cugraph/cugraph/dask/__init__.py b/python/cugraph/cugraph/dask/__init__.py index a6958aaaf49..a76f1460575 100644 --- a/python/cugraph/cugraph/dask/__init__.py +++ b/python/cugraph/cugraph/dask/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,6 +11,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dask import config + from .link_analysis.pagerank import pagerank from .link_analysis.hits import hits from .traversal.bfs import bfs @@ -34,3 +36,6 @@ from .link_prediction.sorensen import sorensen from .link_prediction.overlap import overlap from .community.leiden import leiden + +# Avoid "p2p" shuffling in dask for now +config.set({"dataframe.shuffle.method": "tasks"}) diff --git a/python/cugraph/cugraph/dask/common/input_utils.py b/python/cugraph/cugraph/dask/common/input_utils.py index dcbd811562b..db70b7b089f 100644 --- a/python/cugraph/cugraph/dask/common/input_utils.py +++ b/python/cugraph/cugraph/dask/common/input_utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,8 +15,8 @@ from collections.abc import Sequence from collections import OrderedDict -from dask_cudf.core import DataFrame as dcDataFrame -from dask_cudf.core import Series as daskSeries +from dask_cudf import DataFrame as dcDataFrame +from dask_cudf import Series as daskSeries import cugraph.dask.comms.comms as Comms diff --git a/python/cugraph/cugraph/dask/common/part_utils.py b/python/cugraph/cugraph/dask/common/part_utils.py index d362502f239..19c429bb7be 100644 --- a/python/cugraph/cugraph/dask/common/part_utils.py +++ b/python/cugraph/cugraph/dask/common/part_utils.py @@ -18,8 +18,8 @@ import collections import dask_cudf from dask.array.core import Array as daskArray -from dask_cudf.core import DataFrame as daskDataFrame -from dask_cudf.core import Series as daskSeries +from dask_cudf import DataFrame as daskDataFrame +from dask_cudf import Series as daskSeries from functools import reduce import cugraph.dask.comms.comms as Comms from dask.delayed import delayed diff --git a/python/cugraph/cugraph/structure/convert_matrix.py b/python/cugraph/cugraph/structure/convert_matrix.py index b9b9554b870..024b9ddfba2 100644 --- a/python/cugraph/cugraph/structure/convert_matrix.py +++ b/python/cugraph/cugraph/structure/convert_matrix.py @@ -40,7 +40,7 @@ def from_edgelist( Parameters ---------- - df : cudf.DataFrame, pandas.DataFrame, dask_cudf.core.DataFrame + df : cudf.DataFrame, pandas.DataFrame, dask_cudf.DataFrame This DataFrame contains columns storing edge source vertices, destination (or target following NetworkX's terminology) vertices, and (optional) weights. @@ -95,7 +95,7 @@ def from_edgelist( renumber=renumber, ) - elif df_type is dask_cudf.core.DataFrame: + elif df_type is dask_cudf.DataFrame: if create_using is None: G = Graph() elif isinstance(create_using, Graph): diff --git a/python/cugraph/cugraph/structure/graph_implementation/simpleDistributedGraph.py b/python/cugraph/cugraph/structure/graph_implementation/simpleDistributedGraph.py index 0ef5eaf1b9e..3fa92bb5e67 100644 --- a/python/cugraph/cugraph/structure/graph_implementation/simpleDistributedGraph.py +++ b/python/cugraph/cugraph/structure/graph_implementation/simpleDistributedGraph.py @@ -285,19 +285,20 @@ def __from_edgelist( symmetrize=not self.properties.directed, ) + # Create a dask_cudf dataframe from the cudf series + # or dataframe objects obtained from symmetrization if isinstance(source_col, dask_cudf.Series): - # Create a dask_cudf dataframe from the cudf series obtained - # from symmetrization - input_ddf = source_col.to_frame() - input_ddf = input_ddf.rename(columns={source_col.name: source}) - input_ddf[destination] = dest_col + frames = [ + source_col.to_frame(name=source), + dest_col.to_frame(name=destination), + ] else: - # Multi column dask_cudf dataframe - input_ddf = dask_cudf.concat([source_col, dest_col], axis=1) + frames = [source_col, dest_col] if value_col is not None: - for vc in value_col_names: - input_ddf[vc] = value_col[vc] + frames.append(value_col[value_col_names]) + + input_ddf = dask_cudf.concat(frames, axis=1) self.input_df = input_ddf diff --git a/python/cugraph/cugraph/tests/data_store/test_property_graph_mg.py b/python/cugraph/cugraph/tests/data_store/test_property_graph_mg.py index db4ab0a2ac1..42cb0f232bf 100644 --- a/python/cugraph/cugraph/tests/data_store/test_property_graph_mg.py +++ b/python/cugraph/cugraph/tests/data_store/test_property_graph_mg.py @@ -159,8 +159,8 @@ def df_type_id(dataframe_type): return s + "cudf.DataFrame" if dataframe_type == pd.DataFrame: return s + "pandas.DataFrame" - if dataframe_type == dask_cudf.core.DataFrame: - return s + "dask_cudf.core.DataFrame" + if dataframe_type == dask_cudf.DataFrame: + return s + "dask_cudf.DataFrame" return s + "?" diff --git a/python/cugraph/cugraph/tests/internals/test_symmetrize_mg.py b/python/cugraph/cugraph/tests/internals/test_symmetrize_mg.py index 913443fe400..9091ab7df57 100644 --- a/python/cugraph/cugraph/tests/internals/test_symmetrize_mg.py +++ b/python/cugraph/cugraph/tests/internals/test_symmetrize_mg.py @@ -232,14 +232,17 @@ def test_mg_symmetrize(dask_client, read_datasets): # create a dask DataFrame from the dask Series if isinstance(sym_src, dask_cudf.Series): - ddf2 = sym_src.to_frame() - ddf2 = ddf2.rename(columns={sym_src.name: "src"}) - ddf2["dst"] = sym_dst + frames = [ + sym_src.to_frame(name="src"), + sym_dst.to_frame(name="dst"), + ] else: - ddf2 = dask_cudf.concat([sym_src, sym_dst], axis=1) + frames = [sym_src, sym_dst] if val_col_name is not None: - ddf2["weight"] = sym_val + frames.append(sym_val.to_frame(name="weight")) + + ddf2 = dask_cudf.concat(frames, axis=1) compare(ddf, ddf2, src_col_name, dst_col_name, val_col_name) diff --git a/python/cugraph/cugraph/tests/structure/test_graph_mg.py b/python/cugraph/cugraph/tests/structure/test_graph_mg.py index f23d4ec026d..cba61731e9a 100644 --- a/python/cugraph/cugraph/tests/structure/test_graph_mg.py +++ b/python/cugraph/cugraph/tests/structure/test_graph_mg.py @@ -99,13 +99,13 @@ def test_nodes_functionality(dask_client, input_combo): expected_nodes = ( dask_cudf.concat([ddf["src"], ddf["dst"]]) .drop_duplicates() - .to_frame() - .sort_values(0) + .to_frame(name="0") + .sort_values("0") ) expected_nodes = expected_nodes.compute().reset_index(drop=True) - result_nodes["expected_nodes"] = expected_nodes[0] + result_nodes["expected_nodes"] = expected_nodes["0"] compare = result_nodes.query("result_nodes != expected_nodes") diff --git a/python/cugraph/cugraph/tests/utils/test_dataset.py b/python/cugraph/cugraph/tests/utils/test_dataset.py index fae89e02002..a52b99dabfe 100644 --- a/python/cugraph/cugraph/tests/utils/test_dataset.py +++ b/python/cugraph/cugraph/tests/utils/test_dataset.py @@ -198,7 +198,7 @@ def test_reader_dask(dask_client, dataset): E = dataset.get_dask_edgelist(download=True) assert E is not None - assert isinstance(E, dask_cudf.core.DataFrame) + assert isinstance(E, dask_cudf.DataFrame) dataset.unload() From 562b5a5b9f3db29184390c319468ccb488d21056 Mon Sep 17 00:00:00 2001 From: Tingyu Wang Date: Tue, 28 May 2024 16:48:30 -0400 Subject: [PATCH 4/7] Pin torch version in `cugraph-dgl` wheel test (#4447) To fix the [CI nightly issue](https://github.com/rapidsai/cugraph/actions/runs/9188624604/job/25298484338#step:8:837) in cugraph-dgl wheel test for CUDA 11. Authors: - Tingyu Wang (https://github.com/tingyu66) Approvers: - Rick Ratzel (https://github.com/rlratzel) - Brad Rees (https://github.com/BradReesWork) - Jake Awe (https://github.com/AyodeAwe) URL: https://github.com/rapidsai/cugraph/pull/4447 --- ci/test_wheel_cugraph-dgl.sh | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/ci/test_wheel_cugraph-dgl.sh b/ci/test_wheel_cugraph-dgl.sh index 827ad487115..564b46cb07e 100755 --- a/ci/test_wheel_cugraph-dgl.sh +++ b/ci/test_wheel_cugraph-dgl.sh @@ -32,8 +32,18 @@ fi PYTORCH_URL="https://download.pytorch.org/whl/cu${PYTORCH_CUDA_VER}" DGL_URL="https://data.dgl.ai/wheels/cu${PYTORCH_CUDA_VER}/repo.html" +# Starting from 2.2, PyTorch wheels depend on nvidia-nccl-cuxx>=2.19 wheel and +# dynamically link to NCCL. RAPIDS CUDA 11 CI images have an older NCCL version that +# might shadow the newer NCCL required by PyTorch during import (when importing +# `cupy` before `torch`). +if [[ "${NCCL_VERSION}" < "2.19" ]]; then + PYTORCH_VER="2.1.0" +else + PYTORCH_VER="2.3.0" +fi + rapids-logger "Installing PyTorch and DGL" -rapids-retry python -m pip install torch --index-url ${PYTORCH_URL} +rapids-retry python -m pip install "torch==${PYTORCH_VER}" --index-url ${PYTORCH_URL} rapids-retry python -m pip install dgl==2.0.0 --find-links ${DGL_URL} python -m pytest python/cugraph-dgl/tests From 169d1625fd93d99c481051f78047464e43fdee02 Mon Sep 17 00:00:00 2001 From: Don Acosta <97529984+acostadon@users.noreply.github.com> Date: Wed, 29 May 2024 08:43:47 -0400 Subject: [PATCH 5/7] adding notebook to demo nx_cugraph (#4366) This notebook will be used to demontstrate how to use nx-cugraph and show the speed-up. Authors: - Don Acosta (https://github.com/acostadon) Approvers: - Brad Rees (https://github.com/BradReesWork) URL: https://github.com/rapidsai/cugraph/pull/4366 --- docs/cugraph/source/nx_cugraph/nx_cugraph.md | 159 +++--------- .../nx_cugraph_codeless_switching.ipynb | 244 ++++++++++++++++++ 2 files changed, 275 insertions(+), 128 deletions(-) create mode 100644 notebooks/cugraph_benchmarks/nx_cugraph_codeless_switching.ipynb diff --git a/docs/cugraph/source/nx_cugraph/nx_cugraph.md b/docs/cugraph/source/nx_cugraph/nx_cugraph.md index 92fbf90a43b..ff2fc3d1da8 100644 --- a/docs/cugraph/source/nx_cugraph/nx_cugraph.md +++ b/docs/cugraph/source/nx_cugraph/nx_cugraph.md @@ -24,142 +24,45 @@ Each chart has three measurements. ![Single Source Shortest Path](../images/sssp.png) ![Weakly Connected Components](../images/wcc.png) +### Command line example +Open bc_demo.ipy and paste the code below. -The following algorithms are supported and automatically dispatched to nx-cuGraph for acceleration. +``` +import pandas as pd +import networkx as nx + +url = "https://data.rapids.ai/cugraph/datasets/cit-Patents.csv" +df = pd.read_csv(url, sep=" ", names=["src", "dst"], dtype="int32") +G = nx.from_pandas_edgelist(df, source="src", target="dst") -#### Algorithms +%time result = nx.betweenness_centrality(G, k=10) +``` +Run the command: ``` -bipartite - ├─ basic - │ └─ is_bipartite - └─ generators - └─ complete_bipartite_graph -centrality - ├─ betweenness - │ ├─ betweenness_centrality - │ └─ edge_betweenness_centrality - ├─ degree_alg - │ ├─ degree_centrality - │ ├─ in_degree_centrality - │ └─ out_degree_centrality - ├─ eigenvector - │ └─ eigenvector_centrality - └─ katz - └─ katz_centrality -cluster - ├─ average_clustering - ├─ clustering - ├─ transitivity - └─ triangles -community - └─ louvain - └─ louvain_communities -components - ├─ connected - │ ├─ connected_components - │ ├─ is_connected - │ ├─ node_connected_component - │ └─ number_connected_components - └─ weakly_connected - ├─ is_weakly_connected - ├─ number_weakly_connected_components - └─ weakly_connected_components -core - ├─ core_number - └─ k_truss -dag - ├─ ancestors - └─ descendants -isolate - ├─ is_isolate - ├─ isolates - └─ number_of_isolates -link_analysis - ├─ hits_alg - │ └─ hits - └─ pagerank_alg - └─ pagerank -operators - └─ unary - ├─ complement - └─ reverse -reciprocity - ├─ overall_reciprocity - └─ reciprocity -shortest_paths - └─ unweighted - ├─ single_source_shortest_path_length - └─ single_target_shortest_path_length -traversal - └─ breadth_first_search - ├─ bfs_edges - ├─ bfs_layers - ├─ bfs_predecessors - ├─ bfs_successors - ├─ bfs_tree - ├─ descendants_at_distance - └─ generic_bfs_edges -tree - └─ recognition - ├─ is_arborescence - ├─ is_branching - ├─ is_forest - └─ is_tree +user@machine:/# ipython bc_demo.ipy ``` -#### Generators +You will observe a run time of approximately 7 minutes...more or less depending on your cpu. + +Run the command again, this time specifiying cugraph as the NetworkX backend of choice. +``` +user@machine:/# NETWORKX_BACKEND_PRIORITY=cugraph ipython bc_demo.ipy ``` -classic - ├─ barbell_graph - ├─ circular_ladder_graph - ├─ complete_graph - ├─ complete_multipartite_graph - ├─ cycle_graph - ├─ empty_graph - ├─ ladder_graph - ├─ lollipop_graph - ├─ null_graph - ├─ path_graph - ├─ star_graph - ├─ tadpole_graph - ├─ trivial_graph - ├─ turan_graph - └─ wheel_graph -community - └─ caveman_graph -small - ├─ bull_graph - ├─ chvatal_graph - ├─ cubical_graph - ├─ desargues_graph - ├─ diamond_graph - ├─ dodecahedral_graph - ├─ frucht_graph - ├─ heawood_graph - ├─ house_graph - ├─ house_x_graph - ├─ icosahedral_graph - ├─ krackhardt_kite_graph - ├─ moebius_kantor_graph - ├─ octahedral_graph - ├─ pappus_graph - ├─ petersen_graph - ├─ sedgewick_maze_graph - ├─ tetrahedral_graph - ├─ truncated_cube_graph - ├─ truncated_tetrahedron_graph - └─ tutte_graph -social - ├─ davis_southern_women_graph - ├─ florentine_families_graph - ├─ karate_club_graph - └─ les_miserables_graph +This run will be much faster, typically around 20 seconds depending on your GPU. +``` +user@machine:/# NETWORKX_BACKEND_PRIORITY=cugraph ipython bc_demo.ipy +``` +There is also an option to add caching. This will dramatically help performance when running multiple algorithms on the same graph. +``` +NETWORKX_BACKEND_PRIORITY=cugraph CACHE_CONVERTED_GRAPH=True ipython bc_demo.ipy ``` -#### Other +When running Python interactively, cugraph backend can be specified as an argument in the algorithm call. +For example: ``` -convert_matrix - ├─ from_pandas_edgelist - └─ from_scipy_sparse_array +nx.betweenness_centrality(cit_patents_graph, k=k, backend="cugraph") ``` + + +The latest list of algorithms that can be dispatched to nx-cuGraph for acceleration is found [here](https://github.com/rapidsai/cugraph/blob/main/python/nx-cugraph/README.md#algorithms). diff --git a/notebooks/cugraph_benchmarks/nx_cugraph_codeless_switching.ipynb b/notebooks/cugraph_benchmarks/nx_cugraph_codeless_switching.ipynb new file mode 100644 index 00000000000..e05544448b1 --- /dev/null +++ b/notebooks/cugraph_benchmarks/nx_cugraph_codeless_switching.ipynb @@ -0,0 +1,244 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Benchmarking Performance of NetworkX with Rapids GPU-based nx_cugraph backend vs on cpu\n", + "# Skip notebook test\n", + "This notebook demonstrates compares the performance of nx_cugraph as a dispatcher for NetworkX algorithms. \n", + "\n", + "We do this by executing Betweenness Centrality, Breadth First Search and Louvain Community Detection, collecting run times with and without nx_cugraph backend and graph caching enabled. nx_cugraph is a registered NetworkX backend. Using it is a zero code change solution.\n", + "\n", + "In the notebook switching to the nx-cugraph backend is done via variables set using the [NetworkX config package](https://networkx.org/documentation/stable/reference/backends.html#networkx.utils.configs.NetworkXConfig) **which requires networkX 3.3 or later !!**\n", + "\n", + "\n", + "They can be set at the command line as well.\n", + "\n", + "### See this example from GTC Spring 2024\n", + "\n", + "\n", + "\n", + "Here is a sample minimal script to demonstrate No-code-change GPU acceleration using nx-cugraph.\n", + "\n", + "----\n", + "bc_demo.ipy:\n", + "\n", + "```\n", + "import pandas as pd\n", + "import networkx as nx\n", + "\n", + "url = \"https://data.rapids.ai/cugraph/datasets/cit-Patents.csv\"\n", + "df = pd.read_csv(url, sep=\" \", names=[\"src\", \"dst\"], dtype=\"int32\")\n", + "G = nx.from_pandas_edgelist(df, source=\"src\", target=\"dst\")\n", + "\n", + "%time result = nx.betweenness_centrality(G, k=10)\n", + "```\n", + "----\n", + "Running it with the nx-cugraph backend looks like this:\n", + "```\n", + "user@machine:/# ipython bc_demo.ipy\n", + "CPU times: user 7min 38s, sys: 5.6 s, total: 7min 44s\n", + "Wall time: 7min 44s\n", + "\n", + "user@machine:/# NETWORKX_BACKEND_PRIORITY=cugraph ipython bc_demo.ipy\n", + "CPU times: user 18.4 s, sys: 1.44 s, total: 19.9 s\n", + "Wall time: 20 s\n", + "```\n", + "----\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First import the needed packages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import networkx as nx\n", + "import time\n", + "import os" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This installs the NetworkX cuGraph dispatcher if not already present." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "try: \n", + " import nx_cugraph\n", + "except ModuleNotFoundError:\n", + " os.system('conda install -c rapidsai -c conda-forge -c nvidia nx-cugraph')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This is boiler plate NetworkX code to run:\n", + "* betweenness Centrality\n", + "* Bredth first Search\n", + "* Louvain community detection\n", + "\n", + "and report times. it is completely unaware of cugraph or GPU-based tools.\n", + "[NetworkX configurations](https://networkx.org/documentation/stable/reference/utils.html#backends) can determine how they are run." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def run_algos(G):\n", + " runtime = time.time()\n", + " result = nx.betweenness_centrality(G, k=10)\n", + " print (\"Betweenness Centrality time: \" + str(round(time.time() - runtime))+ \" seconds\")\n", + " runtime = time.time()\n", + " result = nx.bfs_tree(G,source=1)\n", + " print (\"Breadth First Search time: \" + str(round(time.time() - runtime))+ \" seconds\")\n", + " runtime = time.time()\n", + " result = nx.community.louvain_communities(G,threshold=1e-04)\n", + " print (\"Louvain time: \" + str(round(time.time() - runtime))+ \" seconds\")\n", + " return" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Downloads a patent citation dataset containing 3774768 nodes and 16518948 edges and loads it into a NetworkX graph." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "filepath = \"./data/cit-Patents.csv\"\n", + "\n", + "if os.path.exists(filepath):\n", + " print(\"File found\")\n", + " url = filepath\n", + "else:\n", + " url = \"https://data.rapids.ai/cugraph/datasets/cit-Patents.csv\"\n", + "df = pd.read_csv(url, sep=\" \", names=[\"src\", \"dst\"], dtype=\"int32\")\n", + "G = nx.from_pandas_edgelist(df, source=\"src\", target=\"dst\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Setting the NetworkX dispatcher with an environment variable or in code using NetworkX config package which is new to [NetworkX 3.3 config](https://networkx.org/documentation/stable/reference/backends.html#networkx.utils.configs.NetworkXConfig).\n", + "\n", + "These convenience settinge allow turning off caching and cugraph dispatching if you want to see how long cpu-only takes.\n", + "This example using an AMD Ryzen Threadripper PRO 3975WX 32-Cores cpu completed in slightly over 40 minutes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "use_cugraph = True\n", + "cache_graph = True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if use_cugraph:\n", + " nx.config[\"backend_priority\"]=['cugraph']\n", + "else:\n", + " # Use this setting to turn off the cugraph dispatcher running in legacy cpu mode.\n", + " nx.config[\"backend_priority\"]=[]\n", + "if cache_graph:\n", + " nx.config[\"cache_converted_graphs\"]= True\n", + "else:\n", + " # Use this setting to turn off graph caching which will convertthe NetworkX to a gpu-resident graph each time an algorithm is run.\n", + " nx.config[\"cache_converted_graphs\"]= False\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Run the algorithms on GPU. \n", + "\n", + "**Note the messages NetworkX generates to remind us cached graph shouldn't be modified.**\n", + "\n", + "```\n", + "For the cache to be consistent (i.e., correct), the input graph must not have been manually mutated since the cached graph was created.\n", + "\n", + "Using cached graph for 'cugraph' backend in call to bfs_edges.\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "run_algos(G)\n", + "print (\"Total Algorithm run time\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "___\n", + "Copyright (c) 2024, NVIDIA CORPORATION.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\n", + "___" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From 4c797bfa251d36f57870cc9ca8636d3098be964c Mon Sep 17 00:00:00 2001 From: Paul Taylor <178183+trxcllnt@users.noreply.github.com> Date: Wed, 29 May 2024 11:07:57 -0700 Subject: [PATCH 6/7] Fix building cugraph with CCCL main (#4404) Similar to https://github.com/rapidsai/cudf/pull/15552, we are testing [building RAPIDS with CCCL's main branch](https://github.com/NVIDIA/cccl/pull/1667) to get ahead of any breaking changes. Authors: - Paul Taylor (https://github.com/trxcllnt) - Ralph Liu (https://github.com/nv-rliu) - Seunghwa Kang (https://github.com/seunghwak) - Ray Bell (https://github.com/raybellwaves) Approvers: - Chuck Hastings (https://github.com/ChuckHastings) - Seunghwa Kang (https://github.com/seunghwak) - Jake Awe (https://github.com/AyodeAwe) URL: https://github.com/rapidsai/cugraph/pull/4404 --- .devcontainer/Dockerfile | 5 + .../cuda11.8-conda/devcontainer.json | 2 +- .devcontainer/cuda11.8-pip/devcontainer.json | 8 +- .../cuda12.2-conda/devcontainer.json | 2 +- .devcontainer/cuda12.2-pip/devcontainer.json | 8 +- .github/workflows/pr.yaml | 2 +- cpp/CMakeLists.txt | 4 +- .../cugraph/utilities/device_functors.cuh | 9 +- cpp/include/cugraph/utilities/mask_utils.cuh | 5 +- cpp/src/community/detail/common_methods.cuh | 3 +- cpp/src/community/legacy/louvain.cuh | 15 ++- .../weakly_connected_components_impl.cuh | 15 ++- cpp/src/detail/utility_wrappers.cu | 4 +- cpp/src/prims/kv_store.cuh | 1 + ...m_reduce_dst_key_aggregated_outgoing_e.cuh | 2 +- cpp/src/structure/graph_view_impl.cuh | 36 +++--- cpp/tests/CMakeLists.txt | 24 +++- cpp/tests/prims/mg_extract_transform_e.cu | 109 +++++------------- .../sampling/sampling_post_processing_test.cu | 38 +++--- 19 files changed, 138 insertions(+), 154 deletions(-) diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 3d0ac075be3..190003dd7af 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -7,6 +7,11 @@ FROM ${BASE} as pip-base ENV DEFAULT_VIRTUAL_ENV=rapids +RUN apt update -y \ + && DEBIAN_FRONTEND=noninteractive apt install -y \ + libblas-dev liblapack-dev \ + && rm -rf /tmp/* /var/tmp/* /var/cache/apt/* /var/lib/apt/lists/*; + FROM ${BASE} as conda-base ENV DEFAULT_CONDA_ENV=rapids diff --git a/.devcontainer/cuda11.8-conda/devcontainer.json b/.devcontainer/cuda11.8-conda/devcontainer.json index 7c9cd0258a4..d878f2d6584 100644 --- a/.devcontainer/cuda11.8-conda/devcontainer.json +++ b/.devcontainer/cuda11.8-conda/devcontainer.json @@ -11,7 +11,7 @@ "runArgs": [ "--rm", "--name", - "${localEnv:USER}-rapids-${localWorkspaceFolderBasename}-24.06-cuda11.8-conda" + "${localEnv:USER:anon}-rapids-${localWorkspaceFolderBasename}-24.06-cuda11.8-conda" ], "hostRequirements": {"gpu": "optional"}, "features": { diff --git a/.devcontainer/cuda11.8-pip/devcontainer.json b/.devcontainer/cuda11.8-pip/devcontainer.json index a4dc168505b..a0edcb27df8 100644 --- a/.devcontainer/cuda11.8-pip/devcontainer.json +++ b/.devcontainer/cuda11.8-pip/devcontainer.json @@ -5,19 +5,16 @@ "args": { "CUDA": "11.8", "PYTHON_PACKAGE_MANAGER": "pip", - "BASE": "rapidsai/devcontainers:24.06-cpp-cuda11.8-ubuntu22.04" + "BASE": "rapidsai/devcontainers:24.06-cpp-cuda11.8-ucx1.15.0-openmpi-ubuntu22.04" } }, "runArgs": [ "--rm", "--name", - "${localEnv:USER}-rapids-${localWorkspaceFolderBasename}-24.06-cuda11.8-pip" + "${localEnv:USER:anon}-rapids-${localWorkspaceFolderBasename}-24.06-cuda11.8-pip" ], "hostRequirements": {"gpu": "optional"}, "features": { - "ghcr.io/rapidsai/devcontainers/features/ucx:24.6": { - "version": "1.15.0" - }, "ghcr.io/rapidsai/devcontainers/features/cuda:24.6": { "version": "11.8", "installcuBLAS": true, @@ -28,7 +25,6 @@ "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:24.6": {} }, "overrideFeatureInstallOrder": [ - "ghcr.io/rapidsai/devcontainers/features/ucx", "ghcr.io/rapidsai/devcontainers/features/cuda", "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils" ], diff --git a/.devcontainer/cuda12.2-conda/devcontainer.json b/.devcontainer/cuda12.2-conda/devcontainer.json index eae4967f3b2..8a095d9b934 100644 --- a/.devcontainer/cuda12.2-conda/devcontainer.json +++ b/.devcontainer/cuda12.2-conda/devcontainer.json @@ -11,7 +11,7 @@ "runArgs": [ "--rm", "--name", - "${localEnv:USER}-rapids-${localWorkspaceFolderBasename}-24.06-cuda12.2-conda" + "${localEnv:USER:anon}-rapids-${localWorkspaceFolderBasename}-24.06-cuda12.2-conda" ], "hostRequirements": {"gpu": "optional"}, "features": { diff --git a/.devcontainer/cuda12.2-pip/devcontainer.json b/.devcontainer/cuda12.2-pip/devcontainer.json index 393a5c63d23..10436f8b28d 100644 --- a/.devcontainer/cuda12.2-pip/devcontainer.json +++ b/.devcontainer/cuda12.2-pip/devcontainer.json @@ -5,19 +5,16 @@ "args": { "CUDA": "12.2", "PYTHON_PACKAGE_MANAGER": "pip", - "BASE": "rapidsai/devcontainers:24.06-cpp-cuda12.2-ubuntu22.04" + "BASE": "rapidsai/devcontainers:24.06-cpp-cuda12.2-ucx1.15.0-openmpi-ubuntu22.04" } }, "runArgs": [ "--rm", "--name", - "${localEnv:USER}-rapids-${localWorkspaceFolderBasename}-24.06-cuda12.2-pip" + "${localEnv:USER:anon}-rapids-${localWorkspaceFolderBasename}-24.06-cuda12.2-pip" ], "hostRequirements": {"gpu": "optional"}, "features": { - "ghcr.io/rapidsai/devcontainers/features/ucx:24.6": { - "version": "1.15.0" - }, "ghcr.io/rapidsai/devcontainers/features/cuda:24.6": { "version": "12.2", "installcuBLAS": true, @@ -28,7 +25,6 @@ "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils:24.6": {} }, "overrideFeatureInstallOrder": [ - "ghcr.io/rapidsai/devcontainers/features/ucx", "ghcr.io/rapidsai/devcontainers/features/cuda", "ghcr.io/rapidsai/devcontainers/features/rapids-build-utils" ], diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index c04e0e879d2..5733646a8b9 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -196,5 +196,5 @@ jobs: extra-repo-deploy-key: CUGRAPH_OPS_SSH_PRIVATE_DEPLOY_KEY build_command: | sccache -z; - build-all --verbose -j$(nproc --ignore=1); + build-all --verbose -j$(nproc --ignore=1) -DBUILD_CUGRAPH_MG_TESTS=ON; sccache -s; diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 2527599fece..7dca3d983a5 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -92,14 +92,14 @@ set(CUGRAPH_CXX_FLAGS "") set(CUGRAPH_CUDA_FLAGS "") if(CMAKE_COMPILER_IS_GNUCXX) - list(APPEND CUGRAPH_CXX_FLAGS -Werror -Wno-error=deprecated-declarations) + list(APPEND CUGRAPH_CXX_FLAGS -Werror -Wno-error=deprecated-declarations -Wno-deprecated-declarations -DRAFT_HIDE_DEPRECATION_WARNINGS) endif(CMAKE_COMPILER_IS_GNUCXX) message("-- Building for GPU_ARCHS = ${CMAKE_CUDA_ARCHITECTURES}") list(APPEND CUGRAPH_CUDA_FLAGS --expt-extended-lambda --expt-relaxed-constexpr) -list(APPEND CUGRAPH_CUDA_FLAGS -Werror=cross-execution-space-call -Wno-deprecated-declarations -Xptxas=--disable-warnings) +list(APPEND CUGRAPH_CUDA_FLAGS -Werror=cross-execution-space-call -Wno-deprecated-declarations -DRAFT_HIDE_DEPRECATION_WARNINGS -Xptxas=--disable-warnings) list(APPEND CUGRAPH_CUDA_FLAGS -Xcompiler=-Wall,-Wno-error=sign-compare,-Wno-error=unused-but-set-variable) list(APPEND CUGRAPH_CUDA_FLAGS -Xfatbin=-compress-all) diff --git a/cpp/include/cugraph/utilities/device_functors.cuh b/cpp/include/cugraph/utilities/device_functors.cuh index 3af8ed1dd19..20cf98f7e6d 100644 --- a/cpp/include/cugraph/utilities/device_functors.cuh +++ b/cpp/include/cugraph/utilities/device_functors.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -78,13 +78,14 @@ struct indirection_t { template struct indirection_if_idx_valid_t { + using value_type = typename thrust::iterator_traits::value_type; Iterator first{}; index_t invalid_idx{}; - typename thrust::iterator_traits::value_type invalid_value{}; + value_type invalid_value{}; - __device__ typename thrust::iterator_traits::value_type operator()(index_t i) const + __device__ value_type operator()(index_t i) const { - return (i != invalid_idx) ? *(first + i) : invalid_value; + return (i != invalid_idx) ? static_cast(*(first + i)) : invalid_value; } }; diff --git a/cpp/include/cugraph/utilities/mask_utils.cuh b/cpp/include/cugraph/utilities/mask_utils.cuh index 7b69ea3fe3a..1d86eef0ed1 100644 --- a/cpp/include/cugraph/utilities/mask_utils.cuh +++ b/cpp/include/cugraph/utilities/mask_utils.cuh @@ -20,6 +20,7 @@ #include +#include #include #include #include @@ -160,13 +161,13 @@ size_t count_set_bits(raft::handle_t const& handle, MaskIterator mask_first, siz handle.get_thrust_policy(), thrust::make_counting_iterator(size_t{0}), thrust::make_counting_iterator(packed_bool_size(num_bits)), - [mask_first, num_bits] __device__(size_t i) { + cuda::proclaim_return_type([mask_first, num_bits] __device__(size_t i) -> size_t { auto word = *(mask_first + i); if ((i + 1) * packed_bools_per_word() > num_bits) { word &= packed_bool_partial_mask(num_bits % packed_bools_per_word()); } return static_cast(__popc(word)); - }, + }), size_t{0}, thrust::plus{}); } diff --git a/cpp/src/community/detail/common_methods.cuh b/cpp/src/community/detail/common_methods.cuh index fe0a415db30..dcad4e92b95 100644 --- a/cpp/src/community/detail/common_methods.cuh +++ b/cpp/src/community/detail/common_methods.cuh @@ -29,6 +29,7 @@ #include #include +#include #include #include #include @@ -178,7 +179,7 @@ weight_t compute_modularity( handle.get_thrust_policy(), cluster_weights.begin(), cluster_weights.end(), - [] __device__(weight_t p) { return p * p; }, + cuda::proclaim_return_type([] __device__(weight_t p) -> weight_t { return p * p; }), weight_t{0}, thrust::plus()); diff --git a/cpp/src/community/legacy/louvain.cuh b/cpp/src/community/legacy/louvain.cuh index 6cf5bbdc3c6..53d0b231c03 100644 --- a/cpp/src/community/legacy/louvain.cuh +++ b/cpp/src/community/legacy/louvain.cuh @@ -22,6 +22,7 @@ #include #include + #ifdef TIMING #include #endif @@ -29,6 +30,7 @@ #include #include +#include #include #include #include @@ -141,12 +143,13 @@ class Louvain { handle_.get_thrust_policy(), thrust::make_counting_iterator(0), thrust::make_counting_iterator(graph.number_of_vertices), - [d_deg = deg.data(), d_inc = inc.data(), total_edge_weight, resolution] __device__( - vertex_t community) { - return ((d_inc[community] / total_edge_weight) - resolution * - (d_deg[community] * d_deg[community]) / - (total_edge_weight * total_edge_weight)); - }, + cuda::proclaim_return_type( + [d_deg = deg.data(), d_inc = inc.data(), total_edge_weight, resolution] __device__( + vertex_t community) -> weight_t { + return ((d_inc[community] / total_edge_weight) - + resolution * (d_deg[community] * d_deg[community]) / + (total_edge_weight * total_edge_weight)); + }), weight_t{0.0}, thrust::plus()); diff --git a/cpp/src/components/weakly_connected_components_impl.cuh b/cpp/src/components/weakly_connected_components_impl.cuh index d4d6d842951..f63f28210d8 100644 --- a/cpp/src/components/weakly_connected_components_impl.cuh +++ b/cpp/src/components/weakly_connected_components_impl.cuh @@ -34,6 +34,7 @@ #include +#include #include #include #include @@ -400,9 +401,10 @@ void weakly_connected_components_impl(raft::handle_t const& handle, handle.get_thrust_policy(), new_root_candidates.begin(), new_root_candidates.begin() + (new_root_candidates.size() > 0 ? 1 : 0), - [vertex_partition, degrees = degrees.data()] __device__(auto v) { - return degrees[vertex_partition.local_vertex_partition_offset_from_vertex_nocheck(v)]; - }, + cuda::proclaim_return_type( + [vertex_partition, degrees = degrees.data()] __device__(auto v) -> edge_t { + return degrees[vertex_partition.local_vertex_partition_offset_from_vertex_nocheck(v)]; + }), edge_t{0}, thrust::plus{}); @@ -642,9 +644,10 @@ void weakly_connected_components_impl(raft::handle_t const& handle, handle.get_thrust_policy(), thrust::get<0>(vertex_frontier.bucket(bucket_idx_cur).begin().get_iterator_tuple()), thrust::get<0>(vertex_frontier.bucket(bucket_idx_cur).end().get_iterator_tuple()), - [vertex_partition, degrees = degrees.data()] __device__(auto v) { - return degrees[vertex_partition.local_vertex_partition_offset_from_vertex_nocheck(v)]; - }, + cuda::proclaim_return_type( + [vertex_partition, degrees = degrees.data()] __device__(auto v) -> edge_t { + return degrees[vertex_partition.local_vertex_partition_offset_from_vertex_nocheck(v)]; + }), edge_t{0}, thrust::plus()); diff --git a/cpp/src/detail/utility_wrappers.cu b/cpp/src/detail/utility_wrappers.cu index 9100ecbd5e1..6d6158a16e7 100644 --- a/cpp/src/detail/utility_wrappers.cu +++ b/cpp/src/detail/utility_wrappers.cu @@ -21,6 +21,7 @@ #include +#include #include #include #include @@ -139,7 +140,8 @@ vertex_t compute_maximum_vertex_id(rmm::cuda_stream_view const& stream_view, rmm::exec_policy(stream_view), edge_first, edge_first + num_edges, - [] __device__(auto e) { return std::max(thrust::get<0>(e), thrust::get<1>(e)); }, + cuda::proclaim_return_type( + [] __device__(auto e) -> vertex_t { return std::max(thrust::get<0>(e), thrust::get<1>(e)); }), vertex_t{0}, thrust::maximum()); } diff --git a/cpp/src/prims/kv_store.cuh b/cpp/src/prims/kv_store.cuh index 5001a20bb83..de233fd583b 100644 --- a/cpp/src/prims/kv_store.cuh +++ b/cpp/src/prims/kv_store.cuh @@ -17,6 +17,7 @@ #include "prims/detail/optional_dataframe_buffer.hpp" +#include #include #include diff --git a/cpp/src/prims/per_v_transform_reduce_dst_key_aggregated_outgoing_e.cuh b/cpp/src/prims/per_v_transform_reduce_dst_key_aggregated_outgoing_e.cuh index 006d7760666..7be30b0a5f0 100644 --- a/cpp/src/prims/per_v_transform_reduce_dst_key_aggregated_outgoing_e.cuh +++ b/cpp/src/prims/per_v_transform_reduce_dst_key_aggregated_outgoing_e.cuh @@ -754,7 +754,7 @@ void per_v_transform_reduce_dst_key_aggregated_outgoing_e( std::make_unique>( std::move(majors), std::move(edge_major_values), - invalid_vertex_id::value, + edge_src_value_t{}, true, handle.get_stream()); } diff --git a/cpp/src/structure/graph_view_impl.cuh b/cpp/src/structure/graph_view_impl.cuh index 29dca6ef409..7097349dce5 100644 --- a/cpp/src/structure/graph_view_impl.cuh +++ b/cpp/src/structure/graph_view_impl.cuh @@ -353,7 +353,7 @@ edge_t count_edge_partition_multi_edges( execution_policy, thrust::make_counting_iterator(edge_partition.major_range_first()) + (*segment_offsets)[2], thrust::make_counting_iterator(edge_partition.major_range_first()) + (*segment_offsets)[3], - [edge_partition] __device__(auto major) { + cuda::proclaim_return_type([edge_partition] __device__(auto major) -> edge_t { auto major_offset = edge_partition.major_offset_from_major_nocheck(major); vertex_t const* indices{nullptr}; [[maybe_unused]] edge_t edge_offset{}; @@ -365,7 +365,7 @@ edge_t count_edge_partition_multi_edges( if (indices[i - 1] == indices[i]) { ++count; } } return count; - }, + }), edge_t{0}, thrust::plus{}); } @@ -374,19 +374,21 @@ edge_t count_edge_partition_multi_edges( execution_policy, thrust::make_counting_iterator(vertex_t{0}), thrust::make_counting_iterator(*(edge_partition.dcs_nzd_vertex_count())), - [edge_partition, major_start_offset = (*segment_offsets)[3]] __device__(auto idx) { - auto major_idx = - major_start_offset + idx; // major_offset != major_idx in the hypersparse region - vertex_t const* indices{nullptr}; - [[maybe_unused]] edge_t edge_offset{}; - edge_t local_degree{}; - thrust::tie(indices, edge_offset, local_degree) = edge_partition.local_edges(major_idx); - edge_t count{0}; - for (edge_t i = 1; i < local_degree; ++i) { // assumes neighbors are sorted - if (indices[i - 1] == indices[i]) { ++count; } - } - return count; - }, + cuda::proclaim_return_type( + [edge_partition, + major_start_offset = (*segment_offsets)[3]] __device__(auto idx) -> edge_t { + auto major_idx = + major_start_offset + idx; // major_offset != major_idx in the hypersparse region + vertex_t const* indices{nullptr}; + [[maybe_unused]] edge_t edge_offset{}; + edge_t local_degree{}; + thrust::tie(indices, edge_offset, local_degree) = edge_partition.local_edges(major_idx); + edge_t count{0}; + for (edge_t i = 1; i < local_degree; ++i) { // assumes neighbors are sorted + if (indices[i - 1] == indices[i]) { ++count; } + } + return count; + }), edge_t{0}, thrust::plus{}); } @@ -398,7 +400,7 @@ edge_t count_edge_partition_multi_edges( thrust::make_counting_iterator(edge_partition.major_range_first()), thrust::make_counting_iterator(edge_partition.major_range_first()) + edge_partition.major_range_size(), - [edge_partition] __device__(auto major) { + cuda::proclaim_return_type([edge_partition] __device__(auto major) -> edge_t { auto major_offset = edge_partition.major_offset_from_major_nocheck(major); vertex_t const* indices{nullptr}; [[maybe_unused]] edge_t edge_offset{}; @@ -409,7 +411,7 @@ edge_t count_edge_partition_multi_edges( if (indices[i - 1] == indices[i]) { ++count; } } return count; - }, + }), edge_t{0}, thrust::plus{}); } diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index d1dd2dec069..2152de28ff9 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -169,7 +169,11 @@ function(ConfigureTest CMAKE_TEST_NAME) ) set_target_properties( ${CMAKE_TEST_NAME} - PROPERTIES INSTALL_RPATH "\$ORIGIN/../../../lib") + PROPERTIES INSTALL_RPATH "\$ORIGIN/../../../lib" + CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON + CUDA_STANDARD 17 + CUDA_STANDARD_REQUIRED ON) rapids_test_add( NAME ${CMAKE_TEST_NAME} @@ -195,7 +199,11 @@ function(ConfigureTestMG CMAKE_TEST_NAME) ) set_target_properties( ${CMAKE_TEST_NAME} - PROPERTIES INSTALL_RPATH "\$ORIGIN/../../../lib") + PROPERTIES INSTALL_RPATH "\$ORIGIN/../../../lib" + CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON + CUDA_STANDARD 17 + CUDA_STANDARD_REQUIRED ON) rapids_test_add( NAME ${CMAKE_TEST_NAME} @@ -241,7 +249,11 @@ function(ConfigureCTest CMAKE_TEST_NAME) ) set_target_properties( ${CMAKE_TEST_NAME} - PROPERTIES INSTALL_RPATH "\$ORIGIN/../../../lib") + PROPERTIES INSTALL_RPATH "\$ORIGIN/../../../lib" + CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON + CUDA_STANDARD 17 + CUDA_STANDARD_REQUIRED ON) rapids_test_add( NAME ${CMAKE_TEST_NAME} @@ -269,7 +281,11 @@ function(ConfigureCTestMG CMAKE_TEST_NAME) ) set_target_properties( ${CMAKE_TEST_NAME} - PROPERTIES INSTALL_RPATH "\$ORIGIN/../../../lib") + PROPERTIES INSTALL_RPATH "\$ORIGIN/../../../lib" + CXX_STANDARD 17 + CXX_STANDARD_REQUIRED ON + CUDA_STANDARD 17 + CUDA_STANDARD_REQUIRED ON) rapids_test_add( NAME ${CMAKE_TEST_NAME} diff --git a/cpp/tests/prims/mg_extract_transform_e.cu b/cpp/tests/prims/mg_extract_transform_e.cu index 20e87070fa5..d7aa953ef7c 100644 --- a/cpp/tests/prims/mg_extract_transform_e.cu +++ b/cpp/tests/prims/mg_extract_transform_e.cu @@ -59,55 +59,27 @@ #include #include -template +template struct e_op_t { - static_assert(std::is_same_v || - std::is_same_v>); static_assert(std::is_same_v || std::is_same_v>); - using return_type = thrust::optional, - std::conditional_t, - thrust::tuple, - thrust::tuple>, - std::conditional_t, - thrust::tuple, - thrust::tuple>>>; - - __device__ return_type operator()(key_t optionally_tagged_src, - vertex_t dst, - property_t src_val, - property_t dst_val, - thrust::nullopt_t) const + using return_type = + thrust::optional, + thrust::tuple, + thrust::tuple>>; + + __device__ return_type operator()( + vertex_t src, vertex_t dst, property_t src_val, property_t dst_val, thrust::nullopt_t) const { auto output_payload = static_cast(1); if (src_val < dst_val) { - if constexpr (std::is_same_v) { - if constexpr (std::is_arithmetic_v) { - return thrust::make_tuple(optionally_tagged_src, dst, output_payload); - } else { - static_assert(thrust::tuple_size::value == size_t{2}); - return thrust::make_tuple(optionally_tagged_src, - dst, - thrust::get<0>(output_payload), - thrust::get<1>(output_payload)); - } + if constexpr (std::is_arithmetic_v) { + return thrust::make_tuple(src, dst, output_payload); } else { - static_assert(thrust::tuple_size::value == size_t{2}); - if constexpr (std::is_arithmetic_v) { - return thrust::make_tuple(thrust::get<0>(optionally_tagged_src), - thrust::get<1>(optionally_tagged_src), - dst, - output_payload); - } else { - static_assert(thrust::tuple_size::value == size_t{2}); - return thrust::make_tuple(thrust::get<0>(optionally_tagged_src), - thrust::get<1>(optionally_tagged_src), - dst, - thrust::get<0>(output_payload), - thrust::get<1>(output_payload)); - } + static_assert(thrust::tuple_size::value == size_t{2}); + return thrust::make_tuple( + src, dst, thrust::get<0>(output_payload), thrust::get<1>(output_payload)); } } else { return thrust::nullopt; @@ -134,19 +106,11 @@ class Tests_MGExtractTransformE virtual void TearDown() {} // Compare the results of extract_transform_e primitive - template + template void run_current_test(Prims_Usecase const& prims_usecase, input_usecase_t const& input_usecase) { using result_t = int32_t; - using key_t = - std::conditional_t, vertex_t, thrust::tuple>; - - static_assert(std::is_same_v || std::is_arithmetic_v); static_assert(std::is_same_v || cugraph::is_arithmetic_or_thrust_tuple_of_arithmetic::value); if constexpr (cugraph::is_thrust_tuple::value) { @@ -212,7 +176,7 @@ class Tests_MGExtractTransformE mg_src_prop.view(), mg_dst_prop.view(), cugraph::edge_dummy_property_t{}.view(), - e_op_t{}); + e_op_t{}); if (cugraph::test::g_perf) { RAFT_CUDA_TRY(cudaDeviceSynchronize()); // for consistent performance measurement @@ -225,7 +189,7 @@ class Tests_MGExtractTransformE if (prims_usecase.check_correctness) { auto mg_aggregate_extract_transform_output_buffer = cugraph::allocate_dataframe_buffer< - typename e_op_t::return_type::value_type>( + typename e_op_t::return_type::value_type>( size_t{0}, handle_->get_stream()); std::get<0>(mg_aggregate_extract_transform_output_buffer) = cugraph::test::device_gatherv(*handle_, @@ -239,18 +203,12 @@ class Tests_MGExtractTransformE cugraph::test::device_gatherv(*handle_, std::get<2>(mg_extract_transform_output_buffer).data(), std::get<2>(mg_extract_transform_output_buffer).size()); - if constexpr (!std::is_same_v || !std::is_arithmetic_v) { + if constexpr (!std::is_arithmetic_v) { std::get<3>(mg_aggregate_extract_transform_output_buffer) = cugraph::test::device_gatherv(*handle_, std::get<3>(mg_extract_transform_output_buffer).data(), std::get<3>(mg_extract_transform_output_buffer).size()); } - if constexpr (!std::is_same_v && !std::is_arithmetic_v) { - std::get<4>(mg_aggregate_extract_transform_output_buffer) = - cugraph::test::device_gatherv(*handle_, - std::get<4>(mg_extract_transform_output_buffer).data(), - std::get<4>(mg_extract_transform_output_buffer).size()); - } cugraph::graph_t sg_graph(*handle_); std::tie(sg_graph, std::ignore, std::ignore, std::ignore) = @@ -292,7 +250,7 @@ class Tests_MGExtractTransformE sg_src_prop.view(), sg_dst_prop.view(), cugraph::edge_dummy_property_t{}.view(), - e_op_t{}); + e_op_t{}); thrust::sort(handle_->get_thrust_policy(), cugraph::get_dataframe_buffer_begin(sg_extract_transform_output_buffer), @@ -321,13 +279,13 @@ using Tests_MGExtractTransformE_Rmat = Tests_MGExtractTransformE(std::get<0>(param), std::get<1>(param)); + run_current_test(std::get<0>(param), std::get<1>(param)); } TEST_P(Tests_MGExtractTransformE_Rmat, CheckInt32Int32FloatVoidInt32) { auto param = GetParam(); - run_current_test( + run_current_test( std::get<0>(param), cugraph::test::override_Rmat_Usecase_with_cmd_line_arguments(std::get<1>(param))); } @@ -335,14 +293,14 @@ TEST_P(Tests_MGExtractTransformE_Rmat, CheckInt32Int32FloatVoidInt32) TEST_P(Tests_MGExtractTransformE_File, CheckInt32Int32FloatVoidTupleFloatInt32) { auto param = GetParam(); - run_current_test>( - std::get<0>(param), std::get<1>(param)); + run_current_test>(std::get<0>(param), + std::get<1>(param)); } TEST_P(Tests_MGExtractTransformE_Rmat, CheckInt32Int32FloatVoidTupleFloatInt32) { auto param = GetParam(); - run_current_test>( + run_current_test>( std::get<0>(param), cugraph::test::override_Rmat_Usecase_with_cmd_line_arguments(std::get<1>(param))); } @@ -350,14 +308,13 @@ TEST_P(Tests_MGExtractTransformE_Rmat, CheckInt32Int32FloatVoidTupleFloatInt32) TEST_P(Tests_MGExtractTransformE_File, CheckInt32Int32FloatInt32Int32) { auto param = GetParam(); - run_current_test(std::get<0>(param), - std::get<1>(param)); + run_current_test(std::get<0>(param), std::get<1>(param)); } TEST_P(Tests_MGExtractTransformE_Rmat, CheckInt32Int32FloatInt32Int32) { auto param = GetParam(); - run_current_test( + run_current_test( std::get<0>(param), cugraph::test::override_Rmat_Usecase_with_cmd_line_arguments(std::get<1>(param))); } @@ -365,14 +322,14 @@ TEST_P(Tests_MGExtractTransformE_Rmat, CheckInt32Int32FloatInt32Int32) TEST_P(Tests_MGExtractTransformE_File, CheckInt32Int32FloatInt32TupleFloatInt32) { auto param = GetParam(); - run_current_test>( - std::get<0>(param), std::get<1>(param)); + run_current_test>(std::get<0>(param), + std::get<1>(param)); } TEST_P(Tests_MGExtractTransformE_Rmat, CheckInt32Int32FloatInt32TupleFloatInt32) { auto param = GetParam(); - run_current_test>( + run_current_test>( std::get<0>(param), cugraph::test::override_Rmat_Usecase_with_cmd_line_arguments(std::get<1>(param))); } @@ -380,14 +337,13 @@ TEST_P(Tests_MGExtractTransformE_Rmat, CheckInt32Int32FloatInt32TupleFloatInt32) TEST_P(Tests_MGExtractTransformE_File, CheckInt32Int64FloatInt32Int32) { auto param = GetParam(); - run_current_test(std::get<0>(param), - std::get<1>(param)); + run_current_test(std::get<0>(param), std::get<1>(param)); } TEST_P(Tests_MGExtractTransformE_Rmat, CheckInt32Int64FloatInt32Int32) { auto param = GetParam(); - run_current_test( + run_current_test( std::get<0>(param), cugraph::test::override_Rmat_Usecase_with_cmd_line_arguments(std::get<1>(param))); } @@ -395,14 +351,13 @@ TEST_P(Tests_MGExtractTransformE_Rmat, CheckInt32Int64FloatInt32Int32) TEST_P(Tests_MGExtractTransformE_File, CheckInt64Int64FloatInt32Int32) { auto param = GetParam(); - run_current_test(std::get<0>(param), - std::get<1>(param)); + run_current_test(std::get<0>(param), std::get<1>(param)); } TEST_P(Tests_MGExtractTransformE_Rmat, CheckInt64Int64FloatInt32Int32) { auto param = GetParam(); - run_current_test( + run_current_test( std::get<0>(param), cugraph::test::override_Rmat_Usecase_with_cmd_line_arguments(std::get<1>(param))); } diff --git a/cpp/tests/sampling/sampling_post_processing_test.cu b/cpp/tests/sampling/sampling_post_processing_test.cu index c87cc5b960b..3bca382a2eb 100644 --- a/cpp/tests/sampling/sampling_post_processing_test.cu +++ b/cpp/tests/sampling/sampling_post_processing_test.cu @@ -398,15 +398,16 @@ bool check_renumber_map_invariants( handle.get_thrust_policy(), unique_majors.begin(), unique_majors.end(), - [sorted_org_vertices = - raft::device_span(sorted_org_vertices.data(), sorted_org_vertices.size()), - matching_renumbered_vertices = raft::device_span( - matching_renumbered_vertices.data(), - matching_renumbered_vertices.size())] __device__(vertex_t major) { - auto it = thrust::lower_bound( - thrust::seq, sorted_org_vertices.begin(), sorted_org_vertices.end(), major); - return matching_renumbered_vertices[thrust::distance(sorted_org_vertices.begin(), it)]; - }, + cuda::proclaim_return_type( + [sorted_org_vertices = raft::device_span(sorted_org_vertices.data(), + sorted_org_vertices.size()), + matching_renumbered_vertices = raft::device_span( + matching_renumbered_vertices.data(), + matching_renumbered_vertices.size())] __device__(vertex_t major) -> vertex_t { + auto it = thrust::lower_bound( + thrust::seq, sorted_org_vertices.begin(), sorted_org_vertices.end(), major); + return matching_renumbered_vertices[thrust::distance(sorted_org_vertices.begin(), it)]; + }), std::numeric_limits::lowest(), thrust::maximum{}); @@ -414,15 +415,16 @@ bool check_renumber_map_invariants( handle.get_thrust_policy(), unique_minors.begin(), unique_minors.end(), - [sorted_org_vertices = - raft::device_span(sorted_org_vertices.data(), sorted_org_vertices.size()), - matching_renumbered_vertices = raft::device_span( - matching_renumbered_vertices.data(), - matching_renumbered_vertices.size())] __device__(vertex_t minor) { - auto it = thrust::lower_bound( - thrust::seq, sorted_org_vertices.begin(), sorted_org_vertices.end(), minor); - return matching_renumbered_vertices[thrust::distance(sorted_org_vertices.begin(), it)]; - }, + cuda::proclaim_return_type( + [sorted_org_vertices = raft::device_span(sorted_org_vertices.data(), + sorted_org_vertices.size()), + matching_renumbered_vertices = raft::device_span( + matching_renumbered_vertices.data(), + matching_renumbered_vertices.size())] __device__(vertex_t minor) -> vertex_t { + auto it = thrust::lower_bound( + thrust::seq, sorted_org_vertices.begin(), sorted_org_vertices.end(), minor); + return matching_renumbered_vertices[thrust::distance(sorted_org_vertices.begin(), it)]; + }), std::numeric_limits::max(), thrust::minimum{}); From 04e80008180656da050e37a3a4b04c47ab015de9 Mon Sep 17 00:00:00 2001 From: Don Acosta <97529984+acostadon@users.noreply.github.com> Date: Wed, 29 May 2024 15:16:07 -0400 Subject: [PATCH 7/7] Fixed links and added c++ docs per issue 4431 (#4435) added content to document c++ algorithms and fixed links that were pointing to the previously removed content resolves #4431 Resolves #4116 Authors: - Don Acosta (https://github.com/acostadon) Approvers: - Rick Ratzel (https://github.com/rlratzel) - Brad Rees (https://github.com/BradReesWork) - Chuck Hastings (https://github.com/ChuckHastings) URL: https://github.com/rapidsai/cugraph/pull/4435 --- .../source/graph_support/algorithms.md | 10 +-- .../cpp_algorithms/centrality_cpp.md | 81 +++++++++++++++++++ .../algorithms/cpp_algorithms/linear_cpp.md | 37 +++++++++ .../cpp_algorithms/traversal_cpp.md | 56 +++++++++++++ 4 files changed, 179 insertions(+), 5 deletions(-) create mode 100644 docs/cugraph/source/graph_support/algorithms/cpp_algorithms/centrality_cpp.md create mode 100644 docs/cugraph/source/graph_support/algorithms/cpp_algorithms/linear_cpp.md create mode 100644 docs/cugraph/source/graph_support/algorithms/cpp_algorithms/traversal_cpp.md diff --git a/docs/cugraph/source/graph_support/algorithms.md b/docs/cugraph/source/graph_support/algorithms.md index 8a5158f2f56..2aac61325e0 100644 --- a/docs/cugraph/source/graph_support/algorithms.md +++ b/docs/cugraph/source/graph_support/algorithms.md @@ -50,10 +50,10 @@ Note: Multi-GPU, or MG, includes support for Multi-Node Multi-GPU (also called M | Layout | | | | | | [Force Atlas 2](https://github.com/rapidsai/cugraph/blob/main/notebooks/algorithms/layout/Force-Atlas2.ipynb) | Single-GPU | | | Linear Assignment | | | | -| | [Hungarian]() | Single-GPU | [README](cpp/src/linear_assignment/README-hungarian.md) | +| | [Hungarian](https://docs.rapids.ai/api/cugraph/nightly/api_docs/cugraph/linear_assignment/#hungarian) | Single-GPU | [README](./algorithms/cpp_algorithms/linear_cpp.html) | | Link Analysis | | | | -| | [Pagerank](https://github.com/rapidsai/cugraph/blob/main/notebooks/algorithms/link_analysis/Pagerank.ipynb) | __Multi-GPU__ | [C++ README](cpp/src/centrality/README.md#Pagerank) | -| | [Personal Pagerank]() | __Multi-GPU__ | [C++ README](cpp/src/centrality/README.md#Personalized-Pagerank) | +| | [Pagerank](https://github.com/rapidsai/cugraph/blob/main/notebooks/algorithms/link_analysis/Pagerank.ipynb) | __Multi-GPU__ | [C++ README](./algorithms/cpp_algorithms/centrality_cpp.html#Pagerank) | +| | [Personal Pagerank](https://github.com/rapidsai/cugraph/blob/main/notebooks/algorithms/link_analysis/Pagerank.ipynb) | __Multi-GPU__ | [C++ README](./algorithms/cpp_algorithms/centrality_cpp.html#Personalized-Pagerank) | | | [HITS](https://github.com/rapidsai/cugraph/blob/main/notebooks/algorithms/link_analysis/HITS.ipynb) | __Multi-GPU__ | | | [Link Prediction](algorithms/Similarity.html) | | | | | | [Jaccard Similarity](https://github.com/rapidsai/cugraph/blob/main/notebooks/algorithms/link_prediction/Jaccard-Similarity.ipynb) | __Multi-GPU__ | Directed graph only | @@ -68,8 +68,8 @@ Note: Multi-GPU, or MG, includes support for Multi-Node Multi-GPU (also called M | | Node2Vec | __Multi-GPU__ | | | | Neighborhood sampling | __Multi-GPU__ | | | Traversal | | | | -| | Breadth First Search (BFS) | __Multi-GPU__ | with cutoff support [C++ README](cpp/src/traversal/README.md#BFS) | -| | Single Source Shortest Path (SSSP) | __Multi-GPU__ | [C++ README](cpp/src/traversal/README.md#SSSP) | +| | Breadth First Search (BFS) | __Multi-GPU__ | [C++ README](algorithms/cpp_algorithms/traversal_cpp.html#BFS) | +| | Single Source Shortest Path (SSSP) | __Multi-GPU__ | [C++ README](algorithms/cpp_algorithms/traversal_cpp.html#SSSP) | | | _ASSP / APSP_ | --- | | | Tree | | | | | | Minimum Spanning Tree | Single-GPU | | diff --git a/docs/cugraph/source/graph_support/algorithms/cpp_algorithms/centrality_cpp.md b/docs/cugraph/source/graph_support/algorithms/cpp_algorithms/centrality_cpp.md new file mode 100644 index 00000000000..b3f7ac17d1a --- /dev/null +++ b/docs/cugraph/source/graph_support/algorithms/cpp_algorithms/centrality_cpp.md @@ -0,0 +1,81 @@ +# Centrality algorithms +cuGraph Pagerank is implemented using our graph primitive library + +## Pagerank + +The unit test code is the best place to search for examples on calling pagerank. + + * [SG Implementation](https://github.com/rapidsai/cugraph/blob/main/cpp/tests/link_analysis/pagerank_test.cpp) + * [MG Implementation](https://github.com/rapidsai/cugraph/blob/main/cpp/tests/link_analysis/mg_pagerank_test.cpp) + +## Simple pagerank + +The example assumes that you create an SG or MG graph somehow. The caller must create the pageranks vector in device memory and pass in the raw pointer to that vector into the pagerank function. + +```cpp +#include +... +using vertex_t = int32_t; // or int64_t, whichever is appropriate +using weight_t = float; // or double, whichever is appropriate +using result_t = weight_t; // could specify float or double also +raft::handle_t handle; // Must be configured if MG +auto graph_view = graph.view(); // assumes you have created a graph somehow + +result_t constexpr alpha{0.85}; +result_t constexpr epsilon{1e-6}; + +rmm::device_uvector pageranks_v(graph_view.number_of_vertices(), handle.get_stream()); + +// pagerank optionally supports three additional parameters: +// max_iterations - maximum number of iterations, if pagerank doesn't coverge by +// then we abort +// has_initial_guess - if true, values in the pagerank array when the call is initiated +// will be used as the initial pagerank values. These values will +// be normalized before use. If false (the default), the values +// in the pagerank array will be set to 1/num_vertices before +// starting the computation. +// do_expensive_check - perform extensive validation of the input data before +// executing algorithm. Off by default. Note: turning this on +// is expensive +cugraph::pagerank(handle, graph_view, nullptr, nullptr, nullptr, vertex_t{0}, + pageranks_v.data(), alpha, epsilon); +``` + +## Personalized Pagerank + +The example assumes that you create an SG or MG graph somehow. The caller must create the pageranks vector in device memory and pass in the raw pointer to that vector into the pagerank function. Additionally, the caller must create personalization_vertices and personalized_values vectors in device memory, populate them and pass in the raw pointers to those vectors. + +```cpp +#include +... +using vertex_t = int32_t; // or int64_t, whichever is appropriate +using weight_t = float; // or double, whichever is appropriate +using result_t = weight_t; // could specify float or double also +raft::handle_t handle; // Must be configured if MG +auto graph_view = graph.view(); // assumes you have created a graph somehow +vertex_t number_of_personalization_vertices; // Provided by caller + +result_t constexpr alpha{0.85}; +result_t constexpr epsilon{1e-6}; + +rmm::device_uvector pageranks_v(graph_view.number_of_vertices(), handle.get_stream()); +rmm::device_uvector personalization_vertices(number_of_personalization_vertices, handle.get_stream()); +rmm::device_uvector personalization_values(number_of_personalization_vertices, handle.get_stream()); + +// Populate personalization_vertices, personalization_values with user provided data + +// pagerank optionally supports three additional parameters: +// max_iterations - maximum number of iterations, if pagerank doesn't coverge by +// then we abort +// has_initial_guess - if true, values in the pagerank array when the call is initiated +// will be used as the initial pagerank values. These values will +// be normalized before use. If false (the default), the values +// in the pagerank array will be set to 1/num_vertices before +// starting the computation. +// do_expensive_check - perform extensive validation of the input data before +// executing algorithm. Off by default. Note: turning this on +// is expensive +cugraph::pagerank(handle, graph_view, nullptr, personalization_vertices.data(), + personalization_values.data(), number_of_personalization_vertices, + pageranks_v.data(), alpha, epsilon); +``` diff --git a/docs/cugraph/source/graph_support/algorithms/cpp_algorithms/linear_cpp.md b/docs/cugraph/source/graph_support/algorithms/cpp_algorithms/linear_cpp.md new file mode 100644 index 00000000000..8af4a5042f6 --- /dev/null +++ b/docs/cugraph/source/graph_support/algorithms/cpp_algorithms/linear_cpp.md @@ -0,0 +1,37 @@ +# LAP + +Implementation of ***O(n^3) Alternating Tree Variant*** of Hungarian Algorithm on NVIDIA CUDA-enabled GPU. + +This implementation solves a batch of ***k*** **Linear Assignment Problems (LAP)**, each with ***nxn*** matrix of single floating point cost values. At optimality, the algorithm produces an assignment with ***minimum*** cost. + +The API can be used to query optimal primal and dual costs, optimal assignment vector, and optimal row/column dual vectors for each subproblem in the batch. + +cuGraph exposes the Hungarian algorithm, the actual implementation is contained in the RAFT library which contains some common tools and kernels shared between cuGraph and cuML. + +Following parameters can be used to tune the performance of algorithm: + +1. epsilon: (in raft/lap/lap_kernels.cuh) This parameter controls the tolerance on the floating point precision. Setting this too small will result in increased solution time because the algorithm will search for precise solutions. Setting it too high may cause some inaccuracies. + +2. BLOCKDIMX, BLOCKDIMY: (in raft/lap/lap_functions.cuh) These parameters control threads_per_block to be used along the given dimension. Set these according to the device specifications and occupancy calculation. + +***This library is licensed under Apache License 2.0. Please cite our paper, if this library helps you in your research.*** + +- Harvard citation style + + Date, K. and Nagi, R., 2016. GPU-accelerated Hungarian algorithms for the Linear Assignment Problem. Parallel Computing, 57, pp.52-72. + +- BibTeX Citation block to be used in LaTeX bibliography file: + +``` +@article{date2016gpu, + title={GPU-accelerated Hungarian algorithms for the Linear Assignment Problem}, + author={Date, Ketan and Nagi, Rakesh}, + journal={Parallel Computing}, + volume={57}, + pages={52--72}, + year={2016}, + publisher={Elsevier} +} +``` + +The paper is available online on [ScienceDirect](https://www.sciencedirect.com/science/article/abs/pii/S016781911630045X). diff --git a/docs/cugraph/source/graph_support/algorithms/cpp_algorithms/traversal_cpp.md b/docs/cugraph/source/graph_support/algorithms/cpp_algorithms/traversal_cpp.md new file mode 100644 index 00000000000..6480d885a38 --- /dev/null +++ b/docs/cugraph/source/graph_support/algorithms/cpp_algorithms/traversal_cpp.md @@ -0,0 +1,56 @@ +# Traversal +cuGraph traversal algorithms are contained in this directory + +## SSSP + +The unit test code is the best place to search for examples on calling SSSP. + + * [SG Implementation](https://github.com/rapidsai/cugraph/blob/main/cpp/tests/traversal/sssp_test.cpp) + * [MG Implementation](https://github.com/rapidsai/cugraph/blob/main/cpp/tests/traversal/mg_sssp_test.cpp) + +## Simple SSSP + +The example assumes that you create an SG or MG graph somehow. The caller must create the distances and predecessors vectors in device memory and pass in the raw pointers to those vectors into the SSSP function. + +```cpp +#include +... +using vertex_t = int32_t; // or int64_t, whichever is appropriate +using weight_t = float; // or double, whichever is appropriate +using result_t = weight_t; // could specify float or double also +raft::handle_t handle; // Must be configured if MG +auto graph_view = graph.view(); // assumes you have created a graph somehow +vertex_t source; // Initialized by user + +rmm::device_uvector distances_v(graph_view.number_of_vertices(), handle.get_stream()); +rmm::device_uvector predecessors_v(graph_view.number_of_vertices(), handle.get_stream()); + +cugraph::sssp(handle, graph_view, distances_v.begin(), predecessors_v.begin(), source, std::numeric_limits::max(), false); +``` + +## BFS + +The unit test code is the best place to search for examples on calling BFS. + + * [SG Implementation](https://github.com/rapidsai/cugraph/blob/main/cpp/tests/traversal/bfs_test.cpp) + * [MG Implementation](https://github.com/rapidsai/cugraph/blob/main/cpp/tests/traversal/mg_bfs_test.cpp) + +## Simple BFS + +The example assumes that you create an SG or MG graph somehow. The caller must create the distances and predecessors vectors in device memory and pass in the raw pointers to those vectors into the BFS function. + +```cpp +#include +... +using vertex_t = int32_t; // or int64_t, whichever is appropriate +using weight_t = float; // or double, whichever is appropriate +using result_t = weight_t; // could specify float or double also +raft::handle_t handle; // Must be configured if MG +auto graph_view = graph.view(); // assumes you have created a graph somehow +vertex_t source; // Initialized by user + +rmm::device_uvector distances_v(graph_view.number_of_vertices(), handle.get_stream()); +rmm::device_uvector predecessors_v(graph_view.number_of_vertices(), handle.get_stream()); + +cugraph::bfs(handle, graph_view, d_distances.begin(), d_predecessors.begin(), source, false, std::numeric_limits::max(), false); +```