diff --git a/pytorch_example.ipynb b/pytorch_example.ipynb new file mode 100644 index 0000000..90e49eb --- /dev/null +++ b/pytorch_example.ipynb @@ -0,0 +1,2970 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running in /data/hpcdata/users/rychan/notebooks/notebook-pipeline\n" + ] + } + ], + "source": [ + "# Quick hack to put us in the icenet-pipeline folder,\n", + "# assuming it was created as per 01.cli_demonstration.ipynb\n", + "import os\n", + "if os.path.exists(\"pytorch_example.ipynb\"):\n", + " os.chdir(\"../notebook-pipeline\")\n", + "print(\"Running in {}\".format(os.getcwd()))\n", + "\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import os\n", + "import random\n", + "import torch\n", + "\n", + "# We also set the logging level so that we get some feedback from the API\n", + "import logging\n", + "logging.basicConfig(level=logging.INFO)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "A 2.0.1+cu117\n", + "B True\n", + "C True\n" + ] + } + ], + "source": [ + "print('A', torch.__version__)\n", + "print('B', torch.cuda.is_available())\n", + "print('C', torch.backends.cudnn.enabled)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "D _CudaDeviceProperties(name='NVIDIA A2', major=8, minor=6, total_memory=14938MB, multi_processor_count=10)\n" + ] + } + ], + "source": [ + "device = torch.device('cuda')\n", + "print('D', torch.cuda.get_device_properties(device))" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Fri Aug 18 17:07:41 2023 \n", + "+-----------------------------------------------------------------------------+\n", + "| NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 |\n", + "|-------------------------------+----------------------+----------------------+\n", + "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", + "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", + "| | | MIG M. |\n", + "|===============================+======================+======================|\n", + "| 0 NVIDIA A2 On | 00000000:98:00.0 Off | 0 |\n", + "| 0% 32C P8 5W / 60W | 2MiB / 15356MiB | 0% Default |\n", + "| | | N/A |\n", + "+-------------------------------+----------------------+----------------------+\n", + " \n", + "+-----------------------------------------------------------------------------+\n", + "| Processes: |\n", + "| GPU GI CI PID Type Process name GPU Memory |\n", + "| ID ID Usage |\n", + "|=============================================================================|\n", + "| No running processes found |\n", + "+-----------------------------------------------------------------------------+\n" + ] + } + ], + "source": [ + "!nvidia-smi" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Dataset creation\n", + "\n", + "Assuming we have ran [03.library_usage](03.library_usage.ipynb) `loader.notebook_api_data.json` file existing in the current directory." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-08-18 17:07:50.427170: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", + "2023-08-18 17:07:51.864480: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", + "INFO:root:Loading configuration loader.notebook_api_data.json\n" + ] + } + ], + "source": [ + "from icenet.data.loaders import IceNetDataLoaderFactory\n", + "\n", + "implementation = \"dask\"\n", + "loader_config = \"loader.notebook_api_data.json\"\n", + "dataset_name = \"pytorch_notebook\"\n", + "lag = 1\n", + "\n", + "dl = IceNetDataLoaderFactory().create_data_loader(\n", + " implementation,\n", + " loader_config,\n", + " dataset_name,\n", + " lag,\n", + " n_forecast_days=7,\n", + " north=False,\n", + " south=True,\n", + " output_batch_size=4,\n", + " generate_workers=8)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dl" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'sources': {'era5': {'name': 'notebook_api_data',\n", + " 'implementation': 'IceNetERA5PreProcessor',\n", + " 'anom': ['tas', 'zg500', 'zg250'],\n", + " 'abs': ['uas', 'vas'],\n", + " 'dates': {'train': ['2020_01_01',\n", + " '2020_01_02',\n", + " '2020_01_03',\n", + " '2020_01_04',\n", + " '2020_01_05',\n", + " '2020_01_06',\n", + " '2020_01_07',\n", + " '2020_01_08',\n", + " '2020_01_09',\n", + " '2020_01_10',\n", + " '2020_01_11',\n", + " '2020_01_12',\n", + " '2020_01_13',\n", + " '2020_01_14',\n", + " '2020_01_15',\n", + " '2020_01_16',\n", + " '2020_01_17',\n", + " '2020_01_18',\n", + " '2020_01_19',\n", + " '2020_01_20',\n", + " '2020_01_21',\n", + " '2020_01_22',\n", + " '2020_01_23',\n", + " '2020_01_24',\n", + " '2020_01_25',\n", + " '2020_01_26',\n", + " '2020_01_27',\n", + " '2020_01_28',\n", + " '2020_01_29',\n", + " '2020_01_30',\n", + " '2020_01_31',\n", + " '2020_02_01',\n", + " '2020_02_02',\n", + " '2020_02_03',\n", + " '2020_02_04',\n", + " '2020_02_05',\n", + " '2020_02_06',\n", + " '2020_02_07',\n", + " '2020_02_08',\n", + " '2020_02_09',\n", + " '2020_02_10',\n", + " '2020_02_11',\n", + " '2020_02_12',\n", + " '2020_02_13',\n", + " '2020_02_14',\n", + " '2020_02_15',\n", + " '2020_02_16',\n", + " '2020_02_17',\n", + " '2020_02_18',\n", + " '2020_02_19',\n", + " '2020_02_20',\n", + " '2020_02_21',\n", + " '2020_02_22',\n", + " '2020_02_23',\n", + " '2020_02_24',\n", + " '2020_02_25',\n", + " '2020_02_26',\n", + " '2020_02_27',\n", + " '2020_02_28',\n", + " '2020_02_29',\n", + " '2020_03_01',\n", + " '2020_03_02',\n", + " '2020_03_03',\n", + " '2020_03_04',\n", + " '2020_03_05',\n", + " '2020_03_06',\n", + " '2020_03_07',\n", + " '2020_03_08',\n", + " '2020_03_09',\n", + " '2020_03_10',\n", + " '2020_03_11',\n", + " '2020_03_12',\n", + " '2020_03_13',\n", + " '2020_03_14',\n", + " '2020_03_15',\n", + " '2020_03_16',\n", + " '2020_03_17',\n", + " '2020_03_18',\n", + " '2020_03_19',\n", + " '2020_03_20',\n", + " '2020_03_21',\n", + " '2020_03_22',\n", + " '2020_03_23',\n", + " '2020_03_24',\n", + " '2020_03_25',\n", + " '2020_03_26',\n", + " '2020_03_27',\n", + " '2020_03_28',\n", + " '2020_03_29',\n", + " '2020_03_30',\n", + " '2020_03_31'],\n", + " 'val': ['2020_04_03',\n", + " '2020_04_04',\n", + " '2020_04_05',\n", + " '2020_04_06',\n", + " '2020_04_07',\n", + " '2020_04_08',\n", + " '2020_04_09',\n", + " '2020_04_10',\n", + " '2020_04_11',\n", + " '2020_04_12',\n", + " '2020_04_13',\n", + " '2020_04_14',\n", + " '2020_04_15',\n", + " '2020_04_16',\n", + " '2020_04_17',\n", + " '2020_04_18',\n", + " '2020_04_19',\n", + " '2020_04_20',\n", + " '2020_04_21',\n", + " '2020_04_22',\n", + " '2020_04_23'],\n", + " 'test': ['2020_04_01', '2020_04_02']},\n", + " 'linear_trends': [],\n", + " 'linear_trend_steps': [1, 2, 3, 4, 5, 6, 7],\n", + " 'meta': [],\n", + " 'var_files': {'uas': ['./processed/notebook_api_data/era5/south/uas/uas_abs.nc'],\n", + " 'vas': ['./processed/notebook_api_data/era5/south/vas/vas_abs.nc'],\n", + " 'tas': ['./processed/notebook_api_data/era5/south/tas/tas_anom.nc'],\n", + " 'zg500': ['./processed/notebook_api_data/era5/south/zg500/zg500_anom.nc'],\n", + " 'zg250': ['./processed/notebook_api_data/era5/south/zg250/zg250_anom.nc']}},\n", + " 'osisaf': {'name': 'notebook_api_data',\n", + " 'implementation': 'IceNetOSIPreProcessor',\n", + " 'anom': [],\n", + " 'abs': ['siconca'],\n", + " 'dates': {'train': ['2020_01_01',\n", + " '2020_01_02',\n", + " '2020_01_03',\n", + " '2020_01_04',\n", + " '2020_01_05',\n", + " '2020_01_06',\n", + " '2020_01_07',\n", + " '2020_01_08',\n", + " '2020_01_09',\n", + " '2020_01_10',\n", + " '2020_01_11',\n", + " '2020_01_12',\n", + " '2020_01_13',\n", + " '2020_01_14',\n", + " '2020_01_15',\n", + " '2020_01_16',\n", + " '2020_01_17',\n", + " '2020_01_18',\n", + " '2020_01_19',\n", + " '2020_01_20',\n", + " '2020_01_21',\n", + " '2020_01_22',\n", + " '2020_01_23',\n", + " '2020_01_24',\n", + " '2020_01_25',\n", + " '2020_01_26',\n", + " '2020_01_27',\n", + " '2020_01_28',\n", + " '2020_01_29',\n", + " '2020_01_30',\n", + " '2020_01_31',\n", + " '2020_02_01',\n", + " '2020_02_02',\n", + " '2020_02_03',\n", + " '2020_02_04',\n", + " '2020_02_05',\n", + " '2020_02_06',\n", + " '2020_02_07',\n", + " '2020_02_08',\n", + " '2020_02_09',\n", + " '2020_02_10',\n", + " '2020_02_11',\n", + " '2020_02_12',\n", + " '2020_02_13',\n", + " '2020_02_14',\n", + " '2020_02_15',\n", + " '2020_02_16',\n", + " '2020_02_17',\n", + " '2020_02_18',\n", + " '2020_02_19',\n", + " '2020_02_20',\n", + " '2020_02_21',\n", + " '2020_02_22',\n", + " '2020_02_23',\n", + " '2020_02_24',\n", + " '2020_02_25',\n", + " '2020_02_26',\n", + " '2020_02_27',\n", + " '2020_02_28',\n", + " '2020_02_29',\n", + " '2020_03_01',\n", + " '2020_03_02',\n", + " '2020_03_03',\n", + " '2020_03_04',\n", + " '2020_03_05',\n", + " '2020_03_06',\n", + " '2020_03_07',\n", + " '2020_03_08',\n", + " '2020_03_09',\n", + " '2020_03_10',\n", + " '2020_03_11',\n", + " '2020_03_12',\n", + " '2020_03_13',\n", + " '2020_03_14',\n", + " '2020_03_15',\n", + " '2020_03_16',\n", + " '2020_03_17',\n", + " '2020_03_18',\n", + " '2020_03_19',\n", + " '2020_03_20',\n", + " '2020_03_21',\n", + " '2020_03_22',\n", + " '2020_03_23',\n", + " '2020_03_24',\n", + " '2020_03_25',\n", + " '2020_03_26',\n", + " '2020_03_27',\n", + " '2020_03_28',\n", + " '2020_03_29',\n", + " '2020_03_30',\n", + " '2020_03_31'],\n", + " 'val': ['2020_04_03',\n", + " '2020_04_04',\n", + " '2020_04_05',\n", + " '2020_04_06',\n", + " '2020_04_07',\n", + " '2020_04_08',\n", + " '2020_04_09',\n", + " '2020_04_10',\n", + " '2020_04_11',\n", + " '2020_04_12',\n", + " '2020_04_13',\n", + " '2020_04_14',\n", + " '2020_04_15',\n", + " '2020_04_16',\n", + " '2020_04_17',\n", + " '2020_04_18',\n", + " '2020_04_19',\n", + " '2020_04_20',\n", + " '2020_04_21',\n", + " '2020_04_22',\n", + " '2020_04_23'],\n", + " 'test': ['2020_04_01', '2020_04_02']},\n", + " 'linear_trends': [],\n", + " 'linear_trend_steps': [1, 2, 3, 4, 5, 6, 7],\n", + " 'meta': [],\n", + " 'var_files': {'siconca': ['./processed/notebook_api_data/osisaf/south/siconca/siconca_abs.nc']}},\n", + " 'meta': {'name': 'notebook_api_data',\n", + " 'implementation': 'IceNetMetaPreProcessor',\n", + " 'anom': [],\n", + " 'abs': [],\n", + " 'dates': {'train': [], 'val': [], 'test': []},\n", + " 'linear_trends': [],\n", + " 'linear_trend_steps': [1, 2, 3, 4, 5, 6, 7],\n", + " 'meta': ['sin', 'cos', 'land'],\n", + " 'var_files': {'sin': ['./processed/notebook_api_data/meta/south/sin/sin.nc'],\n", + " 'cos': ['./processed/notebook_api_data/meta/south/cos/cos.nc'],\n", + " 'land': ['./processed/notebook_api_data/meta/south/land/land.nc']}}},\n", + " 'dtype': 'float32',\n", + " 'shape': [432, 432],\n", + " 'missing_dates': []}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dl._config" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We generate a config only dataset, which will get saved in `dataset_config.pytorch_notebook.json`." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:Writing dataset configuration without data generation\n", + "INFO:root:91 train dates in total, NOT generating cache data.\n", + "INFO:root:21 val dates in total, NOT generating cache data.\n", + "INFO:root:2 test dates in total, NOT generating cache data.\n", + "INFO:root:Writing configuration to ./dataset_config.pytorch_notebook.json\n" + ] + } + ], + "source": [ + "dl.write_dataset_config_only()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now create the IceNetDataSet object:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "dataset_config = \"dataset_config.pytorch_notebook.json\"" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:Loading configuration dataset_config.pytorch_notebook.json\n", + "WARNING:root:Running in configuration only mode, tfrecords were not generated for this dataset\n" + ] + } + ], + "source": [ + "from icenet.data.dataset import IceNetDataSet\n", + "\n", + "dataset = IceNetDataSet(dataset_config, batch_size=4)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'identifier': 'pytorch_notebook',\n", + " 'implementation': 'DaskMultiWorkerLoader',\n", + " 'channels': ['uas_abs_1',\n", + " 'vas_abs_1',\n", + " 'siconca_abs_1',\n", + " 'tas_anom_1',\n", + " 'zg250_anom_1',\n", + " 'zg500_anom_1',\n", + " 'cos_1',\n", + " 'land_1',\n", + " 'sin_1'],\n", + " 'counts': {'train': 91, 'val': 21, 'test': 2},\n", + " 'dtype': 'float32',\n", + " 'loader_config': '/data/hpcdata/users/rychan/notebooks/notebook-pipeline/loader.notebook_api_data.json',\n", + " 'missing_dates': [],\n", + " 'n_forecast_days': 7,\n", + " 'north': False,\n", + " 'num_channels': 9,\n", + " 'shape': [432, 432],\n", + " 'south': True,\n", + " 'dataset_path': False,\n", + " 'loss_weight_days': True,\n", + " 'output_batch_size': 4,\n", + " 'var_lag': 1,\n", + " 'var_lag_override': {}}" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset._config" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'/data/hpcdata/users/rychan/notebooks/notebook-pipeline/loader.notebook_api_data.json'" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset.loader_config" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Custom PyTorch Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.data import Dataset, DataLoader\n", + "\n", + "class IceNetDataSetPyTorch(Dataset):\n", + " def __init__(self,\n", + " configuration_path: str,\n", + " mode: str,\n", + " batch_size: int = 4,\n", + " shuffling: bool = False):\n", + " self._ds = IceNetDataSet(configuration_path=configuration_path,\n", + " batch_size=batch_size,\n", + " shuffling=shuffling)\n", + " self._dl = self._ds.get_data_loader()\n", + " \n", + " # check mode option\n", + " if mode not in [\"train\", \"val\", \"test\"]:\n", + " raise ValueError(\"mode must be either 'train', 'val' or 'test'\")\n", + " self._mode = mode\n", + " \n", + " self._dates = self._dl._config[\"sources\"][\"osisaf\"][\"dates\"][self._mode]\n", + " \n", + " def __len__(self):\n", + " return self._ds._counts[self._mode]\n", + " \n", + " def __getitem__(self, idx):\n", + " return self._dl.generate_sample(date=pd.Timestamp(self._dates[idx].replace('_', '-')))" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:Loading configuration dataset_config.pytorch_notebook.json\n", + "WARNING:root:Running in configuration only mode, tfrecords were not generated for this dataset\n", + "INFO:root:Loading configuration /data/hpcdata/users/rychan/notebooks/notebook-pipeline/loader.notebook_api_data.json\n" + ] + } + ], + "source": [ + "ds_torch = IceNetDataSetPyTorch(configuration_path=dataset_config,\n", + " mode=\"train\")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "91" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds_torch.__len__()" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'2020_01_01'" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds_torch._dates[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([[[ 0.5269795 , 0.49944958, 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.5254056 , 0.4970613 , 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.5229517 , 0.49159 , 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " ...,\n", + " [ 0.45743546, 0.5098583 , 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.45778623, 0.50784564, 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.45920837, 0.5058264 , 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919]],\n", + " \n", + " [[ 0.5222138 , 0.4954434 , 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.52211976, 0.49204698, 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.5185773 , 0.48746935, 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " ...,\n", + " [ 0.44983226, 0.50899804, 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.4528191 , 0.506892 , 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.45470324, 0.5039056 , 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919]],\n", + " \n", + " [[ 0.5161855 , 0.49079096, 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.51489806, 0.48618045, 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.5147566 , 0.48799527, 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " ...,\n", + " [ 0.4443958 , 0.5084778 , 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.4447234 , 0.5066612 , 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.45306128, 0.5051592 , 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.49564162, 0.58406764, 0. , ..., -0.9999424 ,\n", + " -1. , -0.01072919],\n", + " [ 0.4975367 , 0.58466256, 0. , ..., -0.9999424 ,\n", + " -1. , -0.01072919],\n", + " [ 0.4994542 , 0.5867455 , 0. , ..., -0.9999424 ,\n", + " -1. , -0.01072919],\n", + " ...,\n", + " [ 0.55417156, 0.4807669 , 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.5548666 , 0.4806072 , 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.5501333 , 0.47901478, 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919]],\n", + " \n", + " [[ 0.49571717, 0.5845726 , 0. , ..., -0.9999424 ,\n", + " -1. , -0.01072919],\n", + " [ 0.4949193 , 0.5845744 , 0. , ..., -0.9999424 ,\n", + " -1. , -0.01072919],\n", + " [ 0.4962281 , 0.585907 , 0. , ..., -0.9999424 ,\n", + " -1. , -0.01072919],\n", + " ...,\n", + " [ 0.55459803, 0.47982672, 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.55084467, 0.47981164, 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.54948515, 0.4781476 , 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919]],\n", + " \n", + " [[ 0.49510983, 0.5846764 , 0. , ..., -0.9999424 ,\n", + " -1. , -0.01072919],\n", + " [ 0.49457657, 0.58459777, 0. , ..., -0.9999424 ,\n", + " -1. , -0.01072919],\n", + " [ 0.49391657, 0.58567566, 0. , ..., -0.9999424 ,\n", + " -1. , -0.01072919],\n", + " ...,\n", + " [ 0.5538148 , 0.47739896, 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.55086106, 0.47831887, 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.5516952 , 0.47874218, 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919]]], dtype=float32),\n", + " array([[[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]]], dtype=float32),\n", + " array([[[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]]], dtype=float32))" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds_torch.__getitem__(0)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([[[ 0.5269795 , 0.49944958, 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.5254056 , 0.4970613 , 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.5229517 , 0.49159 , 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " ...,\n", + " [ 0.45743546, 0.5098583 , 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.45778623, 0.50784564, 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.45920837, 0.5058264 , 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919]],\n", + " \n", + " [[ 0.5222138 , 0.4954434 , 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.52211976, 0.49204698, 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.5185773 , 0.48746935, 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " ...,\n", + " [ 0.44983226, 0.50899804, 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.4528191 , 0.506892 , 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.45470324, 0.5039056 , 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919]],\n", + " \n", + " [[ 0.5161855 , 0.49079096, 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.51489806, 0.48618045, 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.5147566 , 0.48799527, 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " ...,\n", + " [ 0.4443958 , 0.5084778 , 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.4447234 , 0.5066612 , 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.45306128, 0.5051592 , 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.49564162, 0.58406764, 0. , ..., -0.9999424 ,\n", + " -1. , -0.01072919],\n", + " [ 0.4975367 , 0.58466256, 0. , ..., -0.9999424 ,\n", + " -1. , -0.01072919],\n", + " [ 0.4994542 , 0.5867455 , 0. , ..., -0.9999424 ,\n", + " -1. , -0.01072919],\n", + " ...,\n", + " [ 0.55417156, 0.4807669 , 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.5548666 , 0.4806072 , 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.5501333 , 0.47901478, 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919]],\n", + " \n", + " [[ 0.49571717, 0.5845726 , 0. , ..., -0.9999424 ,\n", + " -1. , -0.01072919],\n", + " [ 0.4949193 , 0.5845744 , 0. , ..., -0.9999424 ,\n", + " -1. , -0.01072919],\n", + " [ 0.4962281 , 0.585907 , 0. , ..., -0.9999424 ,\n", + " -1. , -0.01072919],\n", + " ...,\n", + " [ 0.55459803, 0.47982672, 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.55084467, 0.47981164, 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.54948515, 0.4781476 , 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919]],\n", + " \n", + " [[ 0.49510983, 0.5846764 , 0. , ..., -0.9999424 ,\n", + " -1. , -0.01072919],\n", + " [ 0.49457657, 0.58459777, 0. , ..., -0.9999424 ,\n", + " -1. , -0.01072919],\n", + " [ 0.49391657, 0.58567566, 0. , ..., -0.9999424 ,\n", + " -1. , -0.01072919],\n", + " ...,\n", + " [ 0.5538148 , 0.47739896, 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.55086106, 0.47831887, 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919],\n", + " [ 0.5516952 , 0.47874218, 0. , ..., -0.9999424 ,\n", + " 1. , -0.01072919]]], dtype=float32),\n", + " array([[[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]]], dtype=float32),\n", + " array([[[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]]], dtype=float32))" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds_torch._dl.generate_sample(date=pd.Timestamp(ds_torch._dates[0].replace('_', '-')))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Generating PyTorch DataLoaders" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:Loading configuration dataset_config.pytorch_notebook.json\n", + "WARNING:root:Running in configuration only mode, tfrecords were not generated for this dataset\n", + "INFO:root:Loading configuration /data/hpcdata/users/rychan/notebooks/notebook-pipeline/loader.notebook_api_data.json\n", + "INFO:root:Loading configuration dataset_config.pytorch_notebook.json\n", + "WARNING:root:Running in configuration only mode, tfrecords were not generated for this dataset\n", + "INFO:root:Loading configuration /data/hpcdata/users/rychan/notebooks/notebook-pipeline/loader.notebook_api_data.json\n", + "INFO:root:Loading configuration dataset_config.pytorch_notebook.json\n", + "WARNING:root:Running in configuration only mode, tfrecords were not generated for this dataset\n", + "INFO:root:Loading configuration /data/hpcdata/users/rychan/notebooks/notebook-pipeline/loader.notebook_api_data.json\n" + ] + } + ], + "source": [ + "train_dataset = IceNetDataSetPyTorch(configuration_path=dataset_config, mode=\"train\")\n", + "val_dataset = IceNetDataSetPyTorch(configuration_path=dataset_config, mode=\"val\")\n", + "test_dataset = IceNetDataSetPyTorch(configuration_path=dataset_config, mode=\"test\")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "batch_size = 4\n", + "shuffle = True\n", + "num_workers = 2\n", + "\n", + "train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, persistent_workers=True, num_workers=2)\n", + "val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, persistent_workers=True, num_workers=2)\n", + "test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=shuffle, persistent_workers=True, num_workers=2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Iterating through DataLoaders" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "23" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(train_dataloader)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/users/rychan/notebooks/icenet-notebooks/pytorch_example.ipynb Cell 29\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0m train_features, train_labels, sample_weights \u001b[39m=\u001b[39m \u001b[39mnext\u001b[39;49m(\u001b[39miter\u001b[39;49m(train_dataloader))\n", + "File \u001b[0;32m/data/hpcdata/users/rychan/miniconda3/envs/icenet_pytorch/lib/python3.8/site-packages/torch/utils/data/dataloader.py:633\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 630\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_sampler_iter \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 631\u001b[0m \u001b[39m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m 632\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_reset() \u001b[39m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 633\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_next_data()\n\u001b[1;32m 634\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n\u001b[1;32m 635\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataset_kind \u001b[39m==\u001b[39m _DatasetKind\u001b[39m.\u001b[39mIterable \u001b[39mand\u001b[39;00m \\\n\u001b[1;32m 636\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m \\\n\u001b[1;32m 637\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_num_yielded \u001b[39m>\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_IterableDataset_len_called:\n", + "File \u001b[0;32m/data/hpcdata/users/rychan/miniconda3/envs/icenet_pytorch/lib/python3.8/site-packages/torch/utils/data/dataloader.py:1328\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1325\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_process_data(data)\n\u001b[1;32m 1327\u001b[0m \u001b[39massert\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_shutdown \u001b[39mand\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_tasks_outstanding \u001b[39m>\u001b[39m \u001b[39m0\u001b[39m\n\u001b[0;32m-> 1328\u001b[0m idx, data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_get_data()\n\u001b[1;32m 1329\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_tasks_outstanding \u001b[39m-\u001b[39m\u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n\u001b[1;32m 1330\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataset_kind \u001b[39m==\u001b[39m _DatasetKind\u001b[39m.\u001b[39mIterable:\n\u001b[1;32m 1331\u001b[0m \u001b[39m# Check for _IterableDatasetStopIteration\u001b[39;00m\n", + "File \u001b[0;32m/data/hpcdata/users/rychan/miniconda3/envs/icenet_pytorch/lib/python3.8/site-packages/torch/utils/data/dataloader.py:1294\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._get_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1290\u001b[0m \u001b[39m# In this case, `self._data_queue` is a `queue.Queue`,. But we don't\u001b[39;00m\n\u001b[1;32m 1291\u001b[0m \u001b[39m# need to call `.task_done()` because we don't use `.join()`.\u001b[39;00m\n\u001b[1;32m 1292\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 1293\u001b[0m \u001b[39mwhile\u001b[39;00m \u001b[39mTrue\u001b[39;00m:\n\u001b[0;32m-> 1294\u001b[0m success, data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_try_get_data()\n\u001b[1;32m 1295\u001b[0m \u001b[39mif\u001b[39;00m success:\n\u001b[1;32m 1296\u001b[0m \u001b[39mreturn\u001b[39;00m data\n", + "File \u001b[0;32m/data/hpcdata/users/rychan/miniconda3/envs/icenet_pytorch/lib/python3.8/site-packages/torch/utils/data/dataloader.py:1132\u001b[0m, in \u001b[0;36m_MultiProcessingDataLoaderIter._try_get_data\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 1119\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_try_get_data\u001b[39m(\u001b[39mself\u001b[39m, timeout\u001b[39m=\u001b[39m_utils\u001b[39m.\u001b[39mMP_STATUS_CHECK_INTERVAL):\n\u001b[1;32m 1120\u001b[0m \u001b[39m# Tries to fetch data from `self._data_queue` once for a given timeout.\u001b[39;00m\n\u001b[1;32m 1121\u001b[0m \u001b[39m# This can also be used as inner loop of fetching without timeout, with\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1129\u001b[0m \u001b[39m# Returns a 2-tuple:\u001b[39;00m\n\u001b[1;32m 1130\u001b[0m \u001b[39m# (bool: whether successfully get data, any: data if successful else None)\u001b[39;00m\n\u001b[1;32m 1131\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m-> 1132\u001b[0m data \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_data_queue\u001b[39m.\u001b[39;49mget(timeout\u001b[39m=\u001b[39;49mtimeout)\n\u001b[1;32m 1133\u001b[0m \u001b[39mreturn\u001b[39;00m (\u001b[39mTrue\u001b[39;00m, data)\n\u001b[1;32m 1134\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mException\u001b[39;00m \u001b[39mas\u001b[39;00m e:\n\u001b[1;32m 1135\u001b[0m \u001b[39m# At timeout and error, we manually check whether any worker has\u001b[39;00m\n\u001b[1;32m 1136\u001b[0m \u001b[39m# failed. Note that this is the only mechanism for Windows to detect\u001b[39;00m\n\u001b[1;32m 1137\u001b[0m \u001b[39m# worker failures.\u001b[39;00m\n", + "File \u001b[0;32m/data/hpcdata/users/rychan/miniconda3/envs/icenet_pytorch/lib/python3.8/multiprocessing/queues.py:107\u001b[0m, in \u001b[0;36mQueue.get\u001b[0;34m(self, block, timeout)\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[39mif\u001b[39;00m block:\n\u001b[1;32m 106\u001b[0m timeout \u001b[39m=\u001b[39m deadline \u001b[39m-\u001b[39m time\u001b[39m.\u001b[39mmonotonic()\n\u001b[0;32m--> 107\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_poll(timeout):\n\u001b[1;32m 108\u001b[0m \u001b[39mraise\u001b[39;00m Empty\n\u001b[1;32m 109\u001b[0m \u001b[39melif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_poll():\n", + "File \u001b[0;32m/data/hpcdata/users/rychan/miniconda3/envs/icenet_pytorch/lib/python3.8/multiprocessing/connection.py:257\u001b[0m, in \u001b[0;36m_ConnectionBase.poll\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 255\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_check_closed()\n\u001b[1;32m 256\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_check_readable()\n\u001b[0;32m--> 257\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_poll(timeout)\n", + "File \u001b[0;32m/data/hpcdata/users/rychan/miniconda3/envs/icenet_pytorch/lib/python3.8/multiprocessing/connection.py:424\u001b[0m, in \u001b[0;36mConnection._poll\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 423\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_poll\u001b[39m(\u001b[39mself\u001b[39m, timeout):\n\u001b[0;32m--> 424\u001b[0m r \u001b[39m=\u001b[39m wait([\u001b[39mself\u001b[39;49m], timeout)\n\u001b[1;32m 425\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mbool\u001b[39m(r)\n", + "File \u001b[0;32m/data/hpcdata/users/rychan/miniconda3/envs/icenet_pytorch/lib/python3.8/multiprocessing/connection.py:931\u001b[0m, in \u001b[0;36mwait\u001b[0;34m(object_list, timeout)\u001b[0m\n\u001b[1;32m 928\u001b[0m deadline \u001b[39m=\u001b[39m time\u001b[39m.\u001b[39mmonotonic() \u001b[39m+\u001b[39m timeout\n\u001b[1;32m 930\u001b[0m \u001b[39mwhile\u001b[39;00m \u001b[39mTrue\u001b[39;00m:\n\u001b[0;32m--> 931\u001b[0m ready \u001b[39m=\u001b[39m selector\u001b[39m.\u001b[39;49mselect(timeout)\n\u001b[1;32m 932\u001b[0m \u001b[39mif\u001b[39;00m ready:\n\u001b[1;32m 933\u001b[0m \u001b[39mreturn\u001b[39;00m [key\u001b[39m.\u001b[39mfileobj \u001b[39mfor\u001b[39;00m (key, events) \u001b[39min\u001b[39;00m ready]\n", + "File \u001b[0;32m/data/hpcdata/users/rychan/miniconda3/envs/icenet_pytorch/lib/python3.8/selectors.py:415\u001b[0m, in \u001b[0;36m_PollLikeSelector.select\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 413\u001b[0m ready \u001b[39m=\u001b[39m []\n\u001b[1;32m 414\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m--> 415\u001b[0m fd_event_list \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_selector\u001b[39m.\u001b[39;49mpoll(timeout)\n\u001b[1;32m 416\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mInterruptedError\u001b[39;00m:\n\u001b[1;32m 417\u001b[0m \u001b[39mreturn\u001b[39;00m ready\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "train_features, train_labels, sample_weights = next(iter(train_dataloader))" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([4, 432, 432, 9])" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_features.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([4, 432, 432, 93, 1])" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_labels.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([4, 432, 432, 93, 1])" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sample_weights.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## IceNet UNet model\n", + "\n", + "As a first attempt to implement a PyTorch example, we adapt code from https://github.com/ampersandmcd/icenet-gan/.\n", + "\n", + "Below is a PyTorch implementation of the UNet architecture." + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch import nn\n", + "import torch.nn.functional as F\n", + "\n", + "class UNet(nn.Module):\n", + " \"\"\"\n", + " An implementation of a UNet for pixelwise classification.\n", + " \"\"\"\n", + " \n", + " def __init__(self,\n", + " input_channels, \n", + " filter_size=3, \n", + " n_filters_factor=1, \n", + " n_forecast_days=6, \n", + " n_output_classes=3,\n", + " **kwargs):\n", + " super(UNet, self).__init__()\n", + "\n", + " self.input_channels = input_channels\n", + " self.filter_size = filter_size\n", + " self.n_filters_factor = n_filters_factor\n", + " self.n_forecast_days = n_forecast_days\n", + " self.n_output_classes = n_output_classes\n", + "\n", + " self.conv1a = nn.Conv2d(in_channels=input_channels, \n", + " out_channels=int(128*n_filters_factor),\n", + " kernel_size=filter_size,\n", + " padding=\"same\")\n", + " self.conv1b = nn.Conv2d(in_channels=int(128*n_filters_factor),\n", + " out_channels=int(128*n_filters_factor),\n", + " kernel_size=filter_size,\n", + " padding=\"same\")\n", + " self.bn1 = nn.BatchNorm2d(num_features=int(128*n_filters_factor))\n", + "\n", + " self.conv2a = nn.Conv2d(in_channels=int(128*n_filters_factor),\n", + " out_channels=int(256*n_filters_factor),\n", + " kernel_size=filter_size,\n", + " padding=\"same\")\n", + " self.conv2b = nn.Conv2d(in_channels=int(256*n_filters_factor),\n", + " out_channels=int(256*n_filters_factor),\n", + " kernel_size=filter_size,\n", + " padding=\"same\")\n", + " self.bn2 = nn.BatchNorm2d(num_features=int(256*n_filters_factor))\n", + "\n", + " # self.conv3a = nn.Conv2d(in_channels=int(256*n_filters_factor),\n", + " # out_channels=int(512*n_filters_factor),\n", + " # kernel_size=filter_size,\n", + " # padding=\"same\")\n", + " # self.conv3b = nn.Conv2d(in_channels=int(512*n_filters_factor),\n", + " # out_channels=int(512*n_filters_factor),\n", + " # kernel_size=filter_size,\n", + " # padding=\"same\")\n", + " # self.bn3 = nn.BatchNorm2d(num_features=int(512*n_filters_factor))\n", + "\n", + " # self.conv4a = nn.Conv2d(in_channels=int(512*n_filters_factor),\n", + " # out_channels=int(512*n_filters_factor),\n", + " # kernel_size=filter_size,\n", + " # padding=\"same\")\n", + " # self.conv4b = nn.Conv2d(in_channels=int(512*n_filters_factor),\n", + " # out_channels=int(512*n_filters_factor),\n", + " # kernel_size=filter_size,\n", + " # padding=\"same\")\n", + " # self.bn4 = nn.BatchNorm2d(num_features=int(512*n_filters_factor))\n", + "\n", + " # self.conv5a = nn.Conv2d(in_channels=int(512*n_filters_factor),\n", + " # out_channels=int(1024*n_filters_factor),\n", + " # kernel_size=filter_size,\n", + " # padding=\"same\")\n", + " # self.conv5b = nn.Conv2d(in_channels=int(1024*n_filters_factor),\n", + " # out_channels=int(1024*n_filters_factor),\n", + " # kernel_size=filter_size,\n", + " # padding=\"same\")\n", + " # self.bn5 = nn.BatchNorm2d(num_features=int(1024*n_filters_factor))\n", + "\n", + " # self.conv6a = nn.Conv2d(in_channels=int(1024*n_filters_factor),\n", + " # out_channels=int(512*n_filters_factor),\n", + " # kernel_size=filter_size,\n", + " # padding=\"same\")\n", + " # self.conv6b = nn.Conv2d(in_channels=int(1024*n_filters_factor),\n", + " # out_channels=int(512*n_filters_factor),\n", + " # kernel_size=filter_size,\n", + " # padding=\"same\")\n", + " # self.conv6c = nn.Conv2d(in_channels=int(512*n_filters_factor),\n", + " # out_channels=int(512*n_filters_factor),\n", + " # kernel_size=filter_size,\n", + " # padding=\"same\")\n", + " # self.bn6 = nn.BatchNorm2d(num_features=int(512*n_filters_factor))\n", + "\n", + " # self.conv7a = nn.Conv2d(in_channels=int(512*n_filters_factor),\n", + " # out_channels=int(512*n_filters_factor),\n", + " # kernel_size=filter_size,\n", + " # padding=\"same\")\n", + " # self.conv7b = nn.Conv2d(in_channels=int(1024*n_filters_factor),\n", + " # out_channels=int(512*n_filters_factor),\n", + " # kernel_size=filter_size,\n", + " # padding=\"same\")\n", + " # self.conv7c = nn.Conv2d(in_channels=int(512*n_filters_factor),\n", + " # out_channels=int(512*n_filters_factor),\n", + " # kernel_size=filter_size,\n", + " # padding=\"same\")\n", + " # self.bn7 = nn.BatchNorm2d(num_features=int(512*n_filters_factor))\n", + "\n", + " # self.conv8a = nn.Conv2d(in_channels=int(512*n_filters_factor),\n", + " # out_channels=int(256*n_filters_factor),\n", + " # kernel_size=filter_size,\n", + " # padding=\"same\")\n", + " # self.conv8b = nn.Conv2d(in_channels=int(512*n_filters_factor),\n", + " # out_channels=int(256*n_filters_factor),\n", + " # kernel_size=filter_size,\n", + " # padding=\"same\")\n", + " # self.conv8c = nn.Conv2d(in_channels=int(256*n_filters_factor),\n", + " # out_channels=int(256*n_filters_factor),\n", + " # kernel_size=filter_size,\n", + " # padding=\"same\")\n", + " # self.bn8 = nn.BatchNorm2d(num_features=int(256*n_filters_factor))\n", + "\n", + " self.conv9a = nn.Conv2d(in_channels=int(256*n_filters_factor),\n", + " out_channels=int(128*n_filters_factor),\n", + " kernel_size=filter_size,\n", + " padding=\"same\")\n", + " self.conv9b = nn.Conv2d(in_channels=int(256*n_filters_factor),\n", + " out_channels=int(128*n_filters_factor),\n", + " kernel_size=filter_size,\n", + " padding=\"same\")\n", + " self.conv9c = nn.Conv2d(in_channels=int(128*n_filters_factor),\n", + " out_channels=int(128*n_filters_factor),\n", + " kernel_size=filter_size,\n", + " padding=\"same\") # no batch norm on last layer\n", + "\n", + " self.final_conv = nn.Conv2d(in_channels=int(128*n_filters_factor),\n", + " out_channels=n_output_classes*n_forecast_days,\n", + " kernel_size=filter_size,\n", + " padding=\"same\")\n", + " \n", + " def forward(self, x):\n", + "\n", + " # transpose from shape (b, h, w, c) to (b, c, h, w) for pytorch conv2d layers\n", + " x = torch.movedim(x, -1, 1) # move c from last to second dim\n", + "\n", + " # run through network\n", + " conv1 = self.conv1a(x) # input to 128\n", + " conv1 = F.relu(conv1)\n", + " conv1 = self.conv1b(conv1) # 128 to 128\n", + " conv1 = F.relu(conv1)\n", + " bn1 = self.bn1(conv1)\n", + " pool1 = F.max_pool2d(bn1, kernel_size=(2, 2))\n", + "\n", + " conv2 = self.conv2a(pool1) # 128 to 256\n", + " conv2 = F.relu(conv2)\n", + " conv2 = self.conv2b(conv2) # 256 to 256\n", + " conv2 = F.relu(conv2)\n", + " bn2 = self.bn2(conv2)\n", + " pool2 = F.max_pool2d(bn2, kernel_size=(2, 2))\n", + "\n", + " # conv3 = self.conv3a(pool2) # 256 to 512\n", + " # conv3 = F.relu(conv3)\n", + " # conv3 = self.conv3b(conv3) # 512 to 512\n", + " # conv3 = F.relu(conv3)\n", + " # bn3 = self.bn3(conv3)\n", + " # pool3 = F.max_pool2d(bn3, kernel_size=(2, 2))\n", + "\n", + " # conv4 = self.conv4a(pool3) # 512 to 512\n", + " # conv4 = F.relu(conv4)\n", + " # conv4 = self.conv4b(conv4) # 512 to 512\n", + " # conv4 = F.relu(conv4)\n", + " # bn4 = self.bn4(conv4)\n", + " # pool4 = F.max_pool2d(bn4, kernel_size=(2, 2))\n", + "\n", + " # conv5 = self.conv5a(pool4) # 512 to 1024\n", + " # conv5 = F.relu(conv5)\n", + " # conv5 = self.conv5b(conv5) # 1024 to 1024\n", + " # conv5 = F.relu(conv5)\n", + " # bn5 = self.bn5(conv5)\n", + "\n", + " # up6 = F.upsample(bn5, scale_factor=2, mode=\"nearest\")\n", + " # up6 = self.conv6a(up6) # 1024 to 512\n", + " # up6 = F.relu(up6)\n", + " # merge6 = torch.cat([bn4, up6], dim=1) # 512 and 512 to 1024 along c dimension\n", + " # conv6 = self.conv6b(merge6) # 1024 to 512\n", + " # conv6 = F.relu(conv6)\n", + " # conv6 = self.conv6c(conv6) # 512 to 512\n", + " # conv6 = F.relu(conv6)\n", + " # bn6 = self.bn6(conv6)\n", + "\n", + " # up7 = F.upsample(bn6, scale_factor=2, mode=\"nearest\")\n", + " # up7 = self.conv7a(up7) # 1024 to 512\n", + " # up7 = F.relu(up7)\n", + " # merge7 = torch.cat([bn3, up7], dim=1) # 512 and 512 to 1024 along c dimension\n", + " # conv7 = self.conv7b(merge7) # 1024 to 512\n", + " # conv7 = F.relu(conv7)\n", + " # conv7 = self.conv7c(conv7) # 512 to 512\n", + " # conv7 = F.relu(conv7)\n", + " # bn7 = self.bn7(conv7)\n", + "\n", + " # up8 = F.upsample(bn7, scale_factor=2, mode=\"nearest\")\n", + " # up8 = self.conv8a(up8) # 512 to 256\n", + " # up8 = F.relu(up8)\n", + " # merge8 = torch.cat([bn2, up8], dim=1) # 256 and 256 to 512 along c dimension\n", + " # conv8 = self.conv8b(merge8) # 512 to 256\n", + " # conv8 = F.relu(conv8)\n", + " # conv8 = self.conv8c(conv8) # 256 to 256\n", + " # conv8 = F.relu(conv8)\n", + " # bn8 = self.bn8(conv8)\n", + "\n", + " up9 = F.upsample(bn8, scale_factor=2, mode=\"nearest\")\n", + " up9 = self.conv9a(up9) # 256 to 128\n", + " up9 = F.relu(up9)\n", + " merge9 = torch.cat([bn1, up9], dim=1) # 128 and 128 to 256 along c dimension\n", + " conv9 = self.conv9b(merge9) # 256 to 128\n", + " conv9 = F.relu(conv9)\n", + " conv9 = self.conv9c(conv9) # 128 to 128\n", + " conv9 = F.relu(conv9) # no batch norm on last layer\n", + " \n", + " final_layer_logits = self.final_conv(conv9)\n", + "\n", + " # transpose from shape (b, c, h, w) back to (b, h, w, c) to align with training data\n", + " final_layer_logits = torch.movedim(final_layer_logits, 1, -1) # move c from second to final dim\n", + " b, h, w, c = final_layer_logits.shape\n", + "\n", + " # unpack c=classes*months dimension into classes, months as separate dimensions\n", + " final_layer_logits = final_layer_logits.reshape((b, h, w, self.n_output_classes, self.n_forecast_days))\n", + "\n", + " output = F.softmax(final_layer_logits, dim=-2) # apply over n_output_classes dimension\n", + " \n", + " return output # shape (b, h, w, c, t)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Some metrics for evaluating IceNet performance:" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [], + "source": [ + "from torchmetrics import Metric\n", + "\n", + "class IceNetAccuracy(Metric):\n", + " \"\"\"\n", + " Binary accuracy metric for use at multiple leadtimes.\n", + " \"\"\" \n", + "\n", + " # Set class properties\n", + " is_differentiable: bool = False\n", + " higher_is_better: bool = True\n", + " full_state_update: bool = True\n", + "\n", + " def __init__(self, leadtimes_to_evaluate: list):\n", + " \"\"\"\n", + " Construct a binary accuracy metric for use at multiple leadtimes.\n", + " :param leadtimes_to_evaluate: A list of leadtimes to consider\n", + " e.g., [0, 1, 2, 3, 4, 5] to consider all six months in accuracy computation or\n", + " e.g., [0] to only look at the first month's accuracy\n", + " e.g., [5] to only look at the sixth month's accuracy\n", + " \"\"\"\n", + " super().__init__()\n", + " self.leadtimes_to_evaluate = leadtimes_to_evaluate\n", + " self.add_state(\"weighted_score\", default=torch.tensor(0.), dist_reduce_fx=\"sum\")\n", + " self.add_state(\"possible_score\", default=torch.tensor(0.), dist_reduce_fx=\"sum\")\n", + "\n", + " def update(self, preds: torch.Tensor, target: torch.Tensor, sample_weight: torch.Tensor):\n", + " # preds and target are shape (b, h, w, t)\n", + " # sum marginal and full ice for binary eval\n", + " preds = (preds > 0).long()\n", + " target = (target > 0).long()\n", + " base_score = preds[:, :, :, self.leadtimes_to_evaluate] == target[:, :, :, self.leadtimes_to_evaluate]\n", + " self.weighted_score += torch.sum(base_score * sample_weight[:, :, :, self.leadtimes_to_evaluate])\n", + " self.possible_score += torch.sum(sample_weight[:, :, :, self.leadtimes_to_evaluate])\n", + "\n", + " def compute(self):\n", + " return self.weighted_score.float() / self.possible_score\n", + "\n", + "\n", + "class SIEError(Metric):\n", + " \"\"\"\n", + " Sea Ice Extent error metric (in km^2) for use at multiple leadtimes.\n", + " \"\"\" \n", + "\n", + " # Set class properties\n", + " is_differentiable: bool = False\n", + " higher_is_better: bool = False\n", + " full_state_update: bool = True\n", + "\n", + " def __init__(self, leadtimes_to_evaluate: list):\n", + " \"\"\"\n", + " Construct an SIE error metric (in km^2) for use at multiple leadtimes.\n", + " :param leadtimes_to_evaluate: A list of leadtimes to consider\n", + " e.g., [0, 1, 2, 3, 4, 5] to consider all six months in computation or\n", + " e.g., [0] to only look at the first month\n", + " e.g., [5] to only look at the sixth month\n", + " \"\"\"\n", + " super().__init__()\n", + " self.leadtimes_to_evaluate = leadtimes_to_evaluate\n", + " self.add_state(\"pred_sie\", default=torch.tensor(0.), dist_reduce_fx=\"sum\")\n", + " self.add_state(\"true_sie\", default=torch.tensor(0.), dist_reduce_fx=\"sum\")\n", + "\n", + " def update(self, preds: torch.Tensor, target: torch.Tensor, sample_weight: torch.Tensor):\n", + " # preds and target are shape (b, h, w, t)\n", + " # sum marginal and full ice for binary eval\n", + " preds = (preds > 0).long()\n", + " target = (target > 0).long()\n", + " self.pred_sie += preds[:, :, :, self.leadtimes_to_evaluate].sum()\n", + " self.true_sie += target[:, :, :, self.leadtimes_to_evaluate].sum()\n", + "\n", + " def compute(self):\n", + " return (self.pred_sie - self.true_sie) * 25**2 # each pixel is 25x25 km" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A _LightningModule_ wrapper for UNet model." + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [], + "source": [ + "import lightning.pytorch as pl\n", + "from torchmetrics import MetricCollection\n", + "\n", + "class LitUNet(pl.LightningModule):\n", + " \"\"\"\n", + " A LightningModule wrapping the UNet implementation of IceNet.\n", + " \"\"\"\n", + " def __init__(self,\n", + " model: nn.Module,\n", + " criterion: callable,\n", + " learning_rate: float):\n", + " \"\"\"\n", + " Construct a UNet LightningModule.\n", + " Note that we keep hyperparameters separate from dataloaders to prevent data leakage at test time.\n", + " :param model: PyTorch model\n", + " :param criterion: PyTorch loss function for training instantiated with reduction=\"none\"\n", + " :param learning_rate: Float learning rate for our optimiser\n", + " \"\"\"\n", + " super().__init__()\n", + " self.model = model\n", + " self.criterion = criterion\n", + " self.learning_rate = learning_rate\n", + " self.n_output_classes = model.n_output_classes # this should be a property of the network\n", + "\n", + " metrics = {\n", + " \"val_accuracy\": IceNetAccuracy(leadtimes_to_evaluate=list(range(self.model.n_forecast_days))),\n", + " \"val_sieerror\": SIEError(leadtimes_to_evaluate=list(range(self.model.n_forecast_days)))\n", + " }\n", + " for i in range(self.model.n_forecast_days):\n", + " metrics[f\"val_accuracy_{i}\"] = IceNetAccuracy(leadtimes_to_evaluate=[i])\n", + " metrics[f\"val_sieerror_{i}\"] = SIEError(leadtimes_to_evaluate=[i])\n", + " self.metrics = MetricCollection(metrics)\n", + "\n", + " test_metrics = {\n", + " \"test_accuracy\": IceNetAccuracy(leadtimes_to_evaluate=list(range(self.model.n_forecast_days))),\n", + " \"test_sieerror\": SIEError(leadtimes_to_evaluate=list(range(self.model.n_forecast_days)))\n", + " }\n", + " for i in range(self.model.n_forecast_days):\n", + " test_metrics[f\"test_accuracy_{i}\"] = IceNetAccuracy(leadtimes_to_evaluate=[i])\n", + " test_metrics[f\"test_sieerror_{i}\"] = SIEError(leadtimes_to_evaluate=[i])\n", + " self.test_metrics = MetricCollection(test_metrics)\n", + "\n", + " self.save_hyperparameters(ignore=[\"model\", \"criterion\"])\n", + "\n", + " def forward(self, x):\n", + " \"\"\"\n", + " Implement forward function.\n", + " :param x: Inputs to model.\n", + " :return: Outputs of model.\n", + " \"\"\"\n", + " return self.model(x)\n", + "\n", + " def training_step(self, batch):\n", + " \"\"\"\n", + " Perform a pass through a batch of training data.\n", + " Apply pixel-weighted loss by manually reducing.\n", + " See e.g. https://discuss.pytorch.org/t/unet-pixel-wise-weighted-loss-function/46689/5.\n", + " :param batch: Batch of input, output, weight triplets\n", + " :param batch_idx: Index of batch\n", + " :return: Loss from this batch of data for use in backprop\n", + " \"\"\"\n", + " x, y, sample_weight = batch\n", + " y_hat = self.model(x)\n", + " # y and y_hat are shape (b, h, w, c, t) but loss expects (b, c, h, w, t)\n", + " # note that criterion needs reduction=\"none\" for weighting to work\n", + " if isinstance(self.criterion, nn.CrossEntropyLoss): # requires int class encoding\n", + " loss = self.criterion(y_hat.movedim(-2, 1), y.argmax(-2).long())\n", + " else: # requires one-hot encoding\n", + " loss = self.criterion(y_hat.movedim(-2, 1), y.movedim(-2, 1))\n", + " loss = torch.mean(loss * sample_weight.movedim(-2, 1))\n", + " self.log(\"train_loss\", loss, sync_dist=True)\n", + " return loss\n", + "\n", + " def validation_step(self, batch):\n", + " x, y, sample_weight = batch\n", + " y_hat = self.model(x)\n", + " # y and y_hat are shape (b, h, w, c, t) but loss expects (b, c, h, w, t)\n", + " # note that criterion needs reduction=\"none\" for weighting to work\n", + " if isinstance(self.criterion, nn.CrossEntropyLoss): # requires int class encoding\n", + " loss = self.criterion(y_hat.movedim(-2, 1), y.argmax(-2).long())\n", + " else: # requires one-hot encoding\n", + " loss = self.criterion(y_hat.movedim(-2, 1), y.movedim(-2, 1))\n", + " loss = torch.mean(loss * sample_weight.movedim(-2, 1))\n", + " self.log(\"val_loss\", loss, on_step=False, on_epoch=True, sync_dist=True) # epoch-level loss\n", + " y_hat_pred = y_hat.argmax(dim=-2).long() # argmax over c where shape is (b, h, w, c, t)\n", + " self.metrics.update(y_hat_pred, y.argmax(dim=-2).long(), sample_weight.squeeze(dim=-2)) # shape (b, h, w, t)\n", + " return loss\n", + "\n", + " def on_validation_epoch_end(self):\n", + " self.log_dict(self.metrics.compute(), on_step=False, on_epoch=True, sync_dist=True) # epoch-level metrics\n", + " self.metrics.reset()\n", + "\n", + " def test_step(self, batch):\n", + " x, y, sample_weight = batch\n", + " y_hat = self.model(x)\n", + " # y and y_hat are shape (b, h, w, c, t) but loss expects (b, c, h, w, t)\n", + " # note that criterion needs reduction=\"none\" for weighting to work\n", + " if isinstance(self.criterion, nn.CrossEntropyLoss): # requires int class encoding\n", + " loss = self.criterion(y_hat.movedim(-2, 1), y.argmax(-2).long())\n", + " else: # requires one-hot encoding\n", + " loss = self.criterion(y_hat.movedim(-2, 1), y.movedim(-2, 1))\n", + " loss = torch.mean(loss * sample_weight.movedim(-2, 1))\n", + " self.log(\"test_loss\", loss, on_step=False, on_epoch=True, sync_dist=True) # epoch-level loss\n", + " y_hat_pred = y_hat.argmax(dim=-2) # argmax over c where shape is (b, h, w, c, t)\n", + " self.test_metrics.update(y_hat_pred, y.argmax(dim=-2).long(), sample_weight.squeeze(dim=-2)) # shape (b, h, w, t)\n", + " return loss\n", + "\n", + " def on_test_epoch_end(self):\n", + " self.log_dict(self.test_metrics.compute(),on_step=False, on_epoch=True, sync_dist=True) # epoch-level metrics\n", + " self.test_metrics.reset()\n", + "\n", + " def configure_optimizers(self):\n", + " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", + " return {\n", + " \"optimizer\": optimizer\n", + " }" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Function for training UNet model using PyTorch Lightning." + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [], + "source": [ + "from lightning.pytorch.callbacks import ModelCheckpoint\n", + "\n", + "def train_icenet(configuration_path,\n", + " learning_rate,\n", + " max_epochs,\n", + " batch_size,\n", + " n_workers,\n", + " filter_size,\n", + " n_filters_factor,\n", + " seed):\n", + " \"\"\"\n", + " Train IceNet using the arguments specified in the `args` namespace.\n", + " :param args: Namespace of configuration parameters\n", + " \"\"\"\n", + " # init\n", + " pl.seed_everything(seed)\n", + " \n", + " # configure datasets and dataloaders\n", + " train_dataset = IceNetDataSetPyTorch(configuration_path, mode=\"train\")\n", + " val_dataset = IceNetDataSetPyTorch(configuration_path, mode=\"val\")\n", + " train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=n_workers,\n", + " persistent_workers=True, shuffle=False)\n", + " val_dataloader = DataLoader(val_dataset, batch_size=batch_size, num_workers=n_workers,\n", + " persistent_workers=True, shuffle=False)\n", + "\n", + " # construct unet\n", + " model = UNet(\n", + " input_channels=len(train_dataset._ds._config[\"channels\"]),\n", + " filter_size=filter_size,\n", + " n_filters_factor=n_filters_factor,\n", + " n_forecast_days=train_dataset._ds._config[\"n_forecast_days\"]\n", + " )\n", + " \n", + " criterion = nn.CrossEntropyLoss(reduction=\"none\")\n", + " \n", + " # configure PyTorch Lightning module\n", + " lit_module = LitUNet(\n", + " model=model,\n", + " criterion=criterion,\n", + " learning_rate=learning_rate\n", + " )\n", + "\n", + " # set up trainer configuration\n", + " trainer = pl.Trainer(\n", + " accelerator=\"auto\",\n", + " devices=1,\n", + " log_every_n_steps=10,\n", + " max_epochs=max_epochs,\n", + " num_sanity_val_steps=1,\n", + " )\n", + " trainer.callbacks.append(ModelCheckpoint(monitor=\"val_accuracy\", mode=\"max\"))\n", + "\n", + " # train model\n", + " print(f\"Training {len(train_dataset)} examples / {len(train_dataloader)} batches (batch size {batch_size}).\")\n", + " print(f\"Validating {len(val_dataset)} examples / {len(val_dataloader)} batches (batch size {batch_size}).\")\n", + " trainer.fit(lit_module, train_dataloader, val_dataloader)" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO: Global seed set to 45\n", + "INFO:lightning.fabric.utilities.seed:Global seed set to 45\n", + "INFO:root:Loading configuration dataset_config.pytorch_notebook.json\n", + "WARNING:root:Running in configuration only mode, tfrecords were not generated for this dataset\n", + "INFO:root:Loading configuration /data/hpcdata/users/rychan/notebooks/notebook-pipeline/loader.notebook_api_data.json\n", + "INFO:root:Loading configuration dataset_config.pytorch_notebook.json\n", + "WARNING:root:Running in configuration only mode, tfrecords were not generated for this dataset\n", + "INFO:root:Loading configuration /data/hpcdata/users/rychan/notebooks/notebook-pipeline/loader.notebook_api_data.json\n", + "INFO: GPU available: True (cuda), used: True\n", + "INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True\n", + "INFO: TPU available: False, using: 0 TPU cores\n", + "INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n", + "INFO: IPU available: False, using: 0 IPUs\n", + "INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs\n", + "INFO: HPU available: False, using: 0 HPUs\n", + "INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs\n", + "INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "INFO: \n", + " | Name | Type | Params\n", + "--------------------------------------------------\n", + "0 | model | UNet | 1.8 M \n", + "1 | criterion | CrossEntropyLoss | 0 \n", + "2 | metrics | MetricCollection | 0 \n", + "3 | test_metrics | MetricCollection | 0 \n", + "--------------------------------------------------\n", + "1.8 M Trainable params\n", + "0 Non-trainable params\n", + "1.8 M Total params\n", + "7.224 Total estimated model params size (MB)\n", + "INFO:lightning.pytorch.callbacks.model_summary:\n", + " | Name | Type | Params\n", + "--------------------------------------------------\n", + "0 | model | UNet | 1.8 M \n", + "1 | criterion | CrossEntropyLoss | 0 \n", + "2 | metrics | MetricCollection | 0 \n", + "3 | test_metrics | MetricCollection | 0 \n", + "--------------------------------------------------\n", + "1.8 M Trainable params\n", + "0 Non-trainable params\n", + "1.8 M Total params\n", + "7.224 Total estimated model params size (MB)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training 91 examples / 23 batches (batch size 4).\n", + "Validating 21 examples / 6 batches (batch size 4).\n", + "Sanity Checking: 0it [00:00, ?it/s]" + ] + } + ], + "source": [ + "seed = 45\n", + "train_icenet(configuration_path=dataset_config,\n", + " learning_rate=1e-4,\n", + " max_epochs=10,\n", + " batch_size=4,\n", + " n_workers=12,\n", + " filter_size=3,\n", + " n_filters_factor=1,\n", + " seed=seed)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "9" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(train_dataset._ds._config[\"channels\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "7" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_dataset._ds._config[\"n_forecast_days\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "icenet", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.17" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pytorch_example/icenet_pytorch_dataset.py b/pytorch_example/icenet_pytorch_dataset.py new file mode 100644 index 0000000..47b849c --- /dev/null +++ b/pytorch_example/icenet_pytorch_dataset.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import pandas as pd +from icenet.data.dataset import IceNetDataSet +from torch.utils.data import Dataset + +class IceNetDataSetPyTorch(Dataset): + def __init__(self, + configuration_path: str, + mode: str, + n_forecast_days: int | None = None, + generate_workers: int | None = None): + self._ds = IceNetDataSet(configuration_path=configuration_path) + self._dl = self._ds.get_data_loader( + n_forecast_days=n_forecast_days, + generate_workers=generate_workers + ) + + # check mode option + if mode not in ["train", "val", "test"]: + raise ValueError("mode must be either 'train', 'val' or 'test'") + self._mode = mode + + self._dates = [ + x.replace('_', '-') + for x in self._dl._config["sources"]["osisaf"]["dates"][self._mode] + ] + + def __len__(self): + return self._ds._counts[self._mode] + + def __getitem__(self, idx): + return self._dl.generate_sample(date=pd.Timestamp(self._dates[idx])) \ No newline at end of file diff --git a/pytorch_example/icenet_unet_small.py b/pytorch_example/icenet_unet_small.py new file mode 100644 index 0000000..3e55fcd --- /dev/null +++ b/pytorch_example/icenet_unet_small.py @@ -0,0 +1,218 @@ +""" +Taken from Andrew McDonald's https://github.com/ampersandmcd/icenet-gan/blob/main/src/models.py. +""" + +import torch +from torch import nn +import torch.nn.functional as F +import lightning.pytorch as pl +# from torchmetrics import MetricCollection +# from metrics import IceNetAccuracy, SIEError + + +def weighted_mse_loss(input, target, weight): + return torch.sum(weight * (input - target) ** 2) + + +class UNet(nn.Module): + """ + A (small) implementation of a UNet for pixelwise classification. + """ + + def __init__(self, + input_channels: int, + filter_size: int = 3, + n_filters_factor: int = 1, + n_forecast_days: int = 7): + super(UNet, self).__init__() + + self.input_channels = input_channels + self.filter_size = filter_size + self.n_filters_factor = n_filters_factor + self.n_forecast_days = n_forecast_days + + self.conv1a = nn.Conv2d(in_channels=input_channels, + out_channels=int(128*n_filters_factor), + kernel_size=filter_size, + padding="same") + self.conv1b = nn.Conv2d(in_channels=int(128*n_filters_factor), + out_channels=int(128*n_filters_factor), + kernel_size=filter_size, + padding="same") + self.bn1 = nn.BatchNorm2d(num_features=int(128*n_filters_factor)) + + self.conv2a = nn.Conv2d(in_channels=int(128*n_filters_factor), + out_channels=int(256*n_filters_factor), + kernel_size=filter_size, + padding="same") + self.conv2b = nn.Conv2d(in_channels=int(256*n_filters_factor), + out_channels=int(256*n_filters_factor), + kernel_size=filter_size, + padding="same") + self.bn2 = nn.BatchNorm2d(num_features=int(256*n_filters_factor)) + + self.conv9a = nn.Conv2d(in_channels=int(256*n_filters_factor), + out_channels=int(128*n_filters_factor), + kernel_size=filter_size, + padding="same") + self.conv9b = nn.Conv2d(in_channels=int(256*n_filters_factor), + out_channels=int(128*n_filters_factor), + kernel_size=filter_size, + padding="same") + self.conv9c = nn.Conv2d(in_channels=int(128*n_filters_factor), + out_channels=int(128*n_filters_factor), + kernel_size=filter_size, + padding="same") # no batch norm on last layer + + self.final_conv = nn.Conv2d(in_channels=int(128*n_filters_factor), + out_channels=n_forecast_days, + kernel_size=filter_size, + padding="same") + + def forward(self, x): + # transpose from shape (b, h, w, c) to (b, c, h, w) for pytorch conv2d layers + x = torch.movedim(x, -1, 1) # move c from last to second dim + + # run through network + conv1 = self.conv1a(x) # input to 128 + conv1 = F.relu(conv1) + conv1 = self.conv1b(conv1) # 128 to 128 + conv1 = F.relu(conv1) + bn1 = self.bn1(conv1) + pool1 = F.max_pool2d(bn1, kernel_size=(2, 2)) + + conv2 = self.conv2a(pool1) # 128 to 256 + conv2 = F.relu(conv2) + conv2 = self.conv2b(conv2) # 256 to 256 + conv2 = F.relu(conv2) + bn2 = self.bn2(conv2) + + up9 = F.upsample(bn2, scale_factor=2, mode="nearest") + up9 = self.conv9a(up9) # 256 to 128 + up9 = F.relu(up9) + merge9 = torch.cat([bn1, up9], dim=1) # 128 and 128 to 256 along c dimension + conv9 = self.conv9b(merge9) # 256 to 128 + conv9 = F.relu(conv9) + conv9 = self.conv9c(conv9) # 128 to 128 + conv9 = F.relu(conv9) # no batch norm on last layer + + final_layer_logits = self.final_conv(conv9) + + # transpose from shape (b, c, h, w) back to (b, h, w, c) to align with training data + final_layer_logits = torch.movedim(final_layer_logits, 1, -1) # move c from second to final dim + + # apply sigmoid + output = F.sigmoid(final_layer_logits) + + return output # shape (b, h, w, c) + + +class LitUNet(pl.LightningModule): + """ + A LightningModule wrapping the UNet implementation of IceNet. + """ + def __init__(self, + model: nn.Module, + criterion: callable, + learning_rate: float): + """ + Construct a UNet LightningModule. + Note that we keep hyperparameters separate from dataloaders to prevent data leakage at test time. + :param model: PyTorch model + :param criterion: PyTorch loss function for training instantiated with reduction="none" + :param learning_rate: Float learning rate for our optimiser + """ + super().__init__() + self.model = model + self.criterion = criterion + self.learning_rate = learning_rate + + # metrics = { + # "val_accuracy": IceNetAccuracy(leadtimes_to_evaluate=list(range(self.model.n_forecast_days))), + # "val_sieerror": SIEError(leadtimes_to_evaluate=list(range(self.model.n_forecast_days))) + # } + # for i in range(self.model.n_forecast_days): + # metrics[f"val_accuracy_{i}"] = IceNetAccuracy(leadtimes_to_evaluate=[i]) + # metrics[f"val_sieerror_{i}"] = SIEError(leadtimes_to_evaluate=[i]) + # self.metrics = MetricCollection(metrics) + + # test_metrics = { + # "test_accuracy": IceNetAccuracy(leadtimes_to_evaluate=list(range(self.model.n_forecast_days))), + # "test_sieerror": SIEError(leadtimes_to_evaluate=list(range(self.model.n_forecast_days))) + # } + # for i in range(self.model.n_forecast_days): + # test_metrics[f"test_accuracy_{i}"] = IceNetAccuracy(leadtimes_to_evaluate=[i]) + # test_metrics[f"test_sieerror_{i}"] = SIEError(leadtimes_to_evaluate=[i]) + # self.test_metrics = MetricCollection(test_metrics) + + self.save_hyperparameters(ignore=["model", "criterion"]) + + def forward(self, x): + """ + Implement forward function. + :param x: Inputs to model. + :return: Outputs of model. + """ + return self.model(x) + + def training_step(self, batch, batch_idx): + """ + Perform a pass through a batch of training data. + Apply pixel-weighted loss by manually reducing. + See e.g. https://discuss.pytorch.org/t/unet-pixel-wise-weighted-loss-function/46689/5. + :param batch: Batch of input, output, weight triplets + :param batch_idx: Index of batch + :return: Loss from this batch of data for use in backprop + """ + print("in training") + x, y, sample_weight = batch + y_hat = self.model(x) + print(f"y.shape: {y.shape}") + print(f"y[:,:,:,:,0].shape: {y[:,:,:,:,0].shape}") + print(f"y_hat.shape: {y_hat.shape}") + print(f"sample_weight[:,:,:,:,0].shape: {sample_weight[:,:,:,:,0].shape}") + # y and sample_weight have shape (b, h, w, c, 1) + # y_hat has shape (b, h, w, c) + loss = self.criterion(y[:,:,:,:,0], y_hat) + loss = torch.mean(loss * sample_weight[:,:,:,:,0]) + self.log("train_loss", loss, sync_dist=True) + return loss + + def validation_step(self, batch, batch_idx): + print("in validation") + x, y, sample_weight = batch + y_hat = self.model(x) + print(f"y.shape: {y.shape}") + print(f"y[:,:,:,:,0].shape: {y[:,:,:,:,0].shape}") + print(f"y_hat.shape: {y_hat.shape}") + print(f"sample_weight[:,:,:,:,0].shape: {sample_weight[:,:,:,:,0].shape}") + # y and sample_weight have shape (b, h, w, c, 1) + # y_hat has shape (b, h, w, c) + loss = self.criterion(y[:,:,:,:,0], y_hat) + loss = torch.mean(loss * sample_weight[:,:,:,:,0]) + self.log("val_loss", loss, on_step=False, on_epoch=True, sync_dist=True) # epoch-level loss + return loss + + # def on_validation_epoch_end(self): + # self.log_dict(self.metrics.compute(), on_step=False, on_epoch=True, sync_dist=True) # epoch-level metrics + # self.metrics.reset() + + def test_step(self, batch, batch_idx): + x, y, sample_weight = batch + y_hat = self.model(x) + # y and sample_weight have shape (b, h, w, c, 1) + # y_hat has shape (b, h, w, c) + loss = self.criterion(y[:,:,:,:,0], y_hat) + loss = torch.mean(loss * sample_weight[:,:,:,:,0]) + self.log("test_loss", loss, on_step=False, on_epoch=True, sync_dist=True) # epoch-level loss + return loss + + # def on_test_epoch_end(self): + # # self.log_dict(self.test_metrics.compute(),on_step=False, on_epoch=True, sync_dist=True) # epoch-level metrics + # self.test_metrics.reset() + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) + return { + "optimizer": optimizer + } \ No newline at end of file diff --git a/pytorch_example/metrics.py b/pytorch_example/metrics.py new file mode 100644 index 0000000..65cd196 --- /dev/null +++ b/pytorch_example/metrics.py @@ -0,0 +1,79 @@ +""" +Taken from Andrew McDonald's https://github.com/ampersandmcd/icenet-gan/blob/main/src/metrics.py +Adapted from Tom Andersson's https://github.com/tom-andersson/icenet-paper/blob/main/icenet/metrics.py +Modified from Tensorflow to PyTorch and PyTorch Lightning. +Extended to include additional sharpness metrics. +""" +import torch +from torchmetrics import Metric + + +class IceNetAccuracy(Metric): + """ + Binary accuracy metric for use at multiple leadtimes. + """ + + # Set class properties + is_differentiable: bool = False + higher_is_better: bool = True + full_state_update: bool = True + + def __init__(self, leadtimes_to_evaluate: list): + """ + Construct a binary accuracy metric for use at multiple leadtimes. + :param leadtimes_to_evaluate: A list of leadtimes to consider + e.g., [0, 1, 2, 3, 4, 5] to consider all six months in accuracy computation or + e.g., [0] to only look at the first month's accuracy + e.g., [5] to only look at the sixth month's accuracy + """ + super().__init__() + self.leadtimes_to_evaluate = leadtimes_to_evaluate + self.add_state("weighted_score", default=torch.tensor(0.), dist_reduce_fx="sum") + self.add_state("possible_score", default=torch.tensor(0.), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor, sample_weight: torch.Tensor): + # preds and target are shape (b, h, w, t) + # sum marginal and full ice for binary eval + preds = (preds > 0).long() + target = (target > 0).long() + base_score = preds[:, :, :, self.leadtimes_to_evaluate] == target[:, :, :, self.leadtimes_to_evaluate] + self.weighted_score += torch.sum(base_score * sample_weight[:, :, :, self.leadtimes_to_evaluate]) + self.possible_score += torch.sum(sample_weight[:, :, :, self.leadtimes_to_evaluate]) + + def compute(self): + return self.weighted_score.float() / self.possible_score + + +class SIEError(Metric): + """ + Sea Ice Extent error metric (in km^2) for use at multiple leadtimes. + """ + + # Set class properties + is_differentiable: bool = False + higher_is_better: bool = False + full_state_update: bool = True + + def __init__(self, leadtimes_to_evaluate: list): + """ + Construct an SIE error metric (in km^2) for use at multiple leadtimes. + :param leadtimes_to_evaluate: A list of leadtimes to consider + e.g., [0, 1, 2, 3, 4, 5] to consider all six months in computation or + e.g., [0] to only look at the first month + e.g., [5] to only look at the sixth month + """ + super().__init__() + self.leadtimes_to_evaluate = leadtimes_to_evaluate + self.add_state("pred_sie", default=torch.tensor(0.), dist_reduce_fx="sum") + self.add_state("true_sie", default=torch.tensor(0.), dist_reduce_fx="sum") + + def update(self, preds: torch.Tensor, target: torch.Tensor, sample_weight: torch.Tensor): + # preds and target are shape (b, h, w, t) + # sum marginal and full ice for binary eval + preds = (preds > 0).long() + target = (target > 0).long() + self.pred_sie += preds[:, :, :, self.leadtimes_to_evaluate].sum() + self.true_sie += target[:, :, :, self.leadtimes_to_evaluate].sum() + + def compute(self): + return (self.pred_sie - self.true_sie) * 25**2 # each pixel is 25x25 km \ No newline at end of file diff --git a/pytorch_example/pytorch_example.ipynb b/pytorch_example/pytorch_example.ipynb new file mode 100644 index 0000000..563dcb3 --- /dev/null +++ b/pytorch_example/pytorch_example.ipynb @@ -0,0 +1,3916 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running in /data/hpcdata/users/rychan/notebooks/notebook-pipeline\n" + ] + } + ], + "source": [ + "# Quick hack to put us in the icenet-pipeline folder,\n", + "# assuming it was created as per 01.cli_demonstration.ipynb\n", + "import os\n", + "if os.path.exists(\"pytorch_example.ipynb\"):\n", + " os.chdir(\"../../notebook-pipeline\")\n", + "print(\"Running in {}\".format(os.getcwd()))\n", + "\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-08-31 12:01:25.322765: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", + "2023-08-31 12:01:25.378600: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import torch\n", + "import logging\n", + "\n", + "from icenet.data.loaders import IceNetDataLoaderFactory\n", + "from icenet.data.dataset import IceNetDataSet\n", + "from icenet_pytorch_dataset import IceNetDataSetPyTorch\n", + "\n", + "from train_icenet_unet import train_icenet_unet\n", + "from test_icenet_unet import test_icenet_unet\n", + "\n", + "# We also set the logging level so that we get some feedback from the API\n", + "import logging\n", + "logging.basicConfig(level=logging.INFO)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "A 2.0.1+cu117\n", + "B True\n", + "C True\n", + "D _CudaDeviceProperties(name='NVIDIA A2', major=8, minor=6, total_memory=14938MB, multi_processor_count=10)\n" + ] + } + ], + "source": [ + "print('A', torch.__version__)\n", + "print('B', torch.cuda.is_available())\n", + "print('C', torch.backends.cudnn.enabled)\n", + "device = torch.device('cuda')\n", + "print('D', torch.cuda.get_device_properties(device))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Thu Aug 31 12:01:29 2023 \n", + "+-----------------------------------------------------------------------------+\n", + "| NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 |\n", + "|-------------------------------+----------------------+----------------------+\n", + "| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n", + "| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n", + "| | | MIG M. |\n", + "|===============================+======================+======================|\n", + "| 0 NVIDIA A2 On | 00000000:98:00.0 Off | 0 |\n", + "| 0% 33C P8 4W / 60W | 2MiB / 15356MiB | 0% Default |\n", + "| | | N/A |\n", + "+-------------------------------+----------------------+----------------------+\n", + " \n", + "+-----------------------------------------------------------------------------+\n", + "| Processes: |\n", + "| GPU GI CI PID Type Process name GPU Memory |\n", + "| ID ID Usage |\n", + "|=============================================================================|\n", + "| No running processes found |\n", + "+-----------------------------------------------------------------------------+\n" + ] + } + ], + "source": [ + "!nvidia-smi" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Dataset creation\n", + "\n", + "Assuming we have ran [03.library_usage](03.library_usage.ipynb) `loader.notebook_api_data.json` file existing in the current directory." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:Loading configuration loader.notebook_api_data.json\n" + ] + } + ], + "source": [ + "implementation = \"dask\"\n", + "loader_config = \"loader.notebook_api_data.json\"\n", + "dataset_name = \"pytorch_notebook\"\n", + "lag = 1\n", + "\n", + "dl = IceNetDataLoaderFactory().create_data_loader(\n", + " implementation,\n", + " loader_config,\n", + " dataset_name,\n", + " lag,\n", + " n_forecast_days=7,\n", + " north=False,\n", + " south=True,\n", + " output_batch_size=4,\n", + " generate_workers=8)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dl.workers" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "7" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dl._n_forecast_days" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'sources': {'era5': {'name': 'notebook_api_data',\n", + " 'implementation': 'IceNetERA5PreProcessor',\n", + " 'anom': ['tas', 'zg500', 'zg250'],\n", + " 'abs': ['uas', 'vas'],\n", + " 'dates': {'train': ['2020_01_01',\n", + " '2020_01_02',\n", + " '2020_01_03',\n", + " '2020_01_04',\n", + " '2020_01_05',\n", + " '2020_01_06',\n", + " '2020_01_07',\n", + " '2020_01_08',\n", + " '2020_01_09',\n", + " '2020_01_10',\n", + " '2020_01_11',\n", + " '2020_01_12',\n", + " '2020_01_13',\n", + " '2020_01_14',\n", + " '2020_01_15',\n", + " '2020_01_16',\n", + " '2020_01_17',\n", + " '2020_01_18',\n", + " '2020_01_19',\n", + " '2020_01_20',\n", + " '2020_01_21',\n", + " '2020_01_22',\n", + " '2020_01_23',\n", + " '2020_01_24',\n", + " '2020_01_25',\n", + " '2020_01_26',\n", + " '2020_01_27',\n", + " '2020_01_28',\n", + " '2020_01_29',\n", + " '2020_01_30',\n", + " '2020_01_31',\n", + " '2020_02_01',\n", + " '2020_02_02',\n", + " '2020_02_03',\n", + " '2020_02_04',\n", + " '2020_02_05',\n", + " '2020_02_06',\n", + " '2020_02_07',\n", + " '2020_02_08',\n", + " '2020_02_09',\n", + " '2020_02_10',\n", + " '2020_02_11',\n", + " '2020_02_12',\n", + " '2020_02_13',\n", + " '2020_02_14',\n", + " '2020_02_15',\n", + " '2020_02_16',\n", + " '2020_02_17',\n", + " '2020_02_18',\n", + " '2020_02_19',\n", + " '2020_02_20',\n", + " '2020_02_21',\n", + " '2020_02_22',\n", + " '2020_02_23',\n", + " '2020_02_24',\n", + " '2020_02_25',\n", + " '2020_02_26',\n", + " '2020_02_27',\n", + " '2020_02_28',\n", + " '2020_02_29',\n", + " '2020_03_01',\n", + " '2020_03_02',\n", + " '2020_03_03',\n", + " '2020_03_04',\n", + " '2020_03_05',\n", + " '2020_03_06',\n", + " '2020_03_07',\n", + " '2020_03_08',\n", + " '2020_03_09',\n", + " '2020_03_10',\n", + " '2020_03_11',\n", + " '2020_03_12',\n", + " '2020_03_13',\n", + " '2020_03_14',\n", + " '2020_03_15',\n", + " '2020_03_16',\n", + " '2020_03_17',\n", + " '2020_03_18',\n", + " '2020_03_19',\n", + " '2020_03_20',\n", + " '2020_03_21',\n", + " '2020_03_22',\n", + " '2020_03_23',\n", + " '2020_03_24',\n", + " '2020_03_25',\n", + " '2020_03_26',\n", + " '2020_03_27',\n", + " '2020_03_28',\n", + " '2020_03_29',\n", + " '2020_03_30',\n", + " '2020_03_31'],\n", + " 'val': ['2020_04_03',\n", + " '2020_04_04',\n", + " '2020_04_05',\n", + " '2020_04_06',\n", + " '2020_04_07',\n", + " '2020_04_08',\n", + " '2020_04_09',\n", + " '2020_04_10',\n", + " '2020_04_11',\n", + " '2020_04_12',\n", + " '2020_04_13',\n", + " '2020_04_14',\n", + " '2020_04_15',\n", + " '2020_04_16',\n", + " '2020_04_17',\n", + " '2020_04_18',\n", + " '2020_04_19',\n", + " '2020_04_20',\n", + " '2020_04_21',\n", + " '2020_04_22',\n", + " '2020_04_23'],\n", + " 'test': ['2020_04_01', '2020_04_02']},\n", + " 'linear_trends': [],\n", + " 'linear_trend_steps': [1, 2, 3, 4, 5, 6, 7],\n", + " 'meta': [],\n", + " 'var_files': {'uas': ['./processed/notebook_api_data/era5/south/uas/uas_abs.nc'],\n", + " 'vas': ['./processed/notebook_api_data/era5/south/vas/vas_abs.nc'],\n", + " 'tas': ['./processed/notebook_api_data/era5/south/tas/tas_anom.nc'],\n", + " 'zg500': ['./processed/notebook_api_data/era5/south/zg500/zg500_anom.nc'],\n", + " 'zg250': ['./processed/notebook_api_data/era5/south/zg250/zg250_anom.nc']}},\n", + " 'osisaf': {'name': 'notebook_api_data',\n", + " 'implementation': 'IceNetOSIPreProcessor',\n", + " 'anom': [],\n", + " 'abs': ['siconca'],\n", + " 'dates': {'train': ['2020_01_01',\n", + " '2020_01_02',\n", + " '2020_01_03',\n", + " '2020_01_04',\n", + " '2020_01_05',\n", + " '2020_01_06',\n", + " '2020_01_07',\n", + " '2020_01_08',\n", + " '2020_01_09',\n", + " '2020_01_10',\n", + " '2020_01_11',\n", + " '2020_01_12',\n", + " '2020_01_13',\n", + " '2020_01_14',\n", + " '2020_01_15',\n", + " '2020_01_16',\n", + " '2020_01_17',\n", + " '2020_01_18',\n", + " '2020_01_19',\n", + " '2020_01_20',\n", + " '2020_01_21',\n", + " '2020_01_22',\n", + " '2020_01_23',\n", + " '2020_01_24',\n", + " '2020_01_25',\n", + " '2020_01_26',\n", + " '2020_01_27',\n", + " '2020_01_28',\n", + " '2020_01_29',\n", + " '2020_01_30',\n", + " '2020_01_31',\n", + " '2020_02_01',\n", + " '2020_02_02',\n", + " '2020_02_03',\n", + " '2020_02_04',\n", + " '2020_02_05',\n", + " '2020_02_06',\n", + " '2020_02_07',\n", + " '2020_02_08',\n", + " '2020_02_09',\n", + " '2020_02_10',\n", + " '2020_02_11',\n", + " '2020_02_12',\n", + " '2020_02_13',\n", + " '2020_02_14',\n", + " '2020_02_15',\n", + " '2020_02_16',\n", + " '2020_02_17',\n", + " '2020_02_18',\n", + " '2020_02_19',\n", + " '2020_02_20',\n", + " '2020_02_21',\n", + " '2020_02_22',\n", + " '2020_02_23',\n", + " '2020_02_24',\n", + " '2020_02_25',\n", + " '2020_02_26',\n", + " '2020_02_27',\n", + " '2020_02_28',\n", + " '2020_02_29',\n", + " '2020_03_01',\n", + " '2020_03_02',\n", + " '2020_03_03',\n", + " '2020_03_04',\n", + " '2020_03_05',\n", + " '2020_03_06',\n", + " '2020_03_07',\n", + " '2020_03_08',\n", + " '2020_03_09',\n", + " '2020_03_10',\n", + " '2020_03_11',\n", + " '2020_03_12',\n", + " '2020_03_13',\n", + " '2020_03_14',\n", + " '2020_03_15',\n", + " '2020_03_16',\n", + " '2020_03_17',\n", + " '2020_03_18',\n", + " '2020_03_19',\n", + " '2020_03_20',\n", + " '2020_03_21',\n", + " '2020_03_22',\n", + " '2020_03_23',\n", + " '2020_03_24',\n", + " '2020_03_25',\n", + " '2020_03_26',\n", + " '2020_03_27',\n", + " '2020_03_28',\n", + " '2020_03_29',\n", + " '2020_03_30',\n", + " '2020_03_31'],\n", + " 'val': ['2020_04_03',\n", + " '2020_04_04',\n", + " '2020_04_05',\n", + " '2020_04_06',\n", + " '2020_04_07',\n", + " '2020_04_08',\n", + " '2020_04_09',\n", + " '2020_04_10',\n", + " '2020_04_11',\n", + " '2020_04_12',\n", + " '2020_04_13',\n", + " '2020_04_14',\n", + " '2020_04_15',\n", + " '2020_04_16',\n", + " '2020_04_17',\n", + " '2020_04_18',\n", + " '2020_04_19',\n", + " '2020_04_20',\n", + " '2020_04_21',\n", + " '2020_04_22',\n", + " '2020_04_23'],\n", + " 'test': ['2020_04_01', '2020_04_02']},\n", + " 'linear_trends': [],\n", + " 'linear_trend_steps': [1, 2, 3, 4, 5, 6, 7],\n", + " 'meta': [],\n", + " 'var_files': {'siconca': ['./processed/notebook_api_data/osisaf/south/siconca/siconca_abs.nc']}},\n", + " 'meta': {'name': 'notebook_api_data',\n", + " 'implementation': 'IceNetMetaPreProcessor',\n", + " 'anom': [],\n", + " 'abs': [],\n", + " 'dates': {'train': [], 'val': [], 'test': []},\n", + " 'linear_trends': [],\n", + " 'linear_trend_steps': [1, 2, 3, 4, 5, 6, 7],\n", + " 'meta': ['sin', 'cos', 'land'],\n", + " 'var_files': {'sin': ['./processed/notebook_api_data/meta/south/sin/sin.nc'],\n", + " 'cos': ['./processed/notebook_api_data/meta/south/cos/cos.nc'],\n", + " 'land': ['./processed/notebook_api_data/meta/south/land/land.nc']}}},\n", + " 'dtype': 'float32',\n", + " 'shape': [432, 432],\n", + " 'missing_dates': []}" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dl._config" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We generate a config only dataset, which will get saved in `dataset_config.pytorch_notebook.json`." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:Writing dataset configuration without data generation\n", + "INFO:root:91 train dates in total, NOT generating cache data.\n", + "INFO:root:21 val dates in total, NOT generating cache data.\n", + "INFO:root:2 test dates in total, NOT generating cache data.\n", + "INFO:root:Writing configuration to ./dataset_config.pytorch_notebook.json\n" + ] + } + ], + "source": [ + "dl.write_dataset_config_only()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now create the IceNetDataSet object:" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "dataset_config = \"dataset_config.pytorch_notebook.json\"" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:Loading configuration dataset_config.pytorch_notebook.json\n", + "WARNING:root:Running in configuration only mode, tfrecords were not generated for this dataset\n" + ] + } + ], + "source": [ + "dataset = IceNetDataSet(dataset_config, batch_size=4)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'identifier': 'pytorch_notebook',\n", + " 'implementation': 'DaskMultiWorkerLoader',\n", + " 'channels': ['uas_abs_1',\n", + " 'vas_abs_1',\n", + " 'siconca_abs_1',\n", + " 'tas_anom_1',\n", + " 'zg250_anom_1',\n", + " 'zg500_anom_1',\n", + " 'cos_1',\n", + " 'land_1',\n", + " 'sin_1'],\n", + " 'counts': {'train': 91, 'val': 21, 'test': 2},\n", + " 'dtype': 'float32',\n", + " 'loader_config': '/data/hpcdata/users/rychan/notebooks/notebook-pipeline/loader.notebook_api_data.json',\n", + " 'missing_dates': [],\n", + " 'n_forecast_days': 7,\n", + " 'north': False,\n", + " 'num_channels': 9,\n", + " 'shape': [432, 432],\n", + " 'south': True,\n", + " 'dataset_path': False,\n", + " 'generate_workers': 4,\n", + " 'loss_weight_days': True,\n", + " 'output_batch_size': 4,\n", + " 'var_lag': 1,\n", + " 'var_lag_override': {}}" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset._config" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "7" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset._config[\"n_forecast_days\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'/data/hpcdata/users/rychan/notebooks/notebook-pipeline/loader.notebook_api_data.json'" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset.loader_config" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:Loading configuration /data/hpcdata/users/rychan/notebooks/notebook-pipeline/loader.notebook_api_data.json\n" + ] + } + ], + "source": [ + "dataloader_from_dataset = dataset.get_data_loader()" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['sources', 'dtype', 'shape', 'missing_dates'])" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataloader_from_dataset._config.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "7" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataloader_from_dataset._n_forecast_days" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataloader_from_dataset.workers" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataloader_from_dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Custom PyTorch Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:Loading configuration dataset_config.pytorch_notebook.json\n", + "WARNING:root:Running in configuration only mode, tfrecords were not generated for this dataset\n", + "INFO:root:Loading configuration /data/hpcdata/users/rychan/notebooks/notebook-pipeline/loader.notebook_api_data.json\n" + ] + } + ], + "source": [ + "ds_torch = IceNetDataSetPyTorch(configuration_path=dataset_config, mode=\"train\")" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "91" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds_torch.__len__()" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'2020-01-01'" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds_torch._dates[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "forecast_date: 2020-01-01 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-01 00:00:00'), Timestamp('2020-01-02 00:00:00'), Timestamp('2020-01-03 00:00:00'), Timestamp('2020-01-04 00:00:00'), Timestamp('2020-01-05 00:00:00'), Timestamp('2020-01-06 00:00:00'), Timestamp('2020-01-07 00:00:00')]\n" + ] + } + ], + "source": [ + "first_item = ds_torch.__getitem__(0)" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "first_item is a of length 3\n" + ] + } + ], + "source": [ + "print(f\"first_item is a {type(first_item)} of length {len(first_item)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "first_item[i] is a of shape (432, 432, 9)\n", + "first_item[i] is a of shape (432, 432, 7, 1)\n", + "first_item[i] is a of shape (432, 432, 7, 1)\n" + ] + } + ], + "source": [ + "for i in range(len(first_item)):\n", + " print(f\"first_item[i] is a {type(first_item[i])} of shape {first_item[i].shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'identifier': 'pytorch_notebook',\n", + " 'implementation': 'DaskMultiWorkerLoader',\n", + " 'channels': ['uas_abs_1',\n", + " 'vas_abs_1',\n", + " 'siconca_abs_1',\n", + " 'tas_anom_1',\n", + " 'zg250_anom_1',\n", + " 'zg500_anom_1',\n", + " 'cos_1',\n", + " 'land_1',\n", + " 'sin_1'],\n", + " 'counts': {'train': 91, 'val': 21, 'test': 2},\n", + " 'dtype': 'float32',\n", + " 'loader_config': '/data/hpcdata/users/rychan/notebooks/notebook-pipeline/loader.notebook_api_data.json',\n", + " 'missing_dates': [],\n", + " 'n_forecast_days': 7,\n", + " 'north': False,\n", + " 'num_channels': 9,\n", + " 'shape': [432, 432],\n", + " 'south': True,\n", + " 'dataset_path': False,\n", + " 'generate_workers': 4,\n", + " 'loss_weight_days': True,\n", + " 'output_batch_size': 4,\n", + " 'var_lag': 1,\n", + " 'var_lag_override': {}}" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds_torch._ds._config" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds_torch._ds" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "forecast_date: 2020-03-21 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-21 00:00:00'), Timestamp('2020-03-22 00:00:00'), Timestamp('2020-03-23 00:00:00'), Timestamp('2020-03-24 00:00:00'), Timestamp('2020-03-25 00:00:00'), Timestamp('2020-03-26 00:00:00'), Timestamp('2020-03-27 00:00:00')]\n" + ] + }, + { + "data": { + "text/plain": [ + "(array([[[ 0.5070619 , 0.507128 , 0. , ..., -0.18561055,\n", + " 1. , -0.9826234 ],\n", + " [ 0.5067984 , 0.5119969 , 0. , ..., -0.18561055,\n", + " 1. , -0.9826234 ],\n", + " [ 0.50709283, 0.5130819 , 0. , ..., -0.18561055,\n", + " 1. , -0.9826234 ],\n", + " ...,\n", + " [ 0.49436113, 0.51067567, 0. , ..., -0.18561055,\n", + " 1. , -0.9826234 ],\n", + " [ 0.49760532, 0.5123943 , 0. , ..., -0.18561055,\n", + " 1. , -0.9826234 ],\n", + " [ 0.5012724 , 0.5121377 , 0. , ..., -0.18561055,\n", + " 1. , -0.9826234 ]],\n", + " \n", + " [[ 0.5080182 , 0.5084777 , 0. , ..., -0.18561055,\n", + " 1. , -0.9826234 ],\n", + " [ 0.50926554, 0.51501405, 0. , ..., -0.18561055,\n", + " 1. , -0.9826234 ],\n", + " [ 0.50827605, 0.5093369 , 0. , ..., -0.18561055,\n", + " 1. , -0.9826234 ],\n", + " ...,\n", + " [ 0.4894217 , 0.5110176 , 0. , ..., -0.18561055,\n", + " 1. , -0.9826234 ],\n", + " [ 0.49363175, 0.51387256, 0. , ..., -0.18561055,\n", + " 1. , -0.9826234 ],\n", + " [ 0.49781784, 0.5137171 , 0. , ..., -0.18561055,\n", + " 1. , -0.9826234 ]],\n", + " \n", + " [[ 0.5070832 , 0.5101056 , 0. , ..., -0.18561055,\n", + " 1. , -0.9826234 ],\n", + " [ 0.5071541 , 0.5147865 , 0. , ..., -0.18561055,\n", + " 1. , -0.9826234 ],\n", + " [ 0.51002264, 0.5079214 , 0. , ..., -0.18561055,\n", + " 1. , -0.9826234 ],\n", + " ...,\n", + " [ 0.48842257, 0.5124112 , 0. , ..., -0.18561055,\n", + " 1. , -0.9826234 ],\n", + " [ 0.49018812, 0.5134328 , 0. , ..., -0.18561055,\n", + " 1. , -0.9826234 ],\n", + " [ 0.4937481 , 0.5152318 , 0. , ..., -0.18561055,\n", + " 1. , -0.9826234 ]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.5201201 , 0.5939807 , 0. , ..., -0.18561055,\n", + " -1. , -0.9826234 ],\n", + " [ 0.5196981 , 0.5929644 , 0. , ..., -0.18561055,\n", + " -1. , -0.9826234 ],\n", + " [ 0.5230958 , 0.5898119 , 0. , ..., -0.18561055,\n", + " -1. , -0.9826234 ],\n", + " ...,\n", + " [ 0.53455853, 0.4542398 , 0. , ..., -0.18561055,\n", + " 1. , -0.9826234 ],\n", + " [ 0.5359504 , 0.45726374, 0. , ..., -0.18561055,\n", + " 1. , -0.9826234 ],\n", + " [ 0.53147364, 0.462161 , 0. , ..., -0.18561055,\n", + " 1. , -0.9826234 ]],\n", + " \n", + " [[ 0.52010256, 0.5904959 , 0. , ..., -0.18561055,\n", + " -1. , -0.9826234 ],\n", + " [ 0.5158349 , 0.59004027, 0. , ..., -0.18561055,\n", + " -1. , -0.9826234 ],\n", + " [ 0.5168654 , 0.58813155, 0. , ..., -0.18561055,\n", + " -1. , -0.9826234 ],\n", + " ...,\n", + " [ 0.5326147 , 0.46137762, 0. , ..., -0.18561055,\n", + " 1. , -0.9826234 ],\n", + " [ 0.53047955, 0.46351588, 0. , ..., -0.18561055,\n", + " 1. , -0.9826234 ],\n", + " [ 0.5268573 , 0.46342507, 0. , ..., -0.18561055,\n", + " 1. , -0.9826234 ]],\n", + " \n", + " [[ 0.5224085 , 0.5875143 , 0. , ..., -0.18561055,\n", + " -1. , -0.9826234 ],\n", + " [ 0.51823205, 0.58804107, 0. , ..., -0.18561055,\n", + " -1. , -0.9826234 ],\n", + " [ 0.5156312 , 0.5871129 , 0. , ..., -0.18561055,\n", + " -1. , -0.9826234 ],\n", + " ...,\n", + " [ 0.52451056, 0.4664941 , 0. , ..., -0.18561055,\n", + " 1. , -0.9826234 ],\n", + " [ 0.5236765 , 0.4643196 , 0. , ..., -0.18561055,\n", + " 1. , -0.9826234 ],\n", + " [ 0.5230731 , 0.46157268, 0. , ..., -0.18561055,\n", + " 1. , -0.9826234 ]]], dtype=float32),\n", + " array([[[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]]], dtype=float32),\n", + " array([[[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]]], dtype=float32))" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds_torch.__getitem__(80)" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "forecast_date: 2020-01-01 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-01 00:00:00'), Timestamp('2020-01-02 00:00:00'), Timestamp('2020-01-03 00:00:00'), Timestamp('2020-01-04 00:00:00'), Timestamp('2020-01-05 00:00:00'), Timestamp('2020-01-06 00:00:00'), Timestamp('2020-01-07 00:00:00')]\n", + "##### i: 0\n", + "ds_torch.__getitem__(0) is a of length 3\n", + "ds_torch.__getitem__(0)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(0)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(0)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-01-02 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-02 00:00:00'), Timestamp('2020-01-03 00:00:00'), Timestamp('2020-01-04 00:00:00'), Timestamp('2020-01-05 00:00:00'), Timestamp('2020-01-06 00:00:00'), Timestamp('2020-01-07 00:00:00'), Timestamp('2020-01-08 00:00:00')]\n", + "##### i: 1\n", + "ds_torch.__getitem__(1) is a of length 3\n", + "ds_torch.__getitem__(1)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(1)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(1)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-01-03 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-03 00:00:00'), Timestamp('2020-01-04 00:00:00'), Timestamp('2020-01-05 00:00:00'), Timestamp('2020-01-06 00:00:00'), Timestamp('2020-01-07 00:00:00'), Timestamp('2020-01-08 00:00:00'), Timestamp('2020-01-09 00:00:00')]\n", + "##### i: 2\n", + "ds_torch.__getitem__(2) is a of length 3\n", + "ds_torch.__getitem__(2)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(2)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(2)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-01-04 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-04 00:00:00'), Timestamp('2020-01-05 00:00:00'), Timestamp('2020-01-06 00:00:00'), Timestamp('2020-01-07 00:00:00'), Timestamp('2020-01-08 00:00:00'), Timestamp('2020-01-09 00:00:00'), Timestamp('2020-01-10 00:00:00')]\n", + "##### i: 3\n", + "ds_torch.__getitem__(3) is a of length 3\n", + "ds_torch.__getitem__(3)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(3)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(3)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-01-05 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-05 00:00:00'), Timestamp('2020-01-06 00:00:00'), Timestamp('2020-01-07 00:00:00'), Timestamp('2020-01-08 00:00:00'), Timestamp('2020-01-09 00:00:00'), Timestamp('2020-01-10 00:00:00'), Timestamp('2020-01-11 00:00:00')]\n", + "##### i: 4\n", + "ds_torch.__getitem__(4) is a of length 3\n", + "ds_torch.__getitem__(4)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(4)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(4)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-01-06 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-06 00:00:00'), Timestamp('2020-01-07 00:00:00'), Timestamp('2020-01-08 00:00:00'), Timestamp('2020-01-09 00:00:00'), Timestamp('2020-01-10 00:00:00'), Timestamp('2020-01-11 00:00:00'), Timestamp('2020-01-12 00:00:00')]\n", + "##### i: 5\n", + "ds_torch.__getitem__(5) is a of length 3\n", + "ds_torch.__getitem__(5)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(5)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(5)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-01-07 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-07 00:00:00'), Timestamp('2020-01-08 00:00:00'), Timestamp('2020-01-09 00:00:00'), Timestamp('2020-01-10 00:00:00'), Timestamp('2020-01-11 00:00:00'), Timestamp('2020-01-12 00:00:00'), Timestamp('2020-01-13 00:00:00')]\n", + "##### i: 6\n", + "ds_torch.__getitem__(6) is a of length 3\n", + "ds_torch.__getitem__(6)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(6)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(6)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-01-08 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-08 00:00:00'), Timestamp('2020-01-09 00:00:00'), Timestamp('2020-01-10 00:00:00'), Timestamp('2020-01-11 00:00:00'), Timestamp('2020-01-12 00:00:00'), Timestamp('2020-01-13 00:00:00'), Timestamp('2020-01-14 00:00:00')]\n", + "##### i: 7\n", + "ds_torch.__getitem__(7) is a of length 3\n", + "ds_torch.__getitem__(7)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(7)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(7)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-01-09 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-09 00:00:00'), Timestamp('2020-01-10 00:00:00'), Timestamp('2020-01-11 00:00:00'), Timestamp('2020-01-12 00:00:00'), Timestamp('2020-01-13 00:00:00'), Timestamp('2020-01-14 00:00:00'), Timestamp('2020-01-15 00:00:00')]\n", + "##### i: 8\n", + "ds_torch.__getitem__(8) is a of length 3\n", + "ds_torch.__getitem__(8)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(8)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(8)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-01-10 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-10 00:00:00'), Timestamp('2020-01-11 00:00:00'), Timestamp('2020-01-12 00:00:00'), Timestamp('2020-01-13 00:00:00'), Timestamp('2020-01-14 00:00:00'), Timestamp('2020-01-15 00:00:00'), Timestamp('2020-01-16 00:00:00')]\n", + "##### i: 9\n", + "ds_torch.__getitem__(9) is a of length 3\n", + "ds_torch.__getitem__(9)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(9)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(9)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-01-11 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-11 00:00:00'), Timestamp('2020-01-12 00:00:00'), Timestamp('2020-01-13 00:00:00'), Timestamp('2020-01-14 00:00:00'), Timestamp('2020-01-15 00:00:00'), Timestamp('2020-01-16 00:00:00'), Timestamp('2020-01-17 00:00:00')]\n", + "##### i: 10\n", + "ds_torch.__getitem__(10) is a of length 3\n", + "ds_torch.__getitem__(10)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(10)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(10)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-01-12 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-12 00:00:00'), Timestamp('2020-01-13 00:00:00'), Timestamp('2020-01-14 00:00:00'), Timestamp('2020-01-15 00:00:00'), Timestamp('2020-01-16 00:00:00'), Timestamp('2020-01-17 00:00:00'), Timestamp('2020-01-18 00:00:00')]\n", + "##### i: 11\n", + "ds_torch.__getitem__(11) is a of length 3\n", + "ds_torch.__getitem__(11)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(11)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(11)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-01-13 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-13 00:00:00'), Timestamp('2020-01-14 00:00:00'), Timestamp('2020-01-15 00:00:00'), Timestamp('2020-01-16 00:00:00'), Timestamp('2020-01-17 00:00:00'), Timestamp('2020-01-18 00:00:00'), Timestamp('2020-01-19 00:00:00')]\n", + "##### i: 12\n", + "ds_torch.__getitem__(12) is a of length 3\n", + "ds_torch.__getitem__(12)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(12)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(12)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-01-14 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-14 00:00:00'), Timestamp('2020-01-15 00:00:00'), Timestamp('2020-01-16 00:00:00'), Timestamp('2020-01-17 00:00:00'), Timestamp('2020-01-18 00:00:00'), Timestamp('2020-01-19 00:00:00'), Timestamp('2020-01-20 00:00:00')]\n", + "##### i: 13\n", + "ds_torch.__getitem__(13) is a of length 3\n", + "ds_torch.__getitem__(13)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(13)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(13)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-01-15 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-15 00:00:00'), Timestamp('2020-01-16 00:00:00'), Timestamp('2020-01-17 00:00:00'), Timestamp('2020-01-18 00:00:00'), Timestamp('2020-01-19 00:00:00'), Timestamp('2020-01-20 00:00:00'), Timestamp('2020-01-21 00:00:00')]\n", + "##### i: 14\n", + "ds_torch.__getitem__(14) is a of length 3\n", + "ds_torch.__getitem__(14)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(14)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(14)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-01-16 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-16 00:00:00'), Timestamp('2020-01-17 00:00:00'), Timestamp('2020-01-18 00:00:00'), Timestamp('2020-01-19 00:00:00'), Timestamp('2020-01-20 00:00:00'), Timestamp('2020-01-21 00:00:00'), Timestamp('2020-01-22 00:00:00')]\n", + "##### i: 15\n", + "ds_torch.__getitem__(15) is a of length 3\n", + "ds_torch.__getitem__(15)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(15)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(15)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-01-17 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-17 00:00:00'), Timestamp('2020-01-18 00:00:00'), Timestamp('2020-01-19 00:00:00'), Timestamp('2020-01-20 00:00:00'), Timestamp('2020-01-21 00:00:00'), Timestamp('2020-01-22 00:00:00'), Timestamp('2020-01-23 00:00:00')]\n", + "##### i: 16\n", + "ds_torch.__getitem__(16) is a of length 3\n", + "ds_torch.__getitem__(16)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(16)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(16)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-01-18 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-18 00:00:00'), Timestamp('2020-01-19 00:00:00'), Timestamp('2020-01-20 00:00:00'), Timestamp('2020-01-21 00:00:00'), Timestamp('2020-01-22 00:00:00'), Timestamp('2020-01-23 00:00:00'), Timestamp('2020-01-24 00:00:00')]\n", + "##### i: 17\n", + "ds_torch.__getitem__(17) is a of length 3\n", + "ds_torch.__getitem__(17)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(17)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(17)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-01-19 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-19 00:00:00'), Timestamp('2020-01-20 00:00:00'), Timestamp('2020-01-21 00:00:00'), Timestamp('2020-01-22 00:00:00'), Timestamp('2020-01-23 00:00:00'), Timestamp('2020-01-24 00:00:00'), Timestamp('2020-01-25 00:00:00')]\n", + "##### i: 18\n", + "ds_torch.__getitem__(18) is a of length 3\n", + "ds_torch.__getitem__(18)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(18)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(18)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-01-20 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-20 00:00:00'), Timestamp('2020-01-21 00:00:00'), Timestamp('2020-01-22 00:00:00'), Timestamp('2020-01-23 00:00:00'), Timestamp('2020-01-24 00:00:00'), Timestamp('2020-01-25 00:00:00'), Timestamp('2020-01-26 00:00:00')]\n", + "##### i: 19\n", + "ds_torch.__getitem__(19) is a of length 3\n", + "ds_torch.__getitem__(19)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(19)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(19)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-01-21 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-21 00:00:00'), Timestamp('2020-01-22 00:00:00'), Timestamp('2020-01-23 00:00:00'), Timestamp('2020-01-24 00:00:00'), Timestamp('2020-01-25 00:00:00'), Timestamp('2020-01-26 00:00:00'), Timestamp('2020-01-27 00:00:00')]\n", + "##### i: 20\n", + "ds_torch.__getitem__(20) is a of length 3\n", + "ds_torch.__getitem__(20)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(20)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(20)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-01-22 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-22 00:00:00'), Timestamp('2020-01-23 00:00:00'), Timestamp('2020-01-24 00:00:00'), Timestamp('2020-01-25 00:00:00'), Timestamp('2020-01-26 00:00:00'), Timestamp('2020-01-27 00:00:00'), Timestamp('2020-01-28 00:00:00')]\n", + "##### i: 21\n", + "ds_torch.__getitem__(21) is a of length 3\n", + "ds_torch.__getitem__(21)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(21)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(21)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-01-23 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-23 00:00:00'), Timestamp('2020-01-24 00:00:00'), Timestamp('2020-01-25 00:00:00'), Timestamp('2020-01-26 00:00:00'), Timestamp('2020-01-27 00:00:00'), Timestamp('2020-01-28 00:00:00'), Timestamp('2020-01-29 00:00:00')]\n", + "##### i: 22\n", + "ds_torch.__getitem__(22) is a of length 3\n", + "ds_torch.__getitem__(22)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(22)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(22)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-01-24 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-24 00:00:00'), Timestamp('2020-01-25 00:00:00'), Timestamp('2020-01-26 00:00:00'), Timestamp('2020-01-27 00:00:00'), Timestamp('2020-01-28 00:00:00'), Timestamp('2020-01-29 00:00:00'), Timestamp('2020-01-30 00:00:00')]\n", + "##### i: 23\n", + "ds_torch.__getitem__(23) is a of length 3\n", + "ds_torch.__getitem__(23)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(23)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(23)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-01-25 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-25 00:00:00'), Timestamp('2020-01-26 00:00:00'), Timestamp('2020-01-27 00:00:00'), Timestamp('2020-01-28 00:00:00'), Timestamp('2020-01-29 00:00:00'), Timestamp('2020-01-30 00:00:00'), Timestamp('2020-01-31 00:00:00')]\n", + "##### i: 24\n", + "ds_torch.__getitem__(24) is a of length 3\n", + "ds_torch.__getitem__(24)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(24)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(24)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-01-26 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-26 00:00:00'), Timestamp('2020-01-27 00:00:00'), Timestamp('2020-01-28 00:00:00'), Timestamp('2020-01-29 00:00:00'), Timestamp('2020-01-30 00:00:00'), Timestamp('2020-01-31 00:00:00'), Timestamp('2020-02-01 00:00:00')]\n", + "##### i: 25\n", + "ds_torch.__getitem__(25) is a of length 3\n", + "ds_torch.__getitem__(25)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(25)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(25)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-01-27 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-27 00:00:00'), Timestamp('2020-01-28 00:00:00'), Timestamp('2020-01-29 00:00:00'), Timestamp('2020-01-30 00:00:00'), Timestamp('2020-01-31 00:00:00'), Timestamp('2020-02-01 00:00:00'), Timestamp('2020-02-02 00:00:00')]\n", + "##### i: 26\n", + "ds_torch.__getitem__(26) is a of length 3\n", + "ds_torch.__getitem__(26)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(26)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(26)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-01-28 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-28 00:00:00'), Timestamp('2020-01-29 00:00:00'), Timestamp('2020-01-30 00:00:00'), Timestamp('2020-01-31 00:00:00'), Timestamp('2020-02-01 00:00:00'), Timestamp('2020-02-02 00:00:00'), Timestamp('2020-02-03 00:00:00')]\n", + "##### i: 27\n", + "ds_torch.__getitem__(27) is a of length 3\n", + "ds_torch.__getitem__(27)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(27)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(27)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-01-29 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-29 00:00:00'), Timestamp('2020-01-30 00:00:00'), Timestamp('2020-01-31 00:00:00'), Timestamp('2020-02-01 00:00:00'), Timestamp('2020-02-02 00:00:00'), Timestamp('2020-02-03 00:00:00'), Timestamp('2020-02-04 00:00:00')]\n", + "##### i: 28\n", + "ds_torch.__getitem__(28) is a of length 3\n", + "ds_torch.__getitem__(28)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(28)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(28)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-01-30 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-30 00:00:00'), Timestamp('2020-01-31 00:00:00'), Timestamp('2020-02-01 00:00:00'), Timestamp('2020-02-02 00:00:00'), Timestamp('2020-02-03 00:00:00'), Timestamp('2020-02-04 00:00:00'), Timestamp('2020-02-05 00:00:00')]\n", + "##### i: 29\n", + "ds_torch.__getitem__(29) is a of length 3\n", + "ds_torch.__getitem__(29)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(29)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(29)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-01-31 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-31 00:00:00'), Timestamp('2020-02-01 00:00:00'), Timestamp('2020-02-02 00:00:00'), Timestamp('2020-02-03 00:00:00'), Timestamp('2020-02-04 00:00:00'), Timestamp('2020-02-05 00:00:00'), Timestamp('2020-02-06 00:00:00')]\n", + "##### i: 30\n", + "ds_torch.__getitem__(30) is a of length 3\n", + "ds_torch.__getitem__(30)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(30)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(30)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-02-01 00:00:00\n", + "forecast_dts: [Timestamp('2020-02-01 00:00:00'), Timestamp('2020-02-02 00:00:00'), Timestamp('2020-02-03 00:00:00'), Timestamp('2020-02-04 00:00:00'), Timestamp('2020-02-05 00:00:00'), Timestamp('2020-02-06 00:00:00'), Timestamp('2020-02-07 00:00:00')]\n", + "##### i: 31\n", + "ds_torch.__getitem__(31) is a of length 3\n", + "ds_torch.__getitem__(31)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(31)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(31)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-02-02 00:00:00\n", + "forecast_dts: [Timestamp('2020-02-02 00:00:00'), Timestamp('2020-02-03 00:00:00'), Timestamp('2020-02-04 00:00:00'), Timestamp('2020-02-05 00:00:00'), Timestamp('2020-02-06 00:00:00'), Timestamp('2020-02-07 00:00:00'), Timestamp('2020-02-08 00:00:00')]\n", + "##### i: 32\n", + "ds_torch.__getitem__(32) is a of length 3\n", + "ds_torch.__getitem__(32)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(32)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(32)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-02-03 00:00:00\n", + "forecast_dts: [Timestamp('2020-02-03 00:00:00'), Timestamp('2020-02-04 00:00:00'), Timestamp('2020-02-05 00:00:00'), Timestamp('2020-02-06 00:00:00'), Timestamp('2020-02-07 00:00:00'), Timestamp('2020-02-08 00:00:00'), Timestamp('2020-02-09 00:00:00')]\n", + "##### i: 33\n", + "ds_torch.__getitem__(33) is a of length 3\n", + "ds_torch.__getitem__(33)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(33)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(33)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-02-04 00:00:00\n", + "forecast_dts: [Timestamp('2020-02-04 00:00:00'), Timestamp('2020-02-05 00:00:00'), Timestamp('2020-02-06 00:00:00'), Timestamp('2020-02-07 00:00:00'), Timestamp('2020-02-08 00:00:00'), Timestamp('2020-02-09 00:00:00'), Timestamp('2020-02-10 00:00:00')]\n", + "##### i: 34\n", + "ds_torch.__getitem__(34) is a of length 3\n", + "ds_torch.__getitem__(34)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(34)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(34)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-02-05 00:00:00\n", + "forecast_dts: [Timestamp('2020-02-05 00:00:00'), Timestamp('2020-02-06 00:00:00'), Timestamp('2020-02-07 00:00:00'), Timestamp('2020-02-08 00:00:00'), Timestamp('2020-02-09 00:00:00'), Timestamp('2020-02-10 00:00:00'), Timestamp('2020-02-11 00:00:00')]\n", + "##### i: 35\n", + "ds_torch.__getitem__(35) is a of length 3\n", + "ds_torch.__getitem__(35)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(35)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(35)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-02-06 00:00:00\n", + "forecast_dts: [Timestamp('2020-02-06 00:00:00'), Timestamp('2020-02-07 00:00:00'), Timestamp('2020-02-08 00:00:00'), Timestamp('2020-02-09 00:00:00'), Timestamp('2020-02-10 00:00:00'), Timestamp('2020-02-11 00:00:00'), Timestamp('2020-02-12 00:00:00')]\n", + "##### i: 36\n", + "ds_torch.__getitem__(36) is a of length 3\n", + "ds_torch.__getitem__(36)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(36)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(36)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-02-07 00:00:00\n", + "forecast_dts: [Timestamp('2020-02-07 00:00:00'), Timestamp('2020-02-08 00:00:00'), Timestamp('2020-02-09 00:00:00'), Timestamp('2020-02-10 00:00:00'), Timestamp('2020-02-11 00:00:00'), Timestamp('2020-02-12 00:00:00'), Timestamp('2020-02-13 00:00:00')]\n", + "##### i: 37\n", + "ds_torch.__getitem__(37) is a of length 3\n", + "ds_torch.__getitem__(37)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(37)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(37)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-02-08 00:00:00\n", + "forecast_dts: [Timestamp('2020-02-08 00:00:00'), Timestamp('2020-02-09 00:00:00'), Timestamp('2020-02-10 00:00:00'), Timestamp('2020-02-11 00:00:00'), Timestamp('2020-02-12 00:00:00'), Timestamp('2020-02-13 00:00:00'), Timestamp('2020-02-14 00:00:00')]\n", + "##### i: 38\n", + "ds_torch.__getitem__(38) is a of length 3\n", + "ds_torch.__getitem__(38)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(38)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(38)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-02-09 00:00:00\n", + "forecast_dts: [Timestamp('2020-02-09 00:00:00'), Timestamp('2020-02-10 00:00:00'), Timestamp('2020-02-11 00:00:00'), Timestamp('2020-02-12 00:00:00'), Timestamp('2020-02-13 00:00:00'), Timestamp('2020-02-14 00:00:00'), Timestamp('2020-02-15 00:00:00')]\n", + "##### i: 39\n", + "ds_torch.__getitem__(39) is a of length 3\n", + "ds_torch.__getitem__(39)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(39)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(39)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-02-10 00:00:00\n", + "forecast_dts: [Timestamp('2020-02-10 00:00:00'), Timestamp('2020-02-11 00:00:00'), Timestamp('2020-02-12 00:00:00'), Timestamp('2020-02-13 00:00:00'), Timestamp('2020-02-14 00:00:00'), Timestamp('2020-02-15 00:00:00'), Timestamp('2020-02-16 00:00:00')]\n", + "##### i: 40\n", + "ds_torch.__getitem__(40) is a of length 3\n", + "ds_torch.__getitem__(40)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(40)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(40)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-02-11 00:00:00\n", + "forecast_dts: [Timestamp('2020-02-11 00:00:00'), Timestamp('2020-02-12 00:00:00'), Timestamp('2020-02-13 00:00:00'), Timestamp('2020-02-14 00:00:00'), Timestamp('2020-02-15 00:00:00'), Timestamp('2020-02-16 00:00:00'), Timestamp('2020-02-17 00:00:00')]\n", + "##### i: 41\n", + "ds_torch.__getitem__(41) is a of length 3\n", + "ds_torch.__getitem__(41)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(41)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(41)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-02-12 00:00:00\n", + "forecast_dts: [Timestamp('2020-02-12 00:00:00'), Timestamp('2020-02-13 00:00:00'), Timestamp('2020-02-14 00:00:00'), Timestamp('2020-02-15 00:00:00'), Timestamp('2020-02-16 00:00:00'), Timestamp('2020-02-17 00:00:00'), Timestamp('2020-02-18 00:00:00')]\n", + "##### i: 42\n", + "ds_torch.__getitem__(42) is a of length 3\n", + "ds_torch.__getitem__(42)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(42)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(42)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-02-13 00:00:00\n", + "forecast_dts: [Timestamp('2020-02-13 00:00:00'), Timestamp('2020-02-14 00:00:00'), Timestamp('2020-02-15 00:00:00'), Timestamp('2020-02-16 00:00:00'), Timestamp('2020-02-17 00:00:00'), Timestamp('2020-02-18 00:00:00'), Timestamp('2020-02-19 00:00:00')]\n", + "##### i: 43\n", + "ds_torch.__getitem__(43) is a of length 3\n", + "ds_torch.__getitem__(43)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(43)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(43)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-02-14 00:00:00\n", + "forecast_dts: [Timestamp('2020-02-14 00:00:00'), Timestamp('2020-02-15 00:00:00'), Timestamp('2020-02-16 00:00:00'), Timestamp('2020-02-17 00:00:00'), Timestamp('2020-02-18 00:00:00'), Timestamp('2020-02-19 00:00:00'), Timestamp('2020-02-20 00:00:00')]\n", + "##### i: 44\n", + "ds_torch.__getitem__(44) is a of length 3\n", + "ds_torch.__getitem__(44)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(44)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(44)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-02-15 00:00:00\n", + "forecast_dts: [Timestamp('2020-02-15 00:00:00'), Timestamp('2020-02-16 00:00:00'), Timestamp('2020-02-17 00:00:00'), Timestamp('2020-02-18 00:00:00'), Timestamp('2020-02-19 00:00:00'), Timestamp('2020-02-20 00:00:00'), Timestamp('2020-02-21 00:00:00')]\n", + "##### i: 45\n", + "ds_torch.__getitem__(45) is a of length 3\n", + "ds_torch.__getitem__(45)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(45)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(45)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-02-16 00:00:00\n", + "forecast_dts: [Timestamp('2020-02-16 00:00:00'), Timestamp('2020-02-17 00:00:00'), Timestamp('2020-02-18 00:00:00'), Timestamp('2020-02-19 00:00:00'), Timestamp('2020-02-20 00:00:00'), Timestamp('2020-02-21 00:00:00'), Timestamp('2020-02-22 00:00:00')]\n", + "##### i: 46\n", + "ds_torch.__getitem__(46) is a of length 3\n", + "ds_torch.__getitem__(46)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(46)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(46)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-02-17 00:00:00\n", + "forecast_dts: [Timestamp('2020-02-17 00:00:00'), Timestamp('2020-02-18 00:00:00'), Timestamp('2020-02-19 00:00:00'), Timestamp('2020-02-20 00:00:00'), Timestamp('2020-02-21 00:00:00'), Timestamp('2020-02-22 00:00:00'), Timestamp('2020-02-23 00:00:00')]\n", + "##### i: 47\n", + "ds_torch.__getitem__(47) is a of length 3\n", + "ds_torch.__getitem__(47)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(47)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(47)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-02-18 00:00:00\n", + "forecast_dts: [Timestamp('2020-02-18 00:00:00'), Timestamp('2020-02-19 00:00:00'), Timestamp('2020-02-20 00:00:00'), Timestamp('2020-02-21 00:00:00'), Timestamp('2020-02-22 00:00:00'), Timestamp('2020-02-23 00:00:00'), Timestamp('2020-02-24 00:00:00')]\n", + "##### i: 48\n", + "ds_torch.__getitem__(48) is a of length 3\n", + "ds_torch.__getitem__(48)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(48)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(48)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-02-19 00:00:00\n", + "forecast_dts: [Timestamp('2020-02-19 00:00:00'), Timestamp('2020-02-20 00:00:00'), Timestamp('2020-02-21 00:00:00'), Timestamp('2020-02-22 00:00:00'), Timestamp('2020-02-23 00:00:00'), Timestamp('2020-02-24 00:00:00'), Timestamp('2020-02-25 00:00:00')]\n", + "##### i: 49\n", + "ds_torch.__getitem__(49) is a of length 3\n", + "ds_torch.__getitem__(49)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(49)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(49)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-02-20 00:00:00\n", + "forecast_dts: [Timestamp('2020-02-20 00:00:00'), Timestamp('2020-02-21 00:00:00'), Timestamp('2020-02-22 00:00:00'), Timestamp('2020-02-23 00:00:00'), Timestamp('2020-02-24 00:00:00'), Timestamp('2020-02-25 00:00:00'), Timestamp('2020-02-26 00:00:00')]\n", + "##### i: 50\n", + "ds_torch.__getitem__(50) is a of length 3\n", + "ds_torch.__getitem__(50)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(50)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(50)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-02-21 00:00:00\n", + "forecast_dts: [Timestamp('2020-02-21 00:00:00'), Timestamp('2020-02-22 00:00:00'), Timestamp('2020-02-23 00:00:00'), Timestamp('2020-02-24 00:00:00'), Timestamp('2020-02-25 00:00:00'), Timestamp('2020-02-26 00:00:00'), Timestamp('2020-02-27 00:00:00')]\n", + "##### i: 51\n", + "ds_torch.__getitem__(51) is a of length 3\n", + "ds_torch.__getitem__(51)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(51)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(51)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-02-22 00:00:00\n", + "forecast_dts: [Timestamp('2020-02-22 00:00:00'), Timestamp('2020-02-23 00:00:00'), Timestamp('2020-02-24 00:00:00'), Timestamp('2020-02-25 00:00:00'), Timestamp('2020-02-26 00:00:00'), Timestamp('2020-02-27 00:00:00'), Timestamp('2020-02-28 00:00:00')]\n", + "##### i: 52\n", + "ds_torch.__getitem__(52) is a of length 3\n", + "ds_torch.__getitem__(52)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(52)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(52)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-02-23 00:00:00\n", + "forecast_dts: [Timestamp('2020-02-23 00:00:00'), Timestamp('2020-02-24 00:00:00'), Timestamp('2020-02-25 00:00:00'), Timestamp('2020-02-26 00:00:00'), Timestamp('2020-02-27 00:00:00'), Timestamp('2020-02-28 00:00:00'), Timestamp('2020-02-29 00:00:00')]\n", + "##### i: 53\n", + "ds_torch.__getitem__(53) is a of length 3\n", + "ds_torch.__getitem__(53)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(53)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(53)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-02-24 00:00:00\n", + "forecast_dts: [Timestamp('2020-02-24 00:00:00'), Timestamp('2020-02-25 00:00:00'), Timestamp('2020-02-26 00:00:00'), Timestamp('2020-02-27 00:00:00'), Timestamp('2020-02-28 00:00:00'), Timestamp('2020-02-29 00:00:00'), Timestamp('2020-03-01 00:00:00')]\n", + "##### i: 54\n", + "ds_torch.__getitem__(54) is a of length 3\n", + "ds_torch.__getitem__(54)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(54)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(54)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-02-25 00:00:00\n", + "forecast_dts: [Timestamp('2020-02-25 00:00:00'), Timestamp('2020-02-26 00:00:00'), Timestamp('2020-02-27 00:00:00'), Timestamp('2020-02-28 00:00:00'), Timestamp('2020-02-29 00:00:00'), Timestamp('2020-03-01 00:00:00'), Timestamp('2020-03-02 00:00:00')]\n", + "##### i: 55\n", + "ds_torch.__getitem__(55) is a of length 3\n", + "ds_torch.__getitem__(55)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(55)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(55)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-02-26 00:00:00\n", + "forecast_dts: [Timestamp('2020-02-26 00:00:00'), Timestamp('2020-02-27 00:00:00'), Timestamp('2020-02-28 00:00:00'), Timestamp('2020-02-29 00:00:00'), Timestamp('2020-03-01 00:00:00'), Timestamp('2020-03-02 00:00:00'), Timestamp('2020-03-03 00:00:00')]\n", + "##### i: 56\n", + "ds_torch.__getitem__(56) is a of length 3\n", + "ds_torch.__getitem__(56)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(56)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(56)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-02-27 00:00:00\n", + "forecast_dts: [Timestamp('2020-02-27 00:00:00'), Timestamp('2020-02-28 00:00:00'), Timestamp('2020-02-29 00:00:00'), Timestamp('2020-03-01 00:00:00'), Timestamp('2020-03-02 00:00:00'), Timestamp('2020-03-03 00:00:00'), Timestamp('2020-03-04 00:00:00')]\n", + "##### i: 57\n", + "ds_torch.__getitem__(57) is a of length 3\n", + "ds_torch.__getitem__(57)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(57)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(57)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-02-28 00:00:00\n", + "forecast_dts: [Timestamp('2020-02-28 00:00:00'), Timestamp('2020-02-29 00:00:00'), Timestamp('2020-03-01 00:00:00'), Timestamp('2020-03-02 00:00:00'), Timestamp('2020-03-03 00:00:00'), Timestamp('2020-03-04 00:00:00'), Timestamp('2020-03-05 00:00:00')]\n", + "##### i: 58\n", + "ds_torch.__getitem__(58) is a of length 3\n", + "ds_torch.__getitem__(58)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(58)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(58)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-02-29 00:00:00\n", + "forecast_dts: [Timestamp('2020-02-29 00:00:00'), Timestamp('2020-03-01 00:00:00'), Timestamp('2020-03-02 00:00:00'), Timestamp('2020-03-03 00:00:00'), Timestamp('2020-03-04 00:00:00'), Timestamp('2020-03-05 00:00:00'), Timestamp('2020-03-06 00:00:00')]\n", + "##### i: 59\n", + "ds_torch.__getitem__(59) is a of length 3\n", + "ds_torch.__getitem__(59)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(59)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(59)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-03-01 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-01 00:00:00'), Timestamp('2020-03-02 00:00:00'), Timestamp('2020-03-03 00:00:00'), Timestamp('2020-03-04 00:00:00'), Timestamp('2020-03-05 00:00:00'), Timestamp('2020-03-06 00:00:00'), Timestamp('2020-03-07 00:00:00')]\n", + "##### i: 60\n", + "ds_torch.__getitem__(60) is a of length 3\n", + "ds_torch.__getitem__(60)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(60)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(60)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-03-02 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-02 00:00:00'), Timestamp('2020-03-03 00:00:00'), Timestamp('2020-03-04 00:00:00'), Timestamp('2020-03-05 00:00:00'), Timestamp('2020-03-06 00:00:00'), Timestamp('2020-03-07 00:00:00'), Timestamp('2020-03-08 00:00:00')]\n", + "##### i: 61\n", + "ds_torch.__getitem__(61) is a of length 3\n", + "ds_torch.__getitem__(61)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(61)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(61)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-03-03 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-03 00:00:00'), Timestamp('2020-03-04 00:00:00'), Timestamp('2020-03-05 00:00:00'), Timestamp('2020-03-06 00:00:00'), Timestamp('2020-03-07 00:00:00'), Timestamp('2020-03-08 00:00:00'), Timestamp('2020-03-09 00:00:00')]\n", + "##### i: 62\n", + "ds_torch.__getitem__(62) is a of length 3\n", + "ds_torch.__getitem__(62)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(62)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(62)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-03-04 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-04 00:00:00'), Timestamp('2020-03-05 00:00:00'), Timestamp('2020-03-06 00:00:00'), Timestamp('2020-03-07 00:00:00'), Timestamp('2020-03-08 00:00:00'), Timestamp('2020-03-09 00:00:00'), Timestamp('2020-03-10 00:00:00')]\n", + "##### i: 63\n", + "ds_torch.__getitem__(63) is a of length 3\n", + "ds_torch.__getitem__(63)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(63)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(63)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-03-05 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-05 00:00:00'), Timestamp('2020-03-06 00:00:00'), Timestamp('2020-03-07 00:00:00'), Timestamp('2020-03-08 00:00:00'), Timestamp('2020-03-09 00:00:00'), Timestamp('2020-03-10 00:00:00'), Timestamp('2020-03-11 00:00:00')]\n", + "##### i: 64\n", + "ds_torch.__getitem__(64) is a of length 3\n", + "ds_torch.__getitem__(64)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(64)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(64)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-03-06 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-06 00:00:00'), Timestamp('2020-03-07 00:00:00'), Timestamp('2020-03-08 00:00:00'), Timestamp('2020-03-09 00:00:00'), Timestamp('2020-03-10 00:00:00'), Timestamp('2020-03-11 00:00:00'), Timestamp('2020-03-12 00:00:00')]\n", + "##### i: 65\n", + "ds_torch.__getitem__(65) is a of length 3\n", + "ds_torch.__getitem__(65)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(65)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(65)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-03-07 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-07 00:00:00'), Timestamp('2020-03-08 00:00:00'), Timestamp('2020-03-09 00:00:00'), Timestamp('2020-03-10 00:00:00'), Timestamp('2020-03-11 00:00:00'), Timestamp('2020-03-12 00:00:00'), Timestamp('2020-03-13 00:00:00')]\n", + "##### i: 66\n", + "ds_torch.__getitem__(66) is a of length 3\n", + "ds_torch.__getitem__(66)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(66)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(66)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-03-08 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-08 00:00:00'), Timestamp('2020-03-09 00:00:00'), Timestamp('2020-03-10 00:00:00'), Timestamp('2020-03-11 00:00:00'), Timestamp('2020-03-12 00:00:00'), Timestamp('2020-03-13 00:00:00'), Timestamp('2020-03-14 00:00:00')]\n", + "##### i: 67\n", + "ds_torch.__getitem__(67) is a of length 3\n", + "ds_torch.__getitem__(67)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(67)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(67)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-03-09 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-09 00:00:00'), Timestamp('2020-03-10 00:00:00'), Timestamp('2020-03-11 00:00:00'), Timestamp('2020-03-12 00:00:00'), Timestamp('2020-03-13 00:00:00'), Timestamp('2020-03-14 00:00:00'), Timestamp('2020-03-15 00:00:00')]\n", + "##### i: 68\n", + "ds_torch.__getitem__(68) is a of length 3\n", + "ds_torch.__getitem__(68)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(68)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(68)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-03-10 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-10 00:00:00'), Timestamp('2020-03-11 00:00:00'), Timestamp('2020-03-12 00:00:00'), Timestamp('2020-03-13 00:00:00'), Timestamp('2020-03-14 00:00:00'), Timestamp('2020-03-15 00:00:00'), Timestamp('2020-03-16 00:00:00')]\n", + "##### i: 69\n", + "ds_torch.__getitem__(69) is a of length 3\n", + "ds_torch.__getitem__(69)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(69)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(69)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-03-11 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-11 00:00:00'), Timestamp('2020-03-12 00:00:00'), Timestamp('2020-03-13 00:00:00'), Timestamp('2020-03-14 00:00:00'), Timestamp('2020-03-15 00:00:00'), Timestamp('2020-03-16 00:00:00'), Timestamp('2020-03-17 00:00:00')]\n", + "##### i: 70\n", + "ds_torch.__getitem__(70) is a of length 3\n", + "ds_torch.__getitem__(70)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(70)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(70)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-03-12 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-12 00:00:00'), Timestamp('2020-03-13 00:00:00'), Timestamp('2020-03-14 00:00:00'), Timestamp('2020-03-15 00:00:00'), Timestamp('2020-03-16 00:00:00'), Timestamp('2020-03-17 00:00:00'), Timestamp('2020-03-18 00:00:00')]\n", + "##### i: 71\n", + "ds_torch.__getitem__(71) is a of length 3\n", + "ds_torch.__getitem__(71)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(71)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(71)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-03-13 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-13 00:00:00'), Timestamp('2020-03-14 00:00:00'), Timestamp('2020-03-15 00:00:00'), Timestamp('2020-03-16 00:00:00'), Timestamp('2020-03-17 00:00:00'), Timestamp('2020-03-18 00:00:00'), Timestamp('2020-03-19 00:00:00')]\n", + "##### i: 72\n", + "ds_torch.__getitem__(72) is a of length 3\n", + "ds_torch.__getitem__(72)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(72)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(72)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-03-14 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-14 00:00:00'), Timestamp('2020-03-15 00:00:00'), Timestamp('2020-03-16 00:00:00'), Timestamp('2020-03-17 00:00:00'), Timestamp('2020-03-18 00:00:00'), Timestamp('2020-03-19 00:00:00'), Timestamp('2020-03-20 00:00:00')]\n", + "##### i: 73\n", + "ds_torch.__getitem__(73) is a of length 3\n", + "ds_torch.__getitem__(73)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(73)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(73)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-03-15 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-15 00:00:00'), Timestamp('2020-03-16 00:00:00'), Timestamp('2020-03-17 00:00:00'), Timestamp('2020-03-18 00:00:00'), Timestamp('2020-03-19 00:00:00'), Timestamp('2020-03-20 00:00:00'), Timestamp('2020-03-21 00:00:00')]\n", + "##### i: 74\n", + "ds_torch.__getitem__(74) is a of length 3\n", + "ds_torch.__getitem__(74)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(74)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(74)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-03-16 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-16 00:00:00'), Timestamp('2020-03-17 00:00:00'), Timestamp('2020-03-18 00:00:00'), Timestamp('2020-03-19 00:00:00'), Timestamp('2020-03-20 00:00:00'), Timestamp('2020-03-21 00:00:00'), Timestamp('2020-03-22 00:00:00')]\n", + "##### i: 75\n", + "ds_torch.__getitem__(75) is a of length 3\n", + "ds_torch.__getitem__(75)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(75)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(75)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-03-17 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-17 00:00:00'), Timestamp('2020-03-18 00:00:00'), Timestamp('2020-03-19 00:00:00'), Timestamp('2020-03-20 00:00:00'), Timestamp('2020-03-21 00:00:00'), Timestamp('2020-03-22 00:00:00'), Timestamp('2020-03-23 00:00:00')]\n", + "##### i: 76\n", + "ds_torch.__getitem__(76) is a of length 3\n", + "ds_torch.__getitem__(76)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(76)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(76)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-03-18 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-18 00:00:00'), Timestamp('2020-03-19 00:00:00'), Timestamp('2020-03-20 00:00:00'), Timestamp('2020-03-21 00:00:00'), Timestamp('2020-03-22 00:00:00'), Timestamp('2020-03-23 00:00:00'), Timestamp('2020-03-24 00:00:00')]\n", + "##### i: 77\n", + "ds_torch.__getitem__(77) is a of length 3\n", + "ds_torch.__getitem__(77)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(77)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(77)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-03-19 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-19 00:00:00'), Timestamp('2020-03-20 00:00:00'), Timestamp('2020-03-21 00:00:00'), Timestamp('2020-03-22 00:00:00'), Timestamp('2020-03-23 00:00:00'), Timestamp('2020-03-24 00:00:00'), Timestamp('2020-03-25 00:00:00')]\n", + "##### i: 78\n", + "ds_torch.__getitem__(78) is a of length 3\n", + "ds_torch.__getitem__(78)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(78)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(78)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-03-20 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-20 00:00:00'), Timestamp('2020-03-21 00:00:00'), Timestamp('2020-03-22 00:00:00'), Timestamp('2020-03-23 00:00:00'), Timestamp('2020-03-24 00:00:00'), Timestamp('2020-03-25 00:00:00'), Timestamp('2020-03-26 00:00:00')]\n", + "##### i: 79\n", + "ds_torch.__getitem__(79) is a of length 3\n", + "ds_torch.__getitem__(79)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(79)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(79)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-03-21 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-21 00:00:00'), Timestamp('2020-03-22 00:00:00'), Timestamp('2020-03-23 00:00:00'), Timestamp('2020-03-24 00:00:00'), Timestamp('2020-03-25 00:00:00'), Timestamp('2020-03-26 00:00:00'), Timestamp('2020-03-27 00:00:00')]\n", + "##### i: 80\n", + "ds_torch.__getitem__(80) is a of length 3\n", + "ds_torch.__getitem__(80)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(80)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(80)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-03-22 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-22 00:00:00'), Timestamp('2020-03-23 00:00:00'), Timestamp('2020-03-24 00:00:00'), Timestamp('2020-03-25 00:00:00'), Timestamp('2020-03-26 00:00:00'), Timestamp('2020-03-27 00:00:00'), Timestamp('2020-03-28 00:00:00')]\n", + "##### i: 81\n", + "ds_torch.__getitem__(81) is a of length 3\n", + "ds_torch.__getitem__(81)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(81)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(81)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-03-23 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-23 00:00:00'), Timestamp('2020-03-24 00:00:00'), Timestamp('2020-03-25 00:00:00'), Timestamp('2020-03-26 00:00:00'), Timestamp('2020-03-27 00:00:00'), Timestamp('2020-03-28 00:00:00'), Timestamp('2020-03-29 00:00:00')]\n", + "##### i: 82\n", + "ds_torch.__getitem__(82) is a of length 3\n", + "ds_torch.__getitem__(82)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(82)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(82)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-03-24 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-24 00:00:00'), Timestamp('2020-03-25 00:00:00'), Timestamp('2020-03-26 00:00:00'), Timestamp('2020-03-27 00:00:00'), Timestamp('2020-03-28 00:00:00'), Timestamp('2020-03-29 00:00:00'), Timestamp('2020-03-30 00:00:00')]\n", + "##### i: 83\n", + "ds_torch.__getitem__(83) is a of length 3\n", + "ds_torch.__getitem__(83)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(83)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(83)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-03-25 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-25 00:00:00'), Timestamp('2020-03-26 00:00:00'), Timestamp('2020-03-27 00:00:00'), Timestamp('2020-03-28 00:00:00'), Timestamp('2020-03-29 00:00:00'), Timestamp('2020-03-30 00:00:00'), Timestamp('2020-03-31 00:00:00')]\n", + "##### i: 84\n", + "ds_torch.__getitem__(84) is a of length 3\n", + "ds_torch.__getitem__(84)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(84)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(84)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-03-26 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-26 00:00:00'), Timestamp('2020-03-27 00:00:00'), Timestamp('2020-03-28 00:00:00'), Timestamp('2020-03-29 00:00:00'), Timestamp('2020-03-30 00:00:00'), Timestamp('2020-03-31 00:00:00'), Timestamp('2020-04-01 00:00:00')]\n", + "##### i: 85\n", + "ds_torch.__getitem__(85) is a of length 3\n", + "ds_torch.__getitem__(85)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(85)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(85)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-03-27 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-27 00:00:00'), Timestamp('2020-03-28 00:00:00'), Timestamp('2020-03-29 00:00:00'), Timestamp('2020-03-30 00:00:00'), Timestamp('2020-03-31 00:00:00'), Timestamp('2020-04-01 00:00:00'), Timestamp('2020-04-02 00:00:00')]\n", + "##### i: 86\n", + "ds_torch.__getitem__(86) is a of length 3\n", + "ds_torch.__getitem__(86)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(86)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(86)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-03-28 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-28 00:00:00'), Timestamp('2020-03-29 00:00:00'), Timestamp('2020-03-30 00:00:00'), Timestamp('2020-03-31 00:00:00'), Timestamp('2020-04-01 00:00:00'), Timestamp('2020-04-02 00:00:00'), Timestamp('2020-04-03 00:00:00')]\n", + "##### i: 87\n", + "ds_torch.__getitem__(87) is a of length 3\n", + "ds_torch.__getitem__(87)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(87)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(87)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-03-29 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-29 00:00:00'), Timestamp('2020-03-30 00:00:00'), Timestamp('2020-03-31 00:00:00'), Timestamp('2020-04-01 00:00:00'), Timestamp('2020-04-02 00:00:00'), Timestamp('2020-04-03 00:00:00'), Timestamp('2020-04-04 00:00:00')]\n", + "##### i: 88\n", + "ds_torch.__getitem__(88) is a of length 3\n", + "ds_torch.__getitem__(88)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(88)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(88)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-03-30 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-30 00:00:00'), Timestamp('2020-03-31 00:00:00'), Timestamp('2020-04-01 00:00:00'), Timestamp('2020-04-02 00:00:00'), Timestamp('2020-04-03 00:00:00'), Timestamp('2020-04-04 00:00:00'), Timestamp('2020-04-05 00:00:00')]\n", + "##### i: 89\n", + "ds_torch.__getitem__(89) is a of length 3\n", + "ds_torch.__getitem__(89)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(89)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(89)[2] is a of shape (432, 432, 7, 1)\n", + "forecast_date: 2020-03-31 00:00:00\n", + "forecast_dts: [Timestamp('2020-03-31 00:00:00'), Timestamp('2020-04-01 00:00:00'), Timestamp('2020-04-02 00:00:00'), Timestamp('2020-04-03 00:00:00'), Timestamp('2020-04-04 00:00:00'), Timestamp('2020-04-05 00:00:00'), Timestamp('2020-04-06 00:00:00')]\n", + "##### i: 90\n", + "ds_torch.__getitem__(90) is a of length 3\n", + "ds_torch.__getitem__(90)[0] is a of shape (432, 432, 9)\n", + "ds_torch.__getitem__(90)[1] is a of shape (432, 432, 7, 1)\n", + "ds_torch.__getitem__(90)[2] is a of shape (432, 432, 7, 1)\n" + ] + } + ], + "source": [ + "for i in range(ds_torch.__len__()):\n", + " item = ds_torch.__getitem__(i)\n", + " print(f\"##### i: {i}\")\n", + " print(f\"ds_torch.__getitem__({i}) is a {type(item)} of length {len(item)}\")\n", + " for j in range(len(item)):\n", + " print(f\"ds_torch.__getitem__({i})[{j}] is a {type(item[j])} of shape {item[j].shape}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "forecast_date: 2020-01-14 00:00:00\n", + "forecast_dts: [Timestamp('2020-01-14 00:00:00'), Timestamp('2020-01-15 00:00:00'), Timestamp('2020-01-16 00:00:00'), Timestamp('2020-01-17 00:00:00'), Timestamp('2020-01-18 00:00:00'), Timestamp('2020-01-19 00:00:00'), Timestamp('2020-01-20 00:00:00')]\n" + ] + }, + { + "data": { + "text/plain": [ + "(array([[[ 0.49167976, 0.45731494, 0. , ..., -0.97276926,\n", + " 1. , -0.23177579],\n", + " [ 0.49114853, 0.45527366, 0. , ..., -0.97276926,\n", + " 1. , -0.23177579],\n", + " [ 0.48822644, 0.45367032, 0. , ..., -0.97276926,\n", + " 1. , -0.23177579],\n", + " ...,\n", + " [ 0.49409333, 0.49700168, 0. , ..., -0.97276926,\n", + " 1. , -0.23177579],\n", + " [ 0.49529564, 0.4982207 , 0. , ..., -0.97276926,\n", + " 1. , -0.23177579],\n", + " [ 0.4960581 , 0.49904954, 0. , ..., -0.97276926,\n", + " 1. , -0.23177579]],\n", + " \n", + " [[ 0.485931 , 0.45947248, 0. , ..., -0.97276926,\n", + " 1. , -0.23177579],\n", + " [ 0.48212218, 0.45727244, 0. , ..., -0.97276926,\n", + " 1. , -0.23177579],\n", + " [ 0.4829446 , 0.46095508, 0. , ..., -0.97276926,\n", + " 1. , -0.23177579],\n", + " ...,\n", + " [ 0.49355316, 0.49622783, 0. , ..., -0.97276926,\n", + " 1. , -0.23177579],\n", + " [ 0.4948722 , 0.49675837, 0. , ..., -0.97276926,\n", + " 1. , -0.23177579],\n", + " [ 0.49594274, 0.49843076, 0. , ..., -0.97276926,\n", + " 1. , -0.23177579]],\n", + " \n", + " [[ 0.48487893, 0.45723182, 0. , ..., -0.97276926,\n", + " 1. , -0.23177579],\n", + " [ 0.4803935 , 0.45857662, 0. , ..., -0.97276926,\n", + " 1. , -0.23177579],\n", + " [ 0.4774811 , 0.46765223, 0. , ..., -0.97276926,\n", + " 1. , -0.23177579],\n", + " ...,\n", + " [ 0.49242944, 0.49738374, 0. , ..., -0.97276926,\n", + " 1. , -0.23177579],\n", + " [ 0.49418244, 0.4987961 , 0. , ..., -0.97276926,\n", + " 1. , -0.23177579],\n", + " [ 0.49736902, 0.49927232, 0. , ..., -0.97276926,\n", + " 1. , -0.23177579]],\n", + " \n", + " ...,\n", + " \n", + " [[ 0.50845265, 0.5765008 , 0. , ..., -0.97276926,\n", + " -1. , -0.23177579],\n", + " [ 0.50534934, 0.57553226, 0. , ..., -0.97276926,\n", + " -1. , -0.23177579],\n", + " [ 0.50184083, 0.5744398 , 0. , ..., -0.97276926,\n", + " -1. , -0.23177579],\n", + " ...,\n", + " [ 0.5148928 , 0.5024948 , 0. , ..., -0.97276926,\n", + " 1. , -0.23177579],\n", + " [ 0.51618797, 0.500339 , 0. , ..., -0.97276926,\n", + " 1. , -0.23177579],\n", + " [ 0.51338226, 0.49549663, 0. , ..., -0.97276926,\n", + " 1. , -0.23177579]],\n", + " \n", + " [[ 0.509647 , 0.5771778 , 0. , ..., -0.97276926,\n", + " -1. , -0.23177579],\n", + " [ 0.5077593 , 0.5757393 , 0. , ..., -0.97276926,\n", + " -1. , -0.23177579],\n", + " [ 0.5046344 , 0.5750293 , 0. , ..., -0.97276926,\n", + " -1. , -0.23177579],\n", + " ...,\n", + " [ 0.5203622 , 0.50123286, 0. , ..., -0.97276926,\n", + " 1. , -0.23177579],\n", + " [ 0.51786757, 0.4986387 , 0. , ..., -0.97276926,\n", + " 1. , -0.23177579],\n", + " [ 0.5171366 , 0.49582803, 0. , ..., -0.97276926,\n", + " 1. , -0.23177579]],\n", + " \n", + " [[ 0.50947356, 0.5785389 , 0. , ..., -0.97276926,\n", + " -1. , -0.23177579],\n", + " [ 0.5100283 , 0.57686365, 0. , ..., -0.97276926,\n", + " -1. , -0.23177579],\n", + " [ 0.5083068 , 0.57562834, 0. , ..., -0.97276926,\n", + " -1. , -0.23177579],\n", + " ...,\n", + " [ 0.523088 , 0.4988309 , 0. , ..., -0.97276926,\n", + " 1. , -0.23177579],\n", + " [ 0.52149713, 0.4976637 , 0. , ..., -0.97276926,\n", + " 1. , -0.23177579],\n", + " [ 0.52379906, 0.49838376, 0. , ..., -0.97276926,\n", + " 1. , -0.23177579]]], dtype=float32),\n", + " array([[[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]]], dtype=float32),\n", + " array([[[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " ...,\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]],\n", + " \n", + " \n", + " [[[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " ...,\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]],\n", + " \n", + " [[0.],\n", + " [0.],\n", + " [0.],\n", + " ...,\n", + " [0.],\n", + " [0.],\n", + " [0.]]]], dtype=float32))" + ] + }, + "execution_count": 49, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds_torch.__getitem__(13)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Generating PyTorch DataLoaders" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:Loading configuration dataset_config.pytorch_notebook.json\n", + "WARNING:root:Running in configuration only mode, tfrecords were not generated for this dataset\n", + "INFO:root:Loading configuration /data/hpcdata/users/rychan/notebooks/notebook-pipeline/loader.notebook_api_data.json\n", + "INFO:root:Loading configuration dataset_config.pytorch_notebook.json\n", + "WARNING:root:Running in configuration only mode, tfrecords were not generated for this dataset\n", + "INFO:root:Loading configuration /data/hpcdata/users/rychan/notebooks/notebook-pipeline/loader.notebook_api_data.json\n", + "INFO:root:Loading configuration dataset_config.pytorch_notebook.json\n", + "WARNING:root:Running in configuration only mode, tfrecords were not generated for this dataset\n", + "INFO:root:Loading configuration /data/hpcdata/users/rychan/notebooks/notebook-pipeline/loader.notebook_api_data.json\n" + ] + } + ], + "source": [ + "train_dataset = IceNetDataSetPyTorch(configuration_path=dataset_config, mode=\"train\")\n", + "val_dataset = IceNetDataSetPyTorch(configuration_path=dataset_config, mode=\"val\")\n", + "test_dataset = IceNetDataSetPyTorch(configuration_path=dataset_config, mode=\"test\")" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.data import DataLoader\n", + "\n", + "batch_size = 4\n", + "shuffle = False # set to False for now\n", + "persistent_workers = False\n", + "num_workers = 0\n", + "\n", + "train_dataloader = DataLoader(train_dataset,\n", + " batch_size=batch_size,\n", + " shuffle=shuffle,\n", + " persistent_workers=persistent_workers,\n", + " num_workers=num_workers)\n", + "val_dataloader = DataLoader(val_dataset,\n", + " batch_size=batch_size,\n", + " shuffle=shuffle,\n", + " persistent_workers=persistent_workers,\n", + " num_workers=num_workers)\n", + "test_dataloader = DataLoader(test_dataset,\n", + " batch_size=batch_size,\n", + " shuffle=shuffle,\n", + " persistent_workers=persistent_workers,\n", + " num_workers=num_workers)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Iterating through DataLoaders" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "23" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(train_dataloader)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "n_forecast_days: 7\n", + "n_forecast_days: 7\n", + "n_forecast_days: 7\n", + "n_forecast_days: 7\n" + ] + } + ], + "source": [ + "train_features, train_labels, sample_weights = next(iter(train_dataloader))" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([4, 432, 432, 9])" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_features.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([4, 432, 432, 7, 1])" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_labels.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([4, 432, 432, 7, 1])" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sample_weights.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "from icenet_unet_small import UNet\n", + "\n", + "unet = UNet(input_channels=9)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/data/hpcdata/users/rychan/miniconda3/envs/icenet_pytorch/lib/python3.8/site-packages/torch/nn/functional.py:3737: UserWarning: nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.\n", + " warnings.warn(\"nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.\")\n" + ] + } + ], + "source": [ + "y_hat = unet(train_features)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([4, 432, 432, 3, 6])" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y_hat.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## IceNet UNet model\n", + "\n", + "As a first attempt to implement a PyTorch example, we adapt code from https://github.com/ampersandmcd/icenet-gan/.\n", + "\n", + "Below is a PyTorch implementation of the UNet architecture." + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch import nn\n", + "import torch.nn.functional as F\n", + "\n", + "class UNet(nn.Module):\n", + " \"\"\"\n", + " An implementation of a UNet for pixelwise classification.\n", + " \"\"\"\n", + " \n", + " def __init__(self,\n", + " input_channels, \n", + " filter_size=3, \n", + " n_filters_factor=1, \n", + " n_forecast_days=6, \n", + " n_output_classes=3,\n", + " **kwargs):\n", + " super(UNet, self).__init__()\n", + "\n", + " self.input_channels = input_channels\n", + " self.filter_size = filter_size\n", + " self.n_filters_factor = n_filters_factor\n", + " self.n_forecast_days = n_forecast_days\n", + " self.n_output_classes = n_output_classes\n", + "\n", + " self.conv1a = nn.Conv2d(in_channels=input_channels, \n", + " out_channels=int(128*n_filters_factor),\n", + " kernel_size=filter_size,\n", + " padding=\"same\")\n", + " self.conv1b = nn.Conv2d(in_channels=int(128*n_filters_factor),\n", + " out_channels=int(128*n_filters_factor),\n", + " kernel_size=filter_size,\n", + " padding=\"same\")\n", + " self.bn1 = nn.BatchNorm2d(num_features=int(128*n_filters_factor))\n", + "\n", + " self.conv2a = nn.Conv2d(in_channels=int(128*n_filters_factor),\n", + " out_channels=int(256*n_filters_factor),\n", + " kernel_size=filter_size,\n", + " padding=\"same\")\n", + " self.conv2b = nn.Conv2d(in_channels=int(256*n_filters_factor),\n", + " out_channels=int(256*n_filters_factor),\n", + " kernel_size=filter_size,\n", + " padding=\"same\")\n", + " self.bn2 = nn.BatchNorm2d(num_features=int(256*n_filters_factor))\n", + "\n", + " # self.conv3a = nn.Conv2d(in_channels=int(256*n_filters_factor),\n", + " # out_channels=int(512*n_filters_factor),\n", + " # kernel_size=filter_size,\n", + " # padding=\"same\")\n", + " # self.conv3b = nn.Conv2d(in_channels=int(512*n_filters_factor),\n", + " # out_channels=int(512*n_filters_factor),\n", + " # kernel_size=filter_size,\n", + " # padding=\"same\")\n", + " # self.bn3 = nn.BatchNorm2d(num_features=int(512*n_filters_factor))\n", + "\n", + " # self.conv4a = nn.Conv2d(in_channels=int(512*n_filters_factor),\n", + " # out_channels=int(512*n_filters_factor),\n", + " # kernel_size=filter_size,\n", + " # padding=\"same\")\n", + " # self.conv4b = nn.Conv2d(in_channels=int(512*n_filters_factor),\n", + " # out_channels=int(512*n_filters_factor),\n", + " # kernel_size=filter_size,\n", + " # padding=\"same\")\n", + " # self.bn4 = nn.BatchNorm2d(num_features=int(512*n_filters_factor))\n", + "\n", + " # self.conv5a = nn.Conv2d(in_channels=int(512*n_filters_factor),\n", + " # out_channels=int(1024*n_filters_factor),\n", + " # kernel_size=filter_size,\n", + " # padding=\"same\")\n", + " # self.conv5b = nn.Conv2d(in_channels=int(1024*n_filters_factor),\n", + " # out_channels=int(1024*n_filters_factor),\n", + " # kernel_size=filter_size,\n", + " # padding=\"same\")\n", + " # self.bn5 = nn.BatchNorm2d(num_features=int(1024*n_filters_factor))\n", + "\n", + " # self.conv6a = nn.Conv2d(in_channels=int(1024*n_filters_factor),\n", + " # out_channels=int(512*n_filters_factor),\n", + " # kernel_size=filter_size,\n", + " # padding=\"same\")\n", + " # self.conv6b = nn.Conv2d(in_channels=int(1024*n_filters_factor),\n", + " # out_channels=int(512*n_filters_factor),\n", + " # kernel_size=filter_size,\n", + " # padding=\"same\")\n", + " # self.conv6c = nn.Conv2d(in_channels=int(512*n_filters_factor),\n", + " # out_channels=int(512*n_filters_factor),\n", + " # kernel_size=filter_size,\n", + " # padding=\"same\")\n", + " # self.bn6 = nn.BatchNorm2d(num_features=int(512*n_filters_factor))\n", + "\n", + " # self.conv7a = nn.Conv2d(in_channels=int(512*n_filters_factor),\n", + " # out_channels=int(512*n_filters_factor),\n", + " # kernel_size=filter_size,\n", + " # padding=\"same\")\n", + " # self.conv7b = nn.Conv2d(in_channels=int(1024*n_filters_factor),\n", + " # out_channels=int(512*n_filters_factor),\n", + " # kernel_size=filter_size,\n", + " # padding=\"same\")\n", + " # self.conv7c = nn.Conv2d(in_channels=int(512*n_filters_factor),\n", + " # out_channels=int(512*n_filters_factor),\n", + " # kernel_size=filter_size,\n", + " # padding=\"same\")\n", + " # self.bn7 = nn.BatchNorm2d(num_features=int(512*n_filters_factor))\n", + "\n", + " # self.conv8a = nn.Conv2d(in_channels=int(512*n_filters_factor),\n", + " # out_channels=int(256*n_filters_factor),\n", + " # kernel_size=filter_size,\n", + " # padding=\"same\")\n", + " # self.conv8b = nn.Conv2d(in_channels=int(512*n_filters_factor),\n", + " # out_channels=int(256*n_filters_factor),\n", + " # kernel_size=filter_size,\n", + " # padding=\"same\")\n", + " # self.conv8c = nn.Conv2d(in_channels=int(256*n_filters_factor),\n", + " # out_channels=int(256*n_filters_factor),\n", + " # kernel_size=filter_size,\n", + " # padding=\"same\")\n", + " # self.bn8 = nn.BatchNorm2d(num_features=int(256*n_filters_factor))\n", + "\n", + " self.conv9a = nn.Conv2d(in_channels=int(256*n_filters_factor),\n", + " out_channels=int(128*n_filters_factor),\n", + " kernel_size=filter_size,\n", + " padding=\"same\")\n", + " self.conv9b = nn.Conv2d(in_channels=int(256*n_filters_factor),\n", + " out_channels=int(128*n_filters_factor),\n", + " kernel_size=filter_size,\n", + " padding=\"same\")\n", + " self.conv9c = nn.Conv2d(in_channels=int(128*n_filters_factor),\n", + " out_channels=int(128*n_filters_factor),\n", + " kernel_size=filter_size,\n", + " padding=\"same\") # no batch norm on last layer\n", + "\n", + " self.final_conv = nn.Conv2d(in_channels=int(128*n_filters_factor),\n", + " out_channels=n_output_classes*n_forecast_days,\n", + " kernel_size=filter_size,\n", + " padding=\"same\")\n", + " \n", + " def forward(self, x):\n", + "\n", + " # transpose from shape (b, h, w, c) to (b, c, h, w) for pytorch conv2d layers\n", + " x = torch.movedim(x, -1, 1) # move c from last to second dim\n", + "\n", + " # run through network\n", + " conv1 = self.conv1a(x) # input to 128\n", + " conv1 = F.relu(conv1)\n", + " conv1 = self.conv1b(conv1) # 128 to 128\n", + " conv1 = F.relu(conv1)\n", + " bn1 = self.bn1(conv1)\n", + " pool1 = F.max_pool2d(bn1, kernel_size=(2, 2))\n", + "\n", + " conv2 = self.conv2a(pool1) # 128 to 256\n", + " conv2 = F.relu(conv2)\n", + " conv2 = self.conv2b(conv2) # 256 to 256\n", + " conv2 = F.relu(conv2)\n", + " bn2 = self.bn2(conv2)\n", + " pool2 = F.max_pool2d(bn2, kernel_size=(2, 2))\n", + "\n", + " # conv3 = self.conv3a(pool2) # 256 to 512\n", + " # conv3 = F.relu(conv3)\n", + " # conv3 = self.conv3b(conv3) # 512 to 512\n", + " # conv3 = F.relu(conv3)\n", + " # bn3 = self.bn3(conv3)\n", + " # pool3 = F.max_pool2d(bn3, kernel_size=(2, 2))\n", + "\n", + " # conv4 = self.conv4a(pool3) # 512 to 512\n", + " # conv4 = F.relu(conv4)\n", + " # conv4 = self.conv4b(conv4) # 512 to 512\n", + " # conv4 = F.relu(conv4)\n", + " # bn4 = self.bn4(conv4)\n", + " # pool4 = F.max_pool2d(bn4, kernel_size=(2, 2))\n", + "\n", + " # conv5 = self.conv5a(pool4) # 512 to 1024\n", + " # conv5 = F.relu(conv5)\n", + " # conv5 = self.conv5b(conv5) # 1024 to 1024\n", + " # conv5 = F.relu(conv5)\n", + " # bn5 = self.bn5(conv5)\n", + "\n", + " # up6 = F.upsample(bn5, scale_factor=2, mode=\"nearest\")\n", + " # up6 = self.conv6a(up6) # 1024 to 512\n", + " # up6 = F.relu(up6)\n", + " # merge6 = torch.cat([bn4, up6], dim=1) # 512 and 512 to 1024 along c dimension\n", + " # conv6 = self.conv6b(merge6) # 1024 to 512\n", + " # conv6 = F.relu(conv6)\n", + " # conv6 = self.conv6c(conv6) # 512 to 512\n", + " # conv6 = F.relu(conv6)\n", + " # bn6 = self.bn6(conv6)\n", + "\n", + " # up7 = F.upsample(bn6, scale_factor=2, mode=\"nearest\")\n", + " # up7 = self.conv7a(up7) # 1024 to 512\n", + " # up7 = F.relu(up7)\n", + " # merge7 = torch.cat([bn3, up7], dim=1) # 512 and 512 to 1024 along c dimension\n", + " # conv7 = self.conv7b(merge7) # 1024 to 512\n", + " # conv7 = F.relu(conv7)\n", + " # conv7 = self.conv7c(conv7) # 512 to 512\n", + " # conv7 = F.relu(conv7)\n", + " # bn7 = self.bn7(conv7)\n", + "\n", + " # up8 = F.upsample(bn7, scale_factor=2, mode=\"nearest\")\n", + " # up8 = self.conv8a(up8) # 512 to 256\n", + " # up8 = F.relu(up8)\n", + " # merge8 = torch.cat([bn2, up8], dim=1) # 256 and 256 to 512 along c dimension\n", + " # conv8 = self.conv8b(merge8) # 512 to 256\n", + " # conv8 = F.relu(conv8)\n", + " # conv8 = self.conv8c(conv8) # 256 to 256\n", + " # conv8 = F.relu(conv8)\n", + " # bn8 = self.bn8(conv8)\n", + "\n", + " up9 = F.upsample(bn8, scale_factor=2, mode=\"nearest\")\n", + " up9 = self.conv9a(up9) # 256 to 128\n", + " up9 = F.relu(up9)\n", + " merge9 = torch.cat([bn1, up9], dim=1) # 128 and 128 to 256 along c dimension\n", + " conv9 = self.conv9b(merge9) # 256 to 128\n", + " conv9 = F.relu(conv9)\n", + " conv9 = self.conv9c(conv9) # 128 to 128\n", + " conv9 = F.relu(conv9) # no batch norm on last layer\n", + " \n", + " final_layer_logits = self.final_conv(conv9)\n", + "\n", + " # transpose from shape (b, c, h, w) back to (b, h, w, c) to align with training data\n", + " final_layer_logits = torch.movedim(final_layer_logits, 1, -1) # move c from second to final dim\n", + " b, h, w, c = final_layer_logits.shape\n", + "\n", + " # unpack c=classes*months dimension into classes, months as separate dimensions\n", + " final_layer_logits = final_layer_logits.reshape((b, h, w, self.n_output_classes, self.n_forecast_days))\n", + "\n", + " output = F.softmax(final_layer_logits, dim=-2) # apply over n_output_classes dimension\n", + " \n", + " return output # shape (b, h, w, c, t)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Some metrics for evaluating IceNet performance:" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [], + "source": [ + "from torchmetrics import Metric\n", + "\n", + "class IceNetAccuracy(Metric):\n", + " \"\"\"\n", + " Binary accuracy metric for use at multiple leadtimes.\n", + " \"\"\" \n", + "\n", + " # Set class properties\n", + " is_differentiable: bool = False\n", + " higher_is_better: bool = True\n", + " full_state_update: bool = True\n", + "\n", + " def __init__(self, leadtimes_to_evaluate: list):\n", + " \"\"\"\n", + " Construct a binary accuracy metric for use at multiple leadtimes.\n", + " :param leadtimes_to_evaluate: A list of leadtimes to consider\n", + " e.g., [0, 1, 2, 3, 4, 5] to consider all six months in accuracy computation or\n", + " e.g., [0] to only look at the first month's accuracy\n", + " e.g., [5] to only look at the sixth month's accuracy\n", + " \"\"\"\n", + " super().__init__()\n", + " self.leadtimes_to_evaluate = leadtimes_to_evaluate\n", + " self.add_state(\"weighted_score\", default=torch.tensor(0.), dist_reduce_fx=\"sum\")\n", + " self.add_state(\"possible_score\", default=torch.tensor(0.), dist_reduce_fx=\"sum\")\n", + "\n", + " def update(self, preds: torch.Tensor, target: torch.Tensor, sample_weight: torch.Tensor):\n", + " # preds and target are shape (b, h, w, t)\n", + " # sum marginal and full ice for binary eval\n", + " preds = (preds > 0).long()\n", + " target = (target > 0).long()\n", + " base_score = preds[:, :, :, self.leadtimes_to_evaluate] == target[:, :, :, self.leadtimes_to_evaluate]\n", + " self.weighted_score += torch.sum(base_score * sample_weight[:, :, :, self.leadtimes_to_evaluate])\n", + " self.possible_score += torch.sum(sample_weight[:, :, :, self.leadtimes_to_evaluate])\n", + "\n", + " def compute(self):\n", + " return self.weighted_score.float() / self.possible_score\n", + "\n", + "\n", + "class SIEError(Metric):\n", + " \"\"\"\n", + " Sea Ice Extent error metric (in km^2) for use at multiple leadtimes.\n", + " \"\"\" \n", + "\n", + " # Set class properties\n", + " is_differentiable: bool = False\n", + " higher_is_better: bool = False\n", + " full_state_update: bool = True\n", + "\n", + " def __init__(self, leadtimes_to_evaluate: list):\n", + " \"\"\"\n", + " Construct an SIE error metric (in km^2) for use at multiple leadtimes.\n", + " :param leadtimes_to_evaluate: A list of leadtimes to consider\n", + " e.g., [0, 1, 2, 3, 4, 5] to consider all six months in computation or\n", + " e.g., [0] to only look at the first month\n", + " e.g., [5] to only look at the sixth month\n", + " \"\"\"\n", + " super().__init__()\n", + " self.leadtimes_to_evaluate = leadtimes_to_evaluate\n", + " self.add_state(\"pred_sie\", default=torch.tensor(0.), dist_reduce_fx=\"sum\")\n", + " self.add_state(\"true_sie\", default=torch.tensor(0.), dist_reduce_fx=\"sum\")\n", + "\n", + " def update(self, preds: torch.Tensor, target: torch.Tensor, sample_weight: torch.Tensor):\n", + " # preds and target are shape (b, h, w, t)\n", + " # sum marginal and full ice for binary eval\n", + " preds = (preds > 0).long()\n", + " target = (target > 0).long()\n", + " self.pred_sie += preds[:, :, :, self.leadtimes_to_evaluate].sum()\n", + " self.true_sie += target[:, :, :, self.leadtimes_to_evaluate].sum()\n", + "\n", + " def compute(self):\n", + " return (self.pred_sie - self.true_sie) * 25**2 # each pixel is 25x25 km" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A _LightningModule_ wrapper for UNet model." + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [], + "source": [ + "import lightning.pytorch as pl\n", + "from torchmetrics import MetricCollection\n", + "\n", + "class LitUNet(pl.LightningModule):\n", + " \"\"\"\n", + " A LightningModule wrapping the UNet implementation of IceNet.\n", + " \"\"\"\n", + " def __init__(self,\n", + " model: nn.Module,\n", + " criterion: callable,\n", + " learning_rate: float):\n", + " \"\"\"\n", + " Construct a UNet LightningModule.\n", + " Note that we keep hyperparameters separate from dataloaders to prevent data leakage at test time.\n", + " :param model: PyTorch model\n", + " :param criterion: PyTorch loss function for training instantiated with reduction=\"none\"\n", + " :param learning_rate: Float learning rate for our optimiser\n", + " \"\"\"\n", + " super().__init__()\n", + " self.model = model\n", + " self.criterion = criterion\n", + " self.learning_rate = learning_rate\n", + " self.n_output_classes = model.n_output_classes # this should be a property of the network\n", + "\n", + " metrics = {\n", + " \"val_accuracy\": IceNetAccuracy(leadtimes_to_evaluate=list(range(self.model.n_forecast_days))),\n", + " \"val_sieerror\": SIEError(leadtimes_to_evaluate=list(range(self.model.n_forecast_days)))\n", + " }\n", + " for i in range(self.model.n_forecast_days):\n", + " metrics[f\"val_accuracy_{i}\"] = IceNetAccuracy(leadtimes_to_evaluate=[i])\n", + " metrics[f\"val_sieerror_{i}\"] = SIEError(leadtimes_to_evaluate=[i])\n", + " self.metrics = MetricCollection(metrics)\n", + "\n", + " test_metrics = {\n", + " \"test_accuracy\": IceNetAccuracy(leadtimes_to_evaluate=list(range(self.model.n_forecast_days))),\n", + " \"test_sieerror\": SIEError(leadtimes_to_evaluate=list(range(self.model.n_forecast_days)))\n", + " }\n", + " for i in range(self.model.n_forecast_days):\n", + " test_metrics[f\"test_accuracy_{i}\"] = IceNetAccuracy(leadtimes_to_evaluate=[i])\n", + " test_metrics[f\"test_sieerror_{i}\"] = SIEError(leadtimes_to_evaluate=[i])\n", + " self.test_metrics = MetricCollection(test_metrics)\n", + "\n", + " self.save_hyperparameters(ignore=[\"model\", \"criterion\"])\n", + "\n", + " def forward(self, x):\n", + " \"\"\"\n", + " Implement forward function.\n", + " :param x: Inputs to model.\n", + " :return: Outputs of model.\n", + " \"\"\"\n", + " return self.model(x)\n", + "\n", + " def training_step(self, batch):\n", + " \"\"\"\n", + " Perform a pass through a batch of training data.\n", + " Apply pixel-weighted loss by manually reducing.\n", + " See e.g. https://discuss.pytorch.org/t/unet-pixel-wise-weighted-loss-function/46689/5.\n", + " :param batch: Batch of input, output, weight triplets\n", + " :param batch_idx: Index of batch\n", + " :return: Loss from this batch of data for use in backprop\n", + " \"\"\"\n", + " x, y, sample_weight = batch\n", + " y_hat = self.model(x)\n", + " # y and y_hat are shape (b, h, w, c, t) but loss expects (b, c, h, w, t)\n", + " # note that criterion needs reduction=\"none\" for weighting to work\n", + " if isinstance(self.criterion, nn.CrossEntropyLoss): # requires int class encoding\n", + " loss = self.criterion(y_hat.movedim(-2, 1), y.argmax(-2).long())\n", + " else: # requires one-hot encoding\n", + " loss = self.criterion(y_hat.movedim(-2, 1), y.movedim(-2, 1))\n", + " loss = torch.mean(loss * sample_weight.movedim(-2, 1))\n", + " self.log(\"train_loss\", loss, sync_dist=True)\n", + " return loss\n", + "\n", + " def validation_step(self, batch):\n", + " x, y, sample_weight = batch\n", + " y_hat = self.model(x)\n", + " # y and y_hat are shape (b, h, w, c, t) but loss expects (b, c, h, w, t)\n", + " # note that criterion needs reduction=\"none\" for weighting to work\n", + " if isinstance(self.criterion, nn.CrossEntropyLoss): # requires int class encoding\n", + " loss = self.criterion(y_hat.movedim(-2, 1), y.argmax(-2).long())\n", + " else: # requires one-hot encoding\n", + " loss = self.criterion(y_hat.movedim(-2, 1), y.movedim(-2, 1))\n", + " loss = torch.mean(loss * sample_weight.movedim(-2, 1))\n", + " self.log(\"val_loss\", loss, on_step=False, on_epoch=True, sync_dist=True) # epoch-level loss\n", + " y_hat_pred = y_hat.argmax(dim=-2).long() # argmax over c where shape is (b, h, w, c, t)\n", + " self.metrics.update(y_hat_pred, y.argmax(dim=-2).long(), sample_weight.squeeze(dim=-2)) # shape (b, h, w, t)\n", + " return loss\n", + "\n", + " def on_validation_epoch_end(self):\n", + " self.log_dict(self.metrics.compute(), on_step=False, on_epoch=True, sync_dist=True) # epoch-level metrics\n", + " self.metrics.reset()\n", + "\n", + " def test_step(self, batch):\n", + " x, y, sample_weight = batch\n", + " y_hat = self.model(x)\n", + " # y and y_hat are shape (b, h, w, c, t) but loss expects (b, c, h, w, t)\n", + " # note that criterion needs reduction=\"none\" for weighting to work\n", + " if isinstance(self.criterion, nn.CrossEntropyLoss): # requires int class encoding\n", + " loss = self.criterion(y_hat.movedim(-2, 1), y.argmax(-2).long())\n", + " else: # requires one-hot encoding\n", + " loss = self.criterion(y_hat.movedim(-2, 1), y.movedim(-2, 1))\n", + " loss = torch.mean(loss * sample_weight.movedim(-2, 1))\n", + " self.log(\"test_loss\", loss, on_step=False, on_epoch=True, sync_dist=True) # epoch-level loss\n", + " y_hat_pred = y_hat.argmax(dim=-2) # argmax over c where shape is (b, h, w, c, t)\n", + " self.test_metrics.update(y_hat_pred, y.argmax(dim=-2).long(), sample_weight.squeeze(dim=-2)) # shape (b, h, w, t)\n", + " return loss\n", + "\n", + " def on_test_epoch_end(self):\n", + " self.log_dict(self.test_metrics.compute(),on_step=False, on_epoch=True, sync_dist=True) # epoch-level metrics\n", + " self.test_metrics.reset()\n", + "\n", + " def configure_optimizers(self):\n", + " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", + " return {\n", + " \"optimizer\": optimizer\n", + " }" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Function for training UNet model using PyTorch Lightning." + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [], + "source": [ + "from lightning.pytorch.callbacks import ModelCheckpoint\n", + "\n", + "def train_icenet(configuration_path,\n", + " learning_rate,\n", + " max_epochs,\n", + " batch_size,\n", + " n_workers,\n", + " filter_size,\n", + " n_filters_factor,\n", + " seed):\n", + " \"\"\"\n", + " Train IceNet using the arguments specified in the `args` namespace.\n", + " :param args: Namespace of configuration parameters\n", + " \"\"\"\n", + " # init\n", + " pl.seed_everything(seed)\n", + " \n", + " # configure datasets and dataloaders\n", + " train_dataset = IceNetDataSetPyTorch(configuration_path, mode=\"train\")\n", + " val_dataset = IceNetDataSetPyTorch(configuration_path, mode=\"val\")\n", + " train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=n_workers,\n", + " persistent_workers=True, shuffle=False)\n", + " val_dataloader = DataLoader(val_dataset, batch_size=batch_size, num_workers=n_workers,\n", + " persistent_workers=True, shuffle=False)\n", + "\n", + " # construct unet\n", + " model = UNet(\n", + " input_channels=len(train_dataset._ds._config[\"channels\"]),\n", + " filter_size=filter_size,\n", + " n_filters_factor=n_filters_factor,\n", + " n_forecast_days=train_dataset._ds._config[\"n_forecast_days\"]\n", + " )\n", + " \n", + " criterion = nn.CrossEntropyLoss(reduction=\"none\")\n", + " \n", + " # configure PyTorch Lightning module\n", + " lit_module = LitUNet(\n", + " model=model,\n", + " criterion=criterion,\n", + " learning_rate=learning_rate\n", + " )\n", + "\n", + " # set up trainer configuration\n", + " trainer = pl.Trainer(\n", + " accelerator=\"auto\",\n", + " devices=1,\n", + " log_every_n_steps=10,\n", + " max_epochs=max_epochs,\n", + " num_sanity_val_steps=1,\n", + " )\n", + " trainer.callbacks.append(ModelCheckpoint(monitor=\"val_accuracy\", mode=\"max\"))\n", + "\n", + " # train model\n", + " print(f\"Training {len(train_dataset)} examples / {len(train_dataloader)} batches (batch size {batch_size}).\")\n", + " print(f\"Validating {len(val_dataset)} examples / {len(val_dataloader)} batches (batch size {batch_size}).\")\n", + " trainer.fit(lit_module, train_dataloader, val_dataloader)" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO: Global seed set to 45\n", + "INFO:lightning.fabric.utilities.seed:Global seed set to 45\n", + "INFO:root:Loading configuration dataset_config.pytorch_notebook.json\n", + "WARNING:root:Running in configuration only mode, tfrecords were not generated for this dataset\n", + "INFO:root:Loading configuration /data/hpcdata/users/rychan/notebooks/notebook-pipeline/loader.notebook_api_data.json\n", + "INFO:root:Loading configuration dataset_config.pytorch_notebook.json\n", + "WARNING:root:Running in configuration only mode, tfrecords were not generated for this dataset\n", + "INFO:root:Loading configuration /data/hpcdata/users/rychan/notebooks/notebook-pipeline/loader.notebook_api_data.json\n", + "INFO: GPU available: True (cuda), used: True\n", + "INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True\n", + "INFO: TPU available: False, using: 0 TPU cores\n", + "INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores\n", + "INFO: IPU available: False, using: 0 IPUs\n", + "INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs\n", + "INFO: HPU available: False, using: 0 HPUs\n", + "INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs\n", + "INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "INFO: \n", + " | Name | Type | Params\n", + "--------------------------------------------------\n", + "0 | model | UNet | 1.8 M \n", + "1 | criterion | CrossEntropyLoss | 0 \n", + "2 | metrics | MetricCollection | 0 \n", + "3 | test_metrics | MetricCollection | 0 \n", + "--------------------------------------------------\n", + "1.8 M Trainable params\n", + "0 Non-trainable params\n", + "1.8 M Total params\n", + "7.224 Total estimated model params size (MB)\n", + "INFO:lightning.pytorch.callbacks.model_summary:\n", + " | Name | Type | Params\n", + "--------------------------------------------------\n", + "0 | model | UNet | 1.8 M \n", + "1 | criterion | CrossEntropyLoss | 0 \n", + "2 | metrics | MetricCollection | 0 \n", + "3 | test_metrics | MetricCollection | 0 \n", + "--------------------------------------------------\n", + "1.8 M Trainable params\n", + "0 Non-trainable params\n", + "1.8 M Total params\n", + "7.224 Total estimated model params size (MB)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training 91 examples / 23 batches (batch size 4).\n", + "Validating 21 examples / 6 batches (batch size 4).\n", + "Sanity Checking: 0it [00:00, ?it/s]" + ] + } + ], + "source": [ + "seed = 45\n", + "train_icenet(configuration_path=dataset_config,\n", + " learning_rate=1e-4,\n", + " max_epochs=10,\n", + " batch_size=4,\n", + " n_workers=12,\n", + " filter_size=3,\n", + " n_filters_factor=1,\n", + " seed=seed)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "9" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(train_dataset._ds._config[\"channels\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "7" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_dataset._ds._config[\"n_forecast_days\"]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "icenet", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.17" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pytorch_example/pytorch_example.py b/pytorch_example/pytorch_example.py new file mode 100644 index 0000000..22ad11a --- /dev/null +++ b/pytorch_example/pytorch_example.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python +# coding: utf-8 + +import torch +import logging +import os + +from icenet.data.loaders import IceNetDataLoaderFactory + +from train_icenet_unet import train_icenet_unet +from test_icenet_unet import test_icenet_unet + +# Quick hack to put us in the icenet-pipeline folder, +# assuming it was created as per 01.cli_demonstration.ipynb +pipeline_directory = os.path.join(os.path.dirname(__file__), "../../notebook-pipeline") +os.chdir(pipeline_directory) +print("Running in {}".format(os.getcwd())) + +logging.getLogger().setLevel(logging.DEBUG) + +print('A', torch.__version__) +print('B', torch.cuda.is_available()) +print('C', torch.backends.cudnn.enabled) +device = torch.device('cuda') +print('D', torch.cuda.get_device_properties(device)) + +# set loader config and dataset names +implementation = "dask" +loader_config = "loader.notebook_api_data.json" +dataset_name = "pytorch_notebook" +lag = 1 + +# create IceNet dataloader +# (assuming notebook 03 has been ran and the loader config exists) +# and write a dataset config +dl = IceNetDataLoaderFactory().create_data_loader( + implementation, + loader_config, + dataset_name, + lag, + n_forecast_days=7, + north=False, + south=True, + output_batch_size=4, + generate_workers=8, +) + +# write dataset config +dl.write_dataset_config_only() +dataset_config = f"dataset_config.{dataset_name}.json" + +# train and test model +seed = 42 +batch_size = 4 +shuffle = True +num_workers = 0 +persistent_workers = False + +# train model +lit_unet_module, unet_model = train_icenet_unet( + configuration_path=dataset_config, + learning_rate=1e-4, + max_epochs=100, + seed=seed, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + persistent_workers=persistent_workers, +) + +print("Finished training UNET model") +print(f"UNet model:\n{unet_model}") +print(f"UNet (Lightning Module):\n{lit_unet_module}") + +# test model +print("Testing model") +y_hat_unet, y_true = test_icenet_unet( + configuration_path=dataset_config, + lit_module_unet=lit_unet_module, + seed=seed, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + persistent_workers=persistent_workers, +) + +print("Finished testing model") +print(f"y_hat_unet.shape: {y_hat_unet.shape}") +print(f"y_true.shape: {y_true.shape}") diff --git a/pytorch_example/pytorch_example_dev.py b/pytorch_example/pytorch_example_dev.py new file mode 100644 index 0000000..ed4e814 --- /dev/null +++ b/pytorch_example/pytorch_example_dev.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python +# coding: utf-8 + +import pandas as pd +import torch +import logging + +from icenet.data.loaders import IceNetDataLoaderFactory +from torch.utils.data import DataLoader +from icenet_pytorch_dataset import IceNetDataSetPyTorch + +logging.getLogger().setLevel(logging.DEBUG) + +print('A', torch.__version__) +print('B', torch.cuda.is_available()) +print('C', torch.backends.cudnn.enabled) +device = torch.device('cuda') +print('D', torch.cuda.get_device_properties(device)) + +# set loader config and dataset names +implementation = "dask" +loader_config = "loader.notebook_api_data.json" +dataset_name = "pytorch_notebook" +lag = 1 + +# create IceNet dataloader +# (assuming notebook 03 has been ran and the loader config exists) +# and write a dataset config +dl = IceNetDataLoaderFactory().create_data_loader( + implementation, + loader_config, + dataset_name, + lag, + n_forecast_days=7, + north=False, + south=True, + output_batch_size=4, + generate_workers=8) + +# write dataset config +dl.write_dataset_config_only() +dataset_config = f"dataset_config.{dataset_name}.json" + +# test creation of custom PyTorch dataset and obtaining samples from them +ds_torch = IceNetDataSetPyTorch(configuration_path=dataset_config, + mode="train") + +logging.info("Inspecting dataset from torch") +logging.info(ds_torch.__len__()) +logging.info(ds_torch._dates[0]) +#logging.info(ds_torch.__getitem__(0)) +logging.info(ds_torch._dl.generate_sample(date=pd.Timestamp(ds_torch._dates[0].replace('_', '-')))) + +# create custom PyTorch datasets for train, validation and test +train_dataset = IceNetDataSetPyTorch(configuration_path=dataset_config, mode="train") +val_dataset = IceNetDataSetPyTorch(configuration_path=dataset_config, mode="val") +test_dataset = IceNetDataSetPyTorch(configuration_path=dataset_config, mode="test") + +# creating PyTorch dataloaders from the datasets +batch_size = 4 +shuffle = True +pworkers = False +num_workers = 0 + +train_dataloader = DataLoader(train_dataset, + batch_size=batch_size, + shuffle=shuffle, + persistent_workers=pworkers, + num_workers=num_workers) +val_dataloader = DataLoader(val_dataset, + batch_size=batch_size, + shuffle=shuffle, + persistent_workers=pworkers, + num_workers=num_workers) +test_dataloader = DataLoader(test_dataset, + batch_size=batch_size, + shuffle=shuffle, + persistent_workers=pworkers, + num_workers=num_workers) + +logging.info("Inspecting dataloader from torch") +logging.info("Getting next sample from {} training samples".format(len(train_dataloader))) +for i, data in enumerate(iter(train_dataloader)): + logging.info("Train sample: {}".format(i)) + train_features, train_labels, sample_weights = data + logging.info(train_features.shape) + logging.info(train_labels.shape) + logging.info(sample_weights.shape) + + + diff --git a/pytorch_example/test_icenet_unet.py b/pytorch_example/test_icenet_unet.py new file mode 100644 index 0000000..daa2478 --- /dev/null +++ b/pytorch_example/test_icenet_unet.py @@ -0,0 +1,49 @@ +""" +Taken from Andrew McDonald's https://github.com/ampersandmcd/icenet-gan/blob/main/notebooks/4_forecast.ipynb +""" + +from __future__ import annotations + +import numpy as np +import torch +from torch.utils.data import DataLoader + +import lightning.pytorch as pl + +from icenet_pytorch_dataset import IceNetDataSetPyTorch + +def test_icenet_unet( + configuration_path: str, + lit_module_unet: pl.LightningModule, + seed: int, + batch_size: int = 4, + num_workers: int = 0, + persistent_workers: bool = False, +) -> tuple[np.ndarray]: + pl.seed_everything(seed) + + # configure datasets and dataloaders + test_dataset = IceNetDataSetPyTorch(configuration_path=configuration_path, mode="test") + test_dataloader = DataLoader(test_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + persistent_workers=persistent_workers) + + # pass batches through unet and accumulate into list + y_true = [] + y_hat_unet = [] + with torch.no_grad(): + for batch in test_dataloader: + x, y, sample_weight = batch + # save ground truth + y_true.extend(y) + # predict using UNet + pred_unet = lit_module_unet(x.to(lit_module_unet.device)).detach().cpu().numpy() + # save prediction + y_hat_unet.extend(pred_unet) + + y_true = np.array(y_true) + y_hat_unet = np.array(y_hat_unet) + + return y_hat_unet, y_true diff --git a/pytorch_example/train_icenet_unet.py b/pytorch_example/train_icenet_unet.py new file mode 100644 index 0000000..67fd7cc --- /dev/null +++ b/pytorch_example/train_icenet_unet.py @@ -0,0 +1,87 @@ +""" +Taken from Andrew McDonald's https://github.com/ampersandmcd/icenet-gan/blob/main/src/train_icenet.py +""" + +from __future__ import annotations + +from torch import nn +from torch.utils.data import DataLoader + +import lightning.pytorch as pl +from lightning.pytorch.callbacks import ModelCheckpoint + +from icenet_unet_small import weighted_mse_loss, UNet, LitUNet +from icenet_pytorch_dataset import IceNetDataSetPyTorch + + +def train_icenet_unet( + configuration_path: str, + learning_rate: float, + max_epochs: int, + filter_size: int = 3, + n_filters_factor: float = 1.0, + seed: int = 42, + batch_size: int = 4, + shuffle: bool = True, + num_workers: int = 0, + persistent_workers: bool = False, +) -> tuple[LitUNet, UNet]: + """ + Train IceNet using the arguments specified. + """ + pl.seed_everything(seed) + + # configure datasets and dataloaders + train_dataset = IceNetDataSetPyTorch(configuration_path, mode="train") + val_dataset = IceNetDataSetPyTorch(configuration_path, mode="val") + train_dataloader = DataLoader(train_dataset, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + persistent_workers=persistent_workers) + # no need to shuffle validation set + val_dataloader = DataLoader(val_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + persistent_workers=persistent_workers) + + # construct unet + model = UNet( + input_channels=len(train_dataset._ds._config["channels"]), + filter_size=filter_size, + n_filters_factor=n_filters_factor, + n_forecast_days=train_dataset._ds._config["n_forecast_days"] + ) + + criterion = nn.MSELoss(reduction="none") + + # configure PyTorch Lightning module + lit_module = LitUNet( + model=model, + criterion=criterion, + learning_rate=learning_rate + ) + + # set up trainer configuration + trainer = pl.Trainer( + accelerator="auto", + devices=1, + log_every_n_steps=10, + max_epochs=max_epochs, + num_sanity_val_steps=1, + ) + trainer.callbacks.append(ModelCheckpoint(monitor="val_accuracy", mode="max")) + + # train model + print( + f"Training {len(train_dataset)} examples / {len(train_dataloader)} " + f"batches (batch size {batch_size})." + ) + print( + f"Validating {len(val_dataset)} examples / {len(val_dataloader)} " + f"batches (batch size {batch_size})." + ) + trainer.fit(lit_module, train_dataloader, val_dataloader) + + return lit_module, model