From e1fead51b743371e9079fab5e0194c4c45079183 Mon Sep 17 00:00:00 2001 From: Tingyu Wang Date: Fri, 8 Sep 2023 16:29:16 -0400 Subject: [PATCH] allow sparsegraph in transformerconv --- .../cugraph_dgl/nn/conv/transformerconv.py | 41 +++++++++++++------ .../tests/nn/test_transformerconv.py | 41 ++++++++++++++----- 2 files changed, 59 insertions(+), 23 deletions(-) diff --git a/python/cugraph-dgl/cugraph_dgl/nn/conv/transformerconv.py b/python/cugraph-dgl/cugraph_dgl/nn/conv/transformerconv.py index 5cd5fbbaebe..c447cbb83fa 100644 --- a/python/cugraph-dgl/cugraph_dgl/nn/conv/transformerconv.py +++ b/python/cugraph-dgl/cugraph_dgl/nn/conv/transformerconv.py @@ -10,9 +10,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from typing import Optional, Tuple, Union -from cugraph_dgl.nn.conv.base import BaseConv +from cugraph_dgl.nn.conv.base import BaseConv, SparseGraph from cugraph.utilities.utils import import_optional dgl = import_optional("dgl") @@ -114,7 +115,7 @@ def reset_parameters(self): def forward( self, - g: dgl.DGLHeteroGraph, + g: Union[SparseGraph, dgl.DGLHeteroGraph], nfeat: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], efeat: Optional[torch.Tensor] = None, ) -> torch.Tensor: @@ -130,17 +131,33 @@ def forward( efeat: torch.Tensor, optional Edge feature tensor. Default: ``None``. """ - offsets, indices, _ = g.adj_tensors("csc") - graph = ops_torch.CSC( - offsets=offsets, - indices=indices, - num_src_nodes=g.num_src_nodes(), - is_bipartite=True, - ) - - if isinstance(nfeat, torch.Tensor): + bipartite = isinstance(nfeat, (list, tuple)) + if not bipartite: nfeat = (nfeat, nfeat) + if isinstance(g, SparseGraph): + assert "csc" in g.formats() + offsets, indices = g.csc() + _graph = ops_torch.CSC( + offsets=offsets, + indices=indices, + num_src_nodes=g.num_src_nodes(), + is_bipartite=True, + ) + elif isinstance(g, dgl.DGLHeteroGraph): + offsets, indices, _ = g.adj_tensors("csc") + _graph = ops_torch.CSC( + offsets=offsets, + indices=indices, + num_src_nodes=g.num_src_nodes(), + is_bipartite=True, + ) + else: + raise TypeError( + f"The graph has to be either a 'SparseGraph' or " + f"'dgl.DGLHeteroGraph', but got '{type(g)}'." + ) + query = self.lin_query(nfeat[1][: g.num_dst_nodes()]) key = self.lin_key(nfeat[0]) value = self.lin_value(nfeat[0]) @@ -157,7 +174,7 @@ def forward( key_emb=key, query_emb=query, value_emb=value, - graph=graph, + graph=_graph, num_heads=self.num_heads, concat_heads=self.concat, edge_emb=efeat, diff --git a/python/cugraph-dgl/tests/nn/test_transformerconv.py b/python/cugraph-dgl/tests/nn/test_transformerconv.py index 00476b9f0bb..b2b69cb35ab 100644 --- a/python/cugraph-dgl/tests/nn/test_transformerconv.py +++ b/python/cugraph-dgl/tests/nn/test_transformerconv.py @@ -13,16 +13,14 @@ import pytest -try: - from cugraph_dgl.nn import TransformerConv -except ModuleNotFoundError: - pytest.skip("cugraph_dgl not available", allow_module_level=True) - -from cugraph.utilities.utils import import_optional +from cugraph_dgl.nn.conv.base import SparseGraph +from cugraph_dgl.nn import TransformerConv from .common import create_graph1 -torch = import_optional("torch") -dgl = import_optional("dgl") +dgl = pytest.importorskip("dgl", reason="DGL not available") +torch = pytest.importorskip("torch", reason="PyTorch not available") + +ATOL = 1e-6 @pytest.mark.parametrize("beta", [False, True]) @@ -32,8 +30,16 @@ @pytest.mark.parametrize("num_heads", [1, 2, 3, 4]) @pytest.mark.parametrize("to_block", [False, True]) @pytest.mark.parametrize("use_edge_feats", [False, True]) -def test_TransformerConv( - beta, bipartite_node_feats, concat, idtype_int, num_heads, to_block, use_edge_feats +@pytest.mark.parametrize("sparse_format", ["coo", "csc", None]) +def test_transformerconv( + beta, + bipartite_node_feats, + concat, + idtype_int, + num_heads, + to_block, + use_edge_feats, + sparse_format, ): device = "cuda" g = create_graph1().to(device) @@ -44,6 +50,15 @@ def test_TransformerConv( if to_block: g = dgl.to_block(g) + size = (g.num_src_nodes(), g.num_dst_nodes()) + if sparse_format == "coo": + sg = SparseGraph( + size=size, src_ids=g.edges()[0], dst_ids=g.edges()[1], formats="csc" + ) + elif sparse_format == "csc": + offsets, indices, _ = g.adj_tensors("csc") + sg = SparseGraph(size=size, src_ids=indices, cdst_ids=offsets, formats="csc") + if bipartite_node_feats: in_node_feats = (5, 3) nfeat = ( @@ -71,6 +86,10 @@ def test_TransformerConv( edge_feats=edge_feats, ).to(device) - out = conv(g, nfeat, efeat) + if sparse_format is not None: + out = conv(sg, nfeat, efeat) + else: + out = conv(g, nfeat, efeat) + grad_out = torch.rand_like(out) out.backward(grad_out)