diff --git a/mnist.ipynb b/mnist.ipynb index b929ff9..c14a2f2 100644 --- a/mnist.ipynb +++ b/mnist.ipynb @@ -2,20 +2,9 @@ "cells": [ { "cell_type": "code", - "execution_count": 11, + "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: 'dlopen(/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so, 0x0006): Symbol not found: __ZN3c106detail19maybe_wrap_dim_slowIxEET_S2_S2_b\n", - " Referenced from: /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torchvision/image.so\n", - " Expected in: /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/torch/lib/libc10.dylib'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?\n", - " warn(\n" - ] - } - ], + "outputs": [], "source": [ "import torch\n", "from torch import nn\n", @@ -45,22 +34,9 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAI4AAACOCAYAAADn/TAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAAIt0lEQVR4nO3df2yU9R0H8PfHtrQroFLBrmJHO6j8EBxujUAgSDJx1SxxZkFgZtmMC5nInBvb+LFlmwsumCwkyJiJZF0xUXQwF4hhI4MocRkyqgMHspafzkIpFgYyh9JeP/vjzvY+F9peP8/9eO76fiVN7/08d32+IR++973n7vmcqCqIBuqabA+AchMLh1xYOOTCwiEXFg65sHDIJVDhiEidiDSJyFERWZ6qQVH4ifc8jogUAGgGMBdAC4B9ABaq6jupGx6FVWGAx94B4KiqHgcAEXkRwH0Aei2cIVKsJRga4JCUaZfwn3ZVHZW4PUjhjAbwXlxuATCtrweUYCimyRcDHJIybaduefdq24MUTlJEZBGARQBQgtJ0H44yJMji+BSAyrh8c2yboarPqmqtqtYWoTjA4ShMghTOPgA1IlItIkMALACwLTXDorBzP1WpaqeILAGwA0ABgHpVPZSykVGoBVrjqOp2ANtTNBbKITxzTC4sHHJh4ZALC4dcWDjkwsIhFxYOubBwyIWFQy4sHHJh4ZALC4dc0v5BrnwhhfafqmDUyKQf2/SDKpMjpV0mjxl71uTSxWLymTVDTH6r9iWT2yMfmjxt89Lu2+O+/0bS4xwIzjjkwsIhFxYOuQyaNU7BxBqTtbjI5NN3Xm/y5el23VB2nc2vf86uM4L40/+Gm/zUr+tM3jvlBZNPdFw2eXXbXJNvej39PY8445ALC4dcWDjkkrdrnMicz5u8pmG9ybcU2XMjmdShEZN/uu6bJhd+aNcoMzYvMXn4qU6Ti9vtmqe0cW/AEfaPMw65sHDIhYVDLnm7xiluOm3ymx9VmnxLUVvKjrW0dbrJx/9r38dqGLvF5Itddg1T/vTfAh0/G52qOeOQCwuHXFg45JK3a5zO1jMmr3tqnslP1tn3ngreHmbygcXr+vz7q9pv67599C7bMCpyodXkr81YbPLJx+zfqsaBPo8VRpxxyKXfwhGRehE5KyIH47aVichfRORI7PeI9A6TwiaZGacBQF3CtuUAdqlqDYBdsUyDSFJ9jkWkCsArqjo5lpsAzFHVVhGpAPCaqo7v7+9cK2Ualq6jBSNvMDly7rzJJ164zeRDs+tNvuOX3+m+feP6YOdhwmynbnlTVWsTt3vXOOWq+skK8AyAcvfIKCcFXhxrdMrqddoSkUUi0igijR34OOjhKCS8hdMWe4pC7PfZ3u7IdrX5yXseZxuAbwBYHfu9NWUjypBI+7k+93d80PfndW59sOebB95/psDu7Iog3yXzcnwTgD0AxotIi4g8jGjBzBWRIwDuimUaRPqdcVR1YS+7wvHyiLKCZ47JJW/fqwpq4rJmkx+aYifY343Z1X37znmPmn3DX0rP9dphwhmHXFg45MLCIReucXoRuXDR5HOPTDT539t6rmVavuo5s2/FA/ebrP+4zuTKJ/fYgzm/FzWbOOOQCwuHXPhUlaSuA4dNXvDED7tvP/+zX5l9+6fbpy7Yq2dw61B7SW/NBvtR087jJ32DzCDOOOTCwiEXFg65JPXR0VQJ00dHU0lnTjX52tUtJm/67I4+Hz/h1W+ZPP4JeyogcuS4f3ABpfqjozTIsXDIhYVDLlzjpEFB+Y0mn54/zuS9y9aafE3C/98HT9xt8sVZfX/MNZ24xqGUYuGQCwuHXPheVRpE2uxlZuVP2/zRj2y72VKxl+JsqHrF5C/f/7i9/x/T3462P5xxyIWFQy4sHHLhGicFumZNNfnYvBKTJ089aXLimibRuvO32/tvbXSPLV0445ALC4dcWDjkwjVOkqR2ssnNj/WsUzbM3Gj2zS65MqC//bF2mPzG+Wp7hy77meQw4IxDLsn0x6kUkVdF5B0ROSQi341tZ8vaQSyZGacTwFJVnYTohR6PisgksGXtoJZMY6VWAK2x25dE5DCA0QDuAzAndreNAF4DsCwto8yAwuoxJh976CaTfz7/RZO/OqzdfayVbfbjLbvX2guvRmxMuEQ4hAa0xon1O74dwF6wZe2glnThiMgwAH8A8LiqfhC/r6+WtWxXm5+SKhwRKUK0aJ5X1Zdjm5NqWct2tfmp3zWOiAiA3wI4rKpr4nblVMvawqrPmHzxCxUmz//Fn03+9vUvwyvxqxb3/Mauacoa/m7yiK7wr2kSJXMCcCaArwP4p4jsj21biWjB/D7WvvZdAA+kZYQUSsm8qvorAOlld/5fskBXxTPH5JI371UVVnza5PP1Q01+pHq3yQuHB/v66CWnZnXffuuZqWbfyC0HTS67lHtrmP5wxiEXFg65sHDIJafWOFe+1HM+5Mr37Fchrhy33eS7P2W/Hnqg2iKXTZ69banJE37yr+7bZRfsGqYr0JFzA2cccmHhkEtOPVWd/EpPnTdP2Tygx66/MNbktbttKxGJ2HOcE1adMLmmzV52m//fgdc3zjjkwsIhFxYOubCVG/WJrdwopVg45MLCIRcWDrmwcMiFhUMuLBxyYeGQCwuHXFg45MLCIZeMvlclIu8jetXnSAD+PiHpFdaxZWtcY1R1VOLGjBZO90FFGq/2xlkYhHVsYRsXn6rIhYVDLtkqnGezdNxkhHVsoRpXVtY4lPv4VEUuGS0cEakTkSYROSoiWW1vKyL1InJWRA7GbQtF7+Zc6C2dscIRkQIA6wHcA2ASgIWxfsnZ0gCgLmFbWHo3h7+3tKpm5AfADAA74vIKACsydfxexlQF4GBcbgJQEbtdAaApm+OLG9dWAHPDNL5MPlWNBvBeXG6JbQuT0PVuDmtvaS6Oe6HR/9ZZfcnp7S2dCZksnFMAKuPyzbFtYZJU7+ZMCNJbOhMyWTj7ANSISLWIDAGwANFeyWHySe9mIIu9m5PoLQ1ku7d0hhd59wJoBnAMwI+zvODchOiXm3Qgut56GMANiL5aOQJgJ4CyLI1tFqJPQ28D2B/7uTcs41NVnjkmHy6OyYWFQy4sHHJh4ZALC4dcWDjkwsIhFxYOufwfp3xNA0HdZ/0AAAAASUVORK5CYII=", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "def transform_ds(b):\n", " b[x] = [TF.to_tensor(ele) for ele in b[x]]\n", @@ -72,20 +48,9 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(torch.Size([1024, 1, 28, 28]), torch.Size([1024]))" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "bs = 1024\n", "class DataLoaders:\n", @@ -104,7 +69,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -119,7 +84,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -143,7 +108,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -160,26 +125,9 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "train, epoch:1, loss: 0.4002, accuracy: 0.7806\n", - "eval, epoch:1, loss: 0.2896, accuracy: 0.9007\n", - "train, epoch:2, loss: 0.2815, accuracy: 0.9171\n", - "eval, epoch:2, loss: 0.2144, accuracy: 0.9318\n", - "train, epoch:3, loss: 0.2128, accuracy: 0.9370\n", - "eval, epoch:3, loss: 0.1721, accuracy: 0.9435\n", - "train, epoch:4, loss: 0.1453, accuracy: 0.9489\n", - "eval, epoch:4, loss: 0.1629, accuracy: 0.9590\n", - "train, epoch:5, loss: 0.1110, accuracy: 0.9565\n", - "eval, epoch:5, loss: 0.1162, accuracy: 0.9681\n" - ] - } - ], + "outputs": [], "source": [ "model = linear_classifier()\n", "lr = 0.1\n", @@ -210,17 +158,17 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "# with open('./linear_classifier.pkl', 'wb') as model_file:\n", - "# pickle.dump(model, model_file)" + "with open('./linear_classifier.pkl', 'wb') as model_file:\n", + " pickle.dump(model, model_file)" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -234,107 +182,9 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "
\n", - "\n", - "
\n", - "\n", - "\n", - "
\n", - "\n", - "
\n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "%%html\n", "