Skip to content

Commit

Permalink
nx-cugraph: handle louvain with isolated nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
eriknw committed Sep 28, 2023
1 parent 91fbcca commit 063c7a8
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 7 deletions.
3 changes: 2 additions & 1 deletion python/nx-cugraph/_nx_cugraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@
# BEGIN: functions
"betweenness_centrality",
"edge_betweenness_centrality",
"isolates",
"louvain_communities",
# END: functions
},
"extra_docstrings": {
# BEGIN: extra_docstrings
"betweenness_centrality": "`weight` parameter is not yet supported.",
"edge_betweenness_centrality": "`weight` parameter is not yet supported.",
"louvain_communities": "`threshold` and `seed` parameters are currently ignored.",
"louvain_communities": "`seed` parameter is currently ignored.",
# END: extra_docstrings
},
"extra_parameters": {
Expand Down
1 change: 1 addition & 0 deletions python/nx-cugraph/nx_cugraph/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
# limitations under the License.
from . import centrality, community
from .centrality import *
from .isolate import *
24 changes: 20 additions & 4 deletions python/nx-cugraph/nx_cugraph/algorithms/community/louvain.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import warnings

import pylibcugraph as plc

Expand All @@ -22,6 +22,8 @@
not_implemented_for,
)

from ..isolate import _isolates

__all__ = ["louvain_communities"]


Expand All @@ -34,15 +36,22 @@
def louvain_communities(
G, weight="weight", resolution=1, threshold=0.0000001, seed=None, *, max_level=None
):
"""`threshold` and `seed` parameters are currently ignored."""
"""`seed` parameter is currently ignored."""
# NetworkX allows both directed and undirected, but cugraph only allows undirected.
seed = _seed_to_int(seed) # Unused, but ensure it's valid for future compatibility
G = _to_undirected_graph(G, weight)
if G.row_indices.size == 0:
# TODO: PLC doesn't handle empty graphs gracefully!
return [{key} for key in G._nodeiter_to_iter(range(len(G)))]
if max_level is None:
max_level = sys.maxsize
max_level = 500
elif max_level > 500:
warnings.warn(
f"max_level is set too high (={max_level}), setting it to 500.",
UserWarning,
stacklevel=2,
)
max_level = 500
vertices, clusters, modularity = plc.louvain(
resource_handle=plc.ResourceHandle(),
graph=G._get_plc_graph(),
Expand All @@ -52,7 +61,14 @@ def louvain_communities(
do_expensive_check=False,
)
groups = _groupby(clusters, vertices)
return [set(G._nodearray_to_list(node_ids)) for node_ids in groups.values()]
rv = [set(G._nodearray_to_list(node_ids)) for node_ids in groups.values()]
# TODO: PLC doesn't handle isolated vertices yet, so this is a temporary fix
isolates = _isolates(G)
if isolates.size > 0:
isolates = isolates[isolates > vertices.max()]
if isolates.size > 0:
rv.extend({node} for node in G._nodearray_to_list(isolates))
return rv


@louvain_communities._can_run
Expand Down
33 changes: 33 additions & 0 deletions python/nx-cugraph/nx_cugraph/algorithms/isolate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cupy as cp

from nx_cugraph.convert import _to_graph
from nx_cugraph.utils import networkx_algorithm

__all__ = ["isolates"]


def _isolates(G) -> cp.ndarray:
G = _to_graph(G)
mark_isolates = cp.ones(len(G), bool)
mark_isolates[G.row_indices] = False
if G.is_directed():
mark_isolates[G.col_indices] = False
return cp.nonzero(mark_isolates)[0]


@networkx_algorithm
def isolates(G):
G = _to_graph(G)
return G._nodeiter_to_iter(iter(_isolates(G).tolist()))
4 changes: 2 additions & 2 deletions python/nx-cugraph/nx_cugraph/classes/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,9 @@ def from_dcsc(
def __new__(cls, incoming_graph_data=None, **attr) -> Graph:
if incoming_graph_data is None:
new_graph = cls.from_coo(0, cp.empty(0, np.int32), cp.empty(0, np.int32))
elif incoming_graph_data.__class__ is new_graph.__class__:
elif incoming_graph_data.__class__ is cls:
new_graph = incoming_graph_data.copy()
elif incoming_graph_data.__class__ is new_graph.to_networkx_class():
elif incoming_graph_data.__class__ is cls.to_networkx_class():
new_graph = nxcg.from_networkx(incoming_graph_data, preserve_all_attrs=True)
else:
raise NotImplementedError
Expand Down
42 changes: 42 additions & 0 deletions python/nx-cugraph/nx_cugraph/tests/test_community.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import networkx as nx

import nx_cugraph as nxcg


def test_louvain_isolated_nodes():
def check(left, right):
assert len(left) == len(right)
assert set(map(frozenset, left)) == set(map(frozenset, right))

# Empty graph (no nodes)
G = nx.Graph()
nx_result = nx.community.louvain_communities(G)
cg_result = nxcg.community.louvain_communities(G)
check(nx_result, cg_result)
# Graph with no edges
G.add_nodes_from(range(5))
nx_result = nx.community.louvain_communities(G)
cg_result = nxcg.community.louvain_communities(G)
check(nx_result, cg_result)
# Graph with isolated nodes
G.add_edge(1, 2)
nx_result = nx.community.louvain_communities(G)
cg_result = nxcg.community.louvain_communities(G)
check(nx_result, cg_result)
# Another one
G.add_edge(4, 4)
nx_result = nx.community.louvain_communities(G)
cg_result = nxcg.community.louvain_communities(G)
check(nx_result, cg_result)

0 comments on commit 063c7a8

Please sign in to comment.