Skip to content

Commit

Permalink
add test for 15-minute fallback
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Dec 13, 2023
1 parent 698bd2a commit 8d0550b
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 4 deletions.
27 changes: 24 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,8 @@ def sat_data():
f"{os.path.dirname(os.path.abspath(__file__))}/test_data/non_hrv_shell.zarr"
)

# Change times so they lead up to present. Delayed by at most 1 hour
t0_datetime_utc = time_before_present(timedelta(minutes=0)).floor(timedelta(minutes=30))
t0_datetime_utc = t0_datetime_utc - timedelta(minutes=30)
# Change times so they lead up to present. Delayed by 30-60 mins
t0_datetime_utc = time_before_present(timedelta(minutes=30)).floor(timedelta(minutes=30))
ds.time.values[:] = pd.date_range(
t0_datetime_utc - timedelta(minutes=5 * (len(ds.time) - 1)),
t0_datetime_utc,
Expand All @@ -146,6 +145,28 @@ def sat_data():
return ds


@pytest.fixture()
def sat_data_delayed(sat_data):
sat_delayed = sat_data.copy(deep=True)

# Set the most recent timestamp to 2 - 2.5 hours ago
t_most_recent = time_before_present(timedelta(hours=2)).floor(timedelta(minutes=30))
offset = sat_delayed.time.max().values - t_most_recent
sat_delayed.time.values[:] = sat_delayed.time.values - offset
return sat_delayed


@pytest.fixture()
def sat_15_data(sat_data):
freq = timedelta(minutes=15)
times_15 = pd.date_range(
pd.to_datetime(sat_data.time.min().values).ceil(freq),
pd.to_datetime(sat_data.time.max().values).floor(freq),
freq=freq,
)
return sat_data.sel(time=times_15)


@pytest.fixture()
def gsp_yields_and_systems(db_session):
"""Create gsp yields and systems"""
Expand Down
56 changes: 55 additions & 1 deletion tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_app(db_session, nwp_data, sat_data, gsp_yields_and_systems, me_latest):

with tempfile.TemporaryDirectory() as tmpdirname:
# The app loads sat and NWP data from environment variable
# Save out data and set paths
# Save out data, and set paths as environmental variables
temp_nwp_path = f"{tmpdirname}/nwp.zarr"
os.environ["NWP_ZARR_PATH"] = temp_nwp_path
nwp_data.to_zarr(temp_nwp_path)
Expand Down Expand Up @@ -55,3 +55,57 @@ def test_app(db_session, nwp_data, sat_data, gsp_yields_and_systems, me_latest):
assert len(db_session.query(ForecastValueSQL).all()) == 319 * 16
assert len(db_session.query(ForecastValueLatestSQL).all()) == 319 * 16
assert len(db_session.query(ForecastValueSevenDaysSQL).all()) == 319 * 16


def test_app_15(
db_session, nwp_data, sat_data_delayed, sat_15_data, gsp_yields_and_systems, me_latest
):
# Environment variable DB_URL is set in engine_url, which is called by db_session
# set NWP_ZARR_PATH
# save nwp_data to temporary file, and set NWP_ZARR_PATH
# SATELLITE_ZARR_PATH
# save sat_data to temporary file, and set SATELLITE_ZARR_PATH
# GSP data

with tempfile.TemporaryDirectory() as tmpdirname:
# The app loads sat and NWP data from environment variable
# Save out data, and set paths as environmental variables
temp_nwp_path = f"{tmpdirname}/nwp.zarr"
os.environ["NWP_ZARR_PATH"] = temp_nwp_path
nwp_data.to_zarr(temp_nwp_path)

# In production sat zarr is zipped
temp_sat_path = f"{tmpdirname}/sat.zarr.zip"
os.environ["SATELLITE_ZARR_PATH"] = temp_sat_path
store = zarr.storage.ZipStore(temp_sat_path, mode="x")
sat_data_delayed.to_zarr(store)
store.close()

# Save the 15-minute data too
temp_sat_path = f"{tmpdirname}/sat_15.zarr.zip"
store = zarr.storage.ZipStore(temp_sat_path, mode="x")
sat_15_data.to_zarr(store)
store.close()

# Set model version
os.environ["SAVE_GSP_SUM"] = "True"

# Run prediction
# This import needs to come after the environ vars have been set
from pvnet_app.app import app
app(gsp_ids=list(range(1, 318)), num_workers=2)

# Check forecasts have been made
# (317 GSPs + 1 National + GSP-sum) = 319 forecasts
# Doubled for historic and forecast
forecasts = db_session.query(ForecastSQL).all()
assert len(forecasts) == 319 * 2

# Check probabilistic added
assert "90" in forecasts[0].forecast_values[0].properties
assert "10" in forecasts[0].forecast_values[0].properties

# 318 GSPs * 16 time steps in forecast
assert len(db_session.query(ForecastValueSQL).all()) == 319 * 16
assert len(db_session.query(ForecastValueLatestSQL).all()) == 319 * 16
assert len(db_session.query(ForecastValueSevenDaysSQL).all()) == 319 * 16

0 comments on commit 8d0550b

Please sign in to comment.