Skip to content

Commit

Permalink
removed batchnorm2d
Browse files Browse the repository at this point in the history
  • Loading branch information
arun477 committed Sep 15, 2023
1 parent 6c5990e commit 338bbe8
Show file tree
Hide file tree
Showing 7 changed files with 256 additions and 76 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# <img src="/static/favicon.png" alt="Logo" style="float: left; margin-right: 10px; border-radius:100%;margin-top:5px" /> MNIST CLASSIFIER
MNIST classifier from scratch
* Model: CNN
* Accuracy: 98%
* Accuracy: 94%

* Training Notebook: mnist_classifier.ipynb
* Cleaned Python Inference Version: mnist_classifier.py
Expand Down
Binary file modified __pycache__/mnist_classifier.cpython-39.pyc
Binary file not shown.
Binary file modified __pycache__/server.cpython-39.pyc
Binary file not shown.
Binary file modified classifier.pth
Binary file not shown.
303 changes: 236 additions & 67 deletions mnist_classifier.ipynb

Large diffs are not rendered by default.

26 changes: 18 additions & 8 deletions mnist_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import torchvision.transforms.functional as TF
from torch.utils.data import default_collate, DataLoader
import torch.optim as optim
import pickle


def transform_ds(b):
Expand All @@ -23,7 +22,7 @@ def transform_ds(b):
class DataLoaders:
def __init__(self, train_ds, valid_ds, bs, collate_fn, **kwargs):
self.train = DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate_fn, **kwargs)
self.valid = DataLoader(valid_ds, batch_size=bs*2, shuffle=False, collate_fn=collate_fn, **kwargs)
self.valid = DataLoader(valid_ds, batch_size=bs, shuffle=False, collate_fn=collate_fn, **kwargs)

def collate_fn(b):
collate = default_collate(b)
Expand Down Expand Up @@ -65,13 +64,24 @@ def forward(self, x):
return self.act(self.convs(x) + self.idconv(self.pool(x)))


# def cnn_classifier():
# return nn.Sequential(
# ResBlock(1, 8, norm=nn.BatchNorm2d(8)),
# ResBlock(8, 16, norm=nn.BatchNorm2d(16)),
# ResBlock(16, 32, norm=nn.BatchNorm2d(32)),
# ResBlock(32, 64, norm=nn.BatchNorm2d(64)),
# ResBlock(64, 64, norm=nn.BatchNorm2d(64)),
# conv(64, 10, act=False),
# nn.Flatten(),
# )

def cnn_classifier():
return nn.Sequential(
ResBlock(1, 8, norm=nn.BatchNorm2d(8)),
ResBlock(8, 16, norm=nn.BatchNorm2d(16)),
ResBlock(16, 32, norm=nn.BatchNorm2d(32)),
ResBlock(32, 64, norm=nn.BatchNorm2d(64)),
ResBlock(64, 64, norm=nn.BatchNorm2d(64)),
ResBlock(1, 8,),
ResBlock(8, 16, ),
ResBlock(16, 32,),
ResBlock(32, 64, ),
ResBlock(64, 64,),
conv(64, 10, act=False),
nn.Flatten(),
)
Expand All @@ -83,7 +93,7 @@ def kaiming_init(m):


loaded_model = cnn_classifier()
loaded_model.load_state_dict(torch.load('classifier.pth'))
loaded_model.load_state_dict(torch.load('classifier.pth'));
loaded_model.eval();


Expand Down
1 change: 1 addition & 0 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path
import torchvision.transforms as transforms
import mnist_classifier
import torch

app = FastAPI()

Expand Down

0 comments on commit 338bbe8

Please sign in to comment.