Skip to content

Commit

Permalink
Merge pull request #99 from openclimatefix/jacob/windnet-batch-creation
Browse files Browse the repository at this point in the history
Add option to save WindNet batches
  • Loading branch information
jacobbieker authored Nov 27, 2023
2 parents 7eee4bb + 2fdb5d9 commit a2be583
Show file tree
Hide file tree
Showing 7 changed files with 239 additions and 60 deletions.
1 change: 1 addition & 0 deletions configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ defaults:
- experiment: null
- hparams_search: null
- hydra: default.yaml
- renewable: "pv"

# enable color logging
# - override hydra/hydra_logging: colorlog
Expand Down
1 change: 1 addition & 0 deletions pvnet/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
"""Data parts"""
from .utils import BatchSplitter
53 changes: 1 addition & 52 deletions pvnet/data/datamodule.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,14 @@
""" Data module for pytorch lightning """
from datetime import datetime

import numpy as np
import torch
from lightning.pytorch import LightningDataModule
from ocf_datapipes.training.pvnet import pvnet_datapipe
from ocf_datapipes.utils.consts import BatchKey
from ocf_datapipes.utils.utils import stack_np_examples_into_batch
from torch.utils.data import DataLoader
from torch.utils.data.datapipes._decorator import functional_datapipe
from torch.utils.data.datapipes.datapipe import IterDataPipe
from torch.utils.data.datapipes.iter import FileLister


def copy_batch_to_device(batch, device):
"""Moves a dict-batch of tensors to new device."""
batch_copy = {}
for k in list(batch.keys()):
if isinstance(batch[k], torch.Tensor):
batch_copy[k] = batch[k].to(device)
else:
batch_copy[k] = batch[k]
return batch_copy


def batch_to_tensor(batch):
"""Moves numpy batch to a tensor"""
for k in list(batch.keys()):
if isinstance(batch[k], np.ndarray) and np.issubdtype(batch[k].dtype, np.number):
batch[k] = torch.as_tensor(batch[k])
return batch


def split_batches(batch):
"""Splits a single batch of data."""
n_samples = batch[BatchKey.gsp].shape[0]
keys = list(batch.keys())
examples = [{} for _ in range(n_samples)]
for i in range(n_samples):
b = examples[i]
for k in keys:
if ("idx" in k.name) or ("channel_names" in k.name):
b[k] = batch[k]
else:
b[k] = batch[k][i]
return examples


@functional_datapipe("split_batches")
class BatchSplitter(IterDataPipe):
"""Pipeline step to split batches of data and yield single examples"""

def __init__(self, source_datapipe: IterDataPipe):
"""Pipeline step to split batches of data and yield single examples"""
self.source_datapipe = source_datapipe

def __iter__(self):
"""Opens the NWP data"""
for batch in self.source_datapipe:
for example in split_batches(batch):
yield example
from pvnet.data.utils import batch_to_tensor


class DataModule(LightningDataModule):
Expand Down
56 changes: 56 additions & 0 deletions pvnet/data/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Utils common between Wind and PV datamodules"""
import numpy as np
import torch
from ocf_datapipes.utils.consts import BatchKey
from torch.utils.data import IterDataPipe, functional_datapipe


def copy_batch_to_device(batch, device):
"""Moves a dict-batch of tensors to new device."""
batch_copy = {}
for k in list(batch.keys()):
if isinstance(batch[k], torch.Tensor):
batch_copy[k] = batch[k].to(device)
else:
batch_copy[k] = batch[k]
return batch_copy


def batch_to_tensor(batch):
"""Moves numpy batch to a tensor"""
for k in list(batch.keys()):
if isinstance(batch[k], np.ndarray) and np.issubdtype(batch[k].dtype, np.number):
batch[k] = torch.as_tensor(batch[k])
return batch


def split_batches(batch, splitting_key=BatchKey.gsp):
"""Splits a single batch of data."""

n_samples = batch[splitting_key].shape[0]
keys = list(batch.keys())
examples = [{} for _ in range(n_samples)]
for i in range(n_samples):
b = examples[i]
for k in keys:
if ("idx" in k.name) or ("channel_names" in k.name):
b[k] = batch[k]
else:
b[k] = batch[k][i]
return examples


@functional_datapipe("split_batches")
class BatchSplitter(IterDataPipe):
"""Pipeline step to split batches of data and yield single examples"""

def __init__(self, source_datapipe: IterDataPipe, splitting_key: BatchKey = BatchKey.gsp):
"""Pipeline step to split batches of data and yield single examples"""
self.source_datapipe = source_datapipe
self.splitting_key = splitting_key

def __iter__(self):
"""Opens the NWP data"""
for batch in self.source_datapipe:
for example in split_batches(batch, splitting_key=self.splitting_key):
yield example
142 changes: 142 additions & 0 deletions pvnet/data/wind_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
""" Data module for pytorch lightning """
import glob
from datetime import datetime

from lightning.pytorch import LightningDataModule
from ocf_datapipes.training.windnet import windnet_netcdf_datapipe
from ocf_datapipes.utils.utils import stack_np_examples_into_batch
from torch.utils.data import DataLoader

from pvnet.data.utils import batch_to_tensor


class WindDataModule(LightningDataModule):
"""Datamodule for training pvnet and using pvnet pipeline in `ocf_datapipes`."""

def __init__(
self,
configuration=None,
batch_size=16,
num_workers=0,
prefetch_factor=None,
train_period=[None, None],
val_period=[None, None],
test_period=[None, None],
batch_dir=None,
):
"""Datamodule for training pvnet and using pvnet pipeline in `ocf_datapipes`.
Can also be used with pre-made batches if `batch_dir` is set.
Args:
configuration: Path to datapipe configuration file.
batch_size: Batch size.
num_workers: Number of workers to use in multiprocess batch loading.
prefetch_factor: Number of data will be prefetched at the end of each worker process.
train_period: Date range filter for train dataloader.
val_period: Date range filter for val dataloader.
test_period: Date range filter for test dataloader.
batch_dir: Path to the directory of pre-saved batches. Cannot be used together with
'train/val/test_period'.
"""
super().__init__()
self.configuration = configuration
self.batch_size = batch_size
self.batch_dir = batch_dir

if batch_dir is not None:
if any([period != [None, None] for period in [train_period, val_period, test_period]]):
raise ValueError("Cannot set `(train/val/test)_period` with presaved batches")

self.train_period = [
None if d is None else datetime.strptime(d, "%Y-%m-%d") for d in train_period
]
self.val_period = [
None if d is None else datetime.strptime(d, "%Y-%m-%d") for d in val_period
]
self.test_period = [
None if d is None else datetime.strptime(d, "%Y-%m-%d") for d in test_period
]

self._common_dataloader_kwargs = dict(
shuffle=False, # shuffled in datapipe step
batch_size=None, # batched in datapipe step
sampler=None,
batch_sampler=None,
num_workers=num_workers,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
prefetch_factor=prefetch_factor,
persistent_workers=False,
)

def _get_datapipe(self, start_time, end_time):
data_pipeline = windnet_netcdf_datapipe(
self.configuration,
keys=["sensor", "nwp"],
)

data_pipeline = (
data_pipeline.batch(self.batch_size)
.map(stack_np_examples_into_batch)
.map(batch_to_tensor)
)
return data_pipeline

def _get_premade_batches_datapipe(self, subdir, shuffle=False):
data_pipeline = windnet_netcdf_datapipe(
config_filename=self.configuration,
keys=["sensor", "nwp"],
filenames=list(glob.glob(f"{self.batch_dir}/{subdir}/*.nc")),
)
if shuffle:
data_pipeline = (
data_pipeline.shuffle(buffer_size=100)
.sharding_filter()
# Split the batches and reshuffle them to be combined into new batches
.split_batches()
.shuffle(buffer_size=100 * self.batch_size)
)
else:
data_pipeline = (
data_pipeline.sharding_filter()
# Split the batches so we can use any batch-size
.split_batches()
)

data_pipeline = (
data_pipeline.batch(self.batch_size)
.map(stack_np_examples_into_batch)
.map(batch_to_tensor)
)

return data_pipeline

def train_dataloader(self):
"""Construct train dataloader"""
if self.batch_dir is not None:
datapipe = self._get_premade_batches_datapipe("train", shuffle=True)
else:
datapipe = self._get_datapipe(*self.train_period)
return DataLoader(datapipe, **self._common_dataloader_kwargs)

def val_dataloader(self):
"""Construct val dataloader"""
if self.batch_dir is not None:
datapipe = self._get_premade_batches_datapipe("val")
else:
datapipe = self._get_datapipe(*self.val_period)
return DataLoader(datapipe, **self._common_dataloader_kwargs)

def test_dataloader(self):
"""Construct test dataloader"""
if self.batch_dir is not None:
datapipe = self._get_premade_batches_datapipe("test")
else:
datapipe = self._get_datapipe(*self.test_period)
return DataLoader(datapipe, **self._common_dataloader_kwargs)
31 changes: 23 additions & 8 deletions scripts/save_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import hydra
import torch
from ocf_datapipes.training.pvnet import pvnet_datapipe
from ocf_datapipes.training.windnet import windnet_datapipe
from ocf_datapipes.utils.utils import stack_np_examples_into_batch
from omegaconf import DictConfig, OmegaConf
from sqlalchemy import exc as sa_exc
Expand All @@ -44,16 +45,26 @@


class _save_batch_func_factory:
def __init__(self, batch_dir):
def __init__(self, batch_dir, output_format: str = "torch"):
self.batch_dir = batch_dir
self.output_format = output_format

def __call__(self, input):
i, batch = input
torch.save(batch, f"{self.batch_dir}/{i:06}.pt")


def _get_datapipe(config_path, start_time, end_time, batch_size):
data_pipeline = pvnet_datapipe(
if self.output_format == "torch":
torch.save(batch, f"{self.batch_dir}/{i:06}.pt")
elif self.output_format == "netcdf":
batch.to_netcdf(f"{self.batch_dir}/{i:06}.nc", mode="w")


def _get_datapipe(config_path, start_time, end_time, batch_size, renewable: str = "pv"):
if renewable == "pv":
data_pipeline_fn = pvnet_datapipe
elif renewable == "wind":
data_pipeline_fn = windnet_datapipe
else:
raise ValueError(f"Unknown renewable: {renewable}")
data_pipeline = data_pipeline_fn(
config_path,
start_time=start_time,
end_time=end_time,
Expand All @@ -65,8 +76,10 @@ def _get_datapipe(config_path, start_time, end_time, batch_size):
return data_pipeline


def _save_batches_with_dataloader(batch_pipe, batch_dir, num_batches, dataloader_kwargs):
save_func = _save_batch_func_factory(batch_dir)
def _save_batches_with_dataloader(
batch_pipe, batch_dir, num_batches, dataloader_kwargs, output_format: str = "torch"
):
save_func = _save_batch_func_factory(batch_dir, output_format=output_format)
filenumber_pipe = IterableWrapper(range(num_batches)).sharding_filter()
save_pipe = filenumber_pipe.zip(batch_pipe).map(save_func)

Expand Down Expand Up @@ -126,6 +139,7 @@ def main(config: DictConfig):
batch_dir=f"{config.batch_output_dir}/val",
num_batches=config.num_val_batches,
dataloader_kwargs=dataloader_kwargs,
output_format="torch" if config.renewable == "pv" else "netcdf",
)

if config.num_train_batches > 0:
Expand All @@ -142,6 +156,7 @@ def main(config: DictConfig):
batch_dir=f"{config.batch_output_dir}/train",
num_batches=config.num_train_batches,
dataloader_kwargs=dataloader_kwargs,
output_format="torch" if config.renewable == "pv" else "netcdf",
)

print("done")
Expand Down
15 changes: 15 additions & 0 deletions tests/data/test_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from pvnet.data.datamodule import DataModule
from pvnet.data.wind_datamodule import WindDataModule
import os


def test_init():
Expand All @@ -13,3 +15,16 @@ def test_init():
block_nwp_and_sat=False,
batch_dir="tests/data/sample_batches",
)


def test_wind_init():
dm = WindDataModule(
configuration=f"{os.path.dirname(os.path.abspath(__file__))}/data/sample_batches/data_configuration.yaml",
batch_size=2,
num_workers=0,
prefetch_factor=None,
train_period=[None, None],
val_period=[None, None],
test_period=[None, None],
batch_dir="tests/data/sample_batches",
)

0 comments on commit a2be583

Please sign in to comment.