Skip to content

Commit

Permalink
Merge branch 'branch-24.02' into complement
Browse files Browse the repository at this point in the history
  • Loading branch information
eriknw committed Jan 25, 2024
2 parents d56b19f + 82552ab commit 4a51295
Show file tree
Hide file tree
Showing 31 changed files with 352 additions and 291 deletions.
70 changes: 41 additions & 29 deletions python/cugraph-dgl/cugraph_dgl/nn/conv/gatconv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -11,7 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple, Union
from typing import Optional, Union

from cugraph_dgl.nn.conv.base import BaseConv, SparseGraph
from cugraph.utilities.utils import import_optional
Expand All @@ -29,7 +29,7 @@ class GATConv(BaseConv):
Parameters
----------
in_feats : int or tuple
in_feats : int or (int, int)
Input feature size. A pair denotes feature sizes of source and
destination nodes.
out_feats : int
Expand Down Expand Up @@ -92,7 +92,7 @@ class GATConv(BaseConv):

def __init__(
self,
in_feats: Union[int, Tuple[int, int]],
in_feats: Union[int, tuple[int, int]],
out_feats: int,
num_heads: int,
feat_drop: float = 0.0,
Expand All @@ -104,14 +104,19 @@ def __init__(
bias: bool = True,
):
super().__init__()

if isinstance(in_feats, int):
self.in_feats_src = self.in_feats_dst = in_feats
else:
self.in_feats_src, self.in_feats_dst = in_feats
self.in_feats = in_feats
self.out_feats = out_feats
self.in_feats_src, self.in_feats_dst = dgl.utils.expand_as_pair(in_feats)
self.num_heads = num_heads
self.feat_drop = nn.Dropout(feat_drop)
self.concat = concat
self.edge_feats = edge_feats
self.negative_slope = negative_slope
self.residual = residual
self.allow_zero_in_degree = allow_zero_in_degree

if isinstance(in_feats, int):
Expand All @@ -126,28 +131,34 @@ def __init__(

if edge_feats is not None:
self.lin_edge = nn.Linear(edge_feats, num_heads * out_feats, bias=False)
self.attn_weights = nn.Parameter(torch.Tensor(3 * num_heads * out_feats))
self.attn_weights = nn.Parameter(torch.empty(3 * num_heads * out_feats))
else:
self.register_parameter("lin_edge", None)
self.attn_weights = nn.Parameter(torch.Tensor(2 * num_heads * out_feats))
self.attn_weights = nn.Parameter(torch.empty(2 * num_heads * out_feats))

if bias and concat:
self.bias = nn.Parameter(torch.Tensor(num_heads, out_feats))
elif bias and not concat:
self.bias = nn.Parameter(torch.Tensor(out_feats))
out_dim = num_heads * out_feats if concat else out_feats
if residual:
if self.in_feats_dst != out_dim:
self.lin_res = nn.Linear(self.in_feats_dst, out_dim, bias=bias)
else:
self.lin_res = nn.Identity()
else:
self.register_buffer("bias", None)
self.register_buffer("lin_res", None)

self.residual = residual and self.in_feats_dst != out_feats * num_heads
if self.residual:
self.lin_res = nn.Linear(
self.in_feats_dst, num_heads * out_feats, bias=bias
)
if bias and not isinstance(self.lin_res, nn.Linear):
if concat:
self.bias = nn.Parameter(torch.empty(num_heads, out_feats))
else:
self.bias = nn.Parameter(torch.empty(out_feats))
else:
self.register_buffer("lin_res", None)
self.register_buffer("bias", None)

self.reset_parameters()

def set_allow_zero_in_degree(self, set_value):
r"""Set allow_zero_in_degree flag."""
self.allow_zero_in_degree = set_value

def reset_parameters(self):
r"""Reinitialize learnable parameters."""
gain = nn.init.calculate_gain("relu")
Expand All @@ -172,7 +183,7 @@ def reset_parameters(self):
def forward(
self,
g: Union[SparseGraph, dgl.DGLHeteroGraph],
nfeat: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
nfeat: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
efeat: Optional[torch.Tensor] = None,
max_in_degree: Optional[int] = None,
) -> torch.Tensor:
Expand All @@ -182,8 +193,10 @@ def forward(
----------
graph : DGLGraph or SparseGraph
The graph.
nfeat : torch.Tensor
Input features of shape :math:`(N, D_{in})`.
nfeat : torch.Tensor or (torch.Tensor, torch.Tensor)
Node features. If given as a tuple, the two elements correspond to
the source and destination node features, respectively, in a
bipartite graph.
efeat: torch.Tensor, optional
Optional edge features.
max_in_degree : int
Expand Down Expand Up @@ -237,18 +250,17 @@ def forward(

if bipartite:
if not hasattr(self, "lin_src"):
raise RuntimeError(
f"{self.__class__.__name__}.in_feats must be a pair of "
f"integers to allow bipartite node features, but got "
f"{self.in_feats}."
)
nfeat_src = self.lin_src(nfeat[0])
nfeat_dst = self.lin_dst(nfeat[1])
nfeat_src = self.lin(nfeat[0])
nfeat_dst = self.lin(nfeat[1])
else:
nfeat_src = self.lin_src(nfeat[0])
nfeat_dst = self.lin_dst(nfeat[1])
else:
if not hasattr(self, "lin"):
raise RuntimeError(
f"{self.__class__.__name__}.in_feats is expected to be an "
f"integer, but got {self.in_feats}."
f"integer when the graph is not bipartite, "
f"but got {self.in_feats}."
)
nfeat = self.lin(nfeat)

Expand Down
69 changes: 31 additions & 38 deletions python/cugraph-dgl/cugraph_dgl/nn/conv/gatv2conv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -11,7 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple, Union
from typing import Optional, Union

from cugraph_dgl.nn.conv.base import BaseConv, SparseGraph
from cugraph.utilities.utils import import_optional
Expand All @@ -29,14 +29,11 @@ class GATv2Conv(BaseConv):
Parameters
----------
in_feats : int, or pair of ints
Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`.
If the layer is to be applied to a unidirectional bipartite graph, `in_feats`
specifies the input feature size on both the source and destination nodes.
If a scalar is given, the source and destination node feature size
would take the same value.
in_feats : int or (int, int)
Input feature size. A pair denotes feature sizes of source and
destination nodes.
out_feats : int
Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`.
Output feature size.
num_heads : int
Number of heads in Multi-Head Attention.
feat_drop : float, optional
Expand All @@ -58,17 +55,15 @@ class GATv2Conv(BaseConv):
input graph. By setting ``True``, it will suppress the check and let the
users handle it by themselves. Defaults: ``False``.
bias : bool, optional
If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
If True, learns a bias term. Defaults: ``True``.
share_weights : bool, optional
If set to :obj:`True`, the same matrix for :math:`W_{left}` and
:math:`W_{right}` in the above equations, will be applied to the source
and the target node of every edge. (default: :obj:`False`)
If ``True``, the same matrix will be applied to the source and the
destination node features. Defaults: ``False``.
"""

def __init__(
self,
in_feats: Union[int, Tuple[int, int]],
in_feats: Union[int, tuple[int, int]],
out_feats: int,
num_heads: int,
feat_drop: float = 0.0,
Expand All @@ -81,16 +76,22 @@ def __init__(
share_weights: bool = False,
):
super().__init__()

if isinstance(in_feats, int):
self.in_feats_src = self.in_feats_dst = in_feats
else:
self.in_feats_src, self.in_feats_dst = in_feats
self.in_feats = in_feats
self.out_feats = out_feats
self.in_feats_src, self.in_feats_dst = dgl.utils.expand_as_pair(in_feats)
self.num_heads = num_heads
self.feat_drop = nn.Dropout(feat_drop)
self.concat = concat
self.edge_feats = edge_feats
self.negative_slope = negative_slope
self.residual = residual
self.allow_zero_in_degree = allow_zero_in_degree
self.share_weights = share_weights
self.bias = bias

self.lin_src = nn.Linear(self.in_feats_src, num_heads * out_feats, bias=bias)
if share_weights:
Expand All @@ -106,52 +107,47 @@ def __init__(
self.in_feats_dst, num_heads * out_feats, bias=bias
)

self.attn = nn.Parameter(torch.Tensor(num_heads * out_feats))
self.attn_weights = nn.Parameter(torch.empty(num_heads * out_feats))

if edge_feats is not None:
self.lin_edge = nn.Linear(edge_feats, num_heads * out_feats, bias=False)
else:
self.register_parameter("lin_edge", None)

if bias and concat:
self.bias = nn.Parameter(torch.Tensor(num_heads, out_feats))
elif bias and not concat:
self.bias = nn.Parameter(torch.Tensor(out_feats))
else:
self.register_buffer("bias", None)

self.residual = residual and self.in_feats_dst != out_feats * num_heads
if self.residual:
self.lin_res = nn.Linear(
self.in_feats_dst, num_heads * out_feats, bias=bias
)
out_dim = num_heads * out_feats if concat else out_feats
if residual:
if self.in_feats_dst != out_dim:
self.lin_res = nn.Linear(self.in_feats_dst, out_dim, bias=bias)
else:
self.lin_res = nn.Identity()
else:
self.register_buffer("lin_res", None)

self.reset_parameters()

def set_allow_zero_in_degree(self, set_value):
r"""Set allow_zero_in_degree flag."""
self.allow_zero_in_degree = set_value

def reset_parameters(self):
r"""Reinitialize learnable parameters."""
gain = nn.init.calculate_gain("relu")
nn.init.xavier_normal_(self.lin_src.weight, gain=gain)
nn.init.xavier_normal_(self.lin_dst.weight, gain=gain)

nn.init.xavier_normal_(
self.attn.view(-1, self.num_heads, self.out_feats), gain=gain
self.attn_weights.view(-1, self.num_heads, self.out_feats), gain=gain
)
if self.lin_edge is not None:
self.lin_edge.reset_parameters()

if self.lin_res is not None:
self.lin_res.reset_parameters()

if self.bias is not None:
nn.init.zeros_(self.bias)

def forward(
self,
g: Union[SparseGraph, dgl.DGLHeteroGraph],
nfeat: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
nfeat: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
efeat: Optional[torch.Tensor] = None,
max_in_degree: Optional[int] = None,
) -> torch.Tensor:
Expand Down Expand Up @@ -225,7 +221,7 @@ def forward(

out = ops_torch.operators.mha_gat_v2_n2n(
nfeat,
self.attn,
self.attn_weights,
_graph,
num_heads=self.num_heads,
activation="LeakyReLU",
Expand All @@ -243,7 +239,4 @@ def forward(
res = res.mean(dim=1)
out = out + res

if self.bias is not None:
out = out + self.bias

return out
10 changes: 5 additions & 5 deletions python/cugraph-dgl/cugraph_dgl/nn/conv/relgraphconv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -100,16 +100,16 @@ def __init__(
self.self_loop = self_loop
if regularizer is None:
self.W = nn.Parameter(
torch.Tensor(num_rels + dim_self_loop, in_feats, out_feats)
torch.empty(num_rels + dim_self_loop, in_feats, out_feats)
)
self.coeff = None
elif regularizer == "basis":
if num_bases is None:
raise ValueError('Missing "num_bases" for basis regularization.')
self.W = nn.Parameter(
torch.Tensor(num_bases + dim_self_loop, in_feats, out_feats)
torch.empty(num_bases + dim_self_loop, in_feats, out_feats)
)
self.coeff = nn.Parameter(torch.Tensor(num_rels, num_bases))
self.coeff = nn.Parameter(torch.empty(num_rels, num_bases))
self.num_bases = num_bases
else:
raise ValueError(
Expand All @@ -119,7 +119,7 @@ def __init__(
self.regularizer = regularizer

if bias:
self.bias = nn.Parameter(torch.Tensor(out_feats))
self.bias = nn.Parameter(torch.empty(out_feats))
else:
self.register_parameter("bias", None)

Expand Down
8 changes: 4 additions & 4 deletions python/cugraph-dgl/cugraph_dgl/nn/conv/sageconv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand All @@ -11,7 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple, Union
from typing import Optional, Union

from cugraph_dgl.nn.conv.base import BaseConv, SparseGraph
from cugraph.utilities.utils import import_optional
Expand Down Expand Up @@ -65,7 +65,7 @@ class SAGEConv(BaseConv):

def __init__(
self,
in_feats: Union[int, Tuple[int, int]],
in_feats: Union[int, tuple[int, int]],
out_feats: int,
aggregator_type: str = "mean",
feat_drop: float = 0.0,
Expand Down Expand Up @@ -111,7 +111,7 @@ def reset_parameters(self):
def forward(
self,
g: Union[SparseGraph, dgl.DGLHeteroGraph],
feat: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
feat: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
max_in_degree: Optional[int] = None,
) -> torch.Tensor:
r"""Forward computation.
Expand Down
Loading

0 comments on commit 4a51295

Please sign in to comment.