Skip to content

Commit

Permalink
Merge pull request #8 from openclimatefix/chris/forecast-by-asset-type
Browse files Browse the repository at this point in the history
Forecast by asset type
  • Loading branch information
peterdudfield authored Jan 25, 2024
2 parents 1ade9d6 + a6b0abd commit 32dcf2a
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 70 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ COPY --from=builder /venv /venv
COPY --from=builder /app/dist .
RUN . /venv/bin/activate && pip install *.whl

ENTRYPOINT ["app"]
ENTRYPOINT ["app", "--write-to-db"]
44 changes: 26 additions & 18 deletions india_forecast_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pandas as pd
from pvsite_datamodel import DatabaseConnection
from pvsite_datamodel.read import get_sites_by_country
from pvsite_datamodel.sqlmodels import SiteSQL
from pvsite_datamodel.write import insert_forecast_values
from sqlalchemy.orm import Session

Expand All @@ -19,31 +20,33 @@
log = logging.getLogger(__name__)


def get_site_ids(db_session: Session) -> list[str]:
def get_sites(db_session: Session) -> list[SiteSQL]:
"""
Gets all avaiable site_ids in India
Gets all available sites in India
Args:
db_session: A SQLAlchemy session
Returns:
A list of site_ids
A list of SiteSQL objects
"""

sites = get_sites_by_country(db_session, country="india")
return sites

return [s.site_uuid for s in sites]


def get_model():
def get_model(asset_type: str):
"""
Instantiates and returns the forecast model ready for running inference
Args:
asset_type: One or "pv" or "wind"
Returns:
A forecasting model
"""

model = DummyModel()
model = DummyModel(asset_type)
return model


Expand Down Expand Up @@ -151,17 +154,22 @@ def app(timestamp: dt.datetime | None, write_to_db: bool, log_level: str):

# 1. Get sites
log.info("Getting sites...")
site_ids = get_site_ids(session)
log.info(f"Found {len(site_ids)} sites")

# 2. Load model
log.info("Loading model...")
model = get_model()
log.info("Loaded model")

# 3. Run model for each site
for site_id in site_ids:
log.info(f"Running model for site={site_id}...")
sites = get_sites(session)
log.info(f"Found {len(sites)} sites")

# 2. Load models
log.info("Loading models...")
pv_model = get_model("pv")
log.info("Loaded PV model")
wind_model = get_model("wind")
log.info("Loaded wind model")

for site in sites:
# 3. Run model for each site
site_id = site.site_uuid
asset_type = site.asset_type.name
log.info(f"Running {asset_type} model for site={site_id}...")
model = wind_model if asset_type == "wind" else pv_model
forecast_values = run_model(model=model, site_id=site_id, timestamp=timestamp)

if forecast_values is None:
Expand Down
12 changes: 7 additions & 5 deletions india_forecast_app/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ class DummyModel:
"""
Dummy model that emulates the capabilities expected by a real model
"""

@property
def version(self):
"""Version number"""
return "0.0.0"

def __init__(self):
def __init__(self, asset_type: str):
"""Initializer for the model"""
pass
self.asset_type = asset_type

def predict(self, site_id: str, timestamp: dt.datetime):
"""Make a prediction for the model"""
Expand All @@ -35,7 +35,9 @@ def _generate_dummy_forecast(self, timestamp: dt.datetime):

for i in range(numSteps):
time = start + i * step
_yield = _basicSolarYieldFunc(int(time.timestamp()))
gen_func = _basicSolarYieldFunc if self.asset_type == "pv" else _basicWindYieldFunc
_yield = gen_func(int(time.timestamp()))

values.append(
{
"start_utc": time,
Expand Down Expand Up @@ -99,6 +101,6 @@ def _basicSolarYieldFunc(timeUnix: int, scaleFactor: int = 10000) -> float:

def _basicWindYieldFunc(timeUnix: int, scaleFactor: int = 10000) -> float:
"""Gets a fake wind yield for the input time."""
output = min(scaleFactor, scaleFactor * 10 * random.random())
output = min(scaleFactor, scaleFactor * random.random())

return output
26 changes: 11 additions & 15 deletions scripts/seed_local_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
from pvsite_datamodel.connection import DatabaseConnection
from pvsite_datamodel.sqlmodels import Base
from pvsite_datamodel.write.user_and_site import (
add_site_to_site_group,
create_site,
create_site_group,
create_user,
)


Expand Down Expand Up @@ -45,29 +42,28 @@ def seed_db():
Base.metadata.create_all(engine)

with db_conn.get_session() as session:
dummy_user = "[email protected]"

print("Seeding database")
site_group = create_site_group(session, site_group_name="dummy_site_group")

_ = create_user(
session, email=dummy_user, site_group_name=site_group.site_group_name
)

site, _ = create_site(
session,
client_site_id=1234,
client_site_name="dummy_site",
client_site_id=1,
client_site_name="dummy_site_1",
latitude=0.0,
longitude=0.0,
capacity_kw=10.0,
asset_type="pv",
country="india",
)

add_site_to_site_group(
site, _ = create_site(
session,
site_uuid=site.site_uuid,
site_group_name=site_group.site_group_name,
client_site_id=2,
client_site_name="dummy_site_2",
latitude=0.0,
longitude=0.0,
capacity_kw=10.0,
asset_type="wind",
country="india",
)

print("Database successfully seeded")
37 changes: 24 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,30 @@ def db_data(engine):

with engine.connect() as connection:
with Session(bind=connection) as session:
n_sites = 3

# Sites
for i in range(n_sites):
site = SiteSQL(
client_site_id=i + 1,
latitude=51,
longitude=3,
capacity_kw=4,
ml_id=i,
country="india"
)
session.add(site)

# PV site
site = SiteSQL(
client_site_id=1,
latitude=20.59,
longitude=78.96,
capacity_kw=4,
ml_id=1,
asset_type="pv",
country="india"
)
session.add(site)

# Wind site
site = SiteSQL(
client_site_id=2,
latitude=20.59,
longitude=78.96,
capacity_kw=4,
ml_id=2,
asset_type="wind",
country="india"
)
session.add(site)

session.commit()

Expand Down
38 changes: 20 additions & 18 deletions tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,45 @@

import pytest
from pvsite_datamodel.read import get_all_sites
from pvsite_datamodel.sqlmodels import ForecastSQL, ForecastValueSQL, SiteSQL
from pvsite_datamodel.sqlmodels import ForecastSQL, ForecastValueSQL

from india_forecast_app.app import app, get_model, get_site_ids, run_model, save_forecast
from india_forecast_app.app import app, get_model, get_sites, run_model, save_forecast
from india_forecast_app.model import DummyModel

from ._utils import run_click_script


def test_get_site_ids(db_session):
def test_get_sites(db_session):
"""Test for correct site ids"""

site_ids = get_site_ids(db_session)
sites = get_sites(db_session)
sites = sorted(sites, key=lambda s: s.client_site_id)

assert len(site_ids) == 3
for site_id in site_ids:
assert isinstance(site_id, uuid.UUID)
assert len(sites) == 2
for site in sites:
assert isinstance(site.site_uuid, uuid.UUID)
assert sites[0].asset_type.name == "pv"
assert sites[1].asset_type.name == "wind"


def test_get_model():
@pytest.mark.parametrize("asset_type", ["pv", "wind"])
def test_get_model(asset_type):
"""Test for getting valid model"""

model = get_model()
model = get_model(asset_type)

assert hasattr(model, 'version')
assert isinstance(model.version, str)
assert hasattr(model, 'predict')


def test_run_model(db_session):
"""Test for running model"""

site = db_session.query(SiteSQL).first()
model = DummyModel()
@pytest.mark.parametrize("asset_type", ["pv", "wind"])
def test_run_model(db_session, asset_type):
"""Test for running PV and wind models"""

forecast = run_model(
model=model,
site_id=site.site_uuid,
model=DummyModel(asset_type),
site_id=str(uuid.uuid4()),
timestamp=dt.datetime.now(tz=dt.UTC)
)

Expand Down Expand Up @@ -88,8 +90,8 @@ def test_app(write_to_db, db_session):
assert result.exit_code == 0

if write_to_db:
assert db_session.query(ForecastSQL).count() == init_n_forecasts + 3
assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values + (3 * 192)
assert db_session.query(ForecastSQL).count() == init_n_forecasts + 2
assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values + (2 * 192)
else:
assert db_session.query(ForecastSQL).count() == init_n_forecasts
assert db_session.query(ForecastValueSQL).count() == init_n_forecast_values

0 comments on commit 32dcf2a

Please sign in to comment.