From f5b5cafa451f0baf9c12d036687379e23e36e4c0 Mon Sep 17 00:00:00 2001 From: Tingyu Wang Date: Thu, 21 Sep 2023 12:47:25 -0400 Subject: [PATCH] fix manual_seed for cugraph-pyg tests --- python/cugraph-pyg/cugraph_pyg/tests/nn/test_gat_conv.py | 1 + python/cugraph-pyg/cugraph_pyg/tests/nn/test_gatv2_conv.py | 1 + python/cugraph-pyg/cugraph_pyg/tests/nn/test_rgcn_conv.py | 1 + python/cugraph-pyg/cugraph_pyg/tests/nn/test_sage_conv.py | 1 + python/cugraph-pyg/cugraph_pyg/tests/nn/test_transformer_conv.py | 1 + 5 files changed, 5 insertions(+) diff --git a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_gat_conv.py b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_gat_conv.py index 21c43bad38c..62bebb9211d 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_gat_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_gat_conv.py @@ -32,6 +32,7 @@ def test_gat_conv_equality( import torch from torch_geometric.nn import GATConv + torch.manual_seed(12345) edge_index, size = request.getfixturevalue(graph) edge_index = edge_index.cuda() diff --git a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_gatv2_conv.py b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_gatv2_conv.py index 6b11e87154a..a4794628410 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_gatv2_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_gatv2_conv.py @@ -28,6 +28,7 @@ def test_gatv2_conv_equality(bipartite, concat, heads, use_edge_attr, graph, req import torch from torch_geometric.nn import GATv2Conv + torch.manual_seed(12345) edge_index, size = request.getfixturevalue(graph) edge_index = edge_index.cuda() diff --git a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_rgcn_conv.py b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_rgcn_conv.py index 233c6aa2836..ded4f300c0c 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_rgcn_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_rgcn_conv.py @@ -31,6 +31,7 @@ def test_rgcn_conv_equality( import torch from torch_geometric.nn import FastRGCNConv as RGCNConv + torch.manual_seed(12345) in_channels, out_channels, num_relations = (4, 2, 3) kwargs = dict(aggr=aggr, bias=bias, num_bases=num_bases, root_weight=root_weight) diff --git a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_sage_conv.py b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_sage_conv.py index 7f73cddbdbb..b2977d1d175 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_sage_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_sage_conv.py @@ -32,6 +32,7 @@ def test_sage_conv_equality( import torch from torch_geometric.nn import SAGEConv + torch.manual_seed(12345) edge_index, size = request.getfixturevalue(graph) edge_index = edge_index.cuda() csc = CuGraphSAGEConv.to_csc(edge_index, size) diff --git a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_transformer_conv.py b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_transformer_conv.py index 7dba1a6d515..fbdb244898b 100644 --- a/python/cugraph-pyg/cugraph_pyg/tests/nn/test_transformer_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/tests/nn/test_transformer_conv.py @@ -27,6 +27,7 @@ def test_transformer_conv_equality(bipartite, concat, heads, graph, request): import torch from torch_geometric.nn import TransformerConv + torch.manual_seed(12345) edge_index, size = request.getfixturevalue(graph) edge_index = edge_index.cuda() csc = CuGraphTransformerConv.to_csc(edge_index, size)