Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add regridding #150

Merged
merged 13 commits into from
Dec 2, 2024
Binary file added india_forecast_app/data/mo_global/india_coords.nc
Binary file not shown.
62 changes: 62 additions & 0 deletions india_forecast_app/data/nwp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import xarray as xr
import os
import logging

logger = logging.getLogger(__name__)


def regrid_nwp_data(nwp_ds: xr.Dataset, target_coords_path: str) -> xr.Dataset:
"""This function loads the NWP data, then regrids and saves it back out if the data is not
on the same grid as expected. The data is resaved in-place.
"""

logger.info(f"Regridding NWP data to expected grid to {target_coords_path}")

ds_raw = nwp_ds

# These are the coords we are aiming for
ds_target_coords = xr.load_dataset(target_coords_path)

# Check if regridding step needs to be done
needs_regridding = not (
ds_raw.latitude.equals(ds_target_coords.latitude)
and ds_raw.longitude.equals(ds_target_coords.longitude)
)

if not needs_regridding:
logger.info(f"No NWP regridding required - skipping this step")
return ds_raw

# flip latitude, so its in ascending order
if ds_raw.latitude[0] > ds_raw.latitude[-1]:
ds_raw = ds_raw.reindex(latitude=ds_raw.latitude[::-1])

# clip to india coordindates
ds_raw = ds_raw.sel(
latitude=slice(0, 40),
longitude=slice(65, 100),
)

# regrid
logger.info(f"Regridding NWP to expected grid")
ds_regridded = ds_raw.interp(
latitude=ds_target_coords.latitude, longitude=ds_target_coords.longitude
)

# rechunking
ds_regridded["variable"] = ds_regridded["variable"].astype(str)

# Rechunk to these dimensions when saving
save_chunk_dict = {
"step": 5,
"latitude": 100,
"longitude": 100,
"x": 100,
"y": 100,
}

ds_regridded = ds_regridded.chunk(
{k: save_chunk_dict[k] for k in list(ds_raw.xindexes) if k in save_chunk_dict}
)

return ds_regridded
43 changes: 30 additions & 13 deletions india_forecast_app/models/pvnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
wind_path,
)
from .utils import (
NWPProcessAndCacheConfig,
download_satellite_data,
populate_data_config_sources,
process_and_cache_nwp,
Expand Down Expand Up @@ -214,26 +215,42 @@ def _prepare_data_sources(self):
satellite_source_file_path = os.getenv("SATELLITE_ZARR_PATH", None)

# only load nwp that we need
nwp_paths = []
nwp_source_file_paths = []
nwp_configs = []
nwp_keys = self.config["input_data"]["nwp"].keys()
if "ecmwf" in nwp_keys:
nwp_ecmwf_source_file_path = os.environ["NWP_ECMWF_ZARR_PATH"]
nwp_source_file_paths.append(nwp_ecmwf_source_file_path)
nwp_paths.append(nwp_ecmwf_path)

nwp_configs.append(
NWPProcessAndCacheConfig(
source_nwp_path=os.environ["NWP_ECMWF_ZARR_PATH"],
dest_nwp_path=nwp_ecmwf_path,
source="ecmwf",
)
)

if "gfs" in nwp_keys:
nwp_gfs_source_file_path = os.environ["NWP_GFS_ZARR_PATH"]
nwp_source_file_paths.append(nwp_gfs_source_file_path)
nwp_paths.append(nwp_gfs_path)

nwp_configs.append(
NWPProcessAndCacheConfig(
source_nwp_path=os.environ["NWP_GFS_ZARR_PATH"],
dest_nwp_path=nwp_gfs_path,
source="gfs",
)
)

if "mo_global" in nwp_keys:
nwp_mo_global_source_file_path = os.environ["NWP_MO_GLOBAL_ZARR_PATH"]
nwp_source_file_paths.append(nwp_mo_global_source_file_path)
nwp_paths.append(nwp_mo_global_path)
nwp_configs.append(
NWPProcessAndCacheConfig(
source_nwp_path=os.environ["NWP_MO_GLOBAL_ZARR_PATH"],
dest_nwp_path=nwp_mo_global_path,
source="mo_global",
config=self.config["input_data"]["nwp"]["mo_global"]
)
)

# Remove local cached zarr if already exists
for nwp_source_file_path, nwp_path in zip(nwp_source_file_paths, nwp_paths, strict=False):
for nwp_config in nwp_configs:
# Process/cache remote zarr locally
process_and_cache_nwp(nwp_source_file_path, nwp_path)
process_and_cache_nwp(nwp_config)
if use_satellite and "satellite" in self.config["input_data"].keys():
shutil.rmtree(satellite_path, ignore_errors=True)
download_satellite_data(satellite_source_file_path)
Expand Down
39 changes: 32 additions & 7 deletions india_forecast_app/models/pvnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
import xarray as xr
import yaml
from ocf_datapipes.batch import BatchKey
from ocf_datapipes.config.model import NWP
from ocf_datapipes.utils.consts import ELEVATION_MEAN, ELEVATION_STD
from pydantic import BaseModel

from india_forecast_app.data.nwp import regrid_nwp_data

from .consts import (
nwp_ecmwf_path,
Expand All @@ -25,6 +29,15 @@
log = logging.getLogger(__name__)


class NWPProcessAndCacheConfig(BaseModel):
"""Configuration for processing and caching NWP data"""

source_nwp_path: str
dest_nwp_path: str
source: str
config: Optional[NWP] = None


def worker_init_fn(worker_id):
"""
Clear reference to the loop and thread.
Expand Down Expand Up @@ -92,11 +105,15 @@ def populate_data_config_sources(input_path, output_path):
return config


def process_and_cache_nwp(source_nwp_path: str, dest_nwp_path: str):
def process_and_cache_nwp(nwp_config: NWPProcessAndCacheConfig):
"""Reads zarr file, renames t variable to t2m and saves zarr to new destination"""

source_nwp_path = nwp_config.source_nwp_path
dest_nwp_path = nwp_config.dest_nwp_path

log.info(
f"Processing and caching NWP data for {source_nwp_path}, " f"and saving to {dest_nwp_path}"
f"Processing and caching NWP data for {source_nwp_path} "
f"and saving to {dest_nwp_path} for {nwp_config.source}"
)

if os.path.exists(dest_nwp_path):
Expand All @@ -115,10 +132,7 @@ def process_and_cache_nwp(source_nwp_path: str, dest_nwp_path: str):
if ds[v].dtype == object:
ds[v].encoding.clear()

is_gfs = "gfs" in source_nwp_path.lower()
is_ecmwf = "ecmwf" in source_nwp_path.lower()

if is_ecmwf:
if nwp_config.source == "ecmwf":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be a separate PR but it would be good to do this for ECMWF as well as there is also a difference in spatial resolution between training and live for it, although it is worth noting that I think we the current HF config files we use it will error since we have hacked them to work and would need to change them back to the original number of pixels

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, i was thinking a seperate PR for ECMWF stuff

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now in #153

# Rename t variable to t2m
variables = list(ds.variable.values)
new_variables = []
Expand All @@ -134,12 +148,23 @@ def process_and_cache_nwp(source_nwp_path: str, dest_nwp_path: str):
ds.__setitem__("variable", new_variables)

# Hack to resolve some NWP data format differences between providers
elif is_gfs:
elif nwp_config.source == "gfs":
data_var = ds[list(ds.data_vars.keys())[0]]
# # Use .to_dataset() to split the data variable based on 'variable' dim
ds = data_var.to_dataset(dim="variable")
ds = ds.rename({"t2m": "t"})

if nwp_config.source == "mo_global":

# only select the variables we need
nwp_channels = list(nwp_config.config.nwp_channels)
ds = ds.sel(variable=nwp_channels)

# regrid data
ds = regrid_nwp_data(ds, "india_forecast_app/data/mo_global/india_coords.nc")

# Save destination path
log.info(f"Saving NWP data to {dest_nwp_path}")
ds.to_zarr(dest_nwp_path, mode="a")


Expand Down
30 changes: 30 additions & 0 deletions tests/data/test_nwp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
""" Tests for the nwp regridding module """
import os
import tempfile

import xarray as xr

from india_forecast_app.data.nwp import regrid_nwp_data


def test_regrid_nwp_data(nwp_mo_global_data):
"""Test the regridding of the nwp data"""

# create a temporary dir
with tempfile.TemporaryDirectory() as temp_dir:

# save mo data to zarr
nwp_zarr = os.environ["NWP_MO_GLOBAL_ZARR_PATH"]

# regrid the data
nwp_xr = xr.open_zarr(nwp_zarr)
nwp_xr_regridded = regrid_nwp_data(
nwp_xr, "india_forecast_app/data/mo_global/india_coords.nc"
)

# check the data is different in latitude and longitude
assert not nwp_xr_regridded.latitude.equals(nwp_xr.latitude)
assert not nwp_xr_regridded.longitude.equals(nwp_xr.longitude)

assert len(nwp_xr_regridded.latitude) == 225
assert len(nwp_xr_regridded.longitude) == 150
Loading