diff --git a/icenet/data/dataset.py b/icenet/data/dataset.py index b0be569..511979c 100644 --- a/icenet/data/dataset.py +++ b/icenet/data/dataset.py @@ -3,12 +3,15 @@ import logging import os +import dask import numpy as np +import pandas as pd from icenet.data.datasets.utils import SplittingMixin from icenet.data.loader import IceNetDataLoaderFactory from icenet.data.producers import DataCollection from icenet.utils import setup_logging + """ @@ -50,7 +53,7 @@ def __init__(self, Args: configuration_path: The path to the JSON configuration file. *args: Additional positional arguments. - batch_size (optional): The batch size for the data loader. Defaults to 4. + batch_size (optional): How many samples to load per batch. Defaults to 4. path (optional): The path to the directory where the processed tfrecord protocol buffer files will be stored. Defaults to './network_datasets'. shuffling (optional): Flag indicating whether to shuffle the data. @@ -317,6 +320,61 @@ def counts(self): return self._config["counts"] +try: + from torch.utils.data import Dataset + class IceNetDataSetPyTorch(IceNetDataSet, Dataset): + """Initialises and configures a PyTorch dataset. + """ + def __init__( + self, + configuration_path: str, + mode: str, + batch_size: int = 1, + shuffling: bool = False, + ): + """Initialises an instance of the IceNetDataSetPyTorch class. + + Args: + configuration_path: The path to the JSON configuration file. + mode: The dataset type, i.e. `train`, `val` or `test`. + batch_size (optional): How many samples to load per batch. Defaults to 1. + shuffling (optional): Flag indicating whether to shuffle the data. + Defaults to False. + """ + super().__init__(configuration_path=configuration_path, + batch_size=batch_size, + shuffling=shuffling) + self._dl = self.get_data_loader() + + # check mode option + if mode not in ["train", "val", "test"]: + raise ValueError("mode must be either 'train', 'val', 'test'") + self._mode = mode + + self._dates = self._dl._config["sources"]["osisaf"]["dates"][self._mode] + + def __len__(self): + return self._counts[self._mode] + + def __getitem__(self, idx): + """Return a sample from the dataloader for given index. + """ + with dask.config.set(scheduler="synchronous"): + sample = self._dl.generate_sample( + date=pd.Timestamp(self._dates[idx].replace('_', '-')), + parallel=False, + ) + return sample + + @property + def dates(self): + return self._dates +except ModuleNotFoundError: + logging.warning("PyTorch module not found - not mandatory if not using PyTorch") +except ImportError: + logging.warning("PyTorch import failed - not mandatory if not using PyTorch") + + @setup_logging def get_args() -> object: """Parse command line arguments using the argparse module.