From 0fdb9a316f9cc3d97c47dd0ec381afda520c2105 Mon Sep 17 00:00:00 2001 From: Peter Dudfield <34686298+peterdudfield@users.noreply.github.com> Date: Tue, 28 May 2024 12:24:19 +0100 Subject: [PATCH] add filter for nwp channels + test (#200) * add filter for nwp channels + test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix logic * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * lint and fix * fix test * filter wind batches in ocf_datapipes * fix test * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- pvnet/data/base.py | 10 +++++++++- pvnet/data/wind_datamodule.py | 3 +++ tests/data/test_datamodule.py | 23 +++++++++++++++++++++++ 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/pvnet/data/base.py b/pvnet/data/base.py index 9b67b8b3..b53c4ee8 100644 --- a/pvnet/data/base.py +++ b/pvnet/data/base.py @@ -19,6 +19,7 @@ def __init__( test_period=[None, None], batch_dir=None, shuffle_factor=100, + nwp_channels=None, ): """Datamodule for training pvnet architecture. @@ -38,17 +39,24 @@ def __init__( shuffle_factor: Number of presaved batches to be split and reshuffled to create returned batches. A larger factor means on each epoch the batches will be more diverse but at the cost of using more RAM. - + nwp_channels: Number of NWP channels to use. If None, the all channels are used """ super().__init__() self.configuration = configuration self.batch_size = batch_size self.batch_dir = batch_dir self.shuffle_factor = shuffle_factor + self.nwp_channels = nwp_channels if not ((batch_dir is not None) ^ (configuration is not None)): raise ValueError("Exactly one of `batch_dir` or `configuration` must be set.") + if (nwp_channels is not None) and (batch_dir is None): + raise ValueError( + "In order for 'nwp_channels' to work, we need batch_dir. " + "Otherwise the nwp channels is one in the configuration" + ) + if batch_dir is not None: if any([period != [None, None] for period in [train_period, val_period, test_period]]): raise ValueError("Cannot set `(train/val/test)_period` with presaved batches") diff --git a/pvnet/data/wind_datamodule.py b/pvnet/data/wind_datamodule.py index ef8ae3b6..0c11d31d 100644 --- a/pvnet/data/wind_datamodule.py +++ b/pvnet/data/wind_datamodule.py @@ -11,6 +11,7 @@ class WindDataModule(BaseDataModule): """Datamodule for training windnet and using windnet pipeline in `ocf_datapipes`.""" def _get_datapipe(self, start_time, end_time): + # TODO is this is not right, need to load full windnet pipeline data_pipeline = windnet_netcdf_datapipe( self.configuration, keys=["wind", "nwp", "sensor"], @@ -28,7 +29,9 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False): data_pipeline = windnet_netcdf_datapipe( keys=["wind", "nwp", "sensor"], filenames=filenames, + nwp_channels=self.nwp_channels, ) + data_pipeline = ( data_pipeline.batch(self.batch_size) .map(stack_np_examples_into_batch) diff --git a/tests/data/test_datamodule.py b/tests/data/test_datamodule.py index bb3b74f0..ff07fe0f 100644 --- a/tests/data/test_datamodule.py +++ b/tests/data/test_datamodule.py @@ -2,6 +2,7 @@ from pvnet.data.wind_datamodule import WindDataModule from pvnet.data.pv_site_datamodule import PVSiteDataModule import os +from ocf_datapipes.batch.batches import BatchKey, NWPBatchKey def test_init(): @@ -30,6 +31,28 @@ def test_wind_init(): ) +def test_wind_init_with_nwp_filter(): + dm = WindDataModule( + configuration=None, + batch_size=2, + num_workers=0, + prefetch_factor=None, + train_period=[None, None], + val_period=[None, None], + test_period=[None, None], + batch_dir="tests/test_data/sample_wind_batches", + nwp_channels={"ecmwf": ["t2m", "v200"]}, + ) + dataloader = iter(dm.train_dataloader()) + + batch = next(dataloader) + batch_channels = batch[BatchKey.nwp]["ecmwf"][NWPBatchKey.nwp_channel_names] + print(batch_channels) + for v in ["t2m", "v200"]: + assert v in batch_channels + assert batch[BatchKey.nwp]["ecmwf"][NWPBatchKey.nwp].shape[2] == 2 + + def test_pv_site_init(): dm = PVSiteDataModule( configuration=f"{os.path.dirname(os.path.abspath(__file__))}/test_data/sample_batches/data_configuration.yaml",