From 08c0c951dd1dd881a892d161f750e3ed380f1ede Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Tue, 26 Sep 2023 17:13:06 +0100 Subject: [PATCH] add tests for national --- src/database.py | 3 ++ src/national.py | 6 +++- src/tests/test_national.py | 61 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 69 insertions(+), 1 deletion(-) diff --git a/src/database.py b/src/database.py index 7fd8391..7dca3b1 100644 --- a/src/database.py +++ b/src/database.py @@ -229,6 +229,9 @@ def get_latest_forecast_values_for_a_specific_gsp_from_database( created_utc_limit=creation_utc_limit, ) + if len(forecast_values) == 0: + return [] + # convert to pydantic objects if ( isinstance(forecast_values[0], ForecastValueSevenDaysSQL) diff --git a/src/national.py b/src/national.py index 51b9f1d..72a4473 100644 --- a/src/national.py +++ b/src/national.py @@ -16,7 +16,7 @@ get_truth_values_for_a_specific_gsp_from_database, ) from pydantic_models import NationalForecast, NationalForecastValue, NationalYield -from utils import format_plevels +from utils import format_datetime, format_plevels logger = structlog.stdlib.get_logger() @@ -66,6 +66,10 @@ def get_national_forecast( """ logger.debug("Get national forecasts") + start_datetime_utc = format_datetime(start_datetime_utc) + end_datetime_utc = format_datetime(end_datetime_utc) + creation_limit_utc = format_datetime(creation_limit_utc) + logger.debug("Getting forecast.") if include_metadata: if forecast_horizon_minutes is not None: diff --git a/src/tests/test_national.py b/src/tests/test_national.py index 164e321..fb0ebda 100644 --- a/src/tests/test_national.py +++ b/src/tests/test_national.py @@ -6,6 +6,7 @@ from nowcasting_datamodel.fake import make_fake_national_forecast from nowcasting_datamodel.models import GSPYield, Location, LocationSQL from nowcasting_datamodel.read.read import get_model +from nowcasting_datamodel.save.save import save_all_forecast_values_seven_days from nowcasting_datamodel.save.update import update_all_forecast_latest from database import get_session @@ -42,6 +43,66 @@ def test_read_latest_national_values(db_session, api_client): ) +def test_read_latest_national_values_creation_limit(db_session, api_client): + """Check main solar/GB/national/forecast route works""" + + with freeze_time("2023-01-01"): + model = get_model(db_session, name="blend", version="0.0.1") + + forecast = make_fake_national_forecast( + session=db_session, t0_datetime_utc=datetime.now(tz=timezone.utc) + ) + forecast.model = model + db_session.add(forecast) + update_all_forecast_latest(forecasts=[forecast], session=db_session) + save_all_forecast_values_seven_days(forecasts=[forecast], session=db_session) + + with freeze_time("2023-01-02"): + app.dependency_overrides[get_session] = lambda: db_session + + response = api_client.get("/v0/solar/GB/national/forecast?creation_limit_utc=2023-01-02") + assert response.status_code == 200 + + national_forecast_values = [NationalForecastValue(**f) for f in response.json()] + assert len(national_forecast_values) == 16 + + response = api_client.get("/v0/solar/GB/national/forecast?creation_limit_utc=2022-12-31") + assert response.status_code == 200 + + national_forecast_values = [NationalForecastValue(**f) for f in response.json()] + assert len(national_forecast_values) == 0 + + +def test_read_latest_national_values_start_and_end_filters(db_session, api_client): + """Check main solar/GB/national/forecast route works""" + + with freeze_time("2023-01-01"): + model = get_model(db_session, name="blend", version="0.0.1") + + forecast = make_fake_national_forecast( + session=db_session, t0_datetime_utc=datetime.now(tz=timezone.utc) + ) + forecast.model = model + db_session.add(forecast) + update_all_forecast_latest(forecasts=[forecast], session=db_session) + + app.dependency_overrides[get_session] = lambda: db_session + + response = api_client.get("/v0/solar/GB/national/forecast?start_datetime_utc=2023-01-01") + assert response.status_code == 200 + + national_forecast_values = [NationalForecastValue(**f) for f in response.json()] + assert len(national_forecast_values) == 16 + + response = api_client.get( + "/v0/solar/GB/national/forecast?start_datetime_utc=2023-01-01&end_datetime_utc=2023-01-01 04:00" + ) + assert response.status_code == 200 + + national_forecast_values = [NationalForecastValue(**f) for f in response.json()] + assert len(national_forecast_values) == 9 + + def test_get_national_forecast(db_session, api_client): """Check main solar/GB/national/forecast route works"""