diff --git a/csrc/cluster.h b/csrc/cluster.h index ac8315b..c33fede 100644 --- a/csrc/cluster.h +++ b/csrc/cluster.h @@ -11,7 +11,7 @@ CLUSTER_INLINE_VARIABLE int64_t _cuda_version = cuda_version(); } // namespace cluster CLUSTER_API torch::Tensor fps(torch::Tensor src, torch::Tensor ptr, double ratio, - bool random_start); + int64_t num_points, bool random_start); CLUSTER_API torch::Tensor graclus(torch::Tensor rowptr, torch::Tensor col, torch::optional optional_weight); diff --git a/csrc/cpu/fps_cpu.cpp b/csrc/cpu/fps_cpu.cpp index d725ca8..24e42a0 100644 --- a/csrc/cpu/fps_cpu.cpp +++ b/csrc/cpu/fps_cpu.cpp @@ -9,11 +9,12 @@ inline torch::Tensor get_dist(torch::Tensor x, int64_t idx) { } torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio, - bool random_start) { + torch::Tensor num_points, bool random_start) { CHECK_CPU(src); CHECK_CPU(ptr); CHECK_CPU(ratio); + CHECK_CPU(num_points); CHECK_INPUT(ptr.dim() == 1); src = src.view({src.size(0), -1}).contiguous(); @@ -21,9 +22,17 @@ torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio, auto batch_size = ptr.numel() - 1; auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size); - auto out_ptr = deg.toType(torch::kFloat) * ratio; - out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0); - + torch::Tensor out_ptr; + if (num_points.sum().item() == 0) { + out_ptr = deg.toType(torch::kFloat) * ratio; + out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0); + } else { + TORCH_CHECK((deg.toType(torch::kLong) >= num_points.toType(torch::kLong)).all().item(), + "Passed tensor has fewer elements than requested number of returned points.") + out_ptr = deg.toType(torch::kLong) + .minimum(num_points.toType(torch::kLong)) + .cumsum(0); + } auto out = torch::empty({out_ptr[-1].data_ptr()[0]}, ptr.options()); auto ptr_data = ptr.data_ptr(); diff --git a/csrc/cpu/fps_cpu.h b/csrc/cpu/fps_cpu.h index d94292d..16c2b32 100644 --- a/csrc/cpu/fps_cpu.h +++ b/csrc/cpu/fps_cpu.h @@ -3,4 +3,4 @@ #include "../extensions.h" torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio, - bool random_start); + torch::Tensor num_points, bool random_start); diff --git a/csrc/cuda/fps_cuda.cu b/csrc/cuda/fps_cuda.cu index 38195fc..9ba4151 100644 --- a/csrc/cuda/fps_cuda.cu +++ b/csrc/cuda/fps_cuda.cu @@ -65,11 +65,13 @@ __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr, } torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, - torch::Tensor ratio, bool random_start) { + torch::Tensor ratio, torch::Tensor num_points, + bool random_start) { CHECK_CUDA(src); CHECK_CUDA(ptr); CHECK_CUDA(ratio); + CHECK_CUDA(num_points); CHECK_INPUT(ptr.dim() == 1); c10::cuda::MaybeSetDevice(src.get_device()); @@ -78,8 +80,18 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, auto batch_size = ptr.numel() - 1; auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size); - auto out_ptr = deg.toType(ratio.scalar_type()) * ratio; - out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0); + torch::Tensor out_ptr; + if (num_points.sum().item() == 0) { + out_ptr = deg.toType(ratio.scalar_type()) * ratio; + out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0); + } else { + TORCH_CHECK((deg.toType(torch::kLong) >= num_points.toType(torch::kLong)).all().item(), + "Passed tensor has fewer elements than requested number of returned points.") + out_ptr = deg.toType(torch::kLong) + .minimum(num_points.toType(torch::kLong)) + .cumsum(0); + } + out_ptr = torch::cat({torch::zeros({1}, ptr.options()), out_ptr}, 0); torch::Tensor start; diff --git a/csrc/cuda/fps_cuda.h b/csrc/cuda/fps_cuda.h index 9a7a164..f0e3071 100644 --- a/csrc/cuda/fps_cuda.h +++ b/csrc/cuda/fps_cuda.h @@ -3,4 +3,5 @@ #include "../extensions.h" torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr, - torch::Tensor ratio, bool random_start); + torch::Tensor ratio, torch::Tensor num_points, + bool random_start); diff --git a/csrc/fps.cpp b/csrc/fps.cpp index db39533..0ab599e 100644 --- a/csrc/fps.cpp +++ b/csrc/fps.cpp @@ -19,16 +19,17 @@ PyMODINIT_FUNC PyInit__fps_cpu(void) { return NULL; } #endif #endif -CLUSTER_API torch::Tensor fps(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio, - bool random_start) { +CLUSTER_API torch::Tensor fps(torch::Tensor src, torch::Tensor ptr, + torch::Tensor ratio, torch::Tensor num_points, + bool random_start) { if (src.device().is_cuda()) { #ifdef WITH_CUDA - return fps_cuda(src, ptr, ratio, random_start); + return fps_cuda(src, ptr, ratio, num_points, random_start); #else AT_ERROR("Not compiled with CUDA support"); #endif } else { - return fps_cpu(src, ptr, ratio, random_start); + return fps_cpu(src, ptr, ratio, num_points, random_start); } } diff --git a/test/test_fps.py b/test/test_fps.py index 0f10e51..f2b9b70 100644 --- a/test/test_fps.py +++ b/test/test_fps.py @@ -9,7 +9,7 @@ @torch.jit.script def fps2(x: Tensor, ratio: Tensor) -> Tensor: - return fps(x, None, ratio, False) + return fps(x, None, ratio, None, False) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @@ -33,6 +33,11 @@ def test_fps(dtype, device): out = fps(x, batch, ratio=0.5, random_start=False) assert out.tolist() == [0, 2, 4, 6] + out = fps(x, batch, num_points=2, random_start=False) + assert out.tolist() == [0, 2, 4, 6] + + out = fps(x, batch, num_points=4, random_start=False) + assert out.tolist() == [0, 2, 1, 3, 4, 6, 5, 7] ratio = torch.tensor(0.5, device=device) out = fps(x, batch, ratio=ratio, random_start=False) @@ -40,7 +45,6 @@ def test_fps(dtype, device): out = fps(x, ptr=ptr_list, ratio=0.5, random_start=False) assert out.tolist() == [0, 2, 4, 6] - out = fps(x, ptr=ptr, ratio=0.5, random_start=False) assert out.tolist() == [0, 2, 4, 6] @@ -48,11 +52,17 @@ def test_fps(dtype, device): out = fps(x, batch, ratio=ratio, random_start=False) assert out.tolist() == [0, 2, 4, 6] + num = torch.tensor([2, 2], device=device) + out = fps(x, batch, num_points=num, random_start=False) + assert out.tolist() == [0, 2, 4, 6] + out = fps(x, random_start=False) assert out.sort()[0].tolist() == [0, 5, 6, 7] out = fps(x, ratio=0.5, random_start=False) assert out.sort()[0].tolist() == [0, 5, 6, 7] + out = fps(x, num_points=4, random_start=False) + assert out.sort()[0].tolist() == [0, 5, 6, 7] out = fps(x, ratio=torch.tensor(0.5, device=device), random_start=False) assert out.sort()[0].tolist() == [0, 5, 6, 7] @@ -63,6 +73,17 @@ def test_fps(dtype, device): out = fps2(x, torch.tensor([0.5], device=device)) assert out.sort()[0].tolist() == [0, 5, 6, 7] + # requesting too many points + with pytest.raises(RuntimeError): + out = fps(x, batch, num_points=100, random_start=False) + + with pytest.raises(RuntimeError): + out = fps(x, batch, num_points=5, random_start=False) + + # invalid argument combination + with pytest.raises(ValueError): + out = fps(x, batch, ratio=0.0, num_points=0, random_start=False) + @pytest.mark.parametrize('device', devices) def test_random_fps(device): diff --git a/torch_cluster/fps.py b/torch_cluster/fps.py index 7baf981..a5f00a9 100644 --- a/torch_cluster/fps.py +++ b/torch_cluster/fps.py @@ -7,26 +7,49 @@ @torch.jit._overload # noqa -def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa - # type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa +def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa + # type: (Tensor, Optional[Tensor], Optional[float], Optional[int], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa pass # pragma: no cover @torch.jit._overload # noqa -def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa - # type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa +def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa + # type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[int], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa pass # pragma: no cover @torch.jit._overload # noqa -def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa - # type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa +def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa + # type: (Tensor, Optional[Tensor], Optional[float], Optional[int], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa pass # pragma: no cover @torch.jit._overload # noqa -def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa - # type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa +def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa + # type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[int], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa + pass # pragma: no cover + +@torch.jit._overload # noqa +def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa + # type: (Tensor, Optional[Tensor], Optional[float], Optional[Tensor], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa + pass # pragma: no cover + + +@torch.jit._overload # noqa +def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa + # type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa + pass # pragma: no cover + + +@torch.jit._overload # noqa +def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa + # type: (Tensor, Optional[Tensor], Optional[float], Optional[Tensor], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa + pass # pragma: no cover + + +@torch.jit._overload # noqa +def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa + # type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa pass # pragma: no cover @@ -34,6 +57,7 @@ def fps( # noqa src: torch.Tensor, batch: Optional[Tensor] = None, ratio: Optional[Union[Tensor, float]] = None, + num_points: Optional[Union[Tensor, int]] = None, random_start: bool = True, batch_size: Optional[int] = None, ptr: Optional[Union[Tensor, List[int]]] = None, @@ -50,7 +74,11 @@ def fps( # noqa :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. (default: :obj:`None`) ratio (float or Tensor, optional): Sampling ratio. + Only ratio or num_points can be specified. (default: :obj:`0.5`) + num_points (int, optional): Number of returned points. + Only ratio or num_points can be specified. + (default: :obj:`None`) random_start (bool, optional): If set to :obj:`False`, use the first node in :math:`\mathbf{X}` as starting node. (default: obj:`True`) batch_size (int, optional): The number of examples :math:`B`. @@ -71,25 +99,45 @@ def fps( # noqa batch = torch.tensor([0, 0, 0, 0]) index = fps(src, batch, ratio=0.5) """ + # check if only of of ratio or num_points is set + # if no one is set, fallback to ratio = 0.5 + if ratio is not None and num_points is not None: + raise ValueError("Only one of ratio and num_points can be specified.") + r: Optional[Tensor] = None if ratio is None: - r = torch.tensor(0.5, dtype=src.dtype, device=src.device) + if num_points is None: + r = torch.tensor(0.5, dtype=src.dtype, device=src.device) + else: + r = torch.tensor(0.0, dtype=src.dtype, device=src.device) elif isinstance(ratio, float): r = torch.tensor(ratio, dtype=src.dtype, device=src.device) else: r = ratio - assert r is not None + + num: Optional[Tensor] = None + if num_points is None: + num = torch.tensor(0, dtype=torch.long, device=src.device) + elif isinstance(num_points, int): + num = torch.tensor(num_points, dtype=torch.long, device=src.device) + else: + num = num_points + + assert r is not None and num is not None + + if r.sum() == 0 and num.sum() == 0: + raise ValueError("At least one of ratio or num_points should be > 0") if ptr is not None: if isinstance(ptr, list) and torch_cluster.typing.WITH_PTR_LIST: - return torch.ops.torch_cluster.fps_ptr_list( - src, ptr, r, random_start) + return torch.ops.torch_cluster.fps_ptr_list(src, ptr, r, random_start) if isinstance(ptr, list): return torch.ops.torch_cluster.fps( - src, torch.tensor(ptr, device=src.device), r, random_start) + src, torch.tensor(ptr, device=src.device), r, num, random_start + ) else: - return torch.ops.torch_cluster.fps(src, ptr, r, random_start) + return torch.ops.torch_cluster.fps(src, ptr, r, num, random_start) if batch is not None: assert src.size(0) == batch.numel() @@ -104,4 +152,4 @@ def fps( # noqa else: ptr_vec = torch.tensor([0, src.size(0)], device=src.device) - return torch.ops.torch_cluster.fps(src, ptr_vec, r, random_start) + return torch.ops.torch_cluster.fps(src, ptr_vec, r, num, random_start)