-
Notifications
You must be signed in to change notification settings - Fork 31
/
dataset_sr.py
134 lines (103 loc) · 4.81 KB
/
dataset_sr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
import cv2
import numpy as np
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import PIL.Image as Image
from randomcrop import RandomHorizontallyFlip
class TrainValDataset(Dataset):
def __init__(self, name):
super().__init__()
self.dataset = name
self.root = '/home/zhxing/Datasets/SRD_inpaint4shadow_fix/'
# self.root = '/home/zhxing/Datasets/ISTD+/'
self.imgs = open(self.dataset).readlines()
self.file_num = len(self.imgs)
self.hflip = RandomHorizontallyFlip()
self.trans = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def __len__(self):
return self.file_num * 100
def __getitem__(self, index):
line = self.imgs[index % self.file_num].strip()
parts = line.split()
image_path, label_path = parts[0], parts[1]
image = cv2.imread(self.root + image_path)
label = cv2.imread(self.root + label_path)
# Convert to LAB color space
image_lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
label_lab = cv2.cvtColor(label, cv2.COLOR_BGR2LAB)
# Resize images
image_lab = cv2.resize(image_lab, (512, 512))
label_lab = cv2.resize(label_lab, (512, 512))
# Convert to PIL Image for transformations
image_lab = Image.fromarray(image_lab)
label_lab = Image.fromarray(label_lab)
# max and min value of image_lab
# print("image_lab max: ", np.max(image_lab))
# print("image_lab min: ", np.min(image_lab))
image_lab, label_lab = self.hflip(image_lab, label_lab)
# label_lab = np.array(label_lab, dtype='float32') / 255.0
label_lab = np.array(label_lab, dtype='float32')
image_nom = self.trans(image_lab)
# print("image_nom max: ", image_nom.max())
# print("image_nom min: ", image_nom.min())
label_lab = np.array([label_lab])
# print("image_nom shape: ", image_nom.shape)
# label_lab shape: (1, 512, 512, 3)
# image_nom shape: torch.Size([3, 512, 512])
# align the shape of label_lab to image_nom
label_lab = label_lab.transpose(3, 0, 1, 2)
# label_lab shape: (3, 1, 512, 512)
# align the shape of label_lab to image_nom
label_lab = np.squeeze(label_lab)
# print("label_lab shape: ", label_lab.shape)
image_ori = np.array(image_lab, dtype='float32').transpose(2, 0, 1)
sample = {'O': image_nom, 'B': label_lab, 'image': np.array(image_lab, dtype='float32').transpose(2, 0, 1) / 255, "image_ori": np.array(image_lab, dtype='float32').transpose(2, 0, 1)}
return sample
class TestDataset(Dataset):
def __init__(self, name):
super().__init__()
self.dataset = name
self.root = '/home/zhxing/Datasets/SRD_inpaint4shadow_fix/'
# self.root = '/home/zhxing/Datasets/ISTD+/'
# self.root = '/home/zhxing/Datasets/DESOBA_xvision/'
self.imgs = open(self.root + 'test_dsc.txt').readlines()
self.file_num = len(self.imgs)
self.trans = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def __len__(self):
return self.file_num
def __getitem__(self, index):
image_path, label_path = self.imgs[index % self.file_num][:-1].split(' ')
image = cv2.imread(self.root + image_path)
label = cv2.imread(self.root + label_path)
# Convert to LAB color space
image_lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
label_lab = cv2.cvtColor(label, cv2.COLOR_BGR2LAB)
# Resize images
image_lab = cv2.resize(image_lab, (512, 512))
label_lab = cv2.resize(label_lab, (512, 512))
# Convert to PIL Image for transformations
image_lab = Image.fromarray(image_lab)
label_lab = np.array(label_lab, dtype='float32') / 255.0
image_nom = self.trans(image_lab)
label_lab = np.array([label_lab])
# print("image_nom shape: ", image_nom.shape)
# label_lab shape: (1, 512, 512, 3)
# image_nom shape: torch.Size([3, 512, 512])
# align the shape of label_lab to image_nom
label_lab = label_lab.transpose(3, 0, 1, 2)
# label_lab shape: (3, 1, 512, 512)
# align the shape of label_lab to image_nom
label_lab = np.squeeze(label_lab)
# print("label_lab shape: ", label_lab.shape)
# print the range of image_nom
# print("image_nom max: ", image_nom.max())
image_ori = np.array(image_lab, dtype='float32').transpose(2, 0, 1)
sample = {'O': image_nom, 'B': label_lab, 'image': np.array(image_lab), "image_ori": image_ori}
return sample