From d55da3f56b4890d86989a942cf11c2ff9b46c11d Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Fri, 5 Jan 2024 10:44:46 +0000 Subject: [PATCH 1/4] Dev #211: Add icenet pytorch dataset class --- icenet/data/dataset.py | 44 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/icenet/data/dataset.py b/icenet/data/dataset.py index b0be5698..60007908 100644 --- a/icenet/data/dataset.py +++ b/icenet/data/dataset.py @@ -3,12 +3,19 @@ import logging import os +import dask import numpy as np +import pandas as pd from icenet.data.datasets.utils import SplittingMixin from icenet.data.loader import IceNetDataLoaderFactory from icenet.data.producers import DataCollection from icenet.utils import setup_logging + +try: + from torch.utils.data import Dataset +except ImportError: + logging.info("PyTorch import failed - not mandatory if not using PyTorch") """ @@ -317,6 +324,43 @@ def counts(self): return self._config["counts"] +class IceNetDataSetPyTorch(IceNetDataSet, Dataset): + + def __init__( + self, + configuration_path: str, + mode: str, + batch_size: int = 1, + shuffling: bool = False, + ): + super().__init__(configuration_path=configuration_path, + batch_size=batch_size, + shuffling=shuffling) + self._dl = self.get_data_loader() + + # check mode option + if mode not in ["train", "val", "test"]: + raise ValueError("mode must be either 'train', 'val', 'test'") + self._mode = mode + + self._dates = self._dl._config["sources"]["osisaf"]["dates"][self._mode] + + def __len__(self): + return self._counts[self._mode] + + def __getitem__(self, idx): + with dask.config.set(scheduler="synchronous"): + sample = self._dl.generate_sample( + date=pd.Timestamp(self._dates[idx].replace('_', '-')), + parallel=False, + ) + return sample + + @property + def dates(self): + return self._dates + + @setup_logging def get_args() -> object: """Parse command line arguments using the argparse module. From 015e133b64a2be42cc8304b8bba4ec21c997ca1b Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" Date: Mon, 8 Jan 2024 09:10:13 +0000 Subject: [PATCH 2/4] Dev #211: Update import --- icenet/data/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/icenet/data/dataset.py b/icenet/data/dataset.py index 60007908..8668b008 100644 --- a/icenet/data/dataset.py +++ b/icenet/data/dataset.py @@ -14,7 +14,7 @@ try: from torch.utils.data import Dataset -except ImportError: +except ModuleNotFoundError: logging.info("PyTorch import failed - not mandatory if not using PyTorch") """ From d62f154e33a4fbb50adbc735d44521c3b1663f15 Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" <55503826+bnubald@users.noreply.github.com> Date: Thu, 29 Feb 2024 14:23:30 +0000 Subject: [PATCH 3/4] Fixes 211: Update docstring, capture no pytorch scenario --- icenet/data/dataset.py | 94 ++++++++++++++++++++++++------------------ 1 file changed, 54 insertions(+), 40 deletions(-) diff --git a/icenet/data/dataset.py b/icenet/data/dataset.py index 8668b008..05636848 100644 --- a/icenet/data/dataset.py +++ b/icenet/data/dataset.py @@ -12,10 +12,6 @@ from icenet.data.producers import DataCollection from icenet.utils import setup_logging -try: - from torch.utils.data import Dataset -except ModuleNotFoundError: - logging.info("PyTorch import failed - not mandatory if not using PyTorch") """ @@ -57,7 +53,7 @@ def __init__(self, Args: configuration_path: The path to the JSON configuration file. *args: Additional positional arguments. - batch_size (optional): The batch size for the data loader. Defaults to 4. + batch_size (optional): How many samples to load per batch. Defaults to 4. path (optional): The path to the directory where the processed tfrecord protocol buffer files will be stored. Defaults to './network_datasets'. shuffling (optional): Flag indicating whether to shuffle the data. @@ -324,41 +320,59 @@ def counts(self): return self._config["counts"] -class IceNetDataSetPyTorch(IceNetDataSet, Dataset): - - def __init__( - self, - configuration_path: str, - mode: str, - batch_size: int = 1, - shuffling: bool = False, - ): - super().__init__(configuration_path=configuration_path, - batch_size=batch_size, - shuffling=shuffling) - self._dl = self.get_data_loader() - - # check mode option - if mode not in ["train", "val", "test"]: - raise ValueError("mode must be either 'train', 'val', 'test'") - self._mode = mode - - self._dates = self._dl._config["sources"]["osisaf"]["dates"][self._mode] - - def __len__(self): - return self._counts[self._mode] - - def __getitem__(self, idx): - with dask.config.set(scheduler="synchronous"): - sample = self._dl.generate_sample( - date=pd.Timestamp(self._dates[idx].replace('_', '-')), - parallel=False, - ) - return sample - - @property - def dates(self): - return self._dates +try: + from torch.utils.data import Dataset + class IceNetDataSetPyTorch(IceNetDataSet, Dataset): + """Initialises and configures a PyTorch dataset. + """ + def __init__( + self, + configuration_path: str, + mode: str, + batch_size: int = 1, + shuffling: bool = False, + ): + """Initialises an instance of the IceNetDataSetPyTorch class. + + Args: + configuration_path: The path to the JSON configuration file. + mode: The dataset type, i.e. `train`, `val` or `test`. + batch_size (optional): How many samples to load per batch. Defaults to 1. + shuffling (optional): Flag indicating whether to shuffle the data. + Defaults to False. + """ + super().__init__(configuration_path=configuration_path, + batch_size=batch_size, + shuffling=shuffling) + self._dl = self.get_data_loader() + + # check mode option + if mode not in ["train", "val", "test"]: + raise ValueError("mode must be either 'train', 'val', 'test'") + self._mode = mode + + self._dates = self._dl._config["sources"]["osisaf"]["dates"][self._mode] + + def __len__(self): + return self._counts[self._mode] + + def __getitem__(self, idx): + """Return a sample from the dataloader for given index. + """ + with dask.config.set(scheduler="synchronous"): + sample = self._dl.generate_sample( + date=pd.Timestamp(self._dates[idx].replace('_', '-')), + parallel=False, + ) + return sample + + @property + def dates(self): + return self._dates +except ModuleNotFoundError: + logging.warn("PyTorch module not found - not mandatory if not using PyTorch") +except ImportError: + logging.warn("PyTorch import failed - not mandatory if not using PyTorch") @setup_logging From 19fc59882d8b91de7106815513c3932b83a7a04e Mon Sep 17 00:00:00 2001 From: "Bryn N. Ubald" <55503826+bnubald@users.noreply.github.com> Date: Thu, 29 Feb 2024 22:18:23 +0000 Subject: [PATCH 4/4] Dev #211: Updated old syntax --- icenet/data/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/icenet/data/dataset.py b/icenet/data/dataset.py index 05636848..511979ca 100644 --- a/icenet/data/dataset.py +++ b/icenet/data/dataset.py @@ -370,9 +370,9 @@ def __getitem__(self, idx): def dates(self): return self._dates except ModuleNotFoundError: - logging.warn("PyTorch module not found - not mandatory if not using PyTorch") + logging.warning("PyTorch module not found - not mandatory if not using PyTorch") except ImportError: - logging.warn("PyTorch import failed - not mandatory if not using PyTorch") + logging.warning("PyTorch import failed - not mandatory if not using PyTorch") @setup_logging