Skip to content

Commit

Permalink
make tests work for mo_global
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdudfield committed Nov 20, 2024
1 parent 759523b commit 552946b
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 12 deletions.
4 changes: 2 additions & 2 deletions india_forecast_app/adjuster.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ def get_me_values(
# currently in 0, 60, 120,...
# change to 0, 15, 30, 45, 60, 75, 90, 105, 120, ...
me_df = me_df.set_index("horizon_minutes")
me_df = me_df.reindex(range(0, max(me_df.index), 15)).interpolate(limit=3)
me_df = me_df.reindex(range(0, max(me_df.index)+15, 15)).interpolate(limit=3)

# reset indiex
# reset index
me_df = me_df.reset_index()

# log the maximum and minimum adjuster results
Expand Down
18 changes: 14 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,11 +350,14 @@ def nwp_mo_global_data(tmp_path_factory, time_before_present):
f"{os.path.dirname(os.path.abspath(__file__))}/test_data/nwp-no-data_gfs.zarr"
)

# Last t0 to at least 6 hours ago and floor to 3-hour interval
# rename dimension init_time_utc to init_time
ds = ds.rename({"init_time_utc": "init_time"})

# Last t0 to at least 4 hours ago and floor to 3-hour interval
t0_datetime_utc = time_before_present(dt.timedelta(hours=0)).floor("3h")
t0_datetime_utc = t0_datetime_utc - dt.timedelta(hours=6)
ds.init_time_utc.values[:] = pd.date_range(
t0_datetime_utc - dt.timedelta(hours=12 * (len(ds.init_time_utc) - 1)),
t0_datetime_utc = t0_datetime_utc - dt.timedelta(hours=4)
ds.init_time.values[:] = pd.date_range(
t0_datetime_utc - dt.timedelta(hours=12 * (len(ds.init_time) - 1)),
t0_datetime_utc,
freq=dt.timedelta(hours=1),
)
Expand All @@ -371,6 +374,13 @@ def nwp_mo_global_data(tmp_path_factory, time_before_present):
if ds[v].dtype == object:
ds[v].encoding.clear()

# change variables values to for MO global
ds.variable.values[0:3] = ["temperature_sl", "wind_u_component_10m", "wind_v_component_10m"]

# interpolate 3 hourly step to 1 hour steps
steps = pd.TimedeltaIndex(np.arange(49) * 3600 * 1e9, freq='infer')
ds = ds.interp(step=steps, method="linear")

ds["mo_global"] = xr.DataArray(
np.zeros([len(ds[c]) for c in ds.xindexes]),
coords=[ds[c] for c in ds.xindexes],
Expand Down
2 changes: 1 addition & 1 deletion tests/test_adjuster.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_get_me_values(db_session, sites, generation_db_values, forecasts):
me_df = get_me_values(db_session, hour, site_uuid=sites[0].site_uuid, ml_model_name="test")

assert len(me_df) != 0
assert len(me_df) == 96
assert len(me_df) == 97
assert me_df["me_kw"].sum() != 0
assert me_df["horizon_minutes"][0] == 0
assert me_df["horizon_minutes"][1] == 15
Expand Down
10 changes: 5 additions & 5 deletions tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,9 @@ def test_app(
assert result.exit_code == 0

if write_to_db:
assert db_session.query(ForecastSQL).count() == init_n_forecasts + 2 * 2
assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values + (2 * 2 * 192)
assert db_session.query(MLModelSQL).count() == 2 * 2
assert db_session.query(ForecastSQL).count() == init_n_forecasts + 3 * 2
assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values + (3 * 2 * 192)
assert db_session.query(MLModelSQL).count() == 3 * 2
else:
assert db_session.query(ForecastSQL).count() == init_n_forecasts
assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values
Expand All @@ -183,8 +183,8 @@ def test_app_no_pv_data(
result = run_click_script(app, args)
assert result.exit_code == 0

assert db_session.query(ForecastSQL).count() == init_n_forecasts + 2 * 2
assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values + (2 * 2 * 192)
assert db_session.query(ForecastSQL).count() == init_n_forecasts + 2 * 3
assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values + (2 * 3 * 192)


@pytest.mark.requires_hf_token
Expand Down

0 comments on commit 552946b

Please sign in to comment.