From 78bc0a022dabf5a980b262b9d05a55993468971f Mon Sep 17 00:00:00 2001 From: Peter Dudfield <34686298+peterdudfield@users.noreply.github.com> Date: Thu, 2 Feb 2023 09:11:24 +0000 Subject: [PATCH] Update read (#39) * make read function more flexible * lint * seperate out different read functions * use contain eager --- sdk/python/pvsite_datamodel/read/__init__.py | 2 +- .../read/latest_forecast_values.py | 90 ++++++++++++++++--- sdk/python/pvsite_datamodel/sqlmodels.py | 6 +- sdk/python/tests/conftest.py | 37 ++++++++ sdk/python/tests/test_read.py | 65 +++++++------- 5 files changed, 153 insertions(+), 47 deletions(-) diff --git a/sdk/python/pvsite_datamodel/read/__init__.py b/sdk/python/pvsite_datamodel/read/__init__.py index 6d91e29..63ce877 100644 --- a/sdk/python/pvsite_datamodel/read/__init__.py +++ b/sdk/python/pvsite_datamodel/read/__init__.py @@ -3,7 +3,7 @@ """ from .generation import get_pv_generation_by_client, get_pv_generation_by_sites -from .latest_forecast_values import get_latest_forecast_values_by_site +from .latest_forecast_values import get_latest_forecast_values_by_site, get_forecast_values_by_site_latest from .site import ( get_all_sites, get_site_by_client_site_id, diff --git a/sdk/python/pvsite_datamodel/read/latest_forecast_values.py b/sdk/python/pvsite_datamodel/read/latest_forecast_values.py index a1ed04c..dff74f2 100644 --- a/sdk/python/pvsite_datamodel/read/latest_forecast_values.py +++ b/sdk/python/pvsite_datamodel/read/latest_forecast_values.py @@ -2,29 +2,34 @@ import datetime as dt import uuid -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union -from sqlalchemy.orm import Query, Session +from sqlalchemy.orm import Query, Session, contains_eager -from pvsite_datamodel.sqlmodels import DatetimeIntervalSQL, LatestForecastValueSQL +from pvsite_datamodel.sqlmodels import ( + DatetimeIntervalSQL, + ForecastValueSQL, + LatestForecastValueSQL, + ForecastSQL, +) def get_latest_forecast_values_by_site( session: Session, site_uuids: List[uuid.UUID], start_utc: Optional[dt.datetime] = None, - model=LatestForecastValueSQL, ) -> Dict[uuid.UUID, List[LatestForecastValueSQL]]: """Get the latest forecast values by input sites. + This reads the LatestForecastValueSQL table + :param session: The sqlalchemy database session :param site_uuids: list of site_uuids for which to fetch latest forecast values :param start_utc: filters on forecast values target_time >= start_utc - :param model, the database table to use, could be 'LatestForecastValueSQL' or ForecastValueSQL :return: dict containing {site_uuid1: List[LatestForecastValueSQL], site_uuid2: ...} """ # start main query - query: Query = session.query(model) + query: Query = session.query(LatestForecastValueSQL) query = query.join(DatetimeIntervalSQL) if start_utc is not None: @@ -32,27 +37,86 @@ def get_latest_forecast_values_by_site( # also filter on creation time, to speed up things created_utc_filter = start_utc - dt.timedelta(days=1) - query = query.filter(model.created_utc >= created_utc_filter) + query = query.filter(LatestForecastValueSQL.created_utc >= created_utc_filter) - output_dict: Dict[uuid.UUID, List[model]] = {} + output_dict: Dict[uuid.UUID, List[LatestForecastValueSQL]] = {} # Filter the query on the desired sites - query = query.where(model.site_uuid.in_(site_uuids)) + query = query.where(LatestForecastValueSQL.site_uuid.in_(site_uuids)) # order by site, target time and created time desc query.order_by( - model.site_uuid, + LatestForecastValueSQL.site_uuid, DatetimeIntervalSQL.start_utc, - model.created_utc, + LatestForecastValueSQL.created_utc, ) - latest_forecast_values: List[model] = query.all() + latest_forecast_values: List[LatestForecastValueSQL] = query.all() for site_uuid in site_uuids: - site_latest_forecast_values: List[model] = [ + site_latest_forecast_values: List[LatestForecastValueSQL] = [ lfv for lfv in latest_forecast_values if lfv.site_uuid == site_uuid ] output_dict[site_uuid] = site_latest_forecast_values return output_dict + + +def get_forecast_values_by_site_latest( + session: Session, + site_uuids: List[uuid.UUID], + start_utc: Optional[dt.datetime] = None, +) -> Dict[uuid.UUID, List[ForecastValueSQL]]: + """Get the forecast values by input sites, get the lastest value + + This reads the ForecastValueSQL table + + :param session: The sqlalchemy database session + :param site_uuids: list of site_uuids for which to fetch latest forecast values + :param start_utc: filters on forecast values target_time >= start_utc + :param model, the database table to use, could be 'LatestForecastValueSQL' or ForecastValueSQL + :return: dict containing {site_uuid1: List[LatestForecastValueSQL], site_uuid2: ...} + """ + # start main query + query: Query = session.query(ForecastValueSQL) + query = query.join(DatetimeIntervalSQL) + + # use distinct if using `ForecastValueSQL` + query = query.join(ForecastSQL) + query = query.distinct(ForecastSQL.site_uuid, DatetimeIntervalSQL.start_utc) + + if start_utc is not None: + query = query.filter(DatetimeIntervalSQL.start_utc >= start_utc) + + # also filter on creation time, to speed up things + created_utc_filter = start_utc - dt.timedelta(days=1) + query = query.filter(ForecastValueSQL.created_utc >= created_utc_filter) + + output_dict: Dict[uuid.UUID, List[ForecastValueSQL]] = {} + + # Filter the query on the desired sites + query = query.where(ForecastSQL.site_uuid.in_(site_uuids)) + + # order by site, target time and created time desc + query.order_by( + ForecastSQL.site_uuid, + DatetimeIntervalSQL.start_utc, + ForecastValueSQL.created_utc, + ) + + # make sure this is all loaded + query = query.options(contains_eager(ForecastValueSQL.forecast)).populate_existing() + query = query.options(contains_eager(ForecastValueSQL.datetime_interval)).populate_existing() + + # query results + forecast_values: List[ForecastValueSQL] = query.all() + + for site_uuid in site_uuids: + site_latest_forecast_values: List[ForecastValueSQL] = [ + fv for fv in forecast_values if fv.forecast.site_uuid == site_uuid + ] + + output_dict[site_uuid] = site_latest_forecast_values + + return output_dict diff --git a/sdk/python/pvsite_datamodel/sqlmodels.py b/sdk/python/pvsite_datamodel/sqlmodels.py index e287fe6..43f7cb4 100644 --- a/sdk/python/pvsite_datamodel/sqlmodels.py +++ b/sdk/python/pvsite_datamodel/sqlmodels.py @@ -141,7 +141,11 @@ class ForecastValueSQL(Base, CreatedMixin): UUID(as_uuid=True), sa.ForeignKey("forecasts.forecast_uuid"), nullable=False, - default=uuid.uuid4, + ) + + forecast: ForecastSQL = relationship("ForecastSQL", back_populates="forecast_values") + datetime_interval: DatetimeIntervalSQL = relationship( + "DatetimeIntervalSQL", back_populates="forecast_values" ) diff --git a/sdk/python/tests/conftest.py b/sdk/python/tests/conftest.py index 1c91dc4..bdbc027 100644 --- a/sdk/python/tests/conftest.py +++ b/sdk/python/tests/conftest.py @@ -8,6 +8,7 @@ from pvsite_datamodel import ( ClientSQL, SiteSQL, + ForecastValueSQL, GenerationSQL, StatusSQL, LatestForecastValueSQL, @@ -144,6 +145,42 @@ def latestforecastvalues(db_session, sites): db_session.commit() +@pytest.fixture() +def forecast_values(db_session, sites): + """Create some fake forecast values""" + + forecast_values = [] + forecast_version: str = "0.0.0" + start_times = [datetime.today() - timedelta(minutes=x) for x in range(10)] + + for site in sites: + forecast: ForecastSQL = ForecastSQL( + forecast_uuid=uuid.uuid4(), + site_uuid=site.site_uuid, + forecast_version=forecast_version, + ) + + db_session.add(forecast) + db_session.commit() + + for i in range(0, 10): + datetime_interval, _ = get_or_else_create_datetime_interval( + session=db_session, start_time=start_times[i] + ) + + forecast_value: ForecastValueSQL = ForecastValueSQL( + forecast_value_uuid=uuid.uuid4(), + datetime_interval_uuid=datetime_interval.datetime_interval_uuid, + forecast_generation_kw=i, + forecast_uuid=forecast.forecast_uuid, + ) + + forecast_values.append(forecast_value) + + db_session.add_all(forecast_values) + db_session.commit() + + @pytest.fixture() def datetimeintervals(db_session): """Create fake datetime intervals""" diff --git a/sdk/python/tests/test_read.py b/sdk/python/tests/test_read.py index 54a4660..d740cac 100644 --- a/sdk/python/tests/test_read.py +++ b/sdk/python/tests/test_read.py @@ -6,14 +6,17 @@ from sqlalchemy.orm import Query from typing import List -from pvsite_datamodel import SiteSQL, StatusSQL, ClientSQL, DatetimeIntervalSQL +from pvsite_datamodel import SiteSQL, StatusSQL, ClientSQL, DatetimeIntervalSQL, ForecastValueSQL from pvsite_datamodel.read import get_all_sites from pvsite_datamodel.read import get_site_by_uuid from pvsite_datamodel.read import get_site_by_client_site_id from pvsite_datamodel.read import get_pv_generation_by_sites from pvsite_datamodel.read import get_latest_status from pvsite_datamodel.read import get_pv_generation_by_client -from pvsite_datamodel.read import get_latest_forecast_values_by_site +from pvsite_datamodel.read import ( + get_latest_forecast_values_by_site, + get_forecast_values_by_site_latest, +) from pvsite_datamodel.read.utils import filter_query_by_datetime_interval import pytest @@ -47,9 +50,7 @@ class TestGetSiteByClientSiteID: def test_gets_site_successfully(self, sites, db_session): site = get_site_by_client_site_id( - session=db_session, - client_name="testclient_1", - client_site_id=1 + session=db_session, client_name="testclient_1", client_site_id=1 ) assert site.client_site_id == 1 @@ -57,9 +58,7 @@ def test_gets_site_successfully(self, sites, db_session): def test_raises_exception_when_no_such_site_exists(self, sites, db_session): with pytest.raises(Exception): _ = get_site_by_client_site_id( - session=db_session, - client_name="testclient_100", - client_site_id=1 + session=db_session, client_name="testclient_100", client_site_id=1 ) @@ -76,8 +75,7 @@ def test_returns_all_generations_for_input_client(self, generations, db_session) client: ClientSQL = query.first() generations = get_pv_generation_by_client( - session=db_session, - client_names=[client.client_name] + session=db_session, client_names=[client.client_name] ) assert len(generations) == 10 @@ -92,7 +90,7 @@ def test_returns_all_generations_in_datetime_window(self, generations, db_sessio session=db_session, client_names=[client.client_name], start_utc=window_lower, - end_utc=window_upper + end_utc=window_upper, ) assert len(generations) == 7 @@ -105,10 +103,7 @@ def test_gets_generation_for_single_input_site(self, generations, db_session): query: Query = db_session.query(SiteSQL) site: SiteSQL = query.first() - generations = get_pv_generation_by_sites( - session=db_session, - site_uuids=[site.site_uuid] - ) + generations = get_pv_generation_by_sites(session=db_session, site_uuids=[site.site_uuid]) assert len(generations) == 10 assert generations[0].datetime_interval is not None @@ -119,17 +114,13 @@ def test_gets_generation_for_multiple_input_sites(self, generations, db_session) sites: List[SiteSQL] = query.all() generations = get_pv_generation_by_sites( - session=db_session, - site_uuids=[site.site_uuid for site in sites] + session=db_session, site_uuids=[site.site_uuid for site in sites] ) assert len(generations) == 10 * len(sites) def test_returns_empty_list_for_no_input_sites(self, generations, db_session): - generations = get_pv_generation_by_sites( - session=db_session, - site_uuids=[] - ) + generations = get_pv_generation_by_sites(session=db_session, site_uuids=[]) assert len(generations) == 0 @@ -151,8 +142,7 @@ def test_gets_latest_forecast_values_with_single_site(self, latestforecastvalues site: SiteSQL = query.first() latest_forecast_values = get_latest_forecast_values_by_site( - session=db_session, - site_uuids=[site.site_uuid] + session=db_session, site_uuids=[site.site_uuid] ) assert len(latest_forecast_values) == 1 @@ -160,36 +150,48 @@ def test_gets_latest_forecast_values_with_single_site(self, latestforecastvalues assert latest_forecast_values[site.site_uuid][0].datetime_interval is not None def test_gets_latest_forecast_values_with_multiple_sites( - self, latestforecastvalues, db_session): + self, latestforecastvalues, db_session + ): query: Query = db_session.query(SiteSQL) sites: SiteSQL = query.all() latest_forecast_values = get_latest_forecast_values_by_site( - session=db_session, - site_uuids=[site.site_uuid for site in sites] + session=db_session, site_uuids=[site.site_uuid for site in sites] ) assert len(latest_forecast_values) == len(sites) - def test_gets_latest_forecast_values_filter_start_utc( - self, latestforecastvalues, db_session): + def test_gets_latest_forecast_values_filter_start_utc(self, latestforecastvalues, db_session): query: Query = db_session.query(SiteSQL) site: SiteSQL = query.first() latest_forecast_values = get_latest_forecast_values_by_site( session=db_session, site_uuids=[site.site_uuid], - start_utc=dt.datetime.today() - dt.timedelta(minutes=7) + start_utc=dt.datetime.today() - dt.timedelta(minutes=7), ) assert len(latest_forecast_values[site.site_uuid]) == 7 latest_forecast_values = get_latest_forecast_values_by_site( session=db_session, site_uuids=[site.site_uuid], - start_utc=dt.datetime.today() - dt.timedelta(minutes=5) + start_utc=dt.datetime.today() - dt.timedelta(minutes=5), ) assert len(latest_forecast_values[site.site_uuid]) == 5 + def test_gets_latest_forecast_values_forecast_value(self, forecast_values, db_session): + query: Query = db_session.query(SiteSQL) + site: SiteSQL = query.first() + + latest_forecast_values = get_forecast_values_by_site_latest( + session=db_session, + site_uuids=[site.site_uuid], + ) + + assert len(latest_forecast_values) == 1 + assert len(latest_forecast_values[site.site_uuid]) == 10 + assert latest_forecast_values[site.site_uuid][0].datetime_interval is not None + class TestFilterQueryByDatetimeInterval: """Tests for the filter_query_by_datetime_interval function""" @@ -197,8 +199,7 @@ class TestFilterQueryByDatetimeInterval: def test_returns_datetime_intervals_in_filter(self, datetimeintervals, db_session): query: Query = db_session.query(DatetimeIntervalSQL) query = filter_query_by_datetime_interval( - query=query, - start_utc=dt.datetime.today() - dt.timedelta(minutes=7) + query=query, start_utc=dt.datetime.today() - dt.timedelta(minutes=7) ) datetime_intervals: List[DatetimeIntervalSQL] = query.all()