-
Notifications
You must be signed in to change notification settings - Fork 4
/
celeba.py
112 lines (83 loc) · 2.88 KB
/
celeba.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import pandas as pd
from torch.utils.data import Dataset
from torchvision.datasets.folder import default_loader
from uvcgan.consts import SPLIT_TRAIN, SPLIT_VAL, SPLIT_TEST
from uvcgan.utils.funcs import check_value_in_range
FNAME_ATTRS = 'list_attr_celeba.txt'
FNAME_SPLIT = 'list_eval_partition.txt'
SUBDIR_IMG = 'img_align_celeba'
SPLITS = {
SPLIT_TRAIN : 0,
SPLIT_VAL : 1,
SPLIT_TEST : 2,
}
DOMAINS = [ 'a', 'b' ]
class CelebaDataset(Dataset):
def __init__(
self, path,
attr = 'Young',
domain = 'a',
split = SPLIT_TRAIN,
transform = None,
**kwargs
):
# pylint: disable=too-many-arguments
check_value_in_range(split, SPLITS, 'CelebaDataset: split')
if attr is None:
assert domain is None
else:
check_value_in_range(domain, DOMAINS, 'CelebaDataset: domain')
super().__init__(**kwargs)
self._path = path
self._root_imgs = os.path.join(path, SUBDIR_IMG)
self._split = split
self._attr = attr
self._domain = domain
self._imgs = []
self._transform = transform
self._collect_files()
def _collect_files(self):
imgs_specs = CelebaDataset.load_image_specs(self._path)
imgs = CelebaDataset.partition_images(
imgs_specs, self._split, self._attr, self._domain
)
self._imgs = [ os.path.join(self._root_imgs, x) for x in imgs ]
@staticmethod
def load_image_partition(root):
path = os.path.join(root, FNAME_SPLIT)
return pd.read_csv(
path, sep = r'\s+', header = None, names = [ 'partition', ],
index_col = 0
)
@staticmethod
def load_image_attrs(root):
path = os.path.join(root, FNAME_ATTRS)
return pd.read_csv(
path, sep = r'\s+', skiprows = 1, header = 0, index_col = 0
)
@staticmethod
def load_image_specs(root):
df_partition = CelebaDataset.load_image_partition(root)
df_attrs = CelebaDataset.load_image_attrs(root)
return df_partition.join(df_attrs)
@staticmethod
def partition_images(image_specs, split, attr, domain):
part_mask = (image_specs.partition == SPLITS[split])
if attr is None:
imgs = image_specs[part_mask].index.to_list()
else:
if domain == 'a':
domain_mask = (image_specs[attr] > 0)
else:
domain_mask = (image_specs[attr] < 0)
imgs = image_specs[part_mask & domain_mask].index.to_list()
return imgs
def __len__(self):
return len(self._imgs)
def __getitem__(self, index):
path = self._imgs[index]
result = default_loader(path)
if self._transform is not None:
result = self._transform(result)
return result