Skip to content

Commit

Permalink
Ensure edge data are contiguous cupy arrays when creating Graphs
Browse files Browse the repository at this point in the history
  • Loading branch information
eriknw committed Sep 18, 2024
1 parent 63cc689 commit 36faa76
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
18 changes: 16 additions & 2 deletions python/nx-cugraph/nx_cugraph/classes/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,19 @@ def from_coo(
) -> Graph:
new_graph = object.__new__(cls)
new_graph.__networkx_cache__ = {}
# Ensure edge data is contiguous; don't copy if they are. Indices handled below.
new_graph.src_indices = src_indices
new_graph.dst_indices = dst_indices
new_graph.edge_values = {} if edge_values is None else dict(edge_values)
new_graph.edge_masks = {} if edge_masks is None else dict(edge_masks)
new_graph.edge_values = (
{}
if edge_values is None
else {key: cp.asarray(val, order="C") for key, val in edge_values.items()}
)
new_graph.edge_masks = (
{}
if edge_masks is None
else {key: cp.asarray(val, order="C") for key, val in edge_masks.items()}
)
new_graph.node_values = {} if node_values is None else dict(node_values)
new_graph.node_masks = {} if node_masks is None else dict(node_masks)
new_graph.key_to_id = None if key_to_id is None else dict(key_to_id)
Expand Down Expand Up @@ -165,6 +174,11 @@ def from_coo(
)
new_graph.dst_indices = dst_indices

# Ensure edge data is contiguous; don't copy if they are. Values handled above.
# Do this now so we don't copy every time we create a PLC graph.
new_graph.src_indices = cp.asarray(new_graph.src_indices, order="C")
new_graph.dst_indices = cp.asarray(new_graph.dst_indices, order="C")

# If the graph contains isolates, plc.SGGraph() must be passed a value
# for vertices_array that contains every vertex ID, since the
# src/dst_indices arrays will not contain IDs for isolates. Create this
Expand Down
5 changes: 4 additions & 1 deletion python/nx-cugraph/nx_cugraph/classes/multigraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ def from_coo(
id_to_key=id_to_key,
**attr,
)
new_graph.edge_indices = edge_indices
# Ensure edge data is contiguous; don't copy if they are.
new_graph.edge_indices = (
None if edge_indices is None else cp.asarray(edge_indices, order="C")
)
new_graph.edge_keys = edge_keys
# Easy and fast sanity checks
if (
Expand Down

0 comments on commit 36faa76

Please sign in to comment.