diff --git a/cpp/include/cugraph/algorithms.hpp b/cpp/include/cugraph/algorithms.hpp index faeb7ad8f83..ed42460ed8e 100644 --- a/cpp/include/cugraph/algorithms.hpp +++ b/cpp/include/cugraph/algorithms.hpp @@ -1873,12 +1873,16 @@ void triangle_count(raft::handle_t const& handle, * @param handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator, and * handles to various CUDA libraries) to run graph algorithms. * @param graph_view Graph view object. + * * @param do_expensive_check A flag to run expensive checks for input arguments (if set to + * `true`). * * @return edge_property_t containing the edge triangle count */ template edge_property_t, edge_t> edge_triangle_count( - raft::handle_t const& handle, graph_view_t const& graph_view); + raft::handle_t const& handle, + graph_view_t const& graph_view, + bool do_expensive_check = false); /* * @brief Compute K-Truss. diff --git a/cpp/src/community/edge_triangle_count_impl.cuh b/cpp/src/community/edge_triangle_count_impl.cuh index 225687c4cf0..e3501065008 100644 --- a/cpp/src/community/edge_triangle_count_impl.cuh +++ b/cpp/src/community/edge_triangle_count_impl.cuh @@ -18,8 +18,8 @@ #include "detail/graph_partition_utils.cuh" #include "prims/edge_bucket.cuh" +#include "prims/per_v_pair_dst_nbr_intersection.cuh" #include "prims/transform_e.cuh" -#include "prims/transform_reduce_dst_nbr_intersection_of_e_endpoints_by_v.cuh" #include #include @@ -124,7 +124,8 @@ struct extract_q_r { template edge_property_t, edge_t> edge_triangle_count_impl( raft::handle_t const& handle, - graph_view_t const& graph_view) + graph_view_t const& graph_view, + bool do_expensive_check) { using weight_t = float; rmm::device_uvector edgelist_srcs(0, handle.get_stream()); @@ -158,14 +159,11 @@ edge_property_t, edge_t> edge_t num_remaining_edges -= chunk_size; // Perform 'nbr_intersection' in chunks to reduce peak memory. auto [intersection_offsets, intersection_indices] = - detail::nbr_intersection(handle, - graph_view, - cugraph::edge_dummy_property_t{}.view(), - edge_first + prev_chunk_size, - edge_first + prev_chunk_size + chunk_size, - std::array{true, true}, - false /*FIXME: pass 'do_expensive_check' as argument*/); - + per_v_pair_dst_nbr_intersection(handle, + graph_view, + edge_first + prev_chunk_size, + edge_first + prev_chunk_size + chunk_size, + do_expensive_check); // Update the number of triangles of each (p, q) edges by looking at their intersection // size thrust::for_each( @@ -365,9 +363,11 @@ edge_property_t, edge_t> edge_t template edge_property_t, edge_t> edge_triangle_count( - raft::handle_t const& handle, graph_view_t const& graph_view) + raft::handle_t const& handle, + graph_view_t const& graph_view, + bool do_expensive_check) { - return detail::edge_triangle_count_impl(handle, graph_view); + return detail::edge_triangle_count_impl(handle, graph_view, do_expensive_check); } } // namespace cugraph diff --git a/cpp/src/community/edge_triangle_count_mg_v32_e32.cu b/cpp/src/community/edge_triangle_count_mg_v32_e32.cu index 1212a13323b..5e333139ddf 100644 --- a/cpp/src/community/edge_triangle_count_mg_v32_e32.cu +++ b/cpp/src/community/edge_triangle_count_mg_v32_e32.cu @@ -20,6 +20,7 @@ namespace cugraph { // SG instantiation template edge_property_t, int32_t> edge_triangle_count( raft::handle_t const& handle, - cugraph::graph_view_t const& graph_view); + cugraph::graph_view_t const& graph_view, + bool do_expensive_check); } // namespace cugraph diff --git a/cpp/src/community/edge_triangle_count_mg_v32_e64.cu b/cpp/src/community/edge_triangle_count_mg_v32_e64.cu index 64ee195c7ee..adab2d1fede 100644 --- a/cpp/src/community/edge_triangle_count_mg_v32_e64.cu +++ b/cpp/src/community/edge_triangle_count_mg_v32_e64.cu @@ -20,6 +20,7 @@ namespace cugraph { // SG instantiation template edge_property_t, int64_t> edge_triangle_count( raft::handle_t const& handle, - cugraph::graph_view_t const& graph_view); + cugraph::graph_view_t const& graph_view, + bool do_expensive_check); } // namespace cugraph diff --git a/cpp/src/community/edge_triangle_count_mg_v64_e64.cu b/cpp/src/community/edge_triangle_count_mg_v64_e64.cu index 67c19e5ac52..1f321b2149f 100644 --- a/cpp/src/community/edge_triangle_count_mg_v64_e64.cu +++ b/cpp/src/community/edge_triangle_count_mg_v64_e64.cu @@ -20,6 +20,7 @@ namespace cugraph { // SG instantiation template edge_property_t, int64_t> edge_triangle_count( raft::handle_t const& handle, - cugraph::graph_view_t const& graph_view); + cugraph::graph_view_t const& graph_view, + bool do_expensive_check); } // namespace cugraph diff --git a/cpp/src/community/edge_triangle_count_sg_v32_e32.cu b/cpp/src/community/edge_triangle_count_sg_v32_e32.cu index d6a215aa456..3e16a2cf7ef 100644 --- a/cpp/src/community/edge_triangle_count_sg_v32_e32.cu +++ b/cpp/src/community/edge_triangle_count_sg_v32_e32.cu @@ -20,6 +20,7 @@ namespace cugraph { // SG instantiation template edge_property_t, int32_t> edge_triangle_count( raft::handle_t const& handle, - cugraph::graph_view_t const& graph_view); + cugraph::graph_view_t const& graph_view, + bool do_expensive_check); } // namespace cugraph diff --git a/cpp/src/community/edge_triangle_count_sg_v32_e64.cu b/cpp/src/community/edge_triangle_count_sg_v32_e64.cu index e70fa45c257..24a8de868e0 100644 --- a/cpp/src/community/edge_triangle_count_sg_v32_e64.cu +++ b/cpp/src/community/edge_triangle_count_sg_v32_e64.cu @@ -20,6 +20,7 @@ namespace cugraph { // SG instantiation template edge_property_t, int64_t> edge_triangle_count( raft::handle_t const& handle, - cugraph::graph_view_t const& graph_view); + cugraph::graph_view_t const& graph_view, + bool do_expensive_check); } // namespace cugraph diff --git a/cpp/src/community/edge_triangle_count_sg_v64_e64.cu b/cpp/src/community/edge_triangle_count_sg_v64_e64.cu index 849603f781b..81f814df713 100644 --- a/cpp/src/community/edge_triangle_count_sg_v64_e64.cu +++ b/cpp/src/community/edge_triangle_count_sg_v64_e64.cu @@ -20,6 +20,7 @@ namespace cugraph { // SG instantiation template edge_property_t, int64_t> edge_triangle_count( raft::handle_t const& handle, - cugraph::graph_view_t const& graph_view); + cugraph::graph_view_t const& graph_view, + bool do_expensive_check); } // namespace cugraph diff --git a/cpp/src/prims/per_v_pair_dst_nbr_intersection.cuh b/cpp/src/prims/per_v_pair_dst_nbr_intersection.cuh new file mode 100644 index 00000000000..01c76e5085a --- /dev/null +++ b/cpp/src/prims/per_v_pair_dst_nbr_intersection.cuh @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "prims/detail/nbr_intersection.cuh" + +#include + +#include + +#include + +namespace cugraph { + +/** + * @brief Iterate over each input vertex pair and returns the common destination neighbor list + * pair in a CSR-like format + * + * Iterate over every vertex pair; intersect destination neighbor lists of the two vertices in the + * pair and store the result in a CSR-like format + * + * @tparam GraphViewType Type of the passed non-owning graph object. + * @tparam VertexPairIterator Type of the iterator for input vertex pairs. + * @param handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator, and + * handles to various CUDA libraries) to run graph algorithms. + * @param graph_view Non-owning graph object. + * @param vertex_pair_first Iterator pointing to the first (inclusive) input vertex pair. + * @param vertex_pair_last Iterator pointing to the last (exclusive) input vertex pair. + * @param do_expensive_check A flag to run expensive checks for input arguments (if set to `true`). + * @return std::tuple Tuple of intersection offsets and indices. + */ +template +std::tuple, rmm::device_uvector> +per_v_pair_dst_nbr_intersection(raft::handle_t const& handle, + GraphViewType const& graph_view, + VertexPairIterator vertex_pair_first, + VertexPairIterator vertex_pair_last, + bool do_expensive_check = false) +{ + static_assert(!GraphViewType::is_storage_transposed); + + return detail::nbr_intersection(handle, + graph_view, + cugraph::edge_dummy_property_t{}.view(), + vertex_pair_first, + vertex_pair_last, + std::array{true, true}, + do_expensive_check); +} + +} // namespace cugraph