Skip to content

Commit

Permalink
Revert "Remove AugmentationSequential wrapper"
Browse files Browse the repository at this point in the history
This reverts commit 6457c71.
  • Loading branch information
ashnair1 committed Dec 28, 2024
1 parent 92c6b63 commit dadec30
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 46 deletions.
123 changes: 77 additions & 46 deletions tests/transforms/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import kornia.augmentation as K
import pytest
import torch
from kornia.contrib import ExtractTensorPatches
from torch import Tensor

from torchgeo.transforms import indices
from torchgeo.transforms import indices, transforms
from torchgeo.transforms.transforms import _ExtractPatches

# Kornia is very particular about its boxes:
#
Expand All @@ -23,7 +23,7 @@ def batch_gray() -> dict[str, Tensor]:
return {
'image': torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float),
'mask': torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long),
'bbox_xyxy': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float),
'boxes': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float),
'labels': torch.tensor([[0, 1]]),
}

Expand All @@ -42,7 +42,7 @@ def batch_rgb() -> dict[str, Tensor]:
dtype=torch.float,
),
'mask': torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long),
'bbox_xyxy': torch.tensor([[0.0, 1.0, 1.0, 2.0]], dtype=torch.float),
'boxes': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float),
'labels': torch.tensor([[0, 1]]),
}

Expand All @@ -63,7 +63,7 @@ def batch_multispectral() -> dict[str, Tensor]:
dtype=torch.float,
),
'mask': torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long),
'bbox_xyxy': torch.tensor([[0.0, 1.0, 1.0, 2.0]], dtype=torch.float),
'boxes': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float),
'labels': torch.tensor([[0, 1]]),
}

Expand All @@ -79,10 +79,12 @@ def test_augmentation_sequential_gray(batch_gray: dict[str, Tensor]) -> None:
expected = {
'image': torch.tensor([[[[3, 2, 1], [6, 5, 4], [9, 8, 7]]]], dtype=torch.float),
'mask': torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long),
'bbox_xyxy': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float),
'boxes': torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float),
'labels': torch.tensor([[0, 1]]),
}
augs = K.AugmentationSequential(K.RandomHorizontalFlip(p=1.0), data_keys=None)
augs = transforms.AugmentationSequential(
K.RandomHorizontalFlip(p=1.0), data_keys=['image', 'mask', 'boxes']
)
output = augs(batch_gray)
assert_matching(output, expected)

Expand All @@ -100,10 +102,12 @@ def test_augmentation_sequential_rgb(batch_rgb: dict[str, Tensor]) -> None:
dtype=torch.float,
),
'mask': torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long),
'bbox_xyxy': torch.tensor([[1.0, 1.0, 2.0, 2.0]], dtype=torch.float),
'boxes': torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float),
'labels': torch.tensor([[0, 1]]),
}
augs = K.AugmentationSequential(K.RandomHorizontalFlip(p=1.0), data_keys=None)
augs = transforms.AugmentationSequential(
K.RandomHorizontalFlip(p=1.0), data_keys=['image', 'mask', 'boxes']
)
output = augs(batch_rgb)
assert_matching(output, expected)

Expand All @@ -115,43 +119,51 @@ def test_augmentation_sequential_multispectral(
'image': torch.tensor(
[
[
[[7, 8, 9], [4, 5, 6], [1, 2, 3]],
[[7, 8, 9], [4, 5, 6], [1, 2, 3]],
[[7, 8, 9], [4, 5, 6], [1, 2, 3]],
[[7, 8, 9], [4, 5, 6], [1, 2, 3]],
[[7, 8, 9], [4, 5, 6], [1, 2, 3]],
[[3, 2, 1], [6, 5, 4], [9, 8, 7]],
[[3, 2, 1], [6, 5, 4], [9, 8, 7]],
[[3, 2, 1], [6, 5, 4], [9, 8, 7]],
[[3, 2, 1], [6, 5, 4], [9, 8, 7]],
[[3, 2, 1], [6, 5, 4], [9, 8, 7]],
]
],
dtype=torch.float,
),
'mask': torch.tensor([[[1, 1, 1], [0, 1, 1], [0, 0, 1]]], dtype=torch.long),
'bbox_xyxy': torch.tensor([[0.0, 0.0, 1.0, 1.0]], dtype=torch.float),
'mask': torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long),
'boxes': torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float),
'labels': torch.tensor([[0, 1]]),
}
augs = K.AugmentationSequential(K.RandomVerticalFlip(p=1.0), data_keys=None)
augs = transforms.AugmentationSequential(
K.RandomHorizontalFlip(p=1.0), data_keys=['image', 'mask', 'boxes']
)
output = augs(batch_multispectral)
assert_matching(output, expected)


def test_augmentation_sequential_image_only(
batch_multispectral: dict[str, Tensor],
) -> None:
expected_image = torch.tensor(
[
expected = {
'image': torch.tensor(
[
[[3, 2, 1], [6, 5, 4], [9, 8, 7]],
[[3, 2, 1], [6, 5, 4], [9, 8, 7]],
[[3, 2, 1], [6, 5, 4], [9, 8, 7]],
[[3, 2, 1], [6, 5, 4], [9, 8, 7]],
[[3, 2, 1], [6, 5, 4], [9, 8, 7]],
]
],
dtype=torch.float,
[
[[3, 2, 1], [6, 5, 4], [9, 8, 7]],
[[3, 2, 1], [6, 5, 4], [9, 8, 7]],
[[3, 2, 1], [6, 5, 4], [9, 8, 7]],
[[3, 2, 1], [6, 5, 4], [9, 8, 7]],
[[3, 2, 1], [6, 5, 4], [9, 8, 7]],
]
],
dtype=torch.float,
),
'mask': torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long),
'boxes': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float),
'labels': torch.tensor([[0, 1]]),
}
augs = transforms.AugmentationSequential(
K.RandomHorizontalFlip(p=1.0), data_keys=['image']
)

augs = K.AugmentationSequential(K.RandomHorizontalFlip(p=1.0), data_keys=['image'])
aug_image = augs(batch_multispectral['image'])
assert torch.allclose(aug_image, expected_image)
output = augs(batch_multispectral)
assert_matching(output, expected)


def test_sequential_transforms_augmentations(
Expand All @@ -176,17 +188,17 @@ def test_sequential_transforms_augmentations(
dtype=torch.float,
),
'mask': torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long),
'bbox_xyxy': torch.tensor([[1.0, 1.0, 2.0, 2.0]], dtype=torch.float),
'boxes': torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float),
'labels': torch.tensor([[0, 1]]),
}
train_transforms = K.AugmentationSequential(
train_transforms = transforms.AugmentationSequential(
indices.AppendNBR(index_nir=0, index_swir=0),
indices.AppendNDBI(index_swir=0, index_nir=0),
indices.AppendNDSI(index_green=0, index_swir=0),
indices.AppendNDVI(index_red=0, index_nir=0),
indices.AppendNDWI(index_green=0, index_nir=0),
K.RandomHorizontalFlip(p=1.0),
data_keys=None,
data_keys=['image', 'mask', 'boxes'],
)
output = train_transforms(batch_multispectral)
assert_matching(output, expected)
Expand All @@ -203,12 +215,12 @@ def test_extract_patches() -> None:
'image': torch.randn(size=(b, c, h, w)),
'mask': torch.randint(low=0, high=2, size=(b, h, w)),
}
train_transforms = ExtractTensorPatches(p, s)
output = {}
output['image'] = train_transforms(batch['image'])
output['mask'] = train_transforms(batch['mask'].unsqueeze(1)).squeeze(2)
assert output['image'].shape == (b, num_patches, c, p, p)
assert output['mask'].shape == (b, num_patches, p, p)
train_transforms = transforms.AugmentationSequential(
_ExtractPatches(window_size=p), same_on_batch=True, data_keys=['image', 'mask']
)
output = train_transforms(batch)
assert batch['image'].shape == (b * num_patches, c, p, p)
assert batch['mask'].shape == (b * num_patches, p, p)

# Test different stride
s = 16
Expand All @@ -217,10 +229,29 @@ def test_extract_patches() -> None:
'image': torch.randn(size=(b, c, h, w)),
'mask': torch.randint(low=0, high=2, size=(b, h, w)),
}
train_transforms = transforms.AugmentationSequential(
_ExtractPatches(window_size=p, stride=s),
same_on_batch=True,
data_keys=['image', 'mask'],
)
output = train_transforms(batch)
assert batch['image'].shape == (b * num_patches, c, p, p)
assert batch['mask'].shape == (b * num_patches, p, p)

train_transforms = ExtractTensorPatches(p, stride=16)
output = {}
output['image'] = train_transforms(batch['image'])
output['mask'] = train_transforms(batch['mask'].unsqueeze(1)).squeeze(2)
assert output['image'].shape == (b, num_patches, c, p, p)
assert output['mask'].shape == (b, num_patches, p, p)
# Test keepdim=False
s = p
num_patches = ((h - p + s) // s) * ((w - p + s) // s)
batch = {
'image': torch.randn(size=(b, c, h, w)),
'mask': torch.randint(low=0, high=2, size=(b, h, w)),
}
train_transforms = transforms.AugmentationSequential(
_ExtractPatches(window_size=p, stride=s, keepdim=False),
same_on_batch=True,
data_keys=['image', 'mask'],
)
output = train_transforms(batch)
for k, v in output.items():
print(k, v.shape, v.dtype)
assert batch['image'].shape == (b, num_patches, c, p, p)
assert batch['mask'].shape == (b, num_patches, 1, p, p)
2 changes: 2 additions & 0 deletions torchgeo/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
AppendSWI,
AppendTriBandNormalizedDifferenceIndex,
)
from .transforms import AugmentationSequential

__all__ = (
'AppendBNDVI',
Expand All @@ -36,5 +37,6 @@
'AppendRBNDVI',
'AppendSWI',
'AppendTriBandNormalizedDifferenceIndex',
'AugmentationSequential',
'RandomGrayscale',
)
97 changes: 97 additions & 0 deletions torchgeo/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,104 @@
import kornia.augmentation as K
import torch
from kornia.geometry import crop_by_indices
from kornia.geometry.boxes import Boxes
from torch import Tensor
from torch.nn.modules import Module


# TODO: contribute these to Kornia and delete this file
class AugmentationSequential(Module):
"""Wrapper around kornia AugmentationSequential to handle input dicts.
.. deprecated:: 0.4
Use :class:`kornia.augmentation.container.AugmentationSequential` instead.
"""

def __init__(
self,
*args: K.base._AugmentationBase | K.ImageSequential,
data_keys: list[str],
**kwargs: Any,
) -> None:
"""Initialize a new augmentation sequential instance.
Args:
*args: Sequence of kornia augmentations
data_keys: List of inputs to augment (e.g., ["image", "mask", "boxes"])
**kwargs: Keyword arguments passed to ``K.AugmentationSequential``
.. versionadded:: 0.5
The ``**kwargs`` parameter.
"""
super().__init__()
self.data_keys = data_keys

keys: list[str] = []
for key in data_keys:
if key.startswith('image'):
keys.append('input')
elif key == 'boxes':
keys.append('bbox')
elif key == 'masks':
keys.append('mask')
else:
keys.append(key)

self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs)

def forward(self, batch: dict[str, Any]) -> dict[str, Any]:
"""Perform augmentations and update data dict.
Args:
batch: the input
Returns:
the augmented input
"""
# Kornia augmentations require all inputs to be float
dtype = {}
for key in self.data_keys:
dtype[key] = batch[key].dtype
batch[key] = batch[key].float()

# Convert shape of boxes from [N, 4] to [N, 4, 2]
if 'boxes' in batch and (
isinstance(batch['boxes'], list) or batch['boxes'].ndim == 2
):
batch['boxes'] = Boxes.from_tensor(batch['boxes']).data

# Kornia requires masks to have a channel dimension
if 'mask' in batch and batch['mask'].ndim == 3:
batch['mask'] = rearrange(batch['mask'], 'b h w -> b () h w')

Check failure on line 79 in torchgeo/transforms/transforms.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

torchgeo/transforms/transforms.py:79:29: F821 Undefined name `rearrange`

if 'masks' in batch and batch['masks'].ndim == 3:
batch['masks'] = rearrange(batch['masks'], 'c h w -> () c h w')

Check failure on line 82 in torchgeo/transforms/transforms.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

torchgeo/transforms/transforms.py:82:30: F821 Undefined name `rearrange`

inputs = [batch[k] for k in self.data_keys]
outputs_list: Tensor | list[Tensor] = self.augs(*inputs)
outputs_list = (
outputs_list if isinstance(outputs_list, list) else [outputs_list]
)
outputs: dict[str, Tensor] = {
k: v for k, v in zip(self.data_keys, outputs_list)
}
batch.update(outputs)

# Convert all inputs back to their previous dtype
for key in self.data_keys:
batch[key] = batch[key].to(dtype[key])

# Convert boxes to default [N, 4]
if 'boxes' in batch:
batch['boxes'] = Boxes(batch['boxes']).to_tensor(mode='xyxy')

# Torchmetrics does not support masks with a channel dimension
if 'mask' in batch and batch['mask'].shape[1] == 1:
batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w')

Check failure on line 104 in torchgeo/transforms/transforms.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

torchgeo/transforms/transforms.py:104:29: F821 Undefined name `rearrange`
if 'masks' in batch and batch['masks'].ndim == 4:
batch['masks'] = rearrange(batch['masks'], '() c h w -> c h w')

Check failure on line 106 in torchgeo/transforms/transforms.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F821)

torchgeo/transforms/transforms.py:106:30: F821 Undefined name `rearrange`

return batch


class _RandomNCrop(K.GeometricAugmentationBase2D):
Expand Down

0 comments on commit dadec30

Please sign in to comment.