Skip to content

Commit

Permalink
add regridding
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdudfield committed Nov 29, 2024
1 parent b5b7224 commit d021915
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 22 deletions.
Binary file added india_forecast_app/data/mo_global/india_coords.nc
Binary file not shown.
57 changes: 57 additions & 0 deletions india_forecast_app/data/nwp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import xarray as xr
import os
import logging

logger = logging.getLogger(__name__)


def regrid_nwp_data(nwp_zarr: str, target_coords_path: str, nwp_zarr_save: str):
"""This function loads the NWP data, then regrids and saves it back out if the data is not
on the same grid as expected. The data is resaved in-place.
method can be 'conservative' or 'bilinear'
"""

logger.info(f"Regridding NWP data {nwp_zarr} to expected grid to {target_coords_path}")

ds_raw = xr.open_zarr(nwp_zarr)

# These are the coords we are aiming for
ds_target_coords = xr.load_dataset(target_coords_path)

# Check if regridding step needs to be done
needs_regridding = not (
ds_raw.latitude.equals(ds_target_coords.latitude)
and ds_raw.longitude.equals(ds_target_coords.longitude)
)

if not needs_regridding:
logger.info(f"No NWP regridding required for {nwp_zarr} - skipping this step")
return

logger.info(f"Regridding NWP {nwp_zarr} to expected grid")

# Pull the raw data into RAM
ds_raw = ds_raw.compute()

# regrid
ds_regridded = ds_raw.interp(
latitude=ds_target_coords.latitude, longitude=ds_target_coords.longitude
)

# Re-save - including rechunking
os.system(f"rm -rf {nwp_zarr_save}")
ds_regridded["variable"] = ds_regridded["variable"].astype(str)

# Rechunk to these dimensions when saving
save_chunk_dict = {
"step": 5,
"latitude": 100,
"longitude": 100,
"x": 100,
"y": 100,
}

ds_regridded.chunk(
{k: save_chunk_dict[k] for k in list(ds_raw.xindexes) if k in save_chunk_dict}
).to_zarr(nwp_zarr_save)
42 changes: 29 additions & 13 deletions india_forecast_app/models/pvnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)
from .utils import (
download_satellite_data,
NWPProcessAndCacheConfig,
populate_data_config_sources,
process_and_cache_nwp,
save_batch,
Expand Down Expand Up @@ -214,26 +215,41 @@ def _prepare_data_sources(self):
satellite_source_file_path = os.getenv("SATELLITE_ZARR_PATH", None)

# only load nwp that we need
nwp_paths = []
nwp_source_file_paths = []
nwp_configs = []
nwp_keys = self.config["input_data"]["nwp"].keys()
if "ecmwf" in nwp_keys:
nwp_ecmwf_source_file_path = os.environ["NWP_ECMWF_ZARR_PATH"]
nwp_source_file_paths.append(nwp_ecmwf_source_file_path)
nwp_paths.append(nwp_ecmwf_path)

nwp_configs.append(
NWPProcessAndCacheConfig(
source_nwp_path=os.environ["NWP_ECMWF_ZARR_PATH"],
dest_nwp_path=nwp_ecmwf_path,
source="ecmwf",
)
)

if "gfs" in nwp_keys:
nwp_gfs_source_file_path = os.environ["NWP_GFS_ZARR_PATH"]
nwp_source_file_paths.append(nwp_gfs_source_file_path)
nwp_paths.append(nwp_gfs_path)

nwp_configs.append(
NWPProcessAndCacheConfig(
source_nwp_path=os.environ["NWP_GFS_ZARR_PATH"],
dest_nwp_path=nwp_gfs_path,
source="gfs",
)
)

if "mo_global" in nwp_keys:
nwp_mo_global_source_file_path = os.environ["NWP_MO_GLOBAL_ZARR_PATH"]
nwp_source_file_paths.append(nwp_mo_global_source_file_path)
nwp_paths.append(nwp_mo_global_path)
nwp_configs.append(
NWPProcessAndCacheConfig(
source_nwp_path=os.environ["NWP_MO_GLOBAL_ZARR_PATH"],
dest_nwp_path=nwp_mo_global_path,
source="mo_global",
)
)

# Remove local cached zarr if already exists
for nwp_source_file_path, nwp_path in zip(nwp_source_file_paths, nwp_paths, strict=False):
for nwp_config in nwp_configs:
# Process/cache remote zarr locally
process_and_cache_nwp(nwp_source_file_path, nwp_path)
process_and_cache_nwp(nwp_config)
if use_satellite and "satellite" in self.config["input_data"].keys():
shutil.rmtree(satellite_path, ignore_errors=True)
download_satellite_data(satellite_source_file_path)
Expand Down
31 changes: 22 additions & 9 deletions india_forecast_app/models/pvnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import yaml
from ocf_datapipes.batch import BatchKey
from ocf_datapipes.utils.consts import ELEVATION_MEAN, ELEVATION_STD
from pydantic import BaseModel

from india_forecast_app.data.nwp import regrid_nwp_data
from .consts import (
nwp_ecmwf_path,
nwp_gfs_path,
Expand All @@ -25,6 +27,12 @@
log = logging.getLogger(__name__)

Check failure on line 27 in india_forecast_app/models/pvnet/utils.py

View workflow job for this annotation

GitHub Actions / lint_and_test / Lint the code and run the tests

Ruff (I001)

india_forecast_app/models/pvnet/utils.py:2:1: I001 Import block is un-sorted or un-formatted


class NWPProcessAndCacheConfig(BaseModel):

Check failure on line 30 in india_forecast_app/models/pvnet/utils.py

View workflow job for this annotation

GitHub Actions / lint_and_test / Lint the code and run the tests

Ruff (D101)

india_forecast_app/models/pvnet/utils.py:30:7: D101 Missing docstring in public class
source_nwp_path: str
dest_nwp_path: str
source: str


def worker_init_fn(worker_id):
"""
Clear reference to the loop and thread.
Expand Down Expand Up @@ -92,12 +100,13 @@ def populate_data_config_sources(input_path, output_path):
return config


def process_and_cache_nwp(source_nwp_path: str, dest_nwp_path: str):
def process_and_cache_nwp(nwp_config: NWPProcessAndCacheConfig):
"""Reads zarr file, renames t variable to t2m and saves zarr to new destination"""

log.info(
f"Processing and caching NWP data for {source_nwp_path}, " f"and saving to {dest_nwp_path}"
)
source_nwp_path = nwp_config.source_nwp_path
dest_nwp_path = nwp_config.dest_nwp_path

log.info(f"Processing and caching NWP data for {source_nwp_path} and saving to {dest_nwp_path}")

if os.path.exists(dest_nwp_path):
log.info(f"File already exists at {dest_nwp_path}")
Expand All @@ -115,10 +124,7 @@ def process_and_cache_nwp(source_nwp_path: str, dest_nwp_path: str):
if ds[v].dtype == object:
ds[v].encoding.clear()

is_gfs = "gfs" in source_nwp_path.lower()
is_ecmwf = "ecmwf" in source_nwp_path.lower()

if is_ecmwf:
if nwp_config.source == "ecmwf":
# Rename t variable to t2m
variables = list(ds.variable.values)
new_variables = []
Expand All @@ -134,14 +140,21 @@ def process_and_cache_nwp(source_nwp_path: str, dest_nwp_path: str):
ds.__setitem__("variable", new_variables)

# Hack to resolve some NWP data format differences between providers
elif is_gfs:
elif nwp_config.source == "gfs":
data_var = ds[list(ds.data_vars.keys())[0]]
# # Use .to_dataset() to split the data variable based on 'variable' dim
ds = data_var.to_dataset(dim="variable")
ds = ds.rename({"t2m": "t"})

# Save destination path
ds.to_zarr(dest_nwp_path, mode="a")

if nwp_config.source == "mo_global":
# regrid data
regrid_nwp_data(
dest_nwp_path, "india_forecast_app/data/mo_global/india_coords.nc", dest_nwp_path
)


def download_satellite_data(satellite_source_file_path: str) -> None:
"""Download the sat data"""
Expand Down
34 changes: 34 additions & 0 deletions tests/data/test_nwp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
""" Tests for the nwp regridding module """
import os
import tempfile

import xarray as xr

from india_forecast_app.data.nwp import regrid_nwp_data


def test_regrid_nwp_data(nwp_mo_global_data):
"""Test the regridding of the nwp data"""

# create a temporary dir
with tempfile.TemporaryDirectory() as temp_dir:

# save mo data to zarr
nwp_zarr = os.environ["NWP_MO_GLOBAL_ZARR_PATH"]

# regrid the data
nwp_zarr_save = f"{temp_dir}/nwp_regrid.zarr"
regrid_nwp_data(
nwp_zarr, "india_forecast_app/data/mo_global/india_coords.nc", nwp_zarr_save
)

# open the regridded data
nwp_xr = xr.open_zarr(nwp_zarr)
nwp_xr_regridded = xr.open_zarr(nwp_zarr_save)

# check the data is different in latitude and longitude
assert not nwp_xr_regridded.latitude.equals(nwp_xr.latitude)
assert not nwp_xr_regridded.longitude.equals(nwp_xr.longitude)

assert len(nwp_xr_regridded.latitude) == 225
assert len(nwp_xr_regridded.longitude) == 150

0 comments on commit d021915

Please sign in to comment.