From 4ab76414c77d26000b7c6005ac527374bf65c97f Mon Sep 17 00:00:00 2001 From: Jacob Bieker Date: Fri, 24 Nov 2023 17:18:33 +0000 Subject: [PATCH] Add option to save WindNet batches --- configs/config.yaml | 1 + scripts/save_batches.py | 30 ++++++++++++++++++++++-------- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/configs/config.yaml b/configs/config.yaml index 02d0a8d0..32931fcc 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -10,6 +10,7 @@ defaults: - experiment: null - hparams_search: null - hydra: default.yaml + - renewable: "pv" # enable color logging # - override hydra/hydra_logging: colorlog diff --git a/scripts/save_batches.py b/scripts/save_batches.py index 950ad23e..f95de847 100644 --- a/scripts/save_batches.py +++ b/scripts/save_batches.py @@ -26,12 +26,14 @@ import hydra import torch from ocf_datapipes.training.pvnet import pvnet_datapipe +from ocf_datapipes.training.windnet import windnet_datapipe from ocf_datapipes.utils.utils import stack_np_examples_into_batch from omegaconf import DictConfig, OmegaConf from sqlalchemy import exc as sa_exc from torch.utils.data import DataLoader from torch.utils.data.datapipes.iter import IterableWrapper from tqdm import tqdm +import xarray as xr from pvnet.data.datamodule import batch_to_tensor from pvnet.utils import print_config @@ -44,16 +46,26 @@ class _save_batch_func_factory: - def __init__(self, batch_dir): + def __init__(self, batch_dir, output_format: str = "torch"): self.batch_dir = batch_dir + self.output_format = output_format def __call__(self, input): i, batch = input - torch.save(batch, f"{self.batch_dir}/{i:06}.pt") - - -def _get_datapipe(config_path, start_time, end_time, batch_size): - data_pipeline = pvnet_datapipe( + if self.output_format == "torch": + torch.save(batch, f"{self.batch_dir}/{i:06}.pt") + elif self.output_format == "netcdf": + batch.to_netcdf(f"{self.batch_dir}/{i:06}.nc", mode="w") + + +def _get_datapipe(config_path, start_time, end_time, batch_size, renewable: str = "pv"): + if renewable == "pv": + data_pipeline_fn = pvnet_datapipe + elif renewable == "wind": + data_pipeline_fn = windnet_datapipe + else: + raise ValueError(f"Unknown renewable: {renewable}") + data_pipeline = data_pipeline_fn( config_path, start_time=start_time, end_time=end_time, @@ -65,8 +77,8 @@ def _get_datapipe(config_path, start_time, end_time, batch_size): return data_pipeline -def _save_batches_with_dataloader(batch_pipe, batch_dir, num_batches, dataloader_kwargs): - save_func = _save_batch_func_factory(batch_dir) +def _save_batches_with_dataloader(batch_pipe, batch_dir, num_batches, dataloader_kwargs, output_format: str = "torch"): + save_func = _save_batch_func_factory(batch_dir, output_format=output_format) filenumber_pipe = IterableWrapper(range(num_batches)).sharding_filter() save_pipe = filenumber_pipe.zip(batch_pipe).map(save_func) @@ -126,6 +138,7 @@ def main(config: DictConfig): batch_dir=f"{config.batch_output_dir}/val", num_batches=config.num_val_batches, dataloader_kwargs=dataloader_kwargs, + output_format="torch" if config.renewable == "pv" else "netcdf", ) if config.num_train_batches > 0: @@ -142,6 +155,7 @@ def main(config: DictConfig): batch_dir=f"{config.batch_output_dir}/train", num_batches=config.num_train_batches, dataloader_kwargs=dataloader_kwargs, + output_format="torch" if config.renewable == "pv" else "netcdf", ) print("done")