From dadec3023dc2c07fde791615567eab3290da3e3b Mon Sep 17 00:00:00 2001 From: Ashwin Nair Date: Sat, 28 Dec 2024 18:47:05 +0400 Subject: [PATCH] Revert "Remove AugmentationSequential wrapper" This reverts commit 6457c712831b9c23cd541d5e09d18876777ef295. --- tests/transforms/test_transforms.py | 123 +++++++++++++++++----------- torchgeo/transforms/__init__.py | 2 + torchgeo/transforms/transforms.py | 97 ++++++++++++++++++++++ 3 files changed, 176 insertions(+), 46 deletions(-) diff --git a/tests/transforms/test_transforms.py b/tests/transforms/test_transforms.py index 1c8e7e274eb..1f2071ae812 100644 --- a/tests/transforms/test_transforms.py +++ b/tests/transforms/test_transforms.py @@ -4,10 +4,10 @@ import kornia.augmentation as K import pytest import torch -from kornia.contrib import ExtractTensorPatches from torch import Tensor -from torchgeo.transforms import indices +from torchgeo.transforms import indices, transforms +from torchgeo.transforms.transforms import _ExtractPatches # Kornia is very particular about its boxes: # @@ -23,7 +23,7 @@ def batch_gray() -> dict[str, Tensor]: return { 'image': torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float), 'mask': torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), - 'bbox_xyxy': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), + 'boxes': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), 'labels': torch.tensor([[0, 1]]), } @@ -42,7 +42,7 @@ def batch_rgb() -> dict[str, Tensor]: dtype=torch.float, ), 'mask': torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), - 'bbox_xyxy': torch.tensor([[0.0, 1.0, 1.0, 2.0]], dtype=torch.float), + 'boxes': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), 'labels': torch.tensor([[0, 1]]), } @@ -63,7 +63,7 @@ def batch_multispectral() -> dict[str, Tensor]: dtype=torch.float, ), 'mask': torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), - 'bbox_xyxy': torch.tensor([[0.0, 1.0, 1.0, 2.0]], dtype=torch.float), + 'boxes': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), 'labels': torch.tensor([[0, 1]]), } @@ -79,10 +79,12 @@ def test_augmentation_sequential_gray(batch_gray: dict[str, Tensor]) -> None: expected = { 'image': torch.tensor([[[[3, 2, 1], [6, 5, 4], [9, 8, 7]]]], dtype=torch.float), 'mask': torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), - 'bbox_xyxy': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), + 'boxes': torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), 'labels': torch.tensor([[0, 1]]), } - augs = K.AugmentationSequential(K.RandomHorizontalFlip(p=1.0), data_keys=None) + augs = transforms.AugmentationSequential( + K.RandomHorizontalFlip(p=1.0), data_keys=['image', 'mask', 'boxes'] + ) output = augs(batch_gray) assert_matching(output, expected) @@ -100,10 +102,12 @@ def test_augmentation_sequential_rgb(batch_rgb: dict[str, Tensor]) -> None: dtype=torch.float, ), 'mask': torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), - 'bbox_xyxy': torch.tensor([[1.0, 1.0, 2.0, 2.0]], dtype=torch.float), + 'boxes': torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), 'labels': torch.tensor([[0, 1]]), } - augs = K.AugmentationSequential(K.RandomHorizontalFlip(p=1.0), data_keys=None) + augs = transforms.AugmentationSequential( + K.RandomHorizontalFlip(p=1.0), data_keys=['image', 'mask', 'boxes'] + ) output = augs(batch_rgb) assert_matching(output, expected) @@ -115,20 +119,22 @@ def test_augmentation_sequential_multispectral( 'image': torch.tensor( [ [ - [[7, 8, 9], [4, 5, 6], [1, 2, 3]], - [[7, 8, 9], [4, 5, 6], [1, 2, 3]], - [[7, 8, 9], [4, 5, 6], [1, 2, 3]], - [[7, 8, 9], [4, 5, 6], [1, 2, 3]], - [[7, 8, 9], [4, 5, 6], [1, 2, 3]], + [[3, 2, 1], [6, 5, 4], [9, 8, 7]], + [[3, 2, 1], [6, 5, 4], [9, 8, 7]], + [[3, 2, 1], [6, 5, 4], [9, 8, 7]], + [[3, 2, 1], [6, 5, 4], [9, 8, 7]], + [[3, 2, 1], [6, 5, 4], [9, 8, 7]], ] ], dtype=torch.float, ), - 'mask': torch.tensor([[[1, 1, 1], [0, 1, 1], [0, 0, 1]]], dtype=torch.long), - 'bbox_xyxy': torch.tensor([[0.0, 0.0, 1.0, 1.0]], dtype=torch.float), + 'mask': torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), + 'boxes': torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), 'labels': torch.tensor([[0, 1]]), } - augs = K.AugmentationSequential(K.RandomVerticalFlip(p=1.0), data_keys=None) + augs = transforms.AugmentationSequential( + K.RandomHorizontalFlip(p=1.0), data_keys=['image', 'mask', 'boxes'] + ) output = augs(batch_multispectral) assert_matching(output, expected) @@ -136,22 +142,28 @@ def test_augmentation_sequential_multispectral( def test_augmentation_sequential_image_only( batch_multispectral: dict[str, Tensor], ) -> None: - expected_image = torch.tensor( - [ + expected = { + 'image': torch.tensor( [ - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - ] - ], - dtype=torch.float, + [ + [[3, 2, 1], [6, 5, 4], [9, 8, 7]], + [[3, 2, 1], [6, 5, 4], [9, 8, 7]], + [[3, 2, 1], [6, 5, 4], [9, 8, 7]], + [[3, 2, 1], [6, 5, 4], [9, 8, 7]], + [[3, 2, 1], [6, 5, 4], [9, 8, 7]], + ] + ], + dtype=torch.float, + ), + 'mask': torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), + 'boxes': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), + 'labels': torch.tensor([[0, 1]]), + } + augs = transforms.AugmentationSequential( + K.RandomHorizontalFlip(p=1.0), data_keys=['image'] ) - - augs = K.AugmentationSequential(K.RandomHorizontalFlip(p=1.0), data_keys=['image']) - aug_image = augs(batch_multispectral['image']) - assert torch.allclose(aug_image, expected_image) + output = augs(batch_multispectral) + assert_matching(output, expected) def test_sequential_transforms_augmentations( @@ -176,17 +188,17 @@ def test_sequential_transforms_augmentations( dtype=torch.float, ), 'mask': torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), - 'bbox_xyxy': torch.tensor([[1.0, 1.0, 2.0, 2.0]], dtype=torch.float), + 'boxes': torch.tensor([[1.0, 0.0, 3.0, 2.0]], dtype=torch.float), 'labels': torch.tensor([[0, 1]]), } - train_transforms = K.AugmentationSequential( + train_transforms = transforms.AugmentationSequential( indices.AppendNBR(index_nir=0, index_swir=0), indices.AppendNDBI(index_swir=0, index_nir=0), indices.AppendNDSI(index_green=0, index_swir=0), indices.AppendNDVI(index_red=0, index_nir=0), indices.AppendNDWI(index_green=0, index_nir=0), K.RandomHorizontalFlip(p=1.0), - data_keys=None, + data_keys=['image', 'mask', 'boxes'], ) output = train_transforms(batch_multispectral) assert_matching(output, expected) @@ -203,12 +215,12 @@ def test_extract_patches() -> None: 'image': torch.randn(size=(b, c, h, w)), 'mask': torch.randint(low=0, high=2, size=(b, h, w)), } - train_transforms = ExtractTensorPatches(p, s) - output = {} - output['image'] = train_transforms(batch['image']) - output['mask'] = train_transforms(batch['mask'].unsqueeze(1)).squeeze(2) - assert output['image'].shape == (b, num_patches, c, p, p) - assert output['mask'].shape == (b, num_patches, p, p) + train_transforms = transforms.AugmentationSequential( + _ExtractPatches(window_size=p), same_on_batch=True, data_keys=['image', 'mask'] + ) + output = train_transforms(batch) + assert batch['image'].shape == (b * num_patches, c, p, p) + assert batch['mask'].shape == (b * num_patches, p, p) # Test different stride s = 16 @@ -217,10 +229,29 @@ def test_extract_patches() -> None: 'image': torch.randn(size=(b, c, h, w)), 'mask': torch.randint(low=0, high=2, size=(b, h, w)), } + train_transforms = transforms.AugmentationSequential( + _ExtractPatches(window_size=p, stride=s), + same_on_batch=True, + data_keys=['image', 'mask'], + ) + output = train_transforms(batch) + assert batch['image'].shape == (b * num_patches, c, p, p) + assert batch['mask'].shape == (b * num_patches, p, p) - train_transforms = ExtractTensorPatches(p, stride=16) - output = {} - output['image'] = train_transforms(batch['image']) - output['mask'] = train_transforms(batch['mask'].unsqueeze(1)).squeeze(2) - assert output['image'].shape == (b, num_patches, c, p, p) - assert output['mask'].shape == (b, num_patches, p, p) + # Test keepdim=False + s = p + num_patches = ((h - p + s) // s) * ((w - p + s) // s) + batch = { + 'image': torch.randn(size=(b, c, h, w)), + 'mask': torch.randint(low=0, high=2, size=(b, h, w)), + } + train_transforms = transforms.AugmentationSequential( + _ExtractPatches(window_size=p, stride=s, keepdim=False), + same_on_batch=True, + data_keys=['image', 'mask'], + ) + output = train_transforms(batch) + for k, v in output.items(): + print(k, v.shape, v.dtype) + assert batch['image'].shape == (b, num_patches, c, p, p) + assert batch['mask'].shape == (b, num_patches, 1, p, p) diff --git a/torchgeo/transforms/__init__.py b/torchgeo/transforms/__init__.py index 34291d71345..5a0f9ee3392 100644 --- a/torchgeo/transforms/__init__.py +++ b/torchgeo/transforms/__init__.py @@ -20,6 +20,7 @@ AppendSWI, AppendTriBandNormalizedDifferenceIndex, ) +from .transforms import AugmentationSequential __all__ = ( 'AppendBNDVI', @@ -36,5 +37,6 @@ 'AppendRBNDVI', 'AppendSWI', 'AppendTriBandNormalizedDifferenceIndex', + 'AugmentationSequential', 'RandomGrayscale', ) diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index a98fe995295..3cc816d4538 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -8,7 +8,104 @@ import kornia.augmentation as K import torch from kornia.geometry import crop_by_indices +from kornia.geometry.boxes import Boxes from torch import Tensor +from torch.nn.modules import Module + + +# TODO: contribute these to Kornia and delete this file +class AugmentationSequential(Module): + """Wrapper around kornia AugmentationSequential to handle input dicts. + + .. deprecated:: 0.4 + Use :class:`kornia.augmentation.container.AugmentationSequential` instead. + """ + + def __init__( + self, + *args: K.base._AugmentationBase | K.ImageSequential, + data_keys: list[str], + **kwargs: Any, + ) -> None: + """Initialize a new augmentation sequential instance. + + Args: + *args: Sequence of kornia augmentations + data_keys: List of inputs to augment (e.g., ["image", "mask", "boxes"]) + **kwargs: Keyword arguments passed to ``K.AugmentationSequential`` + + .. versionadded:: 0.5 + The ``**kwargs`` parameter. + """ + super().__init__() + self.data_keys = data_keys + + keys: list[str] = [] + for key in data_keys: + if key.startswith('image'): + keys.append('input') + elif key == 'boxes': + keys.append('bbox') + elif key == 'masks': + keys.append('mask') + else: + keys.append(key) + + self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs) + + def forward(self, batch: dict[str, Any]) -> dict[str, Any]: + """Perform augmentations and update data dict. + + Args: + batch: the input + + Returns: + the augmented input + """ + # Kornia augmentations require all inputs to be float + dtype = {} + for key in self.data_keys: + dtype[key] = batch[key].dtype + batch[key] = batch[key].float() + + # Convert shape of boxes from [N, 4] to [N, 4, 2] + if 'boxes' in batch and ( + isinstance(batch['boxes'], list) or batch['boxes'].ndim == 2 + ): + batch['boxes'] = Boxes.from_tensor(batch['boxes']).data + + # Kornia requires masks to have a channel dimension + if 'mask' in batch and batch['mask'].ndim == 3: + batch['mask'] = rearrange(batch['mask'], 'b h w -> b () h w') + + if 'masks' in batch and batch['masks'].ndim == 3: + batch['masks'] = rearrange(batch['masks'], 'c h w -> () c h w') + + inputs = [batch[k] for k in self.data_keys] + outputs_list: Tensor | list[Tensor] = self.augs(*inputs) + outputs_list = ( + outputs_list if isinstance(outputs_list, list) else [outputs_list] + ) + outputs: dict[str, Tensor] = { + k: v for k, v in zip(self.data_keys, outputs_list) + } + batch.update(outputs) + + # Convert all inputs back to their previous dtype + for key in self.data_keys: + batch[key] = batch[key].to(dtype[key]) + + # Convert boxes to default [N, 4] + if 'boxes' in batch: + batch['boxes'] = Boxes(batch['boxes']).to_tensor(mode='xyxy') + + # Torchmetrics does not support masks with a channel dimension + if 'mask' in batch and batch['mask'].shape[1] == 1: + batch['mask'] = rearrange(batch['mask'], 'b () h w -> b h w') + if 'masks' in batch and batch['masks'].ndim == 4: + batch['masks'] = rearrange(batch['masks'], '() c h w -> c h w') + + return batch class _RandomNCrop(K.GeometricAugmentationBase2D):