diff --git a/india_forecast_app/data/mo_global/india_coords.nc b/india_forecast_app/data/mo_global/india_coords.nc new file mode 100644 index 0000000..b97874c Binary files /dev/null and b/india_forecast_app/data/mo_global/india_coords.nc differ diff --git a/india_forecast_app/data/nwp.py b/india_forecast_app/data/nwp.py new file mode 100644 index 0000000..3b12107 --- /dev/null +++ b/india_forecast_app/data/nwp.py @@ -0,0 +1,62 @@ +import xarray as xr +import os +import logging + +logger = logging.getLogger(__name__) + + +def regrid_nwp_data(nwp_ds: xr.Dataset, target_coords_path: str) -> xr.Dataset: + """This function loads the NWP data, then regrids and saves it back out if the data is not + on the same grid as expected. The data is resaved in-place. + """ + + logger.info(f"Regridding NWP data to expected grid to {target_coords_path}") + + ds_raw = nwp_ds + + # These are the coords we are aiming for + ds_target_coords = xr.load_dataset(target_coords_path) + + # Check if regridding step needs to be done + needs_regridding = not ( + ds_raw.latitude.equals(ds_target_coords.latitude) + and ds_raw.longitude.equals(ds_target_coords.longitude) + ) + + if not needs_regridding: + logger.info(f"No NWP regridding required - skipping this step") + return ds_raw + + # flip latitude, so its in ascending order + if ds_raw.latitude[0] > ds_raw.latitude[-1]: + ds_raw = ds_raw.reindex(latitude=ds_raw.latitude[::-1]) + + # clip to india coordindates + ds_raw = ds_raw.sel( + latitude=slice(0, 40), + longitude=slice(65, 100), + ) + + # regrid + logger.info(f"Regridding NWP to expected grid") + ds_regridded = ds_raw.interp( + latitude=ds_target_coords.latitude, longitude=ds_target_coords.longitude + ) + + # rechunking + ds_regridded["variable"] = ds_regridded["variable"].astype(str) + + # Rechunk to these dimensions when saving + save_chunk_dict = { + "step": 5, + "latitude": 100, + "longitude": 100, + "x": 100, + "y": 100, + } + + ds_regridded = ds_regridded.chunk( + {k: save_chunk_dict[k] for k in list(ds_raw.xindexes) if k in save_chunk_dict} + ) + + return ds_regridded diff --git a/india_forecast_app/models/pvnet/model.py b/india_forecast_app/models/pvnet/model.py index d6604cc..829eb55 100644 --- a/india_forecast_app/models/pvnet/model.py +++ b/india_forecast_app/models/pvnet/model.py @@ -35,6 +35,7 @@ wind_path, ) from .utils import ( + NWPProcessAndCacheConfig, download_satellite_data, populate_data_config_sources, process_and_cache_nwp, @@ -214,26 +215,42 @@ def _prepare_data_sources(self): satellite_source_file_path = os.getenv("SATELLITE_ZARR_PATH", None) # only load nwp that we need - nwp_paths = [] - nwp_source_file_paths = [] + nwp_configs = [] nwp_keys = self.config["input_data"]["nwp"].keys() if "ecmwf" in nwp_keys: - nwp_ecmwf_source_file_path = os.environ["NWP_ECMWF_ZARR_PATH"] - nwp_source_file_paths.append(nwp_ecmwf_source_file_path) - nwp_paths.append(nwp_ecmwf_path) + + nwp_configs.append( + NWPProcessAndCacheConfig( + source_nwp_path=os.environ["NWP_ECMWF_ZARR_PATH"], + dest_nwp_path=nwp_ecmwf_path, + source="ecmwf", + ) + ) + if "gfs" in nwp_keys: - nwp_gfs_source_file_path = os.environ["NWP_GFS_ZARR_PATH"] - nwp_source_file_paths.append(nwp_gfs_source_file_path) - nwp_paths.append(nwp_gfs_path) + + nwp_configs.append( + NWPProcessAndCacheConfig( + source_nwp_path=os.environ["NWP_GFS_ZARR_PATH"], + dest_nwp_path=nwp_gfs_path, + source="gfs", + ) + ) + if "mo_global" in nwp_keys: - nwp_mo_global_source_file_path = os.environ["NWP_MO_GLOBAL_ZARR_PATH"] - nwp_source_file_paths.append(nwp_mo_global_source_file_path) - nwp_paths.append(nwp_mo_global_path) + nwp_configs.append( + NWPProcessAndCacheConfig( + source_nwp_path=os.environ["NWP_MO_GLOBAL_ZARR_PATH"], + dest_nwp_path=nwp_mo_global_path, + source="mo_global", + config=self.config["input_data"]["nwp"]["mo_global"] + ) + ) # Remove local cached zarr if already exists - for nwp_source_file_path, nwp_path in zip(nwp_source_file_paths, nwp_paths, strict=False): + for nwp_config in nwp_configs: # Process/cache remote zarr locally - process_and_cache_nwp(nwp_source_file_path, nwp_path) + process_and_cache_nwp(nwp_config) if use_satellite and "satellite" in self.config["input_data"].keys(): shutil.rmtree(satellite_path, ignore_errors=True) download_satellite_data(satellite_source_file_path) diff --git a/india_forecast_app/models/pvnet/utils.py b/india_forecast_app/models/pvnet/utils.py index 32b02b2..59e4fb8 100644 --- a/india_forecast_app/models/pvnet/utils.py +++ b/india_forecast_app/models/pvnet/utils.py @@ -9,7 +9,11 @@ import xarray as xr import yaml from ocf_datapipes.batch import BatchKey +from ocf_datapipes.config.model import NWP from ocf_datapipes.utils.consts import ELEVATION_MEAN, ELEVATION_STD +from pydantic import BaseModel + +from india_forecast_app.data.nwp import regrid_nwp_data from .consts import ( nwp_ecmwf_path, @@ -25,6 +29,15 @@ log = logging.getLogger(__name__) +class NWPProcessAndCacheConfig(BaseModel): + """Configuration for processing and caching NWP data""" + + source_nwp_path: str + dest_nwp_path: str + source: str + config: Optional[NWP] = None + + def worker_init_fn(worker_id): """ Clear reference to the loop and thread. @@ -92,11 +105,15 @@ def populate_data_config_sources(input_path, output_path): return config -def process_and_cache_nwp(source_nwp_path: str, dest_nwp_path: str): +def process_and_cache_nwp(nwp_config: NWPProcessAndCacheConfig): """Reads zarr file, renames t variable to t2m and saves zarr to new destination""" + source_nwp_path = nwp_config.source_nwp_path + dest_nwp_path = nwp_config.dest_nwp_path + log.info( - f"Processing and caching NWP data for {source_nwp_path}, " f"and saving to {dest_nwp_path}" + f"Processing and caching NWP data for {source_nwp_path} " + f"and saving to {dest_nwp_path} for {nwp_config.source}" ) if os.path.exists(dest_nwp_path): @@ -115,10 +132,7 @@ def process_and_cache_nwp(source_nwp_path: str, dest_nwp_path: str): if ds[v].dtype == object: ds[v].encoding.clear() - is_gfs = "gfs" in source_nwp_path.lower() - is_ecmwf = "ecmwf" in source_nwp_path.lower() - - if is_ecmwf: + if nwp_config.source == "ecmwf": # Rename t variable to t2m variables = list(ds.variable.values) new_variables = [] @@ -134,12 +148,23 @@ def process_and_cache_nwp(source_nwp_path: str, dest_nwp_path: str): ds.__setitem__("variable", new_variables) # Hack to resolve some NWP data format differences between providers - elif is_gfs: + elif nwp_config.source == "gfs": data_var = ds[list(ds.data_vars.keys())[0]] # # Use .to_dataset() to split the data variable based on 'variable' dim ds = data_var.to_dataset(dim="variable") ds = ds.rename({"t2m": "t"}) + + if nwp_config.source == "mo_global": + + # only select the variables we need + nwp_channels = list(nwp_config.config.nwp_channels) + ds = ds.sel(variable=nwp_channels) + + # regrid data + ds = regrid_nwp_data(ds, "india_forecast_app/data/mo_global/india_coords.nc") + # Save destination path + log.info(f"Saving NWP data to {dest_nwp_path}") ds.to_zarr(dest_nwp_path, mode="a") diff --git a/tests/data/test_nwp.py b/tests/data/test_nwp.py new file mode 100644 index 0000000..1e7b794 --- /dev/null +++ b/tests/data/test_nwp.py @@ -0,0 +1,30 @@ +""" Tests for the nwp regridding module """ +import os +import tempfile + +import xarray as xr + +from india_forecast_app.data.nwp import regrid_nwp_data + + +def test_regrid_nwp_data(nwp_mo_global_data): + """Test the regridding of the nwp data""" + + # create a temporary dir + with tempfile.TemporaryDirectory() as temp_dir: + + # save mo data to zarr + nwp_zarr = os.environ["NWP_MO_GLOBAL_ZARR_PATH"] + + # regrid the data + nwp_xr = xr.open_zarr(nwp_zarr) + nwp_xr_regridded = regrid_nwp_data( + nwp_xr, "india_forecast_app/data/mo_global/india_coords.nc" + ) + + # check the data is different in latitude and longitude + assert not nwp_xr_regridded.latitude.equals(nwp_xr.latitude) + assert not nwp_xr_regridded.longitude.equals(nwp_xr.longitude) + + assert len(nwp_xr_regridded.latitude) == 225 + assert len(nwp_xr_regridded.longitude) == 150