diff --git a/test/test_graclus.py b/test/test_graclus.py index b892330..c8e1f39 100644 --- a/test/test_graclus.py +++ b/test/test_graclus.py @@ -50,3 +50,7 @@ def test_graclus_cluster(test, dtype, device): cluster = graclus_cluster(row, col, weight) assert_correct(row, col, cluster) + + jit = torch.jit.script(graclus_cluster) + cluster = jit(row, col, weight) + assert_correct(row, col, cluster) diff --git a/test/test_grid.py b/test/test_grid.py index c297f33..2d53220 100644 --- a/test/test_grid.py +++ b/test/test_grid.py @@ -38,3 +38,6 @@ def test_grid_cluster(test, dtype, device): cluster = grid_cluster(pos, size, start, end) assert cluster.tolist() == test['cluster'] + + jit = torch.jit.script(grid_cluster) + assert torch.equal(jit(pos, size, start, end), cluster) diff --git a/test/test_knn.py b/test/test_knn.py index 8113a54..32852fe 100644 --- a/test/test_knn.py +++ b/test/test_knn.py @@ -34,6 +34,10 @@ def test_knn(dtype, device): edge_index = knn(x, y, 2) assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)]) + jit = torch.jit.script(knn) + edge_index = jit(x, y, 2) + assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)]) + edge_index = knn(x, y, 2, batch_x, batch_y) assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)]) @@ -65,6 +69,11 @@ def test_knn_graph(dtype, device): assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2), (3, 2), (0, 3), (2, 3)]) + jit = torch.jit.script(knn_graph) + edge_index = jit(x, k=2, flow='source_to_target') + assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2), + (3, 2), (0, 3), (2, 3)]) + @pytest.mark.parametrize('dtype,device', product([torch.float], devices)) def test_knn_graph_large(dtype, device): diff --git a/test/test_radius.py b/test/test_radius.py index 34c4ad9..078412f 100644 --- a/test/test_radius.py +++ b/test/test_radius.py @@ -35,6 +35,11 @@ def test_radius(dtype, device): assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 1), (1, 2), (1, 5), (1, 6)]) + jit = torch.jit.script(radius) + edge_index = jit(x, y, 2, max_num_neighbors=4) + assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 1), + (1, 2), (1, 5), (1, 6)]) + edge_index = radius(x, y, 2, batch_x, batch_y, max_num_neighbors=4) assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 5), (1, 6)]) @@ -64,12 +69,20 @@ def test_radius_graph(dtype, device): assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2), (3, 2), (0, 3), (2, 3)]) + jit = torch.jit.script(radius_graph) + edge_index = jit(x, r=2.5, flow='source_to_target') + assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2), + (3, 2), (0, 3), (2, 3)]) + @pytest.mark.parametrize('dtype,device', product([torch.float], devices)) def test_radius_graph_large(dtype, device): x = torch.randn(1000, 3, dtype=dtype, device=device) - edge_index = radius_graph(x, r=0.5, flow='target_to_source', loop=True, + edge_index = radius_graph(x, + r=0.5, + flow='target_to_source', + loop=True, max_num_neighbors=2000) tree = scipy.spatial.cKDTree(x.cpu().numpy()) diff --git a/test/test_rw.py b/test/test_rw.py index 67d935d..82a8b77 100644 --- a/test/test_rw.py +++ b/test/test_rw.py @@ -31,6 +31,9 @@ def test_rw_small(device): out = random_walk(row, col, start, walk_length, num_nodes=3) assert out.tolist() == [[0, 1, 0, 1, 0], [1, 0, 1, 0, 1], [2, 2, 2, 2, 2]] + jit = torch.jit.script(random_walk) + assert torch.equal(jit(row, col, start, walk_length, num_nodes=3), out) + @pytest.mark.parametrize('device', devices) def test_rw_large_with_edge_indices(device): diff --git a/torch_cluster/graclus.py b/torch_cluster/graclus.py index cdac892..5a7b880 100644 --- a/torch_cluster/graclus.py +++ b/torch_cluster/graclus.py @@ -3,8 +3,8 @@ import torch -@torch.jit.script -def graclus_cluster(row: torch.Tensor, col: torch.Tensor, +def graclus_cluster(row: torch.Tensor, + col: torch.Tensor, weight: Optional[torch.Tensor] = None, num_nodes: Optional[int] = None) -> torch.Tensor: """A greedy clustering algorithm of picking an unmarked vertex and matching diff --git a/torch_cluster/grid.py b/torch_cluster/grid.py index 1dbacb9..bf19949 100644 --- a/torch_cluster/grid.py +++ b/torch_cluster/grid.py @@ -3,8 +3,8 @@ import torch -@torch.jit.script -def grid_cluster(pos: torch.Tensor, size: torch.Tensor, +def grid_cluster(pos: torch.Tensor, + size: torch.Tensor, start: Optional[torch.Tensor] = None, end: Optional[torch.Tensor] = None) -> torch.Tensor: """A clustering algorithm, which overlays a regular grid of user-defined diff --git a/torch_cluster/knn.py b/torch_cluster/knn.py index 4eace5e..cf8f087 100644 --- a/torch_cluster/knn.py +++ b/torch_cluster/knn.py @@ -3,7 +3,6 @@ import torch -@torch.jit.script def knn( x: torch.Tensor, y: torch.Tensor, @@ -83,7 +82,6 @@ def knn( num_workers) -@torch.jit.script def knn_graph( x: torch.Tensor, k: int, diff --git a/torch_cluster/radius.py b/torch_cluster/radius.py index de35298..069824a 100644 --- a/torch_cluster/radius.py +++ b/torch_cluster/radius.py @@ -3,7 +3,6 @@ import torch -@torch.jit.script def radius( x: torch.Tensor, y: torch.Tensor, @@ -84,7 +83,6 @@ def radius( max_num_neighbors, num_workers) -@torch.jit.script def radius_graph( x: torch.Tensor, r: float, diff --git a/torch_cluster/rw.py b/torch_cluster/rw.py index 12e0683..cb7bc2c 100644 --- a/torch_cluster/rw.py +++ b/torch_cluster/rw.py @@ -4,7 +4,6 @@ from torch import Tensor -@torch.jit.script def random_walk( row: Tensor, col: Tensor, @@ -55,7 +54,12 @@ def random_walk( torch.cumsum(deg, 0, out=rowptr[1:]) node_seq, edge_seq = torch.ops.torch_cluster.random_walk( - rowptr, col, start, walk_length, p, q, + rowptr, + col, + start, + walk_length, + p, + q, ) if return_edge_indices: diff --git a/torch_cluster/sampler.py b/torch_cluster/sampler.py index 9d2e08e..1b68de0 100644 --- a/torch_cluster/sampler.py +++ b/torch_cluster/sampler.py @@ -1,7 +1,6 @@ import torch -@torch.jit.script def neighbor_sampler(start: torch.Tensor, rowptr: torch.Tensor, size: float): assert not start.is_cuda diff --git a/torch_cluster/typing.py b/torch_cluster/typing.py index d570684..f57544a 100644 --- a/torch_cluster/typing.py +++ b/torch_cluster/typing.py @@ -1,3 +1,6 @@ import torch -WITH_PTR_LIST = hasattr(torch.ops.torch_cluster, 'fps_ptr_list') +try: + WITH_PTR_LIST = hasattr(torch.ops.torch_cluster, 'fps_ptr_list') +except Exception: + WITH_PTR_LIST = False