Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
tingyu66 committed Dec 17, 2023
1 parent 9bd1dc3 commit e601dd2
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 57 deletions.
8 changes: 4 additions & 4 deletions python/cugraph-dgl/cugraph_dgl/nn/conv/relgraphconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,16 @@ def __init__(
self.self_loop = self_loop
if regularizer is None:
self.W = nn.Parameter(
torch.Tensor(num_rels + dim_self_loop, in_feats, out_feats)
torch.empty(num_rels + dim_self_loop, in_feats, out_feats)
)
self.coeff = None
elif regularizer == "basis":
if num_bases is None:
raise ValueError('Missing "num_bases" for basis regularization.')
self.W = nn.Parameter(
torch.Tensor(num_bases + dim_self_loop, in_feats, out_feats)
torch.empty(num_bases + dim_self_loop, in_feats, out_feats)
)
self.coeff = nn.Parameter(torch.Tensor(num_rels, num_bases))
self.coeff = nn.Parameter(torch.empty(num_rels, num_bases))
self.num_bases = num_bases
else:
raise ValueError(
Expand All @@ -119,7 +119,7 @@ def __init__(
self.regularizer = regularizer

if bias:
self.bias = nn.Parameter(torch.Tensor(out_feats))
self.bias = nn.Parameter(torch.empty(out_feats))
else:
self.register_parameter("bias", None)

Expand Down
6 changes: 3 additions & 3 deletions python/cugraph-dgl/cugraph_dgl/nn/conv/sageconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple, Union
from typing import Optional, Union

from cugraph_dgl.nn.conv.base import BaseConv, SparseGraph
from cugraph.utilities.utils import import_optional
Expand Down Expand Up @@ -65,7 +65,7 @@ class SAGEConv(BaseConv):

def __init__(
self,
in_feats: Union[int, Tuple[int, int]],
in_feats: Union[int, tuple[int, int]],
out_feats: int,
aggregator_type: str = "mean",
feat_drop: float = 0.0,
Expand Down Expand Up @@ -111,7 +111,7 @@ def reset_parameters(self):
def forward(
self,
g: Union[SparseGraph, dgl.DGLHeteroGraph],
feat: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
feat: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
max_in_degree: Optional[int] = None,
) -> torch.Tensor:
r"""Forward computation.
Expand Down
6 changes: 3 additions & 3 deletions python/cugraph-dgl/cugraph_dgl/nn/conv/transformerconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple, Union
from typing import Optional, Union

from cugraph_dgl.nn.conv.base import BaseConv, SparseGraph
from cugraph.utilities.utils import import_optional
Expand Down Expand Up @@ -51,7 +51,7 @@ class TransformerConv(BaseConv):

def __init__(
self,
in_node_feats: Union[int, Tuple[int, int]],
in_node_feats: Union[int, tuple[int, int]],
out_node_feats: int,
num_heads: int,
concat: bool = True,
Expand Down Expand Up @@ -116,7 +116,7 @@ def reset_parameters(self):
def forward(
self,
g: Union[SparseGraph, dgl.DGLHeteroGraph],
nfeat: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
nfeat: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
efeat: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward computation.
Expand Down
7 changes: 6 additions & 1 deletion python/cugraph-dgl/tests/nn/test_gatv2conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,12 @@ def test_gatv2conv_equality(
sg = SparseGraph(size=size, src_ids=indices, cdst_ids=offsets, formats="csc")

args = (in_feats, out_feats, num_heads)
kwargs = {"bias": False, "allow_zero_in_degree": True, "residual": residual, "share_weights": mode=="share_weights"}
kwargs = {
"bias": False,
"allow_zero_in_degree": True,
"residual": residual,
"share_weights": mode == "share_weights",
}

conv1 = GATv2Conv(*args, **kwargs).to(device)
conv2 = CuGraphGATv2Conv(*args, **kwargs).to(device)
Expand Down
36 changes: 18 additions & 18 deletions python/cugraph-dgl/tests/nn/test_relgraphconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,23 @@

from cugraph_dgl.nn.conv.base import SparseGraph
from cugraph_dgl.nn import RelGraphConv as CuGraphRelGraphConv
from .common import create_graph1

dgl = pytest.importorskip("dgl", reason="DGL not available")
torch = pytest.importorskip("torch", reason="PyTorch not available")

ATOL = 1e-6


@pytest.mark.parametrize("idtype_int", [False, True])
@pytest.mark.parametrize("idx_type", [torch.int32, torch.int64])
@pytest.mark.parametrize("max_in_degree", [None, 8])
@pytest.mark.parametrize("num_bases", [1, 2, 5])
@pytest.mark.parametrize("regularizer", [None, "basis"])
@pytest.mark.parametrize("self_loop", [False, True])
@pytest.mark.parametrize("to_block", [False, True])
@pytest.mark.parametrize("sparse_format", ["coo", "csc", None])
def test_relgraphconv_equality(
idtype_int,
dgl_graph_1,
idx_type,
max_in_degree,
num_bases,
regularizer,
Expand All @@ -42,6 +42,12 @@ def test_relgraphconv_equality(
from dgl.nn.pytorch import RelGraphConv

torch.manual_seed(12345)
device = torch.device("cuda:0")
g = dgl_graph_1.to(device).astype(idx_type)

if to_block:
g = dgl.to_block(g)

in_feat, out_feat, num_rels = 10, 2, 3
args = (in_feat, out_feat, num_rels)
kwargs = {
Expand All @@ -50,16 +56,10 @@ def test_relgraphconv_equality(
"bias": False,
"self_loop": self_loop,
}
g = create_graph1().to("cuda")
g.edata[dgl.ETYPE] = torch.randint(num_rels, (g.num_edges(),)).cuda()

if idtype_int:
g = g.int()
if to_block:
g = dgl.to_block(g)

g.edata[dgl.ETYPE] = torch.randint(num_rels, (g.num_edges(),)).to(device)
size = (g.num_src_nodes(), g.num_dst_nodes())
feat = torch.rand(g.num_src_nodes(), in_feat).cuda()
feat = torch.rand(g.num_src_nodes(), in_feat).to(device)

if sparse_format == "coo":
sg = SparseGraph(
Expand All @@ -76,18 +76,18 @@ def test_relgraphconv_equality(
size=size, src_ids=indices, cdst_ids=offsets, values=etypes, formats="csc"
)

conv1 = RelGraphConv(*args, **kwargs).cuda()
conv2 = CuGraphRelGraphConv(*args, **kwargs, apply_norm=False).cuda()
conv1 = RelGraphConv(*args, **kwargs).to(device)
conv2 = CuGraphRelGraphConv(*args, **kwargs, apply_norm=False).to(device)

with torch.no_grad():
if self_loop:
conv2.W.data[:-1] = conv1.linear_r.W.data
conv2.W.data[-1] = conv1.loop_weight.data
conv2.W[:-1].copy_(conv1.linear_r.W)
conv2.W[-1].copy_(conv1.loop_weight)
else:
conv2.W.data = conv1.linear_r.W.data.detach().clone()
conv2.W.copy_(conv1.linear_r.W)

if regularizer is not None:
conv2.coeff.data = conv1.linear_r.coeff.data.detach().clone()
conv2.coeff.copy_(conv1.linear_r.coeff)

out1 = conv1(g, feat, g.edata[dgl.ETYPE])

Expand All @@ -98,7 +98,7 @@ def test_relgraphconv_equality(

assert torch.allclose(out1, out2, atol=ATOL)

grad_out = torch.rand_like(out1)
grad_out = torch.randn_like(out1)
out1.backward(grad_out)
out2.backward(grad_out)

Expand Down
34 changes: 16 additions & 18 deletions python/cugraph-dgl/tests/nn/test_sageconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from cugraph_dgl.nn.conv.base import SparseGraph
from cugraph_dgl.nn import SAGEConv as CuGraphSAGEConv
from .common import create_graph1

dgl = pytest.importorskip("dgl", reason="DGL not available")
torch = pytest.importorskip("torch", reason="PyTorch not available")
Expand All @@ -26,21 +25,19 @@
@pytest.mark.parametrize("aggr", ["mean", "pool"])
@pytest.mark.parametrize("bias", [False, True])
@pytest.mark.parametrize("bipartite", [False, True])
@pytest.mark.parametrize("idtype_int", [False, True])
@pytest.mark.parametrize("idx_type", [torch.int32, torch.int64])
@pytest.mark.parametrize("max_in_degree", [None, 8])
@pytest.mark.parametrize("to_block", [False, True])
@pytest.mark.parametrize("sparse_format", ["coo", "csc", None])
def test_sageconv_equality(
aggr, bias, bipartite, idtype_int, max_in_degree, to_block, sparse_format
dgl_graph_1, aggr, bias, bipartite, idx_type, max_in_degree, to_block, sparse_format
):
from dgl.nn.pytorch import SAGEConv

torch.manual_seed(12345)
kwargs = {"aggregator_type": aggr, "bias": bias}
g = create_graph1().to("cuda")
device = torch.device("cuda:0")
g = dgl_graph_1.to(device).astype(idx_type)

if idtype_int:
g = g.int()
if to_block:
g = dgl.to_block(g)

Expand All @@ -49,12 +46,12 @@ def test_sageconv_equality(
if bipartite:
in_feats = (5, 3)
feat = (
torch.rand(size[0], in_feats[0], requires_grad=True).cuda(),
torch.rand(size[1], in_feats[1], requires_grad=True).cuda(),
torch.rand(size[0], in_feats[0], requires_grad=True).to(device),
torch.rand(size[1], in_feats[1], requires_grad=True).to(device),
)
else:
in_feats = 5
feat = torch.rand(size[0], in_feats).cuda()
feat = torch.rand(size[0], in_feats).to(device)
out_feats = 2

if sparse_format == "coo":
Expand All @@ -65,18 +62,19 @@ def test_sageconv_equality(
offsets, indices, _ = g.adj_tensors("csc")
sg = SparseGraph(size=size, src_ids=indices, cdst_ids=offsets, formats="csc")

conv1 = SAGEConv(in_feats, out_feats, **kwargs).cuda()
conv2 = CuGraphSAGEConv(in_feats, out_feats, **kwargs).cuda()
kwargs = {"aggregator_type": aggr, "bias": bias}
conv1 = SAGEConv(in_feats, out_feats, **kwargs).to(device)
conv2 = CuGraphSAGEConv(in_feats, out_feats, **kwargs).to(device)

in_feats_src = conv2.in_feats_src
with torch.no_grad():
conv2.lin.weight.data[:, :in_feats_src] = conv1.fc_neigh.weight.data
conv2.lin.weight.data[:, in_feats_src:] = conv1.fc_self.weight.data
conv2.lin.weight[:, :in_feats_src].copy_(conv1.fc_neigh.weight)
conv2.lin.weight[:, in_feats_src:].copy_(conv1.fc_self.weight)
if bias:
conv2.lin.bias.data[:] = conv1.fc_self.bias.data
conv2.lin.bias.copy_(conv1.fc_self.bias)
if aggr == "pool":
conv2.pre_lin.weight.data[:] = conv1.fc_pool.weight.data
conv2.pre_lin.bias.data[:] = conv1.fc_pool.bias.data
conv2.pre_lin.weight.copy_(conv1.fc_pool.weight)
conv2.pre_lin.bias.copy_(conv1.fc_pool.bias)

out1 = conv1(g, feat)
if sparse_format is not None:
Expand All @@ -85,7 +83,7 @@ def test_sageconv_equality(
out2 = conv2(g, feat, max_in_degree=max_in_degree)
assert torch.allclose(out1, out2, atol=ATOL)

grad_out = torch.rand_like(out1)
grad_out = torch.randn_like(out1)
out1.backward(grad_out)
out2.backward(grad_out)
assert torch.allclose(
Expand Down
17 changes: 7 additions & 10 deletions python/cugraph-dgl/tests/nn/test_transformerconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from cugraph_dgl.nn.conv.base import SparseGraph
from cugraph_dgl.nn import TransformerConv
from .common import create_graph1

dgl = pytest.importorskip("dgl", reason="DGL not available")
torch = pytest.importorskip("torch", reason="PyTorch not available")
Expand All @@ -26,27 +25,25 @@
@pytest.mark.parametrize("beta", [False, True])
@pytest.mark.parametrize("bipartite_node_feats", [False, True])
@pytest.mark.parametrize("concat", [False, True])
@pytest.mark.parametrize("idtype_int", [False, True])
@pytest.mark.parametrize("num_heads", [1, 2, 3, 4])
@pytest.mark.parametrize("idx_type", [torch.int32, torch.int64])
@pytest.mark.parametrize("num_heads", [1, 3, 4])
@pytest.mark.parametrize("to_block", [False, True])
@pytest.mark.parametrize("use_edge_feats", [False, True])
@pytest.mark.parametrize("sparse_format", ["coo", "csc", None])
def test_transformerconv(
dgl_graph_1,
beta,
bipartite_node_feats,
concat,
idtype_int,
idx_type,
num_heads,
to_block,
use_edge_feats,
sparse_format,
):
torch.manual_seed(12345)
device = "cuda"
g = create_graph1().to(device)

if idtype_int:
g = g.int()
device = torch.device("cuda:0")
g = dgl_graph_1.to(device).astype(idx_type)

if to_block:
g = dgl.to_block(g)
Expand Down Expand Up @@ -92,5 +89,5 @@ def test_transformerconv(
else:
out = conv(g, nfeat, efeat)

grad_out = torch.rand_like(out)
grad_out = torch.randn_like(out)
out.backward(grad_out)

0 comments on commit e601dd2

Please sign in to comment.