Skip to content

Commit

Permalink
Write forecast and forecast values (#97)
Browse files Browse the repository at this point in the history
* Added insert_forecast_values function
* Added fixtures and tests for insert_forecast_values function
* Added assertions for expected input keys for forecast write test
  • Loading branch information
confusedmatrix authored Jan 24, 2024
1 parent decdceb commit 1226dfe
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 19 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ Currently available functions accessible via `from pvsite_datamodel.read import
### Write package functions

Currently available write functions accessible via `from pvsite_datamodels.write import <func>`:
- insert_forecast_values
- insert_generation_values
- create_site
- create_site_group
Expand Down
18 changes: 18 additions & 0 deletions pvsite_datamodel/write/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""
Functions for writing to the PVSite database
"""

from .forecast import insert_forecast_values
from .generation import insert_generation_values
from .user_and_site import (
add_site_to_site_group,
change_user_site_group,
create_site,
create_site_group,
create_user,
delete_site,
delete_site_group,
delete_user,
make_fake_site,
update_user_site_group,
)
43 changes: 43 additions & 0 deletions pvsite_datamodel/write/forecast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
Write helpers for the Forecast and ForecastValues table.
"""

import logging

import pandas as pd
from sqlalchemy.orm import Session

from pvsite_datamodel.sqlmodels import ForecastSQL, ForecastValueSQL

_log = logging.getLogger(__name__)


def insert_forecast_values(
session: Session,
forecast_meta: dict,
forecast_values_df: pd.DataFrame,
):
"""Insert a dataframe of forecast values and forecast meta info into the database.
:param session: sqlalchemy session for interacting with the database
:param forecast_meta: Meta info about the forecast values
:param forecast_values_df: dataframe with the data to insert
"""

forecast = ForecastSQL(**forecast_meta)
session.add(forecast)

# Flush to get the Forecast's primary key.
session.flush()

rows = forecast_values_df.to_dict("records")
session.bulk_save_objects(
[
ForecastValueSQL(
**row,
forecast_uuid=forecast.forecast_uuid,
)
for row in rows
]
)
session.commit()
19 changes: 3 additions & 16 deletions pvsite_datamodel/write/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,19 @@
import logging

import pandas as pd
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Session

from pvsite_datamodel.sqlmodels import Base, GenerationSQL
from pvsite_datamodel.sqlmodels import GenerationSQL
from pvsite_datamodel.write.utils import _insert_do_nothing_on_conflict

_log = logging.getLogger(__name__)


def _insert_do_nothing_on_conflict(session: Session, table: Base, rows: list[dict]):
"""Upserts rows into table.
This functions checks the primary keys and constraints, and if present, does nothing
:param session: sqlalchemy Session
:param table: the table
:param rows: the rows we are going to update
"""
stmt = postgresql.insert(table.__table__)
stmt = stmt.on_conflict_do_nothing()
session.execute(stmt, rows)


def insert_generation_values(
session: Session,
df: pd.DataFrame,
):
"""Insert a dataframe of forecast values into the database.
"""Insert a dataframe of generation values into the database.
:param session: sqlalchemy session for interacting with the database
:param df: dataframe with the data to insert
Expand Down
21 changes: 21 additions & 0 deletions pvsite_datamodel/write/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""
Useful functions for write operations.
"""

from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Session

from pvsite_datamodel.sqlmodels import Base


def _insert_do_nothing_on_conflict(session: Session, table: Base, rows: list[dict]):
"""Upserts rows into table.
This functions checks the primary keys and constraints, and if present, does nothing
:param session: sqlalchemy Session
:param table: the table
:param rows: the rows we are going to update
"""
stmt = postgresql.insert(table.__table__)
stmt = stmt.on_conflict_do_nothing()
session.execute(stmt, rows)
56 changes: 56 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,57 @@ def test_time():
return dt.datetime(2022, 7, 25, 0, 0, 0, 0, dt.timezone.utc)


@pytest.fixture()
def forecast_valid_meta_input(sites):
forecast_meta = {
"site_uuid": sites[0].site_uuid,
"timestamp_utc": dt.datetime.now(tz=dt.UTC),
"forecast_version": "0.0.0",
}

return forecast_meta


@pytest.fixture()
def forecast_valid_values_input():
n = 10 # number of forecast values
step = 15 # in minutes
init_utc = dt.datetime.now(dt.timezone.utc)
start_utc = [init_utc + dt.timedelta(minutes=i * step) for i in range(n)]
end_utc = [d + dt.timedelta(minutes=step) for d in start_utc]
forecast_power_kw = [i * 10 for i in range(n)]
horizon_mins = [int((d - init_utc).seconds / 60) for d in start_utc]
forecast_values = {
"start_utc": start_utc,
"end_utc": end_utc,
"forecast_power_kw": forecast_power_kw,
"horizon_minutes": horizon_mins,
}

return forecast_values


@pytest.fixture()
def forecast_valid_input(forecast_valid_meta_input, forecast_valid_values_input):
return (forecast_valid_meta_input, forecast_valid_values_input)


@pytest.fixture()
def forecast_with_invalid_meta_input(forecast_valid_meta_input, forecast_valid_values_input):
forecast_meta = forecast_valid_meta_input
forecast_meta["site_uuid"] = "not-a-uuid"
return (forecast_meta, forecast_valid_values_input)


@pytest.fixture()
def forecast_with_invalid_values_input(forecast_valid_meta_input, forecast_valid_values_input):
forecast_values = forecast_valid_values_input
forecast_power_kw = forecast_values["forecast_power_kw"]
del forecast_values["forecast_power_kw"]
forecast_values["forecast_power_MW"] = forecast_power_kw
return (forecast_valid_meta_input, forecast_values)


@pytest.fixture()
def forecast_valid_site(sites):
site_uuid = sites[0].site_uuid
Expand Down Expand Up @@ -184,6 +235,11 @@ def forecast_invalid_dataframe():
}


@pytest.fixture()
def forecast_invalid_meta():
return {}


@pytest.fixture()
def generation_valid_site(sites):
site_uuid = sites[0].site_uuid
Expand Down
62 changes: 59 additions & 3 deletions tests/test_write.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
"""Test write functions."""

import datetime
import uuid

import pandas as pd
import pandas.api.types as ptypes
import pytest
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session

from pvsite_datamodel.sqlmodels import GenerationSQL
from pvsite_datamodel.sqlmodels import ForecastSQL, ForecastValueSQL, GenerationSQL
from pvsite_datamodel.write.forecast import insert_forecast_values
from pvsite_datamodel.write.generation import insert_generation_values

# from pvsite_datamodel.read.user import get_user_by_email
from pvsite_datamodel.write.user_and_site import (
add_site_to_site_group,
change_user_site_group,
Expand All @@ -19,6 +22,59 @@
)


class TestInsertForecastValues:
"""Tests for the insert_forecast_values function."""

def test_insert_forecast_for_existing_site(self, db_session, forecast_valid_input):
"""Test if forecast and forecast values inserted successfully"""
forecast_meta, forecast_values = forecast_valid_input
forecast_values_df = pd.DataFrame(forecast_values)

assert "site_uuid" in forecast_meta
assert "timestamp_utc" in forecast_meta
assert "forecast_version" in forecast_meta

assert isinstance(forecast_meta["site_uuid"], uuid.UUID)
assert isinstance(forecast_meta["timestamp_utc"], datetime.datetime)
assert isinstance(forecast_meta["forecast_version"], str)

assert "start_utc" in forecast_values_df.columns
assert "end_utc" in forecast_values_df.columns
assert "forecast_power_kw" in forecast_values_df.columns
assert "horizon_minutes" in forecast_values_df.columns

assert ptypes.is_datetime64_any_dtype(forecast_values_df["start_utc"])
assert ptypes.is_datetime64_any_dtype(forecast_values_df["end_utc"])
assert ptypes.is_numeric_dtype(forecast_values_df["forecast_power_kw"])
assert ptypes.is_numeric_dtype(forecast_values_df["horizon_minutes"])

insert_forecast_values(db_session, forecast_meta, forecast_values_df)

assert db_session.query(ForecastSQL).count() == 1
assert db_session.query(ForecastValueSQL).count() == 10

def test_invalid_forecast_meta(self, db_session, forecast_with_invalid_meta_input):
"""Test function errors on invalid forecast metadata"""
forecast_meta, forecast_values = forecast_with_invalid_meta_input
forecast_values_df = pd.DataFrame(forecast_values)

with pytest.raises(SQLAlchemyError):
insert_forecast_values(db_session, forecast_meta, forecast_values_df)

def test_invalid_forecast_values_dataframe(
self, db_session, forecast_with_invalid_values_input
):
"""test function errors on invalid forecast values dataframe"""
forecast_meta, forecast_values = forecast_with_invalid_values_input
forecast_values_df = pd.DataFrame(forecast_values)

with pytest.raises(
TypeError,
match=r"^'forecast_power_MW' is an invalid keyword argument for ForecastValueSQL.*",
):
insert_forecast_values(db_session, forecast_meta, forecast_values_df)


class TestInsertGenerationValues:
"""Tests for the insert_generation_values function."""

Expand Down

0 comments on commit 1226dfe

Please sign in to comment.