Skip to content

Commit

Permalink
Added remaining tests for app functions
Browse files Browse the repository at this point in the history
  • Loading branch information
confusedmatrix committed Jan 25, 2024
1 parent 720719e commit 81e27cf
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 52 deletions.
6 changes: 4 additions & 2 deletions india_forecast_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import datetime as dt
import logging
import os
import sys

import click
import pandas as pd
Expand Down Expand Up @@ -132,7 +133,7 @@ def app(timestamp: dt.datetime | None, write_to_db: bool, log_level: str):
"""
Main function for running forecasts for sites in India
"""
logging.basicConfig(level=getattr(logging, log_level.upper()))
logging.basicConfig(stream=sys.stdout, level=getattr(logging, log_level.upper()))

if timestamp is None:
timestamp = dt.datetime.now(tz=dt.UTC)
Expand All @@ -142,7 +143,8 @@ def app(timestamp: dt.datetime | None, write_to_db: bool, log_level: str):
timestamp.replace(tzinfo=dt.UTC)

# 0. Initialise DB connection
url = os.getenv("DB_URL")
url = os.environ["DB_URL"]

db_conn = DatabaseConnection(url, echo=False)

with db_conn.get_session() as session:
Expand Down
2 changes: 1 addition & 1 deletion india_forecast_app/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def version(self):
return "0.0.0"

def __init__(self):
"""Initialiser for the model"""
"""Initializer for the model"""
pass

def predict(self, site_id: str, timestamp: dt.datetime):
Expand Down
21 changes: 21 additions & 0 deletions tests/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Testing utils."""

from click.testing import CliRunner


def run_click_script(func, args: list[str], catch_exceptions: bool = False):
"""Util to test click scripts while showing the stdout."""

runner = CliRunner()

# We catch the exception here no matter what, but we'll reraise later if need be.
result = runner.invoke(func, args, catch_exceptions=True)

# Without this the output to stdout/stderr is grabbed by click's test runner.
# print(result.output)

# In case of an exception, raise it so that the test fails with the exception.
if result.exception and not catch_exceptions:
raise result.exception

return result
102 changes: 66 additions & 36 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,58 +3,88 @@
"""


import datetime as dt
import os

import pytest
from pvsite_datamodel.connection import DatabaseConnection
from pvsite_datamodel.sqlmodels import Base, SiteSQL
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from testcontainers.postgres import PostgresContainer


@pytest.fixture(scope="session", autouse=True)
def db_conn():
"""Database engine, this includes the table creation."""
@pytest.fixture(scope="session")
def engine():
"""Database engine fixture."""

with PostgresContainer("postgres:14.5") as postgres:
url = postgres.get_connection_url()

database_connection = DatabaseConnection(url, echo=False)
engine = database_connection.engine

os.environ["DB_URL"] = url
engine = create_engine(url)
Base.metadata.create_all(engine)

yield database_connection

engine.dispose()
yield engine


@pytest.fixture()
def db_session(db_conn):
"""Creates a new database session for a test.
def db_session(engine):
"""Return a sqlalchemy session, which tears down everything properly post-test."""

connection = engine.connect()
# begin the nested transaction
transaction = connection.begin()
# use the connection with the already started transaction

We automatically roll back whatever happens when the test completes.
"""
with Session(bind=connection) as session:
yield session

with db_conn.get_session() as session:
with session.begin():
yield session
session.rollback()
session.close()
# roll back the broader transaction
transaction.rollback()
# put back the connection to the connection pool
connection.close()
session.flush()

engine.dispose()


@pytest.fixture(scope="session", autouse=True)
def db_data(db_conn):
def db_data(engine):
"""Seed some initial data into DB."""

with db_conn.get_session() as session:
n_sites = 3

# Sites
for i in range(n_sites):
site = SiteSQL(
client_site_id=i + 1,
latitude=51,
longitude=3,
capacity_kw=4,
ml_id=i,
country="india"
)
session.add(site)

session.commit()
with engine.connect() as connection:
with Session(bind=connection) as session:
n_sites = 3

# Sites
for i in range(n_sites):
site = SiteSQL(
client_site_id=i + 1,
latitude=51,
longitude=3,
capacity_kw=4,
ml_id=i,
country="india"
)
session.add(site)

session.commit()


@pytest.fixture()
def forecast_values():
"""Dummy forecast values"""

n = 10 # number of forecast values
step = 15 # in minutes
init_utc = dt.datetime.now(dt.timezone.utc)
start_utc = [init_utc + dt.timedelta(minutes=i * step) for i in range(n)]
end_utc = [d + dt.timedelta(minutes=step) for d in start_utc]
forecast_power_kw = [i * 10 for i in range(n)]
forecast_values = {
"start_utc": start_utc,
"end_utc": end_utc,
"forecast_power_kw": forecast_power_kw
}

return forecast_values
64 changes: 51 additions & 13 deletions tests/test_app.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
"""
Tests for functions in app.py
"""

import datetime as dt
import uuid

from pvsite_datamodel.sqlmodels import SiteSQL
import pytest
from pvsite_datamodel.read import get_all_sites
from pvsite_datamodel.sqlmodels import ForecastSQL, ForecastValueSQL, SiteSQL

from india_forecast_app.app import get_model, get_site_ids, run_model
from india_forecast_app.app import app, get_model, get_site_ids, run_model, save_forecast
from india_forecast_app.model import DummyModel

from ._utils import run_click_script


def test_get_site_ids(db_session):
"""Test for correct site ids"""

site_ids = get_site_ids(db_session)

assert len(site_ids) == 3
Expand All @@ -22,6 +26,7 @@ def test_get_site_ids(db_session):

def test_get_model():
"""Test for getting valid model"""

model = get_model()

assert hasattr(model, 'version')
Expand All @@ -31,6 +36,7 @@ def test_get_model():

def test_run_model(db_session):
"""Test for running model"""

site = db_session.query(SiteSQL).first()
model = DummyModel()

Expand All @@ -39,19 +45,51 @@ def test_run_model(db_session):
site_id=site.site_uuid,
timestamp=dt.datetime.now(tz=dt.UTC)
)

# TODO better assertions against forecast

assert isinstance(forecast, list)
assert len(forecast) == 192 # value for every 15mins over 2 days
assert all([isinstance(value["start_utc"], dt.datetime) for value in forecast])
assert all([isinstance(value["end_utc"], dt.datetime) for value in forecast])
assert all([isinstance(value["forecast_power_kw"], int) for value in forecast])

def test_save_forecast(db_session):

def test_save_forecast(db_session, forecast_values):
"""Test for saving forecast"""

# TODO test for successful and unsuccessful saving of forecast
pass

site = get_all_sites(db_session)[0]

def test_app():
forecast = {
"meta": {
"site_id": site.site_uuid,
"version": "0.0.0",
"timestamp": dt.datetime.now(tz=dt.UTC)
},
"values": forecast_values,
}

save_forecast(db_session, forecast, write_to_db=True)

assert db_session.query(ForecastSQL).count() == 1
assert db_session.query(ForecastValueSQL).count() == 10


@pytest.mark.parametrize("write_to_db", [True, False])
def test_app(write_to_db, db_session):
"""Test for running app from command line"""

# TODO test click app
pass

init_n_forecasts = db_session.query(ForecastSQL).count()
init_n_forecast_values = db_session.query(ForecastValueSQL).count()

args = ["--date", dt.datetime.now(tz=dt.UTC).strftime("%Y-%m-%d-%H-%M")]
if write_to_db:
args.append('--write-to-db')

result = run_click_script(app, args)
assert result.exit_code == 0

if write_to_db:
assert db_session.query(ForecastSQL).count() == init_n_forecasts + 3
assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values + (3 * 192)
else:
assert db_session.query(ForecastSQL).count() == init_n_forecasts
assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values

0 comments on commit 81e27cf

Please sign in to comment.