From 355efc1b450cdd75b73ca42cb1bf725695ac09af Mon Sep 17 00:00:00 2001 From: Erik Welch Date: Tue, 9 Jul 2024 01:25:11 +0200 Subject: [PATCH] Better handle cudf.pandas in `from_pandas_edgelist` (#4525) Optimistically use cupy, but fall back to numpy if necessary. Also, bump lint versions. CC @rlratzel Authors: - Erik Welch (https://github.com/eriknw) Approvers: - Rick Ratzel (https://github.com/rlratzel) URL: https://github.com/rapidsai/cugraph/pull/4525 --- python/nx-cugraph/lint.yaml | 14 ++++++------- python/nx-cugraph/nx_cugraph/classes/graph.py | 8 ++++---- .../nx-cugraph/nx_cugraph/convert_matrix.py | 20 ++++++++++++++----- 3 files changed, 26 insertions(+), 16 deletions(-) diff --git a/python/nx-cugraph/lint.yaml b/python/nx-cugraph/lint.yaml index c4422ffb97d..317d5b8d481 100644 --- a/python/nx-cugraph/lint.yaml +++ b/python/nx-cugraph/lint.yaml @@ -26,7 +26,7 @@ repos: - id: mixed-line-ending - id: trailing-whitespace - repo: https://github.com/abravalheri/validate-pyproject - rev: v0.17 + rev: v0.18 hooks: - id: validate-pyproject name: Validate pyproject.toml @@ -40,7 +40,7 @@ repos: hooks: - id: isort - repo: https://github.com/asottile/pyupgrade - rev: v3.15.2 + rev: v3.16.0 hooks: - id: pyupgrade args: [--py39-plus] @@ -50,18 +50,18 @@ repos: - id: black # - id: black-jupyter - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.4 + rev: v0.5.1 hooks: - id: ruff args: [--fix-only, --show-fixes] # --unsafe-fixes] - repo: https://github.com/PyCQA/flake8 - rev: 7.0.0 + rev: 7.1.0 hooks: - id: flake8 args: ['--per-file-ignores=_nx_cugraph/__init__.py:E501', '--extend-ignore=SIM105'] # Why is this necessary? additional_dependencies: &flake8_dependencies # These versions need updated manually - - flake8==7.0.0 + - flake8==7.1.0 - flake8-bugbear==24.4.26 - flake8-simplify==0.21.0 - repo: https://github.com/asottile/yesqa @@ -70,14 +70,14 @@ repos: - id: yesqa additional_dependencies: *flake8_dependencies - repo: https://github.com/codespell-project/codespell - rev: v2.2.6 + rev: v2.3.0 hooks: - id: codespell types_or: [python, rst, markdown] additional_dependencies: [tomli] files: ^(nx_cugraph|docs)/ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.4 + rev: v0.5.1 hooks: - id: ruff - repo: https://github.com/pre-commit/pre-commit-hooks diff --git a/python/nx-cugraph/nx_cugraph/classes/graph.py b/python/nx-cugraph/nx_cugraph/classes/graph.py index 5132e6a547b..ad7cf319139 100644 --- a/python/nx-cugraph/nx_cugraph/classes/graph.py +++ b/python/nx-cugraph/nx_cugraph/classes/graph.py @@ -124,13 +124,13 @@ def from_coo( # Easy and fast sanity checks if size != new_graph.dst_indices.size: raise ValueError - for attr in ["edge_values", "edge_masks"]: - if datadict := getattr(new_graph, attr): + for edge_attr in ["edge_values", "edge_masks"]: + if datadict := getattr(new_graph, edge_attr): for key, val in datadict.items(): if val.shape[0] != size: raise ValueError(key) - for attr in ["node_values", "node_masks"]: - if datadict := getattr(new_graph, attr): + for node_attr in ["node_values", "node_masks"]: + if datadict := getattr(new_graph, node_attr): for key, val in datadict.items(): if val.shape[0] != N: raise ValueError(key) diff --git a/python/nx-cugraph/nx_cugraph/convert_matrix.py b/python/nx-cugraph/nx_cugraph/convert_matrix.py index 1a2ecde9b8c..67f6386987b 100644 --- a/python/nx-cugraph/nx_cugraph/convert_matrix.py +++ b/python/nx-cugraph/nx_cugraph/convert_matrix.py @@ -35,11 +35,21 @@ def from_pandas_edgelist( ): """cudf.DataFrame inputs also supported; value columns with str is unsuppported.""" graph_class, inplace = _create_using_class(create_using) + # Try to be optimal whether using pandas, cudf, or cudf.pandas src_array = df[source].to_numpy() dst_array = df[target].to_numpy() + try: + # Optimistically try to use cupy, but fall back to numpy if necessary + src_array = cp.asarray(src_array) + dst_array = cp.asarray(dst_array) + np_or_cp = cp + except ValueError: + src_array = np.asarray(src_array) + dst_array = np.asarray(dst_array) + np_or_cp = np # TODO: create renumbering helper function(s) # Renumber step 0: node keys - nodes = np.unique(np.concatenate([src_array, dst_array])) + nodes = np_or_cp.unique(np_or_cp.concatenate([src_array, dst_array])) N = nodes.size kwargs = {} if N > 0 and ( @@ -47,13 +57,13 @@ def from_pandas_edgelist( or nodes[N - 1] != N - 1 or ( nodes.dtype.kind not in {"i", "u"} - and not (nodes == np.arange(N, dtype=np.int64)).all() + and not (nodes == np_or_cp.arange(N, dtype=np.int64)).all() ) ): - # We need to renumber indices--np.searchsorted to the rescue! + # We need to renumber indices--np_or_cp.searchsorted to the rescue! kwargs["id_to_key"] = nodes.tolist() - src_indices = cp.array(np.searchsorted(nodes, src_array), index_dtype) - dst_indices = cp.array(np.searchsorted(nodes, dst_array), index_dtype) + src_indices = cp.asarray(np_or_cp.searchsorted(nodes, src_array), index_dtype) + dst_indices = cp.asarray(np_or_cp.searchsorted(nodes, dst_array), index_dtype) else: src_indices = cp.array(src_array) dst_indices = cp.array(dst_array)