Skip to content

Commit

Permalink
Add scripts to display information about functions in nx-cugraph
Browse files Browse the repository at this point in the history
  • Loading branch information
eriknw committed Dec 27, 2023
1 parent 440b852 commit b051bae
Show file tree
Hide file tree
Showing 5 changed files with 344 additions and 0 deletions.
12 changes: 12 additions & 0 deletions python/nx-cugraph/nx_cugraph/scripts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
38 changes: 38 additions & 0 deletions python/nx-cugraph/nx_cugraph/scripts/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/usr/bin/env python
# Copyright (c) 2023, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if __name__ == "__main__":
import argparse

from nx_cugraph.scripts import print_table, print_tree

parser = argparse.ArgumentParser(
parents=[
print_table.get_argumentparser(add_help=False),
print_tree.get_argumentparser(add_help=False),
],
description="Print info about functions implemented by nx-cugraph",
)
parser.add_argument("action", choices=["print_table", "print_tree"])
args = parser.parse_args()
if args.action == "print_table":
print_table.main()
else:
print_tree.main(
by=args.by,
networkx_path=args.networkx_path,
dispatch_name=args.dispatch_name or args.dispatch_name_always,
version_added=args.version_added,
plc=args.plc,
dispatch_name_if_different=not args.dispatch_name_always,
)
71 changes: 71 additions & 0 deletions python/nx-cugraph/nx_cugraph/scripts/print_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#!/usr/bin/env python
# Copyright (c) 2023, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import sys
from collections import namedtuple

from networkx.utils.backends import _registered_algorithms as algos

from _nx_cugraph import get_info
from nx_cugraph.interface import BackendInterface


def get_funcpath(func):
return f"{func.__module__}.{func.__name__}"


def get_path_to_name():
return {
get_funcpath(algos[funcname]): funcname
for funcname in get_info()["functions"].keys() & algos.keys()
}


Info = namedtuple("Info", "networkx_path, dispatch_name, version_added, plc")


def get_path_to_info(path_to_name=None, version_added_sep=".", plc_sep="/"):
if path_to_name is None:
path_to_name = get_path_to_name()
rv = {}
for funcpath in sorted(path_to_name):
funcname = path_to_name[funcpath]
cufunc = getattr(BackendInterface, funcname)
plc = plc_sep.join(sorted(cufunc._plc_names)) if cufunc._plc_names else ""
version_added = cufunc.version_added.replace(".", version_added_sep)
rv[funcpath] = Info(funcpath, funcname, version_added, plc)
return rv


def main(path_to_info=None, *, file=sys.stdout):
if path_to_info is None:
path_to_info = get_path_to_info(version_added_sep=".")
lines = ["networkx_path,dispatch_name,version_added,plc"]
lines.extend(",".join(info) for info in path_to_info.values())
text = "\n".join(lines)
print(text, file=file)
return text


def get_argumentparser(add_help=True):
return argparse.ArgumentParser(
description="Print info about functions implemented by nx-cugraph as CSV",
add_help=add_help,
)


if __name__ == "__main__":
parser = get_argumentparser()
args = parser.parse_args()
main()
215 changes: 215 additions & 0 deletions python/nx-cugraph/nx_cugraph/scripts/print_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
#!/usr/bin/env python
# Copyright (c) 2023, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import re
import sys

import networkx as nx

from nx_cugraph.scripts.print_table import get_path_to_info


def add_branch(G, path, extra="", *, skip=0):
branch = path.split(".")
prev = ".".join(branch[: skip + 1])
for i in range(skip + 2, len(branch)):
cur = ".".join(branch[:i])
G.add_edge(prev, cur)
prev = cur
if extra:
if not isinstance(extra, str):
extra = ", ".join(extra)
path += f" ({extra})"
G.add_edge(prev, path)


def get_extra(
info,
*,
networkx_path=False,
dispatch_name=False,
version_added=False,
plc=False,
dispatch_name_if_different=False,
):
extra = []
if networkx_path:
extra.append(info.networkx_path)
if dispatch_name and (
not dispatch_name_if_different
or info.dispatch_name != info.networkx_path.rsplit(".", 1)[-1]
):
extra.append(info.dispatch_name)
if version_added:
v = info.version_added
if len(v) != 5:
raise ValueError(f"Is there something wrong with version: {v!r}?")
extra.append(v[:2] + "." + v[-2:])
if plc and info.plc:
extra.append(info.plc)
return extra


def create_tree(
path_to_info=None,
*,
by="networkx_path",
skip=0,
networkx_path=False,
dispatch_name=False,
version_added=False,
plc=False,
dispatch_name_if_different=False,
prefix="",
):
if path_to_info is None:
path_to_info = get_path_to_info()
if isinstance(by, str):
by = [by]
G = nx.DiGraph()
for info in sorted(
path_to_info.values(),
key=lambda x: (*(getattr(x, b) for b in by), x.networkx_path),
):
if not all(getattr(info, b) for b in by):
continue
path = prefix + ".".join(getattr(info, b) for b in by)
extra = get_extra(
info,
networkx_path=networkx_path,
dispatch_name=dispatch_name,
version_added=version_added,
plc=plc,
dispatch_name_if_different=dispatch_name_if_different,
)
add_branch(G, path, extra=extra, skip=skip)
return G


def main(
path_to_info=None,
*,
by="networkx_path",
networkx_path=False,
dispatch_name=False,
version_added=False,
plc=False,
dispatch_name_if_different=True,
file=sys.stdout,
):
if path_to_info is None:
path_to_info = get_path_to_info(version_added_sep="-")
kwargs = {
"networkx_path": networkx_path,
"dispatch_name": dispatch_name,
"version_added": version_added,
"plc": plc,
"dispatch_name_if_different": dispatch_name_if_different,
}
if by == "networkx_path":
G = create_tree(path_to_info, by="networkx_path", **kwargs)
text = re.sub(r"[A-Za-z_\./]+\.", "", ("\n".join(nx.generate_network_text(G))))
elif by == "plc":
G = create_tree(
path_to_info, by=["plc", "networkx_path"], prefix="plc-", **kwargs
)
text = re.sub(
"plc-",
"plc.",
re.sub(
r" plc-[A-Za-z_\./]*\.",
" ",
"\n".join(nx.generate_network_text(G)),
),
)
elif by == "version_added":
G = create_tree(
path_to_info,
by=["version_added", "networkx_path"],
prefix="version_added-",
**kwargs,
)
text = re.sub(
"version_added-",
"version: ",
re.sub(
r" version_added-[-0-9A-Za-z_\./]*\.",
" ",
"\n".join(nx.generate_network_text(G)),
),
).replace("-", ".")
else:
raise ValueError(
"`by` argument should be one of {'networkx_path', 'plc', 'version_added' "
f"got: {by}"
)
print(text, file=file)
return text


def get_argumentparser(add_help=True):
parser = argparse.ArgumentParser(
"Print a tree showing NetworkX functions implemented by nx-cugraph",
add_help=add_help,
)
parser.add_argument(
"--by",
choices=["networkx_path", "plc", "version_added"],
default="networkx_path",
help="How to group functions",
)
parser.add_argument(
"--dispatch-name",
"--dispatch_name",
action="store_true",
help="Show the dispatch name in parentheses if different from NetworkX name",
)
parser.add_argument(
"--dispatch-name-always",
"--dispatch_name_always",
action="store_true",
help="Always show the dispatch name in parentheses",
)
parser.add_argument(
"--plc",
"--pylibcugraph",
action="store_true",
help="Show the used pylibcugraph function in parentheses",
)
parser.add_argument(
"--version-added",
"--version_added",
action="store_true",
help="Show the version added in parentheses",
)
parser.add_argument(
"--networkx-path",
"--networkx_path",
action="store_true",
help="Show the full networkx path in parentheses",
)
return parser


if __name__ == "__main__":
parser = get_argumentparser()
args = parser.parse_args()
main(
by=args.by,
networkx_path=args.networkx_path,
dispatch_name=args.dispatch_name or args.dispatch_name_always,
version_added=args.version_added,
plc=args.plc,
dispatch_name_if_different=not args.dispatch_name_always,
)
8 changes: 8 additions & 0 deletions python/nx-cugraph/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,20 @@ test = [
Homepage = "https://github.com/rapidsai/cugraph"
Documentation = "https://docs.rapids.ai/api/cugraph/stable/"

# "plugin" used in nx version < 3.2
[project.entry-points."networkx.plugins"]
cugraph = "nx_cugraph.interface:BackendInterface"

[project.entry-points."networkx.plugin_info"]
cugraph = "_nx_cugraph:get_info"

# "backend" used in nx version >= 3.2
[project.entry-points."networkx.backends"]
cugraph = "nx_cugraph.interface:BackendInterface"

[project.entry-points."networkx.backend_info"]
cugraph = "_nx_cugraph:get_info"

[tool.setuptools]
license-files = ["LICENSE"]

Expand Down

0 comments on commit b051bae

Please sign in to comment.