Skip to content

Commit

Permalink
fix manual_seed for cugraph-dgl tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tingyu66 committed Sep 21, 2023
1 parent d930321 commit 51dccd2
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 4 deletions.
2 changes: 2 additions & 0 deletions python/cugraph-dgl/tests/nn/test_gatconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def test_gatconv_equality(
):
from dgl.nn.pytorch import GATConv

torch.manual_seed(12345)
g = create_graph1().to("cuda")

if idtype_int:
Expand Down Expand Up @@ -121,6 +122,7 @@ def test_gatconv_equality(
def test_gatconv_edge_feats(
bias, bipartite, concat, max_in_degree, num_heads, to_block, use_edge_feats
):
torch.manual_seed(12345)
g = create_graph1().to("cuda")

if to_block:
Expand Down
2 changes: 2 additions & 0 deletions python/cugraph-dgl/tests/nn/test_gatv2conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def test_gatv2conv_equality(
):
from dgl.nn.pytorch import GATv2Conv

torch.manual_seed(12345)
g = create_graph1().to("cuda")

if idtype_int:
Expand Down Expand Up @@ -109,6 +110,7 @@ def test_gatv2conv_equality(
def test_gatv2conv_edge_feats(
bias, bipartite, concat, max_in_degree, num_heads, to_block, use_edge_feats
):
torch.manual_seed(12345)
g = create_graph1().to("cuda")

if to_block:
Expand Down
15 changes: 11 additions & 4 deletions python/cugraph-dgl/tests/nn/test_relgraphconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def test_relgraphconv_equality(
):
from dgl.nn.pytorch import RelGraphConv

torch.manual_seed(12345)
in_feat, out_feat, num_rels = 10, 2, 3
args = (in_feat, out_feat, num_rels)
kwargs = {
Expand Down Expand Up @@ -75,12 +76,18 @@ def test_relgraphconv_equality(
size=size, src_ids=indices, cdst_ids=offsets, values=etypes, formats="csc"
)

torch.manual_seed(0)
conv1 = RelGraphConv(*args, **kwargs).cuda()
conv2 = CuGraphRelGraphConv(*args, **kwargs, apply_norm=False).cuda()

torch.manual_seed(0)
kwargs["apply_norm"] = False
conv2 = CuGraphRelGraphConv(*args, **kwargs).cuda()
with torch.no_grad():
if self_loop:
conv2.W.data[:-1] = conv1.linear_r.W.data
conv2.W.data[-1] = conv1.loop_weight.data
else:
conv2.W.data = conv1.linear_r.W.data.detach().clone()

if regularizer is not None:
conv2.coeff.data = conv1.linear_r.coeff.data.detach().clone()

out1 = conv1(g, feat, g.edata[dgl.ETYPE])

Expand Down
1 change: 1 addition & 0 deletions python/cugraph-dgl/tests/nn/test_sageconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def test_sageconv_equality(
):
from dgl.nn.pytorch import SAGEConv

torch.manual_seed(12345)
kwargs = {"aggregator_type": aggr, "bias": bias}
g = create_graph1().to("cuda")

Expand Down
1 change: 1 addition & 0 deletions python/cugraph-dgl/tests/nn/test_transformerconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def test_transformerconv(
use_edge_feats,
sparse_format,
):
torch.manual_seed(12345)
device = "cuda"
g = create_graph1().to(device)

Expand Down

0 comments on commit 51dccd2

Please sign in to comment.