Skip to content

Commit

Permalink
Merge pull request #225 from bnubald/211_pytorch_dataset
Browse files Browse the repository at this point in the history
Fixes #211: Addition of PyTorch dataset
  • Loading branch information
bnubald authored Feb 29, 2024
2 parents cb1cb78 + 19fc598 commit 65c1271
Showing 1 changed file with 59 additions and 1 deletion.
60 changes: 59 additions & 1 deletion icenet/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
import logging
import os

import dask
import numpy as np
import pandas as pd

from icenet.data.datasets.utils import SplittingMixin
from icenet.data.loader import IceNetDataLoaderFactory
from icenet.data.producers import DataCollection
from icenet.utils import setup_logging

"""
Expand Down Expand Up @@ -50,7 +53,7 @@ def __init__(self,
Args:
configuration_path: The path to the JSON configuration file.
*args: Additional positional arguments.
batch_size (optional): The batch size for the data loader. Defaults to 4.
batch_size (optional): How many samples to load per batch. Defaults to 4.
path (optional): The path to the directory where the processed tfrecord
protocol buffer files will be stored. Defaults to './network_datasets'.
shuffling (optional): Flag indicating whether to shuffle the data.
Expand Down Expand Up @@ -317,6 +320,61 @@ def counts(self):
return self._config["counts"]


try:
from torch.utils.data import Dataset
class IceNetDataSetPyTorch(IceNetDataSet, Dataset):
"""Initialises and configures a PyTorch dataset.
"""
def __init__(
self,
configuration_path: str,
mode: str,
batch_size: int = 1,
shuffling: bool = False,
):
"""Initialises an instance of the IceNetDataSetPyTorch class.
Args:
configuration_path: The path to the JSON configuration file.
mode: The dataset type, i.e. `train`, `val` or `test`.
batch_size (optional): How many samples to load per batch. Defaults to 1.
shuffling (optional): Flag indicating whether to shuffle the data.
Defaults to False.
"""
super().__init__(configuration_path=configuration_path,
batch_size=batch_size,
shuffling=shuffling)
self._dl = self.get_data_loader()

# check mode option
if mode not in ["train", "val", "test"]:
raise ValueError("mode must be either 'train', 'val', 'test'")
self._mode = mode

self._dates = self._dl._config["sources"]["osisaf"]["dates"][self._mode]

def __len__(self):
return self._counts[self._mode]

def __getitem__(self, idx):
"""Return a sample from the dataloader for given index.
"""
with dask.config.set(scheduler="synchronous"):
sample = self._dl.generate_sample(
date=pd.Timestamp(self._dates[idx].replace('_', '-')),
parallel=False,
)
return sample

@property
def dates(self):
return self._dates
except ModuleNotFoundError:
logging.warning("PyTorch module not found - not mandatory if not using PyTorch")
except ImportError:
logging.warning("PyTorch import failed - not mandatory if not using PyTorch")


@setup_logging
def get_args() -> object:
"""Parse command line arguments using the argparse module.
Expand Down

0 comments on commit 65c1271

Please sign in to comment.