From 16550034e5f33868bd492fd7ce051e84d2f91240 Mon Sep 17 00:00:00 2001 From: Tingyu Wang Date: Mon, 11 Dec 2023 13:00:59 -0500 Subject: [PATCH] Add `HeteroGATConv` to `cugraph-pyg` (#3914) Fixes https://github.com/rapidsai/graph_dl/issues/311 Adding @stadlmax POC code to cugraph-pyg Authors: - Tingyu Wang (https://github.com/tingyu66) - Brad Rees (https://github.com/BradReesWork) Approvers: - Brad Rees (https://github.com/BradReesWork) - Alex Barghi (https://github.com/alexbarghi-nv) URL: https://github.com/rapidsai/cugraph/pull/3914 --- .../cugraph_pyg/nn/conv/__init__.py | 2 + .../cugraph_pyg/nn/conv/hetero_gat_conv.py | 265 ++++++++++++++++++ .../cugraph-pyg/cugraph_pyg/tests/conftest.py | 29 ++ .../tests/nn/test_hetero_gat_conv.py | 132 +++++++++ 4 files changed, 428 insertions(+) create mode 100644 python/cugraph-pyg/cugraph_pyg/nn/conv/hetero_gat_conv.py create mode 100644 python/cugraph-pyg/cugraph_pyg/tests/nn/test_hetero_gat_conv.py diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/__init__.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/__init__.py index 9c9dcdb43bb..bef3a023b93 100644 --- a/python/cugraph-pyg/cugraph_pyg/nn/conv/__init__.py +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/__init__.py @@ -13,6 +13,7 @@ from .gat_conv import GATConv from .gatv2_conv import GATv2Conv +from .hetero_gat_conv import HeteroGATConv from .rgcn_conv import RGCNConv from .sage_conv import SAGEConv from .transformer_conv import TransformerConv @@ -20,6 +21,7 @@ __all__ = [ "GATConv", "GATv2Conv", + "HeteroGATConv", "RGCNConv", "SAGEConv", "TransformerConv", diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/hetero_gat_conv.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/hetero_gat_conv.py new file mode 100644 index 00000000000..3b717552a96 --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/hetero_gat_conv.py @@ -0,0 +1,265 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union +from collections import defaultdict + +from cugraph.utilities.utils import import_optional +from pylibcugraphops.pytorch.operators import mha_gat_n2n + +from .base import BaseConv + +torch = import_optional("torch") +torch_geometric = import_optional("torch_geometric") + + +class HeteroGATConv(BaseConv): + r"""The graph attentional operator on heterogeneous graphs, where a separate + `GATConv` is applied on the homogeneous graph for each edge type. Compared + with directly wrapping `GATConv`s with `HeteroConv`, `HeteroGATConv` fuses + all the linear transformation associated with each node type together into 1 + GEMM call, to improve the performance on GPUs. + + Parameters + ---------- + in_channels : int or Dict[str, int]) + Size of each input sample of every node type. + + out_channels : int + Size of each output sample. + + node_types : List[str] + List of Node types. + + edge_types : List[Tuple[str, str, str]] + List of Edge types. + + heads : int, optional (default=1) + Number of multi-head-attentions. + + concat : bool, optional (default=True): + If set to :obj:`False`, the multi-head attentions are averaged instead + of concatenated. + + negative_slope : float, optional (default=0.2) + LeakyReLU angle of the negative slope. + + bias : bool, optional (default=True) + If set to :obj:`False`, the layer will not learn an additive bias. + + aggr : str, optional (default="sum") + The aggregation scheme to use for grouping node embeddings generated by + different relations. Choose from "sum", "mean", "min", "max". + """ + + def __init__( + self, + in_channels: Union[int, dict[str, int]], + out_channels: int, + node_types: list[str], + edge_types: list[tuple[str, str, str]], + heads: int = 1, + concat: bool = True, + negative_slope: float = 0.2, + bias: bool = True, + aggr: str = "sum", + ): + major, minor, patch = torch_geometric.__version__.split(".")[:3] + pyg_version = tuple(map(int, [major, minor, patch])) + if pyg_version < (2, 4, 0): + raise RuntimeError(f"{self.__class__.__name__} requires pyg >= 2.4.0.") + + super().__init__() + + if isinstance(in_channels, int): + in_channels = dict.fromkeys(node_types, in_channels) + self.in_channels = in_channels + self.out_channels = out_channels + + self.node_types = node_types + self.edge_types = edge_types + self.num_heads = heads + self.concat_heads = concat + + self.negative_slope = negative_slope + self.aggr = aggr + + self.relations_per_ntype = defaultdict(lambda: ([], [])) + + lin_weights = dict.fromkeys(self.node_types) + attn_weights = dict.fromkeys(self.edge_types) + biases = dict.fromkeys(self.edge_types) + + ParameterDict = torch_geometric.nn.parameter_dict.ParameterDict + + for edge_type in self.edge_types: + src_type, _, dst_type = edge_type + self.relations_per_ntype[src_type][0].append(edge_type) + if src_type != dst_type: + self.relations_per_ntype[dst_type][1].append(edge_type) + + attn_weights[edge_type] = torch.empty( + 2 * self.num_heads * self.out_channels + ) + + if bias and concat: + biases[edge_type] = torch.empty(self.num_heads * out_channels) + elif bias: + biases[edge_type] = torch.empty(out_channels) + else: + biases[edge_type] = None + + for ntype in self.node_types: + n_src_rel = len(self.relations_per_ntype[ntype][0]) + n_dst_rel = len(self.relations_per_ntype[ntype][1]) + n_rel = n_src_rel + n_dst_rel + + lin_weights[ntype] = torch.empty( + (n_rel * self.num_heads * self.out_channels, self.in_channels[ntype]) + ) + + self.lin_weights = ParameterDict(lin_weights) + self.attn_weights = ParameterDict(attn_weights) + + if bias: + self.bias = ParameterDict(biases) + else: + self.register_parameter("bias", None) + + self.reset_parameters() + + def split_tensors( + self, x_fused_dict: dict[str, torch.Tensor], dim: int + ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + """Split fused tensors into chunks based on edge types. + + Parameters + ---------- + x_fused_dict : dict[str, torch.Tensor] + A dictionary to hold node feature for each node type. The key is + node type; the value is a fused tensor that account for all + relations for that node type. + + dim : int + Dimension along which to split the fused tensor. + + Returns + ------- + x_src_dict : dict[str, torch.Tensor] + A dictionary to hold source node feature for each relation graph. + + x_dst_dict : dict[str, torch.Tensor] + A dictionary to hold destination node feature for each relation graph. + """ + x_src_dict = dict.fromkeys(self.edge_types) + x_dst_dict = dict.fromkeys(self.edge_types) + + for ntype, t in x_fused_dict.items(): + n_src_rel = len(self.relations_per_ntype[ntype][0]) + n_dst_rel = len(self.relations_per_ntype[ntype][1]) + n_rel = n_src_rel + n_dst_rel + t_list = torch.chunk(t, chunks=n_rel, dim=dim) + + for i, src_rel in enumerate(self.relations_per_ntype[ntype][0]): + x_src_dict[src_rel] = t_list[i] + + for i, dst_rel in enumerate(self.relations_per_ntype[ntype][1]): + x_dst_dict[dst_rel] = t_list[i + n_src_rel] + + return x_src_dict, x_dst_dict + + def reset_parameters(self, seed: Optional[int] = None): + if seed is not None: + torch.manual_seed(seed) + + w_src, w_dst = self.split_tensors(self.lin_weights, dim=0) + + for edge_type in self.edge_types: + src_type, _, dst_type = edge_type + + # lin_src + torch_geometric.nn.inits.glorot(w_src[edge_type]) + + # lin_dst + if src_type != dst_type: + torch_geometric.nn.inits.glorot(w_dst[edge_type]) + + # attn_weights + torch_geometric.nn.inits.glorot( + self.attn_weights[edge_type].view(-1, self.num_heads, self.out_channels) + ) + + # bias + if self.bias is not None: + torch_geometric.nn.inits.zeros(self.bias[edge_type]) + + def forward( + self, + x_dict: dict[str, torch.Tensor], + edge_index_dict: dict[tuple[str, str, str], torch.Tensor], + ) -> dict[str, torch.Tensor]: + feat_dict = dict.fromkeys(x_dict.keys()) + + for ntype, x in x_dict.items(): + feat_dict[ntype] = x @ self.lin_weights[ntype].T + + x_src_dict, x_dst_dict = self.split_tensors(feat_dict, dim=1) + + out_dict = defaultdict(list) + + for edge_type, edge_index in edge_index_dict.items(): + src_type, _, dst_type = edge_type + + csc = BaseConv.to_csc( + edge_index, (x_dict[src_type].size(0), x_dict[dst_type].size(0)) + ) + + if src_type == dst_type: + graph = self.get_cugraph( + csc, + bipartite=False, + ) + out = mha_gat_n2n( + x_src_dict[edge_type], + self.attn_weights[edge_type], + graph, + num_heads=self.num_heads, + activation="LeakyReLU", + negative_slope=self.negative_slope, + concat_heads=self.concat_heads, + ) + + else: + graph = self.get_cugraph( + csc, + bipartite=True, + ) + out = mha_gat_n2n( + (x_src_dict[edge_type], x_dst_dict[edge_type]), + self.attn_weights[edge_type], + graph, + num_heads=self.num_heads, + activation="LeakyReLU", + negative_slope=self.negative_slope, + concat_heads=self.concat_heads, + ) + + if self.bias is not None: + out = out + self.bias[edge_type] + + out_dict[dst_type].append(out) + + for key, value in out_dict.items(): + out_dict[key] = torch_geometric.nn.conv.hetero_conv.group(value, self.aggr) + + return out_dict diff --git a/python/cugraph-pyg/cugraph_pyg/tests/conftest.py b/python/cugraph-pyg/cugraph_pyg/tests/conftest.py index 1512901822a..30994289f9c 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/conftest.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/conftest.py @@ -284,3 +284,32 @@ def basic_pyg_graph_2(): ) size = (10, 10) return edge_index, size + + +@pytest.fixture +def sample_pyg_hetero_data(): + torch.manual_seed(12345) + raw_data_dict = { + "v0": torch.randn(6, 3), + "v1": torch.randn(7, 2), + "v2": torch.randn(5, 4), + ("v2", "e0", "v1"): torch.tensor([[0, 2, 2, 4, 4], [4, 3, 6, 0, 1]]), + ("v1", "e1", "v1"): torch.tensor( + [[0, 2, 2, 2, 3, 5, 5], [4, 0, 4, 5, 3, 0, 1]] + ), + ("v0", "e2", "v0"): torch.tensor([[0, 2, 2, 3, 5, 5], [1, 1, 5, 1, 1, 2]]), + ("v1", "e3", "v2"): torch.tensor( + [[0, 1, 1, 2, 4, 5, 6], [1, 2, 3, 1, 2, 2, 2]] + ), + ("v0", "e4", "v2"): torch.tensor([[1, 1, 3, 3, 4, 4], [1, 4, 1, 4, 0, 3]]), + } + + # create a nested dictionary to facilitate PyG's HeteroData construction + hetero_data_dict = {} + for key, value in raw_data_dict.items(): + if isinstance(key, tuple): + hetero_data_dict[key] = {"edge_index": value} + else: + hetero_data_dict[key] = {"x": value} + + return hetero_data_dict diff --git a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_hetero_gat_conv.py b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_hetero_gat_conv.py new file mode 100644 index 00000000000..1c841a17df7 --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_hetero_gat_conv.py @@ -0,0 +1,132 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from cugraph_pyg.nn import HeteroGATConv as CuGraphHeteroGATConv +from cugraph.utilities.utils import import_optional, MissingModule + +torch = import_optional("torch") +torch_geometric = import_optional("torch_geometric") + +ATOL = 1e-6 + + +@pytest.mark.cugraph_ops +@pytest.mark.skipif(isinstance(torch, MissingModule), reason="torch not available") +@pytest.mark.skipif( + isinstance(torch_geometric, MissingModule), reason="torch_geometric not available" +) +@pytest.mark.parametrize("heads", [1, 3, 10]) +@pytest.mark.parametrize("aggr", ["sum", "mean"]) +def test_hetero_gat_conv_equality(sample_pyg_hetero_data, aggr, heads): + major, minor, patch = torch_geometric.__version__.split(".")[:3] + pyg_version = tuple(map(int, [major, minor, patch])) + if pyg_version < (2, 4, 0): + pytest.skip("Skipping HeteroGATConv test") + + from torch_geometric.data import HeteroData + from torch_geometric.nn import HeteroConv, GATConv + + device = torch.device("cuda:0") + data = HeteroData(sample_pyg_hetero_data).to(device) + + in_channels_dict = {k: v.size(1) for k, v in data.x_dict.items()} + out_channels = 2 + + convs_dict = {} + kwargs1 = dict(heads=heads, add_self_loops=False, bias=False) + for edge_type in data.edge_types: + src_t, _, dst_t = edge_type + in_channels_src, in_channels_dst = data.x_dict[src_t].size(-1), data.x_dict[ + dst_t + ].size(-1) + if src_t == dst_t: + convs_dict[edge_type] = GATConv(in_channels_src, out_channels, **kwargs1) + else: + convs_dict[edge_type] = GATConv( + (in_channels_src, in_channels_dst), out_channels, **kwargs1 + ) + + conv1 = HeteroConv(convs_dict, aggr=aggr).to(device) + kwargs2 = dict( + heads=heads, + aggr=aggr, + node_types=data.node_types, + edge_types=data.edge_types, + bias=False, + ) + conv2 = CuGraphHeteroGATConv(in_channels_dict, out_channels, **kwargs2).to(device) + + # copy over linear and attention weights + w_src, w_dst = conv2.split_tensors(conv2.lin_weights, dim=0) + with torch.no_grad(): + for edge_type in conv2.edge_types: + src_t, _, dst_t = edge_type + w_src[edge_type][:, :] = conv1.convs[edge_type].lin_src.weight[:, :] + if w_dst[edge_type] is not None: + w_dst[edge_type][:, :] = conv1.convs[edge_type].lin_dst.weight[:, :] + + conv2.attn_weights[edge_type][: heads * out_channels] = conv1.convs[ + edge_type + ].att_src.data.flatten() + conv2.attn_weights[edge_type][heads * out_channels :] = conv1.convs[ + edge_type + ].att_dst.data.flatten() + + out1 = conv1(data.x_dict, data.edge_index_dict) + out2 = conv2(data.x_dict, data.edge_index_dict) + + for node_type in data.node_types: + assert torch.allclose(out1[node_type], out2[node_type], atol=ATOL) + + loss1 = 0 + loss2 = 0 + for node_type in data.node_types: + loss1 += out1[node_type].mean() + loss2 += out2[node_type].mean() + + loss1.backward() + loss2.backward() + + # check gradient w.r.t attention weights + out_dim = heads * out_channels + for edge_type in conv2.edge_types: + assert torch.allclose( + conv1.convs[edge_type].att_src.grad.flatten(), + conv2.attn_weights[edge_type].grad[:out_dim], + atol=ATOL, + ) + assert torch.allclose( + conv1.convs[edge_type].att_dst.grad.flatten(), + conv2.attn_weights[edge_type].grad[out_dim:], + atol=ATOL, + ) + + # check gradient w.r.t linear weights + grad_lin_weights_ref = dict.fromkeys(out1.keys()) + for node_t, (rels_as_src, rels_as_dst) in conv2.relations_per_ntype.items(): + grad_list = [] + for rel_t in rels_as_src: + grad_list.append(conv1.convs[rel_t].lin_src.weight.grad.clone()) + for rel_t in rels_as_dst: + grad_list.append(conv1.convs[rel_t].lin_dst.weight.grad.clone()) + assert len(grad_list) > 0 + grad_lin_weights_ref[node_t] = torch.vstack(grad_list) + + for node_type in conv2.lin_weights: + assert torch.allclose( + grad_lin_weights_ref[node_type], + conv2.lin_weights[node_type].grad, + atol=ATOL, + )