Skip to content

Commit

Permalink
Rename Graph->CudaGraph and ZeroGraph->Graph; many improvements
Browse files Browse the repository at this point in the history
In addition to renaming and moving graph classes around, this smooths out
*many* rough edges for zero-code change use cases. There is more to be
done, but this is excellent progress and should be good enough to use.
Now to see about getting CI to pass!
  • Loading branch information
eriknw committed Sep 9, 2024
1 parent 04f88ff commit d615fed
Show file tree
Hide file tree
Showing 32 changed files with 848 additions and 550 deletions.
3 changes: 2 additions & 1 deletion ci/run_nx_cugraph_pytests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ set -euo pipefail
# Support invoking run_nx_cugraph_pytests.sh outside the script directory
cd "$(dirname "$(realpath "${BASH_SOURCE[0]}")")"/../python/nx-cugraph/nx_cugraph

NX_CUGRAPH_ZERO=False pytest --capture=no --cache-clear --benchmark-disable "$@" tests
NX_CUGRAPH_USE_COMPAT_GRAPHS=False pytest --capture=no --cache-clear --benchmark-disable "$@" tests
NX_CUGRAPH_USE_COMPAT_GRAPHS=True pytest --capture=no --cache-clear --benchmark-disable "$@" tests
2 changes: 1 addition & 1 deletion ci/test_python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ echo "nx-cugraph coverage from networkx tests: $_coverage"
echo $_coverage | awk '{ if ($NF == "0.0%") exit 1 }'
# Ensure all algorithms were called by comparing covered lines to function lines.
# Run our tests again (they're fast enough) to add their coverage, then create coverage.json
NX_CUGRAPH_ZERO=False pytest \
NX_CUGRAPH_USE_COMPAT_GRAPHS=False pytest \
--pyargs nx_cugraph \
--config-file=../pyproject.toml \
--cov-config=../pyproject.toml \
Expand Down
2 changes: 1 addition & 1 deletion ci/test_wheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ else
DASK_DISTRIBUTED__SCHEDULER__WORKER_TTL="1000s" \
DASK_DISTRIBUTED__COMM__TIMEOUTS__CONNECT="1000s" \
DASK_CUDA_WAIT_WORKERS_MIN_TIMEOUT="1000s" \
NX_CUGRAPH_ZERO=False \
NX_CUGRAPH_USE_COMPAT_GRAPHS=False \
python -m pytest \
-v \
--import-mode=append \
Expand Down
5 changes: 4 additions & 1 deletion python/nx-cugraph/_nx_cugraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,10 @@ def get_info():
del d[key]

d["default_config"] = {
"zero": os.environ.get("NX_CUGRAPH_ZERO", "true").strip().lower() == "true",
"use_compat_graphs": os.environ.get("NX_CUGRAPH_USE_COMPAT_GRAPHS", "true")
.strip()
.lower()
== "true",
}
return d

Expand Down
16 changes: 8 additions & 8 deletions python/nx-cugraph/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ repos:
- id: mixed-line-ending
- id: trailing-whitespace
- repo: https://github.com/abravalheri/validate-pyproject
rev: v0.18
rev: v0.19
hooks:
- id: validate-pyproject
name: Validate pyproject.toml
Expand All @@ -40,29 +40,29 @@ repos:
hooks:
- id: isort
- repo: https://github.com/asottile/pyupgrade
rev: v3.16.0
rev: v3.17.0
hooks:
- id: pyupgrade
args: [--py310-plus]
- repo: https://github.com/psf/black
rev: 24.4.2
rev: 24.8.0
hooks:
- id: black
# - id: black-jupyter
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.4
rev: v0.6.4
hooks:
- id: ruff
args: [--fix-only, --show-fixes] # --unsafe-fixes]
- repo: https://github.com/PyCQA/flake8
rev: 7.1.0
rev: 7.1.1
hooks:
- id: flake8
args: ['--per-file-ignores=_nx_cugraph/__init__.py:E501', '--extend-ignore=B020,SIM105'] # Why is this necessary?
additional_dependencies: &flake8_dependencies
# These versions need updated manually
- flake8==7.1.0
- flake8-bugbear==24.4.26
- flake8==7.1.1
- flake8-bugbear==24.8.19
- flake8-simplify==0.21.0
- repo: https://github.com/asottile/yesqa
rev: v1.5.0
Expand All @@ -77,7 +77,7 @@ repos:
additional_dependencies: [tomli]
files: ^(nx_cugraph|docs)/
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.5.4
rev: v0.6.4
hooks:
- id: ruff
- repo: https://github.com/pre-commit/pre-commit-hooks
Expand Down
8 changes: 4 additions & 4 deletions python/nx-cugraph/nx_cugraph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@

from .interface import BackendInterface

BackendInterface.Graph = classes.ZeroGraph
BackendInterface.DiGraph = classes.ZeroDiGraph
BackendInterface.MultiGraph = classes.ZeroMultiGraph
BackendInterface.MultiDiGraph = classes.ZeroMultiDiGraph
BackendInterface.Graph = classes.Graph
BackendInterface.DiGraph = classes.DiGraph
BackendInterface.MultiGraph = classes.MultiGraph
BackendInterface.MultiDiGraph = classes.MultiDiGraph
del BackendInterface
6 changes: 3 additions & 3 deletions python/nx-cugraph/nx_cugraph/algorithms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ def _(G):
@networkx_algorithm(is_incomplete=True, version_added="23.12", _plc="k_truss_subgraph")
def k_truss(G, k):
if is_nx := isinstance(G, nx.Graph):
zero = isinstance(G, nxcg.ZeroGraph)
is_compat_graph = isinstance(G, nxcg.Graph)
G = nxcg.from_networkx(G, preserve_all_attrs=True)
else:
zero = False
is_compat_graph = False
if nxcg.number_of_selfloops(G) > 0:
if _nxver <= (3, 2):
exc_class = nx.NetworkXError
Expand Down Expand Up @@ -132,7 +132,7 @@ def k_truss(G, k):
node_values,
node_masks,
key_to_id=key_to_id,
zero=zero,
use_compat_graph=is_compat_graph,
)
new_graph.graph.update(G.graph)
return new_graph
14 changes: 7 additions & 7 deletions python/nx-cugraph/nx_cugraph/algorithms/operators/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

@networkx_algorithm(version_added="24.02")
def complement(G):
zero = isinstance(G, nxcg.ZeroGraph)
is_compat_graph = isinstance(G, nxcg.Graph)
G = _to_graph(G)
N = G._N
# Upcast to int64 so indices don't overflow.
Expand All @@ -44,7 +44,7 @@ def complement(G):
src_indices.astype(index_dtype),
dst_indices.astype(index_dtype),
key_to_id=G.key_to_id,
zero=zero,
use_compat_graph=is_compat_graph,
)


Expand All @@ -53,16 +53,16 @@ def reverse(G, copy=True):
if not G.is_directed():
raise nx.NetworkXError("Cannot reverse an undirected graph.")
if isinstance(G, nx.Graph):
zero = isinstance(G, nxcg.ZeroGraph)
if not copy and not zero:
is_compat_graph = isinstance(G, nxcg.Graph)
if not copy and not is_compat_graph:
raise RuntimeError(
"Using `copy=False` is invalid when using a NetworkX graph "
"as input to `nx_cugraph.reverse`"
)
G = nxcg.from_networkx(G, preserve_all_attrs=True)
else:
zero = False
is_compat_graph = False
rv = G.reverse(copy=copy)
if zero:
return rv.to_zero()
if is_compat_graph:
return rv._to_compat_graph()
return rv
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,15 @@ def bfs_tree(G, source, reverse=False, depth_limit=None, sort_neighbors=None):
raise NotImplementedError(
"sort_neighbors argument in bfs_tree is not currently supported"
)
zero = isinstance(G, nxcg.ZeroGraph)
is_compat_graph = isinstance(G, nxcg.Graph)
G = _check_G_and_source(G, source)
if depth_limit is not None and depth_limit < 1:
return nxcg.DiGraph.from_coo(
return nxcg.CudaDiGraph.from_coo(
1,
cp.array([], dtype=index_dtype),
cp.array([], dtype=index_dtype),
id_to_key=[source],
zero=zero,
use_compat_graph=is_compat_graph,
)

distances, predecessors, node_ids = _bfs(
Expand All @@ -151,12 +151,12 @@ def bfs_tree(G, source, reverse=False, depth_limit=None, sort_neighbors=None):
reverse=reverse,
)
if predecessors.size == 0:
return nxcg.DiGraph.from_coo(
return nxcg.CudaDiGraph.from_coo(
1,
cp.array([], dtype=index_dtype),
cp.array([], dtype=index_dtype),
id_to_key=[source],
zero=zero,
use_compat_graph=is_compat_graph,
)
# TODO: create renumbering helper function(s)
unique_node_ids = cp.unique(cp.hstack((predecessors, node_ids)))
Expand All @@ -174,12 +174,12 @@ def bfs_tree(G, source, reverse=False, depth_limit=None, sort_neighbors=None):
old_index: new_index
for new_index, old_index in enumerate(unique_node_ids.tolist())
}
return nxcg.DiGraph.from_coo(
return nxcg.CudaDiGraph.from_coo(
unique_node_ids.size,
src_indices,
dst_indices,
key_to_id=key_to_id,
zero=zero,
use_compat_graph=is_compat_graph,
)


Expand Down
9 changes: 4 additions & 5 deletions python/nx-cugraph/nx_cugraph/classes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .zero import ZeroGraph, ZeroDiGraph, ZeroMultiGraph, ZeroMultiDiGraph
from .graph import Graph
from .digraph import DiGraph
from .multigraph import MultiGraph
from .multidigraph import MultiDiGraph
from .graph import CudaGraph, Graph
from .digraph import CudaDiGraph, DiGraph
from .multigraph import CudaMultiGraph, MultiGraph
from .multidigraph import CudaMultiDiGraph, MultiDiGraph

from .function import *
91 changes: 80 additions & 11 deletions python/nx-cugraph/nx_cugraph/classes/digraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,107 @@
import cupy as cp
import networkx as nx
import numpy as np
from networkx.classes.digraph import (
_CachedPropertyResetterAdjAndSucc,
_CachedPropertyResetterPred,
)

import nx_cugraph as nxcg

from ..utils import index_dtype
from .graph import Graph
from .zero import ZeroDiGraph
from .graph import CudaGraph, Graph

if TYPE_CHECKING: # pragma: no cover
from nx_cugraph.typing import AttrKey

__all__ = ["DiGraph"]
__all__ = ["CudaDiGraph", "DiGraph"]

networkx_api = nxcg.utils.decorators.networkx_class(nx.DiGraph)


class DiGraph(Graph):
#################
# Class methods #
#################
class DiGraph(nx.DiGraph, Graph):
_nx_attrs = ("_node", "_adj", "_succ", "_pred")

name = Graph.name
_node = Graph._node

@property
@networkx_api
def _adj(self):
if (adj := self.__dict__["_adj"]) is None:
self._reify_networkx()
adj = self.__dict__["_adj"]
return adj

@_adj.setter
def _adj(self, val):
self._prepare_setter()
_CachedPropertyResetterAdjAndSucc.__set__(None, self, val)
if cache := getattr(self, "__networkx_cache__", None):
cache.clear()

@property
@networkx_api
def _succ(self):
if (succ := self.__dict__["_succ"]) is None:
self._reify_networkx()
succ = self.__dict__["_succ"]
return succ

@_succ.setter
def _succ(self, val):
self._prepare_setter()
_CachedPropertyResetterAdjAndSucc.__set__(None, self, val)
if cache := getattr(self, "__networkx_cache__", None):
cache.clear()

@property
@networkx_api
def _pred(self):
if (pred := self.__dict__["_pred"]) is None:
self._reify_networkx()
pred = self.__dict__["_pred"]
return pred

@_pred.setter
def _pred(self, val):
self._prepare_setter()
_CachedPropertyResetterPred.__set__(None, self, val)
if cache := getattr(self, "__networkx_cache__", None):
cache.clear()

@classmethod
@networkx_api
def is_directed(cls) -> bool:
return True

@classmethod
@networkx_api
def is_multigraph(cls) -> bool:
return False

@classmethod
def to_cudagraph_class(cls) -> type[CudaDiGraph]:
return CudaDiGraph

@classmethod
def to_networkx_class(cls) -> type[nx.DiGraph]:
return nx.DiGraph


class CudaDiGraph(CudaGraph):
#################
# Class methods #
#################

is_directed = classmethod(DiGraph.is_directed.__func__)
is_multigraph = classmethod(DiGraph.is_multigraph.__func__)
to_cudagraph_class = classmethod(DiGraph.to_cudagraph_class.__func__)
to_networkx_class = classmethod(DiGraph.to_networkx_class.__func__)

@classmethod
def to_zero_class(cls) -> type[ZeroDiGraph]:
return ZeroDiGraph
def _to_compat_graph_class(cls) -> type[DiGraph]:
return DiGraph

@networkx_api
def size(self, weight: AttrKey | None = None) -> int:
Expand All @@ -62,7 +131,7 @@ def size(self, weight: AttrKey | None = None) -> int:
##########################

@networkx_api
def reverse(self, copy: bool = True) -> DiGraph:
def reverse(self, copy: bool = True) -> CudaDiGraph:
return self._copy(not copy, self.__class__, reverse=True)

@networkx_api
Expand Down Expand Up @@ -167,7 +236,7 @@ def to_undirected(self, reciprocal=False, as_view=False):
node_masks,
key_to_id=key_to_id,
id_to_key=id_to_key,
zero=False,
use_compat_graph=False,
)
if as_view:
rv.graph = self.graph
Expand Down
Loading

0 comments on commit d615fed

Please sign in to comment.