Skip to content

Commit

Permalink
cnn classifier added, accuracy is 98%
Browse files Browse the repository at this point in the history
  • Loading branch information
arun477 committed Sep 13, 2023
1 parent 53075d2 commit 5993d2f
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 51 deletions.
Binary file added cnn_classifier.pkl
Binary file not shown.
Binary file removed linear_classifier.pkl
Binary file not shown.
Binary file added mlp_classifier.pkl
Binary file not shown.
149 changes: 115 additions & 34 deletions mnist.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"output_type": "stream",
"text": [
"Found cached dataset mnist (/Users/arun/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332)\n",
"100%|██████████| 2/2 [00:00<00:00, 65.21it/s]\n"
"100%|██████████| 2/2 [00:00<00:00, 69.76it/s]\n"
]
}
],
Expand Down Expand Up @@ -117,72 +117,153 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"# model definition\n",
"def linear_classifier():\n",
" return nn.Sequential(\n",
" Reshape((-1, 784)),\n",
" nn.Linear(784, 50),\n",
" nn.ReLU(),\n",
" nn.Linear(50, 50),\n",
" nn.ReLU(),\n",
" nn.Linear(50, 10)\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train, epoch:1, loss: 0.2640, accuracy: 0.7885\n",
"eval, epoch:1, loss: 0.3039, accuracy: 0.8994\n",
"train, epoch:2, loss: 0.2368, accuracy: 0.9182\n",
"eval, epoch:2, loss: 0.2164, accuracy: 0.9350\n",
"train, epoch:3, loss: 0.1951, accuracy: 0.9402\n",
"eval, epoch:3, loss: 0.1589, accuracy: 0.9498\n",
"train, epoch:4, loss: 0.1511, accuracy: 0.9513\n",
"eval, epoch:4, loss: 0.1388, accuracy: 0.9618\n",
"train, epoch:5, loss: 0.1182, accuracy: 0.9567\n",
"eval, epoch:5, loss: 0.1426, accuracy: 0.9621\n"
]
}
],
"source": [
"model = linear_classifier()\n",
"lr = 0.1\n",
"max_lr = 0.1\n",
"epochs = 5\n",
"opt = optim.AdamW(model.parameters(), lr=lr)\n",
"sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)\n",
"\n",
"for epoch in range(epochs):\n",
" for train in (True, False):\n",
" accuracy = 0\n",
" dl = dls.train if train else dls.valid\n",
" for xb,yb in dl:\n",
" preds = model(xb)\n",
" loss = F.cross_entropy(preds, yb)\n",
" if train:\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
" with torch.no_grad():\n",
" accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()\n",
" if train:\n",
" sched.step()\n",
" accuracy /= len(dl)\n",
" print(f\"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}\")\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {
"tags": [
"exclude"
]
},
"outputs": [],
"source": [
"with open('./mlp_classifier.pkl', 'wb') as model_file:\n",
" pickle.dump(model, model_file)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"def cnn_classifier():\n",
" ks,stride = 3,2\n",
" return nn.Sequential(\n",
" nn.Conv2d(1, 4, kernel_size=ks, stride=stride, padding=ks//2),\n",
" nn.ReLU(),\n",
" nn.Conv2d(4, 8, kernel_size=ks, stride=stride, padding=ks//2),\n",
" nn.Conv2d(1, 8, kernel_size=ks, stride=stride, padding=ks//2),\n",
" nn.BatchNorm2d(8),\n",
" nn.ReLU(),\n",
" nn.Conv2d(8, 16, kernel_size=ks, stride=stride, padding=ks//2),\n",
" nn.BatchNorm2d(16),\n",
" nn.ReLU(),\n",
" nn.Conv2d(16, 32, kernel_size=ks, stride=stride, padding=ks//2),\n",
" nn.BatchNorm2d(32),\n",
" nn.ReLU(),\n",
" nn.Conv2d(32, 64, kernel_size=ks, stride=stride, padding=ks//2),\n",
" nn.BatchNorm2d(64),\n",
" nn.ReLU(),\n",
" nn.Conv2d(32, 32, kernel_size=ks, stride=stride, padding=ks//2),\n",
" nn.Conv2d(64, 64, kernel_size=ks, stride=stride, padding=ks//2),\n",
" nn.BatchNorm2d(64),\n",
" nn.ReLU(),\n",
" nn.Conv2d(32, 10, kernel_size=ks, stride=stride, padding=ks//2),\n",
" nn.Conv2d(64, 10, kernel_size=ks, stride=stride, padding=ks//2),\n",
" nn.Flatten(),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
"# model definition\n",
"def linear_classifier():\n",
" return nn.Sequential(\n",
" Reshape((-1, 784)),\n",
" nn.Linear(784, 50),\n",
" nn.ReLU(),\n",
" nn.Linear(50, 50),\n",
" nn.ReLU(),\n",
" nn.Linear(50, 10)\n",
" )"
"def kaiming_init(m):\n",
" if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):\n",
" nn.init.kaiming_normal_(m.weight)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 37,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train, epoch:1, loss: 0.2638, accuracy: 0.8032\n",
"eval, epoch:1, loss: 0.2929, accuracy: 0.9011\n",
"train, epoch:2, loss: 0.2497, accuracy: 0.9180\n",
"eval, epoch:2, loss: 0.2317, accuracy: 0.9312\n",
"train, epoch:3, loss: 0.1817, accuracy: 0.9391\n",
"eval, epoch:3, loss: 0.1751, accuracy: 0.9496\n",
"train, epoch:4, loss: 0.1589, accuracy: 0.9518\n",
"eval, epoch:4, loss: 0.1630, accuracy: 0.9638\n",
"train, epoch:5, loss: 0.1498, accuracy: 0.9603\n",
"eval, epoch:5, loss: 0.1425, accuracy: 0.9655\n"
"train, epoch:1, loss: 0.1096, accuracy: 0.9145\n",
"eval, epoch:1, loss: 0.1383, accuracy: 0.9774\n",
"train, epoch:2, loss: 0.0487, accuracy: 0.9808\n",
"eval, epoch:2, loss: 0.0715, accuracy: 0.9867\n",
"train, epoch:3, loss: 0.0536, accuracy: 0.9840\n",
"eval, epoch:3, loss: 0.0499, accuracy: 0.9896\n",
"train, epoch:4, loss: 0.0358, accuracy: 0.9842\n",
"eval, epoch:4, loss: 0.0474, accuracy: 0.9893\n",
"train, epoch:5, loss: 0.0514, accuracy: 0.9852\n",
"eval, epoch:5, loss: 0.0579, accuracy: 0.9886\n"
]
}
],
"source": [
"model = linear_classifier()\n",
"model = cnn_classifier()\n",
"model.apply(kaiming_init)\n",
"lr = 0.1\n",
"max_lr = 0.1\n",
"max_lr = 0.3\n",
"epochs = 5\n",
"opt = optim.AdamW(model.parameters(), lr=lr)\n",
"sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)\n",
Expand All @@ -209,16 +290,16 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": 41,
"metadata": {
"tags": [
"exclude"
]
},
"outputs": [],
"source": [
"# with open('./linear_classifier.pkl', 'wb') as model_file:\n",
"# pickle.dump(model, model_file)"
"with open('./cnn_classifier.pkl', 'wb') as model_file:\n",
" pickle.dump(model, model_file)"
]
},
{
Expand Down
72 changes: 55 additions & 17 deletions mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,39 +53,77 @@ def forward(self, x):
return x.reshape(self.dim)


# model definition
def linear_classifier():
return nn.Sequential(
Reshape((-1, 784)),
nn.Linear(784, 50),
nn.ReLU(),
nn.Linear(50, 50),
nn.ReLU(),
nn.Linear(50, 10)
)


model = linear_classifier()
lr = 0.1
max_lr = 0.1
epochs = 5
opt = optim.AdamW(model.parameters(), lr=lr)
sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)

for epoch in range(epochs):
for train in (True, False):
accuracy = 0
dl = dls.train if train else dls.valid
for xb,yb in dl:
preds = model(xb)
loss = F.cross_entropy(preds, yb)
if train:
loss.backward()
opt.step()
opt.zero_grad()
with torch.no_grad():
accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()
if train:
sched.step()
accuracy /= len(dl)
print(f"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}")



def cnn_classifier():
ks,stride = 3,2
return nn.Sequential(
nn.Conv2d(1, 4, kernel_size=ks, stride=stride, padding=ks//2),
nn.ReLU(),
nn.Conv2d(4, 8, kernel_size=ks, stride=stride, padding=ks//2),
nn.Conv2d(1, 8, kernel_size=ks, stride=stride, padding=ks//2),
nn.BatchNorm2d(8),
nn.ReLU(),
nn.Conv2d(8, 16, kernel_size=ks, stride=stride, padding=ks//2),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.Conv2d(16, 32, kernel_size=ks, stride=stride, padding=ks//2),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 32, kernel_size=ks, stride=stride, padding=ks//2),
nn.Conv2d(32, 64, kernel_size=ks, stride=stride, padding=ks//2),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(32, 10, kernel_size=ks, stride=stride, padding=ks//2),
nn.Conv2d(64, 64, kernel_size=ks, stride=stride, padding=ks//2),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 10, kernel_size=ks, stride=stride, padding=ks//2),
nn.Flatten(),
)


# model definition
def linear_classifier():
return nn.Sequential(
Reshape((-1, 784)),
nn.Linear(784, 50),
nn.ReLU(),
nn.Linear(50, 50),
nn.ReLU(),
nn.Linear(50, 10)
)
def kaiming_init(m):
if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
nn.init.kaiming_normal_(m.weight)


model = linear_classifier()
model = cnn_classifier()
model.apply(kaiming_init)
lr = 0.1
max_lr = 0.1
max_lr = 0.3
epochs = 5
opt = optim.AdamW(model.parameters(), lr=lr)
sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)
Expand Down

0 comments on commit 5993d2f

Please sign in to comment.