Skip to content

Commit

Permalink
Issue/sum read (#81)
Browse files Browse the repository at this point in the history
* add summation for generation

* add forecast sum function

* blacks

* lint

* add pydantic to requirements

* update poetry lock file

* add tests

* lint
  • Loading branch information
peterdudfield authored Sep 5, 2023
1 parent 24b8c2e commit 6fa4d33
Show file tree
Hide file tree
Showing 8 changed files with 698 additions and 381 deletions.
867 changes: 506 additions & 361 deletions poetry.lock

Large diffs are not rendered by default.

20 changes: 20 additions & 0 deletions pvsite_datamodel/pydantic_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
""" Pydantic models."""
from datetime import datetime

from pydantic import BaseModel, Field


class GenerationSum(BaseModel):
"""Sum of generation."""

power_kw: float = Field(..., description="Summed power in kW")
start_utc: datetime = Field(..., description="Start datetime of this power")
name: str = Field(..., description="Name of item sums. ")


class ForecastValueSum(BaseModel):
"""Sum of forecast values."""

power_kw: float = Field(..., description="Summed power in kW")
start_utc: datetime = Field(..., description="Start datetime of this power")
name: str = Field(..., description="Name of item sums. ")
46 changes: 41 additions & 5 deletions pvsite_datamodel/read/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import logging
import uuid
from datetime import datetime
from typing import List, Optional
from typing import List, Optional, Union

from sqlalchemy import func
from sqlalchemy.orm import Session, contains_eager

from pvsite_datamodel.pydantic_models import GenerationSum
from pvsite_datamodel.sqlmodels import (
GenerationSQL,
SiteGroupSiteSQL,
Expand Down Expand Up @@ -68,16 +70,21 @@ def get_pv_generation_by_sites(
start_utc: Optional[datetime] = None,
end_utc: Optional[datetime] = None,
site_uuids: Optional[List[uuid.UUID]] = None,
) -> List[GenerationSQL]:
sum_by: Optional[str] = None,
) -> Union[List[GenerationSQL], List[GenerationSum]]:
"""Get the generation data by site.
:param session: database session
:param start_utc: search filters >= on 'datetime_utc'
:param end_utc: search fileters < on 'datetime_utc'
:param site_uuids: optional list of site uuids
:param sum_by: optional string to sum by. Must be one of ['total', 'dno', 'gsp']
:return: list of pv yields
"""
# start main query

if sum_by not in ["total", "dno", "gsp", None]:
raise ValueError(f"sum_by must be one of ['total', 'dno', 'gsp'], not {sum_by}")

query = session.query(GenerationSQL)
query = query.join(SiteSQL)

Expand All @@ -100,7 +107,36 @@ def get_pv_generation_by_sites(
# make sure this is all loaded
query = query.options(contains_eager(GenerationSQL.site)).populate_existing()

# get all results
generations: List[GenerationSQL] = query.all()
if sum_by is None:
# get all results
generations: List[GenerationSQL] = query.all()
else:
subquery = query.subquery()

group_by_variables = [subquery.c.start_utc]
if sum_by == "dno":
group_by_variables.append(SiteSQL.dno)
if sum_by == "gsp":
group_by_variables.append(SiteSQL.gsp)
query_variables = group_by_variables.copy()
query_variables.append(func.sum(subquery.c.generation_power_kw))

query = session.query(*query_variables)
query = query.join(SiteSQL)
query = query.group_by(*group_by_variables)
query = query.order_by(*group_by_variables)
generations_raw = query.all()

generations: List[GenerationSum] = []
for generation_raw in generations_raw:
if len(generation_raw) == 2:
generation = GenerationSum(
start_utc=generation_raw[0], power_kw=generation_raw[1], name="total"
)
else:
generation = GenerationSum(
start_utc=generation_raw[0], power_kw=generation_raw[2], name=generation_raw[1]
)
generations.append(generation)

return generations
63 changes: 52 additions & 11 deletions pvsite_datamodel/read/latest_forecast_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,21 @@

import datetime as dt
import uuid
from typing import List, Optional, Union

from sqlalchemy import func
from sqlalchemy.orm import Session

from pvsite_datamodel.sqlmodels import ForecastSQL, ForecastValueSQL
from pvsite_datamodel.pydantic_models import ForecastValueSum
from pvsite_datamodel.sqlmodels import ForecastSQL, ForecastValueSQL, SiteSQL


def get_latest_forecast_values_by_site(
session: Session,
site_uuids: list[uuid.UUID],
start_utc: dt.datetime,
) -> dict[uuid.UUID, list[ForecastValueSQL]]:
sum_by: Optional[str] = None,
) -> Union[dict[uuid.UUID, list[ForecastValueSQL]], List[ForecastValueSum]]:
"""Get the forecast values by input sites, get the latest value.
Return the forecasts after a given date, but keeping only the latest for a given timestamp.
Expand All @@ -36,7 +40,12 @@ def get_latest_forecast_values_by_site(
:param session: The sqlalchemy database session
:param site_uuids: list of site_uuids for which to fetch latest forecast values
:param start_utc: filters on forecast values target_time >= start_utc
:param sum_by: optional, sum the forecast values by this column
"""

if sum_by not in ["total", "dno", "gsp", None]:
raise ValueError(f"sum_by must be one of ['total', 'dno', 'gsp'], not {sum_by}")

query = (
session.query(ForecastValueSQL)
.distinct(
Expand All @@ -55,16 +64,48 @@ def get_latest_forecast_values_by_site(
)
)

# query results
forecast_values = query.all()
if sum_by is None:
# query results
forecast_values = query.all()

output_dict: dict[uuid.UUID, list[ForecastValueSQL]] = {}

for site_uuid in site_uuids:
site_latest_forecast_values: list[ForecastValueSQL] = [
fv for fv in forecast_values if fv.forecast.site_uuid == site_uuid
]

output_dict[site_uuid] = site_latest_forecast_values

return output_dict
else:
subquery = query.subquery()

output_dict: dict[uuid.UUID, list[ForecastValueSQL]] = {}
group_by_variables = [subquery.c.start_utc]
if sum_by == "dno":
group_by_variables.append(SiteSQL.dno)
if sum_by == "gsp":
group_by_variables.append(SiteSQL.gsp)
query_variables = group_by_variables.copy()
query_variables.append(func.sum(subquery.c.forecast_power_kw))

for site_uuid in site_uuids:
site_latest_forecast_values: list[ForecastValueSQL] = [
fv for fv in forecast_values if fv.forecast.site_uuid == site_uuid
]
query = session.query(*query_variables)
query = query.join(ForecastSQL, ForecastSQL.forecast_uuid == subquery.c.forecast_uuid)
query = query.join(SiteSQL)
query = query.group_by(*group_by_variables)
query = query.order_by(*group_by_variables)
forecasts_raw = query.all()

output_dict[site_uuid] = site_latest_forecast_values
forecasts: List[ForecastValueSum] = []
for forecast_raw in forecasts_raw:
if len(forecast_raw) == 2:
generation = ForecastValueSum(
start_utc=forecast_raw[0], power_kw=forecast_raw[1], name="total"
)
else:
generation = ForecastValueSum(
start_utc=forecast_raw[0], power_kw=forecast_raw[2], name=forecast_raw[1]
)
forecasts.append(generation)

return output_dict
return forecasts
3 changes: 3 additions & 0 deletions pvsite_datamodel/read/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def get_site_by_client_site_id(session: Session, client_name: str, client_site_i
# start main query
query = session.query(SiteSQL)

# start main query
query = query.filter(SiteSQL.client_site_name == client_name)

# select the correct client site id
query = query.filter(SiteSQL.client_site_id == client_site_id)

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ python = "^3.10"
sqlalchemy = "1.4.46"
psycopg2-binary = "^2.9.5"
pandas = "^1.5.3"
pydantic = "^2.3.0"


[tool.poetry.group.dev.dependencies]
Expand Down
11 changes: 9 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Pytest fixtures for tests."""
import datetime as dt
import json
import uuid
from typing import List

Expand Down Expand Up @@ -60,26 +61,32 @@ def sites(db_session):
module_capacity_kw=4.3,
created_utc=dt.datetime.now(dt.timezone.utc),
ml_id=i,
dno=json.dumps({"dno_id": str(i), "name": "unknown", "long_name": "unknown"}),
gsp=json.dumps({"gsp_id": str(i), "name": "unknown"}),
)
db_session.add(site)
db_session.commit()

sites.append(site)

# make sure they are in order
sites = db_session.query(SiteSQL).order_by(SiteSQL.site_uuid).all()

return sites


@pytest.fixture()
def generations(db_session, sites):
"""Create some fake generations."""
start_times = [dt.datetime.now(dt.timezone.utc) - dt.timedelta(minutes=x) for x in range(10)]
now = dt.datetime.now(dt.timezone.utc)
start_times = [now - dt.timedelta(minutes=x) for x in range(10)]

all_generations = []
for site in sites:
for i in range(0, 10):
generation = GenerationSQL(
site_uuid=site.site_uuid,
generation_power_kw=i,
generation_power_kw=10 - i,
start_utc=start_times[i],
end_utc=start_times[i] + dt.timedelta(minutes=5),
)
Expand Down
68 changes: 66 additions & 2 deletions tests/test_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,18 @@ def test_raises_error_for_nonexistant_site(self, sites, db_session):

def test_get_site_by_client_site_id(self, sites, db_session):
site = get_site_by_client_site_id(
session=db_session, client_name="test_client", client_site_id=1
session=db_session,
client_name=sites[0].client_site_name,
client_site_id=sites[0].client_site_id,
)

assert site == sites[0]

def test_get_site_by_client_site_name(self, sites, db_session):
site = get_site_by_client_site_name(
session=db_session, client_name="test_client", client_site_name="test_site_0"
session=db_session,
client_name="test_client",
client_site_name=sites[0].client_site_name,
)

assert site == sites[0]
Expand Down Expand Up @@ -166,6 +170,46 @@ def test_returns_empty_list_for_no_input_sites(self, generations, db_session):

assert len(generations) == 0

def test_gets_generation_for_multiple_sum_total(self, generations, db_session):
query: Query = db_session.query(SiteSQL)
sites: List[SiteSQL] = query.all()

generations = get_pv_generation_by_sites(
session=db_session, site_uuids=[site.site_uuid for site in sites], sum_by="total"
)

assert len(generations) == 10
assert generations[0].power_kw == 4
assert generations[1].power_kw == 8
assert (generations[2].start_utc - generations[1].start_utc).seconds == 60

def test_gets_generation_for_multiple_sum_gsp(self, generations, db_session):
query: Query = db_session.query(SiteSQL)
sites: List[SiteSQL] = query.all()

generations = get_pv_generation_by_sites(
session=db_session, site_uuids=[site.site_uuid for site in sites], sum_by="gsp"
)
assert len(generations) == 10 * len(sites)

def test_gets_generation_for_multiple_sum_dno(self, generations, db_session):
query: Query = db_session.query(SiteSQL)
sites: List[SiteSQL] = query.all()

generations = get_pv_generation_by_sites(
session=db_session, site_uuids=[site.site_uuid for site in sites], sum_by="dno"
)
assert len(generations) == 10 * len(sites)

def test_gets_generation_for_multiple_sum_error(self, generations, db_session):
query: Query = db_session.query(SiteSQL)
sites: List[SiteSQL] = query.all()

with pytest.raises(ValueError): # noqa
_ = get_pv_generation_by_sites(
session=db_session, site_uuids=[site.site_uuid for site in sites], sum_by="blah"
)


class TestGetLatestStatus:
"""Tests for the get_latest_status function."""
Expand Down Expand Up @@ -250,6 +294,26 @@ def test_get_latest_forecast_values(db_session, sites):

assert values_as_tuple == expected[site_uuid]

latest_forecast = get_latest_forecast_values_by_site(
session=db_session, site_uuids=site_uuids, start_utc=d1, sum_by="total"
)
assert len(latest_forecast) == 4

latest_forecast = get_latest_forecast_values_by_site(
session=db_session, site_uuids=site_uuids, start_utc=d1, sum_by="dno"
)
assert len(latest_forecast) == 4 + 2 # 4 from site 1, 2 from site 2

latest_forecast = get_latest_forecast_values_by_site(
session=db_session, site_uuids=site_uuids, start_utc=d2, sum_by="gsp"
)
assert len(latest_forecast) == 3 + 1 # 3 from site 1, 1 from site 2

with pytest.raises(ValueError): # noqa
_ = get_latest_forecast_values_by_site(
session=db_session, site_uuids=site_uuids, start_utc=d2, sum_by="bla"
)


def test_get_site_group_by_name(db_session):
site_group = SiteGroupSQL(site_group_name="test")
Expand Down

0 comments on commit 6fa4d33

Please sign in to comment.