From 0f28b2ee45130486ca891b757574780ac58dd720 Mon Sep 17 00:00:00 2001 From: Alex Barghi <105237337+alexbarghi-nv@users.noreply.github.com> Date: Mon, 20 Nov 2023 13:04:06 -0500 Subject: [PATCH 1/3] [BUG] Fix Graph Construction From Pandas in cuGraph-PyG (#3985) The current graph construction creates a single pandas dataframe, which for larger datasets (i.e. ogbn-papers100M) cannot be serialized. This PR resolves this by breaking up the dataframe into scattered numpy arrays that are then reassembled. Merge after #3978 Authors: - Alex Barghi (https://github.com/alexbarghi-nv) - Naim (https://github.com/naimnv) Approvers: - Vibhu Jawa (https://github.com/VibhuJawa) - Brad Rees (https://github.com/BradReesWork) - Tingyu Wang (https://github.com/tingyu66) URL: https://github.com/rapidsai/cugraph/pull/3985 --- .../cugraph_pyg/data/cugraph_store.py | 75 +++++++++++++------ .../tests/mg/test_mg_cugraph_loader.py | 1 - .../tests/mg/test_mg_cugraph_store.py | 26 +++++++ 3 files changed, 80 insertions(+), 22 deletions(-) diff --git a/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py b/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py index edeeface4c4..14dc5d84f90 100644 --- a/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py +++ b/python/cugraph-pyg/cugraph_pyg/data/cugraph_store.py @@ -27,11 +27,12 @@ import cugraph import warnings -from cugraph.utilities.utils import import_optional, MissingModule +import dask.array as dar +import dask.dataframe as dd +import dask.distributed as distributed +import dask_cudf -dd = import_optional("dask.dataframe") -distributed = import_optional("dask.distributed") -dask_cudf = import_optional("dask_cudf") +from cugraph.utilities.utils import import_optional, MissingModule torch = import_optional("torch") torch_geometric = import_optional("torch_geometric") @@ -367,6 +368,13 @@ def __infer_offsets( } ) + def __dask_array_from_numpy(self, array: np.ndarray, npartitions: int): + return dar.from_array( + array, + meta=np.array([], dtype=array.dtype), + chunks=max(1, len(array) // npartitions), + ) + def __construct_graph( self, edge_info: Dict[Tuple[str, str, str], List[TensorType]], @@ -464,22 +472,32 @@ def __construct_graph( ] ) - df = pandas.DataFrame( - { - "src": pandas.Series(na_dst) - if order == "CSC" - else pandas.Series(na_src), - "dst": pandas.Series(na_src) - if order == "CSC" - else pandas.Series(na_dst), - "etp": pandas.Series(na_etp), - } - ) - vertex_dtype = df.src.dtype + vertex_dtype = na_src.dtype if multi_gpu: - nworkers = len(distributed.get_client().scheduler_info()["workers"]) - df = dd.from_pandas(df, npartitions=nworkers if len(df) > 32 else 1) + client = distributed.get_client() + nworkers = len(client.scheduler_info()["workers"]) + npartitions = nworkers * 4 + + src_dar = self.__dask_array_from_numpy(na_src, npartitions) + del na_src + + dst_dar = self.__dask_array_from_numpy(na_dst, npartitions) + del na_dst + + etp_dar = self.__dask_array_from_numpy(na_etp, npartitions) + del na_etp + + df = dd.from_dask_array(etp_dar, columns=["etp"]) + df["src"] = dst_dar if order == "CSC" else src_dar + df["dst"] = src_dar if order == "CSC" else dst_dar + + del src_dar + del dst_dar + del etp_dar + + if df.etp.dtype != "int32": + raise ValueError("Edge type must be int32!") # Ensure the dataframe is constructed on each partition # instead of adding additional synchronization head from potential @@ -487,9 +505,9 @@ def __construct_graph( def get_empty_df(): return cudf.DataFrame( { + "etp": cudf.Series([], dtype="int32"), "src": cudf.Series([], dtype=vertex_dtype), "dst": cudf.Series([], dtype=vertex_dtype), - "etp": cudf.Series([], dtype="int32"), } ) @@ -500,9 +518,23 @@ def get_empty_df(): if len(f) > 0 else get_empty_df(), meta=get_empty_df(), - ).reset_index(drop=True) + ).reset_index( + drop=True + ) # should be ok for dask else: - df = cudf.from_pandas(df).reset_index(drop=True) + df = pandas.DataFrame( + { + "src": pandas.Series(na_dst) + if order == "CSC" + else pandas.Series(na_src), + "dst": pandas.Series(na_src) + if order == "CSC" + else pandas.Series(na_dst), + "etp": pandas.Series(na_etp), + } + ) + df = cudf.from_pandas(df) + df.reset_index(drop=True, inplace=True) graph = cugraph.MultiGraph(directed=True) if multi_gpu: @@ -521,6 +553,7 @@ def get_empty_df(): edge_type="etp", ) + del df return graph @property diff --git a/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_loader.py b/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_loader.py index 55aebf305da..f5035a38621 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_loader.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_loader.py @@ -15,7 +15,6 @@ from cugraph_pyg.loader import CuGraphNeighborLoader from cugraph_pyg.data import CuGraphStore - from cugraph.utilities.utils import import_optional, MissingModule torch = import_optional("torch") diff --git a/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_store.py b/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_store.py index 13c9c90c7c2..be8f8245807 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_store.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/mg/test_mg_cugraph_store.py @@ -386,3 +386,29 @@ def test_mg_frame_handle(graph, dask_client): F, G, N = graph cugraph_store = CuGraphStore(F, G, N, multi_gpu=True) assert isinstance(cugraph_store._EXPERIMENTAL__CuGraphStore__graph._plc_graph, dict) + + +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +def test_cugraph_loader_large_index(dask_client): + large_index = ( + np.random.randint(0, 1_000_000, (100_000_000,)), + np.random.randint(0, 1_000_000, (100_000_000,)), + ) + + large_features = np.random.randint(0, 50, (1_000_000,)) + F = cugraph.gnn.FeatureStore(backend="torch") + F.add_data(large_features, "N", "f") + + store = CuGraphStore( + F, + {("N", "e", "N"): large_index}, + {"N": 1_000_000}, + multi_gpu=True, + ) + + graph = store._subgraph() + assert isinstance(graph, cugraph.Graph) + + el = graph.view_edge_list().compute() + assert (el["src"].values_host - large_index[0]).sum() == 0 + assert (el["dst"].values_host - large_index[1]).sum() == 0 From d34e3d6522f1f3d8e9fbea6581b7ce37de7e1005 Mon Sep 17 00:00:00 2001 From: Seunghwa Kang <45857425+seunghwak@users.noreply.github.com> Date: Mon, 20 Nov 2023 12:33:35 -0800 Subject: [PATCH 2/3] Address FIXMEs (#3988) This PR works on addressing FIXMEs (and reduce the number of outstanding FIXMEs). Authors: - Seunghwa Kang (https://github.com/seunghwak) - Naim (https://github.com/naimnv) - Ralph Liu (https://github.com/nv-rliu) Approvers: - Naim (https://github.com/naimnv) - Joseph Nke (https://github.com/jnke2016) - Chuck Hastings (https://github.com/ChuckHastings) URL: https://github.com/rapidsai/cugraph/pull/3988 --- cpp/include/cugraph/algorithms.hpp | 45 --------- cpp/include/cugraph/utilities/device_comm.hpp | 8 +- .../cugraph/utilities/host_scalar_comm.hpp | 98 ++++++++++++++----- .../cugraph/utilities/shuffle_comm.cuh | 5 - cpp/src/centrality/katz_centrality_impl.cuh | 2 - .../weakly_connected_components_impl.cuh | 40 ++------ 6 files changed, 83 insertions(+), 115 deletions(-) diff --git a/cpp/include/cugraph/algorithms.hpp b/cpp/include/cugraph/algorithms.hpp index 78846bc5766..8501eedce5c 100644 --- a/cpp/include/cugraph/algorithms.hpp +++ b/cpp/include/cugraph/algorithms.hpp @@ -464,51 +464,6 @@ k_truss_subgraph(raft::handle_t const& handle, size_t number_of_vertices, int k); -// FIXME: Internally distances is of int (signed 32-bit) data type, but current -// template uses data from VT, ET, WT from the legacy::GraphCSR View even if weights -// are not considered -/** - * @Synopsis Performs a breadth first search traversal of a graph starting from a vertex. - * - * @throws cugraph::logic_error with a custom message when an error occurs. - * - * @tparam VT Type of vertex identifiers. Supported value : int (signed, - * 32-bit) - * @tparam ET Type of edge identifiers. Supported value : int (signed, - * 32-bit) - * @tparam WT Type of edge weights. Supported values : int (signed, 32-bit) - * - * @param[in] handle Library handle (RAFT). If a communicator is set in the handle, - the multi GPU version will be selected. - * @param[in] graph cuGraph graph descriptor, should contain the connectivity - * information as a CSR - * - * @param[out] distances If set to a valid pointer, this is populated by distance of - * every vertex in the graph from the starting vertex - * - * @param[out] predecessors If set to a valid pointer, this is populated by bfs traversal - * predecessor of every vertex - * - * @param[out] sp_counters If set to a valid pointer, this is populated by bfs traversal - * shortest_path counter of every vertex - * - * @param[in] start_vertex The starting vertex for breadth first search traversal - * - * @param[in] directed Treat the input graph as directed - * - * @param[in] mg_batch If set to true use SG BFS path when comms are initialized. - * - */ -template -void bfs(raft::handle_t const& handle, - legacy::GraphCSRView const& graph, - VT* distances, - VT* predecessors, - double* sp_counters, - const VT start_vertex, - bool directed = true, - bool mg_batch = false); - /** * @brief Compute Hungarian algorithm on a weighted bipartite graph * diff --git a/cpp/include/cugraph/utilities/device_comm.hpp b/cpp/include/cugraph/utilities/device_comm.hpp index 7087724921a..990074e781b 100644 --- a/cpp/include/cugraph/utilities/device_comm.hpp +++ b/cpp/include/cugraph/utilities/device_comm.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -806,9 +806,6 @@ device_sendrecv(raft::comms::comms_t const& comm, size_t constexpr tuple_size = thrust::tuple_size::value_type>::value; - // FIXME: NCCL 2.7 supports only one ncclSend and one ncclRecv for a source rank and destination - // rank inside ncclGroupStart/ncclGroupEnd, so we cannot place this inside - // ncclGroupStart/ncclGroupEnd, this restriction will be lifted in NCCL 2.8 detail::device_sendrecv_tuple_iterator_element_impl::value_type>::value; - // FIXME: NCCL 2.7 supports only one ncclSend and one ncclRecv for a source rank and destination - // rank inside ncclGroupStart/ncclGroupEnd, so we cannot place this inside - // ncclGroupStart/ncclGroupEnd, this restriction will be lifted in NCCL 2.8 detail::device_multicast_sendrecv_tuple_iterator_element_impl std::enable_if_t::value, std::vector> host_scalar_allgather( raft::comms::comms_t const& comm, T input, cudaStream_t stream) { - std::vector rx_counts(comm.get_size(), size_t{1}); - std::vector displacements(rx_counts.size(), size_t{0}); - std::iota(displacements.begin(), displacements.end(), size_t{0}); - rmm::device_uvector d_outputs(rx_counts.size(), stream); + rmm::device_uvector d_outputs(comm.get_size(), stream); raft::update_device(d_outputs.data() + comm.get_rank(), &input, 1, stream); - // FIXME: better use allgather - comm.allgatherv(d_outputs.data() + comm.get_rank(), - d_outputs.data(), - rx_counts.data(), - displacements.data(), - stream); - std::vector h_outputs(rx_counts.size()); - raft::update_host(h_outputs.data(), d_outputs.data(), rx_counts.size(), stream); + comm.allgather(d_outputs.data() + comm.get_rank(), d_outputs.data(), size_t{1}, stream); + std::vector h_outputs(d_outputs.size()); + raft::update_host(h_outputs.data(), d_outputs.data(), d_outputs.size(), stream); auto status = comm.sync_stream(stream); CUGRAPH_EXPECTS(status == raft::comms::status_t::SUCCESS, "sync_stream() failure."); return h_outputs; @@ -277,11 +269,6 @@ std::enable_if_t::value, std::vector::value; - std::vector rx_counts(comm.get_size(), tuple_size); - std::vector displacements(rx_counts.size(), size_t{0}); - for (size_t i = 0; i < displacements.size(); ++i) { - displacements[i] = i * tuple_size; - } std::vector h_tuple_scalar_elements(tuple_size); rmm::device_uvector d_allgathered_tuple_scalar_elements(comm.get_size() * tuple_size, stream); @@ -292,12 +279,10 @@ host_scalar_allgather(raft::comms::comms_t const& comm, T input, cudaStream_t st h_tuple_scalar_elements.data(), tuple_size, stream); - // FIXME: better use allgather - comm.allgatherv(d_allgathered_tuple_scalar_elements.data() + comm.get_rank() * tuple_size, - d_allgathered_tuple_scalar_elements.data(), - rx_counts.data(), - displacements.data(), - stream); + comm.allgather(d_allgathered_tuple_scalar_elements.data() + comm.get_rank() * tuple_size, + d_allgathered_tuple_scalar_elements.data(), + tuple_size, + stream); std::vector h_allgathered_tuple_scalar_elements(comm.get_size() * tuple_size); raft::update_host(h_allgathered_tuple_scalar_elements.data(), d_allgathered_tuple_scalar_elements.data(), @@ -318,6 +303,71 @@ host_scalar_allgather(raft::comms::comms_t const& comm, T input, cudaStream_t st return ret; } +template +std::enable_if_t::value, T> host_scalar_scatter( + raft::comms::comms_t const& comm, + std::vector const& inputs, // relevant only in root + int root, + cudaStream_t stream) +{ + CUGRAPH_EXPECTS( + ((comm.get_rank() == root) && (inputs.size() == static_cast(comm.get_size()))) || + ((comm.get_rank() != root) && (inputs.size() == 0)), + "inputs.size() should match with comm.get_size() in root and should be 0 otherwise."); + rmm::device_uvector d_outputs(comm.get_size(), stream); + if (comm.get_rank() == root) { + raft::update_device(d_outputs.data(), inputs.data(), inputs.size(), stream); + } + comm.bcast(d_outputs.data(), d_outputs.size(), root, stream); + T h_output{}; + raft::update_host(&h_output, d_outputs.data() + comm.get_rank(), 1, stream); + auto status = comm.sync_stream(stream); + CUGRAPH_EXPECTS(status == raft::comms::status_t::SUCCESS, "sync_stream() failure."); + return h_output; +} + +template +std::enable_if_t::value, T> host_scalar_scatter( + raft::comms::comms_t const& comm, + std::vector const& inputs, // relevant only in root + int root, + cudaStream_t stream) +{ + CUGRAPH_EXPECTS( + ((comm.get_rank() == root) && (inputs.size() == static_cast(comm.get_size()))) || + ((comm.get_rank() != root) && (inputs.size() == 0)), + "inputs.size() should match with comm.get_size() in root and should be 0 otherwise."); + size_t constexpr tuple_size = thrust::tuple_size::value; + rmm::device_uvector d_scatter_tuple_scalar_elements(comm.get_size() * tuple_size, + stream); + if (comm.get_rank() == root) { + for (int i = 0; i < comm.get_size(); ++i) { + std::vector h_tuple_scalar_elements(tuple_size); + detail::update_vector_of_tuple_scalar_elements_from_tuple_impl() + .update(h_tuple_scalar_elements, inputs[i]); + raft::update_device(d_scatter_tuple_scalar_elements.data() + i * tuple_size, + h_tuple_scalar_elements.data(), + tuple_size, + stream); + } + } + comm.bcast( + d_scatter_tuple_scalar_elements.data(), d_scatter_tuple_scalar_elements.size(), root, stream); + std::vector h_tuple_scalar_elements(tuple_size); + raft::update_host(h_tuple_scalar_elements.data(), + d_scatter_tuple_scalar_elements.data() + comm.get_rank() * tuple_size, + tuple_size, + stream); + auto status = comm.sync_stream(stream); + CUGRAPH_EXPECTS(status == raft::comms::status_t::SUCCESS, "sync_stream() failure."); + + T ret{}; + detail::update_tuple_from_vector_of_tuple_scalar_elements_impl().update( + ret, h_tuple_scalar_elements); + + return ret; +} + // Return value is valid only in root (return value may better be std::optional in C++17 or later) template std::enable_if_t::value, std::vector> host_scalar_gather( diff --git a/cpp/include/cugraph/utilities/shuffle_comm.cuh b/cpp/include/cugraph/utilities/shuffle_comm.cuh index 6a260144324..ab6a54cc1c0 100644 --- a/cpp/include/cugraph/utilities/shuffle_comm.cuh +++ b/cpp/include/cugraph/utilities/shuffle_comm.cuh @@ -80,7 +80,6 @@ compute_tx_rx_counts_offsets_ranks(raft::comms::comms_t const& comm, rmm::device_uvector d_rx_value_counts(comm_size, stream_view); - // FIXME: this needs to be replaced with AlltoAll once NCCL 2.8 is released. std::vector tx_counts(comm_size, size_t{1}); std::vector tx_offsets(comm_size); std::iota(tx_offsets.begin(), tx_offsets.end(), size_t{0}); @@ -835,7 +834,6 @@ auto shuffle_values(raft::comms::comms_t const& comm, allocate_dataframe_buffer::value_type>( rx_offsets.size() > 0 ? rx_offsets.back() + rx_counts.back() : size_t{0}, stream_view); - // FIXME: this needs to be replaced with AlltoAll once NCCL 2.8 is released // (if num_tx_dst_ranks == num_rx_src_ranks == comm_size). device_multicast_sendrecv(comm, tx_value_first, @@ -889,7 +887,6 @@ auto groupby_gpu_id_and_shuffle_values(raft::comms::comms_t const& comm, allocate_dataframe_buffer::value_type>( rx_offsets.size() > 0 ? rx_offsets.back() + rx_counts.back() : size_t{0}, stream_view); - // FIXME: this needs to be replaced with AlltoAll once NCCL 2.8 is released // (if num_tx_dst_ranks == num_rx_src_ranks == comm_size). device_multicast_sendrecv(comm, tx_value_first, @@ -946,7 +943,6 @@ auto groupby_gpu_id_and_shuffle_kv_pairs(raft::comms::comms_t const& comm, allocate_dataframe_buffer::value_type>( rx_keys.size(), stream_view); - // FIXME: this needs to be replaced with AlltoAll once NCCL 2.8 is released // (if num_tx_dst_ranks == num_rx_src_ranks == comm_size). device_multicast_sendrecv(comm, tx_key_first, @@ -959,7 +955,6 @@ auto groupby_gpu_id_and_shuffle_kv_pairs(raft::comms::comms_t const& comm, rx_src_ranks, stream_view); - // FIXME: this needs to be replaced with AlltoAll once NCCL 2.8 is released // (if num_tx_dst_ranks == num_rx_src_ranks == comm_size). device_multicast_sendrecv(comm, tx_value_first, diff --git a/cpp/src/centrality/katz_centrality_impl.cuh b/cpp/src/centrality/katz_centrality_impl.cuh index 202d00a5771..ac31043d862 100644 --- a/cpp/src/centrality/katz_centrality_impl.cuh +++ b/cpp/src/centrality/katz_centrality_impl.cuh @@ -74,8 +74,6 @@ void katz_centrality( CUGRAPH_EXPECTS(epsilon >= 0.0, "Invalid input argument: epsilon should be non-negative."); if (do_expensive_check) { - // FIXME: should I check for betas? - if (has_initial_guess) { auto num_negative_values = count_if_v(handle, pull_graph_view, katz_centralities, [] __device__(auto, auto val) { diff --git a/cpp/src/components/weakly_connected_components_impl.cuh b/cpp/src/components/weakly_connected_components_impl.cuh index 615a50ded54..b7b6e139cfa 100644 --- a/cpp/src/components/weakly_connected_components_impl.cuh +++ b/cpp/src/components/weakly_connected_components_impl.cuh @@ -236,18 +236,16 @@ struct v_op_t { auto tag = thrust::get<1>(tagged_v); auto v_offset = vertex_partition.local_vertex_partition_offset_from_vertex_nocheck(thrust::get<0>(tagged_v)); - // FIXME: better switch to atomic_ref after - // https://github.com/nvidia/libcudacxx/milestone/2 - auto old = - atomicCAS(level_components + v_offset, invalid_component_id::value, tag); - if (old != invalid_component_id::value && old != tag) { // conflict + cuda::atomic_ref v_component(*(level_components + v_offset)); + auto old = invalid_component_id::value; + bool success = v_component.compare_exchange_strong(old, tag, cuda::std::memory_order_relaxed); + if (!success && (old != tag)) { // conflict return thrust::make_tuple(thrust::optional{bucket_idx_conflict}, thrust::optional{std::byte{0}} /* dummy */); } else { - auto update = (old == invalid_component_id::value); return thrust::make_tuple( - update ? thrust::optional{bucket_idx_next} : thrust::nullopt, - update ? thrust::optional{std::byte{0}} /* dummy */ : thrust::nullopt); + success ? thrust::optional{bucket_idx_next} : thrust::nullopt, + success ? thrust::optional{std::byte{0}} /* dummy */ : thrust::nullopt); } } @@ -457,33 +455,11 @@ void weakly_connected_components_impl(raft::handle_t const& handle, std::numeric_limits::max()); } - // FIXME: we need to add host_scalar_scatter -#if 1 - rmm::device_uvector d_counts(comm_size, handle.get_stream()); - raft::update_device(d_counts.data(), - init_max_new_root_counts.data(), - init_max_new_root_counts.size(), - handle.get_stream()); - device_bcast( - comm, d_counts.data(), d_counts.data(), d_counts.size(), int{0}, handle.get_stream()); - raft::update_host( - &init_max_new_roots, d_counts.data() + comm_rank, size_t{1}, handle.get_stream()); -#else init_max_new_roots = - host_scalar_scatter(comm, init_max_new_root_counts.data(), int{0}, handle.get_stream()); -#endif + host_scalar_scatter(comm, init_max_new_root_counts, int{0}, handle.get_stream()); } else { - // FIXME: we need to add host_scalar_scatter -#if 1 - rmm::device_uvector d_counts(comm_size, handle.get_stream()); - device_bcast( - comm, d_counts.data(), d_counts.data(), d_counts.size(), int{0}, handle.get_stream()); - raft::update_host( - &init_max_new_roots, d_counts.data() + comm_rank, size_t{1}, handle.get_stream()); -#else init_max_new_roots = - host_scalar_scatter(comm, init_max_new_root_counts.data(), int{0}, handle.get_stream()); -#endif + host_scalar_scatter(comm, std::vector{}, int{0}, handle.get_stream()); } handle.sync_stream(); From 8549b546ef1a97b4c25a0f25b73700802d563d17 Mon Sep 17 00:00:00 2001 From: Naim <110031745+naimnv@users.noreply.github.com> Date: Mon, 20 Nov 2023 21:39:53 +0100 Subject: [PATCH 3/3] Fix Leiden refinement phase (#3990) - Normalization factor was missing in the equation to decide if a node and a refined community is strongly connected inside their Louvain community. This PR adds that factor. - Disable random moves in the refinement phase. We plan to expose a flag to enable/disable random moves in a future PR. - Adds new function to flatten Leiden dendrogram as dendrogram flattening process needs additional info to unroll hierarchical leiden clustering Closes #3850 Closes #3749 Authors: - Naim (https://github.com/naimnv) - Alex Barghi (https://github.com/alexbarghi-nv) Approvers: - Chuck Hastings (https://github.com/ChuckHastings) - Seunghwa Kang (https://github.com/seunghwak) - Brad Rees (https://github.com/BradReesWork) URL: https://github.com/rapidsai/cugraph/pull/3990 --- cpp/src/community/detail/common_methods.cuh | 46 +++- cpp/src/community/detail/refine_impl.cuh | 22 +- cpp/src/community/flatten_dendrogram.hpp | 29 ++- cpp/src/community/leiden_impl.cuh | 200 ++++++++++-------- cpp/tests/c_api/leiden_test.c | 4 +- cpp/tests/c_api/louvain_test.c | 39 +++- cpp/tests/community/louvain_test.cpp | 81 +------ .../cugraph/tests/community/test_leiden.py | 28 +-- 8 files changed, 242 insertions(+), 207 deletions(-) diff --git a/cpp/src/community/detail/common_methods.cuh b/cpp/src/community/detail/common_methods.cuh index b388ba53e81..f67d4d939ad 100644 --- a/cpp/src/community/detail/common_methods.cuh +++ b/cpp/src/community/detail/common_methods.cuh @@ -52,7 +52,7 @@ struct is_bitwise_comparable> : std::true_type {}; namespace cugraph { namespace detail { -// a workaround for cudaErrorInvalidDeviceFunction error when device lambda is used +// FIXME: a workaround for cudaErrorInvalidDeviceFunction error when device lambda is used template struct key_aggregated_edge_op_t { weight_t total_edge_weight{}; @@ -80,7 +80,7 @@ struct key_aggregated_edge_op_t { } }; -// a workaround for cudaErrorInvalidDeviceFunction error when device lambda is used +// FIXME: a workaround for cudaErrorInvalidDeviceFunction error when device lambda is used template struct reduce_op_t { using type = thrust::tuple; @@ -100,7 +100,28 @@ struct reduce_op_t { } }; -// a workaround for cudaErrorInvalidDeviceFunction error when device lambda is used +// FIXME: a workaround for cudaErrorInvalidDeviceFunction error when device lambda is used +template +struct count_updown_moves_op_t { + bool up_down{}; + __device__ auto operator()(thrust::tuple> p) const + { + vertex_t old_cluster = thrust::get<0>(p); + auto new_cluster_gain_pair = thrust::get<1>(p); + vertex_t new_cluster = thrust::get<0>(new_cluster_gain_pair); + weight_t delta_modularity = thrust::get<1>(new_cluster_gain_pair); + + auto result_assignment = + (delta_modularity > weight_t{0}) + ? (((new_cluster > old_cluster) != up_down) ? old_cluster : new_cluster) + : old_cluster; + + return (delta_modularity > weight_t{0}) + ? (((new_cluster > old_cluster) != up_down) ? false : true) + : false; + } +}; +// FIXME: a workaround for cudaErrorInvalidDeviceFunction error when device lambda is used template struct cluster_update_op_t { bool up_down{}; @@ -115,7 +136,7 @@ struct cluster_update_op_t { } }; -// a workaround for cudaErrorInvalidDeviceFunction error when device lambda is used +// FIXME: a workaround for cudaErrorInvalidDeviceFunction error when device lambda is used template struct return_edge_weight_t { __device__ auto operator()( @@ -125,7 +146,7 @@ struct return_edge_weight_t { } }; -// a workaround for cudaErrorInvalidDeviceFunction error when device lambda is used +// FIXME: a workaround for cudaErrorInvalidDeviceFunction error when device lambda is used template struct return_one_t { __device__ auto operator()( @@ -394,6 +415,21 @@ rmm::device_uvector update_clustering_by_delta_modularity( detail::reduce_op_t{}, cugraph::get_dataframe_buffer_begin(output_buffer)); + int nr_moves = thrust::count_if( + handle.get_thrust_policy(), + thrust::make_zip_iterator(thrust::make_tuple( + next_clusters_v.begin(), cugraph::get_dataframe_buffer_begin(output_buffer))), + thrust::make_zip_iterator( + thrust::make_tuple(next_clusters_v.end(), cugraph::get_dataframe_buffer_end(output_buffer))), + detail::count_updown_moves_op_t{up_down}); + + if (multi_gpu) { + nr_moves = host_scalar_allreduce( + handle.get_comms(), nr_moves, raft::comms::op_t::SUM, handle.get_stream()); + } + + if (nr_moves == 0) { up_down = !up_down; } + thrust::transform(handle.get_thrust_policy(), next_clusters_v.begin(), next_clusters_v.end(), diff --git a/cpp/src/community/detail/refine_impl.cuh b/cpp/src/community/detail/refine_impl.cuh index 6b6470991bb..ebaae498d04 100644 --- a/cpp/src/community/detail/refine_impl.cuh +++ b/cpp/src/community/detail/refine_impl.cuh @@ -89,8 +89,9 @@ struct leiden_key_aggregated_edge_op_t { // E(Cr, S-Cr) > ||Cr||*(||S|| -||Cr||) bool is_dst_leiden_cluster_well_connected = - dst_leiden_cut_to_louvain > - resolution * dst_leiden_volume * (louvain_cluster_volume - dst_leiden_volume); + dst_leiden_cut_to_louvain > resolution * dst_leiden_volume * + (louvain_cluster_volume - dst_leiden_volume) / + total_edge_weight; // E(v, Cr-v) - ||v||* ||Cr-v||/||V(G)|| // aggregated_weight_to_neighboring_leiden_cluster == E(v, Cr-v)? @@ -98,11 +99,11 @@ struct leiden_key_aggregated_edge_op_t { weight_t mod_gain = -1.0; if (is_src_active > 0) { if ((louvain_of_dst_leiden_cluster == src_louvain_cluster) && - is_dst_leiden_cluster_well_connected) { + (dst_leiden_cluster_id != src_leiden_cluster) && is_dst_leiden_cluster_well_connected) { mod_gain = aggregated_weight_to_neighboring_leiden_cluster - - resolution * src_weighted_deg * (dst_leiden_volume - src_weighted_deg) / - total_edge_weight; - + resolution * src_weighted_deg * dst_leiden_volume / total_edge_weight; +// FIXME: Disable random moves in refinement phase for now. +#if 0 weight_t random_number{0.0}; if (mod_gain > 0.0) { auto flat_id = uint64_t{threadIdx.x + blockIdx.x * blockDim.x}; @@ -117,6 +118,8 @@ struct leiden_key_aggregated_edge_op_t { ? __expf(static_cast((2.0 * mod_gain) / (theta * total_edge_weight))) * random_number : -1.0; +#endif + mod_gain = mod_gain > 0.0 ? mod_gain : -1.0; } } @@ -240,11 +243,12 @@ refine_clustering( wcut_deg_and_cluster_vol_triple_begin, wcut_deg_and_cluster_vol_triple_end, singleton_and_connected_flags.begin(), - [resolution] __device__(auto wcut_wdeg_and_louvain_volume) { + [resolution, total_edge_weight] __device__(auto wcut_wdeg_and_louvain_volume) { auto wcut = thrust::get<0>(wcut_wdeg_and_louvain_volume); auto wdeg = thrust::get<1>(wcut_wdeg_and_louvain_volume); auto louvain_volume = thrust::get<2>(wcut_wdeg_and_louvain_volume); - return wcut > (resolution * wdeg * (louvain_volume - wdeg)); + return wcut > + (resolution * wdeg * (louvain_volume - wdeg) / total_edge_weight); }); edge_src_property_t src_louvain_cluster_weight_cache(handle); @@ -478,7 +482,7 @@ refine_clustering( auto values_for_leiden_cluster_keys = thrust::make_zip_iterator( thrust::make_tuple(refined_community_volumes.begin(), refined_community_cuts.begin(), - leiden_keys_used_in_edge_reduction.begin(), // redundant + leiden_keys_used_in_edge_reduction.begin(), louvain_of_leiden_keys_used_in_edge_reduction.begin())); using value_t = thrust::tuple; diff --git a/cpp/src/community/flatten_dendrogram.hpp b/cpp/src/community/flatten_dendrogram.hpp index 9a0c103c01f..eac20389765 100644 --- a/cpp/src/community/flatten_dendrogram.hpp +++ b/cpp/src/community/flatten_dendrogram.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -59,4 +59,31 @@ void partition_at_level(raft::handle_t const& handle, }); } +template +void leiden_partition_at_level(raft::handle_t const& handle, + Dendrogram const& dendrogram, + vertex_t* d_partition, + size_t level) +{ + vertex_t local_num_verts = dendrogram.get_level_size_nocheck(0); + raft::copy( + d_partition, dendrogram.get_level_ptr_nocheck(0), local_num_verts, handle.get_stream()); + + rmm::device_uvector local_vertex_ids_v(local_num_verts, handle.get_stream()); + + std::for_each( + thrust::make_counting_iterator(0), + thrust::make_counting_iterator((level - 1) / 2), + [&handle, &dendrogram, &local_vertex_ids_v, &d_partition, local_num_verts](size_t l) { + cugraph::relabel( + handle, + std::tuple(dendrogram.get_level_ptr_nocheck(2 * l + 1), + dendrogram.get_level_ptr_nocheck(2 * l + 2)), + dendrogram.get_level_size_nocheck(2 * l + 1), + d_partition, + local_num_verts, + false); + }); +} + } // namespace cugraph diff --git a/cpp/src/community/leiden_impl.cuh b/cpp/src/community/leiden_impl.cuh index a9faf2f2d82..b6e20272de9 100644 --- a/cpp/src/community/leiden_impl.cuh +++ b/cpp/src/community/leiden_impl.cuh @@ -43,6 +43,34 @@ void check_clustering(graph_view_t const& gr if (graph_view.local_vertex_partition_range_size() > 0) CUGRAPH_EXPECTS(clustering != nullptr, "Invalid input argument: clustering is null"); } +template +vertex_t remove_duplicates(raft::handle_t const& handle, rmm::device_uvector& input_array) +{ + thrust::sort(handle.get_thrust_policy(), input_array.begin(), input_array.end()); + + auto nr_unique_elements = static_cast(thrust::distance( + input_array.begin(), + thrust::unique(handle.get_thrust_policy(), input_array.begin(), input_array.end()))); + + input_array.resize(nr_unique_elements, handle.get_stream()); + + if constexpr (multi_gpu) { + input_array = cugraph::detail::shuffle_ext_vertices_to_local_gpu_by_vertex_partitioning( + handle, std::move(input_array)); + + thrust::sort(handle.get_thrust_policy(), input_array.begin(), input_array.end()); + + nr_unique_elements = static_cast(thrust::distance( + input_array.begin(), + thrust::unique(handle.get_thrust_policy(), input_array.begin(), input_array.end()))); + + input_array.resize(nr_unique_elements, handle.get_stream()); + + nr_unique_elements = host_scalar_allreduce( + handle.get_comms(), nr_unique_elements, raft::comms::op_t::SUM, handle.get_stream()); + } + return nr_unique_elements; +} template >, weight_t> leiden( rmm::device_uvector louvain_of_refined_graph(0, handle.get_stream()); // #V - while (dendrogram->num_levels() < max_level) { + while (dendrogram->num_levels() < 2 * max_level + 1) { // // Initialize every cluster to reference each vertex to itself // @@ -353,40 +381,8 @@ std::pair>, weight_t> leiden( dendrogram->current_level_begin(), dendrogram->current_level_begin() + dendrogram->current_level_size(), copied_louvain_partition.begin()); - - thrust::sort( - handle.get_thrust_policy(), copied_louvain_partition.begin(), copied_louvain_partition.end()); - auto nr_unique_louvain_clusters = - static_cast(thrust::distance(copied_louvain_partition.begin(), - thrust::unique(handle.get_thrust_policy(), - copied_louvain_partition.begin(), - copied_louvain_partition.end()))); - - copied_louvain_partition.resize(nr_unique_louvain_clusters, handle.get_stream()); - - if constexpr (graph_view_t::is_multi_gpu) { - copied_louvain_partition = - cugraph::detail::shuffle_ext_vertices_to_local_gpu_by_vertex_partitioning( - handle, std::move(copied_louvain_partition)); - - thrust::sort(handle.get_thrust_policy(), - copied_louvain_partition.begin(), - copied_louvain_partition.end()); - - nr_unique_louvain_clusters = - static_cast(thrust::distance(copied_louvain_partition.begin(), - thrust::unique(handle.get_thrust_policy(), - copied_louvain_partition.begin(), - copied_louvain_partition.end()))); - - copied_louvain_partition.resize(nr_unique_louvain_clusters, handle.get_stream()); - - nr_unique_louvain_clusters = host_scalar_allreduce(handle.get_comms(), - nr_unique_louvain_clusters, - raft::comms::op_t::SUM, - handle.get_stream()); - } + remove_duplicates(handle, copied_louvain_partition); terminate = terminate || (nr_unique_louvain_clusters == current_graph_view.number_of_vertices()); @@ -481,6 +477,15 @@ std::pair>, weight_t> leiden( (*cluster_assignment).data(), (*cluster_assignment).size(), false); + // louvain assignment of aggregated graph which is necessary to flatten dendrogram + dendrogram->add_level(current_graph_view.local_vertex_partition_range_first(), + current_graph_view.local_vertex_partition_range_size(), + handle.get_stream()); + + raft::copy(dendrogram->current_level_begin(), + (*cluster_assignment).begin(), + (*cluster_assignment).size(), + handle.get_stream()); louvain_of_refined_graph.resize(current_graph_view.local_vertex_partition_range_size(), handle.get_stream()); @@ -492,47 +497,6 @@ std::pair>, weight_t> leiden( } } - // Relabel dendrogram - vertex_t local_cluster_id_first{0}; - if constexpr (multi_gpu) { - auto unique_cluster_range_lasts = cugraph::partition_manager::compute_partition_range_lasts( - handle, static_cast(copied_louvain_partition.size())); - - auto& comm = handle.get_comms(); - auto const comm_size = comm.get_size(); - auto const comm_rank = comm.get_rank(); - auto& major_comm = handle.get_subcomm(cugraph::partition_manager::major_comm_name()); - auto const major_comm_size = major_comm.get_size(); - auto const major_comm_rank = major_comm.get_rank(); - auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); - auto const minor_comm_size = minor_comm.get_size(); - auto const minor_comm_rank = minor_comm.get_rank(); - - auto vertex_partition_id = - partition_manager::compute_vertex_partition_id_from_graph_subcomm_ranks( - major_comm_size, minor_comm_size, major_comm_rank, minor_comm_rank); - - local_cluster_id_first = vertex_partition_id == 0 - ? vertex_t{0} - : unique_cluster_range_lasts[vertex_partition_id - 1]; - } - - rmm::device_uvector numbering_indices(copied_louvain_partition.size(), - handle.get_stream()); - detail::sequence_fill(handle.get_stream(), - numbering_indices.data(), - numbering_indices.size(), - local_cluster_id_first); - - relabel( - handle, - std::make_tuple(static_cast(copied_louvain_partition.begin()), - static_cast(numbering_indices.begin())), - copied_louvain_partition.size(), - dendrogram->current_level_begin(), - dendrogram->current_level_size(), - false); - copied_louvain_partition.resize(0, handle.get_stream()); copied_louvain_partition.shrink_to_fit(handle.get_stream()); @@ -550,23 +514,71 @@ std::pair>, weight_t> leiden( return std::make_pair(std::move(dendrogram), best_modularity); } -// FIXME: Can we have a common flatten_dendrogram to be used by both -// Louvain and Leiden, and possibly other clustering methods? +template +void relabel_cluster_ids(raft::handle_t const& handle, + rmm::device_uvector& unique_cluster_ids, + vertex_t* clustering, + size_t num_nodes) +{ + vertex_t local_cluster_id_first{0}; + if constexpr (multi_gpu) { + auto unique_cluster_range_lasts = cugraph::partition_manager::compute_partition_range_lasts( + handle, static_cast(unique_cluster_ids.size())); + + auto& comm = handle.get_comms(); + auto const comm_size = comm.get_size(); + auto const comm_rank = comm.get_rank(); + auto& major_comm = handle.get_subcomm(cugraph::partition_manager::major_comm_name()); + auto const major_comm_size = major_comm.get_size(); + auto const major_comm_rank = major_comm.get_rank(); + auto& minor_comm = handle.get_subcomm(cugraph::partition_manager::minor_comm_name()); + auto const minor_comm_size = minor_comm.get_size(); + auto const minor_comm_rank = minor_comm.get_rank(); + + auto vertex_partition_id = + partition_manager::compute_vertex_partition_id_from_graph_subcomm_ranks( + major_comm_size, minor_comm_size, major_comm_rank, minor_comm_rank); + + local_cluster_id_first = + vertex_partition_id == 0 ? vertex_t{0} : unique_cluster_range_lasts[vertex_partition_id - 1]; + } + + rmm::device_uvector numbering_indices(unique_cluster_ids.size(), handle.get_stream()); + detail::sequence_fill(handle.get_stream(), + numbering_indices.data(), + numbering_indices.size(), + local_cluster_id_first); + + relabel( + handle, + std::make_tuple(static_cast(unique_cluster_ids.begin()), + static_cast(numbering_indices.begin())), + unique_cluster_ids.size(), + clustering, + num_nodes, + false); +} + template -void flatten_dendrogram(raft::handle_t const& handle, - graph_view_t const& graph_view, - Dendrogram const& dendrogram, - vertex_t* clustering) +void flatten_leiden_dendrogram(raft::handle_t const& handle, + graph_view_t const& graph_view, + Dendrogram const& dendrogram, + vertex_t* clustering) { - rmm::device_uvector vertex_ids_v(graph_view.number_of_vertices(), handle.get_stream()); + leiden_partition_at_level( + handle, dendrogram, clustering, dendrogram.num_levels()); + + rmm::device_uvector unique_cluster_ids(graph_view.number_of_vertices(), + handle.get_stream()); + thrust::copy(handle.get_thrust_policy(), + clustering, + clustering + graph_view.number_of_vertices(), + unique_cluster_ids.begin()); - thrust::sequence(handle.get_thrust_policy(), - vertex_ids_v.begin(), - vertex_ids_v.end(), - graph_view.local_vertex_partition_range_first()); + remove_duplicates(handle, unique_cluster_ids); - partition_at_level( - handle, dendrogram, vertex_ids_v.data(), clustering, dendrogram.num_levels()); + relabel_cluster_ids( + handle, unique_cluster_ids, clustering, graph_view.number_of_vertices()); } } // namespace detail @@ -588,14 +600,14 @@ std::pair>, weight_t> leiden( } template -void flatten_dendrogram(raft::handle_t const& handle, - graph_view_t const& graph_view, - Dendrogram const& dendrogram, - vertex_t* clustering) +void flatten_leiden_dendrogram(raft::handle_t const& handle, + graph_view_t const& graph_view, + Dendrogram const& dendrogram, + vertex_t* clustering) { CUGRAPH_EXPECTS(!graph_view.has_edge_mask(), "unimplemented."); - detail::flatten_dendrogram(handle, graph_view, dendrogram, clustering); + detail::flatten_leiden_dendrogram(handle, graph_view, dendrogram, clustering); } template @@ -620,7 +632,7 @@ std::pair leiden( std::tie(dendrogram, modularity) = detail::leiden(handle, rng_state, graph_view, edge_weight_view, max_level, resolution, theta); - detail::flatten_dendrogram(handle, graph_view, *dendrogram, clustering); + detail::flatten_leiden_dendrogram(handle, graph_view, *dendrogram, clustering); return std::make_pair(dendrogram->num_levels(), modularity); } diff --git a/cpp/tests/c_api/leiden_test.c b/cpp/tests/c_api/leiden_test.c index 9e91adf9f89..df206ebd1ed 100644 --- a/cpp/tests/c_api/leiden_test.c +++ b/cpp/tests/c_api/leiden_test.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -161,7 +161,7 @@ int test_leiden_no_weights() vertex_t h_src[] = {0, 1, 1, 2, 2, 2, 3, 4, 1, 3, 4, 0, 1, 3, 5, 5}; vertex_t h_dst[] = {1, 3, 4, 0, 1, 3, 5, 5, 0, 1, 1, 2, 2, 2, 3, 4}; vertex_t h_result[] = {1, 1, 1, 2, 0, 0}; - weight_t expected_modularity = 0.0859375; + weight_t expected_modularity = 0.125; // Louvain wants store_transposed = FALSE return generic_leiden_test(h_src, diff --git a/cpp/tests/c_api/louvain_test.c b/cpp/tests/c_api/louvain_test.c index e9ac5c9ff06..41d777545b2 100644 --- a/cpp/tests/c_api/louvain_test.c +++ b/cpp/tests/c_api/louvain_test.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,22 +46,39 @@ int generic_louvain_test(vertex_t* h_src, cugraph_graph_t* p_graph = NULL; cugraph_hierarchical_clustering_result_t* p_result = NULL; - data_type_id_t vertex_tid = INT32; - data_type_id_t edge_tid = INT32; - data_type_id_t weight_tid = FLOAT32; + data_type_id_t vertex_tid = INT32; + data_type_id_t edge_tid = INT32; + data_type_id_t weight_tid = FLOAT32; data_type_id_t edge_id_tid = INT32; data_type_id_t edge_type_tid = INT32; p_handle = cugraph_create_resource_handle(NULL); TEST_ASSERT(test_ret_value, p_handle != NULL, "resource handle creation failed."); - ret_code = create_sg_test_graph(p_handle, vertex_tid, edge_tid, h_src, h_dst, weight_tid, h_wgt, edge_type_tid, NULL, edge_id_tid, NULL, num_edges, store_transposed, FALSE, FALSE, FALSE, &p_graph, &ret_error); + ret_code = create_sg_test_graph(p_handle, + vertex_tid, + edge_tid, + h_src, + h_dst, + weight_tid, + h_wgt, + edge_type_tid, + NULL, + edge_id_tid, + NULL, + num_edges, + store_transposed, + FALSE, + FALSE, + FALSE, + &p_graph, + &ret_error); TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, "create_test_graph failed."); TEST_ALWAYS_ASSERT(ret_code == CUGRAPH_SUCCESS, cugraph_error_message(ret_error)); - ret_code = - cugraph_louvain(p_handle, p_graph, max_level, threshold, resolution, FALSE, &p_result, &ret_error); + ret_code = cugraph_louvain( + p_handle, p_graph, max_level, threshold, resolution, FALSE, &p_result, &ret_error); TEST_ASSERT(test_ret_value, ret_code == CUGRAPH_SUCCESS, cugraph_error_message(ret_error)); TEST_ALWAYS_ASSERT(ret_code == CUGRAPH_SUCCESS, "cugraph_louvain failed."); @@ -141,10 +158,10 @@ int test_louvain_no_weight() weight_t threshold = 1e-7; weight_t resolution = 1.0; - vertex_t h_src[] = {0, 1, 1, 2, 2, 2, 3, 4, 1, 3, 4, 0, 1, 3, 5, 5}; - vertex_t h_dst[] = {1, 3, 4, 0, 1, 3, 5, 5, 0, 1, 1, 2, 2, 2, 3, 4}; - vertex_t h_result[] = {1, 1, 1, 2, 0, 0}; - weight_t expected_modularity = 0.0859375; + vertex_t h_src[] = {0, 1, 1, 2, 2, 2, 3, 4, 1, 3, 4, 0, 1, 3, 5, 5}; + vertex_t h_dst[] = {1, 3, 4, 0, 1, 3, 5, 5, 0, 1, 1, 2, 2, 2, 3, 4}; + vertex_t h_result[] = {1, 1, 1, 1, 0, 0}; + weight_t expected_modularity = 0.125; // Louvain wants store_transposed = FALSE return generic_louvain_test(h_src, diff --git a/cpp/tests/community/louvain_test.cpp b/cpp/tests/community/louvain_test.cpp index 1e1fb6d4c33..284dcc94b8c 100644 --- a/cpp/tests/community/louvain_test.cpp +++ b/cpp/tests/community/louvain_test.cpp @@ -317,72 +317,6 @@ TEST(louvain_legacy, success) } } -TEST(louvain_legacy_renumbered, success) -{ - raft::handle_t handle; - - auto stream = handle.get_stream(); - - std::vector off_h = {0, 16, 25, 30, 34, 38, 42, 44, 46, 48, 50, 52, - 54, 56, 73, 85, 95, 101, 107, 112, 117, 121, 125, 129, - 132, 135, 138, 141, 144, 147, 149, 151, 153, 155, 156}; - std::vector ind_h = { - 1, 3, 7, 11, 15, 16, 17, 18, 19, 20, 21, 23, 24, 25, 30, 33, 0, 5, 11, 15, 16, 19, 21, - 25, 30, 4, 13, 14, 22, 27, 0, 9, 20, 24, 2, 13, 15, 26, 1, 13, 14, 18, 13, 15, 0, 16, - 13, 14, 3, 20, 13, 14, 0, 1, 13, 22, 2, 4, 5, 6, 8, 10, 12, 14, 17, 18, 19, 22, 25, - 28, 29, 31, 32, 2, 5, 8, 10, 13, 15, 17, 18, 22, 29, 31, 32, 0, 1, 4, 6, 14, 16, 18, - 19, 21, 28, 0, 1, 7, 15, 19, 21, 0, 13, 14, 26, 27, 28, 0, 5, 13, 14, 15, 0, 1, 13, - 16, 16, 0, 3, 9, 23, 0, 1, 15, 16, 2, 12, 13, 14, 0, 20, 24, 0, 3, 23, 0, 1, 13, - 4, 17, 27, 2, 17, 26, 13, 15, 17, 13, 14, 0, 1, 13, 14, 13, 14, 0}; - - std::vector w_h = { - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}; - - int num_verts = off_h.size() - 1; - int num_edges = ind_h.size(); - - rmm::device_uvector offsets_v(num_verts + 1, stream); - rmm::device_uvector indices_v(num_edges, stream); - rmm::device_uvector weights_v(num_edges, stream); - rmm::device_uvector result_v(num_verts, stream); - - raft::update_device(offsets_v.data(), off_h.data(), off_h.size(), stream); - raft::update_device(indices_v.data(), ind_h.data(), ind_h.size(), stream); - raft::update_device(weights_v.data(), w_h.data(), w_h.size(), stream); - - cugraph::legacy::GraphCSRView G( - offsets_v.data(), indices_v.data(), weights_v.data(), num_verts, num_edges); - - float modularity{0.0}; - size_t num_level = 40; - - // "FIXME": remove this check once we drop support for Pascal - // - // Calling louvain on Pascal will throw an exception, we'll check that - // this is the behavior while we still support Pascal (device_prop.major < 7) - // - if (handle.get_device_properties().major < 7) { - EXPECT_THROW(cugraph::louvain(handle, G, result_v.data()), cugraph::logic_error); - } else { - std::tie(num_level, modularity) = cugraph::louvain(handle, G, result_v.data()); - - auto cluster_id = cugraph::test::to_host(handle, result_v); - - int min = *min_element(cluster_id.begin(), cluster_id.end()); - - ASSERT_GE(min, 0); - ASSERT_FLOAT_EQ(modularity, 0.41880345); - } -} - using Tests_Louvain_File = Tests_Louvain; using Tests_Louvain_File32 = Tests_Louvain; using Tests_Louvain_File64 = Tests_Louvain; @@ -390,11 +324,15 @@ using Tests_Louvain_Rmat = Tests_Louvain; using Tests_Louvain_Rmat32 = Tests_Louvain; using Tests_Louvain_Rmat64 = Tests_Louvain; +#if 0 +// FIXME: Reenable legacy tests once threshold parameter is exposed +// by louvain legacy API. TEST_P(Tests_Louvain_File, CheckInt32Int32FloatFloatLegacy) { run_legacy_test( override_File_Usecase_with_cmd_line_arguments(GetParam())); } +#endif TEST_P(Tests_Louvain_File, CheckInt32Int32FloatFloat) { @@ -458,11 +396,12 @@ TEST_P(Tests_Louvain_Rmat64, CheckInt64Int64FloatFloat) INSTANTIATE_TEST_SUITE_P( simple_test, Tests_Louvain_File, - ::testing::Combine( - ::testing::Values(Louvain_Usecase{std::nullopt, std::nullopt, std::nullopt, true, 3, 0.408695}, - Louvain_Usecase{20, double{1e-4}, std::nullopt, true, 3, 0.408695}, - Louvain_Usecase{100, double{1e-4}, double{0.8}, true, 3, 0.48336622}), - ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx")))); + ::testing::Combine(::testing::Values( + Louvain_Usecase{ + std::nullopt, std::nullopt, std::nullopt, true, 3, 0.39907956}, + Louvain_Usecase{20, double{1e-3}, std::nullopt, true, 3, 0.39907956}, + Louvain_Usecase{100, double{1e-3}, double{0.8}, true, 3, 0.47547662}), + ::testing::Values(cugraph::test::File_Usecase("test/datasets/karate.mtx")))); INSTANTIATE_TEST_SUITE_P( file_benchmark_test, /* note that the test filename can be overridden in benchmarking (with diff --git a/python/cugraph/cugraph/tests/community/test_leiden.py b/python/cugraph/cugraph/tests/community/test_leiden.py index a06b0dd22c5..71117c4210f 100644 --- a/python/cugraph/cugraph/tests/community/test_leiden.py +++ b/python/cugraph/cugraph/tests/community/test_leiden.py @@ -22,8 +22,6 @@ from cugraph.testing import utils, UNDIRECTED_DATASETS from cugraph.datasets import karate_asymmetric -from cudf.testing.testing import assert_series_equal - # ============================================================================= # Test data @@ -43,8 +41,8 @@ "resolution": 1.0, "input_type": "COO", "expected_output": { - "partition": [1, 0, 1, 2, 2, 2], - "modularity_score": 0.1757322, + "partition": [0, 0, 0, 1, 1, 1], + "modularity_score": 0.215969, }, }, "data_2": { @@ -85,10 +83,10 @@ "input_type": "CSR", "expected_output": { # fmt: off - "partition": [6, 6, 3, 3, 1, 5, 5, 3, 0, 3, 1, 6, 3, 3, 4, 4, 5, 6, 4, 6, 4, - 6, 4, 4, 2, 2, 4, 4, 2, 4, 0, 2, 4, 4], + "partition": [3, 3, 3, 3, 2, 2, 2, 3, 1, 3, 2, 3, 3, 3, 1, 1, 2, 3, 1, 3, + 1, 3, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1], # fmt: on - "modularity_score": 0.3468113, + "modularity_score": 0.41880345, }, }, } @@ -138,7 +136,7 @@ def input_and_expected_output(request): # Create graph from csr offsets = src_or_offset_array indices = dst_or_index_array - G.from_cudf_adjlist(offsets, indices, weight) + G.from_cudf_adjlist(offsets, indices, weight, renumber=False) parts, mod = cugraph.leiden(G, max_level, resolution) @@ -223,9 +221,7 @@ def test_leiden_directed_graph(): @pytest.mark.sg def test_leiden_golden_results(input_and_expected_output): - expected_partition = cudf.Series( - input_and_expected_output["expected_output"]["partition"] - ) + expected_partition = input_and_expected_output["expected_output"]["partition"] expected_mod = input_and_expected_output["expected_output"]["modularity_score"] result_partition = input_and_expected_output["result_output"]["partition"] @@ -233,6 +229,10 @@ def test_leiden_golden_results(input_and_expected_output): assert abs(expected_mod - result_mod) < 0.0001 - assert_series_equal( - expected_partition, result_partition, check_dtype=False, check_names=False - ) + expected_to_result_map = {} + for e, r in zip(expected_partition, list(result_partition.to_pandas())): + if e in expected_to_result_map.keys(): + assert r == expected_to_result_map[e] + + else: + expected_to_result_map[e] = r