diff --git a/cnn_classifier.pkl b/cnn_classifier.pkl new file mode 100644 index 0000000..d6e231d Binary files /dev/null and b/cnn_classifier.pkl differ diff --git a/linear_classifier.pkl b/linear_classifier.pkl deleted file mode 100644 index 1a87837..0000000 Binary files a/linear_classifier.pkl and /dev/null differ diff --git a/mlp_classifier.pkl b/mlp_classifier.pkl new file mode 100644 index 0000000..0ff2094 Binary files /dev/null and b/mlp_classifier.pkl differ diff --git a/mnist.ipynb b/mnist.ipynb index a4c90fd..264970b 100644 --- a/mnist.ipynb +++ b/mnist.ipynb @@ -31,7 +31,7 @@ "output_type": "stream", "text": [ "Found cached dataset mnist (/Users/arun/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332)\n", - "100%|██████████| 2/2 [00:00<00:00, 65.21it/s]\n" + "100%|██████████| 2/2 [00:00<00:00, 69.76it/s]\n" ] } ], @@ -117,72 +117,153 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 43, + "metadata": {}, + "outputs": [], + "source": [ + "# model definition\n", + "def linear_classifier():\n", + " return nn.Sequential(\n", + " Reshape((-1, 784)),\n", + " nn.Linear(784, 50),\n", + " nn.ReLU(),\n", + " nn.Linear(50, 50),\n", + " nn.ReLU(),\n", + " nn.Linear(50, 10)\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train, epoch:1, loss: 0.2640, accuracy: 0.7885\n", + "eval, epoch:1, loss: 0.3039, accuracy: 0.8994\n", + "train, epoch:2, loss: 0.2368, accuracy: 0.9182\n", + "eval, epoch:2, loss: 0.2164, accuracy: 0.9350\n", + "train, epoch:3, loss: 0.1951, accuracy: 0.9402\n", + "eval, epoch:3, loss: 0.1589, accuracy: 0.9498\n", + "train, epoch:4, loss: 0.1511, accuracy: 0.9513\n", + "eval, epoch:4, loss: 0.1388, accuracy: 0.9618\n", + "train, epoch:5, loss: 0.1182, accuracy: 0.9567\n", + "eval, epoch:5, loss: 0.1426, accuracy: 0.9621\n" + ] + } + ], + "source": [ + "model = linear_classifier()\n", + "lr = 0.1\n", + "max_lr = 0.1\n", + "epochs = 5\n", + "opt = optim.AdamW(model.parameters(), lr=lr)\n", + "sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)\n", + "\n", + "for epoch in range(epochs):\n", + " for train in (True, False):\n", + " accuracy = 0\n", + " dl = dls.train if train else dls.valid\n", + " for xb,yb in dl:\n", + " preds = model(xb)\n", + " loss = F.cross_entropy(preds, yb)\n", + " if train:\n", + " loss.backward()\n", + " opt.step()\n", + " opt.zero_grad()\n", + " with torch.no_grad():\n", + " accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()\n", + " if train:\n", + " sched.step()\n", + " accuracy /= len(dl)\n", + " print(f\"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}\")\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": { + "tags": [ + "exclude" + ] + }, + "outputs": [], + "source": [ + "with open('./mlp_classifier.pkl', 'wb') as model_file:\n", + " pickle.dump(model, model_file)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "def cnn_classifier():\n", " ks,stride = 3,2\n", " return nn.Sequential(\n", - " nn.Conv2d(1, 4, kernel_size=ks, stride=stride, padding=ks//2),\n", - " nn.ReLU(),\n", - " nn.Conv2d(4, 8, kernel_size=ks, stride=stride, padding=ks//2),\n", + " nn.Conv2d(1, 8, kernel_size=ks, stride=stride, padding=ks//2),\n", + " nn.BatchNorm2d(8),\n", " nn.ReLU(),\n", " nn.Conv2d(8, 16, kernel_size=ks, stride=stride, padding=ks//2),\n", + " nn.BatchNorm2d(16),\n", " nn.ReLU(),\n", " nn.Conv2d(16, 32, kernel_size=ks, stride=stride, padding=ks//2),\n", + " nn.BatchNorm2d(32),\n", + " nn.ReLU(),\n", + " nn.Conv2d(32, 64, kernel_size=ks, stride=stride, padding=ks//2),\n", + " nn.BatchNorm2d(64),\n", " nn.ReLU(),\n", - " nn.Conv2d(32, 32, kernel_size=ks, stride=stride, padding=ks//2),\n", + " nn.Conv2d(64, 64, kernel_size=ks, stride=stride, padding=ks//2),\n", + " nn.BatchNorm2d(64),\n", " nn.ReLU(),\n", - " nn.Conv2d(32, 10, kernel_size=ks, stride=stride, padding=ks//2),\n", + " nn.Conv2d(64, 10, kernel_size=ks, stride=stride, padding=ks//2),\n", " nn.Flatten(),\n", " )" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 36, "metadata": {}, "outputs": [], "source": [ - "# model definition\n", - "def linear_classifier():\n", - " return nn.Sequential(\n", - " Reshape((-1, 784)),\n", - " nn.Linear(784, 50),\n", - " nn.ReLU(),\n", - " nn.Linear(50, 50),\n", - " nn.ReLU(),\n", - " nn.Linear(50, 10)\n", - " )" + "def kaiming_init(m):\n", + " if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):\n", + " nn.init.kaiming_normal_(m.weight)" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 37, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "train, epoch:1, loss: 0.2638, accuracy: 0.8032\n", - "eval, epoch:1, loss: 0.2929, accuracy: 0.9011\n", - "train, epoch:2, loss: 0.2497, accuracy: 0.9180\n", - "eval, epoch:2, loss: 0.2317, accuracy: 0.9312\n", - "train, epoch:3, loss: 0.1817, accuracy: 0.9391\n", - "eval, epoch:3, loss: 0.1751, accuracy: 0.9496\n", - "train, epoch:4, loss: 0.1589, accuracy: 0.9518\n", - "eval, epoch:4, loss: 0.1630, accuracy: 0.9638\n", - "train, epoch:5, loss: 0.1498, accuracy: 0.9603\n", - "eval, epoch:5, loss: 0.1425, accuracy: 0.9655\n" + "train, epoch:1, loss: 0.1096, accuracy: 0.9145\n", + "eval, epoch:1, loss: 0.1383, accuracy: 0.9774\n", + "train, epoch:2, loss: 0.0487, accuracy: 0.9808\n", + "eval, epoch:2, loss: 0.0715, accuracy: 0.9867\n", + "train, epoch:3, loss: 0.0536, accuracy: 0.9840\n", + "eval, epoch:3, loss: 0.0499, accuracy: 0.9896\n", + "train, epoch:4, loss: 0.0358, accuracy: 0.9842\n", + "eval, epoch:4, loss: 0.0474, accuracy: 0.9893\n", + "train, epoch:5, loss: 0.0514, accuracy: 0.9852\n", + "eval, epoch:5, loss: 0.0579, accuracy: 0.9886\n" ] } ], "source": [ - "model = linear_classifier()\n", + "model = cnn_classifier()\n", + "model.apply(kaiming_init)\n", "lr = 0.1\n", - "max_lr = 0.1\n", + "max_lr = 0.3\n", "epochs = 5\n", "opt = optim.AdamW(model.parameters(), lr=lr)\n", "sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)\n", @@ -209,7 +290,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 41, "metadata": { "tags": [ "exclude" @@ -217,8 +298,8 @@ }, "outputs": [], "source": [ - "# with open('./linear_classifier.pkl', 'wb') as model_file:\n", - "# pickle.dump(model, model_file)" + "with open('./cnn_classifier.pkl', 'wb') as model_file:\n", + " pickle.dump(model, model_file)" ] }, { diff --git a/mnist.py b/mnist.py index e7a2d72..45ca3f7 100644 --- a/mnist.py +++ b/mnist.py @@ -53,39 +53,77 @@ def forward(self, x): return x.reshape(self.dim) +# model definition +def linear_classifier(): + return nn.Sequential( + Reshape((-1, 784)), + nn.Linear(784, 50), + nn.ReLU(), + nn.Linear(50, 50), + nn.ReLU(), + nn.Linear(50, 10) + ) + + +model = linear_classifier() +lr = 0.1 +max_lr = 0.1 +epochs = 5 +opt = optim.AdamW(model.parameters(), lr=lr) +sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs) + +for epoch in range(epochs): + for train in (True, False): + accuracy = 0 + dl = dls.train if train else dls.valid + for xb,yb in dl: + preds = model(xb) + loss = F.cross_entropy(preds, yb) + if train: + loss.backward() + opt.step() + opt.zero_grad() + with torch.no_grad(): + accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean() + if train: + sched.step() + accuracy /= len(dl) + print(f"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}") + + + def cnn_classifier(): ks,stride = 3,2 return nn.Sequential( - nn.Conv2d(1, 4, kernel_size=ks, stride=stride, padding=ks//2), - nn.ReLU(), - nn.Conv2d(4, 8, kernel_size=ks, stride=stride, padding=ks//2), + nn.Conv2d(1, 8, kernel_size=ks, stride=stride, padding=ks//2), + nn.BatchNorm2d(8), nn.ReLU(), nn.Conv2d(8, 16, kernel_size=ks, stride=stride, padding=ks//2), + nn.BatchNorm2d(16), nn.ReLU(), nn.Conv2d(16, 32, kernel_size=ks, stride=stride, padding=ks//2), + nn.BatchNorm2d(32), nn.ReLU(), - nn.Conv2d(32, 32, kernel_size=ks, stride=stride, padding=ks//2), + nn.Conv2d(32, 64, kernel_size=ks, stride=stride, padding=ks//2), + nn.BatchNorm2d(64), nn.ReLU(), - nn.Conv2d(32, 10, kernel_size=ks, stride=stride, padding=ks//2), + nn.Conv2d(64, 64, kernel_size=ks, stride=stride, padding=ks//2), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.Conv2d(64, 10, kernel_size=ks, stride=stride, padding=ks//2), nn.Flatten(), ) -# model definition -def linear_classifier(): - return nn.Sequential( - Reshape((-1, 784)), - nn.Linear(784, 50), - nn.ReLU(), - nn.Linear(50, 50), - nn.ReLU(), - nn.Linear(50, 10) - ) +def kaiming_init(m): + if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + nn.init.kaiming_normal_(m.weight) -model = linear_classifier() +model = cnn_classifier() +model.apply(kaiming_init) lr = 0.1 -max_lr = 0.1 +max_lr = 0.3 epochs = 5 opt = optim.AdamW(model.parameters(), lr=lr) sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)