Skip to content

Commit

Permalink
add GATv2Conv
Browse files Browse the repository at this point in the history
  • Loading branch information
tingyu66 committed Sep 8, 2023
1 parent 704389b commit 70b4abd
Show file tree
Hide file tree
Showing 5 changed files with 371 additions and 9 deletions.
2 changes: 2 additions & 0 deletions python/cugraph-dgl/cugraph_dgl/nn/conv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@

from .base import SparseGraph
from .gatconv import GATConv
from .gatv2conv import GATv2Conv
from .relgraphconv import RelGraphConv
from .sageconv import SAGEConv
from .transformerconv import TransformerConv

__all__ = [
"SparseGraph",
"GATConv",
"GATv2Conv",
"RelGraphConv",
"SAGEConv",
"TransformerConv",
Expand Down
5 changes: 2 additions & 3 deletions python/cugraph-dgl/cugraph_dgl/nn/conv/gatconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,12 @@ def __init__(
else:
self.register_buffer("bias", None)

if residual and self.in_feats_dst != out_feats * num_heads:
self.residual = True
self.residual = residual and self.in_feats_dst != out_feats * num_heads
if self.residual:
self.lin_res = nn.Linear(
self.in_feats_dst, num_heads * out_feats, bias=bias
)
else:
self.residual = False
self.register_buffer("lin_res", None)

self.reset_parameters()
Expand Down
261 changes: 261 additions & 0 deletions python/cugraph-dgl/cugraph_dgl/nn/conv/gatv2conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple, Union

from cugraph_dgl.nn.conv.base import BaseConv, SparseGraph
from cugraph.utilities.utils import import_optional

dgl = import_optional("dgl")
torch = import_optional("torch")
nn = import_optional("torch.nn")
ops_torch = import_optional("pylibcugraphops.pytorch")


class GATv2Conv(BaseConv):
r"""GATv2 from `How Attentive are Graph Attention Networks?
<https://arxiv.org/pdf/2105.14491.pdf>`__, with the sparse aggregation
accelerated by cugraph-ops.
Parameters
----------
in_feats : int, or pair of ints
Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`.
If the layer is to be applied to a unidirectional bipartite graph, `in_feats`
specifies the input feature size on both the source and destination nodes.
If a scalar is given, the source and destination node feature size
would take the same value.
out_feats : int
Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`.
num_heads : int
Number of heads in Multi-Head Attention.
feat_drop : float, optional
Dropout rate on feature. Defaults: ``0``.
concat : bool, optional
If False, the multi-head attentions are averaged instead of concatenated.
Default: ``True``.
edge_feats : int, optional
Edge feature size. Default: ``None``.
negative_slope : float, optional
LeakyReLU angle of negative slope. Defaults: ``0.2``.
residual : bool, optional
If True, use residual connection. Defaults: ``False``.
allow_zero_in_degree : bool, optional
If there are 0-in-degree nodes in the graph, output for those nodes will
be invalid since no message will be passed to those nodes. This is
harmful for some applications causing silent performance regression.
This module will raise a DGLError if it detects 0-in-degree nodes in
input graph. By setting ``True``, it will suppress the check and let the
users handle it by themselves. Defaults: ``False``.
bias : bool, optional
If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
share_weights : bool, optional
If set to :obj:`True`, the same matrix for :math:`W_{left}` and
:math:`W_{right}` in the above equations, will be applied to the source
and the target node of every edge. (default: :obj:`False`)
"""

def __init__(
self,
in_feats: Union[int, Tuple[int, int]],
out_feats: int,
num_heads: int,
feat_drop: float = 0.0,
concat: bool = True,
edge_feats: Optional[int] = None,
negative_slope: float = 0.2,
residual: bool = False,
allow_zero_in_degree: bool = False,
bias: bool = True,
share_weights: bool = False,
):
super().__init__()
self.in_feats = in_feats
self.out_feats = out_feats
self.in_feats_src, self.in_feats_dst = dgl.utils.expand_as_pair(in_feats)
self.num_heads = num_heads
self.feat_drop = nn.Dropout(feat_drop)
self.concat = concat
self.edge_feats = edge_feats
self.negative_slope = negative_slope
self.allow_zero_in_degree = allow_zero_in_degree
self.share_weights = share_weights

self.lin_src = nn.Linear(self.in_feats_src, num_heads * out_feats, bias=bias)
if isinstance(in_feats, (list, tuple)):
self.lin_dst = nn.Linear(
self.in_feats_dst, num_heads * out_feats, bias=bias
)
else:
self.lin_dst = self.lin_src

self.attn = nn.Parameter(torch.Tensor(num_heads * out_feats))

if edge_feats is not None:
self.lin_edge = nn.Linear(edge_feats, num_heads * out_feats, bias=False)
else:
self.register_parameter("lin_edge", None)

if bias and concat:
self.bias = nn.Parameter(torch.Tensor(num_heads * out_feats))
elif bias and not concat:
self.bias = nn.Parameter(torch.Tensor(out_feats))
else:
self.register_parameter("bias", None)

self.residual = residual and self.in_feats_dst != out_feats * num_heads
if self.residual:
self.lin_res = nn.Linear(
self.in_feats_dst, num_heads * out_feats, bias=bias
)
else:
self.register_buffer("lin_res", None)

self.reset_parameters()

def reset_parameters(self):
r"""Reinitialize learnable parameters."""
gain = nn.init.calculate_gain("relu")
nn.init.xavier_normal_(self.lin_src.weight, gain=gain)
nn.init.xavier_normal_(self.lin_dst.weight, gain=gain)

nn.init.xavier_normal_(
self.attn.view(-1, self.num_heads, self.out_feats), gain=gain
)
if self.lin_edge is not None:
self.lin_edge.reset_parameters()

if self.lin_res is not None:
self.lin_res.reset_parameters()

if self.bias is not None:
nn.init.zeros_(self.bias)

def forward(
self,
g: Union[SparseGraph, dgl.DGLHeteroGraph],
nfeat: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
efeat: Optional[torch.Tensor] = None,
max_in_degree: Optional[int] = None,
) -> torch.Tensor:
r"""Forward computation.
Parameters
----------
graph : DGLGraph or SparseGraph
The graph.
nfeat : torch.Tensor
Input features of shape :math:`(N, D_{in})`.
efeat: torch.Tensor, optional
Optional edge features.
max_in_degree : int
Maximum in-degree of destination nodes. When :attr:`g` is generated
from a neighbor sampler, the value should be set to the corresponding
:attr:`fanout`. This option is used to invoke the MFG-variant of
cugraph-ops kernel.
Returns
-------
torch.Tensor
The output feature of shape :math:`(N, H, D_{out})` where
:math:`H` is the number of heads, and :math:`D_{out}` is size of
output feature.
"""
if max_in_degree is None:
max_in_degree = -1

bipartite = isinstance(nfeat, (list, tuple))

if isinstance(g, SparseGraph):
assert "csc" in g.formats()
offsets, indices = g.csc()
_graph = ops_torch.CSC(
offsets=offsets,
indices=indices,
num_src_nodes=g.num_src_nodes(),
dst_max_in_degree=max_in_degree,
is_bipartite=bipartite,
)
elif isinstance(g, dgl.DGLHeteroGraph):
if not self.allow_zero_in_degree:
if (g.in_degrees() == 0).any():
raise dgl.base.DGLError(
"There are 0-in-degree nodes in the graph, "
"output for those nodes will be invalid. "
"This is harmful for some applications, "
"causing silent performance regression. "
"Adding self-loop on the input graph by "
"calling `g = dgl.add_self_loop(g)` will resolve "
"the issue. Setting ``allow_zero_in_degree`` "
"to be `True` when constructing this module will "
"suppress the check and let the code run."
)
offsets, indices, _ = g.adj_tensors("csc")
_graph = ops_torch.CSC(
offsets=offsets,
indices=indices,
num_src_nodes=g.num_src_nodes(),
dst_max_in_degree=max_in_degree,
is_bipartite=bipartite,
)
else:
raise TypeError(
f"The graph has to be either a 'SparseGraph' or "
f"'dgl.DGLHeteroGraph', but got '{type(g)}'."
)

if bipartite:
nfeat = (self.feat_drop(nfeat[0]), self.feat_drop(nfeat[1]))
nfeat_dst_orig = nfeat[1]
else:
nfeat = self.feat_drop(nfeat)
nfeat_dst_orig = nfeat[: g.num_dst_nodes()]

if efeat is not None:
if self.lin_edge is None:
raise RuntimeError(
f"{self.__class__.__name__}.edge_feats must be set to "
f"accept edge features."
)
efeat = self.lin_edge(efeat)

if bipartite:
nfeat = (self.lin_src(nfeat[0]), self.lin_dst(nfeat[1]))
else:
nfeat = self.lin_src(nfeat)

out = ops_torch.operators.mha_gat_v2_n2n(
nfeat,
self.attn,
_graph,
num_heads=self.num_heads,
activation="LeakyReLU",
negative_slope=self.negative_slope,
concat_heads=self.concat,
edge_feat=efeat,
)[: g.num_dst_nodes()]

if self.concat:
out = out.view(-1, self.num_heads, self.out_feats)

if self.residual:
res = self.lin_res(nfeat_dst_orig).view(-1, self.num_heads, self.out_feats)
if not self.concat:
res = res.mean(dim=1)
out = out + res

if self.bias is not None:
out = out + self.bias

return out
11 changes: 5 additions & 6 deletions python/cugraph-dgl/tests/nn/test_gatconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=too-many-arguments, too-many-locals

import pytest

Expand Down Expand Up @@ -88,7 +87,7 @@ def test_gatconv_equality(
else:
out2 = conv2(g, nfeat, max_in_degree=max_in_degree)

assert torch.allclose(out1, out2, atol=1e-6)
assert torch.allclose(out1, out2, atol=ATOL)

grad_out1 = torch.rand_like(out1)
grad_out2 = grad_out1.clone().detach()
Expand All @@ -97,18 +96,18 @@ def test_gatconv_equality(

if bipartite:
assert torch.allclose(
conv1.fc_src.weight.grad, conv2.lin_src.weight.grad, atol=1e-6
conv1.fc_src.weight.grad, conv2.lin_src.weight.grad, atol=ATOL
)
assert torch.allclose(
conv1.fc_dst.weight.grad, conv2.lin_dst.weight.grad, atol=1e-6
conv1.fc_dst.weight.grad, conv2.lin_dst.weight.grad, atol=ATOL
)
else:
assert torch.allclose(conv1.fc.weight.grad, conv2.lin.weight.grad, atol=1e-6)
assert torch.allclose(conv1.fc.weight.grad, conv2.lin.weight.grad, atol=ATOL)

assert torch.allclose(
torch.cat((conv1.attn_l.grad, conv1.attn_r.grad), dim=0),
conv2.attn_weights.grad.view(2, num_heads, out_feats),
atol=1e-6,
atol=ATOL,
)


Expand Down
Loading

0 comments on commit 70b4abd

Please sign in to comment.