diff --git a/README.md b/README.md index 2de2c53..989b499 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,9 @@ # MNIST CLASSIFIER MNIST classifier from scratch - -* MLP Model Accuracy: 96% -* CNN Model Accuracy: 98% +* Model: CNN +* Accuracy: 98% * Training Notebook: mnist_classifier.ipynb -* Cleaned Python Version: mnist_classifier.py +* Cleaned Python Inference Version: mnist_classifier.py -* MLP Trained Model: mlp_classifier.pkl -* CNN Trained Model: cnn_classifier.pkl \ No newline at end of file +* Model: classifier.pkl \ No newline at end of file diff --git a/classifier.pkl b/classifier.pkl new file mode 100644 index 0000000..8aab6b0 Binary files /dev/null and b/classifier.pkl differ diff --git a/cnn_classifier.pkl b/cnn_classifier.pkl deleted file mode 100644 index d6e231d..0000000 Binary files a/cnn_classifier.pkl and /dev/null differ diff --git a/mlp_classifier.pkl b/mlp_classifier.pkl deleted file mode 100644 index 8e39e47..0000000 Binary files a/mlp_classifier.pkl and /dev/null differ diff --git a/mnist_classifier.ipynb b/mnist_classifier.ipynb index 8ea5471..5625a3f 100644 --- a/mnist_classifier.ipynb +++ b/mnist_classifier.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 60, + "execution_count": 100, "metadata": {}, "outputs": [], "source": [ @@ -23,7 +23,7 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": 101, "metadata": {}, "outputs": [ { @@ -31,7 +31,7 @@ "output_type": "stream", "text": [ "Found cached dataset mnist (/Users/arun/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332)\n", - "100%|██████████| 2/2 [00:00<00:00, 71.77it/s]\n" + "100%|██████████| 2/2 [00:00<00:00, 35.54it/s]\n" ] } ], @@ -43,7 +43,7 @@ }, { "cell_type": "code", - "execution_count": 62, + "execution_count": 102, "metadata": {}, "outputs": [ { @@ -70,7 +70,7 @@ }, { "cell_type": "code", - "execution_count": 87, + "execution_count": 103, "metadata": {}, "outputs": [ { @@ -79,7 +79,7 @@ "(torch.Size([1024, 1, 28, 28]), torch.Size([1024]))" ] }, - "execution_count": 87, + "execution_count": 103, "metadata": {}, "output_type": "execute_result" } @@ -102,7 +102,7 @@ }, { "cell_type": "code", - "execution_count": 77, + "execution_count": 105, "metadata": {}, "outputs": [], "source": [ @@ -117,109 +117,7 @@ }, { "cell_type": "code", - "execution_count": 78, - "metadata": {}, - "outputs": [], - "source": [ - "# model definition\n", - "def linear_classifier():\n", - " return nn.Sequential(\n", - " Reshape((-1, 784)),\n", - " nn.Linear(784, 50),\n", - " nn.ReLU(),\n", - " nn.Linear(50, 50),\n", - " nn.ReLU(),\n", - " nn.Linear(50, 10)\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": 79, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "train, epoch:1, loss: 0.3142, accuracy: 0.7951\n", - "eval, epoch:1, loss: 0.2298, accuracy: 0.9048\n", - "train, epoch:2, loss: 0.2198, accuracy: 0.9204\n", - "eval, epoch:2, loss: 0.1663, accuracy: 0.9350\n", - "train, epoch:3, loss: 0.1776, accuracy: 0.9420\n", - "eval, epoch:3, loss: 0.1267, accuracy: 0.9493\n", - "train, epoch:4, loss: 0.1328, accuracy: 0.9568\n", - "eval, epoch:4, loss: 0.0959, accuracy: 0.9598\n", - "train, epoch:5, loss: 0.1038, accuracy: 0.9637\n", - "eval, epoch:5, loss: 0.0913, accuracy: 0.9643\n" - ] - } - ], - "source": [ - "model = linear_classifier()\n", - "lr = 0.1\n", - "max_lr = 0.1\n", - "epochs = 5\n", - "opt = optim.AdamW(model.parameters(), lr=lr)\n", - "sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)\n", - "\n", - "for epoch in range(epochs):\n", - " for train in (True, False):\n", - " accuracy = 0\n", - " dl = dls.train if train else dls.valid\n", - " for xb,yb in dl:\n", - " preds = model(xb)\n", - " loss = F.cross_entropy(preds, yb)\n", - " if train:\n", - " loss.backward()\n", - " opt.step()\n", - " opt.zero_grad()\n", - " with torch.no_grad():\n", - " accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()\n", - " if train:\n", - " sched.step()\n", - " accuracy /= len(dl)\n", - " print(f\"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 81, - "metadata": { - "tags": [ - "exclude" - ] - }, - "outputs": [], - "source": [ - "# with open('./mlp_classifier.pkl', 'wb') as model_file:\n", - "# pickle.dump(model, model_file)" - ] - }, - { - "cell_type": "code", - "execution_count": 82, - "metadata": {}, - "outputs": [], - "source": [ - "# def _conv_block(ni, nf, stride, act=act_gr, norm=None, ks=3):\n", - "# return nn.Sequential(conv(ni, nf, stride=1, act=act, norm=norm, ks=ks),\n", - "# conv(nf, nf, stride=stride, act=None, norm=norm, ks=ks))\n", - "\n", - "# class ResBlock(nn.Module):\n", - "# def __init__(self, ni, nf, stride=1, ks=3, act=act_gr, norm=None):\n", - "# super().__init__()\n", - "# self.convs = _conv_block(ni, nf, stride, act=act, ks=ks, norm=norm)\n", - "# self.idconv = fc.noop if ni==nf else conv(ni, nf, ks=1, stride=1, act=None)\n", - "# self.pool = fc.noop if stride==1 else nn.AvgPool2d(2, ceil_mode=True)\n", - "# self.act = act()\n", - "\n", - "# def forward(self, x): return self.act(self.convs(x) + self.idconv(self.pool(x)))" - ] - }, - { - "cell_type": "code", - "execution_count": 83, + "execution_count": 106, "metadata": {}, "outputs": [], "source": [ @@ -251,7 +149,7 @@ }, { "cell_type": "code", - "execution_count": 92, + "execution_count": 107, "metadata": {}, "outputs": [], "source": [ @@ -264,25 +162,12 @@ " ResBlock(64, 64, norm=nn.BatchNorm2d(64)),\n", " conv(64, 10, act=False),\n", " nn.Flatten(),\n", - " )\n", - "\n", - "\n", - "# def cnn_classifier():\n", - "# return nn.Sequential(\n", - "# ResBlock(1, 16, norm=nn.BatchNorm2d(16)),\n", - "# ResBlock(16, 32, norm=nn.BatchNorm2d(32)),\n", - "# ResBlock(32, 64, norm=nn.BatchNorm2d(64)),\n", - "# ResBlock(64, 128, norm=nn.BatchNorm2d(128)),\n", - "# ResBlock(128, 256, norm=nn.BatchNorm2d(256)),\n", - "# ResBlock(256, 256, norm=nn.BatchNorm2d(256)),\n", - "# conv(256, 10, act=False),\n", - "# nn.Flatten(),\n", - "# )" + " )" ] }, { "cell_type": "code", - "execution_count": 93, + "execution_count": 108, "metadata": {}, "outputs": [], "source": [ @@ -293,26 +178,9 @@ }, { "cell_type": "code", - "execution_count": 94, + "execution_count": 109, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "train, epoch:1, loss: 0.0827, accuracy: 0.9102\n", - "eval, epoch:1, loss: 0.0448, accuracy: 0.9817\n", - "train, epoch:2, loss: 0.0382, accuracy: 0.9835\n", - "eval, epoch:2, loss: 0.0353, accuracy: 0.9863\n", - "train, epoch:3, loss: 0.0499, accuracy: 0.9856\n", - "eval, epoch:3, loss: 0.0300, accuracy: 0.9867\n", - "train, epoch:4, loss: 0.0361, accuracy: 0.9869\n", - "eval, epoch:4, loss: 0.0203, accuracy: 0.9877\n", - "train, epoch:5, loss: 0.0427, accuracy: 0.9846\n", - "eval, epoch:5, loss: 0.0250, accuracy: 0.9866\n" - ] - } - ], + "outputs": [], "source": [ "model = cnn_classifier()\n", "model.apply(kaiming_init)\n", @@ -350,7 +218,7 @@ }, "outputs": [], "source": [ - "# with open('./cnn_classifier.pkl', 'wb') as model_file:\n", + "# with open('./classifier.pkl', 'wb') as model_file:\n", "# pickle.dump(model, model_file)" ] }, diff --git a/mnist_classifier.py b/mnist_classifier.py index d8240a7..c4997e3 100644 --- a/mnist_classifier.py +++ b/mnist_classifier.py @@ -53,59 +53,6 @@ def forward(self, x): return x.reshape(self.dim) -# model definition -def linear_classifier(): - return nn.Sequential( - Reshape((-1, 784)), - nn.Linear(784, 50), - nn.ReLU(), - nn.Linear(50, 50), - nn.ReLU(), - nn.Linear(50, 10) - ) - - -model = linear_classifier() -lr = 0.1 -max_lr = 0.1 -epochs = 5 -opt = optim.AdamW(model.parameters(), lr=lr) -sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs) - -for epoch in range(epochs): - for train in (True, False): - accuracy = 0 - dl = dls.train if train else dls.valid - for xb,yb in dl: - preds = model(xb) - loss = F.cross_entropy(preds, yb) - if train: - loss.backward() - opt.step() - opt.zero_grad() - with torch.no_grad(): - accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean() - if train: - sched.step() - accuracy /= len(dl) - print(f"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}") - - -# def _conv_block(ni, nf, stride, act=act_gr, norm=None, ks=3): -# return nn.Sequential(conv(ni, nf, stride=1, act=act, norm=norm, ks=ks), -# conv(nf, nf, stride=stride, act=None, norm=norm, ks=ks)) - -# class ResBlock(nn.Module): -# def __init__(self, ni, nf, stride=1, ks=3, act=act_gr, norm=None): -# super().__init__() -# self.convs = _conv_block(ni, nf, stride, act=act, ks=ks, norm=norm) -# self.idconv = fc.noop if ni==nf else conv(ni, nf, ks=1, stride=1, act=None) -# self.pool = fc.noop if stride==1 else nn.AvgPool2d(2, ceil_mode=True) -# self.act = act() - -# def forward(self, x): return self.act(self.convs(x) + self.idconv(self.pool(x))) - - def conv(ni, nf, ks=3, s=2, act=nn.ReLU, norm=None): layers = [nn.Conv2d(ni, nf, kernel_size=ks, stride=s, padding=ks//2)] if norm: @@ -144,19 +91,6 @@ def cnn_classifier(): ) -# def cnn_classifier(): -# return nn.Sequential( -# ResBlock(1, 16, norm=nn.BatchNorm2d(16)), -# ResBlock(16, 32, norm=nn.BatchNorm2d(32)), -# ResBlock(32, 64, norm=nn.BatchNorm2d(64)), -# ResBlock(64, 128, norm=nn.BatchNorm2d(128)), -# ResBlock(128, 256, norm=nn.BatchNorm2d(256)), -# ResBlock(256, 256, norm=nn.BatchNorm2d(256)), -# conv(256, 10, act=False), -# nn.Flatten(), -# ) - - def kaiming_init(m): if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): nn.init.kaiming_normal_(m.weight) @@ -187,6 +121,3 @@ def kaiming_init(m): accuracy /= len(dl) print(f"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}") - - -