Skip to content

Commit

Permalink
Lazy loader for TF, more LAB fiddling
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Dec 23, 2024
1 parent 3fbbd51 commit d285526
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 51 deletions.
55 changes: 33 additions & 22 deletions timm/data/readers/reader_tfds.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,36 @@
import torch.distributed as dist
from PIL import Image

try:
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu)
import tensorflow_datasets as tfds
try:
tfds.even_splits('', 1, drop_remainder=False) # non-buggy even_splits has drop_remainder arg
has_buggy_even_splits = False
except TypeError:
print("Warning: This version of tfds doesn't have the latest even_splits impl. "
"Please update or use tfds-nightly for better fine-grained split behaviour.")
has_buggy_even_splits = True
# NOTE uncomment below if having file limit issues on dataset build (or alter your OS defaults)
# import resource
# low, high = resource.getrlimit(resource.RLIMIT_NOFILE)
# resource.setrlimit(resource.RLIMIT_NOFILE, (high, high))
except ImportError as e:
print(e)
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
raise e
import importlib

class LazyTfLoader:
def __init__(self):
self._tf = None

def __getattr__(self, name):
if self._tf is None:
self._tf = importlib.import_module('tensorflow')
self._tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu)
return getattr(self._tf, name)

class LazyTfdsLoader:
def __init__(self):
self._tfds = None
self.has_buggy_even_splits = False

def __getattr__(self, name):
if self._tfds is None:
self._tfds = importlib.import_module('tensorflow_datasets')
try:
self._tfds.even_splits('', 1, drop_remainder=False) # non-buggy even_splits has drop_remainder arg
except TypeError:
print("Warning: This version of tfds doesn't have the latest even_splits impl. "
"Please update or use tfds-nightly for better fine-grained split behaviour.")
self.has_buggy_even_splits = True
return getattr(self._tfds, name)

tf = LazyTfLoader()
tfds = LazyTfdsLoader()

from .class_map import load_class_map
from .reader import Reader
Expand All @@ -45,7 +56,6 @@
PREFETCH_SIZE = int(os.environ.get('TFDS_PREFETCH_SIZE', 2048)) # samples to prefetch


@tfds.decode.make_decoder()
def decode_example(serialized_image, feature, dct_method='INTEGER_ACCURATE', channels=3):
return tf.image.decode_jpeg(
serialized_image,
Expand Down Expand Up @@ -231,7 +241,7 @@ def _lazy_init(self):
if should_subsplit:
# split the dataset w/o using sharding for more even samples / worker, can result in less optimal
# read patterns for distributed training (overlap across shards) so better to use InputContext there
if has_buggy_even_splits:
if tfds.has_buggy_even_splits:
# my even_split workaround doesn't work on subsplits, upgrade tfds!
if not isinstance(self.split_info, tfds.core.splits.SubSplitInfo):
subsplits = even_split_indices(self.split, self.global_num_workers, self.num_samples)
Expand All @@ -253,10 +263,11 @@ def _lazy_init(self):
shuffle_reshuffle_each_iteration=True,
input_context=input_context,
)
decode_fn = tfds.decode.make_decoder()(decode_example)
ds = self.builder.as_dataset(
split=self.subsplit or self.split,
shuffle_files=self.is_training,
decoders=dict(image=decode_example(channels=1 if self.input_img_mode == 'L' else 3)),
decoders=dict(image=decode_fn(channels=1 if self.input_img_mode == 'L' else 3)),
read_config=read_config,
)
# avoid overloading threading w/ combo of TF ds threads + PyTorch workers
Expand Down
49 changes: 22 additions & 27 deletions timm/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,16 @@ def rgb_to_lab_tensor(
rgb_img: torch.Tensor,
normalized: bool = True,
srgb_input: bool = True,
) -> torch.Tensor:
split_channels: bool = False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""
Convert RGB image to LAB color space using tensor operations.
Args:
rgb_img: Tensor of shape (..., 3) with values in range [0, 255]
normalized: If True, outputs L,a,b in [0, 1] range instead of native LAB ranges
srgb_input: Input is gamma corrected sRGB, otherwise linear RGB is assumed (rare unless part of a pipeline)
split_channels: If True, outputs a tuple of flattened colour channels instead of stacked image
Returns:
lab_img: Tensor of same shape with either:
- normalized=False: L in [0, 100] and a,b in [-128, 127]
Expand All @@ -152,13 +154,14 @@ def rgb_to_lab_tensor(
rgb_img = srgb_to_linear(rgb_img)

# FIXME transforms before this are causing -ve values, can have a large impact on this conversion
rgb_img.clamp_(0, 1.0)
rgb_img = rgb_img.clamp(0, 1.0)

# Convert to XYZ using matrix multiplication
rgb_to_xyz = torch.tensor([
[0.412453, 0.357580, 0.180423],
[0.212671, 0.715160, 0.072169],
[0.019334, 0.119193, 0.950227]
# X Y Z
[0.412453, 0.212671, 0.019334], # R
[0.357580, 0.715160, 0.119193], # G
[0.180423, 0.072169, 0.950227], # B
], device=rgb_img.device)

# Reshape input for matrix multiplication if needed
Expand All @@ -167,38 +170,30 @@ def rgb_to_lab_tensor(
rgb_img = rgb_img.reshape(-1, 3)

# Perform matrix multiplication
xyz = torch.matmul(rgb_img, rgb_to_xyz.T)
xyz = rgb_img @ rgb_to_xyz

# Adjust XYZ values
xyz[..., 0].div_(xn)
xyz[..., 1].div_(yn)
xyz[..., 2].div_(zn)
xyz.div_(torch.tensor([xn, yn, zn], device=xyz.device))

# Step 4: XYZ to LAB
lab = torch.where(
fxfyfz = torch.where(
xyz > epsilon,
torch.pow(xyz, 1 / 3),
(kappa * xyz + 16) / 116
)

L = 116 * fxfyfz[..., 1] - 16
a = 500 * (fxfyfz[..., 0] - fxfyfz[..., 1])
b = 200 * (fxfyfz[..., 1] - fxfyfz[..., 2])
if normalized:
# Calculate normalized [0,1] L,a,b values directly
# L: map [0,100] to [0,1] : (116y - 16)/100 = 1.16y - 0.16
# a: map [-128,127] to [0,1] : (500(x-y) + 128)/255 ≈ 1.96(x-y) + 0.502
# b: map [-128,127] to [0,1] : (200(y-z) + 128)/255 ≈ 0.784(y-z) + 0.502
shift_128 = 128 / 255
a_scale = 500 / 255
b_scale = 200 / 255
L = 1.16 * lab[..., 1] - 0.16
a = a_scale * (lab[..., 0] - lab[..., 1]) + shift_128
b = b_scale * (lab[..., 1] - lab[..., 2]) + shift_128
else:
# Calculate native range L,a,b values
L = 116 * lab[..., 1] - 16
a = 500 * (lab[..., 0] - lab[..., 1])
b = 200 * (lab[..., 1] - lab[..., 2])
# output in rage [0, 1] for each channel
L.div_(100)
a.add_(128).div_(255)
b.add_(128).div_(255)

if split_channels:
return L, a, b

# Stack the results
lab = torch.stack([L, a, b], dim=-1)

# Restore original shape if needed
Expand Down
4 changes: 2 additions & 2 deletions timm/data/transforms_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def transforms_imagenet_train(
use_prefetcher: bool = False,
normalize: bool = True,
separate: bool = False,
use_tensor: Optional[bool] = True, # FIXME forced True for testing
use_tensor: Optional[bool] = False,
):
""" ImageNet-oriented image transforms for training.
Expand Down Expand Up @@ -273,7 +273,7 @@ def transforms_imagenet_eval(
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
use_prefetcher: bool = False,
normalize: bool = True,
use_tensor: bool = True,
use_tensor: bool = False,
):
""" ImageNet-oriented image transform for evaluation and inference.
Expand Down

0 comments on commit d285526

Please sign in to comment.