diff --git a/configs/config.yaml b/configs/config.yaml index 02d0a8d0..32931fcc 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -10,6 +10,7 @@ defaults: - experiment: null - hparams_search: null - hydra: default.yaml + - renewable: "pv" # enable color logging # - override hydra/hydra_logging: colorlog diff --git a/pvnet/data/__init__.py b/pvnet/data/__init__.py index 87716ddf..5d763df0 100644 --- a/pvnet/data/__init__.py +++ b/pvnet/data/__init__.py @@ -1 +1,2 @@ """Data parts""" +from .utils import BatchSplitter diff --git a/pvnet/data/datamodule.py b/pvnet/data/datamodule.py index 39abae5a..35aad131 100644 --- a/pvnet/data/datamodule.py +++ b/pvnet/data/datamodule.py @@ -1,65 +1,14 @@ """ Data module for pytorch lightning """ from datetime import datetime -import numpy as np import torch from lightning.pytorch import LightningDataModule from ocf_datapipes.training.pvnet import pvnet_datapipe -from ocf_datapipes.utils.consts import BatchKey from ocf_datapipes.utils.utils import stack_np_examples_into_batch from torch.utils.data import DataLoader -from torch.utils.data.datapipes._decorator import functional_datapipe -from torch.utils.data.datapipes.datapipe import IterDataPipe from torch.utils.data.datapipes.iter import FileLister - -def copy_batch_to_device(batch, device): - """Moves a dict-batch of tensors to new device.""" - batch_copy = {} - for k in list(batch.keys()): - if isinstance(batch[k], torch.Tensor): - batch_copy[k] = batch[k].to(device) - else: - batch_copy[k] = batch[k] - return batch_copy - - -def batch_to_tensor(batch): - """Moves numpy batch to a tensor""" - for k in list(batch.keys()): - if isinstance(batch[k], np.ndarray) and np.issubdtype(batch[k].dtype, np.number): - batch[k] = torch.as_tensor(batch[k]) - return batch - - -def split_batches(batch): - """Splits a single batch of data.""" - n_samples = batch[BatchKey.gsp].shape[0] - keys = list(batch.keys()) - examples = [{} for _ in range(n_samples)] - for i in range(n_samples): - b = examples[i] - for k in keys: - if ("idx" in k.name) or ("channel_names" in k.name): - b[k] = batch[k] - else: - b[k] = batch[k][i] - return examples - - -@functional_datapipe("split_batches") -class BatchSplitter(IterDataPipe): - """Pipeline step to split batches of data and yield single examples""" - - def __init__(self, source_datapipe: IterDataPipe): - """Pipeline step to split batches of data and yield single examples""" - self.source_datapipe = source_datapipe - - def __iter__(self): - """Opens the NWP data""" - for batch in self.source_datapipe: - for example in split_batches(batch): - yield example +from pvnet.data.utils import batch_to_tensor class DataModule(LightningDataModule): diff --git a/pvnet/data/utils.py b/pvnet/data/utils.py new file mode 100644 index 00000000..be173f60 --- /dev/null +++ b/pvnet/data/utils.py @@ -0,0 +1,56 @@ +"""Utils common between Wind and PV datamodules""" +import numpy as np +import torch +from ocf_datapipes.utils.consts import BatchKey +from torch.utils.data import IterDataPipe, functional_datapipe + + +def copy_batch_to_device(batch, device): + """Moves a dict-batch of tensors to new device.""" + batch_copy = {} + for k in list(batch.keys()): + if isinstance(batch[k], torch.Tensor): + batch_copy[k] = batch[k].to(device) + else: + batch_copy[k] = batch[k] + return batch_copy + + +def batch_to_tensor(batch): + """Moves numpy batch to a tensor""" + for k in list(batch.keys()): + if isinstance(batch[k], np.ndarray) and np.issubdtype(batch[k].dtype, np.number): + batch[k] = torch.as_tensor(batch[k]) + return batch + + +def split_batches(batch, splitting_key=BatchKey.gsp): + """Splits a single batch of data.""" + + n_samples = batch[splitting_key].shape[0] + keys = list(batch.keys()) + examples = [{} for _ in range(n_samples)] + for i in range(n_samples): + b = examples[i] + for k in keys: + if ("idx" in k.name) or ("channel_names" in k.name): + b[k] = batch[k] + else: + b[k] = batch[k][i] + return examples + + +@functional_datapipe("split_batches") +class BatchSplitter(IterDataPipe): + """Pipeline step to split batches of data and yield single examples""" + + def __init__(self, source_datapipe: IterDataPipe, splitting_key: BatchKey = BatchKey.gsp): + """Pipeline step to split batches of data and yield single examples""" + self.source_datapipe = source_datapipe + self.splitting_key = splitting_key + + def __iter__(self): + """Opens the NWP data""" + for batch in self.source_datapipe: + for example in split_batches(batch, splitting_key=self.splitting_key): + yield example diff --git a/pvnet/data/wind_datamodule.py b/pvnet/data/wind_datamodule.py new file mode 100644 index 00000000..bb3896f7 --- /dev/null +++ b/pvnet/data/wind_datamodule.py @@ -0,0 +1,142 @@ +""" Data module for pytorch lightning """ +import glob +from datetime import datetime + +from lightning.pytorch import LightningDataModule +from ocf_datapipes.training.windnet import windnet_netcdf_datapipe +from ocf_datapipes.utils.utils import stack_np_examples_into_batch +from torch.utils.data import DataLoader + +from pvnet.data.utils import batch_to_tensor + + +class WindDataModule(LightningDataModule): + """Datamodule for training pvnet and using pvnet pipeline in `ocf_datapipes`.""" + + def __init__( + self, + configuration=None, + batch_size=16, + num_workers=0, + prefetch_factor=None, + train_period=[None, None], + val_period=[None, None], + test_period=[None, None], + batch_dir=None, + ): + """Datamodule for training pvnet and using pvnet pipeline in `ocf_datapipes`. + + Can also be used with pre-made batches if `batch_dir` is set. + + + Args: + configuration: Path to datapipe configuration file. + batch_size: Batch size. + num_workers: Number of workers to use in multiprocess batch loading. + prefetch_factor: Number of data will be prefetched at the end of each worker process. + train_period: Date range filter for train dataloader. + val_period: Date range filter for val dataloader. + test_period: Date range filter for test dataloader. + batch_dir: Path to the directory of pre-saved batches. Cannot be used together with + 'train/val/test_period'. + + """ + super().__init__() + self.configuration = configuration + self.batch_size = batch_size + self.batch_dir = batch_dir + + if batch_dir is not None: + if any([period != [None, None] for period in [train_period, val_period, test_period]]): + raise ValueError("Cannot set `(train/val/test)_period` with presaved batches") + + self.train_period = [ + None if d is None else datetime.strptime(d, "%Y-%m-%d") for d in train_period + ] + self.val_period = [ + None if d is None else datetime.strptime(d, "%Y-%m-%d") for d in val_period + ] + self.test_period = [ + None if d is None else datetime.strptime(d, "%Y-%m-%d") for d in test_period + ] + + self._common_dataloader_kwargs = dict( + shuffle=False, # shuffled in datapipe step + batch_size=None, # batched in datapipe step + sampler=None, + batch_sampler=None, + num_workers=num_workers, + collate_fn=None, + pin_memory=False, + drop_last=False, + timeout=0, + worker_init_fn=None, + prefetch_factor=prefetch_factor, + persistent_workers=False, + ) + + def _get_datapipe(self, start_time, end_time): + data_pipeline = windnet_netcdf_datapipe( + self.configuration, + keys=["sensor", "nwp"], + ) + + data_pipeline = ( + data_pipeline.batch(self.batch_size) + .map(stack_np_examples_into_batch) + .map(batch_to_tensor) + ) + return data_pipeline + + def _get_premade_batches_datapipe(self, subdir, shuffle=False): + data_pipeline = windnet_netcdf_datapipe( + config_filename=self.configuration, + keys=["sensor", "nwp"], + filenames=list(glob.glob(f"{self.batch_dir}/{subdir}/*.nc")), + ) + if shuffle: + data_pipeline = ( + data_pipeline.shuffle(buffer_size=100) + .sharding_filter() + # Split the batches and reshuffle them to be combined into new batches + .split_batches() + .shuffle(buffer_size=100 * self.batch_size) + ) + else: + data_pipeline = ( + data_pipeline.sharding_filter() + # Split the batches so we can use any batch-size + .split_batches() + ) + + data_pipeline = ( + data_pipeline.batch(self.batch_size) + .map(stack_np_examples_into_batch) + .map(batch_to_tensor) + ) + + return data_pipeline + + def train_dataloader(self): + """Construct train dataloader""" + if self.batch_dir is not None: + datapipe = self._get_premade_batches_datapipe("train", shuffle=True) + else: + datapipe = self._get_datapipe(*self.train_period) + return DataLoader(datapipe, **self._common_dataloader_kwargs) + + def val_dataloader(self): + """Construct val dataloader""" + if self.batch_dir is not None: + datapipe = self._get_premade_batches_datapipe("val") + else: + datapipe = self._get_datapipe(*self.val_period) + return DataLoader(datapipe, **self._common_dataloader_kwargs) + + def test_dataloader(self): + """Construct test dataloader""" + if self.batch_dir is not None: + datapipe = self._get_premade_batches_datapipe("test") + else: + datapipe = self._get_datapipe(*self.test_period) + return DataLoader(datapipe, **self._common_dataloader_kwargs) diff --git a/scripts/save_batches.py b/scripts/save_batches.py index 950ad23e..995f23e1 100644 --- a/scripts/save_batches.py +++ b/scripts/save_batches.py @@ -26,6 +26,7 @@ import hydra import torch from ocf_datapipes.training.pvnet import pvnet_datapipe +from ocf_datapipes.training.windnet import windnet_datapipe from ocf_datapipes.utils.utils import stack_np_examples_into_batch from omegaconf import DictConfig, OmegaConf from sqlalchemy import exc as sa_exc @@ -44,16 +45,26 @@ class _save_batch_func_factory: - def __init__(self, batch_dir): + def __init__(self, batch_dir, output_format: str = "torch"): self.batch_dir = batch_dir + self.output_format = output_format def __call__(self, input): i, batch = input - torch.save(batch, f"{self.batch_dir}/{i:06}.pt") - - -def _get_datapipe(config_path, start_time, end_time, batch_size): - data_pipeline = pvnet_datapipe( + if self.output_format == "torch": + torch.save(batch, f"{self.batch_dir}/{i:06}.pt") + elif self.output_format == "netcdf": + batch.to_netcdf(f"{self.batch_dir}/{i:06}.nc", mode="w") + + +def _get_datapipe(config_path, start_time, end_time, batch_size, renewable: str = "pv"): + if renewable == "pv": + data_pipeline_fn = pvnet_datapipe + elif renewable == "wind": + data_pipeline_fn = windnet_datapipe + else: + raise ValueError(f"Unknown renewable: {renewable}") + data_pipeline = data_pipeline_fn( config_path, start_time=start_time, end_time=end_time, @@ -65,8 +76,10 @@ def _get_datapipe(config_path, start_time, end_time, batch_size): return data_pipeline -def _save_batches_with_dataloader(batch_pipe, batch_dir, num_batches, dataloader_kwargs): - save_func = _save_batch_func_factory(batch_dir) +def _save_batches_with_dataloader( + batch_pipe, batch_dir, num_batches, dataloader_kwargs, output_format: str = "torch" +): + save_func = _save_batch_func_factory(batch_dir, output_format=output_format) filenumber_pipe = IterableWrapper(range(num_batches)).sharding_filter() save_pipe = filenumber_pipe.zip(batch_pipe).map(save_func) @@ -126,6 +139,7 @@ def main(config: DictConfig): batch_dir=f"{config.batch_output_dir}/val", num_batches=config.num_val_batches, dataloader_kwargs=dataloader_kwargs, + output_format="torch" if config.renewable == "pv" else "netcdf", ) if config.num_train_batches > 0: @@ -142,6 +156,7 @@ def main(config: DictConfig): batch_dir=f"{config.batch_output_dir}/train", num_batches=config.num_train_batches, dataloader_kwargs=dataloader_kwargs, + output_format="torch" if config.renewable == "pv" else "netcdf", ) print("done") diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index 66ebeef9..eea38c94 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -1,4 +1,6 @@ from pvnet.data.datamodule import DataModule +from pvnet.data.wind_datamodule import WindDataModule +import os def test_init(): @@ -13,3 +15,16 @@ def test_init(): block_nwp_and_sat=False, batch_dir="tests/data/sample_batches", ) + + +def test_wind_init(): + dm = WindDataModule( + configuration=f"{os.path.dirname(os.path.abspath(__file__))}/data/sample_batches/data_configuration.yaml", + batch_size=2, + num_workers=0, + prefetch_factor=None, + train_period=[None, None], + val_period=[None, None], + test_period=[None, None], + batch_dir="tests/data/sample_batches", + )