diff --git a/cpp/include/cugraph/mtmg/resource_manager.hpp b/cpp/include/cugraph/mtmg/resource_manager.hpp index bc312c9ae77..a9e4b81f894 100644 --- a/cpp/include/cugraph/mtmg/resource_manager.hpp +++ b/cpp/include/cugraph/mtmg/resource_manager.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -106,9 +106,9 @@ class resource_manager_t { auto per_device_it = per_device_rmm_resources_.insert( std::pair{global_rank, std::make_shared()}); #else - auto const [free, total] = rmm::detail::available_device_memory(); + auto const [free, total] = rmm::available_device_memory(); auto const min_alloc = - rmm::detail::align_down(std::min(free, total / 6), rmm::detail::CUDA_ALLOCATION_ALIGNMENT); + rmm::align_down(std::min(free, total / 6), rmm::CUDA_ALLOCATION_ALIGNMENT); auto per_device_it = per_device_rmm_resources_.insert( std::pair{global_rank, diff --git a/cpp/tests/utilities/base_fixture.hpp b/cpp/tests/utilities/base_fixture.hpp index 3fe70081614..7471e005ca0 100644 --- a/cpp/tests/utilities/base_fixture.hpp +++ b/cpp/tests/utilities/base_fixture.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -73,9 +73,8 @@ inline auto make_pool() // run more than 2 tests in parallel at the same time. Changes to this value could // effect the maximum amount of parallel tests, and therefore `tests/CMakeLists.txt` // `_CUGRAPH_TEST_PERCENT` default value will need to be audited. - auto const [free, total] = rmm::detail::available_device_memory(); - auto const min_alloc = - rmm::detail::align_down(std::min(free, total / 6), rmm::detail::CUDA_ALLOCATION_ALIGNMENT); + auto const [free, total] = rmm::available_device_memory(); + auto const min_alloc = rmm::align_down(std::min(free, total / 6), rmm::CUDA_ALLOCATION_ALIGNMENT); return rmm::mr::make_owning_wrapper(make_cuda(), min_alloc); } diff --git a/python/cugraph/cugraph/structure/graph_implementation/simpleDistributedGraph.py b/python/cugraph/cugraph/structure/graph_implementation/simpleDistributedGraph.py index 6b95d4d30ab..ed3824d708b 100644 --- a/python/cugraph/cugraph/structure/graph_implementation/simpleDistributedGraph.py +++ b/python/cugraph/cugraph/structure/graph_implementation/simpleDistributedGraph.py @@ -35,11 +35,9 @@ from cugraph.structure.number_map import NumberMap from cugraph.structure.symmetrize import symmetrize from cugraph.dask.common.part_utils import ( - get_persisted_df_worker_map, persist_dask_df_equal_parts_per_worker, ) from cugraph.dask.common.mg_utils import run_gc_on_dask_cluster -from cugraph.dask import get_n_workers import cugraph.dask.comms.comms as Comms @@ -828,12 +826,13 @@ def get_two_hop_neighbors(self, start_vertices=None): _client = default_client() def _call_plc_two_hop_neighbors(sID, mg_graph_x, start_vertices): - return pylibcugraph_get_two_hop_neighbors( + results_ = pylibcugraph_get_two_hop_neighbors( resource_handle=ResourceHandle(Comms.get_handle(sID).getHandle()), graph=mg_graph_x, start_vertices=start_vertices, do_expensive_check=False, ) + return results_ if isinstance(start_vertices, int): start_vertices = [start_vertices] @@ -848,31 +847,31 @@ def _call_plc_two_hop_neighbors(sID, mg_graph_x, start_vertices): else: start_vertices_type = self.input_df.dtypes[0] - if not isinstance(start_vertices, (dask_cudf.Series)): - start_vertices = dask_cudf.from_cudf( + start_vertices = start_vertices.astype(start_vertices_type) + + def create_iterable_args( + session_id, input_graph, start_vertices=None, npartitions=None + ): + session_id_it = [session_id] * npartitions + graph_it = input_graph.values() + start_vertices = cp.array_split(start_vertices.values, npartitions) + return [ + session_id_it, + graph_it, start_vertices, - npartitions=min(self._npartitions, len(start_vertices)), - ) - start_vertices = start_vertices.astype(start_vertices_type) + ] - n_workers = get_n_workers() - start_vertices = start_vertices.repartition(npartitions=n_workers) - start_vertices = persist_dask_df_equal_parts_per_worker( - start_vertices, _client + result = _client.map( + _call_plc_two_hop_neighbors, + *create_iterable_args( + Comms.get_session_id(), + self._plc_graph, + start_vertices, + self._npartitions, + ), + pure=False, ) - start_vertices = get_persisted_df_worker_map(start_vertices, _client) - result = [ - _client.submit( - _call_plc_two_hop_neighbors, - Comms.get_session_id(), - self._plc_graph[w], - start_vertices[w][0], - workers=[w], - allow_other_workers=False, - ) - for w in start_vertices.keys() - ] else: result = [ _client.submit( @@ -899,7 +898,8 @@ def convert_to_cudf(cp_arrays): return df cudf_result = [ - _client.submit(convert_to_cudf, cp_arrays) for cp_arrays in result + _client.submit(convert_to_cudf, cp_arrays, pure=False) + for cp_arrays in result ] wait(cudf_result) diff --git a/python/nx-cugraph/_nx_cugraph/__init__.py b/python/nx-cugraph/_nx_cugraph/__init__.py index 4e869c76b7a..8deac55f4ad 100644 --- a/python/nx-cugraph/_nx_cugraph/__init__.py +++ b/python/nx-cugraph/_nx_cugraph/__init__.py @@ -30,6 +30,7 @@ "functions": { # BEGIN: functions "ancestors", + "average_clustering", "barbell_graph", "betweenness_centrality", "bfs_edges", @@ -41,10 +42,12 @@ "caveman_graph", "chvatal_graph", "circular_ladder_graph", + "clustering", "complete_bipartite_graph", "complete_graph", "complete_multipartite_graph", "connected_components", + "core_number", "cubical_graph", "cycle_graph", "davis_southern_women_graph", @@ -68,8 +71,11 @@ "house_x_graph", "icosahedral_graph", "in_degree_centrality", + "is_bipartite", "is_connected", "is_isolate", + "is_strongly_connected", + "is_weakly_connected", "isolates", "k_truss", "karate_club_graph", @@ -85,33 +91,44 @@ "number_connected_components", "number_of_isolates", "number_of_selfloops", + "number_strongly_connected_components", + "number_weakly_connected_components", "octahedral_graph", "out_degree_centrality", + "overall_reciprocity", "pagerank", "pappus_graph", "path_graph", "petersen_graph", + "reciprocity", "sedgewick_maze_graph", "single_source_shortest_path_length", "single_target_shortest_path_length", "star_graph", + "strongly_connected_components", "tadpole_graph", "tetrahedral_graph", + "transitivity", + "triangles", "trivial_graph", "truncated_cube_graph", "truncated_tetrahedron_graph", "turan_graph", "tutte_graph", + "weakly_connected_components", "wheel_graph", # END: functions }, "extra_docstrings": { # BEGIN: extra_docstrings + "average_clustering": "Directed graphs and `weight` parameter are not yet supported.", "betweenness_centrality": "`weight` parameter is not yet supported, and RNG with seed may be different.", "bfs_edges": "`sort_neighbors` parameter is not yet supported.", "bfs_predecessors": "`sort_neighbors` parameter is not yet supported.", "bfs_successors": "`sort_neighbors` parameter is not yet supported.", "bfs_tree": "`sort_neighbors` parameter is not yet supported.", + "clustering": "Directed graphs and `weight` parameter are not yet supported.", + "core_number": "Directed graphs are not yet supported.", "edge_betweenness_centrality": "`weight` parameter is not yet supported, and RNG with seed may be different.", "eigenvector_centrality": "`nstart` parameter is not used, but it is checked for validity.", "from_pandas_edgelist": "cudf.DataFrame inputs also supported; value columns with str is unsuppported.", @@ -123,6 +140,7 @@ "katz_centrality": "`nstart` isn't used (but is checked), and `normalized=False` is not supported.", "louvain_communities": "`seed` parameter is currently ignored, and self-loops are not yet supported.", "pagerank": "`dangling` parameter is not supported, but it is checked for validity.", + "transitivity": "Directed graphs are not yet supported.", # END: extra_docstrings }, "extra_parameters": { diff --git a/python/nx-cugraph/lint.yaml b/python/nx-cugraph/lint.yaml index de6f20bc439..0d4f0b59413 100644 --- a/python/nx-cugraph/lint.yaml +++ b/python/nx-cugraph/lint.yaml @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # # https://pre-commit.com/ # @@ -36,7 +36,7 @@ repos: - id: autoflake args: [--in-place] - repo: https://github.com/pycqa/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort - repo: https://github.com/asottile/pyupgrade @@ -45,23 +45,23 @@ repos: - id: pyupgrade args: [--py39-plus] - repo: https://github.com/psf/black - rev: 23.11.0 + rev: 23.12.1 hooks: - id: black # - id: black-jupyter - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.7 + rev: v0.1.13 hooks: - id: ruff args: [--fix-only, --show-fixes] # --unsafe-fixes] - repo: https://github.com/PyCQA/flake8 - rev: 6.1.0 + rev: 7.0.0 hooks: - id: flake8 args: ['--per-file-ignores=_nx_cugraph/__init__.py:E501', '--extend-ignore=SIM105'] # Why is this necessary? additional_dependencies: &flake8_dependencies # These versions need updated manually - - flake8==6.1.0 + - flake8==7.0.0 - flake8-bugbear==23.12.2 - flake8-simplify==0.21.0 - repo: https://github.com/asottile/yesqa @@ -77,7 +77,7 @@ repos: additional_dependencies: [tomli] files: ^(nx_cugraph|docs)/ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.7 + rev: v0.1.13 hooks: - id: ruff - repo: https://github.com/pre-commit/pre-commit-hooks diff --git a/python/nx-cugraph/nx_cugraph/algorithms/__init__.py b/python/nx-cugraph/nx_cugraph/algorithms/__init__.py index d28a629fe63..de4e9466ba0 100644 --- a/python/nx-cugraph/nx_cugraph/algorithms/__init__.py +++ b/python/nx-cugraph/nx_cugraph/algorithms/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -13,18 +13,21 @@ from . import ( bipartite, centrality, + cluster, community, components, link_analysis, shortest_paths, traversal, ) -from .bipartite import complete_bipartite_graph +from .bipartite import complete_bipartite_graph, is_bipartite from .centrality import * +from .cluster import * from .components import * from .core import * from .dag import * from .isolate import * from .link_analysis import * +from .reciprocity import * from .shortest_paths import * from .traversal import * diff --git a/python/nx-cugraph/nx_cugraph/algorithms/bipartite/__init__.py b/python/nx-cugraph/nx_cugraph/algorithms/bipartite/__init__.py index 062be973d55..e028299c675 100644 --- a/python/nx-cugraph/nx_cugraph/algorithms/bipartite/__init__.py +++ b/python/nx-cugraph/nx_cugraph/algorithms/bipartite/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -10,4 +10,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from .basic import * from .generators import * diff --git a/python/nx-cugraph/nx_cugraph/algorithms/bipartite/basic.py b/python/nx-cugraph/nx_cugraph/algorithms/bipartite/basic.py new file mode 100644 index 00000000000..d0e9a5c7f1b --- /dev/null +++ b/python/nx-cugraph/nx_cugraph/algorithms/bipartite/basic.py @@ -0,0 +1,31 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import cupy as cp + +from nx_cugraph.algorithms.cluster import _triangles +from nx_cugraph.convert import _to_graph +from nx_cugraph.utils import networkx_algorithm + +__all__ = [ + "is_bipartite", +] + + +@networkx_algorithm(plc="triangle_count", version_added="24.02") +def is_bipartite(G): + G = _to_graph(G) + # Counting triangles may not be the fastest way to do this, but it is simple. + node_ids, triangles, is_single_node = _triangles( + G, None, symmetrize="union" if G.is_directed() else None + ) + return int(cp.count_nonzero(triangles)) == 0 diff --git a/python/nx-cugraph/nx_cugraph/algorithms/cluster.py b/python/nx-cugraph/nx_cugraph/algorithms/cluster.py new file mode 100644 index 00000000000..951c358ff26 --- /dev/null +++ b/python/nx-cugraph/nx_cugraph/algorithms/cluster.py @@ -0,0 +1,136 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import cupy as cp +import pylibcugraph as plc + +from nx_cugraph.convert import _to_undirected_graph +from nx_cugraph.utils import networkx_algorithm, not_implemented_for + +__all__ = [ + "triangles", + "average_clustering", + "clustering", + "transitivity", +] + + +def _triangles(G, nodes, symmetrize=None): + if nodes is not None: + if is_single_node := (nodes in G): + nodes = [nodes if G.key_to_id is None else G.key_to_id[nodes]] + else: + nodes = list(nodes) + nodes = G._list_to_nodearray(nodes) + else: + is_single_node = False + if len(G) == 0: + return None, None, is_single_node + node_ids, triangles = plc.triangle_count( + resource_handle=plc.ResourceHandle(), + graph=G._get_plc_graph(symmetrize=symmetrize), + start_list=nodes, + do_expensive_check=False, + ) + return node_ids, triangles, is_single_node + + +@not_implemented_for("directed") +@networkx_algorithm(plc="triangle_count", version_added="24.02") +def triangles(G, nodes=None): + G = _to_undirected_graph(G) + node_ids, triangles, is_single_node = _triangles(G, nodes) + if len(G) == 0: + return {} + if is_single_node: + return int(triangles[0]) + return G._nodearrays_to_dict(node_ids, triangles) + + +@not_implemented_for("directed") +@networkx_algorithm(is_incomplete=True, plc="triangle_count", version_added="24.02") +def clustering(G, nodes=None, weight=None): + """Directed graphs and `weight` parameter are not yet supported.""" + G = _to_undirected_graph(G) + node_ids, triangles, is_single_node = _triangles(G, nodes) + if len(G) == 0: + return {} + if is_single_node: + numer = int(triangles[0]) + if numer == 0: + return 0 + degree = int((G.src_indices == nodes).sum()) + return 2 * numer / (degree * (degree - 1)) + degrees = G._degrees_array(ignore_selfloops=True)[node_ids] + denom = degrees * (degrees - 1) + results = 2 * triangles / denom + results = cp.where(denom, results, 0) # 0 where we divided by 0 + return G._nodearrays_to_dict(node_ids, results) + + +@clustering._can_run +def _(G, nodes=None, weight=None): + return weight is None and not G.is_directed() + + +@not_implemented_for("directed") +@networkx_algorithm(is_incomplete=True, plc="triangle_count", version_added="24.02") +def average_clustering(G, nodes=None, weight=None, count_zeros=True): + """Directed graphs and `weight` parameter are not yet supported.""" + G = _to_undirected_graph(G) + node_ids, triangles, is_single_node = _triangles(G, nodes) + if len(G) == 0: + raise ZeroDivisionError + degrees = G._degrees_array(ignore_selfloops=True)[node_ids] + if not count_zeros: + mask = triangles != 0 + triangles = triangles[mask] + if triangles.size == 0: + raise ZeroDivisionError + degrees = degrees[mask] + denom = degrees * (degrees - 1) + results = 2 * triangles / denom + if count_zeros: + results = cp.where(denom, results, 0) # 0 where we divided by 0 + return float(results.mean()) + + +@average_clustering._can_run +def _(G, nodes=None, weight=None, count_zeros=True): + return weight is None and not G.is_directed() + + +@not_implemented_for("directed") +@networkx_algorithm(is_incomplete=True, plc="triangle_count", version_added="24.02") +def transitivity(G): + """Directed graphs are not yet supported.""" + G = _to_undirected_graph(G) + if len(G) == 0: + return 0 + node_ids, triangles = plc.triangle_count( + resource_handle=plc.ResourceHandle(), + graph=G._get_plc_graph(), + start_list=None, + do_expensive_check=False, + ) + numer = int(triangles.sum()) + if numer == 0: + return 0 + degrees = G._degrees_array(ignore_selfloops=True)[node_ids] + denom = int((degrees * (degrees - 1)).sum()) + return 2 * numer / denom + + +@transitivity._can_run +def _(G): + # Is transitivity supposed to work on directed graphs? + return not G.is_directed() diff --git a/python/nx-cugraph/nx_cugraph/algorithms/components/__init__.py b/python/nx-cugraph/nx_cugraph/algorithms/components/__init__.py index 26816ef3692..12a09b535c0 100644 --- a/python/nx-cugraph/nx_cugraph/algorithms/components/__init__.py +++ b/python/nx-cugraph/nx_cugraph/algorithms/components/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -11,3 +11,5 @@ # See the License for the specific language governing permissions and # limitations under the License. from .connected import * +from .strongly_connected import * +from .weakly_connected import * diff --git a/python/nx-cugraph/nx_cugraph/algorithms/components/connected.py b/python/nx-cugraph/nx_cugraph/algorithms/components/connected.py index 95cc907a82b..cdb9f54f6c4 100644 --- a/python/nx-cugraph/nx_cugraph/algorithms/components/connected.py +++ b/python/nx-cugraph/nx_cugraph/algorithms/components/connected.py @@ -29,9 +29,15 @@ @networkx_algorithm(plc="weakly_connected_components", version_added="23.12") def number_connected_components(G): G = _to_undirected_graph(G) + return _number_connected_components(G) + + +def _number_connected_components(G, symmetrize=None): + if G.src_indices.size == 0: + return len(G) unused_node_ids, labels = plc.weakly_connected_components( resource_handle=plc.ResourceHandle(), - graph=G._get_plc_graph(), + graph=G._get_plc_graph(symmetrize=symmetrize), offsets=None, indices=None, weights=None, @@ -54,11 +60,15 @@ def _(G): @networkx_algorithm(plc="weakly_connected_components", version_added="23.12") def connected_components(G): G = _to_undirected_graph(G) + return _connected_components(G) + + +def _connected_components(G, symmetrize=None): if G.src_indices.size == 0: return [{key} for key in G._nodeiter_to_iter(range(len(G)))] node_ids, labels = plc.weakly_connected_components( resource_handle=plc.ResourceHandle(), - graph=G._get_plc_graph(), + graph=G._get_plc_graph(symmetrize=symmetrize), offsets=None, indices=None, weights=None, @@ -73,13 +83,19 @@ def connected_components(G): @networkx_algorithm(plc="weakly_connected_components", version_added="23.12") def is_connected(G): G = _to_undirected_graph(G) + return _is_connected(G) + + +def _is_connected(G, symmetrize=None): if len(G) == 0: raise nx.NetworkXPointlessConcept( "Connectivity is undefined for the null graph." ) + if G.src_indices.size == 0: + return len(G) == 1 unused_node_ids, labels = plc.weakly_connected_components( resource_handle=plc.ResourceHandle(), - graph=G._get_plc_graph(), + graph=G._get_plc_graph(symmetrize=symmetrize), offsets=None, indices=None, weights=None, diff --git a/python/nx-cugraph/nx_cugraph/algorithms/components/strongly_connected.py b/python/nx-cugraph/nx_cugraph/algorithms/components/strongly_connected.py new file mode 100644 index 00000000000..8fdf99ed5ea --- /dev/null +++ b/python/nx-cugraph/nx_cugraph/algorithms/components/strongly_connected.py @@ -0,0 +1,85 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import cupy as cp +import networkx as nx +import pylibcugraph as plc + +from nx_cugraph.convert import _to_directed_graph +from nx_cugraph.utils import ( + _groupby, + index_dtype, + networkx_algorithm, + not_implemented_for, +) + +__all__ = [ + "number_strongly_connected_components", + "strongly_connected_components", + "is_strongly_connected", +] + + +def _strongly_connected_components(G): + # TODO: create utility function to convert just the indices to CSR + # TODO: this uses a legacy PLC function (strongly_connected_components) + N = len(G) + indices = cp.lexsort(cp.vstack((G.dst_indices, G.src_indices))) + dst_indices = G.dst_indices[indices] + offsets = cp.searchsorted( + G.src_indices, cp.arange(N + 1, dtype=index_dtype), sorter=indices + ).astype(index_dtype) + labels = cp.zeros(N, dtype=index_dtype) + plc.strongly_connected_components( + offsets=offsets, + indices=dst_indices, + weights=None, + num_verts=N, + num_edges=dst_indices.size, + labels=labels, + ) + return labels + + +@not_implemented_for("undirected") +@networkx_algorithm(version_added="24.02", plc="strongly_connected_components") +def strongly_connected_components(G): + G = _to_directed_graph(G) + if G.src_indices.size == 0: + return [{key} for key in G._nodeiter_to_iter(range(len(G)))] + labels = _strongly_connected_components(G) + groups = _groupby(labels, cp.arange(len(G), dtype=index_dtype)) + return (G._nodearray_to_set(connected_ids) for connected_ids in groups.values()) + + +@not_implemented_for("undirected") +@networkx_algorithm(version_added="24.02", plc="strongly_connected_components") +def number_strongly_connected_components(G): + G = _to_directed_graph(G) + if G.src_indices.size == 0: + return len(G) + labels = _strongly_connected_components(G) + return cp.unique(labels).size + + +@not_implemented_for("undirected") +@networkx_algorithm(version_added="24.02", plc="strongly_connected_components") +def is_strongly_connected(G): + G = _to_directed_graph(G) + if len(G) == 0: + raise nx.NetworkXPointlessConcept( + "Connectivity is undefined for the null graph." + ) + if G.src_indices.size == 0: + return len(G) == 1 + labels = _strongly_connected_components(G) + return bool((labels == labels[0]).all()) diff --git a/python/nx-cugraph/nx_cugraph/algorithms/components/weakly_connected.py b/python/nx-cugraph/nx_cugraph/algorithms/components/weakly_connected.py new file mode 100644 index 00000000000..5b797b39118 --- /dev/null +++ b/python/nx-cugraph/nx_cugraph/algorithms/components/weakly_connected.py @@ -0,0 +1,47 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nx_cugraph.convert import _to_directed_graph +from nx_cugraph.utils import networkx_algorithm, not_implemented_for + +from .connected import ( + _connected_components, + _is_connected, + _number_connected_components, +) + +__all__ = [ + "number_weakly_connected_components", + "weakly_connected_components", + "is_weakly_connected", +] + + +@not_implemented_for("undirected") +@networkx_algorithm(plc="weakly_connected_components", version_added="24.02") +def weakly_connected_components(G): + G = _to_directed_graph(G) + return _connected_components(G, symmetrize="union") + + +@not_implemented_for("undirected") +@networkx_algorithm(plc="weakly_connected_components", version_added="24.02") +def number_weakly_connected_components(G): + G = _to_directed_graph(G) + return _number_connected_components(G, symmetrize="union") + + +@not_implemented_for("undirected") +@networkx_algorithm(plc="weakly_connected_components", version_added="24.02") +def is_weakly_connected(G): + G = _to_directed_graph(G) + return _is_connected(G, symmetrize="union") diff --git a/python/nx-cugraph/nx_cugraph/algorithms/core.py b/python/nx-cugraph/nx_cugraph/algorithms/core.py index e4520c2713b..f323cdf6004 100644 --- a/python/nx-cugraph/nx_cugraph/algorithms/core.py +++ b/python/nx-cugraph/nx_cugraph/algorithms/core.py @@ -15,6 +15,7 @@ import pylibcugraph as plc import nx_cugraph as nxcg +from nx_cugraph.convert import _to_undirected_graph from nx_cugraph.utils import ( _get_int_dtype, index_dtype, @@ -22,7 +23,34 @@ not_implemented_for, ) -__all__ = ["k_truss"] +__all__ = ["core_number", "k_truss"] + + +@not_implemented_for("directed") +@not_implemented_for("multigraph") +@networkx_algorithm(is_incomplete=True, plc="core_number", version_added="24.02") +def core_number(G): + """Directed graphs are not yet supported.""" + G = _to_undirected_graph(G) + if len(G) == 0: + return {} + if nxcg.number_of_selfloops(G) > 0: + raise nx.NetworkXNotImplemented( + "Input graph has self loops which is not permitted; " + "Consider using G.remove_edges_from(nx.selfloop_edges(G))." + ) + node_ids, core_numbers = plc.core_number( + resource_handle=plc.ResourceHandle(), + graph=G._get_plc_graph(), + degree_type="bidirectional", + do_expensive_check=False, + ) + return G._nodearrays_to_dict(node_ids, core_numbers) + + +@core_number._can_run +def _(G): + return not G.is_directed() @not_implemented_for("directed") diff --git a/python/nx-cugraph/nx_cugraph/algorithms/isolate.py b/python/nx-cugraph/nx_cugraph/algorithms/isolate.py index c7e5d7113de..62b47a9b354 100644 --- a/python/nx-cugraph/nx_cugraph/algorithms/isolate.py +++ b/python/nx-cugraph/nx_cugraph/algorithms/isolate.py @@ -15,9 +15,10 @@ from typing import TYPE_CHECKING import cupy as cp +import numpy as np from nx_cugraph.convert import _to_graph -from nx_cugraph.utils import networkx_algorithm +from nx_cugraph.utils import index_dtype, networkx_algorithm if TYPE_CHECKING: # pragma: no cover from nx_cugraph.typing import IndexValue @@ -36,19 +37,28 @@ def is_isolate(G, n): ) -def _mark_isolates(G) -> cp.ndarray[bool]: +def _mark_isolates(G, symmetrize=None) -> cp.ndarray[bool]: """Return a boolean mask array indicating indices of isolated nodes.""" mark_isolates = cp.ones(len(G), bool) - mark_isolates[G.src_indices] = False - if G.is_directed(): - mark_isolates[G.dst_indices] = False + if G.is_directed() and symmetrize == "intersection": + N = G._N + # Upcast to int64 so indices don't overflow + src_dst = N * G.src_indices.astype(np.int64) + G.dst_indices + src_dst_T = G.src_indices + N * G.dst_indices.astype(np.int64) + src_dst_new = cp.intersect1d(src_dst, src_dst_T) + new_indices = cp.floor_divide(src_dst_new, N, dtype=index_dtype) + mark_isolates[new_indices] = False + else: + mark_isolates[G.src_indices] = False + if G.is_directed(): + mark_isolates[G.dst_indices] = False return mark_isolates -def _isolates(G) -> cp.ndarray[IndexValue]: +def _isolates(G, symmetrize=None) -> cp.ndarray[IndexValue]: """Like isolates, but return an array of indices instead of an iterator of nodes.""" G = _to_graph(G) - return cp.nonzero(_mark_isolates(G))[0] + return cp.nonzero(_mark_isolates(G, symmetrize=symmetrize))[0] @networkx_algorithm(version_added="23.10") diff --git a/python/nx-cugraph/nx_cugraph/algorithms/reciprocity.py b/python/nx-cugraph/nx_cugraph/algorithms/reciprocity.py new file mode 100644 index 00000000000..c87abdf9fa7 --- /dev/null +++ b/python/nx-cugraph/nx_cugraph/algorithms/reciprocity.py @@ -0,0 +1,93 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import cupy as cp +import networkx as nx +import numpy as np + +from nx_cugraph.convert import _to_directed_graph +from nx_cugraph.utils import networkx_algorithm, not_implemented_for + +__all__ = ["reciprocity", "overall_reciprocity"] + + +@not_implemented_for("undirected", "multigraph") +@networkx_algorithm(version_added="24.02") +def reciprocity(G, nodes=None): + if nodes is None: + return overall_reciprocity(G) + G = _to_directed_graph(G) + N = G._N + # 'nodes' can also be a single node identifier + if nodes in G: + index = nodes if G.key_to_id is None else G.key_to_id[nodes] + mask = (G.src_indices == index) | (G.dst_indices == index) + src_indices = G.src_indices[mask] + if src_indices.size == 0: + raise nx.NetworkXError("Not defined for isolated nodes.") + dst_indices = G.dst_indices[mask] + # Create two lists of edge identifiers, one for each direction. + # Edge identifiers can be created from a pair of node + # identifiers. Simply adding src IDs to dst IDs is not adequate, so + # make one set of values (either src or dst depending on direction) + # unique by multiplying values by N. + # Upcast to int64 so indices don't overflow. + edges_a_b = N * src_indices.astype(np.int64) + dst_indices + edges_b_a = src_indices + N * dst_indices.astype(np.int64) + # Find the matching edge identifiers in each list. The edge identifier + # generation ensures the ID for A->B == the ID for B->A + recip_indices = cp.intersect1d( + edges_a_b, + edges_b_a, + # assume_unique=True, # cupy <= 12.2.0 also assumes sorted + ) + num_selfloops = (src_indices == dst_indices).sum().tolist() + return (recip_indices.size - num_selfloops) / edges_a_b.size + + # Don't include self-loops + mask = G.src_indices != G.dst_indices + src_indices = G.src_indices[mask] + dst_indices = G.dst_indices[mask] + # Create two lists of edges, one for each direction, and find the matching + # IDs in each list (see description above). + edges_a_b = N * src_indices.astype(np.int64) + dst_indices + edges_b_a = src_indices + N * dst_indices.astype(np.int64) + recip_indices = cp.intersect1d( + edges_a_b, + edges_b_a, + # assume_unique=True, # cupy <= 12.2.0 also assumes sorted + ) + numer = cp.bincount(recip_indices // N, minlength=N) + denom = cp.bincount(src_indices, minlength=N) + denom += cp.bincount(dst_indices, minlength=N) + recip = 2 * numer / denom + node_ids = G._nodekeys_to_nodearray(nodes) + return G._nodearrays_to_dict(node_ids, recip[node_ids]) + + +@not_implemented_for("undirected", "multigraph") +@networkx_algorithm(version_added="24.02") +def overall_reciprocity(G): + G = _to_directed_graph(G) + if G.number_of_edges() == 0: + raise nx.NetworkXError("Not defined for empty graphs") + # Create two lists of edges, one for each direction, and find the matching + # IDs in each list (see description in reciprocity()). + edges_a_b = G._N * G.src_indices.astype(np.int64) + G.dst_indices + edges_b_a = G.src_indices + G._N * G.dst_indices.astype(np.int64) + recip_indices = cp.intersect1d( + edges_a_b, + edges_b_a, + # assume_unique=True, # cupy <= 12.2.0 also assumes sorted + ) + num_selfloops = (G.src_indices == G.dst_indices).sum().tolist() + return (recip_indices.size - num_selfloops) / edges_a_b.size diff --git a/python/nx-cugraph/nx_cugraph/classes/digraph.py b/python/nx-cugraph/nx_cugraph/classes/digraph.py index 52ea2334c85..f8217a2c79f 100644 --- a/python/nx-cugraph/nx_cugraph/classes/digraph.py +++ b/python/nx-cugraph/nx_cugraph/classes/digraph.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -12,13 +12,16 @@ # limitations under the License. from __future__ import annotations +from copy import deepcopy from typing import TYPE_CHECKING import cupy as cp import networkx as nx +import numpy as np import nx_cugraph as nxcg +from ..utils import index_dtype from .graph import Graph if TYPE_CHECKING: # pragma: no cover @@ -59,14 +62,131 @@ def number_of_edges( def reverse(self, copy: bool = True) -> DiGraph: return self._copy(not copy, self.__class__, reverse=True) + @networkx_api + def to_undirected(self, reciprocal=False, as_view=False): + N = self._N + # Upcast to int64 so indices don't overflow + src_dst_indices_old = N * self.src_indices.astype(np.int64) + self.dst_indices + if reciprocal: + src_dst_indices_new = cp.intersect1d( + src_dst_indices_old, + self.src_indices + N * self.dst_indices.astype(np.int64), + # assume_unique=True, # cupy <= 12.2.0 also assumes sorted + ) + if self.edge_values: + sorter = cp.argsort(src_dst_indices_old) + idx = cp.searchsorted( + src_dst_indices_old, src_dst_indices_new, sorter=sorter + ) + indices = sorter[idx] + src_indices = self.src_indices[indices].copy() + dst_indices = self.dst_indices[indices].copy() + edge_values = { + key: val[indices].copy() for key, val in self.edge_values.items() + } + edge_masks = { + key: val[indices].copy() for key, val in self.edge_masks.items() + } + else: + src_indices, dst_indices = cp.divmod( + src_dst_indices_new, N, dtype=index_dtype + ) + else: + src_dst_indices_old_T = self.src_indices + N * self.dst_indices.astype( + np.int64 + ) + if self.edge_values: + src_dst_extra = cp.setdiff1d( + src_dst_indices_old_T, src_dst_indices_old, assume_unique=True + ) + sorter = cp.argsort(src_dst_indices_old_T) + idx = cp.searchsorted( + src_dst_indices_old_T, src_dst_extra, sorter=sorter + ) + indices = sorter[idx] + src_indices = cp.hstack((self.src_indices, self.dst_indices[indices])) + dst_indices = cp.hstack((self.dst_indices, self.src_indices[indices])) + edge_values = { + key: cp.hstack((val, val[indices])) + for key, val in self.edge_values.items() + } + edge_masks = { + key: cp.hstack((val, val[indices])) + for key, val in self.edge_masks.items() + } + else: + src_dst_indices_new = cp.union1d( + src_dst_indices_old, src_dst_indices_old_T + ) + src_indices, dst_indices = cp.divmod( + src_dst_indices_new, N, dtype=index_dtype + ) + + if self.edge_values: + recip_indices = cp.lexsort(cp.vstack((src_indices, dst_indices))) + for key, mask in edge_masks.items(): + # Make sure we choose a value that isn't masked out + val = edge_values[key] + rmask = mask[recip_indices] + recip_only = rmask & ~mask + val[recip_only] = val[recip_indices[recip_only]] + only = mask & ~rmask + val[recip_indices[only]] = val[only] + mask |= mask[recip_indices] + # Arbitrarily choose to use value from (j > i) edge + mask = src_indices < dst_indices + left_idx = cp.nonzero(mask)[0] + right_idx = recip_indices[mask] + for val in edge_values.values(): + val[left_idx] = val[right_idx] + else: + edge_values = {} + edge_masks = {} + + node_values = self.node_values + node_masks = self.node_masks + key_to_id = self.key_to_id + id_to_key = None if key_to_id is None else self._id_to_key + if not as_view: + node_values = {key: val.copy() for key, val in node_values.items()} + node_masks = {key: val.copy() for key, val in node_masks.items()} + if key_to_id is not None: + key_to_id = key_to_id.copy() + if id_to_key is not None: + id_to_key = id_to_key.copy() + rv = self.to_undirected_class().from_coo( + N, + src_indices, + dst_indices, + edge_values, + edge_masks, + node_values, + node_masks, + key_to_id=key_to_id, + id_to_key=id_to_key, + ) + if as_view: + rv.graph = self.graph + else: + rv.graph.update(deepcopy(self.graph)) + return rv + # Many more methods to implement... ################### # Private methods # ################### - def _in_degrees_array(self): - return cp.bincount(self.dst_indices, minlength=self._N) - - def _out_degrees_array(self): - return cp.bincount(self.src_indices, minlength=self._N) + def _in_degrees_array(self, *, ignore_selfloops=False): + dst_indices = self.dst_indices + if ignore_selfloops: + not_selfloops = self.src_indices != dst_indices + dst_indices = dst_indices[not_selfloops] + return cp.bincount(dst_indices, minlength=self._N) + + def _out_degrees_array(self, *, ignore_selfloops=False): + src_indices = self.src_indices + if ignore_selfloops: + not_selfloops = src_indices != self.dst_indices + src_indices = src_indices[not_selfloops] + return cp.bincount(src_indices, minlength=self._N) diff --git a/python/nx-cugraph/nx_cugraph/classes/graph.py b/python/nx-cugraph/nx_cugraph/classes/graph.py index cb6b4e7ae42..4aa2de1538e 100644 --- a/python/nx-cugraph/nx_cugraph/classes/graph.py +++ b/python/nx-cugraph/nx_cugraph/classes/graph.py @@ -531,7 +531,7 @@ def to_directed(self, as_view: bool = False) -> nxcg.DiGraph: @networkx_api def to_undirected(self, as_view: bool = False) -> Graph: # Does deep copy in networkx - return self.copy(as_view) + return self._copy(as_view, self.to_undirected_class()) # Not implemented... # adj, adjacency, add_edge, add_edges_from, add_node, @@ -592,6 +592,7 @@ def _get_plc_graph( store_transposed: bool = False, switch_indices: bool = False, edge_array: cp.ndarray[EdgeValue] | None = None, + symmetrize: str | None = None, ): if edge_array is not None or edge_attr is None: pass @@ -650,12 +651,30 @@ def _get_plc_graph( dst_indices = self.dst_indices if switch_indices: src_indices, dst_indices = dst_indices, src_indices + if symmetrize is not None: + if edge_array is not None: + raise NotImplementedError( + "edge_array must be None when symmetrizing the graph" + ) + N = self._N + # Upcast to int64 so indices don't overflow + src_dst = N * src_indices.astype(np.int64) + dst_indices + src_dst_T = src_indices + N * dst_indices.astype(np.int64) + if symmetrize == "union": + src_dst_new = cp.union1d(src_dst, src_dst_T) + elif symmetrize == "intersection": + src_dst_new = cp.intersect1d(src_dst, src_dst_T) + else: + raise ValueError( + f'symmetrize must be "union" or "intersection"; got "{symmetrize}"' + ) + src_indices, dst_indices = cp.divmod(src_dst_new, N, dtype=index_dtype) return plc.SGGraph( resource_handle=plc.ResourceHandle(), graph_properties=plc.GraphProperties( - is_multigraph=self.is_multigraph(), - is_symmetric=not self.is_directed(), + is_multigraph=self.is_multigraph() and symmetrize is None, + is_symmetric=not self.is_directed() or symmetrize is not None, ), src_or_offset_array=src_indices, dst_or_index_array=dst_indices, @@ -713,16 +732,28 @@ def _become(self, other: Graph): self.graph = graph return self - def _degrees_array(self): - degrees = cp.bincount(self.src_indices, minlength=self._N) + def _degrees_array(self, *, ignore_selfloops=False): + src_indices = self.src_indices + dst_indices = self.dst_indices + if ignore_selfloops: + not_selfloops = src_indices != dst_indices + src_indices = src_indices[not_selfloops] + if self.is_directed(): + dst_indices = dst_indices[not_selfloops] + degrees = cp.bincount(src_indices, minlength=self._N) if self.is_directed(): - degrees += cp.bincount(self.dst_indices, minlength=self._N) + degrees += cp.bincount(dst_indices, minlength=self._N) return degrees _in_degrees_array = _degrees_array _out_degrees_array = _degrees_array # Data conversions + def _nodekeys_to_nodearray(self, nodes: Iterable[NodeKey]) -> cp.array[IndexValue]: + if self.key_to_id is None: + return cp.fromiter(nodes, dtype=index_dtype) + return cp.fromiter(map(self.key_to_id.__getitem__, nodes), dtype=index_dtype) + def _nodeiter_to_iter(self, node_ids: Iterable[IndexValue]) -> Iterable[NodeKey]: """Convert an iterable of node IDs to an iterable of node keys.""" if (id_to_key := self.id_to_key) is not None: diff --git a/python/nx-cugraph/nx_cugraph/classes/multidigraph.py b/python/nx-cugraph/nx_cugraph/classes/multidigraph.py index 2c7bfc00752..2e7a55a9eb1 100644 --- a/python/nx-cugraph/nx_cugraph/classes/multidigraph.py +++ b/python/nx-cugraph/nx_cugraph/classes/multidigraph.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -33,3 +33,11 @@ def is_directed(cls) -> bool: @classmethod def to_networkx_class(cls) -> type[nx.MultiDiGraph]: return nx.MultiDiGraph + + ########################## + # NetworkX graph methods # + ########################## + + @networkx_api + def to_undirected(self, reciprocal=False, as_view=False): + raise NotImplementedError diff --git a/python/nx-cugraph/nx_cugraph/classes/multigraph.py b/python/nx-cugraph/nx_cugraph/classes/multigraph.py index 23466dc7dd4..fb787369e58 100644 --- a/python/nx-cugraph/nx_cugraph/classes/multigraph.py +++ b/python/nx-cugraph/nx_cugraph/classes/multigraph.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -399,7 +399,7 @@ def to_directed(self, as_view: bool = False) -> nxcg.MultiDiGraph: @networkx_api def to_undirected(self, as_view: bool = False) -> MultiGraph: # Does deep copy in networkx - return self.copy(as_view) + return self._copy(as_view, self.to_undirected_class()) ################### # Private methods # diff --git a/python/nx-cugraph/nx_cugraph/interface.py b/python/nx-cugraph/nx_cugraph/interface.py index 34eb5969869..04591c0e9e3 100644 --- a/python/nx-cugraph/nx_cugraph/interface.py +++ b/python/nx-cugraph/nx_cugraph/interface.py @@ -68,7 +68,12 @@ def key(testpath): louvain_different = "Louvain may be different due to RNG" no_string_dtype = "string edge values not currently supported" - xfail = {} + xfail = { + key( + "test_strongly_connected.py:" + "TestStronglyConnected.test_condensation_mapping_and_members" + ): "Strongly connected groups in different iteration order", + } from packaging.version import parse diff --git a/python/nx-cugraph/nx_cugraph/tests/test_cluster.py b/python/nx-cugraph/nx_cugraph/tests/test_cluster.py new file mode 100644 index 00000000000..ad4770f1ab8 --- /dev/null +++ b/python/nx-cugraph/nx_cugraph/tests/test_cluster.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import networkx as nx +import pytest +from packaging.version import parse + +nxver = parse(nx.__version__) + +if nxver.major == 3 and nxver.minor < 2: + pytest.skip("Need NetworkX >=3.2 to test clustering", allow_module_level=True) + + +def test_selfloops(): + G = nx.complete_graph(5) + H = nx.complete_graph(5) + H.add_edge(0, 0) + H.add_edge(1, 1) + H.add_edge(2, 2) + # triangles + expected = nx.triangles(G) + assert expected == nx.triangles(H) + assert expected == nx.triangles(G, backend="cugraph") + assert expected == nx.triangles(H, backend="cugraph") + # average_clustering + expected = nx.average_clustering(G) + assert expected == nx.average_clustering(H) + assert expected == nx.average_clustering(G, backend="cugraph") + assert expected == nx.average_clustering(H, backend="cugraph") + # clustering + expected = nx.clustering(G) + assert expected == nx.clustering(H) + assert expected == nx.clustering(G, backend="cugraph") + assert expected == nx.clustering(H, backend="cugraph") + # transitivity + expected = nx.transitivity(G) + assert expected == nx.transitivity(H) + assert expected == nx.transitivity(G, backend="cugraph") + assert expected == nx.transitivity(H, backend="cugraph") diff --git a/python/nx-cugraph/nx_cugraph/tests/test_generators.py b/python/nx-cugraph/nx_cugraph/tests/test_generators.py index 511f8dcd8e2..c751b0fe2b3 100644 --- a/python/nx-cugraph/nx_cugraph/tests/test_generators.py +++ b/python/nx-cugraph/nx_cugraph/tests/test_generators.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. +# Copyright (c) 2023-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -17,30 +17,9 @@ import nx_cugraph as nxcg -nxver = parse(nx.__version__) - +from .testing_utils import assert_graphs_equal -def assert_graphs_equal(Gnx, Gcg): - assert isinstance(Gnx, nx.Graph) - assert isinstance(Gcg, nxcg.Graph) - assert Gnx.number_of_nodes() == Gcg.number_of_nodes() - assert Gnx.number_of_edges() == Gcg.number_of_edges() - assert Gnx.is_directed() == Gcg.is_directed() - assert Gnx.is_multigraph() == Gcg.is_multigraph() - G = nxcg.to_networkx(Gcg) - rv = nx.utils.graphs_equal(G, Gnx) - if not rv: - print("GRAPHS ARE NOT EQUAL!") - assert sorted(G) == sorted(Gnx) - assert sorted(G._adj) == sorted(Gnx._adj) - assert sorted(G._node) == sorted(Gnx._node) - for k in sorted(G._adj): - print(k, sorted(G._adj[k]), sorted(Gnx._adj[k])) - print(nx.to_scipy_sparse_array(G).todense()) - print(nx.to_scipy_sparse_array(Gnx).todense()) - print(G.graph) - print(Gnx.graph) - assert rv +nxver = parse(nx.__version__) if nxver.major == 3 and nxver.minor < 2: diff --git a/python/nx-cugraph/nx_cugraph/tests/test_graph_methods.py b/python/nx-cugraph/nx_cugraph/tests/test_graph_methods.py new file mode 100644 index 00000000000..3120995a2b2 --- /dev/null +++ b/python/nx-cugraph/nx_cugraph/tests/test_graph_methods.py @@ -0,0 +1,67 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import networkx as nx +import pytest + +import nx_cugraph as nxcg + +from .testing_utils import assert_graphs_equal + + +def _create_Gs(): + rv = [] + rv.append(nx.DiGraph()) + G = nx.DiGraph() + G.add_edge(0, 1) + G.add_edge(1, 0) + rv.append(G) + G = G.copy() + G.add_edge(0, 2) + rv.append(G) + G = G.copy() + G.add_edge(1, 1) + rv.append(G) + G = nx.DiGraph() + G.add_edge(0, 1, x=1, y=2) + G.add_edge(1, 0, x=10, z=3) + rv.append(G) + G = G.copy() + G.add_edge(0, 2, a=42) + rv.append(G) + G = G.copy() + G.add_edge(1, 1, a=4) + rv.append(G) + return rv + + +@pytest.mark.parametrize("Gnx", _create_Gs()) +@pytest.mark.parametrize("reciprocal", [False, True]) +def test_to_undirected_directed(Gnx, reciprocal): + Gcg = nxcg.DiGraph(Gnx) + assert_graphs_equal(Gnx, Gcg) + Hnx1 = Gnx.to_undirected(reciprocal=reciprocal) + Hcg1 = Gcg.to_undirected(reciprocal=reciprocal) + assert_graphs_equal(Hnx1, Hcg1) + Hnx2 = Hnx1.to_directed() + Hcg2 = Hcg1.to_directed() + assert_graphs_equal(Hnx2, Hcg2) + + +def test_multidigraph_to_undirected(): + Gnx = nx.MultiDiGraph() + Gnx.add_edge(0, 1) + Gnx.add_edge(0, 1) + Gnx.add_edge(1, 0) + Gcg = nxcg.MultiDiGraph(Gnx) + with pytest.raises(NotImplementedError): + Gcg.to_undirected() diff --git a/python/nx-cugraph/nx_cugraph/tests/testing_utils.py b/python/nx-cugraph/nx_cugraph/tests/testing_utils.py new file mode 100644 index 00000000000..6d4741c9ca6 --- /dev/null +++ b/python/nx-cugraph/nx_cugraph/tests/testing_utils.py @@ -0,0 +1,38 @@ +# Copyright (c) 2023-2024, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import networkx as nx + +import nx_cugraph as nxcg + + +def assert_graphs_equal(Gnx, Gcg): + assert isinstance(Gnx, nx.Graph) + assert isinstance(Gcg, nxcg.Graph) + assert Gnx.number_of_nodes() == Gcg.number_of_nodes() + assert Gnx.number_of_edges() == Gcg.number_of_edges() + assert Gnx.is_directed() == Gcg.is_directed() + assert Gnx.is_multigraph() == Gcg.is_multigraph() + G = nxcg.to_networkx(Gcg) + rv = nx.utils.graphs_equal(G, Gnx) + if not rv: + print("GRAPHS ARE NOT EQUAL!") + assert sorted(G) == sorted(Gnx) + assert sorted(G._adj) == sorted(Gnx._adj) + assert sorted(G._node) == sorted(Gnx._node) + for k in sorted(G._adj): + print(k, sorted(G._adj[k]), sorted(Gnx._adj[k])) + print(nx.to_scipy_sparse_array(G).todense()) + print(nx.to_scipy_sparse_array(Gnx).todense()) + print(G.graph) + print(Gnx.graph) + assert rv