Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ecmwf #48

Merged
merged 5 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added data/nwp_ecmwf_target_coords.nc
Binary file not shown.
File renamed without changes.
28 changes: 16 additions & 12 deletions pvnet_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

This app expects these evironmental variables to be available:
- DB_URL
- NWP_ZARR_PATH
- NWP_UKV_ZARR_PATH
- NWP_ECMWF_ZARR_PATH
- SATELLITE_ZARR_PATH
"""

Expand Down Expand Up @@ -41,7 +42,7 @@
worker_init_fn, populate_data_config_sources, convert_dataarray_to_forecasts, preds_to_dataarray
)
from pvnet_app.data import (
download_sat_data, download_nwp_data, preprocess_sat_data, regrid_nwp_data,
download_all_sat_data, download_all_nwp_data, preprocess_sat_data, preprocess_nwp_data,
)

# ---------------------------------------------------------------------------
Expand All @@ -61,7 +62,7 @@

# Huggingfacehub model repo and commit for PVNet (GSP-level model)
default_model_name = "openclimatefix/pvnet_v2"
default_model_version = "4203e12e719efd93da641c43d2e38527648f4915"
default_model_version = "4491e1ea440ee5f32e5a430391b3d338ff612900"

# Huggingfacehub model repo and commit for PVNet summation (GSP sum to national model)
# If summation_model_name is set to None, a simple sum is computed instead
Expand Down Expand Up @@ -111,7 +112,8 @@ def app(

This app expects these evironmental variables to be available:
- DB_URL
- NWP_ZARR_PATH
- NWP_UKV_ZARR_PATH
- NWP_ECMWF_ZARR_PATH
- SATELLITE_ZARR_PATH
Args:
t0 (datetime): Datetime at which forecast is made
Expand Down Expand Up @@ -180,18 +182,18 @@ def app(

# Download satellite data
logger.info("Downloading satellite data")
download_sat_data()
download_all_sat_data()

# Process the 5/15 minutely satellite data
preprocess_sat_data(t0)

# Download NWP data
logger.info("Downloading NWP data")
download_nwp_data()

# Regrid the NWP data if needed
regrid_nwp_data()
download_all_nwp_data()

# Preprocess the NWP data
preprocess_nwp_data()

# ---------------------------------------------------------------------------
# 2. Set up data loader
logger.info("Creating DataLoader")
Expand All @@ -201,6 +203,7 @@ def app(
model_name,
revision=model_version,
)

# Populate the data config with production data paths
temp_dir = tempfile.TemporaryDirectory()
populated_data_config_filename = f"{temp_dir.name}/data_config.yaml"
Expand Down Expand Up @@ -390,9 +393,10 @@ def app(
da_abs.sum(dim="gsp_id").expand_dims(dim="gsp_id", axis=0).assign_coords(gsp_id=[0])
)
da_abs_all = xr.concat([da_abs_national, da_abs], dim="gsp_id")
logger.info(
f"National forecast is {da_abs.sel(gsp_id=0, output_label='forecast_mw').values}"
)

logger.info(
f"National forecast is {da_abs_all.sel(gsp_id=0, output_label='forecast_mw').values}"
)

if save_gsp_sum:
# Compute the sum if we are logging the sume of GSPs independently
Expand Down
3 changes: 2 additions & 1 deletion pvnet_app/consts.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
sat_path = "sat.zarr"
nwp_path = "nwp.zarr"
nwp_ukv_path = "nwp_ukv.zarr"
nwp_ecmwf_path = "nwp_ecmwf.zarr"
112 changes: 92 additions & 20 deletions pvnet_app/data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import pandas as pd
import xarray as xr
import xesmf as xe
Expand All @@ -7,7 +8,7 @@
from datetime import timedelta
import ocf_blosc2

from pvnet_app.consts import sat_path, nwp_path
from pvnet_app.consts import sat_path, nwp_ukv_path, nwp_ecmwf_path

logger = logging.getLogger(__name__)

Expand All @@ -17,7 +18,7 @@
sat_15_path = "sat_15_min.zarr"


def download_sat_data():
def download_all_sat_data():
"""Download the sat data"""

# Clean out old files
Expand Down Expand Up @@ -76,23 +77,28 @@ def preprocess_sat_data(t0):
ds_sat_15.to_zarr(sat_path)

return use_15_minute


def download_nwp_data():
"""Download the NWP data"""
fs = fsspec.open(os.environ["NWP_ZARR_PATH"]).fs
fs.get(os.environ["NWP_ZARR_PATH"], nwp_path, recursive=True)


def _download_nwp_data(source, destination):
fs = fsspec.open(source).fs
fs.get(source, destination, recursive=True)


def regrid_nwp_data():
"""This function loads the NWP data, then regrids and saves it back out if the data is not on
the same grid as expected. The data is resaved in-place.
def download_all_nwp_data():
"""Download the NWP data"""
_download_nwp_data(os.environ["NWP_UKV_ZARR_PATH"], nwp_ukv_path)
_download_nwp_data(os.environ["NWP_ECMWF_ZARR_PATH"], nwp_ecmwf_path)


def regrid_nwp_data(nwp_zarr, target_coords_path, method):
"""This function loads the NWP data, then regrids and saves it back out if the data is not
on the same grid as expected. The data is resaved in-place.
"""

ds_raw = xr.open_zarr(nwp_path)
ds_raw = xr.open_zarr(nwp_zarr)

# These are the coords we are aiming for
ds_target_coords = xr.load_dataset(f"{this_dir}/../data/nwp_target_coords.nc")
ds_target_coords = xr.load_dataset(target_coords_path)

# Check if regridding step needs to be done
needs_regridding = not (
Expand All @@ -102,23 +108,89 @@ def regrid_nwp_data():
)

if not needs_regridding:
logger.info("No NWP regridding required - skipping this step")
logger.info(f"No NWP regridding required for {nwp_zarr} - skipping this step")
return

logger.info("Regridding NWP to expected grid")
logger.info(f"Regridding NWP {nwp_zarr} to expected grid")

# Pull the raw data into RAM
ds_raw = ds_raw.compute()

# Regrid in RAM efficient way by chunking first. Each step is regridded separately
regridder = xe.Regridder(ds_raw, ds_target_coords, method="bilinear")
regrid_chunk_dict = {
"step": 1,
"latitude": -1,
"longitude": -1,
"x": -1,
"y": -1,
}

regridder = xe.Regridder(ds_raw, ds_target_coords, method=method)
ds_regridded = regridder(
ds_raw.chunk(dict(x=-1, y=-1, step=1))
ds_raw.chunk(
{k: regrid_chunk_dict[k] for k in list(ds_raw.xindexes) if k in regrid_chunk_dict}
)
).compute(scheduler="single-threaded")

# Re-save - including rechunking
os.system(f"rm -fr {nwp_path}")
os.system(f"rm -rf {nwp_zarr}")
ds_regridded["variable"] = ds_regridded["variable"].astype(str)
ds_regridded.chunk(dict(step=12, x=100, y=100)).to_zarr(nwp_path)

return
# Rechunk to these dimensions when saving
save_chunk_dict = {
"step": 5,
"latitude": 100,
"longitude": 100,
"x": 100,
"y": 100,
}

ds_regridded.chunk(
{k: save_chunk_dict[k] for k in list(ds_raw.xindexes) if k in save_chunk_dict}
).to_zarr(nwp_zarr)


def fix_ecmwf_data():

ds = xr.open_zarr(nwp_ecmwf_path).compute()
ds["variable"] = ds["variable"].astype(str)

name_sub = {
"t": "t2m",
"clt": "tcc"
}

if any(v in name_sub for v in ds["variable"].values):
logger.info(f"Renaming the ECMWF variables")
ds["variable"] = np.array([name_sub[v] if v in name_sub else v for v in ds["variable"].values])
else:
logger.info(f"No ECMWF renaming required - skipping this step")

logger.info(f"Extending the ECMWF data to reach the shetlands")
# Thw data must be extended to reach the shetlands. This will fill missing lats with NaNs
# and reflects what the model saw in training
ds = ds.reindex(latitude=np.concatenate([np.arange(62, 60, -0.05), ds.latitude.values]))

# Re-save inplace
os.system(f"rm -rf {nwp_ecmwf_path}")
ds.to_zarr(nwp_ecmwf_path)


def preprocess_nwp_data():

# Regrid the UKV data
regrid_nwp_data(
nwp_zarr=nwp_ukv_path,
target_coords_path=f"{this_dir}/../data/nwp_ukv_target_coords.nc",
method="bilinear"
)

# Regrid the ECMWF data
regrid_nwp_data(
nwp_zarr=nwp_ecmwf_path,
target_coords_path=f"{this_dir}/../data/nwp_ecmwf_target_coords.nc",
method="conservative" # this is needed to avoid zeros around edges of ECMWF data
)

# Names need to be aligned between training and prod, and we need to infill the shetlands
fix_ecmwf_data()
4 changes: 2 additions & 2 deletions pvnet_app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from datetime import timezone, datetime

from pvnet_app.consts import sat_path, nwp_path
from pvnet_app.consts import sat_path, nwp_ukv_path, nwp_ecmwf_path


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -52,7 +52,7 @@ def populate_data_config_sources(input_path, output_path):

production_paths = {
"gsp": os.environ["DB_URL"],
"nwp": {"ukv": nwp_path},
"nwp": {"ukv": nwp_ukv_path, "ecmwf": nwp_ecmwf_path},
"satellite": sat_path,
# TODO: include hrvsatellite
}
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
pydantic
pytorch-lightning==2.1.3
torch[cpu]==2.2.0
PVNet-summation==0.1.2
pvnet==2.6.10
ocf_datapipes==3.2.7
PVNet-summation==0.1.3
pvnet==3.0.11
ocf_datapipes==3.2.11
nowcasting_datamodel>=1.5.30
fsspec[s3]
xarray
Expand Down
28 changes: 21 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,9 @@ def db_session(db_connection, engine_url):
s.rollback()


@pytest.fixture
def nwp_data():
def make_nwp_data(shell_path, varname):
# Load dataset which only contains coordinates, but no data
ds = xr.open_zarr(
f"{os.path.dirname(os.path.abspath(__file__))}/test_data/nwp_shell.zarr"
)
ds = xr.open_zarr(shell_path)

# Last init time was at least 2 hours ago and floor to 3-hour interval
t0_datetime_utc = time_before_present(timedelta(hours=2)).floor(timedelta(hours=3))
Expand All @@ -106,15 +103,32 @@ def nwp_data():
ds[v].encoding.clear()

# Add data to dataset
ds["UKV"] = xr.DataArray(
ds[varname] = xr.DataArray(
np.zeros([len(ds[c]) for c in ds.xindexes]),
coords=[ds[c] for c in ds.xindexes],
)

# Add stored attributes to DataArray
ds.UKV.attrs = ds.attrs["_data_attrs"]
ds[varname].attrs = ds.attrs["_data_attrs"]
del ds.attrs["_data_attrs"]

return ds


@pytest.fixture
def nwp_ukv_data():
return make_nwp_data(
shell_path=f"{os.path.dirname(os.path.abspath(__file__))}/test_data/nwp_ukv_shell.zarr",
varname="UKV",
)


@pytest.fixture
def nwp_ecmwf_data():
return make_nwp_data(
shell_path=f"{os.path.dirname(os.path.abspath(__file__))}/test_data/nwp_ecmwf_shell.zarr",
varname="UKV",
)


@pytest.fixture()
Expand Down
21 changes: 17 additions & 4 deletions tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@
ForecastValueSQL,
)

from pvnet_app.consts import sat_path, nwp_ukv_path, nwp_ecmwf_path
from pvnet_app.data import sat_5_path, sat_15_path

def test_app(db_session, nwp_data, sat_5_data, gsp_yields_and_systems, me_latest):
def test_app(
db_session, nwp_ukv_data, nwp_ecmwf_data, sat_5_data, gsp_yields_and_systems, me_latest
):
# Environment variable DB_URL is set in engine_url, which is called by db_session
# set NWP_ZARR_PATH
# save nwp_data to temporary file, and set NWP_ZARR_PATH
Expand All @@ -22,9 +26,13 @@ def test_app(db_session, nwp_data, sat_5_data, gsp_yields_and_systems, me_latest
with tempfile.TemporaryDirectory() as tmpdirname:
# The app loads sat and NWP data from environment variable
# Save out data, and set paths as environmental variables
temp_nwp_path = f"{tmpdirname}/nwp.zarr"
os.environ["NWP_ZARR_PATH"] = temp_nwp_path
nwp_data.to_zarr(temp_nwp_path)
temp_nwp_path = f"{tmpdirname}/nwp_ukv.zarr"
os.environ["NWP_UKV_ZARR_PATH"] = temp_nwp_path
nwp_ukv_data.to_zarr(temp_nwp_path)

temp_nwp_path = f"{tmpdirname}/nwp_ecmwf.zarr"
os.environ["NWP_ECMWF_ZARR_PATH"] = temp_nwp_path
nwp_ecmwf_data.to_zarr(temp_nwp_path)

# In production sat zarr is zipped
temp_sat_path = f"{tmpdirname}/sat.zarr.zip"
Expand All @@ -41,6 +49,11 @@ def test_app(db_session, nwp_data, sat_5_data, gsp_yields_and_systems, me_latest
from pvnet_app.app import app
app(gsp_ids=list(range(1, 318)), num_workers=2)

os.system(f"rm {sat_5_path}")
os.system(f"rm {sat_15_path}")
os.system(f"rm -r {sat_path}")
os.system(f"rm -r {nwp_ukv_path}")
os.system(f"rm -r {nwp_ecmwf_path}")
# Check forecasts have been made
# (317 GSPs + 1 National + GSP-sum) = 319 forecasts
# Doubled for historic and forecast
Expand Down
Loading
Loading