Skip to content

Commit

Permalink
Update binary_logistic_regression_tf_with_hidden_layers.ipynb
Browse files Browse the repository at this point in the history
  • Loading branch information
fs446 committed Jun 11, 2024
1 parent 03237fc commit 5762cd8
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions binary_logistic_regression_tf_with_hidden_layers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@
"from sklearn.datasets import make_classification\n",
"from sklearn.model_selection import train_test_split\n",
"import tensorflow as tf\n",
"import tensorflow.keras as keras\n",
"import tensorflow.keras.backend as K\n",
"from tensorflow import keras\n",
"\n",
"\n",
"print(\n",
" \"TF version\",\n",
Expand Down Expand Up @@ -229,8 +229,7 @@
"\n",
"model.compile(optimizer=optimizer, loss=loss, metrics=metrics)\n",
"\n",
"tw = np.sum([K.count_params(w) for w in model.trainable_weights])\n",
"print(\"\\ntrainable_weights:\", tw, \"\\n\")"
"model.summary()"
]
},
{
Expand All @@ -249,7 +248,7 @@
"outputs": [],
"source": [
"model.fit(\n",
" X_train, Y_train, epochs=epochs, batch_size=batch_size, verbose=verbose\n",
" X_train, Y_train[:, None], epochs=epochs, batch_size=batch_size, verbose=verbose\n",
")"
]
},
Expand Down Expand Up @@ -279,9 +278,9 @@
"metadata": {},
"outputs": [],
"source": [
"results = model.evaluate(X_train, Y_train, batch_size=M_train, verbose=False)\n",
"results = model.evaluate(X_train, Y_train[:, None], batch_size=M_train, verbose=False)\n",
"Y_train_pred = model.predict(X_train)\n",
"predict_class(Y_train_pred)"
"predict_class(Y_train_pred[:, None])"
]
},
{
Expand Down Expand Up @@ -333,7 +332,7 @@
"metadata": {},
"outputs": [],
"source": [
"results = model.evaluate(X_test, Y_test, batch_size=M_test, verbose=False)\n",
"results = model.evaluate(X_test, Y_test[:,None], batch_size=M_test, verbose=False)\n",
"Y_test_pred = model.predict(X_test)\n",
"predict_class(Y_test_pred)"
]
Expand Down Expand Up @@ -467,7 +466,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.12.3"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 5762cd8

Please sign in to comment.