forked from KGPML/Hyperspectral
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSpatial_dataset.py
119 lines (76 loc) · 3.02 KB
/
Spatial_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# coding: utf-8
# In[1]:
import tensorflow as tf
import numpy as np
import scipy.io as io
# In[2]:
""" Functions for handling the IndianPines data"""
class DataSet(object):
def __init__(self, images, labels, dtype=tf.float32):
"""Construct a DataSet.
FIXME: fake_data options
one_hot arg is used only if fake_data is true. `dtype` can be either
`uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
`[0, 1]`.
"""
#COnvert the shape from [num_exmaple,channels, height, width]
#to [num_exmaple, height, width, channels]
images = np.transpose(images,(0,2,3,1))
#labels[:] = [i - 1 for i in labels]
labels = np.transpose(labels)
dtype = tf.as_dtype(dtype).base_dtype
if dtype not in (tf.uint8, tf.float32):
raise TypeError('Invalid image dtype %r, expected uint8 or float32' %
dtype)
assert images.shape[0] == labels.shape[0], (
'images.shape: %s labels.shape: %s' % (images.shape, labels.shape))
self._num_examples = images.shape[0]
# Convert shape from [num examples, rows, columns, depth]
# to [num examples, rows*columns*depth]
images = images.reshape(images.shape[0],images.shape[1] * images.shape[2] * images.shape[3])
# if dtype == tf.float32:
# # Convert from [0, 255] -> [0.0, 1.0].
# images = images.astype(numpy.float32)
# images = numpy.multiply(images, 1.0 / 255.0)
self._images = images
self._labels = labels
self._epochs_completed = 0
self._index_in_epoch = 0
@property
def images(self):
return self._images
@property
def labels(self):
return self._labels
@property
def num_examples(self):
return self._num_examples
@property
def epochs_completed(self):
return self._epochs_completed
def next_batch(self, batch_size):
"""Return the next `batch_size` examples from this data set."""
start = self._index_in_epoch
self._index_in_epoch += batch_size
if self._index_in_epoch > self._num_examples:
# Finished epoch
self._epochs_completed += 1
# Shuffle the data
perm = np.arange(self._num_examples)
np.random.shuffle(perm)
self._images = self._images[perm]
self._labels = self._labels[perm]
# Start next epoch
start = 0
self._index_in_epoch = batch_size
assert batch_size <= self._num_examples
end = self._index_in_epoch
return self._images[start:end], np.reshape(self._labels[start:end],len(self._labels[start:end]))
# In[20]:
def read_data_sets(directory,value, dtype=tf.float32):
images = io.loadmat(directory)[value+'_patch']
labels = io.loadmat(directory)[value+'_labels']
data_sets = DataSet(images, labels, dtype=dtype)
return data_sets
# In[ ]:
# In[ ]: