-
Notifications
You must be signed in to change notification settings - Fork 5
/
msd_dataloader.py
105 lines (101 loc) · 4.92 KB
/
msd_dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import os
def msd_get_dataloaders(args):
train_loader, val_loader, test_loader = None, None, None
if args.data == 'cifar10':
normalize = transforms.Normalize(mean=[0.4914, 0.4824, 0.4467],
std=[0.2471, 0.2435, 0.2616])
train_set = datasets.CIFAR10(args.data_root, train=True, download=True,
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize
]))
val_set = datasets.CIFAR10(args.data_root, train=False,
transform=transforms.Compose([
transforms.ToTensor(),
normalize
]))
elif args.data == 'cifar100':
normalize = transforms.Normalize(mean=[0.5071, 0.4867, 0.4408],
std=[0.2675, 0.2565, 0.2761])
train_set = datasets.CIFAR100(args.data_root, train=True, download=True,
transform=transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize
]))
val_set = datasets.CIFAR100(args.data_root, train=False, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
normalize
]))
else:
# ImageNet
traindir = os.path.join(args.data_root, 'train')
valdir = os.path.join(args.data_root, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_set = datasets.ImageFolder(traindir, transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize
]))
val_set = datasets.ImageFolder(valdir, transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize
]))
if args.use_valid:
train_set_index = torch.randperm(len(train_set))
if os.path.exists(os.path.join(args.save, 'index.pth')):
print('!!!!!! Load train_set_index !!!!!!')
train_set_index = torch.load(os.path.join(args.save, 'index.pth'))
else:
print('!!!!!! Save train_set_index !!!!!!')
torch.save(train_set_index, os.path.join(args.save, 'index.pth'))
if args.data.startswith('cifar'):
num_sample_valid = 5000
else:
num_sample_valid = 50000
# num_sample_valid = len(val_set)
print("------------------------------------")
print("split num_sample_valid: %d" % num_sample_valid)
print("------------------------------------")
if 'train' in args.splits:
train_loader = torch.utils.data.DataLoader(
train_set, batch_size=args.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(
train_set_index[:-num_sample_valid]),
num_workers=args.workers, pin_memory=True)
if 'val' in args.splits:
val_loader = torch.utils.data.DataLoader(
train_set, batch_size=args.batch_size,
sampler=torch.utils.data.sampler.SubsetRandomSampler(
train_set_index[-num_sample_valid:]),
num_workers=args.workers, pin_memory=True)
if 'test' in args.splits:
test_loader = torch.utils.data.DataLoader(
val_set,
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
else:
if 'train' in args.splits:
train_loader = torch.utils.data.DataLoader(
train_set,
batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, pin_memory=True)
if 'val' or 'test' in args.splits:
val_loader = torch.utils.data.DataLoader(
val_set,
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
test_loader = val_loader
return train_loader, val_loader, test_loader