From 3fbbd511e64c979f555899304e0375bf02c97fd3 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 18 Dec 2024 16:49:17 -0800 Subject: [PATCH] Testing some LAB stuff --- timm/data/transforms.py | 117 +++++++++++++++++++++++++++++++- timm/data/transforms_factory.py | 23 +++++-- 2 files changed, 135 insertions(+), 5 deletions(-) diff --git a/timm/data/transforms.py b/timm/data/transforms.py index e0c7e7f90..f318b9a28 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -90,7 +90,7 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}()" -class ToLab(transforms.ToTensor): +class ToLabPIL: def __init__(self) -> None: super().__init__() @@ -115,6 +115,121 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}()" +def srgb_to_linear(srgb_image: torch.Tensor) -> torch.Tensor: + return torch.where( + srgb_image <= 0.04045, + srgb_image / 12.92, + ((srgb_image + 0.055) / 1.055) ** 2.4 + ) + + +def rgb_to_lab_tensor( + rgb_img: torch.Tensor, + normalized: bool = True, + srgb_input: bool = True, +) -> torch.Tensor: + """ + Convert RGB image to LAB color space using tensor operations. + + Args: + rgb_img: Tensor of shape (..., 3) with values in range [0, 255] + normalized: If True, outputs L,a,b in [0, 1] range instead of native LAB ranges + + Returns: + lab_img: Tensor of same shape with either: + - normalized=False: L in [0, 100] and a,b in [-128, 127] + - normalized=True: L,a,b in [0, 1] + """ + # Constants + epsilon = 216 / 24389 + kappa = 24389 / 27 + xn = 0.95047 + yn = 1.0 + zn = 1.08883 + + # Convert sRGB to linear RGB + if srgb_input: + rgb_img = srgb_to_linear(rgb_img) + + # FIXME transforms before this are causing -ve values, can have a large impact on this conversion + rgb_img.clamp_(0, 1.0) + + # Convert to XYZ using matrix multiplication + rgb_to_xyz = torch.tensor([ + [0.412453, 0.357580, 0.180423], + [0.212671, 0.715160, 0.072169], + [0.019334, 0.119193, 0.950227] + ], device=rgb_img.device) + + # Reshape input for matrix multiplication if needed + original_shape = rgb_img.shape + if len(original_shape) > 2: + rgb_img = rgb_img.reshape(-1, 3) + + # Perform matrix multiplication + xyz = torch.matmul(rgb_img, rgb_to_xyz.T) + + # Adjust XYZ values + xyz[..., 0].div_(xn) + xyz[..., 1].div_(yn) + xyz[..., 2].div_(zn) + + # Step 4: XYZ to LAB + lab = torch.where( + xyz > epsilon, + torch.pow(xyz, 1 / 3), + (kappa * xyz + 16) / 116 + ) + + if normalized: + # Calculate normalized [0,1] L,a,b values directly + # L: map [0,100] to [0,1] : (116y - 16)/100 = 1.16y - 0.16 + # a: map [-128,127] to [0,1] : (500(x-y) + 128)/255 ≈ 1.96(x-y) + 0.502 + # b: map [-128,127] to [0,1] : (200(y-z) + 128)/255 ≈ 0.784(y-z) + 0.502 + shift_128 = 128 / 255 + a_scale = 500 / 255 + b_scale = 200 / 255 + L = 1.16 * lab[..., 1] - 0.16 + a = a_scale * (lab[..., 0] - lab[..., 1]) + shift_128 + b = b_scale * (lab[..., 1] - lab[..., 2]) + shift_128 + else: + # Calculate native range L,a,b values + L = 116 * lab[..., 1] - 16 + a = 500 * (lab[..., 0] - lab[..., 1]) + b = 200 * (lab[..., 1] - lab[..., 2]) + + # Stack the results + lab = torch.stack([L, a, b], dim=-1) + + # Restore original shape if needed + if len(original_shape) > 2: + lab = lab.reshape(original_shape) + + return lab + + +class ToLabTensor: + def __init__(self, srgb_input=False, normalized=True) -> None: + self.srgb_input = srgb_input + self.normalized = normalized + + def __call__(self, pic) -> torch.Tensor: + return rgb_to_lab_tensor( + pic, + normalized=self.normalized, + srgb_input=self.srgb_input, + ) + + +class ToLinearRgb: + def __init__(self): + pass + + def __call__(self, pic) -> torch.Tensor: + assert isinstance(pic, torch.Tensor) + return srgb_to_linear(pic) + + # Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in # favor of the Image.Resampling enum. The top-level resampling attributes will be # removed in Pillow 10. diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py index 5653109f7..a363a4bbe 100644 --- a/timm/data/transforms_factory.py +++ b/timm/data/transforms_factory.py @@ -14,6 +14,7 @@ from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation, \ ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad, TrimBorder, ToNumpy, MaybeToTensor, MaybePILToTensor +from timm.data.transforms import ToLabTensor, ToLinearRgb from timm.data.random_erasing import RandomErasing @@ -123,7 +124,10 @@ def transforms_imagenet_train( * normalizes and converts the branches above with the third, final transform """ if use_tensor: - primary_tfl = [MaybeToTensor()] + primary_tfl = [ + MaybeToTensor(), + ToLinearRgb(), # FIXME + ] else: primary_tfl = [] @@ -236,6 +240,7 @@ def transforms_imagenet_train( if not use_tensor: final_tfl += [MaybeToTensor()] final_tfl += [ + ToLabTensor(), # FIXME transforms.Normalize( mean=torch.tensor(mean), std=torch.tensor(std), @@ -268,6 +273,7 @@ def transforms_imagenet_eval( std: Tuple[float, ...] = IMAGENET_DEFAULT_STD, use_prefetcher: bool = False, normalize: bool = True, + use_tensor: bool = True, ): """ ImageNet-oriented image transform for evaluation and inference. @@ -294,7 +300,13 @@ def transforms_imagenet_eval( scale_size = math.floor(img_size / crop_pct) scale_size = (scale_size, scale_size) - tfl = [] + if use_tensor: + tfl = [ + MaybeToTensor(), + ToLinearRgb(), # FIXME + ] + else: + tfl = [] if crop_border_pixels: tfl += [TrimBorder(crop_border_pixels)] @@ -332,10 +344,13 @@ def transforms_imagenet_eval( tfl += [ToNumpy()] elif not normalize: # when normalize disabled, converted to tensor without scaling, keeps original dtype - tfl += [MaybePILToTensor()] + if not use_tensor: + tfl += [MaybePILToTensor()] else: + if not use_tensor: + tfl += [MaybeToTensor()] tfl += [ - MaybeToTensor(), + ToLabTensor(), # FIXME transforms.Normalize( mean=torch.tensor(mean), std=torch.tensor(std),