Skip to content

Commit

Permalink
allow sparsegraph in transformerconv
Browse files Browse the repository at this point in the history
  • Loading branch information
tingyu66 committed Sep 8, 2023
1 parent 89fdca1 commit e1fead5
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 23 deletions.
41 changes: 29 additions & 12 deletions python/cugraph-dgl/cugraph_dgl/nn/conv/transformerconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple, Union

from cugraph_dgl.nn.conv.base import BaseConv
from cugraph_dgl.nn.conv.base import BaseConv, SparseGraph
from cugraph.utilities.utils import import_optional

dgl = import_optional("dgl")
Expand Down Expand Up @@ -114,7 +115,7 @@ def reset_parameters(self):

def forward(
self,
g: dgl.DGLHeteroGraph,
g: Union[SparseGraph, dgl.DGLHeteroGraph],
nfeat: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
efeat: Optional[torch.Tensor] = None,
) -> torch.Tensor:
Expand All @@ -130,17 +131,33 @@ def forward(
efeat: torch.Tensor, optional
Edge feature tensor. Default: ``None``.
"""
offsets, indices, _ = g.adj_tensors("csc")
graph = ops_torch.CSC(
offsets=offsets,
indices=indices,
num_src_nodes=g.num_src_nodes(),
is_bipartite=True,
)

if isinstance(nfeat, torch.Tensor):
bipartite = isinstance(nfeat, (list, tuple))
if not bipartite:
nfeat = (nfeat, nfeat)

if isinstance(g, SparseGraph):
assert "csc" in g.formats()
offsets, indices = g.csc()
_graph = ops_torch.CSC(
offsets=offsets,
indices=indices,
num_src_nodes=g.num_src_nodes(),
is_bipartite=True,
)
elif isinstance(g, dgl.DGLHeteroGraph):
offsets, indices, _ = g.adj_tensors("csc")
_graph = ops_torch.CSC(
offsets=offsets,
indices=indices,
num_src_nodes=g.num_src_nodes(),
is_bipartite=True,
)
else:
raise TypeError(
f"The graph has to be either a 'SparseGraph' or "
f"'dgl.DGLHeteroGraph', but got '{type(g)}'."
)

query = self.lin_query(nfeat[1][: g.num_dst_nodes()])
key = self.lin_key(nfeat[0])
value = self.lin_value(nfeat[0])
Expand All @@ -157,7 +174,7 @@ def forward(
key_emb=key,
query_emb=query,
value_emb=value,
graph=graph,
graph=_graph,
num_heads=self.num_heads,
concat_heads=self.concat,
edge_emb=efeat,
Expand Down
41 changes: 30 additions & 11 deletions python/cugraph-dgl/tests/nn/test_transformerconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,14 @@

import pytest

try:
from cugraph_dgl.nn import TransformerConv
except ModuleNotFoundError:
pytest.skip("cugraph_dgl not available", allow_module_level=True)

from cugraph.utilities.utils import import_optional
from cugraph_dgl.nn.conv.base import SparseGraph
from cugraph_dgl.nn import TransformerConv
from .common import create_graph1

torch = import_optional("torch")
dgl = import_optional("dgl")
dgl = pytest.importorskip("dgl", reason="DGL not available")
torch = pytest.importorskip("torch", reason="PyTorch not available")

ATOL = 1e-6


@pytest.mark.parametrize("beta", [False, True])
Expand All @@ -32,8 +30,16 @@
@pytest.mark.parametrize("num_heads", [1, 2, 3, 4])
@pytest.mark.parametrize("to_block", [False, True])
@pytest.mark.parametrize("use_edge_feats", [False, True])
def test_TransformerConv(
beta, bipartite_node_feats, concat, idtype_int, num_heads, to_block, use_edge_feats
@pytest.mark.parametrize("sparse_format", ["coo", "csc", None])
def test_transformerconv(
beta,
bipartite_node_feats,
concat,
idtype_int,
num_heads,
to_block,
use_edge_feats,
sparse_format,
):
device = "cuda"
g = create_graph1().to(device)
Expand All @@ -44,6 +50,15 @@ def test_TransformerConv(
if to_block:
g = dgl.to_block(g)

size = (g.num_src_nodes(), g.num_dst_nodes())
if sparse_format == "coo":
sg = SparseGraph(
size=size, src_ids=g.edges()[0], dst_ids=g.edges()[1], formats="csc"
)
elif sparse_format == "csc":
offsets, indices, _ = g.adj_tensors("csc")
sg = SparseGraph(size=size, src_ids=indices, cdst_ids=offsets, formats="csc")

if bipartite_node_feats:
in_node_feats = (5, 3)
nfeat = (
Expand Down Expand Up @@ -71,6 +86,10 @@ def test_TransformerConv(
edge_feats=edge_feats,
).to(device)

out = conv(g, nfeat, efeat)
if sparse_format is not None:
out = conv(sg, nfeat, efeat)
else:
out = conv(g, nfeat, efeat)

grad_out = torch.rand_like(out)
out.backward(grad_out)

0 comments on commit e1fead5

Please sign in to comment.