Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Usuage with torchgeo and kornia #19

Open
Geethen opened this issue Jul 20, 2024 · 0 comments
Open

Usuage with torchgeo and kornia #19

Geethen opened this issue Jul 20, 2024 · 0 comments

Comments

@Geethen
Copy link

Geethen commented Jul 20, 2024

I was initially struggling to figure out how to use flipnslide with torchgeo data loaders and transforms powered by kornia. Please see the attached reproducible example for any interested. This might be a useful addition to the docs

import torch
from flipnslide.tiling import FlipnSlide
from typing import Any
from torch import Tensor
import kornia as K

class _FlipnSlide(K.augmentation.GeometricAugmentationBase2D):
    """Flip and slide a tensor."""

    def __init__(self, tilesize: int, viz: bool = False) -> None:
        """Initialize a new _FlipnSlide instance.

        Args:
            tilesize: desired tile size
            viz: visualization flag
        """
        super().__init__(same_on_batch = True )
        self.flags = {'tilesize': tilesize, 'viz': viz}

    def compute_transformation(
        self, input: Tensor, params: dict[str, Tensor], flags: dict[str, Any]
    ) -> Tensor:
        """Compute the transformation.

        Args:
            input: the input tensor
            params: generated parameters
            flags: static parameters

        Returns:
            the transformation
        """
        
        out: Tensor = self.identity_matrix(input)
        return out

    def apply_transform(
        self,
        input: Tensor,
        params: dict[str, Tensor],
        flags: dict[str, Any],
        transform: Tensor = None,
    ) -> Tensor:
        """Apply the transform.

        Args:
            input: the input tensor
            params: generated parameters
            flags: static parameters
            transform: the geometric transformation tensor

        Returns:
            the augmented input
        """
        np_array = input.squeeze(0).cpu().numpy()
        sample_tiled = FlipnSlide(
            tile_size=flags['tilesize'], 
            data_type='tensor',
            save=False, 
            image=np_array,
            viz=flags['viz']
        )
        return sample_tiled.tiles

# Usage with AugmentationSequential
flipnslide = _FlipnSlide(tilesize=64, viz=False)
tfms = AugmentationSequential(
    flipnslide,
    data_keys=None
)
# Example usage
train_batch = {
    'image': torch.rand(1, 3, 256, 256),
    'mask': torch.rand(1, 1, 256, 256)
}

transformed = tfms(train_batch)

print("Transformed image shape:", transformed['image'].shape)
print("Transformed mask shape:", transformed['mask'].shape)
print(transformed)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant