-
Notifications
You must be signed in to change notification settings - Fork 4
/
data.py
123 lines (98 loc) · 3.78 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import os
import torch
import torchvision
from toytools.datasets import get_toyzero_dataset_torch
from uvcgan.consts import (
ROOT_DATA, SPLIT_TRAIN, MERGE_PAIRED, MERGE_UNPAIRED
)
from uvcgan.torch.select import extract_name_kwargs
from .datasets.celeba import CelebaDataset
from .datasets.image_domain_folder import ImageDomainFolder
from .datasets.image_domain_hierarchy import ImageDomainHierarchy
from .datasets.ndarray_domain_hierarchy import NDArrayDomainHierarchy
from .datasets.zipper import DatasetZipper
from .datasets.custom_dataset import custom_dataset
from .loader_zipper import DataLoaderZipper
from .transforms import select_transform
def select_dataset(name, path, split, transform, **kwargs):
if name == 'celeba':
return CelebaDataset(
path, transform = transform, split = split, **kwargs
)
if name in [ 'cyclegan', 'image-domain-folder' ]:
return ImageDomainFolder(
path, transform = transform, split = split, **kwargs
)
if name in [ 'image-domain-hierarchy' ]:
return ImageDomainHierarchy(
path, transform = transform, split = split, **kwargs
)
if name == 'ndarray-domain-hierarchy':
return NDArrayDomainHierarchy(
path, transform = transform, split = split, **kwargs
)
if name == 'imagenet':
return torchvision.datasets.ImageNet(
path, transform = transform, split = split, **kwargs
)
if name in [ 'imagedir', 'image-folder' ]:
return torchvision.datasets.ImageFolder(
os.path.join(path, split), transform = transform, **kwargs
)
if name == 'custom':
assert 'dataset' in kwargs, \
'a path to your dataset API must provided'
dataset = kwargs.pop('dataset')
return custom_dataset(dataset, path, split = split, **kwargs)
return get_toyzero_dataset_torch(
name, path, transform = transform, split = split, **kwargs
)
def construct_single_dataset(dataset_config, split):
name, kwargs = extract_name_kwargs(dataset_config.dataset)
path = os.path.join(ROOT_DATA, kwargs.pop('path', name))
if split == SPLIT_TRAIN:
transform = select_transform(dataset_config.transform_train)
else:
transform = select_transform(dataset_config.transform_test)
return select_dataset(name, path, split, transform, **kwargs)
def construct_datasets(data_config, split):
return [
construct_single_dataset(config, split)
for config in data_config.datasets
]
def construct_single_loader(
dataset, batch_size, shuffle,
workers = None,
prefetch_factor = 20,
**kwargs
):
if workers is None:
workers = min(torch.get_num_threads(), 20)
return torch.utils.data.DataLoader(
dataset, batch_size,
shuffle = shuffle,
num_workers = workers,
prefetch_factor = prefetch_factor,
pin_memory = True,
**kwargs
)
def construct_data_loaders(data_config, batch_size, split):
datasets = construct_datasets(data_config, split)
shuffle = (split == SPLIT_TRAIN)
if data_config.merge_type == MERGE_PAIRED:
dataset = DatasetZipper(datasets)
return construct_single_loader(
dataset, batch_size, shuffle, data_config.workers,
drop_last = False
)
loaders = [
construct_single_loader(
dataset, batch_size, shuffle, data_config.workers,
drop_last = (data_config.merge_type == MERGE_UNPAIRED)
) for dataset in datasets
]
if data_config.merge_type == MERGE_UNPAIRED:
return DataLoaderZipper(loaders)
if len(loaders) == 1:
return loaders[0]
return loaders