Skip to content

Commit

Permalink
Add wind datamodule
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Nov 27, 2023
1 parent a256ad7 commit 1953090
Showing 1 changed file with 193 additions and 0 deletions.
193 changes: 193 additions & 0 deletions pvnet/data/wind_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
""" Data module for pytorch lightning """
from datetime import datetime

Check warning on line 2 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L2

Added line #L2 was not covered by tests

import numpy as np
import torch
from lightning.pytorch import LightningDataModule
from ocf_datapipes.training.windnet import windnet_netcdf_datapipe
from ocf_datapipes.utils.consts import BatchKey
from ocf_datapipes.utils.utils import stack_np_examples_into_batch
from torch.utils.data import DataLoader
from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import IterDataPipe
import glob

Check warning on line 13 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L4-L13

Added lines #L4 - L13 were not covered by tests


def copy_batch_to_device(batch, device):

Check warning on line 16 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L16

Added line #L16 was not covered by tests
"""Moves a dict-batch of tensors to new device."""
batch_copy = {}
for k in list(batch.keys()):
if isinstance(batch[k], torch.Tensor):
batch_copy[k] = batch[k].to(device)

Check warning on line 21 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L18-L21

Added lines #L18 - L21 were not covered by tests
else:
batch_copy[k] = batch[k]
return batch_copy

Check warning on line 24 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L23-L24

Added lines #L23 - L24 were not covered by tests


def batch_to_tensor(batch):

Check warning on line 27 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L27

Added line #L27 was not covered by tests
"""Moves numpy batch to a tensor"""
for k in list(batch.keys()):
if isinstance(batch[k], np.ndarray) and np.issubdtype(batch[k].dtype, np.number):
batch[k] = torch.as_tensor(batch[k])
return batch

Check warning on line 32 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L29-L32

Added lines #L29 - L32 were not covered by tests


def split_batches(batch):

Check warning on line 35 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L35

Added line #L35 was not covered by tests
"""Splits a single batch of data."""
n_samples = batch[BatchKey.sensor].shape[0]
keys = list(batch.keys())
examples = [{} for _ in range(n_samples)]
for i in range(n_samples):
b = examples[i]
for k in keys:
if ("idx" in k.name) or ("channel_names" in k.name):
b[k] = batch[k]

Check warning on line 44 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L37-L44

Added lines #L37 - L44 were not covered by tests
else:
b[k] = batch[k][i]
return examples

Check warning on line 47 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L46-L47

Added lines #L46 - L47 were not covered by tests


@functional_datapipe("split_batches")
class BatchSplitter(IterDataPipe):

Check warning on line 51 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L50-L51

Added lines #L50 - L51 were not covered by tests
"""Pipeline step to split batches of data and yield single examples"""

def __init__(self, source_datapipe: IterDataPipe):

Check warning on line 54 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L54

Added line #L54 was not covered by tests
"""Pipeline step to split batches of data and yield single examples"""
self.source_datapipe = source_datapipe

Check warning on line 56 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L56

Added line #L56 was not covered by tests

def __iter__(self):

Check warning on line 58 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L58

Added line #L58 was not covered by tests
"""Opens the NWP data"""
for batch in self.source_datapipe:
for example in split_batches(batch):
yield example

Check warning on line 62 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L60-L62

Added lines #L60 - L62 were not covered by tests


class WindDataModule(LightningDataModule):

Check warning on line 65 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L65

Added line #L65 was not covered by tests
"""Datamodule for training pvnet and using pvnet pipeline in `ocf_datapipes`."""

def __init__(

Check warning on line 68 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L68

Added line #L68 was not covered by tests
self,
configuration=None,
batch_size=16,
num_workers=0,
prefetch_factor=None,
train_period=[None, None],
val_period=[None, None],
test_period=[None, None],
batch_dir=None,
):
"""Datamodule for training pvnet and using pvnet pipeline in `ocf_datapipes`.
Can also be used with pre-made batches if `batch_dir` is set.
Args:
configuration: Path to datapipe configuration file.
batch_size: Batch size.
num_workers: Number of workers to use in multiprocess batch loading.
prefetch_factor: Number of data will be prefetched at the end of each worker process.
train_period: Date range filter for train dataloader.
val_period: Date range filter for val dataloader.
test_period: Date range filter for test dataloader.
batch_dir: Path to the directory of pre-saved batches. Cannot be used together with
'train/val/test_period'.
"""
super().__init__()
self.configuration = configuration
self.batch_size = batch_size
self.batch_dir = batch_dir

Check warning on line 99 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L96-L99

Added lines #L96 - L99 were not covered by tests


if batch_dir is not None:
if any([period != [None, None] for period in [train_period, val_period, test_period]]):
raise ValueError("Cannot set `(train/val/test)_period` with presaved batches")

Check warning on line 104 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L102-L104

Added lines #L102 - L104 were not covered by tests

self.train_period = [

Check warning on line 106 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L106

Added line #L106 was not covered by tests
None if d is None else datetime.strptime(d, "%Y-%m-%d") for d in train_period
]
self.val_period = [

Check warning on line 109 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L109

Added line #L109 was not covered by tests
None if d is None else datetime.strptime(d, "%Y-%m-%d") for d in val_period
]
self.test_period = [

Check warning on line 112 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L112

Added line #L112 was not covered by tests
None if d is None else datetime.strptime(d, "%Y-%m-%d") for d in test_period
]

self._common_dataloader_kwargs = dict(

Check warning on line 116 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L116

Added line #L116 was not covered by tests
shuffle=False, # shuffled in datapipe step
batch_size=None, # batched in datapipe step
sampler=None,
batch_sampler=None,
num_workers=num_workers,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
prefetch_factor=prefetch_factor,
persistent_workers=False,
)

def _get_datapipe(self, start_time, end_time):
data_pipeline = windnet_netcdf_datapipe(

Check warning on line 132 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L131-L132

Added lines #L131 - L132 were not covered by tests
self.configuration,
keys=["sensor","nwp"],
)

data_pipeline = (

Check warning on line 137 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L137

Added line #L137 was not covered by tests
data_pipeline.batch(self.batch_size)
.map(stack_np_examples_into_batch)
.map(batch_to_tensor)
)
return data_pipeline

Check warning on line 142 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L142

Added line #L142 was not covered by tests

def _get_premade_batches_datapipe(self, subdir, shuffle=False):
data_pipeline = windnet_netcdf_datapipe(config_filename=self.configuration,

Check warning on line 145 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L144-L145

Added lines #L144 - L145 were not covered by tests
keys=["sensor","nwp"],
filenames=list(glob.glob(f"{self.batch_dir}/{subdir}/*.nc")))
if shuffle:
data_pipeline = (

Check warning on line 149 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L148-L149

Added lines #L148 - L149 were not covered by tests
data_pipeline.shuffle(buffer_size=100)
.sharding_filter()
# Split the batches and reshuffle them to be combined into new batches
.split_batches()
.shuffle(buffer_size=100 * self.batch_size)
)
else:
data_pipeline = (

Check warning on line 157 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L157

Added line #L157 was not covered by tests
data_pipeline.sharding_filter()
# Split the batches so we can use any batch-size
.split_batches()
)

data_pipeline = (

Check warning on line 163 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L163

Added line #L163 was not covered by tests
data_pipeline.batch(self.batch_size)
.map(stack_np_examples_into_batch)
.map(batch_to_tensor)
)

return data_pipeline

Check warning on line 169 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L169

Added line #L169 was not covered by tests

def train_dataloader(self):

Check warning on line 171 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L171

Added line #L171 was not covered by tests
"""Construct train dataloader"""
if self.batch_dir is not None:
datapipe = self._get_premade_batches_datapipe("train", shuffle=True)

Check warning on line 174 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L173-L174

Added lines #L173 - L174 were not covered by tests
else:
datapipe = self._get_datapipe(*self.train_period)
return DataLoader(datapipe, **self._common_dataloader_kwargs)

Check warning on line 177 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L176-L177

Added lines #L176 - L177 were not covered by tests

def val_dataloader(self):

Check warning on line 179 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L179

Added line #L179 was not covered by tests
"""Construct val dataloader"""
if self.batch_dir is not None:
datapipe = self._get_premade_batches_datapipe("val")

Check warning on line 182 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L181-L182

Added lines #L181 - L182 were not covered by tests
else:
datapipe = self._get_datapipe(*self.val_period)
return DataLoader(datapipe, **self._common_dataloader_kwargs)

Check warning on line 185 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L184-L185

Added lines #L184 - L185 were not covered by tests

def test_dataloader(self):

Check warning on line 187 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L187

Added line #L187 was not covered by tests
"""Construct test dataloader"""
if self.batch_dir is not None:
datapipe = self._get_premade_batches_datapipe("test")

Check warning on line 190 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L189-L190

Added lines #L189 - L190 were not covered by tests
else:
datapipe = self._get_datapipe(*self.test_period)
return DataLoader(datapipe, **self._common_dataloader_kwargs)

Check warning on line 193 in pvnet/data/wind_datamodule.py

View check run for this annotation

Codecov / codecov/patch

pvnet/data/wind_datamodule.py#L192-L193

Added lines #L192 - L193 were not covered by tests

0 comments on commit 1953090

Please sign in to comment.