Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HeteroGATConv to cugraph-pyg #3914

Merged
merged 11 commits into from
Dec 11, 2023
2 changes: 2 additions & 0 deletions python/cugraph-pyg/cugraph_pyg/nn/conv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@

from .gat_conv import GATConv
from .gatv2_conv import GATv2Conv
from .hetero_gat_conv import HeteroGATConv
from .rgcn_conv import RGCNConv
from .sage_conv import SAGEConv
from .transformer_conv import TransformerConv

__all__ = [
"GATConv",
"GATv2Conv",
"HeteroGATConv",
"RGCNConv",
"SAGEConv",
"TransformerConv",
Expand Down
243 changes: 243 additions & 0 deletions python/cugraph-pyg/cugraph_pyg/nn/conv/hetero_gat_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
from collections import defaultdict

from cugraph.utilities.utils import import_optional
from pylibcugraphops.pytorch.operators import mha_gat_n2n

from .base import BaseConv

torch = import_optional("torch")
nn = import_optional("torch.nn")
torch_geometric = import_optional("torch_geometric")


class HeteroGATConv(BaseConv):
r"""The graph attentional operator on heterogeneous graphs, where a separate
`GATConv` is applied on the homogeneous graph for each edge type. Compared
with directly wrapping `GATConv`s with `HeteroConv`, `HeteroGATConv` fuses
all the linear transformation associated with each node type together into 1
GEMM call, to improve the performance on GPUs.

Parameters
----------
in_channels : int or Dict[str, int])
Size of each input sample of every node type.

out_channels : int
Size of each output sample.

node_types : List[str]
List of Node types.

edge_types : List[Tuple[str, str, str]]
List of Edge types.

heads : int, optional (default=1)
Number of multi-head-attentions.

concat : bool, optional (default=True):
If set to :obj:`False`, the multi-head attentions are averaged instead
of concatenated.

negative_slope : float, optional (default=0.2)
LeakyReLU angle of the negative slope.

bias : bool, optional (default=True)
If set to :obj:`False`, the layer will not learn an additive bias.

aggr : str, optional (default="sum")
The aggregation scheme to use for grouping node embeddings generated by
different relations. Choose from "sum", "mean", "min", "max".
"""

def __init__(
self,
in_channels: int | dict[str, int],
out_channels: int,
node_types: list[str],
edge_types: list[tuple[str, str, str]],
heads: int = 1,
concat: bool = True,
negative_slope: float = 0.2,
bias: bool = True,
aggr: str = "sum",
):
super().__init__()

if isinstance(in_channels, int):
in_channels = dict.fromkeys(node_types, in_channels)
self.in_channels = in_channels
self.out_channels = out_channels

self.node_types = node_types
self.edge_types = edge_types
self.edge_types_str = ["__".join(etype) for etype in self.edge_types]
self.num_heads = heads
self.concat_heads = concat

self.negative_slope = negative_slope
self.aggr = aggr

self.relations_per_ntype = defaultdict(lambda: ([], []))

lin_weights = dict.fromkeys(self.node_types)

attn_weights = dict.fromkeys(self.edge_types_str)

biases = dict.fromkeys(self.edge_types_str)

for edge_type in self.edge_types:
src_type, _, dst_type = edge_type
etype_str = "__".join(edge_type)
self.relations_per_ntype[src_type][0].append(etype_str)
if src_type != dst_type:
self.relations_per_ntype[dst_type][1].append(etype_str)

attn_weights[etype_str] = torch.empty(
2 * self.num_heads * self.out_channels
)

if bias and concat:
biases[etype_str] = torch.empty(self.num_heads * out_channels)
elif bias:
biases[etype_str] = torch.empty(out_channels)
else:
biases[etype_str] = None

for ntype in self.node_types:
n_src_rel = len(self.relations_per_ntype[ntype][0])
n_dst_rel = len(self.relations_per_ntype[ntype][1])
n_rel = n_src_rel + n_dst_rel

lin_weights[ntype] = torch.empty(
(n_rel * self.num_heads * self.out_channels, self.in_channels[ntype])
)

self.lin_weights = nn.ParameterDict(lin_weights)
self.attn_weights = nn.ParameterDict(attn_weights)

if bias:
self.bias = nn.ParameterDict(biases)
else:
self.register_parameter("bias", None)

self.reset_parameters()

def split_tensors(
self, x_dict: torch.Tensor, dim: int
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
"""Split fused tensors into chunks based on edge types."""
x_src_dict = dict.fromkeys(self.edge_types_str)
x_dst_dict = dict.fromkeys(self.edge_types_str)

for ntype, t in x_dict.items():
n_src_rel = len(self.relations_per_ntype[ntype][0])
n_dst_rel = len(self.relations_per_ntype[ntype][1])
n_rel = n_src_rel + n_dst_rel
t_list = torch.chunk(t, chunks=n_rel, dim=dim)

for i, src_rel in enumerate(self.relations_per_ntype[ntype][0]):
x_src_dict[src_rel] = t_list[i]

for i, dst_rel in enumerate(self.relations_per_ntype[ntype][1]):
src_type, _, dst_type = dst_rel.split("__")
if src_type != dst_type:
x_dst_dict[dst_rel] = t_list[i + n_src_rel]

return x_src_dict, x_dst_dict

def reset_parameters(self, seed: Optional[int] = None) -> None:
if seed is not None:
torch.manual_seed(seed)

w_src, w_dst = self.split_tensors(self.lin_weights, dim=0)

for i, edge_type in enumerate(self.edge_types):
src_type, _, dst_type = edge_type
etype_str = "__".join(edge_type)
# lin_src
torch_geometric.nn.inits.glorot(w_src[etype_str])

# lin_dst
if src_type != dst_type:
torch_geometric.nn.inits.glorot(w_dst[etype_str])

# attn_weights
torch_geometric.nn.inits.glorot(
self.attn_weights[etype_str].view(-1, self.num_heads, self.out_channels)
)

# bias
if self.bias is not None:
torch_geometric.nn.inits.zeros(self.bias[etype_str])

def forward(self, x_dict: dict, edge_index_dict: dict) -> dict[str, torch.Tensor]:
feat_dict = dict.fromkeys(x_dict.keys())

for ntype, x in x_dict.items():
feat_dict[ntype] = x @ self.lin_weights[ntype].T

x_src_dict, x_dst_dict = self.split_tensors(feat_dict, dim=1)

out_dict = defaultdict(list)

for edge_type, edge_index in edge_index_dict.items():
src_type, _, dst_type = edge_type
etype_str = "__".join(edge_type)

csc = BaseConv.to_csc(
edge_index, (x_dict[src_type].size(0), x_dict[dst_type].size(0))
)

if src_type == dst_type:
graph = self.get_cugraph(
csc,
bipartite=False,
)
out = mha_gat_n2n(
x_src_dict[etype_str],
self.attn_weights[etype_str],
graph,
num_heads=self.num_heads,
activation="LeakyReLU",
negative_slope=self.negative_slope,
concat_heads=self.concat_heads,
)

else:
graph = self.get_cugraph(
csc,
bipartite=True,
)
out = mha_gat_n2n(
(x_src_dict[etype_str], x_dst_dict[etype_str]),
self.attn_weights[etype_str],
graph,
num_heads=self.num_heads,
activation="LeakyReLU",
negative_slope=self.negative_slope,
concat_heads=self.concat_heads,
)

if self.bias is not None:
out = out + self.bias[etype_str]

out_dict[dst_type].append(out)

for key, value in out_dict.items():
out_dict[key] = torch_geometric.nn.conv.hetero_conv.group(value, self.aggr)

return out_dict
Loading