Skip to content

Commit

Permalink
add filter for nwp channels + test (#200)
Browse files Browse the repository at this point in the history
* add filter for nwp channels + test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix logic

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* lint and fix

* fix test

* filter wind batches in ocf_datapipes

* fix test

* fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
peterdudfield and pre-commit-ci[bot] authored May 28, 2024
1 parent f6e718b commit 0fdb9a3
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 1 deletion.
10 changes: 9 additions & 1 deletion pvnet/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
test_period=[None, None],
batch_dir=None,
shuffle_factor=100,
nwp_channels=None,
):
"""Datamodule for training pvnet architecture.
Expand All @@ -38,17 +39,24 @@ def __init__(
shuffle_factor: Number of presaved batches to be split and reshuffled to create returned
batches. A larger factor means on each epoch the batches will be more diverse but at
the cost of using more RAM.
nwp_channels: Number of NWP channels to use. If None, the all channels are used
"""
super().__init__()
self.configuration = configuration
self.batch_size = batch_size
self.batch_dir = batch_dir
self.shuffle_factor = shuffle_factor
self.nwp_channels = nwp_channels

if not ((batch_dir is not None) ^ (configuration is not None)):
raise ValueError("Exactly one of `batch_dir` or `configuration` must be set.")

if (nwp_channels is not None) and (batch_dir is None):
raise ValueError(
"In order for 'nwp_channels' to work, we need batch_dir. "
"Otherwise the nwp channels is one in the configuration"
)

if batch_dir is not None:
if any([period != [None, None] for period in [train_period, val_period, test_period]]):
raise ValueError("Cannot set `(train/val/test)_period` with presaved batches")
Expand Down
3 changes: 3 additions & 0 deletions pvnet/data/wind_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class WindDataModule(BaseDataModule):
"""Datamodule for training windnet and using windnet pipeline in `ocf_datapipes`."""

def _get_datapipe(self, start_time, end_time):
# TODO is this is not right, need to load full windnet pipeline
data_pipeline = windnet_netcdf_datapipe(
self.configuration,
keys=["wind", "nwp", "sensor"],
Expand All @@ -28,7 +29,9 @@ def _get_premade_batches_datapipe(self, subdir, shuffle=False):
data_pipeline = windnet_netcdf_datapipe(
keys=["wind", "nwp", "sensor"],
filenames=filenames,
nwp_channels=self.nwp_channels,
)

data_pipeline = (
data_pipeline.batch(self.batch_size)
.map(stack_np_examples_into_batch)
Expand Down
23 changes: 23 additions & 0 deletions tests/data/test_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pvnet.data.wind_datamodule import WindDataModule
from pvnet.data.pv_site_datamodule import PVSiteDataModule
import os
from ocf_datapipes.batch.batches import BatchKey, NWPBatchKey


def test_init():
Expand Down Expand Up @@ -30,6 +31,28 @@ def test_wind_init():
)


def test_wind_init_with_nwp_filter():
dm = WindDataModule(
configuration=None,
batch_size=2,
num_workers=0,
prefetch_factor=None,
train_period=[None, None],
val_period=[None, None],
test_period=[None, None],
batch_dir="tests/test_data/sample_wind_batches",
nwp_channels={"ecmwf": ["t2m", "v200"]},
)
dataloader = iter(dm.train_dataloader())

batch = next(dataloader)
batch_channels = batch[BatchKey.nwp]["ecmwf"][NWPBatchKey.nwp_channel_names]
print(batch_channels)
for v in ["t2m", "v200"]:
assert v in batch_channels
assert batch[BatchKey.nwp]["ecmwf"][NWPBatchKey.nwp].shape[2] == 2


def test_pv_site_init():
dm = PVSiteDataModule(
configuration=f"{os.path.dirname(os.path.abspath(__file__))}/test_data/sample_batches/data_configuration.yaml",
Expand Down

0 comments on commit 0fdb9a3

Please sign in to comment.