From 2b969e8e4395cfa3ab019540266f0782922d37b5 Mon Sep 17 00:00:00 2001 From: Chris Briggs Date: Thu, 25 Jan 2024 16:49:42 +0000 Subject: [PATCH 1/2] Dockerfile set to write to DB by default --- Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index 85dfb85..7245843 100644 --- a/Dockerfile +++ b/Dockerfile @@ -34,4 +34,4 @@ COPY --from=builder /venv /venv COPY --from=builder /app/dist . RUN . /venv/bin/activate && pip install *.whl -ENTRYPOINT ["app"] \ No newline at end of file +ENTRYPOINT ["app", "--write-to-db"] \ No newline at end of file From a6b0abdb0a9baecfaf942e0b1157f1d2d0d72e9c Mon Sep 17 00:00:00 2001 From: Chris Briggs Date: Thu, 25 Jan 2024 17:25:47 +0000 Subject: [PATCH 2/2] Added support for both pv and wind assets types --- india_forecast_app/app.py | 44 ++++++++++++++++++++++--------------- india_forecast_app/model.py | 12 +++++----- scripts/seed_local_db.py | 26 ++++++++++------------ tests/conftest.py | 37 ++++++++++++++++++++----------- tests/test_app.py | 38 +++++++++++++++++--------------- 5 files changed, 88 insertions(+), 69 deletions(-) diff --git a/india_forecast_app/app.py b/india_forecast_app/app.py index fd46797..72e9a0b 100644 --- a/india_forecast_app/app.py +++ b/india_forecast_app/app.py @@ -11,6 +11,7 @@ import pandas as pd from pvsite_datamodel import DatabaseConnection from pvsite_datamodel.read import get_sites_by_country +from pvsite_datamodel.sqlmodels import SiteSQL from pvsite_datamodel.write import insert_forecast_values from sqlalchemy.orm import Session @@ -19,31 +20,33 @@ log = logging.getLogger(__name__) -def get_site_ids(db_session: Session) -> list[str]: +def get_sites(db_session: Session) -> list[SiteSQL]: """ - Gets all avaiable site_ids in India + Gets all available sites in India Args: db_session: A SQLAlchemy session Returns: - A list of site_ids + A list of SiteSQL objects """ sites = get_sites_by_country(db_session, country="india") + return sites - return [s.site_uuid for s in sites] - -def get_model(): +def get_model(asset_type: str): """ Instantiates and returns the forecast model ready for running inference + Args: + asset_type: One or "pv" or "wind" + Returns: A forecasting model """ - model = DummyModel() + model = DummyModel(asset_type) return model @@ -151,17 +154,22 @@ def app(timestamp: dt.datetime | None, write_to_db: bool, log_level: str): # 1. Get sites log.info("Getting sites...") - site_ids = get_site_ids(session) - log.info(f"Found {len(site_ids)} sites") - - # 2. Load model - log.info("Loading model...") - model = get_model() - log.info("Loaded model") - - # 3. Run model for each site - for site_id in site_ids: - log.info(f"Running model for site={site_id}...") + sites = get_sites(session) + log.info(f"Found {len(sites)} sites") + + # 2. Load models + log.info("Loading models...") + pv_model = get_model("pv") + log.info("Loaded PV model") + wind_model = get_model("wind") + log.info("Loaded wind model") + + for site in sites: + # 3. Run model for each site + site_id = site.site_uuid + asset_type = site.asset_type.name + log.info(f"Running {asset_type} model for site={site_id}...") + model = wind_model if asset_type == "wind" else pv_model forecast_values = run_model(model=model, site_id=site_id, timestamp=timestamp) if forecast_values is None: diff --git a/india_forecast_app/model.py b/india_forecast_app/model.py index cb3b24f..f4d0898 100644 --- a/india_forecast_app/model.py +++ b/india_forecast_app/model.py @@ -11,15 +11,15 @@ class DummyModel: """ Dummy model that emulates the capabilities expected by a real model """ - + @property def version(self): """Version number""" return "0.0.0" - def __init__(self): + def __init__(self, asset_type: str): """Initializer for the model""" - pass + self.asset_type = asset_type def predict(self, site_id: str, timestamp: dt.datetime): """Make a prediction for the model""" @@ -35,7 +35,9 @@ def _generate_dummy_forecast(self, timestamp: dt.datetime): for i in range(numSteps): time = start + i * step - _yield = _basicSolarYieldFunc(int(time.timestamp())) + gen_func = _basicSolarYieldFunc if self.asset_type == "pv" else _basicWindYieldFunc + _yield = gen_func(int(time.timestamp())) + values.append( { "start_utc": time, @@ -99,6 +101,6 @@ def _basicSolarYieldFunc(timeUnix: int, scaleFactor: int = 10000) -> float: def _basicWindYieldFunc(timeUnix: int, scaleFactor: int = 10000) -> float: """Gets a fake wind yield for the input time.""" - output = min(scaleFactor, scaleFactor * 10 * random.random()) + output = min(scaleFactor, scaleFactor * random.random()) return output diff --git a/scripts/seed_local_db.py b/scripts/seed_local_db.py index a3bf985..bfd6ef0 100644 --- a/scripts/seed_local_db.py +++ b/scripts/seed_local_db.py @@ -7,10 +7,7 @@ from pvsite_datamodel.connection import DatabaseConnection from pvsite_datamodel.sqlmodels import Base from pvsite_datamodel.write.user_and_site import ( - add_site_to_site_group, create_site, - create_site_group, - create_user, ) @@ -45,29 +42,28 @@ def seed_db(): Base.metadata.create_all(engine) with db_conn.get_session() as session: - dummy_user = "dummy@openclimatefix.org" print("Seeding database") - site_group = create_site_group(session, site_group_name="dummy_site_group") - - _ = create_user( - session, email=dummy_user, site_group_name=site_group.site_group_name - ) - site, _ = create_site( session, - client_site_id=1234, - client_site_name="dummy_site", + client_site_id=1, + client_site_name="dummy_site_1", latitude=0.0, longitude=0.0, capacity_kw=10.0, + asset_type="pv", country="india", ) - add_site_to_site_group( + site, _ = create_site( session, - site_uuid=site.site_uuid, - site_group_name=site_group.site_group_name, + client_site_id=2, + client_site_name="dummy_site_2", + latitude=0.0, + longitude=0.0, + capacity_kw=10.0, + asset_type="wind", + country="india", ) print("Database successfully seeded") diff --git a/tests/conftest.py b/tests/conftest.py index 159e252..a6af384 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -54,19 +54,30 @@ def db_data(engine): with engine.connect() as connection: with Session(bind=connection) as session: - n_sites = 3 - - # Sites - for i in range(n_sites): - site = SiteSQL( - client_site_id=i + 1, - latitude=51, - longitude=3, - capacity_kw=4, - ml_id=i, - country="india" - ) - session.add(site) + + # PV site + site = SiteSQL( + client_site_id=1, + latitude=20.59, + longitude=78.96, + capacity_kw=4, + ml_id=1, + asset_type="pv", + country="india" + ) + session.add(site) + + # Wind site + site = SiteSQL( + client_site_id=2, + latitude=20.59, + longitude=78.96, + capacity_kw=4, + ml_id=2, + asset_type="wind", + country="india" + ) + session.add(site) session.commit() diff --git a/tests/test_app.py b/tests/test_app.py index c0cae05..5de528d 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -6,43 +6,45 @@ import pytest from pvsite_datamodel.read import get_all_sites -from pvsite_datamodel.sqlmodels import ForecastSQL, ForecastValueSQL, SiteSQL +from pvsite_datamodel.sqlmodels import ForecastSQL, ForecastValueSQL -from india_forecast_app.app import app, get_model, get_site_ids, run_model, save_forecast +from india_forecast_app.app import app, get_model, get_sites, run_model, save_forecast from india_forecast_app.model import DummyModel from ._utils import run_click_script -def test_get_site_ids(db_session): +def test_get_sites(db_session): """Test for correct site ids""" - site_ids = get_site_ids(db_session) + sites = get_sites(db_session) + sites = sorted(sites, key=lambda s: s.client_site_id) - assert len(site_ids) == 3 - for site_id in site_ids: - assert isinstance(site_id, uuid.UUID) + assert len(sites) == 2 + for site in sites: + assert isinstance(site.site_uuid, uuid.UUID) + assert sites[0].asset_type.name == "pv" + assert sites[1].asset_type.name == "wind" -def test_get_model(): +@pytest.mark.parametrize("asset_type", ["pv", "wind"]) +def test_get_model(asset_type): """Test for getting valid model""" - model = get_model() + model = get_model(asset_type) assert hasattr(model, 'version') assert isinstance(model.version, str) assert hasattr(model, 'predict') -def test_run_model(db_session): - """Test for running model""" - - site = db_session.query(SiteSQL).first() - model = DummyModel() +@pytest.mark.parametrize("asset_type", ["pv", "wind"]) +def test_run_model(db_session, asset_type): + """Test for running PV and wind models""" forecast = run_model( - model=model, - site_id=site.site_uuid, + model=DummyModel(asset_type), + site_id=str(uuid.uuid4()), timestamp=dt.datetime.now(tz=dt.UTC) ) @@ -88,8 +90,8 @@ def test_app(write_to_db, db_session): assert result.exit_code == 0 if write_to_db: - assert db_session.query(ForecastSQL).count() == init_n_forecasts + 3 - assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values + (3 * 192) + assert db_session.query(ForecastSQL).count() == init_n_forecasts + 2 + assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values + (2 * 192) else: assert db_session.query(ForecastSQL).count() == init_n_forecasts assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values