-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplot_digits.py
59 lines (43 loc) · 1.91 KB
/
plot_digits.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import numpy as np
import matplotlib.pyplot as plt
def plot_digits(digit_array):
"""Visualizes each example in digit_array.
Note: N is the number of examples
and M is the number of features per example.
Inputs:
digits: N x M array of pixel intensities.
"""
CLASS_EXAMPLES_PER_PANE = 5
# assume two evenly split classes
examples_per_class = digit_array.shape[0]/2
num_panes = int(np.ceil(float(examples_per_class)/CLASS_EXAMPLES_PER_PANE))
for pane in xrange(num_panes):
print "Displaying pane {}/{}".format(pane+1, num_panes)
top_start = pane*CLASS_EXAMPLES_PER_PANE
top_end = min((pane+1)*CLASS_EXAMPLES_PER_PANE, examples_per_class)
top_pane_digits = extract_digits(digit_array, top_start, top_end)
bottom_start = top_start + examples_per_class
bottom_end = top_end + examples_per_class
bottom_pane_digits = extract_digits(digit_array, bottom_start, bottom_end)
show_pane(top_pane_digits, bottom_pane_digits)
def extract_digits(digit_array, start_index, end_index):
"""Returns a list of 28 x 28 pixel intensity arrays starting
at start_index and ending at end_index.
"""
digits = []
for index in xrange(start_index, end_index):
digits.append(extract_digit_pixels(digit_array, index))
return digits
def extract_digit_pixels(digit_array, index):
"""Extracts the 28 x 28 pixel intensity array at the specified index.
"""
return digit_array[index].reshape(28, 28)
def show_pane(top_digits, bottom_digits):
"""Displays two rows of digits on the screen.
"""
all_digits = top_digits + bottom_digits
fig, axes = plt.subplots(nrows = 2, ncols = len(all_digits)/2)
for axis, digit in zip(axes.reshape(-1), all_digits):
axis.imshow(digit, interpolation='nearest', cmap=plt.gray())
axis.axis('off')
plt.show()