-
Notifications
You must be signed in to change notification settings - Fork 6
/
datasets.py
71 lines (55 loc) · 2.27 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import cv2
import torch
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
class PoseDataset(torch.utils.data.Dataset):
def __init__(self, data_dir, mode, transform=None):
self.data_dir = data_dir
self.mode = mode
self.transform = transform
if mode == 'train' or mode =='val':
self.images = sorted(os.listdir(os.path.join(data_dir, 'images')))
self.masks = sorted(os.listdir(os.path.join(data_dir, 'masks')))
elif mode == 'test':
self.images = os.listdir(os.path.join(data_dir))
def __len__(self):
return len(self.images)
def __getitem__(self, item):
if self.mode == 'train' or self.mode =='val':
img = cv2.imread(os.path.join(self.data_dir, 'images', self.images[item]))
mask = cv2.imread(os.path.join(self.data_dir, 'masks', self.masks[item]), 0)
if self.transform is not None:
transformed = self.transform(image=img, mask=mask)
img = transformed["image"]
mask = transformed["mask"]
return img, mask.long(), self.images[item]
return img, mask, self.images[item]
elif self.mode == 'test':
img = cv2.imread(os.path.join(self.data_dir, self.images[item]))
if self.transform is not None:
transformed = self.transform(image=img)
img = transformed["image"]
return img, self.images[item]
return img, self.images[item]
# h,w,c = img.shape
# h_flag = h%2==1
# w_flag = w%2==1
# if h_flag:
# temp_transformed = A.Resize(w, h+1)(image=img, mask=mask)
# img = temp_transformed["image"]
# mask = temp_transformed["mask"]
# elif w_flag:
# temp_transformed = A.Resize(w+1, h)(image=img, mask=mask)
# img = temp_transformed["image"]
# mask = temp_transformed["mask"]
if __name__ == '__main__':
transforms = A.Compose([
A.Normalize(
mean=[0.4914, 0.4822, 0.4465],
std=[0.2471, 0.2435, 0.2616],
),
ToTensorV2()
])
ds = PoseDataset('data/train', mode='train', transform=transforms)
print()