Skip to content

Commit

Permalink
ci: bump version to 3.3.9
Browse files Browse the repository at this point in the history
  • Loading branch information
marcpinet committed Nov 25, 2024
1 parent a48eefa commit 31041a3
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 47 deletions.
44 changes: 2 additions & 42 deletions examples/generation/gan-image-generation/gan-mnist.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -202,7 +202,6 @@
"source": [
"epochs = 100\n",
"batch_size = 128\n",
"save_interval = 10\n",
"\n",
"history = gan.fit(\n",
" X_reshaped,\n",
Expand All @@ -215,7 +214,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -283,47 +282,8 @@
" plt.tight_layout()\n",
" plt.show()\n",
"\n",
" plt.figure(figsize=(12, 4))\n",
" \n",
" plt.subplot(1, 2, 1)\n",
" plt.plot(history['discriminator_loss'], label='Discriminateur')\n",
" plt.plot(history['generator_loss'], label='Générateur')\n",
" plt.title('Loss evolution')\n",
" plt.xlabel('Epoch')\n",
" plt.ylabel('Loss')\n",
" plt.legend()\n",
" \n",
" plt.subplot(1, 2, 2)\n",
" plt.plot(history['mmd'], label='MMD Score')\n",
" plt.title('MMD Score evolution')\n",
" plt.xlabel('Epoch')\n",
" plt.ylabel('MMD')\n",
" plt.legend()\n",
" \n",
" plt.tight_layout()\n",
" plt.show()\n",
"\n",
"visualize_mnist_samples(gan)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(10, 5))\n",
"plt.plot(history['discriminator_loss'], label='Discriminator Loss')\n",
"plt.plot(history['generator_loss'], label='Generator Loss')\n",
"plt.xlabel('Epoch')\n",
"plt.ylabel('Loss')\n",
"plt.legend()\n",
"plt.savefig('gan_training_history.png')\n",
"plt.close()\n",
"\n",
"print(\"Saving model...\")\n",
"gan.save('mnist_gan_model')"
]
}
],
"metadata": {
Expand Down
4 changes: 0 additions & 4 deletions neuralnetlib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2306,10 +2306,6 @@ def fit(
)
print(val_metrics_str, end='')

if epoch % save_interval == 0:
weights = self.save_weights(epoch)
self.saved_weights_through_epochs.append(weights)

stop_training = False

for callback in callbacks:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='neuralnetlib',
version='3.3.8',
version='3.3.9',
author='Marc Pinet',
description='A flexible deep learning framework built from scratch using only NumPy',
long_description=open('README.md', encoding="utf-8").read(),
Expand Down

0 comments on commit 31041a3

Please sign in to comment.