From 629e63c164732fe8da690fa42d2bd9b990131b09 Mon Sep 17 00:00:00 2001 From: Erik Welch Date: Wed, 25 Oct 2023 16:16:00 -0500 Subject: [PATCH] nx-cugraph: add k_truss and degree centralities (#3945) New algorithms: - `degree_centrality` - `in_degree_centrality` - `k_truss` - `number_of_selfloops` - `out_degree_centrality` Also, rename `row_indices, col_indices` to `src_indices, dst_indices` Authors: - Erik Welch (https://github.com/eriknw) Approvers: - Rick Ratzel (https://github.com/rlratzel) URL: https://github.com/rapidsai/cugraph/pull/3945 --- python/nx-cugraph/_nx_cugraph/__init__.py | 5 + python/nx-cugraph/lint.yaml | 6 +- .../nx_cugraph/algorithms/__init__.py | 1 + .../algorithms/centrality/__init__.py | 1 + .../algorithms/centrality/degree_alg.py | 48 ++++++++ .../algorithms/community/louvain.py | 2 +- .../nx-cugraph/nx_cugraph/algorithms/core.py | 96 ++++++++++++++++ .../nx_cugraph/algorithms/isolate.py | 8 +- .../nx-cugraph/nx_cugraph/classes/__init__.py | 5 +- .../nx-cugraph/nx_cugraph/classes/digraph.py | 13 ++- .../nx-cugraph/nx_cugraph/classes/function.py | 23 ++++ python/nx-cugraph/nx_cugraph/classes/graph.py | 105 ++++++++++-------- .../nx_cugraph/classes/multigraph.py | 64 ++++++----- python/nx-cugraph/nx_cugraph/convert.py | 38 +++---- .../nx_cugraph/tests/test_convert.py | 32 +++--- .../nx_cugraph/tests/test_match_api.py | 3 + python/nx-cugraph/pyproject.toml | 5 +- 17 files changed, 330 insertions(+), 125 deletions(-) create mode 100644 python/nx-cugraph/nx_cugraph/algorithms/centrality/degree_alg.py create mode 100644 python/nx-cugraph/nx_cugraph/algorithms/core.py create mode 100644 python/nx-cugraph/nx_cugraph/classes/function.py diff --git a/python/nx-cugraph/_nx_cugraph/__init__.py b/python/nx-cugraph/_nx_cugraph/__init__.py index 886d7a19f74..965b5b232ab 100644 --- a/python/nx-cugraph/_nx_cugraph/__init__.py +++ b/python/nx-cugraph/_nx_cugraph/__init__.py @@ -30,11 +30,16 @@ "functions": { # BEGIN: functions "betweenness_centrality", + "degree_centrality", "edge_betweenness_centrality", + "in_degree_centrality", "is_isolate", "isolates", + "k_truss", "louvain_communities", "number_of_isolates", + "number_of_selfloops", + "out_degree_centrality", # END: functions }, "extra_docstrings": { diff --git a/python/nx-cugraph/lint.yaml b/python/nx-cugraph/lint.yaml index 4f604fbeb6e..fef2cebc7f5 100644 --- a/python/nx-cugraph/lint.yaml +++ b/python/nx-cugraph/lint.yaml @@ -45,12 +45,12 @@ repos: - id: pyupgrade args: [--py39-plus] - repo: https://github.com/psf/black - rev: 23.9.1 + rev: 23.10.0 hooks: - id: black # - id: black-jupyter - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.292 + rev: v0.1.1 hooks: - id: ruff args: [--fix-only, --show-fixes] @@ -77,7 +77,7 @@ repos: additional_dependencies: [tomli] files: ^(nx_cugraph|docs)/ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.292 + rev: v0.1.1 hooks: - id: ruff - repo: https://github.com/pre-commit/pre-commit-hooks diff --git a/python/nx-cugraph/nx_cugraph/algorithms/__init__.py b/python/nx-cugraph/nx_cugraph/algorithms/__init__.py index dfd9adfc61a..22600bfdc2d 100644 --- a/python/nx-cugraph/nx_cugraph/algorithms/__init__.py +++ b/python/nx-cugraph/nx_cugraph/algorithms/__init__.py @@ -12,4 +12,5 @@ # limitations under the License. from . import centrality, community from .centrality import * +from .core import * from .isolate import * diff --git a/python/nx-cugraph/nx_cugraph/algorithms/centrality/__init__.py b/python/nx-cugraph/nx_cugraph/algorithms/centrality/__init__.py index 2ac6242e8a4..af91f227843 100644 --- a/python/nx-cugraph/nx_cugraph/algorithms/centrality/__init__.py +++ b/python/nx-cugraph/nx_cugraph/algorithms/centrality/__init__.py @@ -11,3 +11,4 @@ # See the License for the specific language governing permissions and # limitations under the License. from .betweenness import * +from .degree_alg import * diff --git a/python/nx-cugraph/nx_cugraph/algorithms/centrality/degree_alg.py b/python/nx-cugraph/nx_cugraph/algorithms/centrality/degree_alg.py new file mode 100644 index 00000000000..0b2fd24af79 --- /dev/null +++ b/python/nx-cugraph/nx_cugraph/algorithms/centrality/degree_alg.py @@ -0,0 +1,48 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nx_cugraph.convert import _to_directed_graph, _to_graph +from nx_cugraph.utils import networkx_algorithm, not_implemented_for + +__all__ = ["degree_centrality", "in_degree_centrality", "out_degree_centrality"] + + +@networkx_algorithm +def degree_centrality(G): + G = _to_graph(G) + if len(G) <= 1: + return dict.fromkeys(G, 1) + deg = G._degrees_array() + centrality = deg * (1 / (len(G) - 1)) + return G._nodearray_to_dict(centrality) + + +@not_implemented_for("undirected") +@networkx_algorithm +def in_degree_centrality(G): + G = _to_directed_graph(G) + if len(G) <= 1: + return dict.fromkeys(G, 1) + deg = G._in_degrees_array() + centrality = deg * (1 / (len(G) - 1)) + return G._nodearray_to_dict(centrality) + + +@not_implemented_for("undirected") +@networkx_algorithm +def out_degree_centrality(G): + G = _to_directed_graph(G) + if len(G) <= 1: + return dict.fromkeys(G, 1) + deg = G._out_degrees_array() + centrality = deg * (1 / (len(G) - 1)) + return G._nodearray_to_dict(centrality) diff --git a/python/nx-cugraph/nx_cugraph/algorithms/community/louvain.py b/python/nx-cugraph/nx_cugraph/algorithms/community/louvain.py index dc209870c89..62261d109a2 100644 --- a/python/nx-cugraph/nx_cugraph/algorithms/community/louvain.py +++ b/python/nx-cugraph/nx_cugraph/algorithms/community/louvain.py @@ -42,7 +42,7 @@ def louvain_communities( # NetworkX allows both directed and undirected, but cugraph only allows undirected. seed = _seed_to_int(seed) # Unused, but ensure it's valid for future compatibility G = _to_undirected_graph(G, weight) - if G.row_indices.size == 0: + if G.src_indices.size == 0: # TODO: PLC doesn't handle empty graphs gracefully! return [{key} for key in G._nodeiter_to_iter(range(len(G)))] if max_level is None: diff --git a/python/nx-cugraph/nx_cugraph/algorithms/core.py b/python/nx-cugraph/nx_cugraph/algorithms/core.py new file mode 100644 index 00000000000..0a64dd71c69 --- /dev/null +++ b/python/nx-cugraph/nx_cugraph/algorithms/core.py @@ -0,0 +1,96 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import cupy as cp +import networkx as nx +import numpy as np +import pylibcugraph as plc + +import nx_cugraph as nxcg +from nx_cugraph.utils import networkx_algorithm, not_implemented_for + +__all__ = ["k_truss"] + + +@not_implemented_for("directed") +@not_implemented_for("multigraph") +@networkx_algorithm +def k_truss(G, k): + if is_nx := isinstance(G, nx.Graph): + G = nxcg.from_networkx(G, preserve_all_attrs=True) + if nxcg.number_of_selfloops(G) > 0: + raise nx.NetworkXError( + "Input graph has self loops which is not permitted; " + "Consider using G.remove_edges_from(nx.selfloop_edges(G))." + ) + # TODO: create renumbering helper function(s) + if k < 3: + # k-truss graph is comprised of nodes incident on k-2 triangles, so k<3 is a + # boundary condition. Here, all we need to do is drop nodes with zero degree. + # Technically, it would be okay to delete this branch of code, because + # plc.k_truss_subgraph behaves the same for 0 <= k < 3. We keep this branch b/c + # it's faster and has an "early return" if there are no nodes with zero degree. + degrees = G._degrees_array() + # Renumber step 0: node indices + node_indices = degrees.nonzero()[0] + if degrees.size == node_indices.size: + # No change + return G if is_nx else G.copy() + src_indices = G.src_indices + dst_indices = G.dst_indices + # Renumber step 1: edge values (no changes needed) + edge_values = {key: val.copy() for key, val in G.edge_values.items()} + edge_masks = {key: val.copy() for key, val in G.edge_masks.items()} + else: + # int dtype for edge_indices would be preferred + edge_indices = cp.arange(G.src_indices.size, dtype=np.float64) + src_indices, dst_indices, edge_indices, _ = plc.k_truss_subgraph( + resource_handle=plc.ResourceHandle(), + graph=G._get_plc_graph(edge_array=edge_indices), + k=k, + do_expensive_check=False, + ) + # Renumber step 0: node indices + node_indices = cp.unique(cp.concatenate([src_indices, dst_indices])) + # Renumber step 1: edge values + edge_indices = edge_indices.astype(np.int64) + edge_values = {key: val[edge_indices] for key, val in G.edge_values.items()} + edge_masks = {key: val[edge_indices] for key, val in G.edge_masks.items()} + # Renumber step 2: edge indices + mapper = cp.zeros(len(G), src_indices.dtype) + mapper[node_indices] = cp.arange(node_indices.size, dtype=mapper.dtype) + src_indices = mapper[src_indices] + dst_indices = mapper[dst_indices] + # Renumber step 3: node values + node_values = {key: val[node_indices] for key, val in G.node_values.items()} + node_masks = {key: val[node_indices] for key, val in G.node_masks.items()} + # Renumber step 4: key_to_id + if (id_to_key := G.id_to_key) is not None: + key_to_id = { + id_to_key[old_index]: new_index + for new_index, old_index in enumerate(node_indices.tolist()) + } + else: + key_to_id = None + # Same as calling `G.from_coo`, but use __class__ to indicate it's a classmethod. + new_graph = G.__class__.from_coo( + node_indices.size, + src_indices, + dst_indices, + edge_values, + edge_masks, + node_values, + node_masks, + key_to_id=key_to_id, + ) + new_graph.graph.update(G.graph) + return new_graph diff --git a/python/nx-cugraph/nx_cugraph/algorithms/isolate.py b/python/nx-cugraph/nx_cugraph/algorithms/isolate.py index 774627e84f6..d32223fb3ed 100644 --- a/python/nx-cugraph/nx_cugraph/algorithms/isolate.py +++ b/python/nx-cugraph/nx_cugraph/algorithms/isolate.py @@ -30,18 +30,18 @@ def is_isolate(G, n): G = _to_graph(G) index = n if G.key_to_id is None else G.key_to_id[n] return not ( - (G.row_indices == index).any().tolist() + (G.src_indices == index).any().tolist() or G.is_directed() - and (G.col_indices == index).any().tolist() + and (G.dst_indices == index).any().tolist() ) def _mark_isolates(G) -> cp.ndarray[bool]: """Return a boolean mask array indicating indices of isolated nodes.""" mark_isolates = cp.ones(len(G), bool) - mark_isolates[G.row_indices] = False + mark_isolates[G.src_indices] = False if G.is_directed(): - mark_isolates[G.col_indices] = False + mark_isolates[G.dst_indices] = False return mark_isolates diff --git a/python/nx-cugraph/nx_cugraph/classes/__init__.py b/python/nx-cugraph/nx_cugraph/classes/__init__.py index 9916bcbe241..19a5357da55 100644 --- a/python/nx-cugraph/nx_cugraph/classes/__init__.py +++ b/python/nx-cugraph/nx_cugraph/classes/__init__.py @@ -11,7 +11,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from .graph import Graph +from .digraph import DiGraph from .multigraph import MultiGraph +from .multidigraph import MultiDiGraph -from .digraph import DiGraph # isort:skip -from .multidigraph import MultiDiGraph # isort:skip +from .function import * diff --git a/python/nx-cugraph/nx_cugraph/classes/digraph.py b/python/nx-cugraph/nx_cugraph/classes/digraph.py index 72a1bff21a9..52ea2334c85 100644 --- a/python/nx-cugraph/nx_cugraph/classes/digraph.py +++ b/python/nx-cugraph/nx_cugraph/classes/digraph.py @@ -14,6 +14,7 @@ from typing import TYPE_CHECKING +import cupy as cp import networkx as nx import nx_cugraph as nxcg @@ -48,7 +49,7 @@ def number_of_edges( ) -> int: if u is not None or v is not None: raise NotImplementedError - return self.row_indices.size + return self.src_indices.size ########################## # NetworkX graph methods # @@ -59,3 +60,13 @@ def reverse(self, copy: bool = True) -> DiGraph: return self._copy(not copy, self.__class__, reverse=True) # Many more methods to implement... + + ################### + # Private methods # + ################### + + def _in_degrees_array(self): + return cp.bincount(self.dst_indices, minlength=self._N) + + def _out_degrees_array(self): + return cp.bincount(self.src_indices, minlength=self._N) diff --git a/python/nx-cugraph/nx_cugraph/classes/function.py b/python/nx-cugraph/nx_cugraph/classes/function.py new file mode 100644 index 00000000000..633e4abd7f4 --- /dev/null +++ b/python/nx-cugraph/nx_cugraph/classes/function.py @@ -0,0 +1,23 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nx_cugraph.convert import _to_graph +from nx_cugraph.utils import networkx_algorithm + +__all__ = ["number_of_selfloops"] + + +@networkx_algorithm +def number_of_selfloops(G): + G = _to_graph(G) + is_selfloop = G.src_indices == G.dst_indices + return is_selfloop.sum().tolist() diff --git a/python/nx-cugraph/nx_cugraph/classes/graph.py b/python/nx-cugraph/nx_cugraph/classes/graph.py index f1e85c836e3..166b6b9dc6b 100644 --- a/python/nx-cugraph/nx_cugraph/classes/graph.py +++ b/python/nx-cugraph/nx_cugraph/classes/graph.py @@ -52,8 +52,8 @@ class Graph: # Not networkx properties # We store edge data in COO format with {row,col}_indices and edge_values. - row_indices: cp.ndarray[IndexValue] - col_indices: cp.ndarray[IndexValue] + src_indices: cp.ndarray[IndexValue] + dst_indices: cp.ndarray[IndexValue] edge_values: dict[AttrKey, cp.ndarray[EdgeValue]] edge_masks: dict[AttrKey, cp.ndarray[bool]] node_values: dict[AttrKey, cp.ndarray[NodeValue]] @@ -70,8 +70,8 @@ class Graph: def from_coo( cls, N: int, - row_indices: cp.ndarray[IndexValue], - col_indices: cp.ndarray[IndexValue], + src_indices: cp.ndarray[IndexValue], + dst_indices: cp.ndarray[IndexValue], edge_values: dict[AttrKey, cp.ndarray[EdgeValue]] | None = None, edge_masks: dict[AttrKey, cp.ndarray[bool]] | None = None, node_values: dict[AttrKey, cp.ndarray[NodeValue]] | None = None, @@ -82,8 +82,8 @@ def from_coo( **attr, ) -> Graph: new_graph = object.__new__(cls) - new_graph.row_indices = row_indices - new_graph.col_indices = col_indices + new_graph.src_indices = src_indices + new_graph.dst_indices = dst_indices new_graph.edge_values = {} if edge_values is None else dict(edge_values) new_graph.edge_masks = {} if edge_masks is None else dict(edge_masks) new_graph.node_values = {} if node_values is None else dict(node_values) @@ -93,9 +93,9 @@ def from_coo( new_graph._N = op.index(N) # Ensure N is integral new_graph.graph = new_graph.graph_attr_dict_factory() new_graph.graph.update(attr) - size = new_graph.row_indices.size + size = new_graph.src_indices.size # Easy and fast sanity checks - if size != new_graph.col_indices.size: + if size != new_graph.dst_indices.size: raise ValueError for attr in ["edge_values", "edge_masks"]: if datadict := getattr(new_graph, attr): @@ -117,7 +117,7 @@ def from_coo( def from_csr( cls, indptr: cp.ndarray[IndexValue], - col_indices: cp.ndarray[IndexValue], + dst_indices: cp.ndarray[IndexValue], edge_values: dict[AttrKey, cp.ndarray[EdgeValue]] | None = None, edge_masks: dict[AttrKey, cp.ndarray[bool]] | None = None, node_values: dict[AttrKey, cp.ndarray[NodeValue]] | None = None, @@ -128,14 +128,14 @@ def from_csr( **attr, ) -> Graph: N = indptr.size - 1 - row_indices = cp.array( + src_indices = cp.array( # cp.repeat is slow to use here, so use numpy instead np.repeat(np.arange(N, dtype=np.int32), cp.diff(indptr).get()) ) return cls.from_coo( N, - row_indices, - col_indices, + src_indices, + dst_indices, edge_values, edge_masks, node_values, @@ -149,7 +149,7 @@ def from_csr( def from_csc( cls, indptr: cp.ndarray[IndexValue], - row_indices: cp.ndarray[IndexValue], + src_indices: cp.ndarray[IndexValue], edge_values: dict[AttrKey, cp.ndarray[EdgeValue]] | None = None, edge_masks: dict[AttrKey, cp.ndarray[bool]] | None = None, node_values: dict[AttrKey, cp.ndarray[NodeValue]] | None = None, @@ -160,14 +160,14 @@ def from_csc( **attr, ) -> Graph: N = indptr.size - 1 - col_indices = cp.array( + dst_indices = cp.array( # cp.repeat is slow to use here, so use numpy instead np.repeat(np.arange(N, dtype=np.int32), cp.diff(indptr).get()) ) return cls.from_coo( N, - row_indices, - col_indices, + src_indices, + dst_indices, edge_values, edge_masks, node_values, @@ -183,7 +183,7 @@ def from_dcsr( N: int, compressed_rows: cp.ndarray[IndexValue], indptr: cp.ndarray[IndexValue], - col_indices: cp.ndarray[IndexValue], + dst_indices: cp.ndarray[IndexValue], edge_values: dict[AttrKey, cp.ndarray[EdgeValue]] | None = None, edge_masks: dict[AttrKey, cp.ndarray[bool]] | None = None, node_values: dict[AttrKey, cp.ndarray[NodeValue]] | None = None, @@ -193,14 +193,14 @@ def from_dcsr( id_to_key: list[NodeKey] | None = None, **attr, ) -> Graph: - row_indices = cp.array( + src_indices = cp.array( # cp.repeat is slow to use here, so use numpy instead np.repeat(compressed_rows.get(), cp.diff(indptr).get()) ) return cls.from_coo( N, - row_indices, - col_indices, + src_indices, + dst_indices, edge_values, edge_masks, node_values, @@ -216,7 +216,7 @@ def from_dcsc( N: int, compressed_cols: cp.ndarray[IndexValue], indptr: cp.ndarray[IndexValue], - row_indices: cp.ndarray[IndexValue], + src_indices: cp.ndarray[IndexValue], edge_values: dict[AttrKey, cp.ndarray[EdgeValue]] | None = None, edge_masks: dict[AttrKey, cp.ndarray[bool]] | None = None, node_values: dict[AttrKey, cp.ndarray[NodeValue]] | None = None, @@ -226,14 +226,14 @@ def from_dcsc( id_to_key: list[NodeKey] | None = None, **attr, ) -> Graph: - col_indices = cp.array( + dst_indices = cp.array( # cp.repeat is slow to use here, so use numpy instead np.repeat(compressed_cols.get(), cp.diff(indptr).get()) ) return cls.from_coo( N, - row_indices, - col_indices, + src_indices, + dst_indices, edge_values, edge_masks, node_values, @@ -343,8 +343,8 @@ def clear(self) -> None: self.node_values.clear() self.node_masks.clear() self.graph.clear() - self.row_indices = cp.empty(0, self.row_indices.dtype) - self.col_indices = cp.empty(0, self.col_indices.dtype) + self.src_indices = cp.empty(0, self.src_indices.dtype) + self.dst_indices = cp.empty(0, self.dst_indices.dtype) self._N = 0 self.key_to_id = None self._id_to_key = None @@ -353,8 +353,8 @@ def clear(self) -> None: def clear_edges(self) -> None: self.edge_values.clear() self.edge_masks.clear() - self.row_indices = cp.empty(0, self.row_indices.dtype) - self.col_indices = cp.empty(0, self.col_indices.dtype) + self.src_indices = cp.empty(0, self.src_indices.dtype) + self.dst_indices = cp.empty(0, self.dst_indices.dtype) @networkx_api def copy(self, as_view: bool = False) -> Graph: @@ -377,7 +377,7 @@ def get_edge_data( return default except TypeError: return default - index = cp.nonzero((self.row_indices == u) & (self.col_indices == v))[0] + index = cp.nonzero((self.src_indices == u) & (self.dst_indices == v))[0] if index.size == 0: return default [index] = index.tolist() @@ -397,7 +397,7 @@ def has_edge(self, u: NodeKey, v: NodeKey) -> bool: v = self.key_to_id[v] except KeyError: return False - return bool(((self.row_indices == u) & (self.col_indices == v)).any()) + return bool(((self.src_indices == u) & (self.dst_indices == v)).any()) @networkx_api def has_node(self, n: NodeKey) -> bool: @@ -431,8 +431,8 @@ def order(self) -> int: def size(self, weight: AttrKey | None = None) -> int: if weight is not None: raise NotImplementedError - # If no self-edges, then `self.row_indices.size // 2` - return int((self.row_indices <= self.col_indices).sum()) + # If no self-edges, then `self.src_indices.size // 2` + return int((self.src_indices <= self.dst_indices).sum()) @networkx_api def to_directed(self, as_view: bool = False) -> nxcg.DiGraph: @@ -455,9 +455,8 @@ def to_undirected(self, as_view: bool = False) -> Graph: def _copy(self, as_view: bool, cls: type[Graph], reverse: bool = False): # DRY warning: see also MultiGraph._copy - indptr = self.indptr - row_indices = self.row_indices - col_indices = self.col_indices + src_indices = self.src_indices + dst_indices = self.dst_indices edge_values = self.edge_values edge_masks = self.edge_masks node_values = self.node_values @@ -465,9 +464,8 @@ def _copy(self, as_view: bool, cls: type[Graph], reverse: bool = False): key_to_id = self.key_to_id id_to_key = None if key_to_id is None else self._id_to_key if not as_view: - indptr = indptr.copy() - row_indices = row_indices.copy() - col_indices = col_indices.copy() + src_indices = src_indices.copy() + dst_indices = dst_indices.copy() edge_values = {key: val.copy() for key, val in edge_values.items()} edge_masks = {key: val.copy() for key, val in edge_masks.items()} node_values = {key: val.copy() for key, val in node_values.items()} @@ -477,11 +475,11 @@ def _copy(self, as_view: bool, cls: type[Graph], reverse: bool = False): if id_to_key is not None: id_to_key = id_to_key.copy() if reverse: - row_indices, col_indices = col_indices, row_indices + src_indices, dst_indices = dst_indices, src_indices rv = cls.from_coo( - indptr, - row_indices, - col_indices, + self._N, + src_indices, + dst_indices, edge_values, edge_masks, node_values, @@ -502,8 +500,11 @@ def _get_plc_graph( edge_dtype: Dtype | None = None, *, store_transposed: bool = False, + edge_array: cp.ndarray[EdgeValue] | None = None, ): - if edge_attr is None: + if edge_array is not None: + pass + elif edge_attr is None: edge_array = None elif edge_attr not in self.edge_values: raise KeyError("Graph has no edge attribute {edge_attr!r}") @@ -532,14 +533,20 @@ def _get_plc_graph( is_multigraph=self.is_multigraph(), is_symmetric=not self.is_directed(), ), - src_or_offset_array=self.row_indices, - dst_or_index_array=self.col_indices, + src_or_offset_array=self.src_indices, + dst_or_index_array=self.dst_indices, weight_array=edge_array, store_transposed=store_transposed, renumber=False, do_expensive_check=False, ) + def _degrees_array(self): + degrees = cp.bincount(self.src_indices, minlength=self._N) + if self.is_directed(): + degrees += cp.bincount(self.dst_indices, minlength=self._N) + return degrees + def _nodeiter_to_iter(self, node_ids: Iterable[IndexValue]) -> Iterable[NodeKey]: """Convert an iterable of node IDs to an iterable of node keys.""" if (id_to_key := self.id_to_key) is not None: @@ -551,6 +558,14 @@ def _nodearray_to_list(self, node_ids: cp.ndarray[IndexValue]) -> list[NodeKey]: return node_ids.tolist() return list(self._nodeiter_to_iter(node_ids.tolist())) + def _nodearray_to_dict( + self, values: cp.ndarray[NodeValue] + ) -> dict[NodeKey, NodeValue]: + it = enumerate(values.tolist()) + if (id_to_key := self.id_to_key) is not None: + return {id_to_key[key]: val for key, val in it} + return dict(it) + def _nodearrays_to_dict( self, node_ids: cp.ndarray[IndexValue], values: cp.ndarray[NodeValue] ) -> dict[NodeKey, NodeValue]: diff --git a/python/nx-cugraph/nx_cugraph/classes/multigraph.py b/python/nx-cugraph/nx_cugraph/classes/multigraph.py index 499ca7ce212..3d90861a328 100644 --- a/python/nx-cugraph/nx_cugraph/classes/multigraph.py +++ b/python/nx-cugraph/nx_cugraph/classes/multigraph.py @@ -67,8 +67,8 @@ class MultiGraph(Graph): def from_coo( cls, N: int, - row_indices: cp.ndarray[IndexValue], - col_indices: cp.ndarray[IndexValue], + src_indices: cp.ndarray[IndexValue], + dst_indices: cp.ndarray[IndexValue], edge_indices: cp.ndarray[IndexValue] | None = None, edge_values: dict[AttrKey, cp.ndarray[EdgeValue]] | None = None, edge_masks: dict[AttrKey, cp.ndarray[bool]] | None = None, @@ -82,8 +82,8 @@ def from_coo( ) -> MultiGraph: new_graph = super().from_coo( N, - row_indices, - col_indices, + src_indices, + dst_indices, edge_values, edge_masks, node_values, @@ -97,7 +97,7 @@ def from_coo( # Easy and fast sanity checks if ( new_graph.edge_keys is not None - and len(new_graph.edge_keys) != row_indices.size + and len(new_graph.edge_keys) != src_indices.size ): raise ValueError return new_graph @@ -106,7 +106,7 @@ def from_coo( def from_csr( cls, indptr: cp.ndarray[IndexValue], - col_indices: cp.ndarray[IndexValue], + dst_indices: cp.ndarray[IndexValue], edge_indices: cp.ndarray[IndexValue] | None = None, edge_values: dict[AttrKey, cp.ndarray[EdgeValue]] | None = None, edge_masks: dict[AttrKey, cp.ndarray[bool]] | None = None, @@ -119,14 +119,14 @@ def from_csr( **attr, ) -> MultiGraph: N = indptr.size - 1 - row_indices = cp.array( + src_indices = cp.array( # cp.repeat is slow to use here, so use numpy instead np.repeat(np.arange(N, dtype=np.int32), cp.diff(indptr).get()) ) return cls.from_coo( N, - row_indices, - col_indices, + src_indices, + dst_indices, edge_indices, edge_values, edge_masks, @@ -142,7 +142,7 @@ def from_csr( def from_csc( cls, indptr: cp.ndarray[IndexValue], - row_indices: cp.ndarray[IndexValue], + src_indices: cp.ndarray[IndexValue], edge_indices: cp.ndarray[IndexValue] | None = None, edge_values: dict[AttrKey, cp.ndarray[EdgeValue]] | None = None, edge_masks: dict[AttrKey, cp.ndarray[bool]] | None = None, @@ -155,14 +155,14 @@ def from_csc( **attr, ) -> MultiGraph: N = indptr.size - 1 - col_indices = cp.array( + dst_indices = cp.array( # cp.repeat is slow to use here, so use numpy instead np.repeat(np.arange(N, dtype=np.int32), cp.diff(indptr).get()) ) return cls.from_coo( N, - row_indices, - col_indices, + src_indices, + dst_indices, edge_indices, edge_values, edge_masks, @@ -180,7 +180,7 @@ def from_dcsr( N: int, compressed_rows: cp.ndarray[IndexValue], indptr: cp.ndarray[IndexValue], - col_indices: cp.ndarray[IndexValue], + dst_indices: cp.ndarray[IndexValue], edge_indices: cp.ndarray[IndexValue] | None = None, edge_values: dict[AttrKey, cp.ndarray[EdgeValue]] | None = None, edge_masks: dict[AttrKey, cp.ndarray[bool]] | None = None, @@ -192,14 +192,14 @@ def from_dcsr( edge_keys: list[EdgeKey] | None = None, **attr, ) -> MultiGraph: - row_indices = cp.array( + src_indices = cp.array( # cp.repeat is slow to use here, so use numpy instead np.repeat(compressed_rows.get(), cp.diff(indptr).get()) ) return cls.from_coo( N, - row_indices, - col_indices, + src_indices, + dst_indices, edge_indices, edge_values, edge_masks, @@ -217,7 +217,7 @@ def from_dcsc( N: int, compressed_cols: cp.ndarray[IndexValue], indptr: cp.ndarray[IndexValue], - row_indices: cp.ndarray[IndexValue], + src_indices: cp.ndarray[IndexValue], edge_indices: cp.ndarray[IndexValue] | None = None, edge_values: dict[AttrKey, cp.ndarray[EdgeValue]] | None = None, edge_masks: dict[AttrKey, cp.ndarray[bool]] | None = None, @@ -229,14 +229,14 @@ def from_dcsc( edge_keys: list[EdgeKey] | None = None, **attr, ) -> Graph: - col_indices = cp.array( + dst_indices = cp.array( # cp.repeat is slow to use here, so use numpy instead np.repeat(compressed_cols.get(), cp.diff(indptr).get()) ) return cls.from_coo( N, - row_indices, - col_indices, + src_indices, + dst_indices, edge_indices, edge_values, edge_masks, @@ -330,7 +330,7 @@ def get_edge_data( return default except TypeError: return default - mask = (self.row_indices == u) & (self.col_indices == v) + mask = (self.src_indices == u) & (self.dst_indices == v) if not mask.any(): return default if self.edge_keys is None: @@ -376,7 +376,7 @@ def has_edge(self, u: NodeKey, v: NodeKey, key: EdgeKey | None = None) -> bool: v = self.key_to_id[v] except KeyError: return False - mask = (self.row_indices == u) & (self.col_indices == v) + mask = (self.src_indices == u) & (self.dst_indices == v) if key is None or (self.edge_indices is None and self.edge_keys is None): return bool(mask.any()) if self.edge_keys is None: @@ -405,9 +405,8 @@ def to_undirected(self, as_view: bool = False) -> MultiGraph: def _copy(self, as_view: bool, cls: type[Graph], reverse: bool = False): # DRY warning: see also Graph._copy - indptr = self.indptr - row_indices = self.row_indices - col_indices = self.col_indices + src_indices = self.src_indices + dst_indices = self.dst_indices edge_indices = self.edge_indices edge_values = self.edge_values edge_masks = self.edge_masks @@ -417,9 +416,8 @@ def _copy(self, as_view: bool, cls: type[Graph], reverse: bool = False): id_to_key = None if key_to_id is None else self._id_to_key edge_keys = self.edge_keys if not as_view: - indptr = indptr.copy() - row_indices = row_indices.copy() - col_indices = col_indices.copy() + src_indices = src_indices.copy() + dst_indices = dst_indices.copy() edge_indices = edge_indices.copy() edge_values = {key: val.copy() for key, val in edge_values.items()} edge_masks = {key: val.copy() for key, val in edge_masks.items()} @@ -432,11 +430,11 @@ def _copy(self, as_view: bool, cls: type[Graph], reverse: bool = False): if edge_keys is not None: edge_keys = edge_keys.copy() if reverse: - row_indices, col_indices = col_indices, row_indices + src_indices, dst_indices = dst_indices, src_indices rv = cls.from_coo( - indptr, - row_indices, - col_indices, + self._N, + src_indices, + dst_indices, edge_indices, edge_values, edge_masks, diff --git a/python/nx-cugraph/nx_cugraph/convert.py b/python/nx-cugraph/nx_cugraph/convert.py index e82286e5e29..d117c8e5c03 100644 --- a/python/nx-cugraph/nx_cugraph/convert.py +++ b/python/nx-cugraph/nx_cugraph/convert.py @@ -266,12 +266,12 @@ def from_networkx( else: col_iter = map(key_to_id.__getitem__, col_iter) if graph.is_multigraph(): - col_indices = np.fromiter(col_iter, np.int32) + dst_indices = np.fromiter(col_iter, np.int32) num_multiedges = np.fromiter( map(len, concat(map(dict.values, adj.values()))), np.int32 ) # cp.repeat is slow to use here, so use numpy instead - col_indices = cp.array(np.repeat(col_indices, num_multiedges)) + dst_indices = cp.array(np.repeat(dst_indices, num_multiedges)) # Determine edge keys and edge ids for multigraphs edge_keys = list(concat(concat(map(dict.values, adj.values())))) edge_indices = cp.fromiter( @@ -281,7 +281,7 @@ def from_networkx( if edge_keys == edge_indices.tolist(): edge_keys = None # Prefer edge_indices else: - col_indices = cp.fromiter(col_iter, np.int32) + dst_indices = cp.fromiter(col_iter, np.int32) edge_values = {} edge_masks = {} @@ -353,12 +353,12 @@ def from_networkx( # if vals.ndim > 1: ... # cp.repeat is slow to use here, so use numpy instead - row_indices = np.repeat( + src_indices = np.repeat( np.arange(N, dtype=np.int32), np.fromiter(map(len, adj.values()), np.int32) ) if graph.is_multigraph(): - row_indices = np.repeat(row_indices, num_multiedges) - row_indices = cp.array(row_indices) + src_indices = np.repeat(src_indices, num_multiedges) + src_indices = cp.array(src_indices) node_values = {} node_masks = {} @@ -405,8 +405,8 @@ def from_networkx( klass = nxcg.MultiGraph rv = klass.from_coo( N, - row_indices, - col_indices, + src_indices, + dst_indices, edge_indices, edge_values, edge_masks, @@ -422,8 +422,8 @@ def from_networkx( klass = nxcg.Graph rv = klass.from_coo( N, - row_indices, - col_indices, + src_indices, + dst_indices, edge_values, edge_masks, node_values, @@ -496,23 +496,23 @@ def to_networkx(G: nxcg.Graph) -> nx.Graph: else: rv.add_nodes_from(range(len(G))) - row_indices = G.row_indices - col_indices = G.col_indices + src_indices = G.src_indices + dst_indices = G.dst_indices edge_values = G.edge_values edge_masks = G.edge_masks if edge_values and not G.is_directed(): # Only add upper triangle of the adjacency matrix so we don't double-add edges - mask = row_indices <= col_indices - row_indices = row_indices[mask] - col_indices = col_indices[mask] + mask = src_indices <= dst_indices + src_indices = src_indices[mask] + dst_indices = dst_indices[mask] edge_values = {k: v[mask] for k, v in edge_values.items()} if edge_masks: edge_masks = {k: v[mask] for k, v in edge_masks.items()} - row_indices = row_iter = row_indices.tolist() - col_indices = col_iter = col_indices.tolist() + src_indices = row_iter = src_indices.tolist() + dst_indices = col_iter = dst_indices.tolist() if id_to_key is not None: - row_iter = map(id_to_key.__getitem__, row_indices) - col_iter = map(id_to_key.__getitem__, col_indices) + row_iter = map(id_to_key.__getitem__, src_indices) + col_iter = map(id_to_key.__getitem__, dst_indices) if G.is_multigraph() and (G.edge_keys is not None or G.edge_indices is not None): if G.edge_keys is not None: edge_keys = G.edge_keys diff --git a/python/nx-cugraph/nx_cugraph/tests/test_convert.py b/python/nx-cugraph/nx_cugraph/tests/test_convert.py index 83820f7621f..1a71b796861 100644 --- a/python/nx-cugraph/nx_cugraph/tests/test_convert.py +++ b/python/nx-cugraph/nx_cugraph/tests/test_convert.py @@ -71,8 +71,8 @@ def test_convert(graph_class): ]: # All edges have "x" attribute, so all kwargs are equivalent Gcg = nxcg.from_networkx(G, **kwargs) - cp.testing.assert_array_equal(Gcg.row_indices, [0, 1]) - cp.testing.assert_array_equal(Gcg.col_indices, [1, 0]) + cp.testing.assert_array_equal(Gcg.src_indices, [0, 1]) + cp.testing.assert_array_equal(Gcg.dst_indices, [1, 0]) cp.testing.assert_array_equal(Gcg.edge_values["x"], [2, 2]) assert len(Gcg.edge_values) == 1 assert Gcg.edge_masks == {} @@ -86,8 +86,8 @@ def test_convert(graph_class): # Structure-only graph (no edge attributes) Gcg = nxcg.from_networkx(G, preserve_node_attrs=True) - cp.testing.assert_array_equal(Gcg.row_indices, [0, 1]) - cp.testing.assert_array_equal(Gcg.col_indices, [1, 0]) + cp.testing.assert_array_equal(Gcg.src_indices, [0, 1]) + cp.testing.assert_array_equal(Gcg.dst_indices, [1, 0]) cp.testing.assert_array_equal(Gcg.node_values["foo"], [10, 20]) assert Gcg.edge_values == Gcg.edge_masks == {} H = nxcg.to_networkx(Gcg) @@ -99,8 +99,8 @@ def test_convert(graph_class): # Fill completely missing attribute with default value Gcg = nxcg.from_networkx(G, edge_attrs={"y": 0}) - cp.testing.assert_array_equal(Gcg.row_indices, [0, 1]) - cp.testing.assert_array_equal(Gcg.col_indices, [1, 0]) + cp.testing.assert_array_equal(Gcg.src_indices, [0, 1]) + cp.testing.assert_array_equal(Gcg.dst_indices, [1, 0]) cp.testing.assert_array_equal(Gcg.edge_values["y"], [0, 0]) assert len(Gcg.edge_values) == 1 assert Gcg.edge_masks == Gcg.node_values == Gcg.node_masks == {} @@ -111,8 +111,8 @@ def test_convert(graph_class): # If attribute is completely missing (and no default), then just ignore it Gcg = nxcg.from_networkx(G, edge_attrs={"y": None}) - cp.testing.assert_array_equal(Gcg.row_indices, [0, 1]) - cp.testing.assert_array_equal(Gcg.col_indices, [1, 0]) + cp.testing.assert_array_equal(Gcg.src_indices, [0, 1]) + cp.testing.assert_array_equal(Gcg.dst_indices, [1, 0]) assert sorted(Gcg.edge_values) == sorted(Gcg.edge_masks) == [] H = nxcg.to_networkx(Gcg) assert list(H.edges(data=True)) == [(0, 1, {})] @@ -123,8 +123,8 @@ def test_convert(graph_class): # Some edges are missing 'x' attribute; need to use a mask for kwargs in [{"preserve_edge_attrs": True}, {"edge_attrs": {"x": None}}]: Gcg = nxcg.from_networkx(G, **kwargs) - cp.testing.assert_array_equal(Gcg.row_indices, [0, 0, 1, 2]) - cp.testing.assert_array_equal(Gcg.col_indices, [1, 2, 0, 0]) + cp.testing.assert_array_equal(Gcg.src_indices, [0, 0, 1, 2]) + cp.testing.assert_array_equal(Gcg.dst_indices, [1, 2, 0, 0]) assert sorted(Gcg.edge_values) == sorted(Gcg.edge_masks) == ["x"] cp.testing.assert_array_equal(Gcg.edge_masks["x"], [True, False, True, False]) cp.testing.assert_array_equal(Gcg.edge_values["x"][Gcg.edge_masks["x"]], [2, 2]) @@ -160,8 +160,8 @@ def test_convert(graph_class): ]: Gcg = nxcg.from_networkx(G, **kwargs) assert Gcg.id_to_key == [10, 20, 30] # Remap node IDs to 0, 1, ... - cp.testing.assert_array_equal(Gcg.row_indices, [0, 0, 1, 2]) - cp.testing.assert_array_equal(Gcg.col_indices, [1, 2, 0, 0]) + cp.testing.assert_array_equal(Gcg.src_indices, [0, 0, 1, 2]) + cp.testing.assert_array_equal(Gcg.dst_indices, [1, 2, 0, 0]) cp.testing.assert_array_equal(Gcg.edge_values["x"], [1, 2, 1, 2]) assert sorted(Gcg.edge_masks) == ["y"] cp.testing.assert_array_equal(Gcg.edge_masks["y"], [False, True, False, True]) @@ -181,8 +181,8 @@ def test_convert(graph_class): ]: Gcg = nxcg.from_networkx(G, **kwargs) assert Gcg.id_to_key == [10, 20, 30] # Remap node IDs to 0, 1, ... - cp.testing.assert_array_equal(Gcg.row_indices, [0, 0, 1, 2]) - cp.testing.assert_array_equal(Gcg.col_indices, [1, 2, 0, 0]) + cp.testing.assert_array_equal(Gcg.src_indices, [0, 0, 1, 2]) + cp.testing.assert_array_equal(Gcg.dst_indices, [1, 2, 0, 0]) cp.testing.assert_array_equal(Gcg.node_values["foo"], [100, 200, 300]) assert sorted(Gcg.node_masks) == ["bar"] cp.testing.assert_array_equal(Gcg.node_masks["bar"], [False, True, False]) @@ -202,8 +202,8 @@ def test_convert(graph_class): ]: Gcg = nxcg.from_networkx(G, **kwargs) assert Gcg.id_to_key == [10, 20, 30] # Remap node IDs to 0, 1, ... - cp.testing.assert_array_equal(Gcg.row_indices, [0, 0, 1, 2]) - cp.testing.assert_array_equal(Gcg.col_indices, [1, 2, 0, 0]) + cp.testing.assert_array_equal(Gcg.src_indices, [0, 0, 1, 2]) + cp.testing.assert_array_equal(Gcg.dst_indices, [1, 2, 0, 0]) cp.testing.assert_array_equal(Gcg.node_values["bar"], [0, 1000, 0]) assert Gcg.node_masks == {} diff --git a/python/nx-cugraph/nx_cugraph/tests/test_match_api.py b/python/nx-cugraph/nx_cugraph/tests/test_match_api.py index ecfda1397db..a654ff343ed 100644 --- a/python/nx-cugraph/nx_cugraph/tests/test_match_api.py +++ b/python/nx-cugraph/nx_cugraph/tests/test_match_api.py @@ -31,6 +31,9 @@ def test_match_signature_and_names(): if is_nx_30_or_31 and name in {"louvain_communities"}: continue + if name not in nx_backends._registered_algorithms: + print(f"{name} not dispatched from networkx") + continue dispatchable_func = nx_backends._registered_algorithms[name] # nx version >=3.2 uses orig_func, version >=3.0,<3.2 uses _orig_func if is_nx_30_or_31: diff --git a/python/nx-cugraph/pyproject.toml b/python/nx-cugraph/pyproject.toml index 2478a02df9b..9fec8fa0242 100644 --- a/python/nx-cugraph/pyproject.toml +++ b/python/nx-cugraph/pyproject.toml @@ -81,7 +81,10 @@ float_to_top = true default_section = "THIRDPARTY" known_first_party = "nx_cugraph" line_length = 88 -extend_skip_glob = ["nx_cugraph/__init__.py"] +extend_skip_glob = [ + "nx_cugraph/__init__.py", + "nx_cugraph/classes/__init__.py", +] [tool.pytest.ini_options] minversion = "6.0"