Skip to content

Commit

Permalink
nx-cugraph: add ZeroGraph for nx-compatibility (zero-code change)
Browse files Browse the repository at this point in the history
  • Loading branch information
eriknw committed Aug 22, 2024
1 parent 710f268 commit 7cf65b9
Show file tree
Hide file tree
Showing 16 changed files with 371 additions and 11 deletions.
7 changes: 6 additions & 1 deletion python/nx-cugraph/_nx_cugraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
$ python _nx_cugraph/__init__.py
"""
import os

from _nx_cugraph._version import __version__

Expand Down Expand Up @@ -293,12 +294,16 @@ def get_info():

for key in info_keys:
del d[key]

d["default_config"] = {
"zero": os.environ.get("NX_CUGRAPH_ZERO", "true").strip().lower() == "true",
}
return d


def _check_networkx_version():
import warnings
import re
import warnings

import networkx as nx

Expand Down
8 changes: 8 additions & 0 deletions python/nx-cugraph/nx_cugraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,15 @@
from . import algorithms
from .algorithms import *

from .interface import BackendInterface

from _nx_cugraph._version import __git_commit__, __version__
from _nx_cugraph import _check_networkx_version

_check_networkx_version()

BackendInterface.Graph = classes.ZeroGraph
BackendInterface.DiGraph = classes.ZeroDiGraph
BackendInterface.MultiGraph = classes.ZeroMultiGraph
BackendInterface.MultiDiGraph = classes.ZeroMultiDiGraph
del BackendInterface
4 changes: 4 additions & 0 deletions python/nx-cugraph/nx_cugraph/algorithms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ def _(G):
@networkx_algorithm(is_incomplete=True, version_added="23.12", _plc="k_truss_subgraph")
def k_truss(G, k):
if is_nx := isinstance(G, nx.Graph):
zero = isinstance(G, nxcg.ZeroGraph)
G = nxcg.from_networkx(G, preserve_all_attrs=True)
else:
zero = False
if nxcg.number_of_selfloops(G) > 0:
if nx.__version__[:3] <= "3.2":
exc_class = nx.NetworkXError
Expand Down Expand Up @@ -128,6 +131,7 @@ def k_truss(G, k):
node_values,
node_masks,
key_to_id=key_to_id,
zero=zero,
)
new_graph.graph.update(G.graph)
return new_graph
12 changes: 10 additions & 2 deletions python/nx-cugraph/nx_cugraph/algorithms/operators/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

@networkx_algorithm(version_added="24.02")
def complement(G):
zero = isinstance(G, nxcg.ZeroGraph)
G = _to_graph(G)
N = G._N
# Upcast to int64 so indices don't overflow.
Expand All @@ -43,6 +44,7 @@ def complement(G):
src_indices.astype(index_dtype),
dst_indices.astype(index_dtype),
key_to_id=G.key_to_id,
zero=zero,
)


Expand All @@ -51,10 +53,16 @@ def reverse(G, copy=True):
if not G.is_directed():
raise nx.NetworkXError("Cannot reverse an undirected graph.")
if isinstance(G, nx.Graph):
if not copy:
zero = isinstance(G, nxcg.ZeroGraph)
if not copy and not zero:
raise RuntimeError(
"Using `copy=False` is invalid when using a NetworkX graph "
"as input to `nx_cugraph.reverse`"
)
G = nxcg.from_networkx(G, preserve_all_attrs=True)
return G.reverse(copy=copy)
else:
zero = False
rv = G.reverse(copy=copy)
if zero:
return rv.to_zero()
return rv
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,15 @@ def bfs_tree(G, source, reverse=False, depth_limit=None, sort_neighbors=None):
raise NotImplementedError(
"sort_neighbors argument in bfs_tree is not currently supported"
)
zero = isinstance(G, nxcg.ZeroGraph)
G = _check_G_and_source(G, source)
if depth_limit is not None and depth_limit < 1:
return nxcg.DiGraph.from_coo(
1,
cp.array([], dtype=index_dtype),
cp.array([], dtype=index_dtype),
id_to_key=[source],
zero=zero,
)

distances, predecessors, node_ids = _bfs(
Expand All @@ -153,6 +155,7 @@ def bfs_tree(G, source, reverse=False, depth_limit=None, sort_neighbors=None):
cp.array([], dtype=index_dtype),
cp.array([], dtype=index_dtype),
id_to_key=[source],
zero=zero,
)
# TODO: create renumbering helper function(s)
unique_node_ids = cp.unique(cp.hstack((predecessors, node_ids)))
Expand All @@ -175,6 +178,7 @@ def bfs_tree(G, source, reverse=False, depth_limit=None, sort_neighbors=None):
src_indices,
dst_indices,
key_to_id=key_to_id,
zero=zero,
)


Expand Down
3 changes: 2 additions & 1 deletion python/nx-cugraph/nx_cugraph/classes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -10,6 +10,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .zero import ZeroGraph, ZeroDiGraph, ZeroMultiGraph, ZeroMultiDiGraph
from .graph import Graph
from .digraph import DiGraph
from .multigraph import MultiGraph
Expand Down
6 changes: 6 additions & 0 deletions python/nx-cugraph/nx_cugraph/classes/digraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from ..utils import index_dtype
from .graph import Graph
from .zero import ZeroDiGraph

if TYPE_CHECKING: # pragma: no cover
from nx_cugraph.typing import AttrKey
Expand All @@ -46,6 +47,10 @@ def is_directed(cls) -> bool:
def to_networkx_class(cls) -> type[nx.DiGraph]:
return nx.DiGraph

@classmethod
def to_zero_class(cls) -> type[ZeroDiGraph]:
return ZeroDiGraph

@networkx_api
def size(self, weight: AttrKey | None = None) -> int:
if weight is not None:
Expand Down Expand Up @@ -162,6 +167,7 @@ def to_undirected(self, reciprocal=False, as_view=False):
node_masks,
key_to_id=key_to_id,
id_to_key=id_to_key,
zero=False,
)
if as_view:
rv.graph = self.graph
Expand Down
26 changes: 24 additions & 2 deletions python/nx-cugraph/nx_cugraph/classes/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import nx_cugraph as nxcg

from ..utils import index_dtype
from .zero import ZeroGraph

if TYPE_CHECKING: # pragma: no cover
from collections.abc import Iterable, Iterator
Expand Down Expand Up @@ -109,6 +110,7 @@ def from_coo(
*,
key_to_id: dict[NodeKey, IndexValue] | None = None,
id_to_key: list[NodeKey] | None = None,
zero: bool | None = None,
**attr,
) -> Graph:
new_graph = object.__new__(cls)
Expand Down Expand Up @@ -173,7 +175,8 @@ def from_coo(
isolates = nxcg.algorithms.isolate._isolates(new_graph)
if len(isolates) > 0:
new_graph._node_ids = cp.arange(new_graph._N, dtype=index_dtype)

if zero or zero is None and nx.config.backends.cugraph.zero:
new_graph = new_graph.to_zero()
return new_graph

@classmethod
Expand All @@ -188,6 +191,7 @@ def from_csr(
*,
key_to_id: dict[NodeKey, IndexValue] | None = None,
id_to_key: list[NodeKey] | None = None,
zero: bool | None = None,
**attr,
) -> Graph:
N = indptr.size - 1
Expand All @@ -205,6 +209,7 @@ def from_csr(
node_masks,
key_to_id=key_to_id,
id_to_key=id_to_key,
zero=zero,
**attr,
)

Expand All @@ -220,6 +225,7 @@ def from_csc(
*,
key_to_id: dict[NodeKey, IndexValue] | None = None,
id_to_key: list[NodeKey] | None = None,
zero: bool | None = None,
**attr,
) -> Graph:
N = indptr.size - 1
Expand All @@ -237,6 +243,7 @@ def from_csc(
node_masks,
key_to_id=key_to_id,
id_to_key=id_to_key,
zero=zero,
**attr,
)

Expand All @@ -254,6 +261,7 @@ def from_dcsr(
*,
key_to_id: dict[NodeKey, IndexValue] | None = None,
id_to_key: list[NodeKey] | None = None,
zero: bool | None = None,
**attr,
) -> Graph:
src_indices = cp.array(
Expand All @@ -270,6 +278,7 @@ def from_dcsr(
node_masks,
key_to_id=key_to_id,
id_to_key=id_to_key,
zero=zero,
**attr,
)

Expand All @@ -287,6 +296,7 @@ def from_dcsc(
*,
key_to_id: dict[NodeKey, IndexValue] | None = None,
id_to_key: list[NodeKey] | None = None,
zero: bool | None = None,
**attr,
) -> Graph:
dst_indices = cp.array(
Expand All @@ -303,13 +313,14 @@ def from_dcsc(
node_masks,
key_to_id=key_to_id,
id_to_key=id_to_key,
zero=zero,
**attr,
)

def __new__(cls, incoming_graph_data=None, **attr) -> Graph:
if incoming_graph_data is None:
new_graph = cls.from_coo(
0, cp.empty(0, index_dtype), cp.empty(0, index_dtype)
0, cp.empty(0, index_dtype), cp.empty(0, index_dtype), zero=False
)
elif incoming_graph_data.__class__ is cls:
new_graph = incoming_graph_data.copy()
Expand All @@ -318,6 +329,7 @@ def __new__(cls, incoming_graph_data=None, **attr) -> Graph:
else:
raise NotImplementedError
new_graph.graph.update(attr)
# XXX: we could return ZeroGraph here, but let's not for now
return new_graph

#################
Expand Down Expand Up @@ -348,6 +360,10 @@ def to_networkx_class(cls) -> type[nx.Graph]:
def to_undirected_class(cls) -> type[Graph]:
return Graph

@classmethod
def to_zero_class(cls) -> type[ZeroGraph]:
return ZeroGraph

##############
# Properties #
##############
Expand Down Expand Up @@ -542,6 +558,11 @@ def to_undirected(self, as_view: bool = False) -> Graph:
# Does deep copy in networkx
return self._copy(as_view, self.to_undirected_class())

def to_zero(self) -> ZeroGraph:
rv = self.to_zero_class()()
rv._cugraph = self
return rv

# Not implemented...
# adj, adjacency, add_edge, add_edges_from, add_node,
# add_nodes_from, add_weighted_edges_from, degree,
Expand Down Expand Up @@ -593,6 +614,7 @@ def _copy(self, as_view: bool, cls: type[Graph], reverse: bool = False):
node_masks,
key_to_id=key_to_id,
id_to_key=id_to_key,
zero=False,
)
if as_view:
rv.graph = self.graph
Expand Down
5 changes: 5 additions & 0 deletions python/nx-cugraph/nx_cugraph/classes/multidigraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from .digraph import DiGraph
from .multigraph import MultiGraph
from .zero import ZeroMultiDiGraph

__all__ = ["MultiDiGraph"]

Expand All @@ -34,6 +35,10 @@ def is_directed(cls) -> bool:
def to_networkx_class(cls) -> type[nx.MultiDiGraph]:
return nx.MultiDiGraph

@classmethod
def to_zero_class(cls) -> type[ZeroMultiDiGraph]:
return ZeroMultiDiGraph

##########################
# NetworkX graph methods #
##########################
Expand Down
Loading

0 comments on commit 7cf65b9

Please sign in to comment.