-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdataset_wrapper.py
120 lines (99 loc) · 4.12 KB
/
dataset_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""Wrapper classes for original and encoded datasets."""
from keras.datasets import mnist, cifar10
import numpy as np
import matplotlib.pyplot as plt
import h5py
import utils
import stl_dataset
class DatasetWrapper(object):
def __init__(self, train_xs, train_ys, test_xs, test_ys):
"""DO NOT do any normalization in this function"""
self.train_xs = train_xs.astype(np.float32)
self.train_ys = train_ys
self.test_xs = test_xs.astype(np.float32)
self.test_ys = test_ys
@property
def x_shape(self):
return self.train_xs.shape[1:]
@classmethod
def load_from_h5(cls, h5_path):
with h5py.File(h5_path, 'r') as hf:
train_xs = np.array(hf.get('train_xs'))
train_ys = np.array(hf.get('train_ys'))
test_xs = np.array(hf.get('test_xs'))
test_ys = np.array(hf.get('test_ys'))
print 'Dataset loaded from %s' % h5_path
return cls(train_xs, train_ys, test_xs, test_ys)
@classmethod
def load_default(cls):
raise NotImplementedError
def dump_to_h5(self, h5_path):
with h5py.File(h5_path, 'w') as hf:
hf.create_dataset('train_xs', data=self.train_xs)
hf.create_dataset('train_ys', data=self.train_ys)
hf.create_dataset('test_xs', data=self.test_xs)
hf.create_dataset('test_ys', data=self.test_ys)
print 'Dataset written to %s' % h5_path
def reshape(self, new_shape):
batch_size = self.train_xs.shape[0]
self.train_xs = self.train_xs.reshape((batch_size,) + new_shape)
batch_size = self.test_xs.shape[0]
self.test_xs = self.test_xs.reshape((batch_size,) + new_shape)
assert self.train_xs.shape[1:] == self.test_xs.shape[1:]
def plot_data_dist(self, fig_path, num_bins=50):
xs = np.vstack((self.train_xs, self.test_xs))
if len(xs.shape) > 2:
num_imgs = len(xs)
xs = xs.reshape((num_imgs, -1))
plt.hist(xs, num_bins)
if fig_path:
plt.savefig(fig_path)
plt.close()
else:
plt.show()
def get_subset(self, subset, subclass=None):
"""get a subset.
subset: 'train' or 'test'
subclass: name of the subclass of interest
"""
xs = self.train_xs if subset == 'train' else self.test_xs
ys = self.train_ys if subset == 'train' else self.test_ys
assert len(xs) == len(ys)
if subclass:
idx = self.cls2idx[subclass]
loc = np.where(ys == idx)[0]
xs = xs[loc]
ys = ys[loc]
return xs, ys
class MnistWrapper(DatasetWrapper):
@classmethod
def load_default(cls):
((train_xs, train_ys), (test_xs, test_ys)) = mnist.load_data()
train_xs = (train_xs / 255.0).reshape(-1, 28, 28, 1)
test_xs = (test_xs / 255.0).reshape(-1, 28, 28, 1)
return cls(train_xs, train_ys, test_xs, test_ys)
class Cifar10Wrapper(DatasetWrapper):
idx2cls = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
cls2idx = {cls: idx for (idx, cls) in enumerate(idx2cls)}
@classmethod
def load_default(cls):
((train_xs, train_ys), (test_xs, test_ys)) = cifar10.load_data()
train_xs = utils.preprocess_cifar10(train_xs)
test_xs = utils.preprocess_cifar10(test_xs)
return cls(train_xs, train_ys, test_xs, test_ys)
class STL10Wrapper(DatasetWrapper):
@classmethod
def load_default(cls):
train_xs = stl_dataset.read_all_images(stl_dataset.UNLABELED_DATA_PATH)
train_ys = np.zeros(len(train_xs), dtype=np.uint8)
test_xs = stl_dataset.read_all_images(stl_dataset.DATA_PATH)
test_ys = stl_dataset.read_labels(stl_dataset.LABEL_PATH)
train_xs = utils.preprocess_stl10(train_xs)
test_xs = utils.preprocess_stl10(test_xs)
return cls(train_xs, train_ys, test_xs, test_ys)
if __name__ == '__main__':
mnist_dataset = MnistWrapper.load_default()
# mnist_dataset.plot_data_dist(None)
cifar10_dataset = Cifar10Wrapper.load_default()
# cifar10_dataset.plot_data_dist(None)