From 55abd019d5e90831b6ed7f9aaff778dc98d69949 Mon Sep 17 00:00:00 2001 From: Arun Date: Tue, 12 Sep 2023 20:08:21 +0530 Subject: [PATCH] linear classifier trained --- mnist.ipynb | 203 +++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 201 insertions(+), 2 deletions(-) diff --git a/mnist.ipynb b/mnist.ipynb index 51f13c5..fa6ef89 100644 --- a/mnist.ipynb +++ b/mnist.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 10, + "execution_count": 140, "metadata": {}, "outputs": [], "source": [ @@ -10,9 +10,208 @@ "from torch import nn\n", "import torch.nn.functional as F\n", "from datasets import load_dataset\n", - "import fastcore.all as fc" + "import fastcore.all as fc\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib as mpl\n", + "import torchvision.transforms.functional as TF\n", + "from torch.utils.data import default_collate, DataLoader\n", + "import torch.optim as optim\n", + "import pickle\n", + "%matplotlib inline\n", + "plt.rcParams['figure.figsize'] = [2, 2]" ] }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "dataset_nm = 'mnist'\n", + "x,y = 'image', 'label'\n", + "ds = load_dataset(dataset_nm)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAMkAAADICAYAAABCmsWgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAP8UlEQVR4nO3df1CUd34H8PcisKDCg+ixCyPEbc6rTkyhRcAdHWOSrZydOv5qatr7w5g0TnTxBkmbC07UnuMNnnaMkRAzTSOYmRodbqok5kong4pnCmRETMaQI+ZCIj3YNcRhd4Pya/fbP4jb2X4f+bKwuA/4fs08f+xnvyyfx+TNl+fL88MkhBAgonuKiXYDREbHkBApMCRECgwJkQJDQqTAkBApMCRECgwJkQJDQqTAkBApxE7UB1dUVODgwYNwuVzIzs5GeXk58vPzlV8XCATQ2dmJpKQkmEymiWqPHnBCCPh8PmRkZCAmRjFXiAlw8uRJER8fL44dOyY+++wz8fzzz4uUlBThdruVX9vR0SEAcON2X7aOjg7l/5MmISJ/gmNBQQHy8vLw+uuvAxieHTIzM7F9+3a8/PLLI36tx+NBSkoKluGvEIu4SLdGBAAYwiAu4bfo6emBpmkjjo34r1sDAwNobm5GaWlpsBYTEwOHw4GGhgZpfH9/P/r7+4OvfT7fD43FIdbEkNAE+WFqGM2v9BE/cO/u7obf74fFYgmpWywWuFwuaXxZWRk0TQtumZmZkW6JaFyivrpVWloKj8cT3Do6OqLdElGIiP+6NWfOHEybNg1utzuk7na7YbVapfFmsxlmsznSbRBFTMRnkvj4eOTm5qKuri5YCwQCqKurg91uj/S3I5pwE/J3kpKSEmzatAmLFy9Gfn4+Dh8+jN7eXmzevHkivh3RhJqQkGzcuBHffvstdu/eDZfLhZycHNTW1koH80STwYT8nWQ8vF4vNE3DCqzhEjBNmCExiAuogcfjQXJy8ohjo766RWR0DAmRAkNCpMCQECkwJEQKDAmRAkNCpMCQECkwJEQKDAmRAkNCpMCQECkwJEQKDAmRAkNCpMCQECkwJEQKE3YvYBo/U6z8n2faj+aM+3Pb/nGeVPNPD+iOfejhm1Jt+jb9G7q5DsVLtSuLT+mO7fb3SrWC6hd1x/64pFG3fr9wJiFSYEiIFBgSIgWGhEiBISFS4OpWBExbOF+qCbP+PcM6H0uRaneWyCs9AJCqyfXfZeuvFk2U/7ydJNV+/fpPdcc2PXpCqrUP3tEdu9/9l1It43eGugVcEGcSIgWGhEiBISFSYEiIFHjgHgb/ir/QrR+qqpBqP4mTT9EwskHh163vLn9GqsX26h9g26uLpFrSH4d0x5q75QP66ZebRugwejiTECkwJEQKDAmRAkNCpMCQEClwdSsM5rZO3XpzX6ZU+0mcW2fkxHmxa4lU++p7/Qu0qh7+jVTzBPRXrCxH/nt8jd2DMU9A0ceZhEiBISFSYEiIFBgSIgUeuIdhqMulWy//9VNS7Vc/1b9GZNqnM6XaJ9vKR93Dvu4/061/6Zgu1fw9Xbpj/96+Tap9/XP972fDJ6PubariTEKkwJAQKTAkRAoMCZFC2CG5ePEiVq9ejYyMDJhMJpw5cybkfSEEdu/ejfT0dCQmJsLhcOD69euR6pfovgt7dau3txfZ2dl49tlnsX79eun9AwcO4MiRIzh+/DhsNht27dqFwsJCtLa2IiEhISJNG01qZYNU+9H7s3XH+r+7JdUeWfSs7tjPlh+Tau/962O6Y9N6Rn/6iKlBXrGyybtAPwg7JKtWrcKqVat03xNC4PDhw3jllVewZs0aAMA777wDi8WCM2fO4Omnnx5ft0RRENFjkvb2drhcLjgcjmBN0zQUFBSgoUH/R1V/fz+8Xm/IRmQkEQ2JyzX8xzaLxRJSt1gswff+v7KyMmiaFtwyM+UzaomiKeqrW6WlpfB4PMGto6Mj2i0RhYjoaSlWqxUA4Ha7kZ6eHqy73W7k5OTofo3ZbIbZbI5kG4bg7/5u1GMHvaO/s8ojP2vVrX97dJpcDOjfAYXCE9GZxGazwWq1oq6uLljzer1oamqC3W6P5Lcium/Cnkm+//57fPnll8HX7e3tuHr1KlJTU5GVlYXi4mLs27cP8+fPDy4BZ2RkYO3atZHsm+i+CTskly9fxuOPPx58XVJSAgDYtGkTqqqq8NJLL6G3txdbtmxBT08Pli1bhtra2in7NxKa+sIOyYoVKyDEva9QNplM2Lt3L/bu3TuuxoiMIuqrW0RGx4uuDGDhL77QrW9+9EmpVvlQnc5I4LGnnFIt6VR0H+08VXAmIVJgSIgUGBIiBYaESIEH7gbg7/Ho1r/bulCq3XhP/2m2L+97R6qV/u063bGiRZNqmb+6xwUlIyz3Pyg4kxApMCRECgwJkQJDQqTAkBApcHXLwAKffC7Vnv7lP+mO/fc9/yLVri6RV7wAAPLzfvDIDPnx0gAw/y35fsJDX32t/7lTFGcSIgWGhEiBISFSYEiIFExipMsMo8Dr9ULTNKzAGsSa4qLdzqQhluZIteT9/6M79t0/+a9Rf+6C8/8g1f70l/qn0fivfzXqz422ITGIC6iBx+NBcnLyiGM5kxApMCRECgwJkQJDQqTAkBAp8LSUKcL00VWpdvtv0nTH5m3cLtWafvGa7tjfP/5vUu1n81bqjvUsG6HBSYwzCZECQ0KkwJAQKTAkRAo8cJ/C/O6bunXLEbne99KQ7tjpJvkBQ2/NO6s79q/XFctff7pphA4nB84kRAoMCZECQ0KkwJAQKTAkRApc3ZoiAstypNofntJ/TuWinK+lmt4q1r2U3/pz3fr0msuj/ozJhDMJkQJDQqTAkBApMCRECjxwNzDT4kVS7Yuf6x9gv7X0uFRbnjAw7h76xaBUa7xl0x8ckG+JOhVwJiFSYEiIFBgSIgWGhEghrJCUlZUhLy8PSUlJSEtLw9q1a9HW1hYypq+vD06nE7Nnz8bMmTOxYcMGuN3uiDZNdD+FtbpVX18Pp9OJvLw8DA0NYefOnVi5ciVaW1sxY8YMAMCOHTvwwQcfoLq6GpqmoaioCOvXr8dHH300ITsw2cTaHpJqf9icoTv2nzeelGobZnZHvCcA2OlerFuvf01+4s+s4/d4nPUUFVZIamtrQ15XVVUhLS0Nzc3NWL58OTweD95++22cOHECTzzxBACgsrISCxcuRGNjI5Ys0XnEEpHBjeuYxOMZvrt4amoqAKC5uRmDg4NwOBzBMQsWLEBWVhYaGvR/+vT398Pr9YZsREYy5pAEAgEUFxdj6dKlWLRo+I9eLpcL8fHxSElJCRlrsVjgcrl0P6esrAyapgW3zMzMsbZENCHGHBKn04lr167h5En59+ZwlJaWwuPxBLeOjo5xfR5RpI3ptJSioiKcPXsWFy9exNy5c4N1q9WKgYEB9PT0hMwmbrcbVqtV97PMZjPMZvNY2jCM2HlZUs2Tm647duPeWqn2Qsp/RLwnAHixS/8YsOEN+SA9tepj3bGzAg/WQbqesGYSIQSKiopw+vRpnDt3DjZb6Dk8ubm5iIuLQ11dXbDW1taGGzduwG63R6ZjovssrJnE6XTixIkTqKmpQVJSUvA4Q9M0JCYmQtM0PPfccygpKUFqaiqSk5Oxfft22O12rmzRpBVWSI4ePQoAWLFiRUi9srISzzzzDADg1VdfRUxMDDZs2ID+/n4UFhbijTfeiEizRNEQVkhG8wzShIQEVFRUoKKiYsxNERkJz90iUuBFV/cQmy6vxt06NkN37FZbvVT7u6SJOV+t6I/6T8q5cjRHqs35zTXdsak+rliFgzMJkQJDQqTAkBApMCRECg/UgftAoXw6xsCOW7pjd/74t1JtZWJvxHsCALf/jm59+XsvSrUFr/xed2xqj3wwHhhfW/QDziRECgwJkQJDQqTAkBApMCRECg/U6tbXa+WfCV88Wj3uz63oeViqvVa/UnesyW+Sagv2teuOne+WH+/sD7M3Gj/OJEQKDAmRAkNCpMCQECmYxGguN7yPvF4vNE3DCqxBrCku2u3QFDUkBnEBNfB4PEhOTh5xLGcSIgWGhEiBISFSYEiIFBgSIgWGhEiBISFSYEiIFBgSIgWGhEiBISFSYEiIFBgSIgWGhEiBISFSMNyNIO5e3jKEQcBQV7rQVDKEQQCje3qb4ULi8/kAAJcg34uXKNJ8Ph80TRtxjOGuTAwEAujs7ERSUhJ8Ph8yMzPR0dGhvHpssvF6vdy3KBJCwOfzISMjAzExIx91GG4miYmJwdy5cwEAJtPwPaqSk5MN+489Xty36FHNIHfxwJ1IgSEhUjB0SMxmM/bs2QOz2RztViKO+zZ5GO7AnchoDD2TEBkBQ0KkwJAQKTAkRAqGDklFRQXmzZuHhIQEFBQU4OOPP452S2G7ePEiVq9ejYyMDJhMJpw5cybkfSEEdu/ejfT0dCQmJsLhcOD69evRaTYMZWVlyMvLQ1JSEtLS0rB27Vq0tbWFjOnr64PT6cTs2bMxc+ZMbNiwAW63O0odj51hQ3Lq1CmUlJRgz549uHLlCrKzs1FYWIibN29Gu7Ww9Pb2Ijs7GxUVFbrvHzhwAEeOHMGbb76JpqYmzJgxA4WFhejr67vPnYanvr4eTqcTjY2N+PDDDzE4OIiVK1eit/f/nnW/Y8cOvP/++6iurkZ9fT06Ozuxfv36KHY9RsKg8vPzhdPpDL72+/0iIyNDlJWVRbGr8QEgTp8+HXwdCASE1WoVBw8eDNZ6enqE2WwW7777bhQ6HLubN28KAKK+vl4IMbwfcXFxorq6Ojjm888/FwBEQ0NDtNocE0POJAMDA2hubobD4QjWYmJi4HA40NDQEMXOIqu9vR0ulytkPzVNQ0FBwaTbT4/HAwBITU0FADQ3N2NwcDBk3xYsWICsrKxJt2+GDEl3dzf8fj8sFktI3WKxwOVyRamryLu7L5N9PwOBAIqLi7F06VIsWrQIwPC+xcfHIyUlJWTsZNs3wIBnAdPk43Q6ce3aNVy6dCnarUwIQ84kc+bMwbRp06SVELfbDavVGqWuIu/uvkzm/SwqKsLZs2dx/vz54CUOwPC+DQwMoKenJ2T8ZNq3uwwZkvj4eOTm5qKuri5YCwQCqKurg91uj2JnkWWz2WC1WkP20+v1oqmpyfD7KYRAUVERTp8+jXPnzsFms4W8n5ubi7i4uJB9a2trw40bNwy/b5Jorxzcy8mTJ4XZbBZVVVWitbVVbNmyRaSkpAiXyxXt1sLi8/lES0uLaGlpEQDEoUOHREtLi/jmm2+EEELs379fpKSkiJqaGvHpp5+KNWvWCJvNJu7cuRPlzke2detWoWmauHDhgujq6gput2/fDo554YUXRFZWljh37py4fPmysNvtwm63R7HrsTFsSIQQory8XGRlZYn4+HiRn58vGhsbo91S2M6fPy8wfEuLkG3Tpk1CiOFl4F27dgmLxSLMZrN48sknRVtbW3SbHgW9fQIgKisrg2Pu3Lkjtm3bJmbNmiWmT58u1q1bJ7q6uqLX9BjxVHkiBUMekxAZCUNCpMCQECkwJEQKDAmRAkNCpMCQECkwJEQKDAmRAkNCpMCQECkwJEQK/wtdrB3XtwW1LQAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def transform_ds(b):\n", + " b[x] = [TF.to_tensor(ele) for ele in b[x]]\n", + " return b\n", + "\n", + "dst = ds.with_transform(transform_ds)\n", + "plt.imshow(dst['train'][0]['image'].permute(1,2,0));" + ] + }, + { + "cell_type": "code", + "execution_count": 88, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(torch.Size([1024, 1, 28, 28]), torch.Size([1024]))" + ] + }, + "execution_count": 88, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "bs = 1024\n", + "class DataLoaders:\n", + " def __init__(self, train_ds, valid_ds, bs, collate_fn, **kwargs):\n", + " self.train = DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate_fn, **kwargs)\n", + " self.valid = DataLoader(train_ds, batch_size=bs*2, shuffle=False, collate_fn=collate_fn, **kwargs)\n", + "\n", + "def collate_fn(b):\n", + " collate = default_collate(b)\n", + " return (collate[x], collate[y])\n", + "\n", + "dls = DataLoaders(dst['train'], dst['test'], bs=bs, collate_fn=collate_fn)\n", + "xb,yb = next(iter(dls.train))\n", + "xb.shape, yb.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 120, + "metadata": {}, + "outputs": [], + "source": [ + "class Reshape(nn.Module):\n", + " def __init__(self, dim):\n", + " super().__init__()\n", + " self.dim = dim\n", + " \n", + " def forward(self, x):\n", + " return x.reshape(self.dim)" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "metadata": {}, + "outputs": [], + "source": [ + "def cnn_classifier():\n", + " ks,stride = 3,2\n", + " return nn.Sequential(\n", + " nn.Conv2d(1, 4, kernel_size=ks, stride=stride, padding=ks//2),\n", + " nn.ReLU(),\n", + " nn.Conv2d(4, 8, kernel_size=ks, stride=stride, padding=ks//2),\n", + " nn.ReLU(),\n", + " nn.Conv2d(8, 16, kernel_size=ks, stride=stride, padding=ks//2),\n", + " nn.ReLU(),\n", + " nn.Conv2d(16, 32, kernel_size=ks, stride=stride, padding=ks//2),\n", + " nn.ReLU(),\n", + " nn.Conv2d(32, 32, kernel_size=ks, stride=stride, padding=ks//2),\n", + " nn.ReLU(),\n", + " nn.Conv2d(32, 10, kernel_size=ks, stride=stride, padding=ks//2),\n", + " nn.Flatten(),\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 102, + "metadata": {}, + "outputs": [], + "source": [ + "def linear_classifier():\n", + " return nn.Sequential(\n", + " Reshape((-1, 784)),\n", + " nn.Linear(784, 50),\n", + " nn.ReLU(),\n", + " nn.Linear(50, 50),\n", + " nn.ReLU(),\n", + " nn.Linear(50, 10)\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 131, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train, epoch:1, loss: 0.3454, accuracy: 0.7909\n", + "eval, epoch:1, loss: 0.3175, accuracy: 0.9049\n", + "train, epoch:2, loss: 0.2423, accuracy: 0.9222\n", + "eval, epoch:2, loss: 0.2136, accuracy: 0.9385\n", + "train, epoch:3, loss: 0.1425, accuracy: 0.9419\n", + "eval, epoch:3, loss: 0.1797, accuracy: 0.9486\n", + "train, epoch:4, loss: 0.1427, accuracy: 0.9565\n", + "eval, epoch:4, loss: 0.1581, accuracy: 0.9624\n", + "train, epoch:5, loss: 0.1579, accuracy: 0.9620\n", + "eval, epoch:5, loss: 0.0956, accuracy: 0.9681\n" + ] + } + ], + "source": [ + "model = linear_classifier()\n", + "lr = 0.1\n", + "max_lr = 0.1\n", + "epochs = 5\n", + "opt = optim.AdamW(model.parameters(), lr=lr)\n", + "sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)\n", + "\n", + "for epoch in range(epochs):\n", + " for train in (True, False):\n", + " accuracy = 0\n", + " dl = dls.train if train else dls.valid\n", + " for xb,yb in dl:\n", + " preds = model(xb)\n", + " loss = F.cross_entropy(preds, yb)\n", + " if train:\n", + " loss.backward()\n", + " opt.step()\n", + " opt.zero_grad()\n", + " with torch.no_grad():\n", + " accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()\n", + " if train:\n", + " sched.step()\n", + " accuracy /= len(dl)\n", + " print(f\"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}\")\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 141, + "metadata": {}, + "outputs": [], + "source": [ + "with open('linear_classifier.pkl', 'wb') as model_file:\n", + " pickle.dump(model, model_file)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null,