diff --git a/ocf_datapipes/load/nwp/nwp.py b/ocf_datapipes/load/nwp/nwp.py index 6cc90ec38..0584b99ae 100644 --- a/ocf_datapipes/load/nwp/nwp.py +++ b/ocf_datapipes/load/nwp/nwp.py @@ -9,6 +9,7 @@ from torchdata.datapipes.iter import IterDataPipe from ocf_datapipes.load.nwp.providers.ecmwf import open_ifs +from ocf_datapipes.load.nwp.providers.gfs import open_gfs from ocf_datapipes.load.nwp.providers.icon import open_icon_eu, open_icon_global from ocf_datapipes.load.nwp.providers.ukv import open_ukv @@ -34,14 +35,16 @@ def __init__( i.e. OSGB for UKV, Lat/Lon for ICON EU, Icoshedral grid for ICON Global """ self.zarr_path = zarr_path - if provider == "ukv": + if provider.lower() == "ukv" or provider == "UKMetOffice": self.open_nwp = open_ukv - elif provider == "icon-eu": + elif provider.lower() == "icon-eu": self.open_nwp = open_icon_eu - elif provider == "icon-global": + elif provider.lower() == "icon-global": self.open_nwp = open_icon_global - elif provider == "ecmwf": + elif provider.lower() == "ecmwf": self.open_nwp = open_ifs + elif provider.lower() == "gfs": + self.open_nwp = open_gfs else: raise ValueError(f"Unknown provider: {provider}") diff --git a/ocf_datapipes/training/common.py b/ocf_datapipes/training/common.py index 2b0e16262..2cbbc83a1 100644 --- a/ocf_datapipes/training/common.py +++ b/ocf_datapipes/training/common.py @@ -81,7 +81,10 @@ def open_and_return_datapipes( if use_nwp: logger.debug("Opening NWP Data") nwp_datapipe = ( - OpenNWP(configuration.input_data.nwp.nwp_zarr_path) + OpenNWP( + configuration.input_data.nwp.nwp_zarr_path, + provider=configuration.input_data.nwp.nwp_provider, + ) .select_channels(configuration.input_data.nwp.nwp_channels) .add_t0_idx_and_sample_period_duration( sample_period_duration=timedelta(hours=1), diff --git a/ocf_datapipes/training/example/nwp_pv.py b/ocf_datapipes/training/example/nwp_pv.py index 63db7630a..716304fe9 100644 --- a/ocf_datapipes/training/example/nwp_pv.py +++ b/ocf_datapipes/training/example/nwp_pv.py @@ -49,7 +49,7 @@ def nwp_pv_datapipe( .fork(2, buffer_size=BUFFER_SIZE) ) - if configuration.input_data.nwp.nwp_provider == "UKMetOffice": + if configuration.input_data.nwp.nwp_provider in [ "UKMetOffice", "ukv"]: nwp_datapipe = OpenNWPID(configuration.input_data.nwp.nwp_zarr_path) elif configuration.input_data.nwp.nwp_provider == "GFS": nwp_datapipe = OpenGFSForecast(configuration.input_data.nwp.nwp_zarr_path) @@ -178,7 +178,10 @@ def nwp_pv_datapipe( forecast_duration=timedelta(minutes=configuration.input_data.nwp.forecast_minutes), ) - if configuration.input_data.nwp.nwp_provider == "UKMetOffice": + if ( + configuration.input_data.nwp.nwp_provider == "UKMetOffice" + or configuration.input_data.nwp.nwp_provider == "ukv" + ): nwp_datapipe = nwp_datapipe.normalize(mean=NWP_MEAN, std=NWP_STD) else: nwp_datapipe = nwp_datapipe.normalize(mean=NWP_GFS_MEAN, std=NWP_GFS_STD) diff --git a/tests/config/test.yaml b/tests/config/test.yaml index b0ad59069..3aa8bff8a 100644 --- a/tests/config/test.yaml +++ b/tests/config/test.yaml @@ -12,6 +12,7 @@ input_data: nwp_image_size_pixels_height: 2 nwp_image_size_pixels_width: 2 nwp_zarr_path: tests/data/nwp_data/test.zarr + nwp_provider: "ukv" history_minutes: 60 forecast_minutes: 120 time_resolution_minutes: 60