Skip to content

Commit

Permalink
fix linting issue
Browse files Browse the repository at this point in the history
  • Loading branch information
stslxg-nv committed Aug 28, 2024
1 parent 9efe68e commit 5dc216e
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions test/test_radius.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,16 @@
def to_set(edge_index):
return set([(i, j) for i, j in edge_index.t().tolist()])


def to_degree(edge_index):
_, counts = torch.unique(edge_index[1], return_counts=True)
return counts.tolist()


def to_batch(nodes):
return [int(i / 4) for i in nodes]


@pytest.mark.parametrize('dtype,device', product(floating_dtypes, devices))
def test_radius(dtype, device):
x = tensor([
Expand Down Expand Up @@ -79,8 +82,9 @@ def test_radius_graph(dtype, device):
edge_index = jit(x, r=2.5, flow='source_to_target')
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
(3, 2), (0, 3), (2, 3)])

edge_index = radius_graph(x, r=100, flow='source_to_target', max_num_neighbors=1)

edge_index = radius_graph(x, r=100, flow='source_to_target',
max_num_neighbors=1)
assert set(to_degree(edge_index)) == set([1])

x = tensor([
Expand All @@ -90,7 +94,8 @@ def test_radius_graph(dtype, device):
[-1, -1],
], dtype, device)

edge_index = radius_graph(x, r=100, flow='source_to_target', max_num_neighbors=1)
edge_index = radius_graph(x, r=100, flow='source_to_target',
max_num_neighbors=1)
assert set(to_degree(edge_index)) == set([1])

x = tensor([
Expand All @@ -105,10 +110,12 @@ def test_radius_graph(dtype, device):
], dtype, device)
batch_x = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device)

edge_index = radius_graph(x, r=100, batch=batch_x, flow='source_to_target', max_num_neighbors=1)
edge_index = radius_graph(x, r=100, batch=batch_x, flow='source_to_target',
max_num_neighbors=1)
assert set(to_degree(edge_index)) == set([1])
assert to_batch(edge_index[0]) == batch_x.tolist()


@pytest.mark.parametrize('dtype,device', product([torch.float], devices))
def test_radius_graph_large(dtype, device):
x = torch.randn(1000, 3, dtype=dtype, device=device)
Expand Down

0 comments on commit 5dc216e

Please sign in to comment.