diff --git a/mlp_classifier.pkl b/mlp_classifier.pkl index 0ff2094..8e39e47 100644 Binary files a/mlp_classifier.pkl and b/mlp_classifier.pkl differ diff --git a/mnist_classifier.ipynb b/mnist_classifier.ipynb index b1f49b8..de44943 100644 --- a/mnist_classifier.ipynb +++ b/mnist_classifier.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 60, "metadata": {}, "outputs": [], "source": [ @@ -23,7 +23,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 61, "metadata": {}, "outputs": [ { @@ -31,7 +31,7 @@ "output_type": "stream", "text": [ "Found cached dataset mnist (/Users/arun/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332)\n", - "100%|██████████| 2/2 [00:00<00:00, 69.76it/s]\n" + "100%|██████████| 2/2 [00:00<00:00, 71.77it/s]\n" ] } ], @@ -43,7 +43,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 62, "metadata": {}, "outputs": [ { @@ -70,7 +70,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 87, "metadata": {}, "outputs": [ { @@ -79,7 +79,7 @@ "(torch.Size([1024, 1, 28, 28]), torch.Size([1024]))" ] }, - "execution_count": 4, + "execution_count": 87, "metadata": {}, "output_type": "execute_result" } @@ -89,7 +89,7 @@ "class DataLoaders:\n", " def __init__(self, train_ds, valid_ds, bs, collate_fn, **kwargs):\n", " self.train = DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate_fn, **kwargs)\n", - " self.valid = DataLoader(train_ds, batch_size=bs*2, shuffle=False, collate_fn=collate_fn, **kwargs)\n", + " self.valid = DataLoader(valid_ds, batch_size=bs*2, shuffle=False, collate_fn=collate_fn, **kwargs)\n", "\n", "def collate_fn(b):\n", " collate = default_collate(b)\n", @@ -102,7 +102,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 77, "metadata": {}, "outputs": [], "source": [ @@ -117,7 +117,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 78, "metadata": {}, "outputs": [], "source": [ @@ -135,23 +135,23 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 79, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "train, epoch:1, loss: 0.2640, accuracy: 0.7885\n", - "eval, epoch:1, loss: 0.3039, accuracy: 0.8994\n", - "train, epoch:2, loss: 0.2368, accuracy: 0.9182\n", - "eval, epoch:2, loss: 0.2164, accuracy: 0.9350\n", - "train, epoch:3, loss: 0.1951, accuracy: 0.9402\n", - "eval, epoch:3, loss: 0.1589, accuracy: 0.9498\n", - "train, epoch:4, loss: 0.1511, accuracy: 0.9513\n", - "eval, epoch:4, loss: 0.1388, accuracy: 0.9618\n", - "train, epoch:5, loss: 0.1182, accuracy: 0.9567\n", - "eval, epoch:5, loss: 0.1426, accuracy: 0.9621\n" + "train, epoch:1, loss: 0.3142, accuracy: 0.7951\n", + "eval, epoch:1, loss: 0.2298, accuracy: 0.9048\n", + "train, epoch:2, loss: 0.2198, accuracy: 0.9204\n", + "eval, epoch:2, loss: 0.1663, accuracy: 0.9350\n", + "train, epoch:3, loss: 0.1776, accuracy: 0.9420\n", + "eval, epoch:3, loss: 0.1267, accuracy: 0.9493\n", + "train, epoch:4, loss: 0.1328, accuracy: 0.9568\n", + "eval, epoch:4, loss: 0.0959, accuracy: 0.9598\n", + "train, epoch:5, loss: 0.1038, accuracy: 0.9637\n", + "eval, epoch:5, loss: 0.0913, accuracy: 0.9643\n" ] } ], @@ -184,7 +184,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 81, "metadata": { "tags": [ "exclude" @@ -192,42 +192,97 @@ }, "outputs": [], "source": [ - "with open('./mlp_classifier.pkl', 'wb') as model_file:\n", - " pickle.dump(model, model_file)" + "# with open('./mlp_classifier.pkl', 'wb') as model_file:\n", + "# pickle.dump(model, model_file)" ] }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 82, + "metadata": {}, + "outputs": [], + "source": [ + "# def _conv_block(ni, nf, stride, act=act_gr, norm=None, ks=3):\n", + "# return nn.Sequential(conv(ni, nf, stride=1, act=act, norm=norm, ks=ks),\n", + "# conv(nf, nf, stride=stride, act=None, norm=norm, ks=ks))\n", + "\n", + "# class ResBlock(nn.Module):\n", + "# def __init__(self, ni, nf, stride=1, ks=3, act=act_gr, norm=None):\n", + "# super().__init__()\n", + "# self.convs = _conv_block(ni, nf, stride, act=act, ks=ks, norm=norm)\n", + "# self.idconv = fc.noop if ni==nf else conv(ni, nf, ks=1, stride=1, act=None)\n", + "# self.pool = fc.noop if stride==1 else nn.AvgPool2d(2, ceil_mode=True)\n", + "# self.act = act()\n", + "\n", + "# def forward(self, x): return self.act(self.convs(x) + self.idconv(self.pool(x)))" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "metadata": {}, + "outputs": [], + "source": [ + "def conv(ni, nf, ks=3, s=2, act=nn.ReLU, norm=None):\n", + " layers = [nn.Conv2d(ni, nf, kernel_size=ks, stride=s, padding=ks//2)]\n", + " if norm:\n", + " layers.append(norm)\n", + " if act:\n", + " layers.append(act())\n", + " return nn.Sequential(*layers)\n", + "\n", + "def _conv_block(ni, nf, ks=3, s=2, act=nn.ReLU, norm=None):\n", + " return nn.Sequential(\n", + " conv(ni, nf, ks=ks, s=1, norm=norm, act=act),\n", + " conv(nf, nf, ks=ks, s=s, norm=norm, act=act),\n", + " )\n", + "\n", + "class ResBlock(nn.Module):\n", + " def __init__(self, ni, nf, s=2, ks=3, act=nn.ReLU, norm=None):\n", + " super().__init__()\n", + " self.convs = _conv_block(ni, nf, s=s, ks=ks, act=act, norm=norm)\n", + " self.idconv = fc.noop if ni==nf else conv(ni, nf, ks=1, s=1, act=None)\n", + " self.pool = fc.noop if s==1 else nn.AvgPool2d(2, ceil_mode=True)\n", + " self.act = act()\n", + " \n", + " def forward(self, x):\n", + " return self.act(self.convs(x) + self.idconv(self.pool(x)))" + ] + }, + { + "cell_type": "code", + "execution_count": 92, "metadata": {}, "outputs": [], "source": [ "def cnn_classifier():\n", - " ks,stride = 3,2\n", " return nn.Sequential(\n", - " nn.Conv2d(1, 8, kernel_size=ks, stride=stride, padding=ks//2),\n", - " nn.BatchNorm2d(8),\n", - " nn.ReLU(),\n", - " nn.Conv2d(8, 16, kernel_size=ks, stride=stride, padding=ks//2),\n", - " nn.BatchNorm2d(16),\n", - " nn.ReLU(),\n", - " nn.Conv2d(16, 32, kernel_size=ks, stride=stride, padding=ks//2),\n", - " nn.BatchNorm2d(32),\n", - " nn.ReLU(),\n", - " nn.Conv2d(32, 64, kernel_size=ks, stride=stride, padding=ks//2),\n", - " nn.BatchNorm2d(64),\n", - " nn.ReLU(),\n", - " nn.Conv2d(64, 64, kernel_size=ks, stride=stride, padding=ks//2),\n", - " nn.BatchNorm2d(64),\n", - " nn.ReLU(),\n", - " nn.Conv2d(64, 10, kernel_size=ks, stride=stride, padding=ks//2),\n", + " ResBlock(1, 8, norm=nn.BatchNorm2d(8)),\n", + " ResBlock(8, 16, norm=nn.BatchNorm2d(16)),\n", + " ResBlock(16, 32, norm=nn.BatchNorm2d(32)),\n", + " ResBlock(32, 64, norm=nn.BatchNorm2d(64)),\n", + " ResBlock(64, 64, norm=nn.BatchNorm2d(64)),\n", + " conv(64, 10, act=False),\n", " nn.Flatten(),\n", - " )" + " )\n", + "\n", + "\n", + "# def cnn_classifier():\n", + "# return nn.Sequential(\n", + "# ResBlock(1, 16, norm=nn.BatchNorm2d(16)),\n", + "# ResBlock(16, 32, norm=nn.BatchNorm2d(32)),\n", + "# ResBlock(32, 64, norm=nn.BatchNorm2d(64)),\n", + "# ResBlock(64, 128, norm=nn.BatchNorm2d(128)),\n", + "# ResBlock(128, 256, norm=nn.BatchNorm2d(256)),\n", + "# ResBlock(256, 256, norm=nn.BatchNorm2d(256)),\n", + "# conv(256, 10, act=False),\n", + "# nn.Flatten(),\n", + "# )" ] }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 93, "metadata": {}, "outputs": [], "source": [ @@ -238,23 +293,23 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 94, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "train, epoch:1, loss: 0.1096, accuracy: 0.9145\n", - "eval, epoch:1, loss: 0.1383, accuracy: 0.9774\n", - "train, epoch:2, loss: 0.0487, accuracy: 0.9808\n", - "eval, epoch:2, loss: 0.0715, accuracy: 0.9867\n", - "train, epoch:3, loss: 0.0536, accuracy: 0.9840\n", - "eval, epoch:3, loss: 0.0499, accuracy: 0.9896\n", - "train, epoch:4, loss: 0.0358, accuracy: 0.9842\n", - "eval, epoch:4, loss: 0.0474, accuracy: 0.9893\n", - "train, epoch:5, loss: 0.0514, accuracy: 0.9852\n", - "eval, epoch:5, loss: 0.0579, accuracy: 0.9886\n" + "train, epoch:1, loss: 0.0827, accuracy: 0.9102\n", + "eval, epoch:1, loss: 0.0448, accuracy: 0.9817\n", + "train, epoch:2, loss: 0.0382, accuracy: 0.9835\n", + "eval, epoch:2, loss: 0.0353, accuracy: 0.9863\n", + "train, epoch:3, loss: 0.0499, accuracy: 0.9856\n", + "eval, epoch:3, loss: 0.0300, accuracy: 0.9867\n", + "train, epoch:4, loss: 0.0361, accuracy: 0.9869\n", + "eval, epoch:4, loss: 0.0203, accuracy: 0.9877\n", + "train, epoch:5, loss: 0.0427, accuracy: 0.9846\n", + "eval, epoch:5, loss: 0.0250, accuracy: 0.9866\n" ] } ], @@ -266,7 +321,6 @@ "epochs = 5\n", "opt = optim.AdamW(model.parameters(), lr=lr)\n", "sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)\n", - "\n", "for epoch in range(epochs):\n", " for train in (True, False):\n", " accuracy = 0\n", @@ -283,13 +337,12 @@ " if train:\n", " sched.step()\n", " accuracy /= len(dl)\n", - " print(f\"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}\")\n", - " " + " print(f\"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}\")" ] }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 95, "metadata": { "tags": [ "exclude" @@ -297,8 +350,8 @@ }, "outputs": [], "source": [ - "with open('./cnn_classifier.pkl', 'wb') as model_file:\n", - " pickle.dump(model, model_file)" + "# with open('./cnn_classifier.pkl', 'wb') as model_file:\n", + "# pickle.dump(model, model_file)" ] }, { @@ -314,7 +367,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 96, "metadata": { "tags": [ "exclude" @@ -325,22 +378,15 @@ "name": "stdout", "output_type": "stream", "text": [ - "[NbConvertApp] Converting notebook mnist_classifier.ipynb to script\n" + "[NbConvertApp] Converting notebook mnist_classifier.ipynb to script\n", + "[NbConvertApp] Writing 5934 bytes to mnist_classifier.py\n" ] } ], "source": [ - "!jupyter nbconvert --to script --TagRemovePreprocessor.remove_cell_tags=\"exclude\" --TemplateExporter.exclude_input_prompt=True mnist_classifier.ipynb\n", - "\n" + "!jupyter nbconvert --to script --TagRemovePreprocessor.remove_cell_tags=\"exclude\" --TemplateExporter.exclude_input_prompt=True mnist_classifier.ipynb\n" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": null, diff --git a/mnist_classifier.py b/mnist_classifier.py index 079b967..d8240a7 100644 --- a/mnist_classifier.py +++ b/mnist_classifier.py @@ -33,7 +33,7 @@ def transform_ds(b): class DataLoaders: def __init__(self, train_ds, valid_ds, bs, collate_fn, **kwargs): self.train = DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate_fn, **kwargs) - self.valid = DataLoader(train_ds, batch_size=bs*2, shuffle=False, collate_fn=collate_fn, **kwargs) + self.valid = DataLoader(valid_ds, batch_size=bs*2, shuffle=False, collate_fn=collate_fn, **kwargs) def collate_fn(b): collate = default_collate(b) @@ -91,29 +91,72 @@ def linear_classifier(): print(f"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}") +# def _conv_block(ni, nf, stride, act=act_gr, norm=None, ks=3): +# return nn.Sequential(conv(ni, nf, stride=1, act=act, norm=norm, ks=ks), +# conv(nf, nf, stride=stride, act=None, norm=norm, ks=ks)) + +# class ResBlock(nn.Module): +# def __init__(self, ni, nf, stride=1, ks=3, act=act_gr, norm=None): +# super().__init__() +# self.convs = _conv_block(ni, nf, stride, act=act, ks=ks, norm=norm) +# self.idconv = fc.noop if ni==nf else conv(ni, nf, ks=1, stride=1, act=None) +# self.pool = fc.noop if stride==1 else nn.AvgPool2d(2, ceil_mode=True) +# self.act = act() + +# def forward(self, x): return self.act(self.convs(x) + self.idconv(self.pool(x))) + + +def conv(ni, nf, ks=3, s=2, act=nn.ReLU, norm=None): + layers = [nn.Conv2d(ni, nf, kernel_size=ks, stride=s, padding=ks//2)] + if norm: + layers.append(norm) + if act: + layers.append(act()) + return nn.Sequential(*layers) + +def _conv_block(ni, nf, ks=3, s=2, act=nn.ReLU, norm=None): + return nn.Sequential( + conv(ni, nf, ks=ks, s=1, norm=norm, act=act), + conv(nf, nf, ks=ks, s=s, norm=norm, act=act), + ) + +class ResBlock(nn.Module): + def __init__(self, ni, nf, s=2, ks=3, act=nn.ReLU, norm=None): + super().__init__() + self.convs = _conv_block(ni, nf, s=s, ks=ks, act=act, norm=norm) + self.idconv = fc.noop if ni==nf else conv(ni, nf, ks=1, s=1, act=None) + self.pool = fc.noop if s==1 else nn.AvgPool2d(2, ceil_mode=True) + self.act = act() + + def forward(self, x): + return self.act(self.convs(x) + self.idconv(self.pool(x))) + + def cnn_classifier(): - ks,stride = 3,2 return nn.Sequential( - nn.Conv2d(1, 8, kernel_size=ks, stride=stride, padding=ks//2), - nn.BatchNorm2d(8), - nn.ReLU(), - nn.Conv2d(8, 16, kernel_size=ks, stride=stride, padding=ks//2), - nn.BatchNorm2d(16), - nn.ReLU(), - nn.Conv2d(16, 32, kernel_size=ks, stride=stride, padding=ks//2), - nn.BatchNorm2d(32), - nn.ReLU(), - nn.Conv2d(32, 64, kernel_size=ks, stride=stride, padding=ks//2), - nn.BatchNorm2d(64), - nn.ReLU(), - nn.Conv2d(64, 64, kernel_size=ks, stride=stride, padding=ks//2), - nn.BatchNorm2d(64), - nn.ReLU(), - nn.Conv2d(64, 10, kernel_size=ks, stride=stride, padding=ks//2), + ResBlock(1, 8, norm=nn.BatchNorm2d(8)), + ResBlock(8, 16, norm=nn.BatchNorm2d(16)), + ResBlock(16, 32, norm=nn.BatchNorm2d(32)), + ResBlock(32, 64, norm=nn.BatchNorm2d(64)), + ResBlock(64, 64, norm=nn.BatchNorm2d(64)), + conv(64, 10, act=False), nn.Flatten(), ) +# def cnn_classifier(): +# return nn.Sequential( +# ResBlock(1, 16, norm=nn.BatchNorm2d(16)), +# ResBlock(16, 32, norm=nn.BatchNorm2d(32)), +# ResBlock(32, 64, norm=nn.BatchNorm2d(64)), +# ResBlock(64, 128, norm=nn.BatchNorm2d(128)), +# ResBlock(128, 256, norm=nn.BatchNorm2d(256)), +# ResBlock(256, 256, norm=nn.BatchNorm2d(256)), +# conv(256, 10, act=False), +# nn.Flatten(), +# ) + + def kaiming_init(m): if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): nn.init.kaiming_normal_(m.weight) @@ -126,7 +169,6 @@ def kaiming_init(m): epochs = 5 opt = optim.AdamW(model.parameters(), lr=lr) sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs) - for epoch in range(epochs): for train in (True, False): accuracy = 0 @@ -144,10 +186,6 @@ def kaiming_init(m): sched.step() accuracy /= len(dl) print(f"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}") - - - -