Skip to content

Commit

Permalink
Update read (#39)
Browse files Browse the repository at this point in the history
* make read function more flexible

* lint

* seperate out different read functions

* use contain eager
  • Loading branch information
peterdudfield authored Feb 2, 2023
1 parent 88c78eb commit 78bc0a0
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 47 deletions.
2 changes: 1 addition & 1 deletion sdk/python/pvsite_datamodel/read/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

from .generation import get_pv_generation_by_client, get_pv_generation_by_sites
from .latest_forecast_values import get_latest_forecast_values_by_site
from .latest_forecast_values import get_latest_forecast_values_by_site, get_forecast_values_by_site_latest
from .site import (
get_all_sites,
get_site_by_client_site_id,
Expand Down
90 changes: 77 additions & 13 deletions sdk/python/pvsite_datamodel/read/latest_forecast_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,57 +2,121 @@

import datetime as dt
import uuid
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

from sqlalchemy.orm import Query, Session
from sqlalchemy.orm import Query, Session, contains_eager

from pvsite_datamodel.sqlmodels import DatetimeIntervalSQL, LatestForecastValueSQL
from pvsite_datamodel.sqlmodels import (
DatetimeIntervalSQL,
ForecastValueSQL,
LatestForecastValueSQL,
ForecastSQL,
)


def get_latest_forecast_values_by_site(
session: Session,
site_uuids: List[uuid.UUID],
start_utc: Optional[dt.datetime] = None,
model=LatestForecastValueSQL,
) -> Dict[uuid.UUID, List[LatestForecastValueSQL]]:
"""Get the latest forecast values by input sites.
This reads the LatestForecastValueSQL table
:param session: The sqlalchemy database session
:param site_uuids: list of site_uuids for which to fetch latest forecast values
:param start_utc: filters on forecast values target_time >= start_utc
:param model, the database table to use, could be 'LatestForecastValueSQL' or ForecastValueSQL
:return: dict containing {site_uuid1: List[LatestForecastValueSQL], site_uuid2: ...}
"""
# start main query
query: Query = session.query(model)
query: Query = session.query(LatestForecastValueSQL)
query = query.join(DatetimeIntervalSQL)

if start_utc is not None:
query = query.filter(DatetimeIntervalSQL.start_utc >= start_utc)

# also filter on creation time, to speed up things
created_utc_filter = start_utc - dt.timedelta(days=1)
query = query.filter(model.created_utc >= created_utc_filter)
query = query.filter(LatestForecastValueSQL.created_utc >= created_utc_filter)

output_dict: Dict[uuid.UUID, List[model]] = {}
output_dict: Dict[uuid.UUID, List[LatestForecastValueSQL]] = {}

# Filter the query on the desired sites
query = query.where(model.site_uuid.in_(site_uuids))
query = query.where(LatestForecastValueSQL.site_uuid.in_(site_uuids))

# order by site, target time and created time desc
query.order_by(
model.site_uuid,
LatestForecastValueSQL.site_uuid,
DatetimeIntervalSQL.start_utc,
model.created_utc,
LatestForecastValueSQL.created_utc,
)

latest_forecast_values: List[model] = query.all()
latest_forecast_values: List[LatestForecastValueSQL] = query.all()

for site_uuid in site_uuids:
site_latest_forecast_values: List[model] = [
site_latest_forecast_values: List[LatestForecastValueSQL] = [
lfv for lfv in latest_forecast_values if lfv.site_uuid == site_uuid
]

output_dict[site_uuid] = site_latest_forecast_values

return output_dict


def get_forecast_values_by_site_latest(
session: Session,
site_uuids: List[uuid.UUID],
start_utc: Optional[dt.datetime] = None,
) -> Dict[uuid.UUID, List[ForecastValueSQL]]:
"""Get the forecast values by input sites, get the lastest value
This reads the ForecastValueSQL table
:param session: The sqlalchemy database session
:param site_uuids: list of site_uuids for which to fetch latest forecast values
:param start_utc: filters on forecast values target_time >= start_utc
:param model, the database table to use, could be 'LatestForecastValueSQL' or ForecastValueSQL
:return: dict containing {site_uuid1: List[LatestForecastValueSQL], site_uuid2: ...}
"""
# start main query
query: Query = session.query(ForecastValueSQL)
query = query.join(DatetimeIntervalSQL)

# use distinct if using `ForecastValueSQL`
query = query.join(ForecastSQL)
query = query.distinct(ForecastSQL.site_uuid, DatetimeIntervalSQL.start_utc)

if start_utc is not None:
query = query.filter(DatetimeIntervalSQL.start_utc >= start_utc)

# also filter on creation time, to speed up things
created_utc_filter = start_utc - dt.timedelta(days=1)
query = query.filter(ForecastValueSQL.created_utc >= created_utc_filter)

output_dict: Dict[uuid.UUID, List[ForecastValueSQL]] = {}

# Filter the query on the desired sites
query = query.where(ForecastSQL.site_uuid.in_(site_uuids))

# order by site, target time and created time desc
query.order_by(
ForecastSQL.site_uuid,
DatetimeIntervalSQL.start_utc,
ForecastValueSQL.created_utc,
)

# make sure this is all loaded
query = query.options(contains_eager(ForecastValueSQL.forecast)).populate_existing()
query = query.options(contains_eager(ForecastValueSQL.datetime_interval)).populate_existing()

# query results
forecast_values: List[ForecastValueSQL] = query.all()

for site_uuid in site_uuids:
site_latest_forecast_values: List[ForecastValueSQL] = [
fv for fv in forecast_values if fv.forecast.site_uuid == site_uuid
]

output_dict[site_uuid] = site_latest_forecast_values

return output_dict
6 changes: 5 additions & 1 deletion sdk/python/pvsite_datamodel/sqlmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,11 @@ class ForecastValueSQL(Base, CreatedMixin):
UUID(as_uuid=True),
sa.ForeignKey("forecasts.forecast_uuid"),
nullable=False,
default=uuid.uuid4,
)

forecast: ForecastSQL = relationship("ForecastSQL", back_populates="forecast_values")
datetime_interval: DatetimeIntervalSQL = relationship(
"DatetimeIntervalSQL", back_populates="forecast_values"
)


Expand Down
37 changes: 37 additions & 0 deletions sdk/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pvsite_datamodel import (
ClientSQL,
SiteSQL,
ForecastValueSQL,
GenerationSQL,
StatusSQL,
LatestForecastValueSQL,
Expand Down Expand Up @@ -144,6 +145,42 @@ def latestforecastvalues(db_session, sites):
db_session.commit()


@pytest.fixture()
def forecast_values(db_session, sites):
"""Create some fake forecast values"""

forecast_values = []
forecast_version: str = "0.0.0"
start_times = [datetime.today() - timedelta(minutes=x) for x in range(10)]

for site in sites:
forecast: ForecastSQL = ForecastSQL(
forecast_uuid=uuid.uuid4(),
site_uuid=site.site_uuid,
forecast_version=forecast_version,
)

db_session.add(forecast)
db_session.commit()

for i in range(0, 10):
datetime_interval, _ = get_or_else_create_datetime_interval(
session=db_session, start_time=start_times[i]
)

forecast_value: ForecastValueSQL = ForecastValueSQL(
forecast_value_uuid=uuid.uuid4(),
datetime_interval_uuid=datetime_interval.datetime_interval_uuid,
forecast_generation_kw=i,
forecast_uuid=forecast.forecast_uuid,
)

forecast_values.append(forecast_value)

db_session.add_all(forecast_values)
db_session.commit()


@pytest.fixture()
def datetimeintervals(db_session):
"""Create fake datetime intervals"""
Expand Down
65 changes: 33 additions & 32 deletions sdk/python/tests/test_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@
from sqlalchemy.orm import Query
from typing import List

from pvsite_datamodel import SiteSQL, StatusSQL, ClientSQL, DatetimeIntervalSQL
from pvsite_datamodel import SiteSQL, StatusSQL, ClientSQL, DatetimeIntervalSQL, ForecastValueSQL
from pvsite_datamodel.read import get_all_sites
from pvsite_datamodel.read import get_site_by_uuid
from pvsite_datamodel.read import get_site_by_client_site_id
from pvsite_datamodel.read import get_pv_generation_by_sites
from pvsite_datamodel.read import get_latest_status
from pvsite_datamodel.read import get_pv_generation_by_client
from pvsite_datamodel.read import get_latest_forecast_values_by_site
from pvsite_datamodel.read import (
get_latest_forecast_values_by_site,
get_forecast_values_by_site_latest,
)
from pvsite_datamodel.read.utils import filter_query_by_datetime_interval

import pytest
Expand Down Expand Up @@ -47,19 +50,15 @@ class TestGetSiteByClientSiteID:

def test_gets_site_successfully(self, sites, db_session):
site = get_site_by_client_site_id(
session=db_session,
client_name="testclient_1",
client_site_id=1
session=db_session, client_name="testclient_1", client_site_id=1
)

assert site.client_site_id == 1

def test_raises_exception_when_no_such_site_exists(self, sites, db_session):
with pytest.raises(Exception):
_ = get_site_by_client_site_id(
session=db_session,
client_name="testclient_100",
client_site_id=1
session=db_session, client_name="testclient_100", client_site_id=1
)


Expand All @@ -76,8 +75,7 @@ def test_returns_all_generations_for_input_client(self, generations, db_session)
client: ClientSQL = query.first()

generations = get_pv_generation_by_client(
session=db_session,
client_names=[client.client_name]
session=db_session, client_names=[client.client_name]
)

assert len(generations) == 10
Expand All @@ -92,7 +90,7 @@ def test_returns_all_generations_in_datetime_window(self, generations, db_sessio
session=db_session,
client_names=[client.client_name],
start_utc=window_lower,
end_utc=window_upper
end_utc=window_upper,
)

assert len(generations) == 7
Expand All @@ -105,10 +103,7 @@ def test_gets_generation_for_single_input_site(self, generations, db_session):
query: Query = db_session.query(SiteSQL)
site: SiteSQL = query.first()

generations = get_pv_generation_by_sites(
session=db_session,
site_uuids=[site.site_uuid]
)
generations = get_pv_generation_by_sites(session=db_session, site_uuids=[site.site_uuid])

assert len(generations) == 10
assert generations[0].datetime_interval is not None
Expand All @@ -119,17 +114,13 @@ def test_gets_generation_for_multiple_input_sites(self, generations, db_session)
sites: List[SiteSQL] = query.all()

generations = get_pv_generation_by_sites(
session=db_session,
site_uuids=[site.site_uuid for site in sites]
session=db_session, site_uuids=[site.site_uuid for site in sites]
)

assert len(generations) == 10 * len(sites)

def test_returns_empty_list_for_no_input_sites(self, generations, db_session):
generations = get_pv_generation_by_sites(
session=db_session,
site_uuids=[]
)
generations = get_pv_generation_by_sites(session=db_session, site_uuids=[])

assert len(generations) == 0

Expand All @@ -151,54 +142,64 @@ def test_gets_latest_forecast_values_with_single_site(self, latestforecastvalues
site: SiteSQL = query.first()

latest_forecast_values = get_latest_forecast_values_by_site(
session=db_session,
site_uuids=[site.site_uuid]
session=db_session, site_uuids=[site.site_uuid]
)

assert len(latest_forecast_values) == 1
assert len(latest_forecast_values[site.site_uuid]) == 10
assert latest_forecast_values[site.site_uuid][0].datetime_interval is not None

def test_gets_latest_forecast_values_with_multiple_sites(
self, latestforecastvalues, db_session):
self, latestforecastvalues, db_session
):
query: Query = db_session.query(SiteSQL)
sites: SiteSQL = query.all()

latest_forecast_values = get_latest_forecast_values_by_site(
session=db_session,
site_uuids=[site.site_uuid for site in sites]
session=db_session, site_uuids=[site.site_uuid for site in sites]
)

assert len(latest_forecast_values) == len(sites)

def test_gets_latest_forecast_values_filter_start_utc(
self, latestforecastvalues, db_session):
def test_gets_latest_forecast_values_filter_start_utc(self, latestforecastvalues, db_session):
query: Query = db_session.query(SiteSQL)
site: SiteSQL = query.first()

latest_forecast_values = get_latest_forecast_values_by_site(
session=db_session,
site_uuids=[site.site_uuid],
start_utc=dt.datetime.today() - dt.timedelta(minutes=7)
start_utc=dt.datetime.today() - dt.timedelta(minutes=7),
)
assert len(latest_forecast_values[site.site_uuid]) == 7

latest_forecast_values = get_latest_forecast_values_by_site(
session=db_session,
site_uuids=[site.site_uuid],
start_utc=dt.datetime.today() - dt.timedelta(minutes=5)
start_utc=dt.datetime.today() - dt.timedelta(minutes=5),
)
assert len(latest_forecast_values[site.site_uuid]) == 5

def test_gets_latest_forecast_values_forecast_value(self, forecast_values, db_session):
query: Query = db_session.query(SiteSQL)
site: SiteSQL = query.first()

latest_forecast_values = get_forecast_values_by_site_latest(
session=db_session,
site_uuids=[site.site_uuid],
)

assert len(latest_forecast_values) == 1
assert len(latest_forecast_values[site.site_uuid]) == 10
assert latest_forecast_values[site.site_uuid][0].datetime_interval is not None


class TestFilterQueryByDatetimeInterval:
"""Tests for the filter_query_by_datetime_interval function"""

def test_returns_datetime_intervals_in_filter(self, datetimeintervals, db_session):
query: Query = db_session.query(DatetimeIntervalSQL)
query = filter_query_by_datetime_interval(
query=query,
start_utc=dt.datetime.today() - dt.timedelta(minutes=7)
query=query, start_utc=dt.datetime.today() - dt.timedelta(minutes=7)
)

datetime_intervals: List[DatetimeIntervalSQL] = query.all()
Expand Down

0 comments on commit 78bc0a0

Please sign in to comment.