Skip to content

Commit

Permalink
full model inference/ui added
Browse files Browse the repository at this point in the history
  • Loading branch information
arun477 committed Sep 14, 2023
1 parent fccfd79 commit c32023c
Show file tree
Hide file tree
Showing 9 changed files with 290 additions and 90 deletions.
Binary file added __pycache__/mnist_classifier.cpython-39.pyc
Binary file not shown.
Binary file added __pycache__/server.cpython-39.pyc
Binary file not shown.
Binary file removed classifier.pkl
Binary file not shown.
Binary file added classifier.pth
Binary file not shown.
180 changes: 142 additions & 38 deletions mnist_classifier.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,31 @@
"import torchvision.transforms.functional as TF\n",
"from torch.utils.data import default_collate, DataLoader\n",
"import torch.optim as optim\n",
"import pickle\n",
"import pickle\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"exclude"
]
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"plt.rcParams['figure.figsize'] = [2, 2]"
]
},
{
"cell_type": "code",
"execution_count": 101,
"metadata": {},
"metadata": {
"tags": [
"exclude"
]
},
"outputs": [
{
"name": "stderr",
Expand All @@ -43,27 +59,25 @@
},
{
"cell_type": "code",
"execution_count": 102,
"execution_count": 112,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAI4AAACOCAYAAADn/TAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAIsUlEQVR4nO3df2yU9R0H8PfHtrQroFJBVrGjHVRAweHWCASCJBuumiXOLAyYWTbjQiYy58Y2fmzZ5oILJgsJMjSRrCsmig7mAjFsZBIlLkNGdeBgrOWnWqnFwkDmUNrrZ3/0bPu59cfTz3P33NPr+5WQu89zd32+MW+/z/eeu+dzoqogGqgrsj0AGpwYHHJhcMiFwSEXBodcGBxyCRUcEakWkXoROSYiK9M1KIo/8Z7HEZE8AA0A5gNoBLAfwGJV/Wf6hkdxlR/itbcCOKaqJwBARJ4FcBeAXoMzTAq1CMND7JKidhH/blHVManbwwRnHIC3u9WNAGb09YIiDMcM+XyIXVLUXtRtb/a0PUxwpIdt/3fcE5ElAJYAQBGKQ+yO4iTM4rgRQFm3+noAp1OfpKpPqmqVqlYVoDDE7ihOwgRnP4BKEakQkWEAFgHYkZ5hUdy5D1Wq2iYiywDsApAHoEZVD6dtZBRrYdY4UNWdAHamaSw0iPDMMbkwOOTC4JALg0MuDA65MDjkwuCQC4NDLgwOuTA45MLgkAuDQy6hPuQcSiTf/qfKGzM68Gvrf1Bu6kRxu6nHTzhj6uKl9jty764bZurXq54zdUviA1PP2Lq88/7E778aeJwDwRmHXBgccmFwyGXIrHHyplSaWgsLTH36tqtNfWmmXTeUXGXrVz5j1xlh/PG/I0396K+rTb1v2jOmPtl6ydRrm+eb+rpXMt/ziDMOuTA45MLgkEvOrnES8z5r6nW1G019Q4E9NxKlVk2Y+qcbvmnq/A/sGmXW1mWmHvlOm6kLW+yap7huX8gR9o8zDrkwOOTC4JBLzq5xCuvtZeyvfVhm6hsKmtO2r+VNM0194j/2c6zaCdtMfaHdrmHGPvbXUPvPRqdqzjjkwuCQC4NDLjm7xmlretfUGx5dYOpHqu1nT3lvjDD1waUb+vz7a1pu7rx/7Au2YVTifJOpvzZrqalPPWj/VgUO9rmvOOKMQy79BkdEakTkjIgc6ratRET+LCJHk7ejMjtMipsgM04tgOqUbSsB7FbVSgC7kzUNIYH6HItIOYAXVHVqsq4HME9Vm0SkFMDLqjqpv79zpZRoXLqO5o2+xtSJs+dMffKZm019eG6NqW/95Xc671+7Mdx5mDh7Ube9pqpVqdu9a5yxqtoEAMnba8MMjgafjL+rYrva3OSdcZqThygkb8/09kS2q81N3hlnB4BvAFibvN2ethFFJNFyts/HW9/v+/s6N93T9csD7z2RZx9sTyDXBXk7vgXAXgCTRKRRRO5DR2Dmi8hRdPwIyNrMDpPipt8ZR1UX9/JQPN4eUVbwzDG55OxnVWFNWdFg6nun2Qn2t+N3d96/bcED5rGRz2Xmeu044YxDLgwOuTA45MI1Ti8S5y+Y+uz9U0z91o6ua5lWrnnKPLbqq3ebWv9+lanLHtlrd+b8XdRs4oxDLgwOufBQFVD7wSOmXvTwDzvvP/2zX5nHDsy0hy7Yq2dw03B7SW/lJvtV07YTp3yDjBBnHHJhcMiFwSGXQF8dTZc4fXU0nXT2dFNfubbR1Fs+vavP109+6VumnvSwPRWQOHrCP7iQ0v3VURriGBxyYXDIhWucDMgbay/6OL1woqn3rVhv6itS/v+95+Ttpr4wp++vuWYS1ziUVgwOuTA45MLPqjIg0WwvMxv7mK0//JFtN1ss9lKcTeUvmPpLdz9kn/+HzLej7Q9nHHJhcMiFwSEXrnHSoH3OdFMfX1Bk6qnTT5k6dU2TasO5W+zzt9e5x5YpnHHIhcEhFwaHXLjGCUiqppq64cGudcqm2ZvNY3OLLg/ob3+kraZ+9VyFfUK7/U5yHHDGIZcg/XHKROQlETkiIodF5LvJ7WxZO4QFmXHaACxX1SnouNDjARG5EWxZO6QFaazUBODjDqMXReQIgHEA7gIwL/m0zQBeBrAiI6OMQH7FeFMfv/c6U/984bOm/sqIFve+Vjfbr7fsWW8vvBq1OeUS4Rga0Bon2e/4FgD7wJa1Q1rg4IjICAC/B/CQqr4/gNctEZE6EalrxUeeMVIMBQqOiBSgIzRPq+rzyc2BWtayXW1u6neNIyIC4DcAjqjqum4PDaqWtfnlnzL1hc+VmnrhL/5k6m9f/Ty8Un9qce/jdk1TUvs3U49qj/+aJlWQE4CzAXwdwD9E5EBy22p0BOZ3yfa1bwFY0PPLKRcFeVf1FwDSy8O5f8kC9YhnjsklZz6ryi/9pKnP1Qw39f0Ve0y9eGS4n49e9s6czvuvPzHdPDZ62yFTl1wcfGuY/nDGIRcGh1wYHHIZVGucy1/sOh9y+Xv2pxBXT9xp6ts/YX8eeqCaE5dMPXfHclNP/sm/Ou+XnLdrmPZQex4cOOOQC4NDLoPqUHXqy105b5i2dUCv3Xh+gqnX77GtRCRhz3FOXnPS1JXN9rLb3P8NvL5xxiEXBodcGBxyYSs36hNbuVFaMTjkwuCQC4NDLgwOuTA45MLgkAuDQy4MDrkwOOTC4JBLpJ9Vich7AN4EMBqAv09IZnFs1nhVHZO6MdLgdO5UpK6nD87igGMLhocqcmFwyCVbwXkyS/sNgmMLICtrHBr8eKgil0iDIyLVIlIvIsdEJKvtbUWkRkTOiMihbtti0bt5MPSWjiw4IpIHYCOAOwDcCGBxsl9yttQCqE7ZFpfezfHvLa2qkfwDMAvArm71KgCrotp/L2MqB3CoW10PoDR5vxRAfTbH121c2wHMj9P4ojxUjQPwdre6MbktTmLXuzmuvaWjDE5PfQT5lq4P3t7SUYgyOI0AyrrV1wM4HeH+gwjUuzkKYXpLRyHK4OwHUCkiFSIyDMAidPRKjpOPezcDWezdHKC3NJDt3tIRL/LuBNAA4DiAH2d5wbkFHT9u0oqO2fA+ANeg493K0eRtSZbGNgcdh/E3ABxI/rszLuNTVZ45Jh+eOSYXBodcGBxyYXDIhcEhFwaHXBgccmFwyOV/atVD7hyCzrEAAAAASUVORK5CYII=",
"text/plain": [
"<Figure size 144x144 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"outputs": [],
"source": [
"def transform_ds(b):\n",
" b[x] = [TF.to_tensor(ele) for ele in b[x]]\n",
" return b\n",
"\n",
" return b"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"exclude"
]
},
"outputs": [],
"source": [
"dst = ds.with_transform(transform_ds)\n",
"plt.imshow(dst['train'][0]['image'].permute(1,2,0));"
]
Expand Down Expand Up @@ -93,8 +107,19 @@
"\n",
"def collate_fn(b):\n",
" collate = default_collate(b)\n",
" return (collate[x], collate[y])\n",
"\n",
" return (collate[x], collate[y])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": [
"exclude"
]
},
"outputs": [],
"source": [
"dls = DataLoaders(dst['train'], dst['test'], bs=bs, collate_fn=collate_fn)\n",
"xb,yb = next(iter(dls.train))\n",
"xb.shape, yb.shape"
Expand Down Expand Up @@ -178,23 +203,27 @@
},
{
"cell_type": "code",
"execution_count": 109,
"metadata": {},
"execution_count": 195,
"metadata": {
"tags": [
"exclude"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train, epoch:1, loss: 0.1077, accuracy: 0.9104\n",
"eval, epoch:1, loss: 0.0382, accuracy: 0.9791\n",
"train, epoch:2, loss: 0.0410, accuracy: 0.9832\n",
"eval, epoch:2, loss: 0.0221, accuracy: 0.9866\n",
"train, epoch:3, loss: 0.0538, accuracy: 0.9871\n",
"eval, epoch:3, loss: 0.0141, accuracy: 0.9887\n",
"train, epoch:4, loss: 0.0343, accuracy: 0.9858\n",
"eval, epoch:4, loss: 0.0163, accuracy: 0.9871\n",
"train, epoch:5, loss: 0.0390, accuracy: 0.9865\n",
"eval, epoch:5, loss: 0.0169, accuracy: 0.9871\n"
"train, epoch:1, loss: 0.0776, accuracy: 0.9172\n",
"eval, epoch:1, loss: 0.0372, accuracy: 0.9818\n",
"train, epoch:2, loss: 0.0571, accuracy: 0.9828\n",
"eval, epoch:2, loss: 0.0287, accuracy: 0.9863\n",
"train, epoch:3, loss: 0.0425, accuracy: 0.9847\n",
"eval, epoch:3, loss: 0.0256, accuracy: 0.9865\n",
"train, epoch:4, loss: 0.0271, accuracy: 0.9868\n",
"eval, epoch:4, loss: 0.0378, accuracy: 0.9826\n",
"train, epoch:5, loss: 0.0395, accuracy: 0.9844\n",
"eval, epoch:5, loss: 0.0307, accuracy: 0.9873\n"
]
}
],
Expand Down Expand Up @@ -227,16 +256,84 @@
},
{
"cell_type": "code",
"execution_count": 110,
"execution_count": 196,
"metadata": {
"tags": [
"exclude"
]
},
"outputs": [],
"source": [
"with open('./classifier.pkl', 'wb') as model_file:\n",
" pickle.dump(model, model_file)"
"torch.save(model.state_dict(), 'classifier.pth')"
]
},
{
"cell_type": "code",
"execution_count": 197,
"metadata": {},
"outputs": [],
"source": [
"loaded_model = cnn_classifier()\n",
"loaded_model.load_state_dict(torch.load('classifier.pth'))\n",
"loaded_model.eval();"
]
},
{
"cell_type": "code",
"execution_count": 206,
"metadata": {},
"outputs": [],
"source": [
"def predict(img):\n",
" with torch.no_grad():\n",
" img = img[None,]\n",
" pred = loaded_model(img)[0]\n",
" pred_probs = F.softmax(pred, dim=0)\n",
" pred = [{\"digit\": i, \"prob\": f'{prob*100:.2f}%', 'logits': pred[i]} for i, prob in enumerate(pred_probs)]\n",
" pred = sorted(pred, key=lambda ele: ele['digit'], reverse=False)\n",
" return pred"
]
},
{
"cell_type": "code",
"execution_count": 204,
"metadata": {
"tags": [
"exclude"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(5)\n"
]
},
{
"data": {
"text/plain": [
"[{'digit': 0, 'prob': '21.42%', 'logits': tensor(0.0559)},\n",
" {'digit': 8, 'prob': '19.44%', 'logits': tensor(-0.0408)},\n",
" {'digit': 4, 'prob': '18.08%', 'logits': tensor(-0.1135)},\n",
" {'digit': 9, 'prob': '16.41%', 'logits': tensor(-0.2104)},\n",
" {'digit': 6, 'prob': '12.23%', 'logits': tensor(-0.5049)},\n",
" {'digit': 1, 'prob': '6.87%', 'logits': tensor(-1.0806)},\n",
" {'digit': 7, 'prob': '2.33%', 'logits': tensor(-2.1633)},\n",
" {'digit': 5, 'prob': '1.19%', 'logits': tensor(-2.8386)},\n",
" {'digit': 2, 'prob': '1.06%', 'logits': tensor(-2.9527)},\n",
" {'digit': 3, 'prob': '0.97%', 'logits': tensor(-3.0359)}]"
]
},
"execution_count": 204,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"img = xb[0].reshape(1, 28, 28)\n",
"print(yb[0])\n",
"predict(img)"
]
},
{
Expand All @@ -252,7 +349,7 @@
},
{
"cell_type": "code",
"execution_count": 111,
"execution_count": 205,
"metadata": {
"tags": [
"exclude"
Expand All @@ -264,13 +361,20 @@
"output_type": "stream",
"text": [
"[NbConvertApp] Converting notebook mnist_classifier.ipynb to script\n",
"[NbConvertApp] Writing 3691 bytes to mnist_classifier.py\n"
"[NbConvertApp] Writing 2904 bytes to mnist_classifier.py\n"
]
}
],
"source": [
"!jupyter nbconvert --to script --TagRemovePreprocessor.remove_cell_tags=\"exclude\" --TemplateExporter.exclude_input_prompt=True mnist_classifier.ipynb\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
54 changes: 16 additions & 38 deletions mnist_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,12 @@
from torch.utils.data import default_collate, DataLoader
import torch.optim as optim
import pickle
get_ipython().run_line_magic('matplotlib', 'inline')
plt.rcParams['figure.figsize'] = [2, 2]


dataset_nm = 'mnist'
x,y = 'image', 'label'
ds = load_dataset(dataset_nm)


def transform_ds(b):
b[x] = [TF.to_tensor(ele) for ele in b[x]]
return b

dst = ds.with_transform(transform_ds)
plt.imshow(dst['train'][0]['image'].permute(1,2,0));


bs = 1024
class DataLoaders:
Expand All @@ -39,10 +29,6 @@ def collate_fn(b):
collate = default_collate(b)
return (collate[x], collate[y])

dls = DataLoaders(dst['train'], dst['test'], bs=bs, collate_fn=collate_fn)
xb,yb = next(iter(dls.train))
xb.shape, yb.shape


class Reshape(nn.Module):
def __init__(self, dim):
Expand Down Expand Up @@ -96,28 +82,20 @@ def kaiming_init(m):
nn.init.kaiming_normal_(m.weight)


model = cnn_classifier()
model.apply(kaiming_init)
lr = 0.1
max_lr = 0.3
epochs = 5
opt = optim.AdamW(model.parameters(), lr=lr)
sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)
for epoch in range(epochs):
for train in (True, False):
accuracy = 0
dl = dls.train if train else dls.valid
for xb,yb in dl:
preds = model(xb)
loss = F.cross_entropy(preds, yb)
if train:
loss.backward()
opt.step()
opt.zero_grad()
with torch.no_grad():
accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()
if train:
sched.step()
accuracy /= len(dl)
print(f"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}")
loaded_model = cnn_classifier()
loaded_model.load_state_dict(torch.load('classifier.pth'))
loaded_model.eval();


def predict(img):
with torch.no_grad():
img = img[None,]
pred = loaded_model(img)[0]
pred_probs = F.softmax(pred, dim=0)
pred = [{"digit": i, "prob": f'{prob*100:.2f}%', 'logits': pred[i]} for i, prob in enumerate(pred_probs)]
pred = sorted(pred, key=lambda ele: ele['digit'], reverse=False)
return pred




7 changes: 6 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
fastapi==0.68.1
uvicorn==0.15.0
aiofiles
aiofiles
torch
fastcore
torchvision
datasets

24 changes: 21 additions & 3 deletions server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from fastapi import FastAPI
from fastapi import FastAPI, File, UploadFile
import io
from PIL import Image
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from pathlib import Path
import torchvision.transforms as transforms
import mnist_classifier

app = FastAPI()

Expand All @@ -10,6 +14,20 @@
async def root():
return FileResponse("static/index.html")

def process_image(file: UploadFile):
image_bytes = file.file.read()
pil_image = Image.open(io.BytesIO(image_bytes))
transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
])
tensor_image = transform(pil_image)
return tensor_image

@app.post("/predict")
async def predict():
return {"prediction": "Hello, World!"}
async def predict(image: UploadFile):
tensor_image = process_image(image)
prediction = mnist_classifier.predict(tensor_image)
return {"prediction": prediction}

Loading

0 comments on commit c32023c

Please sign in to comment.