Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add entry point to tell NetworkX about nx-cugraph without importing it. #3848

Merged
merged 10 commits into from
Sep 28, 2023
1 change: 1 addition & 0 deletions python/nx-cugraph/.flake8
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ extend-ignore =
per-file-ignores =
nx_cugraph/tests/*.py:T201,
__init__.py:F401,F403,
_nx_cugraph/__init__.py:E501,
10 changes: 10 additions & 0 deletions python/nx-cugraph/Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
SHELL= /bin/bash

.PHONY: all
all: plugin-info lint

.PHONY: lint
lint:
git ls-files | xargs pre-commit run --config lint.yaml --files

.PHONY: lint-update
lint-update:
pre-commit autoupdate --config lint.yaml

.PHONY: plugin-info
plugin-info:
python _nx_cugraph/__init__.py
70 changes: 70 additions & 0 deletions python/nx-cugraph/_nx_cugraph/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tell NetworkX about the cugraph backend. This file can update itself:

$ make plugin-info # Recommended method for development

or

$ python _nx_cugraph/__init__.py
"""


def get_info():
"""Target of ``networkx.plugin_info`` entry point.

This tells NetworkX about the cugraph backend without importing nx_cugraph.
"""
# Entries between BEGIN and END are automatically generated
return {
"backend_name": "cugraph",
"project": "nx-cugraph",
"package": "nx_cugraph",
"url": "https://github.com/rapidsai/cugraph/tree/branch-23.10/python/nx-cugraph",
"short_summary": "GPU-accelerated backend.",
# "description": "TODO",
"functions": {
# BEGIN: functions
"betweenness_centrality",
"edge_betweenness_centrality",
"louvain_communities",
# END: functions
},
"extra_docstrings": {
# BEGIN: extra_docstrings
"betweenness_centrality": "`weight` parameter is not yet supported.",
"edge_betweenness_centrality": "`weight` parameter is not yet supported.",
"louvain_communities": "`threshold` and `seed` parameters are currently ignored.",
# END: extra_docstrings
},
"extra_parameters": {
# BEGIN: extra_parameters
"louvain_communities": {
"max_level : int, optional": "Upper limit of the number of macro-iterations.",
},
# END: extra_parameters
},
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, we could organize info like:

        "functions": {
            # BEGIN: functions
            "betweenness_centrality": {
                "extra_docstring": "`weight` parameter is not yet supported.",
            },
            "edge_betweenness_centrality" : {
                "extra_docstring": "`weight` parameter is not yet supported.",
            },
            "louvain_communities": {
                "extra_docstring": "louvain_communities": "`threshold` and `seed` parameters are currently ignored.",
                "extra_parameters": {
                    "max_level : int, optional": "Upper limit of the number of macro-iterations.",
                },
            },
            # END: functions
        },

}


__version__ = "23.10.00"
eriknw marked this conversation as resolved.
Show resolved Hide resolved

if __name__ == "__main__":
from pathlib import Path

from _nx_cugraph.core import main

filepath = Path(__file__)
text = main(filepath)
with filepath.open("w") as f:
f.write(text)
90 changes: 90 additions & 0 deletions python/nx-cugraph/_nx_cugraph/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities to help keep _nx_cugraph up to date."""


def get_functions():
from nx_cugraph.interface import BackendInterface
from nx_cugraph.utils import networkx_algorithm

return {
key: val
for key, val in vars(BackendInterface).items()
if isinstance(val, networkx_algorithm)
}


def get_extra_docstrings(functions=None):
if functions is None:
functions = get_functions()
return {key: val.extra_doc for key, val in functions.items() if val.extra_doc}


def get_extra_parameters(functions=None):
if functions is None:
functions = get_functions()
return {key: val.extra_params for key, val in functions.items() if val.extra_params}


def update_text(text, lines_to_add, target, indent=" " * 12):
begin = f"# BEGIN: {target}\n"
end = f"# END: {target}\n"
start = text.index(begin)
stop = text.index(end)
to_add = "\n".join([f"{indent}{line}" for line in lines_to_add])
return f"{text[:start]}{begin}{to_add}\n{indent}{text[stop:]}"


def dict_to_lines(d, *, indent=""):
for key in sorted(d):
val = d[key]
if "\n" not in val:
yield f"{indent}{key!r}: {val!r},"
else:
yield f"{indent}{key!r}: ("
*lines, last_line = val.split("\n")
for line in lines:
line += "\n"
yield f" {indent}{line!r}"
yield f" {indent}{last_line!r}"
yield f"{indent}),"


def main(filepath):
from pathlib import Path

filepath = Path(filepath)
with filepath.open() as f:
orig_text = f.read()
text = orig_text

# Update functions
functions = get_functions()
to_add = [f'"{name}",' for name in sorted(functions)]
text = update_text(text, to_add, "functions")

# Update extra_docstrings
extra_docstrings = get_extra_docstrings(functions)
to_add = list(dict_to_lines(extra_docstrings))
text = update_text(text, to_add, "extra_docstrings")

# Update extra_parameters
extra_parameters = get_extra_parameters(functions)
to_add = []
for name in sorted(extra_parameters):
params = extra_parameters[name]
to_add.append(f"{name!r}: {{")
to_add.extend(dict_to_lines(params, indent=" " * 4))
to_add.append("},")
text = update_text(text, to_add, "extra_parameters")
return text
15 changes: 8 additions & 7 deletions python/nx-cugraph/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ repos:
- id: validate-pyproject
name: Validate pyproject.toml
- repo: https://github.com/PyCQA/autoflake
rev: v2.2.0
rev: v2.2.1
hooks:
- id: autoflake
args: [--in-place]
Expand All @@ -50,19 +50,20 @@ repos:
- id: black
# - id: black-jupyter
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.286
rev: v0.0.287
hooks:
- id: ruff
args: [--fix-only, --show-fixes]
- repo: https://github.com/PyCQA/flake8
rev: 6.1.0
hooks:
- id: flake8
args: ['--per-file-ignores=_nx_cugraph/__init__.py:E501'] # Why is this necessary?
additional_dependencies: &flake8_dependencies
# These versions need updated manually
- flake8==6.1.0
- flake8-bugbear==23.7.10
- flake8-simplify==0.20.0
# These versions need updated manually
- flake8==6.1.0
- flake8-bugbear==23.7.10
- flake8-simplify==0.20.0
- repo: https://github.com/asottile/yesqa
rev: v1.5.0
hooks:
Expand All @@ -76,7 +77,7 @@ repos:
additional_dependencies: [tomli]
files: ^(nx_cugraph|docs)/
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.286
rev: v0.0.287
hooks:
- id: ruff
- repo: https://github.com/pre-commit/pre-commit-hooks
Expand Down
16 changes: 14 additions & 2 deletions python/nx-cugraph/nx_cugraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,21 @@
# limitations under the License.
from networkx.exception import *

from . import algorithms, classes, convert, utils
from .algorithms import *
from . import utils

from . import classes
from .classes import *

from . import convert
from .convert import *

# from . import convert_matrix
# from .convert_matrix import *

# from . import generators
# from .generators import *
Comment on lines +23 to +27
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this dead code we can remove?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is future code, yet to be born


from . import algorithms
from .algorithms import *

__version__ = "23.10.00"
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import pylibcugraph as plc

from nx_cugraph.convert import _to_graph
from nx_cugraph.utils import _handle_seed, networkx_algorithm
from nx_cugraph.utils import _seed_to_int, networkx_algorithm

__all__ = ["betweenness_centrality", "edge_betweenness_centrality"]

Expand All @@ -22,11 +22,12 @@
def betweenness_centrality(
G, k=None, normalized=True, weight=None, endpoints=False, seed=None
):
"""`weight` parameter is not yet supported."""
if weight is not None:
raise NotImplementedError(
"Weighted implementation of betweenness centrality not currently supported"
)
seed = _handle_seed(seed)
seed = _seed_to_int(seed)
G = _to_graph(G, weight)
node_ids, values = plc.betweenness_centrality(
resource_handle=plc.ResourceHandle(),
Expand All @@ -47,6 +48,7 @@ def _(G, k=None, normalized=True, weight=None, endpoints=False, seed=None):

@networkx_algorithm
def edge_betweenness_centrality(G, k=None, normalized=True, weight=None, seed=None):
"""`weight` parameter is not yet supported."""
if weight is not None:
raise NotImplementedError(
"Weighted implementation of betweenness centrality not currently supported"
Expand Down
15 changes: 8 additions & 7 deletions python/nx-cugraph/nx_cugraph/algorithms/community/louvain.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from nx_cugraph.convert import _to_undirected_graph
from nx_cugraph.utils import (
_groupby,
_handle_seed,
_seed_to_int,
networkx_algorithm,
not_implemented_for,
)
Expand All @@ -26,16 +26,17 @@


@not_implemented_for("directed")
@networkx_algorithm(extra_params="max_level")
@networkx_algorithm(
extra_params={
"max_level : int, optional": "Upper limit of the number of macro-iterations."
}
)
def louvain_communities(
G, weight="weight", resolution=1, threshold=0.0000001, seed=None, *, max_level=None
):
"""`threshold` and `seed` parameters are currently ignored.

Extra parameter: `max_level` controls the maximum number of levels of the algorithm.
"""
"""`threshold` and `seed` parameters are currently ignored."""
# NetworkX allows both directed and undirected, but cugraph only allows undirected.
seed = _handle_seed(seed) # Unused, but ensure it's valid for future compatibility
seed = _seed_to_int(seed) # Unused, but ensure it's valid for future compatibility
G = _to_undirected_graph(G, weight)
if G.row_indices.size == 0:
# TODO: PLC doesn't handle empty graphs gracefully!
Expand Down
5 changes: 4 additions & 1 deletion python/nx-cugraph/nx_cugraph/tests/test_match_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,14 @@ def test_match_signature_and_names():
assert orig_sig == func_sig
else:
# Ignore extra parameters added to nx-cugraph algorithm
# The key of func.extra_params may be like "max_level : int, optional",
# but we only want "max_level" here.
extra_params = {name.split(" ")[0] for name in func.extra_params}
assert orig_sig == func_sig.replace(
parameters=[
p
for name, p in func_sig.parameters.items()
if name not in func.extra_params
if name not in extra_params
]
)
if func.can_run is not nxcg.utils.decorators._default_can_run:
Expand Down
Loading