Skip to content

Commit

Permalink
Merge pull request #42 from marcpinet/viz-fixes
Browse files Browse the repository at this point in the history
fix: headless env support and regression style adapted
  • Loading branch information
marcpinet authored May 18, 2024
2 parents 959fd4d + 8e8371e commit 9279a50
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 149 deletions.
126 changes: 60 additions & 66 deletions examples/classification-regression/simple_cancer_binary.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-24T02:45:54.992441700Z",
"start_time": "2024-04-24T02:45:54.949443700Z"
"end_time": "2024-05-18T15:40:22.316823400Z",
"start_time": "2024-05-18T15:40:21.250682300Z"
}
},
"outputs": [],
Expand All @@ -49,11 +49,11 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-24T02:45:55.028448400Z",
"start_time": "2024-04-24T02:45:54.955441200Z"
"end_time": "2024-05-18T15:40:22.332821700Z",
"start_time": "2024-05-18T15:40:22.318819700Z"
}
},
"outputs": [],
Expand All @@ -71,21 +71,19 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-24T02:45:55.029956100Z",
"start_time": "2024-04-24T02:45:54.970441300Z"
"end_time": "2024-05-18T15:40:22.347820300Z",
"start_time": "2024-05-18T15:40:22.333821900Z"
}
},
"outputs": [],
"source": [
"x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)\n",
"scaler = StandardScaler()\n",
"x_train = scaler.fit_transform(x_train)\n",
"x_test = scaler.transform(x_test)\n",
"y_train = y_train.reshape(-1, 1)\n",
"y_test = y_test.reshape(-1, 1)"
"x_test = scaler.transform(x_test)"
]
},
{
Expand All @@ -97,11 +95,11 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-24T02:45:55.031965100Z",
"start_time": "2024-04-24T02:45:54.978440700Z"
"end_time": "2024-05-18T15:40:22.362820500Z",
"start_time": "2024-05-18T15:40:22.348820100Z"
}
},
"outputs": [],
Expand Down Expand Up @@ -134,11 +132,11 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-24T02:45:55.054962900Z",
"start_time": "2024-04-24T02:45:54.986441Z"
"end_time": "2024-05-18T15:40:22.377825Z",
"start_time": "2024-05-18T15:40:22.364821200Z"
}
},
"outputs": [
Expand Down Expand Up @@ -187,56 +185,52 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-24T02:45:56.559338300Z",
"start_time": "2024-04-24T02:45:55.000441Z"
"end_time": "2024-05-18T15:40:45.505483600Z",
"start_time": "2024-05-18T15:40:22.378824500Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[==============================] 100% Epoch 1/40 - loss: 0.6905 - - 0.05s\n",
"[==============================] 100% Epoch 2/40 - loss: 0.6785 - - 0.04s\n",
"[==============================] 100% Epoch 3/40 - loss: 0.6621 - - 0.04s\n",
"[==============================] 100% Epoch 4/40 - loss: 0.6433 - - 0.05s\n",
"[==============================] 100% Epoch 5/40 - loss: 0.6219 - - 0.06s\n",
"[==============================] 100% Epoch 6/40 - loss: 0.5981 - - 0.05s\n",
"[==============================] 100% Epoch 7/40 - loss: 0.5717 - - 0.04s\n",
"[==============================] 100% Epoch 8/40 - loss: 0.5433 - - 0.06s\n",
"[==============================] 100% Epoch 9/40 - loss: 0.5139 - - 0.05s\n",
"[==============================] 100% Epoch 10/40 - loss: 0.4846 - - 0.04s\n",
"[==============================] 100% Epoch 11/40 - loss: 0.4565 - - 0.10s\n",
"[==============================] 100% Epoch 12/40 - loss: 0.4308 - - 0.08s\n",
"[==============================] 100% Epoch 13/40 - loss: 0.4077 - - 0.05s\n",
"[==============================] 100% Epoch 14/40 - loss: 0.3877 - - 0.05s\n",
"[==============================] 100% Epoch 15/40 - loss: 0.3708 - - 0.05s\n",
"[==============================] 100% Epoch 16/40 - loss: 0.3571 - - 0.04s\n",
"[==============================] 100% Epoch 17/40 - loss: 0.3464 - - 0.04s\n",
"[==============================] 100% Epoch 18/40 - loss: 0.3382 - - 0.04s\n",
"[==============================] 100% Epoch 19/40 - loss: 0.3317 - - 0.04s\n",
"[==============================] 100% Epoch 20/40 - loss: 0.3268 - - 0.05s\n",
"[==============================] 100% Epoch 21/40 - loss: 0.3232 - - 0.05s\n",
"[==============================] 100% Epoch 22/40 - loss: 0.3206 - - 0.05s\n",
"[==============================] 100% Epoch 23/40 - loss: 0.3186 - - 0.06s\n",
"[==============================] 100% Epoch 24/40 - loss: 0.3171 - - 0.04s\n",
"[==============================] 100% Epoch 25/40 - loss: 0.3160 - - 0.05s\n",
"[==============================] 100% Epoch 26/40 - loss: 0.3158 - - 0.05s\n",
"[==============================] 100% Epoch 27/40 - loss: 0.3167 - - 0.04s\n",
"[==============================] 100% Epoch 28/40 - loss: 0.3190 - - 0.04s\n",
"[==============================] 100% Epoch 29/40 - loss: 0.3217 - - 0.04s\n",
"[==============================] 100% Epoch 30/40 - loss: 0.3249 - - 0.08s\n",
"Early stopping after 30 epochs.\n"
"[==============================] 100% Epoch 1/40 - loss: 0.6905 - - 0.04s\n",
"[==============================] 100% Epoch 2/40 - loss: 0.6785 - - 0.11s\n",
"[==============================] 100% Epoch 3/40 - loss: 0.6621 - - 0.10s\n",
"[==============================] 100% Epoch 4/40 - loss: 0.6432 - - 0.10s\n",
"[==============================] 100% Epoch 5/40 - loss: 0.6218 - - 0.10s\n",
"[==============================] 100% Epoch 6/40 - loss: 0.5978 - - 0.11s\n",
"[==============================] 100% Epoch 7/40 - loss: 0.5711 - - 0.10s\n",
"[==============================] 100% Epoch 8/40 - loss: 0.5420 - - 0.10s\n",
"[==============================] 100% Epoch 9/40 - loss: 0.5118 - - 0.12s\n",
"[==============================] 100% Epoch 10/40 - loss: 0.4814 - - 0.10s\n",
"[==============================] 100% Epoch 11/40 - loss: 0.4520 - - 0.09s\n",
"[==============================] 100% Epoch 12/40 - loss: 0.4250 - - 0.10s\n",
"[==============================] 100% Epoch 13/40 - loss: 0.4006 - - 0.09s\n",
"[==============================] 100% Epoch 14/40 - loss: 0.3795 - - 0.10s\n",
"[==============================] 100% Epoch 15/40 - loss: 0.3621 - - 0.10s\n",
"[==============================] 100% Epoch 16/40 - loss: 0.3485 - - 0.10s\n",
"[==============================] 100% Epoch 17/40 - loss: 0.3385 - - 0.10s\n",
"[==============================] 100% Epoch 18/40 - loss: 0.3314 - - 0.09s\n",
"[==============================] 100% Epoch 19/40 - loss: 0.3268 - - 0.10s\n",
"[==============================] 100% Epoch 20/40 - loss: 0.3241 - - 0.10s\n",
"[==============================] 100% Epoch 21/40 - loss: 0.3226 - - 0.09s\n",
"[==============================] 100% Epoch 22/40 - loss: 0.3219 - - 0.10s\n",
"[==============================] 100% Epoch 23/40 - loss: 0.3218 - - 0.13s\n",
"[==============================] 100% Epoch 24/40 - loss: 0.3230 - - 0.09s\n",
"[==============================] 100% Epoch 25/40 - loss: 0.3265 - - 0.10s\n",
"[==============================] 100% Epoch 26/40 - loss: 0.3318 - - 0.09s\n",
"Early stopping after 26 epochs.\n"
]
}
],
"source": [
"early_stopping = EarlyStopping(patience=5, min_delta=0.001, restore_best_weights=True)\n",
"\n",
"model.fit(x_train, y_train, epochs=40, batch_size=48, random_state=42,\n",
"model.fit(x_train, y_train, epochs=40, batch_size=48, random_state=42, plot_decision_boundary=True,\n",
" callbacks=[early_stopping]) # Here, the early stopping will stop the training if the loss does not decrease\n",
"\n",
"# You could specify a different metric because loss is the default one\n",
Expand All @@ -253,19 +247,19 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 7,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-24T02:45:56.597338Z",
"start_time": "2024-04-24T02:45:56.557338300Z"
"end_time": "2024-05-18T15:40:45.518524100Z",
"start_time": "2024-05-18T15:40:45.504481900Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test loss: 0.3059129577889438\n"
"Test loss: 0.9755719591971315\n"
]
}
],
Expand All @@ -283,11 +277,11 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 8,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-24T02:45:56.609337800Z",
"start_time": "2024-04-24T02:45:56.566337400Z"
"end_time": "2024-05-18T15:40:45.547522500Z",
"start_time": "2024-05-18T15:40:45.519524Z"
}
},
"outputs": [],
Expand All @@ -304,11 +298,11 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 9,
"metadata": {
"ExecuteTime": {
"end_time": "2024-04-24T02:45:56.610337600Z",
"start_time": "2024-04-24T02:45:56.577339100Z"
"end_time": "2024-05-18T15:40:45.553523700Z",
"start_time": "2024-05-18T15:40:45.536524300Z"
}
},
"outputs": [
Expand All @@ -317,9 +311,9 @@
"output_type": "stream",
"text": [
"Accuracy: 0.9298245614035088\n",
"Precision: 0.9412393162393162\n",
"Recall: 0.9097222222222223\n",
"F1 Score: 0.9252124418791086\n"
"Precision: 0.5\n",
"Recall: 0.5\n",
"F1 Score: 0.5\n"
]
}
],
Expand Down
10 changes: 6 additions & 4 deletions neuralnetlib/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np


class EarlyStopping:
def __init__(self, patience: int = 5, min_delta: float = 0.001, restore_best_weights: bool = True,
start_from_epoch: int = 0, monitor: list = None, mode: str = 'auto', baseline: float = None):
Expand Down Expand Up @@ -40,13 +41,13 @@ def on_epoch_end(self, model, loss, metrics=None):
if self.monitor is None:
current_metric = loss
if (self.mode == 'min' and current_metric < self.best_metric - self.min_delta) or \
(self.mode == 'max' and current_metric > self.best_metric + self.min_delta):
(self.mode == 'max' and current_metric > self.best_metric + self.min_delta):
self.best_metric = current_metric
improved = True
else:
current_metric = metrics[self.monitor[0].__name__]
if (self.mode == 'max' and current_metric > self.best_metric + self.min_delta) or \
(self.mode == 'min' and current_metric < self.best_metric - self.min_delta):
(self.mode == 'min' and current_metric < self.best_metric - self.min_delta):
self.best_metric = current_metric
improved = True

Expand All @@ -66,8 +67,9 @@ def on_epoch_end(self, model, loss, metrics=None):
self.stop_training = True
print(f"\nEarly stopping after {self.epoch} epochs.", end='')
if self.restore_best_weights and self.best_weights is not None:
for layer, best_weights in zip([layer for layer in model.layers if hasattr(layer, 'weights')], self.best_weights):
for layer, best_weights in zip([layer for layer in model.layers if hasattr(layer, 'weights')],
self.best_weights):
layer.weights = best_weights
return True

return False
return False
Loading

0 comments on commit 9279a50

Please sign in to comment.