From 8453c82e6da949561997f7ba1e9a3279024bfde2 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Wed, 11 Oct 2023 12:22:00 +0000 Subject: [PATCH] update --- test/test_graclus.py | 4 ++++ test/test_grid.py | 3 +++ test/test_knn.py | 9 +++++++++ test/test_radius.py | 15 ++++++++++++++- test/test_rw.py | 3 +++ torch_cluster/graclus.py | 4 ++-- torch_cluster/grid.py | 4 ++-- torch_cluster/knn.py | 2 -- torch_cluster/radius.py | 2 -- torch_cluster/rw.py | 8 ++++++-- torch_cluster/sampler.py | 1 - torch_cluster/typing.py | 5 ++++- 12 files changed, 47 insertions(+), 13 deletions(-) diff --git a/test/test_graclus.py b/test/test_graclus.py index b892330d..c8e1f39f 100644 --- a/test/test_graclus.py +++ b/test/test_graclus.py @@ -50,3 +50,7 @@ def test_graclus_cluster(test, dtype, device): cluster = graclus_cluster(row, col, weight) assert_correct(row, col, cluster) + + jit = torch.jit.script(graclus_cluster) + cluster = jit(row, col, weight) + assert_correct(row, col, cluster) diff --git a/test/test_grid.py b/test/test_grid.py index c297f339..2d53220f 100644 --- a/test/test_grid.py +++ b/test/test_grid.py @@ -38,3 +38,6 @@ def test_grid_cluster(test, dtype, device): cluster = grid_cluster(pos, size, start, end) assert cluster.tolist() == test['cluster'] + + jit = torch.jit.script(grid_cluster) + assert torch.equal(jit(pos, size, start, end), cluster) diff --git a/test/test_knn.py b/test/test_knn.py index 8113a543..32852fe0 100644 --- a/test/test_knn.py +++ b/test/test_knn.py @@ -34,6 +34,10 @@ def test_knn(dtype, device): edge_index = knn(x, y, 2) assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)]) + jit = torch.jit.script(knn) + edge_index = jit(x, y, 2) + assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)]) + edge_index = knn(x, y, 2, batch_x, batch_y) assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)]) @@ -65,6 +69,11 @@ def test_knn_graph(dtype, device): assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2), (3, 2), (0, 3), (2, 3)]) + jit = torch.jit.script(knn_graph) + edge_index = jit(x, k=2, flow='source_to_target') + assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2), + (3, 2), (0, 3), (2, 3)]) + @pytest.mark.parametrize('dtype,device', product([torch.float], devices)) def test_knn_graph_large(dtype, device): diff --git a/test/test_radius.py b/test/test_radius.py index 34c4ad97..078412fe 100644 --- a/test/test_radius.py +++ b/test/test_radius.py @@ -35,6 +35,11 @@ def test_radius(dtype, device): assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 1), (1, 2), (1, 5), (1, 6)]) + jit = torch.jit.script(radius) + edge_index = jit(x, y, 2, max_num_neighbors=4) + assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 1), + (1, 2), (1, 5), (1, 6)]) + edge_index = radius(x, y, 2, batch_x, batch_y, max_num_neighbors=4) assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 5), (1, 6)]) @@ -64,12 +69,20 @@ def test_radius_graph(dtype, device): assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2), (3, 2), (0, 3), (2, 3)]) + jit = torch.jit.script(radius_graph) + edge_index = jit(x, r=2.5, flow='source_to_target') + assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2), + (3, 2), (0, 3), (2, 3)]) + @pytest.mark.parametrize('dtype,device', product([torch.float], devices)) def test_radius_graph_large(dtype, device): x = torch.randn(1000, 3, dtype=dtype, device=device) - edge_index = radius_graph(x, r=0.5, flow='target_to_source', loop=True, + edge_index = radius_graph(x, + r=0.5, + flow='target_to_source', + loop=True, max_num_neighbors=2000) tree = scipy.spatial.cKDTree(x.cpu().numpy()) diff --git a/test/test_rw.py b/test/test_rw.py index 67d935df..82a8b77d 100644 --- a/test/test_rw.py +++ b/test/test_rw.py @@ -31,6 +31,9 @@ def test_rw_small(device): out = random_walk(row, col, start, walk_length, num_nodes=3) assert out.tolist() == [[0, 1, 0, 1, 0], [1, 0, 1, 0, 1], [2, 2, 2, 2, 2]] + jit = torch.jit.script(random_walk) + assert torch.equal(jit(row, col, start, walk_length, num_nodes=3), out) + @pytest.mark.parametrize('device', devices) def test_rw_large_with_edge_indices(device): diff --git a/torch_cluster/graclus.py b/torch_cluster/graclus.py index cdac8929..5a7b8806 100644 --- a/torch_cluster/graclus.py +++ b/torch_cluster/graclus.py @@ -3,8 +3,8 @@ import torch -@torch.jit.script -def graclus_cluster(row: torch.Tensor, col: torch.Tensor, +def graclus_cluster(row: torch.Tensor, + col: torch.Tensor, weight: Optional[torch.Tensor] = None, num_nodes: Optional[int] = None) -> torch.Tensor: """A greedy clustering algorithm of picking an unmarked vertex and matching diff --git a/torch_cluster/grid.py b/torch_cluster/grid.py index 1dbacb9f..bf199497 100644 --- a/torch_cluster/grid.py +++ b/torch_cluster/grid.py @@ -3,8 +3,8 @@ import torch -@torch.jit.script -def grid_cluster(pos: torch.Tensor, size: torch.Tensor, +def grid_cluster(pos: torch.Tensor, + size: torch.Tensor, start: Optional[torch.Tensor] = None, end: Optional[torch.Tensor] = None) -> torch.Tensor: """A clustering algorithm, which overlays a regular grid of user-defined diff --git a/torch_cluster/knn.py b/torch_cluster/knn.py index 4eace5e1..cf8f0875 100644 --- a/torch_cluster/knn.py +++ b/torch_cluster/knn.py @@ -3,7 +3,6 @@ import torch -@torch.jit.script def knn( x: torch.Tensor, y: torch.Tensor, @@ -83,7 +82,6 @@ def knn( num_workers) -@torch.jit.script def knn_graph( x: torch.Tensor, k: int, diff --git a/torch_cluster/radius.py b/torch_cluster/radius.py index de352988..069824ab 100644 --- a/torch_cluster/radius.py +++ b/torch_cluster/radius.py @@ -3,7 +3,6 @@ import torch -@torch.jit.script def radius( x: torch.Tensor, y: torch.Tensor, @@ -84,7 +83,6 @@ def radius( max_num_neighbors, num_workers) -@torch.jit.script def radius_graph( x: torch.Tensor, r: float, diff --git a/torch_cluster/rw.py b/torch_cluster/rw.py index 12e06837..cb7bc2c8 100644 --- a/torch_cluster/rw.py +++ b/torch_cluster/rw.py @@ -4,7 +4,6 @@ from torch import Tensor -@torch.jit.script def random_walk( row: Tensor, col: Tensor, @@ -55,7 +54,12 @@ def random_walk( torch.cumsum(deg, 0, out=rowptr[1:]) node_seq, edge_seq = torch.ops.torch_cluster.random_walk( - rowptr, col, start, walk_length, p, q, + rowptr, + col, + start, + walk_length, + p, + q, ) if return_edge_indices: diff --git a/torch_cluster/sampler.py b/torch_cluster/sampler.py index 9d2e08eb..1b68de0a 100644 --- a/torch_cluster/sampler.py +++ b/torch_cluster/sampler.py @@ -1,7 +1,6 @@ import torch -@torch.jit.script def neighbor_sampler(start: torch.Tensor, rowptr: torch.Tensor, size: float): assert not start.is_cuda diff --git a/torch_cluster/typing.py b/torch_cluster/typing.py index d570684e..f57544ac 100644 --- a/torch_cluster/typing.py +++ b/torch_cluster/typing.py @@ -1,3 +1,6 @@ import torch -WITH_PTR_LIST = hasattr(torch.ops.torch_cluster, 'fps_ptr_list') +try: + WITH_PTR_LIST = hasattr(torch.ops.torch_cluster, 'fps_ptr_list') +except Exception: + WITH_PTR_LIST = False