-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathplaces2.py
36 lines (28 loc) · 1.4 KB
/
places2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import random
import torch
from PIL import Image
from glob import glob
class Places2(torch.utils.data.Dataset):
def __init__(self, img_root, mask_root, img_transform, mask_transform,
split='train'):
super(Places2, self).__init__()
self.img_transform = img_transform
self.mask_transform = mask_transform
# use about 8M images in the challenge dataset
if split == 'train':
self.paths = glob('{:s}/data_large/**/*.jpg'.format(img_root),
recursive=True)
else:
#self.paths = glob('C:/Users/yatha/OneDrive/Desktop/Github_clones/pytorch-inpainting-with-partial-conv/data/data_large/*.jpg')
self.paths = glob('{:s}/{:s}_large/*'.format(img_root, split))
#self.mask_paths = glob('C:/Users/yatha/OneDrive/Desktop/Github_clones/pytorch-inpainting-with-partial-conv/data/mask_root/*.jpg')
self.mask_paths = glob('{:s}/*.jpg'.format(mask_root))
self.N_mask = len(self.mask_paths)
def __getitem__(self, index):
gt_img = Image.open(self.paths[index])
gt_img = self.img_transform(gt_img.convert('RGB'))
mask = Image.open(self.mask_paths[random.randint(0, self.N_mask - 1)])
mask = self.mask_transform(mask.convert('RGB'))
return gt_img * mask, mask, gt_img
def __len__(self):
return len(self.paths)