From b3d70bd53f59ca3279b7191973791354b805117a Mon Sep 17 00:00:00 2001 From: Dariusz Sciebura Date: Mon, 5 Jun 2023 09:08:03 +0200 Subject: [PATCH] Extend FPS with an extra ptr argument (#180) * Extend FPS with an extra ptr argument * update * update * update --------- Co-authored-by: rusty1s --- .github/workflows/testing.yml | 1 + test/test_fps.py | 16 ++++++++--- torch_cluster/fps.py | 51 +++++++++++++++++++++++++++-------- torch_cluster/typing.py | 3 +++ 4 files changed, 56 insertions(+), 15 deletions(-) create mode 100644 torch_cluster/typing.py diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 4a87490..ee6a580 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -31,6 +31,7 @@ jobs: - name: Install main package run: | + pip install scipy==1.10.1 # Python 3.8 support python setup.py develop - name: Run test-suite diff --git a/test/test_fps.py b/test/test_fps.py index 52b689d..0f10e51 100644 --- a/test/test_fps.py +++ b/test/test_fps.py @@ -25,6 +25,8 @@ def test_fps(dtype, device): [+2, -2], ], dtype, device) batch = tensor([0, 0, 0, 0, 1, 1, 1, 1], torch.long, device) + ptr_list = [0, 4, 8] + ptr = torch.tensor(ptr_list, device=device) out = fps(x, batch, random_start=False) assert out.tolist() == [0, 2, 4, 6] @@ -32,12 +34,18 @@ def test_fps(dtype, device): out = fps(x, batch, ratio=0.5, random_start=False) assert out.tolist() == [0, 2, 4, 6] - out = fps(x, batch, ratio=torch.tensor(0.5, device=device), - random_start=False) + ratio = torch.tensor(0.5, device=device) + out = fps(x, batch, ratio=ratio, random_start=False) assert out.tolist() == [0, 2, 4, 6] - out = fps(x, batch, ratio=torch.tensor([0.5, 0.5], device=device), - random_start=False) + out = fps(x, ptr=ptr_list, ratio=0.5, random_start=False) + assert out.tolist() == [0, 2, 4, 6] + + out = fps(x, ptr=ptr, ratio=0.5, random_start=False) + assert out.tolist() == [0, 2, 4, 6] + + ratio = torch.tensor([0.5, 0.5], device=device) + out = fps(x, batch, ratio=ratio, random_start=False) assert out.tolist() == [0, 2, 4, 6] out = fps(x, random_start=False) diff --git a/torch_cluster/fps.py b/torch_cluster/fps.py index e0d2782..7baf981 100644 --- a/torch_cluster/fps.py +++ b/torch_cluster/fps.py @@ -1,27 +1,42 @@ -from typing import Optional, Union +from typing import List, Optional, Union import torch from torch import Tensor +import torch_cluster.typing + + +@torch.jit._overload # noqa +def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa + # type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa + pass # pragma: no cover + + +@torch.jit._overload # noqa +def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa + # type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa + pass # pragma: no cover + @torch.jit._overload # noqa -def fps(src, batch, ratio, random_start, batch_size): # noqa - # type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int]) -> Tensor # noqa +def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa + # type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa pass # pragma: no cover @torch.jit._overload # noqa -def fps(src, batch, ratio, random_start, batch_size): # noqa - # type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int]) -> Tensor # noqa +def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa + # type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa pass # pragma: no cover def fps( # noqa src: torch.Tensor, batch: Optional[Tensor] = None, - ratio: Optional[Union[torch.Tensor, float]] = None, + ratio: Optional[Union[Tensor, float]] = None, random_start: bool = True, batch_size: Optional[int] = None, + ptr: Optional[Union[Tensor, List[int]]] = None, ): r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space" @@ -40,6 +55,10 @@ def fps( # noqa node in :math:`\mathbf{X}` as starting node. (default: obj:`True`) batch_size (int, optional): The number of examples :math:`B`. Automatically calculated if not given. (default: :obj:`None`) + ptr (torch.Tensor or [int], optional): If given, batch assignment will + be determined based on boundaries in CSR representation, *e.g.*, + :obj:`batch=[0,0,1,1,1,2]` translates to :obj:`ptr=[0,2,5,6]`. + (default: :obj:`None`) :rtype: :class:`LongTensor` @@ -52,7 +71,6 @@ def fps( # noqa batch = torch.tensor([0, 0, 0, 0]) index = fps(src, batch, ratio=0.5) """ - r: Optional[Tensor] = None if ratio is None: r = torch.tensor(0.5, dtype=src.dtype, device=src.device) @@ -62,6 +80,17 @@ def fps( # noqa r = ratio assert r is not None + if ptr is not None: + if isinstance(ptr, list) and torch_cluster.typing.WITH_PTR_LIST: + return torch.ops.torch_cluster.fps_ptr_list( + src, ptr, r, random_start) + + if isinstance(ptr, list): + return torch.ops.torch_cluster.fps( + src, torch.tensor(ptr, device=src.device), r, random_start) + else: + return torch.ops.torch_cluster.fps(src, ptr, r, random_start) + if batch is not None: assert src.size(0) == batch.numel() if batch_size is None: @@ -70,9 +99,9 @@ def fps( # noqa deg = src.new_zeros(batch_size, dtype=torch.long) deg.scatter_add_(0, batch, torch.ones_like(batch)) - ptr = deg.new_zeros(batch_size + 1) - torch.cumsum(deg, 0, out=ptr[1:]) + ptr_vec = deg.new_zeros(batch_size + 1) + torch.cumsum(deg, 0, out=ptr_vec[1:]) else: - ptr = torch.tensor([0, src.size(0)], device=src.device) + ptr_vec = torch.tensor([0, src.size(0)], device=src.device) - return torch.ops.torch_cluster.fps(src, ptr, r, random_start) + return torch.ops.torch_cluster.fps(src, ptr_vec, r, random_start) diff --git a/torch_cluster/typing.py b/torch_cluster/typing.py new file mode 100644 index 0000000..d570684 --- /dev/null +++ b/torch_cluster/typing.py @@ -0,0 +1,3 @@ +import torch + +WITH_PTR_LIST = hasattr(torch.ops.torch_cluster, 'fps_ptr_list')