diff --git a/src/database.py b/src/database.py index 2d7e917..7fd8391 100644 --- a/src/database.py +++ b/src/database.py @@ -195,6 +195,7 @@ def get_latest_forecast_values_for_a_specific_gsp_from_database( forecast_horizon_minutes: Optional[int] = None, start_datetime_utc: Optional[datetime] = None, end_datetime_utc: Optional[datetime] = None, + creation_utc_limit: Optional[datetime] = None, ) -> List[ForecastValue]: """Get the forecast values for yesterday and today for one gsp @@ -207,7 +208,7 @@ def get_latest_forecast_values_for_a_specific_gsp_from_database( start_datetime = get_start_datetime(start_datetime=start_datetime_utc) - if forecast_horizon_minutes is None: + if (forecast_horizon_minutes is None) and (creation_utc_limit is None): forecast_values = get_forecast_values_latest( session=session, gsp_id=gsp_id, @@ -225,6 +226,7 @@ def get_latest_forecast_values_for_a_specific_gsp_from_database( model_name="blend", model=ForecastValueSevenDaysSQL, only_return_latest=True, + created_utc_limit=creation_utc_limit, ) # convert to pydantic objects diff --git a/src/gsp.py b/src/gsp.py index 1907bbb..9d9497f 100644 --- a/src/gsp.py +++ b/src/gsp.py @@ -141,6 +141,7 @@ def get_forecasts_for_a_specific_gsp( user: Auth0User = Security(get_user()), start_datetime_utc: Optional[str] = None, end_datetime_utc: Optional[str] = None, + creation_limit_utc: Optional[str] = None, ) -> Union[Forecast, List[ForecastValue]]: """### Get recent forecast values for a specific GSP @@ -160,6 +161,7 @@ def get_forecasts_for_a_specific_gsp( - **forecast_horizon_minutes**: optional forecast horizon in minutes (ex. 60 - **start_datetime_utc**: optional start datetime for the query. - **end_datetime_utc**: optional end datetime for the query. + - **creation_utc_limit**: optional, only return forecasts made before this datetime. returns the latest forecast made 60 minutes before the target time) """ @@ -178,6 +180,7 @@ def get_forecasts_for_a_specific_gsp( forecast_horizon_minutes=forecast_horizon_minutes, start_datetime_utc=start_datetime_utc, end_datetime_utc=end_datetime_utc, + creation_utc_limit=creation_limit_utc, ) logger.debug("Got forecast values for a specific gsp.") diff --git a/src/tests/test_gsp.py b/src/tests/test_gsp.py index c0dde7f..57f38f7 100644 --- a/src/tests/test_gsp.py +++ b/src/tests/test_gsp.py @@ -5,6 +5,7 @@ from nowcasting_datamodel.fake import make_fake_forecasts from nowcasting_datamodel.models import ( ForecastValue, + ForecastValueSevenDaysSQL, GSPYield, Location, LocationSQL, @@ -12,6 +13,7 @@ ManyForecasts, ) from nowcasting_datamodel.read.read import get_model +from nowcasting_datamodel.save.save import save_all_forecast_values_seven_days from nowcasting_datamodel.save.update import update_all_forecast_latest from database import get_session @@ -38,6 +40,38 @@ def test_read_latest_one_gsp(db_session, api_client): _ = [ForecastValue(**f) for f in response.json()] +def test_read_latest_one_gsp_filter_creation_utc(db_session, api_client): + """Check main solar/GB/gsp/{gsp_id}/forecast route works""" + + with freeze_time("2022-01-01"): + forecasts = make_fake_forecasts( + gsp_ids=list(range(0, 2)), session=db_session, model_name="blend", n_fake_forecasts=10 + ) + db_session.add_all(forecasts) + db_session.commit() + save_all_forecast_values_seven_days(forecasts=forecasts, session=db_session) + + with freeze_time("2022-01-02"): + forecasts_2 = make_fake_forecasts( + gsp_ids=list(range(0, 2)), session=db_session, model_name="blend", n_fake_forecasts=10 + ) + db_session.add_all(forecasts_2) + db_session.commit() + save_all_forecast_values_seven_days(forecasts=forecasts_2, session=db_session) + assert len(db_session.query(ForecastValueSevenDaysSQL).all()) == 2 * 2 * 10 + + with freeze_time("2022-01-03"): + app.dependency_overrides[get_session] = lambda: db_session + + response = api_client.get("/v0/solar/GB/gsp/1/forecast?creation_limit_utc=2022-01-02") + + assert response.status_code == 200 + + f = [ForecastValue(**f) for f in response.json()] + assert len(f) == 10 + assert f[0].target_time == forecasts[1].forecast_values[0].target_time + + def test_read_latest_all_gsp(db_session, api_client): """Check main solar/GB/gsp/forecast/all route works"""