Skip to content

Commit

Permalink
Detects and adds isolates when creating a plc.SGGraph, removes specia…
Browse files Browse the repository at this point in the history
…l case testing and skipped tests from prior PR needed due to the bug this PR fixes.
  • Loading branch information
rlratzel committed Jan 6, 2024
1 parent 0f808e3 commit bb6e744
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 16 deletions.
5 changes: 0 additions & 5 deletions ci/test_python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,6 @@ popd
rapids-logger "pytest networkx using nx-cugraph backend"
pushd python/nx-cugraph
./run_nx_tests.sh
# Individually run tests that are skipped above b/c they may run out of memory
PYTEST_NO_SKIP=True ./run_nx_tests.sh --cov-append -k "TestDAG and test_antichains"
PYTEST_NO_SKIP=True ./run_nx_tests.sh --cov-append -k "TestMultiDiGraph_DAGLCA and test_all_pairs_lca_pairs_without_lca"
PYTEST_NO_SKIP=True ./run_nx_tests.sh --cov-append -k "TestDAGLCA and test_all_pairs_lca_pairs_without_lca"
PYTEST_NO_SKIP=True ./run_nx_tests.sh --cov-append -k "TestEfficiency and test_using_ego_graph"
# run_nx_tests.sh outputs coverage data, so check that total coverage is >0.0%
# in case nx-cugraph failed to load but fallback mode allowed the run to pass.
_coverage=$(coverage report|grep "^TOTAL")
Expand Down
16 changes: 16 additions & 0 deletions python/nx-cugraph/nx_cugraph/classes/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class Graph:
key_to_id: dict[NodeKey, IndexValue] | None
_id_to_key: list[NodeKey] | None
_N: int
_node_ids: cp.ndarray[IndexValue] | None # holds plc.SGGraph.vertices_array data

# Used by graph._get_plc_graph
_plc_type_map: ClassVar[dict[np.dtype, np.dtype]] = {
Expand Down Expand Up @@ -116,6 +117,7 @@ def from_coo(
new_graph.key_to_id = None if key_to_id is None else dict(key_to_id)
new_graph._id_to_key = None if id_to_key is None else list(id_to_key)
new_graph._N = op.index(N) # Ensure N is integral
new_graph._node_ids = None
new_graph.graph = new_graph.graph_attr_dict_factory()
new_graph.graph.update(attr)
size = new_graph.src_indices.size
Expand Down Expand Up @@ -405,6 +407,7 @@ def clear(self) -> None:
self.src_indices = cp.empty(0, self.src_indices.dtype)
self.dst_indices = cp.empty(0, self.dst_indices.dtype)
self._N = 0
self._node_ids = None
self.key_to_id = None
self._id_to_key = None

Expand Down Expand Up @@ -637,6 +640,18 @@ def _get_plc_graph(
dst_indices = self.dst_indices
if switch_indices:
src_indices, dst_indices = dst_indices, src_indices

# If the graph contains isolates, plc.SGGraph() must be passed a
# value for vertices_array that contains every vertex ID, since the
# src/dst_indices arrays will not contain IDs for isolates.
all_node_ids = cp.arange(self._N, dtype=index_dtype)
isolates = cp.setdiff1d(cp.setdiff1d(all_node_ids, src_indices), dst_indices)
if len(isolates) == 0:
all_node_ids = None
# like self.src/dst_indices, the _node_ids array must be maintained for
# the lifetime of the plc.SGGraph
self._node_ids = all_node_ids

return plc.SGGraph(
resource_handle=plc.ResourceHandle(),
graph_properties=plc.GraphProperties(
Expand All @@ -649,6 +664,7 @@ def _get_plc_graph(
store_transposed=store_transposed,
renumber=False,
do_expensive_check=False,
vertices_array=self._node_ids,
)

def _sort_edge_indices(self, primary="src"):
Expand Down
11 changes: 0 additions & 11 deletions python/nx-cugraph/nx_cugraph/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,20 +242,9 @@ def key(testpath):
)

too_slow = "Too slow to run"
maybe_oom = "out of memory in CI"
skip = {
key("test_tree_isomorphism.py:test_positive"): too_slow,
key("test_tree_isomorphism.py:test_negative"): too_slow,
key("test_efficiency.py:TestEfficiency.test_using_ego_graph"): maybe_oom,
key("test_dag.py:TestDAG.test_antichains"): maybe_oom,
key(
"test_lowest_common_ancestors.py:"
"TestDAGLCA.test_all_pairs_lca_pairs_without_lca"
): maybe_oom,
key(
"test_lowest_common_ancestors.py:"
"TestMultiDiGraph_DAGLCA.test_all_pairs_lca_pairs_without_lca"
): maybe_oom,
# These repeatedly call `bfs_layers`, which converts the graph every call
key(
"test_vf2pp.py:TestGraphISOVF2pp.test_custom_graph2_different_labels"
Expand Down

0 comments on commit bb6e744

Please sign in to comment.