diff --git a/cpp/src/community/edge_triangle_count_impl.cuh b/cpp/src/community/edge_triangle_count_impl.cuh index 225687c4cf0..c3c6a85d5af 100644 --- a/cpp/src/community/edge_triangle_count_impl.cuh +++ b/cpp/src/community/edge_triangle_count_impl.cuh @@ -19,7 +19,7 @@ #include "detail/graph_partition_utils.cuh" #include "prims/edge_bucket.cuh" #include "prims/transform_e.cuh" -#include "prims/transform_reduce_dst_nbr_intersection_of_e_endpoints_by_v.cuh" +#include "prims/per_v_pair_dst_nbr_intersection.cuh" #include #include @@ -158,14 +158,11 @@ edge_property_t, edge_t> edge_t num_remaining_edges -= chunk_size; // Perform 'nbr_intersection' in chunks to reduce peak memory. auto [intersection_offsets, intersection_indices] = - detail::nbr_intersection(handle, - graph_view, - cugraph::edge_dummy_property_t{}.view(), - edge_first + prev_chunk_size, - edge_first + prev_chunk_size + chunk_size, - std::array{true, true}, - false /*FIXME: pass 'do_expensive_check' as argument*/); - + per_v_pair_dst_nbr_intersection(handle, + graph_view, + edge_first + prev_chunk_size, + edge_first + prev_chunk_size + chunk_size, + false /*FIXME: pass 'do_expensive_check' as argument*/); // Update the number of triangles of each (p, q) edges by looking at their intersection // size thrust::for_each( diff --git a/cpp/src/prims/per_v_pair_dst_nbr_intersection.cuh b/cpp/src/prims/per_v_pair_dst_nbr_intersection.cuh new file mode 100644 index 00000000000..646db6c6788 --- /dev/null +++ b/cpp/src/prims/per_v_pair_dst_nbr_intersection.cuh @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "detail/graph_partition_utils.cuh" +#include "prims/detail/nbr_intersection.cuh" +#include "prims/property_op_utils.cuh" +#include "utilities/collect_comm.cuh" +#include "utilities/error_check_utils.cuh" + +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace cugraph { + +namespace detail { + + +} // namespace detail + +/** + * @brief Iterate over each input vertex pair and return the common destination neighbor list + * pair in a CSR-like format + * + * Iterate over every vertex pair; intersect destination neighbor lists of the two vertices in the + * pair and store the result in a CSR-like format + * + * @tparam GraphViewType Type of the passed non-owning graph object. + * @tparam VertexPairIterator Type of the iterator for input vertex pairs. + * @param handle RAFT handle object to encapsulate resources (e.g. CUDA stream, communicator, and + * handles to various CUDA libraries) to run graph algorithms. + * @param graph_view Non-owning graph object. + * @param vertex_pair_first Iterator pointing to the first (inclusive) input vertex pair. + * @param vertex_pair_last Iterator pointing to the last (exclusive) input vertex pair. + * @param A flag to run expensive checks for input arguments (if set to `true`). + */ +template +std::tuple, rmm::device_uvector> per_v_pair_dst_nbr_intersection( + raft::handle_t const& handle, + GraphViewType const& graph_view, + VertexPairIterator vertex_pair_first, + VertexPairIterator vertex_pair_last, + bool do_expensive_check = false) +{ + static_assert(!GraphViewType::is_storage_transposed); + + if (do_expensive_check) { + auto num_invalids = + detail::count_invalid_vertex_pairs(handle, graph_view, vertex_pair_first, vertex_pair_last); + CUGRAPH_EXPECTS(num_invalids == 0, + "Invalid input arguments: there are invalid input vertex pairs."); + } + + auto [intersection_offsets, intersection_indices] = + detail::nbr_intersection(handle, + graph_view, + cugraph::edge_dummy_property_t{}.view(), + vertex_pair_first, + vertex_pair_last, + std::array{true, true}, + false /*FIXME: pass 'do_expensive_check' as argument*/); + + return std::make_tuple(std::move(intersection_offsets), std::move(intersection_indices)); +} + +} // namespace cugraph