From 5dc216e68118ab36fcdb139246c5d7f234ee76ab Mon Sep 17 00:00:00 2001 From: Xuangui Huang Date: Wed, 28 Aug 2024 12:11:33 -0700 Subject: [PATCH] fix linting issue --- test/test_radius.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/test/test_radius.py b/test/test_radius.py index cd42b70..4289bfc 100644 --- a/test/test_radius.py +++ b/test/test_radius.py @@ -10,13 +10,16 @@ def to_set(edge_index): return set([(i, j) for i, j in edge_index.t().tolist()]) + def to_degree(edge_index): _, counts = torch.unique(edge_index[1], return_counts=True) return counts.tolist() + def to_batch(nodes): return [int(i / 4) for i in nodes] + @pytest.mark.parametrize('dtype,device', product(floating_dtypes, devices)) def test_radius(dtype, device): x = tensor([ @@ -79,8 +82,9 @@ def test_radius_graph(dtype, device): edge_index = jit(x, r=2.5, flow='source_to_target') assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2), (3, 2), (0, 3), (2, 3)]) - - edge_index = radius_graph(x, r=100, flow='source_to_target', max_num_neighbors=1) + + edge_index = radius_graph(x, r=100, flow='source_to_target', + max_num_neighbors=1) assert set(to_degree(edge_index)) == set([1]) x = tensor([ @@ -90,7 +94,8 @@ def test_radius_graph(dtype, device): [-1, -1], ], dtype, device) - edge_index = radius_graph(x, r=100, flow='source_to_target', max_num_neighbors=1) + edge_index = radius_graph(x, r=100, flow='source_to_target', + max_num_neighbors=1) assert set(to_degree(edge_index)) == set([1]) x = tensor([ @@ -105,10 +110,12 @@ def test_radius_graph(dtype, device): ], dtype, device) batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device) - edge_index = radius_graph(x, r=100, batch=batch_x, flow='source_to_target', max_num_neighbors=1) + edge_index = radius_graph(x, r=100, batch=batch_x, flow='source_to_target', + max_num_neighbors=1) assert set(to_degree(edge_index)) == set([1]) assert to_batch(edge_index[0]) == batch_x.tolist() + @pytest.mark.parametrize('dtype,device', product([torch.float], devices)) def test_radius_graph_large(dtype, device): x = torch.randn(1000, 3, dtype=dtype, device=device)