From 32bee6460cb2398a4e5f1e76c622817147658931 Mon Sep 17 00:00:00 2001 From: Piotr Chmiel Date: Fri, 28 Apr 2023 13:43:32 +0200 Subject: [PATCH] Add `batch_size` argument for `fps`, `knn`, `radius` functions (#175) * Add batch_size argument for fps, knn, radius functions. It can be used to avoid additional calculations if a user is using fixed-size batch. * update * update * update * update --------- Co-authored-by: rusty1s --- .github/workflows/testing.yml | 3 ++- torch_cluster/fps.py | 24 +++++++++++------ torch_cluster/knn.py | 49 ++++++++++++++++++++++----------- torch_cluster/radius.py | 51 ++++++++++++++++++++++++----------- 4 files changed, 87 insertions(+), 40 deletions(-) diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index 6b177df..4a87490 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -31,10 +31,11 @@ jobs: - name: Install main package run: | - pip install -e .[test] + python setup.py develop - name: Run test-suite run: | + pip install pytest pytest-cov pytest --cov --cov-report=xml - name: Upload coverage diff --git a/torch_cluster/fps.py b/torch_cluster/fps.py index 7901dd5..e0d2782 100644 --- a/torch_cluster/fps.py +++ b/torch_cluster/fps.py @@ -1,22 +1,28 @@ -from typing import Optional +from typing import Optional, Union import torch from torch import Tensor @torch.jit._overload # noqa -def fps(src, batch=None, ratio=None, random_start=True): # noqa - # type: (Tensor, Optional[Tensor], Optional[float], bool) -> Tensor +def fps(src, batch, ratio, random_start, batch_size): # noqa + # type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int]) -> Tensor # noqa pass # pragma: no cover @torch.jit._overload # noqa -def fps(src, batch=None, ratio=None, random_start=True): # noqa - # type: (Tensor, Optional[Tensor], Optional[Tensor], bool) -> Tensor +def fps(src, batch, ratio, random_start, batch_size): # noqa + # type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int]) -> Tensor # noqa pass # pragma: no cover -def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa +def fps( # noqa + src: torch.Tensor, + batch: Optional[Tensor] = None, + ratio: Optional[Union[torch.Tensor, float]] = None, + random_start: bool = True, + batch_size: Optional[int] = None, +): r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space" `_ paper, which iteratively samples the @@ -32,10 +38,11 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa (default: :obj:`0.5`) random_start (bool, optional): If set to :obj:`False`, use the first node in :math:`\mathbf{X}` as starting node. (default: obj:`True`) + batch_size (int, optional): The number of examples :math:`B`. + Automatically calculated if not given. (default: :obj:`None`) :rtype: :class:`LongTensor` - .. code-block:: python import torch @@ -57,7 +64,8 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa if batch is not None: assert src.size(0) == batch.numel() - batch_size = int(batch.max()) + 1 + if batch_size is None: + batch_size = int(batch.max()) + 1 deg = src.new_zeros(batch_size, dtype=torch.long) deg.scatter_add_(0, batch, torch.ones_like(batch)) diff --git a/torch_cluster/knn.py b/torch_cluster/knn.py index b981c46..4eace5e 100644 --- a/torch_cluster/knn.py +++ b/torch_cluster/knn.py @@ -4,10 +4,16 @@ @torch.jit.script -def knn(x: torch.Tensor, y: torch.Tensor, k: int, - batch_x: Optional[torch.Tensor] = None, - batch_y: Optional[torch.Tensor] = None, cosine: bool = False, - num_workers: int = 1) -> torch.Tensor: +def knn( + x: torch.Tensor, + y: torch.Tensor, + k: int, + batch_x: Optional[torch.Tensor] = None, + batch_y: Optional[torch.Tensor] = None, + cosine: bool = False, + num_workers: int = 1, + batch_size: Optional[int] = None, +) -> torch.Tensor: r"""Finds for each element in :obj:`y` the :obj:`k` nearest points in :obj:`x`. @@ -31,6 +37,8 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int, num_workers (int): Number of workers to use for computation. Has no effect in case :obj:`batch_x` or :obj:`batch_y` is not :obj:`None`, or the input lies on the GPU. (default: :obj:`1`) + batch_size (int, optional): The number of examples :math:`B`. + Automatically calculated if not given. (default: :obj:`None`) :rtype: :class:`LongTensor` @@ -52,13 +60,15 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int, y = y.view(-1, 1) if y.dim() == 1 else y x, y = x.contiguous(), y.contiguous() - batch_size = 1 - if batch_x is not None: - assert x.size(0) == batch_x.numel() - batch_size = int(batch_x.max()) + 1 - if batch_y is not None: - assert y.size(0) == batch_y.numel() - batch_size = max(batch_size, int(batch_y.max()) + 1) + if batch_size is None: + batch_size = 1 + if batch_x is not None: + assert x.size(0) == batch_x.numel() + batch_size = int(batch_x.max()) + 1 + if batch_y is not None: + assert y.size(0) == batch_y.numel() + batch_size = max(batch_size, int(batch_y.max()) + 1) + assert batch_size > 0 ptr_x: Optional[torch.Tensor] = None ptr_y: Optional[torch.Tensor] = None @@ -74,9 +84,16 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int, @torch.jit.script -def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None, - loop: bool = False, flow: str = 'source_to_target', - cosine: bool = False, num_workers: int = 1) -> torch.Tensor: +def knn_graph( + x: torch.Tensor, + k: int, + batch: Optional[torch.Tensor] = None, + loop: bool = False, + flow: str = 'source_to_target', + cosine: bool = False, + num_workers: int = 1, + batch_size: Optional[int] = None, +) -> torch.Tensor: r"""Computes graph edges to the nearest :obj:`k` points. Args: @@ -98,6 +115,8 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None, num_workers (int): Number of workers to use for computation. Has no effect in case :obj:`batch` is not :obj:`None`, or the input lies on the GPU. (default: :obj:`1`) + batch_size (int, optional): The number of examples :math:`B`. + Automatically calculated if not given. (default: :obj:`None`) :rtype: :class:`LongTensor` @@ -113,7 +132,7 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None, assert flow in ['source_to_target', 'target_to_source'] edge_index = knn(x, x, k if loop else k + 1, batch, batch, cosine, - num_workers) + num_workers, batch_size) if flow == 'source_to_target': row, col = edge_index[1], edge_index[0] diff --git a/torch_cluster/radius.py b/torch_cluster/radius.py index fd73b75..de35298 100644 --- a/torch_cluster/radius.py +++ b/torch_cluster/radius.py @@ -4,10 +4,16 @@ @torch.jit.script -def radius(x: torch.Tensor, y: torch.Tensor, r: float, - batch_x: Optional[torch.Tensor] = None, - batch_y: Optional[torch.Tensor] = None, max_num_neighbors: int = 32, - num_workers: int = 1) -> torch.Tensor: +def radius( + x: torch.Tensor, + y: torch.Tensor, + r: float, + batch_x: Optional[torch.Tensor] = None, + batch_y: Optional[torch.Tensor] = None, + max_num_neighbors: int = 32, + num_workers: int = 1, + batch_size: Optional[int] = None, +) -> torch.Tensor: r"""Finds for each element in :obj:`y` all points in :obj:`x` within distance :obj:`r`. @@ -33,6 +39,8 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float, num_workers (int): Number of workers to use for computation. Has no effect in case :obj:`batch_x` or :obj:`batch_y` is not :obj:`None`, or the input lies on the GPU. (default: :obj:`1`) + batch_size (int, optional): The number of examples :math:`B`. + Automatically calculated if not given. (default: :obj:`None`) .. code-block:: python @@ -52,16 +60,19 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float, y = y.view(-1, 1) if y.dim() == 1 else y x, y = x.contiguous(), y.contiguous() - batch_size = 1 - if batch_x is not None: - assert x.size(0) == batch_x.numel() - batch_size = int(batch_x.max()) + 1 - if batch_y is not None: - assert y.size(0) == batch_y.numel() - batch_size = max(batch_size, int(batch_y.max()) + 1) + if batch_size is None: + batch_size = 1 + if batch_x is not None: + assert x.size(0) == batch_x.numel() + batch_size = int(batch_x.max()) + 1 + if batch_y is not None: + assert y.size(0) == batch_y.numel() + batch_size = max(batch_size, int(batch_y.max()) + 1) + assert batch_size > 0 ptr_x: Optional[torch.Tensor] = None ptr_y: Optional[torch.Tensor] = None + if batch_size > 1: assert batch_x is not None assert batch_y is not None @@ -74,10 +85,16 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float, @torch.jit.script -def radius_graph(x: torch.Tensor, r: float, - batch: Optional[torch.Tensor] = None, loop: bool = False, - max_num_neighbors: int = 32, flow: str = 'source_to_target', - num_workers: int = 1) -> torch.Tensor: +def radius_graph( + x: torch.Tensor, + r: float, + batch: Optional[torch.Tensor] = None, + loop: bool = False, + max_num_neighbors: int = 32, + flow: str = 'source_to_target', + num_workers: int = 1, + batch_size: Optional[int] = None, +) -> torch.Tensor: r"""Computes graph edges to all points within a given distance. Args: @@ -101,6 +118,8 @@ def radius_graph(x: torch.Tensor, r: float, num_workers (int): Number of workers to use for computation. Has no effect in case :obj:`batch` is not :obj:`None`, or the input lies on the GPU. (default: :obj:`1`) + batch_size (int, optional): The number of examples :math:`B`. + Automatically calculated if not given. (default: :obj:`None`) :rtype: :class:`LongTensor` @@ -117,7 +136,7 @@ def radius_graph(x: torch.Tensor, r: float, assert flow in ['source_to_target', 'target_to_source'] edge_index = radius(x, x, r, batch, batch, max_num_neighbors if loop else max_num_neighbors + 1, - num_workers) + num_workers, batch_size) if flow == 'source_to_target': row, col = edge_index[1], edge_index[0] else: