diff --git a/README.md b/README.md index b95aac8..c02ae25 100644 --- a/README.md +++ b/README.md @@ -271,6 +271,69 @@ output : (..., *spatial) tensor Coefficient image. """ ``` + +```python +interpol.resize( + image, + factor=None, + shape=None, + anchor='c', + interpolation=1, + prefilter=True +) +"""Resize an image by a factor or to a specific shape. + +Notes +----- +.. A least one of `factor` and `shape` must be specified +.. If `anchor in ('centers', 'edges')`, exactly one of `factor` or + `shape must be specified. +.. If `anchor in ('first', 'last')`, `factor` must be provided even + if `shape` is specified. +.. Because of rounding, it is in general not assured that + `resize(resize(x, f), 1/f)` returns a tensor with the same shape as x. + + edges centers first last + e - + - + - e + - + - + - + + - + - + - + + - + - + - + + | . | . | . | | c | . | c | | f | . | . | | . | . | . | + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + | . | . | . | | . | . | . | | . | . | . | | . | . | . | + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + | . | . | . | | c | . | c | | . | . | . | | . | . | l | + e _ + _ + _ e + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + +Parameters +---------- +image : (batch, channel, *inshape) tensor + Image to resize +factor : float or list[float], optional + Resizing factor + * > 1 : larger image <-> smaller voxels + * < 1 : smaller image <-> larger voxels +shape : (ndim,) list[int], optional + Output shape +anchor : {'centers', 'edges', 'first', 'last'} or list, default='centers' + * In cases 'c' and 'e', the volume shape is multiplied by the + zoom factor (and eventually truncated), and two anchor points + are used to determine the voxel size. + * In cases 'f' and 'l', a single anchor point is used so that + the voxel size is exactly divided by the zoom factor. + This case with an integer factor corresponds to subslicing + the volume (e.g., `vol[::f, ::f, ::f]`). + * A list of anchors (one per dimension) can also be provided. +interpolation : int or sequence[int], default=1 + Interpolation order. +prefilter : bool, default=True + Apply spline pre-filter (= interpolates the input) + +Returns +------- +resized : (batch, channel, *shape) tensor + Resized image + +""" +``` + ## License torch-interpol is released under the MIT license. diff --git a/interpol/__init__.py b/interpol/__init__.py index 01efa79..45297ae 100644 --- a/interpol/__init__.py +++ b/interpol/__init__.py @@ -1,3 +1,4 @@ from .api import * +from .resize import * from . import _version __version__ = _version.get_versions()['version'] diff --git a/interpol/api.py b/interpol/api.py index df4e2df..e83e435 100755 --- a/interpol/api.py +++ b/interpol/api.py @@ -59,7 +59,6 @@ /!\ Only 'dct1', 'dct2' and 'dft' are implemented for interpolation orders >= 6.""" - _ref_coeff = \ """..[1] M. Unser, A. Aldroubi and M. Eden. "B-Spline Signal Processing: Part I-Theory," @@ -530,4 +529,4 @@ def affine_grid(mat, shape): lin = mat[..., :nb_dim, :nb_dim] off = mat[..., :nb_dim, -1] grid = matvec(lin, grid) + off - return grid \ No newline at end of file + return grid diff --git a/interpol/resize.py b/interpol/resize.py new file mode 100644 index 0000000..e09e043 --- /dev/null +++ b/interpol/resize.py @@ -0,0 +1,113 @@ +""" +Resize functions (equivalent to scipy's zoom, pytorch's interpolate) +based on grid_pull. +""" +from .api import grid_pull +from .utils import make_list, meshgrid_ij +import torch + + +def resize(image, factor=None, shape=None, anchor='c', + interpolation=1, prefilter=True, **kwargs): + """Resize an image by a factor or to a specific shape. + + Notes + ----- + .. A least one of `factor` and `shape` must be specified + .. If `anchor in ('centers', 'edges')`, exactly one of `factor` or + `shape must be specified. + .. If `anchor in ('first', 'last')`, `factor` must be provided even + if `shape` is specified. + .. Because of rounding, it is in general not assured that + `resize(resize(x, f), 1/f)` returns a tensor with the same shape as x. + + edges centers first last + e - + - + - e + - + - + - + + - + - + - + + - + - + - + + | . | . | . | | c | . | c | | f | . | . | | . | . | . | + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + | . | . | . | | . | . | . | | . | . | . | | . | . | . | + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + | . | . | . | | c | . | c | | . | . | . | | . | . | l | + e _ + _ + _ e + _ + _ + _ + + _ + _ + _ + + _ + _ + _ + + + Parameters + ---------- + image : (batch, channel, *inshape) tensor + Image to resize + factor : float or list[float], optional + Resizing factor + * > 1 : larger image <-> smaller voxels + * < 1 : smaller image <-> larger voxels + shape : (ndim,) list[int], optional + Output shape + anchor : {'centers', 'edges', 'first', 'last'} or list, default='centers' + * In cases 'c' and 'e', the volume shape is multiplied by the + zoom factor (and eventually truncated), and two anchor points + are used to determine the voxel size. + * In cases 'f' and 'l', a single anchor point is used so that + the voxel size is exactly divided by the zoom factor. + This case with an integer factor corresponds to subslicing + the volume (e.g., `vol[::f, ::f, ::f]`). + * A list of anchors (one per dimension) can also be provided. + interpolation : int or sequence[int], default=1 + Interpolation order. + prefilter : bool, default=True + Apply spline pre-filter (= interpolates the input) + + Returns + ------- + resized : (batch, channel, *shape) tensor + Resized image + + """ + factor = make_list(factor) if factor else [] + shape = make_list(shape) if shape else [] + anchor = make_list(anchor) + nb_dim = max(len(factor), len(shape), len(anchor)) or (image.dim() - 2) + anchor = [a[0].lower() for a in make_list(anchor, nb_dim)] + backend = dict(dtype=image.dtype, device=image.device) + + # compute output shape + inshape = image.shape[-nb_dim:] + if factor: + factor = make_list(factor, nb_dim) + elif not shape: + raise ValueError('One of `factor` or `shape` must be provided') + if shape: + shape = make_list(shape, nb_dim) + else: + shape = [int(i*f) for i, f in zip(inshape, factor)] + + if not factor: + factor = [o/i for o, i in zip(shape, inshape)] + + # compute transformation grid + lin = [] + for anch, f, inshp, outshp in zip(anchor, factor, inshape, shape): + if anch == 'c': # centers + lin.append(torch.linspace(0, inshp - 1, outshp, **backend)) + elif anch == 'e': # edges + scale = inshp / outshp + shift = 0.5 * (scale - 1) + lin.append(torch.arange(0., outshp, **backend) * scale + shift) + elif anch == 'f': # first voxel + # scale = 1/f + # shift = 0 + lin.append(torch.arange(0., outshp, **backend) / f) + elif anch == 'l': # last voxel + # scale = 1/f + shift = (inshp - 1) - (outshp - 1) / f + lin.append(torch.arange(0., outshp, **backend) / f + shift) + else: + raise ValueError('Unknown anchor {}'.format(anch)) + + # interpolate + kwargs.setdefault('bound', 'nearest') + kwargs.setdefault('extrapolate', True) + kwargs.setdefault('interpolation', interpolation) + kwargs.setdefault('prefilter', prefilter) + grid = torch.stack(meshgrid_ij(*lin), dim=-1) + resized = grid_pull(image, grid, **kwargs) + + return resized + diff --git a/interpol/utils.py b/interpol/utils.py index b58838d..1e4f1b4 100644 --- a/interpol/utils.py +++ b/interpol/utils.py @@ -98,4 +98,69 @@ def matvec(mat, vec, out=None): if out is not None: out = out[..., 0] - return mv \ No newline at end of file + return mv + + +def _compare_versions(version1, mode, version2): + for v1, v2 in zip(version1, version2): + if mode in ('gt', '>'): + if v1 > v2: + return True + elif v1 < v2: + return False + elif mode in ('ge', '>='): + if v1 > v2: + return True + elif v1 < v2: + return False + elif mode in ('lt', '<'): + if v1 < v2: + return True + elif v1 > v2: + return False + elif mode in ('le', '<='): + if v1 < v2: + return True + elif v1 > v2: + return False + if mode in ('gt', 'lt', '>', '<'): + return False + else: + return True + + +def torch_version(mode, version): + """Check torch version + + Parameters + ---------- + mode : {'<', '<=', '>', '>='} + version : tuple[int] + + Returns + ------- + True if "torch.version version" + + """ + current_version, *cuda_variant = torch.__version__.split('+') + major, minor, patch, *_ = current_version.split('.') + # strip alpha tags + for x in 'abcdefghijklmnopqrstuvwxy': + if x in patch: + patch = patch[:patch.index(x)] + current_version = (int(major), int(minor), int(patch)) + version = make_list(version) + return _compare_versions(current_version, mode, version) + + +if torch_version('>=', (1, 10)): + meshgrid_ij = lambda *x: torch.meshgrid(*x, indexing='ij') + meshgrid_xy = lambda *x: torch.meshgrid(*x, indexing='xy') +else: + meshgrid_ij = lambda *x: torch.meshgrid(*x) + def meshgrid_xy(*x): + grid = list(torch.meshgrid(*x)) + if len(grid) > 1: + grid[0] = grid[0].transpose(0, 1) + grid[1] = grid[1].transpose(0, 1) + return grid