diff --git a/ocf_datapipes/utils/utils.py b/ocf_datapipes/utils/utils.py index abbd5275a..b1c5bc25a 100644 --- a/ocf_datapipes/utils/utils.py +++ b/ocf_datapipes/utils/utils.py @@ -336,7 +336,7 @@ def combine_to_single_dataset(dataset_dict: dict[str, xr.Dataset]) -> xr.Dataset Combine multiple datasets into a single dataset Args: - *datasets: Datasets to combine + dataset_dict: Dictionary of xr.Dataset objects to combine Returns: Combined dataset diff --git a/ocf_datapipes/validation/check_for_nans.py b/ocf_datapipes/validation/check_for_nans.py index 421ce2a83..dfea01198 100644 --- a/ocf_datapipes/validation/check_for_nans.py +++ b/ocf_datapipes/validation/check_for_nans.py @@ -25,6 +25,7 @@ def __init__( source_datapipe: Datapipe emitting Xarray Datasets dataset_name: Optional name for dataset to check, if None, checks whole dataset fill_nans: Whether to fill NaNs with 0 or not + fill_value: Value to fill NaNs with """ self.source_datapipe = source_datapipe self.dataset_name = dataset_name diff --git a/tests/training/test_windnet.py b/tests/training/test_windnet.py new file mode 100644 index 000000000..f9919d90a --- /dev/null +++ b/tests/training/test_windnet.py @@ -0,0 +1,28 @@ +from datetime import datetime + +from ocf_datapipes.training.windnet import ( + windnet_datapipe, + windnet_netcdf_datapipe, +) +from ocf_datapipes.utils.utils import combine_to_single_dataset, uncombine_from_single_dataset +import pytest + + +def test_windnet_datapipe(configuration_filename): + start_time = datetime(1900, 1, 1) + end_time = datetime(2050, 1, 1) + dp = windnet_datapipe( + configuration_filename, + start_time=start_time, + end_time=end_time, + ) + datasets = next(iter(dp)) + dataset = combine_to_single_dataset(datasets) + # Need to serialize attributes to strings + dataset.to_netcdf("test.nc", mode="w", engine="h5netcdf", compute=True) + dp = windnet_netcdf_datapipe( + config_filename=configuration_filename, + filenames=["test.nc"], + keys=["gsp", "nwp", "sat", "pv"], + ) + datasets = next(iter(dp))