Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nx-cugraph: Fixes dependency on missing feature, adds attrs to allow auto-dispatch for generators #4558

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 0 additions & 26 deletions python/nx-cugraph/nx_cugraph/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@
from .utils import index_dtype, networkx_algorithm
from .utils.misc import _And_NotImplementedError, pairwise

if _nxver >= (3, 4):
from networkx.utils.backends import _get_cache_key, _get_from_cache, _set_to_cache

if TYPE_CHECKING: # pragma: no cover
from nx_cugraph.typing import AttrKey, Dtype, EdgeValue, NodeValue, any_ndarray

Expand Down Expand Up @@ -195,24 +192,6 @@ def from_networkx(
"you have found a bug, please report a minimum reproducible example to "
"https://github.com/rapidsai/cugraph/issues/new/choose"
)
if _nxver >= (3, 4):
cache_key = _get_cache_key(
edge_attrs=edge_attrs,
node_attrs=node_attrs,
preserve_edge_attrs=preserve_edge_attrs,
preserve_node_attrs=preserve_node_attrs,
preserve_graph_attrs=preserve_graph_attrs,
)
cache = getattr(graph, "__networkx_cache__", None)
if cache is not None:
cache = cache.setdefault("backends", {}).setdefault("cugraph", {})
compat_key, rv = _get_from_cache(cache, cache_key)
if rv is not None:
if isinstance(rv, nxcg.Graph):
# This shouldn't happen during normal use, but be extra-careful
rv = rv._cudagraph
if rv is not None:
return rv

if preserve_all_attrs:
preserve_edge_attrs = True
Expand Down Expand Up @@ -562,11 +541,6 @@ def func(it, edge_attr=edge_attr, dtype=dtype):
)
if preserve_graph_attrs:
rv.graph.update(graph.graph) # deepcopy?
if _nxver >= (3, 4) and isinstance(graph, nxcg.Graph) and cache is not None:
# Make sure this conversion is added to the cache, and make all of
# our graphs share the same `.graph` attribute for consistency.
rv.graph = graph.graph
_set_to_cache(cache, cache_key, rv)
if (
use_compat_graph
# Use compat graphs by default
Expand Down
5 changes: 4 additions & 1 deletion python/nx-cugraph/nx_cugraph/convert_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@


# Value columns with string dtype is not supported
@networkx_algorithm(is_incomplete=True, version_added="23.12", fallback=True)
@networkx_algorithm(is_incomplete=True,
returns_networkx_compatible_graph=True,
fallback=True,
version_added="23.12")
def from_pandas_edgelist(
df,
source="source",
Expand Down
5 changes: 5 additions & 0 deletions python/nx-cugraph/nx_cugraph/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class networkx_algorithm:
is_different: bool
_fallback: bool
_plc_names: set[str] | None
returns_networkx_compatible_graph: bool

def __new__(
cls,
Expand All @@ -66,6 +67,8 @@ def __new__(
is_different: bool = False, # See self.extra_doc for details if True
fallback: bool = False, # Change non-nx exceptions to NotImplementedError
_plc: str | set[str] | None = None, # Hidden from user, may be removed someday
# True if function returns G where isinstance(G, nx.Graph) == True
returns_networkx_compatible_graph: bool = False,
):
if func is None:
return partial(
Expand All @@ -78,6 +81,7 @@ def __new__(
is_different=is_different,
fallback=fallback,
_plc=_plc,
returns_networkx_compatible_graph=returns_networkx_compatible_graph,
)
instance = object.__new__(cls)
if nodes_or_number is not None and _nxver > (3, 2):
Expand Down Expand Up @@ -107,6 +111,7 @@ def __new__(
instance.version_added = version_added
instance.is_incomplete = is_incomplete
instance.is_different = is_different
instance.returns_networkx_compatible_graph = returns_networkx_compatible_graph
instance.fallback = fallback
# The docstring on our function is added to the NetworkX docstring.
instance.extra_doc = (
Expand Down
Loading