Skip to content

Commit

Permalink
nx-cugraph: automatically generate trees in README.md (#4156)
Browse files Browse the repository at this point in the history
This updates how we create trees. Also, CI now tests that auto-generated files are up-to-date (not updating these has gotten me a couple of times).

Authors:
  - Erik Welch (https://github.com/eriknw)

Approvers:
  - Rick Ratzel (https://github.com/rlratzel)
  - Jake Awe (https://github.com/AyodeAwe)

URL: #4156
  • Loading branch information
eriknw authored Feb 27, 2024
1 parent d1ed728 commit 6efb1c6
Show file tree
Hide file tree
Showing 15 changed files with 460 additions and 213 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ datasets/*
!datasets/karate-disjoint.csv
!datasets/netscience.csv

# nx-cugraph side effects
python/nx-cugraph/objects.inv

.pydevproject

Expand Down
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ repos:
hooks:
- id: black
language_version: python3
args: [--target-version=py38]
args: [--target-version=py39]
files: ^(python/.*|benchmarks/.*)$
exclude: ^python/nx-cugraph/
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
hooks:
Expand Down
7 changes: 7 additions & 0 deletions ci/test_python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,13 @@ python -m nx_cugraph.scripts.print_tree --dispatch-name --plc --incomplete --dif
python -m nx_cugraph.scripts.print_table
popd

rapids-logger "ensure nx-cugraph autogenerated files are up to date"
pushd python/nx-cugraph
make || true
git diff --exit-code .
git checkout .
popd

rapids-logger "pytest cugraph-service (single GPU)"
./ci/run_cugraph_service_pytests.sh \
--verbose \
Expand Down
13 changes: 10 additions & 3 deletions python/nx-cugraph/Makefile
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
SHELL= /bin/bash

.PHONY: all
all: plugin-info lint
all: plugin-info lint readme

.PHONY: lint
lint:
git ls-files | xargs pre-commit run --config lint.yaml --files
git ls-files | xargs pre-commit run --config lint.yaml --files || true

.PHONY: lint-update
lint-update:
Expand All @@ -15,3 +15,10 @@ lint-update:
.PHONY: plugin-info
plugin-info:
python _nx_cugraph/__init__.py

objects.inv:
wget https://networkx.org/documentation/stable/objects.inv

.PHONY: readme
readme: objects.inv
python scripts/update_readme.py README.md objects.inv
270 changes: 135 additions & 135 deletions python/nx-cugraph/README.md

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions python/nx-cugraph/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ repos:
- id: validate-pyproject
name: Validate pyproject.toml
- repo: https://github.com/PyCQA/autoflake
rev: v2.2.1
rev: v2.3.0
hooks:
- id: autoflake
args: [--in-place]
Expand All @@ -40,17 +40,17 @@ repos:
hooks:
- id: isort
- repo: https://github.com/asottile/pyupgrade
rev: v3.15.0
rev: v3.15.1
hooks:
- id: pyupgrade
args: [--py39-plus]
- repo: https://github.com/psf/black
rev: 23.12.1
rev: 24.2.0
hooks:
- id: black
# - id: black-jupyter
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.14
rev: v0.2.2
hooks:
- id: ruff
args: [--fix-only, --show-fixes] # --unsafe-fixes]
Expand All @@ -62,7 +62,7 @@ repos:
additional_dependencies: &flake8_dependencies
# These versions need updated manually
- flake8==7.0.0
- flake8-bugbear==24.1.17
- flake8-bugbear==24.2.6
- flake8-simplify==0.21.0
- repo: https://github.com/asottile/yesqa
rev: v1.5.0
Expand All @@ -77,7 +77,7 @@ repos:
additional_dependencies: [tomli]
files: ^(nx_cugraph|docs)/
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.14
rev: v0.2.2
hooks:
- id: ruff
- repo: https://github.com/pre-commit/pre-commit-hooks
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def hits(
resource_handle=plc.ResourceHandle(),
graph=G._get_plc_graph(weight, 1, dtype, store_transposed=True),
tol=tol,
initial_hubs_guess_vertices=None
if nstart is None
else cp.arange(N, dtype=index_dtype),
initial_hubs_guess_vertices=(
None if nstart is None else cp.arange(N, dtype=index_dtype)
),
initial_hubs_guess_values=nstart,
max_iter=max_iter,
normalized=normalized,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def pagerank(
"graph": G._get_plc_graph(weight, 1, dtype, store_transposed=True),
"precomputed_vertex_out_weight_vertices": None,
"precomputed_vertex_out_weight_sums": None,
"initial_guess_vertices": None
if nstart is None
else cp.arange(N, dtype=index_dtype),
"initial_guess_vertices": (
None if nstart is None else cp.arange(N, dtype=index_dtype)
),
"initial_guess_values": nstart,
"alpha": alpha,
"epsilon": N * tol,
Expand Down
3 changes: 1 addition & 2 deletions python/nx-cugraph/nx_cugraph/classes/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,8 +634,7 @@ def _get_plc_graph(
"pylibcugraph only supports float16 and float32 dtypes."
)
elif (
edge_array.dtype == np.uint64
and edge_array.max().tolist() > 2**53
edge_array.dtype == np.uint64 and edge_array.max().tolist() > 2**53
):
raise ValueError(
f"Integer value of value is too large (> 2**53): {val}; "
Expand Down
4 changes: 1 addition & 3 deletions python/nx-cugraph/nx_cugraph/classes/multigraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,7 @@ def get_edge_data(
if k not in self.edge_masks or self.edge_masks[k][index]
}
return {
edge_keys[index]
if edge_keys is not None
else index: {
edge_keys[index] if edge_keys is not None else index: {
k: v[index].tolist()
for k, v in self.edge_values.items()
if k not in self.edge_masks or self.edge_masks[k][index]
Expand Down
6 changes: 3 additions & 3 deletions python/nx-cugraph/nx_cugraph/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,9 @@ def key(testpath):
)
if sys.version_info[:2] == (3, 9):
# This test is sensitive to RNG, which depends on Python version
xfail[
key("test_louvain.py:test_threshold")
] = "Louvain does not support seed parameter"
xfail[key("test_louvain.py:test_threshold")] = (
"Louvain does not support seed parameter"
)
if nxver.major == 3 and nxver.minor >= 2:
xfail.update(
{
Expand Down
3 changes: 2 additions & 1 deletion python/nx-cugraph/nx_cugraph/scripts/print_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def main(path_to_info=None, *, file=sys.stdout):
lines = ["networkx_path,dispatch_name,version_added,plc,is_incomplete,is_different"]
lines.extend(",".join(map(str, info)) for info in path_to_info.values())
text = "\n".join(lines)
print(text, file=file)
if file is not None:
print(text, file=file)
return text


Expand Down
122 changes: 75 additions & 47 deletions python/nx-cugraph/nx_cugraph/scripts/print_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,58 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import re
import sys

import networkx as nx
from nx_cugraph.scripts.print_table import Info, get_path_to_info

from nx_cugraph.scripts.print_table import get_path_to_info

def assoc_in(d, keys, value):
"""Like Clojure's assoc-in, but modifies d in-place."""
inner = d
keys = iter(keys)
key = next(keys)
for next_key in keys:
if key not in inner:
inner[key] = {}
inner = inner[key]
key = next_key
inner[key] = value
return d

def add_branch(G, path, extra="", *, skip=0):
branch = path.split(".")
prev = ".".join(branch[: skip + 1])
for i in range(skip + 2, len(branch)):
cur = ".".join(branch[:i])
G.add_edge(prev, cur)
prev = cur
if extra:
if not isinstance(extra, str):
extra = ", ".join(extra)
path += f" ({extra})"
G.add_edge(prev, path)

def default_get_payload_internal(keys):
return keys[-1]


def tree_lines(
tree,
parents=(),
are_levels_closing=(),
get_payload_internal=default_get_payload_internal,
):
pre = "".join(
" " if is_level_closing else " │ "
for is_level_closing in are_levels_closing
)
c = "├"
are_levels_closing += (False,)
for i, (key, val) in enumerate(tree.items(), 1):
if i == len(tree): # Last item
c = "└"
are_levels_closing = are_levels_closing[:-1] + (True,)
if isinstance(val, str):
yield pre + f" {c}─ " + val
else:
yield pre + f" {c}─ " + get_payload_internal((*parents, key))
yield from tree_lines(
val,
(*parents, key),
are_levels_closing,
get_payload_internal=get_payload_internal,
)


def get_extra(
def get_payload(
info,
*,
networkx_path=False,
Expand Down Expand Up @@ -64,7 +93,10 @@ def get_extra(
extra.append("is-incomplete")
if different and info.is_different:
extra.append("is-different")
return extra
extra = ", ".join(extra)
if extra:
extra = f" ({extra})"
return info.networkx_path.rsplit(".", 1)[-1] + extra


def create_tree(
Expand All @@ -80,20 +112,28 @@ def create_tree(
incomplete=False,
different=False,
prefix="",
strip_networkx=True,
get_payload=get_payload,
):
if path_to_info is None:
path_to_info = get_path_to_info()
if strip_networkx:
path_to_info = {
key: Info(info.networkx_path.replace("networkx.", "", 1), *info[1:])
for key, info in path_to_info.items()
}
if isinstance(by, str):
by = [by]
G = nx.DiGraph()
# We rely on the fact that dicts maintain order
tree = {}
for info in sorted(
path_to_info.values(),
key=lambda x: (*(getattr(x, b) for b in by), x.networkx_path),
):
if not all(getattr(info, b) for b in by):
continue
path = prefix + ".".join(getattr(info, b) for b in by)
extra = get_extra(
payload = get_payload(
info,
networkx_path=networkx_path,
dispatch_name=dispatch_name,
Expand All @@ -103,8 +143,8 @@ def create_tree(
incomplete=incomplete,
different=different,
)
add_branch(G, path, extra=extra, skip=skip)
return G
assoc_in(tree, path.split("."), payload)
return tree


def main(
Expand Down Expand Up @@ -132,45 +172,33 @@ def main(
"different": different,
}
if by == "networkx_path":
G = create_tree(path_to_info, by="networkx_path", **kwargs)
text = re.sub(
r" [A-Za-z_\./]+\.", " ", ("\n".join(nx.generate_network_text(G)))
)
tree = create_tree(path_to_info, by="networkx_path", **kwargs)
text = "\n".join(tree_lines(tree))
elif by == "plc":
G = create_tree(
path_to_info, by=["plc", "networkx_path"], prefix="plc-", **kwargs
)
text = re.sub(
"plc-",
"plc.",
re.sub(
r" plc-[A-Za-z_\./]*\.",
" ",
"\n".join(nx.generate_network_text(G)),
),
tree = create_tree(
path_to_info,
by=["plc", "networkx_path"],
prefix="plc-",
**kwargs,
)
text = "\n".join(tree_lines(tree)).replace("plc-", "plc.")
elif by == "version_added":
G = create_tree(
tree = create_tree(
path_to_info,
by=["version_added", "networkx_path"],
prefix="version_added-",
**kwargs,
)
text = re.sub(
"version_added-",
"version: ",
re.sub(
r" version_added-[-0-9A-Za-z_\./]*\.",
" ",
"\n".join(nx.generate_network_text(G)),
),
).replace("-", ".")
text = "\n".join(tree_lines(tree)).replace("version_added-", "version: ")
for digit in "0123456789":
text = text.replace(f"2{digit}-", f"2{digit}.")
else:
raise ValueError(
"`by` argument should be one of {'networkx_path', 'plc', 'version_added' "
f"got: {by}"
)
print(text, file=file)
if file is not None:
print(text, file=file)
return text


Expand Down
Loading

0 comments on commit 6efb1c6

Please sign in to comment.