Skip to content

Commit

Permalink
Fix torch seed in cugraph-dgl and -pyg tests for conv layers (#3869)
Browse files Browse the repository at this point in the history
Fixes rapidsai/graph_dl#325

Recently, a few CI runs (ex. [1](https://github.com/rapidsai/cugraph/actions/runs/6254253684/job/16983164330?pr=3828#step:7:5078), [2](https://github.com/rapidsai/cugraph/actions/runs/6224345348/job/16896416094?pr=3843)) failed when comparing results from cugraph-ops-based conv layers against results from upstream frameworks. The tests pass most of the time, but occasionally fail due to a combination of using a strict tolerance and bad numerics (floating point error). This PR fixes the seed used for generating random feature tensors so that CI behaves consistently across different runs.

Authors:
  - Tingyu Wang (https://github.com/tingyu66)

Approvers:
  - Alex Barghi (https://github.com/alexbarghi-nv)

URL: #3869
  • Loading branch information
tingyu66 authored Sep 22, 2023
1 parent 367f36c commit f53bb56
Show file tree
Hide file tree
Showing 10 changed files with 22 additions and 4 deletions.
2 changes: 2 additions & 0 deletions python/cugraph-dgl/tests/nn/test_gatconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def test_gatconv_equality(
):
from dgl.nn.pytorch import GATConv

torch.manual_seed(12345)
g = create_graph1().to("cuda")

if idtype_int:
Expand Down Expand Up @@ -121,6 +122,7 @@ def test_gatconv_equality(
def test_gatconv_edge_feats(
bias, bipartite, concat, max_in_degree, num_heads, to_block, use_edge_feats
):
torch.manual_seed(12345)
g = create_graph1().to("cuda")

if to_block:
Expand Down
2 changes: 2 additions & 0 deletions python/cugraph-dgl/tests/nn/test_gatv2conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def test_gatv2conv_equality(
):
from dgl.nn.pytorch import GATv2Conv

torch.manual_seed(12345)
g = create_graph1().to("cuda")

if idtype_int:
Expand Down Expand Up @@ -109,6 +110,7 @@ def test_gatv2conv_equality(
def test_gatv2conv_edge_feats(
bias, bipartite, concat, max_in_degree, num_heads, to_block, use_edge_feats
):
torch.manual_seed(12345)
g = create_graph1().to("cuda")

if to_block:
Expand Down
15 changes: 11 additions & 4 deletions python/cugraph-dgl/tests/nn/test_relgraphconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def test_relgraphconv_equality(
):
from dgl.nn.pytorch import RelGraphConv

torch.manual_seed(12345)
in_feat, out_feat, num_rels = 10, 2, 3
args = (in_feat, out_feat, num_rels)
kwargs = {
Expand Down Expand Up @@ -75,12 +76,18 @@ def test_relgraphconv_equality(
size=size, src_ids=indices, cdst_ids=offsets, values=etypes, formats="csc"
)

torch.manual_seed(0)
conv1 = RelGraphConv(*args, **kwargs).cuda()
conv2 = CuGraphRelGraphConv(*args, **kwargs, apply_norm=False).cuda()

torch.manual_seed(0)
kwargs["apply_norm"] = False
conv2 = CuGraphRelGraphConv(*args, **kwargs).cuda()
with torch.no_grad():
if self_loop:
conv2.W.data[:-1] = conv1.linear_r.W.data
conv2.W.data[-1] = conv1.loop_weight.data
else:
conv2.W.data = conv1.linear_r.W.data.detach().clone()

if regularizer is not None:
conv2.coeff.data = conv1.linear_r.coeff.data.detach().clone()

out1 = conv1(g, feat, g.edata[dgl.ETYPE])

Expand Down
1 change: 1 addition & 0 deletions python/cugraph-dgl/tests/nn/test_sageconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def test_sageconv_equality(
):
from dgl.nn.pytorch import SAGEConv

torch.manual_seed(12345)
kwargs = {"aggregator_type": aggr, "bias": bias}
g = create_graph1().to("cuda")

Expand Down
1 change: 1 addition & 0 deletions python/cugraph-dgl/tests/nn/test_transformerconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def test_transformerconv(
use_edge_feats,
sparse_format,
):
torch.manual_seed(12345)
device = "cuda"
g = create_graph1().to(device)

Expand Down
1 change: 1 addition & 0 deletions python/cugraph-pyg/cugraph_pyg/tests/nn/test_gat_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test_gat_conv_equality(
import torch
from torch_geometric.nn import GATConv

torch.manual_seed(12345)
edge_index, size = request.getfixturevalue(graph)
edge_index = edge_index.cuda()

Expand Down
1 change: 1 addition & 0 deletions python/cugraph-pyg/cugraph_pyg/tests/nn/test_gatv2_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def test_gatv2_conv_equality(bipartite, concat, heads, use_edge_attr, graph, req
import torch
from torch_geometric.nn import GATv2Conv

torch.manual_seed(12345)
edge_index, size = request.getfixturevalue(graph)
edge_index = edge_index.cuda()

Expand Down
1 change: 1 addition & 0 deletions python/cugraph-pyg/cugraph_pyg/tests/nn/test_rgcn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_rgcn_conv_equality(
import torch
from torch_geometric.nn import FastRGCNConv as RGCNConv

torch.manual_seed(12345)
in_channels, out_channels, num_relations = (4, 2, 3)
kwargs = dict(aggr=aggr, bias=bias, num_bases=num_bases, root_weight=root_weight)

Expand Down
1 change: 1 addition & 0 deletions python/cugraph-pyg/cugraph_pyg/tests/nn/test_sage_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test_sage_conv_equality(
import torch
from torch_geometric.nn import SAGEConv

torch.manual_seed(12345)
edge_index, size = request.getfixturevalue(graph)
edge_index = edge_index.cuda()
csc = CuGraphSAGEConv.to_csc(edge_index, size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_transformer_conv_equality(bipartite, concat, heads, graph, request):
import torch
from torch_geometric.nn import TransformerConv

torch.manual_seed(12345)
edge_index, size = request.getfixturevalue(graph)
edge_index = edge_index.cuda()
csc = CuGraphTransformerConv.to_csc(edge_index, size)
Expand Down

0 comments on commit f53bb56

Please sign in to comment.