diff --git a/1_pytorch_mnist_unet.ipynb b/1_pytorch_mnist_unet.ipynb deleted file mode 100644 index f21c78f..0000000 --- a/1_pytorch_mnist_unet.ipynb +++ /dev/null @@ -1,315 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 1/10, Batch 0/750, Loss: 0.4551\n", - "Epoch 1/10, Batch 100/750, Loss: 0.0079\n", - "Epoch 1/10, Batch 200/750, Loss: 0.0026\n", - "Epoch 1/10, Batch 300/750, Loss: 0.0017\n", - "Epoch 1/10, Batch 400/750, Loss: 0.0012\n", - "Epoch 1/10, Batch 500/750, Loss: 0.0009\n", - "Epoch 1/10, Batch 600/750, Loss: 0.0009\n", - "Epoch 1/10, Batch 700/750, Loss: 0.0008\n", - "Validation Loss: 0.0007\n", - "Epoch 2/10, Batch 0/750, Loss: 0.0007\n", - "Epoch 2/10, Batch 100/750, Loss: 0.0007\n", - "Epoch 2/10, Batch 200/750, Loss: 0.0006\n", - "Epoch 2/10, Batch 300/750, Loss: 0.0005\n", - "Epoch 2/10, Batch 400/750, Loss: 0.0005\n", - "Epoch 2/10, Batch 500/750, Loss: 0.0005\n", - "Epoch 2/10, Batch 600/750, Loss: 0.0005\n", - "Epoch 2/10, Batch 700/750, Loss: 0.0007\n", - "Validation Loss: 0.0004\n", - "Epoch 3/10, Batch 0/750, Loss: 0.0005\n", - "Epoch 3/10, Batch 100/750, Loss: 0.0004\n", - "Epoch 3/10, Batch 200/750, Loss: 0.0004\n", - "Epoch 3/10, Batch 300/750, Loss: 0.0004\n", - "Epoch 3/10, Batch 400/750, Loss: 0.0004\n", - "Epoch 3/10, Batch 500/750, Loss: 0.0003\n", - "Epoch 3/10, Batch 600/750, Loss: 0.0004\n", - "Epoch 3/10, Batch 700/750, Loss: 0.0003\n", - "Validation Loss: 0.0003\n", - "Epoch 4/10, Batch 0/750, Loss: 0.0003\n", - "Epoch 4/10, Batch 100/750, Loss: 0.0003\n", - "Epoch 4/10, Batch 200/750, Loss: 0.0004\n", - "Epoch 4/10, Batch 300/750, Loss: 0.0003\n", - "Epoch 4/10, Batch 400/750, Loss: 0.0006\n", - "Epoch 4/10, Batch 500/750, Loss: 0.0003\n", - "Epoch 4/10, Batch 600/750, Loss: 0.0003\n", - "Epoch 4/10, Batch 700/750, Loss: 0.0003\n", - "Validation Loss: 0.0004\n", - "Epoch 5/10, Batch 0/750, Loss: 0.0004\n", - "Epoch 5/10, Batch 100/750, Loss: 0.0003\n", - "Epoch 5/10, Batch 200/750, Loss: 0.0003\n", - "Epoch 5/10, Batch 300/750, Loss: 0.0003\n", - "Epoch 5/10, Batch 400/750, Loss: 0.0004\n", - "Epoch 5/10, Batch 500/750, Loss: 0.0002\n", - "Epoch 5/10, Batch 600/750, Loss: 0.0002\n", - "Epoch 5/10, Batch 700/750, Loss: 0.0004\n", - "Validation Loss: 0.0003\n", - "Epoch 6/10, Batch 0/750, Loss: 0.0003\n", - "Epoch 6/10, Batch 100/750, Loss: 0.0002\n", - "Epoch 6/10, Batch 200/750, Loss: 0.0004\n", - "Epoch 6/10, Batch 300/750, Loss: 0.0004\n", - "Epoch 6/10, Batch 400/750, Loss: 0.0002\n", - "Epoch 6/10, Batch 500/750, Loss: 0.0002\n", - "Epoch 6/10, Batch 600/750, Loss: 0.0003\n", - "Epoch 6/10, Batch 700/750, Loss: 0.0003\n", - "Validation Loss: 0.0002\n", - "Epoch 7/10, Batch 0/750, Loss: 0.0002\n", - "Epoch 7/10, Batch 100/750, Loss: 0.0002\n", - "Epoch 7/10, Batch 200/750, Loss: 0.0002\n", - "Epoch 7/10, Batch 300/750, Loss: 0.0002\n", - "Epoch 7/10, Batch 400/750, Loss: 0.0002\n", - "Epoch 7/10, Batch 500/750, Loss: 0.0002\n", - "Epoch 7/10, Batch 600/750, Loss: 0.0002\n", - "Epoch 7/10, Batch 700/750, Loss: 0.0003\n", - "Validation Loss: 0.0002\n", - "Epoch 8/10, Batch 0/750, Loss: 0.0002\n", - "Epoch 8/10, Batch 100/750, Loss: 0.0002\n", - "Epoch 8/10, Batch 200/750, Loss: 0.0002\n", - "Epoch 8/10, Batch 300/750, Loss: 0.0002\n", - "Epoch 8/10, Batch 400/750, Loss: 0.0002\n", - "Epoch 8/10, Batch 500/750, Loss: 0.0002\n", - "Epoch 8/10, Batch 600/750, Loss: 0.0002\n", - "Epoch 8/10, Batch 700/750, Loss: 0.0002\n", - "Validation Loss: 0.0002\n", - "Epoch 9/10, Batch 0/750, Loss: 0.0002\n", - "Epoch 9/10, Batch 100/750, Loss: 0.0002\n", - "Epoch 9/10, Batch 200/750, Loss: 0.0002\n", - "Epoch 9/10, Batch 300/750, Loss: 0.0002\n", - "Epoch 9/10, Batch 400/750, Loss: 0.0002\n", - "Epoch 9/10, Batch 500/750, Loss: 0.0001\n", - "Epoch 9/10, Batch 600/750, Loss: 0.0003\n", - "Epoch 9/10, Batch 700/750, Loss: 0.0002\n", - "Validation Loss: 0.0002\n", - "Epoch 10/10, Batch 0/750, Loss: 0.0002\n", - "Epoch 10/10, Batch 100/750, Loss: 0.0001\n", - "Epoch 10/10, Batch 200/750, Loss: 0.0003\n", - "Epoch 10/10, Batch 300/750, Loss: 0.0001\n", - "Epoch 10/10, Batch 400/750, Loss: 0.0002\n", - "Epoch 10/10, Batch 500/750, Loss: 0.0002\n", - "Epoch 10/10, Batch 600/750, Loss: 0.0002\n", - "Epoch 10/10, Batch 700/750, Loss: 0.0001\n", - "Validation Loss: 0.0002\n", - "Training complete!\n" - ] - } - ], - "source": [ - "import torch\n", - "import torch.nn as nn\n", - "import torch.optim as optim\n", - "from torch.utils.data import DataLoader, random_split\n", - "from torchvision import datasets, transforms\n", - "\n", - "# Define U-Net architecture\n", - "class UNet(nn.Module):\n", - " def __init__(self):\n", - " super(UNet, self).__init__()\n", - "\n", - " # Encoder\n", - " self.encoder = nn.Sequential(\n", - " nn.Conv2d(1, 64, kernel_size=3, padding=1),\n", - " nn.ReLU(inplace=True),\n", - " nn.Conv2d(64, 64, kernel_size=3, padding=1),\n", - " nn.ReLU(inplace=True),\n", - " nn.MaxPool2d(kernel_size=2, stride=2)\n", - " )\n", - "\n", - " # Decoder\n", - " self.decoder = nn.Sequential(\n", - " nn.Conv2d(64, 64, kernel_size=3, padding=1),\n", - " nn.ReLU(inplace=True),\n", - " nn.Conv2d(64, 64, kernel_size=3, padding=1),\n", - " nn.ReLU(inplace=True),\n", - " nn.ConvTranspose2d(64, 1, kernel_size=2, stride=2)\n", - " )\n", - "\n", - " def forward(self, x):\n", - " x1 = self.encoder(x)\n", - " x2 = self.decoder(x1)\n", - " return x2\n", - "\n", - "# Set device\n", - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "\n", - "# Load MNIST dataset and split into train, validation, and test sets\n", - "transform = transforms.Compose([transforms.ToTensor()])\n", - "full_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)\n", - "\n", - "train_size = int(0.8 * len(full_dataset))\n", - "val_size = int(0.1 * len(full_dataset))\n", - "test_size = len(full_dataset) - train_size - val_size\n", - "\n", - "train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])\n", - "\n", - "train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)\n", - "val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)\n", - "test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)\n", - "\n", - "# Instantiate the U-Net model and move it to the device\n", - "model = UNet().to(device)\n", - "\n", - "# Define loss function and optimizer\n", - "criterion = nn.MSELoss()\n", - "optimizer = optim.Adam(model.parameters(), lr=0.001)\n", - "\n", - "# Training loop\n", - "num_epochs = 10\n", - "for epoch in range(num_epochs):\n", - " model.train()\n", - " for batch_idx, (data, _) in enumerate(train_loader):\n", - " data = data.to(device)\n", - "\n", - " # Forward pass\n", - " output = model(data)\n", - "\n", - " # Compute loss\n", - " loss = criterion(output, data)\n", - "\n", - " # Backward pass and optimization\n", - " optimizer.zero_grad()\n", - " loss.backward()\n", - " optimizer.step()\n", - "\n", - " if batch_idx % 100 == 0:\n", - " print(f\"Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}\")\n", - "\n", - " # Validation loop\n", - " model.eval()\n", - " with torch.no_grad():\n", - " total_loss = 0.0\n", - " for data, _ in val_loader:\n", - " data = data.to(device)\n", - " output = model(data)\n", - " total_loss += criterion(output, data).item()\n", - "\n", - " average_loss = total_loss / len(val_loader)\n", - " print(f\"Validation Loss: {average_loss:.4f}\")\n", - "\n", - "print(\"Training complete!\")\n", - "\n", - "# Save the trained model\n", - "torch.save(model.state_dict(), 'unet_mnist.pth')" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "from PIL import Image\n", - "import matplotlib.pyplot as plt\n", - "\n", - "# Define U-Net architecture (make sure it matches the architecture used during training)\n", - "class UNet(nn.Module):\n", - " def __init__(self):\n", - " super(UNet, self).__init__()\n", - "\n", - " # Encoder\n", - " self.encoder = nn.Sequential(\n", - " nn.Conv2d(1, 64, kernel_size=3, padding=1),\n", - " nn.ReLU(inplace=True),\n", - " nn.Conv2d(64, 64, kernel_size=3, padding=1),\n", - " nn.ReLU(inplace=True),\n", - " nn.MaxPool2d(kernel_size=2, stride=2)\n", - " )\n", - "\n", - " # Decoder\n", - " self.decoder = nn.Sequential(\n", - " nn.Conv2d(64, 64, kernel_size=3, padding=1),\n", - " nn.ReLU(inplace=True),\n", - " nn.Conv2d(64, 64, kernel_size=3, padding=1),\n", - " nn.ReLU(inplace=True),\n", - " nn.ConvTranspose2d(64, 1, kernel_size=2, stride=2)\n", - " )\n", - "\n", - " def forward(self, x):\n", - " x1 = self.encoder(x)\n", - " x2 = self.decoder(x1)\n", - " return x2\n", - "\n", - "# Set device\n", - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "\n", - "# Load the trained model\n", - "model = UNet()\n", - "model.load_state_dict(torch.load('unet_mnist.pth'))\n", - "model.to(device)\n", - "model.eval()\n", - "\n", - "# Transform for input image\n", - "transform = transforms.Compose([transforms.Grayscale(num_output_channels=1),\n", - " transforms.ToTensor()])\n", - "\n", - "# # Load the test dataset\n", - "# test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)\n", - "# test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)\n", - "\n", - "# Testing loop\n", - "with torch.no_grad():\n", - " for batch_idx, (data, _) in enumerate(test_loader):\n", - " data = data.to(device)\n", - "\n", - " # Forward pass\n", - " output = model(data)\n", - "\n", - " # Visualize the input and output for the first batch\n", - " if batch_idx == 0:\n", - " for i in range(min(4, data.size(0))): # Visualize up to 4 samples\n", - " plt.subplot(2, 4, i + 1)\n", - " plt.title('Input')\n", - " plt.imshow(data[i].cpu().squeeze().numpy(), cmap='gray')\n", - "\n", - " plt.subplot(2, 4, i + 5)\n", - " plt.title('Output')\n", - " plt.imshow(output[i].cpu().squeeze().numpy(), cmap='gray')\n", - "\n", - " plt.show()\n", - "\n", - "# You can also calculate metrics like accuracy, precision, recall, etc., for more detailed evaluation." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "pytorch", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.6" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/6_pytorch_dataloading_tests_custom_Unet.ipynb b/6_pytorch_dataloading_tests_custom_Unet.ipynb deleted file mode 100644 index f9746e0..0000000 --- a/6_pytorch_dataloading_tests_custom_Unet.ipynb +++ /dev/null @@ -1,9532 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import torch.nn as nn\n", - "import torch.optim as optim\n", - "from torch.utils.data import DataLoader, random_split\n", - "from torchvision import datasets, transforms\n", - "\n", - "# Set device\n", - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "\n", - "# torch.set_float32_matmul_precision('medium')" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import pandas as pd\n", - "import os\n", - "import random\n", - "import matplotlib.pyplot as plt\n", - "\n", - "# We also set the logging level so that we get some feedback from the API\n", - "import logging\n", - "logging.basicConfig(level=logging.INFO)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Running in /data/hpcdata/users/bryald/git/icenet/notebook-pipeline\n" - ] - } - ], - "source": [ - "# Quick hack to put us in the icenet-pipeline folder,\n", - "# assuming it was created as per 01.cli_demonstration.ipynb\n", - "import os\n", - "if os.path.exists(\"pytorch_example.ipynb\"):\n", - " os.chdir(\"../notebook-pipeline\")\n", - "print(\"Running in {}\".format(os.getcwd()))\n", - "\n", - "%matplotlib inline" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "def plot_image(array, title=\"\"):\n", - " fig = plt.figure(figsize=(18, 10))\n", - " plt.imshow(array[0, :, :, 0, 0])\n", - " plt.colorbar(shrink=0.6)\n", - " plt.suptitle(title)\n", - " plt.tight_layout()\n", - "\n", - " iter = 0\n", - " out_dir = \"image-outputs\"\n", - " out_path = os.path.join(out_dir, f\"{iter}_{title}.jpg\")\n", - " if not os.path.isdir(out_dir): os.makedirs(out_dir)\n", - " \n", - " while os.path.exists(out_path):\n", - " iter += 1\n", - " out_path = os.path.join(out_dir, f\"{iter}_{title}.jpg\")\n", - " plt.savefig(out_path)\n", - " plt.close()" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-12-30 00:19:41.214859: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", - "2023-12-30 00:19:41.215362: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", - "2023-12-30 00:19:41.216466: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", - "INFO:root:Loading configuration loader.notebook_api_data.json\n" - ] - } - ], - "source": [ - "from icenet.data.loaders import IceNetDataLoaderFactory\n", - "\n", - "implementation = \"dask\"\n", - "loader_config = \"loader.notebook_api_data.json\"\n", - "dataset_name = \"pytorch_notebook\"\n", - "lag = 1\n", - "\n", - "dl = IceNetDataLoaderFactory().create_data_loader(\n", - " implementation,\n", - " loader_config,\n", - " dataset_name,\n", - " lag,\n", - " n_forecast_days=7,\n", - " north=False,\n", - " south=True,\n", - " output_batch_size=1,\n", - " generate_workers=4)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We generate a config only dataset, which will get saved in `dataset_config.pytorch_notebook.json`." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:root:Writing dataset configuration without data generation\n", - "INFO:root:91 train dates in total, NOT generating cache data.\n", - "INFO:root:21 val dates in total, NOT generating cache data.\n", - "INFO:root:2 test dates in total, NOT generating cache data.\n", - "INFO:root:Writing configuration to ./dataset_config.pytorch_notebook.json\n" - ] - } - ], - "source": [ - "dl.write_dataset_config_only()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can now create the IceNetDataSet object:" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "dataset_config = \"dataset_config.pytorch_notebook.json\"" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:root:Loading configuration dataset_config.pytorch_notebook.json\n", - "WARNING:root:Running in configuration only mode, tfrecords were not generated for this dataset\n" - ] - } - ], - "source": [ - "from icenet.data.dataset import IceNetDataSet\n", - "\n", - "dataset = IceNetDataSet(dataset_config, batch_size=1)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Custom PyTorch Dataset" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "from torch.utils.data import Dataset, DataLoader\n", - "\n", - "class IceNetDataSetPyTorch(Dataset):\n", - " def __init__(self,\n", - " configuration_path: str,\n", - " mode: str,\n", - " batch_size: int = 1,\n", - " shuffling: bool = False,\n", - " prediction: bool = False,\n", - " start_dates: object = None,\n", - " ):\n", - " self._ds = IceNetDataSet(configuration_path=configuration_path,\n", - " batch_size=batch_size,\n", - " shuffling=shuffling)\n", - " self._dl = self._ds.get_data_loader()\n", - "\n", - " # check mode option\n", - " if mode not in [\"train\", \"val\", \"test\", \"pred\"]:\n", - " raise ValueError(\"mode must be either 'train', 'val', 'test' or 'pred'\")\n", - " self._mode = mode\n", - "\n", - "\n", - " if mode.casefold() == \"pred\":\n", - " self._prediction = True\n", - " self._dates = start_dates\n", - " else:\n", - " self._prediction = False\n", - " self._dates = self._dl._config[\"sources\"][\"osisaf\"][\"dates\"][self._mode]\n", - "\n", - " def __len__(self):\n", - " if not self._prediction:\n", - " return self._ds._counts[self._mode]\n", - " else:\n", - " return len(self._dates)\n", - " \n", - " def __getitem__(self, idx):\n", - " # return tuple( map(lambda x: torch.from_numpy(x).float().contiguous(), self._dl.generate_sample(date=pd.Timestamp(self._dates[idx].replace('_', '-'))) ) )\n", - " return self._dl.generate_sample(date=pd.Timestamp(self._dates[idx].replace('_', '-')),\n", - " prediction=self._prediction,\n", - " parallel=False,\n", - " )\n", - "\n", - " def get_data_loader(self):\n", - " return self._ds.get_data_loader()\n", - " \n", - " @property\n", - " def dates(self):\n", - " return self._dates" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:root:Loading configuration dataset_config.pytorch_notebook.json\n", - "WARNING:root:Running in configuration only mode, tfrecords were not generated for this dataset\n", - "INFO:root:Loading configuration /data/hpcdata/users/bryald/git/icenet/notebook-pipeline/loader.notebook_api_data.json\n" - ] - }, - { - "data": { - "text/plain": [ - "91" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "ds_torch = IceNetDataSetPyTorch(configuration_path=dataset_config,\n", - " mode=\"train\")\n", - "len(ds_torch)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:root:Loading configuration dataset_config.pytorch_notebook.json\n", - "WARNING:root:Running in configuration only mode, tfrecords were not generated for this dataset\n", - "INFO:root:Loading configuration /data/hpcdata/users/bryald/git/icenet/notebook-pipeline/loader.notebook_api_data.json\n", - "INFO:root:Loading configuration dataset_config.pytorch_notebook.json\n", - "WARNING:root:Running in configuration only mode, tfrecords were not generated for this dataset\n", - "INFO:root:Loading configuration /data/hpcdata/users/bryald/git/icenet/notebook-pipeline/loader.notebook_api_data.json\n", - "INFO:root:Loading configuration dataset_config.pytorch_notebook.json\n", - "WARNING:root:Running in configuration only mode, tfrecords were not generated for this dataset\n", - "INFO:root:Loading configuration /data/hpcdata/users/bryald/git/icenet/notebook-pipeline/loader.notebook_api_data.json\n" - ] - } - ], - "source": [ - "train_dataset = IceNetDataSetPyTorch(configuration_path=dataset_config, mode=\"train\")\n", - "val_dataset = IceNetDataSetPyTorch(configuration_path=dataset_config, mode=\"val\")\n", - "test_dataset = IceNetDataSetPyTorch(configuration_path=dataset_config, mode=\"test\")" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [], - "source": [ - "batch_size = 4\n", - "shuffle = False\n", - "persistent_workers=True\n", - "num_workers = 4" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "# train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, persistent_workers=persistent_workers, num_workers=num_workers)\n", - "# val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, persistent_workers=persistent_workers, num_workers=num_workers)\n", - "# test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle, persistent_workers=persistent_workers, num_workers=num_workers)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [], - "source": [ - "# train_data = iter(train_dataloader)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "# # Checking inputs X, target, sample_weights\n", - "# for i, data in enumerate(train_dataloader):\n", - "# # X, y, sample_weights\n", - "# print(type(data), len(data))\n", - "# # print(data[0].shape, data[1].shape, data[2].shape)\n", - "# # print(data[0])\n", - "# fmin = torch.min(data[1])\n", - "# fmax = torch.max(data[1])\n", - "# print( f\"Target SIC min: {fmin:.4f}, max: {fmax:.4f}\" )\n", - "# target = torch.round(data[1], decimals=2)\n", - "# # target = data[1]\n", - "# unique_vals, counts = torch.unique(target, return_counts=True)\n", - "# for value, count in zip(unique_vals, counts):\n", - "# print(f\"{value}: {count} times\")\n", - "# if i == 2:\n", - "# break" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "# for i, data in enumerate(val_dataloader):\n", - "# # X, y, sample_weights\n", - "# print(type(data), len(data))\n", - "# # torch.Size([4, 432, 432, 9]) torch.Size([4, 432, 432, 7, 1]) torch.Size([4, 432, 432, 7, 1])\n", - "# print(data[0].shape, data[1].shape, data[2].shape)\n", - "# # print(data[0])\n", - "# if i == 2:\n", - "# break" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [], - "source": [ - "# for i, data in enumerate(test_dataloader):\n", - "# # X, y, sample_weights\n", - "# print(type(data), len(data))\n", - "# print(data[0].shape, data[1].shape, data[2].shape)\n", - "# # print(data[0])\n", - "# if i == 2:\n", - "# break" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [], - "source": [ - "# val0 = next(train_data)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [], - "source": [ - "# val1 = next(train_data)" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [], - "source": [ - "# try:\n", - "# print( iter(train_dataloader) )\n", - "# print(\"train_dataloader is iterable\")\n", - "# except TypeError:\n", - "# print(\"train_dataloader is not iterable\")" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [], - "source": [ - "# for batch_idx, (data, _) in enumerate(train_loader):\n", - "# # print(type(data))\n", - "# # print(data)\n", - "# # print(batch_idx)\n", - "# break" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## IceNet UNet model\n", - "\n", - "As a first attempt to implement a PyTorch example, we adapt code from https://github.com/ampersandmcd/icenet-gan/.\n", - "\n", - "Below is a PyTorch implementation of the UNet architecture." - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [], - "source": [ - "# import torch\n", - "# from torch import nn\n", - "# import torch.nn.functional as F\n", - "\n", - "# class UNet(nn.Module):\n", - "# \"\"\"\n", - "# An implementation of a UNet for pixelwise classification.\n", - "# \"\"\"\n", - " \n", - "# def __init__(self,\n", - "# input_channels, \n", - "# filter_size=3, \n", - "# n_filters_factor=1, \n", - "# n_forecast_days=7, \n", - "# n_output_classes=1,\n", - "# **kwargs):\n", - "# super(UNet, self).__init__()\n", - "\n", - "# self.input_channels = input_channels\n", - "# self.filter_size = filter_size\n", - "# self.n_filters_factor = n_filters_factor\n", - "# self.n_forecast_days = n_forecast_days\n", - "# self.n_output_classes = n_output_classes\n", - "\n", - "# self.conv1a = nn.Conv2d(in_channels=input_channels, \n", - "# out_channels=int(128*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.conv1b = nn.Conv2d(in_channels=int(128*n_filters_factor),\n", - "# out_channels=int(128*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.bn1 = nn.BatchNorm2d(num_features=int(128*n_filters_factor))\n", - "\n", - "# self.conv2a = nn.Conv2d(in_channels=int(128*n_filters_factor),\n", - "# out_channels=int(256*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.conv2b = nn.Conv2d(in_channels=int(256*n_filters_factor),\n", - "# out_channels=int(256*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.bn2 = nn.BatchNorm2d(num_features=int(256*n_filters_factor))\n", - "\n", - "# self.conv3a = nn.Conv2d(in_channels=int(256*n_filters_factor),\n", - "# out_channels=int(512*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.conv3b = nn.Conv2d(in_channels=int(512*n_filters_factor),\n", - "# out_channels=int(512*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.bn3 = nn.BatchNorm2d(num_features=int(512*n_filters_factor))\n", - "\n", - "# self.conv4a = nn.Conv2d(in_channels=int(512*n_filters_factor),\n", - "# out_channels=int(512*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.conv4b = nn.Conv2d(in_channels=int(512*n_filters_factor),\n", - "# out_channels=int(512*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.bn4 = nn.BatchNorm2d(num_features=int(512*n_filters_factor))\n", - "\n", - "# self.conv5a = nn.Conv2d(in_channels=int(512*n_filters_factor),\n", - "# out_channels=int(1024*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.conv5b = nn.Conv2d(in_channels=int(1024*n_filters_factor),\n", - "# out_channels=int(1024*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.bn5 = nn.BatchNorm2d(num_features=int(1024*n_filters_factor))\n", - "\n", - "# self.conv6a = nn.Conv2d(in_channels=int(1024*n_filters_factor),\n", - "# out_channels=int(512*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.conv6b = nn.Conv2d(in_channels=int(1024*n_filters_factor),\n", - "# out_channels=int(512*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.conv6c = nn.Conv2d(in_channels=int(512*n_filters_factor),\n", - "# out_channels=int(512*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.bn6 = nn.BatchNorm2d(num_features=int(512*n_filters_factor))\n", - "\n", - "# self.conv7a = nn.Conv2d(in_channels=int(512*n_filters_factor),\n", - "# out_channels=int(512*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.conv7b = nn.Conv2d(in_channels=int(1024*n_filters_factor),\n", - "# out_channels=int(512*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.conv7c = nn.Conv2d(in_channels=int(512*n_filters_factor),\n", - "# out_channels=int(512*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.bn7 = nn.BatchNorm2d(num_features=int(512*n_filters_factor))\n", - "\n", - "# self.conv8a = nn.Conv2d(in_channels=int(512*n_filters_factor),\n", - "# out_channels=int(256*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.conv8b = nn.Conv2d(in_channels=int(512*n_filters_factor),\n", - "# out_channels=int(256*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.conv8c = nn.Conv2d(in_channels=int(256*n_filters_factor),\n", - "# out_channels=int(256*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.bn8 = nn.BatchNorm2d(num_features=int(256*n_filters_factor))\n", - "\n", - "# self.conv9a = nn.Conv2d(in_channels=int(256*n_filters_factor),\n", - "# out_channels=int(128*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.conv9b = nn.Conv2d(in_channels=int(256*n_filters_factor),\n", - "# out_channels=int(128*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.conv9c = nn.Conv2d(in_channels=int(128*n_filters_factor),\n", - "# out_channels=int(128*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\") # no batch norm on last layer\n", - "\n", - "# self.final_conv = nn.Conv2d(in_channels=int(128*n_filters_factor),\n", - "# out_channels=n_output_classes*n_forecast_days,\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - " \n", - "# def forward(self, x):\n", - "\n", - "# # transpose from shape (b, h, w, c) to (b, c, h, w) for pytorch conv2d layers\n", - "# x = torch.movedim(x, -1, 1) # move c from last to second dim\n", - "\n", - "# # run through network\n", - "# conv1 = self.conv1a(x) # input to 128\n", - "# conv1 = F.relu(conv1)\n", - "# conv1 = self.conv1b(conv1) # 128 to 128\n", - "# conv1 = F.relu(conv1)\n", - "# bn1 = self.bn1(conv1)\n", - "# pool1 = F.max_pool2d(bn1, kernel_size=(2, 2))\n", - "\n", - "# conv2 = self.conv2a(pool1) # 128 to 256\n", - "# conv2 = F.relu(conv2)\n", - "# conv2 = self.conv2b(conv2) # 256 to 256\n", - "# conv2 = F.relu(conv2)\n", - "# bn2 = self.bn2(conv2)\n", - "# pool2 = F.max_pool2d(bn2, kernel_size=(2, 2))\n", - "\n", - "# conv3 = self.conv3a(pool2) # 256 to 512\n", - "# conv3 = F.relu(conv3)\n", - "# conv3 = self.conv3b(conv3) # 512 to 512\n", - "# conv3 = F.relu(conv3)\n", - "# bn3 = self.bn3(conv3)\n", - "# pool3 = F.max_pool2d(bn3, kernel_size=(2, 2))\n", - "\n", - "# conv4 = self.conv4a(pool3) # 512 to 512\n", - "# conv4 = F.relu(conv4)\n", - "# conv4 = self.conv4b(conv4) # 512 to 512\n", - "# conv4 = F.relu(conv4)\n", - "# bn4 = self.bn4(conv4)\n", - "# pool4 = F.max_pool2d(bn4, kernel_size=(2, 2))\n", - "\n", - "# conv5 = self.conv5a(pool4) # 512 to 1024\n", - "# conv5 = F.relu(conv5)\n", - "# conv5 = self.conv5b(conv5) # 1024 to 1024\n", - "# conv5 = F.relu(conv5)\n", - "# bn5 = self.bn5(conv5)\n", - "\n", - "# up6 = F.interpolate(bn5, scale_factor=2, mode=\"nearest\")\n", - "# up6 = self.conv6a(up6) # 1024 to 512\n", - "# up6 = F.relu(up6)\n", - "# merge6 = torch.cat([bn4, up6], dim=1) # 512 and 512 to 1024 along c dimension\n", - "# conv6 = self.conv6b(merge6) # 1024 to 512\n", - "# conv6 = F.relu(conv6)\n", - "# conv6 = self.conv6c(conv6) # 512 to 512\n", - "# conv6 = F.relu(conv6)\n", - "# bn6 = self.bn6(conv6)\n", - "\n", - "# up7 = F.interpolate(bn6, scale_factor=2, mode=\"nearest\")\n", - "# up7 = self.conv7a(up7) # 1024 to 512\n", - "# up7 = F.relu(up7)\n", - "# merge7 = torch.cat([bn3, up7], dim=1) # 512 and 512 to 1024 along c dimension\n", - "# conv7 = self.conv7b(merge7) # 1024 to 512\n", - "# conv7 = F.relu(conv7)\n", - "# conv7 = self.conv7c(conv7) # 512 to 512\n", - "# conv7 = F.relu(conv7)\n", - "# bn7 = self.bn7(conv7)\n", - "\n", - "# up8 = F.interpolate(bn7, scale_factor=2, mode=\"nearest\")\n", - "# up8 = self.conv8a(up8) # 512 to 256\n", - "# up8 = F.relu(up8)\n", - "# merge8 = torch.cat([bn2, up8], dim=1) # 256 and 256 to 512 along c dimension\n", - "# conv8 = self.conv8b(merge8) # 512 to 256\n", - "# conv8 = F.relu(conv8)\n", - "# conv8 = self.conv8c(conv8) # 256 to 256\n", - "# conv8 = F.relu(conv8)\n", - "# bn8 = self.bn8(conv8)\n", - "\n", - "# up9 = F.interpolate(bn8, scale_factor=2, mode=\"nearest\")\n", - "# up9 = self.conv9a(up9) # 256 to 128\n", - "# up9 = F.relu(up9)\n", - "# merge9 = torch.cat([bn1, up9], dim=1) # 128 and 128 to 256 along c dimension\n", - "# conv9 = self.conv9b(merge9) # 256 to 128\n", - "# conv9 = F.relu(conv9)\n", - "# conv9 = self.conv9c(conv9) # 128 to 128\n", - "# conv9 = F.relu(conv9) # no batch norm on last layer\n", - " \n", - "# final_layer_logits = self.final_conv(conv9)\n", - "\n", - "# # transpose from shape (b, c, h, w) back to (b, h, w, c) to align with training data\n", - "# final_layer_logits = torch.movedim(final_layer_logits, 1, -1) # move c from second to final dim\n", - "# b, h, w, c = final_layer_logits.shape\n", - "\n", - "\n", - "# # unpack c=classes*days dimension into classes, days as separate dimensions\n", - "# final_layer_logits = final_layer_logits.reshape((b, h, w, self.n_output_classes, self.n_forecast_days))\n", - "\n", - "# # output = F.softmax(final_layer_logits, dim=-2) # apply over n_output_classes dimension\n", - "# output = F.sigmoid(final_layer_logits) # Single output class.\n", - "# # output = final_layer_logits\n", - "\n", - "# # print(f\"Final layer shape: {output.shape}\")\n", - "# # print(f\"self.n_output_classes: {self.n_output_classes}\")\n", - "\n", - "# return output # shape (b, h, w, c, t)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [], - "source": [ - "# model = UNet(\n", - "# input_channels=train_dataset._ds._config[\"num_channels\"],\n", - "# filter_size=3,\n", - "# n_filters_factor=1,\n", - "# n_forecast_days=train_dataset._ds._config[\"n_forecast_days\"]\n", - "# )\n", - "\n", - "# # Print the model architecture\n", - "# print(model)" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "\n", - "class Interpolate(nn.Module):\n", - " def __init__(self, scale_factor, mode):\n", - " super().__init__()\n", - " self.interp = F.interpolate\n", - " self.scale_factor = scale_factor\n", - " self.mode = mode\n", - "\n", - " def forward(self, x):\n", - " x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode)\n", - " return x\n", - "\n", - "class UNet(nn.Module):\n", - " def __init__(self,\n", - " input_channels,\n", - " filter_size=3,\n", - " n_filters_factor=1,\n", - " n_forecast_days=7,\n", - " n_output_classes=1,\n", - " **kwargs\n", - " ):\n", - " super(UNet, self).__init__()\n", - "\n", - " self.input_channels = input_channels\n", - " self.filter_size = filter_size\n", - " self.n_filters_factor = n_filters_factor\n", - " self.n_forecast_days = n_forecast_days\n", - " self.n_output_classes = n_output_classes\n", - "\n", - " # Encoder\n", - " self.conv1 = self.conv_block(input_channels, int(64 * n_filters_factor))\n", - " self.conv2 = self.conv_block(int(64 * n_filters_factor), int(128 * n_filters_factor))\n", - " self.conv3 = self.conv_block(int(128 * n_filters_factor), int(256 * n_filters_factor))\n", - " self.conv4 = self.conv_block(int(256 * n_filters_factor), int(256 * n_filters_factor))\n", - "\n", - " # Bottleneck\n", - " self.conv5 = self.bottleneck_block(int(256 * n_filters_factor), int(512 * n_filters_factor))\n", - "\n", - " # Decoder\n", - " self.up6 = self.upconv_block(int(512 * n_filters_factor), int(256 * n_filters_factor))\n", - " self.up7 = self.upconv_block(int(256 * n_filters_factor), int(256 * n_filters_factor))\n", - " self.up8 = self.upconv_block(int(256 * n_filters_factor), int(128 * n_filters_factor))\n", - " self.up9 = self.upconv_block(int(128 * n_filters_factor), int(64 * n_filters_factor))\n", - "\n", - " self.up6b = self.conv_block(int(512 * n_filters_factor), int(256 * n_filters_factor))\n", - " self.up7b = self.conv_block(int(512 * n_filters_factor), int(256 * n_filters_factor))\n", - " self.up8b = self.conv_block(int(256 * n_filters_factor), int(128 * n_filters_factor))\n", - " self.up9b = self.conv_block(int(128 * n_filters_factor), int(64 * n_filters_factor), final=True)\n", - "\n", - "\n", - " # Final layer\n", - " self.final_layer = nn.Conv2d(int(64 * n_filters_factor), n_forecast_days, kernel_size=1, padding=\"same\")\n", - "\n", - " def forward(self, x):\n", - " # transpose from shape (b, h, w, c) to (b, c, h, w) for pytorch conv2d layers\n", - " x = torch.movedim(x, -1, 1) # move c from last to second dim\n", - "\n", - " # Encoder\n", - " bn1 = self.conv1(x)\n", - " conv1 = F.max_pool2d(bn1, kernel_size=2)\n", - " bn2 = self.conv2(conv1)\n", - " conv2 = F.max_pool2d(bn2, kernel_size=2)\n", - " bn3 = self.conv3(conv2)\n", - " conv3 = F.max_pool2d(bn3, kernel_size=2)\n", - " bn4 = self.conv4(conv3)\n", - " conv4 = F.max_pool2d(bn4, kernel_size=2)\n", - "\n", - " # conv3 = self.conv3(F.max_pool2d(conv2, kernel_size=2))\n", - " # conv4 = self.conv4(F.max_pool2d(conv3, kernel_size=2))\n", - " \n", - " # Bottleneck\n", - " bn5 = self.conv5(conv4)\n", - "\n", - " # Decoder\n", - " # up6 = self.up6(conv5) + conv4\n", - " # up7 = self.up7(up6) + conv3\n", - " # up8 = self.up8(up7) + conv2\n", - " # up9 = self.up9(up8) + conv1\n", - "\n", - " # # Final layer\n", - " # output = self.final_layer(up9)\n", - "\n", - " # print(\"Old shape: \", conv4.shape)\n", - " # print(\"New shape: \", self.up6(bn5).shape)\n", - "\n", - " up6 = self.up6b(torch.cat([bn4, self.up6(bn5)], dim=1))\n", - " up7 = self.up7b(torch.cat([bn3, self.up7(up6)], dim=1))\n", - " up8 = self.up8b(torch.cat([bn2, self.up8(up7)], dim=1))\n", - " up9 = self.up9b(torch.cat([bn1, self.up9(up8)], dim=1))\n", - "\n", - " # up6 = self.up6(conv5)\n", - " # up7 = self.up7(up6, conv4)\n", - " # up8 = self.up8(up7, conv3)\n", - " # up9 = self.up9(up8, conv2)\n", - "\n", - " # Final layer\n", - " # output = self.final_layer(up9)\n", - " output = self.final_layer(up9)\n", - "\n", - "\n", - " # up6 = torch.cat([conv5, self.up6(conv5)], dim=1)\n", - " # up7 = torch.cat([conv4, self.up7(up6)], dim=1)\n", - " # up8 = torch.cat([conv3, self.up8(up7)], dim=1)\n", - " # up9 = torch.cat([conv2, self.up9(up8)], dim=1)\n", - "\n", - " # # up6 = self.up6(conv5)\n", - " # # up7 = self.up7(up6, conv4)\n", - " # # up8 = self.up8(up7, conv3)\n", - " # # up9 = self.up9(up8, conv2)\n", - "\n", - " # # Final layer\n", - " # # output = self.final_layer(up9)\n", - " # output = torch.cat([conv1, self.final_layer(up9)], dim=1)\n", - "\n", - " # transpose from shape (b, c, h, w) back to (b, h, w, c) to align with training data\n", - " output = torch.movedim(output, 1, -1) # move c from second to final dim\n", - "\n", - " b, h, w, c = output.shape\n", - "\n", - " # unpack c=classes*months dimension into classes, months as separate dimensions\n", - " output = output.reshape((b, h, w, self.n_output_classes, self.n_forecast_days))\n", - "\n", - " return output\n", - "\n", - " def conv_block(self, in_channels, out_channels, kernel_size=3, final=False):\n", - " if not final:\n", - " return nn.Sequential(\n", - " nn.Conv2d(in_channels, out_channels, kernel_size=self.filter_size, padding=\"same\"),\n", - " nn.ReLU(inplace=True),\n", - " nn.Conv2d(out_channels, out_channels, kernel_size=self.filter_size, padding=\"same\"),\n", - " nn.ReLU(inplace=True),\n", - " nn.BatchNorm2d(num_features=out_channels),\n", - " # nn.MaxPool2d(kernel_size=(2, 2))\n", - " )\n", - " else:\n", - " return nn.Sequential(\n", - " nn.Conv2d(in_channels, out_channels, kernel_size=self.filter_size, padding=\"same\"),\n", - " nn.ReLU(inplace=True),\n", - " nn.Conv2d(out_channels, out_channels, kernel_size=self.filter_size, padding=\"same\"),\n", - " nn.ReLU(inplace=True),\n", - " nn.Conv2d(out_channels, out_channels, kernel_size=self.filter_size, padding=\"same\"),\n", - " nn.ReLU(inplace=True),\n", - " )\n", - "\n", - "\n", - " def bottleneck_block(self, in_channels, out_channels, kernel_size=3):\n", - " return nn.Sequential(\n", - " nn.Conv2d(in_channels, out_channels, kernel_size=self.filter_size, padding=\"same\"),\n", - " nn.ReLU(inplace=True),\n", - " nn.Conv2d(out_channels, out_channels, kernel_size=self.filter_size, padding=\"same\"),\n", - " nn.ReLU(inplace=True),\n", - " nn.BatchNorm2d(num_features=out_channels),\n", - " )\n", - "\n", - " def upconv_block(self, in_channels, out_channels, kernel_size=2):\n", - " return nn.Sequential(\n", - " # nn.ConvTranspose2d(in_channels, out_channels, kernel_size=self.filter_size, stride=2),\n", - " # nn.functional.interpolate(scale_factor=2, mode='nearest'),\n", - " Interpolate(scale_factor=2, mode='nearest'),\n", - " nn.Conv2d(in_channels, out_channels, kernel_size=2, padding=\"same\"),\n", - " nn.ReLU(inplace=True)\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "New - Based on IceNet2" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [], - "source": [ - "# import torch\n", - "# import torch.nn as nn\n", - "# import torch.nn.functional as F\n", - "\n", - "# class UNet(nn.Module):\n", - "# \"\"\"\n", - "# An implementation of a UNet for pixelwise classification.\n", - "# \"\"\"\n", - " \n", - "# def __init__(self,\n", - "# input_channels, \n", - "# filter_size=3, \n", - "# n_filters_factor=1, \n", - "# n_forecast_days=7,\n", - "# n_output_classes=1,\n", - "# **kwargs):\n", - "# super(UNet, self).__init__()\n", - "\n", - "# self.input_channels = input_channels\n", - "# self.filter_size = filter_size\n", - "# self.n_filters_factor = n_filters_factor\n", - "# self.n_forecast_days = n_forecast_days\n", - "# self.n_output_classes = n_output_classes\n", - "\n", - "# self.conv1a = nn.Conv2d(in_channels=input_channels, \n", - "# out_channels=int(64*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.conv1b = nn.Conv2d(in_channels=int(64*n_filters_factor),\n", - "# out_channels=int(64*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.bn1 = nn.BatchNorm2d(num_features=int(64*n_filters_factor))\n", - "\n", - "# self.conv2a = nn.Conv2d(in_channels=int(64*n_filters_factor),\n", - "# out_channels=int(128*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.conv2b = nn.Conv2d(in_channels=int(128*n_filters_factor),\n", - "# out_channels=int(128*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.bn2 = nn.BatchNorm2d(num_features=int(128*n_filters_factor))\n", - "\n", - "# self.conv3a = nn.Conv2d(in_channels=int(128*n_filters_factor),\n", - "# out_channels=int(256*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.conv3b = nn.Conv2d(in_channels=int(256*n_filters_factor),\n", - "# out_channels=int(256*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.bn3 = nn.BatchNorm2d(num_features=int(256*n_filters_factor))\n", - "\n", - "# self.conv4a = nn.Conv2d(in_channels=int(256*n_filters_factor),\n", - "# out_channels=int(256*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.conv4b = nn.Conv2d(in_channels=int(256*n_filters_factor),\n", - "# out_channels=int(256*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.bn4 = nn.BatchNorm2d(num_features=int(256*n_filters_factor))\n", - "\n", - "# self.conv5a = nn.Conv2d(in_channels=int(256*n_filters_factor),\n", - "# out_channels=int(512*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.conv5b = nn.Conv2d(in_channels=int(512*n_filters_factor),\n", - "# out_channels=int(512*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.bn5 = nn.BatchNorm2d(num_features=int(512*n_filters_factor))\n", - "\n", - "# self.conv6a = nn.Conv2d(in_channels=int(512*n_filters_factor),\n", - "# out_channels=int(256*n_filters_factor),\n", - "# kernel_size=2,\n", - "# padding=\"same\")\n", - "# self.conv6b = nn.Conv2d(in_channels=int(512*n_filters_factor),\n", - "# out_channels=int(256*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.conv6c = nn.Conv2d(in_channels=int(256*n_filters_factor),\n", - "# out_channels=int(256*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.bn6 = nn.BatchNorm2d(num_features=int(256*n_filters_factor))\n", - "\n", - "# self.conv7a = nn.Conv2d(in_channels=int(256*n_filters_factor),\n", - "# out_channels=int(256*n_filters_factor),\n", - "# kernel_size=2,\n", - "# padding=\"same\")\n", - "# self.conv7b = nn.Conv2d(in_channels=int(512*n_filters_factor),\n", - "# out_channels=int(256*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.conv7c = nn.Conv2d(in_channels=int(256*n_filters_factor),\n", - "# out_channels=int(256*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.bn7 = nn.BatchNorm2d(num_features=int(256*n_filters_factor))\n", - "\n", - "# self.conv8a = nn.Conv2d(in_channels=int(256*n_filters_factor),\n", - "# out_channels=int(128*n_filters_factor),\n", - "# kernel_size=2,\n", - "# padding=\"same\")\n", - "# self.conv8b = nn.Conv2d(in_channels=int(256*n_filters_factor),\n", - "# out_channels=int(128*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.conv8c = nn.Conv2d(in_channels=int(128*n_filters_factor),\n", - "# out_channels=int(128*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.bn8 = nn.BatchNorm2d(num_features=int(128*n_filters_factor))\n", - "\n", - "# self.conv9a = nn.Conv2d(in_channels=int(128*n_filters_factor),\n", - "# out_channels=int(64*n_filters_factor),\n", - "# kernel_size=2,\n", - "# padding=\"same\")\n", - "# self.conv9b = nn.Conv2d(in_channels=int(128*n_filters_factor),\n", - "# out_channels=int(64*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.conv9c = nn.Conv2d(in_channels=int(64*n_filters_factor),\n", - "# out_channels=int(64*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\")\n", - "# self.conv9d = nn.Conv2d(in_channels=int(64*n_filters_factor),\n", - "# out_channels=int(64*n_filters_factor),\n", - "# kernel_size=filter_size,\n", - "# padding=\"same\") # no batch norm on last layer\n", - "\n", - "# self.final_conv = nn.Conv2d(in_channels=int(64*n_filters_factor),\n", - "# out_channels=n_output_classes*n_forecast_days,\n", - "# kernel_size=1,\n", - "# padding=\"same\")\n", - " \n", - "# def forward(self, x):\n", - "\n", - "# # transpose from shape (b, h, w, c) to (b, c, h, w) for pytorch conv2d layers\n", - "# x = torch.movedim(x, -1, 1) # move c from last to second dim\n", - "\n", - "# # run through network\n", - "# conv1 = self.conv1a(x) # input to 64\n", - "# conv1 = F.relu(conv1)\n", - "# conv1 = self.conv1b(conv1) # 64 to 64\n", - "# conv1 = F.relu(conv1)\n", - "# bn1 = self.bn1(conv1)\n", - "# pool1 = F.max_pool2d(bn1, kernel_size=(2, 2))\n", - "\n", - "# conv2 = self.conv2a(pool1) # 64 to 128\n", - "# conv2 = F.relu(conv2)\n", - "# conv2 = self.conv2b(conv2) # 128 to 128\n", - "# conv2 = F.relu(conv2)\n", - "# bn2 = self.bn2(conv2)\n", - "# pool2 = F.max_pool2d(bn2, kernel_size=(2, 2))\n", - "\n", - "# conv3 = self.conv3a(pool2) # 128 to 256\n", - "# conv3 = F.relu(conv3)\n", - "# conv3 = self.conv3b(conv3) # 256 to 256\n", - "# conv3 = F.relu(conv3)\n", - "# bn3 = self.bn3(conv3)\n", - "# pool3 = F.max_pool2d(bn3, kernel_size=(2, 2))\n", - "\n", - "# conv4 = self.conv4a(pool3) # 256 to 256\n", - "# conv4 = F.relu(conv4)\n", - "# conv4 = self.conv4b(conv4) # 256 to 256\n", - "# conv4 = F.relu(conv4)\n", - "# bn4 = self.bn4(conv4)\n", - "# pool4 = F.max_pool2d(bn4, kernel_size=(2, 2))\n", - "\n", - "# conv5 = self.conv5a(pool4) # 256 to 512\n", - "# conv5 = F.relu(conv5)\n", - "# conv5 = self.conv5b(conv5) # 512 to 512\n", - "# conv5 = F.relu(conv5)\n", - "# bn5 = self.bn5(conv5)\n", - "\n", - "# up6 = F.interpolate(bn5, scale_factor=2, mode=\"nearest\")\n", - "# up6 = self.conv6a(up6) # 512 to 256\n", - "# up6 = F.relu(up6)\n", - "# merge6 = torch.cat([bn4, up6], dim=1) # 256 and 526 to 512 along c dimension\n", - "# conv6 = self.conv6b(merge6) # 512 to 256\n", - "# conv6 = F.relu(conv6)\n", - "# conv6 = self.conv6c(conv6) # 256 to 256\n", - "# conv6 = F.relu(conv6)\n", - "# bn6 = self.bn6(conv6)\n", - "\n", - "# up7 = F.interpolate(bn6, scale_factor=2, mode=\"nearest\")\n", - "# up7 = self.conv7a(up7) # 512 to 256\n", - "# up7 = F.relu(up7)\n", - "# merge7 = torch.cat([bn3, up7], dim=1) # 256 and 256 to 512 along c dimension\n", - "# conv7 = self.conv7b(merge7) # 512 to 256\n", - "# conv7 = F.relu(conv7)\n", - "# conv7 = self.conv7c(conv7) # 256 to 256\n", - "# conv7 = F.relu(conv7)\n", - "# bn7 = self.bn7(conv7)\n", - "\n", - "# up8 = F.interpolate(bn7, scale_factor=2, mode=\"nearest\")\n", - "# up8 = self.conv8a(up8) # 256 to 128\n", - "# up8 = F.relu(up8)\n", - "# merge8 = torch.cat([bn2, up8], dim=1) # 128 and 128 to 256 along c dimension\n", - "# conv8 = self.conv8b(merge8) # 256 to 128\n", - "# conv8 = F.relu(conv8)\n", - "# conv8 = self.conv8c(conv8) # 128 to 128\n", - "# conv8 = F.relu(conv8)\n", - "# bn8 = self.bn8(conv8)\n", - "\n", - "# up9 = F.interpolate(bn8, scale_factor=2, mode=\"nearest\")\n", - "# up9 = self.conv9a(up9) # 128 to 64\n", - "# up9 = F.relu(up9)\n", - "# merge9 = torch.cat([bn1, up9], dim=1) # 64 and 64 to 128 along c dimension\n", - "# conv9 = self.conv9b(merge9) # 128 to 64\n", - "# conv9 = F.relu(conv9)\n", - "# conv9 = self.conv9c(conv9) # 64 to 64\n", - "# conv9 = F.relu(conv9)\n", - "# conv9 = self.conv9d(conv9) # 64 to 64\n", - "# conv9 = F.relu(conv9) # no batch norm on last layer\n", - " \n", - "# final_layer_logits = self.final_conv(conv9)\n", - "\n", - "# # transpose from shape (b, c, h, w) back to (b, h, w, c) to align with training data\n", - "# final_layer_logits = torch.movedim(final_layer_logits, 1, -1) # move c from second to final dim\n", - "# b, h, w, c = final_layer_logits.shape\n", - "\n", - "# # unpack c=classes*months dimension into classes, months as separate dimensions\n", - "# final_layer_logits = final_layer_logits.reshape((b, h, w, self.n_output_classes, self.n_forecast_days))\n", - "\n", - "# return final_layer_logits\n", - "\n", - "# # # output = F.softmax(final_layer_logits, dim=-2) # apply over n_output_classes dimension\n", - "# # output = F.sigmoid(final_layer_logits)\n", - "\n", - "# # return output # shape (b, h, w, c, t)" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [], - "source": [ - "# import torch\n", - "# import torch.nn as nn\n", - "# import torch.nn.functional as F\n", - "\n", - "# # class UNetBatchNorm(nn.Module):\n", - "# class UNet(nn.Module):\n", - "# def __init__(self, input_channels, filter_size=3, n_filters_factor=1, n_forecast_days=6):\n", - "# super(UNet, self).__init__()\n", - "# self.input_channels = input_channels\n", - "# self.filter_size = filter_size\n", - "# self.n_filters_factor = n_filters_factor\n", - "# self.n_output_classes = 1\n", - "# self.n_forecast_days = n_forecast_days\n", - "\n", - "# # Encoder\n", - "# self.conv1 = self.conv_block(input_channels, int(64 * n_filters_factor))\n", - "# self.conv2 = self.conv_block(int(64 * n_filters_factor), int(128 * n_filters_factor))\n", - "# self.conv3 = self.conv_block(int(128 * n_filters_factor), int(256 * n_filters_factor))\n", - "# self.conv4 = self.conv_block(int(256 * n_filters_factor), int(256 * n_filters_factor))\n", - "# self.conv5 = self.conv_block(int(256 * n_filters_factor), int(512 * n_filters_factor))\n", - "\n", - "# # Decoder\n", - "# self.up6 = self.upconv_block(int(512 * n_filters_factor), int(256 * n_filters_factor))\n", - "# self.up7 = self.upconv_block(int(256 * n_filters_factor), int(256 * n_filters_factor))\n", - "# self.up8 = self.upconv_block(int(256 * n_filters_factor), int(128 * n_filters_factor))\n", - "# self.up9 = self.upconv_block(int(128 * n_filters_factor), int(64 * n_filters_factor))\n", - "\n", - "# # Final layer\n", - "# self.final_layer = nn.Conv2d(int(64 * n_filters_factor), n_forecast_days, kernel_size=1)\n", - "\n", - "# def conv_block(self, in_channels, out_channels, kernel_size=3):\n", - "# return nn.Sequential(\n", - "# nn.Conv2d(in_channels, out_channels, kernel_size, padding=\"same\"),\n", - "# nn.ReLU(inplace=True),\n", - "# nn.Conv2d(out_channels, out_channels, kernel_size, padding=\"same\"),\n", - "# nn.ReLU(inplace=True),\n", - "# nn.BatchNorm2d(out_channels)\n", - "# )\n", - "\n", - "# def upconv_block(self, in_channels, out_channels, kernel_size=2):\n", - "# return nn.Sequential(\n", - "# nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=2),\n", - "# nn.ReLU(inplace=True)\n", - "# )\n", - "\n", - "# def forward(self, x):\n", - "# n_filters_factor = self.n_filters_factor\n", - "\n", - "# # transpose from shape (b, h, w, c) to (b, c, h, w) for pytorch conv2d layers\n", - "# x = torch.movedim(x, -1, 1) # move c from last to second dim\n", - "\n", - "# # Encoder\n", - "# conv1 = self.conv1(x)\n", - "# pool1 = F.max_pool2d(conv1, 2)\n", - "\n", - "# conv2 = self.conv2(pool1)\n", - "# pool2 = F.max_pool2d(conv2, 2)\n", - "\n", - "# conv3 = self.conv3(pool2)\n", - "# pool3 = F.max_pool2d(conv3, 2)\n", - "\n", - "# conv4 = self.conv4(pool3)\n", - "# pool4 = F.max_pool2d(conv4, 2)\n", - "\n", - "# conv5 = self.conv5(pool4)\n", - "\n", - "# # Decoder\n", - "# up6 = self.up6(conv5)\n", - "# print(\"Input shape:\", up6.shape)\n", - "# merge6 = torch.cat([conv4, up6], dim=1)\n", - "# print(\"Output shape:\", merge6.shape)\n", - "# conv6 = self.conv_block(int(512 * n_filters_factor), int(256 * n_filters_factor))(merge6)\n", - "\n", - "# up7 = self.up7(conv6)\n", - "# merge7 = torch.cat([conv3, up7], dim=1)\n", - "# conv7 = self.conv_block(int(256 * n_filters_factor), int(256 * n_filters_factor))(merge7)\n", - "\n", - "# up8 = self.up8(conv7)\n", - "# merge8 = torch.cat([conv2, up8], dim=1)\n", - "# conv8 = self.conv_block(int(256 * n_filters_factor), int(128 * n_filters_factor))(merge8)\n", - "\n", - "# up9 = self.up9(conv8)\n", - "# merge9 = torch.cat([conv1, up9], dim=1)\n", - "# conv9 = self.conv_block(int(128 * n_filters_factor), int(64 * n_filters_factor))(merge9)\n", - "\n", - "# # Final layer\n", - "# output = self.final_layer(conv9)\n", - "\n", - "# # transpose from shape (b, c, h, w) back to (b, h, w, c) to align with training data\n", - "# output = torch.movedim(output, 1, -1) # move c from second to final dim\n", - "\n", - "# return output\n", - "\n", - "# # # Instantiate the model\n", - "# # input_channels = train_dataset._ds._config[\"num_channels\"] # Adjust based on your input data\n", - "# # n_forecast_days = 7\n", - "# # n_filters_factor = 1\n", - "# # filter_size = 3\n", - "# # model = UNet(input_channels, filter_size, n_filters_factor, n_forecast_days)\n", - "\n", - "# # # Print the model architecture\n", - "# # print(model)" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [], - "source": [ - "from torchmetrics import Metric\n", - "\n", - "class IceNetAccuracy(Metric):\n", - " \"\"\"Binary accuracy metric for use at multiple leadtimes.\n", - "\n", - " Reference: https://lightning.ai/docs/torchmetrics/stable/pages/implement.html\n", - " \"\"\" \n", - "\n", - " # Set class properties\n", - " is_differentiable: bool = False\n", - " higher_is_better: bool = True\n", - " full_state_update: bool = True\n", - "\n", - " def __init__(self, leadtimes_to_evaluate: list):\n", - " \"\"\"Custom loss/metric for binary accuracy in classifying SIC>15% for multiple leadtimes.\n", - "\n", - " Args:\n", - " leadtimes_to_evaluate: A list of leadtimes to consider\n", - " e.g., [0, 1, 2, 3, 4, 5] to consider first six days in accuracy computation or\n", - " e.g., [0] to only look at the first day's accuracy\n", - " e.g., [5] to only look at the sixth day's accuracy\n", - " \"\"\"\n", - " super().__init__()\n", - " self.leadtimes_to_evaluate = leadtimes_to_evaluate\n", - " self.add_state(\"weighted_score\", default=torch.tensor(0.), dist_reduce_fx=\"sum\")\n", - " self.add_state(\"possible_score\", default=torch.tensor(0.), dist_reduce_fx=\"sum\")\n", - "\n", - " def update(self, preds: torch.Tensor, target: torch.Tensor, sample_weight: torch.Tensor):\n", - " # preds and target are shape (b, h, w, t)\n", - " # sum marginal and full ice for binary eval\n", - " # print(f\"preds shape: {preds.shape}\")\n", - " # print(f\"target shape: {target.shape}\")\n", - " preds = (preds > 0.15).long() # torch.Size([2, 432, 432, 7])\n", - " target = (target > 0.15).long() # torch.Size([2, 432, 432, 7])\n", - " # sample_weight = sample_weight.squeeze(dim=-1) # torch.Size([2, 432, 432, 7, 1]) to torch.Size([2, 432, 432, 7])\n", - " # print(f\"preds shape: {preds.shape}\")\n", - " # print(f\"target shape: {target.shape}\")\n", - " # print(f\"sample_weight shape: {sample_weight.squeeze(dim=-1).shape}\")\n", - " base_score = preds[:, :, :, self.leadtimes_to_evaluate] == target[:, :, :, self.leadtimes_to_evaluate]\n", - " self.weighted_score += torch.sum(base_score * sample_weight[:, :, :, self.leadtimes_to_evaluate])\n", - " self.possible_score += torch.sum(sample_weight[:, :, :, self.leadtimes_to_evaluate])\n", - "\n", - " def compute(self):\n", - " return self.weighted_score.float() / self.possible_score * 100.0\n", - "\n", - "\n", - "class SIEError(Metric):\n", - " \"\"\"\n", - " Sea Ice Extent error metric (in km^2) for use at multiple leadtimes.\n", - " \"\"\" \n", - "\n", - " # Set class properties\n", - " is_differentiable: bool = False\n", - " higher_is_better: bool = False\n", - " full_state_update: bool = True\n", - "\n", - " def __init__(self, leadtimes_to_evaluate: list):\n", - " \"\"\"Construct an SIE error metric (in km^2) for use at multiple leadtimes.\n", - " leadtimes_to_evaluate: A list of leadtimes to consider\n", - " e.g., [0, 1, 2, 3, 4, 5] to consider six days in computation or\n", - " e.g., [0] to only look at the first day\n", - " e.g., [5] to only look at the sixth day\n", - " \"\"\"\n", - " super().__init__()\n", - " self.leadtimes_to_evaluate = leadtimes_to_evaluate\n", - " self.add_state(\"pred_sie\", default=torch.tensor(0.), dist_reduce_fx=\"sum\")\n", - " self.add_state(\"true_sie\", default=torch.tensor(0.), dist_reduce_fx=\"sum\")\n", - "\n", - " def update(self, preds: torch.Tensor, target: torch.Tensor, sample_weight: torch.Tensor):\n", - " # preds and target are shape (b, h, w, t)\n", - " preds = (preds > 0.15).long()\n", - " target = (target > 0.15).long()\n", - " self.pred_sie += preds[:, :, :, self.leadtimes_to_evaluate].sum()\n", - " self.true_sie += target[:, :, :, self.leadtimes_to_evaluate].sum()\n", - "\n", - " def compute(self):\n", - " return (self.pred_sie - self.true_sie) * 25**2 # each pixel is 25x25 km" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Loss function" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": {}, - "outputs": [], - "source": [ - "# def loss_func(output, target, sample_weight):\n", - "# \"\"\"BCEWithLogitsLoss\"\"\"\n", - "# # criterion = nn.BCEWithLogitsLoss(reduction='none')\n", - "# # loss = criterion((output.movedim(-2, 1)), (target.movedim(-1, 1)))\n", - "# # loss = torch.mean(loss * sample_weight.movedim(-1, 1))\n", - "\n", - "# # Convert continuous target values to binary labels for BCELoss calculation\n", - "# targets = torch.round(target)\n", - "\n", - "# criterion = nn.BCEWithLogitsLoss(reduction='none')\n", - "# loss = criterion(\n", - "# (output.movedim(-2, 1)), \n", - "# (targets.movedim(-1, 1))\n", - "# )\n", - "# loss = torch.mean(loss*sample_weight.movedim(-1, 1))\n", - "\n", - "# return loss\n", - "\n", - "# def loss_func(output, target, sample_weight):\n", - "# \"\"\"L1 Loss\"\"\"\n", - "# y_hat = torch.sigmoid(output)\n", - " \n", - "# # criterion = nn.L1Loss(reduction=\"none\")\n", - "# # loss = criterion(\n", - "# # (100*y_hat.movedim(-2, 1)), \n", - "# # (100*target.movedim(-1, 1))\n", - "# # )\n", - "# # loss = torch.mean(loss*sample_weight.movedim(-1, 1))\n", - "\n", - "# loss = torch.mean(\n", - "# (\n", - "# 100*( torch.abs( y_hat.movedim(-2, 1) - target.movedim(-1, 1) ) )\n", - "# )*sample_weight.movedim(-1, 1)\n", - "# )\n", - "\n", - "# return loss\n", - "\n", - "def loss_func(output, target, sample_weight):\n", - " \"\"\"L2 Loss\"\"\"\n", - " y_hat = torch.sigmoid(output)\n", - " # criterion = nn.MSELoss(reduction=\"none\")\n", - " # loss = criterion(\n", - " # (100*y_hat.movedim(-2, 1)), \n", - " # (100*target.movedim(-1, 1))\n", - " # )\n", - " # loss = torch.mean(loss*sample_weight.movedim(-1, 1))\n", - "\n", - " # plot_image(y_hat.numpy(force=True), title=\"output\")\n", - " # plot_image(target.numpy(force=True), title=\"target\")\n", - " # plot_image(sample_weight.cpu(), title=\"mask\")\n", - " # loss_calc = ((\n", - " # ( ( y_hat.movedim(-2, 1) - target.movedim(-1, 1) )*100 )**2\n", - " # )*sample_weight.movedim(-1, 1)).movedim(1, -1)\n", - " # plot_image(loss_calc.numpy(force=True), title=\"losscalc\")\n", - " # # return\n", - "\n", - " loss = torch.mean(\n", - " (\n", - " ( ( y_hat.movedim(-2, 1) - target.movedim(-1, 1) )*100 )**2\n", - " )*sample_weight.movedim(-1, 1)\n", - " )\n", - "\n", - " return loss" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "A _LightningModule_ wrapper for UNet model." - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": {}, - "outputs": [], - "source": [ - "import lightning.pytorch as pl\n", - "from lightning.pytorch.utilities.types import TRAIN_DATALOADERS\n", - "from torchmetrics import MetricCollection\n", - "\n", - "class LitUNet(pl.LightningModule):\n", - " \"\"\"\n", - " A LightningModule wrapping the UNet implementation of IceNet.\n", - " \"\"\"\n", - " def __init__(self,\n", - " model: nn.Module,\n", - " criterion: callable,\n", - " learning_rate: float):\n", - " \"\"\"\n", - " Construct a UNet LightningModule.\n", - " Note that we keep hyperparameters separate from dataloaders to prevent data leakage at test time.\n", - " :param model: PyTorch model\n", - " :param criterion: PyTorch loss function for training instantiated with reduction=\"none\"\n", - " :param learning_rate: Float learning rate for our optimiser\n", - " \"\"\"\n", - " super().__init__()\n", - " self.model = model\n", - " self.criterion = criterion\n", - " self.learning_rate = learning_rate\n", - " self.n_output_classes = model.n_output_classes # this should be a property of the network\n", - "\n", - " metrics = {\n", - " \"val_accuracy\": IceNetAccuracy(leadtimes_to_evaluate=list(range(self.model.n_forecast_days))),\n", - " # \"val_sieerror\": SIEError(leadtimes_to_evaluate=list(range(self.model.n_forecast_days)))\n", - " }\n", - " for i in range(self.model.n_forecast_days):\n", - " metrics[f\"val_accuracy_{i}\"] = IceNetAccuracy(leadtimes_to_evaluate=[i])\n", - " # metrics[f\"val_sieerror_{i}\"] = SIEError(leadtimes_to_evaluate=[i])\n", - " self.metrics = MetricCollection(metrics)\n", - "\n", - " test_metrics = {\n", - " \"test_accuracy\": IceNetAccuracy(leadtimes_to_evaluate=list(range(self.model.n_forecast_days))),\n", - " # \"test_sieerror\": SIEError(leadtimes_to_evaluate=list(range(self.model.n_forecast_days)))\n", - " }\n", - " for i in range(self.model.n_forecast_days):\n", - " test_metrics[f\"test_accuracy_{i}\"] = IceNetAccuracy(leadtimes_to_evaluate=[i])\n", - " # test_metrics[f\"test_sieerror_{i}\"] = SIEError(leadtimes_to_evaluate=[i])\n", - " self.test_metrics = MetricCollection(test_metrics)\n", - "\n", - " # Save input parameters to __init__ (hyperparams) when checkpointing.\n", - " # self.save_hyperparameters(ignore=[\"model\", \"criterion\"])\n", - " self.save_hyperparameters()\n", - "\n", - " def forward(self, x):\n", - " \"\"\"\n", - " Implement forward function.\n", - " :param x: Inputs to model.\n", - " :return: Outputs of model.\n", - " \"\"\"\n", - " return self.model(x)\n", - "\n", - " def training_step(self, batch):\n", - " \"\"\"\n", - " Perform a pass through a batch of training data.\n", - " Apply pixel-weighted loss by manually reducing.\n", - " See e.g. https://discuss.pytorch.org/t/unet-pixel-wise-weighted-loss-function/46689/5.\n", - " :param batch: Batch of input, output, weight triplets\n", - " :param batch_idx: Index of batch\n", - " :return: Loss from this batch of data for use in backprop\n", - " \"\"\"\n", - " x, y, sample_weight = batch\n", - " outputs = self.model(x)\n", - " # y_hat = torch.sigmoid(outputs)\n", - "\n", - " # For cross-entropy loss\n", - " # loss = self.criterion(y_hat.movedim(-2, 1)*100, y.movedim(-1, 1)*100)\n", - " # loss = torch.mean(loss * sample_weight.movedim(-2, 1))\n", - "\n", - " # Use weighted mean squared error as loss function, matching IceNet 2.\n", - " # Mean squared error of SIC (%) (float)\n", - " # loss = torch.mean(\n", - " # (\n", - " # ( y_hat.movedim(-2, 1)*100 - y.movedim(-1, 1)*100 )**2.\n", - " # ) * sample_weight.movedim(-2, 1)\n", - " # )\n", - "\n", - " # criterion = nn.BCEWithLogitsLoss(reduction='none')\n", - " # loss = criterion((outputs.movedim(-2, 1)), (y.movedim(-1, 1)))\n", - " # loss = torch.mean(loss * sample_weight.movedim(-1, 1))\n", - "\n", - " loss = loss_func(outputs, y, sample_weight)\n", - "\n", - " self.log(\"train_loss\", loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)\n", - " return {\"loss\": loss}\n", - "\n", - " def validation_step(self, batch):\n", - " # x: (b, h, w, channels), y: (b, h, w, n_forecast_days, classes), sample_weight: (b, h, w, n_forecast_days, classes)\n", - " x, y, sample_weight = batch\n", - " # y_hat: (b, h, w, classes, n_forecast_days)\n", - " outputs = self.model(x)\n", - " y_hat = torch.sigmoid(outputs)\n", - " # print(f\"x, y, sample_weight shapes: {x.shape}, {y.shape}, {sample_weight.shape}\")\n", - " # print(f\"y_hat shape: {y_hat.shape}\")\n", - "\n", - " # For cross-entropy loss\n", - " # loss = self.criterion(y_hat.movedim(-2, 1)*100, y.movedim(-1, 1)*100)\n", - " # loss = torch.mean(loss * sample_weight.movedim(-2, 1))\n", - " # Use weighted mean squared error as loss function, matching IceNet 2.\n", - " # Mean squared error of SIC (%) (float)\n", - " # loss = torch.mean(\n", - " # (\n", - " # ( y_hat.movedim(-2, 1)*100 - y.movedim(-1, 1)*100 )**2.\n", - " # ) * sample_weight.movedim(-2, 1)\n", - " # )\n", - "\n", - " # criterion = nn.BCEWithLogitsLoss(reduction='none')\n", - " # loss = criterion((outputs.movedim(-2, 1)), (y.movedim(-1, 1)))\n", - " # loss = torch.mean(loss * sample_weight.movedim(-1, 1))\n", - "\n", - " loss = loss_func(outputs, y, sample_weight)\n", - "\n", - " self.log(\"val_loss\", loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) # epoch-level loss\n", - " # y_hat_pred = y_hat.argmax(dim=-2).long() # argmax over c where shape is (b, h, w, c, t)\n", - " # self.metrics.update(y_hat_pred, y.argmax(dim=-1).long(), sample_weight.squeeze(dim=-1)) # shape (b, h, w, t)\n", - "\n", - " self.metrics.update(y_hat.squeeze(dim=-2), y.squeeze(dim=-1), sample_weight.squeeze(dim=-1))\n", - "\n", - " return {\"val_loss\", loss}\n", - "\n", - " def on_validation_epoch_end(self):\n", - " self.log_dict(self.metrics.compute(), on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) # epoch-level metrics\n", - " self.metrics.reset()\n", - "\n", - " def test_step(self, batch):\n", - " x, y, sample_weight = batch\n", - " outputs = self.model(x)\n", - " y_hat = torch.sigmoid(outputs)\n", - " # print(f\"x, y, sample_weight shapes: {x.shape}, {y.shape}, {sample_weight.shape}\")\n", - " # print(f\"y_hat shape: {y_hat.shape}\")\n", - " # Use weighted mean squared error as loss function, matching IceNet 2.\n", - " # Mean squared error of SIC (%) (float)\n", - " # loss = self.criterion(y_hat.movedim(-2, 1)*100, y.movedim(-1, 1)*100)\n", - " # loss = torch.mean(loss * sample_weight.movedim(-2, 1))\n", - " # loss = torch.mean(\n", - " # (\n", - " # ( y_hat.movedim(-2, 1)*100 - y.movedim(-1, 1)*100 )**2.\n", - " # ) * sample_weight.movedim(-2, 1)\n", - " # )\n", - "\n", - " # criterion = nn.BCEWithLogitsLoss(reduction='none')\n", - " # loss = criterion((outputs.movedim(-2, 1)), (y.movedim(-1, 1)))\n", - " # loss = torch.mean(loss * sample_weight.movedim(-1, 1))\n", - "\n", - " loss = loss_func(outputs, y, sample_weight)\n", - "\n", - " self.log(\"test_loss\", loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) # epoch-level loss\n", - " # y_hat_pred = y_hat.argmax(dim=-2) # argmax over c where shape is (b, h, w, c, t)\n", - " # self.test_metrics.update(y_hat_pred, y.argmax(dim=-1).long(), sample_weight.squeeze(dim=-1)) # shape (b, h, w, t)\n", - "\n", - " self.test_metrics.update(y_hat.squeeze(dim=-2), y.squeeze(dim=-1), sample_weight.squeeze(dim=-1))\n", - " \n", - " return loss\n", - "\n", - " def on_test_epoch_end(self):\n", - " self.log_dict(self.test_metrics.compute(), on_step=False, on_epoch=True, sync_dist=True) # epoch-level metrics\n", - " self.test_metrics.reset()\n", - "\n", - "\n", - " def predict_step(self, batch):\n", - " \"\"\"\n", - " :param batch: Batch of input, output, weight triplets\n", - " :param batch_idx: Index of batch\n", - " :return: Predictions for given input.\n", - " \"\"\"\n", - " x, y, sample_weight = batch\n", - " y_hat = torch.sigmoid(self.model(x))\n", - "\n", - " return y_hat\n", - "\n", - " def configure_optimizers(self):\n", - " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", - " return {\n", - " \"optimizer\": optimizer\n", - " }\n", - "\n", - " # def train_dataloader(self) -> TRAIN_DATALOADERS:\n", - " # return super().train_dataloader()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Function for training UNet model using PyTorch Lightning." - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": {}, - "outputs": [], - "source": [ - "from lightning.pytorch.callbacks import ModelCheckpoint\n", - "\n", - "def train_icenet(configuration_path,\n", - " learning_rate,\n", - " max_epochs,\n", - " batch_size,\n", - " n_workers,\n", - " filter_size,\n", - " n_filters_factor,\n", - " seed):\n", - " \"\"\"\n", - " Train IceNet using the arguments specified in the `args` namespace.\n", - " :param args: Namespace of configuration parameters\n", - " \"\"\"\n", - " # init\n", - " pl.seed_everything(seed)\n", - " \n", - " # configure datasets and dataloaders\n", - " train_dataset = IceNetDataSetPyTorch(configuration_path, mode=\"train\")\n", - " val_dataset = IceNetDataSetPyTorch(configuration_path, mode=\"val\")\n", - " train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=n_workers,\n", - " persistent_workers=True, shuffle=False)\n", - " val_dataloader = DataLoader(val_dataset, batch_size=batch_size, num_workers=n_workers,\n", - " persistent_workers=True, shuffle=False)\n", - "\n", - " # print(\"Input train shapes: \")\n", - " # for batch in train_dataloader:\n", - " # [print(batch[i].shape, end='') for i in range(3)]\n", - " # print()\n", - "\n", - " # print(\"Input val shapes: \")\n", - " # for batch in val_dataloader:\n", - " # [print(batch[i].shape, end='') for i in range(3)]\n", - " # print()\n", - "\n", - " # construct unet\n", - " model = UNet(\n", - " input_channels=train_dataset._ds._config[\"num_channels\"],\n", - " filter_size=filter_size,\n", - " n_filters_factor=n_filters_factor,\n", - " n_forecast_days=train_dataset._ds._config[\"n_forecast_days\"]\n", - " )\n", - " \n", - " # criterion = nn.CrossEntropyLoss(reduction=\"none\")\n", - " criterion = nn.MSELoss(reduction=\"none\")\n", - " # from torchvision.ops.focal_loss import sigmoid_focal_loss\n", - " # criterion = sigmoid_focal_loss\n", - " \n", - " # configure PyTorch Lightning module\n", - " lit_module = LitUNet(\n", - " model=model,\n", - " criterion=criterion,\n", - " learning_rate=learning_rate\n", - " )\n", - "\n", - " # set up trainer configuration\n", - " trainer = pl.Trainer(\n", - " accelerator=\"auto\",\n", - " devices=1,\n", - " log_every_n_steps=5,\n", - " max_epochs=max_epochs,\n", - " num_sanity_val_steps=1,\n", - " fast_dev_run=False, # Runs single batch through train and validation\n", - " # when running trainer.test()\n", - " # Note: Cannot use with automatic best checkpointing\n", - " )\n", - " checkpoint_callback = ModelCheckpoint(monitor=\"val_accuracy\", mode=\"max\")\n", - " trainer.callbacks.append(checkpoint_callback)\n", - "\n", - " # train model\n", - " print(f\"Training {len(train_dataset)} examples / {len(train_dataloader)} batches (batch size {batch_size}).\")\n", - " print(f\"Validating {len(val_dataset)} examples / {len(val_dataloader)} batches (batch size {batch_size}).\")\n", - " trainer.fit(lit_module, train_dataloader, val_dataloader)\n", - "\n", - " return model, trainer, checkpoint_callback" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO: Seed set to 45\n", - "INFO:lightning.fabric.utilities.seed:Seed set to 45\n", - "INFO:root:Loading configuration dataset_config.pytorch_notebook.json\n", - "WARNING:root:Running in configuration only mode, tfrecords were not generated for this dataset\n", - "INFO:root:Loading configuration /data/hpcdata/users/bryald/git/icenet/notebook-pipeline/loader.notebook_api_data.json\n", - "INFO:root:Loading configuration dataset_config.pytorch_notebook.json\n", - "WARNING:root:Running in configuration only mode, tfrecords were not generated for this dataset\n", - "INFO:root:Loading configuration /data/hpcdata/users/bryald/git/icenet/notebook-pipeline/loader.notebook_api_data.json\n", - "/data/hpcdata/users/bryald/miniconda3/envs/pytorch/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.\n", - "/data/hpcdata/users/bryald/miniconda3/envs/pytorch/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'criterion' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['criterion'])`.\n", - "/data/hpcdata/users/bryald/miniconda3/envs/pytorch/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /data/hpcdata/users/bryald/miniconda3/envs/pytorch/l ...\n", - "INFO: GPU available: True (cuda), used: True\n", - "INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True\n", - "INFO: TPU available: False, using: 0 TPU cores\n", - "INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n", - "INFO: IPU available: False, using: 0 IPUs\n", - "INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs\n", - "INFO: HPU available: False, using: 0 HPUs\n", - "INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs\n", - "INFO: You are using a CUDA device ('NVIDIA A2') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", - "INFO:lightning.pytorch.utilities.rank_zero:You are using a CUDA device ('NVIDIA A2') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", - "INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Training 91 examples / 23 batches (batch size 4).\n", - "Validating 21 examples / 6 batches (batch size 4).\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", - "INFO: \n", - " | Name | Type | Params\n", - "--------------------------------------------------\n", - "0 | model | UNet | 11.0 M\n", - "1 | criterion | MSELoss | 0 \n", - "2 | metrics | MetricCollection | 0 \n", - "3 | test_metrics | MetricCollection | 0 \n", - "--------------------------------------------------\n", - "11.0 M Trainable params\n", - "0 Non-trainable params\n", - "11.0 M Total params\n", - "43.817 Total estimated model params size (MB)\n", - "INFO:lightning.pytorch.callbacks.model_summary:\n", - " | Name | Type | Params\n", - "--------------------------------------------------\n", - "0 | model | UNet | 11.0 M\n", - "1 | criterion | MSELoss | 0 \n", - "2 | metrics | MetricCollection | 0 \n", - "3 | test_metrics | MetricCollection | 0 \n", - "--------------------------------------------------\n", - "11.0 M Trainable params\n", - "0 Non-trainable params\n", - "11.0 M Total params\n", - "43.817 Total estimated model params size (MB)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Sanity Checking DataLoader 0: 0%| | 0/1 [00:00" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "import matplotlib.pyplot as plt\n", - "fig = plt.figure(figsize=(18, 10))\n", - "\n", - "plt.imshow(prediction[0, :, :, 0, 0])\n", - "plt.xticks([])\n", - "plt.yticks([])\n", - "plt.colorbar(shrink=0.6)\n", - "\n", - "plt.suptitle(\"UNet Mean Forecast\")\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "import matplotlib.pyplot as plt\n", - "\n", - "fig, ax = plt.subplots(7, 7, figsize=(18, 10))\n", - "ax = ax.ravel()\n", - "\n", - "idx = 0\n", - "for workers, prediction in enumerate(predictions):\n", - " for batch in range(prediction.shape[0]):\n", - " # print(idx, worker, batch, prediction.shape, prediction[0, :, :, :, 0].shape)\n", - " for day in range(prediction.shape[-1]): # Loop through no. of forecast days (currently 7)\n", - " im = ax[idx].imshow(prediction[batch, :, :, 0, day], extent=[0, 1, 0, 1])\n", - " ax[idx].axis(\"off\")\n", - " plt.colorbar(im, ax=ax[idx], shrink=0.6)\n", - " idx += 1\n", - "\n", - "\n", - "plt.suptitle(\"UNet Mean Forecast\")\n", - "plt.subplots_adjust(wspace=0, hspace=0)\n", - "plt.margins(0)\n", - "plt.tight_layout()\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Create prediction output directory" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "metadata": {}, - "outputs": [], - "source": [ - "dataset = \"pytorch_notebook\"\n", - "network_name = \"api_dataset\"\n", - "output_name = \"example_pytorch_forecast\"\n", - "output_folder = os.path.join(\".\", \"results\", \"predict\", output_name,\n", - " \"{}.{}\".format(network_name, seed))\n", - "os.makedirs(output_folder, exist_ok=output_folder)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Convert and output predictions to numpy files" - ] - }, - { - "cell_type": "code", - "execution_count": 41, - "metadata": {}, - "outputs": [], - "source": [ - "batch = 0\n", - "batch_size = prediction.shape[0]\n", - "\n", - "idx = 0\n", - "for workers, prediction in enumerate(predictions):\n", - " for batch in range(prediction.shape[0]):\n", - " date = pd.Timestamp(test_dataset.dates[idx].replace('_', '-'))\n", - " output_path = os.path.join(output_folder, date.strftime(\"%Y_%m_%d.npy\"))\n", - " forecast = prediction[batch, :, :, :, :].movedim(-2, 0)\n", - " forecast_np = forecast.detach().cpu().numpy()\n", - " np.save(output_path, forecast_np)\n", - " idx += 1" - ] - }, - { - "cell_type": "code", - "execution_count": 42, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-12-30 00:24:36.013187: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", - "2023-12-30 00:24:36.013257: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", - "2023-12-30 00:24:36.014324: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", - "[30-12-23 00:24:39 :INFO ] - Loading configuration ./dataset_config.api_dataset.json\n", - "[30-12-23 00:24:39 :INFO ] - Training dataset path: ./network_datasets/api_dataset/south/train\n", - "[30-12-23 00:24:39 :INFO ] - Validation dataset path: ./network_datasets/api_dataset/south/val\n", - "[30-12-23 00:24:39 :INFO ] - Test dataset path: ./network_datasets/api_dataset/south/test\n", - "/data/hpcdata/users/bryald/git/icenet/icenet/icenet/process/predict.py:58: FutureWarning: Ignoring a datum in netCDF load for consistency with existing behaviour. In a future version of Iris, this datum will be applied. To apply the datum when loading, use the iris.FUTURE.datum_support flag.\n", - " cube = iris.load_cube(path, 'sea_ice_area_fraction')\n", - "/data/hpcdata/users/bryald/git/icenet/icenet/icenet/process/predict.py:58: FutureWarning: Ignoring a datum in netCDF load for consistency with existing behaviour. In a future version of Iris, this datum will be applied. To apply the datum when loading, use the iris.FUTURE.datum_support flag.\n", - " cube = iris.load_cube(path, 'sea_ice_area_fraction')\n", - "/data/hpcdata/users/bryald/git/icenet/icenet/icenet/process/predict.py:58: FutureWarning: Ignoring a datum in netCDF load for consistency with existing behaviour. In a future version of Iris, this datum will be applied. To apply the datum when loading, use the iris.FUTURE.datum_support flag.\n", - " cube = iris.load_cube(path, 'sea_ice_area_fraction')\n", - "/data/hpcdata/users/bryald/git/icenet/icenet/icenet/process/predict.py:58: FutureWarning: Ignoring a datum in netCDF load for consistency with existing behaviour. In a future version of Iris, this datum will be applied. To apply the datum when loading, use the iris.FUTURE.datum_support flag.\n", - " cube = iris.load_cube(path, 'sea_ice_area_fraction')\n", - "[30-12-23 00:24:40 :INFO ] - Post-processing 2020-04-01\n", - "[30-12-23 00:24:40 :INFO ] - Post-processing 2020-04-02\n", - "[30-12-23 00:24:40 :INFO ] - Dataset arr shape: (2, 432, 432, 7, 2)\n", - "[30-12-23 00:24:40 :INFO ] - Applying active grid cell masks\n", - "[30-12-23 00:24:40 :INFO ] - Land masking the forecast output\n", - "[30-12-23 00:24:40 :INFO ] - Applying zeros to land mask\n", - "[30-12-23 00:24:40 :INFO ] - Saving to results/predict/example_pytorch_forecast.nc\n" - ] - } - ], - "source": [ - "!icenet_output -o results/predict example_pytorch_forecast api_dataset testdates.csv -m" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Plotting the forecast" - ] - }, - { - "cell_type": "code", - "execution_count": 43, - "metadata": {}, - "outputs": [], - "source": [ - "import xarray as xr\n", - "import datetime as dt\n", - "from IPython.display import HTML" - ] - }, - { - "cell_type": "code", - "execution_count": 44, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "xarray.Dataset {\n", - "dimensions:\n", - "\ttime = 2 ;\n", - "\tyc = 432 ;\n", - "\txc = 432 ;\n", - "\tleadtime = 7 ;\n", - "\n", - "variables:\n", - "\tint32 Lambert_Azimuthal_Grid() ;\n", - "\t\tLambert_Azimuthal_Grid:grid_mapping_name = lambert_azimuthal_equal_area ;\n", - "\t\tLambert_Azimuthal_Grid:longitude_of_projection_origin = 0.0 ;\n", - "\t\tLambert_Azimuthal_Grid:latitude_of_projection_origin = -90.0 ;\n", - "\t\tLambert_Azimuthal_Grid:false_easting = 0.0 ;\n", - "\t\tLambert_Azimuthal_Grid:false_northing = 0.0 ;\n", - "\t\tLambert_Azimuthal_Grid:semi_major_axis = 6378137.0 ;\n", - "\t\tLambert_Azimuthal_Grid:inverse_flattening = 298.257223563 ;\n", - "\t\tLambert_Azimuthal_Grid:proj4_string = +proj=laea +lon_0=0 +datum=WGS84 +ellps=WGS84 +lat_0=-90.0 ;\n", - "\tfloat32 sic_mean(time, yc, xc, leadtime) ;\n", - "\t\tsic_mean:long_name = mean sea ice area fraction across ensemble runs of icenet model ;\n", - "\t\tsic_mean:standard_name = sea_ice_area_fraction ;\n", - "\t\tsic_mean:short_name = sic ;\n", - "\t\tsic_mean:valid_min = 0 ;\n", - "\t\tsic_mean:valid_max = 1 ;\n", - "\t\tsic_mean:ancillary_variables = sic_stddev ;\n", - "\t\tsic_mean:grid_mapping = Lambert_Azimuthal_Grid ;\n", - "\t\tsic_mean:units = 1 ;\n", - "\tfloat32 sic_stddev(time, yc, xc, leadtime) ;\n", - "\t\tsic_stddev:long_name = total uncertainty (one standard deviation) of concentration of sea ice ;\n", - "\t\tsic_stddev:standard_name = sea_ice_area_fraction standard_error ;\n", - "\t\tsic_stddev:valid_min = 0 ;\n", - "\t\tsic_stddev:valid_max = 1 ;\n", - "\t\tsic_stddev:grid_mapping = Lambert_Azimuthal_Grid ;\n", - "\t\tsic_stddev:units = 1 ;\n", - "\tint64 ensemble_members(time) ;\n", - "\t\tensemble_members:long_name = number of ensemble members used to create this prediction ;\n", - "\t\tensemble_members:short_name = ensemble_members ;\n", - "\tdatetime64[ns] time(time) ;\n", - "\t\ttime:long_name = reference time of product ;\n", - "\t\ttime:standard_name = time ;\n", - "\t\ttime:axis = T ;\n", - "\tint64 leadtime(leadtime) ;\n", - "\t\tleadtime:long_name = leadtime of forecast in relation to reference time ;\n", - "\t\tleadtime:short_name = leadtime ;\n", - "\tdatetime64[ns] forecast_date(time, leadtime) ;\n", - "\tfloat64 xc(xc) ;\n", - "\t\txc:long_name = x coordinate of projection (eastings) ;\n", - "\t\txc:standard_name = projection_x_coordinate ;\n", - "\t\txc:units = 1000 meter ;\n", - "\t\txc:axis = X ;\n", - "\tfloat64 yc(yc) ;\n", - "\t\tyc:long_name = y coordinate of projection (northings) ;\n", - "\t\tyc:standard_name = projection_y_coordinate ;\n", - "\t\tyc:units = 1000 meter ;\n", - "\t\tyc:axis = Y ;\n", - "\tfloat32 lat(yc, xc) ;\n", - "\t\tlat:long_name = latitude coordinate ;\n", - "\t\tlat:standard_name = latitude ;\n", - "\t\tlat:units = arc_degree ;\n", - "\tfloat32 lon(yc, xc) ;\n", - "\t\tlon:long_name = longitude coordinate ;\n", - "\t\tlon:standard_name = longitude ;\n", - "\t\tlon:units = arc_degree ;\n", - "\n", - "// global attributes:\n", - "\t:Conventions = CF-1.6 ACDD-1.3 ;\n", - "\t:comments = ;\n", - "\t:creator_email = jambyr@bas.ac.uk ;\n", - "\t:creator_institution = British Antarctic Survey ;\n", - "\t:creator_name = James Byrne ;\n", - "\t:creator_url = www.bas.ac.uk ;\n", - "\t:date_created = 2023-12-30 ;\n", - "\t:geospatial_bounds_crs = EPSG:6932 ;\n", - "\t:geospatial_lat_min = -90.0 ;\n", - "\t:geospatial_lat_max = -16.62393 ;\n", - "\t:geospatial_lon_min = -180.0 ;\n", - "\t:geospatial_lon_max = 180.0 ;\n", - "\t:geospatial_vertical_min = 0.0 ;\n", - "\t:geospatial_vertical_max = 0.0 ;\n", - "\t:history = 2023-12-30 00:24:40.616824 - creation ;\n", - "\t:id = IceNet 0.2.7a1 ;\n", - "\t:institution = British Antarctic Survey ;\n", - "\t:keywords = 'Earth Science > Cryosphere > Sea Ice > Sea Ice Concentration\n", - " Earth Science > Oceans > Sea Ice > Sea Ice Concentration\n", - " Earth Science > Climate Indicators > Cryospheric Indicators > Sea Ice\n", - " Geographic Region > Southern Hemisphere ;\n", - "\t:keywords_vocabulary = GCMD Science Keywords ;\n", - "\t:license = Open Government Licece (OGL) V3 ;\n", - "\t:naming_authority = uk.ac.bas ;\n", - "\t:platform = BAS HPC ;\n", - "\t:product_version = 0.2.7a1 ;\n", - "\t:project = IceNet ;\n", - "\t:publisher_email = ;\n", - "\t:publisher_institution = British Antarctic Survey ;\n", - "\t:publisher_url = ;\n", - "\t:source = \n", - " IceNet model generation at v0.2.7a1\n", - " ;\n", - "\t:spatial_resolution = 25.0 km grid spacing ;\n", - "\t:standard_name_vocabulary = CF Standard Name Table v27 ;\n", - "\t:summary = \n", - " This is an output of sea ice concentration predictions from the\n", - " IceNet run in an ensemble, with postprocessing to determine\n", - " the mean and standard deviation across the runs.\n", - " ;\n", - "\t:time_coverage_start = 2020-04-02T00:00:00 ;\n", - "\t:time_coverage_end = 2020-04-09T00:00:00 ;\n", - "\t:time_coverage_duration = P1D ;\n", - "\t:time_coverage_resolution = P1D ;\n", - "\t:title = Sea Ice Concentration Prediction ;\n", - "}" - ] - } - ], - "source": [ - "from icenet.plotting.video import xarray_to_video as xvid\n", - "from icenet.data.sic.mask import Masks\n", - "\n", - "ds = xr.open_dataset(\"results/predict/example_pytorch_forecast.nc\")\n", - "land_mask = Masks(south=True, north=False).get_land_mask()\n", - "ds.info()" - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2020-04-01T00:00:00.000000000\n" - ] - } - ], - "source": [ - "forecast_date = ds.time.values[0]\n", - "print(forecast_date)" - ] - }, - { - "cell_type": "code", - "execution_count": 46, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:root:Inspecting data\n", - "INFO:root:Initialising plot\n", - "INFO:root:Animating\n", - "INFO:root:Not saving plot, will return animation\n", - "INFO:matplotlib.animation:Animation.save using \n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
\n", - " \n", - "
\n", - " \n", - "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 46, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "fc = ds.sic_mean.isel(time=0).drop_vars(\"time\").rename(dict(leadtime=\"time\"))\n", - "fc['time'] = [pd.to_datetime(forecast_date) \\\n", - " + dt.timedelta(days=int(e)) for e in fc.time.values]\n", - "\n", - "anim = xvid(fc, 15, figsize=4, mask=land_mask)\n", - "HTML(anim.to_jshtml())" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Check min/max of predicted SIC fraction" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(1, 432, 432)\n", - "First forecast day min: 0.0077, max: 0.8985\n" - ] - } - ], - "source": [ - "print( forecast_np[:, :, :, 0].shape )\n", - "fmin, fmax = np.min(forecast_np[:, :, :, 0]), np.max(forecast_np[:, :, :, 0])\n", - "print( f\"First forecast day min: {fmin:.4f}, max: {fmax:.4f}\" )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Load original input dataset" - ] - }, - { - "cell_type": "code", - "execution_count": 48, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 48, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "# Load original input dataset (domain not normalised)\n", - "xr.plot.contourf(xr.open_dataset(\"data/osisaf/south/siconca/2020.nc\").isel(time=91).ice_conc, levels=50)\n", - "\n", - "# Load processed - normalised dataset\n", - "# xr.plot.contourf(xr.open_dataset(\"processed/notebook_data/osisaf/south/siconca/siconca_abs.nc\").isel(time=91).ice_conc, levels=50)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Version\n", - "- IceNet Codebase: v0.2.7_dev" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "pytorch", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.6" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/6_tensorflow_normal_run.ipynb b/6_tensorflow_normal_run.ipynb deleted file mode 100644 index 67e0653..0000000 --- a/6_tensorflow_normal_run.ipynb +++ /dev/null @@ -1,20918 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import pandas as pd\n", - "import xarray as xr\n", - "import datetime as dt\n", - "from IPython.display import HTML\n", - "\n", - "# We also set the logging level so that we get some feedback from the API\n", - "import logging\n", - "logging.basicConfig(level=logging.INFO)" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Running in /data/hpcdata/users/bryald/git/icenet/notebook-pipeline\n" - ] - } - ], - "source": [ - "# Quick hack to put us in the icenet-pipeline folder,\n", - "# assuming it was created as per 01.cli_demonstration.ipynb\n", - "import os\n", - "if os.path.exists(\"6_tensorflow_normal_run.ipynb\"):\n", - " os.chdir(\"../notebook-pipeline\")\n", - "print(\"Running in {}\".format(os.getcwd()))\n", - "\n", - "%matplotlib inline" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Download required datasets" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[31-12-23 20:24:25 :INFO ] - siconca ice_conc_sh_ease2-250_cdr-v2p0_200001021200.nc already exists\n", - "[31-12-23 20:24:26 :INFO ] - Saving ./data/masks/south/masks/active_grid_cell_mask_01.npy\n", - "[31-12-23 20:24:26 :INFO ] - siconca ice_conc_sh_ease2-250_cdr-v2p0_200002021200.nc already exists\n", - "[31-12-23 20:24:26 :INFO ] - Saving ./data/masks/south/masks/active_grid_cell_mask_02.npy\n", - "[31-12-23 20:24:26 :INFO ] - siconca ice_conc_sh_ease2-250_cdr-v2p0_200003021200.nc already exists\n", - "[31-12-23 20:24:26 :INFO ] - Saving ./data/masks/south/masks/active_grid_cell_mask_03.npy\n", - "[31-12-23 20:24:26 :INFO ] - siconca ice_conc_sh_ease2-250_cdr-v2p0_200004021200.nc already exists\n", - "[31-12-23 20:24:26 :INFO ] - Saving ./data/masks/south/masks/active_grid_cell_mask_04.npy\n", - "[31-12-23 20:24:26 :INFO ] - siconca ice_conc_sh_ease2-250_cdr-v2p0_200005021200.nc already exists\n", - "[31-12-23 20:24:26 :INFO ] - Saving ./data/masks/south/masks/active_grid_cell_mask_05.npy\n", - "[31-12-23 20:24:26 :INFO ] - siconca ice_conc_sh_ease2-250_cdr-v2p0_200006021200.nc already exists\n", - "[31-12-23 20:24:26 :INFO ] - Saving ./data/masks/south/masks/active_grid_cell_mask_06.npy\n", - "[31-12-23 20:24:26 :INFO ] - siconca ice_conc_sh_ease2-250_cdr-v2p0_200007021200.nc already exists\n", - "[31-12-23 20:24:26 :INFO ] - Saving ./data/masks/south/masks/active_grid_cell_mask_07.npy\n", - "[31-12-23 20:24:26 :INFO ] - siconca ice_conc_sh_ease2-250_cdr-v2p0_200008021200.nc already exists\n", - "[31-12-23 20:24:26 :INFO ] - Saving ./data/masks/south/masks/active_grid_cell_mask_08.npy\n", - "[31-12-23 20:24:26 :INFO ] - siconca ice_conc_sh_ease2-250_cdr-v2p0_200009021200.nc already exists\n", - "[31-12-23 20:24:26 :INFO ] - Saving ./data/masks/south/masks/active_grid_cell_mask_09.npy\n", - "[31-12-23 20:24:26 :INFO ] - siconca ice_conc_sh_ease2-250_cdr-v2p0_200010021200.nc already exists\n", - "[31-12-23 20:24:26 :INFO ] - Saving ./data/masks/south/masks/active_grid_cell_mask_10.npy\n", - "[31-12-23 20:24:26 :INFO ] - siconca ice_conc_sh_ease2-250_cdr-v2p0_200011021200.nc already exists\n", - "[31-12-23 20:24:26 :INFO ] - Saving ./data/masks/south/masks/active_grid_cell_mask_11.npy\n", - "[31-12-23 20:24:26 :INFO ] - siconca ice_conc_sh_ease2-250_cdr-v2p0_200012021200.nc already exists\n", - "[31-12-23 20:24:26 :INFO ] - Saving ./data/masks/south/masks/active_grid_cell_mask_12.npy\n" - ] - } - ], - "source": [ - "# Download masks\n", - "!icenet_data_masks south" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[31-12-23 20:24:28 :INFO ] - ERA5 Data Downloading\n", - "[31-12-23 20:24:28 :WARNING ] - !!! Deletions of temp files are switched off: be careful with this, you need to manage your files manually\n", - "[31-12-23 20:24:28 :INFO ] - Building request(s), downloading and daily averaging from ERA5 API\n", - "[31-12-23 20:24:28 :INFO ] - Processing single download for uas @ None with 121 dates\n", - "[31-12-23 20:24:28 :INFO ] - Processing single download for vas @ None with 121 dates\n", - "[31-12-23 20:24:28 :INFO ] - Processing single download for tas @ None with 121 dates\n", - "[31-12-23 20:24:28 :INFO ] - Processing single download for zg @ 500 with 121 dates\n", - "[31-12-23 20:24:28 :INFO ] - Processing single download for zg @ 250 with 121 dates\n", - "[31-12-23 20:24:29 :INFO ] - No requested dates remain, likely already present\n", - "[31-12-23 20:24:29 :INFO ] - No requested dates remain, likely already present\n", - "[31-12-23 20:24:29 :INFO ] - No requested dates remain, likely already present\n", - "[31-12-23 20:24:29 :INFO ] - No requested dates remain, likely already present\n", - "[31-12-23 20:24:29 :INFO ] - No requested dates remain, likely already present\n", - "[31-12-23 20:24:29 :INFO ] - 0 daily files downloaded\n", - "[31-12-23 20:24:29 :INFO ] - No regrid batches to processing, moving on...\n", - "[31-12-23 20:24:29 :INFO ] - Rotating wind data prior to merging\n", - "/data/hpcdata/users/bryald/git/icenet/icenet/icenet/data/interfaces/downloader.py:361: FutureWarning: Ignoring a datum in netCDF load for consistency with existing behaviour. In a future version of Iris, this datum will be applied. To apply the datum when loading, use the iris.FUTURE.datum_support flag.\n", - " iris.load_cube(sic_day_path, 'sea_ice_area_fraction')\n", - "/data/hpcdata/users/bryald/git/icenet/icenet/icenet/data/interfaces/downloader.py:361: FutureWarning: Ignoring a datum in netCDF load for consistency with existing behaviour. In a future version of Iris, this datum will be applied. To apply the datum when loading, use the iris.FUTURE.datum_support flag.\n", - " iris.load_cube(sic_day_path, 'sea_ice_area_fraction')\n", - "/data/hpcdata/users/bryald/git/icenet/icenet/icenet/data/interfaces/downloader.py:361: FutureWarning: Ignoring a datum in netCDF load for consistency with existing behaviour. In a future version of Iris, this datum will be applied. To apply the datum when loading, use the iris.FUTURE.datum_support flag.\n", - " iris.load_cube(sic_day_path, 'sea_ice_area_fraction')\n", - "/data/hpcdata/users/bryald/git/icenet/icenet/icenet/data/interfaces/downloader.py:361: FutureWarning: Ignoring a datum in netCDF load for consistency with existing behaviour. In a future version of Iris, this datum will be applied. To apply the datum when loading, use the iris.FUTURE.datum_support flag.\n", - " iris.load_cube(sic_day_path, 'sea_ice_area_fraction')\n", - "[31-12-23 20:24:29 :INFO ] - Rotating wind data in ./data/era5/south/uas ./data/era5/south/vas\n", - "[31-12-23 20:24:29 :INFO ] - 0 files for uas\n", - "[31-12-23 20:24:29 :INFO ] - 0 files for vas\n" - ] - } - ], - "source": [ - "# Download climate data - ERA5 reanalysis data\n", - "!icenet_data_era5 south -d --vars uas,vas,tas,zg --levels ',,,500|250' 2020-1-1 2020-4-30" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[31-12-23 20:24:30 :INFO ] - OSASIF-SIC Data Downloading\n", - "[31-12-23 20:24:30 :INFO ] - Downloading SIC datafiles to .temp intermediates...\n", - "[31-12-23 20:24:31 :INFO ] - Excluding 121 dates already existing from 121 dates requested.\n", - "[31-12-23 20:24:39 :INFO ] - Existing file needs concatenating: ./data/osisaf/south/siconca/2020.nc -> ./data/osisaf/south/siconca/old.2020.nc\n", - "[31-12-23 20:24:40 :INFO ] - Saving ./data/osisaf/south/siconca/2020.nc\n", - "[31-12-23 20:24:41 :INFO ] - Opening for interpolation: ['./data/osisaf/south/siconca/2020.nc']\n", - "[31-12-23 20:24:41 :INFO ] - Processing 0 missing dates\n" - ] - } - ], - "source": [ - "# Download sea ice concentration (%) from OSI-SAF\n", - "!icenet_data_sic south -d 2020-1-1 2020-4-30" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Process above data downloads - normalising to use as inputs for UNet\n", - "\n", - "This creates loader.{name}.json.\n", - "\n", - "Also creates train, val, test splits based on CLI arguments within the loader json file." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-12-31 20:24:43.308849: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", - "2023-12-31 20:24:43.308961: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", - "2023-12-31 20:24:43.312233: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", - "[31-12-23 20:24:46 :INFO ] - Got 91 dates for train\n", - "[31-12-23 20:24:46 :INFO ] - Got 21 dates for val\n", - "[31-12-23 20:24:46 :INFO ] - Got 2 dates for test\n", - "[31-12-23 20:24:46 :INFO ] - Processing 91 dates for train category\n", - "[31-12-23 20:24:46 :INFO ] - Including lag of 1 days\n", - "[31-12-23 20:24:46 :INFO ] - Including lead of 93 days\n", - "[31-12-23 20:24:46 :INFO ] - No data found for 2019-12-31, outside data boundary perhaps?\n", - "[31-12-23 20:24:46 :INFO ] - Processing 21 dates for val category\n", - "[31-12-23 20:24:46 :INFO ] - Including lag of 1 days\n", - "[31-12-23 20:24:46 :INFO ] - Including lead of 93 days\n", - "[31-12-23 20:24:46 :INFO ] - Processing 2 dates for test category\n", - "[31-12-23 20:24:46 :INFO ] - Including lag of 1 days\n", - "[31-12-23 20:24:46 :INFO ] - Including lead of 93 days\n", - "[31-12-23 20:24:46 :INFO ] - Got 1 files for tas\n", - "[31-12-23 20:24:46 :INFO ] - Got 1 files for uas\n", - "[31-12-23 20:24:46 :INFO ] - Got 1 files for vas\n", - "[31-12-23 20:24:46 :INFO ] - Got 1 files for zg250\n", - "[31-12-23 20:24:46 :INFO ] - Got 1 files for zg500\n", - "[31-12-23 20:24:46 :INFO ] - Opening files for uas\n", - "[31-12-23 20:24:46 :INFO ] - Filtered to 121 units long based on configuration requirements\n", - "[31-12-23 20:24:47 :INFO ] - Normalising uas\n", - "[31-12-23 20:24:47 :INFO ] - Opening files for vas\n", - "[31-12-23 20:24:47 :INFO ] - Filtered to 121 units long based on configuration requirements\n", - "[31-12-23 20:24:47 :INFO ] - Normalising vas\n", - "[31-12-23 20:24:48 :INFO ] - Opening files for tas\n", - "[31-12-23 20:24:48 :INFO ] - Filtered to 121 units long based on configuration requirements\n", - "[31-12-23 20:24:48 :INFO ] - Reusing climatology ./processed/notebook_tf_data/era5/south/params/climatology.tas\n", - "[31-12-23 20:24:48 :WARNING ] - We don't have a full climatology (1,2,3) compared with data (1,2,3,4)\n", - "[31-12-23 20:24:48 :INFO ] - Normalising tas\n", - "[31-12-23 20:24:48 :INFO ] - Opening files for zg500\n", - "[31-12-23 20:24:48 :INFO ] - Filtered to 121 units long based on configuration requirements\n", - "[31-12-23 20:24:48 :INFO ] - Reusing climatology ./processed/notebook_tf_data/era5/south/params/climatology.zg500\n", - "[31-12-23 20:24:48 :WARNING ] - We don't have a full climatology (1,2,3) compared with data (1,2,3,4)\n", - "[31-12-23 20:24:48 :INFO ] - Normalising zg500\n", - "[31-12-23 20:24:48 :INFO ] - Opening files for zg250\n", - "[31-12-23 20:24:48 :INFO ] - Filtered to 121 units long based on configuration requirements\n", - "[31-12-23 20:24:48 :INFO ] - Reusing climatology ./processed/notebook_tf_data/era5/south/params/climatology.zg250\n", - "[31-12-23 20:24:48 :WARNING ] - We don't have a full climatology (1,2,3) compared with data (1,2,3,4)\n", - "[31-12-23 20:24:48 :INFO ] - Normalising zg250\n", - "[31-12-23 20:24:49 :INFO ] - Loading configuration ./loader.notebook_tf_data.json\n", - "[31-12-23 20:24:49 :INFO ] - Writing configuration to ./loader.notebook_tf_data.json\n", - "2023-12-31 20:24:51.009891: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", - "2023-12-31 20:24:51.009963: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", - "2023-12-31 20:24:51.011149: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", - "[31-12-23 20:24:52 :INFO ] - Got 91 dates for train\n", - "[31-12-23 20:24:52 :INFO ] - Got 20 dates for val\n", - "[31-12-23 20:24:52 :INFO ] - Got 2 dates for test\n", - "[31-12-23 20:24:52 :INFO ] - Processing 91 dates for train category\n", - "[31-12-23 20:24:52 :INFO ] - Including lag of 1 days\n", - "[31-12-23 20:24:52 :INFO ] - Including lead of 93 days\n", - "[31-12-23 20:24:52 :INFO ] - No data found for 2019-12-31, outside data boundary perhaps?\n", - "[31-12-23 20:24:52 :INFO ] - Processing 20 dates for val category\n", - "[31-12-23 20:24:52 :INFO ] - Including lag of 1 days\n", - "[31-12-23 20:24:52 :INFO ] - Including lead of 93 days\n", - "[31-12-23 20:24:52 :INFO ] - Processing 2 dates for test category\n", - "[31-12-23 20:24:52 :INFO ] - Including lag of 1 days\n", - "[31-12-23 20:24:52 :INFO ] - Including lead of 93 days\n", - "[31-12-23 20:24:52 :INFO ] - Got 1 files for siconca\n", - "[31-12-23 20:24:52 :INFO ] - Opening files for siconca\n", - "[31-12-23 20:24:53 :INFO ] - Filtered to 121 units long based on configuration requirements\n", - "[31-12-23 20:24:53 :INFO ] - No normalisation for siconca\n", - "[31-12-23 20:24:54 :INFO ] - Loading configuration ./loader.notebook_tf_data.json\n", - "[31-12-23 20:24:54 :INFO ] - Writing configuration to ./loader.notebook_tf_data.json\n", - "2023-12-31 20:24:56.108180: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", - "2023-12-31 20:24:56.108233: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", - "2023-12-31 20:24:56.109278: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", - "[31-12-23 20:24:58 :INFO ] - Loading configuration ./loader.notebook_tf_data.json\n", - "[31-12-23 20:24:58 :INFO ] - Writing configuration to ./loader.notebook_tf_data.json\n" - ] - } - ], - "source": [ - "# Process ERA5\n", - "!icenet_process_era5 notebook_tf_data south \\\n", - " -ns 2020-1-1 -ne 2020-3-31 -vs 2020-4-3 -ve 2020-4-23 -ts 2020-4-1 -te 2020-4-2 \\\n", - " -l 1 --abs uas,vas --anom tas,zg500,zg250\n", - "\n", - "# Process SIC\n", - "!icenet_process_sic notebook_tf_data south \\\n", - " -ns 2020-1-1 -ne 2020-3-31 -vs 2020-4-1 -ve 2020-4-20 -ts 2020-4-1 -te 2020-4-2 \\\n", - " -l 1 --abs siconca\n", - "\n", - "!icenet_process_metadata notebook_tf_data south" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Creates `dataset_config.{name}.json`, and cached tfrecords dataset for training.\n", - "\n", - "If running in config-only model, will create json file, but not cached dataset (i.e., for just prediction)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-12-31 20:24:59.608475: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", - "2023-12-31 20:24:59.608527: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", - "2023-12-31 20:24:59.609476: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", - "[31-12-23 20:25:01 :INFO ] - Got 0 dates for train\n", - "[31-12-23 20:25:01 :INFO ] - Got 0 dates for val\n", - "[31-12-23 20:25:01 :INFO ] - Got 0 dates for test\n", - "[31-12-23 20:25:01 :INFO ] - Loading configuration loader.notebook_tf_data.json\n", - "[31-12-23 20:25:02 :INFO ] - Dashboard at localhost:8888\n", - "[31-12-23 20:25:02 :INFO ] - Using dask client \n", - "[31-12-23 20:25:02 :INFO ] - 91 train dates to process, generating cache data.\n", - "2023-12-31 20:25:03.000140: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", - "2023-12-31 20:25:03.000199: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", - "2023-12-31 20:25:03.001188: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", - "2023-12-31 20:25:03.014793: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", - "2023-12-31 20:25:03.014856: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", - "2023-12-31 20:25:03.015754: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", - "2023-12-31 20:25:03.034417: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", - "2023-12-31 20:25:03.034417: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", - "2023-12-31 20:25:03.034471: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", - "2023-12-31 20:25:03.034504: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", - "2023-12-31 20:25:03.035564: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", - "2023-12-31 20:25:03.035564: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", - "[31-12-23 20:25:17 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000000.tfrecord\n", - "[31-12-23 20:25:17 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000001.tfrecord\n", - "[31-12-23 20:25:17 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000002.tfrecord\n", - "[31-12-23 20:25:17 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000003.tfrecord\n", - "[31-12-23 20:25:17 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000004.tfrecord\n", - "[31-12-23 20:25:17 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000005.tfrecord\n", - "[31-12-23 20:25:17 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000006.tfrecord\n", - "[31-12-23 20:25:17 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000007.tfrecord\n", - "[31-12-23 20:25:29 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000008.tfrecord\n", - "[31-12-23 20:25:29 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000009.tfrecord\n", - "[31-12-23 20:25:29 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000010.tfrecord\n", - "[31-12-23 20:25:29 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000011.tfrecord\n", - "[31-12-23 20:25:29 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000012.tfrecord\n", - "[31-12-23 20:25:29 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000013.tfrecord\n", - "[31-12-23 20:25:29 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000014.tfrecord\n", - "[31-12-23 20:25:29 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000015.tfrecord\n", - "[31-12-23 20:25:41 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000016.tfrecord\n", - "[31-12-23 20:25:41 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000017.tfrecord\n", - "[31-12-23 20:25:41 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000018.tfrecord\n", - "[31-12-23 20:25:41 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000019.tfrecord\n", - "[31-12-23 20:25:41 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000020.tfrecord\n", - "[31-12-23 20:25:41 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000021.tfrecord\n", - "[31-12-23 20:25:41 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000022.tfrecord\n", - "[31-12-23 20:25:41 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000023.tfrecord\n", - "[31-12-23 20:25:54 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000024.tfrecord\n", - "[31-12-23 20:25:54 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000025.tfrecord\n", - "[31-12-23 20:25:54 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000026.tfrecord\n", - "[31-12-23 20:25:54 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000027.tfrecord\n", - "[31-12-23 20:25:54 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000028.tfrecord\n", - "[31-12-23 20:25:54 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000029.tfrecord\n", - "[31-12-23 20:25:54 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000030.tfrecord\n", - "[31-12-23 20:25:54 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000031.tfrecord\n", - "[31-12-23 20:26:05 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000032.tfrecord\n", - "[31-12-23 20:26:05 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000033.tfrecord\n", - "[31-12-23 20:26:05 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000034.tfrecord\n", - "[31-12-23 20:26:05 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000035.tfrecord\n", - "[31-12-23 20:26:05 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000036.tfrecord\n", - "[31-12-23 20:26:05 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000037.tfrecord\n", - "[31-12-23 20:26:05 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000038.tfrecord\n", - "[31-12-23 20:26:05 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000039.tfrecord\n", - "[31-12-23 20:26:14 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000040.tfrecord\n", - "[31-12-23 20:26:14 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000041.tfrecord\n", - "[31-12-23 20:26:14 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000042.tfrecord\n", - "[31-12-23 20:26:14 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000043.tfrecord\n", - "[31-12-23 20:26:14 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000044.tfrecord\n", - "[31-12-23 20:26:14 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/train/00000045.tfrecord\n", - "[31-12-23 20:26:14 :INFO ] - 23 val dates to process, generating cache data.\n", - "[31-12-23 20:26:27 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/val/00000000.tfrecord\n", - "[31-12-23 20:26:27 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/val/00000001.tfrecord\n", - "[31-12-23 20:26:27 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/val/00000002.tfrecord\n", - "[31-12-23 20:26:27 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/val/00000003.tfrecord\n", - "[31-12-23 20:26:27 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/val/00000004.tfrecord\n", - "[31-12-23 20:26:27 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/val/00000005.tfrecord\n", - "[31-12-23 20:26:27 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/val/00000006.tfrecord\n", - "[31-12-23 20:26:27 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/val/00000007.tfrecord\n", - "[31-12-23 20:26:34 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/val/00000008.tfrecord\n", - "[31-12-23 20:26:34 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/val/00000009.tfrecord\n", - "[31-12-23 20:26:34 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/val/00000010.tfrecord\n", - "[31-12-23 20:26:34 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/val/00000011.tfrecord\n", - "[31-12-23 20:26:34 :INFO ] - 2 test dates to process, generating cache data.\n", - "[31-12-23 20:26:38 :INFO ] - Finished output ./network_datasets/notebook_tf_data/south/test/00000000.tfrecord\n", - "[31-12-23 20:26:38 :INFO ] - Average sample generation time: 4.481075739038402\n", - "[31-12-23 20:26:38 :INFO ] - Writing configuration to ./dataset_config.notebook_tf_data.json\n" - ] - } - ], - "source": [ - "!icenet_dataset_create -l 1 -fd 7 -ob 2 -w 4 notebook_tf_data south" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## IceNet UNet model\n", - "\n", - "Running tensorflow UNet model as normal." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-12-31 20:26:42.163574: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", - "2023-12-31 20:26:42.163623: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", - "2023-12-31 20:26:42.164691: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", - "[31-12-23 20:26:47 :WARNING ] - Setting seed for best attempt at determinism, value 42\n", - "[31-12-23 20:26:47 :INFO ] - Loading configuration dataset_config.notebook_tf_data.json\n", - "[31-12-23 20:26:47 :INFO ] - Training dataset path: ./network_datasets/notebook_tf_data/south/train\n", - "[31-12-23 20:26:47 :INFO ] - Validation dataset path: ./network_datasets/notebook_tf_data/south/val\n", - "[31-12-23 20:26:47 :INFO ] - Test dataset path: ./network_datasets/notebook_tf_data/south/test\n", - "[31-12-23 20:26:47 :WARNING ] - WandB is not available, we will never use it\n", - "[31-12-23 20:26:47 :INFO ] - Adding tensorboard callback\n", - "Model: \"model\"\n", - "__________________________________________________________________________________________________\n", - " Layer (type) Output Shape Param # Connected to \n", - "==================================================================================================\n", - " input_1 (InputLayer) [(None, 432, 432, 9)] 0 [] \n", - " \n", - " conv2d (Conv2D) (None, 432, 432, 19) 1558 ['input_1[0][0]'] \n", - " \n", - " conv2d_1 (Conv2D) (None, 432, 432, 19) 3268 ['conv2d[0][0]'] \n", - " \n", - " batch_normalization (Batch (None, 432, 432, 19) 76 ['conv2d_1[0][0]'] \n", - " Normalization) \n", - " \n", - " max_pooling2d (MaxPooling2 (None, 216, 216, 19) 0 ['batch_normalization[0][0]'] \n", - " D) \n", - " \n", - " conv2d_2 (Conv2D) (None, 216, 216, 38) 6536 ['max_pooling2d[0][0]'] \n", - " \n", - " conv2d_3 (Conv2D) (None, 216, 216, 38) 13034 ['conv2d_2[0][0]'] \n", - " \n", - " batch_normalization_1 (Bat (None, 216, 216, 38) 152 ['conv2d_3[0][0]'] \n", - " chNormalization) \n", - " \n", - " max_pooling2d_1 (MaxPoolin (None, 108, 108, 38) 0 ['batch_normalization_1[0][0]'\n", - " g2D) ] \n", - " \n", - " conv2d_4 (Conv2D) (None, 108, 108, 76) 26068 ['max_pooling2d_1[0][0]'] \n", - " \n", - " conv2d_5 (Conv2D) (None, 108, 108, 76) 52060 ['conv2d_4[0][0]'] \n", - " \n", - " batch_normalization_2 (Bat (None, 108, 108, 76) 304 ['conv2d_5[0][0]'] \n", - " chNormalization) \n", - " \n", - " max_pooling2d_2 (MaxPoolin (None, 54, 54, 76) 0 ['batch_normalization_2[0][0]'\n", - " g2D) ] \n", - " \n", - " conv2d_6 (Conv2D) (None, 54, 54, 76) 52060 ['max_pooling2d_2[0][0]'] \n", - " \n", - " conv2d_7 (Conv2D) (None, 54, 54, 76) 52060 ['conv2d_6[0][0]'] \n", - " \n", - " batch_normalization_3 (Bat (None, 54, 54, 76) 304 ['conv2d_7[0][0]'] \n", - " chNormalization) \n", - " \n", - " max_pooling2d_3 (MaxPoolin (None, 27, 27, 76) 0 ['batch_normalization_3[0][0]'\n", - " g2D) ] \n", - " \n", - " conv2d_8 (Conv2D) (None, 27, 27, 152) 104120 ['max_pooling2d_3[0][0]'] \n", - " \n", - " conv2d_9 (Conv2D) (None, 27, 27, 152) 208088 ['conv2d_8[0][0]'] \n", - " \n", - " batch_normalization_4 (Bat (None, 27, 27, 152) 608 ['conv2d_9[0][0]'] \n", - " chNormalization) \n", - " \n", - " up_sampling2d (UpSampling2 (None, 54, 54, 152) 0 ['batch_normalization_4[0][0]'\n", - " D) ] \n", - " \n", - " conv2d_10 (Conv2D) (None, 54, 54, 76) 46284 ['up_sampling2d[0][0]'] \n", - " \n", - " concatenate (Concatenate) (None, 54, 54, 152) 0 ['batch_normalization_3[0][0]'\n", - " , 'conv2d_10[0][0]'] \n", - " \n", - " conv2d_11 (Conv2D) (None, 54, 54, 76) 104044 ['concatenate[0][0]'] \n", - " \n", - " conv2d_12 (Conv2D) (None, 54, 54, 76) 52060 ['conv2d_11[0][0]'] \n", - " \n", - " batch_normalization_5 (Bat (None, 54, 54, 76) 304 ['conv2d_12[0][0]'] \n", - " chNormalization) \n", - " \n", - " up_sampling2d_1 (UpSamplin (None, 108, 108, 76) 0 ['batch_normalization_5[0][0]'\n", - " g2D) ] \n", - " \n", - " conv2d_13 (Conv2D) (None, 108, 108, 76) 23180 ['up_sampling2d_1[0][0]'] \n", - " \n", - " concatenate_1 (Concatenate (None, 108, 108, 152) 0 ['batch_normalization_2[0][0]'\n", - " ) , 'conv2d_13[0][0]'] \n", - " \n", - " conv2d_14 (Conv2D) (None, 108, 108, 76) 104044 ['concatenate_1[0][0]'] \n", - " \n", - " conv2d_15 (Conv2D) (None, 108, 108, 76) 52060 ['conv2d_14[0][0]'] \n", - " \n", - " batch_normalization_6 (Bat (None, 108, 108, 76) 304 ['conv2d_15[0][0]'] \n", - " chNormalization) \n", - " \n", - " up_sampling2d_2 (UpSamplin (None, 216, 216, 76) 0 ['batch_normalization_6[0][0]'\n", - " g2D) ] \n", - " \n", - " conv2d_16 (Conv2D) (None, 216, 216, 38) 11590 ['up_sampling2d_2[0][0]'] \n", - " \n", - " concatenate_2 (Concatenate (None, 216, 216, 76) 0 ['batch_normalization_1[0][0]'\n", - " ) , 'conv2d_16[0][0]'] \n", - " \n", - " conv2d_17 (Conv2D) (None, 216, 216, 38) 26030 ['concatenate_2[0][0]'] \n", - " \n", - " conv2d_18 (Conv2D) (None, 216, 216, 38) 13034 ['conv2d_17[0][0]'] \n", - " \n", - " batch_normalization_7 (Bat (None, 216, 216, 38) 152 ['conv2d_18[0][0]'] \n", - " chNormalization) \n", - " \n", - " up_sampling2d_3 (UpSamplin (None, 432, 432, 38) 0 ['batch_normalization_7[0][0]'\n", - " g2D) ] \n", - " \n", - " conv2d_19 (Conv2D) (None, 432, 432, 19) 2907 ['up_sampling2d_3[0][0]'] \n", - " \n", - " concatenate_3 (Concatenate (None, 432, 432, 38) 0 ['conv2d_1[0][0]', \n", - " ) 'conv2d_19[0][0]'] \n", - " \n", - " conv2d_20 (Conv2D) (None, 432, 432, 19) 6517 ['concatenate_3[0][0]'] \n", - " \n", - " conv2d_21 (Conv2D) (None, 432, 432, 19) 3268 ['conv2d_20[0][0]'] \n", - " \n", - " conv2d_22 (Conv2D) (None, 432, 432, 19) 3268 ['conv2d_21[0][0]'] \n", - " \n", - " conv2d_23 (Conv2D) (None, 432, 432, 7) 140 ['conv2d_22[0][0]'] \n", - " \n", - "==================================================================================================\n", - "Total params: 969482 (3.70 MB)\n", - "Trainable params: 968380 (3.69 MB)\n", - "Non-trainable params: 1102 (4.30 KB)\n", - "__________________________________________________________________________________________________\n", - "[31-12-23 20:26:48 :INFO ] - Datasets: 46 train, 12 val and 1 test filenames\n", - "[31-12-23 20:26:48 :INFO ] - Reducing datasets to 1.0 of total files\n", - "[31-12-23 20:26:48 :INFO ] - Reduced: 46 train, 12 val and 1 test filenames\n", - "[31-12-23 20:26:49 :INFO ] - \n", - "Setting learning rate to: 9.999999747378752e-05\n", - "\n", - "Epoch 1/10\n", - "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", - "I0000 00:00:1704054417.983441 38901 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.\n", - "\n", - "Epoch 1: val_rmse improved from inf to 43.05126, saving model to ./results/networks/notebook_tf_testrun/notebook_tf_testrun.network_notebook_tf_data.42.h5\n", - "/data/hpcdata/users/bryald/miniconda3/envs/pytorch/lib/python3.11/site-packages/keras/src/engine/training.py:3103: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')`.\n", - " saving_api.save_model(\n", - "23/23 - 56s - loss: 343.9255 - binacc: 25.7934 - mae: 39.2011 - rmse: 43.4948 - mse: 2284.5085 - val_loss: 336.9477 - val_binacc: 36.9813 - val_mae: 40.3724 - val_rmse: 43.0513 - val_mse: 2037.0499 - lr: 1.0000e-04 - 56s/epoch - 2s/step\n", - "[31-12-23 20:27:45 :INFO ] - \n", - "Setting learning rate to: 9.999999747378752e-05\n", - "\n", - "Epoch 2/10\n", - "\n", - "Epoch 2: val_rmse improved from 43.05126 to 40.64370, saving model to ./results/networks/notebook_tf_testrun/notebook_tf_testrun.network_notebook_tf_data.42.h5\n", - "23/23 - 7s - loss: 256.8361 - binacc: 37.4513 - mae: 32.5517 - rmse: 37.5866 - mse: 1888.7924 - val_loss: 300.3152 - val_binacc: 36.9813 - val_mae: 38.6907 - val_rmse: 40.6437 - val_mse: 1913.4229 - lr: 1.0000e-04 - 7s/epoch - 303ms/step\n", - "[31-12-23 20:27:52 :INFO ] - \n", - "Setting learning rate to: 9.999999747378752e-05\n", - "\n", - "Epoch 3/10\n", - "\n", - "Epoch 3: val_rmse improved from 40.64370 to 38.05367, saving model to ./results/networks/notebook_tf_testrun/notebook_tf_testrun.network_notebook_tf_data.42.h5\n", - "23/23 - 7s - loss: 154.3864 - binacc: 54.3259 - mae: 23.8672 - rmse: 29.1413 - mse: 1519.2355 - val_loss: 263.2594 - val_binacc: 38.0947 - val_mae: 35.5122 - val_rmse: 38.0537 - val_mse: 1695.7881 - lr: 1.0000e-04 - 7s/epoch - 304ms/step\n", - "[31-12-23 20:27:59 :INFO ] - \n", - "Setting learning rate to: 9.999999747378752e-05\n", - "\n", - "Epoch 4/10\n", - "\n", - "Epoch 4: val_rmse improved from 38.05367 to 37.05745, saving model to ./results/networks/notebook_tf_testrun/notebook_tf_testrun.network_notebook_tf_data.42.h5\n", - "23/23 - 7s - loss: 53.3989 - binacc: 82.8000 - mae: 11.7304 - rmse: 17.1384 - mse: 1217.8383 - val_loss: 249.6559 - val_binacc: 40.9881 - val_mae: 33.8406 - val_rmse: 37.0574 - val_mse: 1534.1256 - lr: 1.0000e-04 - 7s/epoch - 308ms/step\n", - "[31-12-23 20:28:06 :INFO ] - \n", - "Setting learning rate to: 9.999999747378752e-05\n", - "\n", - "Epoch 5/10\n", - "\n", - "Epoch 5: val_rmse improved from 37.05745 to 35.79379, saving model to ./results/networks/notebook_tf_testrun/notebook_tf_testrun.network_notebook_tf_data.42.h5\n", - "23/23 - 7s - loss: 33.9674 - binacc: 94.9292 - mae: 7.3163 - rmse: 13.6690 - mse: 1167.5585 - val_loss: 232.9197 - val_binacc: 40.5778 - val_mae: 32.6644 - val_rmse: 35.7938 - val_mse: 1461.6030 - lr: 1.0000e-04 - 7s/epoch - 306ms/step\n", - "[31-12-23 20:28:13 :INFO ] - \n", - "Setting learning rate to: 9.999999747378752e-05\n", - "\n", - "Epoch 6/10\n", - "\n", - "Epoch 6: val_rmse improved from 35.79379 to 33.65832, saving model to ./results/networks/notebook_tf_testrun/notebook_tf_testrun.network_notebook_tf_data.42.h5\n", - "23/23 - 7s - loss: 27.6051 - binacc: 95.6763 - mae: 6.2371 - rmse: 12.3225 - mse: 1156.0563 - val_loss: 205.9566 - val_binacc: 42.4682 - val_mae: 30.7192 - val_rmse: 33.6583 - val_mse: 1368.7559 - lr: 1.0000e-04 - 7s/epoch - 305ms/step\n", - "[31-12-23 20:28:20 :INFO ] - \n", - "Setting learning rate to: 9.999999747378752e-05\n", - "\n", - "Epoch 7/10\n", - "\n", - "Epoch 7: val_rmse improved from 33.65832 to 28.84873, saving model to ./results/networks/notebook_tf_testrun/notebook_tf_testrun.network_notebook_tf_data.42.h5\n", - "23/23 - 7s - loss: 22.3518 - binacc: 96.0258 - mae: 5.5203 - rmse: 11.0882 - mse: 1177.7894 - val_loss: 151.3018 - val_binacc: 49.3492 - val_mae: 25.8698 - val_rmse: 28.8487 - val_mse: 1084.3677 - lr: 1.0000e-04 - 7s/epoch - 308ms/step\n", - "[31-12-23 20:28:27 :INFO ] - \n", - "Setting learning rate to: 9.999999747378752e-05\n", - "\n", - "Epoch 8/10\n", - "\n", - "Epoch 8: val_rmse improved from 28.84873 to 23.89613, saving model to ./results/networks/notebook_tf_testrun/notebook_tf_testrun.network_notebook_tf_data.42.h5\n", - "23/23 - 7s - loss: 18.0997 - binacc: 96.1845 - mae: 4.9427 - rmse: 9.9779 - mse: 1180.0663 - val_loss: 103.8116 - val_binacc: 62.1863 - val_mae: 20.4092 - val_rmse: 23.8961 - val_mse: 745.0102 - lr: 1.0000e-04 - 7s/epoch - 306ms/step\n", - "[31-12-23 20:28:34 :INFO ] - \n", - "Setting learning rate to: 9.999999747378752e-05\n", - "\n", - "Epoch 9/10\n", - "\n", - "Epoch 9: val_rmse improved from 23.89613 to 18.21984, saving model to ./results/networks/notebook_tf_testrun/notebook_tf_testrun.network_notebook_tf_data.42.h5\n", - "23/23 - 7s - loss: 15.3763 - binacc: 96.3971 - mae: 4.5406 - rmse: 9.1967 - mse: 1175.2972 - val_loss: 60.3504 - val_binacc: 80.4414 - val_mae: 14.4911 - val_rmse: 18.2198 - val_mse: 495.2992 - lr: 1.0000e-04 - 7s/epoch - 308ms/step\n", - "[31-12-23 20:28:42 :INFO ] - \n", - "Setting learning rate to: 9.999999747378752e-05\n", - "\n", - "Epoch 10/10\n", - "\n", - "Epoch 10: val_rmse improved from 18.21984 to 15.62367, saving model to ./results/networks/notebook_tf_testrun/notebook_tf_testrun.network_notebook_tf_data.42.h5\n", - "23/23 - 7s - loss: 14.0378 - binacc: 96.4605 - mae: 4.3094 - rmse: 8.7873 - mse: 1176.0642 - val_loss: 44.3769 - val_binacc: 87.5845 - val_mae: 11.9019 - val_rmse: 15.6237 - val_mse: 402.5215 - lr: 1.0000e-04 - 7s/epoch - 304ms/step\n", - "[31-12-23 20:28:49 :INFO ] - Saving network to: ./results/networks/notebook_tf_testrun/notebook_tf_testrun.network_notebook_tf_data.42.h5\n", - "[31-12-23 20:28:52 :INFO ] - Running evaluation against test set\n", - "WARNING:tensorflow:Unable to restore custom metric. Please ensure that the layer implements `get_config` and `from_config` when saving. In addition, please use the `custom_objects` arg when calling `load_model()`.\n", - "[31-12-23 20:28:52 :WARNING ] - Unable to restore custom metric. Please ensure that the layer implements `get_config` and `from_config` when saving. In addition, please use the `custom_objects` arg when calling `load_model()`.\n", - "WARNING:tensorflow:Unable to restore custom metric. Please ensure that the layer implements `get_config` and `from_config` when saving. In addition, please use the `custom_objects` arg when calling `load_model()`.\n", - "[31-12-23 20:28:52 :WARNING ] - Unable to restore custom metric. Please ensure that the layer implements `get_config` and `from_config` when saving. In addition, please use the `custom_objects` arg when calling `load_model()`.\n", - "WARNING:tensorflow:Unable to restore custom metric. Please ensure that the layer implements `get_config` and `from_config` when saving. In addition, please use the `custom_objects` arg when calling `load_model()`.\n", - "[31-12-23 20:28:52 :WARNING ] - Unable to restore custom metric. Please ensure that the layer implements `get_config` and `from_config` when saving. In addition, please use the `custom_objects` arg when calling `load_model()`.\n", - "WARNING:tensorflow:Unable to restore custom metric. Please ensure that the layer implements `get_config` and `from_config` when saving. In addition, please use the `custom_objects` arg when calling `load_model()`.\n", - "[31-12-23 20:28:52 :WARNING ] - Unable to restore custom metric. Please ensure that the layer implements `get_config` and `from_config` when saving. In addition, please use the `custom_objects` arg when calling `load_model()`.\n", - "[31-12-23 20:28:53 :INFO ] - Datasets: 46 train, 12 val and 1 test filenames\n", - "[31-12-23 20:28:53 :INFO ] - Reducing datasets to 1.0 of total files\n", - "[31-12-23 20:28:53 :INFO ] - Reduced: 46 train, 12 val and 1 test filenames\n", - "[31-12-23 20:28:53 :INFO ] - Using test set for validation\n", - "[31-12-23 20:28:53 :INFO ] - Metric creation for lead time of 7 days\n", - "[31-12-23 20:28:53 :INFO ] - Evaluating... \n", - "[31-12-23 20:28:57 :INFO ] - Done in 3.4s\n" - ] - } - ], - "source": [ - "!icenet_train notebook_tf_data notebook_tf_testrun 42 -b 4 --epochs 10 --multiprocessing --max-queue-size 4 --workers 4 --n-filters-factor 0.3 --no-wandb" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2020-04-01\n", - "2020-04-02\n" - ] - } - ], - "source": [ - "!./loader_test_dates.sh notebook_tf_data | tee testdates.csv" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-12-31 20:28:59.964270: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", - "2023-12-31 20:28:59.964322: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", - "2023-12-31 20:28:59.965395: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", - "[31-12-23 20:29:03 :INFO ] - Loading configuration ./dataset_config.notebook_tf_data.json\n", - "[31-12-23 20:29:03 :INFO ] - Training dataset path: ./network_datasets/notebook_tf_data/south/train\n", - "[31-12-23 20:29:03 :INFO ] - Validation dataset path: ./network_datasets/notebook_tf_data/south/val\n", - "[31-12-23 20:29:03 :INFO ] - Test dataset path: ./network_datasets/notebook_tf_data/south/test\n", - "[31-12-23 20:29:03 :INFO ] - Loading configuration /data/hpcdata/users/bryald/git/icenet/notebook-pipeline/loader.notebook_tf_data.json\n", - "[31-12-23 20:29:03 :INFO ] - Loading model from ./results/networks/notebook_tf_testrun/notebook_tf_testrun.network_notebook_tf_data.42.h5...\n", - "[31-12-23 20:29:04 :INFO ] - Datasets: 46 train, 12 val and 1 test filenames\n", - "[31-12-23 20:29:04 :INFO ] - Processing test batch 1, item 0 (date 2020-04-01)\n", - "[31-12-23 20:29:04 :INFO ] - Running prediction 2020-04-01\n", - "[31-12-23 20:29:09 :WARNING ] - ./results/predict/example_south_tf_forecast/notebook_tf_testrun.42 output already exists\n", - "[31-12-23 20:29:09 :INFO ] - Saving 2020-04-01 - forecast output (1, 432, 432, 7)\n", - "[31-12-23 20:29:09 :INFO ] - Processing test batch 1, item 1 (date 2020-04-02)\n", - "[31-12-23 20:29:09 :INFO ] - Running prediction 2020-04-02\n", - "[31-12-23 20:29:09 :WARNING ] - ./results/predict/example_south_tf_forecast/notebook_tf_testrun.42 output already exists\n", - "[31-12-23 20:29:09 :INFO ] - Saving 2020-04-02 - forecast output (1, 432, 432, 7)\n" - ] - } - ], - "source": [ - "!icenet_predict -n 0.3 -t \\\n", - " notebook_tf_data notebook_tf_testrun example_south_tf_forecast 42 testdates.csv" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2023-12-31 20:29:11.878115: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", - "2023-12-31 20:29:11.878166: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", - "2023-12-31 20:29:11.879218: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", - "[31-12-23 20:29:13 :INFO ] - Loading configuration ./dataset_config.notebook_tf_data.json\n", - "[31-12-23 20:29:13 :INFO ] - Training dataset path: ./network_datasets/notebook_tf_data/south/train\n", - "[31-12-23 20:29:13 :INFO ] - Validation dataset path: ./network_datasets/notebook_tf_data/south/val\n", - "[31-12-23 20:29:13 :INFO ] - Test dataset path: ./network_datasets/notebook_tf_data/south/test\n", - "/data/hpcdata/users/bryald/git/icenet/icenet/icenet/process/predict.py:58: FutureWarning: Ignoring a datum in netCDF load for consistency with existing behaviour. In a future version of Iris, this datum will be applied. To apply the datum when loading, use the iris.FUTURE.datum_support flag.\n", - " cube = iris.load_cube(path, 'sea_ice_area_fraction')\n", - "/data/hpcdata/users/bryald/git/icenet/icenet/icenet/process/predict.py:58: FutureWarning: Ignoring a datum in netCDF load for consistency with existing behaviour. In a future version of Iris, this datum will be applied. To apply the datum when loading, use the iris.FUTURE.datum_support flag.\n", - " cube = iris.load_cube(path, 'sea_ice_area_fraction')\n", - "/data/hpcdata/users/bryald/git/icenet/icenet/icenet/process/predict.py:58: FutureWarning: Ignoring a datum in netCDF load for consistency with existing behaviour. In a future version of Iris, this datum will be applied. To apply the datum when loading, use the iris.FUTURE.datum_support flag.\n", - " cube = iris.load_cube(path, 'sea_ice_area_fraction')\n", - "/data/hpcdata/users/bryald/git/icenet/icenet/icenet/process/predict.py:58: FutureWarning: Ignoring a datum in netCDF load for consistency with existing behaviour. In a future version of Iris, this datum will be applied. To apply the datum when loading, use the iris.FUTURE.datum_support flag.\n", - " cube = iris.load_cube(path, 'sea_ice_area_fraction')\n", - "[31-12-23 20:29:14 :INFO ] - Post-processing 2020-04-01\n", - "[31-12-23 20:29:14 :INFO ] - Post-processing 2020-04-02\n", - "[31-12-23 20:29:14 :INFO ] - Dataset arr shape: (2, 432, 432, 7, 2)\n", - "[31-12-23 20:29:14 :INFO ] - Saving to results/predict/example_south_tf_forecast.nc\n" - ] - } - ], - "source": [ - "# Create netCDF files from npy prediction outputs.\n", - "!icenet_output -o results/predict example_south_tf_forecast notebook_tf_data testdates.csv" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-12-31 20:29:16.433108: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", - "2023-12-31 20:29:16.433148: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", - "2023-12-31 20:29:16.434205: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "xarray.Dataset {\n", - "dimensions:\n", - "\ttime = 2 ;\n", - "\tyc = 432 ;\n", - "\txc = 432 ;\n", - "\tleadtime = 7 ;\n", - "\n", - "variables:\n", - "\tint32 Lambert_Azimuthal_Grid() ;\n", - "\t\tLambert_Azimuthal_Grid:grid_mapping_name = lambert_azimuthal_equal_area ;\n", - "\t\tLambert_Azimuthal_Grid:longitude_of_projection_origin = 0.0 ;\n", - "\t\tLambert_Azimuthal_Grid:latitude_of_projection_origin = -90.0 ;\n", - "\t\tLambert_Azimuthal_Grid:false_easting = 0.0 ;\n", - "\t\tLambert_Azimuthal_Grid:false_northing = 0.0 ;\n", - "\t\tLambert_Azimuthal_Grid:semi_major_axis = 6378137.0 ;\n", - "\t\tLambert_Azimuthal_Grid:inverse_flattening = 298.257223563 ;\n", - "\t\tLambert_Azimuthal_Grid:proj4_string = +proj=laea +lon_0=0 +datum=WGS84 +ellps=WGS84 +lat_0=-90.0 ;\n", - "\tfloat32 sic_mean(time, yc, xc, leadtime) ;\n", - "\t\tsic_mean:long_name = mean sea ice area fraction across ensemble runs of icenet model ;\n", - "\t\tsic_mean:standard_name = sea_ice_area_fraction ;\n", - "\t\tsic_mean:short_name = sic ;\n", - "\t\tsic_mean:valid_min = 0 ;\n", - "\t\tsic_mean:valid_max = 1 ;\n", - "\t\tsic_mean:ancillary_variables = sic_stddev ;\n", - "\t\tsic_mean:grid_mapping = Lambert_Azimuthal_Grid ;\n", - "\t\tsic_mean:units = 1 ;\n", - "\tfloat32 sic_stddev(time, yc, xc, leadtime) ;\n", - "\t\tsic_stddev:long_name = total uncertainty (one standard deviation) of concentration of sea ice ;\n", - "\t\tsic_stddev:standard_name = sea_ice_area_fraction standard_error ;\n", - "\t\tsic_stddev:valid_min = 0 ;\n", - "\t\tsic_stddev:valid_max = 1 ;\n", - "\t\tsic_stddev:grid_mapping = Lambert_Azimuthal_Grid ;\n", - "\t\tsic_stddev:units = 1 ;\n", - "\tint64 ensemble_members(time) ;\n", - "\t\tensemble_members:long_name = number of ensemble members used to create this prediction ;\n", - "\t\tensemble_members:short_name = ensemble_members ;\n", - "\tdatetime64[ns] time(time) ;\n", - "\t\ttime:long_name = reference time of product ;\n", - "\t\ttime:standard_name = time ;\n", - "\t\ttime:axis = T ;\n", - "\tint64 leadtime(leadtime) ;\n", - "\t\tleadtime:long_name = leadtime of forecast in relation to reference time ;\n", - "\t\tleadtime:short_name = leadtime ;\n", - "\tdatetime64[ns] forecast_date(time, leadtime) ;\n", - "\tfloat64 xc(xc) ;\n", - "\t\txc:long_name = x coordinate of projection (eastings) ;\n", - "\t\txc:standard_name = projection_x_coordinate ;\n", - "\t\txc:units = 1000 meter ;\n", - "\t\txc:axis = X ;\n", - "\tfloat64 yc(yc) ;\n", - "\t\tyc:long_name = y coordinate of projection (northings) ;\n", - "\t\tyc:standard_name = projection_y_coordinate ;\n", - "\t\tyc:units = 1000 meter ;\n", - "\t\tyc:axis = Y ;\n", - "\tfloat32 lat(yc, xc) ;\n", - "\t\tlat:long_name = latitude coordinate ;\n", - "\t\tlat:standard_name = latitude ;\n", - "\t\tlat:units = arc_degree ;\n", - "\tfloat32 lon(yc, xc) ;\n", - "\t\tlon:long_name = longitude coordinate ;\n", - "\t\tlon:standard_name = longitude ;\n", - "\t\tlon:units = arc_degree ;\n", - "\n", - "// global attributes:\n", - "\t:Conventions = CF-1.6 ACDD-1.3 ;\n", - "\t:comments = ;\n", - "\t:creator_email = jambyr@bas.ac.uk ;\n", - "\t:creator_institution = British Antarctic Survey ;\n", - "\t:creator_name = James Byrne ;\n", - "\t:creator_url = www.bas.ac.uk ;\n", - "\t:date_created = 2023-12-31 ;\n", - "\t:geospatial_bounds_crs = EPSG:6932 ;\n", - "\t:geospatial_lat_min = -90.0 ;\n", - "\t:geospatial_lat_max = -16.62393 ;\n", - "\t:geospatial_lon_min = -180.0 ;\n", - "\t:geospatial_lon_max = 180.0 ;\n", - "\t:geospatial_vertical_min = 0.0 ;\n", - "\t:geospatial_vertical_max = 0.0 ;\n", - "\t:history = 2023-12-31 20:29:14.876693 - creation ;\n", - "\t:id = IceNet 0.2.7a1 ;\n", - "\t:institution = British Antarctic Survey ;\n", - "\t:keywords = 'Earth Science > Cryosphere > Sea Ice > Sea Ice Concentration\n", - " Earth Science > Oceans > Sea Ice > Sea Ice Concentration\n", - " Earth Science > Climate Indicators > Cryospheric Indicators > Sea Ice\n", - " Geographic Region > Southern Hemisphere ;\n", - "\t:keywords_vocabulary = GCMD Science Keywords ;\n", - "\t:license = Open Government Licece (OGL) V3 ;\n", - "\t:naming_authority = uk.ac.bas ;\n", - "\t:platform = BAS HPC ;\n", - "\t:product_version = 0.2.7a1 ;\n", - "\t:project = IceNet ;\n", - "\t:publisher_email = ;\n", - "\t:publisher_institution = British Antarctic Survey ;\n", - "\t:publisher_url = ;\n", - "\t:source = \n", - " IceNet model generation at v0.2.7a1\n", - " ;\n", - "\t:spatial_resolution = 25.0 km grid spacing ;\n", - "\t:standard_name_vocabulary = CF Standard Name Table v27 ;\n", - "\t:summary = \n", - " This is an output of sea ice concentration predictions from the\n", - " IceNet run in an ensemble, with postprocessing to determine\n", - " the mean and standard deviation across the runs.\n", - " ;\n", - "\t:time_coverage_start = 2020-04-02T00:00:00 ;\n", - "\t:time_coverage_end = 2020-04-09T00:00:00 ;\n", - "\t:time_coverage_duration = P1D ;\n", - "\t:time_coverage_resolution = P1D ;\n", - "\t:title = Sea Ice Concentration Prediction ;\n", - "}" - ] - } - ], - "source": [ - "from icenet.plotting.video import xarray_to_video as xvid\n", - "from icenet.data.sic.mask import Masks\n", - "\n", - "ds = xr.open_dataset(\"results/predict/example_south_tf_forecast.nc\")\n", - "land_mask = Masks(south=True, north=False).get_land_mask()\n", - "ds.info()" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "2020-04-01T00:00:00.000000000\n" - ] - } - ], - "source": [ - "forecast_date = ds.time.values[0]\n", - "print(forecast_date)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:root:Inspecting data\n", - "INFO:root:Initialising plot\n", - "INFO:root:Animating\n", - "INFO:root:Not saving plot, will return animation\n", - "INFO:matplotlib.animation:Animation.save using \n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
\n", - " \n", - "
\n", - " \n", - "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "fc = ds.sic_mean.isel(time=0).drop_vars(\"time\").rename(dict(leadtime=\"time\"))\n", - "fc['time'] = [pd.to_datetime(forecast_date) \\\n", - " + dt.timedelta(days=int(e)) for e in fc.time.values]\n", - "\n", - "anim = xvid(fc, 15, figsize=4, mask=land_mask)\n", - "HTML(anim.to_jshtml())" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Version\n", - "- IceNet Codebase: v0.2.7_dev" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "pytorch", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.6" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -}