-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
63 lines (50 loc) · 1.89 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
""" train and test dataset
author baiyu
"""
import os
import sys
import pickle
from skimage import io
import matplotlib.pyplot as plt
import numpy
import torch
from torch.utils.data import Dataset
class CIFAR100Train(Dataset):
"""cifar100 test dataset, derived from
torch.utils.data.DataSet
"""
def __init__(self, path, transform=None):
#if transform is given, we transoform data using
with open(os.path.join(path, 'train'), 'rb') as cifar100:
self.data = pickle.load(cifar100, encoding='bytes')
self.transform = transform
def __len__(self):
return len(self.data['fine_labels'.encode()])
def __getitem__(self, index):
label = self.data['fine_labels'.encode()][index]
r = self.data['data'.encode()][index, :1024].reshape(32, 32)
g = self.data['data'.encode()][index, 1024:2048].reshape(32, 32)
b = self.data['data'.encode()][index, 2048:].reshape(32, 32)
image = numpy.dstack((r, g, b))
if self.transform:
image = self.transform(image)
return label, image
class CIFAR100Test(Dataset):
"""cifar100 test dataset, derived from
torch.utils.data.DataSet
"""
def __init__(self, path, transform=None):
with open(os.path.join(path, 'test'), 'rb') as cifar100:
self.data = pickle.load(cifar100, encoding='bytes')
self.transform = transform
def __len__(self):
return len(self.data['data'.encode()])
def __getitem__(self, index):
label = self.data['fine_labels'.encode()][index]
r = self.data['data'.encode()][index, :1024].reshape(32, 32)
g = self.data['data'.encode()][index, 1024:2048].reshape(32, 32)
b = self.data['data'.encode()][index, 2048:].reshape(32, 32)
image = numpy.dstack((r, g, b))
if self.transform:
image = self.transform(image)
return label, image