-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Feature: resize function (== torch.interpolate, scipy.zoom)
- Loading branch information
Showing
5 changed files
with
244 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .api import * | ||
from .resize import * | ||
from . import _version | ||
__version__ = _version.get_versions()['version'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
""" | ||
Resize functions (equivalent to scipy's zoom, pytorch's interpolate) | ||
based on grid_pull. | ||
""" | ||
from .api import grid_pull | ||
from .utils import make_list, meshgrid_ij | ||
import torch | ||
|
||
|
||
def resize(image, factor=None, shape=None, anchor='c', | ||
interpolation=1, prefilter=True, **kwargs): | ||
"""Resize an image by a factor or to a specific shape. | ||
Notes | ||
----- | ||
.. A least one of `factor` and `shape` must be specified | ||
.. If `anchor in ('centers', 'edges')`, exactly one of `factor` or | ||
`shape must be specified. | ||
.. If `anchor in ('first', 'last')`, `factor` must be provided even | ||
if `shape` is specified. | ||
.. Because of rounding, it is in general not assured that | ||
`resize(resize(x, f), 1/f)` returns a tensor with the same shape as x. | ||
edges centers first last | ||
e - + - + - e + - + - + - + + - + - + - + + - + - + - + | ||
| . | . | . | | c | . | c | | f | . | . | | . | . | . | | ||
+ _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + | ||
| . | . | . | | . | . | . | | . | . | . | | . | . | . | | ||
+ _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + | ||
| . | . | . | | c | . | c | | . | . | . | | . | . | l | | ||
e _ + _ + _ e + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + | ||
Parameters | ||
---------- | ||
image : (batch, channel, *inshape) tensor | ||
Image to resize | ||
factor : float or list[float], optional | ||
Resizing factor | ||
* > 1 : larger image <-> smaller voxels | ||
* < 1 : smaller image <-> larger voxels | ||
shape : (ndim,) list[int], optional | ||
Output shape | ||
anchor : {'centers', 'edges', 'first', 'last'} or list, default='centers' | ||
* In cases 'c' and 'e', the volume shape is multiplied by the | ||
zoom factor (and eventually truncated), and two anchor points | ||
are used to determine the voxel size. | ||
* In cases 'f' and 'l', a single anchor point is used so that | ||
the voxel size is exactly divided by the zoom factor. | ||
This case with an integer factor corresponds to subslicing | ||
the volume (e.g., `vol[::f, ::f, ::f]`). | ||
* A list of anchors (one per dimension) can also be provided. | ||
interpolation : int or sequence[int], default=1 | ||
Interpolation order. | ||
prefilter : bool, default=True | ||
Apply spline pre-filter (= interpolates the input) | ||
Returns | ||
------- | ||
resized : (batch, channel, *shape) tensor | ||
Resized image | ||
""" | ||
factor = make_list(factor) if factor else [] | ||
shape = make_list(shape) if shape else [] | ||
anchor = make_list(anchor) | ||
nb_dim = max(len(factor), len(shape), len(anchor)) or (image.dim() - 2) | ||
anchor = [a[0].lower() for a in make_list(anchor, nb_dim)] | ||
backend = dict(dtype=image.dtype, device=image.device) | ||
|
||
# compute output shape | ||
inshape = image.shape[-nb_dim:] | ||
if factor: | ||
factor = make_list(factor, nb_dim) | ||
elif not shape: | ||
raise ValueError('One of `factor` or `shape` must be provided') | ||
if shape: | ||
shape = make_list(shape, nb_dim) | ||
else: | ||
shape = [int(i*f) for i, f in zip(inshape, factor)] | ||
|
||
if not factor: | ||
factor = [o/i for o, i in zip(shape, inshape)] | ||
|
||
# compute transformation grid | ||
lin = [] | ||
for anch, f, inshp, outshp in zip(anchor, factor, inshape, shape): | ||
if anch == 'c': # centers | ||
lin.append(torch.linspace(0, inshp - 1, outshp, **backend)) | ||
elif anch == 'e': # edges | ||
scale = inshp / outshp | ||
shift = 0.5 * (scale - 1) | ||
lin.append(torch.arange(0., outshp, **backend) * scale + shift) | ||
elif anch == 'f': # first voxel | ||
# scale = 1/f | ||
# shift = 0 | ||
lin.append(torch.arange(0., outshp, **backend) / f) | ||
elif anch == 'l': # last voxel | ||
# scale = 1/f | ||
shift = (inshp - 1) - (outshp - 1) / f | ||
lin.append(torch.arange(0., outshp, **backend) / f + shift) | ||
else: | ||
raise ValueError('Unknown anchor {}'.format(anch)) | ||
|
||
# interpolate | ||
kwargs.setdefault('bound', 'nearest') | ||
kwargs.setdefault('extrapolate', True) | ||
kwargs.setdefault('interpolation', interpolation) | ||
kwargs.setdefault('prefilter', prefilter) | ||
grid = torch.stack(meshgrid_ij(*lin), dim=-1) | ||
resized = grid_pull(image, grid, **kwargs) | ||
|
||
return resized | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters