diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/hetero_gat_conv.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/hetero_gat_conv.py index 49fd4bb84bf..3b717552a96 100644 --- a/python/cugraph-pyg/cugraph_pyg/nn/conv/hetero_gat_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/hetero_gat_conv.py @@ -20,7 +20,6 @@ from .base import BaseConv torch = import_optional("torch") -nn = import_optional("torch.nn") torch_geometric = import_optional("torch_geometric") @@ -75,6 +74,11 @@ def __init__( bias: bool = True, aggr: str = "sum", ): + major, minor, patch = torch_geometric.__version__.split(".")[:3] + pyg_version = tuple(map(int, [major, minor, patch])) + if pyg_version < (2, 4, 0): + raise RuntimeError(f"{self.__class__.__name__} requires pyg >= 2.4.0.") + super().__init__() if isinstance(in_channels, int): @@ -93,9 +97,7 @@ def __init__( self.relations_per_ntype = defaultdict(lambda: ([], [])) lin_weights = dict.fromkeys(self.node_types) - attn_weights = dict.fromkeys(self.edge_types) - biases = dict.fromkeys(self.edge_types) ParameterDict = torch_geometric.nn.parameter_dict.ParameterDict diff --git a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_hetero_gat_conv.py b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_hetero_gat_conv.py index 0eaf2e103ee..1c841a17df7 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_hetero_gat_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_hetero_gat_conv.py @@ -30,6 +30,11 @@ @pytest.mark.parametrize("heads", [1, 3, 10]) @pytest.mark.parametrize("aggr", ["sum", "mean"]) def test_hetero_gat_conv_equality(sample_pyg_hetero_data, aggr, heads): + major, minor, patch = torch_geometric.__version__.split(".")[:3] + pyg_version = tuple(map(int, [major, minor, patch])) + if pyg_version < (2, 4, 0): + pytest.skip("Skipping HeteroGATConv test") + from torch_geometric.data import HeteroData from torch_geometric.nn import HeteroConv, GATConv