Skip to content

Commit

Permalink
nx-cugraph: update usage of nodes_or_number for nx compat (#4028)
Browse files Browse the repository at this point in the history
These changes will be necessary when networkx/networkx#7066 is merged.

Authors:
  - Erik Welch (https://github.com/eriknw)

Approvers:
  - Rick Ratzel (https://github.com/rlratzel)

URL: #4028
  • Loading branch information
eriknw authored Dec 12, 2023
1 parent 1655003 commit c637b33
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 28 deletions.
8 changes: 4 additions & 4 deletions python/nx-cugraph/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ repos:
- id: pyupgrade
args: [--py39-plus]
- repo: https://github.com/psf/black
rev: 23.10.1
rev: 23.11.0
hooks:
- id: black
# - id: black-jupyter
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.3
rev: v0.1.7
hooks:
- id: ruff
args: [--fix-only, --show-fixes] # --unsafe-fixes]
Expand All @@ -62,7 +62,7 @@ repos:
additional_dependencies: &flake8_dependencies
# These versions need updated manually
- flake8==6.1.0
- flake8-bugbear==23.9.16
- flake8-bugbear==23.12.2
- flake8-simplify==0.21.0
- repo: https://github.com/asottile/yesqa
rev: v1.5.0
Expand All @@ -77,7 +77,7 @@ repos:
additional_dependencies: [tomli]
files: ^(nx_cugraph|docs)/
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.3
rev: v0.1.7
hooks:
- id: ruff
- repo: https://github.com/pre-commit/pre-commit-hooks
Expand Down
4 changes: 2 additions & 2 deletions python/nx-cugraph/nx_cugraph/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
centrality,
community,
components,
shortest_paths,
link_analysis,
shortest_paths,
)
from .bipartite import complete_bipartite_graph
from .centrality import *
from .components import *
from .core import *
from .isolate import *
from .shortest_paths import *
from .link_analysis import *
from .shortest_paths import *
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,14 @@
import numpy as np

from nx_cugraph.generators._utils import _create_using_class, _number_and_nodes
from nx_cugraph.utils import index_dtype, networkx_algorithm, nodes_or_number
from nx_cugraph.utils import index_dtype, networkx_algorithm

__all__ = [
"complete_bipartite_graph",
]


@nodes_or_number([0, 1])
@networkx_algorithm
@networkx_algorithm(nodes_or_number=[0, 1])
def complete_bipartite_graph(n1, n2, create_using=None):
graph_class, inplace = _create_using_class(create_using)
if graph_class.is_directed():
Expand Down
6 changes: 5 additions & 1 deletion python/nx-cugraph/nx_cugraph/algorithms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ def k_truss(G, k):
if is_nx := isinstance(G, nx.Graph):
G = nxcg.from_networkx(G, preserve_all_attrs=True)
if nxcg.number_of_selfloops(G) > 0:
raise nx.NetworkXError(
if nx.__version__[:3] <= "3.2":
exc_class = nx.NetworkXError
else:
exc_class = nx.NetworkXNotImplemented
raise exc_class(
"Input graph has self loops which is not permitted; "
"Consider using G.remove_edges_from(nx.selfloop_edges(G))."
)
Expand Down
26 changes: 9 additions & 17 deletions python/nx-cugraph/nx_cugraph/generators/classic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import nx_cugraph as nxcg

from ..utils import _get_int_dtype, index_dtype, networkx_algorithm, nodes_or_number
from ..utils import _get_int_dtype, index_dtype, networkx_algorithm
from ._utils import (
_IS_NX32_OR_LESS,
_common_small_graph,
Expand Down Expand Up @@ -86,8 +86,7 @@ def circular_ladder_graph(n, create_using=None):
return _ladder_graph(n, create_using, is_circular=True)


@nodes_or_number(0)
@networkx_algorithm
@networkx_algorithm(nodes_or_number=0)
def complete_graph(n, create_using=None):
n, nodes = _number_and_nodes(n)
if n < 3:
Expand Down Expand Up @@ -143,8 +142,7 @@ def complete_multipartite_graph(*subset_sizes):
)


@nodes_or_number(0)
@networkx_algorithm
@networkx_algorithm(nodes_or_number=0)
def cycle_graph(n, create_using=None):
n, nodes = _number_and_nodes(n)
graph_class, inplace = _create_using_class(create_using)
Expand Down Expand Up @@ -174,8 +172,7 @@ def cycle_graph(n, create_using=None):
return G


@nodes_or_number(0)
@networkx_algorithm
@networkx_algorithm(nodes_or_number=0)
def empty_graph(n=0, create_using=None, default=nx.Graph):
n, nodes = _number_and_nodes(n)
graph_class, inplace = _create_using_class(create_using, default=default)
Expand Down Expand Up @@ -242,8 +239,7 @@ def ladder_graph(n, create_using=None):
return _ladder_graph(n, create_using)


@nodes_or_number([0, 1])
@networkx_algorithm
@networkx_algorithm(nodes_or_number=[0, 1])
def lollipop_graph(m, n, create_using=None):
# Like complete_graph then path_graph
orig_m, unused_nodes_m = m
Expand Down Expand Up @@ -283,8 +279,7 @@ def null_graph(create_using=None):
return _common_small_graph(0, None, create_using)


@nodes_or_number(0)
@networkx_algorithm
@networkx_algorithm(nodes_or_number=0)
def path_graph(n, create_using=None):
n, nodes = _number_and_nodes(n)
graph_class, inplace = _create_using_class(create_using)
Expand All @@ -304,8 +299,7 @@ def path_graph(n, create_using=None):
return G


@nodes_or_number(0)
@networkx_algorithm
@networkx_algorithm(nodes_or_number=0)
def star_graph(n, create_using=None):
orig_n, orig_nodes = n
n, nodes = _number_and_nodes(n)
Expand All @@ -329,8 +323,7 @@ def star_graph(n, create_using=None):
return G


@nodes_or_number([0, 1])
@networkx_algorithm
@networkx_algorithm(nodes_or_number=[0, 1])
def tadpole_graph(m, n, create_using=None):
orig_m, unused_nodes_m = m
orig_n, unused_nodes_n = n
Expand Down Expand Up @@ -382,8 +375,7 @@ def turan_graph(n, r):
return complete_multipartite_graph(*partitions)


@nodes_or_number(0)
@networkx_algorithm
@networkx_algorithm(nodes_or_number=0)
def wheel_graph(n, create_using=None):
n, nodes = _number_and_nodes(n)
graph_class, inplace = _create_using_class(create_using)
Expand Down
13 changes: 12 additions & 1 deletion python/nx-cugraph/nx_cugraph/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from functools import partial, update_wrapper
from textwrap import dedent

import networkx as nx
from networkx.utils.decorators import nodes_or_number, not_implemented_for

from nx_cugraph.interface import BackendInterface
Expand Down Expand Up @@ -47,10 +48,18 @@ def __new__(
*,
name: str | None = None,
extra_params: dict[str, str] | str | None = None,
nodes_or_number: list[int] | int | None = None,
):
if func is None:
return partial(networkx_algorithm, name=name, extra_params=extra_params)
return partial(
networkx_algorithm,
name=name,
extra_params=extra_params,
nodes_or_number=nodes_or_number,
)
instance = object.__new__(cls)
if nodes_or_number is not None and nx.__version__[:3] > "3.2":
func = nx.utils.decorators.nodes_or_number(nodes_or_number)(func)
# update_wrapper sets __wrapped__, which will be used for the signature
update_wrapper(instance, func)
instance.__defaults__ = func.__defaults__
Expand All @@ -76,6 +85,8 @@ def __new__(
setattr(BackendInterface, instance.name, instance)
# Set methods so they are in __dict__
instance._can_run = instance._can_run
if nodes_or_number is not None and nx.__version__[:3] <= "3.2":
instance = nx.utils.decorators.nodes_or_number(nodes_or_number)(instance)
return instance

def _can_run(self, func):
Expand Down

0 comments on commit c637b33

Please sign in to comment.