Skip to content

Commit

Permalink
Update exercise12_MusicGenreClassification.ipynb
Browse files Browse the repository at this point in the history
  • Loading branch information
fs446 committed Jun 11, 2024
1 parent 1f44f17 commit 03237fc
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions exercise12_MusicGenreClassification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@
"from sklearn.model_selection import train_test_split\n",
"from sklearn.preprocessing import OneHotEncoder, LabelBinarizer\n",
"import tensorflow as tf\n",
"import tensorflow.keras as keras\n",
"import tensorflow.keras.backend as K\n",
"from tensorflow import keras\n",
"from keras import backend as K\n",
"import time\n",
"\n",
"\n",
Expand Down Expand Up @@ -142,7 +142,9 @@
"print(root_logdir)\n",
"print(kt_logdir) # folder for keras tuner results\n",
"print(tf_kt_logdir) # folder for TF checkpoints while keras tuning\n",
"print(tf_logdir) # folder for TF checkpoint for best model training"
"print(tf_logdir) # folder for TF checkpoint for best model training\n",
"\n",
"os.makedirs(tf_logdir, exist_ok=True)"
]
},
{
Expand Down Expand Up @@ -468,7 +470,7 @@
"metadata": {},
"outputs": [],
"source": [
"encoder = OneHotEncoder(sparse=False)\n",
"encoder = OneHotEncoder(sparse_output=False)\n",
"# we encode as one-hot for TF model\n",
"Y = encoder.fit_transform(Y.reshape(-1, 1))"
]
Expand Down Expand Up @@ -546,7 +548,7 @@
"def build_model(hp): # with hyper parameter ranges\n",
" model = keras.Sequential()\n",
" # input layer\n",
" model.add(keras.Input(shape=nx))\n",
" model.add(keras.Input(shape=(nx,)))\n",
" # hidden layers\n",
" for layer in range(hp.Int(\"no_layers\", 1, 5)):\n",
" model.add(\n",
Expand Down Expand Up @@ -594,7 +596,7 @@
"model = build_model(kt.HyperParameters())\n",
"hptuner = kt.RandomSearch(\n",
" hypermodel=build_model,\n",
" objective=\"val_categorical_accuracy\", # check performance on val data!\n",
" objective=\"val_loss\", # check performance on val data!\n",
" max_trials=max_trials,\n",
" executions_per_trial=executions_per_trial,\n",
" overwrite=True,\n",
Expand Down Expand Up @@ -653,7 +655,7 @@
"# we might check (train) the best XX models in detail\n",
"# for didactical purpose we choose only the very best one, located in [0]:\n",
"model = hptuner.get_best_models(num_models=1)[0]\n",
"model.save(tf_logdir + \"/best_model\")"
"model.save(tf_logdir + \"/best_model.keras\")"
]
},
{
Expand Down Expand Up @@ -690,7 +692,7 @@
"outputs": [],
"source": [
"# load best model and reset weights\n",
"model = keras.models.load_model(tf_logdir + \"/best_model\")\n",
"model = keras.models.load_model(tf_logdir + \"/best_model.keras\")\n",
"reset_weights(model) # start training from scratch\n",
"print(model.summary())"
]
Expand Down Expand Up @@ -725,7 +727,7 @@
" callbacks=[earlystopping_cb, tensorboard_cb],\n",
" verbose=1,\n",
")\n",
"model.save(tf_logdir + \"/trained_best_model\")\n",
"model.save(tf_logdir + \"/trained_best_model.keras\")\n",
"print(model.summary())"
]
},
Expand Down Expand Up @@ -849,7 +851,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.12.3"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 03237fc

Please sign in to comment.