Skip to content

Commit

Permalink
Clearer, more maintainable; test stay on cpu too
Browse files Browse the repository at this point in the history
  • Loading branch information
eriknw committed Nov 12, 2024
1 parent 2a4ff72 commit 5fa886f
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 109 deletions.
45 changes: 31 additions & 14 deletions python/nx-cugraph/nx_cugraph/classes/digraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
__all__ = ["CudaDiGraph", "DiGraph"]

networkx_api = nxcg.utils.decorators.networkx_class(nx.DiGraph)
gpu_cpu_api = nxcg.utils.decorators._gpu_cpu_api(nx.DiGraph, __name__)


class DiGraph(nx.DiGraph, Graph):
Expand Down Expand Up @@ -110,21 +111,37 @@ def to_networkx_class(cls) -> type[nx.DiGraph]:
##########################

# Dispatch to nx.DiGraph or CudaDiGraph
__contains__ = Graph.__dict__["__contains__"]
__len__ = Graph.__dict__["__len__"]
__iter__ = Graph.__dict__["__iter__"]
get_edge_data = Graph.__dict__["get_edge_data"]
has_edge = Graph.__dict__["has_edge"]
neighbors = Graph.__dict__["neighbors"]
has_node = Graph.__dict__["has_node"]
nbunch_iter = Graph.__dict__["nbunch_iter"]
number_of_nodes = Graph.__dict__["number_of_nodes"]
order = Graph.__dict__["order"]
successors = Graph.__dict__["neighbors"] # Alias

clear = Graph.clear
clear_edges = Graph.clear_edges
__contains__ = gpu_cpu_api("__contains__")
__len__ = gpu_cpu_api("__len__")
__iter__ = gpu_cpu_api("__iter__")

@networkx_api
def clear(self) -> None:
cudagraph = self._cudagraph if self._is_on_gpu else None
if self._is_on_cpu:
super().clear()
if cudagraph is not None:
cudagraph.clear()
self._set_cudagraph(cudagraph, clear_cpu=False)

@networkx_api
def clear_edges(self) -> None:
cudagraph = self._cudagraph if self._is_on_gpu else None
if self._is_on_cpu:
super().clear_edges()
if cudagraph is not None:
cudagraph.clear_edges()
self._set_cudagraph(cudagraph, clear_cpu=False)

get_edge_data = gpu_cpu_api("get_edge_data", edge_data=True)
has_edge = gpu_cpu_api("has_edge")
neighbors = gpu_cpu_api("neighbors")
has_node = gpu_cpu_api("has_node")
nbunch_iter = gpu_cpu_api("nbunch_iter")
number_of_edges = Graph.number_of_edges
number_of_nodes = gpu_cpu_api("number_of_nodes")
order = gpu_cpu_api("order")
successors = gpu_cpu_api("successors")


class CudaDiGraph(CudaGraph):
Expand Down
74 changes: 13 additions & 61 deletions python/nx-cugraph/nx_cugraph/classes/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
__all__ = ["CudaGraph", "Graph"]

networkx_api = nxcg.utils.decorators.networkx_class(nx.Graph)
gpu_cpu_api = nxcg.utils.decorators._gpu_cpu_api(nx.Graph, __name__)

# The "everything" cache key is an internal implementation detail of NetworkX
# that may change between releases.
Expand Down Expand Up @@ -86,56 +87,6 @@ def clear(self) -> None:
super().clear()


class _graph_property:
"""Dispatch property to NetworkX or CudaGraph based on cache.
For example, this will use any cached CudaGraph for ``len(G)``, which
prevents creating NetworkX data structures.
"""

def __init__(self, attr, *, edge_data=False, node_data=False):
self._attr = attr
self._edge_data = edge_data
self._node_data = node_data

def __get__(self, instance, owner=None):
nx_class = owner.to_networkx_class()
if instance is None:
# Let's handle e.g. `nxcg.Graph.__len__` to look and behave correctly.
#
# If you want the instance of `_graph_property`, get it from the class dict:
# >>> nxcg.Graph.__dict__["__len__"]
#
# Alternatives:
# - `return op.methodcaller(self._attr)`
# - This dispatches, but does not have e.g. __name__
# - `return getattr(nx_class, self._attr)`
# - This does not dispatch--it always uses networkx--but does have attrs
prop = owner.__dict__[self._attr]

def inner(self, *args, **kwargs):
return prop.__get__(self, owner)(*args, **kwargs)

# Standard function-wrapping
nx_func = getattr(nx_class, self._attr)
inner.__name__ = nx_func.__name__
inner.__doc__ = nx_func.__doc__
inner.__qualname__ = nx_func.__qualname__
inner.__defaults__ = nx_func.__defaults__
inner.__kwdefaults__ = nx_func.__kwdefaults__
inner.__dict__.update(nx_func.__dict__)
inner.__module__ = owner.__module__
inner.__wrapped__ = nx_func
return inner

cuda_graph = instance._get_cudagraph(
edge_data=self._edge_data, node_data=self._node_data
)
if cuda_graph is not None:
return getattr(cuda_graph, self._attr)
return getattr(nx_class, self._attr).__get__(instance, owner)


class Graph(nx.Graph):
# Tell networkx to dispatch calls with this object to nx-cugraph
__networkx_backend__: ClassVar[str] = "cugraph" # nx >=3.2
Expand Down Expand Up @@ -592,9 +543,9 @@ def from_dcsc(
##########################

# Dispatch to nx.Graph or CudaGraph
__contains__ = _graph_property("__contains__")
__len__ = _graph_property("__len__")
__iter__ = _graph_property("__iter__")
__contains__ = gpu_cpu_api("__contains__")
__len__ = gpu_cpu_api("__len__")
__iter__ = gpu_cpu_api("__iter__")

@networkx_api
def clear(self) -> None:
Expand All @@ -614,11 +565,11 @@ def clear_edges(self) -> None:
cudagraph.clear_edges()
self._set_cudagraph(cudagraph, clear_cpu=False)

get_edge_data = _graph_property("get_edge_data", edge_data=True)
has_edge = _graph_property("has_edge")
neighbors = _graph_property("neighbors")
has_node = _graph_property("has_node")
nbunch_iter = _graph_property("nbunch_iter")
get_edge_data = gpu_cpu_api("get_edge_data", edge_data=True)
has_edge = gpu_cpu_api("has_edge")
neighbors = gpu_cpu_api("neighbors")
has_node = gpu_cpu_api("has_node")
nbunch_iter = gpu_cpu_api("nbunch_iter")

@networkx_api
def number_of_edges(
Expand All @@ -628,10 +579,11 @@ def number_of_edges(
# NotImplemented by CudaGraph
nx_class = self.to_networkx_class()
return nx_class.number_of_edges(self, u, v)
return _graph_property("number_of_edges").__get__(self, self.__class__)()
return self._number_of_edges(u, v)

number_of_nodes = _graph_property("number_of_nodes")
order = _graph_property("order")
_number_of_edges = gpu_cpu_api("number_of_edges")
number_of_nodes = gpu_cpu_api("number_of_nodes")
order = gpu_cpu_api("order")
# Future work: implement more graph methods, and handle e.g. `copy`


Expand Down
30 changes: 15 additions & 15 deletions python/nx-cugraph/nx_cugraph/classes/multidigraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
__all__ = ["CudaMultiDiGraph", "MultiDiGraph"]

networkx_api = nxcg.utils.decorators.networkx_class(nx.MultiDiGraph)
gpu_cpu_api = nxcg.utils.decorators._gpu_cpu_api(nx.MultiDiGraph, __name__)


class MultiDiGraph(nx.MultiDiGraph, MultiGraph, DiGraph):
Expand Down Expand Up @@ -55,21 +56,20 @@ def to_networkx_class(cls) -> type[nx.MultiDiGraph]:
##########################

# Dispatch to nx.MultiDiGraph or CudaMultiDiGraph
__contains__ = Graph.__dict__["__contains__"]
__len__ = Graph.__dict__["__len__"]
__iter__ = Graph.__dict__["__iter__"]
get_edge_data = Graph.__dict__["get_edge_data"]
has_edge = Graph.__dict__["has_edge"]
neighbors = Graph.__dict__["neighbors"]
has_node = Graph.__dict__["has_node"]
nbunch_iter = Graph.__dict__["nbunch_iter"]
number_of_nodes = Graph.__dict__["number_of_nodes"]
order = Graph.__dict__["order"]
successors = Graph.__dict__["neighbors"] # Alias

clear = Graph.clear
clear_edges = Graph.clear_edges
number_of_edges = Graph.number_of_edges
__contains__ = gpu_cpu_api("__contains__")
__len__ = gpu_cpu_api("__len__")
__iter__ = gpu_cpu_api("__iter__")
clear = DiGraph.clear
clear_edges = DiGraph.clear_edges
get_edge_data = gpu_cpu_api("get_edge_data", edge_data=True)
has_edge = gpu_cpu_api("has_edge")
neighbors = gpu_cpu_api("neighbors")
has_node = gpu_cpu_api("has_node")
nbunch_iter = gpu_cpu_api("nbunch_iter")
number_of_edges = MultiGraph.number_of_edges
number_of_nodes = gpu_cpu_api("number_of_nodes")
order = gpu_cpu_api("order")
successors = gpu_cpu_api("successors")


class CudaMultiDiGraph(CudaMultiGraph, CudaDiGraph):
Expand Down
35 changes: 23 additions & 12 deletions python/nx-cugraph/nx_cugraph/classes/multigraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
__all__ = ["MultiGraph", "CudaMultiGraph"]

networkx_api = nxcg.utils.decorators.networkx_class(nx.MultiGraph)
gpu_cpu_api = nxcg.utils.decorators._gpu_cpu_api(nx.MultiGraph, __name__)


class MultiGraph(nx.MultiGraph, Graph):
Expand Down Expand Up @@ -282,20 +283,30 @@ def from_dcsc(
##########################

# Dispatch to nx.MultiGraph or CudaMultiGraph
__contains__ = Graph.__dict__["__contains__"]
__len__ = Graph.__dict__["__len__"]
__iter__ = Graph.__dict__["__iter__"]
get_edge_data = Graph.__dict__["get_edge_data"]
has_edge = Graph.__dict__["has_edge"]
neighbors = Graph.__dict__["neighbors"]
has_node = Graph.__dict__["has_node"]
nbunch_iter = Graph.__dict__["nbunch_iter"]
number_of_nodes = Graph.__dict__["number_of_nodes"]
order = Graph.__dict__["order"]

__contains__ = gpu_cpu_api("__contains__")
__len__ = gpu_cpu_api("__len__")
__iter__ = gpu_cpu_api("__iter__")
clear = Graph.clear
clear_edges = Graph.clear_edges
number_of_edges = Graph.number_of_edges
get_edge_data = gpu_cpu_api("get_edge_data", edge_data=True)
has_edge = gpu_cpu_api("has_edge")
neighbors = gpu_cpu_api("neighbors")
has_node = gpu_cpu_api("has_node")
nbunch_iter = gpu_cpu_api("nbunch_iter")

@networkx_api
def number_of_edges(
self, u: NodeKey | None = None, v: NodeKey | None = None
) -> int:
if u is not None or v is not None:
# NotImplemented by CudaGraph
nx_class = self.to_networkx_class()
return nx_class.number_of_edges(self, u, v)
return self._number_of_edges(u, v)

_number_of_edges = gpu_cpu_api("number_of_edges")
number_of_nodes = gpu_cpu_api("number_of_nodes")
order = gpu_cpu_api("order")


class CudaMultiGraph(CudaGraph):
Expand Down
19 changes: 12 additions & 7 deletions python/nx-cugraph/nx_cugraph/tests/test_graph_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,21 +90,26 @@ def test_multidigraph_to_undirected():
("nbunch_iter", ([0, 1],)),
],
)
def test_method_does_not_create_host_data(create_using, method):
@pytest.mark.parametrize("where", ["gpu", "cpu"])
def test_method_does_not_convert_to_cpu_or_gpu(create_using, method, where):
attr, args = method
if attr == "successors" and not create_using.is_directed():
return
G = nxcg.complete_graph(3, create_using=create_using)
assert G._is_on_gpu
assert not G._is_on_cpu
is_on_gpu = where == "gpu"
is_on_cpu = where == "cpu"
if is_on_cpu:
G.add_edge(10, 20)
assert G._is_on_gpu == is_on_gpu
assert G._is_on_cpu == is_on_cpu
getattr(G, attr)(*args)
assert G._is_on_gpu
assert not G._is_on_cpu
assert G._is_on_gpu == is_on_gpu
assert G._is_on_cpu == is_on_cpu
# Also usable from the class and dispatches correctly
func = getattr(create_using, attr)
func(G, *args)
assert G._is_on_gpu
assert not G._is_on_cpu
assert G._is_on_gpu == is_on_gpu
assert G._is_on_cpu == is_on_cpu
# Basic "looks like networkx" checks
nx_class = create_using.to_networkx_class()
nx_func = getattr(nx_class, attr)
Expand Down
28 changes: 28 additions & 0 deletions python/nx-cugraph/nx_cugraph/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,31 @@ def _default_should_run(*args, **kwargs):

def _restore_networkx_dispatched(name):
return getattr(BackendInterface, name)


def _gpu_cpu_api(nx_class, module_name):
def _gpu_cpu_graph_method(attr, *, edge_data=False, node_data=False):
"""Dispatch property to NetworkX or CudaGraph based on cache.
For example, this will use any cached CudaGraph for ``len(G)``, which
prevents creating NetworkX data structures.
"""
nx_func = getattr(nx_class, attr)

def inner(self, *args, **kwargs):
cuda_graph = self._get_cudagraph(edge_data=edge_data, node_data=node_data)
if cuda_graph is None:
return nx_func(self, *args, **kwargs)
return getattr(cuda_graph, attr)(*args, **kwargs)

inner.__name__ = nx_func.__name__
inner.__doc__ = nx_func.__doc__
inner.__qualname__ = nx_func.__qualname__
inner.__defaults__ = nx_func.__defaults__
inner.__kwdefaults__ = nx_func.__kwdefaults__
inner.__module__ = module_name
inner.__dict__.update(nx_func.__dict__)
inner.__wrapped__ = nx_func
return inner

return _gpu_cpu_graph_method

0 comments on commit 5fa886f

Please sign in to comment.