Skip to content

Commit

Permalink
Better handle cudf.pandas in from_pandas_edgelist
Browse files Browse the repository at this point in the history
Optimistically use cupy, but fall back to numpy if necessary
  • Loading branch information
eriknw committed Jul 8, 2024
1 parent 42c7ad7 commit f2529ae
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 16 deletions.
14 changes: 7 additions & 7 deletions python/nx-cugraph/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ repos:
- id: mixed-line-ending
- id: trailing-whitespace
- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.17
rev: v0.18
hooks:
- id: validate-pyproject
name: Validate pyproject.toml
Expand All @@ -40,7 +40,7 @@ repos:
hooks:
- id: isort
- repo: https://github.com/asottile/pyupgrade
rev: v3.15.2
rev: v3.16.0
hooks:
- id: pyupgrade
args: [--py39-plus]
Expand All @@ -50,18 +50,18 @@ repos:
- id: black
# - id: black-jupyter
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.4
rev: v0.5.1
hooks:
- id: ruff
args: [--fix-only, --show-fixes] # --unsafe-fixes]
- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
rev: 7.1.0
hooks:
- id: flake8
args: ['--per-file-ignores=_nx_cugraph/__init__.py:E501', '--extend-ignore=SIM105'] # Why is this necessary?
additional_dependencies: &flake8_dependencies
# These versions need updated manually
- flake8==7.0.0
- flake8==7.1.0
- flake8-bugbear==24.4.26
- flake8-simplify==0.21.0
- repo: https://github.com/asottile/yesqa
Expand All @@ -70,14 +70,14 @@ repos:
- id: yesqa
additional_dependencies: *flake8_dependencies
- repo: https://github.com/codespell-project/codespell
rev: v2.2.6
rev: v2.3.0
hooks:
- id: codespell
types_or: [python, rst, markdown]
additional_dependencies: [tomli]
files: ^(nx_cugraph|docs)/
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.4
rev: v0.5.1
hooks:
- id: ruff
- repo: https://github.com/pre-commit/pre-commit-hooks
Expand Down
8 changes: 4 additions & 4 deletions python/nx-cugraph/nx_cugraph/classes/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,13 @@ def from_coo(
# Easy and fast sanity checks
if size != new_graph.dst_indices.size:
raise ValueError
for attr in ["edge_values", "edge_masks"]:
if datadict := getattr(new_graph, attr):
for edge_attr in ["edge_values", "edge_masks"]:
if datadict := getattr(new_graph, edge_attr):
for key, val in datadict.items():
if val.shape[0] != size:
raise ValueError(key)
for attr in ["node_values", "node_masks"]:
if datadict := getattr(new_graph, attr):
for node_attr in ["node_values", "node_masks"]:
if datadict := getattr(new_graph, node_attr):
for key, val in datadict.items():
if val.shape[0] != N:
raise ValueError(key)
Expand Down
20 changes: 15 additions & 5 deletions python/nx-cugraph/nx_cugraph/convert_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,25 +35,35 @@ def from_pandas_edgelist(
):
"""cudf.DataFrame inputs also supported; value columns with str is unsuppported."""
graph_class, inplace = _create_using_class(create_using)
# Try to be optimal whether using pandas, cudf, or cudf.pandas
src_array = df[source].to_numpy()
dst_array = df[target].to_numpy()
try:
# Optimistically try to use cupy, but fall back to numpy if necessary
src_array = cp.asarray(src_array)
dst_array = cp.asarray(dst_array)
np_or_cp = cp
except ValueError:
src_array = np.asarray(src_array)
dst_array = np.asarray(dst_array)
np_or_cp = np
# TODO: create renumbering helper function(s)
# Renumber step 0: node keys
nodes = np.unique(np.concatenate([src_array, dst_array]))
nodes = np_or_cp.unique(np_or_cp.concatenate([src_array, dst_array]))
N = nodes.size
kwargs = {}
if N > 0 and (
nodes[0] != 0
or nodes[N - 1] != N - 1
or (
nodes.dtype.kind not in {"i", "u"}
and not (nodes == np.arange(N, dtype=np.int64)).all()
and not (nodes == np_or_cp.arange(N, dtype=np.int64)).all()
)
):
# We need to renumber indices--np.searchsorted to the rescue!
# We need to renumber indices--np_or_cp.searchsorted to the rescue!
kwargs["id_to_key"] = nodes.tolist()
src_indices = cp.array(np.searchsorted(nodes, src_array), index_dtype)
dst_indices = cp.array(np.searchsorted(nodes, dst_array), index_dtype)
src_indices = cp.asarray(np_or_cp.searchsorted(nodes, src_array), index_dtype)
dst_indices = cp.asarray(np_or_cp.searchsorted(nodes, dst_array), index_dtype)
else:
src_indices = cp.array(src_array)
dst_indices = cp.array(dst_array)
Expand Down

0 comments on commit f2529ae

Please sign in to comment.