Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Oct 11, 2023
1 parent 89b74f0 commit 8453c82
Show file tree
Hide file tree
Showing 12 changed files with 47 additions and 13 deletions.
4 changes: 4 additions & 0 deletions test/test_graclus.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,7 @@ def test_graclus_cluster(test, dtype, device):

cluster = graclus_cluster(row, col, weight)
assert_correct(row, col, cluster)

jit = torch.jit.script(graclus_cluster)
cluster = jit(row, col, weight)
assert_correct(row, col, cluster)
3 changes: 3 additions & 0 deletions test/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,6 @@ def test_grid_cluster(test, dtype, device):

cluster = grid_cluster(pos, size, start, end)
assert cluster.tolist() == test['cluster']

jit = torch.jit.script(grid_cluster)
assert torch.equal(jit(pos, size, start, end), cluster)
9 changes: 9 additions & 0 deletions test/test_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def test_knn(dtype, device):
edge_index = knn(x, y, 2)
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)])

jit = torch.jit.script(knn)
edge_index = jit(x, y, 2)
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)])

edge_index = knn(x, y, 2, batch_x, batch_y)
assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)])

Expand Down Expand Up @@ -65,6 +69,11 @@ def test_knn_graph(dtype, device):
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
(3, 2), (0, 3), (2, 3)])

jit = torch.jit.script(knn_graph)
edge_index = jit(x, k=2, flow='source_to_target')
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
(3, 2), (0, 3), (2, 3)])


@pytest.mark.parametrize('dtype,device', product([torch.float], devices))
def test_knn_graph_large(dtype, device):
Expand Down
15 changes: 14 additions & 1 deletion test/test_radius.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ def test_radius(dtype, device):
assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 1),
(1, 2), (1, 5), (1, 6)])

jit = torch.jit.script(radius)
edge_index = jit(x, y, 2, max_num_neighbors=4)
assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 1),
(1, 2), (1, 5), (1, 6)])

edge_index = radius(x, y, 2, batch_x, batch_y, max_num_neighbors=4)
assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 5),
(1, 6)])
Expand Down Expand Up @@ -64,12 +69,20 @@ def test_radius_graph(dtype, device):
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
(3, 2), (0, 3), (2, 3)])

jit = torch.jit.script(radius_graph)
edge_index = jit(x, r=2.5, flow='source_to_target')
assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),
(3, 2), (0, 3), (2, 3)])


@pytest.mark.parametrize('dtype,device', product([torch.float], devices))
def test_radius_graph_large(dtype, device):
x = torch.randn(1000, 3, dtype=dtype, device=device)

edge_index = radius_graph(x, r=0.5, flow='target_to_source', loop=True,
edge_index = radius_graph(x,
r=0.5,
flow='target_to_source',
loop=True,
max_num_neighbors=2000)

tree = scipy.spatial.cKDTree(x.cpu().numpy())
Expand Down
3 changes: 3 additions & 0 deletions test/test_rw.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def test_rw_small(device):
out = random_walk(row, col, start, walk_length, num_nodes=3)
assert out.tolist() == [[0, 1, 0, 1, 0], [1, 0, 1, 0, 1], [2, 2, 2, 2, 2]]

jit = torch.jit.script(random_walk)
assert torch.equal(jit(row, col, start, walk_length, num_nodes=3), out)


@pytest.mark.parametrize('device', devices)
def test_rw_large_with_edge_indices(device):
Expand Down
4 changes: 2 additions & 2 deletions torch_cluster/graclus.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch


@torch.jit.script
def graclus_cluster(row: torch.Tensor, col: torch.Tensor,
def graclus_cluster(row: torch.Tensor,
col: torch.Tensor,
weight: Optional[torch.Tensor] = None,
num_nodes: Optional[int] = None) -> torch.Tensor:
"""A greedy clustering algorithm of picking an unmarked vertex and matching
Expand Down
4 changes: 2 additions & 2 deletions torch_cluster/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch


@torch.jit.script
def grid_cluster(pos: torch.Tensor, size: torch.Tensor,
def grid_cluster(pos: torch.Tensor,
size: torch.Tensor,
start: Optional[torch.Tensor] = None,
end: Optional[torch.Tensor] = None) -> torch.Tensor:
"""A clustering algorithm, which overlays a regular grid of user-defined
Expand Down
2 changes: 0 additions & 2 deletions torch_cluster/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch


@torch.jit.script
def knn(
x: torch.Tensor,
y: torch.Tensor,
Expand Down Expand Up @@ -83,7 +82,6 @@ def knn(
num_workers)


@torch.jit.script
def knn_graph(
x: torch.Tensor,
k: int,
Expand Down
2 changes: 0 additions & 2 deletions torch_cluster/radius.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch


@torch.jit.script
def radius(
x: torch.Tensor,
y: torch.Tensor,
Expand Down Expand Up @@ -84,7 +83,6 @@ def radius(
max_num_neighbors, num_workers)


@torch.jit.script
def radius_graph(
x: torch.Tensor,
r: float,
Expand Down
8 changes: 6 additions & 2 deletions torch_cluster/rw.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from torch import Tensor


@torch.jit.script
def random_walk(
row: Tensor,
col: Tensor,
Expand Down Expand Up @@ -55,7 +54,12 @@ def random_walk(
torch.cumsum(deg, 0, out=rowptr[1:])

node_seq, edge_seq = torch.ops.torch_cluster.random_walk(
rowptr, col, start, walk_length, p, q,
rowptr,
col,
start,
walk_length,
p,
q,
)

if return_edge_indices:
Expand Down
1 change: 0 additions & 1 deletion torch_cluster/sampler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch


@torch.jit.script
def neighbor_sampler(start: torch.Tensor, rowptr: torch.Tensor, size: float):
assert not start.is_cuda

Expand Down
5 changes: 4 additions & 1 deletion torch_cluster/typing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import torch

WITH_PTR_LIST = hasattr(torch.ops.torch_cluster, 'fps_ptr_list')
try:
WITH_PTR_LIST = hasattr(torch.ops.torch_cluster, 'fps_ptr_list')
except Exception:
WITH_PTR_LIST = False

0 comments on commit 8453c82

Please sign in to comment.