From 8d0550b4a4d2841b013af125ff120c17b3835b5c Mon Sep 17 00:00:00 2001 From: James Fulton Date: Wed, 13 Dec 2023 10:46:18 +0000 Subject: [PATCH] add test for 15-minute fallback --- tests/conftest.py | 27 ++++++++++++++++++++--- tests/test_app.py | 56 ++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 79 insertions(+), 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 23f1fe2..642e1a8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -124,9 +124,8 @@ def sat_data(): f"{os.path.dirname(os.path.abspath(__file__))}/test_data/non_hrv_shell.zarr" ) - # Change times so they lead up to present. Delayed by at most 1 hour - t0_datetime_utc = time_before_present(timedelta(minutes=0)).floor(timedelta(minutes=30)) - t0_datetime_utc = t0_datetime_utc - timedelta(minutes=30) + # Change times so they lead up to present. Delayed by 30-60 mins + t0_datetime_utc = time_before_present(timedelta(minutes=30)).floor(timedelta(minutes=30)) ds.time.values[:] = pd.date_range( t0_datetime_utc - timedelta(minutes=5 * (len(ds.time) - 1)), t0_datetime_utc, @@ -146,6 +145,28 @@ def sat_data(): return ds +@pytest.fixture() +def sat_data_delayed(sat_data): + sat_delayed = sat_data.copy(deep=True) + + # Set the most recent timestamp to 2 - 2.5 hours ago + t_most_recent = time_before_present(timedelta(hours=2)).floor(timedelta(minutes=30)) + offset = sat_delayed.time.max().values - t_most_recent + sat_delayed.time.values[:] = sat_delayed.time.values - offset + return sat_delayed + + +@pytest.fixture() +def sat_15_data(sat_data): + freq = timedelta(minutes=15) + times_15 = pd.date_range( + pd.to_datetime(sat_data.time.min().values).ceil(freq), + pd.to_datetime(sat_data.time.max().values).floor(freq), + freq=freq, + ) + return sat_data.sel(time=times_15) + + @pytest.fixture() def gsp_yields_and_systems(db_session): """Create gsp yields and systems""" diff --git a/tests/test_app.py b/tests/test_app.py index 41b735b..1191b18 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -21,7 +21,7 @@ def test_app(db_session, nwp_data, sat_data, gsp_yields_and_systems, me_latest): with tempfile.TemporaryDirectory() as tmpdirname: # The app loads sat and NWP data from environment variable - # Save out data and set paths + # Save out data, and set paths as environmental variables temp_nwp_path = f"{tmpdirname}/nwp.zarr" os.environ["NWP_ZARR_PATH"] = temp_nwp_path nwp_data.to_zarr(temp_nwp_path) @@ -55,3 +55,57 @@ def test_app(db_session, nwp_data, sat_data, gsp_yields_and_systems, me_latest): assert len(db_session.query(ForecastValueSQL).all()) == 319 * 16 assert len(db_session.query(ForecastValueLatestSQL).all()) == 319 * 16 assert len(db_session.query(ForecastValueSevenDaysSQL).all()) == 319 * 16 + + +def test_app_15( + db_session, nwp_data, sat_data_delayed, sat_15_data, gsp_yields_and_systems, me_latest +): + # Environment variable DB_URL is set in engine_url, which is called by db_session + # set NWP_ZARR_PATH + # save nwp_data to temporary file, and set NWP_ZARR_PATH + # SATELLITE_ZARR_PATH + # save sat_data to temporary file, and set SATELLITE_ZARR_PATH + # GSP data + + with tempfile.TemporaryDirectory() as tmpdirname: + # The app loads sat and NWP data from environment variable + # Save out data, and set paths as environmental variables + temp_nwp_path = f"{tmpdirname}/nwp.zarr" + os.environ["NWP_ZARR_PATH"] = temp_nwp_path + nwp_data.to_zarr(temp_nwp_path) + + # In production sat zarr is zipped + temp_sat_path = f"{tmpdirname}/sat.zarr.zip" + os.environ["SATELLITE_ZARR_PATH"] = temp_sat_path + store = zarr.storage.ZipStore(temp_sat_path, mode="x") + sat_data_delayed.to_zarr(store) + store.close() + + # Save the 15-minute data too + temp_sat_path = f"{tmpdirname}/sat_15.zarr.zip" + store = zarr.storage.ZipStore(temp_sat_path, mode="x") + sat_15_data.to_zarr(store) + store.close() + + # Set model version + os.environ["SAVE_GSP_SUM"] = "True" + + # Run prediction + # This import needs to come after the environ vars have been set + from pvnet_app.app import app + app(gsp_ids=list(range(1, 318)), num_workers=2) + + # Check forecasts have been made + # (317 GSPs + 1 National + GSP-sum) = 319 forecasts + # Doubled for historic and forecast + forecasts = db_session.query(ForecastSQL).all() + assert len(forecasts) == 319 * 2 + + # Check probabilistic added + assert "90" in forecasts[0].forecast_values[0].properties + assert "10" in forecasts[0].forecast_values[0].properties + + # 318 GSPs * 16 time steps in forecast + assert len(db_session.query(ForecastValueSQL).all()) == 319 * 16 + assert len(db_session.query(ForecastValueLatestSQL).all()) == 319 * 16 + assert len(db_session.query(ForecastValueSevenDaysSQL).all()) == 319 * 16