forked from davidtvs/PyTorch-ENet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
transforms.py
94 lines (69 loc) · 3.11 KB
/
transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import numpy as np
from PIL import Image
from collections import OrderedDict
from torchvision.transforms import ToPILImage
class PILToLongTensor(object):
"""Converts a ``PIL Image`` to a ``torch.LongTensor``.
Code adapted from: http://pytorch.org/docs/master/torchvision/transforms.html?highlight=totensor
"""
def __call__(self, pic):
"""Performs the conversion from a ``PIL Image`` to a ``torch.LongTensor``.
Keyword arguments:
- pic (``PIL.Image``): the image to convert to ``torch.LongTensor``
Returns:
A ``torch.LongTensor``.
"""
if not isinstance(pic, Image.Image):
raise TypeError("pic should be PIL Image. Got {}".format(
type(pic)))
# handle numpy array
if isinstance(pic, np.ndarray):
img = torch.from_numpy(pic.transpose((2, 0, 1)))
# backward compatibility
return img.long()
# Convert PIL image to ByteTensor
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
# Reshape tensor
nchannel = len(pic.mode)
img = img.view(pic.size[1], pic.size[0], nchannel)
# Convert to long and squeeze the channels
return img.transpose(0, 1).transpose(0,
2).contiguous().long().squeeze_()
class LongTensorToRGBPIL(object):
"""Converts a ``torch.LongTensor`` to a ``PIL image``.
The input is a ``torch.LongTensor`` where each pixel's value identifies the
class.
Keyword arguments:
- rgb_encoding (``OrderedDict``): An ``OrderedDict`` that relates pixel
values, class names, and class colors.
"""
def __init__(self, rgb_encoding):
self.rgb_encoding = rgb_encoding
def __call__(self, tensor):
"""Performs the conversion from ``torch.LongTensor`` to a ``PIL image``
Keyword arguments:
- tensor (``torch.LongTensor``): the tensor to convert
Returns:
A ``PIL.Image``.
"""
# Check if label_tensor is a LongTensor
if not isinstance(tensor, torch.LongTensor):
raise TypeError("label_tensor should be torch.LongTensor. Got {}"
.format(type(tensor)))
# Check if encoding is a ordered dictionary
if not isinstance(self.rgb_encoding, OrderedDict):
raise TypeError("encoding should be an OrderedDict. Got {}".format(
type(self.rgb_encoding)))
# label_tensor might be an image without a channel dimension, in this
# case unsqueeze it
if len(tensor.size()) == 2:
tensor.unsqueeze_(0)
color_tensor = torch.ByteTensor(3, tensor.size(1), tensor.size(2))
for index, (class_name, color) in enumerate(self.rgb_encoding.items()):
# Get a mask of elements equal to index
mask = torch.eq(tensor, index).squeeze_()
# Fill color_tensor with corresponding colors
for channel, color_value in enumerate(color):
color_tensor[channel].masked_fill_(mask, color_value)
return ToPILImage()(color_tensor)