forked from bfortuner/learning_data_aug
-
Notifications
You must be signed in to change notification settings - Fork 0
/
datasets.py
102 lines (84 loc) · 3.9 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import numpy as np
import torch
from torch.autograd import Variable
import torchvision
from torchvision import transforms
CIFAR_CLASSES = (
"beaver", "dolphin", "otter", "seal", "whale",
"aquarium fish", "flatfish", "ray", "shark", "trout",
"orchids", "poppies", "roses", "sunflowers", "tulips",
"bottles", "bowls", "cans", "cups", "plates",
"apples", "mushrooms", "oranges", "pears", "sweet peppers",
"clock", "computer keyboard", "lamp", "telephone", "television",
"bed", "chair", "couch", "table", "wardrobe",
"bee", "beetle", "butterfly", "caterpillar", "cockroach",
"bear", "leopard", "lion", "tiger", "wolf",
"bridge", "castle", "house", "road", "skyscraper",
"cloud", "forest", "mountain", "plain", "sea",
"camel", "cattle", "chimpanzee", "elephant", "kangaroo",
"fox", "porcupine", "possum", "raccoon", "skunk",
"crab", "lobster", "snail", "spider", "worm",
"baby", "boy", "girl", "man", "woman",
"crocodile", "dinosaur", "lizard", "snake", "turtle",
"hamster", "mouse", "rabbit", "shrew", "squirrel",
"maple", "oak", "palm", "pine", "willow",
"bicycle", "bus", "motorcycle", "pickup truck", "train",
"lawn-mower", "rocket", "streetcar", "tank", "tractor"
)
TRN_TRANSFORM = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
])
TST_TRANSFORM = transforms.Compose([
transforms.ToTensor()
])
class LearnedTransform():
def __init__(self, model):
self.model = model
def __call__(self, x):
x = Variable(x)
return self.model.transform(x)
def get_cifar_dataset(trn_transform=TRN_TRANSFORM, tst_transform=TST_TRANSFORM):
# transform code modified from https://github.com/ne7ermore/torch-light/blob/master/DenseNet/train.py
trans = [transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor()]
trans = transforms.Compose(trans)
trainset = torchvision.datasets.CIFAR100(root='data/',
train=True, download=True, transform=trn_transform)
testset = torchvision.datasets.CIFAR100(root='data/',
train=False, download=True, transform=tst_transform)
return trainset, testset
def get_cifar_loader(trainset, testset, batch_size=64):
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(
testset, batch_size=batch_size, shuffle=False, num_workers=2)
return trainloader, testloader
def get_mnist_dataset(trn_size=60000, tst_size=10000):
MNIST_MEAN = np.array([0.1307, ])
MNIST_STD = np.array([0.3081, ])
normTransform = transforms.Normalize(MNIST_MEAN, MNIST_STD)
trainTransform = transforms.Compose([
transforms.ToTensor(),
# normTransform
])
testTransform = transforms.Compose([
transforms.ToTensor(),
# normTransform
])
trainset = torchvision.datasets.MNIST(root='../data', train=True,
download=True, transform=trainTransform)
trainset.train_data = trainset.train_data[:trn_size]
trainset.train_labels = trainset.train_labels[:trn_size]
testset = torchvision.datasets.MNIST(root='../data', train=False,
download=True, transform=testTransform)
testset.test_data = testset.test_data[:tst_size]
testset.test_labels = testset.test_labels[:tst_size]
return trainset, testset
def get_mnist_loader(trainset, testset, batch_size=128):
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False, num_workers=2)
return trainloader, testloader