diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py index 94438a0e0..f1feada3a 100644 --- a/timm/data/auto_augment.py +++ b/timm/data/auto_augment.py @@ -24,12 +24,18 @@ import math import re from functools import partial -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union -from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageFilter +import torch import PIL import numpy as np - +from PIL import Image, ImageFilter +from torchvision.transforms import InterpolationMode +import torchvision.transforms.functional as TF +try: + import torchvision.transforms.v2.functional as TF2 +except ImportError: + TF2 = None _PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]]) @@ -42,118 +48,111 @@ img_mean=_FILL, ) -if hasattr(Image, "Resampling"): - _RANDOM_INTERPOLATION = (Image.Resampling.BILINEAR, Image.Resampling.BICUBIC) - _DEFAULT_INTERPOLATION = Image.Resampling.BICUBIC -else: - _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) - _DEFAULT_INTERPOLATION = Image.BICUBIC + +_RANDOM_INTERPOLATION = (InterpolationMode.BILINEAR, InterpolationMode.BICUBIC) +_DEFAULT_INTERPOLATION = InterpolationMode.BICUBIC -def _interpolation(kwargs): - interpolation = kwargs.pop('resample', _DEFAULT_INTERPOLATION) +def _interpolation(kwargs, basic_only=False): + interpolation = kwargs.pop('interpolation', _DEFAULT_INTERPOLATION) if isinstance(interpolation, (list, tuple)): - return random.choice(interpolation) + interpolation = random.choice(interpolation) + if basic_only: + if interpolation not in (InterpolationMode.NEAREST, InterpolationMode.BILINEAR): + interpolation = InterpolationMode.BILINEAR return interpolation def _check_args_tf(kwargs): - if 'fillcolor' in kwargs and _PIL_VER < (5, 0): - kwargs.pop('fillcolor') - kwargs['resample'] = _interpolation(kwargs) + kwargs['interpolation'] = _interpolation(kwargs) + + +def _check_args_affine(img, kwargs): + if isinstance(img, torch.Tensor): + kwargs['interpolation'] = _interpolation(kwargs, basic_only=True) + else: + kwargs['interpolation'] = _interpolation(kwargs) def shear_x(img, factor, **kwargs): - _check_args_tf(kwargs) - return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) + _check_args_affine(img, kwargs) + return TF.affine(img, angle=0, translate=[0, 0], scale=1, shear=[math.degrees(math.atan(factor)), 0], **kwargs) def shear_y(img, factor, **kwargs): - _check_args_tf(kwargs) - return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) + _check_args_affine(img, kwargs) + return TF.affine(img, angle=0, translate=[0, 0], scale=1, shear=[0, math.degrees(math.atan(factor))], **kwargs) -def translate_x_rel(img, pct, **kwargs): - pixels = pct * img.size[0] - _check_args_tf(kwargs) - return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) +def translate_x_abs(img, pixels, **kwargs): + _check_args_affine(img, kwargs) + return TF.affine(img, angle=0, translate=[pixels, 0], scale=1, shear=[0, 0], **kwargs) -def translate_y_rel(img, pct, **kwargs): - pixels = pct * img.size[1] - _check_args_tf(kwargs) - return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) +def translate_y_abs(img, pixels, **kwargs): + _check_args_affine(img, kwargs) + return TF.affine(img, angle=0, translate=[0, pixels], scale=1, shear=[0, 0], **kwargs) -def translate_x_abs(img, pixels, **kwargs): - _check_args_tf(kwargs) - return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) +def translate_x_rel(img, pct, **kwargs): + pixels = pct * TF.get_image_size(img)[0] + return translate_x_abs(img, pixels, **kwargs) -def translate_y_abs(img, pixels, **kwargs): - _check_args_tf(kwargs) - return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) +def translate_y_rel(img, pct, **kwargs): + pixels = pct * TF.get_image_size(img)[1] + return translate_y_abs(img, pixels, **kwargs) def rotate(img, degrees, **kwargs): - _check_args_tf(kwargs) - if _PIL_VER >= (5, 2): - return img.rotate(degrees, **kwargs) - if _PIL_VER >= (5, 0): - w, h = img.size - post_trans = (0, 0) - rotn_center = (w / 2.0, h / 2.0) - angle = -math.radians(degrees) - matrix = [ - round(math.cos(angle), 15), - round(math.sin(angle), 15), - 0.0, - round(-math.sin(angle), 15), - round(math.cos(angle), 15), - 0.0, - ] - - def transform(x, y, matrix): - (a, b, c, d, e, f) = matrix - return a * x + b * y + c, d * x + e * y + f - - matrix[2], matrix[5] = transform( - -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix - ) - matrix[2] += rotn_center[0] - matrix[5] += rotn_center[1] - return img.transform(img.size, Image.AFFINE, matrix, **kwargs) - return img.rotate(degrees, resample=kwargs['resample']) + _check_args_affine(img, kwargs) + return TF.rotate(img, degrees, **kwargs) def auto_contrast(img, **__): - return ImageOps.autocontrast(img) + return TF.autocontrast(img) def invert(img, **__): - return ImageOps.invert(img) + return TF.invert(img) def equalize(img, **__): - return ImageOps.equalize(img) + if isinstance(img, torch.Tensor) and img.is_floating_point(): + if TF2 is None: + # FIXME warn / assert? + return img + return TF2.equalize(img) + return TF.equalize(img) def solarize(img, thresh, **__): - return ImageOps.solarize(img, thresh) + if isinstance(img, torch.Tensor) and img.is_floating_point(): + thresh = min(thresh / 255, 1.0) + return TF.solarize(img, thresh) def solarize_add(img, add, thresh=128, **__): - lut = [] - for i in range(256): - if i < thresh: - lut.append(min(255, i + add)) + if isinstance(img, torch.Tensor): + if img.is_floating_point(): + thresh = thresh / 255 + add = add / 255 + img_sum = (img + add).clamp_(max=1.0) else: - lut.append(i) + img_sum = (img + add).clamp_(max=255) + return torch.where(img >= thresh, img_sum, img) + else: + lut = [] + for i in range(256): + if i < thresh: + lut.append(min(255, i + add)) + else: + lut.append(i) - if img.mode in ("L", "RGB"): - if img.mode == "RGB" and len(lut) == 256: - lut = lut + lut + lut - return img.point(lut) + if img.mode in ("L", "RGB"): + if img.mode == "RGB" and len(lut) == 256: + lut = lut + lut + lut + return img.point(lut) return img @@ -161,41 +160,50 @@ def solarize_add(img, add, thresh=128, **__): def posterize(img, bits_to_keep, **__): if bits_to_keep >= 8: return img - return ImageOps.posterize(img, bits_to_keep) + if isinstance(img, torch.Tensor) and img.is_floating_point(): + if TF2 is None: + # FIXME warn / assert? + return img + return TF2.posterize(img, bits_to_keep) + return TF.posterize(img, bits_to_keep) def contrast(img, factor, **__): - return ImageEnhance.Contrast(img).enhance(factor) + return TF.adjust_contrast(img, factor) def color(img, factor, **__): - return ImageEnhance.Color(img).enhance(factor) + return TF.adjust_saturation(img, factor) def brightness(img, factor, **__): - return ImageEnhance.Brightness(img).enhance(factor) + return TF.adjust_brightness(img, factor) def sharpness(img, factor, **__): - return ImageEnhance.Sharpness(img).enhance(factor) + return TF.adjust_sharpness(img, factor) def gaussian_blur(img, factor, **__): - img = img.filter(ImageFilter.GaussianBlur(radius=factor)) + if isinstance(img, torch.Tensor): + kernel_size = 2 * int(3 * factor) + 1 # could be bigger, but more expensive + img = TF.gaussian_blur(img, kernel_size=kernel_size, sigma=factor) + else: + img = img.filter(ImageFilter.GaussianBlur(radius=factor)) return img def gaussian_blur_rand(img, factor, **__): radius_min = 0.1 radius_max = 2.0 - img = img.filter(ImageFilter.GaussianBlur(radius=random.uniform(radius_min, radius_max * factor))) - return img + radius = random.uniform(radius_min, radius_max * factor) + return gaussian_blur(img, radius) def desaturate(img, factor, **_): factor = min(1., max(0., 1. - factor)) # enhance factor 0 = grayscale, 1.0 = no-change - return ImageEnhance.Color(img).enhance(factor) + return TF.adjust_saturation(img, factor) def _randomly_negate(v): @@ -356,7 +364,13 @@ def _solarize_add_level_to_arg(level, _hparams): class AugmentOp: - def __init__(self, name, prob=0.5, magnitude=10, hparams=None): + def __init__( + self, + name: str, + prob: float = 0.5, + magnitude: float = 10, + hparams: Optional[Dict[str, Any]] = None + ): hparams = hparams or _HPARAMS_DEFAULT self.name = name self.aug_fn = NAME_TO_OP[name] @@ -365,8 +379,8 @@ def __init__(self, name, prob=0.5, magnitude=10, hparams=None): self.magnitude = magnitude self.hparams = hparams.copy() self.kwargs = dict( - fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL, - resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION, + fill=hparams['img_mean'] if 'img_mean' in hparams else _FILL, + interpolation=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION, ) # If magnitude_std is > 0, we introduce some randomness @@ -564,7 +578,7 @@ def auto_augment_policy(name='v0', hparams=None): class AutoAugment: - def __init__(self, policy): + def __init__(self, policy: List): self.policy = policy def __call__(self, img): @@ -729,8 +743,14 @@ def rand_augment_ops( ): hparams = hparams or _HPARAMS_DEFAULT transforms = transforms or _RAND_TRANSFORMS - return [AugmentOp( - name, prob=prob, magnitude=magnitude, hparams=hparams) for name in transforms] + return [ + AugmentOp( + name, + prob=prob, + magnitude=magnitude, + hparams=hparams + ) for name in transforms + ] class RandAugment: diff --git a/timm/data/loader.py b/timm/data/loader.py index 3b4a6d0ed..8253b2e0e 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -87,7 +87,8 @@ def __init__( re_prob=0., re_mode='const', re_count=1, - re_num_splits=0): + re_num_splits=0, + ): mean = adapt_to_chs(mean, channels) std = adapt_to_chs(std, channels) diff --git a/timm/data/readers/reader_tfds.py b/timm/data/readers/reader_tfds.py index a33bd5059..c224a9678 100644 --- a/timm/data/readers/reader_tfds.py +++ b/timm/data/readers/reader_tfds.py @@ -15,25 +15,36 @@ import torch.distributed as dist from PIL import Image -try: - import tensorflow as tf - tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu) - import tensorflow_datasets as tfds - try: - tfds.even_splits('', 1, drop_remainder=False) # non-buggy even_splits has drop_remainder arg - has_buggy_even_splits = False - except TypeError: - print("Warning: This version of tfds doesn't have the latest even_splits impl. " - "Please update or use tfds-nightly for better fine-grained split behaviour.") - has_buggy_even_splits = True - # NOTE uncomment below if having file limit issues on dataset build (or alter your OS defaults) - # import resource - # low, high = resource.getrlimit(resource.RLIMIT_NOFILE) - # resource.setrlimit(resource.RLIMIT_NOFILE, (high, high)) -except ImportError as e: - print(e) - print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.") - raise e +import importlib + +class LazyTfLoader: + def __init__(self): + self._tf = None + + def __getattr__(self, name): + if self._tf is None: + self._tf = importlib.import_module('tensorflow') + self._tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu) + return getattr(self._tf, name) + +class LazyTfdsLoader: + def __init__(self): + self._tfds = None + self.has_buggy_even_splits = False + + def __getattr__(self, name): + if self._tfds is None: + self._tfds = importlib.import_module('tensorflow_datasets') + try: + self._tfds.even_splits('', 1, drop_remainder=False) # non-buggy even_splits has drop_remainder arg + except TypeError: + print("Warning: This version of tfds doesn't have the latest even_splits impl. " + "Please update or use tfds-nightly for better fine-grained split behaviour.") + self.has_buggy_even_splits = True + return getattr(self._tfds, name) + +tf = LazyTfLoader() +tfds = LazyTfdsLoader() from .class_map import load_class_map from .reader import Reader @@ -45,7 +56,6 @@ PREFETCH_SIZE = int(os.environ.get('TFDS_PREFETCH_SIZE', 2048)) # samples to prefetch -@tfds.decode.make_decoder() def decode_example(serialized_image, feature, dct_method='INTEGER_ACCURATE', channels=3): return tf.image.decode_jpeg( serialized_image, @@ -231,7 +241,7 @@ def _lazy_init(self): if should_subsplit: # split the dataset w/o using sharding for more even samples / worker, can result in less optimal # read patterns for distributed training (overlap across shards) so better to use InputContext there - if has_buggy_even_splits: + if tfds.has_buggy_even_splits: # my even_split workaround doesn't work on subsplits, upgrade tfds! if not isinstance(self.split_info, tfds.core.splits.SubSplitInfo): subsplits = even_split_indices(self.split, self.global_num_workers, self.num_samples) @@ -253,10 +263,11 @@ def _lazy_init(self): shuffle_reshuffle_each_iteration=True, input_context=input_context, ) + decode_fn = tfds.decode.make_decoder()(decode_example) ds = self.builder.as_dataset( split=self.subsplit or self.split, shuffle_files=self.is_training, - decoders=dict(image=decode_example(channels=1 if self.input_img_mode == 'L' else 3)), + decoders=dict(image=decode_fn(channels=1 if self.input_img_mode == 'L' else 3)), read_config=read_config, ) # avoid overloading threading w/ combo of TF ds threads + PyTorch workers diff --git a/timm/data/transforms.py b/timm/data/transforms.py index 215b7b5b6..f82a8cf21 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -12,7 +12,8 @@ has_interpolation_mode = True except ImportError: has_interpolation_mode = False -from PIL import Image +from PIL import Image, ImageCms + import numpy as np __all__ = [ @@ -89,6 +90,141 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}()" +class ToLabPIL: + + def __init__(self) -> None: + super().__init__() + rgb_profile = ImageCms.createProfile(colorSpace='sRGB') + lab_profile = ImageCms.createProfile(colorSpace='LAB') + # Create a transform object from the input and output profiles + self.rgb_to_lab_transform = ImageCms.buildTransform( + inputProfile=rgb_profile, + outputProfile=lab_profile, + inMode='RGB', + outMode='LAB' + ) + + def __call__(self, pic) -> torch.Tensor: + lab_image = ImageCms.applyTransform( + im=pic, + transform=self.rgb_to_lab_transform + ) + return lab_image + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + +def srgb_to_linear(srgb_image: torch.Tensor) -> torch.Tensor: + return torch.where( + srgb_image <= 0.04045, + srgb_image / 12.92, + ((srgb_image + 0.055) / 1.055) ** 2.4 + ) + + +def rgb_to_lab_tensor( + rgb_img: torch.Tensor, + normalized: bool = True, + srgb_input: bool = True, + split_channels: bool = False, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: + """ + Convert RGB image to LAB color space using tensor operations. + + Args: + rgb_img: Tensor of shape (..., 3) with values in range [0, 255] + normalized: If True, outputs L,a,b in [0, 1] range instead of native LAB ranges + srgb_input: Input is gamma corrected sRGB, otherwise linear RGB is assumed (rare unless part of a pipeline) + split_channels: If True, outputs a tuple of flattened colour channels instead of stacked image + Returns: + lab_img: Tensor of same shape with either: + - normalized=False: L in [0, 100] and a,b in [-128, 127] + - normalized=True: L,a,b in [0, 1] + """ + # Constants + epsilon = 216 / 24389 + kappa = 24389 / 27 + xn = 0.95047 + yn = 1.0 + zn = 1.08883 + + # Convert sRGB to linear RGB + if srgb_input: + rgb_img = srgb_to_linear(rgb_img) + + # FIXME transforms before this are causing -ve values, can have a large impact on this conversion + rgb_img = rgb_img.clamp(0, 1.0) + + # Convert to XYZ using matrix multiplication + rgb_to_xyz = torch.tensor([ + # X Y Z + [0.412453, 0.212671, 0.019334], # R + [0.357580, 0.715160, 0.119193], # G + [0.180423, 0.072169, 0.950227], # B + ], device=rgb_img.device) + + # Reshape input for matrix multiplication if needed + original_shape = rgb_img.shape + if len(original_shape) > 2: + rgb_img = rgb_img.reshape(-1, 3) + + # Perform matrix multiplication + xyz = rgb_img @ rgb_to_xyz + + # Adjust XYZ values + xyz.div_(torch.tensor([xn, yn, zn], device=xyz.device)) + + # Step 4: XYZ to LAB + fxfyfz = torch.where( + xyz > epsilon, + torch.pow(xyz, 1 / 3), + (kappa * xyz + 16) / 116 + ) + + L = 116 * fxfyfz[..., 1] - 16 + a = 500 * (fxfyfz[..., 0] - fxfyfz[..., 1]) + b = 200 * (fxfyfz[..., 1] - fxfyfz[..., 2]) + if normalized: + # output in rage [0, 1] for each channel + L.div_(100) + a.add_(128).div_(255) + b.add_(128).div_(255) + + if split_channels: + return L, a, b + + lab = torch.stack([L, a, b], dim=-1) + + # Restore original shape if needed + if len(original_shape) > 2: + lab = lab.reshape(original_shape) + + return lab + + +class ToLabTensor: + def __init__(self, srgb_input=False, normalized=True) -> None: + self.srgb_input = srgb_input + self.normalized = normalized + + def __call__(self, pic) -> torch.Tensor: + return rgb_to_lab_tensor( + pic, + normalized=self.normalized, + srgb_input=self.srgb_input, + ) + + +class ToLinearRgb: + def __init__(self): + pass + + def __call__(self, pic) -> torch.Tensor: + assert isinstance(pic, torch.Tensor) + return srgb_to_linear(pic) + + # Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in # favor of the Image.Resampling enum. The top-level resampling attributes will be # removed in Pillow 10. diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py index 9be0e3bf3..2f387ca5d 100644 --- a/timm/data/transforms_factory.py +++ b/timm/data/transforms_factory.py @@ -4,6 +4,7 @@ Hacked together by / Copyright 2019, Ross Wightman """ import math +from copy import deepcopy from typing import Optional, Tuple, Union import torch @@ -13,6 +14,7 @@ from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation, \ ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad, TrimBorder, ToNumpy, MaybeToTensor, MaybePILToTensor +from timm.data.transforms import ToLabTensor, ToLinearRgb from timm.data.random_erasing import RandomErasing @@ -84,6 +86,7 @@ def transforms_imagenet_train( use_prefetcher: bool = False, normalize: bool = True, separate: bool = False, + use_tensor: Optional[bool] = False, ): """ ImageNet-oriented image transforms for training. @@ -111,6 +114,7 @@ def transforms_imagenet_train( use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize. normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used). separate: Output transforms in 3-stage tuple. + use_tensor: Use of float [0, 1.0) tensors for image transforms Returns: If separate==True, the transforms are returned as a tuple of 3 separate transforms @@ -119,13 +123,21 @@ def transforms_imagenet_train( * a portion of the data through the secondary transform * normalizes and converts the branches above with the third, final transform """ + if use_tensor: + primary_tfl = [ + MaybeToTensor(), + ToLinearRgb(), # FIXME + ] + else: + primary_tfl = [] + train_crop_mode = train_crop_mode or 'rrc' assert train_crop_mode in {'rrc', 'rkrc', 'rkrr'} if train_crop_mode in ('rkrc', 'rkrr'): # FIXME integration of RKR is a WIP scale = tuple(scale or (0.8, 1.00)) ratio = tuple(ratio or (0.9, 1/.9)) - primary_tfl = [ + primary_tfl += [ ResizeKeepRatio( img_size, interpolation=interpolation, @@ -142,7 +154,7 @@ def transforms_imagenet_train( else: scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range - primary_tfl = [ + primary_tfl += [ RandomResizedCropAndInterpolation( img_size, scale=scale, @@ -166,9 +178,13 @@ def transforms_imagenet_train( img_size_min = min(img_size) else: img_size_min = img_size + if use_tensor: + aa_mean = deepcopy(mean) + else: + aa_mean = tuple([min(255, round(255 * x)) for x in mean]) aa_params = dict( translate_const=int(img_size_min * 0.45), - img_mean=tuple([min(255, round(255 * x)) for x in mean]), + img_mean=aa_mean, ) if interpolation and interpolation != 'random': aa_params['interpolation'] = str_to_pil_interp(interpolation) @@ -218,10 +234,13 @@ def transforms_imagenet_train( final_tfl += [ToNumpy()] elif not normalize: # when normalize disable, converted to tensor without scaling, keeps original dtype - final_tfl += [MaybePILToTensor()] + if not use_tensor: + final_tfl += [MaybePILToTensor()] else: + if not use_tensor: + final_tfl += [MaybeToTensor()] final_tfl += [ - MaybeToTensor(), + ToLabTensor(), # FIXME transforms.Normalize( mean=torch.tensor(mean), std=torch.tensor(std), @@ -254,6 +273,7 @@ def transforms_imagenet_eval( std: Tuple[float, ...] = IMAGENET_DEFAULT_STD, use_prefetcher: bool = False, normalize: bool = True, + use_tensor: bool = False, ): """ ImageNet-oriented image transform for evaluation and inference. @@ -280,7 +300,13 @@ def transforms_imagenet_eval( scale_size = math.floor(img_size / crop_pct) scale_size = (scale_size, scale_size) - tfl = [] + if use_tensor: + tfl = [ + MaybeToTensor(), + ToLinearRgb(), # FIXME + ] + else: + tfl = [] if crop_border_pixels: tfl += [TrimBorder(crop_border_pixels)] @@ -318,10 +344,13 @@ def transforms_imagenet_eval( tfl += [ToNumpy()] elif not normalize: # when normalize disabled, converted to tensor without scaling, keeps original dtype - tfl += [MaybePILToTensor()] + if not use_tensor: + tfl += [MaybePILToTensor()] else: + if not use_tensor: + tfl += [MaybeToTensor()] tfl += [ - MaybeToTensor(), + ToLabTensor(), # FIXME transforms.Normalize( mean=torch.tensor(mean), std=torch.tensor(std),