From 81e27cf22ff4fd5a7113786e3ca12d647b3b4f1b Mon Sep 17 00:00:00 2001 From: Chris Briggs Date: Thu, 25 Jan 2024 14:36:05 +0000 Subject: [PATCH] Added remaining tests for app functions --- india_forecast_app/app.py | 6 ++- india_forecast_app/model.py | 2 +- tests/_utils.py | 21 ++++++++ tests/conftest.py | 102 +++++++++++++++++++++++------------- tests/test_app.py | 64 +++++++++++++++++----- 5 files changed, 143 insertions(+), 52 deletions(-) create mode 100644 tests/_utils.py diff --git a/india_forecast_app/app.py b/india_forecast_app/app.py index 1fd5829..7d30152 100644 --- a/india_forecast_app/app.py +++ b/india_forecast_app/app.py @@ -5,6 +5,7 @@ import datetime as dt import logging import os +import sys import click import pandas as pd @@ -132,7 +133,7 @@ def app(timestamp: dt.datetime | None, write_to_db: bool, log_level: str): """ Main function for running forecasts for sites in India """ - logging.basicConfig(level=getattr(logging, log_level.upper())) + logging.basicConfig(stream=sys.stdout, level=getattr(logging, log_level.upper())) if timestamp is None: timestamp = dt.datetime.now(tz=dt.UTC) @@ -142,7 +143,8 @@ def app(timestamp: dt.datetime | None, write_to_db: bool, log_level: str): timestamp.replace(tzinfo=dt.UTC) # 0. Initialise DB connection - url = os.getenv("DB_URL") + url = os.environ["DB_URL"] + db_conn = DatabaseConnection(url, echo=False) with db_conn.get_session() as session: diff --git a/india_forecast_app/model.py b/india_forecast_app/model.py index f4fc754..cb3b24f 100644 --- a/india_forecast_app/model.py +++ b/india_forecast_app/model.py @@ -18,7 +18,7 @@ def version(self): return "0.0.0" def __init__(self): - """Initialiser for the model""" + """Initializer for the model""" pass def predict(self, site_id: str, timestamp: dt.datetime): diff --git a/tests/_utils.py b/tests/_utils.py new file mode 100644 index 0000000..109e690 --- /dev/null +++ b/tests/_utils.py @@ -0,0 +1,21 @@ +"""Testing utils.""" + +from click.testing import CliRunner + + +def run_click_script(func, args: list[str], catch_exceptions: bool = False): + """Util to test click scripts while showing the stdout.""" + + runner = CliRunner() + + # We catch the exception here no matter what, but we'll reraise later if need be. + result = runner.invoke(func, args, catch_exceptions=True) + + # Without this the output to stdout/stderr is grabbed by click's test runner. + # print(result.output) + + # In case of an exception, raise it so that the test fails with the exception. + if result.exception and not catch_exceptions: + raise result.exception + + return result diff --git a/tests/conftest.py b/tests/conftest.py index e1adee0..159e252 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,58 +3,88 @@ """ +import datetime as dt +import os + import pytest -from pvsite_datamodel.connection import DatabaseConnection from pvsite_datamodel.sqlmodels import Base, SiteSQL +from sqlalchemy import create_engine +from sqlalchemy.orm import Session from testcontainers.postgres import PostgresContainer -@pytest.fixture(scope="session", autouse=True) -def db_conn(): - """Database engine, this includes the table creation.""" +@pytest.fixture(scope="session") +def engine(): + """Database engine fixture.""" + with PostgresContainer("postgres:14.5") as postgres: url = postgres.get_connection_url() - - database_connection = DatabaseConnection(url, echo=False) - engine = database_connection.engine - + os.environ["DB_URL"] = url + engine = create_engine(url) Base.metadata.create_all(engine) - yield database_connection - - engine.dispose() + yield engine @pytest.fixture() -def db_session(db_conn): - """Creates a new database session for a test. +def db_session(engine): + """Return a sqlalchemy session, which tears down everything properly post-test.""" + + connection = engine.connect() + # begin the nested transaction + transaction = connection.begin() + # use the connection with the already started transaction - We automatically roll back whatever happens when the test completes. - """ + with Session(bind=connection) as session: + yield session - with db_conn.get_session() as session: - with session.begin(): - yield session - session.rollback() + session.close() + # roll back the broader transaction + transaction.rollback() + # put back the connection to the connection pool + connection.close() + session.flush() + + engine.dispose() @pytest.fixture(scope="session", autouse=True) -def db_data(db_conn): +def db_data(engine): """Seed some initial data into DB.""" - with db_conn.get_session() as session: - n_sites = 3 - - # Sites - for i in range(n_sites): - site = SiteSQL( - client_site_id=i + 1, - latitude=51, - longitude=3, - capacity_kw=4, - ml_id=i, - country="india" - ) - session.add(site) - - session.commit() + with engine.connect() as connection: + with Session(bind=connection) as session: + n_sites = 3 + + # Sites + for i in range(n_sites): + site = SiteSQL( + client_site_id=i + 1, + latitude=51, + longitude=3, + capacity_kw=4, + ml_id=i, + country="india" + ) + session.add(site) + + session.commit() + + +@pytest.fixture() +def forecast_values(): + """Dummy forecast values""" + + n = 10 # number of forecast values + step = 15 # in minutes + init_utc = dt.datetime.now(dt.timezone.utc) + start_utc = [init_utc + dt.timedelta(minutes=i * step) for i in range(n)] + end_utc = [d + dt.timedelta(minutes=step) for d in start_utc] + forecast_power_kw = [i * 10 for i in range(n)] + forecast_values = { + "start_utc": start_utc, + "end_utc": end_utc, + "forecast_power_kw": forecast_power_kw + } + + return forecast_values diff --git a/tests/test_app.py b/tests/test_app.py index c378dbd..c0cae05 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -1,18 +1,22 @@ """ Tests for functions in app.py """ - import datetime as dt import uuid -from pvsite_datamodel.sqlmodels import SiteSQL +import pytest +from pvsite_datamodel.read import get_all_sites +from pvsite_datamodel.sqlmodels import ForecastSQL, ForecastValueSQL, SiteSQL -from india_forecast_app.app import get_model, get_site_ids, run_model +from india_forecast_app.app import app, get_model, get_site_ids, run_model, save_forecast from india_forecast_app.model import DummyModel +from ._utils import run_click_script + def test_get_site_ids(db_session): """Test for correct site ids""" + site_ids = get_site_ids(db_session) assert len(site_ids) == 3 @@ -22,6 +26,7 @@ def test_get_site_ids(db_session): def test_get_model(): """Test for getting valid model""" + model = get_model() assert hasattr(model, 'version') @@ -31,6 +36,7 @@ def test_get_model(): def test_run_model(db_session): """Test for running model""" + site = db_session.query(SiteSQL).first() model = DummyModel() @@ -39,19 +45,51 @@ def test_run_model(db_session): site_id=site.site_uuid, timestamp=dt.datetime.now(tz=dt.UTC) ) - - # TODO better assertions against forecast + assert isinstance(forecast, list) + assert len(forecast) == 192 # value for every 15mins over 2 days + assert all([isinstance(value["start_utc"], dt.datetime) for value in forecast]) + assert all([isinstance(value["end_utc"], dt.datetime) for value in forecast]) + assert all([isinstance(value["forecast_power_kw"], int) for value in forecast]) -def test_save_forecast(db_session): + +def test_save_forecast(db_session, forecast_values): """Test for saving forecast""" - - # TODO test for successful and unsuccessful saving of forecast - pass + site = get_all_sites(db_session)[0] -def test_app(): + forecast = { + "meta": { + "site_id": site.site_uuid, + "version": "0.0.0", + "timestamp": dt.datetime.now(tz=dt.UTC) + }, + "values": forecast_values, + } + + save_forecast(db_session, forecast, write_to_db=True) + + assert db_session.query(ForecastSQL).count() == 1 + assert db_session.query(ForecastValueSQL).count() == 10 + + +@pytest.mark.parametrize("write_to_db", [True, False]) +def test_app(write_to_db, db_session): """Test for running app from command line""" - - # TODO test click app - pass + + init_n_forecasts = db_session.query(ForecastSQL).count() + init_n_forecast_values = db_session.query(ForecastValueSQL).count() + + args = ["--date", dt.datetime.now(tz=dt.UTC).strftime("%Y-%m-%d-%H-%M")] + if write_to_db: + args.append('--write-to-db') + + result = run_click_script(app, args) + assert result.exit_code == 0 + + if write_to_db: + assert db_session.query(ForecastSQL).count() == init_n_forecasts + 3 + assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values + (3 * 192) + else: + assert db_session.query(ForecastSQL).count() == init_n_forecasts + assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values