Skip to content

Commit

Permalink
Add multigraph support to nx-cugraph
Browse files Browse the repository at this point in the history
  • Loading branch information
eriknw committed Oct 13, 2023
1 parent 2c1626f commit 48dcf73
Show file tree
Hide file tree
Showing 9 changed files with 468 additions and 85 deletions.
14 changes: 7 additions & 7 deletions python/nx-cugraph/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ default_language_version:
python: python3
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: check-added-large-files
- id: check-case-conflict
Expand All @@ -26,7 +26,7 @@ repos:
- id: mixed-line-ending
- id: trailing-whitespace
- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.14
rev: v0.15
hooks:
- id: validate-pyproject
name: Validate pyproject.toml
Expand All @@ -40,7 +40,7 @@ repos:
hooks:
- id: isort
- repo: https://github.com/asottile/pyupgrade
rev: v3.13.0
rev: v3.15.0
hooks:
- id: pyupgrade
args: [--py39-plus]
Expand All @@ -50,7 +50,7 @@ repos:
- id: black
# - id: black-jupyter
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.291
rev: v0.0.292
hooks:
- id: ruff
args: [--fix-only, --show-fixes]
Expand All @@ -70,18 +70,18 @@ repos:
- id: yesqa
additional_dependencies: *flake8_dependencies
- repo: https://github.com/codespell-project/codespell
rev: v2.2.5
rev: v2.2.6
hooks:
- id: codespell
types_or: [python, rst, markdown]
additional_dependencies: [tomli]
files: ^(nx_cugraph|docs)/
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.291
rev: v0.0.292
hooks:
- id: ruff
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v4.5.0
hooks:
- id: no-commit-to-branch
args: [-p, "^branch-2....$"]
2 changes: 2 additions & 0 deletions python/nx-cugraph/nx_cugraph/classes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .graph import Graph
from .multigraph import MultiGraph

from .digraph import DiGraph # isort:skip
from .multidigraph import MultiDiGraph # isort:skip
22 changes: 12 additions & 10 deletions python/nx-cugraph/nx_cugraph/classes/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@

class Graph:
# Tell networkx to dispatch calls with this object to nx-cugraph
__networkx_plugin__: ClassVar[str] = "cugraph"
__networkx_backend__: ClassVar[str] = "cugraph" # nx >=3.2
__networkx_plugin__: ClassVar[str] = "cugraph" # nx <3.2

# networkx properties
graph: dict
Expand All @@ -58,7 +59,7 @@ class Graph:
node_values: dict[AttrKey, cp.ndarray[NodeValue]]
node_masks: dict[AttrKey, cp.ndarray[bool]]
key_to_id: dict[NodeKey, IndexValue] | None
_id_to_key: dict[IndexValue, NodeKey] | None
_id_to_key: list[NodeKey] | None
_N: int

####################
Expand All @@ -77,7 +78,7 @@ def from_coo(
node_masks: dict[AttrKey, cp.ndarray[bool]] | None = None,
*,
key_to_id: dict[NodeKey, IndexValue] | None = None,
id_to_key: dict[IndexValue, NodeKey] | None = None,
id_to_key: list[NodeKey] | None = None,
**attr,
) -> Graph:
new_graph = object.__new__(cls)
Expand All @@ -88,7 +89,7 @@ def from_coo(
new_graph.node_values = {} if node_values is None else dict(node_values)
new_graph.node_masks = {} if node_masks is None else dict(node_masks)
new_graph.key_to_id = None if key_to_id is None else dict(key_to_id)
new_graph._id_to_key = None if id_to_key is None else dict(id_to_key)
new_graph._id_to_key = None if id_to_key is None else list(id_to_key)
new_graph._N = op.index(N) # Ensure N is integral
new_graph.graph = new_graph.graph_attr_dict_factory()
new_graph.graph.update(attr)
Expand Down Expand Up @@ -123,7 +124,7 @@ def from_csr(
node_masks: dict[AttrKey, cp.ndarray[bool]] | None = None,
*,
key_to_id: dict[NodeKey, IndexValue] | None = None,
id_to_key: dict[IndexValue, NodeKey] | None = None,
id_to_key: list[NodeKey] | None = None,
**attr,
) -> Graph:
N = indptr.size - 1
Expand Down Expand Up @@ -155,7 +156,7 @@ def from_csc(
node_masks: dict[AttrKey, cp.ndarray[bool]] | None = None,
*,
key_to_id: dict[NodeKey, IndexValue] | None = None,
id_to_key: dict[IndexValue, NodeKey] | None = None,
id_to_key: list[NodeKey] | None = None,
**attr,
) -> Graph:
N = indptr.size - 1
Expand Down Expand Up @@ -189,7 +190,7 @@ def from_dcsr(
node_masks: dict[AttrKey, cp.ndarray[bool]] | None = None,
*,
key_to_id: dict[NodeKey, IndexValue] | None = None,
id_to_key: dict[IndexValue, NodeKey] | None = None,
id_to_key: list[NodeKey] | None = None,
**attr,
) -> Graph:
row_indices = cp.array(
Expand Down Expand Up @@ -222,7 +223,7 @@ def from_dcsc(
node_masks: dict[AttrKey, cp.ndarray[bool]] | None = None,
*,
key_to_id: dict[NodeKey, IndexValue] | None = None,
id_to_key: dict[IndexValue, NodeKey] | None = None,
id_to_key: list[NodeKey] | None = None,
**attr,
) -> Graph:
col_indices = cp.array(
Expand Down Expand Up @@ -295,11 +296,11 @@ def node_dtypes(self) -> dict[AttrKey, Dtype]:
return {key: val.dtype for key, val in self.node_values.items()}

@property
def id_to_key(self) -> dict[IndexValue, NodeKey] | None:
def id_to_key(self) -> [NodeKey] | None:
if self.key_to_id is None:
return None
if self._id_to_key is None:
self._id_to_key = {val: key for key, val in self.key_to_id.items()}
self._id_to_key = sorted(self.key_to_id, key=self.key_to_id.__getitem__)
return self._id_to_key

name = nx.Graph.name
Expand Down Expand Up @@ -447,6 +448,7 @@ def to_undirected(self, as_view: bool = False) -> Graph:
###################

def _copy(self, as_view: bool, cls: type[Graph], reverse: bool = False):
# DRY warning: see also MultiGraph._copy
indptr = self.indptr
row_indices = self.row_indices
col_indices = self.col_indices
Expand Down
30 changes: 30 additions & 0 deletions python/nx-cugraph/nx_cugraph/classes/multidigraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import networkx as nx

import nx_cugraph as nxcg

from .digraph import DiGraph
from .multigraph import MultiGraph

__all__ = ["MultiDiGraph"]

networkx_api = nxcg.utils.decorators.networkx_class(nx.MultiDiGraph)


class MultiDiGraph(MultiGraph, DiGraph):
@classmethod
def to_networkx_class(cls) -> type[nx.MultiDiGraph]:
return nx.MultiDiGraph
Loading

0 comments on commit 48dcf73

Please sign in to comment.