Skip to content

Commit

Permalink
mlp classifier removed
Browse files Browse the repository at this point in the history
  • Loading branch information
arun477 committed Sep 14, 2023
1 parent 4938f38 commit 19d498a
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 221 deletions.
10 changes: 4 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
# MNIST CLASSIFIER
MNIST classifier from scratch

* MLP Model Accuracy: 96%
* CNN Model Accuracy: 98%
* Model: CNN
* Accuracy: 98%

* Training Notebook: mnist_classifier.ipynb
* Cleaned Python Version: mnist_classifier.py
* Cleaned Python Inference Version: mnist_classifier.py

* MLP Trained Model: mlp_classifier.pkl
* CNN Trained Model: cnn_classifier.pkl
* Model: classifier.pkl
Binary file added classifier.pkl
Binary file not shown.
Binary file removed cnn_classifier.pkl
Binary file not shown.
Binary file removed mlp_classifier.pkl
Binary file not shown.
160 changes: 14 additions & 146 deletions mnist_classifier.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 60,
"execution_count": 100,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -23,15 +23,15 @@
},
{
"cell_type": "code",
"execution_count": 61,
"execution_count": 101,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Found cached dataset mnist (/Users/arun/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332)\n",
"100%|██████████| 2/2 [00:00<00:00, 71.77it/s]\n"
"100%|██████████| 2/2 [00:00<00:00, 35.54it/s]\n"
]
}
],
Expand All @@ -43,7 +43,7 @@
},
{
"cell_type": "code",
"execution_count": 62,
"execution_count": 102,
"metadata": {},
"outputs": [
{
Expand All @@ -70,7 +70,7 @@
},
{
"cell_type": "code",
"execution_count": 87,
"execution_count": 103,
"metadata": {},
"outputs": [
{
Expand All @@ -79,7 +79,7 @@
"(torch.Size([1024, 1, 28, 28]), torch.Size([1024]))"
]
},
"execution_count": 87,
"execution_count": 103,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -102,7 +102,7 @@
},
{
"cell_type": "code",
"execution_count": 77,
"execution_count": 105,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -117,109 +117,7 @@
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {},
"outputs": [],
"source": [
"# model definition\n",
"def linear_classifier():\n",
" return nn.Sequential(\n",
" Reshape((-1, 784)),\n",
" nn.Linear(784, 50),\n",
" nn.ReLU(),\n",
" nn.Linear(50, 50),\n",
" nn.ReLU(),\n",
" nn.Linear(50, 10)\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train, epoch:1, loss: 0.3142, accuracy: 0.7951\n",
"eval, epoch:1, loss: 0.2298, accuracy: 0.9048\n",
"train, epoch:2, loss: 0.2198, accuracy: 0.9204\n",
"eval, epoch:2, loss: 0.1663, accuracy: 0.9350\n",
"train, epoch:3, loss: 0.1776, accuracy: 0.9420\n",
"eval, epoch:3, loss: 0.1267, accuracy: 0.9493\n",
"train, epoch:4, loss: 0.1328, accuracy: 0.9568\n",
"eval, epoch:4, loss: 0.0959, accuracy: 0.9598\n",
"train, epoch:5, loss: 0.1038, accuracy: 0.9637\n",
"eval, epoch:5, loss: 0.0913, accuracy: 0.9643\n"
]
}
],
"source": [
"model = linear_classifier()\n",
"lr = 0.1\n",
"max_lr = 0.1\n",
"epochs = 5\n",
"opt = optim.AdamW(model.parameters(), lr=lr)\n",
"sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)\n",
"\n",
"for epoch in range(epochs):\n",
" for train in (True, False):\n",
" accuracy = 0\n",
" dl = dls.train if train else dls.valid\n",
" for xb,yb in dl:\n",
" preds = model(xb)\n",
" loss = F.cross_entropy(preds, yb)\n",
" if train:\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
" with torch.no_grad():\n",
" accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()\n",
" if train:\n",
" sched.step()\n",
" accuracy /= len(dl)\n",
" print(f\"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}\")"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {
"tags": [
"exclude"
]
},
"outputs": [],
"source": [
"# with open('./mlp_classifier.pkl', 'wb') as model_file:\n",
"# pickle.dump(model, model_file)"
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {},
"outputs": [],
"source": [
"# def _conv_block(ni, nf, stride, act=act_gr, norm=None, ks=3):\n",
"# return nn.Sequential(conv(ni, nf, stride=1, act=act, norm=norm, ks=ks),\n",
"# conv(nf, nf, stride=stride, act=None, norm=norm, ks=ks))\n",
"\n",
"# class ResBlock(nn.Module):\n",
"# def __init__(self, ni, nf, stride=1, ks=3, act=act_gr, norm=None):\n",
"# super().__init__()\n",
"# self.convs = _conv_block(ni, nf, stride, act=act, ks=ks, norm=norm)\n",
"# self.idconv = fc.noop if ni==nf else conv(ni, nf, ks=1, stride=1, act=None)\n",
"# self.pool = fc.noop if stride==1 else nn.AvgPool2d(2, ceil_mode=True)\n",
"# self.act = act()\n",
"\n",
"# def forward(self, x): return self.act(self.convs(x) + self.idconv(self.pool(x)))"
]
},
{
"cell_type": "code",
"execution_count": 83,
"execution_count": 106,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -251,7 +149,7 @@
},
{
"cell_type": "code",
"execution_count": 92,
"execution_count": 107,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -264,25 +162,12 @@
" ResBlock(64, 64, norm=nn.BatchNorm2d(64)),\n",
" conv(64, 10, act=False),\n",
" nn.Flatten(),\n",
" )\n",
"\n",
"\n",
"# def cnn_classifier():\n",
"# return nn.Sequential(\n",
"# ResBlock(1, 16, norm=nn.BatchNorm2d(16)),\n",
"# ResBlock(16, 32, norm=nn.BatchNorm2d(32)),\n",
"# ResBlock(32, 64, norm=nn.BatchNorm2d(64)),\n",
"# ResBlock(64, 128, norm=nn.BatchNorm2d(128)),\n",
"# ResBlock(128, 256, norm=nn.BatchNorm2d(256)),\n",
"# ResBlock(256, 256, norm=nn.BatchNorm2d(256)),\n",
"# conv(256, 10, act=False),\n",
"# nn.Flatten(),\n",
"# )"
" )"
]
},
{
"cell_type": "code",
"execution_count": 93,
"execution_count": 108,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -293,26 +178,9 @@
},
{
"cell_type": "code",
"execution_count": 94,
"execution_count": 109,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train, epoch:1, loss: 0.0827, accuracy: 0.9102\n",
"eval, epoch:1, loss: 0.0448, accuracy: 0.9817\n",
"train, epoch:2, loss: 0.0382, accuracy: 0.9835\n",
"eval, epoch:2, loss: 0.0353, accuracy: 0.9863\n",
"train, epoch:3, loss: 0.0499, accuracy: 0.9856\n",
"eval, epoch:3, loss: 0.0300, accuracy: 0.9867\n",
"train, epoch:4, loss: 0.0361, accuracy: 0.9869\n",
"eval, epoch:4, loss: 0.0203, accuracy: 0.9877\n",
"train, epoch:5, loss: 0.0427, accuracy: 0.9846\n",
"eval, epoch:5, loss: 0.0250, accuracy: 0.9866\n"
]
}
],
"outputs": [],
"source": [
"model = cnn_classifier()\n",
"model.apply(kaiming_init)\n",
Expand Down Expand Up @@ -350,7 +218,7 @@
},
"outputs": [],
"source": [
"# with open('./cnn_classifier.pkl', 'wb') as model_file:\n",
"# with open('./classifier.pkl', 'wb') as model_file:\n",
"# pickle.dump(model, model_file)"
]
},
Expand Down
69 changes: 0 additions & 69 deletions mnist_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,59 +53,6 @@ def forward(self, x):
return x.reshape(self.dim)


# model definition
def linear_classifier():
return nn.Sequential(
Reshape((-1, 784)),
nn.Linear(784, 50),
nn.ReLU(),
nn.Linear(50, 50),
nn.ReLU(),
nn.Linear(50, 10)
)


model = linear_classifier()
lr = 0.1
max_lr = 0.1
epochs = 5
opt = optim.AdamW(model.parameters(), lr=lr)
sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)

for epoch in range(epochs):
for train in (True, False):
accuracy = 0
dl = dls.train if train else dls.valid
for xb,yb in dl:
preds = model(xb)
loss = F.cross_entropy(preds, yb)
if train:
loss.backward()
opt.step()
opt.zero_grad()
with torch.no_grad():
accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()
if train:
sched.step()
accuracy /= len(dl)
print(f"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}")


# def _conv_block(ni, nf, stride, act=act_gr, norm=None, ks=3):
# return nn.Sequential(conv(ni, nf, stride=1, act=act, norm=norm, ks=ks),
# conv(nf, nf, stride=stride, act=None, norm=norm, ks=ks))

# class ResBlock(nn.Module):
# def __init__(self, ni, nf, stride=1, ks=3, act=act_gr, norm=None):
# super().__init__()
# self.convs = _conv_block(ni, nf, stride, act=act, ks=ks, norm=norm)
# self.idconv = fc.noop if ni==nf else conv(ni, nf, ks=1, stride=1, act=None)
# self.pool = fc.noop if stride==1 else nn.AvgPool2d(2, ceil_mode=True)
# self.act = act()

# def forward(self, x): return self.act(self.convs(x) + self.idconv(self.pool(x)))


def conv(ni, nf, ks=3, s=2, act=nn.ReLU, norm=None):
layers = [nn.Conv2d(ni, nf, kernel_size=ks, stride=s, padding=ks//2)]
if norm:
Expand Down Expand Up @@ -144,19 +91,6 @@ def cnn_classifier():
)


# def cnn_classifier():
# return nn.Sequential(
# ResBlock(1, 16, norm=nn.BatchNorm2d(16)),
# ResBlock(16, 32, norm=nn.BatchNorm2d(32)),
# ResBlock(32, 64, norm=nn.BatchNorm2d(64)),
# ResBlock(64, 128, norm=nn.BatchNorm2d(128)),
# ResBlock(128, 256, norm=nn.BatchNorm2d(256)),
# ResBlock(256, 256, norm=nn.BatchNorm2d(256)),
# conv(256, 10, act=False),
# nn.Flatten(),
# )


def kaiming_init(m):
if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
nn.init.kaiming_normal_(m.weight)
Expand Down Expand Up @@ -187,6 +121,3 @@ def kaiming_init(m):
accuracy /= len(dl)
print(f"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}")




0 comments on commit 19d498a

Please sign in to comment.