From 3b181b78d1f6fc1f35848d0f26bb37efb1b12798 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 18 Dec 2024 12:24:04 -0800 Subject: [PATCH 1/3] Updating augmentations, esp randaug to support full torch.Tensor pipeline --- timm/data/auto_augment.py | 198 ++++++++++++++++++-------------- timm/data/loader.py | 3 +- timm/data/transforms.py | 28 ++++- timm/data/transforms_factory.py | 24 +++- 4 files changed, 157 insertions(+), 96 deletions(-) diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py index 94438a0e06..f1feada3a6 100644 --- a/timm/data/auto_augment.py +++ b/timm/data/auto_augment.py @@ -24,12 +24,18 @@ import math import re from functools import partial -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union -from PIL import Image, ImageOps, ImageEnhance, ImageChops, ImageFilter +import torch import PIL import numpy as np - +from PIL import Image, ImageFilter +from torchvision.transforms import InterpolationMode +import torchvision.transforms.functional as TF +try: + import torchvision.transforms.v2.functional as TF2 +except ImportError: + TF2 = None _PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]]) @@ -42,118 +48,111 @@ img_mean=_FILL, ) -if hasattr(Image, "Resampling"): - _RANDOM_INTERPOLATION = (Image.Resampling.BILINEAR, Image.Resampling.BICUBIC) - _DEFAULT_INTERPOLATION = Image.Resampling.BICUBIC -else: - _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) - _DEFAULT_INTERPOLATION = Image.BICUBIC + +_RANDOM_INTERPOLATION = (InterpolationMode.BILINEAR, InterpolationMode.BICUBIC) +_DEFAULT_INTERPOLATION = InterpolationMode.BICUBIC -def _interpolation(kwargs): - interpolation = kwargs.pop('resample', _DEFAULT_INTERPOLATION) +def _interpolation(kwargs, basic_only=False): + interpolation = kwargs.pop('interpolation', _DEFAULT_INTERPOLATION) if isinstance(interpolation, (list, tuple)): - return random.choice(interpolation) + interpolation = random.choice(interpolation) + if basic_only: + if interpolation not in (InterpolationMode.NEAREST, InterpolationMode.BILINEAR): + interpolation = InterpolationMode.BILINEAR return interpolation def _check_args_tf(kwargs): - if 'fillcolor' in kwargs and _PIL_VER < (5, 0): - kwargs.pop('fillcolor') - kwargs['resample'] = _interpolation(kwargs) + kwargs['interpolation'] = _interpolation(kwargs) + + +def _check_args_affine(img, kwargs): + if isinstance(img, torch.Tensor): + kwargs['interpolation'] = _interpolation(kwargs, basic_only=True) + else: + kwargs['interpolation'] = _interpolation(kwargs) def shear_x(img, factor, **kwargs): - _check_args_tf(kwargs) - return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) + _check_args_affine(img, kwargs) + return TF.affine(img, angle=0, translate=[0, 0], scale=1, shear=[math.degrees(math.atan(factor)), 0], **kwargs) def shear_y(img, factor, **kwargs): - _check_args_tf(kwargs) - return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) + _check_args_affine(img, kwargs) + return TF.affine(img, angle=0, translate=[0, 0], scale=1, shear=[0, math.degrees(math.atan(factor))], **kwargs) -def translate_x_rel(img, pct, **kwargs): - pixels = pct * img.size[0] - _check_args_tf(kwargs) - return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) +def translate_x_abs(img, pixels, **kwargs): + _check_args_affine(img, kwargs) + return TF.affine(img, angle=0, translate=[pixels, 0], scale=1, shear=[0, 0], **kwargs) -def translate_y_rel(img, pct, **kwargs): - pixels = pct * img.size[1] - _check_args_tf(kwargs) - return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) +def translate_y_abs(img, pixels, **kwargs): + _check_args_affine(img, kwargs) + return TF.affine(img, angle=0, translate=[0, pixels], scale=1, shear=[0, 0], **kwargs) -def translate_x_abs(img, pixels, **kwargs): - _check_args_tf(kwargs) - return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) +def translate_x_rel(img, pct, **kwargs): + pixels = pct * TF.get_image_size(img)[0] + return translate_x_abs(img, pixels, **kwargs) -def translate_y_abs(img, pixels, **kwargs): - _check_args_tf(kwargs) - return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) +def translate_y_rel(img, pct, **kwargs): + pixels = pct * TF.get_image_size(img)[1] + return translate_y_abs(img, pixels, **kwargs) def rotate(img, degrees, **kwargs): - _check_args_tf(kwargs) - if _PIL_VER >= (5, 2): - return img.rotate(degrees, **kwargs) - if _PIL_VER >= (5, 0): - w, h = img.size - post_trans = (0, 0) - rotn_center = (w / 2.0, h / 2.0) - angle = -math.radians(degrees) - matrix = [ - round(math.cos(angle), 15), - round(math.sin(angle), 15), - 0.0, - round(-math.sin(angle), 15), - round(math.cos(angle), 15), - 0.0, - ] - - def transform(x, y, matrix): - (a, b, c, d, e, f) = matrix - return a * x + b * y + c, d * x + e * y + f - - matrix[2], matrix[5] = transform( - -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix - ) - matrix[2] += rotn_center[0] - matrix[5] += rotn_center[1] - return img.transform(img.size, Image.AFFINE, matrix, **kwargs) - return img.rotate(degrees, resample=kwargs['resample']) + _check_args_affine(img, kwargs) + return TF.rotate(img, degrees, **kwargs) def auto_contrast(img, **__): - return ImageOps.autocontrast(img) + return TF.autocontrast(img) def invert(img, **__): - return ImageOps.invert(img) + return TF.invert(img) def equalize(img, **__): - return ImageOps.equalize(img) + if isinstance(img, torch.Tensor) and img.is_floating_point(): + if TF2 is None: + # FIXME warn / assert? + return img + return TF2.equalize(img) + return TF.equalize(img) def solarize(img, thresh, **__): - return ImageOps.solarize(img, thresh) + if isinstance(img, torch.Tensor) and img.is_floating_point(): + thresh = min(thresh / 255, 1.0) + return TF.solarize(img, thresh) def solarize_add(img, add, thresh=128, **__): - lut = [] - for i in range(256): - if i < thresh: - lut.append(min(255, i + add)) + if isinstance(img, torch.Tensor): + if img.is_floating_point(): + thresh = thresh / 255 + add = add / 255 + img_sum = (img + add).clamp_(max=1.0) else: - lut.append(i) + img_sum = (img + add).clamp_(max=255) + return torch.where(img >= thresh, img_sum, img) + else: + lut = [] + for i in range(256): + if i < thresh: + lut.append(min(255, i + add)) + else: + lut.append(i) - if img.mode in ("L", "RGB"): - if img.mode == "RGB" and len(lut) == 256: - lut = lut + lut + lut - return img.point(lut) + if img.mode in ("L", "RGB"): + if img.mode == "RGB" and len(lut) == 256: + lut = lut + lut + lut + return img.point(lut) return img @@ -161,41 +160,50 @@ def solarize_add(img, add, thresh=128, **__): def posterize(img, bits_to_keep, **__): if bits_to_keep >= 8: return img - return ImageOps.posterize(img, bits_to_keep) + if isinstance(img, torch.Tensor) and img.is_floating_point(): + if TF2 is None: + # FIXME warn / assert? + return img + return TF2.posterize(img, bits_to_keep) + return TF.posterize(img, bits_to_keep) def contrast(img, factor, **__): - return ImageEnhance.Contrast(img).enhance(factor) + return TF.adjust_contrast(img, factor) def color(img, factor, **__): - return ImageEnhance.Color(img).enhance(factor) + return TF.adjust_saturation(img, factor) def brightness(img, factor, **__): - return ImageEnhance.Brightness(img).enhance(factor) + return TF.adjust_brightness(img, factor) def sharpness(img, factor, **__): - return ImageEnhance.Sharpness(img).enhance(factor) + return TF.adjust_sharpness(img, factor) def gaussian_blur(img, factor, **__): - img = img.filter(ImageFilter.GaussianBlur(radius=factor)) + if isinstance(img, torch.Tensor): + kernel_size = 2 * int(3 * factor) + 1 # could be bigger, but more expensive + img = TF.gaussian_blur(img, kernel_size=kernel_size, sigma=factor) + else: + img = img.filter(ImageFilter.GaussianBlur(radius=factor)) return img def gaussian_blur_rand(img, factor, **__): radius_min = 0.1 radius_max = 2.0 - img = img.filter(ImageFilter.GaussianBlur(radius=random.uniform(radius_min, radius_max * factor))) - return img + radius = random.uniform(radius_min, radius_max * factor) + return gaussian_blur(img, radius) def desaturate(img, factor, **_): factor = min(1., max(0., 1. - factor)) # enhance factor 0 = grayscale, 1.0 = no-change - return ImageEnhance.Color(img).enhance(factor) + return TF.adjust_saturation(img, factor) def _randomly_negate(v): @@ -356,7 +364,13 @@ def _solarize_add_level_to_arg(level, _hparams): class AugmentOp: - def __init__(self, name, prob=0.5, magnitude=10, hparams=None): + def __init__( + self, + name: str, + prob: float = 0.5, + magnitude: float = 10, + hparams: Optional[Dict[str, Any]] = None + ): hparams = hparams or _HPARAMS_DEFAULT self.name = name self.aug_fn = NAME_TO_OP[name] @@ -365,8 +379,8 @@ def __init__(self, name, prob=0.5, magnitude=10, hparams=None): self.magnitude = magnitude self.hparams = hparams.copy() self.kwargs = dict( - fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL, - resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION, + fill=hparams['img_mean'] if 'img_mean' in hparams else _FILL, + interpolation=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION, ) # If magnitude_std is > 0, we introduce some randomness @@ -564,7 +578,7 @@ def auto_augment_policy(name='v0', hparams=None): class AutoAugment: - def __init__(self, policy): + def __init__(self, policy: List): self.policy = policy def __call__(self, img): @@ -729,8 +743,14 @@ def rand_augment_ops( ): hparams = hparams or _HPARAMS_DEFAULT transforms = transforms or _RAND_TRANSFORMS - return [AugmentOp( - name, prob=prob, magnitude=magnitude, hparams=hparams) for name in transforms] + return [ + AugmentOp( + name, + prob=prob, + magnitude=magnitude, + hparams=hparams + ) for name in transforms + ] class RandAugment: diff --git a/timm/data/loader.py b/timm/data/loader.py index 3b4a6d0ed6..8253b2e0e6 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -87,7 +87,8 @@ def __init__( re_prob=0., re_mode='const', re_count=1, - re_num_splits=0): + re_num_splits=0, + ): mean = adapt_to_chs(mean, channels) std = adapt_to_chs(std, channels) diff --git a/timm/data/transforms.py b/timm/data/transforms.py index 215b7b5b61..e0c7e7f907 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -12,7 +12,8 @@ has_interpolation_mode = True except ImportError: has_interpolation_mode = False -from PIL import Image +from PIL import Image, ImageCms + import numpy as np __all__ = [ @@ -89,6 +90,31 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}()" +class ToLab(transforms.ToTensor): + + def __init__(self) -> None: + super().__init__() + rgb_profile = ImageCms.createProfile(colorSpace='sRGB') + lab_profile = ImageCms.createProfile(colorSpace='LAB') + # Create a transform object from the input and output profiles + self.rgb_to_lab_transform = ImageCms.buildTransform( + inputProfile=rgb_profile, + outputProfile=lab_profile, + inMode='RGB', + outMode='LAB' + ) + + def __call__(self, pic) -> torch.Tensor: + lab_image = ImageCms.applyTransform( + im=pic, + transform=self.rgb_to_lab_transform + ) + return lab_image + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + # Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in # favor of the Image.Resampling enum. The top-level resampling attributes will be # removed in Pillow 10. diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py index 9be0e3bf3c..5653109f78 100644 --- a/timm/data/transforms_factory.py +++ b/timm/data/transforms_factory.py @@ -4,6 +4,7 @@ Hacked together by / Copyright 2019, Ross Wightman """ import math +from copy import deepcopy from typing import Optional, Tuple, Union import torch @@ -84,6 +85,7 @@ def transforms_imagenet_train( use_prefetcher: bool = False, normalize: bool = True, separate: bool = False, + use_tensor: Optional[bool] = True, # FIXME forced True for testing ): """ ImageNet-oriented image transforms for training. @@ -111,6 +113,7 @@ def transforms_imagenet_train( use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize. normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used). separate: Output transforms in 3-stage tuple. + use_tensor: Use of float [0, 1.0) tensors for image transforms Returns: If separate==True, the transforms are returned as a tuple of 3 separate transforms @@ -119,13 +122,18 @@ def transforms_imagenet_train( * a portion of the data through the secondary transform * normalizes and converts the branches above with the third, final transform """ + if use_tensor: + primary_tfl = [MaybeToTensor()] + else: + primary_tfl = [] + train_crop_mode = train_crop_mode or 'rrc' assert train_crop_mode in {'rrc', 'rkrc', 'rkrr'} if train_crop_mode in ('rkrc', 'rkrr'): # FIXME integration of RKR is a WIP scale = tuple(scale or (0.8, 1.00)) ratio = tuple(ratio or (0.9, 1/.9)) - primary_tfl = [ + primary_tfl += [ ResizeKeepRatio( img_size, interpolation=interpolation, @@ -142,7 +150,7 @@ def transforms_imagenet_train( else: scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range - primary_tfl = [ + primary_tfl += [ RandomResizedCropAndInterpolation( img_size, scale=scale, @@ -166,9 +174,13 @@ def transforms_imagenet_train( img_size_min = min(img_size) else: img_size_min = img_size + if use_tensor: + aa_mean = deepcopy(mean) + else: + aa_mean = tuple([min(255, round(255 * x)) for x in mean]) aa_params = dict( translate_const=int(img_size_min * 0.45), - img_mean=tuple([min(255, round(255 * x)) for x in mean]), + img_mean=aa_mean, ) if interpolation and interpolation != 'random': aa_params['interpolation'] = str_to_pil_interp(interpolation) @@ -218,10 +230,12 @@ def transforms_imagenet_train( final_tfl += [ToNumpy()] elif not normalize: # when normalize disable, converted to tensor without scaling, keeps original dtype - final_tfl += [MaybePILToTensor()] + if not use_tensor: + final_tfl += [MaybePILToTensor()] else: + if not use_tensor: + final_tfl += [MaybeToTensor()] final_tfl += [ - MaybeToTensor(), transforms.Normalize( mean=torch.tensor(mean), std=torch.tensor(std), From 3fbbd511e64c979f555899304e0375bf02c97fd3 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 18 Dec 2024 16:49:17 -0800 Subject: [PATCH 2/3] Testing some LAB stuff --- timm/data/transforms.py | 117 +++++++++++++++++++++++++++++++- timm/data/transforms_factory.py | 23 +++++-- 2 files changed, 135 insertions(+), 5 deletions(-) diff --git a/timm/data/transforms.py b/timm/data/transforms.py index e0c7e7f907..f318b9a289 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -90,7 +90,7 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}()" -class ToLab(transforms.ToTensor): +class ToLabPIL: def __init__(self) -> None: super().__init__() @@ -115,6 +115,121 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}()" +def srgb_to_linear(srgb_image: torch.Tensor) -> torch.Tensor: + return torch.where( + srgb_image <= 0.04045, + srgb_image / 12.92, + ((srgb_image + 0.055) / 1.055) ** 2.4 + ) + + +def rgb_to_lab_tensor( + rgb_img: torch.Tensor, + normalized: bool = True, + srgb_input: bool = True, +) -> torch.Tensor: + """ + Convert RGB image to LAB color space using tensor operations. + + Args: + rgb_img: Tensor of shape (..., 3) with values in range [0, 255] + normalized: If True, outputs L,a,b in [0, 1] range instead of native LAB ranges + + Returns: + lab_img: Tensor of same shape with either: + - normalized=False: L in [0, 100] and a,b in [-128, 127] + - normalized=True: L,a,b in [0, 1] + """ + # Constants + epsilon = 216 / 24389 + kappa = 24389 / 27 + xn = 0.95047 + yn = 1.0 + zn = 1.08883 + + # Convert sRGB to linear RGB + if srgb_input: + rgb_img = srgb_to_linear(rgb_img) + + # FIXME transforms before this are causing -ve values, can have a large impact on this conversion + rgb_img.clamp_(0, 1.0) + + # Convert to XYZ using matrix multiplication + rgb_to_xyz = torch.tensor([ + [0.412453, 0.357580, 0.180423], + [0.212671, 0.715160, 0.072169], + [0.019334, 0.119193, 0.950227] + ], device=rgb_img.device) + + # Reshape input for matrix multiplication if needed + original_shape = rgb_img.shape + if len(original_shape) > 2: + rgb_img = rgb_img.reshape(-1, 3) + + # Perform matrix multiplication + xyz = torch.matmul(rgb_img, rgb_to_xyz.T) + + # Adjust XYZ values + xyz[..., 0].div_(xn) + xyz[..., 1].div_(yn) + xyz[..., 2].div_(zn) + + # Step 4: XYZ to LAB + lab = torch.where( + xyz > epsilon, + torch.pow(xyz, 1 / 3), + (kappa * xyz + 16) / 116 + ) + + if normalized: + # Calculate normalized [0,1] L,a,b values directly + # L: map [0,100] to [0,1] : (116y - 16)/100 = 1.16y - 0.16 + # a: map [-128,127] to [0,1] : (500(x-y) + 128)/255 ≈ 1.96(x-y) + 0.502 + # b: map [-128,127] to [0,1] : (200(y-z) + 128)/255 ≈ 0.784(y-z) + 0.502 + shift_128 = 128 / 255 + a_scale = 500 / 255 + b_scale = 200 / 255 + L = 1.16 * lab[..., 1] - 0.16 + a = a_scale * (lab[..., 0] - lab[..., 1]) + shift_128 + b = b_scale * (lab[..., 1] - lab[..., 2]) + shift_128 + else: + # Calculate native range L,a,b values + L = 116 * lab[..., 1] - 16 + a = 500 * (lab[..., 0] - lab[..., 1]) + b = 200 * (lab[..., 1] - lab[..., 2]) + + # Stack the results + lab = torch.stack([L, a, b], dim=-1) + + # Restore original shape if needed + if len(original_shape) > 2: + lab = lab.reshape(original_shape) + + return lab + + +class ToLabTensor: + def __init__(self, srgb_input=False, normalized=True) -> None: + self.srgb_input = srgb_input + self.normalized = normalized + + def __call__(self, pic) -> torch.Tensor: + return rgb_to_lab_tensor( + pic, + normalized=self.normalized, + srgb_input=self.srgb_input, + ) + + +class ToLinearRgb: + def __init__(self): + pass + + def __call__(self, pic) -> torch.Tensor: + assert isinstance(pic, torch.Tensor) + return srgb_to_linear(pic) + + # Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in # favor of the Image.Resampling enum. The top-level resampling attributes will be # removed in Pillow 10. diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py index 5653109f78..a363a4bbe7 100644 --- a/timm/data/transforms_factory.py +++ b/timm/data/transforms_factory.py @@ -14,6 +14,7 @@ from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation, \ ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad, TrimBorder, ToNumpy, MaybeToTensor, MaybePILToTensor +from timm.data.transforms import ToLabTensor, ToLinearRgb from timm.data.random_erasing import RandomErasing @@ -123,7 +124,10 @@ def transforms_imagenet_train( * normalizes and converts the branches above with the third, final transform """ if use_tensor: - primary_tfl = [MaybeToTensor()] + primary_tfl = [ + MaybeToTensor(), + ToLinearRgb(), # FIXME + ] else: primary_tfl = [] @@ -236,6 +240,7 @@ def transforms_imagenet_train( if not use_tensor: final_tfl += [MaybeToTensor()] final_tfl += [ + ToLabTensor(), # FIXME transforms.Normalize( mean=torch.tensor(mean), std=torch.tensor(std), @@ -268,6 +273,7 @@ def transforms_imagenet_eval( std: Tuple[float, ...] = IMAGENET_DEFAULT_STD, use_prefetcher: bool = False, normalize: bool = True, + use_tensor: bool = True, ): """ ImageNet-oriented image transform for evaluation and inference. @@ -294,7 +300,13 @@ def transforms_imagenet_eval( scale_size = math.floor(img_size / crop_pct) scale_size = (scale_size, scale_size) - tfl = [] + if use_tensor: + tfl = [ + MaybeToTensor(), + ToLinearRgb(), # FIXME + ] + else: + tfl = [] if crop_border_pixels: tfl += [TrimBorder(crop_border_pixels)] @@ -332,10 +344,13 @@ def transforms_imagenet_eval( tfl += [ToNumpy()] elif not normalize: # when normalize disabled, converted to tensor without scaling, keeps original dtype - tfl += [MaybePILToTensor()] + if not use_tensor: + tfl += [MaybePILToTensor()] else: + if not use_tensor: + tfl += [MaybeToTensor()] tfl += [ - MaybeToTensor(), + ToLabTensor(), # FIXME transforms.Normalize( mean=torch.tensor(mean), std=torch.tensor(std), From d285526dc9d083b28329802883b4fc966345982a Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 23 Dec 2024 13:24:11 -0800 Subject: [PATCH 3/3] Lazy loader for TF, more LAB fiddling --- timm/data/readers/reader_tfds.py | 55 +++++++++++++++++++------------- timm/data/transforms.py | 49 +++++++++++++--------------- timm/data/transforms_factory.py | 4 +-- 3 files changed, 57 insertions(+), 51 deletions(-) diff --git a/timm/data/readers/reader_tfds.py b/timm/data/readers/reader_tfds.py index a33bd5059a..c224a96787 100644 --- a/timm/data/readers/reader_tfds.py +++ b/timm/data/readers/reader_tfds.py @@ -15,25 +15,36 @@ import torch.distributed as dist from PIL import Image -try: - import tensorflow as tf - tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu) - import tensorflow_datasets as tfds - try: - tfds.even_splits('', 1, drop_remainder=False) # non-buggy even_splits has drop_remainder arg - has_buggy_even_splits = False - except TypeError: - print("Warning: This version of tfds doesn't have the latest even_splits impl. " - "Please update or use tfds-nightly for better fine-grained split behaviour.") - has_buggy_even_splits = True - # NOTE uncomment below if having file limit issues on dataset build (or alter your OS defaults) - # import resource - # low, high = resource.getrlimit(resource.RLIMIT_NOFILE) - # resource.setrlimit(resource.RLIMIT_NOFILE, (high, high)) -except ImportError as e: - print(e) - print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.") - raise e +import importlib + +class LazyTfLoader: + def __init__(self): + self._tf = None + + def __getattr__(self, name): + if self._tf is None: + self._tf = importlib.import_module('tensorflow') + self._tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu) + return getattr(self._tf, name) + +class LazyTfdsLoader: + def __init__(self): + self._tfds = None + self.has_buggy_even_splits = False + + def __getattr__(self, name): + if self._tfds is None: + self._tfds = importlib.import_module('tensorflow_datasets') + try: + self._tfds.even_splits('', 1, drop_remainder=False) # non-buggy even_splits has drop_remainder arg + except TypeError: + print("Warning: This version of tfds doesn't have the latest even_splits impl. " + "Please update or use tfds-nightly for better fine-grained split behaviour.") + self.has_buggy_even_splits = True + return getattr(self._tfds, name) + +tf = LazyTfLoader() +tfds = LazyTfdsLoader() from .class_map import load_class_map from .reader import Reader @@ -45,7 +56,6 @@ PREFETCH_SIZE = int(os.environ.get('TFDS_PREFETCH_SIZE', 2048)) # samples to prefetch -@tfds.decode.make_decoder() def decode_example(serialized_image, feature, dct_method='INTEGER_ACCURATE', channels=3): return tf.image.decode_jpeg( serialized_image, @@ -231,7 +241,7 @@ def _lazy_init(self): if should_subsplit: # split the dataset w/o using sharding for more even samples / worker, can result in less optimal # read patterns for distributed training (overlap across shards) so better to use InputContext there - if has_buggy_even_splits: + if tfds.has_buggy_even_splits: # my even_split workaround doesn't work on subsplits, upgrade tfds! if not isinstance(self.split_info, tfds.core.splits.SubSplitInfo): subsplits = even_split_indices(self.split, self.global_num_workers, self.num_samples) @@ -253,10 +263,11 @@ def _lazy_init(self): shuffle_reshuffle_each_iteration=True, input_context=input_context, ) + decode_fn = tfds.decode.make_decoder()(decode_example) ds = self.builder.as_dataset( split=self.subsplit or self.split, shuffle_files=self.is_training, - decoders=dict(image=decode_example(channels=1 if self.input_img_mode == 'L' else 3)), + decoders=dict(image=decode_fn(channels=1 if self.input_img_mode == 'L' else 3)), read_config=read_config, ) # avoid overloading threading w/ combo of TF ds threads + PyTorch workers diff --git a/timm/data/transforms.py b/timm/data/transforms.py index f318b9a289..f82a8cf215 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -127,14 +127,16 @@ def rgb_to_lab_tensor( rgb_img: torch.Tensor, normalized: bool = True, srgb_input: bool = True, -) -> torch.Tensor: + split_channels: bool = False, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """ Convert RGB image to LAB color space using tensor operations. Args: rgb_img: Tensor of shape (..., 3) with values in range [0, 255] normalized: If True, outputs L,a,b in [0, 1] range instead of native LAB ranges - + srgb_input: Input is gamma corrected sRGB, otherwise linear RGB is assumed (rare unless part of a pipeline) + split_channels: If True, outputs a tuple of flattened colour channels instead of stacked image Returns: lab_img: Tensor of same shape with either: - normalized=False: L in [0, 100] and a,b in [-128, 127] @@ -152,13 +154,14 @@ def rgb_to_lab_tensor( rgb_img = srgb_to_linear(rgb_img) # FIXME transforms before this are causing -ve values, can have a large impact on this conversion - rgb_img.clamp_(0, 1.0) + rgb_img = rgb_img.clamp(0, 1.0) # Convert to XYZ using matrix multiplication rgb_to_xyz = torch.tensor([ - [0.412453, 0.357580, 0.180423], - [0.212671, 0.715160, 0.072169], - [0.019334, 0.119193, 0.950227] + # X Y Z + [0.412453, 0.212671, 0.019334], # R + [0.357580, 0.715160, 0.119193], # G + [0.180423, 0.072169, 0.950227], # B ], device=rgb_img.device) # Reshape input for matrix multiplication if needed @@ -167,38 +170,30 @@ def rgb_to_lab_tensor( rgb_img = rgb_img.reshape(-1, 3) # Perform matrix multiplication - xyz = torch.matmul(rgb_img, rgb_to_xyz.T) + xyz = rgb_img @ rgb_to_xyz # Adjust XYZ values - xyz[..., 0].div_(xn) - xyz[..., 1].div_(yn) - xyz[..., 2].div_(zn) + xyz.div_(torch.tensor([xn, yn, zn], device=xyz.device)) # Step 4: XYZ to LAB - lab = torch.where( + fxfyfz = torch.where( xyz > epsilon, torch.pow(xyz, 1 / 3), (kappa * xyz + 16) / 116 ) + L = 116 * fxfyfz[..., 1] - 16 + a = 500 * (fxfyfz[..., 0] - fxfyfz[..., 1]) + b = 200 * (fxfyfz[..., 1] - fxfyfz[..., 2]) if normalized: - # Calculate normalized [0,1] L,a,b values directly - # L: map [0,100] to [0,1] : (116y - 16)/100 = 1.16y - 0.16 - # a: map [-128,127] to [0,1] : (500(x-y) + 128)/255 ≈ 1.96(x-y) + 0.502 - # b: map [-128,127] to [0,1] : (200(y-z) + 128)/255 ≈ 0.784(y-z) + 0.502 - shift_128 = 128 / 255 - a_scale = 500 / 255 - b_scale = 200 / 255 - L = 1.16 * lab[..., 1] - 0.16 - a = a_scale * (lab[..., 0] - lab[..., 1]) + shift_128 - b = b_scale * (lab[..., 1] - lab[..., 2]) + shift_128 - else: - # Calculate native range L,a,b values - L = 116 * lab[..., 1] - 16 - a = 500 * (lab[..., 0] - lab[..., 1]) - b = 200 * (lab[..., 1] - lab[..., 2]) + # output in rage [0, 1] for each channel + L.div_(100) + a.add_(128).div_(255) + b.add_(128).div_(255) + + if split_channels: + return L, a, b - # Stack the results lab = torch.stack([L, a, b], dim=-1) # Restore original shape if needed diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py index a363a4bbe7..2f387ca5d9 100644 --- a/timm/data/transforms_factory.py +++ b/timm/data/transforms_factory.py @@ -86,7 +86,7 @@ def transforms_imagenet_train( use_prefetcher: bool = False, normalize: bool = True, separate: bool = False, - use_tensor: Optional[bool] = True, # FIXME forced True for testing + use_tensor: Optional[bool] = False, ): """ ImageNet-oriented image transforms for training. @@ -273,7 +273,7 @@ def transforms_imagenet_eval( std: Tuple[float, ...] = IMAGENET_DEFAULT_STD, use_prefetcher: bool = False, normalize: bool = True, - use_tensor: bool = True, + use_tensor: bool = False, ): """ ImageNet-oriented image transform for evaluation and inference.