diff --git a/.github/workflows/building-conda.yml b/.github/workflows/building-conda.yml index 5c7e961..982c81d 100644 --- a/.github/workflows/building-conda.yml +++ b/.github/workflows/building-conda.yml @@ -11,21 +11,14 @@ jobs: fail-fast: false matrix: # We have trouble building for Windows - drop for now. - os: [ubuntu-20.04, macos-11] # windows-2019 - python-version: ['3.8', '3.9', '3.10', '3.11'] - torch-version: [2.0.0, 2.1.0] - cuda-version: ['cpu', 'cu117', 'cu118', 'cu121'] + os: [ubuntu-20.04] # windows-2019 + python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] + torch-version: [2.3.0] # [2.1.0, 2.2.0, 2.3.0] + cuda-version: ['cpu', 'cu118', 'cu121'] exclude: - - torch-version: 2.0.0 - cuda-version: 'cu121' + - python-version: '3.12' # Python 3.12 not yet supported in `conda-build`. - torch-version: 2.1.0 - cuda-version: 'cu117' - - os: macos-11 - cuda-version: 'cu117' - - os: macos-11 - cuda-version: 'cu118' - - os: macos-11 - cuda-version: 'cu121' + python-version: '3.12' steps: - uses: actions/checkout@v2 diff --git a/.github/workflows/building.yml b/.github/workflows/building.yml index 9e00feb..00629c1 100644 --- a/.github/workflows/building.yml +++ b/.github/workflows/building.yml @@ -10,20 +10,16 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-20.04, macos-11, windows-2019] - python-version: ['3.8', '3.9', '3.10', '3.11'] - torch-version: [2.0.0, 2.1.0] - cuda-version: ['cpu', 'cu117', 'cu118', 'cu121'] + os: [ubuntu-20.04, macos-14, windows-2019] + python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] + torch-version: [2.3.0] # [2.1.0, 2.2.0, 2.3.0] + cuda-version: ['cpu', 'cu118', 'cu121'] exclude: - - torch-version: 2.0.0 - cuda-version: 'cu121' - torch-version: 2.1.0 - cuda-version: 'cu117' - - os: macos-11 - cuda-version: 'cu117' - - os: macos-11 + python-version: '3.12' + - os: macos-14 cuda-version: 'cu118' - - os: macos-11 + - os: macos-14 cuda-version: 'cu121' steps: @@ -36,8 +32,11 @@ jobs: - name: Upgrade pip run: | pip install --upgrade setuptools - pip install scipy==1.10.1 # Python 3.8 support - pip list + + - name: Install scipy + if: ${{ matrix.python-version == '3.8' }} + run: | + pip install scipy==1.10.1 - name: Free Disk Space (Ubuntu) if: ${{ runner.os == 'Linux' }} diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 2a08d95..dd517b0 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -16,7 +16,7 @@ jobs: matrix: os: [ubuntu-latest, windows-latest] python-version: [3.8] - torch-version: [2.0.0, 2.1.0] + torch-version: [2.2.0, 2.3.0] steps: - uses: actions/checkout@v2 @@ -29,9 +29,13 @@ jobs: run: | pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/cpu + - name: Install scipy + if: ${{ matrix.python-version == '3.8' }} + run: | + pip install scipy==1.10.1 + - name: Install main package run: | - pip install scipy==1.10.1 # Python 3.8 support python setup.py develop - name: Run test-suite @@ -40,7 +44,7 @@ jobs: pytest --cov --cov-report=xml - name: Upload coverage - uses: codecov/codecov-action@v1 + uses: codecov/codecov-action@v4 if: success() with: fail_ci_if_error: false diff --git a/CMakeLists.txt b/CMakeLists.txt index 2edd49b..baea7a1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.0) project(torchcluster) set(CMAKE_CXX_STANDARD 14) -set(TORCHCLUSTER_VERSION 1.6.2) +set(TORCHCLUSTER_VERSION 1.6.3) option(WITH_CUDA "Enable CUDA support" OFF) option(WITH_PYTHON "Link to Python when building" ON) diff --git a/README.md b/README.md index ce55d79..5479e7d 100644 --- a/README.md +++ b/README.md @@ -43,12 +43,12 @@ conda install pytorch-cluster -c pyg We alternatively provide pip wheels for all major OS/PyTorch/CUDA combinations, see [here](https://data.pyg.org/whl). -#### PyTorch 2.1 +#### PyTorch 2.3 -To install the binaries for PyTorch 2.1.0, simply run +To install the binaries for PyTorch 2.3.0, simply run ``` -pip install torch-cluster -f https://data.pyg.org/whl/torch-2.1.0+${CUDA}.html +pip install torch-cluster -f https://data.pyg.org/whl/torch-2.3.0+${CUDA}.html ``` where `${CUDA}` should be replaced by either `cpu`, `cu118`, or `cu121` depending on your PyTorch installation. @@ -59,23 +59,23 @@ where `${CUDA}` should be replaced by either `cpu`, `cu118`, or `cu121` dependin | **Windows** | ✅ | ✅ | ✅ | | **macOS** | ✅ | | | -#### PyTorch 2.0 +#### PyTorch 2.2 -To install the binaries for PyTorch 2.0.0, simply run +To install the binaries for PyTorch 2.2.0, simply run ``` -pip install torch-cluster -f https://data.pyg.org/whl/torch-2.0.0+${CUDA}.html +pip install torch-cluster -f https://data.pyg.org/whl/torch-2.2.0+${CUDA}.html ``` -where `${CUDA}` should be replaced by either `cpu`, `cu117`, or `cu118` depending on your PyTorch installation. +where `${CUDA}` should be replaced by either `cpu`, `cu118`, or `cu121` depending on your PyTorch installation. -| | `cpu` | `cu117` | `cu118` | +| | `cpu` | `cu118` | `cu121` | |-------------|-------|---------|---------| | **Linux** | ✅ | ✅ | ✅ | | **Windows** | ✅ | ✅ | ✅ | | **macOS** | ✅ | | | -**Note:** Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1, PyTorch 1.8.0/1.8.1, PyTorch 1.9.0, PyTorch 1.10.0/1.10.1/1.10.2, PyTorch 1.11.0, PyTorch 1.12.0/1.12.1 and PyTorch 1.13.0/1.13.1 (following the same procedure). +**Note:** Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1, PyTorch 1.8.0/1.8.1, PyTorch 1.9.0, PyTorch 1.10.0/1.10.1/1.10.2, PyTorch 1.11.0, PyTorch 1.12.0/1.12.1, PyTorch 1.13.0/1.13.1, PyTorch 2.0.0/2.0.1, and PyTorch 2.1.0/2.1.1/2.1.2 (following the same procedure). For older versions, you need to explicitly specify the latest supported version number or install via `pip install --no-index` in order to prevent a manual installation from source. You can look up the latest supported version number [here](https://data.pyg.org/whl). diff --git a/conda/pytorch-cluster/README.md b/conda/pytorch-cluster/README.md index 6816209..c562ff4 100644 --- a/conda/pytorch-cluster/README.md +++ b/conda/pytorch-cluster/README.md @@ -1,3 +1,3 @@ ``` -./build_conda.sh 3.9 2.1.0 cu118 # python, pytorch and cuda version +./build_conda.sh 3.11 2.3.0 cu118 # python, pytorch and cuda version ``` diff --git a/conda/pytorch-cluster/meta.yaml b/conda/pytorch-cluster/meta.yaml index dfb2bdb..8e106b0 100644 --- a/conda/pytorch-cluster/meta.yaml +++ b/conda/pytorch-cluster/meta.yaml @@ -1,6 +1,6 @@ package: name: pytorch-cluster - version: 1.6.2 + version: 1.6.3 source: path: ../.. diff --git a/csrc/cuda/fps_cuda.cu b/csrc/cuda/fps_cuda.cu index dd3671a..38195fc 100644 --- a/csrc/cuda/fps_cuda.cu +++ b/csrc/cuda/fps_cuda.cu @@ -71,7 +71,7 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, CHECK_CUDA(ptr); CHECK_CUDA(ratio); CHECK_INPUT(ptr.dim() == 1); - cudaSetDevice(src.get_device()); + c10::cuda::MaybeSetDevice(src.get_device()); src = src.view({src.size(0), -1}).contiguous(); ptr = ptr.contiguous(); diff --git a/csrc/cuda/graclus_cuda.cu b/csrc/cuda/graclus_cuda.cu index 61e7d70..3bb118b 100644 --- a/csrc/cuda/graclus_cuda.cu +++ b/csrc/cuda/graclus_cuda.cu @@ -223,7 +223,7 @@ torch::Tensor graclus_cuda(torch::Tensor rowptr, torch::Tensor col, CHECK_INPUT(optional_weight.value().dim() == 1); CHECK_INPUT(optional_weight.value().numel() == col.numel()); } - cudaSetDevice(rowptr.get_device()); + c10::cuda::MaybeSetDevice(rowptr.get_device()); int64_t num_nodes = rowptr.numel() - 1; auto out = torch::full(num_nodes, -1, rowptr.options()); diff --git a/csrc/cuda/grid_cuda.cu b/csrc/cuda/grid_cuda.cu index 8696b9f..64037bd 100644 --- a/csrc/cuda/grid_cuda.cu +++ b/csrc/cuda/grid_cuda.cu @@ -29,7 +29,7 @@ torch::Tensor grid_cuda(torch::Tensor pos, torch::Tensor size, torch::optional optional_end) { CHECK_CUDA(pos); CHECK_CUDA(size); - cudaSetDevice(pos.get_device()); + c10::cuda::MaybeSetDevice(pos.get_device()); if (optional_start.has_value()) CHECK_CUDA(optional_start.value()); diff --git a/csrc/cuda/knn_cuda.cu b/csrc/cuda/knn_cuda.cu index caa5c96..c4dac2a 100644 --- a/csrc/cuda/knn_cuda.cu +++ b/csrc/cuda/knn_cuda.cu @@ -113,7 +113,7 @@ torch::Tensor knn_cuda(const torch::Tensor x, const torch::Tensor y, CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel()); - cudaSetDevice(x.get_device()); + c10::cuda::MaybeSetDevice(x.get_device()); auto row = torch::empty({y.size(0) * k}, ptr_y.value().options()); auto col = torch::full(y.size(0) * k, -1, ptr_y.value().options()); diff --git a/csrc/cuda/nearest_cuda.cu b/csrc/cuda/nearest_cuda.cu index 81eef92..7a3458e 100644 --- a/csrc/cuda/nearest_cuda.cu +++ b/csrc/cuda/nearest_cuda.cu @@ -71,7 +71,7 @@ torch::Tensor nearest_cuda(torch::Tensor x, torch::Tensor y, CHECK_CUDA(y); CHECK_CUDA(ptr_x); CHECK_CUDA(ptr_y); - cudaSetDevice(x.get_device()); + c10::cuda::MaybeSetDevice(x.get_device()); x = x.view({x.size(0), -1}).contiguous(); y = y.view({y.size(0), -1}).contiguous(); diff --git a/csrc/cuda/radius_cuda.cu b/csrc/cuda/radius_cuda.cu index a4f0283..7efb2ff 100644 --- a/csrc/cuda/radius_cuda.cu +++ b/csrc/cuda/radius_cuda.cu @@ -52,7 +52,7 @@ torch::Tensor radius_cuda(const torch::Tensor x, const torch::Tensor y, CHECK_INPUT(y.dim() == 2); CHECK_INPUT(x.size(1) == y.size(1)); - cudaSetDevice(x.get_device()); + c10::cuda::MaybeSetDevice(x.get_device()); if (ptr_x.has_value()) { CHECK_CUDA(ptr_x.value()); @@ -70,8 +70,6 @@ torch::Tensor radius_cuda(const torch::Tensor x, const torch::Tensor y, CHECK_INPUT(ptr_x.value().numel() == ptr_y.value().numel()); - cudaSetDevice(x.get_device()); - auto row = torch::full(y.size(0) * max_num_neighbors, -1, ptr_y.value().options()); auto col = @@ -81,13 +79,15 @@ torch::Tensor radius_cuda(const torch::Tensor x, const torch::Tensor y, auto stream = at::cuda::getCurrentCUDAStream(); auto scalar_type = x.scalar_type(); - AT_DISPATCH_FLOATING_TYPES_AND(at::ScalarType::Half, scalar_type, "_", [&] { - radius_kernel<<>>( - x.data_ptr(), y.data_ptr(), - ptr_x.value().data_ptr(), ptr_y.value().data_ptr(), - row.data_ptr(), col.data_ptr(), r * r, x.size(0), - y.size(0), x.size(1), ptr_x.value().numel() - 1, max_num_neighbors); - }); + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, scalar_type, "_", [&] { + radius_kernel<<>>( + x.data_ptr(), y.data_ptr(), + ptr_x.value().data_ptr(), + ptr_y.value().data_ptr(), row.data_ptr(), + col.data_ptr(), r * r, x.size(0), y.size(0), x.size(1), + ptr_x.value().numel() - 1, max_num_neighbors); + }); auto mask = row != -1; return torch::stack({row.masked_select(mask), col.masked_select(mask)}, 0); diff --git a/csrc/cuda/rw_cuda.cu b/csrc/cuda/rw_cuda.cu index dc89063..6008743 100644 --- a/csrc/cuda/rw_cuda.cu +++ b/csrc/cuda/rw_cuda.cu @@ -285,7 +285,7 @@ random_walk_cuda(torch::Tensor rowptr, torch::Tensor col, torch::Tensor start, CHECK_CUDA(rowptr); CHECK_CUDA(col); CHECK_CUDA(start); - cudaSetDevice(rowptr.get_device()); + c10::cuda::MaybeSetDevice(rowptr.get_device()); CHECK_INPUT(rowptr.dim() == 1); CHECK_INPUT(col.dim() == 1); diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..dd14ceb --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools", "torch"] +build-backend = "setuptools.build_meta" diff --git a/setup.cfg b/setup.cfg index 1f21bb7..9a0eaf6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -10,6 +10,7 @@ classifiers = Programming Language :: Python :: 3.9 Programming Language :: Python :: 3.10 Programming Language :: Python :: 3.11 + Programming Language :: Python :: 3.12 Programming Language :: Python :: 3 :: Only [aliases] diff --git a/setup.py b/setup.py index 4bc1a8c..1ef90ab 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension, CUDAExtension) -__version__ = '1.6.2' +__version__ = '1.6.3' URL = 'https://github.com/rusty1s/pytorch_cluster' WITH_CUDA = False @@ -61,9 +61,11 @@ def get_extensions(): print('Compiling without OpenMP...') # Compile for mac arm64 - if (sys.platform == 'darwin' and platform.machine() == 'arm64'): - extra_compile_args['cxx'] += ['-arch', 'arm64'] - extra_link_args += ['-arch', 'arm64'] + if sys.platform == 'darwin': + extra_compile_args['cxx'] += ['-D_LIBCPP_DISABLE_AVAILABILITY'] + if platform.machine == 'arm64': + extra_compile_args['cxx'] += ['-arch', 'arm64'] + extra_link_args += ['-arch', 'arm64'] if suffix == 'cuda': define_macros += [('WITH_CUDA', None)] diff --git a/test/test_graclus.py b/test/test_graclus.py index b892330..c8e1f39 100644 --- a/test/test_graclus.py +++ b/test/test_graclus.py @@ -50,3 +50,7 @@ def test_graclus_cluster(test, dtype, device): cluster = graclus_cluster(row, col, weight) assert_correct(row, col, cluster) + + jit = torch.jit.script(graclus_cluster) + cluster = jit(row, col, weight) + assert_correct(row, col, cluster) diff --git a/test/test_grid.py b/test/test_grid.py index c297f33..2d53220 100644 --- a/test/test_grid.py +++ b/test/test_grid.py @@ -38,3 +38,6 @@ def test_grid_cluster(test, dtype, device): cluster = grid_cluster(pos, size, start, end) assert cluster.tolist() == test['cluster'] + + jit = torch.jit.script(grid_cluster) + assert torch.equal(jit(pos, size, start, end), cluster) diff --git a/test/test_knn.py b/test/test_knn.py index 8113a54..32852fe 100644 --- a/test/test_knn.py +++ b/test/test_knn.py @@ -34,6 +34,10 @@ def test_knn(dtype, device): edge_index = knn(x, y, 2) assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)]) + jit = torch.jit.script(knn) + edge_index = jit(x, y, 2) + assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)]) + edge_index = knn(x, y, 2, batch_x, batch_y) assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)]) @@ -65,6 +69,11 @@ def test_knn_graph(dtype, device): assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2), (3, 2), (0, 3), (2, 3)]) + jit = torch.jit.script(knn_graph) + edge_index = jit(x, k=2, flow='source_to_target') + assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2), + (3, 2), (0, 3), (2, 3)]) + @pytest.mark.parametrize('dtype,device', product([torch.float], devices)) def test_knn_graph_large(dtype, device): diff --git a/test/test_radius.py b/test/test_radius.py index 34c4ad9..b20b2bf 100644 --- a/test/test_radius.py +++ b/test/test_radius.py @@ -4,14 +4,14 @@ import scipy.spatial import torch from torch_cluster import radius, radius_graph -from torch_cluster.testing import devices, grad_dtypes, tensor +from torch_cluster.testing import devices, floating_dtypes, tensor def to_set(edge_index): return set([(i, j) for i, j in edge_index.t().tolist()]) -@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) +@pytest.mark.parametrize('dtype,device', product(floating_dtypes, devices)) def test_radius(dtype, device): x = tensor([ [-1, -1], @@ -35,6 +35,11 @@ def test_radius(dtype, device): assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 1), (1, 2), (1, 5), (1, 6)]) + jit = torch.jit.script(radius) + edge_index = jit(x, y, 2, max_num_neighbors=4) + assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 1), + (1, 2), (1, 5), (1, 6)]) + edge_index = radius(x, y, 2, batch_x, batch_y, max_num_neighbors=4) assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 5), (1, 6)]) @@ -47,7 +52,7 @@ def test_radius(dtype, device): (1, 6)]) -@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) +@pytest.mark.parametrize('dtype,device', product(floating_dtypes, devices)) def test_radius_graph(dtype, device): x = tensor([ [-1, -1], @@ -64,12 +69,20 @@ def test_radius_graph(dtype, device): assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2), (3, 2), (0, 3), (2, 3)]) + jit = torch.jit.script(radius_graph) + edge_index = jit(x, r=2.5, flow='source_to_target') + assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2), + (3, 2), (0, 3), (2, 3)]) + @pytest.mark.parametrize('dtype,device', product([torch.float], devices)) def test_radius_graph_large(dtype, device): x = torch.randn(1000, 3, dtype=dtype, device=device) - edge_index = radius_graph(x, r=0.5, flow='target_to_source', loop=True, + edge_index = radius_graph(x, + r=0.5, + flow='target_to_source', + loop=True, max_num_neighbors=2000) tree = scipy.spatial.cKDTree(x.cpu().numpy()) diff --git a/test/test_rw.py b/test/test_rw.py index 27bebc4..7168c36 100644 --- a/test/test_rw.py +++ b/test/test_rw.py @@ -31,6 +31,9 @@ def test_rw_small(device): out = random_walk(row, col, start, walk_length, num_nodes=3) assert out.tolist() == [[0, 1, 0, 1, 0], [1, 0, 1, 0, 1], [2, 2, 2, 2, 2]] + jit = torch.jit.script(random_walk) + assert torch.equal(jit(row, col, start, walk_length, num_nodes=3), out) + @pytest.mark.parametrize('device', devices) def test_rw_large_with_edge_indices(device): diff --git a/torch_cluster/__init__.py b/torch_cluster/__init__.py index 30aea66..39d1d67 100644 --- a/torch_cluster/__init__.py +++ b/torch_cluster/__init__.py @@ -3,7 +3,7 @@ import torch -__version__ = '1.6.2' +__version__ = '1.6.3' for library in [ '_version', '_grid', '_graclus', '_fps', '_rw', '_sampler', '_nearest', diff --git a/torch_cluster/graclus.py b/torch_cluster/graclus.py index cdac892..7fa834d 100644 --- a/torch_cluster/graclus.py +++ b/torch_cluster/graclus.py @@ -3,10 +3,12 @@ import torch -@torch.jit.script -def graclus_cluster(row: torch.Tensor, col: torch.Tensor, - weight: Optional[torch.Tensor] = None, - num_nodes: Optional[int] = None) -> torch.Tensor: +def graclus_cluster( + row: torch.Tensor, + col: torch.Tensor, + weight: Optional[torch.Tensor] = None, + num_nodes: Optional[int] = None, +) -> torch.Tensor: """A greedy clustering algorithm of picking an unmarked vertex and matching it with one its unmarked neighbors (that maximizes its edge weight). diff --git a/torch_cluster/grid.py b/torch_cluster/grid.py index 1dbacb9..da59d51 100644 --- a/torch_cluster/grid.py +++ b/torch_cluster/grid.py @@ -3,10 +3,12 @@ import torch -@torch.jit.script -def grid_cluster(pos: torch.Tensor, size: torch.Tensor, - start: Optional[torch.Tensor] = None, - end: Optional[torch.Tensor] = None) -> torch.Tensor: +def grid_cluster( + pos: torch.Tensor, + size: torch.Tensor, + start: Optional[torch.Tensor] = None, + end: Optional[torch.Tensor] = None, +) -> torch.Tensor: """A clustering algorithm, which overlays a regular grid of user-defined size over a point cloud and clusters all points within a voxel. diff --git a/torch_cluster/knn.py b/torch_cluster/knn.py index 4eace5e..cf8f087 100644 --- a/torch_cluster/knn.py +++ b/torch_cluster/knn.py @@ -3,7 +3,6 @@ import torch -@torch.jit.script def knn( x: torch.Tensor, y: torch.Tensor, @@ -83,7 +82,6 @@ def knn( num_workers) -@torch.jit.script def knn_graph( x: torch.Tensor, k: int, diff --git a/torch_cluster/radius.py b/torch_cluster/radius.py index de35298..069824a 100644 --- a/torch_cluster/radius.py +++ b/torch_cluster/radius.py @@ -3,7 +3,6 @@ import torch -@torch.jit.script def radius( x: torch.Tensor, y: torch.Tensor, @@ -84,7 +83,6 @@ def radius( max_num_neighbors, num_workers) -@torch.jit.script def radius_graph( x: torch.Tensor, r: float, diff --git a/torch_cluster/rw.py b/torch_cluster/rw.py index 4f0e0f9..b6ff8d9 100644 --- a/torch_cluster/rw.py +++ b/torch_cluster/rw.py @@ -4,7 +4,6 @@ from torch import Tensor -@torch.jit.script def random_walk( row: Tensor, col: Tensor, @@ -66,6 +65,7 @@ def random_walk( node_seq, edge_seq = torch.ops.torch_cluster.random_walk_weighted( rowptr, col, edge_weight, start, walk_length, p, q, ) + if return_edge_indices: return node_seq, edge_seq diff --git a/torch_cluster/sampler.py b/torch_cluster/sampler.py index 9d2e08e..1b68de0 100644 --- a/torch_cluster/sampler.py +++ b/torch_cluster/sampler.py @@ -1,7 +1,6 @@ import torch -@torch.jit.script def neighbor_sampler(start: torch.Tensor, rowptr: torch.Tensor, size: float): assert not start.is_cuda diff --git a/torch_cluster/testing.py b/torch_cluster/testing.py index a124eda..68949fa 100644 --- a/torch_cluster/testing.py +++ b/torch_cluster/testing.py @@ -6,7 +6,11 @@ torch.half, torch.bfloat16, torch.float, torch.double, torch.int, torch.long ] -grad_dtypes = [torch.half, torch.float, torch.double] +if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): + grad_dtypes = [torch.float, torch.double] +else: + grad_dtypes = [torch.half, torch.float, torch.double] +floating_dtypes = grad_dtypes + [torch.bfloat16] devices = [torch.device('cpu')] if torch.cuda.is_available(): diff --git a/torch_cluster/typing.py b/torch_cluster/typing.py index d570684..f57544a 100644 --- a/torch_cluster/typing.py +++ b/torch_cluster/typing.py @@ -1,3 +1,6 @@ import torch -WITH_PTR_LIST = hasattr(torch.ops.torch_cluster, 'fps_ptr_list') +try: + WITH_PTR_LIST = hasattr(torch.ops.torch_cluster, 'fps_ptr_list') +except Exception: + WITH_PTR_LIST = False