-
Notifications
You must be signed in to change notification settings - Fork 0
/
clothcoparse_dataset.py
63 lines (40 loc) · 2.41 KB
/
clothcoparse_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# -*- coding: utf-8 -*-
"""
Created on Tue Jan 22 12:10:32 2021
@author: malrawi
"""
import glob
import os
import scipy.io as sio
from torch.utils.data import Dataset # Dataset class from PyTorch
from PIL import Image, ImageChops # PIL is a nice Python Image Library that we can use to handle images
import numpy as np
# https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html
class ImageDataset(Dataset):
def __init__(self, root, class_names_and_colors, mode="train", HPC_run=False):
self.class_names = list(class_names_and_colors.keys())
if HPC_run:
root = '/home/malrawi/MyPrograms/Data/ClothCoParse'
self.files_A = sorted(glob.glob(os.path.join(root, "%s/A" % mode) + "/*.*")) # get the source image file-names
self.files_B = sorted(glob.glob(os.path.join(root, "%s/B" % mode) + "/*.*")) # get the target image file-names
def number_of_classes(self, opt):
return(len(self.class_names)) # this should do
def __getitem__(self, index):
annot = sio.loadmat(self.files_B[index % len(self.files_B)])
mask = annot["groundtruth"]
image_A = Image.open(self.files_A[index % len(self.files_A)]) # read the image, according to the file name, index select which image to read; index=1 means get the first image in the list self.files_A
# instances are encoded as different colors
obj_ids = np.unique(mask)[1:] # first id is the background, so remove it
masks = mask == obj_ids[:, None, None] # split the color-encoded mask into a set of binary masks
# get bounding box coordinates for each mask
num_objs = len(obj_ids)
masked_img = []; labels =[]
for i in range(num_objs):
img = ImageChops.multiply(image_A, Image.fromarray(255*masks[i]).convert('RGB') )
masked_img.append(np.array(img, dtype='uint8'))
labels.append(self.class_names[obj_ids[i]])
image_id = index
fname = os.path.basename(self.files_A[index % len(self.files_A)])
return image_A, masked_img, labels, image_id, masks, fname
def __len__(self): # this function returns the length of the dataset, the source might not equal the target if the data is unaligned
return len(self.files_B)