diff --git a/conda/recipes/cugraph-dgl/meta.yaml b/conda/recipes/cugraph-dgl/meta.yaml index 2fbc6360c04..9e9fcd2faf1 100644 --- a/conda/recipes/cugraph-dgl/meta.yaml +++ b/conda/recipes/cugraph-dgl/meta.yaml @@ -26,6 +26,7 @@ requirements: - dgl >=1.1.0.cu* - numba >=0.57 - numpy >=1.21 + - pylibcugraphops ={{ version }} - python - pytorch diff --git a/python/cugraph-dgl/cugraph_dgl/nn/conv/__init__.py b/python/cugraph-dgl/cugraph_dgl/nn/conv/__init__.py index e5acbf34478..3e7f2f076f0 100644 --- a/python/cugraph-dgl/cugraph_dgl/nn/conv/__init__.py +++ b/python/cugraph-dgl/cugraph_dgl/nn/conv/__init__.py @@ -13,6 +13,7 @@ from .base import SparseGraph from .gatconv import GATConv +from .gatv2conv import GATv2Conv from .relgraphconv import RelGraphConv from .sageconv import SAGEConv from .transformerconv import TransformerConv @@ -20,6 +21,7 @@ __all__ = [ "SparseGraph", "GATConv", + "GATv2Conv", "RelGraphConv", "SAGEConv", "TransformerConv", diff --git a/python/cugraph-dgl/cugraph_dgl/nn/conv/base.py b/python/cugraph-dgl/cugraph_dgl/nn/conv/base.py index 0eeaed29d86..307eb33078e 100644 --- a/python/cugraph-dgl/cugraph_dgl/nn/conv/base.py +++ b/python/cugraph-dgl/cugraph_dgl/nn/conv/base.py @@ -17,38 +17,7 @@ torch = import_optional("torch") ops_torch = import_optional("pylibcugraphops.pytorch") - - -class BaseConv(torch.nn.Module): - r"""An abstract base class for cugraph-ops nn module.""" - - def __init__(self): - super().__init__() - self._cached_offsets_fg = None - - def reset_parameters(self): - r"""Resets all learnable parameters of the module.""" - raise NotImplementedError - - def forward(self, *args): - r"""Runs the forward pass of the module.""" - raise NotImplementedError - - def pad_offsets(self, offsets: torch.Tensor, size: int) -> torch.Tensor: - r"""Pad zero-in-degree nodes to the end of offsets to reach size. This - is used to augment offset tensors from DGL blocks (MFGs) to be - compatible with cugraph-ops full-graph primitives.""" - if self._cached_offsets_fg is None: - self._cached_offsets_fg = torch.empty( - size, dtype=offsets.dtype, device=offsets.device - ) - elif self._cached_offsets_fg.numel() < size: - self._cached_offsets_fg.resize_(size) - - self._cached_offsets_fg[: offsets.numel()] = offsets - self._cached_offsets_fg[offsets.numel() : size] = offsets[-1] - - return self._cached_offsets_fg[:size] +dgl = import_optional("dgl") def compress_ids(ids: torch.Tensor, size: int) -> torch.Tensor: @@ -63,8 +32,9 @@ def decompress_ids(c_ids: torch.Tensor) -> torch.Tensor: class SparseGraph(object): - r"""A god-class to store different sparse formats needed by cugraph-ops - and facilitate sparse format conversions. + r"""A class to create and store different sparse formats needed by + cugraph-ops. It always creates a CSC representation and can provide COO- or + CSR-format if needed. Parameters ---------- @@ -89,25 +59,43 @@ class SparseGraph(object): consists of the sources between `src_indices[cdst_indices[k]]` and `src_indices[cdst_indices[k+1]]`. - dst_ids_is_sorted: bool - Whether `dst_ids` has been sorted in an ascending order. When sorted, - creating CSC layout is much faster. + values: torch.Tensor, optional + Values on the edges. + + is_sorted: bool + Whether the COO inputs (src_ids, dst_ids, values) have been sorted by + `dst_ids` in an ascending order. CSC layout creation is much faster + when sorted. formats: str or tuple of str, optional - The desired sparse formats to create for the graph. + The desired sparse formats to create for the graph. The formats tuple + must include "csc". Default: "csc". reduce_memory: bool, optional When set, the tensors are not required by the desired formats will be - set to `None`. + set to `None`. Default: True. Notes ----- For MFGs (sampled graphs), the node ids must have been renumbered. """ - supported_formats = {"coo": ("src_ids", "dst_ids"), "csc": ("cdst_ids", "src_ids")} - - all_tensors = set(["src_ids", "dst_ids", "csrc_ids", "cdst_ids"]) + supported_formats = { + "coo": ("_src_ids", "_dst_ids"), + "csc": ("_cdst_ids", "_src_ids"), + "csr": ("_csrc_ids", "_dst_ids", "_perm_csc2csr"), + } + + all_tensors = set( + [ + "_src_ids", + "_dst_ids", + "_csrc_ids", + "_cdst_ids", + "_perm_coo2csc", + "_perm_csc2csr", + ] + ) def __init__( self, @@ -116,15 +104,19 @@ def __init__( dst_ids: Optional[torch.Tensor] = None, csrc_ids: Optional[torch.Tensor] = None, cdst_ids: Optional[torch.Tensor] = None, - dst_ids_is_sorted: bool = False, - formats: Optional[Union[str, Tuple[str]]] = None, + values: Optional[torch.Tensor] = None, + is_sorted: bool = False, + formats: Union[str, Tuple[str]] = "csc", reduce_memory: bool = True, ): self._num_src_nodes, self._num_dst_nodes = size - self._dst_ids_is_sorted = dst_ids_is_sorted + self._is_sorted = is_sorted if dst_ids is None and cdst_ids is None: - raise ValueError("One of 'dst_ids' and 'cdst_ids' must be given.") + raise ValueError( + "One of 'dst_ids' and 'cdst_ids' must be given " + "to create a SparseGraph." + ) if src_ids is not None: src_ids = src_ids.contiguous() @@ -148,21 +140,40 @@ def __init__( ) cdst_ids = cdst_ids.contiguous() + if values is not None: + values = values.contiguous() + self._src_ids = src_ids self._dst_ids = dst_ids self._csrc_ids = csrc_ids self._cdst_ids = cdst_ids - self._perm = None + self._values = values + self._perm_coo2csc = None + self._perm_csc2csr = None if isinstance(formats, str): formats = (formats,) - - if formats is not None: - for format_ in formats: - assert format_ in SparseGraph.supported_formats - self.__getattribute__(f"_create_{format_}")() self._formats = formats + if "csc" not in formats: + raise ValueError( + f"{self.__class__.__name__}.formats must contain " + f"'csc', but got {formats}." + ) + + # always create csc first + if self._cdst_ids is None: + if not self._is_sorted: + self._dst_ids, self._perm_coo2csc = torch.sort(self._dst_ids) + self._src_ids = self._src_ids[self._perm_coo2csc] + if self._values is not None: + self._values = self._values[self._perm_coo2csc] + self._cdst_ids = compress_ids(self._dst_ids, self._num_dst_nodes) + + for format_ in formats: + assert format_ in SparseGraph.supported_formats + self.__getattribute__(f"{format_}")() + self._reduce_memory = reduce_memory if reduce_memory: self.reduce_memory() @@ -170,8 +181,6 @@ def __init__( def reduce_memory(self): """Remove the tensors that are not necessary to create the desired sparse formats to reduce memory footprint.""" - - self._perm = None if self._formats is None: return @@ -181,16 +190,22 @@ def reduce_memory(self): for t in SparseGraph.all_tensors.difference(set(tensors_needed)): self.__dict__[t] = None - def _create_coo(self): + def src_ids(self) -> torch.Tensor: + return self._src_ids + + def cdst_ids(self) -> torch.Tensor: + return self._cdst_ids + + def dst_ids(self) -> torch.Tensor: if self._dst_ids is None: self._dst_ids = decompress_ids(self._cdst_ids) + return self._dst_ids - def _create_csc(self): - if self._cdst_ids is None: - if not self._dst_ids_is_sorted: - self._dst_ids, self._perm = torch.sort(self._dst_ids) - self._src_ids = self._src_ids[self._perm] - self._cdst_ids = compress_ids(self._dst_ids, self._num_dst_nodes) + def csrc_ids(self) -> torch.Tensor: + if self._csrc_ids is None: + src_ids, self._perm_csc2csr = torch.sort(self._src_ids) + self._csrc_ids = compress_ids(src_ids, self._num_src_nodes) + return self._csrc_ids def num_src_nodes(self): return self._num_src_nodes @@ -198,21 +213,134 @@ def num_src_nodes(self): def num_dst_nodes(self): return self._num_dst_nodes + def values(self): + return self._values + def formats(self): return self._formats - def coo(self) -> Tuple[torch.Tensor, torch.Tensor]: + def coo(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: if "coo" not in self.formats(): raise RuntimeError( "The SparseGraph did not create a COO layout. " - "Set 'formats' to include 'coo' when creating the graph." + "Set 'formats' list to include 'coo' when creating the graph." ) - return (self._src_ids, self._dst_ids) + return self.src_ids(), self.dst_ids(), self._values - def csc(self) -> Tuple[torch.Tensor, torch.Tensor]: + def csc(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: if "csc" not in self.formats(): raise RuntimeError( "The SparseGraph did not create a CSC layout. " - "Set 'formats' to include 'csc' when creating the graph." + "Set 'formats' list to include 'csc' when creating the graph." + ) + return self.cdst_ids(), self.src_ids(), self._values + + def csr(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + if "csr" not in self.formats(): + raise RuntimeError( + "The SparseGraph did not create a CSR layout. " + "Set 'formats' list to include 'csr' when creating the graph." + ) + csrc_ids = self.csrc_ids() + dst_ids = self.dst_ids()[self._perm_csc2csr] + value = self._values + if value is not None: + value = value[self._perm_csc2csr] + return csrc_ids, dst_ids, value + + +class BaseConv(torch.nn.Module): + r"""An abstract base class for cugraph-ops nn module.""" + + def __init__(self): + super().__init__() + + def reset_parameters(self): + r"""Resets all learnable parameters of the module.""" + raise NotImplementedError + + def forward(self, *args): + r"""Runs the forward pass of the module.""" + raise NotImplementedError + + def get_cugraph_ops_CSC( + self, + g: Union[SparseGraph, dgl.DGLHeteroGraph], + is_bipartite: bool = False, + max_in_degree: Optional[int] = None, + ) -> ops_torch.CSC: + """Create CSC structure needed by cugraph-ops.""" + + if not isinstance(g, (SparseGraph, dgl.DGLHeteroGraph)): + raise TypeError( + f"The graph has to be either a 'cugraph_dgl.nn.SparseGraph' or " + f"'dgl.DGLHeteroGraph', but got '{type(g)}'." ) - return (self._cdst_ids, self._src_ids) + + # TODO: max_in_degree should default to None in pylibcugraphops + if max_in_degree is None: + max_in_degree = -1 + + if isinstance(g, SparseGraph): + offsets, indices, _ = g.csc() + else: + offsets, indices, _ = g.adj_tensors("csc") + + graph = ops_torch.CSC( + offsets=offsets, + indices=indices, + num_src_nodes=g.num_src_nodes(), + dst_max_in_degree=max_in_degree, + is_bipartite=is_bipartite, + ) + + return graph + + def get_cugraph_ops_HeteroCSC( + self, + g: Union[SparseGraph, dgl.DGLHeteroGraph], + num_edge_types: int, + etypes: Optional[torch.Tensor] = None, + is_bipartite: bool = False, + max_in_degree: Optional[int] = None, + ) -> ops_torch.HeteroCSC: + """Create HeteroCSC structure needed by cugraph-ops.""" + + if not isinstance(g, (SparseGraph, dgl.DGLHeteroGraph)): + raise TypeError( + f"The graph has to be either a 'cugraph_dgl.nn.SparseGraph' or " + f"'dgl.DGLHeteroGraph', but got '{type(g)}'." + ) + + # TODO: max_in_degree should default to None in pylibcugraphops + if max_in_degree is None: + max_in_degree = -1 + + if isinstance(g, SparseGraph): + offsets, indices, etypes = g.csc() + if etypes is None: + raise ValueError( + "SparseGraph must have 'values' to create HeteroCSC. " + "Pass in edge types as 'values' when creating the SparseGraph." + ) + etypes = etypes.int() + else: + if etypes is None: + raise ValueError( + "'etypes' is required when creating HeteroCSC " + "from dgl.DGLHeteroGraph." + ) + offsets, indices, perm = g.adj_tensors("csc") + etypes = etypes[perm].int() + + graph = ops_torch.HeteroCSC( + offsets=offsets, + indices=indices, + edge_types=etypes, + num_src_nodes=g.num_src_nodes(), + num_edge_types=num_edge_types, + dst_max_in_degree=max_in_degree, + is_bipartite=is_bipartite, + ) + + return graph diff --git a/python/cugraph-dgl/cugraph_dgl/nn/conv/gatconv.py b/python/cugraph-dgl/cugraph_dgl/nn/conv/gatconv.py index 239def5b677..8843e61ad89 100644 --- a/python/cugraph-dgl/cugraph_dgl/nn/conv/gatconv.py +++ b/python/cugraph-dgl/cugraph_dgl/nn/conv/gatconv.py @@ -10,13 +10,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Torch Module for graph attention network layer using the aggregation -primitives in cugraph-ops""" -# pylint: disable=no-member, arguments-differ, invalid-name, too-many-arguments -from __future__ import annotations + from typing import Optional, Tuple, Union -from cugraph_dgl.nn.conv.base import BaseConv +from cugraph_dgl.nn.conv.base import BaseConv, SparseGraph from cugraph.utilities.utils import import_optional dgl = import_optional("dgl") @@ -32,13 +29,15 @@ class GATConv(BaseConv): Parameters ---------- - in_feats : int, pair of ints + in_feats : int or tuple Input feature size. A pair denotes feature sizes of source and destination nodes. out_feats : int Output feature size. num_heads : int - Number of heads in Multi-Head Attention. + Number of heads in multi-head attention. + feat_drop : float, optional + Dropout rate on feature. Defaults: ``0``. concat : bool, optional If False, the multi-head attentions are averaged instead of concatenated. Default: ``True``. @@ -46,6 +45,15 @@ class GATConv(BaseConv): Edge feature size. Default: ``None``. negative_slope : float, optional LeakyReLU angle of negative slope. Defaults: ``0.2``. + residual : bool, optional + If True, use residual connection. Defaults: ``False``. + allow_zero_in_degree : bool, optional + If there are 0-in-degree nodes in the graph, output for those nodes will + be invalid since no message will be passed to those nodes. This is + harmful for some applications causing silent performance regression. + This module will raise a DGLError if it detects 0-in-degree nodes in + input graph. By setting ``True``, it will suppress the check and let the + users handle it by themselves. Defaults: ``False``. bias : bool, optional If True, learns a bias term. Defaults: ``True``. @@ -81,37 +89,46 @@ class GATConv(BaseConv): [ 1.6477, -1.9986], [ 1.1138, -1.9302]]], device='cuda:0', grad_fn=) """ - MAX_IN_DEGREE_MFG = 200 def __init__( self, in_feats: Union[int, Tuple[int, int]], out_feats: int, num_heads: int, + feat_drop: float = 0.0, concat: bool = True, edge_feats: Optional[int] = None, negative_slope: float = 0.2, + residual: bool = False, + allow_zero_in_degree: bool = False, bias: bool = True, ): super().__init__() self.in_feats = in_feats self.out_feats = out_feats + self.in_feats_src, self.in_feats_dst = dgl.utils.expand_as_pair(in_feats) self.num_heads = num_heads + self.feat_drop = nn.Dropout(feat_drop) self.concat = concat self.edge_feats = edge_feats self.negative_slope = negative_slope + self.allow_zero_in_degree = allow_zero_in_degree if isinstance(in_feats, int): - self.fc = nn.Linear(in_feats, num_heads * out_feats, bias=False) + self.lin = nn.Linear(in_feats, num_heads * out_feats, bias=False) else: - self.fc_src = nn.Linear(in_feats[0], num_heads * out_feats, bias=False) - self.fc_dst = nn.Linear(in_feats[1], num_heads * out_feats, bias=False) + self.lin_src = nn.Linear( + self.in_feats_src, num_heads * out_feats, bias=False + ) + self.lin_dst = nn.Linear( + self.in_feats_dst, num_heads * out_feats, bias=False + ) if edge_feats is not None: - self.fc_edge = nn.Linear(edge_feats, num_heads * out_feats, bias=False) + self.lin_edge = nn.Linear(edge_feats, num_heads * out_feats, bias=False) self.attn_weights = nn.Parameter(torch.Tensor(3 * num_heads * out_feats)) else: - self.register_parameter("fc_edge", None) + self.register_parameter("lin_edge", None) self.attn_weights = nn.Parameter(torch.Tensor(2 * num_heads * out_feats)) if bias and concat: @@ -121,28 +138,40 @@ def __init__( else: self.register_buffer("bias", None) + self.residual = residual and self.in_feats_dst != out_feats * num_heads + if self.residual: + self.lin_res = nn.Linear( + self.in_feats_dst, num_heads * out_feats, bias=bias + ) + else: + self.register_buffer("lin_res", None) + self.reset_parameters() def reset_parameters(self): r"""Reinitialize learnable parameters.""" gain = nn.init.calculate_gain("relu") - if hasattr(self, "fc"): - nn.init.xavier_normal_(self.fc.weight, gain=gain) + if hasattr(self, "lin"): + nn.init.xavier_normal_(self.lin.weight, gain=gain) else: - nn.init.xavier_normal_(self.fc_src.weight, gain=gain) - nn.init.xavier_normal_(self.fc_dst.weight, gain=gain) + nn.init.xavier_normal_(self.lin_src.weight, gain=gain) + nn.init.xavier_normal_(self.lin_dst.weight, gain=gain) nn.init.xavier_normal_( self.attn_weights.view(-1, self.num_heads, self.out_feats), gain=gain ) - if self.fc_edge is not None: - self.fc_edge.reset_parameters() + if self.lin_edge is not None: + self.lin_edge.reset_parameters() + + if self.lin_res is not None: + self.lin_res.reset_parameters() + if self.bias is not None: nn.init.zeros_(self.bias) def forward( self, - g: dgl.DGLHeteroGraph, + g: Union[SparseGraph, dgl.DGLHeteroGraph], nfeat: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], efeat: Optional[torch.Tensor] = None, max_in_degree: Optional[int] = None, @@ -151,18 +180,17 @@ def forward( Parameters ---------- - graph : DGLGraph + graph : DGLGraph or SparseGraph The graph. nfeat : torch.Tensor Input features of shape :math:`(N, D_{in})`. efeat: torch.Tensor, optional Optional edge features. max_in_degree : int - Maximum in-degree of destination nodes. It is only effective when - :attr:`g` is a :class:`DGLBlock`, i.e., bipartite graph. When - :attr:`g` is generated from a neighbor sampler, the value should be - set to the corresponding :attr:`fanout`. If not given, - :attr:`max_in_degree` will be calculated on-the-fly. + Maximum in-degree of destination nodes. When :attr:`g` is generated + from a neighbor sampler, the value should be set to the corresponding + :attr:`fanout`. This option is used to invoke the MFG-variant of + cugraph-ops kernel. Returns ------- @@ -171,49 +199,63 @@ def forward( :math:`H` is the number of heads, and :math:`D_{out}` is size of output feature. """ - if max_in_degree is None: - max_in_degree = -1 - - bipartite = not isinstance(nfeat, torch.Tensor) - offsets, indices, _ = g.adj_tensors("csc") - - graph = ops_torch.CSC( - offsets=offsets, - indices=indices, - num_src_nodes=g.num_src_nodes(), - dst_max_in_degree=max_in_degree, - is_bipartite=bipartite, + if isinstance(g, dgl.DGLHeteroGraph): + if not self.allow_zero_in_degree: + if (g.in_degrees() == 0).any(): + raise dgl.base.DGLError( + "There are 0-in-degree nodes in the graph, " + "output for those nodes will be invalid. " + "This is harmful for some applications, " + "causing silent performance regression. " + "Adding self-loop on the input graph by " + "calling `g = dgl.add_self_loop(g)` will resolve " + "the issue. Setting ``allow_zero_in_degree`` " + "to be `True` when constructing this module will " + "suppress the check and let the code run." + ) + + bipartite = isinstance(nfeat, (list, tuple)) + + _graph = self.get_cugraph_ops_CSC( + g, is_bipartite=bipartite, max_in_degree=max_in_degree ) + if bipartite: + nfeat = (self.feat_drop(nfeat[0]), self.feat_drop(nfeat[1])) + nfeat_dst_orig = nfeat[1] + else: + nfeat = self.feat_drop(nfeat) + nfeat_dst_orig = nfeat[: g.num_dst_nodes()] + if efeat is not None: - if self.fc_edge is None: + if self.lin_edge is None: raise RuntimeError( f"{self.__class__.__name__}.edge_feats must be set to " f"accept edge features." ) - efeat = self.fc_edge(efeat) + efeat = self.lin_edge(efeat) if bipartite: - if not hasattr(self, "fc_src"): + if not hasattr(self, "lin_src"): raise RuntimeError( f"{self.__class__.__name__}.in_feats must be a pair of " f"integers to allow bipartite node features, but got " f"{self.in_feats}." ) - nfeat_src = self.fc_src(nfeat[0]) - nfeat_dst = self.fc_dst(nfeat[1]) + nfeat_src = self.lin_src(nfeat[0]) + nfeat_dst = self.lin_dst(nfeat[1]) else: - if not hasattr(self, "fc"): + if not hasattr(self, "lin"): raise RuntimeError( f"{self.__class__.__name__}.in_feats is expected to be an " f"integer, but got {self.in_feats}." ) - nfeat = self.fc(nfeat) + nfeat = self.lin(nfeat) out = ops_torch.operators.mha_gat_n2n( (nfeat_src, nfeat_dst) if bipartite else nfeat, self.attn_weights, - graph, + _graph, num_heads=self.num_heads, activation="LeakyReLU", negative_slope=self.negative_slope, @@ -224,6 +266,12 @@ def forward( if self.concat: out = out.view(-1, self.num_heads, self.out_feats) + if self.residual: + res = self.lin_res(nfeat_dst_orig).view(-1, self.num_heads, self.out_feats) + if not self.concat: + res = res.mean(dim=1) + out = out + res + if self.bias is not None: out = out + self.bias diff --git a/python/cugraph-dgl/cugraph_dgl/nn/conv/gatv2conv.py b/python/cugraph-dgl/cugraph_dgl/nn/conv/gatv2conv.py new file mode 100644 index 00000000000..209a5fe1a8d --- /dev/null +++ b/python/cugraph-dgl/cugraph_dgl/nn/conv/gatv2conv.py @@ -0,0 +1,249 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +from cugraph_dgl.nn.conv.base import BaseConv, SparseGraph +from cugraph.utilities.utils import import_optional + +dgl = import_optional("dgl") +torch = import_optional("torch") +nn = import_optional("torch.nn") +ops_torch = import_optional("pylibcugraphops.pytorch") + + +class GATv2Conv(BaseConv): + r"""GATv2 from `How Attentive are Graph Attention Networks? + `__, with the sparse aggregation + accelerated by cugraph-ops. + + Parameters + ---------- + in_feats : int, or pair of ints + Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`. + If the layer is to be applied to a unidirectional bipartite graph, `in_feats` + specifies the input feature size on both the source and destination nodes. + If a scalar is given, the source and destination node feature size + would take the same value. + out_feats : int + Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`. + num_heads : int + Number of heads in Multi-Head Attention. + feat_drop : float, optional + Dropout rate on feature. Defaults: ``0``. + concat : bool, optional + If False, the multi-head attentions are averaged instead of concatenated. + Default: ``True``. + edge_feats : int, optional + Edge feature size. Default: ``None``. + negative_slope : float, optional + LeakyReLU angle of negative slope. Defaults: ``0.2``. + residual : bool, optional + If True, use residual connection. Defaults: ``False``. + allow_zero_in_degree : bool, optional + If there are 0-in-degree nodes in the graph, output for those nodes will + be invalid since no message will be passed to those nodes. This is + harmful for some applications causing silent performance regression. + This module will raise a DGLError if it detects 0-in-degree nodes in + input graph. By setting ``True``, it will suppress the check and let the + users handle it by themselves. Defaults: ``False``. + bias : bool, optional + If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + share_weights : bool, optional + If set to :obj:`True`, the same matrix for :math:`W_{left}` and + :math:`W_{right}` in the above equations, will be applied to the source + and the target node of every edge. (default: :obj:`False`) + """ + + def __init__( + self, + in_feats: Union[int, Tuple[int, int]], + out_feats: int, + num_heads: int, + feat_drop: float = 0.0, + concat: bool = True, + edge_feats: Optional[int] = None, + negative_slope: float = 0.2, + residual: bool = False, + allow_zero_in_degree: bool = False, + bias: bool = True, + share_weights: bool = False, + ): + super().__init__() + self.in_feats = in_feats + self.out_feats = out_feats + self.in_feats_src, self.in_feats_dst = dgl.utils.expand_as_pair(in_feats) + self.num_heads = num_heads + self.feat_drop = nn.Dropout(feat_drop) + self.concat = concat + self.edge_feats = edge_feats + self.negative_slope = negative_slope + self.allow_zero_in_degree = allow_zero_in_degree + self.share_weights = share_weights + + self.lin_src = nn.Linear(self.in_feats_src, num_heads * out_feats, bias=bias) + if share_weights: + if self.in_feats_src != self.in_feats_dst: + raise ValueError( + f"Input feature size of source and destination " + f"nodes must be identical when share_weights is enabled, " + f"but got {self.in_feats_src} and {self.in_feats_dst}." + ) + self.lin_dst = self.lin_src + else: + self.lin_dst = nn.Linear( + self.in_feats_dst, num_heads * out_feats, bias=bias + ) + + self.attn = nn.Parameter(torch.Tensor(num_heads * out_feats)) + + if edge_feats is not None: + self.lin_edge = nn.Linear(edge_feats, num_heads * out_feats, bias=False) + else: + self.register_parameter("lin_edge", None) + + if bias and concat: + self.bias = nn.Parameter(torch.Tensor(num_heads, out_feats)) + elif bias and not concat: + self.bias = nn.Parameter(torch.Tensor(out_feats)) + else: + self.register_buffer("bias", None) + + self.residual = residual and self.in_feats_dst != out_feats * num_heads + if self.residual: + self.lin_res = nn.Linear( + self.in_feats_dst, num_heads * out_feats, bias=bias + ) + else: + self.register_buffer("lin_res", None) + + self.reset_parameters() + + def reset_parameters(self): + r"""Reinitialize learnable parameters.""" + gain = nn.init.calculate_gain("relu") + nn.init.xavier_normal_(self.lin_src.weight, gain=gain) + nn.init.xavier_normal_(self.lin_dst.weight, gain=gain) + + nn.init.xavier_normal_( + self.attn.view(-1, self.num_heads, self.out_feats), gain=gain + ) + if self.lin_edge is not None: + self.lin_edge.reset_parameters() + + if self.lin_res is not None: + self.lin_res.reset_parameters() + + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward( + self, + g: Union[SparseGraph, dgl.DGLHeteroGraph], + nfeat: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + efeat: Optional[torch.Tensor] = None, + max_in_degree: Optional[int] = None, + ) -> torch.Tensor: + r"""Forward computation. + + Parameters + ---------- + graph : DGLGraph or SparseGraph + The graph. + nfeat : torch.Tensor + Input features of shape :math:`(N, D_{in})`. + efeat: torch.Tensor, optional + Optional edge features. + max_in_degree : int + Maximum in-degree of destination nodes. When :attr:`g` is generated + from a neighbor sampler, the value should be set to the corresponding + :attr:`fanout`. This option is used to invoke the MFG-variant of + cugraph-ops kernel. + + Returns + ------- + torch.Tensor + The output feature of shape :math:`(N, H, D_{out})` where + :math:`H` is the number of heads, and :math:`D_{out}` is size of + output feature. + """ + + if isinstance(g, dgl.DGLHeteroGraph): + if not self.allow_zero_in_degree: + if (g.in_degrees() == 0).any(): + raise dgl.base.DGLError( + "There are 0-in-degree nodes in the graph, " + "output for those nodes will be invalid. " + "This is harmful for some applications, " + "causing silent performance regression. " + "Adding self-loop on the input graph by " + "calling `g = dgl.add_self_loop(g)` will resolve " + "the issue. Setting ``allow_zero_in_degree`` " + "to be `True` when constructing this module will " + "suppress the check and let the code run." + ) + + nfeat_bipartite = isinstance(nfeat, (list, tuple)) + graph_bipartite = nfeat_bipartite or self.share_weights is False + + _graph = self.get_cugraph_ops_CSC( + g, is_bipartite=graph_bipartite, max_in_degree=max_in_degree + ) + + if nfeat_bipartite: + nfeat = (self.feat_drop(nfeat[0]), self.feat_drop(nfeat[1])) + nfeat_dst_orig = nfeat[1] + else: + nfeat = self.feat_drop(nfeat) + nfeat_dst_orig = nfeat[: g.num_dst_nodes()] + + if efeat is not None: + if self.lin_edge is None: + raise RuntimeError( + f"{self.__class__.__name__}.edge_feats must be set to " + f"accept edge features." + ) + efeat = self.lin_edge(efeat) + + if nfeat_bipartite: + nfeat = (self.lin_src(nfeat[0]), self.lin_dst(nfeat[1])) + elif graph_bipartite: + nfeat = (self.lin_src(nfeat), self.lin_dst(nfeat[: g.num_dst_nodes()])) + else: + nfeat = self.lin_src(nfeat) + + out = ops_torch.operators.mha_gat_v2_n2n( + nfeat, + self.attn, + _graph, + num_heads=self.num_heads, + activation="LeakyReLU", + negative_slope=self.negative_slope, + concat_heads=self.concat, + edge_feat=efeat, + )[: g.num_dst_nodes()] + + if self.concat: + out = out.view(-1, self.num_heads, self.out_feats) + + if self.residual: + res = self.lin_res(nfeat_dst_orig).view(-1, self.num_heads, self.out_feats) + if not self.concat: + res = res.mean(dim=1) + out = out + res + + if self.bias is not None: + out = out + self.bias + + return out diff --git a/python/cugraph-dgl/cugraph_dgl/nn/conv/relgraphconv.py b/python/cugraph-dgl/cugraph_dgl/nn/conv/relgraphconv.py index 89e49011cf7..54916674210 100644 --- a/python/cugraph-dgl/cugraph_dgl/nn/conv/relgraphconv.py +++ b/python/cugraph-dgl/cugraph_dgl/nn/conv/relgraphconv.py @@ -10,14 +10,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Torch Module for Relational graph convolution layer using the aggregation -primitives in cugraph-ops""" -# pylint: disable=no-member, arguments-differ, invalid-name, too-many-arguments -from __future__ import annotations + import math -from typing import Optional +from typing import Optional, Union -from cugraph_dgl.nn.conv.base import BaseConv +from cugraph_dgl.nn.conv.base import BaseConv, SparseGraph from cugraph.utilities.utils import import_optional dgl = import_optional("dgl") @@ -29,13 +26,8 @@ class RelGraphConv(BaseConv): r"""An accelerated relational graph convolution layer from `Modeling Relational Data with Graph Convolutional Networks - `__ that leverages the highly-optimized - aggregation primitives in cugraph-ops. - - See :class:`dgl.nn.pytorch.conv.RelGraphConv` for mathematical model. - - This module depends on :code:`pylibcugraphops` package, which can be - installed via :code:`conda install -c nvidia pylibcugraphops>=23.02`. + `__, with the sparse aggregation + accelerated by cugraph-ops. Parameters ---------- @@ -84,7 +76,6 @@ class RelGraphConv(BaseConv): [-1.4335, -2.3758], [-1.4331, -2.3295]], device='cuda:0', grad_fn=) """ - MAX_IN_DEGREE_MFG = 500 def __init__( self, @@ -148,7 +139,7 @@ def reset_parameters(self): def forward( self, - g: dgl.DGLHeteroGraph, + g: Union[SparseGraph, dgl.DGLHeteroGraph], feat: torch.Tensor, etypes: torch.Tensor, max_in_degree: Optional[int] = None, @@ -167,49 +158,24 @@ def forward( so any input of other integer types will be casted into int32, thus introducing some overhead. Pass in int32 tensors directly for best performance. - max_in_degree : int, optional - Maximum in-degree of destination nodes. It is only effective when - :attr:`g` is a :class:`DGLBlock`, i.e., bipartite graph. When - :attr:`g` is generated from a neighbor sampler, the value should be - set to the corresponding :attr:`fanout`. If not given, - :attr:`max_in_degree` will be calculated on-the-fly. + max_in_degree : int + Maximum in-degree of destination nodes. When :attr:`g` is generated + from a neighbor sampler, the value should be set to the corresponding + :attr:`fanout`. This option is used to invoke the MFG-variant of + cugraph-ops kernel. Returns ------- torch.Tensor New node features. Shape: :math:`(|V|, D_{out})`. """ - offsets, indices, edge_ids = g.adj_tensors("csc") - edge_types_perm = etypes[edge_ids.long()].int() - - if g.is_block: - if max_in_degree is None: - max_in_degree = g.in_degrees().max().item() - - if max_in_degree < self.MAX_IN_DEGREE_MFG: - _graph = ops_torch.SampledHeteroCSC( - offsets, - indices, - edge_types_perm, - max_in_degree, - g.num_src_nodes(), - self.num_rels, - ) - else: - offsets_fg = self.pad_offsets(offsets, g.num_src_nodes() + 1) - _graph = ops_torch.StaticHeteroCSC( - offsets_fg, - indices, - edge_types_perm, - self.num_rels, - ) - else: - _graph = ops_torch.StaticHeteroCSC( - offsets, - indices, - edge_types_perm, - self.num_rels, - ) + _graph = self.get_cugraph_ops_HeteroCSC( + g, + num_edge_types=self.num_rels, + etypes=etypes, + is_bipartite=False, + max_in_degree=max_in_degree, + ) h = ops_torch.operators.agg_hg_basis_n2n_post( feat, diff --git a/python/cugraph-dgl/cugraph_dgl/nn/conv/sageconv.py b/python/cugraph-dgl/cugraph_dgl/nn/conv/sageconv.py index 60f4c505e19..a3f946d7cb4 100644 --- a/python/cugraph-dgl/cugraph_dgl/nn/conv/sageconv.py +++ b/python/cugraph-dgl/cugraph_dgl/nn/conv/sageconv.py @@ -10,11 +10,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Torch Module for GraphSAGE layer using the aggregation primitives in -cugraph-ops""" -# pylint: disable=no-member, arguments-differ, invalid-name, too-many-arguments -from __future__ import annotations -from typing import Optional, Union + +from typing import Optional, Tuple, Union from cugraph_dgl.nn.conv.base import BaseConv, SparseGraph from cugraph.utilities.utils import import_optional @@ -27,22 +24,18 @@ class SAGEConv(BaseConv): r"""An accelerated GraphSAGE layer from `Inductive Representation Learning - on Large Graphs `__ that leverages the - highly-optimized aggregation primitives in cugraph-ops. - - See :class:`dgl.nn.pytorch.conv.SAGEConv` for mathematical model. - - This module depends on :code:`pylibcugraphops` package, which can be - installed via :code:`conda install -c nvidia pylibcugraphops>=23.02`. + on Large Graphs `, with the sparse + aggregation accelerated by cugraph-ops. Parameters ---------- - in_feats : int - Input feature size. + in_feats : int or tuple + Input feature size. If a scalar is given, the source and destination + nodes are required to be the same. out_feats : int Output feature size. aggregator_type : str - Aggregator type to use (``mean``, ``sum``, ``min``, ``max``). + Aggregator type to use ("mean", "sum", "min", "max", "pool", "gcn"). feat_drop : float Dropout rate on features, default: ``0``. bias : bool @@ -68,38 +61,57 @@ class SAGEConv(BaseConv): [-1.1690, 0.1952], [-1.1690, 0.1952]], device='cuda:0', grad_fn=) """ - MAX_IN_DEGREE_MFG = 500 + valid_aggr_types = {"mean", "sum", "min", "max", "pool", "gcn"} def __init__( self, - in_feats: int, + in_feats: Union[int, Tuple[int, int]], out_feats: int, aggregator_type: str = "mean", feat_drop: float = 0.0, bias: bool = True, ): super().__init__() - self.in_feats = in_feats - self.out_feats = out_feats - valid_aggr_types = {"max", "min", "mean", "sum"} - if aggregator_type not in valid_aggr_types: + + if aggregator_type not in self.valid_aggr_types: raise ValueError( - f"Invalid aggregator_type. Must be one of {valid_aggr_types}. " + f"Invalid aggregator_type. Must be one of {self.valid_aggr_types}. " f"But got '{aggregator_type}' instead." ) - self.aggr = aggregator_type + + self.aggregator_type = aggregator_type + self._aggr = aggregator_type + self.in_feats = in_feats + self.out_feats = out_feats + self.in_feats_src, self.in_feats_dst = dgl.utils.expand_as_pair(in_feats) self.feat_drop = nn.Dropout(feat_drop) - self.linear = nn.Linear(2 * in_feats, out_feats, bias=bias) + if self.aggregator_type == "gcn": + self._aggr = "mean" + self.lin = nn.Linear(self.in_feats_src, out_feats, bias=bias) + else: + self.lin = nn.Linear( + self.in_feats_src + self.in_feats_dst, out_feats, bias=bias + ) + + if self.aggregator_type == "pool": + self._aggr = "max" + self.pre_lin = nn.Linear(self.in_feats_src, self.in_feats_src) + else: + self.register_parameter("pre_lin", None) + + self.reset_parameters() def reset_parameters(self): r"""Reinitialize learnable parameters.""" - self.linear.reset_parameters() + self.lin.reset_parameters() + if self.pre_lin is not None: + self.pre_lin.reset_parameters() def forward( self, g: Union[SparseGraph, dgl.DGLHeteroGraph], - feat: torch.Tensor, + feat: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], max_in_degree: Optional[int] = None, ) -> torch.Tensor: r"""Forward computation. @@ -108,7 +120,7 @@ def forward( ---------- g : DGLGraph or SparseGraph The graph. - feat : torch.Tensor + feat : torch.Tensor or tuple Node features. Shape: :math:`(|V|, D_{in})`. max_in_degree : int Maximum in-degree of destination nodes. When :attr:`g` is generated @@ -121,36 +133,34 @@ def forward( torch.Tensor Output node features. Shape: :math:`(|V|, D_{out})`. """ - if max_in_degree is None: - max_in_degree = -1 - - if isinstance(g, SparseGraph): - assert "csc" in g.formats() - offsets, indices = g.csc() - _graph = ops_torch.CSC( - offsets=offsets, - indices=indices, - num_src_nodes=g.num_src_nodes(), - dst_max_in_degree=max_in_degree, - ) - elif isinstance(g, dgl.DGLHeteroGraph): - offsets, indices, _ = g.adj_tensors("csc") - _graph = ops_torch.CSC( - offsets=offsets, - indices=indices, - num_src_nodes=g.num_src_nodes(), - dst_max_in_degree=max_in_degree, - ) - else: - raise TypeError( - f"The graph has to be either a 'SparseGraph' or " - f"'dgl.DGLHeteroGraph', but got '{type(g)}'." - ) + feat_bipartite = isinstance(feat, (list, tuple)) + graph_bipartite = feat_bipartite or self.aggregator_type == "pool" + + _graph = self.get_cugraph_ops_CSC( + g, is_bipartite=graph_bipartite, max_in_degree=max_in_degree + ) - feat = self.feat_drop(feat) - h = ops_torch.operators.agg_concat_n2n(feat, _graph, self.aggr)[ + if feat_bipartite: + feat = (self.feat_drop(feat[0]), self.feat_drop(feat[1])) + else: + feat = self.feat_drop(feat) + + if self.aggregator_type == "pool": + if feat_bipartite: + feat = (self.pre_lin(feat[0]).relu(), feat[1]) + else: + feat = (self.pre_lin(feat).relu(), feat[: g.num_dst_nodes()]) + # force ctx.needs_input_grad=True in cugraph-ops autograd function + feat[0].requires_grad_() + feat[1].requires_grad_() + + out = ops_torch.operators.agg_concat_n2n(feat, _graph, self._aggr)[ : g.num_dst_nodes() ] - h = self.linear(h) - return h + if self.aggregator_type == "gcn": + out = out[:, : self.in_feats_src] + + out = self.lin(out) + + return out diff --git a/python/cugraph-dgl/cugraph_dgl/nn/conv/transformerconv.py b/python/cugraph-dgl/cugraph_dgl/nn/conv/transformerconv.py index 5cd5fbbaebe..8481b9ee265 100644 --- a/python/cugraph-dgl/cugraph_dgl/nn/conv/transformerconv.py +++ b/python/cugraph-dgl/cugraph_dgl/nn/conv/transformerconv.py @@ -10,9 +10,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from typing import Optional, Tuple, Union -from cugraph_dgl.nn.conv.base import BaseConv +from cugraph_dgl.nn.conv.base import BaseConv, SparseGraph from cugraph.utilities.utils import import_optional dgl = import_optional("dgl") @@ -114,7 +115,7 @@ def reset_parameters(self): def forward( self, - g: dgl.DGLHeteroGraph, + g: Union[SparseGraph, dgl.DGLHeteroGraph], nfeat: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], efeat: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -130,17 +131,12 @@ def forward( efeat: torch.Tensor, optional Edge feature tensor. Default: ``None``. """ - offsets, indices, _ = g.adj_tensors("csc") - graph = ops_torch.CSC( - offsets=offsets, - indices=indices, - num_src_nodes=g.num_src_nodes(), - is_bipartite=True, - ) - - if isinstance(nfeat, torch.Tensor): + feat_bipartite = isinstance(nfeat, (list, tuple)) + if not feat_bipartite: nfeat = (nfeat, nfeat) + _graph = self.get_cugraph_ops_CSC(g, is_bipartite=True) + query = self.lin_query(nfeat[1][: g.num_dst_nodes()]) key = self.lin_key(nfeat[0]) value = self.lin_value(nfeat[0]) @@ -157,7 +153,7 @@ def forward( key_emb=key, query_emb=query, value_emb=value, - graph=graph, + graph=_graph, num_heads=self.num_heads, concat_heads=self.concat, edge_emb=efeat, diff --git a/python/cugraph-dgl/tests/conftest.py b/python/cugraph-dgl/tests/conftest.py index 6f8690d1140..a3863ed81fa 100644 --- a/python/cugraph-dgl/tests/conftest.py +++ b/python/cugraph-dgl/tests/conftest.py @@ -40,16 +40,19 @@ class SparseGraphData1: nnz = 6 src_ids = torch.IntTensor([0, 1, 2, 3, 2, 5]).cuda() dst_ids = torch.IntTensor([1, 2, 3, 4, 0, 3]).cuda() + values = torch.IntTensor([10, 20, 30, 40, 50, 60]).cuda() # CSR src_ids_sorted_by_src = torch.IntTensor([0, 1, 2, 2, 3, 5]).cuda() dst_ids_sorted_by_src = torch.IntTensor([1, 2, 0, 3, 4, 3]).cuda() csrc_ids = torch.IntTensor([0, 1, 2, 4, 5, 5, 6]).cuda() + values_csr = torch.IntTensor([10, 20, 50, 30, 40, 60]).cuda() # CSC src_ids_sorted_by_dst = torch.IntTensor([2, 0, 1, 5, 2, 3]).cuda() dst_ids_sorted_by_dst = torch.IntTensor([0, 1, 2, 3, 3, 4]).cuda() cdst_ids = torch.IntTensor([0, 1, 2, 3, 5, 6]).cuda() + values_csc = torch.IntTensor([50, 10, 20, 60, 30, 40]).cuda() @pytest.fixture diff --git a/python/cugraph-dgl/tests/nn/test_gatconv.py b/python/cugraph-dgl/tests/nn/test_gatconv.py index 7ed65645a28..ef3047dc2cd 100644 --- a/python/cugraph-dgl/tests/nn/test_gatconv.py +++ b/python/cugraph-dgl/tests/nn/test_gatconv.py @@ -10,69 +10,84 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# pylint: disable=too-many-arguments, too-many-locals import pytest -try: - import cugraph_dgl -except ModuleNotFoundError: - pytest.skip("cugraph_dgl not available", allow_module_level=True) - -from cugraph.utilities.utils import import_optional +from cugraph_dgl.nn.conv.base import SparseGraph +from cugraph_dgl.nn import GATConv as CuGraphGATConv from .common import create_graph1 -torch = import_optional("torch") -dgl = import_optional("dgl") +dgl = pytest.importorskip("dgl", reason="DGL not available") +torch = pytest.importorskip("torch", reason="PyTorch not available") + +ATOL = 1e-6 @pytest.mark.parametrize("bipartite", [False, True]) @pytest.mark.parametrize("idtype_int", [False, True]) @pytest.mark.parametrize("max_in_degree", [None, 8]) @pytest.mark.parametrize("num_heads", [1, 2, 7]) +@pytest.mark.parametrize("residual", [False, True]) @pytest.mark.parametrize("to_block", [False, True]) -def test_gatconv_equality(bipartite, idtype_int, max_in_degree, num_heads, to_block): - GATConv = dgl.nn.GATConv - CuGraphGATConv = cugraph_dgl.nn.GATConv - device = "cuda" - g = create_graph1().to(device) +@pytest.mark.parametrize("sparse_format", ["coo", "csc", None]) +def test_gatconv_equality( + bipartite, idtype_int, max_in_degree, num_heads, residual, to_block, sparse_format +): + from dgl.nn.pytorch import GATConv + + g = create_graph1().to("cuda") if idtype_int: g = g.int() - if to_block: g = dgl.to_block(g) + size = (g.num_src_nodes(), g.num_dst_nodes()) + if bipartite: in_feats = (10, 3) nfeat = ( - torch.rand(g.num_src_nodes(), in_feats[0], device=device), - torch.rand(g.num_dst_nodes(), in_feats[1], device=device), + torch.rand(g.num_src_nodes(), in_feats[0]).cuda(), + torch.rand(g.num_dst_nodes(), in_feats[1]).cuda(), ) else: in_feats = 10 - nfeat = torch.rand(g.num_src_nodes(), in_feats, device=device) + nfeat = torch.rand(g.num_src_nodes(), in_feats).cuda() out_feats = 2 + if sparse_format == "coo": + sg = SparseGraph( + size=size, src_ids=g.edges()[0], dst_ids=g.edges()[1], formats="csc" + ) + elif sparse_format == "csc": + offsets, indices, _ = g.adj_tensors("csc") + sg = SparseGraph(size=size, src_ids=indices, cdst_ids=offsets, formats="csc") + args = (in_feats, out_feats, num_heads) - kwargs = {"bias": False} + kwargs = {"bias": False, "allow_zero_in_degree": True} - conv1 = GATConv(*args, **kwargs, allow_zero_in_degree=True).to(device) + conv1 = GATConv(*args, **kwargs).cuda() out1 = conv1(g, nfeat) - conv2 = CuGraphGATConv(*args, **kwargs).to(device) + conv2 = CuGraphGATConv(*args, **kwargs).cuda() dim = num_heads * out_feats with torch.no_grad(): conv2.attn_weights.data[:dim] = conv1.attn_l.data.flatten() conv2.attn_weights.data[dim:] = conv1.attn_r.data.flatten() if bipartite: - conv2.fc_src.weight.data = conv1.fc_src.weight.data.detach().clone() - conv2.fc_dst.weight.data = conv1.fc_dst.weight.data.detach().clone() + conv2.lin_src.weight.data = conv1.fc_src.weight.data.detach().clone() + conv2.lin_dst.weight.data = conv1.fc_dst.weight.data.detach().clone() else: - conv2.fc.weight.data = conv1.fc.weight.data.detach().clone() - out2 = conv2(g, nfeat, max_in_degree=max_in_degree) + conv2.lin.weight.data = conv1.fc.weight.data.detach().clone() + if residual and conv2.residual: + conv2.lin_res.weight.data = conv1.fc_res.weight.data.detach().clone() - assert torch.allclose(out1, out2, atol=1e-6) + if sparse_format is not None: + out2 = conv2(sg, nfeat, max_in_degree=max_in_degree) + else: + out2 = conv2(g, nfeat, max_in_degree=max_in_degree) + + assert torch.allclose(out1, out2, atol=ATOL) grad_out1 = torch.rand_like(out1) grad_out2 = grad_out1.clone().detach() @@ -81,18 +96,18 @@ def test_gatconv_equality(bipartite, idtype_int, max_in_degree, num_heads, to_bl if bipartite: assert torch.allclose( - conv1.fc_src.weight.grad, conv2.fc_src.weight.grad, atol=1e-6 + conv1.fc_src.weight.grad, conv2.lin_src.weight.grad, atol=ATOL ) assert torch.allclose( - conv1.fc_dst.weight.grad, conv2.fc_dst.weight.grad, atol=1e-6 + conv1.fc_dst.weight.grad, conv2.lin_dst.weight.grad, atol=ATOL ) else: - assert torch.allclose(conv1.fc.weight.grad, conv2.fc.weight.grad, atol=1e-6) + assert torch.allclose(conv1.fc.weight.grad, conv2.lin.weight.grad, atol=ATOL) assert torch.allclose( torch.cat((conv1.attn_l.grad, conv1.attn_r.grad), dim=0), conv2.attn_weights.grad.view(2, num_heads, out_feats), - atol=1e-6, + atol=ATOL, ) @@ -106,10 +121,7 @@ def test_gatconv_equality(bipartite, idtype_int, max_in_degree, num_heads, to_bl def test_gatconv_edge_feats( bias, bipartite, concat, max_in_degree, num_heads, to_block, use_edge_feats ): - from cugraph_dgl.nn import GATConv - - device = "cuda" - g = create_graph1().to(device) + g = create_graph1().to("cuda") if to_block: g = dgl.to_block(g) @@ -117,24 +129,30 @@ def test_gatconv_edge_feats( if bipartite: in_feats = (10, 3) nfeat = ( - torch.rand(g.num_src_nodes(), in_feats[0], device=device), - torch.rand(g.num_dst_nodes(), in_feats[1], device=device), + torch.rand(g.num_src_nodes(), in_feats[0]).cuda(), + torch.rand(g.num_dst_nodes(), in_feats[1]).cuda(), ) else: in_feats = 10 - nfeat = torch.rand(g.num_src_nodes(), in_feats, device=device) + nfeat = torch.rand(g.num_src_nodes(), in_feats).cuda() out_feats = 2 if use_edge_feats: edge_feats = 3 - efeat = torch.rand(g.num_edges(), edge_feats, device=device) + efeat = torch.rand(g.num_edges(), edge_feats).cuda() else: edge_feats = None efeat = None - conv = GATConv( - in_feats, out_feats, num_heads, concat=concat, edge_feats=edge_feats, bias=bias - ).to(device) + conv = CuGraphGATConv( + in_feats, + out_feats, + num_heads, + concat=concat, + edge_feats=edge_feats, + bias=bias, + allow_zero_in_degree=True, + ).cuda() out = conv(g, nfeat, efeat=efeat, max_in_degree=max_in_degree) grad_out = torch.rand_like(out) diff --git a/python/cugraph-dgl/tests/nn/test_gatv2conv.py b/python/cugraph-dgl/tests/nn/test_gatv2conv.py new file mode 100644 index 00000000000..cc46a6e4b39 --- /dev/null +++ b/python/cugraph-dgl/tests/nn/test_gatv2conv.py @@ -0,0 +1,147 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from cugraph_dgl.nn.conv.base import SparseGraph +from cugraph_dgl.nn import GATv2Conv as CuGraphGATv2Conv +from .common import create_graph1 + +dgl = pytest.importorskip("dgl", reason="DGL not available") +torch = pytest.importorskip("torch", reason="PyTorch not available") + +ATOL = 1e-6 + + +@pytest.mark.parametrize("bipartite", [False, True]) +@pytest.mark.parametrize("idtype_int", [False, True]) +@pytest.mark.parametrize("max_in_degree", [None, 8]) +@pytest.mark.parametrize("num_heads", [1, 2, 7]) +@pytest.mark.parametrize("residual", [False, True]) +@pytest.mark.parametrize("to_block", [False, True]) +@pytest.mark.parametrize("sparse_format", ["coo", "csc", None]) +def test_gatv2conv_equality( + bipartite, idtype_int, max_in_degree, num_heads, residual, to_block, sparse_format +): + from dgl.nn.pytorch import GATv2Conv + + g = create_graph1().to("cuda") + + if idtype_int: + g = g.int() + if to_block: + g = dgl.to_block(g) + + size = (g.num_src_nodes(), g.num_dst_nodes()) + + if bipartite: + in_feats = (10, 3) + nfeat = ( + torch.rand(g.num_src_nodes(), in_feats[0]).cuda(), + torch.rand(g.num_dst_nodes(), in_feats[1]).cuda(), + ) + else: + in_feats = 10 + nfeat = torch.rand(g.num_src_nodes(), in_feats).cuda() + out_feats = 2 + + if sparse_format == "coo": + sg = SparseGraph( + size=size, src_ids=g.edges()[0], dst_ids=g.edges()[1], formats="csc" + ) + elif sparse_format == "csc": + offsets, indices, _ = g.adj_tensors("csc") + sg = SparseGraph(size=size, src_ids=indices, cdst_ids=offsets, formats="csc") + + args = (in_feats, out_feats, num_heads) + kwargs = {"bias": False, "allow_zero_in_degree": True} + + conv1 = GATv2Conv(*args, **kwargs).cuda() + out1 = conv1(g, nfeat) + + conv2 = CuGraphGATv2Conv(*args, **kwargs).cuda() + with torch.no_grad(): + conv2.attn.data = conv1.attn.data.flatten() + conv2.lin_src.weight.data = conv1.fc_src.weight.data.detach().clone() + conv2.lin_dst.weight.data = conv1.fc_dst.weight.data.detach().clone() + if residual and conv2.residual: + conv2.lin_res.weight.data = conv1.fc_res.weight.data.detach().clone() + + if sparse_format is not None: + out2 = conv2(sg, nfeat, max_in_degree=max_in_degree) + else: + out2 = conv2(g, nfeat, max_in_degree=max_in_degree) + + assert torch.allclose(out1, out2, atol=ATOL) + + grad_out1 = torch.rand_like(out1) + grad_out2 = grad_out1.clone().detach() + out1.backward(grad_out1) + out2.backward(grad_out2) + + assert torch.allclose( + conv1.fc_src.weight.grad, conv2.lin_src.weight.grad, atol=ATOL + ) + assert torch.allclose( + conv1.fc_dst.weight.grad, conv2.lin_dst.weight.grad, atol=ATOL + ) + + assert torch.allclose(conv1.attn.grad, conv1.attn.grad, atol=ATOL) + + +@pytest.mark.parametrize("bias", [False, True]) +@pytest.mark.parametrize("bipartite", [False, True]) +@pytest.mark.parametrize("concat", [False, True]) +@pytest.mark.parametrize("max_in_degree", [None, 8, 800]) +@pytest.mark.parametrize("num_heads", [1, 2, 7]) +@pytest.mark.parametrize("to_block", [False, True]) +@pytest.mark.parametrize("use_edge_feats", [False, True]) +def test_gatv2conv_edge_feats( + bias, bipartite, concat, max_in_degree, num_heads, to_block, use_edge_feats +): + g = create_graph1().to("cuda") + + if to_block: + g = dgl.to_block(g) + + if bipartite: + in_feats = (10, 3) + nfeat = ( + torch.rand(g.num_src_nodes(), in_feats[0]).cuda(), + torch.rand(g.num_dst_nodes(), in_feats[1]).cuda(), + ) + else: + in_feats = 10 + nfeat = torch.rand(g.num_src_nodes(), in_feats).cuda() + out_feats = 2 + + if use_edge_feats: + edge_feats = 3 + efeat = torch.rand(g.num_edges(), edge_feats).cuda() + else: + edge_feats = None + efeat = None + + conv = CuGraphGATv2Conv( + in_feats, + out_feats, + num_heads, + concat=concat, + edge_feats=edge_feats, + bias=bias, + allow_zero_in_degree=True, + ).cuda() + out = conv(g, nfeat, efeat=efeat, max_in_degree=max_in_degree) + + grad_out = torch.rand_like(out) + out.backward(grad_out) diff --git a/python/cugraph-dgl/tests/nn/test_relgraphconv.py b/python/cugraph-dgl/tests/nn/test_relgraphconv.py index d2ae6a23978..901f9ba1433 100644 --- a/python/cugraph-dgl/tests/nn/test_relgraphconv.py +++ b/python/cugraph-dgl/tests/nn/test_relgraphconv.py @@ -10,20 +10,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# pylint: disable=too-many-arguments, too-many-locals import pytest -try: - import cugraph_dgl -except ModuleNotFoundError: - pytest.skip("cugraph_dgl not available", allow_module_level=True) - -from cugraph.utilities.utils import import_optional +from cugraph_dgl.nn.conv.base import SparseGraph +from cugraph_dgl.nn import RelGraphConv as CuGraphRelGraphConv from .common import create_graph1 -torch = import_optional("torch") -dgl = import_optional("dgl") +dgl = pytest.importorskip("dgl", reason="DGL not available") +torch = pytest.importorskip("torch", reason="PyTorch not available") + +ATOL = 1e-6 @pytest.mark.parametrize("idtype_int", [False, True]) @@ -32,12 +29,17 @@ @pytest.mark.parametrize("regularizer", [None, "basis"]) @pytest.mark.parametrize("self_loop", [False, True]) @pytest.mark.parametrize("to_block", [False, True]) +@pytest.mark.parametrize("sparse_format", ["coo", "csc", None]) def test_relgraphconv_equality( - idtype_int, max_in_degree, num_bases, regularizer, self_loop, to_block + idtype_int, + max_in_degree, + num_bases, + regularizer, + self_loop, + to_block, + sparse_format, ): - RelGraphConv = dgl.nn.RelGraphConv - CuGraphRelGraphConv = cugraph_dgl.nn.RelGraphConv - device = "cuda" + from dgl.nn.pytorch import RelGraphConv in_feat, out_feat, num_rels = 10, 2, 3 args = (in_feat, out_feat, num_rels) @@ -47,34 +49,57 @@ def test_relgraphconv_equality( "bias": False, "self_loop": self_loop, } - g = create_graph1().to(device) - g.edata[dgl.ETYPE] = torch.randint(num_rels, (g.num_edges(),)).to(device) + g = create_graph1().to("cuda") + g.edata[dgl.ETYPE] = torch.randint(num_rels, (g.num_edges(),)).cuda() + if idtype_int: g = g.int() if to_block: g = dgl.to_block(g) - feat = torch.rand(g.num_src_nodes(), in_feat).to(device) + + size = (g.num_src_nodes(), g.num_dst_nodes()) + feat = torch.rand(g.num_src_nodes(), in_feat).cuda() + + if sparse_format == "coo": + sg = SparseGraph( + size=size, + src_ids=g.edges()[0], + dst_ids=g.edges()[1], + values=g.edata[dgl.ETYPE], + formats="csc", + ) + elif sparse_format == "csc": + offsets, indices, perm = g.adj_tensors("csc") + etypes = g.edata[dgl.ETYPE][perm] + sg = SparseGraph( + size=size, src_ids=indices, cdst_ids=offsets, values=etypes, formats="csc" + ) torch.manual_seed(0) - conv1 = RelGraphConv(*args, **kwargs).to(device) + conv1 = RelGraphConv(*args, **kwargs).cuda() torch.manual_seed(0) kwargs["apply_norm"] = False - conv2 = CuGraphRelGraphConv(*args, **kwargs).to(device) + conv2 = CuGraphRelGraphConv(*args, **kwargs).cuda() out1 = conv1(g, feat, g.edata[dgl.ETYPE]) - out2 = conv2(g, feat, g.edata[dgl.ETYPE], max_in_degree=max_in_degree) - assert torch.allclose(out1, out2, atol=1e-06) + + if sparse_format is not None: + out2 = conv2(sg, feat, sg.values(), max_in_degree=max_in_degree) + else: + out2 = conv2(g, feat, g.edata[dgl.ETYPE], max_in_degree=max_in_degree) + + assert torch.allclose(out1, out2, atol=ATOL) grad_out = torch.rand_like(out1) out1.backward(grad_out) out2.backward(grad_out) end = -1 if self_loop else None - assert torch.allclose(conv1.linear_r.W.grad, conv2.W.grad[:end], atol=1e-6) + assert torch.allclose(conv1.linear_r.W.grad, conv2.W.grad[:end], atol=ATOL) if self_loop: - assert torch.allclose(conv1.loop_weight.grad, conv2.W.grad[-1], atol=1e-6) + assert torch.allclose(conv1.loop_weight.grad, conv2.W.grad[-1], atol=ATOL) if regularizer is not None: - assert torch.allclose(conv1.linear_r.coeff.grad, conv2.coeff.grad, atol=1e-6) + assert torch.allclose(conv1.linear_r.coeff.grad, conv2.coeff.grad, atol=ATOL) diff --git a/python/cugraph-dgl/tests/nn/test_sageconv.py b/python/cugraph-dgl/tests/nn/test_sageconv.py index 447bbe49460..e2acf9e6596 100644 --- a/python/cugraph-dgl/tests/nn/test_sageconv.py +++ b/python/cugraph-dgl/tests/nn/test_sageconv.py @@ -10,31 +10,33 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# pylint: disable=too-many-arguments, too-many-locals import pytest -from cugraph.utilities.utils import import_optional from cugraph_dgl.nn.conv.base import SparseGraph from cugraph_dgl.nn import SAGEConv as CuGraphSAGEConv from .common import create_graph1 -torch = import_optional("torch") -dgl = import_optional("dgl") +dgl = pytest.importorskip("dgl", reason="DGL not available") +torch = pytest.importorskip("torch", reason="PyTorch not available") +ATOL = 1e-6 + +@pytest.mark.parametrize("aggr", ["mean", "pool"]) @pytest.mark.parametrize("bias", [False, True]) +@pytest.mark.parametrize("bipartite", [False, True]) @pytest.mark.parametrize("idtype_int", [False, True]) @pytest.mark.parametrize("max_in_degree", [None, 8]) @pytest.mark.parametrize("to_block", [False, True]) @pytest.mark.parametrize("sparse_format", ["coo", "csc", None]) -def test_SAGEConv_equality(bias, idtype_int, max_in_degree, to_block, sparse_format): - SAGEConv = dgl.nn.SAGEConv - device = "cuda" +def test_sageconv_equality( + aggr, bias, bipartite, idtype_int, max_in_degree, to_block, sparse_format +): + from dgl.nn.pytorch import SAGEConv - in_feat, out_feat = 5, 2 - kwargs = {"aggregator_type": "mean", "bias": bias} - g = create_graph1().to(device) + kwargs = {"aggregator_type": aggr, "bias": bias} + g = create_graph1().to("cuda") if idtype_int: g = g.int() @@ -42,7 +44,17 @@ def test_SAGEConv_equality(bias, idtype_int, max_in_degree, to_block, sparse_for g = dgl.to_block(g) size = (g.num_src_nodes(), g.num_dst_nodes()) - feat = torch.rand(g.num_src_nodes(), in_feat).to(device) + + if bipartite: + in_feats = (5, 3) + feat = ( + torch.rand(size[0], in_feats[0], requires_grad=True).cuda(), + torch.rand(size[1], in_feats[1], requires_grad=True).cuda(), + ) + else: + in_feats = 5 + feat = torch.rand(size[0], in_feats).cuda() + out_feats = 2 if sparse_format == "coo": sg = SparseGraph( @@ -52,39 +64,38 @@ def test_SAGEConv_equality(bias, idtype_int, max_in_degree, to_block, sparse_for offsets, indices, _ = g.adj_tensors("csc") sg = SparseGraph(size=size, src_ids=indices, cdst_ids=offsets, formats="csc") - torch.manual_seed(0) - conv1 = SAGEConv(in_feat, out_feat, **kwargs).to(device) - - torch.manual_seed(0) - conv2 = CuGraphSAGEConv(in_feat, out_feat, **kwargs).to(device) + conv1 = SAGEConv(in_feats, out_feats, **kwargs).cuda() + conv2 = CuGraphSAGEConv(in_feats, out_feats, **kwargs).cuda() + in_feats_src = conv2.in_feats_src with torch.no_grad(): - conv2.linear.weight.data[:, :in_feat] = conv1.fc_neigh.weight.data - conv2.linear.weight.data[:, in_feat:] = conv1.fc_self.weight.data + conv2.lin.weight.data[:, :in_feats_src] = conv1.fc_neigh.weight.data + conv2.lin.weight.data[:, in_feats_src:] = conv1.fc_self.weight.data if bias: - conv2.linear.bias.data[:] = conv1.fc_self.bias.data + conv2.lin.bias.data[:] = conv1.fc_self.bias.data + if aggr == "pool": + conv2.pre_lin.weight.data[:] = conv1.fc_pool.weight.data + conv2.pre_lin.bias.data[:] = conv1.fc_pool.bias.data out1 = conv1(g, feat) if sparse_format is not None: out2 = conv2(sg, feat, max_in_degree=max_in_degree) else: out2 = conv2(g, feat, max_in_degree=max_in_degree) - assert torch.allclose(out1, out2, atol=1e-06) + assert torch.allclose(out1, out2, atol=ATOL) grad_out = torch.rand_like(out1) out1.backward(grad_out) out2.backward(grad_out) assert torch.allclose( conv1.fc_neigh.weight.grad, - conv2.linear.weight.grad[:, :in_feat], - atol=1e-6, + conv2.lin.weight.grad[:, :in_feats_src], + atol=ATOL, ) assert torch.allclose( conv1.fc_self.weight.grad, - conv2.linear.weight.grad[:, in_feat:], - atol=1e-6, + conv2.lin.weight.grad[:, in_feats_src:], + atol=ATOL, ) if bias: - assert torch.allclose( - conv1.fc_self.bias.grad, conv2.linear.bias.grad, atol=1e-6 - ) + assert torch.allclose(conv1.fc_self.bias.grad, conv2.lin.bias.grad, atol=ATOL) diff --git a/python/cugraph-dgl/tests/nn/test_sparsegraph.py b/python/cugraph-dgl/tests/nn/test_sparsegraph.py index 3fb01575d66..09c0df202ff 100644 --- a/python/cugraph-dgl/tests/nn/test_sparsegraph.py +++ b/python/cugraph-dgl/tests/nn/test_sparsegraph.py @@ -19,32 +19,42 @@ def test_coo2csc(sparse_graph_1): data = sparse_graph_1 - values = torch.ones(data.nnz).cuda() + g = SparseGraph( - size=data.size, src_ids=data.src_ids, dst_ids=data.dst_ids, formats="csc" + size=data.size, + src_ids=data.src_ids, + dst_ids=data.dst_ids, + values=data.values, + formats=["csc"], ) - cdst_ids, src_ids = g.csc() + cdst_ids, src_ids, values = g.csc() new = torch.sparse_csc_tensor(cdst_ids, src_ids, values).cuda() old = torch.sparse_coo_tensor( - torch.vstack((data.src_ids, data.dst_ids)), values + torch.vstack((data.src_ids, data.dst_ids)), data.values ).cuda() torch.allclose(new.to_dense(), old.to_dense()) -def test_csc2coo(sparse_graph_1): +def test_csc_input(sparse_graph_1): data = sparse_graph_1 - values = torch.ones(data.nnz).cuda() + g = SparseGraph( size=data.size, src_ids=data.src_ids_sorted_by_dst, cdst_ids=data.cdst_ids, - formats="coo", + values=data.values_csc, + formats=["coo", "csc", "csr"], ) - src_ids, dst_ids = g.coo() + src_ids, dst_ids, values = g.coo() new = torch.sparse_coo_tensor(torch.vstack((src_ids, dst_ids)), values).cuda() old = torch.sparse_csc_tensor( - data.cdst_ids, data.src_ids_sorted_by_dst, values + data.cdst_ids, data.src_ids_sorted_by_dst, data.values_csc ).cuda() torch.allclose(new.to_dense(), old.to_dense()) + + csrc_ids, dst_ids, values = g.csr() + + new = torch.sparse_csr_tensor(csrc_ids, dst_ids, values).cuda() + torch.allclose(new.to_dense(), old.to_dense()) diff --git a/python/cugraph-dgl/tests/nn/test_transformerconv.py b/python/cugraph-dgl/tests/nn/test_transformerconv.py index 00476b9f0bb..b2b69cb35ab 100644 --- a/python/cugraph-dgl/tests/nn/test_transformerconv.py +++ b/python/cugraph-dgl/tests/nn/test_transformerconv.py @@ -13,16 +13,14 @@ import pytest -try: - from cugraph_dgl.nn import TransformerConv -except ModuleNotFoundError: - pytest.skip("cugraph_dgl not available", allow_module_level=True) - -from cugraph.utilities.utils import import_optional +from cugraph_dgl.nn.conv.base import SparseGraph +from cugraph_dgl.nn import TransformerConv from .common import create_graph1 -torch = import_optional("torch") -dgl = import_optional("dgl") +dgl = pytest.importorskip("dgl", reason="DGL not available") +torch = pytest.importorskip("torch", reason="PyTorch not available") + +ATOL = 1e-6 @pytest.mark.parametrize("beta", [False, True]) @@ -32,8 +30,16 @@ @pytest.mark.parametrize("num_heads", [1, 2, 3, 4]) @pytest.mark.parametrize("to_block", [False, True]) @pytest.mark.parametrize("use_edge_feats", [False, True]) -def test_TransformerConv( - beta, bipartite_node_feats, concat, idtype_int, num_heads, to_block, use_edge_feats +@pytest.mark.parametrize("sparse_format", ["coo", "csc", None]) +def test_transformerconv( + beta, + bipartite_node_feats, + concat, + idtype_int, + num_heads, + to_block, + use_edge_feats, + sparse_format, ): device = "cuda" g = create_graph1().to(device) @@ -44,6 +50,15 @@ def test_TransformerConv( if to_block: g = dgl.to_block(g) + size = (g.num_src_nodes(), g.num_dst_nodes()) + if sparse_format == "coo": + sg = SparseGraph( + size=size, src_ids=g.edges()[0], dst_ids=g.edges()[1], formats="csc" + ) + elif sparse_format == "csc": + offsets, indices, _ = g.adj_tensors("csc") + sg = SparseGraph(size=size, src_ids=indices, cdst_ids=offsets, formats="csc") + if bipartite_node_feats: in_node_feats = (5, 3) nfeat = ( @@ -71,6 +86,10 @@ def test_TransformerConv( edge_feats=edge_feats, ).to(device) - out = conv(g, nfeat, efeat) + if sparse_format is not None: + out = conv(sg, nfeat, efeat) + else: + out = conv(g, nfeat, efeat) + grad_out = torch.rand_like(out) out.backward(grad_out) diff --git a/python/cugraph-dgl/tests/test_dataset.py b/python/cugraph-dgl/tests/test_dataset.py index 69d50261e55..5db443dc0d8 100644 --- a/python/cugraph-dgl/tests/test_dataset.py +++ b/python/cugraph-dgl/tests/test_dataset.py @@ -123,6 +123,6 @@ def test_homogeneous_sampled_graphs_from_dataframe(return_type, seed_node): assert dgl_block.num_src_nodes() == cugraph_dgl_graph.num_src_nodes() assert dgl_block.num_dst_nodes() == cugraph_dgl_graph.num_dst_nodes() dgl_offsets, dgl_indices, _ = dgl_block.adj_tensors("csc") - cugraph_offsets, cugraph_indices = cugraph_dgl_graph.csc() + cugraph_offsets, cugraph_indices, _ = cugraph_dgl_graph.csc() assert torch.equal(dgl_offsets.to("cpu"), cugraph_offsets.to("cpu")) assert torch.equal(dgl_indices.to("cpu"), cugraph_indices.to("cpu")) diff --git a/python/cugraph-dgl/tests/test_from_dgl_hetrograph.py b/python/cugraph-dgl/tests/test_from_dgl_heterograph.py similarity index 100% rename from python/cugraph-dgl/tests/test_from_dgl_hetrograph.py rename to python/cugraph-dgl/tests/test_from_dgl_heterograph.py