diff --git a/ocf_datapipes/training/pvnet_site.py b/ocf_datapipes/training/pvnet_site.py index 33af0256..a0b3e808 100644 --- a/ocf_datapipes/training/pvnet_site.py +++ b/ocf_datapipes/training/pvnet_site.py @@ -2,6 +2,7 @@ import logging from datetime import datetime, timedelta +from functools import partial from typing import List, Optional import xarray as xr @@ -237,12 +238,18 @@ def construct_sliced_data_pipeline( roi_width_pixels=conf_nwp[nwp_key].nwp_image_size_pixels_width, ) # Coarsen the data, if it is separated by 0.05 degrees each - nwp_datapipe = nwp_datapipe.map(potentially_coarsen) + potentially_coarsen_partial = partial( + potentially_coarsen, coarsen_to_deg=conf_nwp[nwp_key].coarsen_to_degrees + ) + nwp_datapipe = nwp_datapipe.map(potentially_coarsen_partial) # Somewhat hacky way for India specifically, need different mean/std for ECMWF data if conf_nwp[nwp_key].nwp_provider in ["ecmwf"]: normalize_provider = "ecmwf_india" + elif conf_nwp[nwp_key].nwp_provider in ["gfs"]: + normalize_provider = "gfs_india" else: normalize_provider = conf_nwp[nwp_key].nwp_provider + nwp_datapipes_dict[nwp_key] = nwp_datapipe.normalize( mean=NWP_MEANS[normalize_provider], std=NWP_STDS[normalize_provider], diff --git a/ocf_datapipes/utils/consts.py b/ocf_datapipes/utils/consts.py index c868b37b..f4f02cd9 100644 --- a/ocf_datapipes/utils/consts.py +++ b/ocf_datapipes/utils/consts.py @@ -39,6 +39,7 @@ def __getitem__(self, key): NWP_PROVIDERS = [ "ukv", "gfs", + "gfs_india", "icon-eu", "icon-global", "ecmwf", @@ -132,17 +133,32 @@ def __getitem__(self, key): UKV_STD = _to_data_array(UKV_STD) UKV_MEAN = _to_data_array(UKV_MEAN) -# These were calculated from 200 random init times (step 0s) from the MO global data +# --- MO Global + MO_GLOBAL_INDIA_MEAN = { - "temperature_sl": 298.2, - "wind_u_component_10m": 0.5732, - "wind_v_component_10m": -0.2831, + "temperature_sl": 295.34392488, + "wind_u_component_10m": 0.83223102, + "wind_v_component_10m": 0.0802083, + "downward_shortwave_radiation_flux_gl": 225.54222068, + "cloud_cover_high": 0.34935897, + "cloud_cover_low": 0.096081, + "cloud_cover_medium": 0.13878676, + "relative_humidity_sl": 69.59633137, + "snow_depth_gl": 3.45158744, + "visibility_sl": 23181.81547681, } MO_GLOBAL_INDIA_STD = { - "temperature_sl": 8.473, - "wind_u_component_10m": 2.599, - "wind_v_component_10m": 2.016, + "temperature_sl": 12.26983825, + "wind_u_component_10m": 3.45169835, + "wind_v_component_10m": 2.9825603, + "downward_shortwave_radiation_flux_gl": 303.85182864, + "cloud_cover_high": 0.40563507, + "cloud_cover_low": 0.18374192, + "cloud_cover_medium": 0.25972151, + "relative_humidity_sl": 21.00264399, + "snow_depth_gl": 30.19116501, + "visibility_sl": 5385.35839715, } @@ -197,6 +213,48 @@ def __getitem__(self, key): GFS_MEAN = _to_data_array(GFS_MEAN) +# ------ GFS +GFS_INDIA_STD_DICT = { + "t": 14.93798, + "prate": 5.965701e-05, + "u10": 3.4826114, + "v10": 3.167296, + "u100": 4.140226, + "v100": 3.984121, + "dlwrf": 79.30329, + "dswrf": 325.58582, + "hcc": 39.91955, + "lcc": 23.208075, + "mcc": 33.283035, + "r": 25.545837, + "sde": 0.10192183, + "tcc": 42.583195, + "vis": 3491.437, +} +GFS_INDIA_MEAN_DICT = { + "t": 298.27713, + "prate": 1.7736e-05, + "u10": 1.5782778, + "v10": 0.09856875, + "u100": 1.4558668, + "v100": -0.28256148, + "dlwrf": 356.57776, + "dswrf": 284.358, + "hcc": 26.965801, + "lcc": 9.2288, + "mcc": 17.2132, + "r": 38.2474, + "sde": 0.02070413, + "tcc": 36.962795, + "vis": 23386.936, +} + + +GFS_INDIA_VARIABLE_NAMES = tuple(GFS_INDIA_MEAN_DICT.keys()) +GFS_INDIA_STD = _to_data_array(GFS_INDIA_STD_DICT) +GFS_INDIA_MEAN = _to_data_array(GFS_INDIA_MEAN_DICT) + + # ------ ECMWF # These were calculated from 100 random init times of UK data from 2020-2023 ECMWF_STD = { @@ -369,6 +427,7 @@ def __getitem__(self, key): NWP_VARIABLE_NAMES = NWPStatDict( ukv=UKV_VARIABLE_NAMES, gfs=GFS_VARIABLE_NAMES, + gfs_india=GFS_INDIA_VARIABLE_NAMES, ecmwf=ECMWF_VARIABLE_NAMES, ecmwf_india=INDIA_ECMWF_VARIABLE_NAMES, excarta=EXCARTA_VARIABLE_NAMES, @@ -379,6 +438,7 @@ def __getitem__(self, key): NWP_STDS = NWPStatDict( ukv=UKV_STD, gfs=GFS_STD, + gfs_india=GFS_INDIA_STD, ecmwf=ECMWF_STD, ecmwf_india=INDIA_ECMWF_STD, excarta=EXCARTA_STD, @@ -389,6 +449,7 @@ def __getitem__(self, key): NWP_MEANS = NWPStatDict( ukv=UKV_MEAN, gfs=GFS_MEAN, + gfs_india=GFS_INDIA_MEAN, ecmwf=ECMWF_MEAN, ecmwf_india=INDIA_ECMWF_MEAN, excarta=EXCARTA_MEAN,