-
Notifications
You must be signed in to change notification settings - Fork 117
/
Copy pathTrainAndTest.py
153 lines (108 loc) · 8.19 KB
/
TrainAndTest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# TrainAndTest.py
import cv2
import numpy as np
import operator
import os
# module level variables ##########################################################################
MIN_CONTOUR_AREA = 100
RESIZED_IMAGE_WIDTH = 20
RESIZED_IMAGE_HEIGHT = 30
###################################################################################################
class ContourWithData():
# member variables ############################################################################
npaContour = None # contour
boundingRect = None # bounding rect for contour
intRectX = 0 # bounding rect top left corner x location
intRectY = 0 # bounding rect top left corner y location
intRectWidth = 0 # bounding rect width
intRectHeight = 0 # bounding rect height
fltArea = 0.0 # area of contour
def calculateRectTopLeftPointAndWidthAndHeight(self): # calculate bounding rect info
[intX, intY, intWidth, intHeight] = self.boundingRect
self.intRectX = intX
self.intRectY = intY
self.intRectWidth = intWidth
self.intRectHeight = intHeight
def checkIfContourIsValid(self): # this is oversimplified, for a production grade program
if self.fltArea < MIN_CONTOUR_AREA: return False # much better validity checking would be necessary
return True
###################################################################################################
def main():
allContoursWithData = [] # declare empty lists,
validContoursWithData = [] # we will fill these shortly
try:
npaClassifications = np.loadtxt("classifications.txt", np.float32) # read in training classifications
except:
print "error, unable to open classifications.txt, exiting program\n"
os.system("pause")
return
# end try
try:
npaFlattenedImages = np.loadtxt("flattened_images.txt", np.float32) # read in training images
except:
print "error, unable to open flattened_images.txt, exiting program\n"
os.system("pause")
return
# end try
npaClassifications = npaClassifications.reshape((npaClassifications.size, 1)) # reshape numpy array to 1d, necessary to pass to call to train
kNearest = cv2.ml.KNearest_create() # instantiate KNN object
kNearest.train(npaFlattenedImages, cv2.ml.ROW_SAMPLE, npaClassifications)
imgTestingNumbers = cv2.imread("test1.png") # read in testing numbers image
if imgTestingNumbers is None: # if image was not read successfully
print "error: image not read from file \n\n" # print error message to std out
os.system("pause") # pause so user can see error message
return # and exit function (which exits program)
# end if
imgGray = cv2.cvtColor(imgTestingNumbers, cv2.COLOR_BGR2GRAY) # get grayscale image
imgBlurred = cv2.GaussianBlur(imgGray, (5,5), 0) # blur
# filter image from grayscale to black and white
imgThresh = cv2.adaptiveThreshold(imgBlurred, # input image
255, # make pixels that pass the threshold full white
cv2.ADAPTIVE_THRESH_GAUSSIAN_C, # use gaussian rather than mean, seems to give better results
cv2.THRESH_BINARY_INV, # invert so foreground will be white, background will be black
11, # size of a pixel neighborhood used to calculate threshold value
2) # constant subtracted from the mean or weighted mean
imgThreshCopy = imgThresh.copy() # make a copy of the thresh image, this in necessary b/c findContours modifies the image
imgContours, npaContours, npaHierarchy = cv2.findContours(imgThreshCopy, # input image, make sure to use a copy since the function will modify this image in the course of finding contours
cv2.RETR_EXTERNAL, # retrieve the outermost contours only
cv2.CHAIN_APPROX_SIMPLE) # compress horizontal, vertical, and diagonal segments and leave only their end points
for npaContour in npaContours: # for each contour
contourWithData = ContourWithData() # instantiate a contour with data object
contourWithData.npaContour = npaContour # assign contour to contour with data
contourWithData.boundingRect = cv2.boundingRect(contourWithData.npaContour) # get the bounding rect
contourWithData.calculateRectTopLeftPointAndWidthAndHeight() # get bounding rect info
contourWithData.fltArea = cv2.contourArea(contourWithData.npaContour) # calculate the contour area
allContoursWithData.append(contourWithData) # add contour with data object to list of all contours with data
# end for
for contourWithData in allContoursWithData: # for all contours
if contourWithData.checkIfContourIsValid(): # check if valid
validContoursWithData.append(contourWithData) # if so, append to valid contour list
# end if
# end for
validContoursWithData.sort(key = operator.attrgetter("intRectX")) # sort contours from left to right
strFinalString = "" # declare final string, this will have the final number sequence by the end of the program
for contourWithData in validContoursWithData: # for each contour
# draw a green rect around the current char
cv2.rectangle(imgTestingNumbers, # draw rectangle on original testing image
(contourWithData.intRectX, contourWithData.intRectY), # upper left corner
(contourWithData.intRectX + contourWithData.intRectWidth, contourWithData.intRectY + contourWithData.intRectHeight), # lower right corner
(0, 255, 0), # green
2) # thickness
imgROI = imgThresh[contourWithData.intRectY : contourWithData.intRectY + contourWithData.intRectHeight, # crop char out of threshold image
contourWithData.intRectX : contourWithData.intRectX + contourWithData.intRectWidth]
imgROIResized = cv2.resize(imgROI, (RESIZED_IMAGE_WIDTH, RESIZED_IMAGE_HEIGHT)) # resize image, this will be more consistent for recognition and storage
npaROIResized = imgROIResized.reshape((1, RESIZED_IMAGE_WIDTH * RESIZED_IMAGE_HEIGHT)) # flatten image into 1d numpy array
npaROIResized = np.float32(npaROIResized) # convert from 1d numpy array of ints to 1d numpy array of floats
retval, npaResults, neigh_resp, dists = kNearest.findNearest(npaROIResized, k = 1) # call KNN function find_nearest
strCurrentChar = str(chr(int(npaResults[0][0]))) # get character from results
strFinalString = strFinalString + strCurrentChar # append current char to full string
# end for
print "\n" + strFinalString + "\n" # show the full string
cv2.imshow("imgTestingNumbers", imgTestingNumbers) # show input image with green boxes drawn around found digits
cv2.waitKey(0) # wait for user key press
cv2.destroyAllWindows() # remove windows from memory
return
###################################################################################################
if __name__ == "__main__":
main()
# end if