You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I was initially struggling to figure out how to use flipnslide with torchgeo data loaders and transforms powered by kornia. Please see the attached reproducible example for any interested. This might be a useful addition to the docs
import torch
from flipnslide.tiling import FlipnSlide
from typing import Any
from torch import Tensor
import kornia as K
class _FlipnSlide(K.augmentation.GeometricAugmentationBase2D):
"""Flip and slide a tensor."""
def __init__(self, tilesize: int, viz: bool = False) -> None:
"""Initialize a new _FlipnSlide instance.
Args:
tilesize: desired tile size
viz: visualization flag
"""
super().__init__(same_on_batch = True )
self.flags = {'tilesize': tilesize, 'viz': viz}
def compute_transformation(
self, input: Tensor, params: dict[str, Tensor], flags: dict[str, Any]
) -> Tensor:
"""Compute the transformation.
Args:
input: the input tensor
params: generated parameters
flags: static parameters
Returns:
the transformation
"""
out: Tensor = self.identity_matrix(input)
return out
def apply_transform(
self,
input: Tensor,
params: dict[str, Tensor],
flags: dict[str, Any],
transform: Tensor = None,
) -> Tensor:
"""Apply the transform.
Args:
input: the input tensor
params: generated parameters
flags: static parameters
transform: the geometric transformation tensor
Returns:
the augmented input
"""
np_array = input.squeeze(0).cpu().numpy()
sample_tiled = FlipnSlide(
tile_size=flags['tilesize'],
data_type='tensor',
save=False,
image=np_array,
viz=flags['viz']
)
return sample_tiled.tiles
# Usage with AugmentationSequential
flipnslide = _FlipnSlide(tilesize=64, viz=False)
tfms = AugmentationSequential(
flipnslide,
data_keys=None
)
# Example usage
train_batch = {
'image': torch.rand(1, 3, 256, 256),
'mask': torch.rand(1, 1, 256, 256)
}
transformed = tfms(train_batch)
print("Transformed image shape:", transformed['image'].shape)
print("Transformed mask shape:", transformed['mask'].shape)
print(transformed)
The text was updated successfully, but these errors were encountered:
I was initially struggling to figure out how to use flipnslide with torchgeo data loaders and transforms powered by kornia. Please see the attached reproducible example for any interested. This might be a useful addition to the docs
The text was updated successfully, but these errors were encountered: