From 861cb4a5732ad59feec0f267920b71261a685966 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ya=C3=ABl=20Balbastre?= Date: Fri, 19 Apr 2024 11:18:31 +0100 Subject: [PATCH] FEAT(fov): random 90 deg rotations (#15) Co-authored-by: Sean I Young --- cornucopia/base.py | 2 +- cornucopia/fov.py | 151 ++++++++++++++++++++++++++++--- cornucopia/geometric.py | 3 +- cornucopia/random.py | 4 +- cornucopia/tests/test_run_fov.py | 34 +++++++ 5 files changed, 176 insertions(+), 18 deletions(-) diff --git a/cornucopia/base.py b/cornucopia/base.py index a93554a..b8c6691 100755 --- a/cornucopia/base.py +++ b/cornucopia/base.py @@ -359,7 +359,7 @@ def __init__(self, *, shared=False, **kwargs): super().__init__(**kwargs) self.shared = self._prepare_shared(shared) - def make_final(self, x, max_depth=float('inf'), *args, **kwargs): + def make_final(self, x, max_depth=float('inf')): if self.is_final or max_depth == 0: return self return NotImplemented diff --git a/cornucopia/fov.py b/cornucopia/fov.py index 19db6fd..2884031 100755 --- a/cornucopia/fov.py +++ b/cornucopia/fov.py @@ -8,14 +8,16 @@ 'CropTransform', 'PadTransform', 'PowerTwoTransform', + 'Rot90Transform', + 'Rot180Transform', + 'RandomRot90Transform', ] - import math from random import shuffle -from .base import FinalTransform, NonFinalTransform +from .base import FinalTransform, NonFinalTransform, PerChannelTransform from .utils.py import ensure_list from .utils.padding import pad -from .random import Uniform, RandKFrom, Sampler +from .random import Uniform, RandKFrom, Sampler, RandInt, make_range class FlipTransform(FinalTransform): @@ -23,7 +25,6 @@ class FlipTransform(FinalTransform): def __init__(self, axis=None, **kwargs): """ - Parameters ---------- axis : [list of] int @@ -46,24 +47,30 @@ def make_inverse(self): class RandomFlipTransform(NonFinalTransform): """Randomly flip one or more axes""" - def __init__(self, axes=None, **kwargs): + def __init__(self, axes=None, *, shared=True, **kwargs): """ - Parameters ---------- axes : Sampler or [list of] int Axes that can be flipped (default: all) + + Other Parameters + ---------------- shared : {'channels', 'tensors', 'channels+tensors', ''} Apply the same flip to all channels and/or tensors """ axes = kwargs.pop('axis', axes) - kwargs.setdefault('shared', True) - super().__init__(**kwargs) + super().__init__(shared=shared, **kwargs) self.axes = axes def make_final(self, x, max_depth=float('inf')): if max_depth == 0: return self + if 'channels' not in self.shared and len(x) > 1: + return PerChannelTransform( + [self.make_final(x[i:i+1], max_depth) for i in range(len(x))], + **self.get_prm() + ).make_final(x, max_depth-1) axes = self.axes or range(1, x.ndim) if not isinstance(axes, Sampler): rand_axes = RandKFrom(ensure_list(axes)) @@ -76,7 +83,6 @@ class PermuteAxesTransform(FinalTransform): def __init__(self, permutation=None, **kwargs): """ - Parameters ---------- permutation : [list of] int @@ -105,23 +111,29 @@ def make_inverse(self): class RandomPermuteAxesTransform(NonFinalTransform): """Randomly permute axes""" - def __init__(self, axes=None, **kwargs): + def __init__(self, axes=None, *, shared=True, **kwargs): """ - Parameters ---------- axes : [list of] int Axes that can be permuted (default: all) + + Other Parameters + ---------------- shared : {'channels', 'tensors', 'channels+tensors', ''} Apply the same permutation to all channels and/or tensors """ - kwargs.setdefault('shared', True) - super().__init__(**kwargs) + super().__init__(shared=shared, **kwargs) self.axes = axes def make_final(self, x, max_depth=float('inf')): if max_depth == 0: return self + if 'channels' not in self.shared and len(x) > 1: + return PerChannelTransform( + [self.make_final(x[i:i+1], max_depth) for i in range(len(x))], + **self.get_prm() + ).make_final(x, max_depth-1) axes = list(self.axes or range(x.ndim-1)) shuffle(axes) return PermuteAxesTransform( @@ -129,6 +141,119 @@ def make_final(self, x, max_depth=float('inf')): ).make_final(x, max_depth-1) +class Rot90Transform(FinalTransform): + """ + Apply a 90 (or 180) rotation along one or several axes + """ + + def __init__(self, axis=0, negative=False, double=False, **kwargs): + """ + Parameters + ---------- + axis : int or list[int] + Rotation axis (indexing does not account for the channel axis) + negative : bool or list[bool] + Rotate by -90 deg instead of 90 deg + double : bool or list[bool] + Rotate be 180 instead of 90 (`negative` is then unused) + """ + super().__init__(**kwargs) + self.axis = ensure_list(axis) + self.negative = ensure_list(negative, len(self.axis)) + self.double = ensure_list(double, len(self.axis)) + + def apply(self, x): + # this implementation is suboptimal. We should fuse all transpose + # and all flips into a single "transpose + flip" operation so that + # a single allocation happens. This will be fine for now. + + ndim = x.ndim - 1 + axis = [1 + (ndim + a if a < 0 else a) for a in self.axis] + for ax, neg, dbl in zip(axis, self.negative, self.double): + if dbl: + if ndim == 2: + dims = [1, 2] + else: + assert ndim == 3 + dims = [d for d in (1, 2, 3) if d != ax] + x = x.flip(dims) + else: + if ndim == 2: + dims = [1, 2] + else: + assert ndim == 3 + dims = [d for d in (1, 2, 3) if d != ax] + x = x.transpose(*dims).flip(dims[1] if neg else dims[0]) + return x + + +class Rot180Transform(Rot90Transform): + """Apply a 180 deg rotation along one or several axes""" + + def __init__(self, axis=0, **kwargs): + """ + Parameters + ---------- + axis : int or list[int] + Rotation axis (indexing does not account for the channel axis) + """ + super().__init__(axis, double=True, **kwargs) + + +class RandomRot90Transform(NonFinalTransform): + """Random set of 90 transforms""" + + def __init__(self, axes=None, max_rot=2, negative=True, + *, shared=True, **kwargs): + """ + Parameters + ---------- + axes : int or list[int] + Axes along which rotations can happen. + If `None`, all axes. + max_rot : int or Sampler + Maximum number of consecutive rotations. + negative : bool + Whether to authorize negative rotations. + + Other Parameters + ---------------- + shared : {'channels', 'tensors', 'channels+tensors', ''} + Apply the same permutation to all channels and/or tensors + """ + super().__init__(shared=shared, **kwargs) + self.axes = axes + self.max_rot = RandInt.make(make_range(1, max_rot)) + self.negative = negative + + def make_final(self, x, max_depth=float('inf')): + if max_depth == 0: + return self + if 'channels' not in self.shared and len(x) > 1: + return PerChannelTransform( + [self.make_final(x[i:i+1], max_depth) for i in range(len(x))], + **self.get_prm() + ).make_final(x, max_depth-1) + ndim = x.ndim - 1 + max_rot = self.max_rot + if isinstance(max_rot, Sampler): + max_rot = max_rot() + axes = self.axes + if axes is None: + axes = list(range(ndim)) + if isinstance(axes, (int, list, tuple)): + axes = ensure_list(axes, max_rot, crop=False) + if not isinstance(axes, Sampler): + axes = RandKFrom(axes, max_rot, replacement=True) + + axes = ensure_list(axes(), max_rot) + negative = RandKFrom([False, True], max_rot, replacement=True)() \ + if self.negative else [False] * max_rot + return Rot90Transform( + axes, negative, **self.get_prm() + ).make_final(max_depth-1) + + class CropPadTransform(FinalTransform): """Crop and/or pad a tensor""" diff --git a/cornucopia/geometric.py b/cornucopia/geometric.py index 0e5747d..aa2554c 100755 --- a/cornucopia/geometric.py +++ b/cornucopia/geometric.py @@ -1119,7 +1119,6 @@ def make_final(self, x, max_depth=float('inf'), flow=True): F = torch.eye(ndim+1, **backend) F[:ndim, -1] = -offsets Z = E.clone() - print(zooms.shape, Z.shape) Z.diagonal(0, -1, -2)[:, :-1].copy_(1 + zooms) T = E.clone() T[:, :ndim, -1] = translations @@ -1362,7 +1361,7 @@ def make_final(self, x, max_depth=float('inf')): # get slice direction slice = self.slice if slice is None: - slice = RandInt(0, ndim) + slice = RandInt(0, ndim - 1) if isinstance(slice, Sampler): slice = slice() diff --git a/cornucopia/random.py b/cornucopia/random.py index 724f165..1f86124 100755 --- a/cornucopia/random.py +++ b/cornucopia/random.py @@ -260,7 +260,7 @@ def __init__(self, range, k=None, replacement=False): self.replacement = replacement def __call__(self, n=None, **backend): - k = self.k or RandInt(len(self.range))() + k = self.k or RandInt(1, len(self.range))() if isinstance(n, (list, tuple)) or n: raise ValueError('RandKFrom cannot sample multiple elements') if not self.replacement: @@ -268,7 +268,7 @@ def __call__(self, n=None, **backend): random.shuffle(range) return range[:k] else: - index = RandInt(len(self.range))(k) + index = RandInt(0, len(self.range)-1)(k) return [self.range[i] for i in index] diff --git a/cornucopia/tests/test_run_fov.py b/cornucopia/tests/test_run_fov.py index 7f69a68..808b667 100755 --- a/cornucopia/tests/test_run_fov.py +++ b/cornucopia/tests/test_run_fov.py @@ -11,6 +11,9 @@ CropTransform, PadTransform, PowerTwoTransform, + Rot90Transform, + Rot180Transform, + RandomRot90Transform, ) SEED = 12345678 @@ -53,6 +56,37 @@ def test_run_fov_permute_random(size): assert True +@pytest.mark.parametrize("size", sizes) +@pytest.mark.parametrize("axes", [0, 1, [0, 1], [0, 0]]) +@pytest.mark.parametrize("negative", [False, True]) +@pytest.mark.parametrize("double", [False, True]) +def test_run_rot90_permute(size, axes, negative, double): + random.seed(SEED) + torch.random.manual_seed(SEED) + x = torch.randn(size) + _ = Rot90Transform(axes, negative, double)(x) + assert True + + +@pytest.mark.parametrize("size", sizes) +@pytest.mark.parametrize("axes", [0, 1, [0, 1], [0, 0]]) +def test_run_rot180_permute(size, axes): + random.seed(SEED) + torch.random.manual_seed(SEED) + x = torch.randn(size) + _ = Rot180Transform(axes)(x) + assert True + + +@pytest.mark.parametrize("size", sizes) +def test_run_rot90_random(size): + random.seed(SEED) + torch.random.manual_seed(SEED) + x = torch.randn(size) + _ = RandomRot90Transform()(x) + assert True + + @pytest.mark.parametrize("size", sizes) def test_run_fov_patch(size): random.seed(SEED)