From b11af9a2ab0017c3a873f67d7fb8ce06b7e6e1ed Mon Sep 17 00:00:00 2001 From: Tingyu Wang Date: Thu, 16 Nov 2023 22:41:36 -0500 Subject: [PATCH] fix test for pyg>=2.4 --- .../tests/nn/test_hetero_gat_conv.py | 17 +++++---- .../cugraph-pyg/cugraph_pyg/tests/nn/utils.py | 36 +++++++++++++++++++ 2 files changed, 47 insertions(+), 6 deletions(-) create mode 100644 python/cugraph-pyg/cugraph_pyg/tests/nn/utils.py diff --git a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_hetero_gat_conv.py b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_hetero_gat_conv.py index a2300d6df95..5b450bc8da6 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_hetero_gat_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_hetero_gat_conv.py @@ -15,6 +15,7 @@ from cugraph_pyg.nn import HeteroGATConv as CuGraphHeteroGATConv from cugraph.utilities.utils import import_optional, MissingModule +from utils import convert_edge_type_key torch = import_optional("torch") torch_geometric = import_optional("torch_geometric") @@ -68,15 +69,16 @@ def test_hetero_gat_conv_equality(sample_pyg_hetero_data, aggr, heads): with torch.no_grad(): for etype_str in conv2.edge_types_str: src_t, _, dst_t = etype_str.split("__") - w_src[etype_str][:, :] = conv1.convs[etype_str].lin_src.weight[:, :] + pyg_etype_str = convert_edge_type_key(etype_str) + w_src[etype_str][:, :] = conv1.convs[pyg_etype_str].lin_src.weight[:, :] if w_dst[etype_str] is not None: - w_dst[etype_str][:, :] = conv1.convs[etype_str].lin_dst.weight[:, :] + w_dst[etype_str][:, :] = conv1.convs[pyg_etype_str].lin_dst.weight[:, :] conv2.attn_weights[etype_str][: heads * out_channels] = conv1.convs[ - etype_str + pyg_etype_str ].att_src.data.flatten() conv2.attn_weights[etype_str][heads * out_channels :] = conv1.convs[ - etype_str + pyg_etype_str ].att_dst.data.flatten() out1 = conv1(data.x_dict, data.edge_index_dict) @@ -97,13 +99,14 @@ def test_hetero_gat_conv_equality(sample_pyg_hetero_data, aggr, heads): # check gradient w.r.t attention weights out_dim = heads * out_channels for etype_str in conv2.edge_types_str: + pyg_etype_str = convert_edge_type_key(etype_str) assert torch.allclose( - conv1.convs[etype_str].att_src.grad.flatten(), + conv1.convs[pyg_etype_str].att_src.grad.flatten(), conv2.attn_weights[etype_str].grad[:out_dim], atol=ATOL, ) assert torch.allclose( - conv1.convs[etype_str].att_dst.grad.flatten(), + conv1.convs[pyg_etype_str].att_dst.grad.flatten(), conv2.attn_weights[etype_str].grad[out_dim:], atol=ATOL, ) @@ -113,8 +116,10 @@ def test_hetero_gat_conv_equality(sample_pyg_hetero_data, aggr, heads): for node_t, (rels_as_src, rels_as_dst) in conv2.relations_per_ntype.items(): grad_list = [] for rel_t in rels_as_src: + rel_t = convert_edge_type_key(rel_t) grad_list.append(conv1.convs[rel_t].lin_src.weight.grad.clone()) for rel_t in rels_as_dst: + rel_t = convert_edge_type_key(rel_t) grad_list.append(conv1.convs[rel_t].lin_dst.weight.grad.clone()) assert len(grad_list) > 0 grad_lin_weights_ref[node_t] = torch.vstack(grad_list) diff --git a/python/cugraph-pyg/cugraph_pyg/tests/nn/utils.py b/python/cugraph-pyg/cugraph_pyg/tests/nn/utils.py new file mode 100644 index 00000000000..58d787f745d --- /dev/null +++ b/python/cugraph-pyg/cugraph_pyg/tests/nn/utils.py @@ -0,0 +1,36 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cugraph.utilities.utils import import_optional, MissingModule + +torch_geometric = import_optional("torch_geometric") + +HAS_PYG_24 = None +if not isinstance(torch_geometric, MissingModule): + major, minor, patch = torch_geometric.__version__.split(".")[:3] + pyg_version = tuple(map(int, [major, minor, patch])) + HAS_PYG_24 = pyg_version >= (2, 4, 0) + + +# TODO: Remove this function when dropping support to pyg 2.3 +def convert_edge_type_key(edge_type_str): + """Convert an edge_type string to one that follows PyG's convention. + + Pre v2.4.0, the keys of nn.ModuleDict in HeteroConv use + "author__writes__paper" style." It has been changed to + "" since 2.4.0. + """ + if HAS_PYG_24: + return f"<{'___'.join(edge_type_str.split('__'))}>" + + return edge_type_str