Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Match weight-sharing option of GATConv in DGL #4074

Merged
merged 10 commits into from
Jan 23, 2024
70 changes: 41 additions & 29 deletions python/cugraph-dgl/cugraph_dgl/nn/conv/gatconv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -11,7 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple, Union
from typing import Optional, Union

from cugraph_dgl.nn.conv.base import BaseConv, SparseGraph
from cugraph.utilities.utils import import_optional
Expand All @@ -29,7 +29,7 @@ class GATConv(BaseConv):

Parameters
----------
in_feats : int or tuple
in_feats : int or (int, int)
Input feature size. A pair denotes feature sizes of source and
destination nodes.
out_feats : int
Expand Down Expand Up @@ -92,7 +92,7 @@ class GATConv(BaseConv):

def __init__(
self,
in_feats: Union[int, Tuple[int, int]],
in_feats: Union[int, tuple[int, int]],
out_feats: int,
num_heads: int,
feat_drop: float = 0.0,
Expand All @@ -104,14 +104,19 @@ def __init__(
bias: bool = True,
):
super().__init__()

if isinstance(in_feats, int):
self.in_feats_src = self.in_feats_dst = in_feats
else:
self.in_feats_src, self.in_feats_dst = in_feats
self.in_feats = in_feats
self.out_feats = out_feats
self.in_feats_src, self.in_feats_dst = dgl.utils.expand_as_pair(in_feats)
self.num_heads = num_heads
self.feat_drop = nn.Dropout(feat_drop)
self.concat = concat
self.edge_feats = edge_feats
self.negative_slope = negative_slope
self.residual = residual
self.allow_zero_in_degree = allow_zero_in_degree

if isinstance(in_feats, int):
Expand All @@ -126,28 +131,34 @@ def __init__(

if edge_feats is not None:
self.lin_edge = nn.Linear(edge_feats, num_heads * out_feats, bias=False)
self.attn_weights = nn.Parameter(torch.Tensor(3 * num_heads * out_feats))
self.attn_weights = nn.Parameter(torch.empty(3 * num_heads * out_feats))
else:
self.register_parameter("lin_edge", None)
self.attn_weights = nn.Parameter(torch.Tensor(2 * num_heads * out_feats))
self.attn_weights = nn.Parameter(torch.empty(2 * num_heads * out_feats))

if bias and concat:
self.bias = nn.Parameter(torch.Tensor(num_heads, out_feats))
elif bias and not concat:
self.bias = nn.Parameter(torch.Tensor(out_feats))
out_dim = num_heads * out_feats if concat else out_feats
if residual:
if self.in_feats_dst != out_dim:
self.lin_res = nn.Linear(self.in_feats_dst, out_dim, bias=bias)
else:
self.lin_res = nn.Identity()
else:
self.register_buffer("bias", None)
self.register_buffer("lin_res", None)

self.residual = residual and self.in_feats_dst != out_feats * num_heads
if self.residual:
self.lin_res = nn.Linear(
self.in_feats_dst, num_heads * out_feats, bias=bias
)
if bias and not isinstance(self.lin_res, nn.Linear):
if concat:
self.bias = nn.Parameter(torch.empty(num_heads, out_feats))
else:
self.bias = nn.Parameter(torch.empty(out_feats))
else:
self.register_buffer("lin_res", None)
self.register_buffer("bias", None)

self.reset_parameters()

def set_allow_zero_in_degree(self, set_value):
r"""Set allow_zero_in_degree flag."""
self.allow_zero_in_degree = set_value

def reset_parameters(self):
r"""Reinitialize learnable parameters."""
gain = nn.init.calculate_gain("relu")
Expand All @@ -172,7 +183,7 @@ def reset_parameters(self):
def forward(
self,
g: Union[SparseGraph, dgl.DGLHeteroGraph],
nfeat: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
nfeat: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
efeat: Optional[torch.Tensor] = None,
max_in_degree: Optional[int] = None,
) -> torch.Tensor:
Expand All @@ -182,8 +193,10 @@ def forward(
----------
graph : DGLGraph or SparseGraph
The graph.
nfeat : torch.Tensor
Input features of shape :math:`(N, D_{in})`.
nfeat : torch.Tensor or (torch.Tensor, torch.Tensor)
Node features. If given as a tuple, the two elements correspond to
the source and destination node features, respectively, in a
bipartite graph.
efeat: torch.Tensor, optional
Optional edge features.
max_in_degree : int
Expand Down Expand Up @@ -237,18 +250,17 @@ def forward(

if bipartite:
if not hasattr(self, "lin_src"):
raise RuntimeError(
f"{self.__class__.__name__}.in_feats must be a pair of "
f"integers to allow bipartite node features, but got "
f"{self.in_feats}."
)
nfeat_src = self.lin_src(nfeat[0])
nfeat_dst = self.lin_dst(nfeat[1])
nfeat_src = self.lin(nfeat[0])
nfeat_dst = self.lin(nfeat[1])
else:
nfeat_src = self.lin_src(nfeat[0])
nfeat_dst = self.lin_dst(nfeat[1])
else:
if not hasattr(self, "lin"):
raise RuntimeError(
f"{self.__class__.__name__}.in_feats is expected to be an "
f"integer, but got {self.in_feats}."
f"integer when the graph is not bipartite, "
f"but got {self.in_feats}."
)
nfeat = self.lin(nfeat)

Expand Down
69 changes: 31 additions & 38 deletions python/cugraph-dgl/cugraph_dgl/nn/conv/gatv2conv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -11,7 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple, Union
from typing import Optional, Union

from cugraph_dgl.nn.conv.base import BaseConv, SparseGraph
from cugraph.utilities.utils import import_optional
Expand All @@ -29,14 +29,11 @@ class GATv2Conv(BaseConv):

Parameters
----------
in_feats : int, or pair of ints
Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`.
If the layer is to be applied to a unidirectional bipartite graph, `in_feats`
specifies the input feature size on both the source and destination nodes.
If a scalar is given, the source and destination node feature size
would take the same value.
in_feats : int or (int, int)
Input feature size. A pair denotes feature sizes of source and
destination nodes.
out_feats : int
Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`.
Output feature size.
num_heads : int
Number of heads in Multi-Head Attention.
feat_drop : float, optional
Expand All @@ -58,17 +55,15 @@ class GATv2Conv(BaseConv):
input graph. By setting ``True``, it will suppress the check and let the
users handle it by themselves. Defaults: ``False``.
bias : bool, optional
If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
If True, learns a bias term. Defaults: ``True``.
share_weights : bool, optional
If set to :obj:`True`, the same matrix for :math:`W_{left}` and
:math:`W_{right}` in the above equations, will be applied to the source
and the target node of every edge. (default: :obj:`False`)
If ``True``, the same matrix will be applied to the source and the
destination node features. Defaults: ``False``.
"""

def __init__(
self,
in_feats: Union[int, Tuple[int, int]],
in_feats: Union[int, tuple[int, int]],
out_feats: int,
num_heads: int,
feat_drop: float = 0.0,
Expand All @@ -81,16 +76,22 @@ def __init__(
share_weights: bool = False,
):
super().__init__()

if isinstance(in_feats, int):
self.in_feats_src = self.in_feats_dst = in_feats
else:
self.in_feats_src, self.in_feats_dst = in_feats
self.in_feats = in_feats
self.out_feats = out_feats
self.in_feats_src, self.in_feats_dst = dgl.utils.expand_as_pair(in_feats)
self.num_heads = num_heads
self.feat_drop = nn.Dropout(feat_drop)
self.concat = concat
self.edge_feats = edge_feats
self.negative_slope = negative_slope
self.residual = residual
self.allow_zero_in_degree = allow_zero_in_degree
self.share_weights = share_weights
self.bias = bias

self.lin_src = nn.Linear(self.in_feats_src, num_heads * out_feats, bias=bias)
if share_weights:
Expand All @@ -106,52 +107,47 @@ def __init__(
self.in_feats_dst, num_heads * out_feats, bias=bias
)

self.attn = nn.Parameter(torch.Tensor(num_heads * out_feats))
self.attn_weights = nn.Parameter(torch.empty(num_heads * out_feats))

if edge_feats is not None:
self.lin_edge = nn.Linear(edge_feats, num_heads * out_feats, bias=False)
else:
self.register_parameter("lin_edge", None)

if bias and concat:
self.bias = nn.Parameter(torch.Tensor(num_heads, out_feats))
elif bias and not concat:
self.bias = nn.Parameter(torch.Tensor(out_feats))
else:
self.register_buffer("bias", None)

self.residual = residual and self.in_feats_dst != out_feats * num_heads
if self.residual:
self.lin_res = nn.Linear(
self.in_feats_dst, num_heads * out_feats, bias=bias
)
out_dim = num_heads * out_feats if concat else out_feats
if residual:
if self.in_feats_dst != out_dim:
self.lin_res = nn.Linear(self.in_feats_dst, out_dim, bias=bias)
else:
self.lin_res = nn.Identity()
else:
self.register_buffer("lin_res", None)

self.reset_parameters()

def set_allow_zero_in_degree(self, set_value):
r"""Set allow_zero_in_degree flag."""
self.allow_zero_in_degree = set_value

def reset_parameters(self):
r"""Reinitialize learnable parameters."""
gain = nn.init.calculate_gain("relu")
nn.init.xavier_normal_(self.lin_src.weight, gain=gain)
nn.init.xavier_normal_(self.lin_dst.weight, gain=gain)

nn.init.xavier_normal_(
self.attn.view(-1, self.num_heads, self.out_feats), gain=gain
self.attn_weights.view(-1, self.num_heads, self.out_feats), gain=gain
)
if self.lin_edge is not None:
self.lin_edge.reset_parameters()

if self.lin_res is not None:
self.lin_res.reset_parameters()

if self.bias is not None:
nn.init.zeros_(self.bias)

def forward(
self,
g: Union[SparseGraph, dgl.DGLHeteroGraph],
nfeat: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
nfeat: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
efeat: Optional[torch.Tensor] = None,
max_in_degree: Optional[int] = None,
) -> torch.Tensor:
Expand Down Expand Up @@ -225,7 +221,7 @@ def forward(

out = ops_torch.operators.mha_gat_v2_n2n(
nfeat,
self.attn,
self.attn_weights,
_graph,
num_heads=self.num_heads,
activation="LeakyReLU",
Expand All @@ -243,7 +239,4 @@ def forward(
res = res.mean(dim=1)
out = out + res

if self.bias is not None:
out = out + self.bias

return out
10 changes: 5 additions & 5 deletions python/cugraph-dgl/cugraph_dgl/nn/conv/relgraphconv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -100,16 +100,16 @@ def __init__(
self.self_loop = self_loop
if regularizer is None:
self.W = nn.Parameter(
torch.Tensor(num_rels + dim_self_loop, in_feats, out_feats)
torch.empty(num_rels + dim_self_loop, in_feats, out_feats)
)
self.coeff = None
elif regularizer == "basis":
if num_bases is None:
raise ValueError('Missing "num_bases" for basis regularization.')
self.W = nn.Parameter(
torch.Tensor(num_bases + dim_self_loop, in_feats, out_feats)
torch.empty(num_bases + dim_self_loop, in_feats, out_feats)
)
self.coeff = nn.Parameter(torch.Tensor(num_rels, num_bases))
self.coeff = nn.Parameter(torch.empty(num_rels, num_bases))
self.num_bases = num_bases
else:
raise ValueError(
Expand All @@ -119,7 +119,7 @@ def __init__(
self.regularizer = regularizer

if bias:
self.bias = nn.Parameter(torch.Tensor(out_feats))
self.bias = nn.Parameter(torch.empty(out_feats))
else:
self.register_parameter("bias", None)

Expand Down
8 changes: 4 additions & 4 deletions python/cugraph-dgl/cugraph_dgl/nn/conv/sageconv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -11,7 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple, Union
from typing import Optional, Union

from cugraph_dgl.nn.conv.base import BaseConv, SparseGraph
from cugraph.utilities.utils import import_optional
Expand Down Expand Up @@ -65,7 +65,7 @@ class SAGEConv(BaseConv):

def __init__(
self,
in_feats: Union[int, Tuple[int, int]],
in_feats: Union[int, tuple[int, int]],
out_feats: int,
aggregator_type: str = "mean",
feat_drop: float = 0.0,
Expand Down Expand Up @@ -111,7 +111,7 @@ def reset_parameters(self):
def forward(
self,
g: Union[SparseGraph, dgl.DGLHeteroGraph],
feat: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
feat: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
max_in_degree: Optional[int] = None,
) -> torch.Tensor:
r"""Forward computation.
Expand Down
Loading