-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathapp.py
204 lines (148 loc) · 5.93 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
from flask import Flask, request, jsonify
import torch
import numpy as np
import pandas as pd
import os
from PIL import Image
import torchvision.models as models
from torchvision import models
import io
app = Flask(__name__)
"""################################## CHECK FOR FILETYPE #####################################################"""
def is_image_file(filename):
# Get the file extension (e.g., ".jpg",)
file_extension = os.path.splitext(filename)[1].lower()
# Check if the extension is either ".jpg" or ".png"
return file_extension in ('.jpg', '.jpeg')
"""################################## PRE PROCESS THE IMAGE #####################################################"""
def process_image(image):
"""Process an image path into a PyTorch tensor"""
# image = Image.open(image_path)
# Resize
img = image.resize((256, 256))
# Center crop
width = 256
height = 256
new_width = 224
new_height = 224
left = (width - new_width) / 2
top = (height - new_height) / 2
right = (width + new_width) / 2
bottom = (height + new_height) / 2
img = img.crop((left, top, right, bottom))
# Convert to numpy, transpose color dimension and normalize
img = np.array(img).transpose((2, 0, 1)) / 256
img = img[:3,:,:]
# Standardization
means = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1))
stds = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1))
# print(img.shape, means.shape)
img = img - means
img = img / stds
img_tensor = torch.Tensor(img)
return img_tensor
"""################################## MODEL LOADING from CHECKPOINT ##############################################"""
# Basic details
def model_loading():
path = 'model/resnet50-transfer.pth'
# Get the model name
model_name = os.path.basename(path).split('-')[0]if '-' in os.path.basename(path) else os.path.basename(path).split('.')[0]
checkpoint = torch.load(path, map_location = torch.device('cpu'))
if model_name == 'resnet50':
model = models.resnet50( weights = None )
model.fc = checkpoint['fc']
# Load in the state dict
model.load_state_dict(checkpoint['state_dict'])
# Model basics
model.class_to_idx = checkpoint['class_to_idx']
model.idx_to_class = checkpoint['idx_to_class']
model.epochs = checkpoint['epochs']
class_labels = [
'Alpinia Galanga (Rasna)',
'Amaranthus Viridis (Arive-Dantu)',
'Artocarpus Heterophyllus (Jackfruit)',
'Azadirachta Indica (Neem)',
'Basella Alba (Basale)',
'Brassica Juncea (Indian Mustard)',
'Carissa Carandas (Karanda)',
'Citrus Limon (Lemon)',
'Ficus Auriculata (Roxburgh fig)',
'Ficus Religiosa (Peepal Tree)',
'Hibiscus Rosa-sinensis',
'Jasminum (Jasmine)',
'Mangifera Indica (Mango)',
'Mentha (Mint)',
'Moringa Oleifera (Drumstick)',
'Muntingia Calabura (Jamaica Cherry-Gasagase)',
'Murraya Koenigii (Curry)',
'Nerium Oleander (Oleander)',
'Nyctanthes Arbor-tristis (Parijata)',
'Ocimum Tenuiflorum (Tulsi)',
'Piper Betle (Betel)',
'Plectranthus Amboinicus (Mexican Mint)',
'Pongamia Pinnata (Indian Beech)',
'Psidium Guajava (Guava)',
'Punica Granatum (Pomegranate)',
'Santalum Album (Sandalwood)',
'Syzygium Cumini (Jamun)',
'Syzygium Jambos (Rose Apple)',
'Tabernaemontana Divaricata (Crape Jasmine)',
'Trigonella Foenum-graecum (Fenugreek)'
]
# Create model.class_to_idx list
model.class_to_idx = [(label, idx) for idx, label in enumerate(class_labels)]
# Create model.idx_to_class list
model.idx_to_class = [(idx, label) for idx, label in enumerate(class_labels)]
return model
"""###################################### PREDICTION FUNCTION ################################################"""
def predict(image, model, topk ):
"""
Make a prediction for an image using a trained model
Params
--------
image_path (str): filename of the image
model (PyTorch model): trained model for inference
topk (int): number of top predictions to return
--------
Returns
"""
img_tensor = process_image(image)
img_tensor = img_tensor.reshape(1, 3, 224, 224)
with torch.no_grad():
# Set to evaluation
model.eval()
# Model outputs log probabilities
out = model(img_tensor)
ps = torch.exp(out)
topk, topclass = ps.topk(topk, dim = 1)
top_classes = [model.idx_to_class[class_] for class_ in topclass.cpu().numpy()[0]]
top_p = topk.cpu().numpy()[0]
return top_p, top_classes
"""############################### Extract Info from CSV corresponding to the Index ############"""
def extract(index):
if ( index < 30 ):
df = pd.read_csv("info1.csv")
# Retrieve column based on index and store as string
info = df.iloc[index, 2]
return info
else :
return 'Not a valid Index'
"""########################################## ROUTES ##################################################"""
# when the user hits submit button
@app.route('/upload', methods=['POST', 'GET'])
def upload_file():
if request.method == "POST":
file = request.files['file']
if file and is_image_file(file.filename):
image_bytes = file.read()
image = Image.open(io.BytesIO(image_bytes))
model = model_loading()
# Predict Function, takes (modelName, number of top precitions to return) as parameters
p, classes = predict(image, model, 1)
info = extract(classes[0][0])
data = {"prediction" : classes[0][1], "confidence_level" : p[0]*100, "info" : info}
return jsonify(data)
return 'Upload failed. Please check for correct file formats, only jpeg and png are accepted.'
"""##################################### MAIN APP CALL #########################################"""
if __name__ == "__main__":
app.run( debug = True )