Skip to content

Commit

Permalink
linear classifier trained
Browse files Browse the repository at this point in the history
  • Loading branch information
arun477 authored Sep 12, 2023
1 parent d64050c commit 55abd01
Showing 1 changed file with 201 additions and 2 deletions.
203 changes: 201 additions & 2 deletions mnist.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,216 @@
"cells": [
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 140,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"import torch.nn.functional as F\n",
"from datasets import load_dataset\n",
"import fastcore.all as fc"
"import fastcore.all as fc\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib as mpl\n",
"import torchvision.transforms.functional as TF\n",
"from torch.utils.data import default_collate, DataLoader\n",
"import torch.optim as optim\n",
"import pickle\n",
"%matplotlib inline\n",
"plt.rcParams['figure.figsize'] = [2, 2]"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"dataset_nm = 'mnist'\n",
"x,y = 'image', 'label'\n",
"ds = load_dataset(dataset_nm)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAMkAAADICAYAAABCmsWgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAP8UlEQVR4nO3df1CUd34H8PcisKDCg+ixCyPEbc6rTkyhRcAdHWOSrZydOv5qatr7w5g0TnTxBkmbC07UnuMNnnaMkRAzTSOYmRodbqok5kong4pnCmRETMaQI+ZCIj3YNcRhd4Pya/fbP4jb2X4f+bKwuA/4fs08f+xnvyyfx+TNl+fL88MkhBAgonuKiXYDREbHkBApMCRECgwJkQJDQqTAkBApMCRECgwJkQJDQqTAkBApxE7UB1dUVODgwYNwuVzIzs5GeXk58vPzlV8XCATQ2dmJpKQkmEymiWqPHnBCCPh8PmRkZCAmRjFXiAlw8uRJER8fL44dOyY+++wz8fzzz4uUlBThdruVX9vR0SEAcON2X7aOjg7l/5MmISJ/gmNBQQHy8vLw+uuvAxieHTIzM7F9+3a8/PLLI36tx+NBSkoKluGvEIu4SLdGBAAYwiAu4bfo6emBpmkjjo34r1sDAwNobm5GaWlpsBYTEwOHw4GGhgZpfH9/P/r7+4OvfT7fD43FIdbEkNAE+WFqGM2v9BE/cO/u7obf74fFYgmpWywWuFwuaXxZWRk0TQtumZmZkW6JaFyivrpVWloKj8cT3Do6OqLdElGIiP+6NWfOHEybNg1utzuk7na7YbVapfFmsxlmsznSbRBFTMRnkvj4eOTm5qKuri5YCwQCqKurg91uj/S3I5pwE/J3kpKSEmzatAmLFy9Gfn4+Dh8+jN7eXmzevHkivh3RhJqQkGzcuBHffvstdu/eDZfLhZycHNTW1koH80STwYT8nWQ8vF4vNE3DCqzhEjBNmCExiAuogcfjQXJy8ohjo766RWR0DAmRAkNCpMCQECkwJEQKDAmRAkNCpMCQECkwJEQKDAmRAkNCpMCQECkwJEQKDAmRAkNCpMCQECkwJEQKE3YvYBo/U6z8n2faj+aM+3Pb/nGeVPNPD+iOfejhm1Jt+jb9G7q5DsVLtSuLT+mO7fb3SrWC6hd1x/64pFG3fr9wJiFSYEiIFBgSIgWGhEiBISFS4OpWBExbOF+qCbP+PcM6H0uRaneWyCs9AJCqyfXfZeuvFk2U/7ydJNV+/fpPdcc2PXpCqrUP3tEdu9/9l1It43eGugVcEGcSIgWGhEiBISFSYEiIFHjgHgb/ir/QrR+qqpBqP4mTT9EwskHh163vLn9GqsX26h9g26uLpFrSH4d0x5q75QP66ZebRugwejiTECkwJEQKDAmRAkNCpMCQEClwdSsM5rZO3XpzX6ZU+0mcW2fkxHmxa4lU++p7/Qu0qh7+jVTzBPRXrCxH/nt8jd2DMU9A0ceZhEiBISFSYEiIFBgSIgUeuIdhqMulWy//9VNS7Vc/1b9GZNqnM6XaJ9vKR93Dvu4/061/6Zgu1fw9Xbpj/96+Tap9/XP972fDJ6PubariTEKkwJAQKTAkRAoMCZFC2CG5ePEiVq9ejYyMDJhMJpw5cybkfSEEdu/ejfT0dCQmJsLhcOD69euR6pfovgt7dau3txfZ2dl49tlnsX79eun9AwcO4MiRIzh+/DhsNht27dqFwsJCtLa2IiEhISJNG01qZYNU+9H7s3XH+r+7JdUeWfSs7tjPlh+Tau/962O6Y9N6Rn/6iKlBXrGyybtAPwg7JKtWrcKqVat03xNC4PDhw3jllVewZs0aAMA777wDi8WCM2fO4Omnnx5ft0RRENFjkvb2drhcLjgcjmBN0zQUFBSgoUH/R1V/fz+8Xm/IRmQkEQ2JyzX8xzaLxRJSt1gswff+v7KyMmiaFtwyM+UzaomiKeqrW6WlpfB4PMGto6Mj2i0RhYjoaSlWqxUA4Ha7kZ6eHqy73W7k5OTofo3ZbIbZbI5kG4bg7/5u1GMHvaO/s8ojP2vVrX97dJpcDOjfAYXCE9GZxGazwWq1oq6uLljzer1oamqC3W6P5Lcium/Cnkm+//57fPnll8HX7e3tuHr1KlJTU5GVlYXi4mLs27cP8+fPDy4BZ2RkYO3atZHsm+i+CTskly9fxuOPPx58XVJSAgDYtGkTqqqq8NJLL6G3txdbtmxBT08Pli1bhtra2in7NxKa+sIOyYoVKyDEva9QNplM2Lt3L/bu3TuuxoiMIuqrW0RGx4uuDGDhL77QrW9+9EmpVvlQnc5I4LGnnFIt6VR0H+08VXAmIVJgSIgUGBIiBYaESIEH7gbg7/Ho1r/bulCq3XhP/2m2L+97R6qV/u063bGiRZNqmb+6xwUlIyz3Pyg4kxApMCRECgwJkQJDQqTAkBApcHXLwAKffC7Vnv7lP+mO/fc9/yLVri6RV7wAAPLzfvDIDPnx0gAw/y35fsJDX32t/7lTFGcSIgWGhEiBISFSYEiIFExipMsMo8Dr9ULTNKzAGsSa4qLdzqQhluZIteT9/6M79t0/+a9Rf+6C8/8g1f70l/qn0fivfzXqz422ITGIC6iBx+NBcnLyiGM5kxApMCRECgwJkQJDQqTAkBAp8LSUKcL00VWpdvtv0nTH5m3cLtWafvGa7tjfP/5vUu1n81bqjvUsG6HBSYwzCZECQ0KkwJAQKTAkRAo8cJ/C/O6bunXLEbne99KQ7tjpJvkBQ2/NO6s79q/XFctff7pphA4nB84kRAoMCZECQ0KkwJAQKTAkRApc3ZoiAstypNofntJ/TuWinK+lmt4q1r2U3/pz3fr0msuj/ozJhDMJkQJDQqTAkBApMCRECjxwNzDT4kVS7Yuf6x9gv7X0uFRbnjAw7h76xaBUa7xl0x8ckG+JOhVwJiFSYEiIFBgSIgWGhEghrJCUlZUhLy8PSUlJSEtLw9q1a9HW1hYypq+vD06nE7Nnz8bMmTOxYcMGuN3uiDZNdD+FtbpVX18Pp9OJvLw8DA0NYefOnVi5ciVaW1sxY8YMAMCOHTvwwQcfoLq6GpqmoaioCOvXr8dHH300ITsw2cTaHpJqf9icoTv2nzeelGobZnZHvCcA2OlerFuvf01+4s+s4/d4nPUUFVZIamtrQ15XVVUhLS0Nzc3NWL58OTweD95++22cOHECTzzxBACgsrISCxcuRGNjI5Ys0XnEEpHBjeuYxOMZvrt4amoqAKC5uRmDg4NwOBzBMQsWLEBWVhYaGvR/+vT398Pr9YZsREYy5pAEAgEUFxdj6dKlWLRo+I9eLpcL8fHxSElJCRlrsVjgcrl0P6esrAyapgW3zMzMsbZENCHGHBKn04lr167h5En59+ZwlJaWwuPxBLeOjo5xfR5RpI3ptJSioiKcPXsWFy9exNy5c4N1q9WKgYEB9PT0hMwmbrcbVqtV97PMZjPMZvNY2jCM2HlZUs2Tm647duPeWqn2Qsp/RLwnAHixS/8YsOEN+SA9tepj3bGzAg/WQbqesGYSIQSKiopw+vRpnDt3DjZb6Dk8ubm5iIuLQ11dXbDW1taGGzduwG63R6ZjovssrJnE6XTixIkTqKmpQVJSUvA4Q9M0JCYmQtM0PPfccygpKUFqaiqSk5Oxfft22O12rmzRpBVWSI4ePQoAWLFiRUi9srISzzzzDADg1VdfRUxMDDZs2ID+/n4UFhbijTfeiEizRNEQVkhG8wzShIQEVFRUoKKiYsxNERkJz90iUuBFV/cQmy6vxt06NkN37FZbvVT7u6SJOV+t6I/6T8q5cjRHqs35zTXdsak+rliFgzMJkQJDQqTAkBApMCRECg/UgftAoXw6xsCOW7pjd/74t1JtZWJvxHsCALf/jm59+XsvSrUFr/xed2xqj3wwHhhfW/QDziRECgwJkQJDQqTAkBApMCRECg/U6tbXa+WfCV88Wj3uz63oeViqvVa/UnesyW+Sagv2teuOne+WH+/sD7M3Gj/OJEQKDAmRAkNCpMCQECmYxGguN7yPvF4vNE3DCqxBrCku2u3QFDUkBnEBNfB4PEhOTh5xLGcSIgWGhEiBISFSYEiIFBgSIgWGhEiBISFSYEiIFBgSIgWGhEiBISFSYEiIFBgSIgWGhEiBISFSMNyNIO5e3jKEQcBQV7rQVDKEQQCje3qb4ULi8/kAAJcg34uXKNJ8Ph80TRtxjOGuTAwEAujs7ERSUhJ8Ph8yMzPR0dGhvHpssvF6vdy3KBJCwOfzISMjAzExIx91GG4miYmJwdy5cwEAJtPwPaqSk5MN+489Xty36FHNIHfxwJ1IgSEhUjB0SMxmM/bs2QOz2RztViKO+zZ5GO7AnchoDD2TEBkBQ0KkwJAQKTAkRAqGDklFRQXmzZuHhIQEFBQU4OOPP452S2G7ePEiVq9ejYyMDJhMJpw5cybkfSEEdu/ejfT0dCQmJsLhcOD69evRaTYMZWVlyMvLQ1JSEtLS0rB27Vq0tbWFjOnr64PT6cTs2bMxc+ZMbNiwAW63O0odj51hQ3Lq1CmUlJRgz549uHLlCrKzs1FYWIibN29Gu7Ww9Pb2Ijs7GxUVFbrvHzhwAEeOHMGbb76JpqYmzJgxA4WFhejr67vPnYanvr4eTqcTjY2N+PDDDzE4OIiVK1eit/f/nnW/Y8cOvP/++6iurkZ9fT06Ozuxfv36KHY9RsKg8vPzhdPpDL72+/0iIyNDlJWVRbGr8QEgTp8+HXwdCASE1WoVBw8eDNZ6enqE2WwW7777bhQ6HLubN28KAKK+vl4IMbwfcXFxorq6Ojjm888/FwBEQ0NDtNocE0POJAMDA2hubobD4QjWYmJi4HA40NDQEMXOIqu9vR0ulytkPzVNQ0FBwaTbT4/HAwBITU0FADQ3N2NwcDBk3xYsWICsrKxJt2+GDEl3dzf8fj8sFktI3WKxwOVyRamryLu7L5N9PwOBAIqLi7F06VIsWrQIwPC+xcfHIyUlJWTsZNs3wIBnAdPk43Q6ce3aNVy6dCnarUwIQ84kc+bMwbRp06SVELfbDavVGqWuIu/uvkzm/SwqKsLZs2dx/vz54CUOwPC+DQwMoKenJ2T8ZNq3uwwZkvj4eOTm5qKuri5YCwQCqKurg91uj2JnkWWz2WC1WkP20+v1oqmpyfD7KYRAUVERTp8+jXPnzsFms4W8n5ubi7i4uJB9a2trw40bNwy/b5Jorxzcy8mTJ4XZbBZVVVWitbVVbNmyRaSkpAiXyxXt1sLi8/lES0uLaGlpEQDEoUOHREtLi/jmm2+EEELs379fpKSkiJqaGvHpp5+KNWvWCJvNJu7cuRPlzke2detWoWmauHDhgujq6gput2/fDo554YUXRFZWljh37py4fPmysNvtwm63R7HrsTFsSIQQory8XGRlZYn4+HiRn58vGhsbo91S2M6fPy8wfEuLkG3Tpk1CiOFl4F27dgmLxSLMZrN48sknRVtbW3SbHgW9fQIgKisrg2Pu3Lkjtm3bJmbNmiWmT58u1q1bJ7q6uqLX9BjxVHkiBUMekxAZCUNCpMCQECkwJEQKDAmRAkNCpMCQECkwJEQKDAmRAkNCpMCQECkwJEQK/wtdrB3XtwW1LQAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 200x200 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"def transform_ds(b):\n",
" b[x] = [TF.to_tensor(ele) for ele in b[x]]\n",
" return b\n",
"\n",
"dst = ds.with_transform(transform_ds)\n",
"plt.imshow(dst['train'][0]['image'].permute(1,2,0));"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([1024, 1, 28, 28]), torch.Size([1024]))"
]
},
"execution_count": 88,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bs = 1024\n",
"class DataLoaders:\n",
" def __init__(self, train_ds, valid_ds, bs, collate_fn, **kwargs):\n",
" self.train = DataLoader(train_ds, batch_size=bs, shuffle=True, collate_fn=collate_fn, **kwargs)\n",
" self.valid = DataLoader(train_ds, batch_size=bs*2, shuffle=False, collate_fn=collate_fn, **kwargs)\n",
"\n",
"def collate_fn(b):\n",
" collate = default_collate(b)\n",
" return (collate[x], collate[y])\n",
"\n",
"dls = DataLoaders(dst['train'], dst['test'], bs=bs, collate_fn=collate_fn)\n",
"xb,yb = next(iter(dls.train))\n",
"xb.shape, yb.shape"
]
},
{
"cell_type": "code",
"execution_count": 120,
"metadata": {},
"outputs": [],
"source": [
"class Reshape(nn.Module):\n",
" def __init__(self, dim):\n",
" super().__init__()\n",
" self.dim = dim\n",
" \n",
" def forward(self, x):\n",
" return x.reshape(self.dim)"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {},
"outputs": [],
"source": [
"def cnn_classifier():\n",
" ks,stride = 3,2\n",
" return nn.Sequential(\n",
" nn.Conv2d(1, 4, kernel_size=ks, stride=stride, padding=ks//2),\n",
" nn.ReLU(),\n",
" nn.Conv2d(4, 8, kernel_size=ks, stride=stride, padding=ks//2),\n",
" nn.ReLU(),\n",
" nn.Conv2d(8, 16, kernel_size=ks, stride=stride, padding=ks//2),\n",
" nn.ReLU(),\n",
" nn.Conv2d(16, 32, kernel_size=ks, stride=stride, padding=ks//2),\n",
" nn.ReLU(),\n",
" nn.Conv2d(32, 32, kernel_size=ks, stride=stride, padding=ks//2),\n",
" nn.ReLU(),\n",
" nn.Conv2d(32, 10, kernel_size=ks, stride=stride, padding=ks//2),\n",
" nn.Flatten(),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 102,
"metadata": {},
"outputs": [],
"source": [
"def linear_classifier():\n",
" return nn.Sequential(\n",
" Reshape((-1, 784)),\n",
" nn.Linear(784, 50),\n",
" nn.ReLU(),\n",
" nn.Linear(50, 50),\n",
" nn.ReLU(),\n",
" nn.Linear(50, 10)\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 131,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train, epoch:1, loss: 0.3454, accuracy: 0.7909\n",
"eval, epoch:1, loss: 0.3175, accuracy: 0.9049\n",
"train, epoch:2, loss: 0.2423, accuracy: 0.9222\n",
"eval, epoch:2, loss: 0.2136, accuracy: 0.9385\n",
"train, epoch:3, loss: 0.1425, accuracy: 0.9419\n",
"eval, epoch:3, loss: 0.1797, accuracy: 0.9486\n",
"train, epoch:4, loss: 0.1427, accuracy: 0.9565\n",
"eval, epoch:4, loss: 0.1581, accuracy: 0.9624\n",
"train, epoch:5, loss: 0.1579, accuracy: 0.9620\n",
"eval, epoch:5, loss: 0.0956, accuracy: 0.9681\n"
]
}
],
"source": [
"model = linear_classifier()\n",
"lr = 0.1\n",
"max_lr = 0.1\n",
"epochs = 5\n",
"opt = optim.AdamW(model.parameters(), lr=lr)\n",
"sched = optim.lr_scheduler.OneCycleLR(opt, max_lr, total_steps=len(dls.train), epochs=epochs)\n",
"\n",
"for epoch in range(epochs):\n",
" for train in (True, False):\n",
" accuracy = 0\n",
" dl = dls.train if train else dls.valid\n",
" for xb,yb in dl:\n",
" preds = model(xb)\n",
" loss = F.cross_entropy(preds, yb)\n",
" if train:\n",
" loss.backward()\n",
" opt.step()\n",
" opt.zero_grad()\n",
" with torch.no_grad():\n",
" accuracy += (preds.argmax(1).detach().cpu() == yb).float().mean()\n",
" if train:\n",
" sched.step()\n",
" accuracy /= len(dl)\n",
" print(f\"{'train' if train else 'eval'}, epoch:{epoch+1}, loss: {loss.item():.4f}, accuracy: {accuracy:.4f}\")\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 141,
"metadata": {},
"outputs": [],
"source": [
"with open('linear_classifier.pkl', 'wb') as model_file:\n",
" pickle.dump(model, model_file)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
Expand Down

0 comments on commit 55abd01

Please sign in to comment.