Skip to content

Commit

Permalink
Merge pull request #26 from marcpinet/docs-change-examples-to-notebooks
Browse files Browse the repository at this point in the history
Docs change examples to notebooks
  • Loading branch information
marcpinet authored Apr 21, 2024
2 parents c03c91e + 4a02b22 commit d645d63
Show file tree
Hide file tree
Showing 17 changed files with 1,846 additions and 441 deletions.
221 changes: 221 additions & 0 deletions examples/classification-regression/mnist_loading_saved_model.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# MNIST Loading Saved Model"
]
},
{
"cell_type": "markdown",
"metadata": {
"tags": []
},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-21T12:52:21.706906Z",
"start_time": "2024-04-21T12:52:18.726598200Z"
}
},
"outputs": [],
"source": [
"from tensorflow.keras.datasets import mnist # Dataset for testing\n",
"\n",
"from neuralnetlib.model import Model\n",
"from neuralnetlib.preprocessing import one_hot_encode\n",
"from neuralnetlib.utils import train_test_split\n",
"from neuralnetlib.metrics import accuracy_score, confusion_matrix"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Loading the MNIST dataset"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-21T12:52:21.915810200Z",
"start_time": "2024-04-21T12:52:21.706906Z"
}
},
"outputs": [],
"source": [
"(x_train, y_train), (x_test, y_test) = mnist.load_data()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Preprocessing"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-21T12:52:22.072282500Z",
"start_time": "2024-04-21T12:52:21.916810900Z"
}
},
"outputs": [],
"source": [
"x_train = x_train.reshape(-1, 28 * 28) / 255.0\n",
"x_test = x_test.reshape(-1, 28 * 28) / 255.0\n",
"y_train = one_hot_encode(y_train, num_classes=10)\n",
"y_test = one_hot_encode(y_test, num_classes=10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Split the training data into training and validation sets"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-21T12:52:22.233389700Z",
"start_time": "2024-04-21T12:52:22.073284800Z"
}
},
"outputs": [],
"source": [
"_, x_val, _, y_val = train_test_split(x_train, y_train, test_size=0.2, random_state=42)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Load the model"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-21T12:52:22.258467800Z",
"start_time": "2024-04-21T12:52:22.234388100Z"
}
},
"outputs": [],
"source": [
"model = Model.load('my_mnist_model.npz')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 5. Predict and evaluate on the validation set"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-21T12:52:22.323518700Z",
"start_time": "2024-04-21T12:52:22.257467100Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Validation Accuracy: 0.899\n"
]
}
],
"source": [
"y_pred_val = model.predict(x_val)\n",
"accuracy_val = accuracy_score(y_pred_val, y_val)\n",
"print(f'Validation Accuracy: {accuracy_val}')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 6. Optionally, you can still evaluate on the test set"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-21T12:52:22.393768500Z",
"start_time": "2024-04-21T12:52:22.318518600Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test Accuracy: 0.8863\n",
"Confusion Matrix:\n",
"[[ 937 0 0 1 11 7 2 18 1 3]\n",
" [ 0 1097 3 4 0 3 2 4 19 3]\n",
" [ 13 9 858 36 26 1 23 38 16 12]\n",
" [ 8 6 18 899 2 33 2 16 12 14]\n",
" [ 1 0 1 0 944 0 7 2 1 26]\n",
" [ 19 0 0 82 30 701 12 5 23 20]\n",
" [ 18 2 0 0 70 15 849 1 2 1]\n",
" [ 0 9 10 5 15 0 0 945 4 40]\n",
" [ 6 22 3 3 37 26 9 2 803 63]\n",
" [ 3 2 1 11 137 2 0 15 8 830]]\n"
]
}
],
"source": [
"y_pred_test = model.predict(x_test)\n",
"accuracy_test = accuracy_score(y_pred_test, y_test)\n",
"print(f'Test Accuracy: {accuracy_test}')\n",
"print(f'Confusion Matrix:\\n{confusion_matrix(y_pred_test, y_test)}')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
38 changes: 0 additions & 38 deletions examples/classification-regression/mnist_loading_saved_model.py

This file was deleted.

Loading

0 comments on commit d645d63

Please sign in to comment.