Skip to content

Commit

Permalink
Add test for windnet
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Nov 15, 2023
1 parent d280ee6 commit 236b602
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 1 deletion.
2 changes: 1 addition & 1 deletion ocf_datapipes/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def combine_to_single_dataset(dataset_dict: dict[str, xr.Dataset]) -> xr.Dataset
Combine multiple datasets into a single dataset
Args:
*datasets: Datasets to combine
dataset_dict: Dictionary of xr.Dataset objects to combine
Returns:
Combined dataset
Expand Down
1 change: 1 addition & 0 deletions ocf_datapipes/validation/check_for_nans.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
source_datapipe: Datapipe emitting Xarray Datasets
dataset_name: Optional name for dataset to check, if None, checks whole dataset
fill_nans: Whether to fill NaNs with 0 or not
fill_value: Value to fill NaNs with
"""
self.source_datapipe = source_datapipe
self.dataset_name = dataset_name
Expand Down
28 changes: 28 additions & 0 deletions tests/training/test_windnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from datetime import datetime

from ocf_datapipes.training.windnet import (
windnet_datapipe,
windnet_netcdf_datapipe,
)
from ocf_datapipes.utils.utils import combine_to_single_dataset, uncombine_from_single_dataset
import pytest


def test_windnet_datapipe(configuration_filename):
start_time = datetime(1900, 1, 1)
end_time = datetime(2050, 1, 1)
dp = windnet_datapipe(
configuration_filename,
start_time=start_time,
end_time=end_time,
)
datasets = next(iter(dp))
dataset = combine_to_single_dataset(datasets)
# Need to serialize attributes to strings
dataset.to_netcdf("test.nc", mode="w", engine="h5netcdf", compute=True)
dp = windnet_netcdf_datapipe(
config_filename=configuration_filename,
filenames=["test.nc"],
keys=["gsp", "nwp", "sat", "pv"],
)
datasets = next(iter(dp))

0 comments on commit 236b602

Please sign in to comment.