Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nx-cugraph: Updates nxcg.Graph classes for API-compatibility with NetworkX Graph classes, needed for zero code change graph generators #4629

Merged
merged 12 commits into from
Sep 24, 2024
Merged
7 changes: 6 additions & 1 deletion python/nx-cugraph/_nx_cugraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

$ python _nx_cugraph/__init__.py
"""
import os

from _nx_cugraph._version import __version__

Expand Down Expand Up @@ -293,12 +294,16 @@ def get_info():

for key in info_keys:
del d[key]

d["default_config"] = {
"zero": os.environ.get("NX_CUGRAPH_ZERO", "true").strip().lower() == "true",
}
return d


def _check_networkx_version():
import warnings
import re
import warnings

import networkx as nx

Expand Down
8 changes: 8 additions & 0 deletions python/nx-cugraph/nx_cugraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,15 @@
from . import algorithms
from .algorithms import *

from .interface import BackendInterface

from _nx_cugraph._version import __git_commit__, __version__
from _nx_cugraph import _check_networkx_version

_check_networkx_version()

BackendInterface.Graph = classes.ZeroGraph
BackendInterface.DiGraph = classes.ZeroDiGraph
BackendInterface.MultiGraph = classes.ZeroMultiGraph
BackendInterface.MultiDiGraph = classes.ZeroMultiDiGraph
del BackendInterface
4 changes: 4 additions & 0 deletions python/nx-cugraph/nx_cugraph/algorithms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ def _(G):
@networkx_algorithm(is_incomplete=True, version_added="23.12", _plc="k_truss_subgraph")
def k_truss(G, k):
if is_nx := isinstance(G, nx.Graph):
zero = isinstance(G, nxcg.ZeroGraph)
G = nxcg.from_networkx(G, preserve_all_attrs=True)
else:
zero = False
if nxcg.number_of_selfloops(G) > 0:
if nx.__version__[:3] <= "3.2":
exc_class = nx.NetworkXError
Expand Down Expand Up @@ -128,6 +131,7 @@ def k_truss(G, k):
node_values,
node_masks,
key_to_id=key_to_id,
zero=zero,
)
new_graph.graph.update(G.graph)
return new_graph
12 changes: 10 additions & 2 deletions python/nx-cugraph/nx_cugraph/algorithms/operators/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

@networkx_algorithm(version_added="24.02")
def complement(G):
zero = isinstance(G, nxcg.ZeroGraph)
G = _to_graph(G)
N = G._N
# Upcast to int64 so indices don't overflow.
Expand All @@ -43,6 +44,7 @@ def complement(G):
src_indices.astype(index_dtype),
dst_indices.astype(index_dtype),
key_to_id=G.key_to_id,
zero=zero,
)


Expand All @@ -51,10 +53,16 @@ def reverse(G, copy=True):
if not G.is_directed():
raise nx.NetworkXError("Cannot reverse an undirected graph.")
if isinstance(G, nx.Graph):
if not copy:
zero = isinstance(G, nxcg.ZeroGraph)
if not copy and not zero:
raise RuntimeError(
"Using `copy=False` is invalid when using a NetworkX graph "
"as input to `nx_cugraph.reverse`"
)
G = nxcg.from_networkx(G, preserve_all_attrs=True)
return G.reverse(copy=copy)
else:
zero = False
rv = G.reverse(copy=copy)
if zero:
return rv.to_zero()
return rv
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,15 @@ def bfs_tree(G, source, reverse=False, depth_limit=None, sort_neighbors=None):
raise NotImplementedError(
"sort_neighbors argument in bfs_tree is not currently supported"
)
zero = isinstance(G, nxcg.ZeroGraph)
G = _check_G_and_source(G, source)
if depth_limit is not None and depth_limit < 1:
return nxcg.DiGraph.from_coo(
1,
cp.array([], dtype=index_dtype),
cp.array([], dtype=index_dtype),
id_to_key=[source],
zero=zero,
)

distances, predecessors, node_ids = _bfs(
Expand All @@ -153,6 +155,7 @@ def bfs_tree(G, source, reverse=False, depth_limit=None, sort_neighbors=None):
cp.array([], dtype=index_dtype),
cp.array([], dtype=index_dtype),
id_to_key=[source],
zero=zero,
)
# TODO: create renumbering helper function(s)
unique_node_ids = cp.unique(cp.hstack((predecessors, node_ids)))
Expand All @@ -175,6 +178,7 @@ def bfs_tree(G, source, reverse=False, depth_limit=None, sort_neighbors=None):
src_indices,
dst_indices,
key_to_id=key_to_id,
zero=zero,
)


Expand Down
3 changes: 2 additions & 1 deletion python/nx-cugraph/nx_cugraph/classes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -10,6 +10,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .zero import ZeroGraph, ZeroDiGraph, ZeroMultiGraph, ZeroMultiDiGraph
from .graph import Graph
from .digraph import DiGraph
from .multigraph import MultiGraph
Expand Down
6 changes: 6 additions & 0 deletions python/nx-cugraph/nx_cugraph/classes/digraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from ..utils import index_dtype
from .graph import Graph
from .zero import ZeroDiGraph

if TYPE_CHECKING: # pragma: no cover
from nx_cugraph.typing import AttrKey
Expand All @@ -46,6 +47,10 @@ def is_directed(cls) -> bool:
def to_networkx_class(cls) -> type[nx.DiGraph]:
return nx.DiGraph

@classmethod
def to_zero_class(cls) -> type[ZeroDiGraph]:
return ZeroDiGraph

@networkx_api
def size(self, weight: AttrKey | None = None) -> int:
if weight is not None:
Expand Down Expand Up @@ -162,6 +167,7 @@ def to_undirected(self, reciprocal=False, as_view=False):
node_masks,
key_to_id=key_to_id,
id_to_key=id_to_key,
zero=False,
)
if as_view:
rv.graph = self.graph
Expand Down
26 changes: 24 additions & 2 deletions python/nx-cugraph/nx_cugraph/classes/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import nx_cugraph as nxcg

from ..utils import index_dtype
from .zero import ZeroGraph

if TYPE_CHECKING: # pragma: no cover
from collections.abc import Iterable, Iterator
Expand Down Expand Up @@ -109,6 +110,7 @@ def from_coo(
*,
key_to_id: dict[NodeKey, IndexValue] | None = None,
id_to_key: list[NodeKey] | None = None,
zero: bool | None = None,
**attr,
) -> Graph:
new_graph = object.__new__(cls)
Expand Down Expand Up @@ -173,7 +175,8 @@ def from_coo(
isolates = nxcg.algorithms.isolate._isolates(new_graph)
if len(isolates) > 0:
new_graph._node_ids = cp.arange(new_graph._N, dtype=index_dtype)

if zero or zero is None and nx.config.backends.cugraph.zero:
new_graph = new_graph.to_zero()
return new_graph

@classmethod
Expand All @@ -188,6 +191,7 @@ def from_csr(
*,
key_to_id: dict[NodeKey, IndexValue] | None = None,
id_to_key: list[NodeKey] | None = None,
zero: bool | None = None,
**attr,
) -> Graph:
N = indptr.size - 1
Expand All @@ -205,6 +209,7 @@ def from_csr(
node_masks,
key_to_id=key_to_id,
id_to_key=id_to_key,
zero=zero,
**attr,
)

Expand All @@ -220,6 +225,7 @@ def from_csc(
*,
key_to_id: dict[NodeKey, IndexValue] | None = None,
id_to_key: list[NodeKey] | None = None,
zero: bool | None = None,
**attr,
) -> Graph:
N = indptr.size - 1
Expand All @@ -237,6 +243,7 @@ def from_csc(
node_masks,
key_to_id=key_to_id,
id_to_key=id_to_key,
zero=zero,
**attr,
)

Expand All @@ -254,6 +261,7 @@ def from_dcsr(
*,
key_to_id: dict[NodeKey, IndexValue] | None = None,
id_to_key: list[NodeKey] | None = None,
zero: bool | None = None,
**attr,
) -> Graph:
src_indices = cp.array(
Expand All @@ -270,6 +278,7 @@ def from_dcsr(
node_masks,
key_to_id=key_to_id,
id_to_key=id_to_key,
zero=zero,
**attr,
)

Expand All @@ -287,6 +296,7 @@ def from_dcsc(
*,
key_to_id: dict[NodeKey, IndexValue] | None = None,
id_to_key: list[NodeKey] | None = None,
zero: bool | None = None,
**attr,
) -> Graph:
dst_indices = cp.array(
Expand All @@ -303,13 +313,14 @@ def from_dcsc(
node_masks,
key_to_id=key_to_id,
id_to_key=id_to_key,
zero=zero,
**attr,
)

def __new__(cls, incoming_graph_data=None, **attr) -> Graph:
if incoming_graph_data is None:
new_graph = cls.from_coo(
0, cp.empty(0, index_dtype), cp.empty(0, index_dtype)
0, cp.empty(0, index_dtype), cp.empty(0, index_dtype), zero=False
)
elif incoming_graph_data.__class__ is cls:
new_graph = incoming_graph_data.copy()
Expand All @@ -318,6 +329,7 @@ def __new__(cls, incoming_graph_data=None, **attr) -> Graph:
else:
raise NotImplementedError
new_graph.graph.update(attr)
# XXX: we could return ZeroGraph here, but let's not for now
return new_graph

#################
Expand Down Expand Up @@ -348,6 +360,10 @@ def to_networkx_class(cls) -> type[nx.Graph]:
def to_undirected_class(cls) -> type[Graph]:
return Graph

@classmethod
def to_zero_class(cls) -> type[ZeroGraph]:
return ZeroGraph

##############
# Properties #
##############
Expand Down Expand Up @@ -542,6 +558,11 @@ def to_undirected(self, as_view: bool = False) -> Graph:
# Does deep copy in networkx
return self._copy(as_view, self.to_undirected_class())

def to_zero(self) -> ZeroGraph:
rv = self.to_zero_class()()
rv._cugraph = self
return rv

# Not implemented...
# adj, adjacency, add_edge, add_edges_from, add_node,
# add_nodes_from, add_weighted_edges_from, degree,
Expand Down Expand Up @@ -593,6 +614,7 @@ def _copy(self, as_view: bool, cls: type[Graph], reverse: bool = False):
node_masks,
key_to_id=key_to_id,
id_to_key=id_to_key,
zero=False,
)
if as_view:
rv.graph = self.graph
Expand Down
5 changes: 5 additions & 0 deletions python/nx-cugraph/nx_cugraph/classes/multidigraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from .digraph import DiGraph
from .multigraph import MultiGraph
from .zero import ZeroMultiDiGraph

__all__ = ["MultiDiGraph"]

Expand All @@ -34,6 +35,10 @@ def is_directed(cls) -> bool:
def to_networkx_class(cls) -> type[nx.MultiDiGraph]:
return nx.MultiDiGraph

@classmethod
def to_zero_class(cls) -> type[ZeroMultiDiGraph]:
return ZeroMultiDiGraph

##########################
# NetworkX graph methods #
##########################
Expand Down
Loading
Loading