From 02655455fa5332ff9535d5776f63bc08f0d23690 Mon Sep 17 00:00:00 2001 From: Tingyu Wang Date: Fri, 8 Sep 2023 15:22:40 -0400 Subject: [PATCH] support share_weights option, test edge_feat --- .../cugraph_dgl/nn/conv/gatv2conv.py | 29 ++++++---- python/cugraph-dgl/tests/nn/test_gatv2conv.py | 56 +++++++++++++++++-- 2 files changed, 70 insertions(+), 15 deletions(-) diff --git a/python/cugraph-dgl/cugraph_dgl/nn/conv/gatv2conv.py b/python/cugraph-dgl/cugraph_dgl/nn/conv/gatv2conv.py index c19152f733b..9b431c5407d 100644 --- a/python/cugraph-dgl/cugraph_dgl/nn/conv/gatv2conv.py +++ b/python/cugraph-dgl/cugraph_dgl/nn/conv/gatv2conv.py @@ -93,12 +93,18 @@ def __init__( self.share_weights = share_weights self.lin_src = nn.Linear(self.in_feats_src, num_heads * out_feats, bias=bias) - if isinstance(in_feats, (list, tuple)): + if share_weights: + if self.in_feats_src != self.in_feats_dst: + raise ValueError( + f"Input feature size of source and destination " + f"nodes must be identical when share_weights is enabled, " + f"but got {self.in_feats_src} and {self.in_feats_dst}." + ) + self.lin_dst = self.lin_src + else: self.lin_dst = nn.Linear( self.in_feats_dst, num_heads * out_feats, bias=bias ) - else: - self.lin_dst = self.lin_src self.attn = nn.Parameter(torch.Tensor(num_heads * out_feats)) @@ -108,11 +114,11 @@ def __init__( self.register_parameter("lin_edge", None) if bias and concat: - self.bias = nn.Parameter(torch.Tensor(num_heads * out_feats)) + self.bias = nn.Parameter(torch.Tensor(num_heads, out_feats)) elif bias and not concat: self.bias = nn.Parameter(torch.Tensor(out_feats)) else: - self.register_parameter("bias", None) + self.register_buffer("bias", None) self.residual = residual and self.in_feats_dst != out_feats * num_heads if self.residual: @@ -175,7 +181,8 @@ def forward( if max_in_degree is None: max_in_degree = -1 - bipartite = isinstance(nfeat, (list, tuple)) + nfeat_bipartite = isinstance(nfeat, (list, tuple)) + graph_bipartite = nfeat_bipartite or self.share_weights is False if isinstance(g, SparseGraph): assert "csc" in g.formats() @@ -185,7 +192,7 @@ def forward( indices=indices, num_src_nodes=g.num_src_nodes(), dst_max_in_degree=max_in_degree, - is_bipartite=bipartite, + is_bipartite=graph_bipartite, ) elif isinstance(g, dgl.DGLHeteroGraph): if not self.allow_zero_in_degree: @@ -207,7 +214,7 @@ def forward( indices=indices, num_src_nodes=g.num_src_nodes(), dst_max_in_degree=max_in_degree, - is_bipartite=bipartite, + is_bipartite=graph_bipartite, ) else: raise TypeError( @@ -215,7 +222,7 @@ def forward( f"'dgl.DGLHeteroGraph', but got '{type(g)}'." ) - if bipartite: + if nfeat_bipartite: nfeat = (self.feat_drop(nfeat[0]), self.feat_drop(nfeat[1])) nfeat_dst_orig = nfeat[1] else: @@ -230,8 +237,10 @@ def forward( ) efeat = self.lin_edge(efeat) - if bipartite: + if nfeat_bipartite: nfeat = (self.lin_src(nfeat[0]), self.lin_dst(nfeat[1])) + elif graph_bipartite: + nfeat = (self.lin_src(nfeat), self.lin_dst(nfeat[: g.num_dst_nodes()])) else: nfeat = self.lin_src(nfeat) diff --git a/python/cugraph-dgl/tests/nn/test_gatv2conv.py b/python/cugraph-dgl/tests/nn/test_gatv2conv.py index 1157fc0d913..cc46a6e4b39 100644 --- a/python/cugraph-dgl/tests/nn/test_gatv2conv.py +++ b/python/cugraph-dgl/tests/nn/test_gatv2conv.py @@ -23,14 +23,13 @@ ATOL = 1e-6 -@pytest.mark.parametrize("bipartite", [False]) +@pytest.mark.parametrize("bipartite", [False, True]) @pytest.mark.parametrize("idtype_int", [False, True]) @pytest.mark.parametrize("max_in_degree", [None, 8]) @pytest.mark.parametrize("num_heads", [1, 2, 7]) @pytest.mark.parametrize("residual", [False, True]) @pytest.mark.parametrize("to_block", [False, True]) -# @pytest.mark.parametrize("sparse_format", ["coo", "csc", None]) -@pytest.mark.parametrize("sparse_format", [None]) +@pytest.mark.parametrize("sparse_format", ["coo", "csc", None]) def test_gatv2conv_equality( bipartite, idtype_int, max_in_degree, num_heads, residual, to_block, sparse_format ): @@ -74,8 +73,7 @@ def test_gatv2conv_equality( with torch.no_grad(): conv2.attn.data = conv1.attn.data.flatten() conv2.lin_src.weight.data = conv1.fc_src.weight.data.detach().clone() - if bipartite: - conv2.lin_dst.weight.data = conv1.fc_dst.weight.data.detach().clone() + conv2.lin_dst.weight.data = conv1.fc_dst.weight.data.detach().clone() if residual and conv2.residual: conv2.lin_res.weight.data = conv1.fc_res.weight.data.detach().clone() @@ -99,3 +97,51 @@ def test_gatv2conv_equality( ) assert torch.allclose(conv1.attn.grad, conv1.attn.grad, atol=ATOL) + + +@pytest.mark.parametrize("bias", [False, True]) +@pytest.mark.parametrize("bipartite", [False, True]) +@pytest.mark.parametrize("concat", [False, True]) +@pytest.mark.parametrize("max_in_degree", [None, 8, 800]) +@pytest.mark.parametrize("num_heads", [1, 2, 7]) +@pytest.mark.parametrize("to_block", [False, True]) +@pytest.mark.parametrize("use_edge_feats", [False, True]) +def test_gatv2conv_edge_feats( + bias, bipartite, concat, max_in_degree, num_heads, to_block, use_edge_feats +): + g = create_graph1().to("cuda") + + if to_block: + g = dgl.to_block(g) + + if bipartite: + in_feats = (10, 3) + nfeat = ( + torch.rand(g.num_src_nodes(), in_feats[0]).cuda(), + torch.rand(g.num_dst_nodes(), in_feats[1]).cuda(), + ) + else: + in_feats = 10 + nfeat = torch.rand(g.num_src_nodes(), in_feats).cuda() + out_feats = 2 + + if use_edge_feats: + edge_feats = 3 + efeat = torch.rand(g.num_edges(), edge_feats).cuda() + else: + edge_feats = None + efeat = None + + conv = CuGraphGATv2Conv( + in_feats, + out_feats, + num_heads, + concat=concat, + edge_feats=edge_feats, + bias=bias, + allow_zero_in_degree=True, + ).cuda() + out = conv(g, nfeat, efeat=efeat, max_in_degree=max_in_degree) + + grad_out = torch.rand_like(out) + out.backward(grad_out)