Skip to content

Commit

Permalink
resblock added
Browse files Browse the repository at this point in the history
  • Loading branch information
arun477 committed Sep 13, 2023
1 parent 056ab4f commit 8e35bc7
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 94 deletions.
Binary file modified mlp_classifier.pkl
Binary file not shown.
188 changes: 117 additions & 71 deletions mnist_classifier.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -23,15 +23,15 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 61,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Found cached dataset mnist (/Users/arun/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332)\n",
"100%|██████████| 2/2 [00:00<00:00, 69.76it/s]\n"
"100%|██████████| 2/2 [00:00<00:00, 71.77it/s]\n"
]
}
],
Expand All @@ -43,7 +43,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 62,
"metadata": {},
"outputs": [
{
Expand All @@ -70,7 +70,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 87,
"metadata": {},
"outputs": [
{
Expand All @@ -79,7 +79,7 @@
"(torch.Size([1024, 1, 28, 28]), torch.Size([1024]))"
]
},
"execution_count": 4,
"execution_count": 87,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -89,7 +89,7 @@
"class DataLoaders:\n",
" def __init__(self, train_ds, valid_ds, bs, collate_fn, **kwargs):\n",
" self.train = DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate_fn, **kwargs)\n",
" self.valid = DataLoader(train_ds, batch_size=bs*2, shuffle=False, collate_fn=collate_fn, **kwargs)\n",
" self.valid = DataLoader(valid_ds, batch_size=bs*2, shuffle=False, collate_fn=collate_fn, **kwargs)\n",
"\n",
"def collate_fn(b):\n",
" collate = default_collate(b)\n",
Expand All @@ -102,7 +102,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 77,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -117,7 +117,7 @@
},
{
"cell_type": "code",
"execution_count": 43,
"execution_count": 78,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -135,23 +135,23 @@
},
{
"cell_type": "code",
"execution_count": 44,
"execution_count": 79,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train, epoch:1, loss: 0.2640, accuracy: 0.7885\n",
"eval, epoch:1, loss: 0.3039, accuracy: 0.8994\n",
"train, epoch:2, loss: 0.2368, accuracy: 0.9182\n",
"eval, epoch:2, loss: 0.2164, accuracy: 0.9350\n",
"train, epoch:3, loss: 0.1951, accuracy: 0.9402\n",
"eval, epoch:3, loss: 0.1589, accuracy: 0.9498\n",
"train, epoch:4, loss: 0.1511, accuracy: 0.9513\n",
"eval, epoch:4, loss: 0.1388, accuracy: 0.9618\n",
"train, epoch:5, loss: 0.1182, accuracy: 0.9567\n",
"eval, epoch:5, loss: 0.1426, accuracy: 0.9621\n"
"train, epoch:1, loss: 0.3142, accuracy: 0.7951\n",
"eval, epoch:1, loss: 0.2298, accuracy: 0.9048\n",
"train, epoch:2, loss: 0.2198, accuracy: 0.9204\n",
"eval, epoch:2, loss: 0.1663, accuracy: 0.9350\n",
"train, epoch:3, loss: 0.1776, accuracy: 0.9420\n",
"eval, epoch:3, loss: 0.1267, accuracy: 0.9493\n",
"train, epoch:4, loss: 0.1328, accuracy: 0.9568\n",
"eval, epoch:4, loss: 0.0959, accuracy: 0.9598\n",
"train, epoch:5, loss: 0.1038, accuracy: 0.9637\n",
"eval, epoch:5, loss: 0.0913, accuracy: 0.9643\n"
]
}
],
Expand Down Expand Up @@ -184,50 +184,105 @@
},
{
"cell_type": "code",
"execution_count": 46,
"execution_count": 81,
"metadata": {
"tags": [
"exclude"
]
},
"outputs": [],
"source": [
"with open('./mlp_classifier.pkl', 'wb') as model_file:\n",
" pickle.dump(model, model_file)"
"# with open('./mlp_classifier.pkl', 'wb') as model_file:\n",
"# pickle.dump(model, model_file)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"execution_count": 82,
"metadata": {},
"outputs": [],
"source": [
"# def _conv_block(ni, nf, stride, act=act_gr, norm=None, ks=3):\n",
"# return nn.Sequential(conv(ni, nf, stride=1, act=act, norm=norm, ks=ks),\n",
"# conv(nf, nf, stride=stride, act=None, norm=norm, ks=ks))\n",
"\n",
"# class ResBlock(nn.Module):\n",
"# def __init__(self, ni, nf, stride=1, ks=3, act=act_gr, norm=None):\n",
"# super().__init__()\n",
"# self.convs = _conv_block(ni, nf, stride, act=act, ks=ks, norm=norm)\n",
"# self.idconv = fc.noop if ni==nf else conv(ni, nf, ks=1, stride=1, act=None)\n",
"# self.pool = fc.noop if stride==1 else nn.AvgPool2d(2, ceil_mode=True)\n",
"# self.act = act()\n",
"\n",
"# def forward(self, x): return self.act(self.convs(x) + self.idconv(self.pool(x)))"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {},
"outputs": [],
"source": [
"def conv(ni, nf, ks=3, s=2, act=nn.ReLU, norm=None):\n",
" layers = [nn.Conv2d(ni, nf, kernel_size=ks, stride=s, padding=ks//2)]\n",
" if norm:\n",
" layers.append(norm)\n",
" if act:\n",
" layers.append(act())\n",
" return nn.Sequential(*layers)\n",
"\n",
"def _conv_block(ni, nf, ks=3, s=2, act=nn.ReLU, norm=None):\n",
" return nn.Sequential(\n",
" conv(ni, nf, ks=ks, s=1, norm=norm, act=act),\n",
" conv(nf, nf, ks=ks, s=s, norm=norm, act=act),\n",
" )\n",
"\n",
"class ResBlock(nn.Module):\n",
" def __init__(self, ni, nf, s=2, ks=3, act=nn.ReLU, norm=None):\n",
" super().__init__()\n",
" self.convs = _conv_block(ni, nf, s=s, ks=ks, act=act, norm=norm)\n",
" self.idconv = fc.noop if ni==nf else conv(ni, nf, ks=1, s=1, act=None)\n",
" self.pool = fc.noop if s==1 else nn.AvgPool2d(2, ceil_mode=True)\n",
" self.act = act()\n",
" \n",
" def forward(self, x):\n",
" return self.act(self.convs(x) + self.idconv(self.pool(x)))"
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {},
"outputs": [],
"source": [
"def cnn_classifier():\n",
" ks,stride = 3,2\n",
" return nn.Sequential(\n",
" nn.Conv2d(1, 8, kernel_size=ks, stride=stride, padding=ks//2),\n",
" nn.BatchNorm2d(8),\n",
" nn.ReLU(),\n",
" nn.Conv2d(8, 16, kernel_size=ks, stride=stride, padding=ks//2),\n",
" nn.BatchNorm2d(16),\n",
" nn.ReLU(),\n",
" nn.Conv2d(16, 32, kernel_size=ks, stride=stride, padding=ks//2),\n",
" nn.BatchNorm2d(32),\n",
" nn.ReLU(),\n",
" nn.Conv2d(32, 64, kernel_size=ks, stride=stride, padding=ks//2),\n",
" nn.BatchNorm2d(64),\n",
" nn.ReLU(),\n",
" nn.Conv2d(64, 64, kernel_size=ks, stride=stride, padding=ks//2),\n",
" nn.BatchNorm2d(64),\n",
" nn.ReLU(),\n",
" nn.Conv2d(64, 10, kernel_size=ks, stride=stride, padding=ks//2),\n",
" ResBlock(1, 8, norm=nn.BatchNorm2d(8)),\n",
" ResBlock(8, 16, norm=nn.BatchNorm2d(16)),\n",
" ResBlock(16, 32, norm=nn.BatchNorm2d(32)),\n",
" ResBlock(32, 64, norm=nn.BatchNorm2d(64)),\n",
" ResBlock(64, 64, norm=nn.BatchNorm2d(64)),\n",
" conv(64, 10, act=False),\n",
" nn.Flatten(),\n",
" )"
" )\n",
"\n",
"\n",
"# def cnn_classifier():\n",
"# return nn.Sequential(\n",
"# ResBlock(1, 16, norm=nn.BatchNorm2d(16)),\n",
"# ResBlock(16, 32, norm=nn.BatchNorm2d(32)),\n",
"# ResBlock(32, 64, norm=nn.BatchNorm2d(64)),\n",
"# ResBlock(64, 128, norm=nn.BatchNorm2d(128)),\n",
"# ResBlock(128, 256, norm=nn.BatchNorm2d(256)),\n",
"# ResBlock(256, 256, norm=nn.BatchNorm2d(256)),\n",
"# conv(256, 10, act=False),\n",
"# nn.Flatten(),\n",
"# )"
]
},
{
"cell_type": "code",
"execution_count": 36,
"execution_count": 93,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -238,23 +293,23 @@
},
{
"cell_type": "code",
"execution_count": 37,
"execution_count": 94,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train, epoch:1, loss: 0.1096, accuracy: 0.9145\n",
"eval, epoch:1, loss: 0.1383, accuracy: 0.9774\n",
"train, epoch:2, loss: 0.0487, accuracy: 0.9808\n",
"eval, epoch:2, loss: 0.0715, accuracy: 0.9867\n",
"train, epoch:3, loss: 0.0536, accuracy: 0.9840\n",
"eval, epoch:3, loss: 0.0499, accuracy: 0.9896\n",
"train, epoch:4, loss: 0.0358, accuracy: 0.9842\n",
"eval, epoch:4, loss: 0.0474, accuracy: 0.9893\n",
"train, epoch:5, loss: 0.0514, accuracy: 0.9852\n",
"eval, epoch:5, loss: 0.0579, accuracy: 0.9886\n"
"train, epoch:1, loss: 0.0827, accuracy: 0.9102\n",
"eval, epoch:1, loss: 0.0448, accuracy: 0.9817\n",
"train, epoch:2, loss: 0.0382, accuracy: 0.9835\n",
"eval, epoch:2, loss: 0.0353, accuracy: 0.9863\n",
"train, epoch:3, loss: 0.0499, accuracy: 0.9856\n",
"eval, epoch:3, loss: 0.0300, accuracy: 0.9867\n",
"train, epoch:4, loss: 0.0361, accuracy: 0.9869\n",
"eval, epoch:4, loss: 0.0203, accuracy: 0.9877\n",
"train, epoch:5, loss: 0.0427, accuracy: 0.9846\n",
"eval, epoch:5, loss: 0.0250, accuracy: 0.9866\n"
]
}
],
Expand All @@ -266,7 +321,6 @@
"epochs = 5\n",
"opt = optim.AdamW(model.parameters(), lr=lr)\n",
"sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)\n",
"\n",
"for epoch in range(epochs):\n",
" for train in (True, False):\n",
" accuracy = 0\n",
Expand All @@ -283,22 +337,21 @@
" if train:\n",
" sched.step()\n",
" accuracy /= len(dl)\n",
" print(f\"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}\")\n",
" "
" print(f\"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}\")"
]
},
{
"cell_type": "code",
"execution_count": 41,
"execution_count": 95,
"metadata": {
"tags": [
"exclude"
]
},
"outputs": [],
"source": [
"with open('./cnn_classifier.pkl', 'wb') as model_file:\n",
" pickle.dump(model, model_file)"
"# with open('./cnn_classifier.pkl', 'wb') as model_file:\n",
"# pickle.dump(model, model_file)"
]
},
{
Expand All @@ -314,7 +367,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 96,
"metadata": {
"tags": [
"exclude"
Expand All @@ -325,22 +378,15 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[NbConvertApp] Converting notebook mnist_classifier.ipynb to script\n"
"[NbConvertApp] Converting notebook mnist_classifier.ipynb to script\n",
"[NbConvertApp] Writing 5934 bytes to mnist_classifier.py\n"
]
}
],
"source": [
"!jupyter nbconvert --to script --TagRemovePreprocessor.remove_cell_tags=\"exclude\" --TemplateExporter.exclude_input_prompt=True mnist_classifier.ipynb\n",
"\n"
"!jupyter nbconvert --to script --TagRemovePreprocessor.remove_cell_tags=\"exclude\" --TemplateExporter.exclude_input_prompt=True mnist_classifier.ipynb\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
Loading

0 comments on commit 8e35bc7

Please sign in to comment.