diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/hetero_gat_conv.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/hetero_gat_conv.py index 08c51400e7f..49fd4bb84bf 100644 --- a/python/cugraph-pyg/cugraph_pyg/nn/conv/hetero_gat_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/hetero_gat_conv.py @@ -84,7 +84,6 @@ def __init__( self.node_types = node_types self.edge_types = edge_types - self.edge_types_str = ["__".join(etype) for etype in self.edge_types] self.num_heads = heads self.concat_heads = concat @@ -95,27 +94,28 @@ def __init__( lin_weights = dict.fromkeys(self.node_types) - attn_weights = dict.fromkeys(self.edge_types_str) + attn_weights = dict.fromkeys(self.edge_types) - biases = dict.fromkeys(self.edge_types_str) + biases = dict.fromkeys(self.edge_types) + + ParameterDict = torch_geometric.nn.parameter_dict.ParameterDict for edge_type in self.edge_types: src_type, _, dst_type = edge_type - etype_str = "__".join(edge_type) - self.relations_per_ntype[src_type][0].append(etype_str) + self.relations_per_ntype[src_type][0].append(edge_type) if src_type != dst_type: - self.relations_per_ntype[dst_type][1].append(etype_str) + self.relations_per_ntype[dst_type][1].append(edge_type) - attn_weights[etype_str] = torch.empty( + attn_weights[edge_type] = torch.empty( 2 * self.num_heads * self.out_channels ) if bias and concat: - biases[etype_str] = torch.empty(self.num_heads * out_channels) + biases[edge_type] = torch.empty(self.num_heads * out_channels) elif bias: - biases[etype_str] = torch.empty(out_channels) + biases[edge_type] = torch.empty(out_channels) else: - biases[etype_str] = None + biases[edge_type] = None for ntype in self.node_types: n_src_rel = len(self.relations_per_ntype[ntype][0]) @@ -126,11 +126,11 @@ def __init__( (n_rel * self.num_heads * self.out_channels, self.in_channels[ntype]) ) - self.lin_weights = nn.ParameterDict(lin_weights) - self.attn_weights = nn.ParameterDict(attn_weights) + self.lin_weights = ParameterDict(lin_weights) + self.attn_weights = ParameterDict(attn_weights) if bias: - self.bias = nn.ParameterDict(biases) + self.bias = ParameterDict(biases) else: self.register_parameter("bias", None) @@ -159,8 +159,8 @@ def split_tensors( x_dst_dict : dict[str, torch.Tensor] A dictionary to hold destination node feature for each relation graph. """ - x_src_dict = dict.fromkeys(self.edge_types_str) - x_dst_dict = dict.fromkeys(self.edge_types_str) + x_src_dict = dict.fromkeys(self.edge_types) + x_dst_dict = dict.fromkeys(self.edge_types) for ntype, t in x_fused_dict.items(): n_src_rel = len(self.relations_per_ntype[ntype][0]) @@ -182,24 +182,24 @@ def reset_parameters(self, seed: Optional[int] = None): w_src, w_dst = self.split_tensors(self.lin_weights, dim=0) - for _, edge_type in enumerate(self.edge_types): + for edge_type in self.edge_types: src_type, _, dst_type = edge_type - etype_str = "__".join(edge_type) + # lin_src - torch_geometric.nn.inits.glorot(w_src[etype_str]) + torch_geometric.nn.inits.glorot(w_src[edge_type]) # lin_dst if src_type != dst_type: - torch_geometric.nn.inits.glorot(w_dst[etype_str]) + torch_geometric.nn.inits.glorot(w_dst[edge_type]) # attn_weights torch_geometric.nn.inits.glorot( - self.attn_weights[etype_str].view(-1, self.num_heads, self.out_channels) + self.attn_weights[edge_type].view(-1, self.num_heads, self.out_channels) ) # bias if self.bias is not None: - torch_geometric.nn.inits.zeros(self.bias[etype_str]) + torch_geometric.nn.inits.zeros(self.bias[edge_type]) def forward( self, @@ -217,7 +217,6 @@ def forward( for edge_type, edge_index in edge_index_dict.items(): src_type, _, dst_type = edge_type - etype_str = "__".join(edge_type) csc = BaseConv.to_csc( edge_index, (x_dict[src_type].size(0), x_dict[dst_type].size(0)) @@ -229,8 +228,8 @@ def forward( bipartite=False, ) out = mha_gat_n2n( - x_src_dict[etype_str], - self.attn_weights[etype_str], + x_src_dict[edge_type], + self.attn_weights[edge_type], graph, num_heads=self.num_heads, activation="LeakyReLU", @@ -244,8 +243,8 @@ def forward( bipartite=True, ) out = mha_gat_n2n( - (x_src_dict[etype_str], x_dst_dict[etype_str]), - self.attn_weights[etype_str], + (x_src_dict[edge_type], x_dst_dict[edge_type]), + self.attn_weights[edge_type], graph, num_heads=self.num_heads, activation="LeakyReLU", @@ -254,7 +253,7 @@ def forward( ) if self.bias is not None: - out = out + self.bias[etype_str] + out = out + self.bias[edge_type] out_dict[dst_type].append(out) diff --git a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_hetero_gat_conv.py b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_hetero_gat_conv.py index 5b450bc8da6..0eaf2e103ee 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_hetero_gat_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_hetero_gat_conv.py @@ -15,7 +15,6 @@ from cugraph_pyg.nn import HeteroGATConv as CuGraphHeteroGATConv from cugraph.utilities.utils import import_optional, MissingModule -from utils import convert_edge_type_key torch = import_optional("torch") torch_geometric = import_optional("torch_geometric") @@ -67,18 +66,17 @@ def test_hetero_gat_conv_equality(sample_pyg_hetero_data, aggr, heads): # copy over linear and attention weights w_src, w_dst = conv2.split_tensors(conv2.lin_weights, dim=0) with torch.no_grad(): - for etype_str in conv2.edge_types_str: - src_t, _, dst_t = etype_str.split("__") - pyg_etype_str = convert_edge_type_key(etype_str) - w_src[etype_str][:, :] = conv1.convs[pyg_etype_str].lin_src.weight[:, :] - if w_dst[etype_str] is not None: - w_dst[etype_str][:, :] = conv1.convs[pyg_etype_str].lin_dst.weight[:, :] - - conv2.attn_weights[etype_str][: heads * out_channels] = conv1.convs[ - pyg_etype_str + for edge_type in conv2.edge_types: + src_t, _, dst_t = edge_type + w_src[edge_type][:, :] = conv1.convs[edge_type].lin_src.weight[:, :] + if w_dst[edge_type] is not None: + w_dst[edge_type][:, :] = conv1.convs[edge_type].lin_dst.weight[:, :] + + conv2.attn_weights[edge_type][: heads * out_channels] = conv1.convs[ + edge_type ].att_src.data.flatten() - conv2.attn_weights[etype_str][heads * out_channels :] = conv1.convs[ - pyg_etype_str + conv2.attn_weights[edge_type][heads * out_channels :] = conv1.convs[ + edge_type ].att_dst.data.flatten() out1 = conv1(data.x_dict, data.edge_index_dict) @@ -98,16 +96,15 @@ def test_hetero_gat_conv_equality(sample_pyg_hetero_data, aggr, heads): # check gradient w.r.t attention weights out_dim = heads * out_channels - for etype_str in conv2.edge_types_str: - pyg_etype_str = convert_edge_type_key(etype_str) + for edge_type in conv2.edge_types: assert torch.allclose( - conv1.convs[pyg_etype_str].att_src.grad.flatten(), - conv2.attn_weights[etype_str].grad[:out_dim], + conv1.convs[edge_type].att_src.grad.flatten(), + conv2.attn_weights[edge_type].grad[:out_dim], atol=ATOL, ) assert torch.allclose( - conv1.convs[pyg_etype_str].att_dst.grad.flatten(), - conv2.attn_weights[etype_str].grad[out_dim:], + conv1.convs[edge_type].att_dst.grad.flatten(), + conv2.attn_weights[edge_type].grad[out_dim:], atol=ATOL, ) @@ -116,10 +113,8 @@ def test_hetero_gat_conv_equality(sample_pyg_hetero_data, aggr, heads): for node_t, (rels_as_src, rels_as_dst) in conv2.relations_per_ntype.items(): grad_list = [] for rel_t in rels_as_src: - rel_t = convert_edge_type_key(rel_t) grad_list.append(conv1.convs[rel_t].lin_src.weight.grad.clone()) for rel_t in rels_as_dst: - rel_t = convert_edge_type_key(rel_t) grad_list.append(conv1.convs[rel_t].lin_dst.weight.grad.clone()) assert len(grad_list) > 0 grad_lin_weights_ref[node_t] = torch.vstack(grad_list) diff --git a/python/cugraph-pyg/cugraph_pyg/tests/nn/utils.py b/python/cugraph-pyg/cugraph_pyg/tests/nn/utils.py deleted file mode 100644 index 58d787f745d..00000000000 --- a/python/cugraph-pyg/cugraph_pyg/tests/nn/utils.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from cugraph.utilities.utils import import_optional, MissingModule - -torch_geometric = import_optional("torch_geometric") - -HAS_PYG_24 = None -if not isinstance(torch_geometric, MissingModule): - major, minor, patch = torch_geometric.__version__.split(".")[:3] - pyg_version = tuple(map(int, [major, minor, patch])) - HAS_PYG_24 = pyg_version >= (2, 4, 0) - - -# TODO: Remove this function when dropping support to pyg 2.3 -def convert_edge_type_key(edge_type_str): - """Convert an edge_type string to one that follows PyG's convention. - - Pre v2.4.0, the keys of nn.ModuleDict in HeteroConv use - "author__writes__paper" style." It has been changed to - "" since 2.4.0. - """ - if HAS_PYG_24: - return f"<{'___'.join(edge_type_str.split('__'))}>" - - return edge_type_str