From 552946b260a60a5404ef3f043e044130b3700880 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Wed, 20 Nov 2024 21:43:45 +0000 Subject: [PATCH] make tests work for mo_global --- india_forecast_app/adjuster.py | 4 ++-- tests/conftest.py | 18 ++++++++++++++---- tests/test_adjuster.py | 2 +- tests/test_app.py | 10 +++++----- 4 files changed, 22 insertions(+), 12 deletions(-) diff --git a/india_forecast_app/adjuster.py b/india_forecast_app/adjuster.py index efda22c..192d242 100644 --- a/india_forecast_app/adjuster.py +++ b/india_forecast_app/adjuster.py @@ -115,9 +115,9 @@ def get_me_values( # currently in 0, 60, 120,... # change to 0, 15, 30, 45, 60, 75, 90, 105, 120, ... me_df = me_df.set_index("horizon_minutes") - me_df = me_df.reindex(range(0, max(me_df.index), 15)).interpolate(limit=3) + me_df = me_df.reindex(range(0, max(me_df.index)+15, 15)).interpolate(limit=3) - # reset indiex + # reset index me_df = me_df.reset_index() # log the maximum and minimum adjuster results diff --git a/tests/conftest.py b/tests/conftest.py index 8f2b6b5..e3f6470 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -350,11 +350,14 @@ def nwp_mo_global_data(tmp_path_factory, time_before_present): f"{os.path.dirname(os.path.abspath(__file__))}/test_data/nwp-no-data_gfs.zarr" ) - # Last t0 to at least 6 hours ago and floor to 3-hour interval + # rename dimension init_time_utc to init_time + ds = ds.rename({"init_time_utc": "init_time"}) + + # Last t0 to at least 4 hours ago and floor to 3-hour interval t0_datetime_utc = time_before_present(dt.timedelta(hours=0)).floor("3h") - t0_datetime_utc = t0_datetime_utc - dt.timedelta(hours=6) - ds.init_time_utc.values[:] = pd.date_range( - t0_datetime_utc - dt.timedelta(hours=12 * (len(ds.init_time_utc) - 1)), + t0_datetime_utc = t0_datetime_utc - dt.timedelta(hours=4) + ds.init_time.values[:] = pd.date_range( + t0_datetime_utc - dt.timedelta(hours=12 * (len(ds.init_time) - 1)), t0_datetime_utc, freq=dt.timedelta(hours=1), ) @@ -371,6 +374,13 @@ def nwp_mo_global_data(tmp_path_factory, time_before_present): if ds[v].dtype == object: ds[v].encoding.clear() + # change variables values to for MO global + ds.variable.values[0:3] = ["temperature_sl", "wind_u_component_10m", "wind_v_component_10m"] + + # interpolate 3 hourly step to 1 hour steps + steps = pd.TimedeltaIndex(np.arange(49) * 3600 * 1e9, freq='infer') + ds = ds.interp(step=steps, method="linear") + ds["mo_global"] = xr.DataArray( np.zeros([len(ds[c]) for c in ds.xindexes]), coords=[ds[c] for c in ds.xindexes], diff --git a/tests/test_adjuster.py b/tests/test_adjuster.py index 6027b99..5c614c7 100644 --- a/tests/test_adjuster.py +++ b/tests/test_adjuster.py @@ -21,7 +21,7 @@ def test_get_me_values(db_session, sites, generation_db_values, forecasts): me_df = get_me_values(db_session, hour, site_uuid=sites[0].site_uuid, ml_model_name="test") assert len(me_df) != 0 - assert len(me_df) == 96 + assert len(me_df) == 97 assert me_df["me_kw"].sum() != 0 assert me_df["horizon_minutes"][0] == 0 assert me_df["horizon_minutes"][1] == 15 diff --git a/tests/test_app.py b/tests/test_app.py index 2f9093d..737eeb5 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -161,9 +161,9 @@ def test_app( assert result.exit_code == 0 if write_to_db: - assert db_session.query(ForecastSQL).count() == init_n_forecasts + 2 * 2 - assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values + (2 * 2 * 192) - assert db_session.query(MLModelSQL).count() == 2 * 2 + assert db_session.query(ForecastSQL).count() == init_n_forecasts + 3 * 2 + assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values + (3 * 2 * 192) + assert db_session.query(MLModelSQL).count() == 3 * 2 else: assert db_session.query(ForecastSQL).count() == init_n_forecasts assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values @@ -183,8 +183,8 @@ def test_app_no_pv_data( result = run_click_script(app, args) assert result.exit_code == 0 - assert db_session.query(ForecastSQL).count() == init_n_forecasts + 2 * 2 - assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values + (2 * 2 * 192) + assert db_session.query(ForecastSQL).count() == init_n_forecasts + 2 * 3 + assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values + (2 * 3 * 192) @pytest.mark.requires_hf_token