diff --git a/cnn_classifier.pkl b/cnn_classifier.pkl
new file mode 100644
index 0000000..d6e231d
Binary files /dev/null and b/cnn_classifier.pkl differ
diff --git a/linear_classifier.pkl b/linear_classifier.pkl
deleted file mode 100644
index 1a87837..0000000
Binary files a/linear_classifier.pkl and /dev/null differ
diff --git a/mlp_classifier.pkl b/mlp_classifier.pkl
new file mode 100644
index 0000000..0ff2094
Binary files /dev/null and b/mlp_classifier.pkl differ
diff --git a/mnist.ipynb b/mnist.ipynb
index a4c90fd..264970b 100644
--- a/mnist.ipynb
+++ b/mnist.ipynb
@@ -31,7 +31,7 @@
      "output_type": "stream",
      "text": [
       "Found cached dataset mnist (/Users/arun/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332)\n",
-      "100%|██████████| 2/2 [00:00<00:00, 65.21it/s]\n"
+      "100%|██████████| 2/2 [00:00<00:00, 69.76it/s]\n"
      ]
     }
    ],
@@ -117,72 +117,153 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 43,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# model definition\n",
+    "def linear_classifier():\n",
+    "    return nn.Sequential(\n",
+    "        Reshape((-1, 784)),\n",
+    "        nn.Linear(784, 50),\n",
+    "        nn.ReLU(),\n",
+    "        nn.Linear(50, 50),\n",
+    "        nn.ReLU(),\n",
+    "        nn.Linear(50, 10)\n",
+    "    )"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 44,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "train, epoch:1, loss: 0.2640, accuracy: 0.7885\n",
+      "eval, epoch:1, loss: 0.3039, accuracy: 0.8994\n",
+      "train, epoch:2, loss: 0.2368, accuracy: 0.9182\n",
+      "eval, epoch:2, loss: 0.2164, accuracy: 0.9350\n",
+      "train, epoch:3, loss: 0.1951, accuracy: 0.9402\n",
+      "eval, epoch:3, loss: 0.1589, accuracy: 0.9498\n",
+      "train, epoch:4, loss: 0.1511, accuracy: 0.9513\n",
+      "eval, epoch:4, loss: 0.1388, accuracy: 0.9618\n",
+      "train, epoch:5, loss: 0.1182, accuracy: 0.9567\n",
+      "eval, epoch:5, loss: 0.1426, accuracy: 0.9621\n"
+     ]
+    }
+   ],
+   "source": [
+    "model = linear_classifier()\n",
+    "lr = 0.1\n",
+    "max_lr = 0.1\n",
+    "epochs = 5\n",
+    "opt = optim.AdamW(model.parameters(), lr=lr)\n",
+    "sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)\n",
+    "\n",
+    "for epoch in range(epochs):\n",
+    "    for train in (True, False):\n",
+    "        accuracy = 0\n",
+    "        dl = dls.train if train else dls.valid\n",
+    "        for xb,yb in dl:\n",
+    "            preds = model(xb)\n",
+    "            loss = F.cross_entropy(preds, yb)\n",
+    "            if train:\n",
+    "                loss.backward()\n",
+    "                opt.step()\n",
+    "                opt.zero_grad()\n",
+    "            with torch.no_grad():\n",
+    "                accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()\n",
+    "        if train:\n",
+    "            sched.step()\n",
+    "        accuracy /= len(dl)\n",
+    "        print(f\"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}\")\n",
+    "    "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 46,
+   "metadata": {
+    "tags": [
+     "exclude"
+    ]
+   },
+   "outputs": [],
+   "source": [
+    "with open('./mlp_classifier.pkl', 'wb') as model_file:\n",
+    "    pickle.dump(model, model_file)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 35,
    "metadata": {},
    "outputs": [],
    "source": [
     "def cnn_classifier():\n",
     "    ks,stride = 3,2\n",
     "    return nn.Sequential(\n",
-    "        nn.Conv2d(1, 4, kernel_size=ks, stride=stride, padding=ks//2),\n",
-    "        nn.ReLU(),\n",
-    "        nn.Conv2d(4, 8, kernel_size=ks, stride=stride, padding=ks//2),\n",
+    "        nn.Conv2d(1, 8, kernel_size=ks, stride=stride, padding=ks//2),\n",
+    "        nn.BatchNorm2d(8),\n",
     "        nn.ReLU(),\n",
     "        nn.Conv2d(8, 16, kernel_size=ks, stride=stride, padding=ks//2),\n",
+    "        nn.BatchNorm2d(16),\n",
     "        nn.ReLU(),\n",
     "        nn.Conv2d(16, 32, kernel_size=ks, stride=stride, padding=ks//2),\n",
+    "        nn.BatchNorm2d(32),\n",
+    "        nn.ReLU(),\n",
+    "        nn.Conv2d(32, 64, kernel_size=ks, stride=stride, padding=ks//2),\n",
+    "        nn.BatchNorm2d(64),\n",
     "        nn.ReLU(),\n",
-    "        nn.Conv2d(32, 32, kernel_size=ks, stride=stride, padding=ks//2),\n",
+    "        nn.Conv2d(64, 64, kernel_size=ks, stride=stride, padding=ks//2),\n",
+    "        nn.BatchNorm2d(64),\n",
     "        nn.ReLU(),\n",
-    "        nn.Conv2d(32, 10, kernel_size=ks, stride=stride, padding=ks//2),\n",
+    "        nn.Conv2d(64, 10, kernel_size=ks, stride=stride, padding=ks//2),\n",
     "        nn.Flatten(),\n",
     "    )"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 7,
+   "execution_count": 36,
    "metadata": {},
    "outputs": [],
    "source": [
-    "# model definition\n",
-    "def linear_classifier():\n",
-    "    return nn.Sequential(\n",
-    "        Reshape((-1, 784)),\n",
-    "        nn.Linear(784, 50),\n",
-    "        nn.ReLU(),\n",
-    "        nn.Linear(50, 50),\n",
-    "        nn.ReLU(),\n",
-    "        nn.Linear(50, 10)\n",
-    "    )"
+    "def kaiming_init(m):\n",
+    "    if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):\n",
+    "        nn.init.kaiming_normal_(m.weight)"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 8,
+   "execution_count": 37,
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "train, epoch:1, loss: 0.2638, accuracy: 0.8032\n",
-      "eval, epoch:1, loss: 0.2929, accuracy: 0.9011\n",
-      "train, epoch:2, loss: 0.2497, accuracy: 0.9180\n",
-      "eval, epoch:2, loss: 0.2317, accuracy: 0.9312\n",
-      "train, epoch:3, loss: 0.1817, accuracy: 0.9391\n",
-      "eval, epoch:3, loss: 0.1751, accuracy: 0.9496\n",
-      "train, epoch:4, loss: 0.1589, accuracy: 0.9518\n",
-      "eval, epoch:4, loss: 0.1630, accuracy: 0.9638\n",
-      "train, epoch:5, loss: 0.1498, accuracy: 0.9603\n",
-      "eval, epoch:5, loss: 0.1425, accuracy: 0.9655\n"
+      "train, epoch:1, loss: 0.1096, accuracy: 0.9145\n",
+      "eval, epoch:1, loss: 0.1383, accuracy: 0.9774\n",
+      "train, epoch:2, loss: 0.0487, accuracy: 0.9808\n",
+      "eval, epoch:2, loss: 0.0715, accuracy: 0.9867\n",
+      "train, epoch:3, loss: 0.0536, accuracy: 0.9840\n",
+      "eval, epoch:3, loss: 0.0499, accuracy: 0.9896\n",
+      "train, epoch:4, loss: 0.0358, accuracy: 0.9842\n",
+      "eval, epoch:4, loss: 0.0474, accuracy: 0.9893\n",
+      "train, epoch:5, loss: 0.0514, accuracy: 0.9852\n",
+      "eval, epoch:5, loss: 0.0579, accuracy: 0.9886\n"
      ]
     }
    ],
    "source": [
-    "model = linear_classifier()\n",
+    "model = cnn_classifier()\n",
+    "model.apply(kaiming_init)\n",
     "lr = 0.1\n",
-    "max_lr = 0.1\n",
+    "max_lr = 0.3\n",
     "epochs = 5\n",
     "opt = optim.AdamW(model.parameters(), lr=lr)\n",
     "sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)\n",
@@ -209,7 +290,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 31,
+   "execution_count": 41,
    "metadata": {
     "tags": [
      "exclude"
@@ -217,8 +298,8 @@
    },
    "outputs": [],
    "source": [
-    "# with open('./linear_classifier.pkl', 'wb') as model_file:\n",
-    "#     pickle.dump(model, model_file)"
+    "with open('./cnn_classifier.pkl', 'wb') as model_file:\n",
+    "    pickle.dump(model, model_file)"
    ]
   },
   {
diff --git a/mnist.py b/mnist.py
index e7a2d72..45ca3f7 100644
--- a/mnist.py
+++ b/mnist.py
@@ -53,39 +53,77 @@ def forward(self, x):
         return x.reshape(self.dim)
 
 
+# model definition
+def linear_classifier():
+    return nn.Sequential(
+        Reshape((-1, 784)),
+        nn.Linear(784, 50),
+        nn.ReLU(),
+        nn.Linear(50, 50),
+        nn.ReLU(),
+        nn.Linear(50, 10)
+    )
+
+
+model = linear_classifier()
+lr = 0.1
+max_lr = 0.1
+epochs = 5
+opt = optim.AdamW(model.parameters(), lr=lr)
+sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)
+
+for epoch in range(epochs):
+    for train in (True, False):
+        accuracy = 0
+        dl = dls.train if train else dls.valid
+        for xb,yb in dl:
+            preds = model(xb)
+            loss = F.cross_entropy(preds, yb)
+            if train:
+                loss.backward()
+                opt.step()
+                opt.zero_grad()
+            with torch.no_grad():
+                accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()
+        if train:
+            sched.step()
+        accuracy /= len(dl)
+        print(f"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}")
+    
+
+
 def cnn_classifier():
     ks,stride = 3,2
     return nn.Sequential(
-        nn.Conv2d(1, 4, kernel_size=ks, stride=stride, padding=ks//2),
-        nn.ReLU(),
-        nn.Conv2d(4, 8, kernel_size=ks, stride=stride, padding=ks//2),
+        nn.Conv2d(1, 8, kernel_size=ks, stride=stride, padding=ks//2),
+        nn.BatchNorm2d(8),
         nn.ReLU(),
         nn.Conv2d(8, 16, kernel_size=ks, stride=stride, padding=ks//2),
+        nn.BatchNorm2d(16),
         nn.ReLU(),
         nn.Conv2d(16, 32, kernel_size=ks, stride=stride, padding=ks//2),
+        nn.BatchNorm2d(32),
         nn.ReLU(),
-        nn.Conv2d(32, 32, kernel_size=ks, stride=stride, padding=ks//2),
+        nn.Conv2d(32, 64, kernel_size=ks, stride=stride, padding=ks//2),
+        nn.BatchNorm2d(64),
         nn.ReLU(),
-        nn.Conv2d(32, 10, kernel_size=ks, stride=stride, padding=ks//2),
+        nn.Conv2d(64, 64, kernel_size=ks, stride=stride, padding=ks//2),
+        nn.BatchNorm2d(64),
+        nn.ReLU(),
+        nn.Conv2d(64, 10, kernel_size=ks, stride=stride, padding=ks//2),
         nn.Flatten(),
     )
 
 
-# model definition
-def linear_classifier():
-    return nn.Sequential(
-        Reshape((-1, 784)),
-        nn.Linear(784, 50),
-        nn.ReLU(),
-        nn.Linear(50, 50),
-        nn.ReLU(),
-        nn.Linear(50, 10)
-    )
+def kaiming_init(m):
+    if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
+        nn.init.kaiming_normal_(m.weight)
 
 
-model = linear_classifier()
+model = cnn_classifier()
+model.apply(kaiming_init)
 lr = 0.1
-max_lr = 0.1
+max_lr = 0.3
 epochs = 5
 opt = optim.AdamW(model.parameters(), lr=lr)
 sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)