Skip to content

Commit

Permalink
Allow node values to be numpy arrays or cupy arrays
Browse files Browse the repository at this point in the history
Currently, node values aren't used for any values, the only thing they
are used for is converting to and from networkx, which we do just fine.
  • Loading branch information
eriknw committed Oct 25, 2023
1 parent 03623a9 commit 13c8a6c
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 33 deletions.
1 change: 0 additions & 1 deletion python/nx-cugraph/.flake8
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,5 @@ extend-ignore =
# E203 whitespace before ':' (to be compatible with black)
per-file-ignores =
nx_cugraph/tests/*.py:T201,
nx_cugraph/generators/community.py:E741,
__init__.py:F401,F403,
_nx_cugraph/__init__.py:E501,
43 changes: 30 additions & 13 deletions python/nx-cugraph/nx_cugraph/classes/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
IndexValue,
NodeKey,
NodeValue,
any_ndarray,
)

__all__ = ["Graph"]
Expand All @@ -58,8 +59,8 @@ class Graph:
col_indices: cp.ndarray[IndexValue]
edge_values: dict[AttrKey, cp.ndarray[EdgeValue]]
edge_masks: dict[AttrKey, cp.ndarray[bool]]
node_values: dict[AttrKey, cp.ndarray[NodeValue]]
node_masks: dict[AttrKey, cp.ndarray[bool]]
node_values: dict[AttrKey, any_ndarray[NodeValue]]
node_masks: dict[AttrKey, any_ndarray[bool]]
key_to_id: dict[NodeKey, IndexValue] | None
_id_to_key: list[NodeKey] | None
_N: int
Expand Down Expand Up @@ -97,8 +98,8 @@ def from_coo(
col_indices: cp.ndarray[IndexValue],
edge_values: dict[AttrKey, cp.ndarray[EdgeValue]] | None = None,
edge_masks: dict[AttrKey, cp.ndarray[bool]] | None = None,
node_values: dict[AttrKey, cp.ndarray[NodeValue]] | None = None,
node_masks: dict[AttrKey, cp.ndarray[bool]] | None = None,
node_values: dict[AttrKey, any_ndarray[NodeValue]] | None = None,
node_masks: dict[AttrKey, any_ndarray[bool]] | None = None,
*,
key_to_id: dict[NodeKey, IndexValue] | None = None,
id_to_key: list[NodeKey] | None = None,
Expand Down Expand Up @@ -139,6 +140,22 @@ def from_coo(
new_graph.key_to_id = dict(zip(new_graph._id_to_key, range(N)))
except TypeError as exc:
raise ValueError("Bad type of a node value") from exc
if new_graph.row_indices.dtype != index_dtype:
row_indices = new_graph.row_indices.astype(index_dtype)
if not (new_graph.row_indices == row_indices).all():
raise ValueError(
f"Unable to convert row_indices to {row_indices.dtype.name} "
f"(got {new_graph.row_indices.dtype.name})."
)
new_graph.row_indices = row_indices
if new_graph.col_indices.dtype != index_dtype:
col_indices = new_graph.col_indices.astype(index_dtype)
if not (new_graph.col_indices == col_indices).all():
raise ValueError(
f"Unable to convert col_indices to {col_indices.dtype.name} "
f"(got {new_graph.col_indices.dtype.name})."
)
new_graph.col_indices = col_indices
return new_graph

@classmethod
Expand All @@ -148,8 +165,8 @@ def from_csr(
col_indices: cp.ndarray[IndexValue],
edge_values: dict[AttrKey, cp.ndarray[EdgeValue]] | None = None,
edge_masks: dict[AttrKey, cp.ndarray[bool]] | None = None,
node_values: dict[AttrKey, cp.ndarray[NodeValue]] | None = None,
node_masks: dict[AttrKey, cp.ndarray[bool]] | None = None,
node_values: dict[AttrKey, any_ndarray[NodeValue]] | None = None,
node_masks: dict[AttrKey, any_ndarray[bool]] | None = None,
*,
key_to_id: dict[NodeKey, IndexValue] | None = None,
id_to_key: list[NodeKey] | None = None,
Expand Down Expand Up @@ -180,8 +197,8 @@ def from_csc(
row_indices: cp.ndarray[IndexValue],
edge_values: dict[AttrKey, cp.ndarray[EdgeValue]] | None = None,
edge_masks: dict[AttrKey, cp.ndarray[bool]] | None = None,
node_values: dict[AttrKey, cp.ndarray[NodeValue]] | None = None,
node_masks: dict[AttrKey, cp.ndarray[bool]] | None = None,
node_values: dict[AttrKey, any_ndarray[NodeValue]] | None = None,
node_masks: dict[AttrKey, any_ndarray[bool]] | None = None,
*,
key_to_id: dict[NodeKey, IndexValue] | None = None,
id_to_key: list[NodeKey] | None = None,
Expand Down Expand Up @@ -214,8 +231,8 @@ def from_dcsr(
col_indices: cp.ndarray[IndexValue],
edge_values: dict[AttrKey, cp.ndarray[EdgeValue]] | None = None,
edge_masks: dict[AttrKey, cp.ndarray[bool]] | None = None,
node_values: dict[AttrKey, cp.ndarray[NodeValue]] | None = None,
node_masks: dict[AttrKey, cp.ndarray[bool]] | None = None,
node_values: dict[AttrKey, any_ndarray[NodeValue]] | None = None,
node_masks: dict[AttrKey, any_ndarray[bool]] | None = None,
*,
key_to_id: dict[NodeKey, IndexValue] | None = None,
id_to_key: list[NodeKey] | None = None,
Expand Down Expand Up @@ -247,8 +264,8 @@ def from_dcsc(
row_indices: cp.ndarray[IndexValue],
edge_values: dict[AttrKey, cp.ndarray[EdgeValue]] | None = None,
edge_masks: dict[AttrKey, cp.ndarray[bool]] | None = None,
node_values: dict[AttrKey, cp.ndarray[NodeValue]] | None = None,
node_masks: dict[AttrKey, cp.ndarray[bool]] | None = None,
node_values: dict[AttrKey, any_ndarray[NodeValue]] | None = None,
node_masks: dict[AttrKey, any_ndarray[bool]] | None = None,
*,
key_to_id: dict[NodeKey, IndexValue] | None = None,
id_to_key: list[NodeKey] | None = None,
Expand Down Expand Up @@ -666,7 +683,7 @@ def _nodearray_to_list(self, node_ids: cp.ndarray[IndexValue]) -> list[NodeKey]:
return list(self._nodeiter_to_iter(node_ids.tolist()))

def _nodearrays_to_dict(
self, node_ids: cp.ndarray[IndexValue], values: cp.ndarray[NodeValue]
self, node_ids: cp.ndarray[IndexValue], values: any_ndarray[NodeValue]
) -> dict[NodeKey, NodeValue]:
it = zip(node_ids.tolist(), values.tolist())
if (id_to_key := self.id_to_key) is not None:
Expand Down
21 changes: 11 additions & 10 deletions python/nx-cugraph/nx_cugraph/classes/multigraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
IndexValue,
NodeKey,
NodeValue,
any_ndarray,
)
__all__ = ["MultiGraph"]

Expand Down Expand Up @@ -73,8 +74,8 @@ def from_coo(
edge_indices: cp.ndarray[IndexValue] | None = None,
edge_values: dict[AttrKey, cp.ndarray[EdgeValue]] | None = None,
edge_masks: dict[AttrKey, cp.ndarray[bool]] | None = None,
node_values: dict[AttrKey, cp.ndarray[NodeValue]] | None = None,
node_masks: dict[AttrKey, cp.ndarray[bool]] | None = None,
node_values: dict[AttrKey, any_ndarray[NodeValue]] | None = None,
node_masks: dict[AttrKey, any_ndarray[bool]] | None = None,
*,
key_to_id: dict[NodeKey, IndexValue] | None = None,
id_to_key: list[NodeKey] | None = None,
Expand Down Expand Up @@ -111,8 +112,8 @@ def from_csr(
edge_indices: cp.ndarray[IndexValue] | None = None,
edge_values: dict[AttrKey, cp.ndarray[EdgeValue]] | None = None,
edge_masks: dict[AttrKey, cp.ndarray[bool]] | None = None,
node_values: dict[AttrKey, cp.ndarray[NodeValue]] | None = None,
node_masks: dict[AttrKey, cp.ndarray[bool]] | None = None,
node_values: dict[AttrKey, any_ndarray[NodeValue]] | None = None,
node_masks: dict[AttrKey, any_ndarray[bool]] | None = None,
*,
key_to_id: dict[NodeKey, IndexValue] | None = None,
id_to_key: list[NodeKey] | None = None,
Expand Down Expand Up @@ -147,8 +148,8 @@ def from_csc(
edge_indices: cp.ndarray[IndexValue] | None = None,
edge_values: dict[AttrKey, cp.ndarray[EdgeValue]] | None = None,
edge_masks: dict[AttrKey, cp.ndarray[bool]] | None = None,
node_values: dict[AttrKey, cp.ndarray[NodeValue]] | None = None,
node_masks: dict[AttrKey, cp.ndarray[bool]] | None = None,
node_values: dict[AttrKey, any_ndarray[NodeValue]] | None = None,
node_masks: dict[AttrKey, any_ndarray[bool]] | None = None,
*,
key_to_id: dict[NodeKey, IndexValue] | None = None,
id_to_key: list[NodeKey] | None = None,
Expand Down Expand Up @@ -185,8 +186,8 @@ def from_dcsr(
edge_indices: cp.ndarray[IndexValue] | None = None,
edge_values: dict[AttrKey, cp.ndarray[EdgeValue]] | None = None,
edge_masks: dict[AttrKey, cp.ndarray[bool]] | None = None,
node_values: dict[AttrKey, cp.ndarray[NodeValue]] | None = None,
node_masks: dict[AttrKey, cp.ndarray[bool]] | None = None,
node_values: dict[AttrKey, any_ndarray[NodeValue]] | None = None,
node_masks: dict[AttrKey, any_ndarray[bool]] | None = None,
*,
key_to_id: dict[NodeKey, IndexValue] | None = None,
id_to_key: list[NodeKey] | None = None,
Expand Down Expand Up @@ -222,8 +223,8 @@ def from_dcsc(
edge_indices: cp.ndarray[IndexValue] | None = None,
edge_values: dict[AttrKey, cp.ndarray[EdgeValue]] | None = None,
edge_masks: dict[AttrKey, cp.ndarray[bool]] | None = None,
node_values: dict[AttrKey, cp.ndarray[NodeValue]] | None = None,
node_masks: dict[AttrKey, cp.ndarray[bool]] | None = None,
node_values: dict[AttrKey, any_ndarray[NodeValue]] | None = None,
node_masks: dict[AttrKey, any_ndarray[bool]] | None = None,
*,
key_to_id: dict[NodeKey, IndexValue] | None = None,
id_to_key: list[NodeKey] | None = None,
Expand Down
31 changes: 24 additions & 7 deletions python/nx-cugraph/nx_cugraph/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .utils import index_dtype

if TYPE_CHECKING: # pragma: no cover
from nx_cugraph.typing import AttrKey, Dtype, EdgeValue, NodeValue
from nx_cugraph.typing import AttrKey, Dtype, EdgeValue, NodeValue, any_ndarray

__all__ = [
"from_networkx",
Expand Down Expand Up @@ -386,8 +386,18 @@ def from_networkx(
or present
for node_id in adj
)
node_masks[node_attr] = cp.fromiter(iter_mask, bool)
node_values[node_attr] = cp.array(vals, dtype)
# Node values may be numpy or cupy arrays (useful for str, object, etc).
# Someday we'll let the user choose np or cp, and support edge values.
node_mask = np.fromiter(iter_mask, bool)
node_value = np.array(vals, dtype)
try:
node_value = cp.array(node_value)
except ValueError:
pass
else:
node_mask = cp.array(node_mask)
node_values[node_attr] = node_value
node_masks[node_attr] = node_mask
# if vals.ndim > 1: ...
else:
if node_default is REQUIRED:
Expand All @@ -396,10 +406,17 @@ def from_networkx(
iter_values = (
nodes[node_id].get(node_attr, node_default) for node_id in adj
)
# Node values may be numpy or cupy arrays (useful for str, object, etc).
# Someday we'll let the user choose np or cp, and support edge values.
if dtype is None:
node_values[node_attr] = cp.array(list(iter_values))
node_value = np.array(list(iter_values))
else:
node_values[node_attr] = cp.fromiter(iter_values, dtype)
node_value = np.fromiter(iter_values, dtype)
try:
node_value = cp.array(node_value)
except ValueError:
pass
node_values[node_attr] = node_value
# if vals.ndim > 1: ...
if graph.is_multigraph():
if graph.is_directed() or as_directed:
Expand Down Expand Up @@ -439,8 +456,8 @@ def from_networkx(


def _iter_attr_dicts(
values: dict[AttrKey, cp.ndarray[EdgeValue | NodeValue]],
masks: dict[AttrKey, cp.ndarray[bool]],
values: dict[AttrKey, any_ndarray[EdgeValue | NodeValue]],
masks: dict[AttrKey, any_ndarray[bool]],
):
full_attrs = list(values.keys() - masks.keys())
if full_attrs:
Expand Down
4 changes: 2 additions & 2 deletions python/nx-cugraph/nx_cugraph/generators/social.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ def karate_club_graph():
],
np.int8,
)
# For now, cupy doesn't handle str dtypes and we only handle cupy arrays.
# This means we are definitely cheating by using a numpy array here! FIXME
# For now, cupy doesn't handle str dtypes and we primarily handle cupy arrays.
# We try to support numpy arrays for node values, so let's use numpy here.
clubs = np.array([
"Mr. Hi", "Mr. Hi", "Mr. Hi", "Mr. Hi", "Mr. Hi", "Mr. Hi", "Mr. Hi",
"Mr. Hi", "Mr. Hi", "Officer", "Mr. Hi", "Mr. Hi", "Mr. Hi", "Mr. Hi",
Expand Down
21 changes: 21 additions & 0 deletions python/nx-cugraph/nx_cugraph/tests/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import networkx as nx
import numpy as np
import pytest

import nx_cugraph as nxcg
Expand Down Expand Up @@ -234,3 +235,23 @@ def test_generator_m_n_complete_vanilla(name, m, n):

def test_bad_lollipop_graph():
compare("lollipop_graph", None, [0, 1], [1, 2])


def test_can_convert_karate_club():
# Karate club graph has string node values.
# This really tests conversions, but it's here so we can use `assert_graphs_equal`.
G = nx.karate_club_graph()
G.add_node(0, foo="bar") # string dtype with a mask
G.add_node(1, object=object()) # haha
Gcg = nxcg.from_networkx(G, preserve_all_attrs=True)
assert_graphs_equal(G, Gcg)
Gnx = nxcg.to_networkx(Gcg)
assert nx.utils.graphs_equal(G, Gnx)
assert isinstance(Gcg.node_values["club"], np.ndarray)
assert Gcg.node_values["club"].dtype.kind == "U"
assert isinstance(Gcg.node_values["foo"], np.ndarray)
assert isinstance(Gcg.node_masks["foo"], np.ndarray)
assert Gcg.node_values["foo"].dtype.kind == "U"
assert isinstance(Gcg.node_values["object"], np.ndarray)
assert Gcg.node_values["object"].dtype.kind == "O"
assert isinstance(Gcg.node_masks["object"], np.ndarray)
8 changes: 8 additions & 0 deletions python/nx-cugraph/nx_cugraph/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from collections.abc import Hashable
from typing import TypeVar

import cupy as cp
import numpy as np

AttrKey = TypeVar("AttrKey", bound=Hashable)
EdgeKey = TypeVar("EdgeKey", bound=Hashable)
NodeKey = TypeVar("NodeKey", bound=Hashable)
Expand All @@ -23,3 +26,8 @@
NodeValue = TypeVar("NodeValue")
IndexValue = TypeVar("IndexValue")
Dtype = TypeVar("Dtype")


class any_ndarray:
def __class_getitem__(cls, item):
return cp.ndarray[item] | np.ndarray[item]

0 comments on commit 13c8a6c

Please sign in to comment.