Skip to content

Commit

Permalink
Testing some LAB stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Dec 19, 2024
1 parent 3b181b7 commit 3fbbd51
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 5 deletions.
117 changes: 116 additions & 1 deletion timm/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}()"


class ToLab(transforms.ToTensor):
class ToLabPIL:

def __init__(self) -> None:
super().__init__()
Expand All @@ -115,6 +115,121 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}()"


def srgb_to_linear(srgb_image: torch.Tensor) -> torch.Tensor:
return torch.where(
srgb_image <= 0.04045,
srgb_image / 12.92,
((srgb_image + 0.055) / 1.055) ** 2.4
)


def rgb_to_lab_tensor(
rgb_img: torch.Tensor,
normalized: bool = True,
srgb_input: bool = True,
) -> torch.Tensor:
"""
Convert RGB image to LAB color space using tensor operations.
Args:
rgb_img: Tensor of shape (..., 3) with values in range [0, 255]
normalized: If True, outputs L,a,b in [0, 1] range instead of native LAB ranges
Returns:
lab_img: Tensor of same shape with either:
- normalized=False: L in [0, 100] and a,b in [-128, 127]
- normalized=True: L,a,b in [0, 1]
"""
# Constants
epsilon = 216 / 24389
kappa = 24389 / 27
xn = 0.95047
yn = 1.0
zn = 1.08883

# Convert sRGB to linear RGB
if srgb_input:
rgb_img = srgb_to_linear(rgb_img)

# FIXME transforms before this are causing -ve values, can have a large impact on this conversion
rgb_img.clamp_(0, 1.0)

# Convert to XYZ using matrix multiplication
rgb_to_xyz = torch.tensor([
[0.412453, 0.357580, 0.180423],
[0.212671, 0.715160, 0.072169],
[0.019334, 0.119193, 0.950227]
], device=rgb_img.device)

# Reshape input for matrix multiplication if needed
original_shape = rgb_img.shape
if len(original_shape) > 2:
rgb_img = rgb_img.reshape(-1, 3)

# Perform matrix multiplication
xyz = torch.matmul(rgb_img, rgb_to_xyz.T)

# Adjust XYZ values
xyz[..., 0].div_(xn)
xyz[..., 1].div_(yn)
xyz[..., 2].div_(zn)

# Step 4: XYZ to LAB
lab = torch.where(
xyz > epsilon,
torch.pow(xyz, 1 / 3),
(kappa * xyz + 16) / 116
)

if normalized:
# Calculate normalized [0,1] L,a,b values directly
# L: map [0,100] to [0,1] : (116y - 16)/100 = 1.16y - 0.16
# a: map [-128,127] to [0,1] : (500(x-y) + 128)/255 ≈ 1.96(x-y) + 0.502
# b: map [-128,127] to [0,1] : (200(y-z) + 128)/255 ≈ 0.784(y-z) + 0.502
shift_128 = 128 / 255
a_scale = 500 / 255
b_scale = 200 / 255
L = 1.16 * lab[..., 1] - 0.16
a = a_scale * (lab[..., 0] - lab[..., 1]) + shift_128
b = b_scale * (lab[..., 1] - lab[..., 2]) + shift_128
else:
# Calculate native range L,a,b values
L = 116 * lab[..., 1] - 16
a = 500 * (lab[..., 0] - lab[..., 1])
b = 200 * (lab[..., 1] - lab[..., 2])

# Stack the results
lab = torch.stack([L, a, b], dim=-1)

# Restore original shape if needed
if len(original_shape) > 2:
lab = lab.reshape(original_shape)

return lab


class ToLabTensor:
def __init__(self, srgb_input=False, normalized=True) -> None:
self.srgb_input = srgb_input
self.normalized = normalized

def __call__(self, pic) -> torch.Tensor:
return rgb_to_lab_tensor(
pic,
normalized=self.normalized,
srgb_input=self.srgb_input,
)


class ToLinearRgb:
def __init__(self):
pass

def __call__(self, pic) -> torch.Tensor:
assert isinstance(pic, torch.Tensor)
return srgb_to_linear(pic)


# Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in
# favor of the Image.Resampling enum. The top-level resampling attributes will be
# removed in Pillow 10.
Expand Down
23 changes: 19 additions & 4 deletions timm/data/transforms_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation, \
ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad, TrimBorder, ToNumpy, MaybeToTensor, MaybePILToTensor
from timm.data.transforms import ToLabTensor, ToLinearRgb
from timm.data.random_erasing import RandomErasing


Expand Down Expand Up @@ -123,7 +124,10 @@ def transforms_imagenet_train(
* normalizes and converts the branches above with the third, final transform
"""
if use_tensor:
primary_tfl = [MaybeToTensor()]
primary_tfl = [
MaybeToTensor(),
ToLinearRgb(), # FIXME
]
else:
primary_tfl = []

Expand Down Expand Up @@ -236,6 +240,7 @@ def transforms_imagenet_train(
if not use_tensor:
final_tfl += [MaybeToTensor()]
final_tfl += [
ToLabTensor(), # FIXME
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std),
Expand Down Expand Up @@ -268,6 +273,7 @@ def transforms_imagenet_eval(
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
use_prefetcher: bool = False,
normalize: bool = True,
use_tensor: bool = True,
):
""" ImageNet-oriented image transform for evaluation and inference.
Expand All @@ -294,7 +300,13 @@ def transforms_imagenet_eval(
scale_size = math.floor(img_size / crop_pct)
scale_size = (scale_size, scale_size)

tfl = []
if use_tensor:
tfl = [
MaybeToTensor(),
ToLinearRgb(), # FIXME
]
else:
tfl = []

if crop_border_pixels:
tfl += [TrimBorder(crop_border_pixels)]
Expand Down Expand Up @@ -332,10 +344,13 @@ def transforms_imagenet_eval(
tfl += [ToNumpy()]
elif not normalize:
# when normalize disabled, converted to tensor without scaling, keeps original dtype
tfl += [MaybePILToTensor()]
if not use_tensor:
tfl += [MaybePILToTensor()]
else:
if not use_tensor:
tfl += [MaybeToTensor()]
tfl += [
MaybeToTensor(),
ToLabTensor(), # FIXME
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std),
Expand Down

0 comments on commit 3fbbd51

Please sign in to comment.