Skip to content

Commit

Permalink
add in MO GLobal (#127)
Browse files Browse the repository at this point in the history
* first try at adding in mo global

* lint

* lint
  • Loading branch information
peterdudfield authored Nov 20, 2024
1 parent 4cbcc37 commit 7ea6c6a
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 24 deletions.
6 changes: 6 additions & 0 deletions india_forecast_app/models/all_models.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ models:
version: ae07c15de064e1d03cf4bc02618b65c6d5b17e8e
client: ruvnl
asset_type: wind
- name: windnet_india_mo
type: pvnet
id: openclimatefix/windnet_india
version: 546baded3d4216736d8ee8d6798d47235bd72b08
client: ruvnl
asset_type: wind
# RU client solar
- name: pvnet_india
type: pvnet
Expand Down
1 change: 1 addition & 0 deletions india_forecast_app/models/pvnet/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
root_data_path = "data"
nwp_path = f"{root_data_path}/nwp.zarr"
nwp_ecmwf_path = f"{root_data_path}/nwp_ecmwf.zarr"
nwp_mo_global_path = f"{root_data_path}/nwp_mo_global.zarr"
nwp_gfs_path = f"{root_data_path}/nwp_gfs.zarr"
wind_path = f"{root_data_path}/wind"
pv_path = f"{root_data_path}/pv"
Expand Down
29 changes: 22 additions & 7 deletions india_forecast_app/models/pvnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .consts import (
nwp_ecmwf_path,
nwp_gfs_path,
nwp_mo_global_path,
pv_metadata_path,
pv_netcdf_path,
pv_path,
Expand Down Expand Up @@ -78,8 +79,8 @@ def __init__(

# Setup the data, dataloader, and model
self.generation_data = generation_data
self._prepare_data_sources()
self.dataloader = self._create_dataloader()
self._prepare_data_sources()
self.model = self._load_model()

def predict(self, site_id: str, timestamp: dt.datetime):
Expand Down Expand Up @@ -205,14 +206,26 @@ def _prepare_data_sources(self):
pass

# Load remote zarr source
nwp_ecmwf_source_file_path = os.environ["NWP_ECMWF_ZARR_PATH"]
nwp_gfs_source_file_path = os.environ["NWP_GFS_ZARR_PATH"]

use_satellite = os.getenv("USE_SATELLITE", "false").lower() == "true"
satellite_source_file_path = os.getenv("SATELLITE_ZARR_PATH", None)

nwp_source_file_paths = [nwp_ecmwf_source_file_path, nwp_gfs_source_file_path]
nwp_paths = [nwp_ecmwf_path, nwp_gfs_path]
# only load nwp that we need
nwp_paths = []
nwp_source_file_paths = []
nwp_keys = self.config["input_data"]["nwp"].keys()
if "ecmwf" in nwp_keys:
nwp_ecmwf_source_file_path = os.environ["NWP_ECMWF_ZARR_PATH"]
nwp_source_file_paths.append(nwp_ecmwf_source_file_path)
nwp_paths.append(nwp_ecmwf_path)
if "gfs" in nwp_keys:
nwp_gfs_source_file_path = os.environ["NWP_GFS_ZARR_PATH"]
nwp_source_file_paths.append(nwp_gfs_source_file_path)
nwp_paths.append(nwp_gfs_path)
if "mo_global" in nwp_keys:
nwp_mo_global_source_file_path = os.environ["NWP_MO_GLOBAL_ZARR_PATH"]
nwp_source_file_paths.append(nwp_mo_global_source_file_path)
nwp_paths.append(nwp_mo_global_path)

# Remove local cached zarr if already exists
for nwp_source_file_path, nwp_path in zip(nwp_source_file_paths, nwp_paths, strict=False):
# Process/cache remote zarr locally
Expand Down Expand Up @@ -293,7 +306,9 @@ def _create_dataloader(self):
temp_dir = tempfile.TemporaryDirectory()
populated_data_config_filename = f"{temp_dir.name}/data_config.yaml"

populate_data_config_sources(data_config_filename, populated_data_config_filename)
self.config = populate_data_config_sources(
data_config_filename, populated_data_config_filename
)

# Location and time datapipes
gen_sites = self.generation_data["metadata"]
Expand Down
7 changes: 5 additions & 2 deletions india_forecast_app/models/pvnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .consts import (
nwp_ecmwf_path,
nwp_gfs_path,
nwp_mo_global_path,
pv_metadata_path,
pv_netcdf_path,
satellite_path,
Expand Down Expand Up @@ -50,8 +51,8 @@ def populate_data_config_sources(input_path, output_path):
production_paths = {
"wind": {"filename": wind_netcdf_path, "metadata_filename": wind_metadata_path},
"pv": {"filename": pv_netcdf_path, "metadata_filename": pv_metadata_path},
"nwp": {"ecmwf": nwp_ecmwf_path, "gfs": nwp_gfs_path},
"satellite": {"filepath": satellite_path},
"nwp": {"ecmwf": nwp_ecmwf_path, "gfs": nwp_gfs_path, "mo_global": nwp_mo_global_path},
"satellite": {"filepath": satellite_path}
}

if "nwp" in config["input_data"]:
Expand Down Expand Up @@ -85,6 +86,8 @@ def populate_data_config_sources(input_path, output_path):
with open(output_path, "w") as outfile:
yaml.dump(config, outfile, default_flow_style=False)

return config


def process_and_cache_nwp(source_nwp_path: str, dest_nwp_path: str):
"""Reads zarr file, renames t variable to t2m and saves zarr to new destination"""
Expand Down
10 changes: 5 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pvnet = "3.0.52"
pytz = "^2024.1"
numpy = "^1.26.4"
huggingface-hub = "0.20.3"
ocf-datapipes = "3.3.50"
ocf-datapipes = "3.3.52"
pyogrio = "0.8.0" # 0.9.0 seems to cause an error at the moment
torch = [
{url="https://download.pytorch.org/whl/cpu/torch-2.2.1%2Bcpu-cp311-cp311-linux_x86_64.whl", markers="platform_system == \"Linux\" and platform_machine == \"x86_64\""},
Expand Down
43 changes: 43 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,49 @@ def nwp_gfs_data(tmp_path_factory, time_before_present):
ds.to_zarr(temp_nwp_path_gfs)


@pytest.fixture(scope="session")
def nwp_mo_global_data(tmp_path_factory, time_before_present):
"""Dummy NWP data"""

# Load dataset which only contains coordinates, but no data
ds = xr.open_zarr(
f"{os.path.dirname(os.path.abspath(__file__))}/test_data/nwp-no-data_gfs.zarr"
)

# Last t0 to at least 6 hours ago and floor to 3-hour interval
t0_datetime_utc = time_before_present(dt.timedelta(hours=0)).floor("3h")
t0_datetime_utc = t0_datetime_utc - dt.timedelta(hours=6)
ds.init_time_utc.values[:] = pd.date_range(
t0_datetime_utc - dt.timedelta(hours=12 * (len(ds.init_time_utc) - 1)),
t0_datetime_utc,
freq=dt.timedelta(hours=1),
)
# force lat and lon to be in 0.1 steps
ds.latitude.values[:] = [35.0 - i * 0.1 for i in range(len(ds.latitude))]
ds.longitude.values[:] = [65.0 + i * 0.1 for i in range(len(ds.longitude))]

# This is important to avoid saving errors
for v in list(ds.coords.keys()):
if ds.coords[v].dtype == object:
ds[v].encoding.clear()

for v in list(ds.variables.keys()):
if ds[v].dtype == object:
ds[v].encoding.clear()

ds["mo_global"] = xr.DataArray(
np.zeros([len(ds[c]) for c in ds.xindexes]),
coords=[ds[c] for c in ds.xindexes],
)

# AS NWP data is loaded by the app from environment variable,
# save out data and set paths as environmental variables
temp_nwp_path_gfs = f"{tmp_path_factory.mktemp('data')}/nwp_mo_global.zarr"

os.environ["NWP_MO_GLOBAL_ZARR_PATH"] = temp_nwp_path_gfs
ds.to_zarr(temp_nwp_path_gfs)


@pytest.fixture(scope="session")
def client_ad():
"""Set ad client env var"""
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_pydantic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
def test_get_all_models():
"""Test for getting all models"""
models = get_all_models()
assert len(models.models) == 4
assert len(models.models) == 5


def test_get_all_models_client():
Expand Down
20 changes: 12 additions & 8 deletions tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,14 @@ def test_save_forecast(db_session, sites, forecast_values):
)

assert db_session.query(ForecastSQL).count() == 2
assert db_session.query(ForecastValueSQL).count() == 10*2
assert db_session.query(ForecastValueSQL).count() == 10 * 2
assert db_session.query(MLModelSQL).count() == 2


@pytest.mark.parametrize("write_to_db", [True, False])
def test_app(write_to_db, db_session, sites, nwp_data, nwp_gfs_data, generation_db_values):
def test_app(
write_to_db, db_session, sites, nwp_data, nwp_gfs_data, nwp_mo_global_data, generation_db_values
):
"""Test for running app from command line"""

init_n_forecasts = db_session.query(ForecastSQL).count()
Expand All @@ -159,15 +161,17 @@ def test_app(write_to_db, db_session, sites, nwp_data, nwp_gfs_data, generation_
assert result.exit_code == 0

if write_to_db:
assert db_session.query(ForecastSQL).count() == init_n_forecasts + 2*2
assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values + (2*2 * 192)
assert db_session.query(MLModelSQL).count() == 2*2
assert db_session.query(ForecastSQL).count() == init_n_forecasts + 2 * 2
assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values + (2 * 2 * 192)
assert db_session.query(MLModelSQL).count() == 2 * 2
else:
assert db_session.query(ForecastSQL).count() == init_n_forecasts
assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values


def test_app_no_pv_data(db_session, sites, nwp_data, nwp_gfs_data, generation_db_values_only_wind):
def test_app_no_pv_data(
db_session, sites, nwp_data, nwp_gfs_data, nwp_mo_global_data, generation_db_values_only_wind
):
"""Test for running app from command line"""

init_n_forecasts = db_session.query(ForecastSQL).count()
Expand All @@ -179,8 +183,8 @@ def test_app_no_pv_data(db_session, sites, nwp_data, nwp_gfs_data, generation_db
result = run_click_script(app, args)
assert result.exit_code == 0

assert db_session.query(ForecastSQL).count() == init_n_forecasts + 2*2
assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values + (2*2 * 192)
assert db_session.query(ForecastSQL).count() == init_n_forecasts + 2 * 2
assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values + (2 * 2 * 192)


@pytest.mark.requires_hf_token
Expand Down

0 comments on commit 7ea6c6a

Please sign in to comment.