Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating augmentations, esp randaug to support full torch.Tensor pipeline #2372

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 109 additions & 89 deletions timm/data/auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,18 @@
import math
import re
from functools import partial
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageFilter
import torch
import PIL
import numpy as np

from PIL import Image, ImageFilter
from torchvision.transforms import InterpolationMode
import torchvision.transforms.functional as TF
try:
import torchvision.transforms.v2.functional as TF2
except ImportError:
TF2 = None

_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]])

Expand All @@ -42,160 +48,162 @@
img_mean=_FILL,
)

if hasattr(Image, "Resampling"):
_RANDOM_INTERPOLATION = (Image.Resampling.BILINEAR, Image.Resampling.BICUBIC)
_DEFAULT_INTERPOLATION = Image.Resampling.BICUBIC
else:
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
_DEFAULT_INTERPOLATION = Image.BICUBIC

_RANDOM_INTERPOLATION = (InterpolationMode.BILINEAR, InterpolationMode.BICUBIC)
_DEFAULT_INTERPOLATION = InterpolationMode.BICUBIC


def _interpolation(kwargs):
interpolation = kwargs.pop('resample', _DEFAULT_INTERPOLATION)
def _interpolation(kwargs, basic_only=False):
interpolation = kwargs.pop('interpolation', _DEFAULT_INTERPOLATION)
if isinstance(interpolation, (list, tuple)):
return random.choice(interpolation)
interpolation = random.choice(interpolation)
if basic_only:
if interpolation not in (InterpolationMode.NEAREST, InterpolationMode.BILINEAR):
interpolation = InterpolationMode.BILINEAR
return interpolation


def _check_args_tf(kwargs):
if 'fillcolor' in kwargs and _PIL_VER < (5, 0):
kwargs.pop('fillcolor')
kwargs['resample'] = _interpolation(kwargs)
kwargs['interpolation'] = _interpolation(kwargs)


def _check_args_affine(img, kwargs):
if isinstance(img, torch.Tensor):
kwargs['interpolation'] = _interpolation(kwargs, basic_only=True)
else:
kwargs['interpolation'] = _interpolation(kwargs)


def shear_x(img, factor, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs)
_check_args_affine(img, kwargs)
return TF.affine(img, angle=0, translate=[0, 0], scale=1, shear=[math.degrees(math.atan(factor)), 0], **kwargs)


def shear_y(img, factor, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs)
_check_args_affine(img, kwargs)
return TF.affine(img, angle=0, translate=[0, 0], scale=1, shear=[0, math.degrees(math.atan(factor))], **kwargs)


def translate_x_rel(img, pct, **kwargs):
pixels = pct * img.size[0]
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
def translate_x_abs(img, pixels, **kwargs):
_check_args_affine(img, kwargs)
return TF.affine(img, angle=0, translate=[pixels, 0], scale=1, shear=[0, 0], **kwargs)


def translate_y_rel(img, pct, **kwargs):
pixels = pct * img.size[1]
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
def translate_y_abs(img, pixels, **kwargs):
_check_args_affine(img, kwargs)
return TF.affine(img, angle=0, translate=[0, pixels], scale=1, shear=[0, 0], **kwargs)


def translate_x_abs(img, pixels, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs)
def translate_x_rel(img, pct, **kwargs):
pixels = pct * TF.get_image_size(img)[0]
return translate_x_abs(img, pixels, **kwargs)


def translate_y_abs(img, pixels, **kwargs):
_check_args_tf(kwargs)
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs)
def translate_y_rel(img, pct, **kwargs):
pixels = pct * TF.get_image_size(img)[1]
return translate_y_abs(img, pixels, **kwargs)


def rotate(img, degrees, **kwargs):
_check_args_tf(kwargs)
if _PIL_VER >= (5, 2):
return img.rotate(degrees, **kwargs)
if _PIL_VER >= (5, 0):
w, h = img.size
post_trans = (0, 0)
rotn_center = (w / 2.0, h / 2.0)
angle = -math.radians(degrees)
matrix = [
round(math.cos(angle), 15),
round(math.sin(angle), 15),
0.0,
round(-math.sin(angle), 15),
round(math.cos(angle), 15),
0.0,
]

def transform(x, y, matrix):
(a, b, c, d, e, f) = matrix
return a * x + b * y + c, d * x + e * y + f

matrix[2], matrix[5] = transform(
-rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix
)
matrix[2] += rotn_center[0]
matrix[5] += rotn_center[1]
return img.transform(img.size, Image.AFFINE, matrix, **kwargs)
return img.rotate(degrees, resample=kwargs['resample'])
_check_args_affine(img, kwargs)
return TF.rotate(img, degrees, **kwargs)


def auto_contrast(img, **__):
return ImageOps.autocontrast(img)
return TF.autocontrast(img)


def invert(img, **__):
return ImageOps.invert(img)
return TF.invert(img)


def equalize(img, **__):
return ImageOps.equalize(img)
if isinstance(img, torch.Tensor) and img.is_floating_point():
if TF2 is None:
# FIXME warn / assert?
return img
return TF2.equalize(img)
return TF.equalize(img)


def solarize(img, thresh, **__):
return ImageOps.solarize(img, thresh)
if isinstance(img, torch.Tensor) and img.is_floating_point():
thresh = min(thresh / 255, 1.0)
return TF.solarize(img, thresh)


def solarize_add(img, add, thresh=128, **__):
lut = []
for i in range(256):
if i < thresh:
lut.append(min(255, i + add))
if isinstance(img, torch.Tensor):
if img.is_floating_point():
thresh = thresh / 255
add = add / 255
img_sum = (img + add).clamp_(max=1.0)
else:
lut.append(i)
img_sum = (img + add).clamp_(max=255)
return torch.where(img >= thresh, img_sum, img)
else:
lut = []
for i in range(256):
if i < thresh:
lut.append(min(255, i + add))
else:
lut.append(i)

if img.mode in ("L", "RGB"):
if img.mode == "RGB" and len(lut) == 256:
lut = lut + lut + lut
return img.point(lut)
if img.mode in ("L", "RGB"):
if img.mode == "RGB" and len(lut) == 256:
lut = lut + lut + lut
return img.point(lut)

return img


def posterize(img, bits_to_keep, **__):
if bits_to_keep >= 8:
return img
return ImageOps.posterize(img, bits_to_keep)
if isinstance(img, torch.Tensor) and img.is_floating_point():
if TF2 is None:
# FIXME warn / assert?
return img
return TF2.posterize(img, bits_to_keep)
return TF.posterize(img, bits_to_keep)


def contrast(img, factor, **__):
return ImageEnhance.Contrast(img).enhance(factor)
return TF.adjust_contrast(img, factor)


def color(img, factor, **__):
return ImageEnhance.Color(img).enhance(factor)
return TF.adjust_saturation(img, factor)


def brightness(img, factor, **__):
return ImageEnhance.Brightness(img).enhance(factor)
return TF.adjust_brightness(img, factor)


def sharpness(img, factor, **__):
return ImageEnhance.Sharpness(img).enhance(factor)
return TF.adjust_sharpness(img, factor)


def gaussian_blur(img, factor, **__):
img = img.filter(ImageFilter.GaussianBlur(radius=factor))
if isinstance(img, torch.Tensor):
kernel_size = 2 * int(3 * factor) + 1 # could be bigger, but more expensive
img = TF.gaussian_blur(img, kernel_size=kernel_size, sigma=factor)
else:
img = img.filter(ImageFilter.GaussianBlur(radius=factor))
return img


def gaussian_blur_rand(img, factor, **__):
radius_min = 0.1
radius_max = 2.0
img = img.filter(ImageFilter.GaussianBlur(radius=random.uniform(radius_min, radius_max * factor)))
return img
radius = random.uniform(radius_min, radius_max * factor)
return gaussian_blur(img, radius)


def desaturate(img, factor, **_):
factor = min(1., max(0., 1. - factor))
# enhance factor 0 = grayscale, 1.0 = no-change
return ImageEnhance.Color(img).enhance(factor)
return TF.adjust_saturation(img, factor)


def _randomly_negate(v):
Expand Down Expand Up @@ -356,7 +364,13 @@ def _solarize_add_level_to_arg(level, _hparams):

class AugmentOp:

def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
def __init__(
self,
name: str,
prob: float = 0.5,
magnitude: float = 10,
hparams: Optional[Dict[str, Any]] = None
):
hparams = hparams or _HPARAMS_DEFAULT
self.name = name
self.aug_fn = NAME_TO_OP[name]
Expand All @@ -365,8 +379,8 @@ def __init__(self, name, prob=0.5, magnitude=10, hparams=None):
self.magnitude = magnitude
self.hparams = hparams.copy()
self.kwargs = dict(
fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
fill=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
interpolation=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
)

# If magnitude_std is > 0, we introduce some randomness
Expand Down Expand Up @@ -564,7 +578,7 @@ def auto_augment_policy(name='v0', hparams=None):

class AutoAugment:

def __init__(self, policy):
def __init__(self, policy: List):
self.policy = policy

def __call__(self, img):
Expand Down Expand Up @@ -729,8 +743,14 @@ def rand_augment_ops(
):
hparams = hparams or _HPARAMS_DEFAULT
transforms = transforms or _RAND_TRANSFORMS
return [AugmentOp(
name, prob=prob, magnitude=magnitude, hparams=hparams) for name in transforms]
return [
AugmentOp(
name,
prob=prob,
magnitude=magnitude,
hparams=hparams
) for name in transforms
]


class RandAugment:
Expand Down
3 changes: 2 additions & 1 deletion timm/data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ def __init__(
re_prob=0.,
re_mode='const',
re_count=1,
re_num_splits=0):
re_num_splits=0,
):

mean = adapt_to_chs(mean, channels)
std = adapt_to_chs(std, channels)
Expand Down
Loading
Loading