Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #211: Addition of PyTorch dataset #225

Merged
merged 5 commits into from
Feb 29, 2024
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion icenet/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
import logging
import os

import dask
import numpy as np
import pandas as pd

from icenet.data.datasets.utils import SplittingMixin
from icenet.data.loader import IceNetDataLoaderFactory
from icenet.data.producers import DataCollection
from icenet.utils import setup_logging

"""


Expand Down Expand Up @@ -50,7 +53,7 @@ def __init__(self,
Args:
configuration_path: The path to the JSON configuration file.
*args: Additional positional arguments.
batch_size (optional): The batch size for the data loader. Defaults to 4.
batch_size (optional): How many samples to load per batch. Defaults to 4.
path (optional): The path to the directory where the processed tfrecord
protocol buffer files will be stored. Defaults to './network_datasets'.
shuffling (optional): Flag indicating whether to shuffle the data.
Expand Down Expand Up @@ -317,6 +320,61 @@ def counts(self):
return self._config["counts"]


try:
from torch.utils.data import Dataset
class IceNetDataSetPyTorch(IceNetDataSet, Dataset):
"""Initialises and configures a PyTorch dataset.
"""
def __init__(
self,
configuration_path: str,
mode: str,
batch_size: int = 1,
shuffling: bool = False,
):
"""Initialises an instance of the IceNetDataSetPyTorch class.

Args:
configuration_path: The path to the JSON configuration file.
mode: The dataset type, i.e. `train`, `val` or `test`.
batch_size (optional): How many samples to load per batch. Defaults to 1.
shuffling (optional): Flag indicating whether to shuffle the data.
Defaults to False.
"""
super().__init__(configuration_path=configuration_path,
batch_size=batch_size,
shuffling=shuffling)
self._dl = self.get_data_loader()

# check mode option
if mode not in ["train", "val", "test"]:
raise ValueError("mode must be either 'train', 'val', 'test'")
self._mode = mode

self._dates = self._dl._config["sources"]["osisaf"]["dates"][self._mode]

def __len__(self):
return self._counts[self._mode]

def __getitem__(self, idx):
"""Return a sample from the dataloader for given index.
"""
with dask.config.set(scheduler="synchronous"):
sample = self._dl.generate_sample(
date=pd.Timestamp(self._dates[idx].replace('_', '-')),
parallel=False,
)
return sample

@property
def dates(self):
return self._dates
except ModuleNotFoundError:
logging.warn("PyTorch module not found - not mandatory if not using PyTorch")
except ImportError:
logging.warn("PyTorch import failed - not mandatory if not using PyTorch")


@setup_logging
def get_args() -> object:
"""Parse command line arguments using the argparse module.
Expand Down