Skip to content

Commit

Permalink
add creation filter on gsp forecast route (#293)
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdudfield authored Sep 12, 2023
1 parent bd8a041 commit 33c74ad
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def get_latest_forecast_values_for_a_specific_gsp_from_database(
forecast_horizon_minutes: Optional[int] = None,
start_datetime_utc: Optional[datetime] = None,
end_datetime_utc: Optional[datetime] = None,
creation_utc_limit: Optional[datetime] = None,
) -> List[ForecastValue]:
"""Get the forecast values for yesterday and today for one gsp
Expand All @@ -207,7 +208,7 @@ def get_latest_forecast_values_for_a_specific_gsp_from_database(

start_datetime = get_start_datetime(start_datetime=start_datetime_utc)

if forecast_horizon_minutes is None:
if (forecast_horizon_minutes is None) and (creation_utc_limit is None):
forecast_values = get_forecast_values_latest(
session=session,
gsp_id=gsp_id,
Expand All @@ -225,6 +226,7 @@ def get_latest_forecast_values_for_a_specific_gsp_from_database(
model_name="blend",
model=ForecastValueSevenDaysSQL,
only_return_latest=True,
created_utc_limit=creation_utc_limit,
)

# convert to pydantic objects
Expand Down
3 changes: 3 additions & 0 deletions src/gsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def get_forecasts_for_a_specific_gsp(
user: Auth0User = Security(get_user()),
start_datetime_utc: Optional[str] = None,
end_datetime_utc: Optional[str] = None,
creation_limit_utc: Optional[str] = None,
) -> Union[Forecast, List[ForecastValue]]:
"""### Get recent forecast values for a specific GSP
Expand All @@ -160,6 +161,7 @@ def get_forecasts_for_a_specific_gsp(
- **forecast_horizon_minutes**: optional forecast horizon in minutes (ex. 60
- **start_datetime_utc**: optional start datetime for the query.
- **end_datetime_utc**: optional end datetime for the query.
- **creation_utc_limit**: optional, only return forecasts made before this datetime.
returns the latest forecast made 60 minutes before the target time)
"""

Expand All @@ -178,6 +180,7 @@ def get_forecasts_for_a_specific_gsp(
forecast_horizon_minutes=forecast_horizon_minutes,
start_datetime_utc=start_datetime_utc,
end_datetime_utc=end_datetime_utc,
creation_utc_limit=creation_limit_utc,
)

logger.debug("Got forecast values for a specific gsp.")
Expand Down
34 changes: 34 additions & 0 deletions src/tests/test_gsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from nowcasting_datamodel.fake import make_fake_forecasts
from nowcasting_datamodel.models import (
ForecastValue,
ForecastValueSevenDaysSQL,
GSPYield,
Location,
LocationSQL,
LocationWithGSPYields,
ManyForecasts,
)
from nowcasting_datamodel.read.read import get_model
from nowcasting_datamodel.save.save import save_all_forecast_values_seven_days
from nowcasting_datamodel.save.update import update_all_forecast_latest

from database import get_session
Expand All @@ -38,6 +40,38 @@ def test_read_latest_one_gsp(db_session, api_client):
_ = [ForecastValue(**f) for f in response.json()]


def test_read_latest_one_gsp_filter_creation_utc(db_session, api_client):
"""Check main solar/GB/gsp/{gsp_id}/forecast route works"""

with freeze_time("2022-01-01"):
forecasts = make_fake_forecasts(
gsp_ids=list(range(0, 2)), session=db_session, model_name="blend", n_fake_forecasts=10
)
db_session.add_all(forecasts)
db_session.commit()
save_all_forecast_values_seven_days(forecasts=forecasts, session=db_session)

with freeze_time("2022-01-02"):
forecasts_2 = make_fake_forecasts(
gsp_ids=list(range(0, 2)), session=db_session, model_name="blend", n_fake_forecasts=10
)
db_session.add_all(forecasts_2)
db_session.commit()
save_all_forecast_values_seven_days(forecasts=forecasts_2, session=db_session)
assert len(db_session.query(ForecastValueSevenDaysSQL).all()) == 2 * 2 * 10

with freeze_time("2022-01-03"):
app.dependency_overrides[get_session] = lambda: db_session

response = api_client.get("/v0/solar/GB/gsp/1/forecast?creation_limit_utc=2022-01-02")

assert response.status_code == 200

f = [ForecastValue(**f) for f in response.json()]
assert len(f) == 10
assert f[0].target_time == forecasts[1].forecast_values[0].target_time


def test_read_latest_all_gsp(db_session, api_client):
"""Check main solar/GB/gsp/forecast/all route works"""

Expand Down

0 comments on commit 33c74ad

Please sign in to comment.