Skip to content

Commit

Permalink
FEAT(fov): random 90 deg rotations (#15)
Browse files Browse the repository at this point in the history
Co-authored-by: Sean I Young <[email protected]>
  • Loading branch information
balbasty and Sean I Young authored Apr 19, 2024
1 parent 2bc24fd commit 861cb4a
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 18 deletions.
2 changes: 1 addition & 1 deletion cornucopia/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def __init__(self, *, shared=False, **kwargs):
super().__init__(**kwargs)
self.shared = self._prepare_shared(shared)

def make_final(self, x, max_depth=float('inf'), *args, **kwargs):
def make_final(self, x, max_depth=float('inf')):
if self.is_final or max_depth == 0:
return self
return NotImplemented
Expand Down
151 changes: 138 additions & 13 deletions cornucopia/fov.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,23 @@
'CropTransform',
'PadTransform',
'PowerTwoTransform',
'Rot90Transform',
'Rot180Transform',
'RandomRot90Transform',
]

import math
from random import shuffle
from .base import FinalTransform, NonFinalTransform
from .base import FinalTransform, NonFinalTransform, PerChannelTransform
from .utils.py import ensure_list
from .utils.padding import pad
from .random import Uniform, RandKFrom, Sampler
from .random import Uniform, RandKFrom, Sampler, RandInt, make_range


class FlipTransform(FinalTransform):
"""Flip one or more axes"""

def __init__(self, axis=None, **kwargs):
"""
Parameters
----------
axis : [list of] int
Expand All @@ -46,24 +47,30 @@ def make_inverse(self):
class RandomFlipTransform(NonFinalTransform):
"""Randomly flip one or more axes"""

def __init__(self, axes=None, **kwargs):
def __init__(self, axes=None, *, shared=True, **kwargs):
"""
Parameters
----------
axes : Sampler or [list of] int
Axes that can be flipped (default: all)
Other Parameters
----------------
shared : {'channels', 'tensors', 'channels+tensors', ''}
Apply the same flip to all channels and/or tensors
"""
axes = kwargs.pop('axis', axes)
kwargs.setdefault('shared', True)
super().__init__(**kwargs)
super().__init__(shared=shared, **kwargs)
self.axes = axes

def make_final(self, x, max_depth=float('inf')):
if max_depth == 0:
return self
if 'channels' not in self.shared and len(x) > 1:
return PerChannelTransform(
[self.make_final(x[i:i+1], max_depth) for i in range(len(x))],
**self.get_prm()
).make_final(x, max_depth-1)
axes = self.axes or range(1, x.ndim)
if not isinstance(axes, Sampler):
rand_axes = RandKFrom(ensure_list(axes))
Expand All @@ -76,7 +83,6 @@ class PermuteAxesTransform(FinalTransform):

def __init__(self, permutation=None, **kwargs):
"""
Parameters
----------
permutation : [list of] int
Expand Down Expand Up @@ -105,30 +111,149 @@ def make_inverse(self):
class RandomPermuteAxesTransform(NonFinalTransform):
"""Randomly permute axes"""

def __init__(self, axes=None, **kwargs):
def __init__(self, axes=None, *, shared=True, **kwargs):
"""
Parameters
----------
axes : [list of] int
Axes that can be permuted (default: all)
Other Parameters
----------------
shared : {'channels', 'tensors', 'channels+tensors', ''}
Apply the same permutation to all channels and/or tensors
"""
kwargs.setdefault('shared', True)
super().__init__(**kwargs)
super().__init__(shared=shared, **kwargs)
self.axes = axes

def make_final(self, x, max_depth=float('inf')):
if max_depth == 0:
return self
if 'channels' not in self.shared and len(x) > 1:
return PerChannelTransform(
[self.make_final(x[i:i+1], max_depth) for i in range(len(x))],
**self.get_prm()
).make_final(x, max_depth-1)
axes = list(self.axes or range(x.ndim-1))
shuffle(axes)
return PermuteAxesTransform(
axes, **self.get_prm()
).make_final(x, max_depth-1)


class Rot90Transform(FinalTransform):
"""
Apply a 90 (or 180) rotation along one or several axes
"""

def __init__(self, axis=0, negative=False, double=False, **kwargs):
"""
Parameters
----------
axis : int or list[int]
Rotation axis (indexing does not account for the channel axis)
negative : bool or list[bool]
Rotate by -90 deg instead of 90 deg
double : bool or list[bool]
Rotate be 180 instead of 90 (`negative` is then unused)
"""
super().__init__(**kwargs)
self.axis = ensure_list(axis)
self.negative = ensure_list(negative, len(self.axis))
self.double = ensure_list(double, len(self.axis))

def apply(self, x):
# this implementation is suboptimal. We should fuse all transpose
# and all flips into a single "transpose + flip" operation so that
# a single allocation happens. This will be fine for now.

ndim = x.ndim - 1
axis = [1 + (ndim + a if a < 0 else a) for a in self.axis]
for ax, neg, dbl in zip(axis, self.negative, self.double):
if dbl:
if ndim == 2:
dims = [1, 2]
else:
assert ndim == 3
dims = [d for d in (1, 2, 3) if d != ax]
x = x.flip(dims)
else:
if ndim == 2:
dims = [1, 2]
else:
assert ndim == 3
dims = [d for d in (1, 2, 3) if d != ax]
x = x.transpose(*dims).flip(dims[1] if neg else dims[0])
return x


class Rot180Transform(Rot90Transform):
"""Apply a 180 deg rotation along one or several axes"""

def __init__(self, axis=0, **kwargs):
"""
Parameters
----------
axis : int or list[int]
Rotation axis (indexing does not account for the channel axis)
"""
super().__init__(axis, double=True, **kwargs)


class RandomRot90Transform(NonFinalTransform):
"""Random set of 90 transforms"""

def __init__(self, axes=None, max_rot=2, negative=True,
*, shared=True, **kwargs):
"""
Parameters
----------
axes : int or list[int]
Axes along which rotations can happen.
If `None`, all axes.
max_rot : int or Sampler
Maximum number of consecutive rotations.
negative : bool
Whether to authorize negative rotations.
Other Parameters
----------------
shared : {'channels', 'tensors', 'channels+tensors', ''}
Apply the same permutation to all channels and/or tensors
"""
super().__init__(shared=shared, **kwargs)
self.axes = axes
self.max_rot = RandInt.make(make_range(1, max_rot))
self.negative = negative

def make_final(self, x, max_depth=float('inf')):
if max_depth == 0:
return self
if 'channels' not in self.shared and len(x) > 1:
return PerChannelTransform(
[self.make_final(x[i:i+1], max_depth) for i in range(len(x))],
**self.get_prm()
).make_final(x, max_depth-1)
ndim = x.ndim - 1
max_rot = self.max_rot
if isinstance(max_rot, Sampler):
max_rot = max_rot()
axes = self.axes
if axes is None:
axes = list(range(ndim))
if isinstance(axes, (int, list, tuple)):
axes = ensure_list(axes, max_rot, crop=False)
if not isinstance(axes, Sampler):
axes = RandKFrom(axes, max_rot, replacement=True)

axes = ensure_list(axes(), max_rot)
negative = RandKFrom([False, True], max_rot, replacement=True)() \
if self.negative else [False] * max_rot
return Rot90Transform(
axes, negative, **self.get_prm()
).make_final(max_depth-1)


class CropPadTransform(FinalTransform):
"""Crop and/or pad a tensor"""

Expand Down
3 changes: 1 addition & 2 deletions cornucopia/geometric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,6 @@ def make_final(self, x, max_depth=float('inf'), flow=True):
F = torch.eye(ndim+1, **backend)
F[:ndim, -1] = -offsets
Z = E.clone()
print(zooms.shape, Z.shape)
Z.diagonal(0, -1, -2)[:, :-1].copy_(1 + zooms)
T = E.clone()
T[:, :ndim, -1] = translations
Expand Down Expand Up @@ -1362,7 +1361,7 @@ def make_final(self, x, max_depth=float('inf')):
# get slice direction
slice = self.slice
if slice is None:
slice = RandInt(0, ndim)
slice = RandInt(0, ndim - 1)
if isinstance(slice, Sampler):
slice = slice()

Expand Down
4 changes: 2 additions & 2 deletions cornucopia/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,15 +260,15 @@ def __init__(self, range, k=None, replacement=False):
self.replacement = replacement

def __call__(self, n=None, **backend):
k = self.k or RandInt(len(self.range))()
k = self.k or RandInt(1, len(self.range))()
if isinstance(n, (list, tuple)) or n:
raise ValueError('RandKFrom cannot sample multiple elements')
if not self.replacement:
range = list(self.range)
random.shuffle(range)
return range[:k]
else:
index = RandInt(len(self.range))(k)
index = RandInt(0, len(self.range)-1)(k)
return [self.range[i] for i in index]


Expand Down
34 changes: 34 additions & 0 deletions cornucopia/tests/test_run_fov.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
CropTransform,
PadTransform,
PowerTwoTransform,
Rot90Transform,
Rot180Transform,
RandomRot90Transform,
)

SEED = 12345678
Expand Down Expand Up @@ -53,6 +56,37 @@ def test_run_fov_permute_random(size):
assert True


@pytest.mark.parametrize("size", sizes)
@pytest.mark.parametrize("axes", [0, 1, [0, 1], [0, 0]])
@pytest.mark.parametrize("negative", [False, True])
@pytest.mark.parametrize("double", [False, True])
def test_run_rot90_permute(size, axes, negative, double):
random.seed(SEED)
torch.random.manual_seed(SEED)
x = torch.randn(size)
_ = Rot90Transform(axes, negative, double)(x)
assert True


@pytest.mark.parametrize("size", sizes)
@pytest.mark.parametrize("axes", [0, 1, [0, 1], [0, 0]])
def test_run_rot180_permute(size, axes):
random.seed(SEED)
torch.random.manual_seed(SEED)
x = torch.randn(size)
_ = Rot180Transform(axes)(x)
assert True


@pytest.mark.parametrize("size", sizes)
def test_run_rot90_random(size):
random.seed(SEED)
torch.random.manual_seed(SEED)
x = torch.randn(size)
_ = RandomRot90Transform()(x)
assert True


@pytest.mark.parametrize("size", sizes)
def test_run_fov_patch(size):
random.seed(SEED)
Expand Down

0 comments on commit 861cb4a

Please sign in to comment.