Skip to content

Commit

Permalink
fix manual_seed for cugraph-pyg tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tingyu66 committed Sep 21, 2023
1 parent 51dccd2 commit f5b5caf
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/cugraph-pyg/cugraph_pyg/tests/nn/test_gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test_gat_conv_equality(
import torch
from torch_geometric.nn import GATConv

torch.manual_seed(12345)
edge_index, size = request.getfixturevalue(graph)
edge_index = edge_index.cuda()

Expand Down
1 change: 1 addition & 0 deletions python/cugraph-pyg/cugraph_pyg/tests/nn/test_gatv2_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def test_gatv2_conv_equality(bipartite, concat, heads, use_edge_attr, graph, req
import torch
from torch_geometric.nn import GATv2Conv

torch.manual_seed(12345)
edge_index, size = request.getfixturevalue(graph)
edge_index = edge_index.cuda()

Expand Down
1 change: 1 addition & 0 deletions python/cugraph-pyg/cugraph_pyg/tests/nn/test_rgcn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_rgcn_conv_equality(
import torch
from torch_geometric.nn import FastRGCNConv as RGCNConv

torch.manual_seed(12345)
in_channels, out_channels, num_relations = (4, 2, 3)
kwargs = dict(aggr=aggr, bias=bias, num_bases=num_bases, root_weight=root_weight)

Expand Down
1 change: 1 addition & 0 deletions python/cugraph-pyg/cugraph_pyg/tests/nn/test_sage_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test_sage_conv_equality(
import torch
from torch_geometric.nn import SAGEConv

torch.manual_seed(12345)
edge_index, size = request.getfixturevalue(graph)
edge_index = edge_index.cuda()
csc = CuGraphSAGEConv.to_csc(edge_index, size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_transformer_conv_equality(bipartite, concat, heads, graph, request):
import torch
from torch_geometric.nn import TransformerConv

torch.manual_seed(12345)
edge_index, size = request.getfixturevalue(graph)
edge_index = edge_index.cuda()
csc = CuGraphTransformerConv.to_csc(edge_index, size)
Expand Down

0 comments on commit f5b5caf

Please sign in to comment.