Skip to content

Commit

Permalink
Merge pull request #48 from openclimatefix/add_ecmwf
Browse files Browse the repository at this point in the history
Add ecmwf
  • Loading branch information
dfulu authored Mar 4, 2024
2 parents 09f9bec + 83dc26c commit 7cdc79d
Show file tree
Hide file tree
Showing 70 changed files with 471 additions and 62 deletions.
Binary file added data/nwp_ecmwf_target_coords.nc
Binary file not shown.
File renamed without changes.
28 changes: 16 additions & 12 deletions pvnet_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
This app expects these evironmental variables to be available:
- DB_URL
- NWP_ZARR_PATH
- NWP_UKV_ZARR_PATH
- NWP_ECMWF_ZARR_PATH
- SATELLITE_ZARR_PATH
"""

Expand Down Expand Up @@ -41,7 +42,7 @@
worker_init_fn, populate_data_config_sources, convert_dataarray_to_forecasts, preds_to_dataarray
)
from pvnet_app.data import (
download_sat_data, download_nwp_data, preprocess_sat_data, regrid_nwp_data,
download_all_sat_data, download_all_nwp_data, preprocess_sat_data, preprocess_nwp_data,
)

# ---------------------------------------------------------------------------
Expand All @@ -61,7 +62,7 @@

# Huggingfacehub model repo and commit for PVNet (GSP-level model)
default_model_name = "openclimatefix/pvnet_v2"
default_model_version = "4203e12e719efd93da641c43d2e38527648f4915"
default_model_version = "4491e1ea440ee5f32e5a430391b3d338ff612900"

# Huggingfacehub model repo and commit for PVNet summation (GSP sum to national model)
# If summation_model_name is set to None, a simple sum is computed instead
Expand Down Expand Up @@ -111,7 +112,8 @@ def app(
This app expects these evironmental variables to be available:
- DB_URL
- NWP_ZARR_PATH
- NWP_UKV_ZARR_PATH
- NWP_ECMWF_ZARR_PATH
- SATELLITE_ZARR_PATH
Args:
t0 (datetime): Datetime at which forecast is made
Expand Down Expand Up @@ -180,18 +182,18 @@ def app(

# Download satellite data
logger.info("Downloading satellite data")
download_sat_data()
download_all_sat_data()

# Process the 5/15 minutely satellite data
preprocess_sat_data(t0)

# Download NWP data
logger.info("Downloading NWP data")
download_nwp_data()

# Regrid the NWP data if needed
regrid_nwp_data()
download_all_nwp_data()

# Preprocess the NWP data
preprocess_nwp_data()

# ---------------------------------------------------------------------------
# 2. Set up data loader
logger.info("Creating DataLoader")
Expand All @@ -201,6 +203,7 @@ def app(
model_name,
revision=model_version,
)

# Populate the data config with production data paths
temp_dir = tempfile.TemporaryDirectory()
populated_data_config_filename = f"{temp_dir.name}/data_config.yaml"
Expand Down Expand Up @@ -390,9 +393,10 @@ def app(
da_abs.sum(dim="gsp_id").expand_dims(dim="gsp_id", axis=0).assign_coords(gsp_id=[0])
)
da_abs_all = xr.concat([da_abs_national, da_abs], dim="gsp_id")
logger.info(
f"National forecast is {da_abs.sel(gsp_id=0, output_label='forecast_mw').values}"
)

logger.info(
f"National forecast is {da_abs_all.sel(gsp_id=0, output_label='forecast_mw').values}"
)

if save_gsp_sum:
# Compute the sum if we are logging the sume of GSPs independently
Expand Down
3 changes: 2 additions & 1 deletion pvnet_app/consts.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
sat_path = "sat.zarr"
nwp_path = "nwp.zarr"
nwp_ukv_path = "nwp_ukv.zarr"
nwp_ecmwf_path = "nwp_ecmwf.zarr"
112 changes: 92 additions & 20 deletions pvnet_app/data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import pandas as pd
import xarray as xr
import xesmf as xe
Expand All @@ -7,7 +8,7 @@
from datetime import timedelta
import ocf_blosc2

from pvnet_app.consts import sat_path, nwp_path
from pvnet_app.consts import sat_path, nwp_ukv_path, nwp_ecmwf_path

logger = logging.getLogger(__name__)

Expand All @@ -17,7 +18,7 @@
sat_15_path = "sat_15_min.zarr"


def download_sat_data():
def download_all_sat_data():
"""Download the sat data"""

# Clean out old files
Expand Down Expand Up @@ -76,23 +77,28 @@ def preprocess_sat_data(t0):
ds_sat_15.to_zarr(sat_path)

return use_15_minute


def download_nwp_data():
"""Download the NWP data"""
fs = fsspec.open(os.environ["NWP_ZARR_PATH"]).fs
fs.get(os.environ["NWP_ZARR_PATH"], nwp_path, recursive=True)


def _download_nwp_data(source, destination):
fs = fsspec.open(source).fs
fs.get(source, destination, recursive=True)


def regrid_nwp_data():
"""This function loads the NWP data, then regrids and saves it back out if the data is not on
the same grid as expected. The data is resaved in-place.
def download_all_nwp_data():
"""Download the NWP data"""
_download_nwp_data(os.environ["NWP_UKV_ZARR_PATH"], nwp_ukv_path)
_download_nwp_data(os.environ["NWP_ECMWF_ZARR_PATH"], nwp_ecmwf_path)


def regrid_nwp_data(nwp_zarr, target_coords_path, method):
"""This function loads the NWP data, then regrids and saves it back out if the data is not
on the same grid as expected. The data is resaved in-place.
"""

ds_raw = xr.open_zarr(nwp_path)
ds_raw = xr.open_zarr(nwp_zarr)

# These are the coords we are aiming for
ds_target_coords = xr.load_dataset(f"{this_dir}/../data/nwp_target_coords.nc")
ds_target_coords = xr.load_dataset(target_coords_path)

# Check if regridding step needs to be done
needs_regridding = not (
Expand All @@ -102,23 +108,89 @@ def regrid_nwp_data():
)

if not needs_regridding:
logger.info("No NWP regridding required - skipping this step")
logger.info(f"No NWP regridding required for {nwp_zarr} - skipping this step")
return

logger.info("Regridding NWP to expected grid")
logger.info(f"Regridding NWP {nwp_zarr} to expected grid")

# Pull the raw data into RAM
ds_raw = ds_raw.compute()

# Regrid in RAM efficient way by chunking first. Each step is regridded separately
regridder = xe.Regridder(ds_raw, ds_target_coords, method="bilinear")
regrid_chunk_dict = {
"step": 1,
"latitude": -1,
"longitude": -1,
"x": -1,
"y": -1,
}

regridder = xe.Regridder(ds_raw, ds_target_coords, method=method)
ds_regridded = regridder(
ds_raw.chunk(dict(x=-1, y=-1, step=1))
ds_raw.chunk(
{k: regrid_chunk_dict[k] for k in list(ds_raw.xindexes) if k in regrid_chunk_dict}
)
).compute(scheduler="single-threaded")

# Re-save - including rechunking
os.system(f"rm -fr {nwp_path}")
os.system(f"rm -rf {nwp_zarr}")
ds_regridded["variable"] = ds_regridded["variable"].astype(str)
ds_regridded.chunk(dict(step=12, x=100, y=100)).to_zarr(nwp_path)

return
# Rechunk to these dimensions when saving
save_chunk_dict = {
"step": 5,
"latitude": 100,
"longitude": 100,
"x": 100,
"y": 100,
}

ds_regridded.chunk(
{k: save_chunk_dict[k] for k in list(ds_raw.xindexes) if k in save_chunk_dict}
).to_zarr(nwp_zarr)


def fix_ecmwf_data():

ds = xr.open_zarr(nwp_ecmwf_path).compute()
ds["variable"] = ds["variable"].astype(str)

name_sub = {
"t": "t2m",
"clt": "tcc"
}

if any(v in name_sub for v in ds["variable"].values):
logger.info(f"Renaming the ECMWF variables")
ds["variable"] = np.array([name_sub[v] if v in name_sub else v for v in ds["variable"].values])
else:
logger.info(f"No ECMWF renaming required - skipping this step")

logger.info(f"Extending the ECMWF data to reach the shetlands")
# Thw data must be extended to reach the shetlands. This will fill missing lats with NaNs
# and reflects what the model saw in training
ds = ds.reindex(latitude=np.concatenate([np.arange(62, 60, -0.05), ds.latitude.values]))

# Re-save inplace
os.system(f"rm -rf {nwp_ecmwf_path}")
ds.to_zarr(nwp_ecmwf_path)


def preprocess_nwp_data():

# Regrid the UKV data
regrid_nwp_data(
nwp_zarr=nwp_ukv_path,
target_coords_path=f"{this_dir}/../data/nwp_ukv_target_coords.nc",
method="bilinear"
)

# Regrid the ECMWF data
regrid_nwp_data(
nwp_zarr=nwp_ecmwf_path,
target_coords_path=f"{this_dir}/../data/nwp_ecmwf_target_coords.nc",
method="conservative" # this is needed to avoid zeros around edges of ECMWF data
)

# Names need to be aligned between training and prod, and we need to infill the shetlands
fix_ecmwf_data()
4 changes: 2 additions & 2 deletions pvnet_app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from datetime import timezone, datetime

from pvnet_app.consts import sat_path, nwp_path
from pvnet_app.consts import sat_path, nwp_ukv_path, nwp_ecmwf_path


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -52,7 +52,7 @@ def populate_data_config_sources(input_path, output_path):

production_paths = {
"gsp": os.environ["DB_URL"],
"nwp": {"ukv": nwp_path},
"nwp": {"ukv": nwp_ukv_path, "ecmwf": nwp_ecmwf_path},
"satellite": sat_path,
# TODO: include hrvsatellite
}
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
pydantic
pytorch-lightning==2.1.3
torch[cpu]==2.2.0
PVNet-summation==0.1.2
pvnet==2.6.10
ocf_datapipes==3.2.7
PVNet-summation==0.1.3
pvnet==3.0.11
ocf_datapipes==3.2.11
nowcasting_datamodel>=1.5.30
fsspec[s3]
xarray
Expand Down
28 changes: 21 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,9 @@ def db_session(db_connection, engine_url):
s.rollback()


@pytest.fixture
def nwp_data():
def make_nwp_data(shell_path, varname):
# Load dataset which only contains coordinates, but no data
ds = xr.open_zarr(
f"{os.path.dirname(os.path.abspath(__file__))}/test_data/nwp_shell.zarr"
)
ds = xr.open_zarr(shell_path)

# Last init time was at least 2 hours ago and floor to 3-hour interval
t0_datetime_utc = time_before_present(timedelta(hours=2)).floor(timedelta(hours=3))
Expand All @@ -106,15 +103,32 @@ def nwp_data():
ds[v].encoding.clear()

# Add data to dataset
ds["UKV"] = xr.DataArray(
ds[varname] = xr.DataArray(
np.zeros([len(ds[c]) for c in ds.xindexes]),
coords=[ds[c] for c in ds.xindexes],
)

# Add stored attributes to DataArray
ds.UKV.attrs = ds.attrs["_data_attrs"]
ds[varname].attrs = ds.attrs["_data_attrs"]
del ds.attrs["_data_attrs"]

return ds


@pytest.fixture
def nwp_ukv_data():
return make_nwp_data(
shell_path=f"{os.path.dirname(os.path.abspath(__file__))}/test_data/nwp_ukv_shell.zarr",
varname="UKV",
)


@pytest.fixture
def nwp_ecmwf_data():
return make_nwp_data(
shell_path=f"{os.path.dirname(os.path.abspath(__file__))}/test_data/nwp_ecmwf_shell.zarr",
varname="UKV",
)


@pytest.fixture()
Expand Down
21 changes: 17 additions & 4 deletions tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,12 @@
ForecastValueSQL,
)

from pvnet_app.consts import sat_path, nwp_ukv_path, nwp_ecmwf_path
from pvnet_app.data import sat_5_path, sat_15_path

def test_app(db_session, nwp_data, sat_5_data, gsp_yields_and_systems, me_latest):
def test_app(
db_session, nwp_ukv_data, nwp_ecmwf_data, sat_5_data, gsp_yields_and_systems, me_latest
):
# Environment variable DB_URL is set in engine_url, which is called by db_session
# set NWP_ZARR_PATH
# save nwp_data to temporary file, and set NWP_ZARR_PATH
Expand All @@ -22,9 +26,13 @@ def test_app(db_session, nwp_data, sat_5_data, gsp_yields_and_systems, me_latest
with tempfile.TemporaryDirectory() as tmpdirname:
# The app loads sat and NWP data from environment variable
# Save out data, and set paths as environmental variables
temp_nwp_path = f"{tmpdirname}/nwp.zarr"
os.environ["NWP_ZARR_PATH"] = temp_nwp_path
nwp_data.to_zarr(temp_nwp_path)
temp_nwp_path = f"{tmpdirname}/nwp_ukv.zarr"
os.environ["NWP_UKV_ZARR_PATH"] = temp_nwp_path
nwp_ukv_data.to_zarr(temp_nwp_path)

temp_nwp_path = f"{tmpdirname}/nwp_ecmwf.zarr"
os.environ["NWP_ECMWF_ZARR_PATH"] = temp_nwp_path
nwp_ecmwf_data.to_zarr(temp_nwp_path)

# In production sat zarr is zipped
temp_sat_path = f"{tmpdirname}/sat.zarr.zip"
Expand All @@ -41,6 +49,11 @@ def test_app(db_session, nwp_data, sat_5_data, gsp_yields_and_systems, me_latest
from pvnet_app.app import app
app(gsp_ids=list(range(1, 318)), num_workers=2)

os.system(f"rm {sat_5_path}")
os.system(f"rm {sat_15_path}")
os.system(f"rm -r {sat_path}")
os.system(f"rm -r {nwp_ukv_path}")
os.system(f"rm -r {nwp_ecmwf_path}")
# Check forecasts have been made
# (317 GSPs + 1 National + GSP-sum) = 319 forecasts
# Doubled for historic and forecast
Expand Down
Loading

0 comments on commit 7cdc79d

Please sign in to comment.