Skip to content

Commit

Permalink
Add batch_size argument for fps, knn, radius functions (#175)
Browse files Browse the repository at this point in the history
* Add batch_size argument for fps, knn, radius functions.

It can be used to avoid additional calculations if a user is using
fixed-size batch.

* update

* update

* update

* update

---------

Co-authored-by: rusty1s <[email protected]>
  • Loading branch information
piotrchmiel and rusty1s authored Apr 28, 2023
1 parent 84bbb71 commit 32bee64
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 40 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ jobs:
- name: Install main package
run: |
pip install -e .[test]
python setup.py develop
- name: Run test-suite
run: |
pip install pytest pytest-cov
pytest --cov --cov-report=xml
- name: Upload coverage
Expand Down
24 changes: 16 additions & 8 deletions torch_cluster/fps.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
from typing import Optional
from typing import Optional, Union

import torch
from torch import Tensor


@torch.jit._overload # noqa
def fps(src, batch=None, ratio=None, random_start=True): # noqa
# type: (Tensor, Optional[Tensor], Optional[float], bool) -> Tensor
def fps(src, batch, ratio, random_start, batch_size): # noqa
# type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int]) -> Tensor # noqa
pass # pragma: no cover


@torch.jit._overload # noqa
def fps(src, batch=None, ratio=None, random_start=True): # noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool) -> Tensor
def fps(src, batch, ratio, random_start, batch_size): # noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int]) -> Tensor # noqa
pass # pragma: no cover


def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa
def fps( # noqa
src: torch.Tensor,
batch: Optional[Tensor] = None,
ratio: Optional[Union[torch.Tensor, float]] = None,
random_start: bool = True,
batch_size: Optional[int] = None,
):
r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature
Learning on Point Sets in a Metric Space"
<https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the
Expand All @@ -32,10 +38,11 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa
(default: :obj:`0.5`)
random_start (bool, optional): If set to :obj:`False`, use the first
node in :math:`\mathbf{X}` as starting node. (default: obj:`True`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
:rtype: :class:`LongTensor`
.. code-block:: python
import torch
Expand All @@ -57,7 +64,8 @@ def fps(src: torch.Tensor, batch=None, ratio=None, random_start=True): # noqa

if batch is not None:
assert src.size(0) == batch.numel()
batch_size = int(batch.max()) + 1
if batch_size is None:
batch_size = int(batch.max()) + 1

deg = src.new_zeros(batch_size, dtype=torch.long)
deg.scatter_add_(0, batch, torch.ones_like(batch))
Expand Down
49 changes: 34 additions & 15 deletions torch_cluster/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,16 @@


@torch.jit.script
def knn(x: torch.Tensor, y: torch.Tensor, k: int,
batch_x: Optional[torch.Tensor] = None,
batch_y: Optional[torch.Tensor] = None, cosine: bool = False,
num_workers: int = 1) -> torch.Tensor:
def knn(
x: torch.Tensor,
y: torch.Tensor,
k: int,
batch_x: Optional[torch.Tensor] = None,
batch_y: Optional[torch.Tensor] = None,
cosine: bool = False,
num_workers: int = 1,
batch_size: Optional[int] = None,
) -> torch.Tensor:
r"""Finds for each element in :obj:`y` the :obj:`k` nearest points in
:obj:`x`.
Expand All @@ -31,6 +37,8 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch_x` or :obj:`batch_y` is not
:obj:`None`, or the input lies on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
:rtype: :class:`LongTensor`
Expand All @@ -52,13 +60,15 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,
y = y.view(-1, 1) if y.dim() == 1 else y
x, y = x.contiguous(), y.contiguous()

batch_size = 1
if batch_x is not None:
assert x.size(0) == batch_x.numel()
batch_size = int(batch_x.max()) + 1
if batch_y is not None:
assert y.size(0) == batch_y.numel()
batch_size = max(batch_size, int(batch_y.max()) + 1)
if batch_size is None:
batch_size = 1
if batch_x is not None:
assert x.size(0) == batch_x.numel()
batch_size = int(batch_x.max()) + 1
if batch_y is not None:
assert y.size(0) == batch_y.numel()
batch_size = max(batch_size, int(batch_y.max()) + 1)
assert batch_size > 0

ptr_x: Optional[torch.Tensor] = None
ptr_y: Optional[torch.Tensor] = None
Expand All @@ -74,9 +84,16 @@ def knn(x: torch.Tensor, y: torch.Tensor, k: int,


@torch.jit.script
def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
loop: bool = False, flow: str = 'source_to_target',
cosine: bool = False, num_workers: int = 1) -> torch.Tensor:
def knn_graph(
x: torch.Tensor,
k: int,
batch: Optional[torch.Tensor] = None,
loop: bool = False,
flow: str = 'source_to_target',
cosine: bool = False,
num_workers: int = 1,
batch_size: Optional[int] = None,
) -> torch.Tensor:
r"""Computes graph edges to the nearest :obj:`k` points.
Args:
Expand All @@ -98,6 +115,8 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch` is not :obj:`None`, or the input lies
on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
:rtype: :class:`LongTensor`
Expand All @@ -113,7 +132,7 @@ def knn_graph(x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None,

assert flow in ['source_to_target', 'target_to_source']
edge_index = knn(x, x, k if loop else k + 1, batch, batch, cosine,
num_workers)
num_workers, batch_size)

if flow == 'source_to_target':
row, col = edge_index[1], edge_index[0]
Expand Down
51 changes: 35 additions & 16 deletions torch_cluster/radius.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,16 @@


@torch.jit.script
def radius(x: torch.Tensor, y: torch.Tensor, r: float,
batch_x: Optional[torch.Tensor] = None,
batch_y: Optional[torch.Tensor] = None, max_num_neighbors: int = 32,
num_workers: int = 1) -> torch.Tensor:
def radius(
x: torch.Tensor,
y: torch.Tensor,
r: float,
batch_x: Optional[torch.Tensor] = None,
batch_y: Optional[torch.Tensor] = None,
max_num_neighbors: int = 32,
num_workers: int = 1,
batch_size: Optional[int] = None,
) -> torch.Tensor:
r"""Finds for each element in :obj:`y` all points in :obj:`x` within
distance :obj:`r`.
Expand All @@ -33,6 +39,8 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch_x` or :obj:`batch_y` is not
:obj:`None`, or the input lies on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
.. code-block:: python
Expand All @@ -52,16 +60,19 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,
y = y.view(-1, 1) if y.dim() == 1 else y
x, y = x.contiguous(), y.contiguous()

batch_size = 1
if batch_x is not None:
assert x.size(0) == batch_x.numel()
batch_size = int(batch_x.max()) + 1
if batch_y is not None:
assert y.size(0) == batch_y.numel()
batch_size = max(batch_size, int(batch_y.max()) + 1)
if batch_size is None:
batch_size = 1
if batch_x is not None:
assert x.size(0) == batch_x.numel()
batch_size = int(batch_x.max()) + 1
if batch_y is not None:
assert y.size(0) == batch_y.numel()
batch_size = max(batch_size, int(batch_y.max()) + 1)
assert batch_size > 0

ptr_x: Optional[torch.Tensor] = None
ptr_y: Optional[torch.Tensor] = None

if batch_size > 1:
assert batch_x is not None
assert batch_y is not None
Expand All @@ -74,10 +85,16 @@ def radius(x: torch.Tensor, y: torch.Tensor, r: float,


@torch.jit.script
def radius_graph(x: torch.Tensor, r: float,
batch: Optional[torch.Tensor] = None, loop: bool = False,
max_num_neighbors: int = 32, flow: str = 'source_to_target',
num_workers: int = 1) -> torch.Tensor:
def radius_graph(
x: torch.Tensor,
r: float,
batch: Optional[torch.Tensor] = None,
loop: bool = False,
max_num_neighbors: int = 32,
flow: str = 'source_to_target',
num_workers: int = 1,
batch_size: Optional[int] = None,
) -> torch.Tensor:
r"""Computes graph edges to all points within a given distance.
Args:
Expand All @@ -101,6 +118,8 @@ def radius_graph(x: torch.Tensor, r: float,
num_workers (int): Number of workers to use for computation. Has no
effect in case :obj:`batch` is not :obj:`None`, or the input lies
on the GPU. (default: :obj:`1`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
:rtype: :class:`LongTensor`
Expand All @@ -117,7 +136,7 @@ def radius_graph(x: torch.Tensor, r: float,
assert flow in ['source_to_target', 'target_to_source']
edge_index = radius(x, x, r, batch, batch,
max_num_neighbors if loop else max_num_neighbors + 1,
num_workers)
num_workers, batch_size)
if flow == 'source_to_target':
row, col = edge_index[1], edge_index[0]
else:
Expand Down

0 comments on commit 32bee64

Please sign in to comment.