Skip to content

Commit

Permalink
add tests for national
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdudfield committed Sep 26, 2023
1 parent 9b5aac8 commit 08c0c95
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ def get_latest_forecast_values_for_a_specific_gsp_from_database(
created_utc_limit=creation_utc_limit,
)

if len(forecast_values) == 0:
return []

# convert to pydantic objects
if (
isinstance(forecast_values[0], ForecastValueSevenDaysSQL)
Expand Down
6 changes: 5 additions & 1 deletion src/national.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
get_truth_values_for_a_specific_gsp_from_database,
)
from pydantic_models import NationalForecast, NationalForecastValue, NationalYield
from utils import format_plevels
from utils import format_datetime, format_plevels

logger = structlog.stdlib.get_logger()

Expand Down Expand Up @@ -66,6 +66,10 @@ def get_national_forecast(
"""
logger.debug("Get national forecasts")

start_datetime_utc = format_datetime(start_datetime_utc)
end_datetime_utc = format_datetime(end_datetime_utc)
creation_limit_utc = format_datetime(creation_limit_utc)

logger.debug("Getting forecast.")
if include_metadata:
if forecast_horizon_minutes is not None:
Expand Down
61 changes: 61 additions & 0 deletions src/tests/test_national.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from nowcasting_datamodel.fake import make_fake_national_forecast
from nowcasting_datamodel.models import GSPYield, Location, LocationSQL
from nowcasting_datamodel.read.read import get_model
from nowcasting_datamodel.save.save import save_all_forecast_values_seven_days
from nowcasting_datamodel.save.update import update_all_forecast_latest

from database import get_session
Expand Down Expand Up @@ -42,6 +43,66 @@ def test_read_latest_national_values(db_session, api_client):
)


def test_read_latest_national_values_creation_limit(db_session, api_client):
"""Check main solar/GB/national/forecast route works"""

with freeze_time("2023-01-01"):
model = get_model(db_session, name="blend", version="0.0.1")

forecast = make_fake_national_forecast(
session=db_session, t0_datetime_utc=datetime.now(tz=timezone.utc)
)
forecast.model = model
db_session.add(forecast)
update_all_forecast_latest(forecasts=[forecast], session=db_session)
save_all_forecast_values_seven_days(forecasts=[forecast], session=db_session)

with freeze_time("2023-01-02"):
app.dependency_overrides[get_session] = lambda: db_session

response = api_client.get("/v0/solar/GB/national/forecast?creation_limit_utc=2023-01-02")
assert response.status_code == 200

national_forecast_values = [NationalForecastValue(**f) for f in response.json()]
assert len(national_forecast_values) == 16

response = api_client.get("/v0/solar/GB/national/forecast?creation_limit_utc=2022-12-31")
assert response.status_code == 200

national_forecast_values = [NationalForecastValue(**f) for f in response.json()]
assert len(national_forecast_values) == 0


def test_read_latest_national_values_start_and_end_filters(db_session, api_client):
"""Check main solar/GB/national/forecast route works"""

with freeze_time("2023-01-01"):
model = get_model(db_session, name="blend", version="0.0.1")

forecast = make_fake_national_forecast(
session=db_session, t0_datetime_utc=datetime.now(tz=timezone.utc)
)
forecast.model = model
db_session.add(forecast)
update_all_forecast_latest(forecasts=[forecast], session=db_session)

app.dependency_overrides[get_session] = lambda: db_session

response = api_client.get("/v0/solar/GB/national/forecast?start_datetime_utc=2023-01-01")
assert response.status_code == 200

national_forecast_values = [NationalForecastValue(**f) for f in response.json()]
assert len(national_forecast_values) == 16

response = api_client.get(
"/v0/solar/GB/national/forecast?start_datetime_utc=2023-01-01&end_datetime_utc=2023-01-01 04:00"
)
assert response.status_code == 200

national_forecast_values = [NationalForecastValue(**f) for f in response.json()]
assert len(national_forecast_values) == 9


def test_get_national_forecast(db_session, api_client):
"""Check main solar/GB/national/forecast route works"""

Expand Down

0 comments on commit 08c0c95

Please sign in to comment.