Skip to content

Commit

Permalink
Merge pull request #229 from openclimatefix/jacob/india-pvnet
Browse files Browse the repository at this point in the history
Update PVNet datapipe to support ECMWF
  • Loading branch information
jacobbieker authored Oct 26, 2023
2 parents 429a1a3 + d17bdc7 commit 3cdbfe9
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 7 deletions.
11 changes: 7 additions & 4 deletions ocf_datapipes/load/nwp/nwp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torchdata.datapipes.iter import IterDataPipe

from ocf_datapipes.load.nwp.providers.ecmwf import open_ifs
from ocf_datapipes.load.nwp.providers.gfs import open_gfs
from ocf_datapipes.load.nwp.providers.icon import open_icon_eu, open_icon_global
from ocf_datapipes.load.nwp.providers.ukv import open_ukv

Expand All @@ -34,14 +35,16 @@ def __init__(
i.e. OSGB for UKV, Lat/Lon for ICON EU, Icoshedral grid for ICON Global
"""
self.zarr_path = zarr_path
if provider == "ukv":
if provider.lower() == "ukv" or provider == "UKMetOffice":
self.open_nwp = open_ukv
elif provider == "icon-eu":
elif provider.lower() == "icon-eu":
self.open_nwp = open_icon_eu
elif provider == "icon-global":
elif provider.lower() == "icon-global":
self.open_nwp = open_icon_global
elif provider == "ecmwf":
elif provider.lower() == "ecmwf":
self.open_nwp = open_ifs
elif provider.lower() == "gfs":
self.open_nwp = open_gfs
else:
raise ValueError(f"Unknown provider: {provider}")

Expand Down
5 changes: 4 additions & 1 deletion ocf_datapipes/training/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ def open_and_return_datapipes(
if use_nwp:
logger.debug("Opening NWP Data")
nwp_datapipe = (
OpenNWP(configuration.input_data.nwp.nwp_zarr_path)
OpenNWP(
configuration.input_data.nwp.nwp_zarr_path,
provider=configuration.input_data.nwp.nwp_provider,
)
.select_channels(configuration.input_data.nwp.nwp_channels)
.add_t0_idx_and_sample_period_duration(
sample_period_duration=timedelta(hours=1),
Expand Down
7 changes: 5 additions & 2 deletions ocf_datapipes/training/example/nwp_pv.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def nwp_pv_datapipe(
.fork(2, buffer_size=BUFFER_SIZE)
)

if configuration.input_data.nwp.nwp_provider == "UKMetOffice":
if configuration.input_data.nwp.nwp_provider in [ "UKMetOffice", "ukv"]:
nwp_datapipe = OpenNWPID(configuration.input_data.nwp.nwp_zarr_path)
elif configuration.input_data.nwp.nwp_provider == "GFS":
nwp_datapipe = OpenGFSForecast(configuration.input_data.nwp.nwp_zarr_path)
Expand Down Expand Up @@ -178,7 +178,10 @@ def nwp_pv_datapipe(
forecast_duration=timedelta(minutes=configuration.input_data.nwp.forecast_minutes),
)

if configuration.input_data.nwp.nwp_provider == "UKMetOffice":
if (
configuration.input_data.nwp.nwp_provider == "UKMetOffice"
or configuration.input_data.nwp.nwp_provider == "ukv"
):
nwp_datapipe = nwp_datapipe.normalize(mean=NWP_MEAN, std=NWP_STD)
else:
nwp_datapipe = nwp_datapipe.normalize(mean=NWP_GFS_MEAN, std=NWP_GFS_STD)
Expand Down
1 change: 1 addition & 0 deletions tests/config/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ input_data:
nwp_image_size_pixels_height: 2
nwp_image_size_pixels_width: 2
nwp_zarr_path: tests/data/nwp_data/test.zarr
nwp_provider: "ukv"
history_minutes: 60
forecast_minutes: 120
time_resolution_minutes: 60
Expand Down

0 comments on commit 3cdbfe9

Please sign in to comment.