Skip to content

Commit

Permalink
nx-cugraph: indicate which plc algorithms are used
Browse files Browse the repository at this point in the history
  • Loading branch information
eriknw committed Dec 21, 2023
1 parent 8d5bba3 commit 799c864
Show file tree
Hide file tree
Showing 12 changed files with 41 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
__all__ = ["betweenness_centrality", "edge_betweenness_centrality"]


@networkx_algorithm
@networkx_algorithm(plc="betweenness_centrality")
def betweenness_centrality(
G, k=None, normalized=True, weight=None, endpoints=False, seed=None
):
Expand Down Expand Up @@ -46,7 +46,7 @@ def _(G, k=None, normalized=True, weight=None, endpoints=False, seed=None):
return weight is None


@networkx_algorithm
@networkx_algorithm(plc="edge_betweenness_centrality")
def edge_betweenness_centrality(G, k=None, normalized=True, weight=None, seed=None):
"""`weight` parameter is not yet supported."""
if weight is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


@not_implemented_for("multigraph")
@networkx_algorithm(extra_params=_dtype_param)
@networkx_algorithm(extra_params=_dtype_param, plc="eigenvector_centrality")
def eigenvector_centrality(
G, max_iter=100, tol=1.0e-6, nstart=None, weight=None, *, dtype=None
):
Expand Down
2 changes: 1 addition & 1 deletion python/nx-cugraph/nx_cugraph/algorithms/centrality/katz.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@


@not_implemented_for("multigraph")
@networkx_algorithm(extra_params=_dtype_param)
@networkx_algorithm(extra_params=_dtype_param, plc="katz_centrality")
def katz_centrality(
G,
alpha=0.1,
Expand Down
3 changes: 2 additions & 1 deletion python/nx-cugraph/nx_cugraph/algorithms/community/louvain.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
"Upper limit of the number of macro-iterations (max: 500)."
),
**_dtype_param,
}
},
plc="louvain",
)
def louvain_communities(
G,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _(G):


@not_implemented_for("directed")
@networkx_algorithm
@networkx_algorithm(plc="weakly_connected_components")
def connected_components(G):
G = _to_undirected_graph(G)
if G.src_indices.size == 0:
Expand Down Expand Up @@ -86,7 +86,7 @@ def connected_components(G):


@not_implemented_for("directed")
@networkx_algorithm
@networkx_algorithm(plc="weakly_connected_components")
def is_connected(G):
G = _to_undirected_graph(G)
if len(G) == 0:
Expand All @@ -110,7 +110,7 @@ def is_connected(G):


@not_implemented_for("directed")
@networkx_algorithm
@networkx_algorithm(plc="weakly_connected_components")
def node_connected_component(G, n):
# We could also do plain BFS from n
G = _to_undirected_graph(G)
Expand Down
2 changes: 1 addition & 1 deletion python/nx-cugraph/nx_cugraph/algorithms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

@not_implemented_for("directed")
@not_implemented_for("multigraph")
@networkx_algorithm
@networkx_algorithm(plc="k_truss_subgraph")
def k_truss(G, k):
"""
Currently raises `NotImplementedError` for graphs with more than one connected
Expand Down
4 changes: 2 additions & 2 deletions python/nx-cugraph/nx_cugraph/algorithms/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ def _ancestors_and_descendants(G, source, *, is_ancestors):
return G._nodearray_to_set(node_ids[mask])


@networkx_algorithm
@networkx_algorithm(plc="bfs")
def descendants(G, source):
return _ancestors_and_descendants(G, source, is_ancestors=False)


@networkx_algorithm
@networkx_algorithm(plc="bfs")
def ancestors(G, source):
return _ancestors_and_descendants(G, source, is_ancestors=True)
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
"The edge attribute to use as the edge weight."
),
**_dtype_param,
}
},
plc="hits",
)
def hits(
G,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
__all__ = ["pagerank"]


@networkx_algorithm(extra_params=_dtype_param)
@networkx_algorithm(
extra_params=_dtype_param, plc={"pagerank", "personalized_pagerank"}
)
def pagerank(
G,
alpha=0.85,
Expand Down Expand Up @@ -97,7 +99,7 @@ def pagerank(


@pagerank._can_run
def pagerank(
def _(
G,
alpha=0.85,
personalization=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@
__all__ = ["single_source_shortest_path_length", "single_target_shortest_path_length"]


@networkx_algorithm
@networkx_algorithm(plc="bfs")
def single_source_shortest_path_length(G, source, cutoff=None):
return _single_shortest_path_length(G, source, cutoff, "Source")


@networkx_algorithm
@networkx_algorithm(plc="bfs")
def single_target_shortest_path_length(G, target, cutoff=None):
return _single_shortest_path_length(G, target, cutoff, "Target")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _bfs(G, source, *, depth_limit=None, reverse=False):
return distances[mask], predecessors[mask], node_ids[mask]


@networkx_algorithm
@networkx_algorithm(plc="bfs")
def generic_bfs_edges(G, source, neighbors=None, depth_limit=None, sort_neighbors=None):
"""`neighbors` and `sort_neighbors` parameters are not yet supported."""
return bfs_edges(source, depth_limit=depth_limit)
Expand All @@ -68,7 +68,7 @@ def _(G, source, neighbors=None, depth_limit=None, sort_neighbors=None):
return neighbors is None and sort_neighbors is None


@networkx_algorithm
@networkx_algorithm(plc="bfs")
def bfs_edges(G, source, reverse=False, depth_limit=None, sort_neighbors=None):
"""`sort_neighbors` parameter is not yet supported."""
G = _check_G_and_source(G, source)
Expand All @@ -95,7 +95,7 @@ def _(G, source, reverse=False, depth_limit=None, sort_neighbors=None):
return sort_neighbors is None


@networkx_algorithm
@networkx_algorithm(plc="bfs")
def bfs_tree(G, source, reverse=False, depth_limit=None, sort_neighbors=None):
"""`sort_neighbors` parameter is not yet supported."""
G = _check_G_and_source(G, source)
Expand Down Expand Up @@ -149,7 +149,7 @@ def _(G, source, reverse=False, depth_limit=None, sort_neighbors=None):
return sort_neighbors is None


@networkx_algorithm
@networkx_algorithm(plc="bfs")
def bfs_successors(G, source, depth_limit=None, sort_neighbors=None):
"""`sort_neighbors` parameter is not yet supported."""
G = _check_G_and_source(G, source)
Expand All @@ -173,7 +173,7 @@ def _(G, source, depth_limit=None, sort_neighbors=None):
return sort_neighbors is None


@networkx_algorithm
@networkx_algorithm(plc="bfs")
def bfs_layers(G, sources):
G = _to_graph(G)
if sources in G:
Expand Down Expand Up @@ -201,7 +201,7 @@ def bfs_layers(G, sources):
return (G._nodearray_to_list(groups[key]) for key in range(len(groups)))


@networkx_algorithm
@networkx_algorithm(plc="bfs")
def bfs_predecessors(G, source, depth_limit=None, sort_neighbors=None):
"""`sort_neighbors` parameter is not yet supported."""
G = _check_G_and_source(G, source)
Expand All @@ -227,7 +227,7 @@ def _(G, source, depth_limit=None, sort_neighbors=None):
return sort_neighbors is None


@networkx_algorithm
@networkx_algorithm(plc="bfs")
def descendants_at_distance(G, source, distance):
G = _check_G_and_source(G, source)
if distance is None or distance < 0:
Expand Down
14 changes: 14 additions & 0 deletions python/nx-cugraph/nx_cugraph/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class networkx_algorithm:
name: str
extra_doc: str | None
extra_params: dict[str, str] | None
_plc_names: set[str] | None

def __new__(
cls,
Expand All @@ -49,13 +50,15 @@ def __new__(
name: str | None = None,
extra_params: dict[str, str] | str | None = None,
nodes_or_number: list[int] | int | None = None,
plc: str | set[str] | None = None,
):
if func is None:
return partial(
networkx_algorithm,
name=name,
extra_params=extra_params,
nodes_or_number=nodes_or_number,
plc=plc,
)
instance = object.__new__(cls)
if nodes_or_number is not None and nx.__version__[:3] > "3.2":
Expand All @@ -74,6 +77,12 @@ def __new__(
f"extra_params must be dict, str, or None; got {type(extra_params)}"
)
instance.extra_params = extra_params
if plc is None or isinstance(plc, set):
instance._plc_names = plc
elif isinstance(plc, str):
instance._plc_names = {plc}
else:
raise TypeError(f"plc argument must be str, set, or None; got {type(plc)}")
# The docstring on our function is added to the NetworkX docstring.
instance.extra_doc = (
dedent(func.__doc__.lstrip("\n").rstrip()) if func.__doc__ else None
Expand All @@ -91,6 +100,11 @@ def __new__(

def _can_run(self, func):
"""Set the `can_run` attribute to the decorated function."""
if not func.__name__.startswith("_"):
raise ValueError(
"The name of the function used by `_can_run` must begin with '_'; "
f"got: {func.__name__!r}"
)
self.can_run = func

def __call__(self, /, *args, **kwargs):
Expand Down

0 comments on commit 799c864

Please sign in to comment.