Skip to content

Commit

Permalink
Remove _ExtractPatches
Browse files Browse the repository at this point in the history
  • Loading branch information
ashnair1 committed Dec 27, 2024
1 parent 6457c71 commit 92c6b63
Showing 1 changed file with 0 additions and 76 deletions.
76 changes: 0 additions & 76 deletions torchgeo/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

import kornia.augmentation as K
import torch
from einops import rearrange
from kornia.contrib import extract_tensor_patches
from kornia.geometry import crop_by_indices
from torch import Tensor

Expand Down Expand Up @@ -103,80 +101,6 @@ def forward(
}


class _ExtractPatches(K.GeometricAugmentationBase2D):
"""Extract patches from an image or mask."""

def __init__(
self,
window_size: int | tuple[int, int],
stride: int | tuple[int, int] | None = None,
padding: int | tuple[int, int] | None = 0,
keepdim: bool = True,
) -> None:
"""Initialize a new _ExtractPatches instance.
Args:
window_size: desired output size (out_h, out_w) of the crop
stride: stride of window to extract patches. Defaults to non-overlapping
patches (stride=window_size)
padding: zero padding added to the height and width dimensions
keepdim: Combine the patch dimension into the batch dimension
"""
super().__init__(p=1)
self.flags = {
'window_size': window_size,
'stride': stride if stride is not None else window_size,
'padding': padding,
'keepdim': keepdim,
}

def compute_transformation(
self, input: Tensor, params: dict[str, Tensor], flags: dict[str, Any]
) -> Tensor:
"""Compute the transformation.
Args:
input: the input tensor
params: generated parameters
flags: static parameters
Returns:
the transformation
"""
out: Tensor = self.identity_matrix(input)
return out

def apply_transform(
self,
input: Tensor,
params: dict[str, Tensor],
flags: dict[str, Any],
transform: Tensor | None = None,
) -> Tensor:
"""Apply the transform.
Args:
input: the input tensor
params: generated parameters
flags: static parameters
transform: the geometric transformation tensor
Returns:
the augmented input
"""
out = extract_tensor_patches(
input,
window_size=flags['window_size'],
stride=flags['stride'],
padding=flags['padding'],
)

if flags['keepdim']:
out = rearrange(out, 'b t c h w -> (b t) c h w')

return out


class _Clamp(K.IntensityAugmentationBase2D):
"""Clamp images to a specific range."""

Expand Down

0 comments on commit 92c6b63

Please sign in to comment.