Skip to content

Commit

Permalink
Update cugraph-dgl conv layers to use improved graph class (#3849)
Browse files Browse the repository at this point in the history
This PR:
- Removes the usage of the deprecated `StaticCSC` and `SampledCSC`
- Support creating CSR and storing edge information in SparseGraph
- clean up unit tests
- Adds GATv2Conv layer
- Adds `pylibcugraphops` as a dependency of `cugraph-dgl` conda package

Authors:
  - Tingyu Wang (https://github.com/tingyu66)

Approvers:
  - Jake Awe (https://github.com/AyodeAwe)
  - Vibhu Jawa (https://github.com/VibhuJawa)
  - Brad Rees (https://github.com/BradReesWork)

URL: #3849
  • Loading branch information
tingyu66 authored Sep 19, 2023
1 parent 5f76161 commit b2e85bf
Show file tree
Hide file tree
Showing 17 changed files with 978 additions and 345 deletions.
1 change: 1 addition & 0 deletions conda/recipes/cugraph-dgl/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ requirements:
- dgl >=1.1.0.cu*
- numba >=0.57
- numpy >=1.21
- pylibcugraphops ={{ version }}
- python
- pytorch

Expand Down
2 changes: 2 additions & 0 deletions python/cugraph-dgl/cugraph_dgl/nn/conv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@

from .base import SparseGraph
from .gatconv import GATConv
from .gatv2conv import GATv2Conv
from .relgraphconv import RelGraphConv
from .sageconv import SAGEConv
from .transformerconv import TransformerConv

__all__ = [
"SparseGraph",
"GATConv",
"GATv2Conv",
"RelGraphConv",
"SAGEConv",
"TransformerConv",
Expand Down
262 changes: 195 additions & 67 deletions python/cugraph-dgl/cugraph_dgl/nn/conv/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,38 +17,7 @@

torch = import_optional("torch")
ops_torch = import_optional("pylibcugraphops.pytorch")


class BaseConv(torch.nn.Module):
r"""An abstract base class for cugraph-ops nn module."""

def __init__(self):
super().__init__()
self._cached_offsets_fg = None

def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
raise NotImplementedError

def forward(self, *args):
r"""Runs the forward pass of the module."""
raise NotImplementedError

def pad_offsets(self, offsets: torch.Tensor, size: int) -> torch.Tensor:
r"""Pad zero-in-degree nodes to the end of offsets to reach size. This
is used to augment offset tensors from DGL blocks (MFGs) to be
compatible with cugraph-ops full-graph primitives."""
if self._cached_offsets_fg is None:
self._cached_offsets_fg = torch.empty(
size, dtype=offsets.dtype, device=offsets.device
)
elif self._cached_offsets_fg.numel() < size:
self._cached_offsets_fg.resize_(size)

self._cached_offsets_fg[: offsets.numel()] = offsets
self._cached_offsets_fg[offsets.numel() : size] = offsets[-1]

return self._cached_offsets_fg[:size]
dgl = import_optional("dgl")


def compress_ids(ids: torch.Tensor, size: int) -> torch.Tensor:
Expand All @@ -63,8 +32,9 @@ def decompress_ids(c_ids: torch.Tensor) -> torch.Tensor:


class SparseGraph(object):
r"""A god-class to store different sparse formats needed by cugraph-ops
and facilitate sparse format conversions.
r"""A class to create and store different sparse formats needed by
cugraph-ops. It always creates a CSC representation and can provide COO- or
CSR-format if needed.
Parameters
----------
Expand All @@ -89,25 +59,43 @@ class SparseGraph(object):
consists of the sources between `src_indices[cdst_indices[k]]` and
`src_indices[cdst_indices[k+1]]`.
dst_ids_is_sorted: bool
Whether `dst_ids` has been sorted in an ascending order. When sorted,
creating CSC layout is much faster.
values: torch.Tensor, optional
Values on the edges.
is_sorted: bool
Whether the COO inputs (src_ids, dst_ids, values) have been sorted by
`dst_ids` in an ascending order. CSC layout creation is much faster
when sorted.
formats: str or tuple of str, optional
The desired sparse formats to create for the graph.
The desired sparse formats to create for the graph. The formats tuple
must include "csc". Default: "csc".
reduce_memory: bool, optional
When set, the tensors are not required by the desired formats will be
set to `None`.
set to `None`. Default: True.
Notes
-----
For MFGs (sampled graphs), the node ids must have been renumbered.
"""

supported_formats = {"coo": ("src_ids", "dst_ids"), "csc": ("cdst_ids", "src_ids")}

all_tensors = set(["src_ids", "dst_ids", "csrc_ids", "cdst_ids"])
supported_formats = {
"coo": ("_src_ids", "_dst_ids"),
"csc": ("_cdst_ids", "_src_ids"),
"csr": ("_csrc_ids", "_dst_ids", "_perm_csc2csr"),
}

all_tensors = set(
[
"_src_ids",
"_dst_ids",
"_csrc_ids",
"_cdst_ids",
"_perm_coo2csc",
"_perm_csc2csr",
]
)

def __init__(
self,
Expand All @@ -116,15 +104,19 @@ def __init__(
dst_ids: Optional[torch.Tensor] = None,
csrc_ids: Optional[torch.Tensor] = None,
cdst_ids: Optional[torch.Tensor] = None,
dst_ids_is_sorted: bool = False,
formats: Optional[Union[str, Tuple[str]]] = None,
values: Optional[torch.Tensor] = None,
is_sorted: bool = False,
formats: Union[str, Tuple[str]] = "csc",
reduce_memory: bool = True,
):
self._num_src_nodes, self._num_dst_nodes = size
self._dst_ids_is_sorted = dst_ids_is_sorted
self._is_sorted = is_sorted

if dst_ids is None and cdst_ids is None:
raise ValueError("One of 'dst_ids' and 'cdst_ids' must be given.")
raise ValueError(
"One of 'dst_ids' and 'cdst_ids' must be given "
"to create a SparseGraph."
)

if src_ids is not None:
src_ids = src_ids.contiguous()
Expand All @@ -148,30 +140,47 @@ def __init__(
)
cdst_ids = cdst_ids.contiguous()

if values is not None:
values = values.contiguous()

self._src_ids = src_ids
self._dst_ids = dst_ids
self._csrc_ids = csrc_ids
self._cdst_ids = cdst_ids
self._perm = None
self._values = values
self._perm_coo2csc = None
self._perm_csc2csr = None

if isinstance(formats, str):
formats = (formats,)

if formats is not None:
for format_ in formats:
assert format_ in SparseGraph.supported_formats
self.__getattribute__(f"_create_{format_}")()
self._formats = formats

if "csc" not in formats:
raise ValueError(
f"{self.__class__.__name__}.formats must contain "
f"'csc', but got {formats}."
)

# always create csc first
if self._cdst_ids is None:
if not self._is_sorted:
self._dst_ids, self._perm_coo2csc = torch.sort(self._dst_ids)
self._src_ids = self._src_ids[self._perm_coo2csc]
if self._values is not None:
self._values = self._values[self._perm_coo2csc]
self._cdst_ids = compress_ids(self._dst_ids, self._num_dst_nodes)

for format_ in formats:
assert format_ in SparseGraph.supported_formats
self.__getattribute__(f"{format_}")()

self._reduce_memory = reduce_memory
if reduce_memory:
self.reduce_memory()

def reduce_memory(self):
"""Remove the tensors that are not necessary to create the desired sparse
formats to reduce memory footprint."""

self._perm = None
if self._formats is None:
return

Expand All @@ -181,38 +190,157 @@ def reduce_memory(self):
for t in SparseGraph.all_tensors.difference(set(tensors_needed)):
self.__dict__[t] = None

def _create_coo(self):
def src_ids(self) -> torch.Tensor:
return self._src_ids

def cdst_ids(self) -> torch.Tensor:
return self._cdst_ids

def dst_ids(self) -> torch.Tensor:
if self._dst_ids is None:
self._dst_ids = decompress_ids(self._cdst_ids)
return self._dst_ids

def _create_csc(self):
if self._cdst_ids is None:
if not self._dst_ids_is_sorted:
self._dst_ids, self._perm = torch.sort(self._dst_ids)
self._src_ids = self._src_ids[self._perm]
self._cdst_ids = compress_ids(self._dst_ids, self._num_dst_nodes)
def csrc_ids(self) -> torch.Tensor:
if self._csrc_ids is None:
src_ids, self._perm_csc2csr = torch.sort(self._src_ids)
self._csrc_ids = compress_ids(src_ids, self._num_src_nodes)
return self._csrc_ids

def num_src_nodes(self):
return self._num_src_nodes

def num_dst_nodes(self):
return self._num_dst_nodes

def values(self):
return self._values

def formats(self):
return self._formats

def coo(self) -> Tuple[torch.Tensor, torch.Tensor]:
def coo(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
if "coo" not in self.formats():
raise RuntimeError(
"The SparseGraph did not create a COO layout. "
"Set 'formats' to include 'coo' when creating the graph."
"Set 'formats' list to include 'coo' when creating the graph."
)
return (self._src_ids, self._dst_ids)
return self.src_ids(), self.dst_ids(), self._values

def csc(self) -> Tuple[torch.Tensor, torch.Tensor]:
def csc(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
if "csc" not in self.formats():
raise RuntimeError(
"The SparseGraph did not create a CSC layout. "
"Set 'formats' to include 'csc' when creating the graph."
"Set 'formats' list to include 'csc' when creating the graph."
)
return self.cdst_ids(), self.src_ids(), self._values

def csr(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
if "csr" not in self.formats():
raise RuntimeError(
"The SparseGraph did not create a CSR layout. "
"Set 'formats' list to include 'csr' when creating the graph."
)
csrc_ids = self.csrc_ids()
dst_ids = self.dst_ids()[self._perm_csc2csr]
value = self._values
if value is not None:
value = value[self._perm_csc2csr]
return csrc_ids, dst_ids, value


class BaseConv(torch.nn.Module):
r"""An abstract base class for cugraph-ops nn module."""

def __init__(self):
super().__init__()

def reset_parameters(self):
r"""Resets all learnable parameters of the module."""
raise NotImplementedError

def forward(self, *args):
r"""Runs the forward pass of the module."""
raise NotImplementedError

def get_cugraph_ops_CSC(
self,
g: Union[SparseGraph, dgl.DGLHeteroGraph],
is_bipartite: bool = False,
max_in_degree: Optional[int] = None,
) -> ops_torch.CSC:
"""Create CSC structure needed by cugraph-ops."""

if not isinstance(g, (SparseGraph, dgl.DGLHeteroGraph)):
raise TypeError(
f"The graph has to be either a 'cugraph_dgl.nn.SparseGraph' or "
f"'dgl.DGLHeteroGraph', but got '{type(g)}'."
)
return (self._cdst_ids, self._src_ids)

# TODO: max_in_degree should default to None in pylibcugraphops
if max_in_degree is None:
max_in_degree = -1

if isinstance(g, SparseGraph):
offsets, indices, _ = g.csc()
else:
offsets, indices, _ = g.adj_tensors("csc")

graph = ops_torch.CSC(
offsets=offsets,
indices=indices,
num_src_nodes=g.num_src_nodes(),
dst_max_in_degree=max_in_degree,
is_bipartite=is_bipartite,
)

return graph

def get_cugraph_ops_HeteroCSC(
self,
g: Union[SparseGraph, dgl.DGLHeteroGraph],
num_edge_types: int,
etypes: Optional[torch.Tensor] = None,
is_bipartite: bool = False,
max_in_degree: Optional[int] = None,
) -> ops_torch.HeteroCSC:
"""Create HeteroCSC structure needed by cugraph-ops."""

if not isinstance(g, (SparseGraph, dgl.DGLHeteroGraph)):
raise TypeError(
f"The graph has to be either a 'cugraph_dgl.nn.SparseGraph' or "
f"'dgl.DGLHeteroGraph', but got '{type(g)}'."
)

# TODO: max_in_degree should default to None in pylibcugraphops
if max_in_degree is None:
max_in_degree = -1

if isinstance(g, SparseGraph):
offsets, indices, etypes = g.csc()
if etypes is None:
raise ValueError(
"SparseGraph must have 'values' to create HeteroCSC. "
"Pass in edge types as 'values' when creating the SparseGraph."
)
etypes = etypes.int()
else:
if etypes is None:
raise ValueError(
"'etypes' is required when creating HeteroCSC "
"from dgl.DGLHeteroGraph."
)
offsets, indices, perm = g.adj_tensors("csc")
etypes = etypes[perm].int()

graph = ops_torch.HeteroCSC(
offsets=offsets,
indices=indices,
edge_types=etypes,
num_src_nodes=g.num_src_nodes(),
num_edge_types=num_edge_types,
dst_max_in_degree=max_in_degree,
is_bipartite=is_bipartite,
)

return graph
Loading

0 comments on commit b2e85bf

Please sign in to comment.