diff --git a/mlguess/keras/keras.md b/docs/source/keras.md similarity index 100% rename from mlguess/keras/keras.md rename to docs/source/keras.md diff --git a/docs/source/torch.md b/docs/source/torch.md new file mode 100644 index 0000000..278a9ce --- /dev/null +++ b/docs/source/torch.md @@ -0,0 +1,5 @@ +Welcome to the pyTorch users page. The instructions below outline how to compute various UQ quantities like aleatoric and epistemic using different modeling approaches. + +Overall, there is (1) one script to train regression models and (2) one to train categorical models. Let us review the configuration file first, then we will train models. + +(1) Currently, for regression problems only the Amini-evidential MLP and a standard multi-task MLP (e.g. one that does not predict uncertaintes). Support for the Gaussian model will be added eventually. To train a regression MLP \ No newline at end of file diff --git a/mlguess/torch/torch.md b/mlguess/torch/torch.md deleted file mode 100644 index 4b6f073..0000000 --- a/mlguess/torch/torch.md +++ /dev/null @@ -1 +0,0 @@ -Welcome to the pyTorch users page. The instructions below outline how to compute various UQ quantities like aleatoric and epistemic using different modeling approaches. \ No newline at end of file diff --git a/notebooks/classifier_example_torch.ipynb b/notebooks/classifier_example_torch.ipynb index f6158e5..0e255b9 100644 --- a/notebooks/classifier_example_torch.ipynb +++ b/notebooks/classifier_example_torch.ipynb @@ -5,7 +5,7 @@ "id": "2c7e9fb8-ef7f-4f32-bad4-ffeafbf90cd9", "metadata": {}, "source": [ - "# MILES-GUESS Classification Example Notebook\n", + "# MILES-GUESS Classification Example Notebook (PyTorch)\n", "\n", "John Schreck, David John Gagne, Charlie Becker, Gabrielle Gantos, Dhamma Kimpara, Thomas Martin" ] @@ -26,7 +26,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 1, "id": "08d23f15-63a1-484d-b9be-e46b6995bfd3", "metadata": {}, "outputs": [], @@ -39,16 +39,26 @@ "\n", "from mlguess.keras.data import load_ptype_uq, preprocess_data\n", "from mlguess.torch.models import CategoricalDNN\n", + "from mlguess.torch.metrics import MetricsCalculator\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from torch.utils.data import TensorDataset, DataLoader\n", "import numpy as np\n", + "from collections import defaultdict\n", "\n", "from mlguess.torch.class_losses import edl_mse_loss, edl_digamma_loss, edl_log_loss, relu_evidence" ] }, + { + "cell_type": "markdown", + "id": "8b25e71d-5d5c-4491-9b9b-0644650ae6e7", + "metadata": {}, + "source": [ + "### Load a config file" + ] + }, { "cell_type": "code", "execution_count": 2, @@ -70,6 +80,14 @@ " conf = yaml.load(cf, Loader=yaml.FullLoader)" ] }, + { + "cell_type": "markdown", + "id": "f7b9c45c-2e67-46ab-a7d1-4a94a0109005", + "metadata": {}, + "source": [ + "### Load the training splits" + ] + }, { "cell_type": "code", "execution_count": 4, @@ -108,7 +126,28 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 5, + "id": "e0cd02c7-3f02-4bf7-82b6-c783845bb4ba", + "metadata": {}, + "outputs": [], + "source": [ + "def one_hot_embedding(labels, num_classes=10):\n", + " # Convert to One Hot Encoding\n", + " y = torch.eye(num_classes)\n", + " return y[labels]" + ] + }, + { + "cell_type": "markdown", + "id": "58870116-bf9d-4a0f-bc1b-255da8edae01", + "metadata": {}, + "source": [ + "### Convert the pandas dataframe into torch tensors, wrap in Dataset then Dataloader" + ] + }, + { + "cell_type": "code", + "execution_count": 6, "id": "20843b47-26f0-4e0b-8cc1-72a11bdf50f4", "metadata": {}, "outputs": [], @@ -116,14 +155,24 @@ "X_train = torch.FloatTensor(scaled_data[\"train_x\"].values)\n", "y_train = torch.LongTensor(np.argmax(scaled_data[\"train_y\"], axis=1))\n", "\n", + "batch_size = 1024\n", + "\n", "# Create dataset and dataloader\n", "dataset = TensorDataset(X_train, y_train)\n", - "dataloader = DataLoader(dataset, batch_size=128, shuffle=True)" + "dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)" + ] + }, + { + "cell_type": "markdown", + "id": "3b0cf1ed-2254-4c57-9541-97cc6ae0e678", + "metadata": {}, + "source": [ + "### First lets train a standard (non-evidential) classifier" ] }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 7, "id": "7132121f-de6e-4091-a36c-bab56f34e4d4", "metadata": {}, "outputs": [], @@ -133,7 +182,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 8, "id": "c3beb428-a58d-4e98-a89f-d80e1f0ef3b5", "metadata": {}, "outputs": [], @@ -143,7 +192,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 9, "id": "013c0efd-8fd3-4b8f-9284-15ebb22ef8d1", "metadata": {}, "outputs": [ @@ -170,7 +219,7 @@ ")" ] }, - "execution_count": 32, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -179,126 +228,30 @@ "mlp" ] }, + { + "cell_type": "markdown", + "id": "a2cf7b16-cb9c-48da-b564-b6cf06a266f0", + "metadata": {}, + "source": [ + "### Train the model" + ] + }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 10, "id": "9dcf127c-9149-463a-bc36-1b6b7ae437e6", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch [1/100], Loss: 0.0021, Accuracy: 0.6578\n", - "Epoch [2/100], Loss: 0.0019, Accuracy: 0.7133\n", - "Epoch [3/100], Loss: 0.0017, Accuracy: 0.8320\n", - "Epoch [4/100], Loss: 0.0017, Accuracy: 0.8141\n", - "Epoch [5/100], Loss: 0.0017, Accuracy: 0.8156\n", - "Epoch [6/100], Loss: 0.0017, Accuracy: 0.8445\n", - "Epoch [7/100], Loss: 0.0016, Accuracy: 0.8359\n", - "Epoch [8/100], Loss: 0.0016, Accuracy: 0.8469\n", - "Epoch [9/100], Loss: 0.0016, Accuracy: 0.8477\n", - "Epoch [10/100], Loss: 0.0016, Accuracy: 0.8453\n", - "Epoch [11/100], Loss: 0.0016, Accuracy: 0.8562\n", - "Epoch [12/100], Loss: 0.0016, Accuracy: 0.8414\n", - "Epoch [13/100], Loss: 0.0016, Accuracy: 0.8625\n", - "Epoch [14/100], Loss: 0.0016, Accuracy: 0.8602\n", - "Epoch [15/100], Loss: 0.0016, Accuracy: 0.8656\n", - "Epoch [16/100], Loss: 0.0016, Accuracy: 0.8539\n", - "Epoch [17/100], Loss: 0.0016, Accuracy: 0.8656\n", - "Epoch [18/100], Loss: 0.0016, Accuracy: 0.8609\n", - "Epoch [19/100], Loss: 0.0016, Accuracy: 0.8617\n", - "Epoch [20/100], Loss: 0.0016, Accuracy: 0.8523\n", - "Epoch [21/100], Loss: 0.0015, Accuracy: 0.8805\n", - "Epoch [22/100], Loss: 0.0016, Accuracy: 0.8500\n", - "Epoch [23/100], Loss: 0.0016, Accuracy: 0.8578\n", - "Epoch [24/100], Loss: 0.0016, Accuracy: 0.8680\n", - "Epoch [25/100], Loss: 0.0016, Accuracy: 0.8484\n", - "Epoch [26/100], Loss: 0.0016, Accuracy: 0.8633\n", - "Epoch [27/100], Loss: 0.0016, Accuracy: 0.8586\n", - "Epoch [28/100], Loss: 0.0016, Accuracy: 0.8641\n", - "Epoch [29/100], Loss: 0.0015, Accuracy: 0.8711\n", - "Epoch [30/100], Loss: 0.0015, Accuracy: 0.8820\n", - "Epoch [31/100], Loss: 0.0015, Accuracy: 0.8648\n", - "Epoch [32/100], Loss: 0.0016, Accuracy: 0.8430\n", - "Epoch [33/100], Loss: 0.0016, Accuracy: 0.8602\n", - "Epoch [34/100], Loss: 0.0015, Accuracy: 0.8648\n", - "Epoch [35/100], Loss: 0.0015, Accuracy: 0.8602\n", - "Epoch [36/100], Loss: 0.0015, Accuracy: 0.8672\n", - "Epoch [37/100], Loss: 0.0015, Accuracy: 0.8727\n", - "Epoch [38/100], Loss: 0.0015, Accuracy: 0.8773\n", - "Epoch [39/100], Loss: 0.0015, Accuracy: 0.8797\n", - "Epoch [40/100], Loss: 0.0015, Accuracy: 0.8734\n", - "Epoch [41/100], Loss: 0.0015, Accuracy: 0.8711\n", - "Epoch [42/100], Loss: 0.0016, Accuracy: 0.8602\n", - "Epoch [43/100], Loss: 0.0016, Accuracy: 0.8672\n", - "Epoch [44/100], Loss: 0.0015, Accuracy: 0.8703\n", - "Epoch [45/100], Loss: 0.0016, Accuracy: 0.8508\n", - "Epoch [46/100], Loss: 0.0015, Accuracy: 0.8586\n", - "Epoch [47/100], Loss: 0.0015, Accuracy: 0.8688\n", - "Epoch [48/100], Loss: 0.0015, Accuracy: 0.8617\n", - "Epoch [49/100], Loss: 0.0015, Accuracy: 0.8703\n", - "Epoch [50/100], Loss: 0.0015, Accuracy: 0.8820\n", - "Epoch [51/100], Loss: 0.0016, Accuracy: 0.8609\n", - "Epoch [52/100], Loss: 0.0015, Accuracy: 0.8602\n", - "Epoch [53/100], Loss: 0.0015, Accuracy: 0.8641\n", - "Epoch [54/100], Loss: 0.0015, Accuracy: 0.8711\n", - "Epoch [55/100], Loss: 0.0015, Accuracy: 0.8727\n", - "Epoch [56/100], Loss: 0.0015, Accuracy: 0.8664\n", - "Epoch [57/100], Loss: 0.0015, Accuracy: 0.8789\n", - "Epoch [58/100], Loss: 0.0016, Accuracy: 0.8469\n", - "Epoch [59/100], Loss: 0.0016, Accuracy: 0.8617\n", - "Epoch [60/100], Loss: 0.0016, Accuracy: 0.8539\n", - "Epoch [61/100], Loss: 0.0015, Accuracy: 0.8812\n", - "Epoch [62/100], Loss: 0.0015, Accuracy: 0.8828\n", - "Epoch [63/100], Loss: 0.0015, Accuracy: 0.8680\n", - "Epoch [64/100], Loss: 0.0015, Accuracy: 0.8648\n", - "Epoch [65/100], Loss: 0.0015, Accuracy: 0.8750\n", - "Epoch [66/100], Loss: 0.0016, Accuracy: 0.8602\n", - "Epoch [67/100], Loss: 0.0015, Accuracy: 0.8594\n", - "Epoch [68/100], Loss: 0.0015, Accuracy: 0.8656\n", - "Epoch [69/100], Loss: 0.0016, Accuracy: 0.8500\n", - "Epoch [70/100], Loss: 0.0016, Accuracy: 0.8625\n", - "Epoch [71/100], Loss: 0.0016, Accuracy: 0.8633\n", - "Epoch [72/100], Loss: 0.0015, Accuracy: 0.8664\n", - "Epoch [73/100], Loss: 0.0015, Accuracy: 0.8805\n", - "Epoch [74/100], Loss: 0.0015, Accuracy: 0.8758\n", - "Epoch [75/100], Loss: 0.0015, Accuracy: 0.8750\n", - "Epoch [76/100], Loss: 0.0016, Accuracy: 0.8438\n", - "Epoch [77/100], Loss: 0.0015, Accuracy: 0.8805\n", - "Epoch [78/100], Loss: 0.0016, Accuracy: 0.8688\n", - "Epoch [79/100], Loss: 0.0015, Accuracy: 0.8773\n", - "Epoch [80/100], Loss: 0.0015, Accuracy: 0.8719\n", - "Epoch [81/100], Loss: 0.0016, Accuracy: 0.8484\n", - "Epoch [82/100], Loss: 0.0015, Accuracy: 0.8719\n", - "Epoch [83/100], Loss: 0.0015, Accuracy: 0.8742\n", - "Epoch [84/100], Loss: 0.0015, Accuracy: 0.8672\n", - "Epoch [85/100], Loss: 0.0016, Accuracy: 0.8664\n", - "Epoch [86/100], Loss: 0.0015, Accuracy: 0.8664\n", - "Epoch [87/100], Loss: 0.0015, Accuracy: 0.8625\n", - "Epoch [88/100], Loss: 0.0015, Accuracy: 0.8883\n", - "Epoch [89/100], Loss: 0.0015, Accuracy: 0.8812\n", - "Epoch [90/100], Loss: 0.0016, Accuracy: 0.8586\n", - "Epoch [91/100], Loss: 0.0015, Accuracy: 0.8695\n", - "Epoch [92/100], Loss: 0.0015, Accuracy: 0.8625\n", - "Epoch [93/100], Loss: 0.0016, Accuracy: 0.8594\n", - "Epoch [94/100], Loss: 0.0015, Accuracy: 0.8617\n", - "Epoch [95/100], Loss: 0.0015, Accuracy: 0.8773\n", - "Epoch [96/100], Loss: 0.0015, Accuracy: 0.8758\n", - "Epoch [97/100], Loss: 0.0016, Accuracy: 0.8547\n", - "Epoch [98/100], Loss: 0.0015, Accuracy: 0.8742\n", - "Epoch [99/100], Loss: 0.0015, Accuracy: 0.8680\n", - "Epoch [100/100], Loss: 0.0015, Accuracy: 0.8812\n" - ] - } - ], + "outputs": [], "source": [ "criterion = nn.CrossEntropyLoss()\n", "optimizer = optim.Adam(mlp.parameters(), lr=0.001)\n", + "metrics = MetricsCalculator(use_uncertainty=False)\n", "\n", "# Training loop\n", "num_epochs = 100\n", "batches_per_epoch = 10\n", + "\n", + "results_dict = defaultdict(list)\n", "for epoch in range(num_epochs):\n", " mlp.train()\n", " total_loss = 0\n", @@ -319,16 +272,45 @@ " # Calculate accuracy\n", " _, predicted = torch.max(outputs.data, 1)\n", " total_predictions += batch_y.size(0)\n", - " correct_predictions += (predicted == batch_y).sum().item()\n", + " correct_predictions += (predicted == batch_y).float().mean().item()\n", + "\n", + " metrics_dict = metrics(one_hot_embedding(batch_y, 4), outputs, split=\"train\")\n", + " for name, value in metrics_dict.items():\n", + " results_dict[name].append(value.item())\n", + " \n", "\n", " if (k + 1) == batches_per_epoch:\n", " break\n", " \n", " # Calculate epoch statistics\n", - " avg_loss = total_loss / len(dataloader)\n", - " accuracy = correct_predictions / total_predictions\n", + " avg_loss = total_loss / batches_per_epoch\n", + " accuracy = correct_predictions / batches_per_epoch\n", " \n", - " print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}')" + " #print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}')" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "b00298b1-83a4-4fa8-a98d-3c1850ea0fe2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train_csi 0.25876237462243046\n", + "train_ave_acc 0.4687493490720337\n", + "train_prec 0.42272548775507235\n", + "train_recall 0.4687493490720337\n", + "train_f1 0.4416411424933151\n", + "train_auc 0.898654333499807\n" + ] + } + ], + "source": [ + "for key, val in results_dict.items():\n", + " print(key, np.mean(val))" ] }, { @@ -336,12 +318,12 @@ "id": "c8537a3c-ba74-47e5-a12b-49d62485c585", "metadata": {}, "source": [ - "### Use an evidential neural network " + "### Next lets train an evidential classifier" ] }, { "cell_type": "code", - "execution_count": 60, + "execution_count": 12, "id": "ff24ebbf-b04a-41ea-91f0-51ecbcb5b949", "metadata": {}, "outputs": [], @@ -352,7 +334,7 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": 13, "id": "905f3c5f-045c-4441-ba59-db420a9b2264", "metadata": {}, "outputs": [], @@ -362,7 +344,7 @@ }, { "cell_type": "code", - "execution_count": 62, + "execution_count": 14, "id": "7361134b-0860-4692-b57c-68d5f94b2d3b", "metadata": {}, "outputs": [ @@ -388,7 +370,7 @@ ")" ] }, - "execution_count": 62, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -397,724 +379,22 @@ "ev_mlp" ] }, + { + "cell_type": "markdown", + "id": "63f06fec-b07e-4d08-b55a-12ead26c0cf7", + "metadata": {}, + "source": [ + "### Note here there is no output activation\n", + "### The other main difference is the choice of loss, seen below" + ] + }, { "cell_type": "code", - "execution_count": 63, + "execution_count": 15, "id": "35b4accb-706d-4bad-9f44-a502b3151b54", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch [1/100]:\n", - " Loss: 1.4121\n", - " Accuracy: 0.6539\n", - " Mean Evidence: 1.1506\n", - " Mean Evidence (Correct): 1.1339\n", - " Mean Evidence (Incorrect): 1.1881\n", - " Mean Uncertainty: 0.7936\n", - "Epoch [2/100]:\n", - " Loss: 1.1648\n", - " Accuracy: 0.6875\n", - " Mean Evidence: 2.2726\n", - " Mean Evidence (Correct): 2.3175\n", - " Mean Evidence (Incorrect): 2.1707\n", - " Mean Uncertainty: 0.6441\n", - "Epoch [3/100]:\n", - " Loss: 1.0674\n", - " Accuracy: 0.7195\n", - " Mean Evidence: 3.0092\n", - " Mean Evidence (Correct): 3.0867\n", - " Mean Evidence (Incorrect): 2.7428\n", - " Mean Uncertainty: 0.5794\n", - "Epoch [4/100]:\n", - " Loss: 0.9974\n", - " Accuracy: 0.8320\n", - " Mean Evidence: 3.3631\n", - " Mean Evidence (Correct): 3.4676\n", - " Mean Evidence (Incorrect): 2.8431\n", - " Mean Uncertainty: 0.5520\n", - "Epoch [5/100]:\n", - " Loss: 0.9768\n", - " Accuracy: 0.7969\n", - " Mean Evidence: 3.5468\n", - " Mean Evidence (Correct): 3.7274\n", - " Mean Evidence (Incorrect): 2.8339\n", - " Mean Uncertainty: 0.5397\n", - "Epoch [6/100]:\n", - " Loss: 0.9526\n", - " Accuracy: 0.8430\n", - " Mean Evidence: 2.9893\n", - " Mean Evidence (Correct): 3.1542\n", - " Mean Evidence (Incorrect): 2.1033\n", - " Mean Uncertainty: 0.5846\n", - "Epoch [7/100]:\n", - " Loss: 0.9803\n", - " Accuracy: 0.8359\n", - " Mean Evidence: 2.9753\n", - " Mean Evidence (Correct): 3.1467\n", - " Mean Evidence (Incorrect): 2.0912\n", - " Mean Uncertainty: 0.5886\n", - "Epoch [8/100]:\n", - " Loss: 0.9236\n", - " Accuracy: 0.8664\n", - " Mean Evidence: 2.9953\n", - " Mean Evidence (Correct): 3.1496\n", - " Mean Evidence (Incorrect): 1.9765\n", - " Mean Uncertainty: 0.5860\n", - "Epoch [9/100]:\n", - " Loss: 0.9038\n", - " Accuracy: 0.8656\n", - " Mean Evidence: 2.9945\n", - " Mean Evidence (Correct): 3.1563\n", - " Mean Evidence (Incorrect): 1.9593\n", - " Mean Uncertainty: 0.5844\n", - "Epoch [10/100]:\n", - " Loss: 0.9368\n", - " Accuracy: 0.8648\n", - " Mean Evidence: 2.9812\n", - " Mean Evidence (Correct): 3.1247\n", - " Mean Evidence (Incorrect): 2.0530\n", - " Mean Uncertainty: 0.5839\n", - "Epoch [11/100]:\n", - " Loss: 0.9211\n", - " Accuracy: 0.8781\n", - " Mean Evidence: 2.9555\n", - " Mean Evidence (Correct): 3.1113\n", - " Mean Evidence (Incorrect): 1.8509\n", - " Mean Uncertainty: 0.5916\n", - "Epoch [12/100]:\n", - " Loss: 0.9234\n", - " Accuracy: 0.8734\n", - " Mean Evidence: 3.0571\n", - " Mean Evidence (Correct): 3.2371\n", - " Mean Evidence (Incorrect): 1.8241\n", - " Mean Uncertainty: 0.5861\n", - "Epoch [13/100]:\n", - " Loss: 0.9226\n", - " Accuracy: 0.8711\n", - " Mean Evidence: 2.9944\n", - " Mean Evidence (Correct): 3.1705\n", - " Mean Evidence (Incorrect): 1.8045\n", - " Mean Uncertainty: 0.5885\n", - "Epoch [14/100]:\n", - " Loss: 0.9256\n", - " Accuracy: 0.8680\n", - " Mean Evidence: 3.1219\n", - " Mean Evidence (Correct): 3.3142\n", - " Mean Evidence (Incorrect): 1.8663\n", - " Mean Uncertainty: 0.5790\n", - "Epoch [15/100]:\n", - " Loss: 0.9094\n", - " Accuracy: 0.8711\n", - " Mean Evidence: 3.1095\n", - " Mean Evidence (Correct): 3.3135\n", - " Mean Evidence (Incorrect): 1.7177\n", - " Mean Uncertainty: 0.5849\n", - "Epoch [16/100]:\n", - " Loss: 0.9376\n", - " Accuracy: 0.8531\n", - " Mean Evidence: 3.0377\n", - " Mean Evidence (Correct): 3.2655\n", - " Mean Evidence (Incorrect): 1.7089\n", - " Mean Uncertainty: 0.5888\n", - "Epoch [17/100]:\n", - " Loss: 0.9212\n", - " Accuracy: 0.8594\n", - " Mean Evidence: 3.2099\n", - " Mean Evidence (Correct): 3.4423\n", - " Mean Evidence (Incorrect): 1.7998\n", - " Mean Uncertainty: 0.5773\n", - "Epoch [18/100]:\n", - " Loss: 0.9339\n", - " Accuracy: 0.8594\n", - " Mean Evidence: 3.1622\n", - " Mean Evidence (Correct): 3.3764\n", - " Mean Evidence (Incorrect): 1.8101\n", - " Mean Uncertainty: 0.5786\n", - "Epoch [19/100]:\n", - " Loss: 0.8685\n", - " Accuracy: 0.8820\n", - " Mean Evidence: 3.2425\n", - " Mean Evidence (Correct): 3.4492\n", - " Mean Evidence (Incorrect): 1.6918\n", - " Mean Uncertainty: 0.5757\n", - "Epoch [20/100]:\n", - " Loss: 0.8929\n", - " Accuracy: 0.8703\n", - " Mean Evidence: 3.2186\n", - " Mean Evidence (Correct): 3.4331\n", - " Mean Evidence (Incorrect): 1.7826\n", - " Mean Uncertainty: 0.5741\n", - "Epoch [21/100]:\n", - " Loss: 0.8966\n", - " Accuracy: 0.8672\n", - " Mean Evidence: 3.4288\n", - " Mean Evidence (Correct): 3.6741\n", - " Mean Evidence (Incorrect): 1.8434\n", - " Mean Uncertainty: 0.5631\n", - "Epoch [22/100]:\n", - " Loss: 0.8907\n", - " Accuracy: 0.8750\n", - " Mean Evidence: 3.1956\n", - " Mean Evidence (Correct): 3.4293\n", - " Mean Evidence (Incorrect): 1.5687\n", - " Mean Uncertainty: 0.5810\n", - "Epoch [23/100]:\n", - " Loss: 0.8836\n", - " Accuracy: 0.8734\n", - " Mean Evidence: 3.4551\n", - " Mean Evidence (Correct): 3.6859\n", - " Mean Evidence (Incorrect): 1.8568\n", - " Mean Uncertainty: 0.5603\n", - "Epoch [24/100]:\n", - " Loss: 0.9116\n", - " Accuracy: 0.8633\n", - " Mean Evidence: 3.2688\n", - " Mean Evidence (Correct): 3.5109\n", - " Mean Evidence (Incorrect): 1.7350\n", - " Mean Uncertainty: 0.5750\n", - "Epoch [25/100]:\n", - " Loss: 0.8749\n", - " Accuracy: 0.8773\n", - " Mean Evidence: 3.3501\n", - " Mean Evidence (Correct): 3.5736\n", - " Mean Evidence (Incorrect): 1.7668\n", - " Mean Uncertainty: 0.5687\n", - "Epoch [26/100]:\n", - " Loss: 0.8799\n", - " Accuracy: 0.8695\n", - " Mean Evidence: 3.4384\n", - " Mean Evidence (Correct): 3.6922\n", - " Mean Evidence (Incorrect): 1.7485\n", - " Mean Uncertainty: 0.5638\n", - "Epoch [27/100]:\n", - " Loss: 0.9192\n", - " Accuracy: 0.8539\n", - " Mean Evidence: 3.2922\n", - " Mean Evidence (Correct): 3.5734\n", - " Mean Evidence (Incorrect): 1.6355\n", - " Mean Uncertainty: 0.5774\n", - "Epoch [28/100]:\n", - " Loss: 0.9234\n", - " Accuracy: 0.8469\n", - " Mean Evidence: 3.2407\n", - " Mean Evidence (Correct): 3.5394\n", - " Mean Evidence (Incorrect): 1.5795\n", - " Mean Uncertainty: 0.5800\n", - "Epoch [29/100]:\n", - " Loss: 0.8884\n", - " Accuracy: 0.8680\n", - " Mean Evidence: 3.1968\n", - " Mean Evidence (Correct): 3.4437\n", - " Mean Evidence (Incorrect): 1.5545\n", - " Mean Uncertainty: 0.5809\n", - "Epoch [30/100]:\n", - " Loss: 0.9134\n", - " Accuracy: 0.8508\n", - " Mean Evidence: 3.3043\n", - " Mean Evidence (Correct): 3.5862\n", - " Mean Evidence (Incorrect): 1.6982\n", - " Mean Uncertainty: 0.5730\n", - "Epoch [31/100]:\n", - " Loss: 0.9268\n", - " Accuracy: 0.8297\n", - " Mean Evidence: 3.5064\n", - " Mean Evidence (Correct): 3.8727\n", - " Mean Evidence (Incorrect): 1.7174\n", - " Mean Uncertainty: 0.5643\n", - "Epoch [32/100]:\n", - " Loss: 0.8848\n", - " Accuracy: 0.8789\n", - " Mean Evidence: 3.3012\n", - " Mean Evidence (Correct): 3.5393\n", - " Mean Evidence (Incorrect): 1.5876\n", - " Mean Uncertainty: 0.5789\n", - "Epoch [33/100]:\n", - " Loss: 0.8797\n", - " Accuracy: 0.8719\n", - " Mean Evidence: 3.4423\n", - " Mean Evidence (Correct): 3.6817\n", - " Mean Evidence (Incorrect): 1.7945\n", - " Mean Uncertainty: 0.5628\n", - "Epoch [34/100]:\n", - " Loss: 0.8954\n", - " Accuracy: 0.8562\n", - " Mean Evidence: 3.5927\n", - " Mean Evidence (Correct): 3.8944\n", - " Mean Evidence (Incorrect): 1.7938\n", - " Mean Uncertainty: 0.5558\n", - "Epoch [35/100]:\n", - " Loss: 0.8858\n", - " Accuracy: 0.8633\n", - " Mean Evidence: 3.5084\n", - " Mean Evidence (Correct): 3.8244\n", - " Mean Evidence (Incorrect): 1.5138\n", - " Mean Uncertainty: 0.5685\n", - "Epoch [36/100]:\n", - " Loss: 0.8749\n", - " Accuracy: 0.8680\n", - " Mean Evidence: 3.5462\n", - " Mean Evidence (Correct): 3.8126\n", - " Mean Evidence (Incorrect): 1.7895\n", - " Mean Uncertainty: 0.5589\n", - "Epoch [37/100]:\n", - " Loss: 0.8609\n", - " Accuracy: 0.8828\n", - " Mean Evidence: 3.4395\n", - " Mean Evidence (Correct): 3.6813\n", - " Mean Evidence (Incorrect): 1.6430\n", - " Mean Uncertainty: 0.5661\n", - "Epoch [38/100]:\n", - " Loss: 0.8602\n", - " Accuracy: 0.8695\n", - " Mean Evidence: 3.6949\n", - " Mean Evidence (Correct): 3.9743\n", - " Mean Evidence (Incorrect): 1.8449\n", - " Mean Uncertainty: 0.5470\n", - "Epoch [39/100]:\n", - " Loss: 0.8767\n", - " Accuracy: 0.8617\n", - " Mean Evidence: 3.5608\n", - " Mean Evidence (Correct): 3.8731\n", - " Mean Evidence (Incorrect): 1.6100\n", - " Mean Uncertainty: 0.5639\n", - "Epoch [40/100]:\n", - " Loss: 0.8687\n", - " Accuracy: 0.8781\n", - " Mean Evidence: 3.2983\n", - " Mean Evidence (Correct): 3.5434\n", - " Mean Evidence (Incorrect): 1.5223\n", - " Mean Uncertainty: 0.5770\n", - "Epoch [41/100]:\n", - " Loss: 0.8647\n", - " Accuracy: 0.8727\n", - " Mean Evidence: 3.6384\n", - " Mean Evidence (Correct): 3.9001\n", - " Mean Evidence (Incorrect): 1.8539\n", - " Mean Uncertainty: 0.5514\n", - "Epoch [42/100]:\n", - " Loss: 0.8878\n", - " Accuracy: 0.8617\n", - " Mean Evidence: 3.4478\n", - " Mean Evidence (Correct): 3.7374\n", - " Mean Evidence (Incorrect): 1.6555\n", - " Mean Uncertainty: 0.5662\n", - "Epoch [43/100]:\n", - " Loss: 0.8711\n", - " Accuracy: 0.8633\n", - " Mean Evidence: 3.5809\n", - " Mean Evidence (Correct): 3.8920\n", - " Mean Evidence (Incorrect): 1.6383\n", - " Mean Uncertainty: 0.5605\n", - "Epoch [44/100]:\n", - " Loss: 0.8767\n", - " Accuracy: 0.8641\n", - " Mean Evidence: 3.4832\n", - " Mean Evidence (Correct): 3.7651\n", - " Mean Evidence (Incorrect): 1.6910\n", - " Mean Uncertainty: 0.5634\n", - "Epoch [45/100]:\n", - " Loss: 0.8649\n", - " Accuracy: 0.8758\n", - " Mean Evidence: 3.7363\n", - " Mean Evidence (Correct): 4.0215\n", - " Mean Evidence (Incorrect): 1.7210\n", - " Mean Uncertainty: 0.5525\n", - "Epoch [46/100]:\n", - " Loss: 0.8691\n", - " Accuracy: 0.8688\n", - " Mean Evidence: 3.5992\n", - " Mean Evidence (Correct): 3.9044\n", - " Mean Evidence (Incorrect): 1.5811\n", - " Mean Uncertainty: 0.5623\n", - "Epoch [47/100]:\n", - " Loss: 0.8674\n", - " Accuracy: 0.8711\n", - " Mean Evidence: 3.5894\n", - " Mean Evidence (Correct): 3.8565\n", - " Mean Evidence (Incorrect): 1.7980\n", - " Mean Uncertainty: 0.5556\n", - "Epoch [48/100]:\n", - " Loss: 0.8546\n", - " Accuracy: 0.8750\n", - " Mean Evidence: 3.5388\n", - " Mean Evidence (Correct): 3.8141\n", - " Mean Evidence (Incorrect): 1.6114\n", - " Mean Uncertainty: 0.5632\n", - "Epoch [49/100]:\n", - " Loss: 0.8415\n", - " Accuracy: 0.8758\n", - " Mean Evidence: 3.8513\n", - " Mean Evidence (Correct): 4.1364\n", - " Mean Evidence (Incorrect): 1.8409\n", - " Mean Uncertainty: 0.5406\n", - "Epoch [50/100]:\n", - " Loss: 0.8811\n", - " Accuracy: 0.8672\n", - " Mean Evidence: 3.4773\n", - " Mean Evidence (Correct): 3.7766\n", - " Mean Evidence (Incorrect): 1.5269\n", - " Mean Uncertainty: 0.5713\n", - "Epoch [51/100]:\n", - " Loss: 0.8888\n", - " Accuracy: 0.8625\n", - " Mean Evidence: 3.5878\n", - " Mean Evidence (Correct): 3.8817\n", - " Mean Evidence (Incorrect): 1.7470\n", - " Mean Uncertainty: 0.5576\n", - "Epoch [52/100]:\n", - " Loss: 0.8819\n", - " Accuracy: 0.8609\n", - " Mean Evidence: 3.5548\n", - " Mean Evidence (Correct): 3.8599\n", - " Mean Evidence (Incorrect): 1.6647\n", - " Mean Uncertainty: 0.5606\n", - "Epoch [53/100]:\n", - " Loss: 0.8711\n", - " Accuracy: 0.8641\n", - " Mean Evidence: 3.5900\n", - " Mean Evidence (Correct): 3.9110\n", - " Mean Evidence (Incorrect): 1.5457\n", - " Mean Uncertainty: 0.5642\n", - "Epoch [54/100]:\n", - " Loss: 0.8424\n", - " Accuracy: 0.8828\n", - " Mean Evidence: 3.6759\n", - " Mean Evidence (Correct): 3.9679\n", - " Mean Evidence (Incorrect): 1.4717\n", - " Mean Uncertainty: 0.5623\n", - "Epoch [55/100]:\n", - " Loss: 0.8635\n", - " Accuracy: 0.8727\n", - " Mean Evidence: 3.6023\n", - " Mean Evidence (Correct): 3.8592\n", - " Mean Evidence (Incorrect): 1.8282\n", - " Mean Uncertainty: 0.5535\n", - "Epoch [56/100]:\n", - " Loss: 0.8453\n", - " Accuracy: 0.8828\n", - " Mean Evidence: 3.5492\n", - " Mean Evidence (Correct): 3.7920\n", - " Mean Evidence (Incorrect): 1.7206\n", - " Mean Uncertainty: 0.5576\n", - "Epoch [57/100]:\n", - " Loss: 0.8522\n", - " Accuracy: 0.8797\n", - " Mean Evidence: 3.7122\n", - " Mean Evidence (Correct): 3.9698\n", - " Mean Evidence (Incorrect): 1.8450\n", - " Mean Uncertainty: 0.5491\n", - "Epoch [58/100]:\n", - " Loss: 0.8640\n", - " Accuracy: 0.8594\n", - " Mean Evidence: 3.7939\n", - " Mean Evidence (Correct): 4.1643\n", - " Mean Evidence (Incorrect): 1.4793\n", - " Mean Uncertainty: 0.5616\n", - "Epoch [59/100]:\n", - " Loss: 0.8853\n", - " Accuracy: 0.8648\n", - " Mean Evidence: 3.3298\n", - " Mean Evidence (Correct): 3.6446\n", - " Mean Evidence (Incorrect): 1.3268\n", - " Mean Uncertainty: 0.5869\n", - "Epoch [60/100]:\n", - " Loss: 0.8929\n", - " Accuracy: 0.8594\n", - " Mean Evidence: 3.6394\n", - " Mean Evidence (Correct): 3.9429\n", - " Mean Evidence (Incorrect): 1.7875\n", - " Mean Uncertainty: 0.5565\n", - "Epoch [61/100]:\n", - " Loss: 0.8792\n", - " Accuracy: 0.8594\n", - " Mean Evidence: 3.5895\n", - " Mean Evidence (Correct): 3.9246\n", - " Mean Evidence (Incorrect): 1.5427\n", - " Mean Uncertainty: 0.5669\n", - "Epoch [62/100]:\n", - " Loss: 0.8514\n", - " Accuracy: 0.8719\n", - " Mean Evidence: 3.6678\n", - " Mean Evidence (Correct): 3.9958\n", - " Mean Evidence (Incorrect): 1.4369\n", - " Mean Uncertainty: 0.5626\n", - "Epoch [63/100]:\n", - " Loss: 0.8798\n", - " Accuracy: 0.8656\n", - " Mean Evidence: 3.5462\n", - " Mean Evidence (Correct): 3.8270\n", - " Mean Evidence (Incorrect): 1.7424\n", - " Mean Uncertainty: 0.5607\n", - "Epoch [64/100]:\n", - " Loss: 0.8383\n", - " Accuracy: 0.8844\n", - " Mean Evidence: 3.7336\n", - " Mean Evidence (Correct): 3.9889\n", - " Mean Evidence (Incorrect): 1.7703\n", - " Mean Uncertainty: 0.5496\n", - "Epoch [65/100]:\n", - " Loss: 0.8633\n", - " Accuracy: 0.8672\n", - " Mean Evidence: 3.6622\n", - " Mean Evidence (Correct): 3.9541\n", - " Mean Evidence (Incorrect): 1.7330\n", - " Mean Uncertainty: 0.5529\n", - "Epoch [66/100]:\n", - " Loss: 0.8435\n", - " Accuracy: 0.8664\n", - " Mean Evidence: 3.8865\n", - " Mean Evidence (Correct): 4.2192\n", - " Mean Evidence (Incorrect): 1.7134\n", - " Mean Uncertainty: 0.5449\n", - "Epoch [67/100]:\n", - " Loss: 0.8875\n", - " Accuracy: 0.8641\n", - " Mean Evidence: 3.4765\n", - " Mean Evidence (Correct): 3.7726\n", - " Mean Evidence (Incorrect): 1.5957\n", - " Mean Uncertainty: 0.5693\n", - "Epoch [68/100]:\n", - " Loss: 0.8957\n", - " Accuracy: 0.8656\n", - " Mean Evidence: 3.5208\n", - " Mean Evidence (Correct): 3.7956\n", - " Mean Evidence (Incorrect): 1.7555\n", - " Mean Uncertainty: 0.5662\n", - "Epoch [69/100]:\n", - " Loss: 0.8540\n", - " Accuracy: 0.8750\n", - " Mean Evidence: 3.6098\n", - " Mean Evidence (Correct): 3.8808\n", - " Mean Evidence (Incorrect): 1.6841\n", - " Mean Uncertainty: 0.5590\n", - "Epoch [70/100]:\n", - " Loss: 0.8496\n", - " Accuracy: 0.8734\n", - " Mean Evidence: 3.8521\n", - " Mean Evidence (Correct): 4.1445\n", - " Mean Evidence (Incorrect): 1.8312\n", - " Mean Uncertainty: 0.5420\n", - "Epoch [71/100]:\n", - " Loss: 0.8715\n", - " Accuracy: 0.8641\n", - " Mean Evidence: 3.6149\n", - " Mean Evidence (Correct): 3.9327\n", - " Mean Evidence (Incorrect): 1.5826\n", - " Mean Uncertainty: 0.5609\n", - "Epoch [72/100]:\n", - " Loss: 0.8776\n", - " Accuracy: 0.8617\n", - " Mean Evidence: 3.7521\n", - " Mean Evidence (Correct): 4.0663\n", - " Mean Evidence (Incorrect): 1.8145\n", - " Mean Uncertainty: 0.5522\n", - "Epoch [73/100]:\n", - " Loss: 0.8276\n", - " Accuracy: 0.8789\n", - " Mean Evidence: 3.8382\n", - " Mean Evidence (Correct): 4.1206\n", - " Mean Evidence (Incorrect): 1.8197\n", - " Mean Uncertainty: 0.5426\n", - "Epoch [74/100]:\n", - " Loss: 0.8731\n", - " Accuracy: 0.8625\n", - " Mean Evidence: 3.6833\n", - " Mean Evidence (Correct): 4.0039\n", - " Mean Evidence (Incorrect): 1.6838\n", - " Mean Uncertainty: 0.5571\n", - "Epoch [75/100]:\n", - " Loss: 0.8443\n", - " Accuracy: 0.8719\n", - " Mean Evidence: 3.7829\n", - " Mean Evidence (Correct): 4.0907\n", - " Mean Evidence (Incorrect): 1.6809\n", - " Mean Uncertainty: 0.5488\n", - "Epoch [76/100]:\n", - " Loss: 0.8620\n", - " Accuracy: 0.8680\n", - " Mean Evidence: 3.7752\n", - " Mean Evidence (Correct): 4.0948\n", - " Mean Evidence (Incorrect): 1.6731\n", - " Mean Uncertainty: 0.5513\n", - "Epoch [77/100]:\n", - " Loss: 0.8395\n", - " Accuracy: 0.8773\n", - " Mean Evidence: 3.8637\n", - " Mean Evidence (Correct): 4.1615\n", - " Mean Evidence (Incorrect): 1.7511\n", - " Mean Uncertainty: 0.5448\n", - "Epoch [78/100]:\n", - " Loss: 0.8717\n", - " Accuracy: 0.8539\n", - " Mean Evidence: 3.8003\n", - " Mean Evidence (Correct): 4.1496\n", - " Mean Evidence (Incorrect): 1.8124\n", - " Mean Uncertainty: 0.5471\n", - "Epoch [79/100]:\n", - " Loss: 0.8369\n", - " Accuracy: 0.8734\n", - " Mean Evidence: 3.6910\n", - " Mean Evidence (Correct): 3.9923\n", - " Mean Evidence (Incorrect): 1.5887\n", - " Mean Uncertainty: 0.5544\n", - "Epoch [80/100]:\n", - " Loss: 0.8539\n", - " Accuracy: 0.8758\n", - " Mean Evidence: 3.7360\n", - " Mean Evidence (Correct): 4.0384\n", - " Mean Evidence (Incorrect): 1.5802\n", - " Mean Uncertainty: 0.5591\n", - "Epoch [81/100]:\n", - " Loss: 0.8414\n", - " Accuracy: 0.8734\n", - " Mean Evidence: 3.7859\n", - " Mean Evidence (Correct): 4.1068\n", - " Mean Evidence (Incorrect): 1.5702\n", - " Mean Uncertainty: 0.5527\n", - "Epoch [82/100]:\n", - " Loss: 0.8424\n", - " Accuracy: 0.8773\n", - " Mean Evidence: 3.6932\n", - " Mean Evidence (Correct): 3.9882\n", - " Mean Evidence (Incorrect): 1.6080\n", - " Mean Uncertainty: 0.5567\n", - "Epoch [83/100]:\n", - " Loss: 0.8568\n", - " Accuracy: 0.8750\n", - " Mean Evidence: 3.6629\n", - " Mean Evidence (Correct): 3.9507\n", - " Mean Evidence (Incorrect): 1.6475\n", - " Mean Uncertainty: 0.5560\n", - "Epoch [84/100]:\n", - " Loss: 0.8444\n", - " Accuracy: 0.8672\n", - " Mean Evidence: 3.7511\n", - " Mean Evidence (Correct): 4.0738\n", - " Mean Evidence (Incorrect): 1.6400\n", - " Mean Uncertainty: 0.5518\n", - "Epoch [85/100]:\n", - " Loss: 0.8295\n", - " Accuracy: 0.8688\n", - " Mean Evidence: 3.9610\n", - " Mean Evidence (Correct): 4.3027\n", - " Mean Evidence (Incorrect): 1.6890\n", - " Mean Uncertainty: 0.5392\n", - "Epoch [86/100]:\n", - " Loss: 0.8343\n", - " Accuracy: 0.8859\n", - " Mean Evidence: 3.7302\n", - " Mean Evidence (Correct): 4.0008\n", - " Mean Evidence (Incorrect): 1.6380\n", - " Mean Uncertainty: 0.5521\n", - "Epoch [87/100]:\n", - " Loss: 0.8547\n", - " Accuracy: 0.8664\n", - " Mean Evidence: 3.8162\n", - " Mean Evidence (Correct): 4.1548\n", - " Mean Evidence (Incorrect): 1.6117\n", - " Mean Uncertainty: 0.5493\n", - "Epoch [88/100]:\n", - " Loss: 0.8403\n", - " Accuracy: 0.8711\n", - " Mean Evidence: 3.7756\n", - " Mean Evidence (Correct): 4.0888\n", - " Mean Evidence (Incorrect): 1.6541\n", - " Mean Uncertainty: 0.5487\n", - "Epoch [89/100]:\n", - " Loss: 0.8722\n", - " Accuracy: 0.8602\n", - " Mean Evidence: 3.7597\n", - " Mean Evidence (Correct): 4.0929\n", - " Mean Evidence (Incorrect): 1.7088\n", - " Mean Uncertainty: 0.5518\n", - "Epoch [90/100]:\n", - " Loss: 0.8429\n", - " Accuracy: 0.8742\n", - " Mean Evidence: 3.6869\n", - " Mean Evidence (Correct): 3.9855\n", - " Mean Evidence (Incorrect): 1.6244\n", - " Mean Uncertainty: 0.5563\n", - "Epoch [91/100]:\n", - " Loss: 0.8533\n", - " Accuracy: 0.8688\n", - " Mean Evidence: 3.7725\n", - " Mean Evidence (Correct): 4.0969\n", - " Mean Evidence (Incorrect): 1.6331\n", - " Mean Uncertainty: 0.5528\n", - "Epoch [92/100]:\n", - " Loss: 0.8425\n", - " Accuracy: 0.8703\n", - " Mean Evidence: 3.8687\n", - " Mean Evidence (Correct): 4.1700\n", - " Mean Evidence (Incorrect): 1.8608\n", - " Mean Uncertainty: 0.5434\n", - "Epoch [93/100]:\n", - " Loss: 0.8356\n", - " Accuracy: 0.8703\n", - " Mean Evidence: 3.9570\n", - " Mean Evidence (Correct): 4.2969\n", - " Mean Evidence (Incorrect): 1.7033\n", - " Mean Uncertainty: 0.5398\n", - "Epoch [94/100]:\n", - " Loss: 0.8330\n", - " Accuracy: 0.8812\n", - " Mean Evidence: 3.5676\n", - " Mean Evidence (Correct): 3.8478\n", - " Mean Evidence (Incorrect): 1.4983\n", - " Mean Uncertainty: 0.5635\n", - "Epoch [95/100]:\n", - " Loss: 0.8487\n", - " Accuracy: 0.8719\n", - " Mean Evidence: 3.8633\n", - " Mean Evidence (Correct): 4.1326\n", - " Mean Evidence (Incorrect): 2.0192\n", - " Mean Uncertainty: 0.5350\n", - "Epoch [96/100]:\n", - " Loss: 0.8898\n", - " Accuracy: 0.8594\n", - " Mean Evidence: 3.5244\n", - " Mean Evidence (Correct): 3.8448\n", - " Mean Evidence (Incorrect): 1.5558\n", - " Mean Uncertainty: 0.5695\n", - "Epoch [97/100]:\n", - " Loss: 0.8595\n", - " Accuracy: 0.8773\n", - " Mean Evidence: 3.6058\n", - " Mean Evidence (Correct): 3.8790\n", - " Mean Evidence (Incorrect): 1.6440\n", - " Mean Uncertainty: 0.5616\n", - "Epoch [98/100]:\n", - " Loss: 0.8711\n", - " Accuracy: 0.8625\n", - " Mean Evidence: 3.6847\n", - " Mean Evidence (Correct): 4.0155\n", - " Mean Evidence (Incorrect): 1.6133\n", - " Mean Uncertainty: 0.5595\n", - "Epoch [99/100]:\n", - " Loss: 0.8772\n", - " Accuracy: 0.8609\n", - " Mean Evidence: 3.6434\n", - " Mean Evidence (Correct): 3.9542\n", - " Mean Evidence (Incorrect): 1.7052\n", - " Mean Uncertainty: 0.5588\n", - "Epoch [100/100]:\n", - " Loss: 0.8690\n", - " Accuracy: 0.8586\n", - " Mean Evidence: 3.7782\n", - " Mean Evidence (Correct): 4.1299\n", - " Mean Evidence (Incorrect): 1.6347\n", - " Mean Uncertainty: 0.5510\n", - "Training completed!\n" - ] - } - ], + "outputs": [], "source": [ - "import torch\n", - "import torch.optim as optim\n", - "\n", "def one_hot_embedding(labels, num_classes=10):\n", " # Convert to One Hot Encoding\n", " y = torch.eye(num_classes)\n", @@ -1122,6 +402,7 @@ "\n", "criterion = edl_digamma_loss\n", "optimizer = optim.Adam(ev_mlp.parameters(), lr=0.001)\n", + "metrics = MetricsCalculator(use_uncertainty=False)\n", "\n", "# Training loop\n", "num_epochs = 100\n", @@ -1172,6 +453,10 @@ " total_evidence_succ += mean_evidence_succ.item()\n", " total_evidence_fail += mean_evidence_fail.item()\n", " total_uncertainty += torch.mean(u).item()\n", + "\n", + " metrics_dict = metrics(one_hot_embedding(batch_y, 4), outputs, split=\"train\")\n", + " for name, value in metrics_dict.items():\n", + " results_dict[name].append(value.item())\n", " \n", " if (k + 1) == batches_per_epoch:\n", " break\n", @@ -1184,24 +469,54 @@ " avg_evidence_fail = total_evidence_fail / batches_per_epoch\n", " avg_uncertainty = total_uncertainty / batches_per_epoch\n", " \n", - " print(f'Epoch [{epoch+1}/{num_epochs}]:')\n", - " print(f' Loss: {avg_loss:.4f}')\n", - " print(f' Accuracy: {avg_acc:.4f}')\n", - " print(f' Mean Evidence: {avg_evidence:.4f}')\n", - " print(f' Mean Evidence (Correct): {avg_evidence_succ:.4f}')\n", - " print(f' Mean Evidence (Incorrect): {avg_evidence_fail:.4f}')\n", - " print(f' Mean Uncertainty: {avg_uncertainty:.4f}')\n", - "\n", - "print(\"Training completed!\")" + " # print(f'Epoch [{epoch+1}/{num_epochs}]:')\n", + " # print(f' Loss: {avg_loss:.4f}')\n", + " # print(f' Accuracy: {avg_acc:.4f}')\n", + " # print(f' Mean Evidence: {avg_evidence:.4f}')\n", + " # print(f' Mean Evidence (Correct): {avg_evidence_succ:.4f}')\n", + " # print(f' Mean Evidence (Incorrect): {avg_evidence_fail:.4f}')\n", + " # print(f' Mean Uncertainty: {avg_uncertainty:.4f}')" ] }, { "cell_type": "code", - "execution_count": null, - "id": "4d73573c-7cf7-4062-8284-b579169f1d14", + "execution_count": 16, + "id": "9bf9b627-4991-4187-8a10-1e0198951f98", "metadata": {}, - "outputs": [], - "source": [] + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train_csi 0.25621067497624167\n", + "train_ave_acc 0.467960801452284\n", + "train_prec 0.4292476873831698\n", + "train_recall 0.467960801452284\n", + "train_f1 0.4421940509144526\n", + "train_auc 0.8909479463346074\n" + ] + } + ], + "source": [ + "for key, val in results_dict.items():\n", + " print(key, np.mean(val))" + ] + }, + { + "cell_type": "markdown", + "id": "fe3615ee-7c0a-4b58-94a4-0d7616475be7", + "metadata": {}, + "source": [ + "### Thats it! " + ] + }, + { + "cell_type": "markdown", + "id": "58d14312-6a99-4940-87a8-41eb60c0bdfc", + "metadata": {}, + "source": [ + "### Questions? Email John Schreck (schreck@ucar.edu)" + ] } ], "metadata": { diff --git a/notebooks/regression_example_torch.ipynb b/notebooks/regression_example_torch.ipynb index 66b160d..7336eec 100644 --- a/notebooks/regression_example_torch.ipynb +++ b/notebooks/regression_example_torch.ipynb @@ -4,18 +4,20 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# MILES-GUESS Regression Example Notebook\n", + "# MILES-GUESS Regression Example Notebook (PyTorch)\n", "\n", "John Schreck, David John Gagne, Charlie Becker, Gabrielle Gantos, Dhamma Kimpara, Thomas Martin" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ - "import os, tqdm, yaml\n", + "import os\n", + "import tqdm \n", + "import yaml\n", "import numpy as np\n", "import pandas as pd\n", "#import seaborn as sns\n", @@ -24,18 +26,16 @@ "from sklearn.model_selection import GroupShuffleSplit\n", "from sklearn.preprocessing import MinMaxScaler, RobustScaler\n", "\n", - "from mlguess.torch.models import DNN\n", - "# from mlguess.keras.models import GaussianRegressorDNN, EvidentialRegressorDNN\n", - "# from mlguess.keras.models import BaseRegressor as RegressorDNN\n", - "# from mlguess.keras.callbacks import get_callbacks\n", - "# from mlguess.regression_uq import compute_results\n", - "\n", - "\n", "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", "from torch.utils.data import Dataset, DataLoader\n", "from sklearn.model_selection import GroupShuffleSplit\n", "from sklearn.preprocessing import RobustScaler, MinMaxScaler\n", - "import yaml" + "\n", + "from mlguess.torch.models import DNN\n", + "from collections import defaultdict\n", + "from torch.utils.data import TensorDataset, DataLoader" ] }, { @@ -49,7 +49,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -70,7 +70,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -92,7 +92,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -103,7 +103,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -129,7 +129,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -140,7 +140,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -154,591 +154,259 @@ "y_test = y_scaler.transform(test_data[output_cols])" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 1. Deterministic multi-layer perceptron (MLP) to predict some quantity\n", - "\n", - "#### Train the model" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "# TODO: Why overwrite these variables in the config?\n", - "conf[\"model\"][\"epochs\"] = 1\n", - "conf[\"model\"][\"verbose\"] = 1" - ] - }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ - "model = DNN(\n", - " len(input_cols),\n", - " len(output_cols),\n", - " block_sizes = [1000], \n", - " dr = [0.5], \n", - " batch_norm = True, \n", - " lng = True\n", - ")" + "X = torch.FloatTensor(x_train)\n", + "y = torch.FloatTensor(y_train)\n", + "\n", + "batch_size = 128\n", + "\n", + "# Create dataset and dataloader\n", + "dataset = TensorDataset(X, y)\n", + "dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Make a torch dataloader" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd\n", - "from sklearn.preprocessing import RobustScaler, MinMaxScaler\n", - "from sklearn.model_selection import GroupShuffleSplit\n", - "from torch.utils.data import Dataset" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "class SurfaceLayerFluxDataset(Dataset):\n", - " def __init__(self, config, split='train'):\n", - " data_path = config['data_path']\n", - " input_cols = config['input_cols']\n", - " output_cols = config['output_cols']\n", - " flat_seed = config['split_params']['flat_seed']\n", - " data_seed = config['split_params']['data_seed']\n", - " train_size = config['split_ratios']['train_size']\n", - " valid_size = config['split_ratios']['valid_size']\n", - "\n", - " # Load data\n", - " data = pd.read_csv(data_path)\n", - " data[\"day\"] = data[\"Time\"].apply(lambda x: str(x).split(\" \")[0])\n", - " data[\"year\"] = data[\"Time\"].apply(lambda x: str(x).split(\"-\")[0])\n", - "\n", - " # Split data into train and test\n", - " gsp = GroupShuffleSplit(n_splits=1, random_state=flat_seed, train_size=train_size)\n", - " splits = list(gsp.split(data, groups=data[\"year\"]))\n", - " train_index, test_index = splits[0]\n", - " train_data, test_data = data.iloc[train_index].copy(), data.iloc[test_index].copy()\n", - "\n", - " # Split train data into train and validation\n", - " gsp = GroupShuffleSplit(n_splits=1, random_state=flat_seed, train_size=valid_size)\n", - " splits = list(gsp.split(train_data, groups=train_data[\"year\"]))\n", - " train_index, valid_index = splits[data_seed]\n", - " train_data, valid_data = train_data.iloc[train_index].copy(), train_data.iloc[valid_index].copy()\n", - "\n", - " # Initialize scalers\n", - " self.x_scaler = RobustScaler()\n", - " self.y_scaler = MinMaxScaler((0, 1))\n", - "\n", - " # Fit scalers on training data\n", - " self.x_scaler.fit(train_data[input_cols])\n", - " self.y_scaler.fit(train_data[output_cols])\n", - "\n", - " # Transform data\n", - " if split == 'train':\n", - " self.inputs = self.x_scaler.transform(train_data[input_cols])\n", - " self.targets = self.y_scaler.transform(train_data[output_cols])\n", - " elif split == 'valid':\n", - " self.inputs = self.x_scaler.transform(valid_data[input_cols])\n", - " self.targets = self.y_scaler.transform(valid_data[output_cols])\n", - " elif split == 'test':\n", - " self.inputs = self.x_scaler.transform(test_data[input_cols])\n", - " self.targets = self.y_scaler.transform(test_data[output_cols])\n", - " else:\n", - " raise ValueError(\"Invalid split value. Choose from 'train', 'valid', 'test'.\")\n", - "\n", - " self.inputs = torch.tensor(self.inputs, dtype=torch.float32)\n", - " self.targets = torch.tensor(self.targets, dtype=torch.float32)\n", - "\n", - " def __len__(self):\n", - " return len(self.inputs)\n", + "### 1. Deterministic multi-layer perceptron (MLP) to predict some quantity\n", "\n", - " def __getitem__(self, idx):\n", - " x = self.inputs[idx]\n", - " y = self.targets[idx]\n", - " return x, y" + "#### Train the model" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ - "train_dataset = SurfaceLayerFluxDataset(conf['data'], split='train')" + "conf[\"model\"][\"lng\"] = False\n", + "device = \"cuda\"" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ - "x, y = train_dataset.__getitem__(0)" + "model = DNN(**conf[\"model\"]).to(device)" ] }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "torch.Size([1])" + "DNN(\n", + " (fcn): Sequential(\n", + " (0): Linear(in_features=4, out_features=1057, bias=True)\n", + " (1): LeakyReLU(negative_slope=0.01)\n", + " (2): Dropout(p=0.263, inplace=False)\n", + " (3): Linear(in_features=1057, out_features=1057, bias=True)\n", + " (4): LeakyReLU(negative_slope=0.01)\n", + " (5): Dropout(p=0.263, inplace=False)\n", + " (6): Linear(in_features=1057, out_features=1057, bias=True)\n", + " (7): LeakyReLU(negative_slope=0.01)\n", + " (8): Dropout(p=0.263, inplace=False)\n", + " (9): Linear(in_features=1057, out_features=1057, bias=True)\n", + " (10): LeakyReLU(negative_slope=0.01)\n", + " (11): Dropout(p=0.263, inplace=False)\n", + " (12): Linear(in_features=1057, out_features=1057, bias=True)\n", + " (13): LeakyReLU(negative_slope=0.01)\n", + " (14): Dropout(p=0.263, inplace=False)\n", + " (15): Linear(in_features=1057, out_features=1, bias=True)\n", + " )\n", + ")" ] }, - "execution_count": 21, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "y.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# model = RegressorDNN(**conf[\"model\"])\n", - "# model.build_neural_network(x_train.shape[-1], y_train.shape[-1])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# model.fit(x_train,\n", - "# y_train,\n", - "# validation_data=(x_valid, y_valid),\n", - "# callbacks=get_callbacks(conf))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "##### Predict with the model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "y_pred = model.predict(x_test, y_scaler)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "mae = np.mean(np.abs(y_pred[:, 0]-test_data[output_cols[0]]))\n", - "mae" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "##### Create a Monte Carlo ensemble" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "monte_carlo_steps = 10" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "results = model.predict_monte_carlo(x_test, monte_carlo_steps, y_scaler)" + "model" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "mu_ensemble = np.mean(results, axis=0)\n", - "var_ensemble = np.var(results, axis=0)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 2. Predict mu and sigma with a \"Gaussian MLP\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(0.0258, device='cuda:0', grad_fn=)\n" + ] + } + ], "source": [ - "config = \"../config/surface_layer/gaussian.yml\"\n", - "with open(config) as cf:\n", - " conf = yaml.load(cf, Loader=yaml.FullLoader)\n", + "results_dict = defaultdict(list)\n", "\n", - "conf[\"model\"][\"epochs\"] = 1\n", - "conf[\"model\"][\"verbose\"] = 1" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "gauss_model = GaussianRegressorDNN(**conf[\"model\"])\n", - "gauss_model.build_neural_network(x_train.shape[-1], y_train.shape[-1])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "gauss_model.fit(\n", - " x_train,\n", - " y_train,\n", - " validation_data=(x_valid, y_valid),\n", - " callbacks=get_callbacks(conf)\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "mu, var = gauss_model.predict_uncertainty(x_test, y_scaler)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# compute variance and std from learned parameters\n", - "#mu, var = gauss_model.calc_uncertainties(y_pred, y_scaler)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "mae = np.mean(np.abs(mu[:, 0]-test_data[output_cols[0]]))\n", - "print(mae, np.mean(var) ** (1/2))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "sns.jointplot(x=test_data[output_cols[0]], y=mu[:, 0], kind='hex')\n", - "plt.xlabel('Target')\n", - "plt.ylabel('Predicted Target')\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "sns.jointplot(x=mu[:, 0], y=np.sqrt(var)[:, 0], kind='hex')\n", - "plt.xlabel('Predicted mu')\n", - "plt.ylabel('Predicted sigma')\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### 3. Compute mu, aleatoric, and epistemic quantities using the evidential model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "config = \"../config/surface_layer/evidential.yml\"\n", + "criterion = nn.L1Loss()\n", + "optimizer = optim.Adam(model.parameters(), lr=0.001)\n", "\n", - "with open(config) as cf:\n", - " conf = yaml.load(cf, Loader=yaml.FullLoader)\n", + "model.train()\n", + "for i, (x, y) in enumerate(dataloader):\n", + " x = x.to(device)\n", + " y_pred = model(x)\n", + " y = y.to(device=device, dtype=x.dtype)\n", + " loss = criterion(y_pred, y.to(x.dtype)).mean()\n", "\n", - "conf[\"model\"][\"epochs\"] = 5\n", - "conf[\"model\"][\"verbose\"] = 1" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ev_model = EvidentialRegressorDNN(**conf[\"model\"])\n", - "ev_model.build_neural_network(x_train.shape[-1], y_train.shape[-1])" + " # Backward pass and optimize\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + "print(loss)" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "ev_model.fit(\n", - " x_train,\n", - " y_train,\n", - " validation_data=(x_valid, y_valid),\n", - " callbacks=get_callbacks(conf))" + "### Now lets use the evidential regressor" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ - "result = ev_model.predict_uncertainty(x_test, scaler=y_scaler)\n", - "mu, aleatoric, epistemic = result" + "conf[\"model\"][\"lng\"] = True" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ - "mae = np.mean(np.abs(mu[:, 0] - test_data[output_cols[0]]))\n", - "print(mae, np.mean(aleatoric)**(1/2), np.mean(epistemic)**(1/2))" + "model = DNN(**conf[\"model\"]).to(device)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "DNN(\n", + " (fcn): Sequential(\n", + " (0): Linear(in_features=4, out_features=1057, bias=True)\n", + " (1): LeakyReLU(negative_slope=0.01)\n", + " (2): Dropout(p=0.263, inplace=False)\n", + " (3): Linear(in_features=1057, out_features=1057, bias=True)\n", + " (4): LeakyReLU(negative_slope=0.01)\n", + " (5): Dropout(p=0.263, inplace=False)\n", + " (6): Linear(in_features=1057, out_features=1057, bias=True)\n", + " (7): LeakyReLU(negative_slope=0.01)\n", + " (8): Dropout(p=0.263, inplace=False)\n", + " (9): Linear(in_features=1057, out_features=1057, bias=True)\n", + " (10): LeakyReLU(negative_slope=0.01)\n", + " (11): Dropout(p=0.263, inplace=False)\n", + " (12): Linear(in_features=1057, out_features=1057, bias=True)\n", + " (13): LeakyReLU(negative_slope=0.01)\n", + " (14): Dropout(p=0.263, inplace=False)\n", + " (15): LinearNormalGamma(\n", + " (linear): Linear(in_features=1057, out_features=4, bias=True)\n", + " )\n", + " )\n", + ")" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "compute_results(test_data,\n", - " output_cols,\n", - " mu,\n", - " np.sqrt(aleatoric),\n", - " np.sqrt(epistemic))" + "model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### 4. Create a deep ensemble with the Gaussian model so that the law of total variance can be applied to compute aleatoric and epistemic" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "config = \"../config/surface_layer/gaussian.yml\"\n", - "with open(config) as cf:\n", - " conf = yaml.load(cf, Loader=yaml.FullLoader)\n", - "\n", - "conf[\"save_loc\"] = \"./\"\n", - "conf[\"model\"][\"epochs\"] = 1\n", - "conf[\"model\"][\"verbose\"] = 0\n", - "n_splits = conf[\"ensemble\"][\"n_splits\"]" + "### Add the training dataset variance to the model class to enable uncertainty calculations after training" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ - "# make save directory for model weights\n", - "os.makedirs(os.path.join(conf[\"save_loc\"], \"cv_ensemble\", \"models\"), exist_ok=True)" + "model.training_var = [np.var(y_train)] # list of length 1 for 1 task " ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ - "data_seed = 0\n", - "gsp = GroupShuffleSplit(n_splits=1, random_state=flat_seed, train_size=0.9)\n", - "splits = list(gsp.split(data, groups=data[\"day\"]))\n", - "train_index, test_index = splits[0]" + "from mlguess.torch.regression_losses import LipschitzMSELoss" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor(-1.8424, device='cuda:0', grad_fn=)\n" + ] + } + ], "source": [ - "ensemble_mu = np.zeros((n_splits, test_data.shape[0], 1))\n", - "ensemble_var = np.zeros((n_splits, test_data.shape[0], 1))\n", + "results_dict = defaultdict(list)\n", "\n", - "for data_seed in tqdm.tqdm(range(n_splits)):\n", - " data = pd.read_csv(fn)\n", - " data[\"day\"] = data[\"Time\"].apply(lambda x: str(x).split(\" \")[0])\n", + "criterion = LipschitzMSELoss(**conf[\"train_loss\"])\n", + "optimizer = optim.Adam(model.parameters(), lr=0.001)\n", "\n", - " # Need the same test_data for all trained models (data and model ensembles)\n", - " flat_seed = 1000\n", - " gsp = GroupShuffleSplit(n_splits=1,\n", - " random_state=flat_seed,\n", - " train_size=0.9)\n", - " splits = list(gsp.split(data, groups=data[\"day\"]))\n", - " train_index, test_index = splits[0]\n", - " train_data, test_data = data.iloc[train_index].copy(), data.iloc[test_index].copy()\n", + "model.train()\n", + "for i, (x, y) in enumerate(dataloader):\n", + " x = x.to(device)\n", + " y_pred = model(x)\n", + " gamma, nu, alpha, beta = y_pred\n", + " y = y.to(device=device, dtype=x.dtype)\n", + " loss = criterion(gamma, nu, alpha, beta, y.to(x.dtype))\n", "\n", - " # Make N train-valid splits using day as grouping variable\n", - " gsp = GroupShuffleSplit(n_splits=n_splits, random_state=flat_seed, train_size=0.885)\n", - " splits = list(gsp.split(train_data, groups=train_data[\"day\"]))\n", - " train_index, valid_index = splits[data_seed]\n", - " train_data, valid_data = train_data.iloc[train_index].copy(), train_data.iloc[valid_index].copy()\n", + " # Predict uncertainties\n", + " y_pred = (_.cpu().detach() for _ in y_pred)\n", + " mu, ale, epi, total = model.predict_uncertainty(y_pred, y_scaler=y_scaler)\n", + " loss = loss.mean()\n", "\n", - " x_scaler, y_scaler = RobustScaler(), MinMaxScaler((0, 1))\n", - " x_train = x_scaler.fit_transform(train_data[input_cols])\n", - " x_valid = x_scaler.transform(valid_data[input_cols])\n", - " x_test = x_scaler.transform(test_data[input_cols])\n", + " # Backward pass and optimize\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", "\n", - " y_train = y_scaler.fit_transform(train_data[output_cols])\n", - " y_valid = y_scaler.transform(valid_data[output_cols])\n", - " y_test = y_scaler.transform(test_data[output_cols])\n", - "\n", - " model = GaussianRegressorDNN(**conf[\"model\"])\n", - " model.build_neural_network(x_train.shape[-1], y_train.shape[-1])\n", - "\n", - " model.fit(\n", - " x_train,\n", - " y_train,\n", - " validation_data=(x_valid, y_valid),\n", - " callbacks=get_callbacks(conf))\n", - "\n", - " model.model_name = f\"cv_ensemble/models/model_seed0_split{data_seed}.h5\"\n", - " model.save_model()\n", - "\n", - " # Save the best model\n", - " model.model_name = \"cv_ensemble/models/best.h5\"\n", - " model.save_model()\n", - "\n", - " mu, var = model.predict_uncertainty(x_test, y_scaler)\n", - " mae = np.mean(np.abs(mu[:, 0]-test_data[output_cols[0]]))\n", - "\n", - " ensemble_mu[data_seed] = mu\n", - " ensemble_var[data_seed] = var" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "##### Use the method predict_ensemble to accomplish the same thing given pretrained models:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "model = GaussianRegressorDNN().load_model(conf)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ensemble_mu, ensemble_var = model.predict_ensemble(x_test, scaler=y_scaler)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "epistemic = np.var(ensemble_mu, axis=0)\n", - "aleatoric = np.mean(ensemble_var, axis=0)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(epistemic.mean()**(1/2), aleatoric.mean()**(1/2))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "compute_results(test_data,\n", - " output_cols,\n", - " np.mean(ensemble_mu, axis=0),\n", - " np.sqrt(aleatoric),\n", - " np.sqrt(epistemic))" + "print(loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### 5. Use Monte Carlo dropout with the Gaussian model to compute aleatoric and epistemic uncertainties" + "### Questions? Email John Schreck (schreck@ucar.edu)" ] }, { @@ -746,31 +414,14 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "monte_carlo_steps = 10\n", - "\n", - "ensemble_mu, ensemble_var = model.predict_monte_carlo(x_test,\n", - " monte_carlo_steps,\n", - " scaler=y_scaler)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ensemble_epistemic = np.var(ensemble_mu, axis=0)\n", - "ensemble_aleatoric = np.mean(ensemble_var, axis=0)\n", - "ensemble_mean = np.mean(ensemble_mu, axis=0)" - ] + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "evidential-casper", + "display_name": "credit", "language": "python", - "name": "evidential-casper" + "name": "credit" }, "language_info": { "codemirror_mode": { @@ -782,7 +433,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.16" + "version": "3.11.8" } }, "nbformat": 4,