Skip to content

Commit

Permalink
Feature: resize function (== torch.interpolate, scipy.zoom)
Browse files Browse the repository at this point in the history
  • Loading branch information
balbasty committed Jun 14, 2022
1 parent cb94381 commit 5ac4bd5
Show file tree
Hide file tree
Showing 5 changed files with 244 additions and 3 deletions.
63 changes: 63 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,69 @@ output : (..., *spatial) tensor
Coefficient image.
"""
```

```python
interpol.resize(
image,
factor=None,
shape=None,
anchor='c',
interpolation=1,
prefilter=True
)
"""Resize an image by a factor or to a specific shape.
Notes
-----
.. A least one of `factor` and `shape` must be specified
.. If `anchor in ('centers', 'edges')`, exactly one of `factor` or
`shape must be specified.
.. If `anchor in ('first', 'last')`, `factor` must be provided even
if `shape` is specified.
.. Because of rounding, it is in general not assured that
`resize(resize(x, f), 1/f)` returns a tensor with the same shape as x.
edges centers first last
e - + - + - e + - + - + - + + - + - + - + + - + - + - +
| . | . | . | | c | . | c | | f | . | . | | . | . | . |
+ _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ +
| . | . | . | | . | . | . | | . | . | . | | . | . | . |
+ _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ +
| . | . | . | | c | . | c | | . | . | . | | . | . | l |
e _ + _ + _ e + _ + _ + _ + + _ + _ + _ + + _ + _ + _ +
Parameters
----------
image : (batch, channel, *inshape) tensor
Image to resize
factor : float or list[float], optional
Resizing factor
* > 1 : larger image <-> smaller voxels
* < 1 : smaller image <-> larger voxels
shape : (ndim,) list[int], optional
Output shape
anchor : {'centers', 'edges', 'first', 'last'} or list, default='centers'
* In cases 'c' and 'e', the volume shape is multiplied by the
zoom factor (and eventually truncated), and two anchor points
are used to determine the voxel size.
* In cases 'f' and 'l', a single anchor point is used so that
the voxel size is exactly divided by the zoom factor.
This case with an integer factor corresponds to subslicing
the volume (e.g., `vol[::f, ::f, ::f]`).
* A list of anchors (one per dimension) can also be provided.
interpolation : int or sequence[int], default=1
Interpolation order.
prefilter : bool, default=True
Apply spline pre-filter (= interpolates the input)
Returns
-------
resized : (batch, channel, *shape) tensor
Resized image
"""
```

## License

torch-interpol is released under the MIT license.
1 change: 1 addition & 0 deletions interpol/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .api import *
from .resize import *
from . import _version
__version__ = _version.get_versions()['version']
3 changes: 1 addition & 2 deletions interpol/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
/!\ Only 'dct1', 'dct2' and 'dft' are implemented for interpolation
orders >= 6."""


_ref_coeff = \
"""..[1] M. Unser, A. Aldroubi and M. Eden.
"B-Spline Signal Processing: Part I-Theory,"
Expand Down Expand Up @@ -530,4 +529,4 @@ def affine_grid(mat, shape):
lin = mat[..., :nb_dim, :nb_dim]
off = mat[..., :nb_dim, -1]
grid = matvec(lin, grid) + off
return grid
return grid
113 changes: 113 additions & 0 deletions interpol/resize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""
Resize functions (equivalent to scipy's zoom, pytorch's interpolate)
based on grid_pull.
"""
from .api import grid_pull
from .utils import make_list, meshgrid_ij
import torch


def resize(image, factor=None, shape=None, anchor='c',
interpolation=1, prefilter=True, **kwargs):
"""Resize an image by a factor or to a specific shape.
Notes
-----
.. A least one of `factor` and `shape` must be specified
.. If `anchor in ('centers', 'edges')`, exactly one of `factor` or
`shape must be specified.
.. If `anchor in ('first', 'last')`, `factor` must be provided even
if `shape` is specified.
.. Because of rounding, it is in general not assured that
`resize(resize(x, f), 1/f)` returns a tensor with the same shape as x.
edges centers first last
e - + - + - e + - + - + - + + - + - + - + + - + - + - +
| . | . | . | | c | . | c | | f | . | . | | . | . | . |
+ _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ +
| . | . | . | | . | . | . | | . | . | . | | . | . | . |
+ _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ +
| . | . | . | | c | . | c | | . | . | . | | . | . | l |
e _ + _ + _ e + _ + _ + _ + + _ + _ + _ + + _ + _ + _ +
Parameters
----------
image : (batch, channel, *inshape) tensor
Image to resize
factor : float or list[float], optional
Resizing factor
* > 1 : larger image <-> smaller voxels
* < 1 : smaller image <-> larger voxels
shape : (ndim,) list[int], optional
Output shape
anchor : {'centers', 'edges', 'first', 'last'} or list, default='centers'
* In cases 'c' and 'e', the volume shape is multiplied by the
zoom factor (and eventually truncated), and two anchor points
are used to determine the voxel size.
* In cases 'f' and 'l', a single anchor point is used so that
the voxel size is exactly divided by the zoom factor.
This case with an integer factor corresponds to subslicing
the volume (e.g., `vol[::f, ::f, ::f]`).
* A list of anchors (one per dimension) can also be provided.
interpolation : int or sequence[int], default=1
Interpolation order.
prefilter : bool, default=True
Apply spline pre-filter (= interpolates the input)
Returns
-------
resized : (batch, channel, *shape) tensor
Resized image
"""
factor = make_list(factor) if factor else []
shape = make_list(shape) if shape else []
anchor = make_list(anchor)
nb_dim = max(len(factor), len(shape), len(anchor)) or (image.dim() - 2)
anchor = [a[0].lower() for a in make_list(anchor, nb_dim)]
backend = dict(dtype=image.dtype, device=image.device)

# compute output shape
inshape = image.shape[-nb_dim:]
if factor:
factor = make_list(factor, nb_dim)
elif not shape:
raise ValueError('One of `factor` or `shape` must be provided')
if shape:
shape = make_list(shape, nb_dim)
else:
shape = [int(i*f) for i, f in zip(inshape, factor)]

if not factor:
factor = [o/i for o, i in zip(shape, inshape)]

# compute transformation grid
lin = []
for anch, f, inshp, outshp in zip(anchor, factor, inshape, shape):
if anch == 'c': # centers
lin.append(torch.linspace(0, inshp - 1, outshp, **backend))
elif anch == 'e': # edges
scale = inshp / outshp
shift = 0.5 * (scale - 1)
lin.append(torch.arange(0., outshp, **backend) * scale + shift)
elif anch == 'f': # first voxel
# scale = 1/f
# shift = 0
lin.append(torch.arange(0., outshp, **backend) / f)
elif anch == 'l': # last voxel
# scale = 1/f
shift = (inshp - 1) - (outshp - 1) / f
lin.append(torch.arange(0., outshp, **backend) / f + shift)
else:
raise ValueError('Unknown anchor {}'.format(anch))

# interpolate
kwargs.setdefault('bound', 'nearest')
kwargs.setdefault('extrapolate', True)
kwargs.setdefault('interpolation', interpolation)
kwargs.setdefault('prefilter', prefilter)
grid = torch.stack(meshgrid_ij(*lin), dim=-1)
resized = grid_pull(image, grid, **kwargs)

return resized

67 changes: 66 additions & 1 deletion interpol/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,69 @@ def matvec(mat, vec, out=None):
if out is not None:
out = out[..., 0]

return mv
return mv


def _compare_versions(version1, mode, version2):
for v1, v2 in zip(version1, version2):
if mode in ('gt', '>'):
if v1 > v2:
return True
elif v1 < v2:
return False
elif mode in ('ge', '>='):
if v1 > v2:
return True
elif v1 < v2:
return False
elif mode in ('lt', '<'):
if v1 < v2:
return True
elif v1 > v2:
return False
elif mode in ('le', '<='):
if v1 < v2:
return True
elif v1 > v2:
return False
if mode in ('gt', 'lt', '>', '<'):
return False
else:
return True


def torch_version(mode, version):
"""Check torch version
Parameters
----------
mode : {'<', '<=', '>', '>='}
version : tuple[int]
Returns
-------
True if "torch.version <mode> version"
"""
current_version, *cuda_variant = torch.__version__.split('+')
major, minor, patch, *_ = current_version.split('.')
# strip alpha tags
for x in 'abcdefghijklmnopqrstuvwxy':
if x in patch:
patch = patch[:patch.index(x)]
current_version = (int(major), int(minor), int(patch))
version = make_list(version)
return _compare_versions(current_version, mode, version)


if torch_version('>=', (1, 10)):
meshgrid_ij = lambda *x: torch.meshgrid(*x, indexing='ij')
meshgrid_xy = lambda *x: torch.meshgrid(*x, indexing='xy')
else:
meshgrid_ij = lambda *x: torch.meshgrid(*x)
def meshgrid_xy(*x):
grid = list(torch.meshgrid(*x))
if len(grid) > 1:
grid[0] = grid[0].transpose(0, 1)
grid[1] = grid[1].transpose(0, 1)
return grid

0 comments on commit 5ac4bd5

Please sign in to comment.