From d285526dc9d083b28329802883b4fc966345982a Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 23 Dec 2024 13:24:11 -0800 Subject: [PATCH] Lazy loader for TF, more LAB fiddling --- timm/data/readers/reader_tfds.py | 55 +++++++++++++++++++------------- timm/data/transforms.py | 49 +++++++++++++--------------- timm/data/transforms_factory.py | 4 +-- 3 files changed, 57 insertions(+), 51 deletions(-) diff --git a/timm/data/readers/reader_tfds.py b/timm/data/readers/reader_tfds.py index a33bd5059a..c224a96787 100644 --- a/timm/data/readers/reader_tfds.py +++ b/timm/data/readers/reader_tfds.py @@ -15,25 +15,36 @@ import torch.distributed as dist from PIL import Image -try: - import tensorflow as tf - tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu) - import tensorflow_datasets as tfds - try: - tfds.even_splits('', 1, drop_remainder=False) # non-buggy even_splits has drop_remainder arg - has_buggy_even_splits = False - except TypeError: - print("Warning: This version of tfds doesn't have the latest even_splits impl. " - "Please update or use tfds-nightly for better fine-grained split behaviour.") - has_buggy_even_splits = True - # NOTE uncomment below if having file limit issues on dataset build (or alter your OS defaults) - # import resource - # low, high = resource.getrlimit(resource.RLIMIT_NOFILE) - # resource.setrlimit(resource.RLIMIT_NOFILE, (high, high)) -except ImportError as e: - print(e) - print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.") - raise e +import importlib + +class LazyTfLoader: + def __init__(self): + self._tf = None + + def __getattr__(self, name): + if self._tf is None: + self._tf = importlib.import_module('tensorflow') + self._tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu) + return getattr(self._tf, name) + +class LazyTfdsLoader: + def __init__(self): + self._tfds = None + self.has_buggy_even_splits = False + + def __getattr__(self, name): + if self._tfds is None: + self._tfds = importlib.import_module('tensorflow_datasets') + try: + self._tfds.even_splits('', 1, drop_remainder=False) # non-buggy even_splits has drop_remainder arg + except TypeError: + print("Warning: This version of tfds doesn't have the latest even_splits impl. " + "Please update or use tfds-nightly for better fine-grained split behaviour.") + self.has_buggy_even_splits = True + return getattr(self._tfds, name) + +tf = LazyTfLoader() +tfds = LazyTfdsLoader() from .class_map import load_class_map from .reader import Reader @@ -45,7 +56,6 @@ PREFETCH_SIZE = int(os.environ.get('TFDS_PREFETCH_SIZE', 2048)) # samples to prefetch -@tfds.decode.make_decoder() def decode_example(serialized_image, feature, dct_method='INTEGER_ACCURATE', channels=3): return tf.image.decode_jpeg( serialized_image, @@ -231,7 +241,7 @@ def _lazy_init(self): if should_subsplit: # split the dataset w/o using sharding for more even samples / worker, can result in less optimal # read patterns for distributed training (overlap across shards) so better to use InputContext there - if has_buggy_even_splits: + if tfds.has_buggy_even_splits: # my even_split workaround doesn't work on subsplits, upgrade tfds! if not isinstance(self.split_info, tfds.core.splits.SubSplitInfo): subsplits = even_split_indices(self.split, self.global_num_workers, self.num_samples) @@ -253,10 +263,11 @@ def _lazy_init(self): shuffle_reshuffle_each_iteration=True, input_context=input_context, ) + decode_fn = tfds.decode.make_decoder()(decode_example) ds = self.builder.as_dataset( split=self.subsplit or self.split, shuffle_files=self.is_training, - decoders=dict(image=decode_example(channels=1 if self.input_img_mode == 'L' else 3)), + decoders=dict(image=decode_fn(channels=1 if self.input_img_mode == 'L' else 3)), read_config=read_config, ) # avoid overloading threading w/ combo of TF ds threads + PyTorch workers diff --git a/timm/data/transforms.py b/timm/data/transforms.py index f318b9a289..f82a8cf215 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -127,14 +127,16 @@ def rgb_to_lab_tensor( rgb_img: torch.Tensor, normalized: bool = True, srgb_input: bool = True, -) -> torch.Tensor: + split_channels: bool = False, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """ Convert RGB image to LAB color space using tensor operations. Args: rgb_img: Tensor of shape (..., 3) with values in range [0, 255] normalized: If True, outputs L,a,b in [0, 1] range instead of native LAB ranges - + srgb_input: Input is gamma corrected sRGB, otherwise linear RGB is assumed (rare unless part of a pipeline) + split_channels: If True, outputs a tuple of flattened colour channels instead of stacked image Returns: lab_img: Tensor of same shape with either: - normalized=False: L in [0, 100] and a,b in [-128, 127] @@ -152,13 +154,14 @@ def rgb_to_lab_tensor( rgb_img = srgb_to_linear(rgb_img) # FIXME transforms before this are causing -ve values, can have a large impact on this conversion - rgb_img.clamp_(0, 1.0) + rgb_img = rgb_img.clamp(0, 1.0) # Convert to XYZ using matrix multiplication rgb_to_xyz = torch.tensor([ - [0.412453, 0.357580, 0.180423], - [0.212671, 0.715160, 0.072169], - [0.019334, 0.119193, 0.950227] + # X Y Z + [0.412453, 0.212671, 0.019334], # R + [0.357580, 0.715160, 0.119193], # G + [0.180423, 0.072169, 0.950227], # B ], device=rgb_img.device) # Reshape input for matrix multiplication if needed @@ -167,38 +170,30 @@ def rgb_to_lab_tensor( rgb_img = rgb_img.reshape(-1, 3) # Perform matrix multiplication - xyz = torch.matmul(rgb_img, rgb_to_xyz.T) + xyz = rgb_img @ rgb_to_xyz # Adjust XYZ values - xyz[..., 0].div_(xn) - xyz[..., 1].div_(yn) - xyz[..., 2].div_(zn) + xyz.div_(torch.tensor([xn, yn, zn], device=xyz.device)) # Step 4: XYZ to LAB - lab = torch.where( + fxfyfz = torch.where( xyz > epsilon, torch.pow(xyz, 1 / 3), (kappa * xyz + 16) / 116 ) + L = 116 * fxfyfz[..., 1] - 16 + a = 500 * (fxfyfz[..., 0] - fxfyfz[..., 1]) + b = 200 * (fxfyfz[..., 1] - fxfyfz[..., 2]) if normalized: - # Calculate normalized [0,1] L,a,b values directly - # L: map [0,100] to [0,1] : (116y - 16)/100 = 1.16y - 0.16 - # a: map [-128,127] to [0,1] : (500(x-y) + 128)/255 ≈ 1.96(x-y) + 0.502 - # b: map [-128,127] to [0,1] : (200(y-z) + 128)/255 ≈ 0.784(y-z) + 0.502 - shift_128 = 128 / 255 - a_scale = 500 / 255 - b_scale = 200 / 255 - L = 1.16 * lab[..., 1] - 0.16 - a = a_scale * (lab[..., 0] - lab[..., 1]) + shift_128 - b = b_scale * (lab[..., 1] - lab[..., 2]) + shift_128 - else: - # Calculate native range L,a,b values - L = 116 * lab[..., 1] - 16 - a = 500 * (lab[..., 0] - lab[..., 1]) - b = 200 * (lab[..., 1] - lab[..., 2]) + # output in rage [0, 1] for each channel + L.div_(100) + a.add_(128).div_(255) + b.add_(128).div_(255) + + if split_channels: + return L, a, b - # Stack the results lab = torch.stack([L, a, b], dim=-1) # Restore original shape if needed diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py index a363a4bbe7..2f387ca5d9 100644 --- a/timm/data/transforms_factory.py +++ b/timm/data/transforms_factory.py @@ -86,7 +86,7 @@ def transforms_imagenet_train( use_prefetcher: bool = False, normalize: bool = True, separate: bool = False, - use_tensor: Optional[bool] = True, # FIXME forced True for testing + use_tensor: Optional[bool] = False, ): """ ImageNet-oriented image transforms for training. @@ -273,7 +273,7 @@ def transforms_imagenet_eval( std: Tuple[float, ...] = IMAGENET_DEFAULT_STD, use_prefetcher: bool = False, normalize: bool = True, - use_tensor: bool = True, + use_tensor: bool = False, ): """ ImageNet-oriented image transform for evaluation and inference.